mouse-core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mouse/__init__.py +12 -0
- mouse/data/__init__.py +0 -0
- mouse/data/augment.py +536 -0
- mouse/data/batch.py +256 -0
- mouse/data/dataset_store.py +390 -0
- mouse/data/hub.py +296 -0
- mouse/losses/__init__.py +18 -0
- mouse/losses/base.py +72 -0
- mouse/losses/dqn.py +117 -0
- mouse/losses/sp.py +218 -0
- mouse/losses/sv.py +59 -0
- mouse/losses/vec_dqn.py +99 -0
- mouse/models/__init__.py +19 -0
- mouse/models/backbone/__init__.py +87 -0
- mouse/models/backbone/llama.py +157 -0
- mouse/models/backbone/none.py +47 -0
- mouse/models/backbone/qwen3.py +171 -0
- mouse/models/base.py +644 -0
- mouse/models/embedding/__init__.py +0 -0
- mouse/models/embedding/embedding.py +666 -0
- mouse/models/embedding/encoding.py +90 -0
- mouse/models/embedding/linear.py +122 -0
- mouse/models/heads/__init__.py +15 -0
- mouse/models/heads/base.py +127 -0
- mouse/models/heads/dqn.py +43 -0
- mouse/models/heads/swiglu.py +68 -0
- mouse/models/heads/vec_dqn.py +116 -0
- mouse_core-0.1.0.dist-info/METADATA +54 -0
- mouse_core-0.1.0.dist-info/RECORD +32 -0
- mouse_core-0.1.0.dist-info/WHEEL +5 -0
- mouse_core-0.1.0.dist-info/licenses/LICENSE +674 -0
- mouse_core-0.1.0.dist-info/top_level.txt +1 -0
mouse/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from mouse.models import Model, load_model
|
|
2
|
+
from mouse.models.heads import BaseHead, BaseHeadWithTarget
|
|
3
|
+
from mouse.losses import LossConfig, LossFunction
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"Model",
|
|
7
|
+
"load_model",
|
|
8
|
+
"BaseHead",
|
|
9
|
+
"BaseHeadWithTarget",
|
|
10
|
+
"LossConfig",
|
|
11
|
+
"LossFunction",
|
|
12
|
+
]
|
mouse/data/__init__.py
ADDED
|
File without changes
|
mouse/data/augment.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
"""Training-time augmentations for offline RL batches.
|
|
2
|
+
|
|
3
|
+
When any augmentation is enabled, :class:`TokenAugmenter` clones the step stream,
|
|
4
|
+
applies transforms to the copy, and returns it; the input is left unchanged.
|
|
5
|
+
If nothing is enabled, the input is returned as-is (no clone).
|
|
6
|
+
|
|
7
|
+
**mask_prob** (see :class:`AugmentMaskProbConfig`): per-field Bernoulli mask
|
|
8
|
+
with the given probability on each step. Masked steps have their corresponding
|
|
9
|
+
field(s) zeroed (or set to -1 for time). ``PREDICTION`` and ``COMPUTE`` rows are never masked.
|
|
10
|
+
Probabilities ``<= 0`` for a field skip that field; if every entry is zero, masking is
|
|
11
|
+
skipped entirely.
|
|
12
|
+
Masks are **not** snapshotted: a new random draw runs on every :meth:`TokenAugmenter.__call__`.
|
|
13
|
+
|
|
14
|
+
**permute_** flags: random choices are **per sequence** (per batch row): the same
|
|
15
|
+
action mapping, done flip apply at every step in that row; different batch rows may
|
|
16
|
+
use different permutations. **scale_** / **shift_** use :class:`AugmentScalarSpec`:
|
|
17
|
+
``low`` / ``high`` for uniform on ``[low, high)`` per batch (configs use this for all
|
|
18
|
+
continuous aug), or Gaussian ``mean`` / ``std``; ``low == high`` or ``std: 0`` fixes
|
|
19
|
+
the scalar at identity.
|
|
20
|
+
|
|
21
|
+
Note: ``permute_tokens`` is not applicable without an explicit token stream and is silently
|
|
22
|
+
ignored.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from dataclasses import dataclass, field
|
|
28
|
+
from typing import Literal, cast
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
from tensordict import TensorDict
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ---------------------------------------------------------------------------
|
|
35
|
+
# Config dataclasses
|
|
36
|
+
# ---------------------------------------------------------------------------
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True)
|
|
40
|
+
class AugmentScalarSpec:
|
|
41
|
+
"""Scalar augmentation: uniform on ``[low, high)`` per batch, or Gaussian ``mean + std * N(0,1)``.
|
|
42
|
+
|
|
43
|
+
YAML configs typically use ``low``/``high`` (``low == high`` fixes at that value = identity).
|
|
44
|
+
Omit ``low``/``high`` for Gaussian; ``std == 0`` fixes at ``mean``.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
mean: float
|
|
48
|
+
std: float = 0.0
|
|
49
|
+
low: float | None = None
|
|
50
|
+
high: float | None = None
|
|
51
|
+
|
|
52
|
+
def __post_init__(self) -> None:
|
|
53
|
+
u_lo, u_hi = self.low, self.high
|
|
54
|
+
if (u_lo is None) != (u_hi is None):
|
|
55
|
+
raise ValueError("AugmentScalarSpec: set both low and high for uniform, or neither for Gaussian.")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _augment_scalar_active(spec: AugmentScalarSpec, identity_mean: float) -> bool:
|
|
59
|
+
if spec.low is not None and spec.high is not None:
|
|
60
|
+
lo, hi = float(spec.low), float(spec.high)
|
|
61
|
+
if lo > hi:
|
|
62
|
+
lo, hi = hi, lo
|
|
63
|
+
if lo == hi:
|
|
64
|
+
return lo != identity_mean
|
|
65
|
+
return True
|
|
66
|
+
return spec.std != 0.0 or spec.mean != identity_mean
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass(frozen=True)
|
|
70
|
+
class AugmentMaskProbConfig:
|
|
71
|
+
"""Per-type Bernoulli mask probability (MLM-style): each eligible token row is masked i.i.d.
|
|
72
|
+
|
|
73
|
+
Masked rows replace payloads with a neutral value (0 / zero float / black pixel); ``step_stream``
|
|
74
|
+
fields at those positions are aligned. ``PREDICTION`` and ``COMPUTE`` rows are never masked.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
action: float = 0.0
|
|
78
|
+
reward: float = 0.0
|
|
79
|
+
done: float = 0.0
|
|
80
|
+
obs_continuous: float = 0.0
|
|
81
|
+
obs_discrete: float = 0.0
|
|
82
|
+
obs_image: float = 0.0
|
|
83
|
+
time: float = 0.0
|
|
84
|
+
|
|
85
|
+
def any_positive(self) -> bool:
|
|
86
|
+
return any(
|
|
87
|
+
p > 0.0
|
|
88
|
+
for p in (
|
|
89
|
+
self.action,
|
|
90
|
+
self.reward,
|
|
91
|
+
self.done,
|
|
92
|
+
self.obs_continuous,
|
|
93
|
+
self.obs_discrete,
|
|
94
|
+
self.obs_image,
|
|
95
|
+
self.time,
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass(frozen=True)
|
|
101
|
+
class AugmentTokensConfig:
|
|
102
|
+
"""Optional training-time token augmentations (copied streams on train batches).
|
|
103
|
+
|
|
104
|
+
PREDICTION and COMPUTE tokens are never modified. See ``augment_tokens.TokenAugmenter``.
|
|
105
|
+
``mask_prob`` enables MLM-style random zero-masking per token type. Set ``enabled: false`` to disable all augmentations.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
enabled: bool = True # master switch — false disables all augmentations regardless of other settings
|
|
109
|
+
permute_tokens: bool = False # random shuffle of token order within each step (excludes PREDICTION and COMPUTE)
|
|
110
|
+
scale_reward: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(1.0, 0.0))
|
|
111
|
+
shift_reward: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(0.0, 0.0))
|
|
112
|
+
scale_obs: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(1.0, 0.0))
|
|
113
|
+
shift_obs: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(0.0, 0.0))
|
|
114
|
+
scale_obs_image: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(1.0, 0.0))
|
|
115
|
+
shift_obs_image: AugmentScalarSpec = field(default_factory=lambda: AugmentScalarSpec(0.0, 0.0))
|
|
116
|
+
permute_obs_discrete: bool = False # remap OBS_DISCRETE token ids only (q_star / actions unchanged — unsafe for semantic categoricals)
|
|
117
|
+
permute_action: Literal[False, "input", "target", "both"] = False # action permutation mode
|
|
118
|
+
permute_done: bool = False # random swap of done 0/1
|
|
119
|
+
mask_prob: AugmentMaskProbConfig = field(default_factory=AugmentMaskProbConfig)
|
|
120
|
+
|
|
121
|
+
def __post_init__(self) -> None:
|
|
122
|
+
val = self.permute_action
|
|
123
|
+
if val is False:
|
|
124
|
+
return
|
|
125
|
+
if val not in ("input", "target", "both"):
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"AugmentTokensConfig.permute_action must be false or one of: 'input', 'target', 'both'; got {val!r}."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def permute_action_enabled(self) -> bool:
|
|
131
|
+
return bool(self.permute_action)
|
|
132
|
+
|
|
133
|
+
def permute_action_mode(self) -> Literal["input", "target", "both"]:
|
|
134
|
+
val = self.permute_action
|
|
135
|
+
if val is False:
|
|
136
|
+
return "both"
|
|
137
|
+
return cast(Literal["input", "target", "both"], val)
|
|
138
|
+
|
|
139
|
+
def any_enabled(self) -> bool:
|
|
140
|
+
if not self.enabled:
|
|
141
|
+
return False
|
|
142
|
+
if self.permute_tokens or self.permute_obs_discrete or self.permute_action_enabled() or self.permute_done:
|
|
143
|
+
return True
|
|
144
|
+
if _augment_scalar_active(spec=self.scale_reward, identity_mean=1.0) or _augment_scalar_active(spec=self.shift_reward, identity_mean=0.0):
|
|
145
|
+
return True
|
|
146
|
+
if _augment_scalar_active(spec=self.scale_obs, identity_mean=1.0) or _augment_scalar_active(spec=self.shift_obs, identity_mean=0.0):
|
|
147
|
+
return True
|
|
148
|
+
if _augment_scalar_active(spec=self.scale_obs_image, identity_mean=1.0) or _augment_scalar_active(spec=self.shift_obs_image, identity_mean=0.0):
|
|
149
|
+
return True
|
|
150
|
+
if self.mask_prob.any_positive():
|
|
151
|
+
return True
|
|
152
|
+
return False
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# ---------------------------------------------------------------------------
|
|
156
|
+
# Augmentation functions
|
|
157
|
+
# ---------------------------------------------------------------------------
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _sample_scalar(spec: AugmentScalarSpec, generator: torch.Generator) -> float:
|
|
161
|
+
if spec.low is not None and spec.high is not None:
|
|
162
|
+
lo, hi = float(spec.low), float(spec.high)
|
|
163
|
+
if lo > hi:
|
|
164
|
+
lo, hi = hi, lo
|
|
165
|
+
u = torch.rand((), device=generator.device, generator=generator)
|
|
166
|
+
return float(lo + (hi - lo) * u)
|
|
167
|
+
if spec.std == 0.0:
|
|
168
|
+
return float(spec.mean)
|
|
169
|
+
t = torch.randn((), device=generator.device, generator=generator)
|
|
170
|
+
return float(t * spec.std + spec.mean)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _inverse_action_perm_rows(perm: torch.Tensor) -> torch.Tensor:
|
|
174
|
+
"""``perm[b, old] = new`` → ``out[b, new] = old`` (inverse along the action axis)."""
|
|
175
|
+
B, A = int(perm.shape[0]), int(perm.shape[1])
|
|
176
|
+
out = torch.empty((B, A), device=perm.device, dtype=perm.dtype)
|
|
177
|
+
cols = torch.arange(A, device=perm.device, dtype=torch.long).unsqueeze(0).expand(B, -1)
|
|
178
|
+
out.scatter_(1, perm.to(dtype=torch.long), cols)
|
|
179
|
+
return out
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@torch.no_grad()
|
|
183
|
+
def apply_permute_action_augmentation(
|
|
184
|
+
step_stream: TensorDict,
|
|
185
|
+
perm: torch.Tensor,
|
|
186
|
+
apply_to_input: bool,
|
|
187
|
+
apply_to_target: bool,
|
|
188
|
+
) -> None:
|
|
189
|
+
"""``a → perm[a]`` with one ``perm`` per batch row; mutates ``step_stream`` in-place.
|
|
190
|
+
|
|
191
|
+
``perm`` shape: ``[B, max_num_actions]``; row ``b`` uses ``perm[b]``.
|
|
192
|
+
"""
|
|
193
|
+
if apply_to_input:
|
|
194
|
+
action = step_stream["action"] # [B, S]
|
|
195
|
+
step_stream["action"].copy_(torch.gather(perm, dim=1, index=action.long()))
|
|
196
|
+
|
|
197
|
+
if apply_to_target and "q_star" in step_stream.keys():
|
|
198
|
+
inv_perm = _inverse_action_perm_rows(perm)
|
|
199
|
+
q = step_stream["q_star"] # [B, S, A]
|
|
200
|
+
B, S, A = q.shape
|
|
201
|
+
inv_exp = inv_perm.unsqueeze(1).expand(B, S, A) # [B, S, A]
|
|
202
|
+
step_stream["q_star"].copy_(torch.gather(q, dim=2, index=inv_exp))
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@torch.no_grad()
|
|
206
|
+
def apply_permute_done_augmentation(
|
|
207
|
+
step_stream: TensorDict,
|
|
208
|
+
perm: torch.Tensor,
|
|
209
|
+
) -> None:
|
|
210
|
+
"""``d → perm[d]`` with one ``perm`` over ``{0,1,2}`` per batch row; mutates in-place.
|
|
211
|
+
|
|
212
|
+
``perm`` shape: ``[B, 3]``; row ``b`` uses ``perm[b]``.
|
|
213
|
+
done values: 0=not done, 1=terminal, 2=truncated.
|
|
214
|
+
"""
|
|
215
|
+
done = step_stream["done"] # [B, S]
|
|
216
|
+
step_stream["done"].copy_(torch.gather(perm, dim=1, index=done.long()))
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@torch.no_grad()
|
|
220
|
+
def apply_reward_scale_shift(
|
|
221
|
+
step_stream: TensorDict,
|
|
222
|
+
scale: float,
|
|
223
|
+
shift: float,
|
|
224
|
+
) -> None:
|
|
225
|
+
"""Scale/shift rewards (in-place)."""
|
|
226
|
+
if scale == 1.0 and shift == 0.0:
|
|
227
|
+
return
|
|
228
|
+
sr = step_stream["reward"]
|
|
229
|
+
step_stream["reward"].copy_(sr * scale + shift)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@torch.no_grad()
|
|
233
|
+
def apply_obs_continuous_scale_shift(
|
|
234
|
+
step_stream: TensorDict,
|
|
235
|
+
scale: float,
|
|
236
|
+
shift: float,
|
|
237
|
+
) -> None:
|
|
238
|
+
"""Scale/shift continuous obs values (in-place)."""
|
|
239
|
+
if scale == 1.0 and shift == 0.0:
|
|
240
|
+
return
|
|
241
|
+
if "obs_continuous" not in step_stream.keys():
|
|
242
|
+
return
|
|
243
|
+
obs = step_stream["obs_continuous"]
|
|
244
|
+
step_stream["obs_continuous"].copy_(obs.double() * scale + shift)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@torch.no_grad()
|
|
248
|
+
def apply_obs_image_scale_shift(
|
|
249
|
+
step_stream: TensorDict,
|
|
250
|
+
scale: float,
|
|
251
|
+
shift: float,
|
|
252
|
+
) -> None:
|
|
253
|
+
"""Scale/shift image pixel values (clamped 0-255, in-place)."""
|
|
254
|
+
if scale == 1.0 and shift == 0.0:
|
|
255
|
+
return
|
|
256
|
+
if "obs_image" not in step_stream.keys():
|
|
257
|
+
return
|
|
258
|
+
obs = step_stream["obs_image"]
|
|
259
|
+
step_stream["obs_image"].copy_((obs.float() * scale + shift).round().clamp(0, 255).to(torch.int64))
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@torch.no_grad()
|
|
263
|
+
def apply_permute_obs_discrete_augmentation(
|
|
264
|
+
step_stream: TensorDict,
|
|
265
|
+
perm: torch.Tensor,
|
|
266
|
+
) -> None:
|
|
267
|
+
"""``v → perm[v]`` for OBS_DISCRETE values per batch row (in-place).
|
|
268
|
+
|
|
269
|
+
``perm`` shape: ``[B, max_num_obs_discrete]``; row ``b`` uses ``perm[b]``.
|
|
270
|
+
"""
|
|
271
|
+
if "obs_discrete" not in step_stream.keys():
|
|
272
|
+
return
|
|
273
|
+
obs = step_stream["obs_discrete"] # [B, S]
|
|
274
|
+
perm_exp = perm.unsqueeze(1).expand_as(obs) # [B, S]
|
|
275
|
+
step_stream["obs_discrete"].copy_(torch.gather(perm_exp, dim=1, index=obs.long()))
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@torch.no_grad()
|
|
279
|
+
def apply_field_masks(
|
|
280
|
+
step_stream: TensorDict,
|
|
281
|
+
mask_prob: AugmentMaskProbConfig,
|
|
282
|
+
generator: torch.Generator,
|
|
283
|
+
) -> None:
|
|
284
|
+
"""Sample Bernoulli masks per step and zero masked fields (in-place)."""
|
|
285
|
+
if not mask_prob.any_positive():
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
dev = cast(torch.device, step_stream["action"].device)
|
|
289
|
+
g = generator
|
|
290
|
+
B, S = int(step_stream["action"].shape[0]), int(step_stream["action"].shape[1])
|
|
291
|
+
|
|
292
|
+
def _bernoulli_mask(prob: float) -> torch.Tensor | None:
|
|
293
|
+
if prob <= 0.0:
|
|
294
|
+
return None
|
|
295
|
+
rand = torch.rand((B, S), device=dev, generator=g)
|
|
296
|
+
return rand < prob
|
|
297
|
+
|
|
298
|
+
mask_action = _bernoulli_mask(mask_prob.action)
|
|
299
|
+
if mask_action is not None and mask_action.any():
|
|
300
|
+
sa = step_stream["action"]
|
|
301
|
+
step_stream["action"].copy_(torch.where(mask_action, torch.zeros_like(sa), sa))
|
|
302
|
+
|
|
303
|
+
mask_reward = _bernoulli_mask(mask_prob.reward)
|
|
304
|
+
if mask_reward is not None and mask_reward.any():
|
|
305
|
+
sr = step_stream["reward"]
|
|
306
|
+
step_stream["reward"].copy_(torch.where(mask_reward, torch.zeros_like(sr), sr))
|
|
307
|
+
|
|
308
|
+
mask_done = _bernoulli_mask(mask_prob.done)
|
|
309
|
+
if mask_done is not None and mask_done.any():
|
|
310
|
+
sd = step_stream["done"]
|
|
311
|
+
step_stream["done"].copy_(torch.where(mask_done, torch.zeros_like(sd), sd))
|
|
312
|
+
|
|
313
|
+
mask_obs_continuous = _bernoulli_mask(mask_prob.obs_continuous)
|
|
314
|
+
if mask_obs_continuous is not None and mask_obs_continuous.any() and "obs_continuous" in step_stream.keys():
|
|
315
|
+
obs = step_stream["obs_continuous"]
|
|
316
|
+
m_exp = mask_obs_continuous.unsqueeze(-1).expand_as(obs)
|
|
317
|
+
step_stream["obs_continuous"].copy_(torch.where(m_exp, torch.zeros_like(obs), obs))
|
|
318
|
+
|
|
319
|
+
mask_obs_discrete = _bernoulli_mask(mask_prob.obs_discrete)
|
|
320
|
+
if mask_obs_discrete is not None and mask_obs_discrete.any() and "obs_discrete" in step_stream.keys():
|
|
321
|
+
obs = step_stream["obs_discrete"] # [B, S]
|
|
322
|
+
step_stream["obs_discrete"].copy_(torch.where(mask_obs_discrete, torch.zeros_like(obs), obs))
|
|
323
|
+
|
|
324
|
+
mask_obs_image = _bernoulli_mask(mask_prob.obs_image)
|
|
325
|
+
if mask_obs_image is not None and mask_obs_image.any() and "obs_image" in step_stream.keys():
|
|
326
|
+
obs = step_stream["obs_image"]
|
|
327
|
+
m_exp = mask_obs_image.unsqueeze(-1).expand_as(obs)
|
|
328
|
+
step_stream["obs_image"].copy_(torch.where(m_exp, torch.zeros_like(obs), obs))
|
|
329
|
+
|
|
330
|
+
mask_time = _bernoulli_mask(mask_prob.time)
|
|
331
|
+
if mask_time is not None and mask_time.any() and "time" in step_stream.keys():
|
|
332
|
+
st = step_stream["time"]
|
|
333
|
+
# -1 means "not available"
|
|
334
|
+
step_stream["time"].copy_(torch.where(mask_time, torch.full_like(st, -1), st))
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
@dataclass
|
|
338
|
+
class AugmentSnapshot:
|
|
339
|
+
"""Fixed permutations and scalar draws for one batch, reused across multiple ``__call__``."""
|
|
340
|
+
|
|
341
|
+
batch_size: int
|
|
342
|
+
device: torch.device
|
|
343
|
+
perm_action: torch.Tensor | None
|
|
344
|
+
perm_done: torch.Tensor | None
|
|
345
|
+
r_scale: float | None
|
|
346
|
+
r_shift: float | None
|
|
347
|
+
o_scale: float | None
|
|
348
|
+
o_shift: float | None
|
|
349
|
+
im_scale: float | None
|
|
350
|
+
im_shift: float | None
|
|
351
|
+
perm_obs_discrete: torch.Tensor | None
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
class TokenAugmenter:
|
|
355
|
+
"""Applies ``AugmentTokensConfig`` to a step TensorDict batch.
|
|
356
|
+
|
|
357
|
+
Call with ``step_stream`` to obtain a possibly augmented copy.
|
|
358
|
+
:meth:`__call__` applies permutations/scalars from the stored snapshot; ``mask_prob``
|
|
359
|
+
is sampled anew each call. Call :meth:`update_augmentations` first (required whenever
|
|
360
|
+
any augmentation is enabled).
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
def __init__(
|
|
364
|
+
self,
|
|
365
|
+
augment: AugmentTokensConfig,
|
|
366
|
+
max_num_actions: int,
|
|
367
|
+
max_num_obs_discrete: int,
|
|
368
|
+
device: torch.device,
|
|
369
|
+
generator: torch.Generator | None = None,
|
|
370
|
+
) -> None:
|
|
371
|
+
if not isinstance(augment, AugmentTokensConfig):
|
|
372
|
+
raise TypeError(f"augment must be AugmentTokensConfig, got {type(augment).__name__}")
|
|
373
|
+
self._augment = augment
|
|
374
|
+
self._max_num_actions = int(max_num_actions)
|
|
375
|
+
self._max_num_obs_discrete = int(max_num_obs_discrete)
|
|
376
|
+
self._generator = generator if generator is not None else torch.Generator(device=device)
|
|
377
|
+
self._snapshot: AugmentSnapshot | None = None
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def augment(self) -> AugmentTokensConfig:
|
|
381
|
+
return self._augment
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def snapshot(self) -> AugmentSnapshot | None:
|
|
385
|
+
return self._snapshot
|
|
386
|
+
|
|
387
|
+
def clear_augmentations(self) -> None:
|
|
388
|
+
self._snapshot = None
|
|
389
|
+
|
|
390
|
+
@torch.no_grad()
|
|
391
|
+
def update_augmentations(self, step_stream: TensorDict) -> None:
|
|
392
|
+
"""Sample permutations and scalar parameters for this batch and store them."""
|
|
393
|
+
augment = self._augment
|
|
394
|
+
if not augment.any_enabled():
|
|
395
|
+
self._snapshot = None
|
|
396
|
+
return
|
|
397
|
+
|
|
398
|
+
action = step_stream["action"]
|
|
399
|
+
B = int(action.shape[0])
|
|
400
|
+
dev = cast(torch.device, action.device)
|
|
401
|
+
g = self._generator
|
|
402
|
+
|
|
403
|
+
perm_action: torch.Tensor | None = None
|
|
404
|
+
if augment.permute_action_enabled():
|
|
405
|
+
perm_action = torch.stack(
|
|
406
|
+
[torch.randperm(self._max_num_actions, device=dev, generator=g) for _ in range(B)],
|
|
407
|
+
dim=0,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
perm_done: torch.Tensor | None = None
|
|
411
|
+
if augment.permute_done:
|
|
412
|
+
perm_done = torch.stack(
|
|
413
|
+
[torch.randperm(3, device=dev, generator=g) for _ in range(B)],
|
|
414
|
+
dim=0,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
perm_obs_discrete: torch.Tensor | None = None
|
|
418
|
+
if augment.permute_obs_discrete:
|
|
419
|
+
perm_obs_discrete = torch.stack(
|
|
420
|
+
[torch.randperm(self._max_num_obs_discrete, device=dev, generator=g) for _ in range(B)],
|
|
421
|
+
dim=0,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
r_scale: float | None = None
|
|
425
|
+
r_shift: float | None = None
|
|
426
|
+
if _augment_scalar_active(augment.scale_reward, 1.0):
|
|
427
|
+
r_scale = _sample_scalar(augment.scale_reward, g)
|
|
428
|
+
if _augment_scalar_active(augment.shift_reward, 0.0):
|
|
429
|
+
r_shift = _sample_scalar(augment.shift_reward, g)
|
|
430
|
+
|
|
431
|
+
o_scale: float | None = None
|
|
432
|
+
o_shift: float | None = None
|
|
433
|
+
if _augment_scalar_active(augment.scale_obs, 1.0):
|
|
434
|
+
o_scale = _sample_scalar(augment.scale_obs, g)
|
|
435
|
+
if _augment_scalar_active(augment.shift_obs, 0.0):
|
|
436
|
+
o_shift = _sample_scalar(augment.shift_obs, g)
|
|
437
|
+
|
|
438
|
+
im_scale: float | None = None
|
|
439
|
+
im_shift: float | None = None
|
|
440
|
+
if _augment_scalar_active(augment.scale_obs_image, 1.0):
|
|
441
|
+
im_scale = _sample_scalar(augment.scale_obs_image, g)
|
|
442
|
+
if _augment_scalar_active(augment.shift_obs_image, 0.0):
|
|
443
|
+
im_shift = _sample_scalar(augment.shift_obs_image, g)
|
|
444
|
+
|
|
445
|
+
self._snapshot = AugmentSnapshot(
|
|
446
|
+
batch_size=B,
|
|
447
|
+
device=dev,
|
|
448
|
+
perm_action=perm_action,
|
|
449
|
+
perm_done=perm_done,
|
|
450
|
+
r_scale=r_scale,
|
|
451
|
+
r_shift=r_shift,
|
|
452
|
+
o_scale=o_scale,
|
|
453
|
+
o_shift=o_shift,
|
|
454
|
+
im_scale=im_scale,
|
|
455
|
+
im_shift=im_shift,
|
|
456
|
+
perm_obs_discrete=perm_obs_discrete,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
def _assert_snapshot_matches(self, step_stream: TensorDict) -> AugmentSnapshot:
|
|
460
|
+
snap = self._snapshot
|
|
461
|
+
if snap is None:
|
|
462
|
+
raise RuntimeError("TokenAugmenter has no snapshot; call update_augmentations first.")
|
|
463
|
+
B = int(step_stream["action"].shape[0])
|
|
464
|
+
dev = cast(torch.device, step_stream["action"].device)
|
|
465
|
+
if B != snap.batch_size or dev != snap.device:
|
|
466
|
+
raise ValueError(
|
|
467
|
+
f"Batch mismatch: got B={B}, device={dev!r}; snapshot expects "
|
|
468
|
+
f"batch_size={snap.batch_size}, device={snap.device!r}."
|
|
469
|
+
)
|
|
470
|
+
return snap
|
|
471
|
+
|
|
472
|
+
@torch.no_grad()
|
|
473
|
+
def __call__(
|
|
474
|
+
self,
|
|
475
|
+
step_stream: TensorDict,
|
|
476
|
+
) -> TensorDict:
|
|
477
|
+
"""Augment a training batch; returns a new TensorDict when augmentation runs.
|
|
478
|
+
|
|
479
|
+
Requires ``step_stream`` shape ``[B, S]`` per field.
|
|
480
|
+
Call :meth:`update_augmentations` with the same batch first.
|
|
481
|
+
Permutations/scalars use :attr:`snapshot`; ``mask_prob`` is drawn fresh here.
|
|
482
|
+
"""
|
|
483
|
+
augment = self._augment
|
|
484
|
+
if not augment.any_enabled():
|
|
485
|
+
return step_stream
|
|
486
|
+
|
|
487
|
+
step_stream = step_stream.clone()
|
|
488
|
+
snap = self._assert_snapshot_matches(step_stream)
|
|
489
|
+
|
|
490
|
+
# MLM-style masks first (corrupt inputs before permute/scale)
|
|
491
|
+
if augment.mask_prob.any_positive():
|
|
492
|
+
apply_field_masks(step_stream=step_stream, mask_prob=augment.mask_prob, generator=self._generator)
|
|
493
|
+
|
|
494
|
+
if augment.permute_action_enabled():
|
|
495
|
+
assert snap.perm_action is not None
|
|
496
|
+
mode = augment.permute_action_mode()
|
|
497
|
+
apply_permute_action_augmentation(
|
|
498
|
+
step_stream=step_stream,
|
|
499
|
+
perm=snap.perm_action,
|
|
500
|
+
apply_to_input=mode in ("input", "both"),
|
|
501
|
+
apply_to_target=mode in ("target", "both"),
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
if augment.permute_done:
|
|
505
|
+
assert snap.perm_done is not None
|
|
506
|
+
apply_permute_done_augmentation(step_stream=step_stream, perm=snap.perm_done)
|
|
507
|
+
|
|
508
|
+
if augment.permute_obs_discrete:
|
|
509
|
+
assert snap.perm_obs_discrete is not None
|
|
510
|
+
apply_permute_obs_discrete_augmentation(step_stream=step_stream, perm=snap.perm_obs_discrete)
|
|
511
|
+
|
|
512
|
+
if _augment_scalar_active(spec=augment.scale_reward, identity_mean=1.0):
|
|
513
|
+
assert snap.r_scale is not None
|
|
514
|
+
apply_reward_scale_shift(step_stream=step_stream, scale=snap.r_scale, shift=0.0)
|
|
515
|
+
|
|
516
|
+
if _augment_scalar_active(spec=augment.shift_reward, identity_mean=0.0):
|
|
517
|
+
assert snap.r_shift is not None
|
|
518
|
+
apply_reward_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.r_shift)
|
|
519
|
+
|
|
520
|
+
if _augment_scalar_active(spec=augment.scale_obs, identity_mean=1.0):
|
|
521
|
+
assert snap.o_scale is not None
|
|
522
|
+
apply_obs_continuous_scale_shift(step_stream=step_stream, scale=snap.o_scale, shift=0.0)
|
|
523
|
+
|
|
524
|
+
if _augment_scalar_active(spec=augment.shift_obs, identity_mean=0.0):
|
|
525
|
+
assert snap.o_shift is not None
|
|
526
|
+
apply_obs_continuous_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.o_shift)
|
|
527
|
+
|
|
528
|
+
if _augment_scalar_active(spec=augment.scale_obs_image, identity_mean=1.0):
|
|
529
|
+
assert snap.im_scale is not None
|
|
530
|
+
apply_obs_image_scale_shift(step_stream=step_stream, scale=snap.im_scale, shift=0.0)
|
|
531
|
+
|
|
532
|
+
if _augment_scalar_active(spec=augment.shift_obs_image, identity_mean=0.0):
|
|
533
|
+
assert snap.im_shift is not None
|
|
534
|
+
apply_obs_image_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.im_shift)
|
|
535
|
+
|
|
536
|
+
return step_stream
|