ophyd-async 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. ophyd_async/__init__.py +1 -4
  2. ophyd_async/_version.py +2 -2
  3. ophyd_async/core/__init__.py +91 -19
  4. ophyd_async/core/_providers.py +68 -0
  5. ophyd_async/core/async_status.py +90 -42
  6. ophyd_async/core/detector.py +341 -0
  7. ophyd_async/core/device.py +226 -0
  8. ophyd_async/core/device_save_loader.py +286 -0
  9. ophyd_async/core/flyer.py +85 -0
  10. ophyd_async/core/mock_signal_backend.py +82 -0
  11. ophyd_async/core/mock_signal_utils.py +145 -0
  12. ophyd_async/core/{_device/_signal/signal.py → signal.py} +249 -61
  13. ophyd_async/core/{_device/_backend/signal_backend.py → signal_backend.py} +12 -5
  14. ophyd_async/core/{_device/_backend/sim_signal_backend.py → soft_signal_backend.py} +54 -48
  15. ophyd_async/core/standard_readable.py +261 -0
  16. ophyd_async/core/utils.py +127 -30
  17. ophyd_async/epics/_backend/_aioca.py +62 -43
  18. ophyd_async/epics/_backend/_p4p.py +100 -52
  19. ophyd_async/epics/_backend/common.py +25 -0
  20. ophyd_async/epics/areadetector/__init__.py +16 -15
  21. ophyd_async/epics/areadetector/aravis.py +63 -0
  22. ophyd_async/epics/areadetector/controllers/__init__.py +5 -0
  23. ophyd_async/epics/areadetector/controllers/ad_sim_controller.py +52 -0
  24. ophyd_async/epics/areadetector/controllers/aravis_controller.py +78 -0
  25. ophyd_async/epics/areadetector/controllers/kinetix_controller.py +49 -0
  26. ophyd_async/epics/areadetector/controllers/pilatus_controller.py +61 -0
  27. ophyd_async/epics/areadetector/controllers/vimba_controller.py +66 -0
  28. ophyd_async/epics/areadetector/drivers/__init__.py +21 -0
  29. ophyd_async/epics/areadetector/drivers/ad_base.py +107 -0
  30. ophyd_async/epics/areadetector/drivers/aravis_driver.py +38 -0
  31. ophyd_async/epics/areadetector/drivers/kinetix_driver.py +27 -0
  32. ophyd_async/epics/areadetector/drivers/pilatus_driver.py +21 -0
  33. ophyd_async/epics/areadetector/drivers/vimba_driver.py +63 -0
  34. ophyd_async/epics/areadetector/kinetix.py +46 -0
  35. ophyd_async/epics/areadetector/pilatus.py +45 -0
  36. ophyd_async/epics/areadetector/single_trigger_det.py +18 -10
  37. ophyd_async/epics/areadetector/utils.py +91 -13
  38. ophyd_async/epics/areadetector/vimba.py +43 -0
  39. ophyd_async/epics/areadetector/writers/__init__.py +5 -0
  40. ophyd_async/epics/areadetector/writers/_hdfdataset.py +10 -0
  41. ophyd_async/epics/areadetector/writers/_hdffile.py +54 -0
  42. ophyd_async/epics/areadetector/writers/hdf_writer.py +142 -0
  43. ophyd_async/epics/areadetector/writers/nd_file_hdf.py +40 -0
  44. ophyd_async/epics/areadetector/writers/nd_plugin.py +38 -0
  45. ophyd_async/epics/demo/__init__.py +78 -51
  46. ophyd_async/epics/demo/demo_ad_sim_detector.py +35 -0
  47. ophyd_async/epics/motion/motor.py +67 -52
  48. ophyd_async/epics/pvi/__init__.py +3 -0
  49. ophyd_async/epics/pvi/pvi.py +318 -0
  50. ophyd_async/epics/signal/__init__.py +8 -3
  51. ophyd_async/epics/signal/signal.py +27 -10
  52. ophyd_async/log.py +130 -0
  53. ophyd_async/panda/__init__.py +24 -7
  54. ophyd_async/panda/_common_blocks.py +49 -0
  55. ophyd_async/panda/_hdf_panda.py +48 -0
  56. ophyd_async/panda/_panda_controller.py +37 -0
  57. ophyd_async/panda/_table.py +158 -0
  58. ophyd_async/panda/_trigger.py +39 -0
  59. ophyd_async/panda/_utils.py +15 -0
  60. ophyd_async/panda/writers/__init__.py +3 -0
  61. ophyd_async/panda/writers/_hdf_writer.py +220 -0
  62. ophyd_async/panda/writers/_panda_hdf_file.py +58 -0
  63. ophyd_async/plan_stubs/__init__.py +13 -0
  64. ophyd_async/plan_stubs/ensure_connected.py +22 -0
  65. ophyd_async/plan_stubs/fly.py +149 -0
  66. ophyd_async/protocols.py +126 -0
  67. ophyd_async/sim/__init__.py +11 -0
  68. ophyd_async/sim/demo/__init__.py +3 -0
  69. ophyd_async/sim/demo/sim_motor.py +103 -0
  70. ophyd_async/sim/pattern_generator.py +318 -0
  71. ophyd_async/sim/sim_pattern_detector_control.py +55 -0
  72. ophyd_async/sim/sim_pattern_detector_writer.py +34 -0
  73. ophyd_async/sim/sim_pattern_generator.py +37 -0
  74. {ophyd_async-0.1.0.dist-info → ophyd_async-0.3.0.dist-info}/METADATA +35 -67
  75. ophyd_async-0.3.0.dist-info/RECORD +86 -0
  76. {ophyd_async-0.1.0.dist-info → ophyd_async-0.3.0.dist-info}/WHEEL +1 -1
  77. ophyd_async/core/_device/__init__.py +0 -0
  78. ophyd_async/core/_device/_backend/__init__.py +0 -0
  79. ophyd_async/core/_device/_signal/__init__.py +0 -0
  80. ophyd_async/core/_device/device.py +0 -60
  81. ophyd_async/core/_device/device_collector.py +0 -121
  82. ophyd_async/core/_device/device_vector.py +0 -14
  83. ophyd_async/core/_device/standard_readable.py +0 -72
  84. ophyd_async/epics/areadetector/ad_driver.py +0 -18
  85. ophyd_async/epics/areadetector/directory_provider.py +0 -18
  86. ophyd_async/epics/areadetector/hdf_streamer_det.py +0 -167
  87. ophyd_async/epics/areadetector/nd_file_hdf.py +0 -22
  88. ophyd_async/epics/areadetector/nd_plugin.py +0 -13
  89. ophyd_async/epics/signal/pvi_get.py +0 -22
  90. ophyd_async/panda/panda.py +0 -332
  91. ophyd_async-0.1.0.dist-info/RECORD +0 -45
  92. {ophyd_async-0.1.0.dist-info → ophyd_async-0.3.0.dist-info}/LICENSE +0 -0
  93. {ophyd_async-0.1.0.dist-info → ophyd_async-0.3.0.dist-info}/entry_points.txt +0 -0
  94. {ophyd_async-0.1.0.dist-info → ophyd_async-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
1
+ from enum import Enum
2
+ from functools import partial
3
+ from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import yaml
8
+ from bluesky.plan_stubs import abs_set, wait
9
+ from bluesky.protocols import Location
10
+ from bluesky.utils import Msg
11
+ from epicscorelibs.ca.dbr import ca_array, ca_float, ca_int, ca_str
12
+
13
+ from .device import Device
14
+ from .signal import SignalRW
15
+
16
+ CaType = Union[ca_float, ca_int, ca_str, ca_array]
17
+
18
+
19
+ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node:
20
+ return dumper.represent_sequence(
21
+ "tag:yaml.org,2002:seq", array.tolist(), flow_style=True
22
+ )
23
+
24
+
25
+ def ca_dbr_representer(dumper: yaml.Dumper, value: CaType) -> yaml.Node:
26
+ # if it's an array, just call ndarray_representer...
27
+ represent_array = partial(ndarray_representer, dumper)
28
+
29
+ representers: Dict[CaType, Callable[[CaType], yaml.Node]] = {
30
+ ca_float: dumper.represent_float,
31
+ ca_int: dumper.represent_int,
32
+ ca_str: dumper.represent_str,
33
+ ca_array: represent_array,
34
+ }
35
+ return representers[type(value)](value)
36
+
37
+
38
+ class OphydDumper(yaml.Dumper):
39
+ def represent_data(self, data: Any) -> Any:
40
+ if isinstance(data, Enum):
41
+ return self.represent_data(data.value)
42
+ return super(OphydDumper, self).represent_data(data)
43
+
44
+
45
+ def get_signal_values(
46
+ signals: Dict[str, SignalRW[Any]], ignore: Optional[List[str]] = None
47
+ ) -> Generator[Msg, Sequence[Location[Any]], Dict[str, Any]]:
48
+ """Get signal values in bulk.
49
+
50
+ Used as part of saving the signals of a device to a yaml file.
51
+
52
+ Parameters
53
+ ----------
54
+ signals : Dict[str, SignalRW]
55
+ Dictionary with pv names and matching SignalRW values. Often the direct result
56
+ of :func:`walk_rw_signals`.
57
+
58
+ ignore : Optional[List[str]]
59
+ Optional list of PVs that should be ignored.
60
+
61
+ Returns
62
+ -------
63
+ Dict[str, Any]
64
+ A dictionary containing pv names and their associated values. Ignored pvs are
65
+ set to None.
66
+
67
+ See Also
68
+ --------
69
+ :func:`ophyd_async.core.walk_rw_signals`
70
+ :func:`ophyd_async.core.save_to_yaml`
71
+ """
72
+
73
+ ignore = ignore or []
74
+ selected_signals = {
75
+ key: signal for key, signal in signals.items() if key not in ignore
76
+ }
77
+ selected_values = yield Msg("locate", *selected_signals.values())
78
+
79
+ # TODO: investigate wrong type hints
80
+ if isinstance(selected_values, dict):
81
+ selected_values = [selected_values] # type: ignore
82
+
83
+ assert selected_values is not None, "No signalRW's were able to be located"
84
+ named_values = {
85
+ key: value["setpoint"] for key, value in zip(selected_signals, selected_values)
86
+ }
87
+ # Ignored values place in with value None so we know which ones were ignored
88
+ named_values.update({key: None for key in ignore})
89
+ return named_values
90
+
91
+
92
+ def walk_rw_signals(
93
+ device: Device, path_prefix: Optional[str] = ""
94
+ ) -> Dict[str, SignalRW[Any]]:
95
+ """Retrieve all SignalRWs from a device.
96
+
97
+ Stores retrieved signals with their dotted attribute paths in a dictionary. Used as
98
+ part of saving and loading a device.
99
+
100
+ Parameters
101
+ ----------
102
+ device : Device
103
+ Ophyd device to retrieve read-write signals from.
104
+
105
+ path_prefix : str
106
+ For internal use, leave blank when calling the method.
107
+
108
+ Returns
109
+ -------
110
+ SignalRWs : dict
111
+ A dictionary matching the string attribute path of a SignalRW with the
112
+ signal itself.
113
+
114
+ See Also
115
+ --------
116
+ :func:`ophyd_async.core.get_signal_values`
117
+ :func:`ophyd_async.core.save_to_yaml`
118
+
119
+ """
120
+
121
+ if not path_prefix:
122
+ path_prefix = ""
123
+
124
+ signals: Dict[str, SignalRW[Any]] = {}
125
+ for attr_name, attr in device.children():
126
+ dot_path = f"{path_prefix}{attr_name}"
127
+ if type(attr) is SignalRW:
128
+ signals[dot_path] = attr
129
+ attr_signals = walk_rw_signals(attr, path_prefix=dot_path + ".")
130
+ signals.update(attr_signals)
131
+ return signals
132
+
133
+
134
+ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None:
135
+ """Plan which serialises a phase or set of phases of SignalRWs to a yaml file.
136
+
137
+ Parameters
138
+ ----------
139
+ phases : dict or list of dicts
140
+ The values to save. Each item in the list is a seperate phase used when loading
141
+ a device. In general this variable be the return value of `get_signal_values`.
142
+
143
+ save_path : str
144
+ Path of the yaml file to write to
145
+
146
+ See Also
147
+ --------
148
+ :func:`ophyd_async.core.walk_rw_signals`
149
+ :func:`ophyd_async.core.get_signal_values`
150
+ :func:`ophyd_async.core.load_from_yaml`
151
+ """
152
+
153
+ yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper)
154
+
155
+ yaml.add_representer(ca_float, ca_dbr_representer, Dumper=yaml.Dumper)
156
+ yaml.add_representer(ca_int, ca_dbr_representer, Dumper=yaml.Dumper)
157
+ yaml.add_representer(ca_str, ca_dbr_representer, Dumper=yaml.Dumper)
158
+ yaml.add_representer(ca_array, ca_dbr_representer, Dumper=yaml.Dumper)
159
+
160
+ with open(save_path, "w") as file:
161
+ yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False)
162
+
163
+
164
+ def load_from_yaml(save_path: str) -> Sequence[Dict[str, Any]]:
165
+ """Plan that returns a list of dicts with saved signal values from a yaml file.
166
+
167
+ Parameters
168
+ ----------
169
+ save_path : str
170
+ Path of the yaml file to load from
171
+
172
+ See Also
173
+ --------
174
+ :func:`ophyd_async.core.save_to_yaml`
175
+ :func:`ophyd_async.core.set_signal_values`
176
+ """
177
+ with open(save_path, "r") as file:
178
+ return yaml.full_load(file)
179
+
180
+
181
+ def set_signal_values(
182
+ signals: Dict[str, SignalRW[Any]], values: Sequence[Dict[str, Any]]
183
+ ) -> Generator[Msg, None, None]:
184
+ """Maps signals from a yaml file into device signals.
185
+
186
+ ``values`` contains signal values in phases, which are loaded in sequentially
187
+ into the provided signals, to ensure signals are set in the correct order.
188
+
189
+ Parameters
190
+ ----------
191
+ signals : Dict[str, SignalRW[Any]]
192
+ Dictionary of named signals to be updated if value found in values argument.
193
+ Can be the output of :func:`walk_rw_signals()` for a device.
194
+
195
+ values : Sequence[Dict[str, Any]]
196
+ List of dictionaries of signal name and value pairs, if a signal matches
197
+ the name of one in the signals argument, sets the signal to that value.
198
+ The groups of signals are loaded in their list order.
199
+ Can be the output of :func:`load_from_yaml()` for a yaml file.
200
+
201
+ See Also
202
+ --------
203
+ :func:`ophyd_async.core.load_from_yaml`
204
+ :func:`ophyd_async.core.walk_rw_signals`
205
+ """
206
+ # For each phase, set all the signals,
207
+ # load them to the correct value and wait for the load to complete
208
+ for phase_number, phase in enumerate(values):
209
+ # Key is signal name
210
+ for key, value in phase.items():
211
+ # Skip ignored values
212
+ if value is None:
213
+ continue
214
+
215
+ if key in signals:
216
+ yield from abs_set(
217
+ signals[key], value, group=f"load-phase{phase_number}"
218
+ )
219
+
220
+ yield from wait(f"load-phase{phase_number}")
221
+
222
+
223
+ def load_device(device: Device, path: str):
224
+ """Plan which loads PVs from a yaml file into a device.
225
+
226
+ Parameters
227
+ ----------
228
+ device: Device
229
+ The device to load PVs into
230
+ path: str
231
+ Path of the yaml file to load
232
+
233
+ See Also
234
+ --------
235
+ :func:`ophyd_async.core.save_device`
236
+ """
237
+ values = load_from_yaml(path)
238
+ signals_to_set = walk_rw_signals(device)
239
+ yield from set_signal_values(signals_to_set, values)
240
+
241
+
242
+ def all_at_once(values: Dict[str, Any]) -> Sequence[Dict[str, Any]]:
243
+ """Sort all the values into a single phase so they are set all at once"""
244
+ return [values]
245
+
246
+
247
+ def save_device(
248
+ device: Device,
249
+ path: str,
250
+ sorter: Callable[[Dict[str, Any]], Sequence[Dict[str, Any]]] = all_at_once,
251
+ ignore: Optional[List[str]] = None,
252
+ ):
253
+ """Plan that saves the state of all PV's on a device using a sorter.
254
+
255
+ The default sorter assumes all saved PVs can be loaded at once, and therefore
256
+ can be saved at one time, i.e. all PVs will appear on one list in the
257
+ resulting yaml file.
258
+
259
+ This can be a problem, because when the yaml is ingested with
260
+ :func:`ophyd_async.core.load_device`, it will set all of those PVs at once.
261
+ However, some PV's need to be set before others - this is device specific.
262
+
263
+ Therefore, users should consider the order of device loading and write their
264
+ own sorter algorithms accordingly.
265
+
266
+ See :func:`ophyd_async.panda.phase_sorter` for a valid implementation of the
267
+ sorter.
268
+
269
+ Parameters
270
+ ----------
271
+ device : Device
272
+ The device whose PVs should be saved.
273
+
274
+ path : str
275
+ The path where the resulting yaml should be saved to
276
+
277
+ sorter : Callable[[Dict[str, Any]], Sequence[Dict[str, Any]]]
278
+
279
+ ignore : Optional[List[str]]
280
+
281
+ See Also
282
+ --------
283
+ :func:`ophyd_async.core.load_device`
284
+ """
285
+ values = yield from get_signal_values(walk_rw_signals(device), ignore=ignore)
286
+ save_to_yaml(sorter(values), path)
@@ -0,0 +1,85 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Generic, Sequence, TypeVar
3
+
4
+ from bluesky.protocols import DataKey, Flyable, Preparable, Reading, Stageable
5
+
6
+ from .async_status import AsyncStatus
7
+ from .device import Device
8
+ from .signal import SignalR
9
+ from .utils import merge_gathered_dicts
10
+
11
+ T = TypeVar("T")
12
+
13
+
14
+ class TriggerLogic(ABC, Generic[T]):
15
+ @abstractmethod
16
+ async def prepare(self, value: T):
17
+ """Move to the start of the flyscan"""
18
+
19
+ @abstractmethod
20
+ async def kickoff(self):
21
+ """Start the flyscan"""
22
+
23
+ @abstractmethod
24
+ async def complete(self):
25
+ """Block until the flyscan is done"""
26
+
27
+ @abstractmethod
28
+ async def stop(self):
29
+ """Stop flying and wait everything to be stopped"""
30
+
31
+
32
+ class HardwareTriggeredFlyable(
33
+ Device,
34
+ Stageable,
35
+ Preparable,
36
+ Flyable,
37
+ Generic[T],
38
+ ):
39
+ def __init__(
40
+ self,
41
+ trigger_logic: TriggerLogic[T],
42
+ configuration_signals: Sequence[SignalR] = (),
43
+ name: str = "",
44
+ ):
45
+ self._trigger_logic = trigger_logic
46
+ self._configuration_signals = tuple(configuration_signals)
47
+ super().__init__(name=name)
48
+
49
+ @property
50
+ def trigger_logic(self) -> TriggerLogic[T]:
51
+ return self._trigger_logic
52
+
53
+ @AsyncStatus.wrap
54
+ async def stage(self) -> None:
55
+ await self.unstage()
56
+
57
+ @AsyncStatus.wrap
58
+ async def unstage(self) -> None:
59
+ await self._trigger_logic.stop()
60
+
61
+ def prepare(self, value: T) -> AsyncStatus:
62
+ """Setup trajectories"""
63
+ return AsyncStatus(self._prepare(value))
64
+
65
+ async def _prepare(self, value: T) -> None:
66
+ # Move to start and setup the flyscan
67
+ await self._trigger_logic.prepare(value)
68
+
69
+ @AsyncStatus.wrap
70
+ async def kickoff(self) -> None:
71
+ await self._trigger_logic.kickoff()
72
+
73
+ @AsyncStatus.wrap
74
+ async def complete(self) -> None:
75
+ await self._trigger_logic.complete()
76
+
77
+ async def describe_configuration(self) -> Dict[str, DataKey]:
78
+ return await merge_gathered_dicts(
79
+ [sig.describe() for sig in self._configuration_signals]
80
+ )
81
+
82
+ async def read_configuration(self) -> Dict[str, Reading]:
83
+ return await merge_gathered_dicts(
84
+ [sig.read() for sig in self._configuration_signals]
85
+ )
@@ -0,0 +1,82 @@
1
+ import asyncio
2
+ from functools import cached_property
3
+ from typing import Callable, Optional, Type
4
+ from unittest.mock import Mock
5
+
6
+ from bluesky.protocols import Descriptor, Reading
7
+
8
+ from ophyd_async.core.signal_backend import SignalBackend
9
+ from ophyd_async.core.soft_signal_backend import SoftSignalBackend
10
+ from ophyd_async.core.utils import DEFAULT_TIMEOUT, ReadingValueCallback, T
11
+
12
+
13
+ class MockSignalBackend(SignalBackend[T]):
14
+ def __init__(
15
+ self,
16
+ datatype: Optional[Type[T]] = None,
17
+ initial_backend: Optional[SignalBackend[T]] = None,
18
+ ) -> None:
19
+ if isinstance(initial_backend, MockSignalBackend):
20
+ raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackends")
21
+
22
+ self.initial_backend = initial_backend
23
+
24
+ if datatype is None:
25
+ assert (
26
+ self.initial_backend
27
+ ), "Must supply either initial_backend or datatype"
28
+ datatype = self.initial_backend.datatype
29
+
30
+ self.datatype = datatype
31
+
32
+ if not isinstance(self.initial_backend, SoftSignalBackend):
33
+ # If the backend is a hard signal backend, or not provided,
34
+ # then we create a soft signal to mimic it
35
+
36
+ self.soft_backend = SoftSignalBackend(datatype=datatype)
37
+ else:
38
+ self.soft_backend = self.initial_backend
39
+
40
+ def source(self, name: str) -> str:
41
+ if self.initial_backend:
42
+ return f"mock+{self.initial_backend.source(name)}"
43
+ return f"mock+{name}"
44
+
45
+ async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None:
46
+ pass
47
+
48
+ @cached_property
49
+ def put_mock(self) -> Mock:
50
+ return Mock(name="put", spec=Callable)
51
+
52
+ @cached_property
53
+ def put_proceeds(self) -> asyncio.Event:
54
+ put_proceeds = asyncio.Event()
55
+ put_proceeds.set()
56
+ return put_proceeds
57
+
58
+ async def put(self, value: Optional[T], wait=True, timeout=None):
59
+ self.put_mock(value, wait=wait, timeout=timeout)
60
+ await self.soft_backend.put(value, wait=wait, timeout=timeout)
61
+
62
+ if wait:
63
+ await asyncio.wait_for(self.put_proceeds.wait(), timeout=timeout)
64
+
65
+ def set_value(self, value: T):
66
+ self.soft_backend.set_value(value)
67
+
68
+ async def get_reading(self) -> Reading:
69
+ return await self.soft_backend.get_reading()
70
+
71
+ async def get_value(self) -> T:
72
+ return await self.soft_backend.get_value()
73
+
74
+ async def get_setpoint(self) -> T:
75
+ """For a soft signal, the setpoint and readback values are the same."""
76
+ return await self.soft_backend.get_setpoint()
77
+
78
+ async def get_datakey(self, source: str) -> Descriptor:
79
+ return await self.soft_backend.get_datakey(source)
80
+
81
+ def set_callback(self, callback: Optional[ReadingValueCallback[T]]) -> None:
82
+ self.soft_backend.set_callback(callback)
@@ -0,0 +1,145 @@
1
+ from contextlib import asynccontextmanager, contextmanager
2
+ from typing import Any, Callable, Iterable
3
+ from unittest.mock import Mock
4
+
5
+ from ophyd_async.core.signal import Signal
6
+ from ophyd_async.core.utils import T
7
+
8
+ from .mock_signal_backend import MockSignalBackend
9
+
10
+
11
+ def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend:
12
+ assert isinstance(signal._backend, MockSignalBackend), (
13
+ "Expected to receive a `MockSignalBackend`, instead "
14
+ f" received {type(signal._backend)}. "
15
+ )
16
+ return signal._backend
17
+
18
+
19
+ def set_mock_value(signal: Signal[T], value: T):
20
+ """Set the value of a signal that is in mock mode."""
21
+ backend = _get_mock_signal_backend(signal)
22
+ backend.set_value(value)
23
+
24
+
25
+ def set_mock_put_proceeds(signal: Signal, proceeds: bool):
26
+ """Allow or block a put with wait=True from proceeding"""
27
+ backend = _get_mock_signal_backend(signal)
28
+
29
+ if proceeds:
30
+ backend.put_proceeds.set()
31
+ else:
32
+ backend.put_proceeds.clear()
33
+
34
+
35
+ @asynccontextmanager
36
+ async def mock_puts_blocked(*signals: Signal):
37
+ for signal in signals:
38
+ set_mock_put_proceeds(signal, False)
39
+ yield
40
+ for signal in signals:
41
+ set_mock_put_proceeds(signal, True)
42
+
43
+
44
+ def get_mock_put(signal: Signal) -> Mock:
45
+ """Get the mock associated with the put call on the signal."""
46
+ return _get_mock_signal_backend(signal).put_mock
47
+
48
+
49
+ def reset_mock_put_calls(signal: Signal):
50
+ backend = _get_mock_signal_backend(signal)
51
+ backend.put_mock.reset_mock()
52
+
53
+
54
+ class _SetValuesIterator:
55
+ # Garbage collected by the time __del__ is called unless we put it as a
56
+ # global attrbute here.
57
+ require_all_consumed: bool = False
58
+
59
+ def __init__(
60
+ self,
61
+ signal: Signal,
62
+ values: Iterable[Any],
63
+ require_all_consumed: bool = False,
64
+ ):
65
+ self.signal = signal
66
+ self.values = values
67
+ self.require_all_consumed = require_all_consumed
68
+ self.index = 0
69
+
70
+ self.iterator = enumerate(values, start=1)
71
+
72
+ def __iter__(self):
73
+ return self
74
+
75
+ def __next__(self):
76
+ # Will propogate StopIteration
77
+ self.index, next_value = next(self.iterator)
78
+ set_mock_value(self.signal, next_value)
79
+ return next_value
80
+
81
+ def __del__(self):
82
+ if self.require_all_consumed and self.index != len(list(self.values)):
83
+ raise AssertionError("Not all values have been consumed.")
84
+
85
+
86
+ def set_mock_values(
87
+ signal: Signal,
88
+ values: Iterable[Any],
89
+ require_all_consumed: bool = False,
90
+ ) -> _SetValuesIterator:
91
+ """Iterator to set a signal to a sequence of values, optionally repeating the
92
+ sequence.
93
+
94
+ Parameters
95
+ ----------
96
+ signal:
97
+ A signal with a `MockSignalBackend` backend.
98
+ values:
99
+ An iterable of the values to set the signal to, on each iteration
100
+ the value will be set.
101
+ require_all_consumed:
102
+ If True, an AssertionError will be raised if the iterator is deleted before
103
+ all values have been consumed.
104
+
105
+ Notes
106
+ -----
107
+ Example usage::
108
+
109
+ for value_set in set_mock_values(signal, [1, 2, 3]):
110
+ # do something
111
+
112
+ cm = set_mock_values(signal, 1, 2, 3, require_all_consumed=True):
113
+ next(cm)
114
+ # do something
115
+ """
116
+
117
+ return _SetValuesIterator(
118
+ signal,
119
+ values,
120
+ require_all_consumed=require_all_consumed,
121
+ )
122
+
123
+
124
+ @contextmanager
125
+ def _unset_side_effect_cm(put_mock: Mock):
126
+ yield
127
+ put_mock.side_effect = None
128
+
129
+
130
+ def callback_on_mock_put(signal: Signal[T], callback: Callable[[T], None]):
131
+ """For setting a callback when a backend is put to.
132
+
133
+ Can either be used in a context, with the callback being
134
+ unset on exit, or as an ordinary function.
135
+
136
+ Parameters
137
+ ----------
138
+ signal:
139
+ A signal with a `MockSignalBackend` backend.
140
+ callback:
141
+ The callback to call when the backend is put to during the context.
142
+ """
143
+ backend = _get_mock_signal_backend(signal)
144
+ backend.put_mock.side_effect = callback
145
+ return _unset_side_effect_cm(backend.put_mock)