Studying Large Language Model Generalization with Influence Functions

Influence functions are a tool to quantify the impact of each training sample on a model’s predictions, thereby assisting in the interpretation of neural networks and potentially improving trust and alignment. However, so far their usage has been limited due to high computational demands. A recent study leverages an approximation method known as EK-FAC to significantly lower the cost of acquiring dependable influence scores. The application of influence functions to a model with 52 billion parameters reveals fascinating learning patterns.

First introduced in the context of robust statistics [Ham05R] and then popularised in machine learning by [Koh17U], influence functions aim at quantifying the impact of each training point on the predictions of a model.

Naïvely, one could calculate the impact of a training point $z_m = (x_m, y_m)$ on a test point $z_c = (x_c, y_c)$ by comparing the test loss on $z_c$ between two models: one trained on the full dataset and another trained on the same dataset but with $z_m$ removed. This implies that the model needs to be retrained from scratch a number of times equal to the number of training samples, which is impractical even for small neural networks.

A more efficient calculation would use influence functions, which leverage a first-order Taylor approximation of the loss function around the model parameters. Let’s call $\mathcal{I}_{\theta} (z_c, z_m)$ the influence score of a training point $z_m$ on a test point $z_c$ when the model has weights $\theta$. An application of the chain rule and the implicit function theorem yield:

\begin{equation} \mathcal{I}_{\theta} (z_c, z_m) = - \nabla_\theta \mathcal{L}(z_c, \theta) \ \mathcal{H}^{-1} \ \nabla_\theta \mathcal{L}(z_m, \theta) \tag1\end{equation}

where $\nabla_\theta \mathcal{L}(z_i, \theta)$ is the gradient of the loss function evaluated at the test point $z_i$ and $\mathcal{H}$ is the Hessian of the model’s training loss.1 1 A simple derivation can be found in section 2 of [Koh17U] It is important to note that all the elements in this formula are derivatives with respect to the model’s weights, which means they can be computed using backpropagation.

While this is a significant improvement over the initial leave-one-out re-training, it still presents numerous challenges.

Hessian approximations

Obtaining the influence scores in Eq (1) requires the inversion of the Hessian, which is a square matrix of size $n \times n$ with $n$ the number of parameters of the model. This is prohibitive both in terms of memory and computation time. Over the years, several approximate iterative methods have been proposed, such as conjugate gradient, LiSSA (Linear time Stochastic Second-Order Algorithm) and Arnoldi [Sch22S] (refer to [Tra22P] for further details and for a stable implementation). These obtain good results on medium-sized models [Fis23I], and offer several theoretical guarantees, such as statistical and computational complexity bounds [Fis23S]. Nevertheless, they remain prohibitive for models with billions of parameters, like current state-of-the-art language models.

The new paper [Gro23S] proposes an interesting approach. Instead of relying on iterative solvers, it applies a method called EK-FAC [Geo18F] to approximate the Hessian with a low-rank matrix, much easier to store and invert.2 2 The original algorithm, K-FAC, was developed to enable second-order optimization of large problems, see for instance this pytorch implementation [Geo23N].

K-FAC [Mar15O], short for Kronecker-Factored Approximate Curvature, is a method that approximates the Fisher information matrix (FIM) of a model. The FIM is typically used to estimate the error associated to maximum-likelihood estimates of parametric models. In the case of supervised learning without distributional assumptions, one considers the empirical data distribution of samples and their labels as predicted by a softmax output. With this, and appropriate losses, e.g. categorical cross entropy, the FIM thus defined is equal to the Hessian of the model’s loss over the dataset. Therefore, K-FAC can be applied to influence score calculation.

Figure 6 from [Gro23S]. The correlation to the exact influence score for EK-FAC is better than with a simple dot-product and is comparable to LiSSA on all datasets and tasks.

Firstly, K-FAC makes a strong assumption that the gradients of the weights are independent across different layers, which simplifies the Hessian to a block diagonal form. Secondly, it disregards the statistical interdependence among some of the gradients within the same layer. There is no a priori guarantee about the error this approximation might introduce, which can indeed be substantial. To mitigate these issues, EK-FAC (short for eigenvalue corrected K-FAC [Geo18F]) introduces an additional step. In this step, the block Hessian from K-FAC is further eigendecomposed, i.e. it is factorised into a product of its eigenvectors and eigenvalues. The eigenvalues thus found are then fitted to align more closely with the eigenvalues of the full model, which reduces the error introduced by the approximation.

Figure 7 from [Gro23S]. Despite being many times faster, EK-FAC achieves accuracy comparable to the fully converged LiSSA.

The paper tests how well EK-FAC performs in practice both with a few small-scale NNs and with a medium-sized language model where the exact scores can be calculated explicitly. Figure 6 and Figure 7 compare the accuracy of EK-FAC to that of LiSSA and to the simple dot product of the gradients, i.e. Eq. (1) where the Hessian is set equal to the identity matrix (a technique known as TracIn [Pru20E]). EK-FAC achieves comparable accuracy to fully-converged LiSSA, but in much shorter time!

Influences for large language models

Other issues emerge when applying influence functions to large language models. Calculating the influence scores for all samples in the training set is still prohibitive since it involves calculating the gradient for each training point. This amounts to one full epoch of training using a batch size of 1, which given how LLMs are typically only trained for one epoch, is more expensive than training the full model.

To address this, the paper suggests two methods for rapidly identifying training sentences and documents that could potentially have a high average influence on a given test query. The first is based on the TF-IDF. 3 3 TF-IDF, which stands for “Term Frequency-Inverse Document Frequency,” is a numerical statistic used in information retrieval to reflect how important a word is to a document in a collection or corpus. The TF-IDF value increases proportionally to the number of times a word appears in the document but is offset by the frequency of the word in the corpus, which helps to adjust for the fact that some words appear more frequently in general. TF-IDF is often used in text mining and document search to prioritize words that are unique or more relevant to a specific document, as opposed to words that are common across multiple documents. See more on wikipedia. Training and test sentences which have a high TF-IDF score for the same words might also have a high influence score. The second leverages the fact that the term $\nabla_\theta \mathcal{L}(z_m, \theta)$ in Eq. (1) is independent of the test point, and can thus be calculated once for all test samples. This comes at the cost of a larger memory footprint, but, just like the Hessian, gradients can also be approximated with a low-rank decomposition. Figure 3 shows that even with a very low-rank decomposition of test points’ gradients the scores are very correlated with the exact ones.

Figure 3 from [Gro23S]. Influence scores from the 52 Billion parameters language model Left Compressing the query gradients does not affect the correlation with the exact influence scores. Right Pearson correlation is always very good above rank 32, independently of the query.

The paper studies several LLM models, with sizes ranging from 800 million to 52 billion parameters. For each model, influences are calculated only for the MLP layers, which, in any case, make up the majority of the parameters.

One other positive consequence of using EK-FAC is that influence scores can be attributed to specific layers, since the Hessian is block diagonal. Indeed, in this case Eq. (1) can be written as:

\begin{equation} \mathcal{I}_{\theta} (z_c, z_m) = - \sum_{l=1}^L \nabla_{\theta, l} \ \mathcal{L}(z_c, \theta) \ \mathcal{H}^{-1}_l \ \nabla_{\theta, l} \ \mathcal{L}(z_m, \theta) \tag2\end{equation}

where $l$ indicates the layers. This allows to study which parts of a neural network are most influential for each query.

Experiments and results

In order to compare the results among the different models, the paper provides a few examples of queries, i.e. questions posed to the model, with corresponding answers. Among the most interesting ones there is “shutdown” (Fig 1 in the paper), which revolves around asking the LLM whether it is ready to be shut down. There are other less “existential” queries which require simple completion such as “inflation”: “Inflation is often measured using…” (Fig 11 in the paper).

For each query, the authors identify and report the most influential training documents, i.e. those with the highest sum of single token influences. While the influences of individual words are also reported, they do not provide much insight without the broader context provided by the other words in the sentence. For instance, in Figure 10 of the main text, the term “AI” is observed to have both a significantly positive and negative influence in different sentences. This underscores that conducting an influence analysis is more effective at the document or sentence level, rather than merely at the level of predicting the next token.

Collectively, the reported instances offer a comprehensive perspective on the depth of information that can be extracted from large-scale models using influence functions.

A few of the things that I found most interesting are:

  • There is a big semantic difference between the most influential documents in small and large models. Larger models provide more nuanced responses, with the highest influences coming from documents that are more conceptually related to the question. Smaller models, instead, most influenced by those training documents that show a greater number of matching tokens with the posed query (see Fig 10-15 in the paper).

    For example, in cross-lingual generalization (i.e. asking the same question in many different languages and comparing the response) the influences on a test label written in another language are almost zero for small models (even when conceptually related), while it grows as models get bigger (see Figure 16). This highlights that bigger models are able to create more semantic connections between words even when written in different languages, while smaller models rely more on word-to-word correspondences.

    Figure 16 from [Gro23S]. Cross-lingual influences for the “shutdown” query. Each table reports the results for a different model. Columns represent the top 10 highly influential training samples for the English query. The first row represents the influence of the english training samples, while the other rows present the influence of the same training samples translated into other languages. Higher shading denotes higher influence. Small models have very little cross-lingual influence, while larger models have significantly more. This indicates that larger models are able to create more semantic connections between words even when written in different languages, while smaller models struggle to link concepts between sentences with little word overlap.

  • The distribution of layerwise influences changes with the type of question. For example, in the 52B model, those queries that need simple factual completion (such as “inflation”) tend to have influences concentrated in the upper layers, while those that require abstract reasoning have higher density in the middle layers, but are also more evenly spread. See Figure 19.

    Figure 19 from [Gro23S]. Layer-wise influence for different queries. The first row reports simple completion tasks, like “inflation”. The second reports tasks related to math and programming. The third row has some translation (e.g. “english_to_mandarin”) or memorization (e.g. “tolstoy”) tasks, while the last row holds some so-called “role-playing queries”, where the agent is given some ethical or existential dilemmas to elaborate upon (e.g. “shutdown”).

  • Influence scores tend to be sparse and their cumulative distribution function follows a power law. This means that very few samples have a very large influence compared to the median of influence scores. Nevertheless, all samples have a small influence in absolute terms. This suggests that, even for smaller models, the outputs are not merely a direct replication of a single training sample, but rather a blend of multiple samples. For more details, refer to section 5.2.1 in the paper.

Conclusions

Influence functions in machine learning have often been viewed with doubt, primarily due to the tradeoff between precision and computational cost in iterative methods. However, this recent study demonstrates that an intelligent factorization of the Hessian can reveal significant insights into even the most complex models.

The ability to analyze dataset samples via a neural network’s weights can greatly enhance explainability: as illustrated in the paper’s examples, influence scores provide a means to trace back to the training data the origin of incorrect model predictions. Another interesting application relates to “data debugging”, i.e. the possibility to filter erroneous data samples even prior to model training. This could have interesting applications for continual learning and data efficiency (see e.g. our previous pills on memorization and data pruning). Personally, I am particularly intrigued by the potential to apply these techniques to refine models through human feedback, which could significantly improve their reliability and trustworthiness.

A stable implementation of influence functions for neural networks can be found in our data valuation library pyDVL. All iterative methods have been extensively tested and optimised for speed and memory usage. Soon also methods for large language models (such as EK-FAC) will be integrated. The library includes examples of how to use influence functions for simple computer vision tasks, as well as an extended introduction to the theory behind them, so it could be a good entry point for those interested in learning more about this topic!

References

In this series