pycityagent 1.0.0__py3-none-any.whl → 2.0.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +7 -3
- pycityagent/agent.py +180 -284
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +307 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/interact.py +141 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/{brain → environment/sence}/static.py +1 -1
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +38 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +43 -0
- pycityagent/environment/sim/economy_services.py +191 -0
- pycityagent/environment/sim/lane_service.py +110 -0
- pycityagent/environment/sim/light_service.py +120 -0
- pycityagent/environment/sim/person_service.py +294 -0
- pycityagent/environment/sim/road_service.py +38 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +58 -0
- pycityagent/environment/simulator.py +320 -0
- pycityagent/environment/utils/__init__.py +10 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +242 -0
- pycityagent/environment/utils/geojson.py +26 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +39 -0
- pycityagent/llm/__init__.py +6 -0
- pycityagent/llm/embedding.py +136 -0
- pycityagent/llm/llm.py +430 -0
- pycityagent/llm/llmconfig.py +15 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +11 -0
- pycityagent/memory/const.py +41 -0
- pycityagent/memory/memory.py +453 -0
- pycityagent/memory/memory_base.py +168 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +27 -0
- pycityagent/message/__init__.py +0 -0
- pycityagent/simulation/__init__.py +7 -0
- pycityagent/simulation/interview.py +36 -0
- pycityagent/simulation/simulation.py +352 -0
- pycityagent/simulation/survey/__init__.py +9 -0
- pycityagent/simulation/survey/manager.py +67 -0
- pycityagent/simulation/survey/models.py +49 -0
- pycityagent/simulation/ui/__init__.py +3 -0
- pycityagent/simulation/ui/interface.py +602 -0
- pycityagent/utils/__init__.py +0 -0
- pycityagent/utils/decorators.py +89 -0
- pycityagent/utils/parsers/__init__.py +12 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/workflow/__init__.py +24 -0
- pycityagent/workflow/block.py +164 -0
- pycityagent/workflow/prompt.py +72 -0
- pycityagent/workflow/tool.py +246 -0
- pycityagent/workflow/trigger.py +150 -0
- pycityagent-2.0.0a2.dist-info/METADATA +208 -0
- pycityagent-2.0.0a2.dist-info/RECORD +69 -0
- {pycityagent-1.0.0.dist-info → pycityagent-2.0.0a2.dist-info}/WHEEL +1 -2
- pycityagent/ac/__init__.py +0 -6
- pycityagent/ac/ac.py +0 -50
- pycityagent/ac/action.py +0 -14
- pycityagent/ac/controled.py +0 -13
- pycityagent/ac/converse.py +0 -31
- pycityagent/ac/idle.py +0 -17
- pycityagent/ac/shop.py +0 -80
- pycityagent/ac/trip.py +0 -37
- pycityagent/brain/__init__.py +0 -10
- pycityagent/brain/brain.py +0 -52
- pycityagent/brain/brainfc.py +0 -10
- pycityagent/brain/memory.py +0 -541
- pycityagent/brain/persistence/social.py +0 -1
- pycityagent/brain/persistence/spatial.py +0 -14
- pycityagent/brain/reason/shop.py +0 -37
- pycityagent/brain/reason/social.py +0 -148
- pycityagent/brain/reason/trip.py +0 -67
- pycityagent/brain/reason/user.py +0 -122
- pycityagent/brain/retrive/social.py +0 -6
- pycityagent/brain/scheduler.py +0 -408
- pycityagent/brain/sence.py +0 -375
- pycityagent/cc/__init__.py +0 -5
- pycityagent/cc/cc.py +0 -102
- pycityagent/cc/conve.py +0 -6
- pycityagent/cc/idle.py +0 -20
- pycityagent/cc/shop.py +0 -6
- pycityagent/cc/trip.py +0 -13
- pycityagent/cc/user.py +0 -13
- pycityagent/hubconnector/__init__.py +0 -3
- pycityagent/hubconnector/hubconnector.py +0 -137
- pycityagent/image/__init__.py +0 -3
- pycityagent/image/image.py +0 -158
- pycityagent/simulator.py +0 -161
- pycityagent/st/__init__.py +0 -4
- pycityagent/st/st.py +0 -96
- pycityagent/urbanllm/__init__.py +0 -3
- pycityagent/urbanllm/urbanllm.py +0 -132
- pycityagent-1.0.0.dist-info/LICENSE +0 -21
- pycityagent-1.0.0.dist-info/METADATA +0 -181
- pycityagent-1.0.0.dist-info/RECORD +0 -48
- pycityagent-1.0.0.dist-info/top_level.txt +0 -1
- /pycityagent/{brain/persistence/__init__.py → config.py} +0 -0
- /pycityagent/{brain/reason → environment/interact}/__init__.py +0 -0
- /pycityagent/{brain/retrive → environment/message}/__init__.py +0 -0
pycityagent/llm/llm.py
ADDED
@@ -0,0 +1,430 @@
|
|
1
|
+
"""UrbanLLM: 智能能力类及其定义"""
|
2
|
+
|
3
|
+
import json
|
4
|
+
from openai import OpenAI, AsyncOpenAI, APIConnectionError, OpenAIError
|
5
|
+
from zhipuai import ZhipuAI
|
6
|
+
import logging
|
7
|
+
logging.getLogger("zhipuai").setLevel(logging.WARNING)
|
8
|
+
|
9
|
+
import asyncio
|
10
|
+
from http import HTTPStatus
|
11
|
+
import dashscope
|
12
|
+
import requests
|
13
|
+
from dashscope import ImageSynthesis
|
14
|
+
from PIL import Image
|
15
|
+
from io import BytesIO
|
16
|
+
from typing import Any, Optional, Union, List, Dict
|
17
|
+
import aiohttp
|
18
|
+
from .llmconfig import *
|
19
|
+
from .utils import *
|
20
|
+
|
21
|
+
import os
|
22
|
+
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
23
|
+
|
24
|
+
class LLM:
|
25
|
+
"""
|
26
|
+
大语言模型对象
|
27
|
+
The LLM Object used by Agent(Soul)
|
28
|
+
"""
|
29
|
+
def __init__(self, config: LLMConfig) -> None:
|
30
|
+
self.config = config
|
31
|
+
if config.text['request_type'] not in ['openai', 'deepseek', 'qwen', 'zhipuai']:
|
32
|
+
raise ValueError("Invalid request type for text request")
|
33
|
+
self.prompt_tokens_used = 0
|
34
|
+
self.completion_tokens_used = 0
|
35
|
+
self.request_number = 0
|
36
|
+
self.semaphore = None
|
37
|
+
if self.config.text['request_type'] == 'openai':
|
38
|
+
self._aclient = AsyncOpenAI(api_key=self.config.text['api_key'], timeout=300)
|
39
|
+
elif self.config.text['request_type'] == 'deepseek':
|
40
|
+
self._aclient = AsyncOpenAI(api_key=self.config.text['api_key'], base_url="https://api.deepseek.com/beta", timeout=300)
|
41
|
+
elif self.config.text['request_type'] == 'zhipuai':
|
42
|
+
self._aclient = ZhipuAI(api_key=self.config.text['api_key'], timeout=300)
|
43
|
+
|
44
|
+
def set_semaphore(self, number_of_coroutine:int):
|
45
|
+
self.semaphore = asyncio.Semaphore(number_of_coroutine)
|
46
|
+
|
47
|
+
def clear_semaphore(self):
|
48
|
+
self.semaphore = None
|
49
|
+
|
50
|
+
def clear_used(self):
|
51
|
+
"""
|
52
|
+
clear the storage of used tokens to start a new log message
|
53
|
+
Only support OpenAI category API right now, including OpenAI, Deepseek
|
54
|
+
"""
|
55
|
+
self.prompt_tokens_used = 0
|
56
|
+
self.completion_tokens_used = 0
|
57
|
+
self.request_number = 0
|
58
|
+
|
59
|
+
def show_consumption(self, input_price:Optional[float]=None, output_price:Optional[float]=None):
|
60
|
+
"""
|
61
|
+
if you give the input and output price of using model, this function will also calculate the consumption for you
|
62
|
+
"""
|
63
|
+
total_token = self.prompt_tokens_used + self.completion_tokens_used
|
64
|
+
if self.completion_tokens_used != 0:
|
65
|
+
rate = self.prompt_tokens_used/self.completion_tokens_used
|
66
|
+
else:
|
67
|
+
rate = 'nan'
|
68
|
+
if self.request_number != 0:
|
69
|
+
TcA = total_token/self.request_number
|
70
|
+
else:
|
71
|
+
TcA = 'nan'
|
72
|
+
out = f"""Request Number: {self.request_number}
|
73
|
+
Token Usage:
|
74
|
+
- Total tokens: {total_token}
|
75
|
+
- Prompt tokens: {self.prompt_tokens_used}
|
76
|
+
- Completion tokens: {self.completion_tokens_used}
|
77
|
+
- Token per request: {TcA}
|
78
|
+
- Prompt:Completion ratio: {rate}:1"""
|
79
|
+
if input_price != None and output_price != None:
|
80
|
+
consumption = self.prompt_tokens_used/1000000*input_price + self.completion_tokens_used/1000000*output_price
|
81
|
+
out += f"\n - Cost Estimation: {consumption}"
|
82
|
+
print(out)
|
83
|
+
return {"total": total_token, "prompt": self.prompt_tokens_used, "completion": self.completion_tokens_used, "ratio": rate}
|
84
|
+
|
85
|
+
|
86
|
+
def text_request(
|
87
|
+
self,
|
88
|
+
dialog: Any,
|
89
|
+
temperature: float = 1,
|
90
|
+
max_tokens: Optional[int] = None,
|
91
|
+
top_p: Optional[float] = None,
|
92
|
+
frequency_penalty: Optional[float] = None,
|
93
|
+
presence_penalty: Optional[float] = None,
|
94
|
+
tools:Optional[List[Dict[str, Any]]]=None,
|
95
|
+
tool_choice:Optional[Dict[str, Any]]=None
|
96
|
+
) -> Optional[str]:
|
97
|
+
"""
|
98
|
+
文本相关请求
|
99
|
+
Text request
|
100
|
+
|
101
|
+
Args:
|
102
|
+
- dialog (list[dict]): 标准的LLM文本dialog. The standard text LLM dialog
|
103
|
+
- temperature (float): default 1, used in openai
|
104
|
+
- max_tokens (int): default None, used in openai
|
105
|
+
- top_p (float): default None, used in openai
|
106
|
+
- frequency_penalty (float): default None, used in openai
|
107
|
+
- presence_penalty (float): default None, used in openai
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
- (str): the response content
|
111
|
+
"""
|
112
|
+
if 'api_base' in self.config.text.keys():
|
113
|
+
api_base = self.config.text['api_base']
|
114
|
+
else:
|
115
|
+
api_base = None
|
116
|
+
if self.config.text['request_type'] == 'openai':
|
117
|
+
client = OpenAI(
|
118
|
+
api_key=self.config.text['api_key'],
|
119
|
+
base_url=api_base,
|
120
|
+
)
|
121
|
+
response = client.chat.completions.create(
|
122
|
+
model=self.config.text['model'],
|
123
|
+
messages=dialog,
|
124
|
+
temperature=temperature,
|
125
|
+
max_tokens=max_tokens,
|
126
|
+
top_p=top_p,
|
127
|
+
frequency_penalty=frequency_penalty,
|
128
|
+
presence_penalty=presence_penalty,
|
129
|
+
tools=tools,
|
130
|
+
tool_choice=tool_choice
|
131
|
+
)
|
132
|
+
self.prompt_tokens_used += response.usage.prompt_tokens # type: ignore
|
133
|
+
self.completion_tokens_used += response.usage.completion_tokens # type: ignore
|
134
|
+
self.request_number += 1
|
135
|
+
if tools != None:
|
136
|
+
return response.tool_calls[0].function.arguments
|
137
|
+
else:
|
138
|
+
return response.choices[0].message.content
|
139
|
+
elif self.config.text['request_type'] == 'qwen':
|
140
|
+
response = dashscope.Generation.call(
|
141
|
+
model=self.config.text['model'],
|
142
|
+
api_key=self.config.text['api_key'],
|
143
|
+
messages=dialog,
|
144
|
+
result_format='message'
|
145
|
+
)
|
146
|
+
if response.status_code == HTTPStatus.OK: # type: ignore
|
147
|
+
return response.output.choices[0]['message']['content'] # type: ignore
|
148
|
+
else:
|
149
|
+
return "Error: {}, {}".format(response.status_code, response.message) # type: ignore
|
150
|
+
elif self.config.text['request_type'] == 'deepseek':
|
151
|
+
client = OpenAI(
|
152
|
+
api_key=self.config.text['api_key'],
|
153
|
+
base_url="https://api.deepseek.com/beta",
|
154
|
+
)
|
155
|
+
response = client.chat.completions.create(
|
156
|
+
model=self.config.text['model'],
|
157
|
+
messages=dialog,
|
158
|
+
temperature=temperature,
|
159
|
+
max_tokens=max_tokens,
|
160
|
+
top_p=top_p,
|
161
|
+
frequency_penalty=frequency_penalty,
|
162
|
+
presence_penalty=presence_penalty,
|
163
|
+
stream=False,
|
164
|
+
)
|
165
|
+
self.prompt_tokens_used += response.usage.prompt_tokens # type: ignore
|
166
|
+
self.completion_tokens_used += response.usage.completion_tokens # type: ignore
|
167
|
+
self.request_number += 1
|
168
|
+
return response.choices[0].message.content
|
169
|
+
elif self.config.text['request_type'] == 'zhipuai':
|
170
|
+
client = ZhipuAI(api_key=self.config.text['api_key'])
|
171
|
+
response = client.chat.completions.create(
|
172
|
+
model=self.config.text['model'],
|
173
|
+
messages=dialog,
|
174
|
+
temperature=temperature,
|
175
|
+
top_p=top_p,
|
176
|
+
stream=False
|
177
|
+
)
|
178
|
+
self.prompt_tokens_used += response.usage.prompt_tokens # type: ignore
|
179
|
+
self.completion_tokens_used += response.usage.completion_tokens # type: ignore
|
180
|
+
self.request_number += 1
|
181
|
+
return response.choices[0].message.content # type: ignore
|
182
|
+
else:
|
183
|
+
print("ERROR: Wrong Config")
|
184
|
+
return "wrong config"
|
185
|
+
|
186
|
+
async def atext_request(
|
187
|
+
self,
|
188
|
+
dialog:Any,
|
189
|
+
temperature:float=1,
|
190
|
+
max_tokens:Optional[int]=None,
|
191
|
+
top_p:Optional[float]=None,
|
192
|
+
frequency_penalty:Optional[float]=None,
|
193
|
+
presence_penalty:Optional[float]=None,
|
194
|
+
timeout:int=300,
|
195
|
+
retries=3,
|
196
|
+
tools:Optional[List[Dict[str, Any]]]=None,
|
197
|
+
tool_choice:Optional[Dict[str, Any]]=None
|
198
|
+
):
|
199
|
+
"""
|
200
|
+
异步版文本请求
|
201
|
+
"""
|
202
|
+
if self.config.text['request_type'] == 'openai' or self.config.text['request_type'] == 'deepseek':
|
203
|
+
for attempt in range(retries):
|
204
|
+
try:
|
205
|
+
if self.semaphore != None:
|
206
|
+
async with self.semaphore:
|
207
|
+
response = await self._aclient.chat.completions.create(
|
208
|
+
model=self.config.text['model'],
|
209
|
+
messages=dialog,
|
210
|
+
temperature=temperature,
|
211
|
+
max_tokens=max_tokens,
|
212
|
+
top_p=top_p,
|
213
|
+
frequency_penalty=frequency_penalty, # type: ignore
|
214
|
+
presence_penalty=presence_penalty, # type: ignore
|
215
|
+
stream=False,
|
216
|
+
timeout=timeout,
|
217
|
+
tools=tools,
|
218
|
+
tool_choice=tool_choice
|
219
|
+
) # type: ignore
|
220
|
+
self.prompt_tokens_used += response.usage.prompt_tokens # type: ignore
|
221
|
+
self.completion_tokens_used += response.usage.completion_tokens # type: ignore
|
222
|
+
self.request_number += 1
|
223
|
+
if tools != None:
|
224
|
+
return response.tool_calls[0].function.arguments
|
225
|
+
else:
|
226
|
+
return response.choices[0].message.content
|
227
|
+
else:
|
228
|
+
response = await self._aclient.chat.completions.create(
|
229
|
+
model=self.config.text['model'],
|
230
|
+
messages=dialog,
|
231
|
+
temperature=temperature,
|
232
|
+
max_tokens=max_tokens,
|
233
|
+
top_p=top_p,
|
234
|
+
frequency_penalty=frequency_penalty, # type: ignore
|
235
|
+
presence_penalty=presence_penalty, # type: ignore
|
236
|
+
stream=False,
|
237
|
+
timeout=timeout,
|
238
|
+
tools=tools,
|
239
|
+
tool_choice=tool_choice
|
240
|
+
) # type: ignore
|
241
|
+
self.prompt_tokens_used += response.usage.prompt_tokens # type: ignore
|
242
|
+
self.completion_tokens_used += response.usage.completion_tokens # type: ignore
|
243
|
+
self.request_number += 1
|
244
|
+
if tools != None:
|
245
|
+
return response.tool_calls[0].function.arguments
|
246
|
+
else:
|
247
|
+
return response.choices[0].message.content
|
248
|
+
except APIConnectionError as e:
|
249
|
+
print("API connection error:", e)
|
250
|
+
if attempt < retries - 1:
|
251
|
+
await asyncio.sleep(2 ** attempt)
|
252
|
+
else:
|
253
|
+
raise e
|
254
|
+
except OpenAIError as e:
|
255
|
+
if hasattr(e, 'http_status'):
|
256
|
+
print(f"HTTP status code: {e.http_status}") # type: ignore
|
257
|
+
else:
|
258
|
+
print("An error occurred:", e)
|
259
|
+
if attempt < retries - 1:
|
260
|
+
await asyncio.sleep(2 ** attempt)
|
261
|
+
else:
|
262
|
+
raise e
|
263
|
+
elif self.config.text['request_type'] == 'zhipuai':
|
264
|
+
for attempt in range(retries):
|
265
|
+
try:
|
266
|
+
response = self._aclient.chat.asyncCompletions.create( # type: ignore
|
267
|
+
model=self.config.text['model'],
|
268
|
+
messages=dialog,
|
269
|
+
temperature=temperature,
|
270
|
+
top_p=top_p,
|
271
|
+
timeout=timeout,
|
272
|
+
tools=tools,
|
273
|
+
tool_choice=tool_choice
|
274
|
+
)
|
275
|
+
task_id = response.id
|
276
|
+
task_status = ''
|
277
|
+
get_cnt = 0
|
278
|
+
cnt_threshold = int(timeout/0.5)
|
279
|
+
while task_status != 'SUCCESS' and task_status != 'FAILED' and get_cnt <= cnt_threshold:
|
280
|
+
result_response = self._aclient.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
|
281
|
+
task_status = result_response.task_status
|
282
|
+
await asyncio.sleep(0.5)
|
283
|
+
get_cnt += 1
|
284
|
+
if task_status != 'SUCCESS':
|
285
|
+
raise Exception(f"Task failed with status: {task_status}")
|
286
|
+
|
287
|
+
self.prompt_tokens_used += result_response.usage.prompt_tokens # type: ignore
|
288
|
+
self.completion_tokens_used += result_response.usage.completion_tokens # type: ignore
|
289
|
+
self.request_number += 1
|
290
|
+
if tools and result_response.choices[0].message.tool_calls:
|
291
|
+
return json.loads(result_response.choices[0].message.tool_calls[0].function.arguments)
|
292
|
+
else:
|
293
|
+
return result_response.choices[0].message.content # type: ignore
|
294
|
+
except APIConnectionError as e:
|
295
|
+
print("API connection error:", e)
|
296
|
+
if attempt < retries - 1:
|
297
|
+
await asyncio.sleep(2 ** attempt)
|
298
|
+
else:
|
299
|
+
raise e
|
300
|
+
elif self.config.text['request_type'] == 'qwen':
|
301
|
+
async with aiohttp.ClientSession() as session:
|
302
|
+
api_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
303
|
+
headers = {"Content-Type": "application/json", "Authorization": f"{self.config.text['api_key']}"}
|
304
|
+
payload = {
|
305
|
+
'model': self.config.text['model'],
|
306
|
+
'input': {
|
307
|
+
'messages': dialog
|
308
|
+
}
|
309
|
+
}
|
310
|
+
async with session.post(api_url, json=payload, headers=headers) as resp:
|
311
|
+
response_json = await resp.json()
|
312
|
+
if 'code' in response_json.keys():
|
313
|
+
raise Exception(f"Error: {response_json['code']}, {response_json['message']}")
|
314
|
+
else:
|
315
|
+
return response_json['output']['text']
|
316
|
+
else:
|
317
|
+
print("ERROR: Wrong Config")
|
318
|
+
return "wrong config"
|
319
|
+
|
320
|
+
|
321
|
+
async def img_understand(self, img_path:Union[str, list[str]], prompt:Optional[str]=None) -> str:
|
322
|
+
"""
|
323
|
+
图像理解
|
324
|
+
Image understanding
|
325
|
+
|
326
|
+
Args:
|
327
|
+
- img_path (Union[str, list[str]]): 目标图像的路径, 既可以是一个路径也可以是包含多张图片路径的list. The path of selected Image
|
328
|
+
- prompt (str): 理解提示词 - 例如理解方向. The understanding prompts
|
329
|
+
|
330
|
+
Returns:
|
331
|
+
- (str): the understanding content
|
332
|
+
"""
|
333
|
+
ppt = "如何理解这幅图像?"
|
334
|
+
if prompt != None:
|
335
|
+
ppt = prompt
|
336
|
+
if self.config.image_u['request_type'] == 'openai':
|
337
|
+
if 'api_base' in self.config.image_u.keys():
|
338
|
+
api_base = self.config.image_u['api_base']
|
339
|
+
else:
|
340
|
+
api_base = None
|
341
|
+
client = OpenAI(
|
342
|
+
api_key=self.config.text['api_key'],
|
343
|
+
base_url=api_base,
|
344
|
+
)
|
345
|
+
content = []
|
346
|
+
content.append({'type': 'text', 'text': ppt})
|
347
|
+
if isinstance(img_path, str):
|
348
|
+
base64_image = encode_image(img_path)
|
349
|
+
content.append({
|
350
|
+
'type': 'image_url',
|
351
|
+
'image_url': {
|
352
|
+
'url': f"data:image/jpeg;base64,{base64_image}"
|
353
|
+
}
|
354
|
+
})
|
355
|
+
elif isinstance(img_path, list) and all(isinstance(item, str) for item in img_path):
|
356
|
+
for item in img_path:
|
357
|
+
base64_image = encode_image(item)
|
358
|
+
content.append({
|
359
|
+
'type': 'image_url',
|
360
|
+
'image_url': {
|
361
|
+
'url': f"data:image/jpeg;base64,{base64_image}"
|
362
|
+
}
|
363
|
+
})
|
364
|
+
response = client.chat.completions.create(
|
365
|
+
model=self.config.image_u['model'],
|
366
|
+
messages=[{
|
367
|
+
'role': 'user',
|
368
|
+
'content': content
|
369
|
+
}]
|
370
|
+
)
|
371
|
+
return response.choices[0].message.content # type: ignore
|
372
|
+
elif self.config.image_u['request_type'] == 'qwen':
|
373
|
+
content = []
|
374
|
+
if isinstance(img_path, str):
|
375
|
+
content.append({'image': 'file://' + img_path})
|
376
|
+
content.append({'text': ppt})
|
377
|
+
elif isinstance(img_path, list) and all(isinstance(item, str) for item in img_path):
|
378
|
+
for item in img_path:
|
379
|
+
content.append({
|
380
|
+
'image': 'file://' + item
|
381
|
+
})
|
382
|
+
content.append({'text': ppt})
|
383
|
+
|
384
|
+
dialog = [{
|
385
|
+
'role': 'user',
|
386
|
+
'content': content
|
387
|
+
}]
|
388
|
+
response = dashscope.MultiModalConversation.call(
|
389
|
+
model=self.config.image_u['model'],
|
390
|
+
api_key=self.config.image_u['api_key'],
|
391
|
+
messages=dialog
|
392
|
+
)
|
393
|
+
if response.status_code == HTTPStatus.OK: # type: ignore
|
394
|
+
return response.output.choices[0]['message']['content'] # type: ignore
|
395
|
+
else:
|
396
|
+
print(response.code) # type: ignore # The error code.
|
397
|
+
return "Error"
|
398
|
+
else:
|
399
|
+
print("ERROR: wrong image understanding type, only 'openai' and 'openai' is available")
|
400
|
+
return "Error"
|
401
|
+
|
402
|
+
async def img_generate(self, prompt:str, size:str='512*512', quantity:int = 1):
|
403
|
+
"""
|
404
|
+
图像生成
|
405
|
+
Image generation
|
406
|
+
|
407
|
+
Args:
|
408
|
+
- prompt (str): 图像生成提示词. The image generation prompts
|
409
|
+
- size (str): 生成图像尺寸, 默认为'512*512'. The image size, default: '512*512'
|
410
|
+
- quantity (int): 生成图像数量, 默认为1. The quantity of generated images, default: 1
|
411
|
+
|
412
|
+
Returns:
|
413
|
+
- (list[PIL.Image.Image]): 生成的图像列表. The list of generated Images.
|
414
|
+
"""
|
415
|
+
rsp = ImageSynthesis.call(
|
416
|
+
model=self.config.image_g['model'],
|
417
|
+
api_key=self.config.image_g['api_key'],
|
418
|
+
prompt=prompt,
|
419
|
+
n=quantity,
|
420
|
+
size=size
|
421
|
+
)
|
422
|
+
if rsp.status_code == HTTPStatus.OK:
|
423
|
+
res = []
|
424
|
+
for result in rsp.output.results:
|
425
|
+
res.append(Image.open(BytesIO(requests.get(result.url).content)))
|
426
|
+
return res
|
427
|
+
else:
|
428
|
+
print('Failed, status_code: %s, code: %s, message: %s' %
|
429
|
+
(rsp.status_code, rsp.code, rsp.message))
|
430
|
+
return None
|
@@ -0,0 +1,15 @@
|
|
1
|
+
class LLMConfig:
|
2
|
+
"""
|
3
|
+
大语言模型相关配置
|
4
|
+
The config of LLM
|
5
|
+
"""
|
6
|
+
def __init__(
|
7
|
+
self,
|
8
|
+
config: dict
|
9
|
+
) -> None:
|
10
|
+
self.config = config
|
11
|
+
self.text = config['text_request']
|
12
|
+
if 'api_base' in self.text.keys() and self.text['api_base'] == 'None':
|
13
|
+
self.text['api_base'] = None
|
14
|
+
self.image_u = config['img_understand_request']
|
15
|
+
self.image_g = config['img_generate_request']
|
pycityagent/llm/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
from pycityproto.city.person.v2.motion_pb2 import Status
|
2
|
+
|
3
|
+
PROFILE_ATTRIBUTES = {
|
4
|
+
"gender": str(),
|
5
|
+
"age": float(),
|
6
|
+
"education": str(),
|
7
|
+
"skill": str(),
|
8
|
+
"occupation": str(),
|
9
|
+
"family_consumption": str(),
|
10
|
+
"consumption": str(),
|
11
|
+
"personality": str(),
|
12
|
+
"income": str(),
|
13
|
+
"residence": str(),
|
14
|
+
"race": str(),
|
15
|
+
"religion": str(),
|
16
|
+
"marital_status": str(),
|
17
|
+
}
|
18
|
+
|
19
|
+
STATE_ATTRIBUTES = {
|
20
|
+
# base
|
21
|
+
"id": -1,
|
22
|
+
"attribute": dict(),
|
23
|
+
"home": dict(),
|
24
|
+
"work": dict(),
|
25
|
+
"schedules": [],
|
26
|
+
"vehicle_attribute": dict(),
|
27
|
+
"bus_attribute": dict(),
|
28
|
+
"pedestrian_attribute": dict(),
|
29
|
+
"bike_attribute": dict(),
|
30
|
+
# motion
|
31
|
+
"status": Status.STATUS_UNSPECIFIED,
|
32
|
+
"position": dict(),
|
33
|
+
"v": float(),
|
34
|
+
"direction": float(),
|
35
|
+
"activity": str(),
|
36
|
+
"l": float(),
|
37
|
+
}
|
38
|
+
|
39
|
+
SELF_DEFINE_PREFIX = "self_define_"
|
40
|
+
|
41
|
+
TIME_STAMP_KEY = "_timestamp"
|