pycityagent 2.0.0a42__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +23 -0
- pycityagent/agent.py +833 -0
- pycityagent/cli/wrapper.py +44 -0
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +355 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +198 -0
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +416 -0
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +39 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +44 -0
- pycityagent/environment/sim/economy_services.py +192 -0
- pycityagent/environment/sim/lane_service.py +111 -0
- pycityagent/environment/sim/light_service.py +122 -0
- pycityagent/environment/sim/person_service.py +295 -0
- pycityagent/environment/sim/road_service.py +39 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +59 -0
- pycityagent/environment/simulator.py +331 -0
- pycityagent/environment/utils/__init__.py +14 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +244 -0
- pycityagent/environment/utils/geojson.py +24 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +41 -0
- pycityagent/llm/__init__.py +11 -0
- pycityagent/llm/embeddings.py +231 -0
- pycityagent/llm/llm.py +377 -0
- pycityagent/llm/llmconfig.py +13 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +13 -0
- pycityagent/memory/const.py +43 -0
- pycityagent/memory/faiss_query.py +302 -0
- pycityagent/memory/memory.py +448 -0
- pycityagent/memory/memory_base.py +170 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +28 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +88 -0
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +8 -0
- pycityagent/simulation/agentgroup.py +580 -0
- pycityagent/simulation/simulation.py +634 -0
- pycityagent/simulation/storage/pg.py +184 -0
- pycityagent/survey/__init__.py +4 -0
- pycityagent/survey/manager.py +54 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +11 -0
- pycityagent/utils/avro_schema.py +109 -0
- pycityagent/utils/decorators.py +99 -0
- pycityagent/utils/parsers/__init__.py +13 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/utils/pg_query.py +92 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/__init__.py +26 -0
- pycityagent/workflow/block.py +211 -0
- pycityagent/workflow/prompt.py +79 -0
- pycityagent/workflow/tool.py +240 -0
- pycityagent/workflow/trigger.py +163 -0
- pycityagent-2.0.0a42.dist-info/LICENSE +21 -0
- pycityagent-2.0.0a42.dist-info/METADATA +235 -0
- pycityagent-2.0.0a42.dist-info/RECORD +81 -0
- pycityagent-2.0.0a42.dist-info/WHEEL +5 -0
- pycityagent-2.0.0a42.dist-info/entry_points.txt +3 -0
- pycityagent-2.0.0a42.dist-info/top_level.txt +3 -0
@@ -0,0 +1,44 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
import subprocess
|
4
|
+
import signal
|
5
|
+
|
6
|
+
_script_dir = os.path.dirname(os.path.abspath(__file__))
|
7
|
+
_parent_dir = os.path.dirname(_script_dir)
|
8
|
+
|
9
|
+
def wrapper(bin: str):
|
10
|
+
binary_path = os.path.join(_parent_dir, bin)
|
11
|
+
if not os.path.exists(binary_path):
|
12
|
+
print(f"Error: {binary_path} not found")
|
13
|
+
sys.exit(1)
|
14
|
+
# get command line arguments
|
15
|
+
args = sys.argv[1:]
|
16
|
+
# run the binary
|
17
|
+
p = subprocess.Popen(
|
18
|
+
[binary_path] + args,
|
19
|
+
env=os.environ,
|
20
|
+
stdin=sys.stdin,
|
21
|
+
stdout=sys.stdout,
|
22
|
+
stderr=sys.stderr,
|
23
|
+
)
|
24
|
+
# register signal handler
|
25
|
+
def signal_handler(sig, frame):
|
26
|
+
if p.poll() is None:
|
27
|
+
p.send_signal(sig)
|
28
|
+
else:
|
29
|
+
sys.exit(p.poll())
|
30
|
+
signals = [signal.SIGINT, signal.SIGTERM, signal.SIGHUP]
|
31
|
+
for sig in signals:
|
32
|
+
signal.signal(sig, signal_handler)
|
33
|
+
# wait for the child process to exit
|
34
|
+
while p.poll() is None:
|
35
|
+
pass
|
36
|
+
# exit with the same code as the child process
|
37
|
+
sys.exit(p.poll())
|
38
|
+
|
39
|
+
|
40
|
+
def pycityagent_sim():
|
41
|
+
wrapper("pycityagent-sim")
|
42
|
+
|
43
|
+
def pycityagent_ui():
|
44
|
+
wrapper("pycityagent-ui")
|
@@ -0,0 +1,355 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from typing import Any, Literal, Union
|
4
|
+
|
5
|
+
import grpc
|
6
|
+
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
7
|
+
import pycityproto.city.economy.v2.org_service_pb2 as org_service
|
8
|
+
import pycityproto.city.economy.v2.org_service_pb2_grpc as org_grpc
|
9
|
+
from google.protobuf import descriptor
|
10
|
+
|
11
|
+
economyv2.ORG_TYPE_BANK
|
12
|
+
|
13
|
+
__all__ = [
|
14
|
+
"EconomyClient",
|
15
|
+
]
|
16
|
+
|
17
|
+
|
18
|
+
def _snake_to_pascal(snake_str):
|
19
|
+
_res = "".join(word.capitalize() or "_" for word in snake_str.split("_"))
|
20
|
+
for _word in {
|
21
|
+
"Gdp",
|
22
|
+
}:
|
23
|
+
if _word in _res:
|
24
|
+
_res = _res.replace(_word, _word.upper())
|
25
|
+
return _res
|
26
|
+
|
27
|
+
|
28
|
+
def _get_field_type_and_repeated(message, field_name: str) -> tuple[Any, bool]:
|
29
|
+
try:
|
30
|
+
field_descriptor = message.DESCRIPTOR.fields_by_name[field_name]
|
31
|
+
field_type = field_descriptor.type
|
32
|
+
_type_mapping = {
|
33
|
+
descriptor.FieldDescriptor.TYPE_FLOAT: float,
|
34
|
+
descriptor.FieldDescriptor.TYPE_INT32: int,
|
35
|
+
}
|
36
|
+
is_repeated = (
|
37
|
+
field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED
|
38
|
+
)
|
39
|
+
return (_type_mapping.get(field_type), is_repeated)
|
40
|
+
except KeyError:
|
41
|
+
raise KeyError(f"Invalid message {message} and filed name {field_name}!")
|
42
|
+
|
43
|
+
|
44
|
+
def _create_aio_channel(server_address: str, secure: bool = False) -> grpc.aio.Channel:
|
45
|
+
"""
|
46
|
+
Create a grpc asynchronous channel
|
47
|
+
|
48
|
+
Args:
|
49
|
+
- server_address (str): server address.
|
50
|
+
- secure (bool, optional): Defaults to False. Whether to use a secure connection. Defaults to False.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
- grpc.aio.Channel: grpc asynchronous channel.
|
54
|
+
"""
|
55
|
+
if server_address.startswith("http://"):
|
56
|
+
server_address = server_address.split("//")[1]
|
57
|
+
if secure:
|
58
|
+
raise ValueError("secure channel must use `https` or not use `http`")
|
59
|
+
elif server_address.startswith("https://"):
|
60
|
+
server_address = server_address.split("//")[1]
|
61
|
+
if not secure:
|
62
|
+
secure = True
|
63
|
+
|
64
|
+
if secure:
|
65
|
+
return grpc.aio.secure_channel(server_address, grpc.ssl_channel_credentials())
|
66
|
+
else:
|
67
|
+
return grpc.aio.insecure_channel(server_address)
|
68
|
+
|
69
|
+
|
70
|
+
class EconomyClient:
|
71
|
+
"""
|
72
|
+
Client side of Economy service
|
73
|
+
"""
|
74
|
+
|
75
|
+
def __init__(self, server_address: str, secure: bool = False):
|
76
|
+
"""
|
77
|
+
Constructor of EconomyClient
|
78
|
+
|
79
|
+
Args:
|
80
|
+
- server_address (str): Economy server address
|
81
|
+
- secure (bool, optional): Defaults to False. Whether to use a secure connection. Defaults to False.
|
82
|
+
"""
|
83
|
+
self.server_address = server_address
|
84
|
+
self.secure = secure
|
85
|
+
aio_channel = _create_aio_channel(server_address, secure)
|
86
|
+
self._aio_stub = org_grpc.OrgServiceStub(aio_channel)
|
87
|
+
|
88
|
+
def __getstate__(self):
|
89
|
+
"""
|
90
|
+
Copy the object's state from self.__dict__ which contains
|
91
|
+
all our instance attributes. Always use the dict.copy()
|
92
|
+
method to avoid modifying the original state.
|
93
|
+
"""
|
94
|
+
state = self.__dict__.copy()
|
95
|
+
# Remove the non-picklable entries.
|
96
|
+
del state["_aio_stub"]
|
97
|
+
return state
|
98
|
+
|
99
|
+
def __setstate__(self, state):
|
100
|
+
""" "
|
101
|
+
Restore instance attributes (i.e., filename and mode) from the
|
102
|
+
unpickled state dictionary.
|
103
|
+
"""
|
104
|
+
self.__dict__.update(state)
|
105
|
+
# Re-initialize the channel after unpickling
|
106
|
+
aio_channel = _create_aio_channel(self.server_address, self.secure)
|
107
|
+
self._aio_stub = org_grpc.OrgServiceStub(aio_channel)
|
108
|
+
|
109
|
+
async def get(
|
110
|
+
self,
|
111
|
+
id: int,
|
112
|
+
key: str,
|
113
|
+
) -> Any:
|
114
|
+
"""
|
115
|
+
Get specific value
|
116
|
+
|
117
|
+
Args:
|
118
|
+
- id (int): the id of `Org` or `Agent`.
|
119
|
+
- key (str): the attribute to fetch.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
- Any
|
123
|
+
"""
|
124
|
+
pascal_key = _snake_to_pascal(key)
|
125
|
+
_request_type = getattr(org_service, f"Get{pascal_key}Request")
|
126
|
+
_request_func = getattr(self._aio_stub, f"Get{pascal_key}")
|
127
|
+
response = await _request_func(_request_type(org_id=id))
|
128
|
+
value_type, is_repeated = _get_field_type_and_repeated(response, field_name=key)
|
129
|
+
if is_repeated:
|
130
|
+
return list(getattr(response, key))
|
131
|
+
else:
|
132
|
+
return value_type(getattr(response, key))
|
133
|
+
|
134
|
+
async def update(
|
135
|
+
self,
|
136
|
+
id: int,
|
137
|
+
key: str,
|
138
|
+
value: Any,
|
139
|
+
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
140
|
+
) -> Any:
|
141
|
+
"""
|
142
|
+
Update key-value pair
|
143
|
+
|
144
|
+
Args:
|
145
|
+
- id (int): the id of `Org` or `Agent`.
|
146
|
+
- key (str): the attribute to update.
|
147
|
+
- mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
148
|
+
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
- Any
|
152
|
+
"""
|
153
|
+
pascal_key = _snake_to_pascal(key)
|
154
|
+
_request_type = getattr(org_service, f"Set{pascal_key}Request")
|
155
|
+
_request_func = getattr(self._aio_stub, f"Set{pascal_key}")
|
156
|
+
if mode == "merge":
|
157
|
+
orig_value = await self.get(id, key)
|
158
|
+
_orig_type = type(orig_value)
|
159
|
+
_new_type = type(value)
|
160
|
+
if _orig_type != _new_type:
|
161
|
+
logging.debug(
|
162
|
+
f"Inconsistent type of original value {_orig_type.__name__} and to-update value {_new_type.__name__}"
|
163
|
+
)
|
164
|
+
else:
|
165
|
+
if isinstance(orig_value, set):
|
166
|
+
orig_value.update(set(value))
|
167
|
+
value = orig_value
|
168
|
+
elif isinstance(orig_value, dict):
|
169
|
+
orig_value.update(dict(value))
|
170
|
+
value = orig_value
|
171
|
+
elif isinstance(orig_value, list):
|
172
|
+
orig_value.extend(list(value))
|
173
|
+
value = orig_value
|
174
|
+
else:
|
175
|
+
logging.warning(
|
176
|
+
f"Type of {type(orig_value)} does not support mode `merge`, using `replace` instead!"
|
177
|
+
)
|
178
|
+
return await _request_func(
|
179
|
+
_request_type(
|
180
|
+
**{
|
181
|
+
"org_id": id,
|
182
|
+
key: value,
|
183
|
+
}
|
184
|
+
)
|
185
|
+
)
|
186
|
+
|
187
|
+
async def add_agents(self, configs: Union[list[dict], dict]):
|
188
|
+
if isinstance(configs, dict):
|
189
|
+
configs = [configs]
|
190
|
+
tasks = [
|
191
|
+
self._aio_stub.AddAgent(
|
192
|
+
org_service.AddAgentRequest(
|
193
|
+
agent=economyv2.Agent(
|
194
|
+
id=config["id"],
|
195
|
+
currency=config.get("currency", 0.0),
|
196
|
+
)
|
197
|
+
)
|
198
|
+
)
|
199
|
+
for config in configs
|
200
|
+
]
|
201
|
+
responses = await asyncio.gather(*tasks)
|
202
|
+
|
203
|
+
async def add_orgs(self, configs: Union[list[dict], dict]):
|
204
|
+
if isinstance(configs, dict):
|
205
|
+
configs = [configs]
|
206
|
+
tasks = [
|
207
|
+
self._aio_stub.AddOrg(
|
208
|
+
org_service.AddOrgRequest(
|
209
|
+
org=economyv2.Org(
|
210
|
+
id=config["id"],
|
211
|
+
type=config["type"],
|
212
|
+
nominal_gdp=config.get("nominal_gdp", []),
|
213
|
+
real_gdp=config.get("real_gdp", []),
|
214
|
+
unemployment=config.get("unemployment", []),
|
215
|
+
wages=config.get("wages", []),
|
216
|
+
prices=config.get("prices", []),
|
217
|
+
inventory=config.get("inventory", 0),
|
218
|
+
price=config.get("price", 0),
|
219
|
+
currency=config.get("currency", 0.0),
|
220
|
+
interest_rate=config.get("interest_rate", 0.0),
|
221
|
+
bracket_cutoffs=config.get("bracket_cutoffs", []),
|
222
|
+
bracket_rates=config.get("bracket_rates", []),
|
223
|
+
)
|
224
|
+
)
|
225
|
+
)
|
226
|
+
for config in configs
|
227
|
+
]
|
228
|
+
responses = await asyncio.gather(*tasks)
|
229
|
+
|
230
|
+
async def calculate_taxes_due(
|
231
|
+
self,
|
232
|
+
org_id: int,
|
233
|
+
agent_ids: list[int],
|
234
|
+
incomes: list[float],
|
235
|
+
enable_redistribution: bool,
|
236
|
+
):
|
237
|
+
request = org_service.CalculateTaxesDueRequest(
|
238
|
+
government_id=org_id,
|
239
|
+
agent_ids=agent_ids,
|
240
|
+
incomes=incomes,
|
241
|
+
enable_redistribution=enable_redistribution,
|
242
|
+
)
|
243
|
+
response: org_service.CalculateTaxesDueResponse = (
|
244
|
+
await self._aio_stub.CalculateTaxesDue(request)
|
245
|
+
)
|
246
|
+
return (float(response.taxes_due), list(response.updated_incomes))
|
247
|
+
|
248
|
+
async def calculate_consumption(
|
249
|
+
self, org_id: int, agent_ids: list[int], demands: list[int]
|
250
|
+
):
|
251
|
+
request = org_service.CalculateConsumptionRequest(
|
252
|
+
firm_id=org_id,
|
253
|
+
agent_ids=agent_ids,
|
254
|
+
demands=demands,
|
255
|
+
)
|
256
|
+
response: org_service.CalculateConsumptionResponse = (
|
257
|
+
await self._aio_stub.CalculateConsumption(request)
|
258
|
+
)
|
259
|
+
return (int(response.remain_inventory), list(response.updated_currencies))
|
260
|
+
|
261
|
+
async def calculate_interest(self, org_id: int, agent_ids: list[int]):
|
262
|
+
request = org_service.CalculateInterestRequest(
|
263
|
+
bank_id=org_id,
|
264
|
+
agent_ids=agent_ids,
|
265
|
+
)
|
266
|
+
response: org_service.CalculateInterestResponse = (
|
267
|
+
await self._aio_stub.CalculateInterest(request)
|
268
|
+
)
|
269
|
+
return (float(response.total_interest), list(response.updated_currencies))
|
270
|
+
|
271
|
+
async def remove_agents(self, agent_ids: Union[int, list[int]]):
|
272
|
+
if isinstance(agent_ids, int):
|
273
|
+
agent_ids = [agent_ids]
|
274
|
+
tasks = [
|
275
|
+
self._aio_stub.RemoveAgent(
|
276
|
+
org_service.RemoveAgentRequest(agent_id=agent_id)
|
277
|
+
)
|
278
|
+
for agent_id in agent_ids
|
279
|
+
]
|
280
|
+
responses = await asyncio.gather(*tasks)
|
281
|
+
|
282
|
+
async def remove_orgs(self, org_ids: Union[int, list[int]]):
|
283
|
+
if isinstance(org_ids, int):
|
284
|
+
org_ids = [org_ids]
|
285
|
+
tasks = [
|
286
|
+
self._aio_stub.RemoveOrg(org_service.RemoveOrgRequest(org_id=org_id))
|
287
|
+
for org_id in org_ids
|
288
|
+
]
|
289
|
+
responses = await asyncio.gather(*tasks)
|
290
|
+
|
291
|
+
async def save(self, file_path: str) -> tuple[list[int], list[int]]:
|
292
|
+
request = org_service.SaveEconomyEntitiesRequest(
|
293
|
+
file_path=file_path,
|
294
|
+
)
|
295
|
+
response: org_service.SaveEconomyEntitiesResponse = (
|
296
|
+
await self._aio_stub.SaveEconomyEntities(request)
|
297
|
+
)
|
298
|
+
# current agent ids and org ids
|
299
|
+
return (list(response.agent_ids), list(response.org_ids))
|
300
|
+
|
301
|
+
async def load(self, file_path: str):
|
302
|
+
request = org_service.LoadEconomyEntitiesRequest(
|
303
|
+
file_path=file_path,
|
304
|
+
)
|
305
|
+
response: org_service.LoadEconomyEntitiesResponse = (
|
306
|
+
await self._aio_stub.LoadEconomyEntities(request)
|
307
|
+
)
|
308
|
+
# current agent ids and org ids
|
309
|
+
return (list(response.agent_ids), list(response.org_ids))
|
310
|
+
|
311
|
+
async def get_org_entity_ids(self, org_type: economyv2.OrgType) -> list[int]:
|
312
|
+
request = org_service.GetOrgEntityIdsRequest(
|
313
|
+
type=org_type,
|
314
|
+
)
|
315
|
+
response: org_service.GetOrgEntityIdsResponse = (
|
316
|
+
await self._aio_stub.GetOrgEntityIds(request)
|
317
|
+
)
|
318
|
+
return list(response.org_ids)
|
319
|
+
|
320
|
+
async def add_delta_value(
|
321
|
+
self,
|
322
|
+
id: int,
|
323
|
+
key: str,
|
324
|
+
value: Any,
|
325
|
+
) -> Any:
|
326
|
+
"""
|
327
|
+
Add key-value pair
|
328
|
+
|
329
|
+
Args:
|
330
|
+
- id (int): the id of `Org` or `Agent`.
|
331
|
+
- key (str): the attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
|
332
|
+
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
- Any
|
336
|
+
"""
|
337
|
+
pascal_key = _snake_to_pascal(key)
|
338
|
+
_request_type = getattr(org_service, f"Add{pascal_key}Request")
|
339
|
+
_request_func = getattr(self._aio_stub, f"Add{pascal_key}")
|
340
|
+
_available_keys = {
|
341
|
+
"inventory",
|
342
|
+
"price",
|
343
|
+
"interest_rate",
|
344
|
+
"currency",
|
345
|
+
}
|
346
|
+
if key not in _available_keys:
|
347
|
+
raise ValueError(f"Invalid key `{key}`, can only be {_available_keys}!")
|
348
|
+
return await _request_func(
|
349
|
+
_request_type(
|
350
|
+
**{
|
351
|
+
"org_id": id,
|
352
|
+
f"delta_{key}": value,
|
353
|
+
}
|
354
|
+
)
|
355
|
+
)
|
File without changes
|
@@ -0,0 +1,198 @@
|
|
1
|
+
"""环境相关的Interaction定义"""
|
2
|
+
|
3
|
+
from enum import Enum
|
4
|
+
from typing import Callable, Optional, Any
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import Callable, Any
|
7
|
+
|
8
|
+
|
9
|
+
class ActionType(Enum):
|
10
|
+
"""
|
11
|
+
行动类型枚举 所有行动本质上为数据推送
|
12
|
+
Action Type enumeration, all actions are essentially data push
|
13
|
+
|
14
|
+
Types:
|
15
|
+
- Sim = 1, 用于表示与模拟器对接的行动
|
16
|
+
- Hub = 2, 用于表示与AppHub(前端)对接的行动
|
17
|
+
- Comp = 3, 表示综合类型 (可能同时包含与Sim以及Hub的交互)
|
18
|
+
"""
|
19
|
+
|
20
|
+
Sim = 1
|
21
|
+
Hub = 2
|
22
|
+
Comp = 3
|
23
|
+
|
24
|
+
|
25
|
+
class Action:
|
26
|
+
"""
|
27
|
+
- Action
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
agent,
|
33
|
+
type: ActionType,
|
34
|
+
source: Optional[str] = None,
|
35
|
+
before: Optional[Callable[[list], Any]] = None,
|
36
|
+
) -> None:
|
37
|
+
"""
|
38
|
+
默认初始化
|
39
|
+
|
40
|
+
Args:
|
41
|
+
- agent (Agent): the related agent
|
42
|
+
- type (ActionType)
|
43
|
+
- source (str): 数据来源, 默认为None, 如果为None则会从接收用户传入的数据作为Forward函数参数, 否则从WM.Reason数据缓存中取对应数据作为参数
|
44
|
+
- before (function): 数据处理方法, 用于当Reason缓存中的参数与标准格式不符时使用
|
45
|
+
"""
|
46
|
+
self._agent = agent
|
47
|
+
self._type = type
|
48
|
+
self._source = source
|
49
|
+
self._before = before
|
50
|
+
|
51
|
+
def get_source(self):
|
52
|
+
"""
|
53
|
+
获取source数据
|
54
|
+
"""
|
55
|
+
if self._source != None:
|
56
|
+
source = self._agent.Brain.Memory.Working.Reason[self._source]
|
57
|
+
if self._before != None:
|
58
|
+
source = self._before(source)
|
59
|
+
return source
|
60
|
+
else:
|
61
|
+
return None
|
62
|
+
|
63
|
+
@abstractmethod
|
64
|
+
async def Forward(self):
|
65
|
+
"""接口函数"""
|
66
|
+
|
67
|
+
|
68
|
+
class SimAction(Action):
|
69
|
+
"""SimAction: 模拟器关联Action"""
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
agent,
|
74
|
+
source: Optional[str] = None,
|
75
|
+
before: Optional[Callable[[list], Any]] = None,
|
76
|
+
) -> None:
|
77
|
+
super().__init__(agent, ActionType.Sim, source, before)
|
78
|
+
|
79
|
+
|
80
|
+
class HubAction(Action):
|
81
|
+
"""HubAction: Apphub关联Action"""
|
82
|
+
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
agent,
|
86
|
+
source: Optional[str] = None,
|
87
|
+
before: Optional[Callable[[list], Any]] = None,
|
88
|
+
) -> None:
|
89
|
+
super().__init__(agent, ActionType.Hub, source, before)
|
90
|
+
|
91
|
+
|
92
|
+
class SetSchedule(SimAction):
|
93
|
+
"""
|
94
|
+
用于将agent的行程信息同步至模拟器 —— 仅对citizen类型agent适用
|
95
|
+
Synchronize agent's schedule to simulator —— only avalable for citizen type of agent
|
96
|
+
"""
|
97
|
+
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
agent,
|
101
|
+
source: Optional[str] = None,
|
102
|
+
before: Optional[Callable[[list], Any]] = None,
|
103
|
+
) -> None:
|
104
|
+
super().__init__(agent, source, before)
|
105
|
+
|
106
|
+
async def Forward(self, schedule=None):
|
107
|
+
"""
|
108
|
+
如果当前行程已经同步至模拟器: 跳过同步, 否则同步至模拟器
|
109
|
+
If current schedule has been synchronized to simulator: skip, else sync
|
110
|
+
"""
|
111
|
+
if not schedule == None:
|
112
|
+
if not schedule.is_set:
|
113
|
+
"""同步schedule至模拟器"""
|
114
|
+
self._agent.Scheduler.now.is_set = True
|
115
|
+
departure_time = schedule.time
|
116
|
+
mode = schedule.mode
|
117
|
+
aoi_id = schedule.target_id_aoi
|
118
|
+
poi_id = schedule.target_id_poi
|
119
|
+
end = {"aoi_position": {"aoi_id": aoi_id, "poi_id": poi_id}}
|
120
|
+
activity = schedule.description
|
121
|
+
trips = [
|
122
|
+
{
|
123
|
+
"mode": mode,
|
124
|
+
"end": end,
|
125
|
+
"departure_time": departure_time,
|
126
|
+
"activity": activity,
|
127
|
+
}
|
128
|
+
]
|
129
|
+
set_schedule = [
|
130
|
+
{"trips": trips, "loop_count": 1, "departure_time": departure_time}
|
131
|
+
]
|
132
|
+
|
133
|
+
# * 与模拟器对接
|
134
|
+
req = {"person_id": self._agent._id, "schedules": set_schedule}
|
135
|
+
await self._agent._client.person_service.SetSchedule(req)
|
136
|
+
elif self._source != None:
|
137
|
+
schedule = self.get_source()
|
138
|
+
if schedule != None and not schedule.is_set:
|
139
|
+
"""同步schedule至模拟器"""
|
140
|
+
self._agent.Scheduler.now.is_set = True
|
141
|
+
departure_time = schedule.time
|
142
|
+
mode = schedule.mode
|
143
|
+
aoi_id = schedule.target_id_aoi
|
144
|
+
poi_id = schedule.target_id_poi
|
145
|
+
end = {"aoi_position": {"aoi_id": aoi_id, "poi_id": poi_id}}
|
146
|
+
activity = schedule.description
|
147
|
+
trips = [
|
148
|
+
{
|
149
|
+
"mode": mode,
|
150
|
+
"end": end,
|
151
|
+
"departure_time": departure_time,
|
152
|
+
"activity": activity,
|
153
|
+
}
|
154
|
+
]
|
155
|
+
set_schedule = [
|
156
|
+
{"trips": trips, "loop_count": 1, "departure_time": departure_time}
|
157
|
+
]
|
158
|
+
|
159
|
+
# * 与模拟器对接
|
160
|
+
req = {"person_id": self._agent._id, "schedules": set_schedule}
|
161
|
+
await self._agent._client.person_service.SetSchedule(req)
|
162
|
+
|
163
|
+
|
164
|
+
class SendAgentMessage(SimAction):
|
165
|
+
"""
|
166
|
+
发送信息给其他agent
|
167
|
+
Send messages to other agents
|
168
|
+
"""
|
169
|
+
|
170
|
+
def __init__(
|
171
|
+
self,
|
172
|
+
agent,
|
173
|
+
source: Optional[str] = None,
|
174
|
+
before: Optional[Callable[[list], Any]] = None,
|
175
|
+
) -> None:
|
176
|
+
super().__init__(agent, source, before)
|
177
|
+
|
178
|
+
async def Forward(self, messages: Optional[dict] = None):
|
179
|
+
if not messages == None and len(messages) > 0:
|
180
|
+
req = {"messages": []}
|
181
|
+
for message in messages:
|
182
|
+
from_id = self._agent._id
|
183
|
+
to_id = message["id"]
|
184
|
+
mes = message["message"]
|
185
|
+
req["messages"].append({"from": from_id, "to": to_id, "message": mes})
|
186
|
+
await self._agent._client.social_service.Send(req)
|
187
|
+
elif self._source != None:
|
188
|
+
messages = self.get_source()
|
189
|
+
if not messages == None and len(messages) > 0:
|
190
|
+
req = {"messages": []}
|
191
|
+
for message in messages:
|
192
|
+
from_id = self._agent._id
|
193
|
+
to_id = message["id"]
|
194
|
+
mes = message["message"]
|
195
|
+
req["messages"].append(
|
196
|
+
{"from": from_id, "to": to_id, "message": mes}
|
197
|
+
)
|
198
|
+
await self._agent._client.social_service.Send(req)
|
File without changes
|
File without changes
|