triggerflow 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +154 -0
  6. trigger_loader/processor.py +212 -0
  7. triggerflow/__init__.py +0 -0
  8. triggerflow/cli.py +122 -0
  9. triggerflow/core.py +617 -0
  10. triggerflow/interfaces/__init__.py +0 -0
  11. triggerflow/interfaces/uGT.py +187 -0
  12. triggerflow/mlflow_wrapper.py +270 -0
  13. triggerflow/starter/.gitignore +143 -0
  14. triggerflow/starter/README.md +0 -0
  15. triggerflow/starter/cookiecutter.json +5 -0
  16. triggerflow/starter/prompts.yml +9 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +90 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/condor_config.json +11 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/cuda_config.json +4 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +24 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/settings.json +8 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/test.root +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +23 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_loader.py +101 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +49 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_loader.py +32 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +70 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +20 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +41 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +13 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +48 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +31 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  88. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  89. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  90. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  91. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  92. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  93. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  94. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  95. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  96. triggerflow/templates/build_ugt.tcl +46 -0
  97. triggerflow/templates/data_types.h +524 -0
  98. triggerflow/templates/makefile +28 -0
  99. triggerflow/templates/makefile_version +15 -0
  100. triggerflow/templates/model-gt.cpp +104 -0
  101. triggerflow/templates/model_template.cpp +63 -0
  102. triggerflow/templates/scales.h +20 -0
  103. triggerflow-0.3.4.dist-info/METADATA +206 -0
  104. triggerflow-0.3.4.dist-info/RECORD +107 -0
  105. triggerflow-0.3.4.dist-info/WHEEL +5 -0
  106. triggerflow-0.3.4.dist-info/entry_points.txt +2 -0
  107. triggerflow-0.3.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,187 @@
1
+ from pathlib import Path
2
+ import shutil, warnings, subprocess
3
+ import pkg_resources
4
+ from jinja2 import Template
5
+ import re
6
+
7
+
8
+ def _render_template(template_name: str, output_file: Path, context: dict):
9
+ template_bytes = pkg_resources.resource_string(
10
+ "triggerflow", template_name
11
+ )
12
+ template_text = template_bytes.decode('utf-8')
13
+
14
+ template = Template(template_text)
15
+ rendered = template.render(**context)
16
+
17
+ with open(output_file, "w") as f:
18
+ f.write(rendered)
19
+
20
+
21
+ def build_ugt_model(
22
+ subsystem_cfg: dict,
23
+ compiler_cfg: dict,
24
+ workspace_manager,
25
+ compiler,
26
+ scales: dict,
27
+ name: str,
28
+ n_outputs: int
29
+ ):
30
+
31
+ firmware_dir = workspace_manager.workspace / "firmware"
32
+ templates_dir = subsystem_cfg.get("templates_dir", Path("templates"))
33
+
34
+ objects = subsystem_cfg.get("objects", {})
35
+
36
+ object_features = {}
37
+
38
+ for object_name, object_config in objects.items():
39
+ object_features[object_name] = object_config.get("features", [])
40
+
41
+ muon_size = subsystem_cfg.get("muon_size", 0)
42
+ jet_size = subsystem_cfg.get("jet_size", 0)
43
+ egamma_size = subsystem_cfg.get("egamma_size", 0)
44
+ tau_size = subsystem_cfg.get("tau_size", 0)
45
+
46
+ uGT(
47
+ templates_dir=templates_dir,
48
+ firmware_dir=firmware_dir,
49
+ compiler=compiler,
50
+ model_name=name,
51
+ namespace=compiler_cfg.get("namespace", "triggerflow"),
52
+ project_name=compiler_cfg.get("project_name", "triggerflow"),
53
+ n_inputs=subsystem_cfg["n_inputs"],
54
+ n_outputs=subsystem_cfg.get("n_outputs", n_outputs),
55
+ nn_offsets=scales["offsets"],
56
+ nn_shifts=scales["shifts"],
57
+ muon_size=muon_size,
58
+ jet_size=jet_size,
59
+ egamma_size=egamma_size,
60
+ tau_size=tau_size,
61
+ output_type=subsystem_cfg.get("output_type", "result_t"),
62
+ offset_type=subsystem_cfg.get("offset_type", "ap_fixed<10,10>"),
63
+ shift_type=subsystem_cfg.get("shift_type", "ap_fixed<10,10>"),
64
+ object_features=object_features,
65
+ global_features=subsystem_cfg.get("global_features")
66
+ )
67
+
68
+
69
+ def uGT(
70
+ templates_dir: Path,
71
+ firmware_dir: Path,
72
+ compiler: str,
73
+ model_name: str,
74
+ namespace: str,
75
+ project_name: str,
76
+ n_inputs: int,
77
+ n_outputs: int,
78
+ nn_offsets: list,
79
+ nn_shifts: list,
80
+ muon_size: int,
81
+ jet_size: int,
82
+ egamma_size: int,
83
+ tau_size: int,
84
+ output_type: str = "result_t",
85
+ offset_type: str = "ap_fixed<10,10>",
86
+ shift_type: str = "ap_fixed<10,10>",
87
+ object_features: dict = None,
88
+ global_features: list = None
89
+ ):
90
+ """
91
+ Render uGT top func.
92
+ """
93
+
94
+ if object_features is None:
95
+ object_features = {
96
+ "muons": ["pt", "eta_extrapolated", "phi_extrapolated"],
97
+ "jets": ["et", "eta", "phi"],
98
+ "egammas": ["et", "eta", "phi"],
99
+ "taus": ["et", "eta", "phi"]
100
+ }
101
+
102
+ if global_features is None:
103
+ global_features = [
104
+ "et.et",
105
+ "ht.et",
106
+ "etmiss.et", "etmiss.phi",
107
+ "htmiss.et", "htmiss.phi",
108
+ "ethfmiss.et", "ethfmiss.phi",
109
+ "hthfmiss.et", "hthfmiss.phi"
110
+ ]
111
+
112
+ header_path = firmware_dir / "firmware" / f"{model_name}_project.h"
113
+ if compiler.lower() == "conifer":
114
+ output_layer = "score"
115
+ output_type = "score_arr_t"
116
+ header_path = firmware_dir / "firmware" / f"{model_name}_project.h"
117
+ removal_pattern = re.compile(
118
+ r',\s*score_t\s+tree_scores\[BDT::fn_classes\(n_classes\)\s*\*\s*n_trees\]',
119
+ re.DOTALL
120
+ )
121
+ modified_content = removal_pattern.sub('', header_path.read_text(encoding='utf-8'))
122
+ header_path.write_text(modified_content, encoding='utf-8')
123
+ out = output_layer
124
+ else:
125
+ header_content = header_path.read_text(encoding='utf-8')
126
+ layer_pattern = re.compile(
127
+ r'result_t\s+(\w+)\[[\w\d_]+\]',
128
+ re.DOTALL
129
+ )
130
+ match = layer_pattern.search(header_content)
131
+ layer_name = match.group(1)
132
+ output_layer = f"{layer_name}[{n_outputs}]"
133
+ out = layer_name
134
+
135
+
136
+ context = {
137
+ "MODEL_NAME": model_name,
138
+ "NAMESPACE": namespace,
139
+ "PROJECT_NAME": project_name,
140
+ "N_INPUTS": n_inputs,
141
+ "N_OUTPUTS": n_outputs,
142
+ "NN_OFFSETS": ", ".join(map(str, nn_offsets)),
143
+ "NN_SHIFTS": ", ".join(map(str, nn_shifts)),
144
+ "MUON_SIZE": muon_size,
145
+ "JET_SIZE": jet_size,
146
+ "EGAMMA_SIZE": egamma_size,
147
+ "TAU_SIZE": tau_size,
148
+ "OUTPUT_TYPE": output_type,
149
+ "OUTPUT_LAYER": output_layer,
150
+ "OUT": out,
151
+ "OFFSET_TYPE": offset_type,
152
+ "SHIFT_TYPE": shift_type,
153
+ "MUON_FEATURES": object_features["muons"],
154
+ "JET_FEATURES": object_features["jets"],
155
+ "EGAMMA_FEATURES": object_features["egammas"],
156
+ "TAU_FEATURES": object_features["taus"],
157
+ "GLOBAL_FEATURES": global_features
158
+ }
159
+
160
+ context_tcl = {
161
+ "MODEL_NAME": model_name,
162
+ }
163
+
164
+ out_path = firmware_dir / "firmware/model-gt.cpp"
165
+
166
+ _render_template("templates/model-gt.cpp", out_path, context)
167
+
168
+ out_path = firmware_dir / "firmware/build_ugt.tcl"
169
+ _render_template("templates/build_ugt.tcl", out_path, context_tcl)
170
+
171
+ shutil.copy(
172
+ pkg_resources.resource_filename("triggerflow", "templates/data_types.h"),
173
+ firmware_dir / "firmware"
174
+ )
175
+
176
+ if shutil.which("vivado") is not None:
177
+ subprocess.run(
178
+ ["vitis_hls", "-f", "build_ugt.tcl"],
179
+ cwd=firmware_dir/"firmware",
180
+ check=True
181
+ )
182
+ else:
183
+ warnings.warn(
184
+ "Vivado executable not found on the system PATH. "
185
+ "Skipping FW build. ",
186
+ UserWarning
187
+ )
@@ -0,0 +1,270 @@
1
+ # trigger_mlflow.py
2
+ import datetime
3
+ import logging
4
+ import os
5
+ import tempfile
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import mlflow
10
+ import mlflow.pyfunc
11
+ from mlflow.tracking import MlflowClient
12
+
13
+ from .core import TriggerModel
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def setup_mlflow(
19
+ mlflow_uri: str = None,
20
+ web_eos_url: str = None,
21
+ web_eos_path: str = None,
22
+ model_name: str = None,
23
+ experiment_name: str = None,
24
+ run_name: str = None,
25
+ experiment_id: str = None,
26
+ run_id: str = None,
27
+ creat_web_eos_dir: bool = False,
28
+ save_env_file: bool = False,
29
+ auto_configure: bool = False
30
+ ):
31
+
32
+ # Set the MLflow tracking URI
33
+ if mlflow_uri is None:
34
+ mlflow_uri = os.getenv('MLFLOW_URI', 'https://ngt.cern.ch/models')
35
+ mlflow.set_tracking_uri(mlflow_uri)
36
+ os.environ["MLFLOW_URI"] = mlflow_uri
37
+ logger.info(f"Using MLflow tracking URI: {mlflow_uri}")
38
+
39
+ # Set the model name
40
+ if model_name is None:
41
+ if os.getenv('MLFLOW_MODEL_NAME'):
42
+ model_name = os.getenv('MLFLOW_MODEL_NAME')
43
+ else:
44
+ model_name = os.getenv('CI_COMMIT_BRANCH', 'Test-Model')
45
+ os.environ["MLFLOW_MODEL_NAME"] = model_name
46
+ logger.info(f"Using model name: {model_name}")
47
+
48
+
49
+ # Set the experiment name
50
+ if experiment_name is None:
51
+ if os.getenv('MLFLOW_EXPERIMENT_NAME'):
52
+ experiment_name = os.getenv('MLFLOW_EXPERIMENT_NAME')
53
+ else:
54
+ experiment_name = os.getenv('CI_COMMIT_BRANCH', 'Test-Training-Torso')
55
+ os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_name
56
+ logger.info(f"Using experiment name: {experiment_name}")
57
+
58
+
59
+ # Set the run name
60
+ if run_name is None:
61
+ if os.getenv('CI') == 'true':
62
+ if os.getenv('CI_PARENT_PIPELINE_ID'):
63
+ run_name = f"{os.getenv('CI_PARENT_PIPELINE_ID')}-{os.getenv('CI_PIPELINE_ID')}"
64
+ else:
65
+ run_name = f"{os.getenv('CI_PIPELINE_ID')}"
66
+ else:
67
+ run_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
68
+ os.environ["MLFLOW_RUN_NAME"] = run_name
69
+ logger.info(f"Using run name: {run_name}")
70
+
71
+
72
+ # Create a new experiment or get the existing one
73
+ if experiment_id is None:
74
+ if os.getenv("MLFLOW_EXPERIMENT_ID"):
75
+ experiment_id = os.getenv("MLFLOW_EXPERIMENT_ID")
76
+ else:
77
+ try:
78
+ experiment_id = mlflow.create_experiment(experiment_name)
79
+ except mlflow.exceptions.MlflowException:
80
+ experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
81
+
82
+ check_experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
83
+ if str(check_experiment_id) != str(experiment_id):
84
+ raise ValueError(f"Provided experiment_id {experiment_id} does not match the ID of experiment_name {experiment_name} ({check_experiment_id})")
85
+
86
+ # if mlflow.get_experiment_by_name(experiment_name).experiment_id is None:
87
+ # experiment_id = mlflow.create_experiment(experiment_name)
88
+ # else:
89
+ # experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
90
+
91
+ mlflow.set_experiment(experiment_id=experiment_id)
92
+ os.environ["MLFLOW_EXPERIMENT_ID"] = experiment_id
93
+ logger.info(f"Using experiment ID: {experiment_id}")
94
+
95
+
96
+ # Start a new MLflow run
97
+ if run_id is None:
98
+ if os.getenv("MLFLOW_RUN_ID"):
99
+ run_id = os.getenv("MLFLOW_RUN_ID")
100
+ else:
101
+ with mlflow.start_run(experiment_id=experiment_id, run_name=run_name) as run:
102
+ run_id = run.info.run_id
103
+
104
+ check_run_info = mlflow.get_run(run_id)
105
+ if str(check_run_info.info.experiment_id) != str(experiment_id):
106
+ raise ValueError(f"Provided run_id {run_id} does not belong to experiment_id {experiment_id} (found {check_run_info.info.experiment_id})")
107
+
108
+ os.environ["MLFLOW_RUN_ID"] = run_id
109
+ logger.info(f"Started run with ID: {run_id}")
110
+
111
+
112
+ if creat_web_eos_dir:
113
+ # Set the web_eos_url
114
+ if web_eos_url is None:
115
+ web_eos_url = os.getenv('WEB_EOS_URL', 'https://ngt-modeltraining.web.cern.ch/')
116
+ os.environ["WEB_EOS_URL"] = web_eos_url
117
+ logger.info(f"Using WEB_EOS_URL: {web_eos_url}")
118
+
119
+ # Set the web_eos_path
120
+ if web_eos_path is None:
121
+ web_eos_path = os.getenv('WEB_EOS_PATH', '/eos/user/m/mlflowngt/backend/www')
122
+ os.environ["WEB_EOS_PATH"] = web_eos_path
123
+ logger.info(f"Using WEB_EOS_PATH: {web_eos_path}")
124
+
125
+ # Create WebEOS experiment dir
126
+ web_eos_experiment_dir = os.path.join(web_eos_path, experiment_name, run_name)
127
+ web_eos_experiment_url = os.path.join(web_eos_url, experiment_name, run_name)
128
+ os.makedirs(web_eos_experiment_dir, exist_ok=True)
129
+ logger.info(f"Created WebEOS experiment directory: {web_eos_experiment_dir}")
130
+ logger.info(f"Using WebEOS experiment URL: {web_eos_experiment_url}")
131
+
132
+ else:
133
+ web_eos_url=None
134
+ web_eos_path=None
135
+ web_eos_experiment_dir=None
136
+ web_eos_experiment_url=None
137
+
138
+
139
+ # Save environment variables to a file for later steps in CI/CD pipelines
140
+ if save_env_file and os.getenv("CI") == "true":
141
+ logger.info(f"Saving MLflow environment variables to {os.getenv('CI_ENV_FILE', 'mlflow.env')}")
142
+ with open(os.getenv('CI_ENV_FILE', 'mlflow.env'), 'a') as f:
143
+ f.write(f"MLFLOW_URI={mlflow_uri}\n")
144
+ f.write(f"MLFLOW_MODEL_NAME={model_name}\n")
145
+ f.write(f"MLFLOW_EXPERIMENT_NAME={experiment_name}\n")
146
+ f.write(f"MLFLOW_RUN_NAME={run_name}\n")
147
+ f.write(f"MLFLOW_EXPERIMENT_ID={experiment_id}\n")
148
+ f.write(f"MLFLOW_RUN_ID={run_id}\n")
149
+
150
+ if creat_web_eos_dir:
151
+ f.write(f"WEB_EOS_URL={web_eos_url}\n")
152
+ f.write(f"WEB_EOS_PATH={web_eos_path}\n")
153
+ f.write(f"WEB_EOS_EXPERIMENT_DIR={web_eos_experiment_dir}\n")
154
+ f.write(f"WEB_EOS_EXPERIMENT_URL={web_eos_experiment_url}\n")
155
+
156
+ if auto_configure:
157
+ logger.info("Auto_configure is set to true. Exporting AUTO_CONFIGURE=true")
158
+ f.write("AUTO_CONFIGURE=true\n")
159
+
160
+ return {
161
+ "experiment_name": experiment_name,
162
+ "run_name": run_name,
163
+ "experiment_id": experiment_id,
164
+ "run_id": run_id,
165
+ "mlflow_uri": mlflow_uri,
166
+ "model_name": model_name,
167
+ "web_eos_url": web_eos_url,
168
+ "web_eos_path": web_eos_path,
169
+ "web_eos_experiment_dir": web_eos_experiment_dir,
170
+ "web_eos_experiment_url": web_eos_experiment_url,
171
+ }
172
+
173
+ if os.getenv("AUTO_CONFIGURE") == "true":
174
+ logger.info("AUTO_CONFIGURE is true and running in CI environment. Setting up mlflow...")
175
+ setup_mlflow()
176
+ else:
177
+ logger.info("AUTO_CONFIGURE is not set. Skipping mlflow run setup")
178
+
179
+ class MLflowWrapper(mlflow.pyfunc.PythonModel):
180
+ """PyFunc wrapper for TriggerModel; backend can be set at runtime."""
181
+ def load_context(self, context):
182
+ archive_path = Path(context.artifacts["triggerflow"])
183
+ self.model = TriggerModel.load(archive_path)
184
+ self.backend = "software"
185
+
186
+ def predict(self, context, model_input):
187
+ if self.backend == "software":
188
+ return self.model.software_predict(model_input)
189
+ elif self.backend == "qonnx":
190
+ if self.model.model_qonnx is None:
191
+ raise RuntimeError("QONNX model not available.")
192
+ return self.model.qonnx_predict(model_input)
193
+ elif self.backend == "firmware":
194
+ if self.model.firmware_model is None:
195
+ raise RuntimeError("Firmware model not available.")
196
+ return self.model.firmware_predict(model_input)
197
+ else:
198
+ raise ValueError(f"Unsupported backend: {self.backend}")
199
+
200
+ def get_model_info(self):
201
+ if hasattr(self.model, "get_model_info"):
202
+ return self.model.get_model_info()
203
+ return {"error": "Model info not available"}
204
+
205
+
206
+ def _get_pip_requirements(triggerflow: TriggerModel) -> list:
207
+ requirements = ["numpy"]
208
+ if triggerflow.ml_backend == "keras":
209
+ requirements.extend(["tensorflow", "keras"])
210
+ elif triggerflow.ml_backend == "xgboost":
211
+ requirements.append("xgboost")
212
+ if triggerflow.compiler == "hls4ml":
213
+ requirements.append("hls4ml")
214
+ elif triggerflow.compiler == "conifer":
215
+ requirements.append("conifer")
216
+ if hasattr(triggerflow, "model_qonnx") and triggerflow.model_qonnx is not None:
217
+ requirements.append("qonnx")
218
+ return requirements
219
+
220
+
221
+ def log_model(triggerflow: TriggerModel, registered_model_name: str, artifact_path: str = "TriggerModel"):
222
+ """Log a TriggerModel as a PyFunc model and register it in the Model Registry."""
223
+ if not registered_model_name:
224
+ if not os.getenv("MLFLOW_MODEL_NAME"):
225
+ raise ValueError("registered_model_name must be provided and non-empty")
226
+ else:
227
+ registered_model_name = os.getenv("MLFLOW_MODEL_NAME")
228
+
229
+ if mlflow.active_run() is None:
230
+ raise RuntimeError("No active MLflow run. Start a run before logging.")
231
+
232
+ run = mlflow.active_run()
233
+ with tempfile.TemporaryDirectory() as tmpdir:
234
+ archive_path = Path(tmpdir) / "triggermodel.tar.xz"
235
+ triggerflow.save(archive_path)
236
+
237
+ mlflow.pyfunc.log_model(
238
+ artifact_path=artifact_path,
239
+ python_model=MLflowWrapper(),
240
+ artifacts={"triggerflow": str(archive_path)},
241
+ pip_requirements=_get_pip_requirements(triggerflow)
242
+ )
243
+
244
+ # register model (always required)
245
+ client = MlflowClient()
246
+ model_uri = f"runs:/{run.info.run_id}/{artifact_path}"
247
+ try:
248
+ client.get_registered_model(registered_model_name)
249
+ except mlflow.exceptions.RestException:
250
+ client.create_registered_model(registered_model_name)
251
+ client.create_model_version(
252
+ name=registered_model_name,
253
+ source=model_uri,
254
+ run_id=run.info.run_id
255
+ )
256
+
257
+ def load_model(model_uri: str) -> mlflow.pyfunc.PyFuncModel:
258
+ return mlflow.pyfunc.load_model(model_uri)
259
+
260
+
261
+ def load_full_model(model_uri: str) -> TriggerModel:
262
+ local_path = mlflow.artifacts.download_artifacts(model_uri)
263
+ archive_path = Path(local_path) / "artifacts" / "triggermodel.tar.xz"
264
+ return TriggerModel.load(archive_path)
265
+
266
+ def get_model_info(model_uri: str) -> dict[str, Any]:
267
+ model = mlflow.pyfunc.load_model(model_uri)
268
+ if hasattr(model._model_impl, "get_model_info"):
269
+ return model._model_impl.get_model_info()
270
+ return {"error": "Model info not available"}
@@ -0,0 +1,143 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # pyenv
82
+ # For a library or package, you might want to ignore these files since the code is
83
+ # intended to run in multiple environments; otherwise, check them in:
84
+ # .python-version
85
+
86
+ # pipenv
87
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
89
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
90
+ # install all needed dependencies.
91
+ #Pipfile.lock
92
+
93
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94
+ __pypackages__/
95
+
96
+ # Celery stuff
97
+ celerybeat-schedule
98
+ celerybeat.pid
99
+
100
+ # SageMath parsed files
101
+ *.sage.py
102
+
103
+ # Environments
104
+ .env
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ # pytype static type analyzer
131
+ .pytype/
132
+
133
+ # Cython debug symbols
134
+ cython_debug/
135
+
136
+ .vscode/
137
+ info.log
138
+
139
+ # IntelliJ
140
+ .idea/
141
+ *.iml
142
+ out/
143
+ .idea_modules/
File without changes
@@ -0,0 +1,5 @@
1
+ {
2
+ "project_name": "triggerflow-pipeline",
3
+ "repo_name": "{{ cookiecutter.project_name.strip().replace(' ', '-').replace('_', '-').lower() }}",
4
+ "python_package": "{{ cookiecutter.project_name.strip().replace(' ', '_').replace('-', '_').lower() }}"
5
+ }
@@ -0,0 +1,9 @@
1
+ project_name:
2
+ title: "Project Name"
3
+ text: |
4
+ Please enter a human readable name for your new project.
5
+ Spaces, hyphens, and underscores are allowed.
6
+ regex_validator: "^[\\w -]{2,}$"
7
+ error_message: |
8
+ It must contain only alphanumeric symbols, spaces, underscores and hyphens and
9
+ be at least 2 characters long.
@@ -0,0 +1,3 @@
1
+ # Add patterns of files dvc should ignore, which could improve
2
+ # the performance. Learn more at
3
+ # https://dvc.org/doc/user-guide/dvcignore