triggerflow 0.1.12__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +95 -0
  6. trigger_loader/processor.py +211 -0
  7. triggerflow/cli.py +122 -0
  8. triggerflow/core.py +29 -9
  9. triggerflow/mlflow_wrapper.py +54 -49
  10. triggerflow/starter/.gitignore +143 -0
  11. triggerflow/starter/README.md +0 -0
  12. triggerflow/starter/cookiecutter.json +5 -0
  13. triggerflow/starter/prompts.yml +9 -0
  14. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  15. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  16. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +84 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +15 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples_dummy.json +26 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +21 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +88 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +50 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +10 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +40 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +12 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +31 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +29 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  88. triggerflow/templates/makefile +3 -3
  89. triggerflow/templates/makefile_version +2 -2
  90. triggerflow/templates/model_template.cpp +19 -18
  91. triggerflow/templates/scales.h +1 -1
  92. triggerflow-0.2.1.dist-info/METADATA +97 -0
  93. triggerflow-0.2.1.dist-info/RECORD +97 -0
  94. triggerflow-0.2.1.dist-info/entry_points.txt +2 -0
  95. triggerflow-0.2.1.dist-info/top_level.txt +3 -0
  96. triggerflow-0.1.12.dist-info/METADATA +0 -61
  97. triggerflow-0.1.12.dist-info/RECORD +0 -11
  98. triggerflow-0.1.12.dist-info/top_level.txt +0 -1
  99. {triggerflow-0.1.12.dist-info → triggerflow-0.2.1.dist-info}/WHEEL +0 -0
triggerflow/core.py CHANGED
@@ -111,6 +111,7 @@ class HLS4MLStrategy(CompilerStrategy):
111
111
  firmware_model.compile()
112
112
  return firmware_model
113
113
 
114
+
114
115
  class ConiferStrategy(CompilerStrategy):
115
116
  """Conifer compilation strategy for XGBoost models"""
116
117
 
@@ -118,21 +119,32 @@ class ConiferStrategy(CompilerStrategy):
118
119
  import conifer
119
120
  import shutil
120
121
  import warnings
122
+ import os
121
123
 
122
124
  firmware_dir = workspace / "firmware"
123
125
  firmware_dir.mkdir(exist_ok=True)
126
+ os.environ['JSON_ROOT'] = '/eos/user/m/maglowac/TriggerModel/json'
127
+ os.environ['XILINX_AP_INCLUDE'] = '/eos/user/m/maglowac/TriggerModel/HLS_arbitrary_Precision_Types/include'
128
+
124
129
 
125
- cfg = config or conifer.backends.xilinxhls.auto_config()
130
+ cfg = conifer.backends.xilinxhls.auto_config()#config or conifer.backends.cpp.auto_config()
126
131
  cfg['OutputDir'] = str(firmware_dir)
127
132
 
128
133
  for key, value in kwargs.items():
129
134
  cfg[key] = value
130
135
 
136
+ print(cfg)
131
137
  firmware_model = conifer.converters.convert_from_xgboost(
132
138
  model,
133
139
  config=cfg
134
140
  )
135
141
 
142
+ firmware_model.write()
143
+ proj_name = cfg.get('ProjectName', 'my_prj')
144
+ bridge_file = firmware_dir / "bridge.cpp"
145
+ text = bridge_file.read_text()
146
+ text = text.replace("my_prj.h", f"{proj_name}.h")
147
+ bridge_file.write_text(text)
136
148
  firmware_model.compile()
137
149
  if shutil.which("vivado") is not None:
138
150
  firmware_model.build()
@@ -141,6 +153,14 @@ class ConiferStrategy(CompilerStrategy):
141
153
 
142
154
  firmware_model.save(firmware_dir / "firmware_model.fml")
143
155
  return firmware_model
156
+
157
+ def load_compiled_model(self, workspace: Path) -> Any:
158
+ from conifer import load_model
159
+
160
+ firmware_model = load_model(workspace / "firmware_model.fml")
161
+ firmware_model.compile()
162
+ return firmware_model
163
+
144
164
 
145
165
  class DA4MLStrategy(CompilerStrategy):
146
166
  """DA4ML compilation strategy (placeholder)"""
@@ -368,7 +388,7 @@ class TriggerModel:
368
388
  """Main facade class that orchestrates model conversion, compilation, and inference"""
369
389
 
370
390
  def __init__(self, name: str, ml_backend: str, n_outputs:int, compiler: str,
371
- native_model: object, compiler_config: dict = None, scales: dict = None):
391
+ native_model: object, compiler_config: dict = None, scales: dict = None, unscaled_type: str = "ap_fixed<16,6>"):
372
392
 
373
393
  if ml_backend.lower() not in ("keras", "xgboost"):
374
394
  raise ValueError("Only Keras or XGBoost backends are currently supported.")
@@ -376,6 +396,7 @@ class TriggerModel:
376
396
  self.name = name
377
397
  self.ml_backend = ml_backend.lower()
378
398
  self.scales = scales
399
+ self.unscaled_type = unscaled_type
379
400
  self.n_outputs = n_outputs
380
401
  self.compiler = compiler.lower()
381
402
  self.native_model = native_model
@@ -436,7 +457,7 @@ class TriggerModel:
436
457
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
437
458
 
438
459
  if self.compiler is not "conifer" and self.scales is not None:
439
- self.build_emulator(self.scales['shifts'], self.scales['offsets'], self.n_outputs)
460
+ self.build_emulator(self.scales['shifts'], self.scales['offsets'], self.n_outputs, self.unscaled_type)
440
461
 
441
462
  self.workspace_manager.add_artifact("firmware", self.workspace_manager.workspace / "firmware")
442
463
  self.workspace_manager.save_metadata()
@@ -473,11 +494,9 @@ class TriggerModel:
473
494
  predictor = FirmwarePredictor(self.firmware_model, self.compiler)
474
495
  return predictor.predict(input_data)
475
496
 
476
- def build_emulator(self, ad_shift: list, ad_offsets: list, n_outputs: int):
477
- """
478
- Create an emulator directory for this model.
479
- Copies HLS sources and generates emulator scaffolding.
480
- """
497
+ def build_emulator(self, ad_shift: list, ad_offsets: list, n_outputs: int, unscaled_type: str = "ap_fixed<16,6>"):
498
+ """Builds CMSSW emulator"""
499
+
481
500
  emulator_dir = self.workspace / "emulator"
482
501
  emulator_dir.mkdir(exist_ok=True)
483
502
 
@@ -497,12 +516,13 @@ class TriggerModel:
497
516
  "N_OUTPUTS": n_outputs,
498
517
  "AD_SHIFT": ", ".join(map(str, ad_shift)),
499
518
  "AD_OFFSETS": ", ".join(map(str, ad_offsets)),
519
+ "UNSCALED_TYPE": unscaled_type,
500
520
  }
501
521
  self._render_template(scales_template_path, scales_out_path, context)
502
522
 
503
523
  with pkg_resources.path(triggerflow.templates, "model_template.cpp") as emulator_template_path:
504
524
  emulator_out_path = model_dir / "emulator.cpp"
505
- self._render_template(emulator_template_path, emulator_out_path, context)
525
+ self._render_template(emulator_template_path, emulator_out_path, context)
506
526
 
507
527
  with pkg_resources.path(triggerflow.templates, "makefile_version") as makefile_template_path:
508
528
  makefile_out_path = model_dir / "Makefile"
@@ -1,34 +1,40 @@
1
1
  # trigger_mlflow.py
2
- import mlflow
2
+ import datetime
3
+ import logging
3
4
  import os
4
- import mlflow.pyfunc
5
5
  import tempfile
6
6
  from pathlib import Path
7
- from typing import Dict, Any
7
+ from typing import Any
8
+
9
+ import mlflow
10
+ import mlflow.pyfunc
8
11
  from mlflow.tracking import MlflowClient
12
+
9
13
  from .core import TriggerModel
10
14
 
15
+ logger = logging.getLogger(__name__)
11
16
 
12
- def setup_mlflow(mlflow_uri: str = None,
13
- web_eos_url: str = None,
14
- web_eos_path: str = None,
15
- model_name: str = None,
16
- experiment_name: str = None,
17
- run_name: str = None,
18
- experiment_id: str = None,
19
- run_id: str = None,
20
- creat_web_eos_dir: bool = False,
21
- save_env_file: bool = False,
22
- auto_configure: bool = False
23
- ):
17
+
18
+ def setup_mlflow(
19
+ mlflow_uri: str = None,
20
+ web_eos_url: str = None,
21
+ web_eos_path: str = None,
22
+ model_name: str = None,
23
+ experiment_name: str = None,
24
+ run_name: str = None,
25
+ experiment_id: str = None,
26
+ run_id: str = None,
27
+ creat_web_eos_dir: bool = False,
28
+ save_env_file: bool = False,
29
+ auto_configure: bool = False
30
+ ):
24
31
 
25
32
  # Set the MLflow tracking URI
26
33
  if mlflow_uri is None:
27
34
  mlflow_uri = os.getenv('MLFLOW_URI', 'https://ngt.cern.ch/models')
28
35
  mlflow.set_tracking_uri(mlflow_uri)
29
36
  os.environ["MLFLOW_URI"] = mlflow_uri
30
- print(f"Using MLflow tracking URI: {mlflow_uri}")
31
-
37
+ logger.info(f"Using MLflow tracking URI: {mlflow_uri}")
32
38
 
33
39
  # Set the model name
34
40
  if model_name is None:
@@ -37,7 +43,7 @@ def setup_mlflow(mlflow_uri: str = None,
37
43
  else:
38
44
  model_name = os.getenv('CI_COMMIT_BRANCH', 'Test-Model')
39
45
  os.environ["MLFLOW_MODEL_NAME"] = model_name
40
- print(f"Using model name: {model_name}")
46
+ logger.info(f"Using model name: {model_name}")
41
47
 
42
48
 
43
49
  # Set the experiment name
@@ -47,7 +53,7 @@ def setup_mlflow(mlflow_uri: str = None,
47
53
  else:
48
54
  experiment_name = os.getenv('CI_COMMIT_BRANCH', 'Test-Training-Torso')
49
55
  os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_name
50
- print(f"Using experiment name: {experiment_name}")
56
+ logger.info(f"Using experiment name: {experiment_name}")
51
57
 
52
58
 
53
59
  # Set the run name
@@ -58,10 +64,9 @@ def setup_mlflow(mlflow_uri: str = None,
58
64
  else:
59
65
  run_name = f"{os.getenv('CI_PIPELINE_ID')}"
60
66
  else:
61
- import datetime
62
67
  run_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
63
68
  os.environ["MLFLOW_RUN_NAME"] = run_name
64
- print(f"Using run name: {run_name}")
69
+ logger.info(f"Using run name: {run_name}")
65
70
 
66
71
 
67
72
  # Create a new experiment or get the existing one
@@ -73,7 +78,7 @@ def setup_mlflow(mlflow_uri: str = None,
73
78
  experiment_id = mlflow.create_experiment(experiment_name)
74
79
  except mlflow.exceptions.MlflowException:
75
80
  experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
76
-
81
+
77
82
  check_experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
78
83
  if str(check_experiment_id) != str(experiment_id):
79
84
  raise ValueError(f"Provided experiment_id {experiment_id} does not match the ID of experiment_name {experiment_name} ({check_experiment_id})")
@@ -85,7 +90,7 @@ def setup_mlflow(mlflow_uri: str = None,
85
90
 
86
91
  mlflow.set_experiment(experiment_id=experiment_id)
87
92
  os.environ["MLFLOW_EXPERIMENT_ID"] = experiment_id
88
- print(f"Using experiment ID: {experiment_id}")
93
+ logger.info(f"Using experiment ID: {experiment_id}")
89
94
 
90
95
 
91
96
  # Start a new MLflow run
@@ -99,9 +104,9 @@ def setup_mlflow(mlflow_uri: str = None,
99
104
  check_run_info = mlflow.get_run(run_id)
100
105
  if str(check_run_info.info.experiment_id) != str(experiment_id):
101
106
  raise ValueError(f"Provided run_id {run_id} does not belong to experiment_id {experiment_id} (found {check_run_info.info.experiment_id})")
102
-
107
+
103
108
  os.environ["MLFLOW_RUN_ID"] = run_id
104
- print(f"Started run with ID: {run_id}")
109
+ logger.info(f"Started run with ID: {run_id}")
105
110
 
106
111
 
107
112
  if creat_web_eos_dir:
@@ -109,21 +114,21 @@ def setup_mlflow(mlflow_uri: str = None,
109
114
  if web_eos_url is None:
110
115
  web_eos_url = os.getenv('WEB_EOS_URL', 'https://ngt-modeltraining.web.cern.ch/')
111
116
  os.environ["WEB_EOS_URL"] = web_eos_url
112
- print(f"Using WEB_EOS_URL: {web_eos_url}")
117
+ logger.info(f"Using WEB_EOS_URL: {web_eos_url}")
113
118
 
114
119
  # Set the web_eos_path
115
120
  if web_eos_path is None:
116
121
  web_eos_path = os.getenv('WEB_EOS_PATH', '/eos/user/m/mlflowngt/backend/www')
117
122
  os.environ["WEB_EOS_PATH"] = web_eos_path
118
- print(f"Using WEB_EOS_PATH: {web_eos_path}")
123
+ logger.info(f"Using WEB_EOS_PATH: {web_eos_path}")
119
124
 
120
125
  # Create WebEOS experiment dir
121
126
  web_eos_experiment_dir = os.path.join(web_eos_path, experiment_name, run_name)
122
127
  web_eos_experiment_url = os.path.join(web_eos_url, experiment_name, run_name)
123
128
  os.makedirs(web_eos_experiment_dir, exist_ok=True)
124
- print(f"Created WebEOS experiment directory: {web_eos_experiment_dir}")
125
- print(f"Using WebEOS experiment URL: {web_eos_experiment_url}")
126
-
129
+ logger.info(f"Created WebEOS experiment directory: {web_eos_experiment_dir}")
130
+ logger.info(f"Using WebEOS experiment URL: {web_eos_experiment_url}")
131
+
127
132
  else:
128
133
  web_eos_url=None
129
134
  web_eos_path=None
@@ -133,7 +138,7 @@ def setup_mlflow(mlflow_uri: str = None,
133
138
 
134
139
  # Save environment variables to a file for later steps in CI/CD pipelines
135
140
  if save_env_file and os.getenv("CI") == "true":
136
- print(f"Saving MLflow environment variables to {os.getenv('CI_ENV_FILE', 'mlflow.env')}")
141
+ logger.info(f"Saving MLflow environment variables to {os.getenv('CI_ENV_FILE', 'mlflow.env')}")
137
142
  with open(os.getenv('CI_ENV_FILE', 'mlflow.env'), 'a') as f:
138
143
  f.write(f"MLFLOW_URI={mlflow_uri}\n")
139
144
  f.write(f"MLFLOW_MODEL_NAME={model_name}\n")
@@ -149,8 +154,8 @@ def setup_mlflow(mlflow_uri: str = None,
149
154
  f.write(f"WEB_EOS_EXPERIMENT_URL={web_eos_experiment_url}\n")
150
155
 
151
156
  if auto_configure:
152
- print("Auto_configure is set to true. Exporting AUTO_CONFIGURE=true")
153
- f.write(f"AUTO_CONFIGURE=true\n")
157
+ logger.info("Auto_configure is set to true. Exporting AUTO_CONFIGURE=true")
158
+ f.write("AUTO_CONFIGURE=true\n")
154
159
 
155
160
  return {
156
161
  "experiment_name": experiment_name,
@@ -166,17 +171,17 @@ def setup_mlflow(mlflow_uri: str = None,
166
171
  }
167
172
 
168
173
  if os.getenv("AUTO_CONFIGURE") == "true":
169
- print("AUTO_CONFIGURE is true and running in CI environment. Setting up mlflow...")
174
+ logger.info("AUTO_CONFIGURE is true and running in CI environment. Setting up mlflow...")
170
175
  setup_mlflow()
171
176
  else:
172
- print("AUTO_CONFIGURE is not set. Skipping mlflow run setup")
177
+ logger.info("AUTO_CONFIGURE is not set. Skipping mlflow run setup")
173
178
 
174
179
  class MLflowWrapper(mlflow.pyfunc.PythonModel):
175
180
  """PyFunc wrapper for TriggerModel; backend can be set at runtime."""
176
181
  def load_context(self, context):
177
- archive_path = Path(context.artifacts["trigger_model"])
182
+ archive_path = Path(context.artifacts["triggerflow"])
178
183
  self.model = TriggerModel.load(archive_path)
179
- self.backend = "software"
184
+ self.backend = "software"
180
185
 
181
186
  def predict(self, context, model_input):
182
187
  if self.backend == "software":
@@ -198,22 +203,22 @@ class MLflowWrapper(mlflow.pyfunc.PythonModel):
198
203
  return {"error": "Model info not available"}
199
204
 
200
205
 
201
- def _get_pip_requirements(trigger_model: TriggerModel) -> list:
206
+ def _get_pip_requirements(triggerflow: TriggerModel) -> list:
202
207
  requirements = ["numpy"]
203
- if trigger_model.ml_backend == "keras":
208
+ if triggerflow.ml_backend == "keras":
204
209
  requirements.extend(["tensorflow", "keras"])
205
- elif trigger_model.ml_backend == "xgboost":
210
+ elif triggerflow.ml_backend == "xgboost":
206
211
  requirements.append("xgboost")
207
- if trigger_model.compiler == "hls4ml":
212
+ if triggerflow.compiler == "hls4ml":
208
213
  requirements.append("hls4ml")
209
- elif trigger_model.compiler == "conifer":
214
+ elif triggerflow.compiler == "conifer":
210
215
  requirements.append("conifer")
211
- if hasattr(trigger_model, "model_qonnx") and trigger_model.model_qonnx is not None:
216
+ if hasattr(triggerflow, "model_qonnx") and triggerflow.model_qonnx is not None:
212
217
  requirements.append("qonnx")
213
218
  return requirements
214
219
 
215
220
 
216
- def log_model(trigger_model: TriggerModel, registered_model_name: str = None, artifact_path: str = "TriggerModel"):
221
+ def log_model(triggerflow: TriggerModel, registered_model_name: str, artifact_path: str = "TriggerModel"):
217
222
  """Log a TriggerModel as a PyFunc model and register it in the Model Registry."""
218
223
  if not registered_model_name:
219
224
  if not os.getenv("MLFLOW_MODEL_NAME"):
@@ -227,13 +232,13 @@ def log_model(trigger_model: TriggerModel, registered_model_name: str = None, ar
227
232
  run = mlflow.active_run()
228
233
  with tempfile.TemporaryDirectory() as tmpdir:
229
234
  archive_path = Path(tmpdir) / "triggermodel.tar.xz"
230
- trigger_model.save(archive_path)
235
+ triggerflow.save(archive_path)
231
236
 
232
237
  mlflow.pyfunc.log_model(
233
238
  artifact_path=artifact_path,
234
239
  python_model=MLflowWrapper(),
235
- artifacts={"trigger_model": str(archive_path)},
236
- pip_requirements=_get_pip_requirements(trigger_model)
240
+ artifacts={"triggerflow": str(archive_path)},
241
+ pip_requirements=_get_pip_requirements(triggerflow)
237
242
  )
238
243
 
239
244
  # register model (always required)
@@ -255,11 +260,11 @@ def load_model(model_uri: str) -> mlflow.pyfunc.PyFuncModel:
255
260
 
256
261
  def load_full_model(model_uri: str) -> TriggerModel:
257
262
  local_path = mlflow.artifacts.download_artifacts(model_uri)
258
- archive_path = Path(local_path) / "trigger_model" / "triggermodel.tar.xz"
263
+ archive_path = Path(local_path) / "triggerflow" / "triggermodel.tar.xz"
259
264
  return TriggerModel.load(archive_path)
260
265
 
261
266
 
262
- def get_model_info(model_uri: str) -> Dict[str, Any]:
267
+ def get_model_info(model_uri: str) -> dict[str, Any]:
263
268
  model = mlflow.pyfunc.load_model(model_uri)
264
269
  if hasattr(model._model_impl, "get_model_info"):
265
270
  return model._model_impl.get_model_info()
@@ -0,0 +1,143 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # pyenv
82
+ # For a library or package, you might want to ignore these files since the code is
83
+ # intended to run in multiple environments; otherwise, check them in:
84
+ # .python-version
85
+
86
+ # pipenv
87
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
89
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
90
+ # install all needed dependencies.
91
+ #Pipfile.lock
92
+
93
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94
+ __pypackages__/
95
+
96
+ # Celery stuff
97
+ celerybeat-schedule
98
+ celerybeat.pid
99
+
100
+ # SageMath parsed files
101
+ *.sage.py
102
+
103
+ # Environments
104
+ .env
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ # pytype static type analyzer
131
+ .pytype/
132
+
133
+ # Cython debug symbols
134
+ cython_debug/
135
+
136
+ .vscode/
137
+ info.log
138
+
139
+ # IntelliJ
140
+ .idea/
141
+ *.iml
142
+ out/
143
+ .idea_modules/
File without changes
@@ -0,0 +1,5 @@
1
+ {
2
+ "project_name": "triggerflow-pipeline",
3
+ "repo_name": "{{ cookiecutter.project_name.strip().replace(' ', '-').replace('_', '-').lower() }}",
4
+ "python_package": "{{ cookiecutter.project_name.strip().replace(' ', '_').replace('-', '_').lower() }}"
5
+ }
@@ -0,0 +1,9 @@
1
+ project_name:
2
+ title: "Project Name"
3
+ text: |
4
+ Please enter a human readable name for your new project.
5
+ Spaces, hyphens, and underscores are allowed.
6
+ regex_validator: "^[\\w -]{2,}$"
7
+ error_message: |
8
+ It must contain only alphanumeric symbols, spaces, underscores and hyphens and
9
+ be at least 2 characters long.
@@ -0,0 +1,3 @@
1
+ # Add patterns of files dvc should ignore, which could improve
2
+ # the performance. Learn more at
3
+ # https://dvc.org/doc/user-guide/dvcignore
@@ -0,0 +1,143 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # pyenv
82
+ # For a library or package, you might want to ignore these files since the code is
83
+ # intended to run in multiple environments; otherwise, check them in:
84
+ # .python-version
85
+
86
+ # pipenv
87
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
89
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
90
+ # install all needed dependencies.
91
+ #Pipfile.lock
92
+
93
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94
+ __pypackages__/
95
+
96
+ # Celery stuff
97
+ celerybeat-schedule
98
+ celerybeat.pid
99
+
100
+ # SageMath parsed files
101
+ *.sage.py
102
+
103
+ # Environments
104
+ .env
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ # pytype static type analyzer
131
+ .pytype/
132
+
133
+ # Cython debug symbols
134
+ cython_debug/
135
+
136
+ .vscode/
137
+ info.log
138
+
139
+ # IntelliJ
140
+ .idea/
141
+ *.iml
142
+ out/
143
+ .idea_modules/