Skip to content

jax.export fails for (Distil)BERT transformers models due to missing tracing signals #56

@kasper0406

Description

@kasper0406

Description

When using torchax to extract a JAX function from the (Distil)BERT model using Hugging Face transformers and then attempting to export it using jax.export, the process fails with a jax.errors.ConcretizationTypeError.

This happens because the (Distil)BERT model from transformers contains data-dependent control flow (e.g., checking if an attention_mask is all ones) to optimize execution. The library attempts to be smart and skip these checks if it detects it is being traced or compiled by PyTorch (using torch.jit.is_tracing() or is_torchdynamo_compiling())

However, when tracing with JAX via torchax, these PyTorch-specific flags are False. As a result, transformers executes the data-dependent checks, which fail because JAX tracers are abstract and cannot be concretized to booleans for if statements.

Reproduction

Here is a standalone script reproducing the issue with distilbert-base-uncased:

import torch
import jax
import torchax as tx
from transformers import AutoModel, AutoTokenizer
import numpy as np

def repro():
    # Load a simple transformer model
    model_name = "distilbert-base-uncased"
    model = AutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Create inputs
    text = "Hello, world!"
    inputs = tokenizer(text, return_tensors="pt")
    
    # We need to pass attention_mask to trigger the specific code path in transformers
    model_inputs = (inputs.input_ids, inputs.attention_mask)

    # Extract JAX function using torchax
    model.eval()
    weights, jax_func = tx.extract_jax(model)

    @jax.jit
    def wrapped_weights_func(inputs):
        out = jax_func(weights, inputs)
        return out

    print("Attempting to export with jax.export...")
    # Convert inputs to numpy for JAX
    jax_inputs = tuple([input.detach().numpy() for input in model_inputs])
    
    # This fails with ConcretizationTypeError
    jax.export.export(wrapped_weights_func)(jax_inputs)

if __name__ == "__main__":
    repro()

Error Trace

jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[]
This occurred in the item() method of jax.Array
...
File ".../transformers/modeling_attn_mask_utils.py", line 454, in _prepare_4d_attention_mask_for_sdpa
    if not is_tracing and torch.all(mask == 1):

Potential Solution

I am not sure if this should be fixed in torchax or in transformers, but wanted to provide signal that these models are failing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions