From 4a099e0a02735b79af59590b814d947d635c50f3 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 9 Jan 2026 10:54:17 -0500 Subject: [PATCH] Fixed Smagorinsky implementation and added jax version --- requirements.txt | 1 + setup.py | 1 + xlb/operator/collision/smagorinsky_les_bgk.py | 66 ++++++++++++++----- xlb/operator/operator.py | 10 +++ 4 files changed, 63 insertions(+), 15 deletions(-) diff --git a/requirements.txt b/requirements.txt index d13ef830..0693dd31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ trimesh warp-lang numpy-stl pydantic +pytest ruff usd-core \ No newline at end of file diff --git a/setup.py b/setup.py index 5d376151..c213d006 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ extras_require={ "cuda": ["jax[cuda13]>=0.8.0"], # For CUDA installations "tpu": ["jax[tpu]>=0.8.0"], # For TPU installations + "test": ["pytest>=8.0.0"], }, python_requires=">=3.11", dependency_links=["https://storage.googleapis.com/jax-releases/libtpu_releases.html"], diff --git a/xlb/operator/collision/smagorinsky_les_bgk.py b/xlb/operator/collision/smagorinsky_les_bgk.py index deb86dea..92ae23d2 100644 --- a/xlb/operator/collision/smagorinsky_les_bgk.py +++ b/xlb/operator/collision/smagorinsky_les_bgk.py @@ -26,12 +26,41 @@ def __init__( self.smagorinsky_coef = smagorinsky_coef super().__init__(velocity_set, precision_policy, compute_backend) + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0,)) + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray, omega): + fneq = f - feq + + pi_neq = jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0)) + + if self.velocity_set.d == 3: + diag = pi_neq[(0, 3, 5), ...] + offdiag = pi_neq[(1, 2, 4), ...] + else: + diag = pi_neq[(0, 2), ...] + offdiag = pi_neq[(1,), ...] + + strain = jnp.sum(diag * diag, axis=0) + self.compute_dtype(2.0) * jnp.sum(offdiag * offdiag, axis=0) + + tau0 = self.compute_dtype(1.0) / self.compute_dtype(omega) + cs = self.compute_dtype(self.smagorinsky_coef) + tau = self.compute_dtype(0.5) * ( + tau0 + jnp.sqrt(tau0 * tau0 + self.compute_dtype(36.0) * (cs * cs) * jnp.sqrt(strain)) + ) + + omega_eff = self.compute_dtype(1.0) / tau + fout = f - omega_eff[None, ...] * fneq + return fout + def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _d = self.velocity_set.d - _c = self.velocity_set.c + _cc = self.velocity_set.cc _smagorinsky_coef = wp.constant(self.compute_dtype(self.smagorinsky_coef)) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2 + _pi_vec = wp.vec(_pi_dim, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) # Construct the functional @wp.func @@ -71,22 +100,29 @@ def functional( # } # Compute strain - strain = wp.float32(0.0) - for l in range(self.velocity_set.q): - # diagonal terms - if (_c[0, l] + _c[1, l] + _c[2, l]) == 1: - strain += fneq[l] * fneq[l] - - # Off-diagonal terms - if (_c[0, l] + _c[1, l] + _c[2, l]) >= 2: - strain += 2.0 * fneq[l] * fneq[l] + pi_neq = _pi_vec() + for a in range(_pi_dim): + pi_neq[a] = self.compute_dtype(0.0) + for l in range(self.velocity_set.q): + pi_neq[a] += _cc[l, a] * fneq[l] + + strain = self.compute_dtype(0.0) + if wp.static(_d == 3): + strain += pi_neq[0] * pi_neq[0] + pi_neq[3] * pi_neq[3] + pi_neq[5] * pi_neq[5] + strain += self.compute_dtype(2.0) * (pi_neq[1] * pi_neq[1] + pi_neq[2] * pi_neq[2] + pi_neq[4] * pi_neq[4]) + else: + strain += pi_neq[0] * pi_neq[0] + pi_neq[2] * pi_neq[2] + strain += self.compute_dtype(2.0) * (pi_neq[1] * pi_neq[1]) # Compute the Smagorinsky model - _tau = self.compute_dtype(1.0 / omega) - tau = _tau + (0.5 * (wp.sqrt(_tau * _tau + 36.0 * (_smagorinsky_coef**2.0) * wp.sqrt(strain)) - _tau)) + _tau = self.compute_dtype(1.0) / self.compute_dtype(omega) + tau = _tau + ( + self.compute_dtype(0.5) + * (wp.sqrt(_tau * _tau + self.compute_dtype(36.0) * (_smagorinsky_coef**2.0) * wp.sqrt(strain)) - _tau) + ) # Compute the collision - fout = f - (1.0 / tau) * fneq + fout = f - (self.compute_dtype(1.0) / tau) * fneq return fout # Construct the warp kernel @@ -109,7 +145,7 @@ def kernel( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] - _u = self._warp_u_vec() + _u = _u_vec() for l in range(_d): _u[l] = u[l, index[0], index[1], index[2]] _rho = rho[0, index[0], index[1], index[2]] @@ -119,7 +155,7 @@ def kernel( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) return functional, kernel diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index ad4d3e61..fcbd07b0 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -55,7 +55,17 @@ def __call__(self, *args, callback=None, **kwargs): method_candidates = [ (key, method) for key, method in self._backends.items() if key[0] == self.__class__.__name__ and key[1] == self.compute_backend ] + if not method_candidates: + supported = [key for key in self._backends.keys() if key[0] == self.__class__.__name__] + raise NotImplementedError( + f"No implementation found for operator {self.__class__.__name__} with backend {self.compute_backend}. " + f"Available implementations: {supported}" + ) + bound_arguments = None + key = None + error = None + traceback_str = None for key, backend_method in method_candidates: try: # This attempts to bind the provided args and kwargs to the compute_backend method's signature