openpi 사용 절차와 핵심 코드 해설

이 문서는 openpi를 실제로 설치하고, checkpoint를 실행하고, 데이터를 변환하고, fine-tuning하고, 필요하면 PyTorch로 변환하는 과정을 step-by-step으로 설명한다. 핵심 실행 경로는 코드 조각 단위로 line-by-line 해설한다.

1. 설치와 환경 확인

1.1 전제

README 기준 전제는 다음이다.

  • OS: Ubuntu 22.04
  • GPU: NVIDIA GPU
  • Python: pyproject.toml 기준 >=3.11
  • dependency manager: uv
  • inference: 8GB 이상 VRAM
  • LoRA fine-tuning: 22.5GB 이상 VRAM
  • full fine-tuning: 70GB 이상 VRAM

현재 워크트리에서는 submodule이 초기화되지 않았다. 실제 ALOHA/LIBERO submodule이 필요한 예제를 쓰려면 먼저 실행한다.

git submodule update --init --recursive

1.2 uv 설치 후 dependency 설치

GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .

해설:

  • GIT_LFS_SKIP_SMUDGE=1: LeRobot dependency checkout 과정에서 Git LFS 대용량 파일을 바로 받지 않게 한다.
  • uv sync: pyproject.tomluv.lock 기준으로 .venv를 만든다.
  • uv pip install -e .: 현재 저장소를 editable package로 설치한다. 코드 수정 후 재설치 없이 import에 반영된다.

설치 확인:

uv run python -c "import openpi; print(openpi.__file__)"
uv run python -c "from openpi.training import config; print(config.get_config('debug').name)"

2. 빠른 checkpoint 추론

README의 핵심 API는 다음이다.

from openpi.training import config as _config
from openpi.policies import policy_config
from openpi.shared import download

config = _config.get_config("pi05_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_droid")

policy = policy_config.create_trained_policy(config, checkpoint_dir)

example = {
    "observation/exterior_image_1_left": ...,
    "observation/wrist_image_left": ...,
    "prompt": "pick up the fork",
}
action_chunk = policy.infer(example)["actions"]

Line-by-line:

  1. from openpi.training import config as _config
    • config registry를 import한다. get_config()src/openpi/training/config.py_CONFIGS_DICT에서 이름으로 TrainConfig를 찾는다.
  2. from openpi.policies import policy_config
    • checkpoint를 실제 inference policy 객체로 바꾸는 factory를 import한다.
  3. from openpi.shared import download
    • gs://, local path, fsspec URL을 cache directory로 내려받는 helper다.
  4. config = _config.get_config("pi05_droid")
    • pi05_droid 설정을 고른다. 이 config는 Pi0Config(pi05=True, action_horizon=15)와 DROID input/output transform을 사용한다.
  5. checkpoint_dir = download.maybe_download(...)
    • Google Cloud Storage checkpoint를 ~/.cache/openpi 또는 OPENPI_DATA_HOME 아래로 받는다.
  6. policy = policy_config.create_trained_policy(config, checkpoint_dir)
    • checkpoint type을 JAX/PyTorch로 감지한다.
    • 모델 parameter를 로드한다.
    • norm stats를 checkpoint assets에서 로드한다.
    • repack, robot transform, normalization, model tokenization transform chain을 만든다.
  7. example = {...}
    • DROID config에서 기대하는 raw observation key다. DROID runtime에서는 observation/exterior_image_1_left, observation/wrist_image_left, joint/gripper state, prompt 등이 필요하다.
  8. action_chunk = policy.infer(example)["actions"]
    • 입력 dict가 transform chain을 지나 Observation으로 바뀐다.
    • 모델 sample_actions()가 action chunk를 생성한다.
    • output transform이 action을 robot command space로 되돌린다.

실제로 로봇 없이 호출 구조만 확인하려면 examples/simple_client를 쓰는 편이 안전하다. 먼저 server:

uv run scripts/serve_policy.py policy:checkpoint \
  --policy.config=pi05_droid \
  --policy.dir=gs://openpi-assets/checkpoints/pi05_droid

다른 터미널에서 client:

uv run examples/simple_client/main.py --env DROID

3. create_trained_policy() line-by-line

파일: src/openpi/policies/policy_config.py

핵심 코드:

repack_transforms = repack_transforms or transforms.Group()
checkpoint_dir = download.maybe_download(str(checkpoint_dir))

weight_path = os.path.join(checkpoint_dir, "model.safetensors")
is_pytorch = os.path.exists(weight_path)

if is_pytorch:
    model = train_config.model.load_pytorch(train_config, weight_path)
    model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
else:
    model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))

data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
if norm_stats is None:
    if data_config.asset_id is None:
        raise ValueError("Asset id is required to load norm stats.")
    norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)

return _policy.Policy(
    model,
    transforms=[
        *repack_transforms.inputs,
        transforms.InjectDefaultPrompt(default_prompt),
        *data_config.data_transforms.inputs,
        transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
        *data_config.model_transforms.inputs,
    ],
    output_transforms=[
        *data_config.model_transforms.outputs,
        transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
        *data_config.data_transforms.outputs,
        *repack_transforms.outputs,
    ],
    sample_kwargs=sample_kwargs,
    metadata=train_config.policy_metadata,
    is_pytorch=is_pytorch,
    pytorch_device=pytorch_device if is_pytorch else None,
)

해설:

  • repack_transforms = ...: inference caller가 별도 key remapping을 넘기지 않으면 빈 transform group을 쓴다.
  • download.maybe_download: checkpoint path가 gs://이면 cache로 다운로드하고, local이면 그대로 Path로 정규화한다.
  • model.safetensors 존재 여부: PyTorch checkpoint 감지 규칙이다. 없으면 JAX Orbax checkpoint로 본다.
  • PyTorch branch: load_pytorch()가 model instance를 만들고 safetensors를 load한다. 이후 선택 parameter를 bfloat16 정책에 맞게 변환한다.
  • JAX branch: restore_params(checkpoint_dir / "params")로 Orbax parameter tree를 읽고 BaseModelConfig.load()로 model state에 주입한다.
  • data_config = train_config.data.create(...): data factory가 repack/data/model transform과 norm stats 설정을 만든다.
  • norm_stats: 명시로 넘기지 않으면 checkpoint의 assets/<asset_id>/norm_stats.json을 읽는다. inference와 training normalization을 반드시 맞추려는 설계다.
  • input transforms: raw dict에서 model-ready dict로 가는 순서다.
  • output output_transforms: model output에서 robot command로 돌아가는 순서다. input의 역변환이 앞쪽에 온다.
  • metadata: policy server가 client에 알려줄 reset pose 같은 정보다.

4. Policy.infer() line-by-line

파일: src/openpi/policies/policy.py

핵심 코드:

inputs = jax.tree.map(lambda x: x, obs)
inputs = self._input_transform(inputs)
if not self._is_pytorch_model:
    inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs)
    self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng)
else:
    inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs)
    sample_rng_or_pytorch_device = self._pytorch_device

sample_kwargs = dict(self._sample_kwargs)
if noise is not None:
    noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)
    if noise.ndim == 2:
        noise = noise[None, ...]
    sample_kwargs["noise"] = noise

observation = _model.Observation.from_dict(inputs)
start_time = time.monotonic()
outputs = {
    "state": inputs["state"],
    "actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs),
}
model_time = time.monotonic() - start_time
...
outputs = self._output_transform(outputs)
outputs["policy_timing"] = {"infer_ms": model_time * 1000}
return outputs

해설:

  • jax.tree.map(lambda x: x, obs): 입력 dict를 얕은 tree copy처럼 다뤄 transform이 원본을 직접 바꾸는 부작용을 줄인다.
  • _input_transform: create_trained_policy()에서 조립한 모든 input transform이 실행된다.
  • JAX branch:
    • 모든 leaf를 jnp.asarray로 바꾼다.
    • [np.newaxis, ...]로 batch dimension을 추가한다.
    • RNG를 split해 이번 inference sampling에 쓸 key를 만든다.
  • PyTorch branch:
    • numpy array로 강제 변환 후 torch tensor로 만들고 device로 옮긴다.
    • batch dimension을 추가한다.
    • PyTorch sample_actions()는 첫 번째 인자로 RNG 대신 device string을 받는 식으로 맞춰져 있다.
  • noise: flow matching 모델에서 deterministic 비교나 debugging을 위해 초기 noise를 직접 넣을 수 있다.
  • Observation.from_dict: dict를 typed observation dataclass로 변환한다.
  • _sample_actions: JAX면 nnx_utils.module_jit(model.sample_actions), PyTorch면 raw model.sample_actions.
  • output에는 stateactions가 들어간다. output transform 중 Unnormalize가 state/action stats를 참조하므로 state도 같이 둔다.
  • batch dimension을 제거한 뒤 _output_transform이 robot action space로 되돌린다.
  • policy_timing은 model sampling 시간만 ms 단위로 기록한다.

5. π0 flow matching 코드 흐름

파일: src/openpi/models/pi0.py

5.1 Attention mask

mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
cumsum = jnp.cumsum(mask_ar, axis=1)
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
return jnp.logical_and(attn_mask, valid_mask)

해설:

  • mask_ar는 token group 경계를 표시한다. False면 이전 token과 같은 attention block, True면 새 causal block으로 본다.
  • cumsum은 각 token이 몇 번째 AR block에 속하는지 만든다.
  • attn_mask[i, q, k] = True 조건은 key token의 block id가 query token의 block id보다 작거나 같을 때다.
  • valid_mask는 padding token을 attention에서 제외한다.
  • prefix image/language는 full attention, suffix action token은 causal/prefix-LM attention을 만들 수 있다.

5.2 embed_prefix()

for name in obs.images:
    image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
    tokens.append(image_tokens)
    input_mask.append(einops.repeat(obs.image_masks[name], "b -> b s", s=image_tokens.shape[1]))
    ar_mask += [False] * image_tokens.shape[1]

if obs.tokenized_prompt is not None:
    tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
    tokens.append(tokenized_inputs)
    input_mask.append(obs.tokenized_prompt_mask)
    ar_mask += [False] * tokenized_inputs.shape[1]

해설:

  • 각 camera view를 SigLIP image token sequence로 바꾼다.
  • image mask는 camera가 실제로 존재하는지 나타낸다. 없는 wrist camera는 zero image + False mask로 들어갈 수 있다.
  • prefix token들은 모두 ar_mask=False라 서로 bidirectional/full attention이 가능하다.
  • prompt token은 Gemma embedder로 embedding한다.
  • 결과는 (prefix_tokens, prefix_input_mask, prefix_ar_mask)다.

5.3 embed_suffix()

π0:

state_token = self.state_proj(obs.state)[:, None, :]
...
action_tokens = self.action_in_proj(noisy_actions)
time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
action_time_tokens = self.action_time_mlp_in(action_time_tokens)
action_time_tokens = nnx.swish(action_time_tokens)
action_time_tokens = self.action_time_mlp_out(action_time_tokens)

π0.5:

action_tokens = self.action_in_proj(noisy_actions)
time_emb = posemb_sincos(...)
time_emb = self.time_mlp_in(time_emb)
time_emb = nnx.swish(time_emb)
time_emb = self.time_mlp_out(time_emb)
time_emb = nnx.swish(time_emb)
action_expert_tokens = action_tokens
adarms_cond = time_emb

해설:

  • π0는 state를 별도 continuous suffix token으로 넣는다.
  • π0.5는 state를 prompt token으로 넣을 수 있으므로 suffix state token을 생략한다.
  • action은 action_dim -> action_expert_width로 projection한다.
  • timestep은 scalar라 sinusoidal embedding으로 만든다.
  • π0는 action embedding과 time embedding을 concatenate해 MLP로 섞는다.
  • π0.5는 timestep을 AdaRMS conditioning으로 action expert transformer block에 주입한다.

5.4 compute_loss()

preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
observation = _model.preprocess_observation(preprocess_rng, observation, train=train)

batch_shape = actions.shape[:-2]
noise = jax.random.normal(noise_rng, actions.shape)
time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
time_expanded = time[..., None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions

해설:

  • RNG를 image augmentation, noise sampling, time sampling 용도로 나눈다.
  • training이면 image augmentation이 적용된다.
  • noise는 action chunk와 같은 shape다.
  • time은 batch마다 하나씩 뽑는다. 0.001..1.0 근처로 제한해 극단값 불안정을 줄인다.
  • x_t는 실제 action과 noise 사이의 interpolation point다.
  • u_t는 flow matching target vector다. 현재 convention은 t=1이 noise, t=0이 data라서 noise - actions가 target이다.
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
attn_mask = make_attn_mask(input_mask, ar_mask)
positions = jnp.cumsum(input_mask, axis=1) - 1
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
    [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
)
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
return jnp.mean(jnp.square(v_t - u_t), axis=-1)

해설:

  • prefix는 image/language, suffix는 state/action/time이다.
  • mask와 ar mask를 합쳐 전체 sequence attention mask를 만든다.
  • positions는 padding을 제외한 position id다.
  • PaliGemma.llm([prefix, suffix])는 두 expert token stream을 함께 attention하게 한다.
  • suffix output 중 마지막 action_horizon개가 action token output이다.
  • action_out_proj가 hidden width를 action dimension으로 되돌린다.
  • 반환 loss shape는 batch와 horizon 축을 유지하고 action dim만 평균낸다. train step에서 다시 전체 평균을 낸다.

5.5 sample_actions()

observation = _model.preprocess_observation(None, observation, train=False)
dt = -1.0 / num_steps
batch_size = observation.state.shape[0]
if noise is None:
    noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

해설:

  • inference preprocess에는 train augmentation이 없다.
  • dt는 negative다. t=1에서 t=0으로 간다.
  • 초기 action은 noise다. 같은 noise를 넣으면 deterministic 비교가 가능하다.
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
positions = jnp.cumsum(prefix_mask, axis=1) - 1
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)

해설:

  • 이미지와 prompt는 denoising step마다 바뀌지 않으므로 한 번만 forward한다.
  • 결과 KV cache를 저장해 suffix step에서 재사용한다.
  • 이 때문에 flow matching 모델은 autoregressive FAST보다 inference가 빠를 수 있다.
def step(carry):
    x_t, time = carry
    suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
        observation, x_t, jnp.broadcast_to(time, batch_size)
    )
    suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
    prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
    full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
    positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
    (prefix_out, suffix_out), _ = self.PaliGemma.llm(
        [None, suffix_tokens],
        mask=full_attn_mask,
        positions=positions,
        kv_cache=kv_cache,
        adarms_cond=[None, adarms_cond],
    )
    v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
    return x_t + dt * v_t, time + dt

해설:

  • 현재 noisy action x_t와 scalar time으로 suffix token을 만든다.
  • suffix 내부 attention mask와 prefix attention mask를 합친다.
  • prefix token은 이미 cache에 있으므로 [None, suffix_tokens]만 넣는다.
  • action vector field v_t를 예측한다.
  • Euler update로 action을 한 step denoise한다.
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
return x_0

해설:

  • Python loop가 아니라 jax.lax.while_loop라 JIT compile 가능하다.
  • 최종 x_0가 normalized action chunk다.
  • policy output transform이 이를 unnormalize하고 robot action space로 바꾼다.

6. π0-FAST 코드 흐름

파일: src/openpi/models/pi0_fast.py

6.1 compute_loss()

observation = _model.preprocess_observation(
    rng, observation, train=train, image_keys=list(observation.images.keys())
)
input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)
attn_mask = make_attn_mask(input_mask, ar_mask)

해설:

  • π0-FAST는 config마다 image key가 다를 수 있어 observation에 실제 들어온 key를 그대로 preprocess한다.
  • embed_inputs()는 image embedding과 token embedding을 합친다.
  • ar_mask는 FAST tokenizer가 만든 token_ar_mask를 포함한다.
targets = jax.nn.one_hot(
    observation.tokenized_prompt[:, 1:],
    self.PaliGemma.llm.module.vocab_size,
)
pre_logits, _, _ = self.PaliGemma.llm(
    embedded_prefix=input_token_embeddings[:, :-1],
    mask=attn_mask[:, :-1, :-1],
    return_prelogits=True,
)
logits, _ = self.PaliGemma.llm(pre_logits=pre_logits[:, -targets.shape[1] :])
logp = jax.nn.log_softmax(logits, axis=-1)

해설:

  • next-token prediction이므로 target은 input token을 한 칸 shift한 것이다.
  • 마지막 token은 다음 token을 예측할 필요가 없으므로 input에서는 제외한다.
  • vocab projection matmul이 크기 때문에 필요한 target 구간만 logits로 decode한다.
  • log_softmax로 CE 계산 준비를 한다.
loss_mask = observation.token_loss_mask[:, 1:]
token_pplx = jnp.sum(targets * logp, axis=-1)
return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)

해설:

  • token_loss_mask가 True인 위치, 즉 action postfix token에만 loss를 건다.
  • prompt prefix에는 loss를 걸지 않는다.
  • 각 sequence별 평균 CE loss를 반환한다.

6.2 sample_actions()

prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
    prefix_token_embeddings, prefix_mask, prefix_attn_mask
)

해설:

  • 입력 prompt 길이가 batch마다 달라 decoding start 위치가 다를 수 있다.
  • left_to_right_align가 sequence를 오른쪽 정렬해 KV cache에서 decode 위치를 맞춘다.
prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
prefix_logits, kv_cache, _ = self.PaliGemma.llm(
    embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True
)
last_logit = prefix_logits[:, -1:]
output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))

해설:

  • prefix forward로 KV cache를 채운다.
  • attention mask를 future decode length만큼 padding해 cache capacity를 확보한다.
  • 마지막 prefix logit이 첫 action token을 예측한다.
token = jax.lax.cond(
    temperature > 0.0,
    lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
    lambda _: jnp.argmax(last_logit, axis=-1),
    operand=None,
)
output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
all_eos = jnp.all(has_eos)

해설:

  • temperature=0이면 greedy argmax다.
  • temperature>0이면 categorical sampling이다.
  • 생성 token을 output buffer의 현재 step 위치에 기록한다.
  • batch의 모든 sample이 EOS를 내면 early stop한다.
token_embedding = self.PaliGemma.llm(token, embed_only=True)
positions = prefill_len[:, None] + step + 1
last_logit, kv_cache, _ = self.PaliGemma.llm(
    embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
)

해설:

  • 방금 생성한 token을 embedding해 다음 token을 예측한다.
  • KV cache를 업데이트하며 autoregressive decoding을 진행한다.
  • 반환값은 raw generated tokens이고, ExtractFASTActions가 action으로 decode한다.

7. Fine-tuning step-by-step

7.1 데이터 형식을 먼저 고정한다

커스텀 데이터셋을 LeRobot 형식으로 만들 때 최소한 정해야 하는 것:

  • observation image key
  • state vector 의미와 dimension
  • action vector 의미와 dimension
  • action이 absolute인지 delta인지
  • action chunk horizon
  • prompt/task source

LIBERO 변환 예시는 다음 command다.

uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/libero/data

커스텀 변환 script를 만들 때는 examples/libero/convert_libero_data_to_lerobot.py를 복사해 dataset writer 부분만 바꾸는 것이 가장 쉽다.

7.2 새 data/policy transform을 만든다

가령 YourRobotInputs는 아래 책임을 가져야 한다.

@dataclasses.dataclass(frozen=True)
class YourRobotInputs(transforms.DataTransformFn):
    model_type: _model.ModelType

    def __call__(self, data: dict) -> dict:
        return {
            "images": {
                "base_0_rgb": data["observation/main_camera"],
                "left_wrist_0_rgb": data["observation/wrist_camera"],
                "right_wrist_0_rgb": np.zeros_like(data["observation/wrist_camera"]),
            },
            "image_masks": {
                "base_0_rgb": np.True_,
                "left_wrist_0_rgb": np.True_,
                "right_wrist_0_rgb": np.False_,
            },
            "state": data["observation/state"],
            "prompt": data["prompt"],
        }

해설:

  • model이 기대하는 key 이름으로 image dict를 만든다.
  • 없는 camera는 zero image와 False mask로 채운다.
  • state는 normalization 전 실제 robot state여야 한다.
  • prompt는 string 또는 numpy scalar string으로 유지한다. 이후 tokenizer transform이 처리한다.

output transform은 모델 action을 실제 command로 되돌린다.

@dataclasses.dataclass(frozen=True)
class YourRobotOutputs(transforms.DataTransformFn):
    def __call__(self, data: dict) -> dict:
        return {"actions": data["actions"][:, :7]}

해설:

  • 모델 action_dim이 32로 padding되어 있어도 실제 로봇은 앞 7차원만 쓸 수 있다.
  • gripper convention이나 joint sign flip이 있으면 여기서 처리한다.

7.3 DataConfigFactory를 추가한다

src/openpi/training/config.py에 다음 패턴을 추가한다.

@dataclasses.dataclass(frozen=True)
class LeRobotYourRobotDataConfig(DataConfigFactory):
    @override
    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
        repack_transform = _transforms.Group(
            inputs=[
                _transforms.RepackTransform(
                    {
                        "observation/main_camera": "main_camera",
                        "observation/wrist_camera": "wrist_camera",
                        "observation/state": "state",
                        "actions": "actions",
                        "prompt": "prompt",
                    }
                )
            ]
        )
        data_transforms = _transforms.Group(
            inputs=[your_robot_policy.YourRobotInputs(model_type=model_config.model_type)],
            outputs=[your_robot_policy.YourRobotOutputs()],
        )
        model_transforms = ModelTransformFactory()(model_config)
        return dataclasses.replace(
            self.create_base_config(assets_dirs, model_config),
            repack_transforms=repack_transform,
            data_transforms=data_transforms,
            model_transforms=model_transforms,
        )

해설:

  • RepackTransform 왼쪽 key는 transform 이후 목표 key다.
  • 오른쪽 string은 LeRobot dataset item 안의 기존 key path다.
  • data_transforms는 training과 inference 둘 다 공유한다.
  • model_transforms는 π0/π0.5/π0-FAST에 맞춰 image resize, tokenize, pad 등을 자동 선택한다.
  • create_base_config()가 norm stats asset loading과 quantile norm 여부를 설정한다.

7.4 TrainConfig를 registry에 추가한다

TrainConfig(
    name="pi05_your_robot",
    model=pi0_config.Pi0Config(pi05=True, action_horizon=16),
    data=LeRobotYourRobotDataConfig(
        repo_id="your_hf_username/your_robot_dataset",
        base_config=DataConfig(prompt_from_task=True),
        assets=AssetsConfig(asset_id="your_robot"),
    ),
    weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
    num_train_steps=30_000,
    batch_size=32,
)

해설:

  • name은 CLI에서 쓰는 고유 이름이다.
  • model에서 action_horizon, action_dim, pi05를 결정한다.
  • repo_id는 LeRobot dataset id다. local dataset이면 LeRobot cache path와 repo id convention을 맞춰야 한다.
  • asset_id는 norm stats 저장/로드 폴더명이다.
  • weight_loader는 base checkpoint parameter로 초기화한다.

7.5 norm stats 계산

uv run scripts/compute_norm_stats.py --config-name pi05_your_robot

흐름:

  1. config를 읽는다.
  2. dataset을 만든다.
  3. repack/data transforms까지 적용한다.
  4. string prompt 등 통계에 불필요한 값을 제거한다.
  5. state/actions running stats를 계산한다.
  6. assets/<config_name>/<asset_id>/norm_stats.json에 저장한다.

중요:

  • norm stats는 tokenization 이전, model padding 이전의 실제 state/action 공간을 기준으로 계산된다.
  • action dimension 중 거의 변하지 않는 값은 std나 quantile range가 너무 작아질 수 있다. divergence가 나면 norm_stats.jsonq01, q99, std를 확인한다.

7.6 JAX fine-tuning 실행

XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
uv run scripts/train.py pi05_your_robot --exp-name=my_experiment --overwrite

해설:

  • XLA_PYTHON_CLIENT_MEM_FRACTION=0.9: JAX가 GPU memory의 90%까지 preallocate할 수 있게 한다.
  • scripts/train.py pi05_your_robot: tyro가 config registry에서 해당 config를 선택한다.
  • --exp-name=my_experiment: checkpoint path가 checkpoints/pi05_your_robot/my_experiment가 된다.
  • --overwrite: 기존 같은 directory가 있으면 덮어쓴다. 재개하려면 --resume을 쓴다.

8. train_step() line-by-line

파일: scripts/train.py

model = nnx.merge(state.model_def, state.params)
model.train()
  • NNX model definition과 parameter state를 합쳐 실행 가능한 model을 만든다.
  • training mode로 바꾼다. dropout/augmentation behavior가 있다면 training mode를 따른다.
def loss_fn(model, rng, observation, actions):
    chunked_loss = model.compute_loss(rng, observation, actions, train=True)
    return jnp.mean(chunked_loss)
  • model별 loss를 호출한다.
  • π0/π0.5는 flow matching MSE, π0-FAST는 CE loss다.
  • 반환 loss를 batch/horizon/token 축 전체 평균으로 줄인다.
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
  • global RNG에 step을 fold-in해 step마다 다른 deterministic RNG를 만든다.
  • data loader는 이미 (Observation, Actions) tuple을 반환한다.
diff_state = nnx.DiffState(0, config.trainable_filter)
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
  • trainable_filter로 freeze되지 않은 parameter만 gradient 대상이 된다.
  • LoRA fine-tuning이면 LoRA parameter만 gradient를 받게 된다.
params = state.params.filter(config.trainable_filter)
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
new_params = optax.apply_updates(params, updates)
  • trainable parameter subset만 optimizer에 넣는다.
  • optimizer state를 갱신하고 update를 적용한다.
nnx.update(model, new_params)
new_params = nnx.state(model)
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
  • model object에 새 trainable parameter를 반영한다.
  • 전체 model state를 다시 꺼낸다. freeze parameter도 포함된 full state다.
  • step과 optimizer state를 갱신한다.
if state.ema_decay is not None:
    new_state = dataclasses.replace(
        new_state,
        ema_params=jax.tree.map(
            lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new,
            state.ema_params,
            new_params,
        ),
    )
  • EMA가 켜져 있으면 inference용 smoothed parameter를 갱신한다.
  • LoRA config들은 보통 ema_decay=None으로 꺼둔다.
info = {
    "loss": loss,
    "grad_norm": optax.global_norm(grads),
    "param_norm": optax.global_norm(kernel_params),
}
return new_state, info
  • logging metric을 반환한다.
  • checkpoint 저장과 wandb logging은 main loop에서 처리한다.

9. Policy server와 client

9.1 server

uv run scripts/serve_policy.py policy:checkpoint \
  --policy.config=pi05_libero \
  --policy.dir=checkpoints/pi05_libero/my_experiment/20000

흐름:

  1. scripts/serve_policy.py가 tyro subcommand를 파싱한다.
  2. Checkpoint(config, dir)이면 해당 config/checkpoint로 policy를 만든다.
  3. WebsocketPolicyServer가 port 8000에서 대기한다.
  4. client가 연결되면 metadata를 먼저 보낸다.
  5. observation msgpack을 받을 때마다 policy.infer()를 호출한다.
  6. action dict와 server_timing을 msgpack으로 돌려준다.

9.2 client

from openpi_client import websocket_client_policy

policy = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
result = policy.infer(observation)
actions = result["actions"]

해설:

  • client package는 full JAX/Torch dependency를 요구하지 않는다.
  • observation은 numpy array와 primitive type 위주여야 msgpack serialization이 된다.
  • ActionChunkBroker(policy)를 감싸면 infer()가 매번 chunk 전체가 아니라 한 step action을 반환한다.

10. PyTorch 변환과 학습

10.1 transformers patch 적용

README 기준:

uv sync
uv pip show transformers
cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/

주의:

  • transformers 버전은 4.53.2여야 한다.
  • uv hardlink mode에서는 이 복사가 uv cache의 transformers에도 영향을 남길 수 있다.
  • 되돌리려면 uv cache clean transformers가 필요할 수 있다.

10.2 JAX checkpoint를 PyTorch로 변환

uv run examples/convert_jax_model_to_pytorch.py \
  --config_name pi05_droid \
  --checkpoint_dir /path/to/jax/checkpoint \
  --output_path /path/to/converted/pytorch/checkpoint

결과 directory에 model.safetensors가 생기면 create_trained_policy()가 PyTorch checkpoint로 자동 감지한다.

10.3 PyTorch inference

config = _config.get_config("pi05_droid")
checkpoint_dir = "/path/to/converted/pytorch/checkpoint"
policy = policy_config.create_trained_policy(config, checkpoint_dir)
action_chunk = policy.infer(example)["actions"]

JAX API와 같다. checkpoint directory만 다르다.

10.4 PyTorch fine-tuning

Single GPU:

uv run scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_test

Multi-GPU single node:

uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 \
  scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test

제약:

  • π0-FAST 미지원.
  • mixed precision training 미지원.
  • FSDP 미지원.
  • LoRA 미지원.
  • EMA 미지원.

11. 문제 해결 체크리스트

11.1 Missing norm stats

원인:

  • assets/<config_name>/<asset_id>/norm_stats.json이 없다.
  • checkpoint assets에 norm stats가 없거나 asset_id가 다르다.

해결:

uv run scripts/compute_norm_stats.py --config-name <config_name>

또는 AssetsConfig(asset_id=...)를 checkpoint asset과 맞춘다.

11.2 action dimension mismatch

확인 순서:

  1. TrainConfig.model.action_dim
  2. robot-specific Inputs가 만드는 state.shape[-1]
  3. dataset action shape
  4. PadStatesAndActions
  5. robot-specific Outputs가 slice하는 action dim

π0.5 base는 종종 32-dim action으로 학습되어 실제 로봇이 7/8/14 dim이어도 padding과 slice가 필요하다.

11.3 prompt가 안 들어가는 문제

확인 순서:

  1. dataset에 task가 있고 DataConfig(prompt_from_task=True)인지.
  2. inference observation에 prompt key가 있는지.
  3. InjectDefaultPrompt(default_prompt)를 쓰고 있는지.
  4. π0.5이면 discrete_state_input 때문에 tokenizer가 state도 필요로 하는지.

11.4 DROID RLDS loader 문제

확인:

  • uv sync --group rlds를 했는지.
  • Python 3.11 환경인지.
  • rlds_data_dirdroid directory의 parent인지.
  • num_workers=0인지. RLDS loader는 내부 multiprocessing/prefetch가 있어 config에서도 0으로 둔다.

11.5 PyTorch checkpoint가 JAX로 감지되는 문제

확인:

  • checkpoint directory 바로 아래에 model.safetensors가 있는지.
  • 파일명이 model.safetensors인지. 코드 주석에는 오타가 있지만 실제 감지는 이 이름이다.

11.6 src/openpi/models/vit.py import 문제

이 파일은 현재 메인 모델 path에서 쓰이지 않는다. openpi.models.resnet이 저장소에 없으므로 직접 import하는 연구 코드를 만들면 실패할 수 있다. SigLIP을 쓰는 현재 π0 path에는 영향이 없다.

12. 권장 연구 시작 순서

  1. uv run pytest src/openpi/transforms_test.py로 transform 기본 동작을 확인한다.
  2. GPU가 준비되어 있으면 examples/simple_client로 checkpoint server/client roundtrip을 확인한다.
  3. target robot의 raw observation/action schema를 문서화한다.
  4. YourRobotInputs/YourRobotOutputs를 먼저 작성하고 fake numpy observation으로 단위 테스트한다.
  5. LeRobot conversion script를 작성하고 2-3 episode만 변환해 shape/key를 확인한다.
  6. compute_norm_stats.py를 돌린다.
  7. debug 또는 작은 step 수로 training smoke test를 한다.
  8. base checkpoint에서 fine-tuning한다.
  9. policy server로 추론하고 action unnormalize/output transform이 실제 robot command와 맞는지 dry-run으로 확인한다.
  10. 실제 robot closed-loop 실행은 action clamp, emergency stop, open-loop horizon을 보수적으로 잡고 시작한다.