diffinytrace 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffinytrace/__init__.py +122 -0
- diffinytrace/basis_functions/__init__.py +14 -0
- diffinytrace/basis_functions/bspline.py +521 -0
- diffinytrace/basis_functions/chebyshev.py +3 -0
- diffinytrace/basis_functions/legendre.py +77 -0
- diffinytrace/basis_functions/zernike.py +235 -0
- diffinytrace/config.py +140 -0
- diffinytrace/constraints.py +54 -0
- diffinytrace/element.py +1660 -0
- diffinytrace/export/__init__.py +8 -0
- diffinytrace/export/cad.py +253 -0
- diffinytrace/gaussian_smoother.py +530 -0
- diffinytrace/hat_smoother.py +44 -0
- diffinytrace/integrators.py +452 -0
- diffinytrace/intersection.py +285 -0
- diffinytrace/optimize.py +808 -0
- diffinytrace/physical_object.py +150 -0
- diffinytrace/plotting/__init__.py +16 -0
- diffinytrace/plotting/core.py +92 -0
- diffinytrace/plotting/quantity2D.py +188 -0
- diffinytrace/plotting/system2D.py +220 -0
- diffinytrace/plotting/system3D.py +327 -0
- diffinytrace/plotting/wavelength.py +231 -0
- diffinytrace/refractive_index.py +101 -0
- diffinytrace/render.py +77 -0
- diffinytrace/source.py +661 -0
- diffinytrace/spectrum.py +79 -0
- diffinytrace/surface.py +468 -0
- diffinytrace/target_grid.py +399 -0
- diffinytrace/transforms.py +472 -0
- diffinytrace/utils/__init__.py +7 -0
- diffinytrace/utils/autograd.py +116 -0
- diffinytrace/utils/irradiance_importer.py +134 -0
- diffinytrace-2.1.dist-info/METADATA +26 -0
- diffinytrace-2.1.dist-info/RECORD +38 -0
- diffinytrace-2.1.dist-info/WHEEL +5 -0
- diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
- diffinytrace-2.1.dist-info/top_level.txt +1 -0
diffinytrace/__init__.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
This module provides a collection of functions and classes for optical system design and analysis.
|
|
7
|
+
It includes modules for ray tracing, surface definitions, optimization, and more.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
# Submodules
|
|
12
|
+
"source",
|
|
13
|
+
"transforms",
|
|
14
|
+
"target_grid",
|
|
15
|
+
"utils",
|
|
16
|
+
"plotting",
|
|
17
|
+
"basis_functions",
|
|
18
|
+
"gaussian_smoother",
|
|
19
|
+
"optimize",
|
|
20
|
+
"integrators",
|
|
21
|
+
"export",
|
|
22
|
+
"render",
|
|
23
|
+
"constraints",
|
|
24
|
+
"spectrum",
|
|
25
|
+
|
|
26
|
+
# intersection
|
|
27
|
+
"cat_semi_functionals",
|
|
28
|
+
"get_functional_param_args",
|
|
29
|
+
"construct_surface_and_normal_func",
|
|
30
|
+
"construct_surface_and_normal_func_with_params",
|
|
31
|
+
"CustomAutogradRule_t",
|
|
32
|
+
"get_ray_intersection_length",
|
|
33
|
+
|
|
34
|
+
# surface
|
|
35
|
+
"Plane",
|
|
36
|
+
"Aspheric",
|
|
37
|
+
"Bspline",
|
|
38
|
+
"Legendre",
|
|
39
|
+
"bspline_n_after_refinement",
|
|
40
|
+
|
|
41
|
+
# element
|
|
42
|
+
"OpticalSystem",
|
|
43
|
+
"SequentialOpticalSystem",
|
|
44
|
+
"OpticalElement",
|
|
45
|
+
"OpticalSurface",
|
|
46
|
+
"LensSurfaceTransmissionEnter",
|
|
47
|
+
"LensSurfaceTransmissionLeave",
|
|
48
|
+
"Lens",
|
|
49
|
+
"Mirror",
|
|
50
|
+
"Detector",
|
|
51
|
+
"trace_to_detector",
|
|
52
|
+
"get_unused_params_mask",
|
|
53
|
+
"set_used_params_bounds_to_constant",
|
|
54
|
+
"set_unused_params_to_zero",
|
|
55
|
+
"set_unused_bspline_coeff_to_nearest",
|
|
56
|
+
"FresnelVirtualLens",
|
|
57
|
+
|
|
58
|
+
# config
|
|
59
|
+
"set_tolerance",
|
|
60
|
+
"get_tolerance",
|
|
61
|
+
"set_max_iterations",
|
|
62
|
+
"get_max_iterations",
|
|
63
|
+
"restore_default_settings",
|
|
64
|
+
"get_damping_factor",
|
|
65
|
+
"set_damping_factor",
|
|
66
|
+
"get_show_iteration_count",
|
|
67
|
+
"set_show_iteration_count",
|
|
68
|
+
|
|
69
|
+
# optimize
|
|
70
|
+
"minimize",
|
|
71
|
+
"make_parameter_from_input",
|
|
72
|
+
|
|
73
|
+
# refractive_index
|
|
74
|
+
"materials",
|
|
75
|
+
"RefractiveIndex",
|
|
76
|
+
|
|
77
|
+
# autograd
|
|
78
|
+
"grad",
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
import os
|
|
82
|
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
|
83
|
+
import torch
|
|
84
|
+
torch.set_default_dtype(torch.float64)
|
|
85
|
+
|
|
86
|
+
from . import source
|
|
87
|
+
from . import transforms
|
|
88
|
+
from . import target_grid
|
|
89
|
+
from . import gaussian_smoother
|
|
90
|
+
|
|
91
|
+
from . import utils
|
|
92
|
+
#from . import refractive_index
|
|
93
|
+
from . import plotting
|
|
94
|
+
from . import basis_functions
|
|
95
|
+
from . import optimize
|
|
96
|
+
from . import integrators
|
|
97
|
+
from . import export
|
|
98
|
+
from . import render
|
|
99
|
+
from . import constraints
|
|
100
|
+
from . import spectrum
|
|
101
|
+
|
|
102
|
+
from .intersection import cat_semi_functionals,get_functional_param_args,\
|
|
103
|
+
construct_surface_and_normal_func,construct_surface_and_normal_func_with_params,\
|
|
104
|
+
CustomAutogradRule_t,get_ray_intersection_length
|
|
105
|
+
|
|
106
|
+
from .surface import Plane,Aspheric,Bspline,Legendre,bspline_n_after_refinement
|
|
107
|
+
|
|
108
|
+
from .element import OpticalSystem,SequentialOpticalSystem,OpticalElement,OpticalSurface,\
|
|
109
|
+
LensSurfaceTransmissionEnter,LensSurfaceTransmissionLeave,Lens,Mirror,Detector,trace_to_detector,\
|
|
110
|
+
get_unused_params_mask,set_used_params_bounds_to_constant,\
|
|
111
|
+
set_unused_params_to_zero,set_unused_bspline_coeff_to_nearest,\
|
|
112
|
+
FresnelVirtualLens
|
|
113
|
+
|
|
114
|
+
from .config import set_tolerance,get_tolerance,set_max_iterations,\
|
|
115
|
+
get_max_iterations,restore_default_settings,get_damping_factor,set_damping_factor,\
|
|
116
|
+
get_show_iteration_count,set_show_iteration_count
|
|
117
|
+
from .optimize import minimize,make_parameter_from_input
|
|
118
|
+
|
|
119
|
+
from .refractive_index import materials
|
|
120
|
+
from .refractive_index import RefractiveIndex
|
|
121
|
+
from .utils.autograd import grad
|
|
122
|
+
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"bspline",
|
|
6
|
+
"legendre",
|
|
7
|
+
"zernike",
|
|
8
|
+
"chebyshev"
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from . import bspline
|
|
12
|
+
from . import legendre
|
|
13
|
+
from . import zernike
|
|
14
|
+
from . import chebyshev
|
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
B-spline surfaces for freeform geometry.
|
|
3
|
+
|
|
4
|
+
B-splines are popular for describing freeform surfaces because they allow
|
|
5
|
+
local changes to the geometry :cite:`THBray`. Their smoothness is controlled
|
|
6
|
+
by the spline degrees, which determine continuity and differentiability
|
|
7
|
+
:cite:`IGA`. A tensor-product B-spline surface is defined by two *knot vectors*,
|
|
8
|
+
a grid of univariate B-spline *basis functions*, and a bi-directional net of
|
|
9
|
+
*control points* :cite:`nurbs`. Below, we summarize the main components.
|
|
10
|
+
|
|
11
|
+
Notes:
|
|
12
|
+
**Knot vectors.**
|
|
13
|
+
A surface uses two (typically clamped) nondecreasing knot vectors
|
|
14
|
+
:math:`U` and :math:`V`:
|
|
15
|
+
|
|
16
|
+
.. math::
|
|
17
|
+
|
|
18
|
+
U = \{\underbrace{0,\dots,0}_{p+1},\, u_{p+1}, \dots, u_{n},\,
|
|
19
|
+
\underbrace{1,\dots,1}_{p+1}\}, \qquad
|
|
20
|
+
V = \{\underbrace{0,\dots,0}_{q+1},\, v_{q+1}, \dots, v_{m},\,
|
|
21
|
+
\underbrace{1,\dots,1}_{q+1}\}.
|
|
22
|
+
|
|
23
|
+
Here :math:`p` and :math:`q` are the degrees in the :math:`u`- and
|
|
24
|
+
:math:`v`-directions. A knot vector :math:`U=\{u_0,\dots,u_M\}` is a
|
|
25
|
+
nondecreasing sequence, i.e., :math:`u_i \le u_{i+1}`; each element is a
|
|
26
|
+
*knot*.
|
|
27
|
+
|
|
28
|
+
**Univariate B-spline basis (Cox–de Boor).**
|
|
29
|
+
In the :math:`u`-direction (analogously for :math:`v`), the basis
|
|
30
|
+
:math:`\{N_{i,p}\}` is defined recursively :cite:`nurbs`:
|
|
31
|
+
|
|
32
|
+
.. math::
|
|
33
|
+
|
|
34
|
+
N_{i,0}(u) =
|
|
35
|
+
\begin{cases}
|
|
36
|
+
1, & u_i \le u < u_{i+1},\\
|
|
37
|
+
0, & \text{otherwise},
|
|
38
|
+
\end{cases}
|
|
39
|
+
|
|
40
|
+
.. math::
|
|
41
|
+
|
|
42
|
+
N_{i,p}(u) =
|
|
43
|
+
\frac{u - u_i}{u_{i+p} - u_i}\, N_{i,p-1}(u) \;+\;
|
|
44
|
+
\frac{u_{i+p+1} - u}{u_{i+p+1} - u_{i+1}}\, N_{i+1,p-1}(u).
|
|
45
|
+
|
|
46
|
+
In the :math:`v`-direction, the basis :math:`\{M_{j,q}\}` is
|
|
47
|
+
|
|
48
|
+
.. math::
|
|
49
|
+
|
|
50
|
+
M_{j,0}(v) =
|
|
51
|
+
\begin{cases}
|
|
52
|
+
1, & v_j \le v < v_{j+1},\\
|
|
53
|
+
0, & \text{otherwise},
|
|
54
|
+
\end{cases}
|
|
55
|
+
|
|
56
|
+
.. math::
|
|
57
|
+
|
|
58
|
+
M_{j,q}(v) =
|
|
59
|
+
\frac{v - v_j}{v_{j+q} - v_j}\, M_{j,q-1}(v) \;+\;
|
|
60
|
+
\frac{v_{j+q+1} - v}{v_{j+q+1} - v_{j+1}}\, M_{j+1,q-1}(v).
|
|
61
|
+
|
|
62
|
+
**Control points.**
|
|
63
|
+
Control points :math:`\mathbf{P}_{i,j}` link the basis to geometry.
|
|
64
|
+
They can be scalars, 2D, or 3D vectors.
|
|
65
|
+
|
|
66
|
+
**Surface definition.**
|
|
67
|
+
The tensor-product B-spline surface is
|
|
68
|
+
|
|
69
|
+
.. math::
|
|
70
|
+
:label: eq-bspline-Z
|
|
71
|
+
|
|
72
|
+
Z(u,v) = \sum_{i=0}^{n} \sum_{j=0}^{m}
|
|
73
|
+
N_{i,p}(u)\, M_{j,q}(v)\, \mathbf{P}_{i,j}.
|
|
74
|
+
|
|
75
|
+
**Implementation details (this library).**
|
|
76
|
+
We use scalar control points :math:`\mathbf{P}_{i,j}` (height field),
|
|
77
|
+
uniformly increasing clamped knot vectors, and :math:`u,v \in [0,1]`.
|
|
78
|
+
To couple an explicit surface to the ray tracer, we map physical
|
|
79
|
+
coordinates :math:`\hat{x}_1,\hat{x}_2` to the parametric domain via a
|
|
80
|
+
scale :math:`h`:
|
|
81
|
+
|
|
82
|
+
.. math::
|
|
83
|
+
|
|
84
|
+
S(\hat{x}_1,\hat{x}_2) =
|
|
85
|
+
Z\!\left(\frac{\hat{x}_1}{h},\, \frac{\hat{x}_2}{h}\right).
|
|
86
|
+
|
|
87
|
+
.. figure:: _static/bspline_plot1.png
|
|
88
|
+
:alt: Freeform lens with B-spline surface
|
|
89
|
+
:width: 60%
|
|
90
|
+
:align: center
|
|
91
|
+
|
|
92
|
+
Visualization of a Freeform lens with a B-spline surface.
|
|
93
|
+
|
|
94
|
+
Examples:
|
|
95
|
+
Define a lens with a B-spline surface and plot it:
|
|
96
|
+
|
|
97
|
+
.. code-block:: python
|
|
98
|
+
|
|
99
|
+
import torch
|
|
100
|
+
import diffinytrace as dit
|
|
101
|
+
|
|
102
|
+
aperture_half = 30.0
|
|
103
|
+
aperture_radius = aperture_half
|
|
104
|
+
lens_thickness = 8.0
|
|
105
|
+
material = dit.materials["NBK7"]
|
|
106
|
+
transform = dit.transforms.Identity()
|
|
107
|
+
|
|
108
|
+
# degree [p, q] and control net size [n_u, n_v] (example values)
|
|
109
|
+
bspline = dit.Bspline(aperture_half, [3, 3], [8, 8])
|
|
110
|
+
plane = dit.Plane()
|
|
111
|
+
|
|
112
|
+
with torch.no_grad():
|
|
113
|
+
bspline.coeff.data = torch.randn_like(bspline.coeff.data) * 3.0
|
|
114
|
+
|
|
115
|
+
lens = dit.Lens(transform, lens_thickness, bspline, plane,
|
|
116
|
+
material, aperture_radius)
|
|
117
|
+
|
|
118
|
+
dit.plotting.system3D.plot(lens, zticks=[0, 5])
|
|
119
|
+
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
123
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
124
|
+
|
|
125
|
+
__all__ = [
|
|
126
|
+
"cox_de_boor_recursion",
|
|
127
|
+
"basis_1D",
|
|
128
|
+
"basis_2D",
|
|
129
|
+
"surface_2D",
|
|
130
|
+
"insert_knot_1D_single",
|
|
131
|
+
"insert_knots1D",
|
|
132
|
+
"refine2D"
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
import torch
|
|
136
|
+
from typing import Tuple,List,Callable,Optional,Union
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def cox_de_boor_recursion(U: torch.Tensor, k: int, n: int, xis: torch.Tensor, k_curr: int) -> torch.Tensor:
|
|
140
|
+
r"""Cox-de Boor recursion for B-spline basis functions.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
U (torch.Tensor): Knot vector.
|
|
144
|
+
k (int): Order of the B-spline.
|
|
145
|
+
n (int): Number of control points.
|
|
146
|
+
xis (torch.Tensor): Evaluation points.
|
|
147
|
+
k_curr (int): Current recursion level.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
torch.Tensor: B-spline basis function values at the evaluation points.
|
|
151
|
+
"""
|
|
152
|
+
#TODO REMOVE n
|
|
153
|
+
U = U.to(xis.device)
|
|
154
|
+
if k_curr<0:
|
|
155
|
+
raise RuntimeError("cox_de_boor_recursion wrong input. k_curr < 0")
|
|
156
|
+
if k_curr == 0:
|
|
157
|
+
out = (xis[None].T < U[1:]).to(xis.device,dtype=xis.dtype)*(xis[None].T >= U[:-1]).to(xis.device,dtype=xis.dtype)
|
|
158
|
+
mask = xis == U[U.shape[0]-1]
|
|
159
|
+
#out[mask] = torch.zeros_like(out[mask])
|
|
160
|
+
out[mask,-1] = 1.0
|
|
161
|
+
return out
|
|
162
|
+
|
|
163
|
+
Ni = cox_de_boor_recursion(U,k,n,xis,k_curr-1)
|
|
164
|
+
|
|
165
|
+
Niplus = Ni[:,1:]
|
|
166
|
+
Ni = Ni[:,:-1]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def save_divisor(up,divisor):
|
|
170
|
+
#something is wrong here fix it
|
|
171
|
+
zeros = divisor==0
|
|
172
|
+
divisor[:,zeros.reshape(-1)] = 1.
|
|
173
|
+
up[:,zeros.reshape(-1)] = 1.
|
|
174
|
+
return up/divisor
|
|
175
|
+
|
|
176
|
+
tmp1 = save_divisor((xis[None].T - U[:-k_curr-1]),(U[k_curr:-1]-U[:-k_curr-1]).reshape(1,-1))
|
|
177
|
+
tmp2 = save_divisor(-(xis[None].T-U[k_curr+1:]),(U[k_curr+1:]-U[1:-k_curr]).reshape(1,-1))
|
|
178
|
+
|
|
179
|
+
out = tmp1*Ni+tmp2*Niplus
|
|
180
|
+
|
|
181
|
+
return out
|
|
182
|
+
|
|
183
|
+
def basis_1D(points:torch.Tensor,
|
|
184
|
+
U:torch.Tensor,
|
|
185
|
+
k:int,
|
|
186
|
+
n:int,
|
|
187
|
+
val_range:tuple[float,float])->torch.Tensor:
|
|
188
|
+
"""
|
|
189
|
+
Compute 1D B-spline basis functions at given points.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
points (torch.Tensor): Points where the basis functions are evaluated.
|
|
193
|
+
U (torch.Tensor): Knot vector.
|
|
194
|
+
k (int): Order of the B-spline.
|
|
195
|
+
n (int): Number of control points.
|
|
196
|
+
val_range (tuple[float, float]): Range of the target interval (e.g., (0.0, 1.0)).
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
torch.Tensor: B-spline basis function values at the evaluation points.
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
RuntimeError: If the knot vector does not start at 0.0 or end at 1.0.
|
|
203
|
+
|
|
204
|
+
Example:
|
|
205
|
+
>>> import torch
|
|
206
|
+
>>> import matplotlib.pyplot as plt
|
|
207
|
+
>>> from diffinytrace.basis_functions import bspline
|
|
208
|
+
>>> U = torch.tensor([0., 0.2, 0.4, 0.6, 0.8, 1])
|
|
209
|
+
>>> n = 3
|
|
210
|
+
>>> k = 3 # This is order 3
|
|
211
|
+
>>> print(U[0], U[-1])
|
|
212
|
+
>>> xis = torch.linspace(0, 1, 100)
|
|
213
|
+
>>> xN = bspline.basis_1D(xis, U, k, n, [0., 1.])
|
|
214
|
+
>>> num_points = xN.shape[0]
|
|
215
|
+
>>> tmp = xN.reshape(num_points, -1, 1) * xN.reshape(num_points, 1, -1)
|
|
216
|
+
>>> for yin in xN.T:
|
|
217
|
+
... plt.plot(xis, yin)
|
|
218
|
+
>>> plt.gca().set_aspect('equal')
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
if U[0] != 0.0 or U[-1]!=1.0:
|
|
222
|
+
raise RuntimeError("Knots should always between 0.0 and 1.0 and also contain these values!")
|
|
223
|
+
points = (points-val_range[0])/(val_range[1]-val_range[0])#points are now between 0. and 1.0
|
|
224
|
+
k_curr = k
|
|
225
|
+
return cox_de_boor_recursion(U,k,n,points,k_curr-1)
|
|
226
|
+
|
|
227
|
+
def basis_2D(points:torch.Tensor,
|
|
228
|
+
Us:List[torch.Tensor],
|
|
229
|
+
orders:List[int],
|
|
230
|
+
ns:List[int],
|
|
231
|
+
x_range:tuple,
|
|
232
|
+
y_range:tuple) -> torch.Tensor:
|
|
233
|
+
"""Compute the 2D B-spline basis functions for given points.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
points (torch.Tensor): Points where the basis functions are evaluated.
|
|
237
|
+
Us (list[torch.Tensor]): Knot vectors for x and y directions.
|
|
238
|
+
orders (list[int]): Orders of the B-spline in x and y directions.
|
|
239
|
+
ns (list[int]): Number of control points in x and y directions.
|
|
240
|
+
x_range (tuple): Range of the target plane in the x direction.
|
|
241
|
+
y_range (tuple): Range of the target plane in the y direction.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
torch.Tensor: 2D B-spline basis function values at the evaluation points.
|
|
245
|
+
|
|
246
|
+
Example:
|
|
247
|
+
>>> import diffinytrace as dit
|
|
248
|
+
>>> from diffinytrace.basis_functions.bspline import basis_2D
|
|
249
|
+
>>> import torch
|
|
250
|
+
>>>
|
|
251
|
+
>>> U1 = torch.tensor([0., 0.2, 0.4, 0.6, 0.8, 1])
|
|
252
|
+
>>> Us = [U1, U1]
|
|
253
|
+
>>> ps = [3, 3]
|
|
254
|
+
>>> ns = [3, 3]
|
|
255
|
+
>>>
|
|
256
|
+
>>> side_points = 100
|
|
257
|
+
>>> _x = torch.linspace(0, 1, side_points)
|
|
258
|
+
>>> _y = torch.linspace(0, 1, side_points)
|
|
259
|
+
>>> grid_y, grid_x = torch.meshgrid(_y, _x, indexing='ij')
|
|
260
|
+
>>> points = torch.cat([grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)], dim=-1)
|
|
261
|
+
>>>
|
|
262
|
+
>>> N2D = basis_2D(points, Us, ps, ns, torch.tensor([0, 1]), torch.tensor([0, 1]))
|
|
263
|
+
>>>
|
|
264
|
+
>>> xi = 0
|
|
265
|
+
>>> yi = 2
|
|
266
|
+
>>> dit.plotting.quantity2D.plot(
|
|
267
|
+
>>> N2D[:, yi, xi].reshape(side_points, side_points),
|
|
268
|
+
>>> "basis fun",
|
|
269
|
+
>>> [0, 1],
|
|
270
|
+
>>> [0, 1],
|
|
271
|
+
>>> xlabel="x",
|
|
272
|
+
>>> ylabel="y"
|
|
273
|
+
>>> )
|
|
274
|
+
|
|
275
|
+
Raises:
|
|
276
|
+
RuntimeError: If the input points are not in local coordinates or have an incorrect shape.
|
|
277
|
+
|
|
278
|
+
"""
|
|
279
|
+
if len(points.shape) != 2 or points.shape[1] != 2:
|
|
280
|
+
raise RuntimeError("The points must be in local coordinates and of shape [#points,2]")
|
|
281
|
+
device = points.device
|
|
282
|
+
if Us[0].device != device:
|
|
283
|
+
Us[0] = Us[0].to(device)
|
|
284
|
+
if Us[1].device != device:
|
|
285
|
+
Us[1] = Us[1].to(device)
|
|
286
|
+
|
|
287
|
+
#Move evaluation to cor.py make abstraction
|
|
288
|
+
Ns1 = basis_1D(points[:,0],Us[0],orders[0],ns[0],x_range)
|
|
289
|
+
Ns2 = basis_1D(points[:,1],Us[1],orders[1],ns[1],y_range)
|
|
290
|
+
num_points = Ns1.shape[0]
|
|
291
|
+
N2D = Ns1.reshape(num_points,-1,1)*Ns2.reshape(num_points,1,-1)
|
|
292
|
+
|
|
293
|
+
return N2D
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def surface_2D(points: torch.Tensor, Us: List[torch.Tensor], orders: List[int], ns: List[int], x_range: tuple, y_range: tuple, control_points: torch.Tensor) -> torch.Tensor:
|
|
297
|
+
"""
|
|
298
|
+
Evaluate a 2D B-spline surface at given points using provided knot vectors, orders, and control points.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
points (torch.Tensor): Points where the surface is evaluated, shape [num_points, 2].
|
|
302
|
+
Us (List[torch.Tensor]): Knot vectors for x and y directions [U_x, U_y].
|
|
303
|
+
orders (List[int]): Orders of the B-spline in x and y directions [order_x, order_y].
|
|
304
|
+
ns (List[int]): Number of control points in x and y directions [n_x, n_y].
|
|
305
|
+
x_range (tuple): Range of the target plane in the x direction (min, max).
|
|
306
|
+
y_range (tuple): Range of the target plane in the y direction (min, max).
|
|
307
|
+
control_points (torch.Tensor): Control points, shape [n_x, n_y, ...] or [n_x*n_y, ...].
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
torch.Tensor: Evaluated surface points at the input locations.
|
|
311
|
+
|
|
312
|
+
Raises:
|
|
313
|
+
RuntimeError: If the input points are not in local coordinates or have an incorrect shape.
|
|
314
|
+
|
|
315
|
+
Example:
|
|
316
|
+
>>> import torch
|
|
317
|
+
>>> from diffinytrace.basis_functions import bspline
|
|
318
|
+
>>> n_x, n_y = 4, 4
|
|
319
|
+
>>> control_points = torch.randn((n_x, n_y, 2))
|
|
320
|
+
>>> k_x, k_y = 3, 3
|
|
321
|
+
>>> U_x = torch.linspace(0, 1, n_x + k_x)
|
|
322
|
+
>>> U_y = torch.linspace(0, 1, n_y + k_y)
|
|
323
|
+
>>> points = torch.rand((100, 2))
|
|
324
|
+
>>> surface = bspline.surface_2D(points, [U_x, U_y], [k_x, k_y], [n_x, n_y], (0.0, 1.0), (0.0, 1.0), control_points)
|
|
325
|
+
"""
|
|
326
|
+
if len(points.shape) != 2 or points.shape[1] != 2:
|
|
327
|
+
raise RuntimeError("The points must be in local coordinates and of shape [#points,2]")
|
|
328
|
+
device = points.device
|
|
329
|
+
if Us[0].device != device:
|
|
330
|
+
Us[0] = Us[0].to(device)
|
|
331
|
+
if Us[1].device != device:
|
|
332
|
+
Us[1] = Us[1].to(device)
|
|
333
|
+
|
|
334
|
+
num_points = points.shape[0]
|
|
335
|
+
# Compute basis functions in x and y directions
|
|
336
|
+
Ns1 = basis_1D(points[:, 0], Us[0], orders[0], ns[0], x_range)
|
|
337
|
+
Ns2 = basis_1D(points[:, 1], Us[1], orders[1], ns[1], y_range)
|
|
338
|
+
|
|
339
|
+
# Unique knots and step size in x-direction
|
|
340
|
+
U1_unique = torch.unique(Us[0])
|
|
341
|
+
du1 = U1_unique[1] - U1_unique[0]
|
|
342
|
+
|
|
343
|
+
# Unique knots and step size in y-direction
|
|
344
|
+
U2_unique = torch.unique(Us[1])
|
|
345
|
+
du2 = U2_unique[1] - U2_unique[0]
|
|
346
|
+
|
|
347
|
+
# Compute start and end indices for basis functions in x and y directions
|
|
348
|
+
_points_x = (points[:,0]-x_range[0])/(x_range[1]-x_range[0])#points are now between 0. and 1.0
|
|
349
|
+
_points_y = (points[:,1]-y_range[0])/(y_range[1]-y_range[0])#points are now between 0. and 1.0
|
|
350
|
+
|
|
351
|
+
start_idx1 = (_points_x / du1).floor().to(torch.int32)
|
|
352
|
+
start_idx2 = (_points_y / du2).floor().to(torch.int32)
|
|
353
|
+
|
|
354
|
+
# Wrap the indices using modulus to ensure they fit within the control points' range
|
|
355
|
+
n1 = Ns1.shape[1]
|
|
356
|
+
n2 = Ns2.shape[1]
|
|
357
|
+
start_idx1 = start_idx1 % n1
|
|
358
|
+
start_idx2 = start_idx2 % n2
|
|
359
|
+
|
|
360
|
+
# Reshape control points to be compatible with the grid
|
|
361
|
+
single_valued = False
|
|
362
|
+
if len(control_points.shape) == 1:
|
|
363
|
+
single_valued = control_points.shape[0]==ns[0]*ns[1]
|
|
364
|
+
else:
|
|
365
|
+
if len(control_points.shape) == 2:
|
|
366
|
+
single_valued = control_points.shape[0] == ns[0] and control_points.shape[1] == ns[1]
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
control_points = control_points.reshape(ns[0],ns[1], -1)
|
|
370
|
+
|
|
371
|
+
# Create tensor for extracting control points (using broadcasting)
|
|
372
|
+
idx1 = torch.arange(orders[0] + 1).unsqueeze(0).to(start_idx1.device) + start_idx1.unsqueeze(1)
|
|
373
|
+
idx2 = torch.arange(orders[1] + 1).unsqueeze(0).to(start_idx2.device) + start_idx2.unsqueeze(1)
|
|
374
|
+
|
|
375
|
+
# Wrap the indices around using modulus for valid ranges
|
|
376
|
+
idx1 = idx1 % n1
|
|
377
|
+
idx2 = idx2 % n2
|
|
378
|
+
|
|
379
|
+
# Extract basis function values and control points using broadcasting
|
|
380
|
+
basis_values_1D_1 = Ns1.gather(1, idx1) # Extract relevant basis values for x
|
|
381
|
+
basis_values_1D_2 = Ns2.gather(1, idx2) # Extract relevant basis values for y
|
|
382
|
+
|
|
383
|
+
# Compute outer product of basis functions (NxM grid for each point)
|
|
384
|
+
N2D = basis_values_1D_1[:, :, None] * basis_values_1D_2[:, None, :]
|
|
385
|
+
N2D = N2D.reshape(num_points,-1)
|
|
386
|
+
# Use advanced indexing to gather control points
|
|
387
|
+
|
|
388
|
+
idx_flat = idx1.reshape(num_points,-1,1)*control_points.shape[1]+idx2.reshape(num_points,1,-1)
|
|
389
|
+
idx_flat = idx_flat.reshape(num_points,-1)
|
|
390
|
+
control_points = control_points.reshape(ns[0]*ns[1], -1)
|
|
391
|
+
control_point_subset = control_points[idx_flat]
|
|
392
|
+
surface_points = torch.einsum('ik,ikl->il', N2D, control_point_subset)
|
|
393
|
+
if single_valued:
|
|
394
|
+
surface_points = surface_points[:,0]
|
|
395
|
+
return surface_points
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def insert_knot_1D_single(U: torch.Tensor,
|
|
400
|
+
korder: int,
|
|
401
|
+
new_knot: torch.Tensor,
|
|
402
|
+
control_points: torch.Tensor,
|
|
403
|
+
dim: int=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
404
|
+
"""
|
|
405
|
+
Insert a single knot into a 1D B-spline knot vector and update control points.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
U (torch.Tensor): Original knot vector.
|
|
409
|
+
korder (int): Order of the B-spline.
|
|
410
|
+
new_knot (torch.Tensor or float): Knot value to insert.
|
|
411
|
+
control_points (torch.Tensor): Control points (shape: [n, ...]).
|
|
412
|
+
dim (int, optional): Dimension along which to insert the knot (default: 0).
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Tuple[torch.Tensor, torch.Tensor]: (Updated knot vector, updated control points).
|
|
416
|
+
|
|
417
|
+
Example:
|
|
418
|
+
>>> import torch
|
|
419
|
+
>>> import numpy as np
|
|
420
|
+
>>> import matplotlib.pyplot as plt
|
|
421
|
+
>>> n = 4
|
|
422
|
+
>>> control_points = torch.randn((n, 2)) # Random control points
|
|
423
|
+
>>> k = 4 # Quadratic B-spline
|
|
424
|
+
>>> U = torch.tensor([0.0] * (k - 1) + list(np.linspace(0, 1.0, n + k - 2 * (k - 1))) + [1.0] * (k - 1))
|
|
425
|
+
>>> U = U.float()
|
|
426
|
+
>>> print(U.shape[0] - k == n, n >= k)
|
|
427
|
+
>>> for m in range(100):
|
|
428
|
+
... U_new, new_control_points = bspline.insert_knot_1D_single(U, k, torch.rand((1)), control_points)
|
|
429
|
+
... print("new_control_points", new_control_points)
|
|
430
|
+
... print("control_points", control_points)
|
|
431
|
+
... xis = torch.linspace(0, 1, 1000)
|
|
432
|
+
... xN1 = bspline.basis_1D(xis, U, k, 3, [0, 1.])
|
|
433
|
+
... out1 = xN1 @ control_points
|
|
434
|
+
... xN2 = bspline.basis_1D(xis, U_new, k, 4, [0, 1.])
|
|
435
|
+
... out2 = xN2 @ new_control_points
|
|
436
|
+
... plt.plot(out1[:, 0], out1[:, 1], linewidth=5.0)
|
|
437
|
+
... plt.plot(out2[:, 0], out2[:, 1], "--")
|
|
438
|
+
... torch.mean((out1 - out2) ** 2)
|
|
439
|
+
"""
|
|
440
|
+
|
|
441
|
+
device = control_points.device
|
|
442
|
+
dtype = control_points.dtype
|
|
443
|
+
if dim != 0:
|
|
444
|
+
control_points = control_points.transpose(0,dim)
|
|
445
|
+
if not torch.is_tensor(U):
|
|
446
|
+
U = torch.tensor(U)
|
|
447
|
+
k = (U<new_knot).int().sum()
|
|
448
|
+
U_new = torch.zeros((U.shape[0]+1),device=device,dtype=dtype)
|
|
449
|
+
U_new[:k] = U[:k]
|
|
450
|
+
U_new[k] = new_knot
|
|
451
|
+
U_new[k+1:] = U[k:]
|
|
452
|
+
new_shape = list(control_points.shape)
|
|
453
|
+
new_shape[0] += 1
|
|
454
|
+
control_points_new = torch.zeros((new_shape),device=device,dtype=dtype)
|
|
455
|
+
i_s = torch.arange(k-korder,k)
|
|
456
|
+
if i_s[0] > 0:
|
|
457
|
+
control_points_new[:i_s[0]] = control_points[:i_s[0]]
|
|
458
|
+
alpha_i = (new_knot-U_new[i_s])/(U_new[i_s+korder]-U_new[i_s])
|
|
459
|
+
alpha_i = alpha_i.reshape(-1,*[1]*(control_points.dim()-1))
|
|
460
|
+
control_points_new[i_s] = alpha_i*control_points[i_s]+(1.0-alpha_i)*control_points[i_s-1]
|
|
461
|
+
control_points_new[i_s[-1]+1:] = control_points[i_s[-1]:]
|
|
462
|
+
if dim != 0:
|
|
463
|
+
control_points_new = control_points_new.transpose(0,dim)
|
|
464
|
+
return U_new,control_points_new
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def insert_knots_1D(U:torch.Tensor,
|
|
468
|
+
korder:int,
|
|
469
|
+
new_knot_list:List[float],
|
|
470
|
+
control_points:torch.Tensor,
|
|
471
|
+
dim:int=0)->Tuple[torch.Tensor, torch.Tensor]:
|
|
472
|
+
"""
|
|
473
|
+
Insert multiple knots into a 1D B-spline knot vector and update control points.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
U (torch.Tensor): Original knot vector.
|
|
477
|
+
korder (int): Order of the B-spline.
|
|
478
|
+
new_knot_list (Iterable): List or tensor of knot values to insert.
|
|
479
|
+
control_points (torch.Tensor): Control points.
|
|
480
|
+
dim (int, optional): Dimension along which to insert the knots (default: 0).
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
Tuple[torch.Tensor, torch.Tensor]: (Updated knot vector, updated control points).
|
|
484
|
+
"""
|
|
485
|
+
for new_knot in new_knot_list:
|
|
486
|
+
U,control_points = insert_knot_1D_single(U,korder,new_knot,control_points,dim=dim)
|
|
487
|
+
return U,control_points
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def refine_2D(Us:List[torch.Tensor],
|
|
491
|
+
orders:List[int],
|
|
492
|
+
coeff:Optional[torch.Tensor]=None) -> Union[Tuple[List[torch.Tensor], torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
|
493
|
+
"""
|
|
494
|
+
Refine 2D B-spline knot vectors by inserting midpoints between existing knots.
|
|
495
|
+
Optionally updates coefficients (control points) accordingly.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
Us (list[torch.Tensor]): List of knot vectors [U1, U2] for x and y directions.
|
|
499
|
+
orders (list[int]): List of orders [order_x, order_y] for x and y directions.
|
|
500
|
+
coeff (torch.Tensor, optional): Coefficient tensor (control points) to update.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
Tuple[list[torch.Tensor], torch.Tensor] or Tuple[torch.Tensor, torch.Tensor]:
|
|
504
|
+
If coeff is provided, returns ([U1_new, U2_new], coeff_new).
|
|
505
|
+
Otherwise, returns (U1_new, U2_new).
|
|
506
|
+
"""
|
|
507
|
+
U1,U2 = Us
|
|
508
|
+
U1_unique = torch.unique(U1)
|
|
509
|
+
dU1 = U1_unique[1]-U1_unique[0]
|
|
510
|
+
new_knots_U1 =torch.linspace(torch.min(U1_unique)+dU1*0.5,torch.max(U1_unique)-dU1*0.5,U1_unique.shape[0]-1)
|
|
511
|
+
U2_unique = torch.unique(U2)
|
|
512
|
+
dU2 = U2_unique[1]-U2_unique[0]
|
|
513
|
+
new_knots_U2 =torch.linspace(torch.min(U2_unique)+dU2*0.5,torch.max(U2_unique)-dU2*0.5,U2_unique.shape[0]-1)
|
|
514
|
+
|
|
515
|
+
if not coeff is None:
|
|
516
|
+
coeff = coeff.detach()
|
|
517
|
+
U1,coeff = insert_knots_1D(U1,orders[0],new_knots_U1,coeff,dim=0)
|
|
518
|
+
U2,coeff = insert_knots_1D(U2,orders[1],new_knots_U2,coeff,dim=1)
|
|
519
|
+
return [U1,U2],coeff
|
|
520
|
+
else:
|
|
521
|
+
return U1,U2
|