triggerflow 0.1.4__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +147 -0
  6. trigger_loader/processor.py +211 -0
  7. triggerflow/cli.py +122 -0
  8. triggerflow/core.py +127 -69
  9. triggerflow/interfaces/__init__.py +0 -0
  10. triggerflow/interfaces/uGT.py +127 -0
  11. triggerflow/mlflow_wrapper.py +190 -19
  12. triggerflow/starter/.gitignore +143 -0
  13. triggerflow/starter/README.md +0 -0
  14. triggerflow/starter/cookiecutter.json +5 -0
  15. triggerflow/starter/prompts.yml +9 -0
  16. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +84 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +15 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples_dummy.json +26 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +21 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +88 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +50 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +10 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +40 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +12 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +31 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +29 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  88. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  89. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  90. triggerflow/templates/build_ugt.tcl +46 -0
  91. triggerflow/templates/data_types.h +524 -0
  92. triggerflow/templates/makefile +3 -3
  93. triggerflow/templates/makefile_version +2 -2
  94. triggerflow/templates/model-gt.cpp +104 -0
  95. triggerflow/templates/model_template.cpp +19 -18
  96. triggerflow/templates/scales.h +1 -1
  97. triggerflow-0.2.4.dist-info/METADATA +192 -0
  98. triggerflow-0.2.4.dist-info/RECORD +102 -0
  99. triggerflow-0.2.4.dist-info/entry_points.txt +2 -0
  100. triggerflow-0.2.4.dist-info/top_level.txt +3 -0
  101. triggerflow-0.1.4.dist-info/METADATA +0 -61
  102. triggerflow-0.1.4.dist-info/RECORD +0 -11
  103. triggerflow-0.1.4.dist-info/top_level.txt +0 -1
  104. {triggerflow-0.1.4.dist-info → triggerflow-0.2.4.dist-info}/WHEEL +0 -0
triggerflow/core.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from pathlib import Path
2
2
  import json
3
+ import yaml
3
4
  import numpy as np
4
5
  import tarfile
5
6
  import importlib
@@ -8,6 +9,7 @@ from typing import Optional, Dict, Any, Union
8
9
  import shutil, warnings
9
10
  import importlib.resources as pkg_resources
10
11
  import triggerflow.templates
12
+ from triggerflow.interfaces.uGT import build_ugt_model
11
13
 
12
14
 
13
15
  class ModelConverter(ABC):
@@ -73,37 +75,31 @@ class NoOpConverter(ModelConverter):
73
75
 
74
76
 
75
77
  class HLS4MLStrategy(CompilerStrategy):
76
- """HLS4ML compilation strategy for Keras models"""
77
-
78
- def compile(self, model, workspace: Path, config: Optional[Dict] = None, **kwargs) -> Any:
78
+ def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
79
79
  import hls4ml
80
-
80
+
81
81
  firmware_dir = workspace / "firmware"
82
82
  firmware_dir.mkdir(exist_ok=True)
83
-
84
- cfg = config or hls4ml.utils.config_from_keras_model(model, granularity="name")
85
83
 
86
- hls_kwargs = {
87
- "hls_config": cfg,
88
- "output_dir": str(firmware_dir),
89
- "io_type": "io_stream",
90
- "backend": "Vitis"
91
- }
92
- hls_kwargs.update(kwargs)
84
+ hls_config = hls4ml.utils.config_from_keras_model(model, granularity="name")
85
+ hls_kwargs = {}
86
+
87
+ for key in ["project_name", "namespace", "io_type", "backend", "write_weights_txt"]:
88
+ if key in config:
89
+ hls_kwargs[key] = config[key]
93
90
 
94
91
  firmware_model = hls4ml.converters.convert_from_keras_model(
95
92
  model,
93
+ hls_config=hls_config,
94
+ output_dir=str(firmware_dir),
96
95
  **hls_kwargs
97
96
  )
98
97
 
99
98
  firmware_model.compile()
100
- if shutil.which("vivado") is not None:
101
- firmware_model.build()
102
- else:
103
- warnings.warn("Vivado not found in PATH. Firmware build failed.", UserWarning)
104
99
  firmware_model.save(workspace / "firmware_model.fml")
105
100
  return firmware_model
106
-
101
+
102
+
107
103
  def load_compiled_model(self, workspace: Path) -> Any:
108
104
  from hls4ml.converters import link_existing_project
109
105
 
@@ -113,27 +109,31 @@ class HLS4MLStrategy(CompilerStrategy):
113
109
 
114
110
 
115
111
  class ConiferStrategy(CompilerStrategy):
116
- """Conifer compilation strategy for XGBoost models"""
117
-
112
+ """Conifer compilation strategy for XGBoost models, unified config/workspace handling."""
113
+
118
114
  def compile(self, model, workspace: Path, config: Optional[Dict] = None) -> Any:
119
115
  import conifer
120
-
116
+ import os
117
+
121
118
  firmware_dir = workspace / "firmware"
122
119
  firmware_dir.mkdir(exist_ok=True)
123
-
124
- cfg = config or conifer.backends.xilinxhls.auto_config()
125
- firmware_model = conifer.converters.convert_from_xgboost(
126
- model,
127
- config=cfg,
128
- output_dir=str(firmware_dir)
129
- )
120
+
121
+ cfg = conifer.backends.xilinxhls.auto_config()
122
+ cfg['OutputDir'] = str(firmware_dir)
123
+ cfg['ProjectName'] = config['project_name']
124
+ cfg['XilinxPart'] = config['fpga_part']
125
+ cfg['ClockPeriod'] = config['clock_period']
126
+
127
+ if config:
128
+ for key, value in config.items():
129
+ cfg[key] = value
130
+
131
+ firmware_model = conifer.converters.convert_from_xgboost(model, config=cfg)
130
132
  firmware_model.compile()
131
- if shutil.which("vivado") is not None:
132
- firmware_model.build()
133
- else:
134
- warnings.warn("Vivado not found in PATH. Firmware build failed.", UserWarning)
135
- firmware_model.save(workspace / "firmware_model.fml")
133
+ firmware_model.save(firmware_dir / "firmware_model.fml")
134
+
136
135
  return firmware_model
136
+
137
137
 
138
138
  def load_compiled_model(self, workspace: Path) -> Any:
139
139
  from conifer import load_model
@@ -202,13 +202,18 @@ class QONNXPredictor(ModelPredictor):
202
202
  class FirmwarePredictor(ModelPredictor):
203
203
  """Firmware-based model predictor"""
204
204
 
205
- def __init__(self, firmware_model):
205
+ def __init__(self, firmware_model, compiler):
206
206
  if firmware_model is None:
207
207
  raise RuntimeError("Firmware model not built.")
208
208
  self.firmware_model = firmware_model
209
+ self.compiler = compiler
210
+
209
211
 
210
212
  def predict(self, input_data: np.ndarray) -> np.ndarray:
211
- return self.firmware_model.predict(input_data)
213
+ if self.compiler == "conifer":
214
+ return self.firmware_model.decision_function(input_data)
215
+ else:
216
+ return self.firmware_model.predict(input_data)
212
217
 
213
218
 
214
219
  class ConverterFactory:
@@ -360,33 +365,45 @@ class ModelSerializer:
360
365
  return model, input_name
361
366
  return None, None
362
367
 
368
+
363
369
  class TriggerModel:
364
- """Main facade class that orchestrates model conversion, compilation, and inference"""
365
-
366
- def __init__(self, name: str, ml_backend: str, scales: dict, n_outputs:int, compiler: str,
367
- native_model: object, dataset_object: object, compiler_config: dict = None):
368
-
369
- if ml_backend.lower() not in ("keras", "xgboost"):
370
- raise ValueError("Only Keras or XGBoost backends are currently supported.")
371
-
372
- self.name = name
373
- self.ml_backend = ml_backend.lower()
374
- self.scales = scales
375
- self.n_outputs = n_outputs
376
- self.compiler = compiler.lower()
370
+ def __init__(self, config: Union[str, Path, Dict], native_model, scales):
371
+ if isinstance(config, (str, Path)):
372
+ with open(config, "r") as f:
373
+ config = yaml.safe_load(f)
374
+ elif not isinstance(config, dict):
375
+ raise TypeError("config must be a dict or path to a YAML file")
376
+
377
377
  self.native_model = native_model
378
- self.dataset_object = dataset_object
379
- self.compiler_conifg = compiler_config
380
-
378
+ self.scales = scales
379
+
380
+ self.compiler_cfg = config.get("compiler", {})
381
+ self.subsystem_cfg = config.get("subsystem", {})
382
+
383
+ self.name = self.compiler_cfg.get("name", "model")
384
+ self.ml_backend = self.compiler_cfg.get("ml_backend", "").lower()
385
+ self.compiler = self.compiler_cfg.get("compiler", "").lower()
386
+
387
+ self.n_outputs = self.compiler_cfg.get("n_outputs")
388
+ self.unscaled_type = self.subsystem_cfg.get("unscaled_type", "ap_fixed<16,6>")
389
+
390
+ if self.ml_backend not in ("keras", "xgboost"):
391
+ raise ValueError("Unsupported backend")
392
+
381
393
  self.workspace_manager = WorkspaceManager()
382
- self.converter = ConverterFactory.create_converter(ml_backend, compiler)
383
- self.compiler_strategy = CompilerFactory.create_compiler(ml_backend, compiler)
384
-
394
+ self.converter = ConverterFactory.create_converter(self.ml_backend, self.compiler)
395
+ self.compiler_strategy = CompilerFactory.create_compiler(self.ml_backend, self.compiler)
396
+
385
397
  self.firmware_model = None
386
398
  self.model_qonnx = None
387
399
  self.input_name = None
388
400
 
389
- self.workspace_manager.setup_workspace(name, self.ml_backend, self.compiler)
401
+
402
+ self.workspace_manager.setup_workspace(
403
+ self.name,
404
+ self.ml_backend,
405
+ self.compiler
406
+ )
390
407
 
391
408
  @property
392
409
  def workspace(self) -> Path:
@@ -403,8 +420,8 @@ class TriggerModel:
403
420
  """Get metadata dictionary"""
404
421
  return self.workspace_manager.metadata
405
422
 
406
- def __call__(self, **compiler_kwargs):
407
- """Execute the full model conversion and compilation pipeline"""
423
+ def __call__(self):
424
+ """Execute full model conversion and compilation pipeline using YAML config"""
408
425
  self.parse_dataset_object()
409
426
 
410
427
  # Save native model
@@ -421,21 +438,57 @@ class TriggerModel:
421
438
  self.input_name = self.model_qonnx.graph.input[0].name
422
439
  self.workspace_manager.add_artifact("qonnx", qonnx_path)
423
440
  self.workspace_manager.add_version({"qonnx": str(qonnx_path)})
441
+
424
442
 
425
443
  # Compile model
426
444
  self.firmware_model = self.compiler_strategy.compile(
427
445
  self.native_model,
428
446
  self.workspace_manager.workspace,
429
- self.compiler_conifg,
430
- **compiler_kwargs
447
+ self.compiler_cfg,
448
+ **self.compiler_cfg.get("kwargs", {})
431
449
  )
432
450
 
433
451
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
452
+ if self.compiler != "conifer" and self.scales is not None:
453
+ self.build_emulator(
454
+ self.scales['shifts'],
455
+ self.scales['offsets'],
456
+ self.n_outputs,
457
+ self.unscaled_type
458
+ )
434
459
 
435
- self.build_emulator(self.scales['shifts'], self.scales['offsets'], self.n_outputs)
436
460
 
461
+ if shutil.which("vivado") is not None:
462
+ build_ugt_model(
463
+ templates_dir=self.subsystem_cfg.get("templates_dir", Path("templates")),
464
+ firmware_dir=self.workspace_manager.workspace / "firmware",
465
+ compiler = self.compiler,
466
+ model_name=self.name,
467
+ n_inputs=self.subsystem_cfg["n_inputs"],
468
+ n_outputs=self.subsystem_cfg.get("n_outputs", self.n_outputs),
469
+ nn_offsets=self.scales["offsets"],
470
+ nn_shifts=self.scales["shifts"],
471
+ muon_size=self.subsystem_cfg.get("muon_size", 0),
472
+ jet_size=self.subsystem_cfg.get("jet_size", 0),
473
+ egamma_size=self.subsystem_cfg.get("egamma_size", 0),
474
+ tau_size=self.subsystem_cfg.get("tau_size", 0),
475
+ output_type=self.subsystem_cfg.get("output_type", "result_t"),
476
+ offset_type=self.subsystem_cfg.get("offset_type", "ap_fixed<10,10>"),
477
+ shift_type=self.subsystem_cfg.get("shift_type", "ap_fixed<10,10>"),
478
+ object_features=self.subsystem_cfg.get("object_features"),
479
+ global_features=self.subsystem_cfg.get("global_features")
480
+ )
481
+ else:
482
+ warnings.warn(
483
+ "Vivado executable not found on the system PATH. "
484
+ "Skipping FW build. ",
485
+ UserWarning
486
+ )
487
+
488
+
437
489
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
438
490
  self.workspace_manager.save_metadata()
491
+
439
492
 
440
493
  @staticmethod
441
494
  def parse_dataset_object():
@@ -457,23 +510,27 @@ class TriggerModel:
457
510
  predictor = SoftwarePredictor(self.native_model, self.ml_backend)
458
511
  return predictor.predict(input_data)
459
512
 
460
- def qonnx_predict(self, input_data: np.ndarray) -> np.ndarray:
513
+ def qonnx_predict(self, input_data: np.ndarray) -> np.ndarray | None:
461
514
  """Make predictions using QONNX model"""
515
+
462
516
  if self.model_qonnx is None:
463
- raise RuntimeError("QONNX model not available")
517
+ warnings.warn(
518
+ "QONNX model is not available. Prediction skipped.",
519
+ UserWarning
520
+ )
521
+ return None
522
+
464
523
  predictor = QONNXPredictor(self.model_qonnx, self.input_name)
465
524
  return predictor.predict(input_data)
466
525
 
467
526
  def firmware_predict(self, input_data: np.ndarray) -> np.ndarray:
468
527
  """Make predictions using firmware model"""
469
- predictor = FirmwarePredictor(self.firmware_model)
528
+ predictor = FirmwarePredictor(self.firmware_model, self.compiler)
470
529
  return predictor.predict(input_data)
471
530
 
472
- def build_emulator(self, ad_shift: list, ad_offsets: list, n_outputs: int):
473
- """
474
- Create an emulator directory for this model.
475
- Copies HLS sources and generates emulator scaffolding.
476
- """
531
+ def build_emulator(self, ad_shift: list, ad_offsets: list, n_outputs: int, unscaled_type: str = "ap_fixed<16,6>"):
532
+ """Builds CMSSW emulator"""
533
+
477
534
  emulator_dir = self.workspace / "emulator"
478
535
  emulator_dir.mkdir(exist_ok=True)
479
536
 
@@ -493,12 +550,13 @@ class TriggerModel:
493
550
  "N_OUTPUTS": n_outputs,
494
551
  "AD_SHIFT": ", ".join(map(str, ad_shift)),
495
552
  "AD_OFFSETS": ", ".join(map(str, ad_offsets)),
553
+ "UNSCALED_TYPE": unscaled_type,
496
554
  }
497
555
  self._render_template(scales_template_path, scales_out_path, context)
498
556
 
499
557
  with pkg_resources.path(triggerflow.templates, "model_template.cpp") as emulator_template_path:
500
558
  emulator_out_path = model_dir / "emulator.cpp"
501
- self._render_template(emulator_template_path, emulator_out_path, context)
559
+ self._render_template(emulator_template_path, emulator_out_path, context)
502
560
 
503
561
  with pkg_resources.path(triggerflow.templates, "makefile_version") as makefile_template_path:
504
562
  makefile_out_path = model_dir / "Makefile"
File without changes
@@ -0,0 +1,127 @@
1
+ from pathlib import Path
2
+ import shutil
3
+ import pkg_resources
4
+ from jinja2 import Template
5
+ import re
6
+
7
+ def _render_template(template_file: str, output_file: Path, context: dict):
8
+ with open(template_file, "r") as f:
9
+ template_text = f.read()
10
+
11
+ template = Template(template_text)
12
+ rendered = template.render(**context)
13
+
14
+ with open(output_file, "w") as f:
15
+ f.write(rendered)
16
+
17
+ def build_ugt_model(
18
+ templates_dir: Path,
19
+ firmware_dir: Path,
20
+ compiler: str,
21
+ model_name: str,
22
+ n_inputs: int,
23
+ n_outputs: int,
24
+ nn_offsets: list,
25
+ nn_shifts: list,
26
+ muon_size: int,
27
+ jet_size: int,
28
+ egamma_size: int,
29
+ tau_size: int,
30
+ output_type: str = "result_t",
31
+ offset_type: str = "ap_fixed<10,10>",
32
+ shift_type: str = "ap_fixed<10,10>",
33
+ object_features: dict = None,
34
+ global_features: list = None
35
+ ):
36
+ """
37
+ Render uGT top func.
38
+ """
39
+
40
+
41
+ if object_features is None:
42
+ object_features = {
43
+ "muons": ["pt", "eta_extrapolated", "phi_extrapolated"],
44
+ "jets": ["et", "eta", "phi"],
45
+ "egammas": ["et", "eta", "phi"],
46
+ "taus": ["et", "eta", "phi"]
47
+ }
48
+
49
+ if global_features is None:
50
+ global_features = [
51
+ "et.et",
52
+ "ht.et",
53
+ "etmiss.et", "etmiss.phi",
54
+ "htmiss.et", "htmiss.phi",
55
+ "ethfmiss.et", "ethfmiss.phi",
56
+ "hthfmiss.et", "hthfmiss.phi"
57
+ ]
58
+
59
+ header_path = firmware_dir / "firmware" / f"{model_name}_project.h"
60
+ if compiler.lower() == "conifer":
61
+ output_layer = "score"
62
+ output_type = "score_arr_t"
63
+ header_path = firmware_dir / "firmware" / f"{model_name}_project.h"
64
+ removal_pattern = re.compile(
65
+ r',\s*score_t\s+tree_scores\[BDT::fn_classes\(n_classes\)\s*\*\s*n_trees\]',
66
+ re.DOTALL
67
+ )
68
+ modified_content = removal_pattern.sub('', header_path.read_text(encoding='utf-8'))
69
+ header_path.write_text(modified_content, encoding='utf-8')
70
+ out = output_layer
71
+ else:
72
+ header_content = header_path.read_text(encoding='utf-8')
73
+ layer_pattern = re.compile(
74
+ r'result_t\s+(\w+)\[\d+\]\s*\)',
75
+ re.DOTALL
76
+ )
77
+ match = layer_pattern.search(header_content)
78
+ layer_name = match.group(1)
79
+ output_layer = f"{layer_name}[{n_outputs}]"
80
+ out = layer_name
81
+
82
+
83
+ context = {
84
+ "MODEL_NAME": model_name,
85
+ "N_INPUTS": n_inputs,
86
+ "N_OUTPUTS": n_outputs,
87
+ "NN_OFFSETS": ", ".join(map(str, nn_offsets)),
88
+ "NN_SHIFTS": ", ".join(map(str, nn_shifts)),
89
+ "MUON_SIZE": muon_size,
90
+ "JET_SIZE": jet_size,
91
+ "EGAMMA_SIZE": egamma_size,
92
+ "TAU_SIZE": tau_size,
93
+ "OUTPUT_TYPE": output_type,
94
+ "OUTPUT_LAYER": output_layer,
95
+ "OUT": out,
96
+ "OFFSET_TYPE": offset_type,
97
+ "SHIFT_TYPE": shift_type,
98
+ "MUON_FEATURES": object_features["muons"],
99
+ "JET_FEATURES": object_features["jets"],
100
+ "EGAMMA_FEATURES": object_features["egammas"],
101
+ "TAU_FEATURES": object_features["taus"],
102
+ "GLOBAL_FEATURES": global_features
103
+ }
104
+
105
+ context_tcl = {
106
+ "MODEL_NAME": model_name,
107
+ }
108
+
109
+ out_path = firmware_dir / "firmware/model-gt.cpp"
110
+
111
+ _render_template(f"{templates_dir}/model-gt.cpp", out_path, context)
112
+
113
+ out_path = firmware_dir / "firmware/build_ugt.tcl"
114
+ _render_template(f"{templates_dir}/build_ugt.tcl", out_path, context_tcl)
115
+
116
+ shutil.copy(f"{templates_dir}/data_types.h", firmware_dir / "firmware")
117
+
118
+
119
+ subprocess.run(
120
+ ["vitis_hls", "-f", "build_ugt.tcl"],
121
+ cwd=firmware_dir/"firmware",
122
+ check=True
123
+ )
124
+
125
+
126
+
127
+