ins-pricing 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. ins_pricing/CHANGELOG.md +179 -0
  2. ins_pricing/RELEASE_NOTES_0.2.8.md +344 -0
  3. ins_pricing/modelling/core/bayesopt/utils.py +2 -1
  4. ins_pricing/modelling/explain/shap_utils.py +209 -6
  5. ins_pricing/pricing/calibration.py +125 -1
  6. ins_pricing/pricing/factors.py +110 -1
  7. ins_pricing/production/preprocess.py +166 -0
  8. ins_pricing/setup.py +1 -1
  9. ins_pricing/tests/governance/__init__.py +1 -0
  10. ins_pricing/tests/governance/test_audit.py +56 -0
  11. ins_pricing/tests/governance/test_registry.py +128 -0
  12. ins_pricing/tests/governance/test_release.py +74 -0
  13. ins_pricing/tests/pricing/__init__.py +1 -0
  14. ins_pricing/tests/pricing/test_calibration.py +72 -0
  15. ins_pricing/tests/pricing/test_exposure.py +64 -0
  16. ins_pricing/tests/pricing/test_factors.py +156 -0
  17. ins_pricing/tests/pricing/test_rate_table.py +40 -0
  18. ins_pricing/tests/production/__init__.py +1 -0
  19. ins_pricing/tests/production/test_monitoring.py +350 -0
  20. ins_pricing/tests/production/test_predict.py +233 -0
  21. ins_pricing/tests/production/test_preprocess.py +339 -0
  22. ins_pricing/tests/production/test_scoring.py +311 -0
  23. ins_pricing/utils/profiling.py +377 -0
  24. ins_pricing/utils/validation.py +427 -0
  25. ins_pricing-0.2.9.dist-info/METADATA +149 -0
  26. {ins_pricing-0.2.7.dist-info → ins_pricing-0.2.9.dist-info}/RECORD +28 -12
  27. ins_pricing/CHANGELOG_20260114.md +0 -275
  28. ins_pricing/CODE_REVIEW_IMPROVEMENTS.md +0 -715
  29. ins_pricing-0.2.7.dist-info/METADATA +0 -101
  30. {ins_pricing-0.2.7.dist-info → ins_pricing-0.2.9.dist-info}/WHEEL +0 -0
  31. {ins_pricing-0.2.7.dist-info → ins_pricing-0.2.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
1
+ """Tests for production scoring module."""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pytest
6
+ from unittest.mock import Mock, patch
7
+
8
+ from ins_pricing.exceptions import DataValidationError
9
+
10
+
11
+ @pytest.fixture
12
+ def sample_predictions():
13
+ """Sample prediction data."""
14
+ return pd.DataFrame({
15
+ "actual": [100, 150, 200, 250, 300],
16
+ "predicted": [105, 145, 210, 240, 295],
17
+ "weight": [1.0, 1.0, 1.0, 1.0, 1.0]
18
+ })
19
+
20
+
21
+ @pytest.fixture
22
+ def classification_data():
23
+ """Sample classification data."""
24
+ return pd.DataFrame({
25
+ "actual": [0, 1, 1, 0, 1, 0, 1, 1, 0, 0],
26
+ "predicted_proba": [0.1, 0.9, 0.8, 0.2, 0.7, 0.3, 0.85, 0.6, 0.15, 0.25],
27
+ "predicted_class": [0, 1, 1, 0, 1, 0, 1, 1, 0, 0]
28
+ })
29
+
30
+
31
+ class TestRegressionMetrics:
32
+ """Test regression scoring metrics."""
33
+
34
+ def test_weighted_mse(self, sample_predictions):
35
+ """Test weighted mean squared error calculation."""
36
+ from ins_pricing.production.scoring import weighted_mse
37
+
38
+ mse = weighted_mse(
39
+ sample_predictions['actual'],
40
+ sample_predictions['predicted'],
41
+ sample_predictions['weight']
42
+ )
43
+
44
+ assert isinstance(mse, (int, float, np.number))
45
+ assert mse >= 0
46
+
47
+ def test_weighted_mae(self, sample_predictions):
48
+ """Test weighted mean absolute error calculation."""
49
+ from ins_pricing.production.scoring import weighted_mae
50
+
51
+ mae = weighted_mae(
52
+ sample_predictions['actual'],
53
+ sample_predictions['predicted'],
54
+ sample_predictions['weight']
55
+ )
56
+
57
+ assert isinstance(mae, (int, float, np.number))
58
+ assert mae >= 0
59
+
60
+ def test_weighted_r2(self, sample_predictions):
61
+ """Test weighted R² score calculation."""
62
+ from ins_pricing.production.scoring import weighted_r2
63
+
64
+ r2 = weighted_r2(
65
+ sample_predictions['actual'],
66
+ sample_predictions['predicted'],
67
+ sample_predictions['weight']
68
+ )
69
+
70
+ assert isinstance(r2, (int, float, np.number))
71
+ assert r2 <= 1.0
72
+
73
+ def test_mape(self, sample_predictions):
74
+ """Test mean absolute percentage error."""
75
+ from ins_pricing.production.scoring import mape
76
+
77
+ mape_score = mape(
78
+ sample_predictions['actual'],
79
+ sample_predictions['predicted']
80
+ )
81
+
82
+ assert isinstance(mape_score, (int, float, np.number))
83
+ assert mape_score >= 0
84
+
85
+ def test_metrics_with_zero_actuals(self):
86
+ """Test metrics handling when actual values are zero."""
87
+ from ins_pricing.production.scoring import mape
88
+
89
+ data = pd.DataFrame({
90
+ "actual": [0, 100, 200],
91
+ "predicted": [10, 105, 195]
92
+ })
93
+
94
+ # MAPE should handle zeros gracefully
95
+ with pytest.raises((ValueError, ZeroDivisionError)):
96
+ mape(data['actual'], data['predicted'])
97
+
98
+
99
+ class TestClassificationMetrics:
100
+ """Test classification scoring metrics."""
101
+
102
+ def test_accuracy(self, classification_data):
103
+ """Test accuracy calculation."""
104
+ from ins_pricing.production.scoring import accuracy
105
+
106
+ acc = accuracy(
107
+ classification_data['actual'],
108
+ classification_data['predicted_class']
109
+ )
110
+
111
+ assert 0 <= acc <= 1
112
+
113
+ def test_precision_recall(self, classification_data):
114
+ """Test precision and recall calculation."""
115
+ from ins_pricing.production.scoring import precision_recall
116
+
117
+ precision, recall = precision_recall(
118
+ classification_data['actual'],
119
+ classification_data['predicted_class']
120
+ )
121
+
122
+ assert 0 <= precision <= 1
123
+ assert 0 <= recall <= 1
124
+
125
+ def test_f1_score(self, classification_data):
126
+ """Test F1 score calculation."""
127
+ from ins_pricing.production.scoring import f1_score
128
+
129
+ f1 = f1_score(
130
+ classification_data['actual'],
131
+ classification_data['predicted_class']
132
+ )
133
+
134
+ assert 0 <= f1 <= 1
135
+
136
+ def test_roc_auc(self, classification_data):
137
+ """Test ROC AUC calculation."""
138
+ from ins_pricing.production.scoring import roc_auc
139
+
140
+ auc = roc_auc(
141
+ classification_data['actual'],
142
+ classification_data['predicted_proba']
143
+ )
144
+
145
+ assert 0 <= auc <= 1
146
+
147
+ def test_confusion_matrix(self, classification_data):
148
+ """Test confusion matrix generation."""
149
+ from ins_pricing.production.scoring import confusion_matrix
150
+
151
+ cm = confusion_matrix(
152
+ classification_data['actual'],
153
+ classification_data['predicted_class']
154
+ )
155
+
156
+ assert cm.shape == (2, 2)
157
+ assert np.all(cm >= 0)
158
+
159
+
160
+ class TestInsuranceMetrics:
161
+ """Test insurance-specific metrics."""
162
+
163
+ def test_loss_ratio(self):
164
+ """Test loss ratio calculation."""
165
+ from ins_pricing.production.scoring import loss_ratio
166
+
167
+ data = pd.DataFrame({
168
+ "claims": [100, 200, 150],
169
+ "premiums": [120, 180, 160],
170
+ "exposure": [1.0, 1.0, 1.0]
171
+ })
172
+
173
+ lr = loss_ratio(data['claims'], data['premiums'], data['exposure'])
174
+
175
+ assert isinstance(lr, (int, float, np.number))
176
+ assert lr >= 0
177
+
178
+ def test_gini_coefficient(self, sample_predictions):
179
+ """Test Gini coefficient calculation."""
180
+ from ins_pricing.production.scoring import gini_coefficient
181
+
182
+ gini = gini_coefficient(
183
+ sample_predictions['actual'],
184
+ sample_predictions['predicted']
185
+ )
186
+
187
+ assert -1 <= gini <= 1
188
+
189
+ def test_lift_at_percentile(self, sample_predictions):
190
+ """Test lift calculation at specific percentile."""
191
+ from ins_pricing.production.scoring import lift_at_percentile
192
+
193
+ lift = lift_at_percentile(
194
+ sample_predictions['actual'],
195
+ sample_predictions['predicted'],
196
+ percentile=20
197
+ )
198
+
199
+ assert isinstance(lift, (int, float, np.number))
200
+
201
+
202
+ class TestMetricValidation:
203
+ """Test metric input validation."""
204
+
205
+ def test_mismatched_lengths(self):
206
+ """Test error on mismatched array lengths."""
207
+ from ins_pricing.production.scoring import weighted_mse
208
+ from ins_pricing.utils.validation import validate_dataframe_not_empty
209
+
210
+ actual = np.array([1, 2, 3])
211
+ predicted = np.array([1, 2]) # Wrong length
212
+ weights = np.array([1, 1, 1])
213
+
214
+ with pytest.raises((ValueError, IndexError)):
215
+ weighted_mse(actual, predicted, weights)
216
+
217
+ def test_negative_weights(self):
218
+ """Test handling of negative weights."""
219
+ from ins_pricing.production.scoring import weighted_mse
220
+
221
+ actual = np.array([100, 200, 300])
222
+ predicted = np.array([105, 195, 310])
223
+ weights = np.array([1.0, -1.0, 1.0]) # Negative weight
224
+
225
+ with pytest.raises(ValueError):
226
+ weighted_mse(actual, predicted, weights)
227
+
228
+ def test_nan_values(self):
229
+ """Test handling of NaN values."""
230
+ from ins_pricing.production.scoring import weighted_mse
231
+
232
+ actual = np.array([100, np.nan, 300])
233
+ predicted = np.array([105, 195, 310])
234
+ weights = np.array([1.0, 1.0, 1.0])
235
+
236
+ with pytest.raises(ValueError):
237
+ weighted_mse(actual, predicted, weights)
238
+
239
+
240
+ class TestScoringReport:
241
+ """Test scoring report generation."""
242
+
243
+ def test_generate_regression_report(self, sample_predictions):
244
+ """Test comprehensive regression scoring report."""
245
+ from ins_pricing.production.scoring import generate_scoring_report
246
+
247
+ report = generate_scoring_report(
248
+ actual=sample_predictions['actual'],
249
+ predicted=sample_predictions['predicted'],
250
+ weights=sample_predictions['weight'],
251
+ task_type='regression'
252
+ )
253
+
254
+ assert 'mse' in report
255
+ assert 'mae' in report
256
+ assert 'r2' in report
257
+ assert all(isinstance(v, (int, float, np.number)) for v in report.values())
258
+
259
+ def test_generate_classification_report(self, classification_data):
260
+ """Test comprehensive classification scoring report."""
261
+ from ins_pricing.production.scoring import generate_scoring_report
262
+
263
+ report = generate_scoring_report(
264
+ actual=classification_data['actual'],
265
+ predicted=classification_data['predicted_class'],
266
+ predicted_proba=classification_data['predicted_proba'],
267
+ task_type='classification'
268
+ )
269
+
270
+ assert 'accuracy' in report
271
+ assert 'precision' in report
272
+ assert 'recall' in report
273
+ assert 'f1' in report
274
+ assert 'roc_auc' in report
275
+
276
+ def test_save_report_to_file(self, sample_predictions, tmp_path):
277
+ """Test saving scoring report to file."""
278
+ from ins_pricing.production.scoring import generate_scoring_report, save_report
279
+
280
+ report = generate_scoring_report(
281
+ actual=sample_predictions['actual'],
282
+ predicted=sample_predictions['predicted'],
283
+ task_type='regression'
284
+ )
285
+
286
+ output_path = tmp_path / "scoring_report.json"
287
+ save_report(report, output_path)
288
+
289
+ assert output_path.exists()
290
+
291
+
292
+ @pytest.mark.performance
293
+ class TestScoringPerformance:
294
+ """Test scoring performance on large datasets."""
295
+
296
+ def test_large_dataset_scoring(self):
297
+ """Test scoring metrics on large dataset."""
298
+ from ins_pricing.production.scoring import weighted_mse
299
+
300
+ n = 1_000_000
301
+ actual = np.random.uniform(100, 500, n)
302
+ predicted = actual + np.random.normal(0, 20, n)
303
+ weights = np.ones(n)
304
+
305
+ import time
306
+ start = time.time()
307
+ mse = weighted_mse(actual, predicted, weights)
308
+ elapsed = time.time() - start
309
+
310
+ assert isinstance(mse, (int, float, np.number))
311
+ assert elapsed < 1.0 # Should complete in under 1 second
@@ -0,0 +1,377 @@
1
+ """Performance profiling and memory monitoring utilities.
2
+
3
+ This module provides tools for tracking execution time, memory usage,
4
+ and GPU resources during model training and data processing.
5
+
6
+ Example:
7
+ >>> from ins_pricing.utils.profiling import profile_section
8
+ >>> with profile_section("Data Loading", logger=my_logger):
9
+ ... data = load_large_dataset()
10
+ [Profile] Data Loading: 5.23s, RAM: +1250.3MB, GPU peak: 2048.5MB
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import gc
16
+ import logging
17
+ import time
18
+ from contextlib import contextmanager
19
+ from typing import Optional
20
+
21
+ try:
22
+ import psutil
23
+ HAS_PSUTIL = True
24
+ except ImportError:
25
+ HAS_PSUTIL = False
26
+
27
+ try:
28
+ import torch
29
+ HAS_TORCH = True
30
+ except ImportError:
31
+ HAS_TORCH = False
32
+
33
+
34
+ @contextmanager
35
+ def profile_section(
36
+ name: str,
37
+ logger: Optional[logging.Logger] = None,
38
+ log_level: int = logging.INFO
39
+ ):
40
+ """Context manager for profiling code sections.
41
+
42
+ Tracks execution time, RAM usage, and GPU memory (if available).
43
+ Logs results when the context exits.
44
+
45
+ Args:
46
+ name: Name of the section being profiled
47
+ logger: Optional logger instance. If None, prints to stdout
48
+ log_level: Logging level (default: INFO)
49
+
50
+ Yields:
51
+ None
52
+
53
+ Example:
54
+ >>> with profile_section("Training Loop", logger):
55
+ ... model.fit(X_train, y_train)
56
+ [Profile] Training Loop: 45.2s, RAM: +2100.5MB, GPU peak: 4096.0MB
57
+
58
+ >>> with profile_section("Preprocessing"):
59
+ ... df = preprocess_data(raw_df)
60
+ [Profile] Preprocessing: 2.1s, RAM: +150.2MB
61
+ """
62
+ start_time = time.time()
63
+
64
+ # Track CPU memory
65
+ start_mem = None
66
+ if HAS_PSUTIL:
67
+ start_mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
68
+
69
+ # Track GPU memory
70
+ start_gpu_mem = None
71
+ if HAS_TORCH and torch.cuda.is_available():
72
+ torch.cuda.reset_peak_memory_stats()
73
+ start_gpu_mem = torch.cuda.memory_allocated() / 1024 / 1024 # MB
74
+
75
+ try:
76
+ yield
77
+ finally:
78
+ elapsed = time.time() - start_time
79
+
80
+ # Build profiling message
81
+ msg_parts = [f"[Profile] {name}: {elapsed:.2f}s"]
82
+
83
+ # Add RAM usage
84
+ if HAS_PSUTIL and start_mem is not None:
85
+ end_mem = psutil.Process().memory_info().rss / 1024 / 1024
86
+ mem_delta = end_mem - start_mem
87
+ msg_parts.append(f"RAM: {mem_delta:+.1f}MB")
88
+
89
+ # Add GPU memory
90
+ if HAS_TORCH and torch.cuda.is_available() and start_gpu_mem is not None:
91
+ peak_gpu = torch.cuda.max_memory_allocated() / 1024 / 1024
92
+ msg_parts.append(f"GPU peak: {peak_gpu:.1f}MB")
93
+
94
+ msg = ", ".join(msg_parts)
95
+
96
+ if logger:
97
+ logger.log(log_level, msg)
98
+ else:
99
+ print(msg)
100
+
101
+
102
+ def get_memory_info() -> dict:
103
+ """Get current memory usage information.
104
+
105
+ Returns:
106
+ Dictionary with memory statistics:
107
+ - rss_mb: Resident Set Size in MB (physical memory)
108
+ - vms_mb: Virtual Memory Size in MB
109
+ - percent: Memory usage percentage
110
+ - available_mb: Available system memory in MB
111
+ - gpu_allocated_mb: GPU memory allocated (if CUDA available)
112
+ - gpu_cached_mb: GPU memory cached (if CUDA available)
113
+
114
+ Example:
115
+ >>> info = get_memory_info()
116
+ >>> print(f"Using {info['rss_mb']:.1f} MB RAM")
117
+ Using 2048.5 MB RAM
118
+ """
119
+ info = {}
120
+
121
+ if HAS_PSUTIL:
122
+ process = psutil.Process()
123
+ mem = process.memory_info()
124
+ info['rss_mb'] = mem.rss / 1024 / 1024
125
+ info['vms_mb'] = mem.vms / 1024 / 1024
126
+
127
+ vm = psutil.virtual_memory()
128
+ info['percent'] = vm.percent
129
+ info['available_mb'] = vm.available / 1024 / 1024
130
+ else:
131
+ info['warning'] = 'psutil not available'
132
+
133
+ if HAS_TORCH and torch.cuda.is_available():
134
+ info['gpu_allocated_mb'] = torch.cuda.memory_allocated() / 1024 / 1024
135
+ info['gpu_cached_mb'] = torch.cuda.memory_reserved() / 1024 / 1024
136
+ info['gpu_max_allocated_mb'] = torch.cuda.max_memory_allocated() / 1024 / 1024
137
+
138
+ return info
139
+
140
+
141
+ def log_memory_usage(
142
+ logger: logging.Logger,
143
+ prefix: str = "",
144
+ level: int = logging.INFO
145
+ ) -> None:
146
+ """Log current memory usage.
147
+
148
+ Args:
149
+ logger: Logger instance
150
+ prefix: Optional prefix for log message
151
+ level: Logging level (default: INFO)
152
+
153
+ Example:
154
+ >>> log_memory_usage(logger, prefix="After epoch 10")
155
+ After epoch 10 - Memory: RSS=2048.5MB, GPU=1024.0MB
156
+ """
157
+ info = get_memory_info()
158
+
159
+ if 'warning' in info:
160
+ logger.log(level, f"{prefix} - Memory info not available (psutil missing)")
161
+ return
162
+
163
+ msg_parts = []
164
+ if prefix:
165
+ msg_parts.append(prefix)
166
+
167
+ ram_msg = f"RSS={info['rss_mb']:.1f}MB ({info['percent']:.1f}%)"
168
+ msg_parts.append(f"Memory: {ram_msg}")
169
+
170
+ if 'gpu_allocated_mb' in info:
171
+ gpu_msg = f"GPU={info['gpu_allocated_mb']:.1f}MB"
172
+ msg_parts.append(gpu_msg)
173
+
174
+ logger.log(level, " - ".join(msg_parts))
175
+
176
+
177
+ def check_memory_threshold(
178
+ threshold_gb: float = 32.0,
179
+ logger: Optional[logging.Logger] = None
180
+ ) -> bool:
181
+ """Check if memory usage exceeds threshold.
182
+
183
+ Args:
184
+ threshold_gb: Memory threshold in GB (default: 32.0)
185
+ logger: Optional logger for warnings
186
+
187
+ Returns:
188
+ True if memory usage exceeds threshold, False otherwise
189
+
190
+ Example:
191
+ >>> if check_memory_threshold(threshold_gb=16.0, logger=logger):
192
+ ... torch.cuda.empty_cache()
193
+ ... gc.collect()
194
+ """
195
+ if not HAS_PSUTIL:
196
+ return False
197
+
198
+ mem = psutil.Process().memory_info()
199
+ rss_gb = mem.rss / 1024 / 1024 / 1024
200
+
201
+ if rss_gb > threshold_gb:
202
+ if logger:
203
+ logger.warning(
204
+ f"High memory usage detected: {rss_gb:.1f}GB "
205
+ f"(threshold: {threshold_gb:.1f}GB)"
206
+ )
207
+ return True
208
+
209
+ return False
210
+
211
+
212
+ def cleanup_memory(logger: Optional[logging.Logger] = None) -> None:
213
+ """Force memory cleanup for CPU and GPU.
214
+
215
+ Args:
216
+ logger: Optional logger instance
217
+
218
+ Example:
219
+ >>> cleanup_memory(logger)
220
+ [Memory] Cleanup: freed 250.5MB RAM, 512.0MB GPU
221
+ """
222
+ if HAS_PSUTIL:
223
+ mem_before = psutil.Process().memory_info().rss / 1024 / 1024
224
+
225
+ gpu_before = None
226
+ if HAS_TORCH and torch.cuda.is_available():
227
+ gpu_before = torch.cuda.memory_allocated() / 1024 / 1024
228
+
229
+ # Perform cleanup
230
+ gc.collect()
231
+
232
+ if HAS_TORCH and torch.cuda.is_available():
233
+ torch.cuda.empty_cache()
234
+
235
+ # Calculate freed memory
236
+ msg_parts = ["[Memory] Cleanup:"]
237
+
238
+ if HAS_PSUTIL:
239
+ mem_after = psutil.Process().memory_info().rss / 1024 / 1024
240
+ mem_freed = mem_before - mem_after
241
+ msg_parts.append(f"freed {mem_freed:.1f}MB RAM")
242
+
243
+ if gpu_before is not None:
244
+ gpu_after = torch.cuda.memory_allocated() / 1024 / 1024
245
+ gpu_freed = gpu_before - gpu_after
246
+ msg_parts.append(f"{gpu_freed:.1f}MB GPU")
247
+
248
+ msg = ", ".join(msg_parts)
249
+
250
+ if logger:
251
+ logger.info(msg)
252
+ else:
253
+ print(msg)
254
+
255
+
256
+ class MemoryMonitor:
257
+ """Memory monitoring context manager with automatic cleanup.
258
+
259
+ Monitors memory usage and optionally triggers cleanup if threshold exceeded.
260
+
261
+ Args:
262
+ name: Name of the monitored section
263
+ threshold_gb: Memory threshold for automatic cleanup (default: None, no cleanup)
264
+ logger: Optional logger instance
265
+
266
+ Example:
267
+ >>> with MemoryMonitor("Training", threshold_gb=16.0, logger=logger):
268
+ ... for epoch in range(100):
269
+ ... train_epoch(model, data)
270
+ [Memory] Training started: RAM=1024.5MB, GPU=512.0MB
271
+ [Memory] Training completed: RAM=2048.3MB (+1023.8MB), GPU=2048.0MB (+1536.0MB)
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ name: str,
277
+ threshold_gb: Optional[float] = None,
278
+ logger: Optional[logging.Logger] = None
279
+ ):
280
+ self.name = name
281
+ self.threshold_gb = threshold_gb
282
+ self.logger = logger
283
+ self.start_mem = None
284
+ self.start_gpu = None
285
+
286
+ def __enter__(self):
287
+ if HAS_PSUTIL:
288
+ self.start_mem = psutil.Process().memory_info().rss / 1024 / 1024
289
+
290
+ if HAS_TORCH and torch.cuda.is_available():
291
+ self.start_gpu = torch.cuda.memory_allocated() / 1024 / 1024
292
+
293
+ # Log starting state
294
+ msg_parts = [f"[Memory] {self.name} started:"]
295
+ if self.start_mem is not None:
296
+ msg_parts.append(f"RAM={self.start_mem:.1f}MB")
297
+ if self.start_gpu is not None:
298
+ msg_parts.append(f"GPU={self.start_gpu:.1f}MB")
299
+
300
+ msg = ", ".join(msg_parts)
301
+ if self.logger:
302
+ self.logger.info(msg)
303
+ else:
304
+ print(msg)
305
+
306
+ return self
307
+
308
+ def __exit__(self, exc_type, exc_val, exc_tb):
309
+ # Calculate deltas
310
+ msg_parts = [f"[Memory] {self.name} completed:"]
311
+
312
+ if HAS_PSUTIL and self.start_mem is not None:
313
+ end_mem = psutil.Process().memory_info().rss / 1024 / 1024
314
+ delta_mem = end_mem - self.start_mem
315
+ msg_parts.append(f"RAM={end_mem:.1f}MB ({delta_mem:+.1f}MB)")
316
+
317
+ if HAS_TORCH and torch.cuda.is_available() and self.start_gpu is not None:
318
+ end_gpu = torch.cuda.memory_allocated() / 1024 / 1024
319
+ delta_gpu = end_gpu - self.start_gpu
320
+ msg_parts.append(f"GPU={end_gpu:.1f}MB ({delta_gpu:+.1f}MB)")
321
+
322
+ msg = ", ".join(msg_parts)
323
+ if self.logger:
324
+ self.logger.info(msg)
325
+ else:
326
+ print(msg)
327
+
328
+ # Check threshold and cleanup if needed
329
+ if self.threshold_gb is not None:
330
+ if check_memory_threshold(self.threshold_gb, self.logger):
331
+ cleanup_memory(self.logger)
332
+
333
+
334
+ def profile_training_epoch(
335
+ epoch: int,
336
+ total_epochs: int,
337
+ logger: Optional[logging.Logger] = None,
338
+ cleanup_interval: int = 10
339
+ ) -> None:
340
+ """Log memory usage during training epochs with periodic cleanup.
341
+
342
+ Args:
343
+ epoch: Current epoch number
344
+ total_epochs: Total number of epochs
345
+ logger: Optional logger instance
346
+ cleanup_interval: Cleanup memory every N epochs (default: 10)
347
+
348
+ Example:
349
+ >>> for epoch in range(1, 101):
350
+ ... train_one_epoch(model, data)
351
+ ... profile_training_epoch(epoch, 100, logger, cleanup_interval=10)
352
+ """
353
+ log_memory_usage(
354
+ logger or logging.getLogger(__name__),
355
+ prefix=f"Epoch {epoch}/{total_epochs}",
356
+ level=logging.DEBUG
357
+ )
358
+
359
+ # Periodic cleanup
360
+ if epoch % cleanup_interval == 0:
361
+ if logger:
362
+ logger.info(f"Epoch {epoch}: Performing periodic memory cleanup")
363
+ cleanup_memory(logger)
364
+
365
+
366
+ # Convenience function for backward compatibility
367
+ def ensure_memory_cleanup(threshold_gb: float = 32.0) -> None:
368
+ """Check memory and cleanup if needed (simple function interface).
369
+
370
+ Args:
371
+ threshold_gb: Memory threshold in GB
372
+
373
+ Example:
374
+ >>> ensure_memory_cleanup(threshold_gb=16.0)
375
+ """
376
+ if check_memory_threshold(threshold_gb):
377
+ cleanup_memory()