genai-protocol-lite 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,505 @@
1
+ import asyncio
2
+ import logging
3
+ import uuid
4
+ from typing import Optional, Dict, Any, List, Callable, Awaitable, Tuple
5
+
6
+ from AIConnector.common.message import MessageTypes
7
+ from AIConnector.common.network import NetworkConfig
8
+ from AIConnector.common.exceptions import PeerDisconnectedException, AgentConnectionError
9
+ from AIConnector.connector.peer_connection_manager import PeerConnectionManager
10
+ from .job_manager import JobManager
11
+ from .message_factory import MessageFactory
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class ChatClient:
17
+ """
18
+ A peer-to-peer chat client that handles message exchange and job processing.
19
+ """
20
+
21
+ # Heartbeat timeout in seconds. If no heartbeat is received within this period,
22
+ # the pending job future is resolved with a timeout message.
23
+ _HEARTBEAT_TIMEOUT = 30
24
+
25
+ def __init__(self, network_config: NetworkConfig, client_id: str, client_name: Optional[str] = None) -> None:
26
+ """
27
+ Initialize the ChatClient.
28
+
29
+ Args:
30
+ network_config (NetworkConfig): Network configuration settings.
31
+ client_id (str): Unique identifier for the client.
32
+ client_name (Optional[str]): Optional display name of the user.
33
+ """
34
+ self.client_id: str = client_id
35
+ self.network_config: NetworkConfig = network_config
36
+ self.client_name: str = client_name or "DefaultUser"
37
+ self.connector: Optional[PeerConnectionManager] = None
38
+ self.job_manager: JobManager = JobManager(send_message_callback=self.send_typed_message)
39
+ self.pending_jobs_futures: Dict[str, asyncio.Future] = {}
40
+
41
+ # Lists for storing async callbacks for peer events.
42
+ self._peer_connected_callbacks: List[Callable[[str], Awaitable[None]]] = []
43
+ self._peer_disconnected_callbacks: List[Callable[[str], Awaitable[None]]] = []
44
+ self._peer_discovered_callbacks: List[Callable[[str], Awaitable[None]]] = []
45
+ self._peer_lost_callbacks: List[Callable[[str], Awaitable[None]]] = []
46
+
47
+ self._timeout_task: Optional[asyncio.Task] = None
48
+
49
+ async def start(self, host: str = "0.0.0.0") -> None:
50
+ """
51
+ Start the chat client by initializing the connector and job manager.
52
+
53
+ Args:
54
+ host (str): The host address to bind the server.
55
+ """
56
+ # Enable discovery if there are jobs to process.
57
+ allow_discovery = bool(self.job_manager.job_list())
58
+
59
+ self.connector = PeerConnectionManager(
60
+ on_message_callback=self._on_message_received,
61
+ on_peer_connected=self._on_peer_connected,
62
+ on_peer_disconnected=self._on_peer_disconnected,
63
+ on_peer_discovered=self._on_peer_discovered,
64
+ on_peer_lost=self._on_peer_lost,
65
+ client_name=self.client_name,
66
+ client_id=self.client_id,
67
+ network_config=self.network_config,
68
+ port=self.network_config.webserver_port,
69
+ )
70
+ await self.connector.start(
71
+ host=host,
72
+ port=self.network_config.webserver_port,
73
+ allow_discovery=allow_discovery
74
+ )
75
+ self.job_manager.start()
76
+ logger.debug(f"[ChatClient] Started (client_id={self.client_id})")
77
+
78
+ # Start background task to periodically check heartbeat timeouts.
79
+ self._timeout_task = asyncio.create_task(self._check_future_timeouts())
80
+
81
+ async def stop(self) -> None:
82
+ """
83
+ Stop the chat client, including the connector and job manager.
84
+ """
85
+ if self.connector:
86
+ await self.connector.stop()
87
+ logger.debug("[ChatClient] Connector stopped")
88
+ if self.job_manager:
89
+ await self.job_manager.stop()
90
+ if self._timeout_task:
91
+ self._timeout_task.cancel()
92
+ try:
93
+ await self._timeout_task
94
+ except asyncio.CancelledError:
95
+ pass
96
+ logger.debug("[ChatClient] Stopped")
97
+
98
+ async def msg_type_hello(self, msg: Dict[str, Any]) -> None:
99
+ """
100
+ Handle a HELLO message.
101
+
102
+ Args:
103
+ msg (Dict[str, Any]): The received message.
104
+ """
105
+ from_id: str = msg.get("from_id")
106
+ logger.debug(f"[ChatClient] Received HELLO from {from_id}")
107
+
108
+ async def msg_type_text(self, msg: Dict[str, Any]) -> None:
109
+ """
110
+ Handle a TEXT message.
111
+
112
+ Args:
113
+ msg (Dict[str, Any]): The received message.
114
+ """
115
+ peer_id: str = msg.get("from_id")
116
+ text: str = msg.get("text", "")
117
+ queue_id: str = msg.get("queue_id", "")
118
+ # Process the text message via the job manager.
119
+ await self.job_manager.process_message(msg=text, from_id=peer_id, queue_id=queue_id)
120
+ logger.debug(f"[ChatClient] Client {self.client_id} received text: {text}")
121
+
122
+ async def msg_job_list(self, msg: Dict[str, Any]) -> List[Any]:
123
+ """
124
+ Return the current job list.
125
+
126
+ Args:
127
+ msg (Dict[str, Any]): The received message (unused).
128
+
129
+ Returns:
130
+ List[Any]: The list of registered jobs.
131
+ """
132
+ return self.job_manager.job_list()
133
+
134
+ async def msg_job_unknown(self, msg: Dict[str, Any]) -> None:
135
+ """
136
+ Handle an unknown message type.
137
+
138
+ Args:
139
+ msg (Dict[str, Any]): The received message.
140
+ """
141
+ msg_type: str = msg.get("type")
142
+ logger.debug(f"[ChatClient] Unknown message type: {msg_type}")
143
+
144
+ async def msg_final_message(self, msg: Dict[str, Any]) -> None:
145
+ """
146
+ Process a FINAL_MESSAGE by resolving the corresponding job future.
147
+
148
+ Args:
149
+ msg (Dict[str, Any]): The message containing the final result.
150
+ """
151
+ logger.debug(f"[ChatClient] Final message received: {msg}")
152
+ queue_id: str = msg.get("queue_id")
153
+ text: str = msg.get("text")
154
+ fut: Optional[asyncio.Future] = self.pending_jobs_futures.get(queue_id)
155
+ if fut and not fut.done():
156
+ fut.set_result(text)
157
+
158
+ logger.debug(f"[ChatClient] Final message processed for job_id={queue_id}, result={text}")
159
+
160
+ async def msg_system_message(self, msg: Dict[str, Any]) -> None:
161
+ """
162
+ Handle a SYSTEM_MESSAGE.
163
+
164
+ Args:
165
+ msg (Dict[str, Any]): The received message.
166
+ """
167
+ logger.debug(f"[ChatClient] System message received: {msg}")
168
+
169
+ async def msg_heartbeat(self, msg: Dict[str, Any]) -> None:
170
+ """
171
+ Handle a HEARTBEAT message, updating heartbeat timestamps for associated futures.
172
+
173
+ Args:
174
+ msg (Dict[str, Any]): The received message.
175
+ """
176
+ logger.debug(f"[ChatClient] Heartbeat received: {msg}")
177
+ peer_id = msg.get("from_id")
178
+ current_time = asyncio.get_event_loop().time()
179
+ for fut in self.pending_jobs_futures.values():
180
+ if getattr(fut, "peer_id", None) == peer_id:
181
+ fut.heartbeat = current_time
182
+
183
+ async def msg_error(self, msg: Dict[str, Any]) -> None:
184
+ """
185
+ Handle an ERROR message by resolving the corresponding future with an error message.
186
+
187
+ Args:
188
+ msg (Dict[str, Any]): The received message.
189
+ """
190
+ logger.error(f"[ChatClient] Error message received: {msg}")
191
+ queue_id: str = msg.get("queue_id")
192
+ fut: Optional[asyncio.Future] = self.pending_jobs_futures.get(queue_id)
193
+ if fut and not fut.done():
194
+ fut.set_result("error")
195
+
196
+
197
+ async def _on_message_received(self, msg: Dict[str, Any]) -> None:
198
+ """
199
+ Dispatch incoming messages to the appropriate handler based on the message type.
200
+
201
+ Args:
202
+ msg (Dict[str, Any]): The received message.
203
+ """
204
+ # Map message types to their corresponding handlers.
205
+ fns: Dict[str, Callable[[Dict[str, Any]], asyncio.Future]] = {
206
+ MessageTypes.HELLO.value: self.msg_type_hello,
207
+ MessageTypes.TEXT.value: self.msg_type_text,
208
+ MessageTypes.JOB_LIST.value: self.msg_job_list,
209
+ MessageTypes.FINAL_MESSAGE.value: self.msg_final_message,
210
+ MessageTypes.SYSTEM_MESSAGE.value: self.msg_system_message,
211
+ MessageTypes.HEARTBEAT.value: self.msg_heartbeat,
212
+ MessageTypes.ERROR.value: self.msg_error,
213
+ }
214
+ msg_type: str = msg.get("type", "").lower()
215
+ # Default to TEXT handler if the message type is not found.
216
+ fn = fns.get(msg_type, self.msg_type_text)
217
+ await fn(msg=msg)
218
+
219
+ async def _on_peer_connected(self, peer_id: str) -> None:
220
+ """
221
+ Handle actions when a new peer connects.
222
+
223
+ Args:
224
+ peer_id (str): The identifier of the connected peer.
225
+ """
226
+ logger.debug(f"[ChatClient] Peer connected: {peer_id}")
227
+ for callback in self._peer_connected_callbacks:
228
+ await callback(peer_id)
229
+
230
+ async def _on_peer_disconnected(self, peer_id: str) -> None:
231
+ """
232
+ Handle cleanup when a peer disconnects, including canceling pending jobs.
233
+
234
+ Args:
235
+ peer_id (str): The identifier of the disconnected peer.
236
+ """
237
+ logger.debug(f"[ChatClient] Peer disconnected: {peer_id}")
238
+ for job_id, fut in list(self.pending_jobs_futures.items()):
239
+ if getattr(fut, "peer_id", None) == peer_id:
240
+ if not fut.done():
241
+ fut.set_exception(PeerDisconnectedException(f"Task {job_id} canceled: peer {peer_id} disconnected"))
242
+ del self.pending_jobs_futures[job_id]
243
+ for callback in self._peer_disconnected_callbacks:
244
+ await callback(peer_id)
245
+
246
+ async def _on_peer_discovered(self, peer_id: str) -> None:
247
+ """
248
+ Handle actions when a new peer is discovered.
249
+
250
+ Args:
251
+ peer_id (str): The identifier of the discovered peer.
252
+ """
253
+ for callback in self._peer_discovered_callbacks:
254
+ await callback(peer_id)
255
+
256
+ async def _on_peer_lost(self, peer_id: str) -> None:
257
+ """
258
+ Handle actions when a peer is lost.
259
+
260
+ Args:
261
+ peer_id (str): The identifier of the lost peer.
262
+ """
263
+ for callback in self._peer_lost_callbacks:
264
+ await callback(peer_id)
265
+
266
+ async def get_peers_list(self) -> List[Dict[str, Any]]:
267
+ """
268
+ Retrieve the list of connected peers.
269
+
270
+ Returns:
271
+ List[Dict[str, Any]]: A list of peer information dictionaries.
272
+ """
273
+ if self.connector is None:
274
+ return []
275
+ return self.connector.get_peers_list()
276
+
277
+ async def get_peers_by_client_name(self, target_client_name: str) -> List[Dict[str, Any]]:
278
+ """
279
+ Retrieve peers matching the given display name.
280
+
281
+ Args:
282
+ target_client_name (str): The display name to search for.
283
+
284
+ Returns:
285
+ List[Dict[str, Any]]: A list of matching peer information dictionaries.
286
+ """
287
+ peers: List[Dict[str, Any]] = await self.get_peers_list()
288
+ return [peer for peer in peers if peer["display_name"] == target_client_name]
289
+
290
+ async def wait_all_peers(self, interval: float = 5, timeout: float = 30) -> List[Dict[str, Any]]:
291
+ """
292
+ Wait until at least one peer is discovered.
293
+
294
+ Args:
295
+ interval (float): Time interval between checks in seconds.
296
+ timeout (float): Maximum time to wait in seconds.
297
+
298
+ Returns:
299
+ List[Dict[str, Any]]: A list of peer information dictionaries.
300
+
301
+ Raises:
302
+ TimeoutError: If no peer is discovered within the timeout period.
303
+ """
304
+ start_time = asyncio.get_event_loop().time()
305
+ while True:
306
+ peers = await self.get_peers_list()
307
+ if peers:
308
+ return peers
309
+ if asyncio.get_event_loop().time() - start_time > timeout:
310
+ raise TimeoutError(f"Couldn't find any peers within {timeout} seconds")
311
+ await asyncio.sleep(interval)
312
+
313
+ async def wait_for_peers(self, target_client_name: str, interval: float = 5, timeout: float = 30) -> Dict[str, Any]:
314
+ """
315
+ Wait until a peer with the given display name is found.
316
+
317
+ Args:
318
+ target_client_name (str): The display name to search for.
319
+ interval (float): Time interval between checks in seconds.
320
+ timeout (float): Maximum time to wait in seconds.
321
+
322
+ Returns:
323
+ Dict[str, Any]: The information dictionary for the first matching peer.
324
+
325
+ Raises:
326
+ TimeoutError: If no matching peer is found within the timeout period.
327
+ """
328
+ start_time = asyncio.get_event_loop().time()
329
+ while True:
330
+ peers = await self.get_peers_by_client_name(target_client_name=target_client_name)
331
+ if peers:
332
+ return peers[0]
333
+ if asyncio.get_event_loop().time() - start_time > timeout:
334
+ raise TimeoutError(f"Couldn't find {target_client_name} within {timeout} seconds")
335
+ await asyncio.sleep(interval)
336
+
337
+ async def send_typed_message(
338
+ self,
339
+ peer_id: str,
340
+ text: str,
341
+ message_type: str = MessageTypes.TEXT.value,
342
+ job_id: Optional[str] = None,
343
+ *args,
344
+ **kwargs
345
+ ) -> None:
346
+ """
347
+ Send a message with a specific type using the MessageFactory.
348
+
349
+ Args:
350
+ peer_id (str): The target peer's identifier.
351
+ text (str): The message text.
352
+ message_type (str): The type of the message.
353
+ job_id (Optional[str]): Optional job identifier.
354
+ *args: Additional positional arguments for MessageFactory.
355
+ **kwargs: Additional keyword arguments for MessageFactory.
356
+ """
357
+ message_factory = MessageFactory(
358
+ message_type=message_type,
359
+ from_id=self.client_id,
360
+ job_id=job_id,
361
+ *args,
362
+ **kwargs
363
+ )
364
+ await self._send_message(peer_id=peer_id, msg=message_factory.generate_message(text=text))
365
+
366
+ async def send_message(self, peer_id: str, text: str) -> str:
367
+ """
368
+ Send a TEXT message to a specific peer and create a pending job.
369
+
370
+ Args:
371
+ peer_id (str): The target peer's identifier.
372
+ text (str): The message text.
373
+
374
+ Returns:
375
+ str: A unique queue/job identifier.
376
+ """
377
+ queue_id: str = str(uuid.uuid4())
378
+ loop = asyncio.get_event_loop()
379
+ future = loop.create_future()
380
+ # Attach metadata to the future for heartbeat monitoring.
381
+ future.peer_id = peer_id
382
+ current_time = loop.time()
383
+ future.created_at = current_time
384
+ future.heartbeat = current_time
385
+ self.pending_jobs_futures[queue_id] = future
386
+
387
+ message_factory = MessageFactory(
388
+ message_type=MessageTypes.TEXT.value,
389
+ from_id=self.client_id,
390
+ queue_id=queue_id,
391
+ )
392
+ await self._send_message(peer_id=peer_id, msg=message_factory.generate_message(text=text))
393
+ return queue_id
394
+
395
+ async def wait_for_result(self, job_id: str, timeout: Optional[float] = None) -> Tuple[bool, str]:
396
+ """
397
+ Wait for the result of a job identified by job_id.
398
+
399
+ Args:
400
+ job_id (str): The job identifier.
401
+ timeout (Optional[float]): Optional timeout in seconds.
402
+
403
+ Returns:
404
+ str: The result text.
405
+
406
+ Raises:
407
+ ValueError: If no pending future is found for the job_id.
408
+ asyncio.TimeoutError: If the result is not received within the timeout.
409
+ """
410
+ fut = self.pending_jobs_futures.get(job_id)
411
+ if not fut:
412
+ raise ValueError(f"No pending future for job_id={job_id}")
413
+ try:
414
+ result = await asyncio.wait_for(fut, timeout=timeout)
415
+ return True, result
416
+ except PeerDisconnectedException:
417
+ return False, "Peer disconnected"
418
+ except AgentConnectionError:
419
+ return False, "couldn't connect to Agent"
420
+
421
+ async def _send_message(self, peer_id: str, msg: Dict[str, Any]) -> None:
422
+ """
423
+ Send a message to a peer using the underlying connector.
424
+
425
+ Args:
426
+ peer_id (str): The target peer's identifier.
427
+ msg (Dict[str, Any]): The message payload.
428
+ """
429
+ await self.connector.send_message(peer_id, msg)
430
+
431
+ async def register_job(self, job_call_back: Callable) -> None:
432
+ """
433
+ Register a new job callback with the job manager.
434
+
435
+ Args:
436
+ job_call_back (Callable): The callback function to register.
437
+ """
438
+ self.job_manager.register_job(job_call_back=job_call_back)
439
+
440
+ def register_on_peer_connected_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
441
+ """
442
+ Register an async callback for when a peer connects.
443
+
444
+ Args:
445
+ callback (Callable[[str], Awaitable[None]]): Async function accepting a peer_id.
446
+ """
447
+ if callback not in self._peer_connected_callbacks:
448
+ self._peer_connected_callbacks.append(callback)
449
+
450
+ def register_on_peer_disconnected_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
451
+ """
452
+ Register an async callback for when a peer disconnects.
453
+
454
+ Args:
455
+ callback (Callable[[str], Awaitable[None]]): Async function accepting a peer_id.
456
+ """
457
+ if callback not in self._peer_disconnected_callbacks:
458
+ self._peer_disconnected_callbacks.append(callback)
459
+
460
+ def register_on_peer_discovered_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
461
+ """
462
+ Register an async callback for when a peer is discovered.
463
+
464
+ Args:
465
+ callback (Callable[[str], Awaitable[None]]): Async function accepting a peer_id.
466
+ """
467
+ if callback not in self._peer_discovered_callbacks:
468
+ self._peer_discovered_callbacks.append(callback)
469
+
470
+ def register_on_peer_lost_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
471
+ """
472
+ Register an async callback for when a peer is lost.
473
+
474
+ Args:
475
+ callback (Callable[[str], Awaitable[None]]): Async function accepting a peer_id.
476
+ """
477
+ if callback not in self._peer_lost_callbacks:
478
+ self._peer_lost_callbacks.append(callback)
479
+
480
+ async def _check_future_timeouts(self) -> None:
481
+ """
482
+ Periodically check all pending futures. If the time since the last heartbeat exceeds
483
+ _HEARTBEAT_TIMEOUT seconds, resolve the future with a timeout message and remove it.
484
+ """
485
+ while True:
486
+ try:
487
+ current_time = asyncio.get_event_loop().time()
488
+ for job_id, fut in list(self.pending_jobs_futures.items()):
489
+ if not fut.done():
490
+ last_hb = getattr(fut, "heartbeat", None)
491
+ if last_hb is not None and (current_time - last_hb > self._HEARTBEAT_TIMEOUT):
492
+ del self.pending_jobs_futures[job_id]
493
+ fut.set_exception(AgentConnectionError)
494
+ except Exception as e:
495
+ logger.error(f"[ChatClient] Error in _check_future_timeouts: {e}")
496
+ await asyncio.sleep(5)
497
+
498
+ async def close_connection(self, peer_id: str):
499
+ """
500
+ Close connection
501
+
502
+ Args:
503
+ peer_id (str): Unique identifier of the lost peer.
504
+ """
505
+ await self.connector.close_connection(peer_id=peer_id)
@@ -0,0 +1,48 @@
1
+ import asyncio
2
+ from datetime import datetime
3
+ from typing import Any, Optional
4
+
5
+
6
+ class Job:
7
+ """
8
+ Represents a job with associated data, peer information, and an optional asyncio task.
9
+ """
10
+
11
+ def __init__(self, job_id: str, data: Any, peer_id: str, queue_id: str) -> None:
12
+ """
13
+ Initialize a new Job instance.
14
+
15
+ Args:
16
+ job_id (str): Unique identifier for the job.
17
+ data (Any): Data associated with the job.
18
+ peer_id (str): Identifier of the peer related to the job.
19
+ queue_id (str): Identifier for the job queue.
20
+ """
21
+ self.job_id: str = job_id
22
+ self.data: Any = data
23
+ self.peer_id: str = peer_id
24
+ self.queue_id: str = queue_id
25
+ self.start_heartbeat: Optional[datetime] = None # Time when heartbeat monitoring started
26
+ self.heartbeat: Optional[datetime] = None # Last recorded heartbeat timestamp
27
+ self.task: Optional[asyncio.Task] = None # Asyncio task executing the job
28
+
29
+ def set_task(self, task: asyncio.Task) -> None:
30
+ """
31
+ Associate an asyncio task with this job.
32
+
33
+ Args:
34
+ task (asyncio.Task): The asyncio Task to be assigned.
35
+ """
36
+ self.task = task
37
+
38
+ def update_heartbeat(self, heartbeat: datetime) -> None:
39
+ """
40
+ Update the heartbeat timestamp for this job.
41
+ If the start heartbeat is not set, initialize it with the current heartbeat.
42
+
43
+ Args:
44
+ heartbeat (datetime): The current heartbeat timestamp.
45
+ """
46
+ if self.start_heartbeat is None:
47
+ self.start_heartbeat = heartbeat
48
+ self.heartbeat = heartbeat