pycityagent 1.1.9__py3-none-any.whl → 1.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +2 -1
- pycityagent/ac/__init__.py +1 -1
- pycityagent/ac/citizen_actions/trip.py +2 -3
- pycityagent/ac/hub_actions.py +216 -16
- pycityagent/agent.py +2 -0
- pycityagent/agent_citizen.py +10 -2
- pycityagent/agent_func.py +93 -13
- pycityagent/agent_group.py +84 -0
- pycityagent/brain/brain.py +1 -0
- pycityagent/brain/memory.py +15 -15
- pycityagent/brain/scheduler.py +1 -1
- pycityagent/brain/sence.py +74 -140
- pycityagent/hubconnector/__init__.py +1 -1
- pycityagent/hubconnector/hubconnector.py +353 -8
- pycityagent/simulator.py +149 -7
- pycityagent/urbanllm/urbanllm.py +305 -5
- pycityagent/utils.py +178 -0
- {pycityagent-1.1.9.dist-info → pycityagent-1.1.11.dist-info}/METADATA +11 -11
- {pycityagent-1.1.9.dist-info → pycityagent-1.1.11.dist-info}/RECORD +22 -20
- {pycityagent-1.1.9.dist-info → pycityagent-1.1.11.dist-info}/WHEEL +1 -1
- {pycityagent-1.1.9.dist-info → pycityagent-1.1.11.dist-info}/LICENSE +0 -0
- {pycityagent-1.1.9.dist-info → pycityagent-1.1.11.dist-info}/top_level.txt +0 -0
pycityagent/simulator.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Simulator: 城市模拟器类及其定义"""
|
2
2
|
|
3
|
-
from typing import Optional, Union
|
3
|
+
from typing import Optional, Union, Tuple
|
4
4
|
from datetime import datetime, timedelta
|
5
5
|
import asyncio
|
6
6
|
from pycitysim import *
|
@@ -58,7 +58,7 @@ class Simulator:
|
|
58
58
|
|
59
59
|
self.map = map.Map(
|
60
60
|
mongo_uri = "mongodb://sim:FiblabSim1001@mgo.db.fiblab.tech:8635/",
|
61
|
-
mongo_db =
|
61
|
+
mongo_db = config['map_request']['mongo_db'],
|
62
62
|
mongo_coll = config['map_request']['mongo_coll'],
|
63
63
|
cache_dir = config['map_request']['cache_dir'],
|
64
64
|
)
|
@@ -67,6 +67,12 @@ class Simulator:
|
|
67
67
|
- Simulator map object
|
68
68
|
"""
|
69
69
|
|
70
|
+
self.pois_matrix: dict[str, list[list[list]]] = {}
|
71
|
+
"""
|
72
|
+
pois的基于区块的划分——方便快速粗略地查询poi
|
73
|
+
通过Simulator.set_pois_matrix()初始化
|
74
|
+
"""
|
75
|
+
|
70
76
|
self.routing = RoutingClient(self.config['route_request']['server'])
|
71
77
|
"""
|
72
78
|
- 导航服务grpc客户端
|
@@ -78,6 +84,20 @@ class Simulator:
|
|
78
84
|
- 模拟城市当前时间
|
79
85
|
- The current time of simulator
|
80
86
|
"""
|
87
|
+
self.poi_cate = {'10': 'eat',
|
88
|
+
'13': 'shopping',
|
89
|
+
'18': 'sports',
|
90
|
+
'22': 'excursion',
|
91
|
+
'16': 'entertainment',
|
92
|
+
'20': 'medical tratment',
|
93
|
+
'14': 'trivialities',
|
94
|
+
'25': 'financial',
|
95
|
+
'12': 'government and political services',
|
96
|
+
'23': 'cultural institutions',
|
97
|
+
'28': 'residence'}
|
98
|
+
self.map_x_gap = None
|
99
|
+
self.map_y_gap = None
|
100
|
+
self.poi_matrix_centers = []
|
81
101
|
|
82
102
|
# * Agent相关
|
83
103
|
def FindAgentsByArea(self, req: dict, status=None):
|
@@ -105,7 +125,7 @@ class Simulator:
|
|
105
125
|
resp.motions = motions
|
106
126
|
return resp
|
107
127
|
|
108
|
-
async def GetCitizenAgent(self, name:str, id:int):
|
128
|
+
async def GetCitizenAgent(self, name:str, id:int) -> CitizenAgent:
|
109
129
|
"""
|
110
130
|
获取agent
|
111
131
|
Get Agent
|
@@ -119,8 +139,9 @@ class Simulator:
|
|
119
139
|
"""
|
120
140
|
await self.GetTime()
|
121
141
|
resp = await self._client.person_service.GetPerson({"person_id": id})
|
122
|
-
|
123
|
-
|
142
|
+
print(f"Agent {id}: {resp}")
|
143
|
+
base = resp['person']['base']
|
144
|
+
motion = resp['person']['motion']
|
124
145
|
agent = CitizenAgent(
|
125
146
|
name,
|
126
147
|
self.config['simulator']['server'],
|
@@ -132,7 +153,7 @@ class Simulator:
|
|
132
153
|
agent.set_streetview_config(self.config['streetview_request'])
|
133
154
|
return agent
|
134
155
|
|
135
|
-
async def GetFuncAgent(self,
|
156
|
+
async def GetFuncAgent(self, name:str) -> FuncAgent:
|
136
157
|
"""
|
137
158
|
获取一个Func Agent模板
|
138
159
|
|
@@ -145,7 +166,6 @@ class Simulator:
|
|
145
166
|
"""
|
146
167
|
agent = FuncAgent(
|
147
168
|
name,
|
148
|
-
id+10000000,
|
149
169
|
self.config['simulator']['server'],
|
150
170
|
simulator=self
|
151
171
|
)
|
@@ -160,6 +180,128 @@ class Simulator:
|
|
160
180
|
"""
|
161
181
|
print("Not Implemented Yet")
|
162
182
|
pass
|
183
|
+
|
184
|
+
def set_poi_matrix(self, map:dict=None, row_number:int=12, col_number:int=10, radius:int=10000):
|
185
|
+
"""
|
186
|
+
初始化pois_matrix
|
187
|
+
|
188
|
+
Args:
|
189
|
+
- map (dict): 地图参数
|
190
|
+
east, west, north, south
|
191
|
+
- row_number (int): 行数
|
192
|
+
- col_number (int): 列数
|
193
|
+
- radius (int): 搜索半径, 单位m
|
194
|
+
"""
|
195
|
+
if map == None:
|
196
|
+
self.matrix_map = self.map
|
197
|
+
else:
|
198
|
+
self.matrix_map = map
|
199
|
+
print(f"Building Poi searching matrix, Row_number: {row_number}, Col_number: {col_number}, Radius: {radius}m")
|
200
|
+
self.map_x_gap = (self.matrix_map.header['east'] - self.matrix_map.header['west']) / col_number
|
201
|
+
self.map_y_gap = (self.matrix_map.header['north'] - self.matrix_map.header['south']) / row_number
|
202
|
+
for i in range(row_number):
|
203
|
+
self.poi_matrix_centers.append([])
|
204
|
+
for j in range(col_number):
|
205
|
+
center_x = self.matrix_map.header['west'] + self.map_x_gap*j + self.map_x_gap/2
|
206
|
+
center_y = self.matrix_map.header['south'] + self.map_y_gap*i + self.map_y_gap/2
|
207
|
+
self.poi_matrix_centers[i].append((center_x, center_y))
|
208
|
+
|
209
|
+
for pre in self.poi_cate.keys():
|
210
|
+
print(f"Building matrix for Poi category: {pre}")
|
211
|
+
self.pois_matrix[pre] = []
|
212
|
+
for row_centers in self.poi_matrix_centers:
|
213
|
+
row_pois = []
|
214
|
+
for center in row_centers:
|
215
|
+
pois = self.map.query_pois(center=center, radius=radius, category_prefix=pre)
|
216
|
+
row_pois.append(pois)
|
217
|
+
self.pois_matrix[pre].append(row_pois)
|
218
|
+
print("Finished")
|
219
|
+
|
220
|
+
def get_pois_from_matrix(self, center:Tuple[float, float], prefix:str):
|
221
|
+
"""
|
222
|
+
从poi搜索矩阵中快速获取poi
|
223
|
+
|
224
|
+
Args:
|
225
|
+
- center (Tuple[float, float]): 位置信息
|
226
|
+
- prefix (str): 类型前缀
|
227
|
+
"""
|
228
|
+
if self.map_x_gap == None:
|
229
|
+
print("Set Poi Matrix first")
|
230
|
+
return
|
231
|
+
elif prefix not in self.poi_cate.keys():
|
232
|
+
print(f"Wrong prefix, only {self.poi_cate.keys()} is usable")
|
233
|
+
return
|
234
|
+
elif center[0] > self.matrix_map.header['east'] or center[0] < self.matrix_map.header['west'] or center[1] > self.matrix_map.header['north'] or center[1] < self.matrix_map.header['south']:
|
235
|
+
print("Wrong center")
|
236
|
+
return
|
237
|
+
|
238
|
+
# 矩阵匹配
|
239
|
+
rows = int((center[1]-self.matrix_map.header['south'])/self.map_y_gap)
|
240
|
+
cols = int((center[0]-self.matrix_map.header['west'])/self.map_x_gap)
|
241
|
+
pois = self.pois_matrix[prefix][rows][cols]
|
242
|
+
return pois
|
243
|
+
|
244
|
+
def get_cat_from_pois(self, pois:list):
|
245
|
+
cat_2_num = {}
|
246
|
+
for poi in pois:
|
247
|
+
cate = poi['category'][:2]
|
248
|
+
if cate not in self.poi_cate.keys():
|
249
|
+
continue
|
250
|
+
if cate in cat_2_num.keys():
|
251
|
+
cat_2_num[cate] += 1
|
252
|
+
else:
|
253
|
+
cat_2_num[cate] = 1
|
254
|
+
max_cat = ""
|
255
|
+
max_num = 0
|
256
|
+
for key in cat_2_num.keys():
|
257
|
+
if cat_2_num[key] > max_num:
|
258
|
+
max_num = cat_2_num[key]
|
259
|
+
max_cat = self.poi_cate[key]
|
260
|
+
return max_cat
|
261
|
+
|
262
|
+
def get_poi_matrix_in_rec(self, center:Tuple[float, float], radius:int=2500, rows:int=5, cols:int=5):
|
263
|
+
"""
|
264
|
+
获取以center为中心的正方形区域内的poi集合
|
265
|
+
|
266
|
+
Args:
|
267
|
+
- center (Tuple[float, float]): 中心位置点
|
268
|
+
- radius (int): 半径
|
269
|
+
"""
|
270
|
+
north = center[1] + radius
|
271
|
+
south = center[1] - radius
|
272
|
+
west = center[0] - radius
|
273
|
+
east = center[0] + radius
|
274
|
+
x_gap = (east-west)/cols
|
275
|
+
y_gap = (north-south)/rows
|
276
|
+
matrix = []
|
277
|
+
for i in range(rows):
|
278
|
+
matrix.append([])
|
279
|
+
for j in range(cols):
|
280
|
+
matrix[i].append([])
|
281
|
+
pois = []
|
282
|
+
for poi in self.map.pois.values():
|
283
|
+
x = poi['position']['x']
|
284
|
+
y = poi['position']['y']
|
285
|
+
if x > west and x < east and y > south and y < north:
|
286
|
+
row_index = int((y-south)/x_gap)
|
287
|
+
col_index = int((x-west)/y_gap)
|
288
|
+
matrix[row_index][col_index].append(poi)
|
289
|
+
matrix_type = []
|
290
|
+
for i in range(rows):
|
291
|
+
for j in range(cols):
|
292
|
+
matrix_type.append(self.get_cat_from_pois(matrix[i][j]))
|
293
|
+
poi_total_number = []
|
294
|
+
poi_type_number = []
|
295
|
+
for i in range(rows):
|
296
|
+
for j in range(cols):
|
297
|
+
poi_total_number.append(len(matrix[i][j]))
|
298
|
+
number = 0
|
299
|
+
for poi in matrix[i][j]:
|
300
|
+
if poi['category'][:2] in self.poi_cate.keys() and self.poi_cate[poi['category'][:2]] == matrix_type[i*cols+j]:
|
301
|
+
number += 1
|
302
|
+
poi_type_number.append(number)
|
303
|
+
|
304
|
+
return matrix, matrix_type, poi_total_number, poi_type_number
|
163
305
|
|
164
306
|
async def GetTime(self, format_time:bool=False, format:Optional[str]="%H:%M:%S") -> Union[int, str]:
|
165
307
|
"""
|
pycityagent/urbanllm/urbanllm.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
"""UrbanLLM: 智能能力类及其定义"""
|
2
2
|
|
3
|
-
from openai import OpenAI
|
3
|
+
from openai import OpenAI, AsyncOpenAI, APIConnectionError, OpenAIError
|
4
|
+
import openai
|
5
|
+
import asyncio
|
4
6
|
from http import HTTPStatus
|
5
7
|
import dashscope
|
6
8
|
import requests
|
@@ -9,10 +11,55 @@ from PIL import Image
|
|
9
11
|
from io import BytesIO
|
10
12
|
from typing import Union
|
11
13
|
import base64
|
14
|
+
import aiohttp
|
15
|
+
import re
|
12
16
|
|
13
17
|
def encode_image(image_path):
|
14
18
|
with open(image_path, "rb") as image_file:
|
15
19
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
20
|
+
|
21
|
+
class VarSet:
|
22
|
+
def __init__(self, variable_batch, variable_description, K):
|
23
|
+
self.var_batch = variable_batch
|
24
|
+
self.var_des = variable_description
|
25
|
+
self.var_size = len(list(variable_description.keys()))
|
26
|
+
self.batch_size = K
|
27
|
+
|
28
|
+
def set_variables(self, variables:list[dict]):
|
29
|
+
if len(variables) <= 0 or len(variables) != self.batch_size or (len(variables[0]) != self.var_size):
|
30
|
+
print(f"Your input variables mis-match the size of current Variable Set: Batch size-{self.batch_size}, Variable size-{self.var_size}")
|
31
|
+
return None
|
32
|
+
for i in range(len(self.var_batch)):
|
33
|
+
for var, _ in self.var_batch[i].items():
|
34
|
+
self.var_batch[i][var] = variables[i][var]
|
35
|
+
|
36
|
+
def get_variables(self):
|
37
|
+
return self.var_batch
|
38
|
+
|
39
|
+
def get_var_des(self):
|
40
|
+
var_des = ""
|
41
|
+
for var, desc in self.var_des.items():
|
42
|
+
var_des += f"batch[x][{var}] -- {desc}\n"
|
43
|
+
return var_des
|
44
|
+
|
45
|
+
class PromptBatch:
|
46
|
+
def __init__(self, prompt_batch:str, variable_batch, vasriable_description, K):
|
47
|
+
self.prompt_raw = prompt_batch
|
48
|
+
self.prompt_set = None
|
49
|
+
self.batch_size = K
|
50
|
+
self.var_set = VarSet(variable_batch, vasriable_description, K)
|
51
|
+
|
52
|
+
def set_variables(self, variables:list[dict]):
|
53
|
+
self.var_set.set_variables(variables)
|
54
|
+
self.prompt_set = self.prompt_raw.format(batch=self.get_variables())
|
55
|
+
|
56
|
+
def get_variables(self):
|
57
|
+
return self.var_set.var_batch
|
58
|
+
|
59
|
+
def get_prompt(self):
|
60
|
+
if self.prompt_set == None:
|
61
|
+
return self.prompt_raw
|
62
|
+
return self.prompt_set
|
16
63
|
|
17
64
|
class LLMConfig:
|
18
65
|
"""
|
@@ -37,6 +84,53 @@ class UrbanLLM:
|
|
37
84
|
"""
|
38
85
|
def __init__(self, config: LLMConfig) -> None:
|
39
86
|
self.config = config
|
87
|
+
self.prompt_tokens_used = 0
|
88
|
+
self.completion_tokens_used = 0
|
89
|
+
self.request_number = 0
|
90
|
+
self.semaphore = None
|
91
|
+
self._aclient = AsyncOpenAI(api_key=self.config.text['api_key'], timeout=300)
|
92
|
+
|
93
|
+
def set_semaphore(self, number_of_coroutine:int):
|
94
|
+
self.semaphore = asyncio.Semaphore(number_of_coroutine)
|
95
|
+
|
96
|
+
def clear_semaphore(self):
|
97
|
+
self.semaphore = None
|
98
|
+
|
99
|
+
def clear_used(self):
|
100
|
+
"""
|
101
|
+
clear the storage of used tokens to start a new log message
|
102
|
+
Only support OpenAI category API right now, including OpenAI, Deepseek
|
103
|
+
"""
|
104
|
+
self.prompt_tokens_used = 0
|
105
|
+
self.completion_tokens_used = 0
|
106
|
+
self.request_number = 0
|
107
|
+
|
108
|
+
def show_consumption(self, input_price:float=None, output_price:float=None):
|
109
|
+
"""
|
110
|
+
if you give the input and output price of using model, this function will also calculate the consumption for you
|
111
|
+
"""
|
112
|
+
total_token = self.prompt_tokens_used + self.completion_tokens_used
|
113
|
+
if self.completion_tokens_used != 0:
|
114
|
+
rate = self.prompt_tokens_used/self.completion_tokens_used
|
115
|
+
else:
|
116
|
+
rate = 'nan'
|
117
|
+
if self.request_number != 0:
|
118
|
+
TcA = total_token/self.request_number
|
119
|
+
else:
|
120
|
+
TcA = 'nan'
|
121
|
+
out = f"""Request Number: {self.request_number}
|
122
|
+
Token Usage:
|
123
|
+
- Total tokens: {total_token}
|
124
|
+
- Prompt tokens: {self.prompt_tokens_used}
|
125
|
+
- Completion tokens: {self.completion_tokens_used}
|
126
|
+
- Token per request: {TcA}
|
127
|
+
- Prompt:Completion ratio: {rate}:1"""
|
128
|
+
if input_price != None and output_price != None:
|
129
|
+
consumption = self.prompt_tokens_used/1000000*input_price + self.completion_tokens_used/1000000*output_price
|
130
|
+
out += f"\n - Cost Estimation: {consumption}"
|
131
|
+
print(out)
|
132
|
+
return {"total": total_token, "prompt": self.prompt_tokens_used, "completion": self.completion_tokens_used, "ratio": rate}
|
133
|
+
|
40
134
|
|
41
135
|
def text_request(self, dialog:list[dict], temperature:float=1, max_tokens:int=None, top_p:float=None, frequency_penalty:float=None, presence_penalty:float=None) -> str:
|
42
136
|
"""
|
@@ -72,6 +166,9 @@ class UrbanLLM:
|
|
72
166
|
frequency_penalty=frequency_penalty,
|
73
167
|
presence_penalty=presence_penalty
|
74
168
|
)
|
169
|
+
self.prompt_tokens_used += response.usage.prompt_tokens
|
170
|
+
self.completion_tokens_used += response.usage.completion_tokens
|
171
|
+
self.request_number += 1
|
75
172
|
return response.choices[0].message.content
|
76
173
|
elif self.config.text['request_type'] == 'qwen':
|
77
174
|
response = dashscope.Generation.call(
|
@@ -83,12 +180,128 @@ class UrbanLLM:
|
|
83
180
|
if response.status_code == HTTPStatus.OK:
|
84
181
|
return response.output.choices[0]['message']['content']
|
85
182
|
else:
|
86
|
-
return "Error: {}".format(response.status_code)
|
183
|
+
return "Error: {}, {}".format(response.status_code, response.message)
|
184
|
+
elif self.config.text['request_type'] == 'deepseek':
|
185
|
+
client = OpenAI(
|
186
|
+
api_key=self.config.text['api_key'],
|
187
|
+
base_url="https://api.deepseek.com/beta",
|
188
|
+
)
|
189
|
+
response = client.chat.completions.create(
|
190
|
+
model=self.config.text['model'],
|
191
|
+
messages=dialog,
|
192
|
+
temperature=temperature,
|
193
|
+
max_tokens=max_tokens,
|
194
|
+
top_p=top_p,
|
195
|
+
frequency_penalty=frequency_penalty,
|
196
|
+
presence_penalty=presence_penalty,
|
197
|
+
stream=False,
|
198
|
+
)
|
199
|
+
self.prompt_tokens_used += response.usage.prompt_tokens
|
200
|
+
self.completion_tokens_used += response.usage.completion_tokens
|
201
|
+
self.request_number += 1
|
202
|
+
return response.choices[0].message.content
|
87
203
|
else:
|
88
204
|
print("ERROR: Wrong Config")
|
89
205
|
return "wrong config"
|
206
|
+
|
207
|
+
async def atext_request(self, dialog:list[dict], temperature:float=1, max_tokens:int=None, top_p:float=None, frequency_penalty:float=None, presence_penalty:float=None, timeout:int=300, retries=3):
|
208
|
+
"""
|
209
|
+
异步版文本请求
|
210
|
+
"""
|
211
|
+
if self.config.text['request_type'] == 'openai':
|
212
|
+
for attempt in range(retries):
|
213
|
+
try:
|
214
|
+
if self.semaphore != None:
|
215
|
+
async with self.semaphore:
|
216
|
+
response = await self._aclient.chat.completions.create(
|
217
|
+
model=self.config.text['model'],
|
218
|
+
messages=dialog,
|
219
|
+
temperature=temperature,
|
220
|
+
max_tokens=max_tokens,
|
221
|
+
top_p=top_p,
|
222
|
+
frequency_penalty=frequency_penalty,
|
223
|
+
presence_penalty=presence_penalty,
|
224
|
+
stream=False,
|
225
|
+
timeout=timeout,
|
226
|
+
)
|
227
|
+
self.prompt_tokens_used += response.usage.prompt_tokens
|
228
|
+
self.completion_tokens_used += response.usage.completion_tokens
|
229
|
+
self.request_number += 1
|
230
|
+
return response.choices[0].message.content
|
231
|
+
else:
|
232
|
+
response = await self._aclient.chat.completions.create(
|
233
|
+
model=self.config.text['model'],
|
234
|
+
messages=dialog,
|
235
|
+
temperature=temperature,
|
236
|
+
max_tokens=max_tokens,
|
237
|
+
top_p=top_p,
|
238
|
+
frequency_penalty=frequency_penalty,
|
239
|
+
presence_penalty=presence_penalty,
|
240
|
+
stream=False,
|
241
|
+
timeout=timeout,
|
242
|
+
)
|
243
|
+
self.prompt_tokens_used += response.usage.prompt_tokens
|
244
|
+
self.completion_tokens_used += response.usage.completion_tokens
|
245
|
+
self.request_number += 1
|
246
|
+
return response.choices[0].message.content
|
247
|
+
except APIConnectionError as e:
|
248
|
+
print("API connection error:", e)
|
249
|
+
if attempt < retries - 1:
|
250
|
+
await asyncio.sleep(2 ** attempt)
|
251
|
+
else:
|
252
|
+
raise e
|
253
|
+
except OpenAIError as e:
|
254
|
+
if hasattr(e, 'http_status'):
|
255
|
+
print(f"HTTP status code: {e.http_status}")
|
256
|
+
else:
|
257
|
+
print("An error occurred:", e)
|
258
|
+
if attempt < retries - 1:
|
259
|
+
await asyncio.sleep(2 ** attempt)
|
260
|
+
else:
|
261
|
+
raise e
|
262
|
+
elif self.config.text['request_type'] == 'qwen':
|
263
|
+
async with aiohttp.ClientSession() as session:
|
264
|
+
api_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
265
|
+
headers = {"Content-Type": "application/json", "Authorization": f"{self.config.text['api_key']}"}
|
266
|
+
payload = {
|
267
|
+
'model': self.config.text['model'],
|
268
|
+
'input': {
|
269
|
+
'messages': dialog
|
270
|
+
}
|
271
|
+
}
|
272
|
+
async with session.post(api_url, json=payload, headers=headers) as resp:
|
273
|
+
# 错误检查
|
274
|
+
response_json = await resp.json()
|
275
|
+
if 'code' in response_json.keys():
|
276
|
+
raise Exception(f"Error: {response_json['code']}, {response_json['message']}")
|
277
|
+
else:
|
278
|
+
return response_json['output']['text']
|
279
|
+
elif self.config.text['request_type'] == 'deepseek':
|
280
|
+
client = AsyncOpenAI(
|
281
|
+
api_key=self.config.text['api_key'],
|
282
|
+
base_url="https://api.deepseek.com/beta",
|
283
|
+
)
|
284
|
+
response = await client.chat.completions.create(
|
285
|
+
model="deepseek-chat",
|
286
|
+
messages=dialog,
|
287
|
+
temperature=temperature,
|
288
|
+
max_tokens=max_tokens,
|
289
|
+
top_p=top_p,
|
290
|
+
frequency_penalty=frequency_penalty,
|
291
|
+
presence_penalty=presence_penalty,
|
292
|
+
stream=False,
|
293
|
+
timeout=timeout,
|
294
|
+
)
|
295
|
+
self.prompt_tokens_used += response.usage.prompt_tokens
|
296
|
+
self.completion_tokens_used += response.usage.completion_tokens
|
297
|
+
self.request_number += 1
|
298
|
+
return response.choices[0].message.content
|
299
|
+
else:
|
300
|
+
print("ERROR: Wrong Config")
|
301
|
+
return "wrong config"
|
302
|
+
|
90
303
|
|
91
|
-
def img_understand(self, img_path:Union[str, list[str]], prompt:str=None) -> str:
|
304
|
+
async def img_understand(self, img_path:Union[str, list[str]], prompt:str=None) -> str:
|
92
305
|
"""
|
93
306
|
图像理解
|
94
307
|
Image understanding
|
@@ -169,7 +382,7 @@ class UrbanLLM:
|
|
169
382
|
print("ERROR: wrong image understanding type, only 'openai' and 'openai' is available")
|
170
383
|
return "Error"
|
171
384
|
|
172
|
-
def img_generate(self, prompt:str, size:str='512*512', quantity:int = 1):
|
385
|
+
async def img_generate(self, prompt:str, size:str='512*512', quantity:int = 1):
|
173
386
|
"""
|
174
387
|
图像生成
|
175
388
|
Image generation
|
@@ -197,4 +410,91 @@ class UrbanLLM:
|
|
197
410
|
else:
|
198
411
|
print('Failed, status_code: %s, code: %s, message: %s' %
|
199
412
|
(rsp.status_code, rsp.code, rsp.message))
|
200
|
-
return None
|
413
|
+
return None
|
414
|
+
|
415
|
+
async def convert_prompt_to_batch(self, prompt_raw, K, type:str="mr"):
|
416
|
+
"""
|
417
|
+
将单个目标的 prompt_raw 转换为适合一次批处理调用的 prompt_batch。
|
418
|
+
prompt_raw: 单目标的prompt,包含占位符!<INPUT X>!
|
419
|
+
K: 批处理的大小
|
420
|
+
type: 批处理prompt类型, 'mr': Multi-request; 'gr': Gather-request
|
421
|
+
返回值: prompt_batch 和 variable_batch
|
422
|
+
"""
|
423
|
+
# 分割说明性文本和待格式化的 prompt
|
424
|
+
sections = prompt_raw.split("<commentblockmarker>###</commentblockmarker>")
|
425
|
+
description_section = sections[0].strip() # 说明文本部分
|
426
|
+
prompt_section = sections[1].strip() # 待格式化的 prompt
|
427
|
+
|
428
|
+
# 1. 提取变量说明
|
429
|
+
variable_info = {}
|
430
|
+
variable_lines = re.findall(r"!<INPUT (\d+)>! -- (.+)", description_section)
|
431
|
+
for var, desc in variable_lines:
|
432
|
+
variable_info[int(var)] = desc.strip()
|
433
|
+
|
434
|
+
# 找到所有的占位符,如!<INPUT X>!
|
435
|
+
placeholders = re.findall(r"!<INPUT (\d+)>!", prompt_section)
|
436
|
+
|
437
|
+
# 构建prompt_batch,使用批次变量占位符,并合并为一个请求
|
438
|
+
prompt_batch = ""
|
439
|
+
for i in range(K):
|
440
|
+
current_prompt = prompt_section
|
441
|
+
for ph in placeholders:
|
442
|
+
current_prompt = current_prompt.replace(f"!<INPUT {ph}>!", f"{{batch[{i}][{ph}]}}")
|
443
|
+
prompt_batch += f"### Request {i + 1} ###\n" + current_prompt + "\n"
|
444
|
+
|
445
|
+
# 构建variable_batch,假设每个批次有K个值
|
446
|
+
variable_batch = [{int(ph): f"input_{i}_{ph}" for ph in placeholders} for i in range(K)]
|
447
|
+
if type == "mr":
|
448
|
+
return PromptBatch(prompt_batch, variable_batch, variable_info, K)
|
449
|
+
elif type == "gr":
|
450
|
+
messages = [{
|
451
|
+
"role": "system",
|
452
|
+
"content": """
|
453
|
+
You are given a batch prompt structured in multiple requests and a variable description. Your task is to simplify and aggregate this prompt by removing redundant request headers and combining variables across different requests into a more compact format.
|
454
|
+
|
455
|
+
### Instructions:
|
456
|
+
1. Identify all unique sections in the prompt that appear across different requests. These sections will be aggregated into a single instance.
|
457
|
+
2. For each variable in the prompt (e.g., `{batch[0][0]}`, `{batch[0][1]}`), group them together across all requests, and present them as a list or in a numbered format.
|
458
|
+
3. Remove all redundant request headers (e.g., `### Request 1 ###`) and replace them with a single unified block where variables are grouped by their positions.
|
459
|
+
4. Ensure the output is structured clearly for LLM processing, avoiding unnecessary repetition.
|
460
|
+
5. Ensure the output only contains the transformed prompt.
|
461
|
+
|
462
|
+
Here is an example transformation:
|
463
|
+
Original prompt:
|
464
|
+
### Request 1 ###
|
465
|
+
{batch[0][0]}
|
466
|
+
In general, Your Lifestyle is as follows: {batch[0][1]}
|
467
|
+
Today is {batch[0][2]}, your wake up hour:
|
468
|
+
### Request 2 ###
|
469
|
+
{batch[1][0]}
|
470
|
+
In general, Your Lifestyle is as follows: {batch[1][1]}
|
471
|
+
Today is {batch[1][2]}, your wake up hour:
|
472
|
+
Variable description:
|
473
|
+
batch[x][0] -- Identity Stable Set
|
474
|
+
batch[x][1] -- Lifestyle
|
475
|
+
batch[x][2] -- Day
|
476
|
+
|
477
|
+
Transformed prompt:
|
478
|
+
Your Identity Stable Set:
|
479
|
+
1. {batch[0][0]}
|
480
|
+
2. {batch[1][0]}
|
481
|
+
Your Lifestyle:
|
482
|
+
1. {batch[0][1]}
|
483
|
+
2. {batch[1][1]}
|
484
|
+
The Day:
|
485
|
+
1. {batch[0][2]}
|
486
|
+
2. {batch[1][2]}
|
487
|
+
Today, your wake up hour is:
|
488
|
+
""",
|
489
|
+
}]
|
490
|
+
var_des = ""
|
491
|
+
for var, desc in variable_info.items():
|
492
|
+
var_des += f"batch[x][{var}] -- {desc}\n"
|
493
|
+
user_prompt = f"""Here is the original batch prompt:
|
494
|
+
{prompt_batch}
|
495
|
+
The Variable description:
|
496
|
+
{var_des}
|
497
|
+
"""
|
498
|
+
messages.append({'role': 'user', 'content': user_prompt})
|
499
|
+
prompt_batch = await self.atext_request(messages)
|
500
|
+
return PromptBatch(prompt_batch, variable_batch, variable_info, K)
|