ophyd-async 0.14.1__py3-none-any.whl → 0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ophyd_async/_version.py +2 -2
  2. ophyd_async/core/__init__.py +17 -5
  3. ophyd_async/core/{_table.py → _datatypes.py} +18 -9
  4. ophyd_async/core/_derived_signal.py +57 -24
  5. ophyd_async/core/_derived_signal_backend.py +1 -5
  6. ophyd_async/core/_device_filler.py +30 -7
  7. ophyd_async/core/_mock_signal_backend.py +25 -7
  8. ophyd_async/core/_mock_signal_utils.py +7 -11
  9. ophyd_async/core/_signal.py +11 -11
  10. ophyd_async/core/_signal_backend.py +7 -19
  11. ophyd_async/core/_soft_signal_backend.py +6 -6
  12. ophyd_async/core/_status.py +81 -4
  13. ophyd_async/core/_typing.py +0 -0
  14. ophyd_async/core/_utils.py +57 -7
  15. ophyd_async/epics/adcore/_core_io.py +12 -5
  16. ophyd_async/epics/adcore/_core_logic.py +1 -1
  17. ophyd_async/epics/core/__init__.py +2 -1
  18. ophyd_async/epics/core/_aioca.py +13 -3
  19. ophyd_async/epics/core/_epics_connector.py +4 -1
  20. ophyd_async/epics/core/_p4p.py +13 -3
  21. ophyd_async/epics/core/_signal.py +18 -6
  22. ophyd_async/epics/core/_util.py +23 -3
  23. ophyd_async/epics/demo/_motor.py +2 -2
  24. ophyd_async/epics/motor.py +15 -17
  25. ophyd_async/epics/odin/_odin_io.py +1 -1
  26. ophyd_async/epics/pmac/_pmac_io.py +23 -4
  27. ophyd_async/epics/pmac/_pmac_trajectory.py +47 -10
  28. ophyd_async/fastcs/eiger/_eiger_io.py +20 -1
  29. ophyd_async/fastcs/jungfrau/_signals.py +4 -1
  30. ophyd_async/fastcs/panda/_block.py +28 -6
  31. ophyd_async/fastcs/panda/_writer.py +1 -3
  32. ophyd_async/tango/core/_tango_transport.py +7 -17
  33. ophyd_async/tango/demo/_counter.py +2 -2
  34. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/METADATA +1 -1
  35. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/RECORD +38 -37
  36. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/WHEEL +1 -1
  37. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/licenses/LICENSE +0 -0
  38. {ophyd_async-0.14.1.dist-info → ophyd_async-0.15.dist-info}/top_level.txt +0 -0
@@ -6,12 +6,13 @@ from abc import abstractmethod
6
6
  from collections.abc import Sequence
7
7
  from dataclasses import dataclass
8
8
  from functools import lru_cache
9
- from typing import Any, Generic, get_args, get_origin
9
+ from typing import Any, Generic, get_args
10
10
 
11
11
  import numpy as np
12
12
  from bluesky.protocols import Reading
13
13
  from event_model import DataKey
14
14
 
15
+ from ._datatypes import Table
15
16
  from ._signal_backend import (
16
17
  Array1D,
17
18
  EnumT,
@@ -24,8 +25,7 @@ from ._signal_backend import (
24
25
  make_datakey,
25
26
  make_metadata,
26
27
  )
27
- from ._table import Table
28
- from ._utils import Callback, get_dtype, get_enum_cls
28
+ from ._utils import Callback, cached_get_origin, get_dtype, get_enum_cls
29
29
 
30
30
 
31
31
  class SoftConverter(Generic[SignalDatatypeT]):
@@ -97,11 +97,11 @@ def make_converter(datatype: type[SignalDatatype]) -> SoftConverter:
97
97
  enum_cls = get_enum_cls(datatype)
98
98
  if datatype in (Sequence[str], typing.Sequence[str]):
99
99
  return SequenceStrSoftConverter()
100
- elif get_origin(datatype) in (Sequence, typing.Sequence) and enum_cls:
100
+ elif cached_get_origin(datatype) in (Sequence, typing.Sequence) and enum_cls:
101
101
  return SequenceEnumSoftConverter(enum_cls)
102
102
  elif datatype is np.ndarray:
103
103
  return NDArraySoftConverter()
104
- elif get_origin(datatype) == np.ndarray:
104
+ elif cached_get_origin(datatype) == np.ndarray:
105
105
  if datatype not in get_args(SignalDatatype):
106
106
  raise TypeError(f"Expected Array1D[dtype], got {datatype}")
107
107
  return NDArraySoftConverter(get_dtype(datatype))
@@ -160,7 +160,7 @@ class SoftSignalBackend(SignalBackend[SignalDatatypeT]):
160
160
  async def connect(self, timeout: float):
161
161
  pass
162
162
 
163
- async def put(self, value: SignalDatatypeT | None, wait: bool) -> None:
163
+ async def put(self, value: SignalDatatypeT | None) -> None:
164
164
  write_value = self.initial_value if value is None else value
165
165
  self.set_value(write_value)
166
166
 
@@ -3,8 +3,10 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import asyncio
6
+ import contextlib
6
7
  import functools
7
8
  import time
9
+ from asyncio import CancelledError
8
10
  from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
9
11
  from dataclasses import asdict, replace
10
12
  from typing import Generic
@@ -17,16 +19,37 @@ from ._utils import Callback, P, T, WatcherUpdate
17
19
 
18
20
 
19
21
  class AsyncStatusBase(Status, Awaitable[None]):
20
- """Convert asyncio awaitable to bluesky Status interface."""
22
+ """Convert asyncio awaitable to bluesky Status interface.
23
+
24
+ Can be used as an async context manager to automatically cancel the calling
25
+ task when the status completes. This is useful for bounding loop execution:
26
+ when the status completes, the calling task is cancelled, causing the loop
27
+ to exit. If the loop completes first, the status task is automatically cancelled.
28
+ """
21
29
 
22
30
  def __init__(self, awaitable: Coroutine | asyncio.Task, name: str | None = None):
23
31
  if isinstance(awaitable, asyncio.Task):
24
32
  self.task = awaitable
25
33
  else:
26
- self.task = asyncio.create_task(awaitable)
34
+
35
+ async def wait_with_error_message(awaitable):
36
+ try:
37
+ await awaitable
38
+ except CancelledError as e:
39
+ raise CancelledError(
40
+ f"CancelledError while awaiting {awaitable} on {name}"
41
+ ) from e
42
+
43
+ self.task = asyncio.create_task(wait_with_error_message(awaitable))
44
+ # There is a small chance we could be cancelled before
45
+ # wait_with_error_message starts.
46
+ # Avoid complaints about awaitable not awaited if task is
47
+ # pre-emptively cancelled, by ensuring it is always disposed
48
+ self.task.add_done_callback(lambda _: awaitable.close())
27
49
  self.task.add_done_callback(self._run_callbacks)
28
50
  self._callbacks: list[Callback[Status]] = []
29
51
  self._name = name
52
+ self._cancelled_error_ok = False
30
53
 
31
54
  def __await__(self):
32
55
  return self.task.__await__()
@@ -85,6 +108,36 @@ class AsyncStatusBase(Status, Awaitable[None]):
85
108
  f"task: {self.task.get_coro()}, {status}>"
86
109
  )
87
110
 
111
+ async def __aenter__(self):
112
+ # Grab the calling task, the one that is doing `with status``
113
+ calling_task = asyncio.current_task()
114
+ if calling_task is None:
115
+ raise RuntimeError("Can only use in a context manager inside a task")
116
+
117
+ def _cancel_calling_task(task: asyncio.Task, calling_task=calling_task):
118
+ # If no-one cancelled our child task, then it is expected
119
+ # that we want to break out of the calling task with block
120
+ # so mark that the CancelledError should be suppressed on exit
121
+ self._cancelled_error_ok = not task.cancelled()
122
+ calling_task.cancel()
123
+
124
+ # When our child task is done, then cancel the calling task
125
+ self.task.add_done_callback(_cancel_calling_task)
126
+ return self
127
+
128
+ async def __aexit__(self, exc_type, exc, tb):
129
+ self.task.cancel()
130
+ # Need to await the task to suppress teardown warnings, but
131
+ # we know it will raise CancelledError as we just cancelled it
132
+ with contextlib.suppress(CancelledError):
133
+ await self.task
134
+ if exc_type is CancelledError and self._cancelled_error_ok:
135
+ # Suppress error as we cancelled it in _cancel_calling_task
136
+ return True
137
+ else:
138
+ # Raise error as we didn't cause it
139
+ return False
140
+
88
141
  __str__ = __repr__
89
142
 
90
143
 
@@ -94,13 +147,37 @@ class AsyncStatus(AsyncStatusBase):
94
147
  :param awaitable: The coroutine or task to await.
95
148
  :param name: The name of the device, if available.
96
149
 
97
- For example:
150
+ Can be awaited like a standard Task:
151
+
98
152
  ```python
99
153
  status = AsyncStatus(asyncio.sleep(1))
100
154
  assert not status.done
101
- await status # waits for 1 second
155
+ await status # waits for 1 second
102
156
  assert status.done
103
157
  ```
158
+
159
+ Can also be used as a context manager to bound loop execution. When the status
160
+ completes, the calling task is cancelled, causing loops to exit:
161
+
162
+ ```python
163
+ async with motor.set(target_position):
164
+ async for value in observe_value(detector):
165
+ process_reading(value)
166
+ # Loop exits automatically when motor reaches position
167
+ ```
168
+
169
+ If the loop completes before the status, the status task is cancelled:
170
+
171
+ ```python
172
+ async with AsyncStatus(long_operation()):
173
+ for i in range(3):
174
+ await process_step(i)
175
+ # Loop completes, long_operation() is cancelled
176
+ ```
177
+
178
+ Note that the body of the with statement will only break at a suspension
179
+ point like `async for` or `await`, so body code without these suspension
180
+ points will continue even if the status completes.
104
181
  """
105
182
 
106
183
  @classmethod
File without changes
@@ -5,6 +5,8 @@ import logging
5
5
  from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
6
6
  from dataclasses import dataclass
7
7
  from enum import Enum, EnumMeta, StrEnum
8
+ from functools import lru_cache
9
+ from inspect import isawaitable
8
10
  from typing import (
9
11
  Any,
10
12
  Generic,
@@ -13,6 +15,7 @@ from typing import (
13
15
  TypeVar,
14
16
  get_args,
15
17
  get_origin,
18
+ get_type_hints,
16
19
  )
17
20
 
18
21
  import numpy as np
@@ -204,6 +207,20 @@ async def wait_for_connection(**coros: Awaitable[None]):
204
207
  raise NotConnectedError.with_other_exceptions_logged(exceptions)
205
208
 
206
209
 
210
+ # Cache get_type_hints calls to avoid expensive introspection across the codebase
211
+ @lru_cache(maxsize=512)
212
+ def cached_get_type_hints(cls: type, include_extras: bool = False) -> dict[str, Any]:
213
+ """Get type hints with caching to avoid expensive introspection."""
214
+ return get_type_hints(cls, include_extras=include_extras)
215
+
216
+
217
+ # Cache get_origin calls to avoid expensive type introspection
218
+ @lru_cache(maxsize=512)
219
+ def cached_get_origin(tp: Any) -> Any:
220
+ """Get the origin of a type with caching."""
221
+ return get_origin(tp)
222
+
223
+
207
224
  def get_dtype(datatype: type) -> np.dtype:
208
225
  """Get the runtime dtype from a numpy ndarray type annotation.
209
226
 
@@ -215,7 +232,7 @@ def get_dtype(datatype: type) -> np.dtype:
215
232
 
216
233
  ```
217
234
  """
218
- if not get_origin(datatype) == np.ndarray:
235
+ if not cached_get_origin(datatype) == np.ndarray:
219
236
  raise TypeError(f"Expected Array1D[dtype], got {datatype}")
220
237
  # datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]
221
238
  # so extract numpy.float64 from it
@@ -240,7 +257,7 @@ def get_enum_cls(datatype: type | None) -> type[EnumTypes] | None:
240
257
 
241
258
  ```
242
259
  """
243
- if get_origin(datatype) is Sequence:
260
+ if cached_get_origin(datatype) is Sequence:
244
261
  datatype = get_args(datatype)[0]
245
262
  datatype = get_origin_class(datatype)
246
263
  if datatype and issubclass(datatype, Enum):
@@ -292,10 +309,36 @@ async def merge_gathered_dicts(
292
309
  return ret
293
310
 
294
311
 
295
- async def gather_dict(coros: Mapping[T, Awaitable[V]]) -> dict[T, V]:
296
- """Take named coros and return a dict of their name to their return value."""
297
- values = await asyncio.gather(*coros.values())
298
- return dict(zip(coros, values, strict=True))
312
+ def _partition_awaitable(
313
+ maybe_awaitables: Iterable[T | Awaitable[T]],
314
+ ) -> tuple[dict[int, Awaitable[T]], dict[int, T]]:
315
+ awaitable: dict[int, Awaitable[T]] = {}
316
+ not_awaitable: dict[int, T] = {}
317
+ for i, x in enumerate(maybe_awaitables):
318
+ if isawaitable(x):
319
+ awaitable[i] = x
320
+ else:
321
+ not_awaitable[i] = x
322
+ return awaitable, not_awaitable
323
+
324
+
325
+ async def gather_dict(coros: dict[T | Awaitable[T], V | Awaitable[V]]) -> dict[T, V]:
326
+ """Await any coros in the keys or values of a dictionary."""
327
+ k_awaitable, k_not_awaitable = _partition_awaitable(coros.keys())
328
+ v_awaitable, v_not_awaitable = _partition_awaitable(coros.values())
329
+
330
+ # Await all awaitables in parallel
331
+ k_results, v_results = await asyncio.gather(
332
+ asyncio.gather(*k_awaitable.values()),
333
+ asyncio.gather(*v_awaitable.values()),
334
+ )
335
+
336
+ # Combine awaited and non-awaited values by index
337
+ k_map = k_not_awaitable | dict(zip(k_awaitable, k_results, strict=True))
338
+ v_map = v_not_awaitable | dict(zip(v_awaitable, v_results, strict=True))
339
+
340
+ # Reconstruct dict in original index order
341
+ return {k_map[i]: v_map[i] for i in range(len(coros))}
299
342
 
300
343
 
301
344
  def in_micros(t: float) -> int:
@@ -310,8 +353,10 @@ def in_micros(t: float) -> int:
310
353
  return int(np.ceil(t * 1e6))
311
354
 
312
355
 
356
+ @lru_cache(maxsize=512)
313
357
  def get_origin_class(annotatation: Any) -> type | None:
314
- origin = get_origin(annotatation) or annotatation
358
+ """Get the origin class of a type annotation with caching."""
359
+ origin = cached_get_origin(annotatation) or annotatation
315
360
  if isinstance(origin, type):
316
361
  return origin
317
362
  return None
@@ -363,3 +408,8 @@ def error_if_none(value: T | None, msg: str) -> T:
363
408
  if value is None:
364
409
  raise RuntimeError(msg)
365
410
  return value
411
+
412
+
413
+ def non_zero(value):
414
+ """Return True if the value cast to an int is not zero."""
415
+ return int(value) != 0
@@ -8,8 +8,9 @@ from ophyd_async.core import (
8
8
  SignalR,
9
9
  SignalRW,
10
10
  StrictEnum,
11
+ non_zero,
11
12
  )
12
- from ophyd_async.epics.core import EpicsDevice, PvSuffix
13
+ from ophyd_async.epics.core import EpicsDevice, EpicsOptions, PvSuffix
13
14
 
14
15
  from ._utils import ADBaseDataType, ADFileWriteMode, ADImageMode, convert_ad_dtype_to_np
15
16
 
@@ -23,7 +24,7 @@ class NDArrayBaseIO(EpicsDevice):
23
24
 
24
25
  unique_id: A[SignalR[int], PvSuffix("UniqueId_RBV")]
25
26
  nd_attributes_file: A[SignalRW[str], PvSuffix("NDAttributesFile")]
26
- acquire: A[SignalRW[bool], PvSuffix.rbv("Acquire")]
27
+ acquire: A[SignalRW[bool], PvSuffix.rbv("Acquire"), EpicsOptions(wait=non_zero)]
27
28
  array_size_x: A[SignalR[int], PvSuffix("ArraySizeX_RBV")]
28
29
  array_size_y: A[SignalR[int], PvSuffix("ArraySizeY_RBV")]
29
30
  data_type: A[SignalR[ADBaseDataType], PvSuffix("DataType_RBV")]
@@ -200,7 +201,7 @@ class NDFileIO(NDArrayBaseIO):
200
201
  file_write_mode: A[SignalRW[ADFileWriteMode], PvSuffix.rbv("FileWriteMode")]
201
202
  num_capture: A[SignalRW[int], PvSuffix.rbv("NumCapture")]
202
203
  num_captured: A[SignalR[int], PvSuffix("NumCaptured_RBV")]
203
- capture: A[SignalRW[bool], PvSuffix.rbv("Capture")]
204
+ capture: A[SignalRW[bool], PvSuffix.rbv("Capture"), EpicsOptions(wait=non_zero)]
204
205
  array_size0: A[SignalR[int], PvSuffix("ArraySize0")]
205
206
  array_size1: A[SignalR[int], PvSuffix("ArraySize1")]
206
207
  create_directory: A[SignalRW[int], PvSuffix("CreateDirectory")]
@@ -240,11 +241,17 @@ class NDCBFlushOnSoftTrgMode(StrictEnum):
240
241
 
241
242
 
242
243
  class NDPluginCBIO(NDPluginBaseIO):
244
+ """Plugin that outputs pre/post-trigger NDArrays based on defined conditions.
245
+
246
+ This mirrors the interface provided by ADCore//Db/NDCircularBuff.template
247
+ See HTML docs at https://areadetector.github.io/areaDetector/ADCore/NDPluginCircularBuff.html
248
+ """
249
+
243
250
  pre_count: A[SignalRW[int], PvSuffix.rbv("PreCount")]
244
251
  post_count: A[SignalRW[int], PvSuffix.rbv("PostCount")]
245
252
  preset_trigger_count: A[SignalRW[int], PvSuffix.rbv("PresetTriggerCount")]
246
- trigger: A[SignalRW[bool], PvSuffix.rbv("Trigger")]
247
- capture: A[SignalRW[bool], PvSuffix.rbv("Capture")]
253
+ trigger: A[SignalRW[bool], PvSuffix.rbv("Trigger"), EpicsOptions(wait=non_zero)]
254
+ capture: A[SignalRW[bool], PvSuffix.rbv("Capture"), EpicsOptions(wait=non_zero)]
248
255
  flush_on_soft_trg: A[
249
256
  SignalRW[NDCBFlushOnSoftTrgMode], PvSuffix.rbv("FlushOnSoftTrg")
250
257
  ]
@@ -218,7 +218,7 @@ class ADBaseContAcqController(ADBaseController[ADBaseIO]):
218
218
  )
219
219
 
220
220
  # Send the trigger to begin acquisition
221
- await self.cb_plugin.trigger.set(True, wait=False)
221
+ await self.cb_plugin.trigger.set(True)
222
222
 
223
223
  async def disarm(self) -> None:
224
224
  await stop_busy_record(self.cb_plugin.capture, False)
@@ -10,7 +10,7 @@ from ._signal import (
10
10
  epics_signal_w,
11
11
  epics_signal_x,
12
12
  )
13
- from ._util import stop_busy_record
13
+ from ._util import EpicsOptions, stop_busy_record
14
14
 
15
15
  __all__ = [
16
16
  "PviDeviceConnector",
@@ -25,4 +25,5 @@ __all__ = [
25
25
  "epics_signal_w",
26
26
  "epics_signal_x",
27
27
  "stop_busy_record",
28
+ "EpicsOptions",
28
29
  ]
@@ -36,7 +36,12 @@ from ophyd_async.core import (
36
36
  wait_for_connection,
37
37
  )
38
38
 
39
- from ._util import EpicsSignalBackend, format_datatype, get_supported_values
39
+ from ._util import (
40
+ EpicsOptions,
41
+ EpicsSignalBackend,
42
+ format_datatype,
43
+ get_supported_values,
44
+ )
40
45
 
41
46
  logger = logging.getLogger("ophyd_async")
42
47
 
@@ -255,12 +260,13 @@ class CaSignalBackend(EpicsSignalBackend[SignalDatatypeT]):
255
260
  datatype: type[SignalDatatypeT] | None,
256
261
  read_pv: str = "",
257
262
  write_pv: str = "",
263
+ options: EpicsOptions | None = None,
258
264
  ):
259
265
  self.converter: CaConverter = DisconnectedCaConverter(float, dbr.DBR_DOUBLE)
260
266
  self.initial_values: dict[str, AugmentedValue] = {}
261
267
  self.subscription: Subscription | None = None
262
268
  self._all_updates = _all_updates()
263
- super().__init__(datatype, read_pv, write_pv)
269
+ super().__init__(datatype, read_pv, write_pv, options)
264
270
 
265
271
  def source(self, name: str, read: bool):
266
272
  return f"ca://{self.read_pv if read else self.write_pv}"
@@ -299,11 +305,15 @@ class CaSignalBackend(EpicsSignalBackend[SignalDatatypeT]):
299
305
  "alarm_severity": -1 if value.severity > 2 else value.severity,
300
306
  }
301
307
 
302
- async def put(self, value: SignalDatatypeT | None, wait: bool):
308
+ async def put(self, value: SignalDatatypeT | None):
303
309
  if value is None:
304
310
  write_value = self.initial_values[self.write_pv]
305
311
  else:
306
312
  write_value = self.converter.write_value(value)
313
+ if callable(self.options.wait):
314
+ wait = self.options.wait(value)
315
+ else:
316
+ wait = self.options.wait
307
317
  try:
308
318
  await caput(
309
319
  self.write_pv,
@@ -5,7 +5,8 @@ from typing import Any
5
5
 
6
6
  from ophyd_async.core import Device, DeviceConnector, DeviceFiller
7
7
 
8
- from ._signal import EpicsSignalBackend, get_signal_backend_type, split_protocol_from_pv
8
+ from ._signal import get_signal_backend_type, split_protocol_from_pv
9
+ from ._util import EpicsOptions, EpicsSignalBackend
9
10
 
10
11
 
11
12
  @dataclass
@@ -44,6 +45,8 @@ def fill_backend_with_prefix(
44
45
  backend.write_pv = prefix + (
45
46
  annotation.write_suffix or annotation.read_suffix
46
47
  )
48
+ elif isinstance(annotation, EpicsOptions):
49
+ backend.options = annotation
47
50
  else:
48
51
  unhandled.append(annotation)
49
52
  annotations.extend(unhandled)
@@ -29,7 +29,12 @@ from ophyd_async.core import (
29
29
  wait_for_connection,
30
30
  )
31
31
 
32
- from ._util import EpicsSignalBackend, format_datatype, get_supported_values
32
+ from ._util import (
33
+ EpicsOptions,
34
+ EpicsSignalBackend,
35
+ format_datatype,
36
+ get_supported_values,
37
+ )
33
38
 
34
39
  logger = logging.getLogger("ophyd_async")
35
40
 
@@ -347,11 +352,12 @@ class PvaSignalBackend(EpicsSignalBackend[SignalDatatypeT]):
347
352
  datatype: type[SignalDatatypeT] | None,
348
353
  read_pv: str = "",
349
354
  write_pv: str = "",
355
+ options: EpicsOptions | None = None,
350
356
  ):
351
357
  self.converter: PvaConverter = DisconnectedPvaConverter(float)
352
358
  self.initial_values: dict[str, Any] = {}
353
359
  self.subscription: Subscription | None = None
354
- super().__init__(datatype, read_pv, write_pv)
360
+ super().__init__(datatype, read_pv, write_pv, options)
355
361
 
356
362
  def source(self, name: str, read: bool):
357
363
  return f"pva://{self.read_pv if read else self.write_pv}"
@@ -380,11 +386,15 @@ class PvaSignalBackend(EpicsSignalBackend[SignalDatatypeT]):
380
386
  "alarm_severity": -1 if sv > 2 else sv,
381
387
  }
382
388
 
383
- async def put(self, value: SignalDatatypeT | None, wait: bool):
389
+ async def put(self, value: SignalDatatypeT | None):
384
390
  if value is None:
385
391
  write_value = self.initial_values[self.write_pv]["value"]
386
392
  else:
387
393
  write_value = self.converter.write_value(value)
394
+ if callable(self.options.wait):
395
+ wait = self.options.wait(value)
396
+ else:
397
+ wait = self.options.wait
388
398
  await context().put(self.write_pv, {"value": write_value}, wait=wait)
389
399
 
390
400
  async def get_datakey(self, source: str) -> DataKey:
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Callable
5
6
  from enum import Enum
6
7
 
7
8
  from ophyd_async.core import (
@@ -15,7 +16,7 @@ from ophyd_async.core import (
15
16
  get_unique,
16
17
  )
17
18
 
18
- from ._util import EpicsSignalBackend, get_pv_basename_and_field
19
+ from ._util import EpicsOptions, EpicsSignalBackend, get_pv_basename_and_field
19
20
 
20
21
 
21
22
  class EpicsProtocol(Enum):
@@ -78,14 +79,18 @@ def get_signal_backend_type(protocol: EpicsProtocol) -> type[EpicsSignalBackend]
78
79
 
79
80
 
80
81
  def _epics_signal_backend(
81
- datatype: type[SignalDatatypeT] | None, read_pv: str, write_pv: str
82
+ datatype: type[SignalDatatypeT] | None,
83
+ read_pv: str,
84
+ write_pv: str,
85
+ options: EpicsOptions | None = None,
82
86
  ) -> SignalBackend[SignalDatatypeT]:
83
87
  """Create an epics signal backend."""
84
88
  r_protocol, r_pv = split_protocol_from_pv(read_pv)
85
89
  w_protocol, w_pv = split_protocol_from_pv(write_pv)
86
90
  protocol = get_unique({read_pv: r_protocol, write_pv: w_protocol}, "protocols")
91
+
87
92
  signal_backend_type = get_signal_backend_type(protocol)
88
- return signal_backend_type(datatype, r_pv, w_pv)
93
+ return signal_backend_type(datatype, r_pv, w_pv, options)
89
94
 
90
95
 
91
96
  def epics_signal_rw(
@@ -95,6 +100,7 @@ def epics_signal_rw(
95
100
  name: str = "",
96
101
  timeout: float = DEFAULT_TIMEOUT,
97
102
  attempts: int = 1,
103
+ wait: bool | Callable[[SignalDatatypeT], bool] = True,
98
104
  ) -> SignalRW[SignalDatatypeT]:
99
105
  """Create a `SignalRW` backed by 1 or 2 EPICS PVs.
100
106
 
@@ -104,7 +110,9 @@ def epics_signal_rw(
104
110
  :param name: The name of the signal (defaults to empty string)
105
111
  :param timeout: A timeout to be used when reading (not connecting) this signal
106
112
  """
107
- backend = _epics_signal_backend(datatype, read_pv, write_pv or read_pv)
113
+ backend = _epics_signal_backend(
114
+ datatype, read_pv, write_pv or read_pv, EpicsOptions(wait=wait)
115
+ )
108
116
  return SignalRW(backend, name=name, timeout=timeout, attempts=attempts)
109
117
 
110
118
 
@@ -115,6 +123,7 @@ def epics_signal_rw_rbv(
115
123
  name: str = "",
116
124
  timeout: float = DEFAULT_TIMEOUT,
117
125
  attempts: int = 1,
126
+ wait: bool | Callable[[SignalDatatypeT], bool] = True,
118
127
  ) -> SignalRW[SignalDatatypeT]:
119
128
  """Create a `SignalRW` backed by 1 or 2 EPICS PVs, with a suffix on the readback pv.
120
129
 
@@ -131,7 +140,7 @@ def epics_signal_rw_rbv(
131
140
  read_pv = f"{write_pv}{read_suffix}"
132
141
 
133
142
  return epics_signal_rw(
134
- datatype, read_pv, write_pv, name, timeout=timeout, attempts=attempts
143
+ datatype, read_pv, write_pv, name, timeout=timeout, attempts=attempts, wait=wait
135
144
  )
136
145
 
137
146
 
@@ -158,6 +167,7 @@ def epics_signal_w(
158
167
  name: str = "",
159
168
  timeout: float = DEFAULT_TIMEOUT,
160
169
  attempts: int = 1,
170
+ wait: bool | Callable[[SignalDatatypeT], bool] = True,
161
171
  ) -> SignalW[SignalDatatypeT]:
162
172
  """Create a `SignalW` backed by 1 EPICS PVs.
163
173
 
@@ -166,7 +176,9 @@ def epics_signal_w(
166
176
  :param name: The name of the signal (defaults to empty string)
167
177
  :param timeout: A timeout to be used when reading (not connecting) this signal
168
178
  """
169
- backend = _epics_signal_backend(datatype, write_pv, write_pv)
179
+ backend = _epics_signal_backend(
180
+ datatype, write_pv, write_pv, EpicsOptions(wait=wait)
181
+ )
170
182
  return SignalW(backend, name=name, timeout=timeout, attempts=attempts)
171
183
 
172
184
 
@@ -1,5 +1,6 @@
1
- from collections.abc import Mapping, Sequence
2
- from typing import Any, TypeVar, get_args, get_origin
1
+ from collections.abc import Callable, Mapping, Sequence
2
+ from dataclasses import dataclass
3
+ from typing import Any, Generic, TypeVar, get_args, get_origin
3
4
 
4
5
  import numpy as np
5
6
 
@@ -19,6 +20,23 @@ from ophyd_async.core import (
19
20
  T = TypeVar("T")
20
21
 
21
22
 
23
+ @dataclass
24
+ class EpicsOptions(Generic[SignalDatatypeT]):
25
+ """Options for EPICS Signals."""
26
+
27
+ wait: bool | Callable[[SignalDatatypeT], bool] = True
28
+ """Whether to wait for server-side completion of the operation:
29
+
30
+ - `True`: Return when server-side operation has completed
31
+ - `False`: Return when server-side operation has started
32
+ - `callable`: Call with the value being put to decide whether to wait
33
+
34
+ For example, use `EpicsOption(wait=non_zero)` for busy records like
35
+ areaDetector acquire PVs that should not wait when being set to zero
36
+ as it causes a deadlock.
37
+ """
38
+
39
+
22
40
  def get_pv_basename_and_field(pv: str) -> tuple[str, str | None]:
23
41
  """Split PV into record name and field."""
24
42
  if "." in pv:
@@ -75,9 +93,11 @@ class EpicsSignalBackend(SignalBackend[SignalDatatypeT]):
75
93
  datatype: type[SignalDatatypeT] | None,
76
94
  read_pv: str = "",
77
95
  write_pv: str = "",
96
+ options: EpicsOptions | None = None,
78
97
  ):
79
98
  self.read_pv = read_pv
80
99
  self.write_pv = write_pv
100
+ self.options = options or EpicsOptions()
81
101
  super().__init__(datatype)
82
102
 
83
103
 
@@ -86,5 +106,5 @@ async def stop_busy_record(
86
106
  value: SignalDatatypeT,
87
107
  timeout: float = DEFAULT_TIMEOUT,
88
108
  ) -> None:
89
- await signal.set(value, wait=False)
109
+ await signal.set(value)
90
110
  await wait_for_value(signal, value, timeout=timeout)
@@ -55,8 +55,8 @@ class DemoMotor(EpicsDevice, StandardReadable, Movable, Stoppable):
55
55
  # If not supplied, calculate a suitable timeout for the move
56
56
  if timeout == CALCULATE_TIMEOUT:
57
57
  timeout = abs(new_position - old_position) / velocity + DEFAULT_TIMEOUT
58
- # Wait for the value to set, but don't wait for put completion callback
59
- await self.setpoint.set(new_position, wait=False)
58
+ # Setting the setpoint starts the motion
59
+ await self.setpoint.set(new_position)
60
60
  # Observe the readback Signal, and on each new position...
61
61
  async for current_position in observe_value(
62
62
  self.readback, done_timeout=timeout