jaxsim 0.4.3.dev77__py3-none-any.whl → 0.4.3.dev80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/kin_dyn_parameters.py +2 -4
- jaxsim/api/link.py +1 -1
- jaxsim/integrators/common.py +8 -12
- jaxsim/integrators/variable_step.py +13 -15
- jaxsim/rbda/contacts/relaxed_rigid.py +1 -1
- jaxsim/utils/jaxsim_dataclass.py +1 -1
- {jaxsim-0.4.3.dev77.dist-info → jaxsim-0.4.3.dev80.dist-info}/METADATA +3 -3
- {jaxsim-0.4.3.dev77.dist-info → jaxsim-0.4.3.dev80.dist-info}/RECORD +12 -12
- {jaxsim-0.4.3.dev77.dist-info → jaxsim-0.4.3.dev80.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev77.dist-info → jaxsim-0.4.3.dev80.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev77.dist-info → jaxsim-0.4.3.dev80.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev80'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev80')
|
jaxsim/api/kin_dyn_parameters.py
CHANGED
@@ -98,9 +98,7 @@ class KynDynParameters(JaxsimDataclass):
|
|
98
98
|
]
|
99
99
|
|
100
100
|
# Create a vectorized object of link parameters.
|
101
|
-
link_parameters = jax.
|
102
|
-
lambda *l: jnp.stack(l), *link_parameters_list
|
103
|
-
)
|
101
|
+
link_parameters = jax.tree.map(lambda *l: jnp.stack(l), *link_parameters_list)
|
104
102
|
|
105
103
|
# =================
|
106
104
|
# Joints properties
|
@@ -114,7 +112,7 @@ class KynDynParameters(JaxsimDataclass):
|
|
114
112
|
|
115
113
|
# Create a vectorized object of joint parameters.
|
116
114
|
joint_parameters = (
|
117
|
-
jax.
|
115
|
+
jax.tree.map(lambda *l: jnp.stack(l), *joint_parameters_list)
|
118
116
|
if len(ordered_joints) > 0
|
119
117
|
else JointParameters(
|
120
118
|
index=jnp.array([], dtype=int),
|
jaxsim/api/link.py
CHANGED
jaxsim/integrators/common.py
CHANGED
@@ -173,9 +173,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
173
173
|
|
174
174
|
# Make sure that all leafs of the dictionary are JAX arrays.
|
175
175
|
# Also, since these are dummy parameters, set them all to zero.
|
176
|
-
params_after_init = jax.
|
177
|
-
lambda l: jnp.zeros_like(l), integrator.params
|
178
|
-
)
|
176
|
+
params_after_init = jax.tree.map(lambda l: jnp.zeros_like(l), integrator.params)
|
179
177
|
|
180
178
|
# Mark the next step as first step after initialization.
|
181
179
|
params_after_init = params_after_init | {
|
@@ -290,7 +288,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
290
288
|
z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
291
289
|
|
292
290
|
# The next state is the batch element located at the configured index of solution.
|
293
|
-
next_state = jax.
|
291
|
+
next_state = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
|
294
292
|
|
295
293
|
return next_state, aux_dict
|
296
294
|
|
@@ -327,7 +325,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
327
325
|
"""
|
328
326
|
|
329
327
|
op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
|
330
|
-
return jax.
|
328
|
+
return jax.tree.map(op, x0, k)
|
331
329
|
|
332
330
|
@classmethod
|
333
331
|
def post_process_state(
|
@@ -374,7 +372,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
374
372
|
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
375
373
|
|
376
374
|
# Initialize the carry of the for loop with the stacked kᵢ vectors.
|
377
|
-
carry0 = jax.
|
375
|
+
carry0 = jax.tree.map(
|
378
376
|
lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0),
|
379
377
|
x0,
|
380
378
|
)
|
@@ -398,7 +396,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
398
396
|
|
399
397
|
# Compute ∑ⱼ aᵢⱼ kⱼ.
|
400
398
|
op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
|
401
|
-
sum_ak = jax.
|
399
|
+
sum_ak = jax.tree.map(op_sum_ak, K)
|
402
400
|
|
403
401
|
# Compute the next state for the kᵢ evaluation.
|
404
402
|
# Note that this is not a Δt integration since aᵢⱼ could be fractional.
|
@@ -419,7 +417,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
419
417
|
|
420
418
|
# Store the kᵢ derivative in K.
|
421
419
|
op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
|
422
|
-
K = jax.
|
420
|
+
K = jax.tree.map(op, K, ki)
|
423
421
|
|
424
422
|
carry = K
|
425
423
|
return carry, aux_dict
|
@@ -433,14 +431,12 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
433
431
|
|
434
432
|
# Update the FSAL property for the next iteration.
|
435
433
|
if self.has_fsal:
|
436
|
-
self.params["dxdt0"] = jax.
|
437
|
-
lambda l: l[self.index_of_fsal], K
|
438
|
-
)
|
434
|
+
self.params["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
|
439
435
|
|
440
436
|
# Compute the output state.
|
441
437
|
# Note that z contains as many new states as the rows of `b.T`.
|
442
438
|
op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
|
443
|
-
z = jax.
|
439
|
+
z = jax.tree.map(op, x0, K)
|
444
440
|
|
445
441
|
# Transform the final state of the integration.
|
446
442
|
# This allows to inject custom logic, if needed.
|
@@ -87,13 +87,13 @@ def estimate_step_size(
|
|
87
87
|
|
88
88
|
# Compute the scaling factors of the initial state and its derivative.
|
89
89
|
compute_scale = lambda x: atol + jnp.abs(x) * rtol
|
90
|
-
scale0 = jax.
|
91
|
-
scale1 = jax.
|
90
|
+
scale0 = jax.tree.map(compute_scale, x0)
|
91
|
+
scale1 = jax.tree.map(compute_scale, ẋ0)
|
92
92
|
|
93
93
|
# Scale the initial state and its derivative.
|
94
94
|
scale_pytree = lambda x, scale: jnp.abs(x) / scale
|
95
|
-
x0_scaled = jax.
|
96
|
-
ẋ0_scaled = jax.
|
95
|
+
x0_scaled = jax.tree.map(scale_pytree, x0, scale0)
|
96
|
+
ẋ0_scaled = jax.tree.map(scale_pytree, ẋ0, scale1)
|
97
97
|
|
98
98
|
# Get the maximum of the scaled pytrees.
|
99
99
|
d0 = jnp.linalg.norm(flatten(x0_scaled), ord=jnp.inf)
|
@@ -103,16 +103,16 @@ def estimate_step_size(
|
|
103
103
|
h0 = jnp.where(jnp.minimum(d0, d1) <= 1e-5, 1e-6, 0.01 * d0 / d1)
|
104
104
|
|
105
105
|
# Compute the next state (explicit Euler step) and its derivative.
|
106
|
-
x1 = jax.
|
106
|
+
x1 = jax.tree.map(lambda x0, ẋ0: x0 + h0 * ẋ0, x0, ẋ0)
|
107
107
|
ẋ1 = f(x1, t0 + h0)[0]
|
108
108
|
|
109
109
|
# Compute the scaling factor of the state derivatives.
|
110
110
|
compute_scale_2 = lambda ẋ0, ẋ1: atol + jnp.maximum(jnp.abs(ẋ0), jnp.abs(ẋ1)) * rtol
|
111
|
-
scale2 = jax.
|
111
|
+
scale2 = jax.tree.map(compute_scale_2, ẋ0, ẋ1)
|
112
112
|
|
113
113
|
# Scale the difference of the state derivatives.
|
114
114
|
scale_ẋ_difference = lambda ẋ0, ẋ1, scale: jnp.abs((ẋ0 - ẋ1) / scale)
|
115
|
-
ẋ_difference_scaled = jax.
|
115
|
+
ẋ_difference_scaled = jax.tree.map(scale_ẋ_difference, ẋ0, ẋ1, scale2)
|
116
116
|
|
117
117
|
# Get the maximum of the scaled derivatives difference.
|
118
118
|
d2 = jnp.linalg.norm(flatten(ẋ_difference_scaled), ord=jnp.inf) / h0
|
@@ -151,11 +151,11 @@ def compute_pytree_scale(
|
|
151
151
|
"""
|
152
152
|
|
153
153
|
# Consider a zero second pytree, if not given.
|
154
|
-
x2 = jax.
|
154
|
+
x2 = jax.tree.map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
|
155
155
|
|
156
156
|
# Compute the scaling factors of the initial state and its derivative.
|
157
157
|
compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
|
158
|
-
scale = jax.
|
158
|
+
scale = jax.tree.map(compute_scale, x1, x2)
|
159
159
|
|
160
160
|
return scale
|
161
161
|
|
@@ -198,14 +198,14 @@ def local_error_estimation(
|
|
198
198
|
|
199
199
|
# Consider a zero estimated final state, if not given.
|
200
200
|
xf_estimate = (
|
201
|
-
jax.
|
201
|
+
jax.tree.map(lambda l: jnp.zeros_like(l), xf)
|
202
202
|
if xf_estimate is None
|
203
203
|
else xf_estimate
|
204
204
|
)
|
205
205
|
|
206
206
|
# Estimate the error.
|
207
207
|
estimate_error = lambda l, l̂, sc: jnp.abs(l - l̂) / sc
|
208
|
-
error_estimate = jax.
|
208
|
+
error_estimate = jax.tree.map(estimate_error, xf, xf_estimate, scale)
|
209
209
|
|
210
210
|
# Return the highest element of the error estimate.
|
211
211
|
return jnp.linalg.norm(flatten(error_estimate), ord=norm_ord)
|
@@ -359,10 +359,8 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
359
359
|
params_next = integrator.params
|
360
360
|
|
361
361
|
# Extract the high-order solution xf and the low-order estimate x̂f.
|
362
|
-
xf = jax.
|
363
|
-
x̂f = jax.
|
364
|
-
lambda l: l[self.row_index_of_solution_estimate], z
|
365
|
-
)
|
362
|
+
xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
|
363
|
+
x̂f = jax.tree.map(lambda l: l[self.row_index_of_solution_estimate], z)
|
366
364
|
|
367
365
|
# Calculate the local integration error.
|
368
366
|
local_error = local_error_estimation(
|
@@ -230,7 +230,7 @@ class RelaxedRigidContacts(ContactModel):
|
|
230
230
|
)
|
231
231
|
|
232
232
|
def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
|
233
|
-
x, y, z = jax.
|
233
|
+
x, y, z = jax.tree.map(jnp.squeeze, (x, y, z))
|
234
234
|
|
235
235
|
n̂ = self.terrain.normal(x=x, y=y).squeeze()
|
236
236
|
h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
|
jaxsim/utils/jaxsim_dataclass.py
CHANGED
@@ -298,7 +298,7 @@ class JaxsimDataclass(abc.ABC):
|
|
298
298
|
"""
|
299
299
|
|
300
300
|
# Make a copy calling tree_map.
|
301
|
-
obj = jax.
|
301
|
+
obj = jax.tree.map(lambda leaf: leaf, self)
|
302
302
|
|
303
303
|
# Make sure that the copied object and all the copied leaves have the same
|
304
304
|
# mutability of the original object.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev80
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -60,9 +60,9 @@ Requires-Python: >=3.10
|
|
60
60
|
Description-Content-Type: text/markdown
|
61
61
|
License-File: LICENSE
|
62
62
|
Requires-Dist: coloredlogs
|
63
|
-
Requires-Dist: jax>=0.4.
|
63
|
+
Requires-Dist: jax>=0.4.26
|
64
64
|
Requires-Dist: jaxopt>=0.8.0
|
65
|
-
Requires-Dist: jaxlib>=0.4.
|
65
|
+
Requires-Dist: jaxlib>=0.4.26
|
66
66
|
Requires-Dist: jaxlie>=1.3.0
|
67
67
|
Requires-Dist: jax-dataclasses>=1.4.0
|
68
68
|
Requires-Dist: pptree
|
@@ -1,5 +1,5 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=G5Qm6992nEqZe7NQkfmL8p-gHKaNDVraUdwnF4D-BbI,426
|
3
3
|
jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
|
@@ -10,16 +10,16 @@ jaxsim/api/contact.py,sha256=Ek1xSKB_VWjfqsqRYlK236ountKmGTl1M04cTYqHgsE,22142
|
|
10
10
|
jaxsim/api/data.py,sha256=QldUHniJqKrdNtAcXuRaS9UyeslJ0Rjvb17UA0Ca5Tw,29008
|
11
11
|
jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
|
12
12
|
jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
|
13
|
-
jaxsim/api/kin_dyn_parameters.py,sha256=
|
14
|
-
jaxsim/api/link.py,sha256=
|
13
|
+
jaxsim/api/kin_dyn_parameters.py,sha256=FrWymdta36THv5QFTzxorJtYiKTVDg7HqOcPTHa12VM,29327
|
14
|
+
jaxsim/api/link.py,sha256=LAA6ZMQXkWomXeptURBtc7z3_xDZ2BBnBMhVrohh0bE,18621
|
15
15
|
jaxsim/api/model.py,sha256=TLjgacgTXm-2YRGDA0Id9pe9nxIem28KoAls6Tdk9WM,66241
|
16
16
|
jaxsim/api/ode.py,sha256=ZshGdHptftku0yoUwBiBdd1iOqntH0vVEOjRHfL7Fao,13518
|
17
17
|
jaxsim/api/ode_data.py,sha256=7RSoBhfCJdP6P9InQbDwdBVpClPMMuetewI-6AWm-_0,20276
|
18
18
|
jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
|
-
jaxsim/integrators/common.py,sha256=
|
20
|
+
jaxsim/integrators/common.py,sha256=78MBs89GxsL0wU2yAexjvBZt3HEtfZoGVIN9f0a8yTc,20305
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
22
|
-
jaxsim/integrators/variable_step.py,sha256=
|
22
|
+
jaxsim/integrators/variable_step.py,sha256=cJD98q5BaiSKvp_KY_1KN3PZpAUJR3L8YRmLX5WPJJo,21114
|
23
23
|
jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
|
24
24
|
jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
|
25
25
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
@@ -54,17 +54,17 @@ jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
|
54
54
|
jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
|
55
55
|
jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
56
|
jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
|
57
|
-
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=
|
57
|
+
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=deTC0M2a_RER7iwVpxLCfuSlgBLqkTmHQdOJ4169IR4,13646
|
58
58
|
jaxsim/rbda/contacts/rigid.py,sha256=zbSM0miwpgC1rp1d0RoQ1q8pYiKdIkHV8iZimeEPC94,15153
|
59
59
|
jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
|
60
60
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
61
61
|
jaxsim/terrain/terrain.py,sha256=xUQg47yGxIOcTkLPbnO3sruEGBhoCd16j1evTGlmNjI,5010
|
62
62
|
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
63
|
-
jaxsim/utils/jaxsim_dataclass.py,sha256=
|
63
|
+
jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
|
64
64
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
65
65
|
jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
|
66
|
-
jaxsim-0.4.3.
|
67
|
-
jaxsim-0.4.3.
|
68
|
-
jaxsim-0.4.3.
|
69
|
-
jaxsim-0.4.3.
|
70
|
-
jaxsim-0.4.3.
|
66
|
+
jaxsim-0.4.3.dev80.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
67
|
+
jaxsim-0.4.3.dev80.dist-info/METADATA,sha256=BLXHcGNmem2sMUaAvTy-6E1XsRF5JdPIrPzQkiSeFyQ,17276
|
68
|
+
jaxsim-0.4.3.dev80.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
69
|
+
jaxsim-0.4.3.dev80.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
70
|
+
jaxsim-0.4.3.dev80.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|