ephys-link 2.0.0__py3-none-any.whl → 2.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,274 +1,200 @@
1
- """Socket.IO Server.
2
-
3
- Responsible to managing the Socket.IO connection and events.
4
- Directs events to the platform handler or handles them directly.
5
-
6
- Usage:
7
- Instantiate Server with the appropriate options, platform handler, and console.
8
- Then call `launch()` to start the server.
9
-
10
- ```python
11
- Server(options, platform_handler, console).launch()
12
- ```
13
- """
14
-
15
- from asyncio import get_event_loop, run
16
- from collections.abc import Callable, Coroutine
17
- from json import JSONDecodeError, dumps, loads
18
- from typing import Any, TypeVar, final
19
- from uuid import uuid4
20
-
21
- from aiohttp.web import Application, run_app
22
- from pydantic import ValidationError
23
- from socketio import AsyncClient, AsyncServer # pyright: ignore [reportMissingTypeStubs]
24
- from vbl_aquarium.models.ephys_link import (
25
- EphysLinkOptions,
26
- SetDepthRequest,
27
- SetInsideBrainRequest,
28
- SetPositionRequest,
29
- )
30
- from vbl_aquarium.models.proxy import PinpointIdResponse
31
- from vbl_aquarium.utils.vbl_base_model import VBLBaseModel
32
-
33
- from ephys_link.__about__ import __version__
34
- from ephys_link.back_end.platform_handler import PlatformHandler
35
- from ephys_link.utils.console import Console
36
- from ephys_link.utils.constants import PORT
37
-
38
- # Server message generic types.
39
- INPUT_TYPE = TypeVar("INPUT_TYPE", bound=VBLBaseModel)
40
- OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=VBLBaseModel)
41
-
42
-
43
- @final
44
- class Server:
45
- def __init__(self, options: EphysLinkOptions, platform_handler: PlatformHandler, console: Console) -> None:
46
- """Initialize server fields based on options and platform handler.
47
-
48
- Args:
49
- options: Launch options object.
50
- platform_handler: Platform handler instance.
51
- console: Console instance.
52
- """
53
-
54
- # Save fields.
55
- self._options = options
56
- self._platform_handler = platform_handler
57
- self._console = console
58
-
59
- # Initialize based on proxy usage.
60
- self._sio: AsyncServer | AsyncClient = AsyncClient() if self._options.use_proxy else AsyncServer()
61
- if not self._options.use_proxy:
62
- # Exit if _sio is not a Server.
63
- if not isinstance(self._sio, AsyncServer):
64
- error = "Server not initialized."
65
- self._console.critical_print(error)
66
- raise TypeError(error)
67
-
68
- self._app = Application()
69
- self._sio.attach(self._app) # pyright: ignore [reportUnknownMemberType]
70
-
71
- # Bind connection events.
72
- _ = self._sio.on("connect", self.connect) # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
73
- _ = self._sio.on("disconnect", self.disconnect) # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
74
-
75
- # Store connected client.
76
- self._client_sid: str = ""
77
-
78
- # Generate Pinpoint ID for proxy usage.
79
- self._pinpoint_id = str(uuid4())[:8]
80
-
81
- # Bind events.
82
- _ = self._sio.on("*", self.platform_event_handler) # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
83
-
84
- def launch(self) -> None:
85
- """Launch the server.
86
-
87
- Based on the options, either connect to a proxy or launch the server locally.
88
- """
89
-
90
- # List platform and available manipulators.
91
- self._console.info_print("PLATFORM", self._platform_handler.get_display_name())
92
- self._console.info_print(
93
- "MANIPULATORS",
94
- str(get_event_loop().run_until_complete(self._platform_handler.get_manipulators()).manipulators),
95
- )
96
-
97
- # Launch server
98
- if self._options.use_proxy:
99
- self._console.info_print("PINPOINT ID", self._pinpoint_id)
100
-
101
- async def connect_proxy() -> None:
102
- # Exit if _sio is not a proxy client.
103
- if not isinstance(self._sio, AsyncClient):
104
- error = "Proxy client not initialized."
105
- self._console.critical_print(error)
106
- raise TypeError(error)
107
-
108
- # noinspection HttpUrlsUsage
109
- await self._sio.connect(f"http://{self._options.proxy_address}:{PORT}") # pyright: ignore [reportUnknownMemberType]
110
- await self._sio.wait()
111
-
112
- run(connect_proxy())
113
- else:
114
- run_app(self._app, port=PORT)
115
-
116
- # Helper functions.
117
- def _malformed_request_response(self, request: str, data: tuple[tuple[Any], ...]) -> str: # pyright: ignore [reportExplicitAny]
118
- """Return a response for a malformed request.
119
-
120
- Args:
121
- request: Original request.
122
- data: Request data.
123
-
124
- Returns:
125
- Response for a malformed request.
126
- """
127
- self._console.error_print("MALFORMED REQUEST", f"{request}: {data}")
128
- return dumps({"error": "Malformed request."})
129
-
130
- async def _run_if_data_available(
131
- self,
132
- function: Callable[[str], Coroutine[Any, Any, VBLBaseModel]], # pyright: ignore [reportExplicitAny]
133
- event: str,
134
- data: tuple[tuple[Any], ...], # pyright: ignore [reportExplicitAny]
135
- ) -> str:
136
- """Run a function if data is available.
137
-
138
- Args:
139
- function: Function to run.
140
- event: Event name.
141
- data: Event data.
142
-
143
- Returns:
144
- Response data from function.
145
- """
146
- request_data = data[1]
147
- if request_data:
148
- return str((await function(str(request_data))).to_json_string())
149
- return self._malformed_request_response(event, request_data)
150
-
151
- async def _run_if_data_parses(
152
- self,
153
- function: Callable[[INPUT_TYPE], Coroutine[Any, Any, OUTPUT_TYPE]], # pyright: ignore [reportExplicitAny]
154
- data_type: type[INPUT_TYPE],
155
- event: str,
156
- data: tuple[tuple[Any], ...], # pyright: ignore [reportExplicitAny]
157
- ) -> str:
158
- """Run a function if data parses.
159
-
160
- Args:
161
- function: Function to run.
162
- data_type: Data type to parse.
163
- event: Event name.
164
- data: Event data.
165
-
166
- Returns:
167
- Response data from function.
168
- """
169
- request_data = data[1]
170
- if request_data:
171
- try:
172
- parsed_data = data_type(**loads(str(request_data)))
173
- except JSONDecodeError:
174
- return self._malformed_request_response(event, request_data)
175
- except ValidationError as e:
176
- self._console.exception_error_print(event, e)
177
- return self._malformed_request_response(event, request_data)
178
- else:
179
- return str((await function(parsed_data)).to_json_string())
180
- return self._malformed_request_response(event, request_data)
181
-
182
- # Event Handlers.
183
-
184
- async def connect(self, sid: str, _: str) -> bool:
185
- """Handle connections to the server.
186
-
187
- Args:
188
- sid: Socket session ID.
189
- _: Extra connection data (unused).
190
-
191
- Returns:
192
- False on error to refuse connection, True otherwise.
193
- """
194
- self._console.info_print("CONNECTION REQUEST", sid)
195
-
196
- if self._client_sid == "":
197
- self._client_sid = sid
198
- self._console.info_print("CONNECTION GRANTED", sid)
199
- return True
200
-
201
- self._console.error_print(
202
- "CONNECTION REFUSED", f"Cannot connect {sid} as {self._client_sid} is already connected."
203
- )
204
- return False
205
-
206
- async def disconnect(self, sid: str) -> None:
207
- """Handle disconnections from the server.
208
-
209
- Args:
210
- sid: Socket session ID.
211
- """
212
- self._console.info_print("DISCONNECTED", sid)
213
-
214
- # Reset client SID if it matches.
215
- if self._client_sid == sid:
216
- self._client_sid = ""
217
- else:
218
- self._console.error_print("DISCONNECTION", f"Client {sid} disconnected without being connected.")
219
-
220
- async def platform_event_handler(self, event: str, *args: tuple[Any]) -> str: # pyright: ignore [reportExplicitAny]
221
- """Handle events from the server.
222
-
223
- Matches incoming events based on the Socket.IO API.
224
-
225
- Args:
226
- event: Event name.
227
- args: Event arguments.
228
-
229
- Returns:
230
- Response data.
231
- """
232
-
233
- # Log event.
234
- self._console.debug_print("EVENT", event)
235
-
236
- # Handle event.
237
- match event:
238
- # Server metadata.
239
- case "get_version":
240
- return __version__
241
- case "get_pinpoint_id":
242
- return PinpointIdResponse(pinpoint_id=self._pinpoint_id, is_requester=False).to_json_string()
243
- case "get_platform_info":
244
- return (await self._platform_handler.get_platform_info()).to_json_string()
245
-
246
- # Manipulator commands.
247
- case "get_manipulators":
248
- return str((await self._platform_handler.get_manipulators()).to_json_string())
249
- case "get_position":
250
- return await self._run_if_data_available(self._platform_handler.get_position, event, args)
251
- case "get_angles":
252
- return await self._run_if_data_available(self._platform_handler.get_angles, event, args)
253
- case "get_shank_count":
254
- return await self._run_if_data_available(self._platform_handler.get_shank_count, event, args)
255
- case "set_position":
256
- return await self._run_if_data_parses(
257
- self._platform_handler.set_position, SetPositionRequest, event, args
258
- )
259
- case "set_depth":
260
- return await self._run_if_data_parses(self._platform_handler.set_depth, SetDepthRequest, event, args)
261
- case "set_inside_brain":
262
- return await self._run_if_data_parses(
263
- self._platform_handler.set_inside_brain, SetInsideBrainRequest, event, args
264
- )
265
- case "stop":
266
- request_data = args[1]
267
- if request_data:
268
- return await self._platform_handler.stop(str(request_data))
269
- return self._malformed_request_response(event, request_data)
270
- case "stop_all":
271
- return await self._platform_handler.stop_all()
272
- case _:
273
- self._console.error_print("EVENT", f"Unknown event: {event}.")
274
- return dumps({"error": "Unknown event."})
1
+ from asyncio import get_event_loop, run
2
+ from collections.abc import Callable, Coroutine
3
+ from json import JSONDecodeError, dumps, loads
4
+ from typing import Any
5
+
6
+ from aiohttp.web import Application, run_app
7
+ from pydantic import ValidationError
8
+ from socketio import AsyncClient, AsyncServer
9
+ from vbl_aquarium.models.ephys_link import (
10
+ EphysLinkOptions,
11
+ SetDepthRequest,
12
+ SetInsideBrainRequest,
13
+ SetPositionRequest,
14
+ )
15
+ from vbl_aquarium.models.generic import VBLBaseModel
16
+
17
+ from ephys_link.back_end.platform_handler import PlatformHandler
18
+ from ephys_link.util.common import PORT, check_for_updates, server_preamble
19
+ from ephys_link.util.console import Console
20
+
21
+
22
+ class Server:
23
+ def __init__(self, options: EphysLinkOptions, platform_handler: PlatformHandler, console: Console) -> None:
24
+ """Initialize server fields based on options and platform handler."""
25
+
26
+ # Save fields.
27
+ self._options = options
28
+ self._platform_handler = platform_handler
29
+ self._console = console
30
+
31
+ # Initialize based on proxy usage.
32
+ self._sio: AsyncServer | AsyncClient = AsyncClient() if self._options.use_proxy else AsyncServer()
33
+ if not self._options.use_proxy:
34
+ self._app = Application()
35
+ self._sio.attach(self._app)
36
+
37
+ # Bind connection events.
38
+ self._sio.on("connect", self.connect)
39
+ self._sio.on("disconnect", self.disconnect)
40
+
41
+ # Store connected client.
42
+ self._client_sid: str = ""
43
+
44
+ # Bind events.
45
+ self._sio.on("*", self.platform_event_handler)
46
+
47
+ # Server launch.
48
+ def launch(self) -> None:
49
+ # Preamble.
50
+ server_preamble()
51
+
52
+ # Check for updates.
53
+ check_for_updates()
54
+
55
+ # List platform and available manipulators.
56
+ self._console.info_print("PLATFORM", self._platform_handler.get_platform_type())
57
+ self._console.info_print(
58
+ "MANIPULATORS",
59
+ str(get_event_loop().run_until_complete(self._platform_handler.get_manipulators()).manipulators),
60
+ )
61
+
62
+ # Launch server
63
+ if self._options.use_proxy:
64
+ self._console.info_print("PINPOINT ID", self._platform_handler.get_pinpoint_id().pinpoint_id)
65
+
66
+ async def connect_proxy() -> None:
67
+ # noinspection HttpUrlsUsage
68
+ await self._sio.connect(f"http://{self._options.proxy_address}:{PORT}")
69
+ await self._sio.wait()
70
+
71
+ run(connect_proxy())
72
+ else:
73
+ run_app(self._app, port=PORT)
74
+
75
+ # Helper functions.
76
+ def _malformed_request_response(self, request: str, data: tuple[tuple[Any], ...]) -> str:
77
+ """Return a response for a malformed request."""
78
+ self._console.labeled_error_print("MALFORMED REQUEST", f"{request}: {data}")
79
+ return dumps({"error": "Malformed request."})
80
+
81
+ async def _run_if_data_available(
82
+ self, function: Callable[[str], Coroutine[Any, Any, VBLBaseModel]], event: str, data: tuple[tuple[Any], ...]
83
+ ) -> str:
84
+ """Run a function if data is available."""
85
+ request_data = data[1]
86
+ if request_data:
87
+ return str((await function(str(request_data))).to_json_string())
88
+ return self._malformed_request_response(event, request_data)
89
+
90
+ async def _run_if_data_parses(
91
+ self,
92
+ function: Callable[[VBLBaseModel], Coroutine[Any, Any, VBLBaseModel]],
93
+ data_type: type[VBLBaseModel],
94
+ event: str,
95
+ data: tuple[tuple[Any], ...],
96
+ ) -> str:
97
+ """Run a function if data parses."""
98
+ request_data = data[1]
99
+ if request_data:
100
+ try:
101
+ parsed_data = data_type(**loads(str(request_data)))
102
+ except JSONDecodeError:
103
+ return self._malformed_request_response(event, request_data)
104
+ except ValidationError as e:
105
+ self._console.exception_error_print(event, e)
106
+ return self._malformed_request_response(event, request_data)
107
+ else:
108
+ return str((await function(parsed_data)).to_json_string())
109
+ return self._malformed_request_response(event, request_data)
110
+
111
+ # Event Handlers.
112
+
113
+ async def connect(self, sid: str, _: str) -> bool:
114
+ """Handle connections to the server
115
+
116
+ :param sid: Socket session ID.
117
+ :type sid: str
118
+ :param _: Extra connection data (unused).
119
+ :type _: str
120
+ :returns: False on error to refuse connection, True otherwise.
121
+ :rtype: bool
122
+ """
123
+ self._console.info_print("CONNECTION REQUEST", sid)
124
+
125
+ if self._client_sid == "":
126
+ self._client_sid = sid
127
+ self._console.info_print("CONNECTION GRANTED", sid)
128
+ return True
129
+
130
+ self._console.error_print(f"CONNECTION REFUSED to {sid}. Client {self._client_sid} already connected.")
131
+ return False
132
+
133
+ async def disconnect(self, sid: str) -> None:
134
+ """Handle disconnections from the server
135
+
136
+ :param sid: Socket session ID.
137
+ :type sid: str
138
+ """
139
+ self._console.info_print("DISCONNECTED", sid)
140
+
141
+ # Reset client SID if it matches.
142
+ if self._client_sid == sid:
143
+ self._client_sid = ""
144
+ else:
145
+ self._console.error_print(f"Client {sid} disconnected without being connected.")
146
+
147
+ # noinspection PyTypeChecker
148
+ async def platform_event_handler(self, event: str, *args: tuple[Any]) -> str:
149
+ """Handle events from the server
150
+
151
+ :param event: Event name.
152
+ :type event: str
153
+ :param args: Event arguments.
154
+ :type args: tuple[Any]
155
+ :returns: Response data.
156
+ :rtype: str
157
+ """
158
+
159
+ # Log event.
160
+ self._console.debug_print("EVENT", event)
161
+
162
+ # Handle event.
163
+ match event:
164
+ # Server metadata.
165
+ case "get_version":
166
+ return self._platform_handler.get_version()
167
+ case "get_pinpoint_id":
168
+ return str(self._platform_handler.get_pinpoint_id().to_json_string())
169
+ case "get_platform_type":
170
+ return self._platform_handler.get_platform_type()
171
+
172
+ # Manipulator commands.
173
+ case "get_manipulators":
174
+ return str((await self._platform_handler.get_manipulators()).to_json_string())
175
+ case "get_position":
176
+ return await self._run_if_data_available(self._platform_handler.get_position, event, args)
177
+ case "get_angles":
178
+ return await self._run_if_data_available(self._platform_handler.get_angles, event, args)
179
+ case "get_shank_count":
180
+ return await self._run_if_data_available(self._platform_handler.get_shank_count, event, args)
181
+ case "set_position":
182
+ return await self._run_if_data_parses(
183
+ self._platform_handler.set_position, SetPositionRequest, event, args
184
+ )
185
+ case "set_depth":
186
+ return await self._run_if_data_parses(self._platform_handler.set_depth, SetDepthRequest, event, args)
187
+ case "set_inside_brain":
188
+ return await self._run_if_data_parses(
189
+ self._platform_handler.set_inside_brain, SetInsideBrainRequest, event, args
190
+ )
191
+ case "stop":
192
+ request_data = args[1]
193
+ if request_data:
194
+ return await self._platform_handler.stop(str(request_data))
195
+ return self._malformed_request_response(event, request_data)
196
+ case "stop_all":
197
+ return await self._platform_handler.stop_all()
198
+ case _:
199
+ self._console.error_print(f"Unknown event: {event}.")
200
+ return dumps({"error": "Unknown event."})
@@ -1,84 +1,54 @@
1
- from typing import final, override
2
-
3
- from vbl_aquarium.models.unity import Vector3, Vector4
4
-
5
- from ephys_link.utils.base_binding import BaseBinding
6
- from ephys_link.utils.converters import list_to_vector4
7
-
8
-
9
- @final
10
- class FakeBinding(BaseBinding):
11
- def __init__(self) -> None:
12
- """Initialize fake manipulator infos."""
13
-
14
- self._positions = [Vector4() for _ in range(8)]
15
- self._angles = [
16
- Vector3(x=90, y=60, z=0),
17
- Vector3(x=-90, y=60, z=0),
18
- Vector3(x=180, y=60, z=0),
19
- Vector3(x=0, y=60, z=0),
20
- Vector3(x=45, y=30, z=0),
21
- Vector3(x=-45, y=30, z=0),
22
- Vector3(x=135, y=30, z=0),
23
- Vector3(x=-135, y=30, z=0),
24
- ]
25
-
26
- @staticmethod
27
- @override
28
- def get_display_name() -> str:
29
- return "Fake Manipulator"
30
-
31
- @staticmethod
32
- @override
33
- def get_cli_name() -> str:
34
- return "fake"
35
-
36
- @override
37
- async def get_manipulators(self) -> list[str]:
38
- return list(map(str, range(8)))
39
-
40
- @override
41
- async def get_axes_count(self) -> int:
42
- return 4
43
-
44
- @override
45
- def get_dimensions(self) -> Vector4:
46
- return list_to_vector4([20] * 4)
47
-
48
- @override
49
- async def get_position(self, manipulator_id: str) -> Vector4:
50
- return self._positions[int(manipulator_id)]
51
-
52
- @override
53
- async def get_angles(self, manipulator_id: str) -> Vector3:
54
- return self._angles[int(manipulator_id)]
55
-
56
- @override
57
- async def get_shank_count(self, manipulator_id: str) -> int:
58
- return 1
59
-
60
- @override
61
- def get_movement_tolerance(self) -> float:
62
- return 0.001
63
-
64
- @override
65
- async def set_position(self, manipulator_id: str, position: Vector4, speed: float) -> Vector4:
66
- self._positions[int(manipulator_id)] = position
67
- return position
68
-
69
- @override
70
- async def set_depth(self, manipulator_id: str, depth: float, speed: float) -> float:
71
- self._positions[int(manipulator_id)].w = depth
72
- return depth
73
-
74
- @override
75
- async def stop(self, manipulator_id: str) -> None:
76
- pass
77
-
78
- @override
79
- def platform_space_to_unified_space(self, platform_space: Vector4) -> Vector4:
80
- return platform_space
81
-
82
- @override
83
- def unified_space_to_platform_space(self, unified_space: Vector4) -> Vector4:
84
- return unified_space
1
+ from vbl_aquarium.models.unity import Vector3, Vector4
2
+
3
+ from ephys_link.util.base_bindings import BaseBindings
4
+
5
+
6
+ class FakeBindings(BaseBindings):
7
+ def __init__(self) -> None:
8
+ """Initialize fake manipulator infos."""
9
+
10
+ self._positions = [Vector4() for _ in range(8)]
11
+ self._angles = [
12
+ Vector3(x=90, y=60, z=0),
13
+ Vector3(x=-90, y=60, z=0),
14
+ Vector3(x=180, y=60, z=0),
15
+ Vector3(x=0, y=60, z=0),
16
+ Vector3(x=45, y=30, z=0),
17
+ Vector3(x=-45, y=30, z=0),
18
+ Vector3(x=135, y=30, z=0),
19
+ Vector3(x=-135, y=30, z=0),
20
+ ]
21
+
22
+ async def get_manipulators(self) -> list[str]:
23
+ return list(map(str, range(8)))
24
+
25
+ async def get_num_axes(self) -> int:
26
+ return 4
27
+
28
+ def get_dimensions(self) -> Vector4:
29
+ return Vector4(x=20, y=20, z=20, w=20)
30
+
31
+ async def get_position(self, manipulator_id: str) -> Vector4:
32
+ return self._positions[int(manipulator_id)]
33
+
34
+ async def get_angles(self, manipulator_id: str) -> Vector3:
35
+ return self._angles[int(manipulator_id)]
36
+
37
+ async def get_shank_count(self, _: str) -> int:
38
+ return 1
39
+
40
+ async def get_movement_tolerance(self) -> float:
41
+ return 0.001
42
+
43
+ async def set_position(self, manipulator_id: str, position: Vector4, _: float) -> Vector4:
44
+ self._positions[int(manipulator_id)] = position
45
+ return position
46
+
47
+ async def stop(self, _: str) -> None:
48
+ pass
49
+
50
+ def platform_space_to_unified_space(self, platform_space: Vector4) -> Vector4:
51
+ pass
52
+
53
+ def unified_space_to_platform_space(self, unified_space: Vector4) -> Vector4:
54
+ pass