mujoco-react 10.3.0 → 10.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +75 -137
- package/dist/{chunk-6AZEFI6A.js → chunk-KHZ5U36J.js} +157 -16
- package/dist/chunk-KHZ5U36J.js.map +1 -0
- package/dist/index.d.ts +180 -49
- package/dist/index.js +627 -19
- package/dist/index.js.map +1 -1
- package/dist/onnx.d.ts +65 -0
- package/dist/onnx.js +58 -0
- package/dist/onnx.js.map +1 -0
- package/dist/spark.d.ts +1 -1
- package/dist/spark.js +1 -1
- package/dist/{types-BOhNDICK.d.ts → types-CViUme8D.d.ts} +157 -1
- package/package.json +14 -3
- package/src/components/CameraView.tsx +245 -0
- package/src/components/Debug.tsx +174 -3
- package/src/core/GenericIK.ts +16 -4
- package/src/core/MujocoSimProvider.tsx +37 -1
- package/src/core/SceneLoader.ts +3 -2
- package/src/hooks/useCameraStream.ts +115 -0
- package/src/hooks/useControlGroup.ts +0 -0
- package/src/hooks/useIkController.ts +3 -0
- package/src/hooks/usePolicyCameraTensors.ts +215 -0
- package/src/index.ts +45 -0
- package/src/onnx.ts +126 -0
- package/src/policyImageTensors.ts +150 -0
- package/src/rendering/cameraFrameCapture.ts +112 -15
- package/src/types.ts +45 -0
- package/dist/chunk-6AZEFI6A.js.map +0 -1
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Capture policy observation tensors directly from Three/MuJoCo cameras,
|
|
6
|
+
* skipping the data-URL/PNG round-trip. Sessions are created once per camera
|
|
7
|
+
* and reused every step, so live inference and dataset recording read straight
|
|
8
|
+
* from the GPU into Float32 tensors.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
|
12
|
+
import { useMujoco } from '../core/MujocoSimProvider';
|
|
13
|
+
import { createPolicyCameraFrameCapturePlanFromApi } from '../policyCameraFrames';
|
|
14
|
+
import type { MountedPolicyCameraFrameCaptureOptions } from './usePolicyCameraFrames';
|
|
15
|
+
import type {
|
|
16
|
+
CameraFrameCaptureSession,
|
|
17
|
+
CameraFrameCaptureTensorOptions,
|
|
18
|
+
CameraFrameTensorResult,
|
|
19
|
+
} from '../rendering/cameraFrameCapture';
|
|
20
|
+
import type { FrameCaptureStatus } from '../types';
|
|
21
|
+
|
|
22
|
+
export interface PolicyCameraTensorStream extends CameraFrameCaptureTensorOptions {
|
|
23
|
+
/** Payload key this stream's tensor is stored under. */
|
|
24
|
+
key: string;
|
|
25
|
+
/** Additional payload keys that should reference the same tensor. */
|
|
26
|
+
aliases?: readonly string[];
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
export interface PolicyCameraTensorsOptions {
|
|
30
|
+
streams: PolicyCameraTensorStream[];
|
|
31
|
+
/** Also expose tensors under `observation.images.<key>` aliases. Defaults to `false`. */
|
|
32
|
+
includeObservationImageAliases?: boolean;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
export interface PolicyCameraTensorsResult {
|
|
36
|
+
tensors: Record<string, CameraFrameTensorResult>;
|
|
37
|
+
sourceSummary: string;
|
|
38
|
+
capturedAt: number;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
export interface PolicyCameraTensorsAPI {
|
|
42
|
+
status: FrameCaptureStatus;
|
|
43
|
+
error: Error | null;
|
|
44
|
+
isCapturing: boolean;
|
|
45
|
+
/** Synchronously render and convert every stream into a policy image tensor. */
|
|
46
|
+
capture: () => PolicyCameraTensorsResult;
|
|
47
|
+
reset: () => void;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
export type MountedPolicyCameraTensorOptions = MountedPolicyCameraFrameCaptureOptions & {
|
|
51
|
+
tensor?: Pick<
|
|
52
|
+
CameraFrameCaptureTensorOptions,
|
|
53
|
+
'width' | 'height' | 'channels' | 'layout' | 'range'
|
|
54
|
+
>;
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
type SessionEntry = {
|
|
58
|
+
session: CameraFrameCaptureSession;
|
|
59
|
+
signature: string;
|
|
60
|
+
};
|
|
61
|
+
|
|
62
|
+
function sessionSignature(stream: PolicyCameraTensorStream): string {
|
|
63
|
+
return JSON.stringify({
|
|
64
|
+
width: stream.width,
|
|
65
|
+
height: stream.height,
|
|
66
|
+
channels: stream.channels,
|
|
67
|
+
renderIsolation: stream.renderIsolation ?? false,
|
|
68
|
+
cameraName: stream.cameraName,
|
|
69
|
+
siteName: stream.siteName,
|
|
70
|
+
bodyName: stream.bodyName,
|
|
71
|
+
});
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
function addTensorAliases(
|
|
75
|
+
tensors: Record<string, CameraFrameTensorResult>,
|
|
76
|
+
stream: PolicyCameraTensorStream,
|
|
77
|
+
tensor: CameraFrameTensorResult,
|
|
78
|
+
includeObservationImageAliases: boolean
|
|
79
|
+
) {
|
|
80
|
+
const keys = new Set<string>([stream.key, ...(stream.aliases ?? [])]);
|
|
81
|
+
if (includeObservationImageAliases) {
|
|
82
|
+
for (const base of [stream.key, ...(stream.aliases ?? [])]) {
|
|
83
|
+
keys.add(`observation.images.${base}`);
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
for (const key of keys) tensors[key] = tensor;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
export function usePolicyCameraTensors(
|
|
90
|
+
options: PolicyCameraTensorsOptions
|
|
91
|
+
): PolicyCameraTensorsAPI {
|
|
92
|
+
const mujoco = useMujoco();
|
|
93
|
+
const [status, setStatus] = useState<FrameCaptureStatus>('idle');
|
|
94
|
+
const [error, setError] = useState<Error | null>(null);
|
|
95
|
+
const sessionsRef = useRef<Map<string, SessionEntry>>(new Map());
|
|
96
|
+
|
|
97
|
+
const disposeSessions = useCallback(() => {
|
|
98
|
+
for (const { session } of sessionsRef.current.values()) session.dispose();
|
|
99
|
+
sessionsRef.current.clear();
|
|
100
|
+
}, []);
|
|
101
|
+
|
|
102
|
+
useEffect(() => disposeSessions, [disposeSessions]);
|
|
103
|
+
|
|
104
|
+
const reset = useCallback(() => {
|
|
105
|
+
setStatus('idle');
|
|
106
|
+
setError(null);
|
|
107
|
+
}, []);
|
|
108
|
+
|
|
109
|
+
const capture = useCallback((): PolicyCameraTensorsResult => {
|
|
110
|
+
const api = mujoco.api;
|
|
111
|
+
if (!api) {
|
|
112
|
+
throw new Error('MuJoCo scene is not ready for policy camera tensor capture.');
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
setStatus('capturing');
|
|
116
|
+
setError(null);
|
|
117
|
+
|
|
118
|
+
try {
|
|
119
|
+
const sessions = sessionsRef.current;
|
|
120
|
+
const seen = new Set<string>();
|
|
121
|
+
const tensors: Record<string, CameraFrameTensorResult> = {};
|
|
122
|
+
const sourceParts: string[] = [];
|
|
123
|
+
|
|
124
|
+
for (const stream of options.streams) {
|
|
125
|
+
seen.add(stream.key);
|
|
126
|
+
const resolved: CameraFrameCaptureTensorOptions = {
|
|
127
|
+
...api.resolveCameraCaptureOptions(stream),
|
|
128
|
+
channels: stream.channels,
|
|
129
|
+
layout: stream.layout,
|
|
130
|
+
range: stream.range,
|
|
131
|
+
};
|
|
132
|
+
|
|
133
|
+
const signature = sessionSignature(stream);
|
|
134
|
+
let entry = sessions.get(stream.key);
|
|
135
|
+
if (!entry || entry.signature !== signature) {
|
|
136
|
+
entry?.session.dispose();
|
|
137
|
+
entry = {
|
|
138
|
+
session: api.createCameraFrameCaptureSession(resolved),
|
|
139
|
+
signature,
|
|
140
|
+
};
|
|
141
|
+
sessions.set(stream.key, entry);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
const tensor = entry.session.captureTensor(resolved);
|
|
145
|
+
addTensorAliases(
|
|
146
|
+
tensors,
|
|
147
|
+
stream,
|
|
148
|
+
tensor,
|
|
149
|
+
options.includeObservationImageAliases ?? false
|
|
150
|
+
);
|
|
151
|
+
sourceParts.push(`${stream.key}:${tensor.source.kind}`);
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
// Drop sessions for streams that are no longer requested.
|
|
155
|
+
for (const key of [...sessions.keys()]) {
|
|
156
|
+
if (!seen.has(key)) {
|
|
157
|
+
sessions.get(key)?.session.dispose();
|
|
158
|
+
sessions.delete(key);
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
setStatus('captured');
|
|
163
|
+
return {
|
|
164
|
+
tensors,
|
|
165
|
+
sourceSummary: sourceParts.join(' + ') || 'not used by policy',
|
|
166
|
+
capturedAt: Date.now(),
|
|
167
|
+
};
|
|
168
|
+
} catch (nextError) {
|
|
169
|
+
const captureError =
|
|
170
|
+
nextError instanceof Error
|
|
171
|
+
? nextError
|
|
172
|
+
: new Error('Unable to capture policy camera tensors.');
|
|
173
|
+
setError(captureError);
|
|
174
|
+
setStatus('error');
|
|
175
|
+
throw captureError;
|
|
176
|
+
}
|
|
177
|
+
}, [mujoco.api, options.includeObservationImageAliases, options.streams]);
|
|
178
|
+
|
|
179
|
+
return {
|
|
180
|
+
status,
|
|
181
|
+
error,
|
|
182
|
+
isCapturing: status === 'capturing',
|
|
183
|
+
capture,
|
|
184
|
+
reset,
|
|
185
|
+
};
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
export function usePolicyCameraTensorsFromMountedStreams(
|
|
189
|
+
options: MountedPolicyCameraTensorOptions
|
|
190
|
+
): PolicyCameraTensorsAPI {
|
|
191
|
+
const mujoco = useMujoco();
|
|
192
|
+
const tensorOptions = options.tensor;
|
|
193
|
+
const mountedOptions = useMemo<PolicyCameraTensorsOptions>(() => {
|
|
194
|
+
const api = mujoco.api;
|
|
195
|
+
if (!api) {
|
|
196
|
+
return {
|
|
197
|
+
streams: [],
|
|
198
|
+
includeObservationImageAliases: options.includeObservationImageAliases ?? false,
|
|
199
|
+
};
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
const plan = createPolicyCameraFrameCapturePlanFromApi(api, options);
|
|
203
|
+
return {
|
|
204
|
+
streams: plan.streams.map(({ key, aliases, ...stream }) => ({
|
|
205
|
+
...stream,
|
|
206
|
+
...tensorOptions,
|
|
207
|
+
key,
|
|
208
|
+
aliases,
|
|
209
|
+
})),
|
|
210
|
+
includeObservationImageAliases: plan.includeObservationImageAliases ?? false,
|
|
211
|
+
};
|
|
212
|
+
}, [mujoco.api, options, tensorOptions]);
|
|
213
|
+
|
|
214
|
+
return usePolicyCameraTensors(mountedOptions);
|
|
215
|
+
}
|
package/src/index.ts
CHANGED
|
@@ -34,6 +34,8 @@ export type { ControllerOptions, ControllerComponent } from './core/createContro
|
|
|
34
34
|
|
|
35
35
|
// IK controller hook
|
|
36
36
|
export { useIkController } from './hooks/useIkController';
|
|
37
|
+
export { GenericIK } from './core/GenericIK';
|
|
38
|
+
export type { GenericIKOptions } from './core/GenericIK';
|
|
37
39
|
|
|
38
40
|
// Components
|
|
39
41
|
export { Body } from './components/Body';
|
|
@@ -74,6 +76,10 @@ export type {
|
|
|
74
76
|
SplatCollisionProxyPreviewVector3,
|
|
75
77
|
UseSplatCollisionProxyGeomsOptions,
|
|
76
78
|
} from './components/SplatCollisionProxyPreview';
|
|
79
|
+
export { CameraView, useCameraViewport } from './components/CameraView';
|
|
80
|
+
export type { CameraViewProps, CameraViewportOptions } from './components/CameraView';
|
|
81
|
+
export { useCameraStream } from './hooks/useCameraStream';
|
|
82
|
+
export type { CameraStreamOptions } from './hooks/useCameraStream';
|
|
77
83
|
export { Debug } from './components/Debug';
|
|
78
84
|
export { TendonRenderer } from './components/TendonRenderer';
|
|
79
85
|
export { FlexRenderer } from './components/FlexRenderer';
|
|
@@ -91,6 +97,13 @@ export { useBodyState } from './hooks/useBodyState';
|
|
|
91
97
|
export { useBodyPose, useGeomPose, useSitePose } from './hooks/usePose';
|
|
92
98
|
export type { PoseReadout, PoseResourceKind } from './hooks/usePose';
|
|
93
99
|
export { useCtrl } from './hooks/useCtrl';
|
|
100
|
+
export { controlGroup, useControlGroup } from './hooks/useControlGroup';
|
|
101
|
+
export type {
|
|
102
|
+
ControlGroup,
|
|
103
|
+
ControlGroupHandle,
|
|
104
|
+
ControlGroupSetOptions,
|
|
105
|
+
UseControlGroupOptions,
|
|
106
|
+
} from './hooks/useControlGroup';
|
|
94
107
|
export { useControlWriter } from './hooks/useControlWriter';
|
|
95
108
|
export type {
|
|
96
109
|
ControlWriterConflict,
|
|
@@ -130,6 +143,17 @@ export type {
|
|
|
130
143
|
MountedPolicyCameraFrameCaptureAPI,
|
|
131
144
|
MountedPolicyCameraFrameCaptureOptions,
|
|
132
145
|
} from './hooks/usePolicyCameraFrames';
|
|
146
|
+
export {
|
|
147
|
+
usePolicyCameraTensors,
|
|
148
|
+
usePolicyCameraTensorsFromMountedStreams,
|
|
149
|
+
} from './hooks/usePolicyCameraTensors';
|
|
150
|
+
export type {
|
|
151
|
+
MountedPolicyCameraTensorOptions,
|
|
152
|
+
PolicyCameraTensorsAPI,
|
|
153
|
+
PolicyCameraTensorsOptions,
|
|
154
|
+
PolicyCameraTensorsResult,
|
|
155
|
+
PolicyCameraTensorStream,
|
|
156
|
+
} from './hooks/usePolicyCameraTensors';
|
|
133
157
|
export { useCameraSequenceRecorder } from './hooks/useCameraSequenceRecorder';
|
|
134
158
|
export { useMountedCameraSequenceRecorder } from './hooks/useMountedCameraSequenceRecorder';
|
|
135
159
|
export type {
|
|
@@ -145,9 +169,16 @@ export {
|
|
|
145
169
|
CAPTURE_EXCLUDE_KEY,
|
|
146
170
|
captureCameraFrame,
|
|
147
171
|
captureCameraFrameBlob,
|
|
172
|
+
captureCameraFrameTensor,
|
|
148
173
|
createCameraFrameCaptureSession,
|
|
149
174
|
renderCameraFrameToCanvas,
|
|
150
175
|
} from './rendering/cameraFrameCapture';
|
|
176
|
+
export type {
|
|
177
|
+
CameraFrameCaptureSession,
|
|
178
|
+
CameraFrameCaptureTensorOptions,
|
|
179
|
+
CameraFramePixelsResult,
|
|
180
|
+
CameraFrameTensorResult,
|
|
181
|
+
} from './rendering/cameraFrameCapture';
|
|
151
182
|
export {
|
|
152
183
|
imagePointToNdc,
|
|
153
184
|
projectImagePointTo3D,
|
|
@@ -172,6 +203,11 @@ export {
|
|
|
172
203
|
readNamedObservation,
|
|
173
204
|
sitePositionField,
|
|
174
205
|
} from './policyObservation';
|
|
206
|
+
export {
|
|
207
|
+
dataUrlToPolicyImageTensor,
|
|
208
|
+
imageDataToPolicyImageTensor,
|
|
209
|
+
pixelsToPolicyImageTensor,
|
|
210
|
+
} from './policyImageTensors';
|
|
175
211
|
export type {
|
|
176
212
|
CreatePolicyCameraFrameCapturePlanOptions,
|
|
177
213
|
PolicyCameraFrameCapturePlan,
|
|
@@ -191,6 +227,14 @@ export type {
|
|
|
191
227
|
NamedObservationOptions,
|
|
192
228
|
NamedObservationResult,
|
|
193
229
|
} from './policyObservation';
|
|
230
|
+
export type {
|
|
231
|
+
PolicyImageTensorLayout,
|
|
232
|
+
PolicyImageTensorOptions,
|
|
233
|
+
PolicyImageTensorPixelOptions,
|
|
234
|
+
PolicyImageTensorRange,
|
|
235
|
+
PolicyImageTensorResult,
|
|
236
|
+
PolicyImageTensorSourceOrigin,
|
|
237
|
+
} from './policyImageTensors';
|
|
194
238
|
export {
|
|
195
239
|
createMountedCameraFrameSequenceManifest,
|
|
196
240
|
createMountedCameraFrameSequenceReadiness,
|
|
@@ -320,6 +364,7 @@ export type {
|
|
|
320
364
|
IkGizmoProps,
|
|
321
365
|
IkGizmoDragInput,
|
|
322
366
|
DragInteractionProps,
|
|
367
|
+
DebugVirtualCamera,
|
|
323
368
|
DebugProps,
|
|
324
369
|
SceneLightsProps,
|
|
325
370
|
ScenarioLightingPreset,
|
package/src/onnx.ts
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Optional ONNX Runtime Web helpers for browser policy demos.
|
|
6
|
+
*
|
|
7
|
+
* This entry point is exported as `mujoco-react/onnx` so the main package does
|
|
8
|
+
* not import or bundle `onnxruntime-web`.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
import type * as ort from 'onnxruntime-web';
|
|
12
|
+
import type { PolicyActionChunk } from './types';
|
|
13
|
+
|
|
14
|
+
export type OnnxPolicyDtype = 'float32' | 'float64' | 'int32' | 'int64' | 'bool' | string;
|
|
15
|
+
|
|
16
|
+
export interface OnnxPolicyTensorSpec {
|
|
17
|
+
name: string;
|
|
18
|
+
shape: number[];
|
|
19
|
+
dtype: OnnxPolicyDtype;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
export interface OnnxPolicyImageSpec {
|
|
23
|
+
width: number;
|
|
24
|
+
height: number;
|
|
25
|
+
channels?: number;
|
|
26
|
+
layout?: 'CHW' | 'HWC' | string;
|
|
27
|
+
range?: readonly [number, number];
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
export interface OnnxPolicyManifest {
|
|
31
|
+
model: string;
|
|
32
|
+
variants?: Record<string, string>;
|
|
33
|
+
fps?: number;
|
|
34
|
+
joints?: string[];
|
|
35
|
+
cameras?: string[];
|
|
36
|
+
image?: OnnxPolicyImageSpec;
|
|
37
|
+
chunk_size?: number;
|
|
38
|
+
n_action_steps?: number;
|
|
39
|
+
inputs: OnnxPolicyTensorSpec[];
|
|
40
|
+
output: OnnxPolicyTensorSpec & {
|
|
41
|
+
units?: string;
|
|
42
|
+
};
|
|
43
|
+
[key: string]: unknown;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
export interface LoadOnnxPolicyManifestResult<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest> {
|
|
47
|
+
manifest: TManifest;
|
|
48
|
+
manifestUrl: URL;
|
|
49
|
+
modelUrl: URL;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
export interface CreateOnnxPolicySessionOptions<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest> {
|
|
53
|
+
manifestUrl: string | URL;
|
|
54
|
+
variant?: string;
|
|
55
|
+
runtime: typeof ort;
|
|
56
|
+
sessionOptions?: ort.InferenceSession.SessionOptions;
|
|
57
|
+
fetcher?: typeof fetch;
|
|
58
|
+
readManifest?: (response: Response) => Promise<TManifest>;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
export interface OnnxPolicySession<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest>
|
|
62
|
+
extends LoadOnnxPolicyManifestResult<TManifest> {
|
|
63
|
+
session: ort.InferenceSession;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
function asUrl(value: string | URL, base = globalThis.location?.href) {
|
|
67
|
+
return value instanceof URL ? value : new URL(value, base);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
function resolveModelPath(manifest: OnnxPolicyManifest, variant: string | undefined) {
|
|
71
|
+
if (variant && manifest.variants?.[variant]) return manifest.variants[variant];
|
|
72
|
+
return manifest.model;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
export async function loadOnnxPolicyManifest<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest>(
|
|
76
|
+
manifestUrlInput: string | URL,
|
|
77
|
+
options: Pick<CreateOnnxPolicySessionOptions<TManifest>, 'variant' | 'fetcher' | 'readManifest'> = {}
|
|
78
|
+
): Promise<LoadOnnxPolicyManifestResult<TManifest>> {
|
|
79
|
+
const fetcher = options.fetcher ?? fetch;
|
|
80
|
+
const manifestUrl = asUrl(manifestUrlInput);
|
|
81
|
+
const response = await fetcher(manifestUrl);
|
|
82
|
+
if (!response.ok) {
|
|
83
|
+
throw new Error(`Unable to load ONNX policy manifest from ${manifestUrl.href} (${response.status}).`);
|
|
84
|
+
}
|
|
85
|
+
const manifest = options.readManifest
|
|
86
|
+
? await options.readManifest(response)
|
|
87
|
+
: await response.json() as TManifest;
|
|
88
|
+
const modelPath = resolveModelPath(manifest, options.variant);
|
|
89
|
+
const modelUrl = asUrl(modelPath, manifestUrl.href);
|
|
90
|
+
return { manifest, manifestUrl, modelUrl };
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
export async function createOnnxPolicySession<TManifest extends OnnxPolicyManifest = OnnxPolicyManifest>(
|
|
94
|
+
options: CreateOnnxPolicySessionOptions<TManifest>
|
|
95
|
+
): Promise<OnnxPolicySession<TManifest>> {
|
|
96
|
+
const fetcher = options.fetcher ?? fetch;
|
|
97
|
+
const resolved = await loadOnnxPolicyManifest(options.manifestUrl, options);
|
|
98
|
+
const response = await fetcher(resolved.modelUrl);
|
|
99
|
+
if (!response.ok) {
|
|
100
|
+
throw new Error(`Unable to load ONNX policy model from ${resolved.modelUrl.href} (${response.status}).`);
|
|
101
|
+
}
|
|
102
|
+
const modelBytes = await response.arrayBuffer();
|
|
103
|
+
const session = await options.runtime.InferenceSession.create(modelBytes, options.sessionOptions);
|
|
104
|
+
return {
|
|
105
|
+
...resolved,
|
|
106
|
+
session,
|
|
107
|
+
};
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
export function onnxTensorToPolicyActionChunk(
|
|
111
|
+
tensor: ort.Tensor,
|
|
112
|
+
actionSize = tensor.dims.at(-1) ?? 1,
|
|
113
|
+
maxActions?: number
|
|
114
|
+
): PolicyActionChunk {
|
|
115
|
+
const rawData = Array.from(tensor.data as ArrayLike<number>, (value) => Number(value));
|
|
116
|
+
const actionCount = Math.floor(rawData.length / actionSize);
|
|
117
|
+
const cappedActionCount = maxActions === undefined
|
|
118
|
+
? actionCount
|
|
119
|
+
: Math.max(0, Math.min(actionCount, Math.floor(maxActions)));
|
|
120
|
+
const actions: number[][] = [];
|
|
121
|
+
for (let actionIndex = 0; actionIndex < cappedActionCount; actionIndex += 1) {
|
|
122
|
+
const start = actionIndex * actionSize;
|
|
123
|
+
actions.push(rawData.slice(start, start + actionSize));
|
|
124
|
+
}
|
|
125
|
+
return actions;
|
|
126
|
+
}
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Helpers for turning browser camera captures into policy image tensors.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
export type PolicyImageTensorLayout = 'CHW' | 'HWC';
|
|
9
|
+
export type PolicyImageTensorRange = readonly [number, number];
|
|
10
|
+
/**
|
|
11
|
+
* Row order of a raw pixel buffer. WebGL `readRenderTargetPixels` returns rows
|
|
12
|
+
* bottom-to-top (`'bottom-left'`); `ImageData` is top-to-bottom (`'top-left'`).
|
|
13
|
+
*/
|
|
14
|
+
export type PolicyImageTensorSourceOrigin = 'top-left' | 'bottom-left';
|
|
15
|
+
|
|
16
|
+
export interface PolicyImageTensorOptions {
|
|
17
|
+
width: number;
|
|
18
|
+
height: number;
|
|
19
|
+
channels?: 3 | 4;
|
|
20
|
+
layout?: PolicyImageTensorLayout;
|
|
21
|
+
range?: PolicyImageTensorRange;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
export interface PolicyImageTensorPixelOptions extends PolicyImageTensorOptions {
|
|
25
|
+
/** Row order of the source buffer. Defaults to `'top-left'`. */
|
|
26
|
+
sourceOrigin?: PolicyImageTensorSourceOrigin;
|
|
27
|
+
/** Mirror horizontally while reading. */
|
|
28
|
+
flipX?: boolean;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
export interface PolicyImageTensorResult {
|
|
32
|
+
data: Float32Array;
|
|
33
|
+
shape: [number, number, number];
|
|
34
|
+
width: number;
|
|
35
|
+
height: number;
|
|
36
|
+
channels: 3 | 4;
|
|
37
|
+
layout: PolicyImageTensorLayout;
|
|
38
|
+
range: PolicyImageTensorRange;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
function resolveTensorOptions(options: PolicyImageTensorOptions): Required<PolicyImageTensorOptions> {
|
|
42
|
+
return {
|
|
43
|
+
channels: 3,
|
|
44
|
+
layout: 'CHW',
|
|
45
|
+
range: [0, 1],
|
|
46
|
+
...options,
|
|
47
|
+
};
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
function normalizeChannel(value: number, range: PolicyImageTensorRange) {
|
|
51
|
+
const [min, max] = range;
|
|
52
|
+
if (min === 0 && max === 255) return value;
|
|
53
|
+
return min + (value / 255) * (max - min);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
/**
|
|
57
|
+
* Convert a raw RGBA pixel buffer (4 bytes per pixel) directly into a policy
|
|
58
|
+
* image tensor. This is the fast path that skips canvas encoding entirely —
|
|
59
|
+
* feed it the `Uint8Array` returned by `readRenderTargetPixels` (which is
|
|
60
|
+
* bottom-left origin, so pass `sourceOrigin: 'bottom-left'`).
|
|
61
|
+
*/
|
|
62
|
+
export function pixelsToPolicyImageTensor(
|
|
63
|
+
pixels: Uint8Array | Uint8ClampedArray,
|
|
64
|
+
options: PolicyImageTensorPixelOptions
|
|
65
|
+
): PolicyImageTensorResult {
|
|
66
|
+
const resolved = resolveTensorOptions(options);
|
|
67
|
+
const { width, height, channels, layout, range } = resolved;
|
|
68
|
+
const expected = width * height * 4;
|
|
69
|
+
if (pixels.length < expected) {
|
|
70
|
+
throw new Error(
|
|
71
|
+
`Pixel buffer of length ${pixels.length} is too small for ${width}x${height} RGBA data (${expected} bytes).`
|
|
72
|
+
);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
const flipY = options.sourceOrigin === 'bottom-left';
|
|
76
|
+
const flipX = options.flipX ?? false;
|
|
77
|
+
const pixelCount = width * height;
|
|
78
|
+
const data = new Float32Array(pixelCount * channels);
|
|
79
|
+
|
|
80
|
+
for (let y = 0; y < height; y += 1) {
|
|
81
|
+
const sourceY = flipY ? height - y - 1 : y;
|
|
82
|
+
for (let x = 0; x < width; x += 1) {
|
|
83
|
+
const sourceX = flipX ? width - x - 1 : x;
|
|
84
|
+
const source = (sourceY * width + sourceX) * 4;
|
|
85
|
+
const target = y * width + x;
|
|
86
|
+
for (let channel = 0; channel < channels; channel += 1) {
|
|
87
|
+
const value = normalizeChannel(pixels[source + channel], range);
|
|
88
|
+
if (layout === 'CHW') {
|
|
89
|
+
data[channel * pixelCount + target] = value;
|
|
90
|
+
} else {
|
|
91
|
+
data[target * channels + channel] = value;
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
return {
|
|
98
|
+
data,
|
|
99
|
+
shape: layout === 'CHW' ? [channels, height, width] : [height, width, channels],
|
|
100
|
+
width,
|
|
101
|
+
height,
|
|
102
|
+
channels,
|
|
103
|
+
layout,
|
|
104
|
+
range,
|
|
105
|
+
};
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
export function imageDataToPolicyImageTensor(
|
|
109
|
+
imageData: ImageData,
|
|
110
|
+
options: PolicyImageTensorOptions
|
|
111
|
+
): PolicyImageTensorResult {
|
|
112
|
+
const resolved = resolveTensorOptions(options);
|
|
113
|
+
if (imageData.width !== resolved.width || imageData.height !== resolved.height) {
|
|
114
|
+
throw new Error(
|
|
115
|
+
`ImageData size ${imageData.width}x${imageData.height} does not match tensor size ${resolved.width}x${resolved.height}.`
|
|
116
|
+
);
|
|
117
|
+
}
|
|
118
|
+
return pixelsToPolicyImageTensor(imageData.data, {
|
|
119
|
+
...resolved,
|
|
120
|
+
sourceOrigin: 'top-left',
|
|
121
|
+
});
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
async function decodeImageSource(dataUrl: string) {
|
|
125
|
+
const image = new Image();
|
|
126
|
+
image.decoding = 'async';
|
|
127
|
+
image.src = dataUrl;
|
|
128
|
+
await image.decode();
|
|
129
|
+
return image;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
export async function dataUrlToPolicyImageTensor(
|
|
133
|
+
dataUrl: string,
|
|
134
|
+
options: PolicyImageTensorOptions
|
|
135
|
+
): Promise<PolicyImageTensorResult> {
|
|
136
|
+
const resolved = resolveTensorOptions(options);
|
|
137
|
+
const image = await decodeImageSource(dataUrl);
|
|
138
|
+
const canvas = document.createElement('canvas');
|
|
139
|
+
canvas.width = resolved.width;
|
|
140
|
+
canvas.height = resolved.height;
|
|
141
|
+
const context = canvas.getContext('2d', { willReadFrequently: true });
|
|
142
|
+
if (!context) {
|
|
143
|
+
throw new Error('Unable to create a 2D canvas context for policy image tensor conversion.');
|
|
144
|
+
}
|
|
145
|
+
context.drawImage(image, 0, 0, resolved.width, resolved.height);
|
|
146
|
+
return imageDataToPolicyImageTensor(
|
|
147
|
+
context.getImageData(0, 0, resolved.width, resolved.height),
|
|
148
|
+
resolved
|
|
149
|
+
);
|
|
150
|
+
}
|