pyglaze 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,165 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ import numpy as np
7
+ from scipy.interpolate import CubicSpline
8
+
9
+ from .pulse import Pulse
10
+
11
+ if TYPE_CHECKING:
12
+ from datetime import datetime
13
+
14
+ from pyglaze.helpers.types import FloatArray
15
+
16
+ __all__ = ["UnprocessedWaveform"]
17
+
18
+ RecoMethod = Literal["cubic_spline"]
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class UnprocessedWaveform:
23
+ """A dataclass representing an unprocessed waveform. No assumptions are made about the delay or signal.
24
+
25
+ Args:
26
+ time: The time values recorded by the lock-in amp during the scan.
27
+ signal: The signal values recorded by the lock-in amp during the scan.
28
+ """
29
+
30
+ time: FloatArray
31
+ signal: FloatArray
32
+
33
+ @classmethod
34
+ def from_polar_coords(
35
+ cls: type[UnprocessedWaveform],
36
+ time: FloatArray,
37
+ radius: FloatArray,
38
+ theta: FloatArray,
39
+ rotation_angle: float | None = None,
40
+ ) -> UnprocessedWaveform:
41
+ """Create an UnprocessedWaveform object from raw lock-in amp output.
42
+
43
+ Args:
44
+ time: The time values recorded by the lock-in amp during the scan.
45
+ radius: The radius values recorded by the lock-in amp during the scan.
46
+ theta: The theta values recorded by the lock-in amp during the scan (in degrees).
47
+ rotation_angle: The angle to rotate lockin signal to align along x-axis. If not given, will use the angle at the maximum value of R.
48
+ """
49
+ _rot_ang = (
50
+ theta[np.argmax(radius)] if rotation_angle is None else rotation_angle
51
+ )
52
+
53
+ # rotate such that all signal lies along X
54
+ new_theta = theta - _rot_ang
55
+ signal = radius * np.cos(new_theta * np.pi / 180.0)
56
+ return cls(time, signal)
57
+
58
+ @classmethod
59
+ def from_dict(
60
+ cls: type[UnprocessedWaveform], d: dict[str, FloatArray | list[float] | None]
61
+ ) -> UnprocessedWaveform:
62
+ """Create an UnprocessedWaveform object from a dictionary.
63
+
64
+ Args:
65
+ d: A dictionary containing the keys 'time', 'signal'.
66
+ """
67
+ return UnprocessedWaveform(
68
+ time=np.array(d["time"]), signal=np.array(d["signal"])
69
+ )
70
+
71
+ def reconstruct(
72
+ self: UnprocessedWaveform, method: RecoMethod, times: FloatArray | None = None
73
+ ) -> UnprocessedWaveform:
74
+ """Reconstructs the waveform for a specified array of times using a specified method. If no delays are given, linearly spaced times between the maximum and the minimum of the delays will be used.
75
+
76
+ Args:
77
+ method: Name of reconstruction method.
78
+ times: Optional array of delay times.
79
+
80
+ Raises:
81
+ ValueError: When an unknown reconstruction method is requested
82
+ """
83
+ if times is None:
84
+ times = np.linspace(
85
+ self.time[0], self.time[-1], len(self.time), endpoint=True
86
+ )
87
+
88
+ if method == "cubic_spline":
89
+ return UnprocessedWaveform(
90
+ times, CubicSpline(self.time, self.signal)(times)
91
+ )
92
+
93
+ msg = f"Unknown reconstruction method: {method}"
94
+ raise ValueError(msg)
95
+
96
+ @classmethod
97
+ def average(
98
+ cls: type[UnprocessedWaveform], waveforms: list[UnprocessedWaveform]
99
+ ) -> UnprocessedWaveform:
100
+ """Computes the average of a list of UnprocessedWaveform objects.
101
+
102
+ Args:
103
+ waveforms: List of waveforms
104
+
105
+ """
106
+ if len(waveforms) == 1:
107
+ return waveforms[0]
108
+ signals = np.array([waveform.signal for waveform in waveforms])
109
+ return UnprocessedWaveform(
110
+ time=waveforms[0].time, signal=np.mean(signals, axis=0)
111
+ )
112
+
113
+ def from_triangular_waveform(
114
+ self: UnprocessedWaveform, ramp: Literal["up", "down"]
115
+ ) -> UnprocessedWaveform:
116
+ """Picks out the pulse from a scan with fiberstretchers driven by a triangular waveform.
117
+
118
+ Args:
119
+ ramp: Whether to pick out the pulse from the upgoing or downgoing ramp of the triangle wave
120
+
121
+ Raises:
122
+ ValueError: If 'ramp' is neither 'up' or 'down'
123
+
124
+ Returns:
125
+ Raw waveform
126
+ """
127
+ argmax = np.argmax(self.time)
128
+ argmin = np.argmin(self.time)
129
+ min_before_max = argmin < argmax
130
+ if ramp == "up" and min_before_max: # down up down
131
+ t = self.time[argmin : argmax + 1]
132
+ s = self.signal[argmin : argmax + 1]
133
+ elif ramp == "up" and not min_before_max: # up down up
134
+ t = np.concatenate((self.time[argmin:], self.time[: argmax + 1]))
135
+ s = np.concatenate((self.signal[argmin:], self.signal[: argmax + 1]))
136
+ elif ramp == "down" and min_before_max: # down up down
137
+ t = np.flip(np.concatenate((self.time[argmax:], self.time[: argmin + 1])))
138
+ s = np.flip(
139
+ np.concatenate((self.signal[argmax:], self.signal[: argmin + 1]))
140
+ )
141
+ elif ramp == "down" and not min_before_max: # up down up
142
+ t = np.flip(self.time[argmax : argmin + 1])
143
+ s = np.flip(self.signal[argmax : argmin + 1])
144
+ else:
145
+ msg = "'ramp' must be either 'up' or 'down'"
146
+ raise ValueError(msg)
147
+
148
+ return UnprocessedWaveform(time=t, signal=s)
149
+
150
+ def as_pulse(self: UnprocessedWaveform) -> Pulse:
151
+ """Converts the current waveform to a Pulse object."""
152
+ return Pulse(time=self.time, signal=self.signal)
153
+
154
+
155
+ @dataclass
156
+ class _TimestampedWaveform:
157
+ """A data class representing a terahertz pulse with a timestamp.
158
+
159
+ Args:
160
+ timestamp: The timestamp of the pulse given by the Toptica server.
161
+ waveform: The terahertz pulse received from the Toptica server.
162
+ """
163
+
164
+ timestamp: datetime
165
+ waveform: UnprocessedWaveform
@@ -0,0 +1,15 @@
1
+ from .configuration import ForceDeviceConfiguration, Interval, LeDeviceConfiguration
2
+ from .delayunit import NonuniformDelay, UniformDelay, list_delayunits, load_delayunit
3
+ from .identifiers import get_device_id, list_devices
4
+
5
+ __all__ = [
6
+ "LeDeviceConfiguration",
7
+ "ForceDeviceConfiguration",
8
+ "Interval",
9
+ "NonuniformDelay",
10
+ "UniformDelay",
11
+ "list_delayunits",
12
+ "load_delayunit",
13
+ "get_device_id",
14
+ "list_devices",
15
+ ]
@@ -0,0 +1,447 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import logging
5
+ import time
6
+ from dataclasses import dataclass, field
7
+ from enum import Enum
8
+ from functools import cached_property
9
+ from math import modf
10
+ from typing import TYPE_CHECKING, ClassVar, overload
11
+
12
+ import numpy as np
13
+ import serial
14
+ from bitstring import BitArray
15
+ from serial import serialutil
16
+
17
+ from pyglaze.device.configuration import (
18
+ DeviceConfiguration,
19
+ ForceDeviceConfiguration,
20
+ Interval,
21
+ LeDeviceConfiguration,
22
+ )
23
+ from pyglaze.device.delayunit import Delay, load_delayunit
24
+ from pyglaze.devtools.mock_device import _mock_device_factory
25
+ from pyglaze.helpers.utilities import LOGGER_NAME, _BackoffRetry
26
+
27
+ if TYPE_CHECKING:
28
+ from pyglaze.devtools.mock_device import (
29
+ ForceMockDevice,
30
+ LeMockDevice,
31
+ MockDevice,
32
+ )
33
+ from pyglaze.helpers.types import FloatArray
34
+
35
+
36
+ @dataclass
37
+ class _ForceAmpCom:
38
+ config: ForceDeviceConfiguration
39
+ CONT_SCAN_UPDATE_FREQ: float = 1 # seconds
40
+ __ser: ForceMockDevice | serial.Serial = field(init=False)
41
+
42
+ ENCODING: ClassVar[str] = "utf-8"
43
+ OK_RESPONSE: ClassVar[str] = "!A,OK"
44
+ N_POINTS: ClassVar[int] = 10000
45
+ DAC_BITWIDTH: ClassVar[int] = 65535 # bit-width of amp DAC
46
+ # DO NOT change - antennas will break.
47
+ MIN_ALLOWED_MOD_VOLTAGE: ClassVar[float] = -1.0
48
+ MAX_ALLOWED_MOD_VOLTAGE: ClassVar[float] = 0.5
49
+
50
+ @cached_property
51
+ def scanning_points(self: _ForceAmpCom) -> int:
52
+ time_pr_point = (
53
+ self.config.integration_periods / self.config.modulation_frequency
54
+ )
55
+ return int(self.config.sweep_length_ms * 1e-3 / time_pr_point)
56
+
57
+ @cached_property
58
+ def _squished_intervals(self: _ForceAmpCom) -> list[Interval]:
59
+ """Intervals squished into effective DAC range."""
60
+ return _squish_intervals(
61
+ intervals=self.config.scan_intervals or [Interval(lower=0.0, upper=1.0)],
62
+ lower_bound=self.config.dac_lower_bound,
63
+ upper_bound=self.config.dac_upper_bound,
64
+ bitwidth=self.DAC_BITWIDTH,
65
+ )
66
+
67
+ @cached_property
68
+ def times(self: _ForceAmpCom) -> FloatArray:
69
+ return _delay_from_intervals(
70
+ delayunit=load_delayunit(self.config.delayunit),
71
+ intervals=self.config.scan_intervals,
72
+ points_per_interval=_points_per_interval(
73
+ self.scanning_points, self._squished_intervals
74
+ ),
75
+ )
76
+
77
+ @cached_property
78
+ def scanning_list(self: _ForceAmpCom) -> list[float]:
79
+ scanning_list: list[float] = []
80
+ for interval, n_points in zip(
81
+ self._squished_intervals,
82
+ _points_per_interval(self.N_POINTS, self._squished_intervals),
83
+ ):
84
+ scanning_list.extend(
85
+ np.linspace(interval.lower, interval.upper, n_points, endpoint=False)
86
+ )
87
+
88
+ return scanning_list
89
+
90
+ @property
91
+ def datapoints_per_update(self: _ForceAmpCom) -> int:
92
+ return int(
93
+ self.CONT_SCAN_UPDATE_FREQ
94
+ / (self.config.integration_periods / self.config.modulation_frequency)
95
+ )
96
+
97
+ def __post_init__(self: _ForceAmpCom) -> None:
98
+ self.__ser = _serial_factory(self.config)
99
+
100
+ def __del__(self: _ForceAmpCom) -> None:
101
+ """Closes connection when class instance goes out of scope."""
102
+ with contextlib.suppress(AttributeError):
103
+ # If the serial device does not exist, self.__ser is never created - hence catch
104
+ self.__ser.close()
105
+
106
+ def write_all(self: _ForceAmpCom) -> list[str]:
107
+ responses = []
108
+ responses.append(self.write_period_and_frequency())
109
+ responses.append(self.write_sweep_length())
110
+ responses.append(self.write_waveform())
111
+ responses.append(self.write_modulation_voltage())
112
+ responses.extend(self.write_list())
113
+ return responses
114
+
115
+ def write_period_and_frequency(self: _ForceAmpCom) -> str:
116
+ s = f"!set timing,{self.config.integration_periods},{self.config.modulation_frequency}\r"
117
+ return self._encode_send_response(s)
118
+
119
+ def write_sweep_length(self: _ForceAmpCom) -> str:
120
+ s = f"!set sweep length,{self.config.sweep_length_ms}\r"
121
+ return self._encode_send_response(s)
122
+
123
+ def write_waveform(self: _ForceAmpCom) -> str:
124
+ s = f"!set wave,{self.config.modulation_waveform}\r"
125
+ return self._encode_send_response(s)
126
+
127
+ def write_modulation_voltage(self: _ForceAmpCom) -> str:
128
+ min_v = self.config.min_modulation_voltage
129
+ max_v = self.config.max_modulation_voltage
130
+ crit1 = self.MIN_ALLOWED_MOD_VOLTAGE <= min_v <= self.MAX_ALLOWED_MOD_VOLTAGE
131
+ crit2 = self.MIN_ALLOWED_MOD_VOLTAGE <= max_v <= self.MAX_ALLOWED_MOD_VOLTAGE
132
+
133
+ if crit1 and crit2:
134
+ s = f"!set generator,{min_v},{max_v}\r"
135
+ return self._encode_send_response(s)
136
+
137
+ msg = f"Modulation voltages min: {min_v:.1f}, max: {max_v:.1f} not allowed."
138
+ raise ValueError(msg)
139
+
140
+ def write_list(self: _ForceAmpCom) -> list[str]:
141
+ for iteration, entry in enumerate(self.scanning_list):
142
+ string = f"!lut,{iteration},{entry}\r"
143
+ self._encode_and_send(string)
144
+ return self._get_response().split("\r")
145
+
146
+ def start_scan(self: _ForceAmpCom) -> tuple[str, np.ndarray]:
147
+ start_command = "!s,\r"
148
+ self._encode_and_send(start_command)
149
+ responses = self._get_response().split("\r")
150
+ output_array = np.zeros((self.scanning_points, 3))
151
+ output_array[:, 0] = self.times
152
+ iteration = 0
153
+ for entry in responses:
154
+ if "!R" in entry:
155
+ radius, angle = self._format_output(entry)
156
+ output_array[iteration, 1] = radius
157
+ output_array[iteration, 2] = angle
158
+ iteration += 1
159
+ elif "!D" in entry:
160
+ break
161
+ return start_command, output_array
162
+
163
+ def start_continuous_scan(self: _ForceAmpCom) -> tuple[str, list[str]]:
164
+ start_command = "!dat,1\r"
165
+ self._encode_and_send(start_command)
166
+ # Call self._read_until() twice, because amp returns !A,OK twice for
167
+ # continuous output (for some unknown reason)
168
+ responses = [self._read_until(expected=b"\r") for _ in range(2)]
169
+ return start_command, responses
170
+
171
+ def stop_continuous_scan(self: _ForceAmpCom) -> tuple[str, str]:
172
+ start_command = "!dat,0\r"
173
+ self._encode_and_send(start_command)
174
+ response = self._read_until(expected=b"!A,OK\r")
175
+ return start_command, response
176
+
177
+ def read_continuous_data(self: _ForceAmpCom) -> FloatArray:
178
+ output_array = np.zeros((self.datapoints_per_update, 3))
179
+ output_array[:, 0] = np.linspace(0, 1, self.datapoints_per_update)
180
+ for iteration in range(self.datapoints_per_update):
181
+ amp_output = self._read_until(expected=b"\r")
182
+ radius, angle = self._format_output(amp_output)
183
+ output_array[iteration, 1] = radius
184
+ output_array[iteration, 2] = angle
185
+ return output_array
186
+
187
+ def _encode_send_response(self: _ForceAmpCom, command: str) -> str:
188
+ self._encode_and_send(command)
189
+ return self._get_response()
190
+
191
+ def _encode_and_send(self: _ForceAmpCom, command: str) -> None:
192
+ self.__ser.write(command.encode(self.ENCODING))
193
+
194
+ @_BackoffRetry(backoff_base=0.2, logger=logging.getLogger(LOGGER_NAME))
195
+ def _get_response(self: _ForceAmpCom) -> str:
196
+ r = self.__ser.readline().decode(self.ENCODING).strip()
197
+ if r[: len(self.OK_RESPONSE)] != self.OK_RESPONSE:
198
+ msg = f"Expected response '{self.OK_RESPONSE}', received: '{r}'"
199
+ raise serialutil.SerialException(msg)
200
+
201
+ return r
202
+
203
+ def _read_until(self: _ForceAmpCom, expected: bytes) -> str:
204
+ return self.__ser.read_until(expected=expected).decode(self.ENCODING).strip()
205
+
206
+ def _format_output(self: _ForceAmpCom, amp_output: str) -> tuple[float, float]:
207
+ """Format output from Force LIA to radius and angle."""
208
+ response_list = amp_output.split(",")
209
+ return float(response_list[1]), float(response_list[2])
210
+
211
+
212
+ @dataclass
213
+ class _LeAmpCom:
214
+ config: LeDeviceConfiguration
215
+
216
+ __ser: serial.Serial | LeMockDevice = field(init=False)
217
+
218
+ ENCODING: ClassVar[str] = "utf-8"
219
+ DAC_BITWIDTH: ClassVar[int] = 4096 # 12-bit DAC
220
+
221
+ OK_RESPONSE: ClassVar[str] = "ACK"
222
+ START_COMMAND: ClassVar[str] = "G"
223
+ FETCH_COMMAND: ClassVar[str] = "R"
224
+ STATUS_COMMAND: ClassVar[str] = "H"
225
+ SEND_LIST_COMMAND: ClassVar[str] = "L"
226
+ SEND_SETTINGS_COMMAND: ClassVar[str] = "S"
227
+
228
+ @cached_property
229
+ def scanning_points(self: _LeAmpCom) -> int:
230
+ return self.config.n_points
231
+
232
+ @cached_property
233
+ def times(self: _LeAmpCom) -> FloatArray:
234
+ return _delay_from_intervals(
235
+ delayunit=load_delayunit(self.config.delayunit),
236
+ intervals=self.config.scan_intervals,
237
+ points_per_interval=_points_per_interval(
238
+ self.scanning_points, self._squished_intervals
239
+ ),
240
+ )
241
+
242
+ @cached_property
243
+ def scanning_list(self: _LeAmpCom) -> list[int]:
244
+ scanning_list: list[int] = []
245
+ for interval, n_points in zip(
246
+ self._squished_intervals,
247
+ _points_per_interval(self.scanning_points, self._squished_intervals),
248
+ ):
249
+ denormalized = self._denormalize_interval(interval)
250
+ scanning_list.extend(
251
+ np.linspace(
252
+ denormalized[0], denormalized[1], n_points, endpoint=False
253
+ ).astype(int),
254
+ )
255
+ return scanning_list
256
+
257
+ def __post_init__(self: _LeAmpCom) -> None:
258
+ self.__ser = _serial_factory(self.config)
259
+
260
+ def __del__(self: _LeAmpCom) -> None:
261
+ """Closes connection when class instance goes out of scope."""
262
+ with contextlib.suppress(AttributeError):
263
+ # If the serial device does not exist, self.__ser is never created - hence catch
264
+ self.__ser.close()
265
+
266
+ def write_all(self: _LeAmpCom) -> list[str]:
267
+ responses: list[str] = []
268
+ responses.append(self.write_list_length_and_integration_periods_and_use_ema())
269
+ responses.append(self.write_list())
270
+ return responses
271
+
272
+ def write_list_length_and_integration_periods_and_use_ema(self: _LeAmpCom) -> str:
273
+ self._encode_send_response(self.SEND_SETTINGS_COMMAND)
274
+ self._raw_byte_send(
275
+ [self.scanning_points, self.config.integration_periods, self.config.use_ema]
276
+ )
277
+ return self._get_response()
278
+
279
+ def write_list(self: _LeAmpCom) -> str:
280
+ self._encode_send_response(self.SEND_LIST_COMMAND)
281
+ self._raw_byte_send(self.scanning_list)
282
+ return self._get_response()
283
+
284
+ def start_scan(self: _LeAmpCom) -> tuple[str, np.ndarray]:
285
+ self._encode_send_response(self.START_COMMAND)
286
+ self._await_scan_finished()
287
+ Xs, Ys = self._read_scan()
288
+
289
+ radii, angles = self._convert_to_r_angle(Xs, Ys)
290
+
291
+ output_array = np.zeros((self.scanning_points, 3))
292
+ output_array[:, 0] = self.times
293
+ output_array[:, 1] = radii
294
+ output_array[:, 2] = angles
295
+
296
+ return self.START_COMMAND, output_array
297
+
298
+ @cached_property
299
+ def _squished_intervals(self: _LeAmpCom) -> list[Interval]:
300
+ """Intervals squished into effective DAC range."""
301
+ return _squish_intervals(
302
+ intervals=self.config.scan_intervals or [Interval(lower=0.0, upper=1.0)],
303
+ lower_bound=self.config.fs_dac_lower_bound,
304
+ upper_bound=self.config.fs_dac_upper_bound,
305
+ bitwidth=self.DAC_BITWIDTH,
306
+ )
307
+
308
+ def _convert_to_r_angle(
309
+ self: _LeAmpCom, Xs: list, Ys: list
310
+ ) -> tuple[FloatArray, FloatArray]:
311
+ r = np.sqrt(np.array(Xs) ** 2 + np.array(Ys) ** 2)
312
+ angle = np.arctan2(np.array(Ys), np.array(Xs))
313
+ return r, np.rad2deg(angle)
314
+
315
+ def _denormalize_interval(self: _LeAmpCom, interval: Interval) -> list[int]:
316
+ lower = int(interval.lower * self.DAC_BITWIDTH)
317
+ upper = int(interval.upper * self.DAC_BITWIDTH)
318
+ return [lower, upper]
319
+
320
+ def _encode_send_response(self: _LeAmpCom, command: str) -> str:
321
+ self._encode_and_send(command)
322
+ return self._get_response()
323
+
324
+ def _encode_and_send(self: _LeAmpCom, command: str) -> None:
325
+ self.__ser.write(command.encode(self.ENCODING))
326
+
327
+ def _raw_byte_send(self: _LeAmpCom, values: list[int]) -> None:
328
+ c = BitArray()
329
+ for value in values:
330
+ c.append(BitArray(uintle=value, length=16))
331
+ self.__ser.write(c.tobytes())
332
+
333
+ def _await_scan_finished(self: _LeAmpCom) -> None:
334
+ time.sleep(self.config._sweep_length_ms * 1.0e-3) # noqa: SLF001, access to private attribute for backwards compatibility
335
+ status = self._get_status()
336
+
337
+ while status == _LeStatus.SCANNING:
338
+ time.sleep(self.config._sweep_length_ms * 1e-3 * 0.01) # noqa: SLF001, access to private attribute for backwards compatibility
339
+ status = self._get_status()
340
+
341
+ @_BackoffRetry(backoff_base=0.05, logger=logging.getLogger(LOGGER_NAME))
342
+ def _get_response(self: _LeAmpCom) -> str:
343
+ return self.__ser.read_until().decode(self.ENCODING).strip()
344
+
345
+ @_BackoffRetry(
346
+ backoff_base=1e-2, max_tries=5, logger=logging.getLogger(LOGGER_NAME)
347
+ )
348
+ def _read_scan(self: _LeAmpCom) -> tuple[list[float], list[float]]:
349
+ self._encode_and_send(self.FETCH_COMMAND)
350
+
351
+ bytes_to_receive = self.scanning_points * 4 + self.scanning_points * 4
352
+ scan_bytes = self.__ser.read(bytes_to_receive)
353
+ if len(scan_bytes) != bytes_to_receive:
354
+ msg = f"received {len(scan_bytes)} bytes, expected {bytes_to_receive}"
355
+ raise serialutil.SerialException(msg)
356
+
357
+ Xs = [
358
+ BitArray(bytes=scan_bytes[d : d + 4]).floatle
359
+ for d in range(0, self.scanning_points * 4, 4)
360
+ ]
361
+ Ys = [
362
+ BitArray(bytes=scan_bytes[d : d + 4]).floatle
363
+ for d in range(self.scanning_points * 4, self.scanning_points * 8, 4)
364
+ ]
365
+
366
+ return Xs, Ys
367
+
368
+ def _get_status(self: _LeAmpCom) -> _LeStatus:
369
+ msg = self._encode_send_response(self.STATUS_COMMAND)
370
+ if msg == _LeStatus.SCANNING.value:
371
+ return _LeStatus.SCANNING
372
+ if msg == _LeStatus.IDLE.value:
373
+ return _LeStatus.IDLE
374
+ msg = f"Unknown status: {msg}"
375
+ raise ValueError(msg)
376
+
377
+
378
+ class _LeStatus(Enum):
379
+ SCANNING = "Error: Scan is ongoing."
380
+ IDLE = "ACK: Idle."
381
+
382
+
383
+ @overload
384
+ def _serial_factory(
385
+ config: ForceDeviceConfiguration,
386
+ ) -> serial.Serial | ForceMockDevice: ...
387
+
388
+
389
+ @overload
390
+ def _serial_factory(config: LeDeviceConfiguration) -> serial.Serial | LeMockDevice: ...
391
+
392
+
393
+ def _serial_factory(config: DeviceConfiguration) -> serial.Serial | MockDevice:
394
+ if "mock_device" in config.amp_port:
395
+ return _mock_device_factory(config)
396
+
397
+ return serial.Serial(
398
+ port=config.amp_port,
399
+ baudrate=config.amp_baudrate,
400
+ timeout=config.amp_timeout_seconds,
401
+ )
402
+
403
+
404
+ def _points_per_interval(n_points: int, intervals: list[Interval]) -> list[int]:
405
+ """Divides a total number of points between intervals."""
406
+ interval_lengths = [interval.length for interval in intervals]
407
+ total_length = sum(interval_lengths)
408
+
409
+ points_per_interval_floats = [
410
+ n_points * length / total_length for length in interval_lengths
411
+ ]
412
+ points_per_interval = [int(e) for e in points_per_interval_floats]
413
+
414
+ # We must distribute the remainder from the int operation to get the right amount of total points
415
+ remainders = [modf(num)[0] for num in points_per_interval_floats]
416
+ sorted_indices = np.flip(np.argsort(remainders))
417
+ for i in range(int(0.5 + np.sum(remainders))):
418
+ points_per_interval[sorted_indices[i]] += 1
419
+
420
+ return points_per_interval
421
+
422
+
423
+ def _squish_intervals(
424
+ intervals: list[Interval], lower_bound: int, upper_bound: int, bitwidth: int
425
+ ) -> list[Interval]:
426
+ """Squish scanning intervals into effective DAC range."""
427
+ lower = lower_bound / bitwidth
428
+ upper = upper_bound / bitwidth
429
+
430
+ def f(x: float) -> float:
431
+ return lower + (upper - lower) * x
432
+
433
+ return [Interval(f(interval.lower), f(interval.upper)) for interval in intervals]
434
+
435
+
436
+ def _delay_from_intervals(
437
+ delayunit: Delay, intervals: list[Interval], points_per_interval: list[int]
438
+ ) -> FloatArray:
439
+ """Convert a list of intervals to a list of delay times."""
440
+ times: list[float] = []
441
+ for interval, n_points in zip(intervals, points_per_interval):
442
+ times.extend(
443
+ delayunit(
444
+ np.linspace(interval.lower, interval.upper, n_points, endpoint=False)
445
+ )
446
+ )
447
+ return np.array(times)