alignscope 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
alignscope/sdk.py ADDED
@@ -0,0 +1,606 @@
1
+ """
2
+ AlignScope — Core SDK Tracker
3
+
4
+ The AlignScopeTracker is the heart of the SDK. It:
5
+ 1. Accepts raw MARL data via .log() calls
6
+ 2. Normalizes data from any framework format into the standard schema
7
+ 3. Runs the metrics engine and anomaly detector in real-time
8
+ 4. Sends results to the dashboard server via WebSocket
9
+ 5. Optionally forwards metrics to W&B and/or MLflow
10
+ """
11
+
12
+ import json
13
+ import time
14
+ import random
15
+ import asyncio
16
+ import threading
17
+ from collections import deque
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+
20
+ from alignscope.metrics import AlignmentMetrics
21
+ from alignscope.detector import DefectionDetector
22
+
23
+
24
+ class AlignScopeTracker:
25
+ """
26
+ Core tracker that bridges training code → dashboard.
27
+
28
+ Handles data normalization, metric computation, and WebSocket streaming.
29
+ Thread-safe — can be called from training threads.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ project: str = "default",
35
+ server_url: str = "ws://localhost:8000/ws/sdk",
36
+ preset: Optional[str] = None,
37
+ paradigm: Optional[dict] = None,
38
+ metrics: Optional[list] = None,
39
+ events: Optional[list] = None,
40
+ topology: Optional[dict] = None,
41
+ config: Optional[dict] = None,
42
+ forward_wandb: bool = True,
43
+ forward_mlflow: bool = True,
44
+ ):
45
+ self.project = project
46
+ self.server_url = server_url
47
+
48
+ self.preset = preset
49
+ self.paradigm = paradigm or {"environment": preset if preset else "cooperative", "learning": "decentralized"}
50
+ self.metrics_config = metrics or []
51
+ self.events_config = events or []
52
+
53
+ if preset == "zero-sum":
54
+ if not self.metrics_config:
55
+ self.metrics_config = [
56
+ {"id": "win_rate", "label": "Win Rate", "type": "scalar"},
57
+ {"id": "exploitability", "label": "Exploitability", "type": "scalar"},
58
+ {"id": "nash_gap", "label": "Nash Gap", "type": "scalar"},
59
+ ]
60
+ if not self.events_config:
61
+ self.events_config = [
62
+ {"id": "policy_shift", "label": "Policy Shift", "color": "#dc4a4a"},
63
+ {"id": "role_reversal", "label": "Role Reversal", "color": "#e8925a"},
64
+ ]
65
+ elif preset == "mean-field":
66
+ if not self.metrics_config:
67
+ self.metrics_config = [
68
+ {"id": "population_distribution", "label": "Pop. Dist.", "type": "scalar"},
69
+ {"id": "mean_reward", "label": "Mean Reward", "type": "scalar"},
70
+ {"id": "field_entropy", "label": "Field Entropy", "type": "scalar"},
71
+ ]
72
+ elif preset == "cooperative" or not preset:
73
+ if not self.metrics_config:
74
+ self.metrics_config = [
75
+ {"id": "role_stability", "label": "Role Stability", "type": "scalar"},
76
+ {"id": "coalitions", "label": "Coalitions", "type": "scalar"},
77
+ {"id": "defectors", "label": "Defectors", "type": "scalar"},
78
+ {"id": "score", "label": "Score", "type": "scalar"},
79
+ ]
80
+ if not self.events_config:
81
+ self.events_config = [
82
+ {"id": "defection", "label": "Defection", "color": "#dc4a4a"},
83
+ {"id": "reciprocity_drop", "label": "Reciprocity ↓", "color": "#c97b3a"},
84
+ {"id": "stability_drop", "label": "Stability ↓", "color": "#b06ec7"},
85
+ {"id": "coalition_fragmentation", "label": "Coalition ↓", "color": "#d4a843"},
86
+ ]
87
+
88
+ self.topology_config = topology or {
89
+ "nodeLabel": "agent",
90
+ "edgeTypes": [
91
+ {"id": "comm", "label": "Communication", "style": "dashed"},
92
+ {"id": "collab", "label": "Collaboration", "style": "solid"}
93
+ ],
94
+ "groupBy": "team"
95
+ }
96
+
97
+ self.config = config or {}
98
+
99
+ # Metric engines
100
+ self._metrics = AlignmentMetrics()
101
+ self._detector = DefectionDetector()
102
+
103
+ # State
104
+ self._step = 0
105
+ self._agents_cache: Dict[Union[int, str], dict] = {}
106
+ self._help_tracker: Dict[tuple, int] = {}
107
+ self._role_history: Dict[Union[int, str], list] = {}
108
+ self._started = False
109
+ self._lock = threading.Lock()
110
+
111
+ # WebSocket connection — single persistent daemon thread
112
+ self._ws = None
113
+ self._ws_thread: Optional[threading.Thread] = None
114
+ self._ws_thread_started = False
115
+ self._send_queue: deque = deque()
116
+ self._ws_stop_event = threading.Event()
117
+
118
+ # Integration bridges
119
+ self._wandb_bridge = None
120
+ self._mlflow_bridge = None
121
+
122
+ if forward_wandb:
123
+ self._init_wandb_bridge()
124
+ if forward_mlflow:
125
+ self._init_mlflow_bridge()
126
+
127
+ self._print_banner()
128
+
129
+ def _print_banner(self):
130
+ """Print a nice startup message like W&B does."""
131
+ try:
132
+ from rich.console import Console
133
+ from rich.panel import Panel
134
+ console = Console()
135
+ console.print(Panel.fit(
136
+ f"[bold cyan]AlignScope[/] v0.1.0\n"
137
+ f"Project: [bold]{self.project}[/]\n"
138
+ f"Dashboard: [link]http://localhost:8000[/link]",
139
+ title="AlignScope",
140
+ border_style="cyan",
141
+ ))
142
+ except (ImportError, Exception):
143
+ print(f"AlignScope v0.1.0 | Project: {self.project}")
144
+ print(f" Dashboard: http://localhost:8000")
145
+
146
+ def log(
147
+ self,
148
+ step: int,
149
+ agents: Optional[Union[list, dict]] = None,
150
+ obs: Any = None,
151
+ actions: Any = None,
152
+ rewards: Any = None,
153
+ **kwargs,
154
+ ) -> None:
155
+ """
156
+ Log one step of multi-agent data.
157
+
158
+ Accepts data in multiple formats:
159
+ - Raw dicts (standard schema)
160
+ - NumPy arrays (auto-converted)
161
+ - Framework-specific formats (auto-detected and normalized)
162
+ """
163
+ with self._lock:
164
+ self._step = step
165
+ normalized = self._normalize_data(step, agents, obs, actions, rewards, **kwargs)
166
+
167
+ # Compute metrics
168
+ tick_metrics = self._metrics.update(normalized)
169
+
170
+ # Detect anomalies
171
+ events = self._detector.analyze(normalized, tick_metrics)
172
+
173
+ # Build payload
174
+ payload = {
175
+ "type": "tick",
176
+ "data": {
177
+ "tick": step,
178
+ "agents": normalized["agents"],
179
+ "objectives": normalized.get("objectives", []),
180
+ "team_scores": normalized.get("team_scores", {}),
181
+ "metrics": {
182
+ "agent_metrics": tick_metrics["agent_metrics"],
183
+ "pair_metrics": tick_metrics["pair_metrics"],
184
+ "team_metrics": {
185
+ str(k): v for k, v in tick_metrics["team_metrics"].items()
186
+ },
187
+ "overall_alignment_score": tick_metrics["overall_alignment_score"],
188
+ },
189
+ "relationships": self._build_relationships(),
190
+ "events": [
191
+ {
192
+ "tick": e["tick"],
193
+ "type": e["type"],
194
+ "agent_id": e.get("agent_id"),
195
+ "team": e.get("team"),
196
+ "severity": e.get("severity", 0.5),
197
+ "description": e["description"],
198
+ }
199
+ for e in events
200
+ ],
201
+ },
202
+ }
203
+
204
+ # Queue for WebSocket send
205
+ self._send_queue.append(payload)
206
+ self._flush_queue()
207
+
208
+ # Forward to integrations
209
+ self._forward_metrics(step, tick_metrics, events)
210
+
211
+ def report(self, tick: int, agent: Union[str, int], metrics: dict) -> None:
212
+ """Dynamically report custom metrics for an agent."""
213
+ with self._lock:
214
+ payload = {
215
+ "type": "tick",
216
+ "data": {
217
+ "tick": tick,
218
+ "agents": [],
219
+ "metrics": {
220
+ "agent_metrics": {str(agent): metrics},
221
+ "team_metrics": {},
222
+ },
223
+ "team_scores": {},
224
+ "relationships": [],
225
+ "events": []
226
+ }
227
+ }
228
+ self._send_queue.append(payload)
229
+ self._flush_queue()
230
+
231
+ def event(self, tick: int, type: str, agent: Union[str, int], detail: str, severity: float = 0.5) -> None:
232
+ """Dynamically report custom events."""
233
+ with self._lock:
234
+ team = self._agents_cache.get(agent, {}).get("team")
235
+ payload = {
236
+ "type": "tick",
237
+ "data": {
238
+ "tick": tick,
239
+ "agents": [],
240
+ "metrics": None,
241
+ "team_scores": {},
242
+ "relationships": [],
243
+ "events": [
244
+ {
245
+ "tick": tick,
246
+ "type": type,
247
+ "agent_id": agent,
248
+ "team": team,
249
+ "severity": severity,
250
+ "description": detail
251
+ }
252
+ ],
253
+ }
254
+ }
255
+ self._send_queue.append(payload)
256
+ self._flush_queue()
257
+
258
+ def _normalize_data(
259
+ self,
260
+ step: int,
261
+ agents: Any,
262
+ obs: Any,
263
+ actions: Any,
264
+ rewards: Any,
265
+ **kwargs,
266
+ ) -> dict:
267
+ """
268
+ Normalize data from any format into AlignScope standard schema.
269
+
270
+ Handles:
271
+ - List of dicts (standard schema) → pass through
272
+ - Dict of {agent_id: data} → convert to list
273
+ - NumPy arrays → convert to lists
274
+ - None values → infer from cache
275
+ """
276
+ normalized_agents = []
277
+ normalized_actions = []
278
+
279
+ # --- Normalize agents ---
280
+ if agents is None:
281
+ # Use cached state
282
+ normalized_agents = list(self._agents_cache.values())
283
+ elif isinstance(agents, list):
284
+ if agents and isinstance(agents[0], dict):
285
+ # Already in standard format
286
+ normalized_agents = agents
287
+ else:
288
+ # List of raw values — create minimal agent entries
289
+ for i, a in enumerate(agents):
290
+ entry = self._make_agent_entry(i, a)
291
+ normalized_agents.append(entry)
292
+ elif isinstance(agents, dict):
293
+ # Dict of {agent_id: data}
294
+ for aid, data in agents.items():
295
+ if isinstance(data, dict):
296
+ data.setdefault("agent_id", aid)
297
+ normalized_agents.append(data)
298
+ else:
299
+ normalized_agents.append(self._make_agent_entry(aid, data))
300
+ else:
301
+ # Try numpy conversion
302
+ try:
303
+ import numpy as np
304
+ if isinstance(agents, np.ndarray):
305
+ for i in range(agents.shape[0]):
306
+ normalized_agents.append(self._make_agent_entry(i, agents[i]))
307
+ except (ImportError, Exception):
308
+ pass
309
+
310
+ # Fill in defaults for any agent missing fields
311
+ import math
312
+ total_agents = len(normalized_agents)
313
+ for i, agent in enumerate(normalized_agents):
314
+ agent.setdefault("agent_id", str(i))
315
+ agent.setdefault("team", 0)
316
+ agent.setdefault("role", "agent")
317
+
318
+ # Deterministic circular layout fallback
319
+ radius = 100
320
+ default_x = math.cos(2 * math.pi * i / max(1, total_agents)) * radius
321
+ default_y = math.sin(2 * math.pi * i / max(1, total_agents)) * radius
322
+
323
+ agent.setdefault("x", round(default_x, 2))
324
+ agent.setdefault("y", round(default_y, 2))
325
+ agent.setdefault("resources", 0)
326
+ agent.setdefault("hearts", 0)
327
+ agent.setdefault("energy", 0)
328
+ agent.setdefault("is_defector", False)
329
+ agent.setdefault("coalition_id", agent.get("team", 0))
330
+
331
+ # Track roles for stability metrics
332
+ aid = agent["agent_id"]
333
+ if aid not in self._role_history:
334
+ self._role_history[aid] = []
335
+ self._role_history[aid].append(agent["role"])
336
+
337
+ # Update cache
338
+ self._agents_cache[aid] = agent
339
+
340
+ # --- Normalize actions ---
341
+ if actions is not None:
342
+ if isinstance(actions, list) and actions and isinstance(actions[0], dict):
343
+ normalized_actions = actions
344
+ elif isinstance(actions, dict):
345
+ # {agent_id: action} format
346
+ for aid, act in actions.items():
347
+ action_str = str(act) if not isinstance(act, str) else act
348
+ entry = {
349
+ "tick": step,
350
+ "agent_id": aid,
351
+ "action": action_str,
352
+ "target_id": None,
353
+ "detail": "",
354
+ }
355
+ normalized_actions.append(entry)
356
+ elif isinstance(actions, list):
357
+ # List of raw action values
358
+ for i, act in enumerate(actions):
359
+ agent_id = normalized_agents[i]["agent_id"] if i < len(normalized_agents) else i
360
+ normalized_actions.append({
361
+ "tick": step,
362
+ "agent_id": agent_id,
363
+ "action": str(act),
364
+ "target_id": None,
365
+ "detail": "",
366
+ })
367
+ else:
368
+ try:
369
+ import numpy as np
370
+ if isinstance(actions, np.ndarray):
371
+ for i in range(actions.shape[0]):
372
+ agent_id = normalized_agents[i]["agent_id"] if i < len(normalized_agents) else i
373
+ normalized_actions.append({
374
+ "tick": step,
375
+ "agent_id": agent_id,
376
+ "action": str(actions[i]),
377
+ "target_id": None,
378
+ "detail": "",
379
+ })
380
+ except (ImportError, Exception):
381
+ pass
382
+
383
+ # --- Normalize rewards and attach to agents ---
384
+ if rewards is not None:
385
+ if isinstance(rewards, dict):
386
+ for aid, r in rewards.items():
387
+ for agent in normalized_agents:
388
+ if agent["agent_id"] == aid:
389
+ agent["energy"] = float(r) if not isinstance(r, (int, float)) else r
390
+ elif isinstance(rewards, (list, tuple)):
391
+ for i, r in enumerate(rewards):
392
+ if i < len(normalized_agents):
393
+ normalized_agents[i]["energy"] = float(r)
394
+ else:
395
+ try:
396
+ import numpy as np
397
+ if isinstance(rewards, np.ndarray):
398
+ for i in range(min(rewards.shape[0], len(normalized_agents))):
399
+ normalized_agents[i]["energy"] = float(rewards[i])
400
+ except (ImportError, Exception):
401
+ pass
402
+
403
+ return {
404
+ "tick": step,
405
+ "step": step,
406
+ "agents": normalized_agents,
407
+ "actions": normalized_actions,
408
+ "defection_events": kwargs.get("defection_events", []),
409
+ "objectives": kwargs.get("objectives", []),
410
+ "team_scores": kwargs.get("team_scores", {}),
411
+ }
412
+
413
+ def _make_agent_entry(self, agent_id: Any, data: Any) -> dict:
414
+ """Create a minimal agent dict from raw data."""
415
+ if isinstance(data, dict):
416
+ entry = dict(data)
417
+ entry.setdefault("agent_id", agent_id)
418
+ return entry
419
+ return {
420
+ "agent_id": agent_id,
421
+ "team": 0,
422
+ "role": "agent",
423
+ "x": 0,
424
+ "y": 0,
425
+ "resources": 0,
426
+ "hearts": 0,
427
+ "energy": 0,
428
+ "is_defector": False,
429
+ "coalition_id": 0,
430
+ }
431
+
432
+ def _build_relationships(self) -> List[dict]:
433
+ """Build relationship edges from accumulated help tracking."""
434
+ edges = []
435
+ for (a, b), count in self._metrics.help_matrix.items():
436
+ if count > 0:
437
+ reverse = self._metrics.help_matrix.get((b, a), 0)
438
+ total = count + reverse
439
+ reciprocity = (min(count, reverse) / max(count, reverse)) if max(count, reverse) > 0 else 0
440
+ # Determine same_team from agents cache
441
+ team_a = self._agents_cache.get(a, {}).get("team")
442
+ team_b = self._agents_cache.get(b, {}).get("team")
443
+ edges.append({
444
+ "source": a,
445
+ "target": b,
446
+ "weight": total,
447
+ "reciprocity": round(reciprocity, 3),
448
+ "same_team": team_a == team_b if (team_a is not None and team_b is not None) else False,
449
+ })
450
+ return edges
451
+
452
+ def _flush_queue(self):
453
+ """Ensure the persistent WebSocket daemon thread is running."""
454
+ if not self._ws_thread_started:
455
+ self._ws_thread_started = True
456
+ self._ws_stop_event.clear()
457
+ self._ws_thread = threading.Thread(
458
+ target=self._ws_sender_loop,
459
+ daemon=True,
460
+ name="alignscope-ws",
461
+ )
462
+ self._ws_thread.start()
463
+
464
+ def _ws_sender_loop(self):
465
+ """
466
+ Single persistent background thread. Handles connection,
467
+ reconnection with exponential backoff, and queue draining.
468
+ Never spawns additional threads.
469
+ """
470
+ backoff = 0.5
471
+ max_backoff = 30.0
472
+
473
+ while not self._ws_stop_event.is_set():
474
+ try:
475
+ asyncio.run(self._ws_send_async())
476
+ except Exception as e:
477
+ print(f"[AlignScope] WS disconnected: {type(e).__name__}: {e}")
478
+
479
+ if self._ws_stop_event.is_set():
480
+ break
481
+
482
+ # Exponential backoff before reconnect
483
+ print(f"[AlignScope] Reconnecting in {backoff:.1f}s...")
484
+ self._ws_stop_event.wait(timeout=backoff)
485
+ backoff = min(backoff * 2, max_backoff)
486
+
487
+ print("[AlignScope] WS sender thread exiting.")
488
+
489
+ def _build_auto_config(self) -> dict:
490
+ """Auto-derive config from seen agents when none is explicitly provided."""
491
+ teams = {}
492
+ roles = set()
493
+ for agent in self._agents_cache.values():
494
+ tid = agent.get("team", 0)
495
+ if tid not in teams:
496
+ teams[tid] = {"id": tid, "name": f"Team {tid}", "size": 0, "color": ""}
497
+ teams[tid]["size"] += 1
498
+ roles.add(agent.get("role", "agent"))
499
+
500
+ team_colors = ['#6d9eeb', '#e8925a', '#4abe7d', '#b06ec7']
501
+ team_list = []
502
+ for tid in sorted(teams.keys()):
503
+ t = teams[tid]
504
+ t["color"] = team_colors[tid % len(team_colors)]
505
+ team_list.append(t)
506
+
507
+ return {
508
+ "num_agents": len(self._agents_cache),
509
+ "teams": team_list,
510
+ "roles": list(roles),
511
+ "num_objectives": 0,
512
+ "max_ticks": 500,
513
+ "preset": self.preset,
514
+ "paradigm": self.paradigm,
515
+ "metrics": self.metrics_config,
516
+ "events": self.events_config,
517
+ "topology": self.topology_config,
518
+ }
519
+
520
+ async def _ws_send_async(self):
521
+ """
522
+ Async WebSocket sender. Connects, sends config, then drains queue.
523
+ Raises on disconnect so the caller loop can reconnect.
524
+ """
525
+ import websockets
526
+ print(f"[AlignScope] Connecting to {self.server_url}...")
527
+ async with websockets.connect(self.server_url) as ws:
528
+ self._ws = ws
529
+ print(f"[AlignScope] Connected to {self.server_url}")
530
+
531
+ # Send config on every (re)connect
532
+ config_to_send = self.config
533
+ if not config_to_send:
534
+ config_to_send = self._build_auto_config()
535
+ if config_to_send:
536
+ await ws.send(json.dumps({"type": "config", "data": config_to_send}))
537
+
538
+ sent_count = 0
539
+ while not self._ws_stop_event.is_set():
540
+ if self._send_queue:
541
+ payload = self._send_queue.popleft()
542
+ await ws.send(json.dumps(payload))
543
+ sent_count += 1
544
+ if sent_count % 50 == 0:
545
+ print(f"[AlignScope] Sent {sent_count} ticks ({len(self._send_queue)} queued)")
546
+ else:
547
+ await asyncio.sleep(0.01)
548
+
549
+ def _init_wandb_bridge(self):
550
+ """Initialize W&B forwarding if wandb is installed."""
551
+ try:
552
+ from alignscope.integrations.wandb_bridge import WandbBridge
553
+ self._wandb_bridge = WandbBridge()
554
+ except (ImportError, Exception):
555
+ pass
556
+
557
+ def _init_mlflow_bridge(self):
558
+ """Initialize MLflow forwarding if mlflow is installed."""
559
+ try:
560
+ from alignscope.integrations.mlflow_bridge import MlflowBridge
561
+ self._mlflow_bridge = MlflowBridge()
562
+ except (ImportError, Exception):
563
+ pass
564
+
565
+ def _forward_metrics(self, step: int, metrics: dict, events: list):
566
+ """Forward alignment metrics to all active integrations."""
567
+ if self._wandb_bridge:
568
+ try:
569
+ self._wandb_bridge.log(step, metrics, events)
570
+ except Exception:
571
+ pass
572
+
573
+ if self._mlflow_bridge:
574
+ try:
575
+ self._mlflow_bridge.log(step, metrics, events)
576
+ except Exception:
577
+ pass
578
+
579
+ def reset(self):
580
+ """Reset metrics state between episodes. Call from env.reset()."""
581
+ with self._lock:
582
+ self._metrics.reset()
583
+ self._detector = DefectionDetector()
584
+ self._step = 0
585
+ self._role_history.clear()
586
+ # Keep _agents_cache and _send_queue intact for continuity
587
+
588
+ def finish(self):
589
+ """Finalize the tracking session."""
590
+ summary = self._detector.get_summary()
591
+ if self._wandb_bridge:
592
+ try:
593
+ self._wandb_bridge.finish(summary)
594
+ except Exception:
595
+ pass
596
+ if self._mlflow_bridge:
597
+ try:
598
+ self._mlflow_bridge.finish(summary)
599
+ except Exception:
600
+ pass
601
+ # Signal the persistent WS thread to stop
602
+ self._ws_stop_event.set()
603
+ if self._ws_thread and self._ws_thread.is_alive():
604
+ self._ws_thread.join(timeout=2.0)
605
+ self._ws_thread_started = False
606
+ self._send_queue.clear()