diffinytrace 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffinytrace/__init__.py +122 -0
- diffinytrace/basis_functions/__init__.py +14 -0
- diffinytrace/basis_functions/bspline.py +521 -0
- diffinytrace/basis_functions/chebyshev.py +3 -0
- diffinytrace/basis_functions/legendre.py +77 -0
- diffinytrace/basis_functions/zernike.py +235 -0
- diffinytrace/config.py +140 -0
- diffinytrace/constraints.py +54 -0
- diffinytrace/element.py +1660 -0
- diffinytrace/export/__init__.py +8 -0
- diffinytrace/export/cad.py +253 -0
- diffinytrace/gaussian_smoother.py +530 -0
- diffinytrace/hat_smoother.py +44 -0
- diffinytrace/integrators.py +452 -0
- diffinytrace/intersection.py +285 -0
- diffinytrace/optimize.py +808 -0
- diffinytrace/physical_object.py +150 -0
- diffinytrace/plotting/__init__.py +16 -0
- diffinytrace/plotting/core.py +92 -0
- diffinytrace/plotting/quantity2D.py +188 -0
- diffinytrace/plotting/system2D.py +220 -0
- diffinytrace/plotting/system3D.py +327 -0
- diffinytrace/plotting/wavelength.py +231 -0
- diffinytrace/refractive_index.py +101 -0
- diffinytrace/render.py +77 -0
- diffinytrace/source.py +661 -0
- diffinytrace/spectrum.py +79 -0
- diffinytrace/surface.py +468 -0
- diffinytrace/target_grid.py +399 -0
- diffinytrace/transforms.py +472 -0
- diffinytrace/utils/__init__.py +7 -0
- diffinytrace/utils/autograd.py +116 -0
- diffinytrace/utils/irradiance_importer.py +134 -0
- diffinytrace-2.1.dist-info/METADATA +26 -0
- diffinytrace-2.1.dist-info/RECORD +38 -0
- diffinytrace-2.1.dist-info/WHEEL +5 -0
- diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
- diffinytrace-2.1.dist-info/top_level.txt +1 -0
diffinytrace/element.py
ADDED
|
@@ -0,0 +1,1660 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"is_valid_square_circle",
|
|
6
|
+
"OpticalSystem",
|
|
7
|
+
"SequentialOpticalSystem",
|
|
8
|
+
"OpticalElement",
|
|
9
|
+
"OpticalSurface",
|
|
10
|
+
"LensSurfaceTransmissionEnter",
|
|
11
|
+
"LensSurfaceTransmissionLeave",
|
|
12
|
+
"LensSurfaceSide",
|
|
13
|
+
"Lens",
|
|
14
|
+
"Mirror",
|
|
15
|
+
"Detector",
|
|
16
|
+
"trace_to_detector",
|
|
17
|
+
"set_unused_params_to_zero",
|
|
18
|
+
"get_unused_params_mask",
|
|
19
|
+
"set_used_params_bounds_to_constant",
|
|
20
|
+
"FresnelOpticalSurface",
|
|
21
|
+
"FresnelVirtualLensSurfaceTransmissionEnter",
|
|
22
|
+
"FresnelVirtualLensSurfaceTransmissionLeave",
|
|
23
|
+
"FresnelVirtualLens",
|
|
24
|
+
"compute_reflected_directions",
|
|
25
|
+
"get_refracted_directions",
|
|
26
|
+
"set_unused_bspline_coeff_to_nearest"
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
from typing import List,Dict,Tuple
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
from .plotting import Plotable
|
|
33
|
+
from .refractive_index import materials
|
|
34
|
+
#from .refractive_index import RefractiveIndex
|
|
35
|
+
from .intersection import construct_surface_and_normal_func_with_params,get_ray_intersection_length
|
|
36
|
+
from .optimize import make_parameter_from_input
|
|
37
|
+
from . import transforms
|
|
38
|
+
from .integrators import Disc,Cube
|
|
39
|
+
import numpy as np
|
|
40
|
+
import cadquery as cq
|
|
41
|
+
from .physical_object import PhysicalObject,PhysicalSurface
|
|
42
|
+
from . utils.autograd import grad
|
|
43
|
+
from . optimize import minimize,remove_bounds,set_bounds_from_params_mask
|
|
44
|
+
from . transforms import Transform
|
|
45
|
+
"""
|
|
46
|
+
import numpy as np
|
|
47
|
+
color_pallete = np.array([(218.0, 232, 252),
|
|
48
|
+
(108, 142, 191),
|
|
49
|
+
(248, 206, 204),
|
|
50
|
+
(184, 84, 80),
|
|
51
|
+
(213, 232, 212),
|
|
52
|
+
(130, 179, 102),
|
|
53
|
+
(255, 230, 204),
|
|
54
|
+
(215, 155, 0),
|
|
55
|
+
(225, 213, 231),
|
|
56
|
+
(150, 115, 166),
|
|
57
|
+
(255, 242, 204),
|
|
58
|
+
(214, 182, 86)])/255.0
|
|
59
|
+
import matplotlib.colors as mcolors
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def is_valid_square_circle(transform:Transform,
|
|
64
|
+
O:torch.Tensor,
|
|
65
|
+
aperture_radius:float,
|
|
66
|
+
is_square:bool)->torch.Tensor:
|
|
67
|
+
r"""
|
|
68
|
+
Checks whether points lie within a circular or square aperture after transformation.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
transform (Transform): Transformation object to convert global to local coordinates.
|
|
72
|
+
O (torch.Tensor): Points in global coordinates of shape (N, 3).
|
|
73
|
+
aperture_radius (float or torch.Tensor): Radius of the circular or square aperture.
|
|
74
|
+
is_square (bool): If True, aperture is square; if False, circular.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
torch.Tensor: Boolean tensor of shape (N,) indicating whether each point lies within the aperture.
|
|
78
|
+
|
|
79
|
+
Note:
|
|
80
|
+
For a square, checks if \( |x| < r \) and \( |y| < r \). For a circle, checks if \( \sqrt{x^2 + y^2} < r \).
|
|
81
|
+
"""
|
|
82
|
+
aperture_radius = torch.abs(torch.tensor(aperture_radius))
|
|
83
|
+
with torch.no_grad():
|
|
84
|
+
O_local = transform.to_local_pos(O.detach())
|
|
85
|
+
if is_square:
|
|
86
|
+
return ((torch.abs(O_local[:,0])<aperture_radius).float()*(torch.abs(O_local[:,1])<aperture_radius).float())==1.0
|
|
87
|
+
else:
|
|
88
|
+
return torch.norm(O_local,dim=1)<aperture_radius
|
|
89
|
+
|
|
90
|
+
class OpticalSystem(nn.Module,Plotable):
|
|
91
|
+
"""
|
|
92
|
+
Base class for optical systems composed of multiple optical modules.
|
|
93
|
+
|
|
94
|
+
This class serves as a container for modules such as lenses, mirrors, and detectors.
|
|
95
|
+
It supports visualization and modular organization.
|
|
96
|
+
|
|
97
|
+
Attributes:
|
|
98
|
+
modules_dict (nn.ModuleDict): Dictionary of named optical modules.
|
|
99
|
+
"""
|
|
100
|
+
def __init__(self, modules_dict:Dict):
|
|
101
|
+
nn.Module.__init__(self)
|
|
102
|
+
Plotable.__init__(self)
|
|
103
|
+
self.modules_dict = nn.ModuleDict(modules_dict)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def forward(self):
|
|
107
|
+
raise RuntimeError("OpticalSystem: forward is not implemented")
|
|
108
|
+
|
|
109
|
+
def get_plotable_childs(self):
|
|
110
|
+
out = [[self.modules_dict[key],key] for key in self.modules_dict.keys()]
|
|
111
|
+
return out
|
|
112
|
+
|
|
113
|
+
def get_plot_points_2D(self,resolution:int):
|
|
114
|
+
return []
|
|
115
|
+
|
|
116
|
+
def get_plot_points_3D(self,resolution:int):
|
|
117
|
+
return []
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class SequentialOpticalSystem(OpticalSystem):
|
|
121
|
+
"""
|
|
122
|
+
Optical system that processes rays in a defined sequence.
|
|
123
|
+
|
|
124
|
+
Useful for simulating light propagation through a sequence of elements, e.g., source → lens → detector.
|
|
125
|
+
|
|
126
|
+
Attributes:
|
|
127
|
+
n_func_enviroment (Callable): Function returning refractive index of the surrounding medium.
|
|
128
|
+
"""
|
|
129
|
+
def __init__(self,modules_dict:Dict, n_func_enviroment=materials["AIR"]):
|
|
130
|
+
OpticalSystem.__init__(self,modules_dict)
|
|
131
|
+
self.n_func_enviroment = n_func_enviroment #Edit wavelength dependent
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def forward(self,x,mapping_sequence:List[str]):
|
|
135
|
+
"""
|
|
136
|
+
Propagates rays through the defined sequence of modules.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
x (Any): Input rays or sampling data.
|
|
140
|
+
mapping_sequence (list[str]): Ordered list of module names defining propagation sequence.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Any: Output after final module in the sequence.
|
|
144
|
+
"""
|
|
145
|
+
for name in mapping_sequence:
|
|
146
|
+
from .source import RaySource
|
|
147
|
+
if isinstance(self.modules_dict[name],RaySource):
|
|
148
|
+
x = self.modules_dict[name](x,self.n_func_enviroment)
|
|
149
|
+
else:
|
|
150
|
+
x = self.modules_dict[name](*x)
|
|
151
|
+
return x
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class OpticalElement(PhysicalObject,Plotable):
|
|
155
|
+
"""
|
|
156
|
+
Abstract base class for optical elements like lenses, mirrors, and detectors.
|
|
157
|
+
|
|
158
|
+
Provides interface for geometric transformation and ray propagation.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
def __init__(self,fill_color="white", outline_color="black",is_volume=False):
|
|
162
|
+
PhysicalObject.__init__(self)
|
|
163
|
+
Plotable.__init__(self,fill_color=fill_color,outline_color=outline_color,is_volume=is_volume)
|
|
164
|
+
|
|
165
|
+
def forward(self,O2:torch.Tensor, D2:torch.Tensor, wl:torch.Tensor, n_func_enviroment, meta_data):
|
|
166
|
+
"""
|
|
167
|
+
Propagates rays through the optical element.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
O2 (torch.Tensor): Ray origins.
|
|
171
|
+
D2 (torch.Tensor): Ray directions.
|
|
172
|
+
wl (torch.Tensor): Wavelengths.
|
|
173
|
+
n_func_enviroment (Callable): Function returning environmental refractive index.
|
|
174
|
+
meta_data (dict): Dictionary with path length and validity information.
|
|
175
|
+
|
|
176
|
+
Raises:
|
|
177
|
+
NotImplementedError: Must be overridden by subclasses.
|
|
178
|
+
"""
|
|
179
|
+
raise NotImplementedError("process_ray not implemented")
|
|
180
|
+
|
|
181
|
+
def get_transform(self):
|
|
182
|
+
"""
|
|
183
|
+
Returns the transformation associated with the surface.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Transform: The local-to-global transformation object.
|
|
187
|
+
"""
|
|
188
|
+
raise NotImplementedError("get_transform not implemented")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class OpticalSurface(OpticalElement,PhysicalSurface):
|
|
192
|
+
"""
|
|
193
|
+
Represents a surface in 3D space with a defined aperture and transformation.
|
|
194
|
+
|
|
195
|
+
Supports both square and circular apertures, and provides methods for parametric sampling,
|
|
196
|
+
CAD conversion, ray intersection, and plotting.
|
|
197
|
+
|
|
198
|
+
Attributes:
|
|
199
|
+
surface (object): Object with a method `explicit(parametric_pos)` returning z-values.
|
|
200
|
+
aperture_radius (float): Radius of the circular or square aperture.
|
|
201
|
+
is_square (bool): Whether the aperture is square-shaped.
|
|
202
|
+
transform (Transform): Local-to-global transformation.
|
|
203
|
+
integrator (Integrator): Integration object (Disc or Cube) for parametric sampling.
|
|
204
|
+
"""
|
|
205
|
+
def __init__(self,
|
|
206
|
+
transform:Transform,
|
|
207
|
+
surface,
|
|
208
|
+
aperture_radius:float,
|
|
209
|
+
is_square:bool=False,
|
|
210
|
+
fill_color:str="white",
|
|
211
|
+
outline_color:str="black"):
|
|
212
|
+
"""
|
|
213
|
+
Initializes the optical surface.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
transform (Transform): Local-to-global transformation of the surface.
|
|
217
|
+
surface (object): Surface object with an `explicit()` method for height computation.
|
|
218
|
+
aperture_radius (float): Radius or half-width of the aperture.
|
|
219
|
+
is_square (bool, optional): If True, aperture is square. Defaults to False.
|
|
220
|
+
fill_color (str, optional): Color for plotting. Defaults to "white".
|
|
221
|
+
outline_color (str, optional): Outline color. Defaults to "black".
|
|
222
|
+
"""
|
|
223
|
+
OpticalElement.__init__(self,fill_color=fill_color,outline_color=outline_color)
|
|
224
|
+
PhysicalSurface.__init__(self)
|
|
225
|
+
self.surface = surface
|
|
226
|
+
self.aperture_radius = aperture_radius
|
|
227
|
+
self.is_square = is_square
|
|
228
|
+
self.transform = transform
|
|
229
|
+
|
|
230
|
+
integrator = None
|
|
231
|
+
if is_square:
|
|
232
|
+
#print("parametric integrator Cube")
|
|
233
|
+
integrator = Cube([[-aperture_radius,aperture_radius],[-aperture_radius,aperture_radius]])
|
|
234
|
+
else:
|
|
235
|
+
#print("parametric integrator Disc")
|
|
236
|
+
integrator = Disc(aperture_radius)
|
|
237
|
+
self.integrator = integrator
|
|
238
|
+
|
|
239
|
+
def get_constraint_funs_leq_zero(self):
|
|
240
|
+
"""
|
|
241
|
+
Returns constraint functions used for integration and optimization over the surface.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
list[Callable]: List of functions f(param_pos) <= 0 indicating valid parametric regions.
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
RuntimeError: If `is_square` is True (not yet implemented).
|
|
248
|
+
"""
|
|
249
|
+
aperture_radius = self.aperture_radius
|
|
250
|
+
if self.is_square:
|
|
251
|
+
#lambda parametric_pos: aperture_radius-torch.abs(parametric_pos)[...,0]
|
|
252
|
+
#lambda parametric_pos: aperture_radius-parametric_pos[...,1]
|
|
253
|
+
raise RuntimeError("get_constraint_funs_geq_zero is not implemented")
|
|
254
|
+
else:
|
|
255
|
+
|
|
256
|
+
def out_func(parametric_pos):
|
|
257
|
+
val = torch.linalg.norm(parametric_pos,dim=-1)-aperture_radius
|
|
258
|
+
return val
|
|
259
|
+
return [out_func]
|
|
260
|
+
|
|
261
|
+
"""
|
|
262
|
+
def get_corners_in_parameter_space(self):
|
|
263
|
+
aperture_radius = self.aperture_radius
|
|
264
|
+
if self.is_square:
|
|
265
|
+
return [[-aperture_radius,aperture_radius],\
|
|
266
|
+
[aperture_radius,-aperture_radius],\
|
|
267
|
+
[-aperture_radius,-aperture_radius],\
|
|
268
|
+
[aperture_radius,aperture_radius]]
|
|
269
|
+
else:
|
|
270
|
+
return []
|
|
271
|
+
|
|
272
|
+
def get_edge_funcs_in_parameter_space(self):
|
|
273
|
+
aperture_radius = self.aperture_radius
|
|
274
|
+
if self.is_square:
|
|
275
|
+
func1D = lambda t: -aperture_radius+t*aperture_radius*2.0
|
|
276
|
+
out = [lambda t: (-aperture_radius,func1D(t)),\
|
|
277
|
+
lambda t: (aperture_radius,func1D(t)),\
|
|
278
|
+
lambda t: (func1D(t),-aperture_radius),\
|
|
279
|
+
lambda t: (func1D(t),aperture_radius)]
|
|
280
|
+
return out
|
|
281
|
+
else:
|
|
282
|
+
def parameterize_circle(t):
|
|
283
|
+
theta = 2 * torch.pi * t # Convert t to radians
|
|
284
|
+
x = torch.cos(theta)*self.aperture_radius
|
|
285
|
+
y = torch.sin(theta)*self.aperture_radius
|
|
286
|
+
return (x,y)
|
|
287
|
+
return [parameterize_circle]
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def get_plot_points_2D(self,resolution:int)->List[Tuple[torch.Tensor]]:
|
|
293
|
+
"""
|
|
294
|
+
Returns 2D slices through the surface (z-y plane) for plotting.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
resolution (int): Number of sample points along the y-axis.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
List[Tuple[torch.Tensor]]: List of (z, y) coordinate tuples.
|
|
301
|
+
"""
|
|
302
|
+
y = torch.linspace(-self.aperture_radius,self.aperture_radius,resolution)
|
|
303
|
+
x = torch.zeros_like(y)
|
|
304
|
+
O = torch.zeros((resolution,2))
|
|
305
|
+
O[:,0] = x
|
|
306
|
+
O[:,1] = y
|
|
307
|
+
|
|
308
|
+
points = None
|
|
309
|
+
with torch.no_grad():
|
|
310
|
+
points = self.parametric_surface(O)
|
|
311
|
+
|
|
312
|
+
y = points[:,1]
|
|
313
|
+
z = points[:,2]
|
|
314
|
+
return [(z,y)]
|
|
315
|
+
|
|
316
|
+
def get_plot_points_3D(self,resolution:int)->List[Tuple[torch.Tensor]]:
|
|
317
|
+
"""
|
|
318
|
+
Returns 3D grid of surface points for visualization.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
resolution (int): Grid resolution in x and y.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
List[Tuple[torch.Tensor]]: List of (x, y, z) meshgrids as torch tensors.
|
|
325
|
+
"""
|
|
326
|
+
_x = torch.linspace(-self.aperture_radius,self.aperture_radius,resolution)
|
|
327
|
+
_y = torch.linspace(-self.aperture_radius,self.aperture_radius,resolution)
|
|
328
|
+
mesh = torch.meshgrid(_x,_y)
|
|
329
|
+
x = mesh[0].reshape(-1)
|
|
330
|
+
y = mesh[1].reshape(-1)
|
|
331
|
+
|
|
332
|
+
if not self.is_square:
|
|
333
|
+
mul = (torch.sqrt(x*x+y*y)>self.aperture_radius).float()/torch.sqrt(x*x+y*y)*self.aperture_radius
|
|
334
|
+
mul += (torch.sqrt(x*x+y*y)<self.aperture_radius).float()
|
|
335
|
+
x = x*mul
|
|
336
|
+
y = y*mul
|
|
337
|
+
|
|
338
|
+
O = torch.zeros((x.shape[0],2))
|
|
339
|
+
O[:,0] = x
|
|
340
|
+
O[:,1] = y
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
points = None
|
|
344
|
+
with torch.no_grad():
|
|
345
|
+
points = self.parametric_surface(O)
|
|
346
|
+
|
|
347
|
+
x = points[:,0].reshape(_x.shape[0],_x.shape[0])
|
|
348
|
+
y = points[:,1].reshape(_x.shape[0],_x.shape[0])
|
|
349
|
+
z = points[:,2].reshape(_x.shape[0],_x.shape[0])
|
|
350
|
+
return [(x,y,z)]
|
|
351
|
+
|
|
352
|
+
def get_CAD_points(self,resolution:int)->List[Tuple[torch.Tensor]]:
|
|
353
|
+
"""
|
|
354
|
+
Generates a 3D surface point grid for CAD conversion.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
resolution (int): Sampling resolution.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Tuple[torch.Tensor]: (x, y, z) coordinate grids for CAD modeling.
|
|
361
|
+
"""
|
|
362
|
+
#TODO maybe implement this also for affine transforms in surface class itself
|
|
363
|
+
_x = torch.linspace(-self.aperture_radius,self.aperture_radius,resolution)
|
|
364
|
+
_y = torch.linspace(-self.aperture_radius,self.aperture_radius,resolution)
|
|
365
|
+
mesh = torch.meshgrid(_x,_y)
|
|
366
|
+
x = mesh[0].reshape(-1)
|
|
367
|
+
y = mesh[1].reshape(-1)
|
|
368
|
+
|
|
369
|
+
O = torch.zeros((x.shape[0],2))
|
|
370
|
+
O[:,0] = x
|
|
371
|
+
O[:,1] = y
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
points = None
|
|
375
|
+
with torch.no_grad():
|
|
376
|
+
points = self.parametric_surface(O)
|
|
377
|
+
|
|
378
|
+
x = points[:,0].reshape(_x.shape[0],_x.shape[0])
|
|
379
|
+
y = points[:,1].reshape(_x.shape[0],_x.shape[0])
|
|
380
|
+
z = points[:,2].reshape(_x.shape[0],_x.shape[0])
|
|
381
|
+
return (x,y,z)
|
|
382
|
+
|
|
383
|
+
def get_CAD_face(self,resolution:int,tol:float=0.001,smoothing = None,minDeg: int = 1,maxDeg: int = 3):
|
|
384
|
+
"""
|
|
385
|
+
Converts the surface into a CAD face using B-spline approximation.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
resolution (int): Sampling resolution.
|
|
389
|
+
tol (float, optional): Approximation tolerance. Defaults to 0.001.
|
|
390
|
+
smoothing (Optional[int]): Smoothing value for fitting.
|
|
391
|
+
minDeg (int): Minimum degree of the spline.
|
|
392
|
+
maxDeg (int): Maximum degree of the spline.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
cadquery.Face: CAD face object.
|
|
396
|
+
"""
|
|
397
|
+
if hasattr(self.surface,"get_CAD_face"):
|
|
398
|
+
affine_transform = self.transform.get_transformation_matrix()
|
|
399
|
+
out = self.surface.get_CAD_face(affine_transform)
|
|
400
|
+
if out is None:
|
|
401
|
+
print("get_CAD_face returned NONE fallback to fit method")
|
|
402
|
+
else:
|
|
403
|
+
return out
|
|
404
|
+
|
|
405
|
+
cat_func_points = lambda tmp: torch.cat([telem.reshape(*telem.shape,1) for telem in tmp],dim=-1)
|
|
406
|
+
|
|
407
|
+
surface1_points = cat_func_points(self.get_CAD_points(resolution))
|
|
408
|
+
surface1_points = [[cq.Vector(elem[0],elem[1],elem[2]) for elem in row ]for row in surface1_points]
|
|
409
|
+
face1 = cq.Face.makeSplineApprox(surface1_points,smoothing=smoothing,minDeg=minDeg,maxDeg=maxDeg,tol=tol)
|
|
410
|
+
return face1
|
|
411
|
+
|
|
412
|
+
def parametric_sample(self,num_points:int,method:str="sobol")-> tuple[torch.Tensor, torch.Tensor]:
|
|
413
|
+
"""
|
|
414
|
+
Samples parametric positions on the aperture using the integrator.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
num_points (int): Number of sample points.
|
|
418
|
+
method (str): Sampling method. Options: "sobol", "monte_carlo", "midpoint", etc.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
Tuple[torch.Tensor, torch.Tensor]: Sampled positions and integration weights.
|
|
422
|
+
"""
|
|
423
|
+
return self.integrator.sample(num_points,method)
|
|
424
|
+
|
|
425
|
+
def parametric_surface(self,parametric_pos:torch.Tensor)->torch.Tensor:
|
|
426
|
+
"""
|
|
427
|
+
Maps 2D parametric coordinates to 3D global coordinates using the surface height and transform.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
parametric_pos (torch.Tensor): 2D parametric positions of shape (N, 2).
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
torch.Tensor: 3D positions of shape (N, 3) in global space.
|
|
434
|
+
|
|
435
|
+
Raises:
|
|
436
|
+
RuntimeError: If input does not have shape [..., 2].
|
|
437
|
+
"""
|
|
438
|
+
if parametric_pos.shape[-1] !=2:
|
|
439
|
+
raise RuntimeError("positions must be in local coordinates [:,2]")
|
|
440
|
+
device = parametric_pos.device
|
|
441
|
+
dtype = parametric_pos.dtype
|
|
442
|
+
|
|
443
|
+
z = self.surface.explicit(parametric_pos)
|
|
444
|
+
x = parametric_pos[:,0]
|
|
445
|
+
y = parametric_pos[:,1]
|
|
446
|
+
|
|
447
|
+
v = torch.zeros((x.shape[0],4),device=device,dtype=dtype)
|
|
448
|
+
v[:,0] = x
|
|
449
|
+
v[:,1] = y
|
|
450
|
+
v[:,2] = z
|
|
451
|
+
v[:,3] = torch.ones_like(v[:,3],device=device,dtype=dtype)
|
|
452
|
+
|
|
453
|
+
Mv = None
|
|
454
|
+
M = self.transform.get_transformation_matrix(device=device,dtype=dtype)
|
|
455
|
+
if (M.dtype != dtype) or M.device != device:
|
|
456
|
+
M = M.to(device=device,dtype=dtype)
|
|
457
|
+
Mv = v@M.T
|
|
458
|
+
out = Mv[:,[0,1,2]]
|
|
459
|
+
return out
|
|
460
|
+
|
|
461
|
+
"""
|
|
462
|
+
def explicit_surface(self,parametric_pos)->torch.Tensor:
|
|
463
|
+
z = self.surface.explicit(parametric_pos)
|
|
464
|
+
return z
|
|
465
|
+
"""
|
|
466
|
+
def get_surface_and_normal_func_with_params(self):
|
|
467
|
+
"""
|
|
468
|
+
Constructs a callable for surface position and normal computation with parameter tracking.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
Tuple[Callable, List]: Callable computes (position, normal), and the list contains parameters to be optimized.
|
|
472
|
+
"""
|
|
473
|
+
surface_and_normal,param_args = construct_surface_and_normal_func_with_params([self.transform,self.surface])
|
|
474
|
+
return surface_and_normal,param_args
|
|
475
|
+
|
|
476
|
+
def get_ray_intersect_length(self,O,D)->torch.Tensor:
|
|
477
|
+
"""
|
|
478
|
+
Computes intersection length along ray until hitting the surface.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
O (torch.Tensor): Ray origins of shape (N, 3).
|
|
482
|
+
D (torch.Tensor): Ray directions of shape (N, 3).
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
torch.Tensor: Intersection distances t such that O + t*D lies on the surface.
|
|
486
|
+
"""
|
|
487
|
+
device = O.device
|
|
488
|
+
dtype = O.dtype
|
|
489
|
+
surface_and_normal,param_args = self.get_surface_and_normal_func_with_params()
|
|
490
|
+
global_pos_approx = self.get_transform().to_global_pos(torch.zeros_like(O))
|
|
491
|
+
t_init = torch.linalg.norm((global_pos_approx.detach()-O.detach()),dim=-1)
|
|
492
|
+
|
|
493
|
+
t = get_ray_intersection_length(O,D,surface_and_normal,param_args,t_init)
|
|
494
|
+
return t
|
|
495
|
+
|
|
496
|
+
def get_new_is_valid(self,O,valid)->torch.Tensor:
|
|
497
|
+
"""
|
|
498
|
+
Updates a boolean mask indicating which rays are still valid after hitting the aperture.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
O (torch.Tensor): Ray intersection points.
|
|
502
|
+
valid (torch.Tensor): Previous boolean validity mask.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
torch.Tensor: Updated validity mask.
|
|
506
|
+
"""
|
|
507
|
+
valid = valid.float()*is_valid_square_circle(self.transform,O,self.aperture_radius,self.is_square).float()
|
|
508
|
+
valid = valid==1.0
|
|
509
|
+
return valid
|
|
510
|
+
|
|
511
|
+
def get_transform(self)->transforms.Transform:
|
|
512
|
+
"""
|
|
513
|
+
Returns the transformation associated with the surface.
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
Transform: The local-to-global transformation object.
|
|
517
|
+
"""
|
|
518
|
+
return self.transform
|
|
519
|
+
|
|
520
|
+
def get_refracted_directions(D: torch.Tensor, N: torch.Tensor, n1: torch.Tensor|float, n2: torch.Tensor|float) -> torch.Tensor:
|
|
521
|
+
r"""
|
|
522
|
+
Computes refracted ray directions using Snell's law.
|
|
523
|
+
|
|
524
|
+
At material interfaces, the transmitted direction :math:`\mathbf{D'}` is computed based on the surface normal
|
|
525
|
+
:math:`\mathbf{N} = \nabla s / \|\nabla s\|` and the incident direction :math:`\mathbf{D}`, using Snell's law (see :cite:`do`):
|
|
526
|
+
|
|
527
|
+
.. math::
|
|
528
|
+
|
|
529
|
+
\mathbf{D'} = \mathbf{N} \sqrt{1 - (1 - \cos^2 \psi_i) \eta^2} + \eta (\mathbf{D} - \mathbf{N} \cos \psi_i),
|
|
530
|
+
|
|
531
|
+
where :math:`\cos \psi_i = \mathbf{D} \cdot \mathbf{N}` and :math:`\eta = n / n'` is the ratio of the refractive indices
|
|
532
|
+
of the two materials.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
D (torch.Tensor): Incident directions of shape (M, 3), normalized.
|
|
536
|
+
N (torch.Tensor): Surface normals at points of incidence, shape (M, 3).
|
|
537
|
+
n1 (float or torch.Tensor): Refractive index of the incident medium.
|
|
538
|
+
n2 (float or torch.Tensor): Refractive index of the transmission medium.
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
torch.Tensor: Refracted directions of shape (M, 3).
|
|
542
|
+
|
|
543
|
+
"""
|
|
544
|
+
# Ensure the input tensors are normalized (unit vectors)
|
|
545
|
+
D = torch.nn.functional.normalize(D, dim=1)
|
|
546
|
+
N = torch.nn.functional.normalize(N, dim=1)
|
|
547
|
+
|
|
548
|
+
minus_DTN = -torch.sum(D * N, dim=1, keepdim=False)
|
|
549
|
+
|
|
550
|
+
constant_surface_dir = torch.sign(minus_DTN)#this is supposed to be positive
|
|
551
|
+
minus_DTN = minus_DTN*constant_surface_dir#Saves every thing
|
|
552
|
+
N = N*(constant_surface_dir.reshape(-1,1))
|
|
553
|
+
|
|
554
|
+
cos_theta_1 = minus_DTN #positive
|
|
555
|
+
|
|
556
|
+
n1_divi_n2 = (n1/n2)
|
|
557
|
+
sin_theta_2_squared = (n1_divi_n2**2.0)*(1.-cos_theta_1**2.)
|
|
558
|
+
cos_theta_2 = torch.sqrt(1.0-sin_theta_2_squared)
|
|
559
|
+
#N has specific sign
|
|
560
|
+
out = (n1_divi_n2.reshape(-1,1))*D+((n1_divi_n2*cos_theta_1-cos_theta_2).reshape(-1,1))*N
|
|
561
|
+
return out
|
|
562
|
+
|
|
563
|
+
class LensSurfaceTransmissionEnter(OpticalSurface):
|
|
564
|
+
def __init__(self,transform:Transform,surface,aperture_radius:float,n_func,is_square:bool=False):
|
|
565
|
+
super().__init__(transform,surface,aperture_radius,is_square,'#dae8fc',"#6c8ebf")
|
|
566
|
+
self.n_func = n_func
|
|
567
|
+
|
|
568
|
+
def forward(self, O1, D1, wl, n_func_enviroment,meta_data):
|
|
569
|
+
"""
|
|
570
|
+
Propagates rays through the lens entry surface.
|
|
571
|
+
|
|
572
|
+
Args:
|
|
573
|
+
O1 (torch.Tensor): Ray origins.
|
|
574
|
+
D1 (torch.Tensor): Ray directions.
|
|
575
|
+
wl (torch.Tensor): Wavelengths.
|
|
576
|
+
n_func_enviroment: Function returning environmental refractive index.
|
|
577
|
+
meta_data (dict): Ray metadata.
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
Tuple: Updated ray origins, directions, wavelengths, environment function, and metadata.
|
|
581
|
+
"""
|
|
582
|
+
PL, OPL, ray_paths, valid = meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"],meta_data["valid"]
|
|
583
|
+
|
|
584
|
+
surface_and_normal_func1,param_args1 = self.get_surface_and_normal_func_with_params()
|
|
585
|
+
t1 = self.get_ray_intersect_length(O1,D1)
|
|
586
|
+
O2 = O1+t1*D1
|
|
587
|
+
valid = self.get_new_is_valid(O2,valid)
|
|
588
|
+
|
|
589
|
+
_,N2 = surface_and_normal_func1(O2,*param_args1)
|
|
590
|
+
|
|
591
|
+
n_enviroment = n_func_enviroment(wl)
|
|
592
|
+
n = self.n_func(wl)
|
|
593
|
+
D2 = get_refracted_directions(D1, N2, n_enviroment, n)
|
|
594
|
+
PL+=t1.reshape(-1)
|
|
595
|
+
OPL+=t1.reshape(-1)*n_enviroment
|
|
596
|
+
ray_paths+=[O2.detach()]
|
|
597
|
+
|
|
598
|
+
meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"], meta_data["valid"] = PL, OPL, ray_paths, valid
|
|
599
|
+
return O2,D2,wl,n_func_enviroment,meta_data
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
class LensSurfaceTransmissionLeave(OpticalSurface):
|
|
603
|
+
def __init__(self,transform:Transform,surface,aperture_radius:float,n_func,is_square:bool=False):
|
|
604
|
+
super().__init__(transform,surface,aperture_radius,is_square,'#dae8fc',"#6c8ebf")
|
|
605
|
+
self.n_func = n_func
|
|
606
|
+
|
|
607
|
+
def forward(self, O2, D2, wl, n_func_enviroment, meta_data):
|
|
608
|
+
"""
|
|
609
|
+
Propagates rays through the lens exit surface.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
O2 (torch.Tensor): Ray origins.
|
|
613
|
+
D2 (torch.Tensor): Ray directions.
|
|
614
|
+
wl (torch.Tensor): Wavelengths.
|
|
615
|
+
n_func_enviroment: Function returning environmental refractive index.
|
|
616
|
+
meta_data (dict): Ray metadata.
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
Tuple: Updated ray origins, directions, wavelengths, environment function, and metadata.
|
|
620
|
+
"""
|
|
621
|
+
PL, OPL, ray_paths, valid = meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"],meta_data["valid"]
|
|
622
|
+
|
|
623
|
+
surface_and_normal_func2,param_args2 = self.get_surface_and_normal_func_with_params()
|
|
624
|
+
|
|
625
|
+
t2 = self.get_ray_intersect_length(O2,D2)
|
|
626
|
+
O3 = O2+t2*D2
|
|
627
|
+
valid = self.get_new_is_valid(O3,valid)
|
|
628
|
+
|
|
629
|
+
n_enviroment = n_func_enviroment(wl)
|
|
630
|
+
n = self.n_func(wl)
|
|
631
|
+
|
|
632
|
+
_,N3 = surface_and_normal_func2(O3,*param_args2)
|
|
633
|
+
D3 = get_refracted_directions(D2, N3,n, n_enviroment)
|
|
634
|
+
|
|
635
|
+
PL+=t2.reshape(-1)
|
|
636
|
+
OPL+=t2.reshape(-1)*n
|
|
637
|
+
ray_paths+=[O3.detach()]
|
|
638
|
+
|
|
639
|
+
meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"], meta_data["valid"] = PL, OPL, ray_paths, valid
|
|
640
|
+
return O3,D3,wl,n_func_enviroment,meta_data
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
class LensSurfaceSide(PhysicalSurface,Plotable):
|
|
644
|
+
"""
|
|
645
|
+
Non-optical surface connecting two curved lens surfaces for visualization.
|
|
646
|
+
|
|
647
|
+
Used to render the full 3D body of the lens.
|
|
648
|
+
|
|
649
|
+
Attributes:
|
|
650
|
+
surface1 (PhysicalSurface): First lens surface.
|
|
651
|
+
surface2 (PhysicalSurface): Second lens surface.
|
|
652
|
+
aperture_radius (float): Radius or half-width of aperture.
|
|
653
|
+
is_square (bool): Whether aperture is square.
|
|
654
|
+
"""
|
|
655
|
+
def __init__(self,surface1:PhysicalSurface,surface2:PhysicalSurface,aperture_radius,is_square:bool):
|
|
656
|
+
Plotable.__init__(self,'#dae8fc','#dae8fc')
|
|
657
|
+
PhysicalSurface.__init__(self)
|
|
658
|
+
self.aperture_radius = aperture_radius
|
|
659
|
+
self.surface1,self.surface2 = surface1,surface2
|
|
660
|
+
self.is_square = is_square
|
|
661
|
+
|
|
662
|
+
self.integrator = Cube([[0.0,1.0],[0.0,1.0]])
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def parametric_sample(self, num_points:int, method:str="sobol"):
|
|
666
|
+
"""
|
|
667
|
+
Samples parametric positions on the lens side surface.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
num_points (int): Number of sample points.
|
|
671
|
+
method (str): Sampling method ("sobol", "monte_carlo", "midpoint").
|
|
672
|
+
|
|
673
|
+
Returns:
|
|
674
|
+
Tuple[torch.Tensor, torch.Tensor]: Sampled positions and integration weights.
|
|
675
|
+
|
|
676
|
+
Raises:
|
|
677
|
+
RuntimeError: If unsupported method is provided.
|
|
678
|
+
"""
|
|
679
|
+
if (method != "sobol") or (method != "monte_carlo") or (method != "midpoint"):
|
|
680
|
+
raise RuntimeError("Only sobol,monte_carlo and midpoint supported for LensSurfaceSide parametric_sample")
|
|
681
|
+
|
|
682
|
+
return self.integrator.sample(num_points, method)
|
|
683
|
+
|
|
684
|
+
def parametric_surface(self, parametric_pos:torch.Tensor) -> torch.Tensor:
|
|
685
|
+
"""
|
|
686
|
+
Maps parametric coordinates to 3D global coordinates for the lens side.
|
|
687
|
+
|
|
688
|
+
Args:
|
|
689
|
+
parametric_pos (torch.Tensor): Parametric positions of shape (N, 2).
|
|
690
|
+
|
|
691
|
+
Returns:
|
|
692
|
+
torch.Tensor: 3D positions of shape (N, 3).
|
|
693
|
+
"""
|
|
694
|
+
device = parametric_pos.device
|
|
695
|
+
dtype = parametric_pos.dtype
|
|
696
|
+
|
|
697
|
+
def parameterize_circle(t):
|
|
698
|
+
theta = 2 * torch.pi * t # Convert t to radians
|
|
699
|
+
x = torch.cos(theta)*self.aperture_radius
|
|
700
|
+
y = torch.sin(theta)*self.aperture_radius
|
|
701
|
+
return (x,y)
|
|
702
|
+
|
|
703
|
+
def parameterize_height(x,y,param_height):
|
|
704
|
+
device = x.device
|
|
705
|
+
dtype = x.dtype
|
|
706
|
+
local_pos = torch.zeros((x.shape[0],2),device=device,dtype=dtype)
|
|
707
|
+
local_pos[:,0]=x
|
|
708
|
+
local_pos[:,1]=y
|
|
709
|
+
|
|
710
|
+
pos3Dlow = self.surface1.parametric_surface(local_pos)
|
|
711
|
+
pos3Dhigh = self.surface2.parametric_surface(local_pos)
|
|
712
|
+
|
|
713
|
+
out = pos3Dlow+(pos3Dhigh-pos3Dlow)*param_height
|
|
714
|
+
return out
|
|
715
|
+
|
|
716
|
+
t = parametric_pos[:,0]
|
|
717
|
+
param_height = parametric_pos[:,1]
|
|
718
|
+
out = parameterize_height(*parameterize_circle(t),param_height)
|
|
719
|
+
return out
|
|
720
|
+
|
|
721
|
+
def get_plot_points_2D(self, resolution:int) -> List[Tuple[torch.Tensor]]:
|
|
722
|
+
"""
|
|
723
|
+
Returns 2D slices through the surface (z-y plane) for plotting.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
resolution (int): Number of sample points along the y-axis.
|
|
727
|
+
|
|
728
|
+
Returns:
|
|
729
|
+
List[Tuple[torch.Tensor]]: List of (z, y) coordinate tuples.
|
|
730
|
+
"""
|
|
731
|
+
def aperture_pass(surface, transformation):
|
|
732
|
+
y = torch.tensor([-self.aperture_radius,self.aperture_radius])
|
|
733
|
+
x = torch.zeros_like(y)
|
|
734
|
+
O = torch.zeros((2,2))
|
|
735
|
+
O[:,0] = x
|
|
736
|
+
O[:,1] = y
|
|
737
|
+
Mv = None
|
|
738
|
+
with torch.no_grad():
|
|
739
|
+
Mv = surface.parametric_surface(O)
|
|
740
|
+
y = Mv[:,1]
|
|
741
|
+
z = Mv[:,2]
|
|
742
|
+
|
|
743
|
+
return z,y
|
|
744
|
+
_z1,_y1 = aperture_pass(self.surface1,self.surface1.transform)
|
|
745
|
+
_z2,_y2 = aperture_pass(self.surface2,self.surface2.transform)
|
|
746
|
+
z1 = torch.zeros((2))
|
|
747
|
+
y1 = torch.zeros((2))
|
|
748
|
+
z1[0] = _z1[0]
|
|
749
|
+
z1[1] = _z2[0]
|
|
750
|
+
|
|
751
|
+
y1[0] = _y1[0]
|
|
752
|
+
y1[1] = _y2[0]
|
|
753
|
+
|
|
754
|
+
z2 = torch.zeros((2))
|
|
755
|
+
y2 = torch.zeros((2))
|
|
756
|
+
z2[0] = _z1[1]
|
|
757
|
+
z2[1] = _z2[1]
|
|
758
|
+
|
|
759
|
+
y2[0] = _y1[1]
|
|
760
|
+
y2[1] = _y2[1]
|
|
761
|
+
return [(z1,y1),(z2,y2)]
|
|
762
|
+
|
|
763
|
+
def get_plot_points_3D(self, resolution:int) -> List[Tuple[torch.Tensor]]:
|
|
764
|
+
"""
|
|
765
|
+
Returns 3D grid of surface points for visualization.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
resolution (int): Grid resolution in x and y.
|
|
769
|
+
|
|
770
|
+
Returns:
|
|
771
|
+
List[Tuple[torch.Tensor]]: List of (x, y, z) meshgrids as torch tensors.
|
|
772
|
+
"""
|
|
773
|
+
def make_sub_surface(x,y):
|
|
774
|
+
O = torch.zeros((y.shape[0],2))
|
|
775
|
+
O[:,0] = x
|
|
776
|
+
O[:,1] = y
|
|
777
|
+
lower = self.surface1.parametric_surface(O)
|
|
778
|
+
upper = self.surface2.parametric_surface(O)
|
|
779
|
+
|
|
780
|
+
out_x = torch.zeros((x.shape[0],2))
|
|
781
|
+
out_y = torch.zeros_like(out_x)
|
|
782
|
+
out_z = torch.zeros_like(out_x)
|
|
783
|
+
|
|
784
|
+
out_x[:,0] = lower[:,0]
|
|
785
|
+
out_x[:,1] = upper[:,0]
|
|
786
|
+
|
|
787
|
+
out_y[:,0] = lower[:,1]
|
|
788
|
+
out_y[:,1] = upper[:,1]
|
|
789
|
+
|
|
790
|
+
out_z[:,0] = lower[:,2]
|
|
791
|
+
out_z[:,1] = upper[:,2]
|
|
792
|
+
out_x,out_y,out_z = out_x.detach().cpu(),out_y.detach().cpu(),out_z.detach().cpu()
|
|
793
|
+
return (out_x,out_y,out_z)
|
|
794
|
+
def parameterize_circle(t):
|
|
795
|
+
theta = 2 * torch.pi * t # Convert t to radians
|
|
796
|
+
x = torch.cos(theta)*self.aperture_radius
|
|
797
|
+
y = torch.sin(theta)*self.aperture_radius
|
|
798
|
+
return (x,y)
|
|
799
|
+
with torch.no_grad():
|
|
800
|
+
if self.is_square:
|
|
801
|
+
y = torch.linspace(-self.aperture_radius,self.aperture_radius,resolution)
|
|
802
|
+
x = torch.ones_like(y)*(-self.aperture_radius)
|
|
803
|
+
out = []
|
|
804
|
+
out.append(make_sub_surface(x,y))
|
|
805
|
+
out.append(make_sub_surface(y,x))
|
|
806
|
+
|
|
807
|
+
y = torch.linspace(-self.aperture_radius,self.aperture_radius,resolution)
|
|
808
|
+
x = torch.ones_like(y)*(self.aperture_radius)
|
|
809
|
+
out.append(make_sub_surface(x,y))
|
|
810
|
+
out.append(make_sub_surface(y,x))
|
|
811
|
+
return out
|
|
812
|
+
else:
|
|
813
|
+
t = torch.linspace(0.,1.0,resolution*4)
|
|
814
|
+
x,y = parameterize_circle(t)
|
|
815
|
+
out = []
|
|
816
|
+
out.append(make_sub_surface(x,y))
|
|
817
|
+
return out
|
|
818
|
+
|
|
819
|
+
def get_plotly_color_scale(self):
|
|
820
|
+
"""
|
|
821
|
+
Returns color scale for plotly visualization.
|
|
822
|
+
|
|
823
|
+
Returns:
|
|
824
|
+
List: Color scale values.
|
|
825
|
+
"""
|
|
826
|
+
if self.is_square:
|
|
827
|
+
out = []
|
|
828
|
+
for k in range(4):
|
|
829
|
+
out += Plotable.get_plotly_color_scale(self)
|
|
830
|
+
return out
|
|
831
|
+
else:
|
|
832
|
+
return Plotable.get_plotly_color_scale(self)
|
|
833
|
+
|
|
834
|
+
class Lens(OpticalElement):
|
|
835
|
+
r"""
|
|
836
|
+
Represents a transmissive lens consisting of two refractive surfaces.
|
|
837
|
+
|
|
838
|
+
The lens is modeled as a sequence of:
|
|
839
|
+
- Entry surface (refraction from external medium into the lens)
|
|
840
|
+
- Exit surface (refraction from lens into external medium)
|
|
841
|
+
- Side surface (purely for visualization)
|
|
842
|
+
|
|
843
|
+
In our implementation, lenses consist of two *explicit surfaces*, a transformation matrix :math:`M`, a lens thickness,
|
|
844
|
+
an *aperture radius*, and a material. When the lens is initialized, one can also optionally specify whether the lens
|
|
845
|
+
is round or square. If the keyword **is_square** is not specified, the lens will default to being round.
|
|
846
|
+
|
|
847
|
+
Example:
|
|
848
|
+
Below is an example of initializing a square lens:
|
|
849
|
+
|
|
850
|
+
>>> import diffinytrace as dit
|
|
851
|
+
>>> aperture_half = 30.
|
|
852
|
+
>>> lens_thickness = 8.
|
|
853
|
+
>>> material = dit.materials["NBK7"]
|
|
854
|
+
>>> transform = dit.transforms.Identity()
|
|
855
|
+
>>> bspline = dit.Bspline(aperture_half, [3, 3], [8, 8])
|
|
856
|
+
>>> plane = dit.Plane()
|
|
857
|
+
>>> lens = dit.Lens(transform, lens_thickness,
|
|
858
|
+
>>> bspline, plane,
|
|
859
|
+
>>> material, aperture_half, is_square=True)
|
|
860
|
+
|
|
861
|
+
Attributes:
|
|
862
|
+
n_func (Callable): Function mapping wavelength to refractive index of the lens material.
|
|
863
|
+
_transform1 (Transform): Transform for the first surface.
|
|
864
|
+
_transform2 (Transform): Transform for the second surface.
|
|
865
|
+
lens_thickness (torch.nn.Parameter): Learnable thickness of the lens.
|
|
866
|
+
surface1 (LensSurfaceTransmissionEnter): Entry surface.
|
|
867
|
+
surface2 (LensSurfaceTransmissionLeave): Exit surface.
|
|
868
|
+
lens_surface_side (LensSurfaceSide): Side surface (for 3D rendering).
|
|
869
|
+
aperture_radius (float): Radius (or half-width) of aperture.
|
|
870
|
+
is_square (bool): Whether the aperture is square.
|
|
871
|
+
"""
|
|
872
|
+
def __init__(self,transform:Transform,lens_thickness:float,surface1,surface2,n_func,aperture_radius:float,is_square=False):
|
|
873
|
+
OpticalElement.__init__(self,'#dae8fc',"#6c8ebf",True)
|
|
874
|
+
|
|
875
|
+
self.n_func = n_func
|
|
876
|
+
self._transform1 = transform
|
|
877
|
+
self.lens_thickness = make_parameter_from_input(lens_thickness)
|
|
878
|
+
self._transform2 = transforms.Distance(self.lens_thickness,parent_transform=self._transform1)
|
|
879
|
+
self._transform2.distance.bounds=torch.tensor([0.,torch.inf])
|
|
880
|
+
|
|
881
|
+
self.surface1 = LensSurfaceTransmissionEnter(self._transform1,surface1,aperture_radius,n_func,is_square)
|
|
882
|
+
self.surface2 = LensSurfaceTransmissionLeave(self._transform2,surface2,aperture_radius,n_func,is_square)
|
|
883
|
+
self.lens_surface_side = LensSurfaceSide(self.surface1,self.surface2,aperture_radius,is_square)
|
|
884
|
+
|
|
885
|
+
self.aperture_radius = aperture_radius
|
|
886
|
+
self.is_square = is_square
|
|
887
|
+
|
|
888
|
+
def get_plot_points_2D(self, resolution:int) -> List[Tuple[torch.Tensor]]:
|
|
889
|
+
"""
|
|
890
|
+
Returns 2D slices through the lens for plotting.
|
|
891
|
+
|
|
892
|
+
Args:
|
|
893
|
+
resolution (int): Number of sample points.
|
|
894
|
+
|
|
895
|
+
Returns:
|
|
896
|
+
List[Tuple[torch.Tensor]]: List of (z, y) coordinate tuples.
|
|
897
|
+
"""
|
|
898
|
+
def inverse_points(input):
|
|
899
|
+
z,y = input
|
|
900
|
+
z = torch.tensor(np.array(np.array(z)[::-1]))
|
|
901
|
+
y = torch.tensor(np.array(np.array(y)[::-1]))
|
|
902
|
+
return (z,y)
|
|
903
|
+
|
|
904
|
+
psurface1 = self.surface1.get_plot_points_2D(resolution)
|
|
905
|
+
psurface2 = self.surface2.get_plot_points_2D(resolution)
|
|
906
|
+
psurfaceCy = self.lens_surface_side.get_plot_points_2D(resolution)
|
|
907
|
+
|
|
908
|
+
#return psurface1+psurface2+psurfaceCy
|
|
909
|
+
|
|
910
|
+
out = [None for k in range(4)]
|
|
911
|
+
out[0] = psurface1[0]
|
|
912
|
+
out[1] = inverse_points(psurfaceCy[1])
|
|
913
|
+
out[2] = inverse_points(psurface2[0])
|
|
914
|
+
out[3] = psurfaceCy[0]
|
|
915
|
+
|
|
916
|
+
"""
|
|
917
|
+
out = []
|
|
918
|
+
out += self.surface1.get_plot_points_2D(resolution)
|
|
919
|
+
out += self.surface2.get_plot_points_2D(resolution)
|
|
920
|
+
out += self.cylinder_surface.get_plot_points_2D(resolution)
|
|
921
|
+
|
|
922
|
+
return out
|
|
923
|
+
|
|
924
|
+
"""
|
|
925
|
+
return out
|
|
926
|
+
|
|
927
|
+
def get_plot_points_3D(self, resolution:int) -> List[Tuple[torch.Tensor]]:
|
|
928
|
+
"""
|
|
929
|
+
Returns 3D grid of lens surface points for visualization.
|
|
930
|
+
|
|
931
|
+
Args:
|
|
932
|
+
resolution (int): Grid resolution.
|
|
933
|
+
|
|
934
|
+
Returns:
|
|
935
|
+
List[Tuple[torch.Tensor]]: List of (x, y, z) meshgrids.
|
|
936
|
+
"""
|
|
937
|
+
out = []
|
|
938
|
+
out += self.surface1.get_plot_points_3D(resolution)
|
|
939
|
+
out += self.surface2.get_plot_points_3D(resolution)
|
|
940
|
+
out += self.lens_surface_side.get_plot_points_3D(resolution)
|
|
941
|
+
return out
|
|
942
|
+
|
|
943
|
+
def get_plotly_color_scale(self) -> List:
|
|
944
|
+
"""
|
|
945
|
+
Returns color scale for plotly visualization.
|
|
946
|
+
|
|
947
|
+
Returns:
|
|
948
|
+
List: Color scale values.
|
|
949
|
+
"""
|
|
950
|
+
out = []
|
|
951
|
+
out += self.surface1.get_plotly_color_scale()
|
|
952
|
+
out += self.surface2.get_plotly_color_scale()
|
|
953
|
+
out += self.lens_surface_side.get_plotly_color_scale()
|
|
954
|
+
return out
|
|
955
|
+
|
|
956
|
+
def get_plotable_childs(self) -> List:
|
|
957
|
+
"""
|
|
958
|
+
Returns plotable child elements.
|
|
959
|
+
|
|
960
|
+
Returns:
|
|
961
|
+
List: List of child elements.
|
|
962
|
+
"""
|
|
963
|
+
return []
|
|
964
|
+
|
|
965
|
+
def forward(self,O1:torch.Tensor,D1:torch.Tensor,wl:torch.Tensor,n_func_enviroment,meta_data):
|
|
966
|
+
"""
|
|
967
|
+
Simulates light passing through the lens.
|
|
968
|
+
|
|
969
|
+
Args:
|
|
970
|
+
O1 (torch.Tensor): Ray origin positions.
|
|
971
|
+
D1 (torch.Tensor): Ray directions.
|
|
972
|
+
wl (torch.Tensor): Wavelengths.
|
|
973
|
+
n_func_enviroment (Callable): Function returning external medium refractive index.
|
|
974
|
+
meta_data (dict): Ray metadata (PL, OPL, paths, valid).
|
|
975
|
+
|
|
976
|
+
Returns:
|
|
977
|
+
Tuple[torch.Tensor]: Updated ray origins, directions, etc.
|
|
978
|
+
"""
|
|
979
|
+
out = self.surface1(O1,D1,wl,n_func_enviroment,meta_data)
|
|
980
|
+
return self.surface2(*out)
|
|
981
|
+
|
|
982
|
+
def get_transform(self):
|
|
983
|
+
"""
|
|
984
|
+
Returns the transformation of the lens exit surface.
|
|
985
|
+
|
|
986
|
+
Returns:
|
|
987
|
+
Transform: The transformation object.
|
|
988
|
+
"""
|
|
989
|
+
return self.surface2.transform
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
def compute_reflected_directions(D:torch.Tensor, N:torch.Tensor) -> torch.Tensor:
|
|
993
|
+
r"""
|
|
994
|
+
Computes reflected ray directions using the reflection law.
|
|
995
|
+
|
|
996
|
+
Args:
|
|
997
|
+
D (torch.Tensor): Incident directions of shape (M, 3), normalized.
|
|
998
|
+
N (torch.Tensor): Surface normals at points of incidence, shape (M, 3).
|
|
999
|
+
|
|
1000
|
+
Returns:
|
|
1001
|
+
torch.Tensor: Reflected directions of shape (M, 3).
|
|
1002
|
+
|
|
1003
|
+
"""
|
|
1004
|
+
# Ensure the input tensors are normalized (unit vectors)
|
|
1005
|
+
D = torch.nn.functional.normalize(D, dim=1)
|
|
1006
|
+
N = torch.nn.functional.normalize(N, dim=1)
|
|
1007
|
+
cos_theta_1 = -torch.sum(D * N, dim=1, keepdim=True)
|
|
1008
|
+
|
|
1009
|
+
# Compute the reflected direction
|
|
1010
|
+
out = D + 2 * cos_theta_1 * N # Shape: (M, 3)
|
|
1011
|
+
|
|
1012
|
+
return out
|
|
1013
|
+
|
|
1014
|
+
class Mirror(OpticalSurface):
|
|
1015
|
+
"""
|
|
1016
|
+
Reflective optical element that reflects rays according to the law of reflection.
|
|
1017
|
+
|
|
1018
|
+
Visualization is colored in a warm gold tone.
|
|
1019
|
+
|
|
1020
|
+
Inherits:
|
|
1021
|
+
OpticalSurface: Full support for surface transformation and intersection.
|
|
1022
|
+
|
|
1023
|
+
"""
|
|
1024
|
+
def __init__(self,transform,surface,aperture_radius,is_square=False):
|
|
1025
|
+
super().__init__(transform,surface,aperture_radius,is_square,'#fff2cc','#d6b656')
|
|
1026
|
+
|
|
1027
|
+
def forward(self,O1,D1,wl,n_func_enviroment,meta_data):
|
|
1028
|
+
"""
|
|
1029
|
+
Propagates rays through the mirror surface.
|
|
1030
|
+
|
|
1031
|
+
Args:
|
|
1032
|
+
O1 (torch.Tensor): Ray origins.
|
|
1033
|
+
D1 (torch.Tensor): Ray directions.
|
|
1034
|
+
wl (torch.Tensor): Wavelengths.
|
|
1035
|
+
n_func_enviroment: Function returning environmental refractive index.
|
|
1036
|
+
meta_data (dict): Ray metadata.
|
|
1037
|
+
|
|
1038
|
+
Returns:
|
|
1039
|
+
Tuple: Updated ray origins, directions, wavelengths, environment function, and metadata.
|
|
1040
|
+
"""
|
|
1041
|
+
#DONE: test mirror for 180° rotation.
|
|
1042
|
+
PL, OPL, ray_paths, valid = meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"],meta_data["valid"]
|
|
1043
|
+
|
|
1044
|
+
surface_and_normal_func1,param_args1 = self.get_surface_and_normal_func_with_params()
|
|
1045
|
+
|
|
1046
|
+
t1 = self.get_ray_intersect_length(O1,D1)
|
|
1047
|
+
O2 = O1+t1*D1
|
|
1048
|
+
valid = self.get_new_is_valid(O2,valid)
|
|
1049
|
+
|
|
1050
|
+
_,N2 = surface_and_normal_func1(O2,*param_args1)
|
|
1051
|
+
|
|
1052
|
+
D2 = compute_reflected_directions(D1, N2)
|
|
1053
|
+
|
|
1054
|
+
n_enviroment = n_func_enviroment(wl)
|
|
1055
|
+
PL+=t1.reshape(-1)
|
|
1056
|
+
OPL+=t1.reshape(-1)*n_enviroment
|
|
1057
|
+
ray_paths += [O2.detach()]
|
|
1058
|
+
meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"], meta_data["valid"] = PL, OPL, ray_paths, valid
|
|
1059
|
+
return O2,D2,wl,n_func_enviroment,meta_data
|
|
1060
|
+
|
|
1061
|
+
|
|
1062
|
+
class Detector(OpticalSurface):
|
|
1063
|
+
r"""
|
|
1064
|
+
Represents a terminal optical element that collects ray data.
|
|
1065
|
+
|
|
1066
|
+
Detectors consist of an *explicit surface*, a transformation matrix :math:`M`, and an *aperture radius*.
|
|
1067
|
+
The detector class represents a target surface used to track the rays that hit it. When the detector is initialized,
|
|
1068
|
+
one can also optionally specify whether the detector is round or square. If the keyword **is_square** is not specified,
|
|
1069
|
+
the detector defaults to being square.
|
|
1070
|
+
|
|
1071
|
+
Example:
|
|
1072
|
+
Below is an example of how to initialize a detector:
|
|
1073
|
+
|
|
1074
|
+
>>> import diffinytrace as dit
|
|
1075
|
+
>>> aperture_half = 30.
|
|
1076
|
+
>>> transform = dit.transforms.Identity()
|
|
1077
|
+
>>> plane = dit.Plane()
|
|
1078
|
+
>>> detector = dit.Detector(transform, plane,
|
|
1079
|
+
>>> aperture_half, is_square=False)
|
|
1080
|
+
|
|
1081
|
+
"""
|
|
1082
|
+
def __init__(self,transform,surface,aperture_radius,is_square=True):
|
|
1083
|
+
super().__init__(transform,surface,aperture_radius,is_square,'#d5e8d4','#82b366')
|
|
1084
|
+
|
|
1085
|
+
def forward(self,O1,D1,wl,n_func_enviroment,meta_data):
|
|
1086
|
+
r"""
|
|
1087
|
+
Captures the final ray interaction without altering its direction.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
O1 (torch.Tensor): Ray origin.
|
|
1091
|
+
D1 (torch.Tensor): Ray direction.
|
|
1092
|
+
wl (torch.Tensor): Wavelength.
|
|
1093
|
+
n_func_enviroment (Callable): Function for surrounding medium.
|
|
1094
|
+
meta_data (dict): Ray tracing metadata.
|
|
1095
|
+
|
|
1096
|
+
Returns:
|
|
1097
|
+
Tuple[torch.Tensor]: Final ray data.
|
|
1098
|
+
"""
|
|
1099
|
+
PL, OPL, ray_paths, valid = meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"],meta_data["valid"]
|
|
1100
|
+
|
|
1101
|
+
t1 = self.get_ray_intersect_length(O1,D1)
|
|
1102
|
+
O2 = O1+t1*D1
|
|
1103
|
+
valid = self.get_new_is_valid(O2,valid)
|
|
1104
|
+
D2 = D1
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
n_enviroment = n_func_enviroment(wl)
|
|
1108
|
+
PL+=t1.reshape(-1)
|
|
1109
|
+
OPL+=t1.reshape(-1)*n_enviroment
|
|
1110
|
+
|
|
1111
|
+
ray_paths+=[O2.detach()]
|
|
1112
|
+
meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"], meta_data["valid"] = PL, OPL, ray_paths, valid
|
|
1113
|
+
return O2,D2,wl,n_func_enviroment,meta_data
|
|
1114
|
+
|
|
1115
|
+
|
|
1116
|
+
|
|
1117
|
+
def trace_to_detector(optical_system:SequentialOpticalSystem,
|
|
1118
|
+
sequence:List,
|
|
1119
|
+
source,
|
|
1120
|
+
detector:Detector,
|
|
1121
|
+
num_rays:int,
|
|
1122
|
+
device=torch.get_default_device(),
|
|
1123
|
+
method_ray_tracing:str="sobol_pow2"):
|
|
1124
|
+
r"""
|
|
1125
|
+
Traces rays through a system to a detector and returns the impact coordinates.
|
|
1126
|
+
|
|
1127
|
+
Args:
|
|
1128
|
+
optical_system (SequentialOpticalSystem): Ray-tracing pipeline.
|
|
1129
|
+
sequence (list[str]): Ordered names of system modules.
|
|
1130
|
+
source: Source object with `.sample()` method.
|
|
1131
|
+
detector (Detector): Final surface to collect rays.
|
|
1132
|
+
num_rays (int): Number of rays to simulate.
|
|
1133
|
+
device: Torch device (CPU/GPU).
|
|
1134
|
+
method_ray_tracing (str): Sampling method for source rays.
|
|
1135
|
+
|
|
1136
|
+
Returns:
|
|
1137
|
+
Tuple[torch.Tensor]: (input samples, weights, detector plane hits, wavelengths)
|
|
1138
|
+
"""
|
|
1139
|
+
def g_mapping(x):
|
|
1140
|
+
O,D,wl,_,_ = optical_system(x,sequence)
|
|
1141
|
+
O_local = detector.to_local_pos(O)
|
|
1142
|
+
return O_local[:,[0,1]],O,wl
|
|
1143
|
+
x,weights = source.sample(num_rays,method_ray_tracing)
|
|
1144
|
+
x = x.to(device)
|
|
1145
|
+
weights = weights.to(device)
|
|
1146
|
+
y,O,wl = g_mapping(x)
|
|
1147
|
+
return x,weights,y,wl
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
|
|
1151
|
+
def set_unused_params_to_zero(optical_system:SequentialOpticalSystem,
|
|
1152
|
+
sequence,
|
|
1153
|
+
source,
|
|
1154
|
+
params,
|
|
1155
|
+
num_rays=200000,
|
|
1156
|
+
method_ray_tracing="sobol"):
|
|
1157
|
+
"""
|
|
1158
|
+
Sets unused parameters (those with zero gradient across ray paths) to zero.
|
|
1159
|
+
|
|
1160
|
+
Args:
|
|
1161
|
+
optical_system (SequentialOpticalSystem): Full system.
|
|
1162
|
+
sequence (list): Ordered module names.
|
|
1163
|
+
source: Ray source.
|
|
1164
|
+
params (list[torch.nn.Parameter] or torch.nn.Parameter): Parameters to clean.
|
|
1165
|
+
num_rays (int): Ray sample count.
|
|
1166
|
+
method_ray_tracing (str): Sampling method.
|
|
1167
|
+
"""
|
|
1168
|
+
if isinstance(params,nn.Parameter):
|
|
1169
|
+
params = [params]
|
|
1170
|
+
params = [param for param in params]
|
|
1171
|
+
device = params[0].device
|
|
1172
|
+
dtype = params[0].dtype
|
|
1173
|
+
|
|
1174
|
+
x,weights = source.sample(num_rays,method_ray_tracing)
|
|
1175
|
+
|
|
1176
|
+
x = x.to(device=device,dtype=dtype)
|
|
1177
|
+
weights = weights.to(device=device,dtype=dtype)
|
|
1178
|
+
O,D,wl,_,_ = optical_system(x,sequence)
|
|
1179
|
+
dOdp = grad(O,params,torch.randn_like(O),create_graph=False,materialize_grads=True)
|
|
1180
|
+
dDdp = grad(D,params,torch.randn_like(D),create_graph=False,materialize_grads=True)
|
|
1181
|
+
dwldp = grad(wl,params,torch.randn_like(wl),create_graph=False,materialize_grads=True)
|
|
1182
|
+
|
|
1183
|
+
for k in range(len(params)):
|
|
1184
|
+
with torch.no_grad():
|
|
1185
|
+
dp_zero = (dOdp[k]==0.0).float()*(dDdp[k]==0.0).float()*(dwldp[k]==0.0).float()
|
|
1186
|
+
dp_zero = dp_zero==1.0
|
|
1187
|
+
param = params[k]
|
|
1188
|
+
param.data[dp_zero] = 0.0
|
|
1189
|
+
|
|
1190
|
+
|
|
1191
|
+
def get_unused_params_mask(optical_system:SequentialOpticalSystem,
|
|
1192
|
+
sequence:List[str],
|
|
1193
|
+
source,
|
|
1194
|
+
params,
|
|
1195
|
+
num_rays:int=100000,
|
|
1196
|
+
method_ray_tracing="sobol")->List[torch.BoolTensor]:
|
|
1197
|
+
"""
|
|
1198
|
+
Returns a boolean mask identifying which parameters are unused in the ray tracing process.
|
|
1199
|
+
|
|
1200
|
+
Args:
|
|
1201
|
+
optical_system (SequentialOpticalSystem): Full system.
|
|
1202
|
+
sequence (list): Ordered module names.
|
|
1203
|
+
source: Ray source.
|
|
1204
|
+
params (list[torch.nn.Parameter]): Parameter list.
|
|
1205
|
+
num_rays (int): Number of rays to test.
|
|
1206
|
+
method_ray_tracing (str): Sampling method.
|
|
1207
|
+
|
|
1208
|
+
Returns:
|
|
1209
|
+
list[torch.BoolTensor]: Masks of the same shape as each parameter.
|
|
1210
|
+
"""
|
|
1211
|
+
if isinstance(params,nn.Parameter):
|
|
1212
|
+
params = [params]
|
|
1213
|
+
params = [param for param in params]
|
|
1214
|
+
device= params[0].device
|
|
1215
|
+
x,weights = source.sample(num_rays,method_ray_tracing)
|
|
1216
|
+
|
|
1217
|
+
x = x.to(device)
|
|
1218
|
+
weights = weights.to(device)
|
|
1219
|
+
O,D,wl,_,_ = optical_system(x,sequence)
|
|
1220
|
+
dOdp = grad(O,params,torch.randn_like(O),create_graph=False,materialize_grads=True)
|
|
1221
|
+
dDdp = grad(D,params,torch.randn_like(D),create_graph=False,materialize_grads=True)
|
|
1222
|
+
dwldp = grad(wl,params,torch.randn_like(wl),create_graph=False,materialize_grads=True)
|
|
1223
|
+
|
|
1224
|
+
#print("len(params)",len(params))
|
|
1225
|
+
out = []
|
|
1226
|
+
for k in range(len(params)):
|
|
1227
|
+
with torch.no_grad():
|
|
1228
|
+
dp_zero = (dOdp[k]==0.0).float()*(dDdp[k]==0.0).float()*(dwldp[k]==0.0).float()
|
|
1229
|
+
dp_zero = dp_zero==1.0
|
|
1230
|
+
#print(dp_zero)
|
|
1231
|
+
out+=[dp_zero]
|
|
1232
|
+
return out
|
|
1233
|
+
|
|
1234
|
+
|
|
1235
|
+
def set_used_params_bounds_to_constant(optical_system,sequence,source,params,bounds_attr_name_new,bounds_attr_name_old="bounds",num_rays=100000,method_ray_tracing="sobol"):
|
|
1236
|
+
"""
|
|
1237
|
+
Locks unused parameters by copying their current value as bounds, making them constant.
|
|
1238
|
+
|
|
1239
|
+
Args:
|
|
1240
|
+
bounds_attr_name_new (str): Name of the new bounds attribute to write.
|
|
1241
|
+
bounds_attr_name_old (str): Name of the original bounds attribute.
|
|
1242
|
+
"""
|
|
1243
|
+
mask = get_unused_params_mask(optical_system,sequence,source,params,num_rays,method_ray_tracing)
|
|
1244
|
+
set_bounds_from_params_mask(params,mask,bounds_attr_name_new,bounds_attr_name_old)
|
|
1245
|
+
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
class FresnelOpticalSurface(OpticalSurface):
|
|
1249
|
+
def __init__(self,transform,surface,aperture_radius,surface_derivative_x,surface_derivative_y,is_square=False):
|
|
1250
|
+
super().__init__(transform,surface,aperture_radius,is_square,'#dae8fc',"#6c8ebf")
|
|
1251
|
+
self.surface_derivative_x = surface_derivative_x
|
|
1252
|
+
self.surface_derivative_y = surface_derivative_y
|
|
1253
|
+
|
|
1254
|
+
def get_virtual_normals(self,O):
|
|
1255
|
+
"""
|
|
1256
|
+
Computes virtual surface normals for Fresnel surfaces.
|
|
1257
|
+
|
|
1258
|
+
Args:
|
|
1259
|
+
O (torch.Tensor): Positions.
|
|
1260
|
+
|
|
1261
|
+
Returns:
|
|
1262
|
+
torch.Tensor: Virtual normals.
|
|
1263
|
+
"""
|
|
1264
|
+
surface_and_normal1,param_args1 = construct_surface_and_normal_func_with_params([self.transform,self.surface_derivative_x])
|
|
1265
|
+
surface_and_normal2,param_args2 = construct_surface_and_normal_func_with_params([self.transform,self.surface_derivative_y])
|
|
1266
|
+
dx,_ = surface_and_normal1(O,*param_args1)
|
|
1267
|
+
dy,_ = surface_and_normal2(O,*param_args2)
|
|
1268
|
+
dx = dx.reshape(-1,1)
|
|
1269
|
+
dy = dy.reshape(-1,1)
|
|
1270
|
+
dz = torch.ones_like(dx)
|
|
1271
|
+
out = torch.cat([dx,dy,dz],dim=1)
|
|
1272
|
+
return out
|
|
1273
|
+
|
|
1274
|
+
|
|
1275
|
+
class FresnelVirtualLensSurfaceTransmissionEnter(FresnelOpticalSurface):
|
|
1276
|
+
def __init__(self,transform,surface,aperture_radius,n_func,surface_derivative_x,surface_derivative_y,is_square=False):
|
|
1277
|
+
super().__init__(transform,surface,aperture_radius,surface_derivative_x,surface_derivative_y,is_square)
|
|
1278
|
+
self.n_func = n_func
|
|
1279
|
+
|
|
1280
|
+
def forward(self, O1, D1, wl, n_func_enviroment,meta_data):
|
|
1281
|
+
"""
|
|
1282
|
+
Propagates rays through the Fresnel lens entry surface.
|
|
1283
|
+
|
|
1284
|
+
Args:
|
|
1285
|
+
O1 (torch.Tensor): Ray origins.
|
|
1286
|
+
D1 (torch.Tensor): Ray directions.
|
|
1287
|
+
wl (torch.Tensor): Wavelengths.
|
|
1288
|
+
n_func_enviroment: Function returning environmental refractive index.
|
|
1289
|
+
meta_data (dict): Ray metadata.
|
|
1290
|
+
|
|
1291
|
+
Returns:
|
|
1292
|
+
Tuple: Updated ray origins, directions, wavelengths, environment function, and metadata.
|
|
1293
|
+
"""
|
|
1294
|
+
PL, OPL, ray_paths, valid = meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"],meta_data["valid"]
|
|
1295
|
+
|
|
1296
|
+
t1 = self.get_ray_intersect_length(O1,D1)
|
|
1297
|
+
O2 = O1+t1*D1
|
|
1298
|
+
valid = self.get_new_is_valid(O2,valid)
|
|
1299
|
+
|
|
1300
|
+
N2 = self.get_virtual_normals(O2)
|
|
1301
|
+
|
|
1302
|
+
n_enviroment = n_func_enviroment(wl)
|
|
1303
|
+
n = self.n_func(wl)
|
|
1304
|
+
D2 = get_refracted_directions(D1, N2, n_enviroment, n)
|
|
1305
|
+
PL+=t1.reshape(-1)
|
|
1306
|
+
OPL+=t1.reshape(-1)*n_enviroment
|
|
1307
|
+
ray_paths+=[O2.detach()]
|
|
1308
|
+
|
|
1309
|
+
meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"], meta_data["valid"] = PL, OPL, ray_paths, valid
|
|
1310
|
+
return O2,D2,wl,n_func_enviroment,meta_data
|
|
1311
|
+
|
|
1312
|
+
class FresnelVirtualLensSurfaceTransmissionLeave(FresnelOpticalSurface):
|
|
1313
|
+
def __init__(self,transform,surface,aperture_radius,n_func,surface_derivative_x,surface_derivative_y,is_square=False):
|
|
1314
|
+
super().__init__(transform,surface,aperture_radius,surface_derivative_x,surface_derivative_y,is_square)
|
|
1315
|
+
self.n_func = n_func
|
|
1316
|
+
|
|
1317
|
+
def forward(self, O2, D2, wl, n_func_enviroment, meta_data):
|
|
1318
|
+
"""
|
|
1319
|
+
Propagates rays through the Fresnel lens exit surface.
|
|
1320
|
+
|
|
1321
|
+
Args:
|
|
1322
|
+
O2 (torch.Tensor): Ray origins.
|
|
1323
|
+
D2 (torch.Tensor): Ray directions.
|
|
1324
|
+
wl (torch.Tensor): Wavelengths.
|
|
1325
|
+
n_func_enviroment: Function returning environmental refractive index.
|
|
1326
|
+
meta_data (dict): Ray metadata.
|
|
1327
|
+
|
|
1328
|
+
Returns:
|
|
1329
|
+
Tuple: Updated ray origins, directions, wavelengths, environment function, and metadata.
|
|
1330
|
+
"""
|
|
1331
|
+
PL, OPL, ray_paths, valid = meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"],meta_data["valid"]
|
|
1332
|
+
|
|
1333
|
+
|
|
1334
|
+
t2 = self.get_ray_intersect_length(O2,D2)
|
|
1335
|
+
O3 = O2+t2*D2
|
|
1336
|
+
valid = self.get_new_is_valid(O3,valid)
|
|
1337
|
+
|
|
1338
|
+
n_enviroment = n_func_enviroment(wl)
|
|
1339
|
+
n = self.n_func(wl)
|
|
1340
|
+
|
|
1341
|
+
N3 = self.get_virtual_normals(O3)
|
|
1342
|
+
D3 = get_refracted_directions(D2, N3,n, n_enviroment)
|
|
1343
|
+
|
|
1344
|
+
PL+=t2.reshape(-1)
|
|
1345
|
+
OPL+=t2.reshape(-1)*n
|
|
1346
|
+
ray_paths+=[O3.detach()]
|
|
1347
|
+
|
|
1348
|
+
meta_data["PL"],meta_data["OPL"],meta_data["ray_paths"], meta_data["valid"] = PL, OPL, ray_paths, valid
|
|
1349
|
+
return O3,D3,wl,n_func_enviroment,meta_data
|
|
1350
|
+
|
|
1351
|
+
class FresnelVirtualLens(OpticalElement):
|
|
1352
|
+
"""
|
|
1353
|
+
"""
|
|
1354
|
+
def __init__(self,transform,lens_thickness,surface1,surface2,n_func,aperture_radius,surface1_derivative_x=None,surface1_derivative_y=None,surface2_derivative_x=None,surface2_derivative_y=None,is_square=False):
|
|
1355
|
+
OpticalElement.__init__(self,'#dae8fc',"#6c8ebf",True)
|
|
1356
|
+
|
|
1357
|
+
self.n_func = n_func
|
|
1358
|
+
self._transform1 = transform
|
|
1359
|
+
self.lens_thickness = make_parameter_from_input(lens_thickness)
|
|
1360
|
+
self._transform2 = transforms.Distance(self.lens_thickness,parent_transform=self._transform1)
|
|
1361
|
+
self._transform2.distance.bounds=torch.tensor([0.,torch.inf])
|
|
1362
|
+
|
|
1363
|
+
if (not surface1_derivative_x is None) or (not surface1_derivative_y is None):
|
|
1364
|
+
if (surface1_derivative_x is None or surface1_derivative_y is None):
|
|
1365
|
+
raise RuntimeError("if surface1_derivative_x is defined also surface1_derivative_y must be defined and the other way around!")
|
|
1366
|
+
self.surface1 = FresnelVirtualLensSurfaceTransmissionEnter(self._transform2,surface2,aperture_radius,n_func,surface1_derivative_x,surface1_derivative_y,is_square)
|
|
1367
|
+
else:
|
|
1368
|
+
self.surface1 = LensSurfaceTransmissionEnter(self._transform1,surface1,aperture_radius,n_func,is_square)
|
|
1369
|
+
|
|
1370
|
+
if (not surface2_derivative_x is None) or (not surface2_derivative_y is None):
|
|
1371
|
+
if (surface2_derivative_x is None or surface2_derivative_y is None):
|
|
1372
|
+
raise RuntimeError("if surface2_derivative_x is defined also surface2_derivative_y must be defined and the other way around!")
|
|
1373
|
+
self.surface2 = FresnelVirtualLensSurfaceTransmissionLeave(self._transform2,surface2,aperture_radius,n_func,surface2_derivative_x,surface2_derivative_y,is_square)
|
|
1374
|
+
|
|
1375
|
+
else:
|
|
1376
|
+
self.surface2 = LensSurfaceTransmissionLeave(self._transform2,surface2,aperture_radius,n_func,is_square)
|
|
1377
|
+
self.lens_surface_side = LensSurfaceSide(self.surface1,self.surface2,aperture_radius,is_square)
|
|
1378
|
+
|
|
1379
|
+
self.aperture_radius = aperture_radius
|
|
1380
|
+
self.is_square = is_square
|
|
1381
|
+
|
|
1382
|
+
def get_plot_points_2D(self,resolution:int):
|
|
1383
|
+
def inverse_points(input):
|
|
1384
|
+
z,y = input
|
|
1385
|
+
z = torch.tensor(np.array(np.array(z)[::-1]))
|
|
1386
|
+
y = torch.tensor(np.array(np.array(y)[::-1]))
|
|
1387
|
+
return (z,y)
|
|
1388
|
+
|
|
1389
|
+
psurface1 = self.surface1.get_plot_points_2D(resolution)
|
|
1390
|
+
psurface2 = self.surface2.get_plot_points_2D(resolution)
|
|
1391
|
+
psurfaceCy = self.lens_surface_side.get_plot_points_2D(resolution)
|
|
1392
|
+
|
|
1393
|
+
#return psurface1+psurface2+psurfaceCy
|
|
1394
|
+
|
|
1395
|
+
out = [None for k in range(4)]
|
|
1396
|
+
out[0] = psurface1[0]
|
|
1397
|
+
out[1] = inverse_points(psurfaceCy[1])
|
|
1398
|
+
out[2] = inverse_points(psurface2[0])
|
|
1399
|
+
out[3] = psurfaceCy[0]
|
|
1400
|
+
|
|
1401
|
+
"""
|
|
1402
|
+
out = []
|
|
1403
|
+
out += self.surface1.get_plot_points_2D(resolution)
|
|
1404
|
+
out += self.surface2.get_plot_points_2D(resolution)
|
|
1405
|
+
out += self.cylinder_surface.get_plot_points_2D(resolution)
|
|
1406
|
+
|
|
1407
|
+
return out
|
|
1408
|
+
|
|
1409
|
+
"""
|
|
1410
|
+
return out
|
|
1411
|
+
|
|
1412
|
+
def get_plot_points_3D(self,resolution):
|
|
1413
|
+
"""
|
|
1414
|
+
Returns 3D grid of Fresnel lens surface points for visualization.
|
|
1415
|
+
|
|
1416
|
+
Args:
|
|
1417
|
+
resolution (int): Grid resolution.
|
|
1418
|
+
|
|
1419
|
+
Returns:
|
|
1420
|
+
List[Tuple[torch.Tensor]]: List of (x, y, z) meshgrids.
|
|
1421
|
+
"""
|
|
1422
|
+
out = []
|
|
1423
|
+
out += self.surface1.get_plot_points_3D(resolution)
|
|
1424
|
+
out += self.surface2.get_plot_points_3D(resolution)
|
|
1425
|
+
out += self.lens_surface_side.get_plot_points_3D(resolution)
|
|
1426
|
+
return out
|
|
1427
|
+
|
|
1428
|
+
def get_plotly_color_scale(self):
|
|
1429
|
+
"""
|
|
1430
|
+
Returns color scale for plotly visualization.
|
|
1431
|
+
|
|
1432
|
+
Returns:
|
|
1433
|
+
List: Color scale values.
|
|
1434
|
+
"""
|
|
1435
|
+
out = []
|
|
1436
|
+
out += self.surface1.get_plotly_color_scale()
|
|
1437
|
+
out += self.surface2.get_plotly_color_scale()
|
|
1438
|
+
out += self.lens_surface_side.get_plotly_color_scale()
|
|
1439
|
+
return out
|
|
1440
|
+
|
|
1441
|
+
def get_plotable_childs(self)->List:
|
|
1442
|
+
"""
|
|
1443
|
+
Returns plotable child elements.
|
|
1444
|
+
|
|
1445
|
+
Returns:
|
|
1446
|
+
List: List of child elements.
|
|
1447
|
+
"""
|
|
1448
|
+
return []
|
|
1449
|
+
|
|
1450
|
+
def forward(self,
|
|
1451
|
+
O1:torch.Tensor,
|
|
1452
|
+
D1:torch.Tensor,
|
|
1453
|
+
wl:torch.Tensor,
|
|
1454
|
+
n_func_enviroment,
|
|
1455
|
+
meta_data)->torch.Tensor:
|
|
1456
|
+
"""
|
|
1457
|
+
Simulates light passing through the Fresnel lens.
|
|
1458
|
+
|
|
1459
|
+
Args:
|
|
1460
|
+
O1 (torch.Tensor): Ray origin positions.
|
|
1461
|
+
D1 (torch.Tensor): Ray directions.
|
|
1462
|
+
wl (torch.Tensor): Wavelengths.
|
|
1463
|
+
n_func_enviroment: Function returning external medium refractive index.
|
|
1464
|
+
meta_data (dict): Ray metadata.
|
|
1465
|
+
|
|
1466
|
+
Returns:
|
|
1467
|
+
Tuple: Updated ray origins, directions, etc.
|
|
1468
|
+
"""
|
|
1469
|
+
out = self.surface1(O1,D1,wl,n_func_enviroment,meta_data)
|
|
1470
|
+
return self.surface2(*out)
|
|
1471
|
+
|
|
1472
|
+
def get_transform(self) -> Transform:
|
|
1473
|
+
"""
|
|
1474
|
+
Returns the transformation of the Fresnel lens exit surface.
|
|
1475
|
+
|
|
1476
|
+
Returns:
|
|
1477
|
+
Transform: The transformation object.
|
|
1478
|
+
"""
|
|
1479
|
+
return self.surface2.transform
|
|
1480
|
+
|
|
1481
|
+
|
|
1482
|
+
|
|
1483
|
+
|
|
1484
|
+
|
|
1485
|
+
"""
|
|
1486
|
+
def smooth_optical_surface_with_unused_params(optical_surface:OpticalSurface,\
|
|
1487
|
+
optical_system:OpticalSystem,\
|
|
1488
|
+
sequence,\
|
|
1489
|
+
source,\
|
|
1490
|
+
params,\
|
|
1491
|
+
bounds_attr_name_old="bounds",\
|
|
1492
|
+
num_rays=100000,\
|
|
1493
|
+
method_ray_tracing="sobol",\
|
|
1494
|
+
num_points_surface=[701,701],\
|
|
1495
|
+
method_surface="simpson",\
|
|
1496
|
+
constraints=[],\
|
|
1497
|
+
minimization_method=None,\
|
|
1498
|
+
tol=1e-9):
|
|
1499
|
+
if isinstance(params,nn.Parameter):
|
|
1500
|
+
params = [params]
|
|
1501
|
+
|
|
1502
|
+
params = [param for param in params]
|
|
1503
|
+
|
|
1504
|
+
device = params[0].device
|
|
1505
|
+
dtype = params[0].dtype
|
|
1506
|
+
bounds_attr_name_new = "__used_bounds"
|
|
1507
|
+
set_used_params_bounds_to_constant(optical_system,sequence,source,params,bounds_attr_name_new,bounds_attr_name_old,num_rays,method_ray_tracing)
|
|
1508
|
+
def smoothness_func1():
|
|
1509
|
+
parametric_pos,weights = optical_surface.parametric_sample(num_points_surface,method_surface)
|
|
1510
|
+
weights = weights.to(device=device,dtype=dtype)
|
|
1511
|
+
parametric_pos = parametric_pos.detach()
|
|
1512
|
+
parametric_pos = parametric_pos.to(device=device,dtype=dtype)
|
|
1513
|
+
|
|
1514
|
+
parametric_pos.requires_grad = True
|
|
1515
|
+
tmp = optical_surface.parametric_surface(parametric_pos)
|
|
1516
|
+
tmp = optical_surface.to_local_pos(tmp)
|
|
1517
|
+
dzdx, = grad(tmp[:,2],parametric_pos,torch.ones_like(tmp[:,2]),create_graph=True,retain_graph=True)
|
|
1518
|
+
|
|
1519
|
+
ddzdx1dx, = grad(dzdx[:,0],parametric_pos,torch.ones_like(dzdx[:,0]),create_graph=True,retain_graph=True)
|
|
1520
|
+
ddzdx2dx, = grad(dzdx[:,1],parametric_pos,torch.ones_like(dzdx[:,1]),create_graph=True,retain_graph=True)
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
|
|
1524
|
+
smoothness = torch.sum((torch.abs(ddzdx1dx)+torch.abs(ddzdx2dx))*weights.reshape(-1,1))
|
|
1525
|
+
return smoothness
|
|
1526
|
+
|
|
1527
|
+
def smoothness_func2():
|
|
1528
|
+
parametric_pos,weights = optical_surface.parametric_sample(num_points_surface,method_surface)
|
|
1529
|
+
weights = weights.to(device=device,dtype=dtype)
|
|
1530
|
+
parametric_pos = parametric_pos.detach()
|
|
1531
|
+
parametric_pos = parametric_pos.to(device=device,dtype=dtype)
|
|
1532
|
+
|
|
1533
|
+
parametric_pos.requires_grad = True
|
|
1534
|
+
tmp = optical_surface.parametric_surface(parametric_pos)
|
|
1535
|
+
tmp = optical_surface.to_local_pos(tmp)
|
|
1536
|
+
dzdpos, = grad(tmp[:,2],parametric_pos,torch.ones_like(tmp[:,2]),create_graph=True,retain_graph=True)
|
|
1537
|
+
|
|
1538
|
+
smoothness = torch.sum(torch.abs(dzdpos)*weights.reshape(-1,1))
|
|
1539
|
+
return smoothness
|
|
1540
|
+
|
|
1541
|
+
|
|
1542
|
+
out = minimize(smoothness_func1, params, constraints, minimization_method, tol=tol,bounds_attr_name="__used_bounds")
|
|
1543
|
+
out = minimize(smoothness_func2, params, constraints, minimization_method, tol=tol,bounds_attr_name="__used_bounds")
|
|
1544
|
+
remove_bounds(params,bounds_attr_name_new)
|
|
1545
|
+
return out
|
|
1546
|
+
|
|
1547
|
+
def smooth_lens_with_unused_params(lens:Lens,\
|
|
1548
|
+
optical_system:OpticalSystem,\
|
|
1549
|
+
sequence,\
|
|
1550
|
+
source,\
|
|
1551
|
+
params,\
|
|
1552
|
+
bounds_attr_name_old="bounds",\
|
|
1553
|
+
num_rays=100000,\
|
|
1554
|
+
method_ray_tracing="sobol",\
|
|
1555
|
+
num_points_surface=[701,701],\
|
|
1556
|
+
method_surface="simpson",\
|
|
1557
|
+
constraints=[],\
|
|
1558
|
+
minimization_method=None,\
|
|
1559
|
+
tol=1e-9):
|
|
1560
|
+
def run_on_surface(optical_surface):
|
|
1561
|
+
smooth_optical_surface_with_unused_params(optical_surface,\
|
|
1562
|
+
optical_system,\
|
|
1563
|
+
sequence,\
|
|
1564
|
+
source,\
|
|
1565
|
+
params,\
|
|
1566
|
+
bounds_attr_name_old,\
|
|
1567
|
+
num_rays,\
|
|
1568
|
+
method_ray_tracing,\
|
|
1569
|
+
num_points_surface,\
|
|
1570
|
+
method_surface,\
|
|
1571
|
+
constraints,\
|
|
1572
|
+
minimization_method,\
|
|
1573
|
+
tol)
|
|
1574
|
+
run_on_surface(lens.surface1)
|
|
1575
|
+
run_on_surface(lens.surface2)
|
|
1576
|
+
|
|
1577
|
+
|
|
1578
|
+
|
|
1579
|
+
"""
|
|
1580
|
+
def set_unused_bspline_coeff_to_nearest(optical_system,
|
|
1581
|
+
sequence:list[str],
|
|
1582
|
+
source,
|
|
1583
|
+
bspline_surface,
|
|
1584
|
+
num_rays=100000,
|
|
1585
|
+
method_ray_tracing="sobol"):
|
|
1586
|
+
|
|
1587
|
+
"""
|
|
1588
|
+
Fills only the unused B-spline coefficients with the nearest used value.
|
|
1589
|
+
|
|
1590
|
+
This function identifies B-spline coefficients that have no influence on the ray paths
|
|
1591
|
+
(i.e., gradients are zero), and updates only those by copying the value from the closest
|
|
1592
|
+
neighboring coefficient that is used. Used coefficients remain unchanged.
|
|
1593
|
+
|
|
1594
|
+
This is useful for having geometry that is simple to manifacture while not tempering with the overall performance.
|
|
1595
|
+
|
|
1596
|
+
Args:
|
|
1597
|
+
optical_system (SequentialOpticalSystem): The optical system used for tracing.
|
|
1598
|
+
sequence (list[str]): Ordered list of module names for ray propagation.
|
|
1599
|
+
source: Ray source with a `.sample()` method.
|
|
1600
|
+
bspline_surface: Surface object with a `.coeff` tensor.
|
|
1601
|
+
num_rays (int, optional): Number of rays used to detect unused coefficients. Default is 100000.
|
|
1602
|
+
method_ray_tracing (str, optional): Sampling method (e.g., "sobol"). Default is "sobol".
|
|
1603
|
+
|
|
1604
|
+
Raises:
|
|
1605
|
+
RuntimeError: If all coefficients are unused — likely due to insufficient ray coverage.
|
|
1606
|
+
"""
|
|
1607
|
+
coeff = bspline_surface.coeff
|
|
1608
|
+
params = [coeff]
|
|
1609
|
+
|
|
1610
|
+
mask = get_unused_params_mask(optical_system,sequence,source,params,num_rays=num_rays,method_ray_tracing=method_ray_tracing)
|
|
1611
|
+
mask = mask[0]
|
|
1612
|
+
mask = mask.reshape(*coeff.shape)
|
|
1613
|
+
|
|
1614
|
+
dist = mask.float()
|
|
1615
|
+
shape = dist.shape
|
|
1616
|
+
dist = dist.reshape(-1)
|
|
1617
|
+
dist[dist==1.0] = torch.inf
|
|
1618
|
+
dist = dist.reshape(*shape)
|
|
1619
|
+
def valid_indices(yi, xi, mask):
|
|
1620
|
+
# Check if indices are within bounds
|
|
1621
|
+
if 0 <= yi < mask.shape[0] and 0 <= xi < mask.shape[1]:
|
|
1622
|
+
return True
|
|
1623
|
+
return False
|
|
1624
|
+
if (dist==torch.inf).all():
|
|
1625
|
+
raise RuntimeError("all coeffs seem to be unused maybe try more rays?")
|
|
1626
|
+
with torch.no_grad():#<----------------------- this was changed
|
|
1627
|
+
while (dist==torch.inf).any():
|
|
1628
|
+
print("number of unset coefficients: ",torch.sum((dist==torch.inf).float()))
|
|
1629
|
+
for yi in range(mask.shape[0]):
|
|
1630
|
+
for xi in range(mask.shape[1]):
|
|
1631
|
+
if mask[yi, xi]: # Only operate if the mask is true
|
|
1632
|
+
# Check each neighbor and update if valid
|
|
1633
|
+
min_dist = torch.inf
|
|
1634
|
+
min_dist_data = None
|
|
1635
|
+
|
|
1636
|
+
if valid_indices(yi + 1, xi + 1, dist) and dist[yi + 1, xi + 1]!=torch.inf:
|
|
1637
|
+
if min_dist>dist[yi + 1, xi + 1]:
|
|
1638
|
+
min_dist_data = coeff.data[yi + 1, xi + 1]
|
|
1639
|
+
min_dist = dist[yi + 1, xi + 1]
|
|
1640
|
+
|
|
1641
|
+
if valid_indices(yi - 1, xi + 1, dist) and dist[yi - 1, xi + 1]!=torch.inf:
|
|
1642
|
+
if min_dist>dist[yi - 1, xi + 1]:
|
|
1643
|
+
min_dist_data = coeff.data[yi - 1, xi + 1]
|
|
1644
|
+
min_dist = dist[yi - 1, xi + 1]
|
|
1645
|
+
|
|
1646
|
+
if valid_indices(yi - 1, xi - 1, dist) and dist[yi - 1, xi - 1]!=torch.inf:
|
|
1647
|
+
if min_dist>dist[yi - 1, xi - 1]:
|
|
1648
|
+
min_dist_data = coeff.data[yi - 1, xi - 1]
|
|
1649
|
+
min_dist = dist[yi - 1, xi - 1]
|
|
1650
|
+
|
|
1651
|
+
if valid_indices(yi + 1, xi - 1, dist) and dist[yi + 1, xi - 1]!=torch.inf:
|
|
1652
|
+
if min_dist>dist[yi + 1, xi - 1]:
|
|
1653
|
+
min_dist_data = coeff.data[yi + 1, xi - 1]
|
|
1654
|
+
min_dist = dist[yi + 1, xi - 1]
|
|
1655
|
+
|
|
1656
|
+
if min_dist != torch.inf:
|
|
1657
|
+
coeff.data[yi, xi] = min_dist_data
|
|
1658
|
+
dist[yi, xi] = min_dist+1
|
|
1659
|
+
|
|
1660
|
+
|