physlink 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1068 @@
1
+ """DreamerV3 adapter for PhysLink."""
2
+
3
+ import contextlib
4
+ from collections.abc import Generator
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from physlink.core._types import TrajectoryBatch, TrajectoryBuffer
8
+
9
+ if TYPE_CHECKING:
10
+ from physlink.core._types import AdaptationRun
11
+ from physlink.core.adapter import BaseAdapter
12
+ from physlink.core.exceptions import ConfigurationError
13
+ from physlink.core.spaces import ActionSpace, ObservationSpace
14
+ from physlink.core.validation import ComplianceReport
15
+
16
+ MIN_OBS_DIMS: int = 4 # DreamerV3 requires >= 4 observation dimensions
17
+ MIN_ACT_DIMS: int = 1 # at least 1 action dimension required
18
+
19
+ _HEALTH_WINDOW: int = 50
20
+ _HEALTH_BASELINE_STEPS: int = 10
21
+ _ANOMALY_MULTIPLIER: float = 2.0
22
+
23
+ _STAGE_NAMES: tuple[str, ...] = (
24
+ "data_loading",
25
+ "world_model_update",
26
+ "actor_update",
27
+ "critic_update",
28
+ )
29
+
30
+ _VIZ_SEQ_LEN: int = 50 # max steps used for triptych inference
31
+
32
+
33
+ class _DebugPanel:
34
+ def __init__(self) -> None:
35
+ self.stages: dict[str, str] = {name: "waiting..." for name in _STAGE_NAMES}
36
+
37
+ def update_all(self, statuses: dict[str, str]) -> None:
38
+ self.stages.update(statuses)
39
+
40
+ def __rich__(self) -> Any: # noqa: ANN401
41
+ from rich.table import Table
42
+
43
+ table = Table(
44
+ title="[dim]Debug Hooks Panel[/dim]",
45
+ show_header=True,
46
+ box=None,
47
+ padding=(0, 1),
48
+ )
49
+ table.add_column("Stage", style="dim", no_wrap=True)
50
+ table.add_column("Status", no_wrap=True)
51
+ for name, status in self.stages.items():
52
+ label = name.replace("_", " ")
53
+ if status == "OK":
54
+ cell = "[bold green]OK[/bold green]"
55
+ elif status == "waiting...":
56
+ cell = "[dim]waiting...[/dim]"
57
+ else:
58
+ cell = f"[bold red]{status}[/bold red]"
59
+ table.add_row(label, cell)
60
+ return table
61
+
62
+
63
+ @contextlib.contextmanager
64
+ def _build_progress_bar(
65
+ steps: int,
66
+ ) -> Generator[tuple[Any, Any], None, None]:
67
+ """Context manager yielding (progress, task_id) for the adaptation loop."""
68
+ from rich.progress import (
69
+ BarColumn,
70
+ MofNCompleteColumn,
71
+ Progress,
72
+ ProgressColumn,
73
+ SpinnerColumn,
74
+ TextColumn,
75
+ TimeRemainingColumn,
76
+ )
77
+ from rich.text import Text
78
+
79
+ class _StepsPerSecColumn(ProgressColumn):
80
+ def render(self, task: Any) -> Text: # noqa: ANN401
81
+ if task.speed is None:
82
+ return Text("? step/s", style="dim")
83
+ return Text(f"{task.speed:.1f} step/s", style="cyan")
84
+
85
+ class _HealthColumn(ProgressColumn):
86
+ def render(self, task: Any) -> Text: # noqa: ANN401
87
+ health = task.fields.get("health", "OK")
88
+ style = "bold green" if health == "OK" else "bold red"
89
+ return Text(health, style=style)
90
+
91
+ with Progress(
92
+ SpinnerColumn(),
93
+ TextColumn("[progress.description]{task.description}"),
94
+ BarColumn(),
95
+ MofNCompleteColumn(),
96
+ TextColumn("•"),
97
+ TimeRemainingColumn(),
98
+ TextColumn("•"),
99
+ _StepsPerSecColumn(),
100
+ TextColumn("•"),
101
+ _HealthColumn(),
102
+ ) as progress:
103
+ task_id = progress.add_task(
104
+ "[cyan]DreamerV3 adaptation",
105
+ total=steps,
106
+ health="OK",
107
+ )
108
+ yield progress, task_id
109
+
110
+
111
+ @contextlib.contextmanager
112
+ def _build_debug_layout(
113
+ steps: int,
114
+ panel: _DebugPanel,
115
+ ) -> Generator[tuple[Any, Any], None, None]:
116
+ """Context manager yielding (progress, task_id) with debug panel alongside."""
117
+ from rich.console import Group
118
+ from rich.live import Live
119
+ from rich.progress import (
120
+ BarColumn,
121
+ MofNCompleteColumn,
122
+ Progress,
123
+ ProgressColumn,
124
+ SpinnerColumn,
125
+ TextColumn,
126
+ TimeRemainingColumn,
127
+ )
128
+ from rich.text import Text
129
+
130
+ class _StepsPerSecColumn(ProgressColumn):
131
+ def render(self, task: Any) -> Text: # noqa: ANN401
132
+ if task.speed is None:
133
+ return Text("? step/s", style="dim")
134
+ return Text(f"{task.speed:.1f} step/s", style="cyan")
135
+
136
+ class _HealthColumn(ProgressColumn):
137
+ def render(self, task: Any) -> Text: # noqa: ANN401
138
+ health = task.fields.get("health", "OK")
139
+ style = "bold green" if health == "OK" else "bold red"
140
+ return Text(health, style=style)
141
+
142
+ progress = Progress(
143
+ SpinnerColumn(),
144
+ TextColumn("[progress.description]{task.description}"),
145
+ BarColumn(),
146
+ MofNCompleteColumn(),
147
+ TextColumn("•"),
148
+ TimeRemainingColumn(),
149
+ TextColumn("•"),
150
+ _StepsPerSecColumn(),
151
+ TextColumn("•"),
152
+ _HealthColumn(),
153
+ )
154
+ task_id = progress.add_task(
155
+ "[cyan]DreamerV3 adaptation",
156
+ total=steps,
157
+ health="OK",
158
+ )
159
+
160
+ with Live(Group(progress, panel), refresh_per_second=4):
161
+ yield progress, task_id
162
+
163
+
164
+ def _save_checkpoint(
165
+ model: Any, # noqa: ANN401
166
+ actor: Any, # noqa: ANN401
167
+ critic: Any, # noqa: ANN401
168
+ step: int,
169
+ checkpoint_dir: str,
170
+ ) -> str:
171
+ import datetime
172
+ import os
173
+
174
+ from safetensors.torch import save_file
175
+
176
+ import physlink
177
+
178
+ os.makedirs(checkpoint_dir, exist_ok=True)
179
+ filename = f"checkpoint_step_{step}.safetensors"
180
+ path = os.path.join(checkpoint_dir, filename)
181
+ tensors: dict[str, Any] = {}
182
+ tensors.update({f"model.{k}": v for k, v in model.state_dict().items()})
183
+ tensors.update({f"actor.{k}": v for k, v in actor.state_dict().items()})
184
+ tensors.update({f"critic.{k}": v for k, v in critic.state_dict().items()})
185
+ metadata = {
186
+ "physlink_version": physlink.__version__,
187
+ "adapter_class": "DreamerV3Adapter",
188
+ "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
189
+ "checkpoint_step": str(step),
190
+ }
191
+ save_file(tensors, path, metadata=metadata)
192
+ print(f"[physlink] Checkpoint saved: {os.path.abspath(path)}")
193
+ return path
194
+
195
+
196
+ def _check_checkpoint_metadata(path: str) -> dict[str, str]:
197
+ from safetensors import safe_open
198
+
199
+ import physlink
200
+ from physlink.core.exceptions import CheckpointCorruptError, CheckpointVersionError
201
+
202
+ try:
203
+ with safe_open(path, framework="pt", device="cpu") as f:
204
+ metadata = f.metadata()
205
+ except Exception as exc:
206
+ raise CheckpointCorruptError(
207
+ f"Cannot open checkpoint: {path}\n"
208
+ f" Got: {type(exc).__name__}: {exc}\n"
209
+ f" Expected: valid safetensors file\n"
210
+ f" Fix: re-run adapter.fit() to generate a fresh checkpoint."
211
+ )
212
+
213
+ if metadata is None or "physlink_version" not in metadata:
214
+ raise CheckpointCorruptError(
215
+ f"Checkpoint metadata missing or incomplete: {path}\n"
216
+ f" Got: metadata={metadata!r}\n"
217
+ f" Expected: metadata dict with key 'physlink_version'\n"
218
+ f" Fix: re-run adapter.fit() to generate a fresh checkpoint."
219
+ )
220
+
221
+ checkpoint_version = metadata["physlink_version"]
222
+ current_version = physlink.__version__
223
+ cv_parts = checkpoint_version.split(".")
224
+ cur_parts = current_version.split(".")
225
+ if cv_parts[:2] != cur_parts[:2]:
226
+ raise CheckpointVersionError(
227
+ f"Checkpoint version incompatible: {path}\n"
228
+ f" Got: checkpoint saved with physlink=={checkpoint_version}\n"
229
+ f" Expected: compatible version (same major.minor as {current_version})\n"
230
+ f" Fix: re-run adapter.fit() to generate a fresh checkpoint.",
231
+ checkpoint_version=checkpoint_version,
232
+ current_version=current_version,
233
+ )
234
+
235
+ return metadata
236
+
237
+
238
+ def _share_panel(export_path: str) -> None:
239
+ """Trigger the Colab share panel: copy notebook URL to clipboard.
240
+
241
+ In Google Colab, copies the current notebook URL to the clipboard via
242
+ Javascript. Outside Colab, prints a graceful fallback message.
243
+
244
+ Args:
245
+ export_path: Absolute path to the export directory. Shown in fallback
246
+ message so collaborators know where to find the artifacts.
247
+
248
+ Example:
249
+ >>> _share_panel("./physlink_export")
250
+ [physlink] Share panel: URL copy is only available in Google Colab.
251
+ ...
252
+ """
253
+ try:
254
+ import google.colab # noqa: F401
255
+ in_colab = True
256
+ except ImportError:
257
+ in_colab = False
258
+
259
+ try:
260
+ if in_colab:
261
+ from IPython.display import Javascript, display
262
+ display(Javascript(
263
+ "navigator.clipboard.writeText(window.location.href)"
264
+ ".then(() => console.log('[physlink] Notebook URL copied.'));"
265
+ ))
266
+ print("[physlink] Share panel: notebook URL copied to clipboard.")
267
+ print(f"[physlink] Export path for collaborators: {export_path}")
268
+ else:
269
+ print(
270
+ "[physlink] Share panel: URL copy is only available in Google Colab.\n"
271
+ f" To share your results, send the export directory: {export_path}"
272
+ )
273
+ except Exception as exc:
274
+ print(f"[physlink] Share panel unavailable: {type(exc).__name__}")
275
+
276
+
277
+ class DreamerV3Adapter(BaseAdapter):
278
+ """DreamerV3 adapter for physical simulation reinforcement learning.
279
+
280
+ Validates space compatibility at construction time. Training, visualization,
281
+ and export are deferred to fit() / visualize() / export() respectively.
282
+ No model weights are loaded and no GPU is required at construction.
283
+
284
+ Args:
285
+ obs_space: Observation space with dims >= 4.
286
+ act_space: Action space with dims >= 1.
287
+
288
+ Raises:
289
+ ConfigurationError: If obs_space.dims < 4 or act_space.dims < 1.
290
+
291
+ Example:
292
+ >>> from physlink import DreamerV3Adapter, ObservationSpace, ActionSpace
293
+ >>> obs = ObservationSpace.from_proprioception(joints=7, include_velocity=True)
294
+ >>> act = ActionSpace.continuous(dims=7, bounds=[(-1.0, 1.0)] * 7)
295
+ >>> adapter = DreamerV3Adapter(obs, act)
296
+ >>> adapter.obs_space.dims
297
+ 14
298
+ """
299
+
300
+ def __init__(self, obs_space: ObservationSpace, act_space: ActionSpace) -> None:
301
+ if obs_space.dims < MIN_OBS_DIMS:
302
+ raise ConfigurationError(
303
+ f"DreamerV3Adapter: incompatible obs_space.\n"
304
+ f" Got: obs_space.dims={obs_space.dims}\n"
305
+ f" Expected: obs_space.dims >= {MIN_OBS_DIMS} (DreamerV3 minimum)\n"
306
+ f" Fix: construct ObservationSpace with joints >= {MIN_OBS_DIMS}, "
307
+ f"or use include_velocity=True to double the dimension count."
308
+ )
309
+ if act_space.dims < MIN_ACT_DIMS:
310
+ raise ConfigurationError(
311
+ f"DreamerV3Adapter: incompatible act_space.\n"
312
+ f" Got: act_space.dims={act_space.dims}\n"
313
+ f" Expected: act_space.dims >= {MIN_ACT_DIMS}\n"
314
+ f" Fix: construct ActionSpace with dims >= 1."
315
+ )
316
+ super().__init__(obs_space, act_space)
317
+ self._model: Any | None = None
318
+ self._actor: Any | None = None
319
+ self._critic: Any | None = None
320
+ self._loss_history: list[float] = []
321
+ self._baseline_loss: float | None = None
322
+ self._fit_elapsed_seconds: float | None = None
323
+ self._triptych_path: str | None = None
324
+ self._last_checkpoint_path: str | None = None
325
+ self._invariants: list = []
326
+ self._invariant_residuals: dict[str, list[float]] = {}
327
+ self._soft_penalty_per_step: float = 0.0
328
+
329
+ def _initialize_model(self, device: Any) -> None: # noqa: ANN401
330
+ import torch.nn as nn
331
+
332
+ obs_dims = self.obs_space.dims
333
+ act_dims = self.act_space.dims
334
+ hidden = 256
335
+ latent = 256
336
+
337
+ class _WorldModel(nn.Module):
338
+ def __init__(self) -> None:
339
+ super().__init__()
340
+ self.encoder = nn.Sequential(
341
+ nn.Linear(obs_dims, hidden), nn.ELU(),
342
+ nn.Linear(hidden, hidden), nn.ELU(),
343
+ )
344
+ self.gru = nn.GRUCell(hidden + act_dims, hidden)
345
+ self.posterior = nn.Sequential(
346
+ nn.Linear(hidden + hidden, hidden), nn.ELU(),
347
+ nn.Linear(hidden, latent * 2),
348
+ )
349
+ self.prior = nn.Sequential(
350
+ nn.Linear(hidden, hidden), nn.ELU(),
351
+ nn.Linear(hidden, latent * 2),
352
+ )
353
+ self.decoder = nn.Sequential(
354
+ nn.Linear(hidden + latent, hidden), nn.ELU(),
355
+ nn.Linear(hidden, obs_dims),
356
+ )
357
+ self.reward_head = nn.Sequential(
358
+ nn.Linear(hidden + latent, hidden), nn.ELU(),
359
+ nn.Linear(hidden, 1),
360
+ )
361
+
362
+ class _Actor(nn.Module):
363
+ def __init__(self) -> None:
364
+ super().__init__()
365
+ self.net = nn.Sequential(
366
+ nn.Linear(hidden + latent, hidden), nn.ELU(),
367
+ nn.Linear(hidden, hidden), nn.ELU(),
368
+ nn.Linear(hidden, act_dims * 2),
369
+ )
370
+
371
+ class _Critic(nn.Module):
372
+ def __init__(self) -> None:
373
+ super().__init__()
374
+ self.net = nn.Sequential(
375
+ nn.Linear(hidden + latent, hidden), nn.ELU(),
376
+ nn.Linear(hidden, hidden), nn.ELU(),
377
+ nn.Linear(hidden, 1),
378
+ )
379
+
380
+ self._model = _WorldModel().to(device)
381
+ self._actor = _Actor().to(device)
382
+ self._critic = _Critic().to(device)
383
+
384
+ def _reset_training_state(self) -> None:
385
+ """Reset all mutable training state for a fresh fit() run (NFR-09)."""
386
+ self._loss_history = []
387
+ self._baseline_loss = None
388
+ self._invariant_residuals = {}
389
+ self._soft_penalty_per_step = 0.0
390
+
391
+ def _apply_invariants(self, trajectories: TrajectoryBatch) -> TrajectoryBatch:
392
+ """Apply registered invariants: filter hard-mode violations, compute soft penalty."""
393
+ if not self._invariants:
394
+ return trajectories
395
+
396
+ from physlink.core.exceptions import ValidationError
397
+
398
+ data = trajectories.data
399
+ for inv in self._invariants:
400
+ self._invariant_residuals[inv.name] = []
401
+
402
+ hard_mask: list[bool] = [True] * len(data)
403
+
404
+ for inv in self._invariants:
405
+ for idx, traj in enumerate(data):
406
+ try:
407
+ residual = float(inv.fn(traj))
408
+ except Exception as exc:
409
+ print(
410
+ f"[physlink] Invariant '{inv.name}' failed on trajectory {idx}: "
411
+ f"{type(exc).__name__} — treating residual as 0.0"
412
+ )
413
+ residual = 0.0
414
+ self._invariant_residuals[inv.name].append(residual)
415
+
416
+ if inv.mode == "hard" and residual > inv.tolerance:
417
+ hard_mask[idx] = False
418
+ print(
419
+ f"[physlink] Invariant '{inv.name}' rejected trajectory {idx}: "
420
+ f"residual={residual:.4f} > tolerance={inv.tolerance}"
421
+ )
422
+
423
+ filtered = [d for d, keep in zip(data, hard_mask) if keep]
424
+ if not filtered:
425
+ raise ValidationError(
426
+ f"register_invariant (hard mode): all {len(data)} trajectories rejected.\n"
427
+ f" Got: 0 trajectories remaining after hard-mode invariant filtering\n"
428
+ f" Expected: at least 1 trajectory passing all hard-mode invariants\n"
429
+ f" Fix: lower tolerance, fix the invariant function, "
430
+ f"or switch to mode='soft'."
431
+ )
432
+
433
+ soft_surplus = 0.0
434
+ for inv in self._invariants:
435
+ if inv.mode == "soft":
436
+ for r in self._invariant_residuals[inv.name]:
437
+ if r > inv.tolerance:
438
+ soft_surplus += r - inv.tolerance
439
+ self._soft_penalty_per_step = soft_surplus / max(len(data), 1)
440
+
441
+ return TrajectoryBatch(data=filtered)
442
+
443
+ def compliance_report(self) -> ComplianceReport:
444
+ """Return a ComplianceReport summarizing invariant compliance from the last fit().
445
+
446
+ Reads ``_invariants`` and ``_invariant_residuals`` stored on the adapter.
447
+ Pure computation — no side effects, safe to call multiple times.
448
+
449
+ Returns:
450
+ ComplianceReport with per-invariant summary and violation details.
451
+ Empty report (no entries) if no invariants are registered.
452
+ Zero-trajectory report if fit() has not yet been called.
453
+
454
+ Example:
455
+ >>> register_invariant(adapter, "mass", fn, tolerance=0.01)
456
+ >>> adapter.fit(trajectories, steps=100)
457
+ >>> report = adapter.compliance_report()
458
+ >>> print(report.summary())
459
+ mass: PASS (max_residual=0.0042, threshold=0.0100, violations=0/10)
460
+ """
461
+ stats: list[dict[str, Any]] = []
462
+ violation_list: list[dict[str, Any]] = []
463
+
464
+ for inv in self._invariants:
465
+ residuals = self._invariant_residuals.get(inv.name, [])
466
+ max_residual = max(residuals) if residuals else 0.0
467
+ violation_count = sum(1 for r in residuals if r > inv.tolerance)
468
+ total = len(residuals)
469
+
470
+ stats.append({
471
+ "name": inv.name,
472
+ "max_residual": max_residual,
473
+ "threshold": inv.tolerance,
474
+ "violation_count": violation_count,
475
+ "total": total,
476
+ })
477
+
478
+ for idx, residual in enumerate(residuals):
479
+ if residual > inv.tolerance:
480
+ violation_list.append({
481
+ "invariant_name": inv.name,
482
+ "trajectory_idx": idx,
483
+ "residual": residual,
484
+ "possible_cause": (
485
+ f"Residual {residual:.4f} exceeds tolerance {inv.tolerance:.4f}."
486
+ ),
487
+ })
488
+
489
+ return ComplianceReport(
490
+ _stats=stats,
491
+ _violation_list=violation_list,
492
+ _residuals_by_invariant={
493
+ inv.name: list(self._invariant_residuals.get(inv.name, []))
494
+ for inv in self._invariants
495
+ },
496
+ )
497
+
498
+ def load_checkpoint(self, path: str) -> None:
499
+ """Load model weights from a safetensors checkpoint.
500
+
501
+ Reads checkpoint metadata before loading weights for early detection
502
+ of version incompatibility or file corruption.
503
+
504
+ Args:
505
+ path: Path to the .safetensors checkpoint file to load.
506
+
507
+ Raises:
508
+ CheckpointCorruptError: If the file is malformed, unreadable, or
509
+ missing required metadata.
510
+ CheckpointVersionError: If physlink_version in the checkpoint
511
+ metadata is incompatible with the installed version
512
+ (different major.minor component).
513
+
514
+ Example:
515
+ >>> adapter = DreamerV3Adapter(obs, act)
516
+ >>> adapter.load_checkpoint("./physlink_checkpoints/checkpoint_step_1000.safetensors")
517
+ """
518
+ _check_checkpoint_metadata(path)
519
+
520
+ import os
521
+
522
+ import torch
523
+ from safetensors.torch import load_file
524
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
525
+ if self._model is None:
526
+ self._initialize_model(device)
527
+ self._model.to(device)
528
+ self._actor.to(device)
529
+ self._critic.to(device)
530
+ state_dict_all = load_file(path, device="cpu")
531
+ model_sd = {
532
+ k[len("model."):]: v for k, v in state_dict_all.items() if k.startswith("model.")
533
+ }
534
+ actor_sd = {
535
+ k[len("actor."):]: v for k, v in state_dict_all.items() if k.startswith("actor.")
536
+ }
537
+ critic_sd = {
538
+ k[len("critic."):]: v for k, v in state_dict_all.items() if k.startswith("critic.")
539
+ }
540
+ self._model.load_state_dict(model_sd)
541
+ self._actor.load_state_dict(actor_sd)
542
+ self._critic.load_state_dict(critic_sd)
543
+ print(f"[physlink] Checkpoint loaded: {os.path.abspath(path)}")
544
+
545
+ def _compute_health(self, loss: float) -> str:
546
+ self._loss_history.append(loss)
547
+ if len(self._loss_history) > _HEALTH_WINDOW:
548
+ self._loss_history = self._loss_history[-_HEALTH_WINDOW:]
549
+
550
+ if self._baseline_loss is None and len(self._loss_history) >= _HEALTH_BASELINE_STEPS:
551
+ self._baseline_loss = (
552
+ sum(self._loss_history[:_HEALTH_BASELINE_STEPS]) / _HEALTH_BASELINE_STEPS
553
+ )
554
+
555
+ if self._baseline_loss is None or self._baseline_loss <= 0:
556
+ return "OK"
557
+
558
+ current_avg = sum(self._loss_history) / len(self._loss_history)
559
+ return "ANOMALY" if current_avg > _ANOMALY_MULTIPLIER * self._baseline_loss else "OK"
560
+
561
+ def _training_step(self, batch: Any, device: Any) -> Any: # noqa: ANN401
562
+ import torch
563
+ import torch.nn as nn
564
+
565
+ obs_all, act_all = batch # pre-processed tensors: (N, obs_dims), (N, act_dims)
566
+ n = obs_all.shape[0]
567
+
568
+ batch_size = 16
569
+ seq_len = min(50, max(1, n))
570
+ max_start = max(0, n - seq_len)
571
+
572
+ starts = torch.randint(0, max_start + 1, (batch_size,)).tolist()
573
+ obs_seq = torch.stack([obs_all[s: s + seq_len] for s in starts])
574
+ act_seq = torch.stack([act_all[s: s + seq_len] for s in starts])
575
+
576
+ b_size, t_steps, obs_d = obs_seq.shape
577
+ gru_hidden = self._model.gru.hidden_size
578
+
579
+ with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
580
+ h_state = torch.zeros(b_size, gru_hidden, device=device)
581
+ latents: list[Any] = []
582
+ kl_losses: list[Any] = []
583
+ recon_losses: list[Any] = []
584
+
585
+ for t in range(t_steps):
586
+ obs_t = obs_seq[:, t]
587
+ act_t = act_seq[:, t]
588
+
589
+ encoded = self._model.encoder(obs_t)
590
+ gru_input = torch.cat([encoded, act_t], dim=-1)
591
+ h_state = self._model.gru(gru_input, h_state)
592
+
593
+ post_params = self._model.posterior(torch.cat([h_state, encoded], dim=-1))
594
+ post_mean, post_log_std = post_params.chunk(2, dim=-1)
595
+ post_std = post_log_std.clamp(-5, 2).exp()
596
+ z = post_mean + post_std * torch.randn_like(post_std)
597
+ latents.append(z)
598
+
599
+ prior_params = self._model.prior(h_state)
600
+ prior_mean, prior_log_std = prior_params.chunk(2, dim=-1)
601
+ prior_std = prior_log_std.clamp(-5, 2).exp().clamp(min=1e-8)
602
+
603
+ kl = 0.5 * (
604
+ (post_mean - prior_mean).pow(2) / prior_std.pow(2)
605
+ + (post_std / prior_std).pow(2)
606
+ - 1
607
+ - 2 * (post_std / prior_std).log()
608
+ ).sum(-1).mean()
609
+ kl_losses.append(kl)
610
+
611
+ recon = self._model.decoder(torch.cat([h_state, z], dim=-1))
612
+ recon_losses.append(nn.functional.mse_loss(recon, obs_t))
613
+
614
+ wm_loss = torch.stack(recon_losses).mean() + 0.1 * torch.stack(kl_losses).mean()
615
+
616
+ # Imagination rollout
617
+ imagine_horizon = 15
618
+ hidden_i = h_state.detach()
619
+ latent_i = latents[-1].detach()
620
+
621
+ imagined_values: list[Any] = []
622
+ imagined_rewards: list[Any] = []
623
+
624
+ for _ in range(imagine_horizon):
625
+ actor_input = torch.cat([hidden_i, latent_i], dim=-1)
626
+ actor_params = self._actor.net(actor_input)
627
+ act_mean, act_log_std = actor_params.chunk(2, dim=-1)
628
+ act_i = torch.tanh(
629
+ act_mean + act_log_std.clamp(-5, 2).exp() * torch.randn_like(act_mean)
630
+ )
631
+
632
+ enc_i = self._model.encoder(torch.zeros(b_size, obs_d, device=device))
633
+ gru_in = torch.cat([enc_i, act_i], dim=-1)
634
+ hidden_i = self._model.gru(gru_in, hidden_i)
635
+
636
+ prior_p = self._model.prior(hidden_i)
637
+ latent_i, _ = prior_p.chunk(2, dim=-1)
638
+
639
+ critic_in = torch.cat([hidden_i, latent_i], dim=-1)
640
+ imagined_values.append(self._critic.net(critic_in))
641
+ imagined_rewards.append(
642
+ self._model.reward_head(torch.cat([hidden_i, latent_i], dim=-1))
643
+ )
644
+
645
+ # λ-returns (simplified)
646
+ returns = imagined_values[-1].detach()
647
+ for v, r in zip(reversed(imagined_values[:-1]), reversed(imagined_rewards[:-1])):
648
+ returns = r + 0.99 * (0.95 * v + 0.05 * returns)
649
+
650
+ actor_loss = -returns.mean()
651
+
652
+ critic_in = torch.cat([h_state.detach(), latents[-1].detach()], dim=-1)
653
+ critic_val = self._critic.net(critic_in)
654
+ critic_loss = nn.functional.mse_loss(critic_val, returns.detach())
655
+
656
+ total_loss = wm_loss + actor_loss + critic_loss + self._soft_penalty_per_step
657
+
658
+ return total_loss
659
+
660
+ def fit(
661
+ self,
662
+ trajectories: list[dict[str, Any]] | TrajectoryBatch | TrajectoryBuffer,
663
+ steps: int,
664
+ checkpoint_interval_steps: int = 1000,
665
+ debug_hooks: bool = False,
666
+ checkpoint_dir: str = "physlink_checkpoints",
667
+ ) -> "AdaptationRun":
668
+ """Run the DreamerV3 adaptation loop with a live progress bar.
669
+
670
+ Adapts the DreamerV3 world model to the provided trajectory data over
671
+ ``steps`` gradient updates. Displays a rich progress bar in Colab output
672
+ with step count, ETA, prediction health (OK/ANOMALY), and throughput.
673
+
674
+ Calling fit() multiple times is safe: each call resets optimizer state
675
+ and training history for a fresh run (NFR-09 idempotence).
676
+
677
+ Args:
678
+ trajectories: Trajectory dataset. ``list[dict]`` and ``TrajectoryBuffer``
679
+ are silently converted to ``TrajectoryBatch``. Each dict must contain
680
+ at minimum "obs" and "action" keys with numpy-compatible values.
681
+ steps: Total gradient steps to run. Must be > 0.
682
+ checkpoint_interval_steps: Interval (in steps) between checkpoint
683
+ saves. A checkpoint file is written every this many steps. Must
684
+ be > 0.
685
+ debug_hooks: When True, displays a debug panel alongside the progress
686
+ bar showing pipeline stage statuses (data_loading, world_model_update,
687
+ actor_update, critic_update). Each stage shows OK or a diagnostic
688
+ status. Defaults to False (opt-in, not default).
689
+ checkpoint_dir: Directory where checkpoint files are written. Defaults
690
+ to "physlink_checkpoints" relative to the current working directory.
691
+
692
+ Returns:
693
+ AdaptationRun capturing config, step count, checkpoint paths, and elapsed time.
694
+
695
+ Raises:
696
+ ValidationError: If steps <= 0 or checkpoint_interval_steps <= 0.
697
+
698
+ Example:
699
+ >>> from physlink import DreamerV3Adapter, ObservationSpace, ActionSpace
700
+ >>> obs = ObservationSpace.from_proprioception(joints=7)
701
+ >>> act = ActionSpace.continuous(dims=7, bounds=[(-1.0, 1.0)] * 7)
702
+ >>> adapter = DreamerV3Adapter(obs, act)
703
+ >>> trajectories = [{"obs": [0.1] * 7, "action": [0.0] * 7}] * 100
704
+ >>> run = adapter.fit(trajectories, steps=10, debug_hooks=True)
705
+ """
706
+ import time
707
+
708
+ from physlink.core._types import AdaptationConfig, AdaptationRun
709
+ from physlink.core.exceptions import ValidationError
710
+
711
+ if isinstance(steps, bool) or not isinstance(steps, int) or steps <= 0:
712
+ raise ValidationError(
713
+ f"DreamerV3Adapter.fit: invalid steps.\n"
714
+ f" Got: steps={steps}\n"
715
+ f" Expected: steps > 0\n"
716
+ f" Fix: provide a positive integer, e.g. steps=10000."
717
+ )
718
+ if (
719
+ isinstance(checkpoint_interval_steps, bool)
720
+ or not isinstance(checkpoint_interval_steps, int)
721
+ or checkpoint_interval_steps <= 0
722
+ ):
723
+ raise ValidationError(
724
+ f"DreamerV3Adapter.fit: invalid checkpoint_interval_steps.\n"
725
+ f" Got: checkpoint_interval_steps={checkpoint_interval_steps}\n"
726
+ f" Expected: checkpoint_interval_steps > 0\n"
727
+ f" Fix: provide a positive integer, e.g. checkpoint_interval_steps=1000."
728
+ )
729
+
730
+ self._reset_training_state()
731
+
732
+ if isinstance(trajectories, TrajectoryBuffer):
733
+ trajectories = trajectories.to_batch()
734
+ if isinstance(trajectories, list):
735
+ trajectories = TrajectoryBatch.from_list(trajectories)
736
+
737
+ trajectories = self._apply_invariants(trajectories)
738
+
739
+ import torch
740
+
741
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
742
+
743
+ if self._model is None:
744
+ self._initialize_model(device)
745
+
746
+ # Pre-process trajectory data to tensors once
747
+ raw_data = trajectories.data
748
+ obs_all = torch.tensor(
749
+ [d["obs"] for d in raw_data], dtype=torch.float32, device=device
750
+ )
751
+ act_raw = torch.tensor(
752
+ [d["action"] for d in raw_data], dtype=torch.float32, device=device
753
+ )
754
+
755
+ # Align action dims to model's act_dims via zero-padding or truncation
756
+ model_act_dims = self.act_space.dims
757
+ if act_raw.shape[-1] < model_act_dims:
758
+ pad = torch.zeros(
759
+ act_raw.shape[0], model_act_dims - act_raw.shape[-1], device=device
760
+ )
761
+ act_all = torch.cat([act_raw, pad], dim=-1)
762
+ elif act_raw.shape[-1] > model_act_dims:
763
+ act_all = act_raw[:, :model_act_dims]
764
+ else:
765
+ act_all = act_raw
766
+
767
+ tensor_batch = (obs_all, act_all)
768
+
769
+ all_params = (
770
+ list(self._model.parameters())
771
+ + list(self._actor.parameters())
772
+ + list(self._critic.parameters())
773
+ )
774
+ optimizer = torch.optim.Adam(all_params, lr=3e-4)
775
+ scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))
776
+
777
+ _fit_start_time = time.monotonic()
778
+ _run_checkpoint_paths: list[str] = []
779
+
780
+ if debug_hooks:
781
+ debug_panel = _DebugPanel()
782
+ with _build_debug_layout(steps, debug_panel) as (progress, task_id):
783
+ for step_idx in range(steps):
784
+ stage_statuses = {name: "OK" for name in _STAGE_NAMES}
785
+ optimizer.zero_grad(set_to_none=True)
786
+ try:
787
+ loss = self._training_step(tensor_batch, device)
788
+ except Exception as exc:
789
+ for name in ("world_model_update", "actor_update", "critic_update"):
790
+ stage_statuses[name] = type(exc).__name__
791
+ debug_panel.update_all(stage_statuses)
792
+ raise
793
+ scaler.scale(loss).backward()
794
+ scaler.unscale_(optimizer)
795
+ torch.nn.utils.clip_grad_norm_(all_params, max_norm=100.0)
796
+ scaler.step(optimizer)
797
+ scaler.update()
798
+ debug_panel.update_all(stage_statuses)
799
+ progress.update(
800
+ task_id, advance=1, health=self._compute_health(loss.item())
801
+ )
802
+ completed = step_idx + 1
803
+ if completed % checkpoint_interval_steps == 0:
804
+ _ckpt = _save_checkpoint(
805
+ self._model, self._actor, self._critic,
806
+ completed, checkpoint_dir,
807
+ )
808
+ self._last_checkpoint_path = _ckpt
809
+ _run_checkpoint_paths.append(_ckpt)
810
+ else:
811
+ with _build_progress_bar(steps) as (progress, task_id):
812
+ for step_idx in range(steps):
813
+ optimizer.zero_grad(set_to_none=True)
814
+ loss = self._training_step(tensor_batch, device)
815
+ scaler.scale(loss).backward()
816
+ scaler.unscale_(optimizer)
817
+ torch.nn.utils.clip_grad_norm_(all_params, max_norm=100.0)
818
+ scaler.step(optimizer)
819
+ scaler.update()
820
+ progress.update(
821
+ task_id, advance=1, health=self._compute_health(loss.item())
822
+ )
823
+ completed = step_idx + 1
824
+ if completed % checkpoint_interval_steps == 0:
825
+ _ckpt = _save_checkpoint(
826
+ self._model, self._actor, self._critic,
827
+ completed, checkpoint_dir,
828
+ )
829
+ self._last_checkpoint_path = _ckpt
830
+ _run_checkpoint_paths.append(_ckpt)
831
+
832
+ self._fit_elapsed_seconds = time.monotonic() - _fit_start_time
833
+
834
+ _config = AdaptationConfig(
835
+ obs_space=self.obs_space,
836
+ act_space=self.act_space,
837
+ steps=steps,
838
+ checkpoint_interval_steps=checkpoint_interval_steps,
839
+ checkpoint_dir=checkpoint_dir,
840
+ )
841
+ _run = AdaptationRun(
842
+ config=_config,
843
+ current_step=completed,
844
+ checkpoint_paths=_run_checkpoint_paths,
845
+ elapsed_seconds=self._fit_elapsed_seconds or 0.0,
846
+ )
847
+ return _run
848
+
849
+ def explain(self) -> dict[str, Any]:
850
+ """Return a metadata dict describing this adapter's space configuration.
851
+
852
+ Returns:
853
+ A JSON-serializable dict with keys: type, obs_space, act_space.
854
+
855
+ Example:
856
+ >>> adapter = DreamerV3Adapter(obs, act)
857
+ >>> info = adapter.explain()
858
+ >>> info["type"]
859
+ 'DreamerV3Adapter'
860
+ """
861
+ return {
862
+ "type": "DreamerV3Adapter",
863
+ "obs_space": self.obs_space.explain(),
864
+ "act_space": self.act_space.explain(),
865
+ }
866
+
867
+ def visualize(
868
+ self,
869
+ trajectories: list[dict[str, Any]] | TrajectoryBatch | TrajectoryBuffer,
870
+ output_path: str = "physlink_triptych.gif",
871
+ ) -> str:
872
+ """Produce a triptych GIF comparing Imagination, Real, and Difference panels.
873
+
874
+ Runs a single inference pass through the trained world model to produce
875
+ reconstructed (Imagination) observations, then renders them alongside the
876
+ real observations and the absolute difference as a 3-panel GIF.
877
+
878
+ Prints a "Friday afternoon window" callout comparing elapsed adaptation
879
+ time to the documented from-scratch baseline.
880
+
881
+ Args:
882
+ trajectories: Trajectory dataset to visualize. Uses the first trajectory
883
+ for the panel rendering. ``list[dict]`` and ``TrajectoryBuffer`` are
884
+ silently converted to ``TrajectoryBatch``. Each dict must contain at
885
+ minimum an "obs" key.
886
+ output_path: File path for the output GIF. Defaults to
887
+ "physlink_triptych.gif" in the current working directory.
888
+
889
+ Returns:
890
+ Absolute path to the saved GIF file.
891
+
892
+ Raises:
893
+ AdapterError: If the model has not been initialized via fit() or
894
+ load_checkpoint().
895
+
896
+ Example:
897
+ >>> adapter = DreamerV3Adapter(obs, act)
898
+ >>> adapter.fit(trajectories, steps=1000)
899
+ >>> path = adapter.visualize(trajectories)
900
+ >>> print(path) # absolute path to physlink_triptych.gif
901
+ """
902
+ from physlink.core.exceptions import AdapterError
903
+
904
+ if self._model is None:
905
+ raise AdapterError(
906
+ "DreamerV3Adapter.visualize: model not initialized.\n"
907
+ " Got: self._model is None\n"
908
+ " Expected: model weights loaded via fit() or load_checkpoint()\n"
909
+ " Fix: call adapter.fit(trajectories, steps=N) before visualize()."
910
+ )
911
+
912
+ if isinstance(trajectories, TrajectoryBuffer):
913
+ trajectories = trajectories.to_batch()
914
+ if isinstance(trajectories, list):
915
+ trajectories = TrajectoryBatch.from_list(trajectories)
916
+
917
+ import numpy as np
918
+ import torch
919
+
920
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
921
+ self._model.to(device)
922
+
923
+ obs_raw = [d["obs"] for d in trajectories.data[:_VIZ_SEQ_LEN]]
924
+ obs_seq = torch.tensor(obs_raw, dtype=torch.float32, device=device) # (T, obs_dims)
925
+
926
+ with torch.no_grad():
927
+ h_state = torch.zeros(1, self._model.gru.hidden_size, device=device)
928
+ imagination_frames = []
929
+ for t in range(obs_seq.shape[0]):
930
+ obs_t = obs_seq[t : t + 1] # shape (1, obs_dims)
931
+ act_t = torch.zeros(1, self.act_space.dims, device=device)
932
+ encoded = self._model.encoder(obs_t)
933
+ gru_input = torch.cat([encoded, act_t], dim=-1)
934
+ h_state = self._model.gru(gru_input, h_state)
935
+ post_params = self._model.posterior(torch.cat([h_state, encoded], dim=-1))
936
+ post_mean, _ = post_params.chunk(2, dim=-1)
937
+ recon = self._model.decoder(torch.cat([h_state, post_mean], dim=-1))
938
+ imagination_frames.append(recon.squeeze(0).cpu().numpy())
939
+
940
+ imagination_np = np.stack(imagination_frames) # (T, obs_dims)
941
+ real_np = obs_seq.cpu().numpy() # (T, obs_dims)
942
+
943
+ from physlink.utils.visualization import (
944
+ _FROM_SCRATCH_BASELINE_LABEL,
945
+ _FROM_SCRATCH_BASELINE_SECONDS,
946
+ render_triptych,
947
+ )
948
+
949
+ gif_path = render_triptych(imagination_np, real_np, output_path)
950
+ self._triptych_path = gif_path
951
+ print(f"[physlink] Triptych saved: {gif_path}")
952
+
953
+ elapsed = self._fit_elapsed_seconds
954
+ if elapsed is not None:
955
+ elapsed_min = elapsed / 60
956
+ baseline_hours = _FROM_SCRATCH_BASELINE_SECONDS / 3600
957
+ speedup = _FROM_SCRATCH_BASELINE_SECONDS / max(elapsed, 1.0)
958
+ print(
959
+ f"[physlink] ⏱ Adaptation complete in {elapsed_min:.1f} min\n"
960
+ f" vs. from-scratch baseline ({_FROM_SCRATCH_BASELINE_LABEL}): "
961
+ f"{baseline_hours:.0f}h\n"
962
+ f" Speedup: ~{speedup:.0f}x"
963
+ )
964
+ else:
965
+ print(
966
+ "[physlink] ⏱ Adaptation time not available "
967
+ "(call fit() before visualize() to see the Friday afternoon window callout)"
968
+ )
969
+
970
+ return gif_path
971
+
972
+ def export(self, path: str) -> dict[str, str]:
973
+ """Export a complete artifact bundle to the specified directory.
974
+
975
+ Copies the triptych GIF, writes a YAML configuration file, and writes
976
+ a human-readable summary. Calls the share panel to copy the Colab
977
+ notebook URL to the clipboard (Colab only; graceful fallback elsewhere).
978
+
979
+ Args:
980
+ path: Directory path for the exported artifacts. Created if it does
981
+ not exist. Existing files in the directory are overwritten.
982
+
983
+ Returns:
984
+ dict with keys ``gif``, ``config``, ``summary`` mapping to the
985
+ absolute paths of the respective exported files.
986
+
987
+ Raises:
988
+ AdapterError: If ``visualize()`` has not been called (no triptych
989
+ available to export).
990
+
991
+ Example:
992
+ >>> adapter.fit(trajectories, steps=1000)
993
+ >>> adapter.visualize(trajectories)
994
+ >>> artifacts = adapter.export("./physlink_export")
995
+ >>> artifacts["config"] # absolute path to config.yaml
996
+ '/abs/path/physlink_export/config.yaml'
997
+ """
998
+ import datetime
999
+ import os
1000
+ import shutil
1001
+
1002
+ import yaml
1003
+
1004
+ from physlink.core.exceptions import AdapterError
1005
+
1006
+ if self._triptych_path is None:
1007
+ raise AdapterError(
1008
+ "DreamerV3Adapter.export: no triptych available.\n"
1009
+ " Got: self._triptych_path is None\n"
1010
+ " Expected: visualize() called before export()\n"
1011
+ " Fix: call adapter.visualize(trajectories) before adapter.export(path)."
1012
+ )
1013
+
1014
+ os.makedirs(path, exist_ok=True)
1015
+
1016
+ gif_dest = os.path.join(path, "triptych.gif")
1017
+ shutil.copy2(self._triptych_path, gif_dest)
1018
+
1019
+ config = {
1020
+ "obs_space": self.obs_space.explain(),
1021
+ "act_space": self.act_space.explain(),
1022
+ "checkpoint_path": self._last_checkpoint_path,
1023
+ }
1024
+ yaml_path = os.path.join(path, "config.yaml")
1025
+ with open(yaml_path, "w", encoding="utf-8") as f:
1026
+ yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
1027
+
1028
+ elapsed_min = (
1029
+ self._fit_elapsed_seconds / 60.0
1030
+ if self._fit_elapsed_seconds is not None
1031
+ else None
1032
+ )
1033
+ elapsed_str = f"{elapsed_min:.1f} min" if elapsed_min is not None else "N/A"
1034
+ timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
1035
+ summary_lines = [
1036
+ "physlink Export Summary",
1037
+ "=======================",
1038
+ "Adapter: DreamerV3Adapter",
1039
+ f"obs_dims: {self.obs_space.dims}",
1040
+ f"act_dims: {self.act_space.dims}",
1041
+ f"Fit elapsed: {elapsed_str}",
1042
+ f"Triptych GIF: {os.path.abspath(self._triptych_path)}",
1043
+ f"Checkpoint: {self._last_checkpoint_path or 'N/A'}",
1044
+ f"Exported at: {timestamp}",
1045
+ ]
1046
+ summary_path = os.path.join(path, "summary.txt")
1047
+ with open(summary_path, "w", encoding="utf-8") as f:
1048
+ f.write("\n".join(summary_lines) + "\n")
1049
+
1050
+ print(f"[physlink] Export complete: {os.path.abspath(path)}")
1051
+ print(f"[physlink] GIF: {os.path.abspath(gif_dest)}")
1052
+ print(f"[physlink] Config: {os.path.abspath(yaml_path)}")
1053
+ print(f"[physlink] Summary: {os.path.abspath(summary_path)}")
1054
+
1055
+ _share_panel(os.path.abspath(path))
1056
+
1057
+ return {
1058
+ "gif": os.path.abspath(gif_dest),
1059
+ "config": os.path.abspath(yaml_path),
1060
+ "summary": os.path.abspath(summary_path),
1061
+ }
1062
+
1063
+ def __repr__(self) -> str:
1064
+ return (
1065
+ f"DreamerV3Adapter("
1066
+ f"obs_dims={self.obs_space.dims}, "
1067
+ f"act_dims={self.act_space.dims})"
1068
+ )