DALLĀ·E mini

Generate images from a text prompt in this interactive report: DALLĀ·E on a smaller architecture.
Generated with DALLĀ·E mini as "logo of an armchair in the shape of an avocado"
ļ»æ

Introduction

As part of the FLAX/JAX community week organized by šŸ¤— Hugging Face and the Google Cloud team, we worked on reproducing the results of OpenAI's DALLĀ·E with a smaller architecture. DALLĀ·E can generate new images from any text prompt.
We show we can achieve impressive results (albeit of a lower quality) while being limited to much smaller hardware resources. Our model is 27 times smaller than the original DALLĀ·E and was trained on a single TPU v3-8 for only 3 days.
By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models available, we were able to satisfy a tight timeline.
ļ»æ
DALLĀ·E mini project timeline
ļ»æ
ļ»æOur code and ļ»æļ»æour interactive demo are available for experimenting with any prompt!
ļ»æ
Demo of DALLĀ·E mini
ļ»æ

Datasets

We used 3 datasets for our model:
  • ļ»æConceptual Captions Dataset which contains 3 million image and caption pairs.
  • ļ»æConceptual 12M which contains 12 million image and caption pairs.
  • The OpenAI subset of YFCC100M which contains about 15 million images and that we further sub-sampled to 2 million images due to limitations in storage space. We used both title and description as caption and removed html tags, new lines and extra spaces.
For fine-tuning our image encoder, we only used a subset of 2 million images.
We used all the images we had (about 15 million) for training our Seq2Seq model.
ļ»æ

Model Architecture

Overview

During training, images and descriptions are both available and pass through the system as follows:
  • Images are encoded through a VQGAN encoder, which turns images into a sequence of tokens.
  • Descriptions are encoded through a BART encoder.
  • The output of the BART encoder and encoded images are fed through the BART decoder, which is an auto-regressive model whose goal is to predict the next token.
  • Loss is the softmax cross-entropy between the model prediction logits and the actual image encodings from the VQGAN.
ļ»æ
Training pipeline of DALLĀ·E mini
ļ»æ
At inference time, we only have captions available and want to generate images:
  • The caption is encoded through the BART encoder.
  • A <BOS> token (special token identifying the "Beginning Of Sequence") is fed through the BART decoder.
  • Image tokens are sampled sequentially based on the decoder's predicted distribution over the next token.
  • Sequences of image tokens are decoded through the VQGAN decoder.
  • ļ»æCLIP is used to select the best generated images.
ļ»æ
Inference pipeline of DALLĀ·E mini
ļ»æ

Image Encoder/Decoder

For encoding & decoding images, we use a VQGAN.
The goal of the VQGAN is to encode an image into a sequence of discrete tokens that can be used in transformers model which have proved to be very efficient in NLP.
ļ»æ
Source: Taming Transformers for High-Resolution Image Synthesis
ļ»æ
Using a sequence of pixel values, the embedded space of discrete values would be too large, making it extremely difficult to train a model and satisfy memory requirements for self attention layers.
The VQGAN learns a codebook of pixels by using a combination of a perceptual loss and a GAN discriminator loss. The encoder outputs the indexes corresponding to the codebook.
Once the image is encoded into a sequence of tokens, it can then be used in any transformer model.
In our model, we encode images to 16 x 16 = 256 discrete tokens from a vocabulary of size 16384, using a reduction factor f=16 (4 blocks dividing width & height by 2 each). Decoded images are then 256 x 256 (16 x 16 for each side).
For more details and better understanding of the VQGAN, please refer to Taming Transformers for High-Resolution Image Synthesis.
ļ»æ

Seq2Seq model

A seq2seq model transforms a sequence of tokens into another sequence of tokens and is typically used in NLP for tasks such as translation, summarization or conversational modeling.
The same idea can be transferred to computer vision once images have been encoded into discrete tokens.
Source: BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
ļ»æ
Our model uses BART, where the input corresponds to the description and the output is the corresponding image encoded by the VQGAN.
We only had to make a few adjustments to the original architecture:
  • Create independent embedding layers for the decoder and encoder (they often can be shared when having the same type of inputs & outputs)
  • Adjust decoder inputs and outputs shape to VQGAN vocabulary size (not needed for the intermediate embedding layers)
  • Force the generated sequence to 256 tokens (without including special tokens <BOS> and <EOS> which identify beginning and end of sequences)
ļ»æ

CLIP

CLIP is a neural network able to create correlation between images and text.
It is trained using contrastive learning, which consists of maximizing the product between a pair of image and text embeddings (also called cosine similarity) and minimizing it between non-associated pairs.
Source: Learning Transferable Visual Models From Natural Language Supervision
ļ»æ
When generating images, we perform random sampling of image tokens based on the model logits distribution, which leads to diverse samples but of unequal quality.
CLIP lets us select the best generated samples by giving a score to the generated images against their input description. We directly use the pre-trained version from OpenAI in our inference pipeline.
ļ»æ

How does it compare to OpenAI DALLĀ·E?

We are grateful for the research and pre-trained models published by OpenAI which were essential in building our model.
Not all the details on DALLĀ·E are public knowledge but here are what we consider to be the main differences:
  • DALLĀ·E uses a 12 billion parameter version of GPT-3. In comparison our model is 27 times smaller with about 0.4 billion parameters.
  • We heavily leverage pre-trained models (VQGAN, BART encoder and CLIP) while OpenAI had to train all their models from scratch. Our model architecture takes into account pre-trained models available and their efficiency.
  • DALLĀ·E encodes images using a larger number of tokens (1024 vs 256) from a smaller vocabulary (8192 vs 16384). DALLĀ·E uses a VQVAE while we use a VQGAN.
  • DALLĀ·E encodes text using fewer tokens (at most 256 vs 1024) and a smaller vocabulary (16,384 vs 50,264).
  • DALLĀ·E reads text and images as a single stream of data while we split them between the Seq2Seq Encoder and Decoder. This also let us use independent vocabulary for text and images.
  • DALLĀ·E reads the text through an auto-regressive model while we use a bidirectional encoder.
  • DALLĀ·E was trained on 250 million pairs of image and text while we used only 15 million pairs.
Those differences have led to an efficient training that can be performed on a single TPU v3-8 in 3 days.
Since we automatically checkpoint and log our model every hour, we could use preemptible TPU instances ($2.40/h at the time of this report), meaning a training cost of less than $200 for our model. This does not include our experimentation on TPU's and hyperparameter search which would add about $1,000 in our case (TPU resources were actually provided for free as part of this project).
Images generated by DALLĀ·E are still of a much higher quality than our model's but it's interesting to observe we can train a reasonably good model with few resources.
ļ»æ

Training the model

Training the VQGAN

We started with a pre-trained checkpoint fine-tuned on ImageNet with a reduction factor f=16 and a vocabulary size of 16,384.
While being extremely efficient at encoding a large range of images, the pre-trained checkpoint was not good at encoding people and faces (they are not frequent in ImageNet) so we decided to fine-tune it for about 20h on a cloud instance of 2 x RTX A6000.
Reconstructions
Step 1
card
Step 19
card
Step 1704
card
Inputs
Step 1
card
Step 19
card
Step 1704
card
Run set
1
The quality of generated images didn't improve a lot on faces, probably due to mode collapse. It would be worthwhile to retrain it from scratch in the future.
Once the model was trained, we converted our Pytorch model to JAX for the next phase.
ļ»æ

Training DALLĀ·E mini

The model is programmed in JAX to take full advantage of the TPU's.
We pre-encoded all our images with the image encoder for faster data loading.
We quickly settled on few parameters that seemed to work well:
  • Batch size per TPU per step: 56 to max the the memory available per TPU
  • Gradient accumulation: 8 steps for an effective batch size of 56 x 8 TPU chips x 8 steps = 3584 images per update.
  • Optimizer: Adafactor for its memory efficiency which let us use a higher batch size.
  • Learning rate with 2,000 warmup steps and a linear decay.
We dedicated half a day to finding a good learning rate for our model by launching a hyper-parameter search.
ļ»æ
1e-42e-43e-44e-45e-46e-41e-32e-33e-34e-3learning_rate8.48.28.07.87.67.47.27.06.86.66.46.2eval/loss
Sweep: 5igzdwfv 1
6
ļ»æ
After our preliminary search, we experimented with a few different learning rates for a longer period until we finally settled with 0.005.
ļ»æ
learning rate (log scale)
5k10k15k20ktrain/step0.00070.00080.00090.0010.0020.0030.0040.005
eval/loss
5k10k15k20k25ktrain/step5.45.65.866.26.4
train/loss
5k10k15k20k25ktrain/step5.45.65.866.26.46.6
Run set
4
ļ»æ
Training could have continued longer as the evaluation loss was still improving well but the project was ending (as was the availability of the TPU VM).
ļ»æ
Validation loss in the final 24h (run resumed from checkpoint)
2025303540Time (hours)5.15.115.125.135.145.15
Run set
704
ļ»æ

Results

Sample predictions

For each prompt, we generate 128 images and select the best 8 images with CLIP.
ļ»æ
Note: the Unreal Engine trick does not seem to affect our model. It is possible that our dataset did not have such image-text pairs and that including them would affect the predictions.
ļ»æ

Evolution of predictions during training

We can clearly see how the quality of generated images improved as the model trained.
ļ»æ
Visualize different examples by clicking on āš™ļø at the top left of the panel and changing index.
šŸ’”
Run set
4
ļ»æ

How do our results compare with OpenAI's DALLĀ·E

The model fails on several prompts published for OpenAI's DALLĀ·E.
It is interesting to note that OpenAI often uses very long repeating prompts such as:
a storefront that has the word 'openai' written on it. a storefront that has the word 'openai' written on it. a storefront that has the word 'openai' written on it. openai storefront.
This may be due to having their prompts defined as the concatenation between image titles and descriptions.
We did not notice any significant impact in using longer prompts in our model.
ļ»æ
Run set
9
ļ»æ

How do our results compare to DALLE-pytorch

The best open-source version of DALLĀ·E that we were aware of when developing our model was lucidrains/DALLE-pytorch.
It offers many different models, allows for plenty of customization (model size, image encoder, custom attention heads, etc), and seems to have been trained on similar datasets as ours. It has shown impressive results especially when trained on smaller datasets.
For this comparison, we use checkpoint 16L_64HD_8H_512I_128T_cc12m_cc3m_3E.pt which is the current recommended model and select the top 8 predictions out of 128 according to CLIP to follow our inference pipeline.
Both models can generate impressive results, especially on landscapes.
Overall, DALLĀ·E mini seems to be able to produce more relevant images, of a slightly better quality, and with more details.
ļ»æ
Visualize different examples by clicking on āš™ļø at the top left of the panel and changing index.
šŸ’”
Run set
1
ļ»æ

How do our results compare to "Generator + CLIP"

There are several models available which consist of a generator coupled with CLIP to create images (such as "VQGAN + CLIP").
These models have a completely different approach. Each image prediction is actually the result of an optimization process where we iterate over the latent space of the generator (image encoding space) to directly maximize the CLIP score between generated image and description.
An interesting aspect of this method is that we can iterate either from a random image or from a pre-selected image. Also it can be used with any image resolution, constrained only by GPU RAM and time to train.
ļ»æ
Sample predictions using "VQGAN + CLIP"
ļ»æ
This technique is slower and mostly used for generating artistic images which could be unrealistic but of a higher resolution.
ļ»æ

Limitations and biases

During our experiments, we observed several limitations:
  • Watermarks are often present on generated samples.ļ»æļ»æ
  • Faces and people in general are not generated properly.
  • Animals are usually unrealistic.
  • It is hard to predict where the model excels or falls short. For example the model is great at generating "a logo of an armchair in the shape of an avocado" but cannot produce anything relevant for "a logo of a computer" (in this case we need to adjust to "an illustration of a computer"). Reformulating matters! The goal is to write a description similar to what could have been seen during training. Good prompt engineering will lead to the best results.
  • The model has only been trained with English descriptions and will not perform well in other languages. This can potentially be fixed using a translation service or model in our inference pipeline, but needs to be evaluated in more details.
ļ»æ
Overall it is difficult to investigate in much detail the model biases due to the low quality of generated people and faces, but it is nevertheless clear that biases are present:
  • Occupations demonstrating higher levels of education (such as engineers, doctors or scientists) or high physical labor (such as in the construction industry) are mostly represented by white men. In contrast, nurses, secretaries or assistants are typically women, often white as well.
  • Most of the people generated are white. It's only on specific examples such as athletes that we will see different races, though most of them still under-represented.
  • The dataset is limited to pictures with English descriptions, preventing text and images from non-English speaking cultures to be represented.
ļ»æ
According to Conceptual 12M paper:
We study the context in which several sensitive terms related to gender, age, race, ethnicity appear such as ā€œblackā€, ā€œwhiteā€, ā€œasianā€, ā€œafricanā€, ā€œamericanā€, ā€œindianā€, ā€œman/menā€, ā€œwoman/womenā€, ā€œboyā€, ā€œgirlā€, ā€œyoungā€, ā€œoldā€, etc. We do not observe any large biases in the distribution of these terms, either in terms of co-occurrence between sensitive term pairs or co-occurrence with other tokens. Furthermore, we check the distribution of web domains and, similar to visual concepts, we find this to be diverse and long-tail: >100K with >40K contributing >10 samples. We take our preliminary study as a positive indication of no severe biases stemming from particular domains or communities.
Since this dataset represents only 70% of all the data we used, it is possible that bias was introduced by:
  • the other datasets
  • the model itself
  • our training pipeline
  • our inference pipeline
  • the pre-trained models we used (mainly BART encoder or CLIP during scoring)
  • a combination of all the above, including potentially undetected bias from the preliminary study done with Conceptual 12M
Since we are releasing a public demo, we will be able to collect feedback from users and get more understanding of our model's limitations and biases. The next step will be to find ways to mitigate them.
ļ»æ

Looking forward

Some improvements can be made on the model:
  • Dataset
    • We can use a larger dataset, we didn't use all the images we had available.
    • We need to better filter the dataset: duplicates, low quality images, watermarks, bad descriptions, etc. The use of Neural Networks can be helpful for these tasks.
  • Text processing
    • We can improve how we pre-process title and description of images and concatenate them based on their quality.
    • We can test different types of tokenizers & encoders.
    • We can try to normalize the text: all lower case (though it probably helps identify names and places), no punctuation, filter allowed characters.
  • Image Encoder/Decoder
    • Our model is limited by the quality of the Image Encoder/Decoder.
    • We can train the VQGAN from scratch which could limit some of the mode collapse acquired by a pre-trained model (though not necessarily avoid it completely)
    • We can explore training more efficient models such as VQGAN with Gumbel-Quantization as explored in CompVis/taming-transformers.
  • Model
    • We can scale up the model.
    • We can train longer and leverage more hardware resources.
    • We can try to generate the image in a different sequence (for example starting from the center).
  • Inference
    • We didn't test the model on conditional tasks where an image is also provided.
  • Limitations & Biases
    • We need to review all the limitations and biases and research ways to address them.
ļ»æ

Resources

ļ»æ

References

ļ»æ

Authors

ļ»æ

Acknowledgements

Write a comment...
Furkan POLAT ā€¢  
Hello, there seems to be a problem with your dall-e mini project. I tried for 2 days but every time I got this error "Service unavailable, status: 404". Will there be a fix for this issue? I have a youtube channel and was preparing to make a video about Dall-e mini. But the program does not work.
2 replies
Hanif Janmohamed ā€¢  
Hi Boris, et al - super cool work! I'm impressed by what dall:e mini generates and the expressive and evocative qualities of the imagery. A little uncanny. I'm a visual artist developing an interactive public engagement artwork and would love to learn more about Dall:e mini and explore further!
3 replies
Ramansh Sharma ā€¢  
An excellent article! I learned quite a bit about the different components of DALL-E and DALL-E-mini
1 reply
Kevin Shen ā€¢  
Nice read! The simplified model does pretty well. It would be cool to see improvements made on dataset size and quality (such as removing watermarks).
Reply