gradia 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,215 @@
1
+ """
2
+ Timeline Logger for Learning Timeline (v2.0.0)
3
+
4
+ Handles persistent storage of learning events, compatible with existing
5
+ EventLogger infrastructure while extending for timeline data.
6
+ """
7
+
8
+ from typing import List, Dict, Any, Optional
9
+ from pathlib import Path
10
+ import json
11
+ import time
12
+ import threading
13
+ import os
14
+
15
+ from .models import LearningEvent, EpochSummary
16
+
17
+
18
+ class TimelineLogger:
19
+ """
20
+ Logs learning events to structured files for timeline visualization.
21
+
22
+ Storage format:
23
+ - timeline_events.jsonl: Raw LearningEvents (append-only)
24
+ - timeline_summary.jsonl: EpochSummaries
25
+ - timeline_state.json: Tracker state for resumption
26
+
27
+ Thread-safe via shared lock with existing EventLogger.
28
+ """
29
+
30
+ def __init__(self, log_dir: str, lock: Optional[threading.Lock] = None):
31
+ """
32
+ Initialize timeline logger.
33
+
34
+ Args:
35
+ log_dir: Directory for log files
36
+ lock: Optional shared lock (uses global if not provided)
37
+ """
38
+ self.log_dir = Path(log_dir)
39
+ self.log_dir.mkdir(parents=True, exist_ok=True)
40
+
41
+ # Use shared lock from callbacks module for thread safety
42
+ if lock is None:
43
+ from ..trainer.callbacks import log_lock
44
+ self._lock = log_lock
45
+ else:
46
+ self._lock = lock
47
+
48
+ # File paths
49
+ self.events_path = self.log_dir / "timeline_events.jsonl"
50
+ self.summary_path = self.log_dir / "timeline_summary.jsonl"
51
+ self.state_path = self.log_dir / "timeline_state.json"
52
+
53
+ # In-memory buffer for batch writes
54
+ self._event_buffer: List[LearningEvent] = []
55
+ self._buffer_size = 50 # Flush every N events
56
+
57
+ def log_events(self, events: List[LearningEvent], flush: bool = False):
58
+ """
59
+ Log a batch of learning events.
60
+
61
+ Args:
62
+ events: List of LearningEvents to log
63
+ flush: Force immediate write to disk
64
+ """
65
+ self._event_buffer.extend(events)
66
+
67
+ if flush or len(self._event_buffer) >= self._buffer_size:
68
+ self._flush_events()
69
+
70
+ def log_event(self, event: LearningEvent):
71
+ """Log a single learning event."""
72
+ self.log_events([event])
73
+
74
+ def log_summary(self, summary: EpochSummary):
75
+ """Log an epoch summary."""
76
+ with self._lock:
77
+ with open(self.summary_path, "a") as f:
78
+ f.write(json.dumps(summary.to_dict()) + "\n")
79
+ f.flush()
80
+ os.fsync(f.fileno())
81
+
82
+ def save_tracker_state(self, tracker_data: Dict[str, Any]):
83
+ """
84
+ Save tracker state for run resumption/replay.
85
+
86
+ Args:
87
+ tracker_data: Serialized SampleTracker state
88
+ """
89
+ with self._lock:
90
+ with open(self.state_path, "w") as f:
91
+ json.dump(tracker_data, f, indent=2, default=str)
92
+ f.flush()
93
+ os.fsync(f.fileno())
94
+
95
+ def load_tracker_state(self) -> Optional[Dict[str, Any]]:
96
+ """Load saved tracker state if exists."""
97
+ if not self.state_path.exists():
98
+ return None
99
+
100
+ with self._lock:
101
+ with open(self.state_path, "r") as f:
102
+ return json.load(f)
103
+
104
+ def get_events(
105
+ self,
106
+ epoch: Optional[int] = None,
107
+ sample_id: Optional[int] = None
108
+ ) -> List[LearningEvent]:
109
+ """
110
+ Read events from log file with optional filtering.
111
+
112
+ Args:
113
+ epoch: Filter by epoch number
114
+ sample_id: Filter by sample ID
115
+ """
116
+ events = []
117
+
118
+ if not self.events_path.exists():
119
+ return events
120
+
121
+ with self._lock:
122
+ with open(self.events_path, "r") as f:
123
+ for line in f:
124
+ if not line.strip():
125
+ continue
126
+ try:
127
+ data = json.loads(line)
128
+ event = LearningEvent.from_dict(data)
129
+
130
+ # Apply filters
131
+ if epoch is not None and event.epoch != epoch:
132
+ continue
133
+ if sample_id is not None and event.sample_id != sample_id:
134
+ continue
135
+
136
+ events.append(event)
137
+ except (json.JSONDecodeError, KeyError):
138
+ continue
139
+
140
+ return events
141
+
142
+ def get_summaries(self) -> List[EpochSummary]:
143
+ """Read all epoch summaries."""
144
+ summaries = []
145
+
146
+ if not self.summary_path.exists():
147
+ return summaries
148
+
149
+ with self._lock:
150
+ with open(self.summary_path, "r") as f:
151
+ for line in f:
152
+ if not line.strip():
153
+ continue
154
+ try:
155
+ data = json.loads(line)
156
+ summary = EpochSummary(**data)
157
+ summaries.append(summary)
158
+ except (json.JSONDecodeError, KeyError, TypeError):
159
+ continue
160
+
161
+ return summaries
162
+
163
+ def get_sample_timeline(self, sample_id: int) -> List[LearningEvent]:
164
+ """Get full timeline for a specific sample."""
165
+ return self.get_events(sample_id=sample_id)
166
+
167
+ def get_latest_epoch(self) -> int:
168
+ """Get the most recent epoch number logged."""
169
+ summaries = self.get_summaries()
170
+ if not summaries:
171
+ return 0
172
+ return max(s.epoch for s in summaries)
173
+
174
+ def clear(self):
175
+ """Clear all timeline logs (for new run)."""
176
+ self._flush_events() # Flush buffer first
177
+
178
+ with self._lock:
179
+ for path in [self.events_path, self.summary_path, self.state_path]:
180
+ if path.exists():
181
+ path.unlink()
182
+
183
+ def _flush_events(self):
184
+ """Write buffered events to disk."""
185
+ if not self._event_buffer:
186
+ return
187
+
188
+ with self._lock:
189
+ with open(self.events_path, "a") as f:
190
+ for event in self._event_buffer:
191
+ f.write(json.dumps(event.to_dict(), default=str) + "\n")
192
+ f.flush()
193
+ os.fsync(f.fileno())
194
+
195
+ self._event_buffer.clear()
196
+
197
+ def finalize(self):
198
+ """Ensure all buffered data is written."""
199
+ self._flush_events()
200
+
201
+ def __del__(self):
202
+ """Flush on deletion."""
203
+ try:
204
+ self._flush_events()
205
+ except Exception:
206
+ pass
207
+
208
+
209
+ def create_timeline_logger(run_dir: str) -> TimelineLogger:
210
+ """
211
+ Factory function to create a TimelineLogger.
212
+
213
+ Convenience function that handles path resolution.
214
+ """
215
+ return TimelineLogger(run_dir)
@@ -0,0 +1,170 @@
1
+ """
2
+ Event Models for Learning Timeline (v2.0.0)
3
+
4
+ Defines the LearningEvent contract that all timeline visuals consume.
5
+ """
6
+
7
+ from dataclasses import dataclass, field, asdict
8
+ from typing import Optional, Any, Dict, List
9
+ from enum import Enum
10
+ import time
11
+
12
+
13
+ class EventType(str, Enum):
14
+ """Types of events in the learning timeline."""
15
+ SAMPLE_PREDICTION = "sample_prediction"
16
+ EPOCH_SUMMARY = "epoch_summary"
17
+ FLIP_DETECTED = "flip_detected"
18
+ STABILITY_CHANGE = "stability_change"
19
+
20
+
21
+ @dataclass
22
+ class LearningEvent:
23
+ """
24
+ Core event model for sample-level prediction tracking.
25
+
26
+ This is the internal contract between training logic and visualization.
27
+ All timeline visuals consume LearningEvents, not training internals.
28
+
29
+ Attributes:
30
+ run_id: Unique identifier for the training run
31
+ epoch: Current epoch number (1-indexed)
32
+ sample_id: Index or identifier of the tracked sample
33
+ true_label: Ground truth label for the sample
34
+ predicted_label: Model's prediction for this sample at this epoch
35
+ confidence: Prediction confidence/probability (0.0 to 1.0)
36
+ correct: Whether prediction matches true label
37
+ timestamp: Unix timestamp when event was recorded
38
+ margin: Optional decision margin (distance from decision boundary)
39
+ probabilities: Optional full probability distribution across classes
40
+ metadata: Optional additional context
41
+ """
42
+ run_id: str
43
+ epoch: int
44
+ sample_id: int
45
+ true_label: Any
46
+ predicted_label: Any
47
+ confidence: float
48
+ correct: bool
49
+ timestamp: float = field(default_factory=time.time)
50
+ margin: Optional[float] = None
51
+ probabilities: Optional[List[float]] = None
52
+ metadata: Optional[Dict[str, Any]] = None
53
+
54
+ def to_dict(self) -> Dict[str, Any]:
55
+ """Convert to dictionary for JSON serialization."""
56
+ return asdict(self)
57
+
58
+ @classmethod
59
+ def from_dict(cls, data: Dict[str, Any]) -> "LearningEvent":
60
+ """Reconstruct from dictionary."""
61
+ return cls(**data)
62
+
63
+
64
+ @dataclass
65
+ class SampleState:
66
+ """
67
+ Tracks the learning state of a single sample across epochs.
68
+
69
+ Used to compute stability metrics and detect flips.
70
+ """
71
+ sample_id: int
72
+ true_label: Any
73
+ history: List[LearningEvent] = field(default_factory=list)
74
+
75
+ @property
76
+ def flip_count(self) -> int:
77
+ """Count how many times the prediction changed."""
78
+ if len(self.history) < 2:
79
+ return 0
80
+ flips = 0
81
+ for i in range(1, len(self.history)):
82
+ if self.history[i].predicted_label != self.history[i-1].predicted_label:
83
+ flips += 1
84
+ return flips
85
+
86
+ @property
87
+ def is_stable(self) -> bool:
88
+ """Sample is stable if no flips in last 3 epochs."""
89
+ if len(self.history) < 3:
90
+ return False
91
+ last_3 = self.history[-3:]
92
+ return all(e.predicted_label == last_3[0].predicted_label for e in last_3)
93
+
94
+ @property
95
+ def stability_class(self) -> str:
96
+ """
97
+ Classify sample stability for visualization.
98
+
99
+ Returns:
100
+ 'stable_correct': Consistently correct
101
+ 'stable_wrong': Consistently wrong
102
+ 'unstable': Predictions keep changing
103
+ 'late_learner': Recently became correct
104
+ """
105
+ if not self.history:
106
+ return "unknown"
107
+
108
+ recent = self.history[-3:] if len(self.history) >= 3 else self.history
109
+ all_correct = all(e.correct for e in recent)
110
+ all_wrong = all(not e.correct for e in recent)
111
+
112
+ if all_correct and self.is_stable:
113
+ return "stable_correct"
114
+ elif all_wrong and self.is_stable:
115
+ return "stable_wrong"
116
+ elif self.flip_count > 2:
117
+ return "unstable"
118
+ elif len(self.history) >= 3 and self.history[-1].correct and not self.history[-3].correct:
119
+ return "late_learner"
120
+ else:
121
+ return "unstable"
122
+
123
+ @property
124
+ def first_correct_epoch(self) -> Optional[int]:
125
+ """Return epoch when sample was first correctly classified."""
126
+ for event in self.history:
127
+ if event.correct:
128
+ return event.epoch
129
+ return None
130
+
131
+ @property
132
+ def current_prediction(self) -> Optional[Any]:
133
+ """Most recent prediction."""
134
+ return self.history[-1].predicted_label if self.history else None
135
+
136
+ @property
137
+ def current_confidence(self) -> Optional[float]:
138
+ """Most recent confidence score."""
139
+ return self.history[-1].confidence if self.history else None
140
+
141
+ def add_event(self, event: LearningEvent):
142
+ """Record a new prediction event."""
143
+ self.history.append(event)
144
+
145
+
146
+ @dataclass
147
+ class EpochSummary:
148
+ """
149
+ Aggregated summary of sample-level events for one epoch.
150
+
151
+ Used for the Timeline Overview block.
152
+ """
153
+ run_id: str
154
+ epoch: int
155
+ timestamp: float
156
+ total_tracked: int
157
+ correct_count: int
158
+ flip_count: int
159
+ stable_correct: int
160
+ stable_wrong: int
161
+ unstable: int
162
+ late_learners: int
163
+
164
+ @property
165
+ def accuracy(self) -> float:
166
+ """Tracked sample accuracy."""
167
+ return self.correct_count / self.total_tracked if self.total_tracked > 0 else 0.0
168
+
169
+ def to_dict(self) -> Dict[str, Any]:
170
+ return asdict(self)