Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ trimesh
warp-lang
numpy-stl
pydantic
pytest
ruff
usd-core
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
extras_require={
"cuda": ["jax[cuda13]>=0.8.0"], # For CUDA installations
"tpu": ["jax[tpu]>=0.8.0"], # For TPU installations
"test": ["pytest>=8.0.0"],
},
python_requires=">=3.11",
dependency_links=["https://storage.googleapis.com/jax-releases/libtpu_releases.html"],
Expand Down
66 changes: 51 additions & 15 deletions xlb/operator/collision/smagorinsky_les_bgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,41 @@ def __init__(
self.smagorinsky_coef = smagorinsky_coef
super().__init__(velocity_set, precision_policy, compute_backend)

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray, omega):
fneq = f - feq

pi_neq = jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0))

if self.velocity_set.d == 3:
diag = pi_neq[(0, 3, 5), ...]
offdiag = pi_neq[(1, 2, 4), ...]
else:
diag = pi_neq[(0, 2), ...]
offdiag = pi_neq[(1,), ...]

strain = jnp.sum(diag * diag, axis=0) + self.compute_dtype(2.0) * jnp.sum(offdiag * offdiag, axis=0)

tau0 = self.compute_dtype(1.0) / self.compute_dtype(omega)
cs = self.compute_dtype(self.smagorinsky_coef)
tau = self.compute_dtype(0.5) * (
tau0 + jnp.sqrt(tau0 * tau0 + self.compute_dtype(36.0) * (cs * cs) * jnp.sqrt(strain))
)

omega_eff = self.compute_dtype(1.0) / tau
fout = f - omega_eff[None, ...] * fneq
return fout

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
_d = self.velocity_set.d
_c = self.velocity_set.c
_cc = self.velocity_set.cc
_smagorinsky_coef = wp.constant(self.compute_dtype(self.smagorinsky_coef))
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2
_pi_vec = wp.vec(_pi_dim, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)

# Construct the functional
@wp.func
Expand Down Expand Up @@ -71,22 +100,29 @@ def functional(
# }

# Compute strain
strain = wp.float32(0.0)
for l in range(self.velocity_set.q):
# diagonal terms
if (_c[0, l] + _c[1, l] + _c[2, l]) == 1:
strain += fneq[l] * fneq[l]

# Off-diagonal terms
if (_c[0, l] + _c[1, l] + _c[2, l]) >= 2:
strain += 2.0 * fneq[l] * fneq[l]
pi_neq = _pi_vec()
for a in range(_pi_dim):
pi_neq[a] = self.compute_dtype(0.0)
for l in range(self.velocity_set.q):
pi_neq[a] += _cc[l, a] * fneq[l]

strain = self.compute_dtype(0.0)
if wp.static(_d == 3):
strain += pi_neq[0] * pi_neq[0] + pi_neq[3] * pi_neq[3] + pi_neq[5] * pi_neq[5]
strain += self.compute_dtype(2.0) * (pi_neq[1] * pi_neq[1] + pi_neq[2] * pi_neq[2] + pi_neq[4] * pi_neq[4])
else:
strain += pi_neq[0] * pi_neq[0] + pi_neq[2] * pi_neq[2]
strain += self.compute_dtype(2.0) * (pi_neq[1] * pi_neq[1])

# Compute the Smagorinsky model
_tau = self.compute_dtype(1.0 / omega)
tau = _tau + (0.5 * (wp.sqrt(_tau * _tau + 36.0 * (_smagorinsky_coef**2.0) * wp.sqrt(strain)) - _tau))
_tau = self.compute_dtype(1.0) / self.compute_dtype(omega)
tau = _tau + (
self.compute_dtype(0.5)
* (wp.sqrt(_tau * _tau + self.compute_dtype(36.0) * (_smagorinsky_coef**2.0) * wp.sqrt(strain)) - _tau)
)

# Compute the collision
fout = f - (1.0 / tau) * fneq
fout = f - (self.compute_dtype(1.0) / tau) * fneq
return fout

# Construct the warp kernel
Expand All @@ -109,7 +145,7 @@ def kernel(
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_feq[l] = feq[l, index[0], index[1], index[2]]
_u = self._warp_u_vec()
_u = _u_vec()
for l in range(_d):
_u[l] = u[l, index[0], index[1], index[2]]
_rho = rho[0, index[0], index[1], index[2]]
Expand All @@ -119,7 +155,7 @@ def kernel(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1], index[2]] = _fout[l]
fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l])

return functional, kernel

Expand Down
10 changes: 10 additions & 0 deletions xlb/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,17 @@ def __call__(self, *args, callback=None, **kwargs):
method_candidates = [
(key, method) for key, method in self._backends.items() if key[0] == self.__class__.__name__ and key[1] == self.compute_backend
]
if not method_candidates:
supported = [key for key in self._backends.keys() if key[0] == self.__class__.__name__]
raise NotImplementedError(
f"No implementation found for operator {self.__class__.__name__} with backend {self.compute_backend}. "
f"Available implementations: {supported}"
)

bound_arguments = None
key = None
error = None
traceback_str = None
for key, backend_method in method_candidates:
try:
# This attempts to bind the provided args and kwargs to the compute_backend method's signature
Expand Down