dexcontrol 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dexcontrol might be problematic. Click here for more details.

Files changed (72) hide show
  1. dexcontrol/__init__.py +45 -0
  2. dexcontrol/apps/dualsense_teleop_base.py +371 -0
  3. dexcontrol/config/__init__.py +14 -0
  4. dexcontrol/config/core/__init__.py +22 -0
  5. dexcontrol/config/core/arm.py +32 -0
  6. dexcontrol/config/core/chassis.py +22 -0
  7. dexcontrol/config/core/hand.py +42 -0
  8. dexcontrol/config/core/head.py +33 -0
  9. dexcontrol/config/core/misc.py +37 -0
  10. dexcontrol/config/core/torso.py +36 -0
  11. dexcontrol/config/sensors/__init__.py +4 -0
  12. dexcontrol/config/sensors/cameras/__init__.py +7 -0
  13. dexcontrol/config/sensors/cameras/gemini_camera.py +16 -0
  14. dexcontrol/config/sensors/cameras/rgb_camera.py +15 -0
  15. dexcontrol/config/sensors/imu/__init__.py +6 -0
  16. dexcontrol/config/sensors/imu/gemini_imu.py +15 -0
  17. dexcontrol/config/sensors/imu/nine_axis_imu.py +15 -0
  18. dexcontrol/config/sensors/lidar/__init__.py +6 -0
  19. dexcontrol/config/sensors/lidar/rplidar.py +15 -0
  20. dexcontrol/config/sensors/ultrasonic/__init__.py +6 -0
  21. dexcontrol/config/sensors/ultrasonic/ultrasonic.py +15 -0
  22. dexcontrol/config/sensors/vega_sensors.py +65 -0
  23. dexcontrol/config/vega.py +203 -0
  24. dexcontrol/core/__init__.py +0 -0
  25. dexcontrol/core/arm.py +324 -0
  26. dexcontrol/core/chassis.py +628 -0
  27. dexcontrol/core/component.py +834 -0
  28. dexcontrol/core/hand.py +170 -0
  29. dexcontrol/core/head.py +232 -0
  30. dexcontrol/core/misc.py +514 -0
  31. dexcontrol/core/torso.py +198 -0
  32. dexcontrol/proto/dexcontrol_msg_pb2.py +69 -0
  33. dexcontrol/proto/dexcontrol_msg_pb2.pyi +168 -0
  34. dexcontrol/proto/dexcontrol_query_pb2.py +73 -0
  35. dexcontrol/proto/dexcontrol_query_pb2.pyi +134 -0
  36. dexcontrol/robot.py +1091 -0
  37. dexcontrol/sensors/__init__.py +40 -0
  38. dexcontrol/sensors/camera/__init__.py +18 -0
  39. dexcontrol/sensors/camera/gemini_camera.py +139 -0
  40. dexcontrol/sensors/camera/rgb_camera.py +98 -0
  41. dexcontrol/sensors/imu/__init__.py +22 -0
  42. dexcontrol/sensors/imu/gemini_imu.py +139 -0
  43. dexcontrol/sensors/imu/nine_axis_imu.py +149 -0
  44. dexcontrol/sensors/lidar/__init__.py +3 -0
  45. dexcontrol/sensors/lidar/rplidar.py +164 -0
  46. dexcontrol/sensors/manager.py +185 -0
  47. dexcontrol/sensors/ultrasonic.py +110 -0
  48. dexcontrol/utils/__init__.py +15 -0
  49. dexcontrol/utils/constants.py +12 -0
  50. dexcontrol/utils/io_utils.py +26 -0
  51. dexcontrol/utils/motion_utils.py +194 -0
  52. dexcontrol/utils/os_utils.py +39 -0
  53. dexcontrol/utils/pb_utils.py +103 -0
  54. dexcontrol/utils/rate_limiter.py +167 -0
  55. dexcontrol/utils/reset_orbbec_camera_usb.py +98 -0
  56. dexcontrol/utils/subscribers/__init__.py +44 -0
  57. dexcontrol/utils/subscribers/base.py +260 -0
  58. dexcontrol/utils/subscribers/camera.py +328 -0
  59. dexcontrol/utils/subscribers/decoders.py +83 -0
  60. dexcontrol/utils/subscribers/generic.py +105 -0
  61. dexcontrol/utils/subscribers/imu.py +170 -0
  62. dexcontrol/utils/subscribers/lidar.py +195 -0
  63. dexcontrol/utils/subscribers/protobuf.py +106 -0
  64. dexcontrol/utils/timer.py +136 -0
  65. dexcontrol/utils/trajectory_utils.py +40 -0
  66. dexcontrol/utils/viz_utils.py +86 -0
  67. dexcontrol-0.2.1.dist-info/METADATA +369 -0
  68. dexcontrol-0.2.1.dist-info/RECORD +72 -0
  69. dexcontrol-0.2.1.dist-info/WHEEL +5 -0
  70. dexcontrol-0.2.1.dist-info/licenses/LICENSE +188 -0
  71. dexcontrol-0.2.1.dist-info/licenses/NOTICE +13 -0
  72. dexcontrol-0.2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,834 @@
1
+ # Copyright (c) 2025 Dexmate CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 with Commons Clause License
4
+ # Condition v1.0 [see LICENSE for details].
5
+
6
+ """Base module for robot components with Zenoh communication.
7
+
8
+ This module provides base classes for robot components that communicate via Zenoh.
9
+ It includes RobotComponent for state-only components and RobotJointComponent for
10
+ components that also support control commands.
11
+ """
12
+
13
+ import time
14
+ from typing import Any, Final, Mapping, TypeVar
15
+
16
+ import numpy as np
17
+ import zenoh
18
+ from google.protobuf.message import Message
19
+ from jaxtyping import Float
20
+
21
+ from dexcontrol.utils.os_utils import resolve_key_name
22
+ from dexcontrol.utils.subscribers import ProtobufZenohSubscriber
23
+
24
+ # Type variable for Message subclasses
25
+ M = TypeVar("M", bound=Message)
26
+
27
+
28
+ class RobotComponent:
29
+ """Base class for robot components with state interface.
30
+
31
+ A component represents a physical part of the robot that maintains state through
32
+ Zenoh communication. It subscribes to state updates and provides methods to
33
+ access the latest state data.
34
+
35
+ Attributes:
36
+ _state_message_type: Protobuf message class for component state.
37
+ _zenoh_session: Active Zenoh session for communication.
38
+ _subscriber: Zenoh subscriber for state updates.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ state_sub_topic: str,
44
+ zenoh_session: zenoh.Session,
45
+ state_message_type: type[M],
46
+ ) -> None:
47
+ """Initializes RobotComponent.
48
+
49
+ Args:
50
+ state_sub_topic: Topic to subscribe to for state updates.
51
+ zenoh_session: Active Zenoh session for communication.
52
+ state_message_type: Protobuf message class for component state.
53
+ """
54
+ self._state_message_type = state_message_type
55
+ self._zenoh_session: Final[zenoh.Session] = zenoh_session
56
+ self._init_subscriber(state_sub_topic, state_message_type, zenoh_session)
57
+
58
+ def _init_subscriber(
59
+ self,
60
+ state_sub_topic: str,
61
+ state_message_type: type[M],
62
+ zenoh_session: zenoh.Session,
63
+ ) -> None:
64
+ """Initialize the Zenoh subscriber for state updates.
65
+
66
+ Args:
67
+ state_sub_topic: Topic to subscribe to for state updates.
68
+ state_message_type: Protobuf message class for component state.
69
+ zenoh_session: Active Zenoh session for communication.
70
+ """
71
+ self._subscriber = ProtobufZenohSubscriber(
72
+ topic=state_sub_topic,
73
+ zenoh_session=zenoh_session,
74
+ message_type=state_message_type,
75
+ name=f"{self.__class__.__name__}",
76
+ enable_fps_tracking=False,
77
+ )
78
+
79
+ def _get_state(self) -> M:
80
+ """Gets the current state of the component.
81
+
82
+ Returns:
83
+ Parsed protobuf state message.
84
+
85
+ Raises:
86
+ RuntimeError: If no state data is available.
87
+ """
88
+ state = self._subscriber.get_latest_data()
89
+ if state is None:
90
+ raise RuntimeError("No state data available")
91
+ return state
92
+
93
+ def wait_for_active(self, timeout: float = 5.0) -> bool:
94
+ """Waits for the component to start receiving state updates.
95
+
96
+ Args:
97
+ timeout: Maximum time to wait in seconds.
98
+
99
+ Returns:
100
+ True if component becomes active, False if timeout is reached.
101
+ """
102
+ return self._subscriber.wait_for_active(timeout)
103
+
104
+ def is_active(self) -> bool:
105
+ """Check if component is receiving state updates.
106
+
107
+ Returns:
108
+ True if component is active, False otherwise.
109
+ """
110
+ return self._subscriber.is_active()
111
+
112
+ def shutdown(self) -> None:
113
+ """Cleans up Zenoh resources."""
114
+ # Stop any ongoing operations if the component has a stop method
115
+ if hasattr(self, "stop"):
116
+ method = getattr(self, "stop")
117
+ if callable(method):
118
+ method()
119
+
120
+ # Shutdown subscriber to release resources
121
+ if hasattr(self, "_subscriber") and self._subscriber:
122
+ self._subscriber.shutdown()
123
+
124
+
125
+ class RobotJointComponent(RobotComponent):
126
+ """Base class for robot components with both state and control interfaces.
127
+
128
+ Extends RobotComponent to add APIs for interacting with joints.
129
+
130
+ Attributes:
131
+ _publisher: Zenoh publisher for control commands.
132
+ _joint_name: List of joint names for this component.
133
+ _pose_pool: Dictionary of predefined poses for this component.
134
+ """
135
+
136
+ @staticmethod
137
+ def _convert_pose_pool_to_arrays(
138
+ pose_pool: Mapping[str, list[float] | np.ndarray] | None = None,
139
+ ) -> dict[str, np.ndarray] | None:
140
+ """Convert pose pool values to numpy arrays.
141
+
142
+ Args:
143
+ pose_pool: Dictionary mapping pose names to lists or arrays of joint values.
144
+
145
+ Returns:
146
+ Dictionary mapping pose names to numpy arrays, or None if input is None.
147
+ """
148
+ if pose_pool is None:
149
+ return None
150
+
151
+ return {
152
+ name: np.array(pose, dtype=np.float32) for name, pose in pose_pool.items()
153
+ }
154
+
155
+ def __init__(
156
+ self,
157
+ state_sub_topic: str,
158
+ control_pub_topic: str,
159
+ state_message_type: type[M],
160
+ zenoh_session: zenoh.Session,
161
+ joint_name: list[str] | None = None,
162
+ pose_pool: Mapping[str, list[float] | np.ndarray] | None = None,
163
+ ) -> None:
164
+ """Initializes RobotJointComponent.
165
+
166
+ Args:
167
+ state_sub_topic: Topic to subscribe to for state updates.
168
+ control_pub_topic: Topic to publish control commands.
169
+ state_message_type: Protobuf message class for component state.
170
+ zenoh_session: Active Zenoh session for communication.
171
+ joint_name: List of joint names for this component.
172
+ pose_pool: Dictionary of predefined poses for this component.
173
+ """
174
+ super().__init__(state_sub_topic, zenoh_session, state_message_type)
175
+
176
+ resolved_topic: Final[str] = resolve_key_name(control_pub_topic)
177
+ self._publisher: Final[zenoh.Publisher] = zenoh_session.declare_publisher(
178
+ resolved_topic
179
+ )
180
+ self._joint_name: list[str] | None = joint_name
181
+ self._pose_pool: dict[str, np.ndarray] | None = (
182
+ self._convert_pose_pool_to_arrays(pose_pool)
183
+ )
184
+
185
+ def _publish_control(self, control_msg: Message) -> None:
186
+ """Publishes a control command message.
187
+
188
+ Args:
189
+ control_msg: Protobuf control message to publish.
190
+ """
191
+ msg_bytes = control_msg.SerializeToString()
192
+ self._publisher.put(msg_bytes)
193
+
194
+ def shutdown(self) -> None:
195
+ """Cleans up all Zenoh resources."""
196
+ super().shutdown()
197
+ if hasattr(self, "_publisher") and self._publisher:
198
+ self._publisher.undeclare()
199
+
200
+ @property
201
+ def joint_name(self) -> list[str]:
202
+ """Gets the joint names of the component.
203
+
204
+ Returns:
205
+ List of joint names.
206
+
207
+ Raises:
208
+ ValueError: If joint names are not available.
209
+ """
210
+ if self._joint_name is None:
211
+ raise ValueError("Joint names not available for this component")
212
+ return self._joint_name.copy()
213
+
214
+ def get_predefined_pose(self, pose_name: str) -> np.ndarray:
215
+ """Gets a predefined pose from the pose pool.
216
+
217
+ Args:
218
+ pose_name: Name of the pose to get.
219
+
220
+ Returns:
221
+ The joint positions for the requested pose.
222
+
223
+ Raises:
224
+ ValueError: If pose pool is not available or pose name is invalid.
225
+ """
226
+ if self._pose_pool is None:
227
+ raise ValueError("Pose pool not available for this component.")
228
+ if pose_name not in self._pose_pool:
229
+ available_poses = list(self._pose_pool.keys())
230
+ raise ValueError(
231
+ f"Invalid pose name: {pose_name}. Available poses: {available_poses}"
232
+ )
233
+ return np.array(self._pose_pool[pose_name], dtype=float).copy()
234
+
235
+ def get_joint_name(self) -> list[str]:
236
+ """Gets the joint names of the component.
237
+
238
+ Returns:
239
+ List of joint names.
240
+
241
+ Raises:
242
+ ValueError: If joint names are not available.
243
+ """
244
+ return self.joint_name
245
+
246
+ def get_joint_pos(
247
+ self, joint_id: list[int] | int | None = None
248
+ ) -> Float[np.ndarray, " N"]:
249
+ """Gets the current positions of all joints in the component.
250
+
251
+ The returned array contains joint positions in the same order as joint_id.
252
+
253
+ Args:
254
+ joint_id: Optional ID(s) of specific joints to query.
255
+
256
+ Returns:
257
+ Array of joint positions in component-specific units (radians for
258
+ revolute joints and meters for prismatic joints).
259
+
260
+ Raises:
261
+ ValueError: If joint positions are not available for this component.
262
+ """
263
+ state = self._get_state()
264
+ if not hasattr(state, "joint_pos"):
265
+ raise ValueError("Joint positions are not available for this component.")
266
+ joint_pos = np.array(state.joint_pos, dtype=np.float32)
267
+ return self._extract_joint_info(joint_pos, joint_id=joint_id)
268
+
269
+ def get_joint_pos_dict(
270
+ self, joint_id: list[int] | int | None = None
271
+ ) -> dict[str, float]:
272
+ """Gets the current positions of all joints in the component as a dictionary.
273
+
274
+ Args:
275
+ joint_id: Optional ID(s) of specific joints to query.
276
+
277
+ Returns:
278
+ Dictionary mapping joint names to position values.
279
+
280
+ Raises:
281
+ ValueError: If joint positions are not available for this component.
282
+ """
283
+ values = self.get_joint_pos(joint_id)
284
+ return self._convert_to_dict(values, joint_id)
285
+
286
+ def get_joint_vel(
287
+ self, joint_id: list[int] | int | None = None
288
+ ) -> Float[np.ndarray, " N"]:
289
+ """Gets the current velocities of all joints in the component.
290
+
291
+ Args:
292
+ joint_id: Optional ID(s) of specific joints to query.
293
+
294
+ Returns:
295
+ Array of joint velocities in component-specific units (radians/s for
296
+ revolute joints and meters/s for prismatic joints).
297
+
298
+ Raises:
299
+ ValueError: If joint velocities are not available for this component.
300
+ """
301
+ state = self._get_state()
302
+ if not hasattr(state, "joint_vel"):
303
+ raise ValueError("Joint velocities are not available for this component.")
304
+ joint_vel = np.array(state.joint_vel, dtype=np.float32)
305
+ return self._extract_joint_info(joint_vel, joint_id=joint_id)
306
+
307
+ def get_joint_vel_dict(
308
+ self, joint_id: list[int] | int | None = None
309
+ ) -> dict[str, float]:
310
+ """Gets the current velocities of all joints in the component as a dictionary.
311
+
312
+ Args:
313
+ joint_id: Optional ID(s) of specific joints to query.
314
+
315
+ Returns:
316
+ Dictionary mapping joint names to velocity values.
317
+
318
+ Raises:
319
+ ValueError: If joint velocities are not available for this component.
320
+ """
321
+ values = self.get_joint_vel(joint_id)
322
+ return self._convert_to_dict(values, joint_id)
323
+
324
+ def get_joint_current(
325
+ self, joint_id: list[int] | int | None = None
326
+ ) -> Float[np.ndarray, " N"]:
327
+ """Gets the current of all joints in the component.
328
+
329
+ Args:
330
+ joint_id: Optional ID(s) of specific joints to query.
331
+
332
+ Returns:
333
+ Array of joint currents in component-specific units (amperes).
334
+
335
+ Raises:
336
+ ValueError: If joint currents are not available for this component.
337
+ """
338
+ state = self._get_state()
339
+ if not hasattr(state, "joint_cur"):
340
+ raise ValueError("Joint currents are not available for this component.")
341
+ joint_cur = np.array(state.joint_cur, dtype=np.float32)
342
+ return self._extract_joint_info(joint_cur, joint_id=joint_id)
343
+
344
+ def get_joint_current_dict(
345
+ self, joint_id: list[int] | int | None = None
346
+ ) -> dict[str, float]:
347
+ """Gets the current of all joints in the component as a dictionary.
348
+
349
+ Args:
350
+ joint_id: Optional ID(s) of specific joints to query.
351
+
352
+ Returns:
353
+ Dictionary mapping joint names to current values.
354
+
355
+ Raises:
356
+ ValueError: If joint currents are not available for this component.
357
+ """
358
+ values = self.get_joint_current(joint_id)
359
+ return self._convert_to_dict(values, joint_id)
360
+
361
+ def get_joint_err(self, joint_id: list[int] | int | None = None) -> np.ndarray:
362
+ """Gets current joint error codes.
363
+
364
+ Args:
365
+ joint_id: Optional ID(s) of specific joints to query.
366
+
367
+ Returns:
368
+ Array of joint error codes.
369
+
370
+ Raises:
371
+ ValueError: If joint error codes are not available for this component.
372
+ """
373
+ state = self._get_state()
374
+ if not hasattr(state, "joint_err"):
375
+ raise ValueError("Joint error codes are not available for this component.")
376
+ joint_err = np.array(state.joint_err, dtype=np.uint32)
377
+ return self._extract_joint_info(joint_err, joint_id=joint_id)
378
+
379
+ def get_joint_err_dict(
380
+ self, joint_id: list[int] | int | None = None
381
+ ) -> dict[str, int]:
382
+ """Gets current joint error codes as a dictionary.
383
+
384
+ Args:
385
+ joint_id: Optional ID(s) of specific joints to query.
386
+
387
+ Returns:
388
+ Dictionary mapping joint names to error code values.
389
+
390
+ Raises:
391
+ ValueError: If joint error codes are not available for this component.
392
+ """
393
+ values = self.get_joint_err(joint_id)
394
+ return self._convert_to_dict(values, joint_id)
395
+
396
+ def get_joint_state(self, joint_id: list[int] | int | None = None) -> np.ndarray:
397
+ """Gets current joint states including positions, velocities and currents.
398
+
399
+ Args:
400
+ joint_id: Optional ID(s) of specific joints to query.
401
+
402
+ Returns:
403
+ Array of joint positions, velocities, and currents.
404
+ The last dimension corresponds to [positions, velocities, currents].
405
+
406
+ Raises:
407
+ ValueError: If joint positions or velocities are not available.
408
+ """
409
+ state = self._get_state()
410
+ if not hasattr(state, "joint_pos") or not hasattr(state, "joint_vel"):
411
+ raise ValueError(
412
+ "Joint positions or velocities are not available for this component."
413
+ )
414
+
415
+ # Create initial state array with positions and velocities
416
+ joint_pos = np.array(state.joint_pos, dtype=np.float32)
417
+ joint_vel = np.array(state.joint_vel, dtype=np.float32)
418
+
419
+ if hasattr(state, "joint_cur"):
420
+ # If currents are available, include them
421
+ joint_cur = np.array(state.joint_cur, dtype=np.float32)
422
+ joint_state = np.stack([joint_pos, joint_vel, joint_cur], axis=1)
423
+ else:
424
+ # Otherwise just include positions and velocities
425
+ joint_state = np.stack([joint_pos, joint_vel], axis=1)
426
+
427
+ return self._extract_joint_info(joint_state, joint_id=joint_id)
428
+
429
+ def get_joint_state_dict(
430
+ self, joint_id: list[int] | int | None = None
431
+ ) -> dict[str, Float[np.ndarray, "3"]]:
432
+ """Gets current joint states including positions, velocities and currents as a dictionary.
433
+
434
+ Args:
435
+ joint_id: Optional ID(s) of specific joints to query.
436
+
437
+ Returns:
438
+ Dictionary mapping joint names to arrays of [position, velocity, current].
439
+
440
+ Raises:
441
+ ValueError: If joint positions or velocities are not available.
442
+ """
443
+ values = self.get_joint_state(joint_id)
444
+ return self._convert_to_dict(values, joint_id)
445
+
446
+ def _convert_joint_cmd_to_array(
447
+ self,
448
+ joint_cmd: Float[np.ndarray, " N"] | list[float] | dict[str, float],
449
+ clip_value: float | None = None,
450
+ ) -> np.ndarray:
451
+ """Convert joint command to numpy array format.
452
+
453
+ Args:
454
+ joint_cmd: Joint command as either:
455
+ - List of joint values [j1, j2, ..., jN]
456
+ - Numpy array with shape (N,)
457
+ - Dictionary mapping joint names to values
458
+ clip_value: Optional value to clip the output array between [-clip_value, clip_value].
459
+
460
+ Returns:
461
+ Joint command as numpy array.
462
+ """
463
+ if isinstance(joint_cmd, dict):
464
+ joint_cmd = self._convert_dict_to_array(joint_cmd)
465
+ elif isinstance(joint_cmd, list):
466
+ joint_cmd = np.array(joint_cmd, dtype=np.float32)
467
+ else:
468
+ joint_cmd = joint_cmd.astype(np.float32)
469
+
470
+ if clip_value is not None:
471
+ joint_cmd = np.clip(joint_cmd, -clip_value, clip_value)
472
+
473
+ return joint_cmd
474
+
475
+ def _resolve_relative_joint_cmd(
476
+ self, joint_cmd: Float[np.ndarray, " N"] | list[float] | dict[str, float]
477
+ ) -> Float[np.ndarray, " N"] | dict[str, float]:
478
+ """Resolve relative joint command by adding current joint positions.
479
+
480
+ Args:
481
+ joint_cmd: Relative joint command as list, numpy array, or dictionary.
482
+
483
+ Returns:
484
+ Absolute joint command in the same format as input.
485
+ """
486
+ if isinstance(joint_cmd, dict):
487
+ current_pos = self.get_joint_pos_dict()
488
+ return {name: current_pos[name] + pos for name, pos in joint_cmd.items()}
489
+
490
+ # Convert list to numpy array if needed
491
+ joint_cmd = self._convert_joint_cmd_to_array(joint_cmd)
492
+ return self.get_joint_pos() + joint_cmd
493
+
494
+ def _extract_joint_info(
495
+ self, joint_info: np.ndarray, joint_id: list[int] | int | None = None
496
+ ) -> np.ndarray:
497
+ """Extract the joint information of the component as a numpy array.
498
+
499
+ Args:
500
+ joint_info: Array of joint information.
501
+ joint_id: Optional ID(s) of specific joints to extract.
502
+
503
+ Returns:
504
+ Array of joint information.
505
+
506
+ Raises:
507
+ ValueError: If an invalid joint ID is provided.
508
+ """
509
+ if joint_id is None:
510
+ return joint_info
511
+
512
+ if isinstance(joint_id, int):
513
+ if joint_id >= len(joint_info):
514
+ raise ValueError(
515
+ f"Invalid joint ID: {joint_id}. Must be less than {len(joint_info)}"
516
+ )
517
+ return joint_info[joint_id]
518
+
519
+ # joint_id is a list
520
+ if max(joint_id) >= len(joint_info):
521
+ raise ValueError(
522
+ f"Invalid joint ID in {joint_id}. Must be less than {len(joint_info)}"
523
+ )
524
+ return joint_info[joint_id]
525
+
526
+ def _convert_to_dict(
527
+ self, values: np.ndarray, joint_id: list[int] | int | None = None
528
+ ) -> dict[str, Any]:
529
+ """Convert a numpy array of joint values to a dictionary of joint names and values.
530
+
531
+ Args:
532
+ values: Array of joint values.
533
+ joint_id: Optional ID(s) of specific joints for the output.
534
+
535
+ Returns:
536
+ Dictionary of joint names and values.
537
+
538
+ Raises:
539
+ ValueError: If joint names are not available for this component.
540
+ """
541
+ if self._joint_name is None:
542
+ raise ValueError("Joint names not available for this component.")
543
+
544
+ if joint_id is None:
545
+ joint_id = list(range(len(self._joint_name)))
546
+ elif isinstance(joint_id, int):
547
+ joint_id = [joint_id]
548
+
549
+ if len(values.shape) == 1:
550
+ return {
551
+ self._joint_name[id]: float(value)
552
+ for id, value in zip(joint_id, values)
553
+ }
554
+ else:
555
+ return {self._joint_name[id]: values[i] for i, id in enumerate(joint_id)}
556
+
557
+ def _get_joint_index(self, joint_name: list[str] | str) -> list[int] | int:
558
+ """Get the indices of the specified joints.
559
+
560
+ Args:
561
+ joint_name: Name(s) of the joints to get indices for.
562
+
563
+ Returns:
564
+ List of indices or single index corresponding to the requested joints.
565
+
566
+ Raises:
567
+ ValueError: If joint names are not available or if an invalid joint name is provided.
568
+ """
569
+ if self._joint_name is None:
570
+ raise ValueError("Joint names not available for this component.")
571
+
572
+ if isinstance(joint_name, str):
573
+ try:
574
+ return self._joint_name.index(joint_name)
575
+ except ValueError:
576
+ raise ValueError(
577
+ f"Invalid joint name: {joint_name}. Available joints: {self._joint_name}"
578
+ )
579
+
580
+ # joint_name is a list
581
+ indices = []
582
+ for name in joint_name:
583
+ try:
584
+ indices.append(self._joint_name.index(name))
585
+ except ValueError:
586
+ raise ValueError(
587
+ f"Invalid joint name: {name}. Available joints: {self._joint_name}"
588
+ )
589
+ return indices
590
+
591
+ def _convert_dict_to_array(
592
+ self, joint_pos_dict: dict[str, float]
593
+ ) -> Float[np.ndarray, " N"]:
594
+ """Convert joint position dictionary to array format.
595
+
596
+ Args:
597
+ joint_pos_dict: Dictionary mapping joint names to position values.
598
+
599
+ Returns:
600
+ Array of joint positions in the correct order.
601
+
602
+ Raises:
603
+ ValueError: If joint_pos_dict contains invalid joint names.
604
+ """
605
+ current_joint_pos = self.get_joint_pos().copy()
606
+ target_joint_names = list(joint_pos_dict.keys())
607
+ target_joint_indices = self._get_joint_index(target_joint_names)
608
+ current_joint_pos[target_joint_indices] = list(joint_pos_dict.values())
609
+ return current_joint_pos
610
+
611
+ def set_joint_pos(
612
+ self,
613
+ joint_pos: Float[np.ndarray, " N"] | list[float] | dict[str, float],
614
+ relative: bool = False,
615
+ wait_time: float = 0.0,
616
+ wait_kwargs: dict[str, float] | None = None,
617
+ ) -> None:
618
+ """Send joint position control commands.
619
+
620
+ Args:
621
+ joint_pos: Joint positions as either:
622
+ - List of joint values [j1, j2, ..., jN]
623
+ - Numpy array with shape (N,)
624
+ - Dictionary mapping joint names to position values
625
+ relative: If True, the joint positions are relative to the current position.
626
+ wait_time: Time to wait after sending command in seconds.
627
+ wait_kwargs: Optional parameters for trajectory generation.
628
+
629
+ Raises:
630
+ ValueError: If joint_pos dictionary contains invalid joint names.
631
+ """
632
+ if relative:
633
+ joint_pos = self._resolve_relative_joint_cmd(joint_pos)
634
+
635
+ # Convert to array format
636
+ if isinstance(joint_pos, (list, dict)):
637
+ joint_pos = self._convert_joint_cmd_to_array(joint_pos)
638
+
639
+ self._send_position_command(joint_pos)
640
+
641
+ if wait_time > 0.0:
642
+ time.sleep(wait_time)
643
+
644
+ def _send_position_command(self, joint_pos: Float[np.ndarray, " N"]) -> None:
645
+ """Send joint position command to the component.
646
+
647
+ This method should be overridden by child classes to implement
648
+ component-specific command message creation and publishing.
649
+
650
+ Args:
651
+ joint_pos: Joint positions as numpy array.
652
+
653
+ Raises:
654
+ NotImplementedError: If child class does not implement this method.
655
+ """
656
+ raise NotImplementedError("Child class must implement _send_position_command")
657
+
658
+ def go_to_pose(
659
+ self,
660
+ pose_name: str,
661
+ wait_time: float = 3.0,
662
+ ) -> None:
663
+ """Move the component to a predefined pose.
664
+
665
+ Args:
666
+ pose_name: Name of the pose to move to.
667
+ wait_time: Time to wait for the component to reach the pose.
668
+
669
+ Raises:
670
+ ValueError: If pose pool is not available or if an invalid pose name is provided.
671
+ """
672
+ if self._pose_pool is None:
673
+ raise ValueError("Pose pool not available for this component.")
674
+ if pose_name not in self._pose_pool:
675
+ raise ValueError(
676
+ f"Invalid pose name: {pose_name}. Available poses: {list(self._pose_pool.keys())}"
677
+ )
678
+ pose = self._pose_pool[pose_name]
679
+ self.set_joint_pos(pose, wait_time=wait_time)
680
+
681
+ def is_joint_pos_reached(
682
+ self,
683
+ joint_pos: np.ndarray | dict[str, float],
684
+ tolerance: float = 0.05,
685
+ joint_id: list[int] | int | None = None,
686
+ ) -> bool:
687
+ """Check if the robot's current joint positions are within a certain tolerance of the target positions.
688
+
689
+ Args:
690
+ joint_pos: Target joint positions.
691
+ tolerance: Tolerance for joint position check.
692
+ joint_id: Optional specific joint indices to check.
693
+
694
+ Returns:
695
+ True if all specified joint positions are within tolerance, False otherwise.
696
+ """
697
+ # Handle dictionary input
698
+ if isinstance(joint_pos, dict):
699
+ current_pos = self.get_joint_pos_dict()
700
+ return self._check_dict_positions_reached(
701
+ joint_pos, current_pos, tolerance, joint_id
702
+ )
703
+
704
+ # Handle numpy array input
705
+ current_pos = self.get_joint_pos()
706
+ return self._check_array_positions_reached(
707
+ joint_pos, current_pos, tolerance, joint_id
708
+ )
709
+
710
+ def _check_dict_positions_reached(
711
+ self,
712
+ target_pos: dict[str, float],
713
+ current_pos: dict[str, float],
714
+ tolerance: float,
715
+ joint_id: list[int] | int | None,
716
+ ) -> bool:
717
+ """Check if dictionary-based joint positions are reached.
718
+
719
+ Args:
720
+ target_pos: Target joint positions as dictionary.
721
+ current_pos: Current joint positions as dictionary.
722
+ tolerance: Tolerance for position check.
723
+ joint_id: Optional specific joint indices to check.
724
+
725
+ Returns:
726
+ True if positions are within tolerance, False otherwise.
727
+ """
728
+ if joint_id is not None:
729
+ # Get joint names for the specified indices
730
+ if self._joint_name is None:
731
+ raise ValueError("Joint names not available for this component")
732
+
733
+ # Handle single index case
734
+ if isinstance(joint_id, int):
735
+ if joint_id >= len(self._joint_name):
736
+ return True # Invalid index, consider it reached
737
+
738
+ name = self._joint_name[joint_id]
739
+ return (
740
+ name in target_pos
741
+ and abs(current_pos[name] - target_pos[name]) <= tolerance
742
+ )
743
+
744
+ # Handle list of indices - filter valid ones
745
+ valid_names = []
746
+ for idx in joint_id:
747
+ if idx < len(self._joint_name):
748
+ name = self._joint_name[idx]
749
+ if name in target_pos:
750
+ valid_names.append(name)
751
+
752
+ # Only check valid joints that are in the target position dictionary
753
+ return all(
754
+ abs(current_pos[name] - target_pos[name]) <= tolerance
755
+ for name in valid_names
756
+ )
757
+ else:
758
+ # Check all joints in the dictionary
759
+ return all(
760
+ abs(current_pos[name] - pos) <= tolerance
761
+ for name, pos in target_pos.items()
762
+ )
763
+
764
+ def _check_array_positions_reached(
765
+ self,
766
+ target_pos: np.ndarray,
767
+ current_pos: np.ndarray,
768
+ tolerance: float,
769
+ joint_id: list[int] | int | None,
770
+ ) -> bool:
771
+ """Check if array-based joint positions are reached.
772
+
773
+ Args:
774
+ target_pos: Target joint positions as numpy array.
775
+ current_pos: Current joint positions as numpy array.
776
+ tolerance: Tolerance for position check.
777
+ joint_id: Optional specific joint indices to check.
778
+
779
+ Returns:
780
+ True if positions are within tolerance, False otherwise.
781
+ """
782
+ if joint_id is not None:
783
+ if isinstance(joint_id, int):
784
+ # Single index - simple and efficient
785
+ if joint_id >= len(current_pos) or joint_id >= len(target_pos):
786
+ return True # Invalid index, consider it reached
787
+ return abs(current_pos[joint_id] - target_pos[joint_id]) <= tolerance
788
+ else:
789
+ # For multiple indices - process one by one
790
+ # This avoids using list indexing with lists which ListConfig doesn't support
791
+ if len(current_pos) == 0 or len(target_pos) == 0:
792
+ return True
793
+
794
+ for idx in joint_id:
795
+ if idx < len(current_pos) and idx < len(target_pos):
796
+ if abs(current_pos[idx] - target_pos[idx]) > tolerance:
797
+ return False
798
+ return True
799
+ else:
800
+ # Check all joints, ensuring arrays are same length
801
+ min_len = min(len(current_pos), len(target_pos))
802
+ return bool(
803
+ np.all(
804
+ np.abs(current_pos[:min_len] - target_pos[:min_len]) <= tolerance
805
+ )
806
+ )
807
+
808
+ def is_pose_reached(
809
+ self,
810
+ pose_name: str,
811
+ tolerance: float = 0.05,
812
+ joint_id: list[int] | int | None = None,
813
+ ) -> bool:
814
+ """Check if the robot's current joint positions are within a certain tolerance of the target pose.
815
+
816
+ Args:
817
+ pose_name: Name of the pose to check against.
818
+ tolerance: Tolerance for joint position check.
819
+ joint_id: Optional specific joint indices to check.
820
+
821
+ Returns:
822
+ True if all specified joint positions are within tolerance, False otherwise.
823
+
824
+ Raises:
825
+ ValueError: If pose pool is not available or pose name is invalid.
826
+ """
827
+ if self._pose_pool is None:
828
+ raise ValueError("Pose pool not available for this component.")
829
+ if pose_name not in self._pose_pool:
830
+ raise ValueError(
831
+ f"Invalid pose name: {pose_name}. Available poses: {list(self._pose_pool.keys())}"
832
+ )
833
+ pose = self._pose_pool[pose_name]
834
+ return self.is_joint_pos_reached(pose, tolerance=tolerance, joint_id=joint_id)