pycityagent 2.0.0a65__cp310-cp310-macosx_11_0_arm64.whl → 2.0.0a67__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +157 -57
- pycityagent/agent/agent_base.py +316 -43
- pycityagent/cityagent/bankagent.py +49 -9
- pycityagent/cityagent/blocks/__init__.py +1 -2
- pycityagent/cityagent/blocks/cognition_block.py +54 -31
- pycityagent/cityagent/blocks/dispatcher.py +22 -17
- pycityagent/cityagent/blocks/economy_block.py +46 -32
- pycityagent/cityagent/blocks/mobility_block.py +209 -105
- pycityagent/cityagent/blocks/needs_block.py +101 -54
- pycityagent/cityagent/blocks/other_block.py +42 -33
- pycityagent/cityagent/blocks/plan_block.py +59 -42
- pycityagent/cityagent/blocks/social_block.py +167 -126
- pycityagent/cityagent/blocks/utils.py +13 -6
- pycityagent/cityagent/firmagent.py +17 -35
- pycityagent/cityagent/governmentagent.py +3 -3
- pycityagent/cityagent/initial.py +79 -49
- pycityagent/cityagent/memory_config.py +123 -94
- pycityagent/cityagent/message_intercept.py +0 -4
- pycityagent/cityagent/metrics.py +41 -0
- pycityagent/cityagent/nbsagent.py +24 -36
- pycityagent/cityagent/societyagent.py +9 -4
- pycityagent/cli/wrapper.py +2 -2
- pycityagent/economy/econ_client.py +407 -81
- pycityagent/environment/__init__.py +0 -3
- pycityagent/environment/sim/__init__.py +0 -3
- pycityagent/environment/sim/aoi_service.py +2 -2
- pycityagent/environment/sim/client.py +3 -31
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/lane_service.py +8 -8
- pycityagent/environment/sim/light_service.py +8 -8
- pycityagent/environment/sim/pause_service.py +9 -10
- pycityagent/environment/sim/person_service.py +20 -20
- pycityagent/environment/sim/road_service.py +2 -2
- pycityagent/environment/sim/sim_env.py +21 -5
- pycityagent/environment/sim/social_service.py +4 -4
- pycityagent/environment/simulator.py +249 -27
- pycityagent/environment/utils/__init__.py +2 -2
- pycityagent/environment/utils/geojson.py +2 -2
- pycityagent/environment/utils/grpc.py +4 -4
- pycityagent/environment/utils/map_utils.py +2 -2
- pycityagent/llm/embeddings.py +147 -28
- pycityagent/llm/llm.py +178 -111
- pycityagent/llm/llmconfig.py +5 -0
- pycityagent/llm/utils.py +4 -0
- pycityagent/memory/__init__.py +0 -4
- pycityagent/memory/const.py +2 -2
- pycityagent/memory/faiss_query.py +140 -61
- pycityagent/memory/memory.py +394 -91
- pycityagent/memory/memory_base.py +140 -34
- pycityagent/memory/profile.py +13 -13
- pycityagent/memory/self_define.py +13 -13
- pycityagent/memory/state.py +14 -14
- pycityagent/message/message_interceptor.py +253 -3
- pycityagent/message/messager.py +133 -6
- pycityagent/metrics/mlflow_client.py +47 -4
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +3 -2
- pycityagent/simulation/agentgroup.py +150 -54
- pycityagent/simulation/simulation.py +276 -66
- pycityagent/survey/manager.py +45 -3
- pycityagent/survey/models.py +42 -2
- pycityagent/tools/__init__.py +1 -2
- pycityagent/tools/tool.py +93 -69
- pycityagent/utils/avro_schema.py +2 -2
- pycityagent/utils/parsers/code_block_parser.py +1 -1
- pycityagent/utils/parsers/json_parser.py +2 -2
- pycityagent/utils/parsers/parser_base.py +2 -2
- pycityagent/workflow/block.py +64 -13
- pycityagent/workflow/prompt.py +31 -23
- pycityagent/workflow/trigger.py +91 -24
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/METADATA +2 -2
- pycityagent-2.0.0a67.dist-info/RECORD +97 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +0 -198
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +0 -416
- pycityagent/environment/sidecar/__init__.py +0 -8
- pycityagent/environment/sidecar/sidecarv2.py +0 -109
- pycityagent/environment/sim/economy_services.py +0 -192
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent-2.0.0a65.dist-info/RECORD +0 -105
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
|
+
import re
|
4
|
+
import time
|
3
5
|
from typing import Any, Literal, Union
|
4
6
|
|
5
7
|
import grpc
|
@@ -7,6 +9,9 @@ import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
|
7
9
|
import pycityproto.city.economy.v2.org_service_pb2 as org_service
|
8
10
|
import pycityproto.city.economy.v2.org_service_pb2_grpc as org_grpc
|
9
11
|
from google.protobuf import descriptor
|
12
|
+
from google.protobuf.json_format import MessageToDict
|
13
|
+
|
14
|
+
logger = logging.getLogger("pycityagent")
|
10
15
|
|
11
16
|
__all__ = [
|
12
17
|
"EconomyClient",
|
@@ -22,33 +27,30 @@ def _snake_to_pascal(snake_str):
|
|
22
27
|
_res = _res.replace(_word, _word.upper())
|
23
28
|
return _res
|
24
29
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
_type_mapping = {
|
31
|
-
descriptor.FieldDescriptor.TYPE_FLOAT: float,
|
32
|
-
descriptor.FieldDescriptor.TYPE_INT32: int,
|
33
|
-
}
|
34
|
-
is_repeated = (
|
35
|
-
field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED
|
36
|
-
)
|
37
|
-
return (_type_mapping.get(field_type), is_repeated)
|
38
|
-
except KeyError:
|
39
|
-
raise KeyError(f"Invalid message {message} and filed name {field_name}!")
|
40
|
-
|
30
|
+
def camel_to_snake(d):
|
31
|
+
if not isinstance(d, dict):
|
32
|
+
return d
|
33
|
+
return {re.sub('([a-z0-9])([A-Z])', r'\1_\2', k).lower(): camel_to_snake(v) if isinstance(v, dict) else v
|
34
|
+
for k, v in d.items()}
|
41
35
|
|
42
36
|
def _create_aio_channel(server_address: str, secure: bool = False) -> grpc.aio.Channel:
|
43
37
|
"""
|
44
|
-
Create a
|
38
|
+
Create a gRPC asynchronous channel.
|
39
|
+
|
40
|
+
- **Args**:
|
41
|
+
- `server_address` (`str`): The address of the server to connect to.
|
42
|
+
- `secure` (`bool`, optional): Whether to use a secure connection. Defaults to `False`.
|
43
|
+
|
44
|
+
- **Returns**:
|
45
|
+
- `grpc.aio.Channel`: A gRPC asynchronous channel for making RPC calls.
|
45
46
|
|
46
|
-
|
47
|
-
|
48
|
-
- secure (bool, optional): Defaults to False. Whether to use a secure connection. Defaults to False.
|
47
|
+
- **Raises**:
|
48
|
+
- `ValueError`: If a secure channel is requested but the server address starts with `http://`.
|
49
49
|
|
50
|
-
|
51
|
-
|
50
|
+
- **Description**:
|
51
|
+
- This function creates and returns a gRPC asynchronous channel based on the provided server address and security flag.
|
52
|
+
- It ensures that if `secure=True`, then the server address does not start with `http://`.
|
53
|
+
- If the server address starts with `https://`, it will automatically switch to a secure connection even if `secure=False`.
|
52
54
|
"""
|
53
55
|
if server_address.startswith("http://"):
|
54
56
|
server_address = server_address.split("//")[1]
|
@@ -67,21 +69,44 @@ def _create_aio_channel(server_address: str, secure: bool = False) -> grpc.aio.C
|
|
67
69
|
|
68
70
|
class EconomyClient:
|
69
71
|
"""
|
70
|
-
Client side of Economy service
|
72
|
+
Client side of Economy service.
|
73
|
+
|
74
|
+
- **Description**:
|
75
|
+
- This class serves as a client interface to interact with the Economy Simulator via gRPC.
|
76
|
+
- It establishes an asynchronous connection and provides methods to communicate with the service.
|
71
77
|
"""
|
72
78
|
|
73
79
|
def __init__(self, server_address: str, secure: bool = False):
|
74
80
|
"""
|
75
|
-
|
81
|
+
Initialize the EconomyClient.
|
82
|
+
|
83
|
+
- **Args**:
|
84
|
+
- `server_address` (`str`): The address of the Economy server to connect to.
|
85
|
+
- `secure` (`bool`, optional): Whether to use a secure connection. Defaults to `False`.
|
86
|
+
|
87
|
+
- **Attributes**:
|
88
|
+
- `server_address` (`str`): The address of the Economy server.
|
89
|
+
- `secure` (`bool`): A flag indicating if a secure connection should be used.
|
90
|
+
- `_aio_stub` (`OrgServiceStub`): A gRPC stub used to make remote calls to the Economy service.
|
76
91
|
|
77
|
-
|
78
|
-
|
79
|
-
|
92
|
+
- **Description**:
|
93
|
+
- Initializes the EconomyClient with the specified server address and security preference.
|
94
|
+
- Creates an asynchronous gRPC channel using `_create_aio_channel`.
|
95
|
+
- Instantiates a gRPC stub (`_aio_stub`) for interacting with the Economy service.
|
80
96
|
"""
|
81
97
|
self.server_address = server_address
|
82
98
|
self.secure = secure
|
83
99
|
aio_channel = _create_aio_channel(server_address, secure)
|
84
100
|
self._aio_stub = org_grpc.OrgServiceStub(aio_channel)
|
101
|
+
self._agent_ids = set()
|
102
|
+
self._org_ids = set()
|
103
|
+
self._log_list = []
|
104
|
+
|
105
|
+
def get_log_list(self):
|
106
|
+
return self._log_list
|
107
|
+
|
108
|
+
def clear_log_list(self):
|
109
|
+
self._log_list = []
|
85
110
|
|
86
111
|
def __getstate__(self):
|
87
112
|
"""
|
@@ -95,7 +120,7 @@ class EconomyClient:
|
|
95
120
|
return state
|
96
121
|
|
97
122
|
def __setstate__(self, state):
|
98
|
-
"""
|
123
|
+
"""
|
99
124
|
Restore instance attributes (i.e., filename and mode) from the
|
100
125
|
unpickled state dictionary.
|
101
126
|
"""
|
@@ -104,6 +129,71 @@ class EconomyClient:
|
|
104
129
|
aio_channel = _create_aio_channel(self.server_address, self.secure)
|
105
130
|
self._aio_stub = org_grpc.OrgServiceStub(aio_channel)
|
106
131
|
|
132
|
+
async def get_ids(self):
|
133
|
+
"""
|
134
|
+
Get the ids of agents and orgs
|
135
|
+
"""
|
136
|
+
return self._agent_ids, self._org_ids
|
137
|
+
|
138
|
+
async def set_ids(self, agent_ids: set[int], org_ids: set[int]):
|
139
|
+
"""
|
140
|
+
Set the ids of agents and orgs
|
141
|
+
"""
|
142
|
+
self._agent_ids = agent_ids
|
143
|
+
self._org_ids = org_ids
|
144
|
+
|
145
|
+
async def get_agent(self, id: int) -> economyv2.Agent:
|
146
|
+
"""
|
147
|
+
Get agent by id
|
148
|
+
|
149
|
+
- **Args**:
|
150
|
+
- `id` (`int`): The id of the agent.
|
151
|
+
|
152
|
+
- **Returns**:
|
153
|
+
- `economyv2.Agent`: The agent object.
|
154
|
+
"""
|
155
|
+
start_time = time.time()
|
156
|
+
log = {
|
157
|
+
"req": "get_agent",
|
158
|
+
"start_time": start_time,
|
159
|
+
"consumption": 0
|
160
|
+
}
|
161
|
+
agent = await self._aio_stub.GetAgent(
|
162
|
+
org_service.GetAgentRequest(
|
163
|
+
agent_id=id
|
164
|
+
)
|
165
|
+
)
|
166
|
+
agent_dict = MessageToDict(agent)["agent"]
|
167
|
+
log["consumption"] = time.time() - start_time
|
168
|
+
self._log_list.append(log)
|
169
|
+
return camel_to_snake(agent_dict)
|
170
|
+
|
171
|
+
async def get_org(self, id: int) -> economyv2.Org:
|
172
|
+
"""
|
173
|
+
Get org by id
|
174
|
+
|
175
|
+
- **Args**:
|
176
|
+
- `id` (`int`): The id of the org.
|
177
|
+
|
178
|
+
- **Returns**:
|
179
|
+
- `economyv2.Org`: The org object.
|
180
|
+
"""
|
181
|
+
start_time = time.time()
|
182
|
+
log = {
|
183
|
+
"req": "get_org",
|
184
|
+
"start_time": start_time,
|
185
|
+
"consumption": 0
|
186
|
+
}
|
187
|
+
org = await self._aio_stub.GetOrg(
|
188
|
+
org_service.GetOrgRequest(
|
189
|
+
org_id=id
|
190
|
+
)
|
191
|
+
)
|
192
|
+
org_dict = MessageToDict(org)["org"]
|
193
|
+
log["consumption"] = time.time() - start_time
|
194
|
+
self._log_list.append(log)
|
195
|
+
return camel_to_snake(org_dict)
|
196
|
+
|
107
197
|
async def get(
|
108
198
|
self,
|
109
199
|
id: int,
|
@@ -112,22 +202,29 @@ class EconomyClient:
|
|
112
202
|
"""
|
113
203
|
Get specific value
|
114
204
|
|
115
|
-
Args
|
116
|
-
|
117
|
-
|
205
|
+
- **Args**:
|
206
|
+
- `id` (`int`): The id of `Org` or `Agent`.
|
207
|
+
- `key` (`str`): The attribute to fetch.
|
118
208
|
|
119
|
-
Returns
|
120
|
-
|
209
|
+
- **Returns**:
|
210
|
+
- Any
|
121
211
|
"""
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
212
|
+
start_time = time.time()
|
213
|
+
log = {
|
214
|
+
"req": "get",
|
215
|
+
"start_time": start_time,
|
216
|
+
"consumption": 0
|
217
|
+
}
|
218
|
+
if id not in self._agent_ids and id not in self._org_ids:
|
219
|
+
raise ValueError(f"Invalid id {id}, this id does not exist!")
|
220
|
+
request_type = "Org" if id in self._org_ids else "Agent"
|
221
|
+
if request_type == "Org":
|
222
|
+
response = await self.get_org(id)
|
129
223
|
else:
|
130
|
-
|
224
|
+
response = await self.get_agent(id)
|
225
|
+
log["consumption"] = time.time() - start_time
|
226
|
+
self._log_list.append(log)
|
227
|
+
return response[key]
|
131
228
|
|
132
229
|
async def update(
|
133
230
|
self,
|
@@ -139,68 +236,141 @@ class EconomyClient:
|
|
139
236
|
"""
|
140
237
|
Update key-value pair
|
141
238
|
|
142
|
-
Args
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
239
|
+
- **Args**:
|
240
|
+
- `id` (`int`): The id of `Org` or `Agent`.
|
241
|
+
- `key` (`str`): The attribute to update.
|
242
|
+
- `mode` (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
147
243
|
|
148
|
-
Returns
|
149
|
-
|
244
|
+
- **Returns**:
|
245
|
+
- Any
|
150
246
|
"""
|
151
|
-
|
152
|
-
|
153
|
-
|
247
|
+
start_time = time.time()
|
248
|
+
log = {
|
249
|
+
"req": "update",
|
250
|
+
"start_time": start_time,
|
251
|
+
"consumption": 0
|
252
|
+
}
|
253
|
+
if id not in self._agent_ids and id not in self._org_ids:
|
254
|
+
raise ValueError(f"Invalid id {id}, this id does not exist!")
|
255
|
+
request_type = "Org" if id in self._org_ids else "Agent"
|
256
|
+
if request_type == "Org":
|
257
|
+
original_value = await self.get_org(id)
|
258
|
+
else:
|
259
|
+
original_value = await self.get_agent(id)
|
154
260
|
if mode == "merge":
|
155
|
-
orig_value =
|
261
|
+
orig_value = original_value[key]
|
156
262
|
_orig_type = type(orig_value)
|
157
263
|
_new_type = type(value)
|
158
264
|
if _orig_type != _new_type:
|
159
|
-
|
265
|
+
logger.debug(
|
160
266
|
f"Inconsistent type of original value {_orig_type.__name__} and to-update value {_new_type.__name__}"
|
161
267
|
)
|
162
268
|
else:
|
163
269
|
if isinstance(orig_value, set):
|
164
270
|
orig_value.update(set(value))
|
165
|
-
|
271
|
+
original_value[key] = orig_value
|
166
272
|
elif isinstance(orig_value, dict):
|
167
273
|
orig_value.update(dict(value))
|
168
|
-
|
274
|
+
original_value[key] = orig_value
|
169
275
|
elif isinstance(orig_value, list):
|
170
276
|
orig_value.extend(list(value))
|
171
|
-
|
277
|
+
original_value[key] = orig_value
|
172
278
|
else:
|
173
|
-
|
279
|
+
logger.warning(
|
174
280
|
f"Type of {type(orig_value)} does not support mode `merge`, using `replace` instead!"
|
175
281
|
)
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
282
|
+
else:
|
283
|
+
original_value[key] = value
|
284
|
+
if request_type == "Org":
|
285
|
+
await self._aio_stub.UpdateOrg(
|
286
|
+
org_service.UpdateOrgRequest(
|
287
|
+
org=original_value
|
288
|
+
)
|
182
289
|
)
|
183
|
-
|
290
|
+
log["consumption"] = time.time() - start_time
|
291
|
+
self._log_list.append(log)
|
292
|
+
else:
|
293
|
+
await self._aio_stub.UpdateAgent(
|
294
|
+
org_service.UpdateAgentRequest(
|
295
|
+
agent=original_value
|
296
|
+
)
|
297
|
+
)
|
298
|
+
log["consumption"] = time.time() - start_time
|
299
|
+
self._log_list.append(log)
|
184
300
|
|
185
301
|
async def add_agents(self, configs: Union[list[dict], dict]):
|
302
|
+
"""
|
303
|
+
Add one or more agents to the economy system.
|
304
|
+
|
305
|
+
- **Args**:
|
306
|
+
- `configs` (`Union[list[dict], dict]`): A single configuration dictionary or a list of dictionaries,
|
307
|
+
each containing the necessary information to create an agent (e.g., id, currency).
|
308
|
+
|
309
|
+
- **Returns**:
|
310
|
+
- The method does not explicitly return any value but gathers the responses from adding each agent.
|
311
|
+
|
312
|
+
- **Description**:
|
313
|
+
- If a single configuration dictionary is provided, it is converted into a list.
|
314
|
+
- For each configuration in the list, a task is created to asynchronously add an agent using the provided configuration.
|
315
|
+
- All tasks are executed concurrently, and their results are gathered and returned.
|
316
|
+
"""
|
317
|
+
start_time = time.time()
|
318
|
+
log = {
|
319
|
+
"req": "add_agents",
|
320
|
+
"start_time": start_time,
|
321
|
+
"consumption": 0
|
322
|
+
}
|
186
323
|
if isinstance(configs, dict):
|
187
324
|
configs = [configs]
|
325
|
+
for config in configs:
|
326
|
+
self._agent_ids.add(config["id"])
|
188
327
|
tasks = [
|
189
328
|
self._aio_stub.AddAgent(
|
190
329
|
org_service.AddAgentRequest(
|
191
330
|
agent=economyv2.Agent(
|
192
331
|
id=config["id"],
|
193
332
|
currency=config.get("currency", 0.0),
|
333
|
+
skill=config.get("skill", 0.0),
|
334
|
+
consumption=config.get("consumption", 0.0),
|
335
|
+
income=config.get("income", 0.0),
|
194
336
|
)
|
195
337
|
)
|
196
338
|
)
|
197
339
|
for config in configs
|
198
340
|
]
|
199
|
-
|
341
|
+
await asyncio.gather(*tasks)
|
342
|
+
log["consumption"] = time.time() - start_time
|
343
|
+
self._log_list.append(log)
|
200
344
|
|
201
345
|
async def add_orgs(self, configs: Union[list[dict], dict]):
|
346
|
+
"""
|
347
|
+
Add one or more organizations to the economy system.
|
348
|
+
|
349
|
+
- **Args**:
|
350
|
+
- `configs` (`Union[List[Dict], Dict]`): A single configuration dictionary or a list of dictionaries,
|
351
|
+
each containing the necessary information to create an organization (e.g., id, type, nominal_gdp, etc.).
|
352
|
+
|
353
|
+
- **Returns**:
|
354
|
+
- `List`: A list of responses from adding each organization.
|
355
|
+
|
356
|
+
- **Raises**:
|
357
|
+
- `KeyError`: If a required field is missing from the config dictionary.
|
358
|
+
|
359
|
+
- **Description**:
|
360
|
+
- Ensures `configs` is always a list, even if only one config is provided.
|
361
|
+
- For each configuration in the list, creates a task to asynchronously add an organization using the provided configuration.
|
362
|
+
- Executes all tasks concurrently and gathers their results.
|
363
|
+
"""
|
364
|
+
start_time = time.time()
|
365
|
+
log = {
|
366
|
+
"req": "add_orgs",
|
367
|
+
"start_time": start_time,
|
368
|
+
"consumption": 0
|
369
|
+
}
|
202
370
|
if isinstance(configs, dict):
|
203
371
|
configs = [configs]
|
372
|
+
for config in configs:
|
373
|
+
self._org_ids.add(config["id"])
|
204
374
|
tasks = [
|
205
375
|
self._aio_stub.AddOrg(
|
206
376
|
org_service.AddOrgRequest(
|
@@ -218,20 +388,48 @@ class EconomyClient:
|
|
218
388
|
interest_rate=config.get("interest_rate", 0.0),
|
219
389
|
bracket_cutoffs=config.get("bracket_cutoffs", []),
|
220
390
|
bracket_rates=config.get("bracket_rates", []),
|
391
|
+
consumption_currency=config.get("consumption_currency", []),
|
392
|
+
consumption_propensity=config.get("consumption_propensity", []),
|
393
|
+
income_currency=config.get("income_currency", []),
|
394
|
+
depression=config.get("depression", []),
|
395
|
+
locus_control=config.get("locus_control", []),
|
396
|
+
working_hours=config.get("working_hours", []),
|
397
|
+
employees=config.get("employees", []),
|
398
|
+
citizens=config.get("citizens", []),
|
221
399
|
)
|
222
400
|
)
|
223
401
|
)
|
224
402
|
for config in configs
|
225
403
|
]
|
226
|
-
|
404
|
+
await asyncio.gather(*tasks)
|
405
|
+
log["consumption"] = time.time() - start_time
|
406
|
+
self._log_list.append(log)
|
227
407
|
|
228
408
|
async def calculate_taxes_due(
|
229
409
|
self,
|
230
|
-
org_id: int,
|
410
|
+
org_id: Union[int, list[int]],
|
231
411
|
agent_ids: list[int],
|
232
412
|
incomes: list[float],
|
233
413
|
enable_redistribution: bool,
|
234
414
|
):
|
415
|
+
"""
|
416
|
+
Calculate the taxes due for agents based on their incomes.
|
417
|
+
|
418
|
+
- **Args**:
|
419
|
+
- `org_id` (`int`): The ID of the government organization.
|
420
|
+
- `agent_ids` (`List[int]`): A list of IDs for the agents whose taxes are being calculated.
|
421
|
+
- `incomes` (`List[float]`): A list of income values corresponding to each agent.
|
422
|
+
- `enable_redistribution` (`bool`): Flag indicating whether redistribution is enabled.
|
423
|
+
|
424
|
+
- **Returns**:
|
425
|
+
- `Tuple[float, List[float]]`: A tuple containing the total taxes due and updated incomes after tax calculation.
|
426
|
+
"""
|
427
|
+
start_time = time.time()
|
428
|
+
log = {
|
429
|
+
"req": "calculate_taxes_due",
|
430
|
+
"start_time": start_time,
|
431
|
+
"consumption": 0
|
432
|
+
}
|
235
433
|
request = org_service.CalculateTaxesDueRequest(
|
236
434
|
government_id=org_id,
|
237
435
|
agent_ids=agent_ids,
|
@@ -241,22 +439,59 @@ class EconomyClient:
|
|
241
439
|
response: org_service.CalculateTaxesDueResponse = (
|
242
440
|
await self._aio_stub.CalculateTaxesDue(request)
|
243
441
|
)
|
442
|
+
log["consumption"] = time.time() - start_time
|
443
|
+
self._log_list.append(log)
|
244
444
|
return (float(response.taxes_due), list(response.updated_incomes))
|
245
445
|
|
246
446
|
async def calculate_consumption(
|
247
|
-
self,
|
447
|
+
self, org_ids: Union[int, list[int]], agent_id: int, demands: list[int]
|
248
448
|
):
|
449
|
+
"""
|
450
|
+
Calculate consumption for agents based on their demands.
|
451
|
+
|
452
|
+
- **Args**:
|
453
|
+
- `org_ids` (`Union[int, list[int]]`): The ID of the firm providing goods or services.
|
454
|
+
- `agent_id` (`int`): The ID of the agent whose consumption is being calculated.
|
455
|
+
- `demands` (`List[int]`): A list of demand quantities corresponding to each agent.
|
456
|
+
|
457
|
+
- **Returns**:
|
458
|
+
- `Tuple[int, List[float]]`: A tuple containing the remaining inventory and updated currencies for each agent.
|
459
|
+
"""
|
460
|
+
start_time = time.time()
|
461
|
+
log = {
|
462
|
+
"req": "calculate_consumption",
|
463
|
+
"start_time": start_time,
|
464
|
+
"consumption": 0
|
465
|
+
}
|
249
466
|
request = org_service.CalculateConsumptionRequest(
|
250
|
-
|
251
|
-
|
467
|
+
firm_ids=org_ids,
|
468
|
+
agent_id=agent_id,
|
252
469
|
demands=demands,
|
253
470
|
)
|
254
471
|
response: org_service.CalculateConsumptionResponse = (
|
255
472
|
await self._aio_stub.CalculateConsumption(request)
|
256
473
|
)
|
474
|
+
log["consumption"] = time.time() - start_time
|
475
|
+
self._log_list.append(log)
|
257
476
|
return (int(response.remain_inventory), list(response.updated_currencies))
|
258
477
|
|
259
478
|
async def calculate_interest(self, org_id: int, agent_ids: list[int]):
|
479
|
+
"""
|
480
|
+
Calculate interest for agents based on their accounts.
|
481
|
+
|
482
|
+
- **Args**:
|
483
|
+
- `org_id` (`int`): The ID of the bank.
|
484
|
+
- `agent_ids` (`List[int]`): A list of IDs for the agents whose interests are being calculated.
|
485
|
+
|
486
|
+
- **Returns**:
|
487
|
+
- `Tuple[float, List[float]]`: A tuple containing the total interest and updated currencies for each agent.
|
488
|
+
"""
|
489
|
+
start_time = time.time()
|
490
|
+
log = {
|
491
|
+
"req": "calculate_interest",
|
492
|
+
"start_time": start_time,
|
493
|
+
"consumption": 0
|
494
|
+
}
|
260
495
|
request = org_service.CalculateInterestRequest(
|
261
496
|
bank_id=org_id,
|
262
497
|
agent_ids=agent_ids,
|
@@ -264,9 +499,23 @@ class EconomyClient:
|
|
264
499
|
response: org_service.CalculateInterestResponse = (
|
265
500
|
await self._aio_stub.CalculateInterest(request)
|
266
501
|
)
|
502
|
+
log["consumption"] = time.time() - start_time
|
503
|
+
self._log_list.append(log)
|
267
504
|
return (float(response.total_interest), list(response.updated_currencies))
|
268
505
|
|
269
506
|
async def remove_agents(self, agent_ids: Union[int, list[int]]):
|
507
|
+
"""
|
508
|
+
Remove one or more agents from the system.
|
509
|
+
|
510
|
+
- **Args**:
|
511
|
+
- `org_ids` (`Union[int, List[int]]`): A single ID or a list of IDs for the agents to be removed.
|
512
|
+
"""
|
513
|
+
start_time = time.time()
|
514
|
+
log = {
|
515
|
+
"req": "remove_agents",
|
516
|
+
"start_time": start_time,
|
517
|
+
"consumption": 0
|
518
|
+
}
|
270
519
|
if isinstance(agent_ids, int):
|
271
520
|
agent_ids = [agent_ids]
|
272
521
|
tasks = [
|
@@ -275,18 +524,49 @@ class EconomyClient:
|
|
275
524
|
)
|
276
525
|
for agent_id in agent_ids
|
277
526
|
]
|
278
|
-
|
527
|
+
await asyncio.gather(*tasks)
|
528
|
+
log["consumption"] = time.time() - start_time
|
529
|
+
self._log_list.append(log)
|
279
530
|
|
280
531
|
async def remove_orgs(self, org_ids: Union[int, list[int]]):
|
532
|
+
"""
|
533
|
+
Remove one or more organizations from the system.
|
534
|
+
|
535
|
+
- **Args**:
|
536
|
+
- `org_ids` (`Union[int, List[int]]`): A single ID or a list of IDs for the organizations to be removed.
|
537
|
+
"""
|
538
|
+
start_time = time.time()
|
539
|
+
log = {
|
540
|
+
"req": "remove_orgs",
|
541
|
+
"start_time": start_time,
|
542
|
+
"consumption": 0
|
543
|
+
}
|
281
544
|
if isinstance(org_ids, int):
|
282
545
|
org_ids = [org_ids]
|
283
546
|
tasks = [
|
284
547
|
self._aio_stub.RemoveOrg(org_service.RemoveOrgRequest(org_id=org_id))
|
285
548
|
for org_id in org_ids
|
286
549
|
]
|
287
|
-
|
550
|
+
await asyncio.gather(*tasks)
|
551
|
+
log["consumption"] = time.time() - start_time
|
552
|
+
self._log_list.append(log)
|
288
553
|
|
289
554
|
async def save(self, file_path: str) -> tuple[list[int], list[int]]:
|
555
|
+
"""
|
556
|
+
Save the current state of all economy entities to a specified file.
|
557
|
+
|
558
|
+
- **Args**:
|
559
|
+
- `file_path` (`str`): The path to the file where the economy entities will be saved.
|
560
|
+
|
561
|
+
- **Returns**:
|
562
|
+
- `Tuple[List[int], List[int]]`: A tuple containing lists of agent IDs and organization IDs that were saved.
|
563
|
+
"""
|
564
|
+
start_time = time.time()
|
565
|
+
log = {
|
566
|
+
"req": "save",
|
567
|
+
"start_time": start_time,
|
568
|
+
"consumption": 0
|
569
|
+
}
|
290
570
|
request = org_service.SaveEconomyEntitiesRequest(
|
291
571
|
file_path=file_path,
|
292
572
|
)
|
@@ -294,9 +574,26 @@ class EconomyClient:
|
|
294
574
|
await self._aio_stub.SaveEconomyEntities(request)
|
295
575
|
)
|
296
576
|
# current agent ids and org ids
|
577
|
+
log["consumption"] = time.time() - start_time
|
578
|
+
self._log_list.append(log)
|
297
579
|
return (list(response.agent_ids), list(response.org_ids))
|
298
580
|
|
299
581
|
async def load(self, file_path: str):
|
582
|
+
"""
|
583
|
+
Load the state of economy entities from a specified file.
|
584
|
+
|
585
|
+
- **Args**:
|
586
|
+
- `file_path` (`str`): The path to the file from which the economy entities will be loaded.
|
587
|
+
|
588
|
+
- **Returns**:
|
589
|
+
- `Tuple[List[int], List[int]]`: A tuple containing lists of agent IDs and organization IDs that were loaded.
|
590
|
+
"""
|
591
|
+
start_time = time.time()
|
592
|
+
log = {
|
593
|
+
"req": "load",
|
594
|
+
"start_time": start_time,
|
595
|
+
"consumption": 0
|
596
|
+
}
|
300
597
|
request = org_service.LoadEconomyEntitiesRequest(
|
301
598
|
file_path=file_path,
|
302
599
|
)
|
@@ -304,46 +601,72 @@ class EconomyClient:
|
|
304
601
|
await self._aio_stub.LoadEconomyEntities(request)
|
305
602
|
)
|
306
603
|
# current agent ids and org ids
|
604
|
+
log["consumption"] = time.time() - start_time
|
605
|
+
self._log_list.append(log)
|
307
606
|
return (list(response.agent_ids), list(response.org_ids))
|
308
607
|
|
309
608
|
async def get_org_entity_ids(self, org_type: economyv2.OrgType) -> list[int]:
|
609
|
+
"""
|
610
|
+
Get the IDs of all organizations of a specific type.
|
611
|
+
|
612
|
+
- **Args**:
|
613
|
+
- `org_type` (`economyv2.OrgType`): The type of organizations whose IDs are to be retrieved.
|
614
|
+
|
615
|
+
- **Returns**:
|
616
|
+
- `List[int]`: A list of organization IDs matching the specified type.
|
617
|
+
"""
|
618
|
+
start_time = time.time()
|
619
|
+
log = {
|
620
|
+
"req": "get_org_entity_ids",
|
621
|
+
"start_time": start_time,
|
622
|
+
"consumption": 0
|
623
|
+
}
|
310
624
|
request = org_service.GetOrgEntityIdsRequest(
|
311
625
|
type=org_type,
|
312
626
|
)
|
313
627
|
response: org_service.GetOrgEntityIdsResponse = (
|
314
628
|
await self._aio_stub.GetOrgEntityIds(request)
|
315
629
|
)
|
630
|
+
log["consumption"] = time.time() - start_time
|
631
|
+
self._log_list.append(log)
|
316
632
|
return list(response.org_ids)
|
317
633
|
|
318
634
|
async def add_delta_value(
|
319
635
|
self,
|
320
|
-
id: int,
|
636
|
+
id: Union[int, list[int]],
|
321
637
|
key: str,
|
322
638
|
value: Any,
|
323
639
|
) -> Any:
|
324
640
|
"""
|
325
|
-
Add
|
641
|
+
Add value pair
|
326
642
|
|
327
|
-
Args
|
328
|
-
|
329
|
-
|
643
|
+
- **Args**:
|
644
|
+
- `id` (`int`): The id of `Org` or `Agent`.
|
645
|
+
- `key` (`str`): The attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
|
330
646
|
|
331
|
-
|
332
|
-
|
333
|
-
- Any
|
647
|
+
- **Returns**:
|
648
|
+
- Any
|
334
649
|
"""
|
650
|
+
start_time = time.time()
|
651
|
+
log = {
|
652
|
+
"req": "add_delta_value",
|
653
|
+
"start_time": start_time,
|
654
|
+
"consumption": 0
|
655
|
+
}
|
335
656
|
pascal_key = _snake_to_pascal(key)
|
336
657
|
_request_type = getattr(org_service, f"Add{pascal_key}Request")
|
337
658
|
_request_func = getattr(self._aio_stub, f"Add{pascal_key}")
|
659
|
+
|
338
660
|
_available_keys = {
|
339
661
|
"inventory",
|
340
662
|
"price",
|
341
663
|
"interest_rate",
|
342
664
|
"currency",
|
665
|
+
"income"
|
343
666
|
}
|
344
667
|
if key not in _available_keys:
|
345
668
|
raise ValueError(f"Invalid key `{key}`, can only be {_available_keys}!")
|
346
|
-
|
669
|
+
response = await _request_func(
|
347
670
|
_request_type(
|
348
671
|
**{
|
349
672
|
"org_id": id,
|
@@ -351,3 +674,6 @@ class EconomyClient:
|
|
351
674
|
}
|
352
675
|
)
|
353
676
|
)
|
677
|
+
log["consumption"] = time.time() - start_time
|
678
|
+
self._log_list.append(log)
|
679
|
+
return response
|