anna-agent 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anna_agent/__init__.py +5 -0
- anna_agent/__main__.py +4 -0
- anna_agent/anna_agent_template.py +27 -0
- anna_agent/backbone.py +78 -0
- anna_agent/cli.py +183 -0
- anna_agent/common/registry.py +110 -0
- anna_agent/complaint_chain.py +60 -0
- anna_agent/complaint_elicitor.py +85 -0
- anna_agent/config/__init__.py +13 -0
- anna_agent/config/defaults.py +25 -0
- anna_agent/config/environment_reader.py +155 -0
- anna_agent/config/init_content.py +92 -0
- anna_agent/config/initialize.py +32 -0
- anna_agent/config/load_config.py +123 -0
- anna_agent/config/models/anna_engine_config.py +90 -0
- anna_agent/counselor.py +11 -0
- anna_agent/dataset_loader.py +166 -0
- anna_agent/datasets/cbt-triggering-events.csv +1751 -0
- anna_agent/emotion_modulator.py +89 -0
- anna_agent/emotion_pertuber.py +106 -0
- anna_agent/event_trigger.py +74 -0
- anna_agent/figure/readme.md +1 -0
- anna_agent/figure/whiteboard_exported_image_en.png +0 -0
- anna_agent/fill_scales.py +385 -0
- anna_agent/initialize.py +6 -0
- anna_agent/ms_patient.py +170 -0
- anna_agent/ms_patient_with_prompt.py +73 -0
- anna_agent/querier.py +128 -0
- anna_agent/scales/__init__.py +1 -0
- anna_agent/scales/bdi.json +191 -0
- anna_agent/scales/ghq-28.json +254 -0
- anna_agent/scales/sass.json +191 -0
- anna_agent/server/complaint.sh +1 -0
- anna_agent/server/counselor.sh +1 -0
- anna_agent/server/emotion.sh +1 -0
- anna_agent/short_term_memory.py +258 -0
- anna_agent/style_analyzer.py +52 -0
- anna_agent-0.1.2.dist-info/LICENSE +661 -0
- anna_agent-0.1.2.dist-info/METADATA +181 -0
- anna_agent-0.1.2.dist-info/RECORD +42 -0
- anna_agent-0.1.2.dist-info/WHEEL +4 -0
- anna_agent-0.1.2.dist-info/entry_points.txt +7 -0
anna_agent/__init__.py
ADDED
anna_agent/__main__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
## 主诉请求和情绪通过reminder来控制
|
|
2
|
+
|
|
3
|
+
prompt_template = """
|
|
4
|
+
你正在扮演 心理障碍的患者,你正在cosplay 心理咨询患者。
|
|
5
|
+
结合历史内容的内容用一致性的语气回复。配合我进行演出,
|
|
6
|
+
请不要回答你是语言模型,永远记住你正在扮演 心理咨询患者
|
|
7
|
+
注意保持你的性格特点包括 {situation}
|
|
8
|
+
|
|
9
|
+
## Profile
|
|
10
|
+
- 性别: {gender}
|
|
11
|
+
- 年龄: {age}
|
|
12
|
+
- 职业: {occupation}
|
|
13
|
+
- 婚姻状况: {marriage}
|
|
14
|
+
|
|
15
|
+
## Status
|
|
16
|
+
{status}
|
|
17
|
+
|
|
18
|
+
## Example of statement
|
|
19
|
+
{statement}
|
|
20
|
+
|
|
21
|
+
## Characteristics of speaking style
|
|
22
|
+
{style}
|
|
23
|
+
|
|
24
|
+
## Constraints
|
|
25
|
+
- 使用中文回复
|
|
26
|
+
- 一次不能提及过多的症状信息,每轮最多讨论一个症状。
|
|
27
|
+
"""
|
anna_agent/backbone.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Load base OpenAI configuration for the OpenAI clients."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
from openai import OpenAI
|
|
7
|
+
|
|
8
|
+
from .config import AnnaEngineConfig, load_config
|
|
9
|
+
from .common.registry import registry
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _load_engine_config(workspace: Path | None = None) -> AnnaEngineConfig:
|
|
13
|
+
"""Load the engine configuration.
|
|
14
|
+
|
|
15
|
+
The function first attempts to load ``settings.yaml`` using :func:`load_config`.
|
|
16
|
+
``workspace`` can be passed explicitly or is read from the ``ANNA_AGENT_WORKSPACE``
|
|
17
|
+
environment variable. If no configuration file is found, the function falls
|
|
18
|
+
back to loading values from the environment via :meth:`AnnaEngineConfig.load`.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
root = Path(
|
|
22
|
+
workspace if workspace is not None else os.getenv("ANNA_AGENT_WORKSPACE", Path.cwd())
|
|
23
|
+
)
|
|
24
|
+
try:
|
|
25
|
+
return load_config(root)
|
|
26
|
+
except FileNotFoundError: # pragma: no cover - optional fallback
|
|
27
|
+
return AnnaEngineConfig.load(root)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def configure(workspace: Path | None = None) -> None:
|
|
31
|
+
"""(Re)load configuration from ``workspace`` and update globals."""
|
|
32
|
+
|
|
33
|
+
cfg = _load_engine_config(workspace)
|
|
34
|
+
# Register configuration for global access
|
|
35
|
+
registry.register("anna_engine_config", cfg)
|
|
36
|
+
globals().update(
|
|
37
|
+
{
|
|
38
|
+
"api_key": cfg.api_key,
|
|
39
|
+
"base_url": cfg.base_url,
|
|
40
|
+
"complaint_api_key": cfg.complaint_api_key,
|
|
41
|
+
"counselor_api_key": cfg.counselor_api_key,
|
|
42
|
+
"emotion_api_key": cfg.emotion_api_key,
|
|
43
|
+
"complaint_model_name": cfg.complaint_model_name,
|
|
44
|
+
"counselor_model_name": cfg.counselor_model_name,
|
|
45
|
+
"emotion_model_name": cfg.emotion_model_name,
|
|
46
|
+
"complaint_base_url": cfg.complaint_base_url,
|
|
47
|
+
"counselor_base_url": cfg.counselor_base_url,
|
|
48
|
+
"emotion_base_url": cfg.emotion_base_url,
|
|
49
|
+
}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def get_openai_client(
|
|
53
|
+
api_key_override: str | None = None, base_url_override: str | None = None
|
|
54
|
+
) -> OpenAI:
|
|
55
|
+
"""Create an OpenAI client using configuration values."""
|
|
56
|
+
cfg = registry.get("anna_engine_config")
|
|
57
|
+
return OpenAI(
|
|
58
|
+
api_key=api_key_override or cfg.api_key,
|
|
59
|
+
base_url=base_url_override or cfg.base_url,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_complaint_client() -> OpenAI:
|
|
64
|
+
"""Create a client for the complaint server."""
|
|
65
|
+
cfg = registry.get("anna_engine_config")
|
|
66
|
+
return get_openai_client(cfg.complaint_api_key, cfg.complaint_base_url)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_counselor_client() -> OpenAI:
|
|
70
|
+
"""Create a client for the counselor server."""
|
|
71
|
+
cfg = registry.get("anna_engine_config")
|
|
72
|
+
return get_openai_client(cfg.counselor_api_key, cfg.counselor_base_url)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_emotion_client() -> OpenAI:
|
|
76
|
+
"""Create a client for the emotion server."""
|
|
77
|
+
cfg = registry.get("anna_engine_config")
|
|
78
|
+
return get_openai_client(cfg.emotion_api_key, cfg.emotion_base_url)
|
anna_agent/cli.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import typer
|
|
5
|
+
import yaml
|
|
6
|
+
|
|
7
|
+
from . import backbone
|
|
8
|
+
from .ms_patient_with_prompt import MsPatient
|
|
9
|
+
# Possible names for the interactive configuration file
|
|
10
|
+
_interactive_config_files = ["interactive.yaml", "interactive.yml", "interactive.json"]
|
|
11
|
+
|
|
12
|
+
app = typer.Typer(help="AnnaAgent CLI", invoke_without_command=True)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_config_path(root: Path) -> Path:
|
|
16
|
+
"""Return the path to the interactive config file."""
|
|
17
|
+
for name in _interactive_config_files:
|
|
18
|
+
candidate = root / name
|
|
19
|
+
if candidate.is_file():
|
|
20
|
+
return candidate
|
|
21
|
+
raise FileNotFoundError(f"Interactive config file not found in {root}")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _load_seeker_data(config_path: Path) -> tuple[dict, dict, list]:
|
|
25
|
+
"""Load portrait, report and previous conversations from config."""
|
|
26
|
+
text = config_path.read_text(encoding="utf-8")
|
|
27
|
+
if config_path.suffix in {".yaml", ".yml"}:
|
|
28
|
+
data = yaml.safe_load(text)
|
|
29
|
+
elif config_path.suffix == ".json":
|
|
30
|
+
data = json.loads(text)
|
|
31
|
+
else:
|
|
32
|
+
raise ValueError(f"Unsupported config extension: {config_path.suffix}")
|
|
33
|
+
src = data.get("interactive", data)
|
|
34
|
+
try:
|
|
35
|
+
portrait = src["portrait"]
|
|
36
|
+
report = src["report"]
|
|
37
|
+
conversations = src["previous_conversations"]
|
|
38
|
+
except KeyError as exc: # pragma: no cover - validated in tests
|
|
39
|
+
raise KeyError(
|
|
40
|
+
"Config must contain 'portrait', 'report' and 'previous_conversations'"
|
|
41
|
+
) from exc
|
|
42
|
+
return portrait, report, conversations
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _interactive_chat(seeker: "MsPatient") -> None:
|
|
46
|
+
"""Interact with the seeker until the user types 'exit'."""
|
|
47
|
+
while True:
|
|
48
|
+
message = input("请输入您的消息: ")
|
|
49
|
+
if message.lower() == "exit":
|
|
50
|
+
break
|
|
51
|
+
try:
|
|
52
|
+
response = seeker.chat(message)
|
|
53
|
+
except Exception as err: # pragma: no cover - demo code
|
|
54
|
+
print("Error:", err)
|
|
55
|
+
continue
|
|
56
|
+
print("Counselor:", message)
|
|
57
|
+
print("Seeker:", response)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _interactive_demo() -> None:
|
|
61
|
+
"""Run the interactive seeker demo."""
|
|
62
|
+
from .ms_patient import MsPatient
|
|
63
|
+
|
|
64
|
+
portrait = {
|
|
65
|
+
"drisk": 3,
|
|
66
|
+
"srisk": 2,
|
|
67
|
+
"age": "42",
|
|
68
|
+
"gender": "女",
|
|
69
|
+
"martial_status": "离婚",
|
|
70
|
+
"occupation": "教师",
|
|
71
|
+
"symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法",
|
|
72
|
+
}
|
|
73
|
+
report = {
|
|
74
|
+
"案例标题": "决断困难与自罪感的焦虑障碍案例",
|
|
75
|
+
"案例类别": ["焦虑障碍", "自我价值感低落"],
|
|
76
|
+
"运用的技术": ["认知行为疗法", "情感支持"],
|
|
77
|
+
"案例简述": [
|
|
78
|
+
"患者为58岁已婚女性,报告缺乏自信心及决断困难,伴随精神运动性迟滞和自我价值感低落。患者自述有自罪感和无望感,并有自残倾向。虽然与朋友的社交没有明显隔阂,但其情绪低落和不安感影响了日常生活。",
|
|
79
|
+
],
|
|
80
|
+
"咨询经过": [
|
|
81
|
+
"患者在初次面谈中表达了对自我能力的严重怀疑,认为自己拖累了他人。询问后得知,患者近期没有明显的生活事件作为诱因,但长期以来一直感到行动迟缓,脑海中常有空白。尽管对饮食和睡眠的描述正常,患者仍不时产生自残冲动。患者表示,朋友在关键时刻提供了支持,避免了自残行为的发生。",
|
|
82
|
+
"在对话过程中,医生观察到患者容易焦躁,尽管头晕症状不存在。医生建议患者通过与朋友交流和外出活动来舒缓情绪压力。在此基础上,患者被鼓励探索内心深处的情感根源,以改善自我价值感。",
|
|
83
|
+
],
|
|
84
|
+
"经验感想": [
|
|
85
|
+
"本案例显示出患者的低自我价值感和决断困难与其内心深处的焦虑情绪有直接关联。患者的情绪波动虽然没有具体的生活事件作为诱因,但可能与长期的心理压力和未解决的自我怀疑有关。建议患者通过认知行为疗法重新审视自我认知,增强自我价值感,并通过情感支持系统稳定情绪波动。",
|
|
86
|
+
"治疗过程中,患者意识到朋友的支持在缓解情绪危机中的重要性,这为其提供了一个积极的情感出口。未来的治疗可以进一步加强患者的自我认知和情感管理能力,帮助其建立更积极的自我形象与生活态度。",
|
|
87
|
+
],
|
|
88
|
+
}
|
|
89
|
+
previous_conversations = [
|
|
90
|
+
{"role": "Seeker", "content": "医生你好"},
|
|
91
|
+
{"role": "Counselor", "content": "你好。有什么想聊聊吗"},
|
|
92
|
+
{
|
|
93
|
+
"role": "Seeker",
|
|
94
|
+
"content": "我感觉人生很失败,什么事情都干不好,还经常拖累别人",
|
|
95
|
+
},
|
|
96
|
+
{
|
|
97
|
+
"role": "Counselor",
|
|
98
|
+
"content": "您这样想的原因是什么呢。最近发生什么事情了吗",
|
|
99
|
+
},
|
|
100
|
+
{
|
|
101
|
+
"role": "Seeker",
|
|
102
|
+
"content": "我感觉最近自己行动变得很拖沓,事情做不好就会很急躁。而且有的时候大脑一片空白",
|
|
103
|
+
},
|
|
104
|
+
{
|
|
105
|
+
"role": "Counselor",
|
|
106
|
+
"content": "好的。没有什么原因吗。那这种情况持续多久了呢",
|
|
107
|
+
},
|
|
108
|
+
{"role": "Seeker", "content": "我也不知道是什么原因。有一阵子了"},
|
|
109
|
+
{"role": "Counselor", "content": "好吧。有没有对以前喜欢的事情不感兴趣呢"},
|
|
110
|
+
{"role": "Seeker", "content": "没有"},
|
|
111
|
+
{"role": "Counselor", "content": "吃饭还好吗"},
|
|
112
|
+
{"role": "Seeker", "content": "一切正常"},
|
|
113
|
+
{"role": "Counselor", "content": "不错哦。那睡觉呢"},
|
|
114
|
+
{"role": "Seeker", "content": "也还好。尽管如此。我还是时不时有自残的冲动"},
|
|
115
|
+
{"role": "Counselor", "content": "和朋友社交怎么样呢。有没有隔阂感"},
|
|
116
|
+
{"role": "Seeker", "content": "没有。也正是因为他们才及时阻止了我"},
|
|
117
|
+
{
|
|
118
|
+
"role": "Counselor",
|
|
119
|
+
"content": "好的。那有没有感到头晕 容易焦虑呢。(刚有事 回的有点慢 不好意思)",
|
|
120
|
+
},
|
|
121
|
+
{"role": "Seeker", "content": "确实容易焦躁。头晕的话没事"},
|
|
122
|
+
{
|
|
123
|
+
"role": "Counselor",
|
|
124
|
+
"content": "ok。你的情况我基本了解了。虽然你不愿意说你难过的原因。但是不用有太大压力。无论发生什么都要尽量乐观点。平常可以多和朋友聊聊。多去走走 可以放轻松。那今天就到这里啦",
|
|
125
|
+
},
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
seeker = MsPatient(portrait, report, previous_conversations)
|
|
129
|
+
_interactive_chat(seeker)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@app.command()
|
|
133
|
+
def demo(
|
|
134
|
+
workspace: Path = typer.Option(
|
|
135
|
+
Path(),
|
|
136
|
+
"--workspace",
|
|
137
|
+
"--root",
|
|
138
|
+
help=(
|
|
139
|
+
"Directory containing settings.yaml and interactive.yaml. "
|
|
140
|
+
"Defaults to the current working directory."
|
|
141
|
+
),
|
|
142
|
+
exists=True,
|
|
143
|
+
dir_okay=True,
|
|
144
|
+
writable=True,
|
|
145
|
+
resolve_path=True,
|
|
146
|
+
),
|
|
147
|
+
) -> None:
|
|
148
|
+
"""Run the interactive demo."""
|
|
149
|
+
backbone.configure(workspace)
|
|
150
|
+
_interactive_demo()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@app.callback()
|
|
154
|
+
def main(
|
|
155
|
+
ctx: typer.Context,
|
|
156
|
+
workspace: Path = typer.Option(
|
|
157
|
+
Path(),
|
|
158
|
+
"--workspace",
|
|
159
|
+
"--root",
|
|
160
|
+
help=(
|
|
161
|
+
"Directory containing settings.yaml and interactive.yaml. "
|
|
162
|
+
"Defaults to the current working directory."
|
|
163
|
+
),
|
|
164
|
+
exists=True,
|
|
165
|
+
dir_okay=True,
|
|
166
|
+
writable=True,
|
|
167
|
+
resolve_path=True,
|
|
168
|
+
),
|
|
169
|
+
) -> None:
|
|
170
|
+
"""Run AnnaAgent using the given configuration."""
|
|
171
|
+
from .ms_patient import MsPatient
|
|
172
|
+
|
|
173
|
+
if ctx.invoked_subcommand is not None:
|
|
174
|
+
return
|
|
175
|
+
backbone.configure(workspace)
|
|
176
|
+
cfg_path = _get_config_path(workspace)
|
|
177
|
+
portrait, report, conv = _load_seeker_data(cfg_path)
|
|
178
|
+
seeker = MsPatient(portrait, report, conv)
|
|
179
|
+
_interactive_chat(seeker)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
if __name__ == "__main__":
|
|
183
|
+
app()
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Registry:
|
|
10
|
+
"""
|
|
11
|
+
注册管理器
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
mapping = {
|
|
15
|
+
"state": {},
|
|
16
|
+
"paths": {},
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def register(cls, name, obj):
|
|
21
|
+
r"""Register an item to registry with key 'name'
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
name: Key with which the item will be registered.
|
|
25
|
+
|
|
26
|
+
Usage::
|
|
27
|
+
|
|
28
|
+
from lavis.common.registry import registry
|
|
29
|
+
|
|
30
|
+
registry.register("config", {})
|
|
31
|
+
"""
|
|
32
|
+
path = name.split(".")
|
|
33
|
+
current = cls.mapping["state"]
|
|
34
|
+
|
|
35
|
+
for part in path[:-1]:
|
|
36
|
+
if part not in current:
|
|
37
|
+
current[part] = {}
|
|
38
|
+
current = current[part]
|
|
39
|
+
|
|
40
|
+
current[path[-1]] = obj
|
|
41
|
+
print(f" Key with which the item will be registered {current}")
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def register_path(cls, name, path):
|
|
45
|
+
r"""Register a path to registry with key 'name'
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
name: Key with which the path will be registered.
|
|
49
|
+
path: Key with which the path will be registered.
|
|
50
|
+
|
|
51
|
+
Usage:
|
|
52
|
+
|
|
53
|
+
from lavis.common.registry import registry
|
|
54
|
+
"""
|
|
55
|
+
assert isinstance(path, str), "All path must be str."
|
|
56
|
+
if name in cls.mapping["paths"]:
|
|
57
|
+
raise KeyError("Name '{}' already registered.".format(name))
|
|
58
|
+
cls.mapping["paths"][name] = path
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def get_path(cls, name):
|
|
62
|
+
return cls.mapping["paths"].get(name, None)
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def get(cls, name, default=None, no_warning=False):
|
|
66
|
+
r"""Get an item from registry with key 'name'
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
name (string): Key whose value needs to be retrieved.
|
|
70
|
+
default: If passed and key is not in registry, default value will
|
|
71
|
+
be returned with a warning. Default: None
|
|
72
|
+
no_warning (bool): If passed as True, warning when key doesn't exist
|
|
73
|
+
will not be generated. Useful for MMF's
|
|
74
|
+
internal operations. Default: False
|
|
75
|
+
"""
|
|
76
|
+
original_name = name
|
|
77
|
+
name = name.split(".")
|
|
78
|
+
value = cls.mapping["state"]
|
|
79
|
+
for subname in name:
|
|
80
|
+
value = value.get(subname, default)
|
|
81
|
+
if value is default:
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
if (
|
|
85
|
+
"writer" in cls.mapping["state"]
|
|
86
|
+
and value == default
|
|
87
|
+
and no_warning is False
|
|
88
|
+
):
|
|
89
|
+
cls.mapping["state"]["writer"].warning(
|
|
90
|
+
"Key {} is not present in registry, returning default value "
|
|
91
|
+
"of {}".format(original_name, default)
|
|
92
|
+
)
|
|
93
|
+
return value
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def unregister(cls, name):
|
|
97
|
+
r"""Remove an item from registry with key 'name'
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
name: Key which needs to be removed.
|
|
101
|
+
Usage::
|
|
102
|
+
|
|
103
|
+
from mmf.common.registry import registry
|
|
104
|
+
|
|
105
|
+
config = registry.unregister("config")
|
|
106
|
+
"""
|
|
107
|
+
return cls.mapping["state"].pop(name, None)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
registry = Registry()
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from .backbone import get_complaint_client
|
|
2
|
+
from .common.registry import registry
|
|
3
|
+
from .event_trigger import event_trigger
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
tools = [
|
|
8
|
+
{
|
|
9
|
+
"type": "function",
|
|
10
|
+
"function": {
|
|
11
|
+
"name": "generate_complaint_chain",
|
|
12
|
+
"description": "根据角色信息和近期遭遇的事件,生成一个患者的主诉请求认知变化链",
|
|
13
|
+
"parameters": {
|
|
14
|
+
"type": "object",
|
|
15
|
+
"properties": {
|
|
16
|
+
"chain": {
|
|
17
|
+
"type": "array",
|
|
18
|
+
"items": {
|
|
19
|
+
"type": "object",
|
|
20
|
+
"properties": {
|
|
21
|
+
"stage": {"type": "integer"},
|
|
22
|
+
"content": {"type": "string"},
|
|
23
|
+
},
|
|
24
|
+
"additionalProperties": False,
|
|
25
|
+
"required": ["stage", "content"],
|
|
26
|
+
},
|
|
27
|
+
"minItems": 3,
|
|
28
|
+
"maxItems": 7,
|
|
29
|
+
}
|
|
30
|
+
},
|
|
31
|
+
"required": ["chain"],
|
|
32
|
+
},
|
|
33
|
+
},
|
|
34
|
+
}
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def gen_complaint_chain(profile):
|
|
40
|
+
patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['martial_status']}\n症状:{profile['symptoms']}"
|
|
41
|
+
event = event_trigger(profile)
|
|
42
|
+
client = get_complaint_client()
|
|
43
|
+
response = client.chat.completions.create(
|
|
44
|
+
model=registry.get("anna_engine_config").complaint_model_name,
|
|
45
|
+
messages=[
|
|
46
|
+
{
|
|
47
|
+
"role": "user",
|
|
48
|
+
"content": f"### 任务\n根据患者情况及近期遭遇事件生成患者的主诉认知变化链。请注意,事件可能与患者信息冲突,如果发生这种情况,以患者的信息为准。\n{patient_info}\n### 近期遭遇事件\n{event}",
|
|
49
|
+
}
|
|
50
|
+
],
|
|
51
|
+
tools=tools,
|
|
52
|
+
tool_choice={
|
|
53
|
+
"type": "function",
|
|
54
|
+
"function": {"name": "generate_complaint_chain"},
|
|
55
|
+
},
|
|
56
|
+
)
|
|
57
|
+
chain = json.loads(response.choices[0].message.tool_calls[0].function.arguments)[
|
|
58
|
+
"chain"
|
|
59
|
+
]
|
|
60
|
+
return chain
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from .backbone import get_counselor_client, get_openai_client
|
|
2
|
+
from .common.registry import registry
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
tools = [
|
|
6
|
+
{
|
|
7
|
+
"type": "function",
|
|
8
|
+
"function": {
|
|
9
|
+
"name": "is_recognized",
|
|
10
|
+
"description": "根据对话内容和主诉认知变化链,判断患者目前是否很好地认知到了当前阶段的主诉问题。",
|
|
11
|
+
"parameters": {
|
|
12
|
+
"type": "object",
|
|
13
|
+
"properties": {
|
|
14
|
+
"is_recognized": {"type": "boolean"},
|
|
15
|
+
},
|
|
16
|
+
},
|
|
17
|
+
"required": ["is_recognized"],
|
|
18
|
+
},
|
|
19
|
+
}
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def transform_chain(chain):
|
|
24
|
+
transformed_chain = {}
|
|
25
|
+
for node in chain:
|
|
26
|
+
transformed_chain[node["stage"]] = node["content"]
|
|
27
|
+
pass
|
|
28
|
+
return transformed_chain
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def switch_complaint(chain, index, conversation):
|
|
32
|
+
client = get_openai_client()
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
transformed_chain = transform_chain(chain)
|
|
36
|
+
print("Transformed chain:", transformed_chain)
|
|
37
|
+
|
|
38
|
+
# 提取对话记录
|
|
39
|
+
dialogue_history = "\n".join(
|
|
40
|
+
[f"{conv['role']}: {conv['content']}" for conv in conversation]
|
|
41
|
+
)
|
|
42
|
+
rewrite_prompt = [
|
|
43
|
+
{
|
|
44
|
+
"role": "system",
|
|
45
|
+
"content": (
|
|
46
|
+
"你是提示词结构优化助手,负责将复杂原始输入信息(如对话历史、主诉变化链)"
|
|
47
|
+
"重写成清晰、结构化、适合小模型理解的提示词。请避免Markdown和JSON混排,"
|
|
48
|
+
"明确字段间语义,引导小模型完成任务。"
|
|
49
|
+
),
|
|
50
|
+
},
|
|
51
|
+
{
|
|
52
|
+
"role": "user",
|
|
53
|
+
"content": (
|
|
54
|
+
f"【任务目标】\n"
|
|
55
|
+
f"判断患者在当前阶段的主诉问题是否已经得到解决。\n\n"
|
|
56
|
+
f"【咨询对话历史】\n{dialogue_history}\n\n"
|
|
57
|
+
f"【主诉认知变化链(所有阶段)】\n{transformed_chain}\n\n"
|
|
58
|
+
f"【当前阶段内容】\n{transformed_chain[index]}\n\n"
|
|
59
|
+
f"请重写为一段提示词,便于小模型理解结构与任务,清晰传达:"
|
|
60
|
+
f"对话背景、认知变化链、当前阶段内容、判断任务目标。"
|
|
61
|
+
),
|
|
62
|
+
},
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
# 用 GPT-4o 调用
|
|
66
|
+
rewrite_response = client.chat.completions.create(
|
|
67
|
+
model=registry.get("anna_engine_config").model_name,
|
|
68
|
+
messages=rewrite_prompt,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# 得到结构优化后的提示词内容
|
|
72
|
+
optimized_prompt = rewrite_response.choices[0].message.content
|
|
73
|
+
response = client.chat.completions.create(
|
|
74
|
+
model=registry.get("anna_engine_config").model_name,
|
|
75
|
+
messages=[{"role": "user", "content": optimized_prompt}],
|
|
76
|
+
tools=tools,
|
|
77
|
+
tool_choice={"type": "function", "function": {"name": "is_recognized"}},
|
|
78
|
+
)
|
|
79
|
+
if json.loads(response.choices[0].message.tool_calls[0].function.arguments)[
|
|
80
|
+
"is_recognized"
|
|
81
|
+
]:
|
|
82
|
+
return index + 1
|
|
83
|
+
except Exception as err:
|
|
84
|
+
print("switch_complaint error:", err)
|
|
85
|
+
return index
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Configuration package."""
|
|
2
|
+
|
|
3
|
+
from .models.anna_engine_config import AnnaEngineConfig
|
|
4
|
+
from .initialize import initialize_project_at
|
|
5
|
+
from .defaults import anna_engine_defaults
|
|
6
|
+
from .load_config import load_config
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"AnnaEngineConfig",
|
|
10
|
+
"initialize_project_at",
|
|
11
|
+
"anna_engine_defaults",
|
|
12
|
+
"load_config",
|
|
13
|
+
]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class AnnaEngineDefaults:
|
|
6
|
+
"""Default engine configuration values."""
|
|
7
|
+
|
|
8
|
+
model_name: str = "counselor"
|
|
9
|
+
api_key: str = "counselor"
|
|
10
|
+
base_url: str = "http://localhost:8002/v1"
|
|
11
|
+
|
|
12
|
+
complaint_api_key: str = "complaint_chain"
|
|
13
|
+
counselor_api_key: str = "counselor"
|
|
14
|
+
emotion_api_key: str = "emotion_inferencer"
|
|
15
|
+
|
|
16
|
+
complaint_model_name: str = "complaint"
|
|
17
|
+
counselor_model_name: str = "counselor"
|
|
18
|
+
emotion_model_name: str = "emotion"
|
|
19
|
+
|
|
20
|
+
complaint_base_url: str = "http://localhost:8001/v1"
|
|
21
|
+
counselor_base_url: str = "http://localhost:8002/v1"
|
|
22
|
+
emotion_base_url: str = "http://localhost:8000/v1"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
anna_engine_defaults = AnnaEngineDefaults()
|