pycityagent 2.0.0a47__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a49__cp39-cp39-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/__init__.py +3 -2
- pycityagent/agent.py +109 -4
- pycityagent/cityagent/__init__.py +20 -0
- pycityagent/cityagent/bankagent.py +54 -0
- pycityagent/cityagent/blocks/__init__.py +20 -0
- pycityagent/cityagent/blocks/cognition_block.py +304 -0
- pycityagent/cityagent/blocks/dispatcher.py +78 -0
- pycityagent/cityagent/blocks/economy_block.py +356 -0
- pycityagent/cityagent/blocks/mobility_block.py +258 -0
- pycityagent/cityagent/blocks/needs_block.py +305 -0
- pycityagent/cityagent/blocks/other_block.py +103 -0
- pycityagent/cityagent/blocks/plan_block.py +309 -0
- pycityagent/cityagent/blocks/social_block.py +345 -0
- pycityagent/cityagent/blocks/time_block.py +116 -0
- pycityagent/cityagent/blocks/utils.py +66 -0
- pycityagent/cityagent/firmagent.py +75 -0
- pycityagent/cityagent/governmentagent.py +60 -0
- pycityagent/cityagent/initial.py +98 -0
- pycityagent/cityagent/memory_config.py +202 -0
- pycityagent/cityagent/nbsagent.py +92 -0
- pycityagent/cityagent/societyagent.py +291 -0
- pycityagent/memory/memory.py +0 -18
- pycityagent/message/messager.py +6 -3
- pycityagent/simulation/agentgroup.py +123 -37
- pycityagent/simulation/simulation.py +311 -316
- pycityagent/workflow/block.py +66 -1
- pycityagent/workflow/tool.py +9 -4
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/RECORD +33 -14
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,10 @@ import json
|
|
3
3
|
import logging
|
4
4
|
import time
|
5
5
|
import uuid
|
6
|
+
from collections.abc import Callable
|
6
7
|
from datetime import datetime, timezone
|
7
8
|
from pathlib import Path
|
8
|
-
from typing import Any, Optional
|
9
|
+
from typing import Any, Optional, Type, Union
|
9
10
|
from uuid import UUID
|
10
11
|
|
11
12
|
import fastavro
|
@@ -13,12 +14,12 @@ import pyproj
|
|
13
14
|
import ray
|
14
15
|
from langchain_core.embeddings import Embeddings
|
15
16
|
|
16
|
-
from ..agent import Agent,
|
17
|
+
from ..agent import Agent, InstitutionAgent
|
17
18
|
from ..economy.econ_client import EconomyClient
|
18
19
|
from ..environment.simulator import Simulator
|
19
20
|
from ..llm.llm import LLM
|
20
21
|
from ..llm.llmconfig import LLMConfig
|
21
|
-
from ..memory import FaissQuery
|
22
|
+
from ..memory import FaissQuery, Memory
|
22
23
|
from ..message import Messager
|
23
24
|
from ..metrics import MlflowClient
|
24
25
|
from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
|
@@ -31,7 +32,12 @@ logger = logging.getLogger("pycityagent")
|
|
31
32
|
class AgentGroup:
|
32
33
|
def __init__(
|
33
34
|
self,
|
34
|
-
|
35
|
+
agent_class: Union[type[Agent], list[type[Agent]]],
|
36
|
+
number_of_agents: Union[int, list[int]],
|
37
|
+
memory_config_function_group: Union[
|
38
|
+
Callable[[], tuple[dict, dict, dict]],
|
39
|
+
list[Callable[[], tuple[dict, dict, dict]]],
|
40
|
+
],
|
35
41
|
config: dict,
|
36
42
|
exp_id: str | UUID,
|
37
43
|
exp_name: str,
|
@@ -42,15 +48,27 @@ class AgentGroup:
|
|
42
48
|
mlflow_run_id: str,
|
43
49
|
embedding_model: Embeddings,
|
44
50
|
logging_level: int,
|
51
|
+
agent_config_file: Union[str, list[str]] = None,
|
45
52
|
):
|
46
53
|
logger.setLevel(logging_level)
|
47
54
|
self._uuid = str(uuid.uuid4())
|
48
|
-
|
55
|
+
if not isinstance(agent_class, list):
|
56
|
+
agent_class = [agent_class]
|
57
|
+
if not isinstance(memory_config_function_group, list):
|
58
|
+
memory_config_function_group = [memory_config_function_group]
|
59
|
+
if not isinstance(number_of_agents, list):
|
60
|
+
number_of_agents = [number_of_agents]
|
61
|
+
self.agent_class = agent_class
|
62
|
+
self.number_of_agents = number_of_agents
|
63
|
+
self.memory_config_function_group = memory_config_function_group
|
64
|
+
self.agents: list[Agent] = []
|
65
|
+
self.id2agent: dict[str, Agent] = {}
|
49
66
|
self.config = config
|
50
67
|
self.exp_id = exp_id
|
51
68
|
self.enable_avro = enable_avro
|
52
69
|
self.enable_pgsql = enable_pgsql
|
53
70
|
self.embedding_model = embedding_model
|
71
|
+
self.agent_config_file = agent_config_file
|
54
72
|
if enable_avro:
|
55
73
|
self.avro_path = avro_path / f"{self._uuid}"
|
56
74
|
self.avro_path.mkdir(parents=True, exist_ok=True)
|
@@ -63,28 +81,33 @@ class AgentGroup:
|
|
63
81
|
if self.enable_pgsql:
|
64
82
|
pass
|
65
83
|
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
84
|
+
# prepare Messager
|
85
|
+
if "mqtt" in config["simulator_request"]:
|
86
|
+
self.messager = Messager.remote(
|
87
|
+
hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
|
88
|
+
port=config["simulator_request"]["mqtt"]["port"],
|
89
|
+
username=config["simulator_request"]["mqtt"].get("username", None),
|
90
|
+
password=config["simulator_request"]["mqtt"].get("password", None),
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
self.messager = None
|
94
|
+
|
72
95
|
self.message_dispatch_task = None
|
73
96
|
self._pgsql_writer = pgsql_writer
|
74
97
|
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
75
98
|
self.initialized = False
|
76
99
|
self.id2agent = {}
|
77
|
-
#
|
100
|
+
# prepare LLM client
|
78
101
|
llmConfig = LLMConfig(config["llm_request"])
|
79
102
|
logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
|
80
103
|
self.llm = LLM(llmConfig)
|
81
104
|
|
82
|
-
#
|
105
|
+
# prepare Simulator
|
83
106
|
logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
|
84
107
|
self.simulator = Simulator(config["simulator_request"])
|
85
108
|
self.projector = pyproj.Proj(self.simulator.map.header["projection"])
|
86
109
|
|
87
|
-
#
|
110
|
+
# prepare Economy client
|
88
111
|
if "economy" in config["simulator_request"]:
|
89
112
|
logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
|
90
113
|
self.economy_client = EconomyClient(
|
@@ -113,25 +136,58 @@ class AgentGroup:
|
|
113
136
|
)
|
114
137
|
else:
|
115
138
|
self.faiss_query = None
|
116
|
-
for
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
agent
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
139
|
+
for i in range(len(number_of_agents)):
|
140
|
+
agent_class_i = agent_class[i]
|
141
|
+
number_of_agents_i = number_of_agents[i]
|
142
|
+
for j in range(number_of_agents_i):
|
143
|
+
memory_config_function_group_i = memory_config_function_group[i]
|
144
|
+
extra_attributes, profile, base = memory_config_function_group_i()
|
145
|
+
memory = Memory(config=extra_attributes, profile=profile, base=base)
|
146
|
+
agent = agent_class_i(
|
147
|
+
name=f"{agent_class_i.__name__}_{i}",
|
148
|
+
memory=memory,
|
149
|
+
llm_client=self.llm,
|
150
|
+
economy_client=self.economy_client,
|
151
|
+
simulator=self.simulator,
|
152
|
+
)
|
153
|
+
agent.set_exp_id(self.exp_id) # type: ignore
|
154
|
+
if self.mlflow_client is not None:
|
155
|
+
agent.set_mlflow_client(self.mlflow_client)
|
156
|
+
if self.messager is not None:
|
157
|
+
agent.set_messager(self.messager)
|
158
|
+
if self.enable_avro:
|
159
|
+
agent.set_avro_file(self.avro_file) # type: ignore
|
160
|
+
if self.enable_pgsql:
|
161
|
+
agent.set_pgsql_writer(self._pgsql_writer)
|
162
|
+
if self.faiss_query is not None:
|
163
|
+
agent.memory.set_faiss_query(self.faiss_query)
|
164
|
+
if self.embedding_model is not None:
|
165
|
+
agent.memory.set_embedding_model(self.embedding_model)
|
166
|
+
if self.agent_config_file[i]:
|
167
|
+
agent.load_from_file(self.agent_config_file[i])
|
168
|
+
self.agents.append(agent)
|
169
|
+
self.id2agent[agent._uuid] = agent
|
170
|
+
|
171
|
+
@property
|
172
|
+
def agent_count(self):
|
173
|
+
return self.number_of_agents
|
174
|
+
|
175
|
+
@property
|
176
|
+
def agent_uuids(self):
|
177
|
+
return list(self.id2agent.keys())
|
178
|
+
|
179
|
+
@property
|
180
|
+
def agent_type(self):
|
181
|
+
return self.agent_class
|
182
|
+
|
183
|
+
def get_agent_count(self):
|
184
|
+
return self.agent_count
|
185
|
+
|
186
|
+
def get_agent_uuids(self):
|
187
|
+
return self.agent_uuids
|
188
|
+
|
189
|
+
def get_agent_type(self):
|
190
|
+
return self.agent_type
|
135
191
|
|
136
192
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
137
193
|
self.message_dispatch_task.cancel() # type: ignore
|
@@ -144,8 +200,9 @@ class AgentGroup:
|
|
144
200
|
await agent.bind_to_simulator() # type: ignore
|
145
201
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
146
202
|
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
203
|
+
assert self.messager is not None
|
147
204
|
await self.messager.connect.remote()
|
148
|
-
if
|
205
|
+
if await self.messager.is_connected.remote():
|
149
206
|
await self.messager.start_listening.remote()
|
150
207
|
topics = []
|
151
208
|
agents = []
|
@@ -236,6 +293,35 @@ class AgentGroup:
|
|
236
293
|
self.initialized = True
|
237
294
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
238
295
|
|
296
|
+
async def filter(
|
297
|
+
self,
|
298
|
+
types: Optional[list[Type[Agent]]] = None,
|
299
|
+
keys: Optional[list[str]] = None,
|
300
|
+
values: Optional[list[Any]] = None,
|
301
|
+
) -> list[str]:
|
302
|
+
filtered_uuids = []
|
303
|
+
for agent in self.agents:
|
304
|
+
add = True
|
305
|
+
if types:
|
306
|
+
if agent.__class__ in types:
|
307
|
+
if keys:
|
308
|
+
for key in keys:
|
309
|
+
assert values is not None
|
310
|
+
if not agent.memory.get(key) == values[keys.index(key)]:
|
311
|
+
add = False
|
312
|
+
break
|
313
|
+
if add:
|
314
|
+
filtered_uuids.append(agent._uuid)
|
315
|
+
elif keys:
|
316
|
+
for key in keys:
|
317
|
+
assert values is not None
|
318
|
+
if not agent.memory.get(key) == values[keys.index(key)]:
|
319
|
+
add = False
|
320
|
+
break
|
321
|
+
if add:
|
322
|
+
filtered_uuids.append(agent._uuid)
|
323
|
+
return filtered_uuids
|
324
|
+
|
239
325
|
async def gather(self, content: str):
|
240
326
|
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
241
327
|
results = {}
|
@@ -253,7 +339,8 @@ class AgentGroup:
|
|
253
339
|
async def message_dispatch(self):
|
254
340
|
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
255
341
|
while True:
|
256
|
-
|
342
|
+
assert self.messager is not None
|
343
|
+
if not await self.messager.is_connected.remote():
|
257
344
|
logger.warning(
|
258
345
|
"Messager is not connected. Skipping message processing."
|
259
346
|
)
|
@@ -287,8 +374,7 @@ class AgentGroup:
|
|
287
374
|
await agent.handle_user_survey_message(payload)
|
288
375
|
elif topic_type == "gather":
|
289
376
|
await agent.handle_gather_message(payload)
|
290
|
-
|
291
|
-
await asyncio.sleep(0.5)
|
377
|
+
await asyncio.sleep(3)
|
292
378
|
|
293
379
|
async def save_status(
|
294
380
|
self, simulator_day: Optional[int] = None, simulator_t: Optional[int] = None
|