openarchx 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openarchx/__init__.py +11 -0
- openarchx/core/tensor.py +179 -0
- openarchx/cuda/__init__.py +27 -0
- openarchx/cuda/cuda_ops.py +296 -0
- openarchx/layers/activations.py +63 -0
- openarchx/layers/base.py +40 -0
- openarchx/layers/cnn.py +145 -0
- openarchx/layers/transformer.py +131 -0
- openarchx/nn/__init__.py +26 -0
- openarchx/nn/activations.py +127 -0
- openarchx/nn/containers.py +174 -0
- openarchx/nn/dropout.py +121 -0
- openarchx/nn/layers.py +338 -0
- openarchx/nn/losses.py +156 -0
- openarchx/nn/module.py +18 -0
- openarchx/nn/padding.py +120 -0
- openarchx/nn/pooling.py +318 -0
- openarchx/nn/rnn.py +226 -0
- openarchx/nn/transformers.py +187 -0
- openarchx/optimizers/adam.py +49 -0
- openarchx/optimizers/adaptive.py +63 -0
- openarchx/optimizers/base.py +24 -0
- openarchx/optimizers/modern.py +98 -0
- openarchx/optimizers/optx.py +91 -0
- openarchx/optimizers/sgd.py +63 -0
- openarchx/quantum/circuit.py +92 -0
- openarchx/quantum/gates.py +126 -0
- openarchx/utils/__init__.py +50 -0
- openarchx/utils/data.py +229 -0
- openarchx/utils/huggingface.py +288 -0
- openarchx/utils/losses.py +21 -0
- openarchx/utils/model_io.py +553 -0
- openarchx/utils/pytorch.py +420 -0
- openarchx/utils/tensorflow.py +467 -0
- openarchx/utils/transforms.py +259 -0
- openarchx-0.1.0.dist-info/METADATA +180 -0
- openarchx-0.1.0.dist-info/RECORD +43 -0
- openarchx-0.1.0.dist-info/WHEEL +5 -0
- openarchx-0.1.0.dist-info/licenses/LICENSE +21 -0
- openarchx-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +1 -0
- tests/test_cuda_ops.py +205 -0
- tests/test_integrations.py +236 -0
@@ -0,0 +1,21 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from ..core.tensor import Tensor
|
3
|
+
|
4
|
+
def mse_loss(pred: Tensor, target) -> Tensor:
|
5
|
+
"""Mean Squared Error Loss"""
|
6
|
+
if not isinstance(target, Tensor):
|
7
|
+
target = Tensor(target)
|
8
|
+
return ((pred - target) * (pred - target)).mean()
|
9
|
+
|
10
|
+
def cross_entropy_loss(pred: Tensor, target) -> Tensor:
|
11
|
+
"""Cross Entropy Loss for classification with numerical stability"""
|
12
|
+
if not isinstance(target, Tensor):
|
13
|
+
target = Tensor(target)
|
14
|
+
|
15
|
+
# Add small epsilon for numerical stability
|
16
|
+
eps = 1e-7
|
17
|
+
pred_clipped = Tensor(np.clip(pred.data, eps, 1.0 - eps))
|
18
|
+
|
19
|
+
# Calculate cross entropy
|
20
|
+
loss = -((target * pred_clipped.log()).sum(axis=-1)).mean()
|
21
|
+
return loss
|
@@ -0,0 +1,553 @@
|
|
1
|
+
"""
|
2
|
+
Model I/O Utilities for OpenArchX.
|
3
|
+
|
4
|
+
This module provides utilities for saving and loading OpenArchX models to/from
|
5
|
+
disk in the native .oaxm format, as well as conversion utilities for other formats.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import json
|
10
|
+
import pickle
|
11
|
+
import numpy as np
|
12
|
+
import importlib.util
|
13
|
+
from ..core.tensor import Tensor
|
14
|
+
|
15
|
+
|
16
|
+
class ModelSerializer:
|
17
|
+
"""Utility for serializing OpenArchX models to the native .oaxm format."""
|
18
|
+
|
19
|
+
@staticmethod
|
20
|
+
def save_model(model, filepath, metadata=None, compress=True):
|
21
|
+
"""
|
22
|
+
Save an OpenArchX model to a .oaxm file.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
model: The OpenArchX model to save.
|
26
|
+
filepath: The path to save the model to. If the extension is not .oaxm,
|
27
|
+
it will be appended.
|
28
|
+
metadata: Optional dictionary of metadata to save with the model.
|
29
|
+
compress: Whether to compress the saved model file.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
The path to the saved model file.
|
33
|
+
"""
|
34
|
+
# Ensure the file has the .oaxm extension
|
35
|
+
if not filepath.endswith('.oaxm'):
|
36
|
+
filepath += '.oaxm'
|
37
|
+
|
38
|
+
# Create directory if it doesn't exist
|
39
|
+
os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True)
|
40
|
+
|
41
|
+
# Prepare model state
|
42
|
+
model_state = {}
|
43
|
+
|
44
|
+
# Get model parameters
|
45
|
+
if hasattr(model, 'state_dict'):
|
46
|
+
params = model.state_dict()
|
47
|
+
else:
|
48
|
+
# Fallback for models without state_dict method
|
49
|
+
params = {}
|
50
|
+
for key, value in model.__dict__.items():
|
51
|
+
if key.startswith('_'):
|
52
|
+
continue
|
53
|
+
if isinstance(value, Tensor):
|
54
|
+
params[key] = value
|
55
|
+
|
56
|
+
# Convert Tensor objects to numpy arrays
|
57
|
+
model_state['parameters'] = {
|
58
|
+
name: param.data if isinstance(param, Tensor) else param
|
59
|
+
for name, param in params.items()
|
60
|
+
}
|
61
|
+
|
62
|
+
# Save model architecture if available
|
63
|
+
if hasattr(model, 'config'):
|
64
|
+
model_state['config'] = model.config
|
65
|
+
elif hasattr(model, 'architecture'):
|
66
|
+
model_state['architecture'] = model.architecture
|
67
|
+
|
68
|
+
# Add metadata
|
69
|
+
if metadata is not None:
|
70
|
+
model_state['metadata'] = metadata
|
71
|
+
else:
|
72
|
+
model_state['metadata'] = {
|
73
|
+
'format_version': '1.0',
|
74
|
+
'framework': 'openarchx',
|
75
|
+
'model_type': model.__class__.__name__
|
76
|
+
}
|
77
|
+
|
78
|
+
# Save model state
|
79
|
+
if compress:
|
80
|
+
import gzip
|
81
|
+
with gzip.open(filepath, 'wb') as f:
|
82
|
+
pickle.dump(model_state, f)
|
83
|
+
else:
|
84
|
+
with open(filepath, 'wb') as f:
|
85
|
+
pickle.dump(model_state, f)
|
86
|
+
|
87
|
+
return filepath
|
88
|
+
|
89
|
+
@staticmethod
|
90
|
+
def load_model(filepath, model_class=None):
|
91
|
+
"""
|
92
|
+
Load an OpenArchX model from a .oaxm file.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
filepath: The path to the model file.
|
96
|
+
model_class: The model class to instantiate. If None, the method will
|
97
|
+
try to infer it from the saved metadata.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
The loaded OpenArchX model.
|
101
|
+
"""
|
102
|
+
# Check if the file exists
|
103
|
+
if not os.path.exists(filepath):
|
104
|
+
raise FileNotFoundError(f"Model file not found: {filepath}")
|
105
|
+
|
106
|
+
# Load the model state
|
107
|
+
if filepath.endswith('.oaxm.gz') or (filepath.endswith('.oaxm') and _is_gzipped(filepath)):
|
108
|
+
import gzip
|
109
|
+
with gzip.open(filepath, 'rb') as f:
|
110
|
+
model_state = pickle.load(f)
|
111
|
+
else:
|
112
|
+
with open(filepath, 'rb') as f:
|
113
|
+
model_state = pickle.load(f)
|
114
|
+
|
115
|
+
# Get the model parameters
|
116
|
+
parameters = model_state.get('parameters', {})
|
117
|
+
|
118
|
+
# Convert numpy arrays to Tensor objects
|
119
|
+
tensor_params = {
|
120
|
+
name: Tensor(param) if isinstance(param, np.ndarray) else param
|
121
|
+
for name, param in parameters.items()
|
122
|
+
}
|
123
|
+
|
124
|
+
# Get model configuration
|
125
|
+
config = model_state.get('config', None)
|
126
|
+
architecture = model_state.get('architecture', None)
|
127
|
+
|
128
|
+
# Initialize the model
|
129
|
+
if model_class is not None:
|
130
|
+
# Use the provided model class
|
131
|
+
if config is not None:
|
132
|
+
model = model_class(config=config)
|
133
|
+
else:
|
134
|
+
model = model_class()
|
135
|
+
else:
|
136
|
+
# Try to infer the model class from metadata
|
137
|
+
metadata = model_state.get('metadata', {})
|
138
|
+
model_type = metadata.get('model_type', None)
|
139
|
+
|
140
|
+
if model_type is None:
|
141
|
+
raise ValueError("Model class must be provided if not stored in metadata")
|
142
|
+
|
143
|
+
# Import the model class dynamically
|
144
|
+
from ..nn.models import get_model_class
|
145
|
+
try:
|
146
|
+
model_class = get_model_class(model_type)
|
147
|
+
if config is not None:
|
148
|
+
model = model_class(config=config)
|
149
|
+
else:
|
150
|
+
model = model_class()
|
151
|
+
except (ImportError, AttributeError):
|
152
|
+
raise ValueError(f"Could not import model class: {model_type}")
|
153
|
+
|
154
|
+
# Load parameters into the model
|
155
|
+
if hasattr(model, 'load_state_dict'):
|
156
|
+
model.load_state_dict(tensor_params)
|
157
|
+
else:
|
158
|
+
# Fallback for models without load_state_dict method
|
159
|
+
for name, param in tensor_params.items():
|
160
|
+
if hasattr(model, name):
|
161
|
+
setattr(model, name, param)
|
162
|
+
|
163
|
+
return model
|
164
|
+
|
165
|
+
|
166
|
+
class ModelConverter:
|
167
|
+
"""Utility for converting between different model formats."""
|
168
|
+
|
169
|
+
@staticmethod
|
170
|
+
def from_pytorch(torch_model, output_file, metadata=None):
|
171
|
+
"""
|
172
|
+
Convert a PyTorch model to OpenArchX .oaxm format.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
torch_model: The PyTorch model to convert.
|
176
|
+
output_file: The path to save the converted model.
|
177
|
+
metadata: Optional metadata to include in the saved model.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
The path to the saved .oaxm file.
|
181
|
+
"""
|
182
|
+
# Check if PyTorch is installed
|
183
|
+
if importlib.util.find_spec("torch") is None:
|
184
|
+
raise ImportError("PyTorch is required. Install with 'pip install torch'")
|
185
|
+
|
186
|
+
import torch
|
187
|
+
|
188
|
+
# Get model parameters
|
189
|
+
state_dict = torch_model.state_dict()
|
190
|
+
|
191
|
+
# Convert to numpy arrays
|
192
|
+
numpy_params = {
|
193
|
+
name: param.detach().cpu().numpy()
|
194
|
+
for name, param in state_dict.items()
|
195
|
+
}
|
196
|
+
|
197
|
+
# Create model state
|
198
|
+
model_state = {
|
199
|
+
'parameters': numpy_params,
|
200
|
+
'metadata': {
|
201
|
+
'format_version': '1.0',
|
202
|
+
'framework': 'openarchx',
|
203
|
+
'original_framework': 'pytorch',
|
204
|
+
'model_type': torch_model.__class__.__name__
|
205
|
+
}
|
206
|
+
}
|
207
|
+
|
208
|
+
# Update with custom metadata if provided
|
209
|
+
if metadata is not None:
|
210
|
+
model_state['metadata'].update(metadata)
|
211
|
+
|
212
|
+
# Save the model
|
213
|
+
if not output_file.endswith('.oaxm'):
|
214
|
+
output_file += '.oaxm'
|
215
|
+
|
216
|
+
# Create directory if it doesn't exist
|
217
|
+
os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True)
|
218
|
+
|
219
|
+
# Save the model
|
220
|
+
with open(output_file, 'wb') as f:
|
221
|
+
pickle.dump(model_state, f)
|
222
|
+
|
223
|
+
return output_file
|
224
|
+
|
225
|
+
@staticmethod
|
226
|
+
def from_tensorflow(tf_model, output_file, metadata=None):
|
227
|
+
"""
|
228
|
+
Convert a TensorFlow model to OpenArchX .oaxm format.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
tf_model: The TensorFlow model to convert.
|
232
|
+
output_file: The path to save the converted model.
|
233
|
+
metadata: Optional metadata to include in the saved model.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
The path to the saved .oaxm file.
|
237
|
+
"""
|
238
|
+
# Check if TensorFlow is installed
|
239
|
+
if importlib.util.find_spec("tensorflow") is None:
|
240
|
+
raise ImportError("TensorFlow is required. Install with 'pip install tensorflow'")
|
241
|
+
|
242
|
+
import tensorflow as tf
|
243
|
+
|
244
|
+
# Get model weights
|
245
|
+
weights = tf_model.get_weights()
|
246
|
+
weight_names = [weight.name for weight in tf_model.weights]
|
247
|
+
|
248
|
+
# Create parameters dictionary
|
249
|
+
numpy_params = {
|
250
|
+
name.replace(':0', ''): weight
|
251
|
+
for name, weight in zip(weight_names, weights)
|
252
|
+
}
|
253
|
+
|
254
|
+
# Create model state
|
255
|
+
model_state = {
|
256
|
+
'parameters': numpy_params,
|
257
|
+
'metadata': {
|
258
|
+
'format_version': '1.0',
|
259
|
+
'framework': 'openarchx',
|
260
|
+
'original_framework': 'tensorflow',
|
261
|
+
'model_type': tf_model.__class__.__name__
|
262
|
+
}
|
263
|
+
}
|
264
|
+
|
265
|
+
# Update with custom metadata if provided
|
266
|
+
if metadata is not None:
|
267
|
+
model_state['metadata'].update(metadata)
|
268
|
+
|
269
|
+
# Save the model
|
270
|
+
if not output_file.endswith('.oaxm'):
|
271
|
+
output_file += '.oaxm'
|
272
|
+
|
273
|
+
# Create directory if it doesn't exist
|
274
|
+
os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True)
|
275
|
+
|
276
|
+
# Save the model
|
277
|
+
with open(output_file, 'wb') as f:
|
278
|
+
pickle.dump(model_state, f)
|
279
|
+
|
280
|
+
return output_file
|
281
|
+
|
282
|
+
@staticmethod
|
283
|
+
def to_pytorch(oaxm_file, torch_model=None):
|
284
|
+
"""
|
285
|
+
Convert an OpenArchX .oaxm model to PyTorch format.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
oaxm_file: The path to the .oaxm model file.
|
289
|
+
torch_model: Optional PyTorch model to load the parameters into.
|
290
|
+
If None, an equivalent PyTorch model will be created.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
The PyTorch model with the loaded parameters.
|
294
|
+
"""
|
295
|
+
# Check if PyTorch is installed
|
296
|
+
if importlib.util.find_spec("torch") is None:
|
297
|
+
raise ImportError("PyTorch is required. Install with 'pip install torch'")
|
298
|
+
|
299
|
+
import torch
|
300
|
+
|
301
|
+
# Load the OpenArchX model
|
302
|
+
model_state = _load_oaxm_state(oaxm_file)
|
303
|
+
parameters = model_state.get('parameters', {})
|
304
|
+
|
305
|
+
# If no PyTorch model is provided, we can't create one automatically
|
306
|
+
# as the architecture is not fully defined in the .oaxm file
|
307
|
+
if torch_model is None:
|
308
|
+
raise ValueError("A PyTorch model must be provided to load parameters into")
|
309
|
+
|
310
|
+
# Convert numpy arrays to PyTorch tensors
|
311
|
+
torch_state_dict = {
|
312
|
+
name: torch.tensor(param) if isinstance(param, np.ndarray) else param
|
313
|
+
for name, param in parameters.items()
|
314
|
+
}
|
315
|
+
|
316
|
+
# Load parameters into the model
|
317
|
+
torch_model.load_state_dict(torch_state_dict)
|
318
|
+
|
319
|
+
return torch_model
|
320
|
+
|
321
|
+
@staticmethod
|
322
|
+
def to_tensorflow(oaxm_file, tf_model=None):
|
323
|
+
"""
|
324
|
+
Convert an OpenArchX .oaxm model to TensorFlow format.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
oaxm_file: The path to the .oaxm model file.
|
328
|
+
tf_model: Optional TensorFlow model to load the parameters into.
|
329
|
+
If None, an equivalent TensorFlow model will be created.
|
330
|
+
|
331
|
+
Returns:
|
332
|
+
The TensorFlow model with the loaded parameters.
|
333
|
+
"""
|
334
|
+
# Check if TensorFlow is installed
|
335
|
+
if importlib.util.find_spec("tensorflow") is None:
|
336
|
+
raise ImportError("TensorFlow is required. Install with 'pip install tensorflow'")
|
337
|
+
|
338
|
+
import tensorflow as tf
|
339
|
+
|
340
|
+
# Load the OpenArchX model
|
341
|
+
model_state = _load_oaxm_state(oaxm_file)
|
342
|
+
parameters = model_state.get('parameters', {})
|
343
|
+
|
344
|
+
# If no TensorFlow model is provided, we can't create one automatically
|
345
|
+
if tf_model is None:
|
346
|
+
raise ValueError("A TensorFlow model must be provided to load parameters into")
|
347
|
+
|
348
|
+
# Get weight names from the TensorFlow model
|
349
|
+
weight_names = [weight.name.replace(':0', '') for weight in tf_model.weights]
|
350
|
+
|
351
|
+
# Map the parameters to the TensorFlow weights
|
352
|
+
weights = []
|
353
|
+
for name in weight_names:
|
354
|
+
if name in parameters:
|
355
|
+
weights.append(parameters[name])
|
356
|
+
else:
|
357
|
+
raise ValueError(f"Parameter not found in .oaxm file: {name}")
|
358
|
+
|
359
|
+
# Set the weights
|
360
|
+
tf_model.set_weights(weights)
|
361
|
+
|
362
|
+
return tf_model
|
363
|
+
|
364
|
+
|
365
|
+
class ModelRegistry:
|
366
|
+
"""Registry for model architectures to help with model loading."""
|
367
|
+
|
368
|
+
_registry = {}
|
369
|
+
|
370
|
+
@classmethod
|
371
|
+
def register(cls, name, model_class):
|
372
|
+
"""
|
373
|
+
Register a model class with a name.
|
374
|
+
|
375
|
+
Args:
|
376
|
+
name: The name to register the model class under.
|
377
|
+
model_class: The model class to register.
|
378
|
+
"""
|
379
|
+
cls._registry[name] = model_class
|
380
|
+
|
381
|
+
@classmethod
|
382
|
+
def get(cls, name):
|
383
|
+
"""
|
384
|
+
Get a model class by name.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
name: The name of the model class.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
The model class or None if not found.
|
391
|
+
"""
|
392
|
+
return cls._registry.get(name, None)
|
393
|
+
|
394
|
+
@classmethod
|
395
|
+
def list_models(cls):
|
396
|
+
"""
|
397
|
+
List all registered model names.
|
398
|
+
|
399
|
+
Returns:
|
400
|
+
A list of registered model names.
|
401
|
+
"""
|
402
|
+
return list(cls._registry.keys())
|
403
|
+
|
404
|
+
|
405
|
+
# Helper functions
|
406
|
+
|
407
|
+
def _is_gzipped(filepath):
|
408
|
+
"""Check if a file is gzipped."""
|
409
|
+
with open(filepath, 'rb') as f:
|
410
|
+
return f.read(2) == b'\x1f\x8b'
|
411
|
+
|
412
|
+
|
413
|
+
def _load_oaxm_state(filepath):
|
414
|
+
"""Load the state from a .oaxm file."""
|
415
|
+
# Check if the file exists
|
416
|
+
if not os.path.exists(filepath):
|
417
|
+
raise FileNotFoundError(f"Model file not found: {filepath}")
|
418
|
+
|
419
|
+
# Load the model state
|
420
|
+
if filepath.endswith('.oaxm.gz') or (filepath.endswith('.oaxm') and _is_gzipped(filepath)):
|
421
|
+
import gzip
|
422
|
+
with gzip.open(filepath, 'rb') as f:
|
423
|
+
return pickle.load(f)
|
424
|
+
else:
|
425
|
+
with open(filepath, 'rb') as f:
|
426
|
+
return pickle.load(f)
|
427
|
+
|
428
|
+
|
429
|
+
# Convenience functions
|
430
|
+
|
431
|
+
def save_model(model, filepath, metadata=None, compress=True):
|
432
|
+
"""
|
433
|
+
Save an OpenArchX model to a .oaxm file.
|
434
|
+
|
435
|
+
Args:
|
436
|
+
model: The OpenArchX model to save.
|
437
|
+
filepath: The path to save the model to.
|
438
|
+
metadata: Optional dictionary of metadata to save with the model.
|
439
|
+
compress: Whether to compress the saved model file.
|
440
|
+
|
441
|
+
Returns:
|
442
|
+
The path to the saved model file.
|
443
|
+
"""
|
444
|
+
return ModelSerializer.save_model(model, filepath, metadata, compress)
|
445
|
+
|
446
|
+
|
447
|
+
def load_model(filepath, model_class=None):
|
448
|
+
"""
|
449
|
+
Load an OpenArchX model from a .oaxm file.
|
450
|
+
|
451
|
+
Args:
|
452
|
+
filepath: The path to the model file.
|
453
|
+
model_class: The model class to instantiate. If None, the method will
|
454
|
+
try to infer it from the saved metadata.
|
455
|
+
|
456
|
+
Returns:
|
457
|
+
The loaded OpenArchX model.
|
458
|
+
"""
|
459
|
+
return ModelSerializer.load_model(filepath, model_class)
|
460
|
+
|
461
|
+
|
462
|
+
def convert_from_pytorch(torch_model, output_file, metadata=None):
|
463
|
+
"""
|
464
|
+
Convert a PyTorch model to OpenArchX .oaxm format.
|
465
|
+
|
466
|
+
Args:
|
467
|
+
torch_model: The PyTorch model to convert.
|
468
|
+
output_file: The path to save the converted model.
|
469
|
+
metadata: Optional metadata to include in the saved model.
|
470
|
+
|
471
|
+
Returns:
|
472
|
+
The path to the saved .oaxm file.
|
473
|
+
"""
|
474
|
+
return ModelConverter.from_pytorch(torch_model, output_file, metadata)
|
475
|
+
|
476
|
+
|
477
|
+
def convert_from_tensorflow(tf_model, output_file, metadata=None):
|
478
|
+
"""
|
479
|
+
Convert a TensorFlow model to OpenArchX .oaxm format.
|
480
|
+
|
481
|
+
Args:
|
482
|
+
tf_model: The TensorFlow model to convert.
|
483
|
+
output_file: The path to save the converted model.
|
484
|
+
metadata: Optional metadata to include in the saved model.
|
485
|
+
|
486
|
+
Returns:
|
487
|
+
The path to the saved .oaxm file.
|
488
|
+
"""
|
489
|
+
return ModelConverter.from_tensorflow(tf_model, output_file, metadata)
|
490
|
+
|
491
|
+
|
492
|
+
def convert_to_pytorch(oaxm_file, torch_model=None):
|
493
|
+
"""
|
494
|
+
Convert an OpenArchX .oaxm model to PyTorch format.
|
495
|
+
|
496
|
+
Args:
|
497
|
+
oaxm_file: The path to the .oaxm model file.
|
498
|
+
torch_model: Optional PyTorch model to load the parameters into.
|
499
|
+
If None, an equivalent PyTorch model will be created.
|
500
|
+
|
501
|
+
Returns:
|
502
|
+
The PyTorch model with the loaded parameters.
|
503
|
+
"""
|
504
|
+
return ModelConverter.to_pytorch(oaxm_file, torch_model)
|
505
|
+
|
506
|
+
|
507
|
+
def convert_to_tensorflow(oaxm_file, tf_model=None):
|
508
|
+
"""
|
509
|
+
Convert an OpenArchX .oaxm model to TensorFlow format.
|
510
|
+
|
511
|
+
Args:
|
512
|
+
oaxm_file: The path to the .oaxm model file.
|
513
|
+
tf_model: Optional TensorFlow model to load the parameters into.
|
514
|
+
If None, an equivalent TensorFlow model will be created.
|
515
|
+
|
516
|
+
Returns:
|
517
|
+
The TensorFlow model with the loaded parameters.
|
518
|
+
"""
|
519
|
+
return ModelConverter.to_tensorflow(oaxm_file, tf_model)
|
520
|
+
|
521
|
+
|
522
|
+
def register_model(name, model_class):
|
523
|
+
"""
|
524
|
+
Register a model class with a name.
|
525
|
+
|
526
|
+
Args:
|
527
|
+
name: The name to register the model class under.
|
528
|
+
model_class: The model class to register.
|
529
|
+
"""
|
530
|
+
ModelRegistry.register(name, model_class)
|
531
|
+
|
532
|
+
|
533
|
+
def get_model_class(name):
|
534
|
+
"""
|
535
|
+
Get a model class by name.
|
536
|
+
|
537
|
+
Args:
|
538
|
+
name: The name of the model class.
|
539
|
+
|
540
|
+
Returns:
|
541
|
+
The model class or None if not found.
|
542
|
+
"""
|
543
|
+
return ModelRegistry.get(name)
|
544
|
+
|
545
|
+
|
546
|
+
def list_registered_models():
|
547
|
+
"""
|
548
|
+
List all registered model names.
|
549
|
+
|
550
|
+
Returns:
|
551
|
+
A list of registered model names.
|
552
|
+
"""
|
553
|
+
return ModelRegistry.list_models()
|