-
Causal Interpretation of Self-Attention in Pre-Trained TransformersResearch/NLP_Paper 2024. 11. 11. 10:38
https://arxiv.org/pdf/2310.20307
(Oct 2023, NeurIPS)
※ 2024 NLP class team project's subject
Abstract
We propose a causal interpretation of self-attention in the Transformer neural network architecture. We interpret self-attention as a mechanism that estimates a structural equation model for a given input sequence of symbols (tokens). The structural equation model can be interpreted, in turn, as a causal structure over the input symbols under the specific context of the input sequence. Importantly, this interpretation remains valid in the presence of latent confounders. Following this interpretation, we estimate conditional independence relations between input symbols by calculating partial correlations between their corresponding representations in the deepest attention layer. This enables learning the causal structure over an input sequence using existing constraint-based algorithms. In this sense, existing pre-trained Transformers can be utilized for zero-shot causal-discovery. We demonstrate this method by providing causal explanations for the outcomes of Transformers in two tasks: sentiment classification (NLP) and recommendation.
1. Introduction
Causality plays an important role in many sciences such as epidemiology, social sciences, and finance (25; 34). Understanding the underlying causal mechanisms is crucial for tasks such as explaining a phenomenon, predicting, and decision making. An automated discovery of causal structures from observed data alone is an important problem in artificial intelligence. It is particularity challenging when latent confounders may exist. One family of algorithms (35; 7; 8; 46; 30), called constraint-based, recovers in the large sample limit an equivalence class of the true underlying graph. However, they require a statistical test to be provided. The statistical test is often sensitive to small sample size, and in some cases it is not clear which statistical test is suitable for the data. Moreover, these method assume that data samples are generated from a single causal graph, and are designed to learn a single equivalence class. There are cases where data samples may be generated by different causal mechanisms and it is not clear how to learn the correct causal graph for each data sample. For example, decision process may be different among humans. Data collected about their actions may not be represented by a single causal graph. This is the case for recommender systems that are required to provide a personalized recommendation for a user, based on her past actions. In addition, it is desirable that such automated systems will provide a tangible explanation to why the specific recommendation was given.
Recently, deep neural networks, based on the Transformer architecture (39) have been shown to achieve state-of-the-art accuracy in domains such as natural language processing (9; 3; 45), vision (10; 19), and recommender systems (36; 18). The Transformer is based on the attention mechanism, where it calculates context-dependent weights for a given input (32). Given an input consisting of a sequence of symbols (tokens), the Transformer is able to capture complex dependencies, but it is unclear how they are represented in the Transformer nor how to extract them.
In this paper we bridge between structural causal models and attention mechanism used in the Transformer architecture. We show that self-attention is a mechanism that estimates a linear structural equation model in the deepest layer for each input sequence independently, and show that it represents a causal structure over the symbols in the input sequence. In addition, we show that an equivalence class of the causal structure over the input can be learned solely from the Transformer’s estimated attention matrix. This enables learning the causal structure over a single input sequence, using existing constraint-based algorithms, and utilizing existing pre-trained Transformer models for zero-shot causal discovery. We demonstrate this method by providing causal explanations for the outcomes of Transformers in two tasks: sentiment classification (NLP) and recommendation.
2. Related Work
In recent years, several advances in causal reasoning using deep neural networks were presented. One line of work is causal modeling with neural networks (42; 12). Here, a neural network architecture follows a given causal graph structure. This neural network architecture is constructed to model the joint distribution over observed variables (12) or to explicitly learn the deterministic functions of a structural causal model (42). It was recently shown that a neural network architecture that explicitly learns the deterministic functions of an SCM can be used for answering interventional and counterfactual queries (43). That is, inference in rungs 2 and 3 in the ladder of causation (26), also called layers 2 and 3 in the Pearl causal hierarchy (2). Nevertheless, this architecture requires knowing the causal graph structure before training, and supports a single causal graph.
In another line of work, an attention mechanism is added before training, and a causal graph is inferred from the attention values at inference (21; 17). Attention values are compared against a threshold, and high values are treated as indication for a potential causal relation. Nauta et al. (21) validate these potential causal relations by permuting the values of the potential cause and measuring the effect. However, this method is suitable only if a single causal model underlay the dataset.
3. Preliminaries
First, we provide preliminaries for structural causal models and the attention mechanism in the Transformer-based neural architecture. Throughout the paper we denote vectors and sets with bolditalic upper-case (e.g., V ), matrices with bold upper-case (e.g., A), random variables with italic upper-case (e.g., X), and models in calligraphic font (e.g., M). Different sets of letters are used when referring to structural causal models and attention (see Appendix A-Table 1 for a list).
3.1. Structural Causal Models
3.2. Self-Attention Mechanism
4. A Link between Pre-trained Self-Attention and a Causal Graph underlying an Input Sequence of Symbols
In this section we first describe self-attention as a mechanism that encodes correlations between symbols in an unknown structural causal model (Section 4.1). After establishing this relation, we present a method for recovering an equivalence class of the causal graph underlying the input sequence (Section 4.2). We also extend the results to the multi-head and multi-layer architecture (Section 4.3).
4.1. Self-Attention as Correlations between Nodes in a Structural Causal Model
We now describe how self-attention can be viewed as a mechanism that estimates the values of observed nodes of an SCM. We first show that the covariance over the outputs of self-attention is similar to the covariance over observed nodes of an SCM. Specifically, we model relations between symbols in an input sequence using a linear-Gaussian SCM at the output of the attention layer. In an SCM, the values over endogenous nodes, in matrix form, are X = GX + ΛU, which means
4.2. Attention-based Causal-Discovery (ABCD)
In Section 4.1 we show that self-attention learns to represent pairwise associations between symbols of an input sequence, for which there exists a linear-Gaussian SCM that has the exact same pair-wise associations between its nodes. However, can we recover the causal structure G of the SCM solely from the weights of a pre-trained attention layer? We now present the Attention-Based Causal-Discovery (ABCD) method for learning an equivalence class of the causal graph that underlies a given input sequence.
Causal discovery (causal structure learning) from observed data alone requires placing certain assumptions. Here we assume the causal Markov (24) and faithfulness (35) assumptions. Under these assumptions, constraint-based methods use tests of conditional independence (CI-tests) to learn the causal structure (27; 35; 8; 7; 30; 31; 22). A statistical CI-test is used for deciding if two variables are statistically independent conditioned on a set of variables. Commonly, partial correlation is used for CI-testing between continuous, normally distributed variables with linear relations. This test requires only a pair-wise correlation matrix (marginal dependence) for evaluating partial correlations (conditional dependence). We evaluate the correlation matrix from the attention matrix. From Equation 8 the covariance matrix of output embeddings is
and (pairwise) correlation coefficients are
Unlike kernel-based CI tests (1; 11; 13; 14; 37; 48), we do not need to explicitly define or estimate the kernel, as it is readily available by a single forward-pass of the input sequence in the Transformer. This implies the following. Firstly, our CI-testing function is inherently learned during the training stage of a Transformer, by that enjoying the efficiency in learning complex models from large datasets. Secondly, since attention is computed for each input sequence uniquely, CI-testing is unique to that specific sequence. That is, conditional independence is tested under the specific context of the input sequence.
Finally, the learned causal graph represents an equivalence class in the form of a partial ancestral graph (PAG) (29; 47), which can also encode the presence of latent confounders. A PAG represents a set of causal graphs that cannot be refuted given the data. There are three types of edge-marks (at some node X): an arrow-head ‘—> X’, a tail ‘—–X’, and circle ‘—o X’ which represent an edge-mark that cannot be determined given the data. Note that reasoning from a PAG is consistent with every member in the equivalence class it represents.
What is the relation between the equivalence class learned by ABCD and the causal graph underlying the input sequence s = [s1, . . . , sn]?
4.3. Multiple Attention Layers and Multi-Head Attention as a Structural Causal Model
Commonly, the encoder module in the Transformer architecture (39) consists of multiple self attention layers executed sequentially. A non-linear transformation is applied to the the output embeddings of one attention layer before feeding it as input to the next attention layer. The non-linear transformation is applied to each symbol independently. In addition each attention layer consists of multiple self-attention heads, processing the input in parallel using head-specific attention matrix. The output embedding of each symbol is then linearly combined along the heads. It is important to note that embedding of each symbol is processed independently of embeddings of other symbols. The only part in which an embedding of one symbol is influenced by embeddings of other symbol is the multiplication by the attention matrix.
Thus, a multi-layer, multi-head, architecture can be viewed as a deep graphical model, where exogenous nodes of an SCM are estimated by a previous attention layer. The different heads in an attention layer learn different graphs for the same input, where their output embeddings are linearly combined. This can be viewed as a mixture model. We recover the causal graph from the last (deepest) attention layer, and treat earlier layers as context estimation, which is encoded in the values of the exogenous nodes (Figure 1).
4.4. Limitations of ABCD
The presented method, ABCD, is susceptible to two sources of error affecting its accuracy. The first is prediction errors of the Transformer neural network, and the second is errors from the constraint-based causal discovery algorithm that uses the covariance matrix calculated in the Transformer. The presence of errors from the second source depends on the first. For a Transformer with perfect generalization accuracy, and when causal Markov and faithfulness assumptions are not violated, the second source of errors vanishes, and a perfect result is returned by the presented ABCD method.
From Assumption 1, the deepest attention layer is expected to produce embeddings for each input symbol, such that the input symbol can be correctly predicted from its corresponding embedding (one-to-one embedding). However, except for trivial cases, generalization accuracy is not perfect. Thus, the attention matrix of an input sequence, for which the Transformer fails to predict the desired outcome, might include errors. As a result, the values of correlation coefficients calculated from the attention matrix might include errors.
Constraint-based algorithms for causal discovery are generally proved to be sound and complete when a perfect CI-test is used. However, if a CI-test returns an erroneous result, the constraint-based algorithm might introduce additional errors. This notion is called stability (35), which is informally measured by the number of output errors as a function of the number of CI-tests errors. In fact, constraint-based algorithm differ from one another in their ability to handle CI-test errors and minimize their effect. Thus, inaccurate correlation coefficients used for CI-testing might lead to erroneous independence relations, which in turn may lead to errors in the output graph.
We expect that as Transformer models become more accurate with larger training datasets, the accuracy of ABCD will increase. In this paper’s experiments we use existing pre-trained Transformers and common datasets, and use the ICD algorithm (30) that was shown to have state-of-the-art accuracy.
5. ABCD with Application To Explaining Predictions
One possible application of the proposed ABCD approach is to reason about predictions by generating causal explanations from pre-trained self-attention models such as BERT (9). Specifically, we seek an explaining set that consists of a subset of the input symbols that are claimed to be solely responsible for the prediction. That is, if the effect of the explaining set on the prediction is masked, a different (alternative) prediction is provided by the neural network.
It was previously claimed that attention cannot be used for explanation (16); however, in a contradicting paper (41), it was shown that explainability is task dependent. A common approach is to use the attention matrix to learn about input-output relations for providing explanations (33; 5; 4). These often rely on the assumption that inputs having high attention values with the class token influence the output (44; 6). We claim that this assumption considers only marginal statistical dependence and ignores conditional independence and explaining-away that may arise due to latent confounders.
To this end, we propose CLEANN (CausaL Explanations from Attention in Neural Networks). A description of this algorithm is detailed in Appendix B, and an overview with application to explaining movie recommendation is given in Figure 4. See also Nisimov et al. (23) for more details.
6. Empirical Evaluation
In this section we demonstrate how a causal graph constructed from a self-attention matrix in a Transformer based model can be used to explain which specific symbols in an input sequence are the causes of the Transformer output. We experiment on the tasks of sentiment classification, which classifies an input sequence, and recommendation systems, which generates a candidate list of recommended symbols (top-k) for the next item. For both experiments in this section we compare our method against two baselines from (38): (1) Pure-Attention algorithm (Pure-Atten.), that uses the attention weights directly in a hill-climbing search to suggest an explaining set, (2) Smart-Attention algorithm (Smart-Attn.), that adapts Pure-Attention, where a symbol is added to the explanation only if it reduces the score gap between the original prediction’s score and the second-ranked prediction’s score. Implementation tools are in https://github.com/IntelLabs/causality-lab.
6.1. Evaluation metrics
We evaluate the quality of inferred explanations using two metrics. One measures minimality of the explanation, and the other measures how specific is the explanation to the prediction.
6.1.1. Minimal explaining set
An important requirement is having the explaining set minimal in the number of symbols (40). The reason for that is twofold. (1) It is more complicated and less interpretable for humans to grasp the interactions and interplay in a set that contains many explaining symbols. (2) In the spirit of occum’s razor, the explaining set should not include symbols that do not contribute to explaining the prediction (when faced with a few possibles explanations, the simpler one is the one most likely to be true). Including redundant symbols in the explaining set might result in a wrong alternative predictions when the effect of this set is masked.
6.1.2. Influence of the explaining set on replacement prediction
The following metric is applicable only to tasks that produce multiple predictions, such as multiple recommendations by a recommender system (e.g., which movie to watch next), as opposed to binary classification (e.g., sentiment) of an input. Given an input sequence consisting of symbols, s = {s1, . . . , sn}, a neural network suggests s˜n+1, the 1st symbol in the top-k candidate list. CLEANN finds the smallest explaining set within s that influenced the selection of s˜n+1. As a consequence, discarding this explaining set from that sequence should prevent that s˜n+1 from being selected again, and instead a new symbol should be selected in replacement (replacement symbol). Optimally, the explaining set should influence the rank of only that 1st symbol (should be downgraded), but not the ranks of the other candidates in the top-k list. This requirement implies that the explaining set is unique and important for the isolation and counterfactual explanation of only that 1st symbol, whereas the other symbols in the original top-k list remain unaffected, for the most part. It is therefore desirable that after discarding the explaining set from the sequence, the new replacement symbol would be one of the original (i.e. before discarding the explaining set) top-k ranked symbols, optimally the 2nd .
6.2. CLEANN for Sentiment classification
We exemplify the explainability provided by our method for the task of sentiment classification of movie reviews from IMDB (20) using a pre-trained BERT model (9) that was fine-tuned for the task. The goal is to propose an explanation to the classification of a review by finding the smallest explaining set of word tokens (symbols) within the review. Figure 2 exemplifies the explanation of an input review that was classified as negative by the sentiment classification model. We compare between our method and two other baselines on finding a minimal explaining set for the classification. We see that both baselines found the word ‘bad’ as a plausible explanation for the negative sentiment classification, however in addition, each of them finds an additional explaining word (‘pizza’, ‘cinema’) that clearly has no influence on the sentiment in the context it appears in the review. Contrary to that, our method finds the words ‘bad’ and ‘but’ as explaining the negative-review classification. Indeed, the word ‘but’ may serve as a plausible explanation for the negative review at the 2nd part of the sentence, since it negates the positive 1st part of the sentence (about the pizza). Additionally, Figure 2(b) shows the corresponding attention map at the last attention layer of the model for the given input review. In addition, in Figure 2(c) we present the corresponding matrix representation of the learned graph (PAG). We can see that despite high attention values of the class token [cls] with some of the word tokens in the review (in the first row of the attention map), some of these words are found not to be directly connected to the class token in the causal graph produced by our method, and therefore may not be associated as influencing the outcome of the model.
Table 1 and Figure 3 show how the length of a review influences the explaining set size. CLEANN produces the smallest explaining sets on average, where statistical significance was tested using Wilcoxon signed-ranks test at significance level α = 0.01. For an increasing review length it is evident that the two baselines increase their explaining set, correspondingly. Even if the additional explaining tokens are of negative/positive context, they are not the most essential and prominent word token of this kind. Contrary to that, the explaining set size produced by our method is relatively stable, with a moderate increase in the explaining set size, meaning that it keeps finding the most influential words in a sentence that dominates the decision put forward by the model.
6.3. CLEANN for recommendation system
We assume that the human decision process, for selecting which items to interact with, consists of multiple decision pathways that may diverge and merge over time. Moreover, they may be influenced by latent confounders along this process. Formally, we assume that the decision process can be modeled by a causal DAG consisting of observed and latent variables. Here, the observed variables are user-item interactions {s1, . . . , sn} in a session S, and latent variables {H1, H2, . . .} represent unmeasured influences on the user’s decision to interact with a specific item. Examples for such unmeasured influences are user intent and previous recommendation slates presented to the user.
We exemplify the explainability provided by our method for the task of a recommendation system, and suggest it as a means for understanding of the complex human decision-making process when interacting with the recommender system. Using an imaginary session that includes the recommendation, we extract an explanation set from the causal graph. To validate this set, we remove it from the original session and feed the modified session into the recommender, resulting in an alternative recommendation that, in turn, can also be explained in a similar manner. Figure 4 provides an overview of our approach.
For empirical evaluation, we use the BERT4Rec recommender (36), pre-trained on the MovieLens 1M dataset (15) and estimate several measures to evaluate the quality of reasoned explanations.
Minimal explaining set:
Figure 5(a) compares the explaining set size for the various sessions produced by CLEANN and the baseline methods. It is evident that the set sizes found by CLEANN are smaller. Figure 5(b) shows the difference between the explaining set sizes found by the baseline method and CLEANN, as calculated for each session, individually. Approximately 25% of the sessions are with positive values, indicating smaller set sizes for CLEANN, zero values shows equality between the two, and only 10% of the sessions are with negative values, indicating smaller set sizes for Pure-Attention.
Influence of the explaining set on replacement recommendation:
Figure 6(a) compares the distribution for positions of replacement recommendations, produced by CLEANN and the baseline methods, in the original top-5 recommendations list. It is evident that compared to the baseline methods, CLEANN recommends replacements that were ranked higher (lower position) in the original top-5 recommendations list. In a different view, Figure 6(b) shows the relative gain in the number of sessions for each position, achieved by CLEANN compared to the baseline methods. There is a trend line indicating higher gains for CLEANN at lower positions. That is, the replacements are more aligned with the original top-5 recommendations. CLEANN is able to isolate a minimal explaining set that influences mainly the 1st item from the original recommendations list.
7. Discussion
We presented a relation between the self-attention mechanism used in the Transformer neural architecture and the structural causal model often used to describe the data-generating mechanism.
One result from this relation is that, under certain assumptions, a causal graph can be learned for a single input sequence (ABCD). This can be viewed as utilizing pre-trained models for zero-shot causal-discovery. An interesting insight is that the only source of errors while learning the causal structure is in the estimation of the attention matrix. Estimation is learned during pre-training the Transformer. Since in recent years it was shown that Transformer models scale well with model and training data sizes, we expect the estimation of the attention matrix to be more accurate as the pre-training data increases, thereby improving the accuracy of causal discovery.
Another result is that the causal structure learned for an input sequence can be employed to reason about the prediction for that sequence by providing causal explanations (CLEANN). We expect learned causal graphs to be able to answer a myriad of causal queries (24), among these are personalized queries that can allow a richer set of predictions. For example, in recommender systems, assuming that the human decision process consists of multiple pathways that merge and split, by identifying independent causal pathways, recommendations can be provided for each one, independently.
The results of this paper contribute to the fields of causal inference and representation learning by providing a relation that bridges concepts from these two domains. For example, they may lead to 1) causal reasoning in applications for which large-scale pre-trained models or unlabeled data exist, and 2) architectural modifications to the Transformer to alleviate causal inference.
'Research > NLP_Paper' 카테고리의 다른 글