mujoco-react 10.0.0 → 10.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{chunk-QTCAVQS6.js → chunk-FEKBKHEN.js} +56 -5
- package/dist/chunk-FEKBKHEN.js.map +1 -0
- package/dist/index.d.ts +271 -19
- package/dist/index.js +1459 -407
- package/dist/index.js.map +1 -1
- package/dist/spark.d.ts +1 -1
- package/dist/spark.js +1 -1
- package/dist/{types-BaSMqJHT.d.ts → types-BHBNJubg.d.ts} +133 -2
- package/package.json +1 -1
- package/src/components/SceneRenderer.tsx +11 -4
- package/src/core/GenericIK.ts +12 -1
- package/src/core/MujocoSimProvider.tsx +67 -6
- package/src/core/SceneLoader.ts +8 -2
- package/src/hooks/useContactHistory.ts +155 -0
- package/src/hooks/useControlWriter.ts +176 -0
- package/src/hooks/useNamedObservation.ts +42 -0
- package/src/hooks/usePolicy.ts +133 -10
- package/src/hooks/usePolicyCameraFrames.ts +162 -0
- package/src/hooks/usePose.ts +119 -0
- package/src/hooks/useRemotePolicy.ts +329 -0
- package/src/index.ts +81 -0
- package/src/policyCameraFrames.ts +213 -0
- package/src/policyControls.ts +87 -0
- package/src/policyObservation.ts +172 -0
- package/src/rendering/GeomBuilder.ts +73 -24
- package/src/rendering/cameraFrameCapture.ts +74 -2
- package/src/types.ts +151 -1
- package/dist/chunk-QTCAVQS6.js.map +0 -1
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* useRemotePolicy — HTTP JSON inference wrapper around usePolicy.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import { useMemo, useRef } from 'react';
|
|
9
|
+
import { usePolicy } from './usePolicy';
|
|
10
|
+
import type {
|
|
11
|
+
PolicyInferenceOutput,
|
|
12
|
+
PolicyVector,
|
|
13
|
+
RemotePolicyAPI,
|
|
14
|
+
RemotePolicyConfig,
|
|
15
|
+
RemotePolicyRequestInput,
|
|
16
|
+
RemotePolicyResponseInfo,
|
|
17
|
+
RemotePolicyStatus,
|
|
18
|
+
} from '../types';
|
|
19
|
+
|
|
20
|
+
function now() {
|
|
21
|
+
return typeof performance !== 'undefined' ? performance.now() : Date.now();
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
function isAbortError(error: unknown) {
|
|
25
|
+
return (
|
|
26
|
+
typeof DOMException !== 'undefined' &&
|
|
27
|
+
error instanceof DOMException && error.name === 'AbortError'
|
|
28
|
+
) || (
|
|
29
|
+
error instanceof Error && error.name === 'AbortError'
|
|
30
|
+
);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
function createAbortError(message: string) {
|
|
34
|
+
if (typeof DOMException !== 'undefined') {
|
|
35
|
+
return new DOMException(message, 'AbortError');
|
|
36
|
+
}
|
|
37
|
+
const error = new Error(message);
|
|
38
|
+
error.name = 'AbortError';
|
|
39
|
+
return error;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
function abortController(controller: AbortController | null, reason?: unknown) {
|
|
43
|
+
if (!controller || controller.signal.aborted) return;
|
|
44
|
+
if (reason !== undefined) {
|
|
45
|
+
controller.abort(reason);
|
|
46
|
+
} else {
|
|
47
|
+
controller.abort();
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
function createMergedAbortSignal(
|
|
52
|
+
localSignal: AbortSignal,
|
|
53
|
+
externalSignal: AbortSignal | undefined
|
|
54
|
+
) {
|
|
55
|
+
if (!externalSignal) return localSignal;
|
|
56
|
+
if (externalSignal.aborted) {
|
|
57
|
+
const controller = new AbortController();
|
|
58
|
+
abortController(controller, externalSignal.reason);
|
|
59
|
+
return controller.signal;
|
|
60
|
+
}
|
|
61
|
+
if (typeof AbortSignal !== 'undefined' && typeof AbortSignal.any === 'function') {
|
|
62
|
+
return AbortSignal.any([localSignal, externalSignal]);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
const controller = new AbortController();
|
|
66
|
+
const abortFromLocal = () => abortController(controller, localSignal.reason);
|
|
67
|
+
const abortFromExternal = () => abortController(controller, externalSignal.reason);
|
|
68
|
+
localSignal.addEventListener('abort', abortFromLocal, { once: true });
|
|
69
|
+
externalSignal.addEventListener('abort', abortFromExternal, { once: true });
|
|
70
|
+
return controller.signal;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
function vectorToArray(vector: PolicyVector) {
|
|
74
|
+
return Array.from(vector, (value) => Number(value));
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
function isPolicyVectorArray(value: unknown): value is PolicyVector[] {
|
|
78
|
+
return (
|
|
79
|
+
Array.isArray(value) &&
|
|
80
|
+
value.every((entry) => Array.isArray(entry) || ArrayBuffer.isView(entry))
|
|
81
|
+
);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
function isPolicyVector(value: unknown): value is PolicyVector {
|
|
85
|
+
return Array.isArray(value) || ArrayBuffer.isView(value);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
function defaultBuildRemotePolicyRequest(input: RemotePolicyRequestInput) {
|
|
89
|
+
const observation = vectorToArray(input.observation);
|
|
90
|
+
return {
|
|
91
|
+
observation,
|
|
92
|
+
state: observation,
|
|
93
|
+
time: input.data.time,
|
|
94
|
+
reset: input.reset,
|
|
95
|
+
};
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
async function defaultReadRemotePolicyResponse(response: Response) {
|
|
99
|
+
const text = await response.text();
|
|
100
|
+
if (text.length === 0) return null;
|
|
101
|
+
return JSON.parse(text);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
function defaultParseRemotePolicyResponse(responseBody: unknown): PolicyInferenceOutput {
|
|
105
|
+
if (responseBody && typeof responseBody === 'object') {
|
|
106
|
+
const body = responseBody as {
|
|
107
|
+
action?: unknown;
|
|
108
|
+
actions?: unknown;
|
|
109
|
+
error?: unknown;
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
if (typeof body.error === 'string' && body.error.length > 0) {
|
|
113
|
+
throw new Error(body.error);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
if (isPolicyVectorArray(body.actions) && body.actions.length > 0) {
|
|
117
|
+
return body.actions;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
if (isPolicyVector(body.action)) {
|
|
121
|
+
return body.action;
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
if (isPolicyVectorArray(responseBody) && responseBody.length > 0) {
|
|
126
|
+
return responseBody;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
if (isPolicyVector(responseBody)) {
|
|
130
|
+
return responseBody;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
throw new Error('Remote policy response must include `action` or `actions`.');
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
function createHttpError(response: Response, responseBody: unknown) {
|
|
137
|
+
const suffix =
|
|
138
|
+
responseBody && typeof responseBody === 'object' && 'error' in responseBody
|
|
139
|
+
? `: ${String((responseBody as { error?: unknown }).error)}`
|
|
140
|
+
: '';
|
|
141
|
+
return new Error(`Remote policy request failed with HTTP ${response.status}${suffix}`);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
/**
|
|
145
|
+
* Run a policy whose inference step lives behind an HTTP JSON endpoint.
|
|
146
|
+
*
|
|
147
|
+
* The hook keeps `usePolicy` responsible for timing, queueing, and pause/reset
|
|
148
|
+
* behavior. This wrapper only builds requests, parses responses, and exposes
|
|
149
|
+
* request metadata for HUDs and debugging.
|
|
150
|
+
*/
|
|
151
|
+
export function useRemotePolicy(config: RemotePolicyConfig): RemotePolicyAPI {
|
|
152
|
+
const configRef = useRef(config);
|
|
153
|
+
configRef.current = config;
|
|
154
|
+
const requestCountRef = useRef(0);
|
|
155
|
+
const responseCountRef = useRef(0);
|
|
156
|
+
const remoteStatusRef = useRef<RemotePolicyStatus>('idle');
|
|
157
|
+
const lastRequestBodyRef = useRef<unknown>(null);
|
|
158
|
+
const lastResponseBodyRef = useRef<unknown>(null);
|
|
159
|
+
const lastHttpStatusRef = useRef<number | null>(null);
|
|
160
|
+
const lastRequestMsRef = useRef<number | null>(null);
|
|
161
|
+
const abortControllerRef = useRef<AbortController | null>(null);
|
|
162
|
+
const remoteEpochRef = useRef(0);
|
|
163
|
+
|
|
164
|
+
const policy = usePolicy({
|
|
165
|
+
...config,
|
|
166
|
+
infer: async ({ observation, model, data }) => {
|
|
167
|
+
const cfg = configRef.current;
|
|
168
|
+
abortController(abortControllerRef.current, createAbortError('Remote policy request was superseded.'));
|
|
169
|
+
const controller = new AbortController();
|
|
170
|
+
abortControllerRef.current = controller;
|
|
171
|
+
const signal = createMergedAbortSignal(controller.signal, cfg.signal);
|
|
172
|
+
const remoteEpoch = remoteEpochRef.current;
|
|
173
|
+
const requestIndex = requestCountRef.current;
|
|
174
|
+
const requestInput: RemotePolicyRequestInput = {
|
|
175
|
+
observation,
|
|
176
|
+
model,
|
|
177
|
+
data,
|
|
178
|
+
reset: requestIndex === 0,
|
|
179
|
+
requestIndex,
|
|
180
|
+
signal,
|
|
181
|
+
};
|
|
182
|
+
requestCountRef.current += 1;
|
|
183
|
+
|
|
184
|
+
const requestStartedAt = now();
|
|
185
|
+
const body = await (
|
|
186
|
+
cfg.buildRequest?.(requestInput) ?? defaultBuildRemotePolicyRequest(requestInput)
|
|
187
|
+
);
|
|
188
|
+
signal.throwIfAborted();
|
|
189
|
+
if (remoteEpoch !== remoteEpochRef.current) {
|
|
190
|
+
throw createAbortError('Remote policy request was reset.');
|
|
191
|
+
}
|
|
192
|
+
lastRequestBodyRef.current = body;
|
|
193
|
+
remoteStatusRef.current = 'requesting';
|
|
194
|
+
cfg.onRequest?.({
|
|
195
|
+
...requestInput,
|
|
196
|
+
body,
|
|
197
|
+
requestStartedAt,
|
|
198
|
+
});
|
|
199
|
+
|
|
200
|
+
let response: Response | null = null;
|
|
201
|
+
let responseBody: unknown = null;
|
|
202
|
+
try {
|
|
203
|
+
const headers = new Headers(cfg.headers);
|
|
204
|
+
if (!headers.has('content-type')) {
|
|
205
|
+
headers.set('content-type', 'application/json');
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
const fetcher = cfg.fetcher ?? fetch;
|
|
209
|
+
response = await fetcher(String(cfg.endpoint), {
|
|
210
|
+
...cfg.requestInit,
|
|
211
|
+
method: cfg.method ?? 'POST',
|
|
212
|
+
credentials: cfg.credentials,
|
|
213
|
+
headers,
|
|
214
|
+
signal,
|
|
215
|
+
body: typeof body === 'string' ? body : JSON.stringify(body),
|
|
216
|
+
});
|
|
217
|
+
if (remoteEpoch === remoteEpochRef.current) {
|
|
218
|
+
lastHttpStatusRef.current = response.status;
|
|
219
|
+
}
|
|
220
|
+
responseBody = await (
|
|
221
|
+
cfg.readResponse?.(response) ?? defaultReadRemotePolicyResponse(response)
|
|
222
|
+
);
|
|
223
|
+
signal.throwIfAborted();
|
|
224
|
+
if (remoteEpoch !== remoteEpochRef.current) {
|
|
225
|
+
throw createAbortError('Remote policy request was reset.');
|
|
226
|
+
}
|
|
227
|
+
lastResponseBodyRef.current = responseBody;
|
|
228
|
+
|
|
229
|
+
if (!response.ok) {
|
|
230
|
+
throw createHttpError(response, responseBody);
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
const responseFinishedAt = now();
|
|
234
|
+
const info: RemotePolicyResponseInfo = {
|
|
235
|
+
...requestInput,
|
|
236
|
+
body,
|
|
237
|
+
requestStartedAt,
|
|
238
|
+
response,
|
|
239
|
+
responseBody,
|
|
240
|
+
responseFinishedAt,
|
|
241
|
+
requestMs: responseFinishedAt - requestStartedAt,
|
|
242
|
+
};
|
|
243
|
+
if (remoteEpoch === remoteEpochRef.current) {
|
|
244
|
+
lastRequestMsRef.current = info.requestMs;
|
|
245
|
+
responseCountRef.current += 1;
|
|
246
|
+
}
|
|
247
|
+
cfg.onResponse?.(info);
|
|
248
|
+
const output = await (
|
|
249
|
+
cfg.parseResponse?.(responseBody, info) ??
|
|
250
|
+
defaultParseRemotePolicyResponse(responseBody)
|
|
251
|
+
);
|
|
252
|
+
if (remoteEpoch === remoteEpochRef.current) {
|
|
253
|
+
remoteStatusRef.current = 'ready';
|
|
254
|
+
}
|
|
255
|
+
return output;
|
|
256
|
+
} catch (error) {
|
|
257
|
+
if (response && remoteEpoch === remoteEpochRef.current) {
|
|
258
|
+
lastHttpStatusRef.current = response.status;
|
|
259
|
+
}
|
|
260
|
+
if (isAbortError(error) || signal.aborted) {
|
|
261
|
+
if (remoteEpoch === remoteEpochRef.current) {
|
|
262
|
+
remoteStatusRef.current = 'aborted';
|
|
263
|
+
}
|
|
264
|
+
throw error;
|
|
265
|
+
}
|
|
266
|
+
if (remoteEpoch === remoteEpochRef.current) {
|
|
267
|
+
remoteStatusRef.current = 'error';
|
|
268
|
+
}
|
|
269
|
+
throw error;
|
|
270
|
+
} finally {
|
|
271
|
+
if (abortControllerRef.current === controller) {
|
|
272
|
+
abortControllerRef.current = null;
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
},
|
|
276
|
+
});
|
|
277
|
+
|
|
278
|
+
return useMemo(() => {
|
|
279
|
+
const abort = (reason?: unknown) => {
|
|
280
|
+
abortController(abortControllerRef.current, reason);
|
|
281
|
+
if (abortControllerRef.current) {
|
|
282
|
+
remoteStatusRef.current = 'aborted';
|
|
283
|
+
}
|
|
284
|
+
};
|
|
285
|
+
const resetRemoteState = () => {
|
|
286
|
+
remoteEpochRef.current += 1;
|
|
287
|
+
abort(createAbortError('Remote policy request was reset.'));
|
|
288
|
+
requestCountRef.current = 0;
|
|
289
|
+
responseCountRef.current = 0;
|
|
290
|
+
remoteStatusRef.current = 'idle';
|
|
291
|
+
lastRequestBodyRef.current = null;
|
|
292
|
+
lastResponseBodyRef.current = null;
|
|
293
|
+
lastHttpStatusRef.current = null;
|
|
294
|
+
lastRequestMsRef.current = null;
|
|
295
|
+
};
|
|
296
|
+
|
|
297
|
+
return {
|
|
298
|
+
get isRunning() { return policy.isRunning; },
|
|
299
|
+
start: policy.start,
|
|
300
|
+
stop: () => {
|
|
301
|
+
if (configRef.current.abortOnStop ?? true) {
|
|
302
|
+
abort(createAbortError('Remote policy request was stopped.'));
|
|
303
|
+
}
|
|
304
|
+
policy.stop();
|
|
305
|
+
if (configRef.current.clearQueueOnStop) {
|
|
306
|
+
resetRemoteState();
|
|
307
|
+
}
|
|
308
|
+
},
|
|
309
|
+
clearQueue: policy.clearQueue,
|
|
310
|
+
abort,
|
|
311
|
+
reset: () => {
|
|
312
|
+
resetRemoteState();
|
|
313
|
+
policy.reset();
|
|
314
|
+
},
|
|
315
|
+
get inFlight() { return policy.inFlight; },
|
|
316
|
+
get queuedActions() { return policy.queuedActions; },
|
|
317
|
+
get lastObservation() { return policy.lastObservation; },
|
|
318
|
+
get lastAction() { return policy.lastAction; },
|
|
319
|
+
get lastError() { return policy.lastError; },
|
|
320
|
+
get remoteStatus() { return remoteStatusRef.current; },
|
|
321
|
+
get requestCount() { return requestCountRef.current; },
|
|
322
|
+
get responseCount() { return responseCountRef.current; },
|
|
323
|
+
get lastRequestBody() { return lastRequestBodyRef.current; },
|
|
324
|
+
get lastResponseBody() { return lastResponseBodyRef.current; },
|
|
325
|
+
get lastHttpStatus() { return lastHttpStatusRef.current; },
|
|
326
|
+
get lastRequestMs() { return lastRequestMsRef.current; },
|
|
327
|
+
};
|
|
328
|
+
}, [policy]);
|
|
329
|
+
}
|
package/src/index.ts
CHANGED
|
@@ -88,12 +88,30 @@ export { useGravityCompensation } from './hooks/useGravityCompensation';
|
|
|
88
88
|
export { useSensor, useSensors } from './hooks/useSensor';
|
|
89
89
|
export { useJointState } from './hooks/useJointState';
|
|
90
90
|
export { useBodyState } from './hooks/useBodyState';
|
|
91
|
+
export { useBodyPose, useGeomPose, useSitePose } from './hooks/usePose';
|
|
92
|
+
export type { PoseReadout, PoseResourceKind } from './hooks/usePose';
|
|
91
93
|
export { useCtrl } from './hooks/useCtrl';
|
|
94
|
+
export { useControlWriter } from './hooks/useControlWriter';
|
|
95
|
+
export type {
|
|
96
|
+
ControlWriterConflict,
|
|
97
|
+
ControlWriterHandle,
|
|
98
|
+
ControlWriterOptions,
|
|
99
|
+
ControlWriterWriteOptions,
|
|
100
|
+
} from './hooks/useControlWriter';
|
|
92
101
|
export { useContacts, useContactEvents } from './hooks/useContacts';
|
|
102
|
+
export { useContactHistory } from './hooks/useContactHistory';
|
|
103
|
+
export type {
|
|
104
|
+
ContactHistoryEntry,
|
|
105
|
+
ContactHistoryHandle,
|
|
106
|
+
ContactHistoryOptions,
|
|
107
|
+
} from './hooks/useContactHistory';
|
|
93
108
|
export { useKeyboardTeleop } from './hooks/useKeyboardTeleop';
|
|
94
109
|
export { useKeyboardIkTarget } from './hooks/useKeyboardIkTarget';
|
|
95
110
|
export { usePolicy } from './hooks/usePolicy';
|
|
111
|
+
export { useRemotePolicy } from './hooks/useRemotePolicy';
|
|
96
112
|
export { useObservation } from './hooks/useObservation';
|
|
113
|
+
export { useNamedObservation } from './hooks/useNamedObservation';
|
|
114
|
+
export type { NamedObservationHandle } from './hooks/useNamedObservation';
|
|
97
115
|
export { useTrajectoryPlayer } from './hooks/useTrajectoryPlayer';
|
|
98
116
|
export { useTrajectoryRecorder } from './hooks/useTrajectoryRecorder';
|
|
99
117
|
export { useGamepad } from './hooks/useGamepad';
|
|
@@ -104,6 +122,14 @@ export {
|
|
|
104
122
|
useFrameCapture,
|
|
105
123
|
} from './hooks/useFrameCapture';
|
|
106
124
|
export { useCameraFrameCapture } from './hooks/useCameraFrameCapture';
|
|
125
|
+
export {
|
|
126
|
+
usePolicyCameraFrames,
|
|
127
|
+
usePolicyCameraFramesFromMountedStreams,
|
|
128
|
+
} from './hooks/usePolicyCameraFrames';
|
|
129
|
+
export type {
|
|
130
|
+
MountedPolicyCameraFrameCaptureAPI,
|
|
131
|
+
MountedPolicyCameraFrameCaptureOptions,
|
|
132
|
+
} from './hooks/usePolicyCameraFrames';
|
|
107
133
|
export { useCameraSequenceRecorder } from './hooks/useCameraSequenceRecorder';
|
|
108
134
|
export { useMountedCameraSequenceRecorder } from './hooks/useMountedCameraSequenceRecorder';
|
|
109
135
|
export type {
|
|
@@ -114,12 +140,53 @@ export type {
|
|
|
114
140
|
MountedCameraSequenceRecordResult,
|
|
115
141
|
} from './hooks/useMountedCameraSequenceRecorder';
|
|
116
142
|
export {
|
|
143
|
+
CAMERA_FRAME_CAPTURE_RENDER_USER_DATA_KEY,
|
|
144
|
+
CAMERA_FRAME_CAPTURE_PRE_RENDER_USER_DATA_KEY,
|
|
117
145
|
CAPTURE_EXCLUDE_KEY,
|
|
118
146
|
captureCameraFrame,
|
|
119
147
|
captureCameraFrameBlob,
|
|
120
148
|
createCameraFrameCaptureSession,
|
|
121
149
|
renderCameraFrameToCanvas,
|
|
122
150
|
} from './rendering/cameraFrameCapture';
|
|
151
|
+
export {
|
|
152
|
+
capturePolicyCameraFrames,
|
|
153
|
+
capturePolicyCameraFramesFromMountedStreams,
|
|
154
|
+
createPolicyCameraFrameCapturePlan,
|
|
155
|
+
createPolicyCameraFrameCapturePlanFromApi,
|
|
156
|
+
} from './policyCameraFrames';
|
|
157
|
+
export {
|
|
158
|
+
applyPolicyActionToControls,
|
|
159
|
+
clampPolicyActionValue,
|
|
160
|
+
} from './policyControls';
|
|
161
|
+
export {
|
|
162
|
+
bodyPositionField,
|
|
163
|
+
createNamedObservationBuilder,
|
|
164
|
+
ctrlField,
|
|
165
|
+
geomPositionField,
|
|
166
|
+
qposField,
|
|
167
|
+
qvelField,
|
|
168
|
+
readNamedObservation,
|
|
169
|
+
sitePositionField,
|
|
170
|
+
} from './policyObservation';
|
|
171
|
+
export type {
|
|
172
|
+
CreatePolicyCameraFrameCapturePlanOptions,
|
|
173
|
+
PolicyCameraFrameCapturePlan,
|
|
174
|
+
PolicyCameraFrameCaptureTarget,
|
|
175
|
+
PolicyCameraFramePlanTarget,
|
|
176
|
+
PolicyCameraFrameStreamOptions,
|
|
177
|
+
} from './policyCameraFrames';
|
|
178
|
+
export type {
|
|
179
|
+
ApplyPolicyActionToControlsOptions,
|
|
180
|
+
ApplyPolicyActionToControlsResult,
|
|
181
|
+
} from './policyControls';
|
|
182
|
+
export type {
|
|
183
|
+
NamedObservationField,
|
|
184
|
+
NamedObservationInput,
|
|
185
|
+
NamedObservationLayoutItem,
|
|
186
|
+
NamedObservationMissing,
|
|
187
|
+
NamedObservationOptions,
|
|
188
|
+
NamedObservationResult,
|
|
189
|
+
} from './policyObservation';
|
|
123
190
|
export {
|
|
124
191
|
createMountedCameraFrameSequenceManifest,
|
|
125
192
|
createMountedCameraFrameSequenceReadiness,
|
|
@@ -216,11 +283,25 @@ export type {
|
|
|
216
283
|
KeyboardIkTargetBinding,
|
|
217
284
|
KeyboardIkTargetConfig,
|
|
218
285
|
// Policy
|
|
286
|
+
PolicyAPI,
|
|
219
287
|
PolicyConfig,
|
|
220
288
|
PolicyVector,
|
|
289
|
+
PolicyActionChunk,
|
|
290
|
+
PolicyInferenceOutput,
|
|
221
291
|
PolicyObservationInput,
|
|
222
292
|
PolicyInferenceInput,
|
|
293
|
+
PolicyInferenceResult,
|
|
223
294
|
PolicyActionInput,
|
|
295
|
+
RemotePolicyAPI,
|
|
296
|
+
RemotePolicyConfig,
|
|
297
|
+
RemotePolicyRequestInput,
|
|
298
|
+
RemotePolicyRequestInfo,
|
|
299
|
+
RemotePolicyResponseInfo,
|
|
300
|
+
RemotePolicyStatus,
|
|
301
|
+
PolicyCameraFrameCaptureAPI,
|
|
302
|
+
PolicyCameraFrameCaptureOptions,
|
|
303
|
+
PolicyCameraFrameCaptureResult,
|
|
304
|
+
PolicyCameraFrameStream,
|
|
224
305
|
// Observations
|
|
225
306
|
ObservationConfig,
|
|
226
307
|
ObservationHandle,
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Helpers for turning Three/MuJoCo camera captures into policy image payloads.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import type {
|
|
9
|
+
CameraFrameCaptureResult,
|
|
10
|
+
MujocoSimAPI,
|
|
11
|
+
PolicyCameraFrameCaptureOptions,
|
|
12
|
+
PolicyCameraFrameCaptureResult,
|
|
13
|
+
PolicyCameraFrameStream,
|
|
14
|
+
} from './types';
|
|
15
|
+
import {
|
|
16
|
+
createMountedCameraFrameSequencePlan,
|
|
17
|
+
type MountedCameraFrameSequenceCameraOptions,
|
|
18
|
+
type MountedCameraFrameSequenceDefaults,
|
|
19
|
+
type MountedCameraFrameSequencePlan,
|
|
20
|
+
type ResolveMountedCameraFrameSourceOptions,
|
|
21
|
+
} from './rendering/cameraFrameSource';
|
|
22
|
+
|
|
23
|
+
export type PolicyCameraFrameCaptureTarget = Pick<MujocoSimAPI, 'captureCameraFrame'>;
|
|
24
|
+
export type PolicyCameraFramePlanTarget = Pick<MujocoSimAPI, 'getCameras' | 'getSites' | 'getBodies'>;
|
|
25
|
+
|
|
26
|
+
export type PolicyCameraFrameStreamOptions =
|
|
27
|
+
Partial<Omit<PolicyCameraFrameStream, 'key'>> &
|
|
28
|
+
MountedCameraFrameSequenceCameraOptions & {
|
|
29
|
+
/** Additional policy payload keys that should receive this stream's data URL. */
|
|
30
|
+
aliases?: readonly string[];
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
export interface CreatePolicyCameraFrameCapturePlanOptions
|
|
34
|
+
extends Omit<
|
|
35
|
+
ResolveMountedCameraFrameSourceOptions,
|
|
36
|
+
'cameras' | 'sites' | 'bodies'
|
|
37
|
+
> {
|
|
38
|
+
cameraKeys: readonly string[];
|
|
39
|
+
cameras?: ResolveMountedCameraFrameSourceOptions['cameras'];
|
|
40
|
+
sites?: ResolveMountedCameraFrameSourceOptions['sites'];
|
|
41
|
+
bodies?: ResolveMountedCameraFrameSourceOptions['bodies'];
|
|
42
|
+
defaults?: MountedCameraFrameSequenceDefaults;
|
|
43
|
+
streamOptions?: Record<string, PolicyCameraFrameStreamOptions>;
|
|
44
|
+
includeObservationImageAliases?: boolean;
|
|
45
|
+
requireAll?: boolean;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
export interface PolicyCameraFrameCapturePlan
|
|
49
|
+
extends PolicyCameraFrameCaptureOptions {
|
|
50
|
+
cameraKeys: string[];
|
|
51
|
+
streams: PolicyCameraFrameStream[];
|
|
52
|
+
mountedPlan: MountedCameraFrameSequencePlan;
|
|
53
|
+
missingKeys: string[];
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
function addPolicyImageAliases(
|
|
57
|
+
images: Record<string, string>,
|
|
58
|
+
stream: PolicyCameraFrameStream,
|
|
59
|
+
frame: CameraFrameCaptureResult,
|
|
60
|
+
includeObservationImageAliases: boolean
|
|
61
|
+
) {
|
|
62
|
+
const keys = new Set<string>();
|
|
63
|
+
keys.add(stream.key);
|
|
64
|
+
for (const alias of stream.aliases ?? []) keys.add(alias);
|
|
65
|
+
if (includeObservationImageAliases) {
|
|
66
|
+
keys.add(`observation.images.${stream.key}`);
|
|
67
|
+
for (const alias of stream.aliases ?? []) {
|
|
68
|
+
keys.add(`observation.images.${alias}`);
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
for (const key of keys) {
|
|
72
|
+
images[key] = frame.dataUrl;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
function describeFrameSource(key: string, frame: CameraFrameCaptureResult) {
|
|
77
|
+
return `${key}:${frame.source.kind}`;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
function hasExplicitPolicyCameraSource(
|
|
81
|
+
options: PolicyCameraFrameStreamOptions | undefined
|
|
82
|
+
) {
|
|
83
|
+
return Boolean(
|
|
84
|
+
options?.camera ||
|
|
85
|
+
options?.position ||
|
|
86
|
+
options?.quaternion ||
|
|
87
|
+
options?.source
|
|
88
|
+
);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
export function createPolicyCameraFrameCapturePlan(
|
|
92
|
+
options: CreatePolicyCameraFrameCapturePlanOptions
|
|
93
|
+
): PolicyCameraFrameCapturePlan {
|
|
94
|
+
const {
|
|
95
|
+
cameraKeys,
|
|
96
|
+
defaults,
|
|
97
|
+
streamOptions,
|
|
98
|
+
includeObservationImageAliases,
|
|
99
|
+
requireAll,
|
|
100
|
+
...sourceOptions
|
|
101
|
+
} = options;
|
|
102
|
+
const mountedPlan = createMountedCameraFrameSequencePlan(cameraKeys, {
|
|
103
|
+
...sourceOptions,
|
|
104
|
+
defaults,
|
|
105
|
+
cameraOptions: streamOptions as
|
|
106
|
+
| Record<string, MountedCameraFrameSequenceCameraOptions>
|
|
107
|
+
| undefined,
|
|
108
|
+
});
|
|
109
|
+
const streams: PolicyCameraFrameStream[] = [];
|
|
110
|
+
const missingKeys = new Set(mountedPlan.missingKeys);
|
|
111
|
+
|
|
112
|
+
for (const key of cameraKeys) {
|
|
113
|
+
const perStreamOptions = streamOptions?.[key];
|
|
114
|
+
if (hasExplicitPolicyCameraSource(perStreamOptions)) {
|
|
115
|
+
missingKeys.delete(key);
|
|
116
|
+
streams.push({
|
|
117
|
+
...defaults,
|
|
118
|
+
...perStreamOptions,
|
|
119
|
+
key,
|
|
120
|
+
aliases: perStreamOptions?.aliases,
|
|
121
|
+
});
|
|
122
|
+
continue;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
const mountedCamera = mountedPlan.cameras.find((camera) => camera.key === key);
|
|
126
|
+
if (!mountedCamera) continue;
|
|
127
|
+
const { key: _mountedKey, ...captureOptions } = mountedCamera;
|
|
128
|
+
streams.push({
|
|
129
|
+
...captureOptions,
|
|
130
|
+
key,
|
|
131
|
+
aliases: perStreamOptions?.aliases,
|
|
132
|
+
});
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
const result: PolicyCameraFrameCapturePlan = {
|
|
136
|
+
cameraKeys: [...cameraKeys],
|
|
137
|
+
streams,
|
|
138
|
+
includeObservationImageAliases,
|
|
139
|
+
mountedPlan,
|
|
140
|
+
missingKeys: [...missingKeys],
|
|
141
|
+
};
|
|
142
|
+
|
|
143
|
+
if (requireAll && result.missingKeys.length > 0) {
|
|
144
|
+
throw new Error(
|
|
145
|
+
`Unable to resolve policy camera stream${
|
|
146
|
+
result.missingKeys.length === 1 ? '' : 's'
|
|
147
|
+
} for ${result.missingKeys.join(', ')}.`
|
|
148
|
+
);
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
return result;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
export function createPolicyCameraFrameCapturePlanFromApi(
|
|
155
|
+
api: PolicyCameraFramePlanTarget,
|
|
156
|
+
options: Omit<CreatePolicyCameraFrameCapturePlanOptions, 'cameras' | 'sites' | 'bodies'>
|
|
157
|
+
): PolicyCameraFrameCapturePlan {
|
|
158
|
+
return createPolicyCameraFrameCapturePlan({
|
|
159
|
+
...options,
|
|
160
|
+
cameras: api.getCameras(),
|
|
161
|
+
sites: api.getSites(),
|
|
162
|
+
bodies: api.getBodies(),
|
|
163
|
+
});
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
export async function capturePolicyCameraFrames(
|
|
167
|
+
target: PolicyCameraFrameCaptureTarget,
|
|
168
|
+
options: PolicyCameraFrameCaptureOptions
|
|
169
|
+
): Promise<PolicyCameraFrameCaptureResult> {
|
|
170
|
+
const includeObservationImageAliases =
|
|
171
|
+
options.includeObservationImageAliases ?? true;
|
|
172
|
+
|
|
173
|
+
const entries = await Promise.all(
|
|
174
|
+
options.streams.map(async ({ key, aliases, ...captureOptions }) => {
|
|
175
|
+
const frame = await target.captureCameraFrame(captureOptions);
|
|
176
|
+
return [key, { frame, aliases }] as const;
|
|
177
|
+
})
|
|
178
|
+
);
|
|
179
|
+
|
|
180
|
+
const frames: Record<string, CameraFrameCaptureResult> = {};
|
|
181
|
+
const images: Record<string, string> = {};
|
|
182
|
+
const sourceParts: string[] = [];
|
|
183
|
+
|
|
184
|
+
for (const [key, { frame, aliases }] of entries) {
|
|
185
|
+
const stream = { key, aliases };
|
|
186
|
+
frames[key] = frame;
|
|
187
|
+
addPolicyImageAliases(
|
|
188
|
+
images,
|
|
189
|
+
stream,
|
|
190
|
+
frame,
|
|
191
|
+
includeObservationImageAliases
|
|
192
|
+
);
|
|
193
|
+
sourceParts.push(describeFrameSource(key, frame));
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
return {
|
|
197
|
+
frames,
|
|
198
|
+
images,
|
|
199
|
+
sourceSummary: sourceParts.length > 0
|
|
200
|
+
? sourceParts.join(' + ')
|
|
201
|
+
: 'not used by policy',
|
|
202
|
+
capturedAt: Date.now(),
|
|
203
|
+
};
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
export async function capturePolicyCameraFramesFromMountedStreams(
|
|
207
|
+
target: PolicyCameraFrameCaptureTarget & PolicyCameraFramePlanTarget,
|
|
208
|
+
options: Omit<CreatePolicyCameraFrameCapturePlanOptions, 'cameras' | 'sites' | 'bodies'>
|
|
209
|
+
): Promise<PolicyCameraFrameCaptureResult & { plan: PolicyCameraFrameCapturePlan }> {
|
|
210
|
+
const plan = createPolicyCameraFrameCapturePlanFromApi(target, options);
|
|
211
|
+
const result = await capturePolicyCameraFrames(target, plan);
|
|
212
|
+
return { ...result, plan };
|
|
213
|
+
}
|