pycityagent 2.0.0a47__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a49__cp312-cp312-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. pycityagent/__init__.py +3 -2
  2. pycityagent/agent.py +109 -4
  3. pycityagent/cityagent/__init__.py +20 -0
  4. pycityagent/cityagent/bankagent.py +54 -0
  5. pycityagent/cityagent/blocks/__init__.py +20 -0
  6. pycityagent/cityagent/blocks/cognition_block.py +304 -0
  7. pycityagent/cityagent/blocks/dispatcher.py +78 -0
  8. pycityagent/cityagent/blocks/economy_block.py +356 -0
  9. pycityagent/cityagent/blocks/mobility_block.py +258 -0
  10. pycityagent/cityagent/blocks/needs_block.py +305 -0
  11. pycityagent/cityagent/blocks/other_block.py +103 -0
  12. pycityagent/cityagent/blocks/plan_block.py +309 -0
  13. pycityagent/cityagent/blocks/social_block.py +345 -0
  14. pycityagent/cityagent/blocks/time_block.py +116 -0
  15. pycityagent/cityagent/blocks/utils.py +66 -0
  16. pycityagent/cityagent/firmagent.py +75 -0
  17. pycityagent/cityagent/governmentagent.py +60 -0
  18. pycityagent/cityagent/initial.py +98 -0
  19. pycityagent/cityagent/memory_config.py +202 -0
  20. pycityagent/cityagent/nbsagent.py +92 -0
  21. pycityagent/cityagent/societyagent.py +291 -0
  22. pycityagent/memory/memory.py +0 -18
  23. pycityagent/message/messager.py +6 -3
  24. pycityagent/simulation/agentgroup.py +123 -37
  25. pycityagent/simulation/simulation.py +311 -316
  26. pycityagent/workflow/block.py +66 -1
  27. pycityagent/workflow/tool.py +9 -4
  28. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/METADATA +2 -2
  29. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/RECORD +33 -14
  30. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/LICENSE +0 -0
  31. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/WHEEL +0 -0
  32. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/entry_points.txt +0 -0
  33. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a49.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,10 @@ import json
3
3
  import logging
4
4
  import time
5
5
  import uuid
6
+ from collections.abc import Callable
6
7
  from datetime import datetime, timezone
7
8
  from pathlib import Path
8
- from typing import Any, Optional
9
+ from typing import Any, Optional, Type, Union
9
10
  from uuid import UUID
10
11
 
11
12
  import fastavro
@@ -13,12 +14,12 @@ import pyproj
13
14
  import ray
14
15
  from langchain_core.embeddings import Embeddings
15
16
 
16
- from ..agent import Agent, CitizenAgent, InstitutionAgent
17
+ from ..agent import Agent, InstitutionAgent
17
18
  from ..economy.econ_client import EconomyClient
18
19
  from ..environment.simulator import Simulator
19
20
  from ..llm.llm import LLM
20
21
  from ..llm.llmconfig import LLMConfig
21
- from ..memory import FaissQuery
22
+ from ..memory import FaissQuery, Memory
22
23
  from ..message import Messager
23
24
  from ..metrics import MlflowClient
24
25
  from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
@@ -31,7 +32,12 @@ logger = logging.getLogger("pycityagent")
31
32
  class AgentGroup:
32
33
  def __init__(
33
34
  self,
34
- agents: list[Agent],
35
+ agent_class: Union[type[Agent], list[type[Agent]]],
36
+ number_of_agents: Union[int, list[int]],
37
+ memory_config_function_group: Union[
38
+ Callable[[], tuple[dict, dict, dict]],
39
+ list[Callable[[], tuple[dict, dict, dict]]],
40
+ ],
35
41
  config: dict,
36
42
  exp_id: str | UUID,
37
43
  exp_name: str,
@@ -42,15 +48,27 @@ class AgentGroup:
42
48
  mlflow_run_id: str,
43
49
  embedding_model: Embeddings,
44
50
  logging_level: int,
51
+ agent_config_file: Union[str, list[str]] = None,
45
52
  ):
46
53
  logger.setLevel(logging_level)
47
54
  self._uuid = str(uuid.uuid4())
48
- self.agents = agents
55
+ if not isinstance(agent_class, list):
56
+ agent_class = [agent_class]
57
+ if not isinstance(memory_config_function_group, list):
58
+ memory_config_function_group = [memory_config_function_group]
59
+ if not isinstance(number_of_agents, list):
60
+ number_of_agents = [number_of_agents]
61
+ self.agent_class = agent_class
62
+ self.number_of_agents = number_of_agents
63
+ self.memory_config_function_group = memory_config_function_group
64
+ self.agents: list[Agent] = []
65
+ self.id2agent: dict[str, Agent] = {}
49
66
  self.config = config
50
67
  self.exp_id = exp_id
51
68
  self.enable_avro = enable_avro
52
69
  self.enable_pgsql = enable_pgsql
53
70
  self.embedding_model = embedding_model
71
+ self.agent_config_file = agent_config_file
54
72
  if enable_avro:
55
73
  self.avro_path = avro_path / f"{self._uuid}"
56
74
  self.avro_path.mkdir(parents=True, exist_ok=True)
@@ -63,28 +81,33 @@ class AgentGroup:
63
81
  if self.enable_pgsql:
64
82
  pass
65
83
 
66
- self.messager = Messager.remote(
67
- hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
68
- port=config["simulator_request"]["mqtt"]["port"],
69
- username=config["simulator_request"]["mqtt"].get("username", None),
70
- password=config["simulator_request"]["mqtt"].get("password", None),
71
- )
84
+ # prepare Messager
85
+ if "mqtt" in config["simulator_request"]:
86
+ self.messager = Messager.remote(
87
+ hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
88
+ port=config["simulator_request"]["mqtt"]["port"],
89
+ username=config["simulator_request"]["mqtt"].get("username", None),
90
+ password=config["simulator_request"]["mqtt"].get("password", None),
91
+ )
92
+ else:
93
+ self.messager = None
94
+
72
95
  self.message_dispatch_task = None
73
96
  self._pgsql_writer = pgsql_writer
74
97
  self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
75
98
  self.initialized = False
76
99
  self.id2agent = {}
77
- # Step:1 prepare LLM client
100
+ # prepare LLM client
78
101
  llmConfig = LLMConfig(config["llm_request"])
79
102
  logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
80
103
  self.llm = LLM(llmConfig)
81
104
 
82
- # Step:2 prepare Simulator
105
+ # prepare Simulator
83
106
  logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
84
107
  self.simulator = Simulator(config["simulator_request"])
85
108
  self.projector = pyproj.Proj(self.simulator.map.header["projection"])
86
109
 
87
- # Step:3 prepare Economy client
110
+ # prepare Economy client
88
111
  if "economy" in config["simulator_request"]:
89
112
  logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
90
113
  self.economy_client = EconomyClient(
@@ -113,25 +136,58 @@ class AgentGroup:
113
136
  )
114
137
  else:
115
138
  self.faiss_query = None
116
- for agent in self.agents:
117
- agent.set_exp_id(self.exp_id) # type: ignore
118
- agent.set_llm_client(self.llm)
119
- agent.set_simulator(self.simulator)
120
- if self.economy_client is not None:
121
- agent.set_economy_client(self.economy_client)
122
- if self.mlflow_client is not None:
123
- agent.set_mlflow_client(self.mlflow_client)
124
- agent.set_messager(self.messager)
125
- if self.enable_avro:
126
- agent.set_avro_file(self.avro_file) # type: ignore
127
- if self.enable_pgsql:
128
- agent.set_pgsql_writer(self._pgsql_writer)
129
- # set memory.faiss_query
130
- if self.faiss_query is not None:
131
- agent.memory.set_faiss_query(self.faiss_query)
132
- # set memory.embedding model
133
- if self.embedding_model is not None:
134
- agent.memory.set_embedding_model(self.embedding_model)
139
+ for i in range(len(number_of_agents)):
140
+ agent_class_i = agent_class[i]
141
+ number_of_agents_i = number_of_agents[i]
142
+ for j in range(number_of_agents_i):
143
+ memory_config_function_group_i = memory_config_function_group[i]
144
+ extra_attributes, profile, base = memory_config_function_group_i()
145
+ memory = Memory(config=extra_attributes, profile=profile, base=base)
146
+ agent = agent_class_i(
147
+ name=f"{agent_class_i.__name__}_{i}",
148
+ memory=memory,
149
+ llm_client=self.llm,
150
+ economy_client=self.economy_client,
151
+ simulator=self.simulator,
152
+ )
153
+ agent.set_exp_id(self.exp_id) # type: ignore
154
+ if self.mlflow_client is not None:
155
+ agent.set_mlflow_client(self.mlflow_client)
156
+ if self.messager is not None:
157
+ agent.set_messager(self.messager)
158
+ if self.enable_avro:
159
+ agent.set_avro_file(self.avro_file) # type: ignore
160
+ if self.enable_pgsql:
161
+ agent.set_pgsql_writer(self._pgsql_writer)
162
+ if self.faiss_query is not None:
163
+ agent.memory.set_faiss_query(self.faiss_query)
164
+ if self.embedding_model is not None:
165
+ agent.memory.set_embedding_model(self.embedding_model)
166
+ if self.agent_config_file[i]:
167
+ agent.load_from_file(self.agent_config_file[i])
168
+ self.agents.append(agent)
169
+ self.id2agent[agent._uuid] = agent
170
+
171
+ @property
172
+ def agent_count(self):
173
+ return self.number_of_agents
174
+
175
+ @property
176
+ def agent_uuids(self):
177
+ return list(self.id2agent.keys())
178
+
179
+ @property
180
+ def agent_type(self):
181
+ return self.agent_class
182
+
183
+ def get_agent_count(self):
184
+ return self.agent_count
185
+
186
+ def get_agent_uuids(self):
187
+ return self.agent_uuids
188
+
189
+ def get_agent_type(self):
190
+ return self.agent_type
135
191
 
136
192
  async def __aexit__(self, exc_type, exc_value, traceback):
137
193
  self.message_dispatch_task.cancel() # type: ignore
@@ -144,8 +200,9 @@ class AgentGroup:
144
200
  await agent.bind_to_simulator() # type: ignore
145
201
  self.id2agent = {agent._uuid: agent for agent in self.agents}
146
202
  logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
203
+ assert self.messager is not None
147
204
  await self.messager.connect.remote()
148
- if ray.get(self.messager.is_connected.remote()):
205
+ if await self.messager.is_connected.remote():
149
206
  await self.messager.start_listening.remote()
150
207
  topics = []
151
208
  agents = []
@@ -236,6 +293,35 @@ class AgentGroup:
236
293
  self.initialized = True
237
294
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
238
295
 
296
+ async def filter(
297
+ self,
298
+ types: Optional[list[Type[Agent]]] = None,
299
+ keys: Optional[list[str]] = None,
300
+ values: Optional[list[Any]] = None,
301
+ ) -> list[str]:
302
+ filtered_uuids = []
303
+ for agent in self.agents:
304
+ add = True
305
+ if types:
306
+ if agent.__class__ in types:
307
+ if keys:
308
+ for key in keys:
309
+ assert values is not None
310
+ if not agent.memory.get(key) == values[keys.index(key)]:
311
+ add = False
312
+ break
313
+ if add:
314
+ filtered_uuids.append(agent._uuid)
315
+ elif keys:
316
+ for key in keys:
317
+ assert values is not None
318
+ if not agent.memory.get(key) == values[keys.index(key)]:
319
+ add = False
320
+ break
321
+ if add:
322
+ filtered_uuids.append(agent._uuid)
323
+ return filtered_uuids
324
+
239
325
  async def gather(self, content: str):
240
326
  logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
241
327
  results = {}
@@ -253,7 +339,8 @@ class AgentGroup:
253
339
  async def message_dispatch(self):
254
340
  logger.debug(f"-----Starting message dispatch for group {self._uuid}")
255
341
  while True:
256
- if not ray.get(self.messager.is_connected.remote()):
342
+ assert self.messager is not None
343
+ if not await self.messager.is_connected.remote():
257
344
  logger.warning(
258
345
  "Messager is not connected. Skipping message processing."
259
346
  )
@@ -287,8 +374,7 @@ class AgentGroup:
287
374
  await agent.handle_user_survey_message(payload)
288
375
  elif topic_type == "gather":
289
376
  await agent.handle_gather_message(payload)
290
-
291
- await asyncio.sleep(0.5)
377
+ await asyncio.sleep(3)
292
378
 
293
379
  async def save_status(
294
380
  self, simulator_day: Optional[int] = None, simulator_t: Optional[int] = None