OpenPI 핵심 코드 Deep Dive
대상: Physical-Intelligence/openpi 핵심 코드 경로 목표: 실제 수정/디버깅/성능 최적화를 할 때 반드시 이해해야 하는 파일들을 “논리적 line-by-line” 방식으로 설명한다. 주의: 아래 설명은 원본 코드를 그대로 대량 복제하지 않고, 함수/클래스/블록 단위로 실제 동작을 풀어쓴 것이다.
0. 전체 호출 그래프
OpenPI에서 inference를 한 번 호출하면 가장 중요한 call graph는 다음이다.
user / robot runtime
↓
WebsocketClientPolicy.infer(obs)
↓ msgpack over websocket
WebsocketPolicyServer._handler
↓
Policy.infer(obs)
↓
input transforms
├─ RepackTransform
├─ InjectDefaultPrompt
├─ Aloha/Droid/Libero Inputs
├─ Normalize
└─ model transforms: Resize/Tokenize/Pad or FAST tokenization
↓
Observation.from_dict(inputs)
↓
model.sample_actions(...)
├─ Pi0.sample_actions # π0/π0.5
└─ Pi0FAST.sample_actions # π0-FAST
↓
output transforms
├─ FAST token decode, if needed
├─ Unnormalize
├─ Aloha/Droid/Libero Outputs
└─ optional repack output
↓
action dict
Training call graph는 다음이다.
scripts/train.py main(config)
↓
config.data.create(...)
↓
data_loader.create_data_loader(...)
↓
raw dataset sample
↓
repack → data transforms → Normalize → model transforms
↓
Observation.from_dict(batch), actions
↓
train_step
↓
model.compute_loss(...)
├─ Pi0.compute_loss # flow matching MSE
└─ Pi0FAST.compute_loss # next-token CE
↓
gradient / optimizer / EMA / checkpoint
1. src/openpi/models/model.py
이 파일은 OpenPI의 공통 type contract다. 이 파일을 이해하지 못하면 policy/data/model 연결을 안정적으로 수정할 수 없다.
1.1 Imports
의미별로 보면 다음을 가져온다.
| import 그룹 | 목적 |
|---|---|
abc, dataclasses, enum, typing | abstract base class, config dataclass, enum type 정의 |
jax, jax.numpy, flax.nnx, flax.struct | JAX/Flax NNX model definition |
numpy, torch, safetensors | numpy/torch/PyTorch checkpoint handling |
orbax.checkpoint | JAX checkpoint restore |
image_tools, array_typing | image resize/pad, typed array utility |
성능 관점
model.py는 직접 heavy computation을 많이 하지 않지만, dtype/shape/image preprocessing이 여기서 결정된다. 특히 Observation.from_dict와 preprocess_observation은 model-only latency 측정에서 포함/제외를 명확히 해야 한다.
1.2 ModelType
ModelType
PI0
PI0_FAST
PI05
이 enum은 transform factory, config, model loading에서 branch key로 쓰인다.
수정 포인트
새 model family를 추가하려면 반드시 다음을 함께 수정해야 한다.
1. ModelType enum
2. BaseModelConfig subclass
3. ModelTransformFactory branch
4. policy_config / checkpoint config
5. training config registry
1.3 IMAGE_KEYS, IMAGE_RESOLUTION
OpenPI 기본 이미지 contract:
IMAGE_KEYS = (
base_0_rgb,
left_wrist_0_rgb,
right_wrist_0_rgb,
)
IMAGE_RESOLUTION = (224, 224)
의미
- model checkpoint는 이 camera slot 구조를 기대한다.
- 실제 robot에 camera가 부족하면 zero image + mask로 slot을 채워야 한다.
preprocess_observation은 기본적으로 이 key들이 모두 있는지 검사한다.
자주 생기는 버그
ValueError: images dict missing keys
원인:
- policy input transform에서
right_wrist_0_rgb를 만들지 않음 - DROID/ALOHA/LIBERO key mapping이 틀림
- image dict의 nesting이 잘못됨
1.4 Observation dataclass
Observation은 model에 들어가는 structured input이다.
필드별 의미:
| field | shape 개념 | 설명 |
|---|---|---|
images | [*batch, H, W, C] | RGB image, float32 [-1, 1] |
image_masks | [*batch] | 해당 view가 유효하면 True |
state | [*batch, state_dim] | proprioceptive state |
tokenized_prompt | [*batch, max_len] | π0/π0.5/FAST prompt tokens |
tokenized_prompt_mask | [*batch, max_len] | prompt token valid mask |
token_ar_mask | [*batch, max_len] | FAST autoregressive mask |
token_loss_mask | [*batch, max_len] | FAST CE loss mask |
Observation.from_dict(data) 논리적 line-by-line
1단계: prompt token/mask 동시성 검사
if tokenized_prompt 존재 여부 != tokenized_prompt_mask 존재 여부:
raise ValueError
이유:
- token IDs만 있고 mask가 없으면 padding token을 구분할 수 없다.
- mask만 있고 token IDs가 없으면 embedding할 수 없다.
2단계: image dtype 변환
for each image:
if numpy uint8:
image = image.astype(float32) / 255 * 2 - 1
elif torch uint8:
image = image.to(float32).permute(...) / 255 * 2 - 1
의미:
- client는 bandwidth를 줄이려고 uint8 image를 보낼 수 있다.
- model은
[-1, 1]float image를 기대한다. - PyTorch image는 channel order issue가 있으므로 permute가 필요할 수 있다.
3단계: dataclass 생성
return Observation(
images=data["image"],
image_masks=data["image_mask"],
state=data["state"],
tokenized_prompt=data.get(...),
...
)
수정/디버깅 팁
policy.infer 직전에 다음을 출력하면 schema mismatch를 빠르게 잡을 수 있다.
print(inputs.keys())
print(inputs["image"].keys())
print({k: v.shape for k, v in inputs["image"].items()})
print(inputs["state"].shape)
1.5 preprocess_observation
이 함수는 모델 forward 전에 observation을 정리한다.
논리적 line-by-line
1단계: image key 검사
if required image_keys not subset of observation.images:
raise ValueError
이 검사가 초기에 실패하면, 문제는 model이 아니라 policy/data transform이다.
2단계: batch shape 계산
batch_shape = observation.state.shape[:-1]
state의 마지막 dimension은 state/action dimension이고, 그 앞은 batch dimension이다.
3단계: 각 image 처리
for key in image_keys:
image = observation.images[key]
if image resolution != target:
resize_with_pad
resize_with_pad는 aspect ratio 보존 + padding일 가능성이 크다. robot camera image를 왜곡하지 않으려는 목적이다.
4단계: train augmentation
훈련 시:
image = image / 2 + 0.5 # [-1,1] -> [0,1]
if not wrist camera:
RandomCrop
Resize
Rotate
ColorJitter
image = image * 2 - 1 # [0,1] -> [-1,1]
의미:
- base/external camera에는 spatial augmentation을 더 적용한다.
- wrist camera는 작은 시야/정밀 조작에 중요하므로 crop/rotate를 다르게 다룰 수 있다.
5단계: image mask 채우기
if key not in observation.image_masks:
mask = all True
else:
mask = provided mask
mask가 없으면 view를 유효하다고 가정한다.
성능 관점
- train augmentation은 CPU/JAX image op 비용이 있다.
- inference에서는 resize가 남아 있을 수 있다.
- robot side에서 미리 224×224로 resize하면 server preprocessing 비용과 network traffic을 줄일 수 있다.
1.6 BaseModelConfig
공통 config fields:
action_dim
action_horizon
max_token_len
load(params)
논리:
1. create()의 eval_shape로 model 구조를 만든다.
2. model graphdef와 state를 분리한다.
3. checkpoint params와 현재 state tree를 intersect한다.
4. shape/dtype consistency를 검사한다.
5. state에 params를 replace한다.
6. graphdef + state를 merge해서 model 반환.
이 구조는 checkpoint compatibility를 잡는 데 중요하다.
load_pytorch(train_config, weight_path)
논리:
1. PI0Pytorch(config=train_config.model) 생성
2. safetensors.torch.load_model로 weight load
3. PyTorch model 반환
π0-FAST는 PyTorch 미지원이므로 이 경로는 π0/π0.5 중심이다.
1.7 restore_params
JAX/Orbax checkpoint restore 함수다.
논리:
1. params_path를 pathlib/gs path로 정리
2. sharding이 없으면 모든 device에 replicated sharding 생성
3. Orbax PyTreeCheckpointer metadata 읽기
4. params item restore
5. NNX State가 붙인 value suffix가 있으면 제거
6. pure dict 반환
자주 생기는 문제
- checkpoint path를
checkpoint_dir로 줘야 하는지checkpoint_dir/params로 줘야 하는지 헷갈림 CheckpointWeightLoader는 보통.../params를 가리킨다.create_trained_policy는checkpoint_dir / "params"를 restore한다.
2. src/openpi/transforms.py
이 파일은 OpenPI의 데이터 변환 DSL이다.
2.1 DataTransformFn
모든 transform은 다음 함수를 구현한다.
__call__(data: dict) -> dict
입력 data는 unbatched sample일 수도 있고, 일부 loader path에서는 batched sample일 수도 있다. 대부분의 transform은 unbatched sample 기준으로 설계되어 있다.
2.2 Group
Group(inputs=..., outputs=...)
input transform과 output transform을 한 쌍으로 관리한다.
push
new_inputs = old_inputs + inputs
new_outputs = outputs + old_outputs
왜 output을 앞에 붙이나?
input transform의 역변환은 output에서 반대 순서로 적용되어야 하기 때문이다.
예:
input transforms: [A, B, C]
output transforms: [C^-1, B^-1, A^-1]
2.3 CompositeTransform / compose
compose는 transform sequence를 하나의 transform으로 묶는다.
for transform in transforms:
data = transform(data)
return data
이 구조 때문에 transform 중 하나가 in-place로 data를 바꾸면 뒤 transform은 바뀐 dict를 받는다. 디버깅할 때는 transform 사이사이에 print wrapper를 넣으면 좋다.
2.4 RepackTransform
목적
raw dataset/environment dict를 새 nested dict로 재구성한다.
동작
1. input nested dict를 flatten한다.
2. self.structure에 적힌 old key path를 찾는다.
3. 새 nested dict 구조로 값을 복사한다.
예
raw:
observation.images.top
observation.state
action
repack structure:
images/cam_high <- observation.images.top
state <- observation.state
actions <- action
디버깅
KeyError가 나오면 대부분 old key path가 틀렸다. dataset sample의 flatten key를 먼저 출력해야 한다.
2.5 InjectDefaultPrompt
동작:
if self.prompt is not None and "prompt" not in data:
data["prompt"] = np.asarray(self.prompt)
주의:
- 이미 prompt가 있으면 덮어쓰지 않는다.
- prompt가 없고 default_prompt도 없으면 tokenizer에서 error가 난다.
2.6 Normalize
동작:
if norm_stats is None:
return data
else:
apply_tree(data, norm_stats, normalization_fn)
z-score:
(x - mean) / (std + 1e-6)
quantile:
(x - q01) / (q99 - q01 + 1e-6) * 2 - 1
중요 디테일
stats.mean[..., : x.shape[-1]]처럼 현재 x dimension까지만 stats를 자른다. 이는 model action_dim이 실제 robot action_dim보다 클 때 padding/partial dimension을 다루기 위한 설계다.
2.7 Unnormalize
Normalize의 역변환이다.
z-score inverse:
x * (std + 1e-6) + mean
quantile inverse:
(x + 1) / 2 * (q99 - q01 + 1e-6) + q01
중요 디테일
Unnormalize는 strict mode로 norm_stats key가 output data에 없으면 error를 낼 수 있다. 이는 model output action을 robot command로 바꾸기 전에 stats가 제대로 적용되도록 보장한다.
2.8 ResizeImages
동작:
data["image"] = {k: resize_with_pad(v, height, width) for each image}
추론 시 client/server 어느 쪽에서 resize할지 결정해야 한다.
| resize 위치 | 장점 | 단점 |
|---|---|---|
| client | network payload 감소 | robot CPU 사용 |
| server | robot side 단순 | network payload 증가 |
2.9 DeltaActions
동작:
selected dims:
actions -= state
여기서 action shape는 보통:
[action_horizon, action_dim]
state shape는:
[action_dim] or state_dim
np.expand_dims로 state를 horizon axis에 broadcast한다.
예
현재 joint position이:
state[:6] = [1.0, 0.5, ...]
absolute target action이:
actions[0,:6] = [1.1, 0.7, ...]
DeltaActions 후:
actions[0,:6] = [0.1, 0.2, ...]
2.10 AbsoluteActions
DeltaActions의 역변환이다.
selected dims:
actions += state
추론 output이 delta라면 실제 robot에 보내기 전에 absolute target으로 바꿔야 할 수 있다.
2.11 TokenizePrompt
π0/π0.5용.
동작:
1. data에서 prompt를 pop한다.
2. prompt가 없으면 ValueError.
3. discrete_state_input이면 state도 tokenizer에 넘긴다.
4. tokenizer.tokenize(prompt, state)
5. tokenized_prompt, tokenized_prompt_mask를 data에 추가한다.
2.12 TokenizeFASTInputs
π0-FAST용.
동작:
1. prompt를 pop한다.
2. state와 optional actions를 가져온다.
3. FASTTokenizer.tokenize(prompt, state, actions)
4. tokenized_prompt, tokenized_prompt_mask, token_ar_mask, token_loss_mask 추가
training에서는 actions가 있으므로 action target tokens가 포함된다. inference에서는 action이 없으므로 prefix만 tokenization된다.
2.13 ExtractFASTActions
π0-FAST model output은 처음에는 token IDs다.
동작:
1. data["actions"]를 token sequence로 해석한다.
2. tokenizer.extract_actions(tokens, action_horizon, action_dim)
3. continuous action chunk로 복원한다.
이 transform이 빠지면 robot에는 token ID sequence가 action처럼 전달되는 치명적 버그가 생긴다.
2.14 PadStatesAndActions
동작:
state를 model_action_dim까지 zero padding
actions가 있으면 actions도 model_action_dim까지 zero padding
모델 checkpoint는 action_dim=32 같은 큰 dimension을 기대할 수 있지만, 실제 robot은 7/8/14 dim만 쓸 수 있다. padding은 이 차이를 맞춘다.
2.15 make_bool_mask
예:
make_bool_mask(6, -1)
→ True True True True True True False
양수는 True 개수, 음수는 False 개수다. delta action mask를 만들 때 자주 쓴다.
3. src/openpi/policies/policy_config.py
이 파일은 checkpoint와 config를 실제 callable Policy로 바꾸는 bridge다.
3.1 create_trained_policy 인자
| 인자 | 의미 |
|---|---|
train_config | model/data/training config |
checkpoint_dir | checkpoint root path 또는 gs path |
repack_transforms | 추론 환경에 추가로 적용할 repack |
sample_kwargs | sample_actions에 넘길 추가 인자, 예: num_steps |
default_prompt | prompt 누락 시 주입 |
norm_stats | 직접 줄 norm stats. 없으면 checkpoint assets에서 load |
pytorch_device | PyTorch model device |
3.2 논리적 line-by-line
1단계: repack transform default
repack_transforms = repack_transforms or Group()
추가 repack이 없으면 빈 transform group이다.
2단계: checkpoint download/cache
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
gs://... path면 local cache로 내려받는다.
3단계: PyTorch checkpoint 감지
weight_path = checkpoint_dir / model.safetensors
is_pytorch = exists(weight_path)
4단계: model load
PyTorch:
model = train_config.model.load_pytorch(train_config, weight_path)
model.paligemma_with_expert.to_bfloat16_for_selected_params(...)
JAX:
params = restore_params(checkpoint_dir / params, dtype=bfloat16)
model = train_config.model.load(params)
5단계: data config 생성
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
이때 repack/data/model transforms와 norm stats 설정이 결정된다.
6단계: norm stats load
if norm_stats is None:
norm_stats = load_norm_stats(checkpoint_dir / assets, data_config.asset_id)
중요: inference에서는 config assets가 아니라 checkpoint assets를 우선 사용한다. 이는 training 때 저장된 stats와 inference stats가 동일해야 하기 때문이다.
7단계: Policy 생성
input transforms:
repack_transforms.inputs
InjectDefaultPrompt(default_prompt)
data_config.data_transforms.inputs
Normalize(norm_stats)
data_config.model_transforms.inputs
output transforms:
data_config.model_transforms.outputs
Unnormalize(norm_stats)
data_config.data_transforms.outputs
repack_transforms.outputs
4. src/openpi/policies/policy.py
4.1 Policy.__init__
논리
1. model 저장
2. input transforms compose
3. output transforms compose
4. sample_kwargs 저장
5. metadata 저장
6. PyTorch model이면 device 이동 + eval mode
7. JAX model이면 sample_actions를 jit compile wrapper로 감쌈
8. JAX rng 초기화
성능 포인트
- JAX는 첫 호출에 compile overhead가 크다.
- PyTorch는
eval()이 중요하다. dropout/batchnorm류가 있으면 inference behavior가 달라질 수 있다. sample_kwargs로num_steps,temperature,max_decoding_steps같은 latency knob를 주입할 수 있다.
4.2 Policy.infer
1단계: input copy
inputs = tree.map(lambda x: x, obs)
얕은 copy에 가깝다. transform이 in-place로 바꿀 수 있으므로 원본 obs를 보호하려는 의도다.
2단계: input transform
inputs = self._input_transform(inputs)
여기서 raw robot observation이 model-ready dict로 바뀐다.
3단계: framework별 batch/device 처리
JAX:
leaf → jnp.asarray(leaf)[None, ...]
rng split
PyTorch:
leaf → torch.from_numpy(np.array(leaf)).to(device)[None, ...]
4단계: optional fixed noise
if noise is not None:
if noise.ndim == 2:
noise = noise[None, ...]
sample_kwargs["noise"] = noise
이것은 profiling에 매우 중요하다. fixed noise를 쓰면 stochastic variance를 줄이고 model path를 deterministic하게 비교할 수 있다.
5단계: Observation 생성
observation = Observation.from_dict(inputs)
여기서 image uint8→float conversion이 발생할 수 있다.
6단계: model.sample_actions timing
start = time.monotonic()
actions = self._sample_actions(...)
model_time = time.monotonic() - start
주의: JAX는 asynchronous execution이 있을 수 있으므로 정확한 latency 측정에는 block_until_ready가 필요할 수 있다. 이 코드의 timing은 일반 server timing으로는 유용하지만, microbenchmark에서는 별도 sync 측정이 필요하다.
7단계: output 변환
outputs = {
"state": inputs["state"],
"actions": actions,
}
batch dimension 제거 후 numpy로 바꾼다.
outputs = self._output_transform(outputs)
여기서 unnormalize, FAST decode, robot-specific output conversion이 일어난다.
5. src/openpi/models/pi0.py
π0/π0.5의 핵심 구현이다.
5.1 파일 역할
이 파일은 다음을 구현한다.
make_attn_mask
posemb_sincos
Pi0 class
__init__
embed_prefix
embed_suffix
compute_loss
sample_actions
5.2 make_attn_mask
입력
input_mask: bool[B, N]
mask_ar: bool/int[B or 1, N]
동작
1. mask_ar를 input_mask shape으로 broadcast
2. token axis cumulative sum 계산
3. query token i가 key token j를 볼 수 있는지 계산
4. padding token을 valid_mask로 제거
왜 필요한가?
π0는 image/prompt prefix와 action suffix를 하나의 transformer sequence에 넣는다. 그런데 모든 token이 모든 token을 보면 안 된다.
원하는 구조:
image/prompt prefix:
서로 bidirectional attend 가능
action suffix:
prefix attend 가능
action block 내부 attend 가능 또는 제한된 block causal
mask_ar는 이 구조를 compact하게 표현한다.
5.3 posemb_sincos
입력
pos: timestep
embedding_dim: hidden dimension
min_period, max_period
동작
1. embedding_dim이 even인지 검사
2. log-spaced periods 생성
3. pos / period 계산
4. sin과 cos를 concat
의미
diffusion/flow model에서 timestep을 transformer가 이해할 수 있는 vector로 바꾼다.
5.4 Pi0.__init__
1단계: base init
super().__init__(action_dim, action_horizon, max_token_len)
모든 model은 action_dim/action_horizon/max_token_len을 공통으로 갖는다.
2단계: π0.5 flag 저장
self.pi05 = config.pi05
이 flag가 이후 action/time conditioning 방식을 바꾼다.
3단계: Gemma/PaliGemma config 준비
paligemma_config = get_config(config.paligemma_variant)
action_expert_config = get_config(config.action_expert_variant)
OpenPI는 VLM backbone과 action expert를 함께 구성한다.
4단계: LLM module 생성
llm = ToNNX(Gemma Module(...))
llm.lazy_init(...)
Flax NNX bridge를 사용한다. Gemma 구현을 NNX로 직접 다시 쓰지 않고 bridge하는 구조다.
π0.5이면 action expert 쪽에 AdaRMS를 켠다.
5단계: image encoder 생성
img = ToNNX(SigLIP Module(... variant="So400m/14" ...))
img.lazy_init(fake image)
SigLIP vision encoder가 image를 token embedding sequence로 바꾼다.
6단계: action projection 정의
action_in_proj : action_dim → hidden_width
action_out_proj : hidden_width → action_dim
continuous action vector와 transformer hidden space를 연결한다.
7단계: π0.5 vs π0 분기
π0.5:
time_mlp_in
time_mlp_out
π0:
state_proj
action_time_mlp_in
action_time_mlp_out
의미:
- π0는 state token을 별도로 suffix에 넣고 action+time을 MLP로 섞는다.
- π0.5는 AdaRMS condition으로 time 정보를 넣는다.
5.5 embed_prefix
목적
변하지 않는 조건부 context를 embedding한다.
단계별 설명
1단계: 빈 list 준비
tokens = []
input_mask = []
ar_mask = []
2단계: image별 embedding
for name in obs.images:
image_tokens = SigLIP(image)
tokens.append(image_tokens)
input_mask.append(repeat image mask across image token length)
ar_mask.append(False block)
image token은 서로 attend 가능해야 하므로 ar_mask는 false block이다.
3단계: prompt token embedding
if tokenized_prompt exists:
prompt_embeddings = LLM.embed(token IDs)
tokens.append(prompt_embeddings)
input_mask.append(prompt_mask)
ar_mask.append(False block)
prompt도 prefix이므로 image와 함께 full attention block으로 둔다.
4단계: concat
return concat(tokens), concat(input_mask), concat(ar_mask)
5.6 embed_suffix
목적
현재 denoising step에서 변하는 state/action/timestep token을 만든다.
π0 path
1단계: state token
state_embedding = state_proj(obs.state)
state를 transformer hidden으로 project한다.
2단계: action token
action_embedding = action_in_proj(noisy_actions)
noisy action chunk [B,H,A]가 hidden token [B,H,D]가 된다.
3단계: time embedding
time_embedding = posemb_sincos(timestep)
time_embedding = repeat across action_horizon
4단계: action+time fusion
concat(action_embedding, time_embedding)
→ action_time_mlp_in
→ activation
→ action_time_mlp_out
5단계: masks
state/action suffix가 prefix 뒤에 붙는다. state와 action block의 attention relation을 ar_mask로 정의한다.
π0.5 path
1단계: action token
action_embedding = action_in_proj(noisy_actions)
2단계: time condition
time_embedding = posemb_sincos(timestep)
adarms_cond = time_mlp(time_embedding)
3단계: LLM/action expert forward에서 AdaRMS condition으로 사용
π0.5는 token에 time을 직접 concat하기보다 normalization modulation 쪽으로 넣는 구조다.
5.7 compute_loss
전체 목적
clean action chunk actions와 random noise 사이를 잇는 vector field를 학습한다.
단계별 설명
1단계: rng split
rng_preprocess, rng_noise, rng_time = split(rng)
각 random source를 분리한다.
2단계: observation preprocessing
observation = preprocess_observation(rng_preprocess, observation, train=train)
train이면 augmentation이 적용된다.
3단계: noise sampling
noise = normal(shape=actions.shape)
4단계: timestep sampling
time = Beta(1.5, 1)
time = scaled to [0.001, 1]
5단계: interpolation
x_t = time * noise + (1 - time) * actions
t=1은 noise, t=0은 action이다.
6단계: target vector field
u_t = noise - actions
이는 x_t를 time에 대해 미분한 값과 연결된다.
7단계: prefix/suffix embedding
prefix = embed_prefix(observation)
suffix = embed_suffix(observation, x_t, time)
8단계: attention mask 생성
input_mask = concat(prefix_mask, suffix_mask)
ar_mask = concat(prefix_ar, suffix_ar)
attn_mask = make_attn_mask(input_mask, ar_mask)
9단계: LLM/action expert forward
outputs = PaliGemma.llm(embedded tokens, mask, ...)
10단계: action vector prediction
v_t = action_out_proj(suffix_outputs)
11단계: MSE loss
loss = mean((v_t - u_t)^2 over action_dim)
5.8 sample_actions
전체 목적
noise에서 시작해 learned vector field를 따라 action chunk를 생성한다.
단계별 설명
1단계: observation preprocessing
observation = preprocess_observation(None, observation, train=False)
2단계: initial noise
if noise is None:
x_t = normal([B, action_horizon, action_dim])
else:
x_t = noise
fixed noise를 넣으면 deterministic profiling 가능.
3단계: prefix prefill
prefix_tokens, prefix_mask, prefix_ar = embed_prefix(observation)
prefix_attn_mask = make_attn_mask(...)
PaliGemma.llm(..., decode=True) → kv_cache
prefix는 denoising loop 동안 변하지 않으므로 cache한다.
4단계: integration loop
dt = -1 / num_steps
time = 1.0
while time >= 0:
suffix = embed_suffix(observation, x_t, time)
suffix forward with prefix kv_cache
v_t = action_out_proj(...)
x_t = x_t + dt * v_t
time = time + dt
5단계: return action chunk
return x_t
성능 핵심
num_steps번 suffix forward가 반복된다. 따라서 optimization 목표는:
1. prefix prefill을 최대한 재사용
2. suffix step latency 감소
3. denoising step 수 감소/distillation
4. action_horizon과 replanning 주기 최적화
6. src/openpi/models/pi0_fast.py
π0-FAST는 action을 token으로 생성한다.
6.1 Pi0FASTConfig
주요 field:
dtype
action_dim
action_horizon
max_token_len
fast_model_tokenizer
fast_model_tokenizer_kwargs
model_type은 PI0_FAST를 반환한다.
6.2 inputs_spec
π0-FAST observation에는 다음이 필요하다.
images
image_masks
state
tokenized_prompt
tokenized_prompt_mask
token_ar_mask
token_loss_mask
π0와 다르게 token loss mask가 필수다. next-token prediction loss를 어디에 걸지 알아야 하기 때문이다.
6.3 left_to_right_align
목적
variable-length prefix를 right-aligned 형태로 정렬한다.
왜 필요한가?
- autoregressive decoding에서 KV cache 위치/position을 맞추기 위해서다.
- batch마다 prompt/token length가 다르면 decode position 계산이 꼬일 수 있다.
6.4 put_along_last_axis
JAX에는 numpy의 put_along_axis(axis=-1)에 해당하는 기능이 부족할 수 있어서 one-hot/einsum으로 구현한다.
목적:
output_tokens[batch, step] = generated_token
6.5 Pi0FAST.__init__
π0와 비슷하지만 action flow head가 없다.
1. PaliGemma/Gemma fast module 생성
2. SigLIP image encoder 생성
3. self.PaliGemma = {llm, img}
continuous action projection 대신 token logits head는 LLM vocabulary head가 담당한다.
6.6 embed_inputs
단계
1. image tokens 생성
2. image masks 생성
3. image ar_mask는 false block
4. tokenized prompt/action input embedding 생성
5. token mask와 token_ar_mask 추가
6. concat해서 반환
여기서 tokenized prompt는 사실 prompt만이 아니라 FAST tokenizer가 구성한 state/action token sequence까지 포함할 수 있다.
6.7 compute_loss
단계별 설명
1단계: observation preprocess
image resize/augmentation 등.
2단계: input embedding
input_token_embeddings, input_mask, ar_mask = embed_inputs(observation)
3단계: attention mask
attn_mask = make_attn_mask(input_mask, ar_mask)
4단계: next-token targets
targets = one_hot(tokenized_prompt[:, 1:])
각 input token은 다음 token을 예측한다.
5단계: last input 제외
forward input = embedded sequence[:, :-1]
마지막 token은 다음 token target이 없으므로 input에서 제외된다.
6단계: prelogits 계산
LLM forward로 hidden/prelogits를 얻는다.
7단계: target token 구간만 logits decode
vocab projection은 비용이 크므로 target 구간에 대해서만 logits를 계산한다.
8단계: CE loss
logp = log_softmax(logits)
loss_mask = token_loss_mask[:, 1:]
loss = -sum(targets * logp * loss_mask) / sum(loss_mask)
6.8 sample_actions
단계별 설명
1단계: preprocess
observation = preprocess_observation(...)
2단계: prefix embedding
prefix_token_embeddings, prefix_mask, prefix_ar_mask = embed_inputs(observation)
추론 시 이 token sequence는 prompt/state prefix다.
3단계: left-to-right align
KV cache를 위한 정렬.
4단계: prefix prefill
prefix_logits, kv_cache = llm(... decode=True)
last_logit = prefix_logits[:, -1:]
last logit이 첫 generated token distribution이다.
5단계: decoding loop
while not all_eos and step < max_decoding_steps:
if temperature > 0:
token = categorical(last_logit / temperature)
else:
token = argmax(last_logit)
output_tokens[step] = token
token_embedding = llm.embed(token)
last_logit, kv_cache = llm(one token, kv_cache=cache)
6단계: return output tokens
Pi0FAST.sample_actions 자체는 generated tokens를 반환한다. 최종 action chunk 복원은 output transform ExtractFASTActions가 담당한다.
7. scripts/train.py
7.1 init_logging
로그 format을 짧게 바꾼다. 실험 로그에서 step/loss를 보기 쉽게 한다.
7.2 init_wandb
논리:
if wandb disabled:
wandb.init(mode="disabled")
else if resume:
checkpoint_dir/wandb_id.txt에서 run id 읽고 resume
else:
새 run 생성하고 wandb_id.txt 저장
실험 reproducibility를 위해 checkpoint directory와 wandb run id를 묶는다.
7.3 _load_weights_and_validate
1. loader.load(params_shape)
2. expected/got tree shape/dtype 검사
3. 실제 loaded param만 반환
checkpoint compatibility 문제를 early detect한다.
7.4 init_train_state
단계
1. optimizer 생성
2. init(rng, partial_params) 함수 정의
3. model 생성
4. partial checkpoint params가 있으면 state에 merge
5. params 추출
6. frozen params를 bf16으로 변환
7. TrainState 생성
8. eval_shape로 train_state_shape 계산
9. FSDP sharding 계산
10. resume이면 shape/sharding만 반환
11. weight_loader로 partial params 로드
12. jitted init으로 actual train_state 생성
중요한 설계
partial checkpoint loading을 지원하므로 PaliGemma base만 로드하거나, LoRA missing weight를 현재 init weight로 채울 수 있다.
7.5 train_step
단계
1. model = merge(model_def, params)
2. model.train()
3. loss_fn 정의
4. train_rng = fold_in(rng, state.step)
5. diff_state = trainable_filter
6. loss, grads = value_and_grad(loss_fn)
7. trainable params만 optimizer update
8. model에 new params update
9. full state 추출
10. EMA update
11. loss/grad_norm/param_norm 반환
성능/안정성 포인트
fold_in을 사용해 step별 rng를 안정적으로 만든다.- frozen params는 gradient 계산 대상에서 제외된다.
- EMA가 있으면 inference checkpoint에 EMA params가 쓰일 수 있다.
7.6 main
단계
1. logging init
2. batch_size % device_count 검사
3. JAX compilation cache 설정
4. rng split
5. mesh/sharding 생성
6. checkpoint manager 생성
7. wandb init
8. data_loader 생성
9. 첫 batch shape/info logging
10. 첫 batch camera image wandb logging
11. train_state init/restore
12. ptrain_step = jax.jit(...)
13. loop:
- train step
- log interval마다 metrics reduce/log
- next batch
- save interval마다 checkpoint
14. checkpoint async save 완료 대기
training debug 추천
처음부터 full training하지 말고 다음 순서로 한다.
# 1. norm stats 확인
uv run scripts/compute_norm_stats.py --config-name <config>
# 2. 매우 짧은 debug config 또는 num_train_steps 수정
uv run scripts/train.py <config> --exp-name=debug --overwrite
# 3. 첫 batch image wandb logging 확인
# 4. loss가 finite인지 확인
# 5. checkpoint 생성 확인
8. scripts/serve_policy.py
8.1 EnvMode
지원 환경:
ALOHA
ALOHA_SIM
DROID
LIBERO
8.2 Checkpoint dataclass
config: training config name
dir: checkpoint directory
예:
config = pi05_droid
dir = gs://openpi-assets/checkpoints/pi05_droid
8.3 DEFAULT_CHECKPOINT
env별 기본 policy를 mapping한다.
DROID → pi05_droid
LIBERO → pi05_libero
...
팀 실험에서는 default를 쓰기보다 명시적으로 policy:checkpoint를 지정하는 것이 reproducibility에 좋다.
8.4 create_policy
if args.policy is Checkpoint:
create_trained_policy(get_config(config), dir)
else:
create_default_policy(env)
8.5 main
1. policy 생성
2. metadata 저장
3. record 옵션이면 PolicyRecorder wrapper 적용
4. host/ip logging
5. WebsocketPolicyServer 생성
6. serve_forever
--record는 디버깅에 유용하다. policy input/output을 policy_records에 저장해 replay/분석할 수 있다.
9. src/openpi/serving/websocket_policy_server.py
9.1 WebsocketPolicyServer.__init__
policy, host, port, metadata를 저장한다.
9.2 serve_forever / run
async websocket server를 시작한다.
설정:
compression=None
max_size=None
process_request=_health_check
압축을 끄는 것은 latency와 CPU overhead 측면에서 합리적일 수 있다. image payload가 크면 압축이 네트워크를 줄일 수 있지만 CPU latency를 늘릴 수 있다. robot real-time에서는 압축 off가 더 안정적인 경우가 많다.
9.3 _handler
단계
1. connection open logging
2. msgpack_numpy.Packer 생성
3. metadata를 client에 최초 전송
4. loop:
a. obs 수신
b. unpack
c. policy.infer(obs)
d. server_timing 추가
e. pack해서 client에 전송
f. total time 기록
5. connection closed 처리
6. exception이면 traceback 전송 후 close
timing 해석
server_timing["infer_ms"]:policy.infer에 걸린 시간prev_total_ms: 이전 request의 수신→추론→전송 포함 total
10. packages/openpi-client/.../websocket_client_policy.py
10.1 __init__
1. host가 ws로 시작하면 그대로 사용
2. 아니면 ws://host 구성
3. port가 있으면 :port 추가
4. Packer 생성
5. server connect retry
6. metadata 수신
10.2 _wait_for_server
server가 아직 안 떠 있으면 5초 간격으로 retry한다. robot runtime에서 server를 먼저 띄우지 않아도 client가 기다릴 수 있다.
10.3 infer
1. obs를 msgpack pack
2. websocket send
3. response recv
4. response가 string이면 server error로 간주
5. bytes면 unpack해서 action dict 반환
성능 포인트
msgpack serialization/deserialization은 e2e latency에 포함된다. 큰 image를 raw로 보내면 이 비용이 커진다. client-side resize/uint8 conversion이 중요한 이유다.
11. src/openpi/training/data_loader.py
11.1 핵심 역할
이 파일은 raw dataset을 training batch로 만든다.
LeRobot/RLDS dataset
→ repack/data/model transforms
→ Normalize
→ Observation + Actions batch
→ JAX sharding 또는 PyTorch tensor
11.2 TransformedDataset
random access dataset wrapper.
__getitem__(index):
return transform(dataset[index])
11.3 IterableTransformedDataset
RLDS처럼 iterable dataset에 transform을 적용한다. batched sample이면 batch를 sample 단위로 쪼개 transform한 뒤 다시 stack한다.
이것은 transform들이 대부분 unbatched sample 기준으로 작성되어 있기 때문이다.
11.4 FakeDataset
model input spec을 기반으로 random fake sample을 만든다. config/debug/test에서 유용하다.
11.5 create_torch_dataset
if repo_id == fake:
FakeDataset
else:
LeRobotDatasetMetadata(repo_id)
LeRobotDataset(repo_id, delta_timestamps=...)
if prompt_from_task:
PromptFromLeRobotTask 적용
delta_timestamps는 action horizon만큼 future action sequence를 가져오기 위한 핵심이다.
11.6 transform_dataset
raw dataset sample
→ repack_transforms.inputs
→ data_transforms.inputs
→ Normalize
→ model_transforms.inputs
norm_stats가 없으면 error를 낸다. 단 skip_norm_stats=True면 건너뛸 수 있다.
11.7 create_data_loader
if data_config.rlds_data_dir is not None:
create_rlds_data_loader
else:
create_torch_data_loader
DROID full dataset path만 RLDS를 쓴다고 보면 된다.
11.8 TorchDataLoader
PyTorch DataLoader를 감싼 wrapper다.
중요 옵션:
| 옵션 | 의미 |
|---|---|
local_batch_size | process/device별 batch |
sharding | JAX sharded array 생성용 |
shuffle | dataset shuffle |
num_workers | CPU worker process 수 |
framework | jax or pytorch |
JAX framework이면 batch를 jax.make_array_from_process_local_data로 sharded array로 만든다. PyTorch framework이면 torch tensor로 반환한다.
12. src/openpi/training/checkpoints.py
12.1 initialize_checkpoint_dir
if checkpoint_dir exists:
if overwrite: delete and recreate
elif resume: resuming=True
else: raise FileExistsError
CheckpointManager 생성
empty checkpoint인데 resume이면 resume 취소
12.2 save_state
1. save_assets callback 정의
- norm_stats를 assets/asset_id에 저장
2. _split_params(state)
- EMA params가 있으면 inference params로 사용
3. items = assets + train_state + params
4. checkpoint_manager.save(step, items)
12.3 _split_params
if ema_params exists:
params = ema_params
train_state.ema_params = None
else:
params = state.params
train_state.params = {}
이 설계 때문에 inference checkpoint의 params는 EMA를 사용할 수 있다.
13. src/openpi/training/weight_loaders.py
13.1 NoOpWeightLoader
아무 weight도 로드하지 않는다. debug/fake training에 유용하다.
13.2 CheckpointWeightLoader
1. params_path에서 checkpoint restore
2. missing LoRA weights는 현재 initialized params에서 merge
3. merged params 반환
LoRA fine-tuning에서 base checkpoint에는 LoRA parameter가 없으므로 missing LoRA weight를 init된 값으로 채운다.
13.3 PaliGemmaWeightLoader
official PaliGemma checkpoint를 로드하고, OpenPI action expert 등 extra weight는 유지한다.
이것은 π0 base pretraining 전 초기화나 특정 실험에서 중요하다.
14. 수정할 때 가장 먼저 봐야 하는 파일 조합
14.1 새 robot 붙이기
src/openpi/policies/<new_robot>_policy.py
src/openpi/training/config.py
src/openpi/transforms.py
examples/ur5/README.md
14.2 inference latency 줄이기
src/openpi/models/pi0.py
src/openpi/models/pi0_fast.py
src/openpi/policies/policy.py
src/openpi/serving/websocket_policy_server.py
packages/openpi-client/src/openpi_client/websocket_client_policy.py
14.3 PyTorch path 최적화
src/openpi/models_pytorch/pi0_pytorch.py
src/openpi/models_pytorch/gemma_pytorch.py
src/openpi/models_pytorch/preprocessing_pytorch.py
src/openpi/models_pytorch/transformers_replace/*
14.4 training 안정화
src/openpi/training/config.py
src/openpi/training/data_loader.py
src/openpi/training/optimizer.py
src/openpi/training/checkpoints.py
scripts/compute_norm_stats.py
scripts/train.py
15. 핵심 invariants
OpenPI를 수정할 때 반드시 유지해야 하는 invariants다.
15.1 shape invariants
image: [B, 224, 224, 3]
state: [B, action_dim] or padded to action_dim
actions: [B, action_horizon, action_dim]
prompt: [B, max_token_len]
15.2 value range invariants
image float: [-1, 1]
uint8 image: [0, 255] before Observation.from_dict
normalized state/action: config norm mode에 따름
unnormalized action: robot command convention에 맞아야 함
15.3 transform order invariants
input:
repack → robot input transform → normalize → model transform
output:
model output transform → unnormalize → robot output transform → repack output
15.4 checkpoint invariants
JAX checkpoint:
checkpoint_dir/params
checkpoint_dir/assets/<asset_id>
PyTorch checkpoint:
checkpoint_dir/model.safetensors
checkpoint_dir/assets/<asset_id>
15.5 policy invariants
policy.infer(obs) returns dict with:
actions
policy_timing
optional:
state
server_timing, if through websocket server
16. 디버깅 print template
아래 template은 transform chain debugging에 유용하다.
def describe_tree(x, prefix=""):
if isinstance(x, dict):
for k, v in x.items():
describe_tree(v, prefix + "/" + str(k))
else:
shape = getattr(x, "shape", None)
dtype = getattr(x, "dtype", None)
print(prefix, shape, dtype)
policy 내부에서 단계별로:
print("raw obs")
describe_tree(obs)
inputs = self._input_transform(obs)
print("after input transforms")
describe_tree(inputs)
observation = Observation.from_dict(inputs)
print("Observation images", {k: v.shape for k, v in observation.images.items()})
print("Observation state", observation.state.shape)
주의: 실제 robot loop에서는 print가 latency를 크게 늘리므로 debug run에서만 사용한다.
17. 최종 mental model
OpenPI 코드의 본질은 다음 세 문장으로 압축된다.
model.py는 모델이 기대하는 표준 tensor contract를 정의한다.transforms.py + policies/* + training/config.py는 각 robot/dataset을 그 contract에 맞춘다.pi0.py / pi0_fast.py는 contract에 맞게 들어온 observation을 action chunk로 생성한다.
따라서 새 연구를 시작할 때 코드를 읽는 순서는 다음이 가장 효율적이다.
model.py
→ transforms.py
→ policy_config.py
→ policy.py
→ target policy file: droid/aloha/libero
→ training/config.py
→ pi0.py or pi0_fast.py
→ train.py / serve_policy.py
이 순서를 따르면 “모델부터 읽다가 데이터 schema에서 막히는 문제”를 피할 수 있다.