pycityagent 2.0.0a47__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a48__cp312-cp312-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. pycityagent/__init__.py +3 -2
  2. pycityagent/agent.py +109 -4
  3. pycityagent/cityagent/__init__.py +20 -0
  4. pycityagent/cityagent/bankagent.py +54 -0
  5. pycityagent/cityagent/blocks/__init__.py +20 -0
  6. pycityagent/cityagent/blocks/cognition_block.py +304 -0
  7. pycityagent/cityagent/blocks/dispatcher.py +78 -0
  8. pycityagent/cityagent/blocks/economy_block.py +356 -0
  9. pycityagent/cityagent/blocks/mobility_block.py +258 -0
  10. pycityagent/cityagent/blocks/needs_block.py +305 -0
  11. pycityagent/cityagent/blocks/other_block.py +103 -0
  12. pycityagent/cityagent/blocks/plan_block.py +309 -0
  13. pycityagent/cityagent/blocks/social_block.py +345 -0
  14. pycityagent/cityagent/blocks/time_block.py +116 -0
  15. pycityagent/cityagent/blocks/utils.py +66 -0
  16. pycityagent/cityagent/firmagent.py +75 -0
  17. pycityagent/cityagent/governmentagent.py +60 -0
  18. pycityagent/cityagent/initial.py +98 -0
  19. pycityagent/cityagent/memory_config.py +202 -0
  20. pycityagent/cityagent/nbsagent.py +92 -0
  21. pycityagent/cityagent/societyagent.py +291 -0
  22. pycityagent/memory/memory.py +0 -18
  23. pycityagent/message/messager.py +6 -3
  24. pycityagent/simulation/agentgroup.py +118 -37
  25. pycityagent/simulation/simulation.py +311 -316
  26. pycityagent/workflow/block.py +66 -1
  27. pycityagent/workflow/tool.py +15 -11
  28. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/METADATA +2 -2
  29. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/RECORD +33 -14
  30. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/LICENSE +0 -0
  31. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/WHEEL +0 -0
  32. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/entry_points.txt +0 -0
  33. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,12 @@
1
1
  import asyncio
2
+ from collections.abc import Callable
2
3
  import json
3
4
  import logging
4
5
  import time
5
6
  import uuid
6
7
  from datetime import datetime, timezone
7
8
  from pathlib import Path
8
- from typing import Any, Optional
9
+ from typing import Any, Optional, Type, Union
9
10
  from uuid import UUID
10
11
 
11
12
  import fastavro
@@ -13,12 +14,12 @@ import pyproj
13
14
  import ray
14
15
  from langchain_core.embeddings import Embeddings
15
16
 
16
- from ..agent import Agent, CitizenAgent, InstitutionAgent
17
+ from ..agent import Agent, InstitutionAgent
17
18
  from ..economy.econ_client import EconomyClient
18
19
  from ..environment.simulator import Simulator
19
20
  from ..llm.llm import LLM
20
21
  from ..llm.llmconfig import LLMConfig
21
- from ..memory import FaissQuery
22
+ from ..memory import FaissQuery, Memory
22
23
  from ..message import Messager
23
24
  from ..metrics import MlflowClient
24
25
  from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
@@ -31,7 +32,9 @@ logger = logging.getLogger("pycityagent")
31
32
  class AgentGroup:
32
33
  def __init__(
33
34
  self,
34
- agents: list[Agent],
35
+ agent_class: Union[type[Agent], list[type[Agent]]],
36
+ number_of_agents: Union[int, list[int]],
37
+ memory_config_function_group: Union[Callable[[], tuple[dict, dict, dict]], list[Callable[[], tuple[dict, dict, dict]]]],
35
38
  config: dict,
36
39
  exp_id: str | UUID,
37
40
  exp_name: str,
@@ -42,15 +45,27 @@ class AgentGroup:
42
45
  mlflow_run_id: str,
43
46
  embedding_model: Embeddings,
44
47
  logging_level: int,
48
+ agent_config_file: Union[str, list[str]] = None,
45
49
  ):
46
50
  logger.setLevel(logging_level)
47
51
  self._uuid = str(uuid.uuid4())
48
- self.agents = agents
52
+ if not isinstance(agent_class, list):
53
+ agent_class = [agent_class]
54
+ if not isinstance(memory_config_function_group, list):
55
+ memory_config_function_group = [memory_config_function_group]
56
+ if not isinstance(number_of_agents, list):
57
+ number_of_agents = [number_of_agents]
58
+ self.agent_class = agent_class
59
+ self.number_of_agents = number_of_agents
60
+ self.memory_config_function_group = memory_config_function_group
61
+ self.agents: list[Agent] = []
62
+ self.id2agent: dict[str, Agent] = {}
49
63
  self.config = config
50
64
  self.exp_id = exp_id
51
65
  self.enable_avro = enable_avro
52
66
  self.enable_pgsql = enable_pgsql
53
67
  self.embedding_model = embedding_model
68
+ self.agent_config_file = agent_config_file
54
69
  if enable_avro:
55
70
  self.avro_path = avro_path / f"{self._uuid}"
56
71
  self.avro_path.mkdir(parents=True, exist_ok=True)
@@ -63,28 +78,33 @@ class AgentGroup:
63
78
  if self.enable_pgsql:
64
79
  pass
65
80
 
66
- self.messager = Messager.remote(
67
- hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
68
- port=config["simulator_request"]["mqtt"]["port"],
69
- username=config["simulator_request"]["mqtt"].get("username", None),
70
- password=config["simulator_request"]["mqtt"].get("password", None),
71
- )
81
+ # prepare Messager
82
+ if "mqtt" in config["simulator_request"]:
83
+ self.messager = Messager.remote(
84
+ hostname=config["simulator_request"]["mqtt"]["server"],
85
+ port=config["simulator_request"]["mqtt"]["port"],
86
+ username=config["simulator_request"]["mqtt"].get("username", None),
87
+ password=config["simulator_request"]["mqtt"].get("password", None),
88
+ )
89
+ else:
90
+ self.messager = None
91
+
72
92
  self.message_dispatch_task = None
73
93
  self._pgsql_writer = pgsql_writer
74
94
  self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
75
95
  self.initialized = False
76
96
  self.id2agent = {}
77
- # Step:1 prepare LLM client
97
+ # prepare LLM client
78
98
  llmConfig = LLMConfig(config["llm_request"])
79
99
  logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
80
100
  self.llm = LLM(llmConfig)
81
101
 
82
- # Step:2 prepare Simulator
102
+ # prepare Simulator
83
103
  logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
84
104
  self.simulator = Simulator(config["simulator_request"])
85
105
  self.projector = pyproj.Proj(self.simulator.map.header["projection"])
86
106
 
87
- # Step:3 prepare Economy client
107
+ # prepare Economy client
88
108
  if "economy" in config["simulator_request"]:
89
109
  logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
90
110
  self.economy_client = EconomyClient(
@@ -113,25 +133,62 @@ class AgentGroup:
113
133
  )
114
134
  else:
115
135
  self.faiss_query = None
116
- for agent in self.agents:
117
- agent.set_exp_id(self.exp_id) # type: ignore
118
- agent.set_llm_client(self.llm)
119
- agent.set_simulator(self.simulator)
120
- if self.economy_client is not None:
121
- agent.set_economy_client(self.economy_client)
122
- if self.mlflow_client is not None:
123
- agent.set_mlflow_client(self.mlflow_client)
124
- agent.set_messager(self.messager)
125
- if self.enable_avro:
126
- agent.set_avro_file(self.avro_file) # type: ignore
127
- if self.enable_pgsql:
128
- agent.set_pgsql_writer(self._pgsql_writer)
129
- # set memory.faiss_query
130
- if self.faiss_query is not None:
131
- agent.memory.set_faiss_query(self.faiss_query)
132
- # set memory.embedding model
133
- if self.embedding_model is not None:
134
- agent.memory.set_embedding_model(self.embedding_model)
136
+ for i in range(len(number_of_agents)):
137
+ agent_class_i = agent_class[i]
138
+ number_of_agents_i = number_of_agents[i]
139
+ for j in range(number_of_agents_i):
140
+ memory_config_function_group_i = memory_config_function_group[i]
141
+ extra_attributes, profile, base = memory_config_function_group_i()
142
+ memory = Memory(config=extra_attributes, profile=profile, base=base)
143
+ agent = agent_class_i(
144
+ name=f"{agent_class_i.__name__}_{i}",
145
+ memory=memory,
146
+ llm_client=self.llm,
147
+ economy_client=self.economy_client,
148
+ simulator=self.simulator,
149
+ )
150
+ agent.set_exp_id(self.exp_id) # type: ignore
151
+ if self.mlflow_client is not None:
152
+ agent.set_mlflow_client(self.mlflow_client)
153
+ if self.messager is not None:
154
+ agent.set_messager(self.messager)
155
+ if self.enable_avro:
156
+ agent.set_avro_file(self.avro_file) # type: ignore
157
+ if self.enable_pgsql:
158
+ agent.set_pgsql_writer(self._pgsql_writer)
159
+ if self.faiss_query is not None:
160
+ agent.memory.set_faiss_query(self.faiss_query)
161
+ if self.embedding_model is not None:
162
+ agent.memory.set_embedding_model(self.embedding_model)
163
+ if self.agent_config_file[i]:
164
+ agent.load_from_file(self.agent_config_file[i])
165
+ self.agents.append(agent)
166
+ self.id2agent[agent._uuid] = agent
167
+
168
+ @property
169
+ def agent_count(self):
170
+ return self.number_of_agents
171
+
172
+ @property
173
+ def agent_uuids(self):
174
+ return list(self.id2agent.keys())
175
+
176
+ @property
177
+ def agent_type(self):
178
+ return self.agent_class
179
+
180
+ def get_agent_count(self):
181
+ return self.agent_count
182
+
183
+ def get_agent_uuids(self):
184
+ return self.agent_uuids
185
+
186
+ def get_agent_type(self):
187
+ return self.agent_type
188
+
189
+ async def __aexit__(self, exc_type, exc_value, traceback):
190
+ self.message_dispatch_task.cancel() # type: ignore
191
+ await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
135
192
 
136
193
  async def __aexit__(self, exc_type, exc_value, traceback):
137
194
  self.message_dispatch_task.cancel() # type: ignore
@@ -145,7 +202,7 @@ class AgentGroup:
145
202
  self.id2agent = {agent._uuid: agent for agent in self.agents}
146
203
  logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
147
204
  await self.messager.connect.remote()
148
- if ray.get(self.messager.is_connected.remote()):
205
+ if await self.messager.is_connected.remote():
149
206
  await self.messager.start_listening.remote()
150
207
  topics = []
151
208
  agents = []
@@ -236,6 +293,31 @@ class AgentGroup:
236
293
  self.initialized = True
237
294
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
238
295
 
296
+ async def filter(self,
297
+ types: Optional[list[Type[Agent]]] = None,
298
+ keys: Optional[list[str]] = None,
299
+ values: Optional[list[Any]] = None) -> list[str]:
300
+ filtered_uuids = []
301
+ for agent in self.agents:
302
+ add = True
303
+ if types:
304
+ if agent.__class__ in types:
305
+ if keys:
306
+ for key in keys:
307
+ if not agent.memory.get(key) == values[keys.index(key)]:
308
+ add = False
309
+ break
310
+ if add:
311
+ filtered_uuids.append(agent._uuid)
312
+ elif keys:
313
+ for key in keys:
314
+ if not agent.memory.get(key) == values[keys.index(key)]:
315
+ add = False
316
+ break
317
+ if add:
318
+ filtered_uuids.append(agent._uuid)
319
+ return filtered_uuids
320
+
239
321
  async def gather(self, content: str):
240
322
  logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
241
323
  results = {}
@@ -253,7 +335,7 @@ class AgentGroup:
253
335
  async def message_dispatch(self):
254
336
  logger.debug(f"-----Starting message dispatch for group {self._uuid}")
255
337
  while True:
256
- if not ray.get(self.messager.is_connected.remote()):
338
+ if not await self.messager.is_connected.remote():
257
339
  logger.warning(
258
340
  "Messager is not connected. Skipping message processing."
259
341
  )
@@ -287,8 +369,7 @@ class AgentGroup:
287
369
  await agent.handle_user_survey_message(payload)
288
370
  elif topic_type == "gather":
289
371
  await agent.handle_gather_message(payload)
290
-
291
- await asyncio.sleep(0.5)
372
+ await asyncio.sleep(3)
292
373
 
293
374
  async def save_status(
294
375
  self, simulator_day: Optional[int] = None, simulator_t: Optional[int] = None