5. Valid Image Embedding Batching

observation image를 각각 embed하지 말고 한번에 embedding을 구해서 나중에 split

이 서버는 여러 사용자가 함께 쓰는 공용 GPU 서버이므로, 전체 시스템을 독점한 dedicated benchmark 환경은 아니다. 따라서 이번 profiling에서는 실행 GPU를 1개의 L40S로 고정하고, 해당 GPU는 실험 중 단독으로 사용하기로 합의했다. 다만 CPU, memory, storage I/O, OS background load는 다른 사용자의 작업 영향을 받을 수 있으므로, 이후 latency 수치는 절대적인 서버 최대 성능이라기보다 shared-server 환경에서의 병목 분석용 측정값으로 해석한다.


지난 게시물에서 사용하지 않는 image를 embedding하는 시간 낭비를 제거했었다. 하지만 남은 두개의 image는 여전히 각각 embed된다.

image_0_embed ≈ 6.08 ms
image_1_embed ≈ 5.85 ms

즉 현재 방식은 embed_image(base_0_rgb), embed_image(left_wrist_0_rgb)을 따로 구하는 것인데, 이번에 적용해볼 optimization은 embed_image(cat([base_0_rgb, left_wrist_0_rgb], dim=0))와 같이 embedding을 구하고 이것을 나중에 필요할 때 split하는 것이다. 목표는 아래와 같다:

1. vision tower 호출 횟수 2회 → 1회
2. small-batch underutilization 완화
3. kernel launch / graph node overhead 감소
4. prefix token order는 동일하게 유지

핵심은 embed_prefix()의 image processing block을 아래 구조로 수정하는 것이다:

pi0_pytorch.pyembed_prefix()에서 img processing block 수정
uv run python - <<'PY'
from pathlib import Path

path = Path("src/openpi/models_pytorch/pi0_pytorch.py")
text = path.read_text()

start = text.index("    def embed_prefix(")
end = text.index("    def embed_suffix(", start)

new_function = '''    def embed_prefix(
        self,
        images,
        img_masks,
        lang_tokens,
        lang_masks,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Embed image and language inputs for the PaliGemma prefix."""

        embs = []
        pad_masks = []
        att_masks = []

        # Collect image slots to be processed.
        image_items = []
        for image_idx, (img, img_mask) in enumerate(
            zip(images, img_masks, strict=True)
        ):
            # LIBERO pi0 uses a zero-filled, mask=False right-wrist slot.
            if (
                getattr(self, "libero_skip_right_wrist_image", False)
                and image_idx == 2
            ):
                continue

            image_items.append((image_idx, img, img_mask))

        def image_embed_func(img):
            return self.paligemma_with_expert.embed_image(img)

        use_batched_image_embed = (
            getattr(self, "libero_batch_valid_images", False)
            and not self.training
            and len(image_items) > 1
        )

        if use_batched_image_embed:
            batch_sizes = {int(img.shape[0]) for _, img, _ in image_items}
            if len(batch_sizes) != 1:
                raise ValueError(
                    f"All image slots must have the same batch size, got {batch_sizes}"
                )

            batch_size = next(iter(batch_sizes))

            # Preserve image-slot ordering:
            # [all samples of image_0, all samples of image_1, ...].
            batched_images = torch.cat(
                [img for _, img, _ in image_items],
                dim=0,
            )

            batched_image_embeddings = self._apply_checkpoint(
                image_embed_func,
                batched_images,
            )

            image_embeddings = torch.split(
                batched_image_embeddings,
                batch_size,
                dim=0,
            )

            for image_embedding, (_, _, image_mask) in zip(
                image_embeddings,
                image_items,
                strict=True,
            ):
                bsize, num_image_tokens = image_embedding.shape[:2]

                embs.append(image_embedding)
                pad_masks.append(
                    image_mask[:, None].expand(bsize, num_image_tokens)
                )
                att_masks += [0] * num_image_tokens

        else:
            for _, image, image_mask in image_items:
                image_embedding = self._apply_checkpoint(
                    image_embed_func,
                    image,
                )

                bsize, num_image_tokens = image_embedding.shape[:2]

                embs.append(image_embedding)
                pad_masks.append(
                    image_mask[:, None].expand(bsize, num_image_tokens)
                )
                att_masks += [0] * num_image_tokens

        # Language processing must remain unconditional and outside
        # both image-embedding branches.
        def language_embed_func(tokens):
            language_embedding = (
                self.paligemma_with_expert.embed_language_tokens(tokens)
            )
            embedding_dim = language_embedding.shape[-1]
            return language_embedding * math.sqrt(embedding_dim)

        language_embedding = self._apply_checkpoint(
            language_embed_func,
            lang_tokens,
        )

        embs.append(language_embedding)
        pad_masks.append(lang_masks)

        num_language_tokens = language_embedding.shape[1]
        att_masks += [0] * num_language_tokens

        prefix_embeddings = torch.cat(embs, dim=1)
        prefix_pad_masks = torch.cat(pad_masks, dim=1)

        prefix_attention_masks = torch.tensor(
            att_masks,
            dtype=torch.bool,
            device=prefix_pad_masks.device,
        )

        bsize = prefix_pad_masks.shape[0]
        prefix_attention_masks = prefix_attention_masks[None, :].expand(
            bsize,
            len(att_masks),
        )

        return (
            prefix_embeddings,
            prefix_pad_masks,
            prefix_attention_masks,
        )

'''

path.write_text(text[:start] + new_function + text[end:])
print("Replaced embed_prefix() in:", path)
PY