triggerflow 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +154 -0
  6. trigger_loader/processor.py +212 -0
  7. triggerflow/__init__.py +0 -0
  8. triggerflow/cli.py +122 -0
  9. triggerflow/core.py +617 -0
  10. triggerflow/interfaces/__init__.py +0 -0
  11. triggerflow/interfaces/uGT.py +187 -0
  12. triggerflow/mlflow_wrapper.py +270 -0
  13. triggerflow/starter/.gitignore +143 -0
  14. triggerflow/starter/README.md +0 -0
  15. triggerflow/starter/cookiecutter.json +5 -0
  16. triggerflow/starter/prompts.yml +9 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +90 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/condor_config.json +11 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/cuda_config.json +4 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +24 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/settings.json +8 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/test.root +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +23 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_loader.py +101 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +49 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_loader.py +32 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +70 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +20 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +41 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +13 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +48 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +31 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  88. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  89. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  90. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  91. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  92. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  93. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  94. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  95. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  96. triggerflow/templates/build_ugt.tcl +46 -0
  97. triggerflow/templates/data_types.h +524 -0
  98. triggerflow/templates/makefile +28 -0
  99. triggerflow/templates/makefile_version +15 -0
  100. triggerflow/templates/model-gt.cpp +104 -0
  101. triggerflow/templates/model_template.cpp +63 -0
  102. triggerflow/templates/scales.h +20 -0
  103. triggerflow-0.3.4.dist-info/METADATA +206 -0
  104. triggerflow-0.3.4.dist-info/RECORD +107 -0
  105. triggerflow-0.3.4.dist-info/WHEEL +5 -0
  106. triggerflow-0.3.4.dist-info/entry_points.txt +2 -0
  107. triggerflow-0.3.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,23 @@
1
+ channels:
2
+ - conda-forge
3
+ - defaults
4
+ dependencies:
5
+ - python=3.11
6
+ - pip
7
+ - pip:
8
+ - kedro
9
+ - kedro-viz
10
+ - kedro-datasets
11
+ - matplotlib
12
+ - mplhep
13
+ - shap
14
+ - scikit-learn
15
+ - pandas
16
+ - dvc
17
+ - shap
18
+ - ruff
19
+ - uproot
20
+ - awkward
21
+ - triggerflow
22
+ - coffea
23
+ - dask[distributed]
@@ -0,0 +1,50 @@
1
+ [build-system]
2
+ requires = [ "setuptools",]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ requires-python = ">=3.10"
7
+ name = "{{ cookiecutter.python_package }}"
8
+ readme = "README.md"
9
+ dynamic = [ "version",]
10
+ dependencies = [ "ipython>=8.10", "jupyterlab>=3.0", "notebook", "kedro~=1.0.0",]
11
+
12
+ [project.scripts]
13
+ {{ cookiecutter.project_name }} = "{{ cookiecutter.python_package }}.__main__:main"
14
+
15
+ [project.optional-dependencies]
16
+ dev = [ "pytest-cov~=3.0", "pytest-mock>=1.7.1, <2.0", "pytest~=7.2", "ruff~=0.1.8",]
17
+
18
+ [tool.kedro]
19
+ package_name = "{{ cookiecutter.python_package }}"
20
+ project_name = "{{ cookiecutter.project_name }}"
21
+ kedro_init_version = "1.0.0"
22
+ tools = "['Linting', 'Testing', 'Custom Logging', 'Data Structure']"
23
+ example_pipeline = "False"
24
+ source_dir = "src"
25
+
26
+ [tool.ruff]
27
+ line-length = 88
28
+ show-fixes = true
29
+ select = [ "F", "W", "E", "I", "UP", "PL", "T201",]
30
+ ignore = [ "E501",]
31
+
32
+ [project.entry-points."kedro.hooks"]
33
+
34
+ [tool.pytest.ini_options]
35
+ addopts = "--cov-report term-missing --cov src/{{ cookiecutter.python_package }} -ra"
36
+
37
+ [tool.coverage.report]
38
+ fail_under = 0
39
+ show_missing = true
40
+ exclude_lines = [ "pragma: no cover", "raise NotImplementedError",]
41
+
42
+ [tool.ruff.format]
43
+ docstring-code-format = true
44
+
45
+ [tool.setuptools.dynamic.version]
46
+ attr = "{{ cookiecutter.python_package }}.__version__"
47
+
48
+ [tool.setuptools.packages.find]
49
+ where = [ "src",]
50
+ namespaces = false
@@ -0,0 +1,3 @@
1
+ """{{ cookiecutter.project_name }}"""
2
+
3
+ __version__ = "0.1"
@@ -0,0 +1,25 @@
1
+ """{{ cookiecutter.project_name }} file for ensuring the package is executable
2
+ as `{{ cookiecutter.project_name }}` and `python -m {{ cookiecutter.python_package }}`
3
+ """
4
+
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from kedro.framework.cli.utils import find_run_command
10
+ from kedro.framework.project import configure_project
11
+
12
+
13
+ def main(*args, **kwargs) -> Any:
14
+ package_name = Path(__file__).parent.name
15
+ configure_project(package_name)
16
+
17
+ interactive = hasattr(sys, "ps1")
18
+ kwargs["standalone_mode"] = not interactive
19
+
20
+ run = find_run_command(package_name)
21
+ return run(*args, **kwargs)
22
+
23
+
24
+ if __name__ == "__main__":
25
+ main()
@@ -0,0 +1,20 @@
1
+ from kedro.io import AbstractDataset
2
+ from typing import Any
3
+
4
+
5
+ class AnyObject(AbstractDataset):
6
+ """
7
+ Abstract class which can be used for passing "Any" object
8
+ """
9
+
10
+ def __init__(self):
11
+ pass
12
+
13
+ def _load(self) -> None:
14
+ pass
15
+
16
+ def _save(self, data: Any) -> Any:
17
+ return data
18
+
19
+ def _describe(self) -> dict:
20
+ return {}
@@ -0,0 +1,137 @@
1
+ import logging, uproot, json, os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from abc import abstractmethod
5
+ from fnmatch import filter as fnmatch_filter
6
+ from kedro.io import AbstractDataset
7
+
8
+
9
+ class BaseDataset(AbstractDataset):
10
+ """
11
+ Abstract Base Class for loading data from ROOT files.
12
+
13
+ Users must inherit from this class and implement the abstract methods.
14
+ The core processing logic in `_load` is fixed and cannot be overridden.
15
+ """
16
+
17
+ def __init__(self, sample_info: str, sample_key: str):
18
+ with open(sample_info, "r") as f:
19
+ data = json.load(f)
20
+ self._sample_info = data[sample_key]
21
+ self._sample_key = sample_key
22
+
23
+ # get logger for reporting
24
+ self.logger = logging.getLogger(__name__)
25
+ self.logger.info(f"Initializing dataset: {self.__class__.__name__}")
26
+
27
+ @abstractmethod
28
+ def get_branches_to_keep(self) -> list[str]:
29
+ """
30
+ USER MUST IMPLEMENT: Return a list of branch names or patterns (with wildcards)
31
+ to keep from the ROOT file.
32
+
33
+ Example:
34
+ return ["Jet_*", "PuppiMET_pt", "nJet"]
35
+ """
36
+ pass
37
+
38
+ @abstractmethod
39
+ def get_cut(self) -> str | None:
40
+ """
41
+ USER MUST IMPLEMENT: Return a string representing the cuts to apply to the data.
42
+ """
43
+ pass
44
+
45
+ @abstractmethod
46
+ def convert_to_pandas(self, data: dict) -> pd.DataFrame:
47
+ """
48
+ USER MUST IMPLEMENT: Convert the loaded data from a dictionary format to a pandas DataFrame.
49
+ """
50
+ pass
51
+
52
+ def get_tree_name(self) -> str:
53
+ return "Events"
54
+
55
+ def _resolve_branches(self, all_branches: list) -> list[str]:
56
+ """Internal method to resolve wildcard patterns."""
57
+ selected = []
58
+ for pattern in self.get_branches_to_keep():
59
+ matched = fnmatch_filter(all_branches, pattern)
60
+ if not matched:
61
+ self.logger.warning(f"Pattern '{pattern}' did not match any branches.")
62
+ selected.extend(matched)
63
+ return sorted(list(set(selected)))
64
+
65
+ def _load(self) -> pd.DataFrame:
66
+ """
67
+ CORE LOGIC (NOT OVERRIDABLE): Loads and processes a single ROOT file.
68
+ """
69
+
70
+ # Process all files in sample
71
+ df = pd.DataFrame()
72
+
73
+ all_root_files = []
74
+ for key in self._sample_info.keys():
75
+ files = os.listdir(self._sample_info[key]["folder"])
76
+ cur_files = []
77
+ for file_pattern in self._sample_info[key]["file_pattern"]:
78
+ for f in fnmatch_filter(files, file_pattern):
79
+ cur_files.append(os.path.join(self._sample_info[key]["folder"], f))
80
+ all_root_files.append(cur_files)
81
+
82
+ is_signals = [
83
+ self._sample_info[key]["is_signal"] for key in self._sample_info.keys()
84
+ ]
85
+ self.logger.info("Processing files")
86
+ for root_files, is_signal in zip(all_root_files, is_signals):
87
+ self.logger.info(f"Processing files: {root_files}")
88
+ for root_file in root_files:
89
+ if f"{root_file}" == "data/01_raw/samples_dummy.json":
90
+ n = 100
91
+ # generate dummy features
92
+ dummy_data = {}
93
+ for branch in self.get_branches_to_keep():
94
+ dummy_data[branch] = np.random.randn(n)
95
+ if is_signal:
96
+ dummy_data["is_signal"] = np.ones(n)
97
+ else:
98
+ dummy_data["is_signal"] = np.zeros(n)
99
+
100
+ cur_df = pd.DataFrame(dummy_data)
101
+
102
+ # generate a binary target (0/1)
103
+ cur_df["y"] = np.random.choice([0, 1], size=n)
104
+
105
+ df = pd.concat([df, cur_df])
106
+
107
+ else:
108
+ with uproot.open(f"{root_file}") as f:
109
+ tree = f[self.get_tree_name()]
110
+ all_branches = tree.keys()
111
+ branches_to_load = self._resolve_branches(all_branches)
112
+
113
+ if not branches_to_load:
114
+ self.logger.warning(
115
+ f"No valid branches to load for {root_file}. Skipping."
116
+ )
117
+ continue
118
+
119
+ data = tree.arrays(branches_to_load, cut=self.get_cut())
120
+
121
+ cur_df = self.convert_to_pandas(data)
122
+
123
+ # set background or signal
124
+ if is_signal:
125
+ cur_df["is_signal"] = [1 for _ in range(len(cur_df))]
126
+ else:
127
+ cur_df["is_signal"] = [0 for _ in range(len(cur_df))]
128
+
129
+ df = pd.concat([df, cur_df])
130
+
131
+ return df
132
+
133
+ def _save(self, data: pd.DataFrame) -> pd.DataFrame:
134
+ return data
135
+
136
+ def _describe(self) -> dict:
137
+ return {"output_sample_info": self._sample_info, "sample_key": self._sample_key}
@@ -0,0 +1,101 @@
1
+ import logging, json
2
+ from abc import abstractmethod
3
+ from kedro.io import AbstractDataset
4
+ from trigger_loader.loader import TriggerLoader
5
+ import pandas as pd
6
+ import numpy as np
7
+ from pathlib import Path
8
+
9
+
10
+ class BaseLoader(AbstractDataset):
11
+ """
12
+ Abstract Base Class for using the TriggerLoader.
13
+
14
+ Users must inherit from this class and implement the abstract methods.
15
+ The core processing logic in `_load` is fixed and cannot be overridden.
16
+ """
17
+
18
+ def __init__(self, sample_json: str, settings: str, config: str):
19
+ self.sample_json = sample_json
20
+ with open(settings, "r") as f:
21
+ self.settings = json.load(f)
22
+ with open(config, "r") as f:
23
+ self.config = json.load(f)
24
+
25
+ # get logger for reporting
26
+ self.logger = logging.getLogger(__name__)
27
+ self.logger.info(f"Initializing loader: {self.__class__.__name__}")
28
+
29
+ @abstractmethod
30
+ def transform(self, events):
31
+ """
32
+ USER MUST IMPLEMENT.
33
+ """
34
+ pass
35
+
36
+ def _load(self) -> pd.DataFrame:
37
+ """
38
+ CORE LOGIC (NOT OVERRIDABLE): Loads and processes a single ROOT file.
39
+ """
40
+
41
+ self.logger.info(f"Start Loading...")
42
+ loader = TriggerLoader(
43
+ sample_json=self.sample_json,
44
+ transform=self.transform,
45
+ output_path=self.settings["output_dir"]
46
+ )
47
+
48
+ if self.settings["run_local"]:
49
+ loader.run_local(
50
+ num_workers=self.settings["num_workers"],
51
+ chunksize=self.settings["chunksize"]
52
+ )
53
+ else:
54
+ loader.run_distributed(
55
+ cluster_type=self.settings["cluster_type"],
56
+ cluster_config=self.config,
57
+ chunksize=self.settings["chunksize"],
58
+ jobs=self.settings["jobs"]
59
+ )
60
+
61
+ # load last parquet file from manifest file for each dataset key
62
+ # from the meta_data
63
+ dataset_keys = set(loader.meta_data.keys())
64
+ manifest_path = Path(self.settings["output_dir"]) / "manifest.json"
65
+
66
+ last_records = {key: None for key in dataset_keys}
67
+
68
+ with manifest_path.open() as f:
69
+ for line in f:
70
+ record = json.loads(line)
71
+ dataset = record.get("dataset")
72
+
73
+ if dataset in last_records:
74
+ last_records[dataset] = record
75
+
76
+ # sanity check
77
+ missing = [k for k, v in last_records.items() if v is None]
78
+ if missing:
79
+ raise ValueError(f"No manifest entry found for datasets: {missing}")
80
+
81
+ final_dfs = []
82
+ for dataset_key, record in last_records.items():
83
+ file_path = record["parquet_file"]
84
+ df = pd.read_parquet(file_path)
85
+
86
+ if loader.meta_data[dataset_key]["is_signal"]:
87
+ df["is_signal"] = np.ones(len(df), dtype=int)
88
+ df["y"] = np.ones(len(df), dtype=int)
89
+ else:
90
+ df["is_signal"] = np.zeros(len(df), dtype=int)
91
+ df["y"] = np.zeros(len(df), dtype=int)
92
+
93
+ final_dfs.append(df)
94
+
95
+ return pd.concat(final_dfs, ignore_index=True)
96
+
97
+ def _save(self, data: pd.DataFrame) -> pd.DataFrame:
98
+ return data
99
+
100
+ def _describe(self) -> dict:
101
+ return {"sample_json": self.sample_json, "settings": self.settings, "config": self.config}
@@ -0,0 +1,49 @@
1
+ import logging, json
2
+ from glob import glob
3
+ from kedro.io import AbstractDataset
4
+
5
+ METADATA_CONFIG = {"x": 0}
6
+
7
+ class MetaDataset(AbstractDataset):
8
+ def __init__(self, filepath: str, sample_key: str):
9
+ self._filepath = filepath
10
+ self._sample_key = sample_key
11
+ self.logger = logging.getLogger(__name__)
12
+
13
+ def get_dasgoclient_metadata(self, das_name: str, config: dict) -> dict:
14
+ self.logger.info(f"Fetching DAS metadata for dataset: {das_name}")
15
+ return {"gridpack": "0.0.0"}
16
+
17
+ def _load(self) -> dict:
18
+ self.logger.info(f"Processing file: {self._filepath}")
19
+ with open(self._filepath, "r") as f:
20
+ data = json.load(f)
21
+ return data
22
+
23
+ def _save(self, samples: dict) -> None:
24
+ metadata = {}
25
+
26
+ dataset_content = samples.get(self._sample_key, {})
27
+
28
+ for sample_name, sample_info in dataset_content.items():
29
+ self.logger.info(f"Processing sample: {sample_name}")
30
+
31
+ sample_files = sample_info.get("files", "")
32
+
33
+ resolved_files = glob(sample_files) if isinstance(sample_files, str) else sample_files
34
+ sample_info["files"] = resolved_files
35
+
36
+ self.logger.info(f"Found {len(resolved_files)} files for {sample_name}.")
37
+
38
+ metadata[sample_name] = self.get_dasgoclient_metadata(
39
+ sample_info.get("DAS", "Unknown"),
40
+ METADATA_CONFIG
41
+ )
42
+
43
+ with open(self._filepath, "w") as f:
44
+ json.dump(metadata, f, indent=4)
45
+
46
+ def _describe(self) -> dict:
47
+ return {"filepath": self._filepath, "sample_key": self._sample_key}
48
+
49
+
@@ -0,0 +1,35 @@
1
+ import pandas as pd
2
+ from .base_dataset import BaseDataset
3
+
4
+
5
+ class {{ cookiecutter.python_package }}Dataset(BaseDataset):
6
+ """
7
+ A custom dataset example.
8
+ """
9
+
10
+ def get_branches_to_keep(self) -> list[str]:
11
+ """
12
+ Define the branches you needed.
13
+ """
14
+ return [
15
+ "PuppiMET_pt",
16
+ "CaloMET_pt",
17
+ "event", # <-- we need this for meta data
18
+ # "Jet_pt",
19
+ # "Jet_eta",
20
+ # "Jet_phi",
21
+ # "Jet_btag*", # Use a wildcard to get all b-tagging info
22
+ "nJet",
23
+ ]
24
+
25
+ def get_cut(self) -> str | None:
26
+ """
27
+ Apply a pre-selection cut to keep only events with exactly 1 jet.
28
+ """
29
+ return "nJet == 1"
30
+
31
+ def convert_to_pandas(self, data: dict):
32
+ """
33
+ Logic to convert from dict of (potentially nested) arrays to a pandas DataFrame.
34
+ """
35
+ return pd.DataFrame(data)
@@ -0,0 +1,32 @@
1
+ import pandas as pd
2
+ import awkward as ak
3
+ from .base_loader import BaseLoader
4
+
5
+
6
+ class {{ cookiecutter.python_package }}Loader(BaseLoader):
7
+ """
8
+ A custom loader example.
9
+ """
10
+
11
+ def transform(self, events):
12
+
13
+ jets = events.Jet
14
+ pt = ak.fill_none(ak.pad_none(jets.pt , 2, axis=1, clip=True), -9999.9)
15
+ eta = ak.fill_none(ak.pad_none(jets.eta, 2, axis=1, clip=True), -9999.9)
16
+ phi = ak.fill_none(ak.pad_none(jets.phi, 2, axis=1, clip=True), -9999.9)
17
+
18
+ met = events.MET
19
+
20
+ result = ak.zip({
21
+ "event": events.event,
22
+ "jet_pt_1": pt[:, 0],
23
+ "jet_pt_2": pt[:, 1],
24
+ "jet_eta_1": eta[:, 0],
25
+ "jet_eta_2": eta[:, 1],
26
+ "jet_phi_1": phi[:, 0],
27
+ "jet_phi_2": phi[:, 1],
28
+ "met_pt": met.pt,
29
+ "met_phi": met.phi,
30
+ })
31
+
32
+ return result
@@ -0,0 +1,155 @@
1
+ import inspect
2
+ import pandas as pd
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+ from sklearn.base import BaseEstimator
6
+
7
+
8
+ class BaseModel(ABC, BaseEstimator):
9
+ """
10
+ Standard Wrapper for a model
11
+ """
12
+
13
+ def __init__(self, name: str, hps: dict):
14
+ self.name = name
15
+ # this will be overwritten after training
16
+ self.model = None
17
+ self.history = None
18
+ self.callbacks = []
19
+ self.hps = hps
20
+
21
+ @abstractmethod
22
+ def train(self, X: pd.DataFrame, y: pd.DataFrame, hps: dict, **kwargs):
23
+ """
24
+ User code function.
25
+ Args:
26
+ X: features
27
+ y: label
28
+ hps: hyperparameters
29
+ kwargs: anything else needed for training
30
+ """
31
+ pass
32
+
33
+ @abstractmethod
34
+ def build(self):
35
+ """
36
+ User code function to build the model.
37
+ """
38
+ pass
39
+
40
+ def predict(self, X: pd.DataFrame, **kwargs) -> pd.DataFrame:
41
+ """
42
+ Calculates predictions of the model
43
+ Args:
44
+ X: features
45
+
46
+ Returns:
47
+ predictions
48
+ (optional in user code) kwargs: anything else needed for predicting
49
+ """
50
+ y_pred = self.model.predict(X.astype("float32"))
51
+ return pd.DataFrame(y_pred)
52
+
53
+ def predict_proba(self, X: pd.DataFrame, **kwargs) -> pd.DataFrame:
54
+ """
55
+ Calculates proba predictions of the model
56
+ Args:
57
+ X: features
58
+
59
+ Returns:
60
+ predictions
61
+ (optional in user code) kwargs: anything else needed for predicting
62
+ """
63
+ y_pred = self.model.predict_proba(X.astype("float32"))
64
+ return pd.DataFrame(y_pred)
65
+
66
+ def fit(self, X: pd.DataFrame, y: pd.DataFrame):
67
+ """
68
+ Same as train but get kwargs from __init__ for sklearn API
69
+ Args:
70
+ X: features
71
+ y: label
72
+
73
+ X can also contain optional inputs https://github.com/scikit-learn/scikit-learn/issues/2879.
74
+ Which should be specified in the user code.
75
+ For example when the train function needs additional inputs:
76
+ ```python
77
+ curX = X.copy()
78
+ kwargs = {"S": curX["S"]}
79
+ del curX["S"]
80
+ self.train(curX, y, self.hps, **kwargs)
81
+ ```
82
+ """
83
+ self.train(X, y, self.hps)
84
+
85
+ def get_params(self, deep=True):
86
+ """
87
+ Get parameters for self.model and self.
88
+ Args:
89
+ deep : bool, default=True
90
+ If True, will return the parameters for this estimator and
91
+ contained subobjects that are estimators.
92
+
93
+ Returns:
94
+ params : dict
95
+ Parameter names mapped to their values.
96
+ """
97
+ out = dict()
98
+ # if self.hps is set return them and not the default values
99
+ for key in self.hps:
100
+ out[key] = self.hps[key]
101
+ for key in get_param_names(self):
102
+ value = getattr(self, key)
103
+ if deep and hasattr(value, "get_params") and not isinstance(value, type):
104
+ deep_items = value.get_params().items()
105
+ out.update((key + "__" + k, val) for k, val in deep_items)
106
+ out[key] = value
107
+ return out
108
+
109
+ def set_params(self, **params):
110
+ """
111
+ Set the parameters of this estimator.
112
+
113
+ We overwrite the sklearn BaseEstimator and set params to self.hps
114
+ Args:
115
+ **params : dict
116
+ Estimator parameters.
117
+
118
+ Returns:
119
+ self : estimator instance
120
+ Estimator instance.
121
+ """
122
+ self.hps = params
123
+
124
+ return self
125
+
126
+
127
+ def get_param_names(cls):
128
+ """Get parameter names for the estimator"""
129
+ # fetch the constructor or the original constructor before
130
+ # deprecation wrapping if any
131
+ init = getattr(cls.__init__, "deprecated_original", cls.__init__)
132
+ if init is object.__init__:
133
+ # No explicit constructor to introspect
134
+ return []
135
+
136
+ # introspect the constructor arguments to find the model parameters
137
+ # to represent
138
+ init_signature = inspect.signature(init)
139
+ # Consider the constructor parameters excluding 'self'
140
+ parameters = [
141
+ p
142
+ for p in init_signature.parameters.values()
143
+ if p.name != "self" and p.kind != p.VAR_KEYWORD
144
+ ]
145
+ for p in parameters:
146
+ if p.kind == p.VAR_POSITIONAL:
147
+ raise RuntimeError(
148
+ "scikit-learn estimators should always "
149
+ "specify their parameters in the signature"
150
+ " of their __init__ (no varargs)."
151
+ " %s with constructor %s doesn't "
152
+ " follow this convention." % (cls, init_signature)
153
+ )
154
+ # Extract and sort argument names excluding 'self'
155
+ return sorted([p.name for p in parameters])
@@ -0,0 +1,16 @@
1
+ import pandas as pd
2
+ from .base_model import BaseModel
3
+ from sklearn.dummy import DummyClassifier
4
+
5
+
6
+ class {{ cookiecutter.python_package }}(BaseModel):
7
+ def train(self, X: pd.DataFrame, y: pd.DataFrame, **kwargs):
8
+ self.build()
9
+ self.history = self.model.fit(X, y)
10
+
11
+ def build(self):
12
+ """Build the test Model.
13
+ self.hps:
14
+ -
15
+ """
16
+ self.model = DummyClassifier()
@@ -0,0 +1,17 @@
1
+ """Project pipelines."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from kedro.framework.project import find_pipelines
6
+ from kedro.pipeline import Pipeline
7
+
8
+
9
+ def register_pipelines() -> dict[str, Pipeline]:
10
+ """Register the project's pipelines.
11
+
12
+ Returns:
13
+ A mapping from pipeline names to ``Pipeline`` objects.
14
+ """
15
+ pipelines = find_pipelines()
16
+ pipelines["__default__"] = sum(pipelines.values())
17
+ return pipelines
@@ -0,0 +1,10 @@
1
+ """
2
+ This is a boilerplate pipeline 'compile'
3
+ generated using Kedro 1.0.0
4
+ """
5
+
6
+ from .pipeline import create_pipeline
7
+
8
+ __all__ = ["create_pipeline"]
9
+
10
+ __version__ = "0.1"