gradia 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gradia/__init__.py +38 -1
- gradia/cli/main.py +1 -1
- gradia/core/config.py +71 -13
- gradia/core/migration.py +324 -0
- gradia/events/__init__.py +17 -0
- gradia/events/logger.py +215 -0
- gradia/events/models.py +170 -0
- gradia/events/tracker.py +337 -0
- gradia/trainer/engine.py +175 -3
- gradia/viz/server.py +153 -17
- gradia/viz/static/css/timeline.css +419 -0
- gradia/viz/static/js/timeline.js +471 -0
- gradia/viz/templates/configure.html +1 -1
- gradia/viz/templates/index.html +11 -9
- gradia/viz/templates/timeline.html +195 -0
- gradia-2.0.0.dist-info/METADATA +394 -0
- gradia-2.0.0.dist-info/RECORD +30 -0
- {gradia-1.0.0.dist-info → gradia-2.0.0.dist-info}/WHEEL +1 -1
- gradia-1.0.0.dist-info/METADATA +0 -143
- gradia-1.0.0.dist-info/RECORD +0 -22
- {gradia-1.0.0.dist-info → gradia-2.0.0.dist-info}/entry_points.txt +0 -0
- {gradia-1.0.0.dist-info → gradia-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {gradia-1.0.0.dist-info → gradia-2.0.0.dist-info}/top_level.txt +0 -0
gradia/events/tracker.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sample Tracker for Learning Timeline (v2.0.0)
|
|
3
|
+
|
|
4
|
+
Manages which samples to track and maintains their state across epochs.
|
|
5
|
+
Implements deterministic, bounded sampling strategy.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Dict, Any, Optional, Set
|
|
9
|
+
import numpy as np
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
|
|
12
|
+
from .models import LearningEvent, SampleState, EpochSummary
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class SampleTracker:
|
|
17
|
+
"""
|
|
18
|
+
Tracks a bounded subset of samples throughout training.
|
|
19
|
+
|
|
20
|
+
Strategy:
|
|
21
|
+
- Auto-select hard samples (near decision boundary)
|
|
22
|
+
- Include user-selected samples if specified
|
|
23
|
+
- Cap at max_samples for performance
|
|
24
|
+
- Use deterministic seeding for reproducibility
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
max_samples: Maximum samples to track (default 100)
|
|
28
|
+
seed: Random seed for reproducibility
|
|
29
|
+
tracked_indices: Set of sample indices being tracked
|
|
30
|
+
sample_states: State history for each tracked sample
|
|
31
|
+
"""
|
|
32
|
+
max_samples: int = 100
|
|
33
|
+
seed: int = 42
|
|
34
|
+
tracked_indices: Set[int] = field(default_factory=set)
|
|
35
|
+
sample_states: Dict[int, SampleState] = field(default_factory=dict)
|
|
36
|
+
user_selected: Set[int] = field(default_factory=set)
|
|
37
|
+
run_id: str = ""
|
|
38
|
+
_initialized: bool = False
|
|
39
|
+
|
|
40
|
+
def initialize(
|
|
41
|
+
self,
|
|
42
|
+
X: np.ndarray,
|
|
43
|
+
y: np.ndarray,
|
|
44
|
+
run_id: str,
|
|
45
|
+
user_indices: Optional[List[int]] = None,
|
|
46
|
+
model: Optional[Any] = None
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Initialize tracking with dataset and optional model predictions.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
X: Feature matrix
|
|
53
|
+
y: Labels
|
|
54
|
+
run_id: Unique run identifier
|
|
55
|
+
user_indices: User-selected sample indices to always track
|
|
56
|
+
model: Optional model for boundary sample selection
|
|
57
|
+
"""
|
|
58
|
+
self.run_id = run_id
|
|
59
|
+
self._rng = np.random.RandomState(self.seed)
|
|
60
|
+
|
|
61
|
+
n_samples = len(y)
|
|
62
|
+
|
|
63
|
+
# Start with user-selected samples
|
|
64
|
+
if user_indices:
|
|
65
|
+
self.user_selected = set(user_indices[:self.max_samples // 2])
|
|
66
|
+
self.tracked_indices = self.user_selected.copy()
|
|
67
|
+
|
|
68
|
+
remaining_slots = self.max_samples - len(self.tracked_indices)
|
|
69
|
+
|
|
70
|
+
if remaining_slots > 0:
|
|
71
|
+
# Try to select boundary/hard samples if model available
|
|
72
|
+
if model is not None and hasattr(model, 'predict_proba'):
|
|
73
|
+
boundary_indices = self._select_boundary_samples(X, y, model, remaining_slots)
|
|
74
|
+
self.tracked_indices.update(boundary_indices)
|
|
75
|
+
|
|
76
|
+
# Fill remaining with stratified random
|
|
77
|
+
remaining_slots = self.max_samples - len(self.tracked_indices)
|
|
78
|
+
if remaining_slots > 0:
|
|
79
|
+
available = set(range(n_samples)) - self.tracked_indices
|
|
80
|
+
random_indices = self._stratified_sample(
|
|
81
|
+
list(available), y, remaining_slots
|
|
82
|
+
)
|
|
83
|
+
self.tracked_indices.update(random_indices)
|
|
84
|
+
|
|
85
|
+
# Initialize sample states
|
|
86
|
+
for idx in self.tracked_indices:
|
|
87
|
+
self.sample_states[idx] = SampleState(
|
|
88
|
+
sample_id=idx,
|
|
89
|
+
true_label=y[idx] if hasattr(y, '__getitem__') else y.iloc[idx]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self._initialized = True
|
|
93
|
+
|
|
94
|
+
def _select_boundary_samples(
|
|
95
|
+
self,
|
|
96
|
+
X: np.ndarray,
|
|
97
|
+
y: np.ndarray,
|
|
98
|
+
model: Any,
|
|
99
|
+
n_samples: int
|
|
100
|
+
) -> Set[int]:
|
|
101
|
+
"""
|
|
102
|
+
Select samples near the decision boundary.
|
|
103
|
+
|
|
104
|
+
These are the most informative for understanding model learning.
|
|
105
|
+
"""
|
|
106
|
+
try:
|
|
107
|
+
probas = model.predict_proba(X)
|
|
108
|
+
|
|
109
|
+
# Compute margin: difference between top 2 class probabilities
|
|
110
|
+
if probas.shape[1] >= 2:
|
|
111
|
+
sorted_probas = np.sort(probas, axis=1)
|
|
112
|
+
margins = sorted_probas[:, -1] - sorted_probas[:, -2]
|
|
113
|
+
else:
|
|
114
|
+
margins = np.abs(probas[:, 0] - 0.5)
|
|
115
|
+
|
|
116
|
+
# Lower margin = closer to boundary = more interesting
|
|
117
|
+
# Exclude already tracked
|
|
118
|
+
available_mask = np.ones(len(margins), dtype=bool)
|
|
119
|
+
for idx in self.tracked_indices:
|
|
120
|
+
available_mask[idx] = False
|
|
121
|
+
|
|
122
|
+
margins[~available_mask] = np.inf
|
|
123
|
+
|
|
124
|
+
# Select lowest margin samples
|
|
125
|
+
boundary_indices = np.argsort(margins)[:n_samples]
|
|
126
|
+
return set(boundary_indices.tolist())
|
|
127
|
+
|
|
128
|
+
except Exception:
|
|
129
|
+
# Fallback if predict_proba fails
|
|
130
|
+
return set()
|
|
131
|
+
|
|
132
|
+
def _stratified_sample(
|
|
133
|
+
self,
|
|
134
|
+
available: List[int],
|
|
135
|
+
y: np.ndarray,
|
|
136
|
+
n_samples: int
|
|
137
|
+
) -> Set[int]:
|
|
138
|
+
"""
|
|
139
|
+
Stratified random sampling to maintain class balance.
|
|
140
|
+
"""
|
|
141
|
+
if not available:
|
|
142
|
+
return set()
|
|
143
|
+
|
|
144
|
+
# Group by class
|
|
145
|
+
class_indices: Dict[Any, List[int]] = {}
|
|
146
|
+
for idx in available:
|
|
147
|
+
label = y[idx] if hasattr(y, '__getitem__') else y.iloc[idx]
|
|
148
|
+
if label not in class_indices:
|
|
149
|
+
class_indices[label] = []
|
|
150
|
+
class_indices[label].append(idx)
|
|
151
|
+
|
|
152
|
+
# Sample proportionally from each class
|
|
153
|
+
selected = []
|
|
154
|
+
n_classes = len(class_indices)
|
|
155
|
+
per_class = max(1, n_samples // n_classes)
|
|
156
|
+
|
|
157
|
+
for label, indices in class_indices.items():
|
|
158
|
+
k = min(per_class, len(indices))
|
|
159
|
+
sampled = self._rng.choice(indices, size=k, replace=False)
|
|
160
|
+
selected.extend(sampled.tolist())
|
|
161
|
+
|
|
162
|
+
# Trim if over budget
|
|
163
|
+
if len(selected) > n_samples:
|
|
164
|
+
selected = self._rng.choice(selected, size=n_samples, replace=False).tolist()
|
|
165
|
+
|
|
166
|
+
return set(selected)
|
|
167
|
+
|
|
168
|
+
def record_predictions(
|
|
169
|
+
self,
|
|
170
|
+
epoch: int,
|
|
171
|
+
X: np.ndarray,
|
|
172
|
+
y: np.ndarray,
|
|
173
|
+
predictions: np.ndarray,
|
|
174
|
+
probabilities: Optional[np.ndarray] = None
|
|
175
|
+
) -> List[LearningEvent]:
|
|
176
|
+
"""
|
|
177
|
+
Record predictions for all tracked samples at this epoch.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
epoch: Current epoch number
|
|
181
|
+
X: Full feature matrix
|
|
182
|
+
y: Full labels
|
|
183
|
+
predictions: Model predictions for all samples
|
|
184
|
+
probabilities: Optional probability matrix
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
List of LearningEvents for this epoch
|
|
188
|
+
"""
|
|
189
|
+
if not self._initialized:
|
|
190
|
+
raise RuntimeError("SampleTracker not initialized. Call initialize() first.")
|
|
191
|
+
|
|
192
|
+
events = []
|
|
193
|
+
|
|
194
|
+
for idx in self.tracked_indices:
|
|
195
|
+
true_label = y[idx] if hasattr(y, '__getitem__') else y.iloc[idx]
|
|
196
|
+
pred_label = predictions[idx]
|
|
197
|
+
|
|
198
|
+
# Compute confidence
|
|
199
|
+
if probabilities is not None:
|
|
200
|
+
proba_row = probabilities[idx]
|
|
201
|
+
confidence = float(np.max(proba_row))
|
|
202
|
+
proba_list = proba_row.tolist()
|
|
203
|
+
|
|
204
|
+
# Compute margin
|
|
205
|
+
sorted_p = np.sort(proba_row)
|
|
206
|
+
margin = float(sorted_p[-1] - sorted_p[-2]) if len(sorted_p) >= 2 else confidence
|
|
207
|
+
else:
|
|
208
|
+
confidence = 1.0 # No probability info
|
|
209
|
+
proba_list = None
|
|
210
|
+
margin = None
|
|
211
|
+
|
|
212
|
+
correct = (pred_label == true_label)
|
|
213
|
+
|
|
214
|
+
event = LearningEvent(
|
|
215
|
+
run_id=self.run_id,
|
|
216
|
+
epoch=epoch,
|
|
217
|
+
sample_id=idx,
|
|
218
|
+
true_label=true_label,
|
|
219
|
+
predicted_label=pred_label,
|
|
220
|
+
confidence=confidence,
|
|
221
|
+
correct=bool(correct),
|
|
222
|
+
margin=margin,
|
|
223
|
+
probabilities=proba_list
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Update sample state
|
|
227
|
+
self.sample_states[idx].add_event(event)
|
|
228
|
+
events.append(event)
|
|
229
|
+
|
|
230
|
+
return events
|
|
231
|
+
|
|
232
|
+
def get_epoch_summary(self, epoch: int) -> EpochSummary:
|
|
233
|
+
"""
|
|
234
|
+
Generate aggregated summary for an epoch.
|
|
235
|
+
"""
|
|
236
|
+
states = list(self.sample_states.values())
|
|
237
|
+
|
|
238
|
+
# Count by stability class
|
|
239
|
+
stability_counts = {
|
|
240
|
+
"stable_correct": 0,
|
|
241
|
+
"stable_wrong": 0,
|
|
242
|
+
"unstable": 0,
|
|
243
|
+
"late_learner": 0,
|
|
244
|
+
"unknown": 0
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
correct_count = 0
|
|
248
|
+
flip_count = 0
|
|
249
|
+
|
|
250
|
+
for state in states:
|
|
251
|
+
stability_counts[state.stability_class] += 1
|
|
252
|
+
if state.history and state.history[-1].correct:
|
|
253
|
+
correct_count += 1
|
|
254
|
+
flip_count += state.flip_count
|
|
255
|
+
|
|
256
|
+
return EpochSummary(
|
|
257
|
+
run_id=self.run_id,
|
|
258
|
+
epoch=epoch,
|
|
259
|
+
timestamp=__import__('time').time(),
|
|
260
|
+
total_tracked=len(states),
|
|
261
|
+
correct_count=correct_count,
|
|
262
|
+
flip_count=flip_count,
|
|
263
|
+
stable_correct=stability_counts["stable_correct"],
|
|
264
|
+
stable_wrong=stability_counts["stable_wrong"],
|
|
265
|
+
unstable=stability_counts["unstable"],
|
|
266
|
+
late_learners=stability_counts["late_learner"]
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def get_top_flipping_samples(self, n: int = 10) -> List[SampleState]:
|
|
270
|
+
"""Get samples with most prediction flips."""
|
|
271
|
+
sorted_states = sorted(
|
|
272
|
+
self.sample_states.values(),
|
|
273
|
+
key=lambda s: s.flip_count,
|
|
274
|
+
reverse=True
|
|
275
|
+
)
|
|
276
|
+
return sorted_states[:n]
|
|
277
|
+
|
|
278
|
+
def get_late_learners(self, threshold_epoch: int = 5) -> List[SampleState]:
|
|
279
|
+
"""Get samples that became correct after threshold epoch."""
|
|
280
|
+
late = []
|
|
281
|
+
for state in self.sample_states.values():
|
|
282
|
+
first = state.first_correct_epoch
|
|
283
|
+
if first is not None and first >= threshold_epoch:
|
|
284
|
+
late.append(state)
|
|
285
|
+
return sorted(late, key=lambda s: s.first_correct_epoch or 999)
|
|
286
|
+
|
|
287
|
+
def get_never_correct(self) -> List[SampleState]:
|
|
288
|
+
"""Get samples that were never correctly classified."""
|
|
289
|
+
return [
|
|
290
|
+
state for state in self.sample_states.values()
|
|
291
|
+
if state.first_correct_epoch is None and state.history
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
def get_sample_state(self, sample_id: int) -> Optional[SampleState]:
|
|
295
|
+
"""Get state for a specific sample."""
|
|
296
|
+
return self.sample_states.get(sample_id)
|
|
297
|
+
|
|
298
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
299
|
+
"""Serialize tracker state for storage."""
|
|
300
|
+
return {
|
|
301
|
+
"max_samples": self.max_samples,
|
|
302
|
+
"seed": self.seed,
|
|
303
|
+
"run_id": self.run_id,
|
|
304
|
+
"tracked_indices": list(self.tracked_indices),
|
|
305
|
+
"user_selected": list(self.user_selected),
|
|
306
|
+
"sample_states": {
|
|
307
|
+
idx: {
|
|
308
|
+
"sample_id": state.sample_id,
|
|
309
|
+
"true_label": state.true_label,
|
|
310
|
+
"history": [e.to_dict() for e in state.history]
|
|
311
|
+
}
|
|
312
|
+
for idx, state in self.sample_states.items()
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
@classmethod
|
|
317
|
+
def from_dict(cls, data: Dict[str, Any]) -> "SampleTracker":
|
|
318
|
+
"""Restore tracker from serialized state."""
|
|
319
|
+
tracker = cls(
|
|
320
|
+
max_samples=data["max_samples"],
|
|
321
|
+
seed=data["seed"],
|
|
322
|
+
run_id=data["run_id"]
|
|
323
|
+
)
|
|
324
|
+
tracker.tracked_indices = set(data["tracked_indices"])
|
|
325
|
+
tracker.user_selected = set(data.get("user_selected", []))
|
|
326
|
+
|
|
327
|
+
for idx_str, state_data in data["sample_states"].items():
|
|
328
|
+
idx = int(idx_str)
|
|
329
|
+
state = SampleState(
|
|
330
|
+
sample_id=state_data["sample_id"],
|
|
331
|
+
true_label=state_data["true_label"],
|
|
332
|
+
history=[LearningEvent.from_dict(e) for e in state_data["history"]]
|
|
333
|
+
)
|
|
334
|
+
tracker.sample_states[idx] = state
|
|
335
|
+
|
|
336
|
+
tracker._initialized = True
|
|
337
|
+
return tracker
|
gradia/trainer/engine.py
CHANGED
|
@@ -15,6 +15,7 @@ from ..models.base import GradiaModel
|
|
|
15
15
|
from ..models.sklearn_wrappers import ModelFactory
|
|
16
16
|
from ..core.scenario import Scenario
|
|
17
17
|
from .callbacks import Callback, EventLogger
|
|
18
|
+
from ..events import SampleTracker, TimelineLogger
|
|
18
19
|
|
|
19
20
|
class Trainer:
|
|
20
21
|
def __init__(self, scenario: Scenario, config: Dict[str, Any], run_dir: str):
|
|
@@ -28,6 +29,25 @@ class Trainer:
|
|
|
28
29
|
config['model'].get('params', {})
|
|
29
30
|
)
|
|
30
31
|
self.callbacks: List[Callback] = [EventLogger(run_dir)]
|
|
32
|
+
|
|
33
|
+
# v2.0: Learning Timeline support
|
|
34
|
+
self.enable_timeline = config.get('timeline', {}).get('enabled', True)
|
|
35
|
+
self.sample_tracker: SampleTracker = None
|
|
36
|
+
self.timeline_logger: TimelineLogger = None
|
|
37
|
+
|
|
38
|
+
if self.enable_timeline:
|
|
39
|
+
timeline_config = config.get('timeline', {})
|
|
40
|
+
self.sample_tracker = SampleTracker(
|
|
41
|
+
max_samples=timeline_config.get('max_samples', 100),
|
|
42
|
+
seed=config['training'].get('random_seed', 42)
|
|
43
|
+
)
|
|
44
|
+
self.timeline_logger = TimelineLogger(run_dir)
|
|
45
|
+
|
|
46
|
+
# Store data references for evaluation
|
|
47
|
+
self._X_test = None
|
|
48
|
+
self._y_test = None
|
|
49
|
+
self._X_train = None
|
|
50
|
+
self._y_train = None
|
|
31
51
|
|
|
32
52
|
def run(self):
|
|
33
53
|
print("DEBUG: Trainer.run() started.")
|
|
@@ -82,6 +102,25 @@ class Trainer:
|
|
|
82
102
|
random_state=self.config['training'].get('random_seed', 42)
|
|
83
103
|
)
|
|
84
104
|
|
|
105
|
+
# Store for evaluation
|
|
106
|
+
self._X_train = X_train
|
|
107
|
+
self._y_train = y_train
|
|
108
|
+
self._X_test = X_test
|
|
109
|
+
self._y_test = y_test
|
|
110
|
+
|
|
111
|
+
# v2.0: Initialize sample tracker for Learning Timeline
|
|
112
|
+
run_id = f"run_{int(time.time())}"
|
|
113
|
+
if self.enable_timeline and self.sample_tracker:
|
|
114
|
+
user_samples = self.config.get('timeline', {}).get('user_samples', None)
|
|
115
|
+
self.sample_tracker.initialize(
|
|
116
|
+
X=X_test.values if hasattr(X_test, 'values') else X_test,
|
|
117
|
+
y=y_test.values if hasattr(y_test, 'values') else y_test,
|
|
118
|
+
run_id=run_id,
|
|
119
|
+
user_indices=user_samples
|
|
120
|
+
)
|
|
121
|
+
# Clear previous timeline logs for fresh run
|
|
122
|
+
self.timeline_logger.clear()
|
|
123
|
+
|
|
85
124
|
# Notify Start
|
|
86
125
|
epochs = self.config['training'].get('epochs', 10)
|
|
87
126
|
self._dispatch('on_train_begin', {
|
|
@@ -103,7 +142,6 @@ class Trainer:
|
|
|
103
142
|
classes = np.unique(y) if self.scenario.task_type == 'classification' else None
|
|
104
143
|
|
|
105
144
|
# TQDM Output to Console
|
|
106
|
-
import time
|
|
107
145
|
with tqdm(range(1, epochs + 1), desc="Training", unit="epoch", colour="green") as pbar:
|
|
108
146
|
for epoch in pbar:
|
|
109
147
|
# Small delay to visualize speed if too fast
|
|
@@ -113,6 +151,10 @@ class Trainer:
|
|
|
113
151
|
|
|
114
152
|
# Evaluate
|
|
115
153
|
metrics = self._evaluate(X_train, y_train, X_test, y_test)
|
|
154
|
+
|
|
155
|
+
# v2.0: Record timeline events
|
|
156
|
+
self._record_timeline_epoch(epoch, X_test, y_test)
|
|
157
|
+
|
|
116
158
|
self._dispatch('on_epoch_end', epoch, metrics)
|
|
117
159
|
|
|
118
160
|
else:
|
|
@@ -125,12 +167,15 @@ class Trainer:
|
|
|
125
167
|
metrics = self._evaluate(X_train, y_train, X_test, y_test)
|
|
126
168
|
|
|
127
169
|
# Simulate progress bar so UI doesn't look broken
|
|
128
|
-
import time
|
|
129
170
|
with tqdm(range(1, epochs + 1), desc="Training", unit="epoch", colour="blue") as pbar:
|
|
130
171
|
for epoch in pbar:
|
|
131
172
|
time.sleep(0.1) # Simulate work
|
|
132
173
|
# We broadcast the SAME metrics for every "epoch" since the model doesn't change
|
|
133
174
|
# But it keeps the UI happy and consistent
|
|
175
|
+
|
|
176
|
+
# v2.0: Record timeline events (same predictions each epoch for non-iterative)
|
|
177
|
+
self._record_timeline_epoch(epoch, X_test, y_test)
|
|
178
|
+
|
|
134
179
|
self._dispatch('on_epoch_end', epoch, metrics)
|
|
135
180
|
|
|
136
181
|
# Update Progress Bar
|
|
@@ -145,11 +190,25 @@ class Trainer:
|
|
|
145
190
|
|
|
146
191
|
# 5. Finalize
|
|
147
192
|
fi = self.model.get_feature_importance()
|
|
193
|
+
|
|
194
|
+
# v2.0: Finalize timeline logging
|
|
195
|
+
if self.enable_timeline and self.timeline_logger:
|
|
196
|
+
self.timeline_logger.save_tracker_state(self.sample_tracker.to_dict())
|
|
197
|
+
self.timeline_logger.finalize()
|
|
148
198
|
|
|
149
199
|
# 6. Training Complete
|
|
200
|
+
timeline_summary = None
|
|
201
|
+
if self.enable_timeline and self.sample_tracker:
|
|
202
|
+
timeline_summary = {
|
|
203
|
+
"tracked_samples": len(self.sample_tracker.tracked_indices),
|
|
204
|
+
"top_flipping": [s.sample_id for s in self.sample_tracker.get_top_flipping_samples(5)],
|
|
205
|
+
"never_correct": len(self.sample_tracker.get_never_correct())
|
|
206
|
+
}
|
|
207
|
+
|
|
150
208
|
self._dispatch("on_train_end", {
|
|
151
209
|
"epoch": epochs,
|
|
152
|
-
"feature_importance": self.model.get_feature_importance()
|
|
210
|
+
"feature_importance": self.model.get_feature_importance(),
|
|
211
|
+
"timeline_summary": timeline_summary
|
|
153
212
|
})
|
|
154
213
|
|
|
155
214
|
# 7. Save Model
|
|
@@ -201,3 +260,116 @@ class Trainer:
|
|
|
201
260
|
def _dispatch(self, method_name, *args, **kwargs):
|
|
202
261
|
for cb in self.callbacks:
|
|
203
262
|
getattr(cb, method_name)(*args, **kwargs)
|
|
263
|
+
|
|
264
|
+
def _record_timeline_epoch(self, epoch: int, X_test, y_test):
|
|
265
|
+
"""
|
|
266
|
+
Record sample-level predictions for Learning Timeline (v2.0).
|
|
267
|
+
|
|
268
|
+
Captures predictions for all tracked samples and logs events.
|
|
269
|
+
"""
|
|
270
|
+
if not self.enable_timeline or not self.sample_tracker:
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
# Get predictions for all test samples
|
|
275
|
+
X_arr = X_test.values if hasattr(X_test, 'values') else X_test
|
|
276
|
+
y_arr = y_test.values if hasattr(y_test, 'values') else y_test
|
|
277
|
+
|
|
278
|
+
predictions = self.model.predict(X_arr)
|
|
279
|
+
|
|
280
|
+
# Get probabilities if available
|
|
281
|
+
probabilities = None
|
|
282
|
+
if hasattr(self.model, 'predict_proba'):
|
|
283
|
+
try:
|
|
284
|
+
probabilities = self.model.predict_proba(X_arr)
|
|
285
|
+
except Exception:
|
|
286
|
+
pass
|
|
287
|
+
|
|
288
|
+
# Record predictions for tracked samples
|
|
289
|
+
events = self.sample_tracker.record_predictions(
|
|
290
|
+
epoch=epoch,
|
|
291
|
+
X=X_arr,
|
|
292
|
+
y=y_arr,
|
|
293
|
+
predictions=predictions,
|
|
294
|
+
probabilities=probabilities
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Log events
|
|
298
|
+
self.timeline_logger.log_events(events, flush=(epoch % 5 == 0))
|
|
299
|
+
|
|
300
|
+
# Log epoch summary
|
|
301
|
+
summary = self.sample_tracker.get_epoch_summary(epoch)
|
|
302
|
+
self.timeline_logger.log_summary(summary)
|
|
303
|
+
|
|
304
|
+
except Exception as e:
|
|
305
|
+
print(f"Warning: Timeline recording failed for epoch {epoch}: {e}")
|
|
306
|
+
|
|
307
|
+
def evaluate_full(self) -> Dict[str, Any]:
|
|
308
|
+
"""
|
|
309
|
+
Run full evaluation on test set with confusion matrix.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Dictionary with evaluation results including confusion matrix.
|
|
313
|
+
"""
|
|
314
|
+
if self._X_test is None or self._y_test is None:
|
|
315
|
+
raise ValueError("No test data available. Run training first.")
|
|
316
|
+
|
|
317
|
+
results = {}
|
|
318
|
+
|
|
319
|
+
preds = self.model.predict(self._X_test)
|
|
320
|
+
|
|
321
|
+
if self.scenario.task_type == 'classification':
|
|
322
|
+
results['accuracy'] = accuracy_score(self._y_test, preds)
|
|
323
|
+
results['precision'] = precision_score(self._y_test, preds, average='weighted', zero_division=0)
|
|
324
|
+
results['recall'] = recall_score(self._y_test, preds, average='weighted', zero_division=0)
|
|
325
|
+
results['f1'] = f1_score(self._y_test, preds, average='weighted', zero_division=0)
|
|
326
|
+
|
|
327
|
+
# Confusion matrix
|
|
328
|
+
cm = confusion_matrix(self._y_test, preds)
|
|
329
|
+
results['confusion_matrix'] = cm.tolist()
|
|
330
|
+
|
|
331
|
+
# Class labels
|
|
332
|
+
classes = sorted(list(set(self._y_test.tolist() if hasattr(self._y_test, 'tolist') else self._y_test)))
|
|
333
|
+
results['classes'] = [str(c) for c in classes]
|
|
334
|
+
else:
|
|
335
|
+
results['mse'] = mean_squared_error(self._y_test, preds)
|
|
336
|
+
results['mae'] = mean_absolute_error(self._y_test, preds)
|
|
337
|
+
results['r2'] = r2_score(self._y_test, preds)
|
|
338
|
+
|
|
339
|
+
# v2.0: Include timeline insights if available
|
|
340
|
+
if self.enable_timeline and self.sample_tracker:
|
|
341
|
+
results['timeline_insights'] = {
|
|
342
|
+
'total_tracked': len(self.sample_tracker.tracked_indices),
|
|
343
|
+
'top_flipping': [
|
|
344
|
+
{
|
|
345
|
+
'sample_id': s.sample_id,
|
|
346
|
+
'flip_count': s.flip_count,
|
|
347
|
+
'true_label': str(s.true_label),
|
|
348
|
+
'stability': s.stability_class
|
|
349
|
+
}
|
|
350
|
+
for s in self.sample_tracker.get_top_flipping_samples(10)
|
|
351
|
+
],
|
|
352
|
+
'never_correct_count': len(self.sample_tracker.get_never_correct()),
|
|
353
|
+
'late_learners_count': len(self.sample_tracker.get_late_learners())
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
return results
|
|
357
|
+
|
|
358
|
+
def get_sample_timeline(self, sample_id: int) -> List[Dict[str, Any]]:
|
|
359
|
+
"""
|
|
360
|
+
Get the learning timeline for a specific sample.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
sample_id: Index of the sample to retrieve
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
List of event dictionaries showing prediction evolution
|
|
367
|
+
"""
|
|
368
|
+
if not self.enable_timeline or not self.sample_tracker:
|
|
369
|
+
return []
|
|
370
|
+
|
|
371
|
+
state = self.sample_tracker.get_sample_state(sample_id)
|
|
372
|
+
if not state:
|
|
373
|
+
return []
|
|
374
|
+
|
|
375
|
+
return [event.to_dict() for event in state.history]
|