-
Notifications
You must be signed in to change notification settings - Fork 17
Description
Description
When using torchax to extract a JAX function from the (Distil)BERT model using Hugging Face transformers and then attempting to export it using jax.export, the process fails with a jax.errors.ConcretizationTypeError.
This happens because the (Distil)BERT model from transformers contains data-dependent control flow (e.g., checking if an attention_mask is all ones) to optimize execution. The library attempts to be smart and skip these checks if it detects it is being traced or compiled by PyTorch (using torch.jit.is_tracing() or is_torchdynamo_compiling())
However, when tracing with JAX via torchax, these PyTorch-specific flags are False. As a result, transformers executes the data-dependent checks, which fail because JAX tracers are abstract and cannot be concretized to booleans for if statements.
Reproduction
Here is a standalone script reproducing the issue with distilbert-base-uncased:
import torch
import jax
import torchax as tx
from transformers import AutoModel, AutoTokenizer
import numpy as np
def repro():
# Load a simple transformer model
model_name = "distilbert-base-uncased"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create inputs
text = "Hello, world!"
inputs = tokenizer(text, return_tensors="pt")
# We need to pass attention_mask to trigger the specific code path in transformers
model_inputs = (inputs.input_ids, inputs.attention_mask)
# Extract JAX function using torchax
model.eval()
weights, jax_func = tx.extract_jax(model)
@jax.jit
def wrapped_weights_func(inputs):
out = jax_func(weights, inputs)
return out
print("Attempting to export with jax.export...")
# Convert inputs to numpy for JAX
jax_inputs = tuple([input.detach().numpy() for input in model_inputs])
# This fails with ConcretizationTypeError
jax.export.export(wrapped_weights_func)(jax_inputs)
if __name__ == "__main__":
repro()Error Trace
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[]
This occurred in the item() method of jax.Array
...
File ".../transformers/modeling_attn_mask_utils.py", line 454, in _prepare_4d_attention_mask_for_sdpa
if not is_tracing and torch.all(mask == 1):
Potential Solution
I am not sure if this should be fixed in torchax or in transformers, but wanted to provide signal that these models are failing.