jaxsim 0.5.1.dev48__py3-none-any.whl → 0.5.1.dev60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/com.py +8 -0
- jaxsim/api/common.py +17 -1
- jaxsim/api/contact.py +9 -0
- jaxsim/api/data.py +16 -0
- jaxsim/api/frame.py +7 -0
- jaxsim/api/joint.py +4 -0
- jaxsim/api/model.py +22 -1
- jaxsim/api/ode.py +3 -0
- jaxsim/api/ode_data.py +3 -217
- jaxsim/api/references.py +44 -43
- jaxsim/rbda/contacts/visco_elastic.py +1 -1
- {jaxsim-0.5.1.dev48.dist-info → jaxsim-0.5.1.dev60.dist-info}/METADATA +1 -1
- {jaxsim-0.5.1.dev48.dist-info → jaxsim-0.5.1.dev60.dist-info}/RECORD +17 -17
- {jaxsim-0.5.1.dev48.dist-info → jaxsim-0.5.1.dev60.dist-info}/LICENSE +0 -0
- {jaxsim-0.5.1.dev48.dist-info → jaxsim-0.5.1.dev60.dist-info}/WHEEL +0 -0
- {jaxsim-0.5.1.dev48.dist-info → jaxsim-0.5.1.dev60.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.5.1.
|
16
|
-
__version_tuple__ = version_tuple = (0, 5, 1, '
|
15
|
+
__version__ = version = '0.5.1.dev60'
|
16
|
+
__version_tuple__ = version_tuple = (0, 5, 1, 'dev60')
|
jaxsim/api/com.py
CHANGED
@@ -8,6 +8,7 @@ import jaxsim.typing as jtp
|
|
8
8
|
from .common import VelRepr
|
9
9
|
|
10
10
|
|
11
|
+
@js.common.named_scope
|
11
12
|
@jax.jit
|
12
13
|
def com_position(
|
13
14
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -44,6 +45,7 @@ def com_position(
|
|
44
45
|
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
|
45
46
|
|
46
47
|
|
48
|
+
@js.common.named_scope
|
47
49
|
@jax.jit
|
48
50
|
def com_linear_velocity(
|
49
51
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -74,6 +76,7 @@ def com_linear_velocity(
|
|
74
76
|
return G_vl_WG
|
75
77
|
|
76
78
|
|
79
|
+
@js.common.named_scope
|
77
80
|
@jax.jit
|
78
81
|
def centroidal_momentum(
|
79
82
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -101,6 +104,7 @@ def centroidal_momentum(
|
|
101
104
|
return G_J @ ν
|
102
105
|
|
103
106
|
|
107
|
+
@js.common.named_scope
|
104
108
|
@jax.jit
|
105
109
|
def centroidal_momentum_jacobian(
|
106
110
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -149,6 +153,7 @@ def centroidal_momentum_jacobian(
|
|
149
153
|
return G_Xf_B @ B_Jh
|
150
154
|
|
151
155
|
|
156
|
+
@js.common.named_scope
|
152
157
|
@jax.jit
|
153
158
|
def locked_centroidal_spatial_inertia(
|
154
159
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -186,6 +191,7 @@ def locked_centroidal_spatial_inertia(
|
|
186
191
|
return G_Xf_B @ B_Mbb_B @ B_Xv_G
|
187
192
|
|
188
193
|
|
194
|
+
@js.common.named_scope
|
189
195
|
@jax.jit
|
190
196
|
def average_centroidal_velocity(
|
191
197
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -213,6 +219,7 @@ def average_centroidal_velocity(
|
|
213
219
|
return G_J @ ν
|
214
220
|
|
215
221
|
|
222
|
+
@js.common.named_scope
|
216
223
|
@jax.jit
|
217
224
|
def average_centroidal_velocity_jacobian(
|
218
225
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -239,6 +246,7 @@ def average_centroidal_velocity_jacobian(
|
|
239
246
|
return jnp.linalg.inv(G_Mbb) @ G_J
|
240
247
|
|
241
248
|
|
249
|
+
@js.common.named_scope
|
242
250
|
@jax.jit
|
243
251
|
def bias_acceleration(
|
244
252
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
jaxsim/api/common.py
CHANGED
@@ -3,7 +3,8 @@ import contextlib
|
|
3
3
|
import dataclasses
|
4
4
|
import enum
|
5
5
|
import functools
|
6
|
-
from collections.abc import Iterator
|
6
|
+
from collections.abc import Callable, Iterator
|
7
|
+
from typing import ParamSpec, TypeVar
|
7
8
|
|
8
9
|
import jax
|
9
10
|
import jax.numpy as jnp
|
@@ -20,6 +21,21 @@ except ImportError:
|
|
20
21
|
from typing_extensions import Self
|
21
22
|
|
22
23
|
|
24
|
+
_P = ParamSpec("_P")
|
25
|
+
_R = TypeVar("_R")
|
26
|
+
|
27
|
+
|
28
|
+
def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
|
29
|
+
"""Applies a JAX named scope to a function for improved profiling and clarity."""
|
30
|
+
|
31
|
+
@functools.wraps(fn)
|
32
|
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
33
|
+
with jax.named_scope(name or fn.__name__):
|
34
|
+
return fn(*args, **kwargs)
|
35
|
+
|
36
|
+
return wrapper
|
37
|
+
|
38
|
+
|
23
39
|
@enum.unique
|
24
40
|
class VelRepr(enum.IntEnum):
|
25
41
|
"""
|
jaxsim/api/contact.py
CHANGED
@@ -16,6 +16,7 @@ from jaxsim.rbda import contacts
|
|
16
16
|
from .common import VelRepr
|
17
17
|
|
18
18
|
|
19
|
+
@js.common.named_scope
|
19
20
|
@jax.jit
|
20
21
|
def collidable_point_kinematics(
|
21
22
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -52,6 +53,7 @@ def collidable_point_kinematics(
|
|
52
53
|
return W_p_Ci, W_ṗ_Ci
|
53
54
|
|
54
55
|
|
56
|
+
@js.common.named_scope
|
55
57
|
@jax.jit
|
56
58
|
def collidable_point_positions(
|
57
59
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -72,6 +74,7 @@ def collidable_point_positions(
|
|
72
74
|
return W_p_Ci
|
73
75
|
|
74
76
|
|
77
|
+
@js.common.named_scope
|
75
78
|
@jax.jit
|
76
79
|
def collidable_point_velocities(
|
77
80
|
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
@@ -92,6 +95,7 @@ def collidable_point_velocities(
|
|
92
95
|
return W_ṗ_Ci
|
93
96
|
|
94
97
|
|
98
|
+
@js.common.named_scope
|
95
99
|
@jax.jit
|
96
100
|
def collidable_point_forces(
|
97
101
|
model: js.model.JaxSimModel,
|
@@ -129,6 +133,7 @@ def collidable_point_forces(
|
|
129
133
|
return f_Ci
|
130
134
|
|
131
135
|
|
136
|
+
@js.common.named_scope
|
132
137
|
@jax.jit
|
133
138
|
def collidable_point_dynamics(
|
134
139
|
model: js.model.JaxSimModel,
|
@@ -227,6 +232,7 @@ def collidable_point_dynamics(
|
|
227
232
|
return f_Ci, aux_data
|
228
233
|
|
229
234
|
|
235
|
+
@js.common.named_scope
|
230
236
|
@functools.partial(jax.jit, static_argnames=["link_names"])
|
231
237
|
def in_contact(
|
232
238
|
model: js.model.JaxSimModel,
|
@@ -424,6 +430,7 @@ def estimate_good_contact_parameters(
|
|
424
430
|
return parameters
|
425
431
|
|
426
432
|
|
433
|
+
@js.common.named_scope
|
427
434
|
@jax.jit
|
428
435
|
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
|
429
436
|
r"""
|
@@ -469,6 +476,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
|
|
469
476
|
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
|
470
477
|
|
471
478
|
|
479
|
+
@js.common.named_scope
|
472
480
|
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
473
481
|
def jacobian(
|
474
482
|
model: js.model.JaxSimModel,
|
@@ -561,6 +569,7 @@ def jacobian(
|
|
561
569
|
return O_J_WC
|
562
570
|
|
563
571
|
|
572
|
+
@js.common.named_scope
|
564
573
|
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
565
574
|
def jacobian_derivative(
|
566
575
|
model: js.model.JaxSimModel,
|
jaxsim/api/data.py
CHANGED
@@ -253,6 +253,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
253
253
|
|
254
254
|
return -self.gravity[2]
|
255
255
|
|
256
|
+
@js.common.named_scope
|
256
257
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
257
258
|
def joint_positions(
|
258
259
|
self,
|
@@ -300,6 +301,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
300
301
|
|
301
302
|
return self.state.physics_model.joint_positions[joint_idxs]
|
302
303
|
|
304
|
+
@js.common.named_scope
|
303
305
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
304
306
|
def joint_velocities(
|
305
307
|
self,
|
@@ -347,6 +349,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
347
349
|
|
348
350
|
return self.state.physics_model.joint_velocities[joint_idxs]
|
349
351
|
|
352
|
+
@js.common.named_scope
|
350
353
|
@jax.jit
|
351
354
|
def base_position(self) -> jtp.Vector:
|
352
355
|
"""
|
@@ -358,6 +361,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
358
361
|
|
359
362
|
return self.state.physics_model.base_position.squeeze()
|
360
363
|
|
364
|
+
@js.common.named_scope
|
361
365
|
@functools.partial(jax.jit, static_argnames=["dcm"])
|
362
366
|
def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
|
363
367
|
"""
|
@@ -386,6 +390,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
386
390
|
float
|
387
391
|
)
|
388
392
|
|
393
|
+
@js.common.named_scope
|
389
394
|
@jax.jit
|
390
395
|
def base_transform(self) -> jtp.Matrix:
|
391
396
|
"""
|
@@ -405,6 +410,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
405
410
|
]
|
406
411
|
)
|
407
412
|
|
413
|
+
@js.common.named_scope
|
408
414
|
@jax.jit
|
409
415
|
def base_velocity(self) -> jtp.Vector:
|
410
416
|
"""
|
@@ -434,6 +440,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
434
440
|
.astype(float)
|
435
441
|
)
|
436
442
|
|
443
|
+
@js.common.named_scope
|
437
444
|
@jax.jit
|
438
445
|
def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
|
439
446
|
r"""
|
@@ -446,6 +453,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
446
453
|
|
447
454
|
return self.base_transform(), self.joint_positions()
|
448
455
|
|
456
|
+
@js.common.named_scope
|
449
457
|
@jax.jit
|
450
458
|
def generalized_velocity(self) -> jtp.Vector:
|
451
459
|
r"""
|
@@ -466,6 +474,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
466
474
|
# Store quantities
|
467
475
|
# ================
|
468
476
|
|
477
|
+
@js.common.named_scope
|
469
478
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
470
479
|
def reset_joint_positions(
|
471
480
|
self,
|
@@ -514,6 +523,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
514
523
|
s=self.state.physics_model.joint_positions.at[joint_idxs].set(positions)
|
515
524
|
)
|
516
525
|
|
526
|
+
@js.common.named_scope
|
517
527
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
518
528
|
def reset_joint_velocities(
|
519
529
|
self,
|
@@ -562,6 +572,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
562
572
|
ṡ=self.state.physics_model.joint_velocities.at[joint_idxs].set(velocities)
|
563
573
|
)
|
564
574
|
|
575
|
+
@js.common.named_scope
|
565
576
|
@jax.jit
|
566
577
|
def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
|
567
578
|
"""
|
@@ -585,6 +596,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
585
596
|
),
|
586
597
|
)
|
587
598
|
|
599
|
+
@js.common.named_scope
|
588
600
|
@jax.jit
|
589
601
|
def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
|
590
602
|
"""
|
@@ -612,6 +624,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
612
624
|
),
|
613
625
|
)
|
614
626
|
|
627
|
+
@js.common.named_scope
|
615
628
|
@jax.jit
|
616
629
|
def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
|
617
630
|
"""
|
@@ -634,6 +647,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
634
647
|
base_quaternion=W_Q_B
|
635
648
|
)
|
636
649
|
|
650
|
+
@js.common.named_scope
|
637
651
|
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
638
652
|
def reset_base_linear_velocity(
|
639
653
|
self,
|
@@ -665,6 +679,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
665
679
|
velocity_representation=velocity_representation,
|
666
680
|
)
|
667
681
|
|
682
|
+
@js.common.named_scope
|
668
683
|
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
669
684
|
def reset_base_angular_velocity(
|
670
685
|
self,
|
@@ -696,6 +711,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
696
711
|
velocity_representation=velocity_representation,
|
697
712
|
)
|
698
713
|
|
714
|
+
@js.common.named_scope
|
699
715
|
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
700
716
|
def reset_base_velocity(
|
701
717
|
self,
|
jaxsim/api/frame.py
CHANGED
@@ -16,6 +16,7 @@ from .common import VelRepr
|
|
16
16
|
# =======================
|
17
17
|
|
18
18
|
|
19
|
+
@js.common.named_scope
|
19
20
|
@jax.jit
|
20
21
|
def idx_of_parent_link(
|
21
22
|
model: js.model.JaxSimModel, *, frame_index: jtp.IntLike
|
@@ -45,6 +46,7 @@ def idx_of_parent_link(
|
|
45
46
|
]
|
46
47
|
|
47
48
|
|
49
|
+
@js.common.named_scope
|
48
50
|
@functools.partial(jax.jit, static_argnames="frame_name")
|
49
51
|
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
|
50
52
|
"""
|
@@ -97,6 +99,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
|
|
97
99
|
]
|
98
100
|
|
99
101
|
|
102
|
+
@js.common.named_scope
|
100
103
|
@functools.partial(jax.jit, static_argnames=["frame_names"])
|
101
104
|
def names_to_idxs(
|
102
105
|
model: js.model.JaxSimModel, *, frame_names: Sequence[str]
|
@@ -139,6 +142,7 @@ def idxs_to_names(
|
|
139
142
|
# ==========
|
140
143
|
|
141
144
|
|
145
|
+
@js.common.named_scope
|
142
146
|
@jax.jit
|
143
147
|
def transform(
|
144
148
|
model: js.model.JaxSimModel,
|
@@ -180,6 +184,7 @@ def transform(
|
|
180
184
|
return W_H_L @ L_H_F
|
181
185
|
|
182
186
|
|
187
|
+
@js.common.named_scope
|
183
188
|
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
184
189
|
def velocity(
|
185
190
|
model: js.model.JaxSimModel,
|
@@ -230,6 +235,7 @@ def velocity(
|
|
230
235
|
return O_J_WF_I @ I_ν
|
231
236
|
|
232
237
|
|
238
|
+
@js.common.named_scope
|
233
239
|
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
234
240
|
def jacobian(
|
235
241
|
model: js.model.JaxSimModel,
|
@@ -309,6 +315,7 @@ def jacobian(
|
|
309
315
|
return O_J_WL_I
|
310
316
|
|
311
317
|
|
318
|
+
@js.common.named_scope
|
312
319
|
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
313
320
|
def jacobian_derivative(
|
314
321
|
model: js.model.JaxSimModel,
|
jaxsim/api/joint.py
CHANGED
@@ -13,6 +13,7 @@ from jaxsim import exceptions
|
|
13
13
|
# =======================
|
14
14
|
|
15
15
|
|
16
|
+
@js.common.named_scope
|
16
17
|
@functools.partial(jax.jit, static_argnames="joint_name")
|
17
18
|
def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
|
18
19
|
"""
|
@@ -61,6 +62,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
|
|
61
62
|
return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
|
62
63
|
|
63
64
|
|
65
|
+
@js.common.named_scope
|
64
66
|
@functools.partial(jax.jit, static_argnames="joint_names")
|
65
67
|
def names_to_idxs(
|
66
68
|
model: js.model.JaxSimModel, *, joint_names: Sequence[str]
|
@@ -141,6 +143,7 @@ def position_limit(
|
|
141
143
|
return s_min.astype(float), s_max.astype(float)
|
142
144
|
|
143
145
|
|
146
|
+
@js.common.named_scope
|
144
147
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
145
148
|
def position_limits(
|
146
149
|
model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
|
@@ -176,6 +179,7 @@ def position_limits(
|
|
176
179
|
# ======================
|
177
180
|
|
178
181
|
|
182
|
+
@js.common.named_scope
|
179
183
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
180
184
|
def random_joint_positions(
|
181
185
|
model: js.model.JaxSimModel,
|
jaxsim/api/model.py
CHANGED
@@ -491,6 +491,7 @@ def reduce(
|
|
491
491
|
# ===================
|
492
492
|
|
493
493
|
|
494
|
+
@js.common.named_scope
|
494
495
|
@jax.jit
|
495
496
|
def total_mass(model: JaxSimModel) -> jtp.Float:
|
496
497
|
"""
|
@@ -506,6 +507,7 @@ def total_mass(model: JaxSimModel) -> jtp.Float:
|
|
506
507
|
return model.kin_dyn_parameters.link_parameters.mass.sum().astype(float)
|
507
508
|
|
508
509
|
|
510
|
+
@js.common.named_scope
|
509
511
|
@jax.jit
|
510
512
|
def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
|
511
513
|
"""
|
@@ -528,6 +530,7 @@ def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
|
|
528
530
|
# ==============================
|
529
531
|
|
530
532
|
|
533
|
+
@js.common.named_scope
|
531
534
|
@jax.jit
|
532
535
|
def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
|
533
536
|
"""
|
@@ -915,6 +918,7 @@ def forward_dynamics(
|
|
915
918
|
)
|
916
919
|
|
917
920
|
|
921
|
+
@js.common.named_scope
|
918
922
|
@jax.jit
|
919
923
|
def forward_dynamics_aba(
|
920
924
|
model: JaxSimModel,
|
@@ -1059,6 +1063,7 @@ def forward_dynamics_aba(
|
|
1059
1063
|
return C_v̇_WB.astype(float), s̈.astype(float)
|
1060
1064
|
|
1061
1065
|
|
1066
|
+
@js.common.named_scope
|
1062
1067
|
@jax.jit
|
1063
1068
|
def forward_dynamics_crb(
|
1064
1069
|
model: JaxSimModel,
|
@@ -1150,6 +1155,7 @@ def forward_dynamics_crb(
|
|
1150
1155
|
return v̇_WB, s̈
|
1151
1156
|
|
1152
1157
|
|
1158
|
+
@js.common.named_scope
|
1153
1159
|
@jax.jit
|
1154
1160
|
def free_floating_mass_matrix(
|
1155
1161
|
model: JaxSimModel, data: js.data.JaxSimModelData
|
@@ -1195,6 +1201,7 @@ def free_floating_mass_matrix(
|
|
1195
1201
|
raise ValueError(data.velocity_representation)
|
1196
1202
|
|
1197
1203
|
|
1204
|
+
@js.common.named_scope
|
1198
1205
|
@jax.jit
|
1199
1206
|
def free_floating_coriolis_matrix(
|
1200
1207
|
model: JaxSimModel, data: js.data.JaxSimModelData
|
@@ -1311,6 +1318,7 @@ def free_floating_coriolis_matrix(
|
|
1311
1318
|
raise ValueError(data.velocity_representation)
|
1312
1319
|
|
1313
1320
|
|
1321
|
+
@js.common.named_scope
|
1314
1322
|
@jax.jit
|
1315
1323
|
def inverse_dynamics(
|
1316
1324
|
model: JaxSimModel,
|
@@ -1466,6 +1474,7 @@ def inverse_dynamics(
|
|
1466
1474
|
return f_B.astype(float), τ.astype(float)
|
1467
1475
|
|
1468
1476
|
|
1477
|
+
@js.common.named_scope
|
1469
1478
|
@jax.jit
|
1470
1479
|
def free_floating_gravity_forces(
|
1471
1480
|
model: JaxSimModel, data: js.data.JaxSimModelData
|
@@ -1515,6 +1524,7 @@ def free_floating_gravity_forces(
|
|
1515
1524
|
).astype(float)
|
1516
1525
|
|
1517
1526
|
|
1527
|
+
@js.common.named_scope
|
1518
1528
|
@jax.jit
|
1519
1529
|
def free_floating_bias_forces(
|
1520
1530
|
model: JaxSimModel, data: js.data.JaxSimModelData
|
@@ -1584,6 +1594,7 @@ def free_floating_bias_forces(
|
|
1584
1594
|
# ==========================
|
1585
1595
|
|
1586
1596
|
|
1597
|
+
@js.common.named_scope
|
1587
1598
|
@jax.jit
|
1588
1599
|
def locked_spatial_inertia(
|
1589
1600
|
model: JaxSimModel, data: js.data.JaxSimModelData
|
@@ -1602,6 +1613,7 @@ def locked_spatial_inertia(
|
|
1602
1613
|
return total_momentum_jacobian(model=model, data=data)[:, 0:6]
|
1603
1614
|
|
1604
1615
|
|
1616
|
+
@js.common.named_scope
|
1605
1617
|
@jax.jit
|
1606
1618
|
def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
|
1607
1619
|
"""
|
@@ -1690,6 +1702,7 @@ def total_momentum_jacobian(
|
|
1690
1702
|
raise ValueError(output_vel_repr)
|
1691
1703
|
|
1692
1704
|
|
1705
|
+
@js.common.named_scope
|
1693
1706
|
@jax.jit
|
1694
1707
|
def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
|
1695
1708
|
"""
|
@@ -1778,6 +1791,7 @@ def average_velocity_jacobian(
|
|
1778
1791
|
# ========================
|
1779
1792
|
|
1780
1793
|
|
1794
|
+
@js.common.named_scope
|
1781
1795
|
@jax.jit
|
1782
1796
|
def link_bias_accelerations(
|
1783
1797
|
model: JaxSimModel,
|
@@ -1987,6 +2001,7 @@ def link_bias_accelerations(
|
|
1987
2001
|
return O_v̇_WL
|
1988
2002
|
|
1989
2003
|
|
2004
|
+
@js.common.named_scope
|
1990
2005
|
@jax.jit
|
1991
2006
|
def link_contact_forces(
|
1992
2007
|
model: js.model.JaxSimModel,
|
@@ -2062,6 +2077,7 @@ def link_contact_forces(
|
|
2062
2077
|
# ======
|
2063
2078
|
|
2064
2079
|
|
2080
|
+
@js.common.named_scope
|
2065
2081
|
@jax.jit
|
2066
2082
|
def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
2067
2083
|
"""
|
@@ -2081,6 +2097,7 @@ def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.
|
|
2081
2097
|
return (K + U).astype(float)
|
2082
2098
|
|
2083
2099
|
|
2100
|
+
@js.common.named_scope
|
2084
2101
|
@jax.jit
|
2085
2102
|
def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
2086
2103
|
"""
|
@@ -2102,6 +2119,7 @@ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Flo
|
|
2102
2119
|
return K.squeeze().astype(float)
|
2103
2120
|
|
2104
2121
|
|
2122
|
+
@js.common.named_scope
|
2105
2123
|
@jax.jit
|
2106
2124
|
def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
2107
2125
|
"""
|
@@ -2128,6 +2146,7 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
|
|
2128
2146
|
# ==========
|
2129
2147
|
|
2130
2148
|
|
2149
|
+
@js.common.named_scope
|
2131
2150
|
@jax.jit
|
2132
2151
|
def step(
|
2133
2152
|
model: JaxSimModel,
|
@@ -2198,7 +2217,9 @@ def step(
|
|
2198
2217
|
isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts)
|
2199
2218
|
& (
|
2200
2219
|
~jnp.allclose(dt, model.time_step)
|
2201
|
-
| ~
|
2220
|
+
| ~int(
|
2221
|
+
isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler)
|
2222
|
+
)
|
2202
2223
|
)
|
2203
2224
|
),
|
2204
2225
|
msg=msg.format(module, name),
|
jaxsim/api/ode.py
CHANGED
@@ -85,6 +85,7 @@ def wrap_system_dynamics_for_integration(
|
|
85
85
|
# ==================================
|
86
86
|
|
87
87
|
|
88
|
+
@js.common.named_scope
|
88
89
|
@jax.jit
|
89
90
|
def system_velocity_dynamics(
|
90
91
|
model: js.model.JaxSimModel,
|
@@ -331,6 +332,7 @@ def system_acceleration(
|
|
331
332
|
return v̇_WB, s̈
|
332
333
|
|
333
334
|
|
335
|
+
@js.common.named_scope
|
334
336
|
@jax.jit
|
335
337
|
def system_position_dynamics(
|
336
338
|
model: js.model.JaxSimModel,
|
@@ -370,6 +372,7 @@ def system_position_dynamics(
|
|
370
372
|
return W_ṗ_B, W_Q̇_B, ṡ
|
371
373
|
|
372
374
|
|
375
|
+
@js.common.named_scope
|
373
376
|
@jax.jit
|
374
377
|
def system_dynamics(
|
375
378
|
model: js.model.JaxSimModel,
|
jaxsim/api/ode_data.py
CHANGED
@@ -10,108 +10,14 @@ import jaxsim.api as js
|
|
10
10
|
import jaxsim.typing as jtp
|
11
11
|
from jaxsim.utils import JaxsimDataclass
|
12
12
|
|
13
|
-
#
|
14
|
-
# Define the
|
15
|
-
#
|
13
|
+
# ===================================================================
|
14
|
+
# Define the state of the ODE system defining the integrated dynamics
|
15
|
+
# ===================================================================
|
16
16
|
|
17
17
|
# Note: the ODE system is the combination of the floating-base dynamics and the
|
18
18
|
# soft-contacts dynamics.
|
19
19
|
|
20
20
|
|
21
|
-
@jax_dataclasses.pytree_dataclass
|
22
|
-
class ODEInput(JaxsimDataclass):
|
23
|
-
"""
|
24
|
-
The input to the ODE system.
|
25
|
-
|
26
|
-
Attributes:
|
27
|
-
physics_model: The input to the physics model.
|
28
|
-
"""
|
29
|
-
|
30
|
-
physics_model: PhysicsModelInput
|
31
|
-
|
32
|
-
@staticmethod
|
33
|
-
def build_from_jaxsim_model(
|
34
|
-
model: js.model.JaxSimModel | None = None,
|
35
|
-
link_forces: jtp.MatrixLike | None = None,
|
36
|
-
joint_force_references: jtp.VectorLike | None = None,
|
37
|
-
) -> ODEInput:
|
38
|
-
"""
|
39
|
-
Build an `ODEInput` from a `JaxSimModel`.
|
40
|
-
|
41
|
-
Args:
|
42
|
-
model: The `JaxSimModel` associated with the ODE input.
|
43
|
-
link_forces: The matrix of external forces applied to the links.
|
44
|
-
joint_force_references: The vector of joint force references.
|
45
|
-
|
46
|
-
Returns:
|
47
|
-
The `ODEInput` built from the `JaxSimModel`.
|
48
|
-
|
49
|
-
Note:
|
50
|
-
If any of the input components are not provided, they are built from the
|
51
|
-
`JaxSimModel` and initialized to zero.
|
52
|
-
"""
|
53
|
-
|
54
|
-
return ODEInput.build(
|
55
|
-
physics_model_input=PhysicsModelInput.build_from_jaxsim_model(
|
56
|
-
model=model,
|
57
|
-
link_forces=link_forces,
|
58
|
-
joint_force_references=joint_force_references,
|
59
|
-
),
|
60
|
-
model=model,
|
61
|
-
)
|
62
|
-
|
63
|
-
@staticmethod
|
64
|
-
def build(
|
65
|
-
physics_model_input: PhysicsModelInput | None = None,
|
66
|
-
model: js.model.JaxSimModel | None = None,
|
67
|
-
) -> ODEInput:
|
68
|
-
"""
|
69
|
-
Build an `ODEInput` from a `PhysicsModelInput`.
|
70
|
-
|
71
|
-
Args:
|
72
|
-
physics_model_input: The `PhysicsModelInput` associated with the ODE input.
|
73
|
-
model: The `JaxSimModel` associated with the ODE input.
|
74
|
-
|
75
|
-
Returns:
|
76
|
-
A `ODEInput` instance.
|
77
|
-
"""
|
78
|
-
|
79
|
-
physics_model_input = (
|
80
|
-
physics_model_input
|
81
|
-
if physics_model_input is not None
|
82
|
-
else PhysicsModelInput.zero(model=model)
|
83
|
-
)
|
84
|
-
|
85
|
-
return ODEInput(physics_model=physics_model_input)
|
86
|
-
|
87
|
-
@staticmethod
|
88
|
-
def zero(model: js.model.JaxSimModel) -> ODEInput:
|
89
|
-
"""
|
90
|
-
Build a zero `ODEInput` from a `JaxSimModel`.
|
91
|
-
|
92
|
-
Args:
|
93
|
-
model: The `JaxSimModel` associated with the ODE input.
|
94
|
-
|
95
|
-
Returns:
|
96
|
-
A zero `ODEInput` instance.
|
97
|
-
"""
|
98
|
-
|
99
|
-
return ODEInput.build(model=model)
|
100
|
-
|
101
|
-
def valid(self, model: js.model.JaxSimModel) -> bool:
|
102
|
-
"""
|
103
|
-
Check if the `ODEInput` is valid for a given `JaxSimModel`.
|
104
|
-
|
105
|
-
Args:
|
106
|
-
model: The `JaxSimModel` to validate the `ODEInput` against.
|
107
|
-
|
108
|
-
Returns:
|
109
|
-
`True` if the ODE input is valid for the given model, `False` otherwise.
|
110
|
-
"""
|
111
|
-
|
112
|
-
return self.physics_model.valid(model=model)
|
113
|
-
|
114
|
-
|
115
21
|
@jax_dataclasses.pytree_dataclass
|
116
22
|
class ODEState(JaxsimDataclass):
|
117
23
|
"""
|
@@ -493,123 +399,3 @@ class PhysicsModelState(JaxsimDataclass):
|
|
493
399
|
return False
|
494
400
|
|
495
401
|
return True
|
496
|
-
|
497
|
-
|
498
|
-
@jax_dataclasses.pytree_dataclass
|
499
|
-
class PhysicsModelInput(JaxsimDataclass):
|
500
|
-
"""
|
501
|
-
Class storing the inputs of the physics model dynamics.
|
502
|
-
|
503
|
-
Attributes:
|
504
|
-
tau: The vector of joint forces.
|
505
|
-
f_ext: The matrix of external forces applied to the links.
|
506
|
-
"""
|
507
|
-
|
508
|
-
tau: jtp.Vector
|
509
|
-
f_ext: jtp.Matrix
|
510
|
-
|
511
|
-
@staticmethod
|
512
|
-
def build_from_jaxsim_model(
|
513
|
-
model: js.model.JaxSimModel | None = None,
|
514
|
-
link_forces: jtp.MatrixLike | None = None,
|
515
|
-
joint_force_references: jtp.VectorLike | None = None,
|
516
|
-
) -> PhysicsModelInput:
|
517
|
-
"""
|
518
|
-
Build a `PhysicsModelInput` from a `JaxSimModel`.
|
519
|
-
|
520
|
-
Args:
|
521
|
-
model: The `JaxSimModel` associated with the input.
|
522
|
-
link_forces: The matrix of external forces applied to the links.
|
523
|
-
joint_force_references: The vector of joint force references.
|
524
|
-
|
525
|
-
Returns:
|
526
|
-
A `PhysicsModelInput` instance.
|
527
|
-
|
528
|
-
Note:
|
529
|
-
If any of the input components are not provided, they are built from the
|
530
|
-
`JaxSimModel` and initialized to zero.
|
531
|
-
"""
|
532
|
-
|
533
|
-
return PhysicsModelInput.build(
|
534
|
-
joint_force_references=joint_force_references,
|
535
|
-
link_forces=link_forces,
|
536
|
-
number_of_dofs=model.dofs(),
|
537
|
-
number_of_links=model.number_of_links(),
|
538
|
-
)
|
539
|
-
|
540
|
-
@staticmethod
|
541
|
-
def build(
|
542
|
-
link_forces: jtp.MatrixLike | None = None,
|
543
|
-
joint_force_references: jtp.VectorLike | None = None,
|
544
|
-
number_of_dofs: jtp.Int | None = None,
|
545
|
-
number_of_links: jtp.Int | None = None,
|
546
|
-
) -> PhysicsModelInput:
|
547
|
-
"""
|
548
|
-
Build a `PhysicsModelInput`.
|
549
|
-
|
550
|
-
Args:
|
551
|
-
link_forces: The matrix of external forces applied to the links.
|
552
|
-
joint_force_references: The vector of joint force references.
|
553
|
-
number_of_dofs: The number of degrees of freedom of the model.
|
554
|
-
number_of_links: The number of links of the model.
|
555
|
-
|
556
|
-
Returns:
|
557
|
-
A `PhysicsModelInput` instance.
|
558
|
-
"""
|
559
|
-
|
560
|
-
joint_force_references = jnp.atleast_1d(
|
561
|
-
jnp.array(joint_force_references, dtype=float).squeeze()
|
562
|
-
if joint_force_references is not None
|
563
|
-
else jnp.zeros(number_of_dofs)
|
564
|
-
).astype(float)
|
565
|
-
|
566
|
-
link_forces = jnp.atleast_2d(
|
567
|
-
jnp.array(link_forces, dtype=float).squeeze()
|
568
|
-
if link_forces is not None
|
569
|
-
else jnp.zeros(shape=(number_of_links, 6))
|
570
|
-
).astype(float)
|
571
|
-
|
572
|
-
return PhysicsModelInput(
|
573
|
-
tau=joint_force_references,
|
574
|
-
f_ext=link_forces,
|
575
|
-
)
|
576
|
-
|
577
|
-
@staticmethod
|
578
|
-
def zero(model: js.model.JaxSimModel) -> PhysicsModelInput:
|
579
|
-
"""
|
580
|
-
Build a `PhysicsModelInput` with all components initialized to zero.
|
581
|
-
|
582
|
-
Args:
|
583
|
-
model: The `JaxSimModel` associated with the input.
|
584
|
-
|
585
|
-
Returns:
|
586
|
-
A `PhysicsModelInput` instance.
|
587
|
-
"""
|
588
|
-
|
589
|
-
return PhysicsModelInput.build_from_jaxsim_model(model=model)
|
590
|
-
|
591
|
-
def valid(self, model: js.model.JaxSimModel) -> bool:
|
592
|
-
"""
|
593
|
-
Check if the `PhysicsModelInput` is valid for a given `JaxSimModel`.
|
594
|
-
|
595
|
-
Args:
|
596
|
-
model: The `JaxSimModel` to validate the `PhysicsModelInput` against.
|
597
|
-
|
598
|
-
Returns:
|
599
|
-
`True` if the `PhysicsModelInput` is valid for the given model,
|
600
|
-
`False` otherwise.
|
601
|
-
"""
|
602
|
-
|
603
|
-
shape = self.tau.shape
|
604
|
-
expected_shape = (model.dofs(),)
|
605
|
-
|
606
|
-
if shape != expected_shape:
|
607
|
-
return False
|
608
|
-
|
609
|
-
shape = self.f_ext.shape
|
610
|
-
expected_shape = (model.number_of_links(), 6)
|
611
|
-
|
612
|
-
if shape != expected_shape:
|
613
|
-
return False
|
614
|
-
|
615
|
-
return True
|
jaxsim/api/references.py
CHANGED
@@ -12,7 +12,6 @@ from jaxsim import exceptions
|
|
12
12
|
from jaxsim.utils.tracing import not_tracing
|
13
13
|
|
14
14
|
from .common import VelRepr
|
15
|
-
from .ode_data import ODEInput
|
16
15
|
|
17
16
|
try:
|
18
17
|
from typing import Self
|
@@ -24,9 +23,14 @@ except ImportError:
|
|
24
23
|
class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
25
24
|
"""
|
26
25
|
Class containing the references for a `JaxSimModel` object.
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
_link_forces: The link 6D forces in inertial-fixed representation.
|
29
|
+
_joint_force_references: The joint force references.
|
27
30
|
"""
|
28
31
|
|
29
|
-
|
32
|
+
_link_forces: jtp.Matrix
|
33
|
+
_joint_force_references: jtp.Vector
|
30
34
|
|
31
35
|
@staticmethod
|
32
36
|
def zero(
|
@@ -94,17 +98,21 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
94
98
|
velocity_representation = (
|
95
99
|
velocity_representation
|
96
100
|
if velocity_representation is not None
|
97
|
-
else (
|
98
|
-
data.velocity_representation if data is not None else VelRepr.Inertial
|
99
|
-
)
|
101
|
+
else getattr(data, "velocity_representation", VelRepr.Inertial)
|
100
102
|
)
|
101
103
|
|
102
104
|
# Create a zero references object.
|
103
105
|
references = JaxSimModelReferences(
|
104
|
-
|
106
|
+
_link_forces=f_L,
|
107
|
+
_joint_force_references=joint_force_references,
|
105
108
|
velocity_representation=velocity_representation,
|
106
109
|
)
|
107
110
|
|
111
|
+
# If the velocity representation is inertial-fixed, we can return
|
112
|
+
# the references directly, as we store the link forces in this frame.
|
113
|
+
if velocity_representation is VelRepr.Inertial:
|
114
|
+
return references
|
115
|
+
|
108
116
|
# Store the joint force references.
|
109
117
|
references = references.set_joint_force_references(
|
110
118
|
forces=joint_force_references,
|
@@ -135,17 +143,27 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
135
143
|
`False` otherwise.
|
136
144
|
"""
|
137
145
|
|
138
|
-
|
146
|
+
if model is None:
|
147
|
+
return True
|
148
|
+
|
149
|
+
shape = self._joint_force_references.shape
|
150
|
+
expected_shape = (model.dofs(),)
|
151
|
+
|
152
|
+
if shape != expected_shape:
|
153
|
+
return False
|
154
|
+
|
155
|
+
shape = self._link_forces.shape
|
156
|
+
expected_shape = (model.number_of_links(), 6)
|
139
157
|
|
140
|
-
if
|
141
|
-
|
158
|
+
if shape != expected_shape:
|
159
|
+
return False
|
142
160
|
|
143
|
-
return
|
161
|
+
return True
|
144
162
|
|
145
163
|
# ==================
|
146
164
|
# Extract quantities
|
147
165
|
# ==================
|
148
|
-
|
166
|
+
@js.common.named_scope
|
149
167
|
@functools.partial(jax.jit, static_argnames=["link_names"])
|
150
168
|
def link_forces(
|
151
169
|
self,
|
@@ -178,7 +196,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
178
196
|
e.g. to the contact model and other kinematic constraints.
|
179
197
|
"""
|
180
198
|
|
181
|
-
W_f_L = self.
|
199
|
+
W_f_L = self._link_forces
|
182
200
|
|
183
201
|
# Return all link forces in inertial-fixed representation using the implicit
|
184
202
|
# serialization.
|
@@ -190,7 +208,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
190
208
|
if link_names is not None:
|
191
209
|
raise ValueError("Link names cannot be provided without a model")
|
192
210
|
|
193
|
-
return
|
211
|
+
return W_f_L
|
194
212
|
|
195
213
|
# If we have the model, we can extract the link names, if not provided.
|
196
214
|
link_idxs = (
|
@@ -207,7 +225,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
207
225
|
msg = "Missing model data to use a representation different from {}"
|
208
226
|
raise ValueError(msg.format(VelRepr.Inertial.name))
|
209
227
|
|
210
|
-
if not_tracing(self.
|
228
|
+
if not_tracing(self._link_forces) and not data.valid(model=model):
|
211
229
|
raise ValueError("The provided data is not valid for the model")
|
212
230
|
|
213
231
|
# Helper function to convert a single 6D force to the active representation
|
@@ -264,9 +282,9 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
264
282
|
if joint_names is not None:
|
265
283
|
raise ValueError("Joint names cannot be provided without a model")
|
266
284
|
|
267
|
-
return self.
|
285
|
+
return self._joint_force_references
|
268
286
|
|
269
|
-
if not_tracing(self.
|
287
|
+
if not_tracing(self._joint_force_references) and not self.valid(model=model):
|
270
288
|
msg = "The actuation object is not compatible with the provided model"
|
271
289
|
raise ValueError(msg)
|
272
290
|
|
@@ -277,13 +295,13 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
277
295
|
)
|
278
296
|
|
279
297
|
return jnp.atleast_1d(
|
280
|
-
self.
|
298
|
+
self._joint_force_references[joint_idxs].squeeze()
|
281
299
|
).astype(float)
|
282
300
|
|
283
301
|
# ================
|
284
302
|
# Store quantities
|
285
303
|
# ================
|
286
|
-
|
304
|
+
@js.common.named_scope
|
287
305
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
288
306
|
def set_joint_force_references(
|
289
307
|
self,
|
@@ -310,11 +328,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
310
328
|
def replace(forces: jtp.Vector) -> JaxSimModelReferences:
|
311
329
|
return self.replace(
|
312
330
|
validate=True,
|
313
|
-
|
314
|
-
physics_model=self.input.physics_model.replace(
|
315
|
-
tau=jnp.atleast_1d(forces.squeeze()).astype(float)
|
316
|
-
)
|
317
|
-
),
|
331
|
+
_joint_force_references=jnp.atleast_1d(forces.squeeze()).astype(float),
|
318
332
|
)
|
319
333
|
|
320
334
|
if model is None:
|
@@ -330,8 +344,9 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
330
344
|
else jnp.arange(model.number_of_joints())
|
331
345
|
)
|
332
346
|
|
333
|
-
return replace(forces=self.
|
347
|
+
return replace(forces=self._joint_force_references.at[joint_idxs].set(forces))
|
334
348
|
|
349
|
+
@js.common.named_scope
|
335
350
|
@functools.partial(jax.jit, static_argnames=["link_names", "additive"])
|
336
351
|
def apply_link_forces(
|
337
352
|
self,
|
@@ -370,11 +385,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
370
385
|
def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
|
371
386
|
return self.replace(
|
372
387
|
validate=True,
|
373
|
-
|
374
|
-
physics_model=self.input.physics_model.replace(
|
375
|
-
f_ext=jnp.atleast_2d(forces.squeeze()).astype(float)
|
376
|
-
)
|
377
|
-
),
|
388
|
+
_link_forces=jnp.atleast_2d(forces.squeeze()).astype(float),
|
378
389
|
)
|
379
390
|
|
380
391
|
# In this case, we allow only to set the inertial 6D forces to all links
|
@@ -389,11 +400,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
389
400
|
|
390
401
|
W_f_L = f_L
|
391
402
|
|
392
|
-
W_f0_L = (
|
393
|
-
jnp.zeros_like(W_f_L)
|
394
|
-
if not additive
|
395
|
-
else self.input.physics_model.f_ext
|
396
|
-
)
|
403
|
+
W_f0_L = jnp.zeros_like(W_f_L) if not additive else self._link_forces
|
397
404
|
|
398
405
|
return replace(forces=W_f0_L + W_f_L)
|
399
406
|
|
@@ -410,18 +417,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
410
417
|
|
411
418
|
# Compute the bias depending on whether we either set or add the link forces.
|
412
419
|
W_f0_L = (
|
413
|
-
jnp.zeros_like(f_L)
|
414
|
-
if not additive
|
415
|
-
else self.input.physics_model.f_ext[link_idxs, :]
|
420
|
+
jnp.zeros_like(f_L) if not additive else self._link_forces[link_idxs, :]
|
416
421
|
)
|
417
422
|
|
418
423
|
# If inertial-fixed representation, we can directly store the link forces.
|
419
424
|
if self.velocity_representation is VelRepr.Inertial:
|
420
425
|
W_f_L = f_L
|
421
426
|
return replace(
|
422
|
-
forces=self.
|
423
|
-
W_f0_L + W_f_L
|
424
|
-
)
|
427
|
+
forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)
|
425
428
|
)
|
426
429
|
|
427
430
|
if data is None:
|
@@ -450,9 +453,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
450
453
|
W_H_L = js.model.forward_kinematics(model=model, data=data)
|
451
454
|
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
|
452
455
|
|
453
|
-
return replace(
|
454
|
-
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
|
455
|
-
)
|
456
|
+
return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
|
456
457
|
|
457
458
|
def apply_frame_forces(
|
458
459
|
self,
|
@@ -851,7 +851,7 @@ class ViscoElasticContacts(common.ContactModel):
|
|
851
851
|
W_f̅_L = (
|
852
852
|
jnp.array(average_link_contact_forces_inertial)
|
853
853
|
if average_link_contact_forces_inertial is not None
|
854
|
-
else jnp.zeros_like(references.
|
854
|
+
else jnp.zeros_like(references._link_forces)
|
855
855
|
).astype(float)
|
856
856
|
|
857
857
|
LW_f̿_L = (
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.5.1.
|
3
|
+
Version: 0.5.1.dev60
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
|
6
6
|
Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
|
@@ -1,21 +1,21 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=SbvsohyIl1wURDKT1I6oE7EjNu_dZVzu8-rPXwguX6s,426
|
3
3
|
jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
|
6
6
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
7
|
-
jaxsim/api/com.py,sha256=
|
8
|
-
jaxsim/api/common.py,sha256=
|
9
|
-
jaxsim/api/contact.py,sha256=
|
10
|
-
jaxsim/api/data.py,sha256=
|
11
|
-
jaxsim/api/frame.py,sha256=
|
12
|
-
jaxsim/api/joint.py,sha256=
|
7
|
+
jaxsim/api/com.py,sha256=eZEF2nAwe8FdEcTz9v--VUn6mTSRjI5nGtRpeXXW_AA,13826
|
8
|
+
jaxsim/api/common.py,sha256=SvEOGxCKOxKLLVHaNp1sFkBX0sku3-wH0-HUlYVWCDk,7090
|
9
|
+
jaxsim/api/contact.py,sha256=GFPLolO4feVUSWrLbHwvV64bD14ESzGFhoiCTJ_Im2Q,25421
|
10
|
+
jaxsim/api/data.py,sha256=hz8g-P0o7XoDYyUFPx6yA8QlgHmjfFf2_OdiwcRV6W8,30292
|
11
|
+
jaxsim/api/frame.py,sha256=7V6ih6cUe5cBKp8GKD0mZ7Zm-e3lj7qiffoAcVgZ76o,14595
|
12
|
+
jaxsim/api/joint.py,sha256=P4j_5caFs_YkV_jHI_98hko0Q5-CmWAYCqHgr8eGILY,7457
|
13
13
|
jaxsim/api/kin_dyn_parameters.py,sha256=wnto0nzzEJ_M8tH2PUdldEyxQwQdsStYUoQFu696uuw,29897
|
14
14
|
jaxsim/api/link.py,sha256=nHjffhNdi_xGkteMsqdb_hC9mdV9rNw7k3pl89Uhw_8,12798
|
15
|
-
jaxsim/api/model.py,sha256=
|
16
|
-
jaxsim/api/ode.py,sha256=
|
17
|
-
jaxsim/api/ode_data.py,sha256=
|
18
|
-
jaxsim/api/references.py,sha256=
|
15
|
+
jaxsim/api/model.py,sha256=ccS7oIpE6LKFEVG4dvXktD_zGQOoieKULANGjY8TZUs,80139
|
16
|
+
jaxsim/api/ode.py,sha256=3pSKJaZ6NrtADuulUUNmp_828PDXELgENPjhQ9UeJCc,15429
|
17
|
+
jaxsim/api/ode_data.py,sha256=ggF1AVaLW5QuXrfpNsFs-voVcW6gZkxK2Xe9GiDmou0,13755
|
18
|
+
jaxsim/api/references.py,sha256=YkdZhRv8NoBC94qvpwn1w9_alVuxrfiZV5w5NHQIt-g,20737
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
20
|
jaxsim/integrators/common.py,sha256=ohISUnUWTaNHt2kweg1JyzwYGZgIH_wc-01qJWJsO80,18281
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
@@ -59,15 +59,15 @@ jaxsim/rbda/contacts/common.py,sha256=BjwZMCkzd1ZOdZW7_Zt09Cl5j2JUHXM5Q8ao_qS6e6
|
|
59
59
|
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=PgwKfProN5sLXJsSov3nIidHHMVpJqIp7eIv6_bPGjs,20345
|
60
60
|
jaxsim/rbda/contacts/rigid.py,sha256=X-PE6PmZqlKoZTY6JhYBSW-vom-rq2uBKmBUNQeQHCg,15991
|
61
61
|
jaxsim/rbda/contacts/soft.py,sha256=sIWT4NUJmoVR5T1Fo0ExdPfzL_gPfiPiB-9CFuotE_s,15567
|
62
|
-
jaxsim/rbda/contacts/visco_elastic.py,sha256=
|
62
|
+
jaxsim/rbda/contacts/visco_elastic.py,sha256=QhyJHjDowyBTAhoSdZcCIkOqzp__gMXhLON-qYyMgQc,39886
|
63
63
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
64
64
|
jaxsim/terrain/terrain.py,sha256=_G1QS3zWycj089R8fTP5s2VjcZpEdJxREjXZJ-oXIvc,5248
|
65
65
|
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
66
66
|
jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
|
67
67
|
jaxsim/utils/tracing.py,sha256=eEY28MZW0Lm_jJNt1NkFqZz0ek01tvhR46OXZYCo7tc,532
|
68
68
|
jaxsim/utils/wrappers.py,sha256=ZY7olSORzZRvSzkdeNLj8yjwUIAt9L0Douwl7wItjpk,4008
|
69
|
-
jaxsim-0.5.1.
|
70
|
-
jaxsim-0.5.1.
|
71
|
-
jaxsim-0.5.1.
|
72
|
-
jaxsim-0.5.1.
|
73
|
-
jaxsim-0.5.1.
|
69
|
+
jaxsim-0.5.1.dev60.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
70
|
+
jaxsim-0.5.1.dev60.dist-info/METADATA,sha256=_jyXanq6vSSkWmVG8N8_5WrIPq-VQx4Skaq4xb-lvFw,17937
|
71
|
+
jaxsim-0.5.1.dev60.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
72
|
+
jaxsim-0.5.1.dev60.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
73
|
+
jaxsim-0.5.1.dev60.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|