pycityagent 2.0.0a19__py3-none-any.whl → 2.0.0a20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/agent.py CHANGED
@@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
8
8
  from copy import deepcopy
9
9
  from datetime import datetime
10
10
  from enum import Enum
11
- from typing import Any, Dict, List, Optional
11
+ from typing import Any, Optional
12
12
  from uuid import UUID
13
13
 
14
14
  import fastavro
@@ -55,7 +55,7 @@ class Agent(ABC):
55
55
  simulator: Optional[Simulator] = None,
56
56
  mlflow_client: Optional[MlflowClient] = None,
57
57
  memory: Optional[Memory] = None,
58
- avro_file: Optional[Dict[str, str]] = None,
58
+ avro_file: Optional[dict[str, str]] = None,
59
59
  ) -> None:
60
60
  """
61
61
  Initialize the Agent.
@@ -69,7 +69,7 @@ class Agent(ABC):
69
69
  simulator (Simulator, optional): The simulator object. Defaults to None.
70
70
  mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
71
71
  memory (Memory, optional): The memory of the agent. Defaults to None.
72
- avro_file (Dict[str, str], optional): The avro file of the agent. Defaults to None.
72
+ avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
73
73
  """
74
74
  self._name = name
75
75
  self._type = type
@@ -85,7 +85,7 @@ class Agent(ABC):
85
85
  self._has_bound_to_simulator = False
86
86
  self._has_bound_to_economy = False
87
87
  self._blocked = False
88
- self._interview_history: List[Dict] = [] # 存储采访历史
88
+ self._interview_history: list[dict] = [] # 存储采访历史
89
89
  self._person_template = PersonService.default_dict_person()
90
90
  self._avro_file = avro_file
91
91
 
@@ -137,7 +137,7 @@ class Agent(ABC):
137
137
  """
138
138
  self._exp_id = exp_id
139
139
 
140
- def set_avro_file(self, avro_file: Dict[str, str]):
140
+ def set_avro_file(self, avro_file: dict[str, str]):
141
141
  """
142
142
  Set the avro file of the agent.
143
143
  """
@@ -557,7 +557,7 @@ class InstitutionAgent(Agent):
557
557
  avro_file=avro_file,
558
558
  )
559
559
  # 添加响应收集器
560
- self._gather_responses: Dict[str, asyncio.Future] = {}
560
+ self._gather_responses: dict[str, asyncio.Future] = {}
561
561
 
562
562
  async def bind_to_simulator(self):
563
563
  await self._bind_to_economy()
@@ -659,7 +659,7 @@ class InstitutionAgent(Agent):
659
659
  }
660
660
  )
661
661
 
662
- async def gather_messages(self, agent_uuids: list[str], target: str) -> List[dict]:
662
+ async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
663
663
  """从多个智能体收集消息
664
664
 
665
665
  Args:
@@ -667,7 +667,7 @@ class InstitutionAgent(Agent):
667
667
  target: 要收集的信息类型
668
668
 
669
669
  Returns:
670
- List[dict]: 收集到的所有响应
670
+ list[dict]: 收集到的所有响应
671
671
  """
672
672
  # 为每个agent创建Future
673
673
  futures = {}
@@ -316,3 +316,40 @@ class EconomyClient:
316
316
  await self._aio_stub.GetOrgEntityIds(request)
317
317
  )
318
318
  return list(response.org_ids)
319
+
320
+ async def add_delta_value(
321
+ self,
322
+ id: int,
323
+ key: str,
324
+ value: Any,
325
+ ) -> Any:
326
+ """
327
+ Add key-value pair
328
+
329
+ Args:
330
+ - id (int): the id of `Org` or `Agent`.
331
+ - key (str): the attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
332
+
333
+
334
+ Returns:
335
+ - Any
336
+ """
337
+ pascal_key = _snake_to_pascal(key)
338
+ _request_type = getattr(org_service, f"Add{pascal_key}Request")
339
+ _request_func = getattr(self._aio_stub, f"Add{pascal_key}")
340
+ _available_keys = {
341
+ "inventory",
342
+ "price",
343
+ "interest_rate",
344
+ "currency",
345
+ }
346
+ if key not in _available_keys:
347
+ raise ValueError(f"Invalid key `{key}`, can only be {_available_keys}!")
348
+ return await _request_func(
349
+ _request_type(
350
+ **{
351
+ "org_id": id,
352
+ f"delta_{key}": value,
353
+ }
354
+ )
355
+ )
@@ -1,9 +1,7 @@
1
- from typing import List
2
-
3
1
  __all__ = ["wrap_feature_collection"]
4
2
 
5
3
 
6
- def wrap_feature_collection(features: List[dict], name: str):
4
+ def wrap_feature_collection(features: list[dict], name: str):
7
5
  """
8
6
  将 GeoJSON Feature 集合包装为 FeatureCollection
9
7
  Wrap GeoJSON Feature collection as FeatureCollection
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
2
+ from typing import Literal, Union
3
3
 
4
4
  import numpy as np
5
5
 
@@ -14,12 +14,12 @@ def point_on_line_given_distance(start_node, end_node, distance):
14
14
  return the coordinates of the point reached after traveling s units along the line, starting from start_point.
15
15
 
16
16
  Args:
17
- start_point (tuple): Tuple of (x, y) representing the starting point on the line.
18
- end_point (tuple): Tuple of (x, y) representing the ending point on the line.
17
+ start_point (tuple): tuple of (x, y) representing the starting point on the line.
18
+ end_point (tuple): tuple of (x, y) representing the ending point on the line.
19
19
  distance (float): Distance to travel along the line, starting from start_point.
20
20
 
21
21
  Returns:
22
- tuple: Tuple of (x, y) representing the new point reached after traveling s units along the line.
22
+ tuple: tuple of (x, y) representing the new point reached after traveling s units along the line.
23
23
  """
24
24
 
25
25
  x1, y1 = start_node["x"], start_node["y"]
@@ -49,7 +49,7 @@ def point_on_line_given_distance(start_node, end_node, distance):
49
49
 
50
50
 
51
51
  def get_key_index_in_lane(
52
- nodes: List[Dict[str, float]],
52
+ nodes: list[dict[str, float]],
53
53
  distance: float,
54
54
  direction: Union[Literal["front"], Literal["back"]],
55
55
  ) -> int:
@@ -61,10 +61,10 @@ def get_key_index_in_lane(
61
61
  _index_offset, _index_factor = len(_nodes) - 1, -1
62
62
  else:
63
63
  raise ValueError(f"Invalid direction type {direction}!")
64
- _lane_points: List[Tuple[float, float, float]] = [
64
+ _lane_points: list[tuple[float, float, float]] = [
65
65
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
66
66
  ]
67
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
67
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
68
68
  _s = 0.0
69
69
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
70
70
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -82,20 +82,20 @@ def get_key_index_in_lane(
82
82
 
83
83
 
84
84
  def get_xy_in_lane(
85
- nodes: List[Dict[str, float]],
85
+ nodes: list[dict[str, float]],
86
86
  distance: float,
87
87
  direction: Union[Literal["front"], Literal["back"]],
88
- ) -> Tuple[float, float]:
88
+ ) -> tuple[float, float]:
89
89
  if direction == "front":
90
90
  _nodes = [n for n in nodes]
91
91
  elif direction == "back":
92
92
  _nodes = [n for n in nodes[::-1]]
93
93
  else:
94
94
  raise ValueError(f"Invalid direction type {direction}!")
95
- _lane_points: List[Tuple[float, float, float]] = [
95
+ _lane_points: list[tuple[float, float, float]] = [
96
96
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
97
97
  ]
98
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
98
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
99
99
  _s = 0.0
100
100
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
101
101
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -122,7 +122,7 @@ def get_xy_in_lane(
122
122
 
123
123
 
124
124
  def get_direction_by_s(
125
- nodes: List[Dict[str, float]],
125
+ nodes: list[dict[str, float]],
126
126
  distance: float,
127
127
  direction: Union[Literal["front"], Literal["back"]],
128
128
  ) -> float:
@@ -132,11 +132,11 @@ def get_direction_by_s(
132
132
  _nodes = [n for n in nodes[::-1]]
133
133
  else:
134
134
  raise ValueError(f"Invalid direction type {direction}!")
135
- _lane_points: List[Tuple[float, float, float]] = [
135
+ _lane_points: list[tuple[float, float, float]] = [
136
136
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
137
137
  ]
138
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
139
- _line_directions: List[Tuple[float, float]] = []
138
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
139
+ _line_directions: list[tuple[float, float]] = []
140
140
  _s = 0.0
141
141
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
142
142
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -1,7 +1,6 @@
1
1
  """简单的基于内存的embedding实现"""
2
2
 
3
3
  import numpy as np
4
- from typing import List, Dict, Optional
5
4
  import hashlib
6
5
  import json
7
6
 
@@ -22,34 +21,34 @@ class SimpleEmbedding:
22
21
  """
23
22
  self.vector_dim = vector_dim
24
23
  self.cache_size = cache_size
25
- self._cache: Dict[str, np.ndarray] = {}
26
- self._vocab: Dict[str, int] = {} # 词汇表
27
- self._idf: Dict[str, float] = {} # 逆文档频率
24
+ self._cache: dict[str, np.ndarray] = {}
25
+ self._vocab: dict[str, int] = {} # 词汇表
26
+ self._idf: dict[str, float] = {} # 逆文档频率
28
27
  self._doc_count = 0 # 文档总数
29
28
 
30
29
  def _text_to_hash(self, text: str) -> str:
31
30
  """将文本转换为hash值"""
32
31
  return hashlib.md5(text.encode()).hexdigest()
33
32
 
34
- def _tokenize(self, text: str) -> List[str]:
33
+ def _tokenize(self, text: str) -> list[str]:
35
34
  """简单的分词"""
36
35
  # 这里使用简单的空格分词,实际应用中可以使用更复杂的分词方法
37
36
  return text.lower().split()
38
37
 
39
- def _update_vocab(self, tokens: List[str]):
38
+ def _update_vocab(self, tokens: list[str]):
40
39
  """更新词汇表"""
41
40
  for token in set(tokens): # 使用set去重
42
41
  if token not in self._vocab:
43
42
  self._vocab[token] = len(self._vocab)
44
43
 
45
- def _update_idf(self, tokens: List[str]):
44
+ def _update_idf(self, tokens: list[str]):
46
45
  """更新IDF值"""
47
46
  self._doc_count += 1
48
47
  unique_tokens = set(tokens)
49
48
  for token in unique_tokens:
50
49
  self._idf[token] = self._idf.get(token, 0) + 1
51
50
 
52
- def _calculate_tf(self, tokens: List[str]) -> Dict[str, float]:
51
+ def _calculate_tf(self, tokens: list[str]) -> dict[str, float]:
53
52
  """计算词频(TF)"""
54
53
  tf = {}
55
54
  total_tokens = len(tokens)
@@ -60,7 +59,7 @@ class SimpleEmbedding:
60
59
  tf[token] /= total_tokens
61
60
  return tf
62
61
 
63
- def _calculate_tfidf(self, tokens: List[str]) -> np.ndarray:
62
+ def _calculate_tfidf(self, tokens: list[str]) -> np.ndarray:
64
63
  """计算TF-IDF向量"""
65
64
  vector = np.zeros(self.vector_dim)
66
65
  tf = self._calculate_tf(tokens)
pycityagent/llm/llm.py CHANGED
@@ -14,7 +14,7 @@ import requests
14
14
  from dashscope import ImageSynthesis
15
15
  from PIL import Image
16
16
  from io import BytesIO
17
- from typing import Any, Optional, Union, List, Dict
17
+ from typing import Any, Optional, Union
18
18
  from .llmconfig import *
19
19
  from .utils import *
20
20
 
@@ -117,8 +117,8 @@ Token Usage:
117
117
  presence_penalty: Optional[float] = None,
118
118
  timeout: int = 300,
119
119
  retries=3,
120
- tools: Optional[List[Dict[str, Any]]] = None,
121
- tool_choice: Optional[Dict[str, Any]] = None,
120
+ tools: Optional[list[dict[str, Any]]] = None,
121
+ tool_choice: Optional[dict[str, Any]] = None,
122
122
  ):
123
123
  """
124
124
  异步版文本请求
@@ -227,9 +227,9 @@ Token Usage:
227
227
  self.prompt_tokens_used += result_response.usage.prompt_tokens # type: ignore
228
228
  self.completion_tokens_used += result_response.usage.completion_tokens # type: ignore
229
229
  self.request_number += 1
230
- if tools and result_response.choices[0].message.tool_calls:
230
+ if tools and result_response.choices[0].message.tool_calls: # type: ignore
231
231
  return json.loads(
232
- result_response.choices[0]
232
+ result_response.choices[0] # type: ignore
233
233
  .message.tool_calls[0]
234
234
  .function.arguments
235
235
  )
@@ -2,7 +2,8 @@ import asyncio
2
2
  import logging
3
3
  from copy import deepcopy
4
4
  from datetime import datetime
5
- from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
5
+ from typing import Any, Literal, Optional, Union
6
+ from collections.abc import Sequence,Callable
6
7
 
7
8
  import numpy as np
8
9
  from pyparsing import deque
@@ -27,10 +28,10 @@ class Memory:
27
28
 
28
29
  def __init__(
29
30
  self,
30
- config: Optional[Dict[Any, Any]] = None,
31
- profile: Optional[Dict[Any, Any]] = None,
32
- base: Optional[Dict[Any, Any]] = None,
33
- motion: Optional[Dict[Any, Any]] = None,
31
+ config: Optional[dict[Any, Any]] = None,
32
+ profile: Optional[dict[Any, Any]] = None,
33
+ base: Optional[dict[Any, Any]] = None,
34
+ motion: Optional[dict[Any, Any]] = None,
34
35
  activate_timestamp: bool = False,
35
36
  embedding_model: Any = None,
36
37
  ) -> None:
@@ -38,7 +39,7 @@ class Memory:
38
39
  Initializes the Memory with optional configuration.
39
40
 
40
41
  Args:
41
- config (Optional[Dict[Any, Any]], optional):
42
+ config (Optional[dict[Any, Any]], optional):
42
43
  A configuration dictionary for dynamic memory. The dictionary format is:
43
44
  - Key: The name of the dynamic memory field.
44
45
  - Value: Can be one of two formats:
@@ -46,24 +47,24 @@ class Memory:
46
47
  2. A callable that returns the default value when invoked (useful for complex default values).
47
48
  Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
48
49
  Defaults to None.
49
- profile (Optional[Dict[Any, Any]], optional): profile attribute dict.
50
- base (Optional[Dict[Any, Any]], optional): base attribute dict from City Simulator.
51
- motion (Optional[Dict[Any, Any]], optional): motion attribute dict from City Simulator.
50
+ profile (Optional[dict[Any, Any]], optional): profile attribute dict.
51
+ base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
52
+ motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
52
53
  activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
53
54
  embedding_model (Any): The embedding model for memory search.
54
55
  """
55
- self.watchers: Dict[str, List[Callable]] = {}
56
+ self.watchers: dict[str, list[Callable]] = {}
56
57
  self._lock = asyncio.Lock()
57
58
  self.embedding_model = embedding_model
58
59
 
59
60
  # 初始化embedding存储
60
61
  self._embeddings = {"state": {}, "profile": {}, "dynamic": {}}
61
62
 
62
- _dynamic_config: Dict[Any, Any] = {}
63
- _state_config: Dict[Any, Any] = {}
64
- _profile_config: Dict[Any, Any] = {}
63
+ _dynamic_config: dict[Any, Any] = {}
64
+ _state_config: dict[Any, Any] = {}
65
+ _profile_config: dict[Any, Any] = {}
65
66
  # 记录哪些字段需要embedding
66
- self._embedding_fields: Dict[str, bool] = {}
67
+ self._embedding_fields: dict[str, bool] = {}
67
68
 
68
69
  if config is not None:
69
70
  for k, v in config.items():
@@ -303,7 +304,7 @@ class Memory:
303
304
 
304
305
  async def update_batch(
305
306
  self,
306
- content: Union[Dict, Sequence[Tuple[Any, Any]]],
307
+ content: Union[dict, Sequence[tuple[Any, Any]]],
307
308
  mode: Union[Literal["replace"], Literal["merge"]] = "replace",
308
309
  store_snapshot: bool = False,
309
310
  protect_llm_read_only_fields: bool = True,
@@ -312,7 +313,7 @@ class Memory:
312
313
  Updates multiple values in the memory at once.
313
314
 
314
315
  Args:
315
- content (Union[Dict, Sequence[Tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
316
+ content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
316
317
  mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
317
318
  store_snapshot (bool): Whether to store a snapshot of the memory after the update.
318
319
  protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
@@ -321,9 +322,9 @@ class Memory:
321
322
  TypeError: If the content type is neither a dictionary nor a sequence of tuples.
322
323
  """
323
324
  if isinstance(content, dict):
324
- _list_content: List[Tuple[Any, Any]] = [(k, v) for k, v in content.items()]
325
+ _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
325
326
  elif isinstance(content, Sequence):
326
- _list_content: List[Tuple[Any, Any]] = [(k, v) for k, v in content]
327
+ _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
327
328
  else:
328
329
  raise TypeError(f"Invalid content type `{type(content)}`!")
329
330
  for k, v in _list_content[:1]:
@@ -353,12 +354,12 @@ class Memory:
353
354
  @lock_decorator
354
355
  async def export(
355
356
  self,
356
- ) -> Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]:
357
+ ) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
357
358
  """
358
359
  Exports the current state of all memory sections.
359
360
 
360
361
  Returns:
361
- Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
362
+ tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
362
363
  """
363
364
  return (
364
365
  await self._profile.export(),
@@ -369,14 +370,14 @@ class Memory:
369
370
  @lock_decorator
370
371
  async def load(
371
372
  self,
372
- snapshots: Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]],
373
+ snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
373
374
  reset_memory: bool = True,
374
375
  ) -> None:
375
376
  """
376
377
  Import the snapshot memories of all sections.
377
378
 
378
379
  Args:
379
- snapshots (Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]): The exported snapshots.
380
+ snapshots (tuple[Sequence[dict], Sequence[dict], Sequence[dict]]): The exported snapshots.
380
381
  reset_memory (bool): Whether to reset previous memory.
381
382
  """
382
383
  _profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
@@ -1,5 +1,6 @@
1
- from .mlflow_client import MlflowClient
1
+ from .mlflow_client import MlflowClient,init_mlflow_connection
2
2
 
3
3
  __all__ = [
4
4
  "MlflowClient",
5
+ "init_mlflow_connection",
5
6
  ]
@@ -18,6 +18,55 @@ from ..utils.decorators import lock_decorator
18
18
  logger = logging.getLogger("mlflow")
19
19
 
20
20
 
21
+ def init_mlflow_connection(
22
+ config: dict,
23
+ mlflow_run_name: Optional[str] = None,
24
+ experiment_name: Optional[str] = None,
25
+ experiment_description: Optional[str] = None,
26
+ experiment_tags: Optional[dict[str, Any]] = None,
27
+ ) -> tuple[str, tuple[str, mlflow.MlflowClient, Run, str]]:
28
+
29
+ os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
30
+ os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
31
+
32
+ run_uuid = str(uuid.uuid4())
33
+ # run name
34
+ if mlflow_run_name is None:
35
+ mlflow_run_name = f"exp_{run_uuid}"
36
+
37
+ # exp name
38
+ if experiment_name is None:
39
+ experiment_name = f"run_{run_uuid}"
40
+
41
+ # tags
42
+ if experiment_tags is None:
43
+ experiment_tags = {}
44
+ if experiment_description is not None:
45
+ experiment_tags["mlflow.note.content"] = experiment_description
46
+
47
+ uri = config["mlflow_uri"]
48
+ client = mlflow.MlflowClient(tracking_uri=uri)
49
+
50
+ # experiment
51
+ try:
52
+ experiment_id = client.create_experiment(
53
+ name=experiment_name,
54
+ tags=experiment_tags,
55
+ )
56
+ except Exception as e:
57
+ experiment = client.get_experiment_by_name(experiment_name)
58
+ if experiment is None:
59
+ raise e
60
+ experiment_id = experiment.experiment_id
61
+
62
+ # run
63
+ run = client.create_run(experiment_id=experiment_id, run_name=mlflow_run_name)
64
+
65
+ run_id = run.info.run_id
66
+
67
+ return run_id, (uri, client, run, run_uuid)
68
+
69
+
21
70
  class MlflowClient:
22
71
  """
23
72
  - Mlflow client
@@ -30,42 +79,30 @@ class MlflowClient:
30
79
  experiment_name: Optional[str] = None,
31
80
  experiment_description: Optional[str] = None,
32
81
  experiment_tags: Optional[dict[str, Any]] = None,
82
+ run_id: Optional[str] = None,
33
83
  ) -> None:
34
- os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
35
- os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
36
- self._mlflow_uri = uri = config["mlflow_uri"]
37
- self._client = client = mlflow.MlflowClient(tracking_uri=uri)
38
- self._run_uuid = run_uuid = str(uuid.uuid4())
39
- self._lock = asyncio.Lock()
40
- # run name
41
- if mlflow_run_name is None:
42
- mlflow_run_name = f"exp_{run_uuid}"
43
-
44
- # exp name
45
- if experiment_name is None:
46
- experiment_name = f"run_{run_uuid}"
47
-
48
- # tags
49
- if experiment_tags is None:
50
- experiment_tags = {}
51
- if experiment_description is not None:
52
- experiment_tags["mlflow.note.content"] = experiment_description
53
-
54
- try:
55
- self._experiment_id = experiment_id = client.create_experiment(
56
- name=experiment_name,
57
- tags=experiment_tags,
84
+ if run_id is None:
85
+ self._run_id, (
86
+ self._mlflow_uri,
87
+ self._client,
88
+ self._run,
89
+ self._run_uuid,
90
+ ) = init_mlflow_connection(
91
+ config=config,
92
+ mlflow_run_name=mlflow_run_name,
93
+ experiment_name=experiment_name,
94
+ experiment_description=experiment_description,
95
+ experiment_tags=experiment_tags,
58
96
  )
59
- except Exception as e:
60
- experiment = client.get_experiment_by_name(experiment_name)
61
- if experiment is None:
62
- raise e
63
- self._experiment_id = experiment_id = experiment.experiment_id
64
-
65
- self._run = run = client.create_run(
66
- experiment_id=experiment_id, run_name=mlflow_run_name
67
- )
68
- self._run_id = run.info.run_id
97
+ else:
98
+ self._mlflow_uri = uri = config["mlflow_uri"]
99
+ os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
100
+ os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
101
+ self._client = client = mlflow.MlflowClient(tracking_uri=uri)
102
+ self._run = client.get_run(run_id=run_id)
103
+ self._run_id = run_id
104
+ self._run_uuid = run_uuid = str(uuid.uuid4())
105
+ self._lock = asyncio.Lock()
69
106
 
70
107
  @property
71
108
  def client(
@@ -77,6 +114,7 @@ class MlflowClient:
77
114
  def run_id(
78
115
  self,
79
116
  ) -> str:
117
+ assert self._run_id is not None
80
118
  return self._run_id
81
119
 
82
120
  @lock_decorator
@@ -35,7 +35,8 @@ class AgentGroup:
35
35
  enable_avro: bool,
36
36
  avro_path: Path,
37
37
  enable_pgsql: bool,
38
- pgsql_args: tuple[str, str, str, str, str],
38
+ pgsql_copy_writer: ray.ObjectRef,
39
+ mlflow_run_id: str,
39
40
  logging_level: int,
40
41
  ):
41
42
  logger.setLevel(logging_level)
@@ -88,6 +89,7 @@ class AgentGroup:
88
89
  config=_mlflow_config,
89
90
  mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
90
91
  experiment_name=exp_name,
92
+ run_id=mlflow_run_id,
91
93
  )
92
94
  else:
93
95
  self.mlflow_client = None
@@ -3,6 +3,7 @@ import json
3
3
  import logging
4
4
  import os
5
5
  import random
6
+ import time
6
7
  import uuid
7
8
  from collections.abc import Callable, Sequence
8
9
  from concurrent.futures import ThreadPoolExecutor
@@ -11,6 +12,7 @@ from pathlib import Path
11
12
  from typing import Any, Optional, Union
12
13
 
13
14
  import pycityproto.city.economy.v2.economy_pb2 as economyv2
15
+ import ray
14
16
  import yaml
15
17
  from mosstool.map._map_util.const import AOI_START_ID
16
18
 
@@ -18,6 +20,7 @@ from ..agent import Agent, InstitutionAgent
18
20
  from ..environment.simulator import Simulator
19
21
  from ..memory.memory import Memory
20
22
  from ..message.messager import Messager
23
+ from ..metrics import init_mlflow_connection
21
24
  from ..survey import Survey
22
25
  from .agentgroup import AgentGroup
23
26
 
@@ -165,7 +168,8 @@ class AgentSimulation:
165
168
  enable_avro: bool,
166
169
  avro_path: Path,
167
170
  enable_pgsql: bool,
168
- pgsql_args: tuple[str, str, str, str, str],
171
+ pgsql_copy_writer: ray.ObjectRef,
172
+ mlflow_run_id: str = None, # type: ignore
169
173
  logging_level: int = logging.WARNING,
170
174
  ):
171
175
  """创建远程组"""
@@ -177,7 +181,8 @@ class AgentSimulation:
177
181
  enable_avro,
178
182
  avro_path,
179
183
  enable_pgsql,
180
- pgsql_args,
184
+ pgsql_copy_writer,
185
+ mlflow_run_id,
181
186
  logging_level,
182
187
  )
183
188
  return group_name, group, agents
@@ -267,6 +272,16 @@ class AgentSimulation:
267
272
 
268
273
  class_init_index += agent_count_i
269
274
 
275
+ # 初始化mlflow连接
276
+ _mlflow_config = self.config.get("metric_request", {}).get("mlflow")
277
+ if _mlflow_config:
278
+ mlflow_run_id, _ = init_mlflow_connection(
279
+ config=_mlflow_config,
280
+ mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
281
+ experiment_name=self.exp_name,
282
+ )
283
+ else:
284
+ mlflow_run_id = None
270
285
  # 收集所有创建组的参数
271
286
  creation_tasks = []
272
287
  for group_name, agents in group_creation_params:
@@ -279,7 +294,10 @@ class AgentSimulation:
279
294
  self.enable_avro,
280
295
  self.avro_path,
281
296
  self.enable_pgsql,
282
- self._pgsql_args,
297
+ # TODO:
298
+ # self._pgsql_copy_writer, # type:ignore
299
+ None,
300
+ mlflow_run_id,
283
301
  self.logging_level,
284
302
  )
285
303
  creation_tasks.append((group_name, group, agents))
@@ -4,7 +4,7 @@ Base class of parser
4
4
 
5
5
  import re
6
6
  from abc import ABC, abstractmethod
7
- from typing import Any, Dict, List, Optional, Tuple, Union
7
+ from typing import Any, Union
8
8
 
9
9
 
10
10
  class ParserBase(ABC):
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Dict, List, Optional, Union
1
+ from typing import Optional, Union
2
2
  import re
3
3
 
4
4
 
@@ -10,7 +10,7 @@ class FormatPrompt:
10
10
  Attributes:
11
11
  template (str): The template string containing placeholders.
12
12
  system_prompt (Optional[str]): An optional system prompt to add to the dialog.
13
- variables (List[str]): A list of variable names extracted from the template.
13
+ variables (list[str]): A list of variable names extracted from the template.
14
14
  formatted_string (str): The formatted string derived from the template and provided variables.
15
15
  """
16
16
 
@@ -27,12 +27,12 @@ class FormatPrompt:
27
27
  self.variables = self._extract_variables()
28
28
  self.formatted_string = "" # To store the formatted string
29
29
 
30
- def _extract_variables(self) -> List[str]:
30
+ def _extract_variables(self) -> list[str]:
31
31
  """
32
32
  Extracts variable names from the template string.
33
33
 
34
34
  Returns:
35
- List[str]: A list of variable names found within the template.
35
+ list[str]: A list of variable names found within the template.
36
36
  """
37
37
  return re.findall(r"\{(\w+)\}", self.template)
38
38
 
@@ -51,12 +51,12 @@ class FormatPrompt:
51
51
  ) # Store the formatted string
52
52
  return self.formatted_string
53
53
 
54
- def to_dialog(self) -> List[Dict[str, str]]:
54
+ def to_dialog(self) -> list[dict[str, str]]:
55
55
  """
56
56
  Converts the formatted prompt and optional system prompt into a dialog format.
57
57
 
58
58
  Returns:
59
- List[Dict[str, str]]: A list representing the dialog with roles and content.
59
+ list[dict[str, str]]: A list representing the dialog with roles and content.
60
60
  """
61
61
  dialog = []
62
62
  if self.system_prompt:
@@ -1,6 +1,6 @@
1
1
  import time
2
- from typing import Any, Callable, Dict, List, Optional, Union
3
-
2
+ from typing import Any, Optional, Union
3
+ from collections.abc import Callable
4
4
  from mlflow.entities import Metric
5
5
 
6
6
  from ..agent import Agent
@@ -76,7 +76,7 @@ class SencePOI(Tool):
76
76
  Attributes:
77
77
  radius (int): The radius within which to search for POIs.
78
78
  category_prefix (str): The prefix for the categories of POIs to consider.
79
- variables (List[str]): A list of variables relevant to the tool's operation.
79
+ variables (list[str]): A list of variables relevant to the tool's operation.
80
80
 
81
81
  Args:
82
82
  radius (int, optional): The circular search radius. Defaults to 100.
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Any, Callable, Dict, List, Optional, Union, Type
2
+ from typing import Optional
3
3
  import socket
4
4
  from ..memory import Memory
5
5
  from ..environment import Simulator
@@ -11,7 +11,7 @@ class EventTrigger:
11
11
  """Base class for event triggers that wait for specific conditions to be met."""
12
12
 
13
13
  # 定义该trigger需要的组件类型
14
- required_components: List[Type] = []
14
+ required_components: list[type] = []
15
15
 
16
16
  def __init__(self, block=None):
17
17
  self.block = block
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a19
3
+ Version: 2.0.0a20
4
4
  Summary: LLM-based城市环境agent构建库
5
5
  License: MIT
6
6
  Author: Yuwei Yan
@@ -34,8 +34,9 @@ Requires-Dist: openai (>=1.58.1,<2.0.0)
34
34
  Requires-Dist: pandavro (>=1.8.0,<2.0.0)
35
35
  Requires-Dist: poetry (>=1.2.2)
36
36
  Requires-Dist: protobuf (<=4.24.0)
37
+ Requires-Dist: psycopg[binary] (>=3.2.3,<4.0.0)
37
38
  Requires-Dist: pycitydata (==1.0.0)
38
- Requires-Dist: pycityproto (>=2.1.4,<3.0.0)
39
+ Requires-Dist: pycityproto (>=2.1.5,<3.0.0)
39
40
  Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
40
41
  Requires-Dist: ray (>=2.40.0,<3.0.0)
41
42
  Requires-Dist: sidecar (==0.7.0)
@@ -1,7 +1,7 @@
1
1
  pycityagent/__init__.py,sha256=EDxt3Su3lH1IMh9suNw7GeGL7UrXeWiZTw5KWNznDzc,637
2
- pycityagent/agent.py,sha256=gy9OlFpnEla7wK45k4m7_ySLp67mBKoG50vxUEmxbJE,24280
2
+ pycityagent/agent.py,sha256=HRFAG_iM1q3nvXtV0T-Dz01foOtty4IWka7h4WD97CU,24268
3
3
  pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
4
- pycityagent/economy/econ_client.py,sha256=EZDGxM7K83ucYZQ5qdv6HA-jhRCWbR1u5q-kLMqelKc,11192
4
+ pycityagent/economy/econ_client.py,sha256=GuHK9ZBnhqW3Z7F8ViDJn_iN73yOBbbwFyJv1wLEBDk,12211
5
5
  pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
6
6
  pycityagent/environment/interact/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  pycityagent/environment/interact/interact.py,sha256=ifxPPzuHeqLHIZ_6zvfXMoBOnBsXNIP4bYp7OJ7pnEQ,6588
@@ -25,19 +25,19 @@ pycityagent/environment/simulator.py,sha256=XjcxbyBIbB3Ht9z087z_oWIPAN6pP5Eq1lyf
25
25
  pycityagent/environment/utils/__init__.py,sha256=1m4Q1EfGvNpUsa1bgQzzCyWhfkpElnskNImjjFD3Znc,237
26
26
  pycityagent/environment/utils/base64.py,sha256=hoREzQo3FXMN79pqQLO2jgsDEvudciomyKii7MWljAM,374
27
27
  pycityagent/environment/utils/const.py,sha256=3RMNy7_bE7-23K90j9DFW_tWEzu8s7hSTgKbV-3BFl4,5327
28
- pycityagent/environment/utils/geojson.py,sha256=Ieg8Bzw63kKhJlhDIOVDoh-wQO4Sbtoe47FtIOy5wWg,686
28
+ pycityagent/environment/utils/geojson.py,sha256=LVHAdEhnZM8d0BoUnuPiIL_gaeXBIIglrLrfje5M0b4,661
29
29
  pycityagent/environment/utils/grpc.py,sha256=6EJwKXXktIWb1NcUiJzIRmfrY0S03QAXXGcCDHqAT00,1998
30
- pycityagent/environment/utils/map_utils.py,sha256=oqrRgQICC3SYw6gwjjPe_MAif7_t6dlrQpY8E32Fexs,5777
30
+ pycityagent/environment/utils/map_utils.py,sha256=lYOEoCFFK6-e9N5txLMMq4HUlxMqc8Uw1YrGW5oJmgg,5749
31
31
  pycityagent/environment/utils/port.py,sha256=3OM6kSUt3PxvDUOlgyiendBtETaWU8Mzk_8H0TzTmYg,295
32
32
  pycityagent/environment/utils/protobuf.py,sha256=0jBvK_s96y_n7tuMbG22TOtQmg71SGV4ONDy2IGsU9o,1148
33
33
  pycityagent/llm/__init__.py,sha256=7klKEmCcDWJIu-F4DoAukSuKfDbLhdczrSIhpwow-sY,145
34
- pycityagent/llm/embedding.py,sha256=Y0xhm_Ny6cawqzlendXb-mAS2QAuuEez1UtTR5-Kb2Q,4293
35
- pycityagent/llm/llm.py,sha256=BtxBvPK4tb8QlZIfxO5XJ73lKXwF8L31LqVbejWB8eo,15121
34
+ pycityagent/llm/embedding.py,sha256=2psX_EK67oPlYe77g43EYUYams4M9AiJqxpHTFHG0n8,4253
35
+ pycityagent/llm/llm.py,sha256=vJaaGqVuyV-GlBxrnvGKZnMDlxeTT_sGUTdxz5tYwEE,15141
36
36
  pycityagent/llm/llmconfig.py,sha256=4Ylf4OFSBEFy8jrOneeX0HvPhWEaF5jGvy1HkXK08Ro,436
37
37
  pycityagent/llm/utils.py,sha256=hoNPhvomb1u6lhFX0GctFipw74hVKb7bvUBDqwBzBYw,160
38
38
  pycityagent/memory/__init__.py,sha256=Hs2NhYpIG-lvpwPWwj4DydB1sxtjz7cuA4iDAzCXnjI,243
39
39
  pycityagent/memory/const.py,sha256=6zpJPJXWoH9-yf4RARYYff586agCoud9BRn7sPERB1g,932
40
- pycityagent/memory/memory.py,sha256=FjKVL_MgNBnSc0sox2tuxLqXg9_MQQr9vYdRDHMdDL4,18183
40
+ pycityagent/memory/memory.py,sha256=vJxHOI74aJDGZPFu2LbBr02ASfOYpig66fto6Gjr-6Q,18191
41
41
  pycityagent/memory/memory_base.py,sha256=euKZRCs4dbcKxjlZzpLCTnH066DAtRjj5g1JFKD40qQ,5633
42
42
  pycityagent/memory/profile.py,sha256=s4LnxSPGSjIGZXHXkkd8mMa6uYYZrytgyQdWjcaqGf4,5182
43
43
  pycityagent/memory/self_define.py,sha256=poPiexNhOLq_iTgK8s4mK_xoL_DAAcB8kMvInj7iE5E,5179
@@ -45,12 +45,12 @@ pycityagent/memory/state.py,sha256=5W0c1yJ-aaPpE74B2LEcw3Ygpm77tyooHv8NylyrozE,5
45
45
  pycityagent/memory/utils.py,sha256=wLNlNlZ-AY9VB8kbUIy0UQSYh26FOQABbhmKQkit5o8,850
46
46
  pycityagent/message/__init__.py,sha256=TCjazxqb5DVwbTu1fF0sNvaH_EPXVuj2XQ0p6W-QCLU,55
47
47
  pycityagent/message/messager.py,sha256=W_OVlNGcreHSBf6v-DrEnfNCXExB78ySr0w26MSncfU,2541
48
- pycityagent/metrics/__init__.py,sha256=IrlwXs_733b3PlMV99Ogh6R8sGCEalF2qIT2mQwdPjU,75
49
- pycityagent/metrics/mlflow_client.py,sha256=BH-YVMlYr-eTFdrCZjgMZWi8ptUEm7todOXCpYixAgU,3311
48
+ pycityagent/metrics/__init__.py,sha256=X08PaBbGVAd7_PRGLREXWxaqm7nS82WBQpD1zvQzcqc,128
49
+ pycityagent/metrics/mlflow_client.py,sha256=g_tHxWkWTDijtbGL74-HmiYzWVKb1y8-w12QrY9jL30,4449
50
50
  pycityagent/metrics/utils/const.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  pycityagent/simulation/__init__.py,sha256=jYaqaNpzM5M_e_ykISS_M-mIyYdzJXJWhgpfBpA6l5k,111
52
- pycityagent/simulation/agentgroup.py,sha256=H2E_YQ3ir3gQmPWsupHXA_LSBVkffzkXWl4UgZ9qLOc,13327
53
- pycityagent/simulation/simulation.py,sha256=ANf04GqaK7gLT50F_ZcP2UcpS_uGx12pLqkrK6eA2L8,21014
52
+ pycityagent/simulation/agentgroup.py,sha256=6p8-OP2x_syaQ1pWLgll0LHs823OHnappuY-8XzL_LU,13383
53
+ pycityagent/simulation/simulation.py,sha256=pxFdaNMLdJMcpW8gnOcxjfbz4gkpcvWsTR0rNP8rE9k,21675
54
54
  pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
55
55
  pycityagent/survey/manager.py,sha256=S5IkwTdelsdtZETChRcfCEczzwSrry_Fly9MY4s3rbk,1681
56
56
  pycityagent/survey/models.py,sha256=YE50UUt5qJ0O_lIUsSY6XFCGUTkJVNu_L1gAhaCJ2fs,3546
@@ -60,13 +60,13 @@ pycityagent/utils/decorators.py,sha256=Gk3r41hfk6awui40tbwpq3C7wC7jHaRmLRlcJFlLQ
60
60
  pycityagent/utils/parsers/__init__.py,sha256=AN2xgiPxszWK4rpX7zrqRsqNwfGF3WnCA5-PFTvbaKk,281
61
61
  pycityagent/utils/parsers/code_block_parser.py,sha256=Cs2Z_hm9VfNCpPPll1TwteaJF-HAQPs-3RApsOekFm4,1173
62
62
  pycityagent/utils/parsers/json_parser.py,sha256=FZ3XN1g8z4Dr2TFraUOoah1oQcze4fPd2m01hHoX0Mo,2917
63
- pycityagent/utils/parsers/parser_base.py,sha256=k6DVqwAMK3jJdOP4IeLE-aFPm3V2F-St5qRBuRdx4aU,1742
63
+ pycityagent/utils/parsers/parser_base.py,sha256=KBKO4zLZPNdGjPAGqIus8LseZ8W3Tlt2y0QxqeCd25Q,1713
64
64
  pycityagent/utils/survey_util.py,sha256=Be9nptmu2JtesFNemPgORh_2GsN7rcDYGQS9Zfvc5OI,2169
65
65
  pycityagent/workflow/__init__.py,sha256=QNkUV-9mACMrR8c0cSKna2gC1mMZdxXbxWzjE-Uods0,621
66
66
  pycityagent/workflow/block.py,sha256=WkE2On97DCZS_9n8aIgT8wxv9Oaff4Fdf2tLqbKfMtE,6010
67
- pycityagent/workflow/prompt.py,sha256=tY69nDO8fgYfF_dOA-iceR8pAhkYmCqoox8uRPqEuGY,2956
68
- pycityagent/workflow/tool.py,sha256=uaB0dV35jA9v2UqWw9L8iPM-HJW5e9BlgFVgOMf9jvw,8201
69
- pycityagent/workflow/trigger.py,sha256=t5X_i0WtL32bipZSsq_E3UUyYYudYLxQUpvxbgClp2s,5683
70
- pycityagent-2.0.0a19.dist-info/METADATA,sha256=x0DbwNQ4Zj0rswWGIv6MA0zSBBEiExXLIdErYA0bhnE,7800
71
- pycityagent-2.0.0a19.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
72
- pycityagent-2.0.0a19.dist-info/RECORD,,
67
+ pycityagent/workflow/prompt.py,sha256=6jI0Rq54JLv3-IXqZLYug62vse10wTI83xvf4ZX42nk,2929
68
+ pycityagent/workflow/tool.py,sha256=SGY18lT71hBLKagopirFbxRjPY_387Dobo9SUwjHIn0,8215
69
+ pycityagent/workflow/trigger.py,sha256=Df-MOBEDWBbM-v0dFLQLXteLsipymT4n8vqexmK2GiQ,5643
70
+ pycityagent-2.0.0a20.dist-info/METADATA,sha256=aygfImVFM4jG0nXVuunu5LZDNi4JyikBULFce0WGWeM,7848
71
+ pycityagent-2.0.0a20.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
72
+ pycityagent-2.0.0a20.dist-info/RECORD,,