Skip to content

Commit 749fbc2

Browse files
marcorudolphflexyaugenst-flex
authored andcommitted
fix(tidy3d): FXC-4473-fix-gradients-for-box-in-geometry-group
1 parent 0b6817c commit 749fbc2

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2020
- Fix to `outer_dot` when frequencies stored in the data were not in increasing order. Previously, the result would be provided with re-sorted frequencies, which would not match the order of the original data.
2121
- Fixed bug where an extra spatial coordinate could appear in `complex_flux` and `ImpedanceCalculator` results.
2222
- Fixed normal for `Box` shape gradient computation to always point outward from boundary which is needed for correct PEC handling.
23+
- Fixed `Box` gradients within `GeometryGroup` where the group intersection boundaries were forwarded.
2324

2425
## [2.10.0rc3] - 2025-11-26
2526

tests/test_components/autograd/test_autograd.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import cProfile
66
import typing
77
import warnings
8+
from dataclasses import dataclass
89
from importlib import reload
910
from os.path import join
11+
from types import MethodType
1012

1113
import autograd as ag
1214
import autograd.numpy as anp
@@ -20,6 +22,7 @@
2022

2123
import tidy3d as td
2224
import tidy3d.web as web
25+
from tidy3d import Box, Geometry, GeometryGroup
2326
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
2427
from tidy3d.components.autograd.field_map import FieldMap
2528
from tidy3d.components.autograd.utils import is_tidy_box
@@ -3093,3 +3096,79 @@ def test_frequency_coordinate_alignment():
30933096
# Selecting non-existent frequency should fail
30943097
with pytest.raises(KeyError):
30953098
_slice_field_data(field_data_multi, np.array([1.5e14]))
3099+
3100+
3101+
def test_geometry_group_passes_intersected_bounds_to_children():
3102+
"""GeometryGroup should clip bounds_intersect for each child geometry."""
3103+
3104+
@dataclass
3105+
class SimpleDerivativeInfo:
3106+
paths: list[tuple]
3107+
bounds: tuple
3108+
bounds_intersect: tuple
3109+
simulation_bounds: tuple
3110+
interpolators: dict | None = None
3111+
3112+
def create_interpolators(self, dtype: float = float):
3113+
return self.interpolators or {}
3114+
3115+
def updated_copy(self, **kwargs):
3116+
data = {
3117+
"paths": self.paths,
3118+
"bounds": self.bounds,
3119+
"bounds_intersect": self.bounds_intersect,
3120+
"simulation_bounds": self.simulation_bounds,
3121+
"interpolators": self.interpolators,
3122+
}
3123+
data.update({k: v for k, v in kwargs.items() if k in data})
3124+
return SimpleDerivativeInfo(**data)
3125+
3126+
fully_inside_box = Box(center=(-1.0, 0.0, 0.0), size=(1.0, 1.0, 1.0))
3127+
big_box = Box(center=(0.0, 0.0, 0.0), size=(10.0, 10.0, 10.0))
3128+
3129+
def record_method(self, derivative_info):
3130+
object.__setattr__(self, "recorded_bounds_intersect", derivative_info.bounds_intersect)
3131+
return {derivative_info.paths[0]: 0.0}
3132+
3133+
boxes = (fully_inside_box, big_box)
3134+
3135+
for box in boxes:
3136+
object.__setattr__(box, "recorded_bounds_intersect", None)
3137+
object.__setattr__(box, "_compute_derivatives", MethodType(record_method, box))
3138+
group = GeometryGroup(geometries=boxes)
3139+
3140+
# case where group bounds bigger than sim bounds
3141+
sim_bounds = ((-5.0, -5.0, -5.0), (5.0, 5.0, 5.0))
3142+
3143+
deriv_info = SimpleDerivativeInfo(
3144+
paths=[("geom", idx, "dummy") for idx, _ in enumerate(boxes)],
3145+
bounds=group.bounds,
3146+
bounds_intersect=Geometry.bounds_intersection(group.bounds, sim_bounds),
3147+
simulation_bounds=sim_bounds,
3148+
interpolators={},
3149+
)
3150+
3151+
group._compute_derivatives(deriv_info)
3152+
3153+
assert (
3154+
object.__getattribute__(fully_inside_box, "recorded_bounds_intersect")
3155+
== fully_inside_box.bounds
3156+
)
3157+
assert object.__getattribute__(big_box, "recorded_bounds_intersect") == sim_bounds
3158+
3159+
# case where sim bounds bigger than group bounds
3160+
sim_bounds = ((-20.0, -20.0, -20.0), (20.0, 20.0, 20.0))
3161+
3162+
deriv_info = SimpleDerivativeInfo(
3163+
paths=[("geom", idx, "dummy") for idx, _ in enumerate(boxes)],
3164+
bounds=group.bounds,
3165+
bounds_intersect=Geometry.bounds_intersection(group.bounds, sim_bounds),
3166+
simulation_bounds=sim_bounds,
3167+
interpolators={},
3168+
)
3169+
3170+
group._compute_derivatives(deriv_info)
3171+
3172+
assert object.__getattribute__(big_box, "recorded_bounds_intersect") == group.bounds, (
3173+
f"got {object.__getattribute__(big_box, 'recorded_bounds_intersect')} and {group.bounds}"
3174+
)

tidy3d/components/geometry/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3594,6 +3594,9 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
35943594
geo_info = derivative_info.updated_copy(
35953595
paths=[tuple(geo_path)],
35963596
bounds=geo.bounds,
3597+
bounds_intersect=self.bounds_intersection(
3598+
geo.bounds, derivative_info.simulation_bounds
3599+
),
35973600
eps_approx=True,
35983601
deep=False,
35993602
interpolators=interpolators,

0 commit comments

Comments
 (0)