Skip to yearly menu bar Skip to main content


Poster

Continual Learning with Global Alignment

Xueying Bai · Jinghuan Shang · Yifan Sun · Niranjan Balasubramanian

[ ]
Thu 12 Dec 11 a.m. PST — 2 p.m. PST

Abstract:

Continual learning aims to sequentially learn new tasks without forgetting previous tasks' knowledge (catastrophic forgetting). One factor that can cause forgetting is the interference between the gradients on losses from different tasks. When the gradients on the current task's loss are in the contradicted directions to those on previous tasks' losses, the model updated for the current task may forget previous tasks' knowledge. In this paper, we first identify causes of the above interference, and hypothesize that correlations between data representations are a key factor of interference. We then propose a method for promoting appropriate correlations between arbitrary tasks' data representations (i.e., global alignment) in individual task learning. Specifically, we learn the data representation as a task-specific composition of pre-trained token representations shared across all tasks. Then the correlations between different tasks' data representations are grounded by correlations between pre-trained token representations. We explore different ways to learn such compositions. Without experience replay, our model achieves SOTA performance in continual learning tasks. It also achieves advanced class-incremental performance through task-incremental training.

Live content is unavailable. Log in and register to view live content