diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,77 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "precompute_legendre_polynomials",
6
+ "basis_2D",
7
+ "get_num_coeff"
8
+ ]
9
+
10
+ import torch
11
+
12
+ def precompute_legendre_polynomials(x: torch.Tensor, degree: int) -> list[torch.Tensor]:
13
+ """
14
+ Precomputes all Legendre polynomials up to a given degree.
15
+
16
+ Args:
17
+ x (torch.Tensor): Input tensor for x-coordinates.
18
+ degree (int): Maximum degree of the Legendre polynomials.
19
+
20
+ Returns:
21
+ list of torch.Tensor: List of precomputed Legendre polynomials [P_0(x), P_1(x), ..., P_degree(x)].
22
+ """
23
+ P = [torch.ones_like(x)] # P_0(x) = 1
24
+ if degree >= 1:
25
+ P.append(x) # P_1(x) = x
26
+
27
+ for n in range(2, degree + 1):
28
+ Pn = ((2 * n - 1) * x * P[-1] - (n - 1) * P[-2]) / n
29
+ P.append(Pn)
30
+
31
+ return P
32
+
33
+ def basis_2D(points: torch.Tensor, degree: int) -> torch.Tensor:
34
+ """
35
+ Generates 2D Legendre polynomial basis functions up to a given degree using precomputed 1D polynomials.
36
+
37
+ Args:
38
+ degree (int): Maximum degree of the Legendre polynomials.
39
+ x (torch.Tensor): x-coordinates as a torch tensor.
40
+ y (torch.Tensor): y-coordinates as a torch tensor.
41
+
42
+ Returns:
43
+ torch.Tensor: Tensor of shape (num_basis_functions, *x.shape) with all 2D basis functions.
44
+ """
45
+ x = points[:, 0]
46
+ y = points[:, 1]
47
+
48
+ # Precompute 1D Legendre polynomials for x and y
49
+ Px = precompute_legendre_polynomials(x, degree)
50
+ Py = precompute_legendre_polynomials(y, degree)
51
+
52
+ basis_functions = {}
53
+
54
+ for i in range(degree + 1):
55
+ for j in range(degree + 1 - i): # Ensure that i + j <= degree
56
+ if not i+j in basis_functions.keys():
57
+ basis_functions[i+j] = []
58
+ basis_functions[i+j] += [Px[i] * Py[j]]
59
+ out = []
60
+ for key in basis_functions.keys():
61
+ out += basis_functions[key]
62
+
63
+ # Stack basis functions along a new dimension
64
+ return torch.stack(out, dim=1)
65
+
66
+ def get_num_coeff(degree: int) -> int:
67
+ """
68
+ Returns the number of coefficients for a given degree of Legendre polynomials.
69
+ The number of coefficients is given by the formula (degree + 1) * (degree + 2) / 2.
70
+
71
+ Args:
72
+ degree (int): Degree of the Legendre polynomial.
73
+
74
+ Returns:
75
+ int: Number of coefficients.
76
+ """
77
+ return (degree + 1) * (degree + 2) // 2
@@ -0,0 +1,235 @@
1
+ """
2
+ Zernike polynomial basis functions for optical wavefront representation.
3
+
4
+ This module provides functions to compute Zernike polynomials, which are commonly used
5
+ in optics for describing wavefront aberrations over a circular aperture. The polynomials
6
+ are orthogonal over the unit disk and are indexed by radial order (n) and azimuthal
7
+ frequency (m).
8
+
9
+ .. figure:: _static/zernike_plot1.png
10
+ :alt: Zernike polynomials visualization
11
+ :width: 60%
12
+ :align: center
13
+
14
+ Visualization of Zernike polynomials organized by radial order (rows) and azimuthal frequency (columns).
15
+
16
+
17
+ Example:
18
+ Basic usage for computing and visualizing Zernike polynomials:
19
+
20
+ >>> import torch
21
+ >>> import numpy as np
22
+ >>> import matplotlib.pyplot as plt
23
+ >>> import diffinytrace.basis_functions.zernike as zernike
24
+ >>>
25
+ >>> # Create unit circle grid
26
+ >>> grid_size = 256
27
+ >>> x = torch.linspace(-1, 1, grid_size)
28
+ >>> y = torch.linspace(-1, 1, grid_size)
29
+ >>> X, Y = torch.meshgrid(x, y, indexing='ij')
30
+ >>>
31
+ >>> # Create mask for unit circle
32
+ >>> mask = (X**2 + Y**2) <= 1.0
33
+ >>> x_points = X[mask]
34
+ >>> y_points = Y[mask]
35
+ >>> points = torch.stack([x_points, y_points], dim=1)
36
+ >>>
37
+ >>> # Evaluate Zernike polynomials
38
+ >>> max_n = 6 # Maximum radial degree
39
+ >>> basis_values = zernike.basis_function(max_n, points)
40
+ >>>
41
+ >>> # Group basis functions by radial degree
42
+ >>> basis_by_degree = {}
43
+ >>> for basis_idx in range(basis_values.shape[1]):
44
+ ... radial_order = zernike.get_radial_order(basis_idx)
45
+ ... if radial_order not in basis_by_degree:
46
+ ... basis_by_degree[radial_order] = []
47
+ ... basis_by_degree[radial_order].append(basis_idx)
48
+ >>>
49
+ >>> # Visualize the polynomials
50
+ >>> max_cols = max(len(indices) for indices in basis_by_degree.values())
51
+ >>> num_rows = len(basis_by_degree)
52
+ >>> fig, axes = plt.subplots(num_rows, max_cols, figsize=(3*max_cols, 3*num_rows))
53
+ >>>
54
+ >>> for row_idx, (radial_order, basis_indices) in enumerate(sorted(basis_by_degree.items())):
55
+ ... for col_idx, basis_idx in enumerate(basis_indices):
56
+ ... # Create 2D array with NaN outside unit circle
57
+ ... tmp = torch.full((grid_size, grid_size), float('nan'))
58
+ ... tmp[mask] = basis_values[:, basis_idx]
59
+ ...
60
+ ... # Plot
61
+ ... ax = axes[row_idx, col_idx]
62
+ ... im = ax.imshow(tmp.numpy(), extent=[-1, 1, -1, 1],
63
+ ... origin='lower', cmap='jet', vmin=-1, vmax=1)
64
+ ... azimuthal = zernike.get_azimuthal_frequency(basis_idx)
65
+ ... ax.set_title(f"$Z^{{{azimuthal}}}_{{{radial_order}}}$", fontsize=25)
66
+ ... ax.set_xticks([])
67
+ ... ax.set_yticks([])
68
+ ... ax.set_aspect('equal')
69
+ >>>
70
+ >>> plt.tight_layout()
71
+ >>> plt.show()
72
+
73
+ Notes:
74
+ - Zernike polynomials are only defined for points within the unit circle (r ≤ 1)
75
+ - Radial order n determines the number of radial variations
76
+ - Azimuthal frequency m determines the angular variations and symmetry
77
+ """
78
+
79
+ # Copyright (c) 2025 Martin Pflaum
80
+ # This file is part of the diffinytrace project, licensed under the MIT License.
81
+
82
+ __all__ = [
83
+ "basis_function",
84
+ "get_num_basis",
85
+ "get_radial_order",
86
+ "get_azimuthal_frequency"
87
+ ]
88
+
89
+ import torch
90
+ import math
91
+
92
+ def __zernike_calc(n:int, m:int, r_powers: list) -> torch.Tensor:
93
+ radial_sum = torch.zeros_like(r_powers[0])
94
+ m = abs(m)
95
+
96
+ for k in range((n - m) // 2 + 1):
97
+ coef = math.factorial(n - k) / (
98
+ math.factorial(k) * math.factorial((n + m) // 2 - k) * math.factorial((n - m) // 2 - k))
99
+ if k%2==1:
100
+ coef = -coef
101
+ power_idx = n - 2 * k - m
102
+
103
+ if power_idx < 0:
104
+ raise RuntimeError("Potential zero division!")
105
+ #tmp = r_powers[abs(power_idx)]
106
+ #radial_sum += coef / tmp
107
+ if power_idx % 2 == 1:
108
+ raise RuntimeError("tried to acces odd power idx!")
109
+
110
+ radial_sum += coef*r_powers[abs(power_idx)]
111
+
112
+ return radial_sum
113
+
114
+ def basis_2D(points: torch.Tensor, max_radial_order: int) -> torch.Tensor:
115
+ """
116
+ Compute Zernike polynomials for a given set of points.
117
+
118
+ Args:
119
+ max_radial_order (int): Maximum radial order.
120
+ points (torch.Tensor): Tensor of shape (N, 2) containing the x and y coordinates of the points.
121
+
122
+ Returns:
123
+ torch.Tensor: Tensor of shape (N, num_coeffs) containing the Zernike polynomial values.
124
+ """
125
+ x = points[:, 0]
126
+ y = points[:, 1]
127
+
128
+ #r = torch.sqrt(x**2 + y**2)
129
+ r2 = x**2 + y**2
130
+ # Precompute powers of r from r^0 to r^max_radial_order
131
+ r_powers = [] #[r ** i for i in range(max_radial_order+1)]
132
+ for i in range(max_radial_order+1):
133
+ if i%2 == 0:
134
+ r_powers += [r2**(i/2.0)]
135
+ else:
136
+ r_powers += [None]
137
+
138
+ # List to store Zernike polynomial results
139
+ zernike_polynomials = []
140
+
141
+ # Loop over radial and azimuthal degrees
142
+ for n in range(max_radial_order + 1):
143
+ for m in range(-n, n + 1, 2): # m must have the same parity as n
144
+ #r_m = r_powers[abs(m)] # Precompute r^m for both cos and sin components
145
+
146
+ if m >= 0:
147
+ #TODO Remove weird complex number stuff!
148
+ multiplier = torch.real((y + 1j * x)**abs(m))#TODO multiply after zerinke_calc!
149
+ #this is to slow!!
150
+ zernike_polynomials.append(multiplier*__zernike_calc(n, m, r_powers))
151
+ else:
152
+ multiplier = torch.imag((y + 1j * x)**abs(m))
153
+ zernike_polynomials.append(multiplier*__zernike_calc(n, m, r_powers))
154
+
155
+ # Stack all Zernike polynomials into a single tensor
156
+ zernike_basis = torch.stack(zernike_polynomials, dim=1)
157
+
158
+ return zernike_basis
159
+
160
+ def get_num_basis(max_radial_order:int) -> int:
161
+ """
162
+ Calculate the number of basis functions for Zernike polynomials up to a given radial order.
163
+ The number of basis functions is given by the formula (n + 1) * (n + 2) / 2.
164
+
165
+ Args:
166
+ max_radial_order (int): Maximum radial order.
167
+
168
+ Returns:
169
+ int: Number of coefficients.
170
+ """
171
+
172
+ n = max_radial_order+1
173
+ return int(n*(n+1) / 2)
174
+
175
+
176
+ def get_radial_order(basis_idx:int) -> int:
177
+ """
178
+ Calculate the radial degree from the basis function index.
179
+
180
+ Args:
181
+ basis_idx (int): Index of the basis function.
182
+
183
+ Returns:
184
+ int: Radial degree.
185
+ """
186
+ basis_idx_runner = basis_idx
187
+
188
+ num_azimuthal_frequencies = 1
189
+ row_idx = 0
190
+ while True:
191
+ if basis_idx_runner < num_azimuthal_frequencies:
192
+ return row_idx
193
+
194
+ basis_idx_runner = basis_idx_runner-num_azimuthal_frequencies
195
+
196
+ num_azimuthal_frequencies += 1
197
+ row_idx += 1
198
+
199
+ def get_azimuthal_frequency(basis_idx:int) -> int:
200
+ """
201
+ Calculate the azimuthal frequency from the basis function index.
202
+
203
+ Args:
204
+ basis_idx (int): Index of the basis function.
205
+
206
+ Returns:
207
+ int: Azimuthal frequency (m value).
208
+ """
209
+ # First get the radial degree n
210
+ row_idx = get_radial_order(basis_idx)
211
+
212
+
213
+ num_azimuthal_frequencies = 1
214
+ basis_idx_runner = basis_idx
215
+ for k in range(row_idx):
216
+ basis_idx_runner = basis_idx_runner-num_azimuthal_frequencies
217
+ num_azimuthal_frequencies += 1
218
+
219
+ x_idx_start = None
220
+ if num_azimuthal_frequencies % 2 == 0:
221
+ half = num_azimuthal_frequencies // 2
222
+ tmp = (half-1)*2+1
223
+ x_idx_start = - tmp
224
+ else:
225
+ half = num_azimuthal_frequencies // 2
226
+ tmp = (half)*2
227
+ x_idx_start = - tmp
228
+ #x_idx_start = - (num_azimuthal_frequencies // 2)
229
+ # For radial degree n, we have m values: -n, -n+2, ..., -2, 0, 2, ..., n-2, n
230
+ # The azimuthal frequency m follows the pattern:
231
+ # pos_in_row = 0 -> m = -n
232
+ # pos_in_row = 1 -> m = -n + 2
233
+ # ...
234
+ # pos_in_row = n -> m = n
235
+ return x_idx_start + basis_idx_runner*2
diffinytrace/config.py ADDED
@@ -0,0 +1,140 @@
1
+ """
2
+ Configuration module for diffinytrace.
3
+
4
+ This module provides global configuration options for controlling ray intersection
5
+ solver behavior, such as tolerance, maximum iterations, damping factor, and whether
6
+ to display iteration counts. These settings can be adjusted at runtime to tune
7
+ performance and accuracy.
8
+
9
+ Example:
10
+ >>> import diffinytrace.config as config
11
+ >>> config.set_tolerance(1e-8)
12
+ >>> config.set_show_iteration_count(True)
13
+ >>> config.restore_default_settings()
14
+ """
15
+
16
+ # Copyright (c) 2025 Martin Pflaum
17
+ # This file is part of the diffinytrace project, licensed under the MIT License.
18
+
19
+ __all__ = [
20
+ "set_show_iteration_count",
21
+ "get_show_iteration_count",
22
+ "set_tolerance",
23
+ "get_tolerance",
24
+ "set_max_iterations",
25
+ "get_max_iterations",
26
+ "set_damping_factor",
27
+ "get_damping_factor",
28
+ "restore_default_settings"
29
+ ]
30
+
31
+ # Default settings for ray intersection parameters
32
+ tolerance:float = 1e-6 # Default tolerance for ray intersection
33
+ max_iterations:int = 100 # Default maximum number of iterations for the solver
34
+ damping_factor:float = 1.0 # Default damping factor for the Newton method (1.0 means no damping)
35
+ show_iteration_count:bool = False # Default setting to not show the number of iterations
36
+
37
+ def set_show_iteration_count(flag):
38
+ """
39
+ Set the option to show the number of iterations for each intersection.
40
+
41
+ Args:
42
+ flag (bool): True to show the number of iterations, False otherwise.
43
+ """
44
+ global show_iteration_count
45
+ show_iteration_count = flag
46
+
47
+ def get_show_iteration_count():
48
+ """
49
+ Check if the number of iterations should be shown.
50
+
51
+ Returns:
52
+ bool: True if the number of iterations should be shown, False otherwise.
53
+ """
54
+ return show_iteration_count
55
+
56
+ def set_tolerance(new_tolerance):
57
+ """
58
+ Set the tolerance for ray intersection calculations.
59
+
60
+ Args:
61
+ new_tolerance (float): The new tolerance value (must be > 0).
62
+
63
+ Raises:
64
+ ValueError: If `new_tolerance` is not greater than 0.
65
+ """
66
+ if new_tolerance <= 0:
67
+ raise ValueError("Tolerance must be greater than 0.")
68
+ global tolerance
69
+ tolerance = new_tolerance
70
+
71
+ def get_tolerance():
72
+ """
73
+ Get the current tolerance for ray intersection calculations.
74
+
75
+ Returns:
76
+ float: The current tolerance value.
77
+ """
78
+ return tolerance
79
+
80
+ def set_max_iterations(new_max_iterations):
81
+ """
82
+ Set the maximum number of iterations for the ray intersection solver.
83
+
84
+ Args:
85
+ new_max_iterations (int): The new maximum number of iterations (must be > 0).
86
+
87
+ Raises:
88
+ ValueError: If `new_max_iterations` is not greater than 0.
89
+ """
90
+ if new_max_iterations <= 0:
91
+ raise ValueError("Maximum iterations must be greater than 0.")
92
+ global max_iterations
93
+ max_iterations = new_max_iterations
94
+
95
+ def get_max_iterations():
96
+ """
97
+ Get the current maximum number of iterations for the ray intersection solver.
98
+
99
+ Returns:
100
+ int: The current maximum number of iterations.
101
+ """
102
+ global max_iterations
103
+ return max_iterations
104
+
105
+ def set_damping_factor(new_damping_factor):
106
+ """
107
+ Set the damping factor for the Newton method used in ray intersections.
108
+
109
+ Args:
110
+ new_damping_factor (float): The new damping factor (0 < new_damping_factor <= 1).
111
+
112
+ Raises:
113
+ ValueError: If `new_damping_factor` is not between 0 and 1 (exclusive of 0, inclusive of 1).
114
+ """
115
+ global damping_factor
116
+ if 0 < new_damping_factor <= 1:
117
+ damping_factor = new_damping_factor
118
+ else:
119
+ raise ValueError("Damping factor must be between 0 and 1.")
120
+
121
+ def get_damping_factor():
122
+ """
123
+ Get the current damping factor for the Newton method.
124
+
125
+ Returns:
126
+ float: The current damping factor.
127
+ """
128
+ return damping_factor
129
+
130
+ def restore_default_settings():
131
+ """
132
+ Reset to the default configuration settings for the ray tracer.
133
+
134
+ This will reset all configuration parameters to their default values.
135
+ """
136
+ global tolerance, max_iterations, damping_factor, show_iteration_count
137
+ tolerance = 1e-6
138
+ max_iterations = 100
139
+ damping_factor = 1.0
140
+ show_iteration_count = False # Reset to default (not showing)
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = ["Constraint", "EqualZero", "GEQZero", "LEQZero"]
5
+
6
+ from .physical_object import PhysicalSurface
7
+ import torch
8
+ from . integrators import Cube
9
+ from .optimize import minimize
10
+ from .utils.autograd import grad
11
+ #from .element import OpticalSurface
12
+
13
+ class Constraint():
14
+ """
15
+ Base class for optimization constraints.
16
+
17
+ Attributes:
18
+ fun (Callable): Function defining the constraint.
19
+ type (str): Type of constraint ('eq' or 'ineq').
20
+ """
21
+ def __init__(self,fun,type):
22
+ self.fun = fun
23
+ self.type = type
24
+
25
+ class EqualZero(Constraint):
26
+ """
27
+ Equality constraint enforcing `fun() == 0`.
28
+
29
+ Args:
30
+ fun (Callable): The constraint function.
31
+ """
32
+ def __init__(self,fun):
33
+ super().__init__(fun,'eq')
34
+
35
+ class GEQZero(Constraint):
36
+ """
37
+ Inequality constraint enforcing `fun() >= 0`.
38
+
39
+ Args:
40
+ fun (Callable): The constraint function.
41
+ """
42
+ def __init__(self,fun):
43
+ super().__init__(fun,'ineq')
44
+
45
+ class LEQZero(Constraint):
46
+ """
47
+ Inequality constraint enforcing `fun() <= 0`.
48
+
49
+ Args:
50
+ fun (Callable): The constraint function.
51
+ """
52
+ def __init__(self,fun):
53
+ super().__init__(lambda: -fun(),'ineq')
54
+