essreduce 26.1.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,3 @@
1
1
  # SPDX-License-Identifier: BSD-3-Clause
2
2
  # Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
3
3
  """Live data utilities"""
4
-
5
- from .workflow import LiveWorkflow
6
-
7
- __all__ = [
8
- 'LiveWorkflow',
9
- ]
ess/reduce/streaming.py CHANGED
@@ -5,12 +5,15 @@
5
5
  from abc import ABC, abstractmethod
6
6
  from collections.abc import Callable
7
7
  from copy import deepcopy
8
- from typing import Any, Generic, TypeVar
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
9
9
 
10
10
  import networkx as nx
11
11
  import sciline
12
12
  import scipp as sc
13
13
 
14
+ if TYPE_CHECKING:
15
+ import graphviz
16
+
14
17
  T = TypeVar('T')
15
18
 
16
19
 
@@ -108,6 +111,14 @@ class Accumulator(ABC, Generic[T]):
108
111
  Clear the accumulator, resetting it to its initial state.
109
112
  """
110
113
 
114
+ def on_finalize(self) -> None:
115
+ """
116
+ Called after finalize retrieves value.
117
+
118
+ Override this method to perform custom cleanup after each finalize cycle.
119
+ The default implementation does nothing.
120
+ """
121
+
111
122
 
112
123
  class EternalAccumulator(Accumulator[T]):
113
124
  """
@@ -352,11 +363,15 @@ class StreamProcessor:
352
363
  for key in context_keys:
353
364
  workflow[key] = None
354
365
 
366
+ # Store for visualization (copy in case caller modifies base_workflow later)
367
+ self._base_workflow_for_viz = base_workflow.copy()
368
+
355
369
  # Find and pre-compute static nodes as far down the graph as possible
356
370
  nodes = _find_descendants(workflow, dynamic_keys + context_keys)
357
371
  last_static = _find_parents(workflow, nodes) - nodes
358
372
  for key, value in base_workflow.compute(last_static).items():
359
373
  workflow[key] = value
374
+ self._cached_keys = last_static # Store for visualization
360
375
 
361
376
  # Nodes that may need updating on context change but should be cached otherwise.
362
377
  dynamic_nodes = _find_descendants(workflow, dynamic_keys)
@@ -422,6 +437,9 @@ class StreamProcessor:
422
437
  needs_recompute |= self._context_key_to_cached_context_nodes_map[key]
423
438
  for key, value in context.items():
424
439
  self._context_workflow[key] = value
440
+ # Propagate context values to finalize workflow so providers that depend
441
+ # on context keys receive the updated values during finalize().
442
+ self._finalize_workflow[key] = value
425
443
  results = self._context_workflow.compute(needs_recompute)
426
444
  for key, value in results.items():
427
445
  if key in self._target_keys:
@@ -505,7 +523,10 @@ class StreamProcessor:
505
523
  """
506
524
  for key in self._accumulators:
507
525
  self._finalize_workflow[key] = self._accumulators[key].value
508
- return self._finalize_workflow.compute(self._target_keys)
526
+ result = self._finalize_workflow.compute(self._target_keys)
527
+ for acc in self._accumulators.values():
528
+ acc.on_finalize()
529
+ return result
509
530
 
510
531
  def clear(self) -> None:
511
532
  """
@@ -517,15 +538,173 @@ class StreamProcessor:
517
538
  for accumulator in self._accumulators.values():
518
539
  accumulator.clear()
519
540
 
541
+ def _get_viz_workflow(self) -> sciline.Pipeline:
542
+ """Create the workflow used for visualization."""
543
+ viz_workflow = sciline.Pipeline()
544
+ for key in self._target_keys:
545
+ viz_workflow[key] = self._base_workflow_for_viz[key]
546
+ viz_graph = viz_workflow.underlying_graph
547
+ for key in self._dynamic_keys:
548
+ if key in viz_graph:
549
+ viz_workflow[key] = None
550
+ for key in self._context_keys:
551
+ if key in viz_graph:
552
+ viz_workflow[key] = None
553
+ return viz_workflow
554
+
555
+ def _classify_nodes(self, graph: nx.DiGraph) -> dict[str, set[Any]]:
556
+ """
557
+ Classify all nodes in the graph for visualization.
558
+
559
+ Node categories:
560
+
561
+ - static: Pre-computed once, not dependent on dynamic or context keys
562
+ - cached_keys: Subset of static nodes that are actually cached (last_static)
563
+ - dynamic_keys: Input entry points for chunk data
564
+ - dynamic_nodes: Downstream of dynamic keys, recomputed per chunk
565
+ (excludes nodes downstream of accumulators, which are computed in finalize)
566
+ - context_keys: Input entry points for context data
567
+ - context_dependent: Downstream of context keys but not dynamic keys
568
+ - accumulator_keys: Where values are aggregated across chunks
569
+ - target_keys: Final outputs computed in finalize()
570
+ - finalize_nodes: Downstream of accumulators, computed in finalize()
571
+ """
572
+ all_nodes = set(graph.nodes)
573
+ accumulator_keys = set(self._accumulators.keys())
574
+ target_keys = set(self._target_keys)
575
+
576
+ # Compute derived classifications
577
+ dynamic_descendants = _find_descendants(graph, self._dynamic_keys)
578
+ context_descendants = _find_descendants(graph, self._context_keys)
579
+
580
+ # Nodes downstream of accumulators are computed in finalize(), not per-chunk
581
+ accumulator_descendants = _find_descendants(graph, accumulator_keys)
582
+ finalize_nodes = accumulator_descendants - accumulator_keys
583
+
584
+ # Dynamic nodes: downstream of dynamic keys but NOT downstream of accumulators
585
+ # These are recomputed for each chunk
586
+ dynamic_nodes = (
587
+ dynamic_descendants - self._dynamic_keys - accumulator_descendants
588
+ )
589
+
590
+ # Context-dependent nodes: downstream of context but not of dynamic
591
+ context_dependent = (
592
+ context_descendants - dynamic_descendants - self._context_keys
593
+ )
594
+
595
+ # Static nodes: not dependent on dynamic or context
596
+ static_nodes = all_nodes - dynamic_descendants - context_descendants
597
+
598
+ return {
599
+ 'static': static_nodes,
600
+ 'cached_keys': self._cached_keys & all_nodes,
601
+ 'dynamic_keys': self._dynamic_keys & all_nodes,
602
+ 'dynamic_nodes': dynamic_nodes,
603
+ 'context_keys': self._context_keys & all_nodes,
604
+ 'context_dependent': context_dependent,
605
+ 'accumulator_keys': accumulator_keys & all_nodes,
606
+ 'target_keys': target_keys & all_nodes,
607
+ 'finalize_nodes': finalize_nodes,
608
+ }
609
+
610
+ def visualize(
611
+ self,
612
+ compact: bool = False,
613
+ mode: Literal['data', 'task', 'both'] = 'data',
614
+ cluster_generics: bool = True,
615
+ cluster_color: str | None = '#f0f0ff',
616
+ show_legend: bool = True,
617
+ show_static_dependencies: bool = True,
618
+ **kwargs: Any,
619
+ ) -> 'graphviz.Digraph':
620
+ """
621
+ Visualize the workflow with node classification styling.
622
+
623
+ This post-processes sciline's visualization to add styling that highlights:
624
+
625
+ - Static nodes (gray): Pre-computed once, dependencies of cached nodes
626
+ - Static cached nodes (gray, thick border): Pre-computed and cached
627
+ - Dynamic keys (green, thick border): Input entry points for chunks
628
+ - Dynamic nodes (light green): Recomputed for each chunk
629
+ - Context keys (blue, thick border): Input entry points for context
630
+ - Context-dependent nodes (light blue): Cached until context changes
631
+ - Accumulator keys (orange cylinder): Aggregation points
632
+ - Finalize nodes (plum): Computed from accumulators during finalize
633
+ - Target keys (double border): Final outputs
634
+
635
+ Parameters
636
+ ----------
637
+ compact:
638
+ If True, parameter-table-dependent branches are collapsed.
639
+ mode:
640
+ 'data' shows only data nodes, 'task' shows task nodes, 'both' shows all.
641
+ cluster_generics:
642
+ If True, generic products are grouped into clusters.
643
+ cluster_color:
644
+ Background color of clusters. If None, clusters are dotted.
645
+ show_legend:
646
+ If True, add a legend explaining the node styles.
647
+ show_static_dependencies:
648
+ If True (default), show all static nodes including dependencies of cached
649
+ nodes. If False, hide the ancestors of cached nodes to simplify the graph.
650
+ **kwargs:
651
+ Additional arguments passed to graphviz.Digraph.
652
+
653
+ Returns
654
+ -------
655
+ :
656
+ A graphviz.Digraph with styled nodes.
657
+ """
658
+ viz_workflow = self._get_viz_workflow()
659
+ if not show_static_dependencies:
660
+ # Create a pruned workflow that hides ancestors of cached nodes
661
+ viz_workflow = viz_workflow.copy()
662
+ for key in self._cached_keys:
663
+ if key in viz_workflow.underlying_graph:
664
+ viz_workflow[key] = None
665
+
666
+ graph = viz_workflow.underlying_graph
667
+
668
+ dot = viz_workflow.visualize(
669
+ compact=compact,
670
+ mode=mode,
671
+ cluster_generics=cluster_generics,
672
+ cluster_color=cluster_color,
673
+ **kwargs,
674
+ )
675
+
676
+ classifications = self._classify_nodes(graph)
677
+
678
+ # Build a mapping from formatted names to original keys
679
+ key_to_formatted: dict[Any, str] = {}
680
+ for key in graph.nodes:
681
+ key_to_formatted[key] = _format_key_for_graphviz(key)
682
+
683
+ # Apply styles by re-adding nodes with updated attributes
684
+ for key, formatted_name in key_to_formatted.items():
685
+ style = _get_node_style(key, classifications)
686
+ if style:
687
+ dot.node(formatted_name, **style)
688
+
689
+ # Add legend
690
+ if show_legend:
691
+ _add_legend(dot, show_static_dependencies=show_static_dependencies)
692
+
693
+ return dot
694
+
520
695
 
521
696
  def _find_descendants(
522
- workflow: sciline.Pipeline, keys: tuple[sciline.typing.Key, ...]
697
+ source: sciline.Pipeline | nx.DiGraph,
698
+ keys: set[sciline.typing.Key] | tuple[sciline.typing.Key, ...],
523
699
  ) -> set[sciline.typing.Key]:
524
- graph = workflow.underlying_graph
700
+ """Find all descendants of any key in keys, including the keys themselves."""
701
+ graph = source.underlying_graph if hasattr(source, 'underlying_graph') else source
702
+ keys_set = set(keys)
525
703
  descendants = set()
526
- for key in keys:
527
- descendants |= nx.descendants(graph, key)
528
- return descendants | set(keys)
704
+ for key in keys_set:
705
+ if key in graph:
706
+ descendants |= nx.descendants(graph, key)
707
+ return descendants | (keys_set & set(graph.nodes))
529
708
 
530
709
 
531
710
  def _find_parents(
@@ -536,3 +715,194 @@ def _find_parents(
536
715
  for key in keys:
537
716
  parents |= set(graph.predecessors(key))
538
717
  return parents
718
+
719
+
720
+ # =============================================================================
721
+ # Visualization helpers
722
+ # =============================================================================
723
+
724
+ # Style definitions for each node category
725
+ # Priority order for overlapping categories (higher = takes precedence for fill)
726
+ _VIZ_STYLES = {
727
+ 'static': {
728
+ 'fillcolor': '#e8e8e8', # Gray
729
+ 'style': 'filled',
730
+ 'priority': 0,
731
+ },
732
+ 'cached_keys': {
733
+ 'fillcolor': '#e8e8e8', # Gray (same as static)
734
+ 'style': 'filled',
735
+ 'penwidth': '2.5', # Thick border to distinguish from dependencies
736
+ 'priority': 1,
737
+ },
738
+ 'context_dependent': {
739
+ 'fillcolor': '#d4e8f4', # Light blue
740
+ 'style': 'filled',
741
+ 'priority': 2,
742
+ },
743
+ 'context_keys': {
744
+ 'fillcolor': '#87CEEB', # Sky blue
745
+ 'style': 'filled',
746
+ 'penwidth': '2.5',
747
+ 'color': 'black', # Override sciline's red for unsatisfied
748
+ 'fontcolor': 'black',
749
+ 'priority': 3,
750
+ },
751
+ 'dynamic_nodes': {
752
+ 'fillcolor': '#d4f4d4', # Light green
753
+ 'style': 'filled',
754
+ 'priority': 4,
755
+ },
756
+ 'dynamic_keys': {
757
+ 'fillcolor': '#90EE90', # Light green (stronger)
758
+ 'style': 'filled',
759
+ 'penwidth': '2.5',
760
+ 'color': 'black', # Override sciline's red for unsatisfied
761
+ 'fontcolor': 'black',
762
+ 'priority': 5,
763
+ },
764
+ 'accumulator_keys': {
765
+ 'fillcolor': '#FFB347', # Orange
766
+ 'style': 'filled',
767
+ 'shape': 'cylinder',
768
+ 'priority': 6,
769
+ },
770
+ 'finalize_nodes': {
771
+ 'fillcolor': '#DDA0DD', # Plum (more distinct from cluster color)
772
+ 'style': 'filled',
773
+ 'priority': 7,
774
+ },
775
+ 'target_keys': {
776
+ 'peripheries': '2', # Double border
777
+ 'priority': 8,
778
+ },
779
+ }
780
+
781
+
782
+ def _format_key_for_graphviz(key: Any) -> str:
783
+ """Format a key to match sciline's node naming convention."""
784
+ from sciline.visualize import _format_type
785
+
786
+ return _format_type(key).name
787
+
788
+
789
+ def _get_node_style(key: Any, classifications: dict[str, set[Any]]) -> dict[str, str]:
790
+ """
791
+ Determine the style for a node based on its classifications.
792
+
793
+ A node can belong to multiple categories. We combine styles with
794
+ higher priority categories taking precedence for conflicting attributes.
795
+ """
796
+ applicable = []
797
+ for category, keys in classifications.items():
798
+ if key in keys:
799
+ applicable.append((_VIZ_STYLES[category]['priority'], category))
800
+
801
+ if not applicable:
802
+ return {}
803
+
804
+ # Sort by priority and merge styles
805
+ applicable.sort()
806
+ merged: dict[str, str] = {}
807
+ for _, category in applicable:
808
+ style = _VIZ_STYLES[category].copy()
809
+ style.pop('priority')
810
+ merged.update(style)
811
+
812
+ return merged
813
+
814
+
815
+ def _add_legend(
816
+ dot: 'graphviz.Digraph', *, show_static_dependencies: bool = True
817
+ ) -> None:
818
+ """Add a legend subgraph explaining the node styles."""
819
+ with dot.subgraph(name='cluster_legend') as legend:
820
+ legend.attr(label='Legend', fontsize='14', style='rounded')
821
+ legend.attr('node', shape='rectangle', width='1.5', height='0.3')
822
+
823
+ # Track first node for edge chaining
824
+ prev_node = None
825
+
826
+ if show_static_dependencies:
827
+ legend.node(
828
+ 'legend_static',
829
+ 'Static',
830
+ fillcolor='#e8e8e8',
831
+ style='filled',
832
+ )
833
+ prev_node = 'legend_static'
834
+
835
+ legend.node(
836
+ 'legend_cached',
837
+ 'Static (cached)',
838
+ fillcolor='#e8e8e8',
839
+ style='filled',
840
+ penwidth='2.5',
841
+ )
842
+ if prev_node:
843
+ legend.edge(prev_node, 'legend_cached', style='invis')
844
+ prev_node = 'legend_cached'
845
+
846
+ legend.node(
847
+ 'legend_context_key',
848
+ 'Context key (input)',
849
+ fillcolor='#87CEEB',
850
+ style='filled',
851
+ penwidth='2.5',
852
+ )
853
+ legend.edge(prev_node, 'legend_context_key', style='invis')
854
+ prev_node = 'legend_context_key'
855
+
856
+ legend.node(
857
+ 'legend_context_dep',
858
+ 'Context-dependent',
859
+ fillcolor='#d4e8f4',
860
+ style='filled',
861
+ )
862
+ legend.edge(prev_node, 'legend_context_dep', style='invis')
863
+ prev_node = 'legend_context_dep'
864
+
865
+ legend.node(
866
+ 'legend_dynamic_key',
867
+ 'Dynamic key (input)',
868
+ fillcolor='#90EE90',
869
+ style='filled',
870
+ penwidth='2.5',
871
+ )
872
+ legend.edge(prev_node, 'legend_dynamic_key', style='invis')
873
+ prev_node = 'legend_dynamic_key'
874
+
875
+ legend.node(
876
+ 'legend_dynamic_node',
877
+ 'Dynamic (per chunk)',
878
+ fillcolor='#d4f4d4',
879
+ style='filled',
880
+ )
881
+ legend.edge(prev_node, 'legend_dynamic_node', style='invis')
882
+ prev_node = 'legend_dynamic_node'
883
+
884
+ legend.node(
885
+ 'legend_accumulator',
886
+ 'Accumulator',
887
+ fillcolor='#FFB347',
888
+ style='filled',
889
+ shape='cylinder',
890
+ )
891
+ legend.edge(prev_node, 'legend_accumulator', style='invis')
892
+ prev_node = 'legend_accumulator'
893
+
894
+ legend.node(
895
+ 'legend_finalize',
896
+ 'Finalize (from accum.)',
897
+ fillcolor='#DDA0DD',
898
+ style='filled',
899
+ )
900
+ legend.edge(prev_node, 'legend_finalize', style='invis')
901
+ prev_node = 'legend_finalize'
902
+
903
+ legend.node(
904
+ 'legend_target',
905
+ 'Target (output)',
906
+ peripheries='2',
907
+ )
908
+ legend.edge(prev_node, 'legend_target', style='invis')
@@ -33,6 +33,8 @@ from .types import (
33
33
  TofLookupTable,
34
34
  TofLookupTableFilename,
35
35
  TofMonitor,
36
+ WavelengthDetector,
37
+ WavelengthMonitor,
36
38
  )
37
39
  from .workflow import GenericTofWorkflow
38
40
 
@@ -60,6 +62,8 @@ __all__ = [
60
62
  "TofLookupTableFilename",
61
63
  "TofLookupTableWorkflow",
62
64
  "TofMonitor",
65
+ "WavelengthDetector",
66
+ "WavelengthMonitor",
63
67
  "providers",
64
68
  "simulate_chopper_cascade_using_tof",
65
69
  ]
@@ -14,6 +14,7 @@ import scipp as sc
14
14
  import scippneutron as scn
15
15
  import scippnexus as snx
16
16
  from scippneutron._utils import elem_unit
17
+ from scippneutron.conversion.tof import wavelength_from_tof
17
18
 
18
19
  try:
19
20
  from .interpolator_numba import Interpolator as InterpolatorImpl
@@ -39,6 +40,8 @@ from .types import (
39
40
  TofDetector,
40
41
  TofLookupTable,
41
42
  TofMonitor,
43
+ WavelengthDetector,
44
+ WavelengthMonitor,
42
45
  )
43
46
 
44
47
 
@@ -401,14 +404,15 @@ def _compute_tof_data(
401
404
  ) -> sc.DataArray:
402
405
  if da.bins is None:
403
406
  data = _time_of_flight_data_histogram(da=da, lookup=lookup, ltotal=ltotal)
404
- return rebin_strictly_increasing(data, dim='tof')
407
+ out = rebin_strictly_increasing(data, dim='tof')
405
408
  else:
406
- return _time_of_flight_data_events(
409
+ out = _time_of_flight_data_events(
407
410
  da=da,
408
411
  lookup=lookup,
409
412
  ltotal=ltotal,
410
413
  pulse_stride_offset=pulse_stride_offset,
411
414
  )
415
+ return out.assign_coords(Ltotal=ltotal)
412
416
 
413
417
 
414
418
  def detector_time_of_flight_data(
@@ -418,8 +422,9 @@ def detector_time_of_flight_data(
418
422
  pulse_stride_offset: PulseStrideOffset,
419
423
  ) -> TofDetector[RunType]:
420
424
  """
421
- Convert the time-of-arrival data to time-of-flight data using a lookup table.
422
- The output data will have a time-of-flight coordinate.
425
+ Convert the time-of-arrival (event_time_offset) data to time-of-flight data using a
426
+ lookup table.
427
+ The output data will have two new coordinates: time-of-flight and Ltotal.
423
428
 
424
429
  Parameters
425
430
  ----------
@@ -452,8 +457,9 @@ def monitor_time_of_flight_data(
452
457
  pulse_stride_offset: PulseStrideOffset,
453
458
  ) -> TofMonitor[RunType, MonitorType]:
454
459
  """
455
- Convert the time-of-arrival data to time-of-flight data using a lookup table.
456
- The output data will have a time-of-flight coordinate.
460
+ Convert the time-of-arrival (event_time_offset) data to time-of-flight data using a
461
+ lookup table.
462
+ The output data will have two new coordinates: time-of-flight and Ltotal.
457
463
 
458
464
  Parameters
459
465
  ----------
@@ -526,6 +532,47 @@ def detector_time_of_arrival_data(
526
532
  return ToaDetector[RunType](result)
527
533
 
528
534
 
535
+ def _tof_to_wavelength(da: sc.DataArray) -> sc.DataArray:
536
+ """
537
+ Convert time-of-flight data to wavelength data.
538
+
539
+ Here we assume that the input data contains a Ltotal coordinate, which is required
540
+ for the conversion.
541
+ This coordinate is assigned in the ``_compute_tof_data`` function.
542
+ """
543
+ return da.transform_coords(
544
+ 'wavelength', graph={"wavelength": wavelength_from_tof}, keep_intermediate=False
545
+ )
546
+
547
+
548
+ def detector_wavelength_data(
549
+ detector_data: TofDetector[RunType],
550
+ ) -> WavelengthDetector[RunType]:
551
+ """
552
+ Convert time-of-flight coordinate of the detector data to wavelength.
553
+
554
+ Parameters
555
+ ----------
556
+ da:
557
+ Detector data with time-of-flight coordinate.
558
+ """
559
+ return WavelengthDetector[RunType](_tof_to_wavelength(detector_data))
560
+
561
+
562
+ def monitor_wavelength_data(
563
+ monitor_data: TofMonitor[RunType, MonitorType],
564
+ ) -> WavelengthMonitor[RunType, MonitorType]:
565
+ """
566
+ Convert time-of-flight coordinate of the monitor data to wavelength.
567
+
568
+ Parameters
569
+ ----------
570
+ da:
571
+ Monitor data with time-of-flight coordinate.
572
+ """
573
+ return WavelengthMonitor[RunType, MonitorType](_tof_to_wavelength(monitor_data))
574
+
575
+
529
576
  def providers() -> tuple[Callable]:
530
577
  """
531
578
  Providers of the time-of-flight workflow.
@@ -536,4 +583,6 @@ def providers() -> tuple[Callable]:
536
583
  detector_ltotal_from_straight_line_approximation,
537
584
  monitor_ltotal_from_straight_line_approximation,
538
585
  detector_time_of_arrival_data,
586
+ detector_wavelength_data,
587
+ monitor_wavelength_data,
539
588
  )
@@ -4,7 +4,6 @@
4
4
  Utilities for computing time-of-flight lookup tables from neutron simulations.
5
5
  """
6
6
 
7
- import math
8
7
  from dataclasses import dataclass
9
8
  from typing import NewType
10
9
 
@@ -17,11 +16,10 @@ from .types import TofLookupTable
17
16
 
18
17
 
19
18
  @dataclass
20
- class SimulationResults:
19
+ class BeamlineComponentReading:
21
20
  """
22
- Results of a time-of-flight simulation used to create a lookup table.
23
-
24
- The results (apart from ``distance``) should be flat lists (1d arrays) of length N
21
+ Reading at a given position along the beamline from a time-of-flight simulation.
22
+ The data (apart from ``distance``) should be flat lists (1d arrays) of length N
25
23
  where N is the number of neutrons, containing the properties of the neutrons in the
26
24
  simulation.
27
25
 
@@ -30,9 +28,6 @@ class SimulationResults:
30
28
  time_of_arrival:
31
29
  Time of arrival of the neutrons at the position where the events were recorded
32
30
  (1d array of size N).
33
- speed:
34
- Speed of the neutrons, typically derived from the wavelength of the neutrons
35
- (1d array of size N).
36
31
  wavelength:
37
32
  Wavelength of the neutrons (1d array of size N).
38
33
  weight:
@@ -43,15 +38,39 @@ class SimulationResults:
43
38
  For a ``tof`` simulation, this is just the position of the detector where the
44
39
  events are recorded. For a ``McStas`` simulation, this is the distance between
45
40
  the source and the event monitor.
46
- choppers:
47
- The parameters of the choppers used in the simulation (if any).
48
41
  """
49
42
 
50
43
  time_of_arrival: sc.Variable
51
- speed: sc.Variable
52
44
  wavelength: sc.Variable
53
45
  weight: sc.Variable
54
46
  distance: sc.Variable
47
+
48
+ def __post_init__(self):
49
+ self.speed = (sc.constants.h / sc.constants.m_n) / self.wavelength
50
+
51
+
52
+ @dataclass
53
+ class SimulationResults:
54
+ """
55
+ Results of a time-of-flight simulation used to create a lookup table.
56
+ It should contain readings at various positions along the beamline, e.g., at
57
+ the source and after each chopper.
58
+ It also contains the chopper parameters used in the simulation, so it can be
59
+ determined if this simulation is compatible with a given experiment.
60
+
61
+ Parameters
62
+ ----------
63
+ readings:
64
+ A dict of :class:`BeamlineComponentReading` objects representing the readings at
65
+ various positions along the beamline. The keys in the dict should correspond to
66
+ the names of the components (e.g., 'source', 'chopper1', etc.).
67
+ choppers:
68
+ The chopper parameters used in the simulation (if any). These are used to verify
69
+ that the simulation is compatible with a given experiment (comparing chopper
70
+ openings, frequencies, phases, etc.).
71
+ """
72
+
73
+ readings: dict[str, BeamlineComponentReading]
55
74
  choppers: DiskChoppers[AnyRun] | None = None
56
75
 
57
76
 
@@ -145,11 +164,10 @@ def _mask_large_uncertainty(table: sc.DataArray, error_threshold: float):
145
164
  table.values[mask.values] = np.nan
146
165
 
147
166
 
148
- def _compute_mean_tof_in_distance_range(
149
- simulation: SimulationResults,
150
- distance_bins: sc.Variable,
167
+ def _compute_mean_tof(
168
+ simulation: BeamlineComponentReading,
169
+ distance: sc.Variable,
151
170
  time_bins: sc.Variable,
152
- distance_unit: str,
153
171
  time_unit: str,
154
172
  frame_period: sc.Variable,
155
173
  time_bins_half_width: sc.Variable,
@@ -162,12 +180,10 @@ def _compute_mean_tof_in_distance_range(
162
180
  ----------
163
181
  simulation:
164
182
  Results of a time-of-flight simulation used to create a lookup table.
165
- distance_bins:
166
- Bin edges for the distance axis in the lookup table.
183
+ distance:
184
+ Distance where table is computed.
167
185
  time_bins:
168
186
  Bin edges for the event_time_offset axis in the lookup table.
169
- distance_unit:
170
- Unit of the distance axis.
171
187
  time_unit:
172
188
  Unit of the event_time_offset axis.
173
189
  frame_period:
@@ -175,23 +191,17 @@ def _compute_mean_tof_in_distance_range(
175
191
  time_bins_half_width:
176
192
  Half width of the time bins in the event_time_offset axis.
177
193
  """
178
- simulation_distance = simulation.distance.to(unit=distance_unit)
179
- distances = sc.midpoints(distance_bins)
194
+ travel_length = distance - simulation.distance.to(unit=distance.unit)
180
195
  # Compute arrival and flight times for all neutrons
181
- toas = simulation.time_of_arrival + (distances / simulation.speed).to(
196
+ toas = simulation.time_of_arrival + (travel_length / simulation.speed).to(
182
197
  unit=time_unit, copy=False
183
198
  )
184
- dist = distances + simulation_distance
185
- tofs = dist * (sc.constants.m_n / sc.constants.h) * simulation.wavelength
199
+ tofs = distance / simulation.speed
186
200
 
187
201
  data = sc.DataArray(
188
- data=sc.broadcast(simulation.weight, sizes=toas.sizes),
189
- coords={
190
- "toa": toas,
191
- "tof": tofs.to(unit=time_unit, copy=False),
192
- "distance": dist,
193
- },
194
- ).flatten(to="event")
202
+ data=simulation.weight,
203
+ coords={"toa": toas, "tof": tofs.to(unit=time_unit, copy=False)},
204
+ )
195
205
 
196
206
  # Add the event_time_offset coordinate, wrapped to the frame_period
197
207
  data.coords['event_time_offset'] = data.coords['toa'] % frame_period
@@ -205,18 +215,14 @@ def _compute_mean_tof_in_distance_range(
205
215
  # data.coords['event_time_offset'] %= pulse_period - time_bins_half_width
206
216
  data.coords['event_time_offset'] %= frame_period - time_bins_half_width
207
217
 
208
- binned = data.bin(
209
- distance=distance_bins + simulation_distance, event_time_offset=time_bins
210
- )
211
-
218
+ binned = data.bin(event_time_offset=time_bins)
219
+ binned_sum = binned.bins.sum()
212
220
  # Weighted mean of tof inside each bin
213
- mean_tof = (
214
- binned.bins.data * binned.bins.coords["tof"]
215
- ).bins.sum() / binned.bins.sum()
221
+ mean_tof = (binned.bins.data * binned.bins.coords["tof"]).bins.sum() / binned_sum
216
222
  # Compute the variance of the tofs to track regions with large uncertainty
217
223
  variance = (
218
224
  binned.bins.data * (binned.bins.coords["tof"] - mean_tof) ** 2
219
- ).bins.sum() / binned.bins.sum()
225
+ ).bins.sum() / binned_sum
220
226
 
221
227
  mean_tof.variances = variance.values
222
228
  return mean_tof
@@ -239,7 +245,7 @@ def make_tof_lookup_table(
239
245
  ----------
240
246
  simulation:
241
247
  Results of a time-of-flight simulation used to create a lookup table.
242
- The results should be a flat table with columns for time-of-arrival, speed,
248
+ The results should be a flat table with columns for time-of-arrival,
243
249
  wavelength, and weight.
244
250
  ltotal_range:
245
251
  Range of total flight path lengths from the source to the detector.
@@ -294,15 +300,14 @@ def make_tof_lookup_table(
294
300
  left edge.
295
301
  """
296
302
  distance_unit = "m"
297
- time_unit = simulation.time_of_arrival.unit
303
+ time_unit = "us"
298
304
  res = distance_resolution.to(unit=distance_unit)
299
305
  pulse_period = pulse_period.to(unit=time_unit)
300
306
  frame_period = pulse_period * pulse_stride
301
307
 
302
- min_dist, max_dist = (
303
- x.to(unit=distance_unit) - simulation.distance.to(unit=distance_unit)
304
- for x in ltotal_range
305
- )
308
+ min_dist = ltotal_range[0].to(unit=distance_unit)
309
+ max_dist = ltotal_range[1].to(unit=distance_unit)
310
+
306
311
  # We need to bin the data below, to compute the weighted mean of the wavelength.
307
312
  # This results in data with bin edges.
308
313
  # However, the 2d interpolator expects bin centers.
@@ -314,7 +319,7 @@ def make_tof_lookup_table(
314
319
  # ensure that the last bin is not cut off. We want the upper edge to be higher than
315
320
  # the maximum distance, hence we pad with an additional 1.5 x resolution.
316
321
  pad = 2.0 * res
317
- distance_bins = sc.arange('distance', min_dist - pad, max_dist + pad, res)
322
+ distances = sc.arange('distance', min_dist - pad, max_dist + pad, res)
318
323
 
319
324
  # Create some time bins for event_time_offset.
320
325
  # We want our final table to strictly cover the range [0, frame_period].
@@ -329,22 +334,36 @@ def make_tof_lookup_table(
329
334
  time_bins_half_width = 0.5 * (time_bins[1] - time_bins[0])
330
335
  time_bins -= time_bins_half_width
331
336
 
332
- # To avoid a too large RAM usage, we compute the table in chunks, and piece them
333
- # together at the end.
334
- ndist = len(distance_bins) - 1
335
- max_size = 2e7
336
- total_size = ndist * len(simulation.time_of_arrival)
337
- nchunks = math.ceil(total_size / max_size)
338
- chunk_size = math.ceil(ndist / nchunks)
337
+ # Sort simulation readings by reverse distance
338
+ sorted_simulation_results = sorted(
339
+ simulation.readings.values(), key=lambda x: x.distance.value, reverse=True
340
+ )
341
+
339
342
  pieces = []
340
- for i in range(nchunks):
341
- dist_edges = distance_bins[i * chunk_size : (i + 1) * chunk_size + 1]
343
+ # To avoid large RAM usage, and having to split the distances into chunks
344
+ # according to which component reading to use, we simply loop over distances one
345
+ # by one here.
346
+ for dist in distances:
347
+ # Find the correct simulation reading
348
+ simulation_reading = None
349
+ for reading in sorted_simulation_results:
350
+ if dist.value >= reading.distance.to(unit=dist.unit).value:
351
+ simulation_reading = reading
352
+ break
353
+ if simulation_reading is None:
354
+ closest = sorted_simulation_results[-1]
355
+ raise ValueError(
356
+ "Building the Tof lookup table failed: the requested position "
357
+ f"{dist.value} {dist.unit} is before the component with the lowest "
358
+ "distance in the simulation. The first component in the beamline "
359
+ f"has distance {closest.distance.value} {closest.distance.unit}."
360
+ )
361
+
342
362
  pieces.append(
343
- _compute_mean_tof_in_distance_range(
344
- simulation=simulation,
345
- distance_bins=dist_edges,
363
+ _compute_mean_tof(
364
+ simulation=simulation_reading,
365
+ distance=dist,
346
366
  time_bins=time_bins,
347
- distance_unit=distance_unit,
348
367
  time_unit=time_unit,
349
368
  frame_period=frame_period,
350
369
  time_bins_half_width=time_bins_half_width,
@@ -352,7 +371,6 @@ def make_tof_lookup_table(
352
371
  )
353
372
 
354
373
  table = sc.concat(pieces, 'distance')
355
- table.coords["distance"] = sc.midpoints(table.coords["distance"])
356
374
  table.coords["event_time_offset"] = sc.midpoints(table.coords["event_time_offset"])
357
375
 
358
376
  # Copy the left edge to the right to create periodic boundary conditions
@@ -361,7 +379,7 @@ def make_tof_lookup_table(
361
379
  [table.data, table.data['event_time_offset', 0]], dim='event_time_offset'
362
380
  ),
363
381
  coords={
364
- "distance": table.coords["distance"],
382
+ "distance": distances,
365
383
  "event_time_offset": sc.concat(
366
384
  [table.coords["event_time_offset"], frame_period],
367
385
  dim='event_time_offset',
@@ -388,6 +406,24 @@ def make_tof_lookup_table(
388
406
  )
389
407
 
390
408
 
409
+ def _to_component_reading(component):
410
+ events = component.data.squeeze().flatten(to='event')
411
+ sel = sc.full(value=True, sizes=events.sizes)
412
+ for key in {'blocked_by_others', 'blocked_by_me'} & set(events.masks.keys()):
413
+ sel &= ~events.masks[key]
414
+ events = events[sel]
415
+ # If the component is a source, use 'birth_time' as 'toa'
416
+ toa = (
417
+ events.coords["toa"] if "toa" in events.coords else events.coords["birth_time"]
418
+ )
419
+ return BeamlineComponentReading(
420
+ time_of_arrival=toa,
421
+ wavelength=events.coords["wavelength"],
422
+ weight=events.data,
423
+ distance=component.distance,
424
+ )
425
+
426
+
391
427
  def simulate_chopper_cascade_using_tof(
392
428
  choppers: DiskChoppers[AnyRun],
393
429
  source_position: SourcePosition,
@@ -432,31 +468,14 @@ def simulate_chopper_cascade_using_tof(
432
468
  source = tof.Source(
433
469
  facility=facility, neutrons=neutrons, pulses=pulse_stride, seed=seed
434
470
  )
471
+ sim_readings = {"source": _to_component_reading(source)}
435
472
  if not tof_choppers:
436
- events = source.data.squeeze().flatten(to='event')
437
- return SimulationResults(
438
- time_of_arrival=events.coords["birth_time"],
439
- speed=events.coords["speed"],
440
- wavelength=events.coords["wavelength"],
441
- weight=events.data,
442
- distance=0.0 * sc.units.m,
443
- )
473
+ return SimulationResults(readings=sim_readings, choppers=None)
444
474
  model = tof.Model(source=source, choppers=tof_choppers)
445
475
  results = model.run()
446
- # Find name of the furthest chopper in tof_choppers
447
- furthest_chopper = max(tof_choppers, key=lambda c: c.distance)
448
- events = results[furthest_chopper.name].data.squeeze().flatten(to='event')
449
- events = events[
450
- ~(events.masks["blocked_by_others"] | events.masks["blocked_by_me"])
451
- ]
452
- return SimulationResults(
453
- time_of_arrival=events.coords["toa"],
454
- speed=events.coords["speed"],
455
- wavelength=events.coords["wavelength"],
456
- weight=events.data,
457
- distance=furthest_chopper.distance,
458
- choppers=choppers,
459
- )
476
+ for name, ch in results.choppers.items():
477
+ sim_readings[name] = _to_component_reading(ch)
478
+ return SimulationResults(readings=sim_readings, choppers=choppers)
460
479
 
461
480
 
462
481
  def TofLookupTableWorkflow():
@@ -87,3 +87,11 @@ class ToaDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray):
87
87
 
88
88
  class TofMonitor(sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray):
89
89
  """Monitor data with time-of-flight coordinate."""
90
+
91
+
92
+ class WavelengthDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray):
93
+ """Detector data with wavelength coordinate."""
94
+
95
+
96
+ class WavelengthMonitor(sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray):
97
+ """Monitor data with wavelength coordinate."""
@@ -7,11 +7,7 @@ import scipp as sc
7
7
 
8
8
  from ..nexus import GenericNeXusWorkflow
9
9
  from . import eto_to_tof
10
- from .types import (
11
- PulseStrideOffset,
12
- TofLookupTable,
13
- TofLookupTableFilename,
14
- )
10
+ from .types import PulseStrideOffset, TofLookupTable, TofLookupTableFilename
15
11
 
16
12
 
17
13
  def load_tof_lookup_table(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: essreduce
3
- Version: 26.1.0
3
+ Version: 26.2.0
4
4
  Summary: Common data reduction tools for the ESS facility
5
5
  Author: Scipp contributors
6
6
  License-Expression: BSD-3-Clause
@@ -25,6 +25,7 @@ Requires-Dist: scipp>=25.04.0
25
25
  Requires-Dist: scippneutron>=25.11.1
26
26
  Requires-Dist: scippnexus>=25.06.0
27
27
  Provides-Extra: test
28
+ Requires-Dist: graphviz>=0.20; extra == "test"
28
29
  Requires-Dist: ipywidgets>=8.1; extra == "test"
29
30
  Requires-Dist: matplotlib>=3.10.7; extra == "test"
30
31
  Requires-Dist: numba>=0.59; extra == "test"
@@ -3,16 +3,15 @@ ess/reduce/logging.py,sha256=6n8Czq4LZ3OK9ENlKsWSI1M3KvKv6_HSoUiV4__IUlU,357
3
3
  ess/reduce/normalization.py,sha256=r8H6SZgT94a1HE9qZ6Bx3N6c3VG3FzlJPzoCVMNI5-0,13081
4
4
  ess/reduce/parameter.py,sha256=4sCfoKOI2HuO_Q7JLH_jAXnEOFANSn5P3NdaOBzhJxc,4635
5
5
  ess/reduce/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- ess/reduce/streaming.py,sha256=zbqxQz5dASDq4ZVyx-TdbapBXMyBttImCYz_6WOj4pg,17978
6
+ ess/reduce/streaming.py,sha256=zdVosXbrtys5Z9gJ1dhnqoLeg2t8abicHBluytmrHCs,31150
7
7
  ess/reduce/ui.py,sha256=zmorAbDwX1cU3ygDT--OP58o0qU7OBcmJz03jPeYSLA,10884
8
8
  ess/reduce/uncertainty.py,sha256=LR4O6ApB6Z-W9gC_XW0ajupl8yFG-du0eee1AX_R-gk,6990
9
9
  ess/reduce/workflow.py,sha256=738-lcdgsORYfQ4A0UTk2IgnbVxC3jBdpscpaOFIpdc,3114
10
10
  ess/reduce/data/__init__.py,sha256=uDtqkmKA_Zwtj6II25zntz9T812XhdCn3tktYev4uyY,486
11
11
  ess/reduce/data/_registry.py,sha256=ngJMzP-AuMN0EKtws5vYSEPsv_Bn3TZjjIvNUKWQDeA,13992
12
- ess/reduce/live/__init__.py,sha256=jPQVhihRVNtEDrE20PoKkclKV2aBF1lS7cCHootgFgI,204
12
+ ess/reduce/live/__init__.py,sha256=AFcqRbIhyqAo-oKV7gNIOfHRxfoVDiULf9vB4cEckHk,133
13
13
  ess/reduce/live/raw.py,sha256=z3JzKl1tOH51z1PWT3MJERactSFRXrNI_MBmpAHX71g,31094
14
14
  ess/reduce/live/roi.py,sha256=t65SfGtCtb8r-f4hkfg2I02CEnOp6Hh5Tv9qOqPOeK0,10588
15
- ess/reduce/live/workflow.py,sha256=bsbwvTqPhRO6mC__3b7MgU7DWwAnOvGvG-t2n22EKq8,4285
16
15
  ess/reduce/nexus/__init__.py,sha256=xXc982vZqRba4jR4z5hA2iim17Z7niw4KlS1aLFbn1Q,1107
17
16
  ess/reduce/nexus/_nexus_loader.py,sha256=5J26y_t-kabj0ik0jf3OLSYda3lDLDQhvPd2_ro7Q_0,23927
18
17
  ess/reduce/nexus/json_generator.py,sha256=ME2Xn8L7Oi3uHJk9ZZdCRQTRX-OV_wh9-DJn07Alplk,2529
@@ -20,15 +19,15 @@ ess/reduce/nexus/json_nexus.py,sha256=QrVc0p424nZ5dHX9gebAJppTw6lGZq9404P_OFl1gi
20
19
  ess/reduce/nexus/types.py,sha256=g5oBBEYPH7urF1tDP0tqXtixhQN8JDpe8vmiKrPiUW0,9320
21
20
  ess/reduce/nexus/workflow.py,sha256=bVRnVZ6HTEdIFwZv61JuvFUeTt9efUwe1MR65gBhyw8,24995
22
21
  ess/reduce/scripts/grow_nexus.py,sha256=hET3h06M0xlJd62E3palNLFvJMyNax2kK4XyJcOhl-I,3387
23
- ess/reduce/time_of_flight/__init__.py,sha256=bNxhK0uePltQpCW2sdNTpPdzXL6dt1IT1ri_cJ5VTL8,1561
24
- ess/reduce/time_of_flight/eto_to_tof.py,sha256=imKuN7IARMqBjmi8kjAcsseTFg6OD8ORas9X1FolgFY,18777
22
+ ess/reduce/time_of_flight/__init__.py,sha256=LENpuVY2nP-YKVqUEPlmnvWdVaWBEKPinPITk-2crrA,1659
23
+ ess/reduce/time_of_flight/eto_to_tof.py,sha256=aSQYzYW5rHhObGU9tu5UuZ-M0xfjnI0NdfJTTQa_-ww,20276
25
24
  ess/reduce/time_of_flight/fakes.py,sha256=BqpO56PQyO9ua7QlZw6xXMAPBrqjKZEM_jc-VB83CyE,4289
26
25
  ess/reduce/time_of_flight/interpolator_numba.py,sha256=wh2YS3j2rOu30v1Ok3xNHcwS7t8eEtZyZvbfXOCtgrQ,3835
27
26
  ess/reduce/time_of_flight/interpolator_scipy.py,sha256=_InoAPuMm2qhJKZQBAHOGRFqtvvuQ8TStoN7j_YgS4M,1853
28
- ess/reduce/time_of_flight/lut.py,sha256=8AupwtfB2983DoyzFgnRjWP3J3s_oPghi3XlXaaxxow,18768
27
+ ess/reduce/time_of_flight/lut.py,sha256=JKZCbDYXOME6c5nMX01VvsAPU0JQgWTy64s2jZOI-7s,19746
29
28
  ess/reduce/time_of_flight/resample.py,sha256=Opmi-JA4zNH725l9VB99U4O9UlM37f5ACTCGtwBcows,3718
30
- ess/reduce/time_of_flight/types.py,sha256=FsSueM6OjJdF80uJHj-TNuyVAci8ixFvMuRMt9oHKDQ,3310
31
- ess/reduce/time_of_flight/workflow.py,sha256=2jUxeSmP0KweQTctAzIFJLm7Odf_e7kZzAc8MAMKBEs,3084
29
+ ess/reduce/time_of_flight/types.py,sha256=Eb0l8_KlL7swORefRqbkMSCq6oxNeujU28-gSNXX9JU,3575
30
+ ess/reduce/time_of_flight/workflow.py,sha256=adcEAwUrxx2I6IP-f0RNeUqDlMbE2i3s7evZBmmicpI,3067
32
31
  ess/reduce/widgets/__init__.py,sha256=SoSHBv8Dc3QXV9HUvPhjSYWMwKTGYZLpsWwsShIO97Q,5325
33
32
  ess/reduce/widgets/_base.py,sha256=_wN3FOlXgx_u0c-A_3yyoIH-SdUvDENGgquh9S-h5GI,4852
34
33
  ess/reduce/widgets/_binedges_widget.py,sha256=ZCQsGjYHnJr9GFUn7NjoZc1CdsnAzm_fMzyF-fTKKVY,2785
@@ -41,9 +40,9 @@ ess/reduce/widgets/_spinner.py,sha256=2VY4Fhfa7HMXox2O7UbofcdKsYG-AJGrsgGJB85nDX
41
40
  ess/reduce/widgets/_string_widget.py,sha256=iPAdfANyXHf-nkfhgkyH6gQDklia0LebLTmwi3m-iYQ,1482
42
41
  ess/reduce/widgets/_switchable_widget.py,sha256=fjKz99SKLhIF1BLgGVBSKKn3Lu_jYBwDYGeAjbJY3Q8,2390
43
42
  ess/reduce/widgets/_vector_widget.py,sha256=aTaBqCFHZQhrIoX6-sSqFWCPePEW8HQt5kUio8jP1t8,1203
44
- essreduce-26.1.0.dist-info/licenses/LICENSE,sha256=nVEiume4Qj6jMYfSRjHTM2jtJ4FGu0g-5Sdh7osfEYw,1553
45
- essreduce-26.1.0.dist-info/METADATA,sha256=WTZ9G8OVIHDjHsgeW9zgv-XJBv6mzYc5HAlxBY-5uXQ,1987
46
- essreduce-26.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
- essreduce-26.1.0.dist-info/entry_points.txt,sha256=PMZOIYzCifHMTe4pK3HbhxUwxjFaZizYlLD0td4Isb0,66
48
- essreduce-26.1.0.dist-info/top_level.txt,sha256=0JxTCgMKPLKtp14wb1-RKisQPQWX7i96innZNvHBr-s,4
49
- essreduce-26.1.0.dist-info/RECORD,,
43
+ essreduce-26.2.0.dist-info/licenses/LICENSE,sha256=nVEiume4Qj6jMYfSRjHTM2jtJ4FGu0g-5Sdh7osfEYw,1553
44
+ essreduce-26.2.0.dist-info/METADATA,sha256=BGHgAZ8pZ7iSElnZZ0eAFlUZQhxk1xjg8ivigzMu6h8,2034
45
+ essreduce-26.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
46
+ essreduce-26.2.0.dist-info/entry_points.txt,sha256=PMZOIYzCifHMTe4pK3HbhxUwxjFaZizYlLD0td4Isb0,66
47
+ essreduce-26.2.0.dist-info/top_level.txt,sha256=0JxTCgMKPLKtp14wb1-RKisQPQWX7i96innZNvHBr-s,4
48
+ essreduce-26.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,128 +0,0 @@
1
- # SPDX-License-Identifier: BSD-3-Clause
2
- # Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
3
- """Tools for creating live data reduction workflows for Beamlime."""
4
-
5
- from pathlib import Path
6
- from typing import NewType, TypeVar
7
-
8
- import sciline
9
- import scipp as sc
10
- import scippnexus as snx
11
-
12
- from ess.reduce import streaming
13
- from ess.reduce.nexus import types as nt
14
- from ess.reduce.nexus.json_nexus import JSONGroup
15
-
16
- JSONEventData = NewType('JSONEventData', dict[str, JSONGroup])
17
-
18
-
19
- def _load_json_event_data(name: str, nxevent_data: JSONEventData) -> sc.DataArray:
20
- return snx.Group(nxevent_data[name], definitions=snx.base_definitions())[()]
21
-
22
-
23
- def load_json_event_data_for_sample_run(
24
- name: nt.NeXusName[nt.Component], nxevent_data: JSONEventData
25
- ) -> nt.NeXusData[nt.Component, nt.SampleRun]:
26
- return nt.NeXusData[nt.Component, nt.SampleRun](
27
- _load_json_event_data(name, nxevent_data)
28
- )
29
-
30
-
31
- def load_json_event_data_for_sample_transmission_run(
32
- name: nt.NeXusName[nt.Component], nxevent_data: JSONEventData
33
- ) -> nt.NeXusData[nt.Component, nt.TransmissionRun[nt.SampleRun]]:
34
- return nt.NeXusData[nt.Component, nt.TransmissionRun[nt.SampleRun]](
35
- _load_json_event_data(name, nxevent_data)
36
- )
37
-
38
-
39
- T = TypeVar('T', bound='LiveWorkflow')
40
-
41
-
42
- class LiveWorkflow:
43
- """A workflow class that fulfills Beamlime's LiveWorkflow protocol."""
44
-
45
- def __init__(
46
- self,
47
- *,
48
- streamed: streaming.StreamProcessor,
49
- outputs: dict[str, sciline.typing.Key],
50
- ) -> None:
51
- self._streamed = streamed
52
- self._outputs = outputs
53
-
54
- @classmethod
55
- def from_workflow(
56
- cls: type[T],
57
- *,
58
- workflow: sciline.Pipeline,
59
- accumulators: dict[sciline.typing.Key, streaming.Accumulator],
60
- outputs: dict[str, sciline.typing.Key],
61
- run_type: type[nt.RunType],
62
- nexus_filename: Path,
63
- ) -> T:
64
- """
65
- Create a live workflow from a base workflow and other parameters.
66
-
67
- Parameters
68
- ----------
69
- workflow:
70
- Base workflow to use for live data reduction.
71
- accumulators:
72
- Accumulators forwarded to the stream processor.
73
- outputs:
74
- Mapping from output names to keys in the workflow. The keys correspond to
75
- workflow results that will be computed.
76
- run_type:
77
- Type of the run to process. This defines which run is the dynamic run being
78
- processed. The NeXus template file will be set as the filename for this run.
79
- nexus_filename:
80
- Path to the NeXus file to process.
81
-
82
- Returns
83
- -------
84
- :
85
- Live workflow object.
86
- """
87
-
88
- workflow = workflow.copy()
89
- if run_type is nt.SampleRun:
90
- workflow.insert(load_json_event_data_for_sample_run)
91
- elif run_type is nt.TransmissionRun[nt.SampleRun]:
92
- workflow.insert(load_json_event_data_for_sample_transmission_run)
93
- else:
94
- raise NotImplementedError(f"Run type {run_type} not supported yet.")
95
- workflow[nt.Filename[run_type]] = nexus_filename
96
- streamed = streaming.StreamProcessor(
97
- base_workflow=workflow,
98
- dynamic_keys=(JSONEventData,),
99
- target_keys=outputs.values(),
100
- accumulators=accumulators,
101
- )
102
- return cls(streamed=streamed, outputs=outputs)
103
-
104
- def __call__(
105
- self, nxevent_data: dict[str, JSONGroup], nxlog: dict[str, JSONGroup]
106
- ) -> dict[str, sc.DataArray]:
107
- """
108
- Implements the __call__ method required by the LiveWorkflow protocol.
109
-
110
- Parameters
111
- ----------
112
- nxevent_data:
113
- NeXus event data.
114
- nxlog:
115
- NeXus log data. WARNING: This is currently not used.
116
-
117
- Returns
118
- -------
119
- :
120
- Dictionary of computed and plottable results.
121
- """
122
- # Beamlime passes full path, but the workflow only needs the name of the monitor
123
- # or detector group.
124
- nxevent_data = {
125
- key.lstrip('/').split('/')[2]: value for key, value in nxevent_data.items()
126
- }
127
- results = self._streamed.add_chunk({JSONEventData: nxevent_data})
128
- return {name: results[key] for name, key in self._outputs.items()}