pycityagent 2.0.0a1__py3-none-any.whl → 2.0.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +1 -1
- pycityagent/environment/sim/__init__.py +2 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/simulator.py +110 -159
- pycityagent/environment/utils/__init__.py +3 -1
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +242 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/simulation/simulation.py +153 -87
- pycityagent/workflow/__init__.py +5 -3
- pycityagent/workflow/block.py +31 -4
- pycityagent/workflow/trigger.py +108 -24
- {pycityagent-2.0.0a1.dist-info → pycityagent-2.0.0a3.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a1.dist-info → pycityagent-2.0.0a3.dist-info}/RECORD +15 -11
- {pycityagent-2.0.0a1.dist-info → pycityagent-2.0.0a3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,242 @@
|
|
1
|
+
POI_CATG_DICT = {
|
2
|
+
"户外活动": [
|
3
|
+
"bandstand",
|
4
|
+
"beach_resort",
|
5
|
+
"bird_hide",
|
6
|
+
"bleachers",
|
7
|
+
"firepit",
|
8
|
+
"fishing",
|
9
|
+
"garden",
|
10
|
+
"nature_reserve",
|
11
|
+
"park",
|
12
|
+
"picnic_table",
|
13
|
+
"playground",
|
14
|
+
"resort",
|
15
|
+
"summer_camp",
|
16
|
+
"swimming_area",
|
17
|
+
"water_park",
|
18
|
+
"wildlife_hide",
|
19
|
+
],
|
20
|
+
"室内娱乐场所": [
|
21
|
+
"adult_gaming_centre",
|
22
|
+
"amusement_arcade",
|
23
|
+
"bowling_alley",
|
24
|
+
"disc_golf_course",
|
25
|
+
"escape_game",
|
26
|
+
"fitness_centre",
|
27
|
+
"fitness_station",
|
28
|
+
"golf_course",
|
29
|
+
"miniature_golf",
|
30
|
+
"sauna",
|
31
|
+
"tanning_salon",
|
32
|
+
"trampoline_park",
|
33
|
+
"bar",
|
34
|
+
"biergarten",
|
35
|
+
"cafe",
|
36
|
+
"fast_food",
|
37
|
+
"food_court",
|
38
|
+
"ice_cream",
|
39
|
+
"pub",
|
40
|
+
"restaurant",
|
41
|
+
"arts_centre",
|
42
|
+
"brothel",
|
43
|
+
"casino",
|
44
|
+
"cinema",
|
45
|
+
"community_centre",
|
46
|
+
"conference_centre",
|
47
|
+
"events_venue",
|
48
|
+
"exhibition_centre",
|
49
|
+
"fountain",
|
50
|
+
"gambling",
|
51
|
+
"love_hotel",
|
52
|
+
"music_venue",
|
53
|
+
"nightclub",
|
54
|
+
"planetarium",
|
55
|
+
"public_bookcase",
|
56
|
+
"social_centre",
|
57
|
+
"stage",
|
58
|
+
"stripclub",
|
59
|
+
"studio",
|
60
|
+
"swingerclub",
|
61
|
+
"theatre",
|
62
|
+
],
|
63
|
+
"体育设施": [
|
64
|
+
"horse_riding",
|
65
|
+
"ice_rink",
|
66
|
+
"marina",
|
67
|
+
"pitch",
|
68
|
+
"sports_centre",
|
69
|
+
"sports_hall",
|
70
|
+
"stadium",
|
71
|
+
"track",
|
72
|
+
"swimming_pool",
|
73
|
+
],
|
74
|
+
"水上活动": [
|
75
|
+
"beach_resort",
|
76
|
+
"ice_rink",
|
77
|
+
"marina",
|
78
|
+
"slipway",
|
79
|
+
"swimming_area",
|
80
|
+
"swimming_pool",
|
81
|
+
"water_park",
|
82
|
+
],
|
83
|
+
"自然与野生动物观赏": [
|
84
|
+
"bird_hide",
|
85
|
+
"nature_reserve",
|
86
|
+
"wildlife_hide",
|
87
|
+
"hunting_stand",
|
88
|
+
],
|
89
|
+
"儿童游乐区": ["playground", "summer_camp", "miniature_golf", "dog_park"],
|
90
|
+
"餐饮服务": [
|
91
|
+
"bar",
|
92
|
+
"biergarten",
|
93
|
+
"cafe",
|
94
|
+
"fast_food",
|
95
|
+
"food_court",
|
96
|
+
"ice_cream",
|
97
|
+
"pub",
|
98
|
+
"restaurant",
|
99
|
+
],
|
100
|
+
"教育机构": [
|
101
|
+
"college",
|
102
|
+
"dancing_school",
|
103
|
+
"driving_school",
|
104
|
+
"first_aid_school",
|
105
|
+
"kindergarten",
|
106
|
+
"language_school",
|
107
|
+
"library",
|
108
|
+
"surf_school",
|
109
|
+
"toy_library",
|
110
|
+
"research_institute",
|
111
|
+
"training",
|
112
|
+
"music_school",
|
113
|
+
"school",
|
114
|
+
"traffic_park",
|
115
|
+
"university",
|
116
|
+
],
|
117
|
+
"交通设施": [
|
118
|
+
"bicycle_parking",
|
119
|
+
"bicycle_repair_station",
|
120
|
+
"bicycle_rental",
|
121
|
+
"bicycle_wash",
|
122
|
+
"boat_rental",
|
123
|
+
"boat_sharing",
|
124
|
+
"bus_station",
|
125
|
+
"car_rental",
|
126
|
+
"car_sharing",
|
127
|
+
"car_wash",
|
128
|
+
"compressed_air",
|
129
|
+
"vehicle_inspection",
|
130
|
+
"charging_station",
|
131
|
+
"driver_training",
|
132
|
+
"ferry_terminal",
|
133
|
+
"fuel",
|
134
|
+
"grit_bin",
|
135
|
+
"motorcycle_parking",
|
136
|
+
"parking",
|
137
|
+
"parking_entrance",
|
138
|
+
"parking_space",
|
139
|
+
"taxi",
|
140
|
+
"weighbridge",
|
141
|
+
],
|
142
|
+
"金融服务": [
|
143
|
+
"atm",
|
144
|
+
"payment_terminal",
|
145
|
+
"bank",
|
146
|
+
"bureau_de_change",
|
147
|
+
"money_transfer",
|
148
|
+
"payment_centre",
|
149
|
+
],
|
150
|
+
"医疗保健": [
|
151
|
+
"baby_hatch",
|
152
|
+
"clinic",
|
153
|
+
"dentist",
|
154
|
+
"doctors",
|
155
|
+
"hospital",
|
156
|
+
"nursing_home",
|
157
|
+
"pharmacy",
|
158
|
+
"social_facility",
|
159
|
+
"veterinary",
|
160
|
+
],
|
161
|
+
"文化艺术": [
|
162
|
+
"arts_centre",
|
163
|
+
"brothel",
|
164
|
+
"casino",
|
165
|
+
"cinema",
|
166
|
+
"community_centre",
|
167
|
+
"conference_centre",
|
168
|
+
"events_venue",
|
169
|
+
"exhibition_centre",
|
170
|
+
"fountain",
|
171
|
+
"gambling",
|
172
|
+
"love_hotel",
|
173
|
+
"music_venue",
|
174
|
+
"nightclub",
|
175
|
+
"planetarium",
|
176
|
+
"public_bookcase",
|
177
|
+
"social_centre",
|
178
|
+
"stage",
|
179
|
+
"stripclub",
|
180
|
+
"studio",
|
181
|
+
"swingerclub",
|
182
|
+
"theatre",
|
183
|
+
],
|
184
|
+
"公共服务": [
|
185
|
+
"bbq",
|
186
|
+
"bench",
|
187
|
+
"dog_toilet",
|
188
|
+
"dressing_room",
|
189
|
+
"drinking_water",
|
190
|
+
"give_box",
|
191
|
+
"lounge",
|
192
|
+
"mailroom",
|
193
|
+
"parcel_locker",
|
194
|
+
"shelter",
|
195
|
+
"shower",
|
196
|
+
"telephone",
|
197
|
+
"toilets",
|
198
|
+
"water_point",
|
199
|
+
"watering_place",
|
200
|
+
"sanitary_dump_station",
|
201
|
+
"recycling",
|
202
|
+
"waste_basket",
|
203
|
+
"waste_disposal",
|
204
|
+
"waste_transfer_station",
|
205
|
+
"post_box",
|
206
|
+
"post_depot",
|
207
|
+
"post_office",
|
208
|
+
"courthouse",
|
209
|
+
"fire_station",
|
210
|
+
"police",
|
211
|
+
"prison",
|
212
|
+
"ranger_station",
|
213
|
+
"townhall",
|
214
|
+
],
|
215
|
+
"其他特殊用途": [
|
216
|
+
"animal_boarding",
|
217
|
+
"animal_breeding",
|
218
|
+
"animal_shelter",
|
219
|
+
"animal_training",
|
220
|
+
"baking_oven",
|
221
|
+
"clock",
|
222
|
+
"crematorium",
|
223
|
+
"dive_centre",
|
224
|
+
"funeral_hall",
|
225
|
+
"grave_yard",
|
226
|
+
"hunting_stand",
|
227
|
+
"internet_cafe",
|
228
|
+
"kitchen",
|
229
|
+
"kneipp_water_cure",
|
230
|
+
"lounger",
|
231
|
+
"marketplace",
|
232
|
+
"monastery",
|
233
|
+
"mortuary",
|
234
|
+
"photo_booth",
|
235
|
+
"place_of_mourning",
|
236
|
+
"place_of_worship",
|
237
|
+
"public_bath",
|
238
|
+
"public_building",
|
239
|
+
"refugee_site",
|
240
|
+
"vending_machine",
|
241
|
+
],
|
242
|
+
}
|
@@ -0,0 +1,11 @@
|
|
1
|
+
import socket
|
2
|
+
from contextlib import closing
|
3
|
+
|
4
|
+
__all__ = ["find_free_port"]
|
5
|
+
|
6
|
+
|
7
|
+
def find_free_port():
|
8
|
+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
9
|
+
s.bind(("", 0))
|
10
|
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
11
|
+
return s.getsockname()[1]
|
@@ -1,13 +1,18 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
import logging
|
4
|
-
from datetime import datetime
|
5
|
-
|
4
|
+
from datetime import datetime
|
5
|
+
import random
|
6
|
+
from typing import Dict, List, Optional, Callable
|
7
|
+
from mosstool.map._map_util.const import AOI_START_ID
|
8
|
+
|
9
|
+
from pycityagent.llm.llm import LLM
|
10
|
+
from pycityagent.memory.memory import Memory
|
6
11
|
|
7
12
|
from ..agent import Agent
|
8
13
|
from ..environment import Simulator
|
9
14
|
from .interview import InterviewManager
|
10
|
-
from .survey import QuestionType,
|
15
|
+
from .survey import QuestionType, SurveyManager
|
11
16
|
from .ui import InterviewUI
|
12
17
|
|
13
18
|
logger = logging.getLogger(__name__)
|
@@ -15,9 +20,18 @@ logger = logging.getLogger(__name__)
|
|
15
20
|
|
16
21
|
class AgentSimulation:
|
17
22
|
"""城市智能体模拟器"""
|
18
|
-
|
19
|
-
|
23
|
+
def __init__(self, agent_class: type[Agent], simulator: Simulator, llm: LLM, agent_prefix: str = "agent_"):
|
24
|
+
"""
|
25
|
+
Args:
|
26
|
+
agent_class: 智能体类
|
27
|
+
simulator: 模拟器
|
28
|
+
llm: 语言模型
|
29
|
+
agent_prefix: 智能体名称前缀
|
30
|
+
"""
|
31
|
+
self.agent_class = agent_class
|
20
32
|
self.simulator = simulator
|
33
|
+
self.llm = llm
|
34
|
+
self.agent_prefix = agent_prefix
|
21
35
|
self._agents: Dict[str, Agent] = {}
|
22
36
|
self._interview_manager = InterviewManager()
|
23
37
|
self._interview_lock = asyncio.Lock()
|
@@ -28,20 +42,77 @@ class AgentSimulation:
|
|
28
42
|
self._blocked_agents: List[str] = [] # 新增:持续阻塞的智能体列表
|
29
43
|
self._survey_manager = SurveyManager()
|
30
44
|
|
31
|
-
def
|
32
|
-
"""
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
+
def init_agents(self, agent_count: int, memory_config_func: Callable = None) -> None:
|
46
|
+
"""初始化智能体
|
47
|
+
|
48
|
+
Args:
|
49
|
+
agent_count: 要创建的智能体数量
|
50
|
+
memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组
|
51
|
+
"""
|
52
|
+
if memory_config_func is None:
|
53
|
+
memory_config_func = self.default_memory_config_func
|
54
|
+
|
55
|
+
for i in range(agent_count):
|
56
|
+
agent_name = f"{self.agent_prefix}{i}"
|
57
|
+
|
58
|
+
# 获取Memory配置
|
59
|
+
extra_attributes, profile, base = memory_config_func()
|
60
|
+
memory = Memory(
|
61
|
+
config=extra_attributes,
|
62
|
+
profile=profile.copy(),
|
63
|
+
base=base.copy()
|
64
|
+
)
|
65
|
+
|
66
|
+
# 创建智能体时传入Memory配置
|
67
|
+
agent = self.agent_class(
|
68
|
+
name=agent_name,
|
69
|
+
simulator=self.simulator,
|
70
|
+
llm=self.llm,
|
71
|
+
memory=memory
|
72
|
+
)
|
73
|
+
|
74
|
+
self._agents[agent_name] = agent
|
75
|
+
|
76
|
+
def default_memory_config_func(self):
|
77
|
+
"""默认的Memory配置函数"""
|
78
|
+
EXTRA_ATTRIBUTES = {
|
79
|
+
# 需求信息
|
80
|
+
"needs": (dict, {
|
81
|
+
'hungry': random.random(), # 饥饿感
|
82
|
+
'tired': random.random(), # 疲劳感
|
83
|
+
'safe': random.random(), # 安全需
|
84
|
+
'social': random.random(), # 社会需求
|
85
|
+
}, True),
|
86
|
+
"current_need": (str, "none", True),
|
87
|
+
"current_plan": (list, [], True),
|
88
|
+
"current_step": (dict, {"intention": "", "type": ""}, True),
|
89
|
+
"execution_context" : (dict, {}, True),
|
90
|
+
"plan_history": (list, [], True),
|
91
|
+
}
|
92
|
+
|
93
|
+
PROFILE = {
|
94
|
+
"gender": random.choice(["male", "female"]),
|
95
|
+
"education": random.choice(["Doctor", "Master", "Bachelor", "College", "High School"]),
|
96
|
+
"consumption": random.choice(["sightly low", "low", "medium", "high"]),
|
97
|
+
"occupation": random.choice(["Student", "Teacher", "Doctor", "Engineer", "Manager", "Businessman", "Artist", "Athlete", "Other"]),
|
98
|
+
"age": random.randint(18, 65),
|
99
|
+
"skill": random.choice(["Good at problem-solving", "Good at communication", "Good at creativity", "Good at teamwork", "Other"]),
|
100
|
+
"family_consumption": random.choice(["low", "medium", "high"]),
|
101
|
+
"personality": random.choice(["outgoint", "introvert", "ambivert", "extrovert"]),
|
102
|
+
"income": random.randint(1000, 10000),
|
103
|
+
"residence": random.choice(["city", "suburb", "rural"]),
|
104
|
+
"race": random.choice(["Chinese", "American", "British", "French", "German", "Japanese", "Korean", "Russian", "Other"]),
|
105
|
+
"religion": random.choice(["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]),
|
106
|
+
"marital_status": random.choice(["not married", "married", "divorced", "widowed"]),
|
107
|
+
}
|
108
|
+
|
109
|
+
aois = self.simulator.aois.keys()
|
110
|
+
BASE = {
|
111
|
+
"home": {"aoi_position": {"aoi_id": random.choice(aois)}},
|
112
|
+
"work": {"aoi_position": {"aoi_id": random.choice(aois)}},
|
113
|
+
}
|
114
|
+
|
115
|
+
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
45
116
|
|
46
117
|
def get_agent_runtime(self, agent_name: str) -> str:
|
47
118
|
"""获取智能体运行时间"""
|
@@ -133,74 +204,7 @@ class AgentSimulation:
|
|
133
204
|
except Exception as e:
|
134
205
|
logger.error(f"采访过程出错: {str(e)}")
|
135
206
|
return f"采访过程出现错误: {str(e)}"
|
136
|
-
|
137
|
-
async def run(
|
138
|
-
self,
|
139
|
-
steps: int = -1,
|
140
|
-
interval: float = 1.0,
|
141
|
-
start_ui: bool = True,
|
142
|
-
server_name: str = "127.0.0.1",
|
143
|
-
server_port: int = 7860,
|
144
|
-
):
|
145
|
-
"""运行模拟器
|
146
|
-
|
147
|
-
Args:
|
148
|
-
steps: 运行步数,默认为-1表示无限运行
|
149
|
-
interval: 智能体forward间隔时间,单位为秒,默认1秒
|
150
|
-
start_ui: 是否启动UI,默认为True
|
151
|
-
server_name: UI服务器地址,默认为"127.0.0.1"
|
152
|
-
server_port: UI服务器端口,默认为7860
|
153
|
-
"""
|
154
|
-
try:
|
155
|
-
self._interview_lock = asyncio.Lock()
|
156
|
-
# 初始化UI
|
157
|
-
if start_ui:
|
158
|
-
self._ui = InterviewUI(self)
|
159
|
-
interface = self._ui.create_interface()
|
160
|
-
interface.queue().launch(
|
161
|
-
server_name=server_name,
|
162
|
-
server_port=server_port,
|
163
|
-
prevent_thread_lock=True,
|
164
|
-
quiet=True,
|
165
|
-
)
|
166
|
-
print(
|
167
|
-
f"Gradio Frontend is running on http://{server_name}:{server_port}"
|
168
|
-
)
|
169
|
-
|
170
|
-
# 运行所有agents
|
171
|
-
tasks = []
|
172
|
-
for agent in self._agents.values():
|
173
|
-
tasks.append(self._run_agent(agent, steps, interval))
|
174
|
-
|
175
|
-
await asyncio.gather(*tasks)
|
176
|
-
|
177
|
-
except Exception as e:
|
178
|
-
logger.error(f"模拟器运行错误: {str(e)}")
|
179
|
-
raise
|
180
|
-
|
181
|
-
async def _run_agent(self, agent: Agent, steps: int = -1, interval: float = 1.0):
|
182
|
-
"""运行单个agent的包装器
|
183
|
-
|
184
|
-
Args:
|
185
|
-
agent: 要运行的能体
|
186
|
-
steps: 运行步数,默认为-1表示无限运行
|
187
|
-
interval: 智能体forward间隔时间,单位为秒
|
188
|
-
"""
|
189
|
-
step_count = 0
|
190
|
-
while steps == -1 or step_count < steps:
|
191
|
-
try:
|
192
|
-
if agent._name in self._blocked_agents:
|
193
|
-
await asyncio.sleep(interval)
|
194
|
-
continue
|
195
|
-
|
196
|
-
await agent.forward()
|
197
|
-
await asyncio.sleep(interval) # 控制运行频率
|
198
|
-
step_count += 1
|
199
|
-
|
200
|
-
except Exception as e:
|
201
|
-
logger.error(f"智能体 {agent._name} 运行错误: {str(e)}")
|
202
|
-
await asyncio.sleep(interval) # 发生错误时暂停一下
|
203
|
-
|
207
|
+
|
204
208
|
async def submit_survey(self, agent_name: str, survey_id: str) -> str:
|
205
209
|
"""向智能体提交问卷
|
206
210
|
|
@@ -284,3 +288,65 @@ class AgentSimulation:
|
|
284
288
|
if survey_dict["id"] == survey_id:
|
285
289
|
return survey_dict
|
286
290
|
return None
|
291
|
+
|
292
|
+
async def init_ui(
|
293
|
+
self,
|
294
|
+
server_name: str = "127.0.0.1",
|
295
|
+
server_port: int = 7860,
|
296
|
+
):
|
297
|
+
"""初始化UI"""
|
298
|
+
self._interview_lock = asyncio.Lock()
|
299
|
+
# 初始化GradioUI
|
300
|
+
self._ui = InterviewUI(self)
|
301
|
+
interface = self._ui.create_interface()
|
302
|
+
interface.queue().launch(
|
303
|
+
server_name=server_name,
|
304
|
+
server_port=server_port,
|
305
|
+
prevent_thread_lock=True,
|
306
|
+
quiet=True,
|
307
|
+
)
|
308
|
+
print(
|
309
|
+
f"Gradio Frontend is running on http://{server_name}:{server_port}"
|
310
|
+
)
|
311
|
+
|
312
|
+
async def step(self):
|
313
|
+
"""运行一步, 即每个智能体执行一次forward"""
|
314
|
+
try:
|
315
|
+
tasks = []
|
316
|
+
for agent in self._agents.values():
|
317
|
+
tasks.append(agent.forward())
|
318
|
+
await asyncio.gather(*tasks)
|
319
|
+
except Exception as e:
|
320
|
+
logger.error(f"运行错误: {str(e)}")
|
321
|
+
raise
|
322
|
+
|
323
|
+
async def run(
|
324
|
+
self,
|
325
|
+
day: int = 1,
|
326
|
+
):
|
327
|
+
"""运行模拟器
|
328
|
+
|
329
|
+
Args:
|
330
|
+
day: 运行天数,默认为1天
|
331
|
+
"""
|
332
|
+
try:
|
333
|
+
# 获取开始时间
|
334
|
+
start_time = self.simulator.GetTime()
|
335
|
+
# 计算结束时间(秒)
|
336
|
+
end_time = start_time + day * 24 * 3600 # 将天数转换为秒
|
337
|
+
|
338
|
+
while True:
|
339
|
+
current_time = self.simulator.GetTime()
|
340
|
+
if current_time >= end_time:
|
341
|
+
break
|
342
|
+
|
343
|
+
tasks = []
|
344
|
+
for agent in self._agents.values():
|
345
|
+
if agent.name not in self._blocked_agents:
|
346
|
+
tasks.append(agent.forward())
|
347
|
+
|
348
|
+
await asyncio.gather(*tasks)
|
349
|
+
|
350
|
+
except Exception as e:
|
351
|
+
logger.error(f"模拟器运行错误: {str(e)}")
|
352
|
+
raise
|
pycityagent/workflow/__init__.py
CHANGED
@@ -4,19 +4,21 @@
|
|
4
4
|
This module contains classes for creating blocks and running workflows.
|
5
5
|
"""
|
6
6
|
|
7
|
-
from .block import Block, log_and_check, log_and_check_with_memory
|
7
|
+
from .block import Block, log_and_check, log_and_check_with_memory, trigger_class
|
8
8
|
from .prompt import FormatPrompt
|
9
9
|
from .tool import GetMap, SencePOI, Tool
|
10
|
-
from .trigger import MemoryChangeTrigger,
|
10
|
+
from .trigger import MemoryChangeTrigger, TimeTrigger, EventTrigger
|
11
11
|
|
12
12
|
__all__ = [
|
13
13
|
"SencePOI",
|
14
14
|
"Tool",
|
15
15
|
"GetMap",
|
16
16
|
"MemoryChangeTrigger",
|
17
|
-
"
|
17
|
+
"TimeTrigger",
|
18
|
+
"EventTrigger",
|
18
19
|
"Block",
|
19
20
|
"log_and_check",
|
20
21
|
"log_and_check_with_memory",
|
21
22
|
"FormatPrompt",
|
23
|
+
"trigger_class",
|
22
24
|
]
|
pycityagent/workflow/block.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
import asyncio
|
2
2
|
import functools
|
3
3
|
import inspect
|
4
|
-
import time
|
5
4
|
from typing import Any, Callable, Coroutine, Optional, Union
|
6
5
|
|
6
|
+
from pycityagent.environment.simulator import Simulator
|
7
|
+
from pycityagent.workflow.trigger import EventTrigger
|
8
|
+
|
7
9
|
from ..llm import LLM
|
8
10
|
from ..memory import Memory
|
9
11
|
from ..utils.decorators import record_call_aio
|
10
12
|
|
11
|
-
TRIGGER_INTERVAL =
|
13
|
+
TRIGGER_INTERVAL = 1
|
12
14
|
|
13
15
|
|
14
16
|
def log_and_check_with_memory(
|
@@ -119,17 +121,42 @@ def log_and_check(
|
|
119
121
|
return decorator
|
120
122
|
|
121
123
|
|
124
|
+
def trigger_class():
|
125
|
+
def decorator(cls):
|
126
|
+
original_forward = cls.forward
|
127
|
+
|
128
|
+
@functools.wraps(original_forward)
|
129
|
+
async def wrapped_forward(self, *args, **kwargs):
|
130
|
+
if self.trigger is not None:
|
131
|
+
await self.trigger.wait_for_trigger()
|
132
|
+
return await original_forward(self, *args, **kwargs)
|
133
|
+
|
134
|
+
cls.forward = wrapped_forward
|
135
|
+
return cls
|
136
|
+
return decorator
|
137
|
+
|
138
|
+
|
122
139
|
# Define a Block, similar to a layer in PyTorch
|
123
140
|
class Block:
|
124
141
|
def __init__(
|
125
142
|
self,
|
126
143
|
name: str,
|
127
144
|
llm: Optional[LLM] = None,
|
145
|
+
memory: Optional[Memory] = None,
|
146
|
+
simulator: Optional[Simulator] = None,
|
147
|
+
trigger: Optional[EventTrigger] = None,
|
128
148
|
):
|
129
149
|
self.name = name
|
130
150
|
self.llm = llm
|
131
|
-
|
132
|
-
|
151
|
+
self.memory = memory
|
152
|
+
self.simulator = simulator
|
153
|
+
# 如果传入trigger,将block注入到trigger中并立即初始化
|
154
|
+
if trigger is not None:
|
155
|
+
trigger.block = self
|
156
|
+
trigger.initialize() # 立即初始化trigger
|
157
|
+
self.trigger = trigger
|
158
|
+
|
159
|
+
async def forward(self):
|
133
160
|
"""
|
134
161
|
Each block performs a specific reasoning task.
|
135
162
|
To be overridden by specific block implementations.
|