|
5 | 5 | import cProfile |
6 | 6 | import typing |
7 | 7 | import warnings |
| 8 | +from dataclasses import dataclass |
8 | 9 | from importlib import reload |
9 | 10 | from os.path import join |
| 11 | +from types import MethodType |
10 | 12 |
|
11 | 13 | import autograd as ag |
12 | 14 | import autograd.numpy as anp |
|
20 | 22 |
|
21 | 23 | import tidy3d as td |
22 | 24 | import tidy3d.web as web |
| 25 | +from tidy3d import Box, Geometry, GeometryGroup |
23 | 26 | from tidy3d.components.autograd.derivative_utils import DerivativeInfo |
24 | 27 | from tidy3d.components.autograd.field_map import FieldMap |
25 | 28 | from tidy3d.components.autograd.utils import is_tidy_box |
@@ -3093,3 +3096,79 @@ def test_frequency_coordinate_alignment(): |
3093 | 3096 | # Selecting non-existent frequency should fail |
3094 | 3097 | with pytest.raises(KeyError): |
3095 | 3098 | _slice_field_data(field_data_multi, np.array([1.5e14])) |
| 3099 | + |
| 3100 | + |
| 3101 | +def test_geometry_group_passes_intersected_bounds_to_children(): |
| 3102 | + """GeometryGroup should clip bounds_intersect for each child geometry.""" |
| 3103 | + |
| 3104 | + @dataclass |
| 3105 | + class SimpleDerivativeInfo: |
| 3106 | + paths: list[tuple] |
| 3107 | + bounds: tuple |
| 3108 | + bounds_intersect: tuple |
| 3109 | + simulation_bounds: tuple |
| 3110 | + interpolators: dict | None = None |
| 3111 | + |
| 3112 | + def create_interpolators(self, dtype: float = float): |
| 3113 | + return self.interpolators or {} |
| 3114 | + |
| 3115 | + def updated_copy(self, **kwargs): |
| 3116 | + data = { |
| 3117 | + "paths": self.paths, |
| 3118 | + "bounds": self.bounds, |
| 3119 | + "bounds_intersect": self.bounds_intersect, |
| 3120 | + "simulation_bounds": self.simulation_bounds, |
| 3121 | + "interpolators": self.interpolators, |
| 3122 | + } |
| 3123 | + data.update({k: v for k, v in kwargs.items() if k in data}) |
| 3124 | + return SimpleDerivativeInfo(**data) |
| 3125 | + |
| 3126 | + fully_inside_box = Box(center=(-1.0, 0.0, 0.0), size=(1.0, 1.0, 1.0)) |
| 3127 | + big_box = Box(center=(0.0, 0.0, 0.0), size=(10.0, 10.0, 10.0)) |
| 3128 | + |
| 3129 | + def record_method(self, derivative_info): |
| 3130 | + object.__setattr__(self, "recorded_bounds_intersect", derivative_info.bounds_intersect) |
| 3131 | + return {derivative_info.paths[0]: 0.0} |
| 3132 | + |
| 3133 | + boxes = (fully_inside_box, big_box) |
| 3134 | + |
| 3135 | + for box in boxes: |
| 3136 | + object.__setattr__(box, "recorded_bounds_intersect", None) |
| 3137 | + object.__setattr__(box, "_compute_derivatives", MethodType(record_method, box)) |
| 3138 | + group = GeometryGroup(geometries=boxes) |
| 3139 | + |
| 3140 | + # case where group bounds bigger than sim bounds |
| 3141 | + sim_bounds = ((-5.0, -5.0, -5.0), (5.0, 5.0, 5.0)) |
| 3142 | + |
| 3143 | + deriv_info = SimpleDerivativeInfo( |
| 3144 | + paths=[("geom", idx, "dummy") for idx, _ in enumerate(boxes)], |
| 3145 | + bounds=group.bounds, |
| 3146 | + bounds_intersect=Geometry.bounds_intersection(group.bounds, sim_bounds), |
| 3147 | + simulation_bounds=sim_bounds, |
| 3148 | + interpolators={}, |
| 3149 | + ) |
| 3150 | + |
| 3151 | + group._compute_derivatives(deriv_info) |
| 3152 | + |
| 3153 | + assert ( |
| 3154 | + object.__getattribute__(fully_inside_box, "recorded_bounds_intersect") |
| 3155 | + == fully_inside_box.bounds |
| 3156 | + ) |
| 3157 | + assert object.__getattribute__(big_box, "recorded_bounds_intersect") == sim_bounds |
| 3158 | + |
| 3159 | + # case where sim bounds bigger than group bounds |
| 3160 | + sim_bounds = ((-20.0, -20.0, -20.0), (20.0, 20.0, 20.0)) |
| 3161 | + |
| 3162 | + deriv_info = SimpleDerivativeInfo( |
| 3163 | + paths=[("geom", idx, "dummy") for idx, _ in enumerate(boxes)], |
| 3164 | + bounds=group.bounds, |
| 3165 | + bounds_intersect=Geometry.bounds_intersection(group.bounds, sim_bounds), |
| 3166 | + simulation_bounds=sim_bounds, |
| 3167 | + interpolators={}, |
| 3168 | + ) |
| 3169 | + |
| 3170 | + group._compute_derivatives(deriv_info) |
| 3171 | + |
| 3172 | + assert object.__getattribute__(big_box, "recorded_bounds_intersect") == group.bounds, ( |
| 3173 | + f"got {object.__getattribute__(big_box, 'recorded_bounds_intersect')} and {group.bounds}" |
| 3174 | + ) |
0 commit comments