laitimes

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

Reports from the Heart of the Machine

Machine Heart Editorial Department

Researchers from Google Research have proposed a new image synthesis model maskGIT that uses a bidirectional transformer decoder, which has seen dramatic improvements in both performance and speed.

Generative transformers are gaining rapid popularity in synthesizing high-fidelity and high-resolution images. But the best generative transformer model to date is to treat the image as a series of tokens and decode the image in raster scan order (i.e., line by line). However, this strategy is neither optimal nor efficient.

Recently, researchers from Google Research have proposed a new image synthesis model maskGIT that uses a bidirectional transformer decoder. During training, MaskGIT learns to predict random mask tokens by focusing on tokens in all directions. In the inference phase, the model first generates all the tokens of the image at the same time, and then the previous generation becomes a conditional iteration of the refinement of the image. Experiments have shown that MaskGIT significantly outperforms the SOTA transformer model on the ImageNet dataset and improves autoregressive decoding by a factor of 64.

Address of the paper: https://arxiv.org/abs/2202.04200

In addition, the study shows that MaskGIT can be easily extended to a variety of image editing tasks, such as repair, extrapolation, and image processing.

Related research

The previous model VQVAE proposed to generate images in two stages in the potential space.

The first phase, called tokenization, in which an attempt is made to compress the image into discrete potential space, consists of three main parts:

An encoder E that learns to tokenize an image x ∈ into a potentially embedded E(x);

A codebook for nearest neighbor lookup to quantize embeddings as visual tokens;

A decoder G that reconstructs the image based on visual token e predictions.

The second stage first uses a deep autoregressive model to predict the potential a priori of the visual token, and then uses the decoder of the first stage to map the token sequence into image pixels.

This two-stage paradigm is very efficient, so several commonly used methods follow this paradigm, such as DALL-E, VQGAN. Among them, VQGAN adds adversarial loss and perception loss in the first stage to improve image fidelity.

MaskGIT

The above method using the two-stage paradigm still uses an autoregressive model, so the decoding time of the second stage is proportional to the length of the token sequence. The goal of this study is to design a new image synthesis paradigm that utilizes parallel decoding and bidirectional generation, following the above two-stage scheme and improving the second stage. The first phase takes the same settings as the VQGAN model and leaves potential improvements to the tokenization step for future work; for the second phase, the researchers propose to learn bidirectional transformers by masked Visual Token Modeling (MVTM).

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

MVTM in training

The study represents the potential token obtained by entering an image into a VQ encoder, where N is the length of the reconstructed token matrix and the corresponding binary mask. During training, the study samples a subset of tokens and replaces them with a special [MASK] token. If m_i=1, the token y_i is replaced with [MASK]; if m_i=0, y_i reserved.

The sampling process is parameterized by the mask scheduling function and then follows these steps:

First sample a ratio from 0 to 1, and then select a token uniformly in Y to place the mask, where N is the length. Mask scheduling significantly affects the quality of image generation.

Iterative decoding

In autoregressive decoding, tokens are generated according to the output order of the previously generated. This process is not parallel, and the token length of the image is usually much longer than the language, so it is very slow. The study proposes a new decoding method in which all tokens in the image are generated in parallel at the same time, based on the two-way self-attention of mtVM.

Theoretically, the model is able to infer all tokens and generate the entire image in a single pass, but inconsistencies in the training tasks present a challenge for the study. To generate images at the time of inference, the study began with a blank canvas where all tokens were masked, i.e. The iterative decoding method proposed in this study, the algorithm running steps for each iteration are as follows:

1. Forecasting

2. Sampling

3. Mask scheduling

4. Mask

Mask design

The researchers found that the quality of the image generation was significantly affected by the mask design. The method models the masking process with a mask dispatch function that is responsible for calculating the mask ratio for a given potential token. During inference, the input used by the function represents the progress of the decoding; during training, the study randomly samples a ratio r in [0,1) to simulate various decoding scenarios.

experiment

The study provided an experimental evaluation of MaskGIT's image generation in terms of quality, efficiency, and flexibility.

Class conditional image composition

The study evaluated the performance of the MaskGIT model on class-conditional image compositing tasks on the ImageNet 256 X 256 and ImageNet 512 X 512, with the main results shown in Table 1 below.

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

quality. On the ImageNet 256 X 256, without using any special sampling strategies, MaskGIT is significantly superior to VQGAN in both FID and IS.

velocity. The study assesses the model speed by evaluating the number of steps (forward pass) required for each model to generate samples. As shown in Table 1, of all the non-GAN-based models, MaskGIT requires the fewest steps at both resolutions.

To further confirm the speed difference between MaskGIT and autoregressive models, the study made a runtime comparison of the decoding processes of MaskGIT and VQGAN. As shown in Figure 4 below, MaskGIT significantly accelerates VQGAN by 30-64x, and the acceleration becomes more pronounced as the image resolution (and input token length) increases.

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

diversity. In addition to sample quality, the study included Classification Accuracy Score (CAS) and Precision/Recall as two metrics to assess sample diversity. MaskGIT's samples are more diverse than BigGAN's, with a variety of lighting, poses, scale, and context, as shown in Figure 5 below.

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

Image editing app

The study demonstrates the direct application of MaskGIT to three image editing tasks: conditional image editing, image repair, and outpainting. If you think of a task as a constraint on the initial binary mask M MaskGIT in its iterative decoding, all three tasks can be easily converted to tasks that MaskGIT can handle.

The study shows that maskgit is able to produce very good results on all three applications without modifying the schema or any task-specific training. In addition, MaskGIT achieves comparable performance to specialized models in terms of image repair and extension.

On the class conditional image editing task, the study defined a new class conditional image editing task to demonstrate the flexibility of MaskGIT. The model regenerates specific content within the bounding box of a given class, while preserving context, that is, content outside the box. The autoregressive method is not feasible because the prediction order is violated.

However, for MaskGIT, this problem is solved if the bounding box region is treated as input to the initial mask of the iterative decoding algorithm. Figure 6 below shows some example results.

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

Table 2 compares quantitative results for several methods. MaskGIT beat DeepFill and HiFill by a significant margin in both FID and IS, while scoring close to coModGAN, an SOTA fix.

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

As shown in Figure 7 below, MaskGIT is also able to synthesize different results given the same input and different seeds.

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

Ablation experiments

To verify the utility of the new design, the study conducted ablation experiments on the default settings of ImageNet 256×256. A key design of MaskGIT is a mask scheduling function for training and iterative decoding, as shown in Table 3 and Figure 8 below.

Autoregressive decoding is accelerated by 64 times, and Google proposes a new model for image synthesis, MaskGIT

It's worth noting, as Figure 8 shows, with the same setup, more iterations are not necessarily better: as the number of iterations T increases, all functions, except for the logarithmic function, which performs poorly throughout the process, reaches a "sweet spot" position, where the performance of the model peaks before deteriorating again.

Read on