pycityagent 2.0.0a94__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a96__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +5 -5
- pycityagent/agent/agent_base.py +1 -6
- pycityagent/cityagent/__init__.py +6 -5
- pycityagent/cityagent/bankagent.py +2 -2
- pycityagent/cityagent/blocks/__init__.py +4 -4
- pycityagent/cityagent/blocks/cognition_block.py +7 -4
- pycityagent/cityagent/blocks/economy_block.py +227 -135
- pycityagent/cityagent/blocks/mobility_block.py +70 -27
- pycityagent/cityagent/blocks/needs_block.py +11 -12
- pycityagent/cityagent/blocks/other_block.py +2 -2
- pycityagent/cityagent/blocks/plan_block.py +22 -24
- pycityagent/cityagent/blocks/social_block.py +15 -17
- pycityagent/cityagent/blocks/utils.py +3 -2
- pycityagent/cityagent/firmagent.py +1 -1
- pycityagent/cityagent/governmentagent.py +1 -1
- pycityagent/cityagent/initial.py +1 -1
- pycityagent/cityagent/memory_config.py +0 -1
- pycityagent/cityagent/message_intercept.py +7 -8
- pycityagent/cityagent/nbsagent.py +1 -1
- pycityagent/cityagent/societyagent.py +1 -2
- pycityagent/configs/__init__.py +18 -0
- pycityagent/configs/exp_config.py +202 -0
- pycityagent/configs/sim_config.py +251 -0
- pycityagent/configs/utils.py +17 -0
- pycityagent/environment/__init__.py +2 -0
- pycityagent/{economy → environment/economy}/econ_client.py +14 -32
- pycityagent/environment/sim/sim_env.py +17 -24
- pycityagent/environment/simulator.py +36 -113
- pycityagent/llm/__init__.py +1 -2
- pycityagent/llm/llm.py +54 -167
- pycityagent/memory/memory.py +13 -12
- pycityagent/message/message_interceptor.py +5 -4
- pycityagent/message/messager.py +3 -5
- pycityagent/metrics/__init__.py +1 -1
- pycityagent/metrics/mlflow_client.py +20 -17
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +18 -20
- pycityagent/simulation/simulation.py +157 -210
- pycityagent/survey/manager.py +0 -2
- pycityagent/utils/__init__.py +3 -0
- pycityagent/utils/config_const.py +20 -0
- pycityagent/workflow/__init__.py +1 -2
- pycityagent/workflow/block.py +0 -3
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/METADATA +7 -24
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/RECORD +50 -46
- pycityagent/llm/llmconfig.py +0 -18
- /pycityagent/{economy → environment/economy}/__init__.py +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,202 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from collections.abc import Callable
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
6
|
+
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
|
9
|
+
from ..utils import WorkflowType
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from ..simulation import AgentSimulation
|
13
|
+
|
14
|
+
|
15
|
+
class WorkflowStep(BaseModel):
|
16
|
+
type: WorkflowType
|
17
|
+
days: int = Field(1, description="Number of simulation days")
|
18
|
+
times: int = Field(1, description="Step Execution Times")
|
19
|
+
description: str = Field("no description")
|
20
|
+
func: Optional[Callable] = None
|
21
|
+
|
22
|
+
|
23
|
+
class AgentConfig(BaseModel):
|
24
|
+
|
25
|
+
number_of_citizen: int = Field(1, description="Number of citizens")
|
26
|
+
number_of_firm: int = Field(1, description="Number of firms")
|
27
|
+
number_of_government: int = Field(1, description="Number of governments")
|
28
|
+
number_of_bank: int = Field(1, description="Number of banks")
|
29
|
+
number_of_nbs: int = Field(1, description="Number of neighborhood-based services")
|
30
|
+
group_size: int = Field(100, description="Size of agent groups")
|
31
|
+
embedding_model: Any = Field(None, description="Embedding model")
|
32
|
+
agent_class_configs: Optional[dict[Any, dict[str, Any]]] = None
|
33
|
+
memory_config_func: Optional[dict[type["Any"], Callable]] = None
|
34
|
+
memory_config_init_func: Optional[Callable] = Field(None)
|
35
|
+
init_func: Optional[list[Callable[["AgentSimulation"], None]]] = None
|
36
|
+
enable_institution: bool = Field(
|
37
|
+
True, description="Whether institutions are enabled in the experiment"
|
38
|
+
)
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def create(
|
42
|
+
cls,
|
43
|
+
number_of_citizen: int = 1,
|
44
|
+
number_of_firm: int = 1,
|
45
|
+
number_of_government: int = 1,
|
46
|
+
number_of_bank: int = 1,
|
47
|
+
number_of_nbs: int = 1,
|
48
|
+
group_size: int = 100,
|
49
|
+
embedding_model: Any = None,
|
50
|
+
agent_class_configs: Optional[dict[Any, dict[str, Any]]] = None,
|
51
|
+
enable_institution: bool = True,
|
52
|
+
memory_config_func: Optional[dict[type["Any"], Callable]] = None,
|
53
|
+
memory_config_init_func: Optional[Callable] = None,
|
54
|
+
init_func: Optional[list[Callable[["AgentSimulation"], None]]] = None,
|
55
|
+
) -> "AgentConfig":
|
56
|
+
return cls(
|
57
|
+
number_of_citizen=number_of_citizen,
|
58
|
+
number_of_firm=number_of_firm,
|
59
|
+
number_of_government=number_of_government,
|
60
|
+
number_of_bank=number_of_bank,
|
61
|
+
number_of_nbs=number_of_nbs,
|
62
|
+
group_size=group_size,
|
63
|
+
embedding_model=embedding_model,
|
64
|
+
agent_class_configs=agent_class_configs,
|
65
|
+
enable_institution=enable_institution,
|
66
|
+
memory_config_func=memory_config_func,
|
67
|
+
memory_config_init_func=memory_config_init_func,
|
68
|
+
init_func=init_func,
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
class EnvironmentConfig(BaseModel):
|
73
|
+
weather: str = Field(default="The weather is normal")
|
74
|
+
crime: str = Field(default="The crime rate is low")
|
75
|
+
pollution: str = Field(default="The pollution level is low")
|
76
|
+
temperature: str = Field(default="The temperature is normal")
|
77
|
+
day: str = Field(default="Workday")
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def create(
|
81
|
+
cls,
|
82
|
+
weather: str = "The weather is normal",
|
83
|
+
crime: str = "The crime rate is low",
|
84
|
+
pollution: str = "The pollution level is low",
|
85
|
+
temperature: str = "The temperature is normal",
|
86
|
+
day: str = "Workday",
|
87
|
+
) -> "EnvironmentConfig":
|
88
|
+
return cls(
|
89
|
+
weather=weather,
|
90
|
+
crime=crime,
|
91
|
+
pollution=pollution,
|
92
|
+
temperature=temperature,
|
93
|
+
day=day,
|
94
|
+
)
|
95
|
+
|
96
|
+
|
97
|
+
class MessageInterceptConfig(BaseModel):
|
98
|
+
mode: str = Field(..., pattern="^(point|edge)$")
|
99
|
+
max_violation_time: int = Field(default=3)
|
100
|
+
|
101
|
+
@classmethod
|
102
|
+
def create(cls, mode: str, max_violation_time: int = 3) -> "MessageInterceptConfig":
|
103
|
+
return cls(mode=mode, max_violation_time=max_violation_time)
|
104
|
+
|
105
|
+
|
106
|
+
class ExpConfig(BaseModel):
|
107
|
+
agent_config: Optional[AgentConfig] = None
|
108
|
+
workflow: Optional[list[WorkflowStep]] = None
|
109
|
+
environment: Optional[EnvironmentConfig] = EnvironmentConfig()
|
110
|
+
message_intercept: Optional[MessageInterceptConfig] = None
|
111
|
+
metric_extractors: Optional[list[tuple[int, Callable]]] = None
|
112
|
+
logging_level: int = Field(logging.WARNING)
|
113
|
+
exp_name: str = Field("default_experiment")
|
114
|
+
llm_semaphore: int = Field(200)
|
115
|
+
|
116
|
+
@property
|
117
|
+
def prop_agent_config(self) -> AgentConfig:
|
118
|
+
return self.agent_config # type:ignore
|
119
|
+
|
120
|
+
@property
|
121
|
+
def prop_workflow(self) -> list[WorkflowStep]:
|
122
|
+
return self.workflow # type:ignore
|
123
|
+
|
124
|
+
@property
|
125
|
+
def prop_environment(self) -> EnvironmentConfig:
|
126
|
+
return self.environment # type:ignore
|
127
|
+
|
128
|
+
@property
|
129
|
+
def prop_message_intercept(self) -> MessageInterceptConfig:
|
130
|
+
return self.message_intercept # type:ignore
|
131
|
+
|
132
|
+
@property
|
133
|
+
def prop_metric_extractors(
|
134
|
+
self,
|
135
|
+
) -> list[tuple[int, Callable]]:
|
136
|
+
return self.metric_extractors # type:ignore
|
137
|
+
|
138
|
+
def SetAgentConfig(
|
139
|
+
self,
|
140
|
+
number_of_citizen: int = 1,
|
141
|
+
number_of_firm: int = 1,
|
142
|
+
number_of_government: int = 1,
|
143
|
+
number_of_bank: int = 1,
|
144
|
+
number_of_nbs: int = 1,
|
145
|
+
group_size: int = 100,
|
146
|
+
embedding_model: Any = None,
|
147
|
+
agent_class_configs: Optional[dict[Any, dict[str, list[dict]]]] = None,
|
148
|
+
enable_institution: bool = True,
|
149
|
+
memory_config_func: Optional[dict[type["Any"], Callable]] = None,
|
150
|
+
memory_config_init_func: Optional[Callable] = None,
|
151
|
+
init_func: Optional[list[Callable[["AgentSimulation"], None]]] = None,
|
152
|
+
) -> "ExpConfig":
|
153
|
+
self.agent_config = AgentConfig.create(
|
154
|
+
number_of_citizen=number_of_citizen,
|
155
|
+
number_of_firm=number_of_firm,
|
156
|
+
number_of_government=number_of_government,
|
157
|
+
number_of_bank=number_of_bank,
|
158
|
+
number_of_nbs=number_of_nbs,
|
159
|
+
group_size=group_size,
|
160
|
+
embedding_model=embedding_model,
|
161
|
+
agent_class_configs=agent_class_configs,
|
162
|
+
enable_institution=enable_institution,
|
163
|
+
memory_config_func=memory_config_func,
|
164
|
+
memory_config_init_func=memory_config_init_func,
|
165
|
+
init_func=init_func,
|
166
|
+
)
|
167
|
+
return self
|
168
|
+
|
169
|
+
def SetEnvironment(
|
170
|
+
self,
|
171
|
+
weather: str = "The weather is normal",
|
172
|
+
crime: str = "The crime rate is low",
|
173
|
+
pollution: str = "The pollution level is low",
|
174
|
+
temperature: str = "The temperature is normal",
|
175
|
+
day: str = "Workday",
|
176
|
+
) -> "ExpConfig":
|
177
|
+
self.environment = EnvironmentConfig.create(
|
178
|
+
weather=weather,
|
179
|
+
crime=crime,
|
180
|
+
pollution=pollution,
|
181
|
+
temperature=temperature,
|
182
|
+
day=day,
|
183
|
+
)
|
184
|
+
return self
|
185
|
+
|
186
|
+
def SetMessageIntercept(
|
187
|
+
self,
|
188
|
+
mode: Union[Literal["point"], Literal["edge"]],
|
189
|
+
max_violation_time: int = 3,
|
190
|
+
) -> "ExpConfig":
|
191
|
+
self.message_intercept = MessageInterceptConfig.create(
|
192
|
+
mode=mode, max_violation_time=max_violation_time
|
193
|
+
)
|
194
|
+
return self
|
195
|
+
|
196
|
+
def SetMetricExtractors(self, metric_extractors: list[tuple[int, Callable]]):
|
197
|
+
self.metric_extractors = metric_extractors
|
198
|
+
return self
|
199
|
+
|
200
|
+
def SetWorkFlow(self, workflows: list[WorkflowStep]):
|
201
|
+
self.workflow = workflows
|
202
|
+
return self
|
@@ -0,0 +1,251 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
from ..utils import LLMRequestType
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"SimConfig",
|
9
|
+
]
|
10
|
+
|
11
|
+
|
12
|
+
class LLMRequestConfig(BaseModel):
|
13
|
+
request_type: LLMRequestType = Field(
|
14
|
+
..., description="The type of the request or provider"
|
15
|
+
)
|
16
|
+
api_key: list[str] = Field(..., description="API key for accessing the service")
|
17
|
+
model: str = Field(..., description="The model to use")
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def create(
|
21
|
+
cls, request_type: LLMRequestType, api_key: list[str], model: str
|
22
|
+
) -> "LLMRequestConfig":
|
23
|
+
return cls(request_type=request_type, api_key=api_key, model=model)
|
24
|
+
|
25
|
+
|
26
|
+
class MQTTConfig(BaseModel):
|
27
|
+
server: str = Field(..., description="MQTT server address")
|
28
|
+
port: int = Field(..., description="Port number for MQTT connection")
|
29
|
+
password: Optional[str] = Field(None, description="Password for MQTT connection")
|
30
|
+
username: Optional[str] = Field(None, description="Username for MQTT connection")
|
31
|
+
|
32
|
+
@classmethod
|
33
|
+
def create(
|
34
|
+
cls,
|
35
|
+
server: str,
|
36
|
+
port: int,
|
37
|
+
username: Optional[str] = None,
|
38
|
+
password: Optional[str] = None,
|
39
|
+
) -> "MQTTConfig":
|
40
|
+
return cls(server=server, username=username, port=port, password=password)
|
41
|
+
|
42
|
+
|
43
|
+
class SimulatorRequestConfig(BaseModel):
|
44
|
+
task_name: str = Field("citysim", description="Name of the simulation task")
|
45
|
+
max_day: int = Field(1000, description="Maximum number of days to simulate")
|
46
|
+
start_step: int = Field(28800, description="Starting step of the simulation")
|
47
|
+
total_step: int = Field(
|
48
|
+
24 * 60 * 60 * 365, description="Total number of steps in the simulation"
|
49
|
+
)
|
50
|
+
log_dir: str = Field("./log", description="Directory path for saving logs")
|
51
|
+
min_step_time: int = Field(
|
52
|
+
1000, description="Minimum time (in seconds) between simulation steps"
|
53
|
+
)
|
54
|
+
primary_node_ip: str = Field(
|
55
|
+
"localhost", description="Primary node IP address for distributed simulation"
|
56
|
+
)
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def create(
|
60
|
+
cls,
|
61
|
+
task_name: str = "citysim",
|
62
|
+
max_day: int = 1000,
|
63
|
+
start_step: int = 28800,
|
64
|
+
total_step: int = 24 * 60 * 60 * 365,
|
65
|
+
log_dir: str = "./log",
|
66
|
+
min_step_time: int = 1000,
|
67
|
+
primary_node_ip: str = "localhost",
|
68
|
+
) -> "SimulatorRequestConfig":
|
69
|
+
return cls(
|
70
|
+
task_name=task_name,
|
71
|
+
max_day=max_day,
|
72
|
+
start_step=start_step,
|
73
|
+
total_step=total_step,
|
74
|
+
log_dir=log_dir,
|
75
|
+
min_step_time=min_step_time,
|
76
|
+
primary_node_ip=primary_node_ip,
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
class MapRequestConfig(BaseModel):
|
81
|
+
file_path: str = Field(..., description="Path to the map file")
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def create(cls, file_path: str) -> "MapRequestConfig":
|
85
|
+
return cls(file_path=file_path)
|
86
|
+
|
87
|
+
|
88
|
+
class MlflowConfig(BaseModel):
|
89
|
+
username: Optional[str] = Field(None, description="Username for MLflow")
|
90
|
+
password: Optional[str] = Field(None, description="Password for MLflow")
|
91
|
+
mlflow_uri: str = Field(..., description="URI for MLflow server")
|
92
|
+
|
93
|
+
@classmethod
|
94
|
+
def create(cls, username: str, password: str, mlflow_uri: str) -> "MlflowConfig":
|
95
|
+
return cls(username=username, password=password, mlflow_uri=mlflow_uri)
|
96
|
+
|
97
|
+
|
98
|
+
class PostgreSQLConfig(BaseModel):
|
99
|
+
enabled: Optional[bool] = Field(
|
100
|
+
True, description="Whether PostgreSQL storage is enabled"
|
101
|
+
)
|
102
|
+
dsn: str = Field(..., description="Data source name for PostgreSQL")
|
103
|
+
|
104
|
+
@classmethod
|
105
|
+
def create(cls, dsn: str, enabled: bool = False) -> "PostgreSQLConfig":
|
106
|
+
return cls(dsn=dsn, enabled=enabled)
|
107
|
+
|
108
|
+
|
109
|
+
class AvroConfig(BaseModel):
|
110
|
+
enabled: Optional[bool] = Field(False, description="Whether Avro storage is enabled")
|
111
|
+
path: str = Field(..., description="Avro file storage path")
|
112
|
+
|
113
|
+
@classmethod
|
114
|
+
def create(cls, path: Optional[str] = None, enabled: bool = False) -> "AvroConfig":
|
115
|
+
return cls(enabled=enabled, path=path)
|
116
|
+
|
117
|
+
|
118
|
+
class MetricRequest(BaseModel):
|
119
|
+
mlflow: Optional[MlflowConfig] = Field(None)
|
120
|
+
|
121
|
+
|
122
|
+
class SimStatus(BaseModel):
|
123
|
+
simulator_activated: bool = False
|
124
|
+
|
125
|
+
|
126
|
+
class SimConfig(BaseModel):
|
127
|
+
llm_request: Optional["LLMRequestConfig"] = None
|
128
|
+
simulator_request: Optional["SimulatorRequestConfig"] = None
|
129
|
+
mqtt: Optional["MQTTConfig"] = None
|
130
|
+
map_request: Optional["MapRequestConfig"] = None
|
131
|
+
metric_request: Optional[MetricRequest] = None
|
132
|
+
pgsql: Optional["PostgreSQLConfig"] = None
|
133
|
+
avro: Optional["AvroConfig"] = None
|
134
|
+
simulator_server_address: Optional[str] = None
|
135
|
+
status: Optional["SimStatus"] = SimStatus()
|
136
|
+
|
137
|
+
@property
|
138
|
+
def prop_llm_request(self) -> "LLMRequestConfig":
|
139
|
+
return self.llm_request # type:ignore
|
140
|
+
|
141
|
+
@property
|
142
|
+
def prop_status(self) -> "SimStatus":
|
143
|
+
return self.status # type:ignore
|
144
|
+
|
145
|
+
@property
|
146
|
+
def prop_simulator_request(self) -> "SimulatorRequestConfig":
|
147
|
+
return self.simulator_request # type:ignore
|
148
|
+
|
149
|
+
@property
|
150
|
+
def prop_mqtt(self) -> "MQTTConfig":
|
151
|
+
return self.mqtt # type:ignore
|
152
|
+
|
153
|
+
@property
|
154
|
+
def prop_map_request(self) -> "MapRequestConfig":
|
155
|
+
return self.map_request # type:ignore
|
156
|
+
|
157
|
+
@property
|
158
|
+
def prop_avro_config(self) -> "AvroConfig":
|
159
|
+
return self.avro # type:ignore
|
160
|
+
|
161
|
+
@property
|
162
|
+
def prop_postgre_sql_config(self) -> "PostgreSQLConfig":
|
163
|
+
return self.pgsql # type:ignore
|
164
|
+
|
165
|
+
@property
|
166
|
+
def prop_simulator_server_address(self) -> str:
|
167
|
+
return self.simulator_server_address # type:ignore
|
168
|
+
|
169
|
+
@property
|
170
|
+
def prop_metric_request(self) -> MetricRequest:
|
171
|
+
return self.metric_request # type:ignore
|
172
|
+
|
173
|
+
def SetLLMRequest(
|
174
|
+
self, request_type: LLMRequestType, api_key: list[str], model: str
|
175
|
+
) -> "SimConfig":
|
176
|
+
self.llm_request = LLMRequestConfig.create(request_type, api_key, model)
|
177
|
+
return self
|
178
|
+
|
179
|
+
def SetSimulatorRequest(
|
180
|
+
self,
|
181
|
+
task_name: str = "citysim",
|
182
|
+
max_day: int = 1000,
|
183
|
+
start_step: int = 28800,
|
184
|
+
total_step: int = 24 * 60 * 60 * 365,
|
185
|
+
log_dir: str = "./log",
|
186
|
+
min_step_time: int = 1000,
|
187
|
+
primary_node_ip: str = "localhost",
|
188
|
+
) -> "SimConfig":
|
189
|
+
self.simulator_request = SimulatorRequestConfig.create(
|
190
|
+
task_name=task_name,
|
191
|
+
max_day=max_day,
|
192
|
+
start_step=start_step,
|
193
|
+
total_step=total_step,
|
194
|
+
log_dir=log_dir,
|
195
|
+
min_step_time=min_step_time,
|
196
|
+
primary_node_ip=primary_node_ip,
|
197
|
+
)
|
198
|
+
return self
|
199
|
+
|
200
|
+
def SetMQTT(
|
201
|
+
self,
|
202
|
+
server: str,
|
203
|
+
port: int,
|
204
|
+
username: Optional[str] = None,
|
205
|
+
password: Optional[str] = None,
|
206
|
+
) -> "SimConfig":
|
207
|
+
self.mqtt = MQTTConfig.create(server, port, username, password)
|
208
|
+
return self
|
209
|
+
|
210
|
+
def SetMapRequest(self, file_path: str) -> "SimConfig":
|
211
|
+
self.map_request = MapRequestConfig.create(file_path)
|
212
|
+
return self
|
213
|
+
|
214
|
+
def SetMetricRequest(
|
215
|
+
self, username: str, password: str, mlflow_uri: str
|
216
|
+
) -> "SimConfig":
|
217
|
+
self.metric_request = MetricRequest(
|
218
|
+
mlflow=MlflowConfig.create(username, password, mlflow_uri)
|
219
|
+
)
|
220
|
+
return self
|
221
|
+
|
222
|
+
def SetAvro(self, path: str, enabled: bool = False) -> "SimConfig":
|
223
|
+
self.avro = AvroConfig.create(path, enabled)
|
224
|
+
return self
|
225
|
+
|
226
|
+
def SetPostgreSql(self, path: str, enabled: bool = False) -> "SimConfig":
|
227
|
+
self.pgsql = PostgreSQLConfig.create(path, enabled)
|
228
|
+
return self
|
229
|
+
|
230
|
+
def SetServerAddress(self, simulator_server_address: str) -> "SimConfig":
|
231
|
+
self.simulator_server_address = simulator_server_address
|
232
|
+
return self
|
233
|
+
|
234
|
+
def model_dump(self, *args, **kwargs):
|
235
|
+
exclude_fields = {
|
236
|
+
"status",
|
237
|
+
}
|
238
|
+
data = super().model_dump(*args, **kwargs)
|
239
|
+
return {k: v for k, v in data.items() if k not in exclude_fields}
|
240
|
+
|
241
|
+
|
242
|
+
if __name__ == "__main__":
|
243
|
+
config = (
|
244
|
+
SimConfig()
|
245
|
+
.SetLLMRequest("openai", "key", "model") # type:ignore
|
246
|
+
.SetMQTT("server", 1883, "username", "password")
|
247
|
+
.SetMapRequest("./path/to/map")
|
248
|
+
.SetMetricRequest("username", "password", "uri")
|
249
|
+
.SetPostgreSql("dsn", True)
|
250
|
+
)
|
251
|
+
print(config.llm_request)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any, Union
|
4
|
+
|
5
|
+
import yaml
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from .exp_config import ExpConfig
|
9
|
+
from .sim_config import SimConfig
|
10
|
+
|
11
|
+
|
12
|
+
def load_config_from_file(
|
13
|
+
filepath: str, config_type: Union[type[SimConfig], type[ExpConfig]]
|
14
|
+
) -> Union[SimConfig, ExpConfig]:
|
15
|
+
with open(filepath, "r") as file:
|
16
|
+
data = yaml.safe_load(file)
|
17
|
+
return config_type(**data)
|
@@ -9,7 +9,6 @@ import grpc
|
|
9
9
|
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
10
10
|
import pycityproto.city.economy.v2.org_service_pb2 as org_service
|
11
11
|
import pycityproto.city.economy.v2.org_service_pb2_grpc as org_grpc
|
12
|
-
from google.protobuf import descriptor
|
13
12
|
from google.protobuf.json_format import MessageToDict
|
14
13
|
|
15
14
|
logger = logging.getLogger("pycityagent")
|
@@ -19,27 +18,6 @@ __all__ = [
|
|
19
18
|
]
|
20
19
|
|
21
20
|
|
22
|
-
def _snake_to_pascal(snake_str):
|
23
|
-
_res = "".join(word.capitalize() or "_" for word in snake_str.split("_"))
|
24
|
-
for _word in {
|
25
|
-
"Gdp",
|
26
|
-
}:
|
27
|
-
if _word in _res:
|
28
|
-
_res = _res.replace(_word, _word.upper())
|
29
|
-
return _res
|
30
|
-
|
31
|
-
|
32
|
-
def camel_to_snake(d):
|
33
|
-
if not isinstance(d, dict):
|
34
|
-
return d
|
35
|
-
return {
|
36
|
-
re.sub("([a-z0-9])([A-Z])", r"\1_\2", k).lower(): (
|
37
|
-
camel_to_snake(v) if isinstance(v, dict) else v
|
38
|
-
)
|
39
|
-
for k, v in d.items()
|
40
|
-
}
|
41
|
-
|
42
|
-
|
43
21
|
def _create_aio_channel(server_address: str, secure: bool = False) -> grpc.aio.Channel:
|
44
22
|
"""
|
45
23
|
Create a gRPC asynchronous channel.
|
@@ -170,8 +148,9 @@ class EconomyClient:
|
|
170
148
|
agents = await self._aio_stub.BatchGet(
|
171
149
|
org_service.BatchGetRequest(ids=id, type="agent")
|
172
150
|
)
|
173
|
-
|
174
|
-
|
151
|
+
agent_dicts = MessageToDict(agents, preserving_proto_field_name=True)[
|
152
|
+
"agents"
|
153
|
+
]
|
175
154
|
log["consumption"] = time.time() - start_time
|
176
155
|
self._log_list.append(log)
|
177
156
|
return agent_dicts
|
@@ -179,10 +158,12 @@ class EconomyClient:
|
|
179
158
|
agent = await self._aio_stub.GetAgent(
|
180
159
|
org_service.GetAgentRequest(agent_id=id)
|
181
160
|
)
|
182
|
-
agent_dict = MessageToDict(agent)[
|
161
|
+
agent_dict: dict = MessageToDict(agent, preserving_proto_field_name=True)[
|
162
|
+
"agent"
|
163
|
+
]
|
183
164
|
log["consumption"] = time.time() - start_time
|
184
165
|
self._log_list.append(log)
|
185
|
-
return
|
166
|
+
return agent_dict
|
186
167
|
|
187
168
|
async def get_org(
|
188
169
|
self, id: Union[list[int], int]
|
@@ -202,17 +183,16 @@ class EconomyClient:
|
|
202
183
|
orgs = await self._aio_stub.BatchGet(
|
203
184
|
org_service.BatchGetRequest(ids=id, type="org")
|
204
185
|
)
|
205
|
-
|
206
|
-
org_dicts = [camel_to_snake(org) for org in orgs]
|
186
|
+
org_dicts = MessageToDict(orgs, preserving_proto_field_name=True)["orgs"]
|
207
187
|
log["consumption"] = time.time() - start_time
|
208
188
|
self._log_list.append(log)
|
209
189
|
return org_dicts
|
210
190
|
else:
|
211
191
|
org = await self._aio_stub.GetOrg(org_service.GetOrgRequest(org_id=id))
|
212
|
-
org_dict = MessageToDict(org)["org"]
|
192
|
+
org_dict = MessageToDict(org, preserving_proto_field_name=True)["org"]
|
213
193
|
log["consumption"] = time.time() - start_time
|
214
194
|
self._log_list.append(log)
|
215
|
-
return
|
195
|
+
return org_dict
|
216
196
|
|
217
197
|
async def get(
|
218
198
|
self,
|
@@ -510,7 +490,7 @@ class EconomyClient:
|
|
510
490
|
return (float(response.taxes_due), list(response.updated_incomes))
|
511
491
|
|
512
492
|
async def calculate_consumption(
|
513
|
-
self, org_ids: Union[int, list[int]], agent_id: int, demands: list[int]
|
493
|
+
self, org_ids: Union[int, list[int]], agent_id: int, demands: list[int], consumption_accumulation: bool = True
|
514
494
|
):
|
515
495
|
"""
|
516
496
|
Calculate consumption for agents based on their demands.
|
@@ -519,6 +499,7 @@ class EconomyClient:
|
|
519
499
|
- `org_ids` (`Union[int, list[int]]`): The ID of the firm providing goods or services.
|
520
500
|
- `agent_id` (`int`): The ID of the agent whose consumption is being calculated.
|
521
501
|
- `demands` (`List[int]`): A list of demand quantities corresponding to each agent.
|
502
|
+
- `consumption_accumulation` (`bool`): Weather accumulation.
|
522
503
|
|
523
504
|
- **Returns**:
|
524
505
|
- `Tuple[int, List[float]]`: A tuple containing the remaining inventory and updated currencies for each agent.
|
@@ -535,6 +516,7 @@ class EconomyClient:
|
|
535
516
|
firm_ids=org_ids,
|
536
517
|
agent_id=agent_id,
|
537
518
|
demands=demands,
|
519
|
+
consumption_accumulation = consumption_accumulation,
|
538
520
|
)
|
539
521
|
response: org_service.CalculateConsumptionResponse = (
|
540
522
|
await self._aio_stub.CalculateConsumption(request)
|
@@ -707,7 +689,7 @@ class EconomyClient:
|
|
707
689
|
"""
|
708
690
|
start_time = time.time()
|
709
691
|
log = {"req": "add_delta_value", "start_time": start_time, "consumption": 0}
|
710
|
-
pascal_key =
|
692
|
+
pascal_key = ''.join(x.title() for x in key.split('_'))
|
711
693
|
_request_type = getattr(org_service, f"Add{pascal_key}Request")
|
712
694
|
_request_func = getattr(self._aio_stub, f"Add{pascal_key}")
|
713
695
|
|
@@ -6,7 +6,7 @@ import warnings
|
|
6
6
|
from subprocess import Popen
|
7
7
|
from typing import Optional
|
8
8
|
|
9
|
-
|
9
|
+
import yaml
|
10
10
|
|
11
11
|
from ..utils import encode_to_base64, find_free_port
|
12
12
|
|
@@ -16,29 +16,22 @@ __all__ = ["ControlSimEnv"]
|
|
16
16
|
def _generate_yaml_config(
|
17
17
|
map_file: str, max_day: int, start_step: int, total_step: int
|
18
18
|
) -> str:
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
prefer_fixed_light: true
|
36
|
-
enable_collision_avoidance: false # 计算性能下降10倍,需要保证subloop>=5
|
37
|
-
enable_go_astray: true # 引入串行的路径规划调用,计算性能下降(幅度不确定)
|
38
|
-
lane_change_model: earliest # mobil (主动变道+强制变道,默认值) earliest (总是尽可能早地变道)
|
39
|
-
|
40
|
-
output:
|
41
|
-
"""
|
19
|
+
config_dict = {
|
20
|
+
"input": {"map": {"file": os.path.abspath(map_file)}},
|
21
|
+
"control": {
|
22
|
+
"day": max_day,
|
23
|
+
"step": {"start": start_step, "total": total_step, "interval": 1},
|
24
|
+
"skip_overtime_trip_when_init": True,
|
25
|
+
"enable_platoon": False,
|
26
|
+
"enable_indoor": False,
|
27
|
+
"prefer_fixed_light": True,
|
28
|
+
"enable_collision_avoidance": False,
|
29
|
+
"enable_go_astray": True,
|
30
|
+
"lane_change_model": "earliest",
|
31
|
+
},
|
32
|
+
"output": None,
|
33
|
+
}
|
34
|
+
return yaml.dump(config_dict, allow_unicode=True)
|
42
35
|
|
43
36
|
|
44
37
|
class ControlSimEnv:
|