loopkit 0.0.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
loopkit/logger.py ADDED
@@ -0,0 +1,729 @@
1
+ import inspect
2
+ import json
3
+ import logging
4
+ import sys
5
+ import time
6
+ import uuid
7
+ from contextlib import contextmanager
8
+ from pathlib import Path
9
+ from typing import Dict, Optional
10
+
11
+ import loopkit
12
+
13
+
14
+ class ColorFormatter(logging.Formatter):
15
+ """Custom formatter that adds colors to console output."""
16
+
17
+ # ANSI color codes
18
+ COLORS = {
19
+ "DEBUG": "\033[36m", # Cyan
20
+ "INFO": "\033[32m", # Green
21
+ "WARNING": "\033[33m", # Yellow
22
+ "ERROR": "\033[31m", # Red
23
+ "CRITICAL": "\033[35m", # Magenta
24
+ }
25
+ RESET = "\033[0m"
26
+ BOLD = "\033[1m"
27
+ DIM = "\033[2m"
28
+
29
+ # Special colors for specific message types
30
+ STAGE_COLOR = "\033[34m" # Blue
31
+ TIMER_COLOR = "\033[36m" # Cyan
32
+ METRIC_COLOR = "\033[35m" # Magenta
33
+
34
+ def format(self, record):
35
+ """Format log record with colors."""
36
+ # Get the base message
37
+ message = record.getMessage()
38
+
39
+ # Determine color based on level
40
+ level_color = self.COLORS.get(record.levelname, "")
41
+
42
+ # Special coloring for specific message patterns
43
+ if message.startswith("Stage ["):
44
+ # Stage messages in blue with bold stage name
45
+ parts = message.split("]", 1)
46
+ if len(parts) == 2:
47
+ stage_name = parts[0] + "]"
48
+ rest = parts[1]
49
+ colored_message = (
50
+ f"{self.STAGE_COLOR}{self.BOLD}{stage_name}{self.RESET}"
51
+ f"{self.STAGE_COLOR}{rest}{self.RESET}"
52
+ )
53
+ else:
54
+ colored_message = f"{self.STAGE_COLOR}{message}{self.RESET}"
55
+ elif message.startswith("Timer ["):
56
+ # Timer messages in cyan with bold timer name
57
+ parts = message.split("]", 1)
58
+ if len(parts) == 2:
59
+ timer_name = parts[0] + "]"
60
+ rest = parts[1]
61
+ colored_message = (
62
+ f"{self.TIMER_COLOR}{self.BOLD}{timer_name}{self.RESET}"
63
+ f"{self.TIMER_COLOR}{rest}{self.RESET}"
64
+ )
65
+ else:
66
+ colored_message = f"{self.TIMER_COLOR}{message}{self.RESET}"
67
+ else:
68
+ # Regular messages: color based on level
69
+ level_name = f"{level_color}{self.BOLD}{record.levelname}{self.RESET}"
70
+ colored_message = f"{level_color}{message}{self.RESET}"
71
+ return f"{level_name}: {colored_message}"
72
+
73
+ # For special messages, just add level name without extra coloring
74
+ level_name = f"{level_color}{self.BOLD}{record.levelname}{self.RESET}"
75
+ return f"{level_name}: {colored_message}"
76
+
77
+
78
+ class ExperimentLogger:
79
+ """Experiment logger with structured logging and metrics tracking.
80
+
81
+ This logger provides two equivalent APIs for text logging:
82
+ 1. Generic: logger.log(message, level='INFO')
83
+ 2. Level-specific: logger.info(message), logger.warning(message), logger.error(message)
84
+
85
+ For metrics tracking, use logger.log_metric(step, split, name, value)
86
+
87
+ Features:
88
+ - Rank-aware logging (DDP compatible)
89
+ - Human-readable and machine-readable logs
90
+ - Metrics tracking with best model tracking
91
+ - Profiling context managers (can be disabled globally)
92
+ - Configurable log levels
93
+
94
+ Args:
95
+ run_dir: Directory for log files
96
+ rank: Process rank (0 for main process)
97
+ run_id: Unique run identifier
98
+ log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
99
+ console_output: Whether to output to console (only rank 0)
100
+
101
+ Examples:
102
+ >>> # Basic usage
103
+ >>> logger = ExperimentLogger(run_dir='runs/exp1')
104
+ >>>
105
+ >>> # Text logging (two equivalent ways)
106
+ >>> logger.info("Training started") # Preferred
107
+ >>> logger.log("Training started", level="INFO") # Alternative
108
+ >>>
109
+ >>> # Metrics logging
110
+ >>> logger.log_metric(step=0, split='train', name='loss', value=0.5)
111
+ >>>
112
+ >>> # Profiling (enabled by default)
113
+ >>> with logger.timer('data_loading'):
114
+ >>> data = load_data()
115
+ >>>
116
+ >>> # Disable profiling globally for production
117
+ >>> import loopkit
118
+ >>> loopkit.timer = False # or set EM_TIMER=0 environment variable
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ run_dir: Path | str,
124
+ rank: int = 0,
125
+ run_id: str = None,
126
+ log_level: str = None,
127
+ console_output: bool = None,
128
+ ):
129
+ # Convert to Path if string
130
+ if isinstance(run_dir, str):
131
+ run_dir = Path(run_dir)
132
+
133
+ self.run_dir = run_dir
134
+ self.rank = rank
135
+ self.run_id = run_id or "unknown"
136
+
137
+ run_dir.mkdir(parents=True, exist_ok=True)
138
+
139
+ # Use global log level if not specified
140
+ if log_level is None:
141
+ log_level = loopkit.log_level
142
+
143
+ # Use global verbose flag if console_output not specified
144
+ if console_output is None:
145
+ console_output = loopkit.verbose
146
+
147
+ # Human-readable log (per-rank) - NO COLORS in file
148
+ log_filename = "exp.log" if rank == 0 else f"exp_rank{rank}.log"
149
+ self.log_file = run_dir / log_filename
150
+ self.log_handler = logging.FileHandler(self.log_file)
151
+ self.log_handler.setFormatter(
152
+ logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
153
+ )
154
+
155
+ # JSONL log for machine reading (per-rank)
156
+ jsonl_filename = "events.jsonl" if rank == 0 else f"events_rank{rank}.jsonl"
157
+ self.jsonl_file = run_dir / jsonl_filename
158
+
159
+ # Console handler (rank-0 only by default) - WITH COLORS
160
+ self.console_handler = None
161
+ if rank == 0 and console_output:
162
+ self.console_handler = logging.StreamHandler(sys.stdout)
163
+ # Check if stdout supports colors (not piped/redirected)
164
+ use_colors = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
165
+ if use_colors:
166
+ self.console_handler.setFormatter(ColorFormatter())
167
+ else:
168
+ # Fallback to plain format if output is redirected
169
+ self.console_handler.setFormatter(
170
+ logging.Formatter("%(levelname)s: %(message)s")
171
+ )
172
+
173
+ # Metrics CSV (rank-0 only) - respect track_metrics flag
174
+ self.metrics_file = run_dir / "metrics.csv"
175
+ if rank == 0 and loopkit.track_metrics and not self.metrics_file.exists():
176
+ with open(self.metrics_file, "w") as f:
177
+ f.write("step,split,name,value,wall_time\n")
178
+
179
+ # Best metrics tracking (per metric)
180
+ self.best_metrics: Dict[
181
+ str, Dict
182
+ ] = {} # {metric_name: {'value': ..., 'step': ..., 'mode': ...}}
183
+ self.best_file = run_dir / "best.json"
184
+ if rank == 0 and self.best_file.exists():
185
+ with open(self.best_file) as f:
186
+ self.best_metrics = json.load(f)
187
+
188
+ # Setup logger (unique name to avoid handler accumulation)
189
+ logger_name = f"loopkit_rank_{rank}_{uuid.uuid4().hex[:8]}"
190
+ self.logger = logging.getLogger(logger_name)
191
+ self.logger.setLevel(getattr(logging, log_level.upper()))
192
+
193
+ # Clear any existing handlers (in case logger was reused)
194
+ self.logger.handlers.clear()
195
+
196
+ self.logger.addHandler(self.log_handler)
197
+ if self.console_handler:
198
+ self.logger.addHandler(self.console_handler)
199
+
200
+ # Prevent propagation to root logger
201
+ self.logger.propagate = False
202
+
203
+ # Stage tracking (for hierarchical context)
204
+ self._current_stage: Optional[str] = None
205
+ self._stage_stack: list = []
206
+
207
+ # Rate limiting tracking
208
+ self._rate_limit_counters: Dict[str, int] = {}
209
+ self._rate_limit_last_log: Dict[str, float] = {}
210
+
211
+ # Time estimation tracking (for automatic ETA in log_metric)
212
+ self._eta_tracking: Dict[
213
+ str, Dict
214
+ ] = {} # {key: {'start_time': ..., 'start_step': ..., 'total_steps': ...}}
215
+
216
+ def log(self, message: str, level: str = "INFO", **kwargs):
217
+ """Log a message with context.
218
+
219
+ Args:
220
+ message: The log message
221
+ level: Log level (DEBUG, INFO, WARNING, ERROR)
222
+ **kwargs: Additional context to include in log
223
+ """
224
+ # Check if this level should be logged
225
+ log_level_num = getattr(logging, level.upper(), logging.INFO)
226
+ logger_level_num = self.logger.level
227
+
228
+ # Only write to JSONL and logger if level is high enough
229
+ if log_level_num >= logger_level_num:
230
+ # Write to JSONL
231
+ log_entry = {
232
+ "timestamp": time.time(),
233
+ "level": level.upper(),
234
+ "message": message,
235
+ "run_id": self.run_id,
236
+ "rank": self.rank,
237
+ **kwargs,
238
+ }
239
+
240
+ with open(self.jsonl_file, "a") as f:
241
+ f.write(json.dumps(log_entry) + "\n")
242
+
243
+ # Write to logger
244
+ self.logger.log(log_level_num, message)
245
+
246
+ def debug(self, message: str, **kwargs):
247
+ self.log(message, "DEBUG", **kwargs)
248
+
249
+ def info(self, message: str, **kwargs):
250
+ self.log(message, "INFO", **kwargs)
251
+
252
+ def warning(self, message: str, **kwargs):
253
+ self.log(message, "WARNING", **kwargs)
254
+
255
+ def error(self, message: str, **kwargs):
256
+ self.log(message, "ERROR", **kwargs)
257
+
258
+ @staticmethod
259
+ def _format_time(seconds: float) -> str:
260
+ """Format seconds into human-readable time string.
261
+
262
+ Args:
263
+ seconds: Time in seconds
264
+
265
+ Returns:
266
+ Formatted time string (e.g., "2h 15m 30s", "45s", "1d 3h 20m")
267
+ """
268
+ if seconds < 0:
269
+ return "0s"
270
+
271
+ days = int(seconds // 86400)
272
+ seconds %= 86400
273
+ hours = int(seconds // 3600)
274
+ seconds %= 3600
275
+ minutes = int(seconds // 60)
276
+ secs = int(seconds % 60)
277
+
278
+ parts = []
279
+ if days > 0:
280
+ parts.append(f"{days}d")
281
+ if hours > 0:
282
+ parts.append(f"{hours}h")
283
+ if minutes > 0:
284
+ parts.append(f"{minutes}m")
285
+ if secs > 0 or not parts:
286
+ parts.append(f"{secs}s")
287
+
288
+ return " ".join(parts)
289
+
290
+ def log_metric(
291
+ self,
292
+ step: int,
293
+ split: str,
294
+ name: str,
295
+ value: float,
296
+ track_best: bool = True,
297
+ mode: Optional[str] = None,
298
+ total_steps: Optional[int] = None,
299
+ eta_key: Optional[str] = None,
300
+ ):
301
+ """Log a metric to CSV and update best tracking.
302
+
303
+ Args:
304
+ step: Training step
305
+ split: Data split (train/val/test)
306
+ name: Metric name
307
+ value: Metric value
308
+ track_best: Whether to track this as a best metric
309
+ mode: 'min' or 'max' for best tracking (auto-detected if None)
310
+ total_steps: Total number of steps (enables automatic ETA estimation)
311
+ eta_key: Key for ETA tracking (auto-generated from split/name if None)
312
+ """
313
+ if self.rank != 0:
314
+ return # Only rank 0 logs metrics
315
+
316
+ # Skip if metrics tracking is disabled
317
+ if not loopkit.track_metrics:
318
+ return
319
+
320
+ wall_time = time.time()
321
+ with open(self.metrics_file, "a") as f:
322
+ f.write(f"{step},{split},{name},{value},{wall_time}\n")
323
+
324
+ # Automatic ETA estimation if total_steps is provided
325
+ if total_steps is not None and step > 0:
326
+ # Auto-generate key if not provided
327
+ if eta_key is None:
328
+ eta_key = f"{split}/{name}"
329
+
330
+ # Initialize tracking for this metric if first time
331
+ if eta_key not in self._eta_tracking:
332
+ self._eta_tracking[eta_key] = {
333
+ "start_time": wall_time,
334
+ "start_step": step,
335
+ "total_steps": total_steps,
336
+ }
337
+
338
+ # Calculate ETA
339
+ tracking = self._eta_tracking[eta_key]
340
+ elapsed = wall_time - tracking["start_time"]
341
+ steps_done = step - tracking["start_step"]
342
+
343
+ if steps_done > 0:
344
+ steps_remaining = total_steps - step
345
+ time_per_step = elapsed / steps_done
346
+ eta_seconds = time_per_step * steps_remaining
347
+
348
+ # Log ETA info (only if significant progress made to avoid spam)
349
+ if steps_done % max(1, total_steps // 20) == 0 or step == total_steps:
350
+ progress_pct = 100 * step / total_steps
351
+ eta_str = self._format_time(eta_seconds)
352
+ elapsed_str = self._format_time(elapsed)
353
+
354
+ self.info(
355
+ f"Progress [{split}/{name}]: {step}/{total_steps} ({progress_pct:.1f}%) | "
356
+ f"Elapsed: {elapsed_str} | ETA: {eta_str}",
357
+ step=step,
358
+ progress=progress_pct,
359
+ eta_seconds=eta_seconds,
360
+ elapsed_seconds=elapsed,
361
+ )
362
+
363
+ # Update best metrics tracking
364
+ if track_best:
365
+ # Auto-detect mode if not specified
366
+ if mode is None:
367
+ name_lower = name.lower()
368
+ if any(x in name_lower for x in ["loss", "error"]):
369
+ mode = "min"
370
+ elif any(x in name_lower for x in ["acc", "accuracy", "f1", "auc"]):
371
+ mode = "max"
372
+ else:
373
+ mode = "min" # Default to min
374
+
375
+ metric_key = f"{split}/{name}"
376
+
377
+ # Check if this is a new best
378
+ is_best = False
379
+ if metric_key not in self.best_metrics:
380
+ is_best = True
381
+ else:
382
+ prev_best = self.best_metrics[metric_key]["value"]
383
+ is_better = (mode == "min" and value < prev_best) or (
384
+ mode == "max" and value > prev_best
385
+ )
386
+ if is_better:
387
+ is_best = True
388
+
389
+ # Update if best
390
+ if is_best:
391
+ self.best_metrics[metric_key] = {
392
+ "value": value,
393
+ "step": step,
394
+ "mode": mode,
395
+ }
396
+
397
+ # Save best metrics
398
+ with open(self.best_file, "w") as f:
399
+ json.dump(self.best_metrics, f, indent=2)
400
+
401
+ def get_best_metric(self, name: str, split: str = "val") -> Optional[Dict]:
402
+ """Get the best value for a metric.
403
+
404
+ Args:
405
+ name: Metric name
406
+ split: Data split
407
+
408
+ Returns:
409
+ dict: Best metric info with 'value', 'step', 'mode' or None
410
+ """
411
+ metric_key = f"{split}/{name}"
412
+ return self.best_metrics.get(metric_key)
413
+
414
+ @contextmanager
415
+ def timer(self, name: str, log_result: bool = True):
416
+ """Context manager for timing code sections.
417
+
418
+ Provides simple wall-clock timing using time.perf_counter().
419
+ For detailed GPU/CPU profiling, use torch.profiler directly.
420
+
421
+ Timing can be globally disabled by setting loopkit.timer = False
422
+ or environment variable LK_TIMER=0.
423
+
424
+ Args:
425
+ name: Name of the timed section
426
+ log_result: Whether to log the timing result
427
+
428
+ Usage:
429
+ # Basic timing
430
+ with logger.timer("data_loading"):
431
+ data = load_data()
432
+
433
+ # Get elapsed time
434
+ with logger.timer("training_step") as result:
435
+ loss = model(batch)
436
+ print(f"Step took {result['elapsed']:.4f}s")
437
+
438
+ # Silent timing (no logging)
439
+ with logger.timer("forward", log_result=False) as result:
440
+ output = model(input)
441
+ elapsed = result['elapsed']
442
+
443
+ Yields:
444
+ dict: Dictionary with 'elapsed' key (in seconds), updated when context exits
445
+
446
+ Example with PyTorch Profiler:
447
+ # Use torch.profiler for detailed profiling
448
+ with torch.profiler.profile(
449
+ activities=[
450
+ torch.profiler.ProfilerActivity.CPU,
451
+ torch.profiler.ProfilerActivity.CUDA,
452
+ ],
453
+ record_shapes=True,
454
+ ) as prof:
455
+ with logger.timer("training_step"):
456
+ loss = model(batch)
457
+ loss.backward()
458
+ prof.export_chrome_trace("trace.json")
459
+ """
460
+ result = {"elapsed": 0.0}
461
+
462
+ # Check if profiling is enabled globally
463
+ if not loopkit.timer:
464
+ yield result
465
+ return
466
+
467
+ start_time = time.perf_counter()
468
+
469
+ try:
470
+ yield result
471
+ finally:
472
+ result["elapsed"] = time.perf_counter() - start_time
473
+
474
+ if log_result and self.rank == 0:
475
+ self.info(
476
+ f"Timer [{name}]: {result['elapsed']:.4f}s",
477
+ section=name,
478
+ elapsed=result["elapsed"],
479
+ )
480
+
481
+ @contextmanager
482
+ def stage(self, name: str, **metadata):
483
+ """Context manager for tracking training stages with hierarchical context.
484
+
485
+ Automatically logs stage entry/exit, tracks duration, and provides context
486
+ for metrics and logs. Stages can be nested to create hierarchies
487
+ (e.g., epoch → train → batch).
488
+
489
+ Args:
490
+ name: Name of the stage (e.g., 'epoch', 'train', 'validation')
491
+ **metadata: Additional metadata to log (e.g., epoch=5, batch=10)
492
+
493
+ Usage:
494
+ # Simple stage
495
+ with logger.stage("training"):
496
+ train_model()
497
+
498
+ # With metadata
499
+ with logger.stage("epoch", epoch=5, lr=0.001):
500
+ train_epoch()
501
+
502
+ # Nested stages
503
+ with logger.stage("epoch", epoch=5):
504
+ with logger.stage("train"):
505
+ train_loss = train_epoch()
506
+ with logger.stage("validation"):
507
+ val_loss = validate()
508
+
509
+ Yields:
510
+ dict: Stage info with 'name', 'metadata', 'elapsed' keys
511
+ """
512
+ stage_info = {"name": name, "metadata": metadata, "elapsed": 0.0}
513
+
514
+ # Build hierarchical stage path
515
+ if self._stage_stack:
516
+ parent = self._stage_stack[-1]["name"]
517
+ full_name = f"{parent}/{name}"
518
+ else:
519
+ full_name = name
520
+
521
+ # Log stage entry
522
+ if self.rank == 0:
523
+ metadata_str = ", ".join(f"{k}={v}" for k, v in metadata.items())
524
+ msg = f"Stage [{full_name}]"
525
+ if metadata_str:
526
+ msg += f" ({metadata_str})"
527
+ msg += " - START"
528
+ self.info(msg, stage=full_name, stage_event="start", **metadata)
529
+
530
+ # Push to stack
531
+ self._stage_stack.append(stage_info)
532
+ old_stage = self._current_stage
533
+ self._current_stage = full_name
534
+
535
+ start_time = time.perf_counter()
536
+
537
+ try:
538
+ yield stage_info
539
+ finally:
540
+ elapsed = time.perf_counter() - start_time
541
+ stage_info["elapsed"] = elapsed
542
+
543
+ # Log stage exit
544
+ if self.rank == 0:
545
+ msg = f"Stage [{full_name}] - END ({elapsed:.4f}s)"
546
+ self.info(
547
+ msg, stage=full_name, stage_event="end", elapsed=elapsed, **metadata
548
+ )
549
+
550
+ # Pop from stack
551
+ self._stage_stack.pop()
552
+ self._current_stage = old_stage
553
+
554
+ def should_run(
555
+ self, every: int = None, seconds: float = None, key: str = None
556
+ ) -> bool:
557
+ """Check if code should run based on rate limiting.
558
+
559
+ Returns a boolean indicating whether to execute. Use with 'if' for clean syntax.
560
+
561
+ Args:
562
+ every: Execute every N iterations (mutually exclusive with seconds)
563
+ seconds: Execute every N seconds (mutually exclusive with every)
564
+ key: Unique key for this rate limiter (auto-generated if None)
565
+
566
+ Returns:
567
+ bool: True if code should execute, False if skipped
568
+
569
+ Usage:
570
+ # Clean 'if' syntax
571
+ for step in range(10000):
572
+ if logger.should_run(every=100):
573
+ logger.info(f"Step {step}")
574
+ save_checkpoint()
575
+
576
+ # Time-based rate limiting
577
+ for batch in dataloader:
578
+ if logger.should_run(seconds=5.0):
579
+ logger.info("Expensive logging")
580
+ """
581
+ if every is None and seconds is None:
582
+ raise ValueError("Must specify either 'every' or 'seconds'")
583
+ if every is not None and seconds is not None:
584
+ raise ValueError("Cannot specify both 'every' and 'seconds'")
585
+
586
+ # Auto-generate key from caller location if not provided
587
+ if key is None:
588
+ frame = inspect.currentframe()
589
+ caller_frame = frame.f_back if frame else None
590
+ if caller_frame:
591
+ key = f"{caller_frame.f_code.co_filename}:{caller_frame.f_lineno}"
592
+ else:
593
+ key = "default"
594
+
595
+ should_execute = False
596
+
597
+ if every is not None:
598
+ # Iteration-based rate limiting
599
+ count = self._rate_limit_counters.get(key, 0)
600
+ self._rate_limit_counters[key] = count + 1
601
+
602
+ if count % every == 0:
603
+ should_execute = True
604
+ else:
605
+ # Time-based rate limiting
606
+ current_time = time.time()
607
+ last_execute_time = self._rate_limit_last_log.get(key, 0.0)
608
+
609
+ if current_time - last_execute_time >= seconds:
610
+ should_execute = True
611
+ self._rate_limit_last_log[key] = current_time
612
+
613
+ return should_execute
614
+
615
+ # Backward compatibility alias - wraps should_run as context manager
616
+ @contextmanager
617
+ def do_every(self, every: int = None, seconds: float = None, key: str = None):
618
+ """Deprecated: Use should_run() instead.
619
+
620
+ Context manager version for backward compatibility.
621
+ """
622
+ should_execute = self.should_run(every=every, seconds=seconds, key=key)
623
+ yield should_execute
624
+
625
+ @contextmanager
626
+ def log_every(self, every: int = None, seconds: float = None, key: str = None):
627
+ """Context manager for rate-limited logging.
628
+
629
+ Logs only every N iterations or every N seconds, useful for reducing
630
+ log spam in tight training loops while still capturing data in JSONL.
631
+
632
+ Args:
633
+ every: Log every N iterations (mutually exclusive with seconds)
634
+ seconds: Log every N seconds (mutually exclusive with every)
635
+ key: Unique key for this rate limiter (auto-generated if None)
636
+
637
+ Usage:
638
+ # Log every 100 iterations
639
+ for step in range(10000):
640
+ loss = train_step()
641
+ with logger.log_every(every=100):
642
+ logger.info(f"Step {step}: loss={loss:.4f}")
643
+
644
+ # Log every 5 seconds
645
+ for batch in dataloader:
646
+ with logger.log_every(seconds=5.0):
647
+ logger.info(f"Processing batch...")
648
+
649
+ # Multiple rate limiters with different keys
650
+ for step in range(1000):
651
+ with logger.log_every(every=10, key="loss"):
652
+ logger.info(f"Loss: {loss:.4f}")
653
+ with logger.log_every(every=100, key="detailed"):
654
+ logger.info(f"Detailed metrics: {metrics}")
655
+
656
+ Yields:
657
+ bool: True if logging should occur, False if suppressed
658
+ """
659
+ # Use should_run to determine if we should log
660
+ should_log = self.should_run(every=every, seconds=seconds, key=key)
661
+
662
+ # Temporarily suppress logging if not time to log
663
+ if not should_log and self.console_handler:
664
+ # Remove console handler temporarily
665
+ if self.console_handler in self.logger.handlers:
666
+ self.logger.removeHandler(self.console_handler)
667
+ try:
668
+ yield should_log
669
+ finally:
670
+ # Restore console handler
671
+ if self.console_handler not in self.logger.handlers:
672
+ self.logger.addHandler(self.console_handler)
673
+ else:
674
+ # Handler not present, just yield
675
+ yield should_log
676
+ else:
677
+ yield should_log
678
+
679
+ def set_log_level(self, level: str):
680
+ """Change the logging level.
681
+
682
+ Args:
683
+ level: One of DEBUG, INFO, WARNING, ERROR
684
+ """
685
+ self.logger.setLevel(getattr(logging, level.upper()))
686
+
687
+ def close(self):
688
+ """Close all log handlers."""
689
+ for handler in self.logger.handlers[:]:
690
+ handler.close()
691
+ self.logger.removeHandler(handler)
692
+
693
+ def log_dict(self, data: Dict, level: str = "INFO"):
694
+ """Log a dictionary as a single entry.
695
+
696
+ Args:
697
+ data: Dictionary to log
698
+ level: Log level
699
+ """
700
+ log_entry = {
701
+ "timestamp": time.time(),
702
+ "level": level.upper(),
703
+ "run_id": self.run_id,
704
+ "rank": self.rank,
705
+ **data,
706
+ }
707
+
708
+ with open(self.jsonl_file, "a") as f:
709
+ f.write(json.dumps(log_entry) + "\n")
710
+
711
+ def save_metadata(self, metadata: Dict, filename: str = "metadata.json"):
712
+ """Save metadata to a JSON file.
713
+
714
+ Args:
715
+ metadata: Dictionary containing metadata
716
+ filename: Name of file to save to
717
+ """
718
+ filepath = self.run_dir / filename
719
+ with open(filepath, "w") as f:
720
+ json.dump(metadata, f, indent=2)
721
+
722
+ def __enter__(self):
723
+ """Context manager entry."""
724
+ return self
725
+
726
+ def __exit__(self, exc_type, exc_val, exc_tb):
727
+ """Context manager exit."""
728
+ self.close()
729
+ return False