Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions tesseract_jax/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ def _unpack_hashable(obj: _Hashable) -> Any:


def apply_tesseract(
tesseract_client: Tesseract,
inputs: Any,
tesseract_client: Tesseract, inputs: Any, static_arg_names: Sequence[str] = ()
) -> Any:
"""Applies the given Tesseract object to the inputs.

Expand Down Expand Up @@ -367,6 +366,9 @@ def apply_tesseract(
Args:
tesseract_client: The Tesseract object to apply.
inputs: The inputs to apply to the Tesseract object.
static_arg_names: Names of input arguments that should be treated as static
(i.e., not traced). This is useful for arguments that affect control flow
or shape but are not arrays.

Returns:
The outputs of the Tesseract object after applying the inputs.
Expand Down Expand Up @@ -401,7 +403,18 @@ def apply_tesseract(
client = Jaxeract(tesseract_client)

flat_args, input_pytreedef = jax.tree.flatten(inputs)
is_static_mask = tuple(_is_static(arg) for arg in flat_args)
is_static_mask = tuple((_is_static(arg)) for arg in flat_args)
paths_and_values, _ = jax.tree.flatten_with_path(inputs)
paths_concatenated = [
"/".join(key.key for key in path) for path, _ in paths_and_values
]
is_static_mask_ = tuple(path in static_arg_names for path in paths_concatenated)

# or with the previous mask
is_static_mask = tuple(
a or b for a, b in zip(is_static_mask, is_static_mask_, strict=True)
)

array_args, static_args = split_args(flat_args, is_static_mask)
static_args = tuple(_make_hashable(arg) for arg in static_args)

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def get_tesseract_folders():
tesseract_folders = [
"univariate_tesseract",
"nested_tesseract",
"non_abstract_tesseract",
"vectoradd_tesseract",
# Add more as needed
]
return tesseract_folders
Expand Down Expand Up @@ -86,3 +88,4 @@ def served_tesseract():
served_univariate_tesseract_raw = make_tesseract_fixture("univariate_tesseract")
served_nested_tesseract_raw = make_tesseract_fixture("nested_tesseract")
served_non_abstract_tesseract = make_tesseract_fixture("non_abstract_tesseract")
served_vectoradd_tesseract = make_tesseract_fixture("vectoradd_tesseract")
43 changes: 43 additions & 0 deletions tests/test_endtoend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax.typing import ArrayLike
Expand Down Expand Up @@ -562,6 +563,48 @@ def f(a):
_assert_pytree_isequal(result, result_ref)


@pytest.mark.parametrize("use_jit", [True, False])
def test_tesseract_loss(served_vectoradd_tesseract, use_jit):
vectoradd_tess = Tesseract(served_vectoradd_tesseract)
a = np.array([1.0, 2.0, 3.0], dtype="float32")

# b = jax.lax.stop_gradient(b)

def loss_fn(a):
b = np.array([4.0, 5.0, 6.0], dtype="float32")

vectoradd_fn_a: jax.Callable = lambda a: apply_tesseract(
vectoradd_tess,
inputs=dict(
a=a,
b=b,
),
)

c = vectoradd_fn_a(a)["c"]
c = jax.lax.stop_gradient(c)

vectoradd_fn_b: jax.Callable = lambda a: apply_tesseract(
vectoradd_tess,
inputs=dict(
a=a,
b=c,
),
static_arg_names=["b"],
)

outputs = vectoradd_fn_b(a)

return jnp.sum((outputs["c"]) ** 2)

if use_jit:
loss_fn = jax.jit(loss_fn)

value_and_grad_fn = jax.value_and_grad(loss_fn)

assert value_and_grad_fn(a) is not None


def test_non_abstract_tesseract_vjp(served_non_abstract_tesseract):
non_abstract_tess = Tesseract(served_non_abstract_tesseract)

Expand Down
44 changes: 44 additions & 0 deletions tests/vectoradd_tesseract/tesseract_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


from typing import Any

from pydantic import BaseModel, Field
from tesseract_core.runtime import Array, Differentiable, Float32


class InputSchema(BaseModel):
a: Differentiable[Array[(None,), Float32]] = Field(description="Arbitrary vector a")
b: Array[(None,), Float32] = Field(description="Arbitrary vector b")


class OutputSchema(BaseModel):
c: Differentiable[Array[(None,), Float32]] = Field(
description="Vector s_a·a + s_b·b"
)


def apply(inputs: InputSchema) -> OutputSchema:
"""Adds two vectors `a` and `b`."""
return OutputSchema(
c=inputs.a + inputs.b,
)


def abstract_eval(abstract_inputs):
"""Abstract evaluation of the addition operation."""
return {
"c": abstract_inputs.a,
}


def vector_jacobian_product(
inputs: InputSchema,
vjp_inputs: set[str],
vjp_outputs: set[str],
cotangent_vector: dict[str, Any],
):
return {
"a": cotangent_vector["c"],
}
9 changes: 9 additions & 0 deletions tests/vectoradd_tesseract/tesseract_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: vectoradd_tesseracts
version: "2025-11-05"
description: |
Tesseract that adds two vectors. Uses jax internally.

build_config:
target_platform: "native"
# package_data: []
# custom_build_steps: []
1 change: 1 addition & 0 deletions tests/vectoradd_tesseract/tesseract_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
jax[cpu]
Loading