ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Progressive Distillation for Fast Sampling of Diffusion Models
    Research/Diffusion 2025. 1. 23. 18:13

    https://arxiv.org/pdf/2202.00512

    https://github.com/google-research/google-research/tree/master/diffusion_distillation


    Abstract

    Diffusion models have recently shown great promise for generative modeling, outperforming GANs on perceptual quality and autoregressive models at density estimation. A remaining downside is their slow sampling time: generating high quality samples takes many hundreds or thousands of model evaluations. Here we make two contributions to help eliminate this downside: First, we present new parameterizations of diffusion models that provide increased stability when using few sampling steps. Second, we present a method to distill a trained deterministic diffusion sampler, using many steps, into a new diffusion model that takes half as many sampling steps. We then keep progressively applying this distillation procedure to our model, halving the number of required sampling steps each time. On standard image generation benchmarks like CIFAR-10, ImageNet, and LSUN, we start out with state-of-the-art samplers taking as many as 8192 steps, and are able to distill down to models taking as few as 4 steps without losing much perceptual quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps. Finally, we show that the full progressive distillation procedure does not take more time than it takes to train the original model, thus representing an efficient solution for generative modeling using diffusion at both train and test time.


    1. Introduction

    Diffusion models (Sohl-Dickstein et al., 2015; Song & Ermon, 2019; Ho et al., 2020) are an emerging class of generative models that has recently delivered impressive results on many standard generative modeling benchmarks. These models have achieved ImageNet generation results outperforming BigGAN-deep and VQ-VAE-2 in terms of FID score and classification accuracy score (Ho et al., 2021; Dhariwal & Nichol, 2021), and they have achieved likelihoods outperforming autoregressive image models (Kingma et al., 2021; Song et al., 2021b). They have also succeeded in image super-resolution (Saharia et al., 2021; Li et al., 2021) and image inpainting (Song et al., 2021c), and there have been promising results in shape generation (Cai et al., 2020), graph generation (Niu et al., 2020), and text generation (Hoogeboom et al., 2021; Austin et al., 2021).

     

    A major barrier remains to practical adoption of diffusion models: sampling speed. While sampling can be accomplished in relatively few steps in strongly conditioned settings, such as text-to-speech (Chen et al., 2021) and image super-resolution (Saharia et al., 2021), or when guiding the sampler using an auxiliary classifier (Dhariwal & Nichol, 2021), the situation is substantially different in settings in which there is less conditioning information available. Examples of such settings are unconditional and standard class-conditional image generation, which currently require hundreds or thousands of steps using network evaluations that are not amenable to the caching optimizations of other types of generative models (Ramachandran et al., 2017).

     

    In this paper, we reduce the sampling time of diffusion models by orders of magnitude in unconditional and class-conditional image generation, which represent the setting in which diffusion models have been slowest in previous work. We present a procedure to distill the behavior of a N-step DDIM sampler (Song et al., 2021a) for a pretrained diffusion model into a new model with N/2 steps, with little degradation in sample quality. In what we call progressive distillation, we repeat this distillation procedure to produce models that generate in as few as 4 steps, still maintaining sample quality competitive with state-of-the-art models using thousands of steps.


    2. Background on Diffusion Models


    3. Progressive Distillation

    To make diffusion models more efficient at sampling time, we propose progressive distillation: an algorithm that iteratively halves the number of required sampling steps by distilling a slow teacher diffusion model into a faster student model. Our implementation of progressive distillation stays very close to the implementation for training the original diffusion model, as described by e.g. Ho et al. (2020). Algorithm 1 and Algorithm 2 present diffusion model training and progressive distillation side-by-side, with the relative changes in progressive distillation highlighted in green.

     

    We start the progressive distillation procedure with a teacher diffusion model that is obtained by training in the standard way. At every iteration of progressive distillation, we then initialize the student model with a copy of the teacher, using both the same parameters and same model definition. Like in standard training, we then sample data from the training set and add noise to it, before forming the training loss by applying the student denoising model to this noisy data zt. The main difference in progressive distillation is in how we set the target for the denoising model: instead of the original data x, we have the student model denoise towards a target x˜ that makes a single student DDIM step match 2 teacher DDIM steps. We calculate this target value by running 2 DDIM sampling steps using the teacher, starting from z_t and ending at z_t − 1/N , with N being the number of student sampling steps. By inverting a single step of DDIM, we then calculate the value the student model would need to predict in order to move from z_t to z_t −1/N in a single step, as we show in detail in Appendix G. The resulting target value x˜(zt) is fully determined given the teacher model and starting point z_t, which allows the student model to make a sharp prediction when evaluated at z_t. In contrast, the original data point x is not fully determined given z_t, since multiple different data points x can produce the same noisy data z_t: this means that the original denoising model is predicting a weighted average of possible x values, which produces a blurry prediction. By making sharper predictions, the student model can make faster progress during sampling.

     

    After running distillation to learn a student model taking N sampling steps, we can repeat the procedure with N/2 steps: The student model then becomes the new teacher, and a new student model is initialized by making a copy of this model.

     

    Unlike our procedure for training the original model, we always run progressive distillation in discrete time: we sample this discrete time such that the highest time index corresponds to a signal-to-noise ratio of zero, i.e. α1 = 0, which exactly matches the distribution of input noise z1 ∼ N (0, I) that is used at test time. We found this to work slightly better than starting from a non-zero signal-to-noise ratio as used by e.g. Ho et al. (2020), both for training the original model as well as when performing progressive distillation.


    4. Diffusion Model Parameterization and Training Loss


    5. Experiments

    In this section we empirically validate the progressive distillation algorithm proposed in Section 3, as well as the parameterizations and loss weightings considered in Section 4. We consider various image generation benchmarks, with resolution varying from 32 × 32 to 128 × 128. All experiments use the cosine schedule αt = cos(0.5πt), and all models use a U-Net architecture similar to that introduced by Ho et al. (2020), but with BigGAN-style up- and downsampling (Brock et al., 2019), as used in the diffusion modeling setting by Nichol & Dhariwal (2021); Song et al. (2021c). Our training setup closely matches the open source code by Ho et al. (2020). Exact details are given in Appendix E.


    5.1. Model Parameterization and Training Loss

    As explained in Section 4, the standard method of having our model predict ε, and minimizing mean squared error in the ε-space (Ho et al., 2020), is not appropriate for use with progressive distillation. We therefore proposed various alternative parameterizations of the denoising diffusion model that are stable under the progressive distillation procedure, as well as various weighting functions for the reconstruction error in x-space. Here, we perform a complete ablation experiment of all parameterizations and loss weightings considered in Section 4. For computational efficiency, and for comparisons to established methods in the literature, we use unconditional CIFAR-10 as the benchmark. We measure performance of undistilled models trained from scratch, to avoid introducing too many factors of variation into our analysis.

     

    Table 1 lists the results of the ablation study. Overall results are fairly close across different parameterizations and loss weights. All proposed stable model specifications achieve excellent performance, with the exception of the combination of outputting ε with the neural network and weighting the loss with the truncated SNR, which we find to be unstable. Both predicting x directly, as well as predicting v, or the combination (ε, x), could thus be recommended for specification of diffusion models. Here, predicting v is the most stable option, as it has the unique property of making DDIM step-sizes independent of the SNR (see Appendix D), but predicting x gives slightly better empirical results in this ablation study.


    5.2. Progressive Distillation

    We evaluate our proposed progressive distillation algorithm on 4 data sets: CIFAR-10, 64 × 64 downsampled ImageNet, 128 × 128 LSUN bedrooms, and 128 × 128 LSUN Church-Outdoor. For each data set we start by training a baseline model, after which we start the progressive distillation procedure. For CIFAR-10 we start progressive distillation from a teacher model taking 8192 steps. For the bigger data sets we start at 1024 steps. At every iteration of distillation we train for 50 thousand parameter updates, except for the distillation to 2 and 1 sampling steps, for which we use 100 thousand updates. We report FID results obtained after each iteration of the algorithm. Using these settings, the computational cost of progressive distillation to 4 sampling steps is comparable or less than for training the original model. In Appendix I we show that this computational cost can be reduce much further still, at a small cost in performance.

     

    In Figure 4 we plot the resulting FID scores (Heusel et al., 2017) obtained for each number of sampling steps. We compare against the undistilled DDIM sampler, as well as to a highly optimized stochastic baseline sampler. For all four data sets, progressive distillation produces near optimal results up to 4 or 8 sampling steps. At 2 or 1 sampling steps, the sample quality degrades relatively more quickly. In contrast, the quality of the DDIM and stochastic samplers degrades very sharply after reducing the number of sampling steps below 128. Overall, we conclude that progressive distillation is thus an attractive solution for computational budgets that allow less than or equal to 128 sampling steps. Although our distillation procedure is designed for use with the DDIM sampler, the resulting distilled models can in principle also be used with stochastic sampling: we investigate this in Appendix F, and find that it achieves performance that falls in between the distilled DDIM sampler and the undistilled stochastic sampler.

     

    Table 2 shows some of our results on CIFAR-10, and compares against other fast sampling methods in the literature: Our method compares favorably and attains higher sampling quality in fewer steps than most of the alternative methods. Figure 3 shows some random samples from our model obtained at different phases of the distillation process. Additional samples are provided in Appendix H.


    6. Related Work on Fast Sampling

    Our proposed method is closest to the work of Luhman & Luhman (2021), who perform distillation of DDIM teacher models into one-step student models. A possible downside of their method is that it requires constructing a large data set by running the original model at its full number of sampling steps: their cost of distillation thus scales linearly with this number of steps, which can be prohibitive. In contrast, our method never needs to run the original model at the full number of sampling steps: at every iteration of progressive distillation, the number of model evaluations is independent of the number of teacher sampling steps, allowing our method to scale up to large numbers of teacher steps at a logarithmic cost in total distillation time.

     

    DDIM (Song et al., 2021a) was originally shown to be effective for few-step sampling, as was the probability flow sampler (Song et al., 2021c). Jolicoeur-Martineau et al. (2021) study fast SDE integrators for reverse diffusion processes, and Tzen & Raginsky (2019b) study unbiased samplers which may be useful for fast, high quality sampling as well.

     

    Other work on fast sampling can be viewed as manual or automated methods to adjust samplers or diffusion processes for fast generation. Nichol & Dhariwal (2021); Kong & Ping (2021) describe methods to adjust a discrete time diffusion model trained on many timesteps into models that can sample in few timesteps. Watson et al. (2021) describe a dynamic programming algorithm to reduce the number of timesteps for a diffusion model in a way that is optimal for log likelihood. Chen et al. (2021); Saharia et al. (2021); Ho et al. (2021) train diffusion models over continuous noise levels and tune samplers post training by adjusting the noise levels of a few-step discrete time reverse diffusion process. Their method is effective in highly conditioned settings such as text-to-speech and image super-resolution. San-Roman et al. (2021) train a new network to estimate the noise level of noisy data and show how to use this estimate to speed up sampling.

     

    Alternative specifications of the diffusion model can also lend themselves to fast sampling, such as modified forward and reverse processes (Nachmani et al., 2021; Lam et al., 2021) and training diffusion models in latent space (Vahdat et al., 2021).


    7. Discussion

    We have presented progressive distillation, a method to drastically reduce the number of sampling steps required for high quality generation of images, and potentially other data, using diffusion models with deterministic samplers like DDIM (Song et al., 2020). By making these models cheaper to run at test time, we hope to increase their usefulness for practical applications, for which running time and computational requirements often represent important constraints.

     

    In the current work we limited ourselves to setups where the student model has the same architecture and number of parameters as the teacher model: in future work we hope to relax this constraint and explore settings where the student model is smaller, potentially enabling further gains in test time computational requirements. In addition, we hope to move past the generation of images and also explore progressive distillation of diffusion models for different data modalities such as e.g. audio (Chen et al., 2021).

     

    In addition to the proposed distillation procedure, some of our progress was realized through different parameterizations of the diffusion model and its training loss. We expect to see more progress in this direction as the community further explores this model class.


    A. Probability Flow ODE in terms of Log-SNR


    B. DDIM is an Integrator of the Probability Flow ODE


    C. Evaluation of Integrators of the Probability Flow ODE


    D. Expression of DDIM in Angular Parameterization


    G. Derivation of the Distillation Target


     

Designed by Tistory.