openpi 사용 절차와 핵심 코드 해설
이 문서는 openpi를 실제로 설치하고, checkpoint를 실행하고, 데이터를 변환하고, fine-tuning하고, 필요하면 PyTorch로 변환하는 과정을 step-by-step으로 설명한다. 핵심 실행 경로는 코드 조각 단위로 line-by-line 해설한다.
1. 설치와 환경 확인
1.1 전제
README 기준 전제는 다음이다.
- OS: Ubuntu 22.04
- GPU: NVIDIA GPU
- Python:
pyproject.toml기준>=3.11 - dependency manager:
uv - inference: 8GB 이상 VRAM
- LoRA fine-tuning: 22.5GB 이상 VRAM
- full fine-tuning: 70GB 이상 VRAM
현재 워크트리에서는 submodule이 초기화되지 않았다. 실제 ALOHA/LIBERO submodule이 필요한 예제를 쓰려면 먼저 실행한다.
git submodule update --init --recursive
1.2 uv 설치 후 dependency 설치
GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
해설:
GIT_LFS_SKIP_SMUDGE=1: LeRobot dependency checkout 과정에서 Git LFS 대용량 파일을 바로 받지 않게 한다.uv sync:pyproject.toml과uv.lock기준으로.venv를 만든다.uv pip install -e .: 현재 저장소를 editable package로 설치한다. 코드 수정 후 재설치 없이 import에 반영된다.
설치 확인:
uv run python -c "import openpi; print(openpi.__file__)"
uv run python -c "from openpi.training import config; print(config.get_config('debug').name)"
2. 빠른 checkpoint 추론
README의 핵심 API는 다음이다.
from openpi.training import config as _config
from openpi.policies import policy_config
from openpi.shared import download
config = _config.get_config("pi05_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_droid")
policy = policy_config.create_trained_policy(config, checkpoint_dir)
example = {
"observation/exterior_image_1_left": ...,
"observation/wrist_image_left": ...,
"prompt": "pick up the fork",
}
action_chunk = policy.infer(example)["actions"]
Line-by-line:
from openpi.training import config as _config- config registry를 import한다.
get_config()는src/openpi/training/config.py의_CONFIGS_DICT에서 이름으로TrainConfig를 찾는다.
- config registry를 import한다.
from openpi.policies import policy_config- checkpoint를 실제 inference policy 객체로 바꾸는 factory를 import한다.
from openpi.shared import downloadgs://, local path, fsspec URL을 cache directory로 내려받는 helper다.
config = _config.get_config("pi05_droid")pi05_droid설정을 고른다. 이 config는Pi0Config(pi05=True, action_horizon=15)와 DROID input/output transform을 사용한다.
checkpoint_dir = download.maybe_download(...)- Google Cloud Storage checkpoint를
~/.cache/openpi또는OPENPI_DATA_HOME아래로 받는다.
- Google Cloud Storage checkpoint를
policy = policy_config.create_trained_policy(config, checkpoint_dir)- checkpoint type을 JAX/PyTorch로 감지한다.
- 모델 parameter를 로드한다.
- norm stats를 checkpoint assets에서 로드한다.
- repack, robot transform, normalization, model tokenization transform chain을 만든다.
example = {...}- DROID config에서 기대하는 raw observation key다. DROID runtime에서는
observation/exterior_image_1_left,observation/wrist_image_left, joint/gripper state, prompt 등이 필요하다.
- DROID config에서 기대하는 raw observation key다. DROID runtime에서는
action_chunk = policy.infer(example)["actions"]- 입력 dict가 transform chain을 지나
Observation으로 바뀐다. - 모델
sample_actions()가 action chunk를 생성한다. - output transform이 action을 robot command space로 되돌린다.
- 입력 dict가 transform chain을 지나
실제로 로봇 없이 호출 구조만 확인하려면 examples/simple_client를 쓰는 편이 안전하다. 먼저 server:
uv run scripts/serve_policy.py policy:checkpoint \
--policy.config=pi05_droid \
--policy.dir=gs://openpi-assets/checkpoints/pi05_droid
다른 터미널에서 client:
uv run examples/simple_client/main.py --env DROID
3. create_trained_policy() line-by-line
파일: src/openpi/policies/policy_config.py
핵심 코드:
repack_transforms = repack_transforms or transforms.Group()
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
weight_path = os.path.join(checkpoint_dir, "model.safetensors")
is_pytorch = os.path.exists(weight_path)
if is_pytorch:
model = train_config.model.load_pytorch(train_config, weight_path)
model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
else:
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
if norm_stats is None:
if data_config.asset_id is None:
raise ValueError("Asset id is required to load norm stats.")
norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
return _policy.Policy(
model,
transforms=[
*repack_transforms.inputs,
transforms.InjectDefaultPrompt(default_prompt),
*data_config.data_transforms.inputs,
transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
output_transforms=[
*data_config.model_transforms.outputs,
transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.data_transforms.outputs,
*repack_transforms.outputs,
],
sample_kwargs=sample_kwargs,
metadata=train_config.policy_metadata,
is_pytorch=is_pytorch,
pytorch_device=pytorch_device if is_pytorch else None,
)
해설:
repack_transforms = ...: inference caller가 별도 key remapping을 넘기지 않으면 빈 transform group을 쓴다.download.maybe_download: checkpoint path가gs://이면 cache로 다운로드하고, local이면 그대로 Path로 정규화한다.model.safetensors존재 여부: PyTorch checkpoint 감지 규칙이다. 없으면 JAX Orbax checkpoint로 본다.- PyTorch branch:
load_pytorch()가 model instance를 만들고 safetensors를 load한다. 이후 선택 parameter를 bfloat16 정책에 맞게 변환한다. - JAX branch:
restore_params(checkpoint_dir / "params")로 Orbax parameter tree를 읽고BaseModelConfig.load()로 model state에 주입한다. data_config = train_config.data.create(...): data factory가 repack/data/model transform과 norm stats 설정을 만든다.norm_stats: 명시로 넘기지 않으면 checkpoint의assets/<asset_id>/norm_stats.json을 읽는다. inference와 training normalization을 반드시 맞추려는 설계다.- input
transforms: raw dict에서 model-ready dict로 가는 순서다. - output
output_transforms: model output에서 robot command로 돌아가는 순서다. input의 역변환이 앞쪽에 온다. metadata: policy server가 client에 알려줄 reset pose 같은 정보다.
4. Policy.infer() line-by-line
파일: src/openpi/policies/policy.py
핵심 코드:
inputs = jax.tree.map(lambda x: x, obs)
inputs = self._input_transform(inputs)
if not self._is_pytorch_model:
inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs)
self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng)
else:
inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs)
sample_rng_or_pytorch_device = self._pytorch_device
sample_kwargs = dict(self._sample_kwargs)
if noise is not None:
noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)
if noise.ndim == 2:
noise = noise[None, ...]
sample_kwargs["noise"] = noise
observation = _model.Observation.from_dict(inputs)
start_time = time.monotonic()
outputs = {
"state": inputs["state"],
"actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs),
}
model_time = time.monotonic() - start_time
...
outputs = self._output_transform(outputs)
outputs["policy_timing"] = {"infer_ms": model_time * 1000}
return outputs
해설:
jax.tree.map(lambda x: x, obs): 입력 dict를 얕은 tree copy처럼 다뤄 transform이 원본을 직접 바꾸는 부작용을 줄인다._input_transform:create_trained_policy()에서 조립한 모든 input transform이 실행된다.- JAX branch:
- 모든 leaf를
jnp.asarray로 바꾼다. [np.newaxis, ...]로 batch dimension을 추가한다.- RNG를 split해 이번 inference sampling에 쓸 key를 만든다.
- 모든 leaf를
- PyTorch branch:
- numpy array로 강제 변환 후 torch tensor로 만들고 device로 옮긴다.
- batch dimension을 추가한다.
- PyTorch
sample_actions()는 첫 번째 인자로 RNG 대신 device string을 받는 식으로 맞춰져 있다.
noise: flow matching 모델에서 deterministic 비교나 debugging을 위해 초기 noise를 직접 넣을 수 있다.Observation.from_dict: dict를 typed observation dataclass로 변환한다._sample_actions: JAX면nnx_utils.module_jit(model.sample_actions), PyTorch면 rawmodel.sample_actions.- output에는
state와actions가 들어간다. output transform 중Unnormalize가 state/action stats를 참조하므로 state도 같이 둔다. - batch dimension을 제거한 뒤
_output_transform이 robot action space로 되돌린다. policy_timing은 model sampling 시간만 ms 단위로 기록한다.
5. π0 flow matching 코드 흐름
파일: src/openpi/models/pi0.py
5.1 Attention mask
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
cumsum = jnp.cumsum(mask_ar, axis=1)
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
return jnp.logical_and(attn_mask, valid_mask)
해설:
mask_ar는 token group 경계를 표시한다.False면 이전 token과 같은 attention block,True면 새 causal block으로 본다.cumsum은 각 token이 몇 번째 AR block에 속하는지 만든다.attn_mask[i, q, k] = True조건은 key token의 block id가 query token의 block id보다 작거나 같을 때다.valid_mask는 padding token을 attention에서 제외한다.- prefix image/language는 full attention, suffix action token은 causal/prefix-LM attention을 만들 수 있다.
5.2 embed_prefix()
for name in obs.images:
image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
tokens.append(image_tokens)
input_mask.append(einops.repeat(obs.image_masks[name], "b -> b s", s=image_tokens.shape[1]))
ar_mask += [False] * image_tokens.shape[1]
if obs.tokenized_prompt is not None:
tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
tokens.append(tokenized_inputs)
input_mask.append(obs.tokenized_prompt_mask)
ar_mask += [False] * tokenized_inputs.shape[1]
해설:
- 각 camera view를 SigLIP image token sequence로 바꾼다.
- image mask는 camera가 실제로 존재하는지 나타낸다. 없는 wrist camera는 zero image +
Falsemask로 들어갈 수 있다. - prefix token들은 모두
ar_mask=False라 서로 bidirectional/full attention이 가능하다. - prompt token은 Gemma embedder로 embedding한다.
- 결과는
(prefix_tokens, prefix_input_mask, prefix_ar_mask)다.
5.3 embed_suffix()
π0:
state_token = self.state_proj(obs.state)[:, None, :]
...
action_tokens = self.action_in_proj(noisy_actions)
time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
action_time_tokens = self.action_time_mlp_in(action_time_tokens)
action_time_tokens = nnx.swish(action_time_tokens)
action_time_tokens = self.action_time_mlp_out(action_time_tokens)
π0.5:
action_tokens = self.action_in_proj(noisy_actions)
time_emb = posemb_sincos(...)
time_emb = self.time_mlp_in(time_emb)
time_emb = nnx.swish(time_emb)
time_emb = self.time_mlp_out(time_emb)
time_emb = nnx.swish(time_emb)
action_expert_tokens = action_tokens
adarms_cond = time_emb
해설:
- π0는 state를 별도 continuous suffix token으로 넣는다.
- π0.5는 state를 prompt token으로 넣을 수 있으므로 suffix state token을 생략한다.
- action은
action_dim -> action_expert_width로 projection한다. - timestep은 scalar라 sinusoidal embedding으로 만든다.
- π0는 action embedding과 time embedding을 concatenate해 MLP로 섞는다.
- π0.5는 timestep을 AdaRMS conditioning으로 action expert transformer block에 주입한다.
5.4 compute_loss()
preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
batch_shape = actions.shape[:-2]
noise = jax.random.normal(noise_rng, actions.shape)
time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
time_expanded = time[..., None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
해설:
- RNG를 image augmentation, noise sampling, time sampling 용도로 나눈다.
- training이면 image augmentation이 적용된다.
noise는 action chunk와 같은 shape다.time은 batch마다 하나씩 뽑는다.0.001..1.0근처로 제한해 극단값 불안정을 줄인다.x_t는 실제 action과 noise 사이의 interpolation point다.u_t는 flow matching target vector다. 현재 convention은t=1이 noise,t=0이 data라서noise - actions가 target이다.
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
attn_mask = make_attn_mask(input_mask, ar_mask)
positions = jnp.cumsum(input_mask, axis=1) - 1
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
)
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
return jnp.mean(jnp.square(v_t - u_t), axis=-1)
해설:
- prefix는 image/language, suffix는 state/action/time이다.
- mask와 ar mask를 합쳐 전체 sequence attention mask를 만든다.
positions는 padding을 제외한 position id다.PaliGemma.llm([prefix, suffix])는 두 expert token stream을 함께 attention하게 한다.- suffix output 중 마지막
action_horizon개가 action token output이다. action_out_proj가 hidden width를 action dimension으로 되돌린다.- 반환 loss shape는 batch와 horizon 축을 유지하고 action dim만 평균낸다. train step에서 다시 전체 평균을 낸다.
5.5 sample_actions()
observation = _model.preprocess_observation(None, observation, train=False)
dt = -1.0 / num_steps
batch_size = observation.state.shape[0]
if noise is None:
noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
해설:
- inference preprocess에는 train augmentation이 없다.
dt는 negative다.t=1에서t=0으로 간다.- 초기 action은 noise다. 같은
noise를 넣으면 deterministic 비교가 가능하다.
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
positions = jnp.cumsum(prefix_mask, axis=1) - 1
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
해설:
- 이미지와 prompt는 denoising step마다 바뀌지 않으므로 한 번만 forward한다.
- 결과 KV cache를 저장해 suffix step에서 재사용한다.
- 이 때문에 flow matching 모델은 autoregressive FAST보다 inference가 빠를 수 있다.
def step(carry):
x_t, time = carry
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
observation, x_t, jnp.broadcast_to(time, batch_size)
)
suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[None, suffix_tokens],
mask=full_attn_mask,
positions=positions,
kv_cache=kv_cache,
adarms_cond=[None, adarms_cond],
)
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
return x_t + dt * v_t, time + dt
해설:
- 현재 noisy action
x_t와 scalartime으로 suffix token을 만든다. - suffix 내부 attention mask와 prefix attention mask를 합친다.
- prefix token은 이미 cache에 있으므로
[None, suffix_tokens]만 넣는다. - action vector field
v_t를 예측한다. - Euler update로 action을 한 step denoise한다.
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
return x_0
해설:
- Python loop가 아니라
jax.lax.while_loop라 JIT compile 가능하다. - 최종
x_0가 normalized action chunk다. - policy output transform이 이를 unnormalize하고 robot action space로 바꾼다.
6. π0-FAST 코드 흐름
파일: src/openpi/models/pi0_fast.py
6.1 compute_loss()
observation = _model.preprocess_observation(
rng, observation, train=train, image_keys=list(observation.images.keys())
)
input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)
attn_mask = make_attn_mask(input_mask, ar_mask)
해설:
- π0-FAST는 config마다 image key가 다를 수 있어 observation에 실제 들어온 key를 그대로 preprocess한다.
embed_inputs()는 image embedding과 token embedding을 합친다.ar_mask는 FAST tokenizer가 만든 token_ar_mask를 포함한다.
targets = jax.nn.one_hot(
observation.tokenized_prompt[:, 1:],
self.PaliGemma.llm.module.vocab_size,
)
pre_logits, _, _ = self.PaliGemma.llm(
embedded_prefix=input_token_embeddings[:, :-1],
mask=attn_mask[:, :-1, :-1],
return_prelogits=True,
)
logits, _ = self.PaliGemma.llm(pre_logits=pre_logits[:, -targets.shape[1] :])
logp = jax.nn.log_softmax(logits, axis=-1)
해설:
- next-token prediction이므로 target은 input token을 한 칸 shift한 것이다.
- 마지막 token은 다음 token을 예측할 필요가 없으므로 input에서는 제외한다.
- vocab projection matmul이 크기 때문에 필요한 target 구간만 logits로 decode한다.
log_softmax로 CE 계산 준비를 한다.
loss_mask = observation.token_loss_mask[:, 1:]
token_pplx = jnp.sum(targets * logp, axis=-1)
return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)
해설:
token_loss_mask가 True인 위치, 즉 action postfix token에만 loss를 건다.- prompt prefix에는 loss를 걸지 않는다.
- 각 sequence별 평균 CE loss를 반환한다.
6.2 sample_actions()
prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
prefix_token_embeddings, prefix_mask, prefix_attn_mask
)
해설:
- 입력 prompt 길이가 batch마다 달라 decoding start 위치가 다를 수 있다.
left_to_right_align가 sequence를 오른쪽 정렬해 KV cache에서 decode 위치를 맞춘다.
prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
prefix_logits, kv_cache, _ = self.PaliGemma.llm(
embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True
)
last_logit = prefix_logits[:, -1:]
output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
해설:
- prefix forward로 KV cache를 채운다.
- attention mask를 future decode length만큼 padding해 cache capacity를 확보한다.
- 마지막 prefix logit이 첫 action token을 예측한다.
token = jax.lax.cond(
temperature > 0.0,
lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
lambda _: jnp.argmax(last_logit, axis=-1),
operand=None,
)
output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
all_eos = jnp.all(has_eos)
해설:
temperature=0이면 greedy argmax다.temperature>0이면 categorical sampling이다.- 생성 token을 output buffer의 현재 step 위치에 기록한다.
- batch의 모든 sample이 EOS를 내면 early stop한다.
token_embedding = self.PaliGemma.llm(token, embed_only=True)
positions = prefill_len[:, None] + step + 1
last_logit, kv_cache, _ = self.PaliGemma.llm(
embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
)
해설:
- 방금 생성한 token을 embedding해 다음 token을 예측한다.
- KV cache를 업데이트하며 autoregressive decoding을 진행한다.
- 반환값은 raw generated tokens이고,
ExtractFASTActions가 action으로 decode한다.
7. Fine-tuning step-by-step
7.1 데이터 형식을 먼저 고정한다
커스텀 데이터셋을 LeRobot 형식으로 만들 때 최소한 정해야 하는 것:
- observation image key
- state vector 의미와 dimension
- action vector 의미와 dimension
- action이 absolute인지 delta인지
- action chunk horizon
- prompt/task source
LIBERO 변환 예시는 다음 command다.
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/libero/data
커스텀 변환 script를 만들 때는 examples/libero/convert_libero_data_to_lerobot.py를 복사해 dataset writer 부분만 바꾸는 것이 가장 쉽다.
7.2 새 data/policy transform을 만든다
가령 YourRobotInputs는 아래 책임을 가져야 한다.
@dataclasses.dataclass(frozen=True)
class YourRobotInputs(transforms.DataTransformFn):
model_type: _model.ModelType
def __call__(self, data: dict) -> dict:
return {
"images": {
"base_0_rgb": data["observation/main_camera"],
"left_wrist_0_rgb": data["observation/wrist_camera"],
"right_wrist_0_rgb": np.zeros_like(data["observation/wrist_camera"]),
},
"image_masks": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
"right_wrist_0_rgb": np.False_,
},
"state": data["observation/state"],
"prompt": data["prompt"],
}
해설:
- model이 기대하는 key 이름으로 image dict를 만든다.
- 없는 camera는 zero image와
Falsemask로 채운다. state는 normalization 전 실제 robot state여야 한다.prompt는 string 또는 numpy scalar string으로 유지한다. 이후 tokenizer transform이 처리한다.
output transform은 모델 action을 실제 command로 되돌린다.
@dataclasses.dataclass(frozen=True)
class YourRobotOutputs(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
return {"actions": data["actions"][:, :7]}
해설:
- 모델 action_dim이 32로 padding되어 있어도 실제 로봇은 앞 7차원만 쓸 수 있다.
- gripper convention이나 joint sign flip이 있으면 여기서 처리한다.
7.3 DataConfigFactory를 추가한다
src/openpi/training/config.py에 다음 패턴을 추가한다.
@dataclasses.dataclass(frozen=True)
class LeRobotYourRobotDataConfig(DataConfigFactory):
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
repack_transform = _transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"observation/main_camera": "main_camera",
"observation/wrist_camera": "wrist_camera",
"observation/state": "state",
"actions": "actions",
"prompt": "prompt",
}
)
]
)
data_transforms = _transforms.Group(
inputs=[your_robot_policy.YourRobotInputs(model_type=model_config.model_type)],
outputs=[your_robot_policy.YourRobotOutputs()],
)
model_transforms = ModelTransformFactory()(model_config)
return dataclasses.replace(
self.create_base_config(assets_dirs, model_config),
repack_transforms=repack_transform,
data_transforms=data_transforms,
model_transforms=model_transforms,
)
해설:
RepackTransform왼쪽 key는 transform 이후 목표 key다.- 오른쪽 string은 LeRobot dataset item 안의 기존 key path다.
data_transforms는 training과 inference 둘 다 공유한다.model_transforms는 π0/π0.5/π0-FAST에 맞춰 image resize, tokenize, pad 등을 자동 선택한다.create_base_config()가 norm stats asset loading과 quantile norm 여부를 설정한다.
7.4 TrainConfig를 registry에 추가한다
TrainConfig(
name="pi05_your_robot",
model=pi0_config.Pi0Config(pi05=True, action_horizon=16),
data=LeRobotYourRobotDataConfig(
repo_id="your_hf_username/your_robot_dataset",
base_config=DataConfig(prompt_from_task=True),
assets=AssetsConfig(asset_id="your_robot"),
),
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
num_train_steps=30_000,
batch_size=32,
)
해설:
name은 CLI에서 쓰는 고유 이름이다.model에서action_horizon,action_dim,pi05를 결정한다.repo_id는 LeRobot dataset id다. local dataset이면 LeRobot cache path와 repo id convention을 맞춰야 한다.asset_id는 norm stats 저장/로드 폴더명이다.weight_loader는 base checkpoint parameter로 초기화한다.
7.5 norm stats 계산
uv run scripts/compute_norm_stats.py --config-name pi05_your_robot
흐름:
- config를 읽는다.
- dataset을 만든다.
- repack/data transforms까지 적용한다.
- string prompt 등 통계에 불필요한 값을 제거한다.
- state/actions running stats를 계산한다.
assets/<config_name>/<asset_id>/norm_stats.json에 저장한다.
중요:
- norm stats는 tokenization 이전, model padding 이전의 실제 state/action 공간을 기준으로 계산된다.
- action dimension 중 거의 변하지 않는 값은
std나 quantile range가 너무 작아질 수 있다. divergence가 나면norm_stats.json의q01,q99,std를 확인한다.
7.6 JAX fine-tuning 실행
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
uv run scripts/train.py pi05_your_robot --exp-name=my_experiment --overwrite
해설:
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9: JAX가 GPU memory의 90%까지 preallocate할 수 있게 한다.scripts/train.py pi05_your_robot: tyro가 config registry에서 해당 config를 선택한다.--exp-name=my_experiment: checkpoint path가checkpoints/pi05_your_robot/my_experiment가 된다.--overwrite: 기존 같은 directory가 있으면 덮어쓴다. 재개하려면--resume을 쓴다.
8. train_step() line-by-line
파일: scripts/train.py
model = nnx.merge(state.model_def, state.params)
model.train()
- NNX model definition과 parameter state를 합쳐 실행 가능한 model을 만든다.
- training mode로 바꾼다. dropout/augmentation behavior가 있다면 training mode를 따른다.
def loss_fn(model, rng, observation, actions):
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
return jnp.mean(chunked_loss)
- model별 loss를 호출한다.
- π0/π0.5는 flow matching MSE, π0-FAST는 CE loss다.
- 반환 loss를 batch/horizon/token 축 전체 평균으로 줄인다.
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
- global RNG에 step을 fold-in해 step마다 다른 deterministic RNG를 만든다.
- data loader는 이미
(Observation, Actions)tuple을 반환한다.
diff_state = nnx.DiffState(0, config.trainable_filter)
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
trainable_filter로 freeze되지 않은 parameter만 gradient 대상이 된다.- LoRA fine-tuning이면 LoRA parameter만 gradient를 받게 된다.
params = state.params.filter(config.trainable_filter)
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
new_params = optax.apply_updates(params, updates)
- trainable parameter subset만 optimizer에 넣는다.
- optimizer state를 갱신하고 update를 적용한다.
nnx.update(model, new_params)
new_params = nnx.state(model)
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
- model object에 새 trainable parameter를 반영한다.
- 전체 model state를 다시 꺼낸다. freeze parameter도 포함된 full state다.
- step과 optimizer state를 갱신한다.
if state.ema_decay is not None:
new_state = dataclasses.replace(
new_state,
ema_params=jax.tree.map(
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new,
state.ema_params,
new_params,
),
)
- EMA가 켜져 있으면 inference용 smoothed parameter를 갱신한다.
- LoRA config들은 보통
ema_decay=None으로 꺼둔다.
info = {
"loss": loss,
"grad_norm": optax.global_norm(grads),
"param_norm": optax.global_norm(kernel_params),
}
return new_state, info
- logging metric을 반환한다.
- checkpoint 저장과 wandb logging은 main loop에서 처리한다.
9. Policy server와 client
9.1 server
uv run scripts/serve_policy.py policy:checkpoint \
--policy.config=pi05_libero \
--policy.dir=checkpoints/pi05_libero/my_experiment/20000
흐름:
scripts/serve_policy.py가 tyro subcommand를 파싱한다.Checkpoint(config, dir)이면 해당 config/checkpoint로 policy를 만든다.WebsocketPolicyServer가 port 8000에서 대기한다.- client가 연결되면 metadata를 먼저 보낸다.
- observation msgpack을 받을 때마다
policy.infer()를 호출한다. - action dict와
server_timing을 msgpack으로 돌려준다.
9.2 client
from openpi_client import websocket_client_policy
policy = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
result = policy.infer(observation)
actions = result["actions"]
해설:
- client package는 full JAX/Torch dependency를 요구하지 않는다.
- observation은 numpy array와 primitive type 위주여야 msgpack serialization이 된다.
ActionChunkBroker(policy)를 감싸면infer()가 매번 chunk 전체가 아니라 한 step action을 반환한다.
10. PyTorch 변환과 학습
10.1 transformers patch 적용
README 기준:
uv sync
uv pip show transformers
cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
주의:
- transformers 버전은 4.53.2여야 한다.
- uv hardlink mode에서는 이 복사가 uv cache의 transformers에도 영향을 남길 수 있다.
- 되돌리려면
uv cache clean transformers가 필요할 수 있다.
10.2 JAX checkpoint를 PyTorch로 변환
uv run examples/convert_jax_model_to_pytorch.py \
--config_name pi05_droid \
--checkpoint_dir /path/to/jax/checkpoint \
--output_path /path/to/converted/pytorch/checkpoint
결과 directory에 model.safetensors가 생기면 create_trained_policy()가 PyTorch checkpoint로 자동 감지한다.
10.3 PyTorch inference
config = _config.get_config("pi05_droid")
checkpoint_dir = "/path/to/converted/pytorch/checkpoint"
policy = policy_config.create_trained_policy(config, checkpoint_dir)
action_chunk = policy.infer(example)["actions"]
JAX API와 같다. checkpoint directory만 다르다.
10.4 PyTorch fine-tuning
Single GPU:
uv run scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_test
Multi-GPU single node:
uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 \
scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
제약:
- π0-FAST 미지원.
- mixed precision training 미지원.
- FSDP 미지원.
- LoRA 미지원.
- EMA 미지원.
11. 문제 해결 체크리스트
11.1 Missing norm stats
원인:
assets/<config_name>/<asset_id>/norm_stats.json이 없다.- checkpoint assets에 norm stats가 없거나
asset_id가 다르다.
해결:
uv run scripts/compute_norm_stats.py --config-name <config_name>
또는 AssetsConfig(asset_id=...)를 checkpoint asset과 맞춘다.
11.2 action dimension mismatch
확인 순서:
TrainConfig.model.action_dim- robot-specific
Inputs가 만드는state.shape[-1] - dataset action shape
PadStatesAndActions- robot-specific
Outputs가 slice하는 action dim
π0.5 base는 종종 32-dim action으로 학습되어 실제 로봇이 7/8/14 dim이어도 padding과 slice가 필요하다.
11.3 prompt가 안 들어가는 문제
확인 순서:
- dataset에
task가 있고DataConfig(prompt_from_task=True)인지. - inference observation에
promptkey가 있는지. InjectDefaultPrompt(default_prompt)를 쓰고 있는지.- π0.5이면
discrete_state_input때문에 tokenizer가 state도 필요로 하는지.
11.4 DROID RLDS loader 문제
확인:
uv sync --group rlds를 했는지.- Python 3.11 환경인지.
rlds_data_dir가droiddirectory의 parent인지.num_workers=0인지. RLDS loader는 내부 multiprocessing/prefetch가 있어 config에서도 0으로 둔다.
11.5 PyTorch checkpoint가 JAX로 감지되는 문제
확인:
- checkpoint directory 바로 아래에
model.safetensors가 있는지. - 파일명이
model.safetensors인지. 코드 주석에는 오타가 있지만 실제 감지는 이 이름이다.
11.6 src/openpi/models/vit.py import 문제
이 파일은 현재 메인 모델 path에서 쓰이지 않는다. openpi.models.resnet이 저장소에 없으므로 직접 import하는 연구 코드를 만들면 실패할 수 있다. SigLIP을 쓰는 현재 π0 path에는 영향이 없다.
12. 권장 연구 시작 순서
uv run pytest src/openpi/transforms_test.py로 transform 기본 동작을 확인한다.- GPU가 준비되어 있으면
examples/simple_client로 checkpoint server/client roundtrip을 확인한다. - target robot의 raw observation/action schema를 문서화한다.
YourRobotInputs/YourRobotOutputs를 먼저 작성하고 fake numpy observation으로 단위 테스트한다.- LeRobot conversion script를 작성하고 2-3 episode만 변환해 shape/key를 확인한다.
compute_norm_stats.py를 돌린다.debug또는 작은 step 수로 training smoke test를 한다.- base checkpoint에서 fine-tuning한다.
- policy server로 추론하고 action unnormalize/output transform이 실제 robot command와 맞는지 dry-run으로 확인한다.
- 실제 robot closed-loop 실행은 action clamp, emergency stop, open-loop horizon을 보수적으로 잡고 시작한다.