physlink 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- physlink/__init__.py +19 -0
- physlink/adapters/__init__.py +0 -0
- physlink/adapters/dreamer.py +1068 -0
- physlink/core/__init__.py +3 -0
- physlink/core/_types.py +302 -0
- physlink/core/adapter.py +57 -0
- physlink/core/exceptions.py +146 -0
- physlink/core/spaces.py +276 -0
- physlink/core/validation.py +302 -0
- physlink/utils/__init__.py +0 -0
- physlink/utils/diagnostics.py +276 -0
- physlink/utils/visualization.py +75 -0
- physlink-0.1.2.dist-info/METADATA +117 -0
- physlink-0.1.2.dist-info/RECORD +17 -0
- physlink-0.1.2.dist-info/WHEEL +5 -0
- physlink-0.1.2.dist-info/licenses/LICENSE +21 -0
- physlink-0.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1068 @@
|
|
|
1
|
+
"""DreamerV3 adapter for PhysLink."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from physlink.core._types import TrajectoryBatch, TrajectoryBuffer
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from physlink.core._types import AdaptationRun
|
|
11
|
+
from physlink.core.adapter import BaseAdapter
|
|
12
|
+
from physlink.core.exceptions import ConfigurationError
|
|
13
|
+
from physlink.core.spaces import ActionSpace, ObservationSpace
|
|
14
|
+
from physlink.core.validation import ComplianceReport
|
|
15
|
+
|
|
16
|
+
MIN_OBS_DIMS: int = 4 # DreamerV3 requires >= 4 observation dimensions
|
|
17
|
+
MIN_ACT_DIMS: int = 1 # at least 1 action dimension required
|
|
18
|
+
|
|
19
|
+
_HEALTH_WINDOW: int = 50
|
|
20
|
+
_HEALTH_BASELINE_STEPS: int = 10
|
|
21
|
+
_ANOMALY_MULTIPLIER: float = 2.0
|
|
22
|
+
|
|
23
|
+
_STAGE_NAMES: tuple[str, ...] = (
|
|
24
|
+
"data_loading",
|
|
25
|
+
"world_model_update",
|
|
26
|
+
"actor_update",
|
|
27
|
+
"critic_update",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
_VIZ_SEQ_LEN: int = 50 # max steps used for triptych inference
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class _DebugPanel:
|
|
34
|
+
def __init__(self) -> None:
|
|
35
|
+
self.stages: dict[str, str] = {name: "waiting..." for name in _STAGE_NAMES}
|
|
36
|
+
|
|
37
|
+
def update_all(self, statuses: dict[str, str]) -> None:
|
|
38
|
+
self.stages.update(statuses)
|
|
39
|
+
|
|
40
|
+
def __rich__(self) -> Any: # noqa: ANN401
|
|
41
|
+
from rich.table import Table
|
|
42
|
+
|
|
43
|
+
table = Table(
|
|
44
|
+
title="[dim]Debug Hooks Panel[/dim]",
|
|
45
|
+
show_header=True,
|
|
46
|
+
box=None,
|
|
47
|
+
padding=(0, 1),
|
|
48
|
+
)
|
|
49
|
+
table.add_column("Stage", style="dim", no_wrap=True)
|
|
50
|
+
table.add_column("Status", no_wrap=True)
|
|
51
|
+
for name, status in self.stages.items():
|
|
52
|
+
label = name.replace("_", " ")
|
|
53
|
+
if status == "OK":
|
|
54
|
+
cell = "[bold green]OK[/bold green]"
|
|
55
|
+
elif status == "waiting...":
|
|
56
|
+
cell = "[dim]waiting...[/dim]"
|
|
57
|
+
else:
|
|
58
|
+
cell = f"[bold red]{status}[/bold red]"
|
|
59
|
+
table.add_row(label, cell)
|
|
60
|
+
return table
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@contextlib.contextmanager
|
|
64
|
+
def _build_progress_bar(
|
|
65
|
+
steps: int,
|
|
66
|
+
) -> Generator[tuple[Any, Any], None, None]:
|
|
67
|
+
"""Context manager yielding (progress, task_id) for the adaptation loop."""
|
|
68
|
+
from rich.progress import (
|
|
69
|
+
BarColumn,
|
|
70
|
+
MofNCompleteColumn,
|
|
71
|
+
Progress,
|
|
72
|
+
ProgressColumn,
|
|
73
|
+
SpinnerColumn,
|
|
74
|
+
TextColumn,
|
|
75
|
+
TimeRemainingColumn,
|
|
76
|
+
)
|
|
77
|
+
from rich.text import Text
|
|
78
|
+
|
|
79
|
+
class _StepsPerSecColumn(ProgressColumn):
|
|
80
|
+
def render(self, task: Any) -> Text: # noqa: ANN401
|
|
81
|
+
if task.speed is None:
|
|
82
|
+
return Text("? step/s", style="dim")
|
|
83
|
+
return Text(f"{task.speed:.1f} step/s", style="cyan")
|
|
84
|
+
|
|
85
|
+
class _HealthColumn(ProgressColumn):
|
|
86
|
+
def render(self, task: Any) -> Text: # noqa: ANN401
|
|
87
|
+
health = task.fields.get("health", "OK")
|
|
88
|
+
style = "bold green" if health == "OK" else "bold red"
|
|
89
|
+
return Text(health, style=style)
|
|
90
|
+
|
|
91
|
+
with Progress(
|
|
92
|
+
SpinnerColumn(),
|
|
93
|
+
TextColumn("[progress.description]{task.description}"),
|
|
94
|
+
BarColumn(),
|
|
95
|
+
MofNCompleteColumn(),
|
|
96
|
+
TextColumn("•"),
|
|
97
|
+
TimeRemainingColumn(),
|
|
98
|
+
TextColumn("•"),
|
|
99
|
+
_StepsPerSecColumn(),
|
|
100
|
+
TextColumn("•"),
|
|
101
|
+
_HealthColumn(),
|
|
102
|
+
) as progress:
|
|
103
|
+
task_id = progress.add_task(
|
|
104
|
+
"[cyan]DreamerV3 adaptation",
|
|
105
|
+
total=steps,
|
|
106
|
+
health="OK",
|
|
107
|
+
)
|
|
108
|
+
yield progress, task_id
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@contextlib.contextmanager
|
|
112
|
+
def _build_debug_layout(
|
|
113
|
+
steps: int,
|
|
114
|
+
panel: _DebugPanel,
|
|
115
|
+
) -> Generator[tuple[Any, Any], None, None]:
|
|
116
|
+
"""Context manager yielding (progress, task_id) with debug panel alongside."""
|
|
117
|
+
from rich.console import Group
|
|
118
|
+
from rich.live import Live
|
|
119
|
+
from rich.progress import (
|
|
120
|
+
BarColumn,
|
|
121
|
+
MofNCompleteColumn,
|
|
122
|
+
Progress,
|
|
123
|
+
ProgressColumn,
|
|
124
|
+
SpinnerColumn,
|
|
125
|
+
TextColumn,
|
|
126
|
+
TimeRemainingColumn,
|
|
127
|
+
)
|
|
128
|
+
from rich.text import Text
|
|
129
|
+
|
|
130
|
+
class _StepsPerSecColumn(ProgressColumn):
|
|
131
|
+
def render(self, task: Any) -> Text: # noqa: ANN401
|
|
132
|
+
if task.speed is None:
|
|
133
|
+
return Text("? step/s", style="dim")
|
|
134
|
+
return Text(f"{task.speed:.1f} step/s", style="cyan")
|
|
135
|
+
|
|
136
|
+
class _HealthColumn(ProgressColumn):
|
|
137
|
+
def render(self, task: Any) -> Text: # noqa: ANN401
|
|
138
|
+
health = task.fields.get("health", "OK")
|
|
139
|
+
style = "bold green" if health == "OK" else "bold red"
|
|
140
|
+
return Text(health, style=style)
|
|
141
|
+
|
|
142
|
+
progress = Progress(
|
|
143
|
+
SpinnerColumn(),
|
|
144
|
+
TextColumn("[progress.description]{task.description}"),
|
|
145
|
+
BarColumn(),
|
|
146
|
+
MofNCompleteColumn(),
|
|
147
|
+
TextColumn("•"),
|
|
148
|
+
TimeRemainingColumn(),
|
|
149
|
+
TextColumn("•"),
|
|
150
|
+
_StepsPerSecColumn(),
|
|
151
|
+
TextColumn("•"),
|
|
152
|
+
_HealthColumn(),
|
|
153
|
+
)
|
|
154
|
+
task_id = progress.add_task(
|
|
155
|
+
"[cyan]DreamerV3 adaptation",
|
|
156
|
+
total=steps,
|
|
157
|
+
health="OK",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
with Live(Group(progress, panel), refresh_per_second=4):
|
|
161
|
+
yield progress, task_id
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _save_checkpoint(
|
|
165
|
+
model: Any, # noqa: ANN401
|
|
166
|
+
actor: Any, # noqa: ANN401
|
|
167
|
+
critic: Any, # noqa: ANN401
|
|
168
|
+
step: int,
|
|
169
|
+
checkpoint_dir: str,
|
|
170
|
+
) -> str:
|
|
171
|
+
import datetime
|
|
172
|
+
import os
|
|
173
|
+
|
|
174
|
+
from safetensors.torch import save_file
|
|
175
|
+
|
|
176
|
+
import physlink
|
|
177
|
+
|
|
178
|
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
179
|
+
filename = f"checkpoint_step_{step}.safetensors"
|
|
180
|
+
path = os.path.join(checkpoint_dir, filename)
|
|
181
|
+
tensors: dict[str, Any] = {}
|
|
182
|
+
tensors.update({f"model.{k}": v for k, v in model.state_dict().items()})
|
|
183
|
+
tensors.update({f"actor.{k}": v for k, v in actor.state_dict().items()})
|
|
184
|
+
tensors.update({f"critic.{k}": v for k, v in critic.state_dict().items()})
|
|
185
|
+
metadata = {
|
|
186
|
+
"physlink_version": physlink.__version__,
|
|
187
|
+
"adapter_class": "DreamerV3Adapter",
|
|
188
|
+
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
|
189
|
+
"checkpoint_step": str(step),
|
|
190
|
+
}
|
|
191
|
+
save_file(tensors, path, metadata=metadata)
|
|
192
|
+
print(f"[physlink] Checkpoint saved: {os.path.abspath(path)}")
|
|
193
|
+
return path
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _check_checkpoint_metadata(path: str) -> dict[str, str]:
|
|
197
|
+
from safetensors import safe_open
|
|
198
|
+
|
|
199
|
+
import physlink
|
|
200
|
+
from physlink.core.exceptions import CheckpointCorruptError, CheckpointVersionError
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
with safe_open(path, framework="pt", device="cpu") as f:
|
|
204
|
+
metadata = f.metadata()
|
|
205
|
+
except Exception as exc:
|
|
206
|
+
raise CheckpointCorruptError(
|
|
207
|
+
f"Cannot open checkpoint: {path}\n"
|
|
208
|
+
f" Got: {type(exc).__name__}: {exc}\n"
|
|
209
|
+
f" Expected: valid safetensors file\n"
|
|
210
|
+
f" Fix: re-run adapter.fit() to generate a fresh checkpoint."
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if metadata is None or "physlink_version" not in metadata:
|
|
214
|
+
raise CheckpointCorruptError(
|
|
215
|
+
f"Checkpoint metadata missing or incomplete: {path}\n"
|
|
216
|
+
f" Got: metadata={metadata!r}\n"
|
|
217
|
+
f" Expected: metadata dict with key 'physlink_version'\n"
|
|
218
|
+
f" Fix: re-run adapter.fit() to generate a fresh checkpoint."
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
checkpoint_version = metadata["physlink_version"]
|
|
222
|
+
current_version = physlink.__version__
|
|
223
|
+
cv_parts = checkpoint_version.split(".")
|
|
224
|
+
cur_parts = current_version.split(".")
|
|
225
|
+
if cv_parts[:2] != cur_parts[:2]:
|
|
226
|
+
raise CheckpointVersionError(
|
|
227
|
+
f"Checkpoint version incompatible: {path}\n"
|
|
228
|
+
f" Got: checkpoint saved with physlink=={checkpoint_version}\n"
|
|
229
|
+
f" Expected: compatible version (same major.minor as {current_version})\n"
|
|
230
|
+
f" Fix: re-run adapter.fit() to generate a fresh checkpoint.",
|
|
231
|
+
checkpoint_version=checkpoint_version,
|
|
232
|
+
current_version=current_version,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return metadata
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _share_panel(export_path: str) -> None:
|
|
239
|
+
"""Trigger the Colab share panel: copy notebook URL to clipboard.
|
|
240
|
+
|
|
241
|
+
In Google Colab, copies the current notebook URL to the clipboard via
|
|
242
|
+
Javascript. Outside Colab, prints a graceful fallback message.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
export_path: Absolute path to the export directory. Shown in fallback
|
|
246
|
+
message so collaborators know where to find the artifacts.
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
>>> _share_panel("./physlink_export")
|
|
250
|
+
[physlink] Share panel: URL copy is only available in Google Colab.
|
|
251
|
+
...
|
|
252
|
+
"""
|
|
253
|
+
try:
|
|
254
|
+
import google.colab # noqa: F401
|
|
255
|
+
in_colab = True
|
|
256
|
+
except ImportError:
|
|
257
|
+
in_colab = False
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
if in_colab:
|
|
261
|
+
from IPython.display import Javascript, display
|
|
262
|
+
display(Javascript(
|
|
263
|
+
"navigator.clipboard.writeText(window.location.href)"
|
|
264
|
+
".then(() => console.log('[physlink] Notebook URL copied.'));"
|
|
265
|
+
))
|
|
266
|
+
print("[physlink] Share panel: notebook URL copied to clipboard.")
|
|
267
|
+
print(f"[physlink] Export path for collaborators: {export_path}")
|
|
268
|
+
else:
|
|
269
|
+
print(
|
|
270
|
+
"[physlink] Share panel: URL copy is only available in Google Colab.\n"
|
|
271
|
+
f" To share your results, send the export directory: {export_path}"
|
|
272
|
+
)
|
|
273
|
+
except Exception as exc:
|
|
274
|
+
print(f"[physlink] Share panel unavailable: {type(exc).__name__}")
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class DreamerV3Adapter(BaseAdapter):
|
|
278
|
+
"""DreamerV3 adapter for physical simulation reinforcement learning.
|
|
279
|
+
|
|
280
|
+
Validates space compatibility at construction time. Training, visualization,
|
|
281
|
+
and export are deferred to fit() / visualize() / export() respectively.
|
|
282
|
+
No model weights are loaded and no GPU is required at construction.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
obs_space: Observation space with dims >= 4.
|
|
286
|
+
act_space: Action space with dims >= 1.
|
|
287
|
+
|
|
288
|
+
Raises:
|
|
289
|
+
ConfigurationError: If obs_space.dims < 4 or act_space.dims < 1.
|
|
290
|
+
|
|
291
|
+
Example:
|
|
292
|
+
>>> from physlink import DreamerV3Adapter, ObservationSpace, ActionSpace
|
|
293
|
+
>>> obs = ObservationSpace.from_proprioception(joints=7, include_velocity=True)
|
|
294
|
+
>>> act = ActionSpace.continuous(dims=7, bounds=[(-1.0, 1.0)] * 7)
|
|
295
|
+
>>> adapter = DreamerV3Adapter(obs, act)
|
|
296
|
+
>>> adapter.obs_space.dims
|
|
297
|
+
14
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def __init__(self, obs_space: ObservationSpace, act_space: ActionSpace) -> None:
|
|
301
|
+
if obs_space.dims < MIN_OBS_DIMS:
|
|
302
|
+
raise ConfigurationError(
|
|
303
|
+
f"DreamerV3Adapter: incompatible obs_space.\n"
|
|
304
|
+
f" Got: obs_space.dims={obs_space.dims}\n"
|
|
305
|
+
f" Expected: obs_space.dims >= {MIN_OBS_DIMS} (DreamerV3 minimum)\n"
|
|
306
|
+
f" Fix: construct ObservationSpace with joints >= {MIN_OBS_DIMS}, "
|
|
307
|
+
f"or use include_velocity=True to double the dimension count."
|
|
308
|
+
)
|
|
309
|
+
if act_space.dims < MIN_ACT_DIMS:
|
|
310
|
+
raise ConfigurationError(
|
|
311
|
+
f"DreamerV3Adapter: incompatible act_space.\n"
|
|
312
|
+
f" Got: act_space.dims={act_space.dims}\n"
|
|
313
|
+
f" Expected: act_space.dims >= {MIN_ACT_DIMS}\n"
|
|
314
|
+
f" Fix: construct ActionSpace with dims >= 1."
|
|
315
|
+
)
|
|
316
|
+
super().__init__(obs_space, act_space)
|
|
317
|
+
self._model: Any | None = None
|
|
318
|
+
self._actor: Any | None = None
|
|
319
|
+
self._critic: Any | None = None
|
|
320
|
+
self._loss_history: list[float] = []
|
|
321
|
+
self._baseline_loss: float | None = None
|
|
322
|
+
self._fit_elapsed_seconds: float | None = None
|
|
323
|
+
self._triptych_path: str | None = None
|
|
324
|
+
self._last_checkpoint_path: str | None = None
|
|
325
|
+
self._invariants: list = []
|
|
326
|
+
self._invariant_residuals: dict[str, list[float]] = {}
|
|
327
|
+
self._soft_penalty_per_step: float = 0.0
|
|
328
|
+
|
|
329
|
+
def _initialize_model(self, device: Any) -> None: # noqa: ANN401
|
|
330
|
+
import torch.nn as nn
|
|
331
|
+
|
|
332
|
+
obs_dims = self.obs_space.dims
|
|
333
|
+
act_dims = self.act_space.dims
|
|
334
|
+
hidden = 256
|
|
335
|
+
latent = 256
|
|
336
|
+
|
|
337
|
+
class _WorldModel(nn.Module):
|
|
338
|
+
def __init__(self) -> None:
|
|
339
|
+
super().__init__()
|
|
340
|
+
self.encoder = nn.Sequential(
|
|
341
|
+
nn.Linear(obs_dims, hidden), nn.ELU(),
|
|
342
|
+
nn.Linear(hidden, hidden), nn.ELU(),
|
|
343
|
+
)
|
|
344
|
+
self.gru = nn.GRUCell(hidden + act_dims, hidden)
|
|
345
|
+
self.posterior = nn.Sequential(
|
|
346
|
+
nn.Linear(hidden + hidden, hidden), nn.ELU(),
|
|
347
|
+
nn.Linear(hidden, latent * 2),
|
|
348
|
+
)
|
|
349
|
+
self.prior = nn.Sequential(
|
|
350
|
+
nn.Linear(hidden, hidden), nn.ELU(),
|
|
351
|
+
nn.Linear(hidden, latent * 2),
|
|
352
|
+
)
|
|
353
|
+
self.decoder = nn.Sequential(
|
|
354
|
+
nn.Linear(hidden + latent, hidden), nn.ELU(),
|
|
355
|
+
nn.Linear(hidden, obs_dims),
|
|
356
|
+
)
|
|
357
|
+
self.reward_head = nn.Sequential(
|
|
358
|
+
nn.Linear(hidden + latent, hidden), nn.ELU(),
|
|
359
|
+
nn.Linear(hidden, 1),
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
class _Actor(nn.Module):
|
|
363
|
+
def __init__(self) -> None:
|
|
364
|
+
super().__init__()
|
|
365
|
+
self.net = nn.Sequential(
|
|
366
|
+
nn.Linear(hidden + latent, hidden), nn.ELU(),
|
|
367
|
+
nn.Linear(hidden, hidden), nn.ELU(),
|
|
368
|
+
nn.Linear(hidden, act_dims * 2),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
class _Critic(nn.Module):
|
|
372
|
+
def __init__(self) -> None:
|
|
373
|
+
super().__init__()
|
|
374
|
+
self.net = nn.Sequential(
|
|
375
|
+
nn.Linear(hidden + latent, hidden), nn.ELU(),
|
|
376
|
+
nn.Linear(hidden, hidden), nn.ELU(),
|
|
377
|
+
nn.Linear(hidden, 1),
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
self._model = _WorldModel().to(device)
|
|
381
|
+
self._actor = _Actor().to(device)
|
|
382
|
+
self._critic = _Critic().to(device)
|
|
383
|
+
|
|
384
|
+
def _reset_training_state(self) -> None:
|
|
385
|
+
"""Reset all mutable training state for a fresh fit() run (NFR-09)."""
|
|
386
|
+
self._loss_history = []
|
|
387
|
+
self._baseline_loss = None
|
|
388
|
+
self._invariant_residuals = {}
|
|
389
|
+
self._soft_penalty_per_step = 0.0
|
|
390
|
+
|
|
391
|
+
def _apply_invariants(self, trajectories: TrajectoryBatch) -> TrajectoryBatch:
|
|
392
|
+
"""Apply registered invariants: filter hard-mode violations, compute soft penalty."""
|
|
393
|
+
if not self._invariants:
|
|
394
|
+
return trajectories
|
|
395
|
+
|
|
396
|
+
from physlink.core.exceptions import ValidationError
|
|
397
|
+
|
|
398
|
+
data = trajectories.data
|
|
399
|
+
for inv in self._invariants:
|
|
400
|
+
self._invariant_residuals[inv.name] = []
|
|
401
|
+
|
|
402
|
+
hard_mask: list[bool] = [True] * len(data)
|
|
403
|
+
|
|
404
|
+
for inv in self._invariants:
|
|
405
|
+
for idx, traj in enumerate(data):
|
|
406
|
+
try:
|
|
407
|
+
residual = float(inv.fn(traj))
|
|
408
|
+
except Exception as exc:
|
|
409
|
+
print(
|
|
410
|
+
f"[physlink] Invariant '{inv.name}' failed on trajectory {idx}: "
|
|
411
|
+
f"{type(exc).__name__} — treating residual as 0.0"
|
|
412
|
+
)
|
|
413
|
+
residual = 0.0
|
|
414
|
+
self._invariant_residuals[inv.name].append(residual)
|
|
415
|
+
|
|
416
|
+
if inv.mode == "hard" and residual > inv.tolerance:
|
|
417
|
+
hard_mask[idx] = False
|
|
418
|
+
print(
|
|
419
|
+
f"[physlink] Invariant '{inv.name}' rejected trajectory {idx}: "
|
|
420
|
+
f"residual={residual:.4f} > tolerance={inv.tolerance}"
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
filtered = [d for d, keep in zip(data, hard_mask) if keep]
|
|
424
|
+
if not filtered:
|
|
425
|
+
raise ValidationError(
|
|
426
|
+
f"register_invariant (hard mode): all {len(data)} trajectories rejected.\n"
|
|
427
|
+
f" Got: 0 trajectories remaining after hard-mode invariant filtering\n"
|
|
428
|
+
f" Expected: at least 1 trajectory passing all hard-mode invariants\n"
|
|
429
|
+
f" Fix: lower tolerance, fix the invariant function, "
|
|
430
|
+
f"or switch to mode='soft'."
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
soft_surplus = 0.0
|
|
434
|
+
for inv in self._invariants:
|
|
435
|
+
if inv.mode == "soft":
|
|
436
|
+
for r in self._invariant_residuals[inv.name]:
|
|
437
|
+
if r > inv.tolerance:
|
|
438
|
+
soft_surplus += r - inv.tolerance
|
|
439
|
+
self._soft_penalty_per_step = soft_surplus / max(len(data), 1)
|
|
440
|
+
|
|
441
|
+
return TrajectoryBatch(data=filtered)
|
|
442
|
+
|
|
443
|
+
def compliance_report(self) -> ComplianceReport:
|
|
444
|
+
"""Return a ComplianceReport summarizing invariant compliance from the last fit().
|
|
445
|
+
|
|
446
|
+
Reads ``_invariants`` and ``_invariant_residuals`` stored on the adapter.
|
|
447
|
+
Pure computation — no side effects, safe to call multiple times.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
ComplianceReport with per-invariant summary and violation details.
|
|
451
|
+
Empty report (no entries) if no invariants are registered.
|
|
452
|
+
Zero-trajectory report if fit() has not yet been called.
|
|
453
|
+
|
|
454
|
+
Example:
|
|
455
|
+
>>> register_invariant(adapter, "mass", fn, tolerance=0.01)
|
|
456
|
+
>>> adapter.fit(trajectories, steps=100)
|
|
457
|
+
>>> report = adapter.compliance_report()
|
|
458
|
+
>>> print(report.summary())
|
|
459
|
+
mass: PASS (max_residual=0.0042, threshold=0.0100, violations=0/10)
|
|
460
|
+
"""
|
|
461
|
+
stats: list[dict[str, Any]] = []
|
|
462
|
+
violation_list: list[dict[str, Any]] = []
|
|
463
|
+
|
|
464
|
+
for inv in self._invariants:
|
|
465
|
+
residuals = self._invariant_residuals.get(inv.name, [])
|
|
466
|
+
max_residual = max(residuals) if residuals else 0.0
|
|
467
|
+
violation_count = sum(1 for r in residuals if r > inv.tolerance)
|
|
468
|
+
total = len(residuals)
|
|
469
|
+
|
|
470
|
+
stats.append({
|
|
471
|
+
"name": inv.name,
|
|
472
|
+
"max_residual": max_residual,
|
|
473
|
+
"threshold": inv.tolerance,
|
|
474
|
+
"violation_count": violation_count,
|
|
475
|
+
"total": total,
|
|
476
|
+
})
|
|
477
|
+
|
|
478
|
+
for idx, residual in enumerate(residuals):
|
|
479
|
+
if residual > inv.tolerance:
|
|
480
|
+
violation_list.append({
|
|
481
|
+
"invariant_name": inv.name,
|
|
482
|
+
"trajectory_idx": idx,
|
|
483
|
+
"residual": residual,
|
|
484
|
+
"possible_cause": (
|
|
485
|
+
f"Residual {residual:.4f} exceeds tolerance {inv.tolerance:.4f}."
|
|
486
|
+
),
|
|
487
|
+
})
|
|
488
|
+
|
|
489
|
+
return ComplianceReport(
|
|
490
|
+
_stats=stats,
|
|
491
|
+
_violation_list=violation_list,
|
|
492
|
+
_residuals_by_invariant={
|
|
493
|
+
inv.name: list(self._invariant_residuals.get(inv.name, []))
|
|
494
|
+
for inv in self._invariants
|
|
495
|
+
},
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
def load_checkpoint(self, path: str) -> None:
|
|
499
|
+
"""Load model weights from a safetensors checkpoint.
|
|
500
|
+
|
|
501
|
+
Reads checkpoint metadata before loading weights for early detection
|
|
502
|
+
of version incompatibility or file corruption.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
path: Path to the .safetensors checkpoint file to load.
|
|
506
|
+
|
|
507
|
+
Raises:
|
|
508
|
+
CheckpointCorruptError: If the file is malformed, unreadable, or
|
|
509
|
+
missing required metadata.
|
|
510
|
+
CheckpointVersionError: If physlink_version in the checkpoint
|
|
511
|
+
metadata is incompatible with the installed version
|
|
512
|
+
(different major.minor component).
|
|
513
|
+
|
|
514
|
+
Example:
|
|
515
|
+
>>> adapter = DreamerV3Adapter(obs, act)
|
|
516
|
+
>>> adapter.load_checkpoint("./physlink_checkpoints/checkpoint_step_1000.safetensors")
|
|
517
|
+
"""
|
|
518
|
+
_check_checkpoint_metadata(path)
|
|
519
|
+
|
|
520
|
+
import os
|
|
521
|
+
|
|
522
|
+
import torch
|
|
523
|
+
from safetensors.torch import load_file
|
|
524
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
525
|
+
if self._model is None:
|
|
526
|
+
self._initialize_model(device)
|
|
527
|
+
self._model.to(device)
|
|
528
|
+
self._actor.to(device)
|
|
529
|
+
self._critic.to(device)
|
|
530
|
+
state_dict_all = load_file(path, device="cpu")
|
|
531
|
+
model_sd = {
|
|
532
|
+
k[len("model."):]: v for k, v in state_dict_all.items() if k.startswith("model.")
|
|
533
|
+
}
|
|
534
|
+
actor_sd = {
|
|
535
|
+
k[len("actor."):]: v for k, v in state_dict_all.items() if k.startswith("actor.")
|
|
536
|
+
}
|
|
537
|
+
critic_sd = {
|
|
538
|
+
k[len("critic."):]: v for k, v in state_dict_all.items() if k.startswith("critic.")
|
|
539
|
+
}
|
|
540
|
+
self._model.load_state_dict(model_sd)
|
|
541
|
+
self._actor.load_state_dict(actor_sd)
|
|
542
|
+
self._critic.load_state_dict(critic_sd)
|
|
543
|
+
print(f"[physlink] Checkpoint loaded: {os.path.abspath(path)}")
|
|
544
|
+
|
|
545
|
+
def _compute_health(self, loss: float) -> str:
|
|
546
|
+
self._loss_history.append(loss)
|
|
547
|
+
if len(self._loss_history) > _HEALTH_WINDOW:
|
|
548
|
+
self._loss_history = self._loss_history[-_HEALTH_WINDOW:]
|
|
549
|
+
|
|
550
|
+
if self._baseline_loss is None and len(self._loss_history) >= _HEALTH_BASELINE_STEPS:
|
|
551
|
+
self._baseline_loss = (
|
|
552
|
+
sum(self._loss_history[:_HEALTH_BASELINE_STEPS]) / _HEALTH_BASELINE_STEPS
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
if self._baseline_loss is None or self._baseline_loss <= 0:
|
|
556
|
+
return "OK"
|
|
557
|
+
|
|
558
|
+
current_avg = sum(self._loss_history) / len(self._loss_history)
|
|
559
|
+
return "ANOMALY" if current_avg > _ANOMALY_MULTIPLIER * self._baseline_loss else "OK"
|
|
560
|
+
|
|
561
|
+
def _training_step(self, batch: Any, device: Any) -> Any: # noqa: ANN401
|
|
562
|
+
import torch
|
|
563
|
+
import torch.nn as nn
|
|
564
|
+
|
|
565
|
+
obs_all, act_all = batch # pre-processed tensors: (N, obs_dims), (N, act_dims)
|
|
566
|
+
n = obs_all.shape[0]
|
|
567
|
+
|
|
568
|
+
batch_size = 16
|
|
569
|
+
seq_len = min(50, max(1, n))
|
|
570
|
+
max_start = max(0, n - seq_len)
|
|
571
|
+
|
|
572
|
+
starts = torch.randint(0, max_start + 1, (batch_size,)).tolist()
|
|
573
|
+
obs_seq = torch.stack([obs_all[s: s + seq_len] for s in starts])
|
|
574
|
+
act_seq = torch.stack([act_all[s: s + seq_len] for s in starts])
|
|
575
|
+
|
|
576
|
+
b_size, t_steps, obs_d = obs_seq.shape
|
|
577
|
+
gru_hidden = self._model.gru.hidden_size
|
|
578
|
+
|
|
579
|
+
with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
|
|
580
|
+
h_state = torch.zeros(b_size, gru_hidden, device=device)
|
|
581
|
+
latents: list[Any] = []
|
|
582
|
+
kl_losses: list[Any] = []
|
|
583
|
+
recon_losses: list[Any] = []
|
|
584
|
+
|
|
585
|
+
for t in range(t_steps):
|
|
586
|
+
obs_t = obs_seq[:, t]
|
|
587
|
+
act_t = act_seq[:, t]
|
|
588
|
+
|
|
589
|
+
encoded = self._model.encoder(obs_t)
|
|
590
|
+
gru_input = torch.cat([encoded, act_t], dim=-1)
|
|
591
|
+
h_state = self._model.gru(gru_input, h_state)
|
|
592
|
+
|
|
593
|
+
post_params = self._model.posterior(torch.cat([h_state, encoded], dim=-1))
|
|
594
|
+
post_mean, post_log_std = post_params.chunk(2, dim=-1)
|
|
595
|
+
post_std = post_log_std.clamp(-5, 2).exp()
|
|
596
|
+
z = post_mean + post_std * torch.randn_like(post_std)
|
|
597
|
+
latents.append(z)
|
|
598
|
+
|
|
599
|
+
prior_params = self._model.prior(h_state)
|
|
600
|
+
prior_mean, prior_log_std = prior_params.chunk(2, dim=-1)
|
|
601
|
+
prior_std = prior_log_std.clamp(-5, 2).exp().clamp(min=1e-8)
|
|
602
|
+
|
|
603
|
+
kl = 0.5 * (
|
|
604
|
+
(post_mean - prior_mean).pow(2) / prior_std.pow(2)
|
|
605
|
+
+ (post_std / prior_std).pow(2)
|
|
606
|
+
- 1
|
|
607
|
+
- 2 * (post_std / prior_std).log()
|
|
608
|
+
).sum(-1).mean()
|
|
609
|
+
kl_losses.append(kl)
|
|
610
|
+
|
|
611
|
+
recon = self._model.decoder(torch.cat([h_state, z], dim=-1))
|
|
612
|
+
recon_losses.append(nn.functional.mse_loss(recon, obs_t))
|
|
613
|
+
|
|
614
|
+
wm_loss = torch.stack(recon_losses).mean() + 0.1 * torch.stack(kl_losses).mean()
|
|
615
|
+
|
|
616
|
+
# Imagination rollout
|
|
617
|
+
imagine_horizon = 15
|
|
618
|
+
hidden_i = h_state.detach()
|
|
619
|
+
latent_i = latents[-1].detach()
|
|
620
|
+
|
|
621
|
+
imagined_values: list[Any] = []
|
|
622
|
+
imagined_rewards: list[Any] = []
|
|
623
|
+
|
|
624
|
+
for _ in range(imagine_horizon):
|
|
625
|
+
actor_input = torch.cat([hidden_i, latent_i], dim=-1)
|
|
626
|
+
actor_params = self._actor.net(actor_input)
|
|
627
|
+
act_mean, act_log_std = actor_params.chunk(2, dim=-1)
|
|
628
|
+
act_i = torch.tanh(
|
|
629
|
+
act_mean + act_log_std.clamp(-5, 2).exp() * torch.randn_like(act_mean)
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
enc_i = self._model.encoder(torch.zeros(b_size, obs_d, device=device))
|
|
633
|
+
gru_in = torch.cat([enc_i, act_i], dim=-1)
|
|
634
|
+
hidden_i = self._model.gru(gru_in, hidden_i)
|
|
635
|
+
|
|
636
|
+
prior_p = self._model.prior(hidden_i)
|
|
637
|
+
latent_i, _ = prior_p.chunk(2, dim=-1)
|
|
638
|
+
|
|
639
|
+
critic_in = torch.cat([hidden_i, latent_i], dim=-1)
|
|
640
|
+
imagined_values.append(self._critic.net(critic_in))
|
|
641
|
+
imagined_rewards.append(
|
|
642
|
+
self._model.reward_head(torch.cat([hidden_i, latent_i], dim=-1))
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# λ-returns (simplified)
|
|
646
|
+
returns = imagined_values[-1].detach()
|
|
647
|
+
for v, r in zip(reversed(imagined_values[:-1]), reversed(imagined_rewards[:-1])):
|
|
648
|
+
returns = r + 0.99 * (0.95 * v + 0.05 * returns)
|
|
649
|
+
|
|
650
|
+
actor_loss = -returns.mean()
|
|
651
|
+
|
|
652
|
+
critic_in = torch.cat([h_state.detach(), latents[-1].detach()], dim=-1)
|
|
653
|
+
critic_val = self._critic.net(critic_in)
|
|
654
|
+
critic_loss = nn.functional.mse_loss(critic_val, returns.detach())
|
|
655
|
+
|
|
656
|
+
total_loss = wm_loss + actor_loss + critic_loss + self._soft_penalty_per_step
|
|
657
|
+
|
|
658
|
+
return total_loss
|
|
659
|
+
|
|
660
|
+
def fit(
|
|
661
|
+
self,
|
|
662
|
+
trajectories: list[dict[str, Any]] | TrajectoryBatch | TrajectoryBuffer,
|
|
663
|
+
steps: int,
|
|
664
|
+
checkpoint_interval_steps: int = 1000,
|
|
665
|
+
debug_hooks: bool = False,
|
|
666
|
+
checkpoint_dir: str = "physlink_checkpoints",
|
|
667
|
+
) -> "AdaptationRun":
|
|
668
|
+
"""Run the DreamerV3 adaptation loop with a live progress bar.
|
|
669
|
+
|
|
670
|
+
Adapts the DreamerV3 world model to the provided trajectory data over
|
|
671
|
+
``steps`` gradient updates. Displays a rich progress bar in Colab output
|
|
672
|
+
with step count, ETA, prediction health (OK/ANOMALY), and throughput.
|
|
673
|
+
|
|
674
|
+
Calling fit() multiple times is safe: each call resets optimizer state
|
|
675
|
+
and training history for a fresh run (NFR-09 idempotence).
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
trajectories: Trajectory dataset. ``list[dict]`` and ``TrajectoryBuffer``
|
|
679
|
+
are silently converted to ``TrajectoryBatch``. Each dict must contain
|
|
680
|
+
at minimum "obs" and "action" keys with numpy-compatible values.
|
|
681
|
+
steps: Total gradient steps to run. Must be > 0.
|
|
682
|
+
checkpoint_interval_steps: Interval (in steps) between checkpoint
|
|
683
|
+
saves. A checkpoint file is written every this many steps. Must
|
|
684
|
+
be > 0.
|
|
685
|
+
debug_hooks: When True, displays a debug panel alongside the progress
|
|
686
|
+
bar showing pipeline stage statuses (data_loading, world_model_update,
|
|
687
|
+
actor_update, critic_update). Each stage shows OK or a diagnostic
|
|
688
|
+
status. Defaults to False (opt-in, not default).
|
|
689
|
+
checkpoint_dir: Directory where checkpoint files are written. Defaults
|
|
690
|
+
to "physlink_checkpoints" relative to the current working directory.
|
|
691
|
+
|
|
692
|
+
Returns:
|
|
693
|
+
AdaptationRun capturing config, step count, checkpoint paths, and elapsed time.
|
|
694
|
+
|
|
695
|
+
Raises:
|
|
696
|
+
ValidationError: If steps <= 0 or checkpoint_interval_steps <= 0.
|
|
697
|
+
|
|
698
|
+
Example:
|
|
699
|
+
>>> from physlink import DreamerV3Adapter, ObservationSpace, ActionSpace
|
|
700
|
+
>>> obs = ObservationSpace.from_proprioception(joints=7)
|
|
701
|
+
>>> act = ActionSpace.continuous(dims=7, bounds=[(-1.0, 1.0)] * 7)
|
|
702
|
+
>>> adapter = DreamerV3Adapter(obs, act)
|
|
703
|
+
>>> trajectories = [{"obs": [0.1] * 7, "action": [0.0] * 7}] * 100
|
|
704
|
+
>>> run = adapter.fit(trajectories, steps=10, debug_hooks=True)
|
|
705
|
+
"""
|
|
706
|
+
import time
|
|
707
|
+
|
|
708
|
+
from physlink.core._types import AdaptationConfig, AdaptationRun
|
|
709
|
+
from physlink.core.exceptions import ValidationError
|
|
710
|
+
|
|
711
|
+
if isinstance(steps, bool) or not isinstance(steps, int) or steps <= 0:
|
|
712
|
+
raise ValidationError(
|
|
713
|
+
f"DreamerV3Adapter.fit: invalid steps.\n"
|
|
714
|
+
f" Got: steps={steps}\n"
|
|
715
|
+
f" Expected: steps > 0\n"
|
|
716
|
+
f" Fix: provide a positive integer, e.g. steps=10000."
|
|
717
|
+
)
|
|
718
|
+
if (
|
|
719
|
+
isinstance(checkpoint_interval_steps, bool)
|
|
720
|
+
or not isinstance(checkpoint_interval_steps, int)
|
|
721
|
+
or checkpoint_interval_steps <= 0
|
|
722
|
+
):
|
|
723
|
+
raise ValidationError(
|
|
724
|
+
f"DreamerV3Adapter.fit: invalid checkpoint_interval_steps.\n"
|
|
725
|
+
f" Got: checkpoint_interval_steps={checkpoint_interval_steps}\n"
|
|
726
|
+
f" Expected: checkpoint_interval_steps > 0\n"
|
|
727
|
+
f" Fix: provide a positive integer, e.g. checkpoint_interval_steps=1000."
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
self._reset_training_state()
|
|
731
|
+
|
|
732
|
+
if isinstance(trajectories, TrajectoryBuffer):
|
|
733
|
+
trajectories = trajectories.to_batch()
|
|
734
|
+
if isinstance(trajectories, list):
|
|
735
|
+
trajectories = TrajectoryBatch.from_list(trajectories)
|
|
736
|
+
|
|
737
|
+
trajectories = self._apply_invariants(trajectories)
|
|
738
|
+
|
|
739
|
+
import torch
|
|
740
|
+
|
|
741
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
742
|
+
|
|
743
|
+
if self._model is None:
|
|
744
|
+
self._initialize_model(device)
|
|
745
|
+
|
|
746
|
+
# Pre-process trajectory data to tensors once
|
|
747
|
+
raw_data = trajectories.data
|
|
748
|
+
obs_all = torch.tensor(
|
|
749
|
+
[d["obs"] for d in raw_data], dtype=torch.float32, device=device
|
|
750
|
+
)
|
|
751
|
+
act_raw = torch.tensor(
|
|
752
|
+
[d["action"] for d in raw_data], dtype=torch.float32, device=device
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Align action dims to model's act_dims via zero-padding or truncation
|
|
756
|
+
model_act_dims = self.act_space.dims
|
|
757
|
+
if act_raw.shape[-1] < model_act_dims:
|
|
758
|
+
pad = torch.zeros(
|
|
759
|
+
act_raw.shape[0], model_act_dims - act_raw.shape[-1], device=device
|
|
760
|
+
)
|
|
761
|
+
act_all = torch.cat([act_raw, pad], dim=-1)
|
|
762
|
+
elif act_raw.shape[-1] > model_act_dims:
|
|
763
|
+
act_all = act_raw[:, :model_act_dims]
|
|
764
|
+
else:
|
|
765
|
+
act_all = act_raw
|
|
766
|
+
|
|
767
|
+
tensor_batch = (obs_all, act_all)
|
|
768
|
+
|
|
769
|
+
all_params = (
|
|
770
|
+
list(self._model.parameters())
|
|
771
|
+
+ list(self._actor.parameters())
|
|
772
|
+
+ list(self._critic.parameters())
|
|
773
|
+
)
|
|
774
|
+
optimizer = torch.optim.Adam(all_params, lr=3e-4)
|
|
775
|
+
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))
|
|
776
|
+
|
|
777
|
+
_fit_start_time = time.monotonic()
|
|
778
|
+
_run_checkpoint_paths: list[str] = []
|
|
779
|
+
|
|
780
|
+
if debug_hooks:
|
|
781
|
+
debug_panel = _DebugPanel()
|
|
782
|
+
with _build_debug_layout(steps, debug_panel) as (progress, task_id):
|
|
783
|
+
for step_idx in range(steps):
|
|
784
|
+
stage_statuses = {name: "OK" for name in _STAGE_NAMES}
|
|
785
|
+
optimizer.zero_grad(set_to_none=True)
|
|
786
|
+
try:
|
|
787
|
+
loss = self._training_step(tensor_batch, device)
|
|
788
|
+
except Exception as exc:
|
|
789
|
+
for name in ("world_model_update", "actor_update", "critic_update"):
|
|
790
|
+
stage_statuses[name] = type(exc).__name__
|
|
791
|
+
debug_panel.update_all(stage_statuses)
|
|
792
|
+
raise
|
|
793
|
+
scaler.scale(loss).backward()
|
|
794
|
+
scaler.unscale_(optimizer)
|
|
795
|
+
torch.nn.utils.clip_grad_norm_(all_params, max_norm=100.0)
|
|
796
|
+
scaler.step(optimizer)
|
|
797
|
+
scaler.update()
|
|
798
|
+
debug_panel.update_all(stage_statuses)
|
|
799
|
+
progress.update(
|
|
800
|
+
task_id, advance=1, health=self._compute_health(loss.item())
|
|
801
|
+
)
|
|
802
|
+
completed = step_idx + 1
|
|
803
|
+
if completed % checkpoint_interval_steps == 0:
|
|
804
|
+
_ckpt = _save_checkpoint(
|
|
805
|
+
self._model, self._actor, self._critic,
|
|
806
|
+
completed, checkpoint_dir,
|
|
807
|
+
)
|
|
808
|
+
self._last_checkpoint_path = _ckpt
|
|
809
|
+
_run_checkpoint_paths.append(_ckpt)
|
|
810
|
+
else:
|
|
811
|
+
with _build_progress_bar(steps) as (progress, task_id):
|
|
812
|
+
for step_idx in range(steps):
|
|
813
|
+
optimizer.zero_grad(set_to_none=True)
|
|
814
|
+
loss = self._training_step(tensor_batch, device)
|
|
815
|
+
scaler.scale(loss).backward()
|
|
816
|
+
scaler.unscale_(optimizer)
|
|
817
|
+
torch.nn.utils.clip_grad_norm_(all_params, max_norm=100.0)
|
|
818
|
+
scaler.step(optimizer)
|
|
819
|
+
scaler.update()
|
|
820
|
+
progress.update(
|
|
821
|
+
task_id, advance=1, health=self._compute_health(loss.item())
|
|
822
|
+
)
|
|
823
|
+
completed = step_idx + 1
|
|
824
|
+
if completed % checkpoint_interval_steps == 0:
|
|
825
|
+
_ckpt = _save_checkpoint(
|
|
826
|
+
self._model, self._actor, self._critic,
|
|
827
|
+
completed, checkpoint_dir,
|
|
828
|
+
)
|
|
829
|
+
self._last_checkpoint_path = _ckpt
|
|
830
|
+
_run_checkpoint_paths.append(_ckpt)
|
|
831
|
+
|
|
832
|
+
self._fit_elapsed_seconds = time.monotonic() - _fit_start_time
|
|
833
|
+
|
|
834
|
+
_config = AdaptationConfig(
|
|
835
|
+
obs_space=self.obs_space,
|
|
836
|
+
act_space=self.act_space,
|
|
837
|
+
steps=steps,
|
|
838
|
+
checkpoint_interval_steps=checkpoint_interval_steps,
|
|
839
|
+
checkpoint_dir=checkpoint_dir,
|
|
840
|
+
)
|
|
841
|
+
_run = AdaptationRun(
|
|
842
|
+
config=_config,
|
|
843
|
+
current_step=completed,
|
|
844
|
+
checkpoint_paths=_run_checkpoint_paths,
|
|
845
|
+
elapsed_seconds=self._fit_elapsed_seconds or 0.0,
|
|
846
|
+
)
|
|
847
|
+
return _run
|
|
848
|
+
|
|
849
|
+
def explain(self) -> dict[str, Any]:
|
|
850
|
+
"""Return a metadata dict describing this adapter's space configuration.
|
|
851
|
+
|
|
852
|
+
Returns:
|
|
853
|
+
A JSON-serializable dict with keys: type, obs_space, act_space.
|
|
854
|
+
|
|
855
|
+
Example:
|
|
856
|
+
>>> adapter = DreamerV3Adapter(obs, act)
|
|
857
|
+
>>> info = adapter.explain()
|
|
858
|
+
>>> info["type"]
|
|
859
|
+
'DreamerV3Adapter'
|
|
860
|
+
"""
|
|
861
|
+
return {
|
|
862
|
+
"type": "DreamerV3Adapter",
|
|
863
|
+
"obs_space": self.obs_space.explain(),
|
|
864
|
+
"act_space": self.act_space.explain(),
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
def visualize(
|
|
868
|
+
self,
|
|
869
|
+
trajectories: list[dict[str, Any]] | TrajectoryBatch | TrajectoryBuffer,
|
|
870
|
+
output_path: str = "physlink_triptych.gif",
|
|
871
|
+
) -> str:
|
|
872
|
+
"""Produce a triptych GIF comparing Imagination, Real, and Difference panels.
|
|
873
|
+
|
|
874
|
+
Runs a single inference pass through the trained world model to produce
|
|
875
|
+
reconstructed (Imagination) observations, then renders them alongside the
|
|
876
|
+
real observations and the absolute difference as a 3-panel GIF.
|
|
877
|
+
|
|
878
|
+
Prints a "Friday afternoon window" callout comparing elapsed adaptation
|
|
879
|
+
time to the documented from-scratch baseline.
|
|
880
|
+
|
|
881
|
+
Args:
|
|
882
|
+
trajectories: Trajectory dataset to visualize. Uses the first trajectory
|
|
883
|
+
for the panel rendering. ``list[dict]`` and ``TrajectoryBuffer`` are
|
|
884
|
+
silently converted to ``TrajectoryBatch``. Each dict must contain at
|
|
885
|
+
minimum an "obs" key.
|
|
886
|
+
output_path: File path for the output GIF. Defaults to
|
|
887
|
+
"physlink_triptych.gif" in the current working directory.
|
|
888
|
+
|
|
889
|
+
Returns:
|
|
890
|
+
Absolute path to the saved GIF file.
|
|
891
|
+
|
|
892
|
+
Raises:
|
|
893
|
+
AdapterError: If the model has not been initialized via fit() or
|
|
894
|
+
load_checkpoint().
|
|
895
|
+
|
|
896
|
+
Example:
|
|
897
|
+
>>> adapter = DreamerV3Adapter(obs, act)
|
|
898
|
+
>>> adapter.fit(trajectories, steps=1000)
|
|
899
|
+
>>> path = adapter.visualize(trajectories)
|
|
900
|
+
>>> print(path) # absolute path to physlink_triptych.gif
|
|
901
|
+
"""
|
|
902
|
+
from physlink.core.exceptions import AdapterError
|
|
903
|
+
|
|
904
|
+
if self._model is None:
|
|
905
|
+
raise AdapterError(
|
|
906
|
+
"DreamerV3Adapter.visualize: model not initialized.\n"
|
|
907
|
+
" Got: self._model is None\n"
|
|
908
|
+
" Expected: model weights loaded via fit() or load_checkpoint()\n"
|
|
909
|
+
" Fix: call adapter.fit(trajectories, steps=N) before visualize()."
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
if isinstance(trajectories, TrajectoryBuffer):
|
|
913
|
+
trajectories = trajectories.to_batch()
|
|
914
|
+
if isinstance(trajectories, list):
|
|
915
|
+
trajectories = TrajectoryBatch.from_list(trajectories)
|
|
916
|
+
|
|
917
|
+
import numpy as np
|
|
918
|
+
import torch
|
|
919
|
+
|
|
920
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
921
|
+
self._model.to(device)
|
|
922
|
+
|
|
923
|
+
obs_raw = [d["obs"] for d in trajectories.data[:_VIZ_SEQ_LEN]]
|
|
924
|
+
obs_seq = torch.tensor(obs_raw, dtype=torch.float32, device=device) # (T, obs_dims)
|
|
925
|
+
|
|
926
|
+
with torch.no_grad():
|
|
927
|
+
h_state = torch.zeros(1, self._model.gru.hidden_size, device=device)
|
|
928
|
+
imagination_frames = []
|
|
929
|
+
for t in range(obs_seq.shape[0]):
|
|
930
|
+
obs_t = obs_seq[t : t + 1] # shape (1, obs_dims)
|
|
931
|
+
act_t = torch.zeros(1, self.act_space.dims, device=device)
|
|
932
|
+
encoded = self._model.encoder(obs_t)
|
|
933
|
+
gru_input = torch.cat([encoded, act_t], dim=-1)
|
|
934
|
+
h_state = self._model.gru(gru_input, h_state)
|
|
935
|
+
post_params = self._model.posterior(torch.cat([h_state, encoded], dim=-1))
|
|
936
|
+
post_mean, _ = post_params.chunk(2, dim=-1)
|
|
937
|
+
recon = self._model.decoder(torch.cat([h_state, post_mean], dim=-1))
|
|
938
|
+
imagination_frames.append(recon.squeeze(0).cpu().numpy())
|
|
939
|
+
|
|
940
|
+
imagination_np = np.stack(imagination_frames) # (T, obs_dims)
|
|
941
|
+
real_np = obs_seq.cpu().numpy() # (T, obs_dims)
|
|
942
|
+
|
|
943
|
+
from physlink.utils.visualization import (
|
|
944
|
+
_FROM_SCRATCH_BASELINE_LABEL,
|
|
945
|
+
_FROM_SCRATCH_BASELINE_SECONDS,
|
|
946
|
+
render_triptych,
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
gif_path = render_triptych(imagination_np, real_np, output_path)
|
|
950
|
+
self._triptych_path = gif_path
|
|
951
|
+
print(f"[physlink] Triptych saved: {gif_path}")
|
|
952
|
+
|
|
953
|
+
elapsed = self._fit_elapsed_seconds
|
|
954
|
+
if elapsed is not None:
|
|
955
|
+
elapsed_min = elapsed / 60
|
|
956
|
+
baseline_hours = _FROM_SCRATCH_BASELINE_SECONDS / 3600
|
|
957
|
+
speedup = _FROM_SCRATCH_BASELINE_SECONDS / max(elapsed, 1.0)
|
|
958
|
+
print(
|
|
959
|
+
f"[physlink] ⏱ Adaptation complete in {elapsed_min:.1f} min\n"
|
|
960
|
+
f" vs. from-scratch baseline ({_FROM_SCRATCH_BASELINE_LABEL}): "
|
|
961
|
+
f"{baseline_hours:.0f}h\n"
|
|
962
|
+
f" Speedup: ~{speedup:.0f}x"
|
|
963
|
+
)
|
|
964
|
+
else:
|
|
965
|
+
print(
|
|
966
|
+
"[physlink] ⏱ Adaptation time not available "
|
|
967
|
+
"(call fit() before visualize() to see the Friday afternoon window callout)"
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
return gif_path
|
|
971
|
+
|
|
972
|
+
def export(self, path: str) -> dict[str, str]:
|
|
973
|
+
"""Export a complete artifact bundle to the specified directory.
|
|
974
|
+
|
|
975
|
+
Copies the triptych GIF, writes a YAML configuration file, and writes
|
|
976
|
+
a human-readable summary. Calls the share panel to copy the Colab
|
|
977
|
+
notebook URL to the clipboard (Colab only; graceful fallback elsewhere).
|
|
978
|
+
|
|
979
|
+
Args:
|
|
980
|
+
path: Directory path for the exported artifacts. Created if it does
|
|
981
|
+
not exist. Existing files in the directory are overwritten.
|
|
982
|
+
|
|
983
|
+
Returns:
|
|
984
|
+
dict with keys ``gif``, ``config``, ``summary`` mapping to the
|
|
985
|
+
absolute paths of the respective exported files.
|
|
986
|
+
|
|
987
|
+
Raises:
|
|
988
|
+
AdapterError: If ``visualize()`` has not been called (no triptych
|
|
989
|
+
available to export).
|
|
990
|
+
|
|
991
|
+
Example:
|
|
992
|
+
>>> adapter.fit(trajectories, steps=1000)
|
|
993
|
+
>>> adapter.visualize(trajectories)
|
|
994
|
+
>>> artifacts = adapter.export("./physlink_export")
|
|
995
|
+
>>> artifacts["config"] # absolute path to config.yaml
|
|
996
|
+
'/abs/path/physlink_export/config.yaml'
|
|
997
|
+
"""
|
|
998
|
+
import datetime
|
|
999
|
+
import os
|
|
1000
|
+
import shutil
|
|
1001
|
+
|
|
1002
|
+
import yaml
|
|
1003
|
+
|
|
1004
|
+
from physlink.core.exceptions import AdapterError
|
|
1005
|
+
|
|
1006
|
+
if self._triptych_path is None:
|
|
1007
|
+
raise AdapterError(
|
|
1008
|
+
"DreamerV3Adapter.export: no triptych available.\n"
|
|
1009
|
+
" Got: self._triptych_path is None\n"
|
|
1010
|
+
" Expected: visualize() called before export()\n"
|
|
1011
|
+
" Fix: call adapter.visualize(trajectories) before adapter.export(path)."
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
os.makedirs(path, exist_ok=True)
|
|
1015
|
+
|
|
1016
|
+
gif_dest = os.path.join(path, "triptych.gif")
|
|
1017
|
+
shutil.copy2(self._triptych_path, gif_dest)
|
|
1018
|
+
|
|
1019
|
+
config = {
|
|
1020
|
+
"obs_space": self.obs_space.explain(),
|
|
1021
|
+
"act_space": self.act_space.explain(),
|
|
1022
|
+
"checkpoint_path": self._last_checkpoint_path,
|
|
1023
|
+
}
|
|
1024
|
+
yaml_path = os.path.join(path, "config.yaml")
|
|
1025
|
+
with open(yaml_path, "w", encoding="utf-8") as f:
|
|
1026
|
+
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
|
1027
|
+
|
|
1028
|
+
elapsed_min = (
|
|
1029
|
+
self._fit_elapsed_seconds / 60.0
|
|
1030
|
+
if self._fit_elapsed_seconds is not None
|
|
1031
|
+
else None
|
|
1032
|
+
)
|
|
1033
|
+
elapsed_str = f"{elapsed_min:.1f} min" if elapsed_min is not None else "N/A"
|
|
1034
|
+
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
|
1035
|
+
summary_lines = [
|
|
1036
|
+
"physlink Export Summary",
|
|
1037
|
+
"=======================",
|
|
1038
|
+
"Adapter: DreamerV3Adapter",
|
|
1039
|
+
f"obs_dims: {self.obs_space.dims}",
|
|
1040
|
+
f"act_dims: {self.act_space.dims}",
|
|
1041
|
+
f"Fit elapsed: {elapsed_str}",
|
|
1042
|
+
f"Triptych GIF: {os.path.abspath(self._triptych_path)}",
|
|
1043
|
+
f"Checkpoint: {self._last_checkpoint_path or 'N/A'}",
|
|
1044
|
+
f"Exported at: {timestamp}",
|
|
1045
|
+
]
|
|
1046
|
+
summary_path = os.path.join(path, "summary.txt")
|
|
1047
|
+
with open(summary_path, "w", encoding="utf-8") as f:
|
|
1048
|
+
f.write("\n".join(summary_lines) + "\n")
|
|
1049
|
+
|
|
1050
|
+
print(f"[physlink] Export complete: {os.path.abspath(path)}")
|
|
1051
|
+
print(f"[physlink] GIF: {os.path.abspath(gif_dest)}")
|
|
1052
|
+
print(f"[physlink] Config: {os.path.abspath(yaml_path)}")
|
|
1053
|
+
print(f"[physlink] Summary: {os.path.abspath(summary_path)}")
|
|
1054
|
+
|
|
1055
|
+
_share_panel(os.path.abspath(path))
|
|
1056
|
+
|
|
1057
|
+
return {
|
|
1058
|
+
"gif": os.path.abspath(gif_dest),
|
|
1059
|
+
"config": os.path.abspath(yaml_path),
|
|
1060
|
+
"summary": os.path.abspath(summary_path),
|
|
1061
|
+
}
|
|
1062
|
+
|
|
1063
|
+
def __repr__(self) -> str:
|
|
1064
|
+
return (
|
|
1065
|
+
f"DreamerV3Adapter("
|
|
1066
|
+
f"obs_dims={self.obs_space.dims}, "
|
|
1067
|
+
f"act_dims={self.act_space.dims})"
|
|
1068
|
+
)
|