pycityagent 2.0.0a94__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a96__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. pycityagent/agent/agent.py +5 -5
  2. pycityagent/agent/agent_base.py +1 -6
  3. pycityagent/cityagent/__init__.py +6 -5
  4. pycityagent/cityagent/bankagent.py +2 -2
  5. pycityagent/cityagent/blocks/__init__.py +4 -4
  6. pycityagent/cityagent/blocks/cognition_block.py +7 -4
  7. pycityagent/cityagent/blocks/economy_block.py +227 -135
  8. pycityagent/cityagent/blocks/mobility_block.py +70 -27
  9. pycityagent/cityagent/blocks/needs_block.py +11 -12
  10. pycityagent/cityagent/blocks/other_block.py +2 -2
  11. pycityagent/cityagent/blocks/plan_block.py +22 -24
  12. pycityagent/cityagent/blocks/social_block.py +15 -17
  13. pycityagent/cityagent/blocks/utils.py +3 -2
  14. pycityagent/cityagent/firmagent.py +1 -1
  15. pycityagent/cityagent/governmentagent.py +1 -1
  16. pycityagent/cityagent/initial.py +1 -1
  17. pycityagent/cityagent/memory_config.py +0 -1
  18. pycityagent/cityagent/message_intercept.py +7 -8
  19. pycityagent/cityagent/nbsagent.py +1 -1
  20. pycityagent/cityagent/societyagent.py +1 -2
  21. pycityagent/configs/__init__.py +18 -0
  22. pycityagent/configs/exp_config.py +202 -0
  23. pycityagent/configs/sim_config.py +251 -0
  24. pycityagent/configs/utils.py +17 -0
  25. pycityagent/environment/__init__.py +2 -0
  26. pycityagent/{economy → environment/economy}/econ_client.py +14 -32
  27. pycityagent/environment/sim/sim_env.py +17 -24
  28. pycityagent/environment/simulator.py +36 -113
  29. pycityagent/llm/__init__.py +1 -2
  30. pycityagent/llm/llm.py +54 -167
  31. pycityagent/memory/memory.py +13 -12
  32. pycityagent/message/message_interceptor.py +5 -4
  33. pycityagent/message/messager.py +3 -5
  34. pycityagent/metrics/__init__.py +1 -1
  35. pycityagent/metrics/mlflow_client.py +20 -17
  36. pycityagent/pycityagent-sim +0 -0
  37. pycityagent/simulation/agentgroup.py +18 -20
  38. pycityagent/simulation/simulation.py +157 -210
  39. pycityagent/survey/manager.py +0 -2
  40. pycityagent/utils/__init__.py +3 -0
  41. pycityagent/utils/config_const.py +20 -0
  42. pycityagent/workflow/__init__.py +1 -2
  43. pycityagent/workflow/block.py +0 -3
  44. {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/METADATA +7 -24
  45. {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/RECORD +50 -46
  46. pycityagent/llm/llmconfig.py +0 -18
  47. /pycityagent/{economy → environment/economy}/__init__.py +0 -0
  48. {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/LICENSE +0 -0
  49. {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/WHEEL +0 -0
  50. {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/entry_points.txt +0 -0
  51. {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from pycityproto.city.person.v2 import person_service_pb2 as person_service
17
17
  from pymongo import MongoClient
18
18
  from shapely.geometry import Point
19
19
 
20
+ from ..configs import SimConfig
20
21
  from .sim import CityClient, ControlSimEnv
21
22
  from .utils.const import *
22
23
 
@@ -29,19 +30,10 @@ __all__ = [
29
30
 
30
31
  @ray.remote
31
32
  class CityMap:
32
- def __init__(self, mongo_input: tuple[str, str, str, str], map_cache_path: str):
33
- if map_cache_path:
34
- self.map = SimMap(
35
- pb_path=map_cache_path,
36
- )
37
- else:
38
- mongo_uri, mongo_db, mongo_coll, cache_dir = mongo_input
39
- self.map = SimMap(
40
- mongo_uri=mongo_uri,
41
- mongo_db=mongo_db,
42
- mongo_coll=mongo_coll,
43
- cache_dir=cache_dir,
44
- )
33
+ def __init__(self, map_cache_path: str):
34
+ self.map = SimMap(
35
+ pb_path=map_cache_path,
36
+ )
45
37
  self.poi_cate = POI_CATG_DICT
46
38
 
47
39
  def get_aoi(self, aoi_id: Optional[int] = None):
@@ -81,74 +73,41 @@ class Simulator:
81
73
  - It reads parameters from a configuration dictionary, initializes map data, and starts or connects to a simulation server as needed.
82
74
  """
83
75
 
84
- def __init__(self, config: dict, create_map: bool = False) -> None:
85
- self.config = config
76
+ def __init__(self, sim_config: SimConfig, create_map: bool = False) -> None:
77
+ self.sim_config = sim_config
86
78
  """
87
79
  - 模拟器配置
88
80
  - simulator config
89
81
  """
90
- _map_request = config["map_request"]
91
- if "file_path" not in _map_request:
92
- # from mongo db
93
- _mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir = (
94
- _map_request["mongo_uri"],
95
- _map_request["mongo_db"],
96
- _map_request["mongo_coll"],
97
- _map_request["cache_dir"],
82
+ _map_pb_path = sim_config.prop_map_request.file_path
83
+ config = sim_config.prop_simulator_request
84
+ if not sim_config.prop_status.simulator_activated:
85
+ self._sim_env = sim_env = ControlSimEnv(
86
+ task_name=config.task_name, # type:ignore
87
+ map_file=_map_pb_path,
88
+ max_day=config.max_day, # type:ignore
89
+ start_step=config.start_step, # type:ignore
90
+ total_step=config.total_step, # type:ignore
91
+ log_dir=config.log_dir, # type:ignore
92
+ min_step_time=config.min_step_time, # type:ignore
93
+ primary_node_ip=config.primary_node_ip, # type:ignore
94
+ sim_addr=sim_config.simulator_server_address,
98
95
  )
99
- _mongo_client = MongoClient(_mongo_uri)
100
- os.makedirs(_map_cache_dir, exist_ok=True)
101
- _map_pb_path = os.path.join(_map_cache_dir, f"{_mongo_db}.{_mongo_coll}.pb") # type: ignore
102
- _map_pb = map_pb2.Map()
103
- if os.path.exists(_map_pb_path):
104
- with open(_map_pb_path, "rb") as f:
105
- _map_pb.ParseFromString(f.read())
106
- else:
107
- _map_pb = coll2pb(_mongo_client[_mongo_db][_mongo_coll], _map_pb)
108
- with open(_map_pb_path, "wb") as f:
109
- f.write(_map_pb.SerializeToString())
110
- else:
111
- # from local file
112
- _mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir = "", "", "", ""
113
- _map_pb_path = _map_request["file_path"]
114
-
115
- if "simulator" in config:
116
- if config["simulator"] is None:
117
- config["simulator"] = {}
118
- if not config["simulator"].get("_server_activated", False):
119
- self._sim_env = sim_env = ControlSimEnv(
120
- task_name=config["simulator"].get("task", "citysim"),
121
- map_file=_map_pb_path,
122
- max_day=config["simulator"].get("max_day", 1000),
123
- start_step=config["simulator"].get("start_step", 28800),
124
- total_step=config["simulator"].get(
125
- "total_step", 24 * 60 * 60 * 365
126
- ),
127
- log_dir=config["simulator"].get("log_dir", "./log"),
128
- min_step_time=config["simulator"].get("min_step_time", 1000),
129
- primary_node_ip=config["simulator"].get("primary_node_ip", "localhost"),
130
- sim_addr=config["simulator"].get("server", None),
131
- )
132
- self.server_addr = sim_env.sim_addr
133
- config["simulator"]["server"] = self.server_addr
134
- config["simulator"]["_server_activated"] = True
135
- # using local client
136
- self._client = CityClient(
137
- sim_env.sim_addr, secure=self.server_addr.startswith("https")
138
- )
139
- """
140
- - 模拟器grpc客户端
141
- - grpc client of simulator
142
- """
143
- else:
144
- self.server_addr = config["simulator"]["server"]
145
- self._client = CityClient(
146
- self.server_addr, secure=self.server_addr.startswith("https")
147
- )
96
+ self.server_addr = sim_env.sim_addr
97
+ sim_config.SetServerAddress(self.server_addr)
98
+ sim_config.prop_status.simulator_activated = True
99
+ # using local client
100
+ self._client = CityClient(
101
+ sim_env.sim_addr, secure=self.server_addr.startswith("https")
102
+ )
103
+ """
104
+ - 模拟器grpc客户端
105
+ - grpc client of simulator
106
+ """
148
107
  else:
149
- self.server_addr = None
150
- logger.warning(
151
- "No simulator config found, no simulator client will be used"
108
+ self.server_addr: str = sim_config.simulator_server_address # type:ignore
109
+ self._client = CityClient(
110
+ self.server_addr, secure=self.server_addr.startswith("https")
152
111
  )
153
112
  self._map = None
154
113
  """
@@ -157,7 +116,6 @@ class Simulator:
157
116
  """
158
117
  if create_map:
159
118
  self._map = CityMap.remote(
160
- (_mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir),
161
119
  _map_pb_path,
162
120
  )
163
121
  self._create_poi_id_2_aoi_id()
@@ -205,8 +163,8 @@ class Simulator:
205
163
  """
206
164
  return self._environment_prompt
207
165
 
208
- def get_server_addr(self):
209
- return self.server_addr
166
+ def get_server_addr(self) -> str:
167
+ return self.server_addr # type:ignore
210
168
 
211
169
  def set_environment(self, environment: dict[str, str]):
212
170
  """
@@ -239,41 +197,6 @@ class Simulator:
239
197
  """
240
198
  self._environment_prompt[key] = value
241
199
 
242
- # * Agent相关
243
- def find_agents_by_area(self, req: dict, status=None):
244
- """
245
- Find agents/persons within a specified area.
246
-
247
- - **Args**:
248
- - `req` (`dict`): A dictionary that describes the area. Refer to
249
- https://cityproto.sim.fiblab.net/#city.person.1.GetPersonByLongLatBBoxRequest.
250
- - `status` (`Optional[int]`): An integer representing the status of the agents/persons to filter by.
251
- If provided, only persons with the given status will be returned.
252
- Refer to https://cityproto.sim.fiblab.net/#city.agent.v2.Status.
253
-
254
- - **Returns**:
255
- - The response from the GetPersonByLongLatBBox method, possibly filtered by status.
256
- Refer to https://cityproto.sim.fiblab.net/#city.person.1.GetPersonByLongLatBBoxResponse.
257
- """
258
- start_time = time.time()
259
- log = {"req": "find_agents_by_area", "start_time": start_time, "consumption": 0}
260
- loop = asyncio.get_event_loop()
261
- resp = loop.run_until_complete(
262
- self._client.person_service.GetPersonByLongLatBBox(req=req)
263
- )
264
- loop.close()
265
- if status == None:
266
- return resp
267
- else:
268
- motions = []
269
- for agent in resp.motions: # type: ignore
270
- if agent.status in status:
271
- motions.append(agent)
272
- resp.motions = motions # type: ignore
273
- log["consumption"] = time.time() - start_time
274
- self._log_list.append(log)
275
- return resp
276
-
277
200
  def get_poi_categories(
278
201
  self,
279
202
  center: Optional[Union[tuple[float, float], Point]] = None,
@@ -1,11 +1,10 @@
1
1
  """LLM相关模块"""
2
2
 
3
3
  from .embeddings import SentenceEmbedding, SimpleEmbedding
4
- from .llm import LLM, LLMConfig
4
+ from .llm import LLM
5
5
 
6
6
  __all__ = [
7
7
  "LLM",
8
- "LLMConfig",
9
8
  "SentenceEmbedding",
10
9
  "SimpleEmbedding",
11
10
  ]
pycityagent/llm/llm.py CHANGED
@@ -3,20 +3,15 @@
3
3
  import asyncio
4
4
  import json
5
5
  import logging
6
- import time
7
6
  import os
8
- from http import HTTPStatus
9
- from io import BytesIO
7
+ import time
10
8
  from typing import Any, Optional, Union
11
9
 
12
- import dashscope
13
- import requests
14
- from dashscope import ImageSynthesis
15
10
  from openai import APIConnectionError, AsyncOpenAI, OpenAI, OpenAIError
16
- from PIL import Image
17
11
  from zhipuai import ZhipuAI
18
12
 
19
- from .llmconfig import *
13
+ from ..configs import LLMRequestConfig
14
+ from ..utils import LLMRequestType
20
15
  from .utils import *
21
16
 
22
17
  logging.getLogger("zhipuai").setLevel(logging.WARNING)
@@ -36,15 +31,15 @@ class LLM:
36
31
  - It initializes clients based on the specified request type and handles token usage and consumption reporting.
37
32
  """
38
33
 
39
- def __init__(self, config: LLMConfig) -> None:
34
+ def __init__(self, config: LLMRequestConfig) -> None:
40
35
  """
41
36
  Initializes the LLM instance.
42
37
 
43
38
  - **Parameters**:
44
- - `config`: An instance of `LLMConfig` containing configuration settings for the LLM.
39
+ - `config`: An instance of `LLMRequestConfig` containing configuration settings for the LLM.
45
40
  """
46
41
  self.config = config
47
- if config.text["request_type"] not in ["openai", "deepseek", "qwen", "zhipuai", "siliconflow"]:
42
+ if config.request_type not in {t.value for t in LLMRequestType}:
48
43
  raise ValueError("Invalid request type for text request")
49
44
  self.prompt_tokens_used = 0
50
45
  self.completion_tokens_used = 0
@@ -53,7 +48,7 @@ class LLM:
53
48
  self._current_client_index = 0
54
49
  self._log_list = []
55
50
 
56
- api_keys = self.config.text["api_key"]
51
+ api_keys = self.config.api_key
57
52
  if not isinstance(api_keys, list):
58
53
  api_keys = [api_keys]
59
54
 
@@ -61,42 +56,40 @@ class LLM:
61
56
  self._client_usage = []
62
57
 
63
58
  for api_key in api_keys:
64
- if self.config.text["request_type"] == "openai":
59
+ if self.config.request_type == LLMRequestType.OpenAI:
65
60
  client = AsyncOpenAI(api_key=api_key, timeout=300)
66
- elif self.config.text["request_type"] == "deepseek":
61
+ elif self.config.request_type == LLMRequestType.DeepSeek:
67
62
  client = AsyncOpenAI(
68
63
  api_key=api_key,
69
64
  base_url="https://api.deepseek.com/v1",
70
65
  timeout=300,
71
66
  )
72
- elif self.config.text["request_type"] == "qwen":
67
+ elif self.config.request_type == LLMRequestType.Qwen:
73
68
  client = AsyncOpenAI(
74
69
  api_key=api_key,
75
70
  base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
76
71
  timeout=300,
77
72
  )
78
- elif self.config.text["request_type"] == "siliconflow":
73
+ elif self.config.request_type == LLMRequestType.SiliconFlow:
79
74
  client = AsyncOpenAI(
80
75
  api_key=api_key,
81
76
  base_url="https://api.siliconflow.cn/v1",
82
77
  timeout=300,
83
78
  )
84
- elif self.config.text["request_type"] == "zhipuai":
79
+ elif self.config.request_type == LLMRequestType.ZhipuAI:
85
80
  client = ZhipuAI(api_key=api_key, timeout=300)
86
81
  else:
87
82
  raise ValueError(
88
- f"Unsupported `request_type` {self.config.text['request_type']}!"
83
+ f"Unsupported `request_type` {self.config.request_type}!"
89
84
  )
90
85
  self._aclients.append(client)
91
- self._client_usage.append({
92
- "prompt_tokens": 0,
93
- "completion_tokens": 0,
94
- "request_number": 0
95
- })
86
+ self._client_usage.append(
87
+ {"prompt_tokens": 0, "completion_tokens": 0, "request_number": 0}
88
+ )
96
89
 
97
90
  def get_log_list(self):
98
91
  return self._log_list
99
-
92
+
100
93
  def clear_log_list(self):
101
94
  self._log_list = []
102
95
 
@@ -122,7 +115,7 @@ class LLM:
122
115
  """
123
116
  for usage in self._client_usage:
124
117
  usage["prompt_tokens"] = 0
125
- usage["completion_tokens"] = 0
118
+ usage["completion_tokens"] = 0
126
119
  usage["request_number"] = 0
127
120
 
128
121
  def get_consumption(self):
@@ -130,7 +123,7 @@ class LLM:
130
123
  for i, usage in enumerate(self._client_usage):
131
124
  consumption[f"api-key-{i+1}"] = {
132
125
  "total_tokens": usage["prompt_tokens"] + usage["completion_tokens"],
133
- "request_number": usage["request_number"]
126
+ "request_number": usage["request_number"],
134
127
  }
135
128
  return consumption
136
129
 
@@ -147,27 +140,24 @@ class LLM:
147
140
  - **Returns**:
148
141
  - A dictionary summarizing the token usage and, if applicable, the estimated cost.
149
142
  """
150
- total_stats = {
151
- "total": 0,
152
- "prompt": 0,
153
- "completion": 0,
154
- "requests": 0
155
- }
156
-
143
+ total_stats = {"total": 0, "prompt": 0, "completion": 0, "requests": 0}
144
+
157
145
  for i, usage in enumerate(self._client_usage):
158
146
  prompt_tokens = usage["prompt_tokens"]
159
147
  completion_tokens = usage["completion_tokens"]
160
148
  requests = usage["request_number"]
161
149
  total_tokens = prompt_tokens + completion_tokens
162
-
150
+
163
151
  total_stats["total"] += total_tokens
164
152
  total_stats["prompt"] += prompt_tokens
165
153
  total_stats["completion"] += completion_tokens
166
154
  total_stats["requests"] += requests
167
-
168
- rate = prompt_tokens / completion_tokens if completion_tokens != 0 else "nan"
155
+
156
+ rate = (
157
+ prompt_tokens / completion_tokens if completion_tokens != 0 else "nan"
158
+ )
169
159
  tokens_per_request = total_tokens / requests if requests != 0 else "nan"
170
-
160
+
171
161
  print(f"\nAPI Key #{i+1}:")
172
162
  print(f"Request Number: {requests}")
173
163
  print("Token Usage:")
@@ -176,12 +166,14 @@ class LLM:
176
166
  print(f" - Completion tokens: {completion_tokens}")
177
167
  print(f" - Token per request: {tokens_per_request}")
178
168
  print(f" - Prompt:Completion ratio: {rate}:1")
179
-
169
+
180
170
  if input_price is not None and output_price is not None:
181
- consumption = (prompt_tokens / 1000000 * input_price +
182
- completion_tokens / 1000000 * output_price)
171
+ consumption = (
172
+ prompt_tokens / 1000000 * input_price
173
+ + completion_tokens / 1000000 * output_price
174
+ )
183
175
  print(f" - Cost Estimation: {consumption}")
184
-
176
+
185
177
  return total_stats
186
178
 
187
179
  def _get_next_client(self):
@@ -203,6 +195,7 @@ class LLM:
203
195
  async def atext_request(
204
196
  self,
205
197
  dialog: Any,
198
+ response_format: Optional[dict[str, Any]] = None,
206
199
  temperature: float = 1,
207
200
  max_tokens: Optional[int] = None,
208
201
  top_p: Optional[float] = None,
@@ -222,6 +215,7 @@ class LLM:
222
215
 
223
216
  - **Parameters**:
224
217
  - `dialog`: Messages to send as part of the chat completion request.
218
+ - `response_format`: JSON schema for the response. Default is None.
225
219
  - `temperature`: Controls randomness in the model's output. Default is 1.
226
220
  - `max_tokens`: Maximum number of tokens to generate in the response. Default is None.
227
221
  - `top_p`: Limits the next token selection to a subset of tokens with a cumulative probability above this value. Default is None.
@@ -238,19 +232,23 @@ class LLM:
238
232
  """
239
233
  start_time = time.time()
240
234
  log = {"request_time": start_time}
235
+ assert (
236
+ self.semaphore is not None
237
+ ), "Please set semaphore with `set_semaphore` first!"
241
238
  async with self.semaphore:
242
239
  if (
243
- self.config.text["request_type"] == "openai"
244
- or self.config.text["request_type"] == "deepseek"
245
- or self.config.text["request_type"] == "qwen"
246
- or self.config.text["request_type"] == "siliconflow"
240
+ self.config.request_type == "openai"
241
+ or self.config.request_type == "deepseek"
242
+ or self.config.request_type == "qwen"
243
+ or self.config.request_type == "siliconflow"
247
244
  ):
248
245
  for attempt in range(retries):
249
246
  try:
250
247
  client = self._get_next_client()
251
248
  response = await client.chat.completions.create(
252
- model=self.config.text["model"],
249
+ model=self.config.model,
253
250
  messages=dialog,
251
+ response_format=response_format,
254
252
  temperature=temperature,
255
253
  max_tokens=max_tokens,
256
254
  top_p=top_p,
@@ -263,7 +261,9 @@ class LLM:
263
261
  ) # type: ignore
264
262
  self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
265
263
  self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
266
- self._client_usage[self._current_client_index]["request_number"] += 1
264
+ self._client_usage[self._current_client_index][
265
+ "request_number"
266
+ ] += 1
267
267
  end_time = time.time()
268
268
  log["consumption"] = end_time - start_time
269
269
  log["input_tokens"] = response.usage.prompt_tokens
@@ -295,16 +295,15 @@ class LLM:
295
295
  except Exception as e:
296
296
  print("LLM Error (OpenAI):", e)
297
297
  if attempt < retries - 1:
298
- print(dialog)
299
298
  await asyncio.sleep(2**attempt)
300
299
  else:
301
300
  raise e
302
- elif self.config.text["request_type"] == "zhipuai":
301
+ elif self.config.request_type == "zhipuai":
303
302
  for attempt in range(retries):
304
303
  try:
305
304
  client = self._get_next_client()
306
305
  response = client.chat.asyncCompletions.create( # type: ignore
307
- model=self.config.text["model"],
306
+ model=self.config.model,
308
307
  messages=dialog,
309
308
  temperature=temperature,
310
309
  top_p=top_p,
@@ -330,10 +329,12 @@ class LLM:
330
329
 
331
330
  self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
332
331
  self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
333
- self._client_usage[self._current_client_index]["request_number"] += 1
332
+ self._client_usage[self._current_client_index][
333
+ "request_number"
334
+ ] += 1
334
335
  end_time = time.time()
335
336
  log["used_time"] = end_time - start_time
336
- log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
337
+ log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens # type: ignore
337
338
  self._log_list.append(log)
338
339
  if tools and result_response.choices[0].message.tool_calls: # type: ignore
339
340
  return json.loads(
@@ -356,118 +357,4 @@ class LLM:
356
357
  else:
357
358
  raise e
358
359
  else:
359
- print("ERROR: Wrong Config")
360
- return "wrong config"
361
-
362
- async def img_understand(
363
- self, img_path: Union[str, list[str]], prompt: Optional[str] = None
364
- ) -> str:
365
- """
366
- Analyzes and understands images using external APIs.
367
-
368
- - **Args**:
369
- img_path (Union[str, list[str]]): Path or list of paths to the images for analysis.
370
- prompt (Optional[str]): Guidance text for understanding the images.
371
-
372
- - **Returns**:
373
- str: The content derived from understanding the images.
374
- """
375
- ppt = "如何理解这幅图像?"
376
- if prompt != None:
377
- ppt = prompt
378
- if self.config.image_u["request_type"] == "openai":
379
- if "api_base" in self.config.image_u.keys():
380
- api_base = self.config.image_u["api_base"]
381
- else:
382
- api_base = None
383
- client = OpenAI(
384
- api_key=self.config.text["api_key"],
385
- base_url=api_base,
386
- )
387
- content = []
388
- content.append({"type": "text", "text": ppt})
389
- if isinstance(img_path, str):
390
- base64_image = encode_image(img_path)
391
- content.append(
392
- {
393
- "type": "image_url",
394
- "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
395
- }
396
- )
397
- elif isinstance(img_path, list) and all(
398
- isinstance(item, str) for item in img_path
399
- ):
400
- for item in img_path:
401
- base64_image = encode_image(item)
402
- content.append(
403
- {
404
- "type": "image_url",
405
- "image_url": {
406
- "url": f"data:image/jpeg;base64,{base64_image}"
407
- },
408
- }
409
- )
410
- response = client.chat.completions.create(
411
- model=self.config.image_u["model"],
412
- messages=[{"role": "user", "content": content}],
413
- )
414
- return response.choices[0].message.content # type: ignore
415
- elif self.config.image_u["request_type"] == "qwen":
416
- content = []
417
- if isinstance(img_path, str):
418
- content.append({"image": "file://" + img_path})
419
- content.append({"text": ppt})
420
- elif isinstance(img_path, list) and all(
421
- isinstance(item, str) for item in img_path
422
- ):
423
- for item in img_path:
424
- content.append({"image": "file://" + item})
425
- content.append({"text": ppt})
426
-
427
- dialog = [{"role": "user", "content": content}]
428
- response = dashscope.MultiModalConversation.call(
429
- model=self.config.image_u["model"],
430
- api_key=self.config.image_u["api_key"],
431
- messages=dialog,
432
- )
433
- if response.status_code == HTTPStatus.OK: # type: ignore
434
- return response.output.choices[0]["message"]["content"] # type: ignore
435
- else:
436
- print(response.code) # type: ignore # The error code.
437
- return "Error"
438
- else:
439
- print(
440
- "ERROR: wrong image understanding type, only 'openai' and 'openai' is available"
441
- )
442
- return "Error"
443
-
444
- async def img_generate(self, prompt: str, size: str = "512*512", quantity: int = 1):
445
- """
446
- Generates images based on a given prompt.
447
-
448
- - **Args**:
449
- prompt (str): Prompt for generating images.
450
- size (str): Size of the generated images, default is '512*512'.
451
- quantity (int): Number of images to generate, default is 1.
452
-
453
- - **Returns**:
454
- list[PIL.Image.Image]: List of generated PIL Image objects.
455
- """
456
- rsp = ImageSynthesis.call(
457
- model=self.config.image_g["model"],
458
- api_key=self.config.image_g["api_key"],
459
- prompt=prompt,
460
- n=quantity,
461
- size=size,
462
- )
463
- if rsp.status_code == HTTPStatus.OK:
464
- res = []
465
- for result in rsp.output.results:
466
- res.append(Image.open(BytesIO(requests.get(result.url).content)))
467
- return res
468
- else:
469
- print(
470
- "Failed, status_code: %s, code: %s, message: %s"
471
- % (rsp.status_code, rsp.code, rsp.message)
472
- )
473
- return None
360
+ raise ValueError("ERROR: Wrong Config")
@@ -335,9 +335,7 @@ class StreamMemory:
335
335
  )
336
336
  return "\n".join(formatted_results)
337
337
 
338
- async def get_by_ids(
339
- self, memory_ids: Union[int, list[int]]
340
- ) -> Coroutine[Any, Any, str]:
338
+ async def get_by_ids(self, memory_ids: Union[int, list[int]]) -> str:
341
339
  """获取指定ID的记忆"""
342
340
  memories = [memory for memory in self._memories if memory.id in memory_ids]
343
341
  sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
@@ -491,15 +489,18 @@ class StreamMemory:
491
489
  - **Returns**:
492
490
  - `list[dict]`: List of all memory nodes as dictionaries.
493
491
  """
494
- return [{
495
- "id": memory.id,
496
- "cognition_id": memory.cognition_id,
497
- "tag": memory.tag.value,
498
- "location": memory.location,
499
- "description": memory.description,
500
- "day": memory.day,
501
- "t": memory.t,
502
- } for memory in self._memories]
492
+ return [
493
+ {
494
+ "id": memory.id,
495
+ "cognition_id": memory.cognition_id,
496
+ "tag": memory.tag.value,
497
+ "location": memory.location,
498
+ "description": memory.description,
499
+ "day": memory.day,
500
+ "t": memory.t,
501
+ }
502
+ for memory in self._memories
503
+ ]
503
504
 
504
505
 
505
506
  class StatusMemory:
@@ -10,7 +10,8 @@ from typing import Any, Optional, Union
10
10
  import ray
11
11
  from ray.util.queue import Queue
12
12
 
13
- from ..llm import LLM, LLMConfig
13
+ from ..configs import LLMRequestConfig
14
+ from ..llm import LLM
14
15
  from ..utils.decorators import lock_decorator
15
16
 
16
17
  DEFAULT_ERROR_STRING = """
@@ -95,7 +96,7 @@ class MessageInterceptor:
95
96
  self,
96
97
  blocks: Optional[list[MessageBlockBase]] = None,
97
98
  black_list: Optional[list[tuple[str, str]]] = None,
98
- llm_config: Optional[dict] = None,
99
+ llm_config: Optional[LLMRequestConfig] = None,
99
100
  queue: Optional[Queue] = None,
100
101
  ) -> None:
101
102
  """
@@ -104,7 +105,7 @@ class MessageInterceptor:
104
105
  - **Args**:
105
106
  - `blocks` (Optional[list[MessageBlockBase]], optional): Initial list of message interception rules. Defaults to an empty list.
106
107
  - `black_list` (Optional[list[tuple[str, str]]], optional): Initial blacklist of communication pairs. Defaults to an empty list.
107
- - `llm_config` (Optional[dict], optional): Configuration dictionary for initializing the LLM instance. Defaults to None.
108
+ - `llm_config` (Optional[LLMRequestConfig], optional): Configuration dictionary for initializing the LLM instance. Defaults to None.
108
109
  - `queue` (Optional[Queue], optional): Queue for message processing. Defaults to None.
109
110
  """
110
111
  if blocks is not None:
@@ -119,7 +120,7 @@ class MessageInterceptor:
119
120
  []
120
121
  ) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
121
122
  if llm_config:
122
- self._llm = LLM(LLMConfig(llm_config))
123
+ self._llm = LLM(llm_config)
123
124
  else:
124
125
  self._llm = None
125
126
  self._queue = queue