triggerflow 0.1.12__tar.gz → 0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. triggerflow-0.2/MANIFEST.in +2 -0
  2. triggerflow-0.2/PKG-INFO +97 -0
  3. triggerflow-0.2/README.md +77 -0
  4. triggerflow-0.2/pyproject.toml +49 -0
  5. triggerflow-0.2/src/trigger_dataset/core.py +88 -0
  6. triggerflow-0.2/src/trigger_loader/__init__.py +0 -0
  7. triggerflow-0.2/src/trigger_loader/cluster_manager.py +107 -0
  8. triggerflow-0.2/src/trigger_loader/loader.py +95 -0
  9. triggerflow-0.2/src/trigger_loader/processor.py +211 -0
  10. triggerflow-0.2/src/triggerflow/__init__.py +0 -0
  11. triggerflow-0.2/src/triggerflow/cli.py +122 -0
  12. {triggerflow-0.1.12 → triggerflow-0.2}/src/triggerflow/core.py +118 -114
  13. {triggerflow-0.1.12 → triggerflow-0.2}/src/triggerflow/mlflow_wrapper.py +54 -49
  14. triggerflow-0.2/src/triggerflow/starter/.gitignore +143 -0
  15. triggerflow-0.2/src/triggerflow/starter/README.md +0 -0
  16. triggerflow-0.2/src/triggerflow/starter/cookiecutter.json +5 -0
  17. triggerflow-0.2/src/triggerflow/starter/prompts.yml +9 -0
  18. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  19. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  20. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  21. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  22. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  23. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  24. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  25. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  26. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  27. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  28. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  29. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  30. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +84 -0
  31. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  32. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  33. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  34. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  35. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  36. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  37. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  38. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  39. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +15 -0
  40. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples_dummy.json +26 -0
  41. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  42. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  43. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  44. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  45. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  46. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  47. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  48. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +21 -0
  49. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  50. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  51. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  52. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  53. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  54. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +88 -0
  55. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  56. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  57. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  58. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  59. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  60. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  61. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +50 -0
  62. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +10 -0
  63. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  64. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +40 -0
  65. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  66. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  67. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +12 -0
  68. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  69. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  70. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +31 -0
  71. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  72. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  73. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +29 -0
  74. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  75. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  76. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  77. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  78. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  79. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  80. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  81. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  82. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  83. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  84. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  85. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  86. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  87. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  88. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  89. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  90. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  91. triggerflow-0.2/src/triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  92. triggerflow-0.2/src/triggerflow.egg-info/PKG-INFO +97 -0
  93. triggerflow-0.2/src/triggerflow.egg-info/SOURCES.txt +103 -0
  94. triggerflow-0.2/src/triggerflow.egg-info/entry_points.txt +2 -0
  95. triggerflow-0.2/src/triggerflow.egg-info/requires.txt +11 -0
  96. triggerflow-0.2/src/triggerflow.egg-info/top_level.txt +3 -0
  97. {triggerflow-0.1.12 → triggerflow-0.2}/tests/test.py +12 -19
  98. triggerflow-0.2/tests/test_loader.py +103 -0
  99. triggerflow-0.1.12/PKG-INFO +0 -61
  100. triggerflow-0.1.12/README.md +0 -50
  101. triggerflow-0.1.12/pyproject.toml +0 -24
  102. triggerflow-0.1.12/src/triggerflow.egg-info/PKG-INFO +0 -61
  103. triggerflow-0.1.12/src/triggerflow.egg-info/SOURCES.txt +0 -15
  104. triggerflow-0.1.12/src/triggerflow.egg-info/requires.txt +0 -1
  105. triggerflow-0.1.12/src/triggerflow.egg-info/top_level.txt +0 -1
  106. {triggerflow-0.1.12 → triggerflow-0.2}/setup.cfg +0 -0
  107. {triggerflow-0.1.12/src/triggerflow → triggerflow-0.2/src/trigger_dataset}/__init__.py +0 -0
  108. {triggerflow-0.1.12 → triggerflow-0.2}/src/triggerflow/templates/makefile +0 -0
  109. {triggerflow-0.1.12 → triggerflow-0.2}/src/triggerflow/templates/makefile_version +0 -0
  110. {triggerflow-0.1.12 → triggerflow-0.2}/src/triggerflow/templates/model_template.cpp +0 -0
  111. {triggerflow-0.1.12 → triggerflow-0.2}/src/triggerflow/templates/scales.h +0 -0
  112. {triggerflow-0.1.12 → triggerflow-0.2}/src/triggerflow.egg-info/dependency_links.txt +0 -0
@@ -0,0 +1,2 @@
1
+ recursive-include src/triggerflow/starter *
2
+ recursive-include src/triggerflow/templates *
@@ -0,0 +1,97 @@
1
+ Metadata-Version: 2.4
2
+ Name: triggerflow
3
+ Version: 0.2
4
+ Summary: Utilities for ML models targeting hardware triggers
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: License :: OSI Approved :: MIT License
7
+ Classifier: Operating System :: OS Independent
8
+ Requires-Python: >=3.11
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: cookiecutter>=2.3
11
+ Requires-Dist: PyYAML>=6
12
+ Requires-Dist: Jinja2>=3
13
+ Requires-Dist: mlflow>=2.0
14
+ Requires-Dist: kedro==1.0.0
15
+ Provides-Extra: dev
16
+ Requires-Dist: pytest-cov~=3.0; extra == "dev"
17
+ Requires-Dist: pytest-mock<2.0,>=1.7.1; extra == "dev"
18
+ Requires-Dist: pytest~=7.2; extra == "dev"
19
+ Requires-Dist: ruff~=0.1.8; extra == "dev"
20
+
21
+ # Machine Learning for Hardware Triggers
22
+
23
+ `triggerflow` provides a set of utilities for Machine Learning models targeting FPGA deployment.
24
+ The `TriggerModel` class consolidates several Machine Learning frontends and compiler backends to construct a "trigger model". MLflow utilities are for logging, versioning, and loading of trigger models.
25
+
26
+ ## Installation
27
+
28
+ ```bash
29
+ pip install triggerflow
30
+ ```
31
+
32
+ ## Usage
33
+
34
+ ```python
35
+
36
+ from triggerflow.core import TriggerModel
37
+
38
+ triggerflow = TriggerModel(name="my-trigger-model", ml_backend="Keras", compiler="hls4ml", model, compiler_config or None)
39
+ triggerflow() # call the constructor
40
+
41
+ # then:
42
+ output_software = triggerflow.software_predict(input_data)
43
+ output_firmware = triggerflow.firmware_predict(input_data)
44
+ output_qonnx = triggerflow.qonnx_predict(input_data)
45
+
46
+ # save and load trigger models:
47
+ triggerflow.save("triggerflow.tar.xz")
48
+
49
+ # in a separate session:
50
+ from triggerflow.core import TriggerModel
51
+ triggerflow = TriggerModel.load("triggerflow.tar.xz")
52
+ ```
53
+
54
+ ## Logging with MLflow
55
+
56
+ ```python
57
+ # logging with MLFlow:
58
+ import mlflow
59
+ from triggerflow.mlflow_wrapper import log_model
60
+
61
+ mlflow.set_tracking_uri("https://ngt.cern.ch/models")
62
+ experiment_id = mlflow.create_experiment("example-experiment")
63
+
64
+ with mlflow.start_run(run_name="trial-v1", experiment_id=experiment_id):
65
+ log_model(triggerflow, registered_model_name="TriggerModel")
66
+ ```
67
+
68
+ ### Note: This package doesn't install dependencies so it won't disrupt specific training environments or custom compilers. For a reference environment, see `environment.yml`.
69
+
70
+
71
+ # Creating a kedro pipeline
72
+
73
+ This repository also comes with a default pipeline for trigger models based on kedro.
74
+ One can create a new pipeline via:
75
+
76
+ NOTE: no "-" and upper cases!
77
+
78
+ ```bash
79
+ # Create a conda environment & activate it
80
+ conda create -n triggerflow python=3.11
81
+ conda activate triggerflow
82
+
83
+ # install triggerflow
84
+ pip install triggerflow
85
+
86
+ # Create a pipeline
87
+ triggerflow new demo_pipeline
88
+
89
+ # NOTE: since we dont install dependency one has to create a
90
+ # conda env based on the environment.yml file of the pipeline
91
+ # this file can be changed to the needs of the indiviual project
92
+ cd demo_pipeline
93
+ conda env update -n triggerflow --file environment.yml
94
+
95
+ # Run Kedro
96
+ kedro run
97
+ ```
@@ -0,0 +1,77 @@
1
+ # Machine Learning for Hardware Triggers
2
+
3
+ `triggerflow` provides a set of utilities for Machine Learning models targeting FPGA deployment.
4
+ The `TriggerModel` class consolidates several Machine Learning frontends and compiler backends to construct a "trigger model". MLflow utilities are for logging, versioning, and loading of trigger models.
5
+
6
+ ## Installation
7
+
8
+ ```bash
9
+ pip install triggerflow
10
+ ```
11
+
12
+ ## Usage
13
+
14
+ ```python
15
+
16
+ from triggerflow.core import TriggerModel
17
+
18
+ triggerflow = TriggerModel(name="my-trigger-model", ml_backend="Keras", compiler="hls4ml", model, compiler_config or None)
19
+ triggerflow() # call the constructor
20
+
21
+ # then:
22
+ output_software = triggerflow.software_predict(input_data)
23
+ output_firmware = triggerflow.firmware_predict(input_data)
24
+ output_qonnx = triggerflow.qonnx_predict(input_data)
25
+
26
+ # save and load trigger models:
27
+ triggerflow.save("triggerflow.tar.xz")
28
+
29
+ # in a separate session:
30
+ from triggerflow.core import TriggerModel
31
+ triggerflow = TriggerModel.load("triggerflow.tar.xz")
32
+ ```
33
+
34
+ ## Logging with MLflow
35
+
36
+ ```python
37
+ # logging with MLFlow:
38
+ import mlflow
39
+ from triggerflow.mlflow_wrapper import log_model
40
+
41
+ mlflow.set_tracking_uri("https://ngt.cern.ch/models")
42
+ experiment_id = mlflow.create_experiment("example-experiment")
43
+
44
+ with mlflow.start_run(run_name="trial-v1", experiment_id=experiment_id):
45
+ log_model(triggerflow, registered_model_name="TriggerModel")
46
+ ```
47
+
48
+ ### Note: This package doesn't install dependencies so it won't disrupt specific training environments or custom compilers. For a reference environment, see `environment.yml`.
49
+
50
+
51
+ # Creating a kedro pipeline
52
+
53
+ This repository also comes with a default pipeline for trigger models based on kedro.
54
+ One can create a new pipeline via:
55
+
56
+ NOTE: no "-" and upper cases!
57
+
58
+ ```bash
59
+ # Create a conda environment & activate it
60
+ conda create -n triggerflow python=3.11
61
+ conda activate triggerflow
62
+
63
+ # install triggerflow
64
+ pip install triggerflow
65
+
66
+ # Create a pipeline
67
+ triggerflow new demo_pipeline
68
+
69
+ # NOTE: since we dont install dependency one has to create a
70
+ # conda env based on the environment.yml file of the pipeline
71
+ # this file can be changed to the needs of the indiviual project
72
+ cd demo_pipeline
73
+ conda env update -n triggerflow --file environment.yml
74
+
75
+ # Run Kedro
76
+ kedro run
77
+ ```
@@ -0,0 +1,49 @@
1
+ [build-system]
2
+ requires = ["setuptools>=65.5", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "triggerflow"
7
+ version = "0.2"
8
+ description = "Utilities for ML models targeting hardware triggers"
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ dependencies = [
12
+ "cookiecutter>=2.3",
13
+ "PyYAML>=6",
14
+ "Jinja2>=3",
15
+ "mlflow>=2.0",
16
+ "kedro==1.0.0",
17
+ ]
18
+ classifiers = [
19
+ "Programming Language :: Python :: 3",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Operating System :: OS Independent"
22
+ ]
23
+
24
+ [project.optional-dependencies]
25
+ dev = [
26
+ "pytest-cov~=3.0",
27
+ "pytest-mock>=1.7.1, <2.0",
28
+ "pytest~=7.2", "ruff~=0.1.8",
29
+ ]
30
+
31
+ [tool.setuptools]
32
+ include-package-data = true
33
+
34
+ [tool.setuptools.packages.find]
35
+ where = ["src"]
36
+
37
+ [tool.setuptools.package-data]
38
+ triggerflow = ["starter/**", "starter/**/*"]
39
+
40
+ [tool.ruff]
41
+ line-length = 88
42
+ show-fixes = true
43
+ select = [ "F", "W", "E", "I", "UP", "PL", "T201",]
44
+ ignore = [ "E501",]
45
+ extend-exclude = ["src/triggerflow/starter"]
46
+
47
+ # expose CLI entrypoint
48
+ [project.scripts]
49
+ triggerflow = "triggerflow.cli:main"
@@ -0,0 +1,88 @@
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from fnmatch import filter as fnmatch_filter
4
+
5
+ import pandas as pd
6
+ import uproot
7
+
8
+
9
+ class TriggerDataset(ABC):
10
+ """
11
+ Abstract Base Class for loading data from ROOT files.
12
+
13
+ Users must inherit from this class and implement the abstract methods.
14
+ The core processing logic in `process_file` is fixed and cannot be overridden.
15
+ """
16
+
17
+ def __init__(self):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_features(self) -> list[str]:
22
+ """
23
+ Return a list of branch names or patterns to keep from the dataset.
24
+ Accepts wildcards (e.g. Jet_*).
25
+ """
26
+ pass
27
+
28
+ @abstractmethod
29
+ def get_cut(self) -> str | None:
30
+ """
31
+ Return a string representing the cuts to apply to the data.
32
+ """
33
+ pass
34
+
35
+ @abstractmethod
36
+ def convert_to_pandas(self, data: dict) -> pd.DataFrame:
37
+ """
38
+ Convert the loaded data from a dictionary format to a pandas DataFrame.
39
+ """
40
+ pass
41
+
42
+ def _resolve_branches(self, all_branches: list) -> list[str]:
43
+ """Internal method to resolve wildcard patterns."""
44
+ selected = []
45
+ for pattern in self.get_features():
46
+ matched = fnmatch_filter(all_branches, pattern)
47
+ if not matched:
48
+ warnings.warn(f"'{pattern}' did not match any branches.")
49
+ selected.extend(matched)
50
+ return sorted(list(set(selected)))
51
+
52
+ def _save_to_parquet(self, df: pd.DataFrame, output_path: str):
53
+ """
54
+ Save the processed DataFrame to a file.
55
+ """
56
+ df.to_parquet(output_path)
57
+
58
+ def _save_to_csv(self, df: pd.DataFrame, output_path: str):
59
+ """
60
+ Save the processed DataFrame to a CSV file.
61
+ """
62
+ df.to_csv(output_path, index=False)
63
+
64
+ def process_file(self, file_path: str, out_file_path: str) -> pd.DataFrame:
65
+ """
66
+ Loads and processes a single ROOT file.
67
+ """
68
+
69
+ with uproot.open(f"{file_path}") as f:
70
+ tree = f[self.get_tree_name()]
71
+ all_branches = tree.keys()
72
+ branches_to_load = self._resolve_branches(all_branches)
73
+
74
+ if not branches_to_load:
75
+ return pd.DataFrame()
76
+
77
+ data = tree.arrays(branches_to_load, cut=self.get_cut(), how=dict)
78
+
79
+ df = self.convert_to_pandas(data)
80
+
81
+ if self.output_format == "parquet":
82
+ self._save_to_parquet(df, f"{out_file_path}.parquet")
83
+ elif self.output_format == "csv":
84
+ self._save_to_csv(df, f"{out_file_path}.csv")
85
+ else:
86
+ return pd.DataFrame()
87
+
88
+ return pd.DataFrame()
File without changes
@@ -0,0 +1,107 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any
5
+
6
+ from dask.distributed import Client, LocalCluster
7
+ from dask_cuda import LocalCUDACluster
8
+ from dask_jobqueue import HTCondorCluster
9
+ from dask_kubernetes import KubeCluster
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class ClusterManager:
15
+ """Context manager to provision and tear down a Dask cluster.
16
+
17
+ Parameters
18
+ ----------
19
+ cluster_type : str
20
+ Backend to use ("local", "condor", "cuda", "kubernetes").
21
+ cluster_config : dict | None, optional
22
+ Keyword arguments forwarded to the specific cluster constructor.
23
+ jobs : int, optional
24
+ Desired number of jobs / workers (used for queue / scalable backends).
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ cluster_type: str,
30
+ cluster_config: dict[str, Any] | None = None,
31
+ jobs: int = 1,
32
+ ) -> None:
33
+ if cluster_config is None:
34
+ cluster_config = {}
35
+ # Copy to avoid mutating caller's dict accidentally.
36
+ self.cluster_config: dict[str, Any] = dict(cluster_config)
37
+ self.cluster_type: str = cluster_type
38
+ self.jobs: int = jobs
39
+
40
+ self.cluster: Any | None = None
41
+ self.client: Any | None = None
42
+
43
+ # ------------------------------------------------------------------
44
+ # Context manager protocol
45
+ # ------------------------------------------------------------------
46
+ def __enter__(self): # -> distributed.Client (avoids importing type eagerly)
47
+ self._start_cluster()
48
+ return self.client
49
+
50
+ def __exit__(self, exc_type, exc, tb) -> bool: # noqa: D401 (simple)
51
+ self._close_cluster()
52
+ # Returning False propagates any exception (desired behavior)
53
+ return False
54
+
55
+ # ------------------------------------------------------------------
56
+ # Internal helpers
57
+ # ------------------------------------------------------------------
58
+ def _start_cluster(self) -> None:
59
+
60
+ ct = self.cluster_type.lower()
61
+
62
+ if ct == "local":
63
+ self.cluster = LocalCluster(**self.cluster_config)
64
+
65
+ elif ct == "condor":
66
+ self.cluster = HTCondorCluster(**self.cluster_config)
67
+ if self.jobs and self.jobs > 0:
68
+ # Scale to the requested number of jobs
69
+ self.cluster.scale(jobs=self.jobs)
70
+
71
+ elif ct == "cuda":
72
+ self.cluster = LocalCUDACluster(**self.cluster_config)
73
+
74
+ elif ct == "kubernetes":
75
+ self.cluster = KubeCluster(**self.cluster_config)
76
+ if self.jobs and self.jobs > 0:
77
+ try:
78
+ # Not all KubeCluster versions expose scale() identically
79
+ self.cluster.scale(self.jobs)
80
+ except Exception:
81
+ pass # Best effort; ignore if unsupported
82
+
83
+ else:
84
+ raise ValueError(f"Unsupported cluster type: {self.cluster_type}")
85
+
86
+ self.client = Client(self.cluster)
87
+ dash = getattr(self.client, "dashboard_link", None)
88
+ if dash:
89
+ logger.info(f"Dask dashboard: {dash}")
90
+
91
+ def _close_cluster(self) -> None:
92
+ # Close client first so tasks wind down before cluster termination.
93
+ if self.client is not None:
94
+ try:
95
+ self.client.close()
96
+ except Exception:
97
+ pass
98
+ finally:
99
+ self.client = None
100
+ if self.cluster is not None:
101
+ try:
102
+ self.cluster.close()
103
+ except Exception:
104
+ pass
105
+ finally:
106
+ self.cluster = None
107
+
@@ -0,0 +1,95 @@
1
+ import json
2
+ import logging
3
+ import platform
4
+ import time
5
+ import uuid
6
+
7
+ import awkward as ak
8
+ import coffea
9
+ from coffea import processor
10
+ from coffea.nanoevents import NanoAODSchema
11
+
12
+ from .cluster_manager import ClusterManager
13
+ from .processor import TriggerProcessor
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class TriggerLoader:
19
+ def __init__(self,
20
+ sample_json: str,
21
+ transform: callable,
22
+ output_path: str,
23
+ ):
24
+ self.transform = transform
25
+ self.fileset = self._load_sample_json(sample_json)
26
+ self.output_path = output_path
27
+ self.run_uuid = str(uuid.uuid4())
28
+
29
+ def _build_processor(self):
30
+ run_meta = {
31
+ "run_uuid": self.run_uuid,
32
+ "fileset_size": sum(len(v) if isinstance(v, list) else 1 for v in self.fileset.values()),
33
+ "coffea_version": coffea.__version__,
34
+ "awkward_version": ak.__version__,
35
+ "python_version": platform.python_version(),
36
+ }
37
+
38
+ return TriggerProcessor(
39
+ output_path=self.output_path,
40
+ transform=self.transform,
41
+ compression="zstd",
42
+ add_uuid=False,
43
+ run_uuid=self.run_uuid,
44
+ run_metadata=run_meta,
45
+ )
46
+
47
+ def _load_sample_json(self, sample_json: str) -> dict:
48
+ with open(sample_json) as f:
49
+ return json.load(f)
50
+
51
+ def _write_run_metadata_file(self, path: str, duration_s: float | None = None):
52
+ meta_path = f"{path}/run_metadata.json"
53
+ data = {
54
+ "run_uuid": self.run_uuid,
55
+ "duration_seconds": duration_s,
56
+ }
57
+ with open(meta_path, "w") as f:
58
+ json.dump(data, f, indent=2)
59
+
60
+ def _run(self, runner: processor.Runner, label: str):
61
+ logger.log(f"Starting processing ({label})...")
62
+ start = time.time()
63
+ proc = self._build_processor()
64
+ acc = runner(
65
+ self.fileset,
66
+ treename="Events",
67
+ processor_instance=proc
68
+ )
69
+ elapsed = time.time() - start
70
+ self._write_run_metadata_file(self.output_path, elapsed)
71
+ logger.log(f"Finished in {elapsed:.2f}s (run_uuid={self.run_uuid})")
72
+ return acc
73
+
74
+ def run_distributed(self, cluster_type: str, cluster_config: dict,
75
+ chunksize: int = 100_000, jobs: int = 1):
76
+ with ClusterManager(cluster_type, cluster_config, jobs) as client:
77
+ executor = processor.DaskExecutor(client=client)
78
+ runner = processor.Runner(
79
+ executor=executor,
80
+ schema=NanoAODSchema,
81
+ chunksize=chunksize
82
+ )
83
+ self._run(runner, f"Distributed ({cluster_type})")
84
+
85
+ def run_local(self, num_workers: int = 4, chunksize: int = 100_000):
86
+ """
87
+ Run processing locally using a multi-processing executor.
88
+ """
89
+ executor = processor.FuturesExecutor(workers=num_workers)
90
+ runner = processor.Runner(
91
+ executor=executor,
92
+ schema=NanoAODSchema,
93
+ chunksize=chunksize
94
+ )
95
+ self._run(runner, f"Local ({num_workers} workers)")