-
Understanding DeepMind's Flamingo Visual Language ModelsResearch/Multimodal 2024. 8. 15. 11:18
https://medium.com/@paluchasz/understanding-flamingo-visual-language-models-bea5eeb05268
Flamingo is a Visual Language Model, one of the earliest multimodal generative models. This article is a deep dive of what it is, how it works and how it is used.
Introduction
Flamingo was introduced in the paper Flamingo: a Visual Language Model for Few-Shot Learning in 2022. It is a multimodal Language Model (actually a family of different sized models with Flamingo being the biggest one of them) meaning that it can take inputs in more than one modality and generate text like a standard Language Model. Flamingo in particular is able to ingest a multimodal prompt containing images and/or videos interleaved with text.
Understanding this model is a building block for understanding multimodal models that have been released since then, for example Gemini. In contrast to Gemini which doesn’t reveal many details, Flamingo offers us a good picture of the architectures and components used. We can also guess that its successors use similar architectures to Flamingo but extended to even more modalities.
Before we dive into the model’s details let’s see its capabilities. As shown in the paper the model is able to reason using text as well as images given a few examples (called few-shot learning). The images and text are completely interleaved maintaining a logical order in the input as shown below.
The last example shows how Flamingo can be used with videos. The video can be split into frames (sampled at 1 FPS) and passed into the model as a sequence.
Flamingo is also capable of multi-image visual dialogue out of the box.
One could also use Flamingo for visual question answering from an image.
Those are quite impressive capabilities so lets dive into how it works!
Model Details Overview
Architecture
The Flamingo models have a few components. Some are pre-trained and kept frozen (meaning the model weights do not get updated during training), while others are trained from scratch.
The model takes interleaved visual/text data as input. The images are extracted from the text and replaced with a common token e.g. <image>. This can be then passed into the plain Language Model component. The images are separately passed in through a vision encoder model to convert them into fixed size embeddings. They can then be “attended” to as part of a novel cross-attention mechanism.
Here is the additional overview from the paper:
Flamingo models leverage two complementary pre-trained and frozen models: a vision model which can “perceive” visual scenes and a large LM which performs a basic form of reasoning. Novel architecture components are added in between these models to connect them in a way that preserves the knowledge they have accumulated during computationally intensive pre-training.
Training
Flamingo is only trained on a carefully chosen mixture of complementary large-scale multimodal data coming only from the web, without using any data annotated for machine learning purposes. After this training, a Flamingo model can be directly adapted to vision tasks via simple few-shot learning without any task-specific tuning.
As described in the paper,
Flamingo models the likelihood of text 𝑦 conditioned on interleaved images and videos 𝑥 as follows:
where 𝑦ℓ is the ℓ-th language token of the input text, 𝑦<ℓ is the set of preceding tokens, 𝑥≤ℓ is the set of images/videos preceding token 𝑦ℓ in the interleaved sequence and 𝑝 is parametrized by a Flamingo model.
This is the standard auto-regressive procedure, used in Language Models where at each time step the model predicts a token based on the previous tokens already generated, i.e. P(y|x) = p(y_1|y<1,x≤1) * p(y_2|y<2,x≤2) … See the decoding strategies section in this article for more details.
Training objective
Four different datasets were used for training the model containing text and image/video pairs. Some were pre-existing and some were collected and assembled as part of this study. For each dataset the negative log likelihood was minimised. The overall loss was a weighted sum of the negative log likelihoods per dataset as described by the following formula
where 𝒟𝑚 and 𝜆𝑚 are the 𝑚-th dataset and its weighting, respectively. As stated in the paper the datasets used for training are: M3W, ALIGN, LTIP and VTP with weights 𝜆𝑚 of 1.0, 0.2, 0.2 and 0.03 respectively. The lambdas proved key to the performance. They were chosen empirically in a smaller experiment and kept constant for the full training.
The largest model containing 80 billion parameters is trained on 1536 chips for 15 days and sharded across 16 devices.
Transformer Attention Mechanism Revisited
Before diving into the details of the individual components, lets do a quick recap of the Transformer Attention mechanism which we will need. For a more detailed introduction see this great post.
The Encoder
Recall that the attention mechanism is described by the following formula
The Q, K, V represent the query, key, and value matrices respectively. The intuition is that for each of the input tokens we “query” (or attend to) all the other tokens that act as “keys”. By using the softmax we select the tokens that have the most importance (relation) to encode the current token and then retrieve a weighted sum of their corresponding values based on this importance.
In the encoder self-attention Q, K, and V are all obtained via a linear projection of the input tokens - a matrix of size (seq_length, embedding_dim) multiplied with learnt weight matrices of size (embedding_dim, head_dim). This results in Q, K, V matrices of size (seq_length, head_dim). The product QK^T is then of size (seq_length, seq_length) and the final product is of size (seq_length, head_dim). The transformer does this with eight different matrices in what is called multi-head attention each with a head_dim of 64. Below is my dirty hand written cheat sheet which makes it easier to visualise.
This mechanism has one important computational drawback. We can see that all the matrices increase with the size of the input sequence length and in fact the attention mechanism has a quadratic complexity in seq_length.
The Decoder
In the transformer decoder there is also an additional cross-attention mechanism (though not called that at the time), in which each of the output tokens being generated attend to all the tokens in the input sequence. Here, the queries come from the output tokens while the keys and values come from the input tokens. As such for an input sequence of length N and an output sequence of length M, Q would have a shape (M, head_dim) while K and V matrices would have a shape (N, head_dim). This scales linearly in the input sequence length and linearly in the output sequence length with an over complexity of O(NM) in comparison to O(N²) in the self-attention mechanism.
The Bottleneck
The O(N²) complexity in the self-attention proved to be a bottleneck for longer sequence lengths (the original Transformer has a maximum sequence length of 512 tokens) and there has been a lot of research since then (2017) about alternative mechanisms that still retain the same level of accuracy but can scale to much longer sequence lengths. For example, the recent Mistal 7B LLM uses a sliding window attention mechanism which only attends to tokens in a fixed window size to the current token but uses the transformer stacking to attend to more far away tokens later on in the network, or the Large World Model which can process 1h of video with ring attention, distributing long sequences across multiple devices.
The quadratic complexity in the sequence length is even more apparent when attempting to use transformers for visual data — it takes a lot more pixels to describe an image than tokens for a piece of text. The original vision transformer only split images into patches instead of using individual pixels for instance. One of the drawbacks of these methods is that they make assumptions about the input data resulting in inductive biases.
In 2021 an alternative general architecture was proposed that handles various modalities and scales to hundreds of thousands of inputs called the Perceiver.
In the Perceiver, the K, V are projections of the input sequence (which they called a byte array) such as an image. This has a length M which is so large (for instance 50K for 224 x 224 ImageNet images) that it would be very computationally expensive to feed into the transformer. As a result, they use what they call a fixed size latent array of size N, where N <<M, for example 512. The query matrix Q is a projection of this latent array. They use a cross attention component, using this latent array to attend to the tokens in the byte array. This scales linearly with the size of the byte array, similar to the original transformer decoder. The complexity of this component is O(NM). The normal Transformer flow can now be applied to the latent array, including a self-attention of the latent array of complexity O(N²). This results in an overall complexity of O(NM) + O(N²) with N <<M. Importantly, as N is small, this allows them to stack many Transformer layers constructing very large networks on large scale data, which would be infeasible had the complexity been O(M²). This use of latent space is a common computational trick used in other models such as Stable Diffusion.
Now that we have the necessary components (Perceiver and cross-attention) we can dive into the Flamingo details.
Visual Components Architectures
To recap, Flamingo extracts the interleaved images from the text and encodes them separately using two different components as shown below:
The Vision Encoder
The vision encoder used is very similar to OpenAI’s CLIP model which I go into detail in another article. The authors experimented with CLIP as well but found their version to give better results.
The vision encoder was trained from scratch, together with a language encoder. Using these encoders, images and text pairs are separately encoded and projected to a shared embedding space and L2 normalised. Similarly, to CLIP they used contrastive learning with the multi-class N-pair loss to maximise the similarity between all paired embeddings and minimise the similarity between unpaired embeddings. In contrast to CLIP which opted for a Vision Transformer, Flamingo uses a Normalizer-Free ResNet. After pre-training this vision encoder is frozen and not trained any longer!
The Perceiver Resampler
The Perceiver Resampler used is very similar to the Perceiver we have already discussed earlier. It takes a variable number of visual features and converts them for a fixed number of output tokens. The overview is shown in the figure below
1. The visual features 𝑋𝑓 are extracted from the Vision Encoder. It is unclear from the paper exactly what is extracted, but based on the above diagram it would seem that a number of patches S get extracted from the image each of embedding dimension d, resulting in 𝑋𝑓 of shape [T, S, d], where T represents the time.
2. 𝑋𝑓 are obtained by first adding a learnt temporal position encoding to each feature within a given video frame (an image being considered as a single-frame video). This is analogous to the positional embeddings in the Transformer, as unlike in an RNN/LSTM the inputs are not processed sequentially and any ordering information is otherwise lost.
3. The 𝑋𝑓 are then flattened and concatenated into a matrix of shape [T*S, d].
4. 𝑋𝑓 is passed through numerous Perceiver Resampler blocks.
5. Just like in the Perceiver the queries (Q) come form a projection of the learnt latents. However, interestingly, both the learnt latents and 𝑋𝑓 are concatenated before projecting to calculate the keys (K) and values (V). The authors explain that they found this to perform slightly better. These get used in the attention mechanism which scales linearly in the size of the input 𝑋𝑓.
6. Finally this is followed by two residual connections and a feed forward component just like in the Transformer.
7. The number of output tokens of the Perceiver Resampler is equal to the number of learnt latent queries.
These output tokens are now ready to be used in the Language Model to generate text. Let’s see how that works!
The conditioning Language Model Components
Text generation is performed by a Transformer decoder, conditioned on the visual representations produced by the Perceiver Resampler. Flamingo interleaves pre-trained and frozen text-only LM blocks (i.e. a standard pre-trained Transformer) with blocks trained from scratch that cross-attend to the visual output from the Perceiver Resampler. The architecture can be visualised below
As described in the paper, a “gating” mechanism is used to improve training stability and final performance:
To ensure that at initialization, the conditioned model yields the same results as the original language model, we use a tanh-gating mechanism. This multiplies the output of a newly added layer by tanh(𝛼) before adding it to the input representation from the residual connection, where 𝛼 is a layer-specific learnable scalar initialized to 0. Thus, at initialization, the model output matches that of the pretrained LM, improving training stability and final performance.
Recall that the tanh activation looks like below
and so indeed when 𝛼 is initialised to 0, tanh is 0 and the parameter y in the pseudo-code is only left with the language features. Below are two diagrams showing how the value of alpha changes during the training process for the attention and feed-forward (FFW) individual layers.
The Cross-attention Block
Below is an illustration of the cross-attention (called xattn-dense in the paper) block.
1. The images are extracted from the text (and replaced with the <image> token) and passed through the visual components as described before. New tokens <BOS> for “beginning of sequence” and <EOC> for “end of chunk are added to indicate where the text starts and ends.
2. At a given text token, the model only cross-attends to the visual tokens corresponding to the last preceding image/video. 𝜑 indicates which image/video a text token can attend or 0 when no image/video is preceding.
3. In practice, this selective cross-attention is achieved through masking — illustrated here with the dark blue entries (unmasked/visible) and light blue entries (masked).
For instance the tokens “<image> My puppy sitting on the grass <EOC>” only attends to the image of the dog, while the tokens “<image> My cat looking very dignified. <EOC>” only attends to the image of the cat.
Though the model only directly attends to a single image at a time, the dependency on all previous images remains via self-attention in the Language Model. This single-image cross-attention scheme importantly allows the model to seamlessly generalise to any number of visual inputs, regardless of how many are used during training. For example, they used 5 images in training but at inference the models can benefit with up to 32 images (few-shots).
Experimental Results
It is important to reiterate that Flamingo was pre-trained using large amounts of textual and visual data in a semi-supervised way (similar to other pre-trained language models), and was not fine-tuned on any specific machine learning dataset. In order to evaluate the models ability to generalise to unseen examples an evaluation was conducted against various datasets such as COCO. Flamingo’s performance can be compared with the state of the art zero/few shot models as well as to the state of the art fine-tuned models on those specific datasets. The results are show below:
- Both image (I) and video (V) understanding datasets are used.
- Results are shown for 3 different Flamingo models (varying in size) with the Flamingo model (80B) achieving the best results.
- The Flamingo models are tested with using zero shot (no examples shown to the model in the prompt) as well as few shot with either 4 or 32 examples images shown in the prompt.
- The Flamingo model with 32 shots outperforms fine-tuned SOTA models across all benchmarks. This is very significant as during fine tuning a model is able to see many more (thousands) examples. Though it was not stated I guess that the numbers in brackets show how many examples those models were trained on, which is for instance 500K for COCO. So the Flamingo model can outperform a model trained on 500K examples with just 32 examples (shots)!
Limitations
Flamingo suffers the typical issues encountered in Language Models such as hallucinations, outputting offensive language, propagating social biases and stereotypes, as well as leaking private information. Its ability to additionally handle visual inputs also poses specific risks such as gender and racial biases relating to the contents of the input images.
Below is a selected set of examples illustrating Hallucinations and their explanations.
Left: The model occasionally hallucinates by producing answers that seem likely given the text only, but are wrong given the image as additional input. Middle: Similar hallucinations can be provoked by adversarially prompting the model with an irrelevant question. Right: A more common pitfall arises when the model makes ungrounded guesses when the answer cannot be determined based on the inputs. Few-shot examples and more sophisticated prompt design may be used to mitigate these issues. More broadly, addressing these issues is an important research direction towards improving our models’ applications in open-ended visual dialogue settings.
Available Implementations
The Flamingo model was not open sourced, and model weights were not made available. However, to my knowledge there are some reproduced implementations, like this repo. Also, if you like using Huggingface, they released an open source version of Flamingo called IDEFICS (Image-aware Decoder Enhanced à la Flamingo with Interleaved Cross-attentionS). You can run it on a CPU but it will take almost half an hour to infer one prompt!
Conclusions
Flamingo is a set of Vision Language Models of varying sizes. It works by taking interleaved text and image/video as input and generating text as output. It outperforms fined tuned models on specific bench mark by just taking a few examples in the prompt (few-shot learning).
In this post, we broke down all the different components in detail, including a Vision Encoder (to embed the visual input), a Perceiver Resampler (to map a variable visual input into a fixed number of tokens), and a cross-attention mechanism that can attend to the visual inputs inside the Language Model and generate text.
'Research > Multimodal' 카테고리의 다른 글
Perceiver / Perceiver IO (0) 2024.09.23 Understanding VQ-VAE (DALL-E Explained Pt. 1) (0) 2024.08.19 [Flamingo] Tackling multiple tasks with a single visual language model (0) 2024.08.15 Contrastive Learning (0) 2024.08.14 [CLIP] Connecting text and images (0) 2024.08.13