imt-ring 1.3.3__py3-none-any.whl → 1.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.3.dist-info → imt_ring-1.3.4.dist-info}/METADATA +1 -1
- {imt_ring-1.3.3.dist-info → imt_ring-1.3.4.dist-info}/RECORD +5 -5
- ring/algorithms/custom_joints/suntay.py +9 -5
- {imt_ring-1.3.3.dist-info → imt_ring-1.3.4.dist-info}/WHEEL +0 -0
- {imt_ring-1.3.3.dist-info → imt_ring-1.3.4.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@ ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18
|
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
|
13
13
|
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
15
|
+
ring/algorithms/custom_joints/suntay.py,sha256=CN0q6G2bjmufNMr7eAjKwIKwXpFp9qFjqrJCey-xIYE,15858
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
|
17
17
|
ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
|
18
18
|
ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
|
@@ -77,7 +77,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
77
77
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
78
78
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
79
79
|
ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
|
80
|
-
imt_ring-1.3.
|
81
|
-
imt_ring-1.3.
|
82
|
-
imt_ring-1.3.
|
83
|
-
imt_ring-1.3.
|
80
|
+
imt_ring-1.3.4.dist-info/METADATA,sha256=oMI6d91TaCBndSOQJjoVBzqsYDg3oGtAxDzSmSamWmg,3104
|
81
|
+
imt_ring-1.3.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
82
|
+
imt_ring-1.3.4.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
83
|
+
imt_ring-1.3.4.dist-info/RECORD,,
|
@@ -295,14 +295,15 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
295
295
|
|
296
296
|
def Polynomial_DrawnFnPair(
|
297
297
|
order: int = 2,
|
298
|
-
val: float = 2.0,
|
299
298
|
center: bool = False,
|
300
299
|
flexion_center: Optional[float] = None,
|
300
|
+
include_bias: bool = True,
|
301
301
|
) -> DrawnFnPairFactory:
|
302
|
-
assert
|
302
|
+
assert not (order == 0 and not include_bias)
|
303
303
|
|
304
304
|
# because 0-th order is also counted
|
305
305
|
order += 1
|
306
|
+
powers = jnp.arange(order) if include_bias else jnp.arange(1, order)
|
306
307
|
|
307
308
|
def factory(xs, mn, mx):
|
308
309
|
nonlocal flexion_center
|
@@ -311,7 +312,7 @@ def Polynomial_DrawnFnPair(
|
|
311
312
|
flexion_mx = jnp.max(xs)
|
312
313
|
|
313
314
|
def _apply_poly_factors(poly_factors, q):
|
314
|
-
return poly_factors @ jnp.power(q,
|
315
|
+
return poly_factors @ jnp.power(q, powers)
|
315
316
|
|
316
317
|
if flexion_center is None:
|
317
318
|
flexion_center = (flexion_mn + flexion_mx) / 2
|
@@ -319,9 +320,9 @@ def Polynomial_DrawnFnPair(
|
|
319
320
|
flexion_center = jnp.array(flexion_center)
|
320
321
|
|
321
322
|
def init(key):
|
322
|
-
c1, c2 = jax.random.split(key)
|
323
|
+
c1, c2, c3 = jax.random.split(key, 3)
|
323
324
|
poly_factors = jax.random.uniform(
|
324
|
-
c1, shape=(
|
325
|
+
c1, shape=(len(powers),), minval=-1.0, maxval=1.0
|
325
326
|
)
|
326
327
|
q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
|
327
328
|
values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
|
@@ -329,6 +330,9 @@ def Polynomial_DrawnFnPair(
|
|
329
330
|
)
|
330
331
|
eps = 1e-6
|
331
332
|
amin, amax = jnp.min(values), jnp.max(values) + eps
|
333
|
+
delta = amax - amin
|
334
|
+
scale_delta = jnp.clip(jax.random.normal(c3) + 0.5, 1.0)
|
335
|
+
amax = amin + delta * scale_delta
|
332
336
|
return amin, amax, poly_factors, q0
|
333
337
|
|
334
338
|
def _apply(params, q):
|
File without changes
|
File without changes
|