ophyd-async 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. ophyd_async/_version.py +2 -2
  2. ophyd_async/core/__init__.py +34 -9
  3. ophyd_async/core/_detector.py +5 -10
  4. ophyd_async/core/_device.py +170 -68
  5. ophyd_async/core/_device_filler.py +269 -0
  6. ophyd_async/core/_device_save_loader.py +6 -7
  7. ophyd_async/core/_mock_signal_backend.py +35 -40
  8. ophyd_async/core/_mock_signal_utils.py +25 -16
  9. ophyd_async/core/_protocol.py +28 -8
  10. ophyd_async/core/_readable.py +133 -134
  11. ophyd_async/core/_signal.py +219 -163
  12. ophyd_async/core/_signal_backend.py +131 -64
  13. ophyd_async/core/_soft_signal_backend.py +131 -194
  14. ophyd_async/core/_status.py +22 -6
  15. ophyd_async/core/_table.py +102 -100
  16. ophyd_async/core/_utils.py +143 -32
  17. ophyd_async/epics/adaravis/_aravis_controller.py +2 -2
  18. ophyd_async/epics/adaravis/_aravis_io.py +8 -6
  19. ophyd_async/epics/adcore/_core_io.py +5 -7
  20. ophyd_async/epics/adcore/_core_logic.py +3 -1
  21. ophyd_async/epics/adcore/_hdf_writer.py +2 -2
  22. ophyd_async/epics/adcore/_single_trigger.py +6 -10
  23. ophyd_async/epics/adcore/_utils.py +15 -10
  24. ophyd_async/epics/adkinetix/__init__.py +2 -1
  25. ophyd_async/epics/adkinetix/_kinetix_controller.py +6 -3
  26. ophyd_async/epics/adkinetix/_kinetix_io.py +4 -5
  27. ophyd_async/epics/adpilatus/_pilatus_controller.py +2 -2
  28. ophyd_async/epics/adpilatus/_pilatus_io.py +3 -4
  29. ophyd_async/epics/adsimdetector/_sim_controller.py +2 -2
  30. ophyd_async/epics/advimba/__init__.py +4 -1
  31. ophyd_async/epics/advimba/_vimba_controller.py +6 -3
  32. ophyd_async/epics/advimba/_vimba_io.py +8 -9
  33. ophyd_async/epics/core/__init__.py +26 -0
  34. ophyd_async/epics/core/_aioca.py +323 -0
  35. ophyd_async/epics/core/_epics_connector.py +53 -0
  36. ophyd_async/epics/core/_epics_device.py +13 -0
  37. ophyd_async/epics/core/_p4p.py +383 -0
  38. ophyd_async/epics/core/_pvi_connector.py +91 -0
  39. ophyd_async/epics/core/_signal.py +171 -0
  40. ophyd_async/epics/core/_util.py +61 -0
  41. ophyd_async/epics/demo/_mover.py +4 -5
  42. ophyd_async/epics/demo/_sensor.py +14 -13
  43. ophyd_async/epics/eiger/_eiger.py +1 -2
  44. ophyd_async/epics/eiger/_eiger_controller.py +7 -2
  45. ophyd_async/epics/eiger/_eiger_io.py +3 -5
  46. ophyd_async/epics/eiger/_odin_io.py +5 -5
  47. ophyd_async/epics/motor.py +4 -5
  48. ophyd_async/epics/signal.py +11 -0
  49. ophyd_async/epics/testing/__init__.py +24 -0
  50. ophyd_async/epics/testing/_example_ioc.py +105 -0
  51. ophyd_async/epics/testing/_utils.py +78 -0
  52. ophyd_async/epics/testing/test_records.db +152 -0
  53. ophyd_async/epics/testing/test_records_pva.db +177 -0
  54. ophyd_async/fastcs/core.py +9 -0
  55. ophyd_async/fastcs/panda/__init__.py +4 -4
  56. ophyd_async/fastcs/panda/_block.py +18 -13
  57. ophyd_async/fastcs/panda/_control.py +3 -5
  58. ophyd_async/fastcs/panda/_hdf_panda.py +5 -19
  59. ophyd_async/fastcs/panda/_table.py +30 -52
  60. ophyd_async/fastcs/panda/_trigger.py +8 -8
  61. ophyd_async/fastcs/panda/_writer.py +2 -5
  62. ophyd_async/plan_stubs/_ensure_connected.py +20 -13
  63. ophyd_async/plan_stubs/_fly.py +2 -2
  64. ophyd_async/plan_stubs/_nd_attributes.py +5 -4
  65. ophyd_async/py.typed +0 -0
  66. ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py +1 -2
  67. ophyd_async/sim/demo/_sim_motor.py +3 -4
  68. ophyd_async/tango/__init__.py +0 -45
  69. ophyd_async/tango/{signal → core}/__init__.py +9 -6
  70. ophyd_async/tango/core/_base_device.py +132 -0
  71. ophyd_async/tango/{signal → core}/_signal.py +42 -53
  72. ophyd_async/tango/{base_devices → core}/_tango_readable.py +3 -4
  73. ophyd_async/tango/{signal → core}/_tango_transport.py +38 -40
  74. ophyd_async/tango/demo/_counter.py +12 -23
  75. ophyd_async/tango/demo/_mover.py +13 -13
  76. {ophyd_async-0.7.0.dist-info → ophyd_async-0.8.0.dist-info}/METADATA +52 -55
  77. ophyd_async-0.8.0.dist-info/RECORD +116 -0
  78. {ophyd_async-0.7.0.dist-info → ophyd_async-0.8.0.dist-info}/WHEEL +1 -1
  79. ophyd_async/epics/pvi/__init__.py +0 -3
  80. ophyd_async/epics/pvi/_pvi.py +0 -338
  81. ophyd_async/epics/signal/__init__.py +0 -21
  82. ophyd_async/epics/signal/_aioca.py +0 -378
  83. ophyd_async/epics/signal/_common.py +0 -57
  84. ophyd_async/epics/signal/_epics_transport.py +0 -34
  85. ophyd_async/epics/signal/_p4p.py +0 -518
  86. ophyd_async/epics/signal/_signal.py +0 -114
  87. ophyd_async/tango/base_devices/__init__.py +0 -4
  88. ophyd_async/tango/base_devices/_base_device.py +0 -225
  89. ophyd_async-0.7.0.dist-info/RECORD +0 -108
  90. {ophyd_async-0.7.0.dist-info → ophyd_async-0.8.0.dist-info}/LICENSE +0 -0
  91. {ophyd_async-0.7.0.dist-info → ophyd_async-0.8.0.dist-info}/entry_points.txt +0 -0
  92. {ophyd_async-0.7.0.dist-info → ophyd_async-0.8.0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from typing import (
13
13
 
14
14
  from bluesky.protocols import Status
15
15
 
16
+ from ._device import Device
16
17
  from ._protocol import Watcher
17
18
  from ._utils import Callback, P, T, WatcherUpdate
18
19
 
@@ -23,13 +24,14 @@ WAS = TypeVar("WAS", bound="WatchableAsyncStatus")
23
24
  class AsyncStatusBase(Status):
24
25
  """Convert asyncio awaitable to bluesky Status interface"""
25
26
 
26
- def __init__(self, awaitable: Coroutine | asyncio.Task):
27
+ def __init__(self, awaitable: Coroutine | asyncio.Task, name: str | None = None):
27
28
  if isinstance(awaitable, asyncio.Task):
28
29
  self.task = awaitable
29
30
  else:
30
31
  self.task = asyncio.create_task(awaitable)
31
32
  self.task.add_done_callback(self._run_callbacks)
32
33
  self._callbacks: list[Callback[Status]] = []
34
+ self._name = name
33
35
 
34
36
  def __await__(self):
35
37
  return self.task.__await__()
@@ -76,7 +78,11 @@ class AsyncStatusBase(Status):
76
78
  status = "done"
77
79
  else:
78
80
  status = "pending"
79
- return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>"
81
+ device_str = f"device: {self._name}, " if self._name else ""
82
+ return (
83
+ f"<{type(self).__name__}, {device_str}"
84
+ f"task: {self.task.get_coro()}, {status}>"
85
+ )
80
86
 
81
87
  __str__ = __repr__
82
88
 
@@ -90,7 +96,11 @@ class AsyncStatus(AsyncStatusBase):
90
96
 
91
97
  @functools.wraps(f)
92
98
  def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
93
- return cls(f(*args, **kwargs))
99
+ if args and isinstance(args[0], Device):
100
+ name = args[0].name
101
+ else:
102
+ name = None
103
+ return cls(f(*args, **kwargs), name=name)
94
104
 
95
105
  # type is actually functools._Wrapped[P, Awaitable, P, AS]
96
106
  # but functools._Wrapped is not necessarily available
@@ -100,11 +110,13 @@ class AsyncStatus(AsyncStatusBase):
100
110
  class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
101
111
  """Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""
102
112
 
103
- def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
113
+ def __init__(
114
+ self, iterator: AsyncIterator[WatcherUpdate[T]], name: str | None = None
115
+ ):
104
116
  self._watchers: list[Watcher] = []
105
117
  self._start = time.monotonic()
106
118
  self._last_update: WatcherUpdate[T] | None = None
107
- super().__init__(self._notify_watchers_from(iterator))
119
+ super().__init__(self._notify_watchers_from(iterator), name)
108
120
 
109
121
  async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
110
122
  async for update in iterator:
@@ -136,7 +148,11 @@ class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
136
148
 
137
149
  @functools.wraps(f)
138
150
  def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
139
- return cls(f(*args, **kwargs))
151
+ if args and isinstance(args[0], Device):
152
+ name = args[0].name
153
+ else:
154
+ name = None
155
+ return cls(f(*args, **kwargs), name=name)
140
156
 
141
157
  return cast(Callable[P, WAS], wrap_f)
142
158
 
@@ -1,8 +1,13 @@
1
- from enum import Enum
2
- from typing import TypeVar, get_args, get_origin
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Sequence
4
+ from typing import Annotated, Any, TypeVar, get_origin
3
5
 
4
6
  import numpy as np
5
- from pydantic import BaseModel, ConfigDict, model_validator
7
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
8
+ from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation
9
+
10
+ from ._utils import get_dtype
6
11
 
7
12
  TableSubclass = TypeVar("TableSubclass", bound="Table")
8
13
 
@@ -14,37 +19,46 @@ def _concat(value1, value2):
14
19
  return value1 + value2
15
20
 
16
21
 
22
+ def _make_default_factory(dtype: np.dtype) -> Callable[[], np.ndarray]:
23
+ def numpy_array_default_factory() -> np.ndarray:
24
+ return np.array([], dtype)
25
+
26
+ return numpy_array_default_factory
27
+
28
+
17
29
  class Table(BaseModel):
18
30
  """An abstraction of a Table of str to numpy array."""
19
31
 
20
- model_config = ConfigDict(validate_assignment=True, strict=False)
21
-
22
- @staticmethod
23
- def row(cls: type[TableSubclass], **kwargs) -> TableSubclass: # type: ignore
24
- arrayified_kwargs = {}
25
- for field_name, field_value in cls.model_fields.items():
26
- value = kwargs.pop(field_name)
27
- if field_value.default_factory is None:
28
- raise ValueError(
29
- "`Table` models should have default factories for their "
30
- "mutable empty columns."
31
- )
32
- default_array = field_value.default_factory()
33
- if isinstance(default_array, np.ndarray):
34
- arrayified_kwargs[field_name] = np.array(
35
- [value], dtype=default_array.dtype
36
- )
37
- elif issubclass(type(value), Enum) and isinstance(value, str):
38
- arrayified_kwargs[field_name] = [value]
32
+ # You can use Table in 2 ways:
33
+ # 1. Table(**whatever_pva_gives_us) when pvi adds a Signal to a Device that is not
34
+ # type hinted
35
+ # 2. MyTable(**whatever_pva_gives_us) where the Signal is type hinted
36
+ #
37
+ # For 1 we want extra="allow" so it is passed through as is. There are no base class
38
+ # fields, only "extra" fields, so they must be allowed. For 2 we want extra="forbid"
39
+ # so it is strictly checked against the BaseModel we are supplied.
40
+ model_config = ConfigDict(extra="allow")
41
+
42
+ @classmethod
43
+ def __init_subclass__(cls):
44
+ # But forbit extra in subclasses so it gets validated
45
+ cls.model_config = ConfigDict(validate_assignment=True, extra="forbid")
46
+ # Change fields to have the correct annotations
47
+ for k, anno in cls.__annotations__.items():
48
+ if get_origin(anno) is np.ndarray:
49
+ dtype = get_dtype(anno)
50
+ new_anno = Annotated[
51
+ anno,
52
+ NpArrayPydanticAnnotation.factory(
53
+ data_type=dtype.type, dimensions=1, strict_data_typing=False
54
+ ),
55
+ Field(default_factory=_make_default_factory(dtype)),
56
+ ]
57
+ elif get_origin(anno) is Sequence:
58
+ new_anno = Annotated[anno, Field(default_factory=list)]
39
59
  else:
40
- raise TypeError(
41
- "Row column should be numpy arrays or sequence of string `Enum`."
42
- )
43
- if kwargs:
44
- raise TypeError(
45
- f"Unexpected keyword arguments {kwargs.keys()} for {cls.__name__}."
46
- )
47
- return cls(**arrayified_kwargs)
60
+ raise TypeError(f"Cannot use annotation {anno} in a Table")
61
+ cls.__annotations__[k] = new_anno
48
62
 
49
63
  def __add__(self, right: TableSubclass) -> TableSubclass:
50
64
  """Concatenate the arrays in field values."""
@@ -64,83 +78,71 @@ class Table(BaseModel):
64
78
  }
65
79
  )
66
80
 
81
+ def __eq__(self, value: object) -> bool:
82
+ return super().__eq__(value)
83
+
67
84
  def numpy_dtype(self) -> np.dtype:
68
85
  dtype = []
69
- for field_name, field_value in self.model_fields.items():
70
- if np.ndarray in (
71
- get_origin(field_value.annotation),
72
- field_value.annotation,
73
- ):
74
- dtype.append((field_name, getattr(self, field_name).dtype))
86
+ for k, v in self:
87
+ if isinstance(v, np.ndarray):
88
+ dtype.append((k, v.dtype))
75
89
  else:
76
- enum_type = get_args(field_value.annotation)[0]
77
- assert issubclass(enum_type, Enum)
78
- enum_values = [element.value for element in enum_type]
79
- max_length_in_enum = max(len(value) for value in enum_values)
80
- dtype.append((field_name, np.dtype(f"<U{max_length_in_enum}")))
81
-
90
+ # TODO: use np.dtypes.StringDType when we can use in structured arrays
91
+ # https://github.com/numpy/numpy/issues/25693
92
+ dtype.append((k, np.dtype("S40")))
82
93
  return np.dtype(dtype)
83
94
 
84
- def numpy_table(self):
85
- # It would be nice to be able to use np.transpose for this,
86
- # but it defaults to the largest dtype for everything.
87
- dtype = self.numpy_dtype()
88
- transposed_list = [
89
- np.array(tuple(row), dtype=dtype)
90
- for row in zip(*self.numpy_columns(), strict=False)
91
- ]
92
- transposed = np.array(transposed_list, dtype=dtype)
93
- return transposed
94
-
95
- def numpy_columns(self) -> list[np.ndarray]:
96
- """Columns in the table can be lists of string enums or numpy arrays.
97
-
98
- This method returns the columns, converting the string enums to numpy arrays.
99
- """
100
-
101
- columns = []
102
- for field_name, field_value in self.model_fields.items():
103
- if np.ndarray in (
104
- get_origin(field_value.annotation),
105
- field_value.annotation,
95
+ def numpy_table(self, selection: slice | None = None) -> np.ndarray:
96
+ array = None
97
+ for k, v in self:
98
+ if selection:
99
+ v = v[selection]
100
+ if array is None:
101
+ array = np.empty(v.shape, dtype=self.numpy_dtype())
102
+ array[k] = v
103
+ assert array is not None
104
+ return array
105
+
106
+ @model_validator(mode="before")
107
+ @classmethod
108
+ def validate_array_dtypes(cls, data: Any) -> Any:
109
+ if isinstance(data, dict):
110
+ data_dict = data
111
+ elif isinstance(data, Table):
112
+ data_dict = data.model_dump()
113
+ else:
114
+ raise AssertionError(f"Cannot construct Table from {data}")
115
+ for field_name, field_value in cls.model_fields.items():
116
+ if (
117
+ get_origin(field_value.annotation) is np.ndarray
118
+ and field_value.annotation
119
+ and field_name in data_dict
106
120
  ):
107
- columns.append(getattr(self, field_name))
108
- else:
109
- enum_type = get_args(field_value.annotation)[0]
110
- assert issubclass(enum_type, Enum)
111
- enum_values = [element.value for element in enum_type]
112
- max_length_in_enum = max(len(value) for value in enum_values)
113
- dtype = np.dtype(f"<U{max_length_in_enum}")
114
-
115
- columns.append(
116
- np.array(
117
- [enum.value for enum in getattr(self, field_name)], dtype=dtype
118
- )
121
+ data_value = data_dict[field_name]
122
+ expected_dtype = get_dtype(field_value.annotation)
123
+ # Convert to correct dtype, but only if we don't lose precision
124
+ # as a result
125
+ cast_value = np.array(data_value).astype(expected_dtype)
126
+ assert np.array_equal(data_value, cast_value), (
127
+ f"{field_name}: Cannot cast {data_value} to {expected_dtype} "
128
+ "without losing precision"
119
129
  )
120
-
121
- return columns
130
+ data_dict[field_name] = cast_value
131
+ return data_dict
122
132
 
123
133
  @model_validator(mode="after")
124
- def validate_arrays(self) -> "Table":
125
- first_length = len(next(iter(self))[1])
126
- assert all(
127
- len(field_value) == first_length for _, field_value in self
128
- ), "Rows should all be of equal size."
129
-
130
- if not all(
131
- # Checks if the values are numpy subtypes if the array is a numpy array,
132
- # or if the value is a string enum.
133
- np.issubdtype(getattr(self, field_name).dtype, default_array.dtype)
134
- if isinstance(
135
- default_array := self.model_fields[field_name].default_factory(), # type: ignore
136
- np.ndarray,
137
- )
138
- else issubclass(get_args(field_value.annotation)[0], Enum)
139
- for field_name, field_value in self.model_fields.items()
140
- ):
141
- raise ValueError(
142
- f"Cannot construct a `{type(self).__name__}`, "
143
- "some rows have incorrect types."
144
- )
145
-
134
+ def validate_lengths(self) -> Table:
135
+ lengths: dict[int, set[str]] = {}
136
+ for field_name, field_value in self:
137
+ lengths.setdefault(len(field_value), set()).add(field_name)
138
+ assert len(lengths) <= 1, f"Columns should be same length, got {lengths=}"
146
139
  return self
140
+
141
+ def __len__(self) -> int:
142
+ return len(next(iter(self))[1])
143
+
144
+ def __getitem__(self, item: int | slice) -> np.ndarray:
145
+ if isinstance(item, int):
146
+ return self.numpy_table(slice(item, item + 1))
147
+ else:
148
+ return self.numpy_table(item)
@@ -2,23 +2,34 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import logging
5
- from collections.abc import Awaitable, Callable, Iterable
5
+ from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
6
6
  from dataclasses import dataclass
7
- from typing import Generic, Literal, ParamSpec, TypeVar, get_origin
7
+ from enum import Enum, EnumMeta
8
+ from typing import Any, Generic, Literal, ParamSpec, TypeVar, get_args, get_origin
9
+ from unittest.mock import Mock
8
10
 
9
11
  import numpy as np
10
- from bluesky.protocols import Reading
11
- from pydantic import BaseModel
12
12
 
13
13
  T = TypeVar("T")
14
14
  P = ParamSpec("P")
15
15
  Callback = Callable[[T], None]
16
-
17
- #: A function that will be called with the Reading and value when the
18
- #: monitor updates
19
- ReadingValueCallback = Callable[[Reading, T], None]
20
16
  DEFAULT_TIMEOUT = 10.0
21
- ErrorText = str | dict[str, Exception]
17
+ ErrorText = str | Mapping[str, Exception]
18
+
19
+
20
+ class StrictEnum(str, Enum):
21
+ """All members should exist in the Backend, and there will be no extras"""
22
+
23
+
24
+ class SubsetEnumMeta(EnumMeta):
25
+ def __call__(self, value, *args, **kwargs): # type: ignore
26
+ if isinstance(value, str) and not isinstance(value, self):
27
+ return value
28
+ return super().__call__(value, *args, **kwargs)
29
+
30
+
31
+ class SubsetEnum(StrictEnum, metaclass=SubsetEnumMeta):
32
+ """All members should exist in the Backend, but there may be extras"""
22
33
 
23
34
 
24
35
  CALCULATE_TIMEOUT = "CALCULATE_TIMEOUT"
@@ -51,6 +62,13 @@ class NotConnected(Exception):
51
62
 
52
63
  self._errors = errors
53
64
 
65
+ @property
66
+ def sub_errors(self) -> Mapping[str, Exception]:
67
+ if isinstance(self._errors, dict):
68
+ return self._errors.copy()
69
+ else:
70
+ return {}
71
+
54
72
  def _format_sub_errors(self, name: str, error: Exception, indent="") -> str:
55
73
  if isinstance(error, NotConnected):
56
74
  error_txt = ":" + error.format_error_string(indent + self._indent_width)
@@ -81,6 +99,19 @@ class NotConnected(Exception):
81
99
  def __str__(self) -> str:
82
100
  return self.format_error_string(indent="")
83
101
 
102
+ @classmethod
103
+ def with_other_exceptions_logged(
104
+ cls, exceptions: Mapping[str, Exception]
105
+ ) -> NotConnected:
106
+ for name, exception in exceptions.items():
107
+ if not isinstance(exception, NotConnected):
108
+ logging.exception(
109
+ f"device `{name}` raised unexpected exception "
110
+ f"{type(exception).__name__}",
111
+ exc_info=exception,
112
+ )
113
+ return NotConnected(exceptions)
114
+
84
115
 
85
116
  @dataclass(frozen=True)
86
117
  class WatcherUpdate(Generic[T]):
@@ -102,24 +133,41 @@ async def wait_for_connection(**coros: Awaitable[None]):
102
133
 
103
134
  Expected kwargs should be a mapping of names to coroutine tasks to execute.
104
135
  """
105
- results = await asyncio.gather(*coros.values(), return_exceptions=True)
106
- exceptions = {}
107
-
108
- for name, result in zip(coros, results, strict=False):
109
- if isinstance(result, Exception):
110
- exceptions[name] = result
111
- if not isinstance(result, NotConnected):
112
- logging.exception(
113
- f"device `{name}` raised unexpected exception "
114
- f"{type(result).__name__}",
115
- exc_info=result,
116
- )
136
+ exceptions: dict[str, Exception] = {}
137
+ if len(coros) == 1:
138
+ # Single device optimization
139
+ name, coro = coros.popitem()
140
+ try:
141
+ await coro
142
+ except Exception as e:
143
+ exceptions[name] = e
144
+ else:
145
+ # Use gather to connect in parallel
146
+ results = await asyncio.gather(*coros.values(), return_exceptions=True)
147
+ for name, result in zip(coros, results, strict=False):
148
+ if isinstance(result, Exception):
149
+ exceptions[name] = result
117
150
 
118
151
  if exceptions:
119
- raise NotConnected(exceptions)
152
+ raise NotConnected.with_other_exceptions_logged(exceptions)
120
153
 
121
154
 
122
- def get_dtype(typ: type) -> np.dtype | None:
155
+ def get_dtype(datatype: type) -> np.dtype:
156
+ """Get the runtime dtype from a numpy ndarray type annotation
157
+
158
+ >>> from ophyd_async.core import Array1D
159
+ >>> import numpy as np
160
+ >>> get_dtype(Array1D[np.int8])
161
+ dtype('int8')
162
+ """
163
+ if not get_origin(datatype) == np.ndarray:
164
+ raise TypeError(f"Expected Array1D[dtype], got {datatype}")
165
+ # datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]
166
+ # so extract numpy.float64 from it
167
+ return np.dtype(get_args(get_args(datatype)[1])[0])
168
+
169
+
170
+ def get_enum_cls(datatype: type | None) -> type[StrictEnum] | None:
123
171
  """Get the runtime dtype from a numpy ndarray type annotation
124
172
 
125
173
  >>> import numpy.typing as npt
@@ -127,11 +175,15 @@ def get_dtype(typ: type) -> np.dtype | None:
127
175
  >>> get_dtype(npt.NDArray[np.int8])
128
176
  dtype('int8')
129
177
  """
130
- if getattr(typ, "__origin__", None) == np.ndarray:
131
- # datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]
132
- # so extract numpy.float64 from it
133
- return np.dtype(typ.__args__[1].__args__[0]) # type: ignore
134
- return None
178
+ if get_origin(datatype) is Sequence:
179
+ datatype = get_args(datatype)[0]
180
+ if datatype and issubclass(datatype, Enum):
181
+ if not issubclass(datatype, StrictEnum):
182
+ raise TypeError(
183
+ f"{datatype} should inherit from .SubsetEnum "
184
+ "or ophyd_async.core.StrictEnum"
185
+ )
186
+ return datatype
135
187
 
136
188
 
137
189
  def get_unique(values: dict[str, T], types: str) -> T:
@@ -187,7 +239,66 @@ def in_micros(t: float) -> int:
187
239
  return int(np.ceil(t * 1e6))
188
240
 
189
241
 
190
- def is_pydantic_model(datatype) -> bool:
191
- while origin := get_origin(datatype):
192
- datatype = origin
193
- return datatype and issubclass(datatype, BaseModel)
242
+ def get_origin_class(annotatation: Any) -> type | None:
243
+ origin = get_origin(annotatation) or annotatation
244
+ if isinstance(origin, type):
245
+ return origin
246
+
247
+
248
+ class Reference(Generic[T]):
249
+ """Hide an object behind a reference.
250
+
251
+ Used to opt out of the naming/parent-child relationship of `Device`.
252
+
253
+ For example::
254
+
255
+ class DeviceWithRefToSignal(Device):
256
+ def __init__(self, signal: SignalRW[int]):
257
+ self.signal_ref = Reference(signal)
258
+ super().__init__()
259
+
260
+ def set(self, value) -> AsyncStatus:
261
+ return self.signal_ref().set(value + 1)
262
+
263
+ """
264
+
265
+ def __init__(self, obj: T):
266
+ self._obj = obj
267
+
268
+ def __call__(self) -> T:
269
+ return self._obj
270
+
271
+
272
+ class LazyMock:
273
+ """A lazily created Mock to be used when connecting in mock mode.
274
+
275
+ Creating Mocks is reasonably expensive when each Device (and Signal)
276
+ requires its own, and the tree is only used when ``Signal.set()`` is
277
+ called. This class allows a tree of lazily connected Mocks to be
278
+ constructed so that when the leaf is created, so are its parents.
279
+ Any calls to the child are then accessible from the parent mock.
280
+
281
+ >>> parent = LazyMock()
282
+ >>> child = parent.child("child")
283
+ >>> child_mock = child()
284
+ >>> child_mock() # doctest: +ELLIPSIS
285
+ <Mock name='mock.child()' id='...'>
286
+ >>> parent_mock = parent()
287
+ >>> parent_mock.mock_calls
288
+ [call.child()]
289
+ """
290
+
291
+ def __init__(self, name: str = "", parent: LazyMock | None = None) -> None:
292
+ self.parent = parent
293
+ self.name = name
294
+ self._mock: Mock | None = None
295
+
296
+ def child(self, name: str) -> LazyMock:
297
+ return LazyMock(name, self)
298
+
299
+ def __call__(self) -> Mock:
300
+ if self._mock is None:
301
+ self._mock = Mock(spec=object)
302
+ if self.parent is not None:
303
+ self.parent().attach_mock(self._mock, self.name)
304
+ return self._mock
@@ -2,12 +2,12 @@ import asyncio
2
2
  from typing import Literal
3
3
 
4
4
  from ophyd_async.core import (
5
+ AsyncStatus,
5
6
  DetectorController,
6
7
  DetectorTrigger,
7
8
  TriggerInfo,
8
9
  set_and_wait_for_value,
9
10
  )
10
- from ophyd_async.core._status import AsyncStatus
11
11
  from ophyd_async.epics import adcore
12
12
 
13
13
  from ._aravis_io import AravisDriverIO, AravisTriggerMode, AravisTriggerSource
@@ -69,7 +69,7 @@ class AravisController(DetectorController):
69
69
  f"use {trigger}"
70
70
  )
71
71
  if trigger == DetectorTrigger.internal:
72
- return AravisTriggerMode.off, "Freerun"
72
+ return AravisTriggerMode.off, AravisTriggerSource.freerun
73
73
  else:
74
74
  return (AravisTriggerMode.on, f"Line{self.gpio_number}") # type: ignore
75
75
 
@@ -1,11 +1,9 @@
1
- from enum import Enum
2
-
3
- from ophyd_async.core import SubsetEnum
1
+ from ophyd_async.core import StrictEnum, SubsetEnum
4
2
  from ophyd_async.epics import adcore
5
- from ophyd_async.epics.signal import epics_signal_rw_rbv
3
+ from ophyd_async.epics.core import epics_signal_rw_rbv
6
4
 
7
5
 
8
- class AravisTriggerMode(str, Enum):
6
+ class AravisTriggerMode(StrictEnum):
9
7
  """GigEVision GenICAM standard: on=externally triggered"""
10
8
 
11
9
  on = "On"
@@ -19,7 +17,11 @@ class AravisTriggerMode(str, Enum):
19
17
  To prevent requiring one Enum class per possible configuration, we set as this Enum
20
18
  but read from the underlying signal as a str.
21
19
  """
22
- AravisTriggerSource = SubsetEnum["Freerun", "Line1"]
20
+
21
+
22
+ class AravisTriggerSource(SubsetEnum):
23
+ freerun = "Freerun"
24
+ line1 = "Line1"
23
25
 
24
26
 
25
27
  class AravisDriverIO(adcore.ADBaseIO):
@@ -1,7 +1,5 @@
1
- from enum import Enum
2
-
3
- from ophyd_async.core import Device
4
- from ophyd_async.epics.signal import (
1
+ from ophyd_async.core import Device, StrictEnum
2
+ from ophyd_async.epics.core import (
5
3
  epics_signal_r,
6
4
  epics_signal_rw,
7
5
  epics_signal_rw_rbv,
@@ -10,7 +8,7 @@ from ophyd_async.epics.signal import (
10
8
  from ._utils import ADBaseDataType, FileWriteMode, ImageMode
11
9
 
12
10
 
13
- class Callback(str, Enum):
11
+ class Callback(StrictEnum):
14
12
  Enable = "Enable"
15
13
  Disable = "Disable"
16
14
 
@@ -68,7 +66,7 @@ class NDPluginStatsIO(NDPluginBaseIO):
68
66
  super().__init__(prefix, name)
69
67
 
70
68
 
71
- class DetectorState(str, Enum):
69
+ class DetectorState(StrictEnum):
72
70
  """
73
71
  Default set of states of an AreaDetector driver.
74
72
  See definition in ADApp/ADSrc/ADDriver.h in https://github.com/areaDetector/ADCore
@@ -100,7 +98,7 @@ class ADBaseIO(NDArrayBaseIO):
100
98
  super().__init__(prefix, name=name)
101
99
 
102
100
 
103
- class Compression(str, Enum):
101
+ class Compression(StrictEnum):
104
102
  none = "None"
105
103
  nbit = "N-bit"
106
104
  szip = "szip"
@@ -91,7 +91,9 @@ async def start_acquiring_driver_and_ensure_status(
91
91
  subsequent raising (if applicable) due to detector state.
92
92
  """
93
93
 
94
- status = await set_and_wait_for_value(driver.acquire, True, timeout=timeout)
94
+ status = await set_and_wait_for_value(
95
+ driver.acquire, True, timeout=timeout, wait_for_set_completion=False
96
+ )
95
97
 
96
98
  async def complete_acquisition() -> None:
97
99
  """NOTE: possible race condition here between the callback from
@@ -134,9 +134,9 @@ class ADHDFWriter(DetectorWriter):
134
134
  describe = {
135
135
  ds.data_key: DataKey(
136
136
  source=self.hdf.full_file_name.source,
137
- shape=outer_shape + tuple(ds.shape),
137
+ shape=list(outer_shape + tuple(ds.shape)),
138
138
  dtype="array" if ds.shape else "number",
139
- dtype_numpy=ds.dtype_numpy, # type: ignore
139
+ dtype_numpy=ds.dtype_numpy,
140
140
  external="STREAM:",
141
141
  )
142
142
  for ds in self._datasets