triggerflow 0.1.11__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +95 -0
  6. trigger_loader/processor.py +211 -0
  7. triggerflow/cli.py +122 -0
  8. triggerflow/core.py +119 -120
  9. triggerflow/mlflow_wrapper.py +54 -49
  10. triggerflow/starter/.gitignore +143 -0
  11. triggerflow/starter/README.md +0 -0
  12. triggerflow/starter/cookiecutter.json +5 -0
  13. triggerflow/starter/prompts.yml +9 -0
  14. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  15. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  16. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +84 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +15 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples_dummy.json +26 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +21 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +88 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +50 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +10 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +40 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +12 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +31 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +29 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  88. triggerflow-0.2.dist-info/METADATA +97 -0
  89. triggerflow-0.2.dist-info/RECORD +97 -0
  90. triggerflow-0.2.dist-info/entry_points.txt +2 -0
  91. triggerflow-0.2.dist-info/top_level.txt +3 -0
  92. triggerflow-0.1.11.dist-info/METADATA +0 -61
  93. triggerflow-0.1.11.dist-info/RECORD +0 -11
  94. triggerflow-0.1.11.dist-info/top_level.txt +0 -1
  95. {triggerflow-0.1.11.dist-info → triggerflow-0.2.dist-info}/WHEEL +0 -0
triggerflow/core.py CHANGED
@@ -1,32 +1,38 @@
1
- from pathlib import Path
1
+ import importlib
2
+ import importlib.resources as pkg_resources
2
3
  import json
3
- import numpy as np
4
+ import logging
5
+ import shutil
4
6
  import tarfile
5
- import importlib
7
+ import warnings
6
8
  from abc import ABC, abstractmethod
7
- from typing import Optional, Dict, Any, Union
8
- import shutil, warnings
9
- import importlib.resources as pkg_resources
10
- import triggerflow.templates
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+
14
+ import triggerflow.templates
15
+
16
+ logger = logging.getLogger(__name__)
11
17
 
12
18
 
13
19
  class ModelConverter(ABC):
14
20
  """Abstract base class for model converters"""
15
-
21
+
16
22
  @abstractmethod
17
- def convert(self, model, workspace: Path, **kwargs) -> Optional[Path]:
23
+ def convert(self, model, workspace: Path, **kwargs) -> Path | None:
18
24
  """Convert model to intermediate format"""
19
25
  pass
20
26
 
21
27
 
22
28
  class CompilerStrategy(ABC):
23
29
  """Abstract base class for compilation strategies"""
24
-
30
+
25
31
  @abstractmethod
26
- def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
32
+ def compile(self, model, workspace: Path, config: dict | None = None, **kwargs) -> Any:
27
33
  """Compile model to firmware"""
28
34
  pass
29
-
35
+
30
36
  @abstractmethod
31
37
  def load_compiled_model(self, workspace: Path) -> Any:
32
38
  """Load a previously compiled model"""
@@ -35,7 +41,7 @@ class CompilerStrategy(ABC):
35
41
 
36
42
  class ModelPredictor(ABC):
37
43
  """Abstract base class for model predictors"""
38
-
44
+
39
45
  @abstractmethod
40
46
  def predict(self, input_data: np.ndarray) -> np.ndarray:
41
47
  """Make predictions using the model"""
@@ -44,7 +50,7 @@ class ModelPredictor(ABC):
44
50
 
45
51
  class KerasToQONNXConverter(ModelConverter):
46
52
  """Converts Keras models to QONNX format"""
47
-
53
+
48
54
  def convert(self, model, workspace: Path, **kwargs) -> Path:
49
55
  import tensorflow as tf
50
56
  from qonnx.converters import keras as keras_converter
@@ -52,7 +58,7 @@ class KerasToQONNXConverter(ModelConverter):
52
58
  from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
53
59
  from qonnx.transformation.gemm_to_matmul import GemmToMatMul
54
60
  from qonnx.util.cleanup import cleanup_model
55
-
61
+
56
62
  qonnx_path = workspace / "model_qonnx.onnx"
57
63
  input_signature = [tf.TensorSpec(1, model.inputs[0].dtype, name="input_0")]
58
64
  qonnx_model, _ = keras_converter.from_keras(model, input_signature, output_path=qonnx_path)
@@ -61,26 +67,26 @@ class KerasToQONNXConverter(ModelConverter):
61
67
  qonnx_model = qonnx_model.transform(ConvertToChannelsLastAndClean())
62
68
  qonnx_model = qonnx_model.transform(GemmToMatMul())
63
69
  cleaned_model = cleanup_model(qonnx_model)
64
-
70
+
65
71
  return qonnx_path, cleaned_model
66
72
 
67
73
 
68
74
  class NoOpConverter(ModelConverter):
69
75
  """No-operation converter for models that don't need conversion"""
70
-
71
- def convert(self, model, workspace: Path, **kwargs) -> Optional[Path]:
76
+
77
+ def convert(self, model, workspace: Path, **kwargs) -> Path | None:
72
78
  return None
73
79
 
74
80
 
75
81
  class HLS4MLStrategy(CompilerStrategy):
76
82
  """HLS4ML compilation strategy for Keras models"""
77
-
78
- def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
83
+
84
+ def compile(self, model, workspace: Path, config: dict | None = None, **kwargs) -> Any:
79
85
  import hls4ml
80
-
86
+
81
87
  firmware_dir = workspace / "firmware"
82
88
  firmware_dir.mkdir(exist_ok=True)
83
-
89
+
84
90
  cfg = config or hls4ml.utils.config_from_keras_model(model, granularity="name")
85
91
 
86
92
  hls_kwargs = {
@@ -103,10 +109,10 @@ class HLS4MLStrategy(CompilerStrategy):
103
109
  warnings.warn("Vivado not found in PATH. Firmware build failed.", UserWarning)
104
110
  firmware_model.save(workspace / "firmware_model.fml")
105
111
  return firmware_model
106
-
112
+
107
113
  def load_compiled_model(self, workspace: Path) -> Any:
108
114
  from hls4ml.converters import link_existing_project
109
-
115
+
110
116
  firmware_model = link_existing_project(workspace / "firmware")
111
117
  firmware_model.compile()
112
118
  return firmware_model
@@ -114,30 +120,23 @@ class HLS4MLStrategy(CompilerStrategy):
114
120
  class ConiferStrategy(CompilerStrategy):
115
121
  """Conifer compilation strategy for XGBoost models"""
116
122
 
117
- def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
123
+ def compile(self, model, workspace: Path, config: dict | None = None, **kwargs) -> Any:
118
124
  import conifer
119
- import os
120
- import shutil
121
- import warnings
122
125
 
123
126
  firmware_dir = workspace / "firmware"
124
127
  firmware_dir.mkdir(exist_ok=True)
125
128
 
126
-
127
- project_name = kwargs.pop('project_name', None)
128
-
129
129
  cfg = config or conifer.backends.xilinxhls.auto_config()
130
130
  cfg['OutputDir'] = str(firmware_dir)
131
- cfg.update(kwargs)
131
+
132
+ for key, value in kwargs.items():
133
+ cfg[key] = value
132
134
 
133
135
  firmware_model = conifer.converters.convert_from_xgboost(
134
136
  model,
135
137
  config=cfg
136
138
  )
137
139
 
138
- if project_name:
139
- firmware_model.config.project_name = project_name
140
-
141
140
  firmware_model.compile()
142
141
  if shutil.which("vivado") is not None:
143
142
  firmware_model.build()
@@ -149,31 +148,31 @@ class ConiferStrategy(CompilerStrategy):
149
148
 
150
149
  class DA4MLStrategy(CompilerStrategy):
151
150
  """DA4ML compilation strategy (placeholder)"""
152
-
153
- def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
151
+
152
+ def compile(self, model, workspace: Path, config: dict | None = None) -> Any:
154
153
  raise NotImplementedError("DA4ML conversion without QONNX not yet implemented")
155
-
154
+
156
155
  def load_compiled_model(self, workspace: Path) -> Any:
157
156
  raise NotImplementedError("DA4ML loading not yet implemented")
158
157
 
159
158
 
160
159
  class FINNStrategy(CompilerStrategy):
161
160
  """FINN compilation strategy (placeholder)"""
162
-
163
- def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
161
+
162
+ def compile(self, model, workspace: Path, config: dict | None = None) -> Any:
164
163
  raise NotImplementedError("FINN conversion without QONNX not yet implemented")
165
-
164
+
166
165
  def load_compiled_model(self, workspace: Path) -> Any:
167
166
  raise NotImplementedError("FINN loading not yet implemented")
168
167
 
169
168
 
170
169
  class SoftwarePredictor(ModelPredictor):
171
170
  """Software-based model predictor"""
172
-
171
+
173
172
  def __init__(self, model, backend: str):
174
173
  self.model = model
175
174
  self.backend = backend.lower()
176
-
175
+
177
176
  def predict(self, input_data):
178
177
  if input_data.ndim == 1:
179
178
  input_data = np.expand_dims(input_data, axis=0)
@@ -182,37 +181,37 @@ class SoftwarePredictor(ModelPredictor):
182
181
 
183
182
  class QONNXPredictor(ModelPredictor):
184
183
  """QONNX-based model predictor"""
185
-
184
+
186
185
  def __init__(self, qonnx_model, input_name: str):
187
186
  self.qonnx_model = qonnx_model
188
187
  self.input_name = input_name
189
-
188
+
190
189
  def predict(self, input_data: np.ndarray) -> np.ndarray:
191
190
  from qonnx.core.onnx_exec import execute_onnx
192
-
191
+
193
192
  input_data = np.asarray(input_data)
194
193
  if input_data.ndim == 1:
195
194
  input_data = np.expand_dims(input_data, axis=0)
196
-
195
+
197
196
  outputs = []
198
197
  for i in range(input_data.shape[0]):
199
198
  sample = input_data[i].astype("float32").reshape(1, -1)
200
199
  output_dict = execute_onnx(self.qonnx_model, {self.input_name: sample})
201
200
  outputs.append(output_dict["global_out"])
202
-
201
+
203
202
  return np.vstack(outputs)
204
203
 
205
204
 
206
205
  class FirmwarePredictor(ModelPredictor):
207
206
  """Firmware-based model predictor"""
208
-
207
+
209
208
  def __init__(self, firmware_model, compiler):
210
209
  if firmware_model is None:
211
210
  raise RuntimeError("Firmware model not built.")
212
211
  self.firmware_model = firmware_model
213
212
  self.compiler = compiler
214
-
215
-
213
+
214
+
216
215
  def predict(self, input_data: np.ndarray) -> np.ndarray:
217
216
  if self.compiler == "conifer":
218
217
  return self.firmware_model.decision_function(input_data)
@@ -222,7 +221,7 @@ class FirmwarePredictor(ModelPredictor):
222
221
 
223
222
  class ConverterFactory:
224
223
  """Factory for creating model converters"""
225
-
224
+
226
225
  @staticmethod
227
226
  def create_converter(ml_backend: str, compiler: str) -> ModelConverter:
228
227
  if ml_backend.lower() == "keras" and compiler.lower() == "hls4ml":
@@ -233,12 +232,12 @@ class ConverterFactory:
233
232
 
234
233
  class CompilerFactory:
235
234
  """Factory for creating compilation strategies"""
236
-
235
+
237
236
  @staticmethod
238
237
  def create_compiler(ml_backend: str, compiler: str) -> CompilerStrategy:
239
238
  backend = ml_backend.lower()
240
239
  comp = compiler.lower()
241
-
240
+
242
241
  if backend == "keras" and comp == "hls4ml":
243
242
  return HLS4MLStrategy()
244
243
  elif backend == "xgboost" and comp == "conifer":
@@ -253,9 +252,9 @@ class CompilerFactory:
253
252
 
254
253
  class WorkspaceManager:
255
254
  """Manages workspace directories and metadata"""
256
-
255
+
257
256
  BASE_WORKSPACE = Path.cwd() / "triggermodel"
258
-
257
+
259
258
  def __init__(self):
260
259
  self.workspace = self.BASE_WORKSPACE
261
260
  self.artifacts = {"firmware": None}
@@ -265,7 +264,7 @@ class WorkspaceManager:
265
264
  "compiler": None,
266
265
  "versions": []
267
266
  }
268
-
267
+
269
268
  def setup_workspace(self, name: str, ml_backend: str, compiler: str):
270
269
  """Initialize workspace and metadata"""
271
270
  self.workspace.mkdir(parents=True, exist_ok=True)
@@ -274,22 +273,22 @@ class WorkspaceManager:
274
273
  "ml_backend": ml_backend,
275
274
  "compiler": compiler,
276
275
  })
277
-
276
+
278
277
  def save_native_model(self, model, ml_backend: str):
279
278
  """Save the native model to workspace"""
280
279
  if ml_backend.lower() == "keras":
281
- model.save(self.workspace / "keras_model")
280
+ model.save(self.workspace / "keras_model.keras")
282
281
  elif ml_backend.lower() == "xgboost":
283
282
  model.save_model(str(self.workspace / "xgb_model.json"))
284
-
283
+
285
284
  def add_artifact(self, key: str, value: Any):
286
285
  """Add artifact to tracking"""
287
286
  self.artifacts[key] = value
288
-
289
- def add_version(self, version_info: Dict):
287
+
288
+ def add_version(self, version_info: dict):
290
289
  """Add version information"""
291
290
  self.metadata["versions"].append(version_info)
292
-
291
+
293
292
  def save_metadata(self):
294
293
  """Save metadata to file"""
295
294
  with open(self.workspace / "metadata.json", "w") as f:
@@ -302,7 +301,7 @@ class WorkspaceManager:
302
301
 
303
302
  class ModelSerializer:
304
303
  """Handles model serialization and deserialization"""
305
-
304
+
306
305
  @staticmethod
307
306
  def save(workspace: Path, path: str):
308
307
  """Serialize the workspace into a tar.xz archive"""
@@ -310,37 +309,37 @@ class ModelSerializer:
310
309
  path.parent.mkdir(parents=True, exist_ok=True)
311
310
  with tarfile.open(path, mode="w:xz") as tar:
312
311
  tar.add(workspace, arcname=workspace.name)
313
- print(f"TriggerModel saved to {path}")
314
-
312
+ logger.info(f"TriggerModel saved to {path}")
313
+
315
314
  @staticmethod
316
- def load(path: str) -> Dict[str, Any]:
315
+ def load(path: str) -> dict[str, Any]:
317
316
  """Load workspace from tar.xz archive"""
318
317
  path = Path(path)
319
318
  if not path.exists():
320
319
  raise FileNotFoundError(f"{path} does not exist")
321
-
320
+
322
321
  workspace = Path.cwd() / "triggermodel"
323
-
322
+
324
323
  if workspace.exists():
325
324
  response = input(f"{workspace} already exists. Overwrite? [y/N]: ").strip().lower()
326
325
  if response != "y":
327
- print("Load cancelled by user.")
326
+ logger.info("Load cancelled by user.")
328
327
  return None
329
328
  shutil.rmtree(workspace)
330
-
329
+
331
330
  with tarfile.open(path, mode="r:xz") as tar:
332
331
  tar.extractall(path=Path.cwd())
333
-
332
+
334
333
  # Load metadata
335
334
  metadata_path = workspace / "metadata.json"
336
- with open(metadata_path, "r") as f:
335
+ with open(metadata_path) as f:
337
336
  metadata = json.load(f)
338
-
337
+
339
338
  return {
340
339
  "workspace": workspace,
341
340
  "metadata": metadata
342
341
  }
343
-
342
+
344
343
  @staticmethod
345
344
  def load_native_model(workspace: Path, ml_backend: str):
346
345
  """Load native model from workspace"""
@@ -357,7 +356,7 @@ class ModelSerializer:
357
356
  return model
358
357
  else:
359
358
  raise ValueError(f"Unsupported ml_backend: {ml_backend}")
360
-
359
+
361
360
  @staticmethod
362
361
  def load_qonnx_model(workspace: Path):
363
362
  """Load QONNX model if it exists"""
@@ -371,13 +370,13 @@ class ModelSerializer:
371
370
 
372
371
  class TriggerModel:
373
372
  """Main facade class that orchestrates model conversion, compilation, and inference"""
374
-
373
+
375
374
  def __init__(self, name: str, ml_backend: str, n_outputs:int, compiler: str,
376
375
  native_model: object, compiler_config: dict = None, scales: dict = None):
377
-
376
+
378
377
  if ml_backend.lower() not in ("keras", "xgboost"):
379
378
  raise ValueError("Only Keras or XGBoost backends are currently supported.")
380
-
379
+
381
380
  self.name = name
382
381
  self.ml_backend = ml_backend.lower()
383
382
  self.scales = scales
@@ -385,67 +384,67 @@ class TriggerModel:
385
384
  self.compiler = compiler.lower()
386
385
  self.native_model = native_model
387
386
  self.compiler_conifg = compiler_config
388
-
387
+
389
388
  self.workspace_manager = WorkspaceManager()
390
389
  self.converter = ConverterFactory.create_converter(ml_backend, compiler)
391
390
  self.compiler_strategy = CompilerFactory.create_compiler(ml_backend, compiler)
392
-
391
+
393
392
  self.firmware_model = None
394
393
  self.model_qonnx = None
395
394
  self.input_name = None
396
-
395
+
397
396
  self.workspace_manager.setup_workspace(name, self.ml_backend, self.compiler)
398
-
397
+
399
398
  @property
400
399
  def workspace(self) -> Path:
401
400
  """Get workspace path"""
402
401
  return self.workspace_manager.workspace
403
-
402
+
404
403
  @property
405
- def artifacts(self) -> Dict[str, Any]:
404
+ def artifacts(self) -> dict[str, Any]:
406
405
  """Get artifacts dictionary"""
407
406
  return self.workspace_manager.artifacts
408
-
407
+
409
408
  @property
410
- def metadata(self) -> Dict[str, Any]:
409
+ def metadata(self) -> dict[str, Any]:
411
410
  """Get metadata dictionary"""
412
411
  return self.workspace_manager.metadata
413
-
412
+
414
413
  def __call__(self, **compiler_kwargs):
415
414
  """Execute the full model conversion and compilation pipeline"""
416
415
  self.parse_dataset_object()
417
-
416
+
418
417
  # Save native model
419
418
  self.workspace_manager.save_native_model(self.native_model, self.ml_backend)
420
-
419
+
421
420
  # Convert model if needed
422
421
  conversion_result = self.converter.convert(
423
- self.native_model,
422
+ self.native_model,
424
423
  self.workspace_manager.workspace
425
424
  )
426
-
425
+
427
426
  if conversion_result is not None:
428
427
  qonnx_path, self.model_qonnx = conversion_result
429
428
  self.input_name = self.model_qonnx.graph.input[0].name
430
429
  self.workspace_manager.add_artifact("qonnx", qonnx_path)
431
430
  self.workspace_manager.add_version({"qonnx": str(qonnx_path)})
432
-
431
+
433
432
  # Compile model
434
433
  self.firmware_model = self.compiler_strategy.compile(
435
434
  self.native_model,
436
435
  self.workspace_manager.workspace,
437
- self.compiler_conifg,
436
+ self.compiler_conifg,
438
437
  **compiler_kwargs
439
438
  )
440
-
439
+
441
440
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
442
441
 
443
- if self.compiler is not "conifer" and self.scales is not None:
442
+ if self.compiler != "conifer" and self.scales is not None:
444
443
  self.build_emulator(self.scales['shifts'], self.scales['offsets'], self.n_outputs)
445
-
444
+
446
445
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
447
446
  self.workspace_manager.save_metadata()
448
-
447
+
449
448
  @staticmethod
450
449
  def parse_dataset_object():
451
450
  """Parse dataset object (placeholder)"""
@@ -460,24 +459,24 @@ class TriggerModel:
460
459
  template = template.replace("{{" + k + "}}", str(v))
461
460
  with open(out_path, "w") as f:
462
461
  f.write(template)
463
-
462
+
464
463
  def software_predict(self, input_data: np.ndarray) -> np.ndarray:
465
464
  """Make predictions using software model"""
466
465
  predictor = SoftwarePredictor(self.native_model, self.ml_backend)
467
466
  return predictor.predict(input_data)
468
-
467
+
469
468
  def qonnx_predict(self, input_data: np.ndarray) -> np.ndarray:
470
469
  """Make predictions using QONNX model"""
471
470
  if self.model_qonnx is None:
472
471
  raise RuntimeError("QONNX model not available")
473
472
  predictor = QONNXPredictor(self.model_qonnx, self.input_name)
474
473
  return predictor.predict(input_data)
475
-
474
+
476
475
  def firmware_predict(self, input_data: np.ndarray) -> np.ndarray:
477
476
  """Make predictions using firmware model"""
478
477
  predictor = FirmwarePredictor(self.firmware_model, self.compiler)
479
478
  return predictor.predict(input_data)
480
-
479
+
481
480
  def build_emulator(self, ad_shift: list, ad_offsets: list, n_outputs: int):
482
481
  """
483
482
  Create an emulator directory for this model.
@@ -486,13 +485,13 @@ class TriggerModel:
486
485
  emulator_dir = self.workspace / "emulator"
487
486
  emulator_dir.mkdir(exist_ok=True)
488
487
 
489
- model_dir = emulator_dir / self.name
488
+ model_dir = emulator_dir / self.name
490
489
  model_dir.mkdir(exist_ok=True)
491
-
490
+
492
491
  firmware_dir = self.workspace / "firmware" / "firmware"
493
-
492
+
494
493
  shutil.copytree(firmware_dir, f"{model_dir}/NN", dirs_exist_ok=True)
495
-
494
+
496
495
  # Access scales template from installed package
497
496
  with pkg_resources.path(triggerflow.templates, "scales.h") as scales_template_path:
498
497
  scales_out_path = model_dir / "scales.h"
@@ -508,7 +507,7 @@ class TriggerModel:
508
507
  with pkg_resources.path(triggerflow.templates, "model_template.cpp") as emulator_template_path:
509
508
  emulator_out_path = model_dir / "emulator.cpp"
510
509
  self._render_template(emulator_template_path, emulator_out_path, context)
511
-
510
+
512
511
  with pkg_resources.path(triggerflow.templates, "makefile_version") as makefile_template_path:
513
512
  makefile_out_path = model_dir / "Makefile"
514
513
  self._render_template(makefile_template_path, makefile_out_path, {"MODEL_NAME": self.name})
@@ -516,44 +515,44 @@ class TriggerModel:
516
515
  with pkg_resources.path(triggerflow.templates, "makefile") as makefile_template_path:
517
516
  makefile_out_path = emulator_dir / "Makefile"
518
517
  self._render_template(makefile_template_path, makefile_out_path, {"MODEL_NAME": self.name})
519
-
520
-
518
+
519
+
521
520
  def save(self, path: str):
522
521
  """Save the complete model to an archive"""
523
522
  ModelSerializer.save(self.workspace_manager.workspace, path)
524
-
523
+
525
524
  @classmethod
526
525
  def load(cls, path: str) -> 'TriggerModel':
527
526
  """Load a model from an archive"""
528
527
  load_result = ModelSerializer.load(path)
529
528
  if load_result is None:
530
529
  return None
531
-
530
+
532
531
  workspace = load_result["workspace"]
533
532
  metadata = load_result["metadata"]
534
-
533
+
535
534
  obj = cls.__new__(cls)
536
535
  obj.workspace_manager = WorkspaceManager()
537
536
  obj.workspace_manager.workspace = workspace
538
537
  obj.workspace_manager.metadata = metadata
539
538
  obj.workspace_manager.artifacts = {"firmware": workspace / "firmware"}
540
-
539
+
541
540
  obj.name = metadata.get("name", "")
542
541
  obj.ml_backend = metadata.get("ml_backend")
543
542
  obj.compiler = metadata.get("compiler")
544
-
543
+
545
544
  obj.native_model = ModelSerializer.load_native_model(workspace, obj.ml_backend)
546
-
545
+
547
546
  obj.model_qonnx, obj.input_name = ModelSerializer.load_qonnx_model(workspace)
548
-
547
+
549
548
  if obj.compiler.lower() in ("hls4ml", "conifer"):
550
549
  obj.compiler_strategy = CompilerFactory.create_compiler(obj.ml_backend, obj.compiler)
551
550
  obj.firmware_model = obj.compiler_strategy.load_compiled_model(workspace)
552
551
  else:
553
552
  obj.firmware_model = None
554
553
  obj.compiler_strategy = None
555
-
554
+
556
555
  obj.converter = ConverterFactory.create_converter(obj.ml_backend, obj.compiler)
557
- obj.dataset_object = None
558
-
559
- return obj
556
+ obj.dataset_object = None
557
+
558
+ return obj