imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,424 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Callable, NamedTuple, Optional
|
3
|
+
|
4
|
+
import haiku as hk
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from tree_utils import PyTree
|
8
|
+
|
9
|
+
import ring
|
10
|
+
from ring import maths
|
11
|
+
from ring.algorithms._random import random_angle_over_time
|
12
|
+
|
13
|
+
Params = PyTree
|
14
|
+
|
15
|
+
|
16
|
+
class DrawnFnPair(NamedTuple):
|
17
|
+
# (key) -> tree
|
18
|
+
init: Callable[[jax.Array], Params]
|
19
|
+
# (params, q) -> (1,)
|
20
|
+
apply: Callable[[Params, jax.Array], jax.Array]
|
21
|
+
|
22
|
+
|
23
|
+
# (flexions, min, max) -> DrawnFnPair
|
24
|
+
DrawnFnPairFactory = Callable[[jax.Array, float, float], DrawnFnPair]
|
25
|
+
|
26
|
+
|
27
|
+
def deg2rad(deg: float):
|
28
|
+
return (deg / 180.0) * 3.1415926535
|
29
|
+
|
30
|
+
|
31
|
+
def GP_DrawFnPair(
|
32
|
+
length_scale: float = 1.4, large_abs_values_of_gps: float = 0.25
|
33
|
+
) -> DrawnFnPairFactory:
|
34
|
+
|
35
|
+
def factory(xs, mn, mx):
|
36
|
+
def init(key):
|
37
|
+
return {
|
38
|
+
"xs": xs,
|
39
|
+
"ys": _gp_draw_and_rom(
|
40
|
+
key=key,
|
41
|
+
xs=xs,
|
42
|
+
ys=None,
|
43
|
+
length_scale=length_scale,
|
44
|
+
mn=mn,
|
45
|
+
mx=mx,
|
46
|
+
amin=-large_abs_values_of_gps,
|
47
|
+
amax=large_abs_values_of_gps,
|
48
|
+
),
|
49
|
+
}
|
50
|
+
|
51
|
+
def apply(params, q):
|
52
|
+
return jnp.interp(q, params["xs"], params["ys"])
|
53
|
+
|
54
|
+
return DrawnFnPair(init, apply)
|
55
|
+
|
56
|
+
return factory
|
57
|
+
|
58
|
+
|
59
|
+
@dataclass
|
60
|
+
class SuntayConfig:
|
61
|
+
flexion_rot_min: float = -deg2rad(5.0)
|
62
|
+
flexion_rot_max: float = deg2rad(95.0)
|
63
|
+
flexion_rot_restrict_method: str = "minmax"
|
64
|
+
###
|
65
|
+
flexion_pos_min: float = -0.015
|
66
|
+
flexion_pos_max: float = 0.015
|
67
|
+
flexion_pos_factory: DrawnFnPairFactory = GP_DrawFnPair()
|
68
|
+
###
|
69
|
+
abduction_rot_min: float = deg2rad(-4)
|
70
|
+
abduction_rot_max: float = deg2rad(4)
|
71
|
+
abduction_rot_factory: DrawnFnPairFactory = GP_DrawFnPair()
|
72
|
+
###
|
73
|
+
abduction_pos_min: float = -0.015
|
74
|
+
abduction_pos_max: float = 0.015
|
75
|
+
abduction_pos_factory: DrawnFnPairFactory = GP_DrawFnPair()
|
76
|
+
###
|
77
|
+
external_rot_min: float = deg2rad(-10)
|
78
|
+
external_rot_max: float = deg2rad(10)
|
79
|
+
external_rot_factory: DrawnFnPairFactory = GP_DrawFnPair()
|
80
|
+
###
|
81
|
+
external_pos_min: float = -0.06
|
82
|
+
external_pos_max: float = 0.0
|
83
|
+
external_pos_factory: DrawnFnPairFactory = GP_DrawFnPair()
|
84
|
+
###
|
85
|
+
num_points: int = 50
|
86
|
+
mconfig: Optional[ring.MotionConfig] = None
|
87
|
+
|
88
|
+
|
89
|
+
def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
90
|
+
"""Ref to 'E.S. Grood and W.J. Suntay' paper"""
|
91
|
+
|
92
|
+
flexion_xs = jnp.linspace(
|
93
|
+
sconfig.flexion_rot_min, sconfig.flexion_rot_max, num=sconfig.num_points
|
94
|
+
)
|
95
|
+
|
96
|
+
draw_fn_pairs = {}
|
97
|
+
for config_name, params_name in zip(
|
98
|
+
[
|
99
|
+
"flexion_pos",
|
100
|
+
"abduction_rot",
|
101
|
+
"abduction_pos",
|
102
|
+
"external_rot",
|
103
|
+
"external_pos",
|
104
|
+
],
|
105
|
+
["ys_S1", "ys_beta", "ys_S2", "ys_gamma", "ys_S3"],
|
106
|
+
):
|
107
|
+
get = lambda key: getattr(sconfig, config_name + "_" + key)
|
108
|
+
factory = get("factory")
|
109
|
+
draw_fn_pairs[params_name] = factory(flexion_xs, get("min"), get("max"))
|
110
|
+
|
111
|
+
def _suntay_rotation_matrix_R_transpose_eq26(alpha, beta, gamma):
|
112
|
+
sin_alp, sin_bet, sin_gam = jnp.sin(alpha), jnp.sin(beta), jnp.sin(gamma)
|
113
|
+
cos_alp, cos_bet, cos_gam = jnp.cos(alpha), jnp.cos(beta), jnp.cos(gamma)
|
114
|
+
return jnp.array(
|
115
|
+
[
|
116
|
+
[cos_gam * sin_bet, sin_gam * sin_bet, cos_bet],
|
117
|
+
[
|
118
|
+
-cos_alp * sin_gam - cos_gam * sin_alp * cos_bet,
|
119
|
+
cos_alp * cos_gam - sin_gam * sin_alp * cos_bet,
|
120
|
+
sin_bet * sin_alp,
|
121
|
+
],
|
122
|
+
[
|
123
|
+
sin_alp * sin_gam - cos_gam * cos_alp * cos_bet,
|
124
|
+
-cos_gam * sin_alp - cos_alp * cos_bet * sin_gam,
|
125
|
+
cos_alp * sin_bet,
|
126
|
+
],
|
127
|
+
]
|
128
|
+
).T
|
129
|
+
|
130
|
+
def _suntay_translation_vector_H_eq9(alpha, beta, S):
|
131
|
+
sin_alp, sin_bet = jnp.sin(alpha), jnp.sin(beta)
|
132
|
+
cos_alp, cos_bet = jnp.cos(alpha), jnp.cos(beta)
|
133
|
+
# eq (10)
|
134
|
+
U = jnp.array(
|
135
|
+
[
|
136
|
+
[1, 0, cos_bet],
|
137
|
+
[0, cos_alp, sin_alp * sin_bet],
|
138
|
+
[0, -sin_alp, cos_alp * sin_bet],
|
139
|
+
]
|
140
|
+
)
|
141
|
+
return U @ S
|
142
|
+
|
143
|
+
def _alpha_beta_gamma_S(q_flexion, params):
|
144
|
+
assert q_flexion.shape == (1,)
|
145
|
+
|
146
|
+
# (1,) -> (,)
|
147
|
+
q_flexion = q_flexion[0]
|
148
|
+
|
149
|
+
S_123 = []
|
150
|
+
for i in range(1, 4):
|
151
|
+
key = f"ys_S{i}"
|
152
|
+
S_123.append(draw_fn_pairs[key].apply(params[key], q_flexion))
|
153
|
+
S = jnp.stack(S_123)
|
154
|
+
# table 2 of suntay paper
|
155
|
+
alpha = q_flexion
|
156
|
+
# note the minus sign, because in config we specify `abduction` not `adduction`
|
157
|
+
adduction = -draw_fn_pairs["ys_beta"].apply(params["ys_beta"], q_flexion)
|
158
|
+
beta = jnp.pi / 2 + adduction
|
159
|
+
gamma = draw_fn_pairs["ys_gamma"].apply(params["ys_gamma"], q_flexion)
|
160
|
+
return alpha, beta, gamma, S
|
161
|
+
|
162
|
+
def _utils_find_suntay_joint(sys: ring.System) -> str:
|
163
|
+
suntay_link_name = None
|
164
|
+
for link_name, link_type in zip(sys.link_names, sys.link_types):
|
165
|
+
if link_type == name:
|
166
|
+
if suntay_link_name is not None:
|
167
|
+
raise Exception(
|
168
|
+
f"multiple links of type `{name}` found, link_names "
|
169
|
+
f"are [{suntay_link_name}, {link_name}]"
|
170
|
+
)
|
171
|
+
suntay_link_name = link_name
|
172
|
+
|
173
|
+
if suntay_link_name is None:
|
174
|
+
raise Exception(
|
175
|
+
f"no link with type `{name}` found, link_types are {sys.link_types}"
|
176
|
+
)
|
177
|
+
return suntay_link_name
|
178
|
+
|
179
|
+
def _utils_Q_S_H_alpha_beta_gamma(sys: ring.System, qs: jax.Array):
|
180
|
+
# qs.shape = (timesteps, q_size)
|
181
|
+
assert qs.ndim == 2
|
182
|
+
assert qs.shape[-1] == sys.q_size()
|
183
|
+
|
184
|
+
suntay_link_name = _utils_find_suntay_joint(sys)
|
185
|
+
|
186
|
+
params = jax.tree_map(
|
187
|
+
lambda arr: arr[sys.idx_map("l")[suntay_link_name]],
|
188
|
+
sys.links.joint_params[name],
|
189
|
+
)
|
190
|
+
# shape = (timesteps, 1)
|
191
|
+
q_flexion = qs[:, sys.idx_map("q")[suntay_link_name]]
|
192
|
+
|
193
|
+
@jax.vmap
|
194
|
+
def _Q_S_H_alpha_beta_gamma_from_q_flexion(q_flexion):
|
195
|
+
alpha, beta, gamma, S = _alpha_beta_gamma_S(q_flexion, params)
|
196
|
+
cos_bet = jnp.cos(beta)
|
197
|
+
Q = jnp.array([S[0] + S[2] * cos_bet, S[1], -S[2] - S[0] * cos_bet])
|
198
|
+
# translation from femur to tibia
|
199
|
+
H = _suntay_translation_vector_H_eq9(alpha, beta, S)
|
200
|
+
return Q, S, H, alpha, beta, gamma
|
201
|
+
|
202
|
+
return _Q_S_H_alpha_beta_gamma_from_q_flexion(q_flexion)
|
203
|
+
|
204
|
+
def _transform_suntay(q_flexion, params):
|
205
|
+
alpha, beta, gamma, S = _alpha_beta_gamma_S(q_flexion, params)
|
206
|
+
|
207
|
+
# rotation from femur to tibia
|
208
|
+
R_T = _suntay_rotation_matrix_R_transpose_eq26(alpha, beta, gamma)
|
209
|
+
q_fem_tib = maths.quat_from_3x3(R_T)
|
210
|
+
# translation from femur to tibia
|
211
|
+
H = _suntay_translation_vector_H_eq9(alpha, beta, S)
|
212
|
+
|
213
|
+
return ring.Transform.create(pos=H, rot=q_fem_tib)
|
214
|
+
|
215
|
+
def _init_joint_params_suntay(key):
|
216
|
+
params = dict()
|
217
|
+
for params_name, draw_fn_pair in draw_fn_pairs.items():
|
218
|
+
key, consume = jax.random.split(key)
|
219
|
+
params[params_name] = draw_fn_pair.init(consume)
|
220
|
+
|
221
|
+
return params
|
222
|
+
|
223
|
+
def _draw_flexion_angle(
|
224
|
+
mconfig: ring.MotionConfig,
|
225
|
+
key_t: jax.random.PRNGKey,
|
226
|
+
key_value: jax.random.PRNGKey,
|
227
|
+
dt: float,
|
228
|
+
_: jax.Array,
|
229
|
+
) -> jax.Array:
|
230
|
+
key_value, consume = jax.random.split(key_value)
|
231
|
+
|
232
|
+
if sconfig.mconfig is not None:
|
233
|
+
mconfig = sconfig.mconfig
|
234
|
+
|
235
|
+
ANG_0 = jax.random.uniform(
|
236
|
+
consume, minval=mconfig.ang0_min, maxval=mconfig.ang0_max
|
237
|
+
)
|
238
|
+
# `random_angle_over_time` always returns wrapped angles, thus it would be
|
239
|
+
# inconsistent to allow an initial value that is not wrapped
|
240
|
+
ANG_0 = maths.wrap_to_pi(ANG_0)
|
241
|
+
qs_flexion = random_angle_over_time(
|
242
|
+
key_t,
|
243
|
+
key_value,
|
244
|
+
ANG_0,
|
245
|
+
mconfig.dang_min,
|
246
|
+
mconfig.dang_max,
|
247
|
+
mconfig.delta_ang_min,
|
248
|
+
mconfig.delta_ang_max,
|
249
|
+
mconfig.t_min,
|
250
|
+
mconfig.t_max,
|
251
|
+
mconfig.T,
|
252
|
+
dt,
|
253
|
+
5,
|
254
|
+
mconfig.randomized_interpolation_angle,
|
255
|
+
mconfig.range_of_motion_hinge,
|
256
|
+
mconfig.range_of_motion_hinge_method,
|
257
|
+
mconfig.cdf_bins_min,
|
258
|
+
mconfig.cdf_bins_max,
|
259
|
+
mconfig.interpolation_method,
|
260
|
+
)
|
261
|
+
return restrict(
|
262
|
+
qs_flexion,
|
263
|
+
sconfig.flexion_rot_min,
|
264
|
+
sconfig.flexion_rot_max,
|
265
|
+
-jnp.pi,
|
266
|
+
jnp.pi,
|
267
|
+
method=sconfig.flexion_rot_restrict_method,
|
268
|
+
)
|
269
|
+
|
270
|
+
joint_model = ring.JointModel(
|
271
|
+
transform=_transform_suntay,
|
272
|
+
rcmg_draw_fn=_draw_flexion_angle,
|
273
|
+
init_joint_params=_init_joint_params_suntay,
|
274
|
+
utilities=dict(
|
275
|
+
Q_S_H_alpha_beta_gamma=_utils_Q_S_H_alpha_beta_gamma,
|
276
|
+
find_suntay_joint=_utils_find_suntay_joint,
|
277
|
+
),
|
278
|
+
)
|
279
|
+
ring.register_new_joint_type(name, joint_model, 1, qd_width=0, overwrite=True)
|
280
|
+
|
281
|
+
|
282
|
+
def MLP_DrawnFnPair(
|
283
|
+
center: bool = False, flexion_center: Optional[float] = None
|
284
|
+
) -> DrawnFnPairFactory:
|
285
|
+
|
286
|
+
def factory(xs, mn, mx):
|
287
|
+
nonlocal flexion_center
|
288
|
+
|
289
|
+
flexion_mn = jnp.min(xs)
|
290
|
+
flexion_mx = jnp.max(xs)
|
291
|
+
|
292
|
+
if flexion_center is None:
|
293
|
+
flexion_center = (flexion_mn + flexion_mx) / 2
|
294
|
+
|
295
|
+
@hk.without_apply_rng
|
296
|
+
@hk.transform
|
297
|
+
def mlp(x):
|
298
|
+
# normalize the x input; [0, 1]
|
299
|
+
x = _shift(x, flexion_mn, flexion_mx)
|
300
|
+
# center the x input; [-0.5, 0.5]
|
301
|
+
x = x - 0.5
|
302
|
+
net = hk.nets.MLP(
|
303
|
+
[10, 5, 1],
|
304
|
+
activation=jnp.tanh,
|
305
|
+
w_init=hk.initializers.RandomNormal(),
|
306
|
+
)
|
307
|
+
return net(x)
|
308
|
+
|
309
|
+
example_q = jnp.zeros((1,))
|
310
|
+
|
311
|
+
def init(key):
|
312
|
+
return mlp.init(key, example_q)
|
313
|
+
|
314
|
+
def _apply(params, q):
|
315
|
+
q = q[None]
|
316
|
+
return jnp.squeeze(_shift_inv(jax.nn.sigmoid(mlp.apply(params, q)), mn, mx))
|
317
|
+
|
318
|
+
if center:
|
319
|
+
|
320
|
+
def apply(params, q):
|
321
|
+
return _apply(params, q) - _apply(params, flexion_center)
|
322
|
+
|
323
|
+
else:
|
324
|
+
apply = _apply
|
325
|
+
|
326
|
+
return DrawnFnPair(init, apply)
|
327
|
+
|
328
|
+
return factory
|
329
|
+
|
330
|
+
|
331
|
+
def _gp_draw_and_rom(key, xs, ys, length_scale, mn, mx, amin, amax):
|
332
|
+
randomized_ys = _gp_draw(key, xs, ys, length_scale)
|
333
|
+
if ys is not None:
|
334
|
+
amin += jnp.min(ys)
|
335
|
+
amax += jnp.max(ys)
|
336
|
+
return restrict(randomized_ys, mn, mx, amin, amax)
|
337
|
+
|
338
|
+
|
339
|
+
def _gp_draw(key, xs, ys=None, length: float = 1.0, noise=0.0, method="svd", **kwargs):
|
340
|
+
if ys is None:
|
341
|
+
ys = jnp.zeros_like(xs)
|
342
|
+
cov = _gp_K(lambda *args: _rbf_kernel(*args, length=length), xs, noise)
|
343
|
+
return jax.random.multivariate_normal(
|
344
|
+
key=key, mean=ys, cov=cov, method=method, **kwargs
|
345
|
+
)
|
346
|
+
|
347
|
+
|
348
|
+
def _gp_K(kernel, xs, noise: float):
|
349
|
+
assert xs.ndim == 1
|
350
|
+
N = len(xs)
|
351
|
+
xs = xs[:, None]
|
352
|
+
|
353
|
+
K = jax.vmap(lambda x1: jax.vmap(lambda x2: kernel(x1, x2))(xs))(xs)
|
354
|
+
assert K.shape == (N, N, 1)
|
355
|
+
return K[..., 0] + jnp.eye(N) * noise
|
356
|
+
|
357
|
+
|
358
|
+
def _rbf_kernel(x1: float, x2: float, length: float):
|
359
|
+
return jnp.exp(-((x1 - x2) ** 2) / (2 * length**2))
|
360
|
+
|
361
|
+
|
362
|
+
def _shift(ys, min, max):
|
363
|
+
return (ys - min) / (max - min)
|
364
|
+
|
365
|
+
|
366
|
+
def _shift_inv(ys, min, max):
|
367
|
+
return (ys * (max - min)) + min
|
368
|
+
|
369
|
+
|
370
|
+
def _normalize(ys, amin=None, amax=None):
|
371
|
+
if amin is None:
|
372
|
+
amin = jnp.min(ys)
|
373
|
+
else:
|
374
|
+
amin = jnp.min(jnp.array([amin, jnp.min(ys)]))
|
375
|
+
if amax is None:
|
376
|
+
amax = jnp.max(ys)
|
377
|
+
else:
|
378
|
+
amax = jnp.max(jnp.array([amax, jnp.max(ys)]))
|
379
|
+
return _shift(ys, amin, amax)
|
380
|
+
|
381
|
+
|
382
|
+
def _smoothclamp(x, mi, mx):
|
383
|
+
return mi + (mx - mi) * (
|
384
|
+
lambda t: jnp.where(t < 0, 0, jnp.where(t <= 1, 3 * t**2 - 2 * t**3, 1))
|
385
|
+
)((x - mi) / (mx - mi))
|
386
|
+
|
387
|
+
|
388
|
+
def _sigmoidclamp(x, mi, mx):
|
389
|
+
return mi + (mx - mi) * (lambda t: (1 + 200 ** (-t + 0.5)) ** (-1))(
|
390
|
+
(x - mi) / (mx - mi)
|
391
|
+
)
|
392
|
+
|
393
|
+
|
394
|
+
def restrict(
|
395
|
+
ys,
|
396
|
+
min: float,
|
397
|
+
max: float,
|
398
|
+
actual_min=None,
|
399
|
+
actual_max=None,
|
400
|
+
method: str = "minmax",
|
401
|
+
method_kwargs=dict(),
|
402
|
+
):
|
403
|
+
if method == "minmax":
|
404
|
+
# scale to [0, 1]
|
405
|
+
ys = _normalize(ys, actual_min, actual_max)
|
406
|
+
# scale to [min, max]
|
407
|
+
return _shift_inv(ys, min, max)
|
408
|
+
elif method == "clip":
|
409
|
+
return jnp.clip(ys, min, max)
|
410
|
+
elif method == "smoothclamp":
|
411
|
+
return _smoothclamp(ys, min, max)
|
412
|
+
elif method == "sigmoidclamp":
|
413
|
+
return _sigmoidclamp(ys, min, max)
|
414
|
+
elif method == "sigmoid":
|
415
|
+
# scale to [0, 1]
|
416
|
+
ys = _normalize(ys, actual_min, actual_max)
|
417
|
+
# scale to [-stepness, stepness]
|
418
|
+
stepness = method_kwargs.get("stepness", 3.0)
|
419
|
+
ys = _shift_inv(ys, -stepness, stepness)
|
420
|
+
# scale to [0, 1]
|
421
|
+
ys = jax.nn.sigmoid(ys)
|
422
|
+
return _shift_inv(ys, min, max)
|
423
|
+
else:
|
424
|
+
raise NotImplementedError()
|