From b1baf16c64017662a31b6305c831940fe5aa3f5d Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Sat, 4 Oct 2025 15:07:43 -0700 Subject: [PATCH] Ensure primitives propagate sharding annotations in abstract evaluation. PiperOrigin-RevId: 815172986 --- drjax/_src/primitives.py | 58 ++++++++++++++++++++++++++++------- drjax/_src/primitives_test.py | 49 +++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 11 deletions(-) diff --git a/drjax/_src/primitives.py b/drjax/_src/primitives.py index fdb01d8..fb342bd 100644 --- a/drjax/_src/primitives.py +++ b/drjax/_src/primitives.py @@ -56,6 +56,7 @@ def _register_broadcast_impls( broadcast_prim_fn: BroadcastType, broadcast_array_eval: BroadcastType, sum_prim_fn: AggType, + placement_str: str, n_elements: int, ) -> None: """Registers implementations for the broadcast primitive. @@ -75,13 +76,37 @@ def _register_broadcast_impls( sum_prim_fn: A callable which binds its arguments to the summation primitive from the placement inserted by this broadcast. Similar to `broadcast_prim_fn`. + placement_str: The name of the placement which this broadcast targets. n_elements: The number of elements present at the placement which this broadcast targets. """ - def broadcast_abstract_eval(xs, *, mesh): - del mesh - return core.ShapedArray((n_elements,) + xs.shape, xs.dtype) + def broadcast_abstract_eval( + xs, *, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None + ): + # If no mesh was provided, we try to use the current abstract mesh. + if mesh is None: + abstract_mesh = jax.sharding.get_abstract_mesh() + else: + abstract_mesh = ( + mesh.abstract_mesh if isinstance(mesh, jax.sharding.Mesh) else mesh + ) + sharding_axis = ( + placement_str + if impls._placement_axis_in_mesh(abstract_mesh, placement_str) # pylint: disable=protected-access + else None + ) + new_sharding = xs.sharding.update( + mesh=abstract_mesh, + spec=jax.sharding.PartitionSpec(sharding_axis, *xs.sharding.spec), + ) + return core.ShapedArray( + shape=(n_elements,) + xs.shape, + dtype=xs.dtype, + weak_type=xs.weak_type, + sharding=new_sharding, + memory_space=xs.memory_space, + ) # Abstract eval rule. broadcast_p.def_abstract_eval(broadcast_abstract_eval) @@ -101,11 +126,11 @@ def broadcast_jvp(primals_in, tangents_in, mesh): ad.primitive_jvps[broadcast_p] = broadcast_jvp def broadcast_vjp(cotangents_out, primals_in, mesh): - del mesh + del mesh # Unused. if isinstance(cotangents_out, jax.interpreters.ad.Zero): # We are differerentiating back through a broadcast; the incoming value, # therefore, has the right shape and dtype for the Zero we generate. - return (jax.interpreters.ad.Zero(primals_in.aval),) + return (jax.interpreters.ad.Zero.from_primal_value(primals_in.aval),) # This implementation *must* use the sum_prim_fn, rather than the array # implementation of summation, to result in a reduce_sum in the Jaxpr. return (sum_prim_fn(cotangents_out),) @@ -157,11 +182,21 @@ def _register_single_arg_agg_impls( """ def agg_abstract_eval(xs): - return jax.tree_util.tree_map( - # We slice away the first dimension in doing the reduction; its gone! - lambda x: core.ShapedArray(x.shape[1:], x.dtype), - xs, - ) + + def aval_with_new_sharding(x): + # We slice away the first dimension in doing the reduction; its gone! + new_sharding = x.sharding.update( + spec=jax.sharding.PartitionSpec(*x.sharding.spec[1:]) + ) + return core.ShapedArray( + shape=x.shape[1:], + dtype=x.dtype, + weak_type=x.weak_type, + sharding=new_sharding, + memory_space=x.memory_space, + ) + + return jax.tree.map(aval_with_new_sharding, xs) # Abstract eval rule agg_p.def_abstract_eval(agg_abstract_eval) @@ -194,7 +229,7 @@ def agg_vjp(cotangents_out, primals_in): # generate. This is always correct if jax's symbolic Zero is a static # concept, depending on data flow in the program (rather than e.g. runtime # values). - return (jax.interpreters.ad.Zero(primals_in.aval),) + return (jax.interpreters.ad.Zero.from_primal_value(primals_in),) return (vjp_impl(cotangents_out),) ad.primitive_transposes[agg_p] = agg_vjp @@ -257,6 +292,7 @@ def broadcast_array_eval(x, *, mesh): broadcast_prim_fn, broadcast_array_eval, sum_prim_fn, + placement_str, n_elements, ) diff --git a/drjax/_src/primitives_test.py b/drjax/_src/primitives_test.py index daa1eca..76b2ed8 100644 --- a/drjax/_src/primitives_test.py +++ b/drjax/_src/primitives_test.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence +import functools + from absl.testing import absltest from absl.testing import parameterized import chex @@ -19,6 +22,8 @@ from drjax._src import primitives import jax from jax import numpy as jnp +from jax.sharding import AxisType # pylint: disable=g-importing-member +import numpy as np def _jaxpr_has_primitive(jaxpr, prim_name: str): @@ -32,6 +37,36 @@ def _jaxpr_has_primitive(jaxpr, prim_name: str): return False +def create_mesh( + axis_type: AxisType, +) -> jax.sharding.Mesh: + return jax.sharding.Mesh( + np.asarray(jax.devices()).reshape(2, 4), + axis_names=('clients', 'data'), + axis_types=(axis_type, axis_type), + ) + + +def run_in_mesh(mesh_axes_types: Sequence[AxisType]): + + def _decorator(fn): + + @functools.wraps(fn) + def _wrapped(self, *args, **kwargs): + with self.subTest('no_mesh'): + # Run once without a mesh, must not raise error. + fn(self, *args, **kwargs) + for mesh_axes_type in mesh_axes_types: + with self.subTest(f'{mesh_axes_type=}'): + mesh = create_mesh(mesh_axes_type) + with jax.set_mesh(mesh), mesh: + fn(self, *args, **kwargs) + + return _wrapped + + return _decorator + + class PrimitivesActingOnArraysTest(parameterized.TestCase): def setUp(self): @@ -44,6 +79,7 @@ def setUp(self): {'clients': self._n_clients}, ) + @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_clients_evaluation(self): fn = self._primdefs['broadcast_clients'] # Check that this function is callable. @@ -58,11 +94,13 @@ def test_broadcast_clients_evaluation(self): chex.assert_trees_all_close( jax.jacfwd(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) ) + # Also that it's reverse-diffable. chex.assert_trees_all_close( jax.jacrev(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) ) + @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_clients_closure_under_fad(self): fn = self._primdefs['broadcast_clients'] # Check that the forward and reverse-mode derivatives generate the expected @@ -72,6 +110,7 @@ def test_broadcast_clients_closure_under_fad(self): rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(jnp.array(1.0)) self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'sum_from_clients')) + @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_sum_from_clients_evaluation(self): fn = self._primdefs['sum_from_clients'] clients_ones = jnp.ones(shape=[self._n_clients, 1]) @@ -92,6 +131,7 @@ def test_sum_from_clients_evaluation(self): jax.jacrev(fn)(clients_ones), jnp.ones(shape=[1, self._n_clients, 1]) ) + @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_and_sum_from_clients_eval(self): fn = self._primdefs['sum_from_clients'] @@ -111,6 +151,7 @@ def _broadcast_then_sum(x): jnp.array([[1.0 * self._n_clients]]), ) + @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_sum_from_clients_closure_under_fad(self): # Check that the forward and reverse-mode derivatives generate the expected # primitives. @@ -121,6 +162,7 @@ def test_sum_from_clients_closure_under_fad(self): rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(clients_ones) self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'broadcast_clients')) + @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_mean_from_clients_eval(self): fn = self._primdefs['mean_from_clients'] clients_ones = jnp.ones(shape=[self._n_clients, 1]) @@ -134,6 +176,7 @@ def test_mean_from_clients_eval(self): 1 / self._n_clients * jnp.ones(shape=[1, self._n_clients, 1]), ) + @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_then_mean_from_clients_eval(self): fn = self._primdefs['mean_from_clients'] @@ -151,6 +194,7 @@ def _broadcast_then_sum(x): jnp.array([[1.0]]), ) + @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_mean_from_clients_closure_under_fad(self): # Check that the forward and reverse-mode derivatives generate the expected # primitives. @@ -203,5 +247,10 @@ def ignore_prim_result(x): ) +# This allows us to test sharding behavior across multiple devices. +def setUpModule(): + chex.set_n_cpu_devices(8) + + if __name__ == '__main__': absltest.main()