pycityagent 2.0.0a18__py3-none-any.whl → 2.0.0a20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,8 @@ import asyncio
2
2
  import logging
3
3
  from copy import deepcopy
4
4
  from datetime import datetime
5
- from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
5
+ from typing import Any, Literal, Optional, Union
6
+ from collections.abc import Sequence,Callable
6
7
 
7
8
  import numpy as np
8
9
  from pyparsing import deque
@@ -27,10 +28,10 @@ class Memory:
27
28
 
28
29
  def __init__(
29
30
  self,
30
- config: Optional[Dict[Any, Any]] = None,
31
- profile: Optional[Dict[Any, Any]] = None,
32
- base: Optional[Dict[Any, Any]] = None,
33
- motion: Optional[Dict[Any, Any]] = None,
31
+ config: Optional[dict[Any, Any]] = None,
32
+ profile: Optional[dict[Any, Any]] = None,
33
+ base: Optional[dict[Any, Any]] = None,
34
+ motion: Optional[dict[Any, Any]] = None,
34
35
  activate_timestamp: bool = False,
35
36
  embedding_model: Any = None,
36
37
  ) -> None:
@@ -38,7 +39,7 @@ class Memory:
38
39
  Initializes the Memory with optional configuration.
39
40
 
40
41
  Args:
41
- config (Optional[Dict[Any, Any]], optional):
42
+ config (Optional[dict[Any, Any]], optional):
42
43
  A configuration dictionary for dynamic memory. The dictionary format is:
43
44
  - Key: The name of the dynamic memory field.
44
45
  - Value: Can be one of two formats:
@@ -46,24 +47,24 @@ class Memory:
46
47
  2. A callable that returns the default value when invoked (useful for complex default values).
47
48
  Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
48
49
  Defaults to None.
49
- profile (Optional[Dict[Any, Any]], optional): profile attribute dict.
50
- base (Optional[Dict[Any, Any]], optional): base attribute dict from City Simulator.
51
- motion (Optional[Dict[Any, Any]], optional): motion attribute dict from City Simulator.
50
+ profile (Optional[dict[Any, Any]], optional): profile attribute dict.
51
+ base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
52
+ motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
52
53
  activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
53
54
  embedding_model (Any): The embedding model for memory search.
54
55
  """
55
- self.watchers: Dict[str, List[Callable]] = {}
56
+ self.watchers: dict[str, list[Callable]] = {}
56
57
  self._lock = asyncio.Lock()
57
58
  self.embedding_model = embedding_model
58
59
 
59
60
  # 初始化embedding存储
60
61
  self._embeddings = {"state": {}, "profile": {}, "dynamic": {}}
61
62
 
62
- _dynamic_config: Dict[Any, Any] = {}
63
- _state_config: Dict[Any, Any] = {}
64
- _profile_config: Dict[Any, Any] = {}
63
+ _dynamic_config: dict[Any, Any] = {}
64
+ _state_config: dict[Any, Any] = {}
65
+ _profile_config: dict[Any, Any] = {}
65
66
  # 记录哪些字段需要embedding
66
- self._embedding_fields: Dict[str, bool] = {}
67
+ self._embedding_fields: dict[str, bool] = {}
67
68
 
68
69
  if config is not None:
69
70
  for k, v in config.items():
@@ -303,7 +304,7 @@ class Memory:
303
304
 
304
305
  async def update_batch(
305
306
  self,
306
- content: Union[Dict, Sequence[Tuple[Any, Any]]],
307
+ content: Union[dict, Sequence[tuple[Any, Any]]],
307
308
  mode: Union[Literal["replace"], Literal["merge"]] = "replace",
308
309
  store_snapshot: bool = False,
309
310
  protect_llm_read_only_fields: bool = True,
@@ -312,7 +313,7 @@ class Memory:
312
313
  Updates multiple values in the memory at once.
313
314
 
314
315
  Args:
315
- content (Union[Dict, Sequence[Tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
316
+ content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
316
317
  mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
317
318
  store_snapshot (bool): Whether to store a snapshot of the memory after the update.
318
319
  protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
@@ -321,9 +322,9 @@ class Memory:
321
322
  TypeError: If the content type is neither a dictionary nor a sequence of tuples.
322
323
  """
323
324
  if isinstance(content, dict):
324
- _list_content: List[Tuple[Any, Any]] = [(k, v) for k, v in content.items()]
325
+ _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
325
326
  elif isinstance(content, Sequence):
326
- _list_content: List[Tuple[Any, Any]] = [(k, v) for k, v in content]
327
+ _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
327
328
  else:
328
329
  raise TypeError(f"Invalid content type `{type(content)}`!")
329
330
  for k, v in _list_content[:1]:
@@ -353,12 +354,12 @@ class Memory:
353
354
  @lock_decorator
354
355
  async def export(
355
356
  self,
356
- ) -> Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]:
357
+ ) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
357
358
  """
358
359
  Exports the current state of all memory sections.
359
360
 
360
361
  Returns:
361
- Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
362
+ tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
362
363
  """
363
364
  return (
364
365
  await self._profile.export(),
@@ -369,14 +370,14 @@ class Memory:
369
370
  @lock_decorator
370
371
  async def load(
371
372
  self,
372
- snapshots: Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]],
373
+ snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
373
374
  reset_memory: bool = True,
374
375
  ) -> None:
375
376
  """
376
377
  Import the snapshot memories of all sections.
377
378
 
378
379
  Args:
379
- snapshots (Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]): The exported snapshots.
380
+ snapshots (tuple[Sequence[dict], Sequence[dict], Sequence[dict]]): The exported snapshots.
380
381
  reset_memory (bool): Whether to reset previous memory.
381
382
  """
382
383
  _profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
@@ -0,0 +1,6 @@
1
+ from .mlflow_client import MlflowClient,init_mlflow_connection
2
+
3
+ __all__ = [
4
+ "MlflowClient",
5
+ "init_mlflow_connection",
6
+ ]
@@ -0,0 +1,147 @@
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import uuid
5
+ from collections.abc import Sequence
6
+ from typing import Any, Optional, Union
7
+
8
+ import mlflow
9
+ from mlflow.entities import (Dataset, DatasetInput, Document, Experiment,
10
+ ExperimentTag, FileInfo, InputTag, LifecycleStage,
11
+ LiveSpan, Metric, NoOpSpan, Param, Run, RunData,
12
+ RunInfo, RunInputs, RunStatus, RunTag, SourceType,
13
+ Span, SpanEvent, SpanStatus, SpanStatusCode,
14
+ SpanType, Trace, TraceData, TraceInfo, ViewType)
15
+
16
+ from ..utils.decorators import lock_decorator
17
+
18
+ logger = logging.getLogger("mlflow")
19
+
20
+
21
+ def init_mlflow_connection(
22
+ config: dict,
23
+ mlflow_run_name: Optional[str] = None,
24
+ experiment_name: Optional[str] = None,
25
+ experiment_description: Optional[str] = None,
26
+ experiment_tags: Optional[dict[str, Any]] = None,
27
+ ) -> tuple[str, tuple[str, mlflow.MlflowClient, Run, str]]:
28
+
29
+ os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
30
+ os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
31
+
32
+ run_uuid = str(uuid.uuid4())
33
+ # run name
34
+ if mlflow_run_name is None:
35
+ mlflow_run_name = f"exp_{run_uuid}"
36
+
37
+ # exp name
38
+ if experiment_name is None:
39
+ experiment_name = f"run_{run_uuid}"
40
+
41
+ # tags
42
+ if experiment_tags is None:
43
+ experiment_tags = {}
44
+ if experiment_description is not None:
45
+ experiment_tags["mlflow.note.content"] = experiment_description
46
+
47
+ uri = config["mlflow_uri"]
48
+ client = mlflow.MlflowClient(tracking_uri=uri)
49
+
50
+ # experiment
51
+ try:
52
+ experiment_id = client.create_experiment(
53
+ name=experiment_name,
54
+ tags=experiment_tags,
55
+ )
56
+ except Exception as e:
57
+ experiment = client.get_experiment_by_name(experiment_name)
58
+ if experiment is None:
59
+ raise e
60
+ experiment_id = experiment.experiment_id
61
+
62
+ # run
63
+ run = client.create_run(experiment_id=experiment_id, run_name=mlflow_run_name)
64
+
65
+ run_id = run.info.run_id
66
+
67
+ return run_id, (uri, client, run, run_uuid)
68
+
69
+
70
+ class MlflowClient:
71
+ """
72
+ - Mlflow client
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ config: dict,
78
+ mlflow_run_name: Optional[str] = None,
79
+ experiment_name: Optional[str] = None,
80
+ experiment_description: Optional[str] = None,
81
+ experiment_tags: Optional[dict[str, Any]] = None,
82
+ run_id: Optional[str] = None,
83
+ ) -> None:
84
+ if run_id is None:
85
+ self._run_id, (
86
+ self._mlflow_uri,
87
+ self._client,
88
+ self._run,
89
+ self._run_uuid,
90
+ ) = init_mlflow_connection(
91
+ config=config,
92
+ mlflow_run_name=mlflow_run_name,
93
+ experiment_name=experiment_name,
94
+ experiment_description=experiment_description,
95
+ experiment_tags=experiment_tags,
96
+ )
97
+ else:
98
+ self._mlflow_uri = uri = config["mlflow_uri"]
99
+ os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
100
+ os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
101
+ self._client = client = mlflow.MlflowClient(tracking_uri=uri)
102
+ self._run = client.get_run(run_id=run_id)
103
+ self._run_id = run_id
104
+ self._run_uuid = run_uuid = str(uuid.uuid4())
105
+ self._lock = asyncio.Lock()
106
+
107
+ @property
108
+ def client(
109
+ self,
110
+ ) -> mlflow.MlflowClient:
111
+ return self._client
112
+
113
+ @property
114
+ def run_id(
115
+ self,
116
+ ) -> str:
117
+ assert self._run_id is not None
118
+ return self._run_id
119
+
120
+ @lock_decorator
121
+ async def log_batch(
122
+ self,
123
+ metrics: Sequence[Metric] = (),
124
+ params: Sequence[Param] = (),
125
+ tags: Sequence[RunTag] = (),
126
+ ):
127
+ self.client.log_batch(
128
+ run_id=self.run_id, metrics=metrics, params=params, tags=tags
129
+ )
130
+
131
+ @lock_decorator
132
+ async def log_metric(
133
+ self,
134
+ key: str,
135
+ value: float,
136
+ step: Optional[int] = None,
137
+ timestamp: Optional[int] = None,
138
+ ):
139
+ if timestamp is not None:
140
+ timestamp = int(timestamp)
141
+ self.client.log_metric(
142
+ run_id=self.run_id,
143
+ key=key,
144
+ value=value,
145
+ timestamp=timestamp,
146
+ step=step,
147
+ )
File without changes
@@ -1,34 +1,52 @@
1
1
  import asyncio
2
- from datetime import datetime
3
2
  import json
4
3
  import logging
5
- from pathlib import Path
4
+ import time
6
5
  import uuid
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Any
9
+ from uuid import UUID
10
+
7
11
  import fastavro
8
12
  import ray
9
- from uuid import UUID
10
- from pycityagent.agent import Agent, CitizenAgent, InstitutionAgent
11
- from pycityagent.economy.econ_client import EconomyClient
12
- from pycityagent.environment.simulator import Simulator
13
- from pycityagent.llm.llm import LLM
14
- from pycityagent.llm.llmconfig import LLMConfig
15
- from pycityagent.message import Messager
16
- from pycityagent.utils import STATUS_SCHEMA, PROFILE_SCHEMA, DIALOG_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
17
- from typing import Any
13
+
14
+ from ..agent import Agent, CitizenAgent, InstitutionAgent
15
+ from ..economy.econ_client import EconomyClient
16
+ from ..environment.simulator import Simulator
17
+ from ..llm.llm import LLM
18
+ from ..llm.llmconfig import LLMConfig
19
+ from ..message import Messager
20
+ from ..metrics import MlflowClient
21
+ from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
22
+ STATUS_SCHEMA, SURVEY_SCHEMA)
18
23
 
19
24
  logger = logging.getLogger("pycityagent")
20
25
 
26
+
21
27
  @ray.remote
22
28
  class AgentGroup:
23
- def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int):
29
+ def __init__(
30
+ self,
31
+ agents: list[Agent],
32
+ config: dict,
33
+ exp_id: str | UUID,
34
+ exp_name: str,
35
+ enable_avro: bool,
36
+ avro_path: Path,
37
+ enable_pgsql: bool,
38
+ pgsql_copy_writer: ray.ObjectRef,
39
+ mlflow_run_id: str,
40
+ logging_level: int,
41
+ ):
24
42
  logger.setLevel(logging_level)
25
43
  self._uuid = str(uuid.uuid4())
26
44
  self.agents = agents
27
45
  self.config = config
28
46
  self.exp_id = exp_id
29
47
  self.enable_avro = enable_avro
30
- self.avro_path = avro_path / f"{self._uuid}"
31
48
  if enable_avro:
49
+ self.avro_path = avro_path / f"{self._uuid}"
32
50
  self.avro_path.mkdir(parents=True, exist_ok=True)
33
51
  self.avro_file = {
34
52
  "profile": self.avro_path / f"profile.avro",
@@ -36,7 +54,7 @@ class AgentGroup:
36
54
  "status": self.avro_path / f"status.avro",
37
55
  "survey": self.avro_path / f"survey.avro",
38
56
  }
39
-
57
+
40
58
  self.messager = Messager(
41
59
  hostname=config["simulator_request"]["mqtt"]["server"],
42
60
  port=config["simulator_request"]["mqtt"]["port"],
@@ -63,21 +81,36 @@ class AgentGroup:
63
81
  else:
64
82
  self.economy_client = None
65
83
 
84
+ # Mlflow
85
+ _mlflow_config = config.get("metric_request", {}).get("mlflow")
86
+ if _mlflow_config:
87
+ logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
88
+ self.mlflow_client = MlflowClient(
89
+ config=_mlflow_config,
90
+ mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
91
+ experiment_name=exp_name,
92
+ run_id=mlflow_run_id,
93
+ )
94
+ else:
95
+ self.mlflow_client = None
96
+
66
97
  for agent in self.agents:
67
- agent.set_exp_id(self.exp_id) # type: ignore
98
+ agent.set_exp_id(self.exp_id) # type: ignore
68
99
  agent.set_llm_client(self.llm)
69
100
  agent.set_simulator(self.simulator)
70
101
  if self.economy_client is not None:
71
102
  agent.set_economy_client(self.economy_client)
103
+ if self.mlflow_client is not None:
104
+ agent.set_mlflow_client(self.mlflow_client)
72
105
  agent.set_messager(self.messager)
73
106
  if self.enable_avro:
74
- agent.set_avro_file(self.avro_file) # type: ignore
107
+ agent.set_avro_file(self.avro_file) # type: ignore
75
108
 
76
109
  async def init_agents(self):
77
110
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
78
111
  logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
79
112
  for agent in self.agents:
80
- await agent.bind_to_simulator() # type: ignore
113
+ await agent.bind_to_simulator() # type: ignore
81
114
  self.id2agent = {agent._uuid: agent for agent in self.agents}
82
115
  logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
83
116
  await self.messager.connect()
@@ -104,7 +137,7 @@ class AgentGroup:
104
137
  for agent in self.agents:
105
138
  profile = await agent.memory._profile.export()
106
139
  profile = profile[0]
107
- profile['id'] = agent._uuid
140
+ profile["id"] = agent._uuid
108
141
  profiles.append(profile)
109
142
  fastavro.writer(f, PROFILE_SCHEMA, profiles)
110
143
 
@@ -139,7 +172,9 @@ class AgentGroup:
139
172
  return results
140
173
 
141
174
  async def update(self, target_agent_uuid: str, target_key: str, content: Any):
142
- logger.debug(f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}")
175
+ logger.debug(
176
+ f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
177
+ )
143
178
  agent = self.id2agent[target_agent_uuid]
144
179
  await agent.memory.update(target_key, content)
145
180
 
@@ -147,7 +182,9 @@ class AgentGroup:
147
182
  logger.debug(f"-----Starting message dispatch for group {self._uuid}")
148
183
  while True:
149
184
  if not self.messager.is_connected():
150
- logger.warning("Messager is not connected. Skipping message processing.")
185
+ logger.warning(
186
+ "Messager is not connected. Skipping message processing."
187
+ )
151
188
 
152
189
  # Step 1: 获取消息
153
190
  messages = await self.messager.fetch_messages()
@@ -165,7 +202,7 @@ class AgentGroup:
165
202
 
166
203
  # 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
167
204
  _, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
168
-
205
+
169
206
  if agent_uuid in self.id2agent:
170
207
  agent = self.id2agent[agent_uuid]
171
208
  # topic_type: agent-chat, user-chat, user-survey, gather