diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,808 @@
1
+ r"""
2
+ Optimization Utilities for PyTorch-SciPy Integration
3
+ ====================================================
4
+
5
+ This submodule provides a set of tools for constrained and unconstrained optimization of PyTorch models using SciPy optimizers. It bridges the gap between SciPy’s powerful optimization routines and PyTorch’s autograd system, enabling flexible and efficient hybrid optimization workflows.
6
+
7
+ Key Features:
8
+ -------------
9
+ - Seamless wrapping of PyTorch-based objective functions for use with SciPy.
10
+ - Automatic gradient computation using PyTorch’s autograd.
11
+ - Support for parameter bounds, including custom mask-based bounds.
12
+ - Caching and reuse of recent function/gradient evaluations.
13
+ - Integration with SciPy's `minimize`.
14
+ - Optional tracking of optimization history (function values and gradient norms).
15
+ - Utility functions for flattening/unpacking tensor parameters.
16
+ - Conversion of PyTorch parameters to SciPy-compatible formats with bounds.
17
+ - Support for custom constraints and callback functions.
18
+
19
+ Optimization Constraints in Optical Systems
20
+ -------------------------------------------
21
+
22
+ When using optimization procedures to attain parameters of an optical system, it is important to have constraints that ensure that the optical system can be manufactured. The following demonstrates the implementation of different types of constraints in our library, with a specific focus on the positive air spacing and minimum glass thickness constraints.
23
+
24
+ Constraint optimization problems can often be expressed as a *nonlinear program*, which is defined as follows (see :cite:`italiens`):
25
+
26
+ .. math::
27
+
28
+ \min_{p} \quad m(p)
29
+
30
+ .. math::
31
+
32
+ \text{subject to} \quad \hat{g}_i(p) \leq 0, \quad i = 1, \ldots, N_1,
33
+
34
+ .. math::
35
+
36
+ \text{subject to} \quad \hat{h}_j(p) = 0, \quad j = 1, \ldots, N_2,
37
+
38
+ where:
39
+ - :math:`p \in \mathbb{R}^n` is the vector of parameters.
40
+ - :math:`m: \mathbb{R}^n \to \mathbb{R}` is the nonlinear objective (merit) function.
41
+ - :math:`\hat{g}_i: \mathbb{R}^n \to \mathbb{R}` are the inequality constraint functions.
42
+ - :math:`\hat{h}_j: \mathbb{R}^n \to \mathbb{R}` are the equality constraint functions.
43
+
44
+ For this type of problem, multiple numerical schemes are available in the Python library *SciPy*. Some optimization schemes also require derivative information for functions that describe constraints. For example, Sequential Least Squares Programming (SLSQP) uses the derivatives of the constraint functions :math:`\hat{g}_i` and :math:`\hat{h}_j` to find local minima.
45
+
46
+ By combining the libraries PyTorch and SciPy, we leverage the strengths of two sophisticated and established libraries:
47
+
48
+ 1. **PyTorch**: Efficiently calculates the derivatives of the merit function :math:`m` and the constraint functions :math:`\hat{g}_i` and :math:`\hat{h}_j` using automatic differentiation. Additionally, it allows evaluation of these functions and their derivatives on a graphics card, providing significant speedups.
49
+ 2. **SciPy**: Provides well-tested traditional algorithms to find local minima. While PyTorch also has a wide variety of optimization algorithms, its main application is stochastic gradient descent in deep learning, which may not be the best choice for optimizing optical systems.
50
+
51
+ Types of Constraints
52
+ --------------------
53
+
54
+ In our library, we implemented three ways to define constraints:
55
+
56
+ 1. **Bounds**
57
+ Most numerical schemes in SciPy support bounding box constraints, allowing the definition of minimum and maximum values for each parameter. These bounds can be interpreted as constraints in the form :math:`\hat{g}_i(p) = p - C_i` or :math:`\hat{g}_i(p) = C_i - p`, where :math:`C_i \in \mathbb{R}`. This is particularly useful for distance transformations, where we can ensure that the distance parameter is never smaller than 0. For example:
58
+
59
+ >>> import diffinytrace as dit
60
+ >>> import torch
61
+ >>> dist_transform = dit.transforms.Distance(10.)
62
+ >>> dist_transform.distance.bounds = torch.tensor([5.0, torch.inf])
63
+
64
+ Here, **torch.inf** indicates that the distance can be arbitrarily large, with no upper bound.
65
+
66
+ 2. **Constant Variables**
67
+ If a specific parameter should be fixed, PyTorch allows disabling gradient computation for that parameter. For example:
68
+
69
+ >>> import diffinytrace as dit
70
+ >>> distance_transform = dit.transforms.Distance(10.)
71
+ >>> distance_transform.distance.requires_grad = False
72
+
73
+ Note: While it is easy to set specific parameters as constants, it is not possible to disable gradient computation for individual parameters if the variable contains multiple values. For instance, in the case of a B-spline surface, it is not possible to disable gradient computation for individual B-spline coefficients.
74
+
75
+ 3. **Arbitrary Constraint Functions**
76
+ Our library also supports defining nonlinear inequality constraint functions :math:`\hat{g}_i` and equality constraint functions :math:`\hat{h}_i`. Some local optimization methods require derivative information for these nonlinear constraint functions. To efficiently evaluate these derivatives, we use automatic differentiation. This is achieved by defining the constraint functions :math:`\hat{g}_i` with PyTorch and calculating their derivatives with respect to the parameters of the optical system. This approach eliminates the need for finite differences, which could significantly slow down the optimization procedure.
77
+ """
78
+
79
+ # Copyright (c) 2025 Martin Pflaum
80
+ # This file is part of the diffinytrace project, licensed under the MIT License.
81
+
82
+ __all__ = [
83
+ "make_bounds_from_param",
84
+ "make_parameter_from_input",
85
+ "pack_tensors",
86
+ "unpack_tensors",
87
+ "apply_vec_to_params",
88
+ "set_full_if_nan",
89
+ "ParameterFunHelper",
90
+ "create_fun_and_gradient",
91
+ "remove_bounds",
92
+ "get_bounds",
93
+ "get_scipy_constraint",
94
+ "create_callback",
95
+ "minimize",
96
+ "copy_bounds_to_attr_name",
97
+ "set_bounds_from_params_mask"
98
+ ]
99
+
100
+ import scipy
101
+ import scipy.optimize
102
+ from .utils.autograd import grad
103
+ import torch
104
+ import numpy as np
105
+ import torch.nn as nn
106
+ import copy
107
+ from typing import Callable, List, Tuple, Optional
108
+
109
+ def make_bounds_from_param(param):
110
+ """
111
+ Creates default bounds (-∞, ∞) for each element of the input tensor.
112
+
113
+ This function returns a tensor of shape `param.shape + [2]`, where the last
114
+ dimension represents the lower and upper bounds for each element in `param`.
115
+
116
+ Args:
117
+ param (torch.Tensor): A tensor for which bounds should be created.
118
+
119
+ Returns:
120
+ torch.Tensor: A tensor of shape `param.shape + [2]` where
121
+ `[..., 0] = -inf` (lower bounds) and `[..., 1] = inf` (upper bounds),
122
+ with the same dtype and device as `param`.
123
+ """
124
+ bounds = torch.zeros(list(param.shape)+[2],device=param.device,dtype=param.dtype)
125
+ bounds[...,0] = -torch.inf
126
+ bounds[...,1] = torch.inf
127
+ return bounds
128
+
129
+
130
+ def make_parameter_from_input(input,bounds=None, dtype=None, device=None,bounds_attr_name="bounds"):
131
+ """
132
+ Converts input to a `torch.nn.Parameter` and attaches bounds as an attribute.
133
+
134
+ Args:
135
+ input (array-like or torch.Tensor): Input data.
136
+ bounds (torch.Tensor, optional): Bounds to attach to the parameter.
137
+ dtype (torch.dtype, optional): Desired tensor data type.
138
+ device (torch.device, optional): Device to store the parameter on.
139
+ bounds_attr_name (str): Attribute name used to store bounds.
140
+
141
+ Returns:
142
+ torch.nn.Parameter: The parameter with bounds attached as an attribute.
143
+ """
144
+ if not torch.is_tensor(input):
145
+ input = torch.tensor(input, dtype=dtype, device=device)
146
+
147
+ # If the input tensor has a different dtype or device, move it accordingly
148
+ if dtype is not None or device is not None:
149
+ input = input.to(device=device, dtype=dtype)
150
+
151
+ # If the input is not already a Parameter, convert it to one
152
+ if not isinstance(input, torch.nn.Parameter):
153
+ input = torch.nn.Parameter(input)
154
+
155
+ if bounds is None:
156
+ bounds = make_bounds_from_param(input)
157
+ #input.bounds = bounds
158
+ setattr(input,bounds_attr_name,bounds)
159
+ return input
160
+
161
+ def pack_tensors(tensor_list:List[torch.Tensor]) -> torch.Tensor:
162
+ """
163
+ Flattens and concatenates a list of tensors into a single 1D tensor.
164
+
165
+ Args:
166
+ tensor_list (list of torch.Tensor or torch.Tensor): Input tensor(s).
167
+
168
+ Returns:
169
+ torch.Tensor: A 1D tensor.
170
+ """
171
+ if torch.is_tensor(tensor_list):
172
+ return tensor_list.reshape(-1)
173
+ return torch.cat([t.reshape(-1) for t in tensor_list])
174
+
175
+ def unpack_tensors(packed_tensor: torch.Tensor, shapes: List[Tuple[int]]) -> List[torch.Tensor]:
176
+ """
177
+ Unpacks a 1D tensor into a list of tensors with specified shapes.
178
+
179
+ Args:
180
+ packed_tensor (torch.Tensor): The flat tensor to unpack.
181
+ shapes (list of tuple): Target shapes for unpacked tensors.
182
+
183
+ Returns:
184
+ list of torch.Tensor: Unpacked tensors with original shapes.
185
+ """
186
+ unpacked_tensors = []
187
+ start = 0
188
+ for shape in shapes:
189
+ size = torch.prod(torch.tensor(shape)).item() # Calculate the size of the tensor
190
+ size = int(max(size,1))
191
+ # Reshape the portion of packed_tensor to the original shape
192
+ tensor = packed_tensor[start:start + size]
193
+ if len(shape) > 0: # Only reshape if shape is not scalar
194
+ tensor = tensor.reshape(*shape)
195
+ unpacked_tensors.append(tensor)
196
+ start += size # Move to the next start index
197
+ return unpacked_tensors
198
+
199
+ def apply_vec_to_params(vec: np.ndarray, params: list[torch.nn.Parameter], device=None, dtype=None):
200
+ """
201
+ Updates PyTorch parameters with values from a flattened NumPy vector.
202
+
203
+ This function is used in optimization workflows to update parameter values
204
+ during SciPy optimization. It takes a flat vector of parameter values and
205
+ distributes them back to the original parameter tensors, preserving their
206
+ original shapes.
207
+
208
+ Args:
209
+ vec (np.ndarray): A 1D NumPy array containing new parameter values.
210
+ The length must match the total number of elements across all parameters.
211
+ params (list[torch.nn.Parameter]): List of PyTorch parameters to update.
212
+ Each parameter will be reshaped from the corresponding portion of `vec`.
213
+ device (torch.device, optional): Target device for the parameters.
214
+ If None, uses the device of the first parameter. Defaults to None.
215
+ dtype (torch.dtype, optional): Target data type for the parameters.
216
+ If None, uses the dtype of the first parameter. Defaults to None.
217
+
218
+ Raises:
219
+ RuntimeError: If `vec` is not a NumPy array.
220
+
221
+ Example:
222
+ >>> import torch
223
+ >>> import numpy as np
224
+ >>> import diffinytrace as dit
225
+ >>>
226
+ >>> # Create some parameters
227
+ >>> params = [
228
+ ... torch.nn.Parameter(torch.ones((2,2)))*0.25,
229
+ ... torch.nn.Parameter(torch.ones(3))
230
+ ... ]
231
+ >>> # Flatten parameters to create a vector
232
+ >>> vec = dit.optimize.pack_tensors(params).detach().cpu().numpy()
233
+ >>> print(f"Vector length: {len(vec)}") # Should be 2*2 + 3 = 7
234
+ >>> # Modify the vector
235
+ >>>
236
+ >>> print(params)
237
+ >>>
238
+ >>> vec_new = vec * 2.0
239
+ >>> # Update parameters with new values
240
+ >>> dit.optimize.apply_vec_to_params(vec_new, params)
241
+ >>>
242
+ >>> # Parameters are now updated with doubled values
243
+ >>> print(params)
244
+
245
+ Note:
246
+ - This function modifies parameters in-place using `param.data = ...`
247
+ - The function uses `torch.no_grad()` to avoid building computation graphs
248
+ - Parameter shapes are preserved during the update process
249
+ - Commonly used with `pack_tensors()` and `unpack_tensors()` for optimization
250
+ """
251
+ if not isinstance(vec, np.ndarray):
252
+ raise RuntimeError("vec should be a numpy vector")
253
+ params = [elem for elem in params]
254
+ if dtype is None:
255
+ dtype = params[0].dtype
256
+ if device is None:
257
+ device = params[0].device
258
+ unpacked_params = unpack_tensors(torch.tensor(vec,device=device,dtype=dtype), [elem.shape for elem in params])
259
+ with torch.no_grad():
260
+ for k,param in enumerate(params):
261
+ param.data = unpacked_params[k]
262
+
263
+ def set_full_if_nan(input:np.ndarray, fill_value: float)->np.ndarray:
264
+ """
265
+ Replaces NaNs in input with a specified fill value.
266
+
267
+ Args:
268
+ input (np.ndarray): A NumPy array or scalar.
269
+ fill_value (float): Value to use in place of NaNs.
270
+
271
+ Returns:
272
+ np.ndarray or float: Modified input with no NaNs.
273
+ """
274
+ if not isinstance(input, np.ndarray):
275
+ raise RuntimeError("set_full_if_nan,input should be a numpy vector")
276
+
277
+ if len(input.shape) == 0:
278
+ if np.isnan(input):
279
+ return np.array(fill_value)
280
+ else:
281
+ return input
282
+ else:
283
+ if np.isnan(input).any():
284
+ input = np.full_like(input, fill_value)
285
+ return input
286
+ else:
287
+ return input
288
+
289
+ class ParameterFunHelper():
290
+ """
291
+ Helper class for evaluating PyTorch functions and gradients in SciPy optimization.
292
+
293
+ This class bridges PyTorch's automatic differentiation with SciPy's optimization
294
+ routines by providing function and gradient evaluations in NumPy format.
295
+ It includes caching to avoid redundant computations and handles NaN values
296
+ gracefully during optimization.
297
+
298
+ Args:
299
+ original_fun (Callable): PyTorch function to be optimized. Should return a scalar tensor.
300
+ params (List[torch.nn.Parameter]): List of PyTorch parameters to optimize over.
301
+ nan_fallback (float, optional): Value to return if NaN is detected in function
302
+ or gradient evaluation. Defaults to float("inf").
303
+
304
+ Attributes:
305
+ original_fun (Callable): The objective function being optimized.
306
+ params (List[torch.nn.Parameter]): Parameters for optimization.
307
+ nan_fallback (float): Fallback value for NaN handling.
308
+ last_x_fun_numpy (np.ndarray): Cache of last input for function evaluation.
309
+ last_fun_val_numpy (float): Cache of last function value in NumPy format.
310
+ last_fun_val_torch (torch.Tensor): Cache of last function value as PyTorch tensor.
311
+ last_x_grad_numpy (np.ndarray): Cache of last input for gradient evaluation.
312
+ last_grad_val_numpy (np.ndarray): Cache of last gradient in NumPy format.
313
+
314
+ Example:
315
+ >>> import torch
316
+ >>> import diffinytrace as dit
317
+ >>> import numpy as np
318
+ >>>
319
+ >>> # Define parameters and objective function
320
+ >>> params = [torch.nn.Parameter(torch.randn(5))]
321
+ >>> def objective():
322
+ ... return torch.sum(params[0]**2)
323
+ >>>
324
+ >>> # Create helper for SciPy optimization
325
+ >>> helper = dit.optimize.ParameterFunHelper(objective, params)
326
+ >>>
327
+ >>> # Use with SciPy
328
+ >>> x0 = np.ones((5,))*3.
329
+ >>> fun_val = helper.fun(x0) # Evaluate function 5*3^2 = 45
330
+ >>> grad_val = helper.jac(x0) # Evaluate gradient 2*3 = 6
331
+ >>> fun_val, grad_val = helper.fun_jac(x0) # Evaluate both
332
+ >>>
333
+ >>> print(fun_val, grad_val) # (45.0, array([6., 6., 6., 6., 6.]))
334
+
335
+ Note:
336
+ - Function and gradient evaluations are cached to avoid redundant computations
337
+ when SciPy requests the same point multiple times.
338
+ - All NaN values in function outputs or gradients are replaced with `nan_fallback`.
339
+ - Parameters are automatically updated with new values during evaluation.
340
+ """
341
+ def __init__(self,orginal_fun,params,nan_fallback = float("inf")):
342
+ self.last_x_fun_numpy = None
343
+ self.last_fun_val_numpy = None
344
+ self.last_fun_val_torch = None
345
+
346
+ self.last_x_grad_numpy = None
347
+ self.last_grad_val_numpy = None
348
+ self.orginal_fun = orginal_fun
349
+
350
+ self.params = [param for param in params]
351
+ self.nan_fallback = nan_fallback
352
+
353
+ def fun(self,x):
354
+ """
355
+ Evaluates the objective function at a given input.
356
+
357
+ Args:
358
+ x (np.ndarray): Flat input array.
359
+
360
+ Returns:
361
+ float: Function value with NaNs replaced if needed.
362
+ """
363
+ if not self.last_x_fun_numpy is None:
364
+ if (x == self.last_x_fun_numpy).all():
365
+ out = self.last_fun_val_numpy
366
+ out = set_full_if_nan(out,self.nan_fallback)
367
+ return out
368
+
369
+
370
+ device = self.params[0].device
371
+ dtype = self.params[0].dtype
372
+ apply_vec_to_params(x,self.params,device,dtype)
373
+ self.last_x_fun_numpy = copy.deepcopy(x)
374
+ fun_val = self.orginal_fun()
375
+ self.last_fun_val_torch = fun_val
376
+ self.last_fun_val_numpy = set_full_if_nan(fun_val.detach().cpu().numpy(),self.nan_fallback)
377
+ out = self.last_fun_val_numpy
378
+ out = set_full_if_nan(out,self.nan_fallback)
379
+ return out
380
+
381
+ def jac(self,x):
382
+ """
383
+ Computes the gradient of the objective function at input x.
384
+
385
+ Args:
386
+ x (np.ndarray): Flat input array.
387
+
388
+ Returns:
389
+ np.ndarray: Gradient with NaNs replaced if needed.
390
+ """
391
+ if not self.last_x_grad_numpy is None:
392
+ if (x == self.last_x_grad_numpy).all():
393
+ out = self.last_grad_val_numpy
394
+ out = set_full_if_nan(out,self.nan_fallback)
395
+ return out
396
+
397
+ self.fun(x)
398
+ self.last_x_grad_numpy = copy.deepcopy(x)
399
+ dp = grad(self.last_fun_val_torch,inputs=self.params,materialize_grads=True,create_graph=False,retain_graph=False)
400
+ dp = pack_tensors(dp)
401
+ dp_numpy = dp.detach().cpu().numpy()
402
+
403
+ self.last_grad_val_numpy = set_full_if_nan(dp_numpy,self.nan_fallback)
404
+
405
+ out = dp_numpy
406
+ out = set_full_if_nan(out,self.nan_fallback)
407
+ return out
408
+
409
+ def fun_jac(self,x):
410
+ """
411
+ Evaluates both function value and gradient at once.
412
+
413
+ Args:
414
+ x (np.ndarray): Flat input array.
415
+
416
+ Returns:
417
+ Tuple[float, np.ndarray]: Function value and gradient.
418
+ """
419
+ fun_val_numpy = self.fun(x)
420
+ grad_val_numpy = self.jac(x)
421
+ return fun_val_numpy,grad_val_numpy
422
+ """
423
+ def hess(self,x,v):
424
+ if not self.calc_hess:
425
+ raise("ParameterFunHelper: calc_hess was initialized with False!")
426
+ device = self.last_grad_val_torch.device
427
+ dtype = self.last_grad_val_torch.dtype
428
+ self.grad(x)
429
+ v_torch = torch.tensor(v,device=device,dtype=dtype)
430
+ Hv = grad(self.last_grad_val_torch,inputs=self.params,grad_outputs=v_torch,materialize_grads=True,create_graph=False,retain_graph=True)
431
+ Hv_packed = pack_tensors(Hv)
432
+ out = Hv_packed.detach().cpu().numpy()
433
+ out = set_full_if_nan(out,self.nan_fallback)
434
+ print("hess out ",out)
435
+ return out"""
436
+
437
+
438
+ def create_fun_and_gradient(merit_fun,params,nan_fallback,device,dtype):
439
+ """
440
+ Wraps a PyTorch merit function and returns a callable that evaluates both
441
+ the function and its gradient in NumPy format.
442
+
443
+ Args:
444
+ merit_fun (Callable): PyTorch function to optimize.
445
+ params (list): List of `torch.nn.Parameter` objects.
446
+ nan_fallback (float): Value to use if NaNs are encountered.
447
+ device (torch.device): Target device.
448
+ dtype (torch.dtype): Target dtype.
449
+
450
+ Returns:
451
+ Callable: Function that returns (value, gradient) as NumPy arrays.
452
+ """
453
+ def fun_and_gradient(input):
454
+ apply_vec_to_params(input,params,device,dtype)
455
+ merit_val = merit_fun()
456
+ dmdp = grad(merit_val,inputs=params,materialize_grads=True,create_graph=False,retain_graph=False)
457
+
458
+ out_merit_val = merit_val.detach().cpu()
459
+ out_dmdp = [elem.detach().cpu() for elem in dmdp]
460
+ out_dmdp = pack_tensors(out_dmdp)
461
+
462
+ out_dmdp = set_full_if_nan(out_dmdp.numpy(),nan_fallback)
463
+ out_merit_val = set_full_if_nan(out_merit_val.numpy(),nan_fallback)
464
+
465
+ #print("merit_val: ",out_merit_val)
466
+ return out_merit_val,out_dmdp
467
+ return fun_and_gradient
468
+
469
+
470
+ def remove_bounds(params,bounds_attr_name) -> None:
471
+ """
472
+ Removes the bounds attribute from parameters if present.
473
+
474
+ Args:
475
+ params (list): List of torch.nn.Parameter objects.
476
+ bounds_attr_name (str): Attribute name of bounds to remove.
477
+ """
478
+ for elem in params:
479
+ if hasattr(elem,bounds_attr_name):
480
+ setattr(elem,bounds_attr_name,None)
481
+
482
+ def get_bounds(params,bounds_attr_name="bounds"):
483
+ """
484
+ Extracts and concatenates bounds for all parameters.
485
+
486
+ Args:
487
+ params (list): List of torch.nn.Parameter objects.
488
+ bounds_attr_name (str): Name of attribute storing bounds.
489
+
490
+ Returns:
491
+ np.ndarray: Array of shape (N, 2) with all bounds.
492
+ """
493
+ out = []
494
+
495
+ for elem in params:
496
+ if not hasattr(elem,bounds_attr_name):
497
+ bounds = make_bounds_from_param(elem)
498
+ setattr(elem,bounds_attr_name,bounds)
499
+ tmp = getattr(elem,bounds_attr_name)
500
+ if isinstance(tmp,list):
501
+ tmp = torch.tensor(np.array(tmp),dtype=torch.get_default_dtype())
502
+ if isinstance(tmp,np.ndarray):
503
+ tmp = torch.tensor(tmp,dtype=torch.get_default_dtype())
504
+
505
+ out += [tmp]
506
+ out = torch.cat([t.reshape(-1,2) for t in out],dim=0)
507
+ out = out.detach().cpu()
508
+ #print("out",out)
509
+ out = np.array(out)
510
+ return out
511
+
512
+ def get_scipy_constraint(constraint,params,nan_fallback):
513
+ """
514
+ Converts a constraint into SciPy-compatible format.
515
+
516
+ Args:
517
+ constraint (Constraint): A custom constraint object.
518
+ params (list): List of parameters for the optimization.
519
+ nan_fallback (float): Fallback value for NaNs.
520
+
521
+ Returns:
522
+ dict: A dictionary compatible with SciPy constraints.
523
+ """
524
+ param_fun_helper = ParameterFunHelper(constraint.fun,params,nan_fallback)
525
+ param_fun_helper.constraint=True
526
+
527
+ scipy_data = {'type': constraint.type,'fun':param_fun_helper.fun,'jac':param_fun_helper.jac}
528
+ return scipy_data
529
+
530
+
531
+ def create_callback(callback_fun,params,device,dtype):
532
+ """
533
+ Wraps a PyTorch callback function for use in SciPy.
534
+
535
+ Args:
536
+ callback_fun (Callable): A function taking no arguments.
537
+ params (list): List of parameters to update before calling.
538
+ device (torch.device): Device of the parameters.
539
+ dtype (torch.dtype): Data type of the parameters.
540
+
541
+ Returns:
542
+ Callable: A callback function for SciPy optimizers.
543
+ """
544
+ def call_back(input):
545
+ apply_vec_to_params(input,params,device,dtype)
546
+ return callback_fun()
547
+ return call_back
548
+
549
+ #nlopt==2.6.2
550
+ """
551
+ def global_dual_annealing(fun,
552
+ params,
553
+ constraints=[],
554
+ annealing_maxiter=1000,
555
+ annealing_initial_temp=5230.0,
556
+ annealing_restart_temp_ratio=2e-05,
557
+ annealing_visit=2.62,
558
+ annealing_accept=-5.0,
559
+ annealing_maxfun=10000000.0,
560
+ bounds_attr_name="bounds",
561
+ local_tol=1e-6,
562
+ local_method=None):
563
+ nan_fallback = annealing_maxfun
564
+
565
+ from .constraints import Constraint
566
+
567
+ if isinstance(constraints,Constraint):
568
+ constraints = [constraints]
569
+
570
+ if local_method is None:
571
+ if len(constraints) == 0:
572
+ local_method = 'L-BFGS-B'
573
+ else:
574
+ local_method = 'SLSQP'
575
+
576
+ if (not local_method == 'SLSQP') and (len(constraints)>0):
577
+ raise RuntimeError("Only for method SLSQP constraints are supported!")
578
+
579
+ if isinstance(params, torch.nn.Parameter):
580
+ params = [params]
581
+
582
+ params = [param for param in params if param.requires_grad]
583
+
584
+ if len(params) == 0:
585
+ raise RuntimeError("Params is either an empty list or no parameter provided requires_grad!")
586
+
587
+ constraints = [get_scipy_constraint(constraint,params,nan_fallback) for constraint in constraints]
588
+
589
+ device = params[0].device
590
+ dtype = params[0].dtype
591
+
592
+ bounds_numpy = get_bounds(params,bounds_attr_name)
593
+ if np.isinf(bounds_numpy).any():
594
+ raise RuntimeError("All bounds need to be non inf!")
595
+ param_helper_main = ParameterFunHelper(fun,params,nan_fallback)
596
+ #fun_helper = ParameterFunHelper(fun,params,False,nan_fallback)
597
+
598
+ minimizer_kwargs = dict(
599
+ #func=param_helper_main.fun,
600
+ jac=param_helper_main.jac,
601
+ constraints=constraints,
602
+ tol=local_tol,
603
+ method=local_method)
604
+
605
+
606
+
607
+ initial_params = pack_tensors([param.cpu().detach() for param in params]) # Pack the initial params
608
+
609
+ result = scipy.optimize.dual_annealing(
610
+ func=param_helper_main.fun,
611
+ x0=initial_params,
612
+ bounds=bounds_numpy,
613
+ maxiter = annealing_maxiter,
614
+ initial_temp=annealing_initial_temp,
615
+ restart_temp_ratio = annealing_restart_temp_ratio,
616
+ visit = annealing_visit,
617
+ accept = annealing_accept,
618
+ maxfun=annealing_maxfun,
619
+ minimizer_kwargs=minimizer_kwargs)
620
+
621
+
622
+ apply_vec_to_params(result["x"],[p for p in params],device,dtype)
623
+ return result
624
+ """
625
+
626
+
627
+ def minimize(fun,
628
+ params,
629
+ constraints:List=[],
630
+ method=None,
631
+ tol:float=1e-9,
632
+ callback:Callable=lambda:None,
633
+ options:Optional[dict]=None,
634
+ nan_fallback:float=float("inf"),
635
+ bounds_attr_name:str="bounds",
636
+ save_history:bool=False,
637
+ call_before_minimize:bool=False)->dict:
638
+ """
639
+ Minimizes a function using SciPy's `minimize`, supporting bounds and constraints.
640
+
641
+ Args:
642
+ fun (Callable): Objective function.
643
+ params (list): Parameters to optimize.
644
+ constraints (list): List of constraints.
645
+ method (str): SciPy optimization method (e.g., 'L-BFGS-B').
646
+ tol (float): Tolerance for convergence.
647
+ callback (Callable): Optional callback function.
648
+ options (dict): Optimizer options.
649
+ nan_fallback (float): Value to use if function returns NaN.
650
+ bounds_attr_name (str): Name of bounds attribute.
651
+ save_history (bool): If True, saves function values and gradient norms.
652
+ call_before_minimize (bool): Whether to evaluate once before optimization.
653
+
654
+ Returns:
655
+ dict: Dictionary containing optimization results (and optionally history).
656
+ """
657
+ from .constraints import Constraint
658
+
659
+ if isinstance(constraints,Constraint):
660
+ constraints = [constraints]
661
+
662
+ if method is None:
663
+ if len(constraints) == 0:
664
+ method = 'L-BFGS-B'
665
+ else:
666
+ method = 'SLSQP'
667
+
668
+ if (not method == 'SLSQP') and (len(constraints)>0):
669
+ raise RuntimeError("Only for method SLSQP constraints are supported!")
670
+
671
+ if isinstance(params, torch.nn.Parameter):
672
+ params = [params]
673
+
674
+ params = [param for param in params if param.requires_grad]
675
+
676
+ if len(params) == 0:
677
+ raise RuntimeError("Params is either an empty list or no parameter provided requires_grad!")
678
+
679
+ constraints = [get_scipy_constraint(constraint,params,nan_fallback) for constraint in constraints]
680
+
681
+ device = params[0].device
682
+ dtype = params[0].dtype
683
+
684
+ bounds_numpy = get_bounds(params,bounds_attr_name)
685
+
686
+ initial_params = pack_tensors([param.cpu().detach() for param in params]) # Pack the initial params
687
+ param_helper_main = ParameterFunHelper(fun,params,nan_fallback)
688
+ #fun_helper = ParameterFunHelper(fun,params,False,nan_fallback)
689
+
690
+ history = {"fun_vals":[],"fun_grads_norm":[]}
691
+
692
+ fun_and_gradient = param_helper_main.fun_jac#create_fun_and_gradient(fun,params,nan_fallback,device=device,dtype=dtype)
693
+ if save_history:
694
+
695
+ def callback_history(input):
696
+
697
+ out_merit_val,out_dmdp = fun_and_gradient(input)
698
+ history["fun_vals"] += [out_merit_val]
699
+ history["fun_grads_norm"] += [np.linalg.norm(out_dmdp)]
700
+
701
+ if callback is None:
702
+ callback = callback_history
703
+ else:
704
+ callback_tmp = create_callback(callback,params,device,dtype)
705
+ def combined_callback(input):
706
+ callback_tmp(input)
707
+ callback_history(input)
708
+ callback = combined_callback
709
+ elif callback is not None:
710
+ callback = create_callback(callback,params,device,dtype)
711
+
712
+ initial_params = np.array(initial_params)
713
+ if call_before_minimize:
714
+ fun_and_gradient(initial_params)
715
+ callback(initial_params)
716
+
717
+ result = scipy.optimize.minimize(
718
+ fun=fun_and_gradient,
719
+ x0=initial_params,
720
+ jac=True, # Indicates that the function returns both value and gradient
721
+ bounds=bounds_numpy,
722
+ method=method, # Choose an appropriate method
723
+ tol=tol,
724
+ callback=callback,
725
+ options=options,
726
+ constraints=constraints,
727
+ #hessp=fun_helper.hess
728
+ )
729
+ apply_vec_to_params(result["x"],[p for p in params],device,dtype)
730
+ result = {key:result[key] for key in result.keys()}
731
+ if len(history["fun_vals"])>0:
732
+ history["fun_vals"] = np.array(history["fun_vals"])
733
+ history["fun_grads_norm"] = np.array(history["fun_grads_norm"])
734
+
735
+ if save_history:
736
+ result["history"] = history
737
+ return result
738
+
739
+
740
+ def copy_bounds_to_attr_name(params,bounds_attr_name_new,bounds_attr_name_old="bounds",replace_existing_once=True):
741
+ """
742
+ Copies bounds from one attribute name to another.
743
+
744
+ Args:
745
+ params (list): List of parameters.
746
+ bounds_attr_name_new (str): New attribute name.
747
+ bounds_attr_name_old (str): Existing attribute name.
748
+ replace_existing_once (bool): Whether to skip copying if new attribute exists.
749
+ """
750
+ def copy_bounds(param,bounds_attr_name_new,bounds_attr_name_old="bounds"):
751
+ bounds = None
752
+ if hasattr(param,bounds_attr_name_old):
753
+ bounds = getattr(param,bounds_attr_name_old)
754
+ else:
755
+ bounds = make_bounds_from_param(param)
756
+ bounds = bounds.clone()
757
+ setattr(param,bounds_attr_name_new,bounds)
758
+ if isinstance(params,nn.Parameter):
759
+ params = [params]
760
+ params = [param for param in params]
761
+ for param in params:
762
+ if (not replace_existing_once) and (hasattr(param,bounds_attr_name_new)):
763
+ continue
764
+ else:
765
+ copy_bounds(param,bounds_attr_name_new,bounds_attr_name_old)
766
+
767
+
768
+ def set_bounds_from_params_mask(params,mask:list|torch.Tensor,bounds_attr_name_new,bounds_attr_name_old="bounds"):
769
+ """
770
+ Sets bounds for parameters based on a mask. Parameters with `mask=False`
771
+ get fixed bounds (equal lower and upper bounds).
772
+
773
+ Args:
774
+ params (list): List of parameters.
775
+ mask (list or torch.Tensor): Mask specifying which elements are free.
776
+ bounds_attr_name_new (str): Attribute name to store new bounds.
777
+ bounds_attr_name_old (str): Attribute name to read old bounds from.
778
+ """
779
+ def set_new_bounds_from_param_mask(param,mask,bounds_attr_name_new,bounds_attr_name_old="bounds"):
780
+ bounds = None
781
+ if hasattr(param,bounds_attr_name_old):
782
+ bounds = getattr(param,bounds_attr_name_old)
783
+ else:
784
+ bounds = make_bounds_from_param(param)
785
+ bounds = bounds.clone()
786
+ bounds_shape = bounds.shape
787
+ mask = mask.reshape(-1)
788
+ bounds = bounds.reshape(-1,2)
789
+ data = param.data.clone()
790
+ data = data.reshape(-1)
791
+ #print("shapes",mask.shape,(mask==False).shape,bounds.shape)
792
+ mask_false = mask==False
793
+ bounds[mask_false,0] = data[mask_false]
794
+ bounds[mask_false,1] = data[mask_false]
795
+ bounds = bounds.reshape(*bounds_shape)
796
+ setattr(param,bounds_attr_name_new,bounds)
797
+
798
+
799
+ if isinstance(params,nn.Parameter):
800
+ params = [params]
801
+ params = [param for param in params]
802
+ if isinstance(mask,(np.ndarray)) or torch.is_tensor(mask):
803
+ mask = [mask]
804
+
805
+ for k in range(len(params)):
806
+ set_new_bounds_from_param_mask(params[k],mask[k],bounds_attr_name_new=bounds_attr_name_new,bounds_attr_name_old=bounds_attr_name_old)
807
+
808
+