ophyd-async 0.14.1__py3-none-any.whl → 0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ophyd_async/_version.py +2 -2
  2. ophyd_async/core/__init__.py +17 -5
  3. ophyd_async/core/{_table.py → _datatypes.py} +18 -9
  4. ophyd_async/core/_derived_signal.py +57 -24
  5. ophyd_async/core/_derived_signal_backend.py +1 -5
  6. ophyd_async/core/_device_filler.py +30 -7
  7. ophyd_async/core/_mock_signal_backend.py +25 -7
  8. ophyd_async/core/_mock_signal_utils.py +7 -11
  9. ophyd_async/core/_signal.py +11 -11
  10. ophyd_async/core/_signal_backend.py +7 -19
  11. ophyd_async/core/_soft_signal_backend.py +6 -6
  12. ophyd_async/core/_status.py +81 -4
  13. ophyd_async/core/_typing.py +0 -0
  14. ophyd_async/core/_utils.py +57 -7
  15. ophyd_async/epics/adcore/_core_io.py +12 -5
  16. ophyd_async/epics/adcore/_core_logic.py +1 -1
  17. ophyd_async/epics/core/__init__.py +2 -1
  18. ophyd_async/epics/core/_aioca.py +13 -3
  19. ophyd_async/epics/core/_epics_connector.py +4 -1
  20. ophyd_async/epics/core/_p4p.py +13 -3
  21. ophyd_async/epics/core/_signal.py +18 -6
  22. ophyd_async/epics/core/_util.py +23 -3
  23. ophyd_async/epics/demo/_motor.py +2 -2
  24. ophyd_async/epics/motor.py +15 -17
  25. ophyd_async/epics/odin/_odin_io.py +1 -1
  26. ophyd_async/epics/pmac/_pmac_io.py +23 -4
  27. ophyd_async/epics/pmac/_pmac_trajectory.py +47 -10
  28. ophyd_async/fastcs/eiger/_eiger_io.py +20 -1
  29. ophyd_async/fastcs/jungfrau/_signals.py +4 -1
  30. ophyd_async/fastcs/panda/_block.py +28 -6
  31. ophyd_async/fastcs/panda/_writer.py +1 -3
  32. ophyd_async/tango/core/_tango_transport.py +7 -17
  33. ophyd_async/tango/demo/_counter.py +2 -2
  34. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/METADATA +1 -1
  35. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/RECORD +38 -37
  36. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/WHEEL +1 -1
  37. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/licenses/LICENSE +0 -0
  38. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/top_level.txt +0 -0
ophyd_async/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.14.1'
32
- __version_tuple__ = version_tuple = (0, 14, 1)
31
+ __version__ = version = '0.15'
32
+ __version_tuple__ = version_tuple = (0, 15)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,5 +1,6 @@
1
1
  """The building blocks for making devices."""
2
2
 
3
+ from ._datatypes import Array1D, DTypeScalar_co, Table, TableSubclass
3
4
  from ._derived_signal import (
4
5
  DerivedSignalFactory,
5
6
  derived_signal_r,
@@ -23,7 +24,7 @@ from ._device import (
23
24
  default_mock_class,
24
25
  init_devices,
25
26
  )
26
- from ._device_filler import DeviceFiller
27
+ from ._device_filler import DeviceAnnotation, DeviceFiller
27
28
  from ._enums import (
28
29
  EnabledDisabled,
29
30
  EnableDisable,
@@ -44,7 +45,15 @@ from ._mock_signal_utils import (
44
45
  set_mock_value,
45
46
  set_mock_values,
46
47
  )
47
- from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable, Watcher
48
+ from ._protocol import (
49
+ AsyncConfigurable,
50
+ AsyncLocatable,
51
+ AsyncMovable,
52
+ AsyncPausable,
53
+ AsyncReadable,
54
+ AsyncStageable,
55
+ Watcher,
56
+ )
48
57
  from ._providers import (
49
58
  AutoIncrementFilenameProvider,
50
59
  AutoIncrementingPathProvider,
@@ -86,8 +95,6 @@ from ._signal import (
86
95
  walk_signal_sources,
87
96
  )
88
97
  from ._signal_backend import (
89
- Array1D,
90
- DTypeScalar_co,
91
98
  Primitive,
92
99
  SignalBackend,
93
100
  SignalDatatype,
@@ -97,7 +104,6 @@ from ._signal_backend import (
97
104
  )
98
105
  from ._soft_signal_backend import SoftSignalBackend
99
106
  from ._status import AsyncStatus, WatchableAsyncStatus, completed_status
100
- from ._table import Table, TableSubclass
101
107
  from ._utils import (
102
108
  CALCULATE_TIMEOUT,
103
109
  DEFAULT_TIMEOUT,
@@ -117,6 +123,7 @@ from ._utils import (
117
123
  get_enum_cls,
118
124
  get_unique,
119
125
  in_micros,
126
+ non_zero,
120
127
  wait_for_connection,
121
128
  )
122
129
  from ._yaml_settings import YamlSettingsProvider
@@ -146,11 +153,15 @@ __all__ = [
146
153
  "Device",
147
154
  "DeviceConnector",
148
155
  "DeviceFiller",
156
+ "DeviceAnnotation",
149
157
  "DeviceVector",
150
158
  "init_devices",
151
159
  # Protocols
152
160
  "AsyncReadable",
153
161
  "AsyncConfigurable",
162
+ "AsyncLocatable",
163
+ "AsyncMovable",
164
+ "AsyncPausable",
154
165
  "AsyncStageable",
155
166
  "Watcher",
156
167
  # Status
@@ -255,6 +266,7 @@ __all__ = [
255
266
  "make_datakey",
256
267
  "wait_for_connection",
257
268
  "Ignore",
269
+ "non_zero",
258
270
  # Derived signal
259
271
  "derived_signal_r",
260
272
  "derived_signal_rw",
@@ -1,13 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections.abc import Callable, Sequence
4
- from typing import Annotated, Any, TypeVar, get_origin, get_type_hints
4
+ from typing import Annotated, Any, TypeVar
5
5
 
6
6
  import numpy as np
7
7
  from pydantic import ConfigDict, Field, model_validator
8
8
  from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation
9
9
 
10
- from ._utils import ConfinedModel, get_dtype
10
+ from ._utils import ConfinedModel, cached_get_origin, cached_get_type_hints, get_dtype
11
+
12
+ DTypeScalar_co = TypeVar("DTypeScalar_co", covariant=True, bound=np.generic)
13
+ """A numpy dtype like [](#numpy.float64)."""
14
+
15
+
16
+ # To be a 1D array shape should really be tuple[int], but np.array()
17
+ # currently produces tuple[int, ...] even when it has 1D input args
18
+ # https://github.com/numpy/numpy/issues/28077#issuecomment-2566485178
19
+ Array1D = np.ndarray[tuple[int, ...], np.dtype[DTypeScalar_co]]
20
+ """A type alias for a 1D numpy array with a specific scalar data type.
21
+
22
+ E.g. `Array1D[np.float64]` is a 1D numpy array of 64-bit floats."""
11
23
 
12
24
  TableSubclass = TypeVar("TableSubclass", bound="Table")
13
25
 
@@ -75,11 +87,8 @@ class Table(ConfinedModel):
75
87
  # ...but forbid extra in subclasses so it gets validated
76
88
  cls.model_config = ConfigDict(validate_assignment=True, extra="forbid")
77
89
  # Change fields to have the correct annotations
78
- # TODO: refactor so we don't need this to break circular imports
79
- from ._signal_backend import Array1D
80
-
81
- for k, anno in get_type_hints(cls, localns={"Array1D": Array1D}).items():
82
- if get_origin(anno) is np.ndarray:
90
+ for k, anno in cached_get_type_hints(cls).items():
91
+ if cached_get_origin(anno) is np.ndarray:
83
92
  dtype = get_dtype(anno)
84
93
  new_anno = Annotated[
85
94
  anno,
@@ -88,7 +97,7 @@ class Table(ConfinedModel):
88
97
  ),
89
98
  Field(default_factory=_make_default_factory(dtype)),
90
99
  ]
91
- elif get_origin(anno) is Sequence:
100
+ elif cached_get_origin(anno) is Sequence:
92
101
  new_anno = Annotated[anno, Field(default_factory=list)]
93
102
  else:
94
103
  raise TypeError(f"Cannot use annotation {anno} in a Table")
@@ -156,7 +165,7 @@ class Table(ConfinedModel):
156
165
  raise AssertionError(f"Cannot construct Table from {data}")
157
166
  for field_name, field_value in cls.model_fields.items():
158
167
  if (
159
- get_origin(field_value.annotation) is np.ndarray
168
+ cached_get_origin(field_value.annotation) is np.ndarray
160
169
  and field_value.annotation
161
170
  and field_name in data_dict
162
171
  ):
@@ -1,5 +1,14 @@
1
- from collections.abc import Awaitable, Callable
2
- from typing import Any, Generic, get_args, get_origin, get_type_hints, is_typeddict
1
+ import functools
2
+ from collections.abc import Awaitable, Callable, Mapping
3
+ from inspect import Parameter, signature
4
+ from typing import (
5
+ Any,
6
+ Generic,
7
+ TypeVar,
8
+ get_args,
9
+ get_origin,
10
+ is_typeddict,
11
+ )
3
12
 
4
13
  from bluesky.protocols import Locatable
5
14
 
@@ -12,6 +21,7 @@ from ._derived_signal_backend import (
12
21
  from ._device import Device
13
22
  from ._signal import Signal, SignalR, SignalRW, SignalT, SignalW
14
23
  from ._signal_backend import Primitive, SignalDatatypeT
24
+ from ._utils import cached_get_type_hints
15
25
 
16
26
 
17
27
  class DerivedSignalFactory(Generic[TransformT]):
@@ -53,12 +63,13 @@ class DerivedSignalFactory(Generic[TransformT]):
53
63
  # Populate expected parameters and types
54
64
  expected = {
55
65
  **{k: f.annotation for k, f in transform_cls.model_fields.items()},
56
- **{
57
- k: v
58
- for k, v in get_type_hints(transform_cls.raw_to_derived).items()
59
- if k not in {"self", "return"}
60
- },
66
+ **_get_params_types_dict(transform_cls.raw_to_derived),
61
67
  }
68
+ if empty_keys := [k for k, v in expected.items() if v == Parameter.empty]:
69
+ raise TypeError(
70
+ f"{transform_cls.raw_to_derived} is missing a type "
71
+ f"hint for arguments: {empty_keys}"
72
+ )
62
73
 
63
74
  # Populate received parameters and types
64
75
  # Use Primitive's type, Signal's datatype,
@@ -76,7 +87,19 @@ class DerivedSignalFactory(Generic[TransformT]):
76
87
  f"Expected the following to be passed as keyword arguments "
77
88
  f"{expected}, got {received}"
78
89
  )
79
- raise TypeError(msg)
90
+ if set(expected.keys()) - set(received.keys()):
91
+ raise TypeError(msg)
92
+
93
+ for k in set(expected.keys()):
94
+ if isinstance(expected[k], type):
95
+ if not issubclass(received[k], expected[k]):
96
+ raise TypeError(msg)
97
+ elif isinstance(expected[k], TypeVar):
98
+ bound = expected[k].__bound__
99
+ if isinstance(bound, type) and not issubclass(
100
+ received[k], bound
101
+ ):
102
+ raise TypeError(msg)
80
103
  self._set_derived_takes_dict = (
81
104
  is_typeddict(_get_first_arg_datatype(set_derived)) if set_derived else False
82
105
  )
@@ -185,7 +208,7 @@ class DerivedSignalFactory(Generic[TransformT]):
185
208
 
186
209
 
187
210
  def _get_return_datatype(func: Callable[..., SignalDatatypeT]) -> type[SignalDatatypeT]:
188
- args = get_type_hints(func)
211
+ args = cached_get_type_hints(func)
189
212
  if "return" not in args:
190
213
  msg = f"{func} does not have a type hint for it's return value"
191
214
  raise TypeError(msg)
@@ -195,28 +218,28 @@ def _get_return_datatype(func: Callable[..., SignalDatatypeT]) -> type[SignalDat
195
218
  def _get_first_arg_datatype(
196
219
  func: Callable[[SignalDatatypeT], Any],
197
220
  ) -> type[SignalDatatypeT]:
198
- args = get_type_hints(func)
199
- args.pop("return", None)
221
+ args = _get_params_types_dict(func)
200
222
  if not args:
201
223
  msg = f"{func} does not have a type hinted argument"
202
224
  raise TypeError(msg)
203
225
  return list(args.values())[0]
204
226
 
205
227
 
228
+ def _get_params_types_dict(inspected_function: Callable) -> Mapping[str, Any]:
229
+ sig = signature(inspected_function, eval_str=True)
230
+ exclude_keys = {"self", "args", "kwargs", "cls"}
231
+ return {k: v.annotation for k, v in sig.parameters.items() if k not in exclude_keys}
232
+
233
+
206
234
  def _make_factory(
207
- raw_to_derived: Callable[..., SignalDatatypeT] | None = None,
235
+ raw_to_derived_func: Callable[..., SignalDatatypeT] | None = None,
208
236
  set_derived: Callable[[SignalDatatypeT], Awaitable[None]] | None = None,
209
237
  raw_devices_and_constants: dict[str, Device | Primitive] | None = None,
210
238
  ) -> DerivedSignalFactory:
211
- if raw_to_derived:
239
+ if raw_to_derived_func:
212
240
 
213
241
  class DerivedTransform(Transform):
214
- def raw_to_derived(self, **kwargs) -> dict[str, SignalDatatypeT]:
215
- return {"value": raw_to_derived(**kwargs)}
216
-
217
- # Update the signature for raw_to_derived to match what we are passed as this
218
- # will be checked in DerivedSignalFactory
219
- DerivedTransform.raw_to_derived.__annotations__ = get_type_hints(raw_to_derived)
242
+ raw_to_derived = _dict_wrapper(raw_to_derived_func)
220
243
 
221
244
  return DerivedSignalFactory(
222
245
  DerivedTransform,
@@ -245,7 +268,7 @@ def derived_signal_r(
245
268
  The names of these arguments must match the arguments of raw_to_derived.
246
269
  """
247
270
  factory = _make_factory(
248
- raw_to_derived=raw_to_derived,
271
+ raw_to_derived_func=raw_to_derived,
249
272
  raw_devices_and_constants=raw_devices_and_constants,
250
273
  )
251
274
  return factory.derived_signal_r(
@@ -278,16 +301,16 @@ def derived_signal_rw(
278
301
  The names of these arguments must match the arguments of raw_to_derived.
279
302
  """
280
303
  raw_to_derived_datatype = _get_return_datatype(raw_to_derived)
281
- set_derived_datatype = _get_first_arg_datatype(set_derived)
282
- if raw_to_derived_datatype != set_derived_datatype:
304
+ set_derived_arg_datatype = _get_first_arg_datatype(set_derived)
305
+ if raw_to_derived_datatype != set_derived_arg_datatype:
283
306
  msg = (
284
307
  f"{raw_to_derived} has datatype {raw_to_derived_datatype} "
285
- f"!= {set_derived_datatype} datatype {set_derived_datatype}"
308
+ f"!= {set_derived_arg_datatype} datatype {set_derived_arg_datatype}"
286
309
  )
287
310
  raise TypeError(msg)
288
311
 
289
312
  factory = _make_factory(
290
- raw_to_derived=raw_to_derived,
313
+ raw_to_derived_func=raw_to_derived,
291
314
  set_derived=set_derived,
292
315
  raw_devices_and_constants=raw_devices_and_constants,
293
316
  )
@@ -343,3 +366,13 @@ def _partition_by_keys(data: dict, keys: set) -> tuple[dict, dict]:
343
366
  else:
344
367
  group_excluded[k] = v
345
368
  return group_excluded, group_included
369
+
370
+
371
+ def _dict_wrapper(
372
+ fn: Callable[..., SignalDatatypeT],
373
+ ) -> Callable[..., dict[str, SignalDatatypeT]]:
374
+ @functools.wraps(fn)
375
+ def wrapped(self, **kwargs):
376
+ return {"value": fn(**kwargs)}
377
+
378
+ return wrapped
@@ -291,11 +291,7 @@ class DerivedSignalBackend(SignalBackend[SignalDatatypeT]):
291
291
  )
292
292
  raise RuntimeError(msg)
293
293
 
294
- async def put(self, value: SignalDatatypeT | None, wait: bool) -> None:
295
- if wait is False:
296
- msg = "Cannot put with wait=False"
297
- raise RuntimeError(msg)
298
-
294
+ async def put(self, value: SignalDatatypeT | None) -> None:
299
295
  value = error_if_none(
300
296
  value,
301
297
  "Must be given a value to put",
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import types
3
4
  from abc import abstractmethod
4
5
  from collections.abc import Callable, Iterator, Sequence
5
6
  from typing import (
@@ -9,16 +10,16 @@ from typing import (
9
10
  NoReturn,
10
11
  Protocol,
11
12
  TypeVar,
13
+ Union,
12
14
  cast,
13
15
  get_args,
14
- get_type_hints,
15
16
  runtime_checkable,
16
17
  )
17
18
 
18
19
  from ._device import Device, DeviceConnector, DeviceVector
19
20
  from ._signal import Ignore, Signal, SignalX
20
21
  from ._signal_backend import SignalBackend, SignalDatatype
21
- from ._utils import get_origin_class
22
+ from ._utils import cached_get_origin, cached_get_type_hints, get_origin_class
22
23
 
23
24
  SignalBackendT = TypeVar("SignalBackendT", bound=SignalBackend)
24
25
  DeviceConnectorT = TypeVar("DeviceConnectorT", bound=DeviceConnector)
@@ -76,6 +77,7 @@ class DeviceFiller(Generic[SignalBackendT, DeviceConnectorT]):
76
77
  self._extras: dict[UniqueName, Sequence[Any]] = {}
77
78
  self._signal_datatype: dict[LogicalName, type | None] = {}
78
79
  self._vector_device_type: dict[LogicalName, type[Device] | None] = {}
80
+ self._optional_devices: set[str] = set()
79
81
  self.ignored_signals: set[str] = set()
80
82
  # Backends and Connectors stored ready for the connection phase
81
83
  self._unfilled_backends: dict[
@@ -113,14 +115,28 @@ class DeviceFiller(Generic[SignalBackendT, DeviceConnectorT]):
113
115
  # https://github.com/python/cpython/issues/124840
114
116
  cls = type(self._device)
115
117
  # Get hints without Annotated for determining types
116
- hints = get_type_hints(cls)
118
+ hints = cached_get_type_hints(cls)
117
119
  # Get hints with Annotated for wrapping signals and backends
118
- extra_hints = get_type_hints(cls, include_extras=True)
120
+ extra_hints = cached_get_type_hints(cls, include_extras=True)
119
121
  for attr_name, annotation in hints.items():
120
122
  if annotation is Ignore:
121
123
  self.ignored_signals.add(attr_name)
122
124
  name = UniqueName(attr_name)
123
125
  origin = get_origin_class(annotation)
126
+ args = get_args(annotation)
127
+
128
+ if (
129
+ cached_get_origin(annotation) is Union
130
+ and types.NoneType in args
131
+ and len(args) == 2
132
+ ):
133
+ # Annotation is an Union with two arguments, one of which is None
134
+ # Make this signal an optional parameter and set origin to T
135
+ # so the device is added to unfilled_connectors
136
+ self._optional_devices.add(name)
137
+ (annotation,) = [x for x in args if x is not types.NoneType]
138
+ origin = get_origin_class(annotation)
139
+
124
140
  if (
125
141
  name == "parent"
126
142
  or name.startswith("_")
@@ -241,10 +257,17 @@ class DeviceFiller(Generic[SignalBackendT, DeviceConnectorT]):
241
257
  :param source: The source of the data that should have done the filling, for
242
258
  reporting as an error message
243
259
  """
244
- unfilled = sorted(set(self._unfilled_connectors).union(self._unfilled_backends))
245
- if unfilled:
260
+ unfilled = set(self._unfilled_connectors).union(self._unfilled_backends)
261
+ unfilled_optional = sorted(unfilled.intersection(self._optional_devices))
262
+
263
+ for name in unfilled_optional:
264
+ setattr(self._device, name, None)
265
+
266
+ required = sorted(unfilled.difference(unfilled_optional))
267
+
268
+ if required:
246
269
  raise RuntimeError(
247
- f"{self._device.name}: cannot provision {unfilled} from {source}"
270
+ f"{self._device.name}: cannot provision {required} from {source}"
248
271
  )
249
272
 
250
273
  def _ensure_device_vector(self, name: LogicalName) -> DeviceVector:
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
- from collections.abc import Callable
4
+ from collections.abc import Awaitable, Callable
5
5
  from functools import cached_property
6
6
  from typing import TYPE_CHECKING
7
7
  from unittest.mock import AsyncMock
@@ -17,6 +17,11 @@ from ._utils import Callback
17
17
  if TYPE_CHECKING:
18
18
  from ._device import LazyMock
19
19
 
20
+ MockPutCallback = (
21
+ Callable[[SignalDatatypeT], SignalDatatypeT | None]
22
+ | Callable[[SignalDatatypeT], Awaitable[SignalDatatypeT | None]]
23
+ )
24
+
20
25
 
21
26
  class MockSignalBackend(SignalBackend[SignalDatatypeT]):
22
27
  """Signal backend for testing, created by ``Device.connect(mock=True)``."""
@@ -42,13 +47,27 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]):
42
47
 
43
48
  # use existing Mock if provided
44
49
  self.mock = mock
50
+ self._mock_put_callback: MockPutCallback | None = None
45
51
  super().__init__(datatype=self.initial_backend.datatype)
46
52
 
53
+ def set_mock_put_callback(self, callback: MockPutCallback | None):
54
+ if "put_mock" in self.__dict__:
55
+ # put_mock cached property exists, so set the side effect on it
56
+ self.put_mock.side_effect = callback
57
+ else:
58
+ # put_mock doesn't exist, don't create it as that would be slow
59
+ # so just keep it internally
60
+ self._mock_put_callback = callback
61
+
47
62
  @cached_property
48
63
  def put_mock(self) -> AsyncMock:
49
64
  """Return the mock that will track calls to `put()`."""
50
65
  put_mock = AsyncMock(
51
- name="put", spec=Callable, side_effect=lambda *_, **__: None
66
+ name="put",
67
+ spec=Callable,
68
+ side_effect=self._mock_put_callback
69
+ if self._mock_put_callback
70
+ else lambda v: None,
52
71
  )
53
72
  self.mock().attach_mock(put_mock, "put")
54
73
  return put_mock
@@ -73,13 +92,12 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]):
73
92
  put_proceeds.set()
74
93
  return put_proceeds
75
94
 
76
- async def put(self, value: SignalDatatypeT | None, wait: bool):
77
- new_value = await self.put_mock(value, wait=wait)
95
+ async def put(self, value: SignalDatatypeT | None):
96
+ new_value = await self.put_mock(value)
78
97
  if new_value is None:
79
98
  new_value = value
80
- await self.soft_backend.put(new_value, wait=wait)
81
- if wait:
82
- await self.put_proceeds.wait()
99
+ await self.soft_backend.put(new_value)
100
+ await self.put_proceeds.wait()
83
101
 
84
102
  async def get_reading(self) -> Reading:
85
103
  return await self.soft_backend.get_reading()
@@ -1,9 +1,9 @@
1
- from collections.abc import Awaitable, Callable, Iterable, Iterator
1
+ from collections.abc import Iterable, Iterator
2
2
  from contextlib import contextmanager
3
3
  from unittest.mock import AsyncMock, Mock
4
4
 
5
5
  from ._device import Device, DeviceMock
6
- from ._mock_signal_backend import MockSignalBackend
6
+ from ._mock_signal_backend import MockPutCallback, MockSignalBackend
7
7
  from ._signal import Signal, SignalConnector, SignalR
8
8
  from ._signal_backend import SignalDatatypeT
9
9
 
@@ -109,16 +109,12 @@ def set_mock_values(
109
109
 
110
110
 
111
111
  @contextmanager
112
- def _unset_side_effect_cm(put_mock: AsyncMock):
112
+ def _unset_side_effect_cm(backend: MockSignalBackend):
113
113
  yield
114
- put_mock.side_effect = None
114
+ backend.set_mock_put_callback(None)
115
115
 
116
116
 
117
- def callback_on_mock_put(
118
- signal: Signal[SignalDatatypeT],
119
- callback: Callable[[SignalDatatypeT, bool], SignalDatatypeT | None]
120
- | Callable[[SignalDatatypeT, bool], Awaitable[SignalDatatypeT | None]],
121
- ):
117
+ def callback_on_mock_put(signal: Signal[SignalDatatypeT], callback: MockPutCallback):
122
118
  """For setting a callback when a backend is put to.
123
119
 
124
120
  Can either be used in a context, with the callback being unset on exit, or
@@ -132,8 +128,8 @@ def callback_on_mock_put(
132
128
  context.
133
129
  """
134
130
  backend = _get_mock_signal_backend(signal)
135
- backend.put_mock.side_effect = callback
136
- return _unset_side_effect_cm(backend.put_mock)
131
+ backend.set_mock_put_callback(callback)
132
+ return _unset_side_effect_cm(backend)
137
133
 
138
134
 
139
135
  def set_mock_put_proceeds(signal: Signal, proceeds: bool):
@@ -279,13 +279,11 @@ class SignalW(Signal[SignalDatatypeT], Movable):
279
279
  async def set(
280
280
  self,
281
281
  value: SignalDatatypeT,
282
- wait=True,
283
282
  timeout: CalculatableTimeout = CALCULATE_TIMEOUT,
284
283
  ) -> None:
285
284
  """Set the value and return a status saying when it's done.
286
285
 
287
286
  :param value: The value to set.
288
- :param wait: If True, wait for the set to complete.
289
287
  :param timeout: The timeout for the set.
290
288
  """
291
289
  if timeout == CALCULATE_TIMEOUT:
@@ -299,9 +297,7 @@ class SignalW(Signal[SignalDatatypeT], Movable):
299
297
  wait_jitter=0,
300
298
  ):
301
299
  with attempt:
302
- await _wait_for(
303
- self._connector.backend.put(value, wait=wait), timeout, source
304
- )
300
+ await _wait_for(self._connector.backend.put(value), timeout, source)
305
301
  self.log.debug(f"Successfully put value {value} to backend at source {source}")
306
302
 
307
303
 
@@ -321,19 +317,16 @@ class SignalX(Signal):
321
317
  """Signal that puts the default value."""
322
318
 
323
319
  @AsyncStatus.wrap
324
- async def trigger(
325
- self, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT
326
- ) -> None:
320
+ async def trigger(self, timeout: CalculatableTimeout = CALCULATE_TIMEOUT) -> None:
327
321
  """Trigger the action and return a status saying when it's done.
328
322
 
329
- :param wait: If True, wait for the trigger to complete.
330
323
  :param timeout: The timeout for the trigger.
331
324
  """
332
325
  if timeout == CALCULATE_TIMEOUT:
333
326
  timeout = self._timeout
334
327
  source = self._connector.backend.source(self.name, read=False)
335
328
  self.log.debug(f"Putting default value to backend at source {source}")
336
- await _wait_for(self._connector.backend.put(None, wait=wait), timeout, source)
329
+ await _wait_for(self._connector.backend.put(None), timeout, source)
337
330
  self.log.debug(f"Successfully put default value to backend at source {source}")
338
331
 
339
332
 
@@ -508,6 +501,12 @@ async def observe_signals_value(
508
501
  f"{[signal.source for signal in signals]}. "
509
502
  f"Last observed signal and value were {last_item}"
510
503
  ) from exc
504
+ except asyncio.CancelledError as exc:
505
+ raise asyncio.CancelledError(
506
+ f"Cancelled Error while waiting {iteration_timeout}s to update "
507
+ f"{[signal.source for signal in signals]}. "
508
+ f"Last observed signal and value were {last_item}"
509
+ ) from exc
511
510
  if done_status and item is done_status:
512
511
  if exc := done_status.exception():
513
512
  raise exc
@@ -679,7 +678,8 @@ async def set_and_wait_for_value(
679
678
  await set_and_wait_for_value(device.parameter, 1)
680
679
  ```
681
680
  For busy record, or other Signals with pattern:
682
- - Set Signal with `wait=True` and stash the Status
681
+ - Set `wait=non_zero` when creating the signal
682
+ - Set Signal and stash the Status
683
683
  - Read the same Signal to check the operation has started
684
684
  - Return the Status so calling code can wait for operation to complete
685
685
  ```python
@@ -1,34 +1,22 @@
1
1
  from abc import abstractmethod
2
2
  from collections.abc import Sequence
3
- from typing import Generic, TypedDict, TypeVar, get_origin
3
+ from typing import Generic, TypedDict, TypeVar
4
4
 
5
5
  import numpy as np
6
6
  from bluesky.protocols import Reading
7
7
  from event_model import DataKey, Dtype, Limits
8
8
 
9
- from ophyd_async.core._utils import (
9
+ from ._datatypes import Array1D, Table
10
+ from ._utils import (
10
11
  Callback,
11
12
  EnumTypes,
12
13
  StrictEnum,
13
14
  SubsetEnum,
14
15
  SupersetEnum,
16
+ cached_get_origin,
15
17
  get_enum_cls,
16
18
  )
17
19
 
18
- from ._table import Table
19
-
20
- DTypeScalar_co = TypeVar("DTypeScalar_co", covariant=True, bound=np.generic)
21
- """A numpy dtype like [](#numpy.float64)."""
22
-
23
-
24
- # To be a 1D array shape should really be tuple[int], but np.array()
25
- # currently produces tuple[int, ...] even when it has 1D input args
26
- # https://github.com/numpy/numpy/issues/28077#issuecomment-2566485178
27
- Array1D = np.ndarray[tuple[int, ...], np.dtype[DTypeScalar_co]]
28
- """A type alias for a 1D numpy array with a specific scalar data type.
29
-
30
- E.g. `Array1D[np.float64]` is a 1D numpy array of 64-bit floats."""
31
-
32
20
  Primitive = bool | int | float | str
33
21
  SignalDatatype = (
34
22
  Primitive
@@ -91,7 +79,7 @@ class SignalBackend(Generic[SignalDatatypeT]):
91
79
  """Connect to underlying hardware."""
92
80
 
93
81
  @abstractmethod
94
- async def put(self, value: SignalDatatypeT | None, wait: bool):
82
+ async def put(self, value: SignalDatatypeT | None):
95
83
  """Put a value to the PV, if wait then wait for completion."""
96
84
 
97
85
  @abstractmethod
@@ -142,7 +130,7 @@ class SignalMetadata(TypedDict, total=False):
142
130
  def _datakey_dtype(datatype: type[SignalDatatype]) -> Dtype:
143
131
  if (
144
132
  datatype is np.ndarray
145
- or get_origin(datatype) in (Sequence, np.ndarray)
133
+ or cached_get_origin(datatype) in (Sequence, np.ndarray)
146
134
  or issubclass(datatype, Table)
147
135
  ):
148
136
  return "array"
@@ -161,7 +149,7 @@ def _datakey_dtype_numpy(
161
149
  # The value already has a dtype, use that
162
150
  return value.dtype
163
151
  elif (
164
- get_origin(datatype) is Sequence
152
+ cached_get_origin(datatype) is Sequence
165
153
  or datatype is str
166
154
  or issubclass(datatype, EnumTypes)
167
155
  ):