triggerflow 0.1.12__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +95 -0
  6. trigger_loader/processor.py +211 -0
  7. triggerflow/cli.py +122 -0
  8. triggerflow/core.py +118 -114
  9. triggerflow/mlflow_wrapper.py +54 -49
  10. triggerflow/starter/.gitignore +143 -0
  11. triggerflow/starter/README.md +0 -0
  12. triggerflow/starter/cookiecutter.json +5 -0
  13. triggerflow/starter/prompts.yml +9 -0
  14. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  15. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  16. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +84 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +15 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples_dummy.json +26 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +21 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +88 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +50 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +10 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +40 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +12 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +31 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +29 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  88. triggerflow-0.2.dist-info/METADATA +97 -0
  89. triggerflow-0.2.dist-info/RECORD +97 -0
  90. triggerflow-0.2.dist-info/entry_points.txt +2 -0
  91. triggerflow-0.2.dist-info/top_level.txt +3 -0
  92. triggerflow-0.1.12.dist-info/METADATA +0 -61
  93. triggerflow-0.1.12.dist-info/RECORD +0 -11
  94. triggerflow-0.1.12.dist-info/top_level.txt +0 -1
  95. {triggerflow-0.1.12.dist-info → triggerflow-0.2.dist-info}/WHEEL +0 -0
triggerflow/core.py CHANGED
@@ -1,32 +1,38 @@
1
- from pathlib import Path
1
+ import importlib
2
+ import importlib.resources as pkg_resources
2
3
  import json
3
- import numpy as np
4
+ import logging
5
+ import shutil
4
6
  import tarfile
5
- import importlib
7
+ import warnings
6
8
  from abc import ABC, abstractmethod
7
- from typing import Optional, Dict, Any, Union
8
- import shutil, warnings
9
- import importlib.resources as pkg_resources
10
- import triggerflow.templates
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+
14
+ import triggerflow.templates
15
+
16
+ logger = logging.getLogger(__name__)
11
17
 
12
18
 
13
19
  class ModelConverter(ABC):
14
20
  """Abstract base class for model converters"""
15
-
21
+
16
22
  @abstractmethod
17
- def convert(self, model, workspace: Path, **kwargs) -> Optional[Path]:
23
+ def convert(self, model, workspace: Path, **kwargs) -> Path | None:
18
24
  """Convert model to intermediate format"""
19
25
  pass
20
26
 
21
27
 
22
28
  class CompilerStrategy(ABC):
23
29
  """Abstract base class for compilation strategies"""
24
-
30
+
25
31
  @abstractmethod
26
- def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
32
+ def compile(self, model, workspace: Path, config: dict | None = None, **kwargs) -> Any:
27
33
  """Compile model to firmware"""
28
34
  pass
29
-
35
+
30
36
  @abstractmethod
31
37
  def load_compiled_model(self, workspace: Path) -> Any:
32
38
  """Load a previously compiled model"""
@@ -35,7 +41,7 @@ class CompilerStrategy(ABC):
35
41
 
36
42
  class ModelPredictor(ABC):
37
43
  """Abstract base class for model predictors"""
38
-
44
+
39
45
  @abstractmethod
40
46
  def predict(self, input_data: np.ndarray) -> np.ndarray:
41
47
  """Make predictions using the model"""
@@ -44,7 +50,7 @@ class ModelPredictor(ABC):
44
50
 
45
51
  class KerasToQONNXConverter(ModelConverter):
46
52
  """Converts Keras models to QONNX format"""
47
-
53
+
48
54
  def convert(self, model, workspace: Path, **kwargs) -> Path:
49
55
  import tensorflow as tf
50
56
  from qonnx.converters import keras as keras_converter
@@ -52,7 +58,7 @@ class KerasToQONNXConverter(ModelConverter):
52
58
  from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
53
59
  from qonnx.transformation.gemm_to_matmul import GemmToMatMul
54
60
  from qonnx.util.cleanup import cleanup_model
55
-
61
+
56
62
  qonnx_path = workspace / "model_qonnx.onnx"
57
63
  input_signature = [tf.TensorSpec(1, model.inputs[0].dtype, name="input_0")]
58
64
  qonnx_model, _ = keras_converter.from_keras(model, input_signature, output_path=qonnx_path)
@@ -61,26 +67,26 @@ class KerasToQONNXConverter(ModelConverter):
61
67
  qonnx_model = qonnx_model.transform(ConvertToChannelsLastAndClean())
62
68
  qonnx_model = qonnx_model.transform(GemmToMatMul())
63
69
  cleaned_model = cleanup_model(qonnx_model)
64
-
70
+
65
71
  return qonnx_path, cleaned_model
66
72
 
67
73
 
68
74
  class NoOpConverter(ModelConverter):
69
75
  """No-operation converter for models that don't need conversion"""
70
-
71
- def convert(self, model, workspace: Path, **kwargs) -> Optional[Path]:
76
+
77
+ def convert(self, model, workspace: Path, **kwargs) -> Path | None:
72
78
  return None
73
79
 
74
80
 
75
81
  class HLS4MLStrategy(CompilerStrategy):
76
82
  """HLS4ML compilation strategy for Keras models"""
77
-
78
- def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
83
+
84
+ def compile(self, model, workspace: Path, config: dict | None = None, **kwargs) -> Any:
79
85
  import hls4ml
80
-
86
+
81
87
  firmware_dir = workspace / "firmware"
82
88
  firmware_dir.mkdir(exist_ok=True)
83
-
89
+
84
90
  cfg = config or hls4ml.utils.config_from_keras_model(model, granularity="name")
85
91
 
86
92
  hls_kwargs = {
@@ -103,10 +109,10 @@ class HLS4MLStrategy(CompilerStrategy):
103
109
  warnings.warn("Vivado not found in PATH. Firmware build failed.", UserWarning)
104
110
  firmware_model.save(workspace / "firmware_model.fml")
105
111
  return firmware_model
106
-
112
+
107
113
  def load_compiled_model(self, workspace: Path) -> Any:
108
114
  from hls4ml.converters import link_existing_project
109
-
115
+
110
116
  firmware_model = link_existing_project(workspace / "firmware")
111
117
  firmware_model.compile()
112
118
  return firmware_model
@@ -114,17 +120,15 @@ class HLS4MLStrategy(CompilerStrategy):
114
120
  class ConiferStrategy(CompilerStrategy):
115
121
  """Conifer compilation strategy for XGBoost models"""
116
122
 
117
- def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
123
+ def compile(self, model, workspace: Path, config: dict | None = None, **kwargs) -> Any:
118
124
  import conifer
119
- import shutil
120
- import warnings
121
125
 
122
126
  firmware_dir = workspace / "firmware"
123
127
  firmware_dir.mkdir(exist_ok=True)
124
-
128
+
125
129
  cfg = config or conifer.backends.xilinxhls.auto_config()
126
130
  cfg['OutputDir'] = str(firmware_dir)
127
-
131
+
128
132
  for key, value in kwargs.items():
129
133
  cfg[key] = value
130
134
 
@@ -144,31 +148,31 @@ class ConiferStrategy(CompilerStrategy):
144
148
 
145
149
  class DA4MLStrategy(CompilerStrategy):
146
150
  """DA4ML compilation strategy (placeholder)"""
147
-
148
- def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
151
+
152
+ def compile(self, model, workspace: Path, config: dict | None = None) -> Any:
149
153
  raise NotImplementedError("DA4ML conversion without QONNX not yet implemented")
150
-
154
+
151
155
  def load_compiled_model(self, workspace: Path) -> Any:
152
156
  raise NotImplementedError("DA4ML loading not yet implemented")
153
157
 
154
158
 
155
159
  class FINNStrategy(CompilerStrategy):
156
160
  """FINN compilation strategy (placeholder)"""
157
-
158
- def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
161
+
162
+ def compile(self, model, workspace: Path, config: dict | None = None) -> Any:
159
163
  raise NotImplementedError("FINN conversion without QONNX not yet implemented")
160
-
164
+
161
165
  def load_compiled_model(self, workspace: Path) -> Any:
162
166
  raise NotImplementedError("FINN loading not yet implemented")
163
167
 
164
168
 
165
169
  class SoftwarePredictor(ModelPredictor):
166
170
  """Software-based model predictor"""
167
-
171
+
168
172
  def __init__(self, model, backend: str):
169
173
  self.model = model
170
174
  self.backend = backend.lower()
171
-
175
+
172
176
  def predict(self, input_data):
173
177
  if input_data.ndim == 1:
174
178
  input_data = np.expand_dims(input_data, axis=0)
@@ -177,37 +181,37 @@ class SoftwarePredictor(ModelPredictor):
177
181
 
178
182
  class QONNXPredictor(ModelPredictor):
179
183
  """QONNX-based model predictor"""
180
-
184
+
181
185
  def __init__(self, qonnx_model, input_name: str):
182
186
  self.qonnx_model = qonnx_model
183
187
  self.input_name = input_name
184
-
188
+
185
189
  def predict(self, input_data: np.ndarray) -> np.ndarray:
186
190
  from qonnx.core.onnx_exec import execute_onnx
187
-
191
+
188
192
  input_data = np.asarray(input_data)
189
193
  if input_data.ndim == 1:
190
194
  input_data = np.expand_dims(input_data, axis=0)
191
-
195
+
192
196
  outputs = []
193
197
  for i in range(input_data.shape[0]):
194
198
  sample = input_data[i].astype("float32").reshape(1, -1)
195
199
  output_dict = execute_onnx(self.qonnx_model, {self.input_name: sample})
196
200
  outputs.append(output_dict["global_out"])
197
-
201
+
198
202
  return np.vstack(outputs)
199
203
 
200
204
 
201
205
  class FirmwarePredictor(ModelPredictor):
202
206
  """Firmware-based model predictor"""
203
-
207
+
204
208
  def __init__(self, firmware_model, compiler):
205
209
  if firmware_model is None:
206
210
  raise RuntimeError("Firmware model not built.")
207
211
  self.firmware_model = firmware_model
208
212
  self.compiler = compiler
209
-
210
-
213
+
214
+
211
215
  def predict(self, input_data: np.ndarray) -> np.ndarray:
212
216
  if self.compiler == "conifer":
213
217
  return self.firmware_model.decision_function(input_data)
@@ -217,7 +221,7 @@ class FirmwarePredictor(ModelPredictor):
217
221
 
218
222
  class ConverterFactory:
219
223
  """Factory for creating model converters"""
220
-
224
+
221
225
  @staticmethod
222
226
  def create_converter(ml_backend: str, compiler: str) -> ModelConverter:
223
227
  if ml_backend.lower() == "keras" and compiler.lower() == "hls4ml":
@@ -228,12 +232,12 @@ class ConverterFactory:
228
232
 
229
233
  class CompilerFactory:
230
234
  """Factory for creating compilation strategies"""
231
-
235
+
232
236
  @staticmethod
233
237
  def create_compiler(ml_backend: str, compiler: str) -> CompilerStrategy:
234
238
  backend = ml_backend.lower()
235
239
  comp = compiler.lower()
236
-
240
+
237
241
  if backend == "keras" and comp == "hls4ml":
238
242
  return HLS4MLStrategy()
239
243
  elif backend == "xgboost" and comp == "conifer":
@@ -248,9 +252,9 @@ class CompilerFactory:
248
252
 
249
253
  class WorkspaceManager:
250
254
  """Manages workspace directories and metadata"""
251
-
255
+
252
256
  BASE_WORKSPACE = Path.cwd() / "triggermodel"
253
-
257
+
254
258
  def __init__(self):
255
259
  self.workspace = self.BASE_WORKSPACE
256
260
  self.artifacts = {"firmware": None}
@@ -260,7 +264,7 @@ class WorkspaceManager:
260
264
  "compiler": None,
261
265
  "versions": []
262
266
  }
263
-
267
+
264
268
  def setup_workspace(self, name: str, ml_backend: str, compiler: str):
265
269
  """Initialize workspace and metadata"""
266
270
  self.workspace.mkdir(parents=True, exist_ok=True)
@@ -269,22 +273,22 @@ class WorkspaceManager:
269
273
  "ml_backend": ml_backend,
270
274
  "compiler": compiler,
271
275
  })
272
-
276
+
273
277
  def save_native_model(self, model, ml_backend: str):
274
278
  """Save the native model to workspace"""
275
279
  if ml_backend.lower() == "keras":
276
- model.save(self.workspace / "keras_model")
280
+ model.save(self.workspace / "keras_model.keras")
277
281
  elif ml_backend.lower() == "xgboost":
278
282
  model.save_model(str(self.workspace / "xgb_model.json"))
279
-
283
+
280
284
  def add_artifact(self, key: str, value: Any):
281
285
  """Add artifact to tracking"""
282
286
  self.artifacts[key] = value
283
-
284
- def add_version(self, version_info: Dict):
287
+
288
+ def add_version(self, version_info: dict):
285
289
  """Add version information"""
286
290
  self.metadata["versions"].append(version_info)
287
-
291
+
288
292
  def save_metadata(self):
289
293
  """Save metadata to file"""
290
294
  with open(self.workspace / "metadata.json", "w") as f:
@@ -297,7 +301,7 @@ class WorkspaceManager:
297
301
 
298
302
  class ModelSerializer:
299
303
  """Handles model serialization and deserialization"""
300
-
304
+
301
305
  @staticmethod
302
306
  def save(workspace: Path, path: str):
303
307
  """Serialize the workspace into a tar.xz archive"""
@@ -305,37 +309,37 @@ class ModelSerializer:
305
309
  path.parent.mkdir(parents=True, exist_ok=True)
306
310
  with tarfile.open(path, mode="w:xz") as tar:
307
311
  tar.add(workspace, arcname=workspace.name)
308
- print(f"TriggerModel saved to {path}")
309
-
312
+ logger.info(f"TriggerModel saved to {path}")
313
+
310
314
  @staticmethod
311
- def load(path: str) -> Dict[str, Any]:
315
+ def load(path: str) -> dict[str, Any]:
312
316
  """Load workspace from tar.xz archive"""
313
317
  path = Path(path)
314
318
  if not path.exists():
315
319
  raise FileNotFoundError(f"{path} does not exist")
316
-
320
+
317
321
  workspace = Path.cwd() / "triggermodel"
318
-
322
+
319
323
  if workspace.exists():
320
324
  response = input(f"{workspace} already exists. Overwrite? [y/N]: ").strip().lower()
321
325
  if response != "y":
322
- print("Load cancelled by user.")
326
+ logger.info("Load cancelled by user.")
323
327
  return None
324
328
  shutil.rmtree(workspace)
325
-
329
+
326
330
  with tarfile.open(path, mode="r:xz") as tar:
327
331
  tar.extractall(path=Path.cwd())
328
-
332
+
329
333
  # Load metadata
330
334
  metadata_path = workspace / "metadata.json"
331
- with open(metadata_path, "r") as f:
335
+ with open(metadata_path) as f:
332
336
  metadata = json.load(f)
333
-
337
+
334
338
  return {
335
339
  "workspace": workspace,
336
340
  "metadata": metadata
337
341
  }
338
-
342
+
339
343
  @staticmethod
340
344
  def load_native_model(workspace: Path, ml_backend: str):
341
345
  """Load native model from workspace"""
@@ -352,7 +356,7 @@ class ModelSerializer:
352
356
  return model
353
357
  else:
354
358
  raise ValueError(f"Unsupported ml_backend: {ml_backend}")
355
-
359
+
356
360
  @staticmethod
357
361
  def load_qonnx_model(workspace: Path):
358
362
  """Load QONNX model if it exists"""
@@ -366,13 +370,13 @@ class ModelSerializer:
366
370
 
367
371
  class TriggerModel:
368
372
  """Main facade class that orchestrates model conversion, compilation, and inference"""
369
-
373
+
370
374
  def __init__(self, name: str, ml_backend: str, n_outputs:int, compiler: str,
371
375
  native_model: object, compiler_config: dict = None, scales: dict = None):
372
-
376
+
373
377
  if ml_backend.lower() not in ("keras", "xgboost"):
374
378
  raise ValueError("Only Keras or XGBoost backends are currently supported.")
375
-
379
+
376
380
  self.name = name
377
381
  self.ml_backend = ml_backend.lower()
378
382
  self.scales = scales
@@ -380,67 +384,67 @@ class TriggerModel:
380
384
  self.compiler = compiler.lower()
381
385
  self.native_model = native_model
382
386
  self.compiler_conifg = compiler_config
383
-
387
+
384
388
  self.workspace_manager = WorkspaceManager()
385
389
  self.converter = ConverterFactory.create_converter(ml_backend, compiler)
386
390
  self.compiler_strategy = CompilerFactory.create_compiler(ml_backend, compiler)
387
-
391
+
388
392
  self.firmware_model = None
389
393
  self.model_qonnx = None
390
394
  self.input_name = None
391
-
395
+
392
396
  self.workspace_manager.setup_workspace(name, self.ml_backend, self.compiler)
393
-
397
+
394
398
  @property
395
399
  def workspace(self) -> Path:
396
400
  """Get workspace path"""
397
401
  return self.workspace_manager.workspace
398
-
402
+
399
403
  @property
400
- def artifacts(self) -> Dict[str, Any]:
404
+ def artifacts(self) -> dict[str, Any]:
401
405
  """Get artifacts dictionary"""
402
406
  return self.workspace_manager.artifacts
403
-
407
+
404
408
  @property
405
- def metadata(self) -> Dict[str, Any]:
409
+ def metadata(self) -> dict[str, Any]:
406
410
  """Get metadata dictionary"""
407
411
  return self.workspace_manager.metadata
408
-
412
+
409
413
  def __call__(self, **compiler_kwargs):
410
414
  """Execute the full model conversion and compilation pipeline"""
411
415
  self.parse_dataset_object()
412
-
416
+
413
417
  # Save native model
414
418
  self.workspace_manager.save_native_model(self.native_model, self.ml_backend)
415
-
419
+
416
420
  # Convert model if needed
417
421
  conversion_result = self.converter.convert(
418
- self.native_model,
422
+ self.native_model,
419
423
  self.workspace_manager.workspace
420
424
  )
421
-
425
+
422
426
  if conversion_result is not None:
423
427
  qonnx_path, self.model_qonnx = conversion_result
424
428
  self.input_name = self.model_qonnx.graph.input[0].name
425
429
  self.workspace_manager.add_artifact("qonnx", qonnx_path)
426
430
  self.workspace_manager.add_version({"qonnx": str(qonnx_path)})
427
-
431
+
428
432
  # Compile model
429
433
  self.firmware_model = self.compiler_strategy.compile(
430
434
  self.native_model,
431
435
  self.workspace_manager.workspace,
432
- self.compiler_conifg,
436
+ self.compiler_conifg,
433
437
  **compiler_kwargs
434
438
  )
435
-
439
+
436
440
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
437
441
 
438
- if self.compiler is not "conifer" and self.scales is not None:
442
+ if self.compiler != "conifer" and self.scales is not None:
439
443
  self.build_emulator(self.scales['shifts'], self.scales['offsets'], self.n_outputs)
440
-
444
+
441
445
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
442
446
  self.workspace_manager.save_metadata()
443
-
447
+
444
448
  @staticmethod
445
449
  def parse_dataset_object():
446
450
  """Parse dataset object (placeholder)"""
@@ -455,24 +459,24 @@ class TriggerModel:
455
459
  template = template.replace("{{" + k + "}}", str(v))
456
460
  with open(out_path, "w") as f:
457
461
  f.write(template)
458
-
462
+
459
463
  def software_predict(self, input_data: np.ndarray) -> np.ndarray:
460
464
  """Make predictions using software model"""
461
465
  predictor = SoftwarePredictor(self.native_model, self.ml_backend)
462
466
  return predictor.predict(input_data)
463
-
467
+
464
468
  def qonnx_predict(self, input_data: np.ndarray) -> np.ndarray:
465
469
  """Make predictions using QONNX model"""
466
470
  if self.model_qonnx is None:
467
471
  raise RuntimeError("QONNX model not available")
468
472
  predictor = QONNXPredictor(self.model_qonnx, self.input_name)
469
473
  return predictor.predict(input_data)
470
-
474
+
471
475
  def firmware_predict(self, input_data: np.ndarray) -> np.ndarray:
472
476
  """Make predictions using firmware model"""
473
477
  predictor = FirmwarePredictor(self.firmware_model, self.compiler)
474
478
  return predictor.predict(input_data)
475
-
479
+
476
480
  def build_emulator(self, ad_shift: list, ad_offsets: list, n_outputs: int):
477
481
  """
478
482
  Create an emulator directory for this model.
@@ -481,13 +485,13 @@ class TriggerModel:
481
485
  emulator_dir = self.workspace / "emulator"
482
486
  emulator_dir.mkdir(exist_ok=True)
483
487
 
484
- model_dir = emulator_dir / self.name
488
+ model_dir = emulator_dir / self.name
485
489
  model_dir.mkdir(exist_ok=True)
486
-
490
+
487
491
  firmware_dir = self.workspace / "firmware" / "firmware"
488
-
492
+
489
493
  shutil.copytree(firmware_dir, f"{model_dir}/NN", dirs_exist_ok=True)
490
-
494
+
491
495
  # Access scales template from installed package
492
496
  with pkg_resources.path(triggerflow.templates, "scales.h") as scales_template_path:
493
497
  scales_out_path = model_dir / "scales.h"
@@ -503,7 +507,7 @@ class TriggerModel:
503
507
  with pkg_resources.path(triggerflow.templates, "model_template.cpp") as emulator_template_path:
504
508
  emulator_out_path = model_dir / "emulator.cpp"
505
509
  self._render_template(emulator_template_path, emulator_out_path, context)
506
-
510
+
507
511
  with pkg_resources.path(triggerflow.templates, "makefile_version") as makefile_template_path:
508
512
  makefile_out_path = model_dir / "Makefile"
509
513
  self._render_template(makefile_template_path, makefile_out_path, {"MODEL_NAME": self.name})
@@ -511,44 +515,44 @@ class TriggerModel:
511
515
  with pkg_resources.path(triggerflow.templates, "makefile") as makefile_template_path:
512
516
  makefile_out_path = emulator_dir / "Makefile"
513
517
  self._render_template(makefile_template_path, makefile_out_path, {"MODEL_NAME": self.name})
514
-
515
-
518
+
519
+
516
520
  def save(self, path: str):
517
521
  """Save the complete model to an archive"""
518
522
  ModelSerializer.save(self.workspace_manager.workspace, path)
519
-
523
+
520
524
  @classmethod
521
525
  def load(cls, path: str) -> 'TriggerModel':
522
526
  """Load a model from an archive"""
523
527
  load_result = ModelSerializer.load(path)
524
528
  if load_result is None:
525
529
  return None
526
-
530
+
527
531
  workspace = load_result["workspace"]
528
532
  metadata = load_result["metadata"]
529
-
533
+
530
534
  obj = cls.__new__(cls)
531
535
  obj.workspace_manager = WorkspaceManager()
532
536
  obj.workspace_manager.workspace = workspace
533
537
  obj.workspace_manager.metadata = metadata
534
538
  obj.workspace_manager.artifacts = {"firmware": workspace / "firmware"}
535
-
539
+
536
540
  obj.name = metadata.get("name", "")
537
541
  obj.ml_backend = metadata.get("ml_backend")
538
542
  obj.compiler = metadata.get("compiler")
539
-
543
+
540
544
  obj.native_model = ModelSerializer.load_native_model(workspace, obj.ml_backend)
541
-
545
+
542
546
  obj.model_qonnx, obj.input_name = ModelSerializer.load_qonnx_model(workspace)
543
-
547
+
544
548
  if obj.compiler.lower() in ("hls4ml", "conifer"):
545
549
  obj.compiler_strategy = CompilerFactory.create_compiler(obj.ml_backend, obj.compiler)
546
550
  obj.firmware_model = obj.compiler_strategy.load_compiled_model(workspace)
547
551
  else:
548
552
  obj.firmware_model = None
549
553
  obj.compiler_strategy = None
550
-
554
+
551
555
  obj.converter = ConverterFactory.create_converter(obj.ml_backend, obj.compiler)
552
- obj.dataset_object = None
553
-
554
- return obj
556
+ obj.dataset_object = None
557
+
558
+ return obj