=======================
== Zafir Stojanovski ==
=======================

Overcoming catastrophic forgetting with Elastic Weight Consolidation (EWC)

In this post, we are going to discuss a seminal method from the field of Continual Learning called Elastic Weight Consolidation, also known as EWC.

Background

Continual Learning (CL) tackles learning from sequential data streams with limited retention and retraining capacity. It requires efficient use of past data and adaptation to new contexts under changing distributions without catastrophic forgetting.

Let’s consider a simple example. Suppose we have a neural network with two tasks, where the first task is to classify between cats and dogs, and the second task is to classify between cars and bikes. The network is trained on the first task, and then the second task is introduced. The network is then trained on the second task, and the weights are updated. However, the weights that were learned for the first task are now overwritten, and the network forgets how to classify between cats and dogs.

More precisely, in Continual Learning a model \(f_\theta\) is trained on a sequence of \(T\) tasks, where for each task \(t \in \{ 1, \ldots, T \} \) the learner only gets access to a subset of samples of the given task: \( D_t = \{ (x_i, y_i) \}_{i=1}^{N_t} \). However, at the end the model is evaluanted on the joint performance, therefore we should aim to optimize:

\[ \theta^{\star} = arg\min_{\theta} \sum_{i=1}^{T} \mathbb{E}_{(x, y) \sim D} [ \mathcal{L} (f_{\theta}(x), y)] \]

The main challenge is that at the time of task \(t\), the model has no access to data from previous tasks \(\tilde{t} \in \{1, \ldots, t-1\}\), therefore violating the typical IID data assumption.

Overcoming catastrophic forgetting in neural networks

One of the seminal works addressing catastrophic forgetting in neural networks is the method called Elastic Weight Consolidation (EWC), first proposed in “Overcoming catastrophic forgetting in neural networks” by Kirkpatrick et al. (2017).

In order to get a better intuition for the final formulation of the method, let us first consider the Bayesian perspective on training neural networks.

From Bayes rule, we have:

\[ p(\theta | D) = \frac{p(D | \theta) p(\theta)}{p(D)} \]

where \(p(\theta | D)\) is the posterior distribution of the weights given the data, \(p(D | \theta)\) is the likelihood of the data given the weights, \(p(\theta)\) is the prior distribution of the weights, and \(p(D)\) is the marginal likelihood of the data (also known as evidence).

Taking the log of the posterior, we have:

\[ \log p(\theta | D) = \log p(D | \theta) + \log p(\theta) - \log p(D) \]

The goal is find the optimal configuration of parameters \( \theta^{\star} \) that maximizes the (log) posterior:

\[ \theta^{\star} = arg\max_{\theta} \log p(\theta | D) \]

In the case of 2 independent tasks s.t. \(D = \{A, B\}\), we can re-write the log-posterior as:

\[ \begin{align*} \log(p(\theta | D)) &= \log \left(\frac{p(B | A, \theta) p(\theta | A) p(A)}{p(B|A) p(A)}\right)\\ &= \log(p(B|\theta)) + \log (p(\theta | A)) - \log(p(B)) &\text{(conditional independence of A and B)} \\ &\approx \log(p(B|\theta)) + \log (p(\theta | A)) &\text{($\log(p(B))$ is const.)} \end{align*} \]

The likelihood \(p(B|\theta)\) corresponds to the loss on the task B. Notice however, that the posterior \(p(\theta|A)\) is in general intractable for neural nets, and for this reason, we will resort to Laplace’s approximation.

To begin with, consider the second order Taylor expansion of the log-likelihood \(l(\theta)\) around task \(A\)’s optimal parameters, \(\theta_A^{*}\):

\[ \begin{align*} l(\theta) &\approx \underbrace{l(\theta_{A}^{*})}_{\text{const.}} + \underbrace{\left( \left.\frac{\partial l(\theta)}{\partial \theta} \right\vert_{\theta_{A}^{*}}\right)}_{0} + \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 l(\theta)}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \\ &\approx \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 l(\theta)}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} \]

In this manner, we can analogously write the approximation for the log posterior \(p(\theta | A)\) in a similar form:

\[ \begin{align*} \log(p(\theta|A)) \approx \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} \]

Now, by exponentiating both sides and performing some algebraic manipulations, we can write the posterior as a multivariate Gaussian distribution:

\[ \begin{align*} p(\theta | A) &\approx \text{exp}\left( \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \right) \\ &= \text{exp}\left(-\frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left(\left. -\frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right)^{-1}\right)^{-1} \left(\theta - \theta_{A}^*\right) \right) \end{align*} \]

Therefore, we obtain the following Laplace approximation:

\[ \begin{align*} p(\theta | A) \approx \mathcal{N}\left(\theta_{A}^{*}, \left(- \left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right)^{-1} \right) \end{align*} \]

Notice that in the formula above, the term \(-\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}}\) is the Fisher Information matrix, which is a measure of the curvature of the log-likelihood around the optimal parameters \(\theta_{A}^{*}\):

\[ \begin{align*} \mathbb{I}_{A} = \mathbb{E} \left[ -\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right] \end{align*} \]

First, let us note that the Fisher Information matrix can be efficiently computed using first-order derivatives of the log-likelihood:

\[ \begin{align*} \mathbb{I}_{A} &= \mathbb{E} \left[ -\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right] \\ &= \mathbb{E} \left[ \left( \frac{\partial \log(p(\theta|A))}{\partial \theta} \right) \left.\left( \frac{\partial \log(p(\theta|A))}{\partial \theta} \right)^t \right\vert_{\theta_{A}^{*}} \right] \end{align*} \]

However, given that neural networks today can have billions of parameters, materializing the entire Fisher Information matrix is computationally infeasible due to its quadratic complexity in the number of parameters. For this reason, EWC approximates the Fisher Information matrix by only considering the diagonal elements of the matrix.

This means that we only need to compute the gradients of the log-likelihood with respect to the parameters, and then square them!

Now, going back and plugging the approximation for \(\log(p(\theta | A))\), we get:

\[ \begin{align*} \log(p(\theta | D)) &\approx \log(p(B|\theta)) + \log (p(\theta | A)) \\ &\approx \log(p(B|\theta)) + \frac{\lambda}{2} \left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} \]

where \(\lambda\) is a hyperparamter that trades off learning task B and not forgetting task A. Simplifying further:

\[ \begin{align*} \log(p(\theta | D)) &\approx \log(p(B|\theta)) - \frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right) \\ \implies -\log(p(\theta | D)) &\approx -\log(p(B|\theta)) + \frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right) \\ \underbrace{\mathcal{L}(\theta)}_{\text{total loss}} &\approx \underbrace{\mathcal{L}_B(\theta)}_{\text{loss on B}} + \underbrace{\frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right)}_{\text{regularizer}} \end{align*} \]

Progress and Compress

In the previous section we looked at the case when our dataset consists only of two tasks \(A\) and \(B\). Now, let us consider the more general case when we have a sequence of \(k\) different tasks:

\[ \begin{align*} p(\theta | T_{1:k}) &= \frac{p(T_{1:k} | \theta) p(\theta)}{p(T_{1:k})} \\ &\approx p(T_{1:k} | \theta) p(\theta) &\text{($P(T_{1:k})$ is const.)} \\ &\approx \left(\prod_{i=1}^k p(T_i | \theta) \right) p(\theta) &\text{(cond. independence)}\\ &= p(T_1 | \theta) p(T_2 | \theta) \ldots p(T_{k-1} | \theta) p(T_k | \theta) p(\theta) &\text{(expand prod)}\\ &= p(T_1 | \theta) p(T_2 | \theta) \ldots p(T_{k-1} | \theta) p(\theta) p(T_k | \theta) &\text{(reorder)} \\ &\approx p(\theta|T_{1:k-1})p(T_k|\theta) &\text{(same approx as above)} \\ \end{align*} \]

This means that the posterior of \(\theta\) given all tasks up to \(k\) can be computed sequentially, by first computing it for the first \(k-1\) tasks, and then updating it with the likelihood (alternatively, the loss) for the \(k\)-th task.

Maximizing the posterior is equivalent to minimizing the negative log-posterior. Therefore, we can further obtain:

\[ \begin{align*} -\log(p(\theta | T_{1:k})) &\approx -\log (p(\theta | T_{1:k-1})) - \log(p(T_{k}| \theta)) \\ &= - \log(p(T_{k}| \theta)) - \log (p(\theta | T_{1:k-1})) &\text{(reorder)} \\ &\approx \underbrace{-\log(p(T_{k}|\theta))}_{\text{loss on task $T_k$}} + \underbrace{\frac{1}{2}\sum_{j=0}^{k-1} \left\Vert \theta - \theta_j^* \right\Vert_{F_j}^2}_{\text{regularizer}} &\text{(see previous section)} \end{align*} \]

Note that this formulation requires keeping a mean and Fisher for each task, thus making the computation cost linear in the number of tasks. Alternatively, one can apply Laplace’s approximation for the whole posterior \(p(\theta|T_{1:k})\), rather than the individual likelihood terms, thus resulting in:

\[ \begin{align*} -\log(p(T_k | \theta)) + \frac{1}{2} \left\Vert \theta - \theta_{k-1}^* \right\Vert_{\sum_{j=0}^{k-1}F_j}^2 \end{align*} \]

This means that we only need to keep the latest Maximum-A-Posteriori (MAP) parameters, along with a running sum of Fishers.

What the paper “Progress & Compress: A scalable framework for continual learning” by Schwarz et al. (2018) instead suggests is to use a running average of the Fisher Information matrices, which is more computationally efficient.

More precisely, let \(\theta_{i-1}^{*}\), \(F_{i-1}^{*}\) be the MAP parameters and overall Fisher after presentation of \(i-1\) tasks. Then, the loss for the $i$-th task is defined as:

\[ \begin{align*} -\log(p(T_i | \theta)) + \frac{1}{2}\left\Vert \theta - \theta_{i-1}^{*}\right\Vert_{\gamma F_{i-1}^{*}}^2 \end{align*} \]

where \(\gamma < 1\) is a hyperparameter associated with removing the approximation term associated with the previous presentation of task \(i\).

If \(\theta_{i}^{*}\) are the optimal MAP parameters, and \(F_i\) the Fisher for task \(i\), then the overall Fisher is updated as:

\[ \begin{align*} F_i^* = \gamma F_{i-1}^* + F_i \end{align*} \]

The authors refer to this modified method as online EWC.

Following is a pseudo-code implementation of the method, adapted from this repo for better readibility:

class OnlineEWC:
    def __init__(self):
        self.net = YourFavoriteModel() 
        self.opt = torch.optim.SGD(self.net.parameters(), lr=1e-3) 
        self.logsoft = nn.LogSoftmax(dim=1)

        self.batch_size = 32
        self.gamma = 0.9
        self.e_lambda = 1e-3

        self.checkpoint = None # the MAP parameters for the previous task
        self.fish = None # the running sum of Fisher Information matrices
    

    def penalty(self):
        """
        The regularization term in the total loss function (see derivation above).
        """
        if self.checkpoint is None:
            # For the first task there is no previous task, therefore regularization is 0
            return torch.tensor(0.0).to(self.device)
        else:
            penalty = (self.fish * ((self.net.get_params() - self.checkpoint) ** 2)).sum()
            return penalty


    def observe(self, inputs, labels):
        """
        This method is called at each iteration to update the model's weights.

        Args:
            inputs: the input samples
            labels: the corresponding labels
        """

        self.opt.zero_grad()
        outputs = self.net(inputs)
        loss = self.loss(outputs, labels) + self.e_lambda * self.penalty()
        loss.backward()
        self.opt.step()
        return loss.item()


    def end_task(self, dataset):
        """
        This method is called at the end of each task to update the 
        Fisher Information matrix and the MAP parameters.

        Args:
            dataset: the dataset object containing samples from the last task.
        """
        fish = torch.zeros_like(self.net.get_params())

        for inputs, labels in dataset.train_loader:
            for inp, lab in zip(inputs, labels):
                self.opt.zero_grad()
                output = self.net(inp.unsqueeze(0))
                loss = -F.nll_loss(self.logsoft(output), lab.unsqueeze(0), reduction='none')
                exp_cond_prob = torch.mean(torch.exp(loss.detach().clone()))
                loss = torch.mean(loss)
                loss.backward()
                fish += exp_cond_prob * self.net.get_grads() ** 2

        fish /= (len(dataset.train_loader) * self.batch_size)
        if self.fish is None:
            self.fish = fish
        else:
            self.fish = self.gamma * self.fish + fish

        self.checkpoint = self.net.get_params().data.clone()