mujoco-react 6.0.1 → 7.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,242 @@
1
+ /**
2
+ * @license
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ */
5
+
6
+ import { useCallback, useEffect, useMemo, useRef } from 'react';
7
+ import { useFrame } from '@react-three/fiber';
8
+ import * as THREE from 'three';
9
+ import { createControllerHook } from '../core/createController';
10
+ import { useMujocoContext, useBeforePhysicsStep } from '../core/MujocoSimProvider';
11
+ import { GenericIK } from '../core/GenericIK';
12
+ import { findSiteByName } from '../core/SceneLoader';
13
+ import type { IkConfig, IkContextValue, IKSolveFn, MujocoData } from '../types';
14
+
15
+ // Preallocated temp for syncGizmoToSite
16
+ const _syncMat4 = new THREE.Matrix4();
17
+
18
+ function syncGizmoToSite(data: MujocoData, siteId: number, target: THREE.Group) {
19
+ if (siteId === -1) return;
20
+ const sitePos = data.site_xpos.subarray(siteId * 3, siteId * 3 + 3);
21
+ const siteMat = data.site_xmat.subarray(siteId * 9, siteId * 9 + 9);
22
+ target.position.set(sitePos[0], sitePos[1], sitePos[2]);
23
+ _syncMat4.set(
24
+ siteMat[0], siteMat[1], siteMat[2], 0,
25
+ siteMat[3], siteMat[4], siteMat[5], 0,
26
+ siteMat[6], siteMat[7], siteMat[8], 0,
27
+ 0, 0, 0, 1,
28
+ );
29
+ target.quaternion.setFromRotationMatrix(_syncMat4);
30
+ }
31
+
32
+ export const useIkController = createControllerHook<IkConfig, IkContextValue>(
33
+ { name: 'useIkController', defaultConfig: { damping: 0.01, maxIterations: 50 } },
34
+ function useIkControllerImpl(config) {
35
+ const { mjModelRef, mjDataRef, mujocoRef, resetCallbacks, status } =
36
+ useMujocoContext();
37
+
38
+ // All IK state lives here
39
+ const ikEnabledRef = useRef(false);
40
+ const ikCalculatingRef = useRef(false);
41
+ const ikTargetRef = useRef<THREE.Group>(new THREE.Group());
42
+ const siteIdRef = useRef(-1);
43
+ const genericIkRef = useRef<GenericIK>(new GenericIK(mujocoRef.current));
44
+ const firstIkEnableRef = useRef(true);
45
+ const needsInitialSync = useRef(true);
46
+
47
+ const gizmoAnimRef = useRef({
48
+ active: false,
49
+ startPos: new THREE.Vector3(),
50
+ endPos: new THREE.Vector3(),
51
+ startRot: new THREE.Quaternion(),
52
+ endRot: new THREE.Quaternion(),
53
+ startTime: 0,
54
+ duration: 1000,
55
+ });
56
+
57
+ // Resolve site ID when model loads or config changes
58
+ useEffect(() => {
59
+ if (!config) {
60
+ siteIdRef.current = -1;
61
+ return;
62
+ }
63
+ const model = mjModelRef.current;
64
+ if (!model || status !== 'ready') {
65
+ siteIdRef.current = -1;
66
+ return;
67
+ }
68
+ siteIdRef.current = findSiteByName(model, config.siteName);
69
+ const data = mjDataRef.current;
70
+ if (data && ikTargetRef.current) {
71
+ syncGizmoToSite(data, siteIdRef.current, ikTargetRef.current);
72
+ }
73
+ }, [config?.siteName, status, mjModelRef, mjDataRef, config]);
74
+
75
+ // IK solve function
76
+ const ikSolveFn = useCallback(
77
+ (pos: THREE.Vector3, quat: THREE.Quaternion, currentQ: number[]): number[] | null => {
78
+ if (!config) return null;
79
+ if (config.ikSolveFn) return config.ikSolveFn(pos, quat, currentQ);
80
+ const model = mjModelRef.current;
81
+ const data = mjDataRef.current;
82
+ if (!model || !data || siteIdRef.current === -1) return null;
83
+ return genericIkRef.current.solve(
84
+ model, data, siteIdRef.current, config.numJoints,
85
+ pos, quat, currentQ,
86
+ { damping: config.damping, maxIterations: config.maxIterations },
87
+ );
88
+ },
89
+ [config, mjModelRef, mjDataRef],
90
+ );
91
+ const ikSolveFnRef = useRef<IKSolveFn>(ikSolveFn);
92
+ ikSolveFnRef.current = ikSolveFn;
93
+
94
+ // Gizmo animation + one-time initial sync
95
+ useFrame(() => {
96
+ if (!config) return;
97
+
98
+ if (needsInitialSync.current && siteIdRef.current !== -1) {
99
+ const data = mjDataRef.current;
100
+ if (data && ikTargetRef.current) {
101
+ syncGizmoToSite(data, siteIdRef.current, ikTargetRef.current);
102
+ needsInitialSync.current = false;
103
+ }
104
+ }
105
+
106
+ const ga = gizmoAnimRef.current;
107
+ const target = ikTargetRef.current;
108
+ if (!ga.active || !target) return;
109
+
110
+ const now = performance.now();
111
+ const elapsed = now - ga.startTime;
112
+ const t = Math.min(elapsed / ga.duration, 1.0);
113
+ const ease = 1 - Math.pow(1 - t, 3);
114
+ target.position.lerpVectors(ga.startPos, ga.endPos, ease);
115
+ target.quaternion.slerpQuaternions(ga.startRot, ga.endRot, ease);
116
+ if (t >= 1.0) ga.active = false;
117
+ });
118
+
119
+ // IK solve in physics loop
120
+ useBeforePhysicsStep((model, data) => {
121
+ if (!config || !ikEnabledRef.current) {
122
+ ikCalculatingRef.current = false;
123
+ return;
124
+ }
125
+ const target = ikTargetRef.current;
126
+ if (!target) return;
127
+
128
+ ikCalculatingRef.current = true;
129
+ const numJoints = config.numJoints;
130
+ const currentQ: number[] = [];
131
+ for (let i = 0; i < numJoints; i++) currentQ.push(data.qpos[i]);
132
+ const solution = ikSolveFnRef.current(target.position, target.quaternion, currentQ);
133
+ if (solution) {
134
+ for (let i = 0; i < numJoints; i++) data.ctrl[i] = solution[i];
135
+ }
136
+ });
137
+
138
+ // Reset callback
139
+ useEffect(() => {
140
+ if (!config) return;
141
+ const cb = () => {
142
+ const data = mjDataRef.current;
143
+ if (data && ikTargetRef.current) {
144
+ syncGizmoToSite(data, siteIdRef.current, ikTargetRef.current);
145
+ }
146
+ gizmoAnimRef.current.active = false;
147
+ firstIkEnableRef.current = true;
148
+ ikEnabledRef.current = false;
149
+ needsInitialSync.current = true;
150
+ };
151
+ resetCallbacks.current.add(cb);
152
+ return () => { resetCallbacks.current.delete(cb); };
153
+ }, [resetCallbacks, mjDataRef, config]);
154
+
155
+ // --- API methods ---
156
+
157
+ const setIkEnabled = useCallback(
158
+ (enabled: boolean) => {
159
+ ikEnabledRef.current = enabled;
160
+ const data = mjDataRef.current;
161
+ if (enabled && data && !gizmoAnimRef.current.active && ikTargetRef.current) {
162
+ syncGizmoToSite(data, siteIdRef.current, ikTargetRef.current);
163
+ firstIkEnableRef.current = false;
164
+ }
165
+ },
166
+ [mjDataRef],
167
+ );
168
+
169
+ const syncTargetToSiteApi = useCallback(() => {
170
+ const data = mjDataRef.current;
171
+ const target = ikTargetRef.current;
172
+ if (data && target) syncGizmoToSite(data, siteIdRef.current, target);
173
+ }, [mjDataRef]);
174
+
175
+ const solveIK = useCallback(
176
+ (pos: THREE.Vector3, quat: THREE.Quaternion, currentQ: number[]): number[] | null => {
177
+ return ikSolveFnRef.current(pos, quat, currentQ);
178
+ },
179
+ [],
180
+ );
181
+
182
+ const moveTarget = useCallback(
183
+ (pos: THREE.Vector3, duration = 0) => {
184
+ if (!ikEnabledRef.current) setIkEnabled(true);
185
+ const target = ikTargetRef.current;
186
+ if (!target) return;
187
+
188
+ const targetPos = pos.clone();
189
+ const targetRot = new THREE.Quaternion().setFromEuler(
190
+ new THREE.Euler(Math.PI, 0, 0),
191
+ );
192
+
193
+ if (duration > 0) {
194
+ const ga = gizmoAnimRef.current;
195
+ ga.active = true;
196
+ ga.startPos.copy(target.position);
197
+ ga.endPos.copy(targetPos);
198
+ ga.startRot.copy(target.quaternion);
199
+ ga.endRot.copy(targetRot);
200
+ ga.startTime = performance.now();
201
+ ga.duration = duration;
202
+ } else {
203
+ gizmoAnimRef.current.active = false;
204
+ target.position.copy(targetPos);
205
+ target.quaternion.copy(targetRot);
206
+ }
207
+ },
208
+ [setIkEnabled],
209
+ );
210
+
211
+ const getGizmoStats = useCallback(
212
+ (): { pos: THREE.Vector3; rot: THREE.Euler } | null => {
213
+ const target = ikTargetRef.current;
214
+ if (!ikCalculatingRef.current || !target) return null;
215
+ return {
216
+ pos: target.position.clone(),
217
+ rot: new THREE.Euler().setFromQuaternion(target.quaternion),
218
+ };
219
+ },
220
+ [],
221
+ );
222
+
223
+ const contextValue = useMemo<IkContextValue>(
224
+ () => ({
225
+ ikEnabledRef,
226
+ ikCalculatingRef,
227
+ ikTargetRef,
228
+ siteIdRef,
229
+ setIkEnabled,
230
+ moveTarget,
231
+ syncTargetToSite: syncTargetToSiteApi,
232
+ solveIK,
233
+ getGizmoStats,
234
+ }),
235
+ [setIkEnabled, moveTarget, syncTargetToSiteApi, solveIK, getGizmoStats],
236
+ );
237
+
238
+ if (!config) return null;
239
+
240
+ return contextValue;
241
+ },
242
+ );
@@ -6,7 +6,7 @@
6
6
  */
7
7
 
8
8
  import { useEffect, useRef } from 'react';
9
- import { useMujoco, useAfterPhysicsStep } from '../core/MujocoSimProvider';
9
+ import { useMujocoContext, useAfterPhysicsStep } from '../core/MujocoSimProvider';
10
10
  import { getName } from '../core/SceneLoader';
11
11
  import type { JointStateResult } from '../types';
12
12
 
@@ -19,7 +19,7 @@ import type { JointStateResult } from '../types';
19
19
  * For free joints, position is pos+quat (7), velocity is lin+ang vel (6).
20
20
  */
21
21
  export function useJointState(name: string): JointStateResult {
22
- const { mjModelRef, mjDataRef, status } = useMujoco();
22
+ const { mjModelRef, mjDataRef, status } = useMujocoContext();
23
23
  const jointIdRef = useRef(-1);
24
24
  const qposAdrRef = useRef(0);
25
25
  const dofAdrRef = useRef(0);
@@ -6,7 +6,7 @@
6
6
  */
7
7
 
8
8
  import { useEffect, useRef } from 'react';
9
- import { useMujoco, useBeforePhysicsStep } from '../core/MujocoSimProvider';
9
+ import { useMujocoContext, useBeforePhysicsStep } from '../core/MujocoSimProvider';
10
10
  import { findActuatorByName } from '../core/SceneLoader';
11
11
  import type { KeyboardTeleopConfig } from '../types';
12
12
 
@@ -19,7 +19,7 @@ import type { KeyboardTeleopConfig } from '../types';
19
19
  * - `set`: Set actuator to a fixed value while key is held
20
20
  */
21
21
  export function useKeyboardTeleop(config: KeyboardTeleopConfig) {
22
- const { mjModelRef, mjDataRef, status } = useMujoco();
22
+ const { mjModelRef, mjDataRef, status } = useMujocoContext();
23
23
  const pressedRef = useRef(new Set<string>());
24
24
  const toggleStateRef = useRef(new Map<string, boolean>());
25
25
  const enabledRef = useRef(config.enabled ?? true);
@@ -6,7 +6,7 @@
6
6
  */
7
7
 
8
8
  import { useRef } from 'react';
9
- import { useMujoco, useBeforePhysicsStep } from '../core/MujocoSimProvider';
9
+ import { useMujocoContext, useBeforePhysicsStep } from '../core/MujocoSimProvider';
10
10
  import type { PolicyConfig } from '../types';
11
11
 
12
12
  /**
@@ -20,7 +20,7 @@ import type { PolicyConfig } from '../types';
20
20
  * @returns { step, isRunning } control handles
21
21
  */
22
22
  export function usePolicy(config: PolicyConfig) {
23
- const { mjModelRef } = useMujoco();
23
+ const { mjModelRef } = useMujocoContext();
24
24
  const lastActionTimeRef = useRef(0);
25
25
  const lastActionRef = useRef<Float32Array | Float64Array | number[] | null>(null);
26
26
  const isRunningRef = useRef(true);
@@ -10,10 +10,10 @@
10
10
  import { useEffect, useRef } from 'react';
11
11
  import * as THREE from 'three';
12
12
  import { useThree } from '@react-three/fiber';
13
- import { useMujoco } from '../core/MujocoSimProvider';
13
+ import { useMujocoContext } from '../core/MujocoSimProvider';
14
14
 
15
15
  export function useSceneLights(intensity = 1.0) {
16
- const { mjModelRef, status } = useMujoco();
16
+ const { mjModelRef, status } = useMujocoContext();
17
17
  const { scene } = useThree();
18
18
  const lightsRef = useRef<THREE.Light[]>([]);
19
19
  const targetsRef = useRef<THREE.Object3D[]>([]);
@@ -6,7 +6,7 @@
6
6
  */
7
7
 
8
8
  import { useEffect, useRef, useMemo } from 'react';
9
- import { useMujoco, useAfterPhysicsStep } from '../core/MujocoSimProvider';
9
+ import { useMujocoContext, useAfterPhysicsStep } from '../core/MujocoSimProvider';
10
10
  import { getName } from '../core/SceneLoader';
11
11
  import type { SensorInfo, SensorResult } from '../types';
12
12
 
@@ -15,7 +15,7 @@ import type { SensorInfo, SensorResult } from '../types';
15
15
  * updated every physics frame without causing React re-renders.
16
16
  */
17
17
  export function useSensor(name: string): SensorResult {
18
- const { mjModelRef, mjDataRef, status } = useMujoco();
18
+ const { mjModelRef, mjDataRef, status } = useMujocoContext();
19
19
  const sensorIdRef = useRef(-1);
20
20
  const sensorAdrRef = useRef(0);
21
21
  const sensorDimRef = useRef(0);
@@ -55,7 +55,7 @@ export function useSensor(name: string): SensorResult {
55
55
  * Returns a stable array recomputed only when the model changes.
56
56
  */
57
57
  export function useSensors(): SensorInfo[] {
58
- const { mjModelRef, status } = useMujoco();
58
+ const { mjModelRef, status } = useMujocoContext();
59
59
 
60
60
  return useMemo(() => {
61
61
  const model = mjModelRef.current;
@@ -6,7 +6,7 @@
6
6
  import { useEffect, useRef } from 'react';
7
7
  import { useFrame } from '@react-three/fiber';
8
8
  import * as THREE from 'three';
9
- import { useMujoco } from '../core/MujocoSimProvider';
9
+ import { useMujocoContext } from '../core/MujocoSimProvider';
10
10
  import { findSiteByName } from '../core/SceneLoader';
11
11
  import type { SitePositionResult } from '../types';
12
12
 
@@ -18,7 +18,7 @@ const _mat4 = new THREE.Matrix4();
18
18
  * Refs are updated every frame without triggering React re-renders.
19
19
  */
20
20
  export function useSitePosition(siteName: string): SitePositionResult {
21
- const { mjModelRef, mjDataRef, status } = useMujoco();
21
+ const { mjModelRef, mjDataRef, status } = useMujocoContext();
22
22
  const siteIdRef = useRef(-1);
23
23
  const positionRef = useRef(new THREE.Vector3());
24
24
  const quaternionRef = useRef(new THREE.Quaternion());
@@ -7,7 +7,7 @@
7
7
 
8
8
  import { useCallback, useRef } from 'react';
9
9
  import { useFrame } from '@react-three/fiber';
10
- import { useMujoco } from '../core/MujocoSimProvider';
10
+ import { useMujocoContext } from '../core/MujocoSimProvider';
11
11
 
12
12
  interface TrajectoryPlayerOptions {
13
13
  fps?: number;
@@ -24,7 +24,7 @@ export function useTrajectoryPlayer(
24
24
  trajectory: number[][],
25
25
  options: TrajectoryPlayerOptions = {},
26
26
  ) {
27
- const { mjModelRef, mjDataRef, mujocoRef, pausedRef } = useMujoco();
27
+ const { mjModelRef, mjDataRef, mujocoRef, pausedRef } = useMujocoContext();
28
28
  const fps = options.fps ?? 30;
29
29
  const loop = options.loop ?? false;
30
30
 
@@ -6,7 +6,7 @@
6
6
  */
7
7
 
8
8
  import { useCallback, useRef } from 'react';
9
- import { useMujoco, useAfterPhysicsStep } from '../core/MujocoSimProvider';
9
+ import { useMujocoContext, useAfterPhysicsStep } from '../core/MujocoSimProvider';
10
10
  import type { TrajectoryFrame } from '../types';
11
11
 
12
12
  interface RecorderOptions {
@@ -17,7 +17,7 @@ interface RecorderOptions {
17
17
  * Record simulation trajectories for analysis, replay, or training data.
18
18
  */
19
19
  export function useTrajectoryRecorder(options: RecorderOptions = {}) {
20
- const { mjModelRef } = useMujoco();
20
+ const { mjModelRef } = useMujocoContext();
21
21
  const recordingRef = useRef(false);
22
22
  const framesRef = useRef<TrajectoryFrame[]>([]);
23
23
  const fields = options.fields ?? ['qpos'];
package/src/index.ts CHANGED
@@ -26,12 +26,11 @@ export {
26
26
  export { createController } from './core/createController';
27
27
  export type { ControllerOptions, ControllerComponent } from './core/createController';
28
28
 
29
- // IK controller plugin
30
- export { IkController } from './components/IkController';
31
- export { useIk } from './core/IkContext';
32
- export type { IkContextValue } from './core/IkContext';
29
+ // IK controller hook
30
+ export { useIkController } from './hooks/useIkController';
33
31
 
34
32
  // Components
33
+ export { Body } from './components/Body';
35
34
  export { IkGizmo } from './components/IkGizmo';
36
35
  export { ContactMarkers } from './components/ContactMarkers';
37
36
  export { DragInteraction } from './components/DragInteraction';
@@ -74,6 +73,7 @@ export type {
74
73
  PhysicsConfig,
75
74
  // IK
76
75
  IkConfig,
76
+ IkContextValue,
77
77
  IKSolveFn,
78
78
  // Callbacks
79
79
  PhysicsStepCallback,
@@ -101,6 +101,7 @@ export type {
101
101
  // Policy
102
102
  PolicyConfig,
103
103
  // Component props
104
+ BodyProps,
104
105
  IkGizmoProps,
105
106
  DragInteractionProps,
106
107
  DebugProps,
package/src/types.ts CHANGED
@@ -3,6 +3,8 @@
3
3
  * SPDX-License-Identifier: Apache-2.0
4
4
  */
5
5
 
6
+ import type React from 'react';
7
+ import type { ReactNode } from 'react';
6
8
  import type { CanvasProps } from '@react-three/fiber';
7
9
  import * as THREE from 'three';
8
10
 
@@ -309,6 +311,18 @@ export interface IkConfig {
309
311
  maxIterations?: number;
310
312
  }
311
313
 
314
+ export interface IkContextValue {
315
+ ikEnabledRef: React.RefObject<boolean>;
316
+ ikCalculatingRef: React.RefObject<boolean>;
317
+ ikTargetRef: React.RefObject<THREE.Group>;
318
+ siteIdRef: React.RefObject<number>;
319
+ setIkEnabled: (enabled: boolean) => void;
320
+ moveTarget: (pos: THREE.Vector3, duration?: number) => void;
321
+ syncTargetToSite: () => void;
322
+ solveIK: (pos: THREE.Vector3, quat: THREE.Quaternion, currentQ: number[]) => number[] | null;
323
+ getGizmoStats: () => { pos: THREE.Vector3; rot: THREE.Euler } | null;
324
+ }
325
+
312
326
  export interface SceneMarker {
313
327
  id: number;
314
328
  position: THREE.Vector3;
@@ -482,6 +496,7 @@ export interface DebugProps {
482
496
  // ---- Component Props ----
483
497
 
484
498
  export interface IkGizmoProps {
499
+ controller: IkContextValue;
485
500
  siteName?: string;
486
501
  scale?: number;
487
502
  onDrag?: (position: THREE.Vector3, quaternion: THREE.Quaternion) => void;
@@ -517,6 +532,21 @@ export interface ContactListenerProps {
517
532
  onContactExit?: (info: ContactInfo) => void;
518
533
  }
519
534
 
535
+ export interface BodyProps {
536
+ name: string;
537
+ type: 'box' | 'sphere' | 'cylinder';
538
+ size: [number, number, number];
539
+ position?: [number, number, number];
540
+ rgba?: [number, number, number, number];
541
+ mass?: number;
542
+ freejoint?: boolean;
543
+ friction?: string;
544
+ solref?: string;
545
+ solimp?: string;
546
+ condim?: number;
547
+ children?: ReactNode;
548
+ }
549
+
520
550
  // ---- Public API (spec: full surface) ----
521
551
 
522
552
  export interface MujocoSimAPI {