pycityagent 2.0.0a65__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a66__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,12 @@
1
+ import math
1
2
  from typing import List
2
3
  from .dispatcher import BlockDispatcher
3
4
  from pycityagent.environment.simulator import Simulator
4
5
  from pycityagent.llm import LLM
5
6
  from pycityagent.memory import Memory
6
7
  from pycityagent.workflow.block import Block
8
+ import numpy as np
9
+ from operator import itemgetter
7
10
  import random
8
11
  import logging
9
12
  logger = logging.getLogger("pycityagent")
@@ -38,18 +41,83 @@ Current temperature: {temperature}
38
41
  Your current emotion: {emotion_types}
39
42
  Your current thought: {thought}
40
43
 
41
- Please analyze how these emotions would affect travel willingness and return only a single integer number between 1000-100000 representing the maximum travel radius in meters. A more positive emotional state generally leads to greater willingness to travel further.
44
+ Please analyze how these emotions would affect travel willingness and return only a single integer number between 3000-100000 representing the maximum travel radius in meters. A more positive emotional state generally leads to greater willingness to travel further.
42
45
 
43
46
  Return only the integer number without any additional text or explanation."""
44
47
 
48
+ def gravity_model(pois):
49
+ N = len(pois)
50
+ pois_Dis = {
51
+ "1k": [],
52
+ "2k": [],
53
+ "3k": [],
54
+ "4k": [],
55
+ "5k": [],
56
+ "6k": [],
57
+ "7k": [],
58
+ "8k": [],
59
+ "9k": [],
60
+ "10k": [],
61
+ "more": [],
62
+ }
63
+ for poi in pois:
64
+ iflt10k = True
65
+ for d in range(1, 11):
66
+ if (d - 1) * 1000 <= poi[1] < d * 1000:
67
+ pois_Dis["{}k".format(d)].append(poi)
68
+ iflt10k = False
69
+ break
70
+ if iflt10k:
71
+ pois_Dis["more"].append(poi)
72
+
73
+ res = []
74
+ distanceProb = []
75
+ for poi in pois:
76
+ iflt10k = True
77
+ for d in range(1, 11):
78
+ if (d - 1) * 1000 <= poi[1] < d * 1000:
79
+ n = len(pois_Dis["{}k".format(d)])
80
+ S = math.pi * ((d * 1000) ** 2 - ((d - 1) * 1000) ** 2)
81
+ density = n / S
82
+ distance = poi[1]
83
+ distance = distance if distance > 1 else 1
84
+
85
+ # distance decay coefficient, use the square of distance to calculate, so that distant places are less likely to be selected
86
+ weight = density / (distance**2) # the original weight is reasonable
87
+ res.append((poi[0]["name"], poi[0]["id"], weight, distance))
88
+ distanceProb.append(
89
+ 1 / (math.sqrt(distance))
90
+ )
91
+ iflt10k = False
92
+ break
93
+
94
+ distanceProb = np.array(distanceProb)
95
+ distanceProb = distanceProb / np.sum(distanceProb)
96
+ distanceProb = list(distanceProb)
97
+
98
+ options = list(range(len(res)))
99
+ sample = list(
100
+ np.random.choice(options, size=50, p=distanceProb)
101
+ ) # sample based on the probability value
102
+
103
+ get_elements = itemgetter(*sample)
104
+ random_elements = get_elements(res)
105
+
106
+ # normalize the weight to become the true probability value
107
+ weightSum = sum(item[2] for item in random_elements)
108
+ final = [
109
+ (item[0], item[1], item[2] / weightSum, item[3]) for item in random_elements
110
+ ]
111
+ return final
112
+
45
113
  class PlaceSelectionBlock(Block):
46
114
  """
47
- 选择目的地
115
+ Select destination
48
116
  PlaceSelectionBlock
49
117
  """
50
118
  configurable_fields: List[str] = ["search_limit"]
51
119
  default_values = {
52
- "search_limit": 10
120
+ "search_limit": 10000
53
121
  }
54
122
 
55
123
  def __init__(self, llm: LLM, memory: Memory, simulator: Simulator):
@@ -59,7 +127,7 @@ class PlaceSelectionBlock(Block):
59
127
  self.secondTypeSelectionPrompt = FormatPrompt(PLACE_SECOND_TYPE_SELECTION_PROMPT)
60
128
  self.radiusPrompt = FormatPrompt(RADIUS_PROMPT)
61
129
  # configurable fields
62
- self.search_limit = 10
130
+ self.search_limit = 10000
63
131
 
64
132
  async def forward(self, step, context):
65
133
  self.typeSelectionPrompt.format(
@@ -106,9 +174,15 @@ class PlaceSelectionBlock(Block):
106
174
  limit=self.search_limit
107
175
  )
108
176
  if len(pois) > 0:
109
- poi = random.choice(pois)[0]
110
- nextPlace = (poi['name'], poi['aoi_id'])
111
- # 将地点信息保存到context中
177
+ pois = gravity_model(pois)
178
+ probabilities = [item[2] for item in pois]
179
+ options = list(range(len(pois)))
180
+ sample = np.random.choice(
181
+ options, size=1, p=probabilities
182
+ ) # sample based on the probability value
183
+ nextPlace = pois[sample[0]]
184
+ nextPlace = (nextPlace[0], nextPlace[1])
185
+ # save the destination to context
112
186
  context['next_place'] = nextPlace
113
187
  node_id = await self.memory.stream.add_mobility(description=f"For {step['intention']}, I selected the destination: {nextPlace}")
114
188
  return {
@@ -121,7 +195,7 @@ class PlaceSelectionBlock(Block):
121
195
  simmap = self.simulator.map
122
196
  poi = random.choice(list(simmap.pois.values()))
123
197
  nextPlace = (poi['name'], poi['aoi_id'])
124
- # 将地点信息保存到context
198
+ # save the destination to context
125
199
  context['next_place'] = nextPlace
126
200
  node_id = await self.memory.stream.add_mobility(description=f"For {step['intention']}, I selected the destination: {nextPlace}")
127
201
  return {
@@ -133,7 +207,7 @@ class PlaceSelectionBlock(Block):
133
207
 
134
208
  class MoveBlock(Block):
135
209
  """
136
- 移动操作
210
+ Execute mobility operations
137
211
  MoveBlock
138
212
  """
139
213
  def __init__(self, llm: LLM, memory: Memory, simulator: Simulator):
@@ -17,9 +17,6 @@ Profile Information:
17
17
  - Age: {age}
18
18
  - Monthly Income: {income}
19
19
 
20
- Current Emotion: {emotion_types}
21
- Current Thought: {thought}
22
-
23
20
  Current Time: {now_time}
24
21
 
25
22
  Please initialize the agent's satisfaction levels and parameters based on the profile above. Return the values in JSON format with the following structure:
@@ -77,9 +74,6 @@ Current satisfaction:
77
74
  - safety_satisfaction: {safety_satisfaction}
78
75
  - social_satisfaction: {social_satisfaction}
79
76
 
80
- Current Emotion: {emotion_types}
81
- Current Thought: {thought}
82
-
83
77
  Please evaluate and adjust the value of {current_need} satisfaction based on the execution results above.
84
78
 
85
79
  Notes:
@@ -134,8 +128,6 @@ class NeedsBlock(Block):
134
128
  occupation=await self.memory.status.get("occupation"),
135
129
  age=await self.memory.status.get("age"),
136
130
  income=await self.memory.status.get("income"),
137
- emotion_types=await self.memory.status.get("emotion_types"),
138
- thought=await self.memory.status.get("thought"),
139
131
  now_time=await self.simulator.get_time(format_time=True)
140
132
  )
141
133
  response = await self.llm.atext_request(
@@ -285,8 +277,6 @@ class NeedsBlock(Block):
285
277
  energy_satisfaction=await self.memory.status.get("energy_satisfaction"),
286
278
  safety_satisfaction=await self.memory.status.get("safety_satisfaction"),
287
279
  social_satisfaction=await self.memory.status.get("social_satisfaction"),
288
- emotion_types=await self.memory.status.get("emotion_types"),
289
- thought=await self.memory.status.get("thought")
290
280
  )
291
281
 
292
282
  response = await self.llm.atext_request(
@@ -210,16 +210,6 @@ class FindPersonBlock(Block):
210
210
 
211
211
  class MessageBlock(Block):
212
212
  """生成并发送消息"""
213
- configurable_fields: List[str] = ["default_message_template", "to_discuss"]
214
- default_values = {
215
- "default_message_template": """
216
- As a {gender} {occupation} with {education} education and {personality} personality,
217
- generate a message for a friend (relationship strength: {relationship_score}/100)
218
- about {intention}.
219
- """,
220
- "to_discuss": []
221
- }
222
-
223
213
  def __init__(self, agent, llm: LLM, memory: Memory, simulator: Simulator):
224
214
  super().__init__("MessageBlock", llm=llm, memory=memory, simulator=simulator)
225
215
  self.agent = agent
@@ -73,22 +73,19 @@ async def bind_agent_info(simulation):
73
73
  infos = await simulation.gather("id")
74
74
  citizen_uuids = await simulation.filter(types=[SocietyAgent])
75
75
  firm_uuids = await simulation.filter(types=[FirmAgent])
76
- locations = await simulation.gather("location", firm_uuids)
77
- locations_plain = {}
78
- for info in locations:
79
- for k, v in info.items():
80
- locations_plain[k] = v
81
76
  government_uuids = await simulation.filter(types=[GovernmentAgent])
82
77
  bank_uuids = await simulation.filter(types=[BankAgent])
83
78
  nbs_uuids = await simulation.filter(types=[NBSAgent])
84
79
  citizen_agent_ids = []
85
80
  firm_ids = []
81
+ id2uuid = {}
86
82
  for info in infos:
87
83
  for k, v in info.items():
88
84
  if k in citizen_uuids:
89
85
  citizen_agent_ids.append(v)
90
86
  elif k in firm_uuids:
91
87
  firm_ids.append(v)
88
+ id2uuid[v] = k
92
89
  elif k in government_uuids:
93
90
  government_id = v
94
91
  elif k in bank_uuids:
@@ -97,9 +94,7 @@ async def bind_agent_info(simulation):
97
94
  nbs_id = v
98
95
  for citizen_uuid in citizen_uuids:
99
96
  random_firm_id = random.choice(firm_ids)
100
- location = locations_plain[random_firm_id]
101
97
  await simulation.update(citizen_uuid, "firm_id", random_firm_id)
102
- await simulation.update(citizen_uuid, "work", location)
103
98
  await simulation.update(citizen_uuid, "government_id", government_id)
104
99
  await simulation.update(citizen_uuid, "bank_id", bank_id)
105
100
  await simulation.update(citizen_uuid, "nbs_id", nbs_id)
@@ -5,6 +5,7 @@ import numpy as np
5
5
  import pycityproto.city.economy.v2.economy_pb2 as economyv2
6
6
  from mosstool.map._map_util.const import AOI_START_ID
7
7
 
8
+ from .firmagent import FirmAgent
8
9
  pareto_param = 8
9
10
  payment_max_skill_multiplier = 950
10
11
  payment_max_skill_multiplier = float(payment_max_skill_multiplier)
@@ -14,17 +15,24 @@ clipped_skills = np.minimum(pmsm, (pmsm - 1) * pareto_samples + 1)
14
15
  sorted_clipped_skills = np.sort(clipped_skills, axis=1)
15
16
  agent_skills = list(sorted_clipped_skills.mean(axis=0))
16
17
 
18
+ work_locations = [AOI_START_ID + random.randint(1000, 10000) for _ in range(1000)]
19
+
20
+ async def memory_config_init(simulation):
21
+ global work_locations
22
+ number_of_firm = simulation.agent_count[FirmAgent]
23
+ work_locations = [AOI_START_ID + random.randint(1000, 10000) for _ in range(number_of_firm)]
17
24
 
18
25
  def memory_config_societyagent():
26
+ global work_locations
19
27
  EXTRA_ATTRIBUTES = {
20
28
  "type": (str, "citizen"),
21
29
  "city": (str, "New York", True),
22
30
 
23
31
  # Needs Model
24
- "hunger_satisfaction": (float, random.random(), True), # 饥饿满意度
25
- "energy_satisfaction": (float, random.random(), True), # 精力满意度
26
- "safety_satisfaction": (float, random.random(), True), # 安全满意度
27
- "social_satisfaction": (float, random.random(), True), # 社交满意度
32
+ "hunger_satisfaction": (float, random.random(), True), # hunger satisfaction
33
+ "energy_satisfaction": (float, random.random(), True), # energy satisfaction
34
+ "safety_satisfaction": (float, random.random(), True), # safety satisfaction
35
+ "social_satisfaction": (float, random.random(), True), # social satisfaction
28
36
  "current_need": (str, "none", True),
29
37
 
30
38
  # Plan Behavior Model
@@ -186,7 +194,7 @@ def memory_config_societyagent():
186
194
  "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1000, 10000)}
187
195
  },
188
196
  "work": {
189
- "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1000, 10000)}
197
+ "aoi_position": {"aoi_id": random.choice(work_locations)}
190
198
  },
191
199
  }
192
200
 
@@ -194,10 +202,11 @@ def memory_config_societyagent():
194
202
 
195
203
 
196
204
  def memory_config_firm():
205
+ global work_locations
197
206
  EXTRA_ATTRIBUTES = {
198
207
  "type": (int, economyv2.ORG_TYPE_FIRM),
199
208
  "location": {
200
- "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1000, 10000)}
209
+ "aoi_position": {"aoi_id": random.choice(work_locations)}
201
210
  },
202
211
  "price": (float, float(np.mean(agent_skills))),
203
212
  "inventory": (int, 0),
@@ -234,7 +234,8 @@ class SocietyAgent(CitizenAgent):
234
234
  await self.update_with_sim()
235
235
 
236
236
  # check last step
237
- if not await self.check_and_update_step():
237
+ ifpass = await self.check_and_update_step()
238
+ if not ifpass:
238
239
  return
239
240
 
240
241
  await self.planAndActionBlock.forward()
pycityagent/llm/llm.py CHANGED
@@ -24,7 +24,6 @@ from .utils import *
24
24
 
25
25
  os.environ["GRPC_VERBOSITY"] = "ERROR"
26
26
 
27
-
28
27
  class LLM:
29
28
  """
30
29
  大语言模型对象
@@ -46,6 +45,7 @@ class LLM:
46
45
  api_keys = [api_keys]
47
46
 
48
47
  self._aclients = []
48
+ self._client_usage = []
49
49
 
50
50
  for api_key in api_keys:
51
51
  if self.config.text["request_type"] == "openai":
@@ -69,6 +69,11 @@ class LLM:
69
69
  f"Unsupported `request_type` {self.config.text['request_type']}!"
70
70
  )
71
71
  self._aclients.append(client)
72
+ self._client_usage.append({
73
+ "prompt_tokens": 0,
74
+ "completion_tokens": 0,
75
+ "request_number": 0
76
+ })
72
77
 
73
78
  def set_semaphore(self, number_of_coroutine: int):
74
79
  self.semaphore = asyncio.Semaphore(number_of_coroutine)
@@ -81,45 +86,62 @@ class LLM:
81
86
  clear the storage of used tokens to start a new log message
82
87
  Only support OpenAI category API right now, including OpenAI, Deepseek
83
88
  """
84
- self.prompt_tokens_used = 0
85
- self.completion_tokens_used = 0
86
- self.request_number = 0
89
+ for usage in self._client_usage:
90
+ usage["prompt_tokens"] = 0
91
+ usage["completion_tokens"] = 0
92
+ usage["request_number"] = 0
93
+
94
+ def get_consumption(self):
95
+ consumption = {}
96
+ for i, usage in enumerate(self._client_usage):
97
+ consumption[f"api-key-{i+1}"] = {
98
+ "total_tokens": usage["prompt_tokens"] + usage["completion_tokens"],
99
+ "request_number": usage["request_number"]
100
+ }
101
+ return consumption
87
102
 
88
103
  def show_consumption(
89
104
  self, input_price: Optional[float] = None, output_price: Optional[float] = None
90
105
  ):
91
106
  """
92
- if you give the input and output price of using model, this function will also calculate the consumption for you
107
+ Show consumption for each API key separately
93
108
  """
94
- total_token = self.prompt_tokens_used + self.completion_tokens_used
95
- if self.completion_tokens_used != 0:
96
- rate = self.prompt_tokens_used / self.completion_tokens_used
97
- else:
98
- rate = "nan"
99
- if self.request_number != 0:
100
- TcA = total_token / self.request_number
101
- else:
102
- TcA = "nan"
103
- out = f"""Request Number: {self.request_number}
104
- Token Usage:
105
- - Total tokens: {total_token}
106
- - Prompt tokens: {self.prompt_tokens_used}
107
- - Completion tokens: {self.completion_tokens_used}
108
- - Token per request: {TcA}
109
- - Prompt:Completion ratio: {rate}:1"""
110
- if input_price != None and output_price != None:
111
- consumption = (
112
- self.prompt_tokens_used / 1000000 * input_price
113
- + self.completion_tokens_used / 1000000 * output_price
114
- )
115
- out += f"\n - Cost Estimation: {consumption}"
116
- print(out)
117
- return {
118
- "total": total_token,
119
- "prompt": self.prompt_tokens_used,
120
- "completion": self.completion_tokens_used,
121
- "ratio": rate,
109
+ total_stats = {
110
+ "total": 0,
111
+ "prompt": 0,
112
+ "completion": 0,
113
+ "requests": 0
122
114
  }
115
+
116
+ for i, usage in enumerate(self._client_usage):
117
+ prompt_tokens = usage["prompt_tokens"]
118
+ completion_tokens = usage["completion_tokens"]
119
+ requests = usage["request_number"]
120
+ total_tokens = prompt_tokens + completion_tokens
121
+
122
+ total_stats["total"] += total_tokens
123
+ total_stats["prompt"] += prompt_tokens
124
+ total_stats["completion"] += completion_tokens
125
+ total_stats["requests"] += requests
126
+
127
+ rate = prompt_tokens / completion_tokens if completion_tokens != 0 else "nan"
128
+ tokens_per_request = total_tokens / requests if requests != 0 else "nan"
129
+
130
+ print(f"\nAPI Key #{i+1}:")
131
+ print(f"Request Number: {requests}")
132
+ print("Token Usage:")
133
+ print(f" - Total tokens: {total_tokens}")
134
+ print(f" - Prompt tokens: {prompt_tokens}")
135
+ print(f" - Completion tokens: {completion_tokens}")
136
+ print(f" - Token per request: {tokens_per_request}")
137
+ print(f" - Prompt:Completion ratio: {rate}:1")
138
+
139
+ if input_price is not None and output_price is not None:
140
+ consumption = (prompt_tokens / 1000000 * input_price +
141
+ completion_tokens / 1000000 * output_price)
142
+ print(f" - Cost Estimation: {consumption}")
143
+
144
+ return total_stats
123
145
 
124
146
  def _get_next_client(self):
125
147
  """获取下一个要使用的客户端"""
@@ -168,9 +190,9 @@ Token Usage:
168
190
  tools=tools,
169
191
  tool_choice=tool_choice,
170
192
  ) # type: ignore
171
- self.prompt_tokens_used += response.usage.prompt_tokens # type: ignore
172
- self.completion_tokens_used += response.usage.completion_tokens # type: ignore
173
- self.request_number += 1
193
+ self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
194
+ self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
195
+ self._client_usage[self._current_client_index]["request_number"] += 1
174
196
  if tools and response.choices[0].message.tool_calls:
175
197
  return json.loads(
176
198
  response.choices[0]
@@ -193,9 +215,9 @@ Token Usage:
193
215
  tools=tools,
194
216
  tool_choice=tool_choice,
195
217
  ) # type: ignore
196
- self.prompt_tokens_used += response.usage.prompt_tokens # type: ignore
197
- self.completion_tokens_used += response.usage.completion_tokens # type: ignore
198
- self.request_number += 1
218
+ self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
219
+ self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
220
+ self._client_usage[self._current_client_index]["request_number"] += 1
199
221
  if tools and response.choices[0].message.tool_calls:
200
222
  return json.loads(
201
223
  response.choices[0]
@@ -248,9 +270,9 @@ Token Usage:
248
270
  if task_status != "SUCCESS":
249
271
  raise Exception(f"Task failed with status: {task_status}")
250
272
 
251
- self.prompt_tokens_used += result_response.usage.prompt_tokens # type: ignore
252
- self.completion_tokens_used += result_response.usage.completion_tokens # type: ignore
253
- self.request_number += 1
273
+ self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
274
+ self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
275
+ self._client_usage[self._current_client_index]["request_number"] += 1
254
276
  if tools and result_response.choices[0].message.tool_calls: # type: ignore
255
277
  return json.loads(
256
278
  result_response.choices[0] # type: ignore
@@ -330,7 +330,7 @@ class StreamMemory:
330
330
  # 找到所有对应的记忆
331
331
  target_memories = []
332
332
  for memory in self._memories:
333
- if id(memory) in memory_ids:
333
+ if memory.id in memory_ids:
334
334
  target_memories.append(memory)
335
335
 
336
336
  if not target_memories:
@@ -186,7 +186,7 @@ class AgentGroup:
186
186
  self.message_dispatch_task.cancel() # type: ignore
187
187
  await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
188
188
 
189
- async def insert_agents(self):
189
+ async def insert_agent(self):
190
190
  bind_tasks = []
191
191
  for agent in self.agents:
192
192
  bind_tasks.append(agent.bind_to_simulator()) # type: ignore
@@ -200,7 +200,7 @@ class AgentGroup:
200
200
  if day == 0:
201
201
  break
202
202
  await asyncio.sleep(1)
203
- await self.insert_agents()
203
+ await self.insert_agent()
204
204
  self.id2agent = {agent._uuid: agent for agent in self.agents}
205
205
  logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
206
206
  assert self.messager is not None
@@ -705,6 +705,9 @@ class AgentGroup:
705
705
  )
706
706
  )
707
707
 
708
+ def get_llm_consumption(self):
709
+ return self.llm.get_consumption()
710
+
708
711
  async def step(self):
709
712
  try:
710
713
  tasks = [agent.run() for agent in self.agents]
@@ -14,9 +14,10 @@ from langchain_core.embeddings import Embeddings
14
14
 
15
15
  from ..agent import Agent, InstitutionAgent
16
16
  from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
17
- SocietyAgent, memory_config_bank, memory_config_firm,
17
+ SocietyAgent)
18
+ from ..cityagent.memory_config import (memory_config_bank, memory_config_firm,
18
19
  memory_config_government, memory_config_nbs,
19
- memory_config_societyagent)
20
+ memory_config_societyagent, memory_config_init)
20
21
  from ..cityagent.initial import bind_agent_info, initialize_social_network
21
22
  from ..cityagent.message_intercept import (EdgeMessageBlock,
22
23
  MessageBlockListener,
@@ -35,16 +36,15 @@ from .storage.pg import PgWriter, create_pg_tables
35
36
 
36
37
  logger = logging.getLogger("pycityagent")
37
38
 
38
-
39
39
  class AgentSimulation:
40
- """城市智能体模拟器"""
40
+ """Agent Simulation"""
41
41
 
42
42
  def __init__(
43
43
  self,
44
44
  config: dict,
45
45
  agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
46
46
  agent_config_file: Optional[dict] = None,
47
- metric_extractor: Optional[list[tuple[int, Callable]]] = None,
47
+ metric_extractors: Optional[list[tuple[int, Callable]]] = None,
48
48
  enable_institution: bool = True,
49
49
  agent_prefix: str = "agent_",
50
50
  exp_name: str = "default_experiment",
@@ -52,10 +52,14 @@ class AgentSimulation:
52
52
  ):
53
53
  """
54
54
  Args:
55
- agent_class: 智能体类
56
- config: 配置
57
- agent_prefix: 智能体名称前缀
58
- exp_name: 实验名称
55
+ config: Configuration
56
+ agent_class: Agent class
57
+ agent_config_file: Agent configuration file
58
+ metric_extractors: Metric extractor
59
+ enable_institution: Whether to enable institution
60
+ agent_prefix: Agent name prefix
61
+ exp_name: Experiment name
62
+ logging_level: Logging level
59
63
  """
60
64
  self.exp_id = str(uuid.uuid4())
61
65
  if isinstance(agent_class, list):
@@ -166,11 +170,12 @@ class AgentSimulation:
166
170
  experiment_name=exp_name,
167
171
  run_id=mlflow_run_id,
168
172
  )
169
- self.metric_extractor = metric_extractor
173
+ if metric_extractors is not None:
174
+ self.metric_extractors = metric_extractors
170
175
  else:
171
176
  logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
172
177
  self.mlflow_client = None
173
- self.metric_extractor = None
178
+ self.metric_extractors = None
174
179
 
175
180
  # pg
176
181
  _pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
@@ -216,8 +221,9 @@ class AgentSimulation:
216
221
  - enable_institution: bool, default is True
217
222
  - agent_config:
218
223
  - agent_config_file: Optional[dict[type[Agent], str]]
224
+ - memory_config_init_func: Optional[Callable]
219
225
  - memory_config_func: Optional[dict[type[Agent], Callable]]
220
- - metric_extractor: Optional[list[tuple[int, Callable]]]
226
+ - metric_extractors: Optional[list[tuple[int, Callable]]]
221
227
  - init_func: Optional[list[Callable[AgentSimulation, None]]]
222
228
  - group_size: Optional[int]
223
229
  - embedding_model: Optional[EmbeddingModel]
@@ -258,7 +264,7 @@ class AgentSimulation:
258
264
  simulation = cls(
259
265
  config=simulation_config,
260
266
  agent_config_file=config["agent_config"].get("agent_config_file", None),
261
- metric_extractor=config["agent_config"].get("metric_extractor", None),
267
+ metric_extractors=config["agent_config"].get("metric_extractors", None),
262
268
  enable_institution=config.get("enable_institution", True),
263
269
  exp_name=config.get("exp_name", "default_experiment"),
264
270
  logging_level=config.get("logging_level", logging.WARNING),
@@ -274,12 +280,16 @@ class AgentSimulation:
274
280
  )
275
281
  simulation._simulator.set_environment(environment)
276
282
  logger.info("Initializing Agents...")
277
- agent_count = []
278
- agent_count.append(config["agent_config"]["number_of_citizen"])
279
- agent_count.append(config["agent_config"]["number_of_firm"])
280
- agent_count.append(config["agent_config"]["number_of_government"])
281
- agent_count.append(config["agent_config"]["number_of_bank"])
282
- agent_count.append(config["agent_config"]["number_of_nbs"])
283
+ agent_count = {
284
+ SocietyAgent: config["agent_config"].get("number_of_citizen", 0),
285
+ FirmAgent: config["agent_config"].get("number_of_firm", 0),
286
+ GovernmentAgent: config["agent_config"].get("number_of_government", 0),
287
+ BankAgent: config["agent_config"].get("number_of_bank", 0),
288
+ NBSAgent: config["agent_config"].get("number_of_nbs", 0),
289
+ }
290
+ if agent_count.get(SocietyAgent, 0) == 0:
291
+ raise ValueError("number_of_citizen is required")
292
+
283
293
  # support MessageInterceptor
284
294
  if "message_intercept" in config:
285
295
  _intercept_config = config["message_intercept"]
@@ -318,7 +328,8 @@ class AgentSimulation:
318
328
  embedding_model=config["agent_config"].get(
319
329
  "embedding_model", SimpleEmbedding()
320
330
  ),
321
- memory_config_func=config["agent_config"].get("memory_config_func", None),
331
+ memory_config_func=config["agent_config"].get("memory_config_func", None),
332
+ memory_config_init_func=config["agent_config"].get("memory_config_init_func", None),
322
333
  **_message_intercept_kwargs,
323
334
  environment=environment,
324
335
  )
@@ -465,7 +476,7 @@ class AgentSimulation:
465
476
 
466
477
  async def init_agents(
467
478
  self,
468
- agent_count: Union[int, list[int]],
479
+ agent_count: dict[type[Agent], int],
469
480
  group_size: int = 10000,
470
481
  pg_sql_writers: int = 32,
471
482
  message_interceptors: int = 1,
@@ -473,6 +484,7 @@ class AgentSimulation:
473
484
  social_black_list: Optional[list[tuple[str, str]]] = None,
474
485
  message_listener: Optional[MessageBlockListenerBase] = None,
475
486
  embedding_model: Embeddings = SimpleEmbedding(),
487
+ memory_config_init_func: Optional[Callable] = None,
476
488
  memory_config_func: Optional[dict[type[Agent], Callable]] = None,
477
489
  environment: Optional[dict[str, str]] = None,
478
490
  ) -> None:
@@ -486,12 +498,13 @@ class AgentSimulation:
486
498
  memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 每个元素表示一个智能体类创建的Memory配置函数
487
499
  environment: 环境变量,用于更新模拟器的环境变量
488
500
  """
489
- if not isinstance(agent_count, list):
490
- agent_count = [agent_count]
501
+ self.agent_count = agent_count
491
502
 
492
503
  if len(self.agent_class) != len(agent_count):
493
- raise ValueError("agent_classagent_count的长度不一致")
504
+ raise ValueError("The length of agent_class and agent_count does not match")
494
505
 
506
+ if memory_config_init_func is not None:
507
+ await memory_config_init(self)
495
508
  if memory_config_func is None:
496
509
  memory_config_func = self.default_memory_config_func # type:ignore
497
510
 
@@ -503,9 +516,11 @@ class AgentSimulation:
503
516
  citizen_params = []
504
517
 
505
518
  # 收集所有参数
519
+ print(self.agent_class)
520
+ print(agent_count)
506
521
  for i in range(len(self.agent_class)):
507
522
  agent_class = self.agent_class[i]
508
- agent_count_i = agent_count[i]
523
+ agent_count_i = agent_count[agent_class]
509
524
  assert memory_config_func is not None
510
525
  memory_config_func_i = memory_config_func.get(
511
526
  agent_class, self.default_memory_config_func[agent_class] # type:ignore
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pycityagent
3
- Version: 2.0.0a65
3
+ Version: 2.0.0a66
4
4
  Summary: LLM-based city environment agent building library
5
5
  Author-email: Yuwei Yan <pinkgranite86@gmail.com>, Junbo Yan <yanjb20thu@gmali.com>, Jun Zhang <zhangjun990222@gmali.com>
6
6
  License: MIT License
@@ -1,3 +1,9 @@
1
+ pycityagent-2.0.0a66.dist-info/RECORD,,
2
+ pycityagent-2.0.0a66.dist-info/LICENSE,sha256=n2HPXiupinpyHMnIkbCf3OTYd3KMqbmldu1e7av0CAU,1084
3
+ pycityagent-2.0.0a66.dist-info/WHEEL,sha256=NW1RskY9zow1Y68W-gXg0oZyBRAugI1JHywIzAIai5o,109
4
+ pycityagent-2.0.0a66.dist-info/entry_points.txt,sha256=BZcne49AAIFv-hawxGnPbblea7X3MtAtoPyDX8L4OC4,132
5
+ pycityagent-2.0.0a66.dist-info/top_level.txt,sha256=yOmeu6cSXmiUtScu53a3s0p7BGtLMaV0aff83EHCTic,43
6
+ pycityagent-2.0.0a66.dist-info/METADATA,sha256=Gd-mcYNBdEIpmg-DYNadjFxVQicXesWi6GK71M85Gfs,9110
1
7
  pycityagent/pycityagent-sim,sha256=vskCJGHJEh0B2dUfmYlVyrcy3sDZ3kBNwjqcYUZpmO8,35449490
2
8
  pycityagent/__init__.py,sha256=PUKWTXc-xdMG7px8oTNclodsILUgypANj2Z647sY63k,808
3
9
  pycityagent/pycityagent-ui,sha256=K2XXJhxIoIk4QWty5i-0FuzZJekkFlbeqrJgPX3tbdE,41225346
@@ -10,10 +16,10 @@ pycityagent/tools/__init__.py,sha256=XtdtGyWeFyK1YOUvWkykBWxemtmwQjWUIuuyU1-gosQ
10
16
  pycityagent/tools/tool.py,sha256=D-ESFlX7EESm5mcvs2zRlGEQTzXbVfQc8G7Vpz8TmAw,8651
11
17
  pycityagent/llm/llmconfig.py,sha256=4Ylf4OFSBEFy8jrOneeX0HvPhWEaF5jGvy1HkXK08Ro,436
12
18
  pycityagent/llm/__init__.py,sha256=iWs6FLgrbRVIiqOf4ILS89gkVCTvS7HFC3vG-MWuyko,205
13
- pycityagent/llm/llm.py,sha256=4AJlTj9llT909gzx9SfcBvbUJrWqHzEo3DaHWgQ3-3I,15852
19
+ pycityagent/llm/llm.py,sha256=upm246fUurltONXies4I4oiZ7NkG0Xk8nj4xVRTNV9o,17215
14
20
  pycityagent/llm/embeddings.py,sha256=2_P4TWm3sJKFdGDx2Q1a2AEapFopDctIXsGuntvmP6E,6816
15
21
  pycityagent/llm/utils.py,sha256=hoNPhvomb1u6lhFX0GctFipw74hVKb7bvUBDqwBzBYw,160
16
- pycityagent/memory/memory.py,sha256=MusbnD-Us5IF16CkYRSlcer-TEbaPsHxQEYzcugo0N4,34589
22
+ pycityagent/memory/memory.py,sha256=-o_W0uC_cOPyWPxdXiCLmU9XiTHF82DQYb53RQyYgeg,34588
17
23
  pycityagent/memory/profile.py,sha256=q8ZS9IBmHCg_X1GONUvXK85P6tCepTKQgXKuvuXYNXw,5203
18
24
  pycityagent/memory/__init__.py,sha256=_Vfdo1HcLWsuuz34_i8e91nnLVYADpMlHHSVaB3xgIk,297
19
25
  pycityagent/memory/memory_base.py,sha256=QG_j3BxZvkadFEeE3uBR_kjl_xcXD1aHUVs8GEF3d6w,5654
@@ -22,9 +28,9 @@ pycityagent/memory/utils.py,sha256=oJWLdPeJy_jcdKcDTo9JAH9kDZhqjoQhhv_zT9qWC0w,8
22
28
  pycityagent/memory/const.py,sha256=6zpJPJXWoH9-yf4RARYYff586agCoud9BRn7sPERB1g,932
23
29
  pycityagent/memory/faiss_query.py,sha256=V3rIw6d1_xcpNqZBbAYz3qfjVNE7NfJ7xOS5SibPtVU,13180
24
30
  pycityagent/memory/state.py,sha256=TYItiyDtehMEQaSBN7PpNrnNxdDM5jGppr9R9Ufv3kA,5134
25
- pycityagent/simulation/simulation.py,sha256=smWCN3qt6i0C-mWd26WM5yVLsfUAYZIgNsHF0fhYAhM,36949
31
+ pycityagent/simulation/simulation.py,sha256=QRD994g4fRzDTbuCEmYGFdyFxix5a5XnXikgy4L19oc,37765
26
32
  pycityagent/simulation/__init__.py,sha256=P5czbcg2d8S0nbbnsQXFIhwzO4CennAhZM8OmKvAeYw,194
27
- pycityagent/simulation/agentgroup.py,sha256=PBmIp6waJR40Gjki9tTXtKMFJ_c5s8SIsi98FHRg-34,31598
33
+ pycityagent/simulation/agentgroup.py,sha256=BpUWmN_CWR7-PJhq9e4pSsVjXZWcoW3Wt-gh5MIxQjQ,31674
28
34
  pycityagent/simulation/storage/pg.py,sha256=xRshSOGttW-p0re0fNBOjOpb-nQ5msIE2LsdT79_E_Y,8425
29
35
  pycityagent/message/message_interceptor.py,sha256=w8XTyZStQtMjILpeAX3VMhAWcYAuaxCgSMwXQU1OryM,8951
30
36
  pycityagent/message/__init__.py,sha256=f5QH7DKPqEAMyfSlBMnl3uouOKlsoel909STlIe7nUk,276
@@ -75,31 +81,25 @@ pycityagent/environment/sim/clock_service.py,sha256=gBUujvX_vIFMKVfcLRyk1GcpRRL6
75
81
  pycityagent/environment/sim/road_service.py,sha256=bKyn3_me0sGmaJVyF6eNeFbdU-9C1yWsa9L7pieDJzg,1285
76
82
  pycityagent/environment/interact/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
83
  pycityagent/environment/interact/interact.py,sha256=ifxPPzuHeqLHIZ_6zvfXMoBOnBsXNIP4bYp7OJ7pnEQ,6588
78
- pycityagent/cityagent/memory_config.py,sha256=8PqSeCdIgkp5a9KN2VdyaraMACu93d5W9TLzYTozmfo,10977
84
+ pycityagent/cityagent/memory_config.py,sha256=nBUkLbxmiNtw4J4s34n4brgyFXnhpNYvQxws0c2n-ho,11356
79
85
  pycityagent/cityagent/bankagent.py,sha256=lr4GEcqt-iwA7DXoDry0WXkV6benmdaAyLpswqSpKlY,2120
80
86
  pycityagent/cityagent/__init__.py,sha256=gcBQ-a50XegFtjigQ7xDXRBZrywBKqifiQFSRnEF8gM,572
81
87
  pycityagent/cityagent/firmagent.py,sha256=UVlNN0lpa4cC4PZVqYzQhbc5VJ2oGsA1731mhbCjnR8,4109
82
88
  pycityagent/cityagent/nbsagent.py,sha256=WIXW__6dZ5IrqBqDCjvGbrCshpXzuFRV3Ww6gkYw7p4,4387
83
- pycityagent/cityagent/initial.py,sha256=BXU9ndY-yIASwoRvFosbhUptGukd68PM1h1vmTEaa34,4946
84
- pycityagent/cityagent/societyagent.py,sha256=RwRS7XaPNFmlTCRJgIpTNkDCWQCG48lpf2SjnZvtcHQ,20288
89
+ pycityagent/cityagent/initial.py,sha256=k9iolPtDj5KiOvg_FaDWPpqhLCpP19T8gmGTKLISa40,4694
90
+ pycityagent/cityagent/societyagent.py,sha256=REGEBWGT7ScKVS9o-x8iJMWSoAPXrMRpBZhAxxYuxS4,20312
85
91
  pycityagent/cityagent/message_intercept.py,sha256=1YMOs6-6bbAaTt7RfMn-ALVIcp0frHN7oqGUkWRy5xE,4519
86
92
  pycityagent/cityagent/governmentagent.py,sha256=HJLuhvEmllu_1KnFEJsYCIasaBJT0BV9Cn_4Y2QGPqg,2791
87
93
  pycityagent/cityagent/blocks/dispatcher.py,sha256=mEa1r3tRS3KI1BMZR_w_sbUGzOj6aUJuiUrsHv1n2n0,2943
88
- pycityagent/cityagent/blocks/needs_block.py,sha256=s8LikgtKORfo_Sw9SQ5_3biNPTof15QuUs4cDynXCyM,15332
94
+ pycityagent/cityagent/blocks/needs_block.py,sha256=ZAf1cQq1N73YBBPEejYF2vAfEkXXTAVUcvqHIy-9Rhs,14935
89
95
  pycityagent/cityagent/blocks/cognition_block.py,sha256=zDbyyLh5GEqje9INJUJA1gMSDPW0wX5lt6yNu97XXn0,14818
90
- pycityagent/cityagent/blocks/social_block.py,sha256=y46mPK9SLvcOHYB64l6qz5ZgT0dWNmX-C3Cusj_44mE,15540
96
+ pycityagent/cityagent/blocks/social_block.py,sha256=Dust9Tpu145h24gbJzEyy_a-IzbnZ0KtTNwTvteVb6w,15138
91
97
  pycityagent/cityagent/blocks/__init__.py,sha256=wydR0s-cCRWgdvQetkfQnD_PU8vC3eTmt2zntcb4fSA,452
92
98
  pycityagent/cityagent/blocks/economy_block.py,sha256=m5B67cgGZ9nKWtrYeak5gxMoCoKlRbATAsXpFajYKyg,19129
93
99
  pycityagent/cityagent/blocks/utils.py,sha256=8O5p1B8JlreIJTGXKAP03rTcn7MvFSR8qJ1_hhszboU,2065
94
100
  pycityagent/cityagent/blocks/other_block.py,sha256=NnDwxQAO5XZ7Uxe-n3qtrfNItHlwFYk2MQsh2GYDKMQ,4338
95
101
  pycityagent/cityagent/blocks/plan_block.py,sha256=v04ePs-6b86TyaP3fl9HPMwWh3_lHp4cjyoEQBHkoDU,11280
96
- pycityagent/cityagent/blocks/mobility_block.py,sha256=f9PlHYX_sCpfOgIhxx3DSA9aViW_e-yNIFpVGwZixUI,12623
102
+ pycityagent/cityagent/blocks/mobility_block.py,sha256=rFRqyVnZG-BMXU0VNAkyrpHFYLv0_QUv4o8vI1wdG5A,15042
97
103
  pycityagent/survey/models.py,sha256=YE50UUt5qJ0O_lIUsSY6XFCGUTkJVNu_L1gAhaCJ2fs,3546
98
104
  pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
99
105
  pycityagent/survey/manager.py,sha256=S5IkwTdelsdtZETChRcfCEczzwSrry_Fly9MY4s3rbk,1681
100
- pycityagent-2.0.0a65.dist-info/RECORD,,
101
- pycityagent-2.0.0a65.dist-info/LICENSE,sha256=n2HPXiupinpyHMnIkbCf3OTYd3KMqbmldu1e7av0CAU,1084
102
- pycityagent-2.0.0a65.dist-info/WHEEL,sha256=NW1RskY9zow1Y68W-gXg0oZyBRAugI1JHywIzAIai5o,109
103
- pycityagent-2.0.0a65.dist-info/entry_points.txt,sha256=BZcne49AAIFv-hawxGnPbblea7X3MtAtoPyDX8L4OC4,132
104
- pycityagent-2.0.0a65.dist-info/top_level.txt,sha256=yOmeu6cSXmiUtScu53a3s0p7BGtLMaV0aff83EHCTic,43
105
- pycityagent-2.0.0a65.dist-info/METADATA,sha256=4zdmqXA1BgOR40yYPnJ1ZC3S3A13TvDy6FChU3-Q9JI,9110