pycityagent 2.0.0a94__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a96__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +5 -5
- pycityagent/agent/agent_base.py +1 -6
- pycityagent/cityagent/__init__.py +6 -5
- pycityagent/cityagent/bankagent.py +2 -2
- pycityagent/cityagent/blocks/__init__.py +4 -4
- pycityagent/cityagent/blocks/cognition_block.py +7 -4
- pycityagent/cityagent/blocks/economy_block.py +227 -135
- pycityagent/cityagent/blocks/mobility_block.py +70 -27
- pycityagent/cityagent/blocks/needs_block.py +11 -12
- pycityagent/cityagent/blocks/other_block.py +2 -2
- pycityagent/cityagent/blocks/plan_block.py +22 -24
- pycityagent/cityagent/blocks/social_block.py +15 -17
- pycityagent/cityagent/blocks/utils.py +3 -2
- pycityagent/cityagent/firmagent.py +1 -1
- pycityagent/cityagent/governmentagent.py +1 -1
- pycityagent/cityagent/initial.py +1 -1
- pycityagent/cityagent/memory_config.py +0 -1
- pycityagent/cityagent/message_intercept.py +7 -8
- pycityagent/cityagent/nbsagent.py +1 -1
- pycityagent/cityagent/societyagent.py +1 -2
- pycityagent/configs/__init__.py +18 -0
- pycityagent/configs/exp_config.py +202 -0
- pycityagent/configs/sim_config.py +251 -0
- pycityagent/configs/utils.py +17 -0
- pycityagent/environment/__init__.py +2 -0
- pycityagent/{economy → environment/economy}/econ_client.py +14 -32
- pycityagent/environment/sim/sim_env.py +17 -24
- pycityagent/environment/simulator.py +36 -113
- pycityagent/llm/__init__.py +1 -2
- pycityagent/llm/llm.py +54 -167
- pycityagent/memory/memory.py +13 -12
- pycityagent/message/message_interceptor.py +5 -4
- pycityagent/message/messager.py +3 -5
- pycityagent/metrics/__init__.py +1 -1
- pycityagent/metrics/mlflow_client.py +20 -17
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +18 -20
- pycityagent/simulation/simulation.py +157 -210
- pycityagent/survey/manager.py +0 -2
- pycityagent/utils/__init__.py +3 -0
- pycityagent/utils/config_const.py +20 -0
- pycityagent/workflow/__init__.py +1 -2
- pycityagent/workflow/block.py +0 -3
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/METADATA +7 -24
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/RECORD +50 -46
- pycityagent/llm/llmconfig.py +0 -18
- /pycityagent/{economy → environment/economy}/__init__.py +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from pycityproto.city.person.v2 import person_service_pb2 as person_service
|
|
17
17
|
from pymongo import MongoClient
|
18
18
|
from shapely.geometry import Point
|
19
19
|
|
20
|
+
from ..configs import SimConfig
|
20
21
|
from .sim import CityClient, ControlSimEnv
|
21
22
|
from .utils.const import *
|
22
23
|
|
@@ -29,19 +30,10 @@ __all__ = [
|
|
29
30
|
|
30
31
|
@ray.remote
|
31
32
|
class CityMap:
|
32
|
-
def __init__(self,
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
)
|
37
|
-
else:
|
38
|
-
mongo_uri, mongo_db, mongo_coll, cache_dir = mongo_input
|
39
|
-
self.map = SimMap(
|
40
|
-
mongo_uri=mongo_uri,
|
41
|
-
mongo_db=mongo_db,
|
42
|
-
mongo_coll=mongo_coll,
|
43
|
-
cache_dir=cache_dir,
|
44
|
-
)
|
33
|
+
def __init__(self, map_cache_path: str):
|
34
|
+
self.map = SimMap(
|
35
|
+
pb_path=map_cache_path,
|
36
|
+
)
|
45
37
|
self.poi_cate = POI_CATG_DICT
|
46
38
|
|
47
39
|
def get_aoi(self, aoi_id: Optional[int] = None):
|
@@ -81,74 +73,41 @@ class Simulator:
|
|
81
73
|
- It reads parameters from a configuration dictionary, initializes map data, and starts or connects to a simulation server as needed.
|
82
74
|
"""
|
83
75
|
|
84
|
-
def __init__(self,
|
85
|
-
self.
|
76
|
+
def __init__(self, sim_config: SimConfig, create_map: bool = False) -> None:
|
77
|
+
self.sim_config = sim_config
|
86
78
|
"""
|
87
79
|
- 模拟器配置
|
88
80
|
- simulator config
|
89
81
|
"""
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
82
|
+
_map_pb_path = sim_config.prop_map_request.file_path
|
83
|
+
config = sim_config.prop_simulator_request
|
84
|
+
if not sim_config.prop_status.simulator_activated:
|
85
|
+
self._sim_env = sim_env = ControlSimEnv(
|
86
|
+
task_name=config.task_name, # type:ignore
|
87
|
+
map_file=_map_pb_path,
|
88
|
+
max_day=config.max_day, # type:ignore
|
89
|
+
start_step=config.start_step, # type:ignore
|
90
|
+
total_step=config.total_step, # type:ignore
|
91
|
+
log_dir=config.log_dir, # type:ignore
|
92
|
+
min_step_time=config.min_step_time, # type:ignore
|
93
|
+
primary_node_ip=config.primary_node_ip, # type:ignore
|
94
|
+
sim_addr=sim_config.simulator_server_address,
|
98
95
|
)
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
else:
|
111
|
-
# from local file
|
112
|
-
_mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir = "", "", "", ""
|
113
|
-
_map_pb_path = _map_request["file_path"]
|
114
|
-
|
115
|
-
if "simulator" in config:
|
116
|
-
if config["simulator"] is None:
|
117
|
-
config["simulator"] = {}
|
118
|
-
if not config["simulator"].get("_server_activated", False):
|
119
|
-
self._sim_env = sim_env = ControlSimEnv(
|
120
|
-
task_name=config["simulator"].get("task", "citysim"),
|
121
|
-
map_file=_map_pb_path,
|
122
|
-
max_day=config["simulator"].get("max_day", 1000),
|
123
|
-
start_step=config["simulator"].get("start_step", 28800),
|
124
|
-
total_step=config["simulator"].get(
|
125
|
-
"total_step", 24 * 60 * 60 * 365
|
126
|
-
),
|
127
|
-
log_dir=config["simulator"].get("log_dir", "./log"),
|
128
|
-
min_step_time=config["simulator"].get("min_step_time", 1000),
|
129
|
-
primary_node_ip=config["simulator"].get("primary_node_ip", "localhost"),
|
130
|
-
sim_addr=config["simulator"].get("server", None),
|
131
|
-
)
|
132
|
-
self.server_addr = sim_env.sim_addr
|
133
|
-
config["simulator"]["server"] = self.server_addr
|
134
|
-
config["simulator"]["_server_activated"] = True
|
135
|
-
# using local client
|
136
|
-
self._client = CityClient(
|
137
|
-
sim_env.sim_addr, secure=self.server_addr.startswith("https")
|
138
|
-
)
|
139
|
-
"""
|
140
|
-
- 模拟器grpc客户端
|
141
|
-
- grpc client of simulator
|
142
|
-
"""
|
143
|
-
else:
|
144
|
-
self.server_addr = config["simulator"]["server"]
|
145
|
-
self._client = CityClient(
|
146
|
-
self.server_addr, secure=self.server_addr.startswith("https")
|
147
|
-
)
|
96
|
+
self.server_addr = sim_env.sim_addr
|
97
|
+
sim_config.SetServerAddress(self.server_addr)
|
98
|
+
sim_config.prop_status.simulator_activated = True
|
99
|
+
# using local client
|
100
|
+
self._client = CityClient(
|
101
|
+
sim_env.sim_addr, secure=self.server_addr.startswith("https")
|
102
|
+
)
|
103
|
+
"""
|
104
|
+
- 模拟器grpc客户端
|
105
|
+
- grpc client of simulator
|
106
|
+
"""
|
148
107
|
else:
|
149
|
-
self.server_addr =
|
150
|
-
|
151
|
-
|
108
|
+
self.server_addr: str = sim_config.simulator_server_address # type:ignore
|
109
|
+
self._client = CityClient(
|
110
|
+
self.server_addr, secure=self.server_addr.startswith("https")
|
152
111
|
)
|
153
112
|
self._map = None
|
154
113
|
"""
|
@@ -157,7 +116,6 @@ class Simulator:
|
|
157
116
|
"""
|
158
117
|
if create_map:
|
159
118
|
self._map = CityMap.remote(
|
160
|
-
(_mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir),
|
161
119
|
_map_pb_path,
|
162
120
|
)
|
163
121
|
self._create_poi_id_2_aoi_id()
|
@@ -205,8 +163,8 @@ class Simulator:
|
|
205
163
|
"""
|
206
164
|
return self._environment_prompt
|
207
165
|
|
208
|
-
def get_server_addr(self):
|
209
|
-
return self.server_addr
|
166
|
+
def get_server_addr(self) -> str:
|
167
|
+
return self.server_addr # type:ignore
|
210
168
|
|
211
169
|
def set_environment(self, environment: dict[str, str]):
|
212
170
|
"""
|
@@ -239,41 +197,6 @@ class Simulator:
|
|
239
197
|
"""
|
240
198
|
self._environment_prompt[key] = value
|
241
199
|
|
242
|
-
# * Agent相关
|
243
|
-
def find_agents_by_area(self, req: dict, status=None):
|
244
|
-
"""
|
245
|
-
Find agents/persons within a specified area.
|
246
|
-
|
247
|
-
- **Args**:
|
248
|
-
- `req` (`dict`): A dictionary that describes the area. Refer to
|
249
|
-
https://cityproto.sim.fiblab.net/#city.person.1.GetPersonByLongLatBBoxRequest.
|
250
|
-
- `status` (`Optional[int]`): An integer representing the status of the agents/persons to filter by.
|
251
|
-
If provided, only persons with the given status will be returned.
|
252
|
-
Refer to https://cityproto.sim.fiblab.net/#city.agent.v2.Status.
|
253
|
-
|
254
|
-
- **Returns**:
|
255
|
-
- The response from the GetPersonByLongLatBBox method, possibly filtered by status.
|
256
|
-
Refer to https://cityproto.sim.fiblab.net/#city.person.1.GetPersonByLongLatBBoxResponse.
|
257
|
-
"""
|
258
|
-
start_time = time.time()
|
259
|
-
log = {"req": "find_agents_by_area", "start_time": start_time, "consumption": 0}
|
260
|
-
loop = asyncio.get_event_loop()
|
261
|
-
resp = loop.run_until_complete(
|
262
|
-
self._client.person_service.GetPersonByLongLatBBox(req=req)
|
263
|
-
)
|
264
|
-
loop.close()
|
265
|
-
if status == None:
|
266
|
-
return resp
|
267
|
-
else:
|
268
|
-
motions = []
|
269
|
-
for agent in resp.motions: # type: ignore
|
270
|
-
if agent.status in status:
|
271
|
-
motions.append(agent)
|
272
|
-
resp.motions = motions # type: ignore
|
273
|
-
log["consumption"] = time.time() - start_time
|
274
|
-
self._log_list.append(log)
|
275
|
-
return resp
|
276
|
-
|
277
200
|
def get_poi_categories(
|
278
201
|
self,
|
279
202
|
center: Optional[Union[tuple[float, float], Point]] = None,
|
pycityagent/llm/__init__.py
CHANGED
pycityagent/llm/llm.py
CHANGED
@@ -3,20 +3,15 @@
|
|
3
3
|
import asyncio
|
4
4
|
import json
|
5
5
|
import logging
|
6
|
-
import time
|
7
6
|
import os
|
8
|
-
|
9
|
-
from io import BytesIO
|
7
|
+
import time
|
10
8
|
from typing import Any, Optional, Union
|
11
9
|
|
12
|
-
import dashscope
|
13
|
-
import requests
|
14
|
-
from dashscope import ImageSynthesis
|
15
10
|
from openai import APIConnectionError, AsyncOpenAI, OpenAI, OpenAIError
|
16
|
-
from PIL import Image
|
17
11
|
from zhipuai import ZhipuAI
|
18
12
|
|
19
|
-
from
|
13
|
+
from ..configs import LLMRequestConfig
|
14
|
+
from ..utils import LLMRequestType
|
20
15
|
from .utils import *
|
21
16
|
|
22
17
|
logging.getLogger("zhipuai").setLevel(logging.WARNING)
|
@@ -36,15 +31,15 @@ class LLM:
|
|
36
31
|
- It initializes clients based on the specified request type and handles token usage and consumption reporting.
|
37
32
|
"""
|
38
33
|
|
39
|
-
def __init__(self, config:
|
34
|
+
def __init__(self, config: LLMRequestConfig) -> None:
|
40
35
|
"""
|
41
36
|
Initializes the LLM instance.
|
42
37
|
|
43
38
|
- **Parameters**:
|
44
|
-
- `config`: An instance of `
|
39
|
+
- `config`: An instance of `LLMRequestConfig` containing configuration settings for the LLM.
|
45
40
|
"""
|
46
41
|
self.config = config
|
47
|
-
if config.
|
42
|
+
if config.request_type not in {t.value for t in LLMRequestType}:
|
48
43
|
raise ValueError("Invalid request type for text request")
|
49
44
|
self.prompt_tokens_used = 0
|
50
45
|
self.completion_tokens_used = 0
|
@@ -53,7 +48,7 @@ class LLM:
|
|
53
48
|
self._current_client_index = 0
|
54
49
|
self._log_list = []
|
55
50
|
|
56
|
-
api_keys = self.config.
|
51
|
+
api_keys = self.config.api_key
|
57
52
|
if not isinstance(api_keys, list):
|
58
53
|
api_keys = [api_keys]
|
59
54
|
|
@@ -61,42 +56,40 @@ class LLM:
|
|
61
56
|
self._client_usage = []
|
62
57
|
|
63
58
|
for api_key in api_keys:
|
64
|
-
if self.config.
|
59
|
+
if self.config.request_type == LLMRequestType.OpenAI:
|
65
60
|
client = AsyncOpenAI(api_key=api_key, timeout=300)
|
66
|
-
elif self.config.
|
61
|
+
elif self.config.request_type == LLMRequestType.DeepSeek:
|
67
62
|
client = AsyncOpenAI(
|
68
63
|
api_key=api_key,
|
69
64
|
base_url="https://api.deepseek.com/v1",
|
70
65
|
timeout=300,
|
71
66
|
)
|
72
|
-
elif self.config.
|
67
|
+
elif self.config.request_type == LLMRequestType.Qwen:
|
73
68
|
client = AsyncOpenAI(
|
74
69
|
api_key=api_key,
|
75
70
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
76
71
|
timeout=300,
|
77
72
|
)
|
78
|
-
elif self.config.
|
73
|
+
elif self.config.request_type == LLMRequestType.SiliconFlow:
|
79
74
|
client = AsyncOpenAI(
|
80
75
|
api_key=api_key,
|
81
76
|
base_url="https://api.siliconflow.cn/v1",
|
82
77
|
timeout=300,
|
83
78
|
)
|
84
|
-
elif self.config.
|
79
|
+
elif self.config.request_type == LLMRequestType.ZhipuAI:
|
85
80
|
client = ZhipuAI(api_key=api_key, timeout=300)
|
86
81
|
else:
|
87
82
|
raise ValueError(
|
88
|
-
f"Unsupported `request_type` {self.config.
|
83
|
+
f"Unsupported `request_type` {self.config.request_type}!"
|
89
84
|
)
|
90
85
|
self._aclients.append(client)
|
91
|
-
self._client_usage.append(
|
92
|
-
"prompt_tokens": 0,
|
93
|
-
|
94
|
-
"request_number": 0
|
95
|
-
})
|
86
|
+
self._client_usage.append(
|
87
|
+
{"prompt_tokens": 0, "completion_tokens": 0, "request_number": 0}
|
88
|
+
)
|
96
89
|
|
97
90
|
def get_log_list(self):
|
98
91
|
return self._log_list
|
99
|
-
|
92
|
+
|
100
93
|
def clear_log_list(self):
|
101
94
|
self._log_list = []
|
102
95
|
|
@@ -122,7 +115,7 @@ class LLM:
|
|
122
115
|
"""
|
123
116
|
for usage in self._client_usage:
|
124
117
|
usage["prompt_tokens"] = 0
|
125
|
-
usage["completion_tokens"] = 0
|
118
|
+
usage["completion_tokens"] = 0
|
126
119
|
usage["request_number"] = 0
|
127
120
|
|
128
121
|
def get_consumption(self):
|
@@ -130,7 +123,7 @@ class LLM:
|
|
130
123
|
for i, usage in enumerate(self._client_usage):
|
131
124
|
consumption[f"api-key-{i+1}"] = {
|
132
125
|
"total_tokens": usage["prompt_tokens"] + usage["completion_tokens"],
|
133
|
-
"request_number": usage["request_number"]
|
126
|
+
"request_number": usage["request_number"],
|
134
127
|
}
|
135
128
|
return consumption
|
136
129
|
|
@@ -147,27 +140,24 @@ class LLM:
|
|
147
140
|
- **Returns**:
|
148
141
|
- A dictionary summarizing the token usage and, if applicable, the estimated cost.
|
149
142
|
"""
|
150
|
-
total_stats = {
|
151
|
-
|
152
|
-
"prompt": 0,
|
153
|
-
"completion": 0,
|
154
|
-
"requests": 0
|
155
|
-
}
|
156
|
-
|
143
|
+
total_stats = {"total": 0, "prompt": 0, "completion": 0, "requests": 0}
|
144
|
+
|
157
145
|
for i, usage in enumerate(self._client_usage):
|
158
146
|
prompt_tokens = usage["prompt_tokens"]
|
159
147
|
completion_tokens = usage["completion_tokens"]
|
160
148
|
requests = usage["request_number"]
|
161
149
|
total_tokens = prompt_tokens + completion_tokens
|
162
|
-
|
150
|
+
|
163
151
|
total_stats["total"] += total_tokens
|
164
152
|
total_stats["prompt"] += prompt_tokens
|
165
153
|
total_stats["completion"] += completion_tokens
|
166
154
|
total_stats["requests"] += requests
|
167
|
-
|
168
|
-
rate =
|
155
|
+
|
156
|
+
rate = (
|
157
|
+
prompt_tokens / completion_tokens if completion_tokens != 0 else "nan"
|
158
|
+
)
|
169
159
|
tokens_per_request = total_tokens / requests if requests != 0 else "nan"
|
170
|
-
|
160
|
+
|
171
161
|
print(f"\nAPI Key #{i+1}:")
|
172
162
|
print(f"Request Number: {requests}")
|
173
163
|
print("Token Usage:")
|
@@ -176,12 +166,14 @@ class LLM:
|
|
176
166
|
print(f" - Completion tokens: {completion_tokens}")
|
177
167
|
print(f" - Token per request: {tokens_per_request}")
|
178
168
|
print(f" - Prompt:Completion ratio: {rate}:1")
|
179
|
-
|
169
|
+
|
180
170
|
if input_price is not None and output_price is not None:
|
181
|
-
consumption = (
|
182
|
-
|
171
|
+
consumption = (
|
172
|
+
prompt_tokens / 1000000 * input_price
|
173
|
+
+ completion_tokens / 1000000 * output_price
|
174
|
+
)
|
183
175
|
print(f" - Cost Estimation: {consumption}")
|
184
|
-
|
176
|
+
|
185
177
|
return total_stats
|
186
178
|
|
187
179
|
def _get_next_client(self):
|
@@ -203,6 +195,7 @@ class LLM:
|
|
203
195
|
async def atext_request(
|
204
196
|
self,
|
205
197
|
dialog: Any,
|
198
|
+
response_format: Optional[dict[str, Any]] = None,
|
206
199
|
temperature: float = 1,
|
207
200
|
max_tokens: Optional[int] = None,
|
208
201
|
top_p: Optional[float] = None,
|
@@ -222,6 +215,7 @@ class LLM:
|
|
222
215
|
|
223
216
|
- **Parameters**:
|
224
217
|
- `dialog`: Messages to send as part of the chat completion request.
|
218
|
+
- `response_format`: JSON schema for the response. Default is None.
|
225
219
|
- `temperature`: Controls randomness in the model's output. Default is 1.
|
226
220
|
- `max_tokens`: Maximum number of tokens to generate in the response. Default is None.
|
227
221
|
- `top_p`: Limits the next token selection to a subset of tokens with a cumulative probability above this value. Default is None.
|
@@ -238,19 +232,23 @@ class LLM:
|
|
238
232
|
"""
|
239
233
|
start_time = time.time()
|
240
234
|
log = {"request_time": start_time}
|
235
|
+
assert (
|
236
|
+
self.semaphore is not None
|
237
|
+
), "Please set semaphore with `set_semaphore` first!"
|
241
238
|
async with self.semaphore:
|
242
239
|
if (
|
243
|
-
self.config.
|
244
|
-
or self.config.
|
245
|
-
or self.config.
|
246
|
-
or self.config.
|
240
|
+
self.config.request_type == "openai"
|
241
|
+
or self.config.request_type == "deepseek"
|
242
|
+
or self.config.request_type == "qwen"
|
243
|
+
or self.config.request_type == "siliconflow"
|
247
244
|
):
|
248
245
|
for attempt in range(retries):
|
249
246
|
try:
|
250
247
|
client = self._get_next_client()
|
251
248
|
response = await client.chat.completions.create(
|
252
|
-
model=self.config.
|
249
|
+
model=self.config.model,
|
253
250
|
messages=dialog,
|
251
|
+
response_format=response_format,
|
254
252
|
temperature=temperature,
|
255
253
|
max_tokens=max_tokens,
|
256
254
|
top_p=top_p,
|
@@ -263,7 +261,9 @@ class LLM:
|
|
263
261
|
) # type: ignore
|
264
262
|
self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
|
265
263
|
self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
|
266
|
-
self._client_usage[self._current_client_index][
|
264
|
+
self._client_usage[self._current_client_index][
|
265
|
+
"request_number"
|
266
|
+
] += 1
|
267
267
|
end_time = time.time()
|
268
268
|
log["consumption"] = end_time - start_time
|
269
269
|
log["input_tokens"] = response.usage.prompt_tokens
|
@@ -295,16 +295,15 @@ class LLM:
|
|
295
295
|
except Exception as e:
|
296
296
|
print("LLM Error (OpenAI):", e)
|
297
297
|
if attempt < retries - 1:
|
298
|
-
print(dialog)
|
299
298
|
await asyncio.sleep(2**attempt)
|
300
299
|
else:
|
301
300
|
raise e
|
302
|
-
elif self.config.
|
301
|
+
elif self.config.request_type == "zhipuai":
|
303
302
|
for attempt in range(retries):
|
304
303
|
try:
|
305
304
|
client = self._get_next_client()
|
306
305
|
response = client.chat.asyncCompletions.create( # type: ignore
|
307
|
-
model=self.config.
|
306
|
+
model=self.config.model,
|
308
307
|
messages=dialog,
|
309
308
|
temperature=temperature,
|
310
309
|
top_p=top_p,
|
@@ -330,10 +329,12 @@ class LLM:
|
|
330
329
|
|
331
330
|
self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
|
332
331
|
self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
|
333
|
-
self._client_usage[self._current_client_index][
|
332
|
+
self._client_usage[self._current_client_index][
|
333
|
+
"request_number"
|
334
|
+
] += 1
|
334
335
|
end_time = time.time()
|
335
336
|
log["used_time"] = end_time - start_time
|
336
|
-
log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
|
337
|
+
log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens # type: ignore
|
337
338
|
self._log_list.append(log)
|
338
339
|
if tools and result_response.choices[0].message.tool_calls: # type: ignore
|
339
340
|
return json.loads(
|
@@ -356,118 +357,4 @@ class LLM:
|
|
356
357
|
else:
|
357
358
|
raise e
|
358
359
|
else:
|
359
|
-
|
360
|
-
return "wrong config"
|
361
|
-
|
362
|
-
async def img_understand(
|
363
|
-
self, img_path: Union[str, list[str]], prompt: Optional[str] = None
|
364
|
-
) -> str:
|
365
|
-
"""
|
366
|
-
Analyzes and understands images using external APIs.
|
367
|
-
|
368
|
-
- **Args**:
|
369
|
-
img_path (Union[str, list[str]]): Path or list of paths to the images for analysis.
|
370
|
-
prompt (Optional[str]): Guidance text for understanding the images.
|
371
|
-
|
372
|
-
- **Returns**:
|
373
|
-
str: The content derived from understanding the images.
|
374
|
-
"""
|
375
|
-
ppt = "如何理解这幅图像?"
|
376
|
-
if prompt != None:
|
377
|
-
ppt = prompt
|
378
|
-
if self.config.image_u["request_type"] == "openai":
|
379
|
-
if "api_base" in self.config.image_u.keys():
|
380
|
-
api_base = self.config.image_u["api_base"]
|
381
|
-
else:
|
382
|
-
api_base = None
|
383
|
-
client = OpenAI(
|
384
|
-
api_key=self.config.text["api_key"],
|
385
|
-
base_url=api_base,
|
386
|
-
)
|
387
|
-
content = []
|
388
|
-
content.append({"type": "text", "text": ppt})
|
389
|
-
if isinstance(img_path, str):
|
390
|
-
base64_image = encode_image(img_path)
|
391
|
-
content.append(
|
392
|
-
{
|
393
|
-
"type": "image_url",
|
394
|
-
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
395
|
-
}
|
396
|
-
)
|
397
|
-
elif isinstance(img_path, list) and all(
|
398
|
-
isinstance(item, str) for item in img_path
|
399
|
-
):
|
400
|
-
for item in img_path:
|
401
|
-
base64_image = encode_image(item)
|
402
|
-
content.append(
|
403
|
-
{
|
404
|
-
"type": "image_url",
|
405
|
-
"image_url": {
|
406
|
-
"url": f"data:image/jpeg;base64,{base64_image}"
|
407
|
-
},
|
408
|
-
}
|
409
|
-
)
|
410
|
-
response = client.chat.completions.create(
|
411
|
-
model=self.config.image_u["model"],
|
412
|
-
messages=[{"role": "user", "content": content}],
|
413
|
-
)
|
414
|
-
return response.choices[0].message.content # type: ignore
|
415
|
-
elif self.config.image_u["request_type"] == "qwen":
|
416
|
-
content = []
|
417
|
-
if isinstance(img_path, str):
|
418
|
-
content.append({"image": "file://" + img_path})
|
419
|
-
content.append({"text": ppt})
|
420
|
-
elif isinstance(img_path, list) and all(
|
421
|
-
isinstance(item, str) for item in img_path
|
422
|
-
):
|
423
|
-
for item in img_path:
|
424
|
-
content.append({"image": "file://" + item})
|
425
|
-
content.append({"text": ppt})
|
426
|
-
|
427
|
-
dialog = [{"role": "user", "content": content}]
|
428
|
-
response = dashscope.MultiModalConversation.call(
|
429
|
-
model=self.config.image_u["model"],
|
430
|
-
api_key=self.config.image_u["api_key"],
|
431
|
-
messages=dialog,
|
432
|
-
)
|
433
|
-
if response.status_code == HTTPStatus.OK: # type: ignore
|
434
|
-
return response.output.choices[0]["message"]["content"] # type: ignore
|
435
|
-
else:
|
436
|
-
print(response.code) # type: ignore # The error code.
|
437
|
-
return "Error"
|
438
|
-
else:
|
439
|
-
print(
|
440
|
-
"ERROR: wrong image understanding type, only 'openai' and 'openai' is available"
|
441
|
-
)
|
442
|
-
return "Error"
|
443
|
-
|
444
|
-
async def img_generate(self, prompt: str, size: str = "512*512", quantity: int = 1):
|
445
|
-
"""
|
446
|
-
Generates images based on a given prompt.
|
447
|
-
|
448
|
-
- **Args**:
|
449
|
-
prompt (str): Prompt for generating images.
|
450
|
-
size (str): Size of the generated images, default is '512*512'.
|
451
|
-
quantity (int): Number of images to generate, default is 1.
|
452
|
-
|
453
|
-
- **Returns**:
|
454
|
-
list[PIL.Image.Image]: List of generated PIL Image objects.
|
455
|
-
"""
|
456
|
-
rsp = ImageSynthesis.call(
|
457
|
-
model=self.config.image_g["model"],
|
458
|
-
api_key=self.config.image_g["api_key"],
|
459
|
-
prompt=prompt,
|
460
|
-
n=quantity,
|
461
|
-
size=size,
|
462
|
-
)
|
463
|
-
if rsp.status_code == HTTPStatus.OK:
|
464
|
-
res = []
|
465
|
-
for result in rsp.output.results:
|
466
|
-
res.append(Image.open(BytesIO(requests.get(result.url).content)))
|
467
|
-
return res
|
468
|
-
else:
|
469
|
-
print(
|
470
|
-
"Failed, status_code: %s, code: %s, message: %s"
|
471
|
-
% (rsp.status_code, rsp.code, rsp.message)
|
472
|
-
)
|
473
|
-
return None
|
360
|
+
raise ValueError("ERROR: Wrong Config")
|
pycityagent/memory/memory.py
CHANGED
@@ -335,9 +335,7 @@ class StreamMemory:
|
|
335
335
|
)
|
336
336
|
return "\n".join(formatted_results)
|
337
337
|
|
338
|
-
async def get_by_ids(
|
339
|
-
self, memory_ids: Union[int, list[int]]
|
340
|
-
) -> Coroutine[Any, Any, str]:
|
338
|
+
async def get_by_ids(self, memory_ids: Union[int, list[int]]) -> str:
|
341
339
|
"""获取指定ID的记忆"""
|
342
340
|
memories = [memory for memory in self._memories if memory.id in memory_ids]
|
343
341
|
sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
|
@@ -491,15 +489,18 @@ class StreamMemory:
|
|
491
489
|
- **Returns**:
|
492
490
|
- `list[dict]`: List of all memory nodes as dictionaries.
|
493
491
|
"""
|
494
|
-
return [
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
492
|
+
return [
|
493
|
+
{
|
494
|
+
"id": memory.id,
|
495
|
+
"cognition_id": memory.cognition_id,
|
496
|
+
"tag": memory.tag.value,
|
497
|
+
"location": memory.location,
|
498
|
+
"description": memory.description,
|
499
|
+
"day": memory.day,
|
500
|
+
"t": memory.t,
|
501
|
+
}
|
502
|
+
for memory in self._memories
|
503
|
+
]
|
503
504
|
|
504
505
|
|
505
506
|
class StatusMemory:
|
@@ -10,7 +10,8 @@ from typing import Any, Optional, Union
|
|
10
10
|
import ray
|
11
11
|
from ray.util.queue import Queue
|
12
12
|
|
13
|
-
from ..
|
13
|
+
from ..configs import LLMRequestConfig
|
14
|
+
from ..llm import LLM
|
14
15
|
from ..utils.decorators import lock_decorator
|
15
16
|
|
16
17
|
DEFAULT_ERROR_STRING = """
|
@@ -95,7 +96,7 @@ class MessageInterceptor:
|
|
95
96
|
self,
|
96
97
|
blocks: Optional[list[MessageBlockBase]] = None,
|
97
98
|
black_list: Optional[list[tuple[str, str]]] = None,
|
98
|
-
llm_config: Optional[
|
99
|
+
llm_config: Optional[LLMRequestConfig] = None,
|
99
100
|
queue: Optional[Queue] = None,
|
100
101
|
) -> None:
|
101
102
|
"""
|
@@ -104,7 +105,7 @@ class MessageInterceptor:
|
|
104
105
|
- **Args**:
|
105
106
|
- `blocks` (Optional[list[MessageBlockBase]], optional): Initial list of message interception rules. Defaults to an empty list.
|
106
107
|
- `black_list` (Optional[list[tuple[str, str]]], optional): Initial blacklist of communication pairs. Defaults to an empty list.
|
107
|
-
- `llm_config` (Optional[
|
108
|
+
- `llm_config` (Optional[LLMRequestConfig], optional): Configuration dictionary for initializing the LLM instance. Defaults to None.
|
108
109
|
- `queue` (Optional[Queue], optional): Queue for message processing. Defaults to None.
|
109
110
|
"""
|
110
111
|
if blocks is not None:
|
@@ -119,7 +120,7 @@ class MessageInterceptor:
|
|
119
120
|
[]
|
120
121
|
) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
|
121
122
|
if llm_config:
|
122
|
-
self._llm = LLM(
|
123
|
+
self._llm = LLM(llm_config)
|
123
124
|
else:
|
124
125
|
self._llm = None
|
125
126
|
self._queue = queue
|