textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# textpolicy/utils/logging/factory.py
|
|
2
|
+
"""
|
|
3
|
+
Factory functions for creating logger instances.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Optional, List
|
|
7
|
+
from .base import Logger
|
|
8
|
+
from .wandb import WandbLogger
|
|
9
|
+
from .tensorboard import TensorboardLogger
|
|
10
|
+
from .console import ConsoleLogger
|
|
11
|
+
from .multi import MultiLogger
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def create_logger(
|
|
15
|
+
logger_type: str = "console",
|
|
16
|
+
**kwargs
|
|
17
|
+
) -> Logger:
|
|
18
|
+
"""
|
|
19
|
+
Factory function to create logger instances.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
logger_type: Type of logger ("wandb", "tensorboard", "console")
|
|
23
|
+
**kwargs: Logger-specific parameters
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Logger instance
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If logger_type is unknown
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
# Console logger
|
|
33
|
+
logger = create_logger("console", verbose=True)
|
|
34
|
+
|
|
35
|
+
# Wandb logger
|
|
36
|
+
logger = create_logger("wandb", project_name="my-project", run_name="test")
|
|
37
|
+
|
|
38
|
+
# TensorBoard logger
|
|
39
|
+
logger = create_logger("tensorboard", log_dir="./logs")
|
|
40
|
+
"""
|
|
41
|
+
if logger_type == "wandb":
|
|
42
|
+
if "project_name" not in kwargs:
|
|
43
|
+
raise ValueError("project_name is required for wandb logger")
|
|
44
|
+
project_name = kwargs.pop("project_name")
|
|
45
|
+
return WandbLogger(project_name, **kwargs)
|
|
46
|
+
elif logger_type == "tensorboard":
|
|
47
|
+
if "log_dir" not in kwargs:
|
|
48
|
+
raise ValueError("log_dir is required for tensorboard logger")
|
|
49
|
+
log_dir = kwargs.pop("log_dir")
|
|
50
|
+
return TensorboardLogger(log_dir)
|
|
51
|
+
elif logger_type == "console":
|
|
52
|
+
return ConsoleLogger(**kwargs)
|
|
53
|
+
else:
|
|
54
|
+
available_types = ["wandb", "tensorboard", "console"]
|
|
55
|
+
raise ValueError(f"Unknown logger type: {logger_type}. Available: {available_types}")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def create_multi_logger(
|
|
59
|
+
configs: List[dict]
|
|
60
|
+
) -> MultiLogger:
|
|
61
|
+
"""
|
|
62
|
+
Create a MultiLogger from a list of logger configurations.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
configs: List of dictionaries with "type" and other parameters
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
MultiLogger instance
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
logger = create_multi_logger([
|
|
72
|
+
{"type": "console", "verbose": True},
|
|
73
|
+
{"type": "wandb", "project_name": "my-project"}
|
|
74
|
+
])
|
|
75
|
+
"""
|
|
76
|
+
loggers = []
|
|
77
|
+
for config in configs:
|
|
78
|
+
logger_type = config.pop("type")
|
|
79
|
+
logger = create_logger(logger_type, **config)
|
|
80
|
+
loggers.append(logger)
|
|
81
|
+
|
|
82
|
+
return MultiLogger(loggers)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def create_auto_logger(
|
|
86
|
+
project_name: Optional[str] = None,
|
|
87
|
+
log_dir: Optional[str] = None,
|
|
88
|
+
console: bool = True
|
|
89
|
+
) -> Logger:
|
|
90
|
+
"""
|
|
91
|
+
Automatically create appropriate logger based on available dependencies.
|
|
92
|
+
|
|
93
|
+
Priority order:
|
|
94
|
+
1. Wandb (if project_name provided and wandb available)
|
|
95
|
+
2. TensorBoard (if log_dir provided and tensorboard available)
|
|
96
|
+
3. Console (always available)
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
project_name: Wandb project name (enables wandb if available)
|
|
100
|
+
log_dir: TensorBoard log directory (enables tensorboard if available)
|
|
101
|
+
console: Whether to include console logging
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Logger instance (MultiLogger if multiple backends, single Logger otherwise)
|
|
105
|
+
"""
|
|
106
|
+
loggers = []
|
|
107
|
+
|
|
108
|
+
# Try wandb first
|
|
109
|
+
if project_name:
|
|
110
|
+
try:
|
|
111
|
+
loggers.append(WandbLogger(project_name))
|
|
112
|
+
except ImportError:
|
|
113
|
+
print("Warning: wandb not available, skipping")
|
|
114
|
+
|
|
115
|
+
# Try tensorboard
|
|
116
|
+
if log_dir:
|
|
117
|
+
try:
|
|
118
|
+
loggers.append(TensorboardLogger(log_dir))
|
|
119
|
+
except ImportError:
|
|
120
|
+
print("Warning: tensorboard not available, skipping")
|
|
121
|
+
|
|
122
|
+
# Always add console if requested
|
|
123
|
+
if console:
|
|
124
|
+
loggers.append(ConsoleLogger(verbose=True))
|
|
125
|
+
|
|
126
|
+
# Return appropriate logger type
|
|
127
|
+
if len(loggers) == 0:
|
|
128
|
+
# Fallback to console
|
|
129
|
+
return ConsoleLogger(verbose=True)
|
|
130
|
+
elif len(loggers) == 1:
|
|
131
|
+
return loggers[0]
|
|
132
|
+
else:
|
|
133
|
+
return MultiLogger(loggers)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# textpolicy/utils/logging/multi.py
|
|
2
|
+
"""
|
|
3
|
+
Multi-logger for combining multiple logging backends.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, List
|
|
7
|
+
from .base import Logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MultiLogger(Logger):
|
|
11
|
+
"""
|
|
12
|
+
Combine multiple logging backends into a single interface.
|
|
13
|
+
|
|
14
|
+
Features:
|
|
15
|
+
- Log to multiple backends simultaneously
|
|
16
|
+
- Graceful error handling (one logger failure doesn't stop others)
|
|
17
|
+
- Unified interface for complex logging setups
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
# Log to both wandb and console
|
|
21
|
+
logger = MultiLogger([
|
|
22
|
+
WandbLogger("my-project"),
|
|
23
|
+
ConsoleLogger(verbose=True)
|
|
24
|
+
])
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, loggers: List[Logger]):
|
|
28
|
+
"""
|
|
29
|
+
Initialize multi-logger with list of backends.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
loggers: List of Logger instances to combine
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If no loggers provided
|
|
36
|
+
"""
|
|
37
|
+
if not loggers:
|
|
38
|
+
raise ValueError("At least one logger must be provided")
|
|
39
|
+
self.loggers = loggers
|
|
40
|
+
|
|
41
|
+
def log_metrics(self, metrics: Dict[str, float], step: int):
|
|
42
|
+
"""
|
|
43
|
+
Log training metrics to all backends.
|
|
44
|
+
|
|
45
|
+
Continues logging to other backends even if one fails.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
metrics: Training metrics dictionary
|
|
49
|
+
step: Training step number
|
|
50
|
+
"""
|
|
51
|
+
for logger in self.loggers:
|
|
52
|
+
try:
|
|
53
|
+
logger.log_metrics(metrics, step)
|
|
54
|
+
except Exception as e:
|
|
55
|
+
print(f"Warning: Logger {type(logger).__name__} failed: {e}")
|
|
56
|
+
|
|
57
|
+
def log_evaluation(self, metrics: Dict[str, float], step: int):
|
|
58
|
+
"""
|
|
59
|
+
Log evaluation metrics to all backends.
|
|
60
|
+
|
|
61
|
+
Continues logging to other backends even if one fails.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
metrics: Evaluation metrics dictionary
|
|
65
|
+
step: Training step when evaluation was performed
|
|
66
|
+
"""
|
|
67
|
+
for logger in self.loggers:
|
|
68
|
+
try:
|
|
69
|
+
logger.log_evaluation(metrics, step)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
print(f"Warning: Logger {type(logger).__name__} failed: {e}")
|
|
72
|
+
|
|
73
|
+
def finish(self):
|
|
74
|
+
"""
|
|
75
|
+
Finish all loggers.
|
|
76
|
+
|
|
77
|
+
Attempts to finish all loggers even if some fail.
|
|
78
|
+
"""
|
|
79
|
+
for logger in self.loggers:
|
|
80
|
+
try:
|
|
81
|
+
logger.finish()
|
|
82
|
+
except Exception as e:
|
|
83
|
+
print(f"Warning: Logger {type(logger).__name__} finish failed: {e}")
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# textpolicy/utils/logging/tensorboard.py
|
|
2
|
+
"""
|
|
3
|
+
TensorBoard logging integration.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict
|
|
7
|
+
from .base import Logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TensorboardLogger(Logger):
|
|
11
|
+
"""
|
|
12
|
+
TensorBoard integration for local experiment visualization.
|
|
13
|
+
|
|
14
|
+
Features:
|
|
15
|
+
- Local scalar metric visualization
|
|
16
|
+
- Histogram and distribution tracking
|
|
17
|
+
- Image and model graph visualization
|
|
18
|
+
- No external service dependency
|
|
19
|
+
|
|
20
|
+
Requires: pip install tensorboard
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, log_dir: str):
|
|
24
|
+
"""
|
|
25
|
+
Initialize TensorBoard logging.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
log_dir: Directory to store TensorBoard log files
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ImportError: If tensorboard is not installed
|
|
32
|
+
"""
|
|
33
|
+
try:
|
|
34
|
+
from torch.utils.tensorboard import SummaryWriter # type: ignore
|
|
35
|
+
self.writer = SummaryWriter(log_dir)
|
|
36
|
+
except ImportError:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"tensorboard not installed. Install with: pip install tensorboard"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def log_metrics(self, metrics: Dict[str, float], step: int):
|
|
42
|
+
"""
|
|
43
|
+
Log training metrics to TensorBoard with 'train/' prefix.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
metrics: Training metrics dictionary
|
|
47
|
+
step: Training step number
|
|
48
|
+
"""
|
|
49
|
+
for key, value in metrics.items():
|
|
50
|
+
self.writer.add_scalar(f"train/{key}", value, step)
|
|
51
|
+
|
|
52
|
+
def log_evaluation(self, metrics: Dict[str, float], step: int):
|
|
53
|
+
"""
|
|
54
|
+
Log evaluation metrics to TensorBoard with 'eval/' prefix.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
metrics: Evaluation metrics dictionary
|
|
58
|
+
step: Training step when evaluation was performed
|
|
59
|
+
"""
|
|
60
|
+
for key, value in metrics.items():
|
|
61
|
+
self.writer.add_scalar(f"eval/{key}", value, step)
|
|
62
|
+
|
|
63
|
+
def finish(self):
|
|
64
|
+
"""Close TensorBoard writer and flush remaining data."""
|
|
65
|
+
self.writer.close()
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# textpolicy/utils/logging/wandb.py
|
|
2
|
+
"""
|
|
3
|
+
Weights & Biases (wandb) logging integration.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Optional
|
|
7
|
+
from .base import Logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class WandbLogger(Logger):
|
|
11
|
+
"""
|
|
12
|
+
Weights & Biases integration for experiment tracking.
|
|
13
|
+
|
|
14
|
+
Features:
|
|
15
|
+
- Automatic experiment organization with projects
|
|
16
|
+
- Real-time metric visualization
|
|
17
|
+
- Hyperparameter tracking
|
|
18
|
+
- Model artifact management
|
|
19
|
+
|
|
20
|
+
Requires: pip install wandb
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, project_name: str, run_name: Optional[str] = None, **kwargs):
|
|
24
|
+
"""
|
|
25
|
+
Initialize wandb logging.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
project_name: Wandb project name for organization
|
|
29
|
+
run_name: Optional run name (auto-generated if None)
|
|
30
|
+
**kwargs: Additional wandb.init() parameters (tags, config, etc.)
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ImportError: If wandb is not installed
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
import wandb # type: ignore
|
|
37
|
+
self.wandb = wandb
|
|
38
|
+
self.run = wandb.init(
|
|
39
|
+
project=project_name,
|
|
40
|
+
name=run_name,
|
|
41
|
+
**kwargs
|
|
42
|
+
)
|
|
43
|
+
except ImportError:
|
|
44
|
+
raise ImportError(
|
|
45
|
+
"wandb not installed. Install with: pip install wandb"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def log_metrics(self, metrics: Dict[str, float], step: int):
|
|
49
|
+
"""
|
|
50
|
+
Log training metrics to wandb with 'train/' prefix.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
metrics: Training metrics dictionary
|
|
54
|
+
step: Training step number
|
|
55
|
+
"""
|
|
56
|
+
prefixed_metrics = {"train/" + k: v for k, v in metrics.items()}
|
|
57
|
+
self.wandb.log(prefixed_metrics, step=step)
|
|
58
|
+
|
|
59
|
+
def log_evaluation(self, metrics: Dict[str, float], step: int):
|
|
60
|
+
"""
|
|
61
|
+
Log evaluation metrics to wandb with 'eval/' prefix.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
metrics: Evaluation metrics dictionary
|
|
65
|
+
step: Training step when evaluation was performed
|
|
66
|
+
"""
|
|
67
|
+
prefixed_metrics = {"eval/" + k: v for k, v in metrics.items()}
|
|
68
|
+
self.wandb.log(prefixed_metrics, step=step)
|
|
69
|
+
|
|
70
|
+
def finish(self):
|
|
71
|
+
"""Finish wandb run and upload final data."""
|
|
72
|
+
self.wandb.finish()
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# textpolicy/utils/memory.py
|
|
2
|
+
"""
|
|
3
|
+
Memory monitoring utilities for TextPolicy.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import gc
|
|
7
|
+
from typing import Dict, Optional
|
|
8
|
+
try:
|
|
9
|
+
import mlx.core as mx # type: ignore
|
|
10
|
+
MLX_AVAILABLE = True
|
|
11
|
+
except ImportError:
|
|
12
|
+
MLX_AVAILABLE = False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_memory_stats() -> Dict[str, float]:
|
|
16
|
+
"""
|
|
17
|
+
Get current memory usage statistics.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Dictionary with memory statistics in MB
|
|
21
|
+
"""
|
|
22
|
+
stats = {}
|
|
23
|
+
|
|
24
|
+
# MLX memory usage (Apple Silicon GPU/ANE)
|
|
25
|
+
if MLX_AVAILABLE:
|
|
26
|
+
try:
|
|
27
|
+
# MLX memory information
|
|
28
|
+
stats["mlx_memory_mb"] = mx.metal.get_active_memory() / 1024 / 1024
|
|
29
|
+
stats["mlx_peak_mb"] = mx.metal.get_peak_memory() / 1024 / 1024
|
|
30
|
+
except Exception as e:
|
|
31
|
+
print(f"Error getting MLX memory stats: {e}")
|
|
32
|
+
stats["mlx_memory_mb"] = 0.0
|
|
33
|
+
stats["mlx_peak_mb"] = 0.0
|
|
34
|
+
|
|
35
|
+
# Python memory usage
|
|
36
|
+
try:
|
|
37
|
+
import psutil # type: ignore
|
|
38
|
+
process = psutil.Process()
|
|
39
|
+
stats["python_memory_mb"] = process.memory_info().rss / 1024 / 1024
|
|
40
|
+
stats["python_virtual_mb"] = process.memory_info().vms / 1024 / 1024
|
|
41
|
+
except ImportError:
|
|
42
|
+
stats["python_memory_mb"] = 0.0
|
|
43
|
+
stats["python_virtual_mb"] = 0.0
|
|
44
|
+
|
|
45
|
+
return stats
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def clear_memory():
|
|
49
|
+
"""
|
|
50
|
+
Clear memory caches and run garbage collection.
|
|
51
|
+
|
|
52
|
+
Useful for freeing memory between training runs or evaluations.
|
|
53
|
+
"""
|
|
54
|
+
# Python garbage collection
|
|
55
|
+
gc.collect()
|
|
56
|
+
|
|
57
|
+
# MLX memory cleanup
|
|
58
|
+
if MLX_AVAILABLE:
|
|
59
|
+
try:
|
|
60
|
+
mx.metal.clear_cache()
|
|
61
|
+
except Exception as e:
|
|
62
|
+
print(f"Error clearing MLX memory: {e}")
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class MemoryMonitor:
|
|
67
|
+
"""
|
|
68
|
+
Monitor memory usage during training.
|
|
69
|
+
|
|
70
|
+
Features:
|
|
71
|
+
- Track peak memory usage
|
|
72
|
+
- Automatic memory alerts
|
|
73
|
+
- Integration with logging systems
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, alert_threshold_mb: float = 8000):
|
|
77
|
+
"""
|
|
78
|
+
Initialize memory monitor.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
alert_threshold_mb: Memory usage threshold for alerts (default 8GB)
|
|
82
|
+
"""
|
|
83
|
+
self.alert_threshold = alert_threshold_mb
|
|
84
|
+
self.peak_stats = {}
|
|
85
|
+
|
|
86
|
+
def check_memory(self, step: Optional[int] = None) -> Dict[str, float]:
|
|
87
|
+
"""
|
|
88
|
+
Check current memory usage and update peaks.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
step: Optional training step for logging
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Current memory statistics
|
|
95
|
+
"""
|
|
96
|
+
current = get_memory_stats()
|
|
97
|
+
|
|
98
|
+
# Update peaks
|
|
99
|
+
for key, value in current.items():
|
|
100
|
+
if key not in self.peak_stats or value > self.peak_stats[key]:
|
|
101
|
+
self.peak_stats[key] = value
|
|
102
|
+
|
|
103
|
+
# Check for alerts
|
|
104
|
+
total_memory = current.get("mlx_memory_mb", 0) + current.get("python_memory_mb", 0)
|
|
105
|
+
if total_memory > self.alert_threshold:
|
|
106
|
+
print(f"Memory alert: {total_memory:.1f}MB (threshold: {self.alert_threshold:.1f}MB)")
|
|
107
|
+
if step is not None:
|
|
108
|
+
print(f" At training step: {step}")
|
|
109
|
+
|
|
110
|
+
return current
|
|
111
|
+
|
|
112
|
+
def get_peak_stats(self) -> Dict[str, float]:
|
|
113
|
+
"""Get peak memory usage statistics."""
|
|
114
|
+
return self.peak_stats.copy()
|
|
115
|
+
|
|
116
|
+
def reset_peaks(self):
|
|
117
|
+
"""Reset peak memory tracking."""
|
|
118
|
+
self.peak_stats.clear()
|