genai-protocol-lite 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- AIConnector/__init__.py +0 -0
- AIConnector/common/__init__.py +0 -0
- AIConnector/common/exceptions.py +13 -0
- AIConnector/common/logger.py +57 -0
- AIConnector/common/message.py +14 -0
- AIConnector/common/network.py +146 -0
- AIConnector/connector/__init__.py +0 -0
- AIConnector/connector/azure_connector.py +205 -0
- AIConnector/connector/base_connector.py +51 -0
- AIConnector/connector/peer_connection_manager.py +260 -0
- AIConnector/connector/ws_connector.py +213 -0
- AIConnector/core/__init__.py +0 -0
- AIConnector/core/chat_client.py +505 -0
- AIConnector/core/job.py +48 -0
- AIConnector/core/job_manager.py +219 -0
- AIConnector/core/message_factory.py +44 -0
- AIConnector/discovery/__init__.py +0 -0
- AIConnector/discovery/azure_discovery_service.py +206 -0
- AIConnector/discovery/base_discovery_service.py +27 -0
- AIConnector/discovery/discovery_service.py +226 -0
- AIConnector/session.py +274 -0
- genai_protocol_lite-1.0.0.dist-info/METADATA +186 -0
- genai_protocol_lite-1.0.0.dist-info/RECORD +26 -0
- genai_protocol_lite-1.0.0.dist-info/WHEEL +5 -0
- genai_protocol_lite-1.0.0.dist-info/licenses/LICENSE +201 -0
- genai_protocol_lite-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,505 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
import uuid
|
4
|
+
from typing import Optional, Dict, Any, List, Callable, Awaitable, Tuple
|
5
|
+
|
6
|
+
from AIConnector.common.message import MessageTypes
|
7
|
+
from AIConnector.common.network import NetworkConfig
|
8
|
+
from AIConnector.common.exceptions import PeerDisconnectedException, AgentConnectionError
|
9
|
+
from AIConnector.connector.peer_connection_manager import PeerConnectionManager
|
10
|
+
from .job_manager import JobManager
|
11
|
+
from .message_factory import MessageFactory
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class ChatClient:
|
17
|
+
"""
|
18
|
+
A peer-to-peer chat client that handles message exchange and job processing.
|
19
|
+
"""
|
20
|
+
|
21
|
+
# Heartbeat timeout in seconds. If no heartbeat is received within this period,
|
22
|
+
# the pending job future is resolved with a timeout message.
|
23
|
+
_HEARTBEAT_TIMEOUT = 30
|
24
|
+
|
25
|
+
def __init__(self, network_config: NetworkConfig, client_id: str, client_name: Optional[str] = None) -> None:
|
26
|
+
"""
|
27
|
+
Initialize the ChatClient.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
network_config (NetworkConfig): Network configuration settings.
|
31
|
+
client_id (str): Unique identifier for the client.
|
32
|
+
client_name (Optional[str]): Optional display name of the user.
|
33
|
+
"""
|
34
|
+
self.client_id: str = client_id
|
35
|
+
self.network_config: NetworkConfig = network_config
|
36
|
+
self.client_name: str = client_name or "DefaultUser"
|
37
|
+
self.connector: Optional[PeerConnectionManager] = None
|
38
|
+
self.job_manager: JobManager = JobManager(send_message_callback=self.send_typed_message)
|
39
|
+
self.pending_jobs_futures: Dict[str, asyncio.Future] = {}
|
40
|
+
|
41
|
+
# Lists for storing async callbacks for peer events.
|
42
|
+
self._peer_connected_callbacks: List[Callable[[str], Awaitable[None]]] = []
|
43
|
+
self._peer_disconnected_callbacks: List[Callable[[str], Awaitable[None]]] = []
|
44
|
+
self._peer_discovered_callbacks: List[Callable[[str], Awaitable[None]]] = []
|
45
|
+
self._peer_lost_callbacks: List[Callable[[str], Awaitable[None]]] = []
|
46
|
+
|
47
|
+
self._timeout_task: Optional[asyncio.Task] = None
|
48
|
+
|
49
|
+
async def start(self, host: str = "0.0.0.0") -> None:
|
50
|
+
"""
|
51
|
+
Start the chat client by initializing the connector and job manager.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
host (str): The host address to bind the server.
|
55
|
+
"""
|
56
|
+
# Enable discovery if there are jobs to process.
|
57
|
+
allow_discovery = bool(self.job_manager.job_list())
|
58
|
+
|
59
|
+
self.connector = PeerConnectionManager(
|
60
|
+
on_message_callback=self._on_message_received,
|
61
|
+
on_peer_connected=self._on_peer_connected,
|
62
|
+
on_peer_disconnected=self._on_peer_disconnected,
|
63
|
+
on_peer_discovered=self._on_peer_discovered,
|
64
|
+
on_peer_lost=self._on_peer_lost,
|
65
|
+
client_name=self.client_name,
|
66
|
+
client_id=self.client_id,
|
67
|
+
network_config=self.network_config,
|
68
|
+
port=self.network_config.webserver_port,
|
69
|
+
)
|
70
|
+
await self.connector.start(
|
71
|
+
host=host,
|
72
|
+
port=self.network_config.webserver_port,
|
73
|
+
allow_discovery=allow_discovery
|
74
|
+
)
|
75
|
+
self.job_manager.start()
|
76
|
+
logger.debug(f"[ChatClient] Started (client_id={self.client_id})")
|
77
|
+
|
78
|
+
# Start background task to periodically check heartbeat timeouts.
|
79
|
+
self._timeout_task = asyncio.create_task(self._check_future_timeouts())
|
80
|
+
|
81
|
+
async def stop(self) -> None:
|
82
|
+
"""
|
83
|
+
Stop the chat client, including the connector and job manager.
|
84
|
+
"""
|
85
|
+
if self.connector:
|
86
|
+
await self.connector.stop()
|
87
|
+
logger.debug("[ChatClient] Connector stopped")
|
88
|
+
if self.job_manager:
|
89
|
+
await self.job_manager.stop()
|
90
|
+
if self._timeout_task:
|
91
|
+
self._timeout_task.cancel()
|
92
|
+
try:
|
93
|
+
await self._timeout_task
|
94
|
+
except asyncio.CancelledError:
|
95
|
+
pass
|
96
|
+
logger.debug("[ChatClient] Stopped")
|
97
|
+
|
98
|
+
async def msg_type_hello(self, msg: Dict[str, Any]) -> None:
|
99
|
+
"""
|
100
|
+
Handle a HELLO message.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
msg (Dict[str, Any]): The received message.
|
104
|
+
"""
|
105
|
+
from_id: str = msg.get("from_id")
|
106
|
+
logger.debug(f"[ChatClient] Received HELLO from {from_id}")
|
107
|
+
|
108
|
+
async def msg_type_text(self, msg: Dict[str, Any]) -> None:
|
109
|
+
"""
|
110
|
+
Handle a TEXT message.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
msg (Dict[str, Any]): The received message.
|
114
|
+
"""
|
115
|
+
peer_id: str = msg.get("from_id")
|
116
|
+
text: str = msg.get("text", "")
|
117
|
+
queue_id: str = msg.get("queue_id", "")
|
118
|
+
# Process the text message via the job manager.
|
119
|
+
await self.job_manager.process_message(msg=text, from_id=peer_id, queue_id=queue_id)
|
120
|
+
logger.debug(f"[ChatClient] Client {self.client_id} received text: {text}")
|
121
|
+
|
122
|
+
async def msg_job_list(self, msg: Dict[str, Any]) -> List[Any]:
|
123
|
+
"""
|
124
|
+
Return the current job list.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
msg (Dict[str, Any]): The received message (unused).
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
List[Any]: The list of registered jobs.
|
131
|
+
"""
|
132
|
+
return self.job_manager.job_list()
|
133
|
+
|
134
|
+
async def msg_job_unknown(self, msg: Dict[str, Any]) -> None:
|
135
|
+
"""
|
136
|
+
Handle an unknown message type.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
msg (Dict[str, Any]): The received message.
|
140
|
+
"""
|
141
|
+
msg_type: str = msg.get("type")
|
142
|
+
logger.debug(f"[ChatClient] Unknown message type: {msg_type}")
|
143
|
+
|
144
|
+
async def msg_final_message(self, msg: Dict[str, Any]) -> None:
|
145
|
+
"""
|
146
|
+
Process a FINAL_MESSAGE by resolving the corresponding job future.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
msg (Dict[str, Any]): The message containing the final result.
|
150
|
+
"""
|
151
|
+
logger.debug(f"[ChatClient] Final message received: {msg}")
|
152
|
+
queue_id: str = msg.get("queue_id")
|
153
|
+
text: str = msg.get("text")
|
154
|
+
fut: Optional[asyncio.Future] = self.pending_jobs_futures.get(queue_id)
|
155
|
+
if fut and not fut.done():
|
156
|
+
fut.set_result(text)
|
157
|
+
|
158
|
+
logger.debug(f"[ChatClient] Final message processed for job_id={queue_id}, result={text}")
|
159
|
+
|
160
|
+
async def msg_system_message(self, msg: Dict[str, Any]) -> None:
|
161
|
+
"""
|
162
|
+
Handle a SYSTEM_MESSAGE.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
msg (Dict[str, Any]): The received message.
|
166
|
+
"""
|
167
|
+
logger.debug(f"[ChatClient] System message received: {msg}")
|
168
|
+
|
169
|
+
async def msg_heartbeat(self, msg: Dict[str, Any]) -> None:
|
170
|
+
"""
|
171
|
+
Handle a HEARTBEAT message, updating heartbeat timestamps for associated futures.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
msg (Dict[str, Any]): The received message.
|
175
|
+
"""
|
176
|
+
logger.debug(f"[ChatClient] Heartbeat received: {msg}")
|
177
|
+
peer_id = msg.get("from_id")
|
178
|
+
current_time = asyncio.get_event_loop().time()
|
179
|
+
for fut in self.pending_jobs_futures.values():
|
180
|
+
if getattr(fut, "peer_id", None) == peer_id:
|
181
|
+
fut.heartbeat = current_time
|
182
|
+
|
183
|
+
async def msg_error(self, msg: Dict[str, Any]) -> None:
|
184
|
+
"""
|
185
|
+
Handle an ERROR message by resolving the corresponding future with an error message.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
msg (Dict[str, Any]): The received message.
|
189
|
+
"""
|
190
|
+
logger.error(f"[ChatClient] Error message received: {msg}")
|
191
|
+
queue_id: str = msg.get("queue_id")
|
192
|
+
fut: Optional[asyncio.Future] = self.pending_jobs_futures.get(queue_id)
|
193
|
+
if fut and not fut.done():
|
194
|
+
fut.set_result("error")
|
195
|
+
|
196
|
+
|
197
|
+
async def _on_message_received(self, msg: Dict[str, Any]) -> None:
|
198
|
+
"""
|
199
|
+
Dispatch incoming messages to the appropriate handler based on the message type.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
msg (Dict[str, Any]): The received message.
|
203
|
+
"""
|
204
|
+
# Map message types to their corresponding handlers.
|
205
|
+
fns: Dict[str, Callable[[Dict[str, Any]], asyncio.Future]] = {
|
206
|
+
MessageTypes.HELLO.value: self.msg_type_hello,
|
207
|
+
MessageTypes.TEXT.value: self.msg_type_text,
|
208
|
+
MessageTypes.JOB_LIST.value: self.msg_job_list,
|
209
|
+
MessageTypes.FINAL_MESSAGE.value: self.msg_final_message,
|
210
|
+
MessageTypes.SYSTEM_MESSAGE.value: self.msg_system_message,
|
211
|
+
MessageTypes.HEARTBEAT.value: self.msg_heartbeat,
|
212
|
+
MessageTypes.ERROR.value: self.msg_error,
|
213
|
+
}
|
214
|
+
msg_type: str = msg.get("type", "").lower()
|
215
|
+
# Default to TEXT handler if the message type is not found.
|
216
|
+
fn = fns.get(msg_type, self.msg_type_text)
|
217
|
+
await fn(msg=msg)
|
218
|
+
|
219
|
+
async def _on_peer_connected(self, peer_id: str) -> None:
|
220
|
+
"""
|
221
|
+
Handle actions when a new peer connects.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
peer_id (str): The identifier of the connected peer.
|
225
|
+
"""
|
226
|
+
logger.debug(f"[ChatClient] Peer connected: {peer_id}")
|
227
|
+
for callback in self._peer_connected_callbacks:
|
228
|
+
await callback(peer_id)
|
229
|
+
|
230
|
+
async def _on_peer_disconnected(self, peer_id: str) -> None:
|
231
|
+
"""
|
232
|
+
Handle cleanup when a peer disconnects, including canceling pending jobs.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
peer_id (str): The identifier of the disconnected peer.
|
236
|
+
"""
|
237
|
+
logger.debug(f"[ChatClient] Peer disconnected: {peer_id}")
|
238
|
+
for job_id, fut in list(self.pending_jobs_futures.items()):
|
239
|
+
if getattr(fut, "peer_id", None) == peer_id:
|
240
|
+
if not fut.done():
|
241
|
+
fut.set_exception(PeerDisconnectedException(f"Task {job_id} canceled: peer {peer_id} disconnected"))
|
242
|
+
del self.pending_jobs_futures[job_id]
|
243
|
+
for callback in self._peer_disconnected_callbacks:
|
244
|
+
await callback(peer_id)
|
245
|
+
|
246
|
+
async def _on_peer_discovered(self, peer_id: str) -> None:
|
247
|
+
"""
|
248
|
+
Handle actions when a new peer is discovered.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
peer_id (str): The identifier of the discovered peer.
|
252
|
+
"""
|
253
|
+
for callback in self._peer_discovered_callbacks:
|
254
|
+
await callback(peer_id)
|
255
|
+
|
256
|
+
async def _on_peer_lost(self, peer_id: str) -> None:
|
257
|
+
"""
|
258
|
+
Handle actions when a peer is lost.
|
259
|
+
|
260
|
+
Args:
|
261
|
+
peer_id (str): The identifier of the lost peer.
|
262
|
+
"""
|
263
|
+
for callback in self._peer_lost_callbacks:
|
264
|
+
await callback(peer_id)
|
265
|
+
|
266
|
+
async def get_peers_list(self) -> List[Dict[str, Any]]:
|
267
|
+
"""
|
268
|
+
Retrieve the list of connected peers.
|
269
|
+
|
270
|
+
Returns:
|
271
|
+
List[Dict[str, Any]]: A list of peer information dictionaries.
|
272
|
+
"""
|
273
|
+
if self.connector is None:
|
274
|
+
return []
|
275
|
+
return self.connector.get_peers_list()
|
276
|
+
|
277
|
+
async def get_peers_by_client_name(self, target_client_name: str) -> List[Dict[str, Any]]:
|
278
|
+
"""
|
279
|
+
Retrieve peers matching the given display name.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
target_client_name (str): The display name to search for.
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
List[Dict[str, Any]]: A list of matching peer information dictionaries.
|
286
|
+
"""
|
287
|
+
peers: List[Dict[str, Any]] = await self.get_peers_list()
|
288
|
+
return [peer for peer in peers if peer["display_name"] == target_client_name]
|
289
|
+
|
290
|
+
async def wait_all_peers(self, interval: float = 5, timeout: float = 30) -> List[Dict[str, Any]]:
|
291
|
+
"""
|
292
|
+
Wait until at least one peer is discovered.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
interval (float): Time interval between checks in seconds.
|
296
|
+
timeout (float): Maximum time to wait in seconds.
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
List[Dict[str, Any]]: A list of peer information dictionaries.
|
300
|
+
|
301
|
+
Raises:
|
302
|
+
TimeoutError: If no peer is discovered within the timeout period.
|
303
|
+
"""
|
304
|
+
start_time = asyncio.get_event_loop().time()
|
305
|
+
while True:
|
306
|
+
peers = await self.get_peers_list()
|
307
|
+
if peers:
|
308
|
+
return peers
|
309
|
+
if asyncio.get_event_loop().time() - start_time > timeout:
|
310
|
+
raise TimeoutError(f"Couldn't find any peers within {timeout} seconds")
|
311
|
+
await asyncio.sleep(interval)
|
312
|
+
|
313
|
+
async def wait_for_peers(self, target_client_name: str, interval: float = 5, timeout: float = 30) -> Dict[str, Any]:
|
314
|
+
"""
|
315
|
+
Wait until a peer with the given display name is found.
|
316
|
+
|
317
|
+
Args:
|
318
|
+
target_client_name (str): The display name to search for.
|
319
|
+
interval (float): Time interval between checks in seconds.
|
320
|
+
timeout (float): Maximum time to wait in seconds.
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
Dict[str, Any]: The information dictionary for the first matching peer.
|
324
|
+
|
325
|
+
Raises:
|
326
|
+
TimeoutError: If no matching peer is found within the timeout period.
|
327
|
+
"""
|
328
|
+
start_time = asyncio.get_event_loop().time()
|
329
|
+
while True:
|
330
|
+
peers = await self.get_peers_by_client_name(target_client_name=target_client_name)
|
331
|
+
if peers:
|
332
|
+
return peers[0]
|
333
|
+
if asyncio.get_event_loop().time() - start_time > timeout:
|
334
|
+
raise TimeoutError(f"Couldn't find {target_client_name} within {timeout} seconds")
|
335
|
+
await asyncio.sleep(interval)
|
336
|
+
|
337
|
+
async def send_typed_message(
|
338
|
+
self,
|
339
|
+
peer_id: str,
|
340
|
+
text: str,
|
341
|
+
message_type: str = MessageTypes.TEXT.value,
|
342
|
+
job_id: Optional[str] = None,
|
343
|
+
*args,
|
344
|
+
**kwargs
|
345
|
+
) -> None:
|
346
|
+
"""
|
347
|
+
Send a message with a specific type using the MessageFactory.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
peer_id (str): The target peer's identifier.
|
351
|
+
text (str): The message text.
|
352
|
+
message_type (str): The type of the message.
|
353
|
+
job_id (Optional[str]): Optional job identifier.
|
354
|
+
*args: Additional positional arguments for MessageFactory.
|
355
|
+
**kwargs: Additional keyword arguments for MessageFactory.
|
356
|
+
"""
|
357
|
+
message_factory = MessageFactory(
|
358
|
+
message_type=message_type,
|
359
|
+
from_id=self.client_id,
|
360
|
+
job_id=job_id,
|
361
|
+
*args,
|
362
|
+
**kwargs
|
363
|
+
)
|
364
|
+
await self._send_message(peer_id=peer_id, msg=message_factory.generate_message(text=text))
|
365
|
+
|
366
|
+
async def send_message(self, peer_id: str, text: str) -> str:
|
367
|
+
"""
|
368
|
+
Send a TEXT message to a specific peer and create a pending job.
|
369
|
+
|
370
|
+
Args:
|
371
|
+
peer_id (str): The target peer's identifier.
|
372
|
+
text (str): The message text.
|
373
|
+
|
374
|
+
Returns:
|
375
|
+
str: A unique queue/job identifier.
|
376
|
+
"""
|
377
|
+
queue_id: str = str(uuid.uuid4())
|
378
|
+
loop = asyncio.get_event_loop()
|
379
|
+
future = loop.create_future()
|
380
|
+
# Attach metadata to the future for heartbeat monitoring.
|
381
|
+
future.peer_id = peer_id
|
382
|
+
current_time = loop.time()
|
383
|
+
future.created_at = current_time
|
384
|
+
future.heartbeat = current_time
|
385
|
+
self.pending_jobs_futures[queue_id] = future
|
386
|
+
|
387
|
+
message_factory = MessageFactory(
|
388
|
+
message_type=MessageTypes.TEXT.value,
|
389
|
+
from_id=self.client_id,
|
390
|
+
queue_id=queue_id,
|
391
|
+
)
|
392
|
+
await self._send_message(peer_id=peer_id, msg=message_factory.generate_message(text=text))
|
393
|
+
return queue_id
|
394
|
+
|
395
|
+
async def wait_for_result(self, job_id: str, timeout: Optional[float] = None) -> Tuple[bool, str]:
|
396
|
+
"""
|
397
|
+
Wait for the result of a job identified by job_id.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
job_id (str): The job identifier.
|
401
|
+
timeout (Optional[float]): Optional timeout in seconds.
|
402
|
+
|
403
|
+
Returns:
|
404
|
+
str: The result text.
|
405
|
+
|
406
|
+
Raises:
|
407
|
+
ValueError: If no pending future is found for the job_id.
|
408
|
+
asyncio.TimeoutError: If the result is not received within the timeout.
|
409
|
+
"""
|
410
|
+
fut = self.pending_jobs_futures.get(job_id)
|
411
|
+
if not fut:
|
412
|
+
raise ValueError(f"No pending future for job_id={job_id}")
|
413
|
+
try:
|
414
|
+
result = await asyncio.wait_for(fut, timeout=timeout)
|
415
|
+
return True, result
|
416
|
+
except PeerDisconnectedException:
|
417
|
+
return False, "Peer disconnected"
|
418
|
+
except AgentConnectionError:
|
419
|
+
return False, "couldn't connect to Agent"
|
420
|
+
|
421
|
+
async def _send_message(self, peer_id: str, msg: Dict[str, Any]) -> None:
|
422
|
+
"""
|
423
|
+
Send a message to a peer using the underlying connector.
|
424
|
+
|
425
|
+
Args:
|
426
|
+
peer_id (str): The target peer's identifier.
|
427
|
+
msg (Dict[str, Any]): The message payload.
|
428
|
+
"""
|
429
|
+
await self.connector.send_message(peer_id, msg)
|
430
|
+
|
431
|
+
async def register_job(self, job_call_back: Callable) -> None:
|
432
|
+
"""
|
433
|
+
Register a new job callback with the job manager.
|
434
|
+
|
435
|
+
Args:
|
436
|
+
job_call_back (Callable): The callback function to register.
|
437
|
+
"""
|
438
|
+
self.job_manager.register_job(job_call_back=job_call_back)
|
439
|
+
|
440
|
+
def register_on_peer_connected_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
|
441
|
+
"""
|
442
|
+
Register an async callback for when a peer connects.
|
443
|
+
|
444
|
+
Args:
|
445
|
+
callback (Callable[[str], Awaitable[None]]): Async function accepting a peer_id.
|
446
|
+
"""
|
447
|
+
if callback not in self._peer_connected_callbacks:
|
448
|
+
self._peer_connected_callbacks.append(callback)
|
449
|
+
|
450
|
+
def register_on_peer_disconnected_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
|
451
|
+
"""
|
452
|
+
Register an async callback for when a peer disconnects.
|
453
|
+
|
454
|
+
Args:
|
455
|
+
callback (Callable[[str], Awaitable[None]]): Async function accepting a peer_id.
|
456
|
+
"""
|
457
|
+
if callback not in self._peer_disconnected_callbacks:
|
458
|
+
self._peer_disconnected_callbacks.append(callback)
|
459
|
+
|
460
|
+
def register_on_peer_discovered_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
|
461
|
+
"""
|
462
|
+
Register an async callback for when a peer is discovered.
|
463
|
+
|
464
|
+
Args:
|
465
|
+
callback (Callable[[str], Awaitable[None]]): Async function accepting a peer_id.
|
466
|
+
"""
|
467
|
+
if callback not in self._peer_discovered_callbacks:
|
468
|
+
self._peer_discovered_callbacks.append(callback)
|
469
|
+
|
470
|
+
def register_on_peer_lost_callback(self, callback: Callable[[str], Awaitable[None]]) -> None:
|
471
|
+
"""
|
472
|
+
Register an async callback for when a peer is lost.
|
473
|
+
|
474
|
+
Args:
|
475
|
+
callback (Callable[[str], Awaitable[None]]): Async function accepting a peer_id.
|
476
|
+
"""
|
477
|
+
if callback not in self._peer_lost_callbacks:
|
478
|
+
self._peer_lost_callbacks.append(callback)
|
479
|
+
|
480
|
+
async def _check_future_timeouts(self) -> None:
|
481
|
+
"""
|
482
|
+
Periodically check all pending futures. If the time since the last heartbeat exceeds
|
483
|
+
_HEARTBEAT_TIMEOUT seconds, resolve the future with a timeout message and remove it.
|
484
|
+
"""
|
485
|
+
while True:
|
486
|
+
try:
|
487
|
+
current_time = asyncio.get_event_loop().time()
|
488
|
+
for job_id, fut in list(self.pending_jobs_futures.items()):
|
489
|
+
if not fut.done():
|
490
|
+
last_hb = getattr(fut, "heartbeat", None)
|
491
|
+
if last_hb is not None and (current_time - last_hb > self._HEARTBEAT_TIMEOUT):
|
492
|
+
del self.pending_jobs_futures[job_id]
|
493
|
+
fut.set_exception(AgentConnectionError)
|
494
|
+
except Exception as e:
|
495
|
+
logger.error(f"[ChatClient] Error in _check_future_timeouts: {e}")
|
496
|
+
await asyncio.sleep(5)
|
497
|
+
|
498
|
+
async def close_connection(self, peer_id: str):
|
499
|
+
"""
|
500
|
+
Close connection
|
501
|
+
|
502
|
+
Args:
|
503
|
+
peer_id (str): Unique identifier of the lost peer.
|
504
|
+
"""
|
505
|
+
await self.connector.close_connection(peer_id=peer_id)
|
AIConnector/core/job.py
ADDED
@@ -0,0 +1,48 @@
|
|
1
|
+
import asyncio
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Any, Optional
|
4
|
+
|
5
|
+
|
6
|
+
class Job:
|
7
|
+
"""
|
8
|
+
Represents a job with associated data, peer information, and an optional asyncio task.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, job_id: str, data: Any, peer_id: str, queue_id: str) -> None:
|
12
|
+
"""
|
13
|
+
Initialize a new Job instance.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
job_id (str): Unique identifier for the job.
|
17
|
+
data (Any): Data associated with the job.
|
18
|
+
peer_id (str): Identifier of the peer related to the job.
|
19
|
+
queue_id (str): Identifier for the job queue.
|
20
|
+
"""
|
21
|
+
self.job_id: str = job_id
|
22
|
+
self.data: Any = data
|
23
|
+
self.peer_id: str = peer_id
|
24
|
+
self.queue_id: str = queue_id
|
25
|
+
self.start_heartbeat: Optional[datetime] = None # Time when heartbeat monitoring started
|
26
|
+
self.heartbeat: Optional[datetime] = None # Last recorded heartbeat timestamp
|
27
|
+
self.task: Optional[asyncio.Task] = None # Asyncio task executing the job
|
28
|
+
|
29
|
+
def set_task(self, task: asyncio.Task) -> None:
|
30
|
+
"""
|
31
|
+
Associate an asyncio task with this job.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
task (asyncio.Task): The asyncio Task to be assigned.
|
35
|
+
"""
|
36
|
+
self.task = task
|
37
|
+
|
38
|
+
def update_heartbeat(self, heartbeat: datetime) -> None:
|
39
|
+
"""
|
40
|
+
Update the heartbeat timestamp for this job.
|
41
|
+
If the start heartbeat is not set, initialize it with the current heartbeat.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
heartbeat (datetime): The current heartbeat timestamp.
|
45
|
+
"""
|
46
|
+
if self.start_heartbeat is None:
|
47
|
+
self.start_heartbeat = heartbeat
|
48
|
+
self.heartbeat = heartbeat
|