essreduce 26.1.1__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,3 @@
1
1
  # SPDX-License-Identifier: BSD-3-Clause
2
2
  # Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
3
3
  """Live data utilities"""
4
-
5
- from .workflow import LiveWorkflow
6
-
7
- __all__ = [
8
- 'LiveWorkflow',
9
- ]
ess/reduce/streaming.py CHANGED
@@ -5,12 +5,15 @@
5
5
  from abc import ABC, abstractmethod
6
6
  from collections.abc import Callable
7
7
  from copy import deepcopy
8
- from typing import Any, Generic, TypeVar
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
9
9
 
10
10
  import networkx as nx
11
11
  import sciline
12
12
  import scipp as sc
13
13
 
14
+ if TYPE_CHECKING:
15
+ import graphviz
16
+
14
17
  T = TypeVar('T')
15
18
 
16
19
 
@@ -360,11 +363,15 @@ class StreamProcessor:
360
363
  for key in context_keys:
361
364
  workflow[key] = None
362
365
 
366
+ # Store for visualization (copy in case caller modifies base_workflow later)
367
+ self._base_workflow_for_viz = base_workflow.copy()
368
+
363
369
  # Find and pre-compute static nodes as far down the graph as possible
364
370
  nodes = _find_descendants(workflow, dynamic_keys + context_keys)
365
371
  last_static = _find_parents(workflow, nodes) - nodes
366
372
  for key, value in base_workflow.compute(last_static).items():
367
373
  workflow[key] = value
374
+ self._cached_keys = last_static # Store for visualization
368
375
 
369
376
  # Nodes that may need updating on context change but should be cached otherwise.
370
377
  dynamic_nodes = _find_descendants(workflow, dynamic_keys)
@@ -531,15 +538,173 @@ class StreamProcessor:
531
538
  for accumulator in self._accumulators.values():
532
539
  accumulator.clear()
533
540
 
541
+ def _get_viz_workflow(self) -> sciline.Pipeline:
542
+ """Create the workflow used for visualization."""
543
+ viz_workflow = sciline.Pipeline()
544
+ for key in self._target_keys:
545
+ viz_workflow[key] = self._base_workflow_for_viz[key]
546
+ viz_graph = viz_workflow.underlying_graph
547
+ for key in self._dynamic_keys:
548
+ if key in viz_graph:
549
+ viz_workflow[key] = None
550
+ for key in self._context_keys:
551
+ if key in viz_graph:
552
+ viz_workflow[key] = None
553
+ return viz_workflow
554
+
555
+ def _classify_nodes(self, graph: nx.DiGraph) -> dict[str, set[Any]]:
556
+ """
557
+ Classify all nodes in the graph for visualization.
558
+
559
+ Node categories:
560
+
561
+ - static: Pre-computed once, not dependent on dynamic or context keys
562
+ - cached_keys: Subset of static nodes that are actually cached (last_static)
563
+ - dynamic_keys: Input entry points for chunk data
564
+ - dynamic_nodes: Downstream of dynamic keys, recomputed per chunk
565
+ (excludes nodes downstream of accumulators, which are computed in finalize)
566
+ - context_keys: Input entry points for context data
567
+ - context_dependent: Downstream of context keys but not dynamic keys
568
+ - accumulator_keys: Where values are aggregated across chunks
569
+ - target_keys: Final outputs computed in finalize()
570
+ - finalize_nodes: Downstream of accumulators, computed in finalize()
571
+ """
572
+ all_nodes = set(graph.nodes)
573
+ accumulator_keys = set(self._accumulators.keys())
574
+ target_keys = set(self._target_keys)
575
+
576
+ # Compute derived classifications
577
+ dynamic_descendants = _find_descendants(graph, self._dynamic_keys)
578
+ context_descendants = _find_descendants(graph, self._context_keys)
579
+
580
+ # Nodes downstream of accumulators are computed in finalize(), not per-chunk
581
+ accumulator_descendants = _find_descendants(graph, accumulator_keys)
582
+ finalize_nodes = accumulator_descendants - accumulator_keys
583
+
584
+ # Dynamic nodes: downstream of dynamic keys but NOT downstream of accumulators
585
+ # These are recomputed for each chunk
586
+ dynamic_nodes = (
587
+ dynamic_descendants - self._dynamic_keys - accumulator_descendants
588
+ )
589
+
590
+ # Context-dependent nodes: downstream of context but not of dynamic
591
+ context_dependent = (
592
+ context_descendants - dynamic_descendants - self._context_keys
593
+ )
594
+
595
+ # Static nodes: not dependent on dynamic or context
596
+ static_nodes = all_nodes - dynamic_descendants - context_descendants
597
+
598
+ return {
599
+ 'static': static_nodes,
600
+ 'cached_keys': self._cached_keys & all_nodes,
601
+ 'dynamic_keys': self._dynamic_keys & all_nodes,
602
+ 'dynamic_nodes': dynamic_nodes,
603
+ 'context_keys': self._context_keys & all_nodes,
604
+ 'context_dependent': context_dependent,
605
+ 'accumulator_keys': accumulator_keys & all_nodes,
606
+ 'target_keys': target_keys & all_nodes,
607
+ 'finalize_nodes': finalize_nodes,
608
+ }
609
+
610
+ def visualize(
611
+ self,
612
+ compact: bool = False,
613
+ mode: Literal['data', 'task', 'both'] = 'data',
614
+ cluster_generics: bool = True,
615
+ cluster_color: str | None = '#f0f0ff',
616
+ show_legend: bool = True,
617
+ show_static_dependencies: bool = True,
618
+ **kwargs: Any,
619
+ ) -> 'graphviz.Digraph':
620
+ """
621
+ Visualize the workflow with node classification styling.
622
+
623
+ This post-processes sciline's visualization to add styling that highlights:
624
+
625
+ - Static nodes (gray): Pre-computed once, dependencies of cached nodes
626
+ - Static cached nodes (gray, thick border): Pre-computed and cached
627
+ - Dynamic keys (green, thick border): Input entry points for chunks
628
+ - Dynamic nodes (light green): Recomputed for each chunk
629
+ - Context keys (blue, thick border): Input entry points for context
630
+ - Context-dependent nodes (light blue): Cached until context changes
631
+ - Accumulator keys (orange cylinder): Aggregation points
632
+ - Finalize nodes (plum): Computed from accumulators during finalize
633
+ - Target keys (double border): Final outputs
634
+
635
+ Parameters
636
+ ----------
637
+ compact:
638
+ If True, parameter-table-dependent branches are collapsed.
639
+ mode:
640
+ 'data' shows only data nodes, 'task' shows task nodes, 'both' shows all.
641
+ cluster_generics:
642
+ If True, generic products are grouped into clusters.
643
+ cluster_color:
644
+ Background color of clusters. If None, clusters are dotted.
645
+ show_legend:
646
+ If True, add a legend explaining the node styles.
647
+ show_static_dependencies:
648
+ If True (default), show all static nodes including dependencies of cached
649
+ nodes. If False, hide the ancestors of cached nodes to simplify the graph.
650
+ **kwargs:
651
+ Additional arguments passed to graphviz.Digraph.
652
+
653
+ Returns
654
+ -------
655
+ :
656
+ A graphviz.Digraph with styled nodes.
657
+ """
658
+ viz_workflow = self._get_viz_workflow()
659
+ if not show_static_dependencies:
660
+ # Create a pruned workflow that hides ancestors of cached nodes
661
+ viz_workflow = viz_workflow.copy()
662
+ for key in self._cached_keys:
663
+ if key in viz_workflow.underlying_graph:
664
+ viz_workflow[key] = None
665
+
666
+ graph = viz_workflow.underlying_graph
667
+
668
+ dot = viz_workflow.visualize(
669
+ compact=compact,
670
+ mode=mode,
671
+ cluster_generics=cluster_generics,
672
+ cluster_color=cluster_color,
673
+ **kwargs,
674
+ )
675
+
676
+ classifications = self._classify_nodes(graph)
677
+
678
+ # Build a mapping from formatted names to original keys
679
+ key_to_formatted: dict[Any, str] = {}
680
+ for key in graph.nodes:
681
+ key_to_formatted[key] = _format_key_for_graphviz(key)
682
+
683
+ # Apply styles by re-adding nodes with updated attributes
684
+ for key, formatted_name in key_to_formatted.items():
685
+ style = _get_node_style(key, classifications)
686
+ if style:
687
+ dot.node(formatted_name, **style)
688
+
689
+ # Add legend
690
+ if show_legend:
691
+ _add_legend(dot, show_static_dependencies=show_static_dependencies)
692
+
693
+ return dot
694
+
534
695
 
535
696
  def _find_descendants(
536
- workflow: sciline.Pipeline, keys: tuple[sciline.typing.Key, ...]
697
+ source: sciline.Pipeline | nx.DiGraph,
698
+ keys: set[sciline.typing.Key] | tuple[sciline.typing.Key, ...],
537
699
  ) -> set[sciline.typing.Key]:
538
- graph = workflow.underlying_graph
700
+ """Find all descendants of any key in keys, including the keys themselves."""
701
+ graph = source.underlying_graph if hasattr(source, 'underlying_graph') else source
702
+ keys_set = set(keys)
539
703
  descendants = set()
540
- for key in keys:
541
- descendants |= nx.descendants(graph, key)
542
- return descendants | set(keys)
704
+ for key in keys_set:
705
+ if key in graph:
706
+ descendants |= nx.descendants(graph, key)
707
+ return descendants | (keys_set & set(graph.nodes))
543
708
 
544
709
 
545
710
  def _find_parents(
@@ -550,3 +715,194 @@ def _find_parents(
550
715
  for key in keys:
551
716
  parents |= set(graph.predecessors(key))
552
717
  return parents
718
+
719
+
720
+ # =============================================================================
721
+ # Visualization helpers
722
+ # =============================================================================
723
+
724
+ # Style definitions for each node category
725
+ # Priority order for overlapping categories (higher = takes precedence for fill)
726
+ _VIZ_STYLES = {
727
+ 'static': {
728
+ 'fillcolor': '#e8e8e8', # Gray
729
+ 'style': 'filled',
730
+ 'priority': 0,
731
+ },
732
+ 'cached_keys': {
733
+ 'fillcolor': '#e8e8e8', # Gray (same as static)
734
+ 'style': 'filled',
735
+ 'penwidth': '2.5', # Thick border to distinguish from dependencies
736
+ 'priority': 1,
737
+ },
738
+ 'context_dependent': {
739
+ 'fillcolor': '#d4e8f4', # Light blue
740
+ 'style': 'filled',
741
+ 'priority': 2,
742
+ },
743
+ 'context_keys': {
744
+ 'fillcolor': '#87CEEB', # Sky blue
745
+ 'style': 'filled',
746
+ 'penwidth': '2.5',
747
+ 'color': 'black', # Override sciline's red for unsatisfied
748
+ 'fontcolor': 'black',
749
+ 'priority': 3,
750
+ },
751
+ 'dynamic_nodes': {
752
+ 'fillcolor': '#d4f4d4', # Light green
753
+ 'style': 'filled',
754
+ 'priority': 4,
755
+ },
756
+ 'dynamic_keys': {
757
+ 'fillcolor': '#90EE90', # Light green (stronger)
758
+ 'style': 'filled',
759
+ 'penwidth': '2.5',
760
+ 'color': 'black', # Override sciline's red for unsatisfied
761
+ 'fontcolor': 'black',
762
+ 'priority': 5,
763
+ },
764
+ 'accumulator_keys': {
765
+ 'fillcolor': '#FFB347', # Orange
766
+ 'style': 'filled',
767
+ 'shape': 'cylinder',
768
+ 'priority': 6,
769
+ },
770
+ 'finalize_nodes': {
771
+ 'fillcolor': '#DDA0DD', # Plum (more distinct from cluster color)
772
+ 'style': 'filled',
773
+ 'priority': 7,
774
+ },
775
+ 'target_keys': {
776
+ 'peripheries': '2', # Double border
777
+ 'priority': 8,
778
+ },
779
+ }
780
+
781
+
782
+ def _format_key_for_graphviz(key: Any) -> str:
783
+ """Format a key to match sciline's node naming convention."""
784
+ from sciline.visualize import _format_type
785
+
786
+ return _format_type(key).name
787
+
788
+
789
+ def _get_node_style(key: Any, classifications: dict[str, set[Any]]) -> dict[str, str]:
790
+ """
791
+ Determine the style for a node based on its classifications.
792
+
793
+ A node can belong to multiple categories. We combine styles with
794
+ higher priority categories taking precedence for conflicting attributes.
795
+ """
796
+ applicable = []
797
+ for category, keys in classifications.items():
798
+ if key in keys:
799
+ applicable.append((_VIZ_STYLES[category]['priority'], category))
800
+
801
+ if not applicable:
802
+ return {}
803
+
804
+ # Sort by priority and merge styles
805
+ applicable.sort()
806
+ merged: dict[str, str] = {}
807
+ for _, category in applicable:
808
+ style = _VIZ_STYLES[category].copy()
809
+ style.pop('priority')
810
+ merged.update(style)
811
+
812
+ return merged
813
+
814
+
815
+ def _add_legend(
816
+ dot: 'graphviz.Digraph', *, show_static_dependencies: bool = True
817
+ ) -> None:
818
+ """Add a legend subgraph explaining the node styles."""
819
+ with dot.subgraph(name='cluster_legend') as legend:
820
+ legend.attr(label='Legend', fontsize='14', style='rounded')
821
+ legend.attr('node', shape='rectangle', width='1.5', height='0.3')
822
+
823
+ # Track first node for edge chaining
824
+ prev_node = None
825
+
826
+ if show_static_dependencies:
827
+ legend.node(
828
+ 'legend_static',
829
+ 'Static',
830
+ fillcolor='#e8e8e8',
831
+ style='filled',
832
+ )
833
+ prev_node = 'legend_static'
834
+
835
+ legend.node(
836
+ 'legend_cached',
837
+ 'Static (cached)',
838
+ fillcolor='#e8e8e8',
839
+ style='filled',
840
+ penwidth='2.5',
841
+ )
842
+ if prev_node:
843
+ legend.edge(prev_node, 'legend_cached', style='invis')
844
+ prev_node = 'legend_cached'
845
+
846
+ legend.node(
847
+ 'legend_context_key',
848
+ 'Context key (input)',
849
+ fillcolor='#87CEEB',
850
+ style='filled',
851
+ penwidth='2.5',
852
+ )
853
+ legend.edge(prev_node, 'legend_context_key', style='invis')
854
+ prev_node = 'legend_context_key'
855
+
856
+ legend.node(
857
+ 'legend_context_dep',
858
+ 'Context-dependent',
859
+ fillcolor='#d4e8f4',
860
+ style='filled',
861
+ )
862
+ legend.edge(prev_node, 'legend_context_dep', style='invis')
863
+ prev_node = 'legend_context_dep'
864
+
865
+ legend.node(
866
+ 'legend_dynamic_key',
867
+ 'Dynamic key (input)',
868
+ fillcolor='#90EE90',
869
+ style='filled',
870
+ penwidth='2.5',
871
+ )
872
+ legend.edge(prev_node, 'legend_dynamic_key', style='invis')
873
+ prev_node = 'legend_dynamic_key'
874
+
875
+ legend.node(
876
+ 'legend_dynamic_node',
877
+ 'Dynamic (per chunk)',
878
+ fillcolor='#d4f4d4',
879
+ style='filled',
880
+ )
881
+ legend.edge(prev_node, 'legend_dynamic_node', style='invis')
882
+ prev_node = 'legend_dynamic_node'
883
+
884
+ legend.node(
885
+ 'legend_accumulator',
886
+ 'Accumulator',
887
+ fillcolor='#FFB347',
888
+ style='filled',
889
+ shape='cylinder',
890
+ )
891
+ legend.edge(prev_node, 'legend_accumulator', style='invis')
892
+ prev_node = 'legend_accumulator'
893
+
894
+ legend.node(
895
+ 'legend_finalize',
896
+ 'Finalize (from accum.)',
897
+ fillcolor='#DDA0DD',
898
+ style='filled',
899
+ )
900
+ legend.edge(prev_node, 'legend_finalize', style='invis')
901
+ prev_node = 'legend_finalize'
902
+
903
+ legend.node(
904
+ 'legend_target',
905
+ 'Target (output)',
906
+ peripheries='2',
907
+ )
908
+ legend.edge(prev_node, 'legend_target', style='invis')
@@ -33,6 +33,8 @@ from .types import (
33
33
  TofLookupTable,
34
34
  TofLookupTableFilename,
35
35
  TofMonitor,
36
+ WavelengthDetector,
37
+ WavelengthMonitor,
36
38
  )
37
39
  from .workflow import GenericTofWorkflow
38
40
 
@@ -60,6 +62,8 @@ __all__ = [
60
62
  "TofLookupTableFilename",
61
63
  "TofLookupTableWorkflow",
62
64
  "TofMonitor",
65
+ "WavelengthDetector",
66
+ "WavelengthMonitor",
63
67
  "providers",
64
68
  "simulate_chopper_cascade_using_tof",
65
69
  ]
@@ -14,6 +14,7 @@ import scipp as sc
14
14
  import scippneutron as scn
15
15
  import scippnexus as snx
16
16
  from scippneutron._utils import elem_unit
17
+ from scippneutron.conversion.tof import wavelength_from_tof
17
18
 
18
19
  try:
19
20
  from .interpolator_numba import Interpolator as InterpolatorImpl
@@ -39,6 +40,8 @@ from .types import (
39
40
  TofDetector,
40
41
  TofLookupTable,
41
42
  TofMonitor,
43
+ WavelengthDetector,
44
+ WavelengthMonitor,
42
45
  )
43
46
 
44
47
 
@@ -401,14 +404,15 @@ def _compute_tof_data(
401
404
  ) -> sc.DataArray:
402
405
  if da.bins is None:
403
406
  data = _time_of_flight_data_histogram(da=da, lookup=lookup, ltotal=ltotal)
404
- return rebin_strictly_increasing(data, dim='tof')
407
+ out = rebin_strictly_increasing(data, dim='tof')
405
408
  else:
406
- return _time_of_flight_data_events(
409
+ out = _time_of_flight_data_events(
407
410
  da=da,
408
411
  lookup=lookup,
409
412
  ltotal=ltotal,
410
413
  pulse_stride_offset=pulse_stride_offset,
411
414
  )
415
+ return out.assign_coords(Ltotal=ltotal)
412
416
 
413
417
 
414
418
  def detector_time_of_flight_data(
@@ -418,8 +422,9 @@ def detector_time_of_flight_data(
418
422
  pulse_stride_offset: PulseStrideOffset,
419
423
  ) -> TofDetector[RunType]:
420
424
  """
421
- Convert the time-of-arrival data to time-of-flight data using a lookup table.
422
- The output data will have a time-of-flight coordinate.
425
+ Convert the time-of-arrival (event_time_offset) data to time-of-flight data using a
426
+ lookup table.
427
+ The output data will have two new coordinates: time-of-flight and Ltotal.
423
428
 
424
429
  Parameters
425
430
  ----------
@@ -452,8 +457,9 @@ def monitor_time_of_flight_data(
452
457
  pulse_stride_offset: PulseStrideOffset,
453
458
  ) -> TofMonitor[RunType, MonitorType]:
454
459
  """
455
- Convert the time-of-arrival data to time-of-flight data using a lookup table.
456
- The output data will have a time-of-flight coordinate.
460
+ Convert the time-of-arrival (event_time_offset) data to time-of-flight data using a
461
+ lookup table.
462
+ The output data will have two new coordinates: time-of-flight and Ltotal.
457
463
 
458
464
  Parameters
459
465
  ----------
@@ -526,6 +532,47 @@ def detector_time_of_arrival_data(
526
532
  return ToaDetector[RunType](result)
527
533
 
528
534
 
535
+ def _tof_to_wavelength(da: sc.DataArray) -> sc.DataArray:
536
+ """
537
+ Convert time-of-flight data to wavelength data.
538
+
539
+ Here we assume that the input data contains a Ltotal coordinate, which is required
540
+ for the conversion.
541
+ This coordinate is assigned in the ``_compute_tof_data`` function.
542
+ """
543
+ return da.transform_coords(
544
+ 'wavelength', graph={"wavelength": wavelength_from_tof}, keep_intermediate=False
545
+ )
546
+
547
+
548
+ def detector_wavelength_data(
549
+ detector_data: TofDetector[RunType],
550
+ ) -> WavelengthDetector[RunType]:
551
+ """
552
+ Convert time-of-flight coordinate of the detector data to wavelength.
553
+
554
+ Parameters
555
+ ----------
556
+ da:
557
+ Detector data with time-of-flight coordinate.
558
+ """
559
+ return WavelengthDetector[RunType](_tof_to_wavelength(detector_data))
560
+
561
+
562
+ def monitor_wavelength_data(
563
+ monitor_data: TofMonitor[RunType, MonitorType],
564
+ ) -> WavelengthMonitor[RunType, MonitorType]:
565
+ """
566
+ Convert time-of-flight coordinate of the monitor data to wavelength.
567
+
568
+ Parameters
569
+ ----------
570
+ da:
571
+ Monitor data with time-of-flight coordinate.
572
+ """
573
+ return WavelengthMonitor[RunType, MonitorType](_tof_to_wavelength(monitor_data))
574
+
575
+
529
576
  def providers() -> tuple[Callable]:
530
577
  """
531
578
  Providers of the time-of-flight workflow.
@@ -536,4 +583,6 @@ def providers() -> tuple[Callable]:
536
583
  detector_ltotal_from_straight_line_approximation,
537
584
  monitor_ltotal_from_straight_line_approximation,
538
585
  detector_time_of_arrival_data,
586
+ detector_wavelength_data,
587
+ monitor_wavelength_data,
539
588
  )
@@ -4,7 +4,6 @@
4
4
  Utilities for computing time-of-flight lookup tables from neutron simulations.
5
5
  """
6
6
 
7
- import math
8
7
  from dataclasses import dataclass
9
8
  from typing import NewType
10
9
 
@@ -17,11 +16,10 @@ from .types import TofLookupTable
17
16
 
18
17
 
19
18
  @dataclass
20
- class SimulationResults:
19
+ class BeamlineComponentReading:
21
20
  """
22
- Results of a time-of-flight simulation used to create a lookup table.
23
-
24
- The results (apart from ``distance``) should be flat lists (1d arrays) of length N
21
+ Reading at a given position along the beamline from a time-of-flight simulation.
22
+ The data (apart from ``distance``) should be flat lists (1d arrays) of length N
25
23
  where N is the number of neutrons, containing the properties of the neutrons in the
26
24
  simulation.
27
25
 
@@ -40,20 +38,42 @@ class SimulationResults:
40
38
  For a ``tof`` simulation, this is just the position of the detector where the
41
39
  events are recorded. For a ``McStas`` simulation, this is the distance between
42
40
  the source and the event monitor.
43
- choppers:
44
- The parameters of the choppers used in the simulation (if any).
45
41
  """
46
42
 
47
43
  time_of_arrival: sc.Variable
48
44
  wavelength: sc.Variable
49
45
  weight: sc.Variable
50
46
  distance: sc.Variable
51
- choppers: DiskChoppers[AnyRun] | None = None
52
47
 
53
48
  def __post_init__(self):
54
49
  self.speed = (sc.constants.h / sc.constants.m_n) / self.wavelength
55
50
 
56
51
 
52
+ @dataclass
53
+ class SimulationResults:
54
+ """
55
+ Results of a time-of-flight simulation used to create a lookup table.
56
+ It should contain readings at various positions along the beamline, e.g., at
57
+ the source and after each chopper.
58
+ It also contains the chopper parameters used in the simulation, so it can be
59
+ determined if this simulation is compatible with a given experiment.
60
+
61
+ Parameters
62
+ ----------
63
+ readings:
64
+ A dict of :class:`BeamlineComponentReading` objects representing the readings at
65
+ various positions along the beamline. The keys in the dict should correspond to
66
+ the names of the components (e.g., 'source', 'chopper1', etc.).
67
+ choppers:
68
+ The chopper parameters used in the simulation (if any). These are used to verify
69
+ that the simulation is compatible with a given experiment (comparing chopper
70
+ openings, frequencies, phases, etc.).
71
+ """
72
+
73
+ readings: dict[str, BeamlineComponentReading]
74
+ choppers: DiskChoppers[AnyRun] | None = None
75
+
76
+
57
77
  NumberOfSimulatedNeutrons = NewType("NumberOfSimulatedNeutrons", int)
58
78
  """
59
79
  Number of neutrons simulated in the simulation that is used to create the lookup table.
@@ -144,11 +164,10 @@ def _mask_large_uncertainty(table: sc.DataArray, error_threshold: float):
144
164
  table.values[mask.values] = np.nan
145
165
 
146
166
 
147
- def _compute_mean_tof_in_distance_range(
148
- simulation: SimulationResults,
149
- distance_bins: sc.Variable,
167
+ def _compute_mean_tof(
168
+ simulation: BeamlineComponentReading,
169
+ distance: sc.Variable,
150
170
  time_bins: sc.Variable,
151
- distance_unit: str,
152
171
  time_unit: str,
153
172
  frame_period: sc.Variable,
154
173
  time_bins_half_width: sc.Variable,
@@ -161,12 +180,10 @@ def _compute_mean_tof_in_distance_range(
161
180
  ----------
162
181
  simulation:
163
182
  Results of a time-of-flight simulation used to create a lookup table.
164
- distance_bins:
165
- Bin edges for the distance axis in the lookup table.
183
+ distance:
184
+ Distance where table is computed.
166
185
  time_bins:
167
186
  Bin edges for the event_time_offset axis in the lookup table.
168
- distance_unit:
169
- Unit of the distance axis.
170
187
  time_unit:
171
188
  Unit of the event_time_offset axis.
172
189
  frame_period:
@@ -174,23 +191,17 @@ def _compute_mean_tof_in_distance_range(
174
191
  time_bins_half_width:
175
192
  Half width of the time bins in the event_time_offset axis.
176
193
  """
177
- simulation_distance = simulation.distance.to(unit=distance_unit)
178
- distances = sc.midpoints(distance_bins)
194
+ travel_length = distance - simulation.distance.to(unit=distance.unit)
179
195
  # Compute arrival and flight times for all neutrons
180
- toas = simulation.time_of_arrival + (distances / simulation.speed).to(
196
+ toas = simulation.time_of_arrival + (travel_length / simulation.speed).to(
181
197
  unit=time_unit, copy=False
182
198
  )
183
- dist = distances + simulation_distance
184
- tofs = dist * (sc.constants.m_n / sc.constants.h) * simulation.wavelength
199
+ tofs = distance / simulation.speed
185
200
 
186
201
  data = sc.DataArray(
187
- data=sc.broadcast(simulation.weight, sizes=toas.sizes),
188
- coords={
189
- "toa": toas,
190
- "tof": tofs.to(unit=time_unit, copy=False),
191
- "distance": dist,
192
- },
193
- ).flatten(to="event")
202
+ data=simulation.weight,
203
+ coords={"toa": toas, "tof": tofs.to(unit=time_unit, copy=False)},
204
+ )
194
205
 
195
206
  # Add the event_time_offset coordinate, wrapped to the frame_period
196
207
  data.coords['event_time_offset'] = data.coords['toa'] % frame_period
@@ -204,18 +215,14 @@ def _compute_mean_tof_in_distance_range(
204
215
  # data.coords['event_time_offset'] %= pulse_period - time_bins_half_width
205
216
  data.coords['event_time_offset'] %= frame_period - time_bins_half_width
206
217
 
207
- binned = data.bin(
208
- distance=distance_bins + simulation_distance, event_time_offset=time_bins
209
- )
210
-
218
+ binned = data.bin(event_time_offset=time_bins)
219
+ binned_sum = binned.bins.sum()
211
220
  # Weighted mean of tof inside each bin
212
- mean_tof = (
213
- binned.bins.data * binned.bins.coords["tof"]
214
- ).bins.sum() / binned.bins.sum()
221
+ mean_tof = (binned.bins.data * binned.bins.coords["tof"]).bins.sum() / binned_sum
215
222
  # Compute the variance of the tofs to track regions with large uncertainty
216
223
  variance = (
217
224
  binned.bins.data * (binned.bins.coords["tof"] - mean_tof) ** 2
218
- ).bins.sum() / binned.bins.sum()
225
+ ).bins.sum() / binned_sum
219
226
 
220
227
  mean_tof.variances = variance.values
221
228
  return mean_tof
@@ -293,15 +300,14 @@ def make_tof_lookup_table(
293
300
  left edge.
294
301
  """
295
302
  distance_unit = "m"
296
- time_unit = simulation.time_of_arrival.unit
303
+ time_unit = "us"
297
304
  res = distance_resolution.to(unit=distance_unit)
298
305
  pulse_period = pulse_period.to(unit=time_unit)
299
306
  frame_period = pulse_period * pulse_stride
300
307
 
301
- min_dist, max_dist = (
302
- x.to(unit=distance_unit) - simulation.distance.to(unit=distance_unit)
303
- for x in ltotal_range
304
- )
308
+ min_dist = ltotal_range[0].to(unit=distance_unit)
309
+ max_dist = ltotal_range[1].to(unit=distance_unit)
310
+
305
311
  # We need to bin the data below, to compute the weighted mean of the wavelength.
306
312
  # This results in data with bin edges.
307
313
  # However, the 2d interpolator expects bin centers.
@@ -313,7 +319,7 @@ def make_tof_lookup_table(
313
319
  # ensure that the last bin is not cut off. We want the upper edge to be higher than
314
320
  # the maximum distance, hence we pad with an additional 1.5 x resolution.
315
321
  pad = 2.0 * res
316
- distance_bins = sc.arange('distance', min_dist - pad, max_dist + pad, res)
322
+ distances = sc.arange('distance', min_dist - pad, max_dist + pad, res)
317
323
 
318
324
  # Create some time bins for event_time_offset.
319
325
  # We want our final table to strictly cover the range [0, frame_period].
@@ -328,22 +334,36 @@ def make_tof_lookup_table(
328
334
  time_bins_half_width = 0.5 * (time_bins[1] - time_bins[0])
329
335
  time_bins -= time_bins_half_width
330
336
 
331
- # To avoid a too large RAM usage, we compute the table in chunks, and piece them
332
- # together at the end.
333
- ndist = len(distance_bins) - 1
334
- max_size = 2e7
335
- total_size = ndist * len(simulation.time_of_arrival)
336
- nchunks = math.ceil(total_size / max_size)
337
- chunk_size = math.ceil(ndist / nchunks)
337
+ # Sort simulation readings by reverse distance
338
+ sorted_simulation_results = sorted(
339
+ simulation.readings.values(), key=lambda x: x.distance.value, reverse=True
340
+ )
341
+
338
342
  pieces = []
339
- for i in range(nchunks):
340
- dist_edges = distance_bins[i * chunk_size : (i + 1) * chunk_size + 1]
343
+ # To avoid large RAM usage, and having to split the distances into chunks
344
+ # according to which component reading to use, we simply loop over distances one
345
+ # by one here.
346
+ for dist in distances:
347
+ # Find the correct simulation reading
348
+ simulation_reading = None
349
+ for reading in sorted_simulation_results:
350
+ if dist.value >= reading.distance.to(unit=dist.unit).value:
351
+ simulation_reading = reading
352
+ break
353
+ if simulation_reading is None:
354
+ closest = sorted_simulation_results[-1]
355
+ raise ValueError(
356
+ "Building the Tof lookup table failed: the requested position "
357
+ f"{dist.value} {dist.unit} is before the component with the lowest "
358
+ "distance in the simulation. The first component in the beamline "
359
+ f"has distance {closest.distance.value} {closest.distance.unit}."
360
+ )
361
+
341
362
  pieces.append(
342
- _compute_mean_tof_in_distance_range(
343
- simulation=simulation,
344
- distance_bins=dist_edges,
363
+ _compute_mean_tof(
364
+ simulation=simulation_reading,
365
+ distance=dist,
345
366
  time_bins=time_bins,
346
- distance_unit=distance_unit,
347
367
  time_unit=time_unit,
348
368
  frame_period=frame_period,
349
369
  time_bins_half_width=time_bins_half_width,
@@ -351,7 +371,6 @@ def make_tof_lookup_table(
351
371
  )
352
372
 
353
373
  table = sc.concat(pieces, 'distance')
354
- table.coords["distance"] = sc.midpoints(table.coords["distance"])
355
374
  table.coords["event_time_offset"] = sc.midpoints(table.coords["event_time_offset"])
356
375
 
357
376
  # Copy the left edge to the right to create periodic boundary conditions
@@ -360,7 +379,7 @@ def make_tof_lookup_table(
360
379
  [table.data, table.data['event_time_offset', 0]], dim='event_time_offset'
361
380
  ),
362
381
  coords={
363
- "distance": table.coords["distance"],
382
+ "distance": distances,
364
383
  "event_time_offset": sc.concat(
365
384
  [table.coords["event_time_offset"], frame_period],
366
385
  dim='event_time_offset',
@@ -387,6 +406,24 @@ def make_tof_lookup_table(
387
406
  )
388
407
 
389
408
 
409
+ def _to_component_reading(component):
410
+ events = component.data.squeeze().flatten(to='event')
411
+ sel = sc.full(value=True, sizes=events.sizes)
412
+ for key in {'blocked_by_others', 'blocked_by_me'} & set(events.masks.keys()):
413
+ sel &= ~events.masks[key]
414
+ events = events[sel]
415
+ # If the component is a source, use 'birth_time' as 'toa'
416
+ toa = (
417
+ events.coords["toa"] if "toa" in events.coords else events.coords["birth_time"]
418
+ )
419
+ return BeamlineComponentReading(
420
+ time_of_arrival=toa,
421
+ wavelength=events.coords["wavelength"],
422
+ weight=events.data,
423
+ distance=component.distance,
424
+ )
425
+
426
+
390
427
  def simulate_chopper_cascade_using_tof(
391
428
  choppers: DiskChoppers[AnyRun],
392
429
  source_position: SourcePosition,
@@ -431,29 +468,14 @@ def simulate_chopper_cascade_using_tof(
431
468
  source = tof.Source(
432
469
  facility=facility, neutrons=neutrons, pulses=pulse_stride, seed=seed
433
470
  )
471
+ sim_readings = {"source": _to_component_reading(source)}
434
472
  if not tof_choppers:
435
- events = source.data.squeeze().flatten(to='event')
436
- return SimulationResults(
437
- time_of_arrival=events.coords["birth_time"],
438
- wavelength=events.coords["wavelength"],
439
- weight=events.data,
440
- distance=0.0 * sc.units.m,
441
- )
473
+ return SimulationResults(readings=sim_readings, choppers=None)
442
474
  model = tof.Model(source=source, choppers=tof_choppers)
443
475
  results = model.run()
444
- # Find name of the furthest chopper in tof_choppers
445
- furthest_chopper = max(tof_choppers, key=lambda c: c.distance)
446
- events = results[furthest_chopper.name].data.squeeze().flatten(to='event')
447
- events = events[
448
- ~(events.masks["blocked_by_others"] | events.masks["blocked_by_me"])
449
- ]
450
- return SimulationResults(
451
- time_of_arrival=events.coords["toa"],
452
- wavelength=events.coords["wavelength"],
453
- weight=events.data,
454
- distance=furthest_chopper.distance,
455
- choppers=choppers,
456
- )
476
+ for name, ch in results.choppers.items():
477
+ sim_readings[name] = _to_component_reading(ch)
478
+ return SimulationResults(readings=sim_readings, choppers=choppers)
457
479
 
458
480
 
459
481
  def TofLookupTableWorkflow():
@@ -87,3 +87,11 @@ class ToaDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray):
87
87
 
88
88
  class TofMonitor(sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray):
89
89
  """Monitor data with time-of-flight coordinate."""
90
+
91
+
92
+ class WavelengthDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray):
93
+ """Detector data with wavelength coordinate."""
94
+
95
+
96
+ class WavelengthMonitor(sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray):
97
+ """Monitor data with wavelength coordinate."""
@@ -7,11 +7,7 @@ import scipp as sc
7
7
 
8
8
  from ..nexus import GenericNeXusWorkflow
9
9
  from . import eto_to_tof
10
- from .types import (
11
- PulseStrideOffset,
12
- TofLookupTable,
13
- TofLookupTableFilename,
14
- )
10
+ from .types import PulseStrideOffset, TofLookupTable, TofLookupTableFilename
15
11
 
16
12
 
17
13
  def load_tof_lookup_table(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: essreduce
3
- Version: 26.1.1
3
+ Version: 26.2.0
4
4
  Summary: Common data reduction tools for the ESS facility
5
5
  Author: Scipp contributors
6
6
  License-Expression: BSD-3-Clause
@@ -25,6 +25,7 @@ Requires-Dist: scipp>=25.04.0
25
25
  Requires-Dist: scippneutron>=25.11.1
26
26
  Requires-Dist: scippnexus>=25.06.0
27
27
  Provides-Extra: test
28
+ Requires-Dist: graphviz>=0.20; extra == "test"
28
29
  Requires-Dist: ipywidgets>=8.1; extra == "test"
29
30
  Requires-Dist: matplotlib>=3.10.7; extra == "test"
30
31
  Requires-Dist: numba>=0.59; extra == "test"
@@ -3,16 +3,15 @@ ess/reduce/logging.py,sha256=6n8Czq4LZ3OK9ENlKsWSI1M3KvKv6_HSoUiV4__IUlU,357
3
3
  ess/reduce/normalization.py,sha256=r8H6SZgT94a1HE9qZ6Bx3N6c3VG3FzlJPzoCVMNI5-0,13081
4
4
  ess/reduce/parameter.py,sha256=4sCfoKOI2HuO_Q7JLH_jAXnEOFANSn5P3NdaOBzhJxc,4635
5
5
  ess/reduce/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- ess/reduce/streaming.py,sha256=s-Pz6fhnaUPMjJ_lvSIdwF8WZjTK48lMDKQmCnqZf3c,18529
6
+ ess/reduce/streaming.py,sha256=zdVosXbrtys5Z9gJ1dhnqoLeg2t8abicHBluytmrHCs,31150
7
7
  ess/reduce/ui.py,sha256=zmorAbDwX1cU3ygDT--OP58o0qU7OBcmJz03jPeYSLA,10884
8
8
  ess/reduce/uncertainty.py,sha256=LR4O6ApB6Z-W9gC_XW0ajupl8yFG-du0eee1AX_R-gk,6990
9
9
  ess/reduce/workflow.py,sha256=738-lcdgsORYfQ4A0UTk2IgnbVxC3jBdpscpaOFIpdc,3114
10
10
  ess/reduce/data/__init__.py,sha256=uDtqkmKA_Zwtj6II25zntz9T812XhdCn3tktYev4uyY,486
11
11
  ess/reduce/data/_registry.py,sha256=ngJMzP-AuMN0EKtws5vYSEPsv_Bn3TZjjIvNUKWQDeA,13992
12
- ess/reduce/live/__init__.py,sha256=jPQVhihRVNtEDrE20PoKkclKV2aBF1lS7cCHootgFgI,204
12
+ ess/reduce/live/__init__.py,sha256=AFcqRbIhyqAo-oKV7gNIOfHRxfoVDiULf9vB4cEckHk,133
13
13
  ess/reduce/live/raw.py,sha256=z3JzKl1tOH51z1PWT3MJERactSFRXrNI_MBmpAHX71g,31094
14
14
  ess/reduce/live/roi.py,sha256=t65SfGtCtb8r-f4hkfg2I02CEnOp6Hh5Tv9qOqPOeK0,10588
15
- ess/reduce/live/workflow.py,sha256=bsbwvTqPhRO6mC__3b7MgU7DWwAnOvGvG-t2n22EKq8,4285
16
15
  ess/reduce/nexus/__init__.py,sha256=xXc982vZqRba4jR4z5hA2iim17Z7niw4KlS1aLFbn1Q,1107
17
16
  ess/reduce/nexus/_nexus_loader.py,sha256=5J26y_t-kabj0ik0jf3OLSYda3lDLDQhvPd2_ro7Q_0,23927
18
17
  ess/reduce/nexus/json_generator.py,sha256=ME2Xn8L7Oi3uHJk9ZZdCRQTRX-OV_wh9-DJn07Alplk,2529
@@ -20,15 +19,15 @@ ess/reduce/nexus/json_nexus.py,sha256=QrVc0p424nZ5dHX9gebAJppTw6lGZq9404P_OFl1gi
20
19
  ess/reduce/nexus/types.py,sha256=g5oBBEYPH7urF1tDP0tqXtixhQN8JDpe8vmiKrPiUW0,9320
21
20
  ess/reduce/nexus/workflow.py,sha256=bVRnVZ6HTEdIFwZv61JuvFUeTt9efUwe1MR65gBhyw8,24995
22
21
  ess/reduce/scripts/grow_nexus.py,sha256=hET3h06M0xlJd62E3palNLFvJMyNax2kK4XyJcOhl-I,3387
23
- ess/reduce/time_of_flight/__init__.py,sha256=bNxhK0uePltQpCW2sdNTpPdzXL6dt1IT1ri_cJ5VTL8,1561
24
- ess/reduce/time_of_flight/eto_to_tof.py,sha256=imKuN7IARMqBjmi8kjAcsseTFg6OD8ORas9X1FolgFY,18777
22
+ ess/reduce/time_of_flight/__init__.py,sha256=LENpuVY2nP-YKVqUEPlmnvWdVaWBEKPinPITk-2crrA,1659
23
+ ess/reduce/time_of_flight/eto_to_tof.py,sha256=aSQYzYW5rHhObGU9tu5UuZ-M0xfjnI0NdfJTTQa_-ww,20276
25
24
  ess/reduce/time_of_flight/fakes.py,sha256=BqpO56PQyO9ua7QlZw6xXMAPBrqjKZEM_jc-VB83CyE,4289
26
25
  ess/reduce/time_of_flight/interpolator_numba.py,sha256=wh2YS3j2rOu30v1Ok3xNHcwS7t8eEtZyZvbfXOCtgrQ,3835
27
26
  ess/reduce/time_of_flight/interpolator_scipy.py,sha256=_InoAPuMm2qhJKZQBAHOGRFqtvvuQ8TStoN7j_YgS4M,1853
28
- ess/reduce/time_of_flight/lut.py,sha256=aFaP6CnbTVkScTfCkyvyKt7PCZcJjUNdlS9l-sLHr8c,18637
27
+ ess/reduce/time_of_flight/lut.py,sha256=JKZCbDYXOME6c5nMX01VvsAPU0JQgWTy64s2jZOI-7s,19746
29
28
  ess/reduce/time_of_flight/resample.py,sha256=Opmi-JA4zNH725l9VB99U4O9UlM37f5ACTCGtwBcows,3718
30
- ess/reduce/time_of_flight/types.py,sha256=FsSueM6OjJdF80uJHj-TNuyVAci8ixFvMuRMt9oHKDQ,3310
31
- ess/reduce/time_of_flight/workflow.py,sha256=2jUxeSmP0KweQTctAzIFJLm7Odf_e7kZzAc8MAMKBEs,3084
29
+ ess/reduce/time_of_flight/types.py,sha256=Eb0l8_KlL7swORefRqbkMSCq6oxNeujU28-gSNXX9JU,3575
30
+ ess/reduce/time_of_flight/workflow.py,sha256=adcEAwUrxx2I6IP-f0RNeUqDlMbE2i3s7evZBmmicpI,3067
32
31
  ess/reduce/widgets/__init__.py,sha256=SoSHBv8Dc3QXV9HUvPhjSYWMwKTGYZLpsWwsShIO97Q,5325
33
32
  ess/reduce/widgets/_base.py,sha256=_wN3FOlXgx_u0c-A_3yyoIH-SdUvDENGgquh9S-h5GI,4852
34
33
  ess/reduce/widgets/_binedges_widget.py,sha256=ZCQsGjYHnJr9GFUn7NjoZc1CdsnAzm_fMzyF-fTKKVY,2785
@@ -41,9 +40,9 @@ ess/reduce/widgets/_spinner.py,sha256=2VY4Fhfa7HMXox2O7UbofcdKsYG-AJGrsgGJB85nDX
41
40
  ess/reduce/widgets/_string_widget.py,sha256=iPAdfANyXHf-nkfhgkyH6gQDklia0LebLTmwi3m-iYQ,1482
42
41
  ess/reduce/widgets/_switchable_widget.py,sha256=fjKz99SKLhIF1BLgGVBSKKn3Lu_jYBwDYGeAjbJY3Q8,2390
43
42
  ess/reduce/widgets/_vector_widget.py,sha256=aTaBqCFHZQhrIoX6-sSqFWCPePEW8HQt5kUio8jP1t8,1203
44
- essreduce-26.1.1.dist-info/licenses/LICENSE,sha256=nVEiume4Qj6jMYfSRjHTM2jtJ4FGu0g-5Sdh7osfEYw,1553
45
- essreduce-26.1.1.dist-info/METADATA,sha256=fwMRSA5kTQ26OFwAM_7xjpvdcz-C3KMiC02MK153XTE,1987
46
- essreduce-26.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
- essreduce-26.1.1.dist-info/entry_points.txt,sha256=PMZOIYzCifHMTe4pK3HbhxUwxjFaZizYlLD0td4Isb0,66
48
- essreduce-26.1.1.dist-info/top_level.txt,sha256=0JxTCgMKPLKtp14wb1-RKisQPQWX7i96innZNvHBr-s,4
49
- essreduce-26.1.1.dist-info/RECORD,,
43
+ essreduce-26.2.0.dist-info/licenses/LICENSE,sha256=nVEiume4Qj6jMYfSRjHTM2jtJ4FGu0g-5Sdh7osfEYw,1553
44
+ essreduce-26.2.0.dist-info/METADATA,sha256=BGHgAZ8pZ7iSElnZZ0eAFlUZQhxk1xjg8ivigzMu6h8,2034
45
+ essreduce-26.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
46
+ essreduce-26.2.0.dist-info/entry_points.txt,sha256=PMZOIYzCifHMTe4pK3HbhxUwxjFaZizYlLD0td4Isb0,66
47
+ essreduce-26.2.0.dist-info/top_level.txt,sha256=0JxTCgMKPLKtp14wb1-RKisQPQWX7i96innZNvHBr-s,4
48
+ essreduce-26.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,128 +0,0 @@
1
- # SPDX-License-Identifier: BSD-3-Clause
2
- # Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
3
- """Tools for creating live data reduction workflows for Beamlime."""
4
-
5
- from pathlib import Path
6
- from typing import NewType, TypeVar
7
-
8
- import sciline
9
- import scipp as sc
10
- import scippnexus as snx
11
-
12
- from ess.reduce import streaming
13
- from ess.reduce.nexus import types as nt
14
- from ess.reduce.nexus.json_nexus import JSONGroup
15
-
16
- JSONEventData = NewType('JSONEventData', dict[str, JSONGroup])
17
-
18
-
19
- def _load_json_event_data(name: str, nxevent_data: JSONEventData) -> sc.DataArray:
20
- return snx.Group(nxevent_data[name], definitions=snx.base_definitions())[()]
21
-
22
-
23
- def load_json_event_data_for_sample_run(
24
- name: nt.NeXusName[nt.Component], nxevent_data: JSONEventData
25
- ) -> nt.NeXusData[nt.Component, nt.SampleRun]:
26
- return nt.NeXusData[nt.Component, nt.SampleRun](
27
- _load_json_event_data(name, nxevent_data)
28
- )
29
-
30
-
31
- def load_json_event_data_for_sample_transmission_run(
32
- name: nt.NeXusName[nt.Component], nxevent_data: JSONEventData
33
- ) -> nt.NeXusData[nt.Component, nt.TransmissionRun[nt.SampleRun]]:
34
- return nt.NeXusData[nt.Component, nt.TransmissionRun[nt.SampleRun]](
35
- _load_json_event_data(name, nxevent_data)
36
- )
37
-
38
-
39
- T = TypeVar('T', bound='LiveWorkflow')
40
-
41
-
42
- class LiveWorkflow:
43
- """A workflow class that fulfills Beamlime's LiveWorkflow protocol."""
44
-
45
- def __init__(
46
- self,
47
- *,
48
- streamed: streaming.StreamProcessor,
49
- outputs: dict[str, sciline.typing.Key],
50
- ) -> None:
51
- self._streamed = streamed
52
- self._outputs = outputs
53
-
54
- @classmethod
55
- def from_workflow(
56
- cls: type[T],
57
- *,
58
- workflow: sciline.Pipeline,
59
- accumulators: dict[sciline.typing.Key, streaming.Accumulator],
60
- outputs: dict[str, sciline.typing.Key],
61
- run_type: type[nt.RunType],
62
- nexus_filename: Path,
63
- ) -> T:
64
- """
65
- Create a live workflow from a base workflow and other parameters.
66
-
67
- Parameters
68
- ----------
69
- workflow:
70
- Base workflow to use for live data reduction.
71
- accumulators:
72
- Accumulators forwarded to the stream processor.
73
- outputs:
74
- Mapping from output names to keys in the workflow. The keys correspond to
75
- workflow results that will be computed.
76
- run_type:
77
- Type of the run to process. This defines which run is the dynamic run being
78
- processed. The NeXus template file will be set as the filename for this run.
79
- nexus_filename:
80
- Path to the NeXus file to process.
81
-
82
- Returns
83
- -------
84
- :
85
- Live workflow object.
86
- """
87
-
88
- workflow = workflow.copy()
89
- if run_type is nt.SampleRun:
90
- workflow.insert(load_json_event_data_for_sample_run)
91
- elif run_type is nt.TransmissionRun[nt.SampleRun]:
92
- workflow.insert(load_json_event_data_for_sample_transmission_run)
93
- else:
94
- raise NotImplementedError(f"Run type {run_type} not supported yet.")
95
- workflow[nt.Filename[run_type]] = nexus_filename
96
- streamed = streaming.StreamProcessor(
97
- base_workflow=workflow,
98
- dynamic_keys=(JSONEventData,),
99
- target_keys=outputs.values(),
100
- accumulators=accumulators,
101
- )
102
- return cls(streamed=streamed, outputs=outputs)
103
-
104
- def __call__(
105
- self, nxevent_data: dict[str, JSONGroup], nxlog: dict[str, JSONGroup]
106
- ) -> dict[str, sc.DataArray]:
107
- """
108
- Implements the __call__ method required by the LiveWorkflow protocol.
109
-
110
- Parameters
111
- ----------
112
- nxevent_data:
113
- NeXus event data.
114
- nxlog:
115
- NeXus log data. WARNING: This is currently not used.
116
-
117
- Returns
118
- -------
119
- :
120
- Dictionary of computed and plottable results.
121
- """
122
- # Beamlime passes full path, but the workflow only needs the name of the monitor
123
- # or detector group.
124
- nxevent_data = {
125
- key.lstrip('/').split('/')[2]: value for key, value in nxevent_data.items()
126
- }
127
- results = self._streamed.add_chunk({JSONEventData: nxevent_data})
128
- return {name: results[key] for name, key in self._outputs.items()}