evograd-diff 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evograd/__init__.py +67 -0
- evograd/algorithms/__init__.py +138 -0
- evograd/algorithms/cmaes.py +1365 -0
- evograd/algorithms/de.py +895 -0
- evograd/algorithms/ga.py +532 -0
- evograd/algorithms/pso.py +648 -0
- evograd/algorithms/shade.py +1165 -0
- evograd/benchmarks/functions/__init__.py +229 -0
- evograd/benchmarks/functions/base.py +217 -0
- evograd/benchmarks/functions/cec2017/__init__.py +250 -0
- evograd/benchmarks/functions/cec2017/basic.py +413 -0
- evograd/benchmarks/functions/cec2017/composition.py +580 -0
- evograd/benchmarks/functions/cec2017/data.pkl +0 -0
- evograd/benchmarks/functions/cec2017/data.py +350 -0
- evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
- evograd/benchmarks/functions/cec2017/simple.py +326 -0
- evograd/benchmarks/functions/classical.py +649 -0
- evograd/benchmarks/functions/smoothed_funnel.py +476 -0
- evograd/benchmarks/functions/transforms.py +463 -0
- evograd/benchmarks/run_benchmark_functions.py +1208 -0
- evograd/core/__init__.py +73 -0
- evograd/core/algorithm.py +778 -0
- evograd/core/maximize.py +269 -0
- evograd/core/minimize.py +740 -0
- evograd/core/problem.py +444 -0
- evograd/core/result.py +571 -0
- evograd/core/termination.py +602 -0
- evograd/operators/__init__.py +178 -0
- evograd/operators/crossover.py +1117 -0
- evograd/operators/mutation.py +1098 -0
- evograd/operators/relaxations.py +175 -0
- evograd/operators/repair.py +601 -0
- evograd/operators/sampling.py +577 -0
- evograd/operators/selection.py +981 -0
- evograd/operators/survival.py +1000 -0
- evograd/tests/__init__.py +11 -0
- evograd/tests/run_all.py +78 -0
- evograd/tests/test_core.py +528 -0
- evograd/tests/test_ga.py +572 -0
- evograd/tests/test_operators.py +662 -0
- evograd/tests/test_per_individual.py +326 -0
- evograd/tests/test_utils.py +328 -0
- evograd/utils/__init__.py +97 -0
- evograd/utils/callbacks.py +926 -0
- evograd/utils/device.py +502 -0
- evograd/utils/duplicates.py +421 -0
- evograd_diff-0.1.0.dist-info/METADATA +439 -0
- evograd_diff-0.1.0.dist-info/RECORD +50 -0
- evograd_diff-0.1.0.dist-info/WHEEL +4 -0
- evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Transformations for benchmark functions.
|
|
3
|
+
|
|
4
|
+
This module provides transformations that wrap base functions to create
|
|
5
|
+
more challenging optimization problems, following CEC competition conventions.
|
|
6
|
+
|
|
7
|
+
Transformations:
|
|
8
|
+
- ShiftedFunction: f(x - o) where o is the shift vector
|
|
9
|
+
- RotatedFunction: f(R @ x) where R is a rotation matrix
|
|
10
|
+
- ShiftedRotatedFunction: f(R @ (x - o))
|
|
11
|
+
- ScaledFunction: f(lambda * x)
|
|
12
|
+
- AsymmetricFunction: Applies asymmetric transformation
|
|
13
|
+
- OscillatedFunction: Applies oscillation transformation
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from typing import Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
from .base import BenchmarkFunction
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ShiftedFunction(BenchmarkFunction):
|
|
25
|
+
"""
|
|
26
|
+
Shifted benchmark function.
|
|
27
|
+
|
|
28
|
+
f_shift(x) = f(x - shift)
|
|
29
|
+
|
|
30
|
+
Moves the optimal solution away from the origin, preventing
|
|
31
|
+
algorithms from exploiting symmetry.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
base_function: BenchmarkFunction,
|
|
37
|
+
shift: Optional[Tensor] = None,
|
|
38
|
+
seed: Optional[int] = None,
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Initialize shifted function.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
base_function: Base function to shift.
|
|
45
|
+
shift: Shift vector of shape [n_var]. If None, generated randomly.
|
|
46
|
+
seed: Random seed for generating shift vector.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(
|
|
49
|
+
n_var=base_function.n_var,
|
|
50
|
+
xl=base_function.xl,
|
|
51
|
+
xu=base_function.xu,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
self.base_function = base_function
|
|
55
|
+
self.name = f"shifted_{base_function.name}"
|
|
56
|
+
self.optimal_value = base_function.optimal_value
|
|
57
|
+
|
|
58
|
+
# Generate or use provided shift
|
|
59
|
+
if shift is not None:
|
|
60
|
+
self.shift = shift
|
|
61
|
+
else:
|
|
62
|
+
if seed is not None:
|
|
63
|
+
torch.manual_seed(seed)
|
|
64
|
+
# Generate shift within 80% of the search space
|
|
65
|
+
xl, xu = base_function.xl, base_function.xu
|
|
66
|
+
range_val = xu - xl
|
|
67
|
+
self.shift = xl + 0.1 * range_val + 0.8 * range_val * torch.rand(self.n_var)
|
|
68
|
+
|
|
69
|
+
# Update optimal location
|
|
70
|
+
self._optimal_x = self.shift + base_function.optimal_x
|
|
71
|
+
|
|
72
|
+
def default_bounds(self) -> Tuple[float, float]:
|
|
73
|
+
return self.base_function.default_bounds()
|
|
74
|
+
|
|
75
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
76
|
+
shift = self.shift.to(x.device, x.dtype)
|
|
77
|
+
return self.base_function(x - shift)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class RotatedFunction(BenchmarkFunction):
|
|
81
|
+
"""
|
|
82
|
+
Rotated benchmark function.
|
|
83
|
+
|
|
84
|
+
f_rot(x) = f(R @ x)
|
|
85
|
+
|
|
86
|
+
Applies rotation to make separable functions non-separable,
|
|
87
|
+
testing an algorithm's ability to handle variable interactions.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
base_function: BenchmarkFunction,
|
|
93
|
+
rotation_matrix: Optional[Tensor] = None,
|
|
94
|
+
seed: Optional[int] = None,
|
|
95
|
+
):
|
|
96
|
+
"""
|
|
97
|
+
Initialize rotated function.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
base_function: Base function to rotate.
|
|
101
|
+
rotation_matrix: Orthogonal rotation matrix of shape [n_var, n_var].
|
|
102
|
+
If None, generated randomly.
|
|
103
|
+
seed: Random seed for generating rotation matrix.
|
|
104
|
+
"""
|
|
105
|
+
super().__init__(
|
|
106
|
+
n_var=base_function.n_var,
|
|
107
|
+
xl=base_function.xl,
|
|
108
|
+
xu=base_function.xu,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self.base_function = base_function
|
|
112
|
+
self.name = f"rotated_{base_function.name}"
|
|
113
|
+
self.optimal_value = base_function.optimal_value
|
|
114
|
+
|
|
115
|
+
# Generate or use provided rotation matrix
|
|
116
|
+
if rotation_matrix is not None:
|
|
117
|
+
self.R = rotation_matrix
|
|
118
|
+
else:
|
|
119
|
+
if seed is not None:
|
|
120
|
+
torch.manual_seed(seed)
|
|
121
|
+
# Generate orthogonal matrix using QR decomposition
|
|
122
|
+
A = torch.randn(self.n_var, self.n_var)
|
|
123
|
+
Q, _ = torch.linalg.qr(A)
|
|
124
|
+
self.R = Q
|
|
125
|
+
|
|
126
|
+
# For rotated functions, optimal is still at the same point
|
|
127
|
+
self._optimal_x = base_function.optimal_x.clone()
|
|
128
|
+
|
|
129
|
+
def default_bounds(self) -> Tuple[float, float]:
|
|
130
|
+
return self.base_function.default_bounds()
|
|
131
|
+
|
|
132
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
133
|
+
R = self.R.to(x.device, x.dtype)
|
|
134
|
+
# x @ R.T is equivalent to R @ x for each row
|
|
135
|
+
x_rotated = x @ R.T
|
|
136
|
+
return self.base_function(x_rotated)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class ShiftedRotatedFunction(BenchmarkFunction):
|
|
140
|
+
"""
|
|
141
|
+
Shifted and rotated benchmark function.
|
|
142
|
+
|
|
143
|
+
f_sr(x) = f(R @ (x - shift))
|
|
144
|
+
|
|
145
|
+
Combines both transformations for maximum difficulty.
|
|
146
|
+
This is the standard transformation used in CEC benchmarks.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
base_function: BenchmarkFunction,
|
|
152
|
+
shift: Optional[Tensor] = None,
|
|
153
|
+
rotation_matrix: Optional[Tensor] = None,
|
|
154
|
+
seed: Optional[int] = None,
|
|
155
|
+
):
|
|
156
|
+
"""
|
|
157
|
+
Initialize shifted and rotated function.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
base_function: Base function to transform.
|
|
161
|
+
shift: Shift vector. If None, generated randomly.
|
|
162
|
+
rotation_matrix: Rotation matrix. If None, generated randomly.
|
|
163
|
+
seed: Random seed for random generation.
|
|
164
|
+
"""
|
|
165
|
+
super().__init__(
|
|
166
|
+
n_var=base_function.n_var,
|
|
167
|
+
xl=base_function.xl,
|
|
168
|
+
xu=base_function.xu,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
self.base_function = base_function
|
|
172
|
+
self.name = f"shifted_rotated_{base_function.name}"
|
|
173
|
+
self.optimal_value = base_function.optimal_value
|
|
174
|
+
|
|
175
|
+
if seed is not None:
|
|
176
|
+
torch.manual_seed(seed)
|
|
177
|
+
|
|
178
|
+
# Generate or use provided shift
|
|
179
|
+
if shift is not None:
|
|
180
|
+
self.shift = shift
|
|
181
|
+
else:
|
|
182
|
+
xl, xu = base_function.xl, base_function.xu
|
|
183
|
+
range_val = xu - xl
|
|
184
|
+
self.shift = xl + 0.1 * range_val + 0.8 * range_val * torch.rand(self.n_var)
|
|
185
|
+
|
|
186
|
+
# Generate or use provided rotation matrix
|
|
187
|
+
if rotation_matrix is not None:
|
|
188
|
+
self.R = rotation_matrix
|
|
189
|
+
else:
|
|
190
|
+
A = torch.randn(self.n_var, self.n_var)
|
|
191
|
+
Q, _ = torch.linalg.qr(A)
|
|
192
|
+
self.R = Q
|
|
193
|
+
|
|
194
|
+
# Optimal is at the shifted location
|
|
195
|
+
self._optimal_x = self.shift.clone()
|
|
196
|
+
|
|
197
|
+
def default_bounds(self) -> Tuple[float, float]:
|
|
198
|
+
return self.base_function.default_bounds()
|
|
199
|
+
|
|
200
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
201
|
+
shift = self.shift.to(x.device, x.dtype)
|
|
202
|
+
R = self.R.to(x.device, x.dtype)
|
|
203
|
+
|
|
204
|
+
x_shifted = x - shift
|
|
205
|
+
x_rotated = x_shifted @ R.T
|
|
206
|
+
|
|
207
|
+
return self.base_function(x_rotated)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class ScaledFunction(BenchmarkFunction):
|
|
211
|
+
"""
|
|
212
|
+
Scaled benchmark function.
|
|
213
|
+
|
|
214
|
+
f_scaled(x) = f(scale * x)
|
|
215
|
+
|
|
216
|
+
Changes the scale of the search space.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
base_function: BenchmarkFunction,
|
|
222
|
+
scale: Union[float, Tensor] = 1.0,
|
|
223
|
+
):
|
|
224
|
+
"""
|
|
225
|
+
Initialize scaled function.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
base_function: Base function to scale.
|
|
229
|
+
scale: Scale factor (scalar or per-dimension tensor).
|
|
230
|
+
"""
|
|
231
|
+
super().__init__(
|
|
232
|
+
n_var=base_function.n_var,
|
|
233
|
+
xl=base_function.xl,
|
|
234
|
+
xu=base_function.xu,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
self.base_function = base_function
|
|
238
|
+
self.name = f"scaled_{base_function.name}"
|
|
239
|
+
self.optimal_value = base_function.optimal_value
|
|
240
|
+
|
|
241
|
+
if isinstance(scale, (int, float)):
|
|
242
|
+
self.scale = torch.full((self.n_var,), float(scale))
|
|
243
|
+
else:
|
|
244
|
+
self.scale = scale
|
|
245
|
+
|
|
246
|
+
self._optimal_x = base_function.optimal_x / self.scale
|
|
247
|
+
|
|
248
|
+
def default_bounds(self) -> Tuple[float, float]:
|
|
249
|
+
return self.base_function.default_bounds()
|
|
250
|
+
|
|
251
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
252
|
+
scale = self.scale.to(x.device, x.dtype)
|
|
253
|
+
return self.base_function(x * scale)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class AsymmetricFunction(BenchmarkFunction):
|
|
257
|
+
"""
|
|
258
|
+
Asymmetric transformation wrapper.
|
|
259
|
+
|
|
260
|
+
Applies T_asy transformation from CEC benchmarks:
|
|
261
|
+
x_i = x_i^(1 + beta * i/(n-1) * sqrt(x_i)) for x_i > 0
|
|
262
|
+
|
|
263
|
+
Breaks symmetry around the origin.
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
base_function: BenchmarkFunction,
|
|
269
|
+
beta: float = 0.5,
|
|
270
|
+
):
|
|
271
|
+
"""
|
|
272
|
+
Initialize asymmetric function.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
base_function: Base function to transform.
|
|
276
|
+
beta: Asymmetry parameter (typically 0.5).
|
|
277
|
+
"""
|
|
278
|
+
super().__init__(
|
|
279
|
+
n_var=base_function.n_var,
|
|
280
|
+
xl=base_function.xl,
|
|
281
|
+
xu=base_function.xu,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
self.base_function = base_function
|
|
285
|
+
self.name = f"asymmetric_{base_function.name}"
|
|
286
|
+
self.optimal_value = base_function.optimal_value
|
|
287
|
+
self.beta = beta
|
|
288
|
+
|
|
289
|
+
self._optimal_x = base_function.optimal_x.clone()
|
|
290
|
+
|
|
291
|
+
def default_bounds(self) -> Tuple[float, float]:
|
|
292
|
+
return self.base_function.default_bounds()
|
|
293
|
+
|
|
294
|
+
def _transform(self, x: Tensor) -> Tensor:
|
|
295
|
+
"""Apply asymmetric transformation."""
|
|
296
|
+
n = x.shape[-1]
|
|
297
|
+
i = torch.arange(n, device=x.device, dtype=x.dtype)
|
|
298
|
+
exponent = 1 + self.beta * i / max(n - 1, 1) * torch.sqrt(torch.abs(x) + 1e-10)
|
|
299
|
+
|
|
300
|
+
x_transformed = x.clone()
|
|
301
|
+
positive_mask = x > 0
|
|
302
|
+
x_transformed = torch.where(
|
|
303
|
+
positive_mask,
|
|
304
|
+
torch.pow(x + 1e-10, exponent),
|
|
305
|
+
x
|
|
306
|
+
)
|
|
307
|
+
return x_transformed
|
|
308
|
+
|
|
309
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
310
|
+
x_transformed = self._transform(x)
|
|
311
|
+
return self.base_function(x_transformed)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class OscillatedFunction(BenchmarkFunction):
|
|
315
|
+
"""
|
|
316
|
+
Oscillation transformation wrapper.
|
|
317
|
+
|
|
318
|
+
Applies T_osz transformation from CEC benchmarks:
|
|
319
|
+
Creates local irregularities while preserving global structure.
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
base_function: BenchmarkFunction,
|
|
325
|
+
):
|
|
326
|
+
"""
|
|
327
|
+
Initialize oscillated function.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
base_function: Base function to transform.
|
|
331
|
+
"""
|
|
332
|
+
super().__init__(
|
|
333
|
+
n_var=base_function.n_var,
|
|
334
|
+
xl=base_function.xl,
|
|
335
|
+
xu=base_function.xu,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
self.base_function = base_function
|
|
339
|
+
self.name = f"oscillated_{base_function.name}"
|
|
340
|
+
self.optimal_value = base_function.optimal_value
|
|
341
|
+
|
|
342
|
+
self._optimal_x = base_function.optimal_x.clone()
|
|
343
|
+
|
|
344
|
+
def default_bounds(self) -> Tuple[float, float]:
|
|
345
|
+
return self.base_function.default_bounds()
|
|
346
|
+
|
|
347
|
+
def _transform(self, x: Tensor) -> Tensor:
|
|
348
|
+
"""Apply oscillation transformation T_osz."""
|
|
349
|
+
x_hat = torch.where(
|
|
350
|
+
x != 0,
|
|
351
|
+
torch.log(torch.abs(x) + 1e-10),
|
|
352
|
+
torch.zeros_like(x)
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
c1 = torch.where(x > 0, torch.tensor(10.0), torch.tensor(5.5))
|
|
356
|
+
c2 = torch.where(x > 0, torch.tensor(7.9), torch.tensor(3.1))
|
|
357
|
+
|
|
358
|
+
sign_x = torch.sign(x)
|
|
359
|
+
|
|
360
|
+
return sign_x * torch.exp(
|
|
361
|
+
x_hat + 0.049 * (
|
|
362
|
+
torch.sin(c1 * x_hat) + torch.sin(c2 * x_hat)
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
367
|
+
x_transformed = self._transform(x)
|
|
368
|
+
return self.base_function(x_transformed)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class BiasedFunction(BenchmarkFunction):
|
|
372
|
+
"""
|
|
373
|
+
Biased benchmark function.
|
|
374
|
+
|
|
375
|
+
f_biased(x) = f(x) + bias
|
|
376
|
+
|
|
377
|
+
Shifts the optimal function value.
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
def __init__(
|
|
381
|
+
self,
|
|
382
|
+
base_function: BenchmarkFunction,
|
|
383
|
+
bias: float = 0.0,
|
|
384
|
+
):
|
|
385
|
+
"""
|
|
386
|
+
Initialize biased function.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
base_function: Base function.
|
|
390
|
+
bias: Value added to function output.
|
|
391
|
+
"""
|
|
392
|
+
super().__init__(
|
|
393
|
+
n_var=base_function.n_var,
|
|
394
|
+
xl=base_function.xl,
|
|
395
|
+
xu=base_function.xu,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
self.base_function = base_function
|
|
399
|
+
self.name = f"biased_{base_function.name}"
|
|
400
|
+
self.optimal_value = base_function.optimal_value + bias
|
|
401
|
+
self.bias = bias
|
|
402
|
+
|
|
403
|
+
self._optimal_x = base_function.optimal_x.clone()
|
|
404
|
+
|
|
405
|
+
def default_bounds(self) -> Tuple[float, float]:
|
|
406
|
+
return self.base_function.default_bounds()
|
|
407
|
+
|
|
408
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
409
|
+
return self.base_function(x) + self.bias
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
# =============================================================================
|
|
413
|
+
# UTILITY FUNCTIONS
|
|
414
|
+
# =============================================================================
|
|
415
|
+
|
|
416
|
+
def generate_rotation_matrix(n: int, seed: Optional[int] = None) -> Tensor:
|
|
417
|
+
"""
|
|
418
|
+
Generate a random orthogonal rotation matrix.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
n: Dimension of the matrix.
|
|
422
|
+
seed: Random seed.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Orthogonal matrix of shape [n, n].
|
|
426
|
+
"""
|
|
427
|
+
if seed is not None:
|
|
428
|
+
torch.manual_seed(seed)
|
|
429
|
+
A = torch.randn(n, n)
|
|
430
|
+
Q, _ = torch.linalg.qr(A)
|
|
431
|
+
return Q
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def generate_shift_vector(
|
|
435
|
+
n: int,
|
|
436
|
+
xl: Union[float, Tensor],
|
|
437
|
+
xu: Union[float, Tensor],
|
|
438
|
+
margin: float = 0.1,
|
|
439
|
+
seed: Optional[int] = None,
|
|
440
|
+
) -> Tensor:
|
|
441
|
+
"""
|
|
442
|
+
Generate a random shift vector within bounds.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
n: Dimension of the vector.
|
|
446
|
+
xl: Lower bounds.
|
|
447
|
+
xu: Upper bounds.
|
|
448
|
+
margin: Margin from bounds (fraction of range).
|
|
449
|
+
seed: Random seed.
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
Shift vector of shape [n].
|
|
453
|
+
"""
|
|
454
|
+
if seed is not None:
|
|
455
|
+
torch.manual_seed(seed)
|
|
456
|
+
|
|
457
|
+
if isinstance(xl, (int, float)):
|
|
458
|
+
xl = torch.full((n,), float(xl))
|
|
459
|
+
if isinstance(xu, (int, float)):
|
|
460
|
+
xu = torch.full((n,), float(xu))
|
|
461
|
+
|
|
462
|
+
range_val = xu - xl
|
|
463
|
+
return xl + margin * range_val + (1 - 2 * margin) * range_val * torch.rand(n)
|