type
status
date
slug
summary
tags
category
icon
password
📖
“Be water,my friend.” ——李小龙
 
由于之前对流匹配算法的掌握有所欠缺,在这里重新梳理一下有关Flow matching与Conditional Flow Matching的内容框架。 参考论文:Flow Matching for Generative Modeling

流匹配思想的引入与优势

在flow matching提出之前,扩散模型在实践中已经取得了相当的成功,但其局限性在于:
  1. 路径限制:只能使用特定的简单扩散过程,限制了概率路径的多样性。
  1. 训练时间长:由于需要模拟扩散过程,训练时间较长。
让我们看看Flow Matching是怎么解决的:
痛点1:训练路径限制。 扩散模型的训练过程(前向加噪)被一个固定的、复杂的随机微分方程 (SDE) 锁死了,通常是模拟高斯噪声。你只能沿着这条路走,灵活性很差。 Flow Matching说:“我们为什么非要走SDE这条弯路?我们可以任意定义一条从噪声(比如高斯噪声)到数据(比如真实动作)的路径。”论文中提到,最简单的直线路径就非常好用!这也就是我们所说的直线流(Rectified Flow)。
痛点2:训练依赖模拟,耗时难以接受。 在扩散模型中,为了得到训练数据,必须从出发,一步步模拟加噪过程来得到。 而在Flow Matching中,因为我们自己定义了路径(比如直线),我们不需要模拟
你只需要: 1. 采样一个真实数据(比如actions)。 2. 采样一个纯噪声(比如noise)。 3. 采样一个时间(比如time)。 4. 用公式直接算出
这个“无需模拟” (simulation-free) 的特性让训练更简单、更稳定、收敛更快。既然这个算法在解决这种问题这么有效,那么它的实现机理是怎样的呢?

矢量场回归与损失函数

流匹配的核心思想是在给定目标概率路径的情况下,通过回归矢量场来匹配生成的概率路径,而无需模拟整个过程。具体来说,我们定义一个损失函数,将模型的矢量场与目标矢量场进行匹配:
其中: :模型参数。 :时间参数,均匀分布在[0,1]上。 :目标概率密度路径,在时间上的分布。 :模型学习的矢量场。 :目标矢量场,生成目标概率路径。
想法很好,但是我们该怎么解决原本未知的呢?又或者,真的能求吗?真的有必要求吗?

条件流匹配(Conditional Flow Matching

事实上,直接计算和采样目标概率路径和目标矢量场在实践中是不现实的,我们通常无法显式地获得它们。那么有没有办法让我们不需要知道这两者就能够解决我们刚刚提到的流匹配问题呢?
让我们引入两个新概念:
  • 条件概率路径:对于给定的样本,定义条件概率路径,使得在时,为简单的已知分布(如标准正态分布),在时,为集中在附近的分布。
  • 条件矢量场:对应于条件概率路径的生成矢量场
通过对条件概率路径和条件矢量场进行边际化,可以得到目标的边际概率路径和边际矢量场。由于计算过程只依靠了单个数据样本,这使得计算和采样成为可能。
当模型在成千上万的Batch上重复梯度下降计算过程时,为了在所有采样的上都表现良好(即最小化期望的CFM损失)它将被迫学习到那个“平均”的、“边际”的速度场,这种现象就叫做边际化。
既然已经有了条件流的理论支持,我们不妨把问题考虑的再简单一点:“既然我们可以任意定义路径,为什么不选一个最简单的?”
因此,我们直接假定真实数据和噪声之间走的是直线路径:。如果我们始终取,那么速度恰好就等于!只要让我们的模型在训练中趋近,就能解决原本无法解决的矢量场采样与逼近问题。

复得返自然——流匹配模型的推理过程

训练过程中,我们从走到了,在推理过程中,我们就需要倒过来,即从走到,我们无法在过程中模拟无限小的时间,但是我们可以用一个合理的步长(steps)来模拟这一连续的过程(此时)。之后的操作如下:
输入噪声(在时)调用模型,获得(在时的速度)。 使用 计算()。 调用模型,获得(在$t=0.9$时的速度)。 使用 计算()。 ...重复这个过程steps次... 最终得到,这就是我们需要的输出。

总结

这周的学习重点是深入Pi0模型背后的“流匹配”(Flow Matching)思想。
一开始,我只是看pi0.pygemma.py的代码,感觉非常困惑。我最不理解的就是compute_loss里的那行核心代码:u_t = noise - actions。为什么一个叫u_t(显然代表速度, velocity)的东西,可以用两个“位置”(noiseactions)相减得到?
为了搞懂这个,我深入研读了"Flow Matching for Generative Modeling"这篇论文的理论。我最大的“啊哈”时刻,是搞懂了“直线流”(Rectified Flow)的概念。论文的天才之处在于,它没有用传统扩散模型那种复杂的随机路径,而是定义了一条极其简单的直线路径
总的来说,这周的学习让我豁然开朗。Pi0的架构本质上就是实现了一个以为条件的"直线流"ODE求解器。而Pi0.5的升级,比如AdaRMS,则是让这个求解器在每一步(每个时间)都能动态调整自己的内部参数(如RMSNorm),从而实现更强的泛化能力。对"流匹配"的理解,是串起pi0.pygemma.py所有模块的关键。
在上周,由于看到大量的数学计算公式和推导,我想当然的认为流匹配是一种很高深莫测的算法,但经过仔细阅读,这个算法简单到我都有点读到发笑,不由得感叹数学的迷惑性。(当然,笔者并不轻视每一种大道至简的理论体系)接下来的学习过程中,我会铭记这段学习历程,尽量不再犯这种经验主义错误。
文档撰写方面,这周的学习笔记缺少了图例作为文字叙述的补充。这既有代码与公式的抽象性因素影响,也与笔者对文字表述能力的依赖性有关,但这种方式可能会导致自己的学习笔记只有自己捋的顺,没办法作为其他同学的参考资料,对日后写论文也没有太多帮助,在这点上我还要继续努力改正。
接下来,我想精读一下Pi0.5的论文,了解RMSNorm等升级点具体是如何量化地提升了模型性能,其中涉及的原理又是什么。此外,对于VLM中的具体数据流向细节以及编码机制等,我还需要结合论文和代码做进一步学习。

补充——实机training实践:基于mnist数据集的手写数字生成

 
为了加深理解,笔者在本机运行环境下尝试进行了基于mnist数据集的手写数字生成的模型的搭建和训练,以下是训练的详细数据

运行环境

实验环境配置

  • 操作系统 (OS): Windows 11
  • 运行环境: 适用于 Linux 的 Windows 子系统 (WSL 2)
  • Linux 发行版: Ubuntu
  • 环境管理器: Miniconda / Anaconda
  • Conda 环境名: flowmatching_env
  • 编程语言: Python 3.10
  • 核心计算框架:
    • PyTorch: 2.5.1
    • Torchvision: 0.20.1
  • 核心依赖库:
    • torchdiffeq:0.2.5
    • matplotlib:3.10.7
  • 硬件加速:
    • GPU: NVIDIA 4070 Laptop
    • CUDA (PyTorch 内置): 11.8
    • CUDA (WSL 系统级): 11.5

实验代码

 
代码中模型的架构与训练逻辑与上文提到的相符,代码中的损失函数: loss = F.mse_loss(vt_pred, images - noise) 即为真实目标速度(images - noise)和模型预测速度(vt_pred)的均方误差。

实验结果:

训练运行时,依照每10个Epoch输出一次训练结果:
epoch=0,模型的输出效果很差
epoch=0,模型的输出效果很差
 
epoch=10时的训练结果
epoch=10时的训练结果
除了少数数字外,大部分的手写数字在训练轮数为10时,已经可辨识,Loss由0.37降至0.17。
epoch=40训练结果
epoch=40训练结果
由于性能原因和参数设计等因素,在epoch=40后,后续的训练Loss趋于0.16左右,变化幅度较小,故笔者提前终止了训练过程。可见在epoch=40时,模型对于数字“2”、“4”以外的数字信息执行手写数字生成的效果较好。这也符合人们手写数字的习惯:数据集里2、4的写法规律性较低,故模型需要更多的轮数在vt_pred中学习到数字图片特征。

参考文献

arXiv.orgarXiv.orgFlow Matching for Generative Modeling
zhuanlan.zhihu.com
極東晝寢愛好家極東晝寢愛好家笔记|扩散模型(一八)Flow Matching 理论详解
zhuanlan.zhihu.com
zhuanlan.zhihu.com
【PI0/PI0.5】代码核心架构解读【PI0.5】重要知识点分析
Loading...