mouse-core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mouse/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ from mouse.models import Model, load_model
2
+ from mouse.models.heads import BaseHead, BaseHeadWithTarget
3
+ from mouse.losses import LossConfig, LossFunction
4
+
5
+ __all__ = [
6
+ "Model",
7
+ "load_model",
8
+ "BaseHead",
9
+ "BaseHeadWithTarget",
10
+ "LossConfig",
11
+ "LossFunction",
12
+ ]
mouse/data/__init__.py ADDED
File without changes
mouse/data/augment.py ADDED
@@ -0,0 +1,536 @@
1
+ """Training-time augmentations for offline RL batches.
2
+
3
+ When any augmentation is enabled, :class:`TokenAugmenter` clones the step stream,
4
+ applies transforms to the copy, and returns it; the input is left unchanged.
5
+ If nothing is enabled, the input is returned as-is (no clone).
6
+
7
+ **mask_prob** (see :class:`AugmentMaskProbConfig`): per-field Bernoulli mask
8
+ with the given probability on each step. Masked steps have their corresponding
9
+ field(s) zeroed (or set to -1 for time). ``PREDICTION`` and ``COMPUTE`` rows are never masked.
10
+ Probabilities ``<= 0`` for a field skip that field; if every entry is zero, masking is
11
+ skipped entirely.
12
+ Masks are **not** snapshotted: a new random draw runs on every :meth:`TokenAugmenter.__call__`.
13
+
14
+ **permute_** flags: random choices are **per sequence** (per batch row): the same
15
+ action mapping, done flip apply at every step in that row; different batch rows may
16
+ use different permutations. **scale_** / **shift_** use :class:`AugmentScalarSpec`:
17
+ ``low`` / ``high`` for uniform on ``[low, high)`` per batch (configs use this for all
18
+ continuous aug), or Gaussian ``mean`` / ``std``; ``low == high`` or ``std: 0`` fixes
19
+ the scalar at identity.
20
+
21
+ Note: ``permute_tokens`` is not applicable without an explicit token stream and is silently
22
+ ignored.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from dataclasses import dataclass, field
28
+ from typing import Literal, cast
29
+
30
+ import torch
31
+ from tensordict import TensorDict
32
+
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Config dataclasses
36
+ # ---------------------------------------------------------------------------
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class AugmentScalarSpec:
41
+ """Scalar augmentation: uniform on ``[low, high)`` per batch, or Gaussian ``mean + std * N(0,1)``.
42
+
43
+ YAML configs typically use ``low``/``high`` (``low == high`` fixes at that value = identity).
44
+ Omit ``low``/``high`` for Gaussian; ``std == 0`` fixes at ``mean``.
45
+ """
46
+
47
+ mean: float
48
+ std: float = 0.0
49
+ low: float | None = None
50
+ high: float | None = None
51
+
52
+ def __post_init__(self) -> None:
53
+ u_lo, u_hi = self.low, self.high
54
+ if (u_lo is None) != (u_hi is None):
55
+ raise ValueError("AugmentScalarSpec: set both low and high for uniform, or neither for Gaussian.")
56
+
57
+
58
+ def _augment_scalar_active(spec: AugmentScalarSpec, identity_mean: float) -> bool:
59
+ if spec.low is not None and spec.high is not None:
60
+ lo, hi = float(spec.low), float(spec.high)
61
+ if lo > hi:
62
+ lo, hi = hi, lo
63
+ if lo == hi:
64
+ return lo != identity_mean
65
+ return True
66
+ return spec.std != 0.0 or spec.mean != identity_mean
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class AugmentMaskProbConfig:
71
+ """Per-type Bernoulli mask probability (MLM-style): each eligible token row is masked i.i.d.
72
+
73
+ Masked rows replace payloads with a neutral value (0 / zero float / black pixel); ``step_stream``
74
+ fields at those positions are aligned. ``PREDICTION`` and ``COMPUTE`` rows are never masked.
75
+ """
76
+
77
+ action: float = 0.0
78
+ reward: float = 0.0
79
+ done: float = 0.0
80
+ obs_continuous: float = 0.0
81
+ obs_discrete: float = 0.0
82
+ obs_image: float = 0.0
83
+ time: float = 0.0
84
+
85
+ def any_positive(self) -> bool:
86
+ return any(
87
+ p > 0.0
88
+ for p in (
89
+ self.action,
90
+ self.reward,
91
+ self.done,
92
+ self.obs_continuous,
93
+ self.obs_discrete,
94
+ self.obs_image,
95
+ self.time,
96
+ )
97
+ )
98
+
99
+
100
+ @dataclass(frozen=True)
101
+ class AugmentTokensConfig:
102
+ """Optional training-time token augmentations (copied streams on train batches).
103
+
104
+ PREDICTION and COMPUTE tokens are never modified. See ``augment_tokens.TokenAugmenter``.
105
+ ``mask_prob`` enables MLM-style random zero-masking per token type. Set ``enabled: false`` to disable all augmentations.
106
+ """
107
+
108
+ enabled: bool = True # master switch — false disables all augmentations regardless of other settings
109
+ permute_tokens: bool = False # random shuffle of token order within each step (excludes PREDICTION and COMPUTE)
110
+ scale_reward: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(1.0, 0.0))
111
+ shift_reward: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(0.0, 0.0))
112
+ scale_obs: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(1.0, 0.0))
113
+ shift_obs: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(0.0, 0.0))
114
+ scale_obs_image: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(1.0, 0.0))
115
+ shift_obs_image: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(0.0, 0.0))
116
+ permute_obs_discrete: bool = False # remap OBS_DISCRETE token ids only (q_star / actions unchanged — unsafe for semantic categoricals)
117
+ permute_action: Literal[False, "input", "target", "both"] = False # action permutation mode
118
+ permute_done: bool = False # random swap of done 0/1
119
+ mask_prob: AugmentMaskProbConfig = field(default_factory=AugmentMaskProbConfig)
120
+
121
+ def __post_init__(self) -> None:
122
+ val = self.permute_action
123
+ if val is False:
124
+ return
125
+ if val not in ("input", "target", "both"):
126
+ raise ValueError(
127
+ f"AugmentTokensConfig.permute_action must be false or one of: 'input', 'target', 'both'; got {val!r}."
128
+ )
129
+
130
+ def permute_action_enabled(self) -> bool:
131
+ return bool(self.permute_action)
132
+
133
+ def permute_action_mode(self) -> Literal["input", "target", "both"]:
134
+ val = self.permute_action
135
+ if val is False:
136
+ return "both"
137
+ return cast(Literal["input", "target", "both"], val)
138
+
139
+ def any_enabled(self) -> bool:
140
+ if not self.enabled:
141
+ return False
142
+ if self.permute_tokens or self.permute_obs_discrete or self.permute_action_enabled() or self.permute_done:
143
+ return True
144
+ if _augment_scalar_active(spec=self.scale_reward, identity_mean=1.0) or _augment_scalar_active(spec=self.shift_reward, identity_mean=0.0):
145
+ return True
146
+ if _augment_scalar_active(spec=self.scale_obs, identity_mean=1.0) or _augment_scalar_active(spec=self.shift_obs, identity_mean=0.0):
147
+ return True
148
+ if _augment_scalar_active(spec=self.scale_obs_image, identity_mean=1.0) or _augment_scalar_active(spec=self.shift_obs_image, identity_mean=0.0):
149
+ return True
150
+ if self.mask_prob.any_positive():
151
+ return True
152
+ return False
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Augmentation functions
157
+ # ---------------------------------------------------------------------------
158
+
159
+
160
+ def _sample_scalar(spec: AugmentScalarSpec, generator: torch.Generator) -> float:
161
+ if spec.low is not None and spec.high is not None:
162
+ lo, hi = float(spec.low), float(spec.high)
163
+ if lo > hi:
164
+ lo, hi = hi, lo
165
+ u = torch.rand((), device=generator.device, generator=generator)
166
+ return float(lo + (hi - lo) * u)
167
+ if spec.std == 0.0:
168
+ return float(spec.mean)
169
+ t = torch.randn((), device=generator.device, generator=generator)
170
+ return float(t * spec.std + spec.mean)
171
+
172
+
173
+ def _inverse_action_perm_rows(perm: torch.Tensor) -> torch.Tensor:
174
+ """``perm[b, old] = new`` → ``out[b, new] = old`` (inverse along the action axis)."""
175
+ B, A = int(perm.shape[0]), int(perm.shape[1])
176
+ out = torch.empty((B, A), device=perm.device, dtype=perm.dtype)
177
+ cols = torch.arange(A, device=perm.device, dtype=torch.long).unsqueeze(0).expand(B, -1)
178
+ out.scatter_(1, perm.to(dtype=torch.long), cols)
179
+ return out
180
+
181
+
182
+ @torch.no_grad()
183
+ def apply_permute_action_augmentation(
184
+ step_stream: TensorDict,
185
+ perm: torch.Tensor,
186
+ apply_to_input: bool,
187
+ apply_to_target: bool,
188
+ ) -> None:
189
+ """``a → perm[a]`` with one ``perm`` per batch row; mutates ``step_stream`` in-place.
190
+
191
+ ``perm`` shape: ``[B, max_num_actions]``; row ``b`` uses ``perm[b]``.
192
+ """
193
+ if apply_to_input:
194
+ action = step_stream["action"] # [B, S]
195
+ step_stream["action"].copy_(torch.gather(perm, dim=1, index=action.long()))
196
+
197
+ if apply_to_target and "q_star" in step_stream.keys():
198
+ inv_perm = _inverse_action_perm_rows(perm)
199
+ q = step_stream["q_star"] # [B, S, A]
200
+ B, S, A = q.shape
201
+ inv_exp = inv_perm.unsqueeze(1).expand(B, S, A) # [B, S, A]
202
+ step_stream["q_star"].copy_(torch.gather(q, dim=2, index=inv_exp))
203
+
204
+
205
+ @torch.no_grad()
206
+ def apply_permute_done_augmentation(
207
+ step_stream: TensorDict,
208
+ perm: torch.Tensor,
209
+ ) -> None:
210
+ """``d → perm[d]`` with one ``perm`` over ``{0,1,2}`` per batch row; mutates in-place.
211
+
212
+ ``perm`` shape: ``[B, 3]``; row ``b`` uses ``perm[b]``.
213
+ done values: 0=not done, 1=terminal, 2=truncated.
214
+ """
215
+ done = step_stream["done"] # [B, S]
216
+ step_stream["done"].copy_(torch.gather(perm, dim=1, index=done.long()))
217
+
218
+
219
+ @torch.no_grad()
220
+ def apply_reward_scale_shift(
221
+ step_stream: TensorDict,
222
+ scale: float,
223
+ shift: float,
224
+ ) -> None:
225
+ """Scale/shift rewards (in-place)."""
226
+ if scale == 1.0 and shift == 0.0:
227
+ return
228
+ sr = step_stream["reward"]
229
+ step_stream["reward"].copy_(sr * scale + shift)
230
+
231
+
232
+ @torch.no_grad()
233
+ def apply_obs_continuous_scale_shift(
234
+ step_stream: TensorDict,
235
+ scale: float,
236
+ shift: float,
237
+ ) -> None:
238
+ """Scale/shift continuous obs values (in-place)."""
239
+ if scale == 1.0 and shift == 0.0:
240
+ return
241
+ if "obs_continuous" not in step_stream.keys():
242
+ return
243
+ obs = step_stream["obs_continuous"]
244
+ step_stream["obs_continuous"].copy_(obs.double() * scale + shift)
245
+
246
+
247
+ @torch.no_grad()
248
+ def apply_obs_image_scale_shift(
249
+ step_stream: TensorDict,
250
+ scale: float,
251
+ shift: float,
252
+ ) -> None:
253
+ """Scale/shift image pixel values (clamped 0-255, in-place)."""
254
+ if scale == 1.0 and shift == 0.0:
255
+ return
256
+ if "obs_image" not in step_stream.keys():
257
+ return
258
+ obs = step_stream["obs_image"]
259
+ step_stream["obs_image"].copy_((obs.float() * scale + shift).round().clamp(0, 255).to(torch.int64))
260
+
261
+
262
+ @torch.no_grad()
263
+ def apply_permute_obs_discrete_augmentation(
264
+ step_stream: TensorDict,
265
+ perm: torch.Tensor,
266
+ ) -> None:
267
+ """``v → perm[v]`` for OBS_DISCRETE values per batch row (in-place).
268
+
269
+ ``perm`` shape: ``[B, max_num_obs_discrete]``; row ``b`` uses ``perm[b]``.
270
+ """
271
+ if "obs_discrete" not in step_stream.keys():
272
+ return
273
+ obs = step_stream["obs_discrete"] # [B, S]
274
+ perm_exp = perm.unsqueeze(1).expand_as(obs) # [B, S]
275
+ step_stream["obs_discrete"].copy_(torch.gather(perm_exp, dim=1, index=obs.long()))
276
+
277
+
278
+ @torch.no_grad()
279
+ def apply_field_masks(
280
+ step_stream: TensorDict,
281
+ mask_prob: AugmentMaskProbConfig,
282
+ generator: torch.Generator,
283
+ ) -> None:
284
+ """Sample Bernoulli masks per step and zero masked fields (in-place)."""
285
+ if not mask_prob.any_positive():
286
+ return
287
+
288
+ dev = cast(torch.device, step_stream["action"].device)
289
+ g = generator
290
+ B, S = int(step_stream["action"].shape[0]), int(step_stream["action"].shape[1])
291
+
292
+ def _bernoulli_mask(prob: float) -> torch.Tensor | None:
293
+ if prob <= 0.0:
294
+ return None
295
+ rand = torch.rand((B, S), device=dev, generator=g)
296
+ return rand < prob
297
+
298
+ mask_action = _bernoulli_mask(mask_prob.action)
299
+ if mask_action is not None and mask_action.any():
300
+ sa = step_stream["action"]
301
+ step_stream["action"].copy_(torch.where(mask_action, torch.zeros_like(sa), sa))
302
+
303
+ mask_reward = _bernoulli_mask(mask_prob.reward)
304
+ if mask_reward is not None and mask_reward.any():
305
+ sr = step_stream["reward"]
306
+ step_stream["reward"].copy_(torch.where(mask_reward, torch.zeros_like(sr), sr))
307
+
308
+ mask_done = _bernoulli_mask(mask_prob.done)
309
+ if mask_done is not None and mask_done.any():
310
+ sd = step_stream["done"]
311
+ step_stream["done"].copy_(torch.where(mask_done, torch.zeros_like(sd), sd))
312
+
313
+ mask_obs_continuous = _bernoulli_mask(mask_prob.obs_continuous)
314
+ if mask_obs_continuous is not None and mask_obs_continuous.any() and "obs_continuous" in step_stream.keys():
315
+ obs = step_stream["obs_continuous"]
316
+ m_exp = mask_obs_continuous.unsqueeze(-1).expand_as(obs)
317
+ step_stream["obs_continuous"].copy_(torch.where(m_exp, torch.zeros_like(obs), obs))
318
+
319
+ mask_obs_discrete = _bernoulli_mask(mask_prob.obs_discrete)
320
+ if mask_obs_discrete is not None and mask_obs_discrete.any() and "obs_discrete" in step_stream.keys():
321
+ obs = step_stream["obs_discrete"] # [B, S]
322
+ step_stream["obs_discrete"].copy_(torch.where(mask_obs_discrete, torch.zeros_like(obs), obs))
323
+
324
+ mask_obs_image = _bernoulli_mask(mask_prob.obs_image)
325
+ if mask_obs_image is not None and mask_obs_image.any() and "obs_image" in step_stream.keys():
326
+ obs = step_stream["obs_image"]
327
+ m_exp = mask_obs_image.unsqueeze(-1).expand_as(obs)
328
+ step_stream["obs_image"].copy_(torch.where(m_exp, torch.zeros_like(obs), obs))
329
+
330
+ mask_time = _bernoulli_mask(mask_prob.time)
331
+ if mask_time is not None and mask_time.any() and "time" in step_stream.keys():
332
+ st = step_stream["time"]
333
+ # -1 means "not available"
334
+ step_stream["time"].copy_(torch.where(mask_time, torch.full_like(st, -1), st))
335
+
336
+
337
+ @dataclass
338
+ class AugmentSnapshot:
339
+ """Fixed permutations and scalar draws for one batch, reused across multiple ``__call__``."""
340
+
341
+ batch_size: int
342
+ device: torch.device
343
+ perm_action: torch.Tensor | None
344
+ perm_done: torch.Tensor | None
345
+ r_scale: float | None
346
+ r_shift: float | None
347
+ o_scale: float | None
348
+ o_shift: float | None
349
+ im_scale: float | None
350
+ im_shift: float | None
351
+ perm_obs_discrete: torch.Tensor | None
352
+
353
+
354
+ class TokenAugmenter:
355
+ """Applies ``AugmentTokensConfig`` to a step TensorDict batch.
356
+
357
+ Call with ``step_stream`` to obtain a possibly augmented copy.
358
+ :meth:`__call__` applies permutations/scalars from the stored snapshot; ``mask_prob``
359
+ is sampled anew each call. Call :meth:`update_augmentations` first (required whenever
360
+ any augmentation is enabled).
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ augment: AugmentTokensConfig,
366
+ max_num_actions: int,
367
+ max_num_obs_discrete: int,
368
+ device: torch.device,
369
+ generator: torch.Generator | None = None,
370
+ ) -> None:
371
+ if not isinstance(augment, AugmentTokensConfig):
372
+ raise TypeError(f"augment must be AugmentTokensConfig, got {type(augment).__name__}")
373
+ self._augment = augment
374
+ self._max_num_actions = int(max_num_actions)
375
+ self._max_num_obs_discrete = int(max_num_obs_discrete)
376
+ self._generator = generator if generator is not None else torch.Generator(device=device)
377
+ self._snapshot: AugmentSnapshot | None = None
378
+
379
+ @property
380
+ def augment(self) -> AugmentTokensConfig:
381
+ return self._augment
382
+
383
+ @property
384
+ def snapshot(self) -> AugmentSnapshot | None:
385
+ return self._snapshot
386
+
387
+ def clear_augmentations(self) -> None:
388
+ self._snapshot = None
389
+
390
+ @torch.no_grad()
391
+ def update_augmentations(self, step_stream: TensorDict) -> None:
392
+ """Sample permutations and scalar parameters for this batch and store them."""
393
+ augment = self._augment
394
+ if not augment.any_enabled():
395
+ self._snapshot = None
396
+ return
397
+
398
+ action = step_stream["action"]
399
+ B = int(action.shape[0])
400
+ dev = cast(torch.device, action.device)
401
+ g = self._generator
402
+
403
+ perm_action: torch.Tensor | None = None
404
+ if augment.permute_action_enabled():
405
+ perm_action = torch.stack(
406
+ [torch.randperm(self._max_num_actions, device=dev, generator=g) for _ in range(B)],
407
+ dim=0,
408
+ )
409
+
410
+ perm_done: torch.Tensor | None = None
411
+ if augment.permute_done:
412
+ perm_done = torch.stack(
413
+ [torch.randperm(3, device=dev, generator=g) for _ in range(B)],
414
+ dim=0,
415
+ )
416
+
417
+ perm_obs_discrete: torch.Tensor | None = None
418
+ if augment.permute_obs_discrete:
419
+ perm_obs_discrete = torch.stack(
420
+ [torch.randperm(self._max_num_obs_discrete, device=dev, generator=g) for _ in range(B)],
421
+ dim=0,
422
+ )
423
+
424
+ r_scale: float | None = None
425
+ r_shift: float | None = None
426
+ if _augment_scalar_active(augment.scale_reward, 1.0):
427
+ r_scale = _sample_scalar(augment.scale_reward, g)
428
+ if _augment_scalar_active(augment.shift_reward, 0.0):
429
+ r_shift = _sample_scalar(augment.shift_reward, g)
430
+
431
+ o_scale: float | None = None
432
+ o_shift: float | None = None
433
+ if _augment_scalar_active(augment.scale_obs, 1.0):
434
+ o_scale = _sample_scalar(augment.scale_obs, g)
435
+ if _augment_scalar_active(augment.shift_obs, 0.0):
436
+ o_shift = _sample_scalar(augment.shift_obs, g)
437
+
438
+ im_scale: float | None = None
439
+ im_shift: float | None = None
440
+ if _augment_scalar_active(augment.scale_obs_image, 1.0):
441
+ im_scale = _sample_scalar(augment.scale_obs_image, g)
442
+ if _augment_scalar_active(augment.shift_obs_image, 0.0):
443
+ im_shift = _sample_scalar(augment.shift_obs_image, g)
444
+
445
+ self._snapshot = AugmentSnapshot(
446
+ batch_size=B,
447
+ device=dev,
448
+ perm_action=perm_action,
449
+ perm_done=perm_done,
450
+ r_scale=r_scale,
451
+ r_shift=r_shift,
452
+ o_scale=o_scale,
453
+ o_shift=o_shift,
454
+ im_scale=im_scale,
455
+ im_shift=im_shift,
456
+ perm_obs_discrete=perm_obs_discrete,
457
+ )
458
+
459
+ def _assert_snapshot_matches(self, step_stream: TensorDict) -> AugmentSnapshot:
460
+ snap = self._snapshot
461
+ if snap is None:
462
+ raise RuntimeError("TokenAugmenter has no snapshot; call update_augmentations first.")
463
+ B = int(step_stream["action"].shape[0])
464
+ dev = cast(torch.device, step_stream["action"].device)
465
+ if B != snap.batch_size or dev != snap.device:
466
+ raise ValueError(
467
+ f"Batch mismatch: got B={B}, device={dev!r}; snapshot expects "
468
+ f"batch_size={snap.batch_size}, device={snap.device!r}."
469
+ )
470
+ return snap
471
+
472
+ @torch.no_grad()
473
+ def __call__(
474
+ self,
475
+ step_stream: TensorDict,
476
+ ) -> TensorDict:
477
+ """Augment a training batch; returns a new TensorDict when augmentation runs.
478
+
479
+ Requires ``step_stream`` shape ``[B, S]`` per field.
480
+ Call :meth:`update_augmentations` with the same batch first.
481
+ Permutations/scalars use :attr:`snapshot`; ``mask_prob`` is drawn fresh here.
482
+ """
483
+ augment = self._augment
484
+ if not augment.any_enabled():
485
+ return step_stream
486
+
487
+ step_stream = step_stream.clone()
488
+ snap = self._assert_snapshot_matches(step_stream)
489
+
490
+ # MLM-style masks first (corrupt inputs before permute/scale)
491
+ if augment.mask_prob.any_positive():
492
+ apply_field_masks(step_stream=step_stream, mask_prob=augment.mask_prob, generator=self._generator)
493
+
494
+ if augment.permute_action_enabled():
495
+ assert snap.perm_action is not None
496
+ mode = augment.permute_action_mode()
497
+ apply_permute_action_augmentation(
498
+ step_stream=step_stream,
499
+ perm=snap.perm_action,
500
+ apply_to_input=mode in ("input", "both"),
501
+ apply_to_target=mode in ("target", "both"),
502
+ )
503
+
504
+ if augment.permute_done:
505
+ assert snap.perm_done is not None
506
+ apply_permute_done_augmentation(step_stream=step_stream, perm=snap.perm_done)
507
+
508
+ if augment.permute_obs_discrete:
509
+ assert snap.perm_obs_discrete is not None
510
+ apply_permute_obs_discrete_augmentation(step_stream=step_stream, perm=snap.perm_obs_discrete)
511
+
512
+ if _augment_scalar_active(spec=augment.scale_reward, identity_mean=1.0):
513
+ assert snap.r_scale is not None
514
+ apply_reward_scale_shift(step_stream=step_stream, scale=snap.r_scale, shift=0.0)
515
+
516
+ if _augment_scalar_active(spec=augment.shift_reward, identity_mean=0.0):
517
+ assert snap.r_shift is not None
518
+ apply_reward_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.r_shift)
519
+
520
+ if _augment_scalar_active(spec=augment.scale_obs, identity_mean=1.0):
521
+ assert snap.o_scale is not None
522
+ apply_obs_continuous_scale_shift(step_stream=step_stream, scale=snap.o_scale, shift=0.0)
523
+
524
+ if _augment_scalar_active(spec=augment.shift_obs, identity_mean=0.0):
525
+ assert snap.o_shift is not None
526
+ apply_obs_continuous_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.o_shift)
527
+
528
+ if _augment_scalar_active(spec=augment.scale_obs_image, identity_mean=1.0):
529
+ assert snap.im_scale is not None
530
+ apply_obs_image_scale_shift(step_stream=step_stream, scale=snap.im_scale, shift=0.0)
531
+
532
+ if _augment_scalar_active(spec=augment.shift_obs_image, identity_mean=0.0):
533
+ assert snap.im_shift is not None
534
+ apply_obs_image_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.im_shift)
535
+
536
+ return step_stream