ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [Flamingo] The architecture behind modern visual language modeling
    Research/Multimodal 2024. 9. 24. 08:39

    https://towardsdatascience.com/flamingo-intuitively-and-exhaustively-explained-bf745611238b


    In this article we’ll discuss Flamingo, a landmark paper in “multimodal modeling”.

     

    First we’ll define “multimodal models” as a class of machine learning models capable of understanding numerous types of data. We’ll then briefly explore landmark papers in image classification and text generation, then describe how Flamingo combined these technologies to achieve state of the art performance in use cases containing both images and text.

     

    By the end of this article you’ll have a thorough understanding of how Flamingo achieved state-of-the-art performance, paving the way for today’s advanced A.I. systems like GPT-4 and Google Gemini.

     

    Flamingo conversing about both text and images. Blocks in pink are generated by the Flamingo model. From the Flaming o paper.


    Multimodal Modeling Before Flamingo

    “Vision-Language modeling” is what most people think of when they think of “multimodal modeling”. Before we get into the nuts and bolts, let’s define these two ideas:

     

    1. Multimodal modeling is an umbrella term for any machine learning model that deals with multiple “modalities”. You can think of a modality as a type of data; things like text, images, tables, and audio are all considered by data scientists as different “modalities”. Multimodal models are models that can somehow work with multiple modalities.

     

    2. Vision-language modeling is probably the most popular form of multimodality. machine learning models that can, in some way, do tasks that require the simultaneous understanding of both images and text.

     

    In reality, vision-language modeling is a catchall term for a broad class of problems:

     

    1. Visual Question Answering: Given an image and a textual question about that image, generate a response

     

    2. Captioning: Given an image, describe the content of the image textually

     

    3. Visual Dialogue: Hold a coherent and organic conversation that contains both images and text

     

    4. Image Classification: Given an image, categorize that image into one of a fixed set of predefined textual classes.

     

    Before Flamingo, highly specific models were state-of-the-art on highly specific multimodal tasks:

    • KAT was state of the art on the OKVQA dataset, a visual question answering dataset
    • a Good Embedding Is All You Need? was state of the art on VQAv2, a different visual question answering dataset
    • SimVLM was state of the art on COCO, an image captioning dataset
    • VIOLET was state of the art on MSVDQA, a visual question answering dataset focused on video

    One of the big claims from the CLIP paper (a landmark paper we’ll cover in the next section) was that this highly specialized performance doesn’t scale well to real-world situations. If you have a model that’s good at one dataset, but can’t do a similar task in a similar dataset, is it really a good model?

     

    CLIP tackled this problem before flamingo, but CLIP style models only work in classifying images. The idea behind Flamingo was to bridge the gap between CLIP style models, which have a robust understanding of the content of images, and the textual understanding and generative ability of language models, creating a system that could be used to converse about both text and images robustly and flexibly.


    The Precursors to Flamingo

    Before we talk about Flamingo itself, it’s critical to understand the two key modeling strategies it inherits; CLIP and Decoder Transformers:


    CLIP

    The whole idea of CLIP was to create a general-purpose image classifier that could be used in a variety of cases without any further training. To achieve this, the authors of the CLIP paper used a strategy called “contrastive learning”.

    CLIP predicting highly specific labels for classification tasks it was never directly trained on.  Source

    contrastive learning is a subtle re-framing of the problem of image classification. Instead of learning “This label goes with this image, and this label goes with that image”, contrastive learning says “This label is closer in similarity to this image, and this label is closer in similarity to that image”. This subtle change in thinking opened up a whole new approach to representing images that have been widely used ever since.

    The objective of CLIP, to move things that are similar to each other closer together, and different things further apart. From my article on  CLIP

    CLIP employs two components to build this idea of “closeness”, an image encoder and a text encoder. These both learn, in unison through the training process, to jointly align pairs of images and text in a high-dimensional space. By putting similar images and text in similar spots, and different images and text in different spots, CLIP style models learn to make decisions about which text belongs with which image.

    an example of CLIP after it’s trained, with 2 dimensional embeddings for demonstrative purposes. The image encoder summarizes images into vectors, and the text encoder summarizes text into a vectors. Notice how, once the encoders are trained, positive pairs end up close together. From my article on  CLIP

    The important part for us is the image encoder. For CLIP to be successful, it has to train the image encoder to understand things like dogs, clothes, and skateboards within an image so that those images can be placed in the right spot with the right text. In other words, CLIP Image encoders are really good at distilling an image down into its general meaning, a quality Flamingo employs to achieve visual-language modeling.

     

    https://towardsdatascience.com/clip-intuitively-and-exhaustively-explained-1d02c07dbf40


    Decoder-Only Transformers

    So that was CLIP, the technology Flamingo uses to understand images. Decoder-only transformers are what Flamingo uses to understand text.

     

    From a super high level, you can think of language models as a big stack of blocks. The purpose of each of these blocks is to refine a representation of the input text, block by block, and then use that representation to predict what text should follow the input.

    The entire GPT-1 model. 12 decoder blocks stacked on top of each other. From  my article on GPT . We don’t have to worry about why these are called decoder blocks in this article, but you can learn all about it in my  article on GPT . image  source

     

    The thing that makes each of these blocks within the transformer special is their use of “attention”. Attention is a form of modeling where the representation of multiple words in the input are combined together to create an abstract and highly contextualized representation.

    The idea of attention, in a nutshell. The mechanism mathematically combines the vectors for different words, creating a matrix which encodes a deeper meaning of the entire input. From  my article on transformers .

    This is done by feeding the attention mechanism three inputs: A “Query”, “Key” and “Value”. Don’t get hung up on the names, we’ll build a more intuitive understanding of all this later, I just want to share with you the high level workings of the attention mechanism.

     

    The query and key, which are derived from the input to the attention mechanism, get multiplied together to create what I like to call the “attention matrix”

    Two inputs, called the “key” and “query” get multiplied together to create the attention matrix. The key is usually transposed (rotated) to make the matrix multiplication work out right.  from my article on transformers

    Then, the attention matrix is used as a filter to transform the value matrix into the final output.

    The attention matrix acts like a filter, which transforms an input called the “value” into the final result.  from my article on transformers

    So, in essence, the attention mechanism uses some inputs to filter other inputs. 

     

    The end goal of a decoder-only transformer is to use the attention mechanisms, within the numerous blocks of the model, to understand the input, and figure out what the next output should be. By being able to guess the next output well, language models can construct an output by guessing one word at a time.

    A decoder-only language model generating output one word at a time. From  my article on GPT


    Flamingo in a Nutshell

    Now that we understand the essence of CLIP style image encoding (which turns an image into a vector that conveys the image's general meaning) and transformer style language models (which use attention to iteratively output words), we can start digging into Flamingo.

     

    At its highest level, flamingo consists of four key components:

     

    1. A Vision Encoder, which re-represents images into their general meaning.

     

    2. A Perceiver Resampler, A system that combines the information from a variable number of images into a fixed number of features (allowing the model to understand things like video or a series of images)

     

    3. A Language Model, a pre-trained decoder style transformer like GPT3 or llama. The flamingo paper used chinchilla.

     

    4. Gated Cross Attention, allowing flamingo to slowly learn to inject image information into the language model throughout the training process.

    The Flamingo Architecture. We’ll explore this diagram, component by component, in future sections.  Source

    Flamingo uses these systems to understand an arbitrary input sequence of images and text to generate textual output. Let’s break down each of these components, one by one, to build a complete understanding of how Flamingo functions.


    The Vision Encoder

    Flamingo uses a CLIP style image encoder (the one we previously discussed) to encode images. This is a common strategy used in numerous multimodal architectures that have to do with images.

     

    The idea is, instead of Flamingo needing to learn about images from scratch, it can employ the high quality summarizations from a pre-trained CLIP image encoder. Thus, Flamingo offloads a lot of the work of understanding images, and instead only has to reason about image distillations.

    A conceptual diagram of how CLIP summarizes an image (which contains a lot of difficult to interpret information) and distills it into an abstract, information dense representation which is easy for language models to interpret. Image courtesy of my trusty writing companion.

    Flamingo doesn’t actually use CLIP, but the NFNet F6 model. For our purposes, the only conceptual difference is that NFNet produces summarizations about sub-regions of an image rather than the whole image. This makes it easier for Flamingo to understand subtleties within the image. NFNet also does a lot of other cool stuff, I might cover it in a future article, but for our purposes, this is more of a line item. Conceptually, NFNet is like a fancy version of CLIP.

    A conceptual diagram of NFNet in action, which summarizes regions of an image to allow for more nuanced information extraction.


    The Perceiver Resampler

    In wanting to create a flexible and robust multimodal model, the authors of Flamingo created a system that was good at handling both images and video. Video data is a difficult type of data to do machine learning on; there’s a lot of information in even small video files, and extracting the right information efficiently can be computationally expensive and difficult.

    An example of flamingo inferring information from a series of images. When handling video, Flamingo simply treats the video as a series of images at one frame per second. video clips  Source .

    Flamingo addresses the problems of video with the “Perceiver Resampler”. The perceiver resampler can be thought of as a summarization system that compresses an arbitrarily long video down into a fixed set of descriptive tokens. It’s not conceptually difficult, but there’s a lot of moving parts. Let’s look at it from a high level, then a more nuts-and-bolts lower level


    The Perceiver Resampler - High Level

    Conceptually, you can think of the perceiver resampler as a filter; it takes in a fixed length of predefined tokens and uses input images extracted from video to filter those tokens. Regardless of the number of images in an input, the same fixed number of tokens come out of the output.

    A conceptual diagram of a perceiver resampler in action. The perceiver resampler uses images to filter a fixed number of tokens. Image  Source .

    From a high level, the perceiver resampler fits into the greater Flamingo architecture in the following way:

     

    1. The images are extracted from the prompt. In their place, an <image> token is placed in the text so that the model knows where the image came from.

     

    2. The output from the perceiver resampler is used to incrementally filter the internal state of the LLM throughout various layers, ultimately allowing the LLM to converse about the images.

    A conceptual diagram of how Flamingo interweaves the content from the preceiver resampler into a language model. The LLM Blocks are from a pre-trained, existing LLM. Flamingo injects “cross attention” blocks, which introduce information about the content of the image through successive layers of the LLM. video clips  Source .

    We’ll cover how cross-attention interweaves the image representation into the LLM in later sections. For now, let’s zoom into the perceiver resampler and see how it works.


    The Perceiver Resampler - Nuts and Bolts

    To understand the perceiver resampler in more detail, let’s work through its components step-by-step.

    The Preceiver Resampler.  Source .

     

    First, the input image, or sequence of images, is passed through a vision encoder. This summarizes the content of the image in a way that’s easy to interpret for ML systems. This is our NFNet image encoder we discussed in a previous section.

    The preceiver resampler summarizing images with our image encoder.  Source .

     

    The attention mechanism (which the image encodings are ultimately fed into), tends to shuffle inputs around and thus lose track of where a particular piece of information was in an input sequence. As a result, it’s customary to add a time encoding, which embeds the time of an input into the value of the input itself.

    The usage of the vision encoder within the perceiver resampler.  Source .

     

    Flamingo uses a learned time vector for each frame in the input. During training, Flamingo has spots for 8 input frames total. These are added to the features extracted from the vision encoder.

    A conceptual diagram of adding time encodings to the image encoding. There is a vector for each time slot; t=0, t=1, and so on. All the vectors from an image encoding get a corresponding time vector added to them (so, at t=0, all encoding vectors have the t=0 time vector added to them); thus the output vectors contain information about the image, and the time of the image within a sequence. This image only contains a sequence of three images, but eight are supported in Flamingo.

     

    I was a bit surprised by this. I thought the whole point of Flamingo was to be generalizable; limiting video input to eight frames seemed like kind of a silly design choice. Apparently, though, the model is robust to interpolating between time embeddings to fit in more frames as necessary. So, if you want to add more frames, just make new time tokens by interpolating between the eight trained ones.

     

    Although our model was trained with a fixed number of 8 frames, at inference time, we input 30 frames at 3 FPS. This is achieved by linearly interpolating the learnt temporal position embedding of the Perceiver Resampler at inference time. — The Flamingo Paper.

     

    Another quick note, you might think “Hey, if we need to add information about time, then why not location? If attention mixes up our input, wouldn’t it be useful to say something like ‘this information came from the top right of an image’?” The answer to that is absolutely, but we don’t have to add positional information because our vision encoder does already:

     

    Note that we only use temporal encodings and no explicit spatial grid position encodings; we did not observe improvements from the latter. This rationale behind is likely that CNNs, such as our NFNet encoder, are known to implicitly include spatial information — The Flamingo Paper.

     

    Now that we have extracted features from each image, and we’ve added all necessary information about time, we can use attention to filter out the right information from the image via the learned tokens.

    The attention mechanism, which uses learned tokens (learned latent queries) to extract the right information from the encoded images.  Source .

     

    Let’s follow some data through the attention mechanism to get a thorough idea of how it functions.

    Step 1) Flattening

    The features extracted from the images are of shape [T, S, d] where T is the number of images, S is the number of spatial grids, and d is the length of the feature vectors. In most machine learning contexts, “tokens” are vectors of some length. d is sometimes referred to as the “internal dimension”, as it’s the size of the token vectors within the model (note: in this article d is depicted as length 6, but in reality, the internal dimension of modern models is very large. on the order of hundreds or thousands. So, usually, these vectors are much, much longer than is depicted).

     

    Before passing through attention, these tokens get flattened along the space and time dimensions; so T and S become a dimension of length T * S , resulting in a two-dimensional matrix of shape [T * S, d] .

    The flattening process. In reality, our image data is a 3D tensor of shape [time, space, d]. This data gets flattened across time and space.

    Keep in mind, while this might appear like we’re shuffling the image data around:

     

    1. This operation is done consistently across successive runs, so while flattening mixes up the order of space and time, it does so in the same way for any given input.

     

    2. We don’t really need the order to be preserved. Recall that spatial information is automatically encoded by the image encoder, and time information was added to encode time information in each token. As long as things are done consistently, the order of the inputs doesn’t really matter.

    Step 2) Creating the Key and Value

    Now that the information from our images is properly processed, it’s time to pass them into the attention mechanism. Recall that, from a high level, the whole idea of the perceiver resampler is to use images to filter a fixed number of tokens.

    A conceptual diagram of a perceiver resampler in action. The perceiver resampler uses images to filter a fixed number of tokens, thus encoding an arbitrary sequence of images into a fixed size. Image  Source .

     

    This “filtering” idea gets done with matrix multiplication within the attention mechanism, which is essentially just matrix multiplication.

    Recall that an [I,m] shaped matrix can be multiplied by a [m,n] shaped matrix to result in an [I,n] shaped matrix. This is because, in matrix multiplication, the rows of the first matrix are multiplied by the columns in the second. By controlling the shape of the “A” and “B” matrices, one can control the shape of the “C” matrix.  Source .

     

    The whole idea of this attention mechanism is that a fixed number of tokens are extracted from a variable sequence of images, so it makes sense that some clever matrix manipulation would be required to get everything working right. That’s why there are a bunch of arrows before the attention mechanism in the perceiver resampler.

    Before the attention mechanism, data gets moved around so that the right values get put in the right spots. this is done for a variety of reasons, but chiefly to make the matrix multiplication work out well.  Source .

     

    For Flamingo, the flattened features from the image, labeled Xf , are concatenated with a set of learned tokens, X of shape [R,d] to construct the “key” and “value” inputs. In this case, R represents an arbitrary fixed number of tokens. The “query” is simply the learned tokens X .

     

    Recall that d is the internal dimension of the model.

    Some shuffling around of data happens before the attention mechanism constructs the query, key, and value.

     

    One question you might be wondering is “Why are the learned tokens appended to the image information?”, I think this is so the perceiver resampler can more intimately control the information that comes out of the attention mechanism. Keep that in mind as we go through the next section.

    Step 3) Running Through the Attention Mechanism

    Now that we have our Query, Key, and Value defined we can run them through attention. Attention is just two matrix multiplication operations. First, the Query and Key are multiplied together to construct the attention matrix

    Then the attention matrix is multiplied by the value to construct the final output.

    et voilà, the attention mechanism in the perceiver resampler has now extracted information from all images into a fixed number of output tokens, which are of size [R, D] , where r is the arbitrary number of fixed learned tokens and d is the, also arbitrary, internal model dimension.

    A conceptual diagram of a perceiver resampler in action. The perceiver resampler uses images to filter a fixed number of tokens, thus encoding an arbitrary sequence of images into a fixed size. Image  Source .

    Step 4) Constructing the Final Output

    After the attention mechanism a “skip” connection is applied. Basically, the attention mechanism is great at extracting subtle features, but it mixes everything around like crazy. It’s useful to allow some of the older, simpler structure from before the attention mechanism to be present. Thus, the learned tokens from before the attention mechanism are added to the output of the attention mechanism, allowing some older and simpler information to be present.

    After the attention mechanism, a skip connection is applied. That’s that “+” after the attention mechanism.  Source .

     

    The output of the skip connection is passed through a feed forward network, which is your classic prototypical neural network. Another skip connection is applied for the same reason as before.

     

    Generally, with machine learning systems, it helps to do the same thing in a few passes so that complex operations can be done incrementally over numerous layers. In the perceiver resampler, this information extraction using attention, skip, and feedforward is done a few times. The output tokens from one iteration are fed right back into the input of another; hence the X num_layers in the diagram above.

     

    Alright, so that’s the perceiver resampler. Using a set of learned tokens (which are just vectors of numbers that can change throughout the training process), multiple stacks of attention allow the perceiver resampler to extract information from a sequence of input images. In the next section we’ll discuss how that information is fed into a language model so that the language model can converse about images.


    Combining Visual and Textual Information With Gated Cross Attention

    Cudos for hanging in there, hopefully you’re finding this as fascinating as I am. It’s been about six months since you started reading this article; recall that the whole point is to get a pre-trained language model to be able to understand and converse about images.

    The high level flamingo architecture, allowing language models to converse about interwoven images and text. Notice how gated cross attention is interwoven between pre-trained language modeling (LM) blocks.  Source .

     

    Now that we’ve used the perceiver resampler to extract a fixed amount of information from our images (or image sequences), we need to feed that information into the language model. To do that, we’ll use gated cross attention.

     

    The idea of gated cross attention is to incrementally inject visual information throughout our pre-trained language model. At every level (or every few levels), gated cross attention tries to introduce the right visual information so that the model can perform best.

    Zooming into the gated cross attention mechanism within a single flamingo layer.  Source .

     

    The visual data from the perceiver resampler is used to construct the “key” and “value”, and tokens from within the language model are used to generate the “query”. By feeding both image and language data into cross attention, both language information and image information are used to create an abstract representation of both modalities (We’ll zoom into the specifics of cross attention in the next section).

     

    The language model, if it were a person, would be freaking out at this point. The language model has learned to carefully represent language throughout successive highly abstract and subtle representations. Now we’re throwing a big old wrench in that careful equation.

     

    To mitigate this issue, Flamingo employs “tanh gating”. at every point Flamingo injects cross attention, it employs a tanh gate with a learnable value alpha. This alpha value, which is used as the input to the tanh function, is initially set to 0. the result of tanh(alpha=0) is 0 .

    the tanh function, which approaches 1 given large inputs, and -1 for small inputs.  Source .

    The result of the gating function is multiplied by the cross attention output, effectively acting like a knob that dials in how much visual information is introduced to the model. The model learns through training that changing alpha changes the value of the tanh function, thus it learns to slowly dial in more or less visual information from the cross attention mechanism.

     

    After the tanh gate, there’s a skip connection denoted as “+”. Recall that skip connections just take data from before some operation, and add it to the result of an operation.

    Notice the skip connections and tanh gating being used after cross attention (and for that matter, also a new learnable feed forward network).  Source .

     

    So, between tanh gating and the skip connection, the cross attention mechanism that introduces visual information looks something like this:

    result = (tanh(alpha) * cross_attention(vis_info, lang_info)) + lang_info

    because alpha is initialized to 0 at the beginning of training, the value of the tanh gate is also 0. Thus, the cross attention mechanism has no info at the beginning of training, and the final output is the same as if we just used the language model without flamingo at all.

    result = (0 * cross_attention(vis_info, lang_info)) + lang_info = just lang_info

    If we didn’t do this, we would be injecting image information straight up at the beginning of training, and would heavily confuse the language model with a bunch of bad information, as a brand new preceiver resampler has no idea how to best represent image data in a way that’s best for the language model. Thus, tanh gating allows Flamingo to gently introduce image information into the model.

     

    After the gated cross attention there’s a gated feed forward layer, which is gated for the same reason. You can think of this as a neural network that transforms the combined language and image data to be copacetic with future layers of the language model.

     

    So, from an information flow perspective, that’s it! You now understand how flamingo models are able to converse about images by:

     

    1. extracting images from the input prompt

     

    2. using the perceiver resampler to reformat images or image sequences into a fixed number of tokens

     

    3. using cross attention to inject image information into the language model, so the language model can converse about both the input text and images.

     

    There are a few other details that Flamingo employs to make the flow of information actually work. Now that we have a complete picture of Flamingo from a high level, we can zoom into some specific details.


    An Intro To Masking 

    One of the things that makes Flamingo unique is the way it handles masking. For those unfamiliar with the general topic, masking is an important component of attention that allows researchers to hide certain information in certain parts of a machine learning model.

     

    You might be wondering why it would be useful to hide information from a model. Wouldn’t a machine learning model generally be better if it had access to all information? Generally that’s the case, but masking is a frequent exception to that general rule.

     

    Language models, for instance, use a “causal mask”. The idea of language models is that they predict the next word in an input sequence, but when they’re trained, language models like GPT actually predict every next word in a sequence.

    A demonstration of a language model predicting every next word in an output sequence during training. When it’s wrong, the parameters are updated to train the model. From my article on  speculative sampling , which goes a bit more in depth on this topic.

     

    If a language model could just look at what the next word in a sequence should be, this form of training would be trivial; the model would just copy the next word at a given output. By using causal masking, however, the model is unable to see future words when predicting a given word.

    A conceptual diagram of what the causal mask does for a model when training a language model. It hides future words from a model, meaning the model has to actually predict the next word in a sequence. From my article on  speculative sampling , which goes a bit more in depth on this topic.

     

    The causal mask itself exists within the self attention mechanism of the language model. Recall that modern language models are, essentially, a big stack of blocks of self attention. The causal mask gets applied to the attention mechanism via the actual attention matrix. Recall that the “query” and “key” inputs get multiplied together to construct the “attention matrix”, and the “attention matrix” is multiplied by the “value” to create the final output.

    A conceptual diagram of the query and key being multiplied to create the attention matrix. Notice how there are specific values in the attention matrix where one word in the input sequence corresponds to another word in the input sequence.

     

    This attention matrix, in being a matrix multiplication, is a complex combination of every input in the “query” with every input in the “key”. The idea of causal masking is, when a “query” interacts with a “key” that corresponds to a word in the future, the causal mask sets the value of the attention matrix at that location to zero; thus hiding from the model any information about future words from the given word.

    A conceptual diagram of our causal mask applied to attention. Notice how certain blocks in the attention matrix are black; these denote values that were set to zero via the causal mask. This general shape, of a diagonal matrix where the upper right region is masked, is the prototypical “causal self attention mask”.

     

    Keep in mind, this is to facilitate making the training objective of predicting every next word non-trivial. If every attention block in a language model uses a causal mask, information from future words can never bleed into the information of previous words.

     

    Causal masking is relevant to Flamingo, in that the language model which Flamingo uses was trained using a causal mask. However, flamingo itself employs a slightly different masking strategy in masking media.


    Arbitrary Cross Attention of Media Via Masking

    The idea of flamingo is that you can arbitrarily converse with the model about images and text. You might have a lot of images, you might have very few images, and different conversations with the model might feature a variety of locations those images might exist in a conversation. Machine learning models like to learn patterns, and they get confused when they encounter examples that don’t respect patterns they’ve learned. How can we get Flamingo to be good at understanding the diversity of conversations it might have about images?

    A few examples of how text and images might be interleaved in a conversation with Flamingo.  Source .

     

    The Flamingo paper addresses this problem by using a different masking system. Recall from the previous section that Flamingo employs cross attention to inject image data from the perceiver resampler into the language model.

    Recall that flamingo uses cross attention to inject image data into the language model.  Source .

     

    Within the cross attention mechanism, only the immediately preceding image (or image sequence) is allowed to attend to a particular word in the input sequence.

    A diagram of how certain images are allowed to attend to certain pieces of text via masking. The dark blue color represents where a particular image can attend to particular text.  Source .

     

    Figure 7: Interleaved visual data and text support. Given text interleaved with images/videos, e.g. coming from a webpage, we first process the text by inserting tags at the locations of the visual data in the text as well as special tokens ( for “beginning of sequence” or for “end of chunk”). Images are processed independently by the Vision Encoder and Perceiver Resampler to extract visual tokens. At a given text token, the model only cross-attends to the visual tokens corresponding to the last preceding image/video. 𝜑 indicates which image/video a text token can attend or 0 when no image/video is preceding. In practice, this selective cross-attention is achieved through masking — illustrated here with the dark blue entries (unmasked/visible) and light blue entries (masked).

     The Flamingo Paper.

     

    Exactly how this is done is left as a bit of a mystery; the Flamingo paper inherits an unfortunate trend of ML research; transitioning from public academia to private corporate incentives.

     

    The code and the data are proprietary. — The Flamingo Paper.

     

    Flamingo was written by a bunch of PhD Googlers at DeepMind; it can’t be too hard to guess what they did. Jokes aside, we know a lot about how Flamingo handles attention, so we probably actually can approximate what was done behind the scenes.

     

    Recall that, with out causal self attention mask in the language model, we mask in this way that only exposes a word to itself and previous words:

    We can apply the same general logic to imagine what our media mask might look like in Flamingo. Based on how they define the query and the key, the actual attention calculation would look something like this:

    a conceptual diagram of the calculation of cross attention within flamingo. Two important notes. 1) Recall that the preceiver resampler defines a fixed number of tokens from a particular image or image sequence. Here I’m arbitrarily defining that as 3, but this is a hyperparameter. 2) each image could be a single image, or a sequence of images.

     

    Then, presumably, the mask that Flamingo employs would simply mask out any relationships that didn’t obey the “immediately preceding image” rule.

    A conceptual diagram with Flamingo’s media mask defined over cross attention.

     

    The Flamingo paper talks a lot about “Φ” (phi). All phi is, in flamingo, is a function that programmatically keeps track of which text should belong with which image so that this mask can be correctly defined for a given input. Naturally, with different input sequences, Flamingo’s cross attention mask would change.

     

    The whole reason Flamingo does this is to improve Flamingo's ability to generally understand interleaved text and images. There’s a few theoretical ideas behind this:

     

    1. By restricting the text to attend only to the immediately preceding image, the model ensures that the textual analysis is directly relevant to the most recent visual context. This helps maintain a tight coupling between what is seen in the image and what is being described or asked in the text.

     

    2. Allowing text to attend to all previous images would significantly increase the computational complexity and memory requirements of the model.

     

    3. Focusing on the immediate image-text pair helps the model better learn the relationships between visual content and textual information. This can improve the model’s ability to generalize from seen examples to unseen examples.

     

    4. By narrowing the scope of attention, the model simplifies the learning task, which can help in reducing the risk of overfitting to irrelevant details or patterns. This focus can lead to a model that is better at understanding and responding to the specific nuances of the image-text interaction.


    Conclusion

    And that’s it.

     

    In this article we defined Multimodal modeling, and focused in on visual question answering. We defined the previous state of the art, where a bunch of disparate models were good at a bunch of disparate tasks. We then discussed how flamingo joined two existing technologies; CLIP models and transformer decoder models, combining two highly performant and generalized modeling strategies.

     

    We covered how Flamingo is, essentially, a set of tools designed to bridge CLIP and decoder models. First, we covered how the perceiver resampler takes CLIP embeddings for image sequences and compresses them down to a fixed-sized representation. Then, we discussed how Flamingo employs tanh gated cross attention to inject that information into the language model. Finally, we went over some specifics, like Flamingo’s masking strategy, to get an idea of some of the important subtleties Flamingo employs to maintain cutting-edge performance.


     

Designed by Tistory.