triggerflow 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: triggerflow
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Utilities for ML models targeting hardware triggers
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "triggerflow"
7
- version = "0.1.4"
7
+ version = "0.1.6"
8
8
  description = "Utilities for ML models targeting hardware triggers"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -114,25 +114,32 @@ class HLS4MLStrategy(CompilerStrategy):
114
114
 
115
115
  class ConiferStrategy(CompilerStrategy):
116
116
  """Conifer compilation strategy for XGBoost models"""
117
-
118
- def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
117
+
118
+ def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
119
119
  import conifer
120
-
120
+ import os
121
+ import shutil
122
+ import warnings
123
+
121
124
  firmware_dir = workspace / "firmware"
122
125
  firmware_dir.mkdir(exist_ok=True)
123
-
126
+
124
127
  cfg = config or conifer.backends.xilinxhls.auto_config()
128
+ cfg['OutputDir'] = str(firmware_dir)
129
+ cfg.update(kwargs)
130
+
125
131
  firmware_model = conifer.converters.convert_from_xgboost(
126
132
  model,
127
- config=cfg,
128
- output_dir=str(firmware_dir)
133
+ config=cfg
129
134
  )
135
+
130
136
  firmware_model.compile()
131
137
  if shutil.which("vivado") is not None:
132
- firmware_model.build()
138
+ firmware_model.build()
133
139
  else:
134
140
  warnings.warn("Vivado not found in PATH. Firmware build failed.", UserWarning)
135
- firmware_model.save(workspace / "firmware_model.fml")
141
+
142
+ firmware_model.save(firmware_dir / "firmware_model.fml")
136
143
  return firmware_model
137
144
 
138
145
  def load_compiled_model(self, workspace: Path) -> Any:
@@ -202,13 +209,18 @@ class QONNXPredictor(ModelPredictor):
202
209
  class FirmwarePredictor(ModelPredictor):
203
210
  """Firmware-based model predictor"""
204
211
 
205
- def __init__(self, firmware_model):
212
+ def __init__(self, firmware_model, compiler):
206
213
  if firmware_model is None:
207
214
  raise RuntimeError("Firmware model not built.")
208
215
  self.firmware_model = firmware_model
216
+ self.compiler = compiler
217
+
209
218
 
210
219
  def predict(self, input_data: np.ndarray) -> np.ndarray:
211
- return self.firmware_model.predict(input_data)
220
+ if self.compiler == "conifer":
221
+ return self.firmware_model.decision_function(input_data)
222
+ else:
223
+ return self.firmware_model.predict(input_data)
212
224
 
213
225
 
214
226
  class ConverterFactory:
@@ -363,8 +375,8 @@ class ModelSerializer:
363
375
  class TriggerModel:
364
376
  """Main facade class that orchestrates model conversion, compilation, and inference"""
365
377
 
366
- def __init__(self, name: str, ml_backend: str, scales: dict, n_outputs:int, compiler: str,
367
- native_model: object, dataset_object: object, compiler_config: dict = None):
378
+ def __init__(self, name: str, ml_backend: str, n_outputs:int, compiler: str,
379
+ native_model: object, compiler_config: dict = None, scales: dict = None):
368
380
 
369
381
  if ml_backend.lower() not in ("keras", "xgboost"):
370
382
  raise ValueError("Only Keras or XGBoost backends are currently supported.")
@@ -375,7 +387,6 @@ class TriggerModel:
375
387
  self.n_outputs = n_outputs
376
388
  self.compiler = compiler.lower()
377
389
  self.native_model = native_model
378
- self.dataset_object = dataset_object
379
390
  self.compiler_conifg = compiler_config
380
391
 
381
392
  self.workspace_manager = WorkspaceManager()
@@ -432,7 +443,8 @@ class TriggerModel:
432
443
 
433
444
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
434
445
 
435
- self.build_emulator(self.scales['shifts'], self.scales['offsets'], self.n_outputs)
446
+ if self.compiler is not "conifer" and self.scales is not None:
447
+ self.build_emulator(self.scales['shifts'], self.scales['offsets'], self.n_outputs)
436
448
 
437
449
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
438
450
  self.workspace_manager.save_metadata()
@@ -466,7 +478,7 @@ class TriggerModel:
466
478
 
467
479
  def firmware_predict(self, input_data: np.ndarray) -> np.ndarray:
468
480
  """Make predictions using firmware model"""
469
- predictor = FirmwarePredictor(self.firmware_model)
481
+ predictor = FirmwarePredictor(self.firmware_model, self.compiler)
470
482
  return predictor.predict(input_data)
471
483
 
472
484
  def build_emulator(self, ad_shift: list, ad_offsets: list, n_outputs: int):
@@ -5,7 +5,7 @@ import tempfile
5
5
  from pathlib import Path
6
6
  from typing import Dict, Any
7
7
  from mlflow.tracking import MlflowClient
8
- from core import TriggerModel
8
+ from .core import TriggerModel
9
9
 
10
10
 
11
11
  class MLflowWrapper(mlflow.pyfunc.PythonModel):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: triggerflow
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Utilities for ML models targeting hardware triggers
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -1,5 +1,6 @@
1
1
  import pytest
2
2
  from triggerflow.core import TriggerModel
3
+ from triggerflow.mlflow_wrapper import log_model
3
4
  import numpy as np
4
5
  from qkeras.qlayers import QDense, QActivation
5
6
  from qkeras.quantizers import quantized_bits
@@ -46,11 +47,10 @@ def test_predict():
46
47
  trigger_model = TriggerModel(
47
48
  name=name,
48
49
  ml_backend="Keras",
49
- scales=scales,
50
50
  n_outputs=int(1),
51
51
  compiler="hls4ml",
52
52
  native_model=dummy_model,
53
- dataset_object=None,
53
+ scales=scales,
54
54
  compiler_config=None
55
55
  )
56
56
  trigger_model(project_name = name+"_project", namespace = name, write_weights_txt = False)
File without changes
File without changes