diffinytrace 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffinytrace/__init__.py +122 -0
- diffinytrace/basis_functions/__init__.py +14 -0
- diffinytrace/basis_functions/bspline.py +521 -0
- diffinytrace/basis_functions/chebyshev.py +3 -0
- diffinytrace/basis_functions/legendre.py +77 -0
- diffinytrace/basis_functions/zernike.py +235 -0
- diffinytrace/config.py +140 -0
- diffinytrace/constraints.py +54 -0
- diffinytrace/element.py +1660 -0
- diffinytrace/export/__init__.py +8 -0
- diffinytrace/export/cad.py +253 -0
- diffinytrace/gaussian_smoother.py +530 -0
- diffinytrace/hat_smoother.py +44 -0
- diffinytrace/integrators.py +452 -0
- diffinytrace/intersection.py +285 -0
- diffinytrace/optimize.py +808 -0
- diffinytrace/physical_object.py +150 -0
- diffinytrace/plotting/__init__.py +16 -0
- diffinytrace/plotting/core.py +92 -0
- diffinytrace/plotting/quantity2D.py +188 -0
- diffinytrace/plotting/system2D.py +220 -0
- diffinytrace/plotting/system3D.py +327 -0
- diffinytrace/plotting/wavelength.py +231 -0
- diffinytrace/refractive_index.py +101 -0
- diffinytrace/render.py +77 -0
- diffinytrace/source.py +661 -0
- diffinytrace/spectrum.py +79 -0
- diffinytrace/surface.py +468 -0
- diffinytrace/target_grid.py +399 -0
- diffinytrace/transforms.py +472 -0
- diffinytrace/utils/__init__.py +7 -0
- diffinytrace/utils/autograd.py +116 -0
- diffinytrace/utils/irradiance_importer.py +134 -0
- diffinytrace-2.1.dist-info/METADATA +26 -0
- diffinytrace-2.1.dist-info/RECORD +38 -0
- diffinytrace-2.1.dist-info/WHEEL +5 -0
- diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
- diffinytrace-2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module implements grid-based spatial aggregation for ray optics.
|
|
3
|
+
|
|
4
|
+
Classes:
|
|
5
|
+
- Grid: Represents a 2D grid for spatial aggregation and statistics.
|
|
6
|
+
- GridSquare: Square variant of Grid for symmetric apertures.
|
|
7
|
+
|
|
8
|
+
Functions:
|
|
9
|
+
- (none at top level)
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
>>> grid = Grid([0, 1], [0, 1], 10, 10)
|
|
13
|
+
>>> area = grid.get_area()
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
17
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"Grid",
|
|
21
|
+
"GridSquare"
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
from sklearn.neighbors import NearestNeighbors
|
|
29
|
+
import numpy as np
|
|
30
|
+
|
|
31
|
+
class Grid():
|
|
32
|
+
"""
|
|
33
|
+
Represents a 2D grid over a rectangular area with aggregation and indexing utilities.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
y_range (tuple[float, float]): The range in y-direction, as (y_min, y_max).
|
|
37
|
+
x_range (tuple[float, float]): The range in x-direction, as (x_min, x_max).
|
|
38
|
+
y_grid_size (int): Number of grid cells in y-direction.
|
|
39
|
+
x_grid_size (int): Number of grid cells in x-direction.
|
|
40
|
+
"""
|
|
41
|
+
def __init__(self,y_range,x_range,y_grid_size,x_grid_size):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.y_range = np.array(y_range)
|
|
44
|
+
self.x_range = np.array(x_range)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
self.x_grid_size = x_grid_size
|
|
48
|
+
self.y_grid_size = y_grid_size
|
|
49
|
+
self.x_delta = (self.x_range[1]-self.x_range[0])/x_grid_size
|
|
50
|
+
self.y_delta = (self.y_range[1]-self.y_range[0])/y_grid_size
|
|
51
|
+
|
|
52
|
+
def get_area(self):
|
|
53
|
+
r"""
|
|
54
|
+
Computes the total area of the grid.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
float: Total area of the grid.
|
|
58
|
+
|
|
59
|
+
.. math::
|
|
60
|
+
A = (x_{max} - x_{min}) \cdot (y_{max} - y_{min})
|
|
61
|
+
"""
|
|
62
|
+
return (self.x_range[1]-self.x_range[0])*(self.y_range[1]-self.y_range[0])
|
|
63
|
+
|
|
64
|
+
def get_pixel_area(self):
|
|
65
|
+
r"""
|
|
66
|
+
Returns the area of a single pixel/grid cell.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
float: Area of a single grid cell.
|
|
70
|
+
|
|
71
|
+
.. math::
|
|
72
|
+
A_{pixel} = \Delta x \cdot \Delta y
|
|
73
|
+
"""
|
|
74
|
+
return self.x_delta*self.y_delta
|
|
75
|
+
|
|
76
|
+
def get_yi_xi(self,local_points,round_to_bounds=True):
|
|
77
|
+
r"""
|
|
78
|
+
Converts 2D local coordinates to integer grid indices.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
local_points (torch.Tensor): Tensor of shape (N, 2) representing 2D points.
|
|
82
|
+
round_to_bounds (bool): If True, clamps indices to stay within grid bounds. If False, returns a mask indicating valid indices.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors (yi, xi) of shape (N,).
|
|
86
|
+
"""
|
|
87
|
+
if len(local_points.shape) != 2 or local_points.shape[1] != 2:
|
|
88
|
+
raise RuntimeError("The local_points must be in local coordinates and of shape [#points,2]")
|
|
89
|
+
local_points = local_points.detach()
|
|
90
|
+
|
|
91
|
+
ref_x = (local_points[:,0]-self.x_range[0])/self.x_delta
|
|
92
|
+
ref_y = (local_points[:,1]-self.y_range[0])/self.y_delta
|
|
93
|
+
|
|
94
|
+
xi = torch.floor(ref_x).long()
|
|
95
|
+
yi = torch.floor(ref_y).long()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
valid = (xi>=self.x_grid_size).float()+(xi<0).float()+(yi>=self.y_grid_size).float()+(yi<0).float()
|
|
99
|
+
valid = valid==0.0
|
|
100
|
+
|
|
101
|
+
if ((xi>=self.x_grid_size).any() or (xi<0).any() or (yi>=self.y_grid_size).any() or (yi<0).any()):
|
|
102
|
+
yi = torch.clamp(yi,min=0,max=(self.y_grid_size-1))
|
|
103
|
+
xi = torch.clamp(xi,min=0,max=(self.x_grid_size-1))
|
|
104
|
+
#else:
|
|
105
|
+
# raise RuntimeError(f"Target grid ERROR: points out of bounds! max xi={xi.max()}, min xi={xi.min()},max yi={yi.max()}, min yi={yi.min()}")
|
|
106
|
+
|
|
107
|
+
if round_to_bounds:
|
|
108
|
+
return (yi,xi)
|
|
109
|
+
else:
|
|
110
|
+
return (yi,xi),valid
|
|
111
|
+
|
|
112
|
+
def get_k(self,local_points,round_to_bounds=True):
|
|
113
|
+
r"""
|
|
114
|
+
Maps local coordinates to flattened grid indices.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
local_points (torch.Tensor): Tensor of shape (N, 2).
|
|
118
|
+
round_to_bounds (bool): Whether to clamp indices to grid bounds.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
122
|
+
- If `round_to_bounds` is True: Tensor of shape (N,).
|
|
123
|
+
- Otherwise: Tuple (indices, validity_mask).
|
|
124
|
+
"""
|
|
125
|
+
if round_to_bounds:
|
|
126
|
+
yi,xi = self.get_yi_xi(local_points,round_to_bounds=round_to_bounds)
|
|
127
|
+
return (yi*self.x_grid_size+xi).long()
|
|
128
|
+
else:
|
|
129
|
+
(yi,xi),valid = self.get_yi_xi(local_points,round_to_bounds=round_to_bounds)
|
|
130
|
+
k = (yi*self.x_grid_size+xi).long()
|
|
131
|
+
return k,valid
|
|
132
|
+
|
|
133
|
+
def map_matrix_to_ray(self,local_points,old_matrix):
|
|
134
|
+
r"""
|
|
135
|
+
Maps a matrix defined on the grid to the given local points.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
local_points (torch.Tensor): Points of shape (N, 2).
|
|
139
|
+
old_matrix (torch.Tensor): Matrix of shape (H, W, ...).
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
torch.Tensor: Resampled matrix values of shape (N, ...).
|
|
143
|
+
"""
|
|
144
|
+
device = local_points.device
|
|
145
|
+
dtype = local_points.dtype
|
|
146
|
+
k = self.get_k(local_points)
|
|
147
|
+
return old_matrix.reshape(-1)[k].reshape(local_points.shape[0],*old_matrix.shape[2:])
|
|
148
|
+
|
|
149
|
+
def sum(self,
|
|
150
|
+
local_points:torch.Tensor,
|
|
151
|
+
values:torch.Tensor,
|
|
152
|
+
old_matrix = None,
|
|
153
|
+
round_to_bounds:bool = False):
|
|
154
|
+
"""
|
|
155
|
+
Sums values over the grid based on point locations.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
local_points (torch.Tensor): Points of shape (N, 2).
|
|
159
|
+
values (torch.Tensor): Values of shape (N,) or (N, D).
|
|
160
|
+
old_matrix (torch.Tensor or None): Previous result for accumulation.
|
|
161
|
+
round_to_bounds (bool): Clamp indices to bounds if True.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
torch.Tensor: Aggregated result of shape (H, W).
|
|
165
|
+
"""
|
|
166
|
+
device = local_points.device
|
|
167
|
+
dtype = local_points.dtype
|
|
168
|
+
out = torch.zeros((self.x_grid_size*self.y_grid_size),device=device,dtype=dtype)
|
|
169
|
+
if not old_matrix is None:
|
|
170
|
+
out = old_matrix
|
|
171
|
+
|
|
172
|
+
if round_to_bounds:
|
|
173
|
+
k = self.get_k(local_points,round_to_bounds)
|
|
174
|
+
out.scatter_add_(0,k,values)
|
|
175
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
176
|
+
return out
|
|
177
|
+
else:
|
|
178
|
+
k,valid = self.get_k(local_points,round_to_bounds)
|
|
179
|
+
values = values[valid]
|
|
180
|
+
k = k[valid]
|
|
181
|
+
out.scatter_add_(0,k,values)
|
|
182
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
183
|
+
return out
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def prod(self,local_points,values,old_matrix = None,round_to_bounds=False):
|
|
187
|
+
"""
|
|
188
|
+
Multiplies values over the grid based on point locations.
|
|
189
|
+
Args:
|
|
190
|
+
local_points (torch.Tensor): Points of shape (N, 2).
|
|
191
|
+
values (torch.Tensor): Values of shape (N,) or (N, D).
|
|
192
|
+
old_matrix (torch.Tensor or None): Previous result for accumulation.
|
|
193
|
+
round_to_bounds (bool): Clamp indices to bounds if True.
|
|
194
|
+
Returns:
|
|
195
|
+
torch.Tensor: Aggregated result of shape (H, W).
|
|
196
|
+
"""
|
|
197
|
+
device = local_points.device
|
|
198
|
+
dtype = local_points.dtype
|
|
199
|
+
out = torch.ones((self.y_grid_size*self.x_grid_size),device=device,dtype=dtype)
|
|
200
|
+
if not old_matrix is None:
|
|
201
|
+
out = old_matrix
|
|
202
|
+
if round_to_bounds:
|
|
203
|
+
k = self.get_k(local_points,round_to_bounds)
|
|
204
|
+
out.scatter_reduce_(0,k,values,reduce='prod')
|
|
205
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
206
|
+
return out
|
|
207
|
+
else:
|
|
208
|
+
k,valid = self.get_k(local_points,round_to_bounds)
|
|
209
|
+
values = values[valid]
|
|
210
|
+
k = k[valid]
|
|
211
|
+
out.scatter_reduce_(0,k,values,"prod")
|
|
212
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
213
|
+
return out
|
|
214
|
+
|
|
215
|
+
def mean(self,local_points,values,old_matrix = None,round_to_bounds=False):
|
|
216
|
+
"""
|
|
217
|
+
Computes the mean of values over the grid based on point locations.
|
|
218
|
+
Args:
|
|
219
|
+
local_points (torch.Tensor): Points of shape (N, 2).
|
|
220
|
+
values (torch.Tensor): Values of shape (N,) or (N, D).
|
|
221
|
+
old_matrix (torch.Tensor or None): Previous result for accumulation.
|
|
222
|
+
round_to_bounds (bool): Clamp indices to bounds if True.
|
|
223
|
+
Returns:
|
|
224
|
+
torch.Tensor: Aggregated result of shape (H, W).
|
|
225
|
+
"""
|
|
226
|
+
device = local_points.device
|
|
227
|
+
dtype = local_points.dtype
|
|
228
|
+
out = torch.zeros((self.y_grid_size*self.x_grid_size),device=device,dtype=dtype)
|
|
229
|
+
if not old_matrix is None:
|
|
230
|
+
out = old_matrix
|
|
231
|
+
if round_to_bounds:
|
|
232
|
+
k = self.get_k(local_points)
|
|
233
|
+
out.scatter_reduce_(0,k,values,reduce='mean',include_self=False)
|
|
234
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
235
|
+
return out
|
|
236
|
+
else:
|
|
237
|
+
k,valid = self.get_k(local_points,round_to_bounds)
|
|
238
|
+
values = values[valid]
|
|
239
|
+
k = k[valid]
|
|
240
|
+
out.scatter_reduce_(0,k,values,reduce='mean',include_self=False)
|
|
241
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
242
|
+
return out
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def __get_args(self,M,b,v):
|
|
247
|
+
device = v.device
|
|
248
|
+
dtype = v.dtype
|
|
249
|
+
M_argmin = torch.full((self.y_grid_size*self.x_grid_size,), -1, dtype=torch.long,device=device)
|
|
250
|
+
mask = (v == M[b])
|
|
251
|
+
indices = torch.arange(len(v))
|
|
252
|
+
M_argmin.scatter_(0, b[mask], indices[mask])
|
|
253
|
+
return M_argmin
|
|
254
|
+
|
|
255
|
+
def min(self,local_points,values,old_matrix = None,return_args=False):
|
|
256
|
+
"""Finds the minimum value for each grid cell based on local points.
|
|
257
|
+
Args:
|
|
258
|
+
local_points (torch.Tensor): Points of shape (N, 2).
|
|
259
|
+
values (torch.Tensor): Values of shape (N,) or (N, D).
|
|
260
|
+
old_matrix (torch.Tensor or None): Previous result for accumulation.
|
|
261
|
+
return_args (bool): If True, also return indices.
|
|
262
|
+
Returns:
|
|
263
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Minimum values, optionally with indices.
|
|
264
|
+
"""
|
|
265
|
+
device = local_points.device
|
|
266
|
+
dtype = local_points.dtype
|
|
267
|
+
out = torch.full((self.y_grid_size*self.x_grid_size,),float('inf'),device=device,dtype=dtype)
|
|
268
|
+
if not old_matrix is None:
|
|
269
|
+
out = old_matrix
|
|
270
|
+
k = self.get_k(local_points)
|
|
271
|
+
out.scatter_reduce_(0,k,values,reduce='amin')
|
|
272
|
+
if return_args:
|
|
273
|
+
out_args = self.__get_args(out,k,values)
|
|
274
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
275
|
+
out_args = out_args.reshape(self.y_grid_size,self.x_grid_size)
|
|
276
|
+
return out,out_args
|
|
277
|
+
else:
|
|
278
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
279
|
+
return out
|
|
280
|
+
|
|
281
|
+
def max(self,local_points,values,old_matrix = None,return_args=False):
|
|
282
|
+
"""
|
|
283
|
+
Finds the maximum value for each grid cell based on local points.
|
|
284
|
+
Args:
|
|
285
|
+
local_points (torch.Tensor): Points of shape (N, 2).
|
|
286
|
+
values (torch.Tensor): Values of shape (N,) or (N, D).
|
|
287
|
+
old_matrix (torch.Tensor or None): Previous result for accumulation.
|
|
288
|
+
return_args (bool): If True, also return indices.
|
|
289
|
+
Returns:
|
|
290
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Maximum values, optionally with indices.
|
|
291
|
+
"""
|
|
292
|
+
device = local_points.device
|
|
293
|
+
dtype = local_points.dtype
|
|
294
|
+
out = torch.full((self.y_grid_size*self.x_grid_size,),float('-inf'),device=device,dtype=dtype)
|
|
295
|
+
if not old_matrix is None:
|
|
296
|
+
out = old_matrix
|
|
297
|
+
k = self.get_k(local_points)
|
|
298
|
+
out.scatter_reduce_(0,k,values,reduce='amax')
|
|
299
|
+
if return_args:
|
|
300
|
+
out_args = self.__get_args(out,k,values)
|
|
301
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
302
|
+
out_args = out_args.reshape(self.y_grid_size,self.x_grid_size)
|
|
303
|
+
return out,out_args
|
|
304
|
+
else:
|
|
305
|
+
out = out.reshape(self.y_grid_size,self.x_grid_size)
|
|
306
|
+
return out
|
|
307
|
+
|
|
308
|
+
def __get_x_middle(self):
|
|
309
|
+
x_middle = self.x_delta*0.5+torch.arange(0,self.x_grid_size)*self.x_delta+self.x_range[0]
|
|
310
|
+
return x_middle
|
|
311
|
+
|
|
312
|
+
def __get_y_middle(self):
|
|
313
|
+
y_middle = self.y_delta*0.5+torch.arange(0,self.y_grid_size)*self.y_delta+self.y_range[0]
|
|
314
|
+
return y_middle
|
|
315
|
+
|
|
316
|
+
def get_y_middle(self):
|
|
317
|
+
return self.__get_y_middle()
|
|
318
|
+
|
|
319
|
+
def get_x_middle(self):
|
|
320
|
+
return self.__get_x_middle()
|
|
321
|
+
|
|
322
|
+
def nearest(self,local_points,return_args=False):
|
|
323
|
+
"""
|
|
324
|
+
Finds the nearest pixel for each local point using L2 distance.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
local_points (torch.Tensor): Tensor of shape (N, 2).
|
|
328
|
+
return_args (bool): If True, also return indices.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Minimum squared distances, optionally with indices.
|
|
332
|
+
"""
|
|
333
|
+
x_middle = self.__get_x_middle()
|
|
334
|
+
y_middle = self.__get_y_middle()
|
|
335
|
+
|
|
336
|
+
yi,xi = self.get_yi_xi(local_points)
|
|
337
|
+
|
|
338
|
+
xdiff = (x_middle[xi]-local_points[:,0])**2.0
|
|
339
|
+
ydiff = (y_middle[yi]-local_points[:,1])**2.0
|
|
340
|
+
l2diff = xdiff+ydiff
|
|
341
|
+
return self.min(local_points,l2diff,return_args=return_args)
|
|
342
|
+
|
|
343
|
+
def get_pixel_centers(self):
|
|
344
|
+
"""
|
|
345
|
+
Returns the 2D center coordinates of each grid cell.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
torch.Tensor: Tensor of shape (H, W, 2) with pixel center coordinates.
|
|
349
|
+
"""
|
|
350
|
+
x_middle = self.__get_x_middle()
|
|
351
|
+
y_middle = self.__get_y_middle()
|
|
352
|
+
|
|
353
|
+
grid_y,grid_x = torch.meshgrid(y_middle, x_middle, indexing='ij')
|
|
354
|
+
V = torch.cat([grid_x.reshape(-1,1),grid_y.reshape(-1,1)],dim=-1)
|
|
355
|
+
return V.reshape(self.y_grid_size,self.x_grid_size,2)
|
|
356
|
+
|
|
357
|
+
def get_nearest_ray(self,local_points):
|
|
358
|
+
"""
|
|
359
|
+
Finds the index of the nearest ray for each grid cell using `sklearn.neighbors.NearestNeighbors`.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
local_points (torch.Tensor): Tensor of shape (N, 2) representing sampled rays.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
torch.Tensor: Tensor of shape (H, W) with ray indices.
|
|
366
|
+
"""
|
|
367
|
+
device = local_points.device
|
|
368
|
+
dtype = local_points.dtype
|
|
369
|
+
local_points = local_points.detach()
|
|
370
|
+
with torch.no_grad():
|
|
371
|
+
|
|
372
|
+
W = local_points
|
|
373
|
+
V = self.get_pixel_centers().reshape(-1,2)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
nn_model = NearestNeighbors(n_neighbors=1, algorithm='kd_tree')
|
|
377
|
+
nn_model.fit(W) # Fit the model on W (the smaller collection)
|
|
378
|
+
distances, indices = nn_model.kneighbors(V)
|
|
379
|
+
indices = indices.flatten()
|
|
380
|
+
#out = flat_args[indices].reshape(self.y_grid_size,self.x_grid_size)
|
|
381
|
+
out = torch.tensor(indices.reshape(self.y_grid_size,self.x_grid_size),device = device)
|
|
382
|
+
return out
|
|
383
|
+
#implment nearest_ray
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class GridSquare(Grid):
|
|
387
|
+
"""
|
|
388
|
+
Convenience class for square grids centered at the origin.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
aperture_radius (float): Half-width of the square domain.
|
|
392
|
+
grid_size (int): Number of grid points in each direction.
|
|
393
|
+
"""
|
|
394
|
+
def __init__(self,aperture_radius,grid_size):
|
|
395
|
+
super().__init__(\
|
|
396
|
+
[-aperture_radius,aperture_radius],\
|
|
397
|
+
[-aperture_radius,aperture_radius],grid_size,grid_size)
|
|
398
|
+
|
|
399
|
+
|