mujoco-lidar 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ # Lazy import to avoid loading dependencies when not needed
2
+ # Import the wrapper class directly
3
+ from mujoco_lidar.lidar_wrapper import MjLidarWrapper
4
+ from typing import Any
5
+
6
+ __version__ = "0.2.5"
7
+
8
+ __all__ = [
9
+ "MjLidarWrapper",
10
+ # Scan generation functions (imported lazily via __getattr__)
11
+ "LivoxGeneratorTi",
12
+ "LivoxGenerator", # From scan_gen_livox (requires taichi)
13
+ "generate_grid_scan_pattern",
14
+ "create_lidar_single_line",
15
+ "generate_HDL64", # From scan_gen (no taichi needed)
16
+ "generate_vlp32",
17
+ "generate_os128",
18
+ "generate_airy96"
19
+ ]
20
+
21
+ def __getattr__(name: str) -> Any:
22
+ """Lazy import for scan generation functions."""
23
+ # LivoxGeneratorTi requires taichi - import from scan_gen_livox_ti
24
+ if name == "LivoxGeneratorTi":
25
+ from mujoco_lidar.scan_gen_livox_ti import LivoxGeneratorTi
26
+ return LivoxGeneratorTi
27
+ # Other scan functions don't require taichi - import from scan_gen
28
+ elif name in ["LivoxGenerator", "generate_grid_scan_pattern", "create_lidar_single_line",
29
+ "generate_HDL64", "generate_vlp32", "generate_os128", "generate_airy96"]:
30
+ from mujoco_lidar import scan_gen
31
+ return getattr(scan_gen, name)
32
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -0,0 +1,240 @@
1
+ import mujoco
2
+ import numpy as np
3
+ from typing import Optional, Dict, Any, Tuple, Union
4
+
5
+ class MjLidarWrapper:
6
+ """
7
+ MuJoCo LiDAR wrapper that supports CPU, Taichi, and JAX backends.
8
+
9
+ Args:
10
+ mj_model (mujoco.MjModel): MuJoCo model object
11
+ site_name (str): Name of the LiDAR site in the MuJoCo model
12
+ backend (str): Computation backend, 'cpu', 'taichi', or 'jax'. Default: 'taichi'
13
+ cutoff_dist (float): Maximum ray tracing distance in meters. Default: 100.0
14
+ args (dict): Additional backend-specific arguments. Default: {}
15
+
16
+ CPU Backend Arguments:
17
+ geomgroup (np.ndarray | None): Geometry group filter (0-5, or None for all). Default: None
18
+ - None: Detect all geometries
19
+ - geomgroup is an array of length mjNGROUP, where 1 means the group should be included. Pass geomgroup=None to skip group exclusion.
20
+ bodyexclude (int): Body ID to exclude from detection. Default: -1
21
+ - -1: Don't exclude any body
22
+ - >= 0: Exclude all geometries of the specified body
23
+
24
+ Taichi Backend Arguments:
25
+ max_candidates (int): Maximum number of BVH candidate nodes. Default: 64
26
+ - Larger values: More accurate but slower
27
+ - Smaller values: Faster but may miss collisions
28
+ - Recommended: 16-32 (simple), 32-64 (medium), 64-128 (complex)
29
+ ti_init_args (dict): Arguments passed to taichi.init(). Default: {}
30
+ - device_memory_GB (float): GPU memory limit in GB
31
+ - debug (bool): Enable debug mode
32
+ - log_level (str): 'trace', 'debug', 'info', 'warn', 'error'
33
+
34
+ JAX Backend Arguments:
35
+ geom_ids (list | None): List of geometry IDs to include. Default: None (all)
36
+
37
+ Examples:
38
+ >>> # CPU backend with body exclusion
39
+ >>> lidar = MjLidarWrapper(
40
+ ... mj_model=model,
41
+ ... site_name="lidar_site",
42
+ ... backend="cpu",
43
+ ... cutoff_dist=50.0,
44
+ ... args={
45
+ ... 'bodyexclude': robot_body_id,
46
+ ... 'geomgroup': np.array([1, 1, 1, 0, 0, 0], np.dtype=np.uint8)
47
+ ... }
48
+ ... )
49
+
50
+ >>> # GPU backend for complex scenes
51
+ >>> lidar = MjLidarWrapper(
52
+ ... mj_model=model,
53
+ ... site_name="lidar_site",
54
+ ... backend="gpu",
55
+ ... cutoff_dist=100.0,
56
+ ... args={
57
+ ... 'bodyexclude': robot_body_id,
58
+ ... 'geomgroup': np.array([1, 1, 1, 0, 0, 0], np.dtype=np.uint8),
59
+ ... 'max_candidates': 64,
60
+ ... 'ti_init_args': {'device_memory_GB': 4.0}
61
+ ... }
62
+ ... )
63
+ """
64
+
65
+ def __init__(self, mj_model: mujoco.MjModel, site_name: str,
66
+ backend: str = "taichi", cutoff_dist: float = 100.0, args: Dict[str, Any] = {}):
67
+ if backend == "gpu":
68
+ backend = "taichi"
69
+ self.backend = backend
70
+ self.mj_model = mj_model
71
+ self.cutoff_dist = cutoff_dist
72
+ self.args = args
73
+
74
+ if backend == "taichi":
75
+ self._init_taichi_backend()
76
+ elif backend == "jax":
77
+ self._init_jax_backend()
78
+ elif backend == "cpu":
79
+ self._init_cpu_backend()
80
+ else:
81
+ raise ValueError(f"Unsupported backend: {backend}, choose from 'cpu', 'taichi', or 'jax'")
82
+
83
+ self.site_name = site_name
84
+ self._sensor_pose = np.eye(4, dtype=np.float32)
85
+ self._local_rays: Optional[np.ndarray] = None
86
+ self._distances: Optional[np.ndarray] = None
87
+
88
+ def _init_taichi_backend(self) -> None:
89
+ """Initialize Taichi backend"""
90
+ try:
91
+ # Lazy import: only import when Taichi backend is selected
92
+ from mujoco_lidar.core_ti.mjlidar_ti import MjLidarTi
93
+ import taichi as ti
94
+
95
+ # Initialize Taichi if not already done
96
+ if not hasattr(ti, '_is_initialized') or not ti._is_initialized:
97
+ ti.init(arch=ti.gpu, **self.args.get('ti_init_args', {}))
98
+
99
+ # Create Taichi backend instance
100
+ geomgroup = self.args.get('geomgroup', None)
101
+ bodyexclude = self.args.get('bodyexclude', -1)
102
+ max_candidates = self.args.get('max_candidates', 64)
103
+ self._backend_instance = MjLidarTi(
104
+ self.mj_model,
105
+ cutoff_dist=self.cutoff_dist,
106
+ geomgroup=geomgroup,
107
+ bodyexclude=bodyexclude,
108
+ max_candidates=max_candidates
109
+ )
110
+
111
+ except ImportError as e:
112
+ raise ImportError(
113
+ f"Failed to import Taichi backend dependencies. "
114
+ f"Please install taichi: pip install taichi\n"
115
+ f"Error: {e}"
116
+ )
117
+
118
+ def _init_jax_backend(self) -> None:
119
+ """Initialize JAX backend"""
120
+ try:
121
+ from mujoco_lidar.core_jax.mjlidar_jax import MjLidarJax
122
+
123
+ geomgroup = self.args.get('geomgroup', None)
124
+ bodyexclude = self.args.get('bodyexclude', -1)
125
+
126
+ # Pass mj_model directly. MjLidarJax will extract what it needs.
127
+ self._backend_instance = MjLidarJax(
128
+ self.mj_model,
129
+ geom_ids=self.args.get('geom_ids'),
130
+ geomgroup=geomgroup,
131
+ bodyexclude=bodyexclude
132
+ )
133
+
134
+ except ImportError as e:
135
+ raise ImportError(
136
+ f"Failed to import JAX backend dependencies.\n"
137
+ f"Error: {e}"
138
+ )
139
+
140
+ def _init_cpu_backend(self) -> None:
141
+ """Initialize CPU backend"""
142
+ try:
143
+ from mujoco_lidar.core_cpu.mjlidar_cpu import MjLidarCPU
144
+
145
+ geomgroup = self.args.get('geomgroup', None)
146
+ bodyexclude = self.args.get('bodyexclude', -1)
147
+ self._backend_instance = MjLidarCPU(
148
+ self.mj_model,
149
+ cutoff_dist=self.cutoff_dist,
150
+ geomgroup=geomgroup,
151
+ bodyexclude=bodyexclude
152
+ )
153
+
154
+ except ImportError as e:
155
+ raise ImportError(
156
+ f"Failed to import CPU backend dependencies.\n"
157
+ f"Error: {e}"
158
+ )
159
+
160
+ @property
161
+ def sensor_position(self) -> np.ndarray:
162
+ return self._sensor_pose[:3,3].copy()
163
+
164
+ @property
165
+ def sensor_rotation(self) -> np.ndarray:
166
+ return self._sensor_pose[:3,:3].copy()
167
+
168
+ def update_sensor_pose(self, mj_data: mujoco.MjData, site_name: str) -> None:
169
+ # For CPU/Taichi/JAX backend, mj_data is mujoco.MjData
170
+ if self.backend in ['cpu', 'taichi', 'jax']:
171
+ self._sensor_pose[:3,:3] = mj_data.site(site_name).xmat.reshape(3,3)
172
+ self._sensor_pose[:3,3] = mj_data.site(site_name).xpos
173
+
174
+ def trace_rays(self, mj_data: mujoco.MjData, ray_theta: np.ndarray, ray_phi: np.ndarray, site_name: Optional[str] = None) -> np.ndarray:
175
+ """
176
+ Trace rays.
177
+ For JAX backend, mj_data can be mujoco.MjData.
178
+ """
179
+ target_site = self.site_name if site_name is None else site_name
180
+
181
+ if self.backend == "jax":
182
+ # Update sensor pose for consistency
183
+ self.update_sensor_pose(mj_data, target_site)
184
+
185
+ # Use JITed trace_rays from backend instance
186
+ # This handles ray generation, transformation and rendering in one JIT call
187
+ self._distances, self._local_rays = self._backend_instance.trace_rays(
188
+ mj_data.geom_xpos,
189
+ mj_data.geom_xmat,
190
+ mj_data.site(target_site).xpos,
191
+ mj_data.site(target_site).xmat.reshape(3, 3),
192
+ ray_theta,
193
+ ray_phi
194
+ )
195
+
196
+ return self._distances
197
+
198
+ elif self.backend == "taichi":
199
+ # Taichi Backend
200
+ self.update_sensor_pose(mj_data, target_site)
201
+ self._backend_instance.update(mj_data)
202
+
203
+ import taichi as ti
204
+ # Convert numpy arrays to Taichi ndarrays if necessary
205
+ if isinstance(ray_theta, np.ndarray):
206
+ theta_ti = ti.ndarray(dtype=ti.f32, shape=ray_theta.shape[0])
207
+ theta_ti.from_numpy(ray_theta.astype(np.float32))
208
+ else:
209
+ theta_ti = ray_theta
210
+
211
+ if isinstance(ray_phi, np.ndarray):
212
+ phi_ti = ti.ndarray(dtype=ti.f32, shape=ray_phi.shape[0])
213
+ phi_ti.from_numpy(ray_phi.astype(np.float32))
214
+ else:
215
+ phi_ti = ray_phi
216
+
217
+ self._backend_instance.trace_rays(self._sensor_pose, theta_ti, phi_ti)
218
+ return self._backend_instance.get_distances()
219
+
220
+ else:
221
+ # CPU Backend
222
+ self.update_sensor_pose(mj_data, target_site)
223
+ self._backend_instance.update(mj_data)
224
+ self._backend_instance.trace_rays(self._sensor_pose, ray_theta, ray_phi)
225
+ return self._backend_instance.get_distances()
226
+
227
+ def get_hit_points(self) -> np.ndarray:
228
+ if self.backend == "jax":
229
+ if self._distances is None or self._local_rays is None:
230
+ return np.zeros((0, 3), dtype=np.float32)
231
+ return np.asarray(self._distances[:, np.newaxis] * self._local_rays)
232
+ return self._backend_instance.get_hit_points()
233
+
234
+ def get_distances(self) -> np.ndarray:
235
+ if self.backend == "jax":
236
+ if self._distances is None:
237
+ return np.zeros(0, dtype=np.float32)
238
+ return np.asarray(self._distances)
239
+ return self._backend_instance.get_distances()
240
+