jupyter-agent 2025.6.100__py3-none-any.whl → 2025.6.101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jupyter_agent/__init__.py +0 -0
- jupyter_agent/bot_agents/__init__.py +42 -0
- jupyter_agent/bot_agents/base.py +324 -0
- jupyter_agent/bot_agents/master_planner.py +45 -0
- jupyter_agent/bot_agents/output_task_result.py +29 -0
- jupyter_agent/bot_agents/task_code_executor.py +53 -0
- jupyter_agent/bot_agents/task_coder.py +71 -0
- jupyter_agent/bot_agents/task_debuger.py +69 -0
- jupyter_agent/bot_agents/task_planner_v1.py +158 -0
- jupyter_agent/bot_agents/task_planner_v2.py +172 -0
- jupyter_agent/bot_agents/task_planner_v3.py +189 -0
- jupyter_agent/bot_agents/task_reasoner.py +61 -0
- jupyter_agent/bot_agents/task_structrue_reasoner.py +106 -0
- jupyter_agent/bot_agents/task_structrue_summarier.py +123 -0
- jupyter_agent/bot_agents/task_summarier.py +76 -0
- jupyter_agent/bot_agents/task_verifier.py +99 -0
- jupyter_agent/bot_agents/task_verify_summarier.py +134 -0
- jupyter_agent/bot_chat.py +218 -0
- jupyter_agent/bot_contexts.py +466 -0
- jupyter_agent/bot_flows/__init__.py +20 -0
- jupyter_agent/bot_flows/base.py +209 -0
- jupyter_agent/bot_flows/master_planner.py +16 -0
- jupyter_agent/bot_flows/task_executor_v1.py +86 -0
- jupyter_agent/bot_flows/task_executor_v2.py +84 -0
- jupyter_agent/bot_flows/task_executor_v3.py +89 -0
- jupyter_agent/bot_magics.py +127 -0
- jupyter_agent/bot_outputs.py +480 -0
- jupyter_agent/utils.py +138 -0
- {jupyter_agent-2025.6.100.dist-info → jupyter_agent-2025.6.101.dist-info}/METADATA +13 -7
- jupyter_agent-2025.6.101.dist-info/RECORD +33 -0
- jupyter_agent-2025.6.101.dist-info/top_level.txt +1 -0
- jupyter_agent-2025.6.100.dist-info/RECORD +0 -5
- jupyter_agent-2025.6.100.dist-info/top_level.txt +0 -1
- {jupyter_agent-2025.6.100.dist-info → jupyter_agent-2025.6.101.dist-info}/WHEEL +0 -0
- {jupyter_agent-2025.6.100.dist-info → jupyter_agent-2025.6.101.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,123 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 viewstar000
|
3
|
+
|
4
|
+
This software is released under the MIT License.
|
5
|
+
https://opensource.org/licenses/MIT
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
|
10
|
+
from enum import Enum
|
11
|
+
from typing import List, Optional, Dict, Any
|
12
|
+
from pydantic import BaseModel, Field
|
13
|
+
from IPython.display import Markdown
|
14
|
+
from .base import BaseChatAgent, AgentOutputFormat
|
15
|
+
from ..bot_outputs import ReplyType, _D, _I, _W, _E, _F, _M, _B, _C, _O, markdown_block
|
16
|
+
from ..utils import RequestUserPrompt, format_user_prompts
|
17
|
+
|
18
|
+
TASK_SUMMARY_PROMPT = """\
|
19
|
+
**角色定义**:
|
20
|
+
|
21
|
+
你是一个信息提炼专家,能够从分析结果中提取关键结论。
|
22
|
+
|
23
|
+
**任务要求**:
|
24
|
+
|
25
|
+
- 将代码执行的输出与结果转化为**人类可读的总结**
|
26
|
+
- 包含以下内容:
|
27
|
+
1. 代码执行结果总结
|
28
|
+
2. 核心发现(如"Electronics类别月均增长12%")
|
29
|
+
3. 数据支撑(引用关键数值或图表)
|
30
|
+
4. 其它建议(如新子任务Prompt等)
|
31
|
+
- 在引用其它已完成的子任务的结果时,特别是其important_infos中的信息,要保证准确、清晰、完整,不要出现任何误导信息
|
32
|
+
|
33
|
+
注:任务代码执行的结果不会记录在全局上下文中,只有任务总结的结果会记录在全局上下文中,
|
34
|
+
因此任务总结中应包含对代码执行结果的简要说明,以便后续子任务使用。
|
35
|
+
|
36
|
+
{% include "TASK_OUTPUT_FORMAT" %}
|
37
|
+
|
38
|
+
---
|
39
|
+
|
40
|
+
{% include "TASK_CONTEXTS" %}
|
41
|
+
|
42
|
+
---
|
43
|
+
|
44
|
+
{% include "CODE_CONTEXTS" %}
|
45
|
+
|
46
|
+
---
|
47
|
+
|
48
|
+
**当前子任务信息**:
|
49
|
+
|
50
|
+
### 当前子任务目标:
|
51
|
+
{{ task.subject }}
|
52
|
+
|
53
|
+
### 当前子任务代码需求:
|
54
|
+
{{ task.coding_prompt }}
|
55
|
+
|
56
|
+
### 当前代码:
|
57
|
+
```python
|
58
|
+
{{ task.source }}
|
59
|
+
```
|
60
|
+
|
61
|
+
### 当前代码执行的输出与结果:
|
62
|
+
{{ task.output }}
|
63
|
+
|
64
|
+
### 当前任务总结要求:
|
65
|
+
{{ task.summary_prompt }}
|
66
|
+
|
67
|
+
---
|
68
|
+
|
69
|
+
请按要求输出任务总结:
|
70
|
+
"""
|
71
|
+
|
72
|
+
|
73
|
+
class RequestInfo(BaseModel):
|
74
|
+
prompt: str = Field(description="需要用户补充更详细的信息的 Prompt", examples=["请补充与...相关的详细的信息"])
|
75
|
+
example: Optional[str] = Field(None, description="示例", examples=["..."])
|
76
|
+
|
77
|
+
|
78
|
+
class TaskStructureSumaryOutput(BaseModel):
|
79
|
+
|
80
|
+
summary: str = Field(description=f"任务总结的详细描述", examples=["..."])
|
81
|
+
important_infos: Optional[Dict[str, Any]] = Field(
|
82
|
+
None,
|
83
|
+
description="任务总结中的重要信息,特别是需要后续子任务重点关注的信息。"
|
84
|
+
"注意:该字段仅支持结构化信息,不能使用代码、长文本等非结构化信息",
|
85
|
+
examples=[
|
86
|
+
{
|
87
|
+
"..._constraint": "...",
|
88
|
+
"..._expression": "...",
|
89
|
+
"..._patterns": ["...", "..."],
|
90
|
+
"..._execution_strategies": ["...", "..."],
|
91
|
+
"..._features": {"...": "...", "...": "..."},
|
92
|
+
"..._mapping_rules": {"...": "...", "...": "..."},
|
93
|
+
"...": "...",
|
94
|
+
}
|
95
|
+
],
|
96
|
+
)
|
97
|
+
request_confirm_infos: Optional[List[RequestUserPrompt]] = Field(
|
98
|
+
None, description="需要用户补充确认的信息,问题应尽量简单,只需要用户回答是/否或在备选项中选择即可"
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
class TaskStructureSummaryAgent(BaseChatAgent):
|
103
|
+
|
104
|
+
PROMPT = TASK_SUMMARY_PROMPT
|
105
|
+
OUTPUT_FORMAT = AgentOutputFormat.JSON
|
106
|
+
OUTPUT_JSON_SCHEMA = TaskStructureSumaryOutput
|
107
|
+
DISPLAY_REPLY = True
|
108
|
+
|
109
|
+
def on_reply(self, reply: TaskStructureSumaryOutput):
|
110
|
+
assert reply.summary, "Reply is empty"
|
111
|
+
_C(Markdown("### 任务总结\n\n" + reply.summary), reply_type=ReplyType.TASK_RESULT)
|
112
|
+
self.task.set_data("result", reply.summary)
|
113
|
+
if reply.important_infos:
|
114
|
+
self.task.set_data("important_infos", reply.important_infos)
|
115
|
+
_O(
|
116
|
+
markdown_block(
|
117
|
+
f"```json\n{json.dumps(reply.important_infos, indent=4, ensure_ascii=False)}\n```",
|
118
|
+
title="重要信息",
|
119
|
+
)
|
120
|
+
)
|
121
|
+
if reply.request_confirm_infos:
|
122
|
+
_O(Markdown(f"### 需要补充确认的信息\n"))
|
123
|
+
_O(Markdown(format_user_prompts(reply.request_confirm_infos, title="用户补充确认信息")))
|
@@ -0,0 +1,76 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 viewstar000
|
3
|
+
|
4
|
+
This software is released under the MIT License.
|
5
|
+
https://opensource.org/licenses/MIT
|
6
|
+
"""
|
7
|
+
|
8
|
+
from IPython.display import Markdown
|
9
|
+
from .base import BaseChatAgent, AgentOutputFormat
|
10
|
+
from ..bot_outputs import _D, _I, _W, _E, _F, _M, _B, _C
|
11
|
+
from ..bot_outputs import ReplyType
|
12
|
+
|
13
|
+
TASK_SUMMARY_PROMPT = """\
|
14
|
+
**角色定义**:
|
15
|
+
|
16
|
+
你是一个信息提炼专家,能够从分析结果中提取关键结论。
|
17
|
+
|
18
|
+
**任务要求**:
|
19
|
+
|
20
|
+
- 将代码执行的输出与结果转化为**人类可读的总结**
|
21
|
+
- 包含以下内容:
|
22
|
+
1. 代码执行结果总结
|
23
|
+
2. 核心发现(如"Electronics类别月均增长12%")
|
24
|
+
3. 数据支撑(引用关键数值或图表)
|
25
|
+
4. 其它建议(如新子任务Prompt等)
|
26
|
+
|
27
|
+
注:任务代码执行的结果不会记录在全局上下文中,只有任务总结的结果会记录在全局上下文中,
|
28
|
+
因此任务总结中应包含对代码执行结果的简要说明,以便后续子任务使用。
|
29
|
+
|
30
|
+
{% include "TASK_OUTPUT_FORMAT" %}
|
31
|
+
|
32
|
+
---
|
33
|
+
|
34
|
+
{% include "TASK_CONTEXTS" %}
|
35
|
+
|
36
|
+
---
|
37
|
+
|
38
|
+
{% include "CODE_CONTEXTS" %}
|
39
|
+
|
40
|
+
---
|
41
|
+
|
42
|
+
**当前子任务信息**:
|
43
|
+
|
44
|
+
### 当前子任务目标:
|
45
|
+
{{ task.subject }}
|
46
|
+
|
47
|
+
### 当前子任务代码需求:
|
48
|
+
{{ task.coding_prompt }}
|
49
|
+
|
50
|
+
### 当前代码:
|
51
|
+
```python
|
52
|
+
{{ task.source }}
|
53
|
+
```
|
54
|
+
|
55
|
+
### 当前代码执行的输出与结果:
|
56
|
+
{{ task.output }}
|
57
|
+
|
58
|
+
### 当前任务总结要求:
|
59
|
+
{{ task.summary_prompt }}
|
60
|
+
|
61
|
+
---
|
62
|
+
|
63
|
+
请按要求输出任务总结:
|
64
|
+
"""
|
65
|
+
|
66
|
+
|
67
|
+
class TaskSummaryAgent(BaseChatAgent):
|
68
|
+
|
69
|
+
PROMPT = TASK_SUMMARY_PROMPT
|
70
|
+
OUTPUT_FORMAT = AgentOutputFormat.TEXT
|
71
|
+
DISPLAY_REPLY = False
|
72
|
+
|
73
|
+
def on_reply(self, reply: str):
|
74
|
+
_C(Markdown("### 任务总结\n" + reply), reply_type=ReplyType.TASK_RESULT)
|
75
|
+
assert reply, "Reply is empty"
|
76
|
+
self.task.set_data("result", reply)
|
@@ -0,0 +1,99 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 viewstar000
|
3
|
+
|
4
|
+
This software is released under the MIT License.
|
5
|
+
https://opensource.org/licenses/MIT
|
6
|
+
"""
|
7
|
+
|
8
|
+
from enum import Enum
|
9
|
+
from typing import List
|
10
|
+
from pydantic import BaseModel, Field
|
11
|
+
from IPython.display import Markdown
|
12
|
+
from .base import BaseChatAgent, AgentOutputFormat
|
13
|
+
from ..bot_outputs import _D, _I, _W, _E, _F, _M, _B, _C, _O
|
14
|
+
from ..bot_outputs import ReplyType
|
15
|
+
|
16
|
+
|
17
|
+
TASK_VERIFY_PROMPT = """\
|
18
|
+
**角色定义**:
|
19
|
+
|
20
|
+
你是一个数据质量检查员,负责验证子任务的输出与结果的正确性。
|
21
|
+
|
22
|
+
**任务要求**:
|
23
|
+
|
24
|
+
- 对比子任务Prompt的预期输出和实际结果,验证以下维度:
|
25
|
+
1. 数据完整性(如无缺失值、数据量合理)
|
26
|
+
2. 逻辑一致性(如增长率计算正确)
|
27
|
+
- 输出验证结果和改进建议(如需要重新运行子任务则标记为失败)
|
28
|
+
|
29
|
+
{% include "TASK_OUTPUT_FORMAT" %}
|
30
|
+
|
31
|
+
---
|
32
|
+
|
33
|
+
{% include "TASK_CONTEXTS" %}
|
34
|
+
|
35
|
+
---
|
36
|
+
|
37
|
+
{% include "CODE_CONTEXTS" %}
|
38
|
+
|
39
|
+
---
|
40
|
+
|
41
|
+
**当前子任务信息**:
|
42
|
+
|
43
|
+
### 当前子任务目标:
|
44
|
+
{{ task.subject }}
|
45
|
+
|
46
|
+
### 当前子任务代码需求:
|
47
|
+
{{ task.coding_prompt }}
|
48
|
+
|
49
|
+
### 当前代码:
|
50
|
+
```python
|
51
|
+
{{ task.source }}
|
52
|
+
```
|
53
|
+
|
54
|
+
### 当前输出:
|
55
|
+
{{ task.output }}
|
56
|
+
|
57
|
+
### 当前任务验证条件:
|
58
|
+
{{ task.verify_prompt }}
|
59
|
+
|
60
|
+
---
|
61
|
+
|
62
|
+
请按要求输出验证结果:
|
63
|
+
"""
|
64
|
+
|
65
|
+
|
66
|
+
class TaskVerifyState(Enum):
|
67
|
+
FAILED = "failed"
|
68
|
+
PASSED = "passed"
|
69
|
+
|
70
|
+
|
71
|
+
class TaskVerifyOutput(BaseModel):
|
72
|
+
state: TaskVerifyState = Field(description="任务验证结果", examples=[TaskVerifyState.PASSED.value])
|
73
|
+
issues: List[str] = Field(
|
74
|
+
[],
|
75
|
+
description="任务验证失败问题清单, 任务验证失败时必填, 任务验证通过时返回空列表",
|
76
|
+
examples=[["...未包含...字段...", "..字段值缺失...", "...字段值未在合理范围内...", "..."]],
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
class TaskVerifyAgent(BaseChatAgent):
|
81
|
+
|
82
|
+
PROMPT = TASK_VERIFY_PROMPT
|
83
|
+
OUTPUT_FORMAT = AgentOutputFormat.JSON
|
84
|
+
OUTPUT_JSON_SCHEMA = TaskVerifyOutput
|
85
|
+
|
86
|
+
def on_reply(self, reply: TaskVerifyOutput):
|
87
|
+
|
88
|
+
if reply.state == TaskVerifyState.PASSED:
|
89
|
+
_M("### 任务验证通过!")
|
90
|
+
return False, reply.state
|
91
|
+
else:
|
92
|
+
_M("### 任务验证不通过!\n")
|
93
|
+
task_issue = ""
|
94
|
+
if reply.issues:
|
95
|
+
for issue in reply.issues:
|
96
|
+
task_issue += "- {}\n".format(issue)
|
97
|
+
self.task.set_data("issue", task_issue)
|
98
|
+
_M(task_issue)
|
99
|
+
return True, reply.state
|
@@ -0,0 +1,134 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 viewstar000
|
3
|
+
|
4
|
+
This software is released under the MIT License.
|
5
|
+
https://opensource.org/licenses/MIT
|
6
|
+
"""
|
7
|
+
|
8
|
+
from enum import Enum
|
9
|
+
from typing import List, Optional
|
10
|
+
from pydantic import BaseModel, Field
|
11
|
+
from IPython.display import Markdown
|
12
|
+
from .base import BaseChatAgent, AgentOutputFormat
|
13
|
+
from ..bot_outputs import _D, _I, _W, _E, _F, _M, _B, _C, _O
|
14
|
+
from ..bot_outputs import ReplyType
|
15
|
+
|
16
|
+
|
17
|
+
TASK_SUMMARY_PROMPT = """\
|
18
|
+
**角色定义**:
|
19
|
+
|
20
|
+
你是一个任务总结规划专家,能够从分析结果中提取关键结论,并对任务结果进行总结分析,给出优化建议。
|
21
|
+
|
22
|
+
**任务要求**:
|
23
|
+
|
24
|
+
- 对任务代码的执行结果进行进一步的推理分析总结,并输出**人类可读的总结**,包含以下内容:
|
25
|
+
1. 代码执行结果总结
|
26
|
+
1. 核心发现(如"Electronics类别月均增长12%")
|
27
|
+
2. 数据支撑(引用关键数值或图表)
|
28
|
+
3. 其它建议(如新子任务Prompt等)
|
29
|
+
- 若代码的执行结果不满足当前子任务的要求,则输出**人类可读的修改建议**,包含以下内容:
|
30
|
+
1. 当前结果不满足子任务目标的具体原因
|
31
|
+
2. 修改后的代码生成Prompt,包括:
|
32
|
+
- 需生成的代码类型(如数据处理、建模、可视化等)
|
33
|
+
- 具体输入(数据、变量、参数等)
|
34
|
+
- 预期输出形式(变量名、图表、文本等)
|
35
|
+
- 代码执行的结果仅在当前子任务中可见,不会记录在全局上下文中
|
36
|
+
3. 修改后的分析总结Prompt,包括:
|
37
|
+
- 说明本子任务结果总结的要点和输出要素,以便后续子任务使用
|
38
|
+
- 验证总结的结果会记录在全局上下文中
|
39
|
+
|
40
|
+
注:任务代码执行的结果不会记录在全局上下文中,只有任务总结的结果会记录在全局上下文中,
|
41
|
+
因此任务总结中应包含对代码执行结果的简要说明,以便后续子任务使用。
|
42
|
+
|
43
|
+
{% include "TASK_OUTPUT_FORMAT" %}
|
44
|
+
|
45
|
+
---
|
46
|
+
|
47
|
+
{% include "TASK_CONTEXTS" %}
|
48
|
+
|
49
|
+
---
|
50
|
+
|
51
|
+
{% include "CODE_CONTEXTS" %}
|
52
|
+
|
53
|
+
---
|
54
|
+
|
55
|
+
**当前子任务信息**:
|
56
|
+
|
57
|
+
### 当前子任务目标:
|
58
|
+
{{ task.subject }}
|
59
|
+
|
60
|
+
### 当前子任务代码需求:
|
61
|
+
{{ task.coding_prompt }}
|
62
|
+
|
63
|
+
### 当前代码:
|
64
|
+
```python
|
65
|
+
{{ task.source }}
|
66
|
+
```
|
67
|
+
|
68
|
+
### 当前输出:
|
69
|
+
{{ task.output }}
|
70
|
+
|
71
|
+
### 当前任务总结要求:
|
72
|
+
{{ task.summary_prompt }}
|
73
|
+
|
74
|
+
---
|
75
|
+
|
76
|
+
请按要求输出验证结果:
|
77
|
+
"""
|
78
|
+
|
79
|
+
|
80
|
+
class TaskSummaryState(str, Enum):
|
81
|
+
SUCCESS = "success"
|
82
|
+
NOT_SATISFY = "not_satisfy"
|
83
|
+
|
84
|
+
|
85
|
+
class TaskEnhancement(BaseModel):
|
86
|
+
issues: List[str] = Field([], description="当前子任务不满足要求的问题清单", examples=[["...", "..."]])
|
87
|
+
code_prompt: str = Field("", description="修改后的代码生成Prompt", examples=["..."])
|
88
|
+
summary_prompt: str = Field("", description="修改后的分析总结Prompt", examples=["..."])
|
89
|
+
|
90
|
+
|
91
|
+
class TaskSummaryOutput(BaseModel):
|
92
|
+
state: TaskSummaryState = Field(description="是否完成总结", examples=[TaskSummaryState.SUCCESS.value])
|
93
|
+
summary: str = Field(
|
94
|
+
"", description=f'任务总结的详细描述,在 state="{TaskSummaryState.SUCCESS}" 时必填', examples=["..."]
|
95
|
+
)
|
96
|
+
enhancement: Optional[TaskEnhancement] = Field(
|
97
|
+
None,
|
98
|
+
description=f"任务不满足要求时的修改建议,在 state='{TaskSummaryState.NOT_SATISFY}' 时必填",
|
99
|
+
examples=[{"issues": ["...", "..."], "code_prompt": "...", "verify_prompt": "...", "summary_prompt": "..."}],
|
100
|
+
)
|
101
|
+
|
102
|
+
|
103
|
+
class TaskVerifySummaryAgent(BaseChatAgent):
|
104
|
+
|
105
|
+
PROMPT = TASK_SUMMARY_PROMPT
|
106
|
+
OUTPUT_FORMAT = AgentOutputFormat.JSON
|
107
|
+
OUTPUT_JSON_SCHEMA = TaskSummaryOutput
|
108
|
+
|
109
|
+
def on_reply(self, reply: TaskSummaryOutput):
|
110
|
+
|
111
|
+
if reply.state == TaskSummaryState.SUCCESS:
|
112
|
+
_O(Markdown("### 任务总结"))
|
113
|
+
assert reply.summary, "Summary is empty"
|
114
|
+
_C(Markdown(reply.summary), reply_type=ReplyType.TASK_RESULT)
|
115
|
+
self.task.set_data("result", reply.summary)
|
116
|
+
return False, reply.state
|
117
|
+
else:
|
118
|
+
_M("### 任务验证不通过!\n")
|
119
|
+
assert reply.enhancement, "Enhancement is empty"
|
120
|
+
assert reply.enhancement.issues, "Issues is empty"
|
121
|
+
assert reply.enhancement.code_prompt, "Code prompt is empty"
|
122
|
+
assert reply.enhancement.summary_prompt, "Summary prompt is empty"
|
123
|
+
task_issue = ""
|
124
|
+
if reply.enhancement.issues:
|
125
|
+
for issue in reply.enhancement.issues:
|
126
|
+
task_issue += "- {}\n".format(issue)
|
127
|
+
self.task.set_data("issue", task_issue)
|
128
|
+
self.task.set_data("coding_prompt", reply.enhancement.code_prompt)
|
129
|
+
self.task.set_data("summary_prompt", reply.enhancement.summary_prompt)
|
130
|
+
_M(task_issue)
|
131
|
+
_M("### 修改后的子任务信息\n")
|
132
|
+
_M(f"### 当前子任务代码需求:\n\n{reply.enhancement.code_prompt}")
|
133
|
+
_M(f"### 当前子任务总结要求:\n\n{reply.enhancement.summary_prompt}")
|
134
|
+
return True, reply.state
|
@@ -0,0 +1,218 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 viewstar000
|
3
|
+
|
4
|
+
This software is released under the MIT License.
|
5
|
+
https://opensource.org/licenses/MIT
|
6
|
+
"""
|
7
|
+
|
8
|
+
import re
|
9
|
+
import json
|
10
|
+
import jinja2
|
11
|
+
import openai
|
12
|
+
|
13
|
+
|
14
|
+
from .bot_outputs import _D, _I, _W, _E, _F, _B, _M
|
15
|
+
|
16
|
+
|
17
|
+
class ChatMessages:
|
18
|
+
def __init__(self, contexts=None, templates=None, display_message=True):
|
19
|
+
self.messages = []
|
20
|
+
self.contexts = contexts
|
21
|
+
self.templates = templates
|
22
|
+
self.display_message = display_message
|
23
|
+
if self.templates is not None:
|
24
|
+
self.jinja_env = jinja2.Environment(
|
25
|
+
loader=jinja2.DictLoader(self.templates), trim_blocks=True, lstrip_blocks=True
|
26
|
+
)
|
27
|
+
else:
|
28
|
+
self.jinja_env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
|
29
|
+
self.jinja_env.filters["json"] = lambda x: json.dumps(x, indent=2, ensure_ascii=False)
|
30
|
+
|
31
|
+
def add(self, content, role="user", content_type="text", tpl_context=None):
|
32
|
+
tpl_context = tpl_context or self.contexts
|
33
|
+
if content_type == "text" and tpl_context is not None:
|
34
|
+
content = self.jinja_env.from_string(content).render(**tpl_context)
|
35
|
+
if content_type == "text":
|
36
|
+
content_key = "text"
|
37
|
+
else:
|
38
|
+
raise NotImplementedError
|
39
|
+
_D("Adding message: role={}, content_type={}".format(role, content_type))
|
40
|
+
if self.display_message:
|
41
|
+
_B(content, title="Chat Message {}: {}".format(role, content_type))
|
42
|
+
if len(self.messages) == 0 or self.messages[-1]["role"] != role:
|
43
|
+
self.messages.append({"role": role, "content": [{"type": content_type, content_key: content}]})
|
44
|
+
else:
|
45
|
+
self.messages[-1]["content"].append({"type": content_type, content_key: content})
|
46
|
+
|
47
|
+
def get(self):
|
48
|
+
return self.messages
|
49
|
+
|
50
|
+
def clear(self):
|
51
|
+
self.messages = []
|
52
|
+
|
53
|
+
|
54
|
+
class BotChat:
|
55
|
+
"""聊天混合类,提供聊天相关功能"""
|
56
|
+
|
57
|
+
display_think = True
|
58
|
+
display_message = True
|
59
|
+
display_response = False
|
60
|
+
|
61
|
+
def __init__(self, base_url, api_key, model_name, **chat_kwargs):
|
62
|
+
"""初始化聊天混合类"""
|
63
|
+
self.base_url = base_url
|
64
|
+
self.api_key = api_key
|
65
|
+
self.model_name = model_name
|
66
|
+
self.dispaly_think = chat_kwargs.get("dispaly_think", self.display_think)
|
67
|
+
self.display_message = chat_kwargs.get("display_message", self.display_message)
|
68
|
+
self.display_response = chat_kwargs.get("display_response", self.display_response)
|
69
|
+
|
70
|
+
def parse_reply(self, reply, ret_think_block=False, ret_empty_block=False, display_reply=True):
|
71
|
+
"""解析聊天回复"""
|
72
|
+
|
73
|
+
def _read_think_block(tokens):
|
74
|
+
text = ""
|
75
|
+
while True:
|
76
|
+
try:
|
77
|
+
token = next(iter_tokens)
|
78
|
+
except StopIteration:
|
79
|
+
break
|
80
|
+
if token is None:
|
81
|
+
continue
|
82
|
+
if token == "</think>":
|
83
|
+
break
|
84
|
+
elif token == "<think>":
|
85
|
+
text += _read_think_block(tokens)
|
86
|
+
text += "</think>"
|
87
|
+
# elif token.startswith("```") and len(token) > 3:
|
88
|
+
# text += _read_code_block(tokens)
|
89
|
+
# elif token.startswith("```") and len(token) == 3:
|
90
|
+
# text += _read_fence_block(tokens)
|
91
|
+
else:
|
92
|
+
text += token
|
93
|
+
return text
|
94
|
+
|
95
|
+
def _read_code_block(tokens):
|
96
|
+
text = ""
|
97
|
+
while True:
|
98
|
+
try:
|
99
|
+
token = next(iter_tokens)
|
100
|
+
except StopIteration:
|
101
|
+
break
|
102
|
+
if token is None:
|
103
|
+
continue
|
104
|
+
if token == "```":
|
105
|
+
break
|
106
|
+
# elif token == "<think>":
|
107
|
+
# text += _read_think_block(tokens)
|
108
|
+
elif token.startswith("```") and len(token) > 3:
|
109
|
+
text += _read_code_block(tokens)
|
110
|
+
text += "```"
|
111
|
+
else:
|
112
|
+
text += token
|
113
|
+
return text
|
114
|
+
|
115
|
+
def _read_fence_block(tokens):
|
116
|
+
text = ""
|
117
|
+
while True:
|
118
|
+
try:
|
119
|
+
token = next(iter_tokens)
|
120
|
+
except StopIteration:
|
121
|
+
break
|
122
|
+
if token is None:
|
123
|
+
continue
|
124
|
+
if token == "```":
|
125
|
+
break
|
126
|
+
# elif token == "<think>":
|
127
|
+
# text += _read_think_block(tokens)
|
128
|
+
elif token.startswith("```") and len(token) > 3:
|
129
|
+
text += _read_code_block(tokens)
|
130
|
+
text += "```"
|
131
|
+
else:
|
132
|
+
text += token
|
133
|
+
return text
|
134
|
+
|
135
|
+
tokens = re.split(r"(<think>)|(</think>)|(```[a-zA-Z_0-9]+)|(```)", reply)
|
136
|
+
iter_tokens = iter(tokens)
|
137
|
+
while True:
|
138
|
+
try:
|
139
|
+
token = next(iter_tokens)
|
140
|
+
except StopIteration:
|
141
|
+
break
|
142
|
+
if token:
|
143
|
+
if token == "<think>":
|
144
|
+
think_block = _read_think_block(iter_tokens)
|
145
|
+
raw_think_block = token + think_block + "</think>"
|
146
|
+
if (self.dispaly_think or display_reply) and think_block and think_block.strip():
|
147
|
+
_B(think_block, title="Thought Block")
|
148
|
+
if ret_think_block and (ret_empty_block or think_block and think_block.strip()):
|
149
|
+
yield {"type": "think", "content": think_block, "raw": raw_think_block}
|
150
|
+
elif token.startswith("```") and len(token) > 3:
|
151
|
+
content = _read_code_block(iter_tokens)
|
152
|
+
raw_content = token + content + "```"
|
153
|
+
lang = token[3:].lower()
|
154
|
+
if display_reply and content and content.strip():
|
155
|
+
_B(content, title="Code Block", format="code", code_language=lang)
|
156
|
+
if ret_empty_block or content and content.strip():
|
157
|
+
yield {"type": "code", "lang": lang, "content": content, "raw": raw_content}
|
158
|
+
elif token.startswith("```") and len(token) == 3:
|
159
|
+
content = _read_fence_block(iter_tokens)
|
160
|
+
raw_content = token + content + "```"
|
161
|
+
if display_reply and content and content.strip():
|
162
|
+
_B(content, title="Fence Block", format="code", code_language="text")
|
163
|
+
if ret_empty_block or content and content.strip():
|
164
|
+
yield {"type": "fence", "content": content, "raw": raw_content}
|
165
|
+
else:
|
166
|
+
if display_reply and token and token.strip():
|
167
|
+
_M(token)
|
168
|
+
if ret_empty_block or token and token.strip():
|
169
|
+
yield {"type": "text", "content": token, "raw": token}
|
170
|
+
|
171
|
+
def create_messages(self, contexts=None, templates=None):
|
172
|
+
return ChatMessages(contexts=contexts, templates=templates, display_message=self.display_message)
|
173
|
+
|
174
|
+
def chat(
|
175
|
+
self,
|
176
|
+
messages,
|
177
|
+
ret_think_block=False,
|
178
|
+
ret_empty_block=False,
|
179
|
+
display_reply=True,
|
180
|
+
max_tokens=32 * 1024,
|
181
|
+
max_completion_tokens=4 * 1024,
|
182
|
+
temperature=0.8,
|
183
|
+
n=1,
|
184
|
+
**kwargs,
|
185
|
+
):
|
186
|
+
"""发送聊天请求"""
|
187
|
+
sizes = [len(content["text"]) for content in messages[0]["content"]]
|
188
|
+
total_size = sum(sizes)
|
189
|
+
_D("Total message size: {} chars, {}".format(total_size, sizes))
|
190
|
+
_I("Connecting to OpenAI API: {}".format(self.base_url or "default"))
|
191
|
+
openai_client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
192
|
+
_I("Sending request to OpenAI API, model: {}".format(self.model_name))
|
193
|
+
response = openai_client.chat.completions.create(
|
194
|
+
model=self.model_name,
|
195
|
+
messages=messages,
|
196
|
+
max_tokens=max_tokens,
|
197
|
+
max_completion_tokens=max_completion_tokens,
|
198
|
+
temperature=temperature,
|
199
|
+
n=n,
|
200
|
+
**kwargs,
|
201
|
+
)
|
202
|
+
if not response.choices or not response.choices[0].message:
|
203
|
+
_E("No valid response from OpenAI API")
|
204
|
+
return []
|
205
|
+
else:
|
206
|
+
_I("Received response from OpenAI API")
|
207
|
+
_D("Response content: " + repr(response.choices[0].message.content)[:50])
|
208
|
+
if self.display_response:
|
209
|
+
_B(response.choices[0].message.content, title="Chat Response")
|
210
|
+
reply = response.choices[0].message.content
|
211
|
+
return list(
|
212
|
+
self.parse_reply(
|
213
|
+
reply,
|
214
|
+
ret_think_block=ret_think_block,
|
215
|
+
ret_empty_block=ret_empty_block,
|
216
|
+
display_reply=display_reply,
|
217
|
+
)
|
218
|
+
)
|