umami-preprocessing 0.2.6__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {umami_preprocessing-0.2.6/umami_preprocessing.egg-info → umami_preprocessing-0.2.7}/PKG-INFO +22 -15
  2. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/pyproject.toml +21 -14
  3. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7/umami_preprocessing.egg-info}/PKG-INFO +22 -15
  4. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/SOURCES.txt +4 -0
  5. umami_preprocessing-0.2.7/umami_preprocessing.egg-info/requires.txt +20 -0
  6. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/__init__.py +1 -1
  7. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/components.py +3 -2
  8. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/preprocessing_config.py +64 -12
  9. umami_preprocessing-0.2.7/upp/classes/reweight_config.py +78 -0
  10. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/main.py +81 -5
  11. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/normalisation.py +4 -2
  12. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/resampling.py +2 -1
  13. umami_preprocessing-0.2.7/upp/stages/reweight.py +465 -0
  14. umami_preprocessing-0.2.7/upp/stages/rw_merge.py +314 -0
  15. umami_preprocessing-0.2.7/upp/stages/split_containers.py +386 -0
  16. umami_preprocessing-0.2.6/umami_preprocessing.egg-info/requires.txt +0 -15
  17. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/LICENSE +0 -0
  18. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/MANIFEST.in +0 -0
  19. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/README.md +0 -0
  20. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/setup.cfg +0 -0
  21. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/dependency_links.txt +0 -0
  22. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/entry_points.txt +0 -0
  23. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/top_level.txt +0 -0
  24. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/__init__.py +0 -0
  25. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/region.py +0 -0
  26. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/resampling_config.py +0 -0
  27. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/variable_config.py +0 -0
  28. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/__init__.py +0 -0
  29. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/hist.py +0 -0
  30. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/interpolation.py +0 -0
  31. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/merging.py +0 -0
  32. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/plot.py +0 -0
  33. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/utils/__init__.py +0 -0
  34. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/utils/check_input_samples.py +0 -0
  35. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/utils/logger.py +0 -0
  36. {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/utils/tools.py +0 -0
@@ -1,26 +1,33 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: umami-preprocessing
3
- Version: 0.2.6
4
- Summary: Preprocessing for jet tagging
3
+ Version: 0.2.7
4
+ Summary: ATLAS Flavour Tagging Preprocessing - Umami PreProcessing (UPP)
5
+ Author: Alexander Froch
5
6
  License: MIT
6
7
  Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
7
- Requires-Python: <3.12,>=3.8
8
+ Project-URL: Issue Tracker, https://github.com/umami-hep/umami-preprocessing/issues
9
+ Requires-Python: <3.12,>=3.10
8
10
  Description-Content-Type: text/markdown
9
11
  License-File: LICENSE
10
- Requires-Dist: atlas-ftag-tools==0.2.15
11
- Requires-Dist: dotmap==1.3.30
12
- Requires-Dist: puma-hep==0.4.10
12
+ Requires-Dist: atlas-ftag-tools==0.2.17
13
+ Requires-Dist: dotmap>=1.3.30
14
+ Requires-Dist: numpy>=2.2.6
15
+ Requires-Dist: puma-hep==0.4.11
13
16
  Requires-Dist: pyyaml-include==1.3
14
- Requires-Dist: PyYAML>=6.0.1
15
- Requires-Dist: rich==12.6.0
16
- Requires-Dist: scipy>=1.15.2
17
+ Requires-Dist: PyYAML>=6.0.2
18
+ Requires-Dist: rich>=14.1.0
19
+ Requires-Dist: scipy>=1.15.3
17
20
  Provides-Extra: dev
18
- Requires-Dist: mypy==1.11.2; extra == "dev"
19
- Requires-Dist: pre-commit==3.5.0; extra == "dev"
20
- Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
21
- Requires-Dist: pytest-mock==3.11.1; extra == "dev"
22
- Requires-Dist: pytest>=7.2.2; extra == "dev"
23
- Requires-Dist: ruff==0.6.2; extra == "dev"
21
+ Requires-Dist: coverage>=7.10.6; extra == "dev"
22
+ Requires-Dist: ipykernel>=6.30.1; extra == "dev"
23
+ Requires-Dist: mypy>=1.18.1; extra == "dev"
24
+ Requires-Dist: pre-commit>=4.3.0; extra == "dev"
25
+ Requires-Dist: pydoclint>=0.7.3; extra == "dev"
26
+ Requires-Dist: pytest_notebook>=0.10.0; extra == "dev"
27
+ Requires-Dist: pytest-cov>=7.0.0; extra == "dev"
28
+ Requires-Dist: pytest-randomly>=4.0.1; extra == "dev"
29
+ Requires-Dist: pytest>=8.4.2; extra == "dev"
30
+ Requires-Dist: ruff>=0.13.0; extra == "dev"
24
31
  Dynamic: license-file
25
32
 
26
33
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
@@ -1,33 +1,40 @@
1
1
  [project]
2
2
  name = "umami-preprocessing"
3
- description = "Preprocessing for jet tagging"
3
+ description = "ATLAS Flavour Tagging Preprocessing - Umami PreProcessing (UPP)"
4
+ authors = [{name="Alexander Froch"}]
4
5
  dynamic = ["version"]
5
6
  license = {text = "MIT"}
6
7
  readme = "README.md"
7
- requires-python = "<3.12,>=3.8"
8
+ requires-python = ">=3.10,<3.12"
8
9
 
9
10
  dependencies = [
10
- "atlas-ftag-tools==0.2.15",
11
- "dotmap==1.3.30",
12
- "puma-hep==0.4.10",
11
+ "atlas-ftag-tools==0.2.17",
12
+ "dotmap>=1.3.30",
13
+ "numpy>=2.2.6",
14
+ "puma-hep==0.4.11",
13
15
  "pyyaml-include==1.3",
14
- "PyYAML>=6.0.1",
15
- "rich==12.6.0",
16
- "scipy>=1.15.2",
16
+ "PyYAML>=6.0.2",
17
+ "rich>=14.1.0",
18
+ "scipy>=1.15.3",
17
19
  ]
18
20
 
19
21
  [project.optional-dependencies]
20
22
  dev = [
21
- "mypy==1.11.2",
22
- "pre-commit==3.5.0",
23
- "pytest-cov>=4.0.0",
24
- "pytest-mock==3.11.1",
25
- "pytest>=7.2.2",
26
- "ruff==0.6.2",
23
+ "coverage>=7.10.6",
24
+ "ipykernel>=6.30.1",
25
+ "mypy>=1.18.1",
26
+ "pre-commit>=4.3.0",
27
+ "pydoclint>=0.7.3",
28
+ "pytest_notebook>=0.10.0",
29
+ "pytest-cov>=7.0.0",
30
+ "pytest-randomly>=4.0.1",
31
+ "pytest>=8.4.2",
32
+ "ruff>=0.13.0",
27
33
  ]
28
34
 
29
35
  [project.urls]
30
36
  "Homepage" = "https://github.com/umami-hep/umami-preprocessing"
37
+ "Issue Tracker" = "https://github.com/umami-hep/umami-preprocessing/issues"
31
38
 
32
39
  [project.scripts]
33
40
  preprocess = "upp.main:main"
@@ -1,26 +1,33 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: umami-preprocessing
3
- Version: 0.2.6
4
- Summary: Preprocessing for jet tagging
3
+ Version: 0.2.7
4
+ Summary: ATLAS Flavour Tagging Preprocessing - Umami PreProcessing (UPP)
5
+ Author: Alexander Froch
5
6
  License: MIT
6
7
  Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
7
- Requires-Python: <3.12,>=3.8
8
+ Project-URL: Issue Tracker, https://github.com/umami-hep/umami-preprocessing/issues
9
+ Requires-Python: <3.12,>=3.10
8
10
  Description-Content-Type: text/markdown
9
11
  License-File: LICENSE
10
- Requires-Dist: atlas-ftag-tools==0.2.15
11
- Requires-Dist: dotmap==1.3.30
12
- Requires-Dist: puma-hep==0.4.10
12
+ Requires-Dist: atlas-ftag-tools==0.2.17
13
+ Requires-Dist: dotmap>=1.3.30
14
+ Requires-Dist: numpy>=2.2.6
15
+ Requires-Dist: puma-hep==0.4.11
13
16
  Requires-Dist: pyyaml-include==1.3
14
- Requires-Dist: PyYAML>=6.0.1
15
- Requires-Dist: rich==12.6.0
16
- Requires-Dist: scipy>=1.15.2
17
+ Requires-Dist: PyYAML>=6.0.2
18
+ Requires-Dist: rich>=14.1.0
19
+ Requires-Dist: scipy>=1.15.3
17
20
  Provides-Extra: dev
18
- Requires-Dist: mypy==1.11.2; extra == "dev"
19
- Requires-Dist: pre-commit==3.5.0; extra == "dev"
20
- Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
21
- Requires-Dist: pytest-mock==3.11.1; extra == "dev"
22
- Requires-Dist: pytest>=7.2.2; extra == "dev"
23
- Requires-Dist: ruff==0.6.2; extra == "dev"
21
+ Requires-Dist: coverage>=7.10.6; extra == "dev"
22
+ Requires-Dist: ipykernel>=6.30.1; extra == "dev"
23
+ Requires-Dist: mypy>=1.18.1; extra == "dev"
24
+ Requires-Dist: pre-commit>=4.3.0; extra == "dev"
25
+ Requires-Dist: pydoclint>=0.7.3; extra == "dev"
26
+ Requires-Dist: pytest_notebook>=0.10.0; extra == "dev"
27
+ Requires-Dist: pytest-cov>=7.0.0; extra == "dev"
28
+ Requires-Dist: pytest-randomly>=4.0.1; extra == "dev"
29
+ Requires-Dist: pytest>=8.4.2; extra == "dev"
30
+ Requires-Dist: ruff>=0.13.0; extra == "dev"
24
31
  Dynamic: license-file
25
32
 
26
33
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
@@ -15,6 +15,7 @@ upp/classes/components.py
15
15
  upp/classes/preprocessing_config.py
16
16
  upp/classes/region.py
17
17
  upp/classes/resampling_config.py
18
+ upp/classes/reweight_config.py
18
19
  upp/classes/variable_config.py
19
20
  upp/stages/__init__.py
20
21
  upp/stages/hist.py
@@ -23,6 +24,9 @@ upp/stages/merging.py
23
24
  upp/stages/normalisation.py
24
25
  upp/stages/plot.py
25
26
  upp/stages/resampling.py
27
+ upp/stages/reweight.py
28
+ upp/stages/rw_merge.py
29
+ upp/stages/split_containers.py
26
30
  upp/utils/__init__.py
27
31
  upp/utils/check_input_samples.py
28
32
  upp/utils/logger.py
@@ -0,0 +1,20 @@
1
+ atlas-ftag-tools==0.2.17
2
+ dotmap>=1.3.30
3
+ numpy>=2.2.6
4
+ puma-hep==0.4.11
5
+ pyyaml-include==1.3
6
+ PyYAML>=6.0.2
7
+ rich>=14.1.0
8
+ scipy>=1.15.3
9
+
10
+ [dev]
11
+ coverage>=7.10.6
12
+ ipykernel>=6.30.1
13
+ mypy>=1.18.1
14
+ pre-commit>=4.3.0
15
+ pydoclint>=0.7.3
16
+ pytest_notebook>=0.10.0
17
+ pytest-cov>=7.0.0
18
+ pytest-randomly>=4.0.1
19
+ pytest>=8.4.2
20
+ ruff>=0.13.0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.6"
5
+ __version__ = "v0.2.7"
6
6
 
7
7
  from . import classes, stages, utils
8
8
  from .main import run_pp
@@ -336,6 +336,7 @@ class Components:
336
336
  pattern=pattern,
337
337
  ntuple_dir=config.ntuple_dir,
338
338
  name=component["sample"]["name"],
339
+ skip_checks=config.skip_checks,
339
340
  )
340
341
 
341
342
  # Create the Component instances for the different flavours
@@ -360,7 +361,7 @@ class Components:
360
361
  components = cls(component_list)
361
362
 
362
363
  # Check the flavour ratios
363
- if config.sampl_cfg.method is not None:
364
+ if config.sampl_cfg and config.sampl_cfg.method is not None:
364
365
  components.check_flavour_ratios()
365
366
 
366
367
  return components
@@ -510,7 +511,7 @@ class Components:
510
511
  def __getitem__(self, index):
511
512
  if isinstance(index, int):
512
513
  return self.components[index]
513
- if isinstance(index, (str, Label)):
514
+ if isinstance(index, str | Label):
514
515
  return self.components[self.flavours.index(index)]
515
516
 
516
517
  def __len__(self):
@@ -20,6 +20,7 @@ from yamlinclude import YamlIncludeConstructor
20
20
  from upp import __version__
21
21
  from upp.classes.components import Components
22
22
  from upp.classes.resampling_config import ResamplingConfig
23
+ from upp.classes.reweight_config import ReweightConfig
23
24
  from upp.classes.variable_config import VariableConfig
24
25
  from upp.utils.tools import path_append
25
26
 
@@ -114,6 +115,8 @@ class PreprocessingConfig:
114
115
  than this number, the final h5 output files are splitted in multiple smaller
115
116
  files with this number of jets per file. By default None which produces one
116
117
  huge output file.
118
+ skip_checks : bool, optional
119
+ Skip checks for the input files. This is used for grid submission
117
120
  skip_config_copy : bool, optional
118
121
  Decide, if the config copying is skipped or not. By default False
119
122
  """
@@ -137,6 +140,7 @@ class PreprocessingConfig:
137
140
  flavour_config: Path | None = None
138
141
  flavour_category: str = "standard"
139
142
  num_jets_per_output_file: int | None = None
143
+ skip_checks: bool = False
140
144
  skip_config_copy: bool = False
141
145
 
142
146
  def __post_init__(self):
@@ -154,11 +158,10 @@ class PreprocessingConfig:
154
158
  for field in dataclasses.fields(self):
155
159
  if field.type == "Path" and field.name != "out_fname" and field.name != "base_dir":
156
160
  setattr(self, field.name, self.get_path(Path(getattr(self, field.name))))
157
- if not self.ntuple_dir.exists():
161
+ if not self.ntuple_dir.exists() and not self.skip_checks:
158
162
  raise FileNotFoundError(f"Path {self.ntuple_dir} does not exist")
159
163
  self.components_dir = self.components_dir / self.split
160
164
  self.out_fname = self.out_dir / path_append(self.out_fname, self.split)
161
-
162
165
  # Define the content of the flavour label container
163
166
  if self.flavour_config:
164
167
  self.flavour_cont = LabelContainer.from_yaml(
@@ -177,12 +180,15 @@ class PreprocessingConfig:
177
180
  "flavours! If you want to use your own flavour config yaml file, please "
178
181
  "provide flavour_config!"
179
182
  )
180
-
181
183
  # configure classes
182
- sampl_cfg = copy(self.config["resampling"])
183
- if self.is_test:
184
- sampl_cfg["method"] = None
185
- self.sampl_cfg = ResamplingConfig(**sampl_cfg)
184
+ if sampl_cfg := self.config.get("resampling", None):
185
+ sampl_cfg = copy(sampl_cfg)
186
+ if self.is_test:
187
+ sampl_cfg["method"] = None
188
+ self.sampl_cfg = ResamplingConfig(**sampl_cfg)
189
+ else:
190
+ self.sampl_cfg = None
191
+
186
192
  self.components = Components.from_config(self)
187
193
 
188
194
  # get track selectors
@@ -196,13 +202,21 @@ class PreprocessingConfig:
196
202
  self.variables = VariableConfig(
197
203
  self.config["variables"], self.jets_name, self.is_test, selectors
198
204
  )
199
- self.variables = self.variables.add_jet_vars(
200
- list(self.config["resampling"]["variables"].keys()), "labels"
201
- )
205
+ if self.sampl_cfg is not None:
206
+ self.variables = self.variables.add_jet_vars(
207
+ list(self.config["resampling"]["variables"].keys()), "labels"
208
+ )
202
209
  self.transform = (
203
210
  Transform(**self.config["transform"]) if "transform" in self.config else None
204
211
  )
205
212
 
213
+ self.rw_config = (
214
+ ReweightConfig(
215
+ **self.config["reweighting"],
216
+ )
217
+ if "reweighting" in self.config
218
+ else None
219
+ )
206
220
  # reproducibility
207
221
  self.git_hash = get_git_hash(Path(__file__).parent)
208
222
  if self.git_hash is None:
@@ -214,7 +228,13 @@ class PreprocessingConfig:
214
228
  self.copy_config()
215
229
 
216
230
  @classmethod
217
- def from_file(cls, config_path: Path, split: Split, skip_config_copy: bool = False):
231
+ def from_file(
232
+ cls,
233
+ config_path: Path,
234
+ split: Split,
235
+ skip_checks: bool = False,
236
+ skip_config_copy: bool = False,
237
+ ):
218
238
  if not config_path.exists():
219
239
  raise FileNotFoundError(f"{config_path} does not exist - check your --config arg")
220
240
  with open(config_path) as file:
@@ -225,6 +245,7 @@ class PreprocessingConfig:
225
245
  config=config,
226
246
  skip_config_copy=skip_config_copy,
227
247
  **config["global"],
248
+ skip_checks=skip_checks,
228
249
  )
229
250
 
230
251
  def get_path(self, path: Path):
@@ -238,7 +259,7 @@ class PreprocessingConfig:
238
259
  def global_cuts(self):
239
260
  cuts_list = self.config["global_cuts"].get("common", [])
240
261
  cuts_list += self.config["global_cuts"][self.split]
241
- if not self.is_test:
262
+ if not self.is_test and self.config.get("resampling", None) is not None:
242
263
  for resampling_var, cfg in self.config["resampling"]["variables"].items():
243
264
  cuts_list.append([resampling_var, ">", cfg["bins"][0][0]])
244
265
  cuts_list.append([resampling_var, "<", cfg["bins"][-1][1]])
@@ -357,3 +378,34 @@ class PreprocessingConfig:
357
378
  f"Option value {option} is not supported! "
358
379
  "Only resampled and resampled_scaled_shuffled are."
359
380
  )
381
+
382
+ # Static because otherwise config paths end up getting messed up
383
+ @staticmethod
384
+ def get_input_files_with_split_components(config_path):
385
+ """Return a nested dictionary of the form.
386
+
387
+ {
388
+ "container_1" : {
389
+ "train_bjets" : list[str] : -> which are cuts
390
+ "test_bjets" : list[str] : -> which are cuts
391
+ ...
392
+ "test_cjets" : list[str] : -> which are cuts
393
+
394
+ },
395
+ "container_2" : ...
396
+ }
397
+ which represents all splits and components to be run for each input container
398
+ """
399
+ containers_with_splits = {}
400
+
401
+ for split in ["train", "val", "test"]:
402
+ split_config = PreprocessingConfig.from_file(config_path, split, skip_checks=True)
403
+
404
+ for component in split_config.components.components:
405
+ for container in component.sample.pattern:
406
+ if container not in containers_with_splits:
407
+ containers_with_splits[container] = {}
408
+ component_cuts = component.cuts
409
+ containers_with_splits[container][f"{split}_{component.name}"] = component_cuts
410
+
411
+ return containers_with_splits
@@ -0,0 +1,78 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass, field
5
+
6
+ import numpy as np
7
+
8
+
9
+ @dataclass
10
+ class ReweightConfig:
11
+ # Number of jets to estimate, if None, use the global num jets estimate
12
+ num_jets_estimate: None | int = None
13
+ merge_num_proc: int = 1 # Number of processes to use for merging
14
+ reweights: list[SingleReweightConfig] = field(default_factory=list)
15
+
16
+ def __post_init__(self):
17
+ if self.num_jets_estimate is not None and self.num_jets_estimate <= 0:
18
+ raise ValueError("num_jets_estimate must be a positive integer or None")
19
+
20
+ parsed_reweights = []
21
+ for rw in self.reweights:
22
+ parsed_reweights.append(SingleReweightConfig(**rw))
23
+ self.reweights = parsed_reweights
24
+
25
+
26
+ @dataclass
27
+ class SingleReweightConfig:
28
+ group: str # The group our variables in the h5 file are in
29
+ reweight_vars: list[str] # The variables we want to reweight
30
+ bins: list[np.ndarray] # The bins we want to use for the reweighting
31
+ class_var: str # The variable which contains the label we resample over, e.g. flavour
32
+ class_target: int | tuple | str | None = None
33
+ add_overflow: bool = True # Whether to add overflow bins
34
+
35
+ target_hist_func: Callable | None = None
36
+ target_hist_func_name: str | None = None
37
+
38
+ # TODO - this is the same as in resampling, maybe can cleanup
39
+ def get_bins_x(self, bins_x, upscale=1):
40
+ flat_bins = []
41
+ for i, sub_bins_x in enumerate(bins_x):
42
+ start, stop, nbins = sub_bins_x
43
+ b = np.linspace(start, stop, nbins * upscale + 1)
44
+ if i > 0:
45
+ b = b[1:]
46
+ flat_bins.append(b)
47
+ if self.add_overflow:
48
+ flat_bins = [np.array([-np.inf])] + flat_bins + [np.array([np.inf])]
49
+ return np.concatenate(flat_bins)
50
+
51
+ @property
52
+ def flat_bins(self):
53
+ return [self.get_bins_x(self.bins[k]) for k in self.reweight_vars]
54
+
55
+ def __post_init__(self):
56
+ if isinstance(self.class_target, str) and self.class_target not in [
57
+ "mean",
58
+ "min",
59
+ "max",
60
+ "uniform",
61
+ ]:
62
+ raise ValueError("class_target must be either 'mean', 'min', 'max' or an integer")
63
+
64
+ if self.target_hist_func is not None and self.target_hist_func_name is None:
65
+ self.target_hist_func_name = self.target_hist_func.__name__
66
+
67
+ def __repr__(self):
68
+ target_str = "target_"
69
+ if self.target_hist_func_name is not None:
70
+ target_str += f"{self.target_hist_func_name}_"
71
+ if self.class_target is not None:
72
+ if isinstance(self.class_target, list | tuple):
73
+ target_str += "_".join(map(str, self.class_target))
74
+ else:
75
+ target_str += f"{self.class_target}_{self.class_var}"
76
+ else:
77
+ target_str += "none"
78
+ return f"weight_{self.group}_{'_'.join(self.reweight_vars)}_{target_str}"
@@ -23,6 +23,9 @@ from upp.stages.merging import Merging
23
23
  from upp.stages.normalisation import Normalisation
24
24
  from upp.stages.plot import plot_resampling_dists
25
25
  from upp.stages.resampling import Resampling
26
+ from upp.stages.reweight import Reweight
27
+ from upp.stages.rw_merge import RWMerge
28
+ from upp.stages.split_containers import SplitContainers
26
29
  from upp.utils.check_input_samples import run_input_sample_check
27
30
  from upp.utils.logger import setup_logger
28
31
 
@@ -116,6 +119,28 @@ def parse_args(args: Any) -> argparse.Namespace:
116
119
  default=None,
117
120
  help="Component which is processed during --prep",
118
121
  )
122
+ parser.add_argument(
123
+ "--split-components",
124
+ action="store_true",
125
+ default=False,
126
+ help="Split containers into components",
127
+ )
128
+ parser.add_argument(
129
+ "--reweight", "--rw", action="store_true", default=False, help="Run the reweighting stage"
130
+ )
131
+ parser.add_argument(
132
+ "--rw-merge", "--rwm", action="store_true", default=False, help="Run the reweighting stage"
133
+ )
134
+ parser.add_argument(
135
+ "--rw-merge-idx",
136
+ "--rwm-idx",
137
+ type=str,
138
+ default=None,
139
+ help=(
140
+ "Commar seperated pair of indices representing the range of output "
141
+ "files to create, e.g '0,10' will create files 0 to 9"
142
+ ),
143
+ )
119
144
  parser.add_argument(
120
145
  "--region",
121
146
  default=None,
@@ -126,10 +151,37 @@ def parse_args(args: Any) -> argparse.Namespace:
126
151
  action="store_true",
127
152
  help="Skip the inital input sample check",
128
153
  )
154
+ parser.add_argument(
155
+ "--grid", action="store_true", help="Use when running the split stage on the grid. "
156
+ )
157
+ parser.add_argument(
158
+ "--container",
159
+ default=None,
160
+ type=str,
161
+ help="Container to use during the 'split-containers' stage. "
162
+ "If not specified, all containers in the config will be used.",
163
+ )
164
+ parser.add_argument(
165
+ "--files",
166
+ default=None,
167
+ help="comma-separated list of files to use during the 'split-containers' stage ",
168
+ )
129
169
 
130
170
  args = parser.parse_args(args)
131
171
  d = vars(args)
132
- ignore = ["config", "split", "component", "region"]
172
+ ignore = [
173
+ "config",
174
+ "split",
175
+ "component",
176
+ "region",
177
+ "container",
178
+ "files",
179
+ "grid",
180
+ "split_components",
181
+ "reweight",
182
+ "rw_merge",
183
+ "rw_merge_idx",
184
+ ]
133
185
  if not any(v for a, v in d.items() if a not in ignore):
134
186
  for v in d:
135
187
  if v not in ignore and d[v] is None:
@@ -151,10 +203,24 @@ def run_pp(args: argparse.Namespace) -> None:
151
203
  log.info("[bold green]Starting preprocessing...")
152
204
  start = datetime.now()
153
205
  log.info(f"Start time: {start.strftime('%Y-%m-%d %H:%M:%S')}")
154
-
155
206
  # load config
156
- config = PreprocessingConfig.from_file(args.config, args.split)
207
+ config = PreprocessingConfig.from_file(args.config, args.split, skip_checks=args.grid)
157
208
 
209
+ if args.split_components:
210
+ log.info("Splitting containers...")
211
+ split = SplitContainers(args.config)
212
+ split.run(
213
+ args.container,
214
+ args.files,
215
+ "." if args.grid else None, # Use current directory if not on grid
216
+ )
217
+ # If we aren't running on the grid, we create the metadata after splitting
218
+ if not args.grid:
219
+ split.create_meta_data()
220
+ if args.reweight:
221
+ log.info("Running reweighting...")
222
+ reweight = Reweight(config)
223
+ reweight.run()
158
224
  # create virtual datasets and pdf files
159
225
  if args.prep:
160
226
  # Check the input samples sizes
@@ -180,6 +246,16 @@ def run_pp(args: argparse.Namespace) -> None:
180
246
  if args.merge:
181
247
  merging = Merging(config)
182
248
  merging.run()
249
+ if args.rw_merge:
250
+ if args.rw_merge_idx:
251
+ rw_merge_idx = args.rw_merge_idx
252
+ assert "," in rw_merge_idx, "rw-merge-idx must be a comma-separated pair of indices"
253
+ rw_merge_idx = tuple(map(int, rw_merge_idx.split(",")))
254
+ assert len(rw_merge_idx) == 2, "rw-merge-idx must be a pair of indices"
255
+ else:
256
+ rw_merge_idx = None
257
+ rw_merge = RWMerge(config, rw_merge_idx)
258
+ rw_merge.run()
183
259
 
184
260
  # run the normalisation
185
261
  if args.norm and args.split == "train":
@@ -209,10 +285,10 @@ def main(args: Any | None = None) -> None:
209
285
  d = vars(args)
210
286
  for split in ["train", "val", "test"]:
211
287
  d["split"] = split
212
- log.info(f"[bold blue]{'-'*100}")
288
+ log.info(f"[bold blue]{'-' * 100}")
213
289
  title = f" {args.split} "
214
290
  log.info(f"[bold blue]{title:-^100}")
215
- log.info(f"[bold blue]{'-'*100}")
291
+ log.info(f"[bold blue]{'-' * 100}")
216
292
  run_pp(args)
217
293
  else:
218
294
  run_pp(args)
@@ -191,7 +191,7 @@ class Normalisation:
191
191
  "Class dict A has arrays of different lengths for the same"
192
192
  " variable. This should not happen."
193
193
  )
194
- counts_A = dict(zip(*class_dict_A[name][v]))
194
+ counts_A = dict(zip(*class_dict_A[name][v], strict=False))
195
195
  counts[i] += counts_A.get(label, 0)
196
196
  var[v] = (labels, counts)
197
197
  return class_dict_B
@@ -233,9 +233,11 @@ class Normalisation:
233
233
  """Run the normalisation calculation."""
234
234
  title = " Computing Normalisations "
235
235
  log.info(f"[bold green]{title:-^100}")
236
+ if self.config.rw_config is not None:
237
+ fname = str(self.config.out_fname).replace(".h5", "_vds.h5")
236
238
 
237
239
  # Get the correct output names if multiple output files were written
238
- if self.config.num_jets_per_output_file:
240
+ elif self.config.num_jets_per_output_file:
239
241
  fname = self.config.out_fname.parent / f"{self.config.out_fname.stem}*.h5"
240
242
 
241
243
  else:
@@ -15,7 +15,8 @@ from upp.stages.interpolation import subdivide_bins, upscale_array_regionally
15
15
  from upp.utils.logger import ProgressBar
16
16
 
17
17
  if TYPE_CHECKING: # pragma: no cover
18
- from typing import Any, Generator
18
+ from collections.abc import Generator
19
+ from typing import Any
19
20
 
20
21
  from rich.progress import Progress
21
22