umami-preprocessing 0.0.6__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/PKG-INFO +9 -10
  2. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/pyproject.toml +11 -19
  3. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/PKG-INFO +9 -10
  4. umami_preprocessing-0.2.0/umami_preprocessing.egg-info/requires.txt +15 -0
  5. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/__init__.py +1 -1
  6. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/components.py +28 -16
  7. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/preprocessing_config.py +52 -14
  8. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/variable_config.py +4 -1
  9. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/main.py +18 -20
  10. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/hist.py +9 -5
  11. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/interpolation.py +4 -4
  12. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/merging.py +12 -2
  13. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/normalisation.py +5 -3
  14. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/plot.py +32 -14
  15. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/resampling.py +13 -12
  16. umami-preprocessing-0.0.6/umami_preprocessing.egg-info/requires.txt +0 -16
  17. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/README.md +0 -0
  18. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/setup.cfg +0 -0
  19. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/SOURCES.txt +0 -0
  20. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/dependency_links.txt +0 -0
  21. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/entry_points.txt +0 -0
  22. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/top_level.txt +0 -0
  23. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/__init__.py +0 -0
  24. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/region.py +0 -0
  25. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/resampling_config.py +0 -0
  26. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/logger.py +0 -0
  27. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/__init__.py +0 -0
  28. {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/utils.py +0 -0
@@ -1,26 +1,25 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: umami-preprocessing
3
- Version: 0.0.6
3
+ Version: 0.2.0
4
4
  Summary: Preprocessing for jet tagging
5
5
  License: MIT
6
6
  Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
7
- Requires-Python: >=3.8
7
+ Requires-Python: <3.12,>=3.8
8
8
  Description-Content-Type: text/markdown
9
9
  Requires-Dist: pyyaml-include==1.3
10
10
  Requires-Dist: PyYAML==6.0.1
11
11
  Requires-Dist: rich==12.6.0
12
12
  Requires-Dist: scipy==1.10.1
13
- Requires-Dist: puma-hep==0.3.0
14
- Requires-Dist: atlas-ftag-tools==0.1.10
13
+ Requires-Dist: puma-hep==0.4.1
14
+ Requires-Dist: atlas-ftag-tools==0.2.7
15
15
  Requires-Dist: dotmap==1.3.30
16
16
  Provides-Extra: dev
17
- Requires-Dist: black==23.9.1; extra == "dev"
18
- Requires-Dist: ruff==0.0.289; extra == "dev"
17
+ Requires-Dist: ruff==0.1.6; extra == "dev"
19
18
  Requires-Dist: mypy==1.5.1; extra == "dev"
20
- Requires-Dist: pre-commit==3.1.1; extra == "dev"
21
- Requires-Dist: pytest==7.2.2; extra == "dev"
19
+ Requires-Dist: pre-commit==3.5.0; extra == "dev"
20
+ Requires-Dist: pytest>=7.0.1; extra == "dev"
22
21
  Requires-Dist: pytest-mock==3.11.1; extra == "dev"
23
- Requires-Dist: pytest-cov==4.0.0; extra == "dev"
22
+ Requires-Dist: pytest-cov>=3.0.0; extra == "dev"
24
23
 
25
24
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
26
25
  [![codecov](https://codecov.io/gh/umami-hep/umami-preprocessing/graph/badge.svg?token=K8MJI20UZO)](https://codecov.io/gh/umami-hep/umami-preprocessing)
@@ -4,27 +4,26 @@ description = "Preprocessing for jet tagging"
4
4
  dynamic = ["version"]
5
5
  license = {text = "MIT"}
6
6
  readme = "README.md"
7
- requires-python = ">=3.8"
7
+ requires-python = "<3.12,>=3.8"
8
8
 
9
9
  dependencies = [
10
10
  "pyyaml-include==1.3",
11
11
  "PyYAML==6.0.1",
12
12
  "rich==12.6.0",
13
13
  "scipy==1.10.1",
14
- "puma-hep==0.3.0",
15
- "atlas-ftag-tools==0.1.10",
14
+ "puma-hep==0.4.1",
15
+ "atlas-ftag-tools==0.2.7",
16
16
  "dotmap==1.3.30"
17
17
  ]
18
18
 
19
19
  [project.optional-dependencies]
20
20
  dev = [
21
- "black==23.9.1",
22
- "ruff==0.0.289",
21
+ "ruff==0.1.6",
23
22
  "mypy==1.5.1",
24
- "pre-commit==3.1.1",
25
- "pytest==7.2.2",
23
+ "pre-commit==3.5.0",
24
+ "pytest>=7.0.1",
26
25
  "pytest-mock==3.11.1",
27
- "pytest-cov==4.0.0",
26
+ "pytest-cov>=3.0.0",
28
27
  ]
29
28
 
30
29
  [project.urls]
@@ -44,24 +43,17 @@ version = {attr = "upp.__version__"}
44
43
  requires = ["setuptools>=62"]
45
44
  build-backend = "setuptools.build_meta"
46
45
 
47
- [tool.black]
48
- line-length = 100
49
- preview = "True"
50
-
51
46
  [tool.ruff]
52
- select = ["I", "E", "W", "F", "B", "UP", "ARG", "SIM", "TID", "RUF", "D2", "D3", "D4"]
53
- ignore = ["D211", "D213", "RUF005"]
47
+ lint.select = ["I", "E", "W", "F", "B", "UP", "ARG", "SIM", "TID", "RUF", "D2", "D3", "D4"]
48
+ lint.ignore = ["D211", "D213", "RUF005"]
54
49
  line-length = 100
55
50
 
56
- [tool.ruff.isort]
51
+ [tool.ruff.lint.isort]
57
52
  required-imports = ["from __future__ import annotations"]
58
53
 
59
- [tool.ruff.pydocstyle]
54
+ [tool.ruff.lint.pydocstyle]
60
55
  convention = "numpy" # Accepts: "google", "numpy", or "pep257".
61
56
 
62
- [mypy]
63
- ignore_missing_imports = "True"
64
-
65
57
  [tool.pytest.ini_options]
66
58
  log_cli_level = "debug"
67
59
  filterwarnings = ["ignore::DeprecationWarning"]
@@ -1,26 +1,25 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: umami-preprocessing
3
- Version: 0.0.6
3
+ Version: 0.2.0
4
4
  Summary: Preprocessing for jet tagging
5
5
  License: MIT
6
6
  Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
7
- Requires-Python: >=3.8
7
+ Requires-Python: <3.12,>=3.8
8
8
  Description-Content-Type: text/markdown
9
9
  Requires-Dist: pyyaml-include==1.3
10
10
  Requires-Dist: PyYAML==6.0.1
11
11
  Requires-Dist: rich==12.6.0
12
12
  Requires-Dist: scipy==1.10.1
13
- Requires-Dist: puma-hep==0.3.0
14
- Requires-Dist: atlas-ftag-tools==0.1.10
13
+ Requires-Dist: puma-hep==0.4.1
14
+ Requires-Dist: atlas-ftag-tools==0.2.7
15
15
  Requires-Dist: dotmap==1.3.30
16
16
  Provides-Extra: dev
17
- Requires-Dist: black==23.9.1; extra == "dev"
18
- Requires-Dist: ruff==0.0.289; extra == "dev"
17
+ Requires-Dist: ruff==0.1.6; extra == "dev"
19
18
  Requires-Dist: mypy==1.5.1; extra == "dev"
20
- Requires-Dist: pre-commit==3.1.1; extra == "dev"
21
- Requires-Dist: pytest==7.2.2; extra == "dev"
19
+ Requires-Dist: pre-commit==3.5.0; extra == "dev"
20
+ Requires-Dist: pytest>=7.0.1; extra == "dev"
22
21
  Requires-Dist: pytest-mock==3.11.1; extra == "dev"
23
- Requires-Dist: pytest-cov==4.0.0; extra == "dev"
22
+ Requires-Dist: pytest-cov>=3.0.0; extra == "dev"
24
23
 
25
24
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
26
25
  [![codecov](https://codecov.io/gh/umami-hep/umami-preprocessing/graph/badge.svg?token=K8MJI20UZO)](https://codecov.io/gh/umami-hep/umami-preprocessing)
@@ -0,0 +1,15 @@
1
+ pyyaml-include==1.3
2
+ PyYAML==6.0.1
3
+ rich==12.6.0
4
+ scipy==1.10.1
5
+ puma-hep==0.4.1
6
+ atlas-ftag-tools==0.2.7
7
+ dotmap==1.3.30
8
+
9
+ [dev]
10
+ ruff==0.1.6
11
+ mypy==1.5.1
12
+ pre-commit==3.5.0
13
+ pytest>=7.0.1
14
+ pytest-mock==3.11.1
15
+ pytest-cov>=3.0.0
@@ -1,4 +1,4 @@
1
1
  """UPP: Umami PreProcessing."""
2
2
  from __future__ import annotations
3
3
 
4
- __version__ = "v0.0.6"
4
+ __version__ = "v0.2.0"
@@ -5,7 +5,7 @@ from dataclasses import dataclass
5
5
  from pathlib import Path
6
6
 
7
7
  import numpy as np
8
- from ftag import Cuts, Flavour, Flavours, Sample
8
+ from ftag import Cuts, Label, Sample
9
9
  from ftag.hdf5 import H5Reader, H5Writer
10
10
 
11
11
  from upp.classes.region import Region
@@ -16,26 +16,28 @@ from upp.stages.hist import Hist
16
16
  class Component:
17
17
  region: Region
18
18
  sample: Sample
19
- flavour: Flavour
19
+ flavour: Label
20
20
  global_cuts: Cuts
21
21
  dirname: Path
22
22
  num_jets: int
23
- num_jets_estimate: int
24
- equal_jets: bool = True
23
+ num_jets_estimate_available: int
24
+ equal_jets: bool
25
25
 
26
26
  def __post_init__(self):
27
27
  self.hist = Hist(self.dirname.parent.parent / "hists" / f"hist_{self.name}.h5")
28
28
 
29
- def setup_reader(self, batch_size, fname=None, **kwargs):
29
+ def setup_reader(self, batch_size, jets_name="jets", fname=None, **kwargs):
30
30
  if fname is None:
31
31
  fname = self.sample.path
32
- self.reader = H5Reader(fname, batch_size, equal_jets=self.equal_jets, **kwargs)
32
+ self.reader = H5Reader(
33
+ fname, batch_size, jets_name=jets_name, equal_jets=self.equal_jets, **kwargs
34
+ )
33
35
  log.debug(f"Setup component reader at: {fname}")
34
36
 
35
- def setup_writer(self, variables):
37
+ def setup_writer(self, variables, jets_name="jets"):
36
38
  dtypes = self.reader.dtypes(variables.combined())
37
39
  shapes = self.reader.shapes(self.num_jets, variables.keys())
38
- self.writer = H5Writer(self.out_path, dtypes, shapes)
40
+ self.writer = H5Writer(self.out_path, dtypes, shapes, jets_name=jets_name)
39
41
  log.debug(f"Setup component writer at: {self.out_path}")
40
42
 
41
43
  @property
@@ -61,7 +63,10 @@ class Component:
61
63
  self, num_req, sampling_frac=None, cuts=None, silent=False, raise_error=True
62
64
  ):
63
65
  # Check if num_jets jets are aviailable after the cuts and sampling fraction
64
- total = self.reader.estimate_available_jets(cuts, self.num_jets_estimate)
66
+ num_est = (
67
+ None if self.num_jets_estimate_available <= 0 else self.num_jets_estimate_available
68
+ )
69
+ total = self.reader.estimate_available_jets(cuts, num_est)
65
70
  available = total
66
71
  if sampling_frac:
67
72
  available = int(total * sampling_frac)
@@ -77,11 +82,17 @@ class Component:
77
82
 
78
83
  if not silent:
79
84
  log.debug(f"Sampling fraction {sampling_frac}")
80
- log.info(f"Estimated {available:,} {self} jets available - {num_req:,} requested")
85
+ log.info(
86
+ f"Estimated {available:,} {self} jets available - {num_req:,} requested"
87
+ f"({self.reader.num_jets:,} in {self.sample})"
88
+ )
81
89
 
82
90
  def get_auto_sampling_frac(self, num_jets, cuts=None, silent=False):
83
- total = self.reader.estimate_available_jets(cuts, self.num_jets_estimate)
84
- auto_sampling_frac = round(1.05 * num_jets / total, 3) # 1.05 is a tolerance factor
91
+ num_est = (
92
+ None if self.num_jets_estimate_available <= 0 else self.num_jets_estimate_available
93
+ )
94
+ total = self.reader.estimate_available_jets(cuts, num_est)
95
+ auto_sampling_frac = round(1.1 * num_jets / total, 3) # 1.1 is a tolerance factor
85
96
  if not silent:
86
97
  log.debug(f"optimal sampling fraction {auto_sampling_frac:.3f}")
87
98
  return auto_sampling_frac
@@ -102,6 +113,7 @@ class Components:
102
113
  def from_config(cls, pp_cfg):
103
114
  components = []
104
115
  for c in pp_cfg.config["components"]:
116
+ assert "equal_jets" not in c, "equal_jets flag should be set in the sample config"
105
117
  region_cuts = Cuts.empty() if pp_cfg.is_test else Cuts.from_list(c["region"]["cuts"])
106
118
  region = Region(c["region"]["name"], region_cuts + pp_cfg.global_cuts)
107
119
  pattern = c["sample"]["pattern"]
@@ -119,11 +131,11 @@ class Components:
119
131
  Component(
120
132
  region,
121
133
  sample,
122
- Flavours[name],
134
+ pp_cfg.flavour_cont[name],
123
135
  pp_cfg.global_cuts,
124
136
  pp_cfg.components_dir,
125
137
  num_jets,
126
- pp_cfg.num_jets_estimate,
138
+ pp_cfg.num_jets_estimate_available,
127
139
  equal_jets,
128
140
  )
129
141
  )
@@ -193,7 +205,7 @@ class Components:
193
205
 
194
206
  @property
195
207
  def dsids(self):
196
- return list(set(sum([c.sample.dsid for c in self], [])))
208
+ return list(set(sum([c.sample.dsid for c in self], []))) # noqa: RUF017
197
209
 
198
210
  def groupby_region(self):
199
211
  return [(r, Components([c for c in self if c.region == r])) for r in self.regions]
@@ -207,7 +219,7 @@ class Components:
207
219
  def __getitem__(self, index):
208
220
  if isinstance(index, int):
209
221
  return self.components[index]
210
- if isinstance(index, (str, Flavour)):
222
+ if isinstance(index, (str, Label)):
211
223
  return self.components[self.flavours.index(index)]
212
224
 
213
225
  def __len__(self):
@@ -6,12 +6,14 @@ import logging as log
6
6
  from copy import copy
7
7
  from dataclasses import dataclass
8
8
  from pathlib import Path
9
- from subprocess import CalledProcessError, check_output
10
9
  from typing import Literal
11
10
 
12
11
  import yaml
13
12
  from dotmap import DotMap
14
13
  from ftag import Cuts
14
+ from ftag.git_check import get_git_hash
15
+ from ftag.labels import LabelContainer
16
+ from ftag.track_selector import TrackSelector
15
17
  from ftag.transform import Transform
16
18
  from yamlinclude import YamlIncludeConstructor
17
19
 
@@ -42,6 +44,7 @@ class PreprocessingConfig:
42
44
  For example:
43
45
  ```yaml
44
46
  global:
47
+ jets_name: jets
45
48
  batch_size: 1_000_000
46
49
  num_jets_estimate: 5_000_000
47
50
  base_dir: /my/stuff/
@@ -69,8 +72,21 @@ class PreprocessingConfig:
69
72
  especially to the `countup` method to achive best agreement of target and resampled
70
73
  distributions.
71
74
  num_jets_estimate : int
75
+ Any of the further three arguments that are not specified will default to this value
76
+ Is equal to 1_000_000 by default.
77
+ num_jets_estimate_available : int | None
78
+ A sabsample taken from the whole sample to estimate the number of jets after the cuts.
79
+ Please keep this number high in order to not get poisson error of more then 5%.
80
+ If time allows you can use -1 to get a precise number of jets and not just an estimate
81
+ although it will be slow for large datasets. Is equal to num_jets_estimate by default.
82
+ num_jets_estimate_hist : int
72
83
  Number of jets of each flavour that are used to construct histograms for probability
73
84
  density function estimation. Larger numbers give a better quality estmate of the pdfs.
85
+ Is equal to num_jets_estimate by default.
86
+ num_jets_estimate_norm : int
87
+ Number of jets of each flavour that are used to estimate shifting and scaling during
88
+ normalisation step. Larger numbers give a better quality estmates.
89
+ Is equal to num_jets_estimate by default.
74
90
  jets_name : str
75
91
  Name of the jets dataset in the input file.
76
92
  """
@@ -85,24 +101,50 @@ class PreprocessingConfig:
85
101
  out_fname: Path = Path("pp_output.h5")
86
102
  batch_size: int = 100_000
87
103
  num_jets_estimate: int = 1_000_000
104
+ num_jets_estimate_available: int | None = None
105
+ num_jets_estimate_hist: int | None = None
106
+ num_jets_estimate_norm: int | None = None
88
107
  merge_test_samples: bool = False
89
108
  jets_name: str = "jets"
109
+ flavour_config: Path | None = None
90
110
 
91
111
  def __post_init__(self):
92
112
  # postprocess paths
113
+ if self.num_jets_estimate:
114
+ if self.num_jets_estimate_available is None:
115
+ self.num_jets_estimate_available = max(self.num_jets_estimate, int(1e6))
116
+ if self.num_jets_estimate_hist is None:
117
+ self.num_jets_estimate_hist = self.num_jets_estimate
118
+ if self.num_jets_estimate_norm is None:
119
+ self.num_jets_estimate_norm = self.num_jets_estimate
120
+
93
121
  for field in dataclasses.fields(self):
94
- if field.type == "Path" and field.name != "out_fname":
122
+ if field.type == "Path" and field.name != "out_fname" and field.name != "base_dir":
95
123
  setattr(self, field.name, self.get_path(Path(getattr(self, field.name))))
96
124
  if not self.ntuple_dir.exists():
97
125
  raise FileNotFoundError(f"Path {self.ntuple_dir} does not exist")
98
126
  self.components_dir = self.components_dir / self.split
99
127
  self.out_fname = self.out_dir / path_append(self.out_fname, self.split)
128
+ self.flavour_cont = LabelContainer.from_yaml(self.flavour_config)
100
129
 
101
130
  # configure classes
102
131
  sampl_cfg = copy(self.config["resampling"])
103
- self.sampl_cfg = ResamplingConfig(sampl_cfg.pop("variables"), **sampl_cfg)
132
+ if self.is_test:
133
+ sampl_cfg["method"] = None
134
+ self.sampl_cfg = ResamplingConfig(**sampl_cfg)
104
135
  self.components = Components.from_config(self)
105
- self.variables = VariableConfig(self.config["variables"], self.jets_name, self.is_test)
136
+
137
+ # get track selectors
138
+ vc = self.config["variables"]
139
+ selectors = {}
140
+ for name, groups in vc.items():
141
+ if selection := groups.get("selection", None):
142
+ selectors[name] = TrackSelector(Cuts.from_list(selection))
143
+
144
+ # configure variables
145
+ self.variables = VariableConfig(
146
+ self.config["variables"], self.jets_name, self.is_test, selectors
147
+ )
106
148
  self.variables = self.variables.add_jet_vars(
107
149
  list(self.config["resampling"]["variables"].keys()), "labels"
108
150
  )
@@ -110,17 +152,13 @@ class PreprocessingConfig:
110
152
  Transform(**self.config["transform"]) if "transform" in self.config else None
111
153
  )
112
154
 
113
- # copy config
114
- try:
115
- git_hash = check_output(
116
- ["git", "rev-parse", "--short", "HEAD"], cwd=Path(__file__).parent
117
- )
118
- self.git_hash = git_hash.decode("ascii").strip()
119
- self.config["pp_git_hash"] = self.git_hash
120
- except CalledProcessError:
121
- log.warning("Could not get git hash")
155
+ # reproducibility
156
+ self.git_hash = get_git_hash(Path(__file__).parent)
157
+ if self.git_hash is None:
122
158
  self.git_hash = __version__
123
- self.config["pp_git_hash"] = self.git_hash
159
+ self.config["upp_hash"] = self.git_hash
160
+
161
+ # copy config
124
162
  self.copy_config()
125
163
 
126
164
  @classmethod
@@ -3,12 +3,15 @@ from __future__ import annotations
3
3
  from copy import deepcopy
4
4
  from dataclasses import dataclass
5
5
 
6
+ from ftag.track_selector import TrackSelector
7
+
6
8
 
7
9
  @dataclass(frozen=True)
8
10
  class VariableConfig:
9
11
  variables: dict[str, dict[str, list[str]]]
10
12
  jets_name: str = "jets"
11
13
  keep_all: bool = False
14
+ selectors: dict[str, TrackSelector] | None = None
12
15
 
13
16
  def __post_init__(self):
14
17
  for track_vars in self.tracks.values():
@@ -33,7 +36,7 @@ class VariableConfig:
33
36
 
34
37
  def add_jet_vars(self, variables: list[str], kind: str = "inputs") -> VariableConfig:
35
38
  """Return a new VariableConfig instance."""
36
- vc = VariableConfig(deepcopy(self.variables), self.jets_name, self.keep_all)
39
+ vc = VariableConfig(deepcopy(self.variables), self.jets_name, self.keep_all, self.selectors)
37
40
  vc.jets[kind] = list(dict.fromkeys(vc.jets[kind] + variables))
38
41
  return vc
39
42
 
@@ -8,11 +8,13 @@ To run without certain stages, include the corresponding negative flag.
8
8
  Note that all stages are required to run the pipeline. If you want to disable resampling,
9
9
  you need to set method: none in your config file.
10
10
  """
11
+
11
12
  from __future__ import annotations
12
13
 
13
14
  import argparse
14
15
  from datetime import datetime
15
- from pathlib import Path
16
+
17
+ from ftag.cli_utils import HelpFormatter, valid_path
16
18
 
17
19
  from upp.classes.preprocessing_config import PreprocessingConfig
18
20
  from upp.logger import setup_logger
@@ -23,30 +25,24 @@ from upp.stages.plot import plot_initial_resampling_dists, plot_resampled_dists
23
25
  from upp.stages.resampling import Resampling
24
26
 
25
27
 
26
- class HelpFormatter(argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter): ...
27
-
28
-
29
- def parse_args():
30
- abool = "store_true"
31
- parser = argparse.ArgumentParser(
32
- description=__doc__,
33
- formatter_class=HelpFormatter,
34
- )
35
- parser.add_argument("--config", required=True, type=Path, help="Path to config file")
36
- parser.add_argument("--prep", action=abool, default=None, help="Estimate and write PDFs")
28
+ def parse_args(args):
29
+ _st = "store_true"
30
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=HelpFormatter)
31
+ parser.add_argument("--config", required=True, type=valid_path, help="Path to config file")
32
+ parser.add_argument("--prep", action=_st, default=None, help="Estimate and write PDFs")
37
33
  parser.add_argument("--no-prep", dest="prep", action="store_false")
38
- parser.add_argument("--resample", action=abool, default=None, help="Run resampling")
34
+ parser.add_argument("--resample", action=_st, default=None, help="Run resampling")
39
35
  parser.add_argument("--no-resample", dest="resample", action="store_false")
40
- parser.add_argument("--merge", action=abool, default=None, help="Run merging")
36
+ parser.add_argument("--merge", action=_st, default=None, help="Run merging")
41
37
  parser.add_argument("--no-merge", dest="merge", action="store_false")
42
- parser.add_argument("--norm", action=abool, default=None, help="Compute normalisations")
38
+ parser.add_argument("--norm", action=_st, default=None, help="Compute normalisations")
43
39
  parser.add_argument("--no-norm", dest="norm", action="store_false")
44
- parser.add_argument("--plot", action=abool, default=None, help="Plot resampled distributions")
40
+ parser.add_argument("--plot", action=_st, default=None, help="Plot output distributions")
45
41
  parser.add_argument("--no-plot", dest="plot", action="store_false")
46
42
  splits = ["train", "val", "test", "all"]
47
43
  parser.add_argument("--split", default="train", choices=splits, help="Which file to produce")
48
44
 
49
- args = parser.parse_args()
45
+ args = parser.parse_args(args)
50
46
  d = vars(args)
51
47
  ignore = ["config", "split"]
52
48
  if not any(v for a, v in d.items() if a not in ignore):
@@ -65,7 +61,7 @@ def run_pp(args) -> None:
65
61
  log.info(f"Start time: {start.strftime('%Y-%m-%d %H:%M:%S')}")
66
62
 
67
63
  # load config
68
- config = PreprocessingConfig.from_file(Path(args.config), args.split)
64
+ config = PreprocessingConfig.from_file(args.config, args.split)
69
65
 
70
66
  # create virtual datasets and pdf files
71
67
  if args.prep and args.split == "train":
@@ -88,6 +84,8 @@ def run_pp(args) -> None:
88
84
 
89
85
  # make plots
90
86
  if args.plot:
87
+ title = " Plotting "
88
+ log.info(f"[bold green]{title:-^100}")
91
89
  plot_initial_resampling_dists(config=config)
92
90
  plot_resampled_dists(config=config, stage=args.split)
93
91
 
@@ -99,8 +97,8 @@ def run_pp(args) -> None:
99
97
  log.info(f"Elapsed time: {str(end - start).split('.')[0]}")
100
98
 
101
99
 
102
- def main() -> None:
103
- args = parse_args()
100
+ def main(args=None) -> None:
101
+ args = parse_args(args)
104
102
  log = setup_logger()
105
103
 
106
104
  if args.split == "all":
@@ -130,19 +130,23 @@ def create_histograms(config) -> None:
130
130
  title = " Writing PDFs "
131
131
  log.info(f"[bold green]{title:-^100}")
132
132
 
133
- log.info(f"[bold green]Estimating PDFs using {config.num_jets_estimate:,} jets...")
133
+ log.info(f"[bold green]Estimating PDFs using {config.num_jets_estimate_hist:,} jets...")
134
134
  sampl_vars = config.sampl_cfg.vars
135
135
  for c in config.components:
136
- log.info(f"Estimating PDF for {c}")
137
- c.setup_reader(config.batch_size)
136
+ log.info(f"Estimating {c} PDF using {config.num_jets_estimate_hist:,} samples...")
137
+ c.setup_reader(config.batch_size, config.jets_name)
138
138
  cuts_no_split = c.cuts.ignore(["eventNumber"])
139
+
140
+ ###
141
+ # TODO: return the number of jets here and pass to the next function to get started
142
+ ###
139
143
  c.check_num_jets(
140
- config.num_jets_estimate,
144
+ config.num_jets_estimate_hist,
141
145
  cuts=cuts_no_split,
142
146
  silent=False,
143
147
  raise_error=False,
144
148
  )
145
- jets = c.get_jets(sampl_vars, config.num_jets_estimate, cuts_no_split)
149
+ jets = c.get_jets(sampl_vars, config.num_jets_estimate_hist, cuts_no_split)
146
150
  c.hist.write_hist(jets, sampl_vars, config.sampl_cfg.flat_bins)
147
151
 
148
152
  log.info(f"[bold green]Saved to {config.components[0].hist.path.parent}/")
@@ -70,7 +70,7 @@ def upscale_array(
70
70
  def upscale_array_regionally(
71
71
  array: np.array,
72
72
  upscl: int,
73
- regionlengthsd: list,
73
+ num_bins: list,
74
74
  order: int = 3,
75
75
  mode: str = "nearest",
76
76
  positive: bool = True,
@@ -83,7 +83,7 @@ def upscale_array_regionally(
83
83
  array to be upscaled
84
84
  upscl : int
85
85
  upscaling factor
86
- regionlengthsd : list
86
+ num_bins : list
87
87
  list of lists of region lengths in each dimension,
88
88
  region lengths should sum to the length of the array in that dimension
89
89
  order : int, optional
@@ -99,10 +99,10 @@ def upscale_array_regionally(
99
99
  Array that is upscaled by a factor of upscl
100
100
  """
101
101
  up_array = np.empty(shape=[ds * upscl for ds in array.shape])
102
- starts = [np.cumsum([0] + regionlengths)[:-1] for regionlengths in regionlengthsd]
102
+ starts = [np.cumsum([0] + regionlengths)[:-1] for regionlengths in num_bins]
103
103
  starts_grid = np.meshgrid(*starts)
104
104
  starts_grid = [starts_grid[i].flatten() for i in range(len(starts_grid))]
105
- finishes = [np.cumsum(regionlengths) for regionlengths in regionlengthsd]
105
+ finishes = [np.cumsum(regionlengths) for regionlengths in num_bins]
106
106
  finishes_grid = np.meshgrid(*finishes)
107
107
  finishes_grid = [finishes_grid[i].flatten() for i in range(len(finishes_grid))]
108
108
  d = len(array.shape)
@@ -17,11 +17,13 @@ class Merging:
17
17
  self.components = config.components
18
18
  self.variables = config.variables
19
19
  self.batch_size = config.batch_size
20
- self.jets_name = self.ppc.jets_name
20
+ self.jets_name = config.jets_name
21
21
  self.rng = np.random.default_rng(42)
22
22
  self.flavours = self.components.flavours
23
23
 
24
24
  def add_jet_flavour_label(self, jets, component):
25
+ if "flavour_label" in jets.dtype.names:
26
+ return jets
25
27
  int_label = self.flavours.index(component.flavour)
26
28
  label_array = np.full(len(jets), int_label, dtype=[("flavour_label", "i4")])
27
29
  return join_structured_arrays([jets, label_array])
@@ -49,6 +51,13 @@ class Merging:
49
51
  if all(c.complete for c in components):
50
52
  return False
51
53
 
54
+ # apply track selections
55
+ for name in self.variables.variables:
56
+ if name == self.jets_name:
57
+ continue
58
+ if selector := self.variables.selectors.get(name):
59
+ merged[name] = selector(merged[name])
60
+
52
61
  # write
53
62
  self.writer.write(merged)
54
63
  return len(merged[self.jets_name])
@@ -57,7 +66,7 @@ class Merging:
57
66
  # setup inputs
58
67
  for c in components:
59
68
  batch_size = self.batch_size * c.num_jets // components.num_jets + 1
60
- c.setup_reader(batch_size, fname=c.out_path)
69
+ c.setup_reader(batch_size, fname=c.out_path, jets_name=self.jets_name)
61
70
  c.stream = c.reader.stream(self.variables.combined(), c.reader.num_jets)
62
71
  c.complete = False
63
72
 
@@ -70,6 +79,7 @@ class Merging:
70
79
  components[0].reader.dtypes(self.variables.combined()),
71
80
  components[0].reader.shapes(components.num_jets, self.variables.keys()),
72
81
  add_flavour_label=self.jets_name,
82
+ jets_name=self.jets_name,
73
83
  )
74
84
  self.writer.add_attr("flavour_label", [f.name for f in self.flavours], self.jets_name)
75
85
  self.writer.add_attr("unique_jets", components.unique_jets)
@@ -16,7 +16,7 @@ class Normalisation:
16
16
  self.components = config.components
17
17
  self.variables = config.variables
18
18
  self.jets_name = self.ppc.jets_name
19
- self.num_jets = config.num_jets_estimate
19
+ self.num_jets = config.num_jets_estimate_norm
20
20
  self.norm_fname = config.out_dir / config.config.get("norm_fname", "norm_dict.yaml")
21
21
  self.class_fname = config.out_dir / config.config.get("class_fname", "class_dict.yaml")
22
22
 
@@ -62,7 +62,7 @@ class Normalisation:
62
62
  return combined
63
63
 
64
64
  def get_class_dict(self, batch):
65
- ignore = ["VertexIndex", "ftagTruthParentBarcode", "barcode"]
65
+ ignore = ["VertexIndex", "ftagTruthParentBarcode", "barcode", "eventNumber", "jetFoldHash"]
66
66
  class_dict = {k: {} for k in self.variables}
67
67
  for name, array in batch.items():
68
68
  if name != self.variables.jets_name:
@@ -118,7 +118,9 @@ class Normalisation:
118
118
  log.info(f"[bold green]{title:-^100}")
119
119
 
120
120
  # setup reader
121
- reader = H5Reader(self.ppc.out_fname, self.ppc.batch_size, precision="full")
121
+ reader = H5Reader(
122
+ self.ppc.out_fname, self.ppc.batch_size, precision="full", jets_name=self.jets_name
123
+ )
122
124
  log.debug(f"Setup reader at: {self.ppc.out_fname}")
123
125
 
124
126
  norm_dict = None
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
 
6
6
  from ftag import Flavours
7
7
  from ftag.hdf5 import H5Reader
8
+ from ftag.labels import LabelContainer
8
9
  from puma import Histogram, HistogramPlot
9
10
 
10
11
  from upp.utils import path_append
@@ -14,6 +15,7 @@ def load_jets(
14
15
  paths: str | list,
15
16
  variable: str,
16
17
  flavour_label="flavour_label",
18
+ jets_name="jets",
17
19
  ) -> dict:
18
20
  """
19
21
  Load the variables and labels from the jets in a given file(s).
@@ -28,15 +30,18 @@ def load_jets(
28
30
  flavour_label : str, optional
29
31
  Name of the flavour label variable which is used for the labels,
30
32
  by default "flavour_label"
33
+ jets_name: str, optional
34
+ Name of the jet dataset / the global objects
35
+ by default "jets"
31
36
 
32
37
  Returns
33
38
  -------
34
39
  dict
35
40
  Dict with the loaded variable and labels.
36
41
  """
37
- variables = {"jets": [flavour_label, variable]}
38
- reader = H5Reader(paths, batch_size=1000)
39
- df = reader.load(variables, num_jets=10000)["jets"]
42
+ variables = {jets_name: [flavour_label, variable]}
43
+ reader = H5Reader(paths, batch_size=1000, jets_name=jets_name)
44
+ df = reader.load(variables, num_jets=10000)[jets_name]
40
45
  return df
41
46
 
42
47
 
@@ -45,8 +50,10 @@ def make_hist(
45
50
  flavours: list,
46
51
  variable: str,
47
52
  in_paths: str | list,
53
+ jets_name: str = "jets",
48
54
  bins_range: tuple | None = None,
49
55
  suffix: str = "",
56
+ flavour_cont: LabelContainer = Flavours,
50
57
  ) -> None:
51
58
  """
52
59
  Create and plot the histogram and save it to disk.
@@ -64,6 +71,9 @@ def make_hist(
64
71
  Variable that is to be histogrammed and plotted.
65
72
  in_paths : str
66
73
  Path to the files from which the jets are loaded.
74
+ jets_name: str, optional
75
+ Name of the jet dataset / the global objects
76
+ by default "jets"
67
77
  bins_range : tuple, optional
68
78
  bins_range argument from from puma.HistogramPlot,
69
79
  by default None
@@ -72,11 +82,11 @@ def make_hist(
72
82
  output name, by default "".
73
83
  """
74
84
  # Load the variable from the jets
75
- df = load_jets(in_paths, variable)
85
+ df = load_jets(in_paths, variable, jets_name=jets_name)
76
86
 
77
87
  # Setup the histogram
78
88
  plot = HistogramPlot(
79
- ylabel="Normalised Number of jets",
89
+ ylabel=f"Normalised Number of {jets_name}",
80
90
  atlas_second_tag="$\\sqrt{s}=13$ TeV",
81
91
  xlabel=variable,
82
92
  bins=50,
@@ -94,8 +104,8 @@ def make_hist(
94
104
  plot.add(
95
105
  Histogram(
96
106
  df[df["flavour_label"] == label_value][variable],
97
- label=Flavours[label_string].label,
98
- colour=Flavours[label_string].colour,
107
+ label=flavour_cont[label_string].label,
108
+ colour=flavour_cont[label_string].colour,
99
109
  )
100
110
  )
101
111
 
@@ -118,6 +128,7 @@ def make_hist_initial(
118
128
  flavours: list,
119
129
  variable: str,
120
130
  in_paths_list: str | list,
131
+ jets_name: str = "jets",
121
132
  bins_range: tuple | None = None,
122
133
  suffix: str = "",
123
134
  jets_to_plot: int = -1,
@@ -125,7 +136,7 @@ def make_hist_initial(
125
136
  suffixes: list | None = None,
126
137
  out_format: str = "png",
127
138
  ) -> None:
128
- """Make inistal dist plots.
139
+ """Make initial distribution plots.
129
140
 
130
141
  Plot the initial distribution of the given variable
131
142
  for multiple different samples (like ttbar, zpext, etc.)
@@ -145,6 +156,9 @@ def make_hist_initial(
145
156
  in_paths_list : str | list
146
157
  String or list of strings with the paths to the files
147
158
  from which the jets are loaded.
159
+ jets_name: str, optional
160
+ Name of the jet dataset / the global objects
161
+ by default "jets"
148
162
  bins_range : tuple, optional
149
163
  bins_range argument from from puma.HistogramPlot,
150
164
  by default None
@@ -163,7 +177,7 @@ def make_hist_initial(
163
177
  """
164
178
  # Setup the histogram
165
179
  plot = HistogramPlot(
166
- ylabel="Normalised Number of jets",
180
+ ylabel=f"Normalised Number of {jets_name}",
167
181
  atlas_second_tag="$\\sqrt{s}=13$ TeV",
168
182
  xlabel=variable,
169
183
  bins=100,
@@ -187,7 +201,7 @@ def make_hist_initial(
187
201
  # Loop over the different samples
188
202
  for i, in_paths in enumerate(in_paths_list):
189
203
  # Load jets from the file
190
- reader = H5Reader(in_paths, batch_size=10000)
204
+ reader = H5Reader(in_paths, batch_size=10000, jets_name=jets_name)
191
205
 
192
206
  # Loop over the flavours
193
207
  for flavour in flavours:
@@ -197,12 +211,10 @@ def make_hist_initial(
197
211
  plot.add(
198
212
  Histogram(
199
213
  reader.load(
200
- {"jets": [variable]},
214
+ {jets_name: [variable]},
201
215
  num_jets=jets_to_plot,
202
216
  cuts=flavour.cuts,
203
- )[
204
- "jets"
205
- ][variable],
217
+ )[jets_name][variable],
206
218
  label=flavour.label + " " + suffixes[i],
207
219
  colour=flavour.colour,
208
220
  linestyle=linestiles[i],
@@ -250,6 +262,7 @@ def plot_initial_resampling_dists(config) -> None:
250
262
  flavours=config.components.flavours,
251
263
  variable=var,
252
264
  in_paths_list=paths,
265
+ jets_name=config.jets_name,
253
266
  jets_to_plot=100000,
254
267
  out_dir=config.out_dir / "plots",
255
268
  suffixes=suffixes,
@@ -260,6 +273,7 @@ def plot_initial_resampling_dists(config) -> None:
260
273
  flavours=config.components.flavours,
261
274
  variable=var,
262
275
  in_paths_list=paths,
276
+ jets_name=config.jets_name,
263
277
  bins_range=(0, 500e3),
264
278
  suffix="low",
265
279
  jets_to_plot=100000,
@@ -293,15 +307,19 @@ def plot_resampled_dists(config, stage: str) -> None:
293
307
  make_hist(
294
308
  stage=stage,
295
309
  flavours=config.components.flavours,
310
+ flavour_cont=config.flavour_cont,
296
311
  variable=var,
297
312
  in_paths=paths,
313
+ jets_name=config.jets_name,
298
314
  )
299
315
  if "pt" in var:
300
316
  make_hist(
301
317
  stage=stage,
302
318
  flavours=config.components.flavours,
319
+ flavour_cont=config.flavour_cont,
303
320
  variable=var,
304
321
  in_paths=paths,
322
+ jets_name=config.jets_name,
305
323
  bins_range=(0, 500e3),
306
324
  suffix="low",
307
325
  )
@@ -38,14 +38,14 @@ class Resampling:
38
38
  self.components = config.components
39
39
  self.variables = config.variables
40
40
  self.batch_size = config.batch_size
41
- self.is_test = config.is_test
42
- self.num_jets_estimate = config.num_jets_estimate
41
+ self.jets_name = config.jets_name
43
42
  self.upscale_pdf = config.sampl_cfg.upscale_pdf or 1
44
- self.regionlengthsd = self.get_regionlengthsd_from_config()
43
+ self.num_bins = self.get_num_bins_from_config()
45
44
  self.methods_map = {
46
45
  "pdf": self.pdf_select_func,
47
46
  "countup": self.countup_select_func,
48
47
  "none": None,
48
+ None: None,
49
49
  }
50
50
  if self.config.method not in self.methods_map:
51
51
  raise ValueError(
@@ -97,7 +97,7 @@ class Resampling:
97
97
  num_samples = int(len(jets) * component.sampling_fraction)
98
98
  ratios = safe_divide(self.target.hist.pbin, component.hist.pbin)
99
99
  if self.upscale_pdf > 1:
100
- ratios = upscale_array_regionally(ratios, self.upscale_pdf, self.regionlengthsd)
100
+ ratios = upscale_array_regionally(ratios, self.upscale_pdf, self.num_bins)
101
101
  probs = ratios[binnumbers]
102
102
  idx = random.choices(np.arange(len(jets)), weights=probs, k=num_samples)
103
103
  return idx
@@ -123,7 +123,7 @@ class Resampling:
123
123
 
124
124
  # apply sampling
125
125
  idx = np.arange(len(batch_out[self.variables.jets_name]))
126
- if c != self.target and not self.is_test and self.select_func:
126
+ if c != self.target and self.select_func:
127
127
  idx = self.select_func(batch_out[self.variables.jets_name], c)
128
128
  if len(idx) == 0:
129
129
  continue
@@ -177,6 +177,7 @@ class Resampling:
177
177
  reader = H5Reader(
178
178
  sample.path,
179
179
  self.batch_size,
180
+ jets_name=self.jets_name,
180
181
  equal_jets=equal_jets_flag,
181
182
  transform=self.transform,
182
183
  )
@@ -242,8 +243,8 @@ class Resampling:
242
243
  # setup i/o
243
244
  for c in self.components:
244
245
  # just used for the writer configuration
245
- c.setup_reader(self.batch_size, transform=self.transform)
246
- c.setup_writer(self.variables)
246
+ c.setup_reader(self.batch_size, jets_name=self.jets_name, transform=self.transform)
247
+ c.setup_writer(self.variables, jets_name=self.jets_name)
247
248
 
248
249
  # set samplig fraction if needed
249
250
  self.set_component_sampling_fractions()
@@ -254,7 +255,7 @@ class Resampling:
254
255
  f" {self.config.sampling_fraction}..."
255
256
  )
256
257
  for c in self.components:
257
- frac = c.sampling_fraction if not self.is_test else 1
258
+ frac = c.sampling_fraction if self.select_func else 1
258
259
  c.check_num_jets(c.num_jets, sampling_frac=frac, cuts=c.cuts)
259
260
 
260
261
  # run resampling
@@ -268,7 +269,7 @@ class Resampling:
268
269
  log.info(f"[bold green]Estimated unqiue jets: {unique:,.0f}")
269
270
  log.info(f"[bold green]Saved to {self.components.out_dir}/")
270
271
 
271
- def get_regionlengthsd_from_config(self) -> list[list[int]]:
272
+ def get_num_bins_from_config(self) -> list[list[int]]:
272
273
  """Get the lengths of the binning regions in each variable from the config.
273
274
 
274
275
  Returns
@@ -276,7 +277,7 @@ class Resampling:
276
277
  typing.List[typing.List[int]]
277
278
  lengths of the binning regions in each variable from the config
278
279
  """
279
- regionlengthsd = []
280
+ num_bins = []
280
281
  for row in self.config.bins.values():
281
- regionlengthsd.append([sub[-1] for sub in row])
282
- return regionlengthsd
282
+ num_bins.append([sub[-1] for sub in row])
283
+ return num_bins
@@ -1,16 +0,0 @@
1
- pyyaml-include==1.3
2
- PyYAML==6.0.1
3
- rich==12.6.0
4
- scipy==1.10.1
5
- puma-hep==0.3.0
6
- atlas-ftag-tools==0.1.10
7
- dotmap==1.3.30
8
-
9
- [dev]
10
- black==23.9.1
11
- ruff==0.0.289
12
- mypy==1.5.1
13
- pre-commit==3.1.1
14
- pytest==7.2.2
15
- pytest-mock==3.11.1
16
- pytest-cov==4.0.0