Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ class MeshConfig:
degree *across* slices (Data Center Network). This typically contains a
single entry for the data-parallel axis.
Example: {'data': 2}
If None, an ordinary device mesh will be used, rather than a hybrid
device mesh (intended for multi-replica workloads)
allow_split_physical_axes: If True, we will split physical axes if
necessary to produce the desired device mesh.
process_is_granule: If True, treat processes as the units of the
slower/outer network.
"""
mesh_axes: list[str]
ici_parallelism: dict[str, int] = dataclasses.field(default_factory=dict)
dcn_parallelism: dict[str, int] = dataclasses.field(default_factory=dict)
dcn_parallelism: dict[str, int] | None = None
allow_split_physical_axes: bool = False
process_is_granule: bool = False
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,15 @@ class TestContext:
path: The test directory path.
options: The specific BenchmarkOptions for this test variant.
mesh: The mesh used for sharding the checkpoint data.
repeat_index: The index of the repeat run, if this test is run multiple
times.
"""

pytree: Any
path: epath.Path
options: BenchmarkOptions # The specific options for this test variant.
mesh: jax.sharding.Mesh | None = None
repeat_index: int | None = None


@dataclasses.dataclass
Expand Down Expand Up @@ -171,7 +174,11 @@ def run(self, repeat_index: int | None = None) -> TestResult:
multihost.sync_global_processes("benchmark:setup_pytree")

context = TestContext(
pytree=data, path=path, options=self.options, mesh=self.mesh
pytree=data,
path=path,
options=self.options,
mesh=self.mesh,
repeat_index=repeat_index,
)

test_context_summary = self._build_test_context_summary(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,19 @@ def create_mesh(config: configs.MeshConfig) -> jax.sharding.Mesh:
num_devices = len(devices)
# Convert the user-friendly dict maps into ordered lists based on mesh_axes
ici_shape = [config.ici_parallelism.get(axis, 1) for axis in config.mesh_axes]
dcn_shape = [config.dcn_parallelism.get(axis, 1) for axis in config.mesh_axes]

dcn_parallelism = config.dcn_parallelism
if dcn_parallelism is None:
logging.info('Creating ICI-only mesh.')
devices_array = mesh_utils.create_device_mesh(ici_shape, devices)
logging.info(
'Creating mesh with axes: %s',
{axis: dim for axis, dim in zip(config.mesh_axes, devices_array.shape)},
)
return jax.sharding.Mesh(devices_array, config.mesh_axes)
else:
logging.info('Creating hybrid mesh.')
dcn_shape = [dcn_parallelism.get(axis, 1) for axis in config.mesh_axes]

# --- Validation ---
if config.process_is_granule:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmarks for orbax.checkpoint.PyTreeCheckpointHandler."""

from __future__ import annotations

import dataclasses
import pprint
import time
from typing import Any

from absl import logging
import jax
from jax.experimental import multihost_utils
import numpy as np
import orbax.checkpoint as ocp_v0 # pylint: disable=unused-import
from orbax.checkpoint import v1 as ocp
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
import requests


SERVICE_URL = "http://service-dns/"


def _metrics_to_measure(options: LustreBenchmarkOptions) -> list[str]:
"""Returns the list of metrics to measure."""
del options
metrics = ["time", "rss", "io"]
return metrics


# ==============================================================================
# 1. Define the Options Dataclass for this specific benchmark
# ==============================================================================
@dataclasses.dataclass(frozen=True)
class LustreBenchmarkOptions(benchmarks_core.BenchmarkOptions):
"""Configuration options for benchmarks targeting PyTreeCheckpointHandler.

Each attribute can be a single value or a list of values to create
a parameter sweep.

Attributes:
use_ocdbt: Whether to use OCDBT for checkpointing.
"""

use_ocdbt: bool = True

def is_valid(self):
return True


class StorageServiceClient:
"""Docstring."""

def __init__(self, service_url: str | None = None):
self._service_url = service_url or SERVICE_URL

def resolve(self, execution_id: int, step: int) -> str:
"""Resolves an asset path from the service."""
start = time.time()
logging.info("Resolving ID-step: %s-%s.", execution_id, step)
payload = {"execution_id": execution_id, "step": step}
response = requests.post(f"{self._service_url}/resolve", json=payload)
logging.info("Response: %s", response.json())
response.raise_for_status()
result = response.json()["path"]
end = time.time()
logging.info("Resolved %s in %s seconds.", result, end - start)
return result

def finalize(self, execution_id: int, step: int) -> None:
"""Finalizes an asset in the service."""
start = time.time()
payload = {"execution_id": execution_id, "step": step}
response = requests.post(f"{self._service_url}/finalize", json=payload)
response.raise_for_status()
logging.info(response)
# assert response.json()["status"] == "ok"
end = time.time()
logging.info(
"Finalized %s %s in %s seconds.", execution_id, step, end - start
)


def _get_xid() -> int:
"""Returns the XID for this run."""
xid = multihost_utils.broadcast_one_to_all(
np.asarray(int(time.time()))
).item()
logging.info("XID: %s", xid)
return xid


# ==============================================================================
# 2. Implement the Benchmark Generator
# ==============================================================================
@benchmarks_core.benchmark_options(LustreBenchmarkOptions)
class LustreBenchmark(benchmarks_core.BenchmarksGenerator):
"""Docstring."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._client = StorageServiceClient()
self._xid = _get_xid()

def _clear_pytree(self, pytree: Any) -> Any:
"""Clears the pytree to free up memory."""
return jax.tree.map(
lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree
)

def test_fn(
self, context: benchmarks_core.TestContext
) -> benchmarks_core.TestResult:
"""The core test logic for a single save/restore cycle.

This function is called for each combination of options generated by the
framework. It uses the `context.options` to configure the handler
dynamically for each run.

Args:
context: The test context containing the pytree, path, and options.

Returns:
The test result containing the metrics.
"""
logging.info(
"JAX info: %s processes, %s devices, %s process index",
jax.process_count(),
jax.device_count(),
jax.process_index(),
)
metrics = metric_lib.Metrics()
pytree = context.pytree
options = context.options
assert isinstance(options, LustreBenchmarkOptions)

logging.info("Benchmark options: %s", pprint.pformat(options))

metrics_to_measure = _metrics_to_measure(options)

step = context.repeat_index or 0

with metrics.measure("resolve_cache", metrics_to_measure):
resolved_path = self._client.resolve(self._xid, step)
with metrics.measure("save_cache", metrics_to_measure):
ocp.save_pytree(resolved_path, pytree)
with metrics.measure("finalize_cache", metrics_to_measure):
self._client.finalize(self._xid, step)
with metrics.measure("restore_cache", metrics_to_measure):
restored_pytree = ocp.load_pytree(resolved_path, pytree)
self._clear_pytree(restored_pytree)

with metrics.measure("save", metrics_to_measure):
ocp.save_pytree(context.path / str(step), pytree)
with metrics.measure("restore", metrics_to_measure):
restored_pytree = ocp.load_pytree(context.path / str(step), pytree)
self._clear_pytree(restored_pytree)

return benchmarks_core.TestResult(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from orbax.checkpoint._src.testing import multiprocess_test
from orbax.checkpoint._src.testing.benchmarks import lustre_benchmark


class LustreBenchmarkTest(multiprocess_test.MultiProcessTest):

def test_xid(self):
xid = lustre_benchmark._get_xid()
self.assertIsInstance(xid, int)
logging.info('XID: %s', xid)


if __name__ == '__main__':
multiprocess_test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ JAX on TPU.
| `--jax-version` | `newest` | `newest`, `nightly`, or `x.y.z`. | **Debugging**. Use `nightly` to test bleeding-edge JAX features. |
| `--device` | `tpu` | `tpu`, `gpu`, `cpu`. | **Multi-Hardware**. When testing on GPU or CP/Local validation. |
| `--base-image` | `python:3.11...` | Base Docker Image. | **Advanced**. If you need custom drivers or non-standard OS libs. |
| `--no-cache` | `N/A` | Disable Docker build cache for all layers. | **Debugging**. Force rebuild of all layers from scratch. |

---
<!-- LINT.ThenChange(build_image.sh:build_image_flags) -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,22 @@ BRANCH="main"
JAX_VERSION="newest"
DEVICE="tpu"
BASE_IMAGE=""
DOCKERFILE_PATH=""
NO_CACHE_FLAG=""

function print_usage() {
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo "Options:"+
echo " --project PROJECT_ID GCP Project ID"
echo " --pr PR_NUMBER GitHub PR number"
echo " --image IMAGE_NAME Image name (default: orbax-benchmarks)"
echo " --branch BRANCH GitHub branch (default: main)"
echo " --jax-version VERSION JAX version: newest, nightly, or X.Y.Z (default: newest)"
echo " --device DEVICE Device type: tpu, gpu, cpu (default: tpu)"
echo " --base-image IMAGE Base Docker image (optional)"
echo " --dockerfile FILE Dockerfile path (optional)"
echo " --tag TAG Image tag"
echo " --no-cache Disable Docker build cache"
echo " --help Show this help"
}

Expand All @@ -32,11 +37,14 @@ while [[ $# -gt 0 ]]; do
case $1 in
--project) PROJECT_ID="$2"; shift 2 ;;
--pr) PR_NUMBER="$2"; shift 2 ;;
--image) IMAGE_NAME="$2"; shift 2 ;;
--branch) BRANCH="$2"; shift 2 ;;
--jax-version) JAX_VERSION="$2"; shift 2 ;;
--device) DEVICE="$2"; shift 2 ;;
--base-image) BASE_IMAGE="$2"; shift 2 ;;
--dockerfile) DOCKERFILE_PATH="$2"; shift 2 ;;
--tag) USER_TAG_FLAG="$2"; shift 2 ;;
--no-cache) NO_CACHE_FLAG="--no-cache"; shift 1 ;;
--help) print_usage; exit 0 ;;
*) echo "Unknown argument: $1"; print_usage; exit 1 ;;
esac
Expand All @@ -54,7 +62,9 @@ if [[ -z "$BASE_IMAGE" ]]; then
fi

SCRIPT_DIR="$(dirname "$(realpath "$0")")"
DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile"
if [[ -z "$DOCKERFILE_PATH" ]]; then
DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile"
fi

if [[ ! -f "$DOCKERFILE_PATH" ]]; then
# Fallback: check if we are running in the source dir
Expand Down Expand Up @@ -110,15 +120,23 @@ done

# Build with local Docker
echo "Building with previously installed Docker..."
docker build \
--build-arg BASE_IMAGE="${BASE_IMAGE}" \
--build-arg BRANCH="${BRANCH}" \
--build-arg JAX_VERSION="${JAX_VERSION}" \
--build-arg DEVICE="${DEVICE}" \
--build-arg PR_NUMBER="${PR_NUMBER}" \
"${build_tag_args[@]}" \
-f "${DOCKERFILE_PATH}" \
.
declare -a build_args=()
if [[ -n "${NO_CACHE_FLAG}" ]]; then
build_args+=("${NO_CACHE_FLAG}")
fi
build_args+=(
"--build-arg" "BASE_IMAGE=${BASE_IMAGE}"
"--build-arg" "BRANCH=${BRANCH}"
"--build-arg" "JAX_VERSION=${JAX_VERSION}"
"--build-arg" "DEVICE=${DEVICE}"
"--build-arg" "PR_NUMBER=${PR_NUMBER}"
)
build_args+=("${build_tag_args[@]}")
build_args+=(
"-f" "${DOCKERFILE_PATH}"
"."
)
docker build "${build_args[@]}"

echo "Pushing image to registry..."
for t in "${tags[@]}"; do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def construct_workload_command(
f'python3 {benchmark_binary_path} '
f'--config_file={config_file} '
f'--output_directory={os.path.join(output_directory, run_id)} '
'--v=1 '
'--alsologtostderr'
)

Expand All @@ -567,6 +568,7 @@ def construct_xpk_command(
f'--workload={workload_name}',
f'--num-slices={_NUM_SLICES.value}',
f'--priority={_PRIORITY.value}',
'--storage=test-service-lustre',
]

if _TPU_TYPE.value is not None:
Expand Down
Loading