gohumanloop 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,51 @@
1
- from typing import Dict, Any, Optional, List, Union, Set
1
+ from typing import Dict, Any, Optional, List, Union
2
2
  import asyncio
3
- import time
4
3
  from gohumanloop.utils import run_async_safely
5
4
 
6
5
  from gohumanloop.core.interface import (
7
- HumanLoopManager, HumanLoopProvider, HumanLoopCallback,
8
- HumanLoopResult, HumanLoopStatus, HumanLoopType
6
+ HumanLoopManager,
7
+ HumanLoopProvider,
8
+ HumanLoopCallback,
9
+ HumanLoopResult,
10
+ HumanLoopStatus,
11
+ HumanLoopType,
9
12
  )
10
13
 
14
+
11
15
  class DefaultHumanLoopManager(HumanLoopManager):
12
16
  """默认人机循环管理器实现"""
13
-
14
- def __init__(self, initial_providers: Optional[Union[HumanLoopProvider, List[HumanLoopProvider]]] = None):
15
- self.providers = {}
17
+
18
+ def __init__(
19
+ self,
20
+ initial_providers: Optional[
21
+ Union[HumanLoopProvider, List[HumanLoopProvider]]
22
+ ] = None,
23
+ ):
24
+ self.providers: dict[str, HumanLoopProvider] = {}
16
25
  self.default_provider_id = None
17
-
26
+
27
+ # 存储请求和回调的映射
28
+ self._callbacks: dict[tuple[str, str], HumanLoopCallback] = {}
29
+ # 存储请求的超时任务
30
+ self._timeout_tasks: dict[tuple[str, str], asyncio.Task] = {}
31
+
32
+ # 存储task_id与conversation_id的映射关系
33
+ self._task_conversations: dict[
34
+ str, set[str]
35
+ ] = {} # task_id -> Set[conversation_id]
36
+ # 存储conversation_id与request_id的映射关系
37
+ self._conversation_requests: dict[
38
+ str, list[str]
39
+ ] = {} # conversation_id -> List[request_id]
40
+ # 存储request_id与task_id的反向映射
41
+ self._request_task: dict[
42
+ tuple[str, str], str
43
+ ] = {} # (conversation_id, request_id) -> task_id
44
+ # 存储对话对应的provider_id
45
+ self._conversation_provider: dict[
46
+ str, str
47
+ ] = {} # conversation_id -> provider_id
48
+
18
49
  # 初始化提供者
19
50
  if initial_providers:
20
51
  if isinstance(initial_providers, list):
@@ -27,41 +58,33 @@ class DefaultHumanLoopManager(HumanLoopManager):
27
58
  # 处理单个提供者
28
59
  self.register_provider_sync(initial_providers, initial_providers.name)
29
60
  self.default_provider_id = initial_providers.name
30
-
31
- # 存储请求和回调的映射
32
- self._callbacks = {}
33
- # 存储请求的超时任务
34
- self._timeout_tasks = {}
35
-
36
- # 存储task_id与conversation_id的映射关系
37
- self._task_conversations = {} # task_id -> Set[conversation_id]
38
- # 存储conversation_id与request_id的映射关系
39
- self._conversation_requests = {} # conversation_id -> List[request_id]
40
- # 存储request_id与task_id的反向映射
41
- self._request_task = {} # (conversation_id, request_id) -> task_id
42
- # 存储对话对应的provider_id
43
- self._conversation_provider = {} # conversation_id -> provider_id
44
-
45
- def register_provider_sync(self, provider: HumanLoopProvider, provider_id: Optional[str]) -> str:
61
+
62
+ def register_provider_sync(
63
+ self, provider: HumanLoopProvider, provider_id: Optional[str]
64
+ ) -> str:
46
65
  """同步注册提供者(用于初始化)"""
47
66
  if not provider_id:
48
67
  provider_id = f"provider_{len(self.providers) + 1}"
49
-
68
+
50
69
  self.providers[provider_id] = provider
51
-
70
+
52
71
  if not self.default_provider_id:
53
72
  self.default_provider_id = provider_id
54
-
73
+
55
74
  return provider_id
56
-
57
- async def async_register_provider(self, provider: HumanLoopProvider, provider_id: Optional[str] = None) -> str:
75
+
76
+ async def async_register_provider(
77
+ self, provider: HumanLoopProvider, provider_id: Optional[str] = None
78
+ ) -> str:
58
79
  """注册人机循环提供者"""
59
80
  return self.register_provider_sync(provider, provider_id)
60
-
61
- def register_provider(self, provider: HumanLoopProvider, provider_id: Optional[str] = None) -> str:
81
+
82
+ def register_provider(
83
+ self, provider: HumanLoopProvider, provider_id: Optional[str] = None
84
+ ) -> str:
62
85
  """注册人机循环提供者(同步版本)"""
63
86
  return self.register_provider_sync(provider, provider_id)
64
-
87
+
65
88
  async def async_request_humanloop(
66
89
  self,
67
90
  task_id: str,
@@ -79,13 +102,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
79
102
  provider_id = provider_id or self.default_provider_id
80
103
  if not provider_id or provider_id not in self.providers:
81
104
  raise ValueError(f"Provider '{provider_id}' not found")
82
-
105
+
83
106
  # 检查对话是否已存在且使用了不同的提供者
84
- if conversation_id in self._conversation_provider and self._conversation_provider[conversation_id] != provider_id:
85
- raise ValueError(f"Conversation '{conversation_id}' already exists with a different provider")
86
-
107
+ if (
108
+ conversation_id in self._conversation_provider
109
+ and self._conversation_provider[conversation_id] != provider_id
110
+ ):
111
+ raise ValueError(
112
+ f"Conversation '{conversation_id}' already exists with a different provider"
113
+ )
114
+
87
115
  provider = self.providers[provider_id]
88
-
116
+
89
117
  try:
90
118
  # 发送请求
91
119
  result = await provider.async_request_humanloop(
@@ -94,38 +122,44 @@ class DefaultHumanLoopManager(HumanLoopManager):
94
122
  loop_type=loop_type,
95
123
  context=context,
96
124
  metadata=metadata,
97
- timeout=timeout
125
+ timeout=timeout,
98
126
  )
99
-
127
+
100
128
  request_id = result.request_id
101
-
129
+
102
130
  if not request_id:
103
- raise ValueError(f"Failed to request humanloop for conversation '{conversation_id}'")
104
-
131
+ raise ValueError(
132
+ f"Failed to request humanloop for conversation '{conversation_id}'"
133
+ )
134
+
105
135
  # 存储task_id、conversation_id和request_id的关系
106
136
  if task_id not in self._task_conversations:
107
137
  self._task_conversations[task_id] = set()
108
138
  self._task_conversations[task_id].add(conversation_id)
109
-
139
+
110
140
  if conversation_id not in self._conversation_requests:
111
141
  self._conversation_requests[conversation_id] = []
112
142
  self._conversation_requests[conversation_id].append(request_id)
113
-
143
+
114
144
  self._request_task[(conversation_id, request_id)] = task_id
115
145
  # 存储对话对应的provider_id
116
146
  self._conversation_provider[conversation_id] = provider_id
117
-
147
+
118
148
  # 如果提供了回调,存储它
119
149
  if callback:
120
150
  self._callbacks[(conversation_id, request_id)] = callback
121
-
151
+
122
152
  # 如果设置了超时,创建超时任务
123
153
  if timeout:
124
- await self._async_create_timeout_task(conversation_id, request_id, timeout, provider, callback)
125
-
154
+ await self._async_create_timeout_task(
155
+ conversation_id, request_id, timeout, provider, callback
156
+ )
157
+
126
158
  # 如果是阻塞模式,等待结果
127
159
  if blocking:
128
- return await self._async_wait_for_result(conversation_id, request_id, provider, timeout)
160
+ return await self._async_wait_for_result(
161
+ conversation_id, request_id, provider, timeout
162
+ )
129
163
  else:
130
164
  return request_id
131
165
  except Exception as e:
@@ -133,7 +167,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
133
167
  if callback:
134
168
  try:
135
169
  await callback.async_on_humanloop_error(provider, e)
136
- except:
170
+ except Exception:
137
171
  # 如果错误回调也失败,只能忽略
138
172
  pass
139
173
  raise # 重新抛出异常,让调用者知道发生了错误
@@ -152,20 +186,21 @@ class DefaultHumanLoopManager(HumanLoopManager):
152
186
  ) -> Union[str, HumanLoopResult]:
153
187
  """请求人机循环(同步版本)"""
154
188
 
155
- return run_async_safely(
156
- self.async_request_humanloop(
157
- task_id=task_id,
158
- conversation_id=conversation_id,
159
- loop_type=loop_type,
160
- context=context,
161
- callback=callback,
162
- metadata=metadata,
163
- provider_id=provider_id,
164
- timeout=timeout,
165
- blocking=blocking
166
- )
189
+ result: Union[str, HumanLoopResult] = run_async_safely(
190
+ self.async_request_humanloop(
191
+ task_id=task_id,
192
+ conversation_id=conversation_id,
193
+ loop_type=loop_type,
194
+ context=context,
195
+ callback=callback,
196
+ metadata=metadata,
197
+ provider_id=provider_id,
198
+ timeout=timeout,
199
+ blocking=blocking,
167
200
  )
201
+ )
168
202
 
203
+ return result
169
204
 
170
205
  async def async_continue_humanloop(
171
206
  self,
@@ -182,16 +217,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
182
217
  if conversation_id in self._conversation_provider:
183
218
  stored_provider_id = self._conversation_provider[conversation_id]
184
219
  if provider_id and provider_id != stored_provider_id:
185
- raise ValueError(f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'")
220
+ raise ValueError(
221
+ f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'"
222
+ )
186
223
  provider_id = stored_provider_id
187
224
  else:
188
225
  provider_id = provider_id or self.default_provider_id
189
-
226
+
190
227
  if not provider_id or provider_id not in self.providers:
191
228
  raise ValueError(f"Provider '{provider_id}' not found")
192
-
229
+
193
230
  provider = self.providers[provider_id]
194
-
231
+
195
232
  try:
196
233
  # 发送继续请求
197
234
  result = await provider.async_continue_humanloop(
@@ -200,42 +237,48 @@ class DefaultHumanLoopManager(HumanLoopManager):
200
237
  metadata=metadata,
201
238
  timeout=timeout,
202
239
  )
203
-
240
+
204
241
  request_id = result.request_id
205
242
 
206
243
  if not request_id:
207
- raise ValueError(f"Failed to continue humanloop for conversation '{conversation_id}'")
208
-
244
+ raise ValueError(
245
+ f"Failed to continue humanloop for conversation '{conversation_id}'"
246
+ )
247
+
209
248
  # 更新conversation_id和request_id的关系
210
249
  if conversation_id not in self._conversation_requests:
211
250
  self._conversation_requests[conversation_id] = []
212
251
  self._conversation_requests[conversation_id].append(request_id)
213
-
252
+
214
253
  # 查找此conversation_id对应的task_id
215
254
  task_id = None
216
255
  for t_id, convs in self._task_conversations.items():
217
256
  if conversation_id in convs:
218
257
  task_id = t_id
219
258
  break
220
-
259
+
221
260
  if task_id:
222
261
  self._request_task[(conversation_id, request_id)] = task_id
223
-
262
+
224
263
  # 存储对话对应的provider_id,如果对话不存在才存储
225
264
  if conversation_id not in self._conversation_provider:
226
265
  self._conversation_provider[conversation_id] = provider_id
227
-
266
+
228
267
  # 如果提供了回调,存储它
229
268
  if callback:
230
269
  self._callbacks[(conversation_id, request_id)] = callback
231
-
270
+
232
271
  # 如果设置了超时,创建超时任务
233
272
  if timeout:
234
- await self._async_create_timeout_task(conversation_id, request_id, timeout, provider, callback)
235
-
273
+ await self._async_create_timeout_task(
274
+ conversation_id, request_id, timeout, provider, callback
275
+ )
276
+
236
277
  # 如果是阻塞模式,等待结果
237
278
  if blocking:
238
- return await self._async_wait_for_result(conversation_id, request_id, provider, timeout)
279
+ return await self._async_wait_for_result(
280
+ conversation_id, request_id, provider, timeout
281
+ )
239
282
  else:
240
283
  return request_id
241
284
  except Exception as e:
@@ -243,7 +286,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
243
286
  if callback:
244
287
  try:
245
288
  await callback.async_on_humanloop_error(provider, e)
246
- except:
289
+ except Exception:
247
290
  # 如果错误回调也失败,只能忽略
248
291
  pass
249
292
  raise # 重新抛出异常,让调用者知道发生了错误
@@ -260,42 +303,48 @@ class DefaultHumanLoopManager(HumanLoopManager):
260
303
  ) -> Union[str, HumanLoopResult]:
261
304
  """继续人机循环(同步版本)"""
262
305
 
263
- return run_async_safely(
264
- self.async_continue_humanloop(
265
- conversation_id=conversation_id,
266
- context=context,
267
- callback=callback,
268
- metadata=metadata,
269
- provider_id=provider_id,
270
- timeout=timeout,
271
- blocking=blocking
272
- )
306
+ result: Union[str, HumanLoopResult] = run_async_safely(
307
+ self.async_continue_humanloop(
308
+ conversation_id=conversation_id,
309
+ context=context,
310
+ callback=callback,
311
+ metadata=metadata,
312
+ provider_id=provider_id,
313
+ timeout=timeout,
314
+ blocking=blocking,
273
315
  )
316
+ )
317
+
318
+ return result
274
319
 
275
320
  async def async_check_request_status(
276
- self,
277
- conversation_id: str,
278
- request_id: str,
279
- provider_id: Optional[str] = None
321
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
280
322
  ) -> HumanLoopResult:
281
323
  """检查请求状态"""
282
324
  # 如果没有指定provider_id,尝试从存储的映射中获取
283
325
  if provider_id is None:
284
326
  stored_provider_id = self._conversation_provider.get(conversation_id)
285
327
  provider_id = stored_provider_id or self.default_provider_id
286
-
328
+
287
329
  if not provider_id or provider_id not in self.providers:
288
330
  raise ValueError(f"Provider '{provider_id}' not found")
289
-
331
+
290
332
  provider = self.providers[provider_id]
291
-
333
+
292
334
  try:
293
- result = await provider.async_check_request_status(conversation_id, request_id)
294
-
335
+ result = await provider.async_check_request_status(
336
+ conversation_id, request_id
337
+ )
338
+
295
339
  # 如果有回调且状态不是等待或进行中,触发状态更新回调
296
- if (conversation_id, request_id) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
297
- await self._async_trigger_update_callback(conversation_id, request_id, provider, result)
298
-
340
+ if (
341
+ conversation_id,
342
+ request_id,
343
+ ) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
344
+ await self._async_trigger_update_callback(
345
+ conversation_id, request_id, provider, result
346
+ )
347
+
299
348
  return result
300
349
  except Exception as e:
301
350
  # 处理检查状态过程中的异常
@@ -303,281 +352,263 @@ class DefaultHumanLoopManager(HumanLoopManager):
303
352
  if callback:
304
353
  try:
305
354
  await callback.async_on_humanloop_error(provider, e)
306
- except:
355
+ except Exception:
307
356
  # 如果错误回调也失败,只能忽略
308
357
  pass
309
358
  raise # 重新抛出异常,让调用者知道发生了错误
310
359
 
311
-
312
360
  def check_request_status(
313
- self,
314
- conversation_id: str,
315
- request_id: str,
316
- provider_id: Optional[str] = None
361
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
317
362
  ) -> HumanLoopResult:
318
363
  """检查请求状态(同步版本)"""
319
364
 
320
- return run_async_safely(
321
- self.async_check_request_status(
322
- conversation_id=conversation_id,
323
- request_id=request_id,
324
- provider_id=provider_id
325
- )
365
+ result: HumanLoopResult = run_async_safely(
366
+ self.async_check_request_status(
367
+ conversation_id=conversation_id,
368
+ request_id=request_id,
369
+ provider_id=provider_id,
326
370
  )
371
+ )
327
372
 
373
+ return result
328
374
 
329
375
  async def async_check_conversation_status(
330
- self,
331
- conversation_id: str,
332
- provider_id: Optional[str] = None
376
+ self, conversation_id: str, provider_id: Optional[str] = None
333
377
  ) -> HumanLoopResult:
334
378
  """检查对话状态"""
335
379
  # 优先使用对话已关联的提供者
336
380
  if provider_id is None:
337
381
  stored_provider_id = self._conversation_provider.get(conversation_id)
338
382
  provider_id = stored_provider_id or self.default_provider_id
339
-
383
+
340
384
  if not provider_id or provider_id not in self.providers:
341
385
  raise ValueError(f"Provider '{provider_id}' not found")
342
-
386
+
343
387
  # 检查对话指定provider_id或默认provider_id最后一次请求的状态
344
388
  provider = self.providers[provider_id]
345
-
389
+
346
390
  try:
347
391
  # 检查对话指定provider_id或默认provider_id最后一次请求的状态
348
392
  return await provider.async_check_conversation_status(conversation_id)
349
393
  except Exception as e:
350
394
  # 处理检查对话状态过程中的异常
351
395
  # 尝试找到与此对话关联的最后一个请求的回调
352
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
396
+ if (
397
+ conversation_id in self._conversation_requests
398
+ and self._conversation_requests[conversation_id]
399
+ ):
353
400
  last_request_id = self._conversation_requests[conversation_id][-1]
354
401
  callback = self._callbacks.get((conversation_id, last_request_id))
355
402
  if callback:
356
403
  try:
357
404
  await callback.async_on_humanloop_error(provider, e)
358
- except:
405
+ except Exception:
359
406
  # 如果错误回调也失败,只能忽略
360
407
  pass
361
408
  raise # 重新抛出异常,让调用者知道发生了错误
362
-
409
+
363
410
  def check_conversation_status(
364
- self,
365
- conversation_id: str,
366
- provider_id: Optional[str] = None
411
+ self, conversation_id: str, provider_id: Optional[str] = None
367
412
  ) -> HumanLoopResult:
368
413
  """检查对话状态(同步版本)"""
369
414
 
370
- return run_async_safely(
371
- self.async_check_conversation_status(
372
- conversation_id=conversation_id,
373
- provider_id=provider_id
374
- )
415
+ result: HumanLoopResult = run_async_safely(
416
+ self.async_check_conversation_status(
417
+ conversation_id=conversation_id, provider_id=provider_id
375
418
  )
419
+ )
420
+
421
+ return result
376
422
 
377
423
  async def async_cancel_request(
378
- self,
379
- conversation_id: str,
380
- request_id: str,
381
- provider_id: Optional[str] = None
424
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
382
425
  ) -> bool:
383
426
  """取消特定请求"""
384
427
  if provider_id is None:
385
428
  stored_provider_id = self._conversation_provider.get(conversation_id)
386
429
  provider_id = stored_provider_id or self.default_provider_id
387
-
430
+
388
431
  if not provider_id or provider_id not in self.providers:
389
432
  raise ValueError(f"Provider '{provider_id}' not found")
390
-
433
+
391
434
  provider = self.providers[provider_id]
392
435
 
393
436
  # 取消超时任务
394
437
  if (conversation_id, request_id) in self._timeout_tasks:
395
438
  self._timeout_tasks[(conversation_id, request_id)].cancel()
396
439
  del self._timeout_tasks[(conversation_id, request_id)]
397
-
440
+
398
441
  # 从回调映射中删除
399
442
  if (conversation_id, request_id) in self._callbacks:
400
443
  del self._callbacks[(conversation_id, request_id)]
401
-
444
+
402
445
  # 清理request关联
403
446
  if (conversation_id, request_id) in self._request_task:
404
447
  del self._request_task[(conversation_id, request_id)]
405
-
448
+
406
449
  # 从conversation_requests中移除
407
450
  if conversation_id in self._conversation_requests:
408
451
  if request_id in self._conversation_requests[conversation_id]:
409
452
  self._conversation_requests[conversation_id].remove(request_id)
410
-
453
+
411
454
  return await provider.async_cancel_request(conversation_id, request_id)
412
-
413
455
 
414
456
  def cancel_request(
415
- self,
416
- conversation_id: str,
417
- request_id: str,
418
- provider_id: Optional[str] = None
457
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
419
458
  ) -> bool:
420
459
  """取消特定请求(同步版本)"""
421
460
 
422
- return run_async_safely(
423
- self.async_cancel_request(
424
- conversation_id=conversation_id,
425
- request_id=request_id,
426
- provider_id=provider_id
427
- )
461
+ result: bool = run_async_safely(
462
+ self.async_cancel_request(
463
+ conversation_id=conversation_id,
464
+ request_id=request_id,
465
+ provider_id=provider_id,
428
466
  )
467
+ )
468
+
469
+ return result
429
470
 
430
471
  async def async_cancel_conversation(
431
- self,
432
- conversation_id: str,
433
- provider_id: Optional[str] = None
472
+ self, conversation_id: str, provider_id: Optional[str] = None
434
473
  ) -> bool:
435
474
  """取消整个对话"""
436
475
  # 优先使用对话已关联的提供者
437
476
  if provider_id is None:
438
477
  stored_provider_id = self._conversation_provider.get(conversation_id)
439
478
  provider_id = stored_provider_id or self.default_provider_id
440
-
479
+
441
480
  if not provider_id or provider_id not in self.providers:
442
481
  raise ValueError(f"Provider '{provider_id}' not found")
443
-
482
+
444
483
  provider = self.providers[provider_id]
445
-
484
+
446
485
  # 取消与此对话相关的所有超时任务和回调
447
486
  keys_to_remove = []
448
487
  for key in self._timeout_tasks:
449
488
  if key[0] == conversation_id:
450
489
  self._timeout_tasks[key].cancel()
451
490
  keys_to_remove.append(key)
452
-
491
+
453
492
  for key in keys_to_remove:
454
493
  del self._timeout_tasks[key]
455
-
494
+
456
495
  keys_to_remove = []
457
496
  for key in self._callbacks:
458
497
  if key[0] == conversation_id:
459
498
  keys_to_remove.append(key)
460
-
499
+
461
500
  for key in keys_to_remove:
462
501
  del self._callbacks[key]
463
-
502
+
464
503
  # 清理与此对话相关的task映射关系
465
504
  # 1. 从task_conversations中移除此对话
466
505
  task_ids_to_update = []
467
506
  for task_id, convs in self._task_conversations.items():
468
507
  if conversation_id in convs:
469
508
  task_ids_to_update.append(task_id)
470
-
509
+
471
510
  for task_id in task_ids_to_update:
472
511
  self._task_conversations[task_id].remove(conversation_id)
473
512
  # 如果task没有关联的对话了,可以考虑删除该task记录
474
513
  if not self._task_conversations[task_id]:
475
514
  del self._task_conversations[task_id]
476
-
515
+
477
516
  # 2. 获取并清理所有与此对话相关的请求
478
517
  request_ids = self._conversation_requests.get(conversation_id, [])
479
518
  for request_id in request_ids:
480
519
  # 清理request_task映射
481
520
  if (conversation_id, request_id) in self._request_task:
482
521
  del self._request_task[(conversation_id, request_id)]
483
-
522
+
484
523
  # 3. 清理conversation_requests映射
485
524
  if conversation_id in self._conversation_requests:
486
525
  del self._conversation_requests[conversation_id]
487
-
526
+
488
527
  # 4. 清理provider关联
489
528
  if conversation_id in self._conversation_provider:
490
529
  del self._conversation_provider[conversation_id]
491
-
530
+
492
531
  return await provider.async_cancel_conversation(conversation_id)
493
-
494
532
 
495
533
  def cancel_conversation(
496
- self,
497
- conversation_id: str,
498
- provider_id: Optional[str] = None
534
+ self, conversation_id: str, provider_id: Optional[str] = None
499
535
  ) -> bool:
500
536
  """取消整个对话(同步版本)"""
501
-
502
- return run_async_safely(
503
- self.async_cancel_conversation(
504
- conversation_id=conversation_id,
505
- provider_id=provider_id
506
- )
537
+
538
+ result: bool = run_async_safely(
539
+ self.async_cancel_conversation(
540
+ conversation_id=conversation_id, provider_id=provider_id
507
541
  )
542
+ )
543
+
544
+ return result
508
545
 
509
546
  async def async_get_provider(
510
- self,
511
- provider_id: Optional[str] = None
547
+ self, provider_id: Optional[str] = None
512
548
  ) -> HumanLoopProvider:
513
549
  """获取指定的提供者实例"""
514
550
  provider_id = provider_id or self.default_provider_id
515
551
  if not provider_id or provider_id not in self.providers:
516
552
  raise ValueError(f"Provider '{provider_id}' not found")
517
-
553
+
518
554
  return self.providers[provider_id]
519
-
520
- def get_provider(
521
- self,
522
- provider_id: Optional[str] = None
523
- ) -> HumanLoopProvider:
555
+
556
+ def get_provider(self, provider_id: Optional[str] = None) -> HumanLoopProvider:
524
557
  """获取指定的提供者实例(同步版本)"""
525
558
 
526
- return run_async_safely(
527
- self.async_get_provider(provider_id=provider_id)
528
- )
559
+ result: HumanLoopProvider = run_async_safely(
560
+ self.async_get_provider(provider_id=provider_id)
561
+ )
562
+
563
+ return result
529
564
 
530
565
  async def async_list_providers(self) -> Dict[str, HumanLoopProvider]:
531
566
  """列出所有注册的提供者"""
532
567
  return self.providers
533
-
534
568
 
535
569
  def list_providers(self) -> Dict[str, HumanLoopProvider]:
536
570
  """列出所有注册的提供者(同步版本)"""
537
-
538
- return run_async_safely(
539
- self.async_list_providers()
540
- )
541
571
 
572
+ result: Dict[str, HumanLoopProvider] = run_async_safely(
573
+ self.async_list_providers()
574
+ )
542
575
 
576
+ return result
543
577
 
544
- async def async_set_default_provider(
545
- self,
546
- provider_id: str
547
- ) -> bool:
578
+ async def async_set_default_provider(self, provider_id: str) -> bool:
548
579
  """设置默认提供者"""
549
580
  if provider_id not in self.providers:
550
581
  raise ValueError(f"Provider '{provider_id}' not found")
551
-
582
+
552
583
  self.default_provider_id = provider_id
553
584
  return True
554
-
555
585
 
556
- def set_default_provider(
557
- self,
558
- provider_id: str
559
- ) -> bool:
586
+ def set_default_provider(self, provider_id: str) -> bool:
560
587
  """设置默认提供者(同步版本)"""
561
588
 
589
+ result: bool = run_async_safely(
590
+ self.async_set_default_provider(provider_id=provider_id)
591
+ )
562
592
 
563
- return run_async_safely(
564
- self.async_set_default_provider(provider_id=provider_id)
565
- )
593
+ return result
566
594
 
567
595
  async def _async_create_timeout_task(
568
- self,
596
+ self,
569
597
  conversation_id: str,
570
- request_id: str,
571
- timeout: int,
598
+ request_id: str,
599
+ timeout: int,
572
600
  provider: HumanLoopProvider,
573
- callback: Optional[HumanLoopCallback]
574
- ):
601
+ callback: Optional[HumanLoopCallback],
602
+ ) -> None:
575
603
  """创建超时任务"""
576
- async def timeout_task():
604
+
605
+ async def timeout_task() -> None:
577
606
  await asyncio.sleep(timeout)
578
607
  # 检查当前状态
579
- result = await self.async_check_request_status(conversation_id, request_id, provider.name)
580
-
608
+ result = await self.async_check_request_status(
609
+ conversation_id, request_id, provider.name
610
+ )
611
+
581
612
  # 只有当状态为PENDING时才触发超时回调
582
613
  # INPROGRESS状态表示对话正在进行中,不应视为超时
583
614
  if result.status == HumanLoopStatus.PENDING:
@@ -591,55 +622,68 @@ class DefaultHumanLoopManager(HumanLoopManager):
591
622
  self._timeout_tasks[(conversation_id, request_id)].cancel()
592
623
  new_task = asyncio.create_task(timeout_task())
593
624
  self._timeout_tasks[(conversation_id, request_id)] = new_task
594
-
625
+
595
626
  task = asyncio.create_task(timeout_task())
596
627
  self._timeout_tasks[(conversation_id, request_id)] = task
597
-
628
+
598
629
  async def _async_wait_for_result(
599
- self,
630
+ self,
600
631
  conversation_id: str,
601
- request_id: str,
602
- provider: HumanLoopProvider,
603
- timeout: Optional[int] = None
632
+ request_id: str,
633
+ provider: HumanLoopProvider,
634
+ timeout: Optional[int] = None,
604
635
  ) -> HumanLoopResult:
605
636
  """等待循环结果"""
606
- start_time = time.time()
607
637
  poll_interval = 1.0 # 轮询间隔(秒)
608
-
638
+
609
639
  while True:
610
- result = await self.async_check_request_status(conversation_id, request_id, provider.name)
611
-
612
- #如果状态是最终状态(非PENDING),返回结果
640
+ result = await self.async_check_request_status(
641
+ conversation_id, request_id, provider.name
642
+ )
643
+
644
+ # 如果状态是最终状态(非PENDING),返回结果
613
645
  if result.status != HumanLoopStatus.PENDING:
614
646
  return result
615
-
647
+
616
648
  # 等待一段时间后再次轮询
617
649
  await asyncio.sleep(poll_interval)
618
-
619
- async def _async_trigger_update_callback(self, conversation_id: str, request_id: str, provider: HumanLoopProvider, result: HumanLoopResult):
650
+
651
+ async def _async_trigger_update_callback(
652
+ self,
653
+ conversation_id: str,
654
+ request_id: str,
655
+ provider: HumanLoopProvider,
656
+ result: HumanLoopResult,
657
+ ) -> None:
620
658
  """触发状态更新回调"""
621
- callback: Optional[HumanLoopCallback] = self._callbacks.get((conversation_id, request_id))
659
+ callback: Optional[HumanLoopCallback] = self._callbacks.get(
660
+ (conversation_id, request_id)
661
+ )
622
662
  if callback:
623
663
  try:
624
- await callback.on_humanloop_update(provider, result)
664
+ await callback.async_on_humanloop_update(provider, result)
625
665
  # 如果状态是最终状态,可以考虑移除回调
626
- if result.status not in [HumanLoopStatus.PENDING, HumanLoopStatus.INPROGRESS]:
666
+ if result.status not in [
667
+ HumanLoopStatus.PENDING,
668
+ HumanLoopStatus.INPROGRESS,
669
+ ]:
627
670
  del self._callbacks[(conversation_id, request_id)]
628
671
  except Exception as e:
629
672
  # 处理回调执行过程中的异常
630
673
  try:
631
- await callback.on_humanloop_error(provider, e)
632
- except:
674
+ await callback.async_on_humanloop_error(provider, e)
675
+ except Exception:
633
676
  # 如果错误回调也失败,只能忽略
634
677
  pass
635
678
 
636
679
  # 添加新方法用于获取task相关信息
680
+
637
681
  async def async_get_task_conversations(self, task_id: str) -> List[str]:
638
682
  """获取任务关联的所有对话ID
639
-
683
+
640
684
  Args:
641
685
  task_id: 任务ID
642
-
686
+
643
687
  Returns:
644
688
  List[str]: 与任务关联的对话ID列表
645
689
  """
@@ -647,106 +691,129 @@ class DefaultHumanLoopManager(HumanLoopManager):
647
691
 
648
692
  def get_task_conversations(self, task_id: str) -> List[str]:
649
693
  """获取任务关联的所有对话ID
650
-
694
+
651
695
  Args:
652
696
  task_id: 任务ID
653
-
697
+
654
698
  Returns:
655
699
  List[str]: 与任务关联的对话ID列表
656
700
  """
657
701
  return list(self._task_conversations.get(task_id, set()))
658
-
702
+
659
703
  async def async_get_conversation_requests(self, conversation_id: str) -> List[str]:
660
704
  """获取对话关联的所有请求ID
661
-
705
+
662
706
  Args:
663
707
  conversation_id: 对话ID
664
-
708
+
665
709
  Returns:
666
710
  List[str]: 与对话关联的请求ID列表
667
711
  """
668
- return self._conversation_requests.get(conversation_id, [])
712
+ ret: List[str] = self._conversation_requests.get(conversation_id, [])
713
+ return ret
669
714
 
670
715
  def get_conversation_requests(self, conversation_id: str) -> List[str]:
671
716
  """获取对话关联的所有请求ID
672
-
717
+
673
718
  Args:
674
719
  conversation_id: 对话ID
675
-
720
+
676
721
  Returns:
677
722
  List[str]: 与对话关联的请求ID列表
678
723
  """
679
- return self._conversation_requests.get(conversation_id, [])
680
-
681
- async def async_get_request_task(self, conversation_id: str, request_id: str) -> Optional[str]:
724
+ ret: List[str] = self._conversation_requests.get(conversation_id, [])
725
+
726
+ return ret
727
+
728
+ async def async_get_request_task(
729
+ self, conversation_id: str, request_id: str
730
+ ) -> Optional[str]:
682
731
  """获取请求关联的任务ID
683
-
732
+
684
733
  Args:
685
734
  conversation_id: 对话ID
686
735
  request_id: 请求ID
687
-
736
+
688
737
  Returns:
689
738
  Optional[str]: 关联的任务ID,如果不存在则返回None
690
739
  """
691
- return self._request_task.get((conversation_id, request_id))
740
+ ret: Optional[str] = self._request_task.get((conversation_id, request_id))
741
+
742
+ return ret
692
743
 
693
- async def async_get_conversation_provider(self, conversation_id: str) -> Optional[str]:
744
+ async def async_get_conversation_provider(
745
+ self, conversation_id: str
746
+ ) -> Optional[str]:
694
747
  """获取请求关联的提供者ID
695
-
748
+
696
749
  Args:
697
750
  conversation_id: 对话ID
698
-
751
+
699
752
  Returns:
700
753
  Optional[str]: 关联的提供者ID,如果不存在则返回None
701
754
  """
702
- return self._conversation_provider.get(conversation_id)
755
+ ret: Optional[str] = self._conversation_provider.get(conversation_id)
756
+
757
+ return ret
703
758
 
704
759
  async def async_check_conversation_exist(
705
760
  self,
706
- task_id:str,
761
+ task_id: str,
707
762
  conversation_id: str,
708
763
  ) -> bool:
709
764
  """判断对话是否已存在
710
-
765
+
711
766
  Args:
712
767
  conversation_id: 对话标识符
713
768
  provider_id: 使用特定提供者的ID(可选)
714
-
769
+
715
770
  Returns:
716
771
  bool: 如果对话存在返回True,否则返回False
717
772
  """
718
773
  # 检查task_id是否存在且conversation_id是否在该task的对话集合中
719
- if task_id in self._task_conversations and conversation_id in self._task_conversations[task_id]:
774
+ if (
775
+ task_id in self._task_conversations
776
+ and conversation_id in self._task_conversations[task_id]
777
+ ):
720
778
  # 进一步验证该对话是否有关联的请求
721
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
779
+ if (
780
+ conversation_id in self._conversation_requests
781
+ and self._conversation_requests[conversation_id]
782
+ ):
722
783
  return True
723
784
 
724
785
  return False
725
786
 
726
787
  def check_conversation_exist(
727
788
  self,
728
- task_id:str,
789
+ task_id: str,
729
790
  conversation_id: str,
730
791
  ) -> bool:
731
792
  """判断对话是否已存在
732
-
793
+
733
794
  Args:
734
795
  conversation_id: 对话标识符
735
796
  provider_id: 使用特定提供者的ID(可选)
736
-
797
+
737
798
  Returns:
738
799
  bool: 如果对话存在返回True,否则返回False
739
800
  """
740
801
  # 检查task_id是否存在且conversation_id是否在该task的对话集合中
741
- if task_id in self._task_conversations and conversation_id in self._task_conversations[task_id]:
802
+ if (
803
+ task_id in self._task_conversations
804
+ and conversation_id in self._task_conversations[task_id]
805
+ ):
742
806
  # 进一步验证该对话是否有关联的请求
743
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
807
+ if (
808
+ conversation_id in self._conversation_requests
809
+ and self._conversation_requests[conversation_id]
810
+ ):
744
811
  return True
745
812
 
746
813
  return False
747
814
 
748
- async def async_shutdown(self):
749
- pass
815
+ async def async_shutdown(self) -> None:
816
+ pass
750
817
 
751
- def shutdown(self):
752
- pass
818
+ def shutdown(self) -> None:
819
+ pass