gohumanloop 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gohumanloop/__init__.py +6 -8
- gohumanloop/adapters/__init__.py +4 -4
- gohumanloop/adapters/langgraph_adapter.py +348 -207
- gohumanloop/cli/main.py +4 -1
- gohumanloop/core/interface.py +181 -215
- gohumanloop/core/manager.py +332 -265
- gohumanloop/manager/ghl_manager.py +223 -185
- gohumanloop/models/api_model.py +32 -7
- gohumanloop/models/glh_model.py +15 -11
- gohumanloop/providers/api_provider.py +233 -189
- gohumanloop/providers/base.py +179 -172
- gohumanloop/providers/email_provider.py +386 -325
- gohumanloop/providers/ghl_provider.py +19 -17
- gohumanloop/providers/terminal_provider.py +111 -92
- gohumanloop/utils/__init__.py +7 -1
- gohumanloop/utils/context_formatter.py +20 -15
- gohumanloop/utils/threadsafedict.py +64 -56
- gohumanloop/utils/utils.py +28 -28
- gohumanloop-0.0.6.dist-info/METADATA +259 -0
- gohumanloop-0.0.6.dist-info/RECORD +30 -0
- {gohumanloop-0.0.5.dist-info → gohumanloop-0.0.6.dist-info}/WHEEL +1 -1
- gohumanloop-0.0.5.dist-info/METADATA +0 -35
- gohumanloop-0.0.5.dist-info/RECORD +0 -30
- {gohumanloop-0.0.5.dist-info → gohumanloop-0.0.6.dist-info}/entry_points.txt +0 -0
- {gohumanloop-0.0.5.dist-info → gohumanloop-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {gohumanloop-0.0.5.dist-info → gohumanloop-0.0.6.dist-info}/top_level.txt +0 -0
gohumanloop/core/manager.py
CHANGED
@@ -1,20 +1,51 @@
|
|
1
|
-
from typing import Dict, Any, Optional, List, Union
|
1
|
+
from typing import Dict, Any, Optional, List, Union
|
2
2
|
import asyncio
|
3
|
-
import time
|
4
3
|
from gohumanloop.utils import run_async_safely
|
5
4
|
|
6
5
|
from gohumanloop.core.interface import (
|
7
|
-
HumanLoopManager,
|
8
|
-
|
6
|
+
HumanLoopManager,
|
7
|
+
HumanLoopProvider,
|
8
|
+
HumanLoopCallback,
|
9
|
+
HumanLoopResult,
|
10
|
+
HumanLoopStatus,
|
11
|
+
HumanLoopType,
|
9
12
|
)
|
10
13
|
|
14
|
+
|
11
15
|
class DefaultHumanLoopManager(HumanLoopManager):
|
12
16
|
"""默认人机循环管理器实现"""
|
13
|
-
|
14
|
-
def __init__(
|
15
|
-
self
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
initial_providers: Optional[
|
21
|
+
Union[HumanLoopProvider, List[HumanLoopProvider]]
|
22
|
+
] = None,
|
23
|
+
):
|
24
|
+
self.providers: dict[str, HumanLoopProvider] = {}
|
16
25
|
self.default_provider_id = None
|
17
|
-
|
26
|
+
|
27
|
+
# 存储请求和回调的映射
|
28
|
+
self._callbacks: dict[tuple[str, str], HumanLoopCallback] = {}
|
29
|
+
# 存储请求的超时任务
|
30
|
+
self._timeout_tasks: dict[tuple[str, str], asyncio.Task] = {}
|
31
|
+
|
32
|
+
# 存储task_id与conversation_id的映射关系
|
33
|
+
self._task_conversations: dict[
|
34
|
+
str, set[str]
|
35
|
+
] = {} # task_id -> Set[conversation_id]
|
36
|
+
# 存储conversation_id与request_id的映射关系
|
37
|
+
self._conversation_requests: dict[
|
38
|
+
str, list[str]
|
39
|
+
] = {} # conversation_id -> List[request_id]
|
40
|
+
# 存储request_id与task_id的反向映射
|
41
|
+
self._request_task: dict[
|
42
|
+
tuple[str, str], str
|
43
|
+
] = {} # (conversation_id, request_id) -> task_id
|
44
|
+
# 存储对话对应的provider_id
|
45
|
+
self._conversation_provider: dict[
|
46
|
+
str, str
|
47
|
+
] = {} # conversation_id -> provider_id
|
48
|
+
|
18
49
|
# 初始化提供者
|
19
50
|
if initial_providers:
|
20
51
|
if isinstance(initial_providers, list):
|
@@ -27,41 +58,33 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
27
58
|
# 处理单个提供者
|
28
59
|
self.register_provider_sync(initial_providers, initial_providers.name)
|
29
60
|
self.default_provider_id = initial_providers.name
|
30
|
-
|
31
|
-
|
32
|
-
self
|
33
|
-
|
34
|
-
self._timeout_tasks = {}
|
35
|
-
|
36
|
-
# 存储task_id与conversation_id的映射关系
|
37
|
-
self._task_conversations = {} # task_id -> Set[conversation_id]
|
38
|
-
# 存储conversation_id与request_id的映射关系
|
39
|
-
self._conversation_requests = {} # conversation_id -> List[request_id]
|
40
|
-
# 存储request_id与task_id的反向映射
|
41
|
-
self._request_task = {} # (conversation_id, request_id) -> task_id
|
42
|
-
# 存储对话对应的provider_id
|
43
|
-
self._conversation_provider = {} # conversation_id -> provider_id
|
44
|
-
|
45
|
-
def register_provider_sync(self, provider: HumanLoopProvider, provider_id: Optional[str]) -> str:
|
61
|
+
|
62
|
+
def register_provider_sync(
|
63
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str]
|
64
|
+
) -> str:
|
46
65
|
"""同步注册提供者(用于初始化)"""
|
47
66
|
if not provider_id:
|
48
67
|
provider_id = f"provider_{len(self.providers) + 1}"
|
49
|
-
|
68
|
+
|
50
69
|
self.providers[provider_id] = provider
|
51
|
-
|
70
|
+
|
52
71
|
if not self.default_provider_id:
|
53
72
|
self.default_provider_id = provider_id
|
54
|
-
|
73
|
+
|
55
74
|
return provider_id
|
56
|
-
|
57
|
-
async def async_register_provider(
|
75
|
+
|
76
|
+
async def async_register_provider(
|
77
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str] = None
|
78
|
+
) -> str:
|
58
79
|
"""注册人机循环提供者"""
|
59
80
|
return self.register_provider_sync(provider, provider_id)
|
60
|
-
|
61
|
-
def register_provider(
|
81
|
+
|
82
|
+
def register_provider(
|
83
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str] = None
|
84
|
+
) -> str:
|
62
85
|
"""注册人机循环提供者(同步版本)"""
|
63
86
|
return self.register_provider_sync(provider, provider_id)
|
64
|
-
|
87
|
+
|
65
88
|
async def async_request_humanloop(
|
66
89
|
self,
|
67
90
|
task_id: str,
|
@@ -79,13 +102,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
79
102
|
provider_id = provider_id or self.default_provider_id
|
80
103
|
if not provider_id or provider_id not in self.providers:
|
81
104
|
raise ValueError(f"Provider '{provider_id}' not found")
|
82
|
-
|
105
|
+
|
83
106
|
# 检查对话是否已存在且使用了不同的提供者
|
84
|
-
if
|
85
|
-
|
86
|
-
|
107
|
+
if (
|
108
|
+
conversation_id in self._conversation_provider
|
109
|
+
and self._conversation_provider[conversation_id] != provider_id
|
110
|
+
):
|
111
|
+
raise ValueError(
|
112
|
+
f"Conversation '{conversation_id}' already exists with a different provider"
|
113
|
+
)
|
114
|
+
|
87
115
|
provider = self.providers[provider_id]
|
88
|
-
|
116
|
+
|
89
117
|
try:
|
90
118
|
# 发送请求
|
91
119
|
result = await provider.async_request_humanloop(
|
@@ -94,38 +122,44 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
94
122
|
loop_type=loop_type,
|
95
123
|
context=context,
|
96
124
|
metadata=metadata,
|
97
|
-
timeout=timeout
|
125
|
+
timeout=timeout,
|
98
126
|
)
|
99
|
-
|
127
|
+
|
100
128
|
request_id = result.request_id
|
101
|
-
|
129
|
+
|
102
130
|
if not request_id:
|
103
|
-
raise ValueError(
|
104
|
-
|
131
|
+
raise ValueError(
|
132
|
+
f"Failed to request humanloop for conversation '{conversation_id}'"
|
133
|
+
)
|
134
|
+
|
105
135
|
# 存储task_id、conversation_id和request_id的关系
|
106
136
|
if task_id not in self._task_conversations:
|
107
137
|
self._task_conversations[task_id] = set()
|
108
138
|
self._task_conversations[task_id].add(conversation_id)
|
109
|
-
|
139
|
+
|
110
140
|
if conversation_id not in self._conversation_requests:
|
111
141
|
self._conversation_requests[conversation_id] = []
|
112
142
|
self._conversation_requests[conversation_id].append(request_id)
|
113
|
-
|
143
|
+
|
114
144
|
self._request_task[(conversation_id, request_id)] = task_id
|
115
145
|
# 存储对话对应的provider_id
|
116
146
|
self._conversation_provider[conversation_id] = provider_id
|
117
|
-
|
147
|
+
|
118
148
|
# 如果提供了回调,存储它
|
119
149
|
if callback:
|
120
150
|
self._callbacks[(conversation_id, request_id)] = callback
|
121
|
-
|
151
|
+
|
122
152
|
# 如果设置了超时,创建超时任务
|
123
153
|
if timeout:
|
124
|
-
await self._async_create_timeout_task(
|
125
|
-
|
154
|
+
await self._async_create_timeout_task(
|
155
|
+
conversation_id, request_id, timeout, provider, callback
|
156
|
+
)
|
157
|
+
|
126
158
|
# 如果是阻塞模式,等待结果
|
127
159
|
if blocking:
|
128
|
-
return await self._async_wait_for_result(
|
160
|
+
return await self._async_wait_for_result(
|
161
|
+
conversation_id, request_id, provider, timeout
|
162
|
+
)
|
129
163
|
else:
|
130
164
|
return request_id
|
131
165
|
except Exception as e:
|
@@ -133,7 +167,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
133
167
|
if callback:
|
134
168
|
try:
|
135
169
|
await callback.async_on_humanloop_error(provider, e)
|
136
|
-
except:
|
170
|
+
except Exception:
|
137
171
|
# 如果错误回调也失败,只能忽略
|
138
172
|
pass
|
139
173
|
raise # 重新抛出异常,让调用者知道发生了错误
|
@@ -152,20 +186,21 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
152
186
|
) -> Union[str, HumanLoopResult]:
|
153
187
|
"""请求人机循环(同步版本)"""
|
154
188
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
)
|
189
|
+
result: Union[str, HumanLoopResult] = run_async_safely(
|
190
|
+
self.async_request_humanloop(
|
191
|
+
task_id=task_id,
|
192
|
+
conversation_id=conversation_id,
|
193
|
+
loop_type=loop_type,
|
194
|
+
context=context,
|
195
|
+
callback=callback,
|
196
|
+
metadata=metadata,
|
197
|
+
provider_id=provider_id,
|
198
|
+
timeout=timeout,
|
199
|
+
blocking=blocking,
|
167
200
|
)
|
201
|
+
)
|
168
202
|
|
203
|
+
return result
|
169
204
|
|
170
205
|
async def async_continue_humanloop(
|
171
206
|
self,
|
@@ -182,16 +217,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
182
217
|
if conversation_id in self._conversation_provider:
|
183
218
|
stored_provider_id = self._conversation_provider[conversation_id]
|
184
219
|
if provider_id and provider_id != stored_provider_id:
|
185
|
-
raise ValueError(
|
220
|
+
raise ValueError(
|
221
|
+
f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'"
|
222
|
+
)
|
186
223
|
provider_id = stored_provider_id
|
187
224
|
else:
|
188
225
|
provider_id = provider_id or self.default_provider_id
|
189
|
-
|
226
|
+
|
190
227
|
if not provider_id or provider_id not in self.providers:
|
191
228
|
raise ValueError(f"Provider '{provider_id}' not found")
|
192
|
-
|
229
|
+
|
193
230
|
provider = self.providers[provider_id]
|
194
|
-
|
231
|
+
|
195
232
|
try:
|
196
233
|
# 发送继续请求
|
197
234
|
result = await provider.async_continue_humanloop(
|
@@ -200,42 +237,48 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
200
237
|
metadata=metadata,
|
201
238
|
timeout=timeout,
|
202
239
|
)
|
203
|
-
|
240
|
+
|
204
241
|
request_id = result.request_id
|
205
242
|
|
206
243
|
if not request_id:
|
207
|
-
raise ValueError(
|
208
|
-
|
244
|
+
raise ValueError(
|
245
|
+
f"Failed to continue humanloop for conversation '{conversation_id}'"
|
246
|
+
)
|
247
|
+
|
209
248
|
# 更新conversation_id和request_id的关系
|
210
249
|
if conversation_id not in self._conversation_requests:
|
211
250
|
self._conversation_requests[conversation_id] = []
|
212
251
|
self._conversation_requests[conversation_id].append(request_id)
|
213
|
-
|
252
|
+
|
214
253
|
# 查找此conversation_id对应的task_id
|
215
254
|
task_id = None
|
216
255
|
for t_id, convs in self._task_conversations.items():
|
217
256
|
if conversation_id in convs:
|
218
257
|
task_id = t_id
|
219
258
|
break
|
220
|
-
|
259
|
+
|
221
260
|
if task_id:
|
222
261
|
self._request_task[(conversation_id, request_id)] = task_id
|
223
|
-
|
262
|
+
|
224
263
|
# 存储对话对应的provider_id,如果对话不存在才存储
|
225
264
|
if conversation_id not in self._conversation_provider:
|
226
265
|
self._conversation_provider[conversation_id] = provider_id
|
227
|
-
|
266
|
+
|
228
267
|
# 如果提供了回调,存储它
|
229
268
|
if callback:
|
230
269
|
self._callbacks[(conversation_id, request_id)] = callback
|
231
|
-
|
270
|
+
|
232
271
|
# 如果设置了超时,创建超时任务
|
233
272
|
if timeout:
|
234
|
-
await self._async_create_timeout_task(
|
235
|
-
|
273
|
+
await self._async_create_timeout_task(
|
274
|
+
conversation_id, request_id, timeout, provider, callback
|
275
|
+
)
|
276
|
+
|
236
277
|
# 如果是阻塞模式,等待结果
|
237
278
|
if blocking:
|
238
|
-
return await self._async_wait_for_result(
|
279
|
+
return await self._async_wait_for_result(
|
280
|
+
conversation_id, request_id, provider, timeout
|
281
|
+
)
|
239
282
|
else:
|
240
283
|
return request_id
|
241
284
|
except Exception as e:
|
@@ -243,7 +286,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
243
286
|
if callback:
|
244
287
|
try:
|
245
288
|
await callback.async_on_humanloop_error(provider, e)
|
246
|
-
except:
|
289
|
+
except Exception:
|
247
290
|
# 如果错误回调也失败,只能忽略
|
248
291
|
pass
|
249
292
|
raise # 重新抛出异常,让调用者知道发生了错误
|
@@ -260,42 +303,48 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
260
303
|
) -> Union[str, HumanLoopResult]:
|
261
304
|
"""继续人机循环(同步版本)"""
|
262
305
|
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
)
|
306
|
+
result: Union[str, HumanLoopResult] = run_async_safely(
|
307
|
+
self.async_continue_humanloop(
|
308
|
+
conversation_id=conversation_id,
|
309
|
+
context=context,
|
310
|
+
callback=callback,
|
311
|
+
metadata=metadata,
|
312
|
+
provider_id=provider_id,
|
313
|
+
timeout=timeout,
|
314
|
+
blocking=blocking,
|
273
315
|
)
|
316
|
+
)
|
317
|
+
|
318
|
+
return result
|
274
319
|
|
275
320
|
async def async_check_request_status(
|
276
|
-
self,
|
277
|
-
conversation_id: str,
|
278
|
-
request_id: str,
|
279
|
-
provider_id: Optional[str] = None
|
321
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
280
322
|
) -> HumanLoopResult:
|
281
323
|
"""检查请求状态"""
|
282
324
|
# 如果没有指定provider_id,尝试从存储的映射中获取
|
283
325
|
if provider_id is None:
|
284
326
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
285
327
|
provider_id = stored_provider_id or self.default_provider_id
|
286
|
-
|
328
|
+
|
287
329
|
if not provider_id or provider_id not in self.providers:
|
288
330
|
raise ValueError(f"Provider '{provider_id}' not found")
|
289
|
-
|
331
|
+
|
290
332
|
provider = self.providers[provider_id]
|
291
|
-
|
333
|
+
|
292
334
|
try:
|
293
|
-
result = await provider.async_check_request_status(
|
294
|
-
|
335
|
+
result = await provider.async_check_request_status(
|
336
|
+
conversation_id, request_id
|
337
|
+
)
|
338
|
+
|
295
339
|
# 如果有回调且状态不是等待或进行中,触发状态更新回调
|
296
|
-
if (
|
297
|
-
|
298
|
-
|
340
|
+
if (
|
341
|
+
conversation_id,
|
342
|
+
request_id,
|
343
|
+
) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
|
344
|
+
await self._async_trigger_update_callback(
|
345
|
+
conversation_id, request_id, provider, result
|
346
|
+
)
|
347
|
+
|
299
348
|
return result
|
300
349
|
except Exception as e:
|
301
350
|
# 处理检查状态过程中的异常
|
@@ -303,281 +352,263 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
303
352
|
if callback:
|
304
353
|
try:
|
305
354
|
await callback.async_on_humanloop_error(provider, e)
|
306
|
-
except:
|
355
|
+
except Exception:
|
307
356
|
# 如果错误回调也失败,只能忽略
|
308
357
|
pass
|
309
358
|
raise # 重新抛出异常,让调用者知道发生了错误
|
310
359
|
|
311
|
-
|
312
360
|
def check_request_status(
|
313
|
-
self,
|
314
|
-
conversation_id: str,
|
315
|
-
request_id: str,
|
316
|
-
provider_id: Optional[str] = None
|
361
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
317
362
|
) -> HumanLoopResult:
|
318
363
|
"""检查请求状态(同步版本)"""
|
319
364
|
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
)
|
365
|
+
result: HumanLoopResult = run_async_safely(
|
366
|
+
self.async_check_request_status(
|
367
|
+
conversation_id=conversation_id,
|
368
|
+
request_id=request_id,
|
369
|
+
provider_id=provider_id,
|
326
370
|
)
|
371
|
+
)
|
327
372
|
|
373
|
+
return result
|
328
374
|
|
329
375
|
async def async_check_conversation_status(
|
330
|
-
self,
|
331
|
-
conversation_id: str,
|
332
|
-
provider_id: Optional[str] = None
|
376
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
333
377
|
) -> HumanLoopResult:
|
334
378
|
"""检查对话状态"""
|
335
379
|
# 优先使用对话已关联的提供者
|
336
380
|
if provider_id is None:
|
337
381
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
338
382
|
provider_id = stored_provider_id or self.default_provider_id
|
339
|
-
|
383
|
+
|
340
384
|
if not provider_id or provider_id not in self.providers:
|
341
385
|
raise ValueError(f"Provider '{provider_id}' not found")
|
342
|
-
|
386
|
+
|
343
387
|
# 检查对话指定provider_id或默认provider_id最后一次请求的状态
|
344
388
|
provider = self.providers[provider_id]
|
345
|
-
|
389
|
+
|
346
390
|
try:
|
347
391
|
# 检查对话指定provider_id或默认provider_id最后一次请求的状态
|
348
392
|
return await provider.async_check_conversation_status(conversation_id)
|
349
393
|
except Exception as e:
|
350
394
|
# 处理检查对话状态过程中的异常
|
351
395
|
# 尝试找到与此对话关联的最后一个请求的回调
|
352
|
-
if
|
396
|
+
if (
|
397
|
+
conversation_id in self._conversation_requests
|
398
|
+
and self._conversation_requests[conversation_id]
|
399
|
+
):
|
353
400
|
last_request_id = self._conversation_requests[conversation_id][-1]
|
354
401
|
callback = self._callbacks.get((conversation_id, last_request_id))
|
355
402
|
if callback:
|
356
403
|
try:
|
357
404
|
await callback.async_on_humanloop_error(provider, e)
|
358
|
-
except:
|
405
|
+
except Exception:
|
359
406
|
# 如果错误回调也失败,只能忽略
|
360
407
|
pass
|
361
408
|
raise # 重新抛出异常,让调用者知道发生了错误
|
362
|
-
|
409
|
+
|
363
410
|
def check_conversation_status(
|
364
|
-
self,
|
365
|
-
conversation_id: str,
|
366
|
-
provider_id: Optional[str] = None
|
411
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
367
412
|
) -> HumanLoopResult:
|
368
413
|
"""检查对话状态(同步版本)"""
|
369
414
|
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
provider_id=provider_id
|
374
|
-
)
|
415
|
+
result: HumanLoopResult = run_async_safely(
|
416
|
+
self.async_check_conversation_status(
|
417
|
+
conversation_id=conversation_id, provider_id=provider_id
|
375
418
|
)
|
419
|
+
)
|
420
|
+
|
421
|
+
return result
|
376
422
|
|
377
423
|
async def async_cancel_request(
|
378
|
-
self,
|
379
|
-
conversation_id: str,
|
380
|
-
request_id: str,
|
381
|
-
provider_id: Optional[str] = None
|
424
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
382
425
|
) -> bool:
|
383
426
|
"""取消特定请求"""
|
384
427
|
if provider_id is None:
|
385
428
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
386
429
|
provider_id = stored_provider_id or self.default_provider_id
|
387
|
-
|
430
|
+
|
388
431
|
if not provider_id or provider_id not in self.providers:
|
389
432
|
raise ValueError(f"Provider '{provider_id}' not found")
|
390
|
-
|
433
|
+
|
391
434
|
provider = self.providers[provider_id]
|
392
435
|
|
393
436
|
# 取消超时任务
|
394
437
|
if (conversation_id, request_id) in self._timeout_tasks:
|
395
438
|
self._timeout_tasks[(conversation_id, request_id)].cancel()
|
396
439
|
del self._timeout_tasks[(conversation_id, request_id)]
|
397
|
-
|
440
|
+
|
398
441
|
# 从回调映射中删除
|
399
442
|
if (conversation_id, request_id) in self._callbacks:
|
400
443
|
del self._callbacks[(conversation_id, request_id)]
|
401
|
-
|
444
|
+
|
402
445
|
# 清理request关联
|
403
446
|
if (conversation_id, request_id) in self._request_task:
|
404
447
|
del self._request_task[(conversation_id, request_id)]
|
405
|
-
|
448
|
+
|
406
449
|
# 从conversation_requests中移除
|
407
450
|
if conversation_id in self._conversation_requests:
|
408
451
|
if request_id in self._conversation_requests[conversation_id]:
|
409
452
|
self._conversation_requests[conversation_id].remove(request_id)
|
410
|
-
|
453
|
+
|
411
454
|
return await provider.async_cancel_request(conversation_id, request_id)
|
412
|
-
|
413
455
|
|
414
456
|
def cancel_request(
|
415
|
-
self,
|
416
|
-
conversation_id: str,
|
417
|
-
request_id: str,
|
418
|
-
provider_id: Optional[str] = None
|
457
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
419
458
|
) -> bool:
|
420
459
|
"""取消特定请求(同步版本)"""
|
421
460
|
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
)
|
461
|
+
result: bool = run_async_safely(
|
462
|
+
self.async_cancel_request(
|
463
|
+
conversation_id=conversation_id,
|
464
|
+
request_id=request_id,
|
465
|
+
provider_id=provider_id,
|
428
466
|
)
|
467
|
+
)
|
468
|
+
|
469
|
+
return result
|
429
470
|
|
430
471
|
async def async_cancel_conversation(
|
431
|
-
self,
|
432
|
-
conversation_id: str,
|
433
|
-
provider_id: Optional[str] = None
|
472
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
434
473
|
) -> bool:
|
435
474
|
"""取消整个对话"""
|
436
475
|
# 优先使用对话已关联的提供者
|
437
476
|
if provider_id is None:
|
438
477
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
439
478
|
provider_id = stored_provider_id or self.default_provider_id
|
440
|
-
|
479
|
+
|
441
480
|
if not provider_id or provider_id not in self.providers:
|
442
481
|
raise ValueError(f"Provider '{provider_id}' not found")
|
443
|
-
|
482
|
+
|
444
483
|
provider = self.providers[provider_id]
|
445
|
-
|
484
|
+
|
446
485
|
# 取消与此对话相关的所有超时任务和回调
|
447
486
|
keys_to_remove = []
|
448
487
|
for key in self._timeout_tasks:
|
449
488
|
if key[0] == conversation_id:
|
450
489
|
self._timeout_tasks[key].cancel()
|
451
490
|
keys_to_remove.append(key)
|
452
|
-
|
491
|
+
|
453
492
|
for key in keys_to_remove:
|
454
493
|
del self._timeout_tasks[key]
|
455
|
-
|
494
|
+
|
456
495
|
keys_to_remove = []
|
457
496
|
for key in self._callbacks:
|
458
497
|
if key[0] == conversation_id:
|
459
498
|
keys_to_remove.append(key)
|
460
|
-
|
499
|
+
|
461
500
|
for key in keys_to_remove:
|
462
501
|
del self._callbacks[key]
|
463
|
-
|
502
|
+
|
464
503
|
# 清理与此对话相关的task映射关系
|
465
504
|
# 1. 从task_conversations中移除此对话
|
466
505
|
task_ids_to_update = []
|
467
506
|
for task_id, convs in self._task_conversations.items():
|
468
507
|
if conversation_id in convs:
|
469
508
|
task_ids_to_update.append(task_id)
|
470
|
-
|
509
|
+
|
471
510
|
for task_id in task_ids_to_update:
|
472
511
|
self._task_conversations[task_id].remove(conversation_id)
|
473
512
|
# 如果task没有关联的对话了,可以考虑删除该task记录
|
474
513
|
if not self._task_conversations[task_id]:
|
475
514
|
del self._task_conversations[task_id]
|
476
|
-
|
515
|
+
|
477
516
|
# 2. 获取并清理所有与此对话相关的请求
|
478
517
|
request_ids = self._conversation_requests.get(conversation_id, [])
|
479
518
|
for request_id in request_ids:
|
480
519
|
# 清理request_task映射
|
481
520
|
if (conversation_id, request_id) in self._request_task:
|
482
521
|
del self._request_task[(conversation_id, request_id)]
|
483
|
-
|
522
|
+
|
484
523
|
# 3. 清理conversation_requests映射
|
485
524
|
if conversation_id in self._conversation_requests:
|
486
525
|
del self._conversation_requests[conversation_id]
|
487
|
-
|
526
|
+
|
488
527
|
# 4. 清理provider关联
|
489
528
|
if conversation_id in self._conversation_provider:
|
490
529
|
del self._conversation_provider[conversation_id]
|
491
|
-
|
530
|
+
|
492
531
|
return await provider.async_cancel_conversation(conversation_id)
|
493
|
-
|
494
532
|
|
495
533
|
def cancel_conversation(
|
496
|
-
self,
|
497
|
-
conversation_id: str,
|
498
|
-
provider_id: Optional[str] = None
|
534
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
499
535
|
) -> bool:
|
500
536
|
"""取消整个对话(同步版本)"""
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
provider_id=provider_id
|
506
|
-
)
|
537
|
+
|
538
|
+
result: bool = run_async_safely(
|
539
|
+
self.async_cancel_conversation(
|
540
|
+
conversation_id=conversation_id, provider_id=provider_id
|
507
541
|
)
|
542
|
+
)
|
543
|
+
|
544
|
+
return result
|
508
545
|
|
509
546
|
async def async_get_provider(
|
510
|
-
self,
|
511
|
-
provider_id: Optional[str] = None
|
547
|
+
self, provider_id: Optional[str] = None
|
512
548
|
) -> HumanLoopProvider:
|
513
549
|
"""获取指定的提供者实例"""
|
514
550
|
provider_id = provider_id or self.default_provider_id
|
515
551
|
if not provider_id or provider_id not in self.providers:
|
516
552
|
raise ValueError(f"Provider '{provider_id}' not found")
|
517
|
-
|
553
|
+
|
518
554
|
return self.providers[provider_id]
|
519
|
-
|
520
|
-
def get_provider(
|
521
|
-
self,
|
522
|
-
provider_id: Optional[str] = None
|
523
|
-
) -> HumanLoopProvider:
|
555
|
+
|
556
|
+
def get_provider(self, provider_id: Optional[str] = None) -> HumanLoopProvider:
|
524
557
|
"""获取指定的提供者实例(同步版本)"""
|
525
558
|
|
526
|
-
|
527
|
-
|
528
|
-
|
559
|
+
result: HumanLoopProvider = run_async_safely(
|
560
|
+
self.async_get_provider(provider_id=provider_id)
|
561
|
+
)
|
562
|
+
|
563
|
+
return result
|
529
564
|
|
530
565
|
async def async_list_providers(self) -> Dict[str, HumanLoopProvider]:
|
531
566
|
"""列出所有注册的提供者"""
|
532
567
|
return self.providers
|
533
|
-
|
534
568
|
|
535
569
|
def list_providers(self) -> Dict[str, HumanLoopProvider]:
|
536
570
|
"""列出所有注册的提供者(同步版本)"""
|
537
|
-
|
538
|
-
return run_async_safely(
|
539
|
-
self.async_list_providers()
|
540
|
-
)
|
541
571
|
|
572
|
+
result: Dict[str, HumanLoopProvider] = run_async_safely(
|
573
|
+
self.async_list_providers()
|
574
|
+
)
|
542
575
|
|
576
|
+
return result
|
543
577
|
|
544
|
-
async def async_set_default_provider(
|
545
|
-
self,
|
546
|
-
provider_id: str
|
547
|
-
) -> bool:
|
578
|
+
async def async_set_default_provider(self, provider_id: str) -> bool:
|
548
579
|
"""设置默认提供者"""
|
549
580
|
if provider_id not in self.providers:
|
550
581
|
raise ValueError(f"Provider '{provider_id}' not found")
|
551
|
-
|
582
|
+
|
552
583
|
self.default_provider_id = provider_id
|
553
584
|
return True
|
554
|
-
|
555
585
|
|
556
|
-
def set_default_provider(
|
557
|
-
self,
|
558
|
-
provider_id: str
|
559
|
-
) -> bool:
|
586
|
+
def set_default_provider(self, provider_id: str) -> bool:
|
560
587
|
"""设置默认提供者(同步版本)"""
|
561
588
|
|
589
|
+
result: bool = run_async_safely(
|
590
|
+
self.async_set_default_provider(provider_id=provider_id)
|
591
|
+
)
|
562
592
|
|
563
|
-
return
|
564
|
-
self.async_set_default_provider(provider_id=provider_id)
|
565
|
-
)
|
593
|
+
return result
|
566
594
|
|
567
595
|
async def _async_create_timeout_task(
|
568
|
-
self,
|
596
|
+
self,
|
569
597
|
conversation_id: str,
|
570
|
-
request_id: str,
|
571
|
-
timeout: int,
|
598
|
+
request_id: str,
|
599
|
+
timeout: int,
|
572
600
|
provider: HumanLoopProvider,
|
573
|
-
callback: Optional[HumanLoopCallback]
|
574
|
-
):
|
601
|
+
callback: Optional[HumanLoopCallback],
|
602
|
+
) -> None:
|
575
603
|
"""创建超时任务"""
|
576
|
-
|
604
|
+
|
605
|
+
async def timeout_task() -> None:
|
577
606
|
await asyncio.sleep(timeout)
|
578
607
|
# 检查当前状态
|
579
|
-
result = await self.async_check_request_status(
|
580
|
-
|
608
|
+
result = await self.async_check_request_status(
|
609
|
+
conversation_id, request_id, provider.name
|
610
|
+
)
|
611
|
+
|
581
612
|
# 只有当状态为PENDING时才触发超时回调
|
582
613
|
# INPROGRESS状态表示对话正在进行中,不应视为超时
|
583
614
|
if result.status == HumanLoopStatus.PENDING:
|
@@ -591,55 +622,68 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
591
622
|
self._timeout_tasks[(conversation_id, request_id)].cancel()
|
592
623
|
new_task = asyncio.create_task(timeout_task())
|
593
624
|
self._timeout_tasks[(conversation_id, request_id)] = new_task
|
594
|
-
|
625
|
+
|
595
626
|
task = asyncio.create_task(timeout_task())
|
596
627
|
self._timeout_tasks[(conversation_id, request_id)] = task
|
597
|
-
|
628
|
+
|
598
629
|
async def _async_wait_for_result(
|
599
|
-
self,
|
630
|
+
self,
|
600
631
|
conversation_id: str,
|
601
|
-
request_id: str,
|
602
|
-
provider: HumanLoopProvider,
|
603
|
-
timeout: Optional[int] = None
|
632
|
+
request_id: str,
|
633
|
+
provider: HumanLoopProvider,
|
634
|
+
timeout: Optional[int] = None,
|
604
635
|
) -> HumanLoopResult:
|
605
636
|
"""等待循环结果"""
|
606
|
-
start_time = time.time()
|
607
637
|
poll_interval = 1.0 # 轮询间隔(秒)
|
608
|
-
|
638
|
+
|
609
639
|
while True:
|
610
|
-
result = await self.async_check_request_status(
|
611
|
-
|
612
|
-
|
640
|
+
result = await self.async_check_request_status(
|
641
|
+
conversation_id, request_id, provider.name
|
642
|
+
)
|
643
|
+
|
644
|
+
# 如果状态是最终状态(非PENDING),返回结果
|
613
645
|
if result.status != HumanLoopStatus.PENDING:
|
614
646
|
return result
|
615
|
-
|
647
|
+
|
616
648
|
# 等待一段时间后再次轮询
|
617
649
|
await asyncio.sleep(poll_interval)
|
618
|
-
|
619
|
-
async def _async_trigger_update_callback(
|
650
|
+
|
651
|
+
async def _async_trigger_update_callback(
|
652
|
+
self,
|
653
|
+
conversation_id: str,
|
654
|
+
request_id: str,
|
655
|
+
provider: HumanLoopProvider,
|
656
|
+
result: HumanLoopResult,
|
657
|
+
) -> None:
|
620
658
|
"""触发状态更新回调"""
|
621
|
-
callback: Optional[HumanLoopCallback] = self._callbacks.get(
|
659
|
+
callback: Optional[HumanLoopCallback] = self._callbacks.get(
|
660
|
+
(conversation_id, request_id)
|
661
|
+
)
|
622
662
|
if callback:
|
623
663
|
try:
|
624
|
-
await callback.
|
664
|
+
await callback.async_on_humanloop_update(provider, result)
|
625
665
|
# 如果状态是最终状态,可以考虑移除回调
|
626
|
-
if result.status not in [
|
666
|
+
if result.status not in [
|
667
|
+
HumanLoopStatus.PENDING,
|
668
|
+
HumanLoopStatus.INPROGRESS,
|
669
|
+
]:
|
627
670
|
del self._callbacks[(conversation_id, request_id)]
|
628
671
|
except Exception as e:
|
629
672
|
# 处理回调执行过程中的异常
|
630
673
|
try:
|
631
|
-
await callback.
|
632
|
-
except:
|
674
|
+
await callback.async_on_humanloop_error(provider, e)
|
675
|
+
except Exception:
|
633
676
|
# 如果错误回调也失败,只能忽略
|
634
677
|
pass
|
635
678
|
|
636
679
|
# 添加新方法用于获取task相关信息
|
680
|
+
|
637
681
|
async def async_get_task_conversations(self, task_id: str) -> List[str]:
|
638
682
|
"""获取任务关联的所有对话ID
|
639
|
-
|
683
|
+
|
640
684
|
Args:
|
641
685
|
task_id: 任务ID
|
642
|
-
|
686
|
+
|
643
687
|
Returns:
|
644
688
|
List[str]: 与任务关联的对话ID列表
|
645
689
|
"""
|
@@ -647,106 +691,129 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
647
691
|
|
648
692
|
def get_task_conversations(self, task_id: str) -> List[str]:
|
649
693
|
"""获取任务关联的所有对话ID
|
650
|
-
|
694
|
+
|
651
695
|
Args:
|
652
696
|
task_id: 任务ID
|
653
|
-
|
697
|
+
|
654
698
|
Returns:
|
655
699
|
List[str]: 与任务关联的对话ID列表
|
656
700
|
"""
|
657
701
|
return list(self._task_conversations.get(task_id, set()))
|
658
|
-
|
702
|
+
|
659
703
|
async def async_get_conversation_requests(self, conversation_id: str) -> List[str]:
|
660
704
|
"""获取对话关联的所有请求ID
|
661
|
-
|
705
|
+
|
662
706
|
Args:
|
663
707
|
conversation_id: 对话ID
|
664
|
-
|
708
|
+
|
665
709
|
Returns:
|
666
710
|
List[str]: 与对话关联的请求ID列表
|
667
711
|
"""
|
668
|
-
|
712
|
+
ret: List[str] = self._conversation_requests.get(conversation_id, [])
|
713
|
+
return ret
|
669
714
|
|
670
715
|
def get_conversation_requests(self, conversation_id: str) -> List[str]:
|
671
716
|
"""获取对话关联的所有请求ID
|
672
|
-
|
717
|
+
|
673
718
|
Args:
|
674
719
|
conversation_id: 对话ID
|
675
|
-
|
720
|
+
|
676
721
|
Returns:
|
677
722
|
List[str]: 与对话关联的请求ID列表
|
678
723
|
"""
|
679
|
-
|
680
|
-
|
681
|
-
|
724
|
+
ret: List[str] = self._conversation_requests.get(conversation_id, [])
|
725
|
+
|
726
|
+
return ret
|
727
|
+
|
728
|
+
async def async_get_request_task(
|
729
|
+
self, conversation_id: str, request_id: str
|
730
|
+
) -> Optional[str]:
|
682
731
|
"""获取请求关联的任务ID
|
683
|
-
|
732
|
+
|
684
733
|
Args:
|
685
734
|
conversation_id: 对话ID
|
686
735
|
request_id: 请求ID
|
687
|
-
|
736
|
+
|
688
737
|
Returns:
|
689
738
|
Optional[str]: 关联的任务ID,如果不存在则返回None
|
690
739
|
"""
|
691
|
-
|
740
|
+
ret: Optional[str] = self._request_task.get((conversation_id, request_id))
|
741
|
+
|
742
|
+
return ret
|
692
743
|
|
693
|
-
async def async_get_conversation_provider(
|
744
|
+
async def async_get_conversation_provider(
|
745
|
+
self, conversation_id: str
|
746
|
+
) -> Optional[str]:
|
694
747
|
"""获取请求关联的提供者ID
|
695
|
-
|
748
|
+
|
696
749
|
Args:
|
697
750
|
conversation_id: 对话ID
|
698
|
-
|
751
|
+
|
699
752
|
Returns:
|
700
753
|
Optional[str]: 关联的提供者ID,如果不存在则返回None
|
701
754
|
"""
|
702
|
-
|
755
|
+
ret: Optional[str] = self._conversation_provider.get(conversation_id)
|
756
|
+
|
757
|
+
return ret
|
703
758
|
|
704
759
|
async def async_check_conversation_exist(
|
705
760
|
self,
|
706
|
-
task_id:str,
|
761
|
+
task_id: str,
|
707
762
|
conversation_id: str,
|
708
763
|
) -> bool:
|
709
764
|
"""判断对话是否已存在
|
710
|
-
|
765
|
+
|
711
766
|
Args:
|
712
767
|
conversation_id: 对话标识符
|
713
768
|
provider_id: 使用特定提供者的ID(可选)
|
714
|
-
|
769
|
+
|
715
770
|
Returns:
|
716
771
|
bool: 如果对话存在返回True,否则返回False
|
717
772
|
"""
|
718
773
|
# 检查task_id是否存在且conversation_id是否在该task的对话集合中
|
719
|
-
if
|
774
|
+
if (
|
775
|
+
task_id in self._task_conversations
|
776
|
+
and conversation_id in self._task_conversations[task_id]
|
777
|
+
):
|
720
778
|
# 进一步验证该对话是否有关联的请求
|
721
|
-
if
|
779
|
+
if (
|
780
|
+
conversation_id in self._conversation_requests
|
781
|
+
and self._conversation_requests[conversation_id]
|
782
|
+
):
|
722
783
|
return True
|
723
784
|
|
724
785
|
return False
|
725
786
|
|
726
787
|
def check_conversation_exist(
|
727
788
|
self,
|
728
|
-
task_id:str,
|
789
|
+
task_id: str,
|
729
790
|
conversation_id: str,
|
730
791
|
) -> bool:
|
731
792
|
"""判断对话是否已存在
|
732
|
-
|
793
|
+
|
733
794
|
Args:
|
734
795
|
conversation_id: 对话标识符
|
735
796
|
provider_id: 使用特定提供者的ID(可选)
|
736
|
-
|
797
|
+
|
737
798
|
Returns:
|
738
799
|
bool: 如果对话存在返回True,否则返回False
|
739
800
|
"""
|
740
801
|
# 检查task_id是否存在且conversation_id是否在该task的对话集合中
|
741
|
-
if
|
802
|
+
if (
|
803
|
+
task_id in self._task_conversations
|
804
|
+
and conversation_id in self._task_conversations[task_id]
|
805
|
+
):
|
742
806
|
# 进一步验证该对话是否有关联的请求
|
743
|
-
if
|
807
|
+
if (
|
808
|
+
conversation_id in self._conversation_requests
|
809
|
+
and self._conversation_requests[conversation_id]
|
810
|
+
):
|
744
811
|
return True
|
745
812
|
|
746
813
|
return False
|
747
814
|
|
748
|
-
async def async_shutdown(self):
|
749
|
-
|
815
|
+
async def async_shutdown(self) -> None:
|
816
|
+
pass
|
750
817
|
|
751
|
-
def shutdown(self):
|
752
|
-
|
818
|
+
def shutdown(self) -> None:
|
819
|
+
pass
|