mujoco-react 8.10.0 → 9.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +81 -44
  2. package/dist/chunk-33CV6HSV.js +400 -0
  3. package/dist/chunk-33CV6HSV.js.map +1 -0
  4. package/dist/index.d.ts +92 -24
  5. package/dist/index.js +338 -54
  6. package/dist/index.js.map +1 -1
  7. package/dist/spark.d.ts +24 -3
  8. package/dist/spark.js +91 -6
  9. package/dist/spark.js.map +1 -1
  10. package/dist/{types-FFW7ykBu.d.ts → types-izZlUweI.d.ts} +109 -16
  11. package/package.json +1 -1
  12. package/src/components/Body.tsx +3 -1
  13. package/src/components/DragInteraction.tsx +1 -1
  14. package/src/components/IkGizmo.tsx +2 -2
  15. package/src/components/SceneRenderer.tsx +1 -1
  16. package/src/components/TrajectoryPlayer.tsx +4 -1
  17. package/src/components/VisualScenario.tsx +343 -6
  18. package/src/core/MujocoCanvas.tsx +8 -1
  19. package/src/core/MujocoPhysics.tsx +10 -4
  20. package/src/core/MujocoSimProvider.tsx +15 -12
  21. package/src/core/SceneLoader.ts +182 -3
  22. package/src/core/createController.tsx +2 -2
  23. package/src/hooks/useBodyState.ts +1 -1
  24. package/src/hooks/useContacts.ts +1 -1
  25. package/src/hooks/useCtrlNoise.ts +1 -1
  26. package/src/hooks/useFrameCapture.ts +206 -0
  27. package/src/hooks/useGamepad.ts +1 -1
  28. package/src/hooks/useGravityCompensation.ts +1 -1
  29. package/src/hooks/useIkController.ts +22 -13
  30. package/src/hooks/useJointState.ts +1 -1
  31. package/src/hooks/useKeyboardTeleop.ts +1 -1
  32. package/src/hooks/usePolicy.ts +13 -9
  33. package/src/hooks/useSensor.ts +1 -1
  34. package/src/hooks/useTrajectoryPlayer.ts +4 -4
  35. package/src/hooks/useTrajectoryRecorder.ts +1 -1
  36. package/src/index.ts +35 -0
  37. package/src/spark.tsx +138 -4
  38. package/src/types.ts +128 -21
  39. package/dist/chunk-KGFRKPLS.js +0 -186
  40. package/dist/chunk-KGFRKPLS.js.map +0 -1
@@ -477,8 +477,9 @@ function sceneObjectToXml(obj: SceneObject): string {
477
477
  const solref = obj.solref ? ` solref="${obj.solref}"` : '';
478
478
  const solimp = obj.solimp ? ` solimp="${obj.solimp}"` : '';
479
479
  const condim = obj.condim ? ` condim="${obj.condim}"` : '';
480
+ const group = obj.group !== undefined ? ` group="${obj.group}"` : '';
480
481
  // Always set contype/conaffinity=1 so objects collide regardless of model defaults
481
- return `<body name="${obj.name}" pos="${pos}">${joint}<geom type="${obj.type}" size="${size}" rgba="${rgba}" contype="1" conaffinity="1"${mass}${friction}${solref}${solimp}${condim}/></body>`;
482
+ return `<body name="${obj.name}" pos="${pos}">${joint}<geom type="${obj.type}" size="${size}" rgba="${rgba}" contype="1" conaffinity="1"${mass}${friction}${solref}${solimp}${condim}${group}/></body>`;
482
483
  }
483
484
 
484
485
  /** Create virtual directory structure for a file path. */
@@ -527,6 +528,22 @@ function localFilePath(file: LocalMujocoFile): string {
527
528
  return normalizeVfsPath(file.webkitRelativePath || file.name);
528
529
  }
529
530
 
531
+ function dirname(path: string): string {
532
+ const normalized = normalizeVfsPath(path);
533
+ const idx = normalized.lastIndexOf('/');
534
+ return idx === -1 ? '' : normalized.slice(0, idx + 1);
535
+ }
536
+
537
+ function relativeVfsPath(fromDir: string, targetPath: string): string {
538
+ const from = normalizeVfsPath(fromDir).split('/').filter(Boolean);
539
+ const target = normalizeVfsPath(targetPath).split('/').filter(Boolean);
540
+ while (from.length && target.length && from[0] === target[0]) {
541
+ from.shift();
542
+ target.shift();
543
+ }
544
+ return [...from.map(() => '..'), ...target].join('/') || '.';
545
+ }
546
+
530
547
  function inferSceneFile(files: readonly LocalMujocoFile[], options?: LoadFromFilesOptions): string {
531
548
  if (options?.sceneFile) return normalizeVfsPath(options.sceneFile);
532
549
 
@@ -551,6 +568,7 @@ export function createSceneConfigFromFiles(
551
568
  src: '',
552
569
  sceneFile: inferSceneFile(fileArray, options),
553
570
  files: fileArray,
571
+ environmentFiles: options.environmentFiles?.map(normalizeVfsPath),
554
572
  homeJoints: options.homeJoints,
555
573
  xmlPatches: options.xmlPatches,
556
574
  sceneObjects: options.sceneObjects,
@@ -558,6 +576,137 @@ export function createSceneConfigFromFiles(
558
576
  };
559
577
  }
560
578
 
579
+ const ENVIRONMENT_MERGE_SECTIONS = [
580
+ 'asset',
581
+ 'worldbody',
582
+ 'contact',
583
+ 'equality',
584
+ 'tendon',
585
+ 'sensor',
586
+ 'keyframe',
587
+ 'custom',
588
+ 'extension',
589
+ ] as const;
590
+
591
+ function directChild(parent: Element, tagName: string): Element | null {
592
+ const lower = tagName.toLowerCase();
593
+ for (const child of Array.from(parent.children)) {
594
+ if (child.tagName.toLowerCase() === lower) return child;
595
+ }
596
+ return null;
597
+ }
598
+
599
+ function ensureTopLevelSection(doc: XMLDocument, tagName: string): Element {
600
+ const root = doc.documentElement;
601
+ const existing = directChild(root, tagName);
602
+ if (existing) return existing;
603
+
604
+ const section = doc.createElement(tagName);
605
+ if (tagName === 'asset') {
606
+ const worldbody = directChild(root, 'worldbody');
607
+ if (worldbody) root.insertBefore(section, worldbody);
608
+ else root.appendChild(section);
609
+ } else {
610
+ root.appendChild(section);
611
+ }
612
+ return section;
613
+ }
614
+
615
+ function readCompilerDirs(doc: XMLDocument) {
616
+ const compiler = directChild(doc.documentElement, 'compiler');
617
+ const assetDir = compiler?.getAttribute('assetdir') || '';
618
+ return {
619
+ meshDir: compiler?.getAttribute('meshdir') || assetDir,
620
+ textureDir: compiler?.getAttribute('texturedir') || assetDir,
621
+ };
622
+ }
623
+
624
+ function isExternalPath(path: string): boolean {
625
+ return /^[a-z]+:\/\//i.test(path) || path.startsWith('package://') || path.startsWith('/');
626
+ }
627
+
628
+ function fileReferencePrefix(el: Element, compilerDirs: ReturnType<typeof readCompilerDirs>): string {
629
+ const tag = el.tagName.toLowerCase();
630
+ if (tag === 'mesh') return compilerDirs.meshDir ? compilerDirs.meshDir + '/' : '';
631
+ if (tag === 'texture' || tag === 'hfield') return compilerDirs.textureDir ? compilerDirs.textureDir + '/' : '';
632
+ return '';
633
+ }
634
+
635
+ function rewriteFileReferencesForMerge(node: Element, sourceFile: string, targetFile: string, sourceDoc: XMLDocument) {
636
+ const sourceDir = dirname(sourceFile);
637
+ const targetDir = dirname(targetFile);
638
+ const compilerDirs = readCompilerDirs(sourceDoc);
639
+ node.querySelectorAll('[file], [filename]').forEach((el) => {
640
+ const attr = el.hasAttribute('file') ? 'file' : 'filename';
641
+ const value = el.getAttribute(attr);
642
+ if (!value || isExternalPath(value)) return;
643
+
644
+ const sourceRelativePath = normalizeVfsPath(fileReferencePrefix(el, compilerDirs) + value);
645
+ const resolvedPath = normalizeVfsPath(sourceDir + sourceRelativePath);
646
+ el.setAttribute(attr, relativeVfsPath(targetDir, resolvedPath));
647
+ });
648
+ }
649
+
650
+ function hasParseError(doc: XMLDocument): boolean {
651
+ return doc.getElementsByTagName('parsererror').length > 0;
652
+ }
653
+
654
+ function composeEnvironmentXml(
655
+ sceneXml: string,
656
+ config: SceneConfig,
657
+ parser: DOMParser,
658
+ environmentXmlByPath: Map<string, string>
659
+ ): string {
660
+ const environmentFiles = config.environmentFiles?.map(normalizeVfsPath) ?? [];
661
+ if (!environmentFiles.length) return sceneXml;
662
+
663
+ const sceneDoc = parser.parseFromString(sceneXml, 'text/xml');
664
+ if (hasParseError(sceneDoc)) {
665
+ console.warn(`Could not compose environments: failed to parse ${config.sceneFile}`);
666
+ return sceneXml;
667
+ }
668
+
669
+ for (const environmentFile of environmentFiles) {
670
+ const environmentXml = environmentXmlByPath.get(environmentFile);
671
+ if (!environmentXml) {
672
+ console.warn(`Environment XML not found: ${environmentFile}`);
673
+ continue;
674
+ }
675
+
676
+ const environmentDoc = parser.parseFromString(environmentXml, 'text/xml');
677
+ if (hasParseError(environmentDoc)) {
678
+ console.warn(`Skipping environment XML with parse errors: ${environmentFile}`);
679
+ continue;
680
+ }
681
+
682
+ for (const sectionName of ENVIRONMENT_MERGE_SECTIONS) {
683
+ const environmentSection = directChild(environmentDoc.documentElement, sectionName);
684
+ if (!environmentSection?.children.length) continue;
685
+
686
+ const targetSection = ensureTopLevelSection(sceneDoc, sectionName);
687
+ for (const child of Array.from(environmentSection.children)) {
688
+ const imported = sceneDoc.importNode(child, true) as Element;
689
+ rewriteFileReferencesForMerge(imported, environmentFile, config.sceneFile, environmentDoc);
690
+ targetSection.appendChild(imported);
691
+ }
692
+ }
693
+ }
694
+
695
+ return new XMLSerializer().serializeToString(sceneDoc);
696
+ }
697
+
698
+ function findTextByConfiguredPath(textByPath: Map<string, string>, configuredPath: string): string | undefined {
699
+ const normalized = normalizeVfsPath(configuredPath);
700
+ const direct = textByPath.get(normalized);
701
+ if (direct) return direct;
702
+
703
+ const suffix = '/' + normalized;
704
+ for (const [path, text] of textByPath) {
705
+ if (path.endsWith(suffix) || path === normalized.split('/').pop()) return text;
706
+ }
707
+ return undefined;
708
+ }
709
+
561
710
  function applyXmlPatches(text: string, fname: string, config: SceneConfig): string {
562
711
  let result = text;
563
712
  for (const patch of config.xmlPatches ?? []) {
@@ -627,10 +776,25 @@ async function loadSceneFromFiles(
627
776
  if (isModelTextFile(path)) {
628
777
  const text = applyXmlPatches(await file.text(), path, config);
629
778
  textByPath.set(path, text);
630
- mujoco.FS.writeFile(`/working/${path}`, text);
631
779
  } else {
632
780
  mujoco.FS.writeFile(`/working/${path}`, new Uint8Array(await file.arrayBuffer()));
781
+ written.add(path);
633
782
  }
783
+ }
784
+
785
+ const environmentXmlByPath = new Map<string, string>();
786
+ for (const environmentFile of config.environmentFiles?.map(normalizeVfsPath) ?? []) {
787
+ const environmentXml = findTextByConfiguredPath(textByPath, environmentFile);
788
+ if (environmentXml) environmentXmlByPath.set(environmentFile, environmentXml);
789
+ }
790
+
791
+ for (const [path, text] of textByPath) {
792
+ const composedText = path === config.sceneFile
793
+ ? composeEnvironmentXml(text, config, parser, environmentXmlByPath)
794
+ : text;
795
+ textByPath.set(path, composedText);
796
+ ensureDir(mujoco, path);
797
+ mujoco.FS.writeFile(`/working/${path}`, composedText);
634
798
  written.add(path);
635
799
  }
636
800
 
@@ -689,6 +853,18 @@ export async function loadScene(
689
853
 
690
854
  const baseUrl = config.src.endsWith('/') ? config.src : config.src + '/';
691
855
 
856
+ const environmentXmlByPath = new Map<string, string>();
857
+ const environmentFiles = config.environmentFiles?.map(normalizeVfsPath) ?? [];
858
+ for (const environmentFile of environmentFiles) {
859
+ onProgress?.(`Downloading ${environmentFile}...`);
860
+ const res = await fetch(baseUrl + environmentFile);
861
+ if (!res.ok) {
862
+ console.warn(`Failed to fetch environment XML ${environmentFile}: ${res.status} ${res.statusText}`);
863
+ continue;
864
+ }
865
+ environmentXmlByPath.set(environmentFile, applyXmlPatches(await res.text(), environmentFile, config));
866
+ }
867
+
692
868
  const downloaded = new Set<string>();
693
869
  const xmlQueue: string[] = [config.sceneFile];
694
870
  const assetFiles: string[] = [];
@@ -714,7 +890,10 @@ export async function loadScene(
714
890
  continue;
715
891
  }
716
892
 
717
- const text = applyXmlPatches(await res.text(), fname, config);
893
+ const patchedText = applyXmlPatches(await res.text(), fname, config);
894
+ const text = fname === config.sceneFile
895
+ ? composeEnvironmentXml(patchedText, config, parser, environmentXmlByPath)
896
+ : patchedText;
718
897
 
719
898
  ensureDir(mujoco, fname);
720
899
  mujoco.FS.writeFile(`/working/${fname}`, text);
@@ -43,7 +43,7 @@ export type ControllerComponent<TConfig> = React.FC<{
43
43
  * const MyController = createController<{ speed: number }>(
44
44
  * { name: 'my-controller', defaultConfig: { speed: 1.0 } },
45
45
  * function MyControllerImpl({ config }) {
46
- * useBeforePhysicsStep((_model, data) => {
46
+ * useBeforePhysicsStep(({ data }) => {
47
47
  * data.ctrl[0] = config.speed;
48
48
  * });
49
49
  * return null;
@@ -100,7 +100,7 @@ export function createController<TConfig>(
100
100
  * { name: 'useMyController', defaultConfig: { gain: 1.0 } },
101
101
  * function useMyControllerImpl(config) {
102
102
  * // config is MyConfig | null — hooks must be called unconditionally
103
- * useBeforePhysicsStep((_model, data) => {
103
+ * useBeforePhysicsStep(({ data }) => {
104
104
  * if (!config) return;
105
105
  * data.ctrl[0] = config.gain * Math.sin(data.time);
106
106
  * });
@@ -29,7 +29,7 @@ export function useBodyState(name: Bodies): BodyStateResult {
29
29
  bodyIdRef.current = findBodyByName(model, name);
30
30
  }, [name, status, mjModelRef]);
31
31
 
32
- useAfterPhysicsStep((_model, data) => {
32
+ useAfterPhysicsStep(({ data }) => {
33
33
  const bid = bodyIdRef.current;
34
34
  if (bid < 0) return;
35
35
 
@@ -60,7 +60,7 @@ export function useContacts(
60
60
  bodyResolvedRef.current = true;
61
61
  }, [bodyName, status, mjModelRef]);
62
62
 
63
- useAfterPhysicsStep((model, data) => {
63
+ useAfterPhysicsStep(({ model, data }) => {
64
64
  // Resolve body id lazily once model exists, to avoid missing the first ready frame.
65
65
  if (bodyName && !bodyResolvedRef.current) {
66
66
  bodyIdRef.current = findBodyByName(model, bodyName);
@@ -30,7 +30,7 @@ export function useCtrlNoise(config: CtrlNoiseConfig = {}) {
30
30
  configRef.current = config;
31
31
  const noiseRef = useRef<Float64Array | null>(null);
32
32
 
33
- useBeforePhysicsStep((_model, data) => {
33
+ useBeforePhysicsStep(({ data }) => {
34
34
  const cfg = configRef.current;
35
35
  if (cfg.enabled === false) return;
36
36
 
@@ -0,0 +1,206 @@
1
+ /**
2
+ * @license
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * useFrameCapture — still-frame capture for canvas-backed MuJoCo/R3F scenes.
6
+ */
7
+
8
+ import { useCallback, useState } from 'react';
9
+ import type React from 'react';
10
+
11
+ export type FrameCaptureStatus = 'idle' | 'capturing' | 'captured' | 'error';
12
+
13
+ export type FrameCaptureTarget =
14
+ | HTMLCanvasElement
15
+ | HTMLElement
16
+ | null
17
+ | undefined;
18
+
19
+ export type FrameCaptureTargetRef =
20
+ React.RefObject<HTMLCanvasElement | HTMLElement | null>;
21
+
22
+ export interface FrameCaptureOptions {
23
+ target?: FrameCaptureTarget | FrameCaptureTargetRef;
24
+ type?: string;
25
+ quality?: number;
26
+ waitForAnimationFrame?: boolean;
27
+ }
28
+
29
+ export interface FrameCaptureResult {
30
+ canvas: HTMLCanvasElement;
31
+ dataUrl: string;
32
+ type: string;
33
+ }
34
+
35
+ export interface FrameCaptureBlobResult {
36
+ canvas: HTMLCanvasElement;
37
+ blob: Blob;
38
+ type: string;
39
+ }
40
+
41
+ export interface FrameCaptureAPI {
42
+ status: FrameCaptureStatus;
43
+ error: Error | null;
44
+ isCapturing: boolean;
45
+ capture: (options?: FrameCaptureOptions) => Promise<FrameCaptureResult>;
46
+ captureBlob: (
47
+ options?: FrameCaptureOptions
48
+ ) => Promise<FrameCaptureBlobResult>;
49
+ reset: () => void;
50
+ }
51
+
52
+ function isTargetRef(
53
+ target: FrameCaptureOptions['target']
54
+ ): target is FrameCaptureTargetRef {
55
+ return Boolean(target && typeof target === 'object' && 'current' in target);
56
+ }
57
+
58
+ function resolveCanvasTarget(
59
+ target: FrameCaptureOptions['target']
60
+ ): HTMLCanvasElement {
61
+ const resolvedTarget = isTargetRef(target) ? target.current : target;
62
+
63
+ if (!resolvedTarget) {
64
+ throw new Error('No frame capture target is available.');
65
+ }
66
+
67
+ if (resolvedTarget instanceof HTMLCanvasElement) {
68
+ return resolvedTarget;
69
+ }
70
+
71
+ const canvas = resolvedTarget.querySelector('canvas');
72
+ if (!canvas) {
73
+ throw new Error('Frame capture target does not contain a canvas.');
74
+ }
75
+ return canvas;
76
+ }
77
+
78
+ function waitForNextAnimationFrame() {
79
+ return new Promise<void>((resolve) => {
80
+ requestAnimationFrame(() => resolve());
81
+ });
82
+ }
83
+
84
+ /**
85
+ * Capture the current canvas frame as a data URL.
86
+ *
87
+ * For WebGL scenes, create the renderer with `preserveDrawingBuffer: true`
88
+ * when you need deterministic captures after the frame has presented.
89
+ */
90
+ export async function captureFrame(
91
+ options: FrameCaptureOptions
92
+ ): Promise<FrameCaptureResult> {
93
+ const type = options.type ?? 'image/png';
94
+ const canvas = resolveCanvasTarget(options.target);
95
+
96
+ if (options.waitForAnimationFrame ?? true) {
97
+ await waitForNextAnimationFrame();
98
+ }
99
+
100
+ return {
101
+ canvas,
102
+ dataUrl: canvas.toDataURL(type, options.quality),
103
+ type,
104
+ };
105
+ }
106
+
107
+ /**
108
+ * Capture the current canvas frame as a Blob.
109
+ */
110
+ export async function captureFrameBlob(
111
+ options: FrameCaptureOptions
112
+ ): Promise<FrameCaptureBlobResult> {
113
+ const type = options.type ?? 'image/png';
114
+ const canvas = resolveCanvasTarget(options.target);
115
+
116
+ if (options.waitForAnimationFrame ?? true) {
117
+ await waitForNextAnimationFrame();
118
+ }
119
+
120
+ const blob = await new Promise<Blob>((resolve, reject) => {
121
+ canvas.toBlob(
122
+ (nextBlob) => {
123
+ if (nextBlob) {
124
+ resolve(nextBlob);
125
+ } else {
126
+ reject(new Error('Canvas frame capture did not produce a Blob.'));
127
+ }
128
+ },
129
+ type,
130
+ options.quality
131
+ );
132
+ });
133
+
134
+ return { canvas, blob, type };
135
+ }
136
+
137
+ /**
138
+ * React state wrapper around `captureFrame` and `captureFrameBlob`.
139
+ */
140
+ export function useFrameCapture(
141
+ defaultOptions: FrameCaptureOptions = {}
142
+ ): FrameCaptureAPI {
143
+ const [status, setStatus] = useState<FrameCaptureStatus>('idle');
144
+ const [error, setError] = useState<Error | null>(null);
145
+
146
+ const reset = useCallback(() => {
147
+ setStatus('idle');
148
+ setError(null);
149
+ }, []);
150
+
151
+ const capture = useCallback(
152
+ async (options: FrameCaptureOptions = {}) => {
153
+ setStatus('capturing');
154
+ setError(null);
155
+
156
+ try {
157
+ const result = await captureFrame({ ...defaultOptions, ...options });
158
+ setStatus('captured');
159
+ return result;
160
+ } catch (nextError) {
161
+ const error =
162
+ nextError instanceof Error
163
+ ? nextError
164
+ : new Error('Unable to capture the current canvas frame.');
165
+ setError(error);
166
+ setStatus('error');
167
+ throw error;
168
+ }
169
+ },
170
+ [defaultOptions]
171
+ );
172
+
173
+ const captureBlob = useCallback(
174
+ async (options: FrameCaptureOptions = {}) => {
175
+ setStatus('capturing');
176
+ setError(null);
177
+
178
+ try {
179
+ const result = await captureFrameBlob({
180
+ ...defaultOptions,
181
+ ...options,
182
+ });
183
+ setStatus('captured');
184
+ return result;
185
+ } catch (nextError) {
186
+ const error =
187
+ nextError instanceof Error
188
+ ? nextError
189
+ : new Error('Unable to capture the current canvas frame.');
190
+ setError(error);
191
+ setStatus('error');
192
+ throw error;
193
+ }
194
+ },
195
+ [defaultOptions]
196
+ );
197
+
198
+ return {
199
+ status,
200
+ error,
201
+ isCapturing: status === 'capturing',
202
+ capture,
203
+ captureBlob,
204
+ reset,
205
+ };
206
+ }
@@ -50,7 +50,7 @@ export function useGamepad(config: GamepadConfig) {
50
50
  }
51
51
  }, [config.axes, config.buttons, status, mjModelRef]);
52
52
 
53
- useBeforePhysicsStep((_model, data) => {
53
+ useBeforePhysicsStep(({ data }) => {
54
54
  const cfg = configRef.current;
55
55
  if (cfg.enabled === false) return;
56
56
 
@@ -13,7 +13,7 @@ import { useBeforePhysicsStep } from '../core/MujocoSimProvider';
13
13
  * hook (and DragInteraction) compose correctly — both add to a clean slate.
14
14
  */
15
15
  export function useGravityCompensation(enabled = true): void {
16
- useBeforePhysicsStep((model, data) => {
16
+ useBeforePhysicsStep(({ model, data }) => {
17
17
  if (!enabled) return;
18
18
  for (let i = 0; i < model.nv; i++) {
19
19
  data.qfrc_applied[i] += data.qfrc_bias[i];
@@ -10,7 +10,7 @@ import { createControllerHook } from '../core/createController';
10
10
  import { useMujocoContext, useBeforePhysicsStep } from '../core/MujocoSimProvider';
11
11
  import { GenericIK } from '../core/GenericIK';
12
12
  import { createContiguousControlGroup, findSiteByName, resolveControlGroup } from '../core/SceneLoader';
13
- import type { ControlGroupInfo, IkConfig, IkContextValue, IKSolveFn, MujocoData } from '../types';
13
+ import type { ControlGroupInfo, IkConfig, IkContextValue, IKSolveFn, IkSolveInput, MujocoData } from '../types';
14
14
 
15
15
  // Preallocated temp for syncGizmoToSite
16
16
  const _syncMat4 = new THREE.Matrix4();
@@ -84,16 +84,16 @@ export const useIkController = createControllerHook<IkConfig, IkContextValue>(
84
84
 
85
85
  // IK solve function
86
86
  const ikSolveFn = useCallback(
87
- (pos: THREE.Vector3, quat: THREE.Quaternion, currentQ: number[]): number[] | null => {
87
+ ({ position, quaternion, currentQ, context }: IkSolveInput): number[] | null => {
88
88
  if (!config) return null;
89
- if (config.ikSolveFn) return config.ikSolveFn(pos, quat, currentQ);
89
+ if (config.ikSolveFn) return config.ikSolveFn({ position, quaternion, currentQ, context });
90
90
  const model = mjModelRef.current;
91
91
  const data = mjDataRef.current;
92
92
  const controlGroup = controlGroupRef.current;
93
93
  if (!model || !data || !controlGroup || siteIdRef.current === -1) return null;
94
94
  return genericIkRef.current.solve(
95
95
  model, data, siteIdRef.current, controlGroup.qposAdr,
96
- pos, quat, currentQ,
96
+ position, quaternion, currentQ,
97
97
  { damping: config.damping, maxIterations: config.maxIterations },
98
98
  );
99
99
  },
@@ -128,7 +128,7 @@ export const useIkController = createControllerHook<IkConfig, IkContextValue>(
128
128
  });
129
129
 
130
130
  // IK solve in physics loop
131
- useBeforePhysicsStep((model, data) => {
131
+ useBeforePhysicsStep(({ model, data }) => {
132
132
  if (!config || !ikEnabledRef.current) {
133
133
  ikCalculatingRef.current = false;
134
134
  return;
@@ -142,13 +142,22 @@ export const useIkController = createControllerHook<IkConfig, IkContextValue>(
142
142
 
143
143
  const currentQ = Array.from(controlGroup.readQpos(data));
144
144
  const solution = config.ikSolveFn
145
- ? config.ikSolveFn(target.position, target.quaternion, currentQ, {
146
- model,
147
- data,
148
- siteId: siteIdRef.current,
149
- controlGroup,
145
+ ? config.ikSolveFn({
146
+ position: target.position,
147
+ quaternion: target.quaternion,
148
+ currentQ,
149
+ context: {
150
+ model,
151
+ data,
152
+ siteId: siteIdRef.current,
153
+ controlGroup,
154
+ },
150
155
  })
151
- : ikSolveFnRef.current(target.position, target.quaternion, currentQ);
156
+ : ikSolveFnRef.current({
157
+ position: target.position,
158
+ quaternion: target.quaternion,
159
+ currentQ,
160
+ });
152
161
  if (solution) {
153
162
  controlGroup.writeCtrl(data, solution);
154
163
  }
@@ -192,8 +201,8 @@ export const useIkController = createControllerHook<IkConfig, IkContextValue>(
192
201
  }, [mjDataRef]);
193
202
 
194
203
  const solveIK = useCallback(
195
- (pos: THREE.Vector3, quat: THREE.Quaternion, currentQ: number[]): number[] | null => {
196
- return ikSolveFnRef.current(pos, quat, currentQ);
204
+ (input: IkSolveInput): number[] | null => {
205
+ return ikSolveFnRef.current(input);
197
206
  },
198
207
  [],
199
208
  );
@@ -59,7 +59,7 @@ export function useJointState(name: Joints): JointStateResult {
59
59
  jointIdRef.current = -1;
60
60
  }, [name, status, mjModelRef]);
61
61
 
62
- useAfterPhysicsStep((_model, data) => {
62
+ useAfterPhysicsStep(({ data }) => {
63
63
  if (jointIdRef.current < 0) return;
64
64
  const qa = qposAdrRef.current;
65
65
  const da = dofAdrRef.current;
@@ -70,7 +70,7 @@ export function useKeyboardTeleop(config: KeyboardTeleopConfig) {
70
70
  }, []);
71
71
 
72
72
  // Apply bindings each physics frame
73
- useBeforePhysicsStep((_model, data) => {
73
+ useBeforePhysicsStep(({ data }) => {
74
74
  if (!enabledRef.current) return;
75
75
  const bindings = bindingsRef.current;
76
76
  const cache = actuatorCacheRef.current;
@@ -6,7 +6,7 @@
6
6
  */
7
7
 
8
8
  import { useRef } from 'react';
9
- import { useMujocoContext, useBeforePhysicsStep } from '../core/MujocoSimProvider';
9
+ import { useBeforePhysicsStep } from '../core/MujocoSimProvider';
10
10
  import type { PolicyConfig } from '../types';
11
11
 
12
12
  /**
@@ -20,14 +20,15 @@ import type { PolicyConfig } from '../types';
20
20
  * @returns { step, isRunning } control handles
21
21
  */
22
22
  export function usePolicy(config: PolicyConfig) {
23
- const { mjModelRef } = useMujocoContext();
24
23
  const lastActionTimeRef = useRef(0);
24
+ const lastObservationRef = useRef<ReturnType<PolicyConfig['onObservation']> | null>(null);
25
25
  const lastActionRef = useRef<Float32Array | Float64Array | number[] | null>(null);
26
- const isRunningRef = useRef(true);
26
+ const isRunningRef = useRef(config.enabled ?? true);
27
27
  const configRef = useRef(config);
28
28
  configRef.current = config;
29
+ isRunningRef.current = config.enabled ?? isRunningRef.current;
29
30
 
30
- useBeforePhysicsStep((model, data) => {
31
+ useBeforePhysicsStep(({ model, data }) => {
31
32
  if (!isRunningRef.current) return;
32
33
 
33
34
  const cfg = configRef.current;
@@ -37,13 +38,15 @@ export function usePolicy(config: PolicyConfig) {
37
38
  // Check if it's time for a new action
38
39
  if (data.time - lastActionTimeRef.current >= interval) {
39
40
  // Build observation
40
- const obs = cfg.onObservation(model, data);
41
+ const observation = cfg.onObservation({ model, data });
42
+ const action = cfg.infer ? cfg.infer({ observation, model, data }) : observation;
41
43
 
42
- // Apply action (consumer does inference inline or uses cached result)
43
- cfg.onAction(obs, model, data);
44
+ // Apply action. If `infer` is omitted, this preserves the legacy inline-controller path.
45
+ cfg.onAction({ action, observation, model, data });
44
46
 
45
47
  lastActionTimeRef.current = data.time;
46
- lastActionRef.current = obs;
48
+ lastObservationRef.current = observation;
49
+ lastActionRef.current = action;
47
50
  }
48
51
  });
49
52
 
@@ -51,6 +54,7 @@ export function usePolicy(config: PolicyConfig) {
51
54
  get isRunning() { return isRunningRef.current; },
52
55
  start: () => { isRunningRef.current = true; },
53
56
  stop: () => { isRunningRef.current = false; },
54
- get lastObservation() { return lastActionRef.current; },
57
+ get lastObservation() { return lastObservationRef.current; },
58
+ get lastAction() { return lastActionRef.current; },
55
59
  };
56
60
  }
@@ -39,7 +39,7 @@ export function useSensor(name: Sensors): SensorHandle {
39
39
  }, [name, status, mjModelRef]);
40
40
 
41
41
  // Update every frame after physics step
42
- useAfterPhysicsStep((_model, data) => {
42
+ useAfterPhysicsStep(({ data }) => {
43
43
  if (sensorIdRef.current < 0) return;
44
44
  const adr = sensorAdrRef.current;
45
45
  const dim = sensorDimRef.current;