quantmllibrary 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantml/__init__.py +74 -0
- quantml/autograd.py +154 -0
- quantml/cli/__init__.py +10 -0
- quantml/cli/run_experiment.py +385 -0
- quantml/config/__init__.py +28 -0
- quantml/config/config.py +259 -0
- quantml/data/__init__.py +33 -0
- quantml/data/cache.py +149 -0
- quantml/data/feature_store.py +234 -0
- quantml/data/futures.py +254 -0
- quantml/data/loaders.py +236 -0
- quantml/data/memory_optimizer.py +234 -0
- quantml/data/validators.py +390 -0
- quantml/experiments/__init__.py +23 -0
- quantml/experiments/logger.py +208 -0
- quantml/experiments/results.py +158 -0
- quantml/experiments/tracker.py +223 -0
- quantml/features/__init__.py +25 -0
- quantml/features/base.py +104 -0
- quantml/features/gap_features.py +124 -0
- quantml/features/registry.py +138 -0
- quantml/features/volatility_features.py +140 -0
- quantml/features/volume_features.py +142 -0
- quantml/functional.py +37 -0
- quantml/models/__init__.py +27 -0
- quantml/models/attention.py +258 -0
- quantml/models/dropout.py +130 -0
- quantml/models/gru.py +319 -0
- quantml/models/linear.py +112 -0
- quantml/models/lstm.py +353 -0
- quantml/models/mlp.py +286 -0
- quantml/models/normalization.py +289 -0
- quantml/models/rnn.py +154 -0
- quantml/models/tcn.py +238 -0
- quantml/online.py +209 -0
- quantml/ops.py +1707 -0
- quantml/optim/__init__.py +42 -0
- quantml/optim/adafactor.py +206 -0
- quantml/optim/adagrad.py +157 -0
- quantml/optim/adam.py +267 -0
- quantml/optim/lookahead.py +97 -0
- quantml/optim/quant_optimizer.py +228 -0
- quantml/optim/radam.py +192 -0
- quantml/optim/rmsprop.py +203 -0
- quantml/optim/schedulers.py +286 -0
- quantml/optim/sgd.py +181 -0
- quantml/py.typed +0 -0
- quantml/streaming.py +175 -0
- quantml/tensor.py +462 -0
- quantml/time_series.py +447 -0
- quantml/training/__init__.py +135 -0
- quantml/training/alpha_eval.py +203 -0
- quantml/training/backtest.py +280 -0
- quantml/training/backtest_analysis.py +168 -0
- quantml/training/cv.py +106 -0
- quantml/training/data_loader.py +177 -0
- quantml/training/ensemble.py +84 -0
- quantml/training/feature_importance.py +135 -0
- quantml/training/features.py +364 -0
- quantml/training/futures_backtest.py +266 -0
- quantml/training/gradient_clipping.py +206 -0
- quantml/training/losses.py +248 -0
- quantml/training/lr_finder.py +127 -0
- quantml/training/metrics.py +376 -0
- quantml/training/regularization.py +89 -0
- quantml/training/trainer.py +239 -0
- quantml/training/walk_forward.py +190 -0
- quantml/utils/__init__.py +51 -0
- quantml/utils/gradient_check.py +274 -0
- quantml/utils/logging.py +181 -0
- quantml/utils/ops_cpu.py +231 -0
- quantml/utils/profiling.py +364 -0
- quantml/utils/reproducibility.py +220 -0
- quantml/utils/serialization.py +335 -0
- quantmllibrary-0.1.0.dist-info/METADATA +536 -0
- quantmllibrary-0.1.0.dist-info/RECORD +79 -0
- quantmllibrary-0.1.0.dist-info/WHEEL +5 -0
- quantmllibrary-0.1.0.dist-info/licenses/LICENSE +22 -0
- quantmllibrary-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reproducibility utilities for QuantML.
|
|
3
|
+
|
|
4
|
+
Provides random seed management, version tracking, and experiment metadata.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import random
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from typing import Optional, Dict, Any
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
import json
|
|
13
|
+
|
|
14
|
+
# Try to import NumPy
|
|
15
|
+
try:
|
|
16
|
+
import numpy as np
|
|
17
|
+
HAS_NUMPY = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
HAS_NUMPY = False
|
|
20
|
+
np = None
|
|
21
|
+
|
|
22
|
+
# Try to get git info
|
|
23
|
+
def get_git_info() -> Dict[str, Optional[str]]:
|
|
24
|
+
"""Get git repository information."""
|
|
25
|
+
info = {
|
|
26
|
+
'commit_hash': None,
|
|
27
|
+
'branch': None,
|
|
28
|
+
'is_dirty': None
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import subprocess
|
|
33
|
+
|
|
34
|
+
# Get commit hash
|
|
35
|
+
try:
|
|
36
|
+
commit_hash = subprocess.check_output(
|
|
37
|
+
['git', 'rev-parse', 'HEAD'],
|
|
38
|
+
stderr=subprocess.DEVNULL
|
|
39
|
+
).decode().strip()
|
|
40
|
+
info['commit_hash'] = commit_hash
|
|
41
|
+
except Exception:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
# Get branch
|
|
45
|
+
try:
|
|
46
|
+
branch = subprocess.check_output(
|
|
47
|
+
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
|
|
48
|
+
stderr=subprocess.DEVNULL
|
|
49
|
+
).decode().strip()
|
|
50
|
+
info['branch'] = branch
|
|
51
|
+
except Exception:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
# Check if dirty
|
|
55
|
+
try:
|
|
56
|
+
result = subprocess.run(
|
|
57
|
+
['git', 'diff', '--quiet'],
|
|
58
|
+
stderr=subprocess.DEVNULL
|
|
59
|
+
)
|
|
60
|
+
info['is_dirty'] = result.returncode != 0
|
|
61
|
+
except Exception:
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
except Exception:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
return info
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_library_version() -> str:
|
|
71
|
+
"""Get QuantML library version."""
|
|
72
|
+
try:
|
|
73
|
+
from quantml import __version__
|
|
74
|
+
return __version__
|
|
75
|
+
except ImportError:
|
|
76
|
+
return "unknown"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def set_random_seed(seed: int, use_cuda: bool = False):
|
|
80
|
+
"""
|
|
81
|
+
Set random seed for reproducibility.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
seed: Random seed value
|
|
85
|
+
use_cuda: Whether to set CUDA random seed (if available)
|
|
86
|
+
"""
|
|
87
|
+
random.seed(seed)
|
|
88
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
89
|
+
|
|
90
|
+
if HAS_NUMPY:
|
|
91
|
+
np.random.seed(seed)
|
|
92
|
+
|
|
93
|
+
# Try to set PyTorch seed if available
|
|
94
|
+
try:
|
|
95
|
+
import torch
|
|
96
|
+
torch.manual_seed(seed)
|
|
97
|
+
if use_cuda and torch.cuda.is_available():
|
|
98
|
+
torch.cuda.manual_seed_all(seed)
|
|
99
|
+
torch.backends.cudnn.deterministic = True
|
|
100
|
+
torch.backends.cudnn.benchmark = False
|
|
101
|
+
except ImportError:
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
# Try to set TensorFlow seed if available
|
|
105
|
+
try:
|
|
106
|
+
import tensorflow as tf
|
|
107
|
+
tf.random.set_seed(seed)
|
|
108
|
+
except ImportError:
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_environment_info() -> Dict[str, Any]:
|
|
113
|
+
"""
|
|
114
|
+
Get environment information for reproducibility.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Dictionary with environment details
|
|
118
|
+
"""
|
|
119
|
+
info = {
|
|
120
|
+
'timestamp': datetime.now().isoformat(),
|
|
121
|
+
'python_version': sys.version,
|
|
122
|
+
'platform': sys.platform,
|
|
123
|
+
'quantml_version': get_library_version(),
|
|
124
|
+
'git_info': get_git_info()
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
# Add NumPy version if available
|
|
128
|
+
if HAS_NUMPY:
|
|
129
|
+
info['numpy_version'] = np.__version__
|
|
130
|
+
|
|
131
|
+
# Try to get other library versions
|
|
132
|
+
for lib_name in ['pandas', 'scipy', 'sklearn']:
|
|
133
|
+
try:
|
|
134
|
+
lib = __import__(lib_name)
|
|
135
|
+
if hasattr(lib, '__version__'):
|
|
136
|
+
info[f'{lib_name}_version'] = lib.__version__
|
|
137
|
+
except ImportError:
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
return info
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def save_experiment_metadata(
|
|
144
|
+
metadata: Dict[str, Any],
|
|
145
|
+
filepath: str
|
|
146
|
+
):
|
|
147
|
+
"""
|
|
148
|
+
Save experiment metadata to file.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
metadata: Metadata dictionary
|
|
152
|
+
filepath: Path to save metadata
|
|
153
|
+
"""
|
|
154
|
+
os.makedirs(os.path.dirname(filepath) or '.', exist_ok=True)
|
|
155
|
+
|
|
156
|
+
with open(filepath, 'w') as f:
|
|
157
|
+
json.dump(metadata, f, indent=2)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def load_experiment_metadata(filepath: str) -> Dict[str, Any]:
|
|
161
|
+
"""
|
|
162
|
+
Load experiment metadata from file.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
filepath: Path to metadata file
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Metadata dictionary
|
|
169
|
+
"""
|
|
170
|
+
with open(filepath, 'r') as f:
|
|
171
|
+
return json.load(f)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def create_experiment_id(prefix: str = "exp") -> str:
|
|
175
|
+
"""
|
|
176
|
+
Create a unique experiment ID.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
prefix: Prefix for experiment ID
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Experiment ID string
|
|
183
|
+
"""
|
|
184
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
185
|
+
random_suffix = random.randint(1000, 9999)
|
|
186
|
+
return f"{prefix}_{timestamp}_{random_suffix}"
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class ReproducibilityContext:
|
|
190
|
+
"""Context manager for reproducible experiments."""
|
|
191
|
+
|
|
192
|
+
def __init__(self, seed: int, experiment_name: Optional[str] = None):
|
|
193
|
+
"""
|
|
194
|
+
Initialize reproducibility context.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
seed: Random seed
|
|
198
|
+
experiment_name: Optional experiment name
|
|
199
|
+
"""
|
|
200
|
+
self.seed = seed
|
|
201
|
+
self.experiment_name = experiment_name
|
|
202
|
+
self.metadata = None
|
|
203
|
+
|
|
204
|
+
def __enter__(self):
|
|
205
|
+
"""Enter context and set random seed."""
|
|
206
|
+
set_random_seed(self.seed)
|
|
207
|
+
self.metadata = get_environment_info()
|
|
208
|
+
self.metadata['random_seed'] = self.seed
|
|
209
|
+
if self.experiment_name:
|
|
210
|
+
self.metadata['experiment_name'] = self.experiment_name
|
|
211
|
+
return self
|
|
212
|
+
|
|
213
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
214
|
+
"""Exit context."""
|
|
215
|
+
pass
|
|
216
|
+
|
|
217
|
+
def get_metadata(self) -> Dict[str, Any]:
|
|
218
|
+
"""Get experiment metadata."""
|
|
219
|
+
return self.metadata.copy() if self.metadata else {}
|
|
220
|
+
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model serialization utilities for saving and loading models.
|
|
3
|
+
|
|
4
|
+
This module provides functions to save and load model weights,
|
|
5
|
+
as well as full training checkpoints.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from typing import Optional, Dict, Any, List, Union
|
|
11
|
+
from quantml.tensor import Tensor
|
|
12
|
+
|
|
13
|
+
# Try to import NumPy for efficient storage
|
|
14
|
+
try:
|
|
15
|
+
import numpy as np
|
|
16
|
+
HAS_NUMPY = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
HAS_NUMPY = False
|
|
19
|
+
np = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def save_model(model, path: str, format: str = 'json') -> None:
|
|
23
|
+
"""
|
|
24
|
+
Save model weights to file.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Model with parameters() method
|
|
28
|
+
path: Path to save weights
|
|
29
|
+
format: 'json' or 'npz' (requires NumPy)
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
>>> from quantml.models import Linear
|
|
33
|
+
>>> model = Linear(10, 1)
|
|
34
|
+
>>> save_model(model, 'model_weights.json')
|
|
35
|
+
"""
|
|
36
|
+
params = model.parameters()
|
|
37
|
+
|
|
38
|
+
if format == 'json':
|
|
39
|
+
_save_json(params, path)
|
|
40
|
+
elif format == 'npz':
|
|
41
|
+
if not HAS_NUMPY:
|
|
42
|
+
raise RuntimeError("NumPy required for npz format")
|
|
43
|
+
_save_npz(params, path)
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(f"Unknown format: {format}. Use 'json' or 'npz'")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def load_model(model, path: str, format: Optional[str] = None) -> None:
|
|
49
|
+
"""
|
|
50
|
+
Load model weights from file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model: Model with parameters() method (weights will be loaded in-place)
|
|
54
|
+
path: Path to load weights from
|
|
55
|
+
format: 'json' or 'npz' (auto-detected from extension if None)
|
|
56
|
+
|
|
57
|
+
Examples:
|
|
58
|
+
>>> from quantml.models import Linear
|
|
59
|
+
>>> model = Linear(10, 1)
|
|
60
|
+
>>> load_model(model, 'model_weights.json')
|
|
61
|
+
"""
|
|
62
|
+
if format is None:
|
|
63
|
+
# Auto-detect from extension
|
|
64
|
+
if path.endswith('.json'):
|
|
65
|
+
format = 'json'
|
|
66
|
+
elif path.endswith('.npz'):
|
|
67
|
+
format = 'npz'
|
|
68
|
+
else:
|
|
69
|
+
format = 'json' # Default
|
|
70
|
+
|
|
71
|
+
params = model.parameters()
|
|
72
|
+
|
|
73
|
+
if format == 'json':
|
|
74
|
+
loaded = _load_json(path)
|
|
75
|
+
elif format == 'npz':
|
|
76
|
+
if not HAS_NUMPY:
|
|
77
|
+
raise RuntimeError("NumPy required for npz format")
|
|
78
|
+
loaded = _load_npz(path)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f"Unknown format: {format}")
|
|
81
|
+
|
|
82
|
+
# Update model parameters
|
|
83
|
+
if len(loaded) != len(params):
|
|
84
|
+
raise ValueError(f"Parameter count mismatch: model has {len(params)}, file has {len(loaded)}")
|
|
85
|
+
|
|
86
|
+
for i, (param, loaded_data) in enumerate(zip(params, loaded)):
|
|
87
|
+
_update_tensor_data(param, loaded_data)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def save_checkpoint(
|
|
91
|
+
model,
|
|
92
|
+
optimizer,
|
|
93
|
+
epoch: int,
|
|
94
|
+
path: str,
|
|
95
|
+
loss: Optional[float] = None,
|
|
96
|
+
metrics: Optional[Dict[str, float]] = None,
|
|
97
|
+
extra: Optional[Dict[str, Any]] = None
|
|
98
|
+
) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Save a full training checkpoint.
|
|
101
|
+
|
|
102
|
+
Includes model weights, optimizer state, epoch, and optional metrics.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model: Model with parameters() method
|
|
106
|
+
optimizer: Optimizer with state to save
|
|
107
|
+
epoch: Current epoch number
|
|
108
|
+
path: Path to save checkpoint
|
|
109
|
+
loss: Optional current loss value
|
|
110
|
+
metrics: Optional dict of metric values
|
|
111
|
+
extra: Optional extra data to save
|
|
112
|
+
|
|
113
|
+
Examples:
|
|
114
|
+
>>> save_checkpoint(model, optimizer, epoch=10, path='checkpoint.json')
|
|
115
|
+
"""
|
|
116
|
+
checkpoint = {
|
|
117
|
+
'epoch': epoch,
|
|
118
|
+
'model_state': _params_to_list(model.parameters()),
|
|
119
|
+
'optimizer_state': _get_optimizer_state(optimizer),
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
if loss is not None:
|
|
123
|
+
checkpoint['loss'] = loss
|
|
124
|
+
if metrics is not None:
|
|
125
|
+
checkpoint['metrics'] = metrics
|
|
126
|
+
if extra is not None:
|
|
127
|
+
checkpoint['extra'] = extra
|
|
128
|
+
|
|
129
|
+
# Save as JSON
|
|
130
|
+
with open(path, 'w') as f:
|
|
131
|
+
json.dump(checkpoint, f, indent=2)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def load_checkpoint(
|
|
135
|
+
model,
|
|
136
|
+
optimizer,
|
|
137
|
+
path: str
|
|
138
|
+
) -> Dict[str, Any]:
|
|
139
|
+
"""
|
|
140
|
+
Load a training checkpoint.
|
|
141
|
+
|
|
142
|
+
Restores model weights and optimizer state.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
model: Model with parameters() method
|
|
146
|
+
optimizer: Optimizer to restore state to
|
|
147
|
+
path: Path to load checkpoint from
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Dict with 'epoch', 'loss', 'metrics', and any extra data
|
|
151
|
+
|
|
152
|
+
Examples:
|
|
153
|
+
>>> info = load_checkpoint(model, optimizer, 'checkpoint.json')
|
|
154
|
+
>>> print(f"Resuming from epoch {info['epoch']}")
|
|
155
|
+
"""
|
|
156
|
+
with open(path, 'r') as f:
|
|
157
|
+
checkpoint = json.load(f)
|
|
158
|
+
|
|
159
|
+
# Restore model weights
|
|
160
|
+
params = model.parameters()
|
|
161
|
+
loaded = checkpoint['model_state']
|
|
162
|
+
|
|
163
|
+
if len(loaded) != len(params):
|
|
164
|
+
raise ValueError(f"Parameter count mismatch: model has {len(params)}, checkpoint has {len(loaded)}")
|
|
165
|
+
|
|
166
|
+
for param, loaded_data in zip(params, loaded):
|
|
167
|
+
_update_tensor_data(param, loaded_data)
|
|
168
|
+
|
|
169
|
+
# Restore optimizer state
|
|
170
|
+
if 'optimizer_state' in checkpoint:
|
|
171
|
+
_set_optimizer_state(optimizer, checkpoint['optimizer_state'])
|
|
172
|
+
|
|
173
|
+
# Return info
|
|
174
|
+
return {
|
|
175
|
+
'epoch': checkpoint.get('epoch', 0),
|
|
176
|
+
'loss': checkpoint.get('loss'),
|
|
177
|
+
'metrics': checkpoint.get('metrics'),
|
|
178
|
+
'extra': checkpoint.get('extra')
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def get_model_state_dict(model) -> Dict[str, List]:
|
|
183
|
+
"""
|
|
184
|
+
Get model state as a dictionary.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
model: Model with parameters() method
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Dict mapping parameter names to their data
|
|
191
|
+
"""
|
|
192
|
+
params = model.parameters()
|
|
193
|
+
state = {}
|
|
194
|
+
for i, param in enumerate(params):
|
|
195
|
+
name = f"param_{i}"
|
|
196
|
+
if hasattr(param, 'data'):
|
|
197
|
+
state[name] = _tensor_to_list(param)
|
|
198
|
+
else:
|
|
199
|
+
state[name] = param
|
|
200
|
+
return state
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def set_model_state_dict(model, state_dict: Dict[str, List]) -> None:
|
|
204
|
+
"""
|
|
205
|
+
Set model state from a dictionary.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
model: Model with parameters() method
|
|
209
|
+
state_dict: Dict mapping parameter names to their data
|
|
210
|
+
"""
|
|
211
|
+
params = model.parameters()
|
|
212
|
+
for i, param in enumerate(params):
|
|
213
|
+
name = f"param_{i}"
|
|
214
|
+
if name in state_dict:
|
|
215
|
+
_update_tensor_data(param, state_dict[name])
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# ============================================================================
|
|
219
|
+
# Private helper functions
|
|
220
|
+
# ============================================================================
|
|
221
|
+
|
|
222
|
+
def _save_json(params: List[Tensor], path: str) -> None:
|
|
223
|
+
"""Save parameters to JSON file."""
|
|
224
|
+
data = _params_to_list(params)
|
|
225
|
+
with open(path, 'w') as f:
|
|
226
|
+
json.dump(data, f, indent=2)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _load_json(path: str) -> List:
|
|
230
|
+
"""Load parameters from JSON file."""
|
|
231
|
+
with open(path, 'r') as f:
|
|
232
|
+
return json.load(f)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _save_npz(params: List[Tensor], path: str) -> None:
|
|
236
|
+
"""Save parameters to NumPy npz file."""
|
|
237
|
+
arrays = {}
|
|
238
|
+
for i, param in enumerate(params):
|
|
239
|
+
arr = _tensor_to_numpy(param)
|
|
240
|
+
arrays[f'param_{i}'] = arr
|
|
241
|
+
np.savez(path, **arrays)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _load_npz(path: str) -> List:
|
|
245
|
+
"""Load parameters from NumPy npz file."""
|
|
246
|
+
data = np.load(path)
|
|
247
|
+
# Sort by parameter index
|
|
248
|
+
keys = sorted(data.files, key=lambda x: int(x.split('_')[1]))
|
|
249
|
+
return [data[k].tolist() for k in keys]
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _params_to_list(params: List[Tensor]) -> List:
|
|
253
|
+
"""Convert list of Tensors to list of nested lists."""
|
|
254
|
+
return [_tensor_to_list(p) for p in params]
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _tensor_to_list(t: Tensor) -> List:
|
|
258
|
+
"""Convert Tensor to nested list."""
|
|
259
|
+
data = t.data
|
|
260
|
+
if isinstance(data, list):
|
|
261
|
+
return data
|
|
262
|
+
# NumPy array
|
|
263
|
+
try:
|
|
264
|
+
return data.tolist()
|
|
265
|
+
except AttributeError:
|
|
266
|
+
return list(data)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _tensor_to_numpy(t: Tensor):
|
|
270
|
+
"""Convert Tensor to NumPy array."""
|
|
271
|
+
if hasattr(t, 'numpy') and t.numpy is not None:
|
|
272
|
+
return t.numpy
|
|
273
|
+
data = t.data
|
|
274
|
+
if isinstance(data, list):
|
|
275
|
+
return np.array(data)
|
|
276
|
+
return data
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _update_tensor_data(param: Tensor, data) -> None:
|
|
280
|
+
"""Update tensor data in-place."""
|
|
281
|
+
if isinstance(data, list):
|
|
282
|
+
param._data = data
|
|
283
|
+
else:
|
|
284
|
+
# NumPy array
|
|
285
|
+
param._data = data.tolist() if hasattr(data, 'tolist') else list(data)
|
|
286
|
+
|
|
287
|
+
# Clear cached numpy array if it exists
|
|
288
|
+
if hasattr(param, '_np_array'):
|
|
289
|
+
param._np_array = None
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _get_optimizer_state(optimizer) -> Dict[str, Any]:
|
|
293
|
+
"""Extract optimizer state for serialization."""
|
|
294
|
+
state = {
|
|
295
|
+
'lr': getattr(optimizer, 'lr', None),
|
|
296
|
+
'step': getattr(optimizer, '_step', 0),
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
# Handle Adam-like optimizers with momentum buffers
|
|
300
|
+
if hasattr(optimizer, '_m'):
|
|
301
|
+
state['m'] = [_tensor_to_list(m) if hasattr(m, 'data') else m
|
|
302
|
+
for m in optimizer._m]
|
|
303
|
+
if hasattr(optimizer, '_v'):
|
|
304
|
+
state['v'] = [_tensor_to_list(v) if hasattr(v, 'data') else v
|
|
305
|
+
for v in optimizer._v]
|
|
306
|
+
|
|
307
|
+
# Handle SGD momentum
|
|
308
|
+
if hasattr(optimizer, '_velocity'):
|
|
309
|
+
state['velocity'] = [_tensor_to_list(v) if hasattr(v, 'data') else v
|
|
310
|
+
for v in optimizer._velocity]
|
|
311
|
+
|
|
312
|
+
return state
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _set_optimizer_state(optimizer, state: Dict[str, Any]) -> None:
|
|
316
|
+
"""Restore optimizer state from serialized data."""
|
|
317
|
+
if 'step' in state:
|
|
318
|
+
optimizer._step = state['step']
|
|
319
|
+
|
|
320
|
+
# Handle Adam-like optimizers
|
|
321
|
+
if 'm' in state and hasattr(optimizer, '_m'):
|
|
322
|
+
for i, m_data in enumerate(state['m']):
|
|
323
|
+
if i < len(optimizer._m):
|
|
324
|
+
_update_tensor_data(optimizer._m[i], m_data)
|
|
325
|
+
|
|
326
|
+
if 'v' in state and hasattr(optimizer, '_v'):
|
|
327
|
+
for i, v_data in enumerate(state['v']):
|
|
328
|
+
if i < len(optimizer._v):
|
|
329
|
+
_update_tensor_data(optimizer._v[i], v_data)
|
|
330
|
+
|
|
331
|
+
# Handle SGD momentum
|
|
332
|
+
if 'velocity' in state and hasattr(optimizer, '_velocity'):
|
|
333
|
+
for i, v_data in enumerate(state['velocity']):
|
|
334
|
+
if i < len(optimizer._velocity):
|
|
335
|
+
_update_tensor_data(optimizer._velocity[i], v_data)
|