jaxsim 0.4.3.dev133__py3-none-any.whl → 0.4.3.dev139__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/integrators/variable_step.py +12 -7
- jaxsim/rbda/contacts/relaxed_rigid.py +60 -12
- {jaxsim-0.4.3.dev133.dist-info → jaxsim-0.4.3.dev139.dist-info}/METADATA +2 -2
- {jaxsim-0.4.3.dev133.dist-info → jaxsim-0.4.3.dev139.dist-info}/RECORD +8 -8
- {jaxsim-0.4.3.dev133.dist-info → jaxsim-0.4.3.dev139.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev133.dist-info → jaxsim-0.4.3.dev139.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev133.dist-info → jaxsim-0.4.3.dev139.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev139'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev139')
|
@@ -262,7 +262,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
262
262
|
**kwargs,
|
263
263
|
)
|
264
264
|
|
265
|
-
def __call__(
|
265
|
+
def __call__(
|
266
|
+
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
267
|
+
) -> tuple[NextState, dict[str, Any]]:
|
266
268
|
|
267
269
|
# This method is called differently in three stages:
|
268
270
|
#
|
@@ -294,14 +296,17 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
294
296
|
# In Stage 3, dt0 is taken from the previous step. If the integrator supports
|
295
297
|
# FSAL, dxdt0 is taken from the previous step. Otherwise, it is computed by
|
296
298
|
# evaluating the dynamics.
|
297
|
-
self.params["dt0"], self.params["dxdt0"] = jax.lax.cond(
|
299
|
+
self.params["dt0"], self.params["dxdt0"], aux_dict = jax.lax.cond(
|
298
300
|
pred=jnp.logical_or("dt0" not in self.params, integrator_first_step),
|
299
|
-
true_fun=lambda params:
|
300
|
-
|
301
|
+
true_fun=lambda params: (
|
302
|
+
*estimate_step_size(
|
303
|
+
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
|
304
|
+
),
|
305
|
+
self.params.get("dxdt0", f(x0, t0))[1],
|
301
306
|
),
|
302
307
|
false_fun=lambda params: (
|
303
308
|
params.get("dt0", jnp.array(0).astype(float)),
|
304
|
-
self.params.get("dxdt0", f(x0, t0)
|
309
|
+
*self.params.get("dxdt0", f(x0, t0)),
|
305
310
|
),
|
306
311
|
operand=self.params,
|
307
312
|
)
|
@@ -355,7 +360,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
355
360
|
# The output z contains multiple solutions (depending on the rows of b.T).
|
356
361
|
with self.editable(validate=True) as integrator:
|
357
362
|
integrator.params = params
|
358
|
-
z = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
|
363
|
+
z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
|
359
364
|
params_next = integrator.params
|
360
365
|
|
361
366
|
# Extract the high-order solution xf and the low-order estimate x̂f.
|
@@ -481,7 +486,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
481
486
|
with self.mutable_context(mutability=Mutability.MUTABLE):
|
482
487
|
self.params = params_tf
|
483
488
|
|
484
|
-
return xf
|
489
|
+
return xf, aux_dict
|
485
490
|
|
486
491
|
@property
|
487
492
|
def order_of_solution(self) -> int:
|
@@ -1,12 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
+
from collections.abc import Callable
|
4
5
|
from typing import Any
|
5
6
|
|
6
7
|
import jax
|
7
8
|
import jax.numpy as jnp
|
8
9
|
import jax_dataclasses
|
9
|
-
import
|
10
|
+
import optax
|
10
11
|
|
11
12
|
import jaxsim.api as js
|
12
13
|
import jaxsim.typing as jtp
|
@@ -297,24 +298,71 @@ class RelaxedRigidContacts(ContactModel):
|
|
297
298
|
A = G + R
|
298
299
|
b = CW_al_free_WC - a_ref
|
299
300
|
|
300
|
-
objective = lambda x: jnp.sum(jnp.square(A @ x + b))
|
301
|
+
objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
|
301
302
|
|
302
|
-
|
303
|
-
|
304
|
-
fun
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
303
|
+
def run_optimization(
|
304
|
+
init_params: jtp.Array,
|
305
|
+
fun: Callable,
|
306
|
+
opt: optax.GradientTransformation,
|
307
|
+
maxiter: jtp.Int,
|
308
|
+
tol: jtp.Float,
|
309
|
+
**kwargs,
|
310
|
+
):
|
311
|
+
value_and_grad_fn = optax.value_and_grad_from_state(fun)
|
312
|
+
|
313
|
+
def step(carry):
|
314
|
+
params, state = carry
|
315
|
+
value, grad = value_and_grad_fn(
|
316
|
+
params,
|
317
|
+
state=state,
|
318
|
+
A=A,
|
319
|
+
b=b,
|
320
|
+
)
|
321
|
+
updates, state = opt.update(
|
322
|
+
updates=grad,
|
323
|
+
state=state,
|
324
|
+
params=params,
|
325
|
+
value=value,
|
326
|
+
grad=grad,
|
327
|
+
value_fn=fun,
|
328
|
+
A=A,
|
329
|
+
b=b,
|
330
|
+
)
|
331
|
+
params = optax.apply_updates(params, updates)
|
332
|
+
return params, state
|
333
|
+
|
334
|
+
def continuing_criterion(carry):
|
335
|
+
_, state = carry
|
336
|
+
iter_num = optax.tree_utils.tree_get(state, "count")
|
337
|
+
grad = optax.tree_utils.tree_get(state, "grad")
|
338
|
+
err = optax.tree_utils.tree_l2_norm(grad)
|
339
|
+
return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))
|
340
|
+
|
341
|
+
init_carry = (init_params, opt.init(init_params))
|
342
|
+
final_params, final_state = jax.lax.while_loop(
|
343
|
+
continuing_criterion, step, init_carry
|
344
|
+
)
|
345
|
+
return final_params, final_state
|
311
346
|
|
312
347
|
init_params = (
|
313
348
|
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
|
314
349
|
+ D[:, jnp.newaxis] * velocity
|
315
350
|
).flatten()
|
316
351
|
|
317
|
-
|
352
|
+
# Compute the 3D linear force in C[W] frame
|
353
|
+
CW_f_Ci, _ = run_optimization(
|
354
|
+
init_params=init_params,
|
355
|
+
A=A,
|
356
|
+
b=b,
|
357
|
+
maxiter=self.parameters.max_iterations,
|
358
|
+
opt=optax.lbfgs(
|
359
|
+
memory_size=10,
|
360
|
+
),
|
361
|
+
fun=objective,
|
362
|
+
tol=self.parameters.tolerance,
|
363
|
+
)
|
364
|
+
|
365
|
+
CW_f_Ci = CW_f_Ci.reshape((-1, 3))
|
318
366
|
|
319
367
|
def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
|
320
368
|
W_Xf_CW = Adjoint.from_transform(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev139
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -61,11 +61,11 @@ Description-Content-Type: text/markdown
|
|
61
61
|
License-File: LICENSE
|
62
62
|
Requires-Dist: coloredlogs
|
63
63
|
Requires-Dist: jax>=0.4.26
|
64
|
-
Requires-Dist: jaxopt>=0.8.0
|
65
64
|
Requires-Dist: jaxlib>=0.4.26
|
66
65
|
Requires-Dist: jaxlie>=1.3.0
|
67
66
|
Requires-Dist: jax-dataclasses>=1.4.0
|
68
67
|
Requires-Dist: pptree
|
68
|
+
Requires-Dist: optax>=0.2.3
|
69
69
|
Requires-Dist: qpax
|
70
70
|
Requires-Dist: rod>=0.3.3
|
71
71
|
Requires-Dist: typing-extensions; python_version < "3.12"
|
@@ -1,5 +1,5 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=3AziyYsOidZySopkrxUFcVGUEMK5LFxmckSLgq6qtds,428
|
3
3
|
jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
|
@@ -19,7 +19,7 @@ jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,2077
|
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
20
|
jaxsim/integrators/common.py,sha256=78MBs89GxsL0wU2yAexjvBZt3HEtfZoGVIN9f0a8yTc,20305
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
22
|
-
jaxsim/integrators/variable_step.py,sha256=
|
22
|
+
jaxsim/integrators/variable_step.py,sha256=1VoSU3GeFcGEuP2dgZQ83sTkI5Xe-IThqKlRoVtwGSE,21270
|
23
23
|
jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
|
24
24
|
jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
|
25
25
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
@@ -54,7 +54,7 @@ jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
|
54
54
|
jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
|
55
55
|
jaxsim/rbda/contacts/__init__.py,sha256=0UnO9ZR3BwdjQa276jOFbPi90pporr32LSc0qa9UUm4,369
|
56
56
|
jaxsim/rbda/contacts/common.py,sha256=-eM8d1kvJ2E_2_kAgZJk4s3x8vDZHNSyOAinwPmRmEk,3469
|
57
|
-
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=
|
57
|
+
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=8kytUPYUmYXVrPEoHbCduFp5KOmOFPK4Vmqv3KhDqy8,15738
|
58
58
|
jaxsim/rbda/contacts/rigid.py,sha256=6cU8kM8LMjEFbt8dtSg5nnz_uh4aD50sKw_svCzYUms,15633
|
59
59
|
jaxsim/rbda/contacts/soft.py,sha256=NzzCYw5rvK8Fx_qH3fiMzPgey-KoxmRe9xkF3fluidE,18866
|
60
60
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
@@ -63,8 +63,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
|
63
63
|
jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
|
64
64
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
65
65
|
jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
|
66
|
-
jaxsim-0.4.3.
|
67
|
-
jaxsim-0.4.3.
|
68
|
-
jaxsim-0.4.3.
|
69
|
-
jaxsim-0.4.3.
|
70
|
-
jaxsim-0.4.3.
|
66
|
+
jaxsim-0.4.3.dev139.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
67
|
+
jaxsim-0.4.3.dev139.dist-info/METADATA,sha256=uWmLplGU4SQyTl8iOGoW-isZe14lCh-mTkNXk1EThJM,17276
|
68
|
+
jaxsim-0.4.3.dev139.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
69
|
+
jaxsim-0.4.3.dev139.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
70
|
+
jaxsim-0.4.3.dev139.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|