gradia 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,337 @@
1
+ """
2
+ Sample Tracker for Learning Timeline (v2.0.0)
3
+
4
+ Manages which samples to track and maintains their state across epochs.
5
+ Implements deterministic, bounded sampling strategy.
6
+ """
7
+
8
+ from typing import List, Dict, Any, Optional, Set
9
+ import numpy as np
10
+ from dataclasses import dataclass, field
11
+
12
+ from .models import LearningEvent, SampleState, EpochSummary
13
+
14
+
15
+ @dataclass
16
+ class SampleTracker:
17
+ """
18
+ Tracks a bounded subset of samples throughout training.
19
+
20
+ Strategy:
21
+ - Auto-select hard samples (near decision boundary)
22
+ - Include user-selected samples if specified
23
+ - Cap at max_samples for performance
24
+ - Use deterministic seeding for reproducibility
25
+
26
+ Attributes:
27
+ max_samples: Maximum samples to track (default 100)
28
+ seed: Random seed for reproducibility
29
+ tracked_indices: Set of sample indices being tracked
30
+ sample_states: State history for each tracked sample
31
+ """
32
+ max_samples: int = 100
33
+ seed: int = 42
34
+ tracked_indices: Set[int] = field(default_factory=set)
35
+ sample_states: Dict[int, SampleState] = field(default_factory=dict)
36
+ user_selected: Set[int] = field(default_factory=set)
37
+ run_id: str = ""
38
+ _initialized: bool = False
39
+
40
+ def initialize(
41
+ self,
42
+ X: np.ndarray,
43
+ y: np.ndarray,
44
+ run_id: str,
45
+ user_indices: Optional[List[int]] = None,
46
+ model: Optional[Any] = None
47
+ ):
48
+ """
49
+ Initialize tracking with dataset and optional model predictions.
50
+
51
+ Args:
52
+ X: Feature matrix
53
+ y: Labels
54
+ run_id: Unique run identifier
55
+ user_indices: User-selected sample indices to always track
56
+ model: Optional model for boundary sample selection
57
+ """
58
+ self.run_id = run_id
59
+ self._rng = np.random.RandomState(self.seed)
60
+
61
+ n_samples = len(y)
62
+
63
+ # Start with user-selected samples
64
+ if user_indices:
65
+ self.user_selected = set(user_indices[:self.max_samples // 2])
66
+ self.tracked_indices = self.user_selected.copy()
67
+
68
+ remaining_slots = self.max_samples - len(self.tracked_indices)
69
+
70
+ if remaining_slots > 0:
71
+ # Try to select boundary/hard samples if model available
72
+ if model is not None and hasattr(model, 'predict_proba'):
73
+ boundary_indices = self._select_boundary_samples(X, y, model, remaining_slots)
74
+ self.tracked_indices.update(boundary_indices)
75
+
76
+ # Fill remaining with stratified random
77
+ remaining_slots = self.max_samples - len(self.tracked_indices)
78
+ if remaining_slots > 0:
79
+ available = set(range(n_samples)) - self.tracked_indices
80
+ random_indices = self._stratified_sample(
81
+ list(available), y, remaining_slots
82
+ )
83
+ self.tracked_indices.update(random_indices)
84
+
85
+ # Initialize sample states
86
+ for idx in self.tracked_indices:
87
+ self.sample_states[idx] = SampleState(
88
+ sample_id=idx,
89
+ true_label=y[idx] if hasattr(y, '__getitem__') else y.iloc[idx]
90
+ )
91
+
92
+ self._initialized = True
93
+
94
+ def _select_boundary_samples(
95
+ self,
96
+ X: np.ndarray,
97
+ y: np.ndarray,
98
+ model: Any,
99
+ n_samples: int
100
+ ) -> Set[int]:
101
+ """
102
+ Select samples near the decision boundary.
103
+
104
+ These are the most informative for understanding model learning.
105
+ """
106
+ try:
107
+ probas = model.predict_proba(X)
108
+
109
+ # Compute margin: difference between top 2 class probabilities
110
+ if probas.shape[1] >= 2:
111
+ sorted_probas = np.sort(probas, axis=1)
112
+ margins = sorted_probas[:, -1] - sorted_probas[:, -2]
113
+ else:
114
+ margins = np.abs(probas[:, 0] - 0.5)
115
+
116
+ # Lower margin = closer to boundary = more interesting
117
+ # Exclude already tracked
118
+ available_mask = np.ones(len(margins), dtype=bool)
119
+ for idx in self.tracked_indices:
120
+ available_mask[idx] = False
121
+
122
+ margins[~available_mask] = np.inf
123
+
124
+ # Select lowest margin samples
125
+ boundary_indices = np.argsort(margins)[:n_samples]
126
+ return set(boundary_indices.tolist())
127
+
128
+ except Exception:
129
+ # Fallback if predict_proba fails
130
+ return set()
131
+
132
+ def _stratified_sample(
133
+ self,
134
+ available: List[int],
135
+ y: np.ndarray,
136
+ n_samples: int
137
+ ) -> Set[int]:
138
+ """
139
+ Stratified random sampling to maintain class balance.
140
+ """
141
+ if not available:
142
+ return set()
143
+
144
+ # Group by class
145
+ class_indices: Dict[Any, List[int]] = {}
146
+ for idx in available:
147
+ label = y[idx] if hasattr(y, '__getitem__') else y.iloc[idx]
148
+ if label not in class_indices:
149
+ class_indices[label] = []
150
+ class_indices[label].append(idx)
151
+
152
+ # Sample proportionally from each class
153
+ selected = []
154
+ n_classes = len(class_indices)
155
+ per_class = max(1, n_samples // n_classes)
156
+
157
+ for label, indices in class_indices.items():
158
+ k = min(per_class, len(indices))
159
+ sampled = self._rng.choice(indices, size=k, replace=False)
160
+ selected.extend(sampled.tolist())
161
+
162
+ # Trim if over budget
163
+ if len(selected) > n_samples:
164
+ selected = self._rng.choice(selected, size=n_samples, replace=False).tolist()
165
+
166
+ return set(selected)
167
+
168
+ def record_predictions(
169
+ self,
170
+ epoch: int,
171
+ X: np.ndarray,
172
+ y: np.ndarray,
173
+ predictions: np.ndarray,
174
+ probabilities: Optional[np.ndarray] = None
175
+ ) -> List[LearningEvent]:
176
+ """
177
+ Record predictions for all tracked samples at this epoch.
178
+
179
+ Args:
180
+ epoch: Current epoch number
181
+ X: Full feature matrix
182
+ y: Full labels
183
+ predictions: Model predictions for all samples
184
+ probabilities: Optional probability matrix
185
+
186
+ Returns:
187
+ List of LearningEvents for this epoch
188
+ """
189
+ if not self._initialized:
190
+ raise RuntimeError("SampleTracker not initialized. Call initialize() first.")
191
+
192
+ events = []
193
+
194
+ for idx in self.tracked_indices:
195
+ true_label = y[idx] if hasattr(y, '__getitem__') else y.iloc[idx]
196
+ pred_label = predictions[idx]
197
+
198
+ # Compute confidence
199
+ if probabilities is not None:
200
+ proba_row = probabilities[idx]
201
+ confidence = float(np.max(proba_row))
202
+ proba_list = proba_row.tolist()
203
+
204
+ # Compute margin
205
+ sorted_p = np.sort(proba_row)
206
+ margin = float(sorted_p[-1] - sorted_p[-2]) if len(sorted_p) >= 2 else confidence
207
+ else:
208
+ confidence = 1.0 # No probability info
209
+ proba_list = None
210
+ margin = None
211
+
212
+ correct = (pred_label == true_label)
213
+
214
+ event = LearningEvent(
215
+ run_id=self.run_id,
216
+ epoch=epoch,
217
+ sample_id=idx,
218
+ true_label=true_label,
219
+ predicted_label=pred_label,
220
+ confidence=confidence,
221
+ correct=bool(correct),
222
+ margin=margin,
223
+ probabilities=proba_list
224
+ )
225
+
226
+ # Update sample state
227
+ self.sample_states[idx].add_event(event)
228
+ events.append(event)
229
+
230
+ return events
231
+
232
+ def get_epoch_summary(self, epoch: int) -> EpochSummary:
233
+ """
234
+ Generate aggregated summary for an epoch.
235
+ """
236
+ states = list(self.sample_states.values())
237
+
238
+ # Count by stability class
239
+ stability_counts = {
240
+ "stable_correct": 0,
241
+ "stable_wrong": 0,
242
+ "unstable": 0,
243
+ "late_learner": 0,
244
+ "unknown": 0
245
+ }
246
+
247
+ correct_count = 0
248
+ flip_count = 0
249
+
250
+ for state in states:
251
+ stability_counts[state.stability_class] += 1
252
+ if state.history and state.history[-1].correct:
253
+ correct_count += 1
254
+ flip_count += state.flip_count
255
+
256
+ return EpochSummary(
257
+ run_id=self.run_id,
258
+ epoch=epoch,
259
+ timestamp=__import__('time').time(),
260
+ total_tracked=len(states),
261
+ correct_count=correct_count,
262
+ flip_count=flip_count,
263
+ stable_correct=stability_counts["stable_correct"],
264
+ stable_wrong=stability_counts["stable_wrong"],
265
+ unstable=stability_counts["unstable"],
266
+ late_learners=stability_counts["late_learner"]
267
+ )
268
+
269
+ def get_top_flipping_samples(self, n: int = 10) -> List[SampleState]:
270
+ """Get samples with most prediction flips."""
271
+ sorted_states = sorted(
272
+ self.sample_states.values(),
273
+ key=lambda s: s.flip_count,
274
+ reverse=True
275
+ )
276
+ return sorted_states[:n]
277
+
278
+ def get_late_learners(self, threshold_epoch: int = 5) -> List[SampleState]:
279
+ """Get samples that became correct after threshold epoch."""
280
+ late = []
281
+ for state in self.sample_states.values():
282
+ first = state.first_correct_epoch
283
+ if first is not None and first >= threshold_epoch:
284
+ late.append(state)
285
+ return sorted(late, key=lambda s: s.first_correct_epoch or 999)
286
+
287
+ def get_never_correct(self) -> List[SampleState]:
288
+ """Get samples that were never correctly classified."""
289
+ return [
290
+ state for state in self.sample_states.values()
291
+ if state.first_correct_epoch is None and state.history
292
+ ]
293
+
294
+ def get_sample_state(self, sample_id: int) -> Optional[SampleState]:
295
+ """Get state for a specific sample."""
296
+ return self.sample_states.get(sample_id)
297
+
298
+ def to_dict(self) -> Dict[str, Any]:
299
+ """Serialize tracker state for storage."""
300
+ return {
301
+ "max_samples": self.max_samples,
302
+ "seed": self.seed,
303
+ "run_id": self.run_id,
304
+ "tracked_indices": list(self.tracked_indices),
305
+ "user_selected": list(self.user_selected),
306
+ "sample_states": {
307
+ idx: {
308
+ "sample_id": state.sample_id,
309
+ "true_label": state.true_label,
310
+ "history": [e.to_dict() for e in state.history]
311
+ }
312
+ for idx, state in self.sample_states.items()
313
+ }
314
+ }
315
+
316
+ @classmethod
317
+ def from_dict(cls, data: Dict[str, Any]) -> "SampleTracker":
318
+ """Restore tracker from serialized state."""
319
+ tracker = cls(
320
+ max_samples=data["max_samples"],
321
+ seed=data["seed"],
322
+ run_id=data["run_id"]
323
+ )
324
+ tracker.tracked_indices = set(data["tracked_indices"])
325
+ tracker.user_selected = set(data.get("user_selected", []))
326
+
327
+ for idx_str, state_data in data["sample_states"].items():
328
+ idx = int(idx_str)
329
+ state = SampleState(
330
+ sample_id=state_data["sample_id"],
331
+ true_label=state_data["true_label"],
332
+ history=[LearningEvent.from_dict(e) for e in state_data["history"]]
333
+ )
334
+ tracker.sample_states[idx] = state
335
+
336
+ tracker._initialized = True
337
+ return tracker
gradia/trainer/engine.py CHANGED
@@ -15,6 +15,7 @@ from ..models.base import GradiaModel
15
15
  from ..models.sklearn_wrappers import ModelFactory
16
16
  from ..core.scenario import Scenario
17
17
  from .callbacks import Callback, EventLogger
18
+ from ..events import SampleTracker, TimelineLogger
18
19
 
19
20
  class Trainer:
20
21
  def __init__(self, scenario: Scenario, config: Dict[str, Any], run_dir: str):
@@ -28,6 +29,25 @@ class Trainer:
28
29
  config['model'].get('params', {})
29
30
  )
30
31
  self.callbacks: List[Callback] = [EventLogger(run_dir)]
32
+
33
+ # v2.0: Learning Timeline support
34
+ self.enable_timeline = config.get('timeline', {}).get('enabled', True)
35
+ self.sample_tracker: SampleTracker = None
36
+ self.timeline_logger: TimelineLogger = None
37
+
38
+ if self.enable_timeline:
39
+ timeline_config = config.get('timeline', {})
40
+ self.sample_tracker = SampleTracker(
41
+ max_samples=timeline_config.get('max_samples', 100),
42
+ seed=config['training'].get('random_seed', 42)
43
+ )
44
+ self.timeline_logger = TimelineLogger(run_dir)
45
+
46
+ # Store data references for evaluation
47
+ self._X_test = None
48
+ self._y_test = None
49
+ self._X_train = None
50
+ self._y_train = None
31
51
 
32
52
  def run(self):
33
53
  print("DEBUG: Trainer.run() started.")
@@ -82,6 +102,25 @@ class Trainer:
82
102
  random_state=self.config['training'].get('random_seed', 42)
83
103
  )
84
104
 
105
+ # Store for evaluation
106
+ self._X_train = X_train
107
+ self._y_train = y_train
108
+ self._X_test = X_test
109
+ self._y_test = y_test
110
+
111
+ # v2.0: Initialize sample tracker for Learning Timeline
112
+ run_id = f"run_{int(time.time())}"
113
+ if self.enable_timeline and self.sample_tracker:
114
+ user_samples = self.config.get('timeline', {}).get('user_samples', None)
115
+ self.sample_tracker.initialize(
116
+ X=X_test.values if hasattr(X_test, 'values') else X_test,
117
+ y=y_test.values if hasattr(y_test, 'values') else y_test,
118
+ run_id=run_id,
119
+ user_indices=user_samples
120
+ )
121
+ # Clear previous timeline logs for fresh run
122
+ self.timeline_logger.clear()
123
+
85
124
  # Notify Start
86
125
  epochs = self.config['training'].get('epochs', 10)
87
126
  self._dispatch('on_train_begin', {
@@ -103,7 +142,6 @@ class Trainer:
103
142
  classes = np.unique(y) if self.scenario.task_type == 'classification' else None
104
143
 
105
144
  # TQDM Output to Console
106
- import time
107
145
  with tqdm(range(1, epochs + 1), desc="Training", unit="epoch", colour="green") as pbar:
108
146
  for epoch in pbar:
109
147
  # Small delay to visualize speed if too fast
@@ -113,6 +151,10 @@ class Trainer:
113
151
 
114
152
  # Evaluate
115
153
  metrics = self._evaluate(X_train, y_train, X_test, y_test)
154
+
155
+ # v2.0: Record timeline events
156
+ self._record_timeline_epoch(epoch, X_test, y_test)
157
+
116
158
  self._dispatch('on_epoch_end', epoch, metrics)
117
159
 
118
160
  else:
@@ -125,12 +167,15 @@ class Trainer:
125
167
  metrics = self._evaluate(X_train, y_train, X_test, y_test)
126
168
 
127
169
  # Simulate progress bar so UI doesn't look broken
128
- import time
129
170
  with tqdm(range(1, epochs + 1), desc="Training", unit="epoch", colour="blue") as pbar:
130
171
  for epoch in pbar:
131
172
  time.sleep(0.1) # Simulate work
132
173
  # We broadcast the SAME metrics for every "epoch" since the model doesn't change
133
174
  # But it keeps the UI happy and consistent
175
+
176
+ # v2.0: Record timeline events (same predictions each epoch for non-iterative)
177
+ self._record_timeline_epoch(epoch, X_test, y_test)
178
+
134
179
  self._dispatch('on_epoch_end', epoch, metrics)
135
180
 
136
181
  # Update Progress Bar
@@ -145,11 +190,25 @@ class Trainer:
145
190
 
146
191
  # 5. Finalize
147
192
  fi = self.model.get_feature_importance()
193
+
194
+ # v2.0: Finalize timeline logging
195
+ if self.enable_timeline and self.timeline_logger:
196
+ self.timeline_logger.save_tracker_state(self.sample_tracker.to_dict())
197
+ self.timeline_logger.finalize()
148
198
 
149
199
  # 6. Training Complete
200
+ timeline_summary = None
201
+ if self.enable_timeline and self.sample_tracker:
202
+ timeline_summary = {
203
+ "tracked_samples": len(self.sample_tracker.tracked_indices),
204
+ "top_flipping": [s.sample_id for s in self.sample_tracker.get_top_flipping_samples(5)],
205
+ "never_correct": len(self.sample_tracker.get_never_correct())
206
+ }
207
+
150
208
  self._dispatch("on_train_end", {
151
209
  "epoch": epochs,
152
- "feature_importance": self.model.get_feature_importance()
210
+ "feature_importance": self.model.get_feature_importance(),
211
+ "timeline_summary": timeline_summary
153
212
  })
154
213
 
155
214
  # 7. Save Model
@@ -201,3 +260,116 @@ class Trainer:
201
260
  def _dispatch(self, method_name, *args, **kwargs):
202
261
  for cb in self.callbacks:
203
262
  getattr(cb, method_name)(*args, **kwargs)
263
+
264
+ def _record_timeline_epoch(self, epoch: int, X_test, y_test):
265
+ """
266
+ Record sample-level predictions for Learning Timeline (v2.0).
267
+
268
+ Captures predictions for all tracked samples and logs events.
269
+ """
270
+ if not self.enable_timeline or not self.sample_tracker:
271
+ return
272
+
273
+ try:
274
+ # Get predictions for all test samples
275
+ X_arr = X_test.values if hasattr(X_test, 'values') else X_test
276
+ y_arr = y_test.values if hasattr(y_test, 'values') else y_test
277
+
278
+ predictions = self.model.predict(X_arr)
279
+
280
+ # Get probabilities if available
281
+ probabilities = None
282
+ if hasattr(self.model, 'predict_proba'):
283
+ try:
284
+ probabilities = self.model.predict_proba(X_arr)
285
+ except Exception:
286
+ pass
287
+
288
+ # Record predictions for tracked samples
289
+ events = self.sample_tracker.record_predictions(
290
+ epoch=epoch,
291
+ X=X_arr,
292
+ y=y_arr,
293
+ predictions=predictions,
294
+ probabilities=probabilities
295
+ )
296
+
297
+ # Log events
298
+ self.timeline_logger.log_events(events, flush=(epoch % 5 == 0))
299
+
300
+ # Log epoch summary
301
+ summary = self.sample_tracker.get_epoch_summary(epoch)
302
+ self.timeline_logger.log_summary(summary)
303
+
304
+ except Exception as e:
305
+ print(f"Warning: Timeline recording failed for epoch {epoch}: {e}")
306
+
307
+ def evaluate_full(self) -> Dict[str, Any]:
308
+ """
309
+ Run full evaluation on test set with confusion matrix.
310
+
311
+ Returns:
312
+ Dictionary with evaluation results including confusion matrix.
313
+ """
314
+ if self._X_test is None or self._y_test is None:
315
+ raise ValueError("No test data available. Run training first.")
316
+
317
+ results = {}
318
+
319
+ preds = self.model.predict(self._X_test)
320
+
321
+ if self.scenario.task_type == 'classification':
322
+ results['accuracy'] = accuracy_score(self._y_test, preds)
323
+ results['precision'] = precision_score(self._y_test, preds, average='weighted', zero_division=0)
324
+ results['recall'] = recall_score(self._y_test, preds, average='weighted', zero_division=0)
325
+ results['f1'] = f1_score(self._y_test, preds, average='weighted', zero_division=0)
326
+
327
+ # Confusion matrix
328
+ cm = confusion_matrix(self._y_test, preds)
329
+ results['confusion_matrix'] = cm.tolist()
330
+
331
+ # Class labels
332
+ classes = sorted(list(set(self._y_test.tolist() if hasattr(self._y_test, 'tolist') else self._y_test)))
333
+ results['classes'] = [str(c) for c in classes]
334
+ else:
335
+ results['mse'] = mean_squared_error(self._y_test, preds)
336
+ results['mae'] = mean_absolute_error(self._y_test, preds)
337
+ results['r2'] = r2_score(self._y_test, preds)
338
+
339
+ # v2.0: Include timeline insights if available
340
+ if self.enable_timeline and self.sample_tracker:
341
+ results['timeline_insights'] = {
342
+ 'total_tracked': len(self.sample_tracker.tracked_indices),
343
+ 'top_flipping': [
344
+ {
345
+ 'sample_id': s.sample_id,
346
+ 'flip_count': s.flip_count,
347
+ 'true_label': str(s.true_label),
348
+ 'stability': s.stability_class
349
+ }
350
+ for s in self.sample_tracker.get_top_flipping_samples(10)
351
+ ],
352
+ 'never_correct_count': len(self.sample_tracker.get_never_correct()),
353
+ 'late_learners_count': len(self.sample_tracker.get_late_learners())
354
+ }
355
+
356
+ return results
357
+
358
+ def get_sample_timeline(self, sample_id: int) -> List[Dict[str, Any]]:
359
+ """
360
+ Get the learning timeline for a specific sample.
361
+
362
+ Args:
363
+ sample_id: Index of the sample to retrieve
364
+
365
+ Returns:
366
+ List of event dictionaries showing prediction evolution
367
+ """
368
+ if not self.enable_timeline or not self.sample_tracker:
369
+ return []
370
+
371
+ state = self.sample_tracker.get_sample_state(sample_id)
372
+ if not state:
373
+ return []
374
+
375
+ return [event.to_dict() for event in state.history]