atlas-schema 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atlas-schema
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: Helper python package for ATLAS Common NTuple Analysis work.
5
5
  Project-URL: Homepage, https://github.com/scipp-atlas/atlas-schema
6
6
  Project-URL: Bug Tracker, https://github.com/scipp-atlas/atlas-schema/issues
7
7
  Project-URL: Discussions, https://github.com/scipp-atlas/atlas-schema/discussions
8
- Project-URL: Documentation, https://atlas-schema.readthedocs.io/en/v0.2.2/
8
+ Project-URL: Documentation, https://atlas-schema.readthedocs.io/en/v0.2.4/
9
9
  Project-URL: Releases, https://github.com/scipp-atlas/atlas-schema/releases
10
10
  Project-URL: Release Notes, https://atlas-schema.readthedocs.io/en/latest/history.html
11
11
  Author-email: Giordon Stark <kratsg@gmail.com>
@@ -251,7 +251,7 @@ Requires-Dist: tbump>=6.7.0; extra == 'test'
251
251
  Requires-Dist: twine; extra == 'test'
252
252
  Description-Content-Type: text/markdown
253
253
 
254
- # atlas-schema v0.2.2
254
+ # atlas-schema v0.2.4
255
255
 
256
256
  [![Actions Status][actions-badge]][actions-link]
257
257
  [![Documentation Status][rtd-badge]][rtd-link]
@@ -279,6 +279,129 @@ Description-Content-Type: text/markdown
279
279
 
280
280
  <!-- prettier-ignore-end -->
281
281
 
282
+ This is the python package containing schemas and helper functions enabling
283
+ analyzers to work with ATLAS datasets (Monte Carlo and Data), using
284
+ [coffea](https://coffea-hep.readthedocs.io/en/latest/).
285
+
286
+ ## Hello World
287
+
288
+ The simplest example is to just get started processing the file as expected:
289
+
290
+ ```python
291
+ from atlas_schema.schema import NtupleSchema
292
+ from coffea import dataset_tools
293
+ import awkward as ak
294
+
295
+ fileset = {"ttbar": {"files": {"path/to/ttbar.root": "tree_name"}}}
296
+ samples, report = dataset_tools.preprocess(fileset)
297
+
298
+
299
+ def noop(events):
300
+ return ak.fields(events)
301
+
302
+
303
+ fields = dataset_tools.apply_to_fileset(noop, samples, schemaclass=NtupleSchema)
304
+ print(fields)
305
+ ```
306
+
307
+ which produces something similar to
308
+
309
+ ```python
310
+ {
311
+ "ttbar": [
312
+ "dataTakingYear",
313
+ "mcChannelNumber",
314
+ "runNumber",
315
+ "eventNumber",
316
+ "lumiBlock",
317
+ "actualInteractionsPerCrossing",
318
+ "averageInteractionsPerCrossing",
319
+ "truthjet",
320
+ "PileupWeight",
321
+ "RandomRunNumber",
322
+ "met",
323
+ "recojet",
324
+ "truth",
325
+ "generatorWeight",
326
+ "beamSpotWeight",
327
+ "trigPassed",
328
+ "jvt",
329
+ ]
330
+ }
331
+ ```
332
+
333
+ However, a more involved example to apply a selection and fill a histogram looks
334
+ like below:
335
+
336
+ ```python
337
+ import awkward as ak
338
+ import dask
339
+ import hist.dask as had
340
+ import matplotlib.pyplot as plt
341
+ from coffea import processor
342
+ from coffea.nanoevents import NanoEventsFactory
343
+ from distributed import Client
344
+
345
+ from atlas_schema.schema import NtupleSchema
346
+
347
+
348
+ class MyFirstProcessor(processor.ProcessorABC):
349
+ def __init__(self):
350
+ pass
351
+
352
+ def process(self, events):
353
+ dataset = events.metadata["dataset"]
354
+ h_ph_pt = (
355
+ had.Hist.new.StrCat(["all", "pass", "fail"], name="isEM")
356
+ .Regular(200, 0.0, 2000.0, name="pt", label="$pt_{\gamma}$ [GeV]")
357
+ .Int64()
358
+ )
359
+
360
+ cut = ak.all(events.ph.isEM, axis=1)
361
+ h_ph_pt.fill(isEM="all", pt=ak.firsts(events.ph.pt / 1.0e3))
362
+ h_ph_pt.fill(isEM="pass", pt=ak.firsts(events[cut].ph.pt / 1.0e3))
363
+ h_ph_pt.fill(isEM="fail", pt=ak.firsts(events[~cut].ph.pt / 1.0e3))
364
+
365
+ return {
366
+ dataset: {
367
+ "entries": ak.num(events, axis=0),
368
+ "ph_pt": h_ph_pt,
369
+ }
370
+ }
371
+
372
+ def postprocess(self, accumulator):
373
+ pass
374
+
375
+
376
+ if __name__ == "__main__":
377
+ client = Client()
378
+
379
+ fname = "ntuple.root"
380
+ events = NanoEventsFactory.from_root(
381
+ {fname: "analysis"},
382
+ schemaclass=NtupleSchema,
383
+ metadata={"dataset": "700352.Zqqgamma.mc20d.v1"},
384
+ ).events()
385
+
386
+ p = MyFirstProcessor()
387
+ out = p.process(events)
388
+ (computed,) = dask.compute(out)
389
+ print(computed)
390
+
391
+ fig, ax = plt.subplots()
392
+ computed["700352.Zqqgamma.mc20d.v1"]["ph_pt"].plot1d(ax=ax)
393
+ ax.set_xscale("log")
394
+ ax.legend(title="Photon pT for Zqqgamma")
395
+
396
+ fig.savefig("ph_pt.pdf")
397
+ ```
398
+
399
+ which produces
400
+
401
+ <img src="https://raw.githubusercontent.com/scipp-atlas/atlas-schema/main/docs/_static/img/ph_pt.png" alt="three stacked histograms of photon pT, with each stack corresponding to: no selection, requiring the isEM flag, and inverting the isEM requirement" width="500" style="display: block; margin-left: auto; margin-right: auto;">
402
+
403
+ <!-- SPHINX-END -->
404
+
282
405
  ## Developer Notes
283
406
 
284
407
  ### Converting Enums from C++ to Python
@@ -0,0 +1,160 @@
1
+ # atlas-schema v0.2.4
2
+
3
+ [![Actions Status][actions-badge]][actions-link]
4
+ [![Documentation Status][rtd-badge]][rtd-link]
5
+
6
+ [![PyPI version][pypi-version]][pypi-link]
7
+ [![Conda-Forge][conda-badge]][conda-link]
8
+ [![PyPI platforms][pypi-platforms]][pypi-link]
9
+
10
+ [![GitHub Discussion][github-discussions-badge]][github-discussions-link]
11
+
12
+ <!-- SPHINX-START -->
13
+
14
+ <!-- prettier-ignore-start -->
15
+ [actions-badge]: https://github.com/scipp-atlas/atlas-schema/workflows/CI/badge.svg
16
+ [actions-link]: https://github.com/scipp-atlas/atlas-schema/actions
17
+ [conda-badge]: https://img.shields.io/conda/vn/conda-forge/atlas-schema
18
+ [conda-link]: https://github.com/conda-forge/atlas-schema-feedstock
19
+ [github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github
20
+ [github-discussions-link]: https://github.com/scipp-atlas/atlas-schema/discussions
21
+ [pypi-link]: https://pypi.org/project/atlas-schema/
22
+ [pypi-platforms]: https://img.shields.io/pypi/pyversions/atlas-schema
23
+ [pypi-version]: https://img.shields.io/pypi/v/atlas-schema
24
+ [rtd-badge]: https://readthedocs.org/projects/atlas-schema/badge/?version=latest
25
+ [rtd-link]: https://atlas-schema.readthedocs.io/en/latest/?badge=latest
26
+
27
+ <!-- prettier-ignore-end -->
28
+
29
+ This is the python package containing schemas and helper functions enabling
30
+ analyzers to work with ATLAS datasets (Monte Carlo and Data), using
31
+ [coffea](https://coffea-hep.readthedocs.io/en/latest/).
32
+
33
+ ## Hello World
34
+
35
+ The simplest example is to just get started processing the file as expected:
36
+
37
+ ```python
38
+ from atlas_schema.schema import NtupleSchema
39
+ from coffea import dataset_tools
40
+ import awkward as ak
41
+
42
+ fileset = {"ttbar": {"files": {"path/to/ttbar.root": "tree_name"}}}
43
+ samples, report = dataset_tools.preprocess(fileset)
44
+
45
+
46
+ def noop(events):
47
+ return ak.fields(events)
48
+
49
+
50
+ fields = dataset_tools.apply_to_fileset(noop, samples, schemaclass=NtupleSchema)
51
+ print(fields)
52
+ ```
53
+
54
+ which produces something similar to
55
+
56
+ ```python
57
+ {
58
+ "ttbar": [
59
+ "dataTakingYear",
60
+ "mcChannelNumber",
61
+ "runNumber",
62
+ "eventNumber",
63
+ "lumiBlock",
64
+ "actualInteractionsPerCrossing",
65
+ "averageInteractionsPerCrossing",
66
+ "truthjet",
67
+ "PileupWeight",
68
+ "RandomRunNumber",
69
+ "met",
70
+ "recojet",
71
+ "truth",
72
+ "generatorWeight",
73
+ "beamSpotWeight",
74
+ "trigPassed",
75
+ "jvt",
76
+ ]
77
+ }
78
+ ```
79
+
80
+ However, a more involved example to apply a selection and fill a histogram looks
81
+ like below:
82
+
83
+ ```python
84
+ import awkward as ak
85
+ import dask
86
+ import hist.dask as had
87
+ import matplotlib.pyplot as plt
88
+ from coffea import processor
89
+ from coffea.nanoevents import NanoEventsFactory
90
+ from distributed import Client
91
+
92
+ from atlas_schema.schema import NtupleSchema
93
+
94
+
95
+ class MyFirstProcessor(processor.ProcessorABC):
96
+ def __init__(self):
97
+ pass
98
+
99
+ def process(self, events):
100
+ dataset = events.metadata["dataset"]
101
+ h_ph_pt = (
102
+ had.Hist.new.StrCat(["all", "pass", "fail"], name="isEM")
103
+ .Regular(200, 0.0, 2000.0, name="pt", label="$pt_{\gamma}$ [GeV]")
104
+ .Int64()
105
+ )
106
+
107
+ cut = ak.all(events.ph.isEM, axis=1)
108
+ h_ph_pt.fill(isEM="all", pt=ak.firsts(events.ph.pt / 1.0e3))
109
+ h_ph_pt.fill(isEM="pass", pt=ak.firsts(events[cut].ph.pt / 1.0e3))
110
+ h_ph_pt.fill(isEM="fail", pt=ak.firsts(events[~cut].ph.pt / 1.0e3))
111
+
112
+ return {
113
+ dataset: {
114
+ "entries": ak.num(events, axis=0),
115
+ "ph_pt": h_ph_pt,
116
+ }
117
+ }
118
+
119
+ def postprocess(self, accumulator):
120
+ pass
121
+
122
+
123
+ if __name__ == "__main__":
124
+ client = Client()
125
+
126
+ fname = "ntuple.root"
127
+ events = NanoEventsFactory.from_root(
128
+ {fname: "analysis"},
129
+ schemaclass=NtupleSchema,
130
+ metadata={"dataset": "700352.Zqqgamma.mc20d.v1"},
131
+ ).events()
132
+
133
+ p = MyFirstProcessor()
134
+ out = p.process(events)
135
+ (computed,) = dask.compute(out)
136
+ print(computed)
137
+
138
+ fig, ax = plt.subplots()
139
+ computed["700352.Zqqgamma.mc20d.v1"]["ph_pt"].plot1d(ax=ax)
140
+ ax.set_xscale("log")
141
+ ax.legend(title="Photon pT for Zqqgamma")
142
+
143
+ fig.savefig("ph_pt.pdf")
144
+ ```
145
+
146
+ which produces
147
+
148
+ <img src="https://raw.githubusercontent.com/scipp-atlas/atlas-schema/main/docs/_static/img/ph_pt.png" alt="three stacked histograms of photon pT, with each stack corresponding to: no selection, requiring the isEM flag, and inverting the isEM requirement" width="500" style="display: block; margin-left: auto; margin-right: auto;">
149
+
150
+ <!-- SPHINX-END -->
151
+
152
+ ## Developer Notes
153
+
154
+ ### Converting Enums from C++ to Python
155
+
156
+ This useful `vim` substitution helps:
157
+
158
+ ```
159
+ %s/ \([A-Za-z]\+\)\s\+= \(\d\+\),\?/ \1: Annotated[int, "\1"] = \2
160
+ ```
@@ -60,7 +60,7 @@ docs = [
60
60
  Homepage = "https://github.com/scipp-atlas/atlas-schema"
61
61
  "Bug Tracker" = "https://github.com/scipp-atlas/atlas-schema/issues"
62
62
  Discussions = "https://github.com/scipp-atlas/atlas-schema/discussions"
63
- Documentation = "https://atlas-schema.readthedocs.io/en/v0.2.2/"
63
+ Documentation = "https://atlas-schema.readthedocs.io/en/v0.2.4/"
64
64
  Releases = "https://github.com/scipp-atlas/atlas-schema/releases"
65
65
  "Release Notes" = "https://atlas-schema.readthedocs.io/en/latest/history.html"
66
66
 
@@ -111,13 +111,18 @@ addopts = [
111
111
  ]
112
112
  xfail_strict = true
113
113
  filterwarnings = [
114
- "error",
114
+ "error",
115
115
  ]
116
+
116
117
  log_cli_level = "INFO"
117
118
  testpaths = [
119
+ "src",
118
120
  "tests",
119
121
  "docs",
120
122
  ]
123
+ norecursedirs = [
124
+ "tests/helpers"
125
+ ]
121
126
 
122
127
  [tool.coverage]
123
128
  run.source = ["atlas_schema"]
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.2'
16
- __version_tuple__ = version_tuple = (0, 2, 2)
15
+ __version__ = version = '0.2.4'
16
+ __version_tuple__ = version_tuple = (0, 2, 4)
@@ -230,12 +230,24 @@ JetArray.MomentumClass = vector.LorentzVectorArray # noqa: F821
230
230
 
231
231
  __all__ = [
232
232
  "Electron",
233
+ "ElectronArray", # noqa: F822
234
+ "ElectronRecord", # noqa: F822
233
235
  "Jet",
236
+ "JetArray", # noqa: F822
237
+ "JetRecord", # noqa: F822
234
238
  "MissingET",
239
+ "MissingETArray", # noqa: F822
240
+ "MissingETRecord", # noqa: F822
235
241
  "Muon",
242
+ "MuonArray", # noqa: F822
243
+ "MuonRecord", # noqa: F822
236
244
  "NtupleEvents",
237
245
  "Particle",
246
+ "ParticleArray", # noqa: F822
247
+ "ParticleRecord", # noqa: F822
238
248
  "Pass",
239
249
  "Photon",
250
+ "PhotonArray", # noqa: F822
251
+ "PhotonRecord", # noqa: F822
240
252
  "Weight",
241
253
  ]
@@ -0,0 +1,405 @@
1
+ from __future__ import annotations
2
+
3
+ import difflib
4
+ import warnings
5
+ from collections.abc import KeysView, ValuesView
6
+ from typing import Any, ClassVar
7
+
8
+ from coffea.nanoevents.schemas.base import BaseSchema, zip_forms
9
+
10
+ from atlas_schema.typing_compat import Behavior, Self
11
+
12
+
13
+ class NtupleSchema(BaseSchema): # type: ignore[misc]
14
+ """The schema for building ATLAS ntuples following the typical centralized formats.
15
+
16
+ This schema is built from all branches found in a tree in the supplied
17
+ file, based on the naming pattern of the branches. This naming pattern is
18
+ typically assumed to be
19
+
20
+ .. code-block:: bash
21
+
22
+ {collection:str}_{subcollection:str}_{systematic:str}
23
+
24
+ where:
25
+ * ``collection`` is assumed to be a prefix with typical characters, following the regex ``[a-zA-Z][a-zA-Z0-9]*``; that is starting with a case-insensitive letter, and proceeded by zero or more alphanumeric characters,
26
+ * ``subcollection`` is assumed to be anything with typical characters (allowing for underscores) following the regex ``[a-zA-Z_][a-zA-Z0-9_]*``; that is starting with a case-insensitive letter or underscore, and proceeded by zero or more alphanumeric characters including underscores, and
27
+ * ``systematic`` is assumed to be either ``NOSYS`` to indicate a branch with potential systematic variariations, or anything with typical characters (allowing for underscores) following the same regular expression as the ``subcollection``.
28
+
29
+ Here, a collection refers to the top-level entry to access an item - a collection called ``el`` will be accessible under the ``el`` attributes via ``events['el']`` or ``events.el``. A subcollection called ``pt`` will be accessible under that collection, such as ``events['el']['pt']`` or ``events.el.pt``. This is the power of the schema providing a more user-friendly (and programmatic) access to the underlying branches.
30
+
31
+ The above logic means that the following branches below will be categorized as follows:
32
+
33
+ +-------------------------------+-------------------+-----------------------+------------------+
34
+ | branch | collection | subcollection | systematic |
35
+ +===============================+===================+=======================+==================+
36
+ | ``'eventNumber'`` | ``'eventNumber'`` | ``None`` | ``None`` |
37
+ +-------------------------------+-------------------+-----------------------+------------------+
38
+ | ``'runNumber'`` | ``'runNumber'`` | ``None`` | ``None`` |
39
+ +-------------------------------+-------------------+-----------------------+------------------+
40
+ | ``'el_pt_NOSYS'`` | ``'el'`` | ``'pt'`` | ``'NOSYS'`` |
41
+ +-------------------------------+-------------------+-----------------------+------------------+
42
+ | ``'jet_cleanTightBad_NOSYS'`` | ``'jet'`` | ``'cleanTightBad'`` | ``'NOSYS'`` |
43
+ +-------------------------------+-------------------+-----------------------+------------------+
44
+ | ``'jet_select_btag_NOSYS'`` | ``'jet'`` | ``'select_btag'`` | ``'NOSYS'`` |
45
+ +-------------------------------+-------------------+-----------------------+------------------+
46
+ | ``'jet_e_NOSYS'`` | ``'jet'`` | ``'e'`` | ``'NOSYS'`` |
47
+ +-------------------------------+-------------------+-----------------------+------------------+
48
+ | ``'truthel_phi'`` | ``'truthel'`` | ``'phi'`` | ``None`` |
49
+ +-------------------------------+-------------------+-----------------------+------------------+
50
+ | ``'truthel_pt'`` | ``'truthel'`` | ``'pt'`` | ``None`` |
51
+ +-------------------------------+-------------------+-----------------------+------------------+
52
+ | ``'ph_eta'`` | ``'ph'`` | ``'eta'`` | ``None`` |
53
+ +-------------------------------+-------------------+-----------------------+------------------+
54
+ | ``'ph_phi_SCALE__1up'`` | ``'ph'`` | ``'phi'`` | ``'SCALE__1up'`` |
55
+ +-------------------------------+-------------------+-----------------------+------------------+
56
+ | ``'mu_TTVA_effSF_NOSYS'`` | ``'mu'`` | ``'TTVA_effSF'`` | ``'NOSYS'`` |
57
+ +-------------------------------+-------------------+-----------------------+------------------+
58
+ | ``'recojet_antikt4PFlow_pt'`` | ``'recojet'`` | ``'antikt4PFlow_pt'`` | ``'NOSYS'`` |
59
+ +-------------------------------+-------------------+-----------------------+------------------+
60
+ | ``'recojet_antikt10UFO_m'`` | ``'recojet'`` | ``'antikt10UFO_m'`` | ``None`` |
61
+ +-------------------------------+-------------------+-----------------------+------------------+
62
+
63
+ Sometimes this logic is not what you want, and there are ways to teach ``NtupleSchema`` how to group some of these better for atypical cases. We can address these case-by-case.
64
+
65
+ **Singletons**
66
+
67
+ Sometimes you have particular branches that you don't want to be treated as a collection (with subcollections). And sometimes you will see warnings about this (see :ref:`faq`). There are some pre-defined ``singletons`` stored under :attr:`event_ids`, and these will be lazily treated as a _singleton_. For other cases where you add your own branches, you can additionally extend this class to add your own :attr:`singletons`:
68
+
69
+ .. code-block:: python
70
+
71
+ from atlas_schema.schema import NtupleSchema
72
+
73
+
74
+ class MySchema(NtupleSchema):
75
+ singletons = {"RandomRunNumber"}
76
+
77
+ and use this schema in your analysis code. The rest of the logic will be handled for you, and you can access your singletons under ``events.RandomRunNumber`` as expected.
78
+
79
+ **Mixins (collections, subcollections)**
80
+
81
+ In more complicated scenarios, you might need to teach :class:`NtupleSchema` how to handle collections that end up having underscores in their name, or other characters that make the grouping non-trivial. In some other scenarios, you want to tell the schema to assign a certain set of behaviors to a collection - rather than the default :class:`atlas_schema.methods.Particle` behavior. This is where :attr:`mixins` comes in. Similar to how :attr:`singletons` are handled, you extend this schema to include your own ``mixins`` pointing them at one of the behaviors defined in :mod:`atlas_schema.methods`.
82
+
83
+ Let's demonstrate both cases. Imagine you want to have your ``truthel`` collections above treated as :class:`atlas_schema.methods.Electron`, then you would extend the existing :attr:`mixins`:
84
+
85
+ .. code-block:: python
86
+
87
+ from atlas_schema.schema import NtupleSchema
88
+
89
+
90
+ class MySchema(NtupleSchema):
91
+ mixins = {"truthel": "Electron", **NtupleSchema.mixins}
92
+
93
+ Now, ``events.truthel`` will give you arrays zipped up with :class:`atlas_schema.methods.Electron` behaviors.
94
+
95
+ If instead, you run into problems with mixing different branches in the same collection, because the default behavior of this schema described above is not smart enough to handle the atypical cases, you can explicitly fix this by defining your collections:
96
+
97
+ .. code-block:: python
98
+
99
+ from atlas_schema.schema import NtupleSchema
100
+
101
+
102
+ class MySchema(NtupleSchema):
103
+ mixins = {
104
+ "recojet_antikt4PFlow": "Jet",
105
+ "recojet_antikt10UFO": "Jet",
106
+ **NtupleSchema.mixins,
107
+ }
108
+
109
+ Now, ``events.recojet_antikt4PFlow`` and ``events.recojet_antikt10UFO`` will be separate collections, instead of a single ``events.recojet`` that incorrectly merged branches from each of these collections.
110
+ """
111
+
112
+ __dask_capable__: ClassVar[bool] = True
113
+
114
+ warn_missing_crossrefs: ClassVar[bool] = True
115
+
116
+ #: Treat missing event-level branches as error instead of warning (default is ``False``)
117
+ error_missing_event_ids: ClassVar[bool] = False
118
+ #: Determine closest behavior for a given branch or treat branch as :attr:`default_behavior` (default is ``True``)
119
+ identify_closest_behavior: ClassVar[bool] = True
120
+
121
+ #: event IDs to expect in data datasets
122
+ event_ids_data: ClassVar[set[str]] = {
123
+ "lumiBlock",
124
+ "averageInteractionsPerCrossing",
125
+ "actualInteractionsPerCrossing",
126
+ "dataTakingYear",
127
+ }
128
+ #: event IDs to expect in MC datasets
129
+ event_ids_mc: ClassVar[set[str]] = {
130
+ "mcChannelNumber",
131
+ "runNumber",
132
+ "eventNumber",
133
+ "mcEventWeights",
134
+ }
135
+ #: all event IDs to expect in the dataset
136
+ event_ids: ClassVar[set[str]] = {*event_ids_data, *event_ids_mc}
137
+
138
+ #: mixins defining the mapping from collection name to behavior to use for that collection
139
+ mixins: ClassVar[dict[str, str]] = {
140
+ "el": "Electron",
141
+ "jet": "Jet",
142
+ "met": "MissingET",
143
+ "mu": "Muon",
144
+ "pass": "Pass",
145
+ "ph": "Photon",
146
+ "trigPassed": "Trigger",
147
+ "weight": "Weight",
148
+ }
149
+
150
+ #: additional branches to pass-through with no zipping or additional interpretation (such as those stored as length-1 vectors)
151
+ singletons: ClassVar[set[str]] = set()
152
+
153
+ #: docstrings to assign for specific subcollections across the various collections identified by this schema
154
+ docstrings: ClassVar[dict[str, str]] = {
155
+ "charge": "charge",
156
+ "eta": "pseudorapidity",
157
+ "met": "missing transverse energy [MeV]",
158
+ "mass": "invariant mass [MeV]",
159
+ "pt": "transverse momentum [MeV]",
160
+ "phi": "azimuthal angle",
161
+ }
162
+
163
+ #: default behavior to use for any collection (default ``"NanoCollection"``, from :class:`coffea.nanoevents.methods.base.NanoCollection`)
164
+ default_behavior: ClassVar[str] = "NanoCollection"
165
+
166
+ def __init__(self, base_form: dict[str, Any], version: str = "latest"):
167
+ super().__init__(base_form)
168
+ self._version = version
169
+ if version == "latest":
170
+ pass
171
+ else:
172
+ pass
173
+ self._form["fields"], self._form["contents"] = self._build_collections(
174
+ self._form["fields"], self._form["contents"]
175
+ )
176
+ self._form["parameters"]["metadata"]["version"] = self._version
177
+
178
+ @classmethod
179
+ def v1(cls, base_form: dict[str, Any]) -> Self:
180
+ """Build the NtupleEvents
181
+
182
+ For example, one can use ``NanoEventsFactory.from_root("file.root", schemaclass=NtupleSchema.v1)``
183
+ to ensure NanoAODv7 compatibility.
184
+ """
185
+ return cls(base_form, version="1")
186
+
187
+ def _build_collections(
188
+ self, field_names: list[str], input_contents: list[Any]
189
+ ) -> tuple[KeysView[str], ValuesView[dict[str, Any]]]:
190
+ branch_forms = dict(zip(field_names, input_contents))
191
+
192
+ # parse into high-level records (collections, list collections, and singletons)
193
+ collections = {k.split("_")[0] for k in branch_forms}
194
+ collections -= self.event_ids
195
+ collections -= set(self.singletons)
196
+
197
+ # now handle any collections that we identified that are substrings of the items in the mixins
198
+ # convert all valid branch_forms into strings to make the lookups a bit faster
199
+ bf_str = ",".join(branch_forms.keys())
200
+ for mixin in self.mixins:
201
+ if mixin in collections:
202
+ continue
203
+ if f",{mixin}_" not in bf_str and not bf_str.startswith(f"{mixin}_"):
204
+ continue
205
+ if "_" in mixin:
206
+ warnings.warn(
207
+ f"I identified a mixin that I did not automatically identify as a collection because it contained an underscore: '{mixin}'. I will add this to the known collections. To suppress this warning next time, please create your ntuples with collections without underscores. [mixin-underscore]",
208
+ RuntimeWarning,
209
+ stacklevel=2,
210
+ )
211
+ collections.add(mixin)
212
+ for collection in list(collections):
213
+ if mixin.startswith(f"{collection}_"):
214
+ warnings.warn(
215
+ f"I found a misidentified collection: '{collection}'. I will remove this from the known collections. To suppress this warning next time, please create your ntuples with collections that are not similarly named with underscores. [collection-subset]",
216
+ RuntimeWarning,
217
+ stacklevel=2,
218
+ )
219
+ collections.remove(collection)
220
+ break
221
+
222
+ # rename needed because easyjet breaks the AMG assumptions
223
+ # https://gitlab.cern.ch/easyjet/easyjet/-/issues/246
224
+ for k in list(branch_forms):
225
+ if "NOSYS" not in k:
226
+ continue
227
+ branch_forms[k.replace("_NOSYS", "") + "_NOSYS"] = branch_forms.pop(k)
228
+
229
+ # these are collections with systematic variations
230
+ subcollections = {
231
+ k.split("__")[0].split("_", 1)[1].replace("_NOSYS", "")
232
+ for k in branch_forms
233
+ if "NOSYS" in k
234
+ }
235
+
236
+ # Check the presence of the event_ids
237
+ missing_event_ids = [
238
+ event_id for event_id in self.event_ids if event_id not in branch_forms
239
+ ]
240
+
241
+ if len(missing_event_ids) > 0:
242
+ if self.error_missing_event_ids:
243
+ msg = f"There are missing event ID fields: {missing_event_ids} \n\n\
244
+ The event ID fields {self.event_ids} are necessary to perform sub-run identification \
245
+ (e.g. for corrections and sub-dividing data during different detector conditions),\
246
+ to cross-validate MC and Data (i.e. matching events for comparison), and to generate event displays. \
247
+ It's advised to never drop these branches from the dataformat.\n\n\
248
+ This error can be demoted to a warning by setting the class level variable error_missing_event_ids to False."
249
+ raise RuntimeError(msg)
250
+ warnings.warn(
251
+ f"Missing event_ids : {missing_event_ids}",
252
+ RuntimeWarning,
253
+ stacklevel=2,
254
+ )
255
+
256
+ output = {}
257
+
258
+ # first, register singletons (event-level, others)
259
+ for name in {*self.event_ids, *self.singletons}:
260
+ if name in missing_event_ids:
261
+ continue
262
+ output[name] = branch_forms[name]
263
+
264
+ # next, go through and start grouping up collections
265
+ for name in collections:
266
+ content = {}
267
+ used = set()
268
+
269
+ for subname in subcollections:
270
+ prefix = f"{name}_{subname}_"
271
+ used.update({k for k in branch_forms if k.startswith(prefix)})
272
+ subcontent = {
273
+ k[len(prefix) :]: branch_forms[k]
274
+ for k in branch_forms
275
+ if k.startswith(prefix)
276
+ }
277
+ if subcontent:
278
+ # create the nominal version
279
+ content[subname] = branch_forms[f"{prefix}NOSYS"]
280
+ # create a collection of the systematic variations for the given variable
281
+ content[f"{subname}_syst"] = zip_forms(
282
+ subcontent, f"{name}_syst", record_name="NanoCollection"
283
+ )
284
+
285
+ content.update(
286
+ {
287
+ k[len(name) + 1 :]: branch_forms[k]
288
+ for k in branch_forms
289
+ if k.startswith(name + "_") and k not in used
290
+ }
291
+ )
292
+
293
+ if not used and not content:
294
+ warnings.warn(
295
+ f"I identified a branch that likely does not have any leaves: '{name}'. I will treat this as a 'singleton'. To suppress this warning next time, please define your singletons explicitly. [singleton-undefined]",
296
+ RuntimeWarning,
297
+ stacklevel=2,
298
+ )
299
+ self.singletons.add(name)
300
+ output[name] = branch_forms[name]
301
+
302
+ else:
303
+ behavior = self.mixins.get(name, "")
304
+ if not behavior:
305
+ behavior = self.suggested_behavior(name)
306
+ warnings.warn(
307
+ f"I found a collection with no defined mixin: '{name}'. I will assume behavior: '{behavior}'. To suppress this warning next time, please define mixins for your custom collections. [mixin-undefined]",
308
+ RuntimeWarning,
309
+ stacklevel=2,
310
+ )
311
+
312
+ output[name] = zip_forms(content, name, record_name=behavior)
313
+
314
+ output[name].setdefault("parameters", {})
315
+ output[name]["parameters"].update({"collection_name": name})
316
+
317
+ if output[name]["class"] == "ListOffsetArray":
318
+ if output[name]["class"] == "RecordArray":
319
+ parameters = output[name]["content"]["fields"]
320
+ contents = output[name]["content"]["contents"]
321
+ else:
322
+ # these are also singletons of another kind that we just pass through
323
+ continue
324
+ elif output[name]["class"] == "RecordArray":
325
+ parameters = output[name]["fields"]
326
+ contents = output[name]["contents"]
327
+ elif output[name]["class"] == "NumpyArray":
328
+ # these are singletons that we just pass through
329
+ continue
330
+ else:
331
+ msg = f"Unhandled class {output[name]['class']}"
332
+ raise RuntimeError(msg)
333
+
334
+ # update docstrings as needed
335
+ # NB: must be before flattening for easier logic
336
+ for index, parameter in enumerate(parameters):
337
+ if "parameters" not in contents[index]:
338
+ continue
339
+
340
+ parsed_name = parameter.replace("_NOSYS", "")
341
+ contents[index]["parameters"]["__doc__"] = self.docstrings.get(
342
+ parsed_name,
343
+ contents[index]["parameters"].get(
344
+ "__doc__", "no docstring available"
345
+ ),
346
+ )
347
+
348
+ return output.keys(), output.values()
349
+
350
+ @classmethod
351
+ def behavior(cls) -> Behavior:
352
+ """Behaviors necessary to implement this schema
353
+
354
+ Returns:
355
+ dict[str | tuple['*', str], type[awkward.Record]]: an :data:`awkward.behavior` dictionary
356
+ """
357
+ from atlas_schema.methods import behavior as roaster
358
+
359
+ return roaster
360
+
361
+ @classmethod
362
+ def suggested_behavior(cls, key: str, cutoff: float = 0.4) -> str:
363
+ """
364
+ Suggest e behavior to use for a provided collection or branch name.
365
+
366
+ Default behavior: :class:`~coffea.nanoevents.methods.base.NanoCollection`.
367
+
368
+ Note:
369
+ If :attr:`identify_closest_behavior` is ``False``, then this function will return the default behavior ``NanoCollection``.
370
+
371
+ Warning:
372
+ If no behavior is found above the *cutoff* score, then this function will return the default behavior.
373
+
374
+ Args:
375
+ key (str): collection name to suggest a matching behavior for
376
+ cutoff (float): o ptional argument cutoff (default ``0.4``) is a float in the range ``[0, 1]``. Possibilities that don't score at least that similar to *key* are ignored.
377
+
378
+ Returns:
379
+ str: suggested behavior to use by string
380
+
381
+ Example:
382
+ >>> from atlas_schema.schema import NtupleSchema
383
+ >>> NtupleSchema.suggested_behavior("truthjet")
384
+ 'Jet'
385
+ >>> NtupleSchema.suggested_behavior("SignalElectron")
386
+ 'Electron'
387
+ >>> NtupleSchema.suggested_behavior("generatorWeight")
388
+ 'Weight'
389
+ >>> NtupleSchema.suggested_behavior("aVeryStrangelyNamedBranchWithNoMatch")
390
+ 'NanoCollection'
391
+ """
392
+ if cls.identify_closest_behavior:
393
+ # lowercase everything to do case-insensitive matching
394
+ behaviors = [b for b in cls.behavior() if isinstance(b, str)]
395
+ behaviors_l = [b.lower() for b in behaviors]
396
+ results = difflib.get_close_matches(
397
+ key.lower(), behaviors_l, n=1, cutoff=cutoff
398
+ )
399
+ if not results:
400
+ return cls.default_behavior
401
+
402
+ behavior = results[0]
403
+ # need to identify the index and return the unlowered version
404
+ return behaviors[behaviors_l.index(behavior)]
405
+ return cls.default_behavior
@@ -5,7 +5,7 @@ Typing helpers.
5
5
  from __future__ import annotations
6
6
 
7
7
  import sys
8
- from typing import Annotated
8
+ from typing import Annotated, Literal, Union
9
9
 
10
10
  import awkward
11
11
 
@@ -19,6 +19,6 @@ if sys.version_info >= (3, 11):
19
19
  else:
20
20
  from typing_extensions import Self
21
21
 
22
- Behavior: TypeAlias = dict[str, type[awkward.Record]]
22
+ Behavior: TypeAlias = dict[Union[str, tuple[Literal["*"]], str], type[awkward.Record]]
23
23
 
24
24
  __all__ = ("Annotated", "Behavior", "Self")
@@ -0,0 +1,49 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import TypeVar, Union, cast
5
+
6
+ import awkward as ak
7
+ import dask_awkward as dak
8
+
9
+ Array = TypeVar("Array", bound=Union[dak.Array, ak.Array])
10
+ _E = TypeVar("_E", bound=Enum)
11
+
12
+
13
+ def isin(element: Array, test_elements: dak.Array | ak.Array, axis: int = -1) -> Array:
14
+ """
15
+ Find test_elements in element. Similar in API as :func:`numpy.isin`.
16
+
17
+ Calculates `element in test_elements`, broadcasting over *element elements only*. Returns a boolean array of the same shape as *element* that is `True` where an element of *element* is in *test_elements* and `False` otherwise.
18
+
19
+ This works by first transforming *test_elements* to an array with one more
20
+ dimension than the *element*, placing the *test_elements* at *axis*, and then doing a
21
+ comparison.
22
+
23
+ Args:
24
+ element (dask_awkward.Array or ak.Array): input array of values.
25
+ test_elements (dask_awkward.Array or ak.Array): one-dimensional set of values against which to test each value of *element*.
26
+ axis (int): the axis along which the comparison is performed
27
+
28
+ Returns:
29
+ dask_awkward.Array or ak.Array: result of comparison for test_elements in *element*
30
+
31
+ Example:
32
+ >>> import awkward as ak
33
+ >>> import atlas_schema as ats
34
+ >>> truth_origins = ak.Array([[1, 2, 3], [4], [5, 6, 7], [1]])
35
+ >>> prompt_origins = ak.Array([1, 2, 7])
36
+ >>> ats.isin(truth_origins, prompt_origins).to_list()
37
+ [[True, True, False], [False], [False, False, True], [True]]
38
+ """
39
+ assert test_elements.ndim == 1, "test_elements must be one-dimensional"
40
+ assert axis >= -1, "axis must be -1 or positive-valued"
41
+ assert axis < element.ndim + 1, "axis too large for the element"
42
+
43
+ # First, build up the transformation, with slice(None) indicating where to stick the test_elements
44
+ reshaper: list[None | slice] = [None] * element.ndim
45
+ axis = element.ndim if axis == -1 else axis
46
+ reshaper.insert(axis, slice(None))
47
+
48
+ # Note: reshaper needs to be a tuple for indexing purposes
49
+ return cast(Array, ak.any(element == test_elements[tuple(reshaper)], axis=-1))
@@ -1,37 +0,0 @@
1
- # atlas-schema v0.2.2
2
-
3
- [![Actions Status][actions-badge]][actions-link]
4
- [![Documentation Status][rtd-badge]][rtd-link]
5
-
6
- [![PyPI version][pypi-version]][pypi-link]
7
- [![Conda-Forge][conda-badge]][conda-link]
8
- [![PyPI platforms][pypi-platforms]][pypi-link]
9
-
10
- [![GitHub Discussion][github-discussions-badge]][github-discussions-link]
11
-
12
- <!-- SPHINX-START -->
13
-
14
- <!-- prettier-ignore-start -->
15
- [actions-badge]: https://github.com/scipp-atlas/atlas-schema/workflows/CI/badge.svg
16
- [actions-link]: https://github.com/scipp-atlas/atlas-schema/actions
17
- [conda-badge]: https://img.shields.io/conda/vn/conda-forge/atlas-schema
18
- [conda-link]: https://github.com/conda-forge/atlas-schema-feedstock
19
- [github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github
20
- [github-discussions-link]: https://github.com/scipp-atlas/atlas-schema/discussions
21
- [pypi-link]: https://pypi.org/project/atlas-schema/
22
- [pypi-platforms]: https://img.shields.io/pypi/pyversions/atlas-schema
23
- [pypi-version]: https://img.shields.io/pypi/v/atlas-schema
24
- [rtd-badge]: https://readthedocs.org/projects/atlas-schema/badge/?version=latest
25
- [rtd-link]: https://atlas-schema.readthedocs.io/en/latest/?badge=latest
26
-
27
- <!-- prettier-ignore-end -->
28
-
29
- ## Developer Notes
30
-
31
- ### Converting Enums from C++ to Python
32
-
33
- This useful `vim` substitution helps:
34
-
35
- ```
36
- %s/ \([A-Za-z]\+\)\s\+= \(\d\+\),\?/ \1: Annotated[int, "\1"] = \2
37
- ```
@@ -1,206 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import warnings
4
- from collections.abc import KeysView, ValuesView
5
- from typing import Any, ClassVar
6
-
7
- from coffea.nanoevents.schemas.base import BaseSchema, zip_forms
8
-
9
- from atlas_schema.typing_compat import Behavior, Self
10
-
11
-
12
- class NtupleSchema(BaseSchema): # type: ignore[misc]
13
- """Ntuple schema builder
14
-
15
- The Ntuple schema is built from all branches found in the supplied file, based on
16
- the naming pattern of the branches. The following additional arrays are constructed:
17
-
18
- - n/a
19
- """
20
-
21
- __dask_capable__ = True
22
-
23
- warn_missing_crossrefs = True
24
- error_missing_event_ids = False
25
-
26
- event_ids_data: ClassVar[set[str]] = {
27
- "lumiBlock",
28
- "averageInteractionsPerCrossing",
29
- "actualInteractionsPerCrossing",
30
- "dataTakingYear",
31
- }
32
- event_ids_mc: ClassVar[set[str]] = {
33
- "mcChannelNumber",
34
- "runNumber",
35
- "eventNumber",
36
- "mcEventWeights",
37
- }
38
- event_ids: ClassVar[set[str]] = {*event_ids_data, *event_ids_mc}
39
-
40
- mixins: ClassVar[dict[str, str]] = {
41
- "el": "Electron",
42
- "jet": "Jet",
43
- "met": "MissingET",
44
- "mu": "Muon",
45
- "pass": "Pass",
46
- "ph": "Photon",
47
- "trigPassed": "Trigger",
48
- "weight": "Weight",
49
- }
50
-
51
- # These are stored as length-1 vectors unnecessarily
52
- singletons: ClassVar[list[str]] = []
53
-
54
- docstrings: ClassVar[dict[str, str]] = {
55
- "charge": "charge",
56
- "eta": "pseudorapidity",
57
- "met": "missing transverse energy [MeV]",
58
- "mass": "invariant mass [MeV]",
59
- "pt": "transverse momentum [MeV]",
60
- "phi": "azimuthal angle",
61
- }
62
-
63
- def __init__(self, base_form: dict[str, Any], version: str = "latest"):
64
- super().__init__(base_form)
65
- self._version = version
66
- if version == "latest":
67
- pass
68
- else:
69
- pass
70
- self._form["fields"], self._form["contents"] = self._build_collections(
71
- self._form["fields"], self._form["contents"]
72
- )
73
- self._form["parameters"]["metadata"]["version"] = self._version
74
-
75
- @classmethod
76
- def v1(cls, base_form: dict[str, Any]) -> Self:
77
- """Build the NtupleEvents
78
-
79
- For example, one can use ``NanoEventsFactory.from_root("file.root", schemaclass=NtupleSchema.v1)``
80
- to ensure NanoAODv7 compatibility.
81
- """
82
- return cls(base_form, version="1")
83
-
84
- def _build_collections(
85
- self, field_names: list[str], input_contents: list[Any]
86
- ) -> tuple[KeysView[str], ValuesView[dict[str, Any]]]:
87
- branch_forms = dict(zip(field_names, input_contents))
88
-
89
- # parse into high-level records (collections, list collections, and singletons)
90
- collections = {k.split("_")[0] for k in branch_forms}
91
- collections -= self.event_ids
92
- collections -= set(self.singletons)
93
-
94
- # rename needed because easyjet breaks the AMG assumptions
95
- # https://gitlab.cern.ch/easyjet/easyjet/-/issues/246
96
- for k in list(branch_forms):
97
- if "NOSYS" not in k:
98
- continue
99
- branch_forms[k.replace("_NOSYS", "") + "_NOSYS"] = branch_forms.pop(k)
100
-
101
- # these are collections with systematic variations
102
- subcollections = {
103
- k.split("__")[0].split("_", 1)[1].replace("_NOSYS", "")
104
- for k in branch_forms
105
- if "NOSYS" in k
106
- }
107
-
108
- # Check the presence of the event_ids
109
- missing_event_ids = [
110
- event_id for event_id in self.event_ids if event_id not in branch_forms
111
- ]
112
-
113
- if len(missing_event_ids) > 0:
114
- if self.error_missing_event_ids:
115
- msg = f"There are missing event ID fields: {missing_event_ids} \n\n\
116
- The event ID fields {self.event_ids} are necessary to perform sub-run identification \
117
- (e.g. for corrections and sub-dividing data during different detector conditions),\
118
- to cross-validate MC and Data (i.e. matching events for comparison), and to generate event displays. \
119
- It's advised to never drop these branches from the dataformat.\n\n\
120
- This error can be demoted to a warning by setting the class level variable error_missing_event_ids to False."
121
- raise RuntimeError(msg)
122
- warnings.warn(
123
- f"Missing event_ids : {missing_event_ids}",
124
- RuntimeWarning,
125
- stacklevel=2,
126
- )
127
-
128
- output = {}
129
-
130
- # first, register the event-level stuff directly
131
- for name in self.event_ids:
132
- if name in missing_event_ids:
133
- continue
134
- output[name] = branch_forms[name]
135
-
136
- # next, go through and start grouping up collections
137
- for name in collections:
138
- mixin = self.mixins.get(name, "NanoCollection")
139
- content = {}
140
- used = set()
141
-
142
- for subname in subcollections:
143
- prefix = f"{name}_{subname}_"
144
- used.update({k for k in branch_forms if k.startswith(prefix)})
145
- subcontent = {
146
- k[len(prefix) :]: branch_forms[k]
147
- for k in branch_forms
148
- if k.startswith(prefix)
149
- }
150
- if subcontent:
151
- # create the nominal version
152
- content[subname] = branch_forms[f"{prefix}NOSYS"]
153
- # create a collection of the systematic variations for the given variable
154
- content[f"{subname}_syst"] = zip_forms(
155
- subcontent, f"{name}_syst", record_name="NanoCollection"
156
- )
157
-
158
- content.update(
159
- {
160
- k[len(name) + 1 :]: branch_forms[k]
161
- for k in branch_forms
162
- if k.startswith(name + "_") and k not in used
163
- }
164
- )
165
-
166
- output[name] = zip_forms(content, name, record_name=mixin)
167
-
168
- output[name].setdefault("parameters", {})
169
- output[name]["parameters"].update({"collection_name": name})
170
-
171
- if output[name]["class"] == "ListOffsetArray":
172
- parameters = output[name]["content"]["fields"]
173
- contents = output[name]["content"]["contents"]
174
- elif output[name]["class"] == "RecordArray":
175
- parameters = output[name]["fields"]
176
- contents = output[name]["contents"]
177
- else:
178
- msg = f"Unhandled class {output[name]['class']}"
179
- raise RuntimeError(msg)
180
- # update docstrings as needed
181
- # NB: must be before flattening for easier logic
182
- for index, parameter in enumerate(parameters):
183
- if "parameters" not in contents[index]:
184
- continue
185
-
186
- parsed_name = parameter.replace("_NOSYS", "")
187
- contents[index]["parameters"]["__doc__"] = self.docstrings.get(
188
- parsed_name,
189
- contents[index]["parameters"].get(
190
- "__doc__", "no docstring available"
191
- ),
192
- )
193
-
194
- if name in self.singletons:
195
- # flatten! this 'promotes' the content of an inner dimension
196
- # upwards, effectively hiding one nested dimension
197
- output[name] = output[name]["content"]
198
-
199
- return output.keys(), output.values()
200
-
201
- @classmethod
202
- def behavior(cls) -> Behavior:
203
- """Behaviors necessary to implement this schema"""
204
- from atlas_schema.methods import behavior as roaster
205
-
206
- return roaster
@@ -1,39 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from enum import Enum
4
- from typing import TypeVar, Union, cast
5
-
6
- import awkward as ak
7
- import dask_awkward as dak
8
-
9
- Array = TypeVar("Array", bound=Union[dak.Array, ak.Array])
10
- _E = TypeVar("_E", bound=Enum)
11
-
12
-
13
- def isin(haystack: Array, needles: dak.Array | ak.Array, axis: int = -1) -> Array:
14
- """
15
- Find needles in haystack.
16
-
17
- This works by first transforming needles to an array with one more
18
- dimension than the haystack, placing the needles at axis, and then doing a
19
- comparison.
20
-
21
- Args:
22
- haystack (dak.Array or ak.Array): haystack of values.
23
- needles (dak.Array or ak.Array): one-dimensional set of needles to find in haystack.
24
- axis (int): the axis along which the comparison is performed
25
-
26
- Returns:
27
- dak.Array or ak.Array: result of comparison for needles in haystack
28
- """
29
- assert needles.ndim == 1, "Needles must be one-dimensional"
30
- assert axis >= -1, "axis must be -1 or positive-valued"
31
- assert axis < haystack.ndim + 1, "axis too large for the haystack"
32
-
33
- # First, build up the transformation, with slice(None) indicating where to stick the needles
34
- reshaper: list[None | slice] = [None] * haystack.ndim
35
- axis = haystack.ndim if axis == -1 else axis
36
- reshaper.insert(axis, slice(None))
37
-
38
- # Note: reshaper needs to be a tuple for indexing purposes
39
- return cast(Array, ak.any(haystack == needles[tuple(reshaper)], axis=-1))
File without changes
File without changes