gohumanloop 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,53 @@
1
- from typing import Dict, Any, Optional, List, Union, Set
1
+ from typing import Dict, Any, Optional, List, Union
2
2
  import asyncio
3
- import time
3
+ from datetime import datetime
4
4
  from gohumanloop.utils import run_async_safely
5
5
 
6
6
  from gohumanloop.core.interface import (
7
- HumanLoopManager, HumanLoopProvider, HumanLoopCallback,
8
- HumanLoopResult, HumanLoopStatus, HumanLoopType
7
+ HumanLoopRequest,
8
+ HumanLoopManager,
9
+ HumanLoopProvider,
10
+ HumanLoopCallback,
11
+ HumanLoopResult,
12
+ HumanLoopStatus,
13
+ HumanLoopType,
9
14
  )
10
15
 
16
+
11
17
  class DefaultHumanLoopManager(HumanLoopManager):
12
18
  """默认人机循环管理器实现"""
13
-
14
- def __init__(self, initial_providers: Optional[Union[HumanLoopProvider, List[HumanLoopProvider]]] = None):
15
- self.providers = {}
19
+
20
+ def __init__(
21
+ self,
22
+ initial_providers: Optional[
23
+ Union[HumanLoopProvider, List[HumanLoopProvider]]
24
+ ] = None,
25
+ ):
26
+ self.providers: dict[str, HumanLoopProvider] = {}
16
27
  self.default_provider_id = None
17
-
28
+
29
+ # 存储请求和回调的映射
30
+ self._callbacks: dict[tuple[str, str], HumanLoopCallback] = {}
31
+ # 存储请求的超时任务
32
+ self._timeout_tasks: dict[tuple[str, str], asyncio.Task] = {}
33
+
34
+ # 存储task_id与conversation_id的映射关系
35
+ self._task_conversations: dict[
36
+ str, set[str]
37
+ ] = {} # task_id -> Set[conversation_id]
38
+ # 存储conversation_id与request_id的映射关系
39
+ self._conversation_requests: dict[
40
+ str, list[str]
41
+ ] = {} # conversation_id -> List[request_id]
42
+ # 存储request_id与task_id的反向映射
43
+ self._request_task: dict[
44
+ tuple[str, str], str
45
+ ] = {} # (conversation_id, request_id) -> task_id
46
+ # 存储对话对应的provider_id
47
+ self._conversation_provider: dict[
48
+ str, str
49
+ ] = {} # conversation_id -> provider_id
50
+
18
51
  # 初始化提供者
19
52
  if initial_providers:
20
53
  if isinstance(initial_providers, list):
@@ -27,41 +60,33 @@ class DefaultHumanLoopManager(HumanLoopManager):
27
60
  # 处理单个提供者
28
61
  self.register_provider_sync(initial_providers, initial_providers.name)
29
62
  self.default_provider_id = initial_providers.name
30
-
31
- # 存储请求和回调的映射
32
- self._callbacks = {}
33
- # 存储请求的超时任务
34
- self._timeout_tasks = {}
35
-
36
- # 存储task_id与conversation_id的映射关系
37
- self._task_conversations = {} # task_id -> Set[conversation_id]
38
- # 存储conversation_id与request_id的映射关系
39
- self._conversation_requests = {} # conversation_id -> List[request_id]
40
- # 存储request_id与task_id的反向映射
41
- self._request_task = {} # (conversation_id, request_id) -> task_id
42
- # 存储对话对应的provider_id
43
- self._conversation_provider = {} # conversation_id -> provider_id
44
-
45
- def register_provider_sync(self, provider: HumanLoopProvider, provider_id: Optional[str]) -> str:
63
+
64
+ def register_provider_sync(
65
+ self, provider: HumanLoopProvider, provider_id: Optional[str]
66
+ ) -> str:
46
67
  """同步注册提供者(用于初始化)"""
47
68
  if not provider_id:
48
69
  provider_id = f"provider_{len(self.providers) + 1}"
49
-
70
+
50
71
  self.providers[provider_id] = provider
51
-
72
+
52
73
  if not self.default_provider_id:
53
74
  self.default_provider_id = provider_id
54
-
75
+
55
76
  return provider_id
56
-
57
- async def async_register_provider(self, provider: HumanLoopProvider, provider_id: Optional[str] = None) -> str:
77
+
78
+ async def async_register_provider(
79
+ self, provider: HumanLoopProvider, provider_id: Optional[str] = None
80
+ ) -> str:
58
81
  """注册人机循环提供者"""
59
82
  return self.register_provider_sync(provider, provider_id)
60
-
61
- def register_provider(self, provider: HumanLoopProvider, provider_id: Optional[str] = None) -> str:
83
+
84
+ def register_provider(
85
+ self, provider: HumanLoopProvider, provider_id: Optional[str] = None
86
+ ) -> str:
62
87
  """注册人机循环提供者(同步版本)"""
63
88
  return self.register_provider_sync(provider, provider_id)
64
-
89
+
65
90
  async def async_request_humanloop(
66
91
  self,
67
92
  task_id: str,
@@ -79,13 +104,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
79
104
  provider_id = provider_id or self.default_provider_id
80
105
  if not provider_id or provider_id not in self.providers:
81
106
  raise ValueError(f"Provider '{provider_id}' not found")
82
-
107
+
83
108
  # 检查对话是否已存在且使用了不同的提供者
84
- if conversation_id in self._conversation_provider and self._conversation_provider[conversation_id] != provider_id:
85
- raise ValueError(f"Conversation '{conversation_id}' already exists with a different provider")
86
-
109
+ if (
110
+ conversation_id in self._conversation_provider
111
+ and self._conversation_provider[conversation_id] != provider_id
112
+ ):
113
+ raise ValueError(
114
+ f"Conversation '{conversation_id}' already exists with a different provider"
115
+ )
116
+
87
117
  provider = self.providers[provider_id]
88
-
118
+
89
119
  try:
90
120
  # 发送请求
91
121
  result = await provider.async_request_humanloop(
@@ -94,38 +124,63 @@ class DefaultHumanLoopManager(HumanLoopManager):
94
124
  loop_type=loop_type,
95
125
  context=context,
96
126
  metadata=metadata,
97
- timeout=timeout
127
+ timeout=timeout,
98
128
  )
99
-
129
+
100
130
  request_id = result.request_id
101
-
131
+
102
132
  if not request_id:
103
- raise ValueError(f"Failed to request humanloop for conversation '{conversation_id}'")
104
-
133
+ raise ValueError(
134
+ f"Failed to request humanloop for conversation '{conversation_id}'"
135
+ )
136
+
105
137
  # 存储task_id、conversation_id和request_id的关系
106
138
  if task_id not in self._task_conversations:
107
139
  self._task_conversations[task_id] = set()
108
140
  self._task_conversations[task_id].add(conversation_id)
109
-
141
+
110
142
  if conversation_id not in self._conversation_requests:
111
143
  self._conversation_requests[conversation_id] = []
112
144
  self._conversation_requests[conversation_id].append(request_id)
113
-
145
+
114
146
  self._request_task[(conversation_id, request_id)] = task_id
115
147
  # 存储对话对应的provider_id
116
148
  self._conversation_provider[conversation_id] = provider_id
117
-
149
+
118
150
  # 如果提供了回调,存储它
119
151
  if callback:
152
+ try:
153
+ # 创建请求对象
154
+ request = HumanLoopRequest(
155
+ task_id=task_id,
156
+ conversation_id=conversation_id,
157
+ loop_type=loop_type,
158
+ context=context,
159
+ metadata=metadata or {},
160
+ timeout=timeout,
161
+ created_at=datetime.now(),
162
+ )
163
+ await callback.async_on_humanloop_request(provider, request)
164
+ except Exception as e:
165
+ # 处理回调执行过程中的异常
166
+ try:
167
+ await callback.async_on_humanloop_error(provider, e)
168
+ except Exception:
169
+ # 如果错误回调也失败,只能忽略
170
+ pass
120
171
  self._callbacks[(conversation_id, request_id)] = callback
121
-
172
+
122
173
  # 如果设置了超时,创建超时任务
123
174
  if timeout:
124
- await self._async_create_timeout_task(conversation_id, request_id, timeout, provider, callback)
125
-
175
+ await self._async_create_timeout_task(
176
+ conversation_id, request_id, timeout, provider, callback
177
+ )
178
+
126
179
  # 如果是阻塞模式,等待结果
127
180
  if blocking:
128
- return await self._async_wait_for_result(conversation_id, request_id, provider, timeout)
181
+ return await self._async_wait_for_result(
182
+ conversation_id, request_id, provider, timeout
183
+ )
129
184
  else:
130
185
  return request_id
131
186
  except Exception as e:
@@ -133,7 +188,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
133
188
  if callback:
134
189
  try:
135
190
  await callback.async_on_humanloop_error(provider, e)
136
- except:
191
+ except Exception:
137
192
  # 如果错误回调也失败,只能忽略
138
193
  pass
139
194
  raise # 重新抛出异常,让调用者知道发生了错误
@@ -152,20 +207,21 @@ class DefaultHumanLoopManager(HumanLoopManager):
152
207
  ) -> Union[str, HumanLoopResult]:
153
208
  """请求人机循环(同步版本)"""
154
209
 
155
- return run_async_safely(
156
- self.async_request_humanloop(
157
- task_id=task_id,
158
- conversation_id=conversation_id,
159
- loop_type=loop_type,
160
- context=context,
161
- callback=callback,
162
- metadata=metadata,
163
- provider_id=provider_id,
164
- timeout=timeout,
165
- blocking=blocking
166
- )
210
+ result: Union[str, HumanLoopResult] = run_async_safely(
211
+ self.async_request_humanloop(
212
+ task_id=task_id,
213
+ conversation_id=conversation_id,
214
+ loop_type=loop_type,
215
+ context=context,
216
+ callback=callback,
217
+ metadata=metadata,
218
+ provider_id=provider_id,
219
+ timeout=timeout,
220
+ blocking=blocking,
167
221
  )
222
+ )
168
223
 
224
+ return result
169
225
 
170
226
  async def async_continue_humanloop(
171
227
  self,
@@ -182,16 +238,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
182
238
  if conversation_id in self._conversation_provider:
183
239
  stored_provider_id = self._conversation_provider[conversation_id]
184
240
  if provider_id and provider_id != stored_provider_id:
185
- raise ValueError(f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'")
241
+ raise ValueError(
242
+ f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'"
243
+ )
186
244
  provider_id = stored_provider_id
187
245
  else:
188
246
  provider_id = provider_id or self.default_provider_id
189
-
247
+
190
248
  if not provider_id or provider_id not in self.providers:
191
249
  raise ValueError(f"Provider '{provider_id}' not found")
192
-
250
+
193
251
  provider = self.providers[provider_id]
194
-
252
+
195
253
  try:
196
254
  # 发送继续请求
197
255
  result = await provider.async_continue_humanloop(
@@ -200,42 +258,67 @@ class DefaultHumanLoopManager(HumanLoopManager):
200
258
  metadata=metadata,
201
259
  timeout=timeout,
202
260
  )
203
-
261
+
204
262
  request_id = result.request_id
205
263
 
206
264
  if not request_id:
207
- raise ValueError(f"Failed to continue humanloop for conversation '{conversation_id}'")
208
-
265
+ raise ValueError(
266
+ f"Failed to continue humanloop for conversation '{conversation_id}'"
267
+ )
268
+
209
269
  # 更新conversation_id和request_id的关系
210
270
  if conversation_id not in self._conversation_requests:
211
271
  self._conversation_requests[conversation_id] = []
212
272
  self._conversation_requests[conversation_id].append(request_id)
213
-
273
+
214
274
  # 查找此conversation_id对应的task_id
215
275
  task_id = None
216
276
  for t_id, convs in self._task_conversations.items():
217
277
  if conversation_id in convs:
218
278
  task_id = t_id
219
279
  break
220
-
280
+
221
281
  if task_id:
222
282
  self._request_task[(conversation_id, request_id)] = task_id
223
-
283
+
224
284
  # 存储对话对应的provider_id,如果对话不存在才存储
225
285
  if conversation_id not in self._conversation_provider:
226
286
  self._conversation_provider[conversation_id] = provider_id
227
-
287
+
228
288
  # 如果提供了回调,存储它
229
289
  if callback:
290
+ try:
291
+ # 创建请求对象
292
+ request = HumanLoopRequest(
293
+ task_id=task_id or "",
294
+ conversation_id=conversation_id,
295
+ loop_type=HumanLoopType.CONVERSATION,
296
+ context=context,
297
+ metadata=metadata or {},
298
+ timeout=timeout,
299
+ created_at=datetime.now(),
300
+ )
301
+ await callback.async_on_humanloop_request(provider, request)
302
+ except Exception as e:
303
+ # 处理回调执行过程中的异常
304
+ try:
305
+ await callback.async_on_humanloop_error(provider, e)
306
+ except Exception:
307
+ # 如果错误回调也失败,只能忽略
308
+ pass
230
309
  self._callbacks[(conversation_id, request_id)] = callback
231
-
310
+
232
311
  # 如果设置了超时,创建超时任务
233
312
  if timeout:
234
- await self._async_create_timeout_task(conversation_id, request_id, timeout, provider, callback)
235
-
313
+ await self._async_create_timeout_task(
314
+ conversation_id, request_id, timeout, provider, callback
315
+ )
316
+
236
317
  # 如果是阻塞模式,等待结果
237
318
  if blocking:
238
- return await self._async_wait_for_result(conversation_id, request_id, provider, timeout)
319
+ return await self._async_wait_for_result(
320
+ conversation_id, request_id, provider, timeout
321
+ )
239
322
  else:
240
323
  return request_id
241
324
  except Exception as e:
@@ -243,7 +326,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
243
326
  if callback:
244
327
  try:
245
328
  await callback.async_on_humanloop_error(provider, e)
246
- except:
329
+ except Exception:
247
330
  # 如果错误回调也失败,只能忽略
248
331
  pass
249
332
  raise # 重新抛出异常,让调用者知道发生了错误
@@ -260,42 +343,48 @@ class DefaultHumanLoopManager(HumanLoopManager):
260
343
  ) -> Union[str, HumanLoopResult]:
261
344
  """继续人机循环(同步版本)"""
262
345
 
263
- return run_async_safely(
264
- self.async_continue_humanloop(
265
- conversation_id=conversation_id,
266
- context=context,
267
- callback=callback,
268
- metadata=metadata,
269
- provider_id=provider_id,
270
- timeout=timeout,
271
- blocking=blocking
272
- )
346
+ result: Union[str, HumanLoopResult] = run_async_safely(
347
+ self.async_continue_humanloop(
348
+ conversation_id=conversation_id,
349
+ context=context,
350
+ callback=callback,
351
+ metadata=metadata,
352
+ provider_id=provider_id,
353
+ timeout=timeout,
354
+ blocking=blocking,
273
355
  )
356
+ )
357
+
358
+ return result
274
359
 
275
360
  async def async_check_request_status(
276
- self,
277
- conversation_id: str,
278
- request_id: str,
279
- provider_id: Optional[str] = None
361
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
280
362
  ) -> HumanLoopResult:
281
363
  """检查请求状态"""
282
364
  # 如果没有指定provider_id,尝试从存储的映射中获取
283
365
  if provider_id is None:
284
366
  stored_provider_id = self._conversation_provider.get(conversation_id)
285
367
  provider_id = stored_provider_id or self.default_provider_id
286
-
368
+
287
369
  if not provider_id or provider_id not in self.providers:
288
370
  raise ValueError(f"Provider '{provider_id}' not found")
289
-
371
+
290
372
  provider = self.providers[provider_id]
291
-
373
+
292
374
  try:
293
- result = await provider.async_check_request_status(conversation_id, request_id)
294
-
375
+ result = await provider.async_check_request_status(
376
+ conversation_id, request_id
377
+ )
378
+
295
379
  # 如果有回调且状态不是等待或进行中,触发状态更新回调
296
- if (conversation_id, request_id) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
297
- await self._async_trigger_update_callback(conversation_id, request_id, provider, result)
298
-
380
+ if (
381
+ conversation_id,
382
+ request_id,
383
+ ) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
384
+ await self._async_trigger_update_callback(
385
+ conversation_id, request_id, provider, result
386
+ )
387
+
299
388
  return result
300
389
  except Exception as e:
301
390
  # 处理检查状态过程中的异常
@@ -303,286 +392,270 @@ class DefaultHumanLoopManager(HumanLoopManager):
303
392
  if callback:
304
393
  try:
305
394
  await callback.async_on_humanloop_error(provider, e)
306
- except:
395
+ except Exception:
307
396
  # 如果错误回调也失败,只能忽略
308
397
  pass
309
398
  raise # 重新抛出异常,让调用者知道发生了错误
310
399
 
311
-
312
400
  def check_request_status(
313
- self,
314
- conversation_id: str,
315
- request_id: str,
316
- provider_id: Optional[str] = None
401
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
317
402
  ) -> HumanLoopResult:
318
403
  """检查请求状态(同步版本)"""
319
404
 
320
- return run_async_safely(
321
- self.async_check_request_status(
322
- conversation_id=conversation_id,
323
- request_id=request_id,
324
- provider_id=provider_id
325
- )
405
+ result: HumanLoopResult = run_async_safely(
406
+ self.async_check_request_status(
407
+ conversation_id=conversation_id,
408
+ request_id=request_id,
409
+ provider_id=provider_id,
326
410
  )
411
+ )
327
412
 
413
+ return result
328
414
 
329
415
  async def async_check_conversation_status(
330
- self,
331
- conversation_id: str,
332
- provider_id: Optional[str] = None
416
+ self, conversation_id: str, provider_id: Optional[str] = None
333
417
  ) -> HumanLoopResult:
334
418
  """检查对话状态"""
335
419
  # 优先使用对话已关联的提供者
336
420
  if provider_id is None:
337
421
  stored_provider_id = self._conversation_provider.get(conversation_id)
338
422
  provider_id = stored_provider_id or self.default_provider_id
339
-
423
+
340
424
  if not provider_id or provider_id not in self.providers:
341
425
  raise ValueError(f"Provider '{provider_id}' not found")
342
-
426
+
343
427
  # 检查对话指定provider_id或默认provider_id最后一次请求的状态
344
428
  provider = self.providers[provider_id]
345
-
429
+
346
430
  try:
347
431
  # 检查对话指定provider_id或默认provider_id最后一次请求的状态
348
432
  return await provider.async_check_conversation_status(conversation_id)
349
433
  except Exception as e:
350
434
  # 处理检查对话状态过程中的异常
351
435
  # 尝试找到与此对话关联的最后一个请求的回调
352
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
436
+ if (
437
+ conversation_id in self._conversation_requests
438
+ and self._conversation_requests[conversation_id]
439
+ ):
353
440
  last_request_id = self._conversation_requests[conversation_id][-1]
354
441
  callback = self._callbacks.get((conversation_id, last_request_id))
355
442
  if callback:
356
443
  try:
357
444
  await callback.async_on_humanloop_error(provider, e)
358
- except:
445
+ except Exception:
359
446
  # 如果错误回调也失败,只能忽略
360
447
  pass
361
448
  raise # 重新抛出异常,让调用者知道发生了错误
362
-
449
+
363
450
  def check_conversation_status(
364
- self,
365
- conversation_id: str,
366
- provider_id: Optional[str] = None
451
+ self, conversation_id: str, provider_id: Optional[str] = None
367
452
  ) -> HumanLoopResult:
368
453
  """检查对话状态(同步版本)"""
369
454
 
370
- return run_async_safely(
371
- self.async_check_conversation_status(
372
- conversation_id=conversation_id,
373
- provider_id=provider_id
374
- )
455
+ result: HumanLoopResult = run_async_safely(
456
+ self.async_check_conversation_status(
457
+ conversation_id=conversation_id, provider_id=provider_id
375
458
  )
459
+ )
460
+
461
+ return result
376
462
 
377
463
  async def async_cancel_request(
378
- self,
379
- conversation_id: str,
380
- request_id: str,
381
- provider_id: Optional[str] = None
464
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
382
465
  ) -> bool:
383
466
  """取消特定请求"""
384
467
  if provider_id is None:
385
468
  stored_provider_id = self._conversation_provider.get(conversation_id)
386
469
  provider_id = stored_provider_id or self.default_provider_id
387
-
470
+
388
471
  if not provider_id or provider_id not in self.providers:
389
472
  raise ValueError(f"Provider '{provider_id}' not found")
390
-
473
+
391
474
  provider = self.providers[provider_id]
392
475
 
393
476
  # 取消超时任务
394
477
  if (conversation_id, request_id) in self._timeout_tasks:
395
478
  self._timeout_tasks[(conversation_id, request_id)].cancel()
396
479
  del self._timeout_tasks[(conversation_id, request_id)]
397
-
480
+
398
481
  # 从回调映射中删除
399
482
  if (conversation_id, request_id) in self._callbacks:
400
483
  del self._callbacks[(conversation_id, request_id)]
401
-
484
+
402
485
  # 清理request关联
403
486
  if (conversation_id, request_id) in self._request_task:
404
487
  del self._request_task[(conversation_id, request_id)]
405
-
488
+
406
489
  # 从conversation_requests中移除
407
490
  if conversation_id in self._conversation_requests:
408
491
  if request_id in self._conversation_requests[conversation_id]:
409
492
  self._conversation_requests[conversation_id].remove(request_id)
410
-
493
+
411
494
  return await provider.async_cancel_request(conversation_id, request_id)
412
-
413
495
 
414
496
  def cancel_request(
415
- self,
416
- conversation_id: str,
417
- request_id: str,
418
- provider_id: Optional[str] = None
497
+ self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
419
498
  ) -> bool:
420
499
  """取消特定请求(同步版本)"""
421
500
 
422
- return run_async_safely(
423
- self.async_cancel_request(
424
- conversation_id=conversation_id,
425
- request_id=request_id,
426
- provider_id=provider_id
427
- )
501
+ result: bool = run_async_safely(
502
+ self.async_cancel_request(
503
+ conversation_id=conversation_id,
504
+ request_id=request_id,
505
+ provider_id=provider_id,
428
506
  )
507
+ )
508
+
509
+ return result
429
510
 
430
511
  async def async_cancel_conversation(
431
- self,
432
- conversation_id: str,
433
- provider_id: Optional[str] = None
512
+ self, conversation_id: str, provider_id: Optional[str] = None
434
513
  ) -> bool:
435
514
  """取消整个对话"""
436
515
  # 优先使用对话已关联的提供者
437
516
  if provider_id is None:
438
517
  stored_provider_id = self._conversation_provider.get(conversation_id)
439
518
  provider_id = stored_provider_id or self.default_provider_id
440
-
519
+
441
520
  if not provider_id or provider_id not in self.providers:
442
521
  raise ValueError(f"Provider '{provider_id}' not found")
443
-
522
+
444
523
  provider = self.providers[provider_id]
445
-
524
+
446
525
  # 取消与此对话相关的所有超时任务和回调
447
526
  keys_to_remove = []
448
527
  for key in self._timeout_tasks:
449
528
  if key[0] == conversation_id:
450
529
  self._timeout_tasks[key].cancel()
451
530
  keys_to_remove.append(key)
452
-
531
+
453
532
  for key in keys_to_remove:
454
533
  del self._timeout_tasks[key]
455
-
534
+
456
535
  keys_to_remove = []
457
536
  for key in self._callbacks:
458
537
  if key[0] == conversation_id:
459
538
  keys_to_remove.append(key)
460
-
539
+
461
540
  for key in keys_to_remove:
462
541
  del self._callbacks[key]
463
-
542
+
464
543
  # 清理与此对话相关的task映射关系
465
544
  # 1. 从task_conversations中移除此对话
466
545
  task_ids_to_update = []
467
546
  for task_id, convs in self._task_conversations.items():
468
547
  if conversation_id in convs:
469
548
  task_ids_to_update.append(task_id)
470
-
549
+
471
550
  for task_id in task_ids_to_update:
472
551
  self._task_conversations[task_id].remove(conversation_id)
473
552
  # 如果task没有关联的对话了,可以考虑删除该task记录
474
553
  if not self._task_conversations[task_id]:
475
554
  del self._task_conversations[task_id]
476
-
555
+
477
556
  # 2. 获取并清理所有与此对话相关的请求
478
557
  request_ids = self._conversation_requests.get(conversation_id, [])
479
558
  for request_id in request_ids:
480
559
  # 清理request_task映射
481
560
  if (conversation_id, request_id) in self._request_task:
482
561
  del self._request_task[(conversation_id, request_id)]
483
-
562
+
484
563
  # 3. 清理conversation_requests映射
485
564
  if conversation_id in self._conversation_requests:
486
565
  del self._conversation_requests[conversation_id]
487
-
566
+
488
567
  # 4. 清理provider关联
489
568
  if conversation_id in self._conversation_provider:
490
569
  del self._conversation_provider[conversation_id]
491
-
570
+
492
571
  return await provider.async_cancel_conversation(conversation_id)
493
-
494
572
 
495
573
  def cancel_conversation(
496
- self,
497
- conversation_id: str,
498
- provider_id: Optional[str] = None
574
+ self, conversation_id: str, provider_id: Optional[str] = None
499
575
  ) -> bool:
500
576
  """取消整个对话(同步版本)"""
501
-
502
- return run_async_safely(
503
- self.async_cancel_conversation(
504
- conversation_id=conversation_id,
505
- provider_id=provider_id
506
- )
577
+
578
+ result: bool = run_async_safely(
579
+ self.async_cancel_conversation(
580
+ conversation_id=conversation_id, provider_id=provider_id
507
581
  )
582
+ )
583
+
584
+ return result
508
585
 
509
586
  async def async_get_provider(
510
- self,
511
- provider_id: Optional[str] = None
587
+ self, provider_id: Optional[str] = None
512
588
  ) -> HumanLoopProvider:
513
589
  """获取指定的提供者实例"""
514
590
  provider_id = provider_id or self.default_provider_id
515
591
  if not provider_id or provider_id not in self.providers:
516
592
  raise ValueError(f"Provider '{provider_id}' not found")
517
-
593
+
518
594
  return self.providers[provider_id]
519
-
520
- def get_provider(
521
- self,
522
- provider_id: Optional[str] = None
523
- ) -> HumanLoopProvider:
595
+
596
+ def get_provider(self, provider_id: Optional[str] = None) -> HumanLoopProvider:
524
597
  """获取指定的提供者实例(同步版本)"""
525
598
 
526
- return run_async_safely(
527
- self.async_get_provider(provider_id=provider_id)
528
- )
599
+ result: HumanLoopProvider = run_async_safely(
600
+ self.async_get_provider(provider_id=provider_id)
601
+ )
602
+
603
+ return result
529
604
 
530
605
  async def async_list_providers(self) -> Dict[str, HumanLoopProvider]:
531
606
  """列出所有注册的提供者"""
532
607
  return self.providers
533
-
534
608
 
535
609
  def list_providers(self) -> Dict[str, HumanLoopProvider]:
536
610
  """列出所有注册的提供者(同步版本)"""
537
-
538
- return run_async_safely(
539
- self.async_list_providers()
540
- )
541
611
 
612
+ result: Dict[str, HumanLoopProvider] = run_async_safely(
613
+ self.async_list_providers()
614
+ )
542
615
 
616
+ return result
543
617
 
544
- async def async_set_default_provider(
545
- self,
546
- provider_id: str
547
- ) -> bool:
618
+ async def async_set_default_provider(self, provider_id: str) -> bool:
548
619
  """设置默认提供者"""
549
620
  if provider_id not in self.providers:
550
621
  raise ValueError(f"Provider '{provider_id}' not found")
551
-
622
+
552
623
  self.default_provider_id = provider_id
553
624
  return True
554
-
555
625
 
556
- def set_default_provider(
557
- self,
558
- provider_id: str
559
- ) -> bool:
626
+ def set_default_provider(self, provider_id: str) -> bool:
560
627
  """设置默认提供者(同步版本)"""
561
628
 
629
+ result: bool = run_async_safely(
630
+ self.async_set_default_provider(provider_id=provider_id)
631
+ )
562
632
 
563
- return run_async_safely(
564
- self.async_set_default_provider(provider_id=provider_id)
565
- )
633
+ return result
566
634
 
567
635
  async def _async_create_timeout_task(
568
- self,
636
+ self,
569
637
  conversation_id: str,
570
- request_id: str,
571
- timeout: int,
638
+ request_id: str,
639
+ timeout: int,
572
640
  provider: HumanLoopProvider,
573
- callback: Optional[HumanLoopCallback]
574
- ):
641
+ callback: Optional[HumanLoopCallback],
642
+ ) -> None:
575
643
  """创建超时任务"""
576
- async def timeout_task():
644
+
645
+ async def timeout_task() -> None:
577
646
  await asyncio.sleep(timeout)
578
647
  # 检查当前状态
579
- result = await self.async_check_request_status(conversation_id, request_id, provider.name)
580
-
648
+ result = await self.async_check_request_status(
649
+ conversation_id, request_id, provider.name
650
+ )
651
+
581
652
  # 只有当状态为PENDING时才触发超时回调
582
653
  # INPROGRESS状态表示对话正在进行中,不应视为超时
583
654
  if result.status == HumanLoopStatus.PENDING:
584
655
  if callback:
585
- await callback.async_on_humanloop_timeout(provider=provider)
656
+ await callback.async_on_humanloop_timeout(
657
+ provider=provider, result=result
658
+ )
586
659
  # 如果状态是INPROGRESS,重置超时任务
587
660
  elif result.status == HumanLoopStatus.INPROGRESS:
588
661
  # 对于进行中的对话,我们可以选择延长超时时间
@@ -591,55 +664,68 @@ class DefaultHumanLoopManager(HumanLoopManager):
591
664
  self._timeout_tasks[(conversation_id, request_id)].cancel()
592
665
  new_task = asyncio.create_task(timeout_task())
593
666
  self._timeout_tasks[(conversation_id, request_id)] = new_task
594
-
667
+
595
668
  task = asyncio.create_task(timeout_task())
596
669
  self._timeout_tasks[(conversation_id, request_id)] = task
597
-
670
+
598
671
  async def _async_wait_for_result(
599
- self,
672
+ self,
600
673
  conversation_id: str,
601
- request_id: str,
602
- provider: HumanLoopProvider,
603
- timeout: Optional[int] = None
674
+ request_id: str,
675
+ provider: HumanLoopProvider,
676
+ timeout: Optional[int] = None,
604
677
  ) -> HumanLoopResult:
605
678
  """等待循环结果"""
606
- start_time = time.time()
607
679
  poll_interval = 1.0 # 轮询间隔(秒)
608
-
680
+
609
681
  while True:
610
- result = await self.async_check_request_status(conversation_id, request_id, provider.name)
611
-
612
- #如果状态是最终状态(非PENDING),返回结果
682
+ result = await self.async_check_request_status(
683
+ conversation_id, request_id, provider.name
684
+ )
685
+
686
+ # 如果状态是最终状态(非PENDING),返回结果
613
687
  if result.status != HumanLoopStatus.PENDING:
614
688
  return result
615
-
689
+
616
690
  # 等待一段时间后再次轮询
617
691
  await asyncio.sleep(poll_interval)
618
-
619
- async def _async_trigger_update_callback(self, conversation_id: str, request_id: str, provider: HumanLoopProvider, result: HumanLoopResult):
692
+
693
+ async def _async_trigger_update_callback(
694
+ self,
695
+ conversation_id: str,
696
+ request_id: str,
697
+ provider: HumanLoopProvider,
698
+ result: HumanLoopResult,
699
+ ) -> None:
620
700
  """触发状态更新回调"""
621
- callback: Optional[HumanLoopCallback] = self._callbacks.get((conversation_id, request_id))
701
+ callback: Optional[HumanLoopCallback] = self._callbacks.get(
702
+ (conversation_id, request_id)
703
+ )
622
704
  if callback:
623
705
  try:
624
- await callback.on_humanloop_update(provider, result)
706
+ await callback.async_on_humanloop_update(provider, result)
625
707
  # 如果状态是最终状态,可以考虑移除回调
626
- if result.status not in [HumanLoopStatus.PENDING, HumanLoopStatus.INPROGRESS]:
708
+ if result.status not in [
709
+ HumanLoopStatus.PENDING,
710
+ HumanLoopStatus.INPROGRESS,
711
+ ]:
627
712
  del self._callbacks[(conversation_id, request_id)]
628
713
  except Exception as e:
629
714
  # 处理回调执行过程中的异常
630
715
  try:
631
- await callback.on_humanloop_error(provider, e)
632
- except:
716
+ await callback.async_on_humanloop_error(provider, e)
717
+ except Exception:
633
718
  # 如果错误回调也失败,只能忽略
634
719
  pass
635
720
 
636
721
  # 添加新方法用于获取task相关信息
722
+
637
723
  async def async_get_task_conversations(self, task_id: str) -> List[str]:
638
724
  """获取任务关联的所有对话ID
639
-
725
+
640
726
  Args:
641
727
  task_id: 任务ID
642
-
728
+
643
729
  Returns:
644
730
  List[str]: 与任务关联的对话ID列表
645
731
  """
@@ -647,106 +733,129 @@ class DefaultHumanLoopManager(HumanLoopManager):
647
733
 
648
734
  def get_task_conversations(self, task_id: str) -> List[str]:
649
735
  """获取任务关联的所有对话ID
650
-
736
+
651
737
  Args:
652
738
  task_id: 任务ID
653
-
739
+
654
740
  Returns:
655
741
  List[str]: 与任务关联的对话ID列表
656
742
  """
657
743
  return list(self._task_conversations.get(task_id, set()))
658
-
744
+
659
745
  async def async_get_conversation_requests(self, conversation_id: str) -> List[str]:
660
746
  """获取对话关联的所有请求ID
661
-
747
+
662
748
  Args:
663
749
  conversation_id: 对话ID
664
-
750
+
665
751
  Returns:
666
752
  List[str]: 与对话关联的请求ID列表
667
753
  """
668
- return self._conversation_requests.get(conversation_id, [])
754
+ ret: List[str] = self._conversation_requests.get(conversation_id, [])
755
+ return ret
669
756
 
670
757
  def get_conversation_requests(self, conversation_id: str) -> List[str]:
671
758
  """获取对话关联的所有请求ID
672
-
759
+
673
760
  Args:
674
761
  conversation_id: 对话ID
675
-
762
+
676
763
  Returns:
677
764
  List[str]: 与对话关联的请求ID列表
678
765
  """
679
- return self._conversation_requests.get(conversation_id, [])
680
-
681
- async def async_get_request_task(self, conversation_id: str, request_id: str) -> Optional[str]:
766
+ ret: List[str] = self._conversation_requests.get(conversation_id, [])
767
+
768
+ return ret
769
+
770
+ async def async_get_request_task(
771
+ self, conversation_id: str, request_id: str
772
+ ) -> Optional[str]:
682
773
  """获取请求关联的任务ID
683
-
774
+
684
775
  Args:
685
776
  conversation_id: 对话ID
686
777
  request_id: 请求ID
687
-
778
+
688
779
  Returns:
689
780
  Optional[str]: 关联的任务ID,如果不存在则返回None
690
781
  """
691
- return self._request_task.get((conversation_id, request_id))
782
+ ret: Optional[str] = self._request_task.get((conversation_id, request_id))
783
+
784
+ return ret
692
785
 
693
- async def async_get_conversation_provider(self, conversation_id: str) -> Optional[str]:
786
+ async def async_get_conversation_provider(
787
+ self, conversation_id: str
788
+ ) -> Optional[str]:
694
789
  """获取请求关联的提供者ID
695
-
790
+
696
791
  Args:
697
792
  conversation_id: 对话ID
698
-
793
+
699
794
  Returns:
700
795
  Optional[str]: 关联的提供者ID,如果不存在则返回None
701
796
  """
702
- return self._conversation_provider.get(conversation_id)
797
+ ret: Optional[str] = self._conversation_provider.get(conversation_id)
798
+
799
+ return ret
703
800
 
704
801
  async def async_check_conversation_exist(
705
802
  self,
706
- task_id:str,
803
+ task_id: str,
707
804
  conversation_id: str,
708
805
  ) -> bool:
709
806
  """判断对话是否已存在
710
-
807
+
711
808
  Args:
712
809
  conversation_id: 对话标识符
713
810
  provider_id: 使用特定提供者的ID(可选)
714
-
811
+
715
812
  Returns:
716
813
  bool: 如果对话存在返回True,否则返回False
717
814
  """
718
815
  # 检查task_id是否存在且conversation_id是否在该task的对话集合中
719
- if task_id in self._task_conversations and conversation_id in self._task_conversations[task_id]:
816
+ if (
817
+ task_id in self._task_conversations
818
+ and conversation_id in self._task_conversations[task_id]
819
+ ):
720
820
  # 进一步验证该对话是否有关联的请求
721
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
821
+ if (
822
+ conversation_id in self._conversation_requests
823
+ and self._conversation_requests[conversation_id]
824
+ ):
722
825
  return True
723
826
 
724
827
  return False
725
828
 
726
829
  def check_conversation_exist(
727
830
  self,
728
- task_id:str,
831
+ task_id: str,
729
832
  conversation_id: str,
730
833
  ) -> bool:
731
834
  """判断对话是否已存在
732
-
835
+
733
836
  Args:
734
837
  conversation_id: 对话标识符
735
838
  provider_id: 使用特定提供者的ID(可选)
736
-
839
+
737
840
  Returns:
738
841
  bool: 如果对话存在返回True,否则返回False
739
842
  """
740
843
  # 检查task_id是否存在且conversation_id是否在该task的对话集合中
741
- if task_id in self._task_conversations and conversation_id in self._task_conversations[task_id]:
844
+ if (
845
+ task_id in self._task_conversations
846
+ and conversation_id in self._task_conversations[task_id]
847
+ ):
742
848
  # 进一步验证该对话是否有关联的请求
743
- if conversation_id in self._conversation_requests and self._conversation_requests[conversation_id]:
849
+ if (
850
+ conversation_id in self._conversation_requests
851
+ and self._conversation_requests[conversation_id]
852
+ ):
744
853
  return True
745
854
 
746
855
  return False
747
856
 
748
- async def async_shutdown(self):
749
- pass
857
+ async def async_shutdown(self) -> None:
858
+ pass
750
859
 
751
- def shutdown(self):
752
- pass
860
+ def shutdown(self) -> None:
861
+ pass