attp-client 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
attp_client/router.py ADDED
@@ -0,0 +1,113 @@
1
+ import asyncio
2
+ from contextvars import ContextVar
3
+ from typing import Any, TypeVar, overload
4
+ import msgpack
5
+ from pydantic import TypeAdapter
6
+ from reactivex import Subject, defer, empty, from_future, of, operators as ops, throw, timer
7
+ from reactivex.scheduler.eventloop import AsyncIOScheduler
8
+ from attp_core.rs_api import PyAttpMessage, AttpCommand
9
+
10
+ from attp_client.errors.correlated_rpc_exception import CorrelatedRPCException
11
+ from attp_client.errors.serialization_error import SerializationError
12
+ from attp_client.interfaces.error import IErr
13
+ from attp_client.misc.fixed_basemodel import FixedBaseModel
14
+ from attp_client.misc.serializable import Serializable
15
+ from attp_client.session import SessionDriver
16
+ from attp_client.utils.context_awaiter import ContextAwaiter
17
+
18
+
19
+ T = TypeVar("T")
20
+
21
+
22
+ class AttpRouter:
23
+ def __init__(
24
+ self,
25
+ responder: Subject[PyAttpMessage],
26
+ session: SessionDriver
27
+ ) -> None:
28
+ self.responder = responder
29
+ self.session = session
30
+ self.context = ContextVar[str | None]("session_context", default=None)
31
+
32
+ @overload
33
+ async def send(
34
+ self,
35
+ route: str,
36
+ data: FixedBaseModel | Serializable | None = ...,
37
+ timeout: float = 20,
38
+ ) -> Any: ...
39
+
40
+ @overload
41
+ async def send(
42
+ self,
43
+ route: str,
44
+ data: FixedBaseModel | Serializable | None = ...,
45
+ timeout: float = 20, *,
46
+ expected_response: type[T],
47
+ ) -> T | Any: ...
48
+
49
+ async def send(
50
+ self,
51
+ route: str,
52
+ data: FixedBaseModel | Serializable | None = None,
53
+ timeout: float = 50, *,
54
+ expected_response: type[T] | None = None
55
+ ) -> T | Any:
56
+ # correlation_id = await self.session.send_message(pattern, data)
57
+
58
+ responder = ContextAwaiter[Any](defer(
59
+ lambda _: (
60
+ from_future(asyncio.ensure_future(self.session.send_message(route=route, data=data))).pipe(
61
+ ops.flat_map(
62
+ lambda cid: empty().pipe(
63
+ ops.concat(self.__pipe_filter(cid, timeout=timeout)),
64
+ )
65
+ )
66
+ )
67
+ )
68
+ ))
69
+
70
+ response_data = await responder.wait()
71
+
72
+ return self.__format_response(expected_type=expected_response or Any, response_data=response_data)
73
+
74
+ async def emit(self, route: str, data: FixedBaseModel | Serializable | None = None):
75
+ await self.session.emit_message(route, data)
76
+
77
+ def __pipe_filter(self, awaiting_correlation_id: bytes, timeout: float):
78
+ loop = asyncio.get_event_loop()
79
+ asyncio_scheduler = AsyncIOScheduler(loop)
80
+ print(awaiting_correlation_id)
81
+ return self.responder.pipe(
82
+ ops.subscribe_on(asyncio_scheduler),
83
+ ops.filter(lambda pair: pair.correlation_id == awaiting_correlation_id),
84
+ ops.flat_map(lambda r: throw(CorrelatedRPCException.from_err_object(correlation_id=r.correlation_id or b'<nocorrid>', err=IErr.mps(r.payload) if r.payload else IErr(detail={"code": "ErrorWithoutPayload"}))) if r.command_type == AttpCommand.ERR else of(r)),
85
+ #######################################
86
+ ## This is RPC Defer Handler ##
87
+ #######################################
88
+ ops.timeout_with_mapper(
89
+ timer(timeout, scheduler=asyncio_scheduler),
90
+ lambda i: (
91
+ timer(timeout, scheduler=asyncio_scheduler) if getattr(i, "frame_type", None) == AttpCommand.DEFER else of(None)
92
+ ),
93
+ throw(TimeoutError("ATTP response failed."))
94
+ ),
95
+ ops.filter(lambda pair: pair.command_type == AttpCommand.ACK),
96
+ ops.first(),
97
+ )
98
+
99
+ def __format_response(self, expected_type: Any, response_data: PyAttpMessage):
100
+ if issubclass(expected_type, FixedBaseModel):
101
+ if not response_data.payload:
102
+ raise SerializationError(f"Nonetype payload received from session while expected type {expected_type.__name__}")
103
+ try:
104
+ return expected_type.mps(response_data.payload)
105
+ except Exception as e:
106
+ raise SerializationError(str(e))
107
+
108
+ serialized = msgpack.unpackb(response_data.payload) if response_data.payload else None
109
+
110
+ if expected_type is not None:
111
+ return serialized
112
+
113
+ return TypeAdapter(expected_type, config={"arbitrary_types_allowed": True}).validate_python(serialized)
attp_client/session.py ADDED
@@ -0,0 +1,316 @@
1
+ import asyncio
2
+ from logging import Logger, getLogger
3
+ import traceback
4
+ from typing import Any, Sequence
5
+ from uuid import uuid4
6
+
7
+ from reactivex import Subject
8
+
9
+ from attp_client.consts import ATTP_VERSION
10
+ from attp_client.errors.dead_session import DeadSessionError
11
+ from attp_client.errors.unauthenticated_error import UnauthenticatedError
12
+ from attp_client.interfaces.handshake.auth import IAuth
13
+ from attp_client.interfaces.error import IErr
14
+ from attp_client.interfaces.handshake.hello import IHello
15
+ from attp_client.interfaces.handshake.ready import IReady
16
+ from attp_client.interfaces.route_mappings import IRouteMapping
17
+ from attp_client.misc.fixed_basemodel import FixedBaseModel
18
+ from attp_core.rs_api import Session, PyAttpMessage, AttpCommand
19
+
20
+ from attp_client.misc.serializable import Serializable
21
+ from attp_client.types.route_mapping import AttpRouteMapping
22
+ from attp_client.utils import serializer
23
+ from attp_client.utils.route_mapper import resolve_route_by_id
24
+
25
+
26
+ class SessionDriver:
27
+ _organization_id: int
28
+ server_routes: Sequence[IRouteMapping] | None
29
+
30
+ def __init__(
31
+ self,
32
+ session: Session,
33
+ agt_token: str,
34
+ organization_id: int,
35
+ *,
36
+ # route_mappings: Sequence[AttpRouteMapping],
37
+ logger: Logger = getLogger("Ascender Framework"),
38
+ ) -> None:
39
+ self.agt_token = agt_token
40
+ self.session = session
41
+ self._organization_id = organization_id
42
+ self.server_routes = None
43
+ self.logger = logger
44
+
45
+ self.client_routes = []
46
+ # self.is_authenticated = False
47
+
48
+ self.messages = asyncio.Queue[PyAttpMessage]()
49
+ self.auth_event = asyncio.Event()
50
+
51
+ @property
52
+ def is_connected(self) -> bool:
53
+ return bool(self.session)
54
+
55
+ @property
56
+ def session_id(self) -> str | None:
57
+ if not self.session:
58
+ raise DeadSessionError(self.organization_id)
59
+
60
+ return self.session.session_id
61
+
62
+ @property
63
+ def peername(self) -> str:
64
+ if not self.session:
65
+ return "undefined"
66
+
67
+ return self.session.peername or "undefined"
68
+
69
+ @property
70
+ def is_authenticated(self) -> bool:
71
+ return self.auth_event.is_set()
72
+
73
+ @property
74
+ def organization_id(self) -> int:
75
+ return self._organization_id
76
+
77
+ async def send_raw(self, frame: PyAttpMessage):
78
+ """
79
+ Send raw message to session driver.
80
+
81
+ Parameters
82
+ ----------
83
+ frame : PyAttpMessage
84
+ Attp Frame that contains
85
+ """
86
+ if not self.session:
87
+ raise DeadSessionError(self.organization_id)
88
+
89
+ return await self.session.send(frame)
90
+
91
+ async def send_message(self, route: str | int, data: FixedBaseModel | Serializable | None) -> bytes:
92
+ """
93
+ Sends an ATTPMessage to the client.
94
+
95
+ Parameters
96
+ ----------
97
+ route : str | int
98
+ String pattern of the route if str passed, or int ID of route.
99
+ data : FixedBaseModel | Serializable | None
100
+ A serializable data that will be sent.
101
+
102
+ Returns
103
+ -------
104
+ bytes
105
+ Generated correlation ID that will be used for mapping the response.
106
+ """
107
+ if not self.session:
108
+ raise DeadSessionError(self.organization_id)
109
+
110
+ if not self.server_routes:
111
+ raise UnauthenticatedError(f"Cannot send an ATTP message with acknowledgement to unauthenticated (route_mapping={route})")
112
+
113
+ correlation_id = uuid4().bytes
114
+ relevant_route = route
115
+
116
+ if isinstance(route, str):
117
+ relevant_route = resolve_route_by_id("message", route, self.server_routes).route_id
118
+
119
+ print("RELEVANT ROUTE", relevant_route)
120
+
121
+ frame = PyAttpMessage(int(relevant_route), AttpCommand.CALL, correlation_id=correlation_id, payload=data.mpd() if data is not None else None, version=ATTP_VERSION)
122
+ print(frame.payload)
123
+ await self.send_raw(frame)
124
+
125
+ return correlation_id
126
+
127
+
128
+ async def emit_message(self, route: str | int, data: FixedBaseModel | Serializable | None) -> None:
129
+ """
130
+ Emits an ATTPMessage to the client.
131
+ It forms the EMIT frame instead of CALL, which doesn't require receiver to respond to it.
132
+
133
+ Parameters
134
+ ----------
135
+ route : str | int
136
+ String pattern of the route if str passed, or int ID of route.
137
+ data : FixedBaseModel | Serializable | None
138
+ A serializable data that will be sent.
139
+ """
140
+ if not self.server_routes:
141
+ raise UnauthenticatedError(f"Cannot send an ATTP message with acknowledgement to unauthenticated (route_mapping={route})")
142
+
143
+ relevant_route = route
144
+
145
+ if isinstance(route, str):
146
+ relevant_route = resolve_route_by_id("event", route, self.server_routes).route_id
147
+
148
+ frame = PyAttpMessage(int(relevant_route), AttpCommand.CALL, correlation_id=None, payload=data.mpd() if data is not None else None, version=ATTP_VERSION)
149
+ await self.send_raw(frame)
150
+
151
+ async def authenticate(self, route_mappings: Sequence[AttpRouteMapping] | None) -> None:
152
+ """
153
+ Send AUTH frame to attp server (AgentHub).
154
+ Version should be b'01' in bytes.
155
+
156
+ Parameters
157
+ ----------
158
+ version : bytes
159
+ Version of ATTP protocol (to validate on Rust side). It's 01, which correlated to '0.1'
160
+
161
+ Returns
162
+ -------
163
+ tuple[str, int]
164
+ (session_id, organization_id) - client sends organization_id while session ID is generated by Rust.
165
+ """
166
+ if route_mappings:
167
+ self.client_routes.extend([IRouteMapping.from_route_mapper(mapper) for mapper in route_mappings])
168
+
169
+ frame = PyAttpMessage(
170
+ route_id=0,
171
+ command_type=AttpCommand.AUTH,
172
+ correlation_id=None,
173
+ payload=IAuth(
174
+ token=self.agt_token,
175
+ organization_id=self.organization_id
176
+ ).mpd(),
177
+ version=ATTP_VERSION
178
+ )
179
+ await self.send_raw(frame)
180
+
181
+ await asyncio.wait_for(self.auth_event.wait(), 10)
182
+
183
+ async def send_error(self, err: IErr, correlation_id: bytes | None = None, route: str | int = 0) -> None:
184
+ """
185
+ Send an error to the session peer.
186
+ Can be two types, non correlated to ack and correlated to ack.
187
+
188
+ When correlated to ack, the `correlation_id` as a response to acknowledgement is required to specify.
189
+
190
+ Parameters
191
+ ----------
192
+ err : IErr
193
+ Error details.
194
+ correlation_id : bytes | None, optional
195
+ For correlated to ack, ID of the correlation as a response error, by default None (non-correlated)
196
+ route : str | int, optional
197
+ For correlated to ack, not required and not allowed to specify route, but for non-correlated it's optional, by default 0
198
+ """
199
+ relevant_route = route
200
+
201
+ if not self.server_routes:
202
+ raise UnauthenticatedError(f"Cannot send an ATTP message with acknowledgement to unauthenticated (route_mapping={route})")
203
+
204
+
205
+ if isinstance(route, str):
206
+ relevant_route = resolve_route_by_id("err", route, self.server_routes).route_id
207
+
208
+ await self.send_raw(
209
+ PyAttpMessage(
210
+ route_id=int(relevant_route),
211
+ command_type=AttpCommand.ERR,
212
+ correlation_id=correlation_id,
213
+ payload=err.mpd(),
214
+ version=ATTP_VERSION
215
+ )
216
+ )
217
+
218
+ async def respond(self, correlation_id: bytes, payload: FixedBaseModel | Any | None = None):
219
+ """
220
+ For responding to `AttpCommand.CALL`. Used only for correlated requests.
221
+ It sends response (acknowledgement) message signed as `AttpCommand.ACK` to the request.
222
+
223
+ Parameters
224
+ ----------
225
+ correlation_id : bytes
226
+ Correlation ID to which response is being sent.
227
+ payload : FixedBaseModel | Serializable | None, optional
228
+ Response payload, the data that will be sent, by default None
229
+ """
230
+ frame = PyAttpMessage(
231
+ route_id=0,
232
+ command_type=AttpCommand.ACK,
233
+ correlation_id=correlation_id,
234
+ payload=serializer.deserialize(payload),
235
+ version=ATTP_VERSION
236
+ )
237
+
238
+ await self.send_raw(frame)
239
+
240
+ async def listen(self, responder: Subject[PyAttpMessage]) -> None:
241
+ """
242
+ Start a background read-loop task that:
243
+ - Routes CALL, EMIT -> `events` (apply backpressure with await put)
244
+ - Routes ACK, ERR -> `responder.on_next`
245
+ Must:
246
+ - batch across the FFI boundary
247
+ - enforce correlation rules:
248
+ * CALL must include Correlation-Id
249
+ * ACK/ERR must include Correlation-Id
250
+ - propagate terminal errors by finishing the Task with an exception
251
+ - complete `responder` on orderly close
252
+ """
253
+ while self.is_authenticated:
254
+ message = await self.messages.get()
255
+ responder.on_next(message)
256
+
257
+ async def close(self):
258
+ """
259
+ Closes the connection from server-side between the client.
260
+ """
261
+ await self.send_raw(PyAttpMessage(
262
+ route_id=0,
263
+ command_type=AttpCommand.DISCONNECT,
264
+ correlation_id=None,
265
+ payload=None,
266
+ version=ATTP_VERSION
267
+ ))
268
+ self.session.stop_listener()
269
+ self.session.disconnect()
270
+ del self.session
271
+
272
+ async def handle_ready(self, frame: IReady):
273
+ self.server_routes = frame.server_routes
274
+
275
+ # print(frame.server_routes)
276
+
277
+ data = PyAttpMessage(
278
+ route_id=0,
279
+ command_type=AttpCommand.READY,
280
+ correlation_id=None,
281
+ payload=IHello(proto="attp", ver="0.1", caps=[], mapping=self.client_routes).mpd(),
282
+ version=b"01"
283
+ )
284
+ await self.send_raw(data)
285
+ self.auth_event.set()
286
+
287
+ async def _on_event(self, events: list[PyAttpMessage]) -> None:
288
+ self.logger.debug(f"[cyan]ATTP[/] ┆ Received a new message from session {self.session_id} ")
289
+ # assert self.session
290
+
291
+ for event in events:
292
+ if event.route_id == 0 and event.command_type == AttpCommand.READY:
293
+ try:
294
+ if not event.payload:
295
+ continue
296
+
297
+ await self.handle_ready(IReady.mps(event.payload))
298
+
299
+ except Exception as e:
300
+ traceback.print_exc()
301
+ await self.close()
302
+ break
303
+
304
+ else:
305
+ if self.is_authenticated:
306
+ self.logger.debug("cyan]ATTP[/] ┆ Handing incoming message to a route handler.")
307
+ self.messages.put_nowait(event)
308
+
309
+ async def start_listener(self):
310
+ if not self.session:
311
+ raise DeadSessionError(self.organization_id)
312
+ self.session.add_event_handler(self._on_event)
313
+ await asyncio.gather(
314
+ self.session.start_handler(),
315
+ self.session.start_listener()
316
+ )
attp_client/tools.py ADDED
@@ -0,0 +1,59 @@
1
+ from typing import Any, Sequence
2
+ from uuid import UUID
3
+ from attp_client.misc.serializable import Serializable
4
+ from attp_client.router import AttpRouter
5
+ from attp_client.session import SessionDriver
6
+
7
+
8
+ class ToolsManager:
9
+ def __init__(self, router: AttpRouter) -> None:
10
+ self.router = router
11
+
12
+ async def register(
13
+ self,
14
+ catalog_name: str,
15
+ name: str,
16
+ description: str | None = None,
17
+ schema_id: str | None = None,
18
+ *,
19
+ return_direct: bool = False,
20
+ schema_ver: str = "1.0",
21
+ timeout_ms: float = 20000,
22
+ idempotent: bool = False
23
+ ) -> UUID:
24
+ response = await self.router.send(
25
+ "tools:register",
26
+ Serializable[dict[str, Any]]({
27
+ "catalog": catalog_name,
28
+ "tool": {
29
+ "name": name,
30
+ "description": description,
31
+ "schema_id": schema_id,
32
+ "return_direct": return_direct,
33
+ "schema_ver": schema_ver,
34
+ "timeout_ms": timeout_ms,
35
+ "idempotent": idempotent
36
+ }
37
+ }),
38
+ timeout=30,
39
+ expected_response=dict[str, Any]
40
+ )
41
+
42
+ return UUID(hex=response["assigned_id"])
43
+
44
+ async def unregister(
45
+ self,
46
+ catalog_name: str,
47
+ tool_id: str | Sequence[str]
48
+ ) -> str | list[str]:
49
+ response = await self.router.send(
50
+ "tool:unregister",
51
+ Serializable[dict[str, Any]]({
52
+ "catalog": catalog_name,
53
+ "tool_id": tool_id
54
+ }),
55
+ timeout=30,
56
+ expected_response=dict[str, Any]
57
+ )
58
+
59
+ return response["removed_ids"]
@@ -0,0 +1,14 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Literal, TypeAlias
3
+
4
+ # from core.attp.interfaces.handshake.mapping import IRouteMapping
5
+
6
+ RouteType: TypeAlias = Literal["event", "message", "err", "disconnect", "connect"]
7
+
8
+
9
+ @dataclass(frozen=True, unsafe_hash=True)
10
+ class AttpRouteMapping:
11
+ pattern: str
12
+ route_id: int
13
+ route_type: RouteType
14
+ callback: Any
@@ -0,0 +1,40 @@
1
+ import asyncio
2
+ from typing import Generic, TypeVar
3
+ from reactivex import Observable
4
+
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ class ContextAwaiter(Generic[T]):
10
+ def __init__(
11
+ self,
12
+ observable: Observable[T],
13
+ ) -> None:
14
+ self.observable = observable
15
+ self.response = asyncio.Future[T]()
16
+ self.event = asyncio.Event()
17
+ self.subscription = None
18
+
19
+ async def wait(self) -> T:
20
+ self.subscription = self.observable.subscribe(
21
+ on_next=self.__define_response,
22
+ on_error=self.__set_error
23
+ )
24
+
25
+ return await self.response
26
+
27
+ def __define_response(self, resp: T):
28
+ if self.subscription:
29
+ self.subscription.dispose()
30
+ self.subscription = None
31
+
32
+ if not self.response.done():
33
+ self.response.set_result(resp)
34
+ self.event.set()
35
+
36
+ def __set_error(self, exc: Exception):
37
+ if not self.response.done():
38
+ self.response.set_exception(exc)
39
+
40
+ self.event.set()
@@ -0,0 +1,18 @@
1
+ from typing import Sequence
2
+
3
+ from attp_client.errors.not_found import NotFoundError
4
+ from attp_client.interfaces.route_mappings import IRouteMapping
5
+
6
+ from attp_client.types.route_mapping import RouteType
7
+
8
+
9
+ def resolve_route_by_id(route_type: RouteType, pattern: str, route_mapping: Sequence[IRouteMapping]):
10
+ route = next(
11
+ (route for route in route_mapping if route.pattern == pattern and route.route_type == route_type),
12
+ None
13
+ )
14
+
15
+ if not route:
16
+ raise NotFoundError(f"Route {pattern} not found.")
17
+
18
+ return route
@@ -0,0 +1,25 @@
1
+ from typing import Any
2
+
3
+ import msgpack
4
+ from pydantic import TypeAdapter
5
+ from attp_client.misc.fixed_basemodel import FixedBaseModel
6
+
7
+
8
+ def serialize(data: bytes | None, model: type[FixedBaseModel] | Any) -> FixedBaseModel | None:
9
+ if not data:
10
+ return None
11
+
12
+ if issubclass(model, FixedBaseModel):
13
+ return model.mps(data)
14
+
15
+ return TypeAdapter(model).validate_python(msgpack.unpackb(data))
16
+
17
+
18
+ def deserialize(data: FixedBaseModel | Any | None) -> bytes | None:
19
+ if data is None:
20
+ return None
21
+
22
+ if isinstance(data, FixedBaseModel):
23
+ return data.mpd()
24
+
25
+ return msgpack.packb(data)