pycityagent 2.0.0a6__py3-none-any.whl → 2.0.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +5 -3
- pycityagent/environment/interact/interact.py +86 -29
- pycityagent/environment/sence/static.py +3 -2
- pycityagent/environment/sim/aoi_service.py +1 -1
- pycityagent/environment/sim/economy_services.py +1 -1
- pycityagent/environment/sim/road_service.py +1 -1
- pycityagent/environment/sim/social_service.py +1 -1
- pycityagent/environment/simulator.py +6 -4
- pycityagent/environment/utils/__init__.py +5 -1
- pycityagent/llm/__init__.py +1 -1
- pycityagent/llm/embedding.py +36 -35
- pycityagent/llm/llm.py +197 -161
- pycityagent/llm/llmconfig.py +7 -9
- pycityagent/llm/utils.py +2 -2
- pycityagent/memory/memory.py +1 -2
- pycityagent/memory/memory_base.py +1 -2
- pycityagent/memory/profile.py +1 -2
- pycityagent/memory/self_define.py +1 -2
- pycityagent/memory/state.py +1 -2
- pycityagent/message/__init__.py +1 -1
- pycityagent/message/messager.py +11 -4
- pycityagent/simulation/__init__.py +1 -1
- pycityagent/simulation/agentgroup.py +13 -6
- pycityagent/simulation/interview.py +9 -5
- pycityagent/simulation/simulation.py +86 -33
- pycityagent/simulation/survey/__init__.py +1 -6
- pycityagent/simulation/survey/manager.py +22 -21
- pycityagent/simulation/survey/models.py +8 -5
- pycityagent/utils/decorators.py +14 -4
- pycityagent/utils/parsers/__init__.py +2 -1
- pycityagent/workflow/block.py +4 -3
- pycityagent/workflow/prompt.py +16 -9
- pycityagent/workflow/tool.py +1 -2
- pycityagent/workflow/trigger.py +36 -23
- {pycityagent-2.0.0a6.dist-info → pycityagent-2.0.0a7.dist-info}/METADATA +1 -1
- pycityagent-2.0.0a7.dist-info/RECORD +70 -0
- pycityagent-2.0.0a6.dist-info/RECORD +0 -70
- {pycityagent-2.0.0a6.dist-info → pycityagent-2.0.0a7.dist-info}/WHEEL +0 -0
pycityagent/memory/profile.py
CHANGED
@@ -3,8 +3,7 @@ Agent Profile
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
from copy import deepcopy
|
6
|
-
from typing import
|
7
|
-
Union, cast)
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
|
8
7
|
|
9
8
|
from ..utils.decorators import lock_decorator
|
10
9
|
from .const import *
|
@@ -3,8 +3,7 @@ Self Define Data
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
from copy import deepcopy
|
6
|
-
from typing import
|
7
|
-
Union, cast)
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
|
8
7
|
|
9
8
|
from ..utils.decorators import lock_decorator
|
10
9
|
from .const import *
|
pycityagent/memory/state.py
CHANGED
@@ -3,8 +3,7 @@ Agent State
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
from copy import deepcopy
|
6
|
-
from typing import
|
7
|
-
Union, cast)
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
|
8
7
|
|
9
8
|
from ..utils.decorators import lock_decorator
|
10
9
|
from .const import *
|
pycityagent/message/__init__.py
CHANGED
pycityagent/message/messager.py
CHANGED
@@ -4,9 +4,14 @@ import logging
|
|
4
4
|
import math
|
5
5
|
from aiomqtt import Client
|
6
6
|
|
7
|
+
|
7
8
|
class Messager:
|
8
|
-
def __init__(
|
9
|
-
self
|
9
|
+
def __init__(
|
10
|
+
self, hostname, port=1883, username=None, password=None, timeout=math.inf
|
11
|
+
):
|
12
|
+
self.client = Client(
|
13
|
+
hostname, port=port, username=username, password=password, timeout=timeout
|
14
|
+
)
|
10
15
|
self.connected = False # 是否已连接标志
|
11
16
|
self.message_queue = asyncio.Queue() # 用于存储接收到的消息
|
12
17
|
self.subscribers = {} # 订阅者信息,topic -> Agent 映射
|
@@ -31,7 +36,9 @@ class Messager:
|
|
31
36
|
|
32
37
|
async def subscribe(self, topic, agent):
|
33
38
|
if not self.is_connected():
|
34
|
-
logging.error(
|
39
|
+
logging.error(
|
40
|
+
f"Cannot subscribe to {topic} because not connected to the Broker."
|
41
|
+
)
|
35
42
|
return
|
36
43
|
await self.client.subscribe(topic)
|
37
44
|
self.subscribers[topic] = agent
|
@@ -48,7 +55,7 @@ class Messager:
|
|
48
55
|
while not self.message_queue.empty():
|
49
56
|
messages.append(await self.message_queue.get())
|
50
57
|
return messages
|
51
|
-
|
58
|
+
|
52
59
|
async def send_message(self, topic: str, payload: str, sender_id: int):
|
53
60
|
"""通过 Messager 发送消息,包含发送者 ID"""
|
54
61
|
# 构造消息,payload 中加入 sender_id 以便接收者识别
|
@@ -8,13 +8,19 @@ from pycityagent.llm.llm import LLM
|
|
8
8
|
from pycityagent.llm.llmconfig import LLMConfig
|
9
9
|
from pycityagent.message import Messager
|
10
10
|
|
11
|
+
|
11
12
|
@ray.remote
|
12
13
|
class AgentGroup:
|
13
14
|
def __init__(self, agents: list[Agent], config: dict, exp_id: str):
|
14
15
|
self.agents = agents
|
15
16
|
self.config = config
|
16
17
|
self.exp_id = exp_id
|
17
|
-
self.messager = Messager(
|
18
|
+
self.messager = Messager(
|
19
|
+
hostname=config["simulator_request"]["mqtt"]["server"],
|
20
|
+
port=config["simulator_request"]["mqtt"]["port"],
|
21
|
+
username=config["simulator_request"]["mqtt"].get("username", None),
|
22
|
+
password=config["simulator_request"]["mqtt"].get("password", None),
|
23
|
+
)
|
18
24
|
self.initialized = False
|
19
25
|
|
20
26
|
# Step:1 prepare LLM client
|
@@ -28,7 +34,9 @@ class AgentGroup:
|
|
28
34
|
|
29
35
|
# Step:3 prepare Economy client
|
30
36
|
logging.info("-----Creating Economy client in remote...")
|
31
|
-
self.economy_client = EconomyClient(
|
37
|
+
self.economy_client = EconomyClient(
|
38
|
+
config["simulator_request"]["economy"]["server"]
|
39
|
+
)
|
32
40
|
|
33
41
|
for agent in self.agents:
|
34
42
|
agent.set_exp_id(self.exp_id)
|
@@ -72,7 +80,7 @@ class AgentGroup:
|
|
72
80
|
|
73
81
|
# 添加解码步骤,将bytes转换为str
|
74
82
|
if isinstance(payload, bytes):
|
75
|
-
payload = payload.decode(
|
83
|
+
payload = payload.decode("utf-8")
|
76
84
|
# 提取 agent_id(主题格式为 "/exps/{exp_id}/agents/{agent_id}/chat")
|
77
85
|
_, _, _, agent_id, _ = topic.strip("/").split("/")
|
78
86
|
agent_id = int(agent_id)
|
@@ -96,15 +104,14 @@ class AgentGroup:
|
|
96
104
|
start_time = await self.simulator.get_time()
|
97
105
|
# 计算结束时间(秒)
|
98
106
|
end_time = start_time + day * 24 * 3600 # 将天数转换为秒
|
99
|
-
|
107
|
+
|
100
108
|
while True:
|
101
109
|
current_time = await self.simulator.get_time()
|
102
110
|
if current_time >= end_time:
|
103
111
|
break
|
104
|
-
|
112
|
+
|
105
113
|
await self.step()
|
106
114
|
|
107
115
|
except Exception as e:
|
108
116
|
logging.error(f"模拟器运行错误: {str(e)}")
|
109
117
|
raise
|
110
|
-
|
@@ -2,20 +2,24 @@ from dataclasses import dataclass
|
|
2
2
|
from datetime import datetime
|
3
3
|
from typing import List, Optional
|
4
4
|
|
5
|
+
|
5
6
|
@dataclass
|
6
7
|
class InterviewRecord:
|
7
8
|
"""采访记录"""
|
9
|
+
|
8
10
|
timestamp: datetime
|
9
11
|
agent_name: str
|
10
12
|
question: str
|
11
13
|
response: str
|
12
14
|
blocking: bool
|
13
15
|
|
16
|
+
|
14
17
|
class InterviewManager:
|
15
18
|
"""采访管理器"""
|
19
|
+
|
16
20
|
def __init__(self):
|
17
21
|
self._history: List[InterviewRecord] = []
|
18
|
-
|
22
|
+
|
19
23
|
def add_record(self, agent_name: str, question: str, response: str, blocking: bool):
|
20
24
|
"""添加采访记录"""
|
21
25
|
record = InterviewRecord(
|
@@ -23,14 +27,14 @@ class InterviewManager:
|
|
23
27
|
agent_name=agent_name,
|
24
28
|
question=question,
|
25
29
|
response=response,
|
26
|
-
blocking=blocking
|
30
|
+
blocking=blocking,
|
27
31
|
)
|
28
32
|
self._history.append(record)
|
29
|
-
|
33
|
+
|
30
34
|
def get_agent_history(self, agent_name: str) -> List[InterviewRecord]:
|
31
35
|
"""获取指定智能体的采访历史"""
|
32
36
|
return [r for r in self._history if r.agent_name == agent_name]
|
33
|
-
|
37
|
+
|
34
38
|
def get_recent_history(self, limit: int = 10) -> List[InterviewRecord]:
|
35
39
|
"""获取最近的采访记录"""
|
36
|
-
return sorted(self._history, key=lambda x: x.timestamp, reverse=True)[:limit]
|
40
|
+
return sorted(self._history, key=lambda x: x.timestamp, reverse=True)[:limit]
|
@@ -14,12 +14,16 @@ from .interview import InterviewManager
|
|
14
14
|
from .survey import QuestionType, SurveyManager
|
15
15
|
from .ui import InterviewUI
|
16
16
|
from .agentgroup import AgentGroup
|
17
|
+
|
17
18
|
logger = logging.getLogger(__name__)
|
18
19
|
|
19
20
|
|
20
21
|
class AgentSimulation:
|
21
22
|
"""城市智能体模拟器"""
|
22
|
-
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self, agent_class: type[Agent], config: dict, agent_prefix: str = "agent_"
|
26
|
+
):
|
23
27
|
"""
|
24
28
|
Args:
|
25
29
|
agent_class: 智能体类
|
@@ -41,16 +45,21 @@ class AgentSimulation:
|
|
41
45
|
self._blocked_agents: List[str] = [] # 新增:持续阻塞的智能体列表
|
42
46
|
self._survey_manager = SurveyManager()
|
43
47
|
|
44
|
-
async def init_agents(
|
48
|
+
async def init_agents(
|
49
|
+
self,
|
50
|
+
agent_count: int,
|
51
|
+
group_size: int = 1000,
|
52
|
+
memory_config_func: Callable = None,
|
53
|
+
) -> None:
|
45
54
|
"""初始化智能体
|
46
|
-
|
55
|
+
|
47
56
|
Args:
|
48
57
|
agent_count: 要创建的总智能体数量
|
49
58
|
group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
|
50
59
|
memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组
|
51
60
|
"""
|
52
61
|
if memory_config_func is None:
|
53
|
-
memory_config_func = self.default_memory_config_func
|
62
|
+
memory_config_func = self.default_memory_config_func
|
54
63
|
|
55
64
|
for i in range(agent_count):
|
56
65
|
agent_name = f"{self.agent_prefix}{i}"
|
@@ -58,11 +67,9 @@ class AgentSimulation:
|
|
58
67
|
# 获取Memory配置
|
59
68
|
extra_attributes, profile, base = memory_config_func()
|
60
69
|
memory = Memory(
|
61
|
-
config=extra_attributes,
|
62
|
-
profile=profile.copy(),
|
63
|
-
base=base.copy()
|
70
|
+
config=extra_attributes, profile=profile.copy(), base=base.copy()
|
64
71
|
)
|
65
|
-
|
72
|
+
|
66
73
|
# 创建智能体时传入Memory配置
|
67
74
|
agent = self.agent_class(
|
68
75
|
name=agent_name,
|
@@ -73,12 +80,12 @@ class AgentSimulation:
|
|
73
80
|
|
74
81
|
# 计算需要的组数,向上取整以处理不足一组的情况
|
75
82
|
num_group = (agent_count + group_size - 1) // group_size
|
76
|
-
|
83
|
+
|
77
84
|
for i in range(num_group):
|
78
85
|
# 计算当前组的起始和结束索引
|
79
86
|
start_idx = i * group_size
|
80
87
|
end_idx = min((i + 1) * group_size, agent_count)
|
81
|
-
|
88
|
+
|
82
89
|
# 获取当前组的agents
|
83
90
|
agents = list(self._agents.values())[start_idx:end_idx]
|
84
91
|
group_name = f"{self.agent_prefix}_group_{i}"
|
@@ -89,39 +96,87 @@ class AgentSimulation:
|
|
89
96
|
"""默认的Memory配置函数"""
|
90
97
|
EXTRA_ATTRIBUTES = {
|
91
98
|
# 需求信息
|
92
|
-
"needs": (
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
99
|
+
"needs": (
|
100
|
+
dict,
|
101
|
+
{
|
102
|
+
"hungry": random.random(), # 饥饿感
|
103
|
+
"tired": random.random(), # 疲劳感
|
104
|
+
"safe": random.random(), # 安全需
|
105
|
+
"social": random.random(), # 社会需求
|
106
|
+
},
|
107
|
+
True,
|
108
|
+
),
|
98
109
|
"current_need": (str, "none", True),
|
99
110
|
"current_plan": (list, [], True),
|
100
111
|
"current_step": (dict, {"intention": "", "type": ""}, True),
|
101
|
-
"execution_context"
|
112
|
+
"execution_context": (dict, {}, True),
|
102
113
|
"plan_history": (list, [], True),
|
103
114
|
}
|
104
115
|
|
105
116
|
PROFILE = {
|
106
117
|
"gender": random.choice(["male", "female"]),
|
107
|
-
"education": random.choice(
|
118
|
+
"education": random.choice(
|
119
|
+
["Doctor", "Master", "Bachelor", "College", "High School"]
|
120
|
+
),
|
108
121
|
"consumption": random.choice(["sightly low", "low", "medium", "high"]),
|
109
|
-
"occupation": random.choice(
|
122
|
+
"occupation": random.choice(
|
123
|
+
[
|
124
|
+
"Student",
|
125
|
+
"Teacher",
|
126
|
+
"Doctor",
|
127
|
+
"Engineer",
|
128
|
+
"Manager",
|
129
|
+
"Businessman",
|
130
|
+
"Artist",
|
131
|
+
"Athlete",
|
132
|
+
"Other",
|
133
|
+
]
|
134
|
+
),
|
110
135
|
"age": random.randint(18, 65),
|
111
|
-
"skill": random.choice(
|
136
|
+
"skill": random.choice(
|
137
|
+
[
|
138
|
+
"Good at problem-solving",
|
139
|
+
"Good at communication",
|
140
|
+
"Good at creativity",
|
141
|
+
"Good at teamwork",
|
142
|
+
"Other",
|
143
|
+
]
|
144
|
+
),
|
112
145
|
"family_consumption": random.choice(["low", "medium", "high"]),
|
113
|
-
"personality": random.choice(
|
146
|
+
"personality": random.choice(
|
147
|
+
["outgoint", "introvert", "ambivert", "extrovert"]
|
148
|
+
),
|
114
149
|
"income": random.randint(1000, 10000),
|
115
150
|
"currency": random.randint(10000, 100000),
|
116
151
|
"residence": random.choice(["city", "suburb", "rural"]),
|
117
|
-
"race": random.choice(
|
118
|
-
|
119
|
-
|
120
|
-
|
152
|
+
"race": random.choice(
|
153
|
+
[
|
154
|
+
"Chinese",
|
155
|
+
"American",
|
156
|
+
"British",
|
157
|
+
"French",
|
158
|
+
"German",
|
159
|
+
"Japanese",
|
160
|
+
"Korean",
|
161
|
+
"Russian",
|
162
|
+
"Other",
|
163
|
+
]
|
164
|
+
),
|
165
|
+
"religion": random.choice(
|
166
|
+
["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]
|
167
|
+
),
|
168
|
+
"marital_status": random.choice(
|
169
|
+
["not married", "married", "divorced", "widowed"]
|
170
|
+
),
|
171
|
+
}
|
121
172
|
|
122
173
|
BASE = {
|
123
|
-
"home": {
|
124
|
-
|
174
|
+
"home": {
|
175
|
+
"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
|
176
|
+
},
|
177
|
+
"work": {
|
178
|
+
"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
|
179
|
+
},
|
125
180
|
}
|
126
181
|
|
127
182
|
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
@@ -218,7 +273,7 @@ class AgentSimulation:
|
|
218
273
|
except Exception as e:
|
219
274
|
logger.error(f"采访过程出错: {str(e)}")
|
220
275
|
return f"采访过程出现错误: {str(e)}"
|
221
|
-
|
276
|
+
|
222
277
|
async def submit_survey(self, agent_name: str, survey_id: str) -> str:
|
223
278
|
"""向智能体提交问卷
|
224
279
|
|
@@ -319,9 +374,7 @@ class AgentSimulation:
|
|
319
374
|
prevent_thread_lock=True,
|
320
375
|
quiet=True,
|
321
376
|
)
|
322
|
-
logger.info(
|
323
|
-
f"Gradio Frontend is running on http://{server_name}:{server_port}"
|
324
|
-
)
|
377
|
+
logger.info(f"Gradio Frontend is running on http://{server_name}:{server_port}")
|
325
378
|
|
326
379
|
async def step(self):
|
327
380
|
"""运行一步, 即每个智能体执行一次forward"""
|
@@ -333,7 +386,7 @@ class AgentSimulation:
|
|
333
386
|
except Exception as e:
|
334
387
|
logger.error(f"运行错误: {str(e)}")
|
335
388
|
raise
|
336
|
-
|
389
|
+
|
337
390
|
async def run(
|
338
391
|
self,
|
339
392
|
day: int = 1,
|
@@ -348,7 +401,7 @@ class AgentSimulation:
|
|
348
401
|
tasks = []
|
349
402
|
for group in self._groups.values():
|
350
403
|
tasks.append(group.run.remote(day))
|
351
|
-
|
404
|
+
|
352
405
|
await asyncio.gather(*tasks)
|
353
406
|
|
354
407
|
except Exception as e:
|
@@ -4,14 +4,17 @@ import uuid
|
|
4
4
|
import json
|
5
5
|
from .models import Survey, Question, QuestionType
|
6
6
|
|
7
|
+
|
7
8
|
class SurveyManager:
|
8
9
|
def __init__(self):
|
9
10
|
self._surveys: Dict[str, Survey] = {}
|
10
|
-
|
11
|
-
def create_survey(
|
11
|
+
|
12
|
+
def create_survey(
|
13
|
+
self, title: str, description: str, questions: List[dict]
|
14
|
+
) -> Survey:
|
12
15
|
"""创建新问卷"""
|
13
16
|
survey_id = str(uuid.uuid4())
|
14
|
-
|
17
|
+
|
15
18
|
# 转换问题数据
|
16
19
|
survey_questions = []
|
17
20
|
for q in questions:
|
@@ -21,47 +24,45 @@ class SurveyManager:
|
|
21
24
|
required=q.get("required", True),
|
22
25
|
options=q.get("options", []),
|
23
26
|
min_rating=q.get("min_rating", 1),
|
24
|
-
max_rating=q.get("max_rating", 5)
|
27
|
+
max_rating=q.get("max_rating", 5),
|
25
28
|
)
|
26
29
|
survey_questions.append(question)
|
27
|
-
|
30
|
+
|
28
31
|
survey = Survey(
|
29
32
|
id=survey_id,
|
30
33
|
title=title,
|
31
34
|
description=description,
|
32
|
-
questions=survey_questions
|
35
|
+
questions=survey_questions,
|
33
36
|
)
|
34
|
-
|
37
|
+
|
35
38
|
self._surveys[survey_id] = survey
|
36
39
|
return survey
|
37
|
-
|
40
|
+
|
38
41
|
def get_survey(self, survey_id: str) -> Optional[Survey]:
|
39
42
|
"""获取指定问卷"""
|
40
43
|
return self._surveys.get(survey_id)
|
41
|
-
|
44
|
+
|
42
45
|
def get_all_surveys(self) -> List[Survey]:
|
43
46
|
"""获取所有问卷"""
|
44
47
|
return list(self._surveys.values())
|
45
|
-
|
48
|
+
|
46
49
|
def add_response(self, survey_id: str, agent_name: str, response: dict) -> bool:
|
47
50
|
"""添加问卷回答"""
|
48
51
|
survey = self.get_survey(survey_id)
|
49
52
|
if not survey:
|
50
53
|
return False
|
51
|
-
|
52
|
-
survey.responses[agent_name] = {
|
53
|
-
"timestamp": datetime.now(),
|
54
|
-
**response
|
55
|
-
}
|
54
|
+
|
55
|
+
survey.responses[agent_name] = {"timestamp": datetime.now(), **response}
|
56
56
|
return True
|
57
|
-
|
57
|
+
|
58
58
|
def export_results(self, survey_id: str) -> str:
|
59
59
|
"""导出问卷结果"""
|
60
60
|
survey = self.get_survey(survey_id)
|
61
61
|
if not survey:
|
62
62
|
return json.dumps({"error": "问卷不存在"})
|
63
|
-
|
64
|
-
return json.dumps(
|
65
|
-
"survey": survey.to_dict(),
|
66
|
-
|
67
|
-
|
63
|
+
|
64
|
+
return json.dumps(
|
65
|
+
{"survey": survey.to_dict(), "responses": survey.responses},
|
66
|
+
ensure_ascii=False,
|
67
|
+
indent=2,
|
68
|
+
)
|
@@ -4,6 +4,7 @@ from datetime import datetime
|
|
4
4
|
from enum import Enum
|
5
5
|
import uuid
|
6
6
|
|
7
|
+
|
7
8
|
class QuestionType(Enum):
|
8
9
|
TEXT = "文本"
|
9
10
|
SINGLE_CHOICE = "单选"
|
@@ -11,6 +12,7 @@ class QuestionType(Enum):
|
|
11
12
|
RATING = "评分"
|
12
13
|
LIKERT = "李克特量表"
|
13
14
|
|
15
|
+
|
14
16
|
@dataclass
|
15
17
|
class Question:
|
16
18
|
content: str
|
@@ -19,7 +21,7 @@ class Question:
|
|
19
21
|
options: List[str] = field(default_factory=list)
|
20
22
|
min_rating: int = 1
|
21
23
|
max_rating: int = 5
|
22
|
-
|
24
|
+
|
23
25
|
def to_dict(self) -> dict:
|
24
26
|
return {
|
25
27
|
"content": self.content,
|
@@ -27,9 +29,10 @@ class Question:
|
|
27
29
|
"required": self.required,
|
28
30
|
"options": self.options,
|
29
31
|
"min_rating": self.min_rating,
|
30
|
-
"max_rating": self.max_rating
|
32
|
+
"max_rating": self.max_rating,
|
31
33
|
}
|
32
34
|
|
35
|
+
|
33
36
|
@dataclass
|
34
37
|
class Survey:
|
35
38
|
id: str
|
@@ -38,12 +41,12 @@ class Survey:
|
|
38
41
|
questions: List[Question]
|
39
42
|
responses: Dict[str, dict] = field(default_factory=dict)
|
40
43
|
created_at: datetime = field(default_factory=datetime.now)
|
41
|
-
|
44
|
+
|
42
45
|
def to_dict(self) -> dict:
|
43
46
|
return {
|
44
47
|
"id": self.id,
|
45
48
|
"title": self.title,
|
46
49
|
"description": self.description,
|
47
50
|
"questions": [q.to_dict() for q in self.questions],
|
48
|
-
"response_count": len(self.responses)
|
49
|
-
}
|
51
|
+
"response_count": len(self.responses),
|
52
|
+
}
|
pycityagent/utils/decorators.py
CHANGED
@@ -2,18 +2,21 @@ import time
|
|
2
2
|
import functools
|
3
3
|
import inspect
|
4
4
|
|
5
|
-
CALLING_STRING =
|
5
|
+
CALLING_STRING = 'function: `{func_name}` in "{file_path}", line {line_number}, arguments: `{arguments}` start time: `{start_time}` end time: `{end_time}` output: `{output}`'
|
6
6
|
|
7
|
-
__all__ =[
|
7
|
+
__all__ = [
|
8
8
|
"record_call_aio",
|
9
9
|
"record_call",
|
10
10
|
"lock_decorator",
|
11
11
|
]
|
12
|
+
|
13
|
+
|
12
14
|
def record_call_aio(record_function_calling: bool = True):
|
13
15
|
"""
|
14
16
|
Decorator to log the async function call details if `record_function_calling` is True.
|
15
17
|
"""
|
16
|
-
|
18
|
+
|
19
|
+
def decorator(func):
|
17
20
|
async def wrapper(*args, **kwargs):
|
18
21
|
cur_frame = inspect.currentframe()
|
19
22
|
assert cur_frame is not None
|
@@ -40,14 +43,18 @@ def record_call_aio(record_function_calling: bool = True):
|
|
40
43
|
)
|
41
44
|
)
|
42
45
|
return result
|
46
|
+
|
43
47
|
return wrapper
|
48
|
+
|
44
49
|
return decorator
|
45
50
|
|
51
|
+
|
46
52
|
def record_call(record_function_calling: bool = True):
|
47
53
|
"""
|
48
54
|
Decorator to log the function call details if `record_function_calling` is True.
|
49
55
|
"""
|
50
|
-
|
56
|
+
|
57
|
+
def decorator(func):
|
51
58
|
def wrapper(*args, **kwargs):
|
52
59
|
cur_frame = inspect.currentframe()
|
53
60
|
assert cur_frame is not None
|
@@ -74,9 +81,12 @@ def record_call(record_function_calling: bool = True):
|
|
74
81
|
)
|
75
82
|
)
|
76
83
|
return result
|
84
|
+
|
77
85
|
return wrapper
|
86
|
+
|
78
87
|
return decorator
|
79
88
|
|
89
|
+
|
80
90
|
def lock_decorator(func):
|
81
91
|
async def wrapper(self, *args, **kwargs):
|
82
92
|
lock = self._lock
|
pycityagent/workflow/block.py
CHANGED
@@ -87,7 +87,7 @@ def log_and_check(
|
|
87
87
|
A decorator that logs function calls and optionally checks a condition before executing the function.
|
88
88
|
|
89
89
|
This decorator is specifically designed to be used with the `block` method.
|
90
|
-
|
90
|
+
|
91
91
|
Args:
|
92
92
|
condition (Callable): A condition function that must be satisfied before the decorated function is executed.
|
93
93
|
Can be synchronous or asynchronous.
|
@@ -124,15 +124,16 @@ def log_and_check(
|
|
124
124
|
def trigger_class():
|
125
125
|
def decorator(cls):
|
126
126
|
original_forward = cls.forward
|
127
|
-
|
127
|
+
|
128
128
|
@functools.wraps(original_forward)
|
129
129
|
async def wrapped_forward(self, *args, **kwargs):
|
130
130
|
if self.trigger is not None:
|
131
131
|
await self.trigger.wait_for_trigger()
|
132
132
|
return await original_forward(self, *args, **kwargs)
|
133
|
-
|
133
|
+
|
134
134
|
cls.forward = wrapped_forward
|
135
135
|
return cls
|
136
|
+
|
136
137
|
return decorator
|
137
138
|
|
138
139
|
|