mujoco-react 10.0.0 → 10.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,87 @@
1
+ /**
2
+ * @license
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Helpers for applying policy action vectors to MuJoCo controls.
6
+ */
7
+
8
+ import type { MujocoData, MujocoModel, PolicyVector } from './types';
9
+
10
+ export interface ApplyPolicyActionToControlsOptions {
11
+ /**
12
+ * First actuator/control index to write. Defaults to 0.
13
+ */
14
+ actuatorOffset?: number;
15
+ /**
16
+ * Maximum number of controls to write. Defaults to the action length.
17
+ */
18
+ actionSize?: number;
19
+ /**
20
+ * Clamp each action value to `model.actuator_ctrlrange` before writing.
21
+ * Defaults to true because most learned policies should not exceed actuator limits.
22
+ */
23
+ clamp?: boolean;
24
+ /**
25
+ * Leave the current control unchanged when an action entry is not finite.
26
+ * Defaults to true so a bad policy response cannot write NaN into the simulation.
27
+ */
28
+ skipInvalid?: boolean;
29
+ }
30
+
31
+ export interface ApplyPolicyActionToControlsResult {
32
+ /**
33
+ * Values actually written to `data.ctrl`, after offset, truncation, and clamping.
34
+ */
35
+ applied: number[];
36
+ /**
37
+ * Actuator indices that were not written because the corresponding action value
38
+ * was not finite and `skipInvalid` was enabled.
39
+ */
40
+ skipped: number[];
41
+ actuatorOffset: number;
42
+ }
43
+
44
+ export function clampPolicyActionValue(
45
+ model: MujocoModel,
46
+ actuatorIndex: number,
47
+ value: number
48
+ ) {
49
+ const ranges = model.actuator_ctrlrange;
50
+ const min = ranges?.[actuatorIndex * 2] ?? -Infinity;
51
+ const max = ranges?.[actuatorIndex * 2 + 1] ?? Infinity;
52
+ return Math.max(min, Math.min(max, value));
53
+ }
54
+
55
+ export function applyPolicyActionToControls(
56
+ model: MujocoModel,
57
+ data: MujocoData,
58
+ action: PolicyVector,
59
+ options: ApplyPolicyActionToControlsOptions = {}
60
+ ): ApplyPolicyActionToControlsResult {
61
+ const actuatorOffset = options.actuatorOffset ?? 0;
62
+ const actionSize = options.actionSize ?? action.length;
63
+ const shouldClamp = options.clamp ?? true;
64
+ const shouldSkipInvalid = options.skipInvalid ?? true;
65
+ const count = Math.max(
66
+ 0,
67
+ Math.min(actionSize, action.length, data.ctrl.length - actuatorOffset, model.nu - actuatorOffset)
68
+ );
69
+ const applied: number[] = [];
70
+ const skipped: number[] = [];
71
+
72
+ for (let index = 0; index < count; index += 1) {
73
+ const actuatorIndex = actuatorOffset + index;
74
+ const value = Number(action[index]);
75
+ if (shouldSkipInvalid && !Number.isFinite(value)) {
76
+ skipped.push(actuatorIndex);
77
+ continue;
78
+ }
79
+ const nextValue = shouldClamp
80
+ ? clampPolicyActionValue(model, actuatorIndex, value)
81
+ : value;
82
+ data.ctrl[actuatorIndex] = nextValue;
83
+ applied.push(nextValue);
84
+ }
85
+
86
+ return { applied, skipped, actuatorOffset };
87
+ }
@@ -0,0 +1,172 @@
1
+ /**
2
+ * @license
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Named policy observation builders with layout and units metadata.
6
+ */
7
+
8
+ import type {
9
+ Bodies,
10
+ Geoms,
11
+ MujocoData,
12
+ MujocoModel,
13
+ ObservationOutput,
14
+ Sites,
15
+ } from './types';
16
+ import { findBodyByName, findGeomByName, findSiteByName } from './core/SceneLoader';
17
+
18
+ export type NamedObservationMissing = 'skip' | 'zeros' | 'throw';
19
+
20
+ export interface NamedObservationInput {
21
+ model: MujocoModel;
22
+ data: MujocoData;
23
+ }
24
+
25
+ export interface NamedObservationField {
26
+ name: string;
27
+ size: number;
28
+ units?: string;
29
+ read: (input: NamedObservationInput) => ArrayLike<number> | number | null | undefined;
30
+ }
31
+
32
+ export interface NamedObservationLayoutItem {
33
+ name: string;
34
+ start: number;
35
+ size: number;
36
+ units?: string;
37
+ }
38
+
39
+ export interface NamedObservationOptions {
40
+ fields: readonly NamedObservationField[];
41
+ output?: ObservationOutput;
42
+ missing?: NamedObservationMissing;
43
+ }
44
+
45
+ export interface NamedObservationResult {
46
+ values: Float32Array | Float64Array;
47
+ layout: NamedObservationLayoutItem[];
48
+ }
49
+
50
+ function pushValues(target: number[], value: ArrayLike<number> | number, size: number) {
51
+ if (typeof value === 'number') {
52
+ target.push(value);
53
+ for (let index = 1; index < size; index += 1) target.push(0);
54
+ return;
55
+ }
56
+
57
+ for (let index = 0; index < size; index += 1) {
58
+ target.push(Number(value[index] ?? 0));
59
+ }
60
+ }
61
+
62
+ export function readNamedObservation(
63
+ model: MujocoModel,
64
+ data: MujocoData,
65
+ options: NamedObservationOptions
66
+ ): NamedObservationResult {
67
+ const values: number[] = [];
68
+ const layout: NamedObservationLayoutItem[] = [];
69
+ const missing = options.missing ?? 'skip';
70
+
71
+ for (const field of options.fields) {
72
+ const start = values.length;
73
+ const value = field.read({ model, data });
74
+ if (value === null || value === undefined) {
75
+ if (missing === 'skip') continue;
76
+ if (missing === 'throw') {
77
+ throw new Error(`Unable to read named observation field "${field.name}".`);
78
+ }
79
+ for (let index = 0; index < field.size; index += 1) values.push(0);
80
+ } else {
81
+ pushValues(values, value, field.size);
82
+ }
83
+ layout.push({
84
+ name: field.name,
85
+ start,
86
+ size: field.size,
87
+ units: field.units,
88
+ });
89
+ }
90
+
91
+ return {
92
+ values: options.output === 'float64'
93
+ ? new Float64Array(values)
94
+ : new Float32Array(values),
95
+ layout,
96
+ };
97
+ }
98
+
99
+ export function createNamedObservationBuilder(options: NamedObservationOptions) {
100
+ return (model: MujocoModel, data: MujocoData) => (
101
+ readNamedObservation(model, data, options)
102
+ );
103
+ }
104
+
105
+ export function qposField(name: string, index: number, units = 'qpos'): NamedObservationField {
106
+ return {
107
+ name,
108
+ size: 1,
109
+ units,
110
+ read: ({ data }) => data.qpos[index],
111
+ };
112
+ }
113
+
114
+ export function qvelField(name: string, index: number, units = 'qvel'): NamedObservationField {
115
+ return {
116
+ name,
117
+ size: 1,
118
+ units,
119
+ read: ({ data }) => data.qvel[index],
120
+ };
121
+ }
122
+
123
+ export function ctrlField(name: string, index: number, units = 'ctrl'): NamedObservationField {
124
+ return {
125
+ name,
126
+ size: 1,
127
+ units,
128
+ read: ({ data }) => data.ctrl[index],
129
+ };
130
+ }
131
+
132
+ export function bodyPositionField(name: Bodies, units = 'world_position'): NamedObservationField {
133
+ return {
134
+ name: `body:${name}:xpos`,
135
+ size: 3,
136
+ units,
137
+ read: ({ model, data }) => {
138
+ const bodyId = findBodyByName(model, name);
139
+ if (bodyId < 0) return null;
140
+ const offset = bodyId * 3;
141
+ return data.xpos.subarray(offset, offset + 3);
142
+ },
143
+ };
144
+ }
145
+
146
+ export function geomPositionField(name: Geoms, units = 'world_position'): NamedObservationField {
147
+ return {
148
+ name: `geom:${name}:xpos`,
149
+ size: 3,
150
+ units,
151
+ read: ({ model, data }) => {
152
+ const geomId = findGeomByName(model, name);
153
+ if (geomId < 0) return null;
154
+ const offset = geomId * 3;
155
+ return data.geom_xpos.subarray(offset, offset + 3);
156
+ },
157
+ };
158
+ }
159
+
160
+ export function sitePositionField(name: Sites, units = 'world_position'): NamedObservationField {
161
+ return {
162
+ name: `site:${name}:xpos`,
163
+ size: 3,
164
+ units,
165
+ read: ({ model, data }) => {
166
+ const siteId = findSiteByName(model, name);
167
+ if (siteId < 0) return null;
168
+ const offset = siteId * 3;
169
+ return data.site_xpos.subarray(offset, offset + 3);
170
+ },
171
+ };
172
+ }
@@ -6,7 +6,7 @@
6
6
 
7
7
  import * as THREE from 'three';
8
8
  import { CapsuleGeometry } from './CapsuleGeometry';
9
- import { Reflector } from './Reflector';
9
+ import { getName } from '../core/SceneLoader';
10
10
  import { MujocoModel, MujocoModule } from '../types';
11
11
 
12
12
  /**
@@ -19,11 +19,64 @@ import { MujocoModel, MujocoModule } from '../types';
19
19
  */
20
20
  export class GeomBuilder {
21
21
  private mujoco: MujocoModule;
22
+ private textureCache = new Map<number, THREE.Texture>();
22
23
 
23
24
  constructor(mujoco: MujocoModule) {
24
25
  this.mujoco = mujoco;
25
26
  }
26
27
 
28
+ private getMaterialTexture(mjModel: MujocoModel, matId: number): THREE.Texture | null {
29
+ if (matId < 0 || !mjModel.mat_texid || !mjModel.tex_data) return null;
30
+
31
+ const materialCount = Math.max(1, Math.floor(mjModel.mat_rgba.length / 4));
32
+ const textureRoles = Math.max(1, Math.floor(mjModel.mat_texid.length / materialCount));
33
+ let texId = -1;
34
+ for (let role = 0; role < textureRoles; role += 1) {
35
+ const candidate = mjModel.mat_texid[matId * textureRoles + role];
36
+ if (candidate >= 0) {
37
+ texId = candidate;
38
+ break;
39
+ }
40
+ }
41
+ if (texId < 0) return null;
42
+
43
+ const cached = this.textureCache.get(texId);
44
+ if (cached) return cached;
45
+
46
+ const width = Number(mjModel.tex_width[texId]);
47
+ const height = Number(mjModel.tex_height[texId]);
48
+ const channels = Number(mjModel.tex_nchannel[texId]);
49
+ const offset = Number(mjModel.tex_adr[texId]);
50
+ if (width <= 0 || height <= 0 || channels <= 0 || offset < 0) return null;
51
+
52
+ const source = mjModel.tex_data.subarray(offset, offset + width * height * channels);
53
+ const rgba = new Uint8Array(width * height * 4);
54
+ for (let i = 0, j = 0; i < width * height; i += 1, j += channels) {
55
+ const r = source[j] ?? 255;
56
+ const g = channels > 1 ? source[j + 1] : r;
57
+ const b = channels > 2 ? source[j + 2] : r;
58
+ const a = channels > 3 ? source[j + 3] : 255;
59
+ const out = i * 4;
60
+ rgba[out] = r;
61
+ rgba[out + 1] = g;
62
+ rgba[out + 2] = b;
63
+ rgba[out + 3] = a;
64
+ }
65
+
66
+ const texture = new THREE.DataTexture(rgba, width, height, THREE.RGBAFormat);
67
+ texture.colorSpace = THREE.LinearSRGBColorSpace;
68
+ texture.wrapS = THREE.RepeatWrapping;
69
+ texture.wrapT = THREE.RepeatWrapping;
70
+ texture.flipY = true;
71
+ const repeatOffset = matId * 2;
72
+ const repeatS = mjModel.mat_texrepeat?.[repeatOffset] ?? 1;
73
+ const repeatT = mjModel.mat_texrepeat?.[repeatOffset + 1] ?? 1;
74
+ texture.repeat.set(repeatS || 1, repeatT || 1);
75
+ texture.needsUpdate = true;
76
+ this.textureCache.set(texId, texture);
77
+ return texture;
78
+ }
79
+
27
80
  /**
28
81
  * Creates a Three.js Object3D (usually a Mesh) for a specific geometry in the MuJoCo model.
29
82
  * Returns null if the geometry shouldn't be rendered (e.g., invisible collision triggers).
@@ -43,6 +96,7 @@ export class GeomBuilder {
43
96
  // Sometimes color is on the geom itself, sometimes it uses a shared material definition.
44
97
  const matId = mjModel.geom_matid[g];
45
98
  const color = new THREE.Color(0xffffff);
99
+ const map = this.getMaterialTexture(mjModel, matId);
46
100
  let opacity = 1.0;
47
101
 
48
102
  if (matId >= 0) {
@@ -65,8 +119,7 @@ export class GeomBuilder {
65
119
  const getVal = (v: unknown) => (v as { value: number })?.value ?? v;
66
120
 
67
121
  if (type === getVal(MG.mjGEOM_PLANE)) {
68
- // Planes are infinite in MuJoCo, but we need a finite mesh for Three.js.
69
- // Fallback reduced to 5m to match grid as requested.
122
+ // Planes are infinite in MuJoCo, but Three needs finite UVs for textured captures.
70
123
  geo = new THREE.PlaneGeometry(size[0] * 2 || 5, size[1] * 2 || 5);
71
124
  } else if (type === getVal(MG.mjGEOM_SPHERE)) {
72
125
  geo = new THREE.SphereGeometry(size[0], 24, 24);
@@ -100,28 +153,22 @@ export class GeomBuilder {
100
153
 
101
154
  // 5. Construct the final Mesh
102
155
  if (geo) {
103
- let mesh;
104
- // Special handling for the floor plane to make it shiny
105
- if (type === getVal(MG.mjGEOM_PLANE)) {
106
- mesh = new Reflector(geo, {
107
- clipBias: 0.003,
108
- textureWidth: 1024, textureHeight: 1024,
109
- color,
110
- mixStrength: 0.25
111
- });
112
- } else {
113
- // Standard physical material for everything else
114
- mesh = new THREE.Mesh(geo, new THREE.MeshStandardMaterial({
115
- color,
116
- transparent: opacity < 1,
117
- opacity,
118
- roughness: 0.6,
119
- metalness: 0.2
120
- }));
121
- // Enable shadows
122
- mesh.castShadow = true;
123
- mesh.receiveShadow = true;
156
+ const isPlane = type === getVal(MG.mjGEOM_PLANE);
157
+ const materialMap = isPlane && map ? map.clone() : map;
158
+ if (isPlane && materialMap) {
159
+ materialMap.repeat.multiplyScalar(2.5);
160
+ materialMap.needsUpdate = true;
124
161
  }
162
+ const mesh = new THREE.Mesh(geo, new THREE.MeshStandardMaterial({
163
+ color,
164
+ map: materialMap,
165
+ transparent: opacity < 1,
166
+ opacity,
167
+ roughness: 0.6,
168
+ metalness: 0
169
+ }));
170
+ mesh.castShadow = type !== getVal(MG.mjGEOM_PLANE);
171
+ mesh.receiveShadow = true;
125
172
 
126
173
  // Apply the local position offset and rotation specified in the MJCF XML
127
174
  mesh.position.set(pos[0], pos[1], pos[2]);
@@ -131,6 +178,8 @@ export class GeomBuilder {
131
178
  // Tag the mesh with its MuJoCo body and geom IDs for interaction (picking/dragging)
132
179
  mesh.userData.bodyID = mjModel.geom_bodyid[g];
133
180
  mesh.userData.geomID = g;
181
+ mesh.userData.geomGroup = mjModel.geom_group[g];
182
+ mesh.userData.geomName = getName(mjModel, mjModel.name_geomadr[g]);
134
183
 
135
184
  return mesh;
136
185
  }
@@ -41,6 +41,8 @@ export interface CameraFrameCaptureSession {
41
41
 
42
42
  export const CAMERA_FRAME_CAPTURE_RENDER_USER_DATA_KEY =
43
43
  'mujocoReactCameraFrameCaptureRender';
44
+ export const CAMERA_FRAME_CAPTURE_PRE_RENDER_USER_DATA_KEY =
45
+ 'mujocoReactCameraFrameCapturePreRender';
44
46
  export const CAPTURE_EXCLUDE_KEY =
45
47
  'mujoco.capture.exclude';
46
48
 
@@ -84,6 +86,8 @@ type VisibilityState = {
84
86
  visible: boolean;
85
87
  };
86
88
 
89
+ type CameraFrameCapturePreRender = () => void;
90
+
87
91
  function toVector3(
88
92
  value: CameraFrameCaptureVector3 | undefined,
89
93
  fallback: THREE.Vector3
@@ -260,6 +264,39 @@ function hideExcludedCaptureObjects(scene: THREE.Scene): VisibilityState[] {
260
264
  return hidden;
261
265
  }
262
266
 
267
+ function hideCaptureGeomGroups(
268
+ scene: THREE.Scene,
269
+ options: CameraFrameCaptureOptions
270
+ ): VisibilityState[] {
271
+ const hidden: VisibilityState[] = [];
272
+ const hiddenGroups = options.hiddenGeomGroups
273
+ ? new Set(options.hiddenGeomGroups)
274
+ : null;
275
+ const visibleGroups = options.visibleGeomGroups
276
+ ? new Set(options.visibleGeomGroups)
277
+ : null;
278
+ const hiddenNames = options.hiddenGeomNames
279
+ ? new Set(options.hiddenGeomNames)
280
+ : null;
281
+ if (!hiddenGroups && !visibleGroups && !hiddenNames) return hidden;
282
+
283
+ scene.traverse((object) => {
284
+ if (!object.visible) return;
285
+ const geomGroup = object.userData.geomGroup;
286
+ const geomName = object.userData.geomName;
287
+ if (typeof geomGroup !== 'number' && typeof geomName !== 'string') return;
288
+ if (
289
+ hiddenNames?.has(geomName) ||
290
+ hiddenGroups?.has(geomGroup) ||
291
+ (typeof geomGroup === 'number' && visibleGroups && !visibleGroups.has(geomGroup))
292
+ ) {
293
+ hidden.push({ object, visible: object.visible });
294
+ object.visible = false;
295
+ }
296
+ });
297
+ return hidden;
298
+ }
299
+
263
300
  function restoreObjectVisibility(hidden: VisibilityState[]) {
264
301
  for (const { object, visible } of hidden) {
265
302
  object.visible = visible;
@@ -332,6 +369,17 @@ function getCaptureRenderer(
332
369
  return renderers[0] ?? null;
333
370
  }
334
371
 
372
+ function runCapturePreRenderHooks(scene: THREE.Scene) {
373
+ const callbacks: CameraFrameCapturePreRender[] = [];
374
+ scene.traverse((object) => {
375
+ const callback = object.userData[
376
+ CAMERA_FRAME_CAPTURE_PRE_RENDER_USER_DATA_KEY
377
+ ] as CameraFrameCapturePreRender | undefined;
378
+ if (typeof callback === 'function') callbacks.push(callback);
379
+ });
380
+ for (const callback of callbacks) callback();
381
+ }
382
+
335
383
  export function createCameraFrameCaptureSession(
336
384
  renderer: THREE.WebGLRenderer,
337
385
  scene: THREE.Scene,
@@ -382,8 +430,12 @@ export function createCameraFrameCaptureSession(
382
430
 
383
431
  function renderPreparedCapture(captureOptions: CameraFrameCaptureOptions) {
384
432
  const previousState = saveRendererState(renderer);
385
- const hidden = hideExcludedCaptureObjects(scene);
433
+ const hidden = [
434
+ ...hideExcludedCaptureObjects(scene),
435
+ ...hideCaptureGeomGroups(scene, captureOptions),
436
+ ];
386
437
 
438
+ runCapturePreRenderHooks(scene);
387
439
  scene.updateMatrixWorld(true);
388
440
  try {
389
441
  renderer.xr.enabled = false;
@@ -391,6 +443,14 @@ export function createCameraFrameCaptureSession(
391
443
  renderer.setViewport(0, 0, width, height);
392
444
  renderer.setScissor(0, 0, width, height);
393
445
  renderer.setScissorTest(false);
446
+ if (captureOptions.background !== undefined) {
447
+ renderer.setClearColor(
448
+ new THREE.Color(captureOptions.background),
449
+ captureOptions.backgroundAlpha ?? previousState.clearAlpha
450
+ );
451
+ } else if (captureOptions.backgroundAlpha !== undefined) {
452
+ renderer.setClearColor(previousState.clearColor, captureOptions.backgroundAlpha);
453
+ }
394
454
  renderer.clear();
395
455
  renderer.render(scene, camera);
396
456
  readRenderTargetToCanvas(
@@ -423,13 +483,25 @@ export function createCameraFrameCaptureSession(
423
483
 
424
484
  async function captureAsync(nextOptions: CameraFrameCaptureOptions = {}) {
425
485
  const captureOptions = resolveCaptureOptions(nextOptions);
486
+ runCapturePreRenderHooks(scene);
426
487
  scene.updateMatrixWorld(true);
427
488
  const captureRenderer = getCaptureRenderer(scene);
428
489
  if (captureRenderer) {
429
490
  const previousState = saveRendererState(renderer);
430
- const hidden = hideExcludedCaptureObjects(scene);
491
+ const hidden = [
492
+ ...hideExcludedCaptureObjects(scene),
493
+ ...hideCaptureGeomGroups(scene, captureOptions),
494
+ ];
431
495
  try {
432
496
  renderer.xr.enabled = false;
497
+ if (captureOptions.background !== undefined) {
498
+ renderer.setClearColor(
499
+ new THREE.Color(captureOptions.background),
500
+ captureOptions.backgroundAlpha ?? previousState.clearAlpha
501
+ );
502
+ } else if (captureOptions.backgroundAlpha !== undefined) {
503
+ renderer.setClearColor(previousState.clearColor, captureOptions.backgroundAlpha);
504
+ }
433
505
  const captureResult = await captureRenderer({
434
506
  renderer,
435
507
  scene,