OpenPI 핵심 코드 Deep Dive

대상: Physical-Intelligence/openpi 핵심 코드 경로 목표: 실제 수정/디버깅/성능 최적화를 할 때 반드시 이해해야 하는 파일들을 “논리적 line-by-line” 방식으로 설명한다. 주의: 아래 설명은 원본 코드를 그대로 대량 복제하지 않고, 함수/클래스/블록 단위로 실제 동작을 풀어쓴 것이다.


0. 전체 호출 그래프

OpenPI에서 inference를 한 번 호출하면 가장 중요한 call graph는 다음이다.

user / robot runtime
  ↓
WebsocketClientPolicy.infer(obs)
  ↓ msgpack over websocket
WebsocketPolicyServer._handler
  ↓
Policy.infer(obs)
  ↓
input transforms
  ├─ RepackTransform
  ├─ InjectDefaultPrompt
  ├─ Aloha/Droid/Libero Inputs
  ├─ Normalize
  └─ model transforms: Resize/Tokenize/Pad or FAST tokenization
  ↓
Observation.from_dict(inputs)
  ↓
model.sample_actions(...)
  ├─ Pi0.sample_actions       # π0/π0.5
  └─ Pi0FAST.sample_actions   # π0-FAST
  ↓
output transforms
  ├─ FAST token decode, if needed
  ├─ Unnormalize
  ├─ Aloha/Droid/Libero Outputs
  └─ optional repack output
  ↓
action dict

Training call graph는 다음이다.

scripts/train.py main(config)
  ↓
config.data.create(...)
  ↓
data_loader.create_data_loader(...)
  ↓
raw dataset sample
  ↓
repack → data transforms → Normalize → model transforms
  ↓
Observation.from_dict(batch), actions
  ↓
train_step
  ↓
model.compute_loss(...)
  ├─ Pi0.compute_loss       # flow matching MSE
  └─ Pi0FAST.compute_loss   # next-token CE
  ↓
gradient / optimizer / EMA / checkpoint

1. src/openpi/models/model.py

이 파일은 OpenPI의 공통 type contract다. 이 파일을 이해하지 못하면 policy/data/model 연결을 안정적으로 수정할 수 없다.

1.1 Imports

의미별로 보면 다음을 가져온다.

import 그룹 목적
abc, dataclasses, enum, typing abstract base class, config dataclass, enum type 정의
jax, jax.numpy, flax.nnx, flax.struct JAX/Flax NNX model definition
numpy, torch, safetensors numpy/torch/PyTorch checkpoint handling
orbax.checkpoint JAX checkpoint restore
image_tools, array_typing image resize/pad, typed array utility

성능 관점

model.py는 직접 heavy computation을 많이 하지 않지만, dtype/shape/image preprocessing이 여기서 결정된다. 특히 Observation.from_dictpreprocess_observation은 model-only latency 측정에서 포함/제외를 명확히 해야 한다.


1.2 ModelType

ModelType
  PI0
  PI0_FAST
  PI05

이 enum은 transform factory, config, model loading에서 branch key로 쓰인다.

수정 포인트

새 model family를 추가하려면 반드시 다음을 함께 수정해야 한다.

1. ModelType enum
2. BaseModelConfig subclass
3. ModelTransformFactory branch
4. policy_config / checkpoint config
5. training config registry

1.3 IMAGE_KEYS, IMAGE_RESOLUTION

OpenPI 기본 이미지 contract:

IMAGE_KEYS = (
  base_0_rgb,
  left_wrist_0_rgb,
  right_wrist_0_rgb,
)
IMAGE_RESOLUTION = (224, 224)

의미

  • model checkpoint는 이 camera slot 구조를 기대한다.
  • 실제 robot에 camera가 부족하면 zero image + mask로 slot을 채워야 한다.
  • preprocess_observation은 기본적으로 이 key들이 모두 있는지 검사한다.

자주 생기는 버그

ValueError: images dict missing keys

원인:

  • policy input transform에서 right_wrist_0_rgb를 만들지 않음
  • DROID/ALOHA/LIBERO key mapping이 틀림
  • image dict의 nesting이 잘못됨

1.4 Observation dataclass

Observation은 model에 들어가는 structured input이다.

필드별 의미:

field shape 개념 설명
images [*batch, H, W, C] RGB image, float32 [-1, 1]
image_masks [*batch] 해당 view가 유효하면 True
state [*batch, state_dim] proprioceptive state
tokenized_prompt [*batch, max_len] π0/π0.5/FAST prompt tokens
tokenized_prompt_mask [*batch, max_len] prompt token valid mask
token_ar_mask [*batch, max_len] FAST autoregressive mask
token_loss_mask [*batch, max_len] FAST CE loss mask

Observation.from_dict(data) 논리적 line-by-line

1단계: prompt token/mask 동시성 검사

if tokenized_prompt 존재 여부 != tokenized_prompt_mask 존재 여부:
    raise ValueError

이유:

  • token IDs만 있고 mask가 없으면 padding token을 구분할 수 없다.
  • mask만 있고 token IDs가 없으면 embedding할 수 없다.

2단계: image dtype 변환

for each image:
  if numpy uint8:
    image = image.astype(float32) / 255 * 2 - 1
  elif torch uint8:
    image = image.to(float32).permute(...) / 255 * 2 - 1

의미:

  • client는 bandwidth를 줄이려고 uint8 image를 보낼 수 있다.
  • model은 [-1, 1] float image를 기대한다.
  • PyTorch image는 channel order issue가 있으므로 permute가 필요할 수 있다.

3단계: dataclass 생성

return Observation(
  images=data["image"],
  image_masks=data["image_mask"],
  state=data["state"],
  tokenized_prompt=data.get(...),
  ...
)

수정/디버깅 팁

policy.infer 직전에 다음을 출력하면 schema mismatch를 빠르게 잡을 수 있다.

print(inputs.keys())
print(inputs["image"].keys())
print({k: v.shape for k, v in inputs["image"].items()})
print(inputs["state"].shape)

1.5 preprocess_observation

이 함수는 모델 forward 전에 observation을 정리한다.

논리적 line-by-line

1단계: image key 검사

if required image_keys not subset of observation.images:
    raise ValueError

이 검사가 초기에 실패하면, 문제는 model이 아니라 policy/data transform이다.

2단계: batch shape 계산

batch_shape = observation.state.shape[:-1]

state의 마지막 dimension은 state/action dimension이고, 그 앞은 batch dimension이다.

3단계: 각 image 처리

for key in image_keys:
    image = observation.images[key]
    if image resolution != target:
        resize_with_pad

resize_with_pad는 aspect ratio 보존 + padding일 가능성이 크다. robot camera image를 왜곡하지 않으려는 목적이다.

4단계: train augmentation

훈련 시:

image = image / 2 + 0.5  # [-1,1] -> [0,1]
if not wrist camera:
    RandomCrop
    Resize
    Rotate
ColorJitter
image = image * 2 - 1  # [0,1] -> [-1,1]

의미:

  • base/external camera에는 spatial augmentation을 더 적용한다.
  • wrist camera는 작은 시야/정밀 조작에 중요하므로 crop/rotate를 다르게 다룰 수 있다.

5단계: image mask 채우기

if key not in observation.image_masks:
    mask = all True
else:
    mask = provided mask

mask가 없으면 view를 유효하다고 가정한다.

성능 관점

  • train augmentation은 CPU/JAX image op 비용이 있다.
  • inference에서는 resize가 남아 있을 수 있다.
  • robot side에서 미리 224×224로 resize하면 server preprocessing 비용과 network traffic을 줄일 수 있다.

1.6 BaseModelConfig

공통 config fields:

action_dim
action_horizon
max_token_len

load(params)

논리:

1. create()의 eval_shape로 model 구조를 만든다.
2. model graphdef와 state를 분리한다.
3. checkpoint params와 현재 state tree를 intersect한다.
4. shape/dtype consistency를 검사한다.
5. state에 params를 replace한다.
6. graphdef + state를 merge해서 model 반환.

이 구조는 checkpoint compatibility를 잡는 데 중요하다.

load_pytorch(train_config, weight_path)

논리:

1. PI0Pytorch(config=train_config.model) 생성
2. safetensors.torch.load_model로 weight load
3. PyTorch model 반환

π0-FAST는 PyTorch 미지원이므로 이 경로는 π0/π0.5 중심이다.


1.7 restore_params

JAX/Orbax checkpoint restore 함수다.

논리:

1. params_path를 pathlib/gs path로 정리
2. sharding이 없으면 모든 device에 replicated sharding 생성
3. Orbax PyTreeCheckpointer metadata 읽기
4. params item restore
5. NNX State가 붙인 value suffix가 있으면 제거
6. pure dict 반환

자주 생기는 문제

  • checkpoint path를 checkpoint_dir로 줘야 하는지 checkpoint_dir/params로 줘야 하는지 헷갈림
  • CheckpointWeightLoader는 보통 .../params를 가리킨다.
  • create_trained_policycheckpoint_dir / "params"를 restore한다.

2. src/openpi/transforms.py

이 파일은 OpenPI의 데이터 변환 DSL이다.


2.1 DataTransformFn

모든 transform은 다음 함수를 구현한다.

__call__(data: dict) -> dict

입력 data는 unbatched sample일 수도 있고, 일부 loader path에서는 batched sample일 수도 있다. 대부분의 transform은 unbatched sample 기준으로 설계되어 있다.


2.2 Group

Group(inputs=..., outputs=...)

input transform과 output transform을 한 쌍으로 관리한다.

push

new_inputs  = old_inputs + inputs
new_outputs = outputs + old_outputs

왜 output을 앞에 붙이나?

input transform의 역변환은 output에서 반대 순서로 적용되어야 하기 때문이다.

예:

input transforms:  [A, B, C]
output transforms: [C^-1, B^-1, A^-1]

2.3 CompositeTransform / compose

compose는 transform sequence를 하나의 transform으로 묶는다.

for transform in transforms:
    data = transform(data)
return data

이 구조 때문에 transform 중 하나가 in-place로 data를 바꾸면 뒤 transform은 바뀐 dict를 받는다. 디버깅할 때는 transform 사이사이에 print wrapper를 넣으면 좋다.


2.4 RepackTransform

목적

raw dataset/environment dict를 새 nested dict로 재구성한다.

동작

1. input nested dict를 flatten한다.
2. self.structure에 적힌 old key path를 찾는다.
3. 새 nested dict 구조로 값을 복사한다.

raw:

observation.images.top
observation.state
action

repack structure:

images/cam_high <- observation.images.top
state           <- observation.state
actions         <- action

디버깅

KeyError가 나오면 대부분 old key path가 틀렸다. dataset sample의 flatten key를 먼저 출력해야 한다.


2.5 InjectDefaultPrompt

동작:

if self.prompt is not None and "prompt" not in data:
    data["prompt"] = np.asarray(self.prompt)

주의:

  • 이미 prompt가 있으면 덮어쓰지 않는다.
  • prompt가 없고 default_prompt도 없으면 tokenizer에서 error가 난다.

2.6 Normalize

동작:

if norm_stats is None:
    return data
else:
    apply_tree(data, norm_stats, normalization_fn)

z-score:

(x - mean) / (std + 1e-6)

quantile:

(x - q01) / (q99 - q01 + 1e-6) * 2 - 1

중요 디테일

stats.mean[..., : x.shape[-1]]처럼 현재 x dimension까지만 stats를 자른다. 이는 model action_dim이 실제 robot action_dim보다 클 때 padding/partial dimension을 다루기 위한 설계다.


2.7 Unnormalize

Normalize의 역변환이다.

z-score inverse:

x * (std + 1e-6) + mean

quantile inverse:

(x + 1) / 2 * (q99 - q01 + 1e-6) + q01

중요 디테일

Unnormalize는 strict mode로 norm_stats key가 output data에 없으면 error를 낼 수 있다. 이는 model output action을 robot command로 바꾸기 전에 stats가 제대로 적용되도록 보장한다.


2.8 ResizeImages

동작:

data["image"] = {k: resize_with_pad(v, height, width) for each image}

추론 시 client/server 어느 쪽에서 resize할지 결정해야 한다.

resize 위치 장점 단점
client network payload 감소 robot CPU 사용
server robot side 단순 network payload 증가

2.9 DeltaActions

동작:

selected dims:
  actions -= state

여기서 action shape는 보통:

[action_horizon, action_dim]

state shape는:

[action_dim] or state_dim

np.expand_dims로 state를 horizon axis에 broadcast한다.

현재 joint position이:

state[:6] = [1.0, 0.5, ...]

absolute target action이:

actions[0,:6] = [1.1, 0.7, ...]

DeltaActions 후:

actions[0,:6] = [0.1, 0.2, ...]

2.10 AbsoluteActions

DeltaActions의 역변환이다.

selected dims:
  actions += state

추론 output이 delta라면 실제 robot에 보내기 전에 absolute target으로 바꿔야 할 수 있다.


2.11 TokenizePrompt

π0/π0.5용.

동작:

1. data에서 prompt를 pop한다.
2. prompt가 없으면 ValueError.
3. discrete_state_input이면 state도 tokenizer에 넘긴다.
4. tokenizer.tokenize(prompt, state)
5. tokenized_prompt, tokenized_prompt_mask를 data에 추가한다.

2.12 TokenizeFASTInputs

π0-FAST용.

동작:

1. prompt를 pop한다.
2. state와 optional actions를 가져온다.
3. FASTTokenizer.tokenize(prompt, state, actions)
4. tokenized_prompt, tokenized_prompt_mask, token_ar_mask, token_loss_mask 추가

training에서는 actions가 있으므로 action target tokens가 포함된다. inference에서는 action이 없으므로 prefix만 tokenization된다.


2.13 ExtractFASTActions

π0-FAST model output은 처음에는 token IDs다.

동작:

1. data["actions"]를 token sequence로 해석한다.
2. tokenizer.extract_actions(tokens, action_horizon, action_dim)
3. continuous action chunk로 복원한다.

이 transform이 빠지면 robot에는 token ID sequence가 action처럼 전달되는 치명적 버그가 생긴다.


2.14 PadStatesAndActions

동작:

state를 model_action_dim까지 zero padding
actions가 있으면 actions도 model_action_dim까지 zero padding

모델 checkpoint는 action_dim=32 같은 큰 dimension을 기대할 수 있지만, 실제 robot은 7/8/14 dim만 쓸 수 있다. padding은 이 차이를 맞춘다.


2.15 make_bool_mask

예:

make_bool_mask(6, -1)
→ True True True True True True False

양수는 True 개수, 음수는 False 개수다. delta action mask를 만들 때 자주 쓴다.


3. src/openpi/policies/policy_config.py

이 파일은 checkpoint와 config를 실제 callable Policy로 바꾸는 bridge다.

3.1 create_trained_policy 인자

인자 의미
train_config model/data/training config
checkpoint_dir checkpoint root path 또는 gs path
repack_transforms 추론 환경에 추가로 적용할 repack
sample_kwargs sample_actions에 넘길 추가 인자, 예: num_steps
default_prompt prompt 누락 시 주입
norm_stats 직접 줄 norm stats. 없으면 checkpoint assets에서 load
pytorch_device PyTorch model device

3.2 논리적 line-by-line

1단계: repack transform default

repack_transforms = repack_transforms or Group()

추가 repack이 없으면 빈 transform group이다.

2단계: checkpoint download/cache

checkpoint_dir = download.maybe_download(str(checkpoint_dir))

gs://... path면 local cache로 내려받는다.

3단계: PyTorch checkpoint 감지

weight_path = checkpoint_dir / model.safetensors
is_pytorch = exists(weight_path)

4단계: model load

PyTorch:

model = train_config.model.load_pytorch(train_config, weight_path)
model.paligemma_with_expert.to_bfloat16_for_selected_params(...)

JAX:

params = restore_params(checkpoint_dir / params, dtype=bfloat16)
model = train_config.model.load(params)

5단계: data config 생성

data_config = train_config.data.create(train_config.assets_dirs, train_config.model)

이때 repack/data/model transforms와 norm stats 설정이 결정된다.

6단계: norm stats load

if norm_stats is None:
    norm_stats = load_norm_stats(checkpoint_dir / assets, data_config.asset_id)

중요: inference에서는 config assets가 아니라 checkpoint assets를 우선 사용한다. 이는 training 때 저장된 stats와 inference stats가 동일해야 하기 때문이다.

7단계: Policy 생성

input transforms:

repack_transforms.inputs
InjectDefaultPrompt(default_prompt)
data_config.data_transforms.inputs
Normalize(norm_stats)
data_config.model_transforms.inputs

output transforms:

data_config.model_transforms.outputs
Unnormalize(norm_stats)
data_config.data_transforms.outputs
repack_transforms.outputs

4. src/openpi/policies/policy.py

4.1 Policy.__init__

논리

1. model 저장
2. input transforms compose
3. output transforms compose
4. sample_kwargs 저장
5. metadata 저장
6. PyTorch model이면 device 이동 + eval mode
7. JAX model이면 sample_actions를 jit compile wrapper로 감쌈
8. JAX rng 초기화

성능 포인트

  • JAX는 첫 호출에 compile overhead가 크다.
  • PyTorch는 eval()이 중요하다. dropout/batchnorm류가 있으면 inference behavior가 달라질 수 있다.
  • sample_kwargsnum_steps, temperature, max_decoding_steps 같은 latency knob를 주입할 수 있다.

4.2 Policy.infer

1단계: input copy

inputs = tree.map(lambda x: x, obs)

얕은 copy에 가깝다. transform이 in-place로 바꿀 수 있으므로 원본 obs를 보호하려는 의도다.

2단계: input transform

inputs = self._input_transform(inputs)

여기서 raw robot observation이 model-ready dict로 바뀐다.

3단계: framework별 batch/device 처리

JAX:

leaf → jnp.asarray(leaf)[None, ...]
rng split

PyTorch:

leaf → torch.from_numpy(np.array(leaf)).to(device)[None, ...]

4단계: optional fixed noise

if noise is not None:
    if noise.ndim == 2:
        noise = noise[None, ...]
    sample_kwargs["noise"] = noise

이것은 profiling에 매우 중요하다. fixed noise를 쓰면 stochastic variance를 줄이고 model path를 deterministic하게 비교할 수 있다.

5단계: Observation 생성

observation = Observation.from_dict(inputs)

여기서 image uint8→float conversion이 발생할 수 있다.

6단계: model.sample_actions timing

start = time.monotonic()
actions = self._sample_actions(...)
model_time = time.monotonic() - start

주의: JAX는 asynchronous execution이 있을 수 있으므로 정확한 latency 측정에는 block_until_ready가 필요할 수 있다. 이 코드의 timing은 일반 server timing으로는 유용하지만, microbenchmark에서는 별도 sync 측정이 필요하다.

7단계: output 변환

outputs = {
  "state": inputs["state"],
  "actions": actions,
}

batch dimension 제거 후 numpy로 바꾼다.

outputs = self._output_transform(outputs)

여기서 unnormalize, FAST decode, robot-specific output conversion이 일어난다.


5. src/openpi/models/pi0.py

π0/π0.5의 핵심 구현이다.

5.1 파일 역할

이 파일은 다음을 구현한다.

make_attn_mask
posemb_sincos
Pi0 class
  __init__
  embed_prefix
  embed_suffix
  compute_loss
  sample_actions

5.2 make_attn_mask

입력

input_mask: bool[B, N]
mask_ar: bool/int[B or 1, N]

동작

1. mask_ar를 input_mask shape으로 broadcast
2. token axis cumulative sum 계산
3. query token i가 key token j를 볼 수 있는지 계산
4. padding token을 valid_mask로 제거

왜 필요한가?

π0는 image/prompt prefix와 action suffix를 하나의 transformer sequence에 넣는다. 그런데 모든 token이 모든 token을 보면 안 된다.

원하는 구조:

image/prompt prefix:
  서로 bidirectional attend 가능

action suffix:
  prefix attend 가능
  action block 내부 attend 가능 또는 제한된 block causal

mask_ar는 이 구조를 compact하게 표현한다.


5.3 posemb_sincos

입력

pos: timestep
embedding_dim: hidden dimension
min_period, max_period

동작

1. embedding_dim이 even인지 검사
2. log-spaced periods 생성
3. pos / period 계산
4. sin과 cos를 concat

의미

diffusion/flow model에서 timestep을 transformer가 이해할 수 있는 vector로 바꾼다.


5.4 Pi0.__init__

1단계: base init

super().__init__(action_dim, action_horizon, max_token_len)

모든 model은 action_dim/action_horizon/max_token_len을 공통으로 갖는다.

2단계: π0.5 flag 저장

self.pi05 = config.pi05

이 flag가 이후 action/time conditioning 방식을 바꾼다.

3단계: Gemma/PaliGemma config 준비

paligemma_config = get_config(config.paligemma_variant)
action_expert_config = get_config(config.action_expert_variant)

OpenPI는 VLM backbone과 action expert를 함께 구성한다.

4단계: LLM module 생성

llm = ToNNX(Gemma Module(...))
llm.lazy_init(...)

Flax NNX bridge를 사용한다. Gemma 구현을 NNX로 직접 다시 쓰지 않고 bridge하는 구조다.

π0.5이면 action expert 쪽에 AdaRMS를 켠다.

5단계: image encoder 생성

img = ToNNX(SigLIP Module(... variant="So400m/14" ...))
img.lazy_init(fake image)

SigLIP vision encoder가 image를 token embedding sequence로 바꾼다.

6단계: action projection 정의

action_in_proj  : action_dim → hidden_width
action_out_proj : hidden_width → action_dim

continuous action vector와 transformer hidden space를 연결한다.

7단계: π0.5 vs π0 분기

π0.5:

time_mlp_in
time_mlp_out

π0:

state_proj
action_time_mlp_in
action_time_mlp_out

의미:

  • π0는 state token을 별도로 suffix에 넣고 action+time을 MLP로 섞는다.
  • π0.5는 AdaRMS condition으로 time 정보를 넣는다.

5.5 embed_prefix

목적

변하지 않는 조건부 context를 embedding한다.

단계별 설명

1단계: 빈 list 준비

tokens = []
input_mask = []
ar_mask = []

2단계: image별 embedding

for name in obs.images:
    image_tokens = SigLIP(image)
    tokens.append(image_tokens)
    input_mask.append(repeat image mask across image token length)
    ar_mask.append(False block)

image token은 서로 attend 가능해야 하므로 ar_mask는 false block이다.

3단계: prompt token embedding

if tokenized_prompt exists:
    prompt_embeddings = LLM.embed(token IDs)
    tokens.append(prompt_embeddings)
    input_mask.append(prompt_mask)
    ar_mask.append(False block)

prompt도 prefix이므로 image와 함께 full attention block으로 둔다.

4단계: concat

return concat(tokens), concat(input_mask), concat(ar_mask)

5.6 embed_suffix

목적

현재 denoising step에서 변하는 state/action/timestep token을 만든다.

π0 path

1단계: state token

state_embedding = state_proj(obs.state)

state를 transformer hidden으로 project한다.

2단계: action token

action_embedding = action_in_proj(noisy_actions)

noisy action chunk [B,H,A]가 hidden token [B,H,D]가 된다.

3단계: time embedding

time_embedding = posemb_sincos(timestep)
time_embedding = repeat across action_horizon

4단계: action+time fusion

concat(action_embedding, time_embedding)
→ action_time_mlp_in
→ activation
→ action_time_mlp_out

5단계: masks

state/action suffix가 prefix 뒤에 붙는다. state와 action block의 attention relation을 ar_mask로 정의한다.

π0.5 path

1단계: action token

action_embedding = action_in_proj(noisy_actions)

2단계: time condition

time_embedding = posemb_sincos(timestep)
adarms_cond = time_mlp(time_embedding)

3단계: LLM/action expert forward에서 AdaRMS condition으로 사용

π0.5는 token에 time을 직접 concat하기보다 normalization modulation 쪽으로 넣는 구조다.


5.7 compute_loss

전체 목적

clean action chunk actions와 random noise 사이를 잇는 vector field를 학습한다.

단계별 설명

1단계: rng split

rng_preprocess, rng_noise, rng_time = split(rng)

각 random source를 분리한다.

2단계: observation preprocessing

observation = preprocess_observation(rng_preprocess, observation, train=train)

train이면 augmentation이 적용된다.

3단계: noise sampling

noise = normal(shape=actions.shape)

4단계: timestep sampling

time = Beta(1.5, 1)
time = scaled to [0.001, 1]

5단계: interpolation

x_t = time * noise + (1 - time) * actions

t=1은 noise, t=0은 action이다.

6단계: target vector field

u_t = noise - actions

이는 x_t를 time에 대해 미분한 값과 연결된다.

7단계: prefix/suffix embedding

prefix = embed_prefix(observation)
suffix = embed_suffix(observation, x_t, time)

8단계: attention mask 생성

input_mask = concat(prefix_mask, suffix_mask)
ar_mask = concat(prefix_ar, suffix_ar)
attn_mask = make_attn_mask(input_mask, ar_mask)

9단계: LLM/action expert forward

outputs = PaliGemma.llm(embedded tokens, mask, ...)

10단계: action vector prediction

v_t = action_out_proj(suffix_outputs)

11단계: MSE loss

loss = mean((v_t - u_t)^2 over action_dim)

5.8 sample_actions

전체 목적

noise에서 시작해 learned vector field를 따라 action chunk를 생성한다.

단계별 설명

1단계: observation preprocessing

observation = preprocess_observation(None, observation, train=False)

2단계: initial noise

if noise is None:
    x_t = normal([B, action_horizon, action_dim])
else:
    x_t = noise

fixed noise를 넣으면 deterministic profiling 가능.

3단계: prefix prefill

prefix_tokens, prefix_mask, prefix_ar = embed_prefix(observation)
prefix_attn_mask = make_attn_mask(...)
PaliGemma.llm(..., decode=True) → kv_cache

prefix는 denoising loop 동안 변하지 않으므로 cache한다.

4단계: integration loop

dt = -1 / num_steps
time = 1.0
while time >= 0:
    suffix = embed_suffix(observation, x_t, time)
    suffix forward with prefix kv_cache
    v_t = action_out_proj(...)
    x_t = x_t + dt * v_t
    time = time + dt

5단계: return action chunk

return x_t

성능 핵심

num_steps번 suffix forward가 반복된다. 따라서 optimization 목표는:

1. prefix prefill을 최대한 재사용
2. suffix step latency 감소
3. denoising step 수 감소/distillation
4. action_horizon과 replanning 주기 최적화

6. src/openpi/models/pi0_fast.py

π0-FAST는 action을 token으로 생성한다.

6.1 Pi0FASTConfig

주요 field:

dtype
action_dim
action_horizon
max_token_len
fast_model_tokenizer
fast_model_tokenizer_kwargs

model_typePI0_FAST를 반환한다.


6.2 inputs_spec

π0-FAST observation에는 다음이 필요하다.

images
image_masks
state
tokenized_prompt
tokenized_prompt_mask
token_ar_mask
token_loss_mask

π0와 다르게 token loss mask가 필수다. next-token prediction loss를 어디에 걸지 알아야 하기 때문이다.


6.3 left_to_right_align

목적

variable-length prefix를 right-aligned 형태로 정렬한다.

왜 필요한가?

  • autoregressive decoding에서 KV cache 위치/position을 맞추기 위해서다.
  • batch마다 prompt/token length가 다르면 decode position 계산이 꼬일 수 있다.

6.4 put_along_last_axis

JAX에는 numpy의 put_along_axis(axis=-1)에 해당하는 기능이 부족할 수 있어서 one-hot/einsum으로 구현한다.

목적:

output_tokens[batch, step] = generated_token

6.5 Pi0FAST.__init__

π0와 비슷하지만 action flow head가 없다.

1. PaliGemma/Gemma fast module 생성
2. SigLIP image encoder 생성
3. self.PaliGemma = {llm, img}

continuous action projection 대신 token logits head는 LLM vocabulary head가 담당한다.


6.6 embed_inputs

단계

1. image tokens 생성
2. image masks 생성
3. image ar_mask는 false block
4. tokenized prompt/action input embedding 생성
5. token mask와 token_ar_mask 추가
6. concat해서 반환

여기서 tokenized prompt는 사실 prompt만이 아니라 FAST tokenizer가 구성한 state/action token sequence까지 포함할 수 있다.


6.7 compute_loss

단계별 설명

1단계: observation preprocess

image resize/augmentation 등.

2단계: input embedding

input_token_embeddings, input_mask, ar_mask = embed_inputs(observation)

3단계: attention mask

attn_mask = make_attn_mask(input_mask, ar_mask)

4단계: next-token targets

targets = one_hot(tokenized_prompt[:, 1:])

각 input token은 다음 token을 예측한다.

5단계: last input 제외

forward input = embedded sequence[:, :-1]

마지막 token은 다음 token target이 없으므로 input에서 제외된다.

6단계: prelogits 계산

LLM forward로 hidden/prelogits를 얻는다.

7단계: target token 구간만 logits decode

vocab projection은 비용이 크므로 target 구간에 대해서만 logits를 계산한다.

8단계: CE loss

logp = log_softmax(logits)
loss_mask = token_loss_mask[:, 1:]
loss = -sum(targets * logp * loss_mask) / sum(loss_mask)

6.8 sample_actions

단계별 설명

1단계: preprocess

observation = preprocess_observation(...)

2단계: prefix embedding

prefix_token_embeddings, prefix_mask, prefix_ar_mask = embed_inputs(observation)

추론 시 이 token sequence는 prompt/state prefix다.

3단계: left-to-right align

KV cache를 위한 정렬.

4단계: prefix prefill

prefix_logits, kv_cache = llm(... decode=True)
last_logit = prefix_logits[:, -1:]

last logit이 첫 generated token distribution이다.

5단계: decoding loop

while not all_eos and step < max_decoding_steps:
    if temperature > 0:
        token = categorical(last_logit / temperature)
    else:
        token = argmax(last_logit)
    output_tokens[step] = token
    token_embedding = llm.embed(token)
    last_logit, kv_cache = llm(one token, kv_cache=cache)

6단계: return output tokens

Pi0FAST.sample_actions 자체는 generated tokens를 반환한다. 최종 action chunk 복원은 output transform ExtractFASTActions가 담당한다.


7. scripts/train.py

7.1 init_logging

로그 format을 짧게 바꾼다. 실험 로그에서 step/loss를 보기 쉽게 한다.


7.2 init_wandb

논리:

if wandb disabled:
    wandb.init(mode="disabled")
else if resume:
    checkpoint_dir/wandb_id.txt에서 run id 읽고 resume
else:
    새 run 생성하고 wandb_id.txt 저장

실험 reproducibility를 위해 checkpoint directory와 wandb run id를 묶는다.


7.3 _load_weights_and_validate

1. loader.load(params_shape)
2. expected/got tree shape/dtype 검사
3. 실제 loaded param만 반환

checkpoint compatibility 문제를 early detect한다.


7.4 init_train_state

단계

1. optimizer 생성
2. init(rng, partial_params) 함수 정의
3. model 생성
4. partial checkpoint params가 있으면 state에 merge
5. params 추출
6. frozen params를 bf16으로 변환
7. TrainState 생성
8. eval_shape로 train_state_shape 계산
9. FSDP sharding 계산
10. resume이면 shape/sharding만 반환
11. weight_loader로 partial params 로드
12. jitted init으로 actual train_state 생성

중요한 설계

partial checkpoint loading을 지원하므로 PaliGemma base만 로드하거나, LoRA missing weight를 현재 init weight로 채울 수 있다.


7.5 train_step

단계

1. model = merge(model_def, params)
2. model.train()
3. loss_fn 정의
4. train_rng = fold_in(rng, state.step)
5. diff_state = trainable_filter
6. loss, grads = value_and_grad(loss_fn)
7. trainable params만 optimizer update
8. model에 new params update
9. full state 추출
10. EMA update
11. loss/grad_norm/param_norm 반환

성능/안정성 포인트

  • fold_in을 사용해 step별 rng를 안정적으로 만든다.
  • frozen params는 gradient 계산 대상에서 제외된다.
  • EMA가 있으면 inference checkpoint에 EMA params가 쓰일 수 있다.

7.6 main

단계

1. logging init
2. batch_size % device_count 검사
3. JAX compilation cache 설정
4. rng split
5. mesh/sharding 생성
6. checkpoint manager 생성
7. wandb init
8. data_loader 생성
9. 첫 batch shape/info logging
10. 첫 batch camera image wandb logging
11. train_state init/restore
12. ptrain_step = jax.jit(...)
13. loop:
      - train step
      - log interval마다 metrics reduce/log
      - next batch
      - save interval마다 checkpoint
14. checkpoint async save 완료 대기

training debug 추천

처음부터 full training하지 말고 다음 순서로 한다.

# 1. norm stats 확인
uv run scripts/compute_norm_stats.py --config-name <config>

# 2. 매우 짧은 debug config 또는 num_train_steps 수정
uv run scripts/train.py <config> --exp-name=debug --overwrite

# 3. 첫 batch image wandb logging 확인
# 4. loss가 finite인지 확인
# 5. checkpoint 생성 확인

8. scripts/serve_policy.py

8.1 EnvMode

지원 환경:

ALOHA
ALOHA_SIM
DROID
LIBERO

8.2 Checkpoint dataclass

config: training config name
dir: checkpoint directory

예:

config = pi05_droid
dir = gs://openpi-assets/checkpoints/pi05_droid

8.3 DEFAULT_CHECKPOINT

env별 기본 policy를 mapping한다.

DROID  → pi05_droid
LIBERO → pi05_libero
...

팀 실험에서는 default를 쓰기보다 명시적으로 policy:checkpoint를 지정하는 것이 reproducibility에 좋다.


8.4 create_policy

if args.policy is Checkpoint:
    create_trained_policy(get_config(config), dir)
else:
    create_default_policy(env)

8.5 main

1. policy 생성
2. metadata 저장
3. record 옵션이면 PolicyRecorder wrapper 적용
4. host/ip logging
5. WebsocketPolicyServer 생성
6. serve_forever

--record는 디버깅에 유용하다. policy input/output을 policy_records에 저장해 replay/분석할 수 있다.


9. src/openpi/serving/websocket_policy_server.py

9.1 WebsocketPolicyServer.__init__

policy, host, port, metadata를 저장한다.


9.2 serve_forever / run

async websocket server를 시작한다.

설정:

compression=None
max_size=None
process_request=_health_check

압축을 끄는 것은 latency와 CPU overhead 측면에서 합리적일 수 있다. image payload가 크면 압축이 네트워크를 줄일 수 있지만 CPU latency를 늘릴 수 있다. robot real-time에서는 압축 off가 더 안정적인 경우가 많다.


9.3 _handler

단계

1. connection open logging
2. msgpack_numpy.Packer 생성
3. metadata를 client에 최초 전송
4. loop:
    a. obs 수신
    b. unpack
    c. policy.infer(obs)
    d. server_timing 추가
    e. pack해서 client에 전송
    f. total time 기록
5. connection closed 처리
6. exception이면 traceback 전송 후 close

timing 해석

  • server_timing["infer_ms"]: policy.infer에 걸린 시간
  • prev_total_ms: 이전 request의 수신→추론→전송 포함 total

10. packages/openpi-client/.../websocket_client_policy.py

10.1 __init__

1. host가 ws로 시작하면 그대로 사용
2. 아니면 ws://host 구성
3. port가 있으면 :port 추가
4. Packer 생성
5. server connect retry
6. metadata 수신

10.2 _wait_for_server

server가 아직 안 떠 있으면 5초 간격으로 retry한다. robot runtime에서 server를 먼저 띄우지 않아도 client가 기다릴 수 있다.


10.3 infer

1. obs를 msgpack pack
2. websocket send
3. response recv
4. response가 string이면 server error로 간주
5. bytes면 unpack해서 action dict 반환

성능 포인트

msgpack serialization/deserialization은 e2e latency에 포함된다. 큰 image를 raw로 보내면 이 비용이 커진다. client-side resize/uint8 conversion이 중요한 이유다.


11. src/openpi/training/data_loader.py

11.1 핵심 역할

이 파일은 raw dataset을 training batch로 만든다.

LeRobot/RLDS dataset
→ repack/data/model transforms
→ Normalize
→ Observation + Actions batch
→ JAX sharding 또는 PyTorch tensor

11.2 TransformedDataset

random access dataset wrapper.

__getitem__(index):
    return transform(dataset[index])

11.3 IterableTransformedDataset

RLDS처럼 iterable dataset에 transform을 적용한다. batched sample이면 batch를 sample 단위로 쪼개 transform한 뒤 다시 stack한다.

이것은 transform들이 대부분 unbatched sample 기준으로 작성되어 있기 때문이다.


11.4 FakeDataset

model input spec을 기반으로 random fake sample을 만든다. config/debug/test에서 유용하다.


11.5 create_torch_dataset

if repo_id == fake:
    FakeDataset
else:
    LeRobotDatasetMetadata(repo_id)
    LeRobotDataset(repo_id, delta_timestamps=...)
    if prompt_from_task:
        PromptFromLeRobotTask 적용

delta_timestamps는 action horizon만큼 future action sequence를 가져오기 위한 핵심이다.


11.6 transform_dataset

raw dataset sample
→ repack_transforms.inputs
→ data_transforms.inputs
→ Normalize
→ model_transforms.inputs

norm_stats가 없으면 error를 낸다. 단 skip_norm_stats=True면 건너뛸 수 있다.


11.7 create_data_loader

if data_config.rlds_data_dir is not None:
    create_rlds_data_loader
else:
    create_torch_data_loader

DROID full dataset path만 RLDS를 쓴다고 보면 된다.


11.8 TorchDataLoader

PyTorch DataLoader를 감싼 wrapper다.

중요 옵션:

옵션 의미
local_batch_size process/device별 batch
sharding JAX sharded array 생성용
shuffle dataset shuffle
num_workers CPU worker process 수
framework jax or pytorch

JAX framework이면 batch를 jax.make_array_from_process_local_data로 sharded array로 만든다. PyTorch framework이면 torch tensor로 반환한다.


12. src/openpi/training/checkpoints.py

12.1 initialize_checkpoint_dir

if checkpoint_dir exists:
    if overwrite: delete and recreate
    elif resume: resuming=True
    else: raise FileExistsError
CheckpointManager 생성
empty checkpoint인데 resume이면 resume 취소

12.2 save_state

1. save_assets callback 정의
   - norm_stats를 assets/asset_id에 저장
2. _split_params(state)
   - EMA params가 있으면 inference params로 사용
3. items = assets + train_state + params
4. checkpoint_manager.save(step, items)

12.3 _split_params

if ema_params exists:
    params = ema_params
    train_state.ema_params = None
else:
    params = state.params
    train_state.params = {}

이 설계 때문에 inference checkpoint의 params는 EMA를 사용할 수 있다.


13. src/openpi/training/weight_loaders.py

13.1 NoOpWeightLoader

아무 weight도 로드하지 않는다. debug/fake training에 유용하다.

13.2 CheckpointWeightLoader

1. params_path에서 checkpoint restore
2. missing LoRA weights는 현재 initialized params에서 merge
3. merged params 반환

LoRA fine-tuning에서 base checkpoint에는 LoRA parameter가 없으므로 missing LoRA weight를 init된 값으로 채운다.

13.3 PaliGemmaWeightLoader

official PaliGemma checkpoint를 로드하고, OpenPI action expert 등 extra weight는 유지한다.

이것은 π0 base pretraining 전 초기화나 특정 실험에서 중요하다.


14. 수정할 때 가장 먼저 봐야 하는 파일 조합

14.1 새 robot 붙이기

src/openpi/policies/<new_robot>_policy.py
src/openpi/training/config.py
src/openpi/transforms.py
examples/ur5/README.md

14.2 inference latency 줄이기

src/openpi/models/pi0.py
src/openpi/models/pi0_fast.py
src/openpi/policies/policy.py
src/openpi/serving/websocket_policy_server.py
packages/openpi-client/src/openpi_client/websocket_client_policy.py

14.3 PyTorch path 최적화

src/openpi/models_pytorch/pi0_pytorch.py
src/openpi/models_pytorch/gemma_pytorch.py
src/openpi/models_pytorch/preprocessing_pytorch.py
src/openpi/models_pytorch/transformers_replace/*

14.4 training 안정화

src/openpi/training/config.py
src/openpi/training/data_loader.py
src/openpi/training/optimizer.py
src/openpi/training/checkpoints.py
scripts/compute_norm_stats.py
scripts/train.py

15. 핵심 invariants

OpenPI를 수정할 때 반드시 유지해야 하는 invariants다.

15.1 shape invariants

image:   [B, 224, 224, 3]
state:   [B, action_dim] or padded to action_dim
actions: [B, action_horizon, action_dim]
prompt:  [B, max_token_len]

15.2 value range invariants

image float: [-1, 1]
uint8 image: [0, 255] before Observation.from_dict
normalized state/action: config norm mode에 따름
unnormalized action: robot command convention에 맞아야 함

15.3 transform order invariants

input:
  repack → robot input transform → normalize → model transform

output:
  model output transform → unnormalize → robot output transform → repack output

15.4 checkpoint invariants

JAX checkpoint:
  checkpoint_dir/params
  checkpoint_dir/assets/<asset_id>

PyTorch checkpoint:
  checkpoint_dir/model.safetensors
  checkpoint_dir/assets/<asset_id>

15.5 policy invariants

policy.infer(obs) returns dict with:
  actions
  policy_timing
optional:
  state
  server_timing, if through websocket server

16. 디버깅 print template

아래 template은 transform chain debugging에 유용하다.

def describe_tree(x, prefix=""):
    if isinstance(x, dict):
        for k, v in x.items():
            describe_tree(v, prefix + "/" + str(k))
    else:
        shape = getattr(x, "shape", None)
        dtype = getattr(x, "dtype", None)
        print(prefix, shape, dtype)

policy 내부에서 단계별로:

print("raw obs")
describe_tree(obs)

inputs = self._input_transform(obs)
print("after input transforms")
describe_tree(inputs)

observation = Observation.from_dict(inputs)
print("Observation images", {k: v.shape for k, v in observation.images.items()})
print("Observation state", observation.state.shape)

주의: 실제 robot loop에서는 print가 latency를 크게 늘리므로 debug run에서만 사용한다.


17. 최종 mental model

OpenPI 코드의 본질은 다음 세 문장으로 압축된다.

  1. model.py는 모델이 기대하는 표준 tensor contract를 정의한다.
  2. transforms.py + policies/* + training/config.py는 각 robot/dataset을 그 contract에 맞춘다.
  3. pi0.py / pi0_fast.py는 contract에 맞게 들어온 observation을 action chunk로 생성한다.

따라서 새 연구를 시작할 때 코드를 읽는 순서는 다음이 가장 효율적이다.

model.py
→ transforms.py
→ policy_config.py
→ policy.py
→ target policy file: droid/aloha/libero
→ training/config.py
→ pi0.py or pi0_fast.py
→ train.py / serve_policy.py

이 순서를 따르면 “모델부터 읽다가 데이터 schema에서 막히는 문제”를 피할 수 있다.