Overcoming catastrophic forgetting with Elastic Weight Consolidation (EWC)
In this post, we are going to discuss a seminal method from the field of Continual Learning called Elastic Weight Consolidation, also known as EWC.
Background
Continual Learning (CL) tackles learning from sequential data streams with limited retention and retraining capacity. It requires efficient use of past data and adaptation to new contexts under changing distributions without catastrophic forgetting.
Let’s consider a simple example. Suppose we have a neural network with two tasks, where the first task is to classify between cats and dogs, and the second task is to classify between cars and bikes. The network is trained on the first task, and then the second task is introduced. The network is then trained on the second task, and the weights are updated. However, the weights that were learned for the first task are now overwritten, and the network forgets how to classify between cats and dogs.
More precisely, in Continual Learning a model \(f_\theta\) is trained on a sequence of \(T\) tasks, where for each task \(t \in \{ 1, \ldots, T \} \) the learner only gets access to a subset of samples of the given task: \( D_t = \{ (x_i, y_i) \}_{i=1}^{N_t} \). However, at the end the model is evaluanted on the joint performance, therefore we should aim to optimize:
\[ \theta^{\star} = arg\min_{\theta} \sum_{i=1}^{T} \mathbb{E}_{(x, y) \sim D} [ \mathcal{L} (f_{\theta}(x), y)] \]The main challenge is that at the time of task \(t\), the model has no access to data from previous tasks \(\tilde{t} \in \{1, \ldots, t-1\}\), therefore violating the typical IID data assumption.
Overcoming catastrophic forgetting in neural networks
One of the seminal works addressing catastrophic forgetting in neural networks is the method called Elastic Weight Consolidation (EWC), first proposed in “Overcoming catastrophic forgetting in neural networks” by Kirkpatrick et al. (2017).
In order to get a better intuition for the final formulation of the method, let us first consider the Bayesian perspective on training neural networks.
From Bayes rule, we have:
\[ p(\theta | D) = \frac{p(D | \theta) p(\theta)}{p(D)} \]where \(p(\theta | D)\) is the posterior distribution of the weights given the data, \(p(D | \theta)\) is the likelihood of the data given the weights, \(p(\theta)\) is the prior distribution of the weights, and \(p(D)\) is the marginal likelihood of the data (also known as evidence).
Taking the log of the posterior, we have:
\[ \log p(\theta | D) = \log p(D | \theta) + \log p(\theta) - \log p(D) \]The goal is find the optimal configuration of parameters \( \theta^{\star} \) that maximizes the (log) posterior:
\[ \theta^{\star} = arg\max_{\theta} \log p(\theta | D) \]In the case of 2 independent tasks s.t. \(D = \{A, B\}\), we can re-write the log-posterior as:
\[ \begin{align*} \log(p(\theta | D)) &= \log \left(\frac{p(B | A, \theta) p(\theta | A) p(A)}{p(B|A) p(A)}\right)\\ &= \log(p(B|\theta)) + \log (p(\theta | A)) - \log(p(B)) &\text{(conditional independence of A and B)} \\ &\approx \log(p(B|\theta)) + \log (p(\theta | A)) &\text{($\log(p(B))$ is const.)} \end{align*} \]The likelihood \(p(B|\theta)\) corresponds to the loss on the task B. Notice however, that the posterior \(p(\theta|A)\) is in general intractable for neural nets, and for this reason, we will resort to Laplace’s approximation.
\[ \begin{align*} l(\theta) &\approx \underbrace{l(\theta_{A}^{*})}_{\text{const.}} + \underbrace{\left( \left.\frac{\partial l(\theta)}{\partial \theta} \right\vert_{\theta_{A}^{*}}\right)}_{0} + \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 l(\theta)}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \\ &\approx \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 l(\theta)}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} \]\[ \begin{align*} \log(p(\theta|A)) \approx \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} \]\[ \begin{align*} p(\theta | A) &\approx \text{exp}\left( \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \right) \\ &= \text{exp}\left(-\frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left(\left. -\frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right)^{-1}\right)^{-1} \left(\theta - \theta_{A}^*\right) \right) \end{align*} \]\[ \begin{align*} p(\theta | A) \approx \mathcal{N}\left(\theta_{A}^{*}, \left(- \left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right)^{-1} \right) \end{align*} \]Notice that in the formula above, the term \(-\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}}\) is the Fisher Information matrix, which is a measure of the curvature of the log-likelihood around the optimal parameters \(\theta_{A}^{*}\):
\[ \begin{align*} \mathbb{I}_{A} = \mathbb{E} \left[ -\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right] \end{align*} \]First, let us note that the Fisher Information matrix can be efficiently computed using first-order derivatives of the log-likelihood:
\[ \begin{align*} \mathbb{I}_{A} &= \mathbb{E} \left[ -\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right] \\ &= \mathbb{E} \left[ \left( \frac{\partial \log(p(\theta|A))}{\partial \theta} \right) \left.\left( \frac{\partial \log(p(\theta|A))}{\partial \theta} \right)^t \right\vert_{\theta_{A}^{*}} \right] \end{align*} \]However, given that neural networks today can have billions of parameters, materializing the entire Fisher Information matrix is computationally infeasible due to its quadratic complexity in the number of parameters. For this reason, EWC approximates the Fisher Information matrix by only considering the diagonal elements of the matrix.
This means that we only need to compute the gradients of the log-likelihood with respect to the parameters, and then square them!
\[ \begin{align*} \log(p(\theta | D)) &\approx \log(p(B|\theta)) + \log (p(\theta | A)) \\ &\approx \log(p(B|\theta)) + \frac{\lambda}{2} \left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} \]\[ \begin{align*} \log(p(\theta | D)) &\approx \log(p(B|\theta)) - \frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right) \\ \implies -\log(p(\theta | D)) &\approx -\log(p(B|\theta)) + \frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right) \\ \underbrace{\mathcal{L}(\theta)}_{\text{total loss}} &\approx \underbrace{\mathcal{L}_B(\theta)}_{\text{loss on B}} + \underbrace{\frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right)}_{\text{regularizer}} \end{align*} \]Progress and Compress
\[ \begin{align*} p(\theta | T_{1:k}) &= \frac{p(T_{1:k} | \theta) p(\theta)}{p(T_{1:k})} \\ &\approx p(T_{1:k} | \theta) p(\theta) &\text{($P(T_{1:k})$ is const.)} \\ &\approx \left(\prod_{i=1}^k p(T_i | \theta) \right) p(\theta) &\text{(cond. independence)}\\ &= p(T_1 | \theta) p(T_2 | \theta) \ldots p(T_{k-1} | \theta) p(T_k | \theta) p(\theta) &\text{(expand prod)}\\ &= p(T_1 | \theta) p(T_2 | \theta) \ldots p(T_{k-1} | \theta) p(\theta) p(T_k | \theta) &\text{(reorder)} \\ &\approx p(\theta|T_{1:k-1})p(T_k|\theta) &\text{(same approx as above)} \\ \end{align*} \]This means that the posterior of \(\theta\) given all tasks up to \(k\) can be computed sequentially, by first computing it for the first \(k-1\) tasks, and then updating it with the likelihood (alternatively, the loss) for the \(k\)-th task.
\[ \begin{align*} -\log(p(\theta | T_{1:k})) &\approx -\log (p(\theta | T_{1:k-1})) - \log(p(T_{k}| \theta)) \\ &= - \log(p(T_{k}| \theta)) - \log (p(\theta | T_{1:k-1})) &\text{(reorder)} \\ &\approx \underbrace{-\log(p(T_{k}|\theta))}_{\text{loss on task $T_k$}} + \underbrace{\frac{1}{2}\sum_{j=0}^{k-1} \left\Vert \theta - \theta_j^* \right\Vert_{F_j}^2}_{\text{regularizer}} &\text{(see previous section)} \end{align*} \]\[ \begin{align*} -\log(p(T_k | \theta)) + \frac{1}{2} \left\Vert \theta - \theta_{k-1}^* \right\Vert_{\sum_{j=0}^{k-1}F_j}^2 \end{align*} \]This means that we only need to keep the latest Maximum-A-Posteriori (MAP) parameters, along with a running sum of Fishers.
What the paper “Progress & Compress: A scalable framework for continual learning” by Schwarz et al. (2018) instead suggests is to use a running average of the Fisher Information matrices, which is more computationally efficient.
\[ \begin{align*} -\log(p(T_i | \theta)) + \frac{1}{2}\left\Vert \theta - \theta_{i-1}^{*}\right\Vert_{\gamma F_{i-1}^{*}}^2 \end{align*} \]where \(\gamma < 1\) is a hyperparameter associated with removing the approximation term associated with the previous presentation of task \(i\).
\[ \begin{align*} F_i^* = \gamma F_{i-1}^* + F_i \end{align*} \]The authors refer to this modified method as online EWC.
Following is a pseudo-code implementation of the method, adapted from this repo for better readibility:
class OnlineEWC:
def __init__(self):
self.net = YourFavoriteModel()
self.opt = torch.optim.SGD(self.net.parameters(), lr=1e-3)
self.logsoft = nn.LogSoftmax(dim=1)
self.batch_size = 32
self.gamma = 0.9
self.e_lambda = 1e-3
self.checkpoint = None # the MAP parameters for the previous task
self.fish = None # the running sum of Fisher Information matrices
def penalty(self):
"""
The regularization term in the total loss function (see derivation above).
"""
if self.checkpoint is None:
# For the first task there is no previous task, therefore regularization is 0
return torch.tensor(0.0).to(self.device)
else:
penalty = (self.fish * ((self.net.get_params() - self.checkpoint) ** 2)).sum()
return penalty
def observe(self, inputs, labels):
"""
This method is called at each iteration to update the model's weights.
Args:
inputs: the input samples
labels: the corresponding labels
"""
self.opt.zero_grad()
outputs = self.net(inputs)
loss = self.loss(outputs, labels) + self.e_lambda * self.penalty()
loss.backward()
self.opt.step()
return loss.item()
def end_task(self, dataset):
"""
This method is called at the end of each task to update the
Fisher Information matrix and the MAP parameters.
Args:
dataset: the dataset object containing samples from the last task.
"""
fish = torch.zeros_like(self.net.get_params())
for inputs, labels in dataset.train_loader:
for inp, lab in zip(inputs, labels):
self.opt.zero_grad()
output = self.net(inp.unsqueeze(0))
loss = -F.nll_loss(self.logsoft(output), lab.unsqueeze(0), reduction='none')
exp_cond_prob = torch.mean(torch.exp(loss.detach().clone()))
loss = torch.mean(loss)
loss.backward()
fish += exp_cond_prob * self.net.get_grads() ** 2
fish /= (len(dataset.train_loader) * self.batch_size)
if self.fish is None:
self.fish = fish
else:
self.fish = self.gamma * self.fish + fish
self.checkpoint = self.net.get_params().data.clone()