gohumanloop 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,839 @@
1
+ from typing import Dict, Any, Optional, List, Union, Set
2
+ import asyncio
3
+ import time
4
+
5
+ from gohumanloop.core.interface import (
6
+ HumanLoopManager, HumanLoopProvider, HumanLoopCallback,
7
+ HumanLoopResult, HumanLoopStatus, HumanLoopType
8
+ )
9
+
10
+ class DefaultHumanLoopManager(HumanLoopManager):
11
+ """默认人机循环管理器实现"""
12
+
13
+ def __init__(self, initial_providers: Optional[Union[HumanLoopProvider, List[HumanLoopProvider]]] = None):
14
+ self.providers = {}
15
+ self.default_provider_id = None
16
+
17
+ # 初始化提供者
18
+ if initial_providers:
19
+ if isinstance(initial_providers, list):
20
+ # 处理提供者列表
21
+ for provider in initial_providers:
22
+ self.register_provider_sync(provider, provider.name)
23
+ if self.default_provider_id is None:
24
+ self.default_provider_id = provider.name
25
+ else:
26
+ # 处理单个提供者
27
+ self.register_provider_sync(initial_providers, initial_providers.name)
28
+ self.default_provider_id = initial_providers.name
29
+
30
+ # 存储请求和回调的映射
31
+ self._callbacks = {}
32
+ # 存储请求的超时任务
33
+ self._timeout_tasks = {}
34
+
35
+ # 存储task_id与conversation_id的映射关系
36
+ self._task_conversations = {} # task_id -> Set[conversation_id]
37
+ # 存储conversation_id与request_id的映射关系
38
+ self._conversation_requests = {} # conversation_id -> List[request_id]
39
+ # 存储request_id与task_id的反向映射
40
+ self._request_task = {} # (conversation_id, request_id) -> task_id
41
+ # 存储对话对应的provider_id
42
+ self._conversation_provider = {} # conversation_id -> provider_id
43
+
44
+ def register_provider_sync(self, provider: HumanLoopProvider, provider_id: Optional[str]) -> str:
45
+ """同步注册提供者(用于初始化)"""
46
+ if not provider_id:
47
+ provider_id = f"provider_{len(self.providers) + 1}"
48
+
49
+ self.providers[provider_id] = provider
50
+
51
+ if not self.default_provider_id:
52
+ self.default_provider_id = provider_id
53
+
54
+ return provider_id
55
+
56
+ async def async_register_provider(self, provider: HumanLoopProvider, provider_id: Optional[str] = None) -> str:
57
+ """注册人机循环提供者"""
58
+ return self.register_provider_sync(provider, provider_id)
59
+
60
+ def register_provider(self, provider: HumanLoopProvider, provider_id: Optional[str] = None) -> str:
61
+ """注册人机循环提供者(同步版本)"""
62
+ return self.register_provider_sync(provider, provider_id)
63
+
64
+ async def async_request_humanloop(
65
+ self,
66
+ task_id: str,
67
+ conversation_id: str,
68
+ loop_type: HumanLoopType,
69
+ context: Dict[str, Any],
70
+ callback: Optional[HumanLoopCallback] = None,
71
+ metadata: Optional[Dict[str, Any]] = None,
72
+ provider_id: Optional[str] = None,
73
+ timeout: Optional[int] = None,
74
+ blocking: bool = False,
75
+ ) -> Union[str, HumanLoopResult]:
76
+ """请求人机循环"""
77
+ # 确定使用哪个提供者
78
+ provider_id = provider_id or self.default_provider_id
79
+ if not provider_id or provider_id not in self.providers:
80
+ raise ValueError(f"Provider '{provider_id}' not found")
81
+
82
+ # 检查对话是否已存在且使用了不同的提供者
83
+ if conversation_id in self._conversation_provider and self._conversation_provider[conversation_id] != provider_id:
84
+ raise ValueError(f"Conversation '{conversation_id}' already exists with a different provider")
85
+
86
+ provider = self.providers[provider_id]
87
+
88
+ try:
89
+ # 发送请求
90
+ result = await provider.async_request_humanloop(
91
+ task_id=task_id,
92
+ conversation_id=conversation_id,
93
+ loop_type=loop_type,
94
+ context=context,
95
+ metadata=metadata,
96
+ timeout=timeout
97
+ )
98
+
99
+ request_id = result.request_id
100
+
101
+ if not request_id:
102
+ raise ValueError(f"Failed to request humanloop for conversation '{conversation_id}'")
103
+
104
+ # 存储task_id、conversation_id和request_id的关系
105
+ if task_id not in self._task_conversations:
106
+ self._task_conversations[task_id] = set()
107
+ self._task_conversations[task_id].add(conversation_id)
108
+
109
+ if conversation_id not in self._conversation_requests:
110
+ self._conversation_requests[conversation_id] = []
111
+ self._conversation_requests[conversation_id].append(request_id)
112
+
113
+ self._request_task[(conversation_id, request_id)] = task_id
114
+ # 存储对话对应的provider_id
115
+ self._conversation_provider[conversation_id] = provider_id
116
+
117
+ # 如果提供了回调,存储它
118
+ if callback:
119
+ self._callbacks[(conversation_id, request_id)] = callback
120
+
121
+ # 如果设置了超时,创建超时任务
122
+ if timeout:
123
+ await self._async_create_timeout_task(conversation_id, request_id, timeout, provider, callback)
124
+
125
+ # 如果是阻塞模式,等待结果
126
+ if blocking:
127
+ return await self._async_wait_for_result(conversation_id, request_id, provider, timeout)
128
+ else:
129
+ return request_id
130
+ except Exception as e:
131
+ # 处理请求过程中的异常
132
+ if callback:
133
+ try:
134
+ await callback.async_on_humanloop_error(provider, e)
135
+ except:
136
+ # 如果错误回调也失败,只能忽略
137
+ pass
138
+ raise # 重新抛出异常,让调用者知道发生了错误
139
+
140
+ def request_humanloop(
141
+ self,
142
+ task_id: str,
143
+ conversation_id: str,
144
+ loop_type: HumanLoopType,
145
+ context: Dict[str, Any],
146
+ callback: Optional[HumanLoopCallback] = None,
147
+ metadata: Optional[Dict[str, Any]] = None,
148
+ provider_id: Optional[str] = None,
149
+ timeout: Optional[int] = None,
150
+ blocking: bool = False,
151
+ ) -> Union[str, HumanLoopResult]:
152
+ """请求人机循环(同步版本)"""
153
+ loop = asyncio.get_event_loop()
154
+ if loop.is_running():
155
+ # 如果事件循环已经在运行,创建一个新的事件循环
156
+ new_loop = asyncio.new_event_loop()
157
+ asyncio.set_event_loop(new_loop)
158
+ loop = new_loop
159
+
160
+ try:
161
+ return loop.run_until_complete(
162
+ self.async_request_humanloop(
163
+ task_id=task_id,
164
+ conversation_id=conversation_id,
165
+ loop_type=loop_type,
166
+ context=context,
167
+ callback=callback,
168
+ metadata=metadata,
169
+ provider_id=provider_id,
170
+ timeout=timeout,
171
+ blocking=blocking
172
+ )
173
+ )
174
+ finally:
175
+ if loop != asyncio.get_event_loop():
176
+ loop.close()
177
+
178
+ async def async_continue_humanloop(
179
+ self,
180
+ conversation_id: str,
181
+ context: Dict[str, Any],
182
+ callback: Optional[HumanLoopCallback] = None,
183
+ metadata: Optional[Dict[str, Any]] = None,
184
+ provider_id: Optional[str] = None,
185
+ timeout: Optional[int] = None,
186
+ blocking: bool = False,
187
+ ) -> Union[str, HumanLoopResult]:
188
+ """继续人机循环"""
189
+ # 确定使用哪个提供者,优先使用对话已关联的提供者
190
+ if conversation_id in self._conversation_provider:
191
+ stored_provider_id = self._conversation_provider[conversation_id]
192
+ if provider_id and provider_id != stored_provider_id:
193
+ raise ValueError(f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'")
194
+ provider_id = stored_provider_id
195
+ else:
196
+ provider_id = provider_id or self.default_provider_id
197
+
198
+ if not provider_id or provider_id not in self.providers:
199
+ raise ValueError(f"Provider '{provider_id}' not found")
200
+
201
+ provider = self.providers[provider_id]
202
+
203
+ try:
204
+ # 发送继续请求
205
+ result = await provider.async_continue_humanloop(
206
+ conversation_id=conversation_id,
207
+ context=context,
208
+ metadata=metadata,
209
+ timeout=timeout,
210
+ )
211
+
212
+ request_id = result.request_id
213
+
214
+ if not request_id:
215
+ raise ValueError(f"Failed to continue humanloop for conversation '{conversation_id}'")
216
+
217
+ # 更新conversation_id和request_id的关系
218
+ if conversation_id not in self._conversation_requests:
219
+ self._conversation_requests[conversation_id] = []
220
+ self._conversation_requests[conversation_id].append(request_id)
221
+
222
+ # 查找此conversation_id对应的task_id
223
+ task_id = None
224
+ for t_id, convs in self._task_conversations.items():
225
+ if conversation_id in convs:
226
+ task_id = t_id
227
+ break
228
+
229
+ if task_id:
230
+ self._request_task[(conversation_id, request_id)] = task_id
231
+
232
+ # 存储对话对应的provider_id,如果对话不存在才存储
233
+ if conversation_id not in self._conversation_provider:
234
+ self._conversation_provider[conversation_id] = provider_id
235
+
236
+ # 如果提供了回调,存储它
237
+ if callback:
238
+ self._callbacks[(conversation_id, request_id)] = callback
239
+
240
+ # 如果设置了超时,创建超时任务
241
+ if timeout:
242
+ await self._async_create_timeout_task(conversation_id, request_id, timeout, provider, callback)
243
+
244
+ # 如果是阻塞模式,等待结果
245
+ if blocking:
246
+ return await self._async_wait_for_result(conversation_id, request_id, provider, timeout)
247
+ else:
248
+ return request_id
249
+ except Exception as e:
250
+ # 处理继续请求过程中的异常
251
+ if callback:
252
+ try:
253
+ await callback.async_on_humanloop_error(provider, e)
254
+ except:
255
+ # 如果错误回调也失败,只能忽略
256
+ pass
257
+ raise # 重新抛出异常,让调用者知道发生了错误
258
+
259
+ def continue_humanloop(
260
+ self,
261
+ conversation_id: str,
262
+ context: Dict[str, Any],
263
+ callback: Optional[HumanLoopCallback] = None,
264
+ metadata: Optional[Dict[str, Any]] = None,
265
+ provider_id: Optional[str] = None,
266
+ timeout: Optional[int] = None,
267
+ blocking: bool = False,
268
+ ) -> Union[str, HumanLoopResult]:
269
+ """继续人机循环(同步版本)"""
270
+ loop = asyncio.get_event_loop()
271
+ if loop.is_running():
272
+ # 如果事件循环已经在运行,创建一个新的事件循环
273
+ new_loop = asyncio.new_event_loop()
274
+ asyncio.set_event_loop(new_loop)
275
+ loop = new_loop
276
+
277
+ try:
278
+ return loop.run_until_complete(
279
+ self.async_continue_humanloop(
280
+ conversation_id=conversation_id,
281
+ context=context,
282
+ callback=callback,
283
+ metadata=metadata,
284
+ provider_id=provider_id,
285
+ timeout=timeout,
286
+ blocking=blocking
287
+ )
288
+ )
289
+ finally:
290
+ if loop != asyncio.get_event_loop():
291
+ loop.close()
292
+
293
+ async def async_check_request_status(
294
+ self,
295
+ conversation_id: str,
296
+ request_id: str,
297
+ provider_id: Optional[str] = None
298
+ ) -> HumanLoopResult:
299
+ """检查请求状态"""
300
+ # 如果没有指定provider_id,尝试从存储的映射中获取
301
+ if provider_id is None:
302
+ stored_provider_id = self._conversation_provider.get(conversation_id)
303
+ provider_id = stored_provider_id or self.default_provider_id
304
+
305
+ if not provider_id or provider_id not in self.providers:
306
+ raise ValueError(f"Provider '{provider_id}' not found")
307
+
308
+ provider = self.providers[provider_id]
309
+
310
+ try:
311
+ result = await provider.async_check_request_status(conversation_id, request_id)
312
+
313
+ # 如果有回调且状态不是等待或进行中,触发状态更新回调
314
+ if (conversation_id, request_id) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
315
+ await self._async_trigger_update_callback(conversation_id, request_id, provider, result)
316
+
317
+ return result
318
+ except Exception as e:
319
+ # 处理检查状态过程中的异常
320
+ callback = self._callbacks.get((conversation_id, request_id))
321
+ if callback:
322
+ try:
323
+ await callback.async_on_humanloop_error(provider, e)
324
+ except:
325
+ # 如果错误回调也失败,只能忽略
326
+ pass
327
+ raise # 重新抛出异常,让调用者知道发生了错误
328
+
329
+
330
+ def check_request_status(
331
+ self,
332
+ conversation_id: str,
333
+ request_id: str,
334
+ provider_id: Optional[str] = None
335
+ ) -> HumanLoopResult:
336
+ """检查请求状态(同步版本)"""
337
+ loop = asyncio.get_event_loop()
338
+ if loop.is_running():
339
+ # 如果事件循环已经在运行,创建一个新的事件循环
340
+ new_loop = asyncio.new_event_loop()
341
+ asyncio.set_event_loop(new_loop)
342
+ loop = new_loop
343
+
344
+ try:
345
+ return loop.run_until_complete(
346
+ self.async_check_request_status(
347
+ conversation_id=conversation_id,
348
+ request_id=request_id,
349
+ provider_id=provider_id
350
+ )
351
+ )
352
+ finally:
353
+ if loop != asyncio.get_event_loop():
354
+ loop.close()
355
+
356
+
357
+ async def async_check_conversation_status(
358
+ self,
359
+ conversation_id: str,
360
+ provider_id: Optional[str] = None
361
+ ) -> HumanLoopResult:
362
+ """检查对话状态"""
363
+ # 优先使用对话已关联的提供者
364
+ if provider_id is None:
365
+ stored_provider_id = self._conversation_provider.get(conversation_id)
366
+ provider_id = stored_provider_id or self.default_provider_id
367
+
368
+ if not provider_id or provider_id not in self.providers:
369
+ raise ValueError(f"Provider '{provider_id}' not found")
370
+
371
+ # 检查对话指定provider_id或默认provider_id最后一次请求的状态
372
+ provider = self.providers[provider_id]
373
+
374
+ try:
375
+ # 检查对话指定provider_id或默认provider_id最后一次请求的状态
376
+ return await provider.async_check_conversation_status(conversation_id)
377
+ except Exception as e:
378
+ # 处理检查对话状态过程中的异常
379
+ # 尝试找到与此对话关联的最后一个请求的回调
380
+ if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
381
+ last_request_id = self._conversation_requests[conversation_id][-1]
382
+ callback = self._callbacks.get((conversation_id, last_request_id))
383
+ if callback:
384
+ try:
385
+ await callback.async_on_humanloop_error(provider, e)
386
+ except:
387
+ # 如果错误回调也失败,只能忽略
388
+ pass
389
+ raise # 重新抛出异常,让调用者知道发生了错误
390
+
391
+ def check_conversation_status(
392
+ self,
393
+ conversation_id: str,
394
+ provider_id: Optional[str] = None
395
+ ) -> HumanLoopResult:
396
+ """检查对话状态(同步版本)"""
397
+ loop = asyncio.get_event_loop()
398
+ if loop.is_running():
399
+ # 如果事件循环已经在运行,创建一个新的事件循环
400
+ new_loop = asyncio.new_event_loop()
401
+ asyncio.set_event_loop(new_loop)
402
+ loop = new_loop
403
+
404
+ try:
405
+ return loop.run_until_complete(
406
+ self.async_check_conversation_status(
407
+ conversation_id=conversation_id,
408
+ provider_id=provider_id
409
+ )
410
+ )
411
+ finally:
412
+ if loop != asyncio.get_event_loop():
413
+ loop.close()
414
+
415
+ async def async_cancel_request(
416
+ self,
417
+ conversation_id: str,
418
+ request_id: str,
419
+ provider_id: Optional[str] = None
420
+ ) -> bool:
421
+ """取消特定请求"""
422
+ if provider_id is None:
423
+ stored_provider_id = self._conversation_provider.get(conversation_id)
424
+ provider_id = stored_provider_id or self.default_provider_id
425
+
426
+ if not provider_id or provider_id not in self.providers:
427
+ raise ValueError(f"Provider '{provider_id}' not found")
428
+
429
+ provider = self.providers[provider_id]
430
+
431
+ # 取消超时任务
432
+ if (conversation_id, request_id) in self._timeout_tasks:
433
+ self._timeout_tasks[(conversation_id, request_id)].cancel()
434
+ del self._timeout_tasks[(conversation_id, request_id)]
435
+
436
+ # 从回调映射中删除
437
+ if (conversation_id, request_id) in self._callbacks:
438
+ del self._callbacks[(conversation_id, request_id)]
439
+
440
+ # 清理request关联
441
+ if (conversation_id, request_id) in self._request_task:
442
+ del self._request_task[(conversation_id, request_id)]
443
+
444
+ # 从conversation_requests中移除
445
+ if conversation_id in self._conversation_requests:
446
+ if request_id in self._conversation_requests[conversation_id]:
447
+ self._conversation_requests[conversation_id].remove(request_id)
448
+
449
+ return await provider.async_cancel_request(conversation_id, request_id)
450
+
451
+
452
+ def cancel_request(
453
+ self,
454
+ conversation_id: str,
455
+ request_id: str,
456
+ provider_id: Optional[str] = None
457
+ ) -> bool:
458
+ """取消特定请求(同步版本)"""
459
+ loop = asyncio.get_event_loop()
460
+ if loop.is_running():
461
+ # 如果事件循环已经在运行,创建一个新的事件循环
462
+ new_loop = asyncio.new_event_loop()
463
+ asyncio.set_event_loop(new_loop)
464
+ loop = new_loop
465
+
466
+ try:
467
+ return loop.run_until_complete(
468
+ self.async_cancel_request(
469
+ conversation_id=conversation_id,
470
+ request_id=request_id,
471
+ provider_id=provider_id
472
+ )
473
+ )
474
+ finally:
475
+ if loop != asyncio.get_event_loop():
476
+ loop.close()
477
+
478
+
479
+ async def async_cancel_conversation(
480
+ self,
481
+ conversation_id: str,
482
+ provider_id: Optional[str] = None
483
+ ) -> bool:
484
+ """取消整个对话"""
485
+ # 优先使用对话已关联的提供者
486
+ if provider_id is None:
487
+ stored_provider_id = self._conversation_provider.get(conversation_id)
488
+ provider_id = stored_provider_id or self.default_provider_id
489
+
490
+ if not provider_id or provider_id not in self.providers:
491
+ raise ValueError(f"Provider '{provider_id}' not found")
492
+
493
+ provider = self.providers[provider_id]
494
+
495
+ # 取消与此对话相关的所有超时任务和回调
496
+ keys_to_remove = []
497
+ for key in self._timeout_tasks:
498
+ if key[0] == conversation_id:
499
+ self._timeout_tasks[key].cancel()
500
+ keys_to_remove.append(key)
501
+
502
+ for key in keys_to_remove:
503
+ del self._timeout_tasks[key]
504
+
505
+ keys_to_remove = []
506
+ for key in self._callbacks:
507
+ if key[0] == conversation_id:
508
+ keys_to_remove.append(key)
509
+
510
+ for key in keys_to_remove:
511
+ del self._callbacks[key]
512
+
513
+ # 清理与此对话相关的task映射关系
514
+ # 1. 从task_conversations中移除此对话
515
+ task_ids_to_update = []
516
+ for task_id, convs in self._task_conversations.items():
517
+ if conversation_id in convs:
518
+ task_ids_to_update.append(task_id)
519
+
520
+ for task_id in task_ids_to_update:
521
+ self._task_conversations[task_id].remove(conversation_id)
522
+ # 如果task没有关联的对话了,可以考虑删除该task记录
523
+ if not self._task_conversations[task_id]:
524
+ del self._task_conversations[task_id]
525
+
526
+ # 2. 获取并清理所有与此对话相关的请求
527
+ request_ids = self._conversation_requests.get(conversation_id, [])
528
+ for request_id in request_ids:
529
+ # 清理request_task映射
530
+ if (conversation_id, request_id) in self._request_task:
531
+ del self._request_task[(conversation_id, request_id)]
532
+
533
+ # 3. 清理conversation_requests映射
534
+ if conversation_id in self._conversation_requests:
535
+ del self._conversation_requests[conversation_id]
536
+
537
+ # 4. 清理provider关联
538
+ if conversation_id in self._conversation_provider:
539
+ del self._conversation_provider[conversation_id]
540
+
541
+ return await provider.async_cancel_conversation(conversation_id)
542
+
543
+
544
+ def cancel_conversation(
545
+ self,
546
+ conversation_id: str,
547
+ provider_id: Optional[str] = None
548
+ ) -> bool:
549
+ """取消整个对话(同步版本)"""
550
+ loop = asyncio.get_event_loop()
551
+ if loop.is_running():
552
+ # 如果事件循环已经在运行,创建一个新的事件循环
553
+ new_loop = asyncio.new_event_loop()
554
+ asyncio.set_event_loop(new_loop)
555
+ loop = new_loop
556
+
557
+ try:
558
+ return loop.run_until_complete(
559
+ self.async_cancel_conversation(
560
+ conversation_id=conversation_id,
561
+ provider_id=provider_id
562
+ )
563
+ )
564
+ finally:
565
+ if loop != asyncio.get_event_loop():
566
+ loop.close()
567
+
568
+
569
+ async def async_get_provider(
570
+ self,
571
+ provider_id: Optional[str] = None
572
+ ) -> HumanLoopProvider:
573
+ """获取指定的提供者实例"""
574
+ provider_id = provider_id or self.default_provider_id
575
+ if not provider_id or provider_id not in self.providers:
576
+ raise ValueError(f"Provider '{provider_id}' not found")
577
+
578
+ return self.providers[provider_id]
579
+
580
+ def get_provider(
581
+ self,
582
+ provider_id: Optional[str] = None
583
+ ) -> HumanLoopProvider:
584
+ """获取指定的提供者实例(同步版本)"""
585
+ loop = asyncio.get_event_loop()
586
+ if loop.is_running():
587
+ # 如果事件循环已经在运行,创建一个新的事件循环
588
+ new_loop = asyncio.new_event_loop()
589
+ asyncio.set_event_loop(new_loop)
590
+ loop = new_loop
591
+
592
+ try:
593
+ return loop.run_until_complete(
594
+ self.async_get_provider(provider_id=provider_id)
595
+ )
596
+ finally:
597
+ if loop != asyncio.get_event_loop():
598
+ loop.close()
599
+
600
+ async def async_list_providers(self) -> Dict[str, HumanLoopProvider]:
601
+ """列出所有注册的提供者"""
602
+ return self.providers
603
+
604
+
605
+ def list_providers(self) -> Dict[str, HumanLoopProvider]:
606
+ """列出所有注册的提供者(同步版本)"""
607
+ loop = asyncio.get_event_loop()
608
+ if loop.is_running():
609
+ # 如果事件循环已经在运行,创建一个新的事件循环
610
+ new_loop = asyncio.new_event_loop()
611
+ asyncio.set_event_loop(new_loop)
612
+ loop = new_loop
613
+
614
+ try:
615
+ return loop.run_until_complete(self.async_list_providers())
616
+ finally:
617
+ if loop != asyncio.get_event_loop():
618
+ loop.close()
619
+
620
+
621
+ async def async_set_default_provider(
622
+ self,
623
+ provider_id: str
624
+ ) -> bool:
625
+ """设置默认提供者"""
626
+ if provider_id not in self.providers:
627
+ raise ValueError(f"Provider '{provider_id}' not found")
628
+
629
+ self.default_provider_id = provider_id
630
+ return True
631
+
632
+
633
+ def set_default_provider(
634
+ self,
635
+ provider_id: str
636
+ ) -> bool:
637
+ """设置默认提供者(同步版本)"""
638
+ loop = asyncio.get_event_loop()
639
+ if loop.is_running():
640
+ # 如果事件循环已经在运行,创建一个新的事件循环
641
+ new_loop = asyncio.new_event_loop()
642
+ asyncio.set_event_loop(new_loop)
643
+ loop = new_loop
644
+
645
+ try:
646
+ return loop.run_until_complete(
647
+ self.async_set_default_provider(provider_id=provider_id)
648
+ )
649
+ finally:
650
+ if loop != asyncio.get_event_loop():
651
+ loop.close()
652
+
653
+
654
+ async def _async_create_timeout_task(
655
+ self,
656
+ conversation_id: str,
657
+ request_id: str,
658
+ timeout: int,
659
+ provider: HumanLoopProvider,
660
+ callback: Optional[HumanLoopCallback]
661
+ ):
662
+ """创建超时任务"""
663
+ async def timeout_task():
664
+ await asyncio.sleep(timeout)
665
+ # 检查当前状态
666
+ result = await self.async_check_request_status(conversation_id, request_id, provider.name)
667
+
668
+ # 只有当状态为PENDING时才触发超时回调
669
+ # INPROGRESS状态表示对话正在进行中,不应视为超时
670
+ if result.status == HumanLoopStatus.PENDING:
671
+ if callback:
672
+ await callback.async_on_humanloop_timeout(provider=provider)
673
+ # 如果状态是INPROGRESS,重置超时任务
674
+ elif result.status == HumanLoopStatus.INPROGRESS:
675
+ # 对于进行中的对话,我们可以选择延长超时时间
676
+ # 这里我们简单地重新创建一个超时任务,使用相同的超时时间
677
+ if (conversation_id, request_id) in self._timeout_tasks:
678
+ self._timeout_tasks[(conversation_id, request_id)].cancel()
679
+ new_task = asyncio.create_task(timeout_task())
680
+ self._timeout_tasks[(conversation_id, request_id)] = new_task
681
+
682
+ task = asyncio.create_task(timeout_task())
683
+ self._timeout_tasks[(conversation_id, request_id)] = task
684
+
685
+ async def _async_wait_for_result(
686
+ self,
687
+ conversation_id: str,
688
+ request_id: str,
689
+ provider: HumanLoopProvider,
690
+ timeout: Optional[int] = None
691
+ ) -> HumanLoopResult:
692
+ """等待循环结果"""
693
+ start_time = time.time()
694
+ poll_interval = 1.0 # 轮询间隔(秒)
695
+
696
+ while True:
697
+ result = await self.async_check_request_status(conversation_id, request_id, provider.name)
698
+
699
+ #如果状态是最终状态(非PENDING),返回结果
700
+ if result.status != HumanLoopStatus.PENDING:
701
+ return result
702
+
703
+ # 等待一段时间后再次轮询
704
+ await asyncio.sleep(poll_interval)
705
+
706
+ async def _async_trigger_update_callback(self, conversation_id: str, request_id: str, provider: HumanLoopProvider, result: HumanLoopResult):
707
+ """触发状态更新回调"""
708
+ callback: Optional[HumanLoopCallback] = self._callbacks.get((conversation_id, request_id))
709
+ if callback:
710
+ try:
711
+ await callback.on_humanloop_update(provider, result)
712
+ # 如果状态是最终状态,可以考虑移除回调
713
+ if result.status not in [HumanLoopStatus.PENDING, HumanLoopStatus.INPROGRESS]:
714
+ del self._callbacks[(conversation_id, request_id)]
715
+ except Exception as e:
716
+ # 处理回调执行过程中的异常
717
+ try:
718
+ await callback.on_humanloop_error(provider, e)
719
+ except:
720
+ # 如果错误回调也失败,只能忽略
721
+ pass
722
+
723
+ # 添加新方法用于获取task相关信息
724
+ async def async_get_task_conversations(self, task_id: str) -> List[str]:
725
+ """获取任务关联的所有对话ID
726
+
727
+ Args:
728
+ task_id: 任务ID
729
+
730
+ Returns:
731
+ List[str]: 与任务关联的对话ID列表
732
+ """
733
+ return list(self._task_conversations.get(task_id, set()))
734
+
735
+ def get_task_conversations(self, task_id: str) -> List[str]:
736
+ """获取任务关联的所有对话ID
737
+
738
+ Args:
739
+ task_id: 任务ID
740
+
741
+ Returns:
742
+ List[str]: 与任务关联的对话ID列表
743
+ """
744
+ return list(self._task_conversations.get(task_id, set()))
745
+
746
+ async def async_get_conversation_requests(self, conversation_id: str) -> List[str]:
747
+ """获取对话关联的所有请求ID
748
+
749
+ Args:
750
+ conversation_id: 对话ID
751
+
752
+ Returns:
753
+ List[str]: 与对话关联的请求ID列表
754
+ """
755
+ return self._conversation_requests.get(conversation_id, [])
756
+
757
+ def get_conversation_requests(self, conversation_id: str) -> List[str]:
758
+ """获取对话关联的所有请求ID
759
+
760
+ Args:
761
+ conversation_id: 对话ID
762
+
763
+ Returns:
764
+ List[str]: 与对话关联的请求ID列表
765
+ """
766
+ return self._conversation_requests.get(conversation_id, [])
767
+
768
+ async def async_get_request_task(self, conversation_id: str, request_id: str) -> Optional[str]:
769
+ """获取请求关联的任务ID
770
+
771
+ Args:
772
+ conversation_id: 对话ID
773
+ request_id: 请求ID
774
+
775
+ Returns:
776
+ Optional[str]: 关联的任务ID,如果不存在则返回None
777
+ """
778
+ return self._request_task.get((conversation_id, request_id))
779
+
780
+ async def async_get_conversation_provider(self, conversation_id: str) -> Optional[str]:
781
+ """获取请求关联的提供者ID
782
+
783
+ Args:
784
+ conversation_id: 对话ID
785
+
786
+ Returns:
787
+ Optional[str]: 关联的提供者ID,如果不存在则返回None
788
+ """
789
+ return self._conversation_provider.get(conversation_id)
790
+
791
+ async def async_check_conversation_exist(
792
+ self,
793
+ task_id:str,
794
+ conversation_id: str,
795
+ ) -> bool:
796
+ """判断对话是否已存在
797
+
798
+ Args:
799
+ conversation_id: 对话标识符
800
+ provider_id: 使用特定提供者的ID(可选)
801
+
802
+ Returns:
803
+ bool: 如果对话存在返回True,否则返回False
804
+ """
805
+ # 检查task_id是否存在且conversation_id是否在该task的对话集合中
806
+ if task_id in self._task_conversations and conversation_id in self._task_conversations[task_id]:
807
+ # 进一步验证该对话是否有关联的请求
808
+ if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
809
+ return True
810
+
811
+ return False
812
+
813
+ def check_conversation_exist(
814
+ self,
815
+ task_id:str,
816
+ conversation_id: str,
817
+ ) -> bool:
818
+ """判断对话是否已存在
819
+
820
+ Args:
821
+ conversation_id: 对话标识符
822
+ provider_id: 使用特定提供者的ID(可选)
823
+
824
+ Returns:
825
+ bool: 如果对话存在返回True,否则返回False
826
+ """
827
+ # 检查task_id是否存在且conversation_id是否在该task的对话集合中
828
+ if task_id in self._task_conversations and conversation_id in self._task_conversations[task_id]:
829
+ # 进一步验证该对话是否有关联的请求
830
+ if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
831
+ return True
832
+
833
+ return False
834
+
835
+ async def async_shutdown(self):
836
+ pass
837
+
838
+ def shutdown(self):
839
+ pass