triggerflow 0.1.4__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +147 -0
  6. trigger_loader/processor.py +211 -0
  7. triggerflow/cli.py +122 -0
  8. triggerflow/core.py +127 -69
  9. triggerflow/interfaces/__init__.py +0 -0
  10. triggerflow/interfaces/uGT.py +127 -0
  11. triggerflow/mlflow_wrapper.py +190 -19
  12. triggerflow/starter/.gitignore +143 -0
  13. triggerflow/starter/README.md +0 -0
  14. triggerflow/starter/cookiecutter.json +5 -0
  15. triggerflow/starter/prompts.yml +9 -0
  16. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +84 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +15 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples_dummy.json +26 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +21 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +88 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +50 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +10 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +40 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +12 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +31 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +29 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  88. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  89. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  90. triggerflow/templates/build_ugt.tcl +46 -0
  91. triggerflow/templates/data_types.h +524 -0
  92. triggerflow/templates/makefile +3 -3
  93. triggerflow/templates/makefile_version +2 -2
  94. triggerflow/templates/model-gt.cpp +104 -0
  95. triggerflow/templates/model_template.cpp +19 -18
  96. triggerflow/templates/scales.h +1 -1
  97. triggerflow-0.2.4.dist-info/METADATA +192 -0
  98. triggerflow-0.2.4.dist-info/RECORD +102 -0
  99. triggerflow-0.2.4.dist-info/entry_points.txt +2 -0
  100. triggerflow-0.2.4.dist-info/top_level.txt +3 -0
  101. triggerflow-0.1.4.dist-info/METADATA +0 -61
  102. triggerflow-0.1.4.dist-info/RECORD +0 -11
  103. triggerflow-0.1.4.dist-info/top_level.txt +0 -1
  104. {triggerflow-0.1.4.dist-info → triggerflow-0.2.4.dist-info}/WHEEL +0 -0
File without changes
@@ -0,0 +1,88 @@
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from fnmatch import filter as fnmatch_filter
4
+
5
+ import pandas as pd
6
+ import uproot
7
+
8
+
9
+ class TriggerDataset(ABC):
10
+ """
11
+ Abstract Base Class for loading data from ROOT files.
12
+
13
+ Users must inherit from this class and implement the abstract methods.
14
+ The core processing logic in `process_file` is fixed and cannot be overridden.
15
+ """
16
+
17
+ def __init__(self):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_features(self) -> list[str]:
22
+ """
23
+ Return a list of branch names or patterns to keep from the dataset.
24
+ Accepts wildcards (e.g. Jet_*).
25
+ """
26
+ pass
27
+
28
+ @abstractmethod
29
+ def get_cut(self) -> str | None:
30
+ """
31
+ Return a string representing the cuts to apply to the data.
32
+ """
33
+ pass
34
+
35
+ @abstractmethod
36
+ def convert_to_pandas(self, data: dict) -> pd.DataFrame:
37
+ """
38
+ Convert the loaded data from a dictionary format to a pandas DataFrame.
39
+ """
40
+ pass
41
+
42
+ def _resolve_branches(self, all_branches: list) -> list[str]:
43
+ """Internal method to resolve wildcard patterns."""
44
+ selected = []
45
+ for pattern in self.get_features():
46
+ matched = fnmatch_filter(all_branches, pattern)
47
+ if not matched:
48
+ warnings.warn(f"'{pattern}' did not match any branches.")
49
+ selected.extend(matched)
50
+ return sorted(list(set(selected)))
51
+
52
+ def _save_to_parquet(self, df: pd.DataFrame, output_path: str):
53
+ """
54
+ Save the processed DataFrame to a file.
55
+ """
56
+ df.to_parquet(output_path)
57
+
58
+ def _save_to_csv(self, df: pd.DataFrame, output_path: str):
59
+ """
60
+ Save the processed DataFrame to a CSV file.
61
+ """
62
+ df.to_csv(output_path, index=False)
63
+
64
+ def process_file(self, file_path: str, out_file_path: str) -> pd.DataFrame:
65
+ """
66
+ Loads and processes a single ROOT file.
67
+ """
68
+
69
+ with uproot.open(f"{file_path}") as f:
70
+ tree = f[self.get_tree_name()]
71
+ all_branches = tree.keys()
72
+ branches_to_load = self._resolve_branches(all_branches)
73
+
74
+ if not branches_to_load:
75
+ return pd.DataFrame()
76
+
77
+ data = tree.arrays(branches_to_load, cut=self.get_cut(), how=dict)
78
+
79
+ df = self.convert_to_pandas(data)
80
+
81
+ if self.output_format == "parquet":
82
+ self._save_to_parquet(df, f"{out_file_path}.parquet")
83
+ elif self.output_format == "csv":
84
+ self._save_to_csv(df, f"{out_file_path}.csv")
85
+ else:
86
+ return pd.DataFrame()
87
+
88
+ return pd.DataFrame()
File without changes
@@ -0,0 +1,107 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any
5
+
6
+ from dask.distributed import Client, LocalCluster
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class ClusterManager:
12
+ """Context manager to provision and tear down a Dask cluster.
13
+
14
+ Parameters
15
+ ----------
16
+ cluster_type : str
17
+ Backend to use ("local", "condor", "cuda", "kubernetes").
18
+ cluster_config : dict | None, optional
19
+ Keyword arguments forwarded to the specific cluster constructor.
20
+ jobs : int, optional
21
+ Desired number of jobs / workers (used for queue / scalable backends).
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ cluster_type: str,
27
+ cluster_config: dict[str, Any] | None = None,
28
+ jobs: int = 1,
29
+ ) -> None:
30
+ if cluster_config is None:
31
+ cluster_config = {}
32
+ # Copy to avoid mutating caller's dict accidentally.
33
+ self.cluster_config: dict[str, Any] = dict(cluster_config)
34
+ self.cluster_type: str = cluster_type
35
+ self.jobs: int = jobs
36
+
37
+ self.cluster: Any | None = None
38
+ self.client: Any | None = None
39
+
40
+ # ------------------------------------------------------------------
41
+ # Context manager protocol
42
+ # ------------------------------------------------------------------
43
+ def __enter__(self): # -> distributed.Client (avoids importing type eagerly)
44
+ self._start_cluster()
45
+ return self.client
46
+
47
+ def __exit__(self, exc_type, exc, tb) -> bool: # noqa: D401 (simple)
48
+ self._close_cluster()
49
+ # Returning False propagates any exception (desired behavior)
50
+ return False
51
+
52
+ # ------------------------------------------------------------------
53
+ # Internal helpers
54
+ # ------------------------------------------------------------------
55
+ def _start_cluster(self) -> None:
56
+
57
+ ct = self.cluster_type.lower()
58
+
59
+ if ct == "local":
60
+ self.cluster = LocalCluster(**self.cluster_config)
61
+
62
+ elif ct == "condor":
63
+ from dask_jobqueue import HTCondorCluster
64
+ self.cluster = HTCondorCluster(**self.cluster_config)
65
+ if self.jobs and self.jobs > 0:
66
+ # Scale to the requested number of jobs
67
+ self.cluster.scale(jobs=self.jobs)
68
+
69
+ elif ct == "cuda":
70
+ from dask_cuda import LocalCUDACluster
71
+ self.cluster = LocalCUDACluster(**self.cluster_config)
72
+
73
+ elif ct == "kubernetes":
74
+ from dask_kubernetes import KubeCluster
75
+ self.cluster = KubeCluster(**self.cluster_config)
76
+ if self.jobs and self.jobs > 0:
77
+ try:
78
+ # Not all KubeCluster versions expose scale() identically
79
+ self.cluster.scale(self.jobs)
80
+ except Exception:
81
+ pass # Best effort; ignore if unsupported
82
+
83
+ else:
84
+ raise ValueError(f"Unsupported cluster type: {self.cluster_type}")
85
+
86
+ self.client = Client(self.cluster)
87
+ dash = getattr(self.client, "dashboard_link", None)
88
+ if dash:
89
+ logger.info(f"Dask dashboard: {dash}")
90
+
91
+ def _close_cluster(self) -> None:
92
+ # Close client first so tasks wind down before cluster termination.
93
+ if self.client is not None:
94
+ try:
95
+ self.client.close()
96
+ except Exception:
97
+ pass
98
+ finally:
99
+ self.client = None
100
+ if self.cluster is not None:
101
+ try:
102
+ self.cluster.close()
103
+ except Exception:
104
+ pass
105
+ finally:
106
+ self.cluster = None
107
+
@@ -0,0 +1,147 @@
1
+ import json
2
+ import logging
3
+ import platform
4
+ import time
5
+ import uuid
6
+
7
+ import awkward as ak
8
+ import coffea
9
+ from coffea import processor
10
+ from coffea.nanoevents import NanoAODSchema
11
+
12
+ from .cluster_manager import ClusterManager
13
+ from .processor import TriggerProcessor
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class TriggerLoader:
19
+ def __init__(self,
20
+ sample_json: str,
21
+ transform: callable,
22
+ output_path: str,
23
+ ):
24
+ self.transform = transform
25
+ self.fileset = self._load_sample_json(sample_json)
26
+ self.output_path = output_path
27
+ self.run_uuid = str(uuid.uuid4())
28
+
29
+ def _build_processor(self):
30
+ run_meta = {
31
+ "run_uuid": self.run_uuid,
32
+ "fileset_size": sum(len(v) if isinstance(v, list) else 1 for v in self.fileset.values()),
33
+ "coffea_version": coffea.__version__,
34
+ "awkward_version": ak.__version__,
35
+ "python_version": platform.python_version(),
36
+ }
37
+
38
+ return TriggerProcessor(
39
+ output_path=self.output_path,
40
+ transform=self.transform,
41
+ compression="zstd",
42
+ add_uuid=False,
43
+ run_uuid=self.run_uuid,
44
+ run_metadata=run_meta,
45
+ )
46
+
47
+ def _load_sample_json(self, sample_json: str) -> dict:
48
+ """
49
+ Loads the JSON and resolves file paths using the priority:
50
+ 1. Explicit 'files' list or directory path (Local/Explicit)
51
+ 2. 'DAS' query (Remote Fallback)
52
+
53
+ Returns the canonical coffea fileset format: {dataset_name: [file_path_list]}.
54
+ """
55
+ import glob
56
+ import os
57
+
58
+ # Helper function definition needed here if it's not imported:
59
+ # def _fetch_files_from_das(das_query: str) -> list[str]: ... (placeholder or actual implementation)
60
+
61
+ with open(sample_json) as f:
62
+ full_data = json.load(f)
63
+ dataset_metadata = full_data.get("samples", full_data)
64
+
65
+ fileset = {}
66
+ for ds_name, ds_info in dataset_metadata.items():
67
+ files = []
68
+
69
+ if "files" in ds_info:
70
+ file_info = ds_info["files"]
71
+
72
+ if isinstance(file_info, list):
73
+ files = file_info
74
+
75
+ elif isinstance(file_info, str):
76
+ if os.path.isdir(file_info):
77
+ path_glob = os.path.join(file_info, "*.root")
78
+ files = glob.glob(path_glob)
79
+ logger.info(f"Resolved {len(files)} files from directory {file_info}.")
80
+ else:
81
+ files = [file_info]
82
+
83
+ if files:
84
+ logger.info(f"Using {len(files)} local/explicit files for {ds_name}.")
85
+
86
+ if not files and "DAS" in ds_info:
87
+ try:
88
+ files = _fetch_files_from_das(ds_info["DAS"])
89
+ logger.info(f"Resolved {len(files)} files via DAS for {ds_name}.")
90
+ except NameError:
91
+ logger.error("DAS fetching skipped: _fetch_files_from_das is not defined.")
92
+
93
+ if not files:
94
+ logger.warning(f"No files found for dataset: {ds_name}. Skipping.")
95
+ continue
96
+
97
+ fileset[ds_name] = files
98
+
99
+ return fileset
100
+
101
+ def _write_run_metadata_file(self, path: str, duration_s: float | None = None):
102
+ meta_path = f"{path}/run_metadata.json"
103
+ data = {
104
+ "run_uuid": self.run_uuid,
105
+ "duration_seconds": duration_s,
106
+ }
107
+ with open(meta_path, "w") as f:
108
+ json.dump(data, f, indent=2)
109
+
110
+ def _run(self, runner: processor.Runner, label: str):
111
+ logger.log(logging.INFO, f"Starting processing ({label})...")
112
+ start = time.time()
113
+ proc = self._build_processor()
114
+ print(self.fileset)
115
+
116
+ acc = runner(
117
+ self.fileset,
118
+ treename="Events",
119
+ processor_instance=proc
120
+ )
121
+ elapsed = time.time() - start
122
+ self._write_run_metadata_file(self.output_path, elapsed)
123
+ logger.log(logging.INFO, f"Finished in {elapsed:.2f}s (run_uuid={self.run_uuid})")
124
+ return acc
125
+
126
+ def run_distributed(self, cluster_type: str, cluster_config: dict,
127
+ chunksize: int = 100_000, jobs: int = 1):
128
+ with ClusterManager(cluster_type, cluster_config, jobs) as client:
129
+ executor = processor.DaskExecutor(client=client)
130
+ runner = processor.Runner(
131
+ executor=executor,
132
+ schema=NanoAODSchema,
133
+ chunksize=chunksize
134
+ )
135
+ self._run(runner, f"Distributed ({cluster_type})")
136
+
137
+ def run_local(self, num_workers: int = 4, chunksize: int = 100_000):
138
+ """
139
+ Run processing locally using a multi-processing executor.
140
+ """
141
+ executor = processor.FuturesExecutor(workers=num_workers)
142
+ runner = processor.Runner(
143
+ executor=executor,
144
+ schema=NanoAODSchema,
145
+ chunksize=chunksize
146
+ )
147
+ self._run(runner, f"Local ({num_workers} workers)")
@@ -0,0 +1,211 @@
1
+ import datetime as dt
2
+ import hashlib
3
+ import json
4
+ import os
5
+ import time
6
+ import uuid
7
+ import warnings
8
+ from collections.abc import Callable
9
+ from typing import Any
10
+
11
+ import awkward as ak
12
+ import pyarrow.parquet as pq
13
+ from coffea import processor
14
+
15
+ warnings.filterwarnings("ignore", message="Found duplicate branch")
16
+
17
+
18
+ class TriggerProcessor(processor.ProcessorABC):
19
+ """
20
+ Coffea processor that applies a user transform to events and writes Parquet files.
21
+
22
+ This processor transforms event data and writes the results to Parquet files with
23
+ comprehensive metadata tracking for reproducibility and provenance.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ output_path: str,
29
+ transform: Callable[[Any], ak.Array],
30
+ compression: str = "zstd",
31
+ compression_level: int | None = None,
32
+ filename_template: str = "{dataset}_{fileuuid}_{start}-{stop}.parquet",
33
+ add_uuid: bool = False,
34
+ write_manifest: bool = True,
35
+ manifest_name: str = "manifest.jsonl",
36
+ run_uuid: str | None = None,
37
+ run_metadata: dict | None = None,
38
+ preserve_event_metadata: bool = True,
39
+ ):
40
+ """
41
+ Initialize the TriggerProcessor.
42
+
43
+ Args:
44
+ output_path: Directory where output files will be written
45
+ transform: Function to apply to events, returns awkward array
46
+ compression: Parquet compression algorithm
47
+ compression_level: Compression level (None for default)
48
+ filename_template: Template for output filenames
49
+ add_uuid: Whether to add UUID to filenames
50
+ write_manifest: Whether to write manifest file
51
+ manifest_name: Name of the manifest file
52
+ run_uuid: UUID for this processing run (generated if None)
53
+ run_metadata: Additional metadata for the run
54
+ preserve_event_metadata: Whether to preserve event-level metadata
55
+ """
56
+ self.output_path = output_path
57
+ self.transform = transform
58
+ self.compression = compression
59
+ self.compression_level = compression_level
60
+ self.filename_template = filename_template
61
+ self.add_uuid = add_uuid
62
+ self.write_manifest = write_manifest
63
+ self.manifest_name = manifest_name
64
+ self.run_uuid = run_uuid or str(uuid.uuid4())
65
+ self.run_metadata = run_metadata or {}
66
+ self.preserve_event_metadata = preserve_event_metadata
67
+
68
+ # Initialize output directory and paths
69
+ os.makedirs(self.output_path, exist_ok=True)
70
+ if write_manifest:
71
+ self._manifest_path = os.path.join(self.output_path, self.manifest_name)
72
+
73
+ @property
74
+ def accumulator(self):
75
+ """No aggregation needed (side-effect writing)."""
76
+ return {}
77
+
78
+ def process(self, events):
79
+ """
80
+ Process a chunk of events: transform and write to Parquet.
81
+
82
+ Args:
83
+ events: Input events from Coffea
84
+
85
+ Returns:
86
+ Empty dict (no accumulation needed)
87
+ """
88
+ # Apply transform and measure time
89
+ data, elapsed_s = self._apply_transform(events)
90
+
91
+ # Extract event metadata
92
+ event_meta = self._extract_event_metadata(events)
93
+ output_file = self._generate_output_filename(event_meta)
94
+
95
+ # Convert to Arrow table
96
+ table = ak.to_arrow_table(data)
97
+
98
+ # Build file metadata
99
+ file_meta = self._build_file_metadata(event_meta, table, elapsed_s)
100
+ table = self._embed_metadata_in_schema(table, file_meta, events.metadata)
101
+
102
+ # Write Parquet file
103
+ self._write_parquet(table, output_file)
104
+
105
+ # Write manifest entry
106
+ if self.write_manifest:
107
+ self._write_manifest_entry(output_file, file_meta)
108
+
109
+ return {}
110
+
111
+ def postprocess(self, accumulator):
112
+ """Postprocess accumulated results (no-op for this processor)."""
113
+ return accumulator
114
+
115
+ # ==================== Private Helper Methods ====================
116
+
117
+ def _apply_transform(self, events) -> tuple[ak.Array, float]:
118
+ """Apply user transform to events and measure execution time."""
119
+ t0 = time.time()
120
+ data = self.transform(events)
121
+ elapsed_s = time.time() - t0
122
+ return data, elapsed_s
123
+
124
+ def _extract_event_metadata(self, events) -> dict:
125
+ """Extract metadata from events object."""
126
+ source_file = None
127
+ if hasattr(events, "_events"):
128
+ source_file = getattr(events._events, "files", [None])[0]
129
+
130
+ return {
131
+ "start": events.metadata.get("entrystart", 0),
132
+ "stop": events.metadata.get("entrystop", 0),
133
+ "dataset": events.metadata.get("dataset", "unknown"),
134
+ "source_file": source_file,
135
+ "fileuuid": events.metadata.get("fileuuid"),
136
+ }
137
+
138
+ def _generate_output_filename(self, event_meta: dict) -> str:
139
+ """Generate output filename based on template and metadata."""
140
+ fname = self.filename_template.format(
141
+ dataset=event_meta["dataset"],
142
+ fileuuid=event_meta.get("fileuuid", "xx"),
143
+ start=event_meta["start"],
144
+ stop=event_meta["stop"]
145
+ )
146
+
147
+ if self.add_uuid:
148
+ stem, ext = os.path.splitext(fname)
149
+ fname = f"{stem}_{uuid.uuid4()}{ext}"
150
+
151
+ return os.path.join(self.output_path, fname)
152
+
153
+ def _build_file_metadata(self, event_meta: dict, table, elapsed_s: float) -> dict:
154
+ """Build comprehensive metadata dictionary for the output file."""
155
+ fileuuid = event_meta["fileuuid"]
156
+
157
+ return {
158
+ "run_uuid": self.run_uuid,
159
+ "dataset": event_meta["dataset"],
160
+ "source_root_file": event_meta["source_file"],
161
+ "fileuuid": str(fileuuid) if fileuuid is not None else None,
162
+ "entrystart": event_meta["start"],
163
+ "entrystop": event_meta["stop"],
164
+ "n_events_written": len(table),
165
+ "columns": table.column_names,
166
+ "created_utc": dt.datetime.now(dt.UTC).isoformat(timespec="seconds") + "Z",
167
+ "compression": self.compression,
168
+ "processing_time_s": round(elapsed_s, 6),
169
+ }
170
+
171
+ def _embed_metadata_in_schema(self, table, file_meta: dict, event_metadata: dict):
172
+ """Embed metadata into the Arrow table schema."""
173
+ schema = table.schema
174
+ existing = dict(schema.metadata or {})
175
+
176
+ # Add file metadata
177
+ existing[b"x-trigger-meta"] = json.dumps(
178
+ file_meta, separators=(",", ":")
179
+ ).encode()
180
+
181
+ # Optionally preserve event-level metadata
182
+ if self.preserve_event_metadata:
183
+ for k, v in event_metadata.items():
184
+ if k not in file_meta:
185
+ file_meta[f"eventmeta_{k}"] = v
186
+
187
+ # Add run metadata hash
188
+ if self.run_metadata:
189
+ existing.setdefault(
190
+ b"x-run-meta-hash",
191
+ hashlib.sha256(
192
+ json.dumps(self.run_metadata, sort_keys=True).encode()
193
+ ).hexdigest().encode()
194
+ )
195
+
196
+ return table.replace_schema_metadata(existing)
197
+
198
+ def _write_parquet(self, table, output_file: str):
199
+ """Write Arrow table to Parquet file."""
200
+ pq.write_table(
201
+ table,
202
+ output_file,
203
+ compression=self.compression,
204
+ compression_level=self.compression_level
205
+ )
206
+
207
+ def _write_manifest_entry(self, output_file: str, file_meta: dict):
208
+ """Write a single line to the manifest file."""
209
+ manifest_record = {"parquet_file": output_file, **file_meta}
210
+ with open(self._manifest_path, "a") as mf:
211
+ mf.write(json.dumps(manifest_record, separators=(",", ":")) + "\n")
triggerflow/cli.py ADDED
@@ -0,0 +1,122 @@
1
+ import argparse
2
+ import logging
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import yaml
9
+ from cookiecutter.main import cookiecutter
10
+
11
+ try:
12
+ from importlib.resources import as_file
13
+ from importlib.resources import files as ir_files
14
+ except ImportError:
15
+ import importlib_resources
16
+ ir_files = importlib_resources.files # type: ignore
17
+ as_file = importlib_resources.as_file # type: ignore
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def echo(cmd): logger.info(f"$ {' '.join(str(c) for c in cmd)}", flush=True)
23
+ def which_or_die(cmd):
24
+ if shutil.which(cmd) is None:
25
+ logger.info(f"Error: '{cmd}' not found on PATH.", file=sys.stderr); sys.exit(127)
26
+
27
+ def packaged_starter_root() -> Path:
28
+ """Real FS path to packaged starter folder."""
29
+ base = ir_files("triggerflow")
30
+ with as_file(base / "starter") as p:
31
+ p = Path(p)
32
+ if p.exists():
33
+ return p
34
+ with as_file(base) as p:
35
+ logger.info("Error: starter not found. Contents of package:", file=sys.stderr)
36
+ for child in Path(p).iterdir(): logger.info(" -", child.name, file=sys.stderr)
37
+ sys.exit(2)
38
+
39
+ def render_starter(project: str, out_dir: Path) -> Path:
40
+ """Render cookiecutter starter into out_dir; returns project path."""
41
+ starter = packaged_starter_root()
42
+ extra = {
43
+ "project_name": project,
44
+ "repo_name": project.replace("-", "_").lower(),
45
+ "python_package": project.replace("-", "_").lower(),
46
+ }
47
+
48
+ # cookiecutter.json expects bare keys, not cookiecutter.* here
49
+ out_dir.mkdir(parents=True, exist_ok=True)
50
+ proj_path = Path(cookiecutter(
51
+ template=str(starter),
52
+ no_input=True,
53
+ output_dir=str(out_dir),
54
+ extra_context=extra,
55
+ ))
56
+ return proj_path
57
+
58
+ def find_env_yml(project_dir: Path) -> Path:
59
+ # Prefer root environment.yml; else find under src/**/
60
+ candidates = []
61
+ root = project_dir / "environment.yml"
62
+ if root.exists(): candidates.append(root)
63
+ candidates += list((project_dir / "src").rglob("environment.yml"))
64
+ if not candidates:
65
+ logger.info(f"Error: environment.yml not found under {project_dir}", file=sys.stderr); sys.exit(3)
66
+ # stable preference
67
+ candidates.sort(key=lambda p: (0 if p.parent.name != "src" else 1, len(str(p))))
68
+ return candidates[0]
69
+
70
+ def conda_env_create_or_update(env_yml: Path) -> int:
71
+ which_or_die("conda")
72
+ # If YAML has 'name', override it to 'env_name' so updates are consistent
73
+ data = yaml.safe_load(env_yml.read_text(encoding="utf-8"))
74
+ if isinstance(data, dict):
75
+ tmp = env_yml.with_suffix(".rendered.yml")
76
+ tmp.write_text(yaml.safe_dump(data, sort_keys=False), encoding="utf-8")
77
+ env_yml = tmp
78
+
79
+ update = ["conda", "env", "update", "-f", str(env_yml), "--prune"]
80
+
81
+ echo(update); rc = subprocess.call(update)
82
+ if rc != 0: return rc
83
+
84
+ verify = ["conda", "run", "python", "-c", "import sys; logger.info(sys.executable)"]
85
+ echo(verify); subprocess.call(verify)
86
+ return 0
87
+
88
+ def cmd_new(project: str, output: Path) -> int:
89
+ proj_dir = render_starter(project, output)
90
+ return 0
91
+
92
+ def cmd_setup(project: str, output: Path) -> int:
93
+ # If project dir doesn’t exist yet, render it first
94
+ target = output / project
95
+ if not target.exists():
96
+ logger.info(f"Project '{project}' not found under {output}. Rendering starter first...")
97
+ render_starter(project, output)
98
+ env_yml = find_env_yml(target)
99
+ logger.info(f"Using environment file: {env_yml}")
100
+ return conda_env_create_or_update(env_yml)
101
+
102
+ def main():
103
+ parser = argparse.ArgumentParser(prog="triggerflow", description="Triggerflow CLI")
104
+ sub = parser.add_subparsers(dest="cmd", required=True)
105
+
106
+ p_new = sub.add_parser("new", help="Render a new project from the packaged starter (Cookiecutter)")
107
+ p_new.add_argument("project", help="Project name")
108
+ p_new.add_argument("--out", default=".", help="Output directory (default: .)")
109
+
110
+ p_setup = sub.add_parser("setup", help="Create/update conda env from the rendered project's environment.yml")
111
+ p_setup.add_argument("project", help="Project/env name")
112
+ p_setup.add_argument("--out", default=".", help="Project parent directory (default: .)")
113
+
114
+ args = parser.parse_args()
115
+ out = Path(getattr(args, "out", ".")).resolve()
116
+ if args.cmd == "new":
117
+ sys.exit(cmd_new(args.project, out))
118
+ else: # setup
119
+ sys.exit(cmd_setup(args.project, out))
120
+
121
+ if __name__ == "__main__":
122
+ main()