gohumanloop 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gohumanloop/__init__.py +15 -9
- gohumanloop/adapters/__init__.py +4 -4
- gohumanloop/adapters/langgraph_adapter.py +365 -220
- gohumanloop/cli/main.py +4 -1
- gohumanloop/core/interface.py +181 -215
- gohumanloop/core/manager.py +341 -361
- gohumanloop/manager/ghl_manager.py +223 -185
- gohumanloop/models/api_model.py +32 -7
- gohumanloop/models/glh_model.py +15 -11
- gohumanloop/providers/api_provider.py +233 -189
- gohumanloop/providers/base.py +179 -172
- gohumanloop/providers/email_provider.py +386 -325
- gohumanloop/providers/ghl_provider.py +19 -17
- gohumanloop/providers/terminal_provider.py +111 -92
- gohumanloop/utils/__init__.py +7 -1
- gohumanloop/utils/context_formatter.py +20 -15
- gohumanloop/utils/threadsafedict.py +64 -56
- gohumanloop/utils/utils.py +28 -28
- gohumanloop-0.0.6.dist-info/METADATA +259 -0
- gohumanloop-0.0.6.dist-info/RECORD +30 -0
- {gohumanloop-0.0.4.dist-info → gohumanloop-0.0.6.dist-info}/WHEEL +1 -1
- gohumanloop-0.0.4.dist-info/METADATA +0 -35
- gohumanloop-0.0.4.dist-info/RECORD +0 -30
- {gohumanloop-0.0.4.dist-info → gohumanloop-0.0.6.dist-info}/entry_points.txt +0 -0
- {gohumanloop-0.0.4.dist-info → gohumanloop-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {gohumanloop-0.0.4.dist-info → gohumanloop-0.0.6.dist-info}/top_level.txt +0 -0
gohumanloop/core/manager.py
CHANGED
@@ -1,19 +1,51 @@
|
|
1
|
-
from typing import Dict, Any, Optional, List, Union
|
1
|
+
from typing import Dict, Any, Optional, List, Union
|
2
2
|
import asyncio
|
3
|
-
import
|
3
|
+
from gohumanloop.utils import run_async_safely
|
4
4
|
|
5
5
|
from gohumanloop.core.interface import (
|
6
|
-
HumanLoopManager,
|
7
|
-
|
6
|
+
HumanLoopManager,
|
7
|
+
HumanLoopProvider,
|
8
|
+
HumanLoopCallback,
|
9
|
+
HumanLoopResult,
|
10
|
+
HumanLoopStatus,
|
11
|
+
HumanLoopType,
|
8
12
|
)
|
9
13
|
|
14
|
+
|
10
15
|
class DefaultHumanLoopManager(HumanLoopManager):
|
11
16
|
"""默认人机循环管理器实现"""
|
12
|
-
|
13
|
-
def __init__(
|
14
|
-
self
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
initial_providers: Optional[
|
21
|
+
Union[HumanLoopProvider, List[HumanLoopProvider]]
|
22
|
+
] = None,
|
23
|
+
):
|
24
|
+
self.providers: dict[str, HumanLoopProvider] = {}
|
15
25
|
self.default_provider_id = None
|
16
|
-
|
26
|
+
|
27
|
+
# 存储请求和回调的映射
|
28
|
+
self._callbacks: dict[tuple[str, str], HumanLoopCallback] = {}
|
29
|
+
# 存储请求的超时任务
|
30
|
+
self._timeout_tasks: dict[tuple[str, str], asyncio.Task] = {}
|
31
|
+
|
32
|
+
# 存储task_id与conversation_id的映射关系
|
33
|
+
self._task_conversations: dict[
|
34
|
+
str, set[str]
|
35
|
+
] = {} # task_id -> Set[conversation_id]
|
36
|
+
# 存储conversation_id与request_id的映射关系
|
37
|
+
self._conversation_requests: dict[
|
38
|
+
str, list[str]
|
39
|
+
] = {} # conversation_id -> List[request_id]
|
40
|
+
# 存储request_id与task_id的反向映射
|
41
|
+
self._request_task: dict[
|
42
|
+
tuple[str, str], str
|
43
|
+
] = {} # (conversation_id, request_id) -> task_id
|
44
|
+
# 存储对话对应的provider_id
|
45
|
+
self._conversation_provider: dict[
|
46
|
+
str, str
|
47
|
+
] = {} # conversation_id -> provider_id
|
48
|
+
|
17
49
|
# 初始化提供者
|
18
50
|
if initial_providers:
|
19
51
|
if isinstance(initial_providers, list):
|
@@ -26,41 +58,33 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
26
58
|
# 处理单个提供者
|
27
59
|
self.register_provider_sync(initial_providers, initial_providers.name)
|
28
60
|
self.default_provider_id = initial_providers.name
|
29
|
-
|
30
|
-
|
31
|
-
self
|
32
|
-
|
33
|
-
self._timeout_tasks = {}
|
34
|
-
|
35
|
-
# 存储task_id与conversation_id的映射关系
|
36
|
-
self._task_conversations = {} # task_id -> Set[conversation_id]
|
37
|
-
# 存储conversation_id与request_id的映射关系
|
38
|
-
self._conversation_requests = {} # conversation_id -> List[request_id]
|
39
|
-
# 存储request_id与task_id的反向映射
|
40
|
-
self._request_task = {} # (conversation_id, request_id) -> task_id
|
41
|
-
# 存储对话对应的provider_id
|
42
|
-
self._conversation_provider = {} # conversation_id -> provider_id
|
43
|
-
|
44
|
-
def register_provider_sync(self, provider: HumanLoopProvider, provider_id: Optional[str]) -> str:
|
61
|
+
|
62
|
+
def register_provider_sync(
|
63
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str]
|
64
|
+
) -> str:
|
45
65
|
"""同步注册提供者(用于初始化)"""
|
46
66
|
if not provider_id:
|
47
67
|
provider_id = f"provider_{len(self.providers) + 1}"
|
48
|
-
|
68
|
+
|
49
69
|
self.providers[provider_id] = provider
|
50
|
-
|
70
|
+
|
51
71
|
if not self.default_provider_id:
|
52
72
|
self.default_provider_id = provider_id
|
53
|
-
|
73
|
+
|
54
74
|
return provider_id
|
55
|
-
|
56
|
-
async def async_register_provider(
|
75
|
+
|
76
|
+
async def async_register_provider(
|
77
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str] = None
|
78
|
+
) -> str:
|
57
79
|
"""注册人机循环提供者"""
|
58
80
|
return self.register_provider_sync(provider, provider_id)
|
59
|
-
|
60
|
-
def register_provider(
|
81
|
+
|
82
|
+
def register_provider(
|
83
|
+
self, provider: HumanLoopProvider, provider_id: Optional[str] = None
|
84
|
+
) -> str:
|
61
85
|
"""注册人机循环提供者(同步版本)"""
|
62
86
|
return self.register_provider_sync(provider, provider_id)
|
63
|
-
|
87
|
+
|
64
88
|
async def async_request_humanloop(
|
65
89
|
self,
|
66
90
|
task_id: str,
|
@@ -78,13 +102,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
78
102
|
provider_id = provider_id or self.default_provider_id
|
79
103
|
if not provider_id or provider_id not in self.providers:
|
80
104
|
raise ValueError(f"Provider '{provider_id}' not found")
|
81
|
-
|
105
|
+
|
82
106
|
# 检查对话是否已存在且使用了不同的提供者
|
83
|
-
if
|
84
|
-
|
85
|
-
|
107
|
+
if (
|
108
|
+
conversation_id in self._conversation_provider
|
109
|
+
and self._conversation_provider[conversation_id] != provider_id
|
110
|
+
):
|
111
|
+
raise ValueError(
|
112
|
+
f"Conversation '{conversation_id}' already exists with a different provider"
|
113
|
+
)
|
114
|
+
|
86
115
|
provider = self.providers[provider_id]
|
87
|
-
|
116
|
+
|
88
117
|
try:
|
89
118
|
# 发送请求
|
90
119
|
result = await provider.async_request_humanloop(
|
@@ -93,38 +122,44 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
93
122
|
loop_type=loop_type,
|
94
123
|
context=context,
|
95
124
|
metadata=metadata,
|
96
|
-
timeout=timeout
|
125
|
+
timeout=timeout,
|
97
126
|
)
|
98
|
-
|
127
|
+
|
99
128
|
request_id = result.request_id
|
100
|
-
|
129
|
+
|
101
130
|
if not request_id:
|
102
|
-
raise ValueError(
|
103
|
-
|
131
|
+
raise ValueError(
|
132
|
+
f"Failed to request humanloop for conversation '{conversation_id}'"
|
133
|
+
)
|
134
|
+
|
104
135
|
# 存储task_id、conversation_id和request_id的关系
|
105
136
|
if task_id not in self._task_conversations:
|
106
137
|
self._task_conversations[task_id] = set()
|
107
138
|
self._task_conversations[task_id].add(conversation_id)
|
108
|
-
|
139
|
+
|
109
140
|
if conversation_id not in self._conversation_requests:
|
110
141
|
self._conversation_requests[conversation_id] = []
|
111
142
|
self._conversation_requests[conversation_id].append(request_id)
|
112
|
-
|
143
|
+
|
113
144
|
self._request_task[(conversation_id, request_id)] = task_id
|
114
145
|
# 存储对话对应的provider_id
|
115
146
|
self._conversation_provider[conversation_id] = provider_id
|
116
|
-
|
147
|
+
|
117
148
|
# 如果提供了回调,存储它
|
118
149
|
if callback:
|
119
150
|
self._callbacks[(conversation_id, request_id)] = callback
|
120
|
-
|
151
|
+
|
121
152
|
# 如果设置了超时,创建超时任务
|
122
153
|
if timeout:
|
123
|
-
await self._async_create_timeout_task(
|
124
|
-
|
154
|
+
await self._async_create_timeout_task(
|
155
|
+
conversation_id, request_id, timeout, provider, callback
|
156
|
+
)
|
157
|
+
|
125
158
|
# 如果是阻塞模式,等待结果
|
126
159
|
if blocking:
|
127
|
-
return await self._async_wait_for_result(
|
160
|
+
return await self._async_wait_for_result(
|
161
|
+
conversation_id, request_id, provider, timeout
|
162
|
+
)
|
128
163
|
else:
|
129
164
|
return request_id
|
130
165
|
except Exception as e:
|
@@ -132,7 +167,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
132
167
|
if callback:
|
133
168
|
try:
|
134
169
|
await callback.async_on_humanloop_error(provider, e)
|
135
|
-
except:
|
170
|
+
except Exception:
|
136
171
|
# 如果错误回调也失败,只能忽略
|
137
172
|
pass
|
138
173
|
raise # 重新抛出异常,让调用者知道发生了错误
|
@@ -150,30 +185,22 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
150
185
|
blocking: bool = False,
|
151
186
|
) -> Union[str, HumanLoopResult]:
|
152
187
|
"""请求人机循环(同步版本)"""
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
loop_type=loop_type,
|
166
|
-
context=context,
|
167
|
-
callback=callback,
|
168
|
-
metadata=metadata,
|
169
|
-
provider_id=provider_id,
|
170
|
-
timeout=timeout,
|
171
|
-
blocking=blocking
|
172
|
-
)
|
188
|
+
|
189
|
+
result: Union[str, HumanLoopResult] = run_async_safely(
|
190
|
+
self.async_request_humanloop(
|
191
|
+
task_id=task_id,
|
192
|
+
conversation_id=conversation_id,
|
193
|
+
loop_type=loop_type,
|
194
|
+
context=context,
|
195
|
+
callback=callback,
|
196
|
+
metadata=metadata,
|
197
|
+
provider_id=provider_id,
|
198
|
+
timeout=timeout,
|
199
|
+
blocking=blocking,
|
173
200
|
)
|
174
|
-
|
175
|
-
|
176
|
-
|
201
|
+
)
|
202
|
+
|
203
|
+
return result
|
177
204
|
|
178
205
|
async def async_continue_humanloop(
|
179
206
|
self,
|
@@ -190,16 +217,18 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
190
217
|
if conversation_id in self._conversation_provider:
|
191
218
|
stored_provider_id = self._conversation_provider[conversation_id]
|
192
219
|
if provider_id and provider_id != stored_provider_id:
|
193
|
-
raise ValueError(
|
220
|
+
raise ValueError(
|
221
|
+
f"Conversation '{conversation_id}' already exists with provider '{stored_provider_id}'"
|
222
|
+
)
|
194
223
|
provider_id = stored_provider_id
|
195
224
|
else:
|
196
225
|
provider_id = provider_id or self.default_provider_id
|
197
|
-
|
226
|
+
|
198
227
|
if not provider_id or provider_id not in self.providers:
|
199
228
|
raise ValueError(f"Provider '{provider_id}' not found")
|
200
|
-
|
229
|
+
|
201
230
|
provider = self.providers[provider_id]
|
202
|
-
|
231
|
+
|
203
232
|
try:
|
204
233
|
# 发送继续请求
|
205
234
|
result = await provider.async_continue_humanloop(
|
@@ -208,42 +237,48 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
208
237
|
metadata=metadata,
|
209
238
|
timeout=timeout,
|
210
239
|
)
|
211
|
-
|
240
|
+
|
212
241
|
request_id = result.request_id
|
213
242
|
|
214
243
|
if not request_id:
|
215
|
-
raise ValueError(
|
216
|
-
|
244
|
+
raise ValueError(
|
245
|
+
f"Failed to continue humanloop for conversation '{conversation_id}'"
|
246
|
+
)
|
247
|
+
|
217
248
|
# 更新conversation_id和request_id的关系
|
218
249
|
if conversation_id not in self._conversation_requests:
|
219
250
|
self._conversation_requests[conversation_id] = []
|
220
251
|
self._conversation_requests[conversation_id].append(request_id)
|
221
|
-
|
252
|
+
|
222
253
|
# 查找此conversation_id对应的task_id
|
223
254
|
task_id = None
|
224
255
|
for t_id, convs in self._task_conversations.items():
|
225
256
|
if conversation_id in convs:
|
226
257
|
task_id = t_id
|
227
258
|
break
|
228
|
-
|
259
|
+
|
229
260
|
if task_id:
|
230
261
|
self._request_task[(conversation_id, request_id)] = task_id
|
231
|
-
|
262
|
+
|
232
263
|
# 存储对话对应的provider_id,如果对话不存在才存储
|
233
264
|
if conversation_id not in self._conversation_provider:
|
234
265
|
self._conversation_provider[conversation_id] = provider_id
|
235
|
-
|
266
|
+
|
236
267
|
# 如果提供了回调,存储它
|
237
268
|
if callback:
|
238
269
|
self._callbacks[(conversation_id, request_id)] = callback
|
239
|
-
|
270
|
+
|
240
271
|
# 如果设置了超时,创建超时任务
|
241
272
|
if timeout:
|
242
|
-
await self._async_create_timeout_task(
|
243
|
-
|
273
|
+
await self._async_create_timeout_task(
|
274
|
+
conversation_id, request_id, timeout, provider, callback
|
275
|
+
)
|
276
|
+
|
244
277
|
# 如果是阻塞模式,等待结果
|
245
278
|
if blocking:
|
246
|
-
return await self._async_wait_for_result(
|
279
|
+
return await self._async_wait_for_result(
|
280
|
+
conversation_id, request_id, provider, timeout
|
281
|
+
)
|
247
282
|
else:
|
248
283
|
return request_id
|
249
284
|
except Exception as e:
|
@@ -251,7 +286,7 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
251
286
|
if callback:
|
252
287
|
try:
|
253
288
|
await callback.async_on_humanloop_error(provider, e)
|
254
|
-
except:
|
289
|
+
except Exception:
|
255
290
|
# 如果错误回调也失败,只能忽略
|
256
291
|
pass
|
257
292
|
raise # 重新抛出异常,让调用者知道发生了错误
|
@@ -267,53 +302,49 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
267
302
|
blocking: bool = False,
|
268
303
|
) -> Union[str, HumanLoopResult]:
|
269
304
|
"""继续人机循环(同步版本)"""
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
conversation_id=conversation_id,
|
281
|
-
context=context,
|
282
|
-
callback=callback,
|
283
|
-
metadata=metadata,
|
284
|
-
provider_id=provider_id,
|
285
|
-
timeout=timeout,
|
286
|
-
blocking=blocking
|
287
|
-
)
|
305
|
+
|
306
|
+
result: Union[str, HumanLoopResult] = run_async_safely(
|
307
|
+
self.async_continue_humanloop(
|
308
|
+
conversation_id=conversation_id,
|
309
|
+
context=context,
|
310
|
+
callback=callback,
|
311
|
+
metadata=metadata,
|
312
|
+
provider_id=provider_id,
|
313
|
+
timeout=timeout,
|
314
|
+
blocking=blocking,
|
288
315
|
)
|
289
|
-
|
290
|
-
|
291
|
-
|
316
|
+
)
|
317
|
+
|
318
|
+
return result
|
292
319
|
|
293
320
|
async def async_check_request_status(
|
294
|
-
self,
|
295
|
-
conversation_id: str,
|
296
|
-
request_id: str,
|
297
|
-
provider_id: Optional[str] = None
|
321
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
298
322
|
) -> HumanLoopResult:
|
299
323
|
"""检查请求状态"""
|
300
324
|
# 如果没有指定provider_id,尝试从存储的映射中获取
|
301
325
|
if provider_id is None:
|
302
326
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
303
327
|
provider_id = stored_provider_id or self.default_provider_id
|
304
|
-
|
328
|
+
|
305
329
|
if not provider_id or provider_id not in self.providers:
|
306
330
|
raise ValueError(f"Provider '{provider_id}' not found")
|
307
|
-
|
331
|
+
|
308
332
|
provider = self.providers[provider_id]
|
309
|
-
|
333
|
+
|
310
334
|
try:
|
311
|
-
result = await provider.async_check_request_status(
|
312
|
-
|
335
|
+
result = await provider.async_check_request_status(
|
336
|
+
conversation_id, request_id
|
337
|
+
)
|
338
|
+
|
313
339
|
# 如果有回调且状态不是等待或进行中,触发状态更新回调
|
314
|
-
if (
|
315
|
-
|
316
|
-
|
340
|
+
if (
|
341
|
+
conversation_id,
|
342
|
+
request_id,
|
343
|
+
) in self._callbacks and result.status not in [HumanLoopStatus.PENDING]:
|
344
|
+
await self._async_trigger_update_callback(
|
345
|
+
conversation_id, request_id, provider, result
|
346
|
+
)
|
347
|
+
|
317
348
|
return result
|
318
349
|
except Exception as e:
|
319
350
|
# 处理检查状态过程中的异常
|
@@ -321,350 +352,263 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
321
352
|
if callback:
|
322
353
|
try:
|
323
354
|
await callback.async_on_humanloop_error(provider, e)
|
324
|
-
except:
|
355
|
+
except Exception:
|
325
356
|
# 如果错误回调也失败,只能忽略
|
326
357
|
pass
|
327
358
|
raise # 重新抛出异常,让调用者知道发生了错误
|
328
359
|
|
329
|
-
|
330
360
|
def check_request_status(
|
331
|
-
self,
|
332
|
-
conversation_id: str,
|
333
|
-
request_id: str,
|
334
|
-
provider_id: Optional[str] = None
|
361
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
335
362
|
) -> HumanLoopResult:
|
336
363
|
"""检查请求状态(同步版本)"""
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
try:
|
345
|
-
return loop.run_until_complete(
|
346
|
-
self.async_check_request_status(
|
347
|
-
conversation_id=conversation_id,
|
348
|
-
request_id=request_id,
|
349
|
-
provider_id=provider_id
|
350
|
-
)
|
364
|
+
|
365
|
+
result: HumanLoopResult = run_async_safely(
|
366
|
+
self.async_check_request_status(
|
367
|
+
conversation_id=conversation_id,
|
368
|
+
request_id=request_id,
|
369
|
+
provider_id=provider_id,
|
351
370
|
)
|
352
|
-
|
353
|
-
if loop != asyncio.get_event_loop():
|
354
|
-
loop.close()
|
371
|
+
)
|
355
372
|
|
373
|
+
return result
|
356
374
|
|
357
375
|
async def async_check_conversation_status(
|
358
|
-
self,
|
359
|
-
conversation_id: str,
|
360
|
-
provider_id: Optional[str] = None
|
376
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
361
377
|
) -> HumanLoopResult:
|
362
378
|
"""检查对话状态"""
|
363
379
|
# 优先使用对话已关联的提供者
|
364
380
|
if provider_id is None:
|
365
381
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
366
382
|
provider_id = stored_provider_id or self.default_provider_id
|
367
|
-
|
383
|
+
|
368
384
|
if not provider_id or provider_id not in self.providers:
|
369
385
|
raise ValueError(f"Provider '{provider_id}' not found")
|
370
|
-
|
386
|
+
|
371
387
|
# 检查对话指定provider_id或默认provider_id最后一次请求的状态
|
372
388
|
provider = self.providers[provider_id]
|
373
|
-
|
389
|
+
|
374
390
|
try:
|
375
391
|
# 检查对话指定provider_id或默认provider_id最后一次请求的状态
|
376
392
|
return await provider.async_check_conversation_status(conversation_id)
|
377
393
|
except Exception as e:
|
378
394
|
# 处理检查对话状态过程中的异常
|
379
395
|
# 尝试找到与此对话关联的最后一个请求的回调
|
380
|
-
if
|
396
|
+
if (
|
397
|
+
conversation_id in self._conversation_requests
|
398
|
+
and self._conversation_requests[conversation_id]
|
399
|
+
):
|
381
400
|
last_request_id = self._conversation_requests[conversation_id][-1]
|
382
401
|
callback = self._callbacks.get((conversation_id, last_request_id))
|
383
402
|
if callback:
|
384
403
|
try:
|
385
404
|
await callback.async_on_humanloop_error(provider, e)
|
386
|
-
except:
|
405
|
+
except Exception:
|
387
406
|
# 如果错误回调也失败,只能忽略
|
388
407
|
pass
|
389
408
|
raise # 重新抛出异常,让调用者知道发生了错误
|
390
|
-
|
409
|
+
|
391
410
|
def check_conversation_status(
|
392
|
-
self,
|
393
|
-
conversation_id: str,
|
394
|
-
provider_id: Optional[str] = None
|
411
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
395
412
|
) -> HumanLoopResult:
|
396
413
|
"""检查对话状态(同步版本)"""
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
asyncio.set_event_loop(new_loop)
|
402
|
-
loop = new_loop
|
403
|
-
|
404
|
-
try:
|
405
|
-
return loop.run_until_complete(
|
406
|
-
self.async_check_conversation_status(
|
407
|
-
conversation_id=conversation_id,
|
408
|
-
provider_id=provider_id
|
409
|
-
)
|
414
|
+
|
415
|
+
result: HumanLoopResult = run_async_safely(
|
416
|
+
self.async_check_conversation_status(
|
417
|
+
conversation_id=conversation_id, provider_id=provider_id
|
410
418
|
)
|
411
|
-
|
412
|
-
|
413
|
-
|
419
|
+
)
|
420
|
+
|
421
|
+
return result
|
414
422
|
|
415
423
|
async def async_cancel_request(
|
416
|
-
self,
|
417
|
-
conversation_id: str,
|
418
|
-
request_id: str,
|
419
|
-
provider_id: Optional[str] = None
|
424
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
420
425
|
) -> bool:
|
421
426
|
"""取消特定请求"""
|
422
427
|
if provider_id is None:
|
423
428
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
424
429
|
provider_id = stored_provider_id or self.default_provider_id
|
425
|
-
|
430
|
+
|
426
431
|
if not provider_id or provider_id not in self.providers:
|
427
432
|
raise ValueError(f"Provider '{provider_id}' not found")
|
428
|
-
|
433
|
+
|
429
434
|
provider = self.providers[provider_id]
|
430
435
|
|
431
436
|
# 取消超时任务
|
432
437
|
if (conversation_id, request_id) in self._timeout_tasks:
|
433
438
|
self._timeout_tasks[(conversation_id, request_id)].cancel()
|
434
439
|
del self._timeout_tasks[(conversation_id, request_id)]
|
435
|
-
|
440
|
+
|
436
441
|
# 从回调映射中删除
|
437
442
|
if (conversation_id, request_id) in self._callbacks:
|
438
443
|
del self._callbacks[(conversation_id, request_id)]
|
439
|
-
|
444
|
+
|
440
445
|
# 清理request关联
|
441
446
|
if (conversation_id, request_id) in self._request_task:
|
442
447
|
del self._request_task[(conversation_id, request_id)]
|
443
|
-
|
448
|
+
|
444
449
|
# 从conversation_requests中移除
|
445
450
|
if conversation_id in self._conversation_requests:
|
446
451
|
if request_id in self._conversation_requests[conversation_id]:
|
447
452
|
self._conversation_requests[conversation_id].remove(request_id)
|
448
|
-
|
453
|
+
|
449
454
|
return await provider.async_cancel_request(conversation_id, request_id)
|
450
|
-
|
451
455
|
|
452
456
|
def cancel_request(
|
453
|
-
self,
|
454
|
-
conversation_id: str,
|
455
|
-
request_id: str,
|
456
|
-
provider_id: Optional[str] = None
|
457
|
+
self, conversation_id: str, request_id: str, provider_id: Optional[str] = None
|
457
458
|
) -> bool:
|
458
459
|
"""取消特定请求(同步版本)"""
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
try:
|
467
|
-
return loop.run_until_complete(
|
468
|
-
self.async_cancel_request(
|
469
|
-
conversation_id=conversation_id,
|
470
|
-
request_id=request_id,
|
471
|
-
provider_id=provider_id
|
472
|
-
)
|
460
|
+
|
461
|
+
result: bool = run_async_safely(
|
462
|
+
self.async_cancel_request(
|
463
|
+
conversation_id=conversation_id,
|
464
|
+
request_id=request_id,
|
465
|
+
provider_id=provider_id,
|
473
466
|
)
|
474
|
-
|
475
|
-
if loop != asyncio.get_event_loop():
|
476
|
-
loop.close()
|
467
|
+
)
|
477
468
|
|
469
|
+
return result
|
478
470
|
|
479
471
|
async def async_cancel_conversation(
|
480
|
-
self,
|
481
|
-
conversation_id: str,
|
482
|
-
provider_id: Optional[str] = None
|
472
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
483
473
|
) -> bool:
|
484
474
|
"""取消整个对话"""
|
485
475
|
# 优先使用对话已关联的提供者
|
486
476
|
if provider_id is None:
|
487
477
|
stored_provider_id = self._conversation_provider.get(conversation_id)
|
488
478
|
provider_id = stored_provider_id or self.default_provider_id
|
489
|
-
|
479
|
+
|
490
480
|
if not provider_id or provider_id not in self.providers:
|
491
481
|
raise ValueError(f"Provider '{provider_id}' not found")
|
492
|
-
|
482
|
+
|
493
483
|
provider = self.providers[provider_id]
|
494
|
-
|
484
|
+
|
495
485
|
# 取消与此对话相关的所有超时任务和回调
|
496
486
|
keys_to_remove = []
|
497
487
|
for key in self._timeout_tasks:
|
498
488
|
if key[0] == conversation_id:
|
499
489
|
self._timeout_tasks[key].cancel()
|
500
490
|
keys_to_remove.append(key)
|
501
|
-
|
491
|
+
|
502
492
|
for key in keys_to_remove:
|
503
493
|
del self._timeout_tasks[key]
|
504
|
-
|
494
|
+
|
505
495
|
keys_to_remove = []
|
506
496
|
for key in self._callbacks:
|
507
497
|
if key[0] == conversation_id:
|
508
498
|
keys_to_remove.append(key)
|
509
|
-
|
499
|
+
|
510
500
|
for key in keys_to_remove:
|
511
501
|
del self._callbacks[key]
|
512
|
-
|
502
|
+
|
513
503
|
# 清理与此对话相关的task映射关系
|
514
504
|
# 1. 从task_conversations中移除此对话
|
515
505
|
task_ids_to_update = []
|
516
506
|
for task_id, convs in self._task_conversations.items():
|
517
507
|
if conversation_id in convs:
|
518
508
|
task_ids_to_update.append(task_id)
|
519
|
-
|
509
|
+
|
520
510
|
for task_id in task_ids_to_update:
|
521
511
|
self._task_conversations[task_id].remove(conversation_id)
|
522
512
|
# 如果task没有关联的对话了,可以考虑删除该task记录
|
523
513
|
if not self._task_conversations[task_id]:
|
524
514
|
del self._task_conversations[task_id]
|
525
|
-
|
515
|
+
|
526
516
|
# 2. 获取并清理所有与此对话相关的请求
|
527
517
|
request_ids = self._conversation_requests.get(conversation_id, [])
|
528
518
|
for request_id in request_ids:
|
529
519
|
# 清理request_task映射
|
530
520
|
if (conversation_id, request_id) in self._request_task:
|
531
521
|
del self._request_task[(conversation_id, request_id)]
|
532
|
-
|
522
|
+
|
533
523
|
# 3. 清理conversation_requests映射
|
534
524
|
if conversation_id in self._conversation_requests:
|
535
525
|
del self._conversation_requests[conversation_id]
|
536
|
-
|
526
|
+
|
537
527
|
# 4. 清理provider关联
|
538
528
|
if conversation_id in self._conversation_provider:
|
539
529
|
del self._conversation_provider[conversation_id]
|
540
|
-
|
530
|
+
|
541
531
|
return await provider.async_cancel_conversation(conversation_id)
|
542
|
-
|
543
532
|
|
544
533
|
def cancel_conversation(
|
545
|
-
self,
|
546
|
-
conversation_id: str,
|
547
|
-
provider_id: Optional[str] = None
|
534
|
+
self, conversation_id: str, provider_id: Optional[str] = None
|
548
535
|
) -> bool:
|
549
536
|
"""取消整个对话(同步版本)"""
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
asyncio.set_event_loop(new_loop)
|
555
|
-
loop = new_loop
|
556
|
-
|
557
|
-
try:
|
558
|
-
return loop.run_until_complete(
|
559
|
-
self.async_cancel_conversation(
|
560
|
-
conversation_id=conversation_id,
|
561
|
-
provider_id=provider_id
|
562
|
-
)
|
537
|
+
|
538
|
+
result: bool = run_async_safely(
|
539
|
+
self.async_cancel_conversation(
|
540
|
+
conversation_id=conversation_id, provider_id=provider_id
|
563
541
|
)
|
564
|
-
|
565
|
-
|
566
|
-
|
542
|
+
)
|
543
|
+
|
544
|
+
return result
|
567
545
|
|
568
|
-
|
569
546
|
async def async_get_provider(
|
570
|
-
self,
|
571
|
-
provider_id: Optional[str] = None
|
547
|
+
self, provider_id: Optional[str] = None
|
572
548
|
) -> HumanLoopProvider:
|
573
549
|
"""获取指定的提供者实例"""
|
574
550
|
provider_id = provider_id or self.default_provider_id
|
575
551
|
if not provider_id or provider_id not in self.providers:
|
576
552
|
raise ValueError(f"Provider '{provider_id}' not found")
|
577
|
-
|
553
|
+
|
578
554
|
return self.providers[provider_id]
|
579
|
-
|
580
|
-
def get_provider(
|
581
|
-
self,
|
582
|
-
provider_id: Optional[str] = None
|
583
|
-
) -> HumanLoopProvider:
|
555
|
+
|
556
|
+
def get_provider(self, provider_id: Optional[str] = None) -> HumanLoopProvider:
|
584
557
|
"""获取指定的提供者实例(同步版本)"""
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
try:
|
593
|
-
return loop.run_until_complete(
|
594
|
-
self.async_get_provider(provider_id=provider_id)
|
595
|
-
)
|
596
|
-
finally:
|
597
|
-
if loop != asyncio.get_event_loop():
|
598
|
-
loop.close()
|
558
|
+
|
559
|
+
result: HumanLoopProvider = run_async_safely(
|
560
|
+
self.async_get_provider(provider_id=provider_id)
|
561
|
+
)
|
562
|
+
|
563
|
+
return result
|
599
564
|
|
600
565
|
async def async_list_providers(self) -> Dict[str, HumanLoopProvider]:
|
601
566
|
"""列出所有注册的提供者"""
|
602
567
|
return self.providers
|
603
|
-
|
604
568
|
|
605
569
|
def list_providers(self) -> Dict[str, HumanLoopProvider]:
|
606
570
|
"""列出所有注册的提供者(同步版本)"""
|
607
|
-
loop = asyncio.get_event_loop()
|
608
|
-
if loop.is_running():
|
609
|
-
# 如果事件循环已经在运行,创建一个新的事件循环
|
610
|
-
new_loop = asyncio.new_event_loop()
|
611
|
-
asyncio.set_event_loop(new_loop)
|
612
|
-
loop = new_loop
|
613
|
-
|
614
|
-
try:
|
615
|
-
return loop.run_until_complete(self.async_list_providers())
|
616
|
-
finally:
|
617
|
-
if loop != asyncio.get_event_loop():
|
618
|
-
loop.close()
|
619
571
|
|
572
|
+
result: Dict[str, HumanLoopProvider] = run_async_safely(
|
573
|
+
self.async_list_providers()
|
574
|
+
)
|
620
575
|
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
) -> bool:
|
576
|
+
return result
|
577
|
+
|
578
|
+
async def async_set_default_provider(self, provider_id: str) -> bool:
|
625
579
|
"""设置默认提供者"""
|
626
580
|
if provider_id not in self.providers:
|
627
581
|
raise ValueError(f"Provider '{provider_id}' not found")
|
628
|
-
|
582
|
+
|
629
583
|
self.default_provider_id = provider_id
|
630
584
|
return True
|
631
|
-
|
632
585
|
|
633
|
-
def set_default_provider(
|
634
|
-
self,
|
635
|
-
provider_id: str
|
636
|
-
) -> bool:
|
586
|
+
def set_default_provider(self, provider_id: str) -> bool:
|
637
587
|
"""设置默认提供者(同步版本)"""
|
638
|
-
loop = asyncio.get_event_loop()
|
639
|
-
if loop.is_running():
|
640
|
-
# 如果事件循环已经在运行,创建一个新的事件循环
|
641
|
-
new_loop = asyncio.new_event_loop()
|
642
|
-
asyncio.set_event_loop(new_loop)
|
643
|
-
loop = new_loop
|
644
|
-
|
645
|
-
try:
|
646
|
-
return loop.run_until_complete(
|
647
|
-
self.async_set_default_provider(provider_id=provider_id)
|
648
|
-
)
|
649
|
-
finally:
|
650
|
-
if loop != asyncio.get_event_loop():
|
651
|
-
loop.close()
|
652
588
|
|
589
|
+
result: bool = run_async_safely(
|
590
|
+
self.async_set_default_provider(provider_id=provider_id)
|
591
|
+
)
|
592
|
+
|
593
|
+
return result
|
653
594
|
|
654
595
|
async def _async_create_timeout_task(
|
655
|
-
self,
|
596
|
+
self,
|
656
597
|
conversation_id: str,
|
657
|
-
request_id: str,
|
658
|
-
timeout: int,
|
598
|
+
request_id: str,
|
599
|
+
timeout: int,
|
659
600
|
provider: HumanLoopProvider,
|
660
|
-
callback: Optional[HumanLoopCallback]
|
661
|
-
):
|
601
|
+
callback: Optional[HumanLoopCallback],
|
602
|
+
) -> None:
|
662
603
|
"""创建超时任务"""
|
663
|
-
|
604
|
+
|
605
|
+
async def timeout_task() -> None:
|
664
606
|
await asyncio.sleep(timeout)
|
665
607
|
# 检查当前状态
|
666
|
-
result = await self.async_check_request_status(
|
667
|
-
|
608
|
+
result = await self.async_check_request_status(
|
609
|
+
conversation_id, request_id, provider.name
|
610
|
+
)
|
611
|
+
|
668
612
|
# 只有当状态为PENDING时才触发超时回调
|
669
613
|
# INPROGRESS状态表示对话正在进行中,不应视为超时
|
670
614
|
if result.status == HumanLoopStatus.PENDING:
|
@@ -678,55 +622,68 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
678
622
|
self._timeout_tasks[(conversation_id, request_id)].cancel()
|
679
623
|
new_task = asyncio.create_task(timeout_task())
|
680
624
|
self._timeout_tasks[(conversation_id, request_id)] = new_task
|
681
|
-
|
625
|
+
|
682
626
|
task = asyncio.create_task(timeout_task())
|
683
627
|
self._timeout_tasks[(conversation_id, request_id)] = task
|
684
|
-
|
628
|
+
|
685
629
|
async def _async_wait_for_result(
|
686
|
-
self,
|
630
|
+
self,
|
687
631
|
conversation_id: str,
|
688
|
-
request_id: str,
|
689
|
-
provider: HumanLoopProvider,
|
690
|
-
timeout: Optional[int] = None
|
632
|
+
request_id: str,
|
633
|
+
provider: HumanLoopProvider,
|
634
|
+
timeout: Optional[int] = None,
|
691
635
|
) -> HumanLoopResult:
|
692
636
|
"""等待循环结果"""
|
693
|
-
start_time = time.time()
|
694
637
|
poll_interval = 1.0 # 轮询间隔(秒)
|
695
|
-
|
638
|
+
|
696
639
|
while True:
|
697
|
-
result = await self.async_check_request_status(
|
698
|
-
|
699
|
-
|
640
|
+
result = await self.async_check_request_status(
|
641
|
+
conversation_id, request_id, provider.name
|
642
|
+
)
|
643
|
+
|
644
|
+
# 如果状态是最终状态(非PENDING),返回结果
|
700
645
|
if result.status != HumanLoopStatus.PENDING:
|
701
646
|
return result
|
702
|
-
|
647
|
+
|
703
648
|
# 等待一段时间后再次轮询
|
704
649
|
await asyncio.sleep(poll_interval)
|
705
|
-
|
706
|
-
async def _async_trigger_update_callback(
|
650
|
+
|
651
|
+
async def _async_trigger_update_callback(
|
652
|
+
self,
|
653
|
+
conversation_id: str,
|
654
|
+
request_id: str,
|
655
|
+
provider: HumanLoopProvider,
|
656
|
+
result: HumanLoopResult,
|
657
|
+
) -> None:
|
707
658
|
"""触发状态更新回调"""
|
708
|
-
callback: Optional[HumanLoopCallback] = self._callbacks.get(
|
659
|
+
callback: Optional[HumanLoopCallback] = self._callbacks.get(
|
660
|
+
(conversation_id, request_id)
|
661
|
+
)
|
709
662
|
if callback:
|
710
663
|
try:
|
711
|
-
await callback.
|
664
|
+
await callback.async_on_humanloop_update(provider, result)
|
712
665
|
# 如果状态是最终状态,可以考虑移除回调
|
713
|
-
if result.status not in [
|
666
|
+
if result.status not in [
|
667
|
+
HumanLoopStatus.PENDING,
|
668
|
+
HumanLoopStatus.INPROGRESS,
|
669
|
+
]:
|
714
670
|
del self._callbacks[(conversation_id, request_id)]
|
715
671
|
except Exception as e:
|
716
672
|
# 处理回调执行过程中的异常
|
717
673
|
try:
|
718
|
-
await callback.
|
719
|
-
except:
|
674
|
+
await callback.async_on_humanloop_error(provider, e)
|
675
|
+
except Exception:
|
720
676
|
# 如果错误回调也失败,只能忽略
|
721
677
|
pass
|
722
678
|
|
723
679
|
# 添加新方法用于获取task相关信息
|
680
|
+
|
724
681
|
async def async_get_task_conversations(self, task_id: str) -> List[str]:
|
725
682
|
"""获取任务关联的所有对话ID
|
726
|
-
|
683
|
+
|
727
684
|
Args:
|
728
685
|
task_id: 任务ID
|
729
|
-
|
686
|
+
|
730
687
|
Returns:
|
731
688
|
List[str]: 与任务关联的对话ID列表
|
732
689
|
"""
|
@@ -734,106 +691,129 @@ class DefaultHumanLoopManager(HumanLoopManager):
|
|
734
691
|
|
735
692
|
def get_task_conversations(self, task_id: str) -> List[str]:
|
736
693
|
"""获取任务关联的所有对话ID
|
737
|
-
|
694
|
+
|
738
695
|
Args:
|
739
696
|
task_id: 任务ID
|
740
|
-
|
697
|
+
|
741
698
|
Returns:
|
742
699
|
List[str]: 与任务关联的对话ID列表
|
743
700
|
"""
|
744
701
|
return list(self._task_conversations.get(task_id, set()))
|
745
|
-
|
702
|
+
|
746
703
|
async def async_get_conversation_requests(self, conversation_id: str) -> List[str]:
|
747
704
|
"""获取对话关联的所有请求ID
|
748
|
-
|
705
|
+
|
749
706
|
Args:
|
750
707
|
conversation_id: 对话ID
|
751
|
-
|
708
|
+
|
752
709
|
Returns:
|
753
710
|
List[str]: 与对话关联的请求ID列表
|
754
711
|
"""
|
755
|
-
|
712
|
+
ret: List[str] = self._conversation_requests.get(conversation_id, [])
|
713
|
+
return ret
|
756
714
|
|
757
715
|
def get_conversation_requests(self, conversation_id: str) -> List[str]:
|
758
716
|
"""获取对话关联的所有请求ID
|
759
|
-
|
717
|
+
|
760
718
|
Args:
|
761
719
|
conversation_id: 对话ID
|
762
|
-
|
720
|
+
|
763
721
|
Returns:
|
764
722
|
List[str]: 与对话关联的请求ID列表
|
765
723
|
"""
|
766
|
-
|
767
|
-
|
768
|
-
|
724
|
+
ret: List[str] = self._conversation_requests.get(conversation_id, [])
|
725
|
+
|
726
|
+
return ret
|
727
|
+
|
728
|
+
async def async_get_request_task(
|
729
|
+
self, conversation_id: str, request_id: str
|
730
|
+
) -> Optional[str]:
|
769
731
|
"""获取请求关联的任务ID
|
770
|
-
|
732
|
+
|
771
733
|
Args:
|
772
734
|
conversation_id: 对话ID
|
773
735
|
request_id: 请求ID
|
774
|
-
|
736
|
+
|
775
737
|
Returns:
|
776
738
|
Optional[str]: 关联的任务ID,如果不存在则返回None
|
777
739
|
"""
|
778
|
-
|
740
|
+
ret: Optional[str] = self._request_task.get((conversation_id, request_id))
|
741
|
+
|
742
|
+
return ret
|
779
743
|
|
780
|
-
async def async_get_conversation_provider(
|
744
|
+
async def async_get_conversation_provider(
|
745
|
+
self, conversation_id: str
|
746
|
+
) -> Optional[str]:
|
781
747
|
"""获取请求关联的提供者ID
|
782
|
-
|
748
|
+
|
783
749
|
Args:
|
784
750
|
conversation_id: 对话ID
|
785
|
-
|
751
|
+
|
786
752
|
Returns:
|
787
753
|
Optional[str]: 关联的提供者ID,如果不存在则返回None
|
788
754
|
"""
|
789
|
-
|
755
|
+
ret: Optional[str] = self._conversation_provider.get(conversation_id)
|
756
|
+
|
757
|
+
return ret
|
790
758
|
|
791
759
|
async def async_check_conversation_exist(
|
792
760
|
self,
|
793
|
-
task_id:str,
|
761
|
+
task_id: str,
|
794
762
|
conversation_id: str,
|
795
763
|
) -> bool:
|
796
764
|
"""判断对话是否已存在
|
797
|
-
|
765
|
+
|
798
766
|
Args:
|
799
767
|
conversation_id: 对话标识符
|
800
768
|
provider_id: 使用特定提供者的ID(可选)
|
801
|
-
|
769
|
+
|
802
770
|
Returns:
|
803
771
|
bool: 如果对话存在返回True,否则返回False
|
804
772
|
"""
|
805
773
|
# 检查task_id是否存在且conversation_id是否在该task的对话集合中
|
806
|
-
if
|
774
|
+
if (
|
775
|
+
task_id in self._task_conversations
|
776
|
+
and conversation_id in self._task_conversations[task_id]
|
777
|
+
):
|
807
778
|
# 进一步验证该对话是否有关联的请求
|
808
|
-
if
|
779
|
+
if (
|
780
|
+
conversation_id in self._conversation_requests
|
781
|
+
and self._conversation_requests[conversation_id]
|
782
|
+
):
|
809
783
|
return True
|
810
784
|
|
811
785
|
return False
|
812
786
|
|
813
787
|
def check_conversation_exist(
|
814
788
|
self,
|
815
|
-
task_id:str,
|
789
|
+
task_id: str,
|
816
790
|
conversation_id: str,
|
817
791
|
) -> bool:
|
818
792
|
"""判断对话是否已存在
|
819
|
-
|
793
|
+
|
820
794
|
Args:
|
821
795
|
conversation_id: 对话标识符
|
822
796
|
provider_id: 使用特定提供者的ID(可选)
|
823
|
-
|
797
|
+
|
824
798
|
Returns:
|
825
799
|
bool: 如果对话存在返回True,否则返回False
|
826
800
|
"""
|
827
801
|
# 检查task_id是否存在且conversation_id是否在该task的对话集合中
|
828
|
-
if
|
802
|
+
if (
|
803
|
+
task_id in self._task_conversations
|
804
|
+
and conversation_id in self._task_conversations[task_id]
|
805
|
+
):
|
829
806
|
# 进一步验证该对话是否有关联的请求
|
830
|
-
if
|
807
|
+
if (
|
808
|
+
conversation_id in self._conversation_requests
|
809
|
+
and self._conversation_requests[conversation_id]
|
810
|
+
):
|
831
811
|
return True
|
832
812
|
|
833
813
|
return False
|
834
814
|
|
835
|
-
async def async_shutdown(self):
|
836
|
-
|
815
|
+
async def async_shutdown(self) -> None:
|
816
|
+
pass
|
837
817
|
|
838
|
-
def shutdown(self):
|
839
|
-
|
818
|
+
def shutdown(self) -> None:
|
819
|
+
pass
|