gradia 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gradia/viz/server.py CHANGED
@@ -1,3 +1,4 @@
1
+ from contextlib import asynccontextmanager
1
2
  from fastapi import FastAPI, Request
2
3
  from fastapi.staticfiles import StaticFiles
3
4
  from fastapi.templating import Jinja2Templates
@@ -6,15 +7,13 @@ import uvicorn
6
7
  import json
7
8
  import threading
8
9
  from pathlib import Path
9
- from typing import Dict, Any
10
+ from typing import Dict, Any, Optional
10
11
 
11
12
  from ..trainer.engine import Trainer
12
13
 
13
14
  import psutil
14
15
  import time
15
16
 
16
- app = FastAPI()
17
-
18
17
  # Global State (Injected by CLI)
19
18
  SCENARIO = None
20
19
  CONFIG_MGR = None
@@ -24,20 +23,13 @@ TRAINER = None
24
23
  TRAINING_THREAD = None
25
24
  SYSTEM_THREAD = None
26
25
 
27
- # Mounts
26
+ # Base directory for templates and static files
28
27
  BASE_DIR = Path(__file__).resolve().parent
29
28
 
30
- app.mount("/static", StaticFiles(directory=BASE_DIR / "static"), name="static")
31
- # Mount assets if they exist outside static, or ensure user put them in static. Assuming viz/assets
32
- assets_path = BASE_DIR / "assets"
33
- if assets_path.exists():
34
- app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
35
-
36
29
  templates = Jinja2Templates(directory=BASE_DIR / "templates")
37
30
 
38
31
  from ..trainer.callbacks import log_lock
39
32
 
40
- # ... imports ...
41
33
  import os
42
34
 
43
35
  # System Monitor
@@ -61,27 +53,47 @@ def system_monitor_loop():
61
53
  f.flush()
62
54
  os.fsync(f.fileno())
63
55
 
64
- # Start System Monitor on import/startup (or when server starts)
65
- @app.on_event("startup")
66
- async def startup_event():
56
+
57
+ @asynccontextmanager
58
+ async def lifespan(app: FastAPI):
59
+ """Lifespan context manager for startup/shutdown events."""
67
60
  global SYSTEM_THREAD
61
+ # Startup
68
62
  SYSTEM_THREAD = threading.Thread(target=system_monitor_loop, daemon=True)
69
63
  SYSTEM_THREAD.start()
64
+ yield
65
+ # Shutdown (if needed)
66
+ pass
67
+
68
+
69
+ app = FastAPI(lifespan=lifespan)
70
+
71
+ # Mount static files
72
+ app.mount("/static", StaticFiles(directory=BASE_DIR / "static"), name="static")
73
+ assets_path = BASE_DIR / "assets"
74
+ if assets_path.exists():
75
+ app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
70
76
 
71
77
 
72
78
  @app.get("/")
73
79
  async def read_root(request: Request):
74
80
  if TRAINER is None:
75
81
  return RedirectResponse("/configure")
76
- return templates.TemplateResponse("index.html", {"request": request, "scenario": SCENARIO})
82
+ return templates.TemplateResponse(request, "index.html", {"scenario": SCENARIO})
83
+
84
+ @app.get("/timeline")
85
+ async def timeline_page(request: Request):
86
+ """Learning Timeline view (v2.0)."""
87
+ if TRAINER is None:
88
+ return RedirectResponse("/configure")
89
+ return templates.TemplateResponse(request, "timeline.html", {"scenario": SCENARIO})
77
90
 
78
91
  @app.get("/configure")
79
92
  async def configure_page(request: Request):
80
93
  if SCENARIO is None:
81
94
  return "System not initialized correctly from CLI."
82
95
 
83
- return templates.TemplateResponse("configure.html", {
84
- "request": request,
96
+ return templates.TemplateResponse(request, "configure.html", {
85
97
  "scenario": SCENARIO,
86
98
  "features": SCENARIO.features,
87
99
  "default_config": DEFAULT_CONFIG
@@ -221,6 +233,130 @@ async def evaluate_model():
221
233
  except Exception as e:
222
234
  return JSONResponse({"error": str(e)}, status_code=500)
223
235
 
236
+
237
+ # ============================================================================
238
+ # v2.0.0 Learning Timeline API Endpoints
239
+ # ============================================================================
240
+
241
+ @app.get("/api/timeline/events")
242
+ async def get_timeline_events(epoch: int = None, sample_id: int = None):
243
+ """
244
+ Get learning timeline events with optional filtering.
245
+
246
+ Query params:
247
+ epoch: Filter by specific epoch
248
+ sample_id: Filter by specific sample
249
+ """
250
+ if TRAINER is None or not TRAINER.enable_timeline:
251
+ return JSONResponse({"error": "Timeline not available"}, status_code=400)
252
+
253
+ try:
254
+ events = TRAINER.timeline_logger.get_events(epoch=epoch, sample_id=sample_id)
255
+ return JSONResponse(content=[e.to_dict() for e in events])
256
+ except Exception as e:
257
+ return JSONResponse({"error": str(e)}, status_code=500)
258
+
259
+
260
+ @app.get("/api/timeline/summaries")
261
+ async def get_timeline_summaries():
262
+ """Get all epoch summaries for timeline overview."""
263
+ if TRAINER is None or not TRAINER.enable_timeline:
264
+ return JSONResponse({"error": "Timeline not available"}, status_code=400)
265
+
266
+ try:
267
+ summaries = TRAINER.timeline_logger.get_summaries()
268
+ return JSONResponse(content=[s.to_dict() for s in summaries])
269
+ except Exception as e:
270
+ return JSONResponse({"error": str(e)}, status_code=500)
271
+
272
+
273
+ @app.get("/api/timeline/sample/{sample_id}")
274
+ async def get_sample_timeline(sample_id: int):
275
+ """Get full timeline for a specific tracked sample."""
276
+ if TRAINER is None or not TRAINER.enable_timeline:
277
+ return JSONResponse({"error": "Timeline not available"}, status_code=400)
278
+
279
+ try:
280
+ timeline = TRAINER.get_sample_timeline(sample_id)
281
+ if not timeline:
282
+ return JSONResponse({"error": "Sample not tracked"}, status_code=404)
283
+
284
+ # Also include stability analysis
285
+ state = TRAINER.sample_tracker.get_sample_state(sample_id)
286
+ result = {
287
+ "sample_id": sample_id,
288
+ "true_label": str(state.true_label) if state else None,
289
+ "flip_count": state.flip_count if state else 0,
290
+ "stability_class": state.stability_class if state else "unknown",
291
+ "first_correct_epoch": state.first_correct_epoch if state else None,
292
+ "events": timeline
293
+ }
294
+ return JSONResponse(content=result)
295
+ except Exception as e:
296
+ return JSONResponse({"error": str(e)}, status_code=500)
297
+
298
+
299
+ @app.get("/api/timeline/insights")
300
+ async def get_timeline_insights():
301
+ """Get aggregated timeline insights: flipping samples, late learners, etc."""
302
+ if TRAINER is None or not TRAINER.enable_timeline:
303
+ return JSONResponse({"error": "Timeline not available"}, status_code=400)
304
+
305
+ try:
306
+ tracker = TRAINER.sample_tracker
307
+
308
+ insights = {
309
+ "tracked_samples": list(tracker.tracked_indices),
310
+ "total_tracked": len(tracker.tracked_indices),
311
+ "top_flipping": [
312
+ {
313
+ "sample_id": s.sample_id,
314
+ "flip_count": s.flip_count,
315
+ "true_label": str(s.true_label),
316
+ "current_prediction": str(s.current_prediction),
317
+ "stability_class": s.stability_class
318
+ }
319
+ for s in tracker.get_top_flipping_samples(20)
320
+ ],
321
+ "late_learners": [
322
+ {
323
+ "sample_id": s.sample_id,
324
+ "true_label": str(s.true_label),
325
+ "first_correct_epoch": s.first_correct_epoch
326
+ }
327
+ for s in tracker.get_late_learners()
328
+ ],
329
+ "never_correct": [
330
+ {
331
+ "sample_id": s.sample_id,
332
+ "true_label": str(s.true_label),
333
+ "current_prediction": str(s.current_prediction)
334
+ }
335
+ for s in tracker.get_never_correct()
336
+ ],
337
+ "stability_distribution": _compute_stability_distribution(tracker)
338
+ }
339
+ return JSONResponse(content=insights)
340
+ except Exception as e:
341
+ return JSONResponse({"error": str(e)}, status_code=500)
342
+
343
+
344
+ def _compute_stability_distribution(tracker) -> Dict[str, int]:
345
+ """Compute distribution of stability classes."""
346
+ distribution = {
347
+ "stable_correct": 0,
348
+ "stable_wrong": 0,
349
+ "unstable": 0,
350
+ "late_learner": 0,
351
+ "unknown": 0
352
+ }
353
+ for state in tracker.sample_states.values():
354
+ stability = state.stability_class
355
+ if stability in distribution:
356
+ distribution[stability] += 1
357
+ return distribution
358
+
359
+
224
360
  def start_server(run_dir: str, port: int = 8000):
225
361
  global RUN_DIR
226
362
  RUN_DIR = Path(run_dir).resolve()
@@ -0,0 +1,419 @@
1
+ /* timeline.css - Styles for Learning Timeline v2.0 */
2
+
3
+ /* Timeline Overview */
4
+ .timeline-container {
5
+ min-height: 200px;
6
+ padding: 20px 0;
7
+ overflow-x: auto;
8
+ }
9
+
10
+ .timeline-loading {
11
+ display: flex;
12
+ align-items: center;
13
+ justify-content: center;
14
+ height: 150px;
15
+ color: var(--text-secondary);
16
+ font-style: italic;
17
+ }
18
+
19
+ .timeline-grid {
20
+ display: flex;
21
+ flex-direction: column;
22
+ gap: 4px;
23
+ min-width: 100%;
24
+ }
25
+
26
+ .timeline-row {
27
+ display: flex;
28
+ align-items: center;
29
+ gap: 8px;
30
+ height: 24px;
31
+ }
32
+
33
+ .timeline-sample-id {
34
+ width: 60px;
35
+ font-size: 0.75rem;
36
+ color: var(--text-secondary);
37
+ text-align: right;
38
+ flex-shrink: 0;
39
+ cursor: pointer;
40
+ transition: color 0.2s;
41
+ }
42
+
43
+ .timeline-sample-id:hover {
44
+ color: var(--accent);
45
+ }
46
+
47
+ .timeline-epochs {
48
+ display: flex;
49
+ gap: 2px;
50
+ flex-grow: 1;
51
+ }
52
+
53
+ .epoch-cell {
54
+ width: 20px;
55
+ height: 20px;
56
+ border-radius: 3px;
57
+ cursor: pointer;
58
+ transition: transform 0.1s, box-shadow 0.1s;
59
+ position: relative;
60
+ }
61
+
62
+ .epoch-cell:hover {
63
+ transform: scale(1.3);
64
+ box-shadow: 0 0 8px rgba(255,255,255,0.3);
65
+ z-index: 10;
66
+ }
67
+
68
+ .epoch-cell.correct {
69
+ background: var(--success);
70
+ }
71
+
72
+ .epoch-cell.wrong {
73
+ background: var(--error);
74
+ }
75
+
76
+ .epoch-cell.flip {
77
+ border: 2px solid #ffd700;
78
+ }
79
+
80
+ /* Stability Colors */
81
+ .stable-correct { background: #238636; }
82
+ .stable-wrong { background: #da3633; }
83
+ .unstable { background: #d29922; }
84
+ .late-learner { background: #58a6ff; }
85
+ .unknown { background: #484f58; }
86
+
87
+ /* Legend */
88
+ .timeline-legend {
89
+ display: flex;
90
+ gap: 20px;
91
+ padding: 15px 0 0;
92
+ border-top: 1px solid var(--border);
93
+ margin-top: 15px;
94
+ }
95
+
96
+ .legend-item {
97
+ display: flex;
98
+ align-items: center;
99
+ gap: 6px;
100
+ font-size: 0.8rem;
101
+ color: var(--text-secondary);
102
+ }
103
+
104
+ .legend-color {
105
+ width: 12px;
106
+ height: 12px;
107
+ border-radius: 2px;
108
+ }
109
+
110
+ /* Card Enhancements */
111
+ .card-header {
112
+ margin-bottom: 15px;
113
+ }
114
+
115
+ .card-header h3 {
116
+ margin: 0;
117
+ font-size: 1.1rem;
118
+ }
119
+
120
+ .card-subtitle {
121
+ font-size: 0.8rem;
122
+ color: var(--text-secondary);
123
+ margin-top: 4px;
124
+ display: block;
125
+ }
126
+
127
+ /* Instability Panel */
128
+ #instability-panel {
129
+ display: flex;
130
+ flex-direction: column;
131
+ gap: 20px;
132
+ }
133
+
134
+ .instability-section h4 {
135
+ font-size: 0.9rem;
136
+ color: var(--text-secondary);
137
+ margin: 0 0 10px 0;
138
+ padding-bottom: 5px;
139
+ border-bottom: 1px solid var(--border);
140
+ }
141
+
142
+ .sample-list {
143
+ display: flex;
144
+ flex-direction: column;
145
+ gap: 6px;
146
+ max-height: 150px;
147
+ overflow-y: auto;
148
+ }
149
+
150
+ .sample-item {
151
+ display: flex;
152
+ justify-content: space-between;
153
+ align-items: center;
154
+ padding: 8px 12px;
155
+ background: var(--bg-tertiary);
156
+ border-radius: 4px;
157
+ cursor: pointer;
158
+ transition: background 0.2s;
159
+ }
160
+
161
+ .sample-item:hover {
162
+ background: var(--bg-hover);
163
+ }
164
+
165
+ .sample-item .sample-id {
166
+ font-family: monospace;
167
+ font-weight: 600;
168
+ }
169
+
170
+ .sample-item .sample-meta {
171
+ font-size: 0.8rem;
172
+ color: var(--text-secondary);
173
+ }
174
+
175
+ .sample-item .flip-badge {
176
+ background: var(--warning);
177
+ color: #000;
178
+ padding: 2px 6px;
179
+ border-radius: 10px;
180
+ font-size: 0.7rem;
181
+ font-weight: 600;
182
+ }
183
+
184
+ .empty-state {
185
+ color: var(--text-secondary);
186
+ font-style: italic;
187
+ font-size: 0.85rem;
188
+ padding: 10px;
189
+ }
190
+
191
+ /* Context Strip Items */
192
+ .context-item {
193
+ display: flex;
194
+ justify-content: space-between;
195
+ padding: 4px 0;
196
+ }
197
+
198
+ .context-label {
199
+ color: var(--text-secondary);
200
+ font-size: 0.8rem;
201
+ }
202
+
203
+ .context-value {
204
+ font-weight: 600;
205
+ font-size: 0.85rem;
206
+ }
207
+
208
+ /* Navigation Buttons */
209
+ .btn-nav {
210
+ background: var(--bg-card);
211
+ border: 1px solid var(--border);
212
+ color: var(--text-primary);
213
+ padding: 10px 12px;
214
+ border-radius: 6px;
215
+ cursor: pointer;
216
+ text-align: left;
217
+ font-size: 0.85rem;
218
+ transition: all 0.2s;
219
+ }
220
+
221
+ .btn-nav:hover {
222
+ background: var(--bg-hover);
223
+ border-color: var(--accent);
224
+ }
225
+
226
+ .btn-nav.active {
227
+ background: rgba(88, 166, 255, 0.1);
228
+ border-color: var(--accent);
229
+ color: var(--accent);
230
+ }
231
+
232
+ /* Modal Styles */
233
+ .modal {
234
+ position: fixed;
235
+ top: 0;
236
+ left: 0;
237
+ width: 100%;
238
+ height: 100%;
239
+ background: rgba(0, 0, 0, 0.8);
240
+ display: flex;
241
+ align-items: center;
242
+ justify-content: center;
243
+ z-index: 1000;
244
+ }
245
+
246
+ .modal-content {
247
+ background: var(--bg-card);
248
+ border-radius: 12px;
249
+ width: 90%;
250
+ max-width: 900px;
251
+ max-height: 85vh;
252
+ overflow-y: auto;
253
+ border: 1px solid var(--border);
254
+ }
255
+
256
+ .modal-header {
257
+ display: flex;
258
+ justify-content: space-between;
259
+ align-items: center;
260
+ padding: 20px;
261
+ border-bottom: 1px solid var(--border);
262
+ }
263
+
264
+ .modal-header h3 {
265
+ margin: 0;
266
+ }
267
+
268
+ .modal-close {
269
+ background: none;
270
+ border: none;
271
+ color: var(--text-secondary);
272
+ font-size: 1.5rem;
273
+ cursor: pointer;
274
+ padding: 5px 10px;
275
+ transition: color 0.2s;
276
+ }
277
+
278
+ .modal-close:hover {
279
+ color: var(--error);
280
+ }
281
+
282
+ .modal-body {
283
+ padding: 20px;
284
+ }
285
+
286
+ /* Inspector Header */
287
+ .inspector-header {
288
+ display: grid;
289
+ grid-template-columns: repeat(4, 1fr);
290
+ gap: 15px;
291
+ margin-bottom: 25px;
292
+ }
293
+
294
+ .inspector-stat {
295
+ background: var(--bg-tertiary);
296
+ padding: 15px;
297
+ border-radius: 8px;
298
+ text-align: center;
299
+ }
300
+
301
+ .stat-label {
302
+ display: block;
303
+ font-size: 0.75rem;
304
+ color: var(--text-secondary);
305
+ margin-bottom: 5px;
306
+ }
307
+
308
+ .stat-value {
309
+ display: block;
310
+ font-size: 1.2rem;
311
+ font-weight: 700;
312
+ color: var(--accent);
313
+ }
314
+
315
+ /* Inspector Charts */
316
+ .inspector-charts {
317
+ display: grid;
318
+ grid-template-columns: 1fr 1fr;
319
+ gap: 20px;
320
+ margin-bottom: 25px;
321
+ }
322
+
323
+ .inspector-charts .chart-container {
324
+ background: var(--bg-tertiary);
325
+ padding: 15px;
326
+ border-radius: 8px;
327
+ }
328
+
329
+ .inspector-charts h4 {
330
+ margin: 0 0 10px 0;
331
+ font-size: 0.9rem;
332
+ color: var(--text-secondary);
333
+ }
334
+
335
+ /* Event History */
336
+ .inspector-events h4 {
337
+ margin: 0 0 10px 0;
338
+ font-size: 0.9rem;
339
+ color: var(--text-secondary);
340
+ }
341
+
342
+ .event-list {
343
+ display: flex;
344
+ flex-direction: column;
345
+ gap: 8px;
346
+ max-height: 200px;
347
+ overflow-y: auto;
348
+ }
349
+
350
+ .event-item {
351
+ display: grid;
352
+ grid-template-columns: 60px 80px 80px 80px 1fr;
353
+ gap: 10px;
354
+ padding: 10px;
355
+ background: var(--bg-tertiary);
356
+ border-radius: 4px;
357
+ font-size: 0.8rem;
358
+ align-items: center;
359
+ }
360
+
361
+ .event-item.correct {
362
+ border-left: 3px solid var(--success);
363
+ }
364
+
365
+ .event-item.wrong {
366
+ border-left: 3px solid var(--error);
367
+ }
368
+
369
+ .event-item .epoch-badge {
370
+ font-weight: 600;
371
+ }
372
+
373
+ .event-item .flip-marker {
374
+ color: #ffd700;
375
+ font-weight: bold;
376
+ }
377
+
378
+ /* Epoch Axis */
379
+ .epoch-axis {
380
+ display: flex;
381
+ gap: 2px;
382
+ padding-left: 68px;
383
+ margin-bottom: 5px;
384
+ }
385
+
386
+ .epoch-label {
387
+ width: 20px;
388
+ text-align: center;
389
+ font-size: 0.65rem;
390
+ color: var(--text-secondary);
391
+ }
392
+
393
+ /* Tooltip */
394
+ .epoch-tooltip {
395
+ position: absolute;
396
+ background: var(--bg-card);
397
+ border: 1px solid var(--border);
398
+ border-radius: 4px;
399
+ padding: 8px 12px;
400
+ font-size: 0.8rem;
401
+ z-index: 100;
402
+ pointer-events: none;
403
+ white-space: nowrap;
404
+ }
405
+
406
+ /* Responsive */
407
+ @media (max-width: 768px) {
408
+ .inspector-header {
409
+ grid-template-columns: repeat(2, 1fr);
410
+ }
411
+
412
+ .inspector-charts {
413
+ grid-template-columns: 1fr;
414
+ }
415
+
416
+ .timeline-legend {
417
+ flex-wrap: wrap;
418
+ }
419
+ }