DeepAI
Log In Sign Up

AtMan: Understanding Transformer Predictions Through Memory Efficient Attention Manipulation

Generative transformer models have become increasingly complex, with large numbers of parameters and the ability to process multiple input modalities. Current methods for explaining their predictions are resource-intensive. Most crucially, they require prohibitively large amounts of extra memory, since they rely on backpropagation which allocates almost twice as much GPU memory as the forward pass. This makes it difficult, if not impossible, to use them in production. We present AtMan that provides explanations of generative transformer models at almost no extra cost. Specifically, AtMan is a modality-agnostic perturbation method that manipulates the attention mechanisms of transformers to produce relevance maps for the input with respect to the output prediction. Instead of using backpropagation, AtMan applies a parallelizable token-based search method based on cosine similarity neighborhood in the embedding space. Our exhaustive experiments on text and image-text benchmarks demonstrate that AtMan outperforms current state-of-the-art gradient-based methods on several metrics while being computationally efficient. As such, AtMan is suitable for use in large model inference deployments.

READ FULL TEXT VIEW PDF

page 1

page 5

page 7

page 12

page 13

page 14

04/23/2022

Grad-SAM: Explaining Transformers via Gradient Self-Attention Maps

Transformer-based language models significantly advanced the state-of-th...
03/18/2020

Scene Text Recognition via Transformer

Scene text recognition with arbitrary shape is very challenging due to l...
02/15/2022

XAI for Transformers: Better Explanations through Conservative Propagation

Transformers have become an important workhorse of machine learning, wit...
09/06/2022

Analyzing Transformers in Embedding Space

Understanding Transformer-based models has attracted significant attenti...
07/02/2021

Learned Token Pruning for Transformers

A major challenge in deploying transformer models is their prohibitive i...
01/20/2021

PGT: Pseudo Relevance Feedback Using a Graph-Based Transformer

Most research on pseudo relevance feedback (PRF) has been done in vector...
10/07/2022

Understanding Transformer Memorization Recall Through Idioms

To produce accurate predictions, language models (LMs) must balance betw...

1 Explainability Through Attention Maps

Generalizing beyond single-task solutions using large-scale transformer-based language models has gained increasing attention from the community. In particular, the switch to open-vocabulary predictions promises AI systems capable of adapting beyond before-seen training objectives. Arguably, transformers are the state-of-the-art method in Natural Language Processing (NLP) and Computer Vision. Most recently, they demonstrated remarkable performance on multi-modal modes, e.g., bridging Computer Vision (CV) capabilities with text understanding to solve Visual Question Answering (VQA) scenarios 

[9, 17, 29, 28]. The increasing adoption of transformers, however, also raises the necessity to better understand the reasons behind their otherwise black-box predictions. Unfortunately, the “scale is all you need” assumption of transformers results in severely large and complex architectures, making their training, inference deployment, and understanding a resource-intensive task that requires multiple enterprise-grade GPUs or even entire computing nodes, along with prolonged runtimes.

Figure 1: “What am I looking at?” The proposed explainability method AtMan visualizes the most important aspects of the given image while completing the sequence (displayed above the relevance maps). The generative multi-modal model MAGMA is prompted to describe the shown image with: “Image This is a painting of ”. (Best viewed in color.)

Most, if not all, explainable AI (XAI) methods—making the decision-making processes and internal workings of AI models transparent and understandable to humans—for transformers work by propagating (some form of) gradients back through the model. This backpropagation allows for the accumulation of information about how each input feature contributes to output tokens [6, 1], utilizing stored activations during the forward pass. Unfortunately, this leads to a significant memory consumption overhead, which renders their productive deployment to be uneconomical, if not impossible. Often half of the available memory of the GPU has to remain empty on inference, or it requires an entirely separate deployment of the XAI pipeline.

Fortunately, another popular XAI idea, namely perturbation [18, 23], is much more memory-efficient. Though it has not been proven beneficial for explaining the predictions of transformers, most likely because of the immense number of necessary forward trials accumulating unreasonable computation time.

To tackle these issues and, in turn, scale explanations with the size of transformers, we propose to bridge relevance propagation and perturbations. In contrast to existing perturbation methods, executing perturbations directly in the input space, we push them into the latent space, allowing, as we will show, state interpolation and token-based similarity measures. Specifically, inspired by

[10] and backpropagation approaches, we introduce attention manipulations throughout latent layers of the transformer during the forward pass as a method to steer model predictions. Our explanation methods, called AtMan

, then leverages these predictions to compute relevance values for transformer networks. Our experimental evidence demonstrates that

AtMan significantly reduces the number of required perturbations, making them applicable at deployment time, and does not require additional memory compared to the forward passes. In short, AtMan can scale with transformers. Our exhaustive experiments on text and image-text benchmarks also demonstrate that AtMan outperforms current state-of-the-art based on gradients while being computationally efficient. Actually, for the first time, AtMan allows one to study generative model predictions as visualized in Fig. 1. During the sequence generation with large multi-modal models, AtMan is able to additionally highlight relevant features wrt. the input proving novel insights on the generation process.

Contributions.

In summary, our contributions are: (i) An examination of the effects of token-based attention score manipulation on generative transformer models. (ii) The introduction of a novel and memory-efficient XAI perturbation method for large-scale transformer models, called AtMan, which reduces the number of required iterations to a computable amount by correlating tokens in the embedding space. (iii) Exhaustive multi-modal evaluations of XAI methods on several text and image-text benchmarks and autoregressive (AR) transformers. We release the source code of the proposed method and all evaluation scriptshttps://github.com/Mayukhdeb/atman-magma.

We proceed as follows. We start off by discussing related work. Then, we derive AtMan and explain its attention manipulation as a perturbation technique. Before concluding and discussing the benefits as well as limitations, we touch upon our experimental evaluation, showing that AtMan not only nullifies memory overhead but also outperforms competitors on several visual and textual reasoning benchmarks.

2 Related Work

Explainability in CV and NLP.

Explainability of AI systems is a still ambiguously defined term [7]. XAI methods are expected to show some level of relevance on the input with respect to the computed result of an algorithm. This task is usually tackled by constructing an input relevance map given the model’s prediction. The nature of relevance can be class-specific, e.g., depending on specific target instances of a task and showing a local solution [25, 26], or class-agnostic, i.e., depending on the global behavior of the model behavior only [1, 3]. The level of fine granularity of the achieved explanation depends, therefore, on the chosen method, the model, and the actual evaluation benchmark.

Explainability in CV is usually evaluated by mapping the relevance maps to a pixel level and regard the evaluation as a weak segmentation task [24, 19, 26]. On the other hand, NLP explanations are much more vaguely defined and usually mixed with more complex philosophical interpretations, such as labeling a given text to a certain sentiment category [7].

The majority of XAI methods can be divided into the classes of perturbation and gradient analysis. Perturbations treat the model as a black box and attempt to derive knowledge of the model’s behavior by studying changes in input-output pairs only. Gradient-based methods, on the other hand, execute a backpropagation step towards a target and aggregate the model’s parameter adoptions to derive insights.

Most of these XAI methods usually are not motivated by a specific discipline, e.g., neither by NLP nor CV. They are so generic that they can be applied to both disciplines, to some extent. However, architecture-specific XAI methods exist, such as GradCAM [24]

, leveraging convolutional neural networks’ spatial input aggregation in the deepest layers to increase efficiency.

Explainability in Transformers.

Through their increasing size, transformers are particularly challenging for explainability methods, especially for architecture-agnostic ones. Transformers’ core components, in particular, include an embedding layer followed by multiple layers of alternating attention and feed-forward blocks. The attention blocks map the input into separate “query”, “key”, and “value” matrices and are split into an array of “heads”. As with convolutions in CNN networks, separation heads are believed to relate to specific learned features or tasks [12]. Further, the attention matrix dimensions match that of the input sequence dimension, which makes the attention mechanism in particular suited for deriving input explanations.

Consequently, most explainability adoptions to transformers focus on the attention mechanism. rollout (rollout) assume that activations in attention layers are combined linearly and considered paths along the pairwise attention graph. However, while being efficient, it often emphasizes irrelevant tokens, in particular, due to its class-agnostic nature. Therefore, the authors also propose attention flow [1], which is unfeasible to use due to its high computational demands in constructing graphs. More recently, Chefer_2021_CVPR (Chefer_2021_CVPR) proposed to aggregate backward gradients and LRP [19] throughout all layers and heads of the attention modules in order to derive explanation relevancy. Their introduced method outperforms previous transformer-specific and unspecific XAI methods on several benchmarks and transformer models. This method is extended to multimodal transformers [5] by studying other variations of attention. However, the evaluated benchmarks only include classification tasks, despite transformers’ remarkable performance on open-vocabulary tasks, e.g., utilizing InstructGPT [20] or multimodal autoregressive transformers such as MAGMA [9], BLIP [17] and OFA [29].

Mutlimodal Transformers.

Contrarily to these explainability studies evaluating on models like DETR and ViT [4, 8]

, we study explainability on generated text tokens of a language model, and not specifically trained classifiers. Due to the multimodality, the XAI method should produce output relevancy either on the input text or the input image as depicted in Fig. 

1.

To this end, we study the explainability of multimodal transformer architectures such as MAGMA [9].111

An open-source version can be found at

https://github.com/Aleph-Alpha/magma . Specifically, to obtain image modality, magma (magma) propose to fine-tune a frozen pre-trained language model by adding sequential adapters to the layers, leaving the attention mechanism untouched. It uses a CLIP [21] vision encoder to produce image embeddings. These embeddings are afterward treated as equal, in particular regarding other modalities, and input tokens during model execution. This methodology has shown competitive performance compared to single-task solutions (c.f. [9]).

3 AtMan: Attention Manipulation

We formulate finding the best explainability estimator of a model as solving the following question:

What is the most important part on the input, annotated by the explanator, to produce the model’s output? In the following, we derive our perturbation probe mathematically through studies of influence functions and embedding layer updates on autoregressive (AR) models [13, 2].

Then we show how attention manipulation on single tokens can be used in NLP tasks to steer the prediction of a model in directions found within the prompt. Finally, we derive our multi-modal XAI method AtMan by extending this concept to the cosine neighborhood in the embedding space.

3.1 Influence Functions as Explainability Estimators

Transformer-based language models are probability distribution estimators. They map from some input space

(e.g. text or image embeddings) to an output space

(e.g. language token probabilities). Let

be the space of all explanations (i.e. binary labels) over . An explanator function can then be defined as

i.e. given a model, an input, and a target, derive a label on the input.

Figure 2: Transformer decoder architecture. The left-hand side shows the general components: The token embeddings pass through

transformer blocks to produce output logits, e.g., taken for the next token prediction in the setting of a generative language model. The middle shows in detail a masked attention block, consisting of MatMul, Mask, and SoftMax steps. The right-hand side shows our proposed Attention Manipulation method. We multiply the modifier factors and the attention scores, before applying the diagonal causal attention mask. Red hollow boxes (

) indicate one values, and green ones () -infinity. (Best viewed in color.)

Given a sequence , an AR language model assigns a probability to that sequence by applying factorization . The loss optimization during training can then be formalized by solving:

(1)
(2)

Here denotes the model, the learned embedding matrix, and the vocabulary index of the target of length . Eq. 1 is derived by integrating the cross-entropy loss, commonly used during language model training with . Finally,

denotes our loss function.

Perturbation methods study the influence of the model’s predictions by adding small noise to the input and measuring the prediction change. We follow the results of the studies [13, 2] to approximate the perturbation effect directly through the model’s parameters when executing Leaving-One-Out experiments on the input. The influence function estimating the perturbation of an input is then derived as:

(3)

Here denotes the set of model parameters in which would not have been seen during training. In the following, we further show how to approximate .

3.2 Single Token Attention Manipulation

The core idea of AtMan is the shift of the perturbation space from the raw input space to the embedded token space. This allows us to reduce the dimensionality of possible perturbations down to a single scaling factor per token. Moreover, we do not manipulate the value matrix of attention blocks and therewith do not introduce the otherwise inherent input-distribution shift of obfuscation methods. By manipulating the attention entries at the positions of the corresponding input sequence tokens, we are able to interpolate the focus of the prediction distribution of the model—amplifying or suppressing concepts of the prompt. The following shows that this procedure indeed derives a well-performing XAI method.

Attention was introduced in [27] as: , where denotes matrix multiplication. The pre-softmax query-key attention scores are defined as:

In the case of autoregression, a lower left triangular unit mask M is applied to these scores as with the Hadamard product. The output of the self-attention module is , the query matrix is and the keys and values matrices. Finally . The number of heads is denoted as , and is the embedding dimension of the model. Finally, there are query and key tokens that coincide here with the dimension of input-sequence tokens.

Figure 3: Illustration of the proposed explainability method. First, we collect the original cross-entropy score of the target tokens (1). Then we iterate and suppress one token at a time, indicated by the red box (), and track changes in the cross-entropy score of the target token (2). (Best viewed in color.)

The perturbation approximation required by Sec. 3.1 can now be approximated through attention score manipulation as follows: Let w be an input token sequence of length . Let be a token index within this sequence to be perturbated by a factor . For all layers and all heads we modify the pre-softmax query-key attention scores as:

(4)

where denotes the matrix containing only ones and the suppression factor matrix for token . In this section we set , for and and elsewhere. As depicted in Fig. 2 we thus only amplify the column of the attention scores of by a factor . This, however for all heads equally.222We follow the common assumption that all relevant entropy of the input token is processed primarily at that position within the attention module due to the sequence-to-sequence nature of the transformer. A different variant of this approach is discussed in Appendix A.5. Let us denote this modification to the model by and assume a fixed factor .333We ran a parameter sweep once to fix this parameter.

We define for a class label target

the explanation as the vector of the influence functions to all positions:

(5)

with derived by Eq. 2 and Eq. 3 as

In words, we average the cross-entropy of the AR input sequence wrt. all target tokens and measure the change when suppressing token index to the unmodified one. The explanation becomes this difference vector of all possible sequence position perturbations and thus requires iterations.

Figure 4: Manipulating the attention scores of a single token (highlighted in blue) inside a transformer block steers the model’s prediction into a different contextual direction (amplifications highlighted in green, suppression in red). (Best viewed in color.)

Fig. 3 illustrates this algorithm. The original input prompt is the text “Ben likes to eat burgers and ” for which we want to extract the most valuable token for the completion and target token “fries”. Initially, the model predicts the target token with a cross-entropy score of . We now iterate through the input tokens, suppressing them one by one, and track the changes in the cross-entropy of the target token, as depicted in the right-most column. In this example, it can be observed that “burgers” was the most-influential input token to complete the sentence with “fries”, with the highest score of .

Next, we give a more descriptive intuition about the effects of such modifications on the model’s generative nature.

Token attention suppression steers the model’s prediction.

Intuitively, for factors , we call the modifications “suppression”, as we find the model’s output now relatively less influenced by the token at the position of the respective manipulated attention scores. Contrarily, “amplifies” the influence of the manipulated input token on the output.

An example of the varying continuations when a single token manipulation is applied can be seen in Fig. 4. We provide the model a prompt in which the focus of continuation largely depends on two tokens, namely “soccer” and “math”. We show how suppressing and amplifying them alters the prediction distributions away from or towards to those concepts. It is precisely this distribution shift we measure and visualize as our explainability.

Figure 5: Correlated token suppression of AtMan enhances explainability in the image domain. i) Shows an input image along with three perturbation examples (). In we only suppress a single image token (blue), in the same token with its relative cosine neighborhood (yellow). In a non-related token with its neighborhood. Below are depicted the changes in cross-entropy loss. is the original score for the target token “panda”. denotes the loss change. ii) Shows the label, the resulting explanation without Cosine Similarity (CS) and with CS. (Best viewed in color.)

3.3 Correlated Token Attention Manipulation

Suppressing single tokens works well when the entire entropy responsible to produce the target token occurs only once. However, for inputs with redundant information, this approach would often fail. This issue is, in particular, prominent in the field of CV, as information, e.g., about objects in an image, is often spread across several embeddings due to the split of image parts and the separate application of embedding function. It is a common finding that applied cosine similarity in the embedding space, e.g., right after the embedding layer, gives a good correlation distance estimator [16, 2]. We integrate this finding into AtMan in order to suppress all redundant information corresponding to a particular input token at once, which we refer to as correlated token suppression.

Fig. 5 summarizes the correlated token suppression visually. For input tokens and embedding dimension , the embedded tokens result in a matrix . The cosine similarity, in turn, is computed from the normalized embeddings , with , for , as . Note that the index denotes a column corresponding to the respective input token index. Intuitively, the vector then contains similarity scores to all (other) input tokens. Suppressing the correlated neighborhood to a specific token with the index , we, therefore, adjust the suppression factor matrix for Eq. 4 as

(6)

As we only want to suppress tokens, we restrict the range of factor values to be greater than . The parameter is to ensure a lower bound, and in particular, prevents a sign flip. We empirically fixed through a parameter sweep (Appendix A.4).

With that, we arrived at our final version of AtMan. As a final remark note that this form of explanation is local, as target refers to our target-class. We can however straightforward derive a global explanation by setting , for y a model completion to input w of certain length. It could then be interpreted rather abstract as the model’s general focus [3].

Figure 6: An example of a single instance of the SQuAD dataset with AtMan Explanations. An instance contains three questions for a given context, each with a labeled answer pointing to a fragment of the context. AtMan is used to explain, i.e., highlight, the corresponding fragments. It can be observed that the green example is fully, the blue partially, and the yellow not at all recovered according to the labels. However, the yellow highlight seems at least related to the label. (Best viewed in color.)

4 Empirical Evaluation

We ran empirical evaluations on text and image corpora to address the following questions: (Q1) Does AtMan achieve competitive results compared to previous XAI for transformers in the language as well as vision domain? (Q2) Does AtMan scale efficiently and, therefore, can be applied to current large-scale AR models?

To answer these questions, we conducted empirical studies on textual and visual XAI benchmarks and compared AtMan to standard approaches such as IxG [25], IG [26], GradCAM [24] and the transformer-specific XAI method of [6] called Chefer in the following. Note that all these methods utilize gradients and, therefore, categorize as propagation methods leading to memory inefficiency. We also applied existing perturbation methods such as LIME [23] and SHAP [18]. However, they failed due to extremely large trials and, in turn, prohibitive computation time. We adopt common metrics, namely mean average precision (mAP) and recall (mAR), and state their interquartile statistics in all experiments. Whereas through its memory efficiency AtMan can be utilized on larger models, to provide a comparison between XAI methods, we ran the corresponding experiments on MAGMA-6B444Available at https://github.com/aleph-alpha/magma . if not stated otherwise.

4.1 AtMan  can do Language reasoning

Protocol.

Since with AtMan we aim to study large-scale generative models, we formulate XAI on generative tasks as described in Sec. 3.3. To this end, we used the Stanford Question Answering (QA) Dataset (SQuAD) [22]. The QA dataset is structured as follows: Given a single paragraph of information, there are multiple questions, each with a corresponding answer referring to a position in the paragraph. A visualization of an instance of this dataset can be found in Fig. 6. In total, SQuAD contains 536 unique paragraphs and 107,785 question/explanation pairs. The average context sequence length is tokens, and the average label (explanation) length is .

Figure 7: AtMan produces less noisy and more focused explanations when prompted with multi-class weak segmentation compared to Chefer. The three shown figures are prompted to explain the target classes above and below separately. It can be observed that both methods produce reasonable, and even similar output. Though more sensitivity and more noise is observed on the method of Chefer. In particular on the last example, for the target “birthday”, Chefer highlights more details like the decoration. However the same is also derived to some extent when just prompting “bear”. (Best viewed in color.)
IxG IG Chefer AtMan
mAP 51.7 49.5 72.7 73.7
mAP 61.4 49.5 77.5 81.8
mAR 91.8 87.1 96.6 93.4
mAR 100 98.6 100 100
Table 1: AtMan outperforms XAI methods on the QA dataset SQuAD. Shown are (interquartile) mean average precision and mean average recall (the higher, the better). Best and second best values are highlighted with and .

The model was prompted with the template: “{Context} Q: {Question} A:”, and the explainability methods executed to derive scores for the tokens inside the given context, c.f. Fig. 6

. If there were multiple tokens in the target label, we computed the average of the scores for the target token. Similar to weak segmentation tasks in computer vision, we regarded the annotated explanations as binary labels and determined precision and recall over all these target tokens.

Results.

The results are shown in Tab. 1. It can be observed that the proposed AtMan method thoroughly outperforms all previous approaches by means of mean average precision. This statement holds as well for the mean average interquartile recall. However, on the mean average recall Chefer slightly outperforms AtMan. Furthermore, it is noteworthy that the small average explanation length (such as depicted in Fig. 6) results in high values for recall scores in all methods. Further details and some qualitative examples can be found in Appendix A.2.

Paragraph Chunking.

AtMan can naturally be lifted to the explanation of paragraphs. We ran experiments for AtMan splitting the input text into a few paragraphs by splitting by common delimiters and evaluating the resulting chunks simultaneously, despite token-wise evaluations. This significantly decreases the total number of required forward passes and, on top, produces “more human” text explanations of the otherwise still heterogeneously highlighted word parts. Results are shown in Appendix A.8.

4.2 AtMan  can do Visual reasoning

Protocol.

Similar to language reasoning, we again perform XAI on generative models. We evaluated the OpenImages [15]

dataset as VQA task and generated open-vocabulary prediction with the autoregressive model. Specifically, the model is prompted with the template: “{Image} This is a picture of ”, and the explainability methods executed to derive scores for the pixels of the image with respect to the target label. If there were multiple tokens in the target label, we take the average of the generated scores for each target token. For evaluation, we considered the segmentation annotations of the dataset as ground truth explanations. The segmentation subset contains 2,7M annotated images for 350 different classes. In order to ensure a good performance of the large-scale model at hand and, in turn, adequate explanations of the XAI methods, we filtered the images for a minimum dimension of

pixels and a maximal proportional deviation between width and height of . Moreover, we randomly sample images per class to avoid overweighting classes. This filtering leads to a dataset of samples. The average context sequence length is tokens and the average label coverage is of the input image.

Quantitative Results.

The results are shown in Tab. 2. It can be observed that AtMan thoroughly outperforms all other XAI approaches on the visual reasoning task for all metrics. Note how explicit transformer XAI methods (AtMan, Chefer) in particular outperform generic methods (GradCAM, IG, IxG) in recall. Moreover, while being memory-efficient (see next section), AtMan also generates more accurate explanations compared to Chefer. Through the memory efficiency of AtMan, we were able to evaluate an intermediate version of a 30B upscaling trial of MAGMA (c.f. Tab. 2). Interestingly, the general explanation performance slightly decreases compared to the 6B model variant. This could be attributed to the increased complexity of the model and, subsequently, the complexity of the explanation at hand. Hence, it is not expected that the “human” alignment with the model’s explanations scales with their size.

IxG IG GradCAM Chefer AtMan AtMan
mAP 38.0 46.1 56.7 49.9 65.5 61.2
mAP 34.1 45.2 60.4 50.2 70.2 65.1
mAR 0.2 0.3 0.1 11.1 13.7 12.2
mAR 0.1 0.1 0.1 10.1 19.7 14.5
Table 2: AtMan outperforms XAI methods on the VQA benchmark of OpenImages. Shown are (interquartile) mean average precision and mean average recall (the higher, the better). Best and second best values are highlighted with and . XAI methods are evaluated on a 6B model, except the last column, in which case only AtMan succeeds in generating explanations.

Qualitative Illustration.

Fig. 21 shows several generated image explanations of AtMan and Chefer for different concepts. More examples of all methods can be found in Appendix A.7. We generally observe more noise in gradient-based methods, in particular around the edges. Note that as VQA only changes target-tokens, we do not need to evaluate the prompt more than once with the AtMan method for different object classes.

In general, the results clearly provide an affirmative answer to (Q1): AtMan  is competitive with previous XAI methods, including transformer-specific ones. Next, we will analyze the computational efficiency of AtMan.

4.3 AtMan can do large scale

While AtMan shows competitive performance, it computes, unlike previous approaches, explanations at almost no extra memory cost. Fig. 8 illustrates the runtime and memory consumption on a single NVIDIA A100 80GB GPU. We evaluated the gradient-based transformer XAI method [6] and AtMan. The statistics vary in sequence lengths (colors) from 128 to 1024 tokens, and all experiments are executed with batch size 1 for better comparison.

One can observe that the memory consumption of AtMan is around that of the forward pass (Baseline; green) and increases only marginally over the sequence lengths. In comparison, the method of [6]—and other gradient-based methods—exceeds the memory limit with more than double in memory consumption. Therefore, they fail on larger sequence lengths.

Whereas the memory consumption of AtMan stays almost constant, the execution time significantly increases over sequence length when no further token aggregation is applied upfront. However, note that the exhaustive search loop of AtMan can be run in parallel to decrease its runtime. In particular, this can be achieved by increasing the batch size and naturally by a pipeline-parallel555https://pytorch.org/docs/stable/pipeline.html execution. For instance, since large models beyond 100B are scattered among nodes and thus many GPUs, the effective runtime is reduced by magnitudes to a proximate scale of the forward pass.

Overall, these results clearly provide an affirmative answer to (Q2): Through the memory efficiency of AtMan, it can be applied to large-scale transformer-based models.

Figure 8: AtMan scales efficiently. Performance comparison of the explainability methods AtMan and Chefer et al. over various model sizes (x-axis) executed on a single 80GB memory GPU. Current gradient-based approaches do not scale; only AtMan can be utilized on large-scale models. Solid lines refer to the GPU memory consumption in GB (left y-axis). Dashed lines refer to the runtime in seconds (right y-axis). Colors indicate experiments on varying input sequence lengths. As baseline (green) a plain forward pass with a sequence length of 1024 is measured. (Best viewed in color.)

5 Conclusion

We proposed AtMan, a modality-agnostic perturbation-based XAI method for transformer networks. In particular, AtMan reduces the complex issue of finding proper perturbations to a single scaling factor per token. As our experiments demonstrate, AtMan outperforms current approaches relying on gradient computation. AtMan is memory-efficient and requires forward passes only, enabling its utilization for deployed large models.

However, some limitations remain unresolved. Whereas AtMan reduces the overall noise on the generated explanation, when compared to gradient-based methods, undesirable artifacts still remain. It is unclear to what extent this is due to the method or the underlying transformer architecture. Through AtMan’s memory efficiency, one is able to evaluate whether models’ explanatory capabilities scale with their size. The extent to which larger models produce explanations that are more difficult to understand, is a question that arises when comparing performance scores. Consequently, scaling explainability with model size should be further studied. Besides this, our paper provides several avenues for future work, including explanatory studies of current generative models impacting our society. Furthermore, it could lay the foundation for not only instructing and, in turn, improving the predictive outcome of autoregressive models based on human feedback [20] but also their explanations [11].

Acknowledgments

This research has benefited from the Hessian Ministry of Higher Education, Research, Science and the Arts (HMWK) cluster projects “The Third Wave of AI” and hessian.AI as well as from the German Center for Artificial Intelligence (DFKI) project “SAINT”. Further, we thank Manuel Brack, Felix Friedrich, Marco Bellagente and Constantin Eichenberg for their valuable feedback.

References

  • [1] S. Abnar and W. H. Zuidema (2020) Quantifying attention flow in transformers. In Proceedings of the Annual Meeting of the Association for Computational Linguistics, (ACL), pp. 4190–4197. Cited by: §1, §2, §2.
  • [2] D. Bis, M. Podkorytov, and X. Liu (2021) Too much in common: shifting of embeddings in transformer language models and its implications. In Proceedings of the Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pp. 5117–5130. Cited by: §3.1, §3.3, §3.
  • [3] N. Burkart and M. F. Huber (2021) A survey on the explainability of supervised machine learning. J. Artif. Intell. Res. 70, pp. 245–317. Cited by: §2, §3.3.
  • [4] N. Carion, F. Massa, G. Synnaeve, N. Usunier, A. Kirillov, and S. Zagoruyko (2020) End-to-end object detection with transformers. In Proceedings of the European Conference of Computer Vision (ECCV), Lecture Notes in Computer Science, pp. 213–229. Cited by: §2.
  • [5] H. Chefer, S. Gur, and L. Wolf (2021)

    Generic attention-model explainability for interpreting bi-modal and encoder-decoder transformers

    .
    In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pp. 397–406. Cited by: §A.6, §2.
  • [6] H. Chefer, S. Gur, and L. Wolf (2021) Transformer interpretability beyond attention visualization. In

    Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)

    ,
    pp. 782–791. Cited by: §1, §4.3, §4.3, §4.
  • [7] M. Danilevsky, K. Qian, R. Aharonov, Y. Katsis, B. Kawas, and P. Sen (2020) A survey of the state of explainable AI for natural language processing. In Proceedings of the Conference of the Asia-Pacific Chapter of the Association for Computational Linguistics and the International Joint Conference on Natural Language Processing, (AACL/IJCNLP), pp. 447–459. Cited by: §2, §2.
  • [8] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly, J. Uszkoreit, and N. Houlsby (2021) An image is worth 16x16 words: transformers for image recognition at scale. In International Conference on Learning Representations (ICLR), Cited by: §2.
  • [9] C. Eichenberg, S. Black, S. Weinbach, L. Parcalabescu, and A. Frank (2022) MAGMA – multimodal augmentation of generative models through adapter-based finetuning. In Findings of EMNLP, External Links: Link Cited by: §1, §2, §2.
  • [10] N. Elhage, N. Nanda, C. Olsson, T. Henighan, N. Joseph, B. Mann, A. Askell, Y. Bai, A. Chen, T. Conerly, N. DasSarma, D. Drain, D. Ganguli, Z. Hatfield-Dodds, D. Hernandez, A. Jones, J. Kernion, L. Lovitt, K. Ndousse, D. Amodei, T. Brown, J. Clark, J. Kaplan, S. McCandlish, and C. Olah (2021) A mathematical framework for transformer circuits. Transformer Circuits Thread. Note: https://transformer-circuits.pub/2021/framework/index.html Cited by: §1.
  • [11] F. Friedrich, W. Stammer, P. Schramowski, and K. Kersting (2023) A typology to explore and guide explanatory interactive machine learning. Nature Machine Intelligence. Cited by: §5.
  • [12] J. Jo and S. Myaeng (2020) Roles and utilization of attention heads in transformer-based neural language models. In Proceedings of the Annual Meeting of the Association for Computational Linguistics, (ACL), pp. 3404–3417. Cited by: §2.
  • [13] P. W. Koh and P. Liang (2017) Understanding black-box predictions via influence functions. In Proceedings of the International Conference on Machine Learning, (ICML), Proceedings of Machine Learning Research, Vol. 70, pp. 1885–1894. Cited by: §3.1, §3.
  • [14] N. Kokhlikyan, V. Miglani, M. Martin, E. Wang, B. Alsallakh, J. Reynolds, A. Melnikov, N. Kliushkina, C. Araya, S. Yan, and O. Reblitz-Richardson (2020)

    Captum: a unified and generic model interpretability library for pytorch

    .
    External Links: 2009.07896 Cited by: §A.1.
  • [15] I. Krasin, T. Duerig, N. Alldrin, V. Ferrari, S. Abu-El-Haija, A. Kuznetsova, H. Rom, J. Uijlings, S. Popov, A. Veit, S. Belongie, V. Gomes, A. Gupta, C. Sun, G. Chechik, D. Cai, Z. Feng, D. Narayanan, and K. Murphy (2017) OpenImages: a public dataset for large-scale multi-label and multi-class image classification.. Dataset available from https://github.com/openimages. Cited by: §4.2.
  • [16] R. R. Larson (2010) Introduction to information retrieval. J. Assoc. Inf. Sci. Technol. 61, pp. 852–853. Cited by: §3.3.
  • [17] J. Li, D. Li, C. Xiong, and S. C. H. Hoi (2022) BLIP: bootstrapping language-image pre-training for unified vision-language understanding and generation. In Proceedings of the International Conference on Machine Learning (ICML), Cited by: §1, §2.
  • [18] S. M. Lundberg and S. Lee (2017) A unified approach to interpreting model predictions. In Proceedings of Advances in Neural Information Processing Systems: Annual Conference on Neural Information Processing Systems (NeurIPS), pp. 4765–4774. Cited by: §1, §4.
  • [19] G. Montavon, A. Binder, S. Lapuschkin, W. Samek, and K. Müller (2019) Layer-wise relevance propagation: an overview. In

    Explainable AI: Interpreting, Explaining and Visualizing Deep Learning

    ,
    Lecture Notes in Computer Science, Vol. 11700, pp. 193–209. Cited by: §2, §2.
  • [20] L. Ouyang, J. Wu, X. Jiang, D. Almeida, C. Wainwright, P. Mishkin, C. Zhang, S. Agarwal, K. Slama, A. Gray, J. Schulman, J. Hilton, F. Kelton, L. Miller, M. Simens, A. Askell, P. Welinder, P. Christiano, J. Leike, and R. Lowe (2022) Training language models to follow instructions with human feedback. In Proceedings of Advances in Neural Information Processing Systems: Annual Conference on Neural Information Processing Systems (NeurIPS), Cited by: §2, §5.
  • [21] A. Radford, J. W. Kim, C. Hallacy, A. Ramesh, G. Goh, S. Agarwal, G. Sastry, A. Askell, P. Mishkin, J. Clark, G. Krueger, and I. Sutskever (2021) Learning transferable visual models from natural language supervision. In Proceedings of the International Conference on Machine Learning (ICML), Proceedings of Machine Learning Research, pp. 8748–8763. Cited by: §2.
  • [22] P. Rajpurkar, J. Zhang, K. Lopyrev, and P. Liang (2016) SQuAD: 100, 000+ questions for machine comprehension of text. In Proceedings of the Conference on Empirical Methods in Natural Language Processing (EMNLP), pp. 2383–2392. Cited by: §4.1.
  • [23] M. T. Ribeiro, S. Singh, and C. Guestrin (2016) ”Why should I trust you?”: explaining the predictions of any classifier. In Proceedings of the ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), pp. 1135–1144. Cited by: §1, §4.
  • [24] R. Rs, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra (2020-02) Grad-CAM: visual explanations from deep networks via gradient-based localization. International Journal of Computer Vision 128, pp. . External Links: Document Cited by: §2, §2, §4.
  • [25] A. Shrikumar, P. Greenside, and A. Kundaje (2017) Learning important features through propagating activation differences. In Proceedings of the International Conference on Machine Learning (ICML), Proceedings of Machine Learning Research, Vol. 70, pp. 3145–3153. Cited by: §2, §4.
  • [26] M. Sundararajan, A. Taly, and Q. Yan (2017) Axiomatic attribution for deep networks. Proceedings of the International Conference on Machine Learning, (ICML), pp. 3319–3328. Cited by: §2, §2, §4.
  • [27] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in Neural Information Processing Systems: Annual Conference on Neural Information Processing Systems (NeurIPS), pp. 5998–6008. Cited by: §3.2.
  • [28] J. Wang, Z. Yang, X. Hu, L. Li, K. Lin, Z. Gan, Z. Liu, C. Liu, and L. Wang (2022) GIT: a generative image-to-text transformer for vision and language. Transactions on Machine Learning Research. Cited by: §1.
  • [29] P. Wang, A. Yang, R. Men, J. Lin, S. Bai, Z. Li, J. Ma, C. Zhou, J. Zhou, and H. Yang (2022) Unifying architectures, tasks, and modalities through a simple sequence-to-sequence learning framework. In Proceedings of the International Conference on Machine Learning (ICML), Cited by: §1, §2.

Appendix A Appendix

a.1 Remarks on executed benchmarks

We executed all benchmarks faithfully and to the best of our knowledge. The selection of compared methods was made to be rather diverse and obtain a good overview in this field of research. In particular, with regards to the multi-modal transformer scaling behavior, as there are in fact no such studies for AR models yet to compare to. It is possible, for all methods, that there are still improvements we missed in quality as well as performance. However, we see the optimizations of other methods to multi-modal AR transformer models as a research direction on its own.

Chefer.

The integration of Chefer was straightforward. As it can be derived by the visualizations, there are noticeable artifacts, particularly on the edges of images. In this work the underlying transformer model was MAGMA, which is finetuned using sequential adapters. It is possible that this, or the multi-modal AR nature itself, is the cause for these artifacts. We did not further investigate to what extent the adapters are to be particularly integrated in the attribute accumulation of Chefer. Also notice that AtMan often has similar, however not as severe, artifacts.

IxG, IG and GradCAM.

The methods IxG, IG, and (guided) GradCAM failed completely from the quality perspective. Those were the only ones that operated on a pixel level, and thus also included the vision encoder in the backward pass (which is even a requirement for GradCAM; it can only be used to explain Images). We did not further investigate or fine-tune evaluations to any method. All methods are evaluated with the same metrics and therewith give us a reasonable performance comparison without additional customization or configuration.

Details on Results.

For a fair comparison, all experiments were executed on a single GPU, as scaling naturally extends all methods. We also want to highlight that we did not optimize the methods for performance further but rather adopted the repositories as they were. The memory inefficiency of gradient-based methods arises from the backward pass. A maximal memory performant representative is the Single-Layer-Attribution method IxG, which only computes partial derivatives on the input with respect to the loss. Even this approach increases the memory requirement beyond an additional and fails for the scaling experiments up to 34B.

In Fig. 8 we ran Chefer with a full backward pass. We adopted this to the minimum amount of gradients (we saw) possible and plot the full scaling benchmark below in Fig 9666Setting requires_grad=False

to every but the attention tensors.

. The key message remains the same. With the given IxG argument, we do not see much potential in improving memory consumption further.

The methods IxG, IG and GradCam are integrated using the library Captum [14]. We expect them to be implemented as performant as possible. IntegratedGradients is a perturbation method on the input, integrating changes over the gradients. The implementation at hand vastly runs OOM. Finally GradCam is a method specialized on CNN networks and therefore does not work for text only (or varying sequence lengths). It requires the least amount of resources but also produces poor results, without further investigations.

AtMan Parallelizability.

As a final remark on AtMan, we want to recall again that the runtime measured in the sequential execution can be drastically reduced due to its parallelizability, i.p., as it only requires forward passes. For sequence length 1024, we measured 1024 iterations in order to explain each token. However note that AtMan can also be applied to only parts or chunks of the sequence (c.f. Sec. A.8), in contrast to gradient methods. Moreover, all tokens to explain can be computed entirely in parallel. In a cluster deployment, these can be distributed amongst all available workers. On top, it can be divided by the available batch size and true pipeline-parallelism.

Figure 9: Performance comparison of the explainability methods over various model sizes (x-axis) executed on a single 80GB memory GPU, with fixed batch size 1. Solid lines refer to the GPU memory consumption in GB (left y-axis). Dashed lines refer to the runtime in seconds (right y-axis). Colors indicate experiments on varying input sequence lengths. As baseline (green) a plain forward pass with a sequence length of 1024 is measured. Note that GradCAM can only be applied to the vision domain, it is therefore fixed to 144 tokens. Note that it already consumes as much memory as a forward pass of 1024 tokens. (Best viewed in color.)

a.2 Detailed SQuAD Evaluations

This sections gives more detailed statistics on the scores presented in Tab. 1. First Fig. 10 is the histogram of the token lengths of all explanations. Fig. 11 is the mAP score for all methods on the entire dataset, grouped by the number of questions occuring per instance.

Figure 10: Histogram of explanation token length of SQuAD.
Figure 11: mAP for all methods, grouped by number of question/ answer pairs in SQuAD. (Best viewed in color.)

a.3 Detailed OpenImages Evaluations

This section gives more detailed statistics on the scores presented in Tab. 2. Fig. 12 is the histogram of the fraction of label coverage on all images. Fig. 13 and 14 are boxplots for all methods on the entire dataset, for mean average precision as well as recall.

Figure 12: Histogram of percentage of label coverage of the images.
Figure 13: mAP Boxplot for all methods of all images.
Figure 14: mAR Boxplot for all methods of all images.

a.4 Discussion of Cosine Embedding Similarity

We fixed the parameter of Eq. 6 empirically by running a line sweep on a randomly sampled subset of the OpenImages dataset once. Throughout this work we set . In Fig. 15 and 16 we compare the mean average precision and recall scores for OpenImages for both variants, with and without correlated token suppression (to threshold ). Clearly the latter outperforms single token suppression.

Figure 15: Histogram of the mixed average precision on OpenImages for AtMan with (green) and without (orange) correlated suppression of tokens. (Best viewed in color.)
Figure 16: Histogram of the mixed average recall on OpenImages for AtMan with (green) and without (orange) correlated suppression of tokens. (Best viewed in color.)

The following Fig. 17 shows visually the effect on weak image segmentation when correlated suppression of tokens is activated, or when using single token suppression only. Notice how single token only occasionally hits the label, and often marks a token at the edge. This gives us a reason to believe that entropy is accumulated around such edges during layer wise processing. This effect (on these images) completely vanishes with correlated suppression of tokens.

Figure 17: Example images showing the effect of correlated suppression as described in Correlated Token Attention Manipulation Sec. 3.3 and Single Token Attention Manipulation Sec. 3.2. (Best viewed in color.)

a.5 Variation Discussion of the method

Note that the results of Eq. 4 are directly passed to a softmax operation. The softmax of a vector z is defined as

In particular, the entries and will yield to the results and . So one might argue as follows: If we intent to suppress the entropy of token , we do not want to multiply it by a factor , but rather subtract of it. I.e. we propose the modification

(7)

The only problem with this Eq. 7

is, that it skews the cosine neighborhood factors. While we experienced this working more naturally in principle, for hand-crafted factors, we did not get best performance in combination with Eq. 

6. In the following Fig. 18 and 19, we show analogous evaluations to Fig. 15 and 16. It is in particular interesting that the mode without correlated tokens slightly improves, while the one with slightly decreases in scores, for both metrics.

Figure 18: Histogram of the mixed average precision on OpenImages for AtMan with and without correlated token suppression. (Best viewed in color.)
Figure 19: Histogram of the mixed average recall on OpenImages for AtMan with and without correlated token suppression. (Best viewed in color.)

a.6 Artifacts and failure modes

In Fig. 20 we re-do an experiment (with the same examples) of [5]. I.e. given a VQA text-image prompt that is supposed to be answered with “yes” or “no”, derive the explanation on both, the input question and the image at the same time. The results are very blurry, in particular we noticed exceptionally high noise around the edges of the images for both methods. It sometimes seems to highlight the correct area, in particular in the giraffe and frisbee sample, however along with a lot of noise. It is interesting that the methods highlight different areas, in the questions as well as in the images. In general we observe inconsistent behavior of the model’s completions to this kind of prompt. It is therefore questionable what explainability methods produce at all. They might contain a direction on how to specifically alter the shortcomings of the model itself.

Figure 20: AtMan and Chefer evaluated on true multimodal explanations. The underlying model is prompted to answer the displayed question with yes or no, the explainability method is asked to highlight the important aspects on the prompt image and question at the same time. (Best viewed in color.)

a.7 Qualitative comparison weak image segmentation

In the following Fig. 21 we give several examples for better comparison between the methods on the task of weak image segmentation. To generate the explanations, we prompt the model with “Image This is a picture of ” and extract the scores towards the next target tokens as described with Eq. 5 for AtMan. For multiple target tokens, these results are averaged. In the same fashion, but with an additional backpropagation towards the next target token, we derive the explanations for Chefer and the other gradient methods.

Figure 21: Weak image segmentation comparison of several images for all methods studied in this work. (Best viewed in color.)

a.8 Application to document q/a

In Fig. 22 we apply AtMan on a larger context of around 500 tokens paragraph wise. The Context is first split into chunks by the delimiter tokens of “.”, “,”, “\n” and “ and”. Then iteratively each chunk is evaluated by prompting in the fashion “ Q: A: ” and the cross entropy extracted towards the target tokens, suppressing the entire chunk at once, as described in Sec. 3. It can be observed that the correct paragraphs are highlighted for the given questions and expected targets. In particular, one can observe the models interpretation, like the mapping of formats or of states to countries. Note in particular that it is not fooled by questions not answered by the text (last row).

Figure 22: Showing AtMan capabilities to highlight information in a document q/a setting. The model is prompted with “Context Q: A: ” and asked to extract the answer (target) of the given Explanation. Here, AtMan is run paragraph wise, as described in text, and correctly highlights the ones containing the information. All Explanations where split in around 50 paragraphs (thus requiring 50 AtMan forwad-passes). In particular it is shown in row 2 that the model can interpret, i.e. convert date-time formats. Row 3 shows that it can derive from world knowledge that Michigian is in the US. Row 4 shows that the method AtMan is robust against questions with non-including information. (Best viewed in color.)