gohumanloop 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,51 @@
1
- from typing import Dict, Any, Optional, List, Union, Set
1
+ from typing import Dict, Any, Optional, List, Union
2
2
  import asyncio
3
- import time
3
+ from gohumanloop.utils import run_async_safely
4
4
 
5
5
  from gohumanloop.core.interface import (
6
- HumanLoopManager, HumanLoopProvider, HumanLoopCallback,
7
- HumanLoopResult, HumanLoopStatus, HumanLoopType
6
+ HumanLoopManager,
7
+ HumanLoopProvider,
8
+ HumanLoopCallback,
9
+ HumanLoopResult,
10
+ HumanLoopStatus,
11
+ HumanLoopType,
8
12
  )
9
13
 
14
+
10
15
  class DefaultHumanLoopManager(HumanLoopManager):
11
16
  """默认人机循环管理器实现"""
12
-
13
- def __init__(self, initial_providers: Optional[Union[HumanLoopProvider, List[HumanLoopProvider]]] = None):
14
- self.providers = {}
17
+
18
+ def __init__(
19
+ self,
20
+ initial_providers: Optional[
21
+ Union[HumanLoopProvider, List[HumanLoopProvider]]
22
+ ] = None,
23
+ ):
24
+ self.providers: dict[str, HumanLoopProvider] = {}
15
25
  self.default_provider_id = None
16
-
26
+
27
+ # 存储请求和回调的映射
28
+ self._callbacks: dict[tuple[str, str], HumanLoopCallback] = {}
29
+ # 存储请求的超时任务
30
+ self._timeout_tasks: dict[tuple[str, str], asyncio.Task] = {}
31
+
32
+ # 存储task_id与conversation_id的映射关系
33
+ self._task_conversations: dict[
34
+ str, set[str]
35
+ ] = {} # task_id -> Set[conversation_id]
36
+ # 存储conversation_id与request_id的映射关系
37
+ self._conversation_requests: dict[
38
+ str, list[str]
39
+ ] = {} # conversation_id -> List[request_id]
40
+ # 存储request_id与task_id的反向映射
41
+ self._request_task: dict[
42
+ tuple[str, str], str
43
+ ] = {} # (conversation_id, request_id) -> task_id
44
+ # 存储对话对应的provider_id
45
+ self._conversation_provider: dict[
46
+ str, str
47
+ ] = {} # conversation_id -> provider_id
48
+
17
49
  # 初始化提供者
18
50
  if initial_providers:
19
51
  if isinstance(initial_providers, list):
@@ -26,41 +58,33 @@ class DefaultHumanLoopManager(HumanLoopManager):
26
58
  # 处理单个提供者
27
59
  self.register_provider_sync(initial_providers, initial_providers.name)
28
60
  self.default_provider_id = initial_providers.name
29
-
30
- # 存储请求和回调的映射
31
- self._callbacks = {}
32
- # 存储请求的超时任务
33
- self._timeout_tasks = {}
34
-
35
- # 存储task_id与conversation_id的映射关系
36
- self._task_conversations = {} # task_id -> Set[conversation_id]
37
- # 存储conversation_id与request_id的映射关系
38
- self._conversation_requests = {} # conversation_id -> List[request_id]
39
- # 存储request_id与task_id的反向映射
40
- self._request_task = {} # (conversation_id, request_id) -> task_id
41
- # 存储对话对应的provider_id
42
- self._conversation_provider = {} # conversation_id -> provider_id
43
-
44
- def register_provider_sync(self, provider: HumanLoopProvider, provider_id: Optional[str]) -> str:
61
+
62
+ def register_provider_sync(
63
+ self, provider: HumanLoopProvider, provider_id: Optional[str]
64
+ ) -> str:
45
65
  """同步注册提供者(用于初始化)"""
46
66
  if not provider_id:
47
67
  provider_id = f"provider_{len(self.providers) + 1}"
48
-
68
+
49
69
  self.providers[provider_id] = provider
50
-
70
+
51
71
  if not self.default_provider_id:
52
72
  self.default_provider_id = provider_id
53
-
73
+
54
74
  return provider_id
55
-
56
- async def async_register_provider(self, provider: HumanLoopProvider, provider_id: Optional[str] = None) -> str:
75
+
76
+ async def async_register_provider(
77
+ self, provider: HumanLoopProvider, provider_id: Optional[str] = None
78
+ ) -> str:
57
79
  """注册人机循环提供者"""
58
80
  return self.register_provider_sync(provider, provider_id)
59
-
60
- def register_provider(self, provider: HumanLoopProvider, provider_id: Optional[str] = None) -> str:
81
+
82
+ def register_provider(
83
+ self, provider: HumanLoopProvider, provider_id: Optional[str] = None
84
+ ) -> str:
61
85
  """注册人机循环提供者(同步版本)"""
62
86
  return self.register_provider_sync(provider, provider_id)
63
-
87
+
64
88
  async def async_request_humanloop(
65
89
  self,
66
90
  task_id: str,
@@ -78,13 +102,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
78
102
  provider_id = provider_id or self.default_provider_id
79
103
  if not provider_id or provider_id not in self.providers:
80
104
  raise ValueError(f"Provider '{provider_id}' not found")
81
-
105
+
82
106
  # 检查对话是否已存在且使用了不同的提供者
83
- if conversation_id in self._conversation_provider and self._conversation_provider[conversation_id] != provider_id:
84
- raise ValueError(f"Conversation '{conversation_id}' already exists with a different provider")
85
-
107
+ if (
108
+ conversation_id in self._conversation_provider
109
+ and self._conversation_provider[conversation_id] != provider_id
110
+ ):
111
+ raise ValueError(
112
+ f"Conversation '{conversation_id}' already exists with a different provider"
113
+ )
114
+
86
115
  provider = self.providers[provider_id]
87
-
116
+
88
117
  try:
89
118
  # 发送请求
90
119
  result = await provider.async_request_humanloop(
@@ -93,38 +122,44 @@ class DefaultHumanLoopManager(HumanLoopManager):
93
122
  loop_type=loop_type,
94
123
  context=context,
95
124
  metadata=metadata,
96
- timeout=timeout
125
+ timeout=timeout,
97
126
  )
98
-
127
+
99
128
  request_id = result.request_id
100
-
129
+
101
130
  if not request_id:
102
- raise ValueError(f"Failed to request humanloop for conversation '{conversation_id}'")
103
-
131
+ raise ValueError(
132
+ f"Failed to request humanloop for conversation '{conversation_id}'"
133
+ )
134
+
104
135
  # 存储task_id、conversation_id和request_id的关系
105
136
  if task_id not in self._task_conversations:
106
137
  self._task_conversations[task_id] = set()
107
138
  self._task_conversations[task_id].add(conversation_id)
108
-
139
+
109
140
  if conversation_id not in self._conversation_requests:
110
141
  self._conversation_requests[conversation_id] = []
111
142
  self._conversation_requests[conversation_id].append(request_id)
112
-
143
+
113
144
  self._request_task[(conversation_id, request_id)] = task_id
114
145
  # 存储对话对应的provider_id
115
146
  self._conversation_provider[conversation_id] = provider_id
116
-
147
+
117
148
  # 如果提供了回调,存储它
118
149
  if callback:
119
150
  self._callbacks[(conversation_id, request_id)] = callback
120
-
151
+
121
152
  # 如果设置了超时,创建超时任务
122
153
  if timeout:
123
- await self._async_create_timeout_task(conversation_id, request_id, timeout, provider, callback)
124
-
154
+ await self._async_create_timeout_task(
155
+ conversation_id, request_id, timeout, provider, callback
156
+ )
157
+
125
158
  # 如果是阻塞模式,等待结果
126
159
  if blocking:
127
- return await self._async_wait_for_result(conversation_id, request_id, provider, timeout)
160
+ return await self._async_wait_for_result(
161
+ conversation_id, request_id, provider, timeout
162
+ )
128
163
  else:
129
164
  return request_id
130
165
  except Exception as e:
@@ -132,7 +167,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
132
167
  if callback:
133
168
  try:
134
169
  await callback.async_on_humanloop_error(provider, e)
135
- except:
170
+ except Exception:
136
171
  # 如果错误回调也失败,只能忽略
137
172
  pass
138
173
  raise # 重新抛出异常,让调用者知道发生了错误
@@ -150,30 +185,22 @@ class DefaultHumanLoopManager(HumanLoopManager):
150
185
  blocking: bool = False,
151
186
  ) -> Union[str, HumanLoopResult]:
152
187
  """请求人机循环(同步版本)"""
153
- loop = asyncio.get_event_loop()
154
- if loop.is_running():
155
- # 如果事件循环已经在运行,创建一个新的事件循环
156
- new_loop = asyncio.new_event_loop()
157
- asyncio.set_event_loop(new_loop)
158
- loop = new_loop
159
-
160
- try:
161
- return loop.run_until_complete(
162
- self.async_request_humanloop(
163
- task_id=task_id,
164
- conversation_id=conversation_id,
165
- loop_type=loop_type,
166
- context=context,
167
- callback=callback,
168
- metadata=metadata,
169
- provider_id=provider_id,
170
- timeout=timeout,
171
- blocking=blocking
172
- )
188
+
189
+ result: Union[str, HumanLoopResult] = run_async_safely(
190
+ self.async_request_humanloop(
191
+ task_id=task_id,
192
+ conversation_id=conversation_id,
193
+ loop_type=loop_type,
194
+ context=context,
195
+ callback=callback,
196
+ metadata=metadata,
197
+ provider_id=provider_id,
198
+ timeout=timeout,
199
+ blocking=blocking,
173
200
  )
174
- finally:
175
- if loop != asyncio.get_event_loop():
176
- loop.close()
201
+ )
202
+
203
+ return result
177
204
 
178
205
  async def async_continue_humanloop(
179
206
  self,
@@ -190,16 +217,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
190
217
  if conversation_id in self._conversation_provider:
191
218
  stored_provider_id = self._conversation_provider[conversation_id]
192
219
  if provider_id and provider_id != stored_provider_id:
193
- raise ValueError(f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'")
220
+ raise ValueError(
221
+ f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'"
222
+ )
194
223
  provider_id = stored_provider_id
195
224
  else:
196
225
  provider_id = provider_id or self.default_provider_id
197
-
226
+
198
227
  if not provider_id or provider_id not in self.providers:
199
228
  raise ValueError(f"Provider '{provider_id}' not found")
200
-
229
+
201
230
  provider = self.providers[provider_id]
202
-
231
+
203
232
  try:
204
233
  # 发送继续请求
205
234
  result = await provider.async_continue_humanloop(
@@ -208,42 +237,48 @@ class DefaultHumanLoopManager(HumanLoopManager):
208
237
  metadata=metadata,
209
238
  timeout=timeout,
210
239
  )
211
-
240
+
212
241
  request_id = result.request_id
213
242
 
214
243
  if not request_id:
215
- raise ValueError(f"Failed to continue humanloop for conversation '{conversation_id}'")
216
-
244
+ raise ValueError(
245
+ f"Failed to continue humanloop for conversation '{conversation_id}'"
246
+ )
247
+
217
248
  # 更新conversation_id和request_id的关系
218
249
  if conversation_id not in self._conversation_requests:
219
250
  self._conversation_requests[conversation_id] = []
220
251
  self._conversation_requests[conversation_id].append(request_id)
221
-
252
+
222
253
  # 查找此conversation_id对应的task_id
223
254
  task_id = None
224
255
  for t_id, convs in self._task_conversations.items():
225
256
  if conversation_id in convs:
226
257
  task_id = t_id
227
258
  break
228
-
259
+
229
260
  if task_id:
230
261
  self._request_task[(conversation_id, request_id)] = task_id
231
-
262
+
232
263
  # 存储对话对应的provider_id,如果对话不存在才存储
233
264
  if conversation_id not in self._conversation_provider:
234
265
  self._conversation_provider[conversation_id] = provider_id
235
-
266
+
236
267
  # 如果提供了回调,存储它
237
268
  if callback:
238
269
  self._callbacks[(conversation_id, request_id)] = callback
239
-
270
+
240
271
  # 如果设置了超时,创建超时任务
241
272
  if timeout:
242
- await self._async_create_timeout_task(conversation_id, request_id, timeout, provider, callback)
243
-
273
+ await self._async_create_timeout_task(
274
+ conversation_id, request_id, timeout, provider, callback
275
+ )
276
+
244
277
  # 如果是阻塞模式,等待结果
245
278
  if blocking:
246
- return await self._async_wait_for_result(conversation_id, request_id, provider, timeout)
279
+ return await self._async_wait_for_result(
280
+ conversation_id, request_id, provider, timeout
281
+ )
247
282
  else:
248
283
  return request_id
249
284
  except Exception as e:
@@ -251,7 +286,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
251
286
  if callback:
252
287
  try:
253
288
  await callback.async_on_humanloop_error(provider, e)
254
- except:
289
+ except Exception:
255
290
  # 如果错误回调也失败,只能忽略
256
291
  pass
257
292
  raise # 重新抛出异常,让调用者知道发生了错误
@@ -267,53 +302,49 @@ class DefaultHumanLoopManager(HumanLoopManager):
267
302
  blocking: bool = False,
268
303
  ) -> Union[str, HumanLoopResult]:
269
304
  """继续人机循环(同步版本)"""
270
- loop = asyncio.get_event_loop()
271
- if loop.is_running():
272
- # 如果事件循环已经在运行,创建一个新的事件循环
273
- new_loop = asyncio.new_event_loop()
274
- asyncio.set_event_loop(new_loop)
275
- loop = new_loop
276
-
277
- try:
278
- return loop.run_until_complete(
279
- self.async_continue_humanloop(
280
- conversation_id=conversation_id,
281
- context=context,
282
- callback=callback,
283
- metadata=metadata,
284
- provider_id=provider_id,
285
- timeout=timeout,
286
- blocking=blocking
287
- )
305
+
306
+ result: Union[str, HumanLoopResult] = run_async_safely(
307
+ self.async_continue_humanloop(
308
+ conversation_id=conversation_id,
309
+ context=context,
310
+ callback=callback,
311
+ metadata=metadata,
312
+ provider_id=provider_id,
313
+ timeout=timeout,
314
+ blocking=blocking,
288
315
  )
289
- finally:
290
- if loop != asyncio.get_event_loop():
291
- loop.close()
316
+ )
317
+
318
+ return result
292
319
 
293
320
  async def async_check_request_status(
294
- self,
295
- conversation_id: str,
296
- request_id: str,
297
- provider_id: Optional[str] = None
321
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
298
322
  ) -> HumanLoopResult:
299
323
  """检查请求状态"""
300
324
  # 如果没有指定provider_id,尝试从存储的映射中获取
301
325
  if provider_id is None:
302
326
  stored_provider_id = self._conversation_provider.get(conversation_id)
303
327
  provider_id = stored_provider_id or self.default_provider_id
304
-
328
+
305
329
  if not provider_id or provider_id not in self.providers:
306
330
  raise ValueError(f"Provider '{provider_id}' not found")
307
-
331
+
308
332
  provider = self.providers[provider_id]
309
-
333
+
310
334
  try:
311
- result = await provider.async_check_request_status(conversation_id, request_id)
312
-
335
+ result = await provider.async_check_request_status(
336
+ conversation_id, request_id
337
+ )
338
+
313
339
  # 如果有回调且状态不是等待或进行中,触发状态更新回调
314
- if (conversation_id, request_id) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
315
- await self._async_trigger_update_callback(conversation_id, request_id, provider, result)
316
-
340
+ if (
341
+ conversation_id,
342
+ request_id,
343
+ ) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
344
+ await self._async_trigger_update_callback(
345
+ conversation_id, request_id, provider, result
346
+ )
347
+
317
348
  return result
318
349
  except Exception as e:
319
350
  # 处理检查状态过程中的异常
@@ -321,350 +352,263 @@ class DefaultHumanLoopManager(HumanLoopManager):
321
352
  if callback:
322
353
  try:
323
354
  await callback.async_on_humanloop_error(provider, e)
324
- except:
355
+ except Exception:
325
356
  # 如果错误回调也失败,只能忽略
326
357
  pass
327
358
  raise # 重新抛出异常,让调用者知道发生了错误
328
359
 
329
-
330
360
  def check_request_status(
331
- self,
332
- conversation_id: str,
333
- request_id: str,
334
- provider_id: Optional[str] = None
361
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
335
362
  ) -> HumanLoopResult:
336
363
  """检查请求状态(同步版本)"""
337
- loop = asyncio.get_event_loop()
338
- if loop.is_running():
339
- # 如果事件循环已经在运行,创建一个新的事件循环
340
- new_loop = asyncio.new_event_loop()
341
- asyncio.set_event_loop(new_loop)
342
- loop = new_loop
343
-
344
- try:
345
- return loop.run_until_complete(
346
- self.async_check_request_status(
347
- conversation_id=conversation_id,
348
- request_id=request_id,
349
- provider_id=provider_id
350
- )
364
+
365
+ result: HumanLoopResult = run_async_safely(
366
+ self.async_check_request_status(
367
+ conversation_id=conversation_id,
368
+ request_id=request_id,
369
+ provider_id=provider_id,
351
370
  )
352
- finally:
353
- if loop != asyncio.get_event_loop():
354
- loop.close()
371
+ )
355
372
 
373
+ return result
356
374
 
357
375
  async def async_check_conversation_status(
358
- self,
359
- conversation_id: str,
360
- provider_id: Optional[str] = None
376
+ self, conversation_id: str, provider_id: Optional[str] = None
361
377
  ) -> HumanLoopResult:
362
378
  """检查对话状态"""
363
379
  # 优先使用对话已关联的提供者
364
380
  if provider_id is None:
365
381
  stored_provider_id = self._conversation_provider.get(conversation_id)
366
382
  provider_id = stored_provider_id or self.default_provider_id
367
-
383
+
368
384
  if not provider_id or provider_id not in self.providers:
369
385
  raise ValueError(f"Provider '{provider_id}' not found")
370
-
386
+
371
387
  # 检查对话指定provider_id或默认provider_id最后一次请求的状态
372
388
  provider = self.providers[provider_id]
373
-
389
+
374
390
  try:
375
391
  # 检查对话指定provider_id或默认provider_id最后一次请求的状态
376
392
  return await provider.async_check_conversation_status(conversation_id)
377
393
  except Exception as e:
378
394
  # 处理检查对话状态过程中的异常
379
395
  # 尝试找到与此对话关联的最后一个请求的回调
380
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
396
+ if (
397
+ conversation_id in self._conversation_requests
398
+ and self._conversation_requests[conversation_id]
399
+ ):
381
400
  last_request_id = self._conversation_requests[conversation_id][-1]
382
401
  callback = self._callbacks.get((conversation_id, last_request_id))
383
402
  if callback:
384
403
  try:
385
404
  await callback.async_on_humanloop_error(provider, e)
386
- except:
405
+ except Exception:
387
406
  # 如果错误回调也失败,只能忽略
388
407
  pass
389
408
  raise # 重新抛出异常,让调用者知道发生了错误
390
-
409
+
391
410
  def check_conversation_status(
392
- self,
393
- conversation_id: str,
394
- provider_id: Optional[str] = None
411
+ self, conversation_id: str, provider_id: Optional[str] = None
395
412
  ) -> HumanLoopResult:
396
413
  """检查对话状态(同步版本)"""
397
- loop = asyncio.get_event_loop()
398
- if loop.is_running():
399
- # 如果事件循环已经在运行,创建一个新的事件循环
400
- new_loop = asyncio.new_event_loop()
401
- asyncio.set_event_loop(new_loop)
402
- loop = new_loop
403
-
404
- try:
405
- return loop.run_until_complete(
406
- self.async_check_conversation_status(
407
- conversation_id=conversation_id,
408
- provider_id=provider_id
409
- )
414
+
415
+ result: HumanLoopResult = run_async_safely(
416
+ self.async_check_conversation_status(
417
+ conversation_id=conversation_id, provider_id=provider_id
410
418
  )
411
- finally:
412
- if loop != asyncio.get_event_loop():
413
- loop.close()
419
+ )
420
+
421
+ return result
414
422
 
415
423
  async def async_cancel_request(
416
- self,
417
- conversation_id: str,
418
- request_id: str,
419
- provider_id: Optional[str] = None
424
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
420
425
  ) -> bool:
421
426
  """取消特定请求"""
422
427
  if provider_id is None:
423
428
  stored_provider_id = self._conversation_provider.get(conversation_id)
424
429
  provider_id = stored_provider_id or self.default_provider_id
425
-
430
+
426
431
  if not provider_id or provider_id not in self.providers:
427
432
  raise ValueError(f"Provider '{provider_id}' not found")
428
-
433
+
429
434
  provider = self.providers[provider_id]
430
435
 
431
436
  # 取消超时任务
432
437
  if (conversation_id, request_id) in self._timeout_tasks:
433
438
  self._timeout_tasks[(conversation_id, request_id)].cancel()
434
439
  del self._timeout_tasks[(conversation_id, request_id)]
435
-
440
+
436
441
  # 从回调映射中删除
437
442
  if (conversation_id, request_id) in self._callbacks:
438
443
  del self._callbacks[(conversation_id, request_id)]
439
-
444
+
440
445
  # 清理request关联
441
446
  if (conversation_id, request_id) in self._request_task:
442
447
  del self._request_task[(conversation_id, request_id)]
443
-
448
+
444
449
  # 从conversation_requests中移除
445
450
  if conversation_id in self._conversation_requests:
446
451
  if request_id in self._conversation_requests[conversation_id]:
447
452
  self._conversation_requests[conversation_id].remove(request_id)
448
-
453
+
449
454
  return await provider.async_cancel_request(conversation_id, request_id)
450
-
451
455
 
452
456
  def cancel_request(
453
- self,
454
- conversation_id: str,
455
- request_id: str,
456
- provider_id: Optional[str] = None
457
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
457
458
  ) -> bool:
458
459
  """取消特定请求(同步版本)"""
459
- loop = asyncio.get_event_loop()
460
- if loop.is_running():
461
- # 如果事件循环已经在运行,创建一个新的事件循环
462
- new_loop = asyncio.new_event_loop()
463
- asyncio.set_event_loop(new_loop)
464
- loop = new_loop
465
-
466
- try:
467
- return loop.run_until_complete(
468
- self.async_cancel_request(
469
- conversation_id=conversation_id,
470
- request_id=request_id,
471
- provider_id=provider_id
472
- )
460
+
461
+ result: bool = run_async_safely(
462
+ self.async_cancel_request(
463
+ conversation_id=conversation_id,
464
+ request_id=request_id,
465
+ provider_id=provider_id,
473
466
  )
474
- finally:
475
- if loop != asyncio.get_event_loop():
476
- loop.close()
467
+ )
477
468
 
469
+ return result
478
470
 
479
471
  async def async_cancel_conversation(
480
- self,
481
- conversation_id: str,
482
- provider_id: Optional[str] = None
472
+ self, conversation_id: str, provider_id: Optional[str] = None
483
473
  ) -> bool:
484
474
  """取消整个对话"""
485
475
  # 优先使用对话已关联的提供者
486
476
  if provider_id is None:
487
477
  stored_provider_id = self._conversation_provider.get(conversation_id)
488
478
  provider_id = stored_provider_id or self.default_provider_id
489
-
479
+
490
480
  if not provider_id or provider_id not in self.providers:
491
481
  raise ValueError(f"Provider '{provider_id}' not found")
492
-
482
+
493
483
  provider = self.providers[provider_id]
494
-
484
+
495
485
  # 取消与此对话相关的所有超时任务和回调
496
486
  keys_to_remove = []
497
487
  for key in self._timeout_tasks:
498
488
  if key[0] == conversation_id:
499
489
  self._timeout_tasks[key].cancel()
500
490
  keys_to_remove.append(key)
501
-
491
+
502
492
  for key in keys_to_remove:
503
493
  del self._timeout_tasks[key]
504
-
494
+
505
495
  keys_to_remove = []
506
496
  for key in self._callbacks:
507
497
  if key[0] == conversation_id:
508
498
  keys_to_remove.append(key)
509
-
499
+
510
500
  for key in keys_to_remove:
511
501
  del self._callbacks[key]
512
-
502
+
513
503
  # 清理与此对话相关的task映射关系
514
504
  # 1. 从task_conversations中移除此对话
515
505
  task_ids_to_update = []
516
506
  for task_id, convs in self._task_conversations.items():
517
507
  if conversation_id in convs:
518
508
  task_ids_to_update.append(task_id)
519
-
509
+
520
510
  for task_id in task_ids_to_update:
521
511
  self._task_conversations[task_id].remove(conversation_id)
522
512
  # 如果task没有关联的对话了,可以考虑删除该task记录
523
513
  if not self._task_conversations[task_id]:
524
514
  del self._task_conversations[task_id]
525
-
515
+
526
516
  # 2. 获取并清理所有与此对话相关的请求
527
517
  request_ids = self._conversation_requests.get(conversation_id, [])
528
518
  for request_id in request_ids:
529
519
  # 清理request_task映射
530
520
  if (conversation_id, request_id) in self._request_task:
531
521
  del self._request_task[(conversation_id, request_id)]
532
-
522
+
533
523
  # 3. 清理conversation_requests映射
534
524
  if conversation_id in self._conversation_requests:
535
525
  del self._conversation_requests[conversation_id]
536
-
526
+
537
527
  # 4. 清理provider关联
538
528
  if conversation_id in self._conversation_provider:
539
529
  del self._conversation_provider[conversation_id]
540
-
530
+
541
531
  return await provider.async_cancel_conversation(conversation_id)
542
-
543
532
 
544
533
  def cancel_conversation(
545
- self,
546
- conversation_id: str,
547
- provider_id: Optional[str] = None
534
+ self, conversation_id: str, provider_id: Optional[str] = None
548
535
  ) -> bool:
549
536
  """取消整个对话(同步版本)"""
550
- loop = asyncio.get_event_loop()
551
- if loop.is_running():
552
- # 如果事件循环已经在运行,创建一个新的事件循环
553
- new_loop = asyncio.new_event_loop()
554
- asyncio.set_event_loop(new_loop)
555
- loop = new_loop
556
-
557
- try:
558
- return loop.run_until_complete(
559
- self.async_cancel_conversation(
560
- conversation_id=conversation_id,
561
- provider_id=provider_id
562
- )
537
+
538
+ result: bool = run_async_safely(
539
+ self.async_cancel_conversation(
540
+ conversation_id=conversation_id, provider_id=provider_id
563
541
  )
564
- finally:
565
- if loop != asyncio.get_event_loop():
566
- loop.close()
542
+ )
543
+
544
+ return result
567
545
 
568
-
569
546
  async def async_get_provider(
570
- self,
571
- provider_id: Optional[str] = None
547
+ self, provider_id: Optional[str] = None
572
548
  ) -> HumanLoopProvider:
573
549
  """获取指定的提供者实例"""
574
550
  provider_id = provider_id or self.default_provider_id
575
551
  if not provider_id or provider_id not in self.providers:
576
552
  raise ValueError(f"Provider '{provider_id}' not found")
577
-
553
+
578
554
  return self.providers[provider_id]
579
-
580
- def get_provider(
581
- self,
582
- provider_id: Optional[str] = None
583
- ) -> HumanLoopProvider:
555
+
556
+ def get_provider(self, provider_id: Optional[str] = None) -> HumanLoopProvider:
584
557
  """获取指定的提供者实例(同步版本)"""
585
- loop = asyncio.get_event_loop()
586
- if loop.is_running():
587
- # 如果事件循环已经在运行,创建一个新的事件循环
588
- new_loop = asyncio.new_event_loop()
589
- asyncio.set_event_loop(new_loop)
590
- loop = new_loop
591
-
592
- try:
593
- return loop.run_until_complete(
594
- self.async_get_provider(provider_id=provider_id)
595
- )
596
- finally:
597
- if loop != asyncio.get_event_loop():
598
- loop.close()
558
+
559
+ result: HumanLoopProvider = run_async_safely(
560
+ self.async_get_provider(provider_id=provider_id)
561
+ )
562
+
563
+ return result
599
564
 
600
565
  async def async_list_providers(self) -> Dict[str, HumanLoopProvider]:
601
566
  """列出所有注册的提供者"""
602
567
  return self.providers
603
-
604
568
 
605
569
  def list_providers(self) -> Dict[str, HumanLoopProvider]:
606
570
  """列出所有注册的提供者(同步版本)"""
607
- loop = asyncio.get_event_loop()
608
- if loop.is_running():
609
- # 如果事件循环已经在运行,创建一个新的事件循环
610
- new_loop = asyncio.new_event_loop()
611
- asyncio.set_event_loop(new_loop)
612
- loop = new_loop
613
-
614
- try:
615
- return loop.run_until_complete(self.async_list_providers())
616
- finally:
617
- if loop != asyncio.get_event_loop():
618
- loop.close()
619
571
 
572
+ result: Dict[str, HumanLoopProvider] = run_async_safely(
573
+ self.async_list_providers()
574
+ )
620
575
 
621
- async def async_set_default_provider(
622
- self,
623
- provider_id: str
624
- ) -> bool:
576
+ return result
577
+
578
+ async def async_set_default_provider(self, provider_id: str) -> bool:
625
579
  """设置默认提供者"""
626
580
  if provider_id not in self.providers:
627
581
  raise ValueError(f"Provider '{provider_id}' not found")
628
-
582
+
629
583
  self.default_provider_id = provider_id
630
584
  return True
631
-
632
585
 
633
- def set_default_provider(
634
- self,
635
- provider_id: str
636
- ) -> bool:
586
+ def set_default_provider(self, provider_id: str) -> bool:
637
587
  """设置默认提供者(同步版本)"""
638
- loop = asyncio.get_event_loop()
639
- if loop.is_running():
640
- # 如果事件循环已经在运行,创建一个新的事件循环
641
- new_loop = asyncio.new_event_loop()
642
- asyncio.set_event_loop(new_loop)
643
- loop = new_loop
644
-
645
- try:
646
- return loop.run_until_complete(
647
- self.async_set_default_provider(provider_id=provider_id)
648
- )
649
- finally:
650
- if loop != asyncio.get_event_loop():
651
- loop.close()
652
588
 
589
+ result: bool = run_async_safely(
590
+ self.async_set_default_provider(provider_id=provider_id)
591
+ )
592
+
593
+ return result
653
594
 
654
595
  async def _async_create_timeout_task(
655
- self,
596
+ self,
656
597
  conversation_id: str,
657
- request_id: str,
658
- timeout: int,
598
+ request_id: str,
599
+ timeout: int,
659
600
  provider: HumanLoopProvider,
660
- callback: Optional[HumanLoopCallback]
661
- ):
601
+ callback: Optional[HumanLoopCallback],
602
+ ) -> None:
662
603
  """创建超时任务"""
663
- async def timeout_task():
604
+
605
+ async def timeout_task() -> None:
664
606
  await asyncio.sleep(timeout)
665
607
  # 检查当前状态
666
- result = await self.async_check_request_status(conversation_id, request_id, provider.name)
667
-
608
+ result = await self.async_check_request_status(
609
+ conversation_id, request_id, provider.name
610
+ )
611
+
668
612
  # 只有当状态为PENDING时才触发超时回调
669
613
  # INPROGRESS状态表示对话正在进行中,不应视为超时
670
614
  if result.status == HumanLoopStatus.PENDING:
@@ -678,55 +622,68 @@ class DefaultHumanLoopManager(HumanLoopManager):
678
622
  self._timeout_tasks[(conversation_id, request_id)].cancel()
679
623
  new_task = asyncio.create_task(timeout_task())
680
624
  self._timeout_tasks[(conversation_id, request_id)] = new_task
681
-
625
+
682
626
  task = asyncio.create_task(timeout_task())
683
627
  self._timeout_tasks[(conversation_id, request_id)] = task
684
-
628
+
685
629
  async def _async_wait_for_result(
686
- self,
630
+ self,
687
631
  conversation_id: str,
688
- request_id: str,
689
- provider: HumanLoopProvider,
690
- timeout: Optional[int] = None
632
+ request_id: str,
633
+ provider: HumanLoopProvider,
634
+ timeout: Optional[int] = None,
691
635
  ) -> HumanLoopResult:
692
636
  """等待循环结果"""
693
- start_time = time.time()
694
637
  poll_interval = 1.0 # 轮询间隔(秒)
695
-
638
+
696
639
  while True:
697
- result = await self.async_check_request_status(conversation_id, request_id, provider.name)
698
-
699
- #如果状态是最终状态(非PENDING),返回结果
640
+ result = await self.async_check_request_status(
641
+ conversation_id, request_id, provider.name
642
+ )
643
+
644
+ # 如果状态是最终状态(非PENDING),返回结果
700
645
  if result.status != HumanLoopStatus.PENDING:
701
646
  return result
702
-
647
+
703
648
  # 等待一段时间后再次轮询
704
649
  await asyncio.sleep(poll_interval)
705
-
706
- async def _async_trigger_update_callback(self, conversation_id: str, request_id: str, provider: HumanLoopProvider, result: HumanLoopResult):
650
+
651
+ async def _async_trigger_update_callback(
652
+ self,
653
+ conversation_id: str,
654
+ request_id: str,
655
+ provider: HumanLoopProvider,
656
+ result: HumanLoopResult,
657
+ ) -> None:
707
658
  """触发状态更新回调"""
708
- callback: Optional[HumanLoopCallback] = self._callbacks.get((conversation_id, request_id))
659
+ callback: Optional[HumanLoopCallback] = self._callbacks.get(
660
+ (conversation_id, request_id)
661
+ )
709
662
  if callback:
710
663
  try:
711
- await callback.on_humanloop_update(provider, result)
664
+ await callback.async_on_humanloop_update(provider, result)
712
665
  # 如果状态是最终状态,可以考虑移除回调
713
- if result.status not in [HumanLoopStatus.PENDING, HumanLoopStatus.INPROGRESS]:
666
+ if result.status not in [
667
+ HumanLoopStatus.PENDING,
668
+ HumanLoopStatus.INPROGRESS,
669
+ ]:
714
670
  del self._callbacks[(conversation_id, request_id)]
715
671
  except Exception as e:
716
672
  # 处理回调执行过程中的异常
717
673
  try:
718
- await callback.on_humanloop_error(provider, e)
719
- except:
674
+ await callback.async_on_humanloop_error(provider, e)
675
+ except Exception:
720
676
  # 如果错误回调也失败,只能忽略
721
677
  pass
722
678
 
723
679
  # 添加新方法用于获取task相关信息
680
+
724
681
  async def async_get_task_conversations(self, task_id: str) -> List[str]:
725
682
  """获取任务关联的所有对话ID
726
-
683
+
727
684
  Args:
728
685
  task_id: 任务ID
729
-
686
+
730
687
  Returns:
731
688
  List[str]: 与任务关联的对话ID列表
732
689
  """
@@ -734,106 +691,129 @@ class DefaultHumanLoopManager(HumanLoopManager):
734
691
 
735
692
  def get_task_conversations(self, task_id: str) -> List[str]:
736
693
  """获取任务关联的所有对话ID
737
-
694
+
738
695
  Args:
739
696
  task_id: 任务ID
740
-
697
+
741
698
  Returns:
742
699
  List[str]: 与任务关联的对话ID列表
743
700
  """
744
701
  return list(self._task_conversations.get(task_id, set()))
745
-
702
+
746
703
  async def async_get_conversation_requests(self, conversation_id: str) -> List[str]:
747
704
  """获取对话关联的所有请求ID
748
-
705
+
749
706
  Args:
750
707
  conversation_id: 对话ID
751
-
708
+
752
709
  Returns:
753
710
  List[str]: 与对话关联的请求ID列表
754
711
  """
755
- return self._conversation_requests.get(conversation_id, [])
712
+ ret: List[str] = self._conversation_requests.get(conversation_id, [])
713
+ return ret
756
714
 
757
715
  def get_conversation_requests(self, conversation_id: str) -> List[str]:
758
716
  """获取对话关联的所有请求ID
759
-
717
+
760
718
  Args:
761
719
  conversation_id: 对话ID
762
-
720
+
763
721
  Returns:
764
722
  List[str]: 与对话关联的请求ID列表
765
723
  """
766
- return self._conversation_requests.get(conversation_id, [])
767
-
768
- async def async_get_request_task(self, conversation_id: str, request_id: str) -> Optional[str]:
724
+ ret: List[str] = self._conversation_requests.get(conversation_id, [])
725
+
726
+ return ret
727
+
728
+ async def async_get_request_task(
729
+ self, conversation_id: str, request_id: str
730
+ ) -> Optional[str]:
769
731
  """获取请求关联的任务ID
770
-
732
+
771
733
  Args:
772
734
  conversation_id: 对话ID
773
735
  request_id: 请求ID
774
-
736
+
775
737
  Returns:
776
738
  Optional[str]: 关联的任务ID,如果不存在则返回None
777
739
  """
778
- return self._request_task.get((conversation_id, request_id))
740
+ ret: Optional[str] = self._request_task.get((conversation_id, request_id))
741
+
742
+ return ret
779
743
 
780
- async def async_get_conversation_provider(self, conversation_id: str) -> Optional[str]:
744
+ async def async_get_conversation_provider(
745
+ self, conversation_id: str
746
+ ) -> Optional[str]:
781
747
  """获取请求关联的提供者ID
782
-
748
+
783
749
  Args:
784
750
  conversation_id: 对话ID
785
-
751
+
786
752
  Returns:
787
753
  Optional[str]: 关联的提供者ID,如果不存在则返回None
788
754
  """
789
- return self._conversation_provider.get(conversation_id)
755
+ ret: Optional[str] = self._conversation_provider.get(conversation_id)
756
+
757
+ return ret
790
758
 
791
759
  async def async_check_conversation_exist(
792
760
  self,
793
- task_id:str,
761
+ task_id: str,
794
762
  conversation_id: str,
795
763
  ) -> bool:
796
764
  """判断对话是否已存在
797
-
765
+
798
766
  Args:
799
767
  conversation_id: 对话标识符
800
768
  provider_id: 使用特定提供者的ID(可选)
801
-
769
+
802
770
  Returns:
803
771
  bool: 如果对话存在返回True,否则返回False
804
772
  """
805
773
  # 检查task_id是否存在且conversation_id是否在该task的对话集合中
806
- if task_id in self._task_conversations and conversation_id in self._task_conversations[task_id]:
774
+ if (
775
+ task_id in self._task_conversations
776
+ and conversation_id in self._task_conversations[task_id]
777
+ ):
807
778
  # 进一步验证该对话是否有关联的请求
808
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
779
+ if (
780
+ conversation_id in self._conversation_requests
781
+ and self._conversation_requests[conversation_id]
782
+ ):
809
783
  return True
810
784
 
811
785
  return False
812
786
 
813
787
  def check_conversation_exist(
814
788
  self,
815
- task_id:str,
789
+ task_id: str,
816
790
  conversation_id: str,
817
791
  ) -> bool:
818
792
  """判断对话是否已存在
819
-
793
+
820
794
  Args:
821
795
  conversation_id: 对话标识符
822
796
  provider_id: 使用特定提供者的ID(可选)
823
-
797
+
824
798
  Returns:
825
799
  bool: 如果对话存在返回True,否则返回False
826
800
  """
827
801
  # 检查task_id是否存在且conversation_id是否在该task的对话集合中
828
- if task_id in self._task_conversations and conversation_id in self._task_conversations[task_id]:
802
+ if (
803
+ task_id in self._task_conversations
804
+ and conversation_id in self._task_conversations[task_id]
805
+ ):
829
806
  # 进一步验证该对话是否有关联的请求
830
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
807
+ if (
808
+ conversation_id in self._conversation_requests
809
+ and self._conversation_requests[conversation_id]
810
+ ):
831
811
  return True
832
812
 
833
813
  return False
834
814
 
835
- async def async_shutdown(self):
836
- pass
815
+ async def async_shutdown(self) -> None:
816
+ pass
837
817
 
838
- def shutdown(self):
839
- pass
818
+ def shutdown(self) -> None:
819
+ pass