pycityagent 2.0.0a93__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a95__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. pycityagent/agent/agent.py +5 -5
  2. pycityagent/agent/agent_base.py +1 -2
  3. pycityagent/cityagent/__init__.py +6 -5
  4. pycityagent/cityagent/bankagent.py +2 -2
  5. pycityagent/cityagent/blocks/__init__.py +4 -4
  6. pycityagent/cityagent/blocks/cognition_block.py +7 -4
  7. pycityagent/cityagent/blocks/economy_block.py +227 -135
  8. pycityagent/cityagent/blocks/mobility_block.py +70 -27
  9. pycityagent/cityagent/blocks/needs_block.py +11 -12
  10. pycityagent/cityagent/blocks/other_block.py +2 -2
  11. pycityagent/cityagent/blocks/plan_block.py +22 -24
  12. pycityagent/cityagent/blocks/social_block.py +15 -17
  13. pycityagent/cityagent/blocks/utils.py +3 -2
  14. pycityagent/cityagent/firmagent.py +1 -1
  15. pycityagent/cityagent/governmentagent.py +1 -1
  16. pycityagent/cityagent/initial.py +1 -1
  17. pycityagent/cityagent/memory_config.py +0 -1
  18. pycityagent/cityagent/message_intercept.py +7 -8
  19. pycityagent/cityagent/nbsagent.py +1 -1
  20. pycityagent/cityagent/societyagent.py +1 -2
  21. pycityagent/configs/__init__.py +18 -0
  22. pycityagent/configs/exp_config.py +202 -0
  23. pycityagent/configs/sim_config.py +251 -0
  24. pycityagent/configs/utils.py +17 -0
  25. pycityagent/environment/__init__.py +2 -0
  26. pycityagent/{economy → environment/economy}/econ_client.py +14 -32
  27. pycityagent/environment/sim/sim_env.py +17 -24
  28. pycityagent/environment/simulator.py +36 -113
  29. pycityagent/llm/__init__.py +1 -2
  30. pycityagent/llm/llm.py +60 -166
  31. pycityagent/memory/memory.py +13 -12
  32. pycityagent/message/message_interceptor.py +5 -4
  33. pycityagent/message/messager.py +3 -5
  34. pycityagent/metrics/__init__.py +1 -1
  35. pycityagent/metrics/mlflow_client.py +20 -17
  36. pycityagent/pycityagent-sim +0 -0
  37. pycityagent/simulation/agentgroup.py +17 -19
  38. pycityagent/simulation/simulation.py +157 -210
  39. pycityagent/survey/manager.py +0 -2
  40. pycityagent/utils/__init__.py +3 -0
  41. pycityagent/utils/config_const.py +20 -0
  42. pycityagent/workflow/__init__.py +1 -2
  43. pycityagent/workflow/block.py +0 -3
  44. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/METADATA +7 -24
  45. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/RECORD +50 -46
  46. pycityagent/llm/llmconfig.py +0 -18
  47. /pycityagent/{economy → environment/economy}/__init__.py +0 -0
  48. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/LICENSE +0 -0
  49. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/WHEEL +0 -0
  50. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/entry_points.txt +0 -0
  51. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,202 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from collections.abc import Callable
5
+ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ from ..utils import WorkflowType
10
+
11
+ if TYPE_CHECKING:
12
+ from ..simulation import AgentSimulation
13
+
14
+
15
+ class WorkflowStep(BaseModel):
16
+ type: WorkflowType
17
+ days: int = Field(1, description="Number of simulation days")
18
+ times: int = Field(1, description="Step Execution Times")
19
+ description: str = Field("no description")
20
+ func: Optional[Callable] = None
21
+
22
+
23
+ class AgentConfig(BaseModel):
24
+
25
+ number_of_citizen: int = Field(1, description="Number of citizens")
26
+ number_of_firm: int = Field(1, description="Number of firms")
27
+ number_of_government: int = Field(1, description="Number of governments")
28
+ number_of_bank: int = Field(1, description="Number of banks")
29
+ number_of_nbs: int = Field(1, description="Number of neighborhood-based services")
30
+ group_size: int = Field(100, description="Size of agent groups")
31
+ embedding_model: Any = Field(None, description="Embedding model")
32
+ agent_class_configs: Optional[dict[Any, dict[str, list[dict]]]] = None
33
+ memory_config_func: Optional[dict[type["Any"], Callable]] = None
34
+ memory_config_init_func: Optional[Callable] = Field(None)
35
+ init_func: Optional[list[Callable[["AgentSimulation"], None]]] = None
36
+ enable_institution: bool = Field(
37
+ True, description="Whether institutions are enabled in the experiment"
38
+ )
39
+
40
+ @classmethod
41
+ def create(
42
+ cls,
43
+ number_of_citizen: int = 1,
44
+ number_of_firm: int = 1,
45
+ number_of_government: int = 1,
46
+ number_of_bank: int = 1,
47
+ number_of_nbs: int = 1,
48
+ group_size: int = 100,
49
+ embedding_model: Any = None,
50
+ agent_class_configs: Optional[dict[Any, dict[str, list[dict]]]] = None,
51
+ enable_institution: bool = True,
52
+ memory_config_func: Optional[dict[type["Any"], Callable]] = None,
53
+ memory_config_init_func: Optional[Callable] = None,
54
+ init_func: Optional[list[Callable[["AgentSimulation"], None]]] = None,
55
+ ) -> "AgentConfig":
56
+ return cls(
57
+ number_of_citizen=number_of_citizen,
58
+ number_of_firm=number_of_firm,
59
+ number_of_government=number_of_government,
60
+ number_of_bank=number_of_bank,
61
+ number_of_nbs=number_of_nbs,
62
+ group_size=group_size,
63
+ embedding_model=embedding_model,
64
+ agent_class_configs=agent_class_configs,
65
+ enable_institution=enable_institution,
66
+ memory_config_func=memory_config_func,
67
+ memory_config_init_func=memory_config_init_func,
68
+ init_func=init_func,
69
+ )
70
+
71
+
72
+ class EnvironmentConfig(BaseModel):
73
+ weather: str = Field(default="The weather is normal")
74
+ crime: str = Field(default="The crime rate is low")
75
+ pollution: str = Field(default="The pollution level is low")
76
+ temperature: str = Field(default="The temperature is normal")
77
+ day: str = Field(default="Workday")
78
+
79
+ @classmethod
80
+ def create(
81
+ cls,
82
+ weather: str = "The weather is normal",
83
+ crime: str = "The crime rate is low",
84
+ pollution: str = "The pollution level is low",
85
+ temperature: str = "The temperature is normal",
86
+ day: str = "Workday",
87
+ ) -> "EnvironmentConfig":
88
+ return cls(
89
+ weather=weather,
90
+ crime=crime,
91
+ pollution=pollution,
92
+ temperature=temperature,
93
+ day=day,
94
+ )
95
+
96
+
97
+ class MessageInterceptConfig(BaseModel):
98
+ mode: str = Field(..., pattern="^(point|edge)$")
99
+ max_violation_time: int = Field(default=3)
100
+
101
+ @classmethod
102
+ def create(cls, mode: str, max_violation_time: int = 3) -> "MessageInterceptConfig":
103
+ return cls(mode=mode, max_violation_time=max_violation_time)
104
+
105
+
106
+ class ExpConfig(BaseModel):
107
+ agent_config: Optional[AgentConfig] = None
108
+ workflow: Optional[list[WorkflowStep]] = None
109
+ environment: Optional[EnvironmentConfig] = EnvironmentConfig()
110
+ message_intercept: Optional[MessageInterceptConfig] = None
111
+ metric_extractors: Optional[list[tuple[int, Callable]]] = None
112
+ logging_level: int = Field(logging.WARNING)
113
+ exp_name: str = Field("default_experiment")
114
+ llm_semaphore: int = Field(200)
115
+
116
+ @property
117
+ def prop_agent_config(self) -> AgentConfig:
118
+ return self.agent_config # type:ignore
119
+
120
+ @property
121
+ def prop_workflow(self) -> list[WorkflowStep]:
122
+ return self.workflow # type:ignore
123
+
124
+ @property
125
+ def prop_environment(self) -> EnvironmentConfig:
126
+ return self.environment # type:ignore
127
+
128
+ @property
129
+ def prop_message_intercept(self) -> MessageInterceptConfig:
130
+ return self.message_intercept # type:ignore
131
+
132
+ @property
133
+ def prop_metric_extractors(
134
+ self,
135
+ ) -> list[tuple[int, Callable]]:
136
+ return self.metric_extractors # type:ignore
137
+
138
+ def SetAgentConfig(
139
+ self,
140
+ number_of_citizen: int = 1,
141
+ number_of_firm: int = 1,
142
+ number_of_government: int = 1,
143
+ number_of_bank: int = 1,
144
+ number_of_nbs: int = 1,
145
+ group_size: int = 100,
146
+ embedding_model: Any = None,
147
+ agent_class_configs: Optional[dict[Any, dict[str, list[dict]]]] = None,
148
+ enable_institution: bool = True,
149
+ memory_config_func: Optional[dict[type["Any"], Callable]] = None,
150
+ memory_config_init_func: Optional[Callable] = None,
151
+ init_func: Optional[list[Callable[["AgentSimulation"], None]]] = None,
152
+ ) -> "ExpConfig":
153
+ self.agent_config = AgentConfig.create(
154
+ number_of_citizen=number_of_citizen,
155
+ number_of_firm=number_of_firm,
156
+ number_of_government=number_of_government,
157
+ number_of_bank=number_of_bank,
158
+ number_of_nbs=number_of_nbs,
159
+ group_size=group_size,
160
+ embedding_model=embedding_model,
161
+ agent_class_configs=agent_class_configs,
162
+ enable_institution=enable_institution,
163
+ memory_config_func=memory_config_func,
164
+ memory_config_init_func=memory_config_init_func,
165
+ init_func=init_func,
166
+ )
167
+ return self
168
+
169
+ def SetEnvironment(
170
+ self,
171
+ weather: str = "The weather is normal",
172
+ crime: str = "The crime rate is low",
173
+ pollution: str = "The pollution level is low",
174
+ temperature: str = "The temperature is normal",
175
+ day: str = "Workday",
176
+ ) -> "ExpConfig":
177
+ self.environment = EnvironmentConfig.create(
178
+ weather=weather,
179
+ crime=crime,
180
+ pollution=pollution,
181
+ temperature=temperature,
182
+ day=day,
183
+ )
184
+ return self
185
+
186
+ def SetMessageIntercept(
187
+ self,
188
+ mode: Union[Literal["point"], Literal["edge"]],
189
+ max_violation_time: int = 3,
190
+ ) -> "ExpConfig":
191
+ self.message_intercept = MessageInterceptConfig.create(
192
+ mode=mode, max_violation_time=max_violation_time
193
+ )
194
+ return self
195
+
196
+ def SetMetricExtractors(self, metric_extractors: list[tuple[int, Callable]]):
197
+ self.metric_extractors = metric_extractors
198
+ return self
199
+
200
+ def SetWorkFlow(self, workflows: list[WorkflowStep]):
201
+ self.workflow = workflows
202
+ return self
@@ -0,0 +1,251 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from ..utils import LLMRequestType
6
+
7
+ __all__ = [
8
+ "SimConfig",
9
+ ]
10
+
11
+
12
+ class LLMRequestConfig(BaseModel):
13
+ request_type: LLMRequestType = Field(
14
+ ..., description="The type of the request or provider"
15
+ )
16
+ api_key: list[str] = Field(..., description="API key for accessing the service")
17
+ model: str = Field(..., description="The model to use")
18
+
19
+ @classmethod
20
+ def create(
21
+ cls, request_type: LLMRequestType, api_key: list[str], model: str
22
+ ) -> "LLMRequestConfig":
23
+ return cls(request_type=request_type, api_key=api_key, model=model)
24
+
25
+
26
+ class MQTTConfig(BaseModel):
27
+ server: str = Field(..., description="MQTT server address")
28
+ port: int = Field(..., description="Port number for MQTT connection")
29
+ password: Optional[str] = Field(None, description="Password for MQTT connection")
30
+ username: Optional[str] = Field(None, description="Username for MQTT connection")
31
+
32
+ @classmethod
33
+ def create(
34
+ cls,
35
+ server: str,
36
+ port: int,
37
+ username: Optional[str] = None,
38
+ password: Optional[str] = None,
39
+ ) -> "MQTTConfig":
40
+ return cls(server=server, username=username, port=port, password=password)
41
+
42
+
43
+ class SimulatorRequestConfig(BaseModel):
44
+ task_name: str = Field("citysim", description="Name of the simulation task")
45
+ max_day: int = Field(1000, description="Maximum number of days to simulate")
46
+ start_step: int = Field(28800, description="Starting step of the simulation")
47
+ total_step: int = Field(
48
+ 24 * 60 * 60 * 365, description="Total number of steps in the simulation"
49
+ )
50
+ log_dir: str = Field("./log", description="Directory path for saving logs")
51
+ min_step_time: int = Field(
52
+ 1000, description="Minimum time (in seconds) between simulation steps"
53
+ )
54
+ primary_node_ip: str = Field(
55
+ "localhost", description="Primary node IP address for distributed simulation"
56
+ )
57
+
58
+ @classmethod
59
+ def create(
60
+ cls,
61
+ task_name: str = "citysim",
62
+ max_day: int = 1000,
63
+ start_step: int = 28800,
64
+ total_step: int = 24 * 60 * 60 * 365,
65
+ log_dir: str = "./log",
66
+ min_step_time: int = 1000,
67
+ primary_node_ip: str = "localhost",
68
+ ) -> "SimulatorRequestConfig":
69
+ return cls(
70
+ task_name=task_name,
71
+ max_day=max_day,
72
+ start_step=start_step,
73
+ total_step=total_step,
74
+ log_dir=log_dir,
75
+ min_step_time=min_step_time,
76
+ primary_node_ip=primary_node_ip,
77
+ )
78
+
79
+
80
+ class MapRequestConfig(BaseModel):
81
+ file_path: str = Field(..., description="Path to the map file")
82
+
83
+ @classmethod
84
+ def create(cls, file_path: str) -> "MapRequestConfig":
85
+ return cls(file_path=file_path)
86
+
87
+
88
+ class MlflowConfig(BaseModel):
89
+ username: Optional[str] = Field(None, description="Username for MLflow")
90
+ password: Optional[str] = Field(None, description="Password for MLflow")
91
+ mlflow_uri: str = Field(..., description="URI for MLflow server")
92
+
93
+ @classmethod
94
+ def create(cls, username: str, password: str, mlflow_uri: str) -> "MlflowConfig":
95
+ return cls(username=username, password=password, mlflow_uri=mlflow_uri)
96
+
97
+
98
+ class PostgreSQLConfig(BaseModel):
99
+ enabled: Optional[bool] = Field(
100
+ True, description="Whether PostgreSQL storage is enabled"
101
+ )
102
+ dsn: str = Field(..., description="Data source name for PostgreSQL")
103
+
104
+ @classmethod
105
+ def create(cls, dsn: str, enabled: bool = False) -> "PostgreSQLConfig":
106
+ return cls(dsn=dsn, enabled=enabled)
107
+
108
+
109
+ class AvroConfig(BaseModel):
110
+ enabled: Optional[bool] = Field(False, description="Whether Avro storage is enabled")
111
+ path: str = Field(..., description="Avro file storage path")
112
+
113
+ @classmethod
114
+ def create(cls, path: Optional[str] = None, enabled: bool = False) -> "AvroConfig":
115
+ return cls(enabled=enabled, path=path)
116
+
117
+
118
+ class MetricRequest(BaseModel):
119
+ mlflow: Optional[MlflowConfig] = Field(None)
120
+
121
+
122
+ class SimStatus(BaseModel):
123
+ simulator_activated: bool = False
124
+
125
+
126
+ class SimConfig(BaseModel):
127
+ llm_request: Optional["LLMRequestConfig"] = None
128
+ simulator_request: Optional["SimulatorRequestConfig"] = None
129
+ mqtt: Optional["MQTTConfig"] = None
130
+ map_request: Optional["MapRequestConfig"] = None
131
+ metric_request: Optional[MetricRequest] = None
132
+ pgsql: Optional["PostgreSQLConfig"] = None
133
+ avro: Optional["AvroConfig"] = None
134
+ simulator_server_address: Optional[str] = None
135
+ status: Optional["SimStatus"] = SimStatus()
136
+
137
+ @property
138
+ def prop_llm_request(self) -> "LLMRequestConfig":
139
+ return self.llm_request # type:ignore
140
+
141
+ @property
142
+ def prop_status(self) -> "SimStatus":
143
+ return self.status # type:ignore
144
+
145
+ @property
146
+ def prop_simulator_request(self) -> "SimulatorRequestConfig":
147
+ return self.simulator_request # type:ignore
148
+
149
+ @property
150
+ def prop_mqtt(self) -> "MQTTConfig":
151
+ return self.mqtt # type:ignore
152
+
153
+ @property
154
+ def prop_map_request(self) -> "MapRequestConfig":
155
+ return self.map_request # type:ignore
156
+
157
+ @property
158
+ def prop_avro_config(self) -> "AvroConfig":
159
+ return self.avro # type:ignore
160
+
161
+ @property
162
+ def prop_postgre_sql_config(self) -> "PostgreSQLConfig":
163
+ return self.pgsql # type:ignore
164
+
165
+ @property
166
+ def prop_simulator_server_address(self) -> str:
167
+ return self.simulator_server_address # type:ignore
168
+
169
+ @property
170
+ def prop_metric_request(self) -> MetricRequest:
171
+ return self.metric_request # type:ignore
172
+
173
+ def SetLLMRequest(
174
+ self, request_type: LLMRequestType, api_key: list[str], model: str
175
+ ) -> "SimConfig":
176
+ self.llm_request = LLMRequestConfig.create(request_type, api_key, model)
177
+ return self
178
+
179
+ def SetSimulatorRequest(
180
+ self,
181
+ task_name: str = "citysim",
182
+ max_day: int = 1000,
183
+ start_step: int = 28800,
184
+ total_step: int = 24 * 60 * 60 * 365,
185
+ log_dir: str = "./log",
186
+ min_step_time: int = 1000,
187
+ primary_node_ip: str = "localhost",
188
+ ) -> "SimConfig":
189
+ self.simulator_request = SimulatorRequestConfig.create(
190
+ task_name=task_name,
191
+ max_day=max_day,
192
+ start_step=start_step,
193
+ total_step=total_step,
194
+ log_dir=log_dir,
195
+ min_step_time=min_step_time,
196
+ primary_node_ip=primary_node_ip,
197
+ )
198
+ return self
199
+
200
+ def SetMQTT(
201
+ self,
202
+ server: str,
203
+ port: int,
204
+ username: Optional[str] = None,
205
+ password: Optional[str] = None,
206
+ ) -> "SimConfig":
207
+ self.mqtt = MQTTConfig.create(server, port, username, password)
208
+ return self
209
+
210
+ def SetMapRequest(self, file_path: str) -> "SimConfig":
211
+ self.map_request = MapRequestConfig.create(file_path)
212
+ return self
213
+
214
+ def SetMetricRequest(
215
+ self, username: str, password: str, mlflow_uri: str
216
+ ) -> "SimConfig":
217
+ self.metric_request = MetricRequest(
218
+ mlflow=MlflowConfig.create(username, password, mlflow_uri)
219
+ )
220
+ return self
221
+
222
+ def SetAvro(self, path: str, enabled: bool = False) -> "SimConfig":
223
+ self.avro = AvroConfig.create(path, enabled)
224
+ return self
225
+
226
+ def SetPostgreSql(self, path: str, enabled: bool = False) -> "SimConfig":
227
+ self.pgsql = PostgreSQLConfig.create(path, enabled)
228
+ return self
229
+
230
+ def SetServerAddress(self, simulator_server_address: str) -> "SimConfig":
231
+ self.simulator_server_address = simulator_server_address
232
+ return self
233
+
234
+ def model_dump(self, *args, **kwargs):
235
+ exclude_fields = {
236
+ "status",
237
+ }
238
+ data = super().model_dump(*args, **kwargs)
239
+ return {k: v for k, v in data.items() if k not in exclude_fields}
240
+
241
+
242
+ if __name__ == "__main__":
243
+ config = (
244
+ SimConfig()
245
+ .SetLLMRequest("openai", "key", "model") # type:ignore
246
+ .SetMQTT("server", 1883, "username", "password")
247
+ .SetMapRequest("./path/to/map")
248
+ .SetMetricRequest("username", "password", "uri")
249
+ .SetPostgreSql("dsn", True)
250
+ )
251
+ print(config.llm_request)
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Union
4
+
5
+ import yaml
6
+
7
+ if TYPE_CHECKING:
8
+ from .exp_config import ExpConfig
9
+ from .sim_config import SimConfig
10
+
11
+
12
+ def load_config_from_file(
13
+ filepath: str, config_type: Union[type[SimConfig], type[ExpConfig]]
14
+ ) -> Union[SimConfig, ExpConfig]:
15
+ with open(filepath, "r") as file:
16
+ data = yaml.safe_load(file)
17
+ return config_type(**data)
@@ -2,9 +2,11 @@
2
2
 
3
3
  from .sim import AoiService, PersonService
4
4
  from .simulator import Simulator
5
+ from .economy import EconomyClient
5
6
 
6
7
  __all__ = [
7
8
  "Simulator",
8
9
  "PersonService",
9
10
  "AoiService",
11
+ "EconomyClient",
10
12
  ]
@@ -9,7 +9,6 @@ import grpc
9
9
  import pycityproto.city.economy.v2.economy_pb2 as economyv2
10
10
  import pycityproto.city.economy.v2.org_service_pb2 as org_service
11
11
  import pycityproto.city.economy.v2.org_service_pb2_grpc as org_grpc
12
- from google.protobuf import descriptor
13
12
  from google.protobuf.json_format import MessageToDict
14
13
 
15
14
  logger = logging.getLogger("pycityagent")
@@ -19,27 +18,6 @@ __all__ = [
19
18
  ]
20
19
 
21
20
 
22
- def _snake_to_pascal(snake_str):
23
- _res = "".join(word.capitalize() or "_" for word in snake_str.split("_"))
24
- for _word in {
25
- "Gdp",
26
- }:
27
- if _word in _res:
28
- _res = _res.replace(_word, _word.upper())
29
- return _res
30
-
31
-
32
- def camel_to_snake(d):
33
- if not isinstance(d, dict):
34
- return d
35
- return {
36
- re.sub("([a-z0-9])([A-Z])", r"\1_\2", k).lower(): (
37
- camel_to_snake(v) if isinstance(v, dict) else v
38
- )
39
- for k, v in d.items()
40
- }
41
-
42
-
43
21
  def _create_aio_channel(server_address: str, secure: bool = False) -> grpc.aio.Channel:
44
22
  """
45
23
  Create a gRPC asynchronous channel.
@@ -170,8 +148,9 @@ class EconomyClient:
170
148
  agents = await self._aio_stub.BatchGet(
171
149
  org_service.BatchGetRequest(ids=id, type="agent")
172
150
  )
173
- agents = MessageToDict(agents)["agents"]
174
- agent_dicts = [camel_to_snake(agent) for agent in agents]
151
+ agent_dicts = MessageToDict(agents, preserving_proto_field_name=True)[
152
+ "agents"
153
+ ]
175
154
  log["consumption"] = time.time() - start_time
176
155
  self._log_list.append(log)
177
156
  return agent_dicts
@@ -179,10 +158,12 @@ class EconomyClient:
179
158
  agent = await self._aio_stub.GetAgent(
180
159
  org_service.GetAgentRequest(agent_id=id)
181
160
  )
182
- agent_dict = MessageToDict(agent)["agent"]
161
+ agent_dict: dict = MessageToDict(agent, preserving_proto_field_name=True)[
162
+ "agent"
163
+ ]
183
164
  log["consumption"] = time.time() - start_time
184
165
  self._log_list.append(log)
185
- return camel_to_snake(agent_dict)
166
+ return agent_dict
186
167
 
187
168
  async def get_org(
188
169
  self, id: Union[list[int], int]
@@ -202,17 +183,16 @@ class EconomyClient:
202
183
  orgs = await self._aio_stub.BatchGet(
203
184
  org_service.BatchGetRequest(ids=id, type="org")
204
185
  )
205
- orgs = MessageToDict(orgs)["orgs"]
206
- org_dicts = [camel_to_snake(org) for org in orgs]
186
+ org_dicts = MessageToDict(orgs, preserving_proto_field_name=True)["orgs"]
207
187
  log["consumption"] = time.time() - start_time
208
188
  self._log_list.append(log)
209
189
  return org_dicts
210
190
  else:
211
191
  org = await self._aio_stub.GetOrg(org_service.GetOrgRequest(org_id=id))
212
- org_dict = MessageToDict(org)["org"]
192
+ org_dict = MessageToDict(org, preserving_proto_field_name=True)["org"]
213
193
  log["consumption"] = time.time() - start_time
214
194
  self._log_list.append(log)
215
- return camel_to_snake(org_dict)
195
+ return org_dict
216
196
 
217
197
  async def get(
218
198
  self,
@@ -510,7 +490,7 @@ class EconomyClient:
510
490
  return (float(response.taxes_due), list(response.updated_incomes))
511
491
 
512
492
  async def calculate_consumption(
513
- self, org_ids: Union[int, list[int]], agent_id: int, demands: list[int]
493
+ self, org_ids: Union[int, list[int]], agent_id: int, demands: list[int], consumption_accumulation: bool = True
514
494
  ):
515
495
  """
516
496
  Calculate consumption for agents based on their demands.
@@ -519,6 +499,7 @@ class EconomyClient:
519
499
  - `org_ids` (`Union[int, list[int]]`): The ID of the firm providing goods or services.
520
500
  - `agent_id` (`int`): The ID of the agent whose consumption is being calculated.
521
501
  - `demands` (`List[int]`): A list of demand quantities corresponding to each agent.
502
+ - `consumption_accumulation` (`bool`): Weather accumulation.
522
503
 
523
504
  - **Returns**:
524
505
  - `Tuple[int, List[float]]`: A tuple containing the remaining inventory and updated currencies for each agent.
@@ -535,6 +516,7 @@ class EconomyClient:
535
516
  firm_ids=org_ids,
536
517
  agent_id=agent_id,
537
518
  demands=demands,
519
+ consumption_accumulation = consumption_accumulation,
538
520
  )
539
521
  response: org_service.CalculateConsumptionResponse = (
540
522
  await self._aio_stub.CalculateConsumption(request)
@@ -707,7 +689,7 @@ class EconomyClient:
707
689
  """
708
690
  start_time = time.time()
709
691
  log = {"req": "add_delta_value", "start_time": start_time, "consumption": 0}
710
- pascal_key = _snake_to_pascal(key)
692
+ pascal_key = ''.join(x.title() for x in key.split('_'))
711
693
  _request_type = getattr(org_service, f"Add{pascal_key}Request")
712
694
  _request_func = getattr(self._aio_stub, f"Add{pascal_key}")
713
695
 
@@ -6,7 +6,7 @@ import warnings
6
6
  from subprocess import Popen
7
7
  from typing import Optional
8
8
 
9
- from pycitydata.map import Map
9
+ import yaml
10
10
 
11
11
  from ..utils import encode_to_base64, find_free_port
12
12
 
@@ -16,29 +16,22 @@ __all__ = ["ControlSimEnv"]
16
16
  def _generate_yaml_config(
17
17
  map_file: str, max_day: int, start_step: int, total_step: int
18
18
  ) -> str:
19
- map_file = os.path.abspath(map_file)
20
- return f"""
21
- input:
22
- # 地图
23
- map:
24
- file: "{map_file}"
25
-
26
- control:
27
- day: {max_day}
28
- step:
29
- start: {start_step}
30
- total: {total_step}
31
- interval: 1
32
- skip_overtime_trip_when_init: true
33
- enable_platoon: false
34
- enable_indoor: false
35
- prefer_fixed_light: true
36
- enable_collision_avoidance: false # 计算性能下降10倍,需要保证subloop>=5
37
- enable_go_astray: true # 引入串行的路径规划调用,计算性能下降(幅度不确定)
38
- lane_change_model: earliest # mobil (主动变道+强制变道,默认值) earliest (总是尽可能早地变道)
39
-
40
- output:
41
- """
19
+ config_dict = {
20
+ "input": {"map": {"file": os.path.abspath(map_file)}},
21
+ "control": {
22
+ "day": max_day,
23
+ "step": {"start": start_step, "total": total_step, "interval": 1},
24
+ "skip_overtime_trip_when_init": True,
25
+ "enable_platoon": False,
26
+ "enable_indoor": False,
27
+ "prefer_fixed_light": True,
28
+ "enable_collision_avoidance": False,
29
+ "enable_go_astray": True,
30
+ "lane_change_model": "earliest",
31
+ },
32
+ "output": None,
33
+ }
34
+ return yaml.dump(config_dict, allow_unicode=True)
42
35
 
43
36
 
44
37
  class ControlSimEnv: