diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,79 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "Spectrum",
6
+ "VisibleSunlight_am15g"
7
+ ]
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ from .plotting.wavelength import PlotableWavelength
13
+ from typing import Callable,Union,List,Tuple
14
+
15
+ class Spectrum(nn.Module, PlotableWavelength):
16
+ """
17
+ A class to represent a spectrum as a function of wavelength.
18
+ """
19
+ def __init__(self, func: Callable[[torch.Tensor], torch.Tensor], bounds: Tuple[float, float]):
20
+ """
21
+ Initialize the Spectrum class.
22
+
23
+ Args:
24
+ func (callable): A function that takes a wavelength and returns the spectrum value.
25
+ bounds (tuple): A tuple containing the minimum and maximum wavelength.
26
+ """
27
+ nn.Module.__init__(self)
28
+ PlotableWavelength.__init__(self,bounds,"Intensity [1]")
29
+ self.func = func
30
+ self.bounds = bounds
31
+
32
+ def forward(self, wl: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Calculate the spectrum for given wavelengths.
35
+
36
+ Args:
37
+ wl (torch.Tensor or float): Wavelength in μm.
38
+ Returns:
39
+ torch.Tensor: Spectrum value at the given wavelengths.
40
+ """
41
+ if not torch.is_tensor(wl):
42
+ wl = torch.tensor(wl)
43
+
44
+ vmin,vmax = self.bounds
45
+ out = self.func(wl)
46
+
47
+ if isinstance(out,float):
48
+ return out*torch.ones_like(wl)
49
+
50
+ if isinstance(out,np.ndarray):
51
+ out = torch.tensor(out,device=wl.device,dtype=wl.dtype)
52
+
53
+ if (vmin > wl).any():
54
+ out[vmin > wl] = 0.0
55
+ if (wl>vmax).any():
56
+ out[wl>vmax] = 0.0
57
+
58
+ if torch.is_tensor(out):
59
+ if len(out.shape) == 0:
60
+ return out*torch.ones_like(wl)
61
+ return out
62
+
63
+ class VisibleSunlight_am15g(Spectrum):
64
+ """
65
+ A class to represent the AM 1.5 G spectrum.
66
+ This class uses the pvlib library to calculate the spectrum.
67
+ """
68
+ def __init__(self):
69
+ from pvlib.spectrum import get_am15g
70
+
71
+ def func(wl):
72
+ device = wl.device
73
+ dtype = wl.dtype
74
+ wl = wl.detach().cpu().numpy()
75
+ out = get_am15g(wl*1000.)
76
+ out = np.array(out)
77
+ out = torch.tensor(out,device=device,dtype=dtype)
78
+ return out
79
+ super().__init__(func,[0.360,0.780])
@@ -0,0 +1,468 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "Surface",
6
+ "Plane",
7
+ "Aspheric",
8
+ "Bspline",
9
+ "Legendre",
10
+ # "Zernike" # Uncomment if you implement Zernike
11
+ ]
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ #from diffinytrace.basis_functions import *
17
+ from .transforms import SemiFunctionalModule,Transform
18
+ from .optimize import make_parameter_from_input
19
+ from . import basis_functions
20
+ from typing import Tuple,List,Callable,Optional,Union
21
+
22
+
23
+ class Surface(SemiFunctionalModule):
24
+ r"""
25
+ While we all have an intuitive idea of what curves and surfaces are, we need a mathematically accurate definition from which we can proceed to illustrate how different types of algorithms are implemented. In the following, we will introduce three common ways of describing curves and surfaces.
26
+
27
+ 1. **Parametric Equations**
28
+ Parametric curves are functions that map a single variable :math:`\theta` (the parameter) to a vector in :math:`\mathbb{R}^2`. Thus, such curves are referred to as parametrized or parametrically defined curves (see :cite:`implicit_surfaces`). The variable :math:`\theta` is an element of the *parametric domain* of the parametric curve (see :cite:`IGA`). For example, a circle can be described with the *parametric domain* :math:`[0, 2\pi]` and the function :math:`f: [0, 2\pi] \mapsto \mathbb{R}^2`,
29
+
30
+ .. math::
31
+
32
+ f(\theta) =
33
+ \begin{bmatrix}
34
+ \cos \theta \\
35
+ \sin \theta
36
+ \end{bmatrix}.
37
+
38
+ Similarly, parametric surfaces can be described by a function that maps from a two-dimensional *parametric domain* to :math:`\mathbb{R}^3` (see :cite:`implicit_surfaces`).
39
+
40
+ 2. **Explicit Equations**
41
+ Curves and surfaces can also be expressed using explicit equations. When describing a curve with explicit equations, an explicit function :math:`f: \mathbb{R} \to \mathbb{R}` of the form :math:`y = f(x)` assigns a unique value of :math:`y` to each :math:`x \in \mathbb{R}`. The values of :math:`y` can then be seen as a description of the curve. Unfortunately, it is not possible to describe all curves and surfaces with this method. For example, considering the unit circle, only one semicircle can be represented at a time using explicit equations such as:
42
+
43
+ .. math::
44
+
45
+ y = \sqrt{1 - x^2} \quad \text{or} \quad y = -\sqrt{1 - x^2}.
46
+
47
+ Similarly, three-dimensional surfaces can be described explicitly using functions of the form :math:`y = f(x_1, x_2)`, which assign a unique :math:`y`-value to each pair of :math:`(x_1, x_2)`-coordinates (see :cite:`implicit_surfaces`).
48
+
49
+ 3. **Implicit Equations**
50
+ A planar curve is defined implicitly, or in Cartesian coordinates, when it is described as the set of solutions to an equation involving two variables, typically expressed as :math:`f(y_1, y_2) = 0`. For example, the equation
51
+
52
+ .. math::
53
+
54
+ y_1^2 + y_2^2 - 1 = 0
55
+
56
+ represents an implicit unit circle in :math:`\mathbb{R}^2`. Similarly, an implicit surface can be expressed with an equation in the form of (see :cite:`implicit_surfaces`):
57
+
58
+ .. math::
59
+
60
+ f(y_1, y_2, y_3) = 0.
61
+
62
+ **Optical Surfaces**
63
+ In our ray tracer, we use a less general description of the surfaces. We will call surfaces relevant for ray tracing *optical surfaces*. Every *optical surface* is composed of an *explicit surface* :math:`\hat{S}: \mathbb{R}^2 \mapsto \mathbb{R}` and a transformation matrix :math:`M \in \mathbb{R}^{4 \times 4}`. In the following, we will state the *implicit surface* description for the ray tracer itself and the *parametric surface* description for plots and constraint optimization.
64
+
65
+ 1. **Implicit Surface Description**
66
+ Here, surfaces are described implicitly by the equation :math:`s(\hat{y}) = 0`. The function :math:`s` is composed of the explicit description :math:`\hat{S}(\hat{x}_1, \hat{x}_2)` and an affine transformation matrix :math:`M` as follows:
67
+
68
+ .. math::
69
+
70
+ \begin{bmatrix} \hat{x} \\ 1 \end{bmatrix} = M^{-1} \begin{bmatrix} \hat{y} \\ 1 \end{bmatrix}^T
71
+
72
+ .. math::
73
+
74
+ s(\hat{y}) = \hat{S}(\hat{x}_1(\hat{y}), \hat{x}_2(\hat{y})) - \hat{x}_3(\hat{y})
75
+
76
+ This description allows us to calculate ray-surface intersections efficiently. Typically, we do not state :math:`M^{-1}` explicitly in the implementation but simply apply the transformation itself directly.
77
+
78
+ 2. **Parametric Surface Description**
79
+ In this approach, surfaces are defined by parameterizing coordinates. For optical surfaces, the surface is described again as a composition of the explicit description :math:`\hat{S}` and a transformation matrix :math:`M` as follows:
80
+
81
+ .. math::
82
+
83
+ \begin{bmatrix} S(\hat{x}_1, \hat{x}_2) \\ 1 \end{bmatrix} = M \begin{bmatrix} \hat{x}_1 \\ \hat{x}_2 \\ \hat{S}(\hat{x}_1, \hat{x}_2) \\ 1 \end{bmatrix}
84
+
85
+ In our library, the *parametric domains* are defined by the lenses or target surfaces (detectors). For example, in the case of a round lens, the *parametric domain* would be the disc determined by the aperture radius. This surface description is typically used for plotting but is also useful in the context of constraint optimization.
86
+
87
+ Examples:
88
+ >>> import diffinytrace as dit
89
+ >>> aperture_radius = 30.
90
+ >>> lens_thickness = 8.
91
+ >>> material = dit.materials["NBK7"]
92
+ >>> transform = dit.transforms.Identity()
93
+ >>> asphere = dit.Aspheric(1./40., 0.0, [-0.00001])
94
+ >>> plane = dit.Plane()
95
+ >>> lens = dit.Lens(transform, lens_thickness,
96
+ >>> asphere, plane,
97
+ >>> material, aperture_radius)
98
+ >>> dit.plotting.system2D.plot(lens)
99
+ """
100
+
101
+ @staticmethod
102
+ def functional(O,*params_list):
103
+ raise NotImplementedError("functional not implemented")
104
+
105
+ def get_functional_param_args(self):
106
+ raise NotImplementedError("params_list not implemented")
107
+
108
+ def explicit(self,local_pos):
109
+ raise NotImplementedError("explicit function not implemented")
110
+
111
+ class Plane(Surface):
112
+ """
113
+ A class to represent a plane surface in 3D space.
114
+ The plane is defined by the equation z = 0, and the functional method
115
+ returns the z-coordinate of the input points.
116
+ """
117
+ def __init__(self):
118
+ super().__init__()
119
+
120
+ @staticmethod
121
+ def functional(O):
122
+ return O[:,-1]
123
+
124
+ def get_functional_param_args(self):
125
+ return []
126
+ def explicit(self,local_pos):
127
+ if local_pos.shape[-1] != 2:
128
+ raise RuntimeError("local_pos needs to be of shape [:,2]")
129
+
130
+
131
+
132
+ x = local_pos[:,0]
133
+ y = local_pos[:,1]
134
+
135
+ O_new = torch.zeros((local_pos.shape[0],3),device=local_pos.device,dtype=local_pos.dtype)
136
+ O_new[:,0] = x
137
+ O_new[:,1] = y
138
+
139
+ return self.functional(O_new,*self.get_functional_param_args())
140
+
141
+ class Aspheric(Surface):
142
+ r"""
143
+ This is the aspheric surface class, implementation follows:
144
+ https://en.wikipedia.org/wiki/Aspheric_lens.
145
+
146
+ The surface is parameterized as an implicit function :math:`f(x, y, z) = 0`.
147
+ For simplicity, we assume the surface function :math:`f(x, y, z)` can be decomposed as:
148
+
149
+ .. math::
150
+
151
+ f(x, y, z) = g(x, y) + h(z),
152
+
153
+ where :math:`g(x, y)` and :math:`h(z)` are explicit functions:
154
+
155
+ .. math::
156
+
157
+ r^2 = x^2 + y^2
158
+
159
+ .. math::
160
+
161
+ g(x, y) = \frac{c \cdot r^2}{1 + \sqrt{1 - (1 + k) \cdot \frac{r^2}{R^2}}}
162
+ + a_0 \cdot r^4 + a_1 \cdot r^6 + \cdots
163
+
164
+ .. math::
165
+
166
+ h(z) = -z
167
+
168
+ Args:
169
+ c (float): Surface curvature, or one over the radius of curvature.
170
+ k (float): Conic coefficient.
171
+ ai (list or None): Aspheric parameters, could be a vector. When None, the surface is spherical.
172
+ """
173
+ def __init__(self, curvature, conic_coeff=None, aspheric_param=None):
174
+ super().__init__()
175
+ if not torch.is_tensor(curvature):
176
+ curvature = torch.tensor(curvature)
177
+ curvature = curvature.to(torch.get_default_dtype())
178
+
179
+ conic_coeff_requires_grad = True
180
+ if conic_coeff is None:
181
+ conic_coeff_requires_grad = False
182
+ conic_coeff = torch.tensor(0.0)
183
+
184
+ self.curvature = make_parameter_from_input(curvature)
185
+ self.conic_coeff = make_parameter_from_input(conic_coeff)
186
+ self.conic_coeff.requires_grad = conic_coeff_requires_grad
187
+
188
+ self.aspheric_param = None
189
+ if aspheric_param is not None:
190
+ self.aspheric_param = make_parameter_from_input(aspheric_param)
191
+
192
+
193
+ @staticmethod
194
+ def g(x: torch.Tensor,
195
+ y: torch.Tensor,
196
+ curvature: torch.Tensor,
197
+ conic_coeff: torch.Tensor,
198
+ aspheric_param: Optional[torch.Tensor]) -> torch.Tensor:
199
+ return Aspheric._g(x**2 + y**2, curvature, conic_coeff, aspheric_param)
200
+
201
+ @staticmethod
202
+ def h(z: torch.Tensor,
203
+ curvature: torch.Tensor,
204
+ conic_coeff: torch.Tensor,
205
+ aspheric_param: Optional[torch.Tensor]) -> torch.Tensor:
206
+ return -z
207
+
208
+ @staticmethod
209
+ def _g(r2: torch.Tensor,
210
+ curvature: torch.Tensor,
211
+ conic_coeff: torch.Tensor,
212
+ aspheric_param: Optional[torch.Tensor]) -> torch.Tensor:
213
+ #r2 is r**2.
214
+ tmp = r2*curvature
215
+ total_surface = tmp / (1 + torch.sqrt(1 - (1+conic_coeff) * tmp*curvature))
216
+ higher_surface = 0.
217
+ if aspheric_param is not None:
218
+ for i in range(len(aspheric_param)):
219
+ higher_surface += aspheric_param[i]*(r2**(i+2))
220
+ return total_surface + higher_surface
221
+
222
+ @staticmethod
223
+ def functional(O:torch.Tensor,
224
+ curvature:torch.Tensor,
225
+ conic_coeff:torch.Tensor,
226
+ aspheric_param:Optional[torch.Tensor]) -> torch.Tensor:
227
+ x = O[:,0]
228
+ y = O[:,1]
229
+ z = O[:,2]
230
+
231
+ return Aspheric.g(x, y,curvature,conic_coeff,aspheric_param) + Aspheric.h(z,curvature,conic_coeff,aspheric_param)
232
+
233
+ def get_functional_param_args(self) -> List[torch.Tensor]:
234
+ return [self.curvature,self.conic_coeff,self.aspheric_param]
235
+
236
+
237
+ def explicit(self,local_pos:torch.Tensor) -> torch.Tensor:
238
+ if local_pos.shape[-1] != 2:
239
+ raise RuntimeError("local_pos needs to be of shape [:,2]")
240
+ x = local_pos[:,0]
241
+ y = local_pos[:,1]
242
+
243
+ O_new = torch.zeros((local_pos.shape[0],3),device=local_pos.device,dtype=local_pos.dtype)
244
+ O_new[:,0] = x
245
+ O_new[:,1] = y
246
+
247
+ return self.functional(O_new,*self.get_functional_param_args())
248
+
249
+ """
250
+ class Zernike(Surface):
251
+ #TODO reimplement!
252
+ def __init__(self,aperture_radius,max_radial_degree):
253
+ super().__init__()
254
+
255
+ self.max_radial_degree = max_radial_degree
256
+ self.coeff = make_parameter_from_input(torch.zeros((basis_functions.zernike.get_num_coeffs(max_radial_degree))))
257
+ self.aperture_radius = torch.tensor(aperture_radius)
258
+
259
+ def refine(self):
260
+ #TODO move to a parent class
261
+ with torch.no_grad():
262
+ coeff = make_parameter_from_input(torch.zeros((basis_functions.zernike.get_num_coeffs(self.max_radial_degree+1))))
263
+ coeff.data[:self.coeff.shape[0]] = self.coeff.data.detach()
264
+ self.coeff = coeff
265
+ self.max_radial_degree = self.max_radial_degree+1
266
+
267
+
268
+ @staticmethod
269
+ def functional(O,coeffs,aperture_radius):
270
+ points = O[:,[0,1]]/aperture_radius
271
+ max_n = basis_functions.zernike.get_radial_degree(coeffs.shape[0])
272
+ zernike_surface = basis_functions.zernike.basis_function(max_n, points)@coeffs
273
+
274
+ return zernike_surface-O[:,-1]
275
+
276
+ def get_functional_param_args(self):
277
+ return [self.coeff,self.aperture_radius]
278
+
279
+ def explicit(self,local_pos):
280
+ if local_pos.shape[-1] != 2:
281
+ raise RuntimeError("local_pos needs to be of shape [:,2]")
282
+ x = local_pos[:,0]
283
+ y = local_pos[:,1]
284
+
285
+ O_new = torch.zeros((local_pos.shape[0],3),device=local_pos.device,dtype=local_pos.dtype)
286
+ O_new[:,0] = x
287
+ O_new[:,1] = y
288
+
289
+ return self.functional(O_new,*self.get_functional_param_args())
290
+
291
+ """
292
+
293
+ def bspline_n_after_refinement(n,k):
294
+ return ((2*n+1)-k)
295
+
296
+ class Bspline(Surface):
297
+ """
298
+ A class to represent a B-spline surface in 3D space.
299
+ The surface is defined by the B-spline basis functions and control points.
300
+ The functional method returns the z-coordinate of the input points.
301
+ """
302
+
303
+ def __init__(self,aperture_radius:float,orders:List[int],ns:List[int]):
304
+ super().__init__()
305
+ #orders is order!!!
306
+ #order = degree + 1
307
+
308
+ U1 = [0.]*(orders[0]-1)+list(np.linspace(0.,1.,ns[0]-orders[0]+2))+[1.0]*(orders[0]-1)
309
+ U2 = [0.]*(orders[1]-1)+list(np.linspace(0.,1.,ns[1]-orders[1]+2))+[1.0]*(orders[1]-1)
310
+ U1 = torch.tensor(U1,dtype=torch.get_default_dtype())
311
+ U2 = torch.tensor(U2,dtype=torch.get_default_dtype())
312
+
313
+ self.Us = [U1,U2]
314
+ self.ns = ns
315
+ self.orders = orders
316
+ #print("orders",self.orders)
317
+ #print("ns",self.ns)
318
+ self.coeff = make_parameter_from_input(torch.zeros((self.ns)))
319
+ self.aperture_radius = torch.tensor(aperture_radius)
320
+
321
+ def get_CAD_coeff(self,affine_transform:Transform)->np.ndarray:
322
+ """
323
+ Get the CAD coefficients from the affine transform.
324
+
325
+ Args:
326
+ affine_transform (torch.Tensor): Affine transformation matrix.
327
+
328
+ Returns:
329
+ numpy.ndarray: Control points of the B-spline surface.
330
+ """
331
+ affine_transform = affine_transform.detach().cpu()
332
+ dtype = affine_transform.dtype
333
+ coeff = self.coeff.detach()
334
+
335
+ v = torch.zeros((coeff.shape[0],coeff.shape[1],4)).cpu()
336
+ ys = torch.linspace(-self.aperture_radius,self.aperture_radius,coeff.shape[0]).cpu()
337
+ xs = torch.linspace(-self.aperture_radius,self.aperture_radius,coeff.shape[1]).cpu()
338
+ xs,ys = torch.meshgrid(xs,ys)
339
+ v[:,:,0] = xs
340
+ v[:,:,1] = ys
341
+ v[:,:,2] = coeff
342
+ v[:,:,3] = torch.ones_like(v[:,:,3])
343
+ v = v.reshape(-1,4).cpu().to(dtype=dtype)
344
+ Mv = None
345
+ M = affine_transform
346
+ Mv = v@M.T
347
+ out = Mv[:,[0,1,2]]
348
+ out = out.reshape(coeff.shape[0],coeff.shape[1],3)
349
+ return out.detach().cpu().numpy()
350
+
351
+ def get_CAD_face(self,affine_transform):
352
+ """
353
+ Get the CAD face from the affine transform.
354
+
355
+ Args:
356
+ affine_transform (torch.Tensor): Affine transformation matrix.
357
+
358
+ Returns:
359
+ CAD face object.
360
+ """
361
+
362
+ from . export.cad import makeBsplineFace
363
+ control_points = self.get_CAD_coeff(affine_transform)
364
+ U1,U2 = self.Us
365
+ u_order,v_order = self.orders
366
+ u_degree = u_order-1
367
+ v_degree = v_order-1
368
+ return makeBsplineFace(control_points,U1,U2,u_degree,v_degree)
369
+
370
+ def refine(self):
371
+ """
372
+ Refine the B-spline surface by increasing the number of control points.
373
+ The number of control points is increased by 1 in each direction."""
374
+ Us,coeff = basis_functions.bspline.refine_2D(self.Us,self.orders,self.coeff)
375
+ self.Us = Us
376
+ with torch.no_grad():
377
+ xtmp = make_parameter_from_input(coeff.data.detach())
378
+ self.coeff = xtmp
379
+
380
+ self.ns[0] = self.coeff.shape[0]
381
+ self.ns[1] = self.coeff.shape[1]
382
+
383
+ def functional(self,O,coeff,aperture_radius):
384
+ points = O[:,[0,1]]#/(2.0*aperture_radius))+0.5
385
+ x_range = [-aperture_radius,aperture_radius]
386
+ y_range = [-aperture_radius,aperture_radius]
387
+
388
+
389
+ _zsurface = basis_functions.bspline.surface_2D(points, self.Us, self.orders, self.ns, x_range, y_range, coeff)
390
+ return _zsurface-O[:,-1]
391
+
392
+
393
+ def get_functional_param_args(self)->List[torch.Tensor]:
394
+ return [self.coeff,self.aperture_radius]
395
+
396
+ def explicit(self,local_pos:torch.Tensor)->torch.Tensor:
397
+ """
398
+ Convert local position to global position using the B-spline surface functional.
399
+
400
+ Args:
401
+ local_pos (torch.Tensor): Local position in 2D space.
402
+ Returns:
403
+ torch.Tensor: Global position in 3D space.
404
+ """
405
+ if local_pos.shape[-1] != 2:
406
+ raise RuntimeError("local_pos needs to be of shape [:,2]")
407
+ device = local_pos.device
408
+ dtype = local_pos.dtype
409
+
410
+ x = local_pos[:,0]
411
+ y = local_pos[:,1]
412
+
413
+ O_new = torch.zeros((local_pos.shape[0],3),device=device,dtype=dtype)
414
+ O_new[:,0] = x
415
+ O_new[:,1] = y
416
+
417
+ return self.functional(O_new,*self.get_functional_param_args())
418
+
419
+
420
+ class Legendre(Surface):
421
+
422
+ """
423
+ A class to represent a Legendre surface in 3D space.
424
+ Its kinda work in progress.
425
+ """
426
+ def __init__(self,aperture_radius: float, degree: int):
427
+ super().__init__()
428
+
429
+ self.degree = degree
430
+ self.coeff = make_parameter_from_input(torch.zeros((basis_functions.legendre.get_num_coeff(degree))))
431
+ self.aperture_radius = torch.tensor(aperture_radius)
432
+
433
+ def refine(self):
434
+ #TODO move to a parent class
435
+ with torch.no_grad():
436
+ coeff = make_parameter_from_input(torch.zeros((basis_functions.legendre.get_num_coeff(self.degree+1))))
437
+ coeff.data[:self.coeff.shape[0]] = self.coeff.data.detach()
438
+ self.coeff = coeff
439
+ self.degree = self.degree+1
440
+
441
+
442
+ def functional(self, O: torch.Tensor, coeffs: torch.Tensor, aperture_radius: float) -> torch.Tensor:
443
+ points = O[:,[0,1]]/aperture_radius
444
+ z = basis_functions.legendre.basis_2D(self.degree,points[:,0],points[:,1])@coeffs
445
+
446
+ return z-O[:,-1]
447
+
448
+ def get_functional_param_args(self):
449
+ return [self.coeff,self.aperture_radius]
450
+
451
+ def explicit(self,local_pos:torch.Tensor)->torch.Tensor:
452
+ if local_pos.shape[-1] != 2:
453
+ raise RuntimeError("local_pos needs to be of shape [:,2]")
454
+ device = local_pos.device
455
+ dtype = local_pos.dtype
456
+
457
+ x = local_pos[:,0]
458
+ y = local_pos[:,1]
459
+
460
+ O_new = torch.zeros((local_pos.shape[0],3),device=device,dtype=dtype)
461
+ O_new[:,0] = x
462
+ O_new[:,1] = y
463
+
464
+ return self.functional(O_new,*self.get_functional_param_args())
465
+
466
+
467
+
468
+