ophyd-async 0.14.2__py3-none-any.whl → 0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ophyd_async/_version.py +2 -2
  2. ophyd_async/core/__init__.py +17 -5
  3. ophyd_async/core/{_table.py → _datatypes.py} +18 -9
  4. ophyd_async/core/_derived_signal.py +2 -2
  5. ophyd_async/core/_derived_signal_backend.py +1 -5
  6. ophyd_async/core/_device_filler.py +4 -6
  7. ophyd_async/core/_mock_signal_backend.py +25 -7
  8. ophyd_async/core/_mock_signal_utils.py +7 -11
  9. ophyd_async/core/_signal.py +11 -11
  10. ophyd_async/core/_signal_backend.py +7 -19
  11. ophyd_async/core/_soft_signal_backend.py +6 -6
  12. ophyd_async/core/_status.py +81 -4
  13. ophyd_async/core/_typing.py +0 -0
  14. ophyd_async/core/_utils.py +57 -7
  15. ophyd_async/epics/adcore/_core_io.py +12 -5
  16. ophyd_async/epics/adcore/_core_logic.py +1 -1
  17. ophyd_async/epics/core/__init__.py +2 -1
  18. ophyd_async/epics/core/_aioca.py +13 -3
  19. ophyd_async/epics/core/_epics_connector.py +4 -1
  20. ophyd_async/epics/core/_p4p.py +13 -3
  21. ophyd_async/epics/core/_signal.py +18 -6
  22. ophyd_async/epics/core/_util.py +23 -3
  23. ophyd_async/epics/demo/_motor.py +2 -2
  24. ophyd_async/epics/motor.py +15 -17
  25. ophyd_async/epics/odin/_odin_io.py +1 -1
  26. ophyd_async/epics/pmac/_pmac_io.py +23 -4
  27. ophyd_async/epics/pmac/_pmac_trajectory.py +47 -10
  28. ophyd_async/fastcs/eiger/_eiger_io.py +20 -1
  29. ophyd_async/fastcs/jungfrau/_signals.py +4 -1
  30. ophyd_async/fastcs/panda/_block.py +28 -6
  31. ophyd_async/fastcs/panda/_writer.py +1 -3
  32. ophyd_async/tango/core/_tango_transport.py +7 -17
  33. ophyd_async/tango/demo/_counter.py +2 -2
  34. {ophyd_async-0.14.2.dist-info → ophyd_async-0.15.dist-info}/METADATA +1 -1
  35. {ophyd_async-0.14.2.dist-info → ophyd_async-0.15.dist-info}/RECORD +38 -37
  36. {ophyd_async-0.14.2.dist-info → ophyd_async-0.15.dist-info}/WHEEL +1 -1
  37. {ophyd_async-0.14.2.dist-info → ophyd_async-0.15.dist-info}/licenses/LICENSE +0 -0
  38. {ophyd_async-0.14.2.dist-info → ophyd_async-0.15.dist-info}/top_level.txt +0 -0
ophyd_async/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.14.2'
32
- __version_tuple__ = version_tuple = (0, 14, 2)
31
+ __version__ = version = '0.15'
32
+ __version_tuple__ = version_tuple = (0, 15)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,5 +1,6 @@
1
1
  """The building blocks for making devices."""
2
2
 
3
+ from ._datatypes import Array1D, DTypeScalar_co, Table, TableSubclass
3
4
  from ._derived_signal import (
4
5
  DerivedSignalFactory,
5
6
  derived_signal_r,
@@ -23,7 +24,7 @@ from ._device import (
23
24
  default_mock_class,
24
25
  init_devices,
25
26
  )
26
- from ._device_filler import DeviceFiller
27
+ from ._device_filler import DeviceAnnotation, DeviceFiller
27
28
  from ._enums import (
28
29
  EnabledDisabled,
29
30
  EnableDisable,
@@ -44,7 +45,15 @@ from ._mock_signal_utils import (
44
45
  set_mock_value,
45
46
  set_mock_values,
46
47
  )
47
- from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable, Watcher
48
+ from ._protocol import (
49
+ AsyncConfigurable,
50
+ AsyncLocatable,
51
+ AsyncMovable,
52
+ AsyncPausable,
53
+ AsyncReadable,
54
+ AsyncStageable,
55
+ Watcher,
56
+ )
48
57
  from ._providers import (
49
58
  AutoIncrementFilenameProvider,
50
59
  AutoIncrementingPathProvider,
@@ -86,8 +95,6 @@ from ._signal import (
86
95
  walk_signal_sources,
87
96
  )
88
97
  from ._signal_backend import (
89
- Array1D,
90
- DTypeScalar_co,
91
98
  Primitive,
92
99
  SignalBackend,
93
100
  SignalDatatype,
@@ -97,7 +104,6 @@ from ._signal_backend import (
97
104
  )
98
105
  from ._soft_signal_backend import SoftSignalBackend
99
106
  from ._status import AsyncStatus, WatchableAsyncStatus, completed_status
100
- from ._table import Table, TableSubclass
101
107
  from ._utils import (
102
108
  CALCULATE_TIMEOUT,
103
109
  DEFAULT_TIMEOUT,
@@ -117,6 +123,7 @@ from ._utils import (
117
123
  get_enum_cls,
118
124
  get_unique,
119
125
  in_micros,
126
+ non_zero,
120
127
  wait_for_connection,
121
128
  )
122
129
  from ._yaml_settings import YamlSettingsProvider
@@ -146,11 +153,15 @@ __all__ = [
146
153
  "Device",
147
154
  "DeviceConnector",
148
155
  "DeviceFiller",
156
+ "DeviceAnnotation",
149
157
  "DeviceVector",
150
158
  "init_devices",
151
159
  # Protocols
152
160
  "AsyncReadable",
153
161
  "AsyncConfigurable",
162
+ "AsyncLocatable",
163
+ "AsyncMovable",
164
+ "AsyncPausable",
154
165
  "AsyncStageable",
155
166
  "Watcher",
156
167
  # Status
@@ -255,6 +266,7 @@ __all__ = [
255
266
  "make_datakey",
256
267
  "wait_for_connection",
257
268
  "Ignore",
269
+ "non_zero",
258
270
  # Derived signal
259
271
  "derived_signal_r",
260
272
  "derived_signal_rw",
@@ -1,13 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections.abc import Callable, Sequence
4
- from typing import Annotated, Any, TypeVar, get_origin, get_type_hints
4
+ from typing import Annotated, Any, TypeVar
5
5
 
6
6
  import numpy as np
7
7
  from pydantic import ConfigDict, Field, model_validator
8
8
  from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation
9
9
 
10
- from ._utils import ConfinedModel, get_dtype
10
+ from ._utils import ConfinedModel, cached_get_origin, cached_get_type_hints, get_dtype
11
+
12
+ DTypeScalar_co = TypeVar("DTypeScalar_co", covariant=True, bound=np.generic)
13
+ """A numpy dtype like [](#numpy.float64)."""
14
+
15
+
16
+ # To be a 1D array shape should really be tuple[int], but np.array()
17
+ # currently produces tuple[int, ...] even when it has 1D input args
18
+ # https://github.com/numpy/numpy/issues/28077#issuecomment-2566485178
19
+ Array1D = np.ndarray[tuple[int, ...], np.dtype[DTypeScalar_co]]
20
+ """A type alias for a 1D numpy array with a specific scalar data type.
21
+
22
+ E.g. `Array1D[np.float64]` is a 1D numpy array of 64-bit floats."""
11
23
 
12
24
  TableSubclass = TypeVar("TableSubclass", bound="Table")
13
25
 
@@ -75,11 +87,8 @@ class Table(ConfinedModel):
75
87
  # ...but forbid extra in subclasses so it gets validated
76
88
  cls.model_config = ConfigDict(validate_assignment=True, extra="forbid")
77
89
  # Change fields to have the correct annotations
78
- # TODO: refactor so we don't need this to break circular imports
79
- from ._signal_backend import Array1D
80
-
81
- for k, anno in get_type_hints(cls, localns={"Array1D": Array1D}).items():
82
- if get_origin(anno) is np.ndarray:
90
+ for k, anno in cached_get_type_hints(cls).items():
91
+ if cached_get_origin(anno) is np.ndarray:
83
92
  dtype = get_dtype(anno)
84
93
  new_anno = Annotated[
85
94
  anno,
@@ -88,7 +97,7 @@ class Table(ConfinedModel):
88
97
  ),
89
98
  Field(default_factory=_make_default_factory(dtype)),
90
99
  ]
91
- elif get_origin(anno) is Sequence:
100
+ elif cached_get_origin(anno) is Sequence:
92
101
  new_anno = Annotated[anno, Field(default_factory=list)]
93
102
  else:
94
103
  raise TypeError(f"Cannot use annotation {anno} in a Table")
@@ -156,7 +165,7 @@ class Table(ConfinedModel):
156
165
  raise AssertionError(f"Cannot construct Table from {data}")
157
166
  for field_name, field_value in cls.model_fields.items():
158
167
  if (
159
- get_origin(field_value.annotation) is np.ndarray
168
+ cached_get_origin(field_value.annotation) is np.ndarray
160
169
  and field_value.annotation
161
170
  and field_name in data_dict
162
171
  ):
@@ -7,7 +7,6 @@ from typing import (
7
7
  TypeVar,
8
8
  get_args,
9
9
  get_origin,
10
- get_type_hints,
11
10
  is_typeddict,
12
11
  )
13
12
 
@@ -22,6 +21,7 @@ from ._derived_signal_backend import (
22
21
  from ._device import Device
23
22
  from ._signal import Signal, SignalR, SignalRW, SignalT, SignalW
24
23
  from ._signal_backend import Primitive, SignalDatatypeT
24
+ from ._utils import cached_get_type_hints
25
25
 
26
26
 
27
27
  class DerivedSignalFactory(Generic[TransformT]):
@@ -208,7 +208,7 @@ class DerivedSignalFactory(Generic[TransformT]):
208
208
 
209
209
 
210
210
  def _get_return_datatype(func: Callable[..., SignalDatatypeT]) -> type[SignalDatatypeT]:
211
- args = get_type_hints(func)
211
+ args = cached_get_type_hints(func)
212
212
  if "return" not in args:
213
213
  msg = f"{func} does not have a type hint for it's return value"
214
214
  raise TypeError(msg)
@@ -291,11 +291,7 @@ class DerivedSignalBackend(SignalBackend[SignalDatatypeT]):
291
291
  )
292
292
  raise RuntimeError(msg)
293
293
 
294
- async def put(self, value: SignalDatatypeT | None, wait: bool) -> None:
295
- if wait is False:
296
- msg = "Cannot put with wait=False"
297
- raise RuntimeError(msg)
298
-
294
+ async def put(self, value: SignalDatatypeT | None) -> None:
299
295
  value = error_if_none(
300
296
  value,
301
297
  "Must be given a value to put",
@@ -13,15 +13,13 @@ from typing import (
13
13
  Union,
14
14
  cast,
15
15
  get_args,
16
- get_origin,
17
- get_type_hints,
18
16
  runtime_checkable,
19
17
  )
20
18
 
21
19
  from ._device import Device, DeviceConnector, DeviceVector
22
20
  from ._signal import Ignore, Signal, SignalX
23
21
  from ._signal_backend import SignalBackend, SignalDatatype
24
- from ._utils import get_origin_class
22
+ from ._utils import cached_get_origin, cached_get_type_hints, get_origin_class
25
23
 
26
24
  SignalBackendT = TypeVar("SignalBackendT", bound=SignalBackend)
27
25
  DeviceConnectorT = TypeVar("DeviceConnectorT", bound=DeviceConnector)
@@ -117,9 +115,9 @@ class DeviceFiller(Generic[SignalBackendT, DeviceConnectorT]):
117
115
  # https://github.com/python/cpython/issues/124840
118
116
  cls = type(self._device)
119
117
  # Get hints without Annotated for determining types
120
- hints = get_type_hints(cls)
118
+ hints = cached_get_type_hints(cls)
121
119
  # Get hints with Annotated for wrapping signals and backends
122
- extra_hints = get_type_hints(cls, include_extras=True)
120
+ extra_hints = cached_get_type_hints(cls, include_extras=True)
123
121
  for attr_name, annotation in hints.items():
124
122
  if annotation is Ignore:
125
123
  self.ignored_signals.add(attr_name)
@@ -128,7 +126,7 @@ class DeviceFiller(Generic[SignalBackendT, DeviceConnectorT]):
128
126
  args = get_args(annotation)
129
127
 
130
128
  if (
131
- get_origin(annotation) is Union
129
+ cached_get_origin(annotation) is Union
132
130
  and types.NoneType in args
133
131
  and len(args) == 2
134
132
  ):
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
- from collections.abc import Callable
4
+ from collections.abc import Awaitable, Callable
5
5
  from functools import cached_property
6
6
  from typing import TYPE_CHECKING
7
7
  from unittest.mock import AsyncMock
@@ -17,6 +17,11 @@ from ._utils import Callback
17
17
  if TYPE_CHECKING:
18
18
  from ._device import LazyMock
19
19
 
20
+ MockPutCallback = (
21
+ Callable[[SignalDatatypeT], SignalDatatypeT | None]
22
+ | Callable[[SignalDatatypeT], Awaitable[SignalDatatypeT | None]]
23
+ )
24
+
20
25
 
21
26
  class MockSignalBackend(SignalBackend[SignalDatatypeT]):
22
27
  """Signal backend for testing, created by ``Device.connect(mock=True)``."""
@@ -42,13 +47,27 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]):
42
47
 
43
48
  # use existing Mock if provided
44
49
  self.mock = mock
50
+ self._mock_put_callback: MockPutCallback | None = None
45
51
  super().__init__(datatype=self.initial_backend.datatype)
46
52
 
53
+ def set_mock_put_callback(self, callback: MockPutCallback | None):
54
+ if "put_mock" in self.__dict__:
55
+ # put_mock cached property exists, so set the side effect on it
56
+ self.put_mock.side_effect = callback
57
+ else:
58
+ # put_mock doesn't exist, don't create it as that would be slow
59
+ # so just keep it internally
60
+ self._mock_put_callback = callback
61
+
47
62
  @cached_property
48
63
  def put_mock(self) -> AsyncMock:
49
64
  """Return the mock that will track calls to `put()`."""
50
65
  put_mock = AsyncMock(
51
- name="put", spec=Callable, side_effect=lambda *_, **__: None
66
+ name="put",
67
+ spec=Callable,
68
+ side_effect=self._mock_put_callback
69
+ if self._mock_put_callback
70
+ else lambda v: None,
52
71
  )
53
72
  self.mock().attach_mock(put_mock, "put")
54
73
  return put_mock
@@ -73,13 +92,12 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]):
73
92
  put_proceeds.set()
74
93
  return put_proceeds
75
94
 
76
- async def put(self, value: SignalDatatypeT | None, wait: bool):
77
- new_value = await self.put_mock(value, wait=wait)
95
+ async def put(self, value: SignalDatatypeT | None):
96
+ new_value = await self.put_mock(value)
78
97
  if new_value is None:
79
98
  new_value = value
80
- await self.soft_backend.put(new_value, wait=wait)
81
- if wait:
82
- await self.put_proceeds.wait()
99
+ await self.soft_backend.put(new_value)
100
+ await self.put_proceeds.wait()
83
101
 
84
102
  async def get_reading(self) -> Reading:
85
103
  return await self.soft_backend.get_reading()
@@ -1,9 +1,9 @@
1
- from collections.abc import Awaitable, Callable, Iterable, Iterator
1
+ from collections.abc import Iterable, Iterator
2
2
  from contextlib import contextmanager
3
3
  from unittest.mock import AsyncMock, Mock
4
4
 
5
5
  from ._device import Device, DeviceMock
6
- from ._mock_signal_backend import MockSignalBackend
6
+ from ._mock_signal_backend import MockPutCallback, MockSignalBackend
7
7
  from ._signal import Signal, SignalConnector, SignalR
8
8
  from ._signal_backend import SignalDatatypeT
9
9
 
@@ -109,16 +109,12 @@ def set_mock_values(
109
109
 
110
110
 
111
111
  @contextmanager
112
- def _unset_side_effect_cm(put_mock: AsyncMock):
112
+ def _unset_side_effect_cm(backend: MockSignalBackend):
113
113
  yield
114
- put_mock.side_effect = None
114
+ backend.set_mock_put_callback(None)
115
115
 
116
116
 
117
- def callback_on_mock_put(
118
- signal: Signal[SignalDatatypeT],
119
- callback: Callable[[SignalDatatypeT, bool], SignalDatatypeT | None]
120
- | Callable[[SignalDatatypeT, bool], Awaitable[SignalDatatypeT | None]],
121
- ):
117
+ def callback_on_mock_put(signal: Signal[SignalDatatypeT], callback: MockPutCallback):
122
118
  """For setting a callback when a backend is put to.
123
119
 
124
120
  Can either be used in a context, with the callback being unset on exit, or
@@ -132,8 +128,8 @@ def callback_on_mock_put(
132
128
  context.
133
129
  """
134
130
  backend = _get_mock_signal_backend(signal)
135
- backend.put_mock.side_effect = callback
136
- return _unset_side_effect_cm(backend.put_mock)
131
+ backend.set_mock_put_callback(callback)
132
+ return _unset_side_effect_cm(backend)
137
133
 
138
134
 
139
135
  def set_mock_put_proceeds(signal: Signal, proceeds: bool):
@@ -279,13 +279,11 @@ class SignalW(Signal[SignalDatatypeT], Movable):
279
279
  async def set(
280
280
  self,
281
281
  value: SignalDatatypeT,
282
- wait=True,
283
282
  timeout: CalculatableTimeout = CALCULATE_TIMEOUT,
284
283
  ) -> None:
285
284
  """Set the value and return a status saying when it's done.
286
285
 
287
286
  :param value: The value to set.
288
- :param wait: If True, wait for the set to complete.
289
287
  :param timeout: The timeout for the set.
290
288
  """
291
289
  if timeout == CALCULATE_TIMEOUT:
@@ -299,9 +297,7 @@ class SignalW(Signal[SignalDatatypeT], Movable):
299
297
  wait_jitter=0,
300
298
  ):
301
299
  with attempt:
302
- await _wait_for(
303
- self._connector.backend.put(value, wait=wait), timeout, source
304
- )
300
+ await _wait_for(self._connector.backend.put(value), timeout, source)
305
301
  self.log.debug(f"Successfully put value {value} to backend at source {source}")
306
302
 
307
303
 
@@ -321,19 +317,16 @@ class SignalX(Signal):
321
317
  """Signal that puts the default value."""
322
318
 
323
319
  @AsyncStatus.wrap
324
- async def trigger(
325
- self, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT
326
- ) -> None:
320
+ async def trigger(self, timeout: CalculatableTimeout = CALCULATE_TIMEOUT) -> None:
327
321
  """Trigger the action and return a status saying when it's done.
328
322
 
329
- :param wait: If True, wait for the trigger to complete.
330
323
  :param timeout: The timeout for the trigger.
331
324
  """
332
325
  if timeout == CALCULATE_TIMEOUT:
333
326
  timeout = self._timeout
334
327
  source = self._connector.backend.source(self.name, read=False)
335
328
  self.log.debug(f"Putting default value to backend at source {source}")
336
- await _wait_for(self._connector.backend.put(None, wait=wait), timeout, source)
329
+ await _wait_for(self._connector.backend.put(None), timeout, source)
337
330
  self.log.debug(f"Successfully put default value to backend at source {source}")
338
331
 
339
332
 
@@ -508,6 +501,12 @@ async def observe_signals_value(
508
501
  f"{[signal.source for signal in signals]}. "
509
502
  f"Last observed signal and value were {last_item}"
510
503
  ) from exc
504
+ except asyncio.CancelledError as exc:
505
+ raise asyncio.CancelledError(
506
+ f"Cancelled Error while waiting {iteration_timeout}s to update "
507
+ f"{[signal.source for signal in signals]}. "
508
+ f"Last observed signal and value were {last_item}"
509
+ ) from exc
511
510
  if done_status and item is done_status:
512
511
  if exc := done_status.exception():
513
512
  raise exc
@@ -679,7 +678,8 @@ async def set_and_wait_for_value(
679
678
  await set_and_wait_for_value(device.parameter, 1)
680
679
  ```
681
680
  For busy record, or other Signals with pattern:
682
- - Set Signal with `wait=True` and stash the Status
681
+ - Set `wait=non_zero` when creating the signal
682
+ - Set Signal and stash the Status
683
683
  - Read the same Signal to check the operation has started
684
684
  - Return the Status so calling code can wait for operation to complete
685
685
  ```python
@@ -1,34 +1,22 @@
1
1
  from abc import abstractmethod
2
2
  from collections.abc import Sequence
3
- from typing import Generic, TypedDict, TypeVar, get_origin
3
+ from typing import Generic, TypedDict, TypeVar
4
4
 
5
5
  import numpy as np
6
6
  from bluesky.protocols import Reading
7
7
  from event_model import DataKey, Dtype, Limits
8
8
 
9
- from ophyd_async.core._utils import (
9
+ from ._datatypes import Array1D, Table
10
+ from ._utils import (
10
11
  Callback,
11
12
  EnumTypes,
12
13
  StrictEnum,
13
14
  SubsetEnum,
14
15
  SupersetEnum,
16
+ cached_get_origin,
15
17
  get_enum_cls,
16
18
  )
17
19
 
18
- from ._table import Table
19
-
20
- DTypeScalar_co = TypeVar("DTypeScalar_co", covariant=True, bound=np.generic)
21
- """A numpy dtype like [](#numpy.float64)."""
22
-
23
-
24
- # To be a 1D array shape should really be tuple[int], but np.array()
25
- # currently produces tuple[int, ...] even when it has 1D input args
26
- # https://github.com/numpy/numpy/issues/28077#issuecomment-2566485178
27
- Array1D = np.ndarray[tuple[int, ...], np.dtype[DTypeScalar_co]]
28
- """A type alias for a 1D numpy array with a specific scalar data type.
29
-
30
- E.g. `Array1D[np.float64]` is a 1D numpy array of 64-bit floats."""
31
-
32
20
  Primitive = bool | int | float | str
33
21
  SignalDatatype = (
34
22
  Primitive
@@ -91,7 +79,7 @@ class SignalBackend(Generic[SignalDatatypeT]):
91
79
  """Connect to underlying hardware."""
92
80
 
93
81
  @abstractmethod
94
- async def put(self, value: SignalDatatypeT | None, wait: bool):
82
+ async def put(self, value: SignalDatatypeT | None):
95
83
  """Put a value to the PV, if wait then wait for completion."""
96
84
 
97
85
  @abstractmethod
@@ -142,7 +130,7 @@ class SignalMetadata(TypedDict, total=False):
142
130
  def _datakey_dtype(datatype: type[SignalDatatype]) -> Dtype:
143
131
  if (
144
132
  datatype is np.ndarray
145
- or get_origin(datatype) in (Sequence, np.ndarray)
133
+ or cached_get_origin(datatype) in (Sequence, np.ndarray)
146
134
  or issubclass(datatype, Table)
147
135
  ):
148
136
  return "array"
@@ -161,7 +149,7 @@ def _datakey_dtype_numpy(
161
149
  # The value already has a dtype, use that
162
150
  return value.dtype
163
151
  elif (
164
- get_origin(datatype) is Sequence
152
+ cached_get_origin(datatype) is Sequence
165
153
  or datatype is str
166
154
  or issubclass(datatype, EnumTypes)
167
155
  ):
@@ -6,12 +6,13 @@ from abc import abstractmethod
6
6
  from collections.abc import Sequence
7
7
  from dataclasses import dataclass
8
8
  from functools import lru_cache
9
- from typing import Any, Generic, get_args, get_origin
9
+ from typing import Any, Generic, get_args
10
10
 
11
11
  import numpy as np
12
12
  from bluesky.protocols import Reading
13
13
  from event_model import DataKey
14
14
 
15
+ from ._datatypes import Table
15
16
  from ._signal_backend import (
16
17
  Array1D,
17
18
  EnumT,
@@ -24,8 +25,7 @@ from ._signal_backend import (
24
25
  make_datakey,
25
26
  make_metadata,
26
27
  )
27
- from ._table import Table
28
- from ._utils import Callback, get_dtype, get_enum_cls
28
+ from ._utils import Callback, cached_get_origin, get_dtype, get_enum_cls
29
29
 
30
30
 
31
31
  class SoftConverter(Generic[SignalDatatypeT]):
@@ -97,11 +97,11 @@ def make_converter(datatype: type[SignalDatatype]) -> SoftConverter:
97
97
  enum_cls = get_enum_cls(datatype)
98
98
  if datatype in (Sequence[str], typing.Sequence[str]):
99
99
  return SequenceStrSoftConverter()
100
- elif get_origin(datatype) in (Sequence, typing.Sequence) and enum_cls:
100
+ elif cached_get_origin(datatype) in (Sequence, typing.Sequence) and enum_cls:
101
101
  return SequenceEnumSoftConverter(enum_cls)
102
102
  elif datatype is np.ndarray:
103
103
  return NDArraySoftConverter()
104
- elif get_origin(datatype) == np.ndarray:
104
+ elif cached_get_origin(datatype) == np.ndarray:
105
105
  if datatype not in get_args(SignalDatatype):
106
106
  raise TypeError(f"Expected Array1D[dtype], got {datatype}")
107
107
  return NDArraySoftConverter(get_dtype(datatype))
@@ -160,7 +160,7 @@ class SoftSignalBackend(SignalBackend[SignalDatatypeT]):
160
160
  async def connect(self, timeout: float):
161
161
  pass
162
162
 
163
- async def put(self, value: SignalDatatypeT | None, wait: bool) -> None:
163
+ async def put(self, value: SignalDatatypeT | None) -> None:
164
164
  write_value = self.initial_value if value is None else value
165
165
  self.set_value(write_value)
166
166
 
@@ -3,8 +3,10 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import asyncio
6
+ import contextlib
6
7
  import functools
7
8
  import time
9
+ from asyncio import CancelledError
8
10
  from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
9
11
  from dataclasses import asdict, replace
10
12
  from typing import Generic
@@ -17,16 +19,37 @@ from ._utils import Callback, P, T, WatcherUpdate
17
19
 
18
20
 
19
21
  class AsyncStatusBase(Status, Awaitable[None]):
20
- """Convert asyncio awaitable to bluesky Status interface."""
22
+ """Convert asyncio awaitable to bluesky Status interface.
23
+
24
+ Can be used as an async context manager to automatically cancel the calling
25
+ task when the status completes. This is useful for bounding loop execution:
26
+ when the status completes, the calling task is cancelled, causing the loop
27
+ to exit. If the loop completes first, the status task is automatically cancelled.
28
+ """
21
29
 
22
30
  def __init__(self, awaitable: Coroutine | asyncio.Task, name: str | None = None):
23
31
  if isinstance(awaitable, asyncio.Task):
24
32
  self.task = awaitable
25
33
  else:
26
- self.task = asyncio.create_task(awaitable)
34
+
35
+ async def wait_with_error_message(awaitable):
36
+ try:
37
+ await awaitable
38
+ except CancelledError as e:
39
+ raise CancelledError(
40
+ f"CancelledError while awaiting {awaitable} on {name}"
41
+ ) from e
42
+
43
+ self.task = asyncio.create_task(wait_with_error_message(awaitable))
44
+ # There is a small chance we could be cancelled before
45
+ # wait_with_error_message starts.
46
+ # Avoid complaints about awaitable not awaited if task is
47
+ # pre-emptively cancelled, by ensuring it is always disposed
48
+ self.task.add_done_callback(lambda _: awaitable.close())
27
49
  self.task.add_done_callback(self._run_callbacks)
28
50
  self._callbacks: list[Callback[Status]] = []
29
51
  self._name = name
52
+ self._cancelled_error_ok = False
30
53
 
31
54
  def __await__(self):
32
55
  return self.task.__await__()
@@ -85,6 +108,36 @@ class AsyncStatusBase(Status, Awaitable[None]):
85
108
  f"task: {self.task.get_coro()}, {status}>"
86
109
  )
87
110
 
111
+ async def __aenter__(self):
112
+ # Grab the calling task, the one that is doing `with status``
113
+ calling_task = asyncio.current_task()
114
+ if calling_task is None:
115
+ raise RuntimeError("Can only use in a context manager inside a task")
116
+
117
+ def _cancel_calling_task(task: asyncio.Task, calling_task=calling_task):
118
+ # If no-one cancelled our child task, then it is expected
119
+ # that we want to break out of the calling task with block
120
+ # so mark that the CancelledError should be suppressed on exit
121
+ self._cancelled_error_ok = not task.cancelled()
122
+ calling_task.cancel()
123
+
124
+ # When our child task is done, then cancel the calling task
125
+ self.task.add_done_callback(_cancel_calling_task)
126
+ return self
127
+
128
+ async def __aexit__(self, exc_type, exc, tb):
129
+ self.task.cancel()
130
+ # Need to await the task to suppress teardown warnings, but
131
+ # we know it will raise CancelledError as we just cancelled it
132
+ with contextlib.suppress(CancelledError):
133
+ await self.task
134
+ if exc_type is CancelledError and self._cancelled_error_ok:
135
+ # Suppress error as we cancelled it in _cancel_calling_task
136
+ return True
137
+ else:
138
+ # Raise error as we didn't cause it
139
+ return False
140
+
88
141
  __str__ = __repr__
89
142
 
90
143
 
@@ -94,13 +147,37 @@ class AsyncStatus(AsyncStatusBase):
94
147
  :param awaitable: The coroutine or task to await.
95
148
  :param name: The name of the device, if available.
96
149
 
97
- For example:
150
+ Can be awaited like a standard Task:
151
+
98
152
  ```python
99
153
  status = AsyncStatus(asyncio.sleep(1))
100
154
  assert not status.done
101
- await status # waits for 1 second
155
+ await status # waits for 1 second
102
156
  assert status.done
103
157
  ```
158
+
159
+ Can also be used as a context manager to bound loop execution. When the status
160
+ completes, the calling task is cancelled, causing loops to exit:
161
+
162
+ ```python
163
+ async with motor.set(target_position):
164
+ async for value in observe_value(detector):
165
+ process_reading(value)
166
+ # Loop exits automatically when motor reaches position
167
+ ```
168
+
169
+ If the loop completes before the status, the status task is cancelled:
170
+
171
+ ```python
172
+ async with AsyncStatus(long_operation()):
173
+ for i in range(3):
174
+ await process_step(i)
175
+ # Loop completes, long_operation() is cancelled
176
+ ```
177
+
178
+ Note that the body of the with statement will only break at a suspension
179
+ point like `async for` or `await`, so body code without these suspension
180
+ points will continue even if the status completes.
104
181
  """
105
182
 
106
183
  @classmethod
File without changes