diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,122 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+
5
+ """
6
+ This module provides a collection of functions and classes for optical system design and analysis.
7
+ It includes modules for ray tracing, surface definitions, optimization, and more.
8
+ """
9
+
10
+ __all__ = [
11
+ # Submodules
12
+ "source",
13
+ "transforms",
14
+ "target_grid",
15
+ "utils",
16
+ "plotting",
17
+ "basis_functions",
18
+ "gaussian_smoother",
19
+ "optimize",
20
+ "integrators",
21
+ "export",
22
+ "render",
23
+ "constraints",
24
+ "spectrum",
25
+
26
+ # intersection
27
+ "cat_semi_functionals",
28
+ "get_functional_param_args",
29
+ "construct_surface_and_normal_func",
30
+ "construct_surface_and_normal_func_with_params",
31
+ "CustomAutogradRule_t",
32
+ "get_ray_intersection_length",
33
+
34
+ # surface
35
+ "Plane",
36
+ "Aspheric",
37
+ "Bspline",
38
+ "Legendre",
39
+ "bspline_n_after_refinement",
40
+
41
+ # element
42
+ "OpticalSystem",
43
+ "SequentialOpticalSystem",
44
+ "OpticalElement",
45
+ "OpticalSurface",
46
+ "LensSurfaceTransmissionEnter",
47
+ "LensSurfaceTransmissionLeave",
48
+ "Lens",
49
+ "Mirror",
50
+ "Detector",
51
+ "trace_to_detector",
52
+ "get_unused_params_mask",
53
+ "set_used_params_bounds_to_constant",
54
+ "set_unused_params_to_zero",
55
+ "set_unused_bspline_coeff_to_nearest",
56
+ "FresnelVirtualLens",
57
+
58
+ # config
59
+ "set_tolerance",
60
+ "get_tolerance",
61
+ "set_max_iterations",
62
+ "get_max_iterations",
63
+ "restore_default_settings",
64
+ "get_damping_factor",
65
+ "set_damping_factor",
66
+ "get_show_iteration_count",
67
+ "set_show_iteration_count",
68
+
69
+ # optimize
70
+ "minimize",
71
+ "make_parameter_from_input",
72
+
73
+ # refractive_index
74
+ "materials",
75
+ "RefractiveIndex",
76
+
77
+ # autograd
78
+ "grad",
79
+ ]
80
+
81
+ import os
82
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
83
+ import torch
84
+ torch.set_default_dtype(torch.float64)
85
+
86
+ from . import source
87
+ from . import transforms
88
+ from . import target_grid
89
+ from . import gaussian_smoother
90
+
91
+ from . import utils
92
+ #from . import refractive_index
93
+ from . import plotting
94
+ from . import basis_functions
95
+ from . import optimize
96
+ from . import integrators
97
+ from . import export
98
+ from . import render
99
+ from . import constraints
100
+ from . import spectrum
101
+
102
+ from .intersection import cat_semi_functionals,get_functional_param_args,\
103
+ construct_surface_and_normal_func,construct_surface_and_normal_func_with_params,\
104
+ CustomAutogradRule_t,get_ray_intersection_length
105
+
106
+ from .surface import Plane,Aspheric,Bspline,Legendre,bspline_n_after_refinement
107
+
108
+ from .element import OpticalSystem,SequentialOpticalSystem,OpticalElement,OpticalSurface,\
109
+ LensSurfaceTransmissionEnter,LensSurfaceTransmissionLeave,Lens,Mirror,Detector,trace_to_detector,\
110
+ get_unused_params_mask,set_used_params_bounds_to_constant,\
111
+ set_unused_params_to_zero,set_unused_bspline_coeff_to_nearest,\
112
+ FresnelVirtualLens
113
+
114
+ from .config import set_tolerance,get_tolerance,set_max_iterations,\
115
+ get_max_iterations,restore_default_settings,get_damping_factor,set_damping_factor,\
116
+ get_show_iteration_count,set_show_iteration_count
117
+ from .optimize import minimize,make_parameter_from_input
118
+
119
+ from .refractive_index import materials
120
+ from .refractive_index import RefractiveIndex
121
+ from .utils.autograd import grad
122
+
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "bspline",
6
+ "legendre",
7
+ "zernike",
8
+ "chebyshev"
9
+ ]
10
+
11
+ from . import bspline
12
+ from . import legendre
13
+ from . import zernike
14
+ from . import chebyshev
@@ -0,0 +1,521 @@
1
+ r"""
2
+ B-spline surfaces for freeform geometry.
3
+
4
+ B-splines are popular for describing freeform surfaces because they allow
5
+ local changes to the geometry :cite:`THBray`. Their smoothness is controlled
6
+ by the spline degrees, which determine continuity and differentiability
7
+ :cite:`IGA`. A tensor-product B-spline surface is defined by two *knot vectors*,
8
+ a grid of univariate B-spline *basis functions*, and a bi-directional net of
9
+ *control points* :cite:`nurbs`. Below, we summarize the main components.
10
+
11
+ Notes:
12
+ **Knot vectors.**
13
+ A surface uses two (typically clamped) nondecreasing knot vectors
14
+ :math:`U` and :math:`V`:
15
+
16
+ .. math::
17
+
18
+ U = \{\underbrace{0,\dots,0}_{p+1},\, u_{p+1}, \dots, u_{n},\,
19
+ \underbrace{1,\dots,1}_{p+1}\}, \qquad
20
+ V = \{\underbrace{0,\dots,0}_{q+1},\, v_{q+1}, \dots, v_{m},\,
21
+ \underbrace{1,\dots,1}_{q+1}\}.
22
+
23
+ Here :math:`p` and :math:`q` are the degrees in the :math:`u`- and
24
+ :math:`v`-directions. A knot vector :math:`U=\{u_0,\dots,u_M\}` is a
25
+ nondecreasing sequence, i.e., :math:`u_i \le u_{i+1}`; each element is a
26
+ *knot*.
27
+
28
+ **Univariate B-spline basis (Cox–de Boor).**
29
+ In the :math:`u`-direction (analogously for :math:`v`), the basis
30
+ :math:`\{N_{i,p}\}` is defined recursively :cite:`nurbs`:
31
+
32
+ .. math::
33
+
34
+ N_{i,0}(u) =
35
+ \begin{cases}
36
+ 1, & u_i \le u < u_{i+1},\\
37
+ 0, & \text{otherwise},
38
+ \end{cases}
39
+
40
+ .. math::
41
+
42
+ N_{i,p}(u) =
43
+ \frac{u - u_i}{u_{i+p} - u_i}\, N_{i,p-1}(u) \;+\;
44
+ \frac{u_{i+p+1} - u}{u_{i+p+1} - u_{i+1}}\, N_{i+1,p-1}(u).
45
+
46
+ In the :math:`v`-direction, the basis :math:`\{M_{j,q}\}` is
47
+
48
+ .. math::
49
+
50
+ M_{j,0}(v) =
51
+ \begin{cases}
52
+ 1, & v_j \le v < v_{j+1},\\
53
+ 0, & \text{otherwise},
54
+ \end{cases}
55
+
56
+ .. math::
57
+
58
+ M_{j,q}(v) =
59
+ \frac{v - v_j}{v_{j+q} - v_j}\, M_{j,q-1}(v) \;+\;
60
+ \frac{v_{j+q+1} - v}{v_{j+q+1} - v_{j+1}}\, M_{j+1,q-1}(v).
61
+
62
+ **Control points.**
63
+ Control points :math:`\mathbf{P}_{i,j}` link the basis to geometry.
64
+ They can be scalars, 2D, or 3D vectors.
65
+
66
+ **Surface definition.**
67
+ The tensor-product B-spline surface is
68
+
69
+ .. math::
70
+ :label: eq-bspline-Z
71
+
72
+ Z(u,v) = \sum_{i=0}^{n} \sum_{j=0}^{m}
73
+ N_{i,p}(u)\, M_{j,q}(v)\, \mathbf{P}_{i,j}.
74
+
75
+ **Implementation details (this library).**
76
+ We use scalar control points :math:`\mathbf{P}_{i,j}` (height field),
77
+ uniformly increasing clamped knot vectors, and :math:`u,v \in [0,1]`.
78
+ To couple an explicit surface to the ray tracer, we map physical
79
+ coordinates :math:`\hat{x}_1,\hat{x}_2` to the parametric domain via a
80
+ scale :math:`h`:
81
+
82
+ .. math::
83
+
84
+ S(\hat{x}_1,\hat{x}_2) =
85
+ Z\!\left(\frac{\hat{x}_1}{h},\, \frac{\hat{x}_2}{h}\right).
86
+
87
+ .. figure:: _static/bspline_plot1.png
88
+ :alt: Freeform lens with B-spline surface
89
+ :width: 60%
90
+ :align: center
91
+
92
+ Visualization of a Freeform lens with a B-spline surface.
93
+
94
+ Examples:
95
+ Define a lens with a B-spline surface and plot it:
96
+
97
+ .. code-block:: python
98
+
99
+ import torch
100
+ import diffinytrace as dit
101
+
102
+ aperture_half = 30.0
103
+ aperture_radius = aperture_half
104
+ lens_thickness = 8.0
105
+ material = dit.materials["NBK7"]
106
+ transform = dit.transforms.Identity()
107
+
108
+ # degree [p, q] and control net size [n_u, n_v] (example values)
109
+ bspline = dit.Bspline(aperture_half, [3, 3], [8, 8])
110
+ plane = dit.Plane()
111
+
112
+ with torch.no_grad():
113
+ bspline.coeff.data = torch.randn_like(bspline.coeff.data) * 3.0
114
+
115
+ lens = dit.Lens(transform, lens_thickness, bspline, plane,
116
+ material, aperture_radius)
117
+
118
+ dit.plotting.system3D.plot(lens, zticks=[0, 5])
119
+
120
+ """
121
+
122
+ # Copyright (c) 2025 Martin Pflaum
123
+ # This file is part of the diffinytrace project, licensed under the MIT License.
124
+
125
+ __all__ = [
126
+ "cox_de_boor_recursion",
127
+ "basis_1D",
128
+ "basis_2D",
129
+ "surface_2D",
130
+ "insert_knot_1D_single",
131
+ "insert_knots1D",
132
+ "refine2D"
133
+ ]
134
+
135
+ import torch
136
+ from typing import Tuple,List,Callable,Optional,Union
137
+
138
+
139
+ def cox_de_boor_recursion(U: torch.Tensor, k: int, n: int, xis: torch.Tensor, k_curr: int) -> torch.Tensor:
140
+ r"""Cox-de Boor recursion for B-spline basis functions.
141
+
142
+ Args:
143
+ U (torch.Tensor): Knot vector.
144
+ k (int): Order of the B-spline.
145
+ n (int): Number of control points.
146
+ xis (torch.Tensor): Evaluation points.
147
+ k_curr (int): Current recursion level.
148
+
149
+ Returns:
150
+ torch.Tensor: B-spline basis function values at the evaluation points.
151
+ """
152
+ #TODO REMOVE n
153
+ U = U.to(xis.device)
154
+ if k_curr<0:
155
+ raise RuntimeError("cox_de_boor_recursion wrong input. k_curr < 0")
156
+ if k_curr == 0:
157
+ out = (xis[None].T < U[1:]).to(xis.device,dtype=xis.dtype)*(xis[None].T >= U[:-1]).to(xis.device,dtype=xis.dtype)
158
+ mask = xis == U[U.shape[0]-1]
159
+ #out[mask] = torch.zeros_like(out[mask])
160
+ out[mask,-1] = 1.0
161
+ return out
162
+
163
+ Ni = cox_de_boor_recursion(U,k,n,xis,k_curr-1)
164
+
165
+ Niplus = Ni[:,1:]
166
+ Ni = Ni[:,:-1]
167
+
168
+
169
+ def save_divisor(up,divisor):
170
+ #something is wrong here fix it
171
+ zeros = divisor==0
172
+ divisor[:,zeros.reshape(-1)] = 1.
173
+ up[:,zeros.reshape(-1)] = 1.
174
+ return up/divisor
175
+
176
+ tmp1 = save_divisor((xis[None].T - U[:-k_curr-1]),(U[k_curr:-1]-U[:-k_curr-1]).reshape(1,-1))
177
+ tmp2 = save_divisor(-(xis[None].T-U[k_curr+1:]),(U[k_curr+1:]-U[1:-k_curr]).reshape(1,-1))
178
+
179
+ out = tmp1*Ni+tmp2*Niplus
180
+
181
+ return out
182
+
183
+ def basis_1D(points:torch.Tensor,
184
+ U:torch.Tensor,
185
+ k:int,
186
+ n:int,
187
+ val_range:tuple[float,float])->torch.Tensor:
188
+ """
189
+ Compute 1D B-spline basis functions at given points.
190
+
191
+ Args:
192
+ points (torch.Tensor): Points where the basis functions are evaluated.
193
+ U (torch.Tensor): Knot vector.
194
+ k (int): Order of the B-spline.
195
+ n (int): Number of control points.
196
+ val_range (tuple[float, float]): Range of the target interval (e.g., (0.0, 1.0)).
197
+
198
+ Returns:
199
+ torch.Tensor: B-spline basis function values at the evaluation points.
200
+
201
+ Raises:
202
+ RuntimeError: If the knot vector does not start at 0.0 or end at 1.0.
203
+
204
+ Example:
205
+ >>> import torch
206
+ >>> import matplotlib.pyplot as plt
207
+ >>> from diffinytrace.basis_functions import bspline
208
+ >>> U = torch.tensor([0., 0.2, 0.4, 0.6, 0.8, 1])
209
+ >>> n = 3
210
+ >>> k = 3 # This is order 3
211
+ >>> print(U[0], U[-1])
212
+ >>> xis = torch.linspace(0, 1, 100)
213
+ >>> xN = bspline.basis_1D(xis, U, k, n, [0., 1.])
214
+ >>> num_points = xN.shape[0]
215
+ >>> tmp = xN.reshape(num_points, -1, 1) * xN.reshape(num_points, 1, -1)
216
+ >>> for yin in xN.T:
217
+ ... plt.plot(xis, yin)
218
+ >>> plt.gca().set_aspect('equal')
219
+ """
220
+
221
+ if U[0] != 0.0 or U[-1]!=1.0:
222
+ raise RuntimeError("Knots should always between 0.0 and 1.0 and also contain these values!")
223
+ points = (points-val_range[0])/(val_range[1]-val_range[0])#points are now between 0. and 1.0
224
+ k_curr = k
225
+ return cox_de_boor_recursion(U,k,n,points,k_curr-1)
226
+
227
+ def basis_2D(points:torch.Tensor,
228
+ Us:List[torch.Tensor],
229
+ orders:List[int],
230
+ ns:List[int],
231
+ x_range:tuple,
232
+ y_range:tuple) -> torch.Tensor:
233
+ """Compute the 2D B-spline basis functions for given points.
234
+
235
+ Args:
236
+ points (torch.Tensor): Points where the basis functions are evaluated.
237
+ Us (list[torch.Tensor]): Knot vectors for x and y directions.
238
+ orders (list[int]): Orders of the B-spline in x and y directions.
239
+ ns (list[int]): Number of control points in x and y directions.
240
+ x_range (tuple): Range of the target plane in the x direction.
241
+ y_range (tuple): Range of the target plane in the y direction.
242
+
243
+ Returns:
244
+ torch.Tensor: 2D B-spline basis function values at the evaluation points.
245
+
246
+ Example:
247
+ >>> import diffinytrace as dit
248
+ >>> from diffinytrace.basis_functions.bspline import basis_2D
249
+ >>> import torch
250
+ >>>
251
+ >>> U1 = torch.tensor([0., 0.2, 0.4, 0.6, 0.8, 1])
252
+ >>> Us = [U1, U1]
253
+ >>> ps = [3, 3]
254
+ >>> ns = [3, 3]
255
+ >>>
256
+ >>> side_points = 100
257
+ >>> _x = torch.linspace(0, 1, side_points)
258
+ >>> _y = torch.linspace(0, 1, side_points)
259
+ >>> grid_y, grid_x = torch.meshgrid(_y, _x, indexing='ij')
260
+ >>> points = torch.cat([grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)], dim=-1)
261
+ >>>
262
+ >>> N2D = basis_2D(points, Us, ps, ns, torch.tensor([0, 1]), torch.tensor([0, 1]))
263
+ >>>
264
+ >>> xi = 0
265
+ >>> yi = 2
266
+ >>> dit.plotting.quantity2D.plot(
267
+ >>> N2D[:, yi, xi].reshape(side_points, side_points),
268
+ >>> "basis fun",
269
+ >>> [0, 1],
270
+ >>> [0, 1],
271
+ >>> xlabel="x",
272
+ >>> ylabel="y"
273
+ >>> )
274
+
275
+ Raises:
276
+ RuntimeError: If the input points are not in local coordinates or have an incorrect shape.
277
+
278
+ """
279
+ if len(points.shape) != 2 or points.shape[1] != 2:
280
+ raise RuntimeError("The points must be in local coordinates and of shape [#points,2]")
281
+ device = points.device
282
+ if Us[0].device != device:
283
+ Us[0] = Us[0].to(device)
284
+ if Us[1].device != device:
285
+ Us[1] = Us[1].to(device)
286
+
287
+ #Move evaluation to cor.py make abstraction
288
+ Ns1 = basis_1D(points[:,0],Us[0],orders[0],ns[0],x_range)
289
+ Ns2 = basis_1D(points[:,1],Us[1],orders[1],ns[1],y_range)
290
+ num_points = Ns1.shape[0]
291
+ N2D = Ns1.reshape(num_points,-1,1)*Ns2.reshape(num_points,1,-1)
292
+
293
+ return N2D
294
+
295
+
296
+ def surface_2D(points: torch.Tensor, Us: List[torch.Tensor], orders: List[int], ns: List[int], x_range: tuple, y_range: tuple, control_points: torch.Tensor) -> torch.Tensor:
297
+ """
298
+ Evaluate a 2D B-spline surface at given points using provided knot vectors, orders, and control points.
299
+
300
+ Args:
301
+ points (torch.Tensor): Points where the surface is evaluated, shape [num_points, 2].
302
+ Us (List[torch.Tensor]): Knot vectors for x and y directions [U_x, U_y].
303
+ orders (List[int]): Orders of the B-spline in x and y directions [order_x, order_y].
304
+ ns (List[int]): Number of control points in x and y directions [n_x, n_y].
305
+ x_range (tuple): Range of the target plane in the x direction (min, max).
306
+ y_range (tuple): Range of the target plane in the y direction (min, max).
307
+ control_points (torch.Tensor): Control points, shape [n_x, n_y, ...] or [n_x*n_y, ...].
308
+
309
+ Returns:
310
+ torch.Tensor: Evaluated surface points at the input locations.
311
+
312
+ Raises:
313
+ RuntimeError: If the input points are not in local coordinates or have an incorrect shape.
314
+
315
+ Example:
316
+ >>> import torch
317
+ >>> from diffinytrace.basis_functions import bspline
318
+ >>> n_x, n_y = 4, 4
319
+ >>> control_points = torch.randn((n_x, n_y, 2))
320
+ >>> k_x, k_y = 3, 3
321
+ >>> U_x = torch.linspace(0, 1, n_x + k_x)
322
+ >>> U_y = torch.linspace(0, 1, n_y + k_y)
323
+ >>> points = torch.rand((100, 2))
324
+ >>> surface = bspline.surface_2D(points, [U_x, U_y], [k_x, k_y], [n_x, n_y], (0.0, 1.0), (0.0, 1.0), control_points)
325
+ """
326
+ if len(points.shape) != 2 or points.shape[1] != 2:
327
+ raise RuntimeError("The points must be in local coordinates and of shape [#points,2]")
328
+ device = points.device
329
+ if Us[0].device != device:
330
+ Us[0] = Us[0].to(device)
331
+ if Us[1].device != device:
332
+ Us[1] = Us[1].to(device)
333
+
334
+ num_points = points.shape[0]
335
+ # Compute basis functions in x and y directions
336
+ Ns1 = basis_1D(points[:, 0], Us[0], orders[0], ns[0], x_range)
337
+ Ns2 = basis_1D(points[:, 1], Us[1], orders[1], ns[1], y_range)
338
+
339
+ # Unique knots and step size in x-direction
340
+ U1_unique = torch.unique(Us[0])
341
+ du1 = U1_unique[1] - U1_unique[0]
342
+
343
+ # Unique knots and step size in y-direction
344
+ U2_unique = torch.unique(Us[1])
345
+ du2 = U2_unique[1] - U2_unique[0]
346
+
347
+ # Compute start and end indices for basis functions in x and y directions
348
+ _points_x = (points[:,0]-x_range[0])/(x_range[1]-x_range[0])#points are now between 0. and 1.0
349
+ _points_y = (points[:,1]-y_range[0])/(y_range[1]-y_range[0])#points are now between 0. and 1.0
350
+
351
+ start_idx1 = (_points_x / du1).floor().to(torch.int32)
352
+ start_idx2 = (_points_y / du2).floor().to(torch.int32)
353
+
354
+ # Wrap the indices using modulus to ensure they fit within the control points' range
355
+ n1 = Ns1.shape[1]
356
+ n2 = Ns2.shape[1]
357
+ start_idx1 = start_idx1 % n1
358
+ start_idx2 = start_idx2 % n2
359
+
360
+ # Reshape control points to be compatible with the grid
361
+ single_valued = False
362
+ if len(control_points.shape) == 1:
363
+ single_valued = control_points.shape[0]==ns[0]*ns[1]
364
+ else:
365
+ if len(control_points.shape) == 2:
366
+ single_valued = control_points.shape[0] == ns[0] and control_points.shape[1] == ns[1]
367
+
368
+
369
+ control_points = control_points.reshape(ns[0],ns[1], -1)
370
+
371
+ # Create tensor for extracting control points (using broadcasting)
372
+ idx1 = torch.arange(orders[0] + 1).unsqueeze(0).to(start_idx1.device) + start_idx1.unsqueeze(1)
373
+ idx2 = torch.arange(orders[1] + 1).unsqueeze(0).to(start_idx2.device) + start_idx2.unsqueeze(1)
374
+
375
+ # Wrap the indices around using modulus for valid ranges
376
+ idx1 = idx1 % n1
377
+ idx2 = idx2 % n2
378
+
379
+ # Extract basis function values and control points using broadcasting
380
+ basis_values_1D_1 = Ns1.gather(1, idx1) # Extract relevant basis values for x
381
+ basis_values_1D_2 = Ns2.gather(1, idx2) # Extract relevant basis values for y
382
+
383
+ # Compute outer product of basis functions (NxM grid for each point)
384
+ N2D = basis_values_1D_1[:, :, None] * basis_values_1D_2[:, None, :]
385
+ N2D = N2D.reshape(num_points,-1)
386
+ # Use advanced indexing to gather control points
387
+
388
+ idx_flat = idx1.reshape(num_points,-1,1)*control_points.shape[1]+idx2.reshape(num_points,1,-1)
389
+ idx_flat = idx_flat.reshape(num_points,-1)
390
+ control_points = control_points.reshape(ns[0]*ns[1], -1)
391
+ control_point_subset = control_points[idx_flat]
392
+ surface_points = torch.einsum('ik,ikl->il', N2D, control_point_subset)
393
+ if single_valued:
394
+ surface_points = surface_points[:,0]
395
+ return surface_points
396
+
397
+
398
+
399
+ def insert_knot_1D_single(U: torch.Tensor,
400
+ korder: int,
401
+ new_knot: torch.Tensor,
402
+ control_points: torch.Tensor,
403
+ dim: int=0) -> Tuple[torch.Tensor, torch.Tensor]:
404
+ """
405
+ Insert a single knot into a 1D B-spline knot vector and update control points.
406
+
407
+ Args:
408
+ U (torch.Tensor): Original knot vector.
409
+ korder (int): Order of the B-spline.
410
+ new_knot (torch.Tensor or float): Knot value to insert.
411
+ control_points (torch.Tensor): Control points (shape: [n, ...]).
412
+ dim (int, optional): Dimension along which to insert the knot (default: 0).
413
+
414
+ Returns:
415
+ Tuple[torch.Tensor, torch.Tensor]: (Updated knot vector, updated control points).
416
+
417
+ Example:
418
+ >>> import torch
419
+ >>> import numpy as np
420
+ >>> import matplotlib.pyplot as plt
421
+ >>> n = 4
422
+ >>> control_points = torch.randn((n, 2)) # Random control points
423
+ >>> k = 4 # Quadratic B-spline
424
+ >>> U = torch.tensor([0.0] * (k - 1) + list(np.linspace(0, 1.0, n + k - 2 * (k - 1))) + [1.0] * (k - 1))
425
+ >>> U = U.float()
426
+ >>> print(U.shape[0] - k == n, n >= k)
427
+ >>> for m in range(100):
428
+ ... U_new, new_control_points = bspline.insert_knot_1D_single(U, k, torch.rand((1)), control_points)
429
+ ... print("new_control_points", new_control_points)
430
+ ... print("control_points", control_points)
431
+ ... xis = torch.linspace(0, 1, 1000)
432
+ ... xN1 = bspline.basis_1D(xis, U, k, 3, [0, 1.])
433
+ ... out1 = xN1 @ control_points
434
+ ... xN2 = bspline.basis_1D(xis, U_new, k, 4, [0, 1.])
435
+ ... out2 = xN2 @ new_control_points
436
+ ... plt.plot(out1[:, 0], out1[:, 1], linewidth=5.0)
437
+ ... plt.plot(out2[:, 0], out2[:, 1], "--")
438
+ ... torch.mean((out1 - out2) ** 2)
439
+ """
440
+
441
+ device = control_points.device
442
+ dtype = control_points.dtype
443
+ if dim != 0:
444
+ control_points = control_points.transpose(0,dim)
445
+ if not torch.is_tensor(U):
446
+ U = torch.tensor(U)
447
+ k = (U<new_knot).int().sum()
448
+ U_new = torch.zeros((U.shape[0]+1),device=device,dtype=dtype)
449
+ U_new[:k] = U[:k]
450
+ U_new[k] = new_knot
451
+ U_new[k+1:] = U[k:]
452
+ new_shape = list(control_points.shape)
453
+ new_shape[0] += 1
454
+ control_points_new = torch.zeros((new_shape),device=device,dtype=dtype)
455
+ i_s = torch.arange(k-korder,k)
456
+ if i_s[0] > 0:
457
+ control_points_new[:i_s[0]] = control_points[:i_s[0]]
458
+ alpha_i = (new_knot-U_new[i_s])/(U_new[i_s+korder]-U_new[i_s])
459
+ alpha_i = alpha_i.reshape(-1,*[1]*(control_points.dim()-1))
460
+ control_points_new[i_s] = alpha_i*control_points[i_s]+(1.0-alpha_i)*control_points[i_s-1]
461
+ control_points_new[i_s[-1]+1:] = control_points[i_s[-1]:]
462
+ if dim != 0:
463
+ control_points_new = control_points_new.transpose(0,dim)
464
+ return U_new,control_points_new
465
+
466
+
467
+ def insert_knots_1D(U:torch.Tensor,
468
+ korder:int,
469
+ new_knot_list:List[float],
470
+ control_points:torch.Tensor,
471
+ dim:int=0)->Tuple[torch.Tensor, torch.Tensor]:
472
+ """
473
+ Insert multiple knots into a 1D B-spline knot vector and update control points.
474
+
475
+ Args:
476
+ U (torch.Tensor): Original knot vector.
477
+ korder (int): Order of the B-spline.
478
+ new_knot_list (Iterable): List or tensor of knot values to insert.
479
+ control_points (torch.Tensor): Control points.
480
+ dim (int, optional): Dimension along which to insert the knots (default: 0).
481
+
482
+ Returns:
483
+ Tuple[torch.Tensor, torch.Tensor]: (Updated knot vector, updated control points).
484
+ """
485
+ for new_knot in new_knot_list:
486
+ U,control_points = insert_knot_1D_single(U,korder,new_knot,control_points,dim=dim)
487
+ return U,control_points
488
+
489
+
490
+ def refine_2D(Us:List[torch.Tensor],
491
+ orders:List[int],
492
+ coeff:Optional[torch.Tensor]=None) -> Union[Tuple[List[torch.Tensor], torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
493
+ """
494
+ Refine 2D B-spline knot vectors by inserting midpoints between existing knots.
495
+ Optionally updates coefficients (control points) accordingly.
496
+
497
+ Args:
498
+ Us (list[torch.Tensor]): List of knot vectors [U1, U2] for x and y directions.
499
+ orders (list[int]): List of orders [order_x, order_y] for x and y directions.
500
+ coeff (torch.Tensor, optional): Coefficient tensor (control points) to update.
501
+
502
+ Returns:
503
+ Tuple[list[torch.Tensor], torch.Tensor] or Tuple[torch.Tensor, torch.Tensor]:
504
+ If coeff is provided, returns ([U1_new, U2_new], coeff_new).
505
+ Otherwise, returns (U1_new, U2_new).
506
+ """
507
+ U1,U2 = Us
508
+ U1_unique = torch.unique(U1)
509
+ dU1 = U1_unique[1]-U1_unique[0]
510
+ new_knots_U1 =torch.linspace(torch.min(U1_unique)+dU1*0.5,torch.max(U1_unique)-dU1*0.5,U1_unique.shape[0]-1)
511
+ U2_unique = torch.unique(U2)
512
+ dU2 = U2_unique[1]-U2_unique[0]
513
+ new_knots_U2 =torch.linspace(torch.min(U2_unique)+dU2*0.5,torch.max(U2_unique)-dU2*0.5,U2_unique.shape[0]-1)
514
+
515
+ if not coeff is None:
516
+ coeff = coeff.detach()
517
+ U1,coeff = insert_knots_1D(U1,orders[0],new_knots_U1,coeff,dim=0)
518
+ U2,coeff = insert_knots_1D(U2,orders[1],new_knots_U2,coeff,dim=1)
519
+ return [U1,U2],coeff
520
+ else:
521
+ return U1,U2
@@ -0,0 +1,3 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+