loopkit 0.0.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- loopkit/__init__.py +85 -0
- loopkit/cli/__init__.py +3 -0
- loopkit/cli/compare.py +274 -0
- loopkit/cli/main.py +62 -0
- loopkit/cli/sweep.py +313 -0
- loopkit/cli/visualize.py +285 -0
- loopkit/config.py +1236 -0
- loopkit/git.py +154 -0
- loopkit/logger.py +729 -0
- loopkit/monitor.py +259 -0
- loopkit/torch/__init__.py +43 -0
- loopkit/torch/checkpoint.py +339 -0
- loopkit/torch/mp.py +346 -0
- loopkit/tracking.py +203 -0
- loopkit/utils.py +75 -0
- loopkit-0.0.1a1.dist-info/METADATA +44 -0
- loopkit-0.0.1a1.dist-info/RECORD +21 -0
- loopkit-0.0.1a1.dist-info/WHEEL +5 -0
- loopkit-0.0.1a1.dist-info/entry_points.txt +2 -0
- loopkit-0.0.1a1.dist-info/licenses/LICENSE +7 -0
- loopkit-0.0.1a1.dist-info/top_level.txt +1 -0
loopkit/logger.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
import time
|
|
6
|
+
import uuid
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Dict, Optional
|
|
10
|
+
|
|
11
|
+
import loopkit
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ColorFormatter(logging.Formatter):
|
|
15
|
+
"""Custom formatter that adds colors to console output."""
|
|
16
|
+
|
|
17
|
+
# ANSI color codes
|
|
18
|
+
COLORS = {
|
|
19
|
+
"DEBUG": "\033[36m", # Cyan
|
|
20
|
+
"INFO": "\033[32m", # Green
|
|
21
|
+
"WARNING": "\033[33m", # Yellow
|
|
22
|
+
"ERROR": "\033[31m", # Red
|
|
23
|
+
"CRITICAL": "\033[35m", # Magenta
|
|
24
|
+
}
|
|
25
|
+
RESET = "\033[0m"
|
|
26
|
+
BOLD = "\033[1m"
|
|
27
|
+
DIM = "\033[2m"
|
|
28
|
+
|
|
29
|
+
# Special colors for specific message types
|
|
30
|
+
STAGE_COLOR = "\033[34m" # Blue
|
|
31
|
+
TIMER_COLOR = "\033[36m" # Cyan
|
|
32
|
+
METRIC_COLOR = "\033[35m" # Magenta
|
|
33
|
+
|
|
34
|
+
def format(self, record):
|
|
35
|
+
"""Format log record with colors."""
|
|
36
|
+
# Get the base message
|
|
37
|
+
message = record.getMessage()
|
|
38
|
+
|
|
39
|
+
# Determine color based on level
|
|
40
|
+
level_color = self.COLORS.get(record.levelname, "")
|
|
41
|
+
|
|
42
|
+
# Special coloring for specific message patterns
|
|
43
|
+
if message.startswith("Stage ["):
|
|
44
|
+
# Stage messages in blue with bold stage name
|
|
45
|
+
parts = message.split("]", 1)
|
|
46
|
+
if len(parts) == 2:
|
|
47
|
+
stage_name = parts[0] + "]"
|
|
48
|
+
rest = parts[1]
|
|
49
|
+
colored_message = (
|
|
50
|
+
f"{self.STAGE_COLOR}{self.BOLD}{stage_name}{self.RESET}"
|
|
51
|
+
f"{self.STAGE_COLOR}{rest}{self.RESET}"
|
|
52
|
+
)
|
|
53
|
+
else:
|
|
54
|
+
colored_message = f"{self.STAGE_COLOR}{message}{self.RESET}"
|
|
55
|
+
elif message.startswith("Timer ["):
|
|
56
|
+
# Timer messages in cyan with bold timer name
|
|
57
|
+
parts = message.split("]", 1)
|
|
58
|
+
if len(parts) == 2:
|
|
59
|
+
timer_name = parts[0] + "]"
|
|
60
|
+
rest = parts[1]
|
|
61
|
+
colored_message = (
|
|
62
|
+
f"{self.TIMER_COLOR}{self.BOLD}{timer_name}{self.RESET}"
|
|
63
|
+
f"{self.TIMER_COLOR}{rest}{self.RESET}"
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
colored_message = f"{self.TIMER_COLOR}{message}{self.RESET}"
|
|
67
|
+
else:
|
|
68
|
+
# Regular messages: color based on level
|
|
69
|
+
level_name = f"{level_color}{self.BOLD}{record.levelname}{self.RESET}"
|
|
70
|
+
colored_message = f"{level_color}{message}{self.RESET}"
|
|
71
|
+
return f"{level_name}: {colored_message}"
|
|
72
|
+
|
|
73
|
+
# For special messages, just add level name without extra coloring
|
|
74
|
+
level_name = f"{level_color}{self.BOLD}{record.levelname}{self.RESET}"
|
|
75
|
+
return f"{level_name}: {colored_message}"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class ExperimentLogger:
|
|
79
|
+
"""Experiment logger with structured logging and metrics tracking.
|
|
80
|
+
|
|
81
|
+
This logger provides two equivalent APIs for text logging:
|
|
82
|
+
1. Generic: logger.log(message, level='INFO')
|
|
83
|
+
2. Level-specific: logger.info(message), logger.warning(message), logger.error(message)
|
|
84
|
+
|
|
85
|
+
For metrics tracking, use logger.log_metric(step, split, name, value)
|
|
86
|
+
|
|
87
|
+
Features:
|
|
88
|
+
- Rank-aware logging (DDP compatible)
|
|
89
|
+
- Human-readable and machine-readable logs
|
|
90
|
+
- Metrics tracking with best model tracking
|
|
91
|
+
- Profiling context managers (can be disabled globally)
|
|
92
|
+
- Configurable log levels
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
run_dir: Directory for log files
|
|
96
|
+
rank: Process rank (0 for main process)
|
|
97
|
+
run_id: Unique run identifier
|
|
98
|
+
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
|
99
|
+
console_output: Whether to output to console (only rank 0)
|
|
100
|
+
|
|
101
|
+
Examples:
|
|
102
|
+
>>> # Basic usage
|
|
103
|
+
>>> logger = ExperimentLogger(run_dir='runs/exp1')
|
|
104
|
+
>>>
|
|
105
|
+
>>> # Text logging (two equivalent ways)
|
|
106
|
+
>>> logger.info("Training started") # Preferred
|
|
107
|
+
>>> logger.log("Training started", level="INFO") # Alternative
|
|
108
|
+
>>>
|
|
109
|
+
>>> # Metrics logging
|
|
110
|
+
>>> logger.log_metric(step=0, split='train', name='loss', value=0.5)
|
|
111
|
+
>>>
|
|
112
|
+
>>> # Profiling (enabled by default)
|
|
113
|
+
>>> with logger.timer('data_loading'):
|
|
114
|
+
>>> data = load_data()
|
|
115
|
+
>>>
|
|
116
|
+
>>> # Disable profiling globally for production
|
|
117
|
+
>>> import loopkit
|
|
118
|
+
>>> loopkit.timer = False # or set EM_TIMER=0 environment variable
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
run_dir: Path | str,
|
|
124
|
+
rank: int = 0,
|
|
125
|
+
run_id: str = None,
|
|
126
|
+
log_level: str = None,
|
|
127
|
+
console_output: bool = None,
|
|
128
|
+
):
|
|
129
|
+
# Convert to Path if string
|
|
130
|
+
if isinstance(run_dir, str):
|
|
131
|
+
run_dir = Path(run_dir)
|
|
132
|
+
|
|
133
|
+
self.run_dir = run_dir
|
|
134
|
+
self.rank = rank
|
|
135
|
+
self.run_id = run_id or "unknown"
|
|
136
|
+
|
|
137
|
+
run_dir.mkdir(parents=True, exist_ok=True)
|
|
138
|
+
|
|
139
|
+
# Use global log level if not specified
|
|
140
|
+
if log_level is None:
|
|
141
|
+
log_level = loopkit.log_level
|
|
142
|
+
|
|
143
|
+
# Use global verbose flag if console_output not specified
|
|
144
|
+
if console_output is None:
|
|
145
|
+
console_output = loopkit.verbose
|
|
146
|
+
|
|
147
|
+
# Human-readable log (per-rank) - NO COLORS in file
|
|
148
|
+
log_filename = "exp.log" if rank == 0 else f"exp_rank{rank}.log"
|
|
149
|
+
self.log_file = run_dir / log_filename
|
|
150
|
+
self.log_handler = logging.FileHandler(self.log_file)
|
|
151
|
+
self.log_handler.setFormatter(
|
|
152
|
+
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# JSONL log for machine reading (per-rank)
|
|
156
|
+
jsonl_filename = "events.jsonl" if rank == 0 else f"events_rank{rank}.jsonl"
|
|
157
|
+
self.jsonl_file = run_dir / jsonl_filename
|
|
158
|
+
|
|
159
|
+
# Console handler (rank-0 only by default) - WITH COLORS
|
|
160
|
+
self.console_handler = None
|
|
161
|
+
if rank == 0 and console_output:
|
|
162
|
+
self.console_handler = logging.StreamHandler(sys.stdout)
|
|
163
|
+
# Check if stdout supports colors (not piped/redirected)
|
|
164
|
+
use_colors = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
|
|
165
|
+
if use_colors:
|
|
166
|
+
self.console_handler.setFormatter(ColorFormatter())
|
|
167
|
+
else:
|
|
168
|
+
# Fallback to plain format if output is redirected
|
|
169
|
+
self.console_handler.setFormatter(
|
|
170
|
+
logging.Formatter("%(levelname)s: %(message)s")
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Metrics CSV (rank-0 only) - respect track_metrics flag
|
|
174
|
+
self.metrics_file = run_dir / "metrics.csv"
|
|
175
|
+
if rank == 0 and loopkit.track_metrics and not self.metrics_file.exists():
|
|
176
|
+
with open(self.metrics_file, "w") as f:
|
|
177
|
+
f.write("step,split,name,value,wall_time\n")
|
|
178
|
+
|
|
179
|
+
# Best metrics tracking (per metric)
|
|
180
|
+
self.best_metrics: Dict[
|
|
181
|
+
str, Dict
|
|
182
|
+
] = {} # {metric_name: {'value': ..., 'step': ..., 'mode': ...}}
|
|
183
|
+
self.best_file = run_dir / "best.json"
|
|
184
|
+
if rank == 0 and self.best_file.exists():
|
|
185
|
+
with open(self.best_file) as f:
|
|
186
|
+
self.best_metrics = json.load(f)
|
|
187
|
+
|
|
188
|
+
# Setup logger (unique name to avoid handler accumulation)
|
|
189
|
+
logger_name = f"loopkit_rank_{rank}_{uuid.uuid4().hex[:8]}"
|
|
190
|
+
self.logger = logging.getLogger(logger_name)
|
|
191
|
+
self.logger.setLevel(getattr(logging, log_level.upper()))
|
|
192
|
+
|
|
193
|
+
# Clear any existing handlers (in case logger was reused)
|
|
194
|
+
self.logger.handlers.clear()
|
|
195
|
+
|
|
196
|
+
self.logger.addHandler(self.log_handler)
|
|
197
|
+
if self.console_handler:
|
|
198
|
+
self.logger.addHandler(self.console_handler)
|
|
199
|
+
|
|
200
|
+
# Prevent propagation to root logger
|
|
201
|
+
self.logger.propagate = False
|
|
202
|
+
|
|
203
|
+
# Stage tracking (for hierarchical context)
|
|
204
|
+
self._current_stage: Optional[str] = None
|
|
205
|
+
self._stage_stack: list = []
|
|
206
|
+
|
|
207
|
+
# Rate limiting tracking
|
|
208
|
+
self._rate_limit_counters: Dict[str, int] = {}
|
|
209
|
+
self._rate_limit_last_log: Dict[str, float] = {}
|
|
210
|
+
|
|
211
|
+
# Time estimation tracking (for automatic ETA in log_metric)
|
|
212
|
+
self._eta_tracking: Dict[
|
|
213
|
+
str, Dict
|
|
214
|
+
] = {} # {key: {'start_time': ..., 'start_step': ..., 'total_steps': ...}}
|
|
215
|
+
|
|
216
|
+
def log(self, message: str, level: str = "INFO", **kwargs):
|
|
217
|
+
"""Log a message with context.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
message: The log message
|
|
221
|
+
level: Log level (DEBUG, INFO, WARNING, ERROR)
|
|
222
|
+
**kwargs: Additional context to include in log
|
|
223
|
+
"""
|
|
224
|
+
# Check if this level should be logged
|
|
225
|
+
log_level_num = getattr(logging, level.upper(), logging.INFO)
|
|
226
|
+
logger_level_num = self.logger.level
|
|
227
|
+
|
|
228
|
+
# Only write to JSONL and logger if level is high enough
|
|
229
|
+
if log_level_num >= logger_level_num:
|
|
230
|
+
# Write to JSONL
|
|
231
|
+
log_entry = {
|
|
232
|
+
"timestamp": time.time(),
|
|
233
|
+
"level": level.upper(),
|
|
234
|
+
"message": message,
|
|
235
|
+
"run_id": self.run_id,
|
|
236
|
+
"rank": self.rank,
|
|
237
|
+
**kwargs,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
with open(self.jsonl_file, "a") as f:
|
|
241
|
+
f.write(json.dumps(log_entry) + "\n")
|
|
242
|
+
|
|
243
|
+
# Write to logger
|
|
244
|
+
self.logger.log(log_level_num, message)
|
|
245
|
+
|
|
246
|
+
def debug(self, message: str, **kwargs):
|
|
247
|
+
self.log(message, "DEBUG", **kwargs)
|
|
248
|
+
|
|
249
|
+
def info(self, message: str, **kwargs):
|
|
250
|
+
self.log(message, "INFO", **kwargs)
|
|
251
|
+
|
|
252
|
+
def warning(self, message: str, **kwargs):
|
|
253
|
+
self.log(message, "WARNING", **kwargs)
|
|
254
|
+
|
|
255
|
+
def error(self, message: str, **kwargs):
|
|
256
|
+
self.log(message, "ERROR", **kwargs)
|
|
257
|
+
|
|
258
|
+
@staticmethod
|
|
259
|
+
def _format_time(seconds: float) -> str:
|
|
260
|
+
"""Format seconds into human-readable time string.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
seconds: Time in seconds
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Formatted time string (e.g., "2h 15m 30s", "45s", "1d 3h 20m")
|
|
267
|
+
"""
|
|
268
|
+
if seconds < 0:
|
|
269
|
+
return "0s"
|
|
270
|
+
|
|
271
|
+
days = int(seconds // 86400)
|
|
272
|
+
seconds %= 86400
|
|
273
|
+
hours = int(seconds // 3600)
|
|
274
|
+
seconds %= 3600
|
|
275
|
+
minutes = int(seconds // 60)
|
|
276
|
+
secs = int(seconds % 60)
|
|
277
|
+
|
|
278
|
+
parts = []
|
|
279
|
+
if days > 0:
|
|
280
|
+
parts.append(f"{days}d")
|
|
281
|
+
if hours > 0:
|
|
282
|
+
parts.append(f"{hours}h")
|
|
283
|
+
if minutes > 0:
|
|
284
|
+
parts.append(f"{minutes}m")
|
|
285
|
+
if secs > 0 or not parts:
|
|
286
|
+
parts.append(f"{secs}s")
|
|
287
|
+
|
|
288
|
+
return " ".join(parts)
|
|
289
|
+
|
|
290
|
+
def log_metric(
|
|
291
|
+
self,
|
|
292
|
+
step: int,
|
|
293
|
+
split: str,
|
|
294
|
+
name: str,
|
|
295
|
+
value: float,
|
|
296
|
+
track_best: bool = True,
|
|
297
|
+
mode: Optional[str] = None,
|
|
298
|
+
total_steps: Optional[int] = None,
|
|
299
|
+
eta_key: Optional[str] = None,
|
|
300
|
+
):
|
|
301
|
+
"""Log a metric to CSV and update best tracking.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
step: Training step
|
|
305
|
+
split: Data split (train/val/test)
|
|
306
|
+
name: Metric name
|
|
307
|
+
value: Metric value
|
|
308
|
+
track_best: Whether to track this as a best metric
|
|
309
|
+
mode: 'min' or 'max' for best tracking (auto-detected if None)
|
|
310
|
+
total_steps: Total number of steps (enables automatic ETA estimation)
|
|
311
|
+
eta_key: Key for ETA tracking (auto-generated from split/name if None)
|
|
312
|
+
"""
|
|
313
|
+
if self.rank != 0:
|
|
314
|
+
return # Only rank 0 logs metrics
|
|
315
|
+
|
|
316
|
+
# Skip if metrics tracking is disabled
|
|
317
|
+
if not loopkit.track_metrics:
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
wall_time = time.time()
|
|
321
|
+
with open(self.metrics_file, "a") as f:
|
|
322
|
+
f.write(f"{step},{split},{name},{value},{wall_time}\n")
|
|
323
|
+
|
|
324
|
+
# Automatic ETA estimation if total_steps is provided
|
|
325
|
+
if total_steps is not None and step > 0:
|
|
326
|
+
# Auto-generate key if not provided
|
|
327
|
+
if eta_key is None:
|
|
328
|
+
eta_key = f"{split}/{name}"
|
|
329
|
+
|
|
330
|
+
# Initialize tracking for this metric if first time
|
|
331
|
+
if eta_key not in self._eta_tracking:
|
|
332
|
+
self._eta_tracking[eta_key] = {
|
|
333
|
+
"start_time": wall_time,
|
|
334
|
+
"start_step": step,
|
|
335
|
+
"total_steps": total_steps,
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
# Calculate ETA
|
|
339
|
+
tracking = self._eta_tracking[eta_key]
|
|
340
|
+
elapsed = wall_time - tracking["start_time"]
|
|
341
|
+
steps_done = step - tracking["start_step"]
|
|
342
|
+
|
|
343
|
+
if steps_done > 0:
|
|
344
|
+
steps_remaining = total_steps - step
|
|
345
|
+
time_per_step = elapsed / steps_done
|
|
346
|
+
eta_seconds = time_per_step * steps_remaining
|
|
347
|
+
|
|
348
|
+
# Log ETA info (only if significant progress made to avoid spam)
|
|
349
|
+
if steps_done % max(1, total_steps // 20) == 0 or step == total_steps:
|
|
350
|
+
progress_pct = 100 * step / total_steps
|
|
351
|
+
eta_str = self._format_time(eta_seconds)
|
|
352
|
+
elapsed_str = self._format_time(elapsed)
|
|
353
|
+
|
|
354
|
+
self.info(
|
|
355
|
+
f"Progress [{split}/{name}]: {step}/{total_steps} ({progress_pct:.1f}%) | "
|
|
356
|
+
f"Elapsed: {elapsed_str} | ETA: {eta_str}",
|
|
357
|
+
step=step,
|
|
358
|
+
progress=progress_pct,
|
|
359
|
+
eta_seconds=eta_seconds,
|
|
360
|
+
elapsed_seconds=elapsed,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Update best metrics tracking
|
|
364
|
+
if track_best:
|
|
365
|
+
# Auto-detect mode if not specified
|
|
366
|
+
if mode is None:
|
|
367
|
+
name_lower = name.lower()
|
|
368
|
+
if any(x in name_lower for x in ["loss", "error"]):
|
|
369
|
+
mode = "min"
|
|
370
|
+
elif any(x in name_lower for x in ["acc", "accuracy", "f1", "auc"]):
|
|
371
|
+
mode = "max"
|
|
372
|
+
else:
|
|
373
|
+
mode = "min" # Default to min
|
|
374
|
+
|
|
375
|
+
metric_key = f"{split}/{name}"
|
|
376
|
+
|
|
377
|
+
# Check if this is a new best
|
|
378
|
+
is_best = False
|
|
379
|
+
if metric_key not in self.best_metrics:
|
|
380
|
+
is_best = True
|
|
381
|
+
else:
|
|
382
|
+
prev_best = self.best_metrics[metric_key]["value"]
|
|
383
|
+
is_better = (mode == "min" and value < prev_best) or (
|
|
384
|
+
mode == "max" and value > prev_best
|
|
385
|
+
)
|
|
386
|
+
if is_better:
|
|
387
|
+
is_best = True
|
|
388
|
+
|
|
389
|
+
# Update if best
|
|
390
|
+
if is_best:
|
|
391
|
+
self.best_metrics[metric_key] = {
|
|
392
|
+
"value": value,
|
|
393
|
+
"step": step,
|
|
394
|
+
"mode": mode,
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
# Save best metrics
|
|
398
|
+
with open(self.best_file, "w") as f:
|
|
399
|
+
json.dump(self.best_metrics, f, indent=2)
|
|
400
|
+
|
|
401
|
+
def get_best_metric(self, name: str, split: str = "val") -> Optional[Dict]:
|
|
402
|
+
"""Get the best value for a metric.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
name: Metric name
|
|
406
|
+
split: Data split
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
dict: Best metric info with 'value', 'step', 'mode' or None
|
|
410
|
+
"""
|
|
411
|
+
metric_key = f"{split}/{name}"
|
|
412
|
+
return self.best_metrics.get(metric_key)
|
|
413
|
+
|
|
414
|
+
@contextmanager
|
|
415
|
+
def timer(self, name: str, log_result: bool = True):
|
|
416
|
+
"""Context manager for timing code sections.
|
|
417
|
+
|
|
418
|
+
Provides simple wall-clock timing using time.perf_counter().
|
|
419
|
+
For detailed GPU/CPU profiling, use torch.profiler directly.
|
|
420
|
+
|
|
421
|
+
Timing can be globally disabled by setting loopkit.timer = False
|
|
422
|
+
or environment variable LK_TIMER=0.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
name: Name of the timed section
|
|
426
|
+
log_result: Whether to log the timing result
|
|
427
|
+
|
|
428
|
+
Usage:
|
|
429
|
+
# Basic timing
|
|
430
|
+
with logger.timer("data_loading"):
|
|
431
|
+
data = load_data()
|
|
432
|
+
|
|
433
|
+
# Get elapsed time
|
|
434
|
+
with logger.timer("training_step") as result:
|
|
435
|
+
loss = model(batch)
|
|
436
|
+
print(f"Step took {result['elapsed']:.4f}s")
|
|
437
|
+
|
|
438
|
+
# Silent timing (no logging)
|
|
439
|
+
with logger.timer("forward", log_result=False) as result:
|
|
440
|
+
output = model(input)
|
|
441
|
+
elapsed = result['elapsed']
|
|
442
|
+
|
|
443
|
+
Yields:
|
|
444
|
+
dict: Dictionary with 'elapsed' key (in seconds), updated when context exits
|
|
445
|
+
|
|
446
|
+
Example with PyTorch Profiler:
|
|
447
|
+
# Use torch.profiler for detailed profiling
|
|
448
|
+
with torch.profiler.profile(
|
|
449
|
+
activities=[
|
|
450
|
+
torch.profiler.ProfilerActivity.CPU,
|
|
451
|
+
torch.profiler.ProfilerActivity.CUDA,
|
|
452
|
+
],
|
|
453
|
+
record_shapes=True,
|
|
454
|
+
) as prof:
|
|
455
|
+
with logger.timer("training_step"):
|
|
456
|
+
loss = model(batch)
|
|
457
|
+
loss.backward()
|
|
458
|
+
prof.export_chrome_trace("trace.json")
|
|
459
|
+
"""
|
|
460
|
+
result = {"elapsed": 0.0}
|
|
461
|
+
|
|
462
|
+
# Check if profiling is enabled globally
|
|
463
|
+
if not loopkit.timer:
|
|
464
|
+
yield result
|
|
465
|
+
return
|
|
466
|
+
|
|
467
|
+
start_time = time.perf_counter()
|
|
468
|
+
|
|
469
|
+
try:
|
|
470
|
+
yield result
|
|
471
|
+
finally:
|
|
472
|
+
result["elapsed"] = time.perf_counter() - start_time
|
|
473
|
+
|
|
474
|
+
if log_result and self.rank == 0:
|
|
475
|
+
self.info(
|
|
476
|
+
f"Timer [{name}]: {result['elapsed']:.4f}s",
|
|
477
|
+
section=name,
|
|
478
|
+
elapsed=result["elapsed"],
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
@contextmanager
|
|
482
|
+
def stage(self, name: str, **metadata):
|
|
483
|
+
"""Context manager for tracking training stages with hierarchical context.
|
|
484
|
+
|
|
485
|
+
Automatically logs stage entry/exit, tracks duration, and provides context
|
|
486
|
+
for metrics and logs. Stages can be nested to create hierarchies
|
|
487
|
+
(e.g., epoch → train → batch).
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
name: Name of the stage (e.g., 'epoch', 'train', 'validation')
|
|
491
|
+
**metadata: Additional metadata to log (e.g., epoch=5, batch=10)
|
|
492
|
+
|
|
493
|
+
Usage:
|
|
494
|
+
# Simple stage
|
|
495
|
+
with logger.stage("training"):
|
|
496
|
+
train_model()
|
|
497
|
+
|
|
498
|
+
# With metadata
|
|
499
|
+
with logger.stage("epoch", epoch=5, lr=0.001):
|
|
500
|
+
train_epoch()
|
|
501
|
+
|
|
502
|
+
# Nested stages
|
|
503
|
+
with logger.stage("epoch", epoch=5):
|
|
504
|
+
with logger.stage("train"):
|
|
505
|
+
train_loss = train_epoch()
|
|
506
|
+
with logger.stage("validation"):
|
|
507
|
+
val_loss = validate()
|
|
508
|
+
|
|
509
|
+
Yields:
|
|
510
|
+
dict: Stage info with 'name', 'metadata', 'elapsed' keys
|
|
511
|
+
"""
|
|
512
|
+
stage_info = {"name": name, "metadata": metadata, "elapsed": 0.0}
|
|
513
|
+
|
|
514
|
+
# Build hierarchical stage path
|
|
515
|
+
if self._stage_stack:
|
|
516
|
+
parent = self._stage_stack[-1]["name"]
|
|
517
|
+
full_name = f"{parent}/{name}"
|
|
518
|
+
else:
|
|
519
|
+
full_name = name
|
|
520
|
+
|
|
521
|
+
# Log stage entry
|
|
522
|
+
if self.rank == 0:
|
|
523
|
+
metadata_str = ", ".join(f"{k}={v}" for k, v in metadata.items())
|
|
524
|
+
msg = f"Stage [{full_name}]"
|
|
525
|
+
if metadata_str:
|
|
526
|
+
msg += f" ({metadata_str})"
|
|
527
|
+
msg += " - START"
|
|
528
|
+
self.info(msg, stage=full_name, stage_event="start", **metadata)
|
|
529
|
+
|
|
530
|
+
# Push to stack
|
|
531
|
+
self._stage_stack.append(stage_info)
|
|
532
|
+
old_stage = self._current_stage
|
|
533
|
+
self._current_stage = full_name
|
|
534
|
+
|
|
535
|
+
start_time = time.perf_counter()
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
yield stage_info
|
|
539
|
+
finally:
|
|
540
|
+
elapsed = time.perf_counter() - start_time
|
|
541
|
+
stage_info["elapsed"] = elapsed
|
|
542
|
+
|
|
543
|
+
# Log stage exit
|
|
544
|
+
if self.rank == 0:
|
|
545
|
+
msg = f"Stage [{full_name}] - END ({elapsed:.4f}s)"
|
|
546
|
+
self.info(
|
|
547
|
+
msg, stage=full_name, stage_event="end", elapsed=elapsed, **metadata
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Pop from stack
|
|
551
|
+
self._stage_stack.pop()
|
|
552
|
+
self._current_stage = old_stage
|
|
553
|
+
|
|
554
|
+
def should_run(
|
|
555
|
+
self, every: int = None, seconds: float = None, key: str = None
|
|
556
|
+
) -> bool:
|
|
557
|
+
"""Check if code should run based on rate limiting.
|
|
558
|
+
|
|
559
|
+
Returns a boolean indicating whether to execute. Use with 'if' for clean syntax.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
every: Execute every N iterations (mutually exclusive with seconds)
|
|
563
|
+
seconds: Execute every N seconds (mutually exclusive with every)
|
|
564
|
+
key: Unique key for this rate limiter (auto-generated if None)
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
bool: True if code should execute, False if skipped
|
|
568
|
+
|
|
569
|
+
Usage:
|
|
570
|
+
# Clean 'if' syntax
|
|
571
|
+
for step in range(10000):
|
|
572
|
+
if logger.should_run(every=100):
|
|
573
|
+
logger.info(f"Step {step}")
|
|
574
|
+
save_checkpoint()
|
|
575
|
+
|
|
576
|
+
# Time-based rate limiting
|
|
577
|
+
for batch in dataloader:
|
|
578
|
+
if logger.should_run(seconds=5.0):
|
|
579
|
+
logger.info("Expensive logging")
|
|
580
|
+
"""
|
|
581
|
+
if every is None and seconds is None:
|
|
582
|
+
raise ValueError("Must specify either 'every' or 'seconds'")
|
|
583
|
+
if every is not None and seconds is not None:
|
|
584
|
+
raise ValueError("Cannot specify both 'every' and 'seconds'")
|
|
585
|
+
|
|
586
|
+
# Auto-generate key from caller location if not provided
|
|
587
|
+
if key is None:
|
|
588
|
+
frame = inspect.currentframe()
|
|
589
|
+
caller_frame = frame.f_back if frame else None
|
|
590
|
+
if caller_frame:
|
|
591
|
+
key = f"{caller_frame.f_code.co_filename}:{caller_frame.f_lineno}"
|
|
592
|
+
else:
|
|
593
|
+
key = "default"
|
|
594
|
+
|
|
595
|
+
should_execute = False
|
|
596
|
+
|
|
597
|
+
if every is not None:
|
|
598
|
+
# Iteration-based rate limiting
|
|
599
|
+
count = self._rate_limit_counters.get(key, 0)
|
|
600
|
+
self._rate_limit_counters[key] = count + 1
|
|
601
|
+
|
|
602
|
+
if count % every == 0:
|
|
603
|
+
should_execute = True
|
|
604
|
+
else:
|
|
605
|
+
# Time-based rate limiting
|
|
606
|
+
current_time = time.time()
|
|
607
|
+
last_execute_time = self._rate_limit_last_log.get(key, 0.0)
|
|
608
|
+
|
|
609
|
+
if current_time - last_execute_time >= seconds:
|
|
610
|
+
should_execute = True
|
|
611
|
+
self._rate_limit_last_log[key] = current_time
|
|
612
|
+
|
|
613
|
+
return should_execute
|
|
614
|
+
|
|
615
|
+
# Backward compatibility alias - wraps should_run as context manager
|
|
616
|
+
@contextmanager
|
|
617
|
+
def do_every(self, every: int = None, seconds: float = None, key: str = None):
|
|
618
|
+
"""Deprecated: Use should_run() instead.
|
|
619
|
+
|
|
620
|
+
Context manager version for backward compatibility.
|
|
621
|
+
"""
|
|
622
|
+
should_execute = self.should_run(every=every, seconds=seconds, key=key)
|
|
623
|
+
yield should_execute
|
|
624
|
+
|
|
625
|
+
@contextmanager
|
|
626
|
+
def log_every(self, every: int = None, seconds: float = None, key: str = None):
|
|
627
|
+
"""Context manager for rate-limited logging.
|
|
628
|
+
|
|
629
|
+
Logs only every N iterations or every N seconds, useful for reducing
|
|
630
|
+
log spam in tight training loops while still capturing data in JSONL.
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
every: Log every N iterations (mutually exclusive with seconds)
|
|
634
|
+
seconds: Log every N seconds (mutually exclusive with every)
|
|
635
|
+
key: Unique key for this rate limiter (auto-generated if None)
|
|
636
|
+
|
|
637
|
+
Usage:
|
|
638
|
+
# Log every 100 iterations
|
|
639
|
+
for step in range(10000):
|
|
640
|
+
loss = train_step()
|
|
641
|
+
with logger.log_every(every=100):
|
|
642
|
+
logger.info(f"Step {step}: loss={loss:.4f}")
|
|
643
|
+
|
|
644
|
+
# Log every 5 seconds
|
|
645
|
+
for batch in dataloader:
|
|
646
|
+
with logger.log_every(seconds=5.0):
|
|
647
|
+
logger.info(f"Processing batch...")
|
|
648
|
+
|
|
649
|
+
# Multiple rate limiters with different keys
|
|
650
|
+
for step in range(1000):
|
|
651
|
+
with logger.log_every(every=10, key="loss"):
|
|
652
|
+
logger.info(f"Loss: {loss:.4f}")
|
|
653
|
+
with logger.log_every(every=100, key="detailed"):
|
|
654
|
+
logger.info(f"Detailed metrics: {metrics}")
|
|
655
|
+
|
|
656
|
+
Yields:
|
|
657
|
+
bool: True if logging should occur, False if suppressed
|
|
658
|
+
"""
|
|
659
|
+
# Use should_run to determine if we should log
|
|
660
|
+
should_log = self.should_run(every=every, seconds=seconds, key=key)
|
|
661
|
+
|
|
662
|
+
# Temporarily suppress logging if not time to log
|
|
663
|
+
if not should_log and self.console_handler:
|
|
664
|
+
# Remove console handler temporarily
|
|
665
|
+
if self.console_handler in self.logger.handlers:
|
|
666
|
+
self.logger.removeHandler(self.console_handler)
|
|
667
|
+
try:
|
|
668
|
+
yield should_log
|
|
669
|
+
finally:
|
|
670
|
+
# Restore console handler
|
|
671
|
+
if self.console_handler not in self.logger.handlers:
|
|
672
|
+
self.logger.addHandler(self.console_handler)
|
|
673
|
+
else:
|
|
674
|
+
# Handler not present, just yield
|
|
675
|
+
yield should_log
|
|
676
|
+
else:
|
|
677
|
+
yield should_log
|
|
678
|
+
|
|
679
|
+
def set_log_level(self, level: str):
|
|
680
|
+
"""Change the logging level.
|
|
681
|
+
|
|
682
|
+
Args:
|
|
683
|
+
level: One of DEBUG, INFO, WARNING, ERROR
|
|
684
|
+
"""
|
|
685
|
+
self.logger.setLevel(getattr(logging, level.upper()))
|
|
686
|
+
|
|
687
|
+
def close(self):
|
|
688
|
+
"""Close all log handlers."""
|
|
689
|
+
for handler in self.logger.handlers[:]:
|
|
690
|
+
handler.close()
|
|
691
|
+
self.logger.removeHandler(handler)
|
|
692
|
+
|
|
693
|
+
def log_dict(self, data: Dict, level: str = "INFO"):
|
|
694
|
+
"""Log a dictionary as a single entry.
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
data: Dictionary to log
|
|
698
|
+
level: Log level
|
|
699
|
+
"""
|
|
700
|
+
log_entry = {
|
|
701
|
+
"timestamp": time.time(),
|
|
702
|
+
"level": level.upper(),
|
|
703
|
+
"run_id": self.run_id,
|
|
704
|
+
"rank": self.rank,
|
|
705
|
+
**data,
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
with open(self.jsonl_file, "a") as f:
|
|
709
|
+
f.write(json.dumps(log_entry) + "\n")
|
|
710
|
+
|
|
711
|
+
def save_metadata(self, metadata: Dict, filename: str = "metadata.json"):
|
|
712
|
+
"""Save metadata to a JSON file.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
metadata: Dictionary containing metadata
|
|
716
|
+
filename: Name of file to save to
|
|
717
|
+
"""
|
|
718
|
+
filepath = self.run_dir / filename
|
|
719
|
+
with open(filepath, "w") as f:
|
|
720
|
+
json.dump(metadata, f, indent=2)
|
|
721
|
+
|
|
722
|
+
def __enter__(self):
|
|
723
|
+
"""Context manager entry."""
|
|
724
|
+
return self
|
|
725
|
+
|
|
726
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
727
|
+
"""Context manager exit."""
|
|
728
|
+
self.close()
|
|
729
|
+
return False
|