robot-keyframe-kit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ """Robot Keyframe Kit - A generalizable Viser-based keyframe editor for MuJoCo robots."""
2
+
3
+ from .config import EditorConfig
4
+ from .editor import ViserKeyframeEditor
5
+ from .keyframe import Keyframe
6
+
7
+ __version__ = "0.1.0"
8
+ __all__ = ["ViserKeyframeEditor", "EditorConfig", "Keyframe"]
9
+
10
+
11
+
@@ -0,0 +1,242 @@
1
+ """Configuration dataclass for the keyframe editor."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from dataclasses import dataclass, field
7
+ from typing import Dict, List, Optional
8
+
9
+ try:
10
+ import yaml
11
+ except ImportError:
12
+ yaml = None
13
+
14
+
15
+ @dataclass
16
+ class EditorConfig:
17
+ """Configuration options for the ViserKeyframeEditor.
18
+
19
+ Attributes:
20
+ root_body: Name of the root body used for ground alignment (e.g., "torso", "base_link").
21
+ If None, will be auto-detected as the first non-world body that is a direct child of world.
22
+ end_effector_sites: List of site names for end-effector tracking/alignment.
23
+ If None, will attempt to auto-detect sites with common naming patterns.
24
+ mirror_pairs: Dictionary mapping left joint names to right joint names for mirroring.
25
+ If None, will attempt to auto-detect based on "left"/"right" naming convention.
26
+ mirror_signs: Dictionary of joint names to sign multipliers for mirroring.
27
+ Positive 1 means same direction, -1 means opposite direction.
28
+ dt: Timestep for trajectory playback (seconds).
29
+ save_dir: Directory to save keyframe data files.
30
+ name: Optional name for this robot/project (used in save filenames).
31
+ """
32
+ root_body: Optional[str] = None
33
+ end_effector_sites: Optional[List[str]] = None
34
+ mirror_pairs: Optional[Dict[str, str]] = None
35
+ mirror_signs: Optional[Dict[str, int]] = None
36
+ dt: float = 0.02
37
+ save_dir: str = "keyframes"
38
+ name: str = "robot"
39
+
40
+ # Physics simulation settings
41
+ n_frames: int = 20 # Number of physics substeps per control step
42
+ physics_dt: float = 0.001 # Physics timestep
43
+
44
+ # PD control gains for trajectory playback
45
+ # Higher values = stiffer joints. Adjust based on robot's actuator setup.
46
+ kp: float = 50.0 # Position gain (proportional)
47
+ kd: float = 2.0 # Velocity gain (derivative)
48
+
49
+ # UI settings
50
+ show_com: bool = True # Show center of mass marker
51
+ show_grid: bool = True # Show ground grid
52
+
53
+ @classmethod
54
+ def from_yaml(cls, path: str) -> "EditorConfig":
55
+ """Load configuration from a YAML file.
56
+
57
+ Args:
58
+ path: Path to the YAML configuration file.
59
+
60
+ Returns:
61
+ EditorConfig instance loaded from the file.
62
+
63
+ Raises:
64
+ ImportError: If PyYAML is not installed.
65
+ FileNotFoundError: If the config file doesn't exist.
66
+ """
67
+ if yaml is None:
68
+ raise ImportError(
69
+ "PyYAML is required to load YAML config files. "
70
+ "Install with: pip install pyyaml"
71
+ )
72
+
73
+ if not os.path.exists(path):
74
+ raise FileNotFoundError(f"Config file not found: {path}")
75
+
76
+ with open(path, "r") as f:
77
+ data = yaml.safe_load(f)
78
+
79
+ if data is None:
80
+ data = {}
81
+
82
+ # Map YAML keys to config fields
83
+ mapped_data = {}
84
+
85
+ # Direct mappings
86
+ for key in ["name", "root_body", "dt", "save_dir", "kp", "kd", "n_frames", "physics_dt", "show_com", "show_grid"]:
87
+ if key in data:
88
+ mapped_data[key] = data[key]
89
+
90
+ # Handle end_effectors -> end_effector_sites
91
+ if "end_effectors" in data:
92
+ mapped_data["end_effector_sites"] = data["end_effectors"]
93
+ elif "end_effector_sites" in data:
94
+ mapped_data["end_effector_sites"] = data["end_effector_sites"]
95
+
96
+ # Handle mirror_pairs and mirror_signs
97
+ if "mirror_pairs" in data:
98
+ mapped_data["mirror_pairs"] = data["mirror_pairs"]
99
+ if "mirror_signs" in data:
100
+ mapped_data["mirror_signs"] = data["mirror_signs"]
101
+
102
+ # Extract nested physics settings (override direct settings)
103
+ physics = data.get("physics", {})
104
+ if isinstance(physics, dict):
105
+ for key in ["kp", "kd", "dt"]:
106
+ if key in physics:
107
+ mapped_data[key] = physics[key]
108
+
109
+ # Extract nested UI settings
110
+ ui = data.get("ui", {})
111
+ if isinstance(ui, dict):
112
+ for key in ["show_com", "show_grid"]:
113
+ if key in ui:
114
+ mapped_data[key] = ui[key]
115
+
116
+ return cls(**mapped_data)
117
+
118
+ def to_yaml(self, path: str) -> None:
119
+ """Save configuration to a YAML file.
120
+
121
+ Args:
122
+ path: Path where to save the YAML configuration file.
123
+
124
+ Raises:
125
+ ImportError: If PyYAML is not installed.
126
+ """
127
+ if yaml is None:
128
+ raise ImportError(
129
+ "PyYAML is required to save YAML config files. "
130
+ "Install with: pip install pyyaml"
131
+ )
132
+
133
+ data = {
134
+ "name": self.name,
135
+ "root_body": self.root_body,
136
+ "end_effectors": self.end_effector_sites,
137
+ "mirror_pairs": self.mirror_pairs,
138
+ "mirror_signs": self.mirror_signs,
139
+ "dt": self.dt,
140
+ "save_dir": self.save_dir,
141
+ "physics": {
142
+ "kp": self.kp,
143
+ "kd": self.kd,
144
+ "dt": self.dt,
145
+ },
146
+ "ui": {
147
+ "show_com": self.show_com,
148
+ "show_grid": self.show_grid,
149
+ },
150
+ }
151
+
152
+ # Remove None values
153
+ data = {k: v for k, v in data.items() if v is not None}
154
+
155
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
156
+ with open(path, "w") as f:
157
+ yaml.dump(data, f, default_flow_style=False, sort_keys=False)
158
+
159
+ @classmethod
160
+ def generate_from_model(cls, xml_path: str, name: Optional[str] = None) -> "EditorConfig":
161
+ """Generate a configuration file from a MuJoCo model by auto-detecting settings.
162
+
163
+ This creates a config with auto-detected values that can be manually edited.
164
+
165
+ Args:
166
+ xml_path: Path to the MuJoCo XML file.
167
+ name: Optional name for the robot. If None, inferred from XML filename.
168
+
169
+ Returns:
170
+ EditorConfig instance with auto-detected values.
171
+ """
172
+ import mujoco
173
+
174
+ model = mujoco.MjModel.from_xml_path(xml_path)
175
+
176
+ # Infer name from XML path if not provided
177
+ if name is None:
178
+ name = os.path.splitext(os.path.basename(xml_path))[0]
179
+
180
+ # Auto-detect root body
181
+ root_body = None
182
+ for body_id in range(model.nbody):
183
+ if body_id == 0: # Skip world body
184
+ continue
185
+ if model.body_parentid[body_id] == 0: # Direct child of world
186
+ root_body = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_BODY, body_id)
187
+ if root_body and root_body != "world":
188
+ break
189
+
190
+ # Auto-detect end-effector sites (from leaf bodies)
191
+ parent_ids = set(model.body_parentid)
192
+ leaf_body_ids = [bid for bid in range(model.nbody) if bid not in parent_ids]
193
+
194
+ end_effectors = []
195
+ ee_keywords = ["foot", "hand", "calf", "leg", "lleg", "ankle", "toe", "gripper"]
196
+
197
+ # First try sites
198
+ for body_id in leaf_body_ids:
199
+ for site_id in range(model.nsite):
200
+ if model.site_bodyid[site_id] == body_id:
201
+ site_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_SITE, site_id)
202
+ if site_name:
203
+ end_effectors.append(site_name)
204
+
205
+ # Fallback to leaf bodies
206
+ if not end_effectors:
207
+ for body_id in leaf_body_ids:
208
+ body_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_BODY, body_id)
209
+ if body_name and body_name != "world":
210
+ body_lower = body_name.lower()
211
+ if any(kw in body_lower for kw in ee_keywords):
212
+ end_effectors.append(body_name)
213
+
214
+ # Auto-detect mirror pairs (left/right joints)
215
+ mirror_pairs = {}
216
+ joint_names = []
217
+ for jnt_id in range(model.njnt):
218
+ jnt_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, jnt_id)
219
+ if jnt_name:
220
+ joint_names.append(jnt_name)
221
+
222
+ # Find left/right pairs
223
+ for joint_name in joint_names:
224
+ if "left" in joint_name.lower() or "_l_" in joint_name.lower():
225
+ # Try to find corresponding right joint
226
+ right_name = joint_name.replace("left", "right").replace("Left", "Right")
227
+ right_name_alt = joint_name.replace("_l_", "_r_").replace("_L_", "_R_")
228
+
229
+ if right_name in joint_names:
230
+ mirror_pairs[joint_name] = right_name
231
+ elif right_name_alt in joint_names:
232
+ mirror_pairs[joint_name] = right_name_alt
233
+
234
+ return cls(
235
+ name=name,
236
+ root_body=root_body,
237
+ end_effector_sites=end_effectors if end_effectors else None,
238
+ mirror_pairs=mirror_pairs if mirror_pairs else None,
239
+ mirror_signs=None, # Will be auto-computed by editor
240
+ )
241
+
242
+