diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/configs.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/configs.py index 24f3c49ab..084377ab6 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/configs.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/configs.py @@ -54,6 +54,8 @@ class MeshConfig: degree *across* slices (Data Center Network). This typically contains a single entry for the data-parallel axis. Example: {'data': 2} + If None, an ordinary device mesh will be used, rather than a hybrid + device mesh (intended for multi-replica workloads) allow_split_physical_axes: If True, we will split physical axes if necessary to produce the desired device mesh. process_is_granule: If True, treat processes as the units of the @@ -61,6 +63,6 @@ class MeshConfig: """ mesh_axes: list[str] ici_parallelism: dict[str, int] = dataclasses.field(default_factory=dict) - dcn_parallelism: dict[str, int] = dataclasses.field(default_factory=dict) + dcn_parallelism: dict[str, int] | None = None allow_split_physical_axes: bool = False process_is_granule: bool = False diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py index 2e6c2087a..40ab7ca9a 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py @@ -77,12 +77,15 @@ class TestContext: path: The test directory path. options: The specific BenchmarkOptions for this test variant. mesh: The mesh used for sharding the checkpoint data. + repeat_index: The index of the repeat run, if this test is run multiple + times. """ pytree: Any path: epath.Path options: BenchmarkOptions # The specific options for this test variant. mesh: jax.sharding.Mesh | None = None + repeat_index: int | None = None @dataclasses.dataclass @@ -171,7 +174,11 @@ def run(self, repeat_index: int | None = None) -> TestResult: multihost.sync_global_processes("benchmark:setup_pytree") context = TestContext( - pytree=data, path=path, options=self.options, mesh=self.mesh + pytree=data, + path=path, + options=self.options, + mesh=self.mesh, + repeat_index=repeat_index, ) test_context_summary = self._build_test_context_summary(context) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py index f65fc7a6c..acbfe3c21 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py @@ -42,7 +42,19 @@ def create_mesh(config: configs.MeshConfig) -> jax.sharding.Mesh: num_devices = len(devices) # Convert the user-friendly dict maps into ordered lists based on mesh_axes ici_shape = [config.ici_parallelism.get(axis, 1) for axis in config.mesh_axes] - dcn_shape = [config.dcn_parallelism.get(axis, 1) for axis in config.mesh_axes] + + dcn_parallelism = config.dcn_parallelism + if dcn_parallelism is None: + logging.info('Creating ICI-only mesh.') + devices_array = mesh_utils.create_device_mesh(ici_shape, devices) + logging.info( + 'Creating mesh with axes: %s', + {axis: dim for axis, dim in zip(config.mesh_axes, devices_array.shape)}, + ) + return jax.sharding.Mesh(devices_array, config.mesh_axes) + else: + logging.info('Creating hybrid mesh.') + dcn_shape = [dcn_parallelism.get(axis, 1) for axis in config.mesh_axes] # --- Validation --- if config.process_is_granule: diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark.py new file mode 100644 index 000000000..1864c38ea --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark.py @@ -0,0 +1,173 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for orbax.checkpoint.PyTreeCheckpointHandler.""" + +from __future__ import annotations + +import dataclasses +import pprint +import time +from typing import Any + +from absl import logging +import jax +from jax.experimental import multihost_utils +import numpy as np +import orbax.checkpoint as ocp_v0 # pylint: disable=unused-import +from orbax.checkpoint import v1 as ocp +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core +from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib +import requests + + +SERVICE_URL = "http://service-dns/" + + +def _metrics_to_measure(options: LustreBenchmarkOptions) -> list[str]: + """Returns the list of metrics to measure.""" + del options + metrics = ["time", "rss", "io"] + return metrics + + +# ============================================================================== +# 1. Define the Options Dataclass for this specific benchmark +# ============================================================================== +@dataclasses.dataclass(frozen=True) +class LustreBenchmarkOptions(benchmarks_core.BenchmarkOptions): + """Configuration options for benchmarks targeting PyTreeCheckpointHandler. + + Each attribute can be a single value or a list of values to create + a parameter sweep. + + Attributes: + use_ocdbt: Whether to use OCDBT for checkpointing. + """ + + use_ocdbt: bool = True + + def is_valid(self): + return True + + +class StorageServiceClient: + """Docstring.""" + + def __init__(self, service_url: str | None = None): + self._service_url = service_url or SERVICE_URL + + def resolve(self, execution_id: int, step: int) -> str: + """Resolves an asset path from the service.""" + start = time.time() + logging.info("Resolving ID-step: %s-%s.", execution_id, step) + payload = {"execution_id": execution_id, "step": step} + response = requests.post(f"{self._service_url}/resolve", json=payload) + logging.info("Response: %s", response.json()) + response.raise_for_status() + result = response.json()["path"] + end = time.time() + logging.info("Resolved %s in %s seconds.", result, end - start) + return result + + def finalize(self, execution_id: int, step: int) -> None: + """Finalizes an asset in the service.""" + start = time.time() + payload = {"execution_id": execution_id, "step": step} + response = requests.post(f"{self._service_url}/finalize", json=payload) + response.raise_for_status() + logging.info(response) + # assert response.json()["status"] == "ok" + end = time.time() + logging.info( + "Finalized %s %s in %s seconds.", execution_id, step, end - start + ) + + +def _get_xid() -> int: + """Returns the XID for this run.""" + xid = multihost_utils.broadcast_one_to_all( + np.asarray(int(time.time())) + ).item() + logging.info("XID: %s", xid) + return xid + + +# ============================================================================== +# 2. Implement the Benchmark Generator +# ============================================================================== +@benchmarks_core.benchmark_options(LustreBenchmarkOptions) +class LustreBenchmark(benchmarks_core.BenchmarksGenerator): + """Docstring.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._client = StorageServiceClient() + self._xid = _get_xid() + + def _clear_pytree(self, pytree: Any) -> Any: + """Clears the pytree to free up memory.""" + return jax.tree.map( + lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree + ) + + def test_fn( + self, context: benchmarks_core.TestContext + ) -> benchmarks_core.TestResult: + """The core test logic for a single save/restore cycle. + + This function is called for each combination of options generated by the + framework. It uses the `context.options` to configure the handler + dynamically for each run. + + Args: + context: The test context containing the pytree, path, and options. + + Returns: + The test result containing the metrics. + """ + logging.info( + "JAX info: %s processes, %s devices, %s process index", + jax.process_count(), + jax.device_count(), + jax.process_index(), + ) + metrics = metric_lib.Metrics() + pytree = context.pytree + options = context.options + assert isinstance(options, LustreBenchmarkOptions) + + logging.info("Benchmark options: %s", pprint.pformat(options)) + + metrics_to_measure = _metrics_to_measure(options) + + step = context.repeat_index or 0 + + with metrics.measure("resolve_cache", metrics_to_measure): + resolved_path = self._client.resolve(self._xid, step) + with metrics.measure("save_cache", metrics_to_measure): + ocp.save_pytree(resolved_path, pytree) + with metrics.measure("finalize_cache", metrics_to_measure): + self._client.finalize(self._xid, step) + with metrics.measure("restore_cache", metrics_to_measure): + restored_pytree = ocp.load_pytree(resolved_path, pytree) + self._clear_pytree(restored_pytree) + + with metrics.measure("save", metrics_to_measure): + ocp.save_pytree(context.path / str(step), pytree) + with metrics.measure("restore", metrics_to_measure): + restored_pytree = ocp.load_pytree(context.path / str(step), pytree) + self._clear_pytree(restored_pytree) + + return benchmarks_core.TestResult(metrics=metrics) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark_test.py new file mode 100644 index 000000000..73551b887 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark_test.py @@ -0,0 +1,30 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint._src.testing.benchmarks import lustre_benchmark + + +class LustreBenchmarkTest(multiprocess_test.MultiProcessTest): + + def test_xid(self): + xid = lustre_benchmark._get_xid() + self.assertIsInstance(xid, int) + logging.info('XID: %s', xid) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md index cd0403a8d..91429bf82 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md @@ -160,6 +160,7 @@ JAX on TPU. | `--jax-version` | `newest` | `newest`, `nightly`, or `x.y.z`. | **Debugging**. Use `nightly` to test bleeding-edge JAX features. | | `--device` | `tpu` | `tpu`, `gpu`, `cpu`. | **Multi-Hardware**. When testing on GPU or CP/Local validation. | | `--base-image` | `python:3.11...` | Base Docker Image. | **Advanced**. If you need custom drivers or non-standard OS libs. | +| `--no-cache` | `N/A` | Disable Docker build cache for all layers. | **Debugging**. Force rebuild of all layers from scratch. | --- diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh index fcb83de28..486b95983 100755 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh @@ -12,17 +12,22 @@ BRANCH="main" JAX_VERSION="newest" DEVICE="tpu" BASE_IMAGE="" +DOCKERFILE_PATH="" +NO_CACHE_FLAG="" function print_usage() { echo "Usage: $0 [OPTIONS]" - echo "Options:" + echo "Options:"+ echo " --project PROJECT_ID GCP Project ID" echo " --pr PR_NUMBER GitHub PR number" + echo " --image IMAGE_NAME Image name (default: orbax-benchmarks)" echo " --branch BRANCH GitHub branch (default: main)" echo " --jax-version VERSION JAX version: newest, nightly, or X.Y.Z (default: newest)" echo " --device DEVICE Device type: tpu, gpu, cpu (default: tpu)" echo " --base-image IMAGE Base Docker image (optional)" + echo " --dockerfile FILE Dockerfile path (optional)" echo " --tag TAG Image tag" + echo " --no-cache Disable Docker build cache" echo " --help Show this help" } @@ -32,11 +37,14 @@ while [[ $# -gt 0 ]]; do case $1 in --project) PROJECT_ID="$2"; shift 2 ;; --pr) PR_NUMBER="$2"; shift 2 ;; + --image) IMAGE_NAME="$2"; shift 2 ;; --branch) BRANCH="$2"; shift 2 ;; --jax-version) JAX_VERSION="$2"; shift 2 ;; --device) DEVICE="$2"; shift 2 ;; --base-image) BASE_IMAGE="$2"; shift 2 ;; + --dockerfile) DOCKERFILE_PATH="$2"; shift 2 ;; --tag) USER_TAG_FLAG="$2"; shift 2 ;; + --no-cache) NO_CACHE_FLAG="--no-cache"; shift 1 ;; --help) print_usage; exit 0 ;; *) echo "Unknown argument: $1"; print_usage; exit 1 ;; esac @@ -54,7 +62,9 @@ if [[ -z "$BASE_IMAGE" ]]; then fi SCRIPT_DIR="$(dirname "$(realpath "$0")")" -DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile" +if [[ -z "$DOCKERFILE_PATH" ]]; then + DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile" +fi if [[ ! -f "$DOCKERFILE_PATH" ]]; then # Fallback: check if we are running in the source dir @@ -110,15 +120,23 @@ done # Build with local Docker echo "Building with previously installed Docker..." -docker build \ - --build-arg BASE_IMAGE="${BASE_IMAGE}" \ - --build-arg BRANCH="${BRANCH}" \ - --build-arg JAX_VERSION="${JAX_VERSION}" \ - --build-arg DEVICE="${DEVICE}" \ - --build-arg PR_NUMBER="${PR_NUMBER}" \ - "${build_tag_args[@]}" \ - -f "${DOCKERFILE_PATH}" \ - . +declare -a build_args=() +if [[ -n "${NO_CACHE_FLAG}" ]]; then + build_args+=("${NO_CACHE_FLAG}") +fi +build_args+=( + "--build-arg" "BASE_IMAGE=${BASE_IMAGE}" + "--build-arg" "BRANCH=${BRANCH}" + "--build-arg" "JAX_VERSION=${JAX_VERSION}" + "--build-arg" "DEVICE=${DEVICE}" + "--build-arg" "PR_NUMBER=${PR_NUMBER}" +) +build_args+=("${build_tag_args[@]}") +build_args+=( + "-f" "${DOCKERFILE_PATH}" + "." +) +docker build "${build_args[@]}" echo "Pushing image to registry..." for t in "${tags[@]}"; do diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py index c2b3f66e1..48f8b4384 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py @@ -547,6 +547,7 @@ def construct_workload_command( f'python3 {benchmark_binary_path} ' f'--config_file={config_file} ' f'--output_directory={os.path.join(output_directory, run_id)} ' + '--v=1 ' '--alsologtostderr' ) @@ -567,6 +568,7 @@ def construct_xpk_command( f'--workload={workload_name}', f'--num-slices={_NUM_SLICES.value}', f'--priority={_PRIORITY.value}', + '--storage=test-service-lustre', ] if _TPU_TYPE.value is not None: