triggerflow 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +154 -0
  6. trigger_loader/processor.py +212 -0
  7. triggerflow/__init__.py +0 -0
  8. triggerflow/cli.py +122 -0
  9. triggerflow/core.py +617 -0
  10. triggerflow/interfaces/__init__.py +0 -0
  11. triggerflow/interfaces/uGT.py +187 -0
  12. triggerflow/mlflow_wrapper.py +270 -0
  13. triggerflow/starter/.gitignore +143 -0
  14. triggerflow/starter/README.md +0 -0
  15. triggerflow/starter/cookiecutter.json +5 -0
  16. triggerflow/starter/prompts.yml +9 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +90 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/condor_config.json +11 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/cuda_config.json +4 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +24 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/settings.json +8 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/test.root +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +23 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_loader.py +101 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +49 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_loader.py +32 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +70 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +20 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +41 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +13 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +48 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +31 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  88. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  89. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  90. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  91. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  92. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  93. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  94. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  95. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  96. triggerflow/templates/build_ugt.tcl +46 -0
  97. triggerflow/templates/data_types.h +524 -0
  98. triggerflow/templates/makefile +28 -0
  99. triggerflow/templates/makefile_version +15 -0
  100. triggerflow/templates/model-gt.cpp +104 -0
  101. triggerflow/templates/model_template.cpp +63 -0
  102. triggerflow/templates/scales.h +20 -0
  103. triggerflow-0.3.4.dist-info/METADATA +206 -0
  104. triggerflow-0.3.4.dist-info/RECORD +107 -0
  105. triggerflow-0.3.4.dist-info/WHEEL +5 -0
  106. triggerflow-0.3.4.dist-info/entry_points.txt +2 -0
  107. triggerflow-0.3.4.dist-info/top_level.txt +3 -0
triggerflow/core.py ADDED
@@ -0,0 +1,617 @@
1
+ from pathlib import Path
2
+ import json
3
+ import yaml
4
+ import numpy as np
5
+ import tarfile
6
+ import importlib
7
+ from abc import ABC, abstractmethod
8
+ from typing import Optional, Dict, Any, Union
9
+ import shutil, warnings
10
+ import importlib.resources as pkg_resources
11
+ import triggerflow.templates
12
+ from importlib import import_module
13
+
14
+
15
+ class ModelConverter(ABC):
16
+ """Abstract base class for model converters"""
17
+
18
+ @abstractmethod
19
+ def convert(self, model, workspace: Path, **kwargs) -> Optional[Path]:
20
+ """Convert model to intermediate format"""
21
+ pass
22
+
23
+
24
+ class CompilerStrategy(ABC):
25
+ """Abstract base class for compilation strategies"""
26
+
27
+ @abstractmethod
28
+ def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
29
+ """Compile model to firmware"""
30
+ pass
31
+
32
+ @abstractmethod
33
+ def load_compiled_model(self, workspace: Path) -> Any:
34
+ """Load a previously compiled model"""
35
+ pass
36
+
37
+
38
+ class ModelPredictor(ABC):
39
+ """Abstract base class for model predictors"""
40
+
41
+ @abstractmethod
42
+ def predict(self, input_data: np.ndarray) -> np.ndarray:
43
+ """Make predictions using the model"""
44
+ pass
45
+
46
+
47
+ class KerasToQONNXConverter(ModelConverter):
48
+ """Converts Keras models to QONNX format"""
49
+
50
+ def convert(self, model, workspace: Path, **kwargs) -> Path:
51
+ import tensorflow as tf
52
+ from qonnx.converters import keras as keras_converter
53
+ from qonnx.core.modelwrapper import ModelWrapper
54
+ from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
55
+ from qonnx.transformation.gemm_to_matmul import GemmToMatMul
56
+ from qonnx.util.cleanup import cleanup_model
57
+
58
+ qonnx_path = workspace / "model_qonnx.onnx"
59
+ input_signature = [tf.TensorSpec(1, model.inputs[0].dtype, name="input_0")]
60
+ qonnx_model, _ = keras_converter.from_keras(model, input_signature, output_path=qonnx_path)
61
+ qonnx_model = ModelWrapper(qonnx_model)
62
+ qonnx_model = cleanup_model(qonnx_model)
63
+ qonnx_model = qonnx_model.transform(ConvertToChannelsLastAndClean())
64
+ qonnx_model = qonnx_model.transform(GemmToMatMul())
65
+ cleaned_model = cleanup_model(qonnx_model)
66
+
67
+ return qonnx_path, cleaned_model
68
+
69
+
70
+ class NoOpConverter(ModelConverter):
71
+ """No-operation converter for models that don't need conversion"""
72
+
73
+ def convert(self, model, workspace: Path, **kwargs) -> Optional[Path]:
74
+ return None
75
+
76
+
77
+ class HLS4MLStrategy(CompilerStrategy):
78
+ def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
79
+ import hls4ml
80
+
81
+ firmware_dir = workspace / "firmware"
82
+ firmware_dir.mkdir(exist_ok=True)
83
+
84
+ hls_config = hls4ml.utils.config_from_keras_model(model, granularity="name")
85
+ hls_kwargs = {}
86
+
87
+ for key in ["project_name", "namespace", "io_type", "backend", "write_weights_txt"]:
88
+ if key in config:
89
+ hls_kwargs[key] = config[key]
90
+
91
+ if config and "Model" in config:
92
+ for key, value in config["Model"].items():
93
+ if isinstance(value, dict):
94
+ for layer, layer_config in value.items():
95
+ if layer in hls_config["LayerName"]:
96
+ hls_config["LayerName"][layer][key] = layer_config
97
+ else:
98
+ hls_config["Model"][key] = value
99
+
100
+ firmware_model = hls4ml.converters.convert_from_keras_model(
101
+ model,
102
+ hls_config=hls_config,
103
+ output_dir=str(firmware_dir),
104
+ **hls_kwargs
105
+ )
106
+
107
+ firmware_model.compile()
108
+ return firmware_model
109
+
110
+
111
+ def load_compiled_model(self, workspace: Path) -> Any:
112
+ from hls4ml.converters import link_existing_project
113
+
114
+ firmware_model = link_existing_project(workspace / "firmware")
115
+ firmware_model.compile()
116
+ return firmware_model
117
+
118
+
119
+ class ConiferStrategy(CompilerStrategy):
120
+ """Conifer compilation strategy for XGBoost models, unified config/workspace handling."""
121
+
122
+ def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
123
+ import conifer
124
+ import os
125
+
126
+ firmware_dir = workspace / "firmware"
127
+ firmware_dir.mkdir(exist_ok=True)
128
+
129
+ cfg = conifer.backends.xilinxhls.auto_config()
130
+ cfg['OutputDir'] = str(firmware_dir)
131
+ cfg['ProjectName'] = config['project_name']
132
+ cfg['XilinxPart'] = config['fpga_part']
133
+ cfg['ClockPeriod'] = config['clock_period']
134
+ cfg['Precision'] = config['Precision']
135
+
136
+ if config:
137
+ for key, value in config.items():
138
+ cfg[key] = value
139
+
140
+ firmware_model = conifer.converters.convert_from_xgboost(model, config=cfg)
141
+ firmware_model.compile()
142
+ firmware_model.save(firmware_dir / "firmware_model.fml")
143
+
144
+ return firmware_model
145
+
146
+
147
+ def load_compiled_model(self, workspace: Path) -> Any:
148
+ from conifer import load_model
149
+
150
+ firmware_model = load_model(workspace / "firmware_model.fml")
151
+ firmware_model.compile()
152
+ return firmware_model
153
+
154
+
155
+ class DA4MLStrategy(CompilerStrategy):
156
+ """DA4ML compilation strategy (placeholder)"""
157
+
158
+ def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
159
+ raise NotImplementedError("DA4ML conversion without QONNX not yet implemented")
160
+
161
+ def load_compiled_model(self, workspace: Path) -> Any:
162
+ raise NotImplementedError("DA4ML loading not yet implemented")
163
+
164
+
165
+ class FINNStrategy(CompilerStrategy):
166
+ """FINN compilation strategy (placeholder)"""
167
+
168
+ def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
169
+ raise NotImplementedError("FINN conversion without QONNX not yet implemented")
170
+
171
+ def load_compiled_model(self, workspace: Path) -> Any:
172
+ raise NotImplementedError("FINN loading not yet implemented")
173
+
174
+
175
+ class SoftwarePredictor(ModelPredictor):
176
+ """Software-based model predictor"""
177
+
178
+ def __init__(self, model, backend: str):
179
+ self.model = model
180
+ self.backend = backend.lower()
181
+
182
+ def predict(self, input_data):
183
+ if input_data.ndim == 1:
184
+ input_data = np.expand_dims(input_data, axis=0)
185
+ return self.model.predict(input_data)
186
+
187
+
188
+ class QONNXPredictor(ModelPredictor):
189
+ """QONNX-based model predictor"""
190
+
191
+ def __init__(self, qonnx_model, input_name: str):
192
+ self.qonnx_model = qonnx_model
193
+ self.input_name = input_name
194
+
195
+ def predict(self, input_data: np.ndarray) -> np.ndarray:
196
+ from qonnx.core.onnx_exec import execute_onnx
197
+
198
+ input_data = np.asarray(input_data)
199
+ if input_data.ndim == 1:
200
+ input_data = np.expand_dims(input_data, axis=0)
201
+
202
+ outputs = []
203
+ for i in range(input_data.shape[0]):
204
+ sample = input_data[i].astype("float32").reshape(1, -1)
205
+ output_dict = execute_onnx(self.qonnx_model, {self.input_name: sample})
206
+ outputs.append(output_dict["global_out"])
207
+
208
+ return np.vstack(outputs)
209
+
210
+
211
+ class FirmwarePredictor(ModelPredictor):
212
+ """Firmware-based model predictor"""
213
+
214
+ def __init__(self, firmware_model, compiler):
215
+ if firmware_model is None:
216
+ raise RuntimeError("Firmware model not built.")
217
+ self.firmware_model = firmware_model
218
+ self.compiler = compiler
219
+
220
+
221
+ def predict(self, input_data: np.ndarray) -> np.ndarray:
222
+ if self.compiler == "conifer":
223
+ return self.firmware_model.decision_function(input_data)
224
+ else:
225
+ return self.firmware_model.predict(input_data)
226
+
227
+
228
+ class ConverterFactory:
229
+ """Factory for creating model converters"""
230
+
231
+ @staticmethod
232
+ def create_converter(ml_backend: str, compiler: str) -> ModelConverter:
233
+ if ml_backend.lower() == "keras" and compiler.lower() == "hls4ml":
234
+ import keras
235
+ if not keras.__version__.startswith("3"):
236
+ return KerasToQONNXConverter()
237
+ return NoOpConverter()
238
+
239
+
240
+ class CompilerFactory:
241
+ """Factory for creating compilation strategies"""
242
+
243
+ @staticmethod
244
+ def create_compiler(ml_backend: str, compiler: str) -> CompilerStrategy:
245
+ backend = ml_backend.lower()
246
+ comp = compiler.lower()
247
+
248
+ if backend == "keras" and comp == "hls4ml":
249
+ return HLS4MLStrategy()
250
+ elif backend == "xgboost" and comp == "conifer":
251
+ return ConiferStrategy()
252
+ elif comp == "da4ml":
253
+ return DA4MLStrategy()
254
+ elif comp == "finn":
255
+ return FINNStrategy()
256
+ else:
257
+ raise RuntimeError(f"Unsupported combination: ml_backend={backend}, compiler={comp}")
258
+
259
+
260
+ class WorkspaceManager:
261
+ """Manages workspace directories and metadata"""
262
+
263
+ BASE_WORKSPACE = Path.cwd() / "triggermodel"
264
+
265
+ def __init__(self):
266
+ self.workspace = self.BASE_WORKSPACE
267
+ self.artifacts = {"firmware": None}
268
+ self.metadata = {
269
+ "name": None,
270
+ "ml_backend": None,
271
+ "compiler": None,
272
+ "versions": []
273
+ }
274
+
275
+ def setup_workspace(self, name: str, ml_backend: str, compiler: str):
276
+ """Initialize workspace and metadata"""
277
+ self.workspace.mkdir(parents=True, exist_ok=True)
278
+ self.metadata.update({
279
+ "name": name,
280
+ "ml_backend": ml_backend,
281
+ "compiler": compiler,
282
+ })
283
+
284
+ def save_native_model(self, model, ml_backend: str):
285
+ """Save the native model to workspace"""
286
+ if ml_backend.lower() == "keras":
287
+ model.save(self.workspace / "keras_model.h5")
288
+ elif ml_backend.lower() == "xgboost":
289
+ model.save_model(str(self.workspace / "xgb_model.json"))
290
+
291
+ def add_artifact(self, key: str, value: Any):
292
+ """Add artifact to tracking"""
293
+ self.artifacts[key] = value
294
+
295
+ def add_version(self, version_info: Dict):
296
+ """Add version information"""
297
+ self.metadata["versions"].append(version_info)
298
+
299
+ def save_metadata(self):
300
+ """Save metadata to file"""
301
+ with open(self.workspace / "metadata.json", "w") as f:
302
+ json.dump({
303
+ "name": self.metadata["name"],
304
+ "ml_backend": self.metadata["ml_backend"],
305
+ "compiler": self.metadata["compiler"],
306
+ }, f, indent=2)
307
+
308
+
309
+ class ModelSerializer:
310
+ """Handles model serialization and deserialization"""
311
+
312
+ @staticmethod
313
+ def save(workspace: Path, path: str):
314
+ """Serialize the workspace into a tar.xz archive"""
315
+ path = Path(path)
316
+ path.parent.mkdir(parents=True, exist_ok=True)
317
+ with tarfile.open(path, mode="w:xz") as tar:
318
+ tar.add(workspace, arcname=workspace.name)
319
+ print(f"TriggerModel saved to {path}")
320
+
321
+ @staticmethod
322
+ def load(path: str) -> Dict[str, Any]:
323
+ """Load workspace from tar.xz archive"""
324
+ path = Path(path)
325
+ if not path.exists():
326
+ raise FileNotFoundError(f"{path} does not exist")
327
+
328
+ workspace = Path.cwd() / "triggermodel"
329
+
330
+ if workspace.exists():
331
+ response = input(f"{workspace} already exists. Overwrite? [y/N]: ").strip().lower()
332
+ if response != "y":
333
+ print("Load cancelled by user.")
334
+ return None
335
+ shutil.rmtree(workspace)
336
+
337
+ with tarfile.open(path, mode="r:xz") as tar:
338
+ tar.extractall(path=Path.cwd())
339
+
340
+ # Load metadata
341
+ metadata_path = workspace / "metadata.json"
342
+ with open(metadata_path, "r") as f:
343
+ metadata = json.load(f)
344
+
345
+ return {
346
+ "workspace": workspace,
347
+ "metadata": metadata
348
+ }
349
+
350
+ @staticmethod
351
+ def load_native_model(workspace: Path, ml_backend: str):
352
+ """Load native model from workspace"""
353
+ if ml_backend in ("keras", "qkeras"):
354
+ try:
355
+ tf_keras = importlib.import_module("keras.models")
356
+ except ModuleNotFoundError:
357
+ tf_keras = importlib.import_module("tensorflow.keras.models")
358
+ try:
359
+ return tf_keras.load_model(workspace / "keras_model.h5")
360
+ except:
361
+ try:
362
+ from qkeras.utils import _add_supported_quantized_objects
363
+ co = {}; _add_supported_quantized_objects(co)
364
+ return tf_keras.load_model(workspace / "keras_model.h5", custom_objects=co)
365
+ except:
366
+ print("Native model could not be loaded")
367
+ elif ml_backend == "xgboost":
368
+ import xgboost as xgb
369
+ model = xgb.Booster()
370
+ model.load_model(str(workspace / "xgb_model.json"))
371
+ return model
372
+ else:
373
+ raise ValueError(f"Unsupported ml_backend: {ml_backend}")
374
+
375
+ @staticmethod
376
+ def load_qonnx_model(workspace: Path):
377
+ """Load QONNX model if it exists"""
378
+ qonnx_path = workspace / "model_qonnx.onnx"
379
+ if qonnx_path.exists():
380
+ from qonnx.core.modelwrapper import ModelWrapper
381
+ model = ModelWrapper(str(qonnx_path))
382
+ input_name = model.graph.input[0].name
383
+ return model, input_name
384
+ return None, None
385
+
386
+
387
+ class TriggerModel:
388
+ def __init__(self, config: Union[str, Path, Dict], native_model, scales):
389
+ if isinstance(config, (str, Path)):
390
+ with open(config, "r") as f:
391
+ config = yaml.safe_load(f)
392
+ elif not isinstance(config, dict):
393
+ raise TypeError("config must be a dict or path to a YAML file")
394
+
395
+ self.native_model = native_model
396
+ self.scales = scales
397
+
398
+ self.compiler_cfg = config.get("compiler", {})
399
+ self.subsystem_cfg = config.get("subsystem", {})
400
+
401
+ self.name = self.compiler_cfg.get("name", "model")
402
+ self.ml_backend = self.compiler_cfg.get("ml_backend", "").lower()
403
+ self.compiler = self.compiler_cfg.get("compiler", "").lower()
404
+
405
+ self.n_outputs = self.compiler_cfg.get("n_outputs")
406
+ self.unscaled_type = self.subsystem_cfg.get("unscaled_type", "ap_fixed<16,6>")
407
+ self.namespace = self.compiler_cfg.get("namespace", "triggerflow")
408
+ self.project_name = self.compiler_cfg.get("project_name", "triggerflow")
409
+
410
+ if self.ml_backend not in ("keras", "xgboost"):
411
+ raise ValueError("Unsupported backend")
412
+
413
+ self.workspace_manager = WorkspaceManager()
414
+ self.converter = ConverterFactory.create_converter(self.ml_backend, self.compiler)
415
+ self.compiler_strategy = CompilerFactory.create_compiler(self.ml_backend, self.compiler)
416
+
417
+ self.firmware_model = None
418
+ self.model_qonnx = None
419
+ self.input_name = None
420
+
421
+
422
+ self.workspace_manager.setup_workspace(
423
+ self.name,
424
+ self.ml_backend,
425
+ self.compiler
426
+ )
427
+
428
+ @property
429
+ def workspace(self) -> Path:
430
+ """Get workspace path"""
431
+ return self.workspace_manager.workspace
432
+
433
+ @property
434
+ def artifacts(self) -> Dict[str, Any]:
435
+ """Get artifacts dictionary"""
436
+ return self.workspace_manager.artifacts
437
+
438
+ @property
439
+ def metadata(self) -> Dict[str, Any]:
440
+ """Get metadata dictionary"""
441
+ return self.workspace_manager.metadata
442
+
443
+ def __call__(self):
444
+ """Execute full model conversion and compilation pipeline using YAML config"""
445
+ self.parse_dataset_object()
446
+
447
+ # Save native model
448
+ self.workspace_manager.save_native_model(self.native_model, self.ml_backend)
449
+
450
+ # Convert model if needed
451
+ conversion_result = self.converter.convert(
452
+ self.native_model,
453
+ self.workspace_manager.workspace
454
+ )
455
+
456
+ if conversion_result is not None:
457
+ qonnx_path, self.model_qonnx = conversion_result
458
+ self.input_name = self.model_qonnx.graph.input[0].name
459
+ self.workspace_manager.add_artifact("qonnx", qonnx_path)
460
+ self.workspace_manager.add_version({"qonnx": str(qonnx_path)})
461
+
462
+
463
+ # Compile model
464
+ self.firmware_model = self.compiler_strategy.compile(
465
+ self.native_model,
466
+ self.workspace_manager.workspace,
467
+ self.compiler_cfg,
468
+ **self.compiler_cfg.get("kwargs", {})
469
+ )
470
+
471
+ self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
472
+ if self.compiler != "conifer" and self.scales is not None:
473
+ self.build_emulator(
474
+ self.scales['shifts'],
475
+ self.scales['offsets'],
476
+ self.n_outputs,
477
+ self.unscaled_type
478
+ )
479
+
480
+ subsystem_name = self.subsystem_cfg.get("name", "uGT")
481
+ interface_module = import_module(f"triggerflow.interfaces.{subsystem_name}")
482
+ build_firmware = getattr(interface_module, f"build_{subsystem_name.lower()}_model")
483
+
484
+ build_firmware(
485
+ subsystem_cfg=self.subsystem_cfg,
486
+ compiler_cfg=self.compiler_cfg,
487
+ workspace_manager=self.workspace_manager,
488
+ compiler=self.compiler,
489
+ scales=self.scales,
490
+ name=self.name,
491
+ n_outputs=self.n_outputs
492
+ )
493
+
494
+
495
+
496
+ self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
497
+ self.workspace_manager.save_metadata()
498
+
499
+
500
+ @staticmethod
501
+ def parse_dataset_object():
502
+ """Parse dataset object (placeholder)"""
503
+ pass
504
+
505
+ @staticmethod
506
+ def _render_template(template_path: Path, out_path: Path, context: dict):
507
+ """Simple template substitution"""
508
+ with open(template_path) as f:
509
+ template = f.read()
510
+ for k, v in context.items():
511
+ template = template.replace("{{" + k + "}}", str(v))
512
+ with open(out_path, "w") as f:
513
+ f.write(template)
514
+
515
+ def software_predict(self, input_data: np.ndarray) -> np.ndarray:
516
+ """Make predictions using software model"""
517
+ predictor = SoftwarePredictor(self.native_model, self.ml_backend)
518
+ return predictor.predict(input_data)
519
+
520
+ def qonnx_predict(self, input_data: np.ndarray) -> np.ndarray | None:
521
+ """Make predictions using QONNX model"""
522
+
523
+ if self.model_qonnx is None:
524
+ warnings.warn(
525
+ "QONNX model is not available. Prediction skipped.",
526
+ UserWarning
527
+ )
528
+ return None
529
+
530
+ predictor = QONNXPredictor(self.model_qonnx, self.input_name)
531
+ return predictor.predict(input_data)
532
+
533
+ def firmware_predict(self, input_data: np.ndarray) -> np.ndarray:
534
+ """Make predictions using firmware model"""
535
+ predictor = FirmwarePredictor(self.firmware_model, self.compiler)
536
+ return predictor.predict(input_data)
537
+
538
+ def build_emulator(self, ad_shift: list, ad_offsets: list, n_outputs: int, unscaled_type: str = "ap_fixed<16,6>"):
539
+ """Builds CMSSW emulator"""
540
+
541
+ emulator_dir = self.workspace / "emulator"
542
+ emulator_dir.mkdir(exist_ok=True)
543
+
544
+ model_dir = emulator_dir / self.name
545
+ model_dir.mkdir(exist_ok=True)
546
+
547
+ firmware_dir = self.workspace / "firmware" / "firmware"
548
+
549
+ shutil.copytree(firmware_dir, f"{model_dir}/NN", dirs_exist_ok=True)
550
+
551
+ # Access scales template from installed package
552
+ with pkg_resources.path(triggerflow.templates, "scales.h") as scales_template_path:
553
+ scales_out_path = model_dir / "scales.h"
554
+ context = {
555
+ "MODEL_NAME": self.name,
556
+ "N_INPUTS": len(ad_shift),
557
+ "N_OUTPUTS": n_outputs,
558
+ "AD_SHIFT": ", ".join(map(str, ad_shift)),
559
+ "AD_OFFSETS": ", ".join(map(str, ad_offsets)),
560
+ "UNSCALED_TYPE": unscaled_type,
561
+ "NAMESPACE": self.namespace,
562
+ "PROJECT_NAME": self.project_name,
563
+ }
564
+ self._render_template(scales_template_path, scales_out_path, context)
565
+
566
+ with pkg_resources.path(triggerflow.templates, "model_template.cpp") as emulator_template_path:
567
+ emulator_out_path = model_dir / "emulator.cpp"
568
+ self._render_template(emulator_template_path, emulator_out_path, context)
569
+
570
+ with pkg_resources.path(triggerflow.templates, "makefile_version") as makefile_template_path:
571
+ makefile_out_path = model_dir / "Makefile"
572
+ self._render_template(makefile_template_path, makefile_out_path, {"MODEL_NAME": self.name})
573
+
574
+ with pkg_resources.path(triggerflow.templates, "makefile") as makefile_template_path:
575
+ makefile_out_path = emulator_dir / "Makefile"
576
+ self._render_template(makefile_template_path, makefile_out_path, {"MODEL_NAME": self.name})
577
+
578
+
579
+ def save(self, path: str):
580
+ """Save the complete model to an archive"""
581
+ ModelSerializer.save(self.workspace_manager.workspace, path)
582
+
583
+ @classmethod
584
+ def load(cls, path: str) -> 'TriggerModel':
585
+ """Load a model from an archive"""
586
+ load_result = ModelSerializer.load(path)
587
+ if load_result is None:
588
+ return None
589
+
590
+ workspace = load_result["workspace"]
591
+ metadata = load_result["metadata"]
592
+
593
+ obj = cls.__new__(cls)
594
+ obj.workspace_manager = WorkspaceManager()
595
+ obj.workspace_manager.workspace = workspace
596
+ obj.workspace_manager.metadata = metadata
597
+ obj.workspace_manager.artifacts = {"firmware": workspace / "firmware"}
598
+
599
+ obj.name = metadata.get("name", "")
600
+ obj.ml_backend = metadata.get("ml_backend")
601
+ obj.compiler = metadata.get("compiler")
602
+
603
+ obj.native_model = ModelSerializer.load_native_model(workspace, obj.ml_backend)
604
+
605
+ obj.model_qonnx, obj.input_name = ModelSerializer.load_qonnx_model(workspace)
606
+
607
+ if obj.compiler.lower() in ("hls4ml", "conifer"):
608
+ obj.compiler_strategy = CompilerFactory.create_compiler(obj.ml_backend, obj.compiler)
609
+ obj.firmware_model = obj.compiler_strategy.load_compiled_model(workspace)
610
+ else:
611
+ obj.firmware_model = None
612
+ obj.compiler_strategy = None
613
+
614
+ obj.converter = ConverterFactory.create_converter(obj.ml_backend, obj.compiler)
615
+ obj.dataset_object = None
616
+
617
+ return obj
File without changes