pycityagent 1.0.0__py3-none-any.whl → 2.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +7 -3
- pycityagent/agent.py +180 -284
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +307 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/interact.py +141 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/{brain → environment/sence}/static.py +1 -1
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +27 -0
- pycityagent/environment/sim/aoi_service.py +38 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +43 -0
- pycityagent/environment/sim/economy_services.py +191 -0
- pycityagent/environment/sim/lane_service.py +110 -0
- pycityagent/environment/sim/light_service.py +120 -0
- pycityagent/environment/sim/person_service.py +294 -0
- pycityagent/environment/sim/road_service.py +38 -0
- pycityagent/environment/sim/social_service.py +58 -0
- pycityagent/environment/simulator.py +369 -0
- pycityagent/environment/utils/__init__.py +8 -0
- pycityagent/environment/utils/geojson.py +26 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/protobuf.py +39 -0
- pycityagent/llm/__init__.py +6 -0
- pycityagent/llm/embedding.py +136 -0
- pycityagent/llm/llm.py +430 -0
- pycityagent/llm/llmconfig.py +15 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +11 -0
- pycityagent/memory/const.py +41 -0
- pycityagent/memory/memory.py +453 -0
- pycityagent/memory/memory_base.py +168 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +27 -0
- pycityagent/message/__init__.py +0 -0
- pycityagent/simulation/__init__.py +7 -0
- pycityagent/simulation/interview.py +36 -0
- pycityagent/simulation/simulation.py +286 -0
- pycityagent/simulation/survey/__init__.py +9 -0
- pycityagent/simulation/survey/manager.py +67 -0
- pycityagent/simulation/survey/models.py +49 -0
- pycityagent/simulation/ui/__init__.py +3 -0
- pycityagent/simulation/ui/interface.py +602 -0
- pycityagent/utils/__init__.py +0 -0
- pycityagent/utils/decorators.py +89 -0
- pycityagent/utils/parsers/__init__.py +12 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/workflow/__init__.py +22 -0
- pycityagent/workflow/block.py +137 -0
- pycityagent/workflow/prompt.py +72 -0
- pycityagent/workflow/tool.py +246 -0
- pycityagent/workflow/trigger.py +66 -0
- pycityagent-2.0.0a1.dist-info/METADATA +208 -0
- pycityagent-2.0.0a1.dist-info/RECORD +65 -0
- {pycityagent-1.0.0.dist-info → pycityagent-2.0.0a1.dist-info}/WHEEL +1 -2
- pycityagent/ac/__init__.py +0 -6
- pycityagent/ac/ac.py +0 -50
- pycityagent/ac/action.py +0 -14
- pycityagent/ac/controled.py +0 -13
- pycityagent/ac/converse.py +0 -31
- pycityagent/ac/idle.py +0 -17
- pycityagent/ac/shop.py +0 -80
- pycityagent/ac/trip.py +0 -37
- pycityagent/brain/__init__.py +0 -10
- pycityagent/brain/brain.py +0 -52
- pycityagent/brain/brainfc.py +0 -10
- pycityagent/brain/memory.py +0 -541
- pycityagent/brain/persistence/social.py +0 -1
- pycityagent/brain/persistence/spatial.py +0 -14
- pycityagent/brain/reason/shop.py +0 -37
- pycityagent/brain/reason/social.py +0 -148
- pycityagent/brain/reason/trip.py +0 -67
- pycityagent/brain/reason/user.py +0 -122
- pycityagent/brain/retrive/social.py +0 -6
- pycityagent/brain/scheduler.py +0 -408
- pycityagent/brain/sence.py +0 -375
- pycityagent/cc/__init__.py +0 -5
- pycityagent/cc/cc.py +0 -102
- pycityagent/cc/conve.py +0 -6
- pycityagent/cc/idle.py +0 -20
- pycityagent/cc/shop.py +0 -6
- pycityagent/cc/trip.py +0 -13
- pycityagent/cc/user.py +0 -13
- pycityagent/hubconnector/__init__.py +0 -3
- pycityagent/hubconnector/hubconnector.py +0 -137
- pycityagent/image/__init__.py +0 -3
- pycityagent/image/image.py +0 -158
- pycityagent/simulator.py +0 -161
- pycityagent/st/__init__.py +0 -4
- pycityagent/st/st.py +0 -96
- pycityagent/urbanllm/__init__.py +0 -3
- pycityagent/urbanllm/urbanllm.py +0 -132
- pycityagent-1.0.0.dist-info/LICENSE +0 -21
- pycityagent-1.0.0.dist-info/METADATA +0 -181
- pycityagent-1.0.0.dist-info/RECORD +0 -48
- pycityagent-1.0.0.dist-info/top_level.txt +0 -1
- /pycityagent/{brain/persistence/__init__.py → config.py} +0 -0
- /pycityagent/{brain/reason → environment/interact}/__init__.py +0 -0
- /pycityagent/{brain/retrive → environment/message}/__init__.py +0 -0
@@ -0,0 +1,307 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from typing import Any, Literal, Union
|
4
|
+
|
5
|
+
import grpc
|
6
|
+
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
7
|
+
import pycityproto.city.economy.v2.org_service_pb2 as org_service
|
8
|
+
import pycityproto.city.economy.v2.org_service_pb2_grpc as org_grpc
|
9
|
+
from google.protobuf import descriptor
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"EconomyClient",
|
13
|
+
]
|
14
|
+
|
15
|
+
|
16
|
+
def _snake_to_pascal(snake_str):
|
17
|
+
_res = "".join(word.capitalize() or "_" for word in snake_str.split("_"))
|
18
|
+
for _word in {
|
19
|
+
"Gdp",
|
20
|
+
}:
|
21
|
+
if _word in _res:
|
22
|
+
_res = _res.replace(_word, _word.upper())
|
23
|
+
return _res
|
24
|
+
|
25
|
+
|
26
|
+
def _get_field_type_and_repeated(message, field_name: str) -> tuple[Any, bool]:
|
27
|
+
try:
|
28
|
+
field_descriptor = message.DESCRIPTOR.fields_by_name[field_name]
|
29
|
+
field_type = field_descriptor.type
|
30
|
+
_type_mapping = {
|
31
|
+
descriptor.FieldDescriptor.TYPE_FLOAT: float,
|
32
|
+
descriptor.FieldDescriptor.TYPE_INT32: int,
|
33
|
+
}
|
34
|
+
is_repeated = (
|
35
|
+
field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED
|
36
|
+
)
|
37
|
+
return (_type_mapping.get(field_type), is_repeated)
|
38
|
+
except KeyError:
|
39
|
+
raise KeyError(f"Invalid message {message} and filed name {field_name}!")
|
40
|
+
|
41
|
+
|
42
|
+
def _create_aio_channel(server_address: str, secure: bool = False) -> grpc.aio.Channel:
|
43
|
+
"""
|
44
|
+
Create a grpc asynchronous channel
|
45
|
+
|
46
|
+
Args:
|
47
|
+
- server_address (str): server address.
|
48
|
+
- secure (bool, optional): Defaults to False. Whether to use a secure connection. Defaults to False.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
- grpc.aio.Channel: grpc asynchronous channel.
|
52
|
+
"""
|
53
|
+
if server_address.startswith("http://"):
|
54
|
+
server_address = server_address.split("//")[1]
|
55
|
+
if secure:
|
56
|
+
raise ValueError("secure channel must use `https` or not use `http`")
|
57
|
+
elif server_address.startswith("https://"):
|
58
|
+
server_address = server_address.split("//")[1]
|
59
|
+
if not secure:
|
60
|
+
secure = True
|
61
|
+
|
62
|
+
if secure:
|
63
|
+
return grpc.aio.secure_channel(server_address, grpc.ssl_channel_credentials())
|
64
|
+
else:
|
65
|
+
return grpc.aio.insecure_channel(server_address)
|
66
|
+
|
67
|
+
|
68
|
+
class EconomyClient:
|
69
|
+
"""
|
70
|
+
Client side of Economy service
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(self, server_address: str, secure: bool = False):
|
74
|
+
"""
|
75
|
+
Constructor of EconomyClient
|
76
|
+
|
77
|
+
Args:
|
78
|
+
- server_address (str): Economy server address
|
79
|
+
- secure (bool, optional): Defaults to False. Whether to use a secure connection. Defaults to False.
|
80
|
+
"""
|
81
|
+
self.server_address = server_address
|
82
|
+
self.secure = secure
|
83
|
+
aio_channel = _create_aio_channel(server_address, secure)
|
84
|
+
self._aio_stub = org_grpc.OrgServiceStub(aio_channel)
|
85
|
+
|
86
|
+
def __getstate__(self):
|
87
|
+
"""
|
88
|
+
Copy the object's state from self.__dict__ which contains
|
89
|
+
all our instance attributes. Always use the dict.copy()
|
90
|
+
method to avoid modifying the original state.
|
91
|
+
"""
|
92
|
+
state = self.__dict__.copy()
|
93
|
+
# Remove the non-picklable entries.
|
94
|
+
del state["_aio_stub"]
|
95
|
+
return state
|
96
|
+
|
97
|
+
def __setstate__(self, state):
|
98
|
+
""" "
|
99
|
+
Restore instance attributes (i.e., filename and mode) from the
|
100
|
+
unpickled state dictionary.
|
101
|
+
"""
|
102
|
+
self.__dict__.update(state)
|
103
|
+
# Re-initialize the channel after unpickling
|
104
|
+
aio_channel = _create_aio_channel(self.server_address, self.secure)
|
105
|
+
self._aio_stub = org_grpc.OrgServiceStub(aio_channel)
|
106
|
+
|
107
|
+
async def get(
|
108
|
+
self,
|
109
|
+
id: int,
|
110
|
+
key: str,
|
111
|
+
) -> Any:
|
112
|
+
"""
|
113
|
+
Get specific value
|
114
|
+
|
115
|
+
Args:
|
116
|
+
- id (int): the id of `Org` or `Agent`.
|
117
|
+
- key (str): the attribute to fetch.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
- Any
|
121
|
+
"""
|
122
|
+
pascal_key = _snake_to_pascal(key)
|
123
|
+
_request_type = getattr(org_service, f"Get{pascal_key}Request")
|
124
|
+
_request_func = getattr(self._aio_stub, f"Get{pascal_key}")
|
125
|
+
response = await _request_func(_request_type(org_id=id))
|
126
|
+
value_type, is_repeated = _get_field_type_and_repeated(response, field_name=key)
|
127
|
+
if is_repeated:
|
128
|
+
return list(getattr(response, key))
|
129
|
+
else:
|
130
|
+
return value_type(getattr(response, key))
|
131
|
+
|
132
|
+
async def update(
|
133
|
+
self,
|
134
|
+
id: int,
|
135
|
+
key: str,
|
136
|
+
value: Any,
|
137
|
+
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
138
|
+
) -> Any:
|
139
|
+
"""
|
140
|
+
Update key-value pair
|
141
|
+
|
142
|
+
Args:
|
143
|
+
- id (int): the id of `Org` or `Agent`.
|
144
|
+
- key (str): the attribute to update.
|
145
|
+
- mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
146
|
+
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
- Any
|
150
|
+
"""
|
151
|
+
pascal_key = _snake_to_pascal(key)
|
152
|
+
_request_type = getattr(org_service, f"Set{pascal_key}Request")
|
153
|
+
_request_func = getattr(self._aio_stub, f"Set{pascal_key}")
|
154
|
+
if mode == "merge":
|
155
|
+
orig_value = await self.get(id, key)
|
156
|
+
_orig_type = type(orig_value)
|
157
|
+
_new_type = type(value)
|
158
|
+
if _orig_type != _new_type:
|
159
|
+
logging.debug(
|
160
|
+
f"Inconsistent type of original value {_orig_type.__name__} and to-update value {_new_type.__name__}"
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
if isinstance(orig_value, set):
|
164
|
+
orig_value.update(set(value))
|
165
|
+
value = orig_value
|
166
|
+
elif isinstance(orig_value, dict):
|
167
|
+
orig_value.update(dict(value))
|
168
|
+
value = orig_value
|
169
|
+
elif isinstance(orig_value, list):
|
170
|
+
orig_value.extend(list(value))
|
171
|
+
value = orig_value
|
172
|
+
else:
|
173
|
+
logging.warning(
|
174
|
+
f"Type of {type(orig_value)} does not support mode `merge`, using `replace` instead!"
|
175
|
+
)
|
176
|
+
return await _request_func(
|
177
|
+
_request_type(
|
178
|
+
**{
|
179
|
+
"org_id": id,
|
180
|
+
key: value,
|
181
|
+
}
|
182
|
+
)
|
183
|
+
)
|
184
|
+
|
185
|
+
async def add_agents(self, configs: Union[list[dict], dict]):
|
186
|
+
if isinstance(configs, dict):
|
187
|
+
configs = [configs]
|
188
|
+
tasks = [
|
189
|
+
self._aio_stub.AddAgent(
|
190
|
+
org_service.AddAgentRequest(
|
191
|
+
agent=economyv2.Agent(
|
192
|
+
id=config["id"],
|
193
|
+
currency=config.get("currency", 0.0),
|
194
|
+
)
|
195
|
+
)
|
196
|
+
)
|
197
|
+
for config in configs
|
198
|
+
]
|
199
|
+
responses = await asyncio.gather(*tasks)
|
200
|
+
|
201
|
+
async def add_orgs(self, configs: Union[list[dict], dict]):
|
202
|
+
if isinstance(configs, dict):
|
203
|
+
configs = [configs]
|
204
|
+
tasks = [
|
205
|
+
self._aio_stub.AddOrg(
|
206
|
+
org_service.AddOrgRequest(
|
207
|
+
org=economyv2.Org(
|
208
|
+
id=config["id"],
|
209
|
+
type=config["type"],
|
210
|
+
nominal_gdp=config.get("nominal_gdp", []),
|
211
|
+
real_gdp=config.get("real_gdp", []),
|
212
|
+
unemployment=config.get("unemployment", []),
|
213
|
+
wages=config.get("wages", []),
|
214
|
+
prices=config.get("prices", []),
|
215
|
+
inventory=config.get("inventory", 0),
|
216
|
+
price=config.get("price", 0),
|
217
|
+
currency=config.get("currency", 0.0),
|
218
|
+
interest_rate=config.get("interest_rate", 0.0),
|
219
|
+
bracket_cutoffs=config.get("bracket_cutoffs", []),
|
220
|
+
bracket_rates=config.get("bracket_rates", []),
|
221
|
+
)
|
222
|
+
)
|
223
|
+
)
|
224
|
+
for config in configs
|
225
|
+
]
|
226
|
+
responses = await asyncio.gather(*tasks)
|
227
|
+
|
228
|
+
async def calculate_taxes_due(
|
229
|
+
self,
|
230
|
+
org_id: int,
|
231
|
+
agent_ids: list[int],
|
232
|
+
incomes: list[float],
|
233
|
+
enable_redistribution: bool,
|
234
|
+
):
|
235
|
+
request = org_service.CalculateTaxesDueRequest(
|
236
|
+
government_id=org_id,
|
237
|
+
agent_ids=agent_ids,
|
238
|
+
incomes=incomes,
|
239
|
+
enable_redistribution=enable_redistribution,
|
240
|
+
)
|
241
|
+
response: org_service.CalculateTaxesDueResponse = (
|
242
|
+
await self._aio_stub.CalculateTaxesDue(request)
|
243
|
+
)
|
244
|
+
return (float(response.taxes_due), list(response.updated_incomes))
|
245
|
+
|
246
|
+
async def calculate_consumption(
|
247
|
+
self, org_id: int, agent_ids: list[int], demands: list[int]
|
248
|
+
):
|
249
|
+
request = org_service.CalculateConsumptionRequest(
|
250
|
+
firm_id=org_id,
|
251
|
+
agent_ids=agent_ids,
|
252
|
+
demands=demands,
|
253
|
+
)
|
254
|
+
response: org_service.CalculateConsumptionResponse = (
|
255
|
+
await self._aio_stub.CalculateConsumption(request)
|
256
|
+
)
|
257
|
+
return (int(response.remain_inventory), list(response.updated_currencies))
|
258
|
+
|
259
|
+
async def calculate_interest(self, org_id: int, agent_ids: list[int]):
|
260
|
+
request = org_service.CalculateInterestRequest(
|
261
|
+
bank_id=org_id,
|
262
|
+
agent_ids=agent_ids,
|
263
|
+
)
|
264
|
+
response: org_service.CalculateInterestResponse = (
|
265
|
+
await self._aio_stub.CalculateInterest(request)
|
266
|
+
)
|
267
|
+
return (float(response.total_interest), list(response.updated_currencies))
|
268
|
+
|
269
|
+
async def remove_agents(self, agent_ids: Union[int, list[int]]):
|
270
|
+
if isinstance(agent_ids, int):
|
271
|
+
agent_ids = [agent_ids]
|
272
|
+
tasks = [
|
273
|
+
self._aio_stub.RemoveAgent(
|
274
|
+
org_service.RemoveAgentRequest(agent_id=agent_id)
|
275
|
+
)
|
276
|
+
for agent_id in agent_ids
|
277
|
+
]
|
278
|
+
responses = await asyncio.gather(*tasks)
|
279
|
+
|
280
|
+
async def remove_orgs(self, org_ids: Union[int, list[int]]):
|
281
|
+
if isinstance(org_ids, int):
|
282
|
+
org_ids = [org_ids]
|
283
|
+
tasks = [
|
284
|
+
self._aio_stub.RemoveOrg(org_service.RemoveOrgRequest(org_id=org_id))
|
285
|
+
for org_id in org_ids
|
286
|
+
]
|
287
|
+
responses = await asyncio.gather(*tasks)
|
288
|
+
|
289
|
+
async def save(self, file_path: str) -> tuple[list[int], list[int]]:
|
290
|
+
request = org_service.SaveEconomyEntitiesRequest(
|
291
|
+
file_path=file_path,
|
292
|
+
)
|
293
|
+
response: org_service.SaveEconomyEntitiesResponse = (
|
294
|
+
await self._aio_stub.SaveEconomyEntities(request)
|
295
|
+
)
|
296
|
+
# current agent ids and org ids
|
297
|
+
return (list(response.agent_ids), list(response.org_ids))
|
298
|
+
|
299
|
+
async def load(self, file_path: str):
|
300
|
+
request = org_service.LoadEconomyEntitiesRequest(
|
301
|
+
file_path=file_path,
|
302
|
+
)
|
303
|
+
response: org_service.LoadEconomyEntitiesResponse = (
|
304
|
+
await self._aio_stub.LoadEconomyEntities(request)
|
305
|
+
)
|
306
|
+
# current agent ids and org ids
|
307
|
+
return (list(response.agent_ids), list(response.org_ids))
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""环境相关的Interaction定义"""
|
2
|
+
from enum import Enum
|
3
|
+
from typing import Callable, Optional, Any
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import Callable, Any
|
6
|
+
|
7
|
+
class ActionType(Enum):
|
8
|
+
"""
|
9
|
+
行动类型枚举 所有行动本质上为数据推送
|
10
|
+
Action Type enumeration, all actions are essentially data push
|
11
|
+
|
12
|
+
Types:
|
13
|
+
- Sim = 1, 用于表示与模拟器对接的行动
|
14
|
+
- Hub = 2, 用于表示与AppHub(前端)对接的行动
|
15
|
+
- Comp = 3, 表示综合类型 (可能同时包含与Sim以及Hub的交互)
|
16
|
+
"""
|
17
|
+
Sim = 1
|
18
|
+
Hub = 2
|
19
|
+
Comp = 3
|
20
|
+
|
21
|
+
class Action:
|
22
|
+
"""
|
23
|
+
- Action
|
24
|
+
"""
|
25
|
+
def __init__(self, agent, type:ActionType, source: Optional[str] = None, before:Optional[Callable[[list], Any]] = None) -> None:
|
26
|
+
'''
|
27
|
+
默认初始化
|
28
|
+
|
29
|
+
Args:
|
30
|
+
- agent (Agent): the related agent
|
31
|
+
- type (ActionType)
|
32
|
+
- source (str): 数据来源, 默认为None, 如果为None则会从接收用户传入的数据作为Forward函数参数, 否则从WM.Reason数据缓存中取对应数据作为参数
|
33
|
+
- before (function): 数据处理方法, 用于当Reason缓存中的参数与标准格式不符时使用
|
34
|
+
'''
|
35
|
+
self._agent = agent
|
36
|
+
self._type = type
|
37
|
+
self._source = source
|
38
|
+
self._before = before
|
39
|
+
|
40
|
+
def get_source(self):
|
41
|
+
"""
|
42
|
+
获取source数据
|
43
|
+
"""
|
44
|
+
if self._source != None:
|
45
|
+
source = self._agent.Brain.Memory.Working.Reason[self._source]
|
46
|
+
if self._before != None:
|
47
|
+
source = self._before(source)
|
48
|
+
return source
|
49
|
+
else:
|
50
|
+
return None
|
51
|
+
|
52
|
+
@abstractmethod
|
53
|
+
async def Forward(self):
|
54
|
+
'''接口函数'''
|
55
|
+
|
56
|
+
class SimAction(Action):
|
57
|
+
"""SimAction: 模拟器关联Action"""
|
58
|
+
def __init__(self, agent, source: Optional[str] = None, before:Optional[Callable[[list], Any]] = None) -> None:
|
59
|
+
super().__init__(agent, ActionType.Sim, source, before)
|
60
|
+
|
61
|
+
class HubAction(Action):
|
62
|
+
"""HubAction: Apphub关联Action"""
|
63
|
+
def __init__(self, agent, source: Optional[str] = None, before:Optional[Callable[[list], Any]] = None) -> None:
|
64
|
+
super().__init__(agent, ActionType.Hub, source, before)
|
65
|
+
|
66
|
+
|
67
|
+
class SetSchedule(SimAction):
|
68
|
+
"""
|
69
|
+
用于将agent的行程信息同步至模拟器 —— 仅对citizen类型agent适用
|
70
|
+
Synchronize agent's schedule to simulator —— only avalable for citizen type of agent
|
71
|
+
"""
|
72
|
+
def __init__(self, agent, source: Optional[str] = None, before:Optional[Callable[[list], Any]] = None) -> None:
|
73
|
+
super().__init__(agent, source, before)
|
74
|
+
|
75
|
+
async def Forward(self, schedule = None):
|
76
|
+
"""
|
77
|
+
如果当前行程已经同步至模拟器: 跳过同步, 否则同步至模拟器
|
78
|
+
If current schedule has been synchronized to simulator: skip, else sync
|
79
|
+
"""
|
80
|
+
if not schedule == None:
|
81
|
+
if not schedule.is_set:
|
82
|
+
'''同步schedule至模拟器'''
|
83
|
+
self._agent.Scheduler.now.is_set = True
|
84
|
+
departure_time = schedule.time
|
85
|
+
mode = schedule.mode
|
86
|
+
aoi_id = schedule.target_id_aoi
|
87
|
+
poi_id = schedule.target_id_poi
|
88
|
+
end = {'aoi_position': {'aoi_id': aoi_id, 'poi_id': poi_id}}
|
89
|
+
activity = schedule.description
|
90
|
+
trips = [{'mode': mode, 'end': end, 'departure_time': departure_time, 'activity': activity}]
|
91
|
+
set_schedule = [{'trips': trips, 'loop_count': 1, 'departure_time': departure_time}]
|
92
|
+
|
93
|
+
# * 与模拟器对接
|
94
|
+
req = {'person_id': self._agent._id, 'schedules': set_schedule}
|
95
|
+
await self._agent._client.person_service.SetSchedule(req)
|
96
|
+
elif self._source != None:
|
97
|
+
schedule = self.get_source()
|
98
|
+
if schedule != None and not schedule.is_set:
|
99
|
+
'''同步schedule至模拟器'''
|
100
|
+
self._agent.Scheduler.now.is_set = True
|
101
|
+
departure_time = schedule.time
|
102
|
+
mode = schedule.mode
|
103
|
+
aoi_id = schedule.target_id_aoi
|
104
|
+
poi_id = schedule.target_id_poi
|
105
|
+
end = {'aoi_position': {'aoi_id': aoi_id, 'poi_id': poi_id}}
|
106
|
+
activity = schedule.description
|
107
|
+
trips = [{'mode': mode, 'end': end, 'departure_time': departure_time, 'activity': activity}]
|
108
|
+
set_schedule = [{'trips': trips, 'loop_count': 1, 'departure_time': departure_time}]
|
109
|
+
|
110
|
+
# * 与模拟器对接
|
111
|
+
req = {'person_id': self._agent._id, 'schedules': set_schedule}
|
112
|
+
await self._agent._client.person_service.SetSchedule(req)
|
113
|
+
|
114
|
+
|
115
|
+
class SendAgentMessage(SimAction):
|
116
|
+
"""
|
117
|
+
发送信息给其他agent
|
118
|
+
Send messages to other agents
|
119
|
+
"""
|
120
|
+
def __init__(self, agent, source: Optional[str] = None, before:Optional[Callable[[list], Any]] = None) -> None:
|
121
|
+
super().__init__(agent, source, before)
|
122
|
+
|
123
|
+
async def Forward(self, messages: Optional[dict] = None):
|
124
|
+
if not messages == None and len(messages) > 0:
|
125
|
+
req = {'messages': []}
|
126
|
+
for message in messages:
|
127
|
+
from_id = self._agent._id
|
128
|
+
to_id = message['id']
|
129
|
+
mes = message['message']
|
130
|
+
req['messages'].append({'from': from_id, 'to': to_id, 'message': mes})
|
131
|
+
await self._agent._client.social_service.Send(req)
|
132
|
+
elif self._source != None:
|
133
|
+
messages = self.get_source()
|
134
|
+
if not messages == None and len(messages) > 0:
|
135
|
+
req = {'messages': []}
|
136
|
+
for message in messages:
|
137
|
+
from_id = self._agent._id
|
138
|
+
to_id = message['id']
|
139
|
+
mes = message['message']
|
140
|
+
req['messages'].append({'from': from_id, 'to': to_id, 'message': mes})
|
141
|
+
await self._agent._client.social_service.Send(req)
|
File without changes
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import logging
|
2
|
+
from time import sleep
|
3
|
+
from typing import cast
|
4
|
+
|
5
|
+
import grpc
|
6
|
+
from pycityproto.city.sync.v2 import sync_service_pb2 as sync_service
|
7
|
+
from pycityproto.city.sync.v2 import sync_service_pb2_grpc as sync_grpc
|
8
|
+
|
9
|
+
from ..utils.grpc import create_channel
|
10
|
+
|
11
|
+
__all__ = ["OnlyClientSidecar"]
|
12
|
+
|
13
|
+
|
14
|
+
class OnlyClientSidecar:
|
15
|
+
"""
|
16
|
+
Sidecar框架服务(仅支持作为客户端,不支持对外提供gRPC服务)
|
17
|
+
Sidecar framework service (only supported as a client, does not support external gRPC services)
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self, name: str, syncer_address: str, secure: bool = False):
|
21
|
+
"""
|
22
|
+
Args:
|
23
|
+
- name (str): 本服务在etcd上的注册名。The registered name of this service on etcd.
|
24
|
+
- server_address (str): syncer地址。syncer address.
|
25
|
+
- listen_address (str): sidecar监听地址。sidecar listening address.
|
26
|
+
- secure (bool, optional): 是否使用安全连接. Defaults to False. Whether to use a secure connection. Defaults to False.
|
27
|
+
"""
|
28
|
+
self._name = name
|
29
|
+
channel = create_channel(syncer_address, secure)
|
30
|
+
self._sync_stub = sync_grpc.SyncServiceStub(channel)
|
31
|
+
|
32
|
+
def wait_url(self, name: str) -> str:
|
33
|
+
"""
|
34
|
+
获取服务的uri
|
35
|
+
Get the uri of the service
|
36
|
+
|
37
|
+
Args:
|
38
|
+
- name (str): 服务的注册名。Service registration name.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
- str: 服务的url。service url.
|
42
|
+
"""
|
43
|
+
while True:
|
44
|
+
try:
|
45
|
+
resp = cast(
|
46
|
+
sync_service.GetURLResponse,
|
47
|
+
self._sync_stub.GetURL(sync_service.GetURLRequest(name=name)),
|
48
|
+
)
|
49
|
+
url = resp.url
|
50
|
+
break
|
51
|
+
except grpc.RpcError as e:
|
52
|
+
logging.warning("get uri failed, retrying..., %s", e)
|
53
|
+
sleep(1)
|
54
|
+
|
55
|
+
logging.debug("get uri: %s for name=%s", url, name)
|
56
|
+
return url
|
57
|
+
|
58
|
+
def step(self, close: bool = False) -> bool:
|
59
|
+
"""
|
60
|
+
同步器步进
|
61
|
+
synchronizer step up
|
62
|
+
|
63
|
+
Args:
|
64
|
+
- close (bool): 是否退出模拟。Whether the simulation exited.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
- close (bool): 是否退出模拟。Whether the simulation exited.
|
68
|
+
"""
|
69
|
+
self._sync_stub.EnterStepSync(
|
70
|
+
sync_service.EnterStepSyncRequest(name=self._name)
|
71
|
+
)
|
72
|
+
response = self._sync_stub.ExitStepSync(
|
73
|
+
sync_service.ExitStepSyncRequest(name=self._name, close=close)
|
74
|
+
)
|
75
|
+
return response.close
|
76
|
+
|
77
|
+
def init(self) -> bool:
|
78
|
+
"""
|
79
|
+
同步器初始化
|
80
|
+
Synchronizer initialization
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
- close (bool): 是否退出模拟。Whether the simulation exited.
|
84
|
+
|
85
|
+
Examples:
|
86
|
+
```python
|
87
|
+
close = client.init()
|
88
|
+
print(close)
|
89
|
+
# > False
|
90
|
+
```
|
91
|
+
"""
|
92
|
+
return self.step()
|
93
|
+
|
94
|
+
def close(self) -> bool:
|
95
|
+
"""
|
96
|
+
同步器关闭
|
97
|
+
Synchronizer close
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
- close (bool): 是否退出模拟。Whether the simulation exited.
|
101
|
+
"""
|
102
|
+
return self.step(True)
|
103
|
+
|
104
|
+
def notify_step_ready(self):
|
105
|
+
"""
|
106
|
+
通知prepare阶段已完成
|
107
|
+
Notify that the prepare phase is completed
|
108
|
+
"""
|
109
|
+
...
|
@@ -0,0 +1,27 @@
|
|
1
|
+
"""
|
2
|
+
模拟器gRPC接入客户端
|
3
|
+
Simulator gRPC access client
|
4
|
+
"""
|
5
|
+
|
6
|
+
from .person_service import PersonService
|
7
|
+
from .aoi_service import AoiService
|
8
|
+
from .client import CityClient
|
9
|
+
from .clock_service import ClockService
|
10
|
+
from .economy_services import EconomyOrgService, EconomyPersonService
|
11
|
+
from .lane_service import LaneService
|
12
|
+
from .road_service import RoadService
|
13
|
+
from .social_service import SocialService
|
14
|
+
from .light_service import LightService
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"CityClient",
|
18
|
+
"ClockService",
|
19
|
+
"PersonService",
|
20
|
+
"AoiService",
|
21
|
+
"LaneService",
|
22
|
+
"RoadService",
|
23
|
+
"SocialService",
|
24
|
+
"EconomyPersonService",
|
25
|
+
"EconomyOrgService",
|
26
|
+
"LightService",
|
27
|
+
]
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from typing import Any, Awaitable, Coroutine, cast, Union, Dict
|
2
|
+
|
3
|
+
import grpc
|
4
|
+
from google.protobuf.json_format import ParseDict
|
5
|
+
from pycityproto.city.map.v2 import aoi_service_pb2 as aoi_service
|
6
|
+
from pycityproto.city.map.v2 import aoi_service_pb2_grpc as aoi_grpc
|
7
|
+
|
8
|
+
from ..utils.protobuf import async_parse
|
9
|
+
|
10
|
+
__all__ = ["AoiService"]
|
11
|
+
|
12
|
+
|
13
|
+
class AoiService:
|
14
|
+
"""
|
15
|
+
aoi服务
|
16
|
+
AOI service
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, aio_channel: grpc.aio.Channel):
|
20
|
+
self._aio_stub = aoi_grpc.AoiServiceStub(aio_channel)
|
21
|
+
|
22
|
+
def GetAoi(
|
23
|
+
self, req: Union[aoi_service.GetAoiRequest, dict], dict_return: bool = True
|
24
|
+
) -> Coroutine[Any, Any, Union[Dict[str, Any], aoi_service.GetAoiResponse]]:
|
25
|
+
"""
|
26
|
+
获取AOI信息
|
27
|
+
get AOI information
|
28
|
+
|
29
|
+
Args:
|
30
|
+
- req (dict): https://cityproto.sim.fiblab.net/#city.map.v2.GetAoiRequest
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
- https://cityproto.sim.fiblab.net/#city.map.v2.GetAoiResponse
|
34
|
+
"""
|
35
|
+
if type(req) != aoi_service.GetAoiRequest:
|
36
|
+
req = ParseDict(req, aoi_service.GetAoiRequest())
|
37
|
+
res = cast(Awaitable[aoi_service.GetAoiResponse], self._aio_stub.GetAoi(req))
|
38
|
+
return async_parse(res, dict_return)
|