type
status
date
slug
summary
tags
category
icon
password
由于pi0与pi0.5在架构上的相似性,PI公司在其开源的openpi代码仓库中的部分设计采用了代码复用的形式,笔者在尝试归纳后进行统一的核心梳理。有关pi0.5的论文精读,会在接下来的学习中尽快完成。
pi0.py中的pi0模型类(class Pi0)
pi0.py是Pi0/Pi0.5决策模型的核心实现文件。它定义了整个多模态策略网络的“骨架”,负责将底层的视觉-语言Backbone(PaliGemma)与机器人动作的生成任务相结合。
该文件的主要结构和定义内容包括:
class Pi0:这是模型的主类。__init__:初始化函数,负责组装所有组件,包括加载PaliGemma Backbone(llm和img),以及定义动作投影层。它包含了Pi0和Pi0.5的关键架构切换逻辑(if config.pi05:),用于决定是使用传统的state_proj和action_time_mlp,还是使用Pi0.5的time_mlp(为AdaRMS准备)。embed_prefix:将观测(图像、文本)编码为“前缀”token序列。embed_suffix:将带噪声的动作和扩散时间t编码为“后缀”token序列。这是 Pi0.5的AdaRMS交互的起点,它会(在 Pi0.5 模式下)生成并返回一个adarms_cond条件向量。compute_loss:实现"Conditional Flow Matching"(流匹配)的核心。它通过采样噪声和时间来定义一个直线流,计算出目标速度u_t = noise - actions,然后将模型预测v_t与之比较以计算均方误差损失。sample_actions:实现推理(动作采样)。它使用jax.lax.while_loop执行一个迭代去噪循环(欧拉法求解器),并利用KV缓存技术优化前缀(观测)的处理速度make_attn_mask:创建核心的“前缀-LM”注意力掩码,控制前缀和后缀之间的信息流。posemb_sincos:用于将标量的扩散时间t编码为高维向量
除此之外还有两个辅助函数:
初始化(__init__)
输入嵌入(embed_prefix & embed_suffix)
损失计算(compute_loss)
动作采样(sample_actions)
让我们关注pi0.py是如何与Backbone交互的:
- 混合输入 (Tokens):一个包含
[prefix_tokens, suffix_tokens]的列表,gemma.py将其视为[PaliGemma_tokens, ActionExpert_tokens]。
- 条件 (Condition):一个
adarms_cond=[None, adarms_cond]列表,其中第二个元素是 Pi0.5 独有的时间编码向量。
下面让我们深入gemma.py,寻找这两个结构是如何被消费的。
gemma.py
gemma.py定义了 Pi0 模型所依赖的Gemma Transformer Backbone(即
pi0.py中的self.PaliGemma.llm)。这个文件不仅是标准的Transformer实现,它还被特别设计用来支持Pi0架构的两个关键特性:多专家输入(同时处理PaliGemma和ActionExpert的token)和Pi0.5的动态条件注入(AdaRMS)。gemma.py源文件地址放置在超链接中,由于篇幅原因,接下来的学习总结省略笔者的思考过程,只保留了相关资料说明部分。class Module
Gemma Transformer的主模块。它的
__call__方法是Pi0交互的入口点,被设计为接收一个Token序列的列表(embedded)和一个条件列表 (adarms_cond),并将它们广播到模型的所有层。Module类接收
embedded: Sequence[...](pi0.py传输的[prefix_tokens, suffix_tokens]列表)与adarms_cond: Sequence[...] | None = None(pi0.py传入的[None, adarms_cond]列表),并会将adarms_cond广播到每一层和最后的归一化层(final_norms)class block
定义了单个 Transformer 层。它负责编排
RMSNorm、Attention和FeedForward之间的交互,并将adarms_cond列表中的相应条件传递给它包含的 RMSNorm层。class RMSNorm
它的
__call__方法接收一个cond参数。(就是adarms_cond列表中的元素(时间向量或None)- 如果
cond为None(Pi0模式),它执行标准RMSNorm
- 如果
cond不为None(Pi0.5模式,接收到时间向量),它会通过一个线性层动态生成scaleshiftgate三个控制信号,并用它们来调整归一化过程 (normed_inputs * (1 + scale) + shift)。
根据是否有gate的返回,在
_gated_residual(门控残差连接)中会执行不同的残差连接方式:- Pi0路径:执行标准残差连接
x + y。
- Pi0.5路径:执行门控残差连接
x + y * gate。会动态地缩放子层的输出,然后再加回去。
class Attention
注意力模块。它被设计为可以处理一个
xs(Token 列表),并使用_name(..., i) 辅助函数为不同“专家”(PaliGemma 和 ActionExpert)使用不同的投影权重。它将所有专家的Q,K,V拼接起来进行统一的注意力计算,然后再将结果拆分。- 作者:CreamGreen.
- 链接:www.creamgreen.com/article/297555f7-8779-80ee-912a-f9920dd2cd23
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章
