pycityagent 2.0.0a15__py3-none-any.whl → 2.0.0a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +37 -33
- pycityagent/simulation/agentgroup.py +4 -4
- pycityagent/simulation/simulation.py +88 -53
- pycityagent/survey/manager.py +10 -14
- pycityagent/survey/models.py +24 -24
- pycityagent/workflow/tool.py +1 -1
- {pycityagent-2.0.0a15.dist-info → pycityagent-2.0.0a16.dist-info}/METADATA +3 -2
- {pycityagent-2.0.0a15.dist-info → pycityagent-2.0.0a16.dist-info}/RECORD +9 -9
- {pycityagent-2.0.0a15.dist-info → pycityagent-2.0.0a16.dist-info}/WHEEL +0 -0
pycityagent/agent.py
CHANGED
@@ -9,7 +9,7 @@ from enum import Enum
|
|
9
9
|
import logging
|
10
10
|
import random
|
11
11
|
import uuid
|
12
|
-
from typing import Dict, List, Optional
|
12
|
+
from typing import Dict, List, Optional,Any
|
13
13
|
|
14
14
|
import fastavro
|
15
15
|
|
@@ -80,6 +80,7 @@ class Agent(ABC):
|
|
80
80
|
self._simulator = simulator
|
81
81
|
self._memory = memory
|
82
82
|
self._exp_id = -1
|
83
|
+
self._agent_id = -1
|
83
84
|
self._has_bound_to_simulator = False
|
84
85
|
self._has_bound_to_economy = False
|
85
86
|
self._blocked = False
|
@@ -224,8 +225,8 @@ class Agent(ABC):
|
|
224
225
|
return
|
225
226
|
response_to_avro = [{
|
226
227
|
"id": self._uuid,
|
227
|
-
"day": await self.
|
228
|
-
"t": await self.
|
228
|
+
"day": await self.simulator.get_simulator_day(),
|
229
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
229
230
|
"survey_id": survey["id"],
|
230
231
|
"result": survey_response,
|
231
232
|
"created_at": int(datetime.now().timestamp() * 1000),
|
@@ -273,8 +274,8 @@ class Agent(ABC):
|
|
273
274
|
async def _process_interview(self, payload: dict):
|
274
275
|
auros = [{
|
275
276
|
"id": self._uuid,
|
276
|
-
"day": await self.
|
277
|
-
"t": await self.
|
277
|
+
"day": await self.simulator.get_simulator_day(),
|
278
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
278
279
|
"type": 2,
|
279
280
|
"speaker": "user",
|
280
281
|
"content": payload["content"],
|
@@ -284,8 +285,8 @@ class Agent(ABC):
|
|
284
285
|
response = await self.generate_user_chat_response(question)
|
285
286
|
auros.append({
|
286
287
|
"id": self._uuid,
|
287
|
-
"day": await self.
|
288
|
-
"t": await self.
|
288
|
+
"day": await self.simulator.get_simulator_day(),
|
289
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
289
290
|
"type": 2,
|
290
291
|
"speaker": "",
|
291
292
|
"content": response,
|
@@ -297,7 +298,9 @@ class Agent(ABC):
|
|
297
298
|
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
298
299
|
|
299
300
|
async def process_agent_chat_response(self, payload: dict) -> str:
|
300
|
-
|
301
|
+
resp = f"Agent {self._uuid} received agent chat response: {payload}"
|
302
|
+
logger.info(resp)
|
303
|
+
return resp
|
301
304
|
|
302
305
|
async def _process_agent_chat(self, payload: dict):
|
303
306
|
auros = [{
|
@@ -334,7 +337,7 @@ class Agent(ABC):
|
|
334
337
|
logger.info(f"Agent {self._uuid} received user survey message: {payload}")
|
335
338
|
asyncio.create_task(self._process_survey(payload["data"]))
|
336
339
|
|
337
|
-
async def handle_gather_message(self, payload:
|
340
|
+
async def handle_gather_message(self, payload: Any):
|
338
341
|
raise NotImplementedError
|
339
342
|
|
340
343
|
# MQTT send message
|
@@ -357,14 +360,14 @@ class Agent(ABC):
|
|
357
360
|
"from": self._uuid,
|
358
361
|
"content": content,
|
359
362
|
"timestamp": int(datetime.now().timestamp() * 1000),
|
360
|
-
"day": await self.
|
361
|
-
"t": await self.
|
363
|
+
"day": await self.simulator.get_simulator_day(),
|
364
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
362
365
|
}
|
363
366
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
364
367
|
auros = [{
|
365
368
|
"id": self._uuid,
|
366
|
-
"day": await self.
|
367
|
-
"t": await self.
|
369
|
+
"day": await self.simulator.get_simulator_day(),
|
370
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
368
371
|
"type": 1,
|
369
372
|
"speaker": self._uuid,
|
370
373
|
"content": content,
|
@@ -440,8 +443,8 @@ class CitizenAgent(Agent):
|
|
440
443
|
"pedestrian_attribute",
|
441
444
|
"bike_attribute",
|
442
445
|
}
|
443
|
-
simulator = self.
|
444
|
-
memory = self.
|
446
|
+
simulator = self.simulator
|
447
|
+
memory = self.memory
|
445
448
|
person_id = await memory.get("id")
|
446
449
|
# ATTENTION:模拟器分配的id从0开始
|
447
450
|
if person_id >= 0:
|
@@ -478,11 +481,11 @@ class CitizenAgent(Agent):
|
|
478
481
|
await self._economy_client.remove_agents([self._agent_id])
|
479
482
|
except:
|
480
483
|
pass
|
481
|
-
person_id = await self.
|
484
|
+
person_id = await self.memory.get("id")
|
482
485
|
await self._economy_client.add_agents(
|
483
486
|
{
|
484
487
|
"id": person_id,
|
485
|
-
"currency": await self.
|
488
|
+
"currency": await self.memory.get("currency"),
|
486
489
|
}
|
487
490
|
)
|
488
491
|
self._has_bound_to_economy = True
|
@@ -543,62 +546,63 @@ class InstitutionAgent(Agent):
|
|
543
546
|
# TODO: More general id generation
|
544
547
|
_id = random.randint(100000, 999999)
|
545
548
|
self._agent_id = _id
|
546
|
-
await self.
|
549
|
+
await self.memory.update("id", _id, protect_llm_read_only_fields=False)
|
547
550
|
try:
|
548
551
|
await self._economy_client.remove_orgs([self._agent_id])
|
549
552
|
except:
|
550
553
|
pass
|
551
554
|
try:
|
552
|
-
|
553
|
-
|
555
|
+
_memory = self.memory
|
556
|
+
_id = await _memory.get("id")
|
557
|
+
_type = await _memory.get("type")
|
554
558
|
try:
|
555
|
-
nominal_gdp = await
|
559
|
+
nominal_gdp = await _memory.get("nominal_gdp")
|
556
560
|
except:
|
557
561
|
nominal_gdp = []
|
558
562
|
try:
|
559
|
-
real_gdp = await
|
563
|
+
real_gdp = await _memory.get("real_gdp")
|
560
564
|
except:
|
561
565
|
real_gdp = []
|
562
566
|
try:
|
563
|
-
unemployment = await
|
567
|
+
unemployment = await _memory.get("unemployment")
|
564
568
|
except:
|
565
569
|
unemployment = []
|
566
570
|
try:
|
567
|
-
wages = await
|
571
|
+
wages = await _memory.get("wages")
|
568
572
|
except:
|
569
573
|
wages = []
|
570
574
|
try:
|
571
|
-
prices = await
|
575
|
+
prices = await _memory.get("prices")
|
572
576
|
except:
|
573
577
|
prices = []
|
574
578
|
try:
|
575
|
-
inventory = await
|
579
|
+
inventory = await _memory.get("inventory")
|
576
580
|
except:
|
577
581
|
inventory = 0
|
578
582
|
try:
|
579
|
-
price = await
|
583
|
+
price = await _memory.get("price")
|
580
584
|
except:
|
581
585
|
price = 0
|
582
586
|
try:
|
583
|
-
currency = await
|
587
|
+
currency = await _memory.get("currency")
|
584
588
|
except:
|
585
589
|
currency = 0.0
|
586
590
|
try:
|
587
|
-
interest_rate = await
|
591
|
+
interest_rate = await _memory.get("interest_rate")
|
588
592
|
except:
|
589
593
|
interest_rate = 0.0
|
590
594
|
try:
|
591
|
-
bracket_cutoffs = await
|
595
|
+
bracket_cutoffs = await _memory.get("bracket_cutoffs")
|
592
596
|
except:
|
593
597
|
bracket_cutoffs = []
|
594
598
|
try:
|
595
|
-
bracket_rates = await
|
599
|
+
bracket_rates = await _memory.get("bracket_rates")
|
596
600
|
except:
|
597
601
|
bracket_rates = []
|
598
602
|
await self._economy_client.add_orgs(
|
599
603
|
{
|
600
|
-
"id":
|
601
|
-
"type":
|
604
|
+
"id": _id,
|
605
|
+
"type": _type,
|
602
606
|
"nominal_gdp": nominal_gdp,
|
603
607
|
"real_gdp": real_gdp,
|
604
608
|
"unemployment": unemployment,
|
@@ -20,7 +20,7 @@ logger = logging.getLogger("pycityagent")
|
|
20
20
|
|
21
21
|
@ray.remote
|
22
22
|
class AgentGroup:
|
23
|
-
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int
|
23
|
+
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int):
|
24
24
|
logger.setLevel(logging_level)
|
25
25
|
self._uuid = str(uuid.uuid4())
|
26
26
|
self.agents = agents
|
@@ -64,20 +64,20 @@ class AgentGroup:
|
|
64
64
|
self.economy_client = None
|
65
65
|
|
66
66
|
for agent in self.agents:
|
67
|
-
agent.set_exp_id(self.exp_id)
|
67
|
+
agent.set_exp_id(self.exp_id) # type: ignore
|
68
68
|
agent.set_llm_client(self.llm)
|
69
69
|
agent.set_simulator(self.simulator)
|
70
70
|
if self.economy_client is not None:
|
71
71
|
agent.set_economy_client(self.economy_client)
|
72
72
|
agent.set_messager(self.messager)
|
73
73
|
if self.enable_avro:
|
74
|
-
agent.set_avro_file(self.avro_file)
|
74
|
+
agent.set_avro_file(self.avro_file) # type: ignore
|
75
75
|
|
76
76
|
async def init_agents(self):
|
77
77
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
78
78
|
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
79
79
|
for agent in self.agents:
|
80
|
-
await agent.bind_to_simulator()
|
80
|
+
await agent.bind_to_simulator() # type: ignore
|
81
81
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
82
82
|
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
83
83
|
await self.messager.connect()
|
@@ -2,19 +2,22 @@ import asyncio
|
|
2
2
|
import json
|
3
3
|
import logging
|
4
4
|
import os
|
5
|
-
|
5
|
+
import random
|
6
6
|
import uuid
|
7
|
+
from collections.abc import Sequence
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
7
9
|
from datetime import datetime, timezone
|
8
|
-
import
|
9
|
-
from typing import Dict, List, Optional,
|
10
|
-
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
12
|
+
|
11
13
|
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
14
|
+
import yaml
|
15
|
+
from mosstool.map._map_util.const import AOI_START_ID
|
16
|
+
|
12
17
|
from pycityagent.environment.simulator import Simulator
|
13
18
|
from pycityagent.memory.memory import Memory
|
14
19
|
from pycityagent.message.messager import Messager
|
15
20
|
from pycityagent.survey import Survey
|
16
|
-
import yaml
|
17
|
-
from concurrent.futures import ThreadPoolExecutor
|
18
21
|
|
19
22
|
from ..agent import Agent, InstitutionAgent
|
20
23
|
from .agentgroup import AgentGroup
|
@@ -31,7 +34,7 @@ class AgentSimulation:
|
|
31
34
|
config: dict,
|
32
35
|
agent_prefix: str = "agent_",
|
33
36
|
exp_name: str = "default_experiment",
|
34
|
-
logging_level: int = logging.WARNING
|
37
|
+
logging_level: int = logging.WARNING,
|
35
38
|
):
|
36
39
|
"""
|
37
40
|
Args:
|
@@ -50,8 +53,8 @@ class AgentSimulation:
|
|
50
53
|
self._simulator = Simulator(config["simulator_request"])
|
51
54
|
self.agent_prefix = agent_prefix
|
52
55
|
self._agents: Dict[uuid.UUID, Agent] = {}
|
53
|
-
self._groups: Dict[str, AgentGroup] = {}
|
54
|
-
self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {}
|
56
|
+
self._groups: Dict[str, AgentGroup] = {} # type:ignore
|
57
|
+
self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {} # type:ignore
|
55
58
|
self._agent_uuids: List[uuid.UUID] = []
|
56
59
|
self._user_chat_topics: Dict[uuid.UUID, str] = {}
|
57
60
|
self._user_survey_topics: Dict[uuid.UUID, str] = {}
|
@@ -89,33 +92,44 @@ class AgentSimulation:
|
|
89
92
|
"cur_t": 0.0,
|
90
93
|
"config": json.dumps(config),
|
91
94
|
"error": "",
|
92
|
-
"created_at": datetime.now(timezone.utc).isoformat()
|
95
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
93
96
|
}
|
94
|
-
|
97
|
+
|
95
98
|
# 创建异步任务保存实验信息
|
96
99
|
self._exp_info_file = self._avro_path / "experiment_info.yaml"
|
97
|
-
with open(self._exp_info_file,
|
100
|
+
with open(self._exp_info_file, "w") as f:
|
98
101
|
yaml.dump(self._exp_info, f)
|
99
102
|
|
100
103
|
@property
|
101
104
|
def agents(self):
|
102
105
|
return self._agents
|
103
|
-
|
106
|
+
|
104
107
|
@property
|
105
108
|
def groups(self):
|
106
109
|
return self._groups
|
107
|
-
|
110
|
+
|
108
111
|
@property
|
109
112
|
def agent_uuids(self):
|
110
113
|
return self._agent_uuids
|
111
|
-
|
114
|
+
|
112
115
|
@property
|
113
116
|
def agent_uuid2group(self):
|
114
117
|
return self._agent_uuid2group
|
115
|
-
|
116
|
-
def create_remote_group(
|
118
|
+
|
119
|
+
def create_remote_group(
|
120
|
+
self,
|
121
|
+
group_name: str,
|
122
|
+
agents: list[Agent],
|
123
|
+
config: dict,
|
124
|
+
exp_id: str,
|
125
|
+
enable_avro: bool,
|
126
|
+
avro_path: Path,
|
127
|
+
logging_level: int = logging.WARNING,
|
128
|
+
):
|
117
129
|
"""创建远程组"""
|
118
|
-
group = AgentGroup.remote(
|
130
|
+
group = AgentGroup.remote(
|
131
|
+
agents, config, exp_id, enable_avro, avro_path, logging_level
|
132
|
+
)
|
119
133
|
return group_name, group, agents
|
120
134
|
|
121
135
|
async def init_agents(
|
@@ -164,7 +178,7 @@ class AgentSimulation:
|
|
164
178
|
# 使用线程池并行创建 AgentGroup
|
165
179
|
group_creation_params = []
|
166
180
|
class_init_index = 0
|
167
|
-
|
181
|
+
|
168
182
|
# 首先收集所有需要创建的组的参数
|
169
183
|
for i in range(len(self.agent_class)):
|
170
184
|
agent_class = self.agent_class[i]
|
@@ -175,9 +189,7 @@ class AgentSimulation:
|
|
175
189
|
|
176
190
|
# 获取Memory配置
|
177
191
|
extra_attributes, profile, base = memory_config_func_i()
|
178
|
-
memory = Memory(
|
179
|
-
config=extra_attributes, profile=profile, base=base
|
180
|
-
)
|
192
|
+
memory = Memory(config=extra_attributes, profile=profile, base=base)
|
181
193
|
|
182
194
|
# 创建智能体时传入Memory配置
|
183
195
|
agent = agent_class(
|
@@ -190,34 +202,36 @@ class AgentSimulation:
|
|
190
202
|
|
191
203
|
# 计算需要的组数,向上取整以处理不足一组的情况
|
192
204
|
num_group = (agent_count_i + group_size - 1) // group_size
|
193
|
-
|
205
|
+
|
194
206
|
for k in range(num_group):
|
195
207
|
start_idx = class_init_index + k * group_size
|
196
208
|
end_idx = min(
|
197
209
|
class_init_index + (k + 1) * group_size, # 修正了索引计算
|
198
|
-
class_init_index + agent_count_i
|
210
|
+
class_init_index + agent_count_i,
|
199
211
|
)
|
200
|
-
|
212
|
+
|
201
213
|
agents = list(self._agents.values())[start_idx:end_idx]
|
202
214
|
group_name = f"AgentType_{i}_Group_{k}"
|
203
|
-
|
215
|
+
|
204
216
|
# 收集创建参数
|
205
|
-
group_creation_params.append((
|
206
|
-
|
207
|
-
agents
|
208
|
-
))
|
209
|
-
|
217
|
+
group_creation_params.append((group_name, agents))
|
218
|
+
|
210
219
|
class_init_index += agent_count_i
|
211
220
|
|
212
221
|
# 收集所有创建组的参数
|
213
222
|
creation_tasks = []
|
214
223
|
for group_name, agents in group_creation_params:
|
215
224
|
# 直接创建异步任务
|
216
|
-
group = AgentGroup.remote(
|
217
|
-
|
218
|
-
|
225
|
+
group = AgentGroup.remote(
|
226
|
+
agents,
|
227
|
+
self.config,
|
228
|
+
self.exp_id,
|
229
|
+
self._enable_avro,
|
230
|
+
self._avro_path,
|
231
|
+
self.logging_level,
|
232
|
+
)
|
219
233
|
creation_tasks.append((group_name, group, agents))
|
220
|
-
|
234
|
+
|
221
235
|
# 更新数据结构
|
222
236
|
for group_name, group, agents in creation_tasks:
|
223
237
|
self._groups[group_name] = group
|
@@ -233,7 +247,9 @@ class AgentSimulation:
|
|
233
247
|
# 设置用户主题
|
234
248
|
for uuid, agent in self._agents.items():
|
235
249
|
self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
|
236
|
-
self._user_survey_topics[uuid] =
|
250
|
+
self._user_survey_topics[uuid] = (
|
251
|
+
f"exps/{self.exp_id}/agents/{uuid}/user-survey"
|
252
|
+
)
|
237
253
|
|
238
254
|
async def gather(self, content: str):
|
239
255
|
"""收集智能体的特定信息"""
|
@@ -250,7 +266,18 @@ class AgentSimulation:
|
|
250
266
|
def default_memory_config_institution(self):
|
251
267
|
"""默认的Memory配置函数"""
|
252
268
|
EXTRA_ATTRIBUTES = {
|
253
|
-
"type": (
|
269
|
+
"type": (
|
270
|
+
int,
|
271
|
+
random.choice(
|
272
|
+
[
|
273
|
+
economyv2.ORG_TYPE_BANK,
|
274
|
+
economyv2.ORG_TYPE_GOVERNMENT,
|
275
|
+
economyv2.ORG_TYPE_FIRM,
|
276
|
+
economyv2.ORG_TYPE_NBS,
|
277
|
+
economyv2.ORG_TYPE_UNSPECIFIED,
|
278
|
+
]
|
279
|
+
),
|
280
|
+
),
|
254
281
|
"nominal_gdp": (list, [], True),
|
255
282
|
"real_gdp": (list, [], True),
|
256
283
|
"unemployment": (list, [], True),
|
@@ -364,29 +391,35 @@ class AgentSimulation:
|
|
364
391
|
}
|
365
392
|
|
366
393
|
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
367
|
-
|
368
|
-
async def send_survey(
|
394
|
+
|
395
|
+
async def send_survey(
|
396
|
+
self, survey: Survey, agent_uuids: Optional[List[uuid.UUID]] = None
|
397
|
+
):
|
369
398
|
"""发送问卷"""
|
370
|
-
|
399
|
+
survey_dict = survey.to_dict()
|
371
400
|
if agent_uuids is None:
|
372
401
|
agent_uuids = self._agent_uuids
|
373
402
|
payload = {
|
374
403
|
"from": "none",
|
375
|
-
"survey_id":
|
404
|
+
"survey_id": survey_dict["id"],
|
376
405
|
"timestamp": int(datetime.now().timestamp() * 1000),
|
377
|
-
"data":
|
406
|
+
"data": survey_dict,
|
378
407
|
}
|
379
408
|
for uuid in agent_uuids:
|
380
409
|
topic = self._user_survey_topics[uuid]
|
381
410
|
await self._messager.send_message(topic, payload)
|
382
411
|
|
383
|
-
async def send_interview_message(
|
412
|
+
async def send_interview_message(
|
413
|
+
self, content: str, agent_uuids: Union[uuid.UUID, List[uuid.UUID]]
|
414
|
+
):
|
384
415
|
"""发送面试消息"""
|
385
416
|
payload = {
|
386
417
|
"from": "none",
|
387
418
|
"content": content,
|
388
419
|
"timestamp": int(datetime.now().timestamp() * 1000),
|
389
420
|
}
|
421
|
+
if not isinstance(agent_uuids, Sequence):
|
422
|
+
agent_uuids = [agent_uuids]
|
390
423
|
for uuid in agent_uuids:
|
391
424
|
topic = self._user_chat_topics[uuid]
|
392
425
|
await self._messager.send_message(topic, payload)
|
@@ -405,7 +438,7 @@ class AgentSimulation:
|
|
405
438
|
async def _save_exp_info(self) -> None:
|
406
439
|
"""异步保存实验信息到YAML文件"""
|
407
440
|
try:
|
408
|
-
with open(self._exp_info_file,
|
441
|
+
with open(self._exp_info_file, "w") as f:
|
409
442
|
yaml.dump(self._exp_info, f)
|
410
443
|
except Exception as e:
|
411
444
|
logger.error(f"保存实验信息失败: {str(e)}")
|
@@ -418,20 +451,22 @@ class AgentSimulation:
|
|
418
451
|
|
419
452
|
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
420
453
|
"""监控实验状态并更新
|
421
|
-
|
454
|
+
|
422
455
|
Args:
|
423
456
|
stop_event: 用于通知监控任务停止的事件
|
424
457
|
"""
|
425
458
|
try:
|
426
|
-
while not stop_event.is_set():
|
459
|
+
while not stop_event.is_set():
|
427
460
|
# 更新实验状态
|
428
461
|
# 假设所有group的cur_day和cur_t是同步的,取第一个即可
|
429
462
|
self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
|
430
|
-
self._exp_info["cur_t"] =
|
463
|
+
self._exp_info["cur_t"] = (
|
464
|
+
await self._simulator.get_simulator_second_from_start_of_day()
|
465
|
+
)
|
431
466
|
await self._save_exp_info()
|
432
|
-
|
467
|
+
|
433
468
|
await asyncio.sleep(1) # 避免过于频繁的更新
|
434
|
-
except asyncio.
|
469
|
+
except asyncio.CancelledError:
|
435
470
|
# 正常取消,不需要特殊处理
|
436
471
|
pass
|
437
472
|
except Exception as e:
|
@@ -446,12 +481,12 @@ class AgentSimulation:
|
|
446
481
|
try:
|
447
482
|
self._exp_info["num_day"] += day
|
448
483
|
await self._update_exp_status(1) # 更新状态为运行中
|
449
|
-
|
484
|
+
|
450
485
|
# 创建停止事件
|
451
486
|
stop_event = asyncio.Event()
|
452
487
|
# 创建监控任务
|
453
488
|
monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
|
454
|
-
|
489
|
+
|
455
490
|
try:
|
456
491
|
tasks = []
|
457
492
|
for group in self._groups.values():
|
@@ -459,13 +494,13 @@ class AgentSimulation:
|
|
459
494
|
|
460
495
|
# 等待所有group运行完成
|
461
496
|
await asyncio.gather(*tasks)
|
462
|
-
|
497
|
+
|
463
498
|
finally:
|
464
499
|
# 设置停止事件
|
465
500
|
stop_event.set()
|
466
501
|
# 等待监控任务结束
|
467
502
|
await monitor_task
|
468
|
-
|
503
|
+
|
469
504
|
# 运行成功后更新状态
|
470
505
|
await self._update_exp_status(2)
|
471
506
|
|
pycityagent/survey/manager.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
from typing import List, Dict, Optional
|
2
|
-
from datetime import datetime
|
3
|
-
import uuid
|
4
1
|
import json
|
5
|
-
|
2
|
+
import uuid
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
from .models import Page, Question, QuestionType, Survey
|
6
7
|
|
7
8
|
|
8
9
|
class SurveyManager:
|
9
10
|
def __init__(self):
|
10
|
-
self._surveys:
|
11
|
+
self._surveys: dict[str, Survey] = {}
|
11
12
|
|
12
|
-
def create_survey(
|
13
|
-
self, title: str, description: str, pages: List[dict]
|
14
|
-
) -> Survey:
|
13
|
+
def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
|
15
14
|
"""创建新问卷"""
|
16
15
|
survey_id = uuid.uuid4()
|
17
16
|
|
@@ -32,11 +31,8 @@ class SurveyManager:
|
|
32
31
|
max_rating=q.get("max_rating", 5),
|
33
32
|
)
|
34
33
|
questions.append(question)
|
35
|
-
|
36
|
-
page = Page(
|
37
|
-
name=page_data["name"],
|
38
|
-
elements=questions
|
39
|
-
)
|
34
|
+
|
35
|
+
page = Page(name=page_data["name"], elements=questions)
|
40
36
|
survey_pages.append(page)
|
41
37
|
|
42
38
|
survey = Survey(
|
@@ -53,6 +49,6 @@ class SurveyManager:
|
|
53
49
|
"""获取指定问卷"""
|
54
50
|
return self._surveys.get(survey_id)
|
55
51
|
|
56
|
-
def get_all_surveys(self) ->
|
52
|
+
def get_all_surveys(self) -> list[Survey]:
|
57
53
|
"""获取所有问卷"""
|
58
54
|
return list(self._surveys.values())
|
pycityagent/survey/models.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
+
import json
|
2
|
+
import uuid
|
1
3
|
from dataclasses import dataclass, field
|
2
|
-
from typing import List, Dict, Optional
|
3
4
|
from datetime import datetime
|
4
5
|
from enum import Enum
|
5
|
-
import
|
6
|
-
import json
|
6
|
+
from typing import Any
|
7
7
|
|
8
8
|
|
9
9
|
class QuestionType(Enum):
|
@@ -20,19 +20,20 @@ class Question:
|
|
20
20
|
name: str
|
21
21
|
title: str
|
22
22
|
type: QuestionType
|
23
|
-
choices:
|
24
|
-
columns:
|
25
|
-
rows:
|
23
|
+
choices: list[str] = field(default_factory=list)
|
24
|
+
columns: list[str] = field(default_factory=list)
|
25
|
+
rows: list[str] = field(default_factory=list)
|
26
|
+
required: bool = True
|
26
27
|
min_rating: int = 1
|
27
28
|
max_rating: int = 5
|
28
29
|
|
29
30
|
def to_dict(self) -> dict:
|
30
|
-
base_dict = {
|
31
|
+
base_dict: dict[str, Any] = {
|
31
32
|
"type": self.type.value,
|
32
33
|
"name": self.name,
|
33
34
|
"title": self.title,
|
34
35
|
}
|
35
|
-
|
36
|
+
|
36
37
|
if self.type in [QuestionType.RADIO, QuestionType.CHECKBOX]:
|
37
38
|
base_dict["choices"] = self.choices
|
38
39
|
elif self.type == QuestionType.MATRIX:
|
@@ -41,20 +42,17 @@ class Question:
|
|
41
42
|
elif self.type == QuestionType.RATING:
|
42
43
|
base_dict["min_rating"] = self.min_rating
|
43
44
|
base_dict["max_rating"] = self.max_rating
|
44
|
-
|
45
|
+
|
45
46
|
return base_dict
|
46
47
|
|
47
48
|
|
48
49
|
@dataclass
|
49
50
|
class Page:
|
50
51
|
name: str
|
51
|
-
elements:
|
52
|
+
elements: list[Question]
|
52
53
|
|
53
54
|
def to_dict(self) -> dict:
|
54
|
-
return {
|
55
|
-
"name": self.name,
|
56
|
-
"elements": [q.to_dict() for q in self.elements]
|
57
|
-
}
|
55
|
+
return {"name": self.name, "elements": [q.to_dict() for q in self.elements]}
|
58
56
|
|
59
57
|
|
60
58
|
@dataclass
|
@@ -62,8 +60,8 @@ class Survey:
|
|
62
60
|
id: uuid.UUID
|
63
61
|
title: str
|
64
62
|
description: str
|
65
|
-
pages:
|
66
|
-
responses:
|
63
|
+
pages: list[Page]
|
64
|
+
responses: dict[str, dict] = field(default_factory=dict)
|
67
65
|
created_at: datetime = field(default_factory=datetime.now)
|
68
66
|
|
69
67
|
def to_dict(self) -> dict:
|
@@ -83,12 +81,12 @@ class Survey:
|
|
83
81
|
"description": self.description,
|
84
82
|
"pages": [p.to_dict() for p in self.pages],
|
85
83
|
"responses": self.responses,
|
86
|
-
"created_at": self.created_at.isoformat()
|
84
|
+
"created_at": self.created_at.isoformat(),
|
87
85
|
}
|
88
86
|
return json.dumps(survey_dict)
|
89
87
|
|
90
88
|
@classmethod
|
91
|
-
def from_json(cls, json_str: str) ->
|
89
|
+
def from_json(cls, json_str: str) -> "Survey":
|
92
90
|
"""Create a Survey instance from a JSON string"""
|
93
91
|
data = json.loads(json_str)
|
94
92
|
pages = [
|
@@ -104,17 +102,19 @@ class Survey:
|
|
104
102
|
columns=q.get("columns", []),
|
105
103
|
rows=q.get("rows", []),
|
106
104
|
min_rating=q.get("min_rating", 1),
|
107
|
-
max_rating=q.get("max_rating", 5)
|
108
|
-
)
|
109
|
-
|
110
|
-
|
105
|
+
max_rating=q.get("max_rating", 5),
|
106
|
+
)
|
107
|
+
for q in p["elements"]
|
108
|
+
],
|
109
|
+
)
|
110
|
+
for p in data["pages"]
|
111
111
|
]
|
112
|
-
|
112
|
+
|
113
113
|
return cls(
|
114
114
|
id=uuid.UUID(data["id"]),
|
115
115
|
title=data["title"],
|
116
116
|
description=data["description"],
|
117
117
|
pages=pages,
|
118
118
|
responses=data.get("responses", {}),
|
119
|
-
created_at=datetime.fromisoformat(data["created_at"])
|
119
|
+
created_at=datetime.fromisoformat(data["created_at"]),
|
120
120
|
)
|
pycityagent/workflow/tool.py
CHANGED
@@ -139,7 +139,7 @@ class UpdateWithSimulator(Tool):
|
|
139
139
|
if agent._simulator is None:
|
140
140
|
return
|
141
141
|
if not agent._has_bound_to_simulator:
|
142
|
-
await
|
142
|
+
await agent._bind_to_simulator() # type: ignore
|
143
143
|
simulator = agent.simulator
|
144
144
|
memory = agent.memory
|
145
145
|
person_id = await memory.get("id")
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pycityagent
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.0a16
|
4
4
|
Summary: LLM-based城市环境agent构建库
|
5
5
|
License: MIT
|
6
6
|
Author: Yuwei Yan
|
@@ -24,6 +24,7 @@ Requires-Dist: fastavro (>=1.10.0,<2.0.0)
|
|
24
24
|
Requires-Dist: geojson (==3.1.0)
|
25
25
|
Requires-Dist: gradio (>=5.7.1,<6.0.0)
|
26
26
|
Requires-Dist: grpcio (==1.67.1)
|
27
|
+
Requires-Dist: langchain-core (>=0.3.28,<0.4.0)
|
27
28
|
Requires-Dist: matplotlib (==3.8.3)
|
28
29
|
Requires-Dist: mosstool (==1.0.24)
|
29
30
|
Requires-Dist: networkx (==3.2.1)
|
@@ -33,7 +34,7 @@ Requires-Dist: pandavro (>=1.8.0,<2.0.0)
|
|
33
34
|
Requires-Dist: poetry (>=1.2.2)
|
34
35
|
Requires-Dist: protobuf (<=4.24.0)
|
35
36
|
Requires-Dist: pycitydata (==1.0.0)
|
36
|
-
Requires-Dist: pycityproto (
|
37
|
+
Requires-Dist: pycityproto (>=2.1.3,<3.0.0)
|
37
38
|
Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
|
38
39
|
Requires-Dist: ray (>=2.40.0,<3.0.0)
|
39
40
|
Requires-Dist: sidecar (==0.7.0)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
pycityagent/__init__.py,sha256=EDxt3Su3lH1IMh9suNw7GeGL7UrXeWiZTw5KWNznDzc,637
|
2
|
-
pycityagent/agent.py,sha256=
|
2
|
+
pycityagent/agent.py,sha256=t9W9sKxtQ0EkMxL78kAjAu-rXigEK6gyLY0IEA4DbnQ,23143
|
3
3
|
pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
|
4
4
|
pycityagent/economy/econ_client.py,sha256=wcuNtcpkSijJwNkt2mXw3SshYy4SBy6qbvJ0VQ7Aovo,10854
|
5
5
|
pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
|
@@ -46,11 +46,11 @@ pycityagent/memory/utils.py,sha256=wLNlNlZ-AY9VB8kbUIy0UQSYh26FOQABbhmKQkit5o8,8
|
|
46
46
|
pycityagent/message/__init__.py,sha256=TCjazxqb5DVwbTu1fF0sNvaH_EPXVuj2XQ0p6W-QCLU,55
|
47
47
|
pycityagent/message/messager.py,sha256=W_OVlNGcreHSBf6v-DrEnfNCXExB78ySr0w26MSncfU,2541
|
48
48
|
pycityagent/simulation/__init__.py,sha256=jYaqaNpzM5M_e_ykISS_M-mIyYdzJXJWhgpfBpA6l5k,111
|
49
|
-
pycityagent/simulation/agentgroup.py,sha256=
|
50
|
-
pycityagent/simulation/simulation.py,sha256=
|
49
|
+
pycityagent/simulation/agentgroup.py,sha256=M19XWJRWyjMAYS0_RIOBQ2C7I1MuVYIaX3DgehGZL2Y,12541
|
50
|
+
pycityagent/simulation/simulation.py,sha256=TndrMZSm6qe_wgfv9h6mL9oAiAEHbq6KHBwdwUGG_3k,19261
|
51
51
|
pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
|
52
|
-
pycityagent/survey/manager.py,sha256=
|
53
|
-
pycityagent/survey/models.py,sha256=
|
52
|
+
pycityagent/survey/manager.py,sha256=S5IkwTdelsdtZETChRcfCEczzwSrry_Fly9MY4s3rbk,1681
|
53
|
+
pycityagent/survey/models.py,sha256=YE50UUt5qJ0O_lIUsSY6XFCGUTkJVNu_L1gAhaCJ2fs,3546
|
54
54
|
pycityagent/utils/__init__.py,sha256=xXEMhVfFeOJUXjczaHv9DJqYNp57rc6FibtS7CfrVbA,305
|
55
55
|
pycityagent/utils/avro_schema.py,sha256=DHM3bOo8m0dJf8oSwyOWzVeXrH6OERmzA_a5vS4So4M,4255
|
56
56
|
pycityagent/utils/decorators.py,sha256=Gk3r41hfk6awui40tbwpq3C7wC7jHaRmLRlcJFlLQCE,3160
|
@@ -62,8 +62,8 @@ pycityagent/utils/survey_util.py,sha256=Be9nptmu2JtesFNemPgORh_2GsN7rcDYGQS9Zfvc
|
|
62
62
|
pycityagent/workflow/__init__.py,sha256=EyCcjB6LyBim-5iAOPe4m2qfvghEPqu1ZdGfy4KPeZ8,551
|
63
63
|
pycityagent/workflow/block.py,sha256=6EmiRMLdOZC1wMlmLMIjfrp9TuiI7Gw4s3nnXVMbrnw,6031
|
64
64
|
pycityagent/workflow/prompt.py,sha256=tY69nDO8fgYfF_dOA-iceR8pAhkYmCqoox8uRPqEuGY,2956
|
65
|
-
pycityagent/workflow/tool.py,sha256=
|
65
|
+
pycityagent/workflow/tool.py,sha256=_bCluIX8HTC8ZW6a-wrMB3Uhx2yzD8sM8XFDI3vd0MM,6642
|
66
66
|
pycityagent/workflow/trigger.py,sha256=t5X_i0WtL32bipZSsq_E3UUyYYudYLxQUpvxbgClp2s,5683
|
67
|
-
pycityagent-2.0.
|
68
|
-
pycityagent-2.0.
|
69
|
-
pycityagent-2.0.
|
67
|
+
pycityagent-2.0.0a16.dist-info/METADATA,sha256=ACLlN0UDbuwDNRv5Olf-K-W7LJWmdZ8ksmeTiclp0Zk,7760
|
68
|
+
pycityagent-2.0.0a16.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
69
|
+
pycityagent-2.0.0a16.dist-info/RECORD,,
|
File without changes
|