jupyter-agent 2025.6.103__py3-none-any.whl → 2025.6.105__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jupyter_agent/bot_actions.py +270 -0
- jupyter_agent/bot_agents/__init__.py +0 -42
- jupyter_agent/bot_agents/base.py +85 -45
- jupyter_agent/bot_agents/master_planner.py +2 -0
- jupyter_agent/bot_agents/output_task_result.py +6 -7
- jupyter_agent/bot_agents/request_user_supply.py +186 -0
- jupyter_agent/bot_agents/task_planner_v3.py +12 -13
- jupyter_agent/bot_agents/task_reasoner.py +2 -2
- jupyter_agent/bot_agents/task_structrue_reasoner.py +19 -12
- jupyter_agent/bot_agents/task_structrue_summarier.py +19 -18
- jupyter_agent/bot_agents/task_summarier.py +2 -2
- jupyter_agent/bot_agents/task_verifier.py +1 -1
- jupyter_agent/bot_agents/task_verify_summarier.py +5 -6
- jupyter_agent/bot_chat.py +2 -2
- jupyter_agent/bot_contexts.py +28 -23
- jupyter_agent/bot_evaluation.py +325 -0
- jupyter_agent/bot_evaluators/__init__.py +0 -0
- jupyter_agent/bot_evaluators/base.py +42 -0
- jupyter_agent/bot_evaluators/dummy_flow.py +20 -0
- jupyter_agent/bot_evaluators/dummy_global.py +20 -0
- jupyter_agent/bot_evaluators/dummy_task.py +20 -0
- jupyter_agent/bot_evaluators/flow_global_planning.py +88 -0
- jupyter_agent/bot_evaluators/flow_task_executor.py +152 -0
- jupyter_agent/bot_flows/__init__.py +0 -4
- jupyter_agent/bot_flows/base.py +114 -10
- jupyter_agent/bot_flows/master_planner.py +7 -2
- jupyter_agent/bot_flows/task_executor_v3.py +45 -20
- jupyter_agent/bot_magics.py +108 -53
- jupyter_agent/bot_outputs.py +56 -3
- jupyter_agent/utils.py +20 -31
- {jupyter_agent-2025.6.103.dist-info → jupyter_agent-2025.6.105.dist-info}/METADATA +39 -8
- jupyter_agent-2025.6.105.dist-info/RECORD +40 -0
- jupyter_agent-2025.6.105.dist-info/entry_points.txt +2 -0
- jupyter_agent/bot_agents/task_planner_v1.py +0 -158
- jupyter_agent/bot_agents/task_planner_v2.py +0 -172
- jupyter_agent/bot_flows/task_executor_v1.py +0 -86
- jupyter_agent/bot_flows/task_executor_v2.py +0 -84
- jupyter_agent-2025.6.103.dist-info/RECORD +0 -33
- {jupyter_agent-2025.6.103.dist-info → jupyter_agent-2025.6.105.dist-info}/WHEEL +0 -0
- {jupyter_agent-2025.6.103.dist-info → jupyter_agent-2025.6.105.dist-info}/licenses/LICENSE +0 -0
- {jupyter_agent-2025.6.103.dist-info → jupyter_agent-2025.6.105.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 viewstar000
|
3
|
+
|
4
|
+
This software is released under the MIT License.
|
5
|
+
https://opensource.org/licenses/MIT
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
import time
|
10
|
+
import uuid
|
11
|
+
import threading
|
12
|
+
import queue
|
13
|
+
import traceback
|
14
|
+
import importlib
|
15
|
+
import socket
|
16
|
+
|
17
|
+
from enum import Enum
|
18
|
+
from typing import Optional, Dict, List, Any
|
19
|
+
from pydantic import BaseModel, Field
|
20
|
+
from wsgiref.simple_server import make_server
|
21
|
+
from bottle import default_app, get, post, request, response
|
22
|
+
from .utils import get_env_capbilities
|
23
|
+
|
24
|
+
|
25
|
+
class ActionBase(BaseModel):
|
26
|
+
timestamp: float = 0
|
27
|
+
uuid: str = ""
|
28
|
+
source: str = ""
|
29
|
+
action: str
|
30
|
+
params: Dict[str, Any] = {}
|
31
|
+
|
32
|
+
def __init__(self, **data):
|
33
|
+
super().__init__(**data)
|
34
|
+
self.timestamp = self.timestamp or time.time()
|
35
|
+
self.uuid = self.uuid or str(uuid.uuid4())
|
36
|
+
|
37
|
+
|
38
|
+
class ReplyActionBase(ActionBase):
|
39
|
+
reply_host: str = ""
|
40
|
+
reply_port: int = 0
|
41
|
+
|
42
|
+
|
43
|
+
class SetCellContentParams(BaseModel):
|
44
|
+
index: int = 1 # -1 previous, 0 current, 1 next
|
45
|
+
type: str = "code" # code/markdown
|
46
|
+
source: str = ""
|
47
|
+
tags: List[str] = []
|
48
|
+
metadata: Dict[str, Any] = {}
|
49
|
+
|
50
|
+
|
51
|
+
class ActionSetCellContent(ActionBase):
|
52
|
+
|
53
|
+
action: str = "set_cell_content"
|
54
|
+
params: SetCellContentParams = SetCellContentParams()
|
55
|
+
|
56
|
+
|
57
|
+
class ConfirmChoiceItem(BaseModel):
|
58
|
+
label: str = ""
|
59
|
+
value: str
|
60
|
+
|
61
|
+
|
62
|
+
class RequestUserConfirmParams(BaseModel):
|
63
|
+
prompt: str = ""
|
64
|
+
choices: List[ConfirmChoiceItem] = []
|
65
|
+
default: str = ""
|
66
|
+
|
67
|
+
|
68
|
+
class ActionRequestUserConfirm(ReplyActionBase):
|
69
|
+
|
70
|
+
action: str = "request_user_confirm"
|
71
|
+
params: RequestUserConfirmParams = RequestUserConfirmParams()
|
72
|
+
|
73
|
+
|
74
|
+
class ReceiveUserConfirmParams(BaseModel):
|
75
|
+
result: str = ""
|
76
|
+
|
77
|
+
|
78
|
+
class ActionReceiveUserConfirm(ActionBase):
|
79
|
+
|
80
|
+
action: str = "receive_user_confirm"
|
81
|
+
params: ReceiveUserConfirmParams = ReceiveUserConfirmParams()
|
82
|
+
|
83
|
+
|
84
|
+
class RequestUserSupplyInfo(BaseModel):
|
85
|
+
prompt: str = Field(
|
86
|
+
description="需要用户补充详细信息的Prompt",
|
87
|
+
examples=["请补充与...相关的详细的信息", "请确认...是否...", "请提供..."],
|
88
|
+
)
|
89
|
+
example: Optional[str] = Field(None, description="示例", examples=["..."])
|
90
|
+
|
91
|
+
|
92
|
+
class UserSupplyInfoReply(BaseModel):
|
93
|
+
prompt: str = Field(description="需要用户补充详细信息的Prompt", examples=["..."])
|
94
|
+
reply: str = Field(description="用户补充的详细信息", examples=["..."])
|
95
|
+
|
96
|
+
|
97
|
+
class RequestUserSupplyInfoParams(BaseModel):
|
98
|
+
title: str = ""
|
99
|
+
issues: List[RequestUserSupplyInfo] = []
|
100
|
+
|
101
|
+
|
102
|
+
class ActionRequestUserSupplyInfo(ReplyActionBase):
|
103
|
+
|
104
|
+
action: str = "request_user_supply_info"
|
105
|
+
params: RequestUserSupplyInfoParams = RequestUserSupplyInfoParams()
|
106
|
+
|
107
|
+
|
108
|
+
class ReceiveUserSupplyInfoParams(BaseModel):
|
109
|
+
replies: List[UserSupplyInfoReply] = Field(
|
110
|
+
description="完成补充确认的信息列表",
|
111
|
+
examples=[
|
112
|
+
UserSupplyInfoReply(prompt="请确认...是否...", reply="是"),
|
113
|
+
UserSupplyInfoReply(prompt="请补充...", reply="..."),
|
114
|
+
],
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
class ActionReceiveUserSupplyInfo(ActionBase):
|
119
|
+
action: str = "receive_user_supply_info"
|
120
|
+
params: ReceiveUserSupplyInfoParams = ReceiveUserSupplyInfoParams(replies=[])
|
121
|
+
|
122
|
+
|
123
|
+
def request_user_reply(prompts: list[RequestUserSupplyInfo]) -> list[UserSupplyInfoReply]:
|
124
|
+
responses = []
|
125
|
+
for prompt in prompts:
|
126
|
+
response = input(f"{prompt.prompt} (例如: {prompt.example})")
|
127
|
+
responses.append(UserSupplyInfoReply(prompt=prompt.prompt, reply=response))
|
128
|
+
return responses
|
129
|
+
|
130
|
+
|
131
|
+
def get_action_class(action_name: str) -> type[ActionBase]:
|
132
|
+
for obj in globals().values():
|
133
|
+
if isinstance(obj, type) and issubclass(obj, ActionBase):
|
134
|
+
if obj.__name__ == action_name or obj.model_fields["action"].default == action_name:
|
135
|
+
return obj
|
136
|
+
raise ValueError(f"Unknown action: {action_name}")
|
137
|
+
|
138
|
+
|
139
|
+
class ActionReply(BaseModel):
|
140
|
+
reply_timestamp: float
|
141
|
+
retrieved_timestamp: float = 0
|
142
|
+
uuid: str
|
143
|
+
source: str = ""
|
144
|
+
action: str = ""
|
145
|
+
retrieved: bool = False
|
146
|
+
reply: ActionBase
|
147
|
+
|
148
|
+
|
149
|
+
class ActionDispatcher(threading.Thread):
|
150
|
+
def __init__(self, host="127.0.0.1", port=0, app=None):
|
151
|
+
super().__init__(daemon=True)
|
152
|
+
self.action_queue = queue.Queue()
|
153
|
+
self.action_replies: dict[str, ActionReply] = {}
|
154
|
+
self.app = app or default_app()
|
155
|
+
self.host = host
|
156
|
+
self.port = port
|
157
|
+
self.server = None
|
158
|
+
if get_env_capbilities().user_confirm or get_env_capbilities().user_supply_info:
|
159
|
+
self.port = self.port or self.select_port(self.host)
|
160
|
+
self.server = make_server(self.host, self.port, self.app)
|
161
|
+
self.start()
|
162
|
+
|
163
|
+
def select_port(self, host):
|
164
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
165
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
166
|
+
sock.bind((host, 0))
|
167
|
+
port = sock.getsockname()[1]
|
168
|
+
sock.close()
|
169
|
+
return port
|
170
|
+
|
171
|
+
def run(self):
|
172
|
+
if self.server is not None:
|
173
|
+
self.server.serve_forever()
|
174
|
+
|
175
|
+
def close(self):
|
176
|
+
if self.server is not None:
|
177
|
+
self.server.shutdown()
|
178
|
+
self.server.server_close()
|
179
|
+
|
180
|
+
def __del__(self):
|
181
|
+
self.close()
|
182
|
+
|
183
|
+
def __enter__(self):
|
184
|
+
return self
|
185
|
+
|
186
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
187
|
+
self.close()
|
188
|
+
|
189
|
+
def send_action(self, action: ActionBase, need_reply: bool = False):
|
190
|
+
|
191
|
+
if need_reply:
|
192
|
+
assert isinstance(action, ReplyActionBase)
|
193
|
+
action.reply_host = self.host
|
194
|
+
action.reply_port = self.port
|
195
|
+
action.timestamp = action.timestamp or time.time()
|
196
|
+
action.uuid = action.uuid and str(uuid.uuid4())
|
197
|
+
self.action_queue.put(action.model_dump())
|
198
|
+
bot_outputs = importlib.import_module(".bot_outputs", __package__)
|
199
|
+
bot_outputs.output_action(action)
|
200
|
+
|
201
|
+
def get_action_reply(self, action: ReplyActionBase, wait: bool = True) -> Optional[ActionBase]:
|
202
|
+
|
203
|
+
while wait and action.uuid not in self.action_replies:
|
204
|
+
time.sleep(1)
|
205
|
+
if action.uuid in self.action_replies:
|
206
|
+
self.action_replies[action.uuid].retrieved = True
|
207
|
+
self.action_replies[action.uuid].retrieved_timestamp = time.time()
|
208
|
+
return self.action_replies.get(action.uuid) and self.action_replies[action.uuid].reply
|
209
|
+
|
210
|
+
|
211
|
+
_default_action_dispatcher = None
|
212
|
+
|
213
|
+
|
214
|
+
def get_action_dispatcher() -> ActionDispatcher:
|
215
|
+
global _default_action_dispatcher
|
216
|
+
|
217
|
+
if not _default_action_dispatcher:
|
218
|
+
_default_action_dispatcher = ActionDispatcher()
|
219
|
+
elif not _default_action_dispatcher.is_alive():
|
220
|
+
_default_action_dispatcher.close()
|
221
|
+
_default_action_dispatcher = ActionDispatcher()
|
222
|
+
return _default_action_dispatcher
|
223
|
+
|
224
|
+
|
225
|
+
def close_action_dispatcher():
|
226
|
+
global _default_action_dispatcher
|
227
|
+
|
228
|
+
if _default_action_dispatcher:
|
229
|
+
_default_action_dispatcher.close()
|
230
|
+
_default_action_dispatcher = None
|
231
|
+
|
232
|
+
|
233
|
+
@get("/echo")
|
234
|
+
def echo():
|
235
|
+
response.content_type = "application/json"
|
236
|
+
return json.dumps({"status": "OK"})
|
237
|
+
|
238
|
+
|
239
|
+
@post("/action_reply")
|
240
|
+
def action_reply():
|
241
|
+
try:
|
242
|
+
uuid = request.GET["uuid"] # type: ignore
|
243
|
+
action = request.GET.get("a") or request.json.get("action") # type: ignore
|
244
|
+
source = request.GET.get("s") or request.json.get("source") # type: ignore
|
245
|
+
reply = get_action_class(action)(**request.json) # type: ignore
|
246
|
+
action_reply = ActionReply(reply_timestamp=time.time(), uuid=uuid, source=source, action=action, reply=reply)
|
247
|
+
get_action_dispatcher().action_replies[action_reply.uuid] = action_reply
|
248
|
+
response.content_type = "application/json"
|
249
|
+
return json.dumps({"status": "OK"})
|
250
|
+
except Exception as e:
|
251
|
+
response.content_type = "application/json"
|
252
|
+
return json.dumps(
|
253
|
+
{"status": "ERROR", "error": f"{type(e).__name__}: {e}", "traceback": traceback.format_exc()}
|
254
|
+
)
|
255
|
+
|
256
|
+
|
257
|
+
@get("/action_fetch")
|
258
|
+
def action_fetch():
|
259
|
+
try:
|
260
|
+
action = get_action_dispatcher().action_queue.get(block=False)
|
261
|
+
response.content_type = "application/json"
|
262
|
+
return json.dumps({"status": "OK", "action": action})
|
263
|
+
except queue.Empty:
|
264
|
+
response.content_type = "application/json"
|
265
|
+
return json.dumps({"status": "EMPTY"})
|
266
|
+
except Exception as e:
|
267
|
+
response.content_type = "application/json"
|
268
|
+
return json.dumps(
|
269
|
+
{"status": "ERROR", "error": f"{type(e).__name__}: {e}", "traceback": traceback.format_exc()}
|
270
|
+
)
|
@@ -1,42 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Copyright (c) 2025 viewstar000
|
3
|
-
|
4
|
-
This software is released under the MIT License.
|
5
|
-
https://opensource.org/licenses/MIT
|
6
|
-
"""
|
7
|
-
|
8
|
-
from .base import BaseChatAgent, AgentFactory
|
9
|
-
from .master_planner import MasterPlannerAgent
|
10
|
-
from .output_task_result import OutputTaskResult
|
11
|
-
from .task_code_executor import CodeExecutor
|
12
|
-
from .task_planner_v1 import TaskPlannerAgentV1
|
13
|
-
from .task_planner_v2 import TaskPlannerAgentV2
|
14
|
-
from .task_planner_v3 import TaskPlannerAgentV3
|
15
|
-
from .task_coder import TaskCodingAgent
|
16
|
-
from .task_debuger import CodeDebugerAgent
|
17
|
-
from .task_verifier import TaskVerifyAgent, TaskVerifyState
|
18
|
-
from .task_summarier import TaskSummaryAgent
|
19
|
-
from .task_verify_summarier import TaskVerifySummaryAgent
|
20
|
-
from .task_structrue_summarier import TaskStructureSummaryAgent
|
21
|
-
from .task_reasoner import TaskReasoningAgent
|
22
|
-
from .task_structrue_reasoner import TaskStructureReasoningAgent
|
23
|
-
|
24
|
-
__all__ = [
|
25
|
-
"AgentFactory",
|
26
|
-
"BaseChatAgent",
|
27
|
-
"CodeDebugerAgent",
|
28
|
-
"CodeExecutor",
|
29
|
-
"MasterPlannerAgent",
|
30
|
-
"TaskCodingAgent",
|
31
|
-
"TaskPlannerAgentV1",
|
32
|
-
"TaskPlannerAgentV2",
|
33
|
-
"TaskPlannerAgentV3",
|
34
|
-
"TaskReasoningAgent",
|
35
|
-
"TaskStructureReasoningAgent",
|
36
|
-
"TaskStructureSummaryAgent",
|
37
|
-
"TaskSummaryAgent",
|
38
|
-
"TaskVerifyAgent",
|
39
|
-
"TaskVerifyState",
|
40
|
-
"TaskVerifySummaryAgent",
|
41
|
-
"OutputTaskResult",
|
42
|
-
]
|
jupyter_agent/bot_agents/base.py
CHANGED
@@ -7,11 +7,13 @@ https://opensource.org/licenses/MIT
|
|
7
7
|
|
8
8
|
import json
|
9
9
|
import importlib
|
10
|
+
import traceback
|
10
11
|
|
11
12
|
from typing import Tuple, Any
|
12
13
|
from enum import Enum, unique
|
14
|
+
from pydantic import BaseModel, Field
|
13
15
|
from IPython.display import Markdown
|
14
|
-
from ..bot_outputs import _C, flush_output
|
16
|
+
from ..bot_outputs import _C, _O, _W, _T, flush_output
|
15
17
|
from ..bot_chat import BotChat
|
16
18
|
from ..utils import no_indent
|
17
19
|
|
@@ -161,12 +163,13 @@ class BaseChatAgent(BotChat, BaseAgent):
|
|
161
163
|
DISPLAY_REPLY = True
|
162
164
|
COMBINE_REPLY = AgentCombineReply.MERGE
|
163
165
|
ACCEPT_EMPYT_REPLY = False
|
166
|
+
REPLY_ERROR_RETRIES = 1
|
164
167
|
MODEL_TYPE = AgentModelType.REASONING
|
165
168
|
|
166
|
-
def __init__(self, notebook_context,
|
169
|
+
def __init__(self, notebook_context, **chat_kwargs):
|
167
170
|
"""初始化基础任务代理"""
|
168
171
|
BaseAgent.__init__(self, notebook_context)
|
169
|
-
BotChat.__init__(self,
|
172
|
+
BotChat.__init__(self, **chat_kwargs)
|
170
173
|
|
171
174
|
def prepare_contexts(self, **kwargs):
|
172
175
|
contexts = {
|
@@ -185,8 +188,16 @@ class BaseChatAgent(BotChat, BaseAgent):
|
|
185
188
|
}
|
186
189
|
else:
|
187
190
|
json_example = {}
|
188
|
-
|
189
|
-
|
191
|
+
|
192
|
+
def _default(o):
|
193
|
+
if isinstance(o, BaseModel):
|
194
|
+
return o.model_dump()
|
195
|
+
if isinstance(o, Enum):
|
196
|
+
return o.value
|
197
|
+
return repr(o)
|
198
|
+
|
199
|
+
contexts["OUTPUT_JSON_SCHEMA"] = json.dumps(json_schema, indent=2, ensure_ascii=False, default=_default)
|
200
|
+
contexts["OUTPUT_JSON_EXAMPLE"] = json.dumps(json_example, indent=2, ensure_ascii=False, default=_default)
|
190
201
|
contexts.update(kwargs)
|
191
202
|
return contexts
|
192
203
|
|
@@ -220,30 +231,41 @@ class BaseChatAgent(BotChat, BaseAgent):
|
|
220
231
|
|
221
232
|
def combine_json_replies(self, replies):
|
222
233
|
json_replies = [reply for reply in replies if reply["type"] == "code" and reply["lang"] == "json"]
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
if self.
|
231
|
-
json_obj =
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
234
|
+
assert self.COMBINE_REPLY in [
|
235
|
+
AgentCombineReply.FIRST,
|
236
|
+
AgentCombineReply.LAST,
|
237
|
+
AgentCombineReply.LIST,
|
238
|
+
AgentCombineReply.MERGE,
|
239
|
+
]
|
240
|
+
try:
|
241
|
+
if self.COMBINE_REPLY == AgentCombineReply.FIRST:
|
242
|
+
json_obj = json.loads(json_replies[0]["content"])
|
243
|
+
if self.OUTPUT_JSON_SCHEMA:
|
244
|
+
json_obj = self.OUTPUT_JSON_SCHEMA(**json_obj)
|
245
|
+
return json_obj
|
246
|
+
elif self.COMBINE_REPLY == AgentCombineReply.LAST:
|
247
|
+
json_obj = json.loads(json_replies[-1]["content"])
|
248
|
+
if self.OUTPUT_JSON_SCHEMA:
|
249
|
+
json_obj = self.OUTPUT_JSON_SCHEMA(**json_obj)
|
250
|
+
return json_obj
|
251
|
+
elif self.COMBINE_REPLY == AgentCombineReply.LIST:
|
252
|
+
json_objs = [json.loads(reply["content"]) for reply in json_replies]
|
253
|
+
if self.OUTPUT_JSON_SCHEMA:
|
254
|
+
json_objs = [self.OUTPUT_JSON_SCHEMA(**json_obj) for json_obj in json_objs]
|
255
|
+
return json_objs
|
256
|
+
elif self.COMBINE_REPLY == AgentCombineReply.MERGE:
|
257
|
+
json_obj = {}
|
258
|
+
for json_reply in json_replies:
|
259
|
+
json_obj.update(json.loads(json_reply["content"]))
|
260
|
+
if self.OUTPUT_JSON_SCHEMA:
|
261
|
+
json_obj = self.OUTPUT_JSON_SCHEMA(**json_obj)
|
262
|
+
return json_obj
|
263
|
+
else:
|
264
|
+
return False
|
265
|
+
except Exception as e:
|
266
|
+
_T(f"提取JSON失败: {type(e).__name__}: {e}")
|
267
|
+
_W(traceback.format_exc())
|
268
|
+
return False
|
247
269
|
|
248
270
|
def combine_text_replies(self, replies):
|
249
271
|
text_replies = [reply for reply in replies if reply["type"] == "text"]
|
@@ -274,10 +296,22 @@ class BaseChatAgent(BotChat, BaseAgent):
|
|
274
296
|
def __call__(self, **kwargs) -> Tuple[bool, Any]:
|
275
297
|
contexts = self.prepare_contexts(**kwargs)
|
276
298
|
messages = self.create_messages(contexts)
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
299
|
+
reply_retries = 0
|
300
|
+
while reply_retries <= self.REPLY_ERROR_RETRIES:
|
301
|
+
replies = self.chat(messages.get(), display_reply=self.DISPLAY_REPLY)
|
302
|
+
reply = self.combine_replies(replies)
|
303
|
+
if reply is False:
|
304
|
+
reply_retries += 1
|
305
|
+
if reply_retries > self.REPLY_ERROR_RETRIES:
|
306
|
+
raise ValueError("Failed to get reply")
|
307
|
+
_W("Failed to get reply, retrying...")
|
308
|
+
elif not self.ACCEPT_EMPYT_REPLY and not reply:
|
309
|
+
reply_retries += 1
|
310
|
+
if reply_retries > self.REPLY_ERROR_RETRIES:
|
311
|
+
raise ValueError("Reply is empty")
|
312
|
+
_W("Reply is empty, retrying...")
|
313
|
+
else:
|
314
|
+
break
|
281
315
|
result = self.on_reply(reply)
|
282
316
|
flush_output()
|
283
317
|
if not isinstance(result, tuple):
|
@@ -300,25 +334,31 @@ class AgentFactory:
|
|
300
334
|
"model": model_name,
|
301
335
|
}
|
302
336
|
|
303
|
-
def
|
304
|
-
|
337
|
+
def get_agent_class(self, agent_class):
|
305
338
|
if isinstance(agent_class, str):
|
306
339
|
bot_agents = importlib.import_module("..bot_agents", __package__)
|
307
340
|
agent_class = getattr(bot_agents, agent_class)
|
341
|
+
assert issubclass(agent_class, BaseAgent), "Unsupported agent class: {}".format(agent_class)
|
342
|
+
return agent_class
|
308
343
|
|
344
|
+
def get_chat_kwargs(self, agent_class):
|
309
345
|
if issubclass(agent_class, BaseChatAgent):
|
310
346
|
agent_model = agent_class.MODEL_TYPE if hasattr(agent_class, "MODEL_TYPE") else AgentModelType.DEFAULT
|
311
|
-
|
312
|
-
|
313
|
-
base_url=self.models.get(agent_model, {}).get("api_url")
|
347
|
+
chat_kwargs = {
|
348
|
+
"base_url": self.models.get(agent_model, {}).get("api_url")
|
314
349
|
or self.models[AgentModelType.DEFAULT]["api_url"],
|
315
|
-
api_key
|
350
|
+
"api_key": self.models.get(agent_model, {}).get("api_key")
|
316
351
|
or self.models[AgentModelType.DEFAULT]["api_key"],
|
317
|
-
model_name
|
352
|
+
"model_name": self.models.get(agent_model, {}).get("model")
|
318
353
|
or self.models[AgentModelType.DEFAULT]["model"],
|
319
|
-
|
320
|
-
)
|
321
|
-
|
322
|
-
return agent_class(notebook_context=self.notebook_context)
|
354
|
+
}
|
355
|
+
chat_kwargs.update(self.chat_kwargs)
|
356
|
+
return chat_kwargs
|
323
357
|
else:
|
324
|
-
|
358
|
+
return {}
|
359
|
+
|
360
|
+
def __call__(self, agent_class):
|
361
|
+
|
362
|
+
agent_class = self.get_agent_class(agent_class)
|
363
|
+
chat_kwargs = self.get_chat_kwargs(agent_class)
|
364
|
+
return agent_class(self.notebook_context, **chat_kwargs)
|
@@ -8,6 +8,7 @@ https://opensource.org/licenses/MIT
|
|
8
8
|
from IPython.display import Markdown
|
9
9
|
from .base import BaseChatAgent, AgentModelType
|
10
10
|
from ..bot_outputs import _C, ReplyType
|
11
|
+
from ..bot_evaluators.dummy_task import DummyTaskEvaluator
|
11
12
|
|
12
13
|
MASTER_PLANNER_PROMPT = """\
|
13
14
|
**角色定义**:
|
@@ -40,6 +41,7 @@ class MasterPlannerAgent(BaseChatAgent):
|
|
40
41
|
PROMPT = MASTER_PLANNER_PROMPT
|
41
42
|
DISPLAY_REPLY = False
|
42
43
|
MODEL_TYPE = AgentModelType.PLANNER
|
44
|
+
EVALUATORS = {None: DummyTaskEvaluator}
|
43
45
|
|
44
46
|
def on_reply(self, reply):
|
45
47
|
_C(Markdown(reply), reply_type=ReplyType.TASK_RESULT)
|
@@ -17,13 +17,12 @@ class OutputTaskResult(BaseAgent):
|
|
17
17
|
def __call__(self):
|
18
18
|
"""执行代码逻辑"""
|
19
19
|
if self.task.result:
|
20
|
-
|
21
|
-
_C(Markdown(self.task.result), reply_type=ReplyType.TASK_RESULT)
|
20
|
+
_M("### 任务结果\n\n" + self.task.result)
|
22
21
|
if self.task.important_infos:
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
22
|
+
_B(
|
23
|
+
json.dumps(self.task.important_infos, indent=4, ensure_ascii=False),
|
24
|
+
title="重要信息",
|
25
|
+
format="code",
|
26
|
+
code_language="json",
|
28
27
|
)
|
29
28
|
return False, None
|