diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,399 @@
1
+ """
2
+ This module implements grid-based spatial aggregation for ray optics.
3
+
4
+ Classes:
5
+ - Grid: Represents a 2D grid for spatial aggregation and statistics.
6
+ - GridSquare: Square variant of Grid for symmetric apertures.
7
+
8
+ Functions:
9
+ - (none at top level)
10
+
11
+ Example:
12
+ >>> grid = Grid([0, 1], [0, 1], 10, 10)
13
+ >>> area = grid.get_area()
14
+ """
15
+
16
+ # Copyright (c) 2025 Martin Pflaum
17
+ # This file is part of the diffinytrace project, licensed under the MIT License.
18
+
19
+ __all__ = [
20
+ "Grid",
21
+ "GridSquare"
22
+ ]
23
+
24
+
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from sklearn.neighbors import NearestNeighbors
29
+ import numpy as np
30
+
31
+ class Grid():
32
+ """
33
+ Represents a 2D grid over a rectangular area with aggregation and indexing utilities.
34
+
35
+ Args:
36
+ y_range (tuple[float, float]): The range in y-direction, as (y_min, y_max).
37
+ x_range (tuple[float, float]): The range in x-direction, as (x_min, x_max).
38
+ y_grid_size (int): Number of grid cells in y-direction.
39
+ x_grid_size (int): Number of grid cells in x-direction.
40
+ """
41
+ def __init__(self,y_range,x_range,y_grid_size,x_grid_size):
42
+ super().__init__()
43
+ self.y_range = np.array(y_range)
44
+ self.x_range = np.array(x_range)
45
+
46
+
47
+ self.x_grid_size = x_grid_size
48
+ self.y_grid_size = y_grid_size
49
+ self.x_delta = (self.x_range[1]-self.x_range[0])/x_grid_size
50
+ self.y_delta = (self.y_range[1]-self.y_range[0])/y_grid_size
51
+
52
+ def get_area(self):
53
+ r"""
54
+ Computes the total area of the grid.
55
+
56
+ Returns:
57
+ float: Total area of the grid.
58
+
59
+ .. math::
60
+ A = (x_{max} - x_{min}) \cdot (y_{max} - y_{min})
61
+ """
62
+ return (self.x_range[1]-self.x_range[0])*(self.y_range[1]-self.y_range[0])
63
+
64
+ def get_pixel_area(self):
65
+ r"""
66
+ Returns the area of a single pixel/grid cell.
67
+
68
+ Returns:
69
+ float: Area of a single grid cell.
70
+
71
+ .. math::
72
+ A_{pixel} = \Delta x \cdot \Delta y
73
+ """
74
+ return self.x_delta*self.y_delta
75
+
76
+ def get_yi_xi(self,local_points,round_to_bounds=True):
77
+ r"""
78
+ Converts 2D local coordinates to integer grid indices.
79
+
80
+ Args:
81
+ local_points (torch.Tensor): Tensor of shape (N, 2) representing 2D points.
82
+ round_to_bounds (bool): If True, clamps indices to stay within grid bounds. If False, returns a mask indicating valid indices.
83
+
84
+ Returns:
85
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors (yi, xi) of shape (N,).
86
+ """
87
+ if len(local_points.shape) != 2 or local_points.shape[1] != 2:
88
+ raise RuntimeError("The local_points must be in local coordinates and of shape [#points,2]")
89
+ local_points = local_points.detach()
90
+
91
+ ref_x = (local_points[:,0]-self.x_range[0])/self.x_delta
92
+ ref_y = (local_points[:,1]-self.y_range[0])/self.y_delta
93
+
94
+ xi = torch.floor(ref_x).long()
95
+ yi = torch.floor(ref_y).long()
96
+
97
+
98
+ valid = (xi>=self.x_grid_size).float()+(xi<0).float()+(yi>=self.y_grid_size).float()+(yi<0).float()
99
+ valid = valid==0.0
100
+
101
+ if ((xi>=self.x_grid_size).any() or (xi<0).any() or (yi>=self.y_grid_size).any() or (yi<0).any()):
102
+ yi = torch.clamp(yi,min=0,max=(self.y_grid_size-1))
103
+ xi = torch.clamp(xi,min=0,max=(self.x_grid_size-1))
104
+ #else:
105
+ # raise RuntimeError(f"Target grid ERROR: points out of bounds! max xi={xi.max()}, min xi={xi.min()},max yi={yi.max()}, min yi={yi.min()}")
106
+
107
+ if round_to_bounds:
108
+ return (yi,xi)
109
+ else:
110
+ return (yi,xi),valid
111
+
112
+ def get_k(self,local_points,round_to_bounds=True):
113
+ r"""
114
+ Maps local coordinates to flattened grid indices.
115
+
116
+ Args:
117
+ local_points (torch.Tensor): Tensor of shape (N, 2).
118
+ round_to_bounds (bool): Whether to clamp indices to grid bounds.
119
+
120
+ Returns:
121
+ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122
+ - If `round_to_bounds` is True: Tensor of shape (N,).
123
+ - Otherwise: Tuple (indices, validity_mask).
124
+ """
125
+ if round_to_bounds:
126
+ yi,xi = self.get_yi_xi(local_points,round_to_bounds=round_to_bounds)
127
+ return (yi*self.x_grid_size+xi).long()
128
+ else:
129
+ (yi,xi),valid = self.get_yi_xi(local_points,round_to_bounds=round_to_bounds)
130
+ k = (yi*self.x_grid_size+xi).long()
131
+ return k,valid
132
+
133
+ def map_matrix_to_ray(self,local_points,old_matrix):
134
+ r"""
135
+ Maps a matrix defined on the grid to the given local points.
136
+
137
+ Args:
138
+ local_points (torch.Tensor): Points of shape (N, 2).
139
+ old_matrix (torch.Tensor): Matrix of shape (H, W, ...).
140
+
141
+ Returns:
142
+ torch.Tensor: Resampled matrix values of shape (N, ...).
143
+ """
144
+ device = local_points.device
145
+ dtype = local_points.dtype
146
+ k = self.get_k(local_points)
147
+ return old_matrix.reshape(-1)[k].reshape(local_points.shape[0],*old_matrix.shape[2:])
148
+
149
+ def sum(self,
150
+ local_points:torch.Tensor,
151
+ values:torch.Tensor,
152
+ old_matrix = None,
153
+ round_to_bounds:bool = False):
154
+ """
155
+ Sums values over the grid based on point locations.
156
+
157
+ Args:
158
+ local_points (torch.Tensor): Points of shape (N, 2).
159
+ values (torch.Tensor): Values of shape (N,) or (N, D).
160
+ old_matrix (torch.Tensor or None): Previous result for accumulation.
161
+ round_to_bounds (bool): Clamp indices to bounds if True.
162
+
163
+ Returns:
164
+ torch.Tensor: Aggregated result of shape (H, W).
165
+ """
166
+ device = local_points.device
167
+ dtype = local_points.dtype
168
+ out = torch.zeros((self.x_grid_size*self.y_grid_size),device=device,dtype=dtype)
169
+ if not old_matrix is None:
170
+ out = old_matrix
171
+
172
+ if round_to_bounds:
173
+ k = self.get_k(local_points,round_to_bounds)
174
+ out.scatter_add_(0,k,values)
175
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
176
+ return out
177
+ else:
178
+ k,valid = self.get_k(local_points,round_to_bounds)
179
+ values = values[valid]
180
+ k = k[valid]
181
+ out.scatter_add_(0,k,values)
182
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
183
+ return out
184
+
185
+
186
+ def prod(self,local_points,values,old_matrix = None,round_to_bounds=False):
187
+ """
188
+ Multiplies values over the grid based on point locations.
189
+ Args:
190
+ local_points (torch.Tensor): Points of shape (N, 2).
191
+ values (torch.Tensor): Values of shape (N,) or (N, D).
192
+ old_matrix (torch.Tensor or None): Previous result for accumulation.
193
+ round_to_bounds (bool): Clamp indices to bounds if True.
194
+ Returns:
195
+ torch.Tensor: Aggregated result of shape (H, W).
196
+ """
197
+ device = local_points.device
198
+ dtype = local_points.dtype
199
+ out = torch.ones((self.y_grid_size*self.x_grid_size),device=device,dtype=dtype)
200
+ if not old_matrix is None:
201
+ out = old_matrix
202
+ if round_to_bounds:
203
+ k = self.get_k(local_points,round_to_bounds)
204
+ out.scatter_reduce_(0,k,values,reduce='prod')
205
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
206
+ return out
207
+ else:
208
+ k,valid = self.get_k(local_points,round_to_bounds)
209
+ values = values[valid]
210
+ k = k[valid]
211
+ out.scatter_reduce_(0,k,values,"prod")
212
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
213
+ return out
214
+
215
+ def mean(self,local_points,values,old_matrix = None,round_to_bounds=False):
216
+ """
217
+ Computes the mean of values over the grid based on point locations.
218
+ Args:
219
+ local_points (torch.Tensor): Points of shape (N, 2).
220
+ values (torch.Tensor): Values of shape (N,) or (N, D).
221
+ old_matrix (torch.Tensor or None): Previous result for accumulation.
222
+ round_to_bounds (bool): Clamp indices to bounds if True.
223
+ Returns:
224
+ torch.Tensor: Aggregated result of shape (H, W).
225
+ """
226
+ device = local_points.device
227
+ dtype = local_points.dtype
228
+ out = torch.zeros((self.y_grid_size*self.x_grid_size),device=device,dtype=dtype)
229
+ if not old_matrix is None:
230
+ out = old_matrix
231
+ if round_to_bounds:
232
+ k = self.get_k(local_points)
233
+ out.scatter_reduce_(0,k,values,reduce='mean',include_self=False)
234
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
235
+ return out
236
+ else:
237
+ k,valid = self.get_k(local_points,round_to_bounds)
238
+ values = values[valid]
239
+ k = k[valid]
240
+ out.scatter_reduce_(0,k,values,reduce='mean',include_self=False)
241
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
242
+ return out
243
+
244
+
245
+
246
+ def __get_args(self,M,b,v):
247
+ device = v.device
248
+ dtype = v.dtype
249
+ M_argmin = torch.full((self.y_grid_size*self.x_grid_size,), -1, dtype=torch.long,device=device)
250
+ mask = (v == M[b])
251
+ indices = torch.arange(len(v))
252
+ M_argmin.scatter_(0, b[mask], indices[mask])
253
+ return M_argmin
254
+
255
+ def min(self,local_points,values,old_matrix = None,return_args=False):
256
+ """Finds the minimum value for each grid cell based on local points.
257
+ Args:
258
+ local_points (torch.Tensor): Points of shape (N, 2).
259
+ values (torch.Tensor): Values of shape (N,) or (N, D).
260
+ old_matrix (torch.Tensor or None): Previous result for accumulation.
261
+ return_args (bool): If True, also return indices.
262
+ Returns:
263
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Minimum values, optionally with indices.
264
+ """
265
+ device = local_points.device
266
+ dtype = local_points.dtype
267
+ out = torch.full((self.y_grid_size*self.x_grid_size,),float('inf'),device=device,dtype=dtype)
268
+ if not old_matrix is None:
269
+ out = old_matrix
270
+ k = self.get_k(local_points)
271
+ out.scatter_reduce_(0,k,values,reduce='amin')
272
+ if return_args:
273
+ out_args = self.__get_args(out,k,values)
274
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
275
+ out_args = out_args.reshape(self.y_grid_size,self.x_grid_size)
276
+ return out,out_args
277
+ else:
278
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
279
+ return out
280
+
281
+ def max(self,local_points,values,old_matrix = None,return_args=False):
282
+ """
283
+ Finds the maximum value for each grid cell based on local points.
284
+ Args:
285
+ local_points (torch.Tensor): Points of shape (N, 2).
286
+ values (torch.Tensor): Values of shape (N,) or (N, D).
287
+ old_matrix (torch.Tensor or None): Previous result for accumulation.
288
+ return_args (bool): If True, also return indices.
289
+ Returns:
290
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Maximum values, optionally with indices.
291
+ """
292
+ device = local_points.device
293
+ dtype = local_points.dtype
294
+ out = torch.full((self.y_grid_size*self.x_grid_size,),float('-inf'),device=device,dtype=dtype)
295
+ if not old_matrix is None:
296
+ out = old_matrix
297
+ k = self.get_k(local_points)
298
+ out.scatter_reduce_(0,k,values,reduce='amax')
299
+ if return_args:
300
+ out_args = self.__get_args(out,k,values)
301
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
302
+ out_args = out_args.reshape(self.y_grid_size,self.x_grid_size)
303
+ return out,out_args
304
+ else:
305
+ out = out.reshape(self.y_grid_size,self.x_grid_size)
306
+ return out
307
+
308
+ def __get_x_middle(self):
309
+ x_middle = self.x_delta*0.5+torch.arange(0,self.x_grid_size)*self.x_delta+self.x_range[0]
310
+ return x_middle
311
+
312
+ def __get_y_middle(self):
313
+ y_middle = self.y_delta*0.5+torch.arange(0,self.y_grid_size)*self.y_delta+self.y_range[0]
314
+ return y_middle
315
+
316
+ def get_y_middle(self):
317
+ return self.__get_y_middle()
318
+
319
+ def get_x_middle(self):
320
+ return self.__get_x_middle()
321
+
322
+ def nearest(self,local_points,return_args=False):
323
+ """
324
+ Finds the nearest pixel for each local point using L2 distance.
325
+
326
+ Args:
327
+ local_points (torch.Tensor): Tensor of shape (N, 2).
328
+ return_args (bool): If True, also return indices.
329
+
330
+ Returns:
331
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Minimum squared distances, optionally with indices.
332
+ """
333
+ x_middle = self.__get_x_middle()
334
+ y_middle = self.__get_y_middle()
335
+
336
+ yi,xi = self.get_yi_xi(local_points)
337
+
338
+ xdiff = (x_middle[xi]-local_points[:,0])**2.0
339
+ ydiff = (y_middle[yi]-local_points[:,1])**2.0
340
+ l2diff = xdiff+ydiff
341
+ return self.min(local_points,l2diff,return_args=return_args)
342
+
343
+ def get_pixel_centers(self):
344
+ """
345
+ Returns the 2D center coordinates of each grid cell.
346
+
347
+ Returns:
348
+ torch.Tensor: Tensor of shape (H, W, 2) with pixel center coordinates.
349
+ """
350
+ x_middle = self.__get_x_middle()
351
+ y_middle = self.__get_y_middle()
352
+
353
+ grid_y,grid_x = torch.meshgrid(y_middle, x_middle, indexing='ij')
354
+ V = torch.cat([grid_x.reshape(-1,1),grid_y.reshape(-1,1)],dim=-1)
355
+ return V.reshape(self.y_grid_size,self.x_grid_size,2)
356
+
357
+ def get_nearest_ray(self,local_points):
358
+ """
359
+ Finds the index of the nearest ray for each grid cell using `sklearn.neighbors.NearestNeighbors`.
360
+
361
+ Args:
362
+ local_points (torch.Tensor): Tensor of shape (N, 2) representing sampled rays.
363
+
364
+ Returns:
365
+ torch.Tensor: Tensor of shape (H, W) with ray indices.
366
+ """
367
+ device = local_points.device
368
+ dtype = local_points.dtype
369
+ local_points = local_points.detach()
370
+ with torch.no_grad():
371
+
372
+ W = local_points
373
+ V = self.get_pixel_centers().reshape(-1,2)
374
+
375
+
376
+ nn_model = NearestNeighbors(n_neighbors=1, algorithm='kd_tree')
377
+ nn_model.fit(W) # Fit the model on W (the smaller collection)
378
+ distances, indices = nn_model.kneighbors(V)
379
+ indices = indices.flatten()
380
+ #out = flat_args[indices].reshape(self.y_grid_size,self.x_grid_size)
381
+ out = torch.tensor(indices.reshape(self.y_grid_size,self.x_grid_size),device = device)
382
+ return out
383
+ #implment nearest_ray
384
+
385
+
386
+ class GridSquare(Grid):
387
+ """
388
+ Convenience class for square grids centered at the origin.
389
+
390
+ Args:
391
+ aperture_radius (float): Half-width of the square domain.
392
+ grid_size (int): Number of grid points in each direction.
393
+ """
394
+ def __init__(self,aperture_radius,grid_size):
395
+ super().__init__(\
396
+ [-aperture_radius,aperture_radius],\
397
+ [-aperture_radius,aperture_radius],grid_size,grid_size)
398
+
399
+