pycityagent 2.0.0a93__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a95__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. pycityagent/agent/agent.py +5 -5
  2. pycityagent/agent/agent_base.py +1 -2
  3. pycityagent/cityagent/__init__.py +6 -5
  4. pycityagent/cityagent/bankagent.py +2 -2
  5. pycityagent/cityagent/blocks/__init__.py +4 -4
  6. pycityagent/cityagent/blocks/cognition_block.py +7 -4
  7. pycityagent/cityagent/blocks/economy_block.py +227 -135
  8. pycityagent/cityagent/blocks/mobility_block.py +70 -27
  9. pycityagent/cityagent/blocks/needs_block.py +11 -12
  10. pycityagent/cityagent/blocks/other_block.py +2 -2
  11. pycityagent/cityagent/blocks/plan_block.py +22 -24
  12. pycityagent/cityagent/blocks/social_block.py +15 -17
  13. pycityagent/cityagent/blocks/utils.py +3 -2
  14. pycityagent/cityagent/firmagent.py +1 -1
  15. pycityagent/cityagent/governmentagent.py +1 -1
  16. pycityagent/cityagent/initial.py +1 -1
  17. pycityagent/cityagent/memory_config.py +0 -1
  18. pycityagent/cityagent/message_intercept.py +7 -8
  19. pycityagent/cityagent/nbsagent.py +1 -1
  20. pycityagent/cityagent/societyagent.py +1 -2
  21. pycityagent/configs/__init__.py +18 -0
  22. pycityagent/configs/exp_config.py +202 -0
  23. pycityagent/configs/sim_config.py +251 -0
  24. pycityagent/configs/utils.py +17 -0
  25. pycityagent/environment/__init__.py +2 -0
  26. pycityagent/{economy → environment/economy}/econ_client.py +14 -32
  27. pycityagent/environment/sim/sim_env.py +17 -24
  28. pycityagent/environment/simulator.py +36 -113
  29. pycityagent/llm/__init__.py +1 -2
  30. pycityagent/llm/llm.py +60 -166
  31. pycityagent/memory/memory.py +13 -12
  32. pycityagent/message/message_interceptor.py +5 -4
  33. pycityagent/message/messager.py +3 -5
  34. pycityagent/metrics/__init__.py +1 -1
  35. pycityagent/metrics/mlflow_client.py +20 -17
  36. pycityagent/pycityagent-sim +0 -0
  37. pycityagent/simulation/agentgroup.py +17 -19
  38. pycityagent/simulation/simulation.py +157 -210
  39. pycityagent/survey/manager.py +0 -2
  40. pycityagent/utils/__init__.py +3 -0
  41. pycityagent/utils/config_const.py +20 -0
  42. pycityagent/workflow/__init__.py +1 -2
  43. pycityagent/workflow/block.py +0 -3
  44. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/METADATA +7 -24
  45. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/RECORD +50 -46
  46. pycityagent/llm/llmconfig.py +0 -18
  47. /pycityagent/{economy → environment/economy}/__init__.py +0 -0
  48. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/LICENSE +0 -0
  49. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/WHEEL +0 -0
  50. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/entry_points.txt +0 -0
  51. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from pycityproto.city.person.v2 import person_service_pb2 as person_service
17
17
  from pymongo import MongoClient
18
18
  from shapely.geometry import Point
19
19
 
20
+ from ..configs import SimConfig
20
21
  from .sim import CityClient, ControlSimEnv
21
22
  from .utils.const import *
22
23
 
@@ -29,19 +30,10 @@ __all__ = [
29
30
 
30
31
  @ray.remote
31
32
  class CityMap:
32
- def __init__(self, mongo_input: tuple[str, str, str, str], map_cache_path: str):
33
- if map_cache_path:
34
- self.map = SimMap(
35
- pb_path=map_cache_path,
36
- )
37
- else:
38
- mongo_uri, mongo_db, mongo_coll, cache_dir = mongo_input
39
- self.map = SimMap(
40
- mongo_uri=mongo_uri,
41
- mongo_db=mongo_db,
42
- mongo_coll=mongo_coll,
43
- cache_dir=cache_dir,
44
- )
33
+ def __init__(self, map_cache_path: str):
34
+ self.map = SimMap(
35
+ pb_path=map_cache_path,
36
+ )
45
37
  self.poi_cate = POI_CATG_DICT
46
38
 
47
39
  def get_aoi(self, aoi_id: Optional[int] = None):
@@ -81,74 +73,41 @@ class Simulator:
81
73
  - It reads parameters from a configuration dictionary, initializes map data, and starts or connects to a simulation server as needed.
82
74
  """
83
75
 
84
- def __init__(self, config: dict, create_map: bool = False) -> None:
85
- self.config = config
76
+ def __init__(self, sim_config: SimConfig, create_map: bool = False) -> None:
77
+ self.sim_config = sim_config
86
78
  """
87
79
  - 模拟器配置
88
80
  - simulator config
89
81
  """
90
- _map_request = config["map_request"]
91
- if "file_path" not in _map_request:
92
- # from mongo db
93
- _mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir = (
94
- _map_request["mongo_uri"],
95
- _map_request["mongo_db"],
96
- _map_request["mongo_coll"],
97
- _map_request["cache_dir"],
82
+ _map_pb_path = sim_config.prop_map_request.file_path
83
+ config = sim_config.prop_simulator_request
84
+ if not sim_config.prop_status.simulator_activated:
85
+ self._sim_env = sim_env = ControlSimEnv(
86
+ task_name=config.task_name, # type:ignore
87
+ map_file=_map_pb_path,
88
+ max_day=config.max_day, # type:ignore
89
+ start_step=config.start_step, # type:ignore
90
+ total_step=config.total_step, # type:ignore
91
+ log_dir=config.log_dir, # type:ignore
92
+ min_step_time=config.min_step_time, # type:ignore
93
+ primary_node_ip=config.primary_node_ip, # type:ignore
94
+ sim_addr=sim_config.simulator_server_address,
98
95
  )
99
- _mongo_client = MongoClient(_mongo_uri)
100
- os.makedirs(_map_cache_dir, exist_ok=True)
101
- _map_pb_path = os.path.join(_map_cache_dir, f"{_mongo_db}.{_mongo_coll}.pb") # type: ignore
102
- _map_pb = map_pb2.Map()
103
- if os.path.exists(_map_pb_path):
104
- with open(_map_pb_path, "rb") as f:
105
- _map_pb.ParseFromString(f.read())
106
- else:
107
- _map_pb = coll2pb(_mongo_client[_mongo_db][_mongo_coll], _map_pb)
108
- with open(_map_pb_path, "wb") as f:
109
- f.write(_map_pb.SerializeToString())
110
- else:
111
- # from local file
112
- _mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir = "", "", "", ""
113
- _map_pb_path = _map_request["file_path"]
114
-
115
- if "simulator" in config:
116
- if config["simulator"] is None:
117
- config["simulator"] = {}
118
- if not config["simulator"].get("_server_activated", False):
119
- self._sim_env = sim_env = ControlSimEnv(
120
- task_name=config["simulator"].get("task", "citysim"),
121
- map_file=_map_pb_path,
122
- max_day=config["simulator"].get("max_day", 1000),
123
- start_step=config["simulator"].get("start_step", 28800),
124
- total_step=config["simulator"].get(
125
- "total_step", 24 * 60 * 60 * 365
126
- ),
127
- log_dir=config["simulator"].get("log_dir", "./log"),
128
- min_step_time=config["simulator"].get("min_step_time", 1000),
129
- primary_node_ip=config["simulator"].get("primary_node_ip", "localhost"),
130
- sim_addr=config["simulator"].get("server", None),
131
- )
132
- self.server_addr = sim_env.sim_addr
133
- config["simulator"]["server"] = self.server_addr
134
- config["simulator"]["_server_activated"] = True
135
- # using local client
136
- self._client = CityClient(
137
- sim_env.sim_addr, secure=self.server_addr.startswith("https")
138
- )
139
- """
140
- - 模拟器grpc客户端
141
- - grpc client of simulator
142
- """
143
- else:
144
- self.server_addr = config["simulator"]["server"]
145
- self._client = CityClient(
146
- self.server_addr, secure=self.server_addr.startswith("https")
147
- )
96
+ self.server_addr = sim_env.sim_addr
97
+ sim_config.SetServerAddress(self.server_addr)
98
+ sim_config.prop_status.simulator_activated = True
99
+ # using local client
100
+ self._client = CityClient(
101
+ sim_env.sim_addr, secure=self.server_addr.startswith("https")
102
+ )
103
+ """
104
+ - 模拟器grpc客户端
105
+ - grpc client of simulator
106
+ """
148
107
  else:
149
- self.server_addr = None
150
- logger.warning(
151
- "No simulator config found, no simulator client will be used"
108
+ self.server_addr: str = sim_config.simulator_server_address # type:ignore
109
+ self._client = CityClient(
110
+ self.server_addr, secure=self.server_addr.startswith("https")
152
111
  )
153
112
  self._map = None
154
113
  """
@@ -157,7 +116,6 @@ class Simulator:
157
116
  """
158
117
  if create_map:
159
118
  self._map = CityMap.remote(
160
- (_mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir),
161
119
  _map_pb_path,
162
120
  )
163
121
  self._create_poi_id_2_aoi_id()
@@ -205,8 +163,8 @@ class Simulator:
205
163
  """
206
164
  return self._environment_prompt
207
165
 
208
- def get_server_addr(self):
209
- return self.server_addr
166
+ def get_server_addr(self) -> str:
167
+ return self.server_addr # type:ignore
210
168
 
211
169
  def set_environment(self, environment: dict[str, str]):
212
170
  """
@@ -239,41 +197,6 @@ class Simulator:
239
197
  """
240
198
  self._environment_prompt[key] = value
241
199
 
242
- # * Agent相关
243
- def find_agents_by_area(self, req: dict, status=None):
244
- """
245
- Find agents/persons within a specified area.
246
-
247
- - **Args**:
248
- - `req` (`dict`): A dictionary that describes the area. Refer to
249
- https://cityproto.sim.fiblab.net/#city.person.1.GetPersonByLongLatBBoxRequest.
250
- - `status` (`Optional[int]`): An integer representing the status of the agents/persons to filter by.
251
- If provided, only persons with the given status will be returned.
252
- Refer to https://cityproto.sim.fiblab.net/#city.agent.v2.Status.
253
-
254
- - **Returns**:
255
- - The response from the GetPersonByLongLatBBox method, possibly filtered by status.
256
- Refer to https://cityproto.sim.fiblab.net/#city.person.1.GetPersonByLongLatBBoxResponse.
257
- """
258
- start_time = time.time()
259
- log = {"req": "find_agents_by_area", "start_time": start_time, "consumption": 0}
260
- loop = asyncio.get_event_loop()
261
- resp = loop.run_until_complete(
262
- self._client.person_service.GetPersonByLongLatBBox(req=req)
263
- )
264
- loop.close()
265
- if status == None:
266
- return resp
267
- else:
268
- motions = []
269
- for agent in resp.motions: # type: ignore
270
- if agent.status in status:
271
- motions.append(agent)
272
- resp.motions = motions # type: ignore
273
- log["consumption"] = time.time() - start_time
274
- self._log_list.append(log)
275
- return resp
276
-
277
200
  def get_poi_categories(
278
201
  self,
279
202
  center: Optional[Union[tuple[float, float], Point]] = None,
@@ -1,11 +1,10 @@
1
1
  """LLM相关模块"""
2
2
 
3
3
  from .embeddings import SentenceEmbedding, SimpleEmbedding
4
- from .llm import LLM, LLMConfig
4
+ from .llm import LLM
5
5
 
6
6
  __all__ = [
7
7
  "LLM",
8
- "LLMConfig",
9
8
  "SentenceEmbedding",
10
9
  "SimpleEmbedding",
11
10
  ]
pycityagent/llm/llm.py CHANGED
@@ -3,20 +3,15 @@
3
3
  import asyncio
4
4
  import json
5
5
  import logging
6
- import time
7
6
  import os
8
- from http import HTTPStatus
9
- from io import BytesIO
7
+ import time
10
8
  from typing import Any, Optional, Union
11
9
 
12
- import dashscope
13
- import requests
14
- from dashscope import ImageSynthesis
15
10
  from openai import APIConnectionError, AsyncOpenAI, OpenAI, OpenAIError
16
- from PIL import Image
17
11
  from zhipuai import ZhipuAI
18
12
 
19
- from .llmconfig import *
13
+ from ..configs import LLMRequestConfig
14
+ from ..utils import LLMRequestType
20
15
  from .utils import *
21
16
 
22
17
  logging.getLogger("zhipuai").setLevel(logging.WARNING)
@@ -36,24 +31,24 @@ class LLM:
36
31
  - It initializes clients based on the specified request type and handles token usage and consumption reporting.
37
32
  """
38
33
 
39
- def __init__(self, config: LLMConfig) -> None:
34
+ def __init__(self, config: LLMRequestConfig) -> None:
40
35
  """
41
36
  Initializes the LLM instance.
42
37
 
43
38
  - **Parameters**:
44
- - `config`: An instance of `LLMConfig` containing configuration settings for the LLM.
39
+ - `config`: An instance of `LLMRequestConfig` containing configuration settings for the LLM.
45
40
  """
46
41
  self.config = config
47
- if config.text["request_type"] not in ["openai", "deepseek", "qwen", "zhipuai"]:
42
+ if config.request_type not in {t.value for t in LLMRequestType}:
48
43
  raise ValueError("Invalid request type for text request")
49
44
  self.prompt_tokens_used = 0
50
45
  self.completion_tokens_used = 0
51
46
  self.request_number = 0
52
- self.semaphore = None
47
+ self.semaphore = asyncio.Semaphore(200)
53
48
  self._current_client_index = 0
54
49
  self._log_list = []
55
50
 
56
- api_keys = self.config.text["api_key"]
51
+ api_keys = self.config.api_key
57
52
  if not isinstance(api_keys, list):
58
53
  api_keys = [api_keys]
59
54
 
@@ -61,36 +56,40 @@ class LLM:
61
56
  self._client_usage = []
62
57
 
63
58
  for api_key in api_keys:
64
- if self.config.text["request_type"] == "openai":
59
+ if self.config.request_type == LLMRequestType.OpenAI:
65
60
  client = AsyncOpenAI(api_key=api_key, timeout=300)
66
- elif self.config.text["request_type"] == "deepseek":
61
+ elif self.config.request_type == LLMRequestType.DeepSeek:
67
62
  client = AsyncOpenAI(
68
63
  api_key=api_key,
69
64
  base_url="https://api.deepseek.com/v1",
70
65
  timeout=300,
71
66
  )
72
- elif self.config.text["request_type"] == "qwen":
67
+ elif self.config.request_type == LLMRequestType.Qwen:
73
68
  client = AsyncOpenAI(
74
69
  api_key=api_key,
75
70
  base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
76
71
  timeout=300,
77
72
  )
78
- elif self.config.text["request_type"] == "zhipuai":
73
+ elif self.config.request_type == LLMRequestType.SiliconFlow:
74
+ client = AsyncOpenAI(
75
+ api_key=api_key,
76
+ base_url="https://api.siliconflow.cn/v1",
77
+ timeout=300,
78
+ )
79
+ elif self.config.request_type == LLMRequestType.ZhipuAI:
79
80
  client = ZhipuAI(api_key=api_key, timeout=300)
80
81
  else:
81
82
  raise ValueError(
82
- f"Unsupported `request_type` {self.config.text['request_type']}!"
83
+ f"Unsupported `request_type` {self.config.request_type}!"
83
84
  )
84
85
  self._aclients.append(client)
85
- self._client_usage.append({
86
- "prompt_tokens": 0,
87
- "completion_tokens": 0,
88
- "request_number": 0
89
- })
86
+ self._client_usage.append(
87
+ {"prompt_tokens": 0, "completion_tokens": 0, "request_number": 0}
88
+ )
90
89
 
91
90
  def get_log_list(self):
92
91
  return self._log_list
93
-
92
+
94
93
  def clear_log_list(self):
95
94
  self._log_list = []
96
95
 
@@ -116,7 +115,7 @@ class LLM:
116
115
  """
117
116
  for usage in self._client_usage:
118
117
  usage["prompt_tokens"] = 0
119
- usage["completion_tokens"] = 0
118
+ usage["completion_tokens"] = 0
120
119
  usage["request_number"] = 0
121
120
 
122
121
  def get_consumption(self):
@@ -124,7 +123,7 @@ class LLM:
124
123
  for i, usage in enumerate(self._client_usage):
125
124
  consumption[f"api-key-{i+1}"] = {
126
125
  "total_tokens": usage["prompt_tokens"] + usage["completion_tokens"],
127
- "request_number": usage["request_number"]
126
+ "request_number": usage["request_number"],
128
127
  }
129
128
  return consumption
130
129
 
@@ -141,27 +140,24 @@ class LLM:
141
140
  - **Returns**:
142
141
  - A dictionary summarizing the token usage and, if applicable, the estimated cost.
143
142
  """
144
- total_stats = {
145
- "total": 0,
146
- "prompt": 0,
147
- "completion": 0,
148
- "requests": 0
149
- }
150
-
143
+ total_stats = {"total": 0, "prompt": 0, "completion": 0, "requests": 0}
144
+
151
145
  for i, usage in enumerate(self._client_usage):
152
146
  prompt_tokens = usage["prompt_tokens"]
153
147
  completion_tokens = usage["completion_tokens"]
154
148
  requests = usage["request_number"]
155
149
  total_tokens = prompt_tokens + completion_tokens
156
-
150
+
157
151
  total_stats["total"] += total_tokens
158
152
  total_stats["prompt"] += prompt_tokens
159
153
  total_stats["completion"] += completion_tokens
160
154
  total_stats["requests"] += requests
161
-
162
- rate = prompt_tokens / completion_tokens if completion_tokens != 0 else "nan"
155
+
156
+ rate = (
157
+ prompt_tokens / completion_tokens if completion_tokens != 0 else "nan"
158
+ )
163
159
  tokens_per_request = total_tokens / requests if requests != 0 else "nan"
164
-
160
+
165
161
  print(f"\nAPI Key #{i+1}:")
166
162
  print(f"Request Number: {requests}")
167
163
  print("Token Usage:")
@@ -170,12 +166,14 @@ class LLM:
170
166
  print(f" - Completion tokens: {completion_tokens}")
171
167
  print(f" - Token per request: {tokens_per_request}")
172
168
  print(f" - Prompt:Completion ratio: {rate}:1")
173
-
169
+
174
170
  if input_price is not None and output_price is not None:
175
- consumption = (prompt_tokens / 1000000 * input_price +
176
- completion_tokens / 1000000 * output_price)
171
+ consumption = (
172
+ prompt_tokens / 1000000 * input_price
173
+ + completion_tokens / 1000000 * output_price
174
+ )
177
175
  print(f" - Cost Estimation: {consumption}")
178
-
176
+
179
177
  return total_stats
180
178
 
181
179
  def _get_next_client(self):
@@ -197,6 +195,7 @@ class LLM:
197
195
  async def atext_request(
198
196
  self,
199
197
  dialog: Any,
198
+ response_format: Optional[dict[str, Any]] = None,
200
199
  temperature: float = 1,
201
200
  max_tokens: Optional[int] = None,
202
201
  top_p: Optional[float] = None,
@@ -216,6 +215,7 @@ class LLM:
216
215
 
217
216
  - **Parameters**:
218
217
  - `dialog`: Messages to send as part of the chat completion request.
218
+ - `response_format`: JSON schema for the response. Default is None.
219
219
  - `temperature`: Controls randomness in the model's output. Default is 1.
220
220
  - `max_tokens`: Maximum number of tokens to generate in the response. Default is None.
221
221
  - `top_p`: Limits the next token selection to a subset of tokens with a cumulative probability above this value. Default is None.
@@ -232,18 +232,23 @@ class LLM:
232
232
  """
233
233
  start_time = time.time()
234
234
  log = {"request_time": start_time}
235
+ assert (
236
+ self.semaphore is not None
237
+ ), "Please set semaphore with `set_semaphore` first!"
235
238
  async with self.semaphore:
236
239
  if (
237
- self.config.text["request_type"] == "openai"
238
- or self.config.text["request_type"] == "deepseek"
239
- or self.config.text["request_type"] == "qwen"
240
+ self.config.request_type == "openai"
241
+ or self.config.request_type == "deepseek"
242
+ or self.config.request_type == "qwen"
243
+ or self.config.request_type == "siliconflow"
240
244
  ):
241
245
  for attempt in range(retries):
242
246
  try:
243
247
  client = self._get_next_client()
244
248
  response = await client.chat.completions.create(
245
- model=self.config.text["model"],
249
+ model=self.config.model,
246
250
  messages=dialog,
251
+ response_format=response_format,
247
252
  temperature=temperature,
248
253
  max_tokens=max_tokens,
249
254
  top_p=top_p,
@@ -256,7 +261,9 @@ class LLM:
256
261
  ) # type: ignore
257
262
  self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
258
263
  self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
259
- self._client_usage[self._current_client_index]["request_number"] += 1
264
+ self._client_usage[self._current_client_index][
265
+ "request_number"
266
+ ] += 1
260
267
  end_time = time.time()
261
268
  log["consumption"] = end_time - start_time
262
269
  log["input_tokens"] = response.usage.prompt_tokens
@@ -288,16 +295,15 @@ class LLM:
288
295
  except Exception as e:
289
296
  print("LLM Error (OpenAI):", e)
290
297
  if attempt < retries - 1:
291
- print(dialog)
292
298
  await asyncio.sleep(2**attempt)
293
299
  else:
294
300
  raise e
295
- elif self.config.text["request_type"] == "zhipuai":
301
+ elif self.config.request_type == "zhipuai":
296
302
  for attempt in range(retries):
297
303
  try:
298
304
  client = self._get_next_client()
299
305
  response = client.chat.asyncCompletions.create( # type: ignore
300
- model=self.config.text["model"],
306
+ model=self.config.model,
301
307
  messages=dialog,
302
308
  temperature=temperature,
303
309
  top_p=top_p,
@@ -323,10 +329,12 @@ class LLM:
323
329
 
324
330
  self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
325
331
  self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
326
- self._client_usage[self._current_client_index]["request_number"] += 1
332
+ self._client_usage[self._current_client_index][
333
+ "request_number"
334
+ ] += 1
327
335
  end_time = time.time()
328
336
  log["used_time"] = end_time - start_time
329
- log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
337
+ log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens # type: ignore
330
338
  self._log_list.append(log)
331
339
  if tools and result_response.choices[0].message.tool_calls: # type: ignore
332
340
  return json.loads(
@@ -349,118 +357,4 @@ class LLM:
349
357
  else:
350
358
  raise e
351
359
  else:
352
- print("ERROR: Wrong Config")
353
- return "wrong config"
354
-
355
- async def img_understand(
356
- self, img_path: Union[str, list[str]], prompt: Optional[str] = None
357
- ) -> str:
358
- """
359
- Analyzes and understands images using external APIs.
360
-
361
- - **Args**:
362
- img_path (Union[str, list[str]]): Path or list of paths to the images for analysis.
363
- prompt (Optional[str]): Guidance text for understanding the images.
364
-
365
- - **Returns**:
366
- str: The content derived from understanding the images.
367
- """
368
- ppt = "如何理解这幅图像?"
369
- if prompt != None:
370
- ppt = prompt
371
- if self.config.image_u["request_type"] == "openai":
372
- if "api_base" in self.config.image_u.keys():
373
- api_base = self.config.image_u["api_base"]
374
- else:
375
- api_base = None
376
- client = OpenAI(
377
- api_key=self.config.text["api_key"],
378
- base_url=api_base,
379
- )
380
- content = []
381
- content.append({"type": "text", "text": ppt})
382
- if isinstance(img_path, str):
383
- base64_image = encode_image(img_path)
384
- content.append(
385
- {
386
- "type": "image_url",
387
- "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
388
- }
389
- )
390
- elif isinstance(img_path, list) and all(
391
- isinstance(item, str) for item in img_path
392
- ):
393
- for item in img_path:
394
- base64_image = encode_image(item)
395
- content.append(
396
- {
397
- "type": "image_url",
398
- "image_url": {
399
- "url": f"data:image/jpeg;base64,{base64_image}"
400
- },
401
- }
402
- )
403
- response = client.chat.completions.create(
404
- model=self.config.image_u["model"],
405
- messages=[{"role": "user", "content": content}],
406
- )
407
- return response.choices[0].message.content # type: ignore
408
- elif self.config.image_u["request_type"] == "qwen":
409
- content = []
410
- if isinstance(img_path, str):
411
- content.append({"image": "file://" + img_path})
412
- content.append({"text": ppt})
413
- elif isinstance(img_path, list) and all(
414
- isinstance(item, str) for item in img_path
415
- ):
416
- for item in img_path:
417
- content.append({"image": "file://" + item})
418
- content.append({"text": ppt})
419
-
420
- dialog = [{"role": "user", "content": content}]
421
- response = dashscope.MultiModalConversation.call(
422
- model=self.config.image_u["model"],
423
- api_key=self.config.image_u["api_key"],
424
- messages=dialog,
425
- )
426
- if response.status_code == HTTPStatus.OK: # type: ignore
427
- return response.output.choices[0]["message"]["content"] # type: ignore
428
- else:
429
- print(response.code) # type: ignore # The error code.
430
- return "Error"
431
- else:
432
- print(
433
- "ERROR: wrong image understanding type, only 'openai' and 'openai' is available"
434
- )
435
- return "Error"
436
-
437
- async def img_generate(self, prompt: str, size: str = "512*512", quantity: int = 1):
438
- """
439
- Generates images based on a given prompt.
440
-
441
- - **Args**:
442
- prompt (str): Prompt for generating images.
443
- size (str): Size of the generated images, default is '512*512'.
444
- quantity (int): Number of images to generate, default is 1.
445
-
446
- - **Returns**:
447
- list[PIL.Image.Image]: List of generated PIL Image objects.
448
- """
449
- rsp = ImageSynthesis.call(
450
- model=self.config.image_g["model"],
451
- api_key=self.config.image_g["api_key"],
452
- prompt=prompt,
453
- n=quantity,
454
- size=size,
455
- )
456
- if rsp.status_code == HTTPStatus.OK:
457
- res = []
458
- for result in rsp.output.results:
459
- res.append(Image.open(BytesIO(requests.get(result.url).content)))
460
- return res
461
- else:
462
- print(
463
- "Failed, status_code: %s, code: %s, message: %s"
464
- % (rsp.status_code, rsp.code, rsp.message)
465
- )
466
- return None
360
+ raise ValueError("ERROR: Wrong Config")
@@ -335,9 +335,7 @@ class StreamMemory:
335
335
  )
336
336
  return "\n".join(formatted_results)
337
337
 
338
- async def get_by_ids(
339
- self, memory_ids: Union[int, list[int]]
340
- ) -> Coroutine[Any, Any, str]:
338
+ async def get_by_ids(self, memory_ids: Union[int, list[int]]) -> str:
341
339
  """获取指定ID的记忆"""
342
340
  memories = [memory for memory in self._memories if memory.id in memory_ids]
343
341
  sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
@@ -491,15 +489,18 @@ class StreamMemory:
491
489
  - **Returns**:
492
490
  - `list[dict]`: List of all memory nodes as dictionaries.
493
491
  """
494
- return [{
495
- "id": memory.id,
496
- "cognition_id": memory.cognition_id,
497
- "tag": memory.tag.value,
498
- "location": memory.location,
499
- "description": memory.description,
500
- "day": memory.day,
501
- "t": memory.t,
502
- } for memory in self._memories]
492
+ return [
493
+ {
494
+ "id": memory.id,
495
+ "cognition_id": memory.cognition_id,
496
+ "tag": memory.tag.value,
497
+ "location": memory.location,
498
+ "description": memory.description,
499
+ "day": memory.day,
500
+ "t": memory.t,
501
+ }
502
+ for memory in self._memories
503
+ ]
503
504
 
504
505
 
505
506
  class StatusMemory:
@@ -10,7 +10,8 @@ from typing import Any, Optional, Union
10
10
  import ray
11
11
  from ray.util.queue import Queue
12
12
 
13
- from ..llm import LLM, LLMConfig
13
+ from ..configs import LLMRequestConfig
14
+ from ..llm import LLM
14
15
  from ..utils.decorators import lock_decorator
15
16
 
16
17
  DEFAULT_ERROR_STRING = """
@@ -95,7 +96,7 @@ class MessageInterceptor:
95
96
  self,
96
97
  blocks: Optional[list[MessageBlockBase]] = None,
97
98
  black_list: Optional[list[tuple[str, str]]] = None,
98
- llm_config: Optional[dict] = None,
99
+ llm_config: Optional[LLMRequestConfig] = None,
99
100
  queue: Optional[Queue] = None,
100
101
  ) -> None:
101
102
  """
@@ -104,7 +105,7 @@ class MessageInterceptor:
104
105
  - **Args**:
105
106
  - `blocks` (Optional[list[MessageBlockBase]], optional): Initial list of message interception rules. Defaults to an empty list.
106
107
  - `black_list` (Optional[list[tuple[str, str]]], optional): Initial blacklist of communication pairs. Defaults to an empty list.
107
- - `llm_config` (Optional[dict], optional): Configuration dictionary for initializing the LLM instance. Defaults to None.
108
+ - `llm_config` (Optional[LLMRequestConfig], optional): Configuration dictionary for initializing the LLM instance. Defaults to None.
108
109
  - `queue` (Optional[Queue], optional): Queue for message processing. Defaults to None.
109
110
  """
110
111
  if blocks is not None:
@@ -119,7 +120,7 @@ class MessageInterceptor:
119
120
  []
120
121
  ) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
121
122
  if llm_config:
122
- self._llm = LLM(LLMConfig(llm_config))
123
+ self._llm = LLM(llm_config)
123
124
  else:
124
125
  self._llm = None
125
126
  self._queue = queue