新开传奇私服

传奇私服发布网

当前位置:首页 > 互联网 IT业界 > 颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源

admin 互联网 IT业界 25热度

  新智元报道

  编辑:LRS 好困

  Masked Diffusion Transformer V2 在 ImageNet benchmark 上实现了 1.58 的 FID score 的新 SoTA,并通过 mask modeling 表征学习策略大幅提升了 DiT 的训练速度。

  DiT 作为效果惊艳的 Sora 的核心技术之一,利用 Difffusion Transfomer 将生成模型扩展到更大的模型规模,从而实现高质量的图像生成。

  然而,更大的模型规模导致训练成本飙升。

  为此,来自 Sea AI Lab、南开大学、昆仑万维 2050 研究院的颜水成和程明明研究团队在 ICCV 2023 提出的 Masked Diffusion Transformer 利用 mask modeling 表征学习策略通过学习语义表征信息来大幅加速 Diffusion Transfomer 的训练速度,并实现 SoTA 的图像生成效果。

  论文地址:https://arxiv.org/abs/2303.14389

  GitHub 地址:https://github.com/sail-sg/MDT

  近日,Masked Diffusion Transformer V2 再次刷新 SoTA, 相比 DiT 的训练速度提升 10 倍以上,并实现了 ImageNet benchmark 上 1.58 的 FID score。

  最新版本的论文和代码均已开源。

  背景

  尽管以 DiT 为代表的扩散模型在图像生成领域取得了显著的成功,但研究者发现扩散模型往往难以高效地学习图像中物体各部分之间的语义关系,这一局限性导致了训练过程的低收敛效率。

  例如上图所示,DiT 在第 50k 次训练步骤时已经学会生成狗的毛发纹理,然后在第 200k 次训练步骤时才学会生成狗的一只眼睛和嘴巴,但是却漏生成了另一只眼睛。

  即使在第 300k 次训练步骤时,DiT 生成的狗的两只耳朵的相对位置也不是非常准确。

  这一训练学习过程揭示了扩散模型未能高效地学习到图像中物体各部分之间的语义关系,而只是独立地学习每个物体的语义信息。

  研究者推测这一现象的原因是扩散模型通过最小化每个像素的预测损失来学习真实图像数据的分布,这个过程忽略了图像中物体各部分之间的语义相对关系,因此导致模型的收敛速度缓慢。

  方法:Masked Diffusion Transformer

  受到上述观察的启发,研究者提出了 Masked Diffusion Transformer (MDT) 提高扩散模型的训练效率和生成质量。

  MDT 提出了一种针对 Diffusion Transformer 设计的 mask modeling 表征学习策略,以显式地增强 Diffusion Transformer 对上下文语义信息的学习能力,并增强图像中物体之间语义信息的关联学习。

  如上图所示,MDT 在保持扩散训练过程的同时引入 mask modeling 学习策略。通过 mask 部分加噪声的图像 token,MDT 利用一个非对称 Diffusion Transformer (Asymmetric Diffusion Transformer) 架构从未被 mask 的加噪声的图像 token 预测被 mask 部分的图像 token,从而同时实现 mask modeling 和扩散训练过程。

  在推理过程中,MDT 仍保持标准的扩散生成过程。MDT 的设计有助于 Diffusion Transformer 同时具有 mask modeling 表征学习带来的语义信息表达能力和扩散模型对图像细节的生成能力。

  具体而言,MDT 通过 VAE encoder 将图片映射到 latent 空间,并在 latent 空间中进行处理以节省计算成本。

  在训练过程中,MDT 首先 mask 掉部分加噪声后的图像 token,并将剩余的 token 送入 Asymmetric Diffusion Transformer 来预测去噪声后的全部图像 token。

  Asymmetric Diffusion Transformer 架构

  如上图所示,Asymmetric Diffusion Transformer 架构包含 encoder、side-interpolater(辅助插值器)和 decoder。

  在训练过程中,Encoder 只处理未被 mask 的 token;而在推理过程中,由于没有 mask 步骤,它会处理所有 token。

  因此,为了保证在训练或推理阶段,decoder 始终能处理所有的 token,研究者们提出了一个方案:在训练过程中,通过一个由 DiT block 组成的辅助插值器(如上图所示),从 encoder 的输出中插值预测出被 mask 的 token,并在推理阶段将其移除因而不增加任何推理开销。

  MDT 的 encoder 和 decoder 在标准的 DiT block 中插入全局和局部位置编码信息以帮助预测 mask 部分的 token。

  Asymmetric Diffusion Transformer V2

  如上图所示,MDTv2 通过引入了一个针对 Masked Diffusion 过程设计的更为高效的宏观网络结构,进一步优化了 diffusion 和 mask modeling 的学习过程。

  这包括在 encoder 中融合了U-Net 式的 long-shortcut,在 decoder 中集成了 dense input-shortcut。

  其中,dense input-shortcut 将添加噪后的被 mask 的 token 送入 decoder,保留了被 mask 的 token 对应的噪声信息,从而有助于 diffusion 过程的训练。

  此外,MDT 还引入了包括采用更快的 Adan 优化器、time-step 相关的损失权重,以及扩大掩码比率等更优的训练策略来进一步加速 Masked Diffusion 模型的训练过程。

  实验结果

  ImageNet 256 基准生成质量比较

  上表比较了不同模型尺寸下 MDT 与 DiT 在 ImageNet 256 基准下的性能对比。

  显而易见,MDT 在所有模型规模上都以较少的训练成本实现了更高的 FID 分数。

  MDT 的参数和推理成本与 DiT 基本一致,因为正如前文所介绍的,MDT 推理过程中仍保持与 DiT 一致的标准的 diffusion 过程。

  对于最大的 XL 模型,经过 400k 步骤训练的 MDTv2-XL/2,显著超过了经过 7000k 步骤训练的 DiT-XL/2,FID 分数提高了 1.92。在这一 setting 下,结果表明了 MDT 相对 DiT 有约 18 倍的训练加速。

  对于小型模型,MDTv2-S/2 仍然以显著更少的训练步骤实现了相比 DiT-S/2 显著更好的性能。例如同样训练 400k 步骤,MDTv2 以 39.50 的 FID 指标大幅领先 DiT 68.40 的 FID 指标。

  更重要的是,这一结果也超过更大模型 DiT-B/2 在 400k 训练步骤下的性能(39.50 vs 43.47)。

  ImageNet 256 基准 CFG 生成质量比较

  我们还在上表中比较了 MDT 与现有方法在 classifier-free guidance 下的图像生成性能。

  MDT 以 1.79 的 FID 分数超越了以前的 SOTA DiT 和其他方法。MDTv2 进一步提升了性能,以更少的训练步骤将图像生成的 SOTA FID 得分推至新低,达到 1.58。

  与 DiT 类似,我们在训练过程中没有观察到模型的 FID 分数在继续训练时出现饱和现象。

  MDT 在 PaperWithCode 的 leaderboard 上刷新 SoTA

  收敛速度比较

  上图比较了 ImageNet 256 基准下,8×A100 GPU 上 DiT-S/2 基线、MDT-S/2 和 MDTv2-S/2 在不同训练步骤/训练时间下的 FID 性能。

  得益于更优秀的上下文学习能力,MDT 在性能和生成速度上均超越了 DiT。MDTv2 的训练收敛速度相比 DiT 提升 10 倍以上。

  MDT 在训练步骤和训练时间方面大相比 DiT 约 3 倍的速度提升。MDTv2 进一步将训练速度相比于 MDT 提高了大约 5 倍。

  例如,MDTv2-S/2 仅需 13 小时(15k 步骤)就展示出比需要大约 100 小时(1500k 步骤)训练的 DiT-S/2 更好的性能,这揭示了上下文表征学习对于扩散模型更快的生成学习至关重要。

  总结&讨论

  MDT 通过在扩散训练过程中引入类似于 MAE 的 mask modeling 表征学习方案,能够利用图像物体的上下文信息重建不完整输入图像的完整信息,从而学习图像中语义部分之间的关联关系,进而提升图像生成的质量和学习速度。

  研究者认为,通过视觉表征学习增强对物理世界的语义理解,能够提升生成模型对物理世界的模拟效果。这正与 Sora 期待的通过生成模型构建物理世界模拟器的理念不谋而合。希望该工作能够激发更多关于统一表征学习和生成学习的工作。

  参考资料:

  https://arxiv.org/abs/2303.14389

更新时间 2024-05-05 08:06:22