mujoco-react 10.0.0 → 10.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{chunk-QTCAVQS6.js → chunk-FEKBKHEN.js} +56 -5
- package/dist/chunk-FEKBKHEN.js.map +1 -0
- package/dist/index.d.ts +271 -19
- package/dist/index.js +1459 -407
- package/dist/index.js.map +1 -1
- package/dist/spark.d.ts +1 -1
- package/dist/spark.js +1 -1
- package/dist/{types-BaSMqJHT.d.ts → types-BHBNJubg.d.ts} +133 -2
- package/package.json +1 -1
- package/src/components/SceneRenderer.tsx +11 -4
- package/src/core/GenericIK.ts +12 -1
- package/src/core/MujocoSimProvider.tsx +67 -6
- package/src/core/SceneLoader.ts +8 -2
- package/src/hooks/useContactHistory.ts +155 -0
- package/src/hooks/useControlWriter.ts +176 -0
- package/src/hooks/useNamedObservation.ts +42 -0
- package/src/hooks/usePolicy.ts +133 -10
- package/src/hooks/usePolicyCameraFrames.ts +162 -0
- package/src/hooks/usePose.ts +119 -0
- package/src/hooks/useRemotePolicy.ts +329 -0
- package/src/index.ts +81 -0
- package/src/policyCameraFrames.ts +213 -0
- package/src/policyControls.ts +87 -0
- package/src/policyObservation.ts +172 -0
- package/src/rendering/GeomBuilder.ts +73 -24
- package/src/rendering/cameraFrameCapture.ts +74 -2
- package/src/types.ts +151 -1
- package/dist/chunk-QTCAVQS6.js.map +0 -1
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Cooperative actuator/control ownership for policies, IK, teleop, and replay.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import { useCallback, useEffect, useMemo, useRef } from 'react';
|
|
9
|
+
import type { RefObject } from 'react';
|
|
10
|
+
import { useMujocoContext } from '../core/MujocoSimProvider';
|
|
11
|
+
import type { ControlGroupInfo, ControlGroupSelector, MujocoData, MujocoModel } from '../types';
|
|
12
|
+
|
|
13
|
+
export interface ControlWriterConflict {
|
|
14
|
+
actuatorIndex: number;
|
|
15
|
+
owner: string;
|
|
16
|
+
requestedBy: string;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
export interface ControlWriterOptions {
|
|
20
|
+
owner: string;
|
|
21
|
+
selector?: ControlGroupSelector;
|
|
22
|
+
enabled?: boolean;
|
|
23
|
+
warnOnConflict?: boolean;
|
|
24
|
+
allowSameOwner?: boolean;
|
|
25
|
+
onConflict?: (conflicts: ControlWriterConflict[]) => void;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
export interface ControlWriterWriteOptions {
|
|
29
|
+
force?: boolean;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
export interface ControlWriterHandle {
|
|
33
|
+
owner: string;
|
|
34
|
+
group: RefObject<ControlGroupInfo | null>;
|
|
35
|
+
conflicts: RefObject<ControlWriterConflict[]>;
|
|
36
|
+
canWrite: () => boolean;
|
|
37
|
+
read: () => Float64Array;
|
|
38
|
+
write: (values: ArrayLike<number>, options?: ControlWriterWriteOptions) => boolean;
|
|
39
|
+
release: () => void;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
interface ControlWriterClaim {
|
|
43
|
+
owner: string;
|
|
44
|
+
token: symbol;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
const claimsByModel = new WeakMap<MujocoModel, Map<number, ControlWriterClaim>>();
|
|
48
|
+
|
|
49
|
+
function getClaims(model: MujocoModel) {
|
|
50
|
+
let claims = claimsByModel.get(model);
|
|
51
|
+
if (!claims) {
|
|
52
|
+
claims = new Map();
|
|
53
|
+
claimsByModel.set(model, claims);
|
|
54
|
+
}
|
|
55
|
+
return claims;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
function releaseClaims(model: MujocoModel | null, token: symbol) {
|
|
59
|
+
if (!model) return;
|
|
60
|
+
const claims = claimsByModel.get(model);
|
|
61
|
+
if (!claims) return;
|
|
62
|
+
for (const [actuatorIndex, claim] of claims) {
|
|
63
|
+
if (claim.token === token) claims.delete(actuatorIndex);
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
export function useControlWriter(options: ControlWriterOptions): ControlWriterHandle {
|
|
68
|
+
const {
|
|
69
|
+
owner,
|
|
70
|
+
selector,
|
|
71
|
+
enabled = true,
|
|
72
|
+
warnOnConflict = true,
|
|
73
|
+
allowSameOwner = true,
|
|
74
|
+
onConflict,
|
|
75
|
+
} = options;
|
|
76
|
+
const { api, mjModelRef, mjDataRef, status } = useMujocoContext();
|
|
77
|
+
const tokenRef = useRef(Symbol(owner));
|
|
78
|
+
const claimedModelRef = useRef<MujocoModel | null>(null);
|
|
79
|
+
const groupRef = useRef<ControlGroupInfo | null>(null);
|
|
80
|
+
const conflictsRef = useRef<ControlWriterConflict[]>([]);
|
|
81
|
+
const onConflictRef = useRef(onConflict);
|
|
82
|
+
onConflictRef.current = onConflict;
|
|
83
|
+
|
|
84
|
+
const release = useCallback(() => {
|
|
85
|
+
releaseClaims(claimedModelRef.current, tokenRef.current);
|
|
86
|
+
claimedModelRef.current = null;
|
|
87
|
+
conflictsRef.current = [];
|
|
88
|
+
}, []);
|
|
89
|
+
|
|
90
|
+
useEffect(() => {
|
|
91
|
+
release();
|
|
92
|
+
if (!enabled || status !== 'ready') {
|
|
93
|
+
groupRef.current = null;
|
|
94
|
+
return;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
const model = mjModelRef.current;
|
|
98
|
+
if (!model) {
|
|
99
|
+
groupRef.current = null;
|
|
100
|
+
return;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
const group = selector ? api.resolveControlGroup(selector) : api.getControlMap();
|
|
104
|
+
groupRef.current = group;
|
|
105
|
+
if (!group) return;
|
|
106
|
+
|
|
107
|
+
const claims = getClaims(model);
|
|
108
|
+
const conflicts: ControlWriterConflict[] = [];
|
|
109
|
+
for (const actuatorIndex of group.ctrlAdr) {
|
|
110
|
+
const existing = claims.get(actuatorIndex);
|
|
111
|
+
if (
|
|
112
|
+
existing &&
|
|
113
|
+
existing.token !== tokenRef.current &&
|
|
114
|
+
(!allowSameOwner || existing.owner !== owner)
|
|
115
|
+
) {
|
|
116
|
+
conflicts.push({
|
|
117
|
+
actuatorIndex,
|
|
118
|
+
owner: existing.owner,
|
|
119
|
+
requestedBy: owner,
|
|
120
|
+
});
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
conflictsRef.current = conflicts;
|
|
125
|
+
if (conflicts.length > 0) {
|
|
126
|
+
onConflictRef.current?.(conflicts);
|
|
127
|
+
if (warnOnConflict) {
|
|
128
|
+
console.warn(
|
|
129
|
+
`[mujoco-react] Control writer "${owner}" conflicts with existing writer(s): ${conflicts
|
|
130
|
+
.map((conflict) => `${conflict.actuatorIndex}:${conflict.owner}`)
|
|
131
|
+
.join(', ')}`
|
|
132
|
+
);
|
|
133
|
+
}
|
|
134
|
+
return;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
for (const actuatorIndex of group.ctrlAdr) {
|
|
138
|
+
claims.set(actuatorIndex, { owner, token: tokenRef.current });
|
|
139
|
+
}
|
|
140
|
+
claimedModelRef.current = model;
|
|
141
|
+
|
|
142
|
+
return release;
|
|
143
|
+
}, [allowSameOwner, api, enabled, mjModelRef, owner, release, selector, status, warnOnConflict]);
|
|
144
|
+
|
|
145
|
+
const canWrite = useCallback(() => (
|
|
146
|
+
enabled &&
|
|
147
|
+
groupRef.current !== null &&
|
|
148
|
+
conflictsRef.current.length === 0
|
|
149
|
+
), [enabled]);
|
|
150
|
+
|
|
151
|
+
const read = useCallback(() => {
|
|
152
|
+
const data = mjDataRef.current;
|
|
153
|
+
const group = groupRef.current;
|
|
154
|
+
if (!data || !group) return new Float64Array(0);
|
|
155
|
+
return group.readCtrl(data);
|
|
156
|
+
}, [mjDataRef]);
|
|
157
|
+
|
|
158
|
+
const write = useCallback((values: ArrayLike<number>, writeOptions: ControlWriterWriteOptions = {}) => {
|
|
159
|
+
const data: MujocoData | null = mjDataRef.current;
|
|
160
|
+
const group = groupRef.current;
|
|
161
|
+
if (!data || !group) return false;
|
|
162
|
+
if (!writeOptions.force && !canWrite()) return false;
|
|
163
|
+
group.writeCtrl(data, values);
|
|
164
|
+
return true;
|
|
165
|
+
}, [canWrite, mjDataRef]);
|
|
166
|
+
|
|
167
|
+
return useMemo(() => ({
|
|
168
|
+
owner,
|
|
169
|
+
group: groupRef,
|
|
170
|
+
conflicts: conflictsRef,
|
|
171
|
+
canWrite,
|
|
172
|
+
read,
|
|
173
|
+
write,
|
|
174
|
+
release,
|
|
175
|
+
}), [canWrite, owner, read, release, write]);
|
|
176
|
+
}
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Stable React handle for named policy observations.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import { useMemo, useRef } from 'react';
|
|
9
|
+
import { useMujocoContext } from '../core/MujocoSimProvider';
|
|
10
|
+
import { readNamedObservation } from '../policyObservation';
|
|
11
|
+
import type {
|
|
12
|
+
NamedObservationOptions,
|
|
13
|
+
NamedObservationResult,
|
|
14
|
+
} from '../policyObservation';
|
|
15
|
+
|
|
16
|
+
const EMPTY_NAMED_OBSERVATION: NamedObservationResult = {
|
|
17
|
+
values: new Float32Array(0),
|
|
18
|
+
layout: [],
|
|
19
|
+
};
|
|
20
|
+
|
|
21
|
+
export interface NamedObservationHandle {
|
|
22
|
+
read: () => NamedObservationResult;
|
|
23
|
+
readValues: () => Float32Array | Float64Array;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
export function useNamedObservation(options: NamedObservationOptions): NamedObservationHandle {
|
|
27
|
+
const { mjModelRef, mjDataRef } = useMujocoContext();
|
|
28
|
+
const optionsRef = useRef(options);
|
|
29
|
+
optionsRef.current = options;
|
|
30
|
+
|
|
31
|
+
return useMemo(() => ({
|
|
32
|
+
read() {
|
|
33
|
+
const model = mjModelRef.current;
|
|
34
|
+
const data = mjDataRef.current;
|
|
35
|
+
if (!model || !data) return EMPTY_NAMED_OBSERVATION;
|
|
36
|
+
return readNamedObservation(model, data, optionsRef.current);
|
|
37
|
+
},
|
|
38
|
+
readValues() {
|
|
39
|
+
return this.read().values;
|
|
40
|
+
},
|
|
41
|
+
}), [mjDataRef, mjModelRef]);
|
|
42
|
+
}
|
package/src/hooks/usePolicy.ts
CHANGED
|
@@ -5,9 +5,47 @@
|
|
|
5
5
|
* usePolicy — policy decimation loop hook (spec 10.1)
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
import { useRef } from 'react';
|
|
8
|
+
import { useCallback, useMemo, useRef } from 'react';
|
|
9
9
|
import { useBeforePhysicsStep } from '../core/MujocoSimProvider';
|
|
10
|
-
import type { PolicyConfig } from '../types';
|
|
10
|
+
import type { PolicyAPI, PolicyConfig, PolicyInferenceOutput, PolicyVector } from '../types';
|
|
11
|
+
|
|
12
|
+
type PendingPolicyAction = {
|
|
13
|
+
action: PolicyVector;
|
|
14
|
+
observation: PolicyVector;
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
function isPromiseLike(value: unknown): value is Promise<PolicyInferenceOutput> {
|
|
18
|
+
return (
|
|
19
|
+
typeof value === 'object' &&
|
|
20
|
+
value !== null &&
|
|
21
|
+
'then' in value &&
|
|
22
|
+
typeof (value as { then?: unknown }).then === 'function'
|
|
23
|
+
);
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
function isPolicyActionChunk(value: PolicyInferenceOutput): value is readonly PolicyVector[] {
|
|
27
|
+
return (
|
|
28
|
+
Array.isArray(value) &&
|
|
29
|
+
value.length > 0 &&
|
|
30
|
+
(Array.isArray(value[0]) || ArrayBuffer.isView(value[0]))
|
|
31
|
+
);
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
function toPolicyActions(output: PolicyInferenceOutput): PolicyVector[] {
|
|
35
|
+
return isPolicyActionChunk(output) ? [...output] : [output];
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
function enqueuePolicyActions(
|
|
39
|
+
queue: PendingPolicyAction[],
|
|
40
|
+
actions: PolicyVector[],
|
|
41
|
+
observation: PolicyVector,
|
|
42
|
+
strategy: PolicyConfig['queueStrategy']
|
|
43
|
+
) {
|
|
44
|
+
if (strategy === 'replace') {
|
|
45
|
+
queue.splice(0, queue.length);
|
|
46
|
+
}
|
|
47
|
+
queue.push(...actions.map((action) => ({ action, observation })));
|
|
48
|
+
}
|
|
11
49
|
|
|
12
50
|
/**
|
|
13
51
|
* Framework-agnostic policy execution hook.
|
|
@@ -19,15 +57,33 @@ import type { PolicyConfig } from '../types';
|
|
|
19
57
|
* @param config Policy configuration
|
|
20
58
|
* @returns { step, isRunning } control handles
|
|
21
59
|
*/
|
|
22
|
-
export function usePolicy(config: PolicyConfig) {
|
|
60
|
+
export function usePolicy(config: PolicyConfig): PolicyAPI {
|
|
23
61
|
const lastActionTimeRef = useRef(0);
|
|
24
62
|
const lastObservationRef = useRef<ReturnType<PolicyConfig['onObservation']> | null>(null);
|
|
25
63
|
const lastActionRef = useRef<Float32Array | Float64Array | number[] | null>(null);
|
|
64
|
+
const actionQueueRef = useRef<PendingPolicyAction[]>([]);
|
|
65
|
+
const inFlightRef = useRef(false);
|
|
66
|
+
const lastErrorRef = useRef<unknown>(null);
|
|
67
|
+
const epochRef = useRef(0);
|
|
26
68
|
const isRunningRef = useRef(config.enabled ?? true);
|
|
27
69
|
const configRef = useRef(config);
|
|
28
70
|
configRef.current = config;
|
|
29
71
|
isRunningRef.current = config.enabled ?? isRunningRef.current;
|
|
30
72
|
|
|
73
|
+
const clearQueue = useCallback(() => {
|
|
74
|
+
epochRef.current += 1;
|
|
75
|
+
actionQueueRef.current.splice(0, actionQueueRef.current.length);
|
|
76
|
+
inFlightRef.current = false;
|
|
77
|
+
lastErrorRef.current = null;
|
|
78
|
+
}, []);
|
|
79
|
+
|
|
80
|
+
const reset = useCallback(() => {
|
|
81
|
+
clearQueue();
|
|
82
|
+
lastActionTimeRef.current = 0;
|
|
83
|
+
lastObservationRef.current = null;
|
|
84
|
+
lastActionRef.current = null;
|
|
85
|
+
}, [clearQueue]);
|
|
86
|
+
|
|
31
87
|
useBeforePhysicsStep(({ model, data }) => {
|
|
32
88
|
if (!isRunningRef.current) return;
|
|
33
89
|
|
|
@@ -37,24 +93,91 @@ export function usePolicy(config: PolicyConfig) {
|
|
|
37
93
|
|
|
38
94
|
// Check if it's time for a new action
|
|
39
95
|
if (data.time - lastActionTimeRef.current >= interval) {
|
|
96
|
+
const queuedAction = actionQueueRef.current.shift();
|
|
97
|
+
if (queuedAction) {
|
|
98
|
+
cfg.onAction({
|
|
99
|
+
action: queuedAction.action,
|
|
100
|
+
observation: queuedAction.observation,
|
|
101
|
+
model,
|
|
102
|
+
data,
|
|
103
|
+
});
|
|
104
|
+
lastActionTimeRef.current = data.time;
|
|
105
|
+
lastActionRef.current = queuedAction.action;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
const prefetchThreshold = cfg.prefetchThreshold ?? 0;
|
|
109
|
+
const shouldInfer = !inFlightRef.current && (!queuedAction || actionQueueRef.current.length <= prefetchThreshold);
|
|
110
|
+
if (!shouldInfer) return;
|
|
111
|
+
|
|
40
112
|
// Build observation
|
|
41
113
|
const observation = cfg.onObservation({ model, data });
|
|
42
|
-
const
|
|
114
|
+
const result = cfg.infer ? cfg.infer({ observation, model, data }) : observation;
|
|
43
115
|
|
|
44
|
-
|
|
45
|
-
|
|
116
|
+
if (isPromiseLike(result)) {
|
|
117
|
+
const epoch = epochRef.current;
|
|
118
|
+
inFlightRef.current = true;
|
|
119
|
+
result
|
|
120
|
+
.then((output) => {
|
|
121
|
+
if (epoch !== epochRef.current) return;
|
|
122
|
+
enqueuePolicyActions(
|
|
123
|
+
actionQueueRef.current,
|
|
124
|
+
toPolicyActions(output),
|
|
125
|
+
observation,
|
|
126
|
+
cfg.queueStrategy ?? 'append'
|
|
127
|
+
);
|
|
128
|
+
lastErrorRef.current = null;
|
|
129
|
+
})
|
|
130
|
+
.catch((error: unknown) => {
|
|
131
|
+
if (epoch !== epochRef.current) return;
|
|
132
|
+
lastErrorRef.current = error;
|
|
133
|
+
cfg.onError?.(error);
|
|
134
|
+
})
|
|
135
|
+
.finally(() => {
|
|
136
|
+
if (epoch !== epochRef.current) return;
|
|
137
|
+
inFlightRef.current = false;
|
|
138
|
+
});
|
|
139
|
+
} else {
|
|
140
|
+
const actions = toPolicyActions(result);
|
|
141
|
+
if (queuedAction) {
|
|
142
|
+
enqueuePolicyActions(
|
|
143
|
+
actionQueueRef.current,
|
|
144
|
+
actions,
|
|
145
|
+
observation,
|
|
146
|
+
cfg.queueStrategy ?? 'append'
|
|
147
|
+
);
|
|
148
|
+
} else {
|
|
149
|
+
const [action, ...queuedActions] = actions;
|
|
150
|
+
if (!action) return;
|
|
151
|
+
enqueuePolicyActions(
|
|
152
|
+
actionQueueRef.current,
|
|
153
|
+
queuedActions,
|
|
154
|
+
observation,
|
|
155
|
+
cfg.queueStrategy ?? 'append'
|
|
156
|
+
);
|
|
157
|
+
// Apply action. If `infer` is omitted, this preserves the legacy inline-controller path.
|
|
158
|
+
cfg.onAction({ action, observation, model, data });
|
|
159
|
+
lastActionRef.current = action;
|
|
160
|
+
}
|
|
161
|
+
}
|
|
46
162
|
|
|
47
163
|
lastActionTimeRef.current = data.time;
|
|
48
164
|
lastObservationRef.current = observation;
|
|
49
|
-
lastActionRef.current = action;
|
|
50
165
|
}
|
|
51
166
|
});
|
|
52
167
|
|
|
53
|
-
return {
|
|
168
|
+
return useMemo(() => ({
|
|
54
169
|
get isRunning() { return isRunningRef.current; },
|
|
55
170
|
start: () => { isRunningRef.current = true; },
|
|
56
|
-
stop: () => {
|
|
171
|
+
stop: () => {
|
|
172
|
+
isRunningRef.current = false;
|
|
173
|
+
if (configRef.current.clearQueueOnStop) reset();
|
|
174
|
+
},
|
|
175
|
+
clearQueue,
|
|
176
|
+
reset,
|
|
177
|
+
get inFlight() { return inFlightRef.current; },
|
|
178
|
+
get queuedActions() { return actionQueueRef.current.length; },
|
|
57
179
|
get lastObservation() { return lastObservationRef.current; },
|
|
58
180
|
get lastAction() { return lastActionRef.current; },
|
|
59
|
-
|
|
181
|
+
get lastError() { return lastErrorRef.current; },
|
|
182
|
+
}), [clearQueue, reset]);
|
|
60
183
|
}
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* React wrapper for capturing policy image payloads from Three/MuJoCo cameras.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import { useCallback, useState } from 'react';
|
|
9
|
+
import { useMujoco } from '../core/MujocoSimProvider';
|
|
10
|
+
import {
|
|
11
|
+
capturePolicyCameraFrames,
|
|
12
|
+
capturePolicyCameraFramesFromMountedStreams,
|
|
13
|
+
} from '../policyCameraFrames';
|
|
14
|
+
import type {
|
|
15
|
+
CreatePolicyCameraFrameCapturePlanOptions,
|
|
16
|
+
PolicyCameraFrameCapturePlan,
|
|
17
|
+
} from '../policyCameraFrames';
|
|
18
|
+
import type {
|
|
19
|
+
FrameCaptureStatus,
|
|
20
|
+
PolicyCameraFrameCaptureAPI,
|
|
21
|
+
PolicyCameraFrameCaptureOptions,
|
|
22
|
+
PolicyCameraFrameCaptureResult,
|
|
23
|
+
} from '../types';
|
|
24
|
+
|
|
25
|
+
export type MountedPolicyCameraFrameCaptureOptions = Omit<
|
|
26
|
+
CreatePolicyCameraFrameCapturePlanOptions,
|
|
27
|
+
'cameras' | 'sites' | 'bodies'
|
|
28
|
+
>;
|
|
29
|
+
|
|
30
|
+
export interface MountedPolicyCameraFrameCaptureAPI {
|
|
31
|
+
status: FrameCaptureStatus;
|
|
32
|
+
error: Error | null;
|
|
33
|
+
isCapturing: boolean;
|
|
34
|
+
capture: (
|
|
35
|
+
options?: Partial<MountedPolicyCameraFrameCaptureOptions>
|
|
36
|
+
) => Promise<PolicyCameraFrameCaptureResult & { plan: PolicyCameraFrameCapturePlan }>;
|
|
37
|
+
reset: () => void;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
function mergePolicyCameraFrameCaptureOptions(
|
|
41
|
+
defaultOptions: MountedPolicyCameraFrameCaptureOptions,
|
|
42
|
+
options: Partial<MountedPolicyCameraFrameCaptureOptions>
|
|
43
|
+
): MountedPolicyCameraFrameCaptureOptions {
|
|
44
|
+
return {
|
|
45
|
+
...defaultOptions,
|
|
46
|
+
...options,
|
|
47
|
+
cameraKeys: options.cameraKeys ?? defaultOptions.cameraKeys,
|
|
48
|
+
aliases: {
|
|
49
|
+
...defaultOptions.aliases,
|
|
50
|
+
...options.aliases,
|
|
51
|
+
},
|
|
52
|
+
defaults: {
|
|
53
|
+
...defaultOptions.defaults,
|
|
54
|
+
...options.defaults,
|
|
55
|
+
},
|
|
56
|
+
streamOptions: {
|
|
57
|
+
...defaultOptions.streamOptions,
|
|
58
|
+
...options.streamOptions,
|
|
59
|
+
},
|
|
60
|
+
};
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
export function usePolicyCameraFrames(
|
|
64
|
+
defaultOptions: PolicyCameraFrameCaptureOptions
|
|
65
|
+
): PolicyCameraFrameCaptureAPI {
|
|
66
|
+
const mujoco = useMujoco();
|
|
67
|
+
const [status, setStatus] = useState<FrameCaptureStatus>('idle');
|
|
68
|
+
const [error, setError] = useState<Error | null>(null);
|
|
69
|
+
|
|
70
|
+
const reset = useCallback(() => {
|
|
71
|
+
setStatus('idle');
|
|
72
|
+
setError(null);
|
|
73
|
+
}, []);
|
|
74
|
+
|
|
75
|
+
const capture = useCallback(
|
|
76
|
+
async (options: Partial<PolicyCameraFrameCaptureOptions> = {}) => {
|
|
77
|
+
if (!mujoco.api) {
|
|
78
|
+
throw new Error('MuJoCo scene is not ready for policy camera capture.');
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
setStatus('capturing');
|
|
82
|
+
setError(null);
|
|
83
|
+
|
|
84
|
+
try {
|
|
85
|
+
const result = await capturePolicyCameraFrames(mujoco.api, {
|
|
86
|
+
...defaultOptions,
|
|
87
|
+
...options,
|
|
88
|
+
streams: options.streams ?? defaultOptions.streams,
|
|
89
|
+
});
|
|
90
|
+
setStatus('captured');
|
|
91
|
+
return result;
|
|
92
|
+
} catch (nextError) {
|
|
93
|
+
const error =
|
|
94
|
+
nextError instanceof Error
|
|
95
|
+
? nextError
|
|
96
|
+
: new Error('Unable to capture policy camera frames.');
|
|
97
|
+
setError(error);
|
|
98
|
+
setStatus('error');
|
|
99
|
+
throw error;
|
|
100
|
+
}
|
|
101
|
+
},
|
|
102
|
+
[defaultOptions, mujoco.api]
|
|
103
|
+
);
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
status,
|
|
107
|
+
error,
|
|
108
|
+
isCapturing: status === 'capturing',
|
|
109
|
+
capture,
|
|
110
|
+
reset,
|
|
111
|
+
};
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
export function usePolicyCameraFramesFromMountedStreams(
|
|
115
|
+
defaultOptions: MountedPolicyCameraFrameCaptureOptions
|
|
116
|
+
): MountedPolicyCameraFrameCaptureAPI {
|
|
117
|
+
const mujoco = useMujoco();
|
|
118
|
+
const [status, setStatus] = useState<FrameCaptureStatus>('idle');
|
|
119
|
+
const [error, setError] = useState<Error | null>(null);
|
|
120
|
+
|
|
121
|
+
const reset = useCallback(() => {
|
|
122
|
+
setStatus('idle');
|
|
123
|
+
setError(null);
|
|
124
|
+
}, []);
|
|
125
|
+
|
|
126
|
+
const capture = useCallback(
|
|
127
|
+
async (options: Partial<MountedPolicyCameraFrameCaptureOptions> = {}) => {
|
|
128
|
+
if (!mujoco.api) {
|
|
129
|
+
throw new Error('MuJoCo scene is not ready for mounted policy camera capture.');
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
setStatus('capturing');
|
|
133
|
+
setError(null);
|
|
134
|
+
|
|
135
|
+
try {
|
|
136
|
+
const result = await capturePolicyCameraFramesFromMountedStreams(
|
|
137
|
+
mujoco.api,
|
|
138
|
+
mergePolicyCameraFrameCaptureOptions(defaultOptions, options)
|
|
139
|
+
);
|
|
140
|
+
setStatus('captured');
|
|
141
|
+
return result;
|
|
142
|
+
} catch (nextError) {
|
|
143
|
+
const error =
|
|
144
|
+
nextError instanceof Error
|
|
145
|
+
? nextError
|
|
146
|
+
: new Error('Unable to capture mounted policy camera frames.');
|
|
147
|
+
setError(error);
|
|
148
|
+
setStatus('error');
|
|
149
|
+
throw error;
|
|
150
|
+
}
|
|
151
|
+
},
|
|
152
|
+
[defaultOptions, mujoco.api]
|
|
153
|
+
);
|
|
154
|
+
|
|
155
|
+
return {
|
|
156
|
+
status,
|
|
157
|
+
error,
|
|
158
|
+
isCapturing: status === 'capturing',
|
|
159
|
+
capture,
|
|
160
|
+
reset,
|
|
161
|
+
};
|
|
162
|
+
}
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @license
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Ref-based world pose hooks for named MuJoCo bodies, geoms, and sites.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import { useEffect, useRef } from 'react';
|
|
9
|
+
import type { RefObject } from 'react';
|
|
10
|
+
import * as THREE from 'three';
|
|
11
|
+
import { useAfterPhysicsStep, useMujocoContext } from '../core/MujocoSimProvider';
|
|
12
|
+
import { findBodyByName, findGeomByName, findSiteByName } from '../core/SceneLoader';
|
|
13
|
+
import type { Bodies, Geoms, Sites } from '../types';
|
|
14
|
+
|
|
15
|
+
export type PoseResourceKind = 'body' | 'geom' | 'site';
|
|
16
|
+
|
|
17
|
+
export interface PoseReadout {
|
|
18
|
+
id: RefObject<number>;
|
|
19
|
+
found: RefObject<boolean>;
|
|
20
|
+
position: RefObject<THREE.Vector3>;
|
|
21
|
+
quaternion: RefObject<THREE.Quaternion>;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
const _matrix = new THREE.Matrix4();
|
|
25
|
+
|
|
26
|
+
function quaternionFromMatrixArray(
|
|
27
|
+
target: THREE.Quaternion,
|
|
28
|
+
values: ArrayLike<number>,
|
|
29
|
+
offset: number
|
|
30
|
+
) {
|
|
31
|
+
_matrix.set(
|
|
32
|
+
values[offset], values[offset + 1], values[offset + 2], 0,
|
|
33
|
+
values[offset + 3], values[offset + 4], values[offset + 5], 0,
|
|
34
|
+
values[offset + 6], values[offset + 7], values[offset + 8], 0,
|
|
35
|
+
0, 0, 0, 1
|
|
36
|
+
);
|
|
37
|
+
target.setFromRotationMatrix(_matrix);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
function quaternionFromMujocoQuat(
|
|
41
|
+
target: THREE.Quaternion,
|
|
42
|
+
values: ArrayLike<number>,
|
|
43
|
+
offset: number
|
|
44
|
+
) {
|
|
45
|
+
target.set(
|
|
46
|
+
values[offset + 1] ?? 0,
|
|
47
|
+
values[offset + 2] ?? 0,
|
|
48
|
+
values[offset + 3] ?? 0,
|
|
49
|
+
values[offset] ?? 1
|
|
50
|
+
);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
function useNamedPose(kind: PoseResourceKind, name: string): PoseReadout {
|
|
54
|
+
const { mjModelRef, status } = useMujocoContext();
|
|
55
|
+
const idRef = useRef(-1);
|
|
56
|
+
const foundRef = useRef(false);
|
|
57
|
+
const positionRef = useRef(new THREE.Vector3());
|
|
58
|
+
const quaternionRef = useRef(new THREE.Quaternion());
|
|
59
|
+
|
|
60
|
+
useEffect(() => {
|
|
61
|
+
const model = mjModelRef.current;
|
|
62
|
+
if (!model || status !== 'ready') {
|
|
63
|
+
idRef.current = -1;
|
|
64
|
+
foundRef.current = false;
|
|
65
|
+
return;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
if (kind === 'body') idRef.current = findBodyByName(model, name);
|
|
69
|
+
else if (kind === 'geom') idRef.current = findGeomByName(model, name);
|
|
70
|
+
else idRef.current = findSiteByName(model, name);
|
|
71
|
+
foundRef.current = idRef.current >= 0;
|
|
72
|
+
}, [kind, name, status, mjModelRef]);
|
|
73
|
+
|
|
74
|
+
useAfterPhysicsStep(({ data }) => {
|
|
75
|
+
const id = idRef.current;
|
|
76
|
+
if (id < 0) return;
|
|
77
|
+
|
|
78
|
+
if (kind === 'body') {
|
|
79
|
+
const p = id * 3;
|
|
80
|
+
positionRef.current.set(data.xpos[p], data.xpos[p + 1], data.xpos[p + 2]);
|
|
81
|
+
if (data.xmat) {
|
|
82
|
+
quaternionFromMatrixArray(quaternionRef.current, data.xmat, id * 9);
|
|
83
|
+
} else {
|
|
84
|
+
quaternionFromMujocoQuat(quaternionRef.current, data.xquat, id * 4);
|
|
85
|
+
}
|
|
86
|
+
return;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
if (kind === 'geom') {
|
|
90
|
+
const p = id * 3;
|
|
91
|
+
positionRef.current.set(data.geom_xpos[p], data.geom_xpos[p + 1], data.geom_xpos[p + 2]);
|
|
92
|
+
quaternionFromMatrixArray(quaternionRef.current, data.geom_xmat, id * 9);
|
|
93
|
+
return;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
const p = id * 3;
|
|
97
|
+
positionRef.current.set(data.site_xpos[p], data.site_xpos[p + 1], data.site_xpos[p + 2]);
|
|
98
|
+
quaternionFromMatrixArray(quaternionRef.current, data.site_xmat, id * 9);
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
return {
|
|
102
|
+
id: idRef,
|
|
103
|
+
found: foundRef,
|
|
104
|
+
position: positionRef,
|
|
105
|
+
quaternion: quaternionRef,
|
|
106
|
+
};
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
export function useBodyPose(name: Bodies): PoseReadout {
|
|
110
|
+
return useNamedPose('body', name);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
export function useGeomPose(name: Geoms): PoseReadout {
|
|
114
|
+
return useNamedPose('geom', name);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
export function useSitePose(name: Sites): PoseReadout {
|
|
118
|
+
return useNamedPose('site', name);
|
|
119
|
+
}
|