pycityagent 2.0.0a49__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a50__cp39-cp39-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/__init__.py +12 -3
- pycityagent/agent/__init__.py +9 -0
- pycityagent/agent/agent.py +324 -0
- pycityagent/{agent.py → agent/agent_base.py} +41 -345
- pycityagent/cityagent/bankagent.py +28 -16
- pycityagent/cityagent/firmagent.py +63 -25
- pycityagent/cityagent/governmentagent.py +35 -19
- pycityagent/cityagent/initial.py +38 -28
- pycityagent/cityagent/memory_config.py +240 -128
- pycityagent/cityagent/nbsagent.py +81 -35
- pycityagent/cityagent/societyagent.py +155 -72
- pycityagent/simulation/agentgroup.py +2 -2
- pycityagent/simulation/simulation.py +94 -55
- pycityagent/tools/__init__.py +9 -0
- pycityagent/{workflow → tools}/tool.py +3 -1
- pycityagent/workflow/__init__.py +0 -5
- pycityagent/workflow/block.py +12 -10
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a50.dist-info}/METADATA +1 -2
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a50.dist-info}/RECORD +23 -20
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a50.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a50.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a50.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a50.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,29 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import inspect
|
4
5
|
import json
|
5
6
|
import logging
|
6
|
-
import random
|
7
7
|
import uuid
|
8
8
|
from abc import ABC, abstractmethod
|
9
|
-
from copy import deepcopy
|
10
9
|
from datetime import datetime, timezone
|
11
10
|
from enum import Enum
|
12
|
-
from typing import Any,
|
13
|
-
from uuid import UUID
|
11
|
+
from typing import Any, Optional, Union, get_type_hints
|
14
12
|
|
15
13
|
import fastavro
|
16
|
-
from pyparsing import Dict
|
17
14
|
import ray
|
18
|
-
from mosstool.util.format_converter import dict2pb
|
19
15
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
16
|
+
from pyparsing import Dict
|
20
17
|
|
21
|
-
from
|
22
|
-
|
23
|
-
from .
|
24
|
-
from
|
25
|
-
from
|
26
|
-
from .
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from .utils import DIALOG_SCHEMA, SURVEY_SCHEMA, process_survey_for_llm
|
18
|
+
from ..economy import EconomyClient
|
19
|
+
from ..environment import Simulator
|
20
|
+
from ..environment.sim.person_service import PersonService
|
21
|
+
from ..llm import LLM
|
22
|
+
from ..memory import Memory
|
23
|
+
from ..message.messager import Messager
|
24
|
+
from ..metrics import MlflowClient
|
25
|
+
from ..utils import DIALOG_SCHEMA, SURVEY_SCHEMA, process_survey_for_llm
|
26
|
+
from ..workflow import Block
|
31
27
|
|
32
28
|
logger = logging.getLogger("pycityagent")
|
33
29
|
|
@@ -49,6 +45,7 @@ class Agent(ABC):
|
|
49
45
|
"""
|
50
46
|
Agent base class
|
51
47
|
"""
|
48
|
+
|
52
49
|
configurable_fields: list[str] = []
|
53
50
|
default_values: dict[str, Any] = {}
|
54
51
|
|
@@ -108,11 +105,7 @@ class Agent(ABC):
|
|
108
105
|
|
109
106
|
@classmethod
|
110
107
|
def export_class_config(cls) -> dict[str, Dict]:
|
111
|
-
result = {
|
112
|
-
"agent_name": cls.__name__,
|
113
|
-
"config": {},
|
114
|
-
"blocks": []
|
115
|
-
}
|
108
|
+
result = {"agent_name": cls.__name__, "config": {}, "blocks": []}
|
116
109
|
config = {
|
117
110
|
field: cls.default_values.get(field, "default_value")
|
118
111
|
for field in cls.configurable_fields
|
@@ -123,25 +116,29 @@ class Agent(ABC):
|
|
123
116
|
for attr_name, attr_type in hints.items():
|
124
117
|
if inspect.isclass(attr_type) and issubclass(attr_type, Block):
|
125
118
|
block_config = attr_type.export_class_config()
|
126
|
-
result["blocks"].append(
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
119
|
+
result["blocks"].append(
|
120
|
+
{
|
121
|
+
"name": attr_name,
|
122
|
+
"config": block_config,
|
123
|
+
"children": cls._export_subblocks(attr_type),
|
124
|
+
}
|
125
|
+
)
|
131
126
|
return result
|
132
127
|
|
133
128
|
@classmethod
|
134
|
-
def _export_subblocks(cls, block_cls:
|
129
|
+
def _export_subblocks(cls, block_cls: type[Block]) -> list[Dict]:
|
135
130
|
children = []
|
136
131
|
hints = get_type_hints(block_cls) # 获取类的注解
|
137
132
|
for attr_name, attr_type in hints.items():
|
138
133
|
if inspect.isclass(attr_type) and issubclass(attr_type, Block):
|
139
134
|
block_config = attr_type.export_class_config()
|
140
|
-
children.append(
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
135
|
+
children.append(
|
136
|
+
{
|
137
|
+
"name": attr_name,
|
138
|
+
"config": block_config,
|
139
|
+
"children": cls._export_subblocks(attr_type),
|
140
|
+
}
|
141
|
+
)
|
145
142
|
return children
|
146
143
|
|
147
144
|
@classmethod
|
@@ -151,28 +148,29 @@ class Agent(ABC):
|
|
151
148
|
json.dump(config, f, indent=4)
|
152
149
|
|
153
150
|
@classmethod
|
154
|
-
def import_block_config(cls, config: dict[str, list[
|
155
|
-
agent = cls(name=config["agent_name"])
|
151
|
+
def import_block_config(cls, config: dict[str, Union[list[dict], str]]) -> Agent:
|
152
|
+
agent = cls(name=config["agent_name"]) # type:ignore
|
156
153
|
|
157
|
-
def build_block(block_data:
|
154
|
+
def build_block(block_data: dict[str, Any]) -> Block:
|
158
155
|
block_cls = globals()[block_data["name"]]
|
159
156
|
block_instance = block_cls.import_config(block_data)
|
160
157
|
return block_instance
|
161
158
|
|
162
159
|
# 创建顶层Block
|
163
160
|
for block_data in config["blocks"]:
|
161
|
+
assert isinstance(block_data, dict)
|
164
162
|
block = build_block(block_data)
|
165
163
|
setattr(agent, block.name.lower(), block)
|
166
164
|
|
167
165
|
return agent
|
168
166
|
|
169
167
|
@classmethod
|
170
|
-
def import_from_file(cls, filepath: str) ->
|
168
|
+
def import_from_file(cls, filepath: str) -> Agent:
|
171
169
|
with open(filepath, "r") as f:
|
172
170
|
config = json.load(f)
|
173
171
|
return cls.import_block_config(config)
|
174
|
-
|
175
|
-
def load_from_config(self, config: dict[str, list[
|
172
|
+
|
173
|
+
def load_from_config(self, config: dict[str, list[dict]]) -> None:
|
176
174
|
"""
|
177
175
|
使用配置更新当前Agent实例的Block层次结构。
|
178
176
|
"""
|
@@ -185,13 +183,15 @@ class Agent(ABC):
|
|
185
183
|
# 递归更新或创建顶层Block
|
186
184
|
for block_data in config.get("blocks", []):
|
187
185
|
block_name = block_data["name"]
|
188
|
-
existing_block = getattr(self, block_name, None)
|
186
|
+
existing_block = getattr(self, block_name, None) # type:ignore
|
189
187
|
|
190
188
|
if existing_block:
|
191
189
|
# 如果Block已经存在,则递归更新
|
192
190
|
existing_block.load_from_config(block_data)
|
193
191
|
else:
|
194
|
-
raise KeyError(
|
192
|
+
raise KeyError(
|
193
|
+
f"Block '{block_name}' not found in agent '{self.__class__.__name__}'"
|
194
|
+
)
|
195
195
|
|
196
196
|
def load_from_file(self, filepath: str) -> None:
|
197
197
|
with open(filepath, "r") as f:
|
@@ -315,7 +315,7 @@ class Agent(ABC):
|
|
315
315
|
f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
|
316
316
|
)
|
317
317
|
return self._pgsql_writer
|
318
|
-
|
318
|
+
|
319
319
|
async def messager_ping(self):
|
320
320
|
if self._messager is None:
|
321
321
|
raise RuntimeError("Messager is not set")
|
@@ -633,307 +633,3 @@ class Agent(ABC):
|
|
633
633
|
await self._messager.ping.remote()
|
634
634
|
if not self._blocked:
|
635
635
|
await self.forward()
|
636
|
-
|
637
|
-
|
638
|
-
class CitizenAgent(Agent):
|
639
|
-
"""
|
640
|
-
CitizenAgent: 城市居民智能体类及其定义
|
641
|
-
"""
|
642
|
-
|
643
|
-
def __init__(
|
644
|
-
self,
|
645
|
-
name: str,
|
646
|
-
llm_client: Optional[LLM] = None,
|
647
|
-
simulator: Optional[Simulator] = None,
|
648
|
-
mlflow_client: Optional[MlflowClient] = None,
|
649
|
-
memory: Optional[Memory] = None,
|
650
|
-
economy_client: Optional[EconomyClient] = None,
|
651
|
-
messager: Optional[Messager] = None, # type:ignore
|
652
|
-
avro_file: Optional[dict] = None,
|
653
|
-
) -> None:
|
654
|
-
super().__init__(
|
655
|
-
name=name,
|
656
|
-
type=AgentType.Citizen,
|
657
|
-
llm_client=llm_client,
|
658
|
-
economy_client=economy_client,
|
659
|
-
messager=messager,
|
660
|
-
simulator=simulator,
|
661
|
-
mlflow_client=mlflow_client,
|
662
|
-
memory=memory,
|
663
|
-
avro_file=avro_file,
|
664
|
-
)
|
665
|
-
|
666
|
-
async def bind_to_simulator(self):
|
667
|
-
await self._bind_to_simulator()
|
668
|
-
await self._bind_to_economy()
|
669
|
-
|
670
|
-
async def _bind_to_simulator(self):
|
671
|
-
"""
|
672
|
-
Bind Agent to Simulator
|
673
|
-
|
674
|
-
Args:
|
675
|
-
person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
|
676
|
-
"""
|
677
|
-
if self._simulator is None:
|
678
|
-
logger.warning("Simulator is not set")
|
679
|
-
return
|
680
|
-
if not self._has_bound_to_simulator:
|
681
|
-
FROM_MEMORY_KEYS = {
|
682
|
-
"attribute",
|
683
|
-
"home",
|
684
|
-
"work",
|
685
|
-
"vehicle_attribute",
|
686
|
-
"bus_attribute",
|
687
|
-
"pedestrian_attribute",
|
688
|
-
"bike_attribute",
|
689
|
-
}
|
690
|
-
simulator = self.simulator
|
691
|
-
memory = self.memory
|
692
|
-
person_id = await memory.get("id")
|
693
|
-
# ATTENTION:模拟器分配的id从0开始
|
694
|
-
if person_id >= 0:
|
695
|
-
await simulator.get_person(person_id)
|
696
|
-
logger.debug(f"Binding to Person `{person_id}` already in Simulator")
|
697
|
-
else:
|
698
|
-
dict_person = deepcopy(self._person_template)
|
699
|
-
for _key in FROM_MEMORY_KEYS:
|
700
|
-
try:
|
701
|
-
_value = await memory.get(_key)
|
702
|
-
if _value:
|
703
|
-
dict_person[_key] = _value
|
704
|
-
except KeyError as e:
|
705
|
-
continue
|
706
|
-
resp = await simulator.add_person(
|
707
|
-
dict2pb(dict_person, person_pb2.Person())
|
708
|
-
)
|
709
|
-
person_id = resp["person_id"]
|
710
|
-
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
711
|
-
logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
|
712
|
-
# 防止模拟器还没有到prepare阶段导致get_person出错
|
713
|
-
self._has_bound_to_simulator = True
|
714
|
-
self._agent_id = person_id
|
715
|
-
self.memory.set_agent_id(person_id)
|
716
|
-
|
717
|
-
async def _bind_to_economy(self):
|
718
|
-
if self._economy_client is None:
|
719
|
-
logger.warning("Economy client is not set")
|
720
|
-
return
|
721
|
-
if not self._has_bound_to_economy:
|
722
|
-
if self._has_bound_to_simulator:
|
723
|
-
try:
|
724
|
-
await self._economy_client.remove_agents([self._agent_id])
|
725
|
-
except:
|
726
|
-
pass
|
727
|
-
person_id = await self.memory.get("id")
|
728
|
-
currency = await self.memory.get("currency")
|
729
|
-
await self._economy_client.add_agents(
|
730
|
-
{
|
731
|
-
"id": person_id,
|
732
|
-
"currency": currency,
|
733
|
-
}
|
734
|
-
)
|
735
|
-
self._has_bound_to_economy = True
|
736
|
-
else:
|
737
|
-
logger.debug(
|
738
|
-
f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
|
739
|
-
)
|
740
|
-
|
741
|
-
async def handle_gather_message(self, payload: dict):
|
742
|
-
"""处理收到的消息,识别发送者"""
|
743
|
-
# 从消息中解析发送者 ID 和消息内容
|
744
|
-
target = payload["target"]
|
745
|
-
sender_id = payload["from"]
|
746
|
-
content = await self.memory.get(f"{target}")
|
747
|
-
payload = {
|
748
|
-
"from": self._uuid,
|
749
|
-
"content": content,
|
750
|
-
}
|
751
|
-
await self._send_message(sender_id, payload, "gather")
|
752
|
-
|
753
|
-
|
754
|
-
class InstitutionAgent(Agent):
|
755
|
-
"""
|
756
|
-
InstitutionAgent: 机构智能体类及其定义
|
757
|
-
"""
|
758
|
-
|
759
|
-
def __init__(
|
760
|
-
self,
|
761
|
-
name: str,
|
762
|
-
llm_client: Optional[LLM] = None,
|
763
|
-
simulator: Optional[Simulator] = None,
|
764
|
-
mlflow_client: Optional[MlflowClient] = None,
|
765
|
-
memory: Optional[Memory] = None,
|
766
|
-
economy_client: Optional[EconomyClient] = None,
|
767
|
-
messager: Optional[Messager] = None, # type:ignore
|
768
|
-
avro_file: Optional[dict] = None,
|
769
|
-
) -> None:
|
770
|
-
super().__init__(
|
771
|
-
name=name,
|
772
|
-
type=AgentType.Institution,
|
773
|
-
llm_client=llm_client,
|
774
|
-
economy_client=economy_client,
|
775
|
-
mlflow_client=mlflow_client,
|
776
|
-
messager=messager,
|
777
|
-
simulator=simulator,
|
778
|
-
memory=memory,
|
779
|
-
avro_file=avro_file,
|
780
|
-
)
|
781
|
-
# 添加响应收集器
|
782
|
-
self._gather_responses: dict[str, asyncio.Future] = {}
|
783
|
-
|
784
|
-
async def bind_to_simulator(self):
|
785
|
-
await self._bind_to_economy()
|
786
|
-
|
787
|
-
async def _bind_to_economy(self):
|
788
|
-
print("Debug:", self._economy_client, self._has_bound_to_economy)
|
789
|
-
if self._economy_client is None:
|
790
|
-
logger.debug("Economy client is not set")
|
791
|
-
return
|
792
|
-
if not self._has_bound_to_economy:
|
793
|
-
# TODO: More general id generation
|
794
|
-
_id = random.randint(100000, 999999)
|
795
|
-
self._agent_id = _id
|
796
|
-
self.memory.set_agent_id(_id)
|
797
|
-
map_header = self.simulator.map.header
|
798
|
-
# TODO: remove random position assignment
|
799
|
-
await self.memory.update(
|
800
|
-
"position",
|
801
|
-
{
|
802
|
-
"xy_position": {
|
803
|
-
"x": float(
|
804
|
-
random.randrange(
|
805
|
-
start=int(map_header["west"]),
|
806
|
-
stop=int(map_header["east"]),
|
807
|
-
)
|
808
|
-
),
|
809
|
-
"y": float(
|
810
|
-
random.randrange(
|
811
|
-
start=int(map_header["south"]),
|
812
|
-
stop=int(map_header["north"]),
|
813
|
-
)
|
814
|
-
),
|
815
|
-
}
|
816
|
-
},
|
817
|
-
protect_llm_read_only_fields=False,
|
818
|
-
)
|
819
|
-
await self.memory.update("id", _id, protect_llm_read_only_fields=False)
|
820
|
-
try:
|
821
|
-
await self._economy_client.remove_orgs([self._agent_id])
|
822
|
-
except:
|
823
|
-
pass
|
824
|
-
try:
|
825
|
-
_memory = self.memory
|
826
|
-
_id = await _memory.get("id")
|
827
|
-
_type = await _memory.get("type")
|
828
|
-
try:
|
829
|
-
nominal_gdp = await _memory.get("nominal_gdp")
|
830
|
-
except:
|
831
|
-
nominal_gdp = []
|
832
|
-
try:
|
833
|
-
real_gdp = await _memory.get("real_gdp")
|
834
|
-
except:
|
835
|
-
real_gdp = []
|
836
|
-
try:
|
837
|
-
unemployment = await _memory.get("unemployment")
|
838
|
-
except:
|
839
|
-
unemployment = []
|
840
|
-
try:
|
841
|
-
wages = await _memory.get("wages")
|
842
|
-
except:
|
843
|
-
wages = []
|
844
|
-
try:
|
845
|
-
prices = await _memory.get("prices")
|
846
|
-
except:
|
847
|
-
prices = []
|
848
|
-
try:
|
849
|
-
inventory = await _memory.get("inventory")
|
850
|
-
except:
|
851
|
-
inventory = 0
|
852
|
-
try:
|
853
|
-
price = await _memory.get("price")
|
854
|
-
except:
|
855
|
-
price = 0
|
856
|
-
try:
|
857
|
-
currency = await _memory.get("currency")
|
858
|
-
except:
|
859
|
-
currency = 0.0
|
860
|
-
try:
|
861
|
-
interest_rate = await _memory.get("interest_rate")
|
862
|
-
except:
|
863
|
-
interest_rate = 0.0
|
864
|
-
try:
|
865
|
-
bracket_cutoffs = await _memory.get("bracket_cutoffs")
|
866
|
-
except:
|
867
|
-
bracket_cutoffs = []
|
868
|
-
try:
|
869
|
-
bracket_rates = await _memory.get("bracket_rates")
|
870
|
-
except:
|
871
|
-
bracket_rates = []
|
872
|
-
await self._economy_client.add_orgs(
|
873
|
-
{
|
874
|
-
"id": _id,
|
875
|
-
"type": _type,
|
876
|
-
"nominal_gdp": nominal_gdp,
|
877
|
-
"real_gdp": real_gdp,
|
878
|
-
"unemployment": unemployment,
|
879
|
-
"wages": wages,
|
880
|
-
"prices": prices,
|
881
|
-
"inventory": inventory,
|
882
|
-
"price": price,
|
883
|
-
"currency": currency,
|
884
|
-
"interest_rate": interest_rate,
|
885
|
-
"bracket_cutoffs": bracket_cutoffs,
|
886
|
-
"bracket_rates": bracket_rates,
|
887
|
-
}
|
888
|
-
)
|
889
|
-
except Exception as e:
|
890
|
-
logger.error(f"Failed to bind to Economy: {e}")
|
891
|
-
self._has_bound_to_economy = True
|
892
|
-
|
893
|
-
async def handle_gather_message(self, payload: dict):
|
894
|
-
"""处理收到的消息,识别发送者"""
|
895
|
-
content = payload["content"]
|
896
|
-
sender_id = payload["from"]
|
897
|
-
|
898
|
-
# 将响应存储到对应的Future中
|
899
|
-
response_key = str(sender_id)
|
900
|
-
if response_key in self._gather_responses:
|
901
|
-
self._gather_responses[response_key].set_result(
|
902
|
-
{
|
903
|
-
"from": sender_id,
|
904
|
-
"content": content,
|
905
|
-
}
|
906
|
-
)
|
907
|
-
|
908
|
-
async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
|
909
|
-
"""从多个智能体收集消息
|
910
|
-
|
911
|
-
Args:
|
912
|
-
agent_uuids: 目标智能体UUID列表
|
913
|
-
target: 要收集的信息类型
|
914
|
-
|
915
|
-
Returns:
|
916
|
-
list[dict]: 收集到的所有响应
|
917
|
-
"""
|
918
|
-
# 为每个agent创建Future
|
919
|
-
futures = {}
|
920
|
-
for agent_uuid in agent_uuids:
|
921
|
-
futures[agent_uuid] = asyncio.Future()
|
922
|
-
self._gather_responses[agent_uuid] = futures[agent_uuid]
|
923
|
-
|
924
|
-
# 发送gather请求
|
925
|
-
payload = {
|
926
|
-
"from": self._uuid,
|
927
|
-
"target": target,
|
928
|
-
}
|
929
|
-
for agent_uuid in agent_uuids:
|
930
|
-
await self._send_message(agent_uuid, payload, "gather")
|
931
|
-
|
932
|
-
try:
|
933
|
-
# 等待所有响应
|
934
|
-
responses = await asyncio.gather(*futures.values())
|
935
|
-
return responses
|
936
|
-
finally:
|
937
|
-
# 清理Future
|
938
|
-
for key in futures:
|
939
|
-
self._gather_responses.pop(key, None)
|
@@ -11,22 +11,32 @@ import logging
|
|
11
11
|
|
12
12
|
logger = logging.getLogger("pycityagent")
|
13
13
|
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
14
|
+
|
15
|
+
class BankAgent(InstitutionAgent):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
name: str,
|
19
|
+
llm_client: Optional[LLM] = None,
|
20
|
+
simulator: Optional[Simulator] = None,
|
21
|
+
memory: Optional[Memory] = None,
|
22
|
+
economy_client: Optional[EconomyClient] = None,
|
23
|
+
messager: Optional[Messager] = None,
|
24
|
+
avro_file: Optional[dict] = None,
|
25
|
+
) -> None:
|
26
|
+
super().__init__(
|
27
|
+
name=name,
|
28
|
+
llm_client=llm_client,
|
29
|
+
simulator=simulator,
|
30
|
+
memory=memory,
|
31
|
+
economy_client=economy_client,
|
32
|
+
messager=messager,
|
33
|
+
avro_file=avro_file,
|
34
|
+
)
|
25
35
|
self.initailzed = False
|
26
36
|
self.last_time_trigger = None
|
27
37
|
self.time_diff = 30 * 24 * 60 * 60
|
28
38
|
self.forward_times = 0
|
29
|
-
|
39
|
+
|
30
40
|
async def month_trigger(self):
|
31
41
|
now_time = await self.simulator.get_time()
|
32
42
|
if self.last_time_trigger is None:
|
@@ -36,19 +46,21 @@ class BankAgent(InstitutionAgent):
|
|
36
46
|
self.last_time_trigger = now_time
|
37
47
|
return True
|
38
48
|
return False
|
39
|
-
|
49
|
+
|
40
50
|
async def gather_messages(self, agent_ids, content):
|
41
51
|
infos = await super().gather_messages(agent_ids, content)
|
42
|
-
return [info[
|
52
|
+
return [info["content"] for info in infos]
|
43
53
|
|
44
54
|
async def forward(self):
|
45
55
|
if await self.month_trigger():
|
46
56
|
citizens = await self.memory.get("citizens")
|
47
57
|
while True:
|
48
|
-
agents_forward = await self.gather_messages(citizens,
|
58
|
+
agents_forward = await self.gather_messages(citizens, "forward")
|
49
59
|
if np.all(np.array(agents_forward) > self.forward_times):
|
50
60
|
break
|
51
61
|
await asyncio.sleep(1)
|
52
62
|
self.forward_times += 1
|
53
63
|
for uuid in citizens:
|
54
|
-
await self.send_message_to_agent(
|
64
|
+
await self.send_message_to_agent(
|
65
|
+
uuid, f"bank_forward@{self.forward_times}"
|
66
|
+
)
|
@@ -11,7 +11,8 @@ import logging
|
|
11
11
|
|
12
12
|
logger = logging.getLogger("pycityagent")
|
13
13
|
|
14
|
-
|
14
|
+
|
15
|
+
class FirmAgent(InstitutionAgent):
|
15
16
|
configurable_fields = ["time_diff", "max_price_inflation", "max_wage_inflation"]
|
16
17
|
default_values = {
|
17
18
|
"time_diff": 30 * 24 * 60 * 60,
|
@@ -19,16 +20,25 @@ class FirmAgent(InstitutionAgent):
|
|
19
20
|
"max_wage_inflation": 0.05,
|
20
21
|
}
|
21
22
|
|
22
|
-
def __init__(
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
name: str,
|
26
|
+
llm_client: Optional[LLM] = None,
|
27
|
+
simulator: Optional[Simulator] = None,
|
28
|
+
memory: Optional[Memory] = None,
|
29
|
+
economy_client: Optional[EconomyClient] = None,
|
30
|
+
messager: Optional[Messager] = None,
|
31
|
+
avro_file: Optional[dict] = None,
|
32
|
+
) -> None:
|
33
|
+
super().__init__(
|
34
|
+
name=name,
|
35
|
+
llm_client=llm_client,
|
36
|
+
simulator=simulator,
|
37
|
+
memory=memory,
|
38
|
+
economy_client=economy_client,
|
39
|
+
messager=messager,
|
40
|
+
avro_file=avro_file,
|
41
|
+
)
|
32
42
|
self.initailzed = False
|
33
43
|
self.last_time_trigger = None
|
34
44
|
self.forward_times = 0
|
@@ -45,31 +55,59 @@ class FirmAgent(InstitutionAgent):
|
|
45
55
|
self.last_time_trigger = now_time
|
46
56
|
return True
|
47
57
|
return False
|
48
|
-
|
58
|
+
|
49
59
|
async def gather_messages(self, agent_ids, content):
|
50
60
|
infos = await super().gather_messages(agent_ids, content)
|
51
|
-
return [info[
|
61
|
+
return [info["content"] for info in infos]
|
52
62
|
|
53
63
|
async def forward(self):
|
54
64
|
if await self.month_trigger():
|
55
65
|
employees = await self.memory.get("employees")
|
56
66
|
while True:
|
57
|
-
agents_forward = await self.gather_messages(employees,
|
67
|
+
agents_forward = await self.gather_messages(employees, "forward")
|
58
68
|
if np.all(np.array(agents_forward) > self.forward_times):
|
59
69
|
break
|
60
70
|
await asyncio.sleep(1)
|
61
|
-
goods_demand = await self.gather_messages(employees,
|
62
|
-
goods_consumption = await self.gather_messages(
|
63
|
-
|
71
|
+
goods_demand = await self.gather_messages(employees, "goods_demand")
|
72
|
+
goods_consumption = await self.gather_messages(
|
73
|
+
employees, "goods_consumption"
|
74
|
+
)
|
75
|
+
print(
|
76
|
+
f"goods_demand: {goods_demand}, goods_consumption: {goods_consumption}"
|
77
|
+
)
|
64
78
|
total_demand = sum(goods_demand)
|
65
|
-
last_inventory = sum(goods_consumption) + await self.economy_client.get(
|
66
|
-
|
67
|
-
|
68
|
-
|
79
|
+
last_inventory = sum(goods_consumption) + await self.economy_client.get(
|
80
|
+
self._agent_id, "inventory"
|
81
|
+
)
|
82
|
+
print(
|
83
|
+
f"total_demand: {total_demand}, last_inventory: {last_inventory}, goods_contumption: {sum(goods_consumption)}"
|
84
|
+
)
|
85
|
+
max_change_rate = (total_demand - last_inventory) / (
|
86
|
+
max(total_demand, last_inventory) + 1e-8
|
87
|
+
)
|
88
|
+
skills = await self.gather_messages(employees, "work_skill")
|
69
89
|
for skill, uuid in zip(skills, employees):
|
70
|
-
await self.send_message_to_agent(
|
71
|
-
|
72
|
-
|
90
|
+
await self.send_message_to_agent(
|
91
|
+
uuid,
|
92
|
+
f"work_skill@{max(skill*(1 + np.random.uniform(0, max_change_rate*self.max_wage_inflation)), 1)}",
|
93
|
+
)
|
94
|
+
price = await self.economy_client.get(self._agent_id, "price")
|
95
|
+
await self.economy_client.update(
|
96
|
+
self._agent_id,
|
97
|
+
"price",
|
98
|
+
max(
|
99
|
+
price
|
100
|
+
* (
|
101
|
+
1
|
102
|
+
+ np.random.uniform(
|
103
|
+
0, max_change_rate * self.max_price_inflation
|
104
|
+
)
|
105
|
+
),
|
106
|
+
1,
|
107
|
+
),
|
108
|
+
)
|
73
109
|
self.forward_times += 1
|
74
110
|
for uuid in employees:
|
75
|
-
await self.send_message_to_agent(
|
111
|
+
await self.send_message_to_agent(
|
112
|
+
uuid, f"firm_forward@{self.forward_times}"
|
113
|
+
)
|