diffinytrace 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffinytrace/__init__.py +122 -0
- diffinytrace/basis_functions/__init__.py +14 -0
- diffinytrace/basis_functions/bspline.py +521 -0
- diffinytrace/basis_functions/chebyshev.py +3 -0
- diffinytrace/basis_functions/legendre.py +77 -0
- diffinytrace/basis_functions/zernike.py +235 -0
- diffinytrace/config.py +140 -0
- diffinytrace/constraints.py +54 -0
- diffinytrace/element.py +1660 -0
- diffinytrace/export/__init__.py +8 -0
- diffinytrace/export/cad.py +253 -0
- diffinytrace/gaussian_smoother.py +530 -0
- diffinytrace/hat_smoother.py +44 -0
- diffinytrace/integrators.py +452 -0
- diffinytrace/intersection.py +285 -0
- diffinytrace/optimize.py +808 -0
- diffinytrace/physical_object.py +150 -0
- diffinytrace/plotting/__init__.py +16 -0
- diffinytrace/plotting/core.py +92 -0
- diffinytrace/plotting/quantity2D.py +188 -0
- diffinytrace/plotting/system2D.py +220 -0
- diffinytrace/plotting/system3D.py +327 -0
- diffinytrace/plotting/wavelength.py +231 -0
- diffinytrace/refractive_index.py +101 -0
- diffinytrace/render.py +77 -0
- diffinytrace/source.py +661 -0
- diffinytrace/spectrum.py +79 -0
- diffinytrace/surface.py +468 -0
- diffinytrace/target_grid.py +399 -0
- diffinytrace/transforms.py +472 -0
- diffinytrace/utils/__init__.py +7 -0
- diffinytrace/utils/autograd.py +116 -0
- diffinytrace/utils/irradiance_importer.py +134 -0
- diffinytrace-2.1.dist-info/METADATA +26 -0
- diffinytrace-2.1.dist-info/RECORD +38 -0
- diffinytrace-2.1.dist-info/WHEEL +5 -0
- diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
- diffinytrace-2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"precompute_legendre_polynomials",
|
|
6
|
+
"basis_2D",
|
|
7
|
+
"get_num_coeff"
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
def precompute_legendre_polynomials(x: torch.Tensor, degree: int) -> list[torch.Tensor]:
|
|
13
|
+
"""
|
|
14
|
+
Precomputes all Legendre polynomials up to a given degree.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
x (torch.Tensor): Input tensor for x-coordinates.
|
|
18
|
+
degree (int): Maximum degree of the Legendre polynomials.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
list of torch.Tensor: List of precomputed Legendre polynomials [P_0(x), P_1(x), ..., P_degree(x)].
|
|
22
|
+
"""
|
|
23
|
+
P = [torch.ones_like(x)] # P_0(x) = 1
|
|
24
|
+
if degree >= 1:
|
|
25
|
+
P.append(x) # P_1(x) = x
|
|
26
|
+
|
|
27
|
+
for n in range(2, degree + 1):
|
|
28
|
+
Pn = ((2 * n - 1) * x * P[-1] - (n - 1) * P[-2]) / n
|
|
29
|
+
P.append(Pn)
|
|
30
|
+
|
|
31
|
+
return P
|
|
32
|
+
|
|
33
|
+
def basis_2D(points: torch.Tensor, degree: int) -> torch.Tensor:
|
|
34
|
+
"""
|
|
35
|
+
Generates 2D Legendre polynomial basis functions up to a given degree using precomputed 1D polynomials.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
degree (int): Maximum degree of the Legendre polynomials.
|
|
39
|
+
x (torch.Tensor): x-coordinates as a torch tensor.
|
|
40
|
+
y (torch.Tensor): y-coordinates as a torch tensor.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
torch.Tensor: Tensor of shape (num_basis_functions, *x.shape) with all 2D basis functions.
|
|
44
|
+
"""
|
|
45
|
+
x = points[:, 0]
|
|
46
|
+
y = points[:, 1]
|
|
47
|
+
|
|
48
|
+
# Precompute 1D Legendre polynomials for x and y
|
|
49
|
+
Px = precompute_legendre_polynomials(x, degree)
|
|
50
|
+
Py = precompute_legendre_polynomials(y, degree)
|
|
51
|
+
|
|
52
|
+
basis_functions = {}
|
|
53
|
+
|
|
54
|
+
for i in range(degree + 1):
|
|
55
|
+
for j in range(degree + 1 - i): # Ensure that i + j <= degree
|
|
56
|
+
if not i+j in basis_functions.keys():
|
|
57
|
+
basis_functions[i+j] = []
|
|
58
|
+
basis_functions[i+j] += [Px[i] * Py[j]]
|
|
59
|
+
out = []
|
|
60
|
+
for key in basis_functions.keys():
|
|
61
|
+
out += basis_functions[key]
|
|
62
|
+
|
|
63
|
+
# Stack basis functions along a new dimension
|
|
64
|
+
return torch.stack(out, dim=1)
|
|
65
|
+
|
|
66
|
+
def get_num_coeff(degree: int) -> int:
|
|
67
|
+
"""
|
|
68
|
+
Returns the number of coefficients for a given degree of Legendre polynomials.
|
|
69
|
+
The number of coefficients is given by the formula (degree + 1) * (degree + 2) / 2.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
degree (int): Degree of the Legendre polynomial.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
int: Number of coefficients.
|
|
76
|
+
"""
|
|
77
|
+
return (degree + 1) * (degree + 2) // 2
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Zernike polynomial basis functions for optical wavefront representation.
|
|
3
|
+
|
|
4
|
+
This module provides functions to compute Zernike polynomials, which are commonly used
|
|
5
|
+
in optics for describing wavefront aberrations over a circular aperture. The polynomials
|
|
6
|
+
are orthogonal over the unit disk and are indexed by radial order (n) and azimuthal
|
|
7
|
+
frequency (m).
|
|
8
|
+
|
|
9
|
+
.. figure:: _static/zernike_plot1.png
|
|
10
|
+
:alt: Zernike polynomials visualization
|
|
11
|
+
:width: 60%
|
|
12
|
+
:align: center
|
|
13
|
+
|
|
14
|
+
Visualization of Zernike polynomials organized by radial order (rows) and azimuthal frequency (columns).
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
Basic usage for computing and visualizing Zernike polynomials:
|
|
19
|
+
|
|
20
|
+
>>> import torch
|
|
21
|
+
>>> import numpy as np
|
|
22
|
+
>>> import matplotlib.pyplot as plt
|
|
23
|
+
>>> import diffinytrace.basis_functions.zernike as zernike
|
|
24
|
+
>>>
|
|
25
|
+
>>> # Create unit circle grid
|
|
26
|
+
>>> grid_size = 256
|
|
27
|
+
>>> x = torch.linspace(-1, 1, grid_size)
|
|
28
|
+
>>> y = torch.linspace(-1, 1, grid_size)
|
|
29
|
+
>>> X, Y = torch.meshgrid(x, y, indexing='ij')
|
|
30
|
+
>>>
|
|
31
|
+
>>> # Create mask for unit circle
|
|
32
|
+
>>> mask = (X**2 + Y**2) <= 1.0
|
|
33
|
+
>>> x_points = X[mask]
|
|
34
|
+
>>> y_points = Y[mask]
|
|
35
|
+
>>> points = torch.stack([x_points, y_points], dim=1)
|
|
36
|
+
>>>
|
|
37
|
+
>>> # Evaluate Zernike polynomials
|
|
38
|
+
>>> max_n = 6 # Maximum radial degree
|
|
39
|
+
>>> basis_values = zernike.basis_function(max_n, points)
|
|
40
|
+
>>>
|
|
41
|
+
>>> # Group basis functions by radial degree
|
|
42
|
+
>>> basis_by_degree = {}
|
|
43
|
+
>>> for basis_idx in range(basis_values.shape[1]):
|
|
44
|
+
... radial_order = zernike.get_radial_order(basis_idx)
|
|
45
|
+
... if radial_order not in basis_by_degree:
|
|
46
|
+
... basis_by_degree[radial_order] = []
|
|
47
|
+
... basis_by_degree[radial_order].append(basis_idx)
|
|
48
|
+
>>>
|
|
49
|
+
>>> # Visualize the polynomials
|
|
50
|
+
>>> max_cols = max(len(indices) for indices in basis_by_degree.values())
|
|
51
|
+
>>> num_rows = len(basis_by_degree)
|
|
52
|
+
>>> fig, axes = plt.subplots(num_rows, max_cols, figsize=(3*max_cols, 3*num_rows))
|
|
53
|
+
>>>
|
|
54
|
+
>>> for row_idx, (radial_order, basis_indices) in enumerate(sorted(basis_by_degree.items())):
|
|
55
|
+
... for col_idx, basis_idx in enumerate(basis_indices):
|
|
56
|
+
... # Create 2D array with NaN outside unit circle
|
|
57
|
+
... tmp = torch.full((grid_size, grid_size), float('nan'))
|
|
58
|
+
... tmp[mask] = basis_values[:, basis_idx]
|
|
59
|
+
...
|
|
60
|
+
... # Plot
|
|
61
|
+
... ax = axes[row_idx, col_idx]
|
|
62
|
+
... im = ax.imshow(tmp.numpy(), extent=[-1, 1, -1, 1],
|
|
63
|
+
... origin='lower', cmap='jet', vmin=-1, vmax=1)
|
|
64
|
+
... azimuthal = zernike.get_azimuthal_frequency(basis_idx)
|
|
65
|
+
... ax.set_title(f"$Z^{{{azimuthal}}}_{{{radial_order}}}$", fontsize=25)
|
|
66
|
+
... ax.set_xticks([])
|
|
67
|
+
... ax.set_yticks([])
|
|
68
|
+
... ax.set_aspect('equal')
|
|
69
|
+
>>>
|
|
70
|
+
>>> plt.tight_layout()
|
|
71
|
+
>>> plt.show()
|
|
72
|
+
|
|
73
|
+
Notes:
|
|
74
|
+
- Zernike polynomials are only defined for points within the unit circle (r ≤ 1)
|
|
75
|
+
- Radial order n determines the number of radial variations
|
|
76
|
+
- Azimuthal frequency m determines the angular variations and symmetry
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
80
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
81
|
+
|
|
82
|
+
__all__ = [
|
|
83
|
+
"basis_function",
|
|
84
|
+
"get_num_basis",
|
|
85
|
+
"get_radial_order",
|
|
86
|
+
"get_azimuthal_frequency"
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
import torch
|
|
90
|
+
import math
|
|
91
|
+
|
|
92
|
+
def __zernike_calc(n:int, m:int, r_powers: list) -> torch.Tensor:
|
|
93
|
+
radial_sum = torch.zeros_like(r_powers[0])
|
|
94
|
+
m = abs(m)
|
|
95
|
+
|
|
96
|
+
for k in range((n - m) // 2 + 1):
|
|
97
|
+
coef = math.factorial(n - k) / (
|
|
98
|
+
math.factorial(k) * math.factorial((n + m) // 2 - k) * math.factorial((n - m) // 2 - k))
|
|
99
|
+
if k%2==1:
|
|
100
|
+
coef = -coef
|
|
101
|
+
power_idx = n - 2 * k - m
|
|
102
|
+
|
|
103
|
+
if power_idx < 0:
|
|
104
|
+
raise RuntimeError("Potential zero division!")
|
|
105
|
+
#tmp = r_powers[abs(power_idx)]
|
|
106
|
+
#radial_sum += coef / tmp
|
|
107
|
+
if power_idx % 2 == 1:
|
|
108
|
+
raise RuntimeError("tried to acces odd power idx!")
|
|
109
|
+
|
|
110
|
+
radial_sum += coef*r_powers[abs(power_idx)]
|
|
111
|
+
|
|
112
|
+
return radial_sum
|
|
113
|
+
|
|
114
|
+
def basis_2D(points: torch.Tensor, max_radial_order: int) -> torch.Tensor:
|
|
115
|
+
"""
|
|
116
|
+
Compute Zernike polynomials for a given set of points.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
max_radial_order (int): Maximum radial order.
|
|
120
|
+
points (torch.Tensor): Tensor of shape (N, 2) containing the x and y coordinates of the points.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
torch.Tensor: Tensor of shape (N, num_coeffs) containing the Zernike polynomial values.
|
|
124
|
+
"""
|
|
125
|
+
x = points[:, 0]
|
|
126
|
+
y = points[:, 1]
|
|
127
|
+
|
|
128
|
+
#r = torch.sqrt(x**2 + y**2)
|
|
129
|
+
r2 = x**2 + y**2
|
|
130
|
+
# Precompute powers of r from r^0 to r^max_radial_order
|
|
131
|
+
r_powers = [] #[r ** i for i in range(max_radial_order+1)]
|
|
132
|
+
for i in range(max_radial_order+1):
|
|
133
|
+
if i%2 == 0:
|
|
134
|
+
r_powers += [r2**(i/2.0)]
|
|
135
|
+
else:
|
|
136
|
+
r_powers += [None]
|
|
137
|
+
|
|
138
|
+
# List to store Zernike polynomial results
|
|
139
|
+
zernike_polynomials = []
|
|
140
|
+
|
|
141
|
+
# Loop over radial and azimuthal degrees
|
|
142
|
+
for n in range(max_radial_order + 1):
|
|
143
|
+
for m in range(-n, n + 1, 2): # m must have the same parity as n
|
|
144
|
+
#r_m = r_powers[abs(m)] # Precompute r^m for both cos and sin components
|
|
145
|
+
|
|
146
|
+
if m >= 0:
|
|
147
|
+
#TODO Remove weird complex number stuff!
|
|
148
|
+
multiplier = torch.real((y + 1j * x)**abs(m))#TODO multiply after zerinke_calc!
|
|
149
|
+
#this is to slow!!
|
|
150
|
+
zernike_polynomials.append(multiplier*__zernike_calc(n, m, r_powers))
|
|
151
|
+
else:
|
|
152
|
+
multiplier = torch.imag((y + 1j * x)**abs(m))
|
|
153
|
+
zernike_polynomials.append(multiplier*__zernike_calc(n, m, r_powers))
|
|
154
|
+
|
|
155
|
+
# Stack all Zernike polynomials into a single tensor
|
|
156
|
+
zernike_basis = torch.stack(zernike_polynomials, dim=1)
|
|
157
|
+
|
|
158
|
+
return zernike_basis
|
|
159
|
+
|
|
160
|
+
def get_num_basis(max_radial_order:int) -> int:
|
|
161
|
+
"""
|
|
162
|
+
Calculate the number of basis functions for Zernike polynomials up to a given radial order.
|
|
163
|
+
The number of basis functions is given by the formula (n + 1) * (n + 2) / 2.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
max_radial_order (int): Maximum radial order.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
int: Number of coefficients.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
n = max_radial_order+1
|
|
173
|
+
return int(n*(n+1) / 2)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_radial_order(basis_idx:int) -> int:
|
|
177
|
+
"""
|
|
178
|
+
Calculate the radial degree from the basis function index.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
basis_idx (int): Index of the basis function.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
int: Radial degree.
|
|
185
|
+
"""
|
|
186
|
+
basis_idx_runner = basis_idx
|
|
187
|
+
|
|
188
|
+
num_azimuthal_frequencies = 1
|
|
189
|
+
row_idx = 0
|
|
190
|
+
while True:
|
|
191
|
+
if basis_idx_runner < num_azimuthal_frequencies:
|
|
192
|
+
return row_idx
|
|
193
|
+
|
|
194
|
+
basis_idx_runner = basis_idx_runner-num_azimuthal_frequencies
|
|
195
|
+
|
|
196
|
+
num_azimuthal_frequencies += 1
|
|
197
|
+
row_idx += 1
|
|
198
|
+
|
|
199
|
+
def get_azimuthal_frequency(basis_idx:int) -> int:
|
|
200
|
+
"""
|
|
201
|
+
Calculate the azimuthal frequency from the basis function index.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
basis_idx (int): Index of the basis function.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
int: Azimuthal frequency (m value).
|
|
208
|
+
"""
|
|
209
|
+
# First get the radial degree n
|
|
210
|
+
row_idx = get_radial_order(basis_idx)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
num_azimuthal_frequencies = 1
|
|
214
|
+
basis_idx_runner = basis_idx
|
|
215
|
+
for k in range(row_idx):
|
|
216
|
+
basis_idx_runner = basis_idx_runner-num_azimuthal_frequencies
|
|
217
|
+
num_azimuthal_frequencies += 1
|
|
218
|
+
|
|
219
|
+
x_idx_start = None
|
|
220
|
+
if num_azimuthal_frequencies % 2 == 0:
|
|
221
|
+
half = num_azimuthal_frequencies // 2
|
|
222
|
+
tmp = (half-1)*2+1
|
|
223
|
+
x_idx_start = - tmp
|
|
224
|
+
else:
|
|
225
|
+
half = num_azimuthal_frequencies // 2
|
|
226
|
+
tmp = (half)*2
|
|
227
|
+
x_idx_start = - tmp
|
|
228
|
+
#x_idx_start = - (num_azimuthal_frequencies // 2)
|
|
229
|
+
# For radial degree n, we have m values: -n, -n+2, ..., -2, 0, 2, ..., n-2, n
|
|
230
|
+
# The azimuthal frequency m follows the pattern:
|
|
231
|
+
# pos_in_row = 0 -> m = -n
|
|
232
|
+
# pos_in_row = 1 -> m = -n + 2
|
|
233
|
+
# ...
|
|
234
|
+
# pos_in_row = n -> m = n
|
|
235
|
+
return x_idx_start + basis_idx_runner*2
|
diffinytrace/config.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration module for diffinytrace.
|
|
3
|
+
|
|
4
|
+
This module provides global configuration options for controlling ray intersection
|
|
5
|
+
solver behavior, such as tolerance, maximum iterations, damping factor, and whether
|
|
6
|
+
to display iteration counts. These settings can be adjusted at runtime to tune
|
|
7
|
+
performance and accuracy.
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
>>> import diffinytrace.config as config
|
|
11
|
+
>>> config.set_tolerance(1e-8)
|
|
12
|
+
>>> config.set_show_iteration_count(True)
|
|
13
|
+
>>> config.restore_default_settings()
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
17
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"set_show_iteration_count",
|
|
21
|
+
"get_show_iteration_count",
|
|
22
|
+
"set_tolerance",
|
|
23
|
+
"get_tolerance",
|
|
24
|
+
"set_max_iterations",
|
|
25
|
+
"get_max_iterations",
|
|
26
|
+
"set_damping_factor",
|
|
27
|
+
"get_damping_factor",
|
|
28
|
+
"restore_default_settings"
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
# Default settings for ray intersection parameters
|
|
32
|
+
tolerance:float = 1e-6 # Default tolerance for ray intersection
|
|
33
|
+
max_iterations:int = 100 # Default maximum number of iterations for the solver
|
|
34
|
+
damping_factor:float = 1.0 # Default damping factor for the Newton method (1.0 means no damping)
|
|
35
|
+
show_iteration_count:bool = False # Default setting to not show the number of iterations
|
|
36
|
+
|
|
37
|
+
def set_show_iteration_count(flag):
|
|
38
|
+
"""
|
|
39
|
+
Set the option to show the number of iterations for each intersection.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
flag (bool): True to show the number of iterations, False otherwise.
|
|
43
|
+
"""
|
|
44
|
+
global show_iteration_count
|
|
45
|
+
show_iteration_count = flag
|
|
46
|
+
|
|
47
|
+
def get_show_iteration_count():
|
|
48
|
+
"""
|
|
49
|
+
Check if the number of iterations should be shown.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
bool: True if the number of iterations should be shown, False otherwise.
|
|
53
|
+
"""
|
|
54
|
+
return show_iteration_count
|
|
55
|
+
|
|
56
|
+
def set_tolerance(new_tolerance):
|
|
57
|
+
"""
|
|
58
|
+
Set the tolerance for ray intersection calculations.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
new_tolerance (float): The new tolerance value (must be > 0).
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If `new_tolerance` is not greater than 0.
|
|
65
|
+
"""
|
|
66
|
+
if new_tolerance <= 0:
|
|
67
|
+
raise ValueError("Tolerance must be greater than 0.")
|
|
68
|
+
global tolerance
|
|
69
|
+
tolerance = new_tolerance
|
|
70
|
+
|
|
71
|
+
def get_tolerance():
|
|
72
|
+
"""
|
|
73
|
+
Get the current tolerance for ray intersection calculations.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
float: The current tolerance value.
|
|
77
|
+
"""
|
|
78
|
+
return tolerance
|
|
79
|
+
|
|
80
|
+
def set_max_iterations(new_max_iterations):
|
|
81
|
+
"""
|
|
82
|
+
Set the maximum number of iterations for the ray intersection solver.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
new_max_iterations (int): The new maximum number of iterations (must be > 0).
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
ValueError: If `new_max_iterations` is not greater than 0.
|
|
89
|
+
"""
|
|
90
|
+
if new_max_iterations <= 0:
|
|
91
|
+
raise ValueError("Maximum iterations must be greater than 0.")
|
|
92
|
+
global max_iterations
|
|
93
|
+
max_iterations = new_max_iterations
|
|
94
|
+
|
|
95
|
+
def get_max_iterations():
|
|
96
|
+
"""
|
|
97
|
+
Get the current maximum number of iterations for the ray intersection solver.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
int: The current maximum number of iterations.
|
|
101
|
+
"""
|
|
102
|
+
global max_iterations
|
|
103
|
+
return max_iterations
|
|
104
|
+
|
|
105
|
+
def set_damping_factor(new_damping_factor):
|
|
106
|
+
"""
|
|
107
|
+
Set the damping factor for the Newton method used in ray intersections.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
new_damping_factor (float): The new damping factor (0 < new_damping_factor <= 1).
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If `new_damping_factor` is not between 0 and 1 (exclusive of 0, inclusive of 1).
|
|
114
|
+
"""
|
|
115
|
+
global damping_factor
|
|
116
|
+
if 0 < new_damping_factor <= 1:
|
|
117
|
+
damping_factor = new_damping_factor
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError("Damping factor must be between 0 and 1.")
|
|
120
|
+
|
|
121
|
+
def get_damping_factor():
|
|
122
|
+
"""
|
|
123
|
+
Get the current damping factor for the Newton method.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
float: The current damping factor.
|
|
127
|
+
"""
|
|
128
|
+
return damping_factor
|
|
129
|
+
|
|
130
|
+
def restore_default_settings():
|
|
131
|
+
"""
|
|
132
|
+
Reset to the default configuration settings for the ray tracer.
|
|
133
|
+
|
|
134
|
+
This will reset all configuration parameters to their default values.
|
|
135
|
+
"""
|
|
136
|
+
global tolerance, max_iterations, damping_factor, show_iteration_count
|
|
137
|
+
tolerance = 1e-6
|
|
138
|
+
max_iterations = 100
|
|
139
|
+
damping_factor = 1.0
|
|
140
|
+
show_iteration_count = False # Reset to default (not showing)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
__all__ = ["Constraint", "EqualZero", "GEQZero", "LEQZero"]
|
|
5
|
+
|
|
6
|
+
from .physical_object import PhysicalSurface
|
|
7
|
+
import torch
|
|
8
|
+
from . integrators import Cube
|
|
9
|
+
from .optimize import minimize
|
|
10
|
+
from .utils.autograd import grad
|
|
11
|
+
#from .element import OpticalSurface
|
|
12
|
+
|
|
13
|
+
class Constraint():
|
|
14
|
+
"""
|
|
15
|
+
Base class for optimization constraints.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
fun (Callable): Function defining the constraint.
|
|
19
|
+
type (str): Type of constraint ('eq' or 'ineq').
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self,fun,type):
|
|
22
|
+
self.fun = fun
|
|
23
|
+
self.type = type
|
|
24
|
+
|
|
25
|
+
class EqualZero(Constraint):
|
|
26
|
+
"""
|
|
27
|
+
Equality constraint enforcing `fun() == 0`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
fun (Callable): The constraint function.
|
|
31
|
+
"""
|
|
32
|
+
def __init__(self,fun):
|
|
33
|
+
super().__init__(fun,'eq')
|
|
34
|
+
|
|
35
|
+
class GEQZero(Constraint):
|
|
36
|
+
"""
|
|
37
|
+
Inequality constraint enforcing `fun() >= 0`.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
fun (Callable): The constraint function.
|
|
41
|
+
"""
|
|
42
|
+
def __init__(self,fun):
|
|
43
|
+
super().__init__(fun,'ineq')
|
|
44
|
+
|
|
45
|
+
class LEQZero(Constraint):
|
|
46
|
+
"""
|
|
47
|
+
Inequality constraint enforcing `fun() <= 0`.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
fun (Callable): The constraint function.
|
|
51
|
+
"""
|
|
52
|
+
def __init__(self,fun):
|
|
53
|
+
super().__init__(lambda: -fun(),'ineq')
|
|
54
|
+
|