pycityagent 2.0.0a65__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a67__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. pycityagent/agent/agent.py +157 -57
  2. pycityagent/agent/agent_base.py +316 -43
  3. pycityagent/cityagent/bankagent.py +49 -9
  4. pycityagent/cityagent/blocks/__init__.py +1 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +54 -31
  6. pycityagent/cityagent/blocks/dispatcher.py +22 -17
  7. pycityagent/cityagent/blocks/economy_block.py +46 -32
  8. pycityagent/cityagent/blocks/mobility_block.py +209 -105
  9. pycityagent/cityagent/blocks/needs_block.py +101 -54
  10. pycityagent/cityagent/blocks/other_block.py +42 -33
  11. pycityagent/cityagent/blocks/plan_block.py +59 -42
  12. pycityagent/cityagent/blocks/social_block.py +167 -126
  13. pycityagent/cityagent/blocks/utils.py +13 -6
  14. pycityagent/cityagent/firmagent.py +17 -35
  15. pycityagent/cityagent/governmentagent.py +3 -3
  16. pycityagent/cityagent/initial.py +79 -49
  17. pycityagent/cityagent/memory_config.py +123 -94
  18. pycityagent/cityagent/message_intercept.py +0 -4
  19. pycityagent/cityagent/metrics.py +41 -0
  20. pycityagent/cityagent/nbsagent.py +24 -36
  21. pycityagent/cityagent/societyagent.py +9 -4
  22. pycityagent/cli/wrapper.py +2 -2
  23. pycityagent/economy/econ_client.py +407 -81
  24. pycityagent/environment/__init__.py +0 -3
  25. pycityagent/environment/sim/__init__.py +0 -3
  26. pycityagent/environment/sim/aoi_service.py +2 -2
  27. pycityagent/environment/sim/client.py +3 -31
  28. pycityagent/environment/sim/clock_service.py +2 -2
  29. pycityagent/environment/sim/lane_service.py +8 -8
  30. pycityagent/environment/sim/light_service.py +8 -8
  31. pycityagent/environment/sim/pause_service.py +9 -10
  32. pycityagent/environment/sim/person_service.py +20 -20
  33. pycityagent/environment/sim/road_service.py +2 -2
  34. pycityagent/environment/sim/sim_env.py +21 -5
  35. pycityagent/environment/sim/social_service.py +4 -4
  36. pycityagent/environment/simulator.py +249 -27
  37. pycityagent/environment/utils/__init__.py +2 -2
  38. pycityagent/environment/utils/geojson.py +2 -2
  39. pycityagent/environment/utils/grpc.py +4 -4
  40. pycityagent/environment/utils/map_utils.py +2 -2
  41. pycityagent/llm/embeddings.py +147 -28
  42. pycityagent/llm/llm.py +178 -111
  43. pycityagent/llm/llmconfig.py +5 -0
  44. pycityagent/llm/utils.py +4 -0
  45. pycityagent/memory/__init__.py +0 -4
  46. pycityagent/memory/const.py +2 -2
  47. pycityagent/memory/faiss_query.py +140 -61
  48. pycityagent/memory/memory.py +394 -91
  49. pycityagent/memory/memory_base.py +140 -34
  50. pycityagent/memory/profile.py +13 -13
  51. pycityagent/memory/self_define.py +13 -13
  52. pycityagent/memory/state.py +14 -14
  53. pycityagent/message/message_interceptor.py +253 -3
  54. pycityagent/message/messager.py +133 -6
  55. pycityagent/metrics/mlflow_client.py +47 -4
  56. pycityagent/pycityagent-sim +0 -0
  57. pycityagent/pycityagent-ui +0 -0
  58. pycityagent/simulation/__init__.py +3 -2
  59. pycityagent/simulation/agentgroup.py +150 -54
  60. pycityagent/simulation/simulation.py +276 -66
  61. pycityagent/survey/manager.py +45 -3
  62. pycityagent/survey/models.py +42 -2
  63. pycityagent/tools/__init__.py +1 -2
  64. pycityagent/tools/tool.py +93 -69
  65. pycityagent/utils/avro_schema.py +2 -2
  66. pycityagent/utils/parsers/code_block_parser.py +1 -1
  67. pycityagent/utils/parsers/json_parser.py +2 -2
  68. pycityagent/utils/parsers/parser_base.py +2 -2
  69. pycityagent/workflow/block.py +64 -13
  70. pycityagent/workflow/prompt.py +31 -23
  71. pycityagent/workflow/trigger.py +91 -24
  72. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/METADATA +2 -2
  73. pycityagent-2.0.0a67.dist-info/RECORD +97 -0
  74. pycityagent/environment/interact/__init__.py +0 -0
  75. pycityagent/environment/interact/interact.py +0 -198
  76. pycityagent/environment/message/__init__.py +0 -0
  77. pycityagent/environment/sence/__init__.py +0 -0
  78. pycityagent/environment/sence/static.py +0 -416
  79. pycityagent/environment/sidecar/__init__.py +0 -8
  80. pycityagent/environment/sidecar/sidecarv2.py +0 -109
  81. pycityagent/environment/sim/economy_services.py +0 -192
  82. pycityagent/metrics/utils/const.py +0 -0
  83. pycityagent-2.0.0a65.dist-info/RECORD +0 -105
  84. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/LICENSE +0 -0
  85. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/WHEEL +0 -0
  86. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/entry_points.txt +0 -0
  87. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  import asyncio
2
2
  import logging
3
+ import re
4
+ import time
3
5
  from typing import Any, Literal, Union
4
6
 
5
7
  import grpc
@@ -7,6 +9,9 @@ import pycityproto.city.economy.v2.economy_pb2 as economyv2
7
9
  import pycityproto.city.economy.v2.org_service_pb2 as org_service
8
10
  import pycityproto.city.economy.v2.org_service_pb2_grpc as org_grpc
9
11
  from google.protobuf import descriptor
12
+ from google.protobuf.json_format import MessageToDict
13
+
14
+ logger = logging.getLogger("pycityagent")
10
15
 
11
16
  __all__ = [
12
17
  "EconomyClient",
@@ -22,33 +27,30 @@ def _snake_to_pascal(snake_str):
22
27
  _res = _res.replace(_word, _word.upper())
23
28
  return _res
24
29
 
25
-
26
- def _get_field_type_and_repeated(message, field_name: str) -> tuple[Any, bool]:
27
- try:
28
- field_descriptor = message.DESCRIPTOR.fields_by_name[field_name]
29
- field_type = field_descriptor.type
30
- _type_mapping = {
31
- descriptor.FieldDescriptor.TYPE_FLOAT: float,
32
- descriptor.FieldDescriptor.TYPE_INT32: int,
33
- }
34
- is_repeated = (
35
- field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED
36
- )
37
- return (_type_mapping.get(field_type), is_repeated)
38
- except KeyError:
39
- raise KeyError(f"Invalid message {message} and filed name {field_name}!")
40
-
30
+ def camel_to_snake(d):
31
+ if not isinstance(d, dict):
32
+ return d
33
+ return {re.sub('([a-z0-9])([A-Z])', r'\1_\2', k).lower(): camel_to_snake(v) if isinstance(v, dict) else v
34
+ for k, v in d.items()}
41
35
 
42
36
  def _create_aio_channel(server_address: str, secure: bool = False) -> grpc.aio.Channel:
43
37
  """
44
- Create a grpc asynchronous channel
38
+ Create a gRPC asynchronous channel.
39
+
40
+ - **Args**:
41
+ - `server_address` (`str`): The address of the server to connect to.
42
+ - `secure` (`bool`, optional): Whether to use a secure connection. Defaults to `False`.
43
+
44
+ - **Returns**:
45
+ - `grpc.aio.Channel`: A gRPC asynchronous channel for making RPC calls.
45
46
 
46
- Args:
47
- - server_address (str): server address.
48
- - secure (bool, optional): Defaults to False. Whether to use a secure connection. Defaults to False.
47
+ - **Raises**:
48
+ - `ValueError`: If a secure channel is requested but the server address starts with `http://`.
49
49
 
50
- Returns:
51
- - grpc.aio.Channel: grpc asynchronous channel.
50
+ - **Description**:
51
+ - This function creates and returns a gRPC asynchronous channel based on the provided server address and security flag.
52
+ - It ensures that if `secure=True`, then the server address does not start with `http://`.
53
+ - If the server address starts with `https://`, it will automatically switch to a secure connection even if `secure=False`.
52
54
  """
53
55
  if server_address.startswith("http://"):
54
56
  server_address = server_address.split("//")[1]
@@ -67,21 +69,44 @@ def _create_aio_channel(server_address: str, secure: bool = False) -> grpc.aio.C
67
69
 
68
70
  class EconomyClient:
69
71
  """
70
- Client side of Economy service
72
+ Client side of Economy service.
73
+
74
+ - **Description**:
75
+ - This class serves as a client interface to interact with the Economy Simulator via gRPC.
76
+ - It establishes an asynchronous connection and provides methods to communicate with the service.
71
77
  """
72
78
 
73
79
  def __init__(self, server_address: str, secure: bool = False):
74
80
  """
75
- Constructor of EconomyClient
81
+ Initialize the EconomyClient.
82
+
83
+ - **Args**:
84
+ - `server_address` (`str`): The address of the Economy server to connect to.
85
+ - `secure` (`bool`, optional): Whether to use a secure connection. Defaults to `False`.
86
+
87
+ - **Attributes**:
88
+ - `server_address` (`str`): The address of the Economy server.
89
+ - `secure` (`bool`): A flag indicating if a secure connection should be used.
90
+ - `_aio_stub` (`OrgServiceStub`): A gRPC stub used to make remote calls to the Economy service.
76
91
 
77
- Args:
78
- - server_address (str): Economy server address
79
- - secure (bool, optional): Defaults to False. Whether to use a secure connection. Defaults to False.
92
+ - **Description**:
93
+ - Initializes the EconomyClient with the specified server address and security preference.
94
+ - Creates an asynchronous gRPC channel using `_create_aio_channel`.
95
+ - Instantiates a gRPC stub (`_aio_stub`) for interacting with the Economy service.
80
96
  """
81
97
  self.server_address = server_address
82
98
  self.secure = secure
83
99
  aio_channel = _create_aio_channel(server_address, secure)
84
100
  self._aio_stub = org_grpc.OrgServiceStub(aio_channel)
101
+ self._agent_ids = set()
102
+ self._org_ids = set()
103
+ self._log_list = []
104
+
105
+ def get_log_list(self):
106
+ return self._log_list
107
+
108
+ def clear_log_list(self):
109
+ self._log_list = []
85
110
 
86
111
  def __getstate__(self):
87
112
  """
@@ -95,7 +120,7 @@ class EconomyClient:
95
120
  return state
96
121
 
97
122
  def __setstate__(self, state):
98
- """ "
123
+ """
99
124
  Restore instance attributes (i.e., filename and mode) from the
100
125
  unpickled state dictionary.
101
126
  """
@@ -104,6 +129,71 @@ class EconomyClient:
104
129
  aio_channel = _create_aio_channel(self.server_address, self.secure)
105
130
  self._aio_stub = org_grpc.OrgServiceStub(aio_channel)
106
131
 
132
+ async def get_ids(self):
133
+ """
134
+ Get the ids of agents and orgs
135
+ """
136
+ return self._agent_ids, self._org_ids
137
+
138
+ async def set_ids(self, agent_ids: set[int], org_ids: set[int]):
139
+ """
140
+ Set the ids of agents and orgs
141
+ """
142
+ self._agent_ids = agent_ids
143
+ self._org_ids = org_ids
144
+
145
+ async def get_agent(self, id: int) -> economyv2.Agent:
146
+ """
147
+ Get agent by id
148
+
149
+ - **Args**:
150
+ - `id` (`int`): The id of the agent.
151
+
152
+ - **Returns**:
153
+ - `economyv2.Agent`: The agent object.
154
+ """
155
+ start_time = time.time()
156
+ log = {
157
+ "req": "get_agent",
158
+ "start_time": start_time,
159
+ "consumption": 0
160
+ }
161
+ agent = await self._aio_stub.GetAgent(
162
+ org_service.GetAgentRequest(
163
+ agent_id=id
164
+ )
165
+ )
166
+ agent_dict = MessageToDict(agent)["agent"]
167
+ log["consumption"] = time.time() - start_time
168
+ self._log_list.append(log)
169
+ return camel_to_snake(agent_dict)
170
+
171
+ async def get_org(self, id: int) -> economyv2.Org:
172
+ """
173
+ Get org by id
174
+
175
+ - **Args**:
176
+ - `id` (`int`): The id of the org.
177
+
178
+ - **Returns**:
179
+ - `economyv2.Org`: The org object.
180
+ """
181
+ start_time = time.time()
182
+ log = {
183
+ "req": "get_org",
184
+ "start_time": start_time,
185
+ "consumption": 0
186
+ }
187
+ org = await self._aio_stub.GetOrg(
188
+ org_service.GetOrgRequest(
189
+ org_id=id
190
+ )
191
+ )
192
+ org_dict = MessageToDict(org)["org"]
193
+ log["consumption"] = time.time() - start_time
194
+ self._log_list.append(log)
195
+ return camel_to_snake(org_dict)
196
+
107
197
  async def get(
108
198
  self,
109
199
  id: int,
@@ -112,22 +202,29 @@ class EconomyClient:
112
202
  """
113
203
  Get specific value
114
204
 
115
- Args:
116
- - id (int): the id of `Org` or `Agent`.
117
- - key (str): the attribute to fetch.
205
+ - **Args**:
206
+ - `id` (`int`): The id of `Org` or `Agent`.
207
+ - `key` (`str`): The attribute to fetch.
118
208
 
119
- Returns:
120
- - Any
209
+ - **Returns**:
210
+ - Any
121
211
  """
122
- pascal_key = _snake_to_pascal(key)
123
- _request_type = getattr(org_service, f"Get{pascal_key}Request")
124
- _request_func = getattr(self._aio_stub, f"Get{pascal_key}")
125
- response = await _request_func(_request_type(org_id=id))
126
- value_type, is_repeated = _get_field_type_and_repeated(response, field_name=key)
127
- if is_repeated:
128
- return list(getattr(response, key))
212
+ start_time = time.time()
213
+ log = {
214
+ "req": "get",
215
+ "start_time": start_time,
216
+ "consumption": 0
217
+ }
218
+ if id not in self._agent_ids and id not in self._org_ids:
219
+ raise ValueError(f"Invalid id {id}, this id does not exist!")
220
+ request_type = "Org" if id in self._org_ids else "Agent"
221
+ if request_type == "Org":
222
+ response = await self.get_org(id)
129
223
  else:
130
- return value_type(getattr(response, key))
224
+ response = await self.get_agent(id)
225
+ log["consumption"] = time.time() - start_time
226
+ self._log_list.append(log)
227
+ return response[key]
131
228
 
132
229
  async def update(
133
230
  self,
@@ -139,68 +236,141 @@ class EconomyClient:
139
236
  """
140
237
  Update key-value pair
141
238
 
142
- Args:
143
- - id (int): the id of `Org` or `Agent`.
144
- - key (str): the attribute to update.
145
- - mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
146
-
239
+ - **Args**:
240
+ - `id` (`int`): The id of `Org` or `Agent`.
241
+ - `key` (`str`): The attribute to update.
242
+ - `mode` (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
147
243
 
148
- Returns:
149
- - Any
244
+ - **Returns**:
245
+ - Any
150
246
  """
151
- pascal_key = _snake_to_pascal(key)
152
- _request_type = getattr(org_service, f"Set{pascal_key}Request")
153
- _request_func = getattr(self._aio_stub, f"Set{pascal_key}")
247
+ start_time = time.time()
248
+ log = {
249
+ "req": "update",
250
+ "start_time": start_time,
251
+ "consumption": 0
252
+ }
253
+ if id not in self._agent_ids and id not in self._org_ids:
254
+ raise ValueError(f"Invalid id {id}, this id does not exist!")
255
+ request_type = "Org" if id in self._org_ids else "Agent"
256
+ if request_type == "Org":
257
+ original_value = await self.get_org(id)
258
+ else:
259
+ original_value = await self.get_agent(id)
154
260
  if mode == "merge":
155
- orig_value = await self.get(id, key)
261
+ orig_value = original_value[key]
156
262
  _orig_type = type(orig_value)
157
263
  _new_type = type(value)
158
264
  if _orig_type != _new_type:
159
- logging.debug(
265
+ logger.debug(
160
266
  f"Inconsistent type of original value {_orig_type.__name__} and to-update value {_new_type.__name__}"
161
267
  )
162
268
  else:
163
269
  if isinstance(orig_value, set):
164
270
  orig_value.update(set(value))
165
- value = orig_value
271
+ original_value[key] = orig_value
166
272
  elif isinstance(orig_value, dict):
167
273
  orig_value.update(dict(value))
168
- value = orig_value
274
+ original_value[key] = orig_value
169
275
  elif isinstance(orig_value, list):
170
276
  orig_value.extend(list(value))
171
- value = orig_value
277
+ original_value[key] = orig_value
172
278
  else:
173
- logging.warning(
279
+ logger.warning(
174
280
  f"Type of {type(orig_value)} does not support mode `merge`, using `replace` instead!"
175
281
  )
176
- return await _request_func(
177
- _request_type(
178
- **{
179
- "org_id": id,
180
- key: value,
181
- }
282
+ else:
283
+ original_value[key] = value
284
+ if request_type == "Org":
285
+ await self._aio_stub.UpdateOrg(
286
+ org_service.UpdateOrgRequest(
287
+ org=original_value
288
+ )
182
289
  )
183
- )
290
+ log["consumption"] = time.time() - start_time
291
+ self._log_list.append(log)
292
+ else:
293
+ await self._aio_stub.UpdateAgent(
294
+ org_service.UpdateAgentRequest(
295
+ agent=original_value
296
+ )
297
+ )
298
+ log["consumption"] = time.time() - start_time
299
+ self._log_list.append(log)
184
300
 
185
301
  async def add_agents(self, configs: Union[list[dict], dict]):
302
+ """
303
+ Add one or more agents to the economy system.
304
+
305
+ - **Args**:
306
+ - `configs` (`Union[list[dict], dict]`): A single configuration dictionary or a list of dictionaries,
307
+ each containing the necessary information to create an agent (e.g., id, currency).
308
+
309
+ - **Returns**:
310
+ - The method does not explicitly return any value but gathers the responses from adding each agent.
311
+
312
+ - **Description**:
313
+ - If a single configuration dictionary is provided, it is converted into a list.
314
+ - For each configuration in the list, a task is created to asynchronously add an agent using the provided configuration.
315
+ - All tasks are executed concurrently, and their results are gathered and returned.
316
+ """
317
+ start_time = time.time()
318
+ log = {
319
+ "req": "add_agents",
320
+ "start_time": start_time,
321
+ "consumption": 0
322
+ }
186
323
  if isinstance(configs, dict):
187
324
  configs = [configs]
325
+ for config in configs:
326
+ self._agent_ids.add(config["id"])
188
327
  tasks = [
189
328
  self._aio_stub.AddAgent(
190
329
  org_service.AddAgentRequest(
191
330
  agent=economyv2.Agent(
192
331
  id=config["id"],
193
332
  currency=config.get("currency", 0.0),
333
+ skill=config.get("skill", 0.0),
334
+ consumption=config.get("consumption", 0.0),
335
+ income=config.get("income", 0.0),
194
336
  )
195
337
  )
196
338
  )
197
339
  for config in configs
198
340
  ]
199
- responses = await asyncio.gather(*tasks)
341
+ await asyncio.gather(*tasks)
342
+ log["consumption"] = time.time() - start_time
343
+ self._log_list.append(log)
200
344
 
201
345
  async def add_orgs(self, configs: Union[list[dict], dict]):
346
+ """
347
+ Add one or more organizations to the economy system.
348
+
349
+ - **Args**:
350
+ - `configs` (`Union[List[Dict], Dict]`): A single configuration dictionary or a list of dictionaries,
351
+ each containing the necessary information to create an organization (e.g., id, type, nominal_gdp, etc.).
352
+
353
+ - **Returns**:
354
+ - `List`: A list of responses from adding each organization.
355
+
356
+ - **Raises**:
357
+ - `KeyError`: If a required field is missing from the config dictionary.
358
+
359
+ - **Description**:
360
+ - Ensures `configs` is always a list, even if only one config is provided.
361
+ - For each configuration in the list, creates a task to asynchronously add an organization using the provided configuration.
362
+ - Executes all tasks concurrently and gathers their results.
363
+ """
364
+ start_time = time.time()
365
+ log = {
366
+ "req": "add_orgs",
367
+ "start_time": start_time,
368
+ "consumption": 0
369
+ }
202
370
  if isinstance(configs, dict):
203
371
  configs = [configs]
372
+ for config in configs:
373
+ self._org_ids.add(config["id"])
204
374
  tasks = [
205
375
  self._aio_stub.AddOrg(
206
376
  org_service.AddOrgRequest(
@@ -218,20 +388,48 @@ class EconomyClient:
218
388
  interest_rate=config.get("interest_rate", 0.0),
219
389
  bracket_cutoffs=config.get("bracket_cutoffs", []),
220
390
  bracket_rates=config.get("bracket_rates", []),
391
+ consumption_currency=config.get("consumption_currency", []),
392
+ consumption_propensity=config.get("consumption_propensity", []),
393
+ income_currency=config.get("income_currency", []),
394
+ depression=config.get("depression", []),
395
+ locus_control=config.get("locus_control", []),
396
+ working_hours=config.get("working_hours", []),
397
+ employees=config.get("employees", []),
398
+ citizens=config.get("citizens", []),
221
399
  )
222
400
  )
223
401
  )
224
402
  for config in configs
225
403
  ]
226
- responses = await asyncio.gather(*tasks)
404
+ await asyncio.gather(*tasks)
405
+ log["consumption"] = time.time() - start_time
406
+ self._log_list.append(log)
227
407
 
228
408
  async def calculate_taxes_due(
229
409
  self,
230
- org_id: int,
410
+ org_id: Union[int, list[int]],
231
411
  agent_ids: list[int],
232
412
  incomes: list[float],
233
413
  enable_redistribution: bool,
234
414
  ):
415
+ """
416
+ Calculate the taxes due for agents based on their incomes.
417
+
418
+ - **Args**:
419
+ - `org_id` (`int`): The ID of the government organization.
420
+ - `agent_ids` (`List[int]`): A list of IDs for the agents whose taxes are being calculated.
421
+ - `incomes` (`List[float]`): A list of income values corresponding to each agent.
422
+ - `enable_redistribution` (`bool`): Flag indicating whether redistribution is enabled.
423
+
424
+ - **Returns**:
425
+ - `Tuple[float, List[float]]`: A tuple containing the total taxes due and updated incomes after tax calculation.
426
+ """
427
+ start_time = time.time()
428
+ log = {
429
+ "req": "calculate_taxes_due",
430
+ "start_time": start_time,
431
+ "consumption": 0
432
+ }
235
433
  request = org_service.CalculateTaxesDueRequest(
236
434
  government_id=org_id,
237
435
  agent_ids=agent_ids,
@@ -241,22 +439,59 @@ class EconomyClient:
241
439
  response: org_service.CalculateTaxesDueResponse = (
242
440
  await self._aio_stub.CalculateTaxesDue(request)
243
441
  )
442
+ log["consumption"] = time.time() - start_time
443
+ self._log_list.append(log)
244
444
  return (float(response.taxes_due), list(response.updated_incomes))
245
445
 
246
446
  async def calculate_consumption(
247
- self, org_id: int, agent_ids: list[int], demands: list[int]
447
+ self, org_ids: Union[int, list[int]], agent_id: int, demands: list[int]
248
448
  ):
449
+ """
450
+ Calculate consumption for agents based on their demands.
451
+
452
+ - **Args**:
453
+ - `org_ids` (`Union[int, list[int]]`): The ID of the firm providing goods or services.
454
+ - `agent_id` (`int`): The ID of the agent whose consumption is being calculated.
455
+ - `demands` (`List[int]`): A list of demand quantities corresponding to each agent.
456
+
457
+ - **Returns**:
458
+ - `Tuple[int, List[float]]`: A tuple containing the remaining inventory and updated currencies for each agent.
459
+ """
460
+ start_time = time.time()
461
+ log = {
462
+ "req": "calculate_consumption",
463
+ "start_time": start_time,
464
+ "consumption": 0
465
+ }
249
466
  request = org_service.CalculateConsumptionRequest(
250
- firm_id=org_id,
251
- agent_ids=agent_ids,
467
+ firm_ids=org_ids,
468
+ agent_id=agent_id,
252
469
  demands=demands,
253
470
  )
254
471
  response: org_service.CalculateConsumptionResponse = (
255
472
  await self._aio_stub.CalculateConsumption(request)
256
473
  )
474
+ log["consumption"] = time.time() - start_time
475
+ self._log_list.append(log)
257
476
  return (int(response.remain_inventory), list(response.updated_currencies))
258
477
 
259
478
  async def calculate_interest(self, org_id: int, agent_ids: list[int]):
479
+ """
480
+ Calculate interest for agents based on their accounts.
481
+
482
+ - **Args**:
483
+ - `org_id` (`int`): The ID of the bank.
484
+ - `agent_ids` (`List[int]`): A list of IDs for the agents whose interests are being calculated.
485
+
486
+ - **Returns**:
487
+ - `Tuple[float, List[float]]`: A tuple containing the total interest and updated currencies for each agent.
488
+ """
489
+ start_time = time.time()
490
+ log = {
491
+ "req": "calculate_interest",
492
+ "start_time": start_time,
493
+ "consumption": 0
494
+ }
260
495
  request = org_service.CalculateInterestRequest(
261
496
  bank_id=org_id,
262
497
  agent_ids=agent_ids,
@@ -264,9 +499,23 @@ class EconomyClient:
264
499
  response: org_service.CalculateInterestResponse = (
265
500
  await self._aio_stub.CalculateInterest(request)
266
501
  )
502
+ log["consumption"] = time.time() - start_time
503
+ self._log_list.append(log)
267
504
  return (float(response.total_interest), list(response.updated_currencies))
268
505
 
269
506
  async def remove_agents(self, agent_ids: Union[int, list[int]]):
507
+ """
508
+ Remove one or more agents from the system.
509
+
510
+ - **Args**:
511
+ - `org_ids` (`Union[int, List[int]]`): A single ID or a list of IDs for the agents to be removed.
512
+ """
513
+ start_time = time.time()
514
+ log = {
515
+ "req": "remove_agents",
516
+ "start_time": start_time,
517
+ "consumption": 0
518
+ }
270
519
  if isinstance(agent_ids, int):
271
520
  agent_ids = [agent_ids]
272
521
  tasks = [
@@ -275,18 +524,49 @@ class EconomyClient:
275
524
  )
276
525
  for agent_id in agent_ids
277
526
  ]
278
- responses = await asyncio.gather(*tasks)
527
+ await asyncio.gather(*tasks)
528
+ log["consumption"] = time.time() - start_time
529
+ self._log_list.append(log)
279
530
 
280
531
  async def remove_orgs(self, org_ids: Union[int, list[int]]):
532
+ """
533
+ Remove one or more organizations from the system.
534
+
535
+ - **Args**:
536
+ - `org_ids` (`Union[int, List[int]]`): A single ID or a list of IDs for the organizations to be removed.
537
+ """
538
+ start_time = time.time()
539
+ log = {
540
+ "req": "remove_orgs",
541
+ "start_time": start_time,
542
+ "consumption": 0
543
+ }
281
544
  if isinstance(org_ids, int):
282
545
  org_ids = [org_ids]
283
546
  tasks = [
284
547
  self._aio_stub.RemoveOrg(org_service.RemoveOrgRequest(org_id=org_id))
285
548
  for org_id in org_ids
286
549
  ]
287
- responses = await asyncio.gather(*tasks)
550
+ await asyncio.gather(*tasks)
551
+ log["consumption"] = time.time() - start_time
552
+ self._log_list.append(log)
288
553
 
289
554
  async def save(self, file_path: str) -> tuple[list[int], list[int]]:
555
+ """
556
+ Save the current state of all economy entities to a specified file.
557
+
558
+ - **Args**:
559
+ - `file_path` (`str`): The path to the file where the economy entities will be saved.
560
+
561
+ - **Returns**:
562
+ - `Tuple[List[int], List[int]]`: A tuple containing lists of agent IDs and organization IDs that were saved.
563
+ """
564
+ start_time = time.time()
565
+ log = {
566
+ "req": "save",
567
+ "start_time": start_time,
568
+ "consumption": 0
569
+ }
290
570
  request = org_service.SaveEconomyEntitiesRequest(
291
571
  file_path=file_path,
292
572
  )
@@ -294,9 +574,26 @@ class EconomyClient:
294
574
  await self._aio_stub.SaveEconomyEntities(request)
295
575
  )
296
576
  # current agent ids and org ids
577
+ log["consumption"] = time.time() - start_time
578
+ self._log_list.append(log)
297
579
  return (list(response.agent_ids), list(response.org_ids))
298
580
 
299
581
  async def load(self, file_path: str):
582
+ """
583
+ Load the state of economy entities from a specified file.
584
+
585
+ - **Args**:
586
+ - `file_path` (`str`): The path to the file from which the economy entities will be loaded.
587
+
588
+ - **Returns**:
589
+ - `Tuple[List[int], List[int]]`: A tuple containing lists of agent IDs and organization IDs that were loaded.
590
+ """
591
+ start_time = time.time()
592
+ log = {
593
+ "req": "load",
594
+ "start_time": start_time,
595
+ "consumption": 0
596
+ }
300
597
  request = org_service.LoadEconomyEntitiesRequest(
301
598
  file_path=file_path,
302
599
  )
@@ -304,46 +601,72 @@ class EconomyClient:
304
601
  await self._aio_stub.LoadEconomyEntities(request)
305
602
  )
306
603
  # current agent ids and org ids
604
+ log["consumption"] = time.time() - start_time
605
+ self._log_list.append(log)
307
606
  return (list(response.agent_ids), list(response.org_ids))
308
607
 
309
608
  async def get_org_entity_ids(self, org_type: economyv2.OrgType) -> list[int]:
609
+ """
610
+ Get the IDs of all organizations of a specific type.
611
+
612
+ - **Args**:
613
+ - `org_type` (`economyv2.OrgType`): The type of organizations whose IDs are to be retrieved.
614
+
615
+ - **Returns**:
616
+ - `List[int]`: A list of organization IDs matching the specified type.
617
+ """
618
+ start_time = time.time()
619
+ log = {
620
+ "req": "get_org_entity_ids",
621
+ "start_time": start_time,
622
+ "consumption": 0
623
+ }
310
624
  request = org_service.GetOrgEntityIdsRequest(
311
625
  type=org_type,
312
626
  )
313
627
  response: org_service.GetOrgEntityIdsResponse = (
314
628
  await self._aio_stub.GetOrgEntityIds(request)
315
629
  )
630
+ log["consumption"] = time.time() - start_time
631
+ self._log_list.append(log)
316
632
  return list(response.org_ids)
317
633
 
318
634
  async def add_delta_value(
319
635
  self,
320
- id: int,
636
+ id: Union[int, list[int]],
321
637
  key: str,
322
638
  value: Any,
323
639
  ) -> Any:
324
640
  """
325
- Add key-value pair
641
+ Add value pair
326
642
 
327
- Args:
328
- - id (int): the id of `Org` or `Agent`.
329
- - key (str): the attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
643
+ - **Args**:
644
+ - `id` (`int`): The id of `Org` or `Agent`.
645
+ - `key` (`str`): The attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
330
646
 
331
-
332
- Returns:
333
- - Any
647
+ - **Returns**:
648
+ - Any
334
649
  """
650
+ start_time = time.time()
651
+ log = {
652
+ "req": "add_delta_value",
653
+ "start_time": start_time,
654
+ "consumption": 0
655
+ }
335
656
  pascal_key = _snake_to_pascal(key)
336
657
  _request_type = getattr(org_service, f"Add{pascal_key}Request")
337
658
  _request_func = getattr(self._aio_stub, f"Add{pascal_key}")
659
+
338
660
  _available_keys = {
339
661
  "inventory",
340
662
  "price",
341
663
  "interest_rate",
342
664
  "currency",
665
+ "income"
343
666
  }
344
667
  if key not in _available_keys:
345
668
  raise ValueError(f"Invalid key `{key}`, can only be {_available_keys}!")
346
- return await _request_func(
669
+ response = await _request_func(
347
670
  _request_type(
348
671
  **{
349
672
  "org_id": id,
@@ -351,3 +674,6 @@ class EconomyClient:
351
674
  }
352
675
  )
353
676
  )
677
+ log["consumption"] = time.time() - start_time
678
+ self._log_list.append(log)
679
+ return response