diffinytrace 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffinytrace/__init__.py +122 -0
- diffinytrace/basis_functions/__init__.py +14 -0
- diffinytrace/basis_functions/bspline.py +521 -0
- diffinytrace/basis_functions/chebyshev.py +3 -0
- diffinytrace/basis_functions/legendre.py +77 -0
- diffinytrace/basis_functions/zernike.py +235 -0
- diffinytrace/config.py +140 -0
- diffinytrace/constraints.py +54 -0
- diffinytrace/element.py +1660 -0
- diffinytrace/export/__init__.py +8 -0
- diffinytrace/export/cad.py +253 -0
- diffinytrace/gaussian_smoother.py +530 -0
- diffinytrace/hat_smoother.py +44 -0
- diffinytrace/integrators.py +452 -0
- diffinytrace/intersection.py +285 -0
- diffinytrace/optimize.py +808 -0
- diffinytrace/physical_object.py +150 -0
- diffinytrace/plotting/__init__.py +16 -0
- diffinytrace/plotting/core.py +92 -0
- diffinytrace/plotting/quantity2D.py +188 -0
- diffinytrace/plotting/system2D.py +220 -0
- diffinytrace/plotting/system3D.py +327 -0
- diffinytrace/plotting/wavelength.py +231 -0
- diffinytrace/refractive_index.py +101 -0
- diffinytrace/render.py +77 -0
- diffinytrace/source.py +661 -0
- diffinytrace/spectrum.py +79 -0
- diffinytrace/surface.py +468 -0
- diffinytrace/target_grid.py +399 -0
- diffinytrace/transforms.py +472 -0
- diffinytrace/utils/__init__.py +7 -0
- diffinytrace/utils/autograd.py +116 -0
- diffinytrace/utils/irradiance_importer.py +134 -0
- diffinytrace-2.1.dist-info/METADATA +26 -0
- diffinytrace-2.1.dist-info/RECORD +38 -0
- diffinytrace-2.1.dist-info/WHEEL +5 -0
- diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
- diffinytrace-2.1.dist-info/top_level.txt +1 -0
diffinytrace/optimize.py
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
Optimization Utilities for PyTorch-SciPy Integration
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
This submodule provides a set of tools for constrained and unconstrained optimization of PyTorch models using SciPy optimizers. It bridges the gap between SciPy’s powerful optimization routines and PyTorch’s autograd system, enabling flexible and efficient hybrid optimization workflows.
|
|
6
|
+
|
|
7
|
+
Key Features:
|
|
8
|
+
-------------
|
|
9
|
+
- Seamless wrapping of PyTorch-based objective functions for use with SciPy.
|
|
10
|
+
- Automatic gradient computation using PyTorch’s autograd.
|
|
11
|
+
- Support for parameter bounds, including custom mask-based bounds.
|
|
12
|
+
- Caching and reuse of recent function/gradient evaluations.
|
|
13
|
+
- Integration with SciPy's `minimize`.
|
|
14
|
+
- Optional tracking of optimization history (function values and gradient norms).
|
|
15
|
+
- Utility functions for flattening/unpacking tensor parameters.
|
|
16
|
+
- Conversion of PyTorch parameters to SciPy-compatible formats with bounds.
|
|
17
|
+
- Support for custom constraints and callback functions.
|
|
18
|
+
|
|
19
|
+
Optimization Constraints in Optical Systems
|
|
20
|
+
-------------------------------------------
|
|
21
|
+
|
|
22
|
+
When using optimization procedures to attain parameters of an optical system, it is important to have constraints that ensure that the optical system can be manufactured. The following demonstrates the implementation of different types of constraints in our library, with a specific focus on the positive air spacing and minimum glass thickness constraints.
|
|
23
|
+
|
|
24
|
+
Constraint optimization problems can often be expressed as a *nonlinear program*, which is defined as follows (see :cite:`italiens`):
|
|
25
|
+
|
|
26
|
+
.. math::
|
|
27
|
+
|
|
28
|
+
\min_{p} \quad m(p)
|
|
29
|
+
|
|
30
|
+
.. math::
|
|
31
|
+
|
|
32
|
+
\text{subject to} \quad \hat{g}_i(p) \leq 0, \quad i = 1, \ldots, N_1,
|
|
33
|
+
|
|
34
|
+
.. math::
|
|
35
|
+
|
|
36
|
+
\text{subject to} \quad \hat{h}_j(p) = 0, \quad j = 1, \ldots, N_2,
|
|
37
|
+
|
|
38
|
+
where:
|
|
39
|
+
- :math:`p \in \mathbb{R}^n` is the vector of parameters.
|
|
40
|
+
- :math:`m: \mathbb{R}^n \to \mathbb{R}` is the nonlinear objective (merit) function.
|
|
41
|
+
- :math:`\hat{g}_i: \mathbb{R}^n \to \mathbb{R}` are the inequality constraint functions.
|
|
42
|
+
- :math:`\hat{h}_j: \mathbb{R}^n \to \mathbb{R}` are the equality constraint functions.
|
|
43
|
+
|
|
44
|
+
For this type of problem, multiple numerical schemes are available in the Python library *SciPy*. Some optimization schemes also require derivative information for functions that describe constraints. For example, Sequential Least Squares Programming (SLSQP) uses the derivatives of the constraint functions :math:`\hat{g}_i` and :math:`\hat{h}_j` to find local minima.
|
|
45
|
+
|
|
46
|
+
By combining the libraries PyTorch and SciPy, we leverage the strengths of two sophisticated and established libraries:
|
|
47
|
+
|
|
48
|
+
1. **PyTorch**: Efficiently calculates the derivatives of the merit function :math:`m` and the constraint functions :math:`\hat{g}_i` and :math:`\hat{h}_j` using automatic differentiation. Additionally, it allows evaluation of these functions and their derivatives on a graphics card, providing significant speedups.
|
|
49
|
+
2. **SciPy**: Provides well-tested traditional algorithms to find local minima. While PyTorch also has a wide variety of optimization algorithms, its main application is stochastic gradient descent in deep learning, which may not be the best choice for optimizing optical systems.
|
|
50
|
+
|
|
51
|
+
Types of Constraints
|
|
52
|
+
--------------------
|
|
53
|
+
|
|
54
|
+
In our library, we implemented three ways to define constraints:
|
|
55
|
+
|
|
56
|
+
1. **Bounds**
|
|
57
|
+
Most numerical schemes in SciPy support bounding box constraints, allowing the definition of minimum and maximum values for each parameter. These bounds can be interpreted as constraints in the form :math:`\hat{g}_i(p) = p - C_i` or :math:`\hat{g}_i(p) = C_i - p`, where :math:`C_i \in \mathbb{R}`. This is particularly useful for distance transformations, where we can ensure that the distance parameter is never smaller than 0. For example:
|
|
58
|
+
|
|
59
|
+
>>> import diffinytrace as dit
|
|
60
|
+
>>> import torch
|
|
61
|
+
>>> dist_transform = dit.transforms.Distance(10.)
|
|
62
|
+
>>> dist_transform.distance.bounds = torch.tensor([5.0, torch.inf])
|
|
63
|
+
|
|
64
|
+
Here, **torch.inf** indicates that the distance can be arbitrarily large, with no upper bound.
|
|
65
|
+
|
|
66
|
+
2. **Constant Variables**
|
|
67
|
+
If a specific parameter should be fixed, PyTorch allows disabling gradient computation for that parameter. For example:
|
|
68
|
+
|
|
69
|
+
>>> import diffinytrace as dit
|
|
70
|
+
>>> distance_transform = dit.transforms.Distance(10.)
|
|
71
|
+
>>> distance_transform.distance.requires_grad = False
|
|
72
|
+
|
|
73
|
+
Note: While it is easy to set specific parameters as constants, it is not possible to disable gradient computation for individual parameters if the variable contains multiple values. For instance, in the case of a B-spline surface, it is not possible to disable gradient computation for individual B-spline coefficients.
|
|
74
|
+
|
|
75
|
+
3. **Arbitrary Constraint Functions**
|
|
76
|
+
Our library also supports defining nonlinear inequality constraint functions :math:`\hat{g}_i` and equality constraint functions :math:`\hat{h}_i`. Some local optimization methods require derivative information for these nonlinear constraint functions. To efficiently evaluate these derivatives, we use automatic differentiation. This is achieved by defining the constraint functions :math:`\hat{g}_i` with PyTorch and calculating their derivatives with respect to the parameters of the optical system. This approach eliminates the need for finite differences, which could significantly slow down the optimization procedure.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
80
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
81
|
+
|
|
82
|
+
__all__ = [
|
|
83
|
+
"make_bounds_from_param",
|
|
84
|
+
"make_parameter_from_input",
|
|
85
|
+
"pack_tensors",
|
|
86
|
+
"unpack_tensors",
|
|
87
|
+
"apply_vec_to_params",
|
|
88
|
+
"set_full_if_nan",
|
|
89
|
+
"ParameterFunHelper",
|
|
90
|
+
"create_fun_and_gradient",
|
|
91
|
+
"remove_bounds",
|
|
92
|
+
"get_bounds",
|
|
93
|
+
"get_scipy_constraint",
|
|
94
|
+
"create_callback",
|
|
95
|
+
"minimize",
|
|
96
|
+
"copy_bounds_to_attr_name",
|
|
97
|
+
"set_bounds_from_params_mask"
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
import scipy
|
|
101
|
+
import scipy.optimize
|
|
102
|
+
from .utils.autograd import grad
|
|
103
|
+
import torch
|
|
104
|
+
import numpy as np
|
|
105
|
+
import torch.nn as nn
|
|
106
|
+
import copy
|
|
107
|
+
from typing import Callable, List, Tuple, Optional
|
|
108
|
+
|
|
109
|
+
def make_bounds_from_param(param):
|
|
110
|
+
"""
|
|
111
|
+
Creates default bounds (-∞, ∞) for each element of the input tensor.
|
|
112
|
+
|
|
113
|
+
This function returns a tensor of shape `param.shape + [2]`, where the last
|
|
114
|
+
dimension represents the lower and upper bounds for each element in `param`.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
param (torch.Tensor): A tensor for which bounds should be created.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
torch.Tensor: A tensor of shape `param.shape + [2]` where
|
|
121
|
+
`[..., 0] = -inf` (lower bounds) and `[..., 1] = inf` (upper bounds),
|
|
122
|
+
with the same dtype and device as `param`.
|
|
123
|
+
"""
|
|
124
|
+
bounds = torch.zeros(list(param.shape)+[2],device=param.device,dtype=param.dtype)
|
|
125
|
+
bounds[...,0] = -torch.inf
|
|
126
|
+
bounds[...,1] = torch.inf
|
|
127
|
+
return bounds
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def make_parameter_from_input(input,bounds=None, dtype=None, device=None,bounds_attr_name="bounds"):
|
|
131
|
+
"""
|
|
132
|
+
Converts input to a `torch.nn.Parameter` and attaches bounds as an attribute.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
input (array-like or torch.Tensor): Input data.
|
|
136
|
+
bounds (torch.Tensor, optional): Bounds to attach to the parameter.
|
|
137
|
+
dtype (torch.dtype, optional): Desired tensor data type.
|
|
138
|
+
device (torch.device, optional): Device to store the parameter on.
|
|
139
|
+
bounds_attr_name (str): Attribute name used to store bounds.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
torch.nn.Parameter: The parameter with bounds attached as an attribute.
|
|
143
|
+
"""
|
|
144
|
+
if not torch.is_tensor(input):
|
|
145
|
+
input = torch.tensor(input, dtype=dtype, device=device)
|
|
146
|
+
|
|
147
|
+
# If the input tensor has a different dtype or device, move it accordingly
|
|
148
|
+
if dtype is not None or device is not None:
|
|
149
|
+
input = input.to(device=device, dtype=dtype)
|
|
150
|
+
|
|
151
|
+
# If the input is not already a Parameter, convert it to one
|
|
152
|
+
if not isinstance(input, torch.nn.Parameter):
|
|
153
|
+
input = torch.nn.Parameter(input)
|
|
154
|
+
|
|
155
|
+
if bounds is None:
|
|
156
|
+
bounds = make_bounds_from_param(input)
|
|
157
|
+
#input.bounds = bounds
|
|
158
|
+
setattr(input,bounds_attr_name,bounds)
|
|
159
|
+
return input
|
|
160
|
+
|
|
161
|
+
def pack_tensors(tensor_list:List[torch.Tensor]) -> torch.Tensor:
|
|
162
|
+
"""
|
|
163
|
+
Flattens and concatenates a list of tensors into a single 1D tensor.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
tensor_list (list of torch.Tensor or torch.Tensor): Input tensor(s).
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
torch.Tensor: A 1D tensor.
|
|
170
|
+
"""
|
|
171
|
+
if torch.is_tensor(tensor_list):
|
|
172
|
+
return tensor_list.reshape(-1)
|
|
173
|
+
return torch.cat([t.reshape(-1) for t in tensor_list])
|
|
174
|
+
|
|
175
|
+
def unpack_tensors(packed_tensor: torch.Tensor, shapes: List[Tuple[int]]) -> List[torch.Tensor]:
|
|
176
|
+
"""
|
|
177
|
+
Unpacks a 1D tensor into a list of tensors with specified shapes.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
packed_tensor (torch.Tensor): The flat tensor to unpack.
|
|
181
|
+
shapes (list of tuple): Target shapes for unpacked tensors.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
list of torch.Tensor: Unpacked tensors with original shapes.
|
|
185
|
+
"""
|
|
186
|
+
unpacked_tensors = []
|
|
187
|
+
start = 0
|
|
188
|
+
for shape in shapes:
|
|
189
|
+
size = torch.prod(torch.tensor(shape)).item() # Calculate the size of the tensor
|
|
190
|
+
size = int(max(size,1))
|
|
191
|
+
# Reshape the portion of packed_tensor to the original shape
|
|
192
|
+
tensor = packed_tensor[start:start + size]
|
|
193
|
+
if len(shape) > 0: # Only reshape if shape is not scalar
|
|
194
|
+
tensor = tensor.reshape(*shape)
|
|
195
|
+
unpacked_tensors.append(tensor)
|
|
196
|
+
start += size # Move to the next start index
|
|
197
|
+
return unpacked_tensors
|
|
198
|
+
|
|
199
|
+
def apply_vec_to_params(vec: np.ndarray, params: list[torch.nn.Parameter], device=None, dtype=None):
|
|
200
|
+
"""
|
|
201
|
+
Updates PyTorch parameters with values from a flattened NumPy vector.
|
|
202
|
+
|
|
203
|
+
This function is used in optimization workflows to update parameter values
|
|
204
|
+
during SciPy optimization. It takes a flat vector of parameter values and
|
|
205
|
+
distributes them back to the original parameter tensors, preserving their
|
|
206
|
+
original shapes.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
vec (np.ndarray): A 1D NumPy array containing new parameter values.
|
|
210
|
+
The length must match the total number of elements across all parameters.
|
|
211
|
+
params (list[torch.nn.Parameter]): List of PyTorch parameters to update.
|
|
212
|
+
Each parameter will be reshaped from the corresponding portion of `vec`.
|
|
213
|
+
device (torch.device, optional): Target device for the parameters.
|
|
214
|
+
If None, uses the device of the first parameter. Defaults to None.
|
|
215
|
+
dtype (torch.dtype, optional): Target data type for the parameters.
|
|
216
|
+
If None, uses the dtype of the first parameter. Defaults to None.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
RuntimeError: If `vec` is not a NumPy array.
|
|
220
|
+
|
|
221
|
+
Example:
|
|
222
|
+
>>> import torch
|
|
223
|
+
>>> import numpy as np
|
|
224
|
+
>>> import diffinytrace as dit
|
|
225
|
+
>>>
|
|
226
|
+
>>> # Create some parameters
|
|
227
|
+
>>> params = [
|
|
228
|
+
... torch.nn.Parameter(torch.ones((2,2)))*0.25,
|
|
229
|
+
... torch.nn.Parameter(torch.ones(3))
|
|
230
|
+
... ]
|
|
231
|
+
>>> # Flatten parameters to create a vector
|
|
232
|
+
>>> vec = dit.optimize.pack_tensors(params).detach().cpu().numpy()
|
|
233
|
+
>>> print(f"Vector length: {len(vec)}") # Should be 2*2 + 3 = 7
|
|
234
|
+
>>> # Modify the vector
|
|
235
|
+
>>>
|
|
236
|
+
>>> print(params)
|
|
237
|
+
>>>
|
|
238
|
+
>>> vec_new = vec * 2.0
|
|
239
|
+
>>> # Update parameters with new values
|
|
240
|
+
>>> dit.optimize.apply_vec_to_params(vec_new, params)
|
|
241
|
+
>>>
|
|
242
|
+
>>> # Parameters are now updated with doubled values
|
|
243
|
+
>>> print(params)
|
|
244
|
+
|
|
245
|
+
Note:
|
|
246
|
+
- This function modifies parameters in-place using `param.data = ...`
|
|
247
|
+
- The function uses `torch.no_grad()` to avoid building computation graphs
|
|
248
|
+
- Parameter shapes are preserved during the update process
|
|
249
|
+
- Commonly used with `pack_tensors()` and `unpack_tensors()` for optimization
|
|
250
|
+
"""
|
|
251
|
+
if not isinstance(vec, np.ndarray):
|
|
252
|
+
raise RuntimeError("vec should be a numpy vector")
|
|
253
|
+
params = [elem for elem in params]
|
|
254
|
+
if dtype is None:
|
|
255
|
+
dtype = params[0].dtype
|
|
256
|
+
if device is None:
|
|
257
|
+
device = params[0].device
|
|
258
|
+
unpacked_params = unpack_tensors(torch.tensor(vec,device=device,dtype=dtype), [elem.shape for elem in params])
|
|
259
|
+
with torch.no_grad():
|
|
260
|
+
for k,param in enumerate(params):
|
|
261
|
+
param.data = unpacked_params[k]
|
|
262
|
+
|
|
263
|
+
def set_full_if_nan(input:np.ndarray, fill_value: float)->np.ndarray:
|
|
264
|
+
"""
|
|
265
|
+
Replaces NaNs in input with a specified fill value.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
input (np.ndarray): A NumPy array or scalar.
|
|
269
|
+
fill_value (float): Value to use in place of NaNs.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
np.ndarray or float: Modified input with no NaNs.
|
|
273
|
+
"""
|
|
274
|
+
if not isinstance(input, np.ndarray):
|
|
275
|
+
raise RuntimeError("set_full_if_nan,input should be a numpy vector")
|
|
276
|
+
|
|
277
|
+
if len(input.shape) == 0:
|
|
278
|
+
if np.isnan(input):
|
|
279
|
+
return np.array(fill_value)
|
|
280
|
+
else:
|
|
281
|
+
return input
|
|
282
|
+
else:
|
|
283
|
+
if np.isnan(input).any():
|
|
284
|
+
input = np.full_like(input, fill_value)
|
|
285
|
+
return input
|
|
286
|
+
else:
|
|
287
|
+
return input
|
|
288
|
+
|
|
289
|
+
class ParameterFunHelper():
|
|
290
|
+
"""
|
|
291
|
+
Helper class for evaluating PyTorch functions and gradients in SciPy optimization.
|
|
292
|
+
|
|
293
|
+
This class bridges PyTorch's automatic differentiation with SciPy's optimization
|
|
294
|
+
routines by providing function and gradient evaluations in NumPy format.
|
|
295
|
+
It includes caching to avoid redundant computations and handles NaN values
|
|
296
|
+
gracefully during optimization.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
original_fun (Callable): PyTorch function to be optimized. Should return a scalar tensor.
|
|
300
|
+
params (List[torch.nn.Parameter]): List of PyTorch parameters to optimize over.
|
|
301
|
+
nan_fallback (float, optional): Value to return if NaN is detected in function
|
|
302
|
+
or gradient evaluation. Defaults to float("inf").
|
|
303
|
+
|
|
304
|
+
Attributes:
|
|
305
|
+
original_fun (Callable): The objective function being optimized.
|
|
306
|
+
params (List[torch.nn.Parameter]): Parameters for optimization.
|
|
307
|
+
nan_fallback (float): Fallback value for NaN handling.
|
|
308
|
+
last_x_fun_numpy (np.ndarray): Cache of last input for function evaluation.
|
|
309
|
+
last_fun_val_numpy (float): Cache of last function value in NumPy format.
|
|
310
|
+
last_fun_val_torch (torch.Tensor): Cache of last function value as PyTorch tensor.
|
|
311
|
+
last_x_grad_numpy (np.ndarray): Cache of last input for gradient evaluation.
|
|
312
|
+
last_grad_val_numpy (np.ndarray): Cache of last gradient in NumPy format.
|
|
313
|
+
|
|
314
|
+
Example:
|
|
315
|
+
>>> import torch
|
|
316
|
+
>>> import diffinytrace as dit
|
|
317
|
+
>>> import numpy as np
|
|
318
|
+
>>>
|
|
319
|
+
>>> # Define parameters and objective function
|
|
320
|
+
>>> params = [torch.nn.Parameter(torch.randn(5))]
|
|
321
|
+
>>> def objective():
|
|
322
|
+
... return torch.sum(params[0]**2)
|
|
323
|
+
>>>
|
|
324
|
+
>>> # Create helper for SciPy optimization
|
|
325
|
+
>>> helper = dit.optimize.ParameterFunHelper(objective, params)
|
|
326
|
+
>>>
|
|
327
|
+
>>> # Use with SciPy
|
|
328
|
+
>>> x0 = np.ones((5,))*3.
|
|
329
|
+
>>> fun_val = helper.fun(x0) # Evaluate function 5*3^2 = 45
|
|
330
|
+
>>> grad_val = helper.jac(x0) # Evaluate gradient 2*3 = 6
|
|
331
|
+
>>> fun_val, grad_val = helper.fun_jac(x0) # Evaluate both
|
|
332
|
+
>>>
|
|
333
|
+
>>> print(fun_val, grad_val) # (45.0, array([6., 6., 6., 6., 6.]))
|
|
334
|
+
|
|
335
|
+
Note:
|
|
336
|
+
- Function and gradient evaluations are cached to avoid redundant computations
|
|
337
|
+
when SciPy requests the same point multiple times.
|
|
338
|
+
- All NaN values in function outputs or gradients are replaced with `nan_fallback`.
|
|
339
|
+
- Parameters are automatically updated with new values during evaluation.
|
|
340
|
+
"""
|
|
341
|
+
def __init__(self,orginal_fun,params,nan_fallback = float("inf")):
|
|
342
|
+
self.last_x_fun_numpy = None
|
|
343
|
+
self.last_fun_val_numpy = None
|
|
344
|
+
self.last_fun_val_torch = None
|
|
345
|
+
|
|
346
|
+
self.last_x_grad_numpy = None
|
|
347
|
+
self.last_grad_val_numpy = None
|
|
348
|
+
self.orginal_fun = orginal_fun
|
|
349
|
+
|
|
350
|
+
self.params = [param for param in params]
|
|
351
|
+
self.nan_fallback = nan_fallback
|
|
352
|
+
|
|
353
|
+
def fun(self,x):
|
|
354
|
+
"""
|
|
355
|
+
Evaluates the objective function at a given input.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
x (np.ndarray): Flat input array.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
float: Function value with NaNs replaced if needed.
|
|
362
|
+
"""
|
|
363
|
+
if not self.last_x_fun_numpy is None:
|
|
364
|
+
if (x == self.last_x_fun_numpy).all():
|
|
365
|
+
out = self.last_fun_val_numpy
|
|
366
|
+
out = set_full_if_nan(out,self.nan_fallback)
|
|
367
|
+
return out
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
device = self.params[0].device
|
|
371
|
+
dtype = self.params[0].dtype
|
|
372
|
+
apply_vec_to_params(x,self.params,device,dtype)
|
|
373
|
+
self.last_x_fun_numpy = copy.deepcopy(x)
|
|
374
|
+
fun_val = self.orginal_fun()
|
|
375
|
+
self.last_fun_val_torch = fun_val
|
|
376
|
+
self.last_fun_val_numpy = set_full_if_nan(fun_val.detach().cpu().numpy(),self.nan_fallback)
|
|
377
|
+
out = self.last_fun_val_numpy
|
|
378
|
+
out = set_full_if_nan(out,self.nan_fallback)
|
|
379
|
+
return out
|
|
380
|
+
|
|
381
|
+
def jac(self,x):
|
|
382
|
+
"""
|
|
383
|
+
Computes the gradient of the objective function at input x.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
x (np.ndarray): Flat input array.
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
np.ndarray: Gradient with NaNs replaced if needed.
|
|
390
|
+
"""
|
|
391
|
+
if not self.last_x_grad_numpy is None:
|
|
392
|
+
if (x == self.last_x_grad_numpy).all():
|
|
393
|
+
out = self.last_grad_val_numpy
|
|
394
|
+
out = set_full_if_nan(out,self.nan_fallback)
|
|
395
|
+
return out
|
|
396
|
+
|
|
397
|
+
self.fun(x)
|
|
398
|
+
self.last_x_grad_numpy = copy.deepcopy(x)
|
|
399
|
+
dp = grad(self.last_fun_val_torch,inputs=self.params,materialize_grads=True,create_graph=False,retain_graph=False)
|
|
400
|
+
dp = pack_tensors(dp)
|
|
401
|
+
dp_numpy = dp.detach().cpu().numpy()
|
|
402
|
+
|
|
403
|
+
self.last_grad_val_numpy = set_full_if_nan(dp_numpy,self.nan_fallback)
|
|
404
|
+
|
|
405
|
+
out = dp_numpy
|
|
406
|
+
out = set_full_if_nan(out,self.nan_fallback)
|
|
407
|
+
return out
|
|
408
|
+
|
|
409
|
+
def fun_jac(self,x):
|
|
410
|
+
"""
|
|
411
|
+
Evaluates both function value and gradient at once.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
x (np.ndarray): Flat input array.
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
Tuple[float, np.ndarray]: Function value and gradient.
|
|
418
|
+
"""
|
|
419
|
+
fun_val_numpy = self.fun(x)
|
|
420
|
+
grad_val_numpy = self.jac(x)
|
|
421
|
+
return fun_val_numpy,grad_val_numpy
|
|
422
|
+
"""
|
|
423
|
+
def hess(self,x,v):
|
|
424
|
+
if not self.calc_hess:
|
|
425
|
+
raise("ParameterFunHelper: calc_hess was initialized with False!")
|
|
426
|
+
device = self.last_grad_val_torch.device
|
|
427
|
+
dtype = self.last_grad_val_torch.dtype
|
|
428
|
+
self.grad(x)
|
|
429
|
+
v_torch = torch.tensor(v,device=device,dtype=dtype)
|
|
430
|
+
Hv = grad(self.last_grad_val_torch,inputs=self.params,grad_outputs=v_torch,materialize_grads=True,create_graph=False,retain_graph=True)
|
|
431
|
+
Hv_packed = pack_tensors(Hv)
|
|
432
|
+
out = Hv_packed.detach().cpu().numpy()
|
|
433
|
+
out = set_full_if_nan(out,self.nan_fallback)
|
|
434
|
+
print("hess out ",out)
|
|
435
|
+
return out"""
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def create_fun_and_gradient(merit_fun,params,nan_fallback,device,dtype):
|
|
439
|
+
"""
|
|
440
|
+
Wraps a PyTorch merit function and returns a callable that evaluates both
|
|
441
|
+
the function and its gradient in NumPy format.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
merit_fun (Callable): PyTorch function to optimize.
|
|
445
|
+
params (list): List of `torch.nn.Parameter` objects.
|
|
446
|
+
nan_fallback (float): Value to use if NaNs are encountered.
|
|
447
|
+
device (torch.device): Target device.
|
|
448
|
+
dtype (torch.dtype): Target dtype.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
Callable: Function that returns (value, gradient) as NumPy arrays.
|
|
452
|
+
"""
|
|
453
|
+
def fun_and_gradient(input):
|
|
454
|
+
apply_vec_to_params(input,params,device,dtype)
|
|
455
|
+
merit_val = merit_fun()
|
|
456
|
+
dmdp = grad(merit_val,inputs=params,materialize_grads=True,create_graph=False,retain_graph=False)
|
|
457
|
+
|
|
458
|
+
out_merit_val = merit_val.detach().cpu()
|
|
459
|
+
out_dmdp = [elem.detach().cpu() for elem in dmdp]
|
|
460
|
+
out_dmdp = pack_tensors(out_dmdp)
|
|
461
|
+
|
|
462
|
+
out_dmdp = set_full_if_nan(out_dmdp.numpy(),nan_fallback)
|
|
463
|
+
out_merit_val = set_full_if_nan(out_merit_val.numpy(),nan_fallback)
|
|
464
|
+
|
|
465
|
+
#print("merit_val: ",out_merit_val)
|
|
466
|
+
return out_merit_val,out_dmdp
|
|
467
|
+
return fun_and_gradient
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def remove_bounds(params,bounds_attr_name) -> None:
|
|
471
|
+
"""
|
|
472
|
+
Removes the bounds attribute from parameters if present.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
params (list): List of torch.nn.Parameter objects.
|
|
476
|
+
bounds_attr_name (str): Attribute name of bounds to remove.
|
|
477
|
+
"""
|
|
478
|
+
for elem in params:
|
|
479
|
+
if hasattr(elem,bounds_attr_name):
|
|
480
|
+
setattr(elem,bounds_attr_name,None)
|
|
481
|
+
|
|
482
|
+
def get_bounds(params,bounds_attr_name="bounds"):
|
|
483
|
+
"""
|
|
484
|
+
Extracts and concatenates bounds for all parameters.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
params (list): List of torch.nn.Parameter objects.
|
|
488
|
+
bounds_attr_name (str): Name of attribute storing bounds.
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
np.ndarray: Array of shape (N, 2) with all bounds.
|
|
492
|
+
"""
|
|
493
|
+
out = []
|
|
494
|
+
|
|
495
|
+
for elem in params:
|
|
496
|
+
if not hasattr(elem,bounds_attr_name):
|
|
497
|
+
bounds = make_bounds_from_param(elem)
|
|
498
|
+
setattr(elem,bounds_attr_name,bounds)
|
|
499
|
+
tmp = getattr(elem,bounds_attr_name)
|
|
500
|
+
if isinstance(tmp,list):
|
|
501
|
+
tmp = torch.tensor(np.array(tmp),dtype=torch.get_default_dtype())
|
|
502
|
+
if isinstance(tmp,np.ndarray):
|
|
503
|
+
tmp = torch.tensor(tmp,dtype=torch.get_default_dtype())
|
|
504
|
+
|
|
505
|
+
out += [tmp]
|
|
506
|
+
out = torch.cat([t.reshape(-1,2) for t in out],dim=0)
|
|
507
|
+
out = out.detach().cpu()
|
|
508
|
+
#print("out",out)
|
|
509
|
+
out = np.array(out)
|
|
510
|
+
return out
|
|
511
|
+
|
|
512
|
+
def get_scipy_constraint(constraint,params,nan_fallback):
|
|
513
|
+
"""
|
|
514
|
+
Converts a constraint into SciPy-compatible format.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
constraint (Constraint): A custom constraint object.
|
|
518
|
+
params (list): List of parameters for the optimization.
|
|
519
|
+
nan_fallback (float): Fallback value for NaNs.
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
dict: A dictionary compatible with SciPy constraints.
|
|
523
|
+
"""
|
|
524
|
+
param_fun_helper = ParameterFunHelper(constraint.fun,params,nan_fallback)
|
|
525
|
+
param_fun_helper.constraint=True
|
|
526
|
+
|
|
527
|
+
scipy_data = {'type': constraint.type,'fun':param_fun_helper.fun,'jac':param_fun_helper.jac}
|
|
528
|
+
return scipy_data
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def create_callback(callback_fun,params,device,dtype):
|
|
532
|
+
"""
|
|
533
|
+
Wraps a PyTorch callback function for use in SciPy.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
callback_fun (Callable): A function taking no arguments.
|
|
537
|
+
params (list): List of parameters to update before calling.
|
|
538
|
+
device (torch.device): Device of the parameters.
|
|
539
|
+
dtype (torch.dtype): Data type of the parameters.
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
Callable: A callback function for SciPy optimizers.
|
|
543
|
+
"""
|
|
544
|
+
def call_back(input):
|
|
545
|
+
apply_vec_to_params(input,params,device,dtype)
|
|
546
|
+
return callback_fun()
|
|
547
|
+
return call_back
|
|
548
|
+
|
|
549
|
+
#nlopt==2.6.2
|
|
550
|
+
"""
|
|
551
|
+
def global_dual_annealing(fun,
|
|
552
|
+
params,
|
|
553
|
+
constraints=[],
|
|
554
|
+
annealing_maxiter=1000,
|
|
555
|
+
annealing_initial_temp=5230.0,
|
|
556
|
+
annealing_restart_temp_ratio=2e-05,
|
|
557
|
+
annealing_visit=2.62,
|
|
558
|
+
annealing_accept=-5.0,
|
|
559
|
+
annealing_maxfun=10000000.0,
|
|
560
|
+
bounds_attr_name="bounds",
|
|
561
|
+
local_tol=1e-6,
|
|
562
|
+
local_method=None):
|
|
563
|
+
nan_fallback = annealing_maxfun
|
|
564
|
+
|
|
565
|
+
from .constraints import Constraint
|
|
566
|
+
|
|
567
|
+
if isinstance(constraints,Constraint):
|
|
568
|
+
constraints = [constraints]
|
|
569
|
+
|
|
570
|
+
if local_method is None:
|
|
571
|
+
if len(constraints) == 0:
|
|
572
|
+
local_method = 'L-BFGS-B'
|
|
573
|
+
else:
|
|
574
|
+
local_method = 'SLSQP'
|
|
575
|
+
|
|
576
|
+
if (not local_method == 'SLSQP') and (len(constraints)>0):
|
|
577
|
+
raise RuntimeError("Only for method SLSQP constraints are supported!")
|
|
578
|
+
|
|
579
|
+
if isinstance(params, torch.nn.Parameter):
|
|
580
|
+
params = [params]
|
|
581
|
+
|
|
582
|
+
params = [param for param in params if param.requires_grad]
|
|
583
|
+
|
|
584
|
+
if len(params) == 0:
|
|
585
|
+
raise RuntimeError("Params is either an empty list or no parameter provided requires_grad!")
|
|
586
|
+
|
|
587
|
+
constraints = [get_scipy_constraint(constraint,params,nan_fallback) for constraint in constraints]
|
|
588
|
+
|
|
589
|
+
device = params[0].device
|
|
590
|
+
dtype = params[0].dtype
|
|
591
|
+
|
|
592
|
+
bounds_numpy = get_bounds(params,bounds_attr_name)
|
|
593
|
+
if np.isinf(bounds_numpy).any():
|
|
594
|
+
raise RuntimeError("All bounds need to be non inf!")
|
|
595
|
+
param_helper_main = ParameterFunHelper(fun,params,nan_fallback)
|
|
596
|
+
#fun_helper = ParameterFunHelper(fun,params,False,nan_fallback)
|
|
597
|
+
|
|
598
|
+
minimizer_kwargs = dict(
|
|
599
|
+
#func=param_helper_main.fun,
|
|
600
|
+
jac=param_helper_main.jac,
|
|
601
|
+
constraints=constraints,
|
|
602
|
+
tol=local_tol,
|
|
603
|
+
method=local_method)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
initial_params = pack_tensors([param.cpu().detach() for param in params]) # Pack the initial params
|
|
608
|
+
|
|
609
|
+
result = scipy.optimize.dual_annealing(
|
|
610
|
+
func=param_helper_main.fun,
|
|
611
|
+
x0=initial_params,
|
|
612
|
+
bounds=bounds_numpy,
|
|
613
|
+
maxiter = annealing_maxiter,
|
|
614
|
+
initial_temp=annealing_initial_temp,
|
|
615
|
+
restart_temp_ratio = annealing_restart_temp_ratio,
|
|
616
|
+
visit = annealing_visit,
|
|
617
|
+
accept = annealing_accept,
|
|
618
|
+
maxfun=annealing_maxfun,
|
|
619
|
+
minimizer_kwargs=minimizer_kwargs)
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
apply_vec_to_params(result["x"],[p for p in params],device,dtype)
|
|
623
|
+
return result
|
|
624
|
+
"""
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def minimize(fun,
|
|
628
|
+
params,
|
|
629
|
+
constraints:List=[],
|
|
630
|
+
method=None,
|
|
631
|
+
tol:float=1e-9,
|
|
632
|
+
callback:Callable=lambda:None,
|
|
633
|
+
options:Optional[dict]=None,
|
|
634
|
+
nan_fallback:float=float("inf"),
|
|
635
|
+
bounds_attr_name:str="bounds",
|
|
636
|
+
save_history:bool=False,
|
|
637
|
+
call_before_minimize:bool=False)->dict:
|
|
638
|
+
"""
|
|
639
|
+
Minimizes a function using SciPy's `minimize`, supporting bounds and constraints.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
fun (Callable): Objective function.
|
|
643
|
+
params (list): Parameters to optimize.
|
|
644
|
+
constraints (list): List of constraints.
|
|
645
|
+
method (str): SciPy optimization method (e.g., 'L-BFGS-B').
|
|
646
|
+
tol (float): Tolerance for convergence.
|
|
647
|
+
callback (Callable): Optional callback function.
|
|
648
|
+
options (dict): Optimizer options.
|
|
649
|
+
nan_fallback (float): Value to use if function returns NaN.
|
|
650
|
+
bounds_attr_name (str): Name of bounds attribute.
|
|
651
|
+
save_history (bool): If True, saves function values and gradient norms.
|
|
652
|
+
call_before_minimize (bool): Whether to evaluate once before optimization.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
dict: Dictionary containing optimization results (and optionally history).
|
|
656
|
+
"""
|
|
657
|
+
from .constraints import Constraint
|
|
658
|
+
|
|
659
|
+
if isinstance(constraints,Constraint):
|
|
660
|
+
constraints = [constraints]
|
|
661
|
+
|
|
662
|
+
if method is None:
|
|
663
|
+
if len(constraints) == 0:
|
|
664
|
+
method = 'L-BFGS-B'
|
|
665
|
+
else:
|
|
666
|
+
method = 'SLSQP'
|
|
667
|
+
|
|
668
|
+
if (not method == 'SLSQP') and (len(constraints)>0):
|
|
669
|
+
raise RuntimeError("Only for method SLSQP constraints are supported!")
|
|
670
|
+
|
|
671
|
+
if isinstance(params, torch.nn.Parameter):
|
|
672
|
+
params = [params]
|
|
673
|
+
|
|
674
|
+
params = [param for param in params if param.requires_grad]
|
|
675
|
+
|
|
676
|
+
if len(params) == 0:
|
|
677
|
+
raise RuntimeError("Params is either an empty list or no parameter provided requires_grad!")
|
|
678
|
+
|
|
679
|
+
constraints = [get_scipy_constraint(constraint,params,nan_fallback) for constraint in constraints]
|
|
680
|
+
|
|
681
|
+
device = params[0].device
|
|
682
|
+
dtype = params[0].dtype
|
|
683
|
+
|
|
684
|
+
bounds_numpy = get_bounds(params,bounds_attr_name)
|
|
685
|
+
|
|
686
|
+
initial_params = pack_tensors([param.cpu().detach() for param in params]) # Pack the initial params
|
|
687
|
+
param_helper_main = ParameterFunHelper(fun,params,nan_fallback)
|
|
688
|
+
#fun_helper = ParameterFunHelper(fun,params,False,nan_fallback)
|
|
689
|
+
|
|
690
|
+
history = {"fun_vals":[],"fun_grads_norm":[]}
|
|
691
|
+
|
|
692
|
+
fun_and_gradient = param_helper_main.fun_jac#create_fun_and_gradient(fun,params,nan_fallback,device=device,dtype=dtype)
|
|
693
|
+
if save_history:
|
|
694
|
+
|
|
695
|
+
def callback_history(input):
|
|
696
|
+
|
|
697
|
+
out_merit_val,out_dmdp = fun_and_gradient(input)
|
|
698
|
+
history["fun_vals"] += [out_merit_val]
|
|
699
|
+
history["fun_grads_norm"] += [np.linalg.norm(out_dmdp)]
|
|
700
|
+
|
|
701
|
+
if callback is None:
|
|
702
|
+
callback = callback_history
|
|
703
|
+
else:
|
|
704
|
+
callback_tmp = create_callback(callback,params,device,dtype)
|
|
705
|
+
def combined_callback(input):
|
|
706
|
+
callback_tmp(input)
|
|
707
|
+
callback_history(input)
|
|
708
|
+
callback = combined_callback
|
|
709
|
+
elif callback is not None:
|
|
710
|
+
callback = create_callback(callback,params,device,dtype)
|
|
711
|
+
|
|
712
|
+
initial_params = np.array(initial_params)
|
|
713
|
+
if call_before_minimize:
|
|
714
|
+
fun_and_gradient(initial_params)
|
|
715
|
+
callback(initial_params)
|
|
716
|
+
|
|
717
|
+
result = scipy.optimize.minimize(
|
|
718
|
+
fun=fun_and_gradient,
|
|
719
|
+
x0=initial_params,
|
|
720
|
+
jac=True, # Indicates that the function returns both value and gradient
|
|
721
|
+
bounds=bounds_numpy,
|
|
722
|
+
method=method, # Choose an appropriate method
|
|
723
|
+
tol=tol,
|
|
724
|
+
callback=callback,
|
|
725
|
+
options=options,
|
|
726
|
+
constraints=constraints,
|
|
727
|
+
#hessp=fun_helper.hess
|
|
728
|
+
)
|
|
729
|
+
apply_vec_to_params(result["x"],[p for p in params],device,dtype)
|
|
730
|
+
result = {key:result[key] for key in result.keys()}
|
|
731
|
+
if len(history["fun_vals"])>0:
|
|
732
|
+
history["fun_vals"] = np.array(history["fun_vals"])
|
|
733
|
+
history["fun_grads_norm"] = np.array(history["fun_grads_norm"])
|
|
734
|
+
|
|
735
|
+
if save_history:
|
|
736
|
+
result["history"] = history
|
|
737
|
+
return result
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def copy_bounds_to_attr_name(params,bounds_attr_name_new,bounds_attr_name_old="bounds",replace_existing_once=True):
|
|
741
|
+
"""
|
|
742
|
+
Copies bounds from one attribute name to another.
|
|
743
|
+
|
|
744
|
+
Args:
|
|
745
|
+
params (list): List of parameters.
|
|
746
|
+
bounds_attr_name_new (str): New attribute name.
|
|
747
|
+
bounds_attr_name_old (str): Existing attribute name.
|
|
748
|
+
replace_existing_once (bool): Whether to skip copying if new attribute exists.
|
|
749
|
+
"""
|
|
750
|
+
def copy_bounds(param,bounds_attr_name_new,bounds_attr_name_old="bounds"):
|
|
751
|
+
bounds = None
|
|
752
|
+
if hasattr(param,bounds_attr_name_old):
|
|
753
|
+
bounds = getattr(param,bounds_attr_name_old)
|
|
754
|
+
else:
|
|
755
|
+
bounds = make_bounds_from_param(param)
|
|
756
|
+
bounds = bounds.clone()
|
|
757
|
+
setattr(param,bounds_attr_name_new,bounds)
|
|
758
|
+
if isinstance(params,nn.Parameter):
|
|
759
|
+
params = [params]
|
|
760
|
+
params = [param for param in params]
|
|
761
|
+
for param in params:
|
|
762
|
+
if (not replace_existing_once) and (hasattr(param,bounds_attr_name_new)):
|
|
763
|
+
continue
|
|
764
|
+
else:
|
|
765
|
+
copy_bounds(param,bounds_attr_name_new,bounds_attr_name_old)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def set_bounds_from_params_mask(params,mask:list|torch.Tensor,bounds_attr_name_new,bounds_attr_name_old="bounds"):
|
|
769
|
+
"""
|
|
770
|
+
Sets bounds for parameters based on a mask. Parameters with `mask=False`
|
|
771
|
+
get fixed bounds (equal lower and upper bounds).
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
params (list): List of parameters.
|
|
775
|
+
mask (list or torch.Tensor): Mask specifying which elements are free.
|
|
776
|
+
bounds_attr_name_new (str): Attribute name to store new bounds.
|
|
777
|
+
bounds_attr_name_old (str): Attribute name to read old bounds from.
|
|
778
|
+
"""
|
|
779
|
+
def set_new_bounds_from_param_mask(param,mask,bounds_attr_name_new,bounds_attr_name_old="bounds"):
|
|
780
|
+
bounds = None
|
|
781
|
+
if hasattr(param,bounds_attr_name_old):
|
|
782
|
+
bounds = getattr(param,bounds_attr_name_old)
|
|
783
|
+
else:
|
|
784
|
+
bounds = make_bounds_from_param(param)
|
|
785
|
+
bounds = bounds.clone()
|
|
786
|
+
bounds_shape = bounds.shape
|
|
787
|
+
mask = mask.reshape(-1)
|
|
788
|
+
bounds = bounds.reshape(-1,2)
|
|
789
|
+
data = param.data.clone()
|
|
790
|
+
data = data.reshape(-1)
|
|
791
|
+
#print("shapes",mask.shape,(mask==False).shape,bounds.shape)
|
|
792
|
+
mask_false = mask==False
|
|
793
|
+
bounds[mask_false,0] = data[mask_false]
|
|
794
|
+
bounds[mask_false,1] = data[mask_false]
|
|
795
|
+
bounds = bounds.reshape(*bounds_shape)
|
|
796
|
+
setattr(param,bounds_attr_name_new,bounds)
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
if isinstance(params,nn.Parameter):
|
|
800
|
+
params = [params]
|
|
801
|
+
params = [param for param in params]
|
|
802
|
+
if isinstance(mask,(np.ndarray)) or torch.is_tensor(mask):
|
|
803
|
+
mask = [mask]
|
|
804
|
+
|
|
805
|
+
for k in range(len(params)):
|
|
806
|
+
set_new_bounds_from_param_mask(params[k],mask[k],bounds_attr_name_new=bounds_attr_name_new,bounds_attr_name_old=bounds_attr_name_old)
|
|
807
|
+
|
|
808
|
+
|