gohumanloop 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gohumanloop/__init__.py +6 -8
- gohumanloop/adapters/__init__.py +6 -6
- gohumanloop/adapters/base_adapter.py +838 -0
- gohumanloop/adapters/langgraph_adapter.py +140 -632
- gohumanloop/cli/main.py +4 -1
- gohumanloop/core/interface.py +194 -217
- gohumanloop/core/manager.py +375 -266
- gohumanloop/manager/ghl_manager.py +223 -185
- gohumanloop/models/api_model.py +32 -7
- gohumanloop/models/glh_model.py +15 -11
- gohumanloop/providers/api_provider.py +233 -189
- gohumanloop/providers/base.py +179 -172
- gohumanloop/providers/email_provider.py +386 -325
- gohumanloop/providers/ghl_provider.py +19 -17
- gohumanloop/providers/terminal_provider.py +111 -92
- gohumanloop/utils/__init__.py +7 -1
- gohumanloop/utils/context_formatter.py +20 -15
- gohumanloop/utils/threadsafedict.py +64 -56
- gohumanloop/utils/utils.py +28 -28
- gohumanloop-0.0.7.dist-info/METADATA +298 -0
- gohumanloop-0.0.7.dist-info/RECORD +31 -0
- {gohumanloop-0.0.5.dist-info → gohumanloop-0.0.7.dist-info}/WHEEL +1 -1
- gohumanloop-0.0.5.dist-info/METADATA +0 -35
- gohumanloop-0.0.5.dist-info/RECORD +0 -30
- {gohumanloop-0.0.5.dist-info → gohumanloop-0.0.7.dist-info}/entry_points.txt +0 -0
- {gohumanloop-0.0.5.dist-info → gohumanloop-0.0.7.dist-info}/licenses/LICENSE +0 -0
- {gohumanloop-0.0.5.dist-info → gohumanloop-0.0.7.dist-info}/top_level.txt +0 -0
gohumanloop/core/manager.py
CHANGED
@@ -1,20 +1,53 @@
|
|
1
|
-
from typing import Dict, Any, Optional, List, Union
|
1
|
+
from typing import Dict, Any, Optional, List, Union
|
2
2
|
import asyncio
|
3
|
-
import
|
3
|
+
from datetime import datetime
|
4
4
|
from gohumanloop.utils import run_async_safely
|
5
5
|
|
6
6
|
from gohumanloop.core.interface import (
|
7
|
-
|
8
|
-
|
7
|
+
HumanLoopRequest,
|
8
|
+
HumanLoopManager,
|
9
|
+
HumanLoopProvider,
|
10
|
+
HumanLoopCallback,
|
11
|
+
HumanLoopResult,
|
12
|
+
HumanLoopStatus,
|
13
|
+
HumanLoopType,
|
9
14
|
)
|
10
15
|
|
16
|
+
|
11
17
|
class DefaultHumanLoopManager(HumanLoopManager):
|
12
18
|
"""默认人机循环管理器实现"""
|
13
|
-
|
14
|
-
def __init__(
|
15
|
-
self
|
19
|
+
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
initial_providers: Optional[
|
23
|
+
Union[HumanLoopProvider, List[HumanLoopProvider]]
|
24
|
+
] = None,
|
25
|
+
):
|
26
|
+
self.providers: dict[str, HumanLoopProvider] = {}
|
16
27
|
self.default_provider_id = None
|
17
|
-
|
28
|
+
|
29
|
+
# 存储请求和回调的映射
|
30
|
+
self._callbacks: dict[tuple[str, str], HumanLoopCallback] = {}
|
31
|
+
# 存储请求的超时任务
|
32
|
+
self._timeout_tasks: dict[tuple[str, str], asyncio.Task] = {}
|
33
|
+
|
34
|
+
# 存储task_id与conversation_id的映射关系
|
35
|
+
self._task_conversations: dict[
|
36
|
+
str, set[str]
|
37
|
+
] = {} # task_id -> Set[conversation_id]
|
38
|
+
# 存储conversation_id与request_id的映射关系
|
39
|
+
self._conversation_requests: dict[
|
40
|
+
str, list[str]
|
41
|
+
] = {} # conversation_id -> List[request_id]
|
42
|
+
# 存储request_id与task_id的反向映射
|
43
|
+
self._request_task: dict[
|
44
|
+
tuple[str, str], str
|
45
|
+
] = {} # (conversation_id, request_id) -> task_id
|
46
|
+
# 存储对话对应的provider_id
|
47
|
+
self._conversation_provider: dict[
|
48
|
+
str, str
|
49
|
+
] = {} # conversation_id -> provider_id
|
50
|
+
|
18
51
|
# 初始化提供者
|
19
52
|
if initial_providers:
|
20
53
|
if isinstance(initial_providers, list):
|
@@ -27,41 +60,33 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
27
60
|
# 处理单个提供者
|
28
61
|
self.register_provider_sync(initial_providers, initial_providers.name)
|
29
62
|
self.default_provider_id = initial_providers.name
|
30
|
-
|
31
|
-
|
32
|
-
self
|
33
|
-
|
34
|
-
self._timeout_tasks = {}
|
35
|
-
|
36
|
-
# 存储task_id与conversation_id的映射关系
|
37
|
-
self._task_conversations = {} # task_id -> Set[conversation_id]
|
38
|
-
# 存储conversation_id与request_id的映射关系
|
39
|
-
self._conversation_requests = {} # conversation_id -> List[request_id]
|
40
|
-
# 存储request_id与task_id的反向映射
|
41
|
-
self._request_task = {} # (conversation_id, request_id) -> task_id
|
42
|
-
# 存储对话对应的provider_id
|
43
|
-
self._conversation_provider = {} # conversation_id -> provider_id
|
44
|
-
|
45
|
-
def register_provider_sync(self, provider: HumanLoopProvider, provider_id: Optional[str]) -> str:
|
63
|
+
|
64
|
+
def register_provider_sync(
|
65
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str]
|
66
|
+
) -> str:
|
46
67
|
"""同步注册提供者(用于初始化)"""
|
47
68
|
if not provider_id:
|
48
69
|
provider_id = f"provider_{len(self.providers) + 1}"
|
49
|
-
|
70
|
+
|
50
71
|
self.providers[provider_id] = provider
|
51
|
-
|
72
|
+
|
52
73
|
if not self.default_provider_id:
|
53
74
|
self.default_provider_id = provider_id
|
54
|
-
|
75
|
+
|
55
76
|
return provider_id
|
56
|
-
|
57
|
-
async def async_register_provider(
|
77
|
+
|
78
|
+
async def async_register_provider(
|
79
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str] = None
|
80
|
+
) -> str:
|
58
81
|
"""注册人机循环提供者"""
|
59
82
|
return self.register_provider_sync(provider, provider_id)
|
60
|
-
|
61
|
-
def register_provider(
|
83
|
+
|
84
|
+
def register_provider(
|
85
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str] = None
|
86
|
+
) -> str:
|
62
87
|
"""注册人机循环提供者(同步版本)"""
|
63
88
|
return self.register_provider_sync(provider, provider_id)
|
64
|
-
|
89
|
+
|
65
90
|
async def async_request_humanloop(
|
66
91
|
self,
|
67
92
|
task_id: str,
|
@@ -79,13 +104,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
79
104
|
provider_id = provider_id or self.default_provider_id
|
80
105
|
if not provider_id or provider_id not in self.providers:
|
81
106
|
raise ValueError(f"Provider '{provider_id}' not found")
|
82
|
-
|
107
|
+
|
83
108
|
# 检查对话是否已存在且使用了不同的提供者
|
84
|
-
if
|
85
|
-
|
86
|
-
|
109
|
+
if (
|
110
|
+
conversation_id in self._conversation_provider
|
111
|
+
and self._conversation_provider[conversation_id] != provider_id
|
112
|
+
):
|
113
|
+
raise ValueError(
|
114
|
+
f"Conversation '{conversation_id}' already exists with a different provider"
|
115
|
+
)
|
116
|
+
|
87
117
|
provider = self.providers[provider_id]
|
88
|
-
|
118
|
+
|
89
119
|
try:
|
90
120
|
# 发送请求
|
91
121
|
result = await provider.async_request_humanloop(
|
@@ -94,38 +124,63 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
94
124
|
loop_type=loop_type,
|
95
125
|
context=context,
|
96
126
|
metadata=metadata,
|
97
|
-
timeout=timeout
|
127
|
+
timeout=timeout,
|
98
128
|
)
|
99
|
-
|
129
|
+
|
100
130
|
request_id = result.request_id
|
101
|
-
|
131
|
+
|
102
132
|
if not request_id:
|
103
|
-
raise ValueError(
|
104
|
-
|
133
|
+
raise ValueError(
|
134
|
+
f"Failed to request humanloop for conversation '{conversation_id}'"
|
135
|
+
)
|
136
|
+
|
105
137
|
# 存储task_id、conversation_id和request_id的关系
|
106
138
|
if task_id not in self._task_conversations:
|
107
139
|
self._task_conversations[task_id] = set()
|
108
140
|
self._task_conversations[task_id].add(conversation_id)
|
109
|
-
|
141
|
+
|
110
142
|
if conversation_id not in self._conversation_requests:
|
111
143
|
self._conversation_requests[conversation_id] = []
|
112
144
|
self._conversation_requests[conversation_id].append(request_id)
|
113
|
-
|
145
|
+
|
114
146
|
self._request_task[(conversation_id, request_id)] = task_id
|
115
147
|
# 存储对话对应的provider_id
|
116
148
|
self._conversation_provider[conversation_id] = provider_id
|
117
|
-
|
149
|
+
|
118
150
|
# 如果提供了回调,存储它
|
119
151
|
if callback:
|
152
|
+
try:
|
153
|
+
# 创建请求对象
|
154
|
+
request = HumanLoopRequest(
|
155
|
+
task_id=task_id,
|
156
|
+
conversation_id=conversation_id,
|
157
|
+
loop_type=loop_type,
|
158
|
+
context=context,
|
159
|
+
metadata=metadata or {},
|
160
|
+
timeout=timeout,
|
161
|
+
created_at=datetime.now(),
|
162
|
+
)
|
163
|
+
await callback.async_on_humanloop_request(provider, request)
|
164
|
+
except Exception as e:
|
165
|
+
# 处理回调执行过程中的异常
|
166
|
+
try:
|
167
|
+
await callback.async_on_humanloop_error(provider, e)
|
168
|
+
except Exception:
|
169
|
+
# 如果错误回调也失败,只能忽略
|
170
|
+
pass
|
120
171
|
self._callbacks[(conversation_id, request_id)] = callback
|
121
|
-
|
172
|
+
|
122
173
|
# 如果设置了超时,创建超时任务
|
123
174
|
if timeout:
|
124
|
-
await self._async_create_timeout_task(
|
125
|
-
|
175
|
+
await self._async_create_timeout_task(
|
176
|
+
conversation_id, request_id, timeout, provider, callback
|
177
|
+
)
|
178
|
+
|
126
179
|
# 如果是阻塞模式,等待结果
|
127
180
|
if blocking:
|
128
|
-
return await self._async_wait_for_result(
|
181
|
+
return await self._async_wait_for_result(
|
182
|
+
conversation_id, request_id, provider, timeout
|
183
|
+
)
|
129
184
|
else:
|
130
185
|
return request_id
|
131
186
|
except Exception as e:
|
@@ -133,7 +188,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
133
188
|
if callback:
|
134
189
|
try:
|
135
190
|
await callback.async_on_humanloop_error(provider, e)
|
136
|
-
except:
|
191
|
+
except Exception:
|
137
192
|
# 如果错误回调也失败,只能忽略
|
138
193
|
pass
|
139
194
|
raise # 重新抛出异常,让调用者知道发生了错误
|
@@ -152,20 +207,21 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
152
207
|
) -> Union[str, HumanLoopResult]:
|
153
208
|
"""请求人机循环(同步版本)"""
|
154
209
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
)
|
210
|
+
result: Union[str, HumanLoopResult] = run_async_safely(
|
211
|
+
self.async_request_humanloop(
|
212
|
+
task_id=task_id,
|
213
|
+
conversation_id=conversation_id,
|
214
|
+
loop_type=loop_type,
|
215
|
+
context=context,
|
216
|
+
callback=callback,
|
217
|
+
metadata=metadata,
|
218
|
+
provider_id=provider_id,
|
219
|
+
timeout=timeout,
|
220
|
+
blocking=blocking,
|
167
221
|
)
|
222
|
+
)
|
168
223
|
|
224
|
+
return result
|
169
225
|
|
170
226
|
async def async_continue_humanloop(
|
171
227
|
self,
|
@@ -182,16 +238,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
182
238
|
if conversation_id in self._conversation_provider:
|
183
239
|
stored_provider_id = self._conversation_provider[conversation_id]
|
184
240
|
if provider_id and provider_id != stored_provider_id:
|
185
|
-
raise ValueError(
|
241
|
+
raise ValueError(
|
242
|
+
f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'"
|
243
|
+
)
|
186
244
|
provider_id = stored_provider_id
|
187
245
|
else:
|
188
246
|
provider_id = provider_id or self.default_provider_id
|
189
|
-
|
247
|
+
|
190
248
|
if not provider_id or provider_id not in self.providers:
|
191
249
|
raise ValueError(f"Provider '{provider_id}' not found")
|
192
|
-
|
250
|
+
|
193
251
|
provider = self.providers[provider_id]
|
194
|
-
|
252
|
+
|
195
253
|
try:
|
196
254
|
# 发送继续请求
|
197
255
|
result = await provider.async_continue_humanloop(
|
@@ -200,42 +258,67 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
200
258
|
metadata=metadata,
|
201
259
|
timeout=timeout,
|
202
260
|
)
|
203
|
-
|
261
|
+
|
204
262
|
request_id = result.request_id
|
205
263
|
|
206
264
|
if not request_id:
|
207
|
-
raise ValueError(
|
208
|
-
|
265
|
+
raise ValueError(
|
266
|
+
f"Failed to continue humanloop for conversation '{conversation_id}'"
|
267
|
+
)
|
268
|
+
|
209
269
|
# 更新conversation_id和request_id的关系
|
210
270
|
if conversation_id not in self._conversation_requests:
|
211
271
|
self._conversation_requests[conversation_id] = []
|
212
272
|
self._conversation_requests[conversation_id].append(request_id)
|
213
|
-
|
273
|
+
|
214
274
|
# 查找此conversation_id对应的task_id
|
215
275
|
task_id = None
|
216
276
|
for t_id, convs in self._task_conversations.items():
|
217
277
|
if conversation_id in convs:
|
218
278
|
task_id = t_id
|
219
279
|
break
|
220
|
-
|
280
|
+
|
221
281
|
if task_id:
|
222
282
|
self._request_task[(conversation_id, request_id)] = task_id
|
223
|
-
|
283
|
+
|
224
284
|
# 存储对话对应的provider_id,如果对话不存在才存储
|
225
285
|
if conversation_id not in self._conversation_provider:
|
226
286
|
self._conversation_provider[conversation_id] = provider_id
|
227
|
-
|
287
|
+
|
228
288
|
# 如果提供了回调,存储它
|
229
289
|
if callback:
|
290
|
+
try:
|
291
|
+
# 创建请求对象
|
292
|
+
request = HumanLoopRequest(
|
293
|
+
task_id=task_id or "",
|
294
|
+
conversation_id=conversation_id,
|
295
|
+
loop_type=HumanLoopType.CONVERSATION,
|
296
|
+
context=context,
|
297
|
+
metadata=metadata or {},
|
298
|
+
timeout=timeout,
|
299
|
+
created_at=datetime.now(),
|
300
|
+
)
|
301
|
+
await callback.async_on_humanloop_request(provider, request)
|
302
|
+
except Exception as e:
|
303
|
+
# 处理回调执行过程中的异常
|
304
|
+
try:
|
305
|
+
await callback.async_on_humanloop_error(provider, e)
|
306
|
+
except Exception:
|
307
|
+
# 如果错误回调也失败,只能忽略
|
308
|
+
pass
|
230
309
|
self._callbacks[(conversation_id, request_id)] = callback
|
231
|
-
|
310
|
+
|
232
311
|
# 如果设置了超时,创建超时任务
|
233
312
|
if timeout:
|
234
|
-
await self._async_create_timeout_task(
|
235
|
-
|
313
|
+
await self._async_create_timeout_task(
|
314
|
+
conversation_id, request_id, timeout, provider, callback
|
315
|
+
)
|
316
|
+
|
236
317
|
# 如果是阻塞模式,等待结果
|
237
318
|
if blocking:
|
238
|
-
return await self._async_wait_for_result(
|
319
|
+
return await self._async_wait_for_result(
|
320
|
+
conversation_id, request_id, provider, timeout
|
321
|
+
)
|
239
322
|
else:
|
240
323
|
return request_id
|
241
324
|
except Exception as e:
|
@@ -243,7 +326,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
243
326
|
if callback:
|
244
327
|
try:
|
245
328
|
await callback.async_on_humanloop_error(provider, e)
|
246
|
-
except:
|
329
|
+
except Exception:
|
247
330
|
# 如果错误回调也失败,只能忽略
|
248
331
|
pass
|
249
332
|
raise # 重新抛出异常,让调用者知道发生了错误
|
@@ -260,42 +343,48 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
260
343
|
) -> Union[str, HumanLoopResult]:
|
261
344
|
"""继续人机循环(同步版本)"""
|
262
345
|
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
)
|
346
|
+
result: Union[str, HumanLoopResult] = run_async_safely(
|
347
|
+
self.async_continue_humanloop(
|
348
|
+
conversation_id=conversation_id,
|
349
|
+
context=context,
|
350
|
+
callback=callback,
|
351
|
+
metadata=metadata,
|
352
|
+
provider_id=provider_id,
|
353
|
+
timeout=timeout,
|
354
|
+
blocking=blocking,
|
273
355
|
)
|
356
|
+
)
|
357
|
+
|
358
|
+
return result
|
274
359
|
|
275
360
|
async def async_check_request_status(
|
276
|
-
self,
|
277
|
-
conversation_id: str,
|
278
|
-
request_id: str,
|
279
|
-
provider_id: Optional[str] = None
|
361
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
280
362
|
) -> HumanLoopResult:
|
281
363
|
"""检查请求状态"""
|
282
364
|
# 如果没有指定provider_id,尝试从存储的映射中获取
|
283
365
|
if provider_id is None:
|
284
366
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
285
367
|
provider_id = stored_provider_id or self.default_provider_id
|
286
|
-
|
368
|
+
|
287
369
|
if not provider_id or provider_id not in self.providers:
|
288
370
|
raise ValueError(f"Provider '{provider_id}' not found")
|
289
|
-
|
371
|
+
|
290
372
|
provider = self.providers[provider_id]
|
291
|
-
|
373
|
+
|
292
374
|
try:
|
293
|
-
result = await provider.async_check_request_status(
|
294
|
-
|
375
|
+
result = await provider.async_check_request_status(
|
376
|
+
conversation_id, request_id
|
377
|
+
)
|
378
|
+
|
295
379
|
# 如果有回调且状态不是等待或进行中,触发状态更新回调
|
296
|
-
if (
|
297
|
-
|
298
|
-
|
380
|
+
if (
|
381
|
+
conversation_id,
|
382
|
+
request_id,
|
383
|
+
) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
|
384
|
+
await self._async_trigger_update_callback(
|
385
|
+
conversation_id, request_id, provider, result
|
386
|
+
)
|
387
|
+
|
299
388
|
return result
|
300
389
|
except Exception as e:
|
301
390
|
# 处理检查状态过程中的异常
|
@@ -303,286 +392,270 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
303
392
|
if callback:
|
304
393
|
try:
|
305
394
|
await callback.async_on_humanloop_error(provider, e)
|
306
|
-
except:
|
395
|
+
except Exception:
|
307
396
|
# 如果错误回调也失败,只能忽略
|
308
397
|
pass
|
309
398
|
raise # 重新抛出异常,让调用者知道发生了错误
|
310
399
|
|
311
|
-
|
312
400
|
def check_request_status(
|
313
|
-
self,
|
314
|
-
conversation_id: str,
|
315
|
-
request_id: str,
|
316
|
-
provider_id: Optional[str] = None
|
401
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
317
402
|
) -> HumanLoopResult:
|
318
403
|
"""检查请求状态(同步版本)"""
|
319
404
|
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
)
|
405
|
+
result: HumanLoopResult = run_async_safely(
|
406
|
+
self.async_check_request_status(
|
407
|
+
conversation_id=conversation_id,
|
408
|
+
request_id=request_id,
|
409
|
+
provider_id=provider_id,
|
326
410
|
)
|
411
|
+
)
|
327
412
|
|
413
|
+
return result
|
328
414
|
|
329
415
|
async def async_check_conversation_status(
|
330
|
-
self,
|
331
|
-
conversation_id: str,
|
332
|
-
provider_id: Optional[str] = None
|
416
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
333
417
|
) -> HumanLoopResult:
|
334
418
|
"""检查对话状态"""
|
335
419
|
# 优先使用对话已关联的提供者
|
336
420
|
if provider_id is None:
|
337
421
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
338
422
|
provider_id = stored_provider_id or self.default_provider_id
|
339
|
-
|
423
|
+
|
340
424
|
if not provider_id or provider_id not in self.providers:
|
341
425
|
raise ValueError(f"Provider '{provider_id}' not found")
|
342
|
-
|
426
|
+
|
343
427
|
# 检查对话指定provider_id或默认provider_id最后一次请求的状态
|
344
428
|
provider = self.providers[provider_id]
|
345
|
-
|
429
|
+
|
346
430
|
try:
|
347
431
|
# 检查对话指定provider_id或默认provider_id最后一次请求的状态
|
348
432
|
return await provider.async_check_conversation_status(conversation_id)
|
349
433
|
except Exception as e:
|
350
434
|
# 处理检查对话状态过程中的异常
|
351
435
|
# 尝试找到与此对话关联的最后一个请求的回调
|
352
|
-
if
|
436
|
+
if (
|
437
|
+
conversation_id in self._conversation_requests
|
438
|
+
and self._conversation_requests[conversation_id]
|
439
|
+
):
|
353
440
|
last_request_id = self._conversation_requests[conversation_id][-1]
|
354
441
|
callback = self._callbacks.get((conversation_id, last_request_id))
|
355
442
|
if callback:
|
356
443
|
try:
|
357
444
|
await callback.async_on_humanloop_error(provider, e)
|
358
|
-
except:
|
445
|
+
except Exception:
|
359
446
|
# 如果错误回调也失败,只能忽略
|
360
447
|
pass
|
361
448
|
raise # 重新抛出异常,让调用者知道发生了错误
|
362
|
-
|
449
|
+
|
363
450
|
def check_conversation_status(
|
364
|
-
self,
|
365
|
-
conversation_id: str,
|
366
|
-
provider_id: Optional[str] = None
|
451
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
367
452
|
) -> HumanLoopResult:
|
368
453
|
"""检查对话状态(同步版本)"""
|
369
454
|
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
provider_id=provider_id
|
374
|
-
)
|
455
|
+
result: HumanLoopResult = run_async_safely(
|
456
|
+
self.async_check_conversation_status(
|
457
|
+
conversation_id=conversation_id, provider_id=provider_id
|
375
458
|
)
|
459
|
+
)
|
460
|
+
|
461
|
+
return result
|
376
462
|
|
377
463
|
async def async_cancel_request(
|
378
|
-
self,
|
379
|
-
conversation_id: str,
|
380
|
-
request_id: str,
|
381
|
-
provider_id: Optional[str] = None
|
464
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
382
465
|
) -> bool:
|
383
466
|
"""取消特定请求"""
|
384
467
|
if provider_id is None:
|
385
468
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
386
469
|
provider_id = stored_provider_id or self.default_provider_id
|
387
|
-
|
470
|
+
|
388
471
|
if not provider_id or provider_id not in self.providers:
|
389
472
|
raise ValueError(f"Provider '{provider_id}' not found")
|
390
|
-
|
473
|
+
|
391
474
|
provider = self.providers[provider_id]
|
392
475
|
|
393
476
|
# 取消超时任务
|
394
477
|
if (conversation_id, request_id) in self._timeout_tasks:
|
395
478
|
self._timeout_tasks[(conversation_id, request_id)].cancel()
|
396
479
|
del self._timeout_tasks[(conversation_id, request_id)]
|
397
|
-
|
480
|
+
|
398
481
|
# 从回调映射中删除
|
399
482
|
if (conversation_id, request_id) in self._callbacks:
|
400
483
|
del self._callbacks[(conversation_id, request_id)]
|
401
|
-
|
484
|
+
|
402
485
|
# 清理request关联
|
403
486
|
if (conversation_id, request_id) in self._request_task:
|
404
487
|
del self._request_task[(conversation_id, request_id)]
|
405
|
-
|
488
|
+
|
406
489
|
# 从conversation_requests中移除
|
407
490
|
if conversation_id in self._conversation_requests:
|
408
491
|
if request_id in self._conversation_requests[conversation_id]:
|
409
492
|
self._conversation_requests[conversation_id].remove(request_id)
|
410
|
-
|
493
|
+
|
411
494
|
return await provider.async_cancel_request(conversation_id, request_id)
|
412
|
-
|
413
495
|
|
414
496
|
def cancel_request(
|
415
|
-
self,
|
416
|
-
conversation_id: str,
|
417
|
-
request_id: str,
|
418
|
-
provider_id: Optional[str] = None
|
497
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
419
498
|
) -> bool:
|
420
499
|
"""取消特定请求(同步版本)"""
|
421
500
|
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
)
|
501
|
+
result: bool = run_async_safely(
|
502
|
+
self.async_cancel_request(
|
503
|
+
conversation_id=conversation_id,
|
504
|
+
request_id=request_id,
|
505
|
+
provider_id=provider_id,
|
428
506
|
)
|
507
|
+
)
|
508
|
+
|
509
|
+
return result
|
429
510
|
|
430
511
|
async def async_cancel_conversation(
|
431
|
-
self,
|
432
|
-
conversation_id: str,
|
433
|
-
provider_id: Optional[str] = None
|
512
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
434
513
|
) -> bool:
|
435
514
|
"""取消整个对话"""
|
436
515
|
# 优先使用对话已关联的提供者
|
437
516
|
if provider_id is None:
|
438
517
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
439
518
|
provider_id = stored_provider_id or self.default_provider_id
|
440
|
-
|
519
|
+
|
441
520
|
if not provider_id or provider_id not in self.providers:
|
442
521
|
raise ValueError(f"Provider '{provider_id}' not found")
|
443
|
-
|
522
|
+
|
444
523
|
provider = self.providers[provider_id]
|
445
|
-
|
524
|
+
|
446
525
|
# 取消与此对话相关的所有超时任务和回调
|
447
526
|
keys_to_remove = []
|
448
527
|
for key in self._timeout_tasks:
|
449
528
|
if key[0] == conversation_id:
|
450
529
|
self._timeout_tasks[key].cancel()
|
451
530
|
keys_to_remove.append(key)
|
452
|
-
|
531
|
+
|
453
532
|
for key in keys_to_remove:
|
454
533
|
del self._timeout_tasks[key]
|
455
|
-
|
534
|
+
|
456
535
|
keys_to_remove = []
|
457
536
|
for key in self._callbacks:
|
458
537
|
if key[0] == conversation_id:
|
459
538
|
keys_to_remove.append(key)
|
460
|
-
|
539
|
+
|
461
540
|
for key in keys_to_remove:
|
462
541
|
del self._callbacks[key]
|
463
|
-
|
542
|
+
|
464
543
|
# 清理与此对话相关的task映射关系
|
465
544
|
# 1. 从task_conversations中移除此对话
|
466
545
|
task_ids_to_update = []
|
467
546
|
for task_id, convs in self._task_conversations.items():
|
468
547
|
if conversation_id in convs:
|
469
548
|
task_ids_to_update.append(task_id)
|
470
|
-
|
549
|
+
|
471
550
|
for task_id in task_ids_to_update:
|
472
551
|
self._task_conversations[task_id].remove(conversation_id)
|
473
552
|
# 如果task没有关联的对话了,可以考虑删除该task记录
|
474
553
|
if not self._task_conversations[task_id]:
|
475
554
|
del self._task_conversations[task_id]
|
476
|
-
|
555
|
+
|
477
556
|
# 2. 获取并清理所有与此对话相关的请求
|
478
557
|
request_ids = self._conversation_requests.get(conversation_id, [])
|
479
558
|
for request_id in request_ids:
|
480
559
|
# 清理request_task映射
|
481
560
|
if (conversation_id, request_id) in self._request_task:
|
482
561
|
del self._request_task[(conversation_id, request_id)]
|
483
|
-
|
562
|
+
|
484
563
|
# 3. 清理conversation_requests映射
|
485
564
|
if conversation_id in self._conversation_requests:
|
486
565
|
del self._conversation_requests[conversation_id]
|
487
|
-
|
566
|
+
|
488
567
|
# 4. 清理provider关联
|
489
568
|
if conversation_id in self._conversation_provider:
|
490
569
|
del self._conversation_provider[conversation_id]
|
491
|
-
|
570
|
+
|
492
571
|
return await provider.async_cancel_conversation(conversation_id)
|
493
|
-
|
494
572
|
|
495
573
|
def cancel_conversation(
|
496
|
-
self,
|
497
|
-
conversation_id: str,
|
498
|
-
provider_id: Optional[str] = None
|
574
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
499
575
|
) -> bool:
|
500
576
|
"""取消整个对话(同步版本)"""
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
provider_id=provider_id
|
506
|
-
)
|
577
|
+
|
578
|
+
result: bool = run_async_safely(
|
579
|
+
self.async_cancel_conversation(
|
580
|
+
conversation_id=conversation_id, provider_id=provider_id
|
507
581
|
)
|
582
|
+
)
|
583
|
+
|
584
|
+
return result
|
508
585
|
|
509
586
|
async def async_get_provider(
|
510
|
-
self,
|
511
|
-
provider_id: Optional[str] = None
|
587
|
+
self, provider_id: Optional[str] = None
|
512
588
|
) -> HumanLoopProvider:
|
513
589
|
"""获取指定的提供者实例"""
|
514
590
|
provider_id = provider_id or self.default_provider_id
|
515
591
|
if not provider_id or provider_id not in self.providers:
|
516
592
|
raise ValueError(f"Provider '{provider_id}' not found")
|
517
|
-
|
593
|
+
|
518
594
|
return self.providers[provider_id]
|
519
|
-
|
520
|
-
def get_provider(
|
521
|
-
self,
|
522
|
-
provider_id: Optional[str] = None
|
523
|
-
) -> HumanLoopProvider:
|
595
|
+
|
596
|
+
def get_provider(self, provider_id: Optional[str] = None) -> HumanLoopProvider:
|
524
597
|
"""获取指定的提供者实例(同步版本)"""
|
525
598
|
|
526
|
-
|
527
|
-
|
528
|
-
|
599
|
+
result: HumanLoopProvider = run_async_safely(
|
600
|
+
self.async_get_provider(provider_id=provider_id)
|
601
|
+
)
|
602
|
+
|
603
|
+
return result
|
529
604
|
|
530
605
|
async def async_list_providers(self) -> Dict[str, HumanLoopProvider]:
|
531
606
|
"""列出所有注册的提供者"""
|
532
607
|
return self.providers
|
533
|
-
|
534
608
|
|
535
609
|
def list_providers(self) -> Dict[str, HumanLoopProvider]:
|
536
610
|
"""列出所有注册的提供者(同步版本)"""
|
537
|
-
|
538
|
-
return run_async_safely(
|
539
|
-
self.async_list_providers()
|
540
|
-
)
|
541
611
|
|
612
|
+
result: Dict[str, HumanLoopProvider] = run_async_safely(
|
613
|
+
self.async_list_providers()
|
614
|
+
)
|
542
615
|
|
616
|
+
return result
|
543
617
|
|
544
|
-
async def async_set_default_provider(
|
545
|
-
self,
|
546
|
-
provider_id: str
|
547
|
-
) -> bool:
|
618
|
+
async def async_set_default_provider(self, provider_id: str) -> bool:
|
548
619
|
"""设置默认提供者"""
|
549
620
|
if provider_id not in self.providers:
|
550
621
|
raise ValueError(f"Provider '{provider_id}' not found")
|
551
|
-
|
622
|
+
|
552
623
|
self.default_provider_id = provider_id
|
553
624
|
return True
|
554
|
-
|
555
625
|
|
556
|
-
def set_default_provider(
|
557
|
-
self,
|
558
|
-
provider_id: str
|
559
|
-
) -> bool:
|
626
|
+
def set_default_provider(self, provider_id: str) -> bool:
|
560
627
|
"""设置默认提供者(同步版本)"""
|
561
628
|
|
629
|
+
result: bool = run_async_safely(
|
630
|
+
self.async_set_default_provider(provider_id=provider_id)
|
631
|
+
)
|
562
632
|
|
563
|
-
return
|
564
|
-
self.async_set_default_provider(provider_id=provider_id)
|
565
|
-
)
|
633
|
+
return result
|
566
634
|
|
567
635
|
async def _async_create_timeout_task(
|
568
|
-
self,
|
636
|
+
self,
|
569
637
|
conversation_id: str,
|
570
|
-
request_id: str,
|
571
|
-
timeout: int,
|
638
|
+
request_id: str,
|
639
|
+
timeout: int,
|
572
640
|
provider: HumanLoopProvider,
|
573
|
-
callback: Optional[HumanLoopCallback]
|
574
|
-
):
|
641
|
+
callback: Optional[HumanLoopCallback],
|
642
|
+
) -> None:
|
575
643
|
"""创建超时任务"""
|
576
|
-
|
644
|
+
|
645
|
+
async def timeout_task() -> None:
|
577
646
|
await asyncio.sleep(timeout)
|
578
647
|
# 检查当前状态
|
579
|
-
result = await self.async_check_request_status(
|
580
|
-
|
648
|
+
result = await self.async_check_request_status(
|
649
|
+
conversation_id, request_id, provider.name
|
650
|
+
)
|
651
|
+
|
581
652
|
# 只有当状态为PENDING时才触发超时回调
|
582
653
|
# INPROGRESS状态表示对话正在进行中,不应视为超时
|
583
654
|
if result.status == HumanLoopStatus.PENDING:
|
584
655
|
if callback:
|
585
|
-
await callback.async_on_humanloop_timeout(
|
656
|
+
await callback.async_on_humanloop_timeout(
|
657
|
+
provider=provider, result=result
|
658
|
+
)
|
586
659
|
# 如果状态是INPROGRESS,重置超时任务
|
587
660
|
elif result.status == HumanLoopStatus.INPROGRESS:
|
588
661
|
# 对于进行中的对话,我们可以选择延长超时时间
|
@@ -591,55 +664,68 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
591
664
|
self._timeout_tasks[(conversation_id, request_id)].cancel()
|
592
665
|
new_task = asyncio.create_task(timeout_task())
|
593
666
|
self._timeout_tasks[(conversation_id, request_id)] = new_task
|
594
|
-
|
667
|
+
|
595
668
|
task = asyncio.create_task(timeout_task())
|
596
669
|
self._timeout_tasks[(conversation_id, request_id)] = task
|
597
|
-
|
670
|
+
|
598
671
|
async def _async_wait_for_result(
|
599
|
-
self,
|
672
|
+
self,
|
600
673
|
conversation_id: str,
|
601
|
-
request_id: str,
|
602
|
-
provider: HumanLoopProvider,
|
603
|
-
timeout: Optional[int] = None
|
674
|
+
request_id: str,
|
675
|
+
provider: HumanLoopProvider,
|
676
|
+
timeout: Optional[int] = None,
|
604
677
|
) -> HumanLoopResult:
|
605
678
|
"""等待循环结果"""
|
606
|
-
start_time = time.time()
|
607
679
|
poll_interval = 1.0 # 轮询间隔(秒)
|
608
|
-
|
680
|
+
|
609
681
|
while True:
|
610
|
-
result = await self.async_check_request_status(
|
611
|
-
|
612
|
-
|
682
|
+
result = await self.async_check_request_status(
|
683
|
+
conversation_id, request_id, provider.name
|
684
|
+
)
|
685
|
+
|
686
|
+
# 如果状态是最终状态(非PENDING),返回结果
|
613
687
|
if result.status != HumanLoopStatus.PENDING:
|
614
688
|
return result
|
615
|
-
|
689
|
+
|
616
690
|
# 等待一段时间后再次轮询
|
617
691
|
await asyncio.sleep(poll_interval)
|
618
|
-
|
619
|
-
async def _async_trigger_update_callback(
|
692
|
+
|
693
|
+
async def _async_trigger_update_callback(
|
694
|
+
self,
|
695
|
+
conversation_id: str,
|
696
|
+
request_id: str,
|
697
|
+
provider: HumanLoopProvider,
|
698
|
+
result: HumanLoopResult,
|
699
|
+
) -> None:
|
620
700
|
"""触发状态更新回调"""
|
621
|
-
callback: Optional[HumanLoopCallback] = self._callbacks.get(
|
701
|
+
callback: Optional[HumanLoopCallback] = self._callbacks.get(
|
702
|
+
(conversation_id, request_id)
|
703
|
+
)
|
622
704
|
if callback:
|
623
705
|
try:
|
624
|
-
await callback.
|
706
|
+
await callback.async_on_humanloop_update(provider, result)
|
625
707
|
# 如果状态是最终状态,可以考虑移除回调
|
626
|
-
if result.status not in [
|
708
|
+
if result.status not in [
|
709
|
+
HumanLoopStatus.PENDING,
|
710
|
+
HumanLoopStatus.INPROGRESS,
|
711
|
+
]:
|
627
712
|
del self._callbacks[(conversation_id, request_id)]
|
628
713
|
except Exception as e:
|
629
714
|
# 处理回调执行过程中的异常
|
630
715
|
try:
|
631
|
-
await callback.
|
632
|
-
except:
|
716
|
+
await callback.async_on_humanloop_error(provider, e)
|
717
|
+
except Exception:
|
633
718
|
# 如果错误回调也失败,只能忽略
|
634
719
|
pass
|
635
720
|
|
636
721
|
# 添加新方法用于获取task相关信息
|
722
|
+
|
637
723
|
async def async_get_task_conversations(self, task_id: str) -> List[str]:
|
638
724
|
"""获取任务关联的所有对话ID
|
639
|
-
|
725
|
+
|
640
726
|
Args:
|
641
727
|
task_id: 任务ID
|
642
|
-
|
728
|
+
|
643
729
|
Returns:
|
644
730
|
List[str]: 与任务关联的对话ID列表
|
645
731
|
"""
|
@@ -647,106 +733,129 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
647
733
|
|
648
734
|
def get_task_conversations(self, task_id: str) -> List[str]:
|
649
735
|
"""获取任务关联的所有对话ID
|
650
|
-
|
736
|
+
|
651
737
|
Args:
|
652
738
|
task_id: 任务ID
|
653
|
-
|
739
|
+
|
654
740
|
Returns:
|
655
741
|
List[str]: 与任务关联的对话ID列表
|
656
742
|
"""
|
657
743
|
return list(self._task_conversations.get(task_id, set()))
|
658
|
-
|
744
|
+
|
659
745
|
async def async_get_conversation_requests(self, conversation_id: str) -> List[str]:
|
660
746
|
"""获取对话关联的所有请求ID
|
661
|
-
|
747
|
+
|
662
748
|
Args:
|
663
749
|
conversation_id: 对话ID
|
664
|
-
|
750
|
+
|
665
751
|
Returns:
|
666
752
|
List[str]: 与对话关联的请求ID列表
|
667
753
|
"""
|
668
|
-
|
754
|
+
ret: List[str] = self._conversation_requests.get(conversation_id, [])
|
755
|
+
return ret
|
669
756
|
|
670
757
|
def get_conversation_requests(self, conversation_id: str) -> List[str]:
|
671
758
|
"""获取对话关联的所有请求ID
|
672
|
-
|
759
|
+
|
673
760
|
Args:
|
674
761
|
conversation_id: 对话ID
|
675
|
-
|
762
|
+
|
676
763
|
Returns:
|
677
764
|
List[str]: 与对话关联的请求ID列表
|
678
765
|
"""
|
679
|
-
|
680
|
-
|
681
|
-
|
766
|
+
ret: List[str] = self._conversation_requests.get(conversation_id, [])
|
767
|
+
|
768
|
+
return ret
|
769
|
+
|
770
|
+
async def async_get_request_task(
|
771
|
+
self, conversation_id: str, request_id: str
|
772
|
+
) -> Optional[str]:
|
682
773
|
"""获取请求关联的任务ID
|
683
|
-
|
774
|
+
|
684
775
|
Args:
|
685
776
|
conversation_id: 对话ID
|
686
777
|
request_id: 请求ID
|
687
|
-
|
778
|
+
|
688
779
|
Returns:
|
689
780
|
Optional[str]: 关联的任务ID,如果不存在则返回None
|
690
781
|
"""
|
691
|
-
|
782
|
+
ret: Optional[str] = self._request_task.get((conversation_id, request_id))
|
783
|
+
|
784
|
+
return ret
|
692
785
|
|
693
|
-
async def async_get_conversation_provider(
|
786
|
+
async def async_get_conversation_provider(
|
787
|
+
self, conversation_id: str
|
788
|
+
) -> Optional[str]:
|
694
789
|
"""获取请求关联的提供者ID
|
695
|
-
|
790
|
+
|
696
791
|
Args:
|
697
792
|
conversation_id: 对话ID
|
698
|
-
|
793
|
+
|
699
794
|
Returns:
|
700
795
|
Optional[str]: 关联的提供者ID,如果不存在则返回None
|
701
796
|
"""
|
702
|
-
|
797
|
+
ret: Optional[str] = self._conversation_provider.get(conversation_id)
|
798
|
+
|
799
|
+
return ret
|
703
800
|
|
704
801
|
async def async_check_conversation_exist(
|
705
802
|
self,
|
706
|
-
task_id:str,
|
803
|
+
task_id: str,
|
707
804
|
conversation_id: str,
|
708
805
|
) -> bool:
|
709
806
|
"""判断对话是否已存在
|
710
|
-
|
807
|
+
|
711
808
|
Args:
|
712
809
|
conversation_id: 对话标识符
|
713
810
|
provider_id: 使用特定提供者的ID(可选)
|
714
|
-
|
811
|
+
|
715
812
|
Returns:
|
716
813
|
bool: 如果对话存在返回True,否则返回False
|
717
814
|
"""
|
718
815
|
# 检查task_id是否存在且conversation_id是否在该task的对话集合中
|
719
|
-
if
|
816
|
+
if (
|
817
|
+
task_id in self._task_conversations
|
818
|
+
and conversation_id in self._task_conversations[task_id]
|
819
|
+
):
|
720
820
|
# 进一步验证该对话是否有关联的请求
|
721
|
-
if
|
821
|
+
if (
|
822
|
+
conversation_id in self._conversation_requests
|
823
|
+
and self._conversation_requests[conversation_id]
|
824
|
+
):
|
722
825
|
return True
|
723
826
|
|
724
827
|
return False
|
725
828
|
|
726
829
|
def check_conversation_exist(
|
727
830
|
self,
|
728
|
-
task_id:str,
|
831
|
+
task_id: str,
|
729
832
|
conversation_id: str,
|
730
833
|
) -> bool:
|
731
834
|
"""判断对话是否已存在
|
732
|
-
|
835
|
+
|
733
836
|
Args:
|
734
837
|
conversation_id: 对话标识符
|
735
838
|
provider_id: 使用特定提供者的ID(可选)
|
736
|
-
|
839
|
+
|
737
840
|
Returns:
|
738
841
|
bool: 如果对话存在返回True,否则返回False
|
739
842
|
"""
|
740
843
|
# 检查task_id是否存在且conversation_id是否在该task的对话集合中
|
741
|
-
if
|
844
|
+
if (
|
845
|
+
task_id in self._task_conversations
|
846
|
+
and conversation_id in self._task_conversations[task_id]
|
847
|
+
):
|
742
848
|
# 进一步验证该对话是否有关联的请求
|
743
|
-
if
|
849
|
+
if (
|
850
|
+
conversation_id in self._conversation_requests
|
851
|
+
and self._conversation_requests[conversation_id]
|
852
|
+
):
|
744
853
|
return True
|
745
854
|
|
746
855
|
return False
|
747
856
|
|
748
|
-
async def async_shutdown(self):
|
749
|
-
|
857
|
+
async def async_shutdown(self) -> None:
|
858
|
+
pass
|
750
859
|
|
751
|
-
def shutdown(self):
|
752
|
-
|
860
|
+
def shutdown(self) -> None:
|
861
|
+
pass
|