mujoco-react 10.4.0 → 10.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/onnx.ts ADDED
@@ -0,0 +1,126 @@
1
+ /**
2
+ * @license
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Optional ONNX Runtime Web helpers for browser policy demos.
6
+ *
7
+ * This entry point is exported as `mujoco-react/onnx` so the main package does
8
+ * not import or bundle `onnxruntime-web`.
9
+ */
10
+
11
+ import type * as ort from 'onnxruntime-web';
12
+ import type { PolicyActionChunk } from './types';
13
+
14
+ export type OnnxPolicyDtype = 'float32' | 'float64' | 'int32' | 'int64' | 'bool' | string;
15
+
16
+ export interface OnnxPolicyTensorSpec {
17
+ name: string;
18
+ shape: number[];
19
+ dtype: OnnxPolicyDtype;
20
+ }
21
+
22
+ export interface OnnxPolicyImageSpec {
23
+ width: number;
24
+ height: number;
25
+ channels?: number;
26
+ layout?: 'CHW' | 'HWC' | string;
27
+ range?: readonly [number, number];
28
+ }
29
+
30
+ export interface OnnxPolicyManifest {
31
+ model: string;
32
+ variants?: Record<string, string>;
33
+ fps?: number;
34
+ joints?: string[];
35
+ cameras?: string[];
36
+ image?: OnnxPolicyImageSpec;
37
+ chunk_size?: number;
38
+ n_action_steps?: number;
39
+ inputs: OnnxPolicyTensorSpec[];
40
+ output: OnnxPolicyTensorSpec & {
41
+ units?: string;
42
+ };
43
+ [key: string]: unknown;
44
+ }
45
+
46
+ export interface LoadOnnxPolicyManifestResult<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest> {
47
+ manifest: TManifest;
48
+ manifestUrl: URL;
49
+ modelUrl: URL;
50
+ }
51
+
52
+ export interface CreateOnnxPolicySessionOptions<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest> {
53
+ manifestUrl: string | URL;
54
+ variant?: string;
55
+ runtime: typeof ort;
56
+ sessionOptions?: ort.InferenceSession.SessionOptions;
57
+ fetcher?: typeof fetch;
58
+ readManifest?: (response: Response) => Promise<TManifest>;
59
+ }
60
+
61
+ export interface OnnxPolicySession<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest>
62
+ extends LoadOnnxPolicyManifestResult<TManifest> {
63
+ session: ort.InferenceSession;
64
+ }
65
+
66
+ function asUrl(value: string | URL, base = globalThis.location?.href) {
67
+ return value instanceof URL ? value : new URL(value, base);
68
+ }
69
+
70
+ function resolveModelPath(manifest: OnnxPolicyManifest, variant: string | undefined) {
71
+ if (variant && manifest.variants?.[variant]) return manifest.variants[variant];
72
+ return manifest.model;
73
+ }
74
+
75
+ export async function loadOnnxPolicyManifest<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest>(
76
+ manifestUrlInput: string | URL,
77
+ options: Pick<CreateOnnxPolicySessionOptions<TManifest>, 'variant' | 'fetcher' | 'readManifest'> = {}
78
+ ): Promise<LoadOnnxPolicyManifestResult<TManifest>> {
79
+ const fetcher = options.fetcher ?? fetch;
80
+ const manifestUrl = asUrl(manifestUrlInput);
81
+ const response = await fetcher(manifestUrl);
82
+ if (!response.ok) {
83
+ throw new Error(`Unable to load ONNX policy manifest from ${manifestUrl.href} (${response.status}).`);
84
+ }
85
+ const manifest = options.readManifest
86
+ ? await options.readManifest(response)
87
+ : await response.json() as TManifest;
88
+ const modelPath = resolveModelPath(manifest, options.variant);
89
+ const modelUrl = asUrl(modelPath, manifestUrl.href);
90
+ return { manifest, manifestUrl, modelUrl };
91
+ }
92
+
93
+ export async function createOnnxPolicySession<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest>(
94
+ options: CreateOnnxPolicySessionOptions<TManifest>
95
+ ): Promise<OnnxPolicySession<TManifest>> {
96
+ const fetcher = options.fetcher ?? fetch;
97
+ const resolved = await loadOnnxPolicyManifest(options.manifestUrl, options);
98
+ const response = await fetcher(resolved.modelUrl);
99
+ if (!response.ok) {
100
+ throw new Error(`Unable to load ONNX policy model from ${resolved.modelUrl.href} (${response.status}).`);
101
+ }
102
+ const modelBytes = await response.arrayBuffer();
103
+ const session = await options.runtime.InferenceSession.create(modelBytes, options.sessionOptions);
104
+ return {
105
+ ...resolved,
106
+ session,
107
+ };
108
+ }
109
+
110
+ export function onnxTensorToPolicyActionChunk(
111
+ tensor: ort.Tensor,
112
+ actionSize = tensor.dims.at(-1) ?? 1,
113
+ maxActions?: number
114
+ ): PolicyActionChunk {
115
+ const rawData = Array.from(tensor.data as ArrayLike<number>, (value) => Number(value));
116
+ const actionCount = Math.floor(rawData.length / actionSize);
117
+ const cappedActionCount = maxActions === undefined
118
+ ? actionCount
119
+ : Math.max(0, Math.min(actionCount, Math.floor(maxActions)));
120
+ const actions: number[][] = [];
121
+ for (let actionIndex = 0; actionIndex < cappedActionCount; actionIndex += 1) {
122
+ const start = actionIndex * actionSize;
123
+ actions.push(rawData.slice(start, start + actionSize));
124
+ }
125
+ return actions;
126
+ }
@@ -0,0 +1,150 @@
1
+ /**
2
+ * @license
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Helpers for turning browser camera captures into policy image tensors.
6
+ */
7
+
8
+ export type PolicyImageTensorLayout = 'CHW' | 'HWC';
9
+ export type PolicyImageTensorRange = readonly [number, number];
10
+ /**
11
+ * Row order of a raw pixel buffer. WebGL `readRenderTargetPixels` returns rows
12
+ * bottom-to-top (`'bottom-left'`); `ImageData` is top-to-bottom (`'top-left'`).
13
+ */
14
+ export type PolicyImageTensorSourceOrigin = 'top-left' | 'bottom-left';
15
+
16
+ export interface PolicyImageTensorOptions {
17
+ width: number;
18
+ height: number;
19
+ channels?: 3 | 4;
20
+ layout?: PolicyImageTensorLayout;
21
+ range?: PolicyImageTensorRange;
22
+ }
23
+
24
+ export interface PolicyImageTensorPixelOptions extends PolicyImageTensorOptions {
25
+ /** Row order of the source buffer. Defaults to `'top-left'`. */
26
+ sourceOrigin?: PolicyImageTensorSourceOrigin;
27
+ /** Mirror horizontally while reading. */
28
+ flipX?: boolean;
29
+ }
30
+
31
+ export interface PolicyImageTensorResult {
32
+ data: Float32Array;
33
+ shape: [number, number, number];
34
+ width: number;
35
+ height: number;
36
+ channels: 3 | 4;
37
+ layout: PolicyImageTensorLayout;
38
+ range: PolicyImageTensorRange;
39
+ }
40
+
41
+ function resolveTensorOptions(options: PolicyImageTensorOptions): Required<PolicyImageTensorOptions> {
42
+ return {
43
+ channels: 3,
44
+ layout: 'CHW',
45
+ range: [0, 1],
46
+ ...options,
47
+ };
48
+ }
49
+
50
+ function normalizeChannel(value: number, range: PolicyImageTensorRange) {
51
+ const [min, max] = range;
52
+ if (min === 0 && max === 255) return value;
53
+ return min + (value / 255) * (max - min);
54
+ }
55
+
56
+ /**
57
+ * Convert a raw RGBA pixel buffer (4 bytes per pixel) directly into a policy
58
+ * image tensor. This is the fast path that skips canvas encoding entirely —
59
+ * feed it the `Uint8Array` returned by `readRenderTargetPixels` (which is
60
+ * bottom-left origin, so pass `sourceOrigin: 'bottom-left'`).
61
+ */
62
+ export function pixelsToPolicyImageTensor(
63
+ pixels: Uint8Array | Uint8ClampedArray,
64
+ options: PolicyImageTensorPixelOptions
65
+ ): PolicyImageTensorResult {
66
+ const resolved = resolveTensorOptions(options);
67
+ const { width, height, channels, layout, range } = resolved;
68
+ const expected = width * height * 4;
69
+ if (pixels.length < expected) {
70
+ throw new Error(
71
+ `Pixel buffer of length ${pixels.length} is too small for ${width}x${height} RGBA data (${expected} bytes).`
72
+ );
73
+ }
74
+
75
+ const flipY = options.sourceOrigin === 'bottom-left';
76
+ const flipX = options.flipX ?? false;
77
+ const pixelCount = width * height;
78
+ const data = new Float32Array(pixelCount * channels);
79
+
80
+ for (let y = 0; y < height; y += 1) {
81
+ const sourceY = flipY ? height - y - 1 : y;
82
+ for (let x = 0; x < width; x += 1) {
83
+ const sourceX = flipX ? width - x - 1 : x;
84
+ const source = (sourceY * width + sourceX) * 4;
85
+ const target = y * width + x;
86
+ for (let channel = 0; channel < channels; channel += 1) {
87
+ const value = normalizeChannel(pixels[source + channel], range);
88
+ if (layout === 'CHW') {
89
+ data[channel * pixelCount + target] = value;
90
+ } else {
91
+ data[target * channels + channel] = value;
92
+ }
93
+ }
94
+ }
95
+ }
96
+
97
+ return {
98
+ data,
99
+ shape: layout === 'CHW' ? [channels, height, width] : [height, width, channels],
100
+ width,
101
+ height,
102
+ channels,
103
+ layout,
104
+ range,
105
+ };
106
+ }
107
+
108
+ export function imageDataToPolicyImageTensor(
109
+ imageData: ImageData,
110
+ options: PolicyImageTensorOptions
111
+ ): PolicyImageTensorResult {
112
+ const resolved = resolveTensorOptions(options);
113
+ if (imageData.width !== resolved.width || imageData.height !== resolved.height) {
114
+ throw new Error(
115
+ `ImageData size ${imageData.width}x${imageData.height} does not match tensor size ${resolved.width}x${resolved.height}.`
116
+ );
117
+ }
118
+ return pixelsToPolicyImageTensor(imageData.data, {
119
+ ...resolved,
120
+ sourceOrigin: 'top-left',
121
+ });
122
+ }
123
+
124
+ async function decodeImageSource(dataUrl: string) {
125
+ const image = new Image();
126
+ image.decoding = 'async';
127
+ image.src = dataUrl;
128
+ await image.decode();
129
+ return image;
130
+ }
131
+
132
+ export async function dataUrlToPolicyImageTensor(
133
+ dataUrl: string,
134
+ options: PolicyImageTensorOptions
135
+ ): Promise<PolicyImageTensorResult> {
136
+ const resolved = resolveTensorOptions(options);
137
+ const image = await decodeImageSource(dataUrl);
138
+ const canvas = document.createElement('canvas');
139
+ canvas.width = resolved.width;
140
+ canvas.height = resolved.height;
141
+ const context = canvas.getContext('2d', { willReadFrequently: true });
142
+ if (!context) {
143
+ throw new Error('Unable to create a 2D canvas context for policy image tensor conversion.');
144
+ }
145
+ context.drawImage(image, 0, 0, resolved.width, resolved.height);
146
+ return imageDataToPolicyImageTensor(
147
+ context.getImageData(0, 0, resolved.width, resolved.height),
148
+ resolved
149
+ );
150
+ }
@@ -13,6 +13,29 @@ import type {
13
13
  CameraFrameCaptureSource,
14
14
  CameraFrameCaptureVector3,
15
15
  } from '../types';
16
+ import {
17
+ pixelsToPolicyImageTensor,
18
+ type PolicyImageTensorOptions,
19
+ type PolicyImageTensorResult,
20
+ } from '../policyImageTensors';
21
+
22
+ /** Options for capturing a camera frame straight into a policy image tensor. */
23
+ export type CameraFrameCaptureTensorOptions = CameraFrameCaptureOptions &
24
+ Pick<PolicyImageTensorOptions, 'channels' | 'layout' | 'range'>;
25
+
26
+ export interface CameraFramePixelsResult {
27
+ /** Raw RGBA pixels, bottom-left origin (reused buffer — consume before the next capture). */
28
+ pixels: Uint8Array;
29
+ camera: THREE.Camera;
30
+ width: number;
31
+ height: number;
32
+ source: CameraFrameCaptureSource;
33
+ }
34
+
35
+ export interface CameraFrameTensorResult extends PolicyImageTensorResult {
36
+ camera: THREE.Camera;
37
+ source: CameraFrameCaptureSource;
38
+ }
16
39
 
17
40
  export interface CameraFrameCaptureSession {
18
41
  readonly width: number;
@@ -36,6 +59,14 @@ export interface CameraFrameCaptureSession {
36
59
  options?: CameraFrameCaptureOptions
37
60
  ): Promise<CameraFrameCaptureResult>;
38
61
  captureBlob(options?: CameraFrameCaptureOptions): Promise<CameraFrameCaptureBlobResult>;
62
+ /**
63
+ * Render and read raw RGBA pixels without any canvas/PNG round-trip. The
64
+ * returned buffer is reused between calls — copy or convert it before the
65
+ * next capture.
66
+ */
67
+ capturePixels(options?: CameraFrameCaptureOptions): CameraFramePixelsResult;
68
+ /** Render straight into a normalized policy image tensor (no canvas/PNG encode). */
69
+ captureTensor(options?: CameraFrameCaptureTensorOptions): CameraFrameTensorResult;
39
70
  dispose(): void;
40
71
  }
41
72
 
@@ -250,7 +281,7 @@ function applyProjectionMatrix(
250
281
  camera.projectionMatrixInverse.copy(camera.projectionMatrix).invert();
251
282
  }
252
283
 
253
- function createCaptureCamera(
284
+ export function createCaptureCamera(
254
285
  options: CameraFrameCaptureOptions,
255
286
  fallbackCamera: THREE.Camera,
256
287
  width: number,
@@ -290,7 +321,7 @@ function getCaptureDimensions(
290
321
  return { width, height };
291
322
  }
292
323
 
293
- function prepareCaptureCamera(
324
+ export function prepareCaptureCamera(
294
325
  camera: THREE.Camera,
295
326
  options: CameraFrameCaptureOptions,
296
327
  fallbackCamera: THREE.Camera,
@@ -646,7 +677,10 @@ export function createCameraFrameCaptureSession(
646
677
  return captureOptions;
647
678
  }
648
679
 
649
- function renderPreparedCapture(captureOptions: CameraFrameCaptureOptions) {
680
+ function renderCaptureToTarget(
681
+ captureOptions: CameraFrameCaptureOptions,
682
+ readback: () => void
683
+ ) {
650
684
  const previousState = saveRendererState(sessionRenderer);
651
685
  const previousSceneState = applyCaptureVisualOverrides(
652
686
  sessionRenderer,
@@ -676,6 +710,16 @@ export function createCameraFrameCaptureSession(
676
710
  }
677
711
  sessionRenderer.clear();
678
712
  sessionRenderer.render(scene, camera);
713
+ readback();
714
+ } finally {
715
+ restoreObjectVisibility(hidden);
716
+ if (previousSceneState) restoreSceneVisualState(scene, previousSceneState);
717
+ restoreRendererState(sessionRenderer, previousState);
718
+ }
719
+ }
720
+
721
+ function renderPreparedCapture(captureOptions: CameraFrameCaptureOptions) {
722
+ renderCaptureToTarget(captureOptions, () => {
679
723
  readRenderTargetToCanvas(
680
724
  sessionRenderer,
681
725
  target,
@@ -688,24 +732,50 @@ export function createCameraFrameCaptureSession(
688
732
  sessionRenderer.outputColorSpace,
689
733
  captureOptions.flipX ?? false
690
734
  );
691
- return {
692
- canvas,
693
- camera,
694
- width,
695
- height,
696
- source: getCameraFrameCaptureSource(captureOptions),
697
- };
698
- } finally {
699
- restoreObjectVisibility(hidden);
700
- if (previousSceneState) restoreSceneVisualState(scene, previousSceneState);
701
- restoreRendererState(sessionRenderer, previousState);
702
- }
735
+ });
736
+ return {
737
+ canvas,
738
+ camera,
739
+ width,
740
+ height,
741
+ source: getCameraFrameCaptureSource(captureOptions),
742
+ };
703
743
  }
704
744
 
705
745
  function capture(nextOptions: CameraFrameCaptureOptions = {}) {
706
746
  return renderPreparedCapture(resolveCaptureOptions(nextOptions));
707
747
  }
708
748
 
749
+ function capturePixels(nextOptions: CameraFrameCaptureOptions = {}): CameraFramePixelsResult {
750
+ const captureOptions = resolveCaptureOptions(nextOptions);
751
+ renderCaptureToTarget(captureOptions, () => {
752
+ sessionRenderer.readRenderTargetPixels(target, 0, 0, width, height, pixels);
753
+ });
754
+ return {
755
+ pixels,
756
+ camera,
757
+ width,
758
+ height,
759
+ source: getCameraFrameCaptureSource(captureOptions),
760
+ };
761
+ }
762
+
763
+ function captureTensor(
764
+ nextOptions: CameraFrameCaptureTensorOptions = {}
765
+ ): CameraFrameTensorResult {
766
+ const result = capturePixels(nextOptions);
767
+ const tensor = pixelsToPolicyImageTensor(pixels, {
768
+ width,
769
+ height,
770
+ channels: nextOptions.channels,
771
+ layout: nextOptions.layout,
772
+ range: nextOptions.range,
773
+ sourceOrigin: 'bottom-left',
774
+ flipX: nextOptions.flipX,
775
+ });
776
+ return { ...tensor, camera, source: result.source };
777
+ }
778
+
709
779
  async function captureAsync(nextOptions: CameraFrameCaptureOptions = {}) {
710
780
  const captureOptions = resolveCaptureOptions(nextOptions);
711
781
  runCapturePreRenderHooks(scene);
@@ -779,6 +849,8 @@ export function createCameraFrameCaptureSession(
779
849
  height,
780
850
  capture,
781
851
  captureAsync,
852
+ capturePixels,
853
+ captureTensor,
782
854
  captureDataUrl(nextOptions = {}) {
783
855
  const type = nextOptions.type ?? options.type ?? 'image/png';
784
856
  const result = capture(nextOptions);
@@ -889,3 +961,28 @@ export async function captureCameraFrameBlob(
889
961
  session.dispose();
890
962
  }
891
963
  }
964
+
965
+ /**
966
+ * One-shot camera frame capture straight into a policy image tensor, skipping
967
+ * the canvas/PNG round-trip. For repeated captures (live inference, recording),
968
+ * create a session once with {@link createCameraFrameCaptureSession} and call
969
+ * `session.captureTensor()` so the render target and buffers are reused.
970
+ */
971
+ export function captureCameraFrameTensor(
972
+ renderer: THREE.WebGLRenderer,
973
+ scene: THREE.Scene,
974
+ fallbackCamera: THREE.Camera,
975
+ options: CameraFrameCaptureTensorOptions = {}
976
+ ): CameraFrameTensorResult {
977
+ const session = createCameraFrameCaptureSession(
978
+ renderer,
979
+ scene,
980
+ fallbackCamera,
981
+ options
982
+ );
983
+ try {
984
+ return session.captureTensor(options);
985
+ } finally {
986
+ session.dispose();
987
+ }
988
+ }
package/src/types.ts CHANGED
@@ -7,6 +7,11 @@ import type React from 'react';
7
7
  import type { ReactNode } from 'react';
8
8
  import type { CanvasProps, ThreeElements } from '@react-three/fiber';
9
9
  import * as THREE from 'three';
10
+ import type {
11
+ CameraFrameCaptureSession,
12
+ CameraFrameCaptureTensorOptions,
13
+ CameraFrameTensorResult,
14
+ } from './rendering/cameraFrameCapture';
10
15
 
11
16
  // ---- Register (type-safe named resources) ----
12
17
 
@@ -459,6 +464,10 @@ export interface SceneObject {
459
464
  solref?: string;
460
465
  solimp?: string;
461
466
  condim?: number;
467
+ /** MuJoCo geom contact type bitmask. Defaults to 1 for generated objects. */
468
+ contype?: number;
469
+ /** MuJoCo geom contact affinity bitmask. Defaults to 1 for generated objects. */
470
+ conaffinity?: number;
462
471
  /** MuJoCo geom group. Group 3 is conventionally used for collision-only helper geoms. */
463
472
  group?: number;
464
473
  }
@@ -527,6 +536,12 @@ export interface IkConfig {
527
536
  * starting at index 0. Prefer inferred IK or `joints`/`actuators`.
528
537
  */
529
538
  numJoints?: number;
539
+ /**
540
+ * Optional solve-space joint limits in the same order as the resolved joints.
541
+ * Use this when MJCF limits are intentionally broad or when a setup/calibration
542
+ * tool should stay within a narrower envelope.
543
+ */
544
+ jointLimits?: ReadonlyArray<readonly [number, number] | null | undefined>;
530
545
  /** Custom IK solver. When omitted, uses built-in Damped Least-Squares solver. */
531
546
  ikSolveFn?: IKSolveFn;
532
547
  /** DLS damping. Default: 0.01. */
@@ -1458,6 +1473,19 @@ export interface MujocoSimAPI {
1458
1473
  captureFrameBlob(options?: MujocoFrameCaptureOptions): Promise<FrameCaptureBlobResult>;
1459
1474
  captureCameraFrame(options?: CameraFrameCaptureOptions): Promise<CameraFrameCaptureResult>;
1460
1475
  captureCameraFrameBlob(options?: CameraFrameCaptureOptions): Promise<CameraFrameCaptureBlobResult>;
1476
+ /** Capture a camera frame straight into a policy image tensor (no canvas/PNG encode). */
1477
+ captureCameraFrameTensor(options?: CameraFrameCaptureTensorOptions): CameraFrameTensorResult;
1478
+ /**
1479
+ * Create a reusable offscreen capture session bound to this scene. Reuse it
1480
+ * for live inference/recording so the render target and buffers persist
1481
+ * across frames; call `session.captureTensor()` / `capturePixels()` each step.
1482
+ */
1483
+ createCameraFrameCaptureSession(options?: CameraFrameCaptureOptions): CameraFrameCaptureSession;
1484
+ /**
1485
+ * Resolve a named MuJoCo camera/site/body into concrete capture options with
1486
+ * the current world pose. Useful for re-aiming a persistent session each step.
1487
+ */
1488
+ resolveCameraCaptureOptions(options?: CameraFrameCaptureOptions): CameraFrameCaptureOptions;
1461
1489
  recordCameraSequence(options: CameraFrameSequenceOptions): Promise<CameraFrameSequenceResult>;
1462
1490
  project2DTo3D(
1463
1491
  x: number,