gradia 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gradia/__init__.py +38 -1
- gradia/cli/main.py +1 -1
- gradia/core/config.py +71 -13
- gradia/core/migration.py +324 -0
- gradia/events/__init__.py +17 -0
- gradia/events/logger.py +215 -0
- gradia/events/models.py +170 -0
- gradia/events/tracker.py +337 -0
- gradia/trainer/engine.py +175 -3
- gradia/viz/server.py +153 -17
- gradia/viz/static/css/timeline.css +419 -0
- gradia/viz/static/js/timeline.js +471 -0
- gradia/viz/templates/configure.html +1 -1
- gradia/viz/templates/index.html +11 -9
- gradia/viz/templates/timeline.html +195 -0
- gradia-2.0.0.dist-info/METADATA +394 -0
- gradia-2.0.0.dist-info/RECORD +30 -0
- {gradia-1.0.0.dist-info → gradia-2.0.0.dist-info}/WHEEL +1 -1
- gradia-1.0.0.dist-info/METADATA +0 -143
- gradia-1.0.0.dist-info/RECORD +0 -22
- {gradia-1.0.0.dist-info → gradia-2.0.0.dist-info}/entry_points.txt +0 -0
- {gradia-1.0.0.dist-info → gradia-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {gradia-1.0.0.dist-info → gradia-2.0.0.dist-info}/top_level.txt +0 -0
gradia/events/logger.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Timeline Logger for Learning Timeline (v2.0.0)
|
|
3
|
+
|
|
4
|
+
Handles persistent storage of learning events, compatible with existing
|
|
5
|
+
EventLogger infrastructure while extending for timeline data.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Dict, Any, Optional
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import json
|
|
11
|
+
import time
|
|
12
|
+
import threading
|
|
13
|
+
import os
|
|
14
|
+
|
|
15
|
+
from .models import LearningEvent, EpochSummary
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TimelineLogger:
|
|
19
|
+
"""
|
|
20
|
+
Logs learning events to structured files for timeline visualization.
|
|
21
|
+
|
|
22
|
+
Storage format:
|
|
23
|
+
- timeline_events.jsonl: Raw LearningEvents (append-only)
|
|
24
|
+
- timeline_summary.jsonl: EpochSummaries
|
|
25
|
+
- timeline_state.json: Tracker state for resumption
|
|
26
|
+
|
|
27
|
+
Thread-safe via shared lock with existing EventLogger.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, log_dir: str, lock: Optional[threading.Lock] = None):
|
|
31
|
+
"""
|
|
32
|
+
Initialize timeline logger.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
log_dir: Directory for log files
|
|
36
|
+
lock: Optional shared lock (uses global if not provided)
|
|
37
|
+
"""
|
|
38
|
+
self.log_dir = Path(log_dir)
|
|
39
|
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
|
|
41
|
+
# Use shared lock from callbacks module for thread safety
|
|
42
|
+
if lock is None:
|
|
43
|
+
from ..trainer.callbacks import log_lock
|
|
44
|
+
self._lock = log_lock
|
|
45
|
+
else:
|
|
46
|
+
self._lock = lock
|
|
47
|
+
|
|
48
|
+
# File paths
|
|
49
|
+
self.events_path = self.log_dir / "timeline_events.jsonl"
|
|
50
|
+
self.summary_path = self.log_dir / "timeline_summary.jsonl"
|
|
51
|
+
self.state_path = self.log_dir / "timeline_state.json"
|
|
52
|
+
|
|
53
|
+
# In-memory buffer for batch writes
|
|
54
|
+
self._event_buffer: List[LearningEvent] = []
|
|
55
|
+
self._buffer_size = 50 # Flush every N events
|
|
56
|
+
|
|
57
|
+
def log_events(self, events: List[LearningEvent], flush: bool = False):
|
|
58
|
+
"""
|
|
59
|
+
Log a batch of learning events.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
events: List of LearningEvents to log
|
|
63
|
+
flush: Force immediate write to disk
|
|
64
|
+
"""
|
|
65
|
+
self._event_buffer.extend(events)
|
|
66
|
+
|
|
67
|
+
if flush or len(self._event_buffer) >= self._buffer_size:
|
|
68
|
+
self._flush_events()
|
|
69
|
+
|
|
70
|
+
def log_event(self, event: LearningEvent):
|
|
71
|
+
"""Log a single learning event."""
|
|
72
|
+
self.log_events([event])
|
|
73
|
+
|
|
74
|
+
def log_summary(self, summary: EpochSummary):
|
|
75
|
+
"""Log an epoch summary."""
|
|
76
|
+
with self._lock:
|
|
77
|
+
with open(self.summary_path, "a") as f:
|
|
78
|
+
f.write(json.dumps(summary.to_dict()) + "\n")
|
|
79
|
+
f.flush()
|
|
80
|
+
os.fsync(f.fileno())
|
|
81
|
+
|
|
82
|
+
def save_tracker_state(self, tracker_data: Dict[str, Any]):
|
|
83
|
+
"""
|
|
84
|
+
Save tracker state for run resumption/replay.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
tracker_data: Serialized SampleTracker state
|
|
88
|
+
"""
|
|
89
|
+
with self._lock:
|
|
90
|
+
with open(self.state_path, "w") as f:
|
|
91
|
+
json.dump(tracker_data, f, indent=2, default=str)
|
|
92
|
+
f.flush()
|
|
93
|
+
os.fsync(f.fileno())
|
|
94
|
+
|
|
95
|
+
def load_tracker_state(self) -> Optional[Dict[str, Any]]:
|
|
96
|
+
"""Load saved tracker state if exists."""
|
|
97
|
+
if not self.state_path.exists():
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
with self._lock:
|
|
101
|
+
with open(self.state_path, "r") as f:
|
|
102
|
+
return json.load(f)
|
|
103
|
+
|
|
104
|
+
def get_events(
|
|
105
|
+
self,
|
|
106
|
+
epoch: Optional[int] = None,
|
|
107
|
+
sample_id: Optional[int] = None
|
|
108
|
+
) -> List[LearningEvent]:
|
|
109
|
+
"""
|
|
110
|
+
Read events from log file with optional filtering.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
epoch: Filter by epoch number
|
|
114
|
+
sample_id: Filter by sample ID
|
|
115
|
+
"""
|
|
116
|
+
events = []
|
|
117
|
+
|
|
118
|
+
if not self.events_path.exists():
|
|
119
|
+
return events
|
|
120
|
+
|
|
121
|
+
with self._lock:
|
|
122
|
+
with open(self.events_path, "r") as f:
|
|
123
|
+
for line in f:
|
|
124
|
+
if not line.strip():
|
|
125
|
+
continue
|
|
126
|
+
try:
|
|
127
|
+
data = json.loads(line)
|
|
128
|
+
event = LearningEvent.from_dict(data)
|
|
129
|
+
|
|
130
|
+
# Apply filters
|
|
131
|
+
if epoch is not None and event.epoch != epoch:
|
|
132
|
+
continue
|
|
133
|
+
if sample_id is not None and event.sample_id != sample_id:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
events.append(event)
|
|
137
|
+
except (json.JSONDecodeError, KeyError):
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
return events
|
|
141
|
+
|
|
142
|
+
def get_summaries(self) -> List[EpochSummary]:
|
|
143
|
+
"""Read all epoch summaries."""
|
|
144
|
+
summaries = []
|
|
145
|
+
|
|
146
|
+
if not self.summary_path.exists():
|
|
147
|
+
return summaries
|
|
148
|
+
|
|
149
|
+
with self._lock:
|
|
150
|
+
with open(self.summary_path, "r") as f:
|
|
151
|
+
for line in f:
|
|
152
|
+
if not line.strip():
|
|
153
|
+
continue
|
|
154
|
+
try:
|
|
155
|
+
data = json.loads(line)
|
|
156
|
+
summary = EpochSummary(**data)
|
|
157
|
+
summaries.append(summary)
|
|
158
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
return summaries
|
|
162
|
+
|
|
163
|
+
def get_sample_timeline(self, sample_id: int) -> List[LearningEvent]:
|
|
164
|
+
"""Get full timeline for a specific sample."""
|
|
165
|
+
return self.get_events(sample_id=sample_id)
|
|
166
|
+
|
|
167
|
+
def get_latest_epoch(self) -> int:
|
|
168
|
+
"""Get the most recent epoch number logged."""
|
|
169
|
+
summaries = self.get_summaries()
|
|
170
|
+
if not summaries:
|
|
171
|
+
return 0
|
|
172
|
+
return max(s.epoch for s in summaries)
|
|
173
|
+
|
|
174
|
+
def clear(self):
|
|
175
|
+
"""Clear all timeline logs (for new run)."""
|
|
176
|
+
self._flush_events() # Flush buffer first
|
|
177
|
+
|
|
178
|
+
with self._lock:
|
|
179
|
+
for path in [self.events_path, self.summary_path, self.state_path]:
|
|
180
|
+
if path.exists():
|
|
181
|
+
path.unlink()
|
|
182
|
+
|
|
183
|
+
def _flush_events(self):
|
|
184
|
+
"""Write buffered events to disk."""
|
|
185
|
+
if not self._event_buffer:
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
with self._lock:
|
|
189
|
+
with open(self.events_path, "a") as f:
|
|
190
|
+
for event in self._event_buffer:
|
|
191
|
+
f.write(json.dumps(event.to_dict(), default=str) + "\n")
|
|
192
|
+
f.flush()
|
|
193
|
+
os.fsync(f.fileno())
|
|
194
|
+
|
|
195
|
+
self._event_buffer.clear()
|
|
196
|
+
|
|
197
|
+
def finalize(self):
|
|
198
|
+
"""Ensure all buffered data is written."""
|
|
199
|
+
self._flush_events()
|
|
200
|
+
|
|
201
|
+
def __del__(self):
|
|
202
|
+
"""Flush on deletion."""
|
|
203
|
+
try:
|
|
204
|
+
self._flush_events()
|
|
205
|
+
except Exception:
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def create_timeline_logger(run_dir: str) -> TimelineLogger:
|
|
210
|
+
"""
|
|
211
|
+
Factory function to create a TimelineLogger.
|
|
212
|
+
|
|
213
|
+
Convenience function that handles path resolution.
|
|
214
|
+
"""
|
|
215
|
+
return TimelineLogger(run_dir)
|
gradia/events/models.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Event Models for Learning Timeline (v2.0.0)
|
|
3
|
+
|
|
4
|
+
Defines the LearningEvent contract that all timeline visuals consume.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field, asdict
|
|
8
|
+
from typing import Optional, Any, Dict, List
|
|
9
|
+
from enum import Enum
|
|
10
|
+
import time
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EventType(str, Enum):
|
|
14
|
+
"""Types of events in the learning timeline."""
|
|
15
|
+
SAMPLE_PREDICTION = "sample_prediction"
|
|
16
|
+
EPOCH_SUMMARY = "epoch_summary"
|
|
17
|
+
FLIP_DETECTED = "flip_detected"
|
|
18
|
+
STABILITY_CHANGE = "stability_change"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class LearningEvent:
|
|
23
|
+
"""
|
|
24
|
+
Core event model for sample-level prediction tracking.
|
|
25
|
+
|
|
26
|
+
This is the internal contract between training logic and visualization.
|
|
27
|
+
All timeline visuals consume LearningEvents, not training internals.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
run_id: Unique identifier for the training run
|
|
31
|
+
epoch: Current epoch number (1-indexed)
|
|
32
|
+
sample_id: Index or identifier of the tracked sample
|
|
33
|
+
true_label: Ground truth label for the sample
|
|
34
|
+
predicted_label: Model's prediction for this sample at this epoch
|
|
35
|
+
confidence: Prediction confidence/probability (0.0 to 1.0)
|
|
36
|
+
correct: Whether prediction matches true label
|
|
37
|
+
timestamp: Unix timestamp when event was recorded
|
|
38
|
+
margin: Optional decision margin (distance from decision boundary)
|
|
39
|
+
probabilities: Optional full probability distribution across classes
|
|
40
|
+
metadata: Optional additional context
|
|
41
|
+
"""
|
|
42
|
+
run_id: str
|
|
43
|
+
epoch: int
|
|
44
|
+
sample_id: int
|
|
45
|
+
true_label: Any
|
|
46
|
+
predicted_label: Any
|
|
47
|
+
confidence: float
|
|
48
|
+
correct: bool
|
|
49
|
+
timestamp: float = field(default_factory=time.time)
|
|
50
|
+
margin: Optional[float] = None
|
|
51
|
+
probabilities: Optional[List[float]] = None
|
|
52
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
53
|
+
|
|
54
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
55
|
+
"""Convert to dictionary for JSON serialization."""
|
|
56
|
+
return asdict(self)
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls, data: Dict[str, Any]) -> "LearningEvent":
|
|
60
|
+
"""Reconstruct from dictionary."""
|
|
61
|
+
return cls(**data)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class SampleState:
|
|
66
|
+
"""
|
|
67
|
+
Tracks the learning state of a single sample across epochs.
|
|
68
|
+
|
|
69
|
+
Used to compute stability metrics and detect flips.
|
|
70
|
+
"""
|
|
71
|
+
sample_id: int
|
|
72
|
+
true_label: Any
|
|
73
|
+
history: List[LearningEvent] = field(default_factory=list)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def flip_count(self) -> int:
|
|
77
|
+
"""Count how many times the prediction changed."""
|
|
78
|
+
if len(self.history) < 2:
|
|
79
|
+
return 0
|
|
80
|
+
flips = 0
|
|
81
|
+
for i in range(1, len(self.history)):
|
|
82
|
+
if self.history[i].predicted_label != self.history[i-1].predicted_label:
|
|
83
|
+
flips += 1
|
|
84
|
+
return flips
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def is_stable(self) -> bool:
|
|
88
|
+
"""Sample is stable if no flips in last 3 epochs."""
|
|
89
|
+
if len(self.history) < 3:
|
|
90
|
+
return False
|
|
91
|
+
last_3 = self.history[-3:]
|
|
92
|
+
return all(e.predicted_label == last_3[0].predicted_label for e in last_3)
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def stability_class(self) -> str:
|
|
96
|
+
"""
|
|
97
|
+
Classify sample stability for visualization.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
'stable_correct': Consistently correct
|
|
101
|
+
'stable_wrong': Consistently wrong
|
|
102
|
+
'unstable': Predictions keep changing
|
|
103
|
+
'late_learner': Recently became correct
|
|
104
|
+
"""
|
|
105
|
+
if not self.history:
|
|
106
|
+
return "unknown"
|
|
107
|
+
|
|
108
|
+
recent = self.history[-3:] if len(self.history) >= 3 else self.history
|
|
109
|
+
all_correct = all(e.correct for e in recent)
|
|
110
|
+
all_wrong = all(not e.correct for e in recent)
|
|
111
|
+
|
|
112
|
+
if all_correct and self.is_stable:
|
|
113
|
+
return "stable_correct"
|
|
114
|
+
elif all_wrong and self.is_stable:
|
|
115
|
+
return "stable_wrong"
|
|
116
|
+
elif self.flip_count > 2:
|
|
117
|
+
return "unstable"
|
|
118
|
+
elif len(self.history) >= 3 and self.history[-1].correct and not self.history[-3].correct:
|
|
119
|
+
return "late_learner"
|
|
120
|
+
else:
|
|
121
|
+
return "unstable"
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def first_correct_epoch(self) -> Optional[int]:
|
|
125
|
+
"""Return epoch when sample was first correctly classified."""
|
|
126
|
+
for event in self.history:
|
|
127
|
+
if event.correct:
|
|
128
|
+
return event.epoch
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def current_prediction(self) -> Optional[Any]:
|
|
133
|
+
"""Most recent prediction."""
|
|
134
|
+
return self.history[-1].predicted_label if self.history else None
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def current_confidence(self) -> Optional[float]:
|
|
138
|
+
"""Most recent confidence score."""
|
|
139
|
+
return self.history[-1].confidence if self.history else None
|
|
140
|
+
|
|
141
|
+
def add_event(self, event: LearningEvent):
|
|
142
|
+
"""Record a new prediction event."""
|
|
143
|
+
self.history.append(event)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass
|
|
147
|
+
class EpochSummary:
|
|
148
|
+
"""
|
|
149
|
+
Aggregated summary of sample-level events for one epoch.
|
|
150
|
+
|
|
151
|
+
Used for the Timeline Overview block.
|
|
152
|
+
"""
|
|
153
|
+
run_id: str
|
|
154
|
+
epoch: int
|
|
155
|
+
timestamp: float
|
|
156
|
+
total_tracked: int
|
|
157
|
+
correct_count: int
|
|
158
|
+
flip_count: int
|
|
159
|
+
stable_correct: int
|
|
160
|
+
stable_wrong: int
|
|
161
|
+
unstable: int
|
|
162
|
+
late_learners: int
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def accuracy(self) -> float:
|
|
166
|
+
"""Tracked sample accuracy."""
|
|
167
|
+
return self.correct_count / self.total_tracked if self.total_tracked > 0 else 0.0
|
|
168
|
+
|
|
169
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
170
|
+
return asdict(self)
|