type
status
date
slug
summary
tags
category
icon
password
📖
由于pi0与pi0.5在架构上的相似性,PI公司在其开源的openpi代码仓库中的部分设计采用了代码复用的形式,笔者在尝试归纳后进行统一的核心梳理。有关pi0.5的论文精读,会在接下来的学习中尽快完成。

pi0.py中的pi0模型类(class Pi0)

pi0.py是Pi0/Pi0.5决策模型的核心实现文件。它定义了整个多模态策略网络的“骨架”,负责将底层的视觉-语言Backbone(PaliGemma)与机器人动作的生成任务相结合。
该文件的主要结构和定义内容包括:
  • class Pi0:这是模型的主类。
    • __init__:初始化函数,负责组装所有组件,包括加载PaliGemma Backbone(llm和img),以及定义动作投影层。它包含了Pi0和Pi0.5的关键架构切换逻辑if config.pi05:),用于决定是使用传统的state_projaction_time_mlp,还是使用Pi0.5的time_mlp(为AdaRMS准备)。
    • embed_prefix:将观测(图像、文本)编码为“前缀”token序列。
    • embed_suffix:将带噪声的动作和扩散时间t编码为“后缀”token序列。这是 Pi0.5的AdaRMS交互的起点,它会(在 Pi0.5 模式下)生成并返回一个 adarms_cond 条件向量。
    • compute_loss实现"Conditional Flow Matching"(流匹配)的核心。它通过采样噪声和时间来定义一个直线流,计算出目标速度u_t = noise - actions,然后将模型预测v_t与之比较以计算均方误差损失。
    • sample_actions实现推理(动作采样)。它使用jax.lax.while_loop执行一个迭代去噪循环(欧拉法求解器),并利用KV缓存技术优化前缀(观测)的处理速度
    • 除此之外还有两个辅助函数:
    • make_attn_mask:创建核心的“前缀-LM”注意力掩码,控制前缀和后缀之间的信息流。
    • posemb_sincos:用于将标量的扩散时间t编码为高维向量

初始化(__init__

输入嵌入(embed_prefix & embed_suffix

损失计算(compute_loss

动作采样(sample_actions

让我们关注pi0.py是如何与Backbone交互的:
  • 混合输入 (Tokens):一个包含 [prefix_tokens, suffix_tokens] 的列表,gemma.py 将其视为 [PaliGemma_tokens, ActionExpert_tokens]
  • 条件 (Condition):一个 adarms_cond=[None, adarms_cond] 列表,其中第二个元素是 Pi0.5 独有的时间编码向量。
下面让我们深入gemma.py,寻找这两个结构是如何被消费的。

gemma.py

gemma.py定义了 Pi0 模型所依赖的Gemma Transformer Backbone(即pi0.py中的self.PaliGemma.llm)。这个文件不仅是标准的Transformer实现,它还被特别设计用来支持Pi0架构的两个关键特性:多专家输入(同时处理PaliGemma和ActionExpert的token)和Pi0.5的动态条件注入(AdaRMS)。gemma.py源文件地址放置在超链接中,由于篇幅原因,接下来的学习总结省略笔者的思考过程,只保留了相关资料说明部分。

class Module

Gemma Transformer的主模块。它的__call__方法是Pi0交互的入口点,被设计为接收一个Token序列的列表(embedded)和一个条件列表 (adarms_cond),并将它们广播到模型的所有层。
Module类接收embedded: Sequence[...](pi0.py传输的[prefix_tokens, suffix_tokens]列表)与adarms_cond: Sequence[...] | None = None(pi0.py传入的[None, adarms_cond]列表),并会将adarms_cond广播到每一层和最后的归一化层(final_norms)

class block

定义了单个 Transformer 层。它负责编排RMSNormAttentionFeedForward之间的交互,并将adarms_cond列表中的相应条件传递给它包含的 RMSNorm层。

class RMSNorm

它的__call__方法接收一个cond参数。(就是adarms_cond列表中的元素(时间向量或None
  • 如果condNone(Pi0模式),它执行标准RMSNorm
  • 如果cond不为None(Pi0.5模式,接收到时间向量),它会通过一个线性层动态生成scaleshiftgate 三个控制信号,并用它们来调整归一化过程 (normed_inputs * (1 + scale) + shift)。
根据是否有gate的返回,在_gated_residual(门控残差连接)中会执行不同的残差连接方式:
  • Pi0路径:执行标准残差连接x + y
  • Pi0.5路径:执行门控残差连接x + y * gate。会动态地缩放子层的输出,然后再加回去。

class Attention

注意力模块。它被设计为可以处理一个xs(Token 列表),并使用_name(..., i) 辅助函数为不同“专家”(PaliGemma 和 ActionExpert)使用不同的投影权重。它将所有专家的Q,K,V拼接起来进行统一的注意力计算,然后再将结果拆分。
 
【具身智能随想】从图灵论文到pi0 【Flow matching】再读流匹配算法
Loading...