diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,285 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "SemiFunctionalModule",
6
+ "cat_semi_functionals",
7
+ "get_functional_param_args",
8
+ "construct_surface_and_normal_func",
9
+ "construct_surface_and_normal_func_with_params",
10
+ "CustomAutogradRule_t",
11
+ "get_ray_intersection_length"
12
+ ]
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from .utils.autograd import grad
17
+ from .config import get_max_iterations,get_tolerance,get_damping_factor,get_show_iteration_count
18
+ from typing import List,Tuple,Callable,Optional
19
+
20
+ class SemiFunctionalModule(nn.Module):
21
+ r"""
22
+ Abstract base class for semi-functional surface modules.
23
+
24
+ These modules define a static method `functional` that computes a
25
+ functional transformation on inputs and parameters, and a method to list
26
+ their functional parameters for optimization purposes.
27
+ """
28
+ def __init__(self):
29
+ super().__init__()
30
+
31
+ @staticmethod
32
+ def functional(O:torch.Tensor,*params):
33
+ r"""
34
+ This method provides the implicit surface description. It is a static method.
35
+ Diffinytrace constructs a function `s(R, p)` on the fly to describe the surface,
36
+ allowing better control over derivative calculations.
37
+ """
38
+ raise NotImplementedError("functional not implemented")
39
+
40
+ def get_functional_param_args(self):
41
+
42
+ raise NotImplementedError("params_list not implemented")
43
+
44
+
45
+ def cat_semi_functionals(functional_modules:List[SemiFunctionalModule])->Callable:
46
+ r"""
47
+ Recursively chains a list of `SemiFunctionalModule`s into a single composite function.
48
+
49
+ Each module's `functional()` method is applied in sequence using the respective
50
+ slice of the parameter list.
51
+
52
+ Args:
53
+ functional_modules (list[SemiFunctionalModule]): List of functional modules.
54
+
55
+ Returns:
56
+ Callable: A function f(O, *params) that applies all modules in sequence.
57
+ """
58
+
59
+ if len(functional_modules) == 0:
60
+ return lambda O,*params: O
61
+ current_func = functional_modules[0].functional
62
+ other = functional_modules[1:]
63
+ num_params = len(functional_modules[0].get_functional_param_args())
64
+ def fun_out(O,*params):
65
+ other_funs = cat_semi_functionals(other)
66
+ return other_funs(current_func(O,*params[:num_params]),*params[num_params:])
67
+ return fun_out
68
+
69
+ def get_functional_param_args(semi_functional_module_list:List[SemiFunctionalModule])->List:
70
+ r"""
71
+ Collects all functional parameters from a list of semi-functional modules.
72
+
73
+ Args:
74
+ semi_functional_module_list (list[SemiFunctionalModule]): List of modules.
75
+
76
+ Returns:
77
+ list[torch.nn.Parameter]: Flattened list of all parameters.
78
+ """
79
+ out = []
80
+ for elem in semi_functional_module_list:
81
+ out += elem.get_functional_param_args()
82
+ return out
83
+
84
+
85
+ def construct_surface_and_normal_func(semi_functional_module_list:List[SemiFunctionalModule]) -> Callable:
86
+ r"""
87
+ Constructs a function to evaluate both the surface value and its gradient
88
+ (normal direction) with respect to the ray origin `O`.
89
+
90
+ The surface is defined by composing the provided semi-functional modules.
91
+
92
+ Returns a callable:
93
+
94
+ .. math::
95
+ (O, p_1, ..., p_n) \mapsto ( s(O), \frac{\partial s}{\partial O} )
96
+
97
+ Args:
98
+ semi_functional_module_list (list[SemiFunctionalModule]): List of modules.
99
+
100
+ Returns:
101
+ Callable: A function `s_dsd(O, *params, only_s=False)` returning
102
+ surface value `s` and optionally gradient `ds/dO`.
103
+ """
104
+ s = cat_semi_functionals(semi_functional_module_list)
105
+ def s_dsd(O,*params,only_s = False):
106
+ sval,dsdval= None,None
107
+ with torch.enable_grad():
108
+ if not O.requires_grad:
109
+ O.requires_grad = True
110
+ sval = s(O,*params)
111
+ if only_s:
112
+ return sval
113
+ dsdval = grad(sval,inputs=O,grad_outputs=torch.ones_like(sval))
114
+ dsdval = dsdval[0]
115
+ return sval,dsdval
116
+ return s_dsd
117
+
118
+ def construct_surface_and_normal_func_with_params(semi_functional_module_list:List[SemiFunctionalModule]) -> Tuple[Callable, List]:
119
+ r"""
120
+ Constructs both the surface function and a list of its functional parameters.
121
+
122
+ Useful for optimization workflows that require parameter tracking.
123
+
124
+ Args:
125
+ semi_functional_module_list (list[SemiFunctionalModule]): List of modules.
126
+
127
+ Returns:
128
+ tuple:
129
+ Callable: A function computing surface and its gradient.
130
+ list[torch.nn.Parameter]: The list of parameters for the surface.
131
+ """
132
+ s_dsd = construct_surface_and_normal_func(semi_functional_module_list)
133
+ args = get_functional_param_args(semi_functional_module_list)
134
+ return s_dsd,args
135
+
136
+
137
+
138
+
139
+ class CustomAutogradRule_t(torch.autograd.Function):
140
+ """
141
+ Custom PyTorch autograd rule for ray-surface intersection.
142
+
143
+ Computes a differentiable intersection length `t` such that:
144
+
145
+ .. math::
146
+ s(O + t D) = 0
147
+
148
+ where `O` is the ray origin, `D` is the direction, and `s` is the surface function.
149
+
150
+ This rule enables backpropagation through `t` with respect to `O`, `D`, and surface parameters.
151
+ """
152
+
153
+ @staticmethod
154
+ def forward(ctx,
155
+ O:torch.Tensor,
156
+ D:torch.Tensor,
157
+ surface_and_normal_func:Callable,
158
+ t_detached:torch.Tensor, *param_args) -> torch.Tensor:
159
+ """
160
+ Stores inputs for backward pass and returns precomputed `t`.
161
+
162
+ Args:
163
+ O (torch.Tensor): Ray origin of shape (N, 3).
164
+ D (torch.Tensor): Ray direction of shape (N, 3).
165
+ surface_and_normal_func (Callable): Surface function returning (s, ds/dR).
166
+ t_detached (torch.Tensor): Estimated intersection length (detached).
167
+ *param_args: Surface parameters.
168
+
169
+ Returns:
170
+ torch.Tensor: Intersection length `t`.
171
+ """
172
+ ctx.save_for_backward(O,D,t_detached,*param_args)
173
+ ctx.surface_and_normal_func = surface_and_normal_func
174
+ return t_detached
175
+
176
+ @staticmethod
177
+ def backward(ctx, grad_outputs:torch.Tensor)->Tuple:
178
+ """
179
+ Computes gradients of intersection length `t` with respect to:
180
+ - ray origin `O`
181
+ - ray direction `D`
182
+ - surface parameters
183
+
184
+ Args:
185
+ grad_outputs (torch.Tensor): Gradient of the loss w.r.t. output `t`.
186
+
187
+ Returns:
188
+ tuple: Gradients with respect to inputs (O, D, None, None, *param_args).
189
+ """
190
+ saved_tensors = ctx.saved_tensors
191
+ O = saved_tensors[0]
192
+ D = saved_tensors[1]
193
+ t_detached = saved_tensors[2]
194
+ param_args = saved_tensors[3:]
195
+ surface_and_normal_func = ctx.surface_and_normal_func
196
+ t = CustomAutogradRule_t.apply(O,D,surface_and_normal_func,t_detached,*param_args)
197
+ R = O+t*D
198
+
199
+ param_args_clone = []
200
+ for elem in param_args:
201
+ if torch.is_tensor(elem):
202
+ elem = elem.clone()
203
+ param_args_clone.append(elem)
204
+
205
+ s_val,dsdR_val = surface_and_normal_func(R,*param_args_clone)
206
+ dsdR_T_D = torch.sum(dsdR_val*D,axis=-1)
207
+ v1 = -grad_outputs.reshape(-1)/dsdR_T_D.reshape(-1)
208
+
209
+ jact_dtdp = None
210
+ with torch.enable_grad():
211
+ s_val = [s_val.reshape(-1)]
212
+ jact_dtdp = grad(s_val,[*param_args_clone], grad_outputs=v1,create_graph=True,retain_graph=True)
213
+
214
+ jact_dtdO = v1.reshape(-1,1)*dsdR_val
215
+ jact_dtdD = jact_dtdO*t.reshape(-1,1)
216
+ return jact_dtdO,jact_dtdD,None,None,*jact_dtdp
217
+
218
+ def get_ray_intersection_length(O:torch.Tensor,
219
+ D:torch.Tensor,
220
+ surface_and_normal_func:Callable,
221
+ param_args:List,
222
+ t_init:Optional[torch.Tensor]=None)->torch.Tensor:
223
+ """
224
+ Solves for the intersection length `t` such that:
225
+
226
+ .. math::
227
+ s(O + t D) = 0
228
+
229
+ using a Newton-style iteration method with damping.
230
+
231
+ This function finds the length `t` where a ray intersects a parametric surface,
232
+ given by a composed function with normal information.
233
+
234
+ Args:
235
+ O (torch.Tensor): Ray origins of shape (N, 3).
236
+ D (torch.Tensor): Ray directions of shape (N, 3).
237
+ surface_and_normal_func (Callable): A function returning (s, ds/dR).
238
+ param_args (list): List of surface parameters.
239
+ t_init (torch.Tensor, optional): Initial guess for `t`. If None, starts from zero.
240
+
241
+ Returns:
242
+ torch.Tensor: Estimated intersection lengths `t` with autograd support.
243
+
244
+ Raises:
245
+ Warning is printed (not exception) if convergence fails within `max_iter`.
246
+ """
247
+ tolerance = get_tolerance()
248
+ max_iter = get_max_iterations()
249
+ damping = get_damping_factor()
250
+ device = O.device
251
+ dtype = O.dtype
252
+
253
+ N = O.shape[0]
254
+
255
+ #better initial value
256
+ t_detached = None
257
+ if t_init is not None:
258
+ t_detached = t_init.detach().reshape(N,1)
259
+ else:
260
+ t_detached = torch.zeros((N,1),device=device,dtype=dtype)
261
+
262
+ O_detached = O.detach()
263
+ D_detached = D.detach()
264
+
265
+
266
+ converged = False
267
+
268
+ smax_vals = []
269
+ for k in range(max_iter):
270
+ R_detached = O_detached+t_detached*D_detached
271
+ s_val,dsdR_val = surface_and_normal_func(R_detached,*param_args)
272
+ s_val,dsdR_val = s_val.detach(),dsdR_val.detach()
273
+ t_detached = t_detached-damping*s_val.reshape(-1,1)/(torch.sum(dsdR_val*D_detached,dim=-1).reshape(-1,1))
274
+ t_detached = t_detached.detach()
275
+ smax_vals += [torch.max(torch.abs(s_val.detach()))]
276
+ if (s_val<tolerance).all():
277
+ converged = True
278
+ if get_show_iteration_count():
279
+ print(f"Ray intersection with surface completed in {k} iterations.")
280
+ break
281
+ if not converged:
282
+ print(f"Ray intersection FAILED to converge after {max_iter} iterations!\nThis is totally normal durring optimization when a bad parameterset is chosen."+"maximum svals are: "+str(smax_vals))
283
+
284
+ t_out = CustomAutogradRule_t.apply(O,D,surface_and_normal_func,t_detached,*param_args)
285
+ return t_out