umami-preprocessing 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/PKG-INFO +1 -1
  2. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/umami_preprocessing.egg-info/PKG-INFO +1 -1
  3. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/__init__.py +1 -1
  4. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/main.py +12 -3
  5. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/stages/hist.py +47 -6
  6. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/stages/resampling.py +103 -53
  7. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/README.md +0 -0
  8. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/pyproject.toml +0 -0
  9. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/setup.cfg +0 -0
  10. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/umami_preprocessing.egg-info/SOURCES.txt +0 -0
  11. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/umami_preprocessing.egg-info/dependency_links.txt +0 -0
  12. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/umami_preprocessing.egg-info/entry_points.txt +0 -0
  13. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/umami_preprocessing.egg-info/requires.txt +0 -0
  14. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/umami_preprocessing.egg-info/top_level.txt +0 -0
  15. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/classes/__init__.py +0 -0
  16. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/classes/components.py +0 -0
  17. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/classes/preprocessing_config.py +0 -0
  18. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/classes/region.py +0 -0
  19. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/classes/resampling_config.py +0 -0
  20. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/classes/variable_config.py +0 -0
  21. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/logger.py +0 -0
  22. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/stages/__init__.py +0 -0
  23. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/stages/interpolation.py +0 -0
  24. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/stages/merging.py +0 -0
  25. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/stages/normalisation.py +0 -0
  26. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/stages/plot.py +0 -0
  27. {umami_preprocessing-0.2.2 → umami_preprocessing-0.2.3}/upp/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: umami-preprocessing
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Preprocessing for jet tagging
5
5
  License: MIT
6
6
  Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: umami-preprocessing
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Preprocessing for jet tagging
5
5
  License: MIT
6
6
  Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
@@ -2,4 +2,4 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.2"
5
+ __version__ = "v0.2.3"
@@ -41,10 +41,16 @@ def parse_args(args):
41
41
  parser.add_argument("--no-plot", dest="plot", action="store_false")
42
42
  splits = ["train", "val", "test", "all"]
43
43
  parser.add_argument("--split", default="train", choices=splits, help="Which file to produce")
44
+ parser.add_argument(
45
+ "--component", default=None, help="Component which is processed during --prep"
46
+ )
47
+ parser.add_argument(
48
+ "--region", default=None, help="Region which is processed during --resample"
49
+ )
44
50
 
45
51
  args = parser.parse_args(args)
46
52
  d = vars(args)
47
- ignore = ["config", "split"]
53
+ ignore = ["config", "split", "component", "region"]
48
54
  if not any(v for a, v in d.items() if a not in ignore):
49
55
  for v in d:
50
56
  if v not in ignore and d[v] is None:
@@ -65,12 +71,15 @@ def run_pp(args) -> None:
65
71
 
66
72
  # create virtual datasets and pdf files
67
73
  if args.prep and args.split == "train":
68
- create_histograms(config)
74
+ create_histograms(
75
+ config=config,
76
+ component_to_run=args.component,
77
+ )
69
78
 
70
79
  # run the resampling
71
80
  if args.resample:
72
81
  resampling = Resampling(config)
73
- resampling.run()
82
+ resampling.run(region=args.region)
74
83
 
75
84
  # run the merging
76
85
  if args.merge:
@@ -5,6 +5,7 @@ import logging as log
5
5
  import math
6
6
  from dataclasses import dataclass
7
7
  from pathlib import Path
8
+ from typing import TYPE_CHECKING
8
9
 
9
10
  import h5py
10
11
  import numpy as np
@@ -13,6 +14,9 @@ from scipy.stats import binned_statistic_dd
13
14
 
14
15
  from upp.logger import setup_logger
15
16
 
17
+ if TYPE_CHECKING: # pragma: no cover
18
+ from upp.classes.preprocessing_config import PreprocessingConfig
19
+
16
20
 
17
21
  def bin_jets(array: dict, bins: list) -> np.ndarray:
18
22
  """Create the histogram and bins for the given resampling variables.
@@ -117,24 +121,39 @@ class Hist:
117
121
  return f["pbin"][:]
118
122
 
119
123
 
120
- def create_histograms(config) -> None:
124
+ def create_histograms(
125
+ config: PreprocessingConfig,
126
+ component_to_run: str | None = None,
127
+ ) -> None:
121
128
  """Create the virtual datasets and pdf files.
122
129
 
123
130
  Parameters
124
131
  ----------
125
132
  config : PreprocessingConfig object
126
133
  PreprocessingConfig object of the current preprocessing.
134
+ component_to_run : str | None
135
+ Component which should be run. By default (None), all components
136
+ are processed sequentially.
127
137
  """
138
+ # Setup the logger and load the variables used for resampling
128
139
  setup_logger()
140
+ sampl_vars = config.sampl_cfg.vars
129
141
 
130
142
  title = " Writing PDFs "
131
143
  log.info(f"[bold green]{title:-^100}")
132
-
133
144
  log.info(f"[bold green]Estimating PDFs using {config.num_jets_estimate_hist:,} jets...")
134
- sampl_vars = config.sampl_cfg.vars
145
+
146
+ # Create check variable to ensure at least one component was processed
147
+ component_processed = not component_to_run
148
+
149
+ # Process the different components
135
150
  for component in config.components:
151
+ # Check if only one component should be processed
152
+ if isinstance(component_to_run, str) and component_to_run != component.name:
153
+ continue
154
+
136
155
  log.info(f"Estimating {component} PDF using {config.num_jets_estimate_hist:,} samples...")
137
- component.setup_reader(config.batch_size, config.jets_name)
156
+ component.setup_reader(batch_size=config.batch_size, jets_name=config.jets_name)
138
157
  cuts_no_split = component.cuts.ignore(["eventNumber"])
139
158
 
140
159
  ###
@@ -146,7 +165,29 @@ def create_histograms(config) -> None:
146
165
  silent=False,
147
166
  raise_error=False,
148
167
  )
149
- jets = component.get_jets(sampl_vars, config.num_jets_estimate_hist, cuts_no_split)
150
- component.hist.write_hist(jets, sampl_vars, config.sampl_cfg.flat_bins)
168
+
169
+ # Load the jets from file used for resampling
170
+ jets = component.get_jets(
171
+ variables=sampl_vars,
172
+ num_jets=config.num_jets_estimate_hist,
173
+ cuts=cuts_no_split,
174
+ )
175
+
176
+ # Write out the hist used for resampling
177
+ component.hist.write_hist(
178
+ jets=jets,
179
+ resampling_vars=sampl_vars,
180
+ bins=config.sampl_cfg.flat_bins,
181
+ )
182
+
183
+ # Set the check variable to true
184
+ component_processed = True
185
+
186
+ # Raise error of no region was processed
187
+ if component_processed is False:
188
+ raise ValueError(
189
+ "No component processed during resampling! Check that you correctly spelled "
190
+ "the component name when running with --component!"
191
+ )
151
192
 
152
193
  log.info(f"[bold green]Saved to {config.components[0].hist.path.parent}/")
@@ -356,9 +356,14 @@ class Resampling:
356
356
  f" Jets are upsampled at most {np.max(component._ups_max):.0f} times"
357
357
  )
358
358
 
359
- def set_component_sampling_fractions(self):
359
+ def set_component_sampling_fractions(self, component: Component):
360
360
  """Automatically set the sampling fraction for each of the components.
361
361
 
362
+ Parameters
363
+ ----------
364
+ component : Component
365
+ Component for which the sampling fraction is set.
366
+
362
367
  Raises
363
368
  ------
364
369
  ValueError
@@ -366,71 +371,86 @@ class Resampling:
366
371
  """
367
372
  # Check that the sampling fraction must be found automatically
368
373
  if self.config.sampling_fraction == "auto" or self.config.sampling_fraction is None:
369
- log.info("[bold green]Sampling fraction chosen for each component automatically...")
374
+ log.info(
375
+ "[bold green]Sampling fraction will be chosen "
376
+ f"automatically for {component.name}..."
377
+ )
370
378
 
371
- # Loop over each component
372
- for component in self.components:
373
- # Target component always gets one as sampling fraction
374
- if component.is_target(self.config.target):
375
- component.sampling_fraction = 1
376
-
377
- else:
378
- sam_frac = component.get_auto_sampling_fraction(
379
- num_jets=component.num_jets,
380
- cuts=component.cuts,
381
- )
379
+ # Target component always gets one as sampling fraction
380
+ if component.is_target(self.config.target):
381
+ component.sampling_fraction = 1
382
+
383
+ else:
384
+ sam_frac = component.get_auto_sampling_fraction(
385
+ num_jets=component.num_jets,
386
+ cuts=component.cuts,
387
+ )
382
388
 
383
- # Raise an error/warning if the sampling fraction is above one
384
- # for the countup/pdf method
385
- if sam_frac > 1:
386
- if self.config.method == "countup":
387
- raise ValueError(
388
- f"[bold red]Sampling fraction of {sam_frac:.3f}>1 is"
389
- f" needed for component {component} This is not supported for"
390
- " countup method."
391
- )
392
- else:
393
- log.warning(
394
- f"[bold yellow]sampling fraction of {sam_frac:.3f}>1 is"
395
- f" needed for component {component}"
396
- )
397
-
398
- # Ensure the sampling fraction is at least above 0.1
399
- component.sampling_fraction = max(sam_frac, 0.1)
389
+ # Raise an error/warning if the sampling fraction is above one
390
+ # for the countup/pdf method
391
+ if sam_frac > 1:
392
+ if self.config.method == "countup":
393
+ raise ValueError(
394
+ f"[bold red]Sampling fraction of {sam_frac:.3f}>1 is"
395
+ f" needed for component {component} This is not supported for"
396
+ " countup method."
397
+ )
398
+ else:
399
+ log.warning(
400
+ f"[bold yellow]sampling fraction of {sam_frac:.3f}>1 is"
401
+ f" needed for component {component}"
402
+ )
403
+
404
+ # Ensure the sampling fraction is at least above 0.1
405
+ component.sampling_fraction = max(sam_frac, 0.1)
400
406
 
401
407
  else:
402
- # Set the sampling fraction for each component to the value defined
408
+ # Set the sampling fraction for the component to the value defined
403
409
  # in the config
404
- for component in self.components:
405
- if component.is_target(self.config.target):
406
- component.sampling_fraction = 1
410
+ if component.is_target(self.config.target):
411
+ component.sampling_fraction = 1
412
+
413
+ else:
414
+ component.sampling_fraction = self.config.sampling_fraction
407
415
 
408
- else:
409
- component.sampling_fraction = self.config.sampling_fraction
416
+ def run(self, region: str | None = None):
417
+ """Execute the resampling.
410
418
 
411
- def run(self):
412
- """Execute the resampling."""
419
+ Parameters
420
+ ----------
421
+ region : str | None, optional
422
+ Define which region is to be resampled, by default None
423
+ which works through the regions sequentially
424
+
425
+ Raises
426
+ ------
427
+ ValueError
428
+ If no region was processed during resampling
429
+ """
413
430
  title = " Running resampling "
414
431
  log.info(f"[bold green]{title:-^100}")
415
432
  log.info(f"Resampling method: {self.config.method}")
416
433
 
417
- # Setup the different components and readers/writers
434
+ # Setup the different components and readers/writers and their sampling fraction
418
435
  for component in self.components:
436
+ # Check if the component is needed for the region
437
+ if region and region not in component.name:
438
+ continue
439
+
419
440
  # just used for the writer configuration
420
441
  component.setup_reader(
421
442
  self.batch_size, jets_name=self.jets_name, transform=self.transform
422
443
  )
423
444
  component.setup_writer(self.variables, jets_name=self.jets_name)
424
445
 
425
- # Set samplig fraction if needed
426
- self.set_component_sampling_fractions()
446
+ # Set sampling fraction
447
+ self.set_component_sampling_fractions(component=component)
427
448
 
428
- # Check samples
429
- log.info(
430
- "[bold green]Checking requested num_jets based on a sampling fraction of"
431
- f" {self.config.sampling_fraction}..."
432
- )
433
- for component in self.components:
449
+ # Check that enough jets are available
450
+ log.info(
451
+ "[bold green]Checking requested num_jets based on a sampling fraction of"
452
+ f" {self.config.sampling_fraction}..."
453
+ )
434
454
  frac = component.sampling_fraction if self.select_func else 1
435
455
  component.check_num_jets(
436
456
  component.num_jets,
@@ -438,16 +458,46 @@ class Resampling:
438
458
  cuts=component.cuts,
439
459
  )
440
460
 
461
+ # Create check variable to ensure at least one region was processed
462
+ region_processed = not region
463
+
441
464
  # Run resampling
442
- for region, components in self.components.groupby_region():
465
+ for iter_region, iter_components in self.components.groupby_region():
466
+ # Check if a specific region for resampling was chosen
467
+ if region and region != iter_region.name:
468
+ continue
469
+
443
470
  log.info(f"[bold green]Running over region {region}...")
444
- self.run_on_region(components, region)
471
+ self.run_on_region(components=iter_components, region=iter_region)
472
+ region_processed = True
473
+
474
+ # Raise error of no region was processed
475
+ if region_processed is False:
476
+ raise ValueError(
477
+ "No region processed during resampling! Check that you correctly spelled "
478
+ "the region name when running with --region!"
479
+ )
445
480
 
446
481
  # Finalise the resampling
447
- unique = sum(component.writer.get_attr("unique_jets") for component in self.components)
448
- log.info(f"[bold green]Finished resampling a total of {self.components.num_jets:,} jets!")
449
- log.info(f"[bold green]Estimated unqiue jets: {unique:,.0f}")
450
- log.info(f"[bold green]Saved to {self.components.out_dir}/")
482
+ if region:
483
+ unique = 0
484
+ for component in self.components:
485
+ if region in component.name:
486
+ unique += component.writer.get_attr("unique_jets")
487
+ log.info(
488
+ f"[bold green]Finished resampling of region {region}. "
489
+ f"A total of {self.components.num_jets:,} jets!"
490
+ )
491
+ log.info(f"[bold green]Estimated unqiue jets: {unique:,.0f}")
492
+ log.info(f"[bold green]Saved to {self.components.out_dir}/")
493
+
494
+ else:
495
+ unique = sum(component.writer.get_attr("unique_jets") for component in self.components)
496
+ log.info(
497
+ f"[bold green]Finished resampling a total of {self.components.num_jets:,} jets!"
498
+ )
499
+ log.info(f"[bold green]Estimated unqiue jets: {unique:,.0f}")
500
+ log.info(f"[bold green]Saved to {self.components.out_dir}/")
451
501
 
452
502
  def get_num_bins_from_config(self) -> list[list[int]]:
453
503
  """Get the lengths of the binning regions in each variable from the config.