Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions drjax/_src/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _register_broadcast_impls(
broadcast_prim_fn: BroadcastType,
broadcast_array_eval: BroadcastType,
sum_prim_fn: AggType,
placement_str: str,
n_elements: int,
) -> None:
"""Registers implementations for the broadcast primitive.
Expand All @@ -75,13 +76,37 @@ def _register_broadcast_impls(
sum_prim_fn: A callable which binds its arguments to the summation primitive
from the placement inserted by this broadcast. Similar to
`broadcast_prim_fn`.
placement_str: The name of the placement which this broadcast targets.
n_elements: The number of elements present at the placement which this
broadcast targets.
"""

def broadcast_abstract_eval(xs, *, mesh):
del mesh
return core.ShapedArray((n_elements,) + xs.shape, xs.dtype)
def broadcast_abstract_eval(
xs, *, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None
):
# If no mesh was provided, we try to use the current abstract mesh.
if mesh is None:
abstract_mesh = jax.sharding.get_abstract_mesh()
else:
abstract_mesh = (
mesh.abstract_mesh if isinstance(mesh, jax.sharding.Mesh) else mesh
)
sharding_axis = (
placement_str
if impls._placement_axis_in_mesh(abstract_mesh, placement_str) # pylint: disable=protected-access
else None
)
new_sharding = xs.sharding.update(
mesh=abstract_mesh,
spec=jax.sharding.PartitionSpec(sharding_axis, *xs.sharding.spec),
)
return core.ShapedArray(
shape=(n_elements,) + xs.shape,
dtype=xs.dtype,
weak_type=xs.weak_type,
sharding=new_sharding,
memory_space=xs.memory_space,
)

# Abstract eval rule.
broadcast_p.def_abstract_eval(broadcast_abstract_eval)
Expand All @@ -101,11 +126,11 @@ def broadcast_jvp(primals_in, tangents_in, mesh):
ad.primitive_jvps[broadcast_p] = broadcast_jvp

def broadcast_vjp(cotangents_out, primals_in, mesh):
del mesh
del mesh # Unused.
if isinstance(cotangents_out, jax.interpreters.ad.Zero):
# We are differerentiating back through a broadcast; the incoming value,
# therefore, has the right shape and dtype for the Zero we generate.
return (jax.interpreters.ad.Zero(primals_in.aval),)
return (jax.interpreters.ad.Zero.from_primal_value(primals_in.aval),)
# This implementation *must* use the sum_prim_fn, rather than the array
# implementation of summation, to result in a reduce_sum in the Jaxpr.
return (sum_prim_fn(cotangents_out),)
Expand Down Expand Up @@ -157,11 +182,21 @@ def _register_single_arg_agg_impls(
"""

def agg_abstract_eval(xs):
return jax.tree_util.tree_map(
# We slice away the first dimension in doing the reduction; its gone!
lambda x: core.ShapedArray(x.shape[1:], x.dtype),
xs,
)

def aval_with_new_sharding(x):
# We slice away the first dimension in doing the reduction; its gone!
new_sharding = x.sharding.update(
spec=jax.sharding.PartitionSpec(*x.sharding.spec[1:])
)
return core.ShapedArray(
shape=x.shape[1:],
dtype=x.dtype,
weak_type=x.weak_type,
sharding=new_sharding,
memory_space=x.memory_space,
)

return jax.tree.map(aval_with_new_sharding, xs)

# Abstract eval rule
agg_p.def_abstract_eval(agg_abstract_eval)
Expand Down Expand Up @@ -194,7 +229,7 @@ def agg_vjp(cotangents_out, primals_in):
# generate. This is always correct if jax's symbolic Zero is a static
# concept, depending on data flow in the program (rather than e.g. runtime
# values).
return (jax.interpreters.ad.Zero(primals_in.aval),)
return (jax.interpreters.ad.Zero.from_primal_value(primals_in),)
return (vjp_impl(cotangents_out),)

ad.primitive_transposes[agg_p] = agg_vjp
Expand Down Expand Up @@ -257,6 +292,7 @@ def broadcast_array_eval(x, *, mesh):
broadcast_prim_fn,
broadcast_array_eval,
sum_prim_fn,
placement_str,
n_elements,
)

Expand Down
49 changes: 49 additions & 0 deletions drjax/_src/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Sequence
import functools

from absl.testing import absltest
from absl.testing import parameterized
import chex
from drjax._src import impls
from drjax._src import primitives
import jax
from jax import numpy as jnp
from jax.sharding import AxisType # pylint: disable=g-importing-member
import numpy as np


def _jaxpr_has_primitive(jaxpr, prim_name: str):
Expand All @@ -32,6 +37,36 @@ def _jaxpr_has_primitive(jaxpr, prim_name: str):
return False


def create_mesh(
axis_type: AxisType,
) -> jax.sharding.Mesh:
return jax.sharding.Mesh(
np.asarray(jax.devices()).reshape(2, 4),
axis_names=('clients', 'data'),
axis_types=(axis_type, axis_type),
)


def run_in_mesh(mesh_axes_types: Sequence[AxisType]):

def _decorator(fn):

@functools.wraps(fn)
def _wrapped(self, *args, **kwargs):
with self.subTest('no_mesh'):
# Run once without a mesh, must not raise error.
fn(self, *args, **kwargs)
for mesh_axes_type in mesh_axes_types:
with self.subTest(f'{mesh_axes_type=}'):
mesh = create_mesh(mesh_axes_type)
with jax.set_mesh(mesh), mesh:
fn(self, *args, **kwargs)

return _wrapped

return _decorator


class PrimitivesActingOnArraysTest(parameterized.TestCase):

def setUp(self):
Expand All @@ -44,6 +79,7 @@ def setUp(self):
{'clients': self._n_clients},
)

@run_in_mesh((AxisType.Auto, AxisType.Explicit))
def test_broadcast_clients_evaluation(self):
fn = self._primdefs['broadcast_clients']
# Check that this function is callable.
Expand All @@ -58,11 +94,13 @@ def test_broadcast_clients_evaluation(self):
chex.assert_trees_all_close(
jax.jacfwd(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients])
)

# Also that it's reverse-diffable.
chex.assert_trees_all_close(
jax.jacrev(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients])
)

@run_in_mesh((AxisType.Auto, AxisType.Explicit))
def test_broadcast_clients_closure_under_fad(self):
fn = self._primdefs['broadcast_clients']
# Check that the forward and reverse-mode derivatives generate the expected
Expand All @@ -72,6 +110,7 @@ def test_broadcast_clients_closure_under_fad(self):
rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(jnp.array(1.0))
self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'sum_from_clients'))

@run_in_mesh((AxisType.Auto, AxisType.Explicit))
def test_sum_from_clients_evaluation(self):
fn = self._primdefs['sum_from_clients']
clients_ones = jnp.ones(shape=[self._n_clients, 1])
Expand All @@ -92,6 +131,7 @@ def test_sum_from_clients_evaluation(self):
jax.jacrev(fn)(clients_ones), jnp.ones(shape=[1, self._n_clients, 1])
)

@run_in_mesh((AxisType.Auto, AxisType.Explicit))
def test_broadcast_and_sum_from_clients_eval(self):
fn = self._primdefs['sum_from_clients']

Expand All @@ -111,6 +151,7 @@ def _broadcast_then_sum(x):
jnp.array([[1.0 * self._n_clients]]),
)

@run_in_mesh((AxisType.Auto, AxisType.Explicit))
def test_sum_from_clients_closure_under_fad(self):
# Check that the forward and reverse-mode derivatives generate the expected
# primitives.
Expand All @@ -121,6 +162,7 @@ def test_sum_from_clients_closure_under_fad(self):
rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(clients_ones)
self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'broadcast_clients'))

@run_in_mesh((AxisType.Auto, AxisType.Explicit))
def test_mean_from_clients_eval(self):
fn = self._primdefs['mean_from_clients']
clients_ones = jnp.ones(shape=[self._n_clients, 1])
Expand All @@ -134,6 +176,7 @@ def test_mean_from_clients_eval(self):
1 / self._n_clients * jnp.ones(shape=[1, self._n_clients, 1]),
)

@run_in_mesh((AxisType.Auto, AxisType.Explicit))
def test_broadcast_then_mean_from_clients_eval(self):
fn = self._primdefs['mean_from_clients']

Expand All @@ -151,6 +194,7 @@ def _broadcast_then_sum(x):
jnp.array([[1.0]]),
)

@run_in_mesh((AxisType.Auto, AxisType.Explicit))
def test_mean_from_clients_closure_under_fad(self):
# Check that the forward and reverse-mode derivatives generate the expected
# primitives.
Expand Down Expand Up @@ -203,5 +247,10 @@ def ignore_prim_result(x):
)


# This allows us to test sharding behavior across multiple devices.
def setUpModule():
chex.set_n_cpu_devices(8)


if __name__ == '__main__':
absltest.main()
Loading