pycityagent 2.0.0a18__py3-none-any.whl → 2.0.0a20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +126 -100
- pycityagent/economy/econ_client.py +39 -2
- pycityagent/environment/simulator.py +2 -2
- pycityagent/environment/utils/geojson.py +1 -3
- pycityagent/environment/utils/map_utils.py +15 -15
- pycityagent/llm/embedding.py +8 -9
- pycityagent/llm/llm.py +5 -5
- pycityagent/memory/memory.py +23 -22
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/simulation/agentgroup.py +58 -21
- pycityagent/simulation/simulation.py +114 -38
- pycityagent/utils/parsers/parser_base.py +1 -1
- pycityagent/workflow/__init__.py +5 -3
- pycityagent/workflow/block.py +2 -3
- pycityagent/workflow/prompt.py +6 -6
- pycityagent/workflow/tool.py +53 -4
- pycityagent/workflow/trigger.py +2 -2
- {pycityagent-2.0.0a18.dist-info → pycityagent-2.0.0a20.dist-info}/METADATA +4 -2
- {pycityagent-2.0.0a18.dist-info → pycityagent-2.0.0a20.dist-info}/RECORD +22 -19
- {pycityagent-2.0.0a18.dist-info → pycityagent-2.0.0a20.dist-info}/WHEEL +0 -0
pycityagent/memory/memory.py
CHANGED
@@ -2,7 +2,8 @@ import asyncio
|
|
2
2
|
import logging
|
3
3
|
from copy import deepcopy
|
4
4
|
from datetime import datetime
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Literal, Optional, Union
|
6
|
+
from collections.abc import Sequence,Callable
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
from pyparsing import deque
|
@@ -27,10 +28,10 @@ class Memory:
|
|
27
28
|
|
28
29
|
def __init__(
|
29
30
|
self,
|
30
|
-
config: Optional[
|
31
|
-
profile: Optional[
|
32
|
-
base: Optional[
|
33
|
-
motion: Optional[
|
31
|
+
config: Optional[dict[Any, Any]] = None,
|
32
|
+
profile: Optional[dict[Any, Any]] = None,
|
33
|
+
base: Optional[dict[Any, Any]] = None,
|
34
|
+
motion: Optional[dict[Any, Any]] = None,
|
34
35
|
activate_timestamp: bool = False,
|
35
36
|
embedding_model: Any = None,
|
36
37
|
) -> None:
|
@@ -38,7 +39,7 @@ class Memory:
|
|
38
39
|
Initializes the Memory with optional configuration.
|
39
40
|
|
40
41
|
Args:
|
41
|
-
config (Optional[
|
42
|
+
config (Optional[dict[Any, Any]], optional):
|
42
43
|
A configuration dictionary for dynamic memory. The dictionary format is:
|
43
44
|
- Key: The name of the dynamic memory field.
|
44
45
|
- Value: Can be one of two formats:
|
@@ -46,24 +47,24 @@ class Memory:
|
|
46
47
|
2. A callable that returns the default value when invoked (useful for complex default values).
|
47
48
|
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
48
49
|
Defaults to None.
|
49
|
-
profile (Optional[
|
50
|
-
base (Optional[
|
51
|
-
motion (Optional[
|
50
|
+
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
51
|
+
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
52
|
+
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
52
53
|
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
53
54
|
embedding_model (Any): The embedding model for memory search.
|
54
55
|
"""
|
55
|
-
self.watchers:
|
56
|
+
self.watchers: dict[str, list[Callable]] = {}
|
56
57
|
self._lock = asyncio.Lock()
|
57
58
|
self.embedding_model = embedding_model
|
58
59
|
|
59
60
|
# 初始化embedding存储
|
60
61
|
self._embeddings = {"state": {}, "profile": {}, "dynamic": {}}
|
61
62
|
|
62
|
-
_dynamic_config:
|
63
|
-
_state_config:
|
64
|
-
_profile_config:
|
63
|
+
_dynamic_config: dict[Any, Any] = {}
|
64
|
+
_state_config: dict[Any, Any] = {}
|
65
|
+
_profile_config: dict[Any, Any] = {}
|
65
66
|
# 记录哪些字段需要embedding
|
66
|
-
self._embedding_fields:
|
67
|
+
self._embedding_fields: dict[str, bool] = {}
|
67
68
|
|
68
69
|
if config is not None:
|
69
70
|
for k, v in config.items():
|
@@ -303,7 +304,7 @@ class Memory:
|
|
303
304
|
|
304
305
|
async def update_batch(
|
305
306
|
self,
|
306
|
-
content: Union[
|
307
|
+
content: Union[dict, Sequence[tuple[Any, Any]]],
|
307
308
|
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
308
309
|
store_snapshot: bool = False,
|
309
310
|
protect_llm_read_only_fields: bool = True,
|
@@ -312,7 +313,7 @@ class Memory:
|
|
312
313
|
Updates multiple values in the memory at once.
|
313
314
|
|
314
315
|
Args:
|
315
|
-
content (Union[
|
316
|
+
content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
|
316
317
|
mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
317
318
|
store_snapshot (bool): Whether to store a snapshot of the memory after the update.
|
318
319
|
protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
|
@@ -321,9 +322,9 @@ class Memory:
|
|
321
322
|
TypeError: If the content type is neither a dictionary nor a sequence of tuples.
|
322
323
|
"""
|
323
324
|
if isinstance(content, dict):
|
324
|
-
_list_content:
|
325
|
+
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
|
325
326
|
elif isinstance(content, Sequence):
|
326
|
-
_list_content:
|
327
|
+
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
|
327
328
|
else:
|
328
329
|
raise TypeError(f"Invalid content type `{type(content)}`!")
|
329
330
|
for k, v in _list_content[:1]:
|
@@ -353,12 +354,12 @@ class Memory:
|
|
353
354
|
@lock_decorator
|
354
355
|
async def export(
|
355
356
|
self,
|
356
|
-
) ->
|
357
|
+
) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
|
357
358
|
"""
|
358
359
|
Exports the current state of all memory sections.
|
359
360
|
|
360
361
|
Returns:
|
361
|
-
|
362
|
+
tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
|
362
363
|
"""
|
363
364
|
return (
|
364
365
|
await self._profile.export(),
|
@@ -369,14 +370,14 @@ class Memory:
|
|
369
370
|
@lock_decorator
|
370
371
|
async def load(
|
371
372
|
self,
|
372
|
-
snapshots:
|
373
|
+
snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
|
373
374
|
reset_memory: bool = True,
|
374
375
|
) -> None:
|
375
376
|
"""
|
376
377
|
Import the snapshot memories of all sections.
|
377
378
|
|
378
379
|
Args:
|
379
|
-
snapshots (
|
380
|
+
snapshots (tuple[Sequence[dict], Sequence[dict], Sequence[dict]]): The exported snapshots.
|
380
381
|
reset_memory (bool): Whether to reset previous memory.
|
381
382
|
"""
|
382
383
|
_profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
|
@@ -0,0 +1,147 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import uuid
|
5
|
+
from collections.abc import Sequence
|
6
|
+
from typing import Any, Optional, Union
|
7
|
+
|
8
|
+
import mlflow
|
9
|
+
from mlflow.entities import (Dataset, DatasetInput, Document, Experiment,
|
10
|
+
ExperimentTag, FileInfo, InputTag, LifecycleStage,
|
11
|
+
LiveSpan, Metric, NoOpSpan, Param, Run, RunData,
|
12
|
+
RunInfo, RunInputs, RunStatus, RunTag, SourceType,
|
13
|
+
Span, SpanEvent, SpanStatus, SpanStatusCode,
|
14
|
+
SpanType, Trace, TraceData, TraceInfo, ViewType)
|
15
|
+
|
16
|
+
from ..utils.decorators import lock_decorator
|
17
|
+
|
18
|
+
logger = logging.getLogger("mlflow")
|
19
|
+
|
20
|
+
|
21
|
+
def init_mlflow_connection(
|
22
|
+
config: dict,
|
23
|
+
mlflow_run_name: Optional[str] = None,
|
24
|
+
experiment_name: Optional[str] = None,
|
25
|
+
experiment_description: Optional[str] = None,
|
26
|
+
experiment_tags: Optional[dict[str, Any]] = None,
|
27
|
+
) -> tuple[str, tuple[str, mlflow.MlflowClient, Run, str]]:
|
28
|
+
|
29
|
+
os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
|
30
|
+
os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
|
31
|
+
|
32
|
+
run_uuid = str(uuid.uuid4())
|
33
|
+
# run name
|
34
|
+
if mlflow_run_name is None:
|
35
|
+
mlflow_run_name = f"exp_{run_uuid}"
|
36
|
+
|
37
|
+
# exp name
|
38
|
+
if experiment_name is None:
|
39
|
+
experiment_name = f"run_{run_uuid}"
|
40
|
+
|
41
|
+
# tags
|
42
|
+
if experiment_tags is None:
|
43
|
+
experiment_tags = {}
|
44
|
+
if experiment_description is not None:
|
45
|
+
experiment_tags["mlflow.note.content"] = experiment_description
|
46
|
+
|
47
|
+
uri = config["mlflow_uri"]
|
48
|
+
client = mlflow.MlflowClient(tracking_uri=uri)
|
49
|
+
|
50
|
+
# experiment
|
51
|
+
try:
|
52
|
+
experiment_id = client.create_experiment(
|
53
|
+
name=experiment_name,
|
54
|
+
tags=experiment_tags,
|
55
|
+
)
|
56
|
+
except Exception as e:
|
57
|
+
experiment = client.get_experiment_by_name(experiment_name)
|
58
|
+
if experiment is None:
|
59
|
+
raise e
|
60
|
+
experiment_id = experiment.experiment_id
|
61
|
+
|
62
|
+
# run
|
63
|
+
run = client.create_run(experiment_id=experiment_id, run_name=mlflow_run_name)
|
64
|
+
|
65
|
+
run_id = run.info.run_id
|
66
|
+
|
67
|
+
return run_id, (uri, client, run, run_uuid)
|
68
|
+
|
69
|
+
|
70
|
+
class MlflowClient:
|
71
|
+
"""
|
72
|
+
- Mlflow client
|
73
|
+
"""
|
74
|
+
|
75
|
+
def __init__(
|
76
|
+
self,
|
77
|
+
config: dict,
|
78
|
+
mlflow_run_name: Optional[str] = None,
|
79
|
+
experiment_name: Optional[str] = None,
|
80
|
+
experiment_description: Optional[str] = None,
|
81
|
+
experiment_tags: Optional[dict[str, Any]] = None,
|
82
|
+
run_id: Optional[str] = None,
|
83
|
+
) -> None:
|
84
|
+
if run_id is None:
|
85
|
+
self._run_id, (
|
86
|
+
self._mlflow_uri,
|
87
|
+
self._client,
|
88
|
+
self._run,
|
89
|
+
self._run_uuid,
|
90
|
+
) = init_mlflow_connection(
|
91
|
+
config=config,
|
92
|
+
mlflow_run_name=mlflow_run_name,
|
93
|
+
experiment_name=experiment_name,
|
94
|
+
experiment_description=experiment_description,
|
95
|
+
experiment_tags=experiment_tags,
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
self._mlflow_uri = uri = config["mlflow_uri"]
|
99
|
+
os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
|
100
|
+
os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
|
101
|
+
self._client = client = mlflow.MlflowClient(tracking_uri=uri)
|
102
|
+
self._run = client.get_run(run_id=run_id)
|
103
|
+
self._run_id = run_id
|
104
|
+
self._run_uuid = run_uuid = str(uuid.uuid4())
|
105
|
+
self._lock = asyncio.Lock()
|
106
|
+
|
107
|
+
@property
|
108
|
+
def client(
|
109
|
+
self,
|
110
|
+
) -> mlflow.MlflowClient:
|
111
|
+
return self._client
|
112
|
+
|
113
|
+
@property
|
114
|
+
def run_id(
|
115
|
+
self,
|
116
|
+
) -> str:
|
117
|
+
assert self._run_id is not None
|
118
|
+
return self._run_id
|
119
|
+
|
120
|
+
@lock_decorator
|
121
|
+
async def log_batch(
|
122
|
+
self,
|
123
|
+
metrics: Sequence[Metric] = (),
|
124
|
+
params: Sequence[Param] = (),
|
125
|
+
tags: Sequence[RunTag] = (),
|
126
|
+
):
|
127
|
+
self.client.log_batch(
|
128
|
+
run_id=self.run_id, metrics=metrics, params=params, tags=tags
|
129
|
+
)
|
130
|
+
|
131
|
+
@lock_decorator
|
132
|
+
async def log_metric(
|
133
|
+
self,
|
134
|
+
key: str,
|
135
|
+
value: float,
|
136
|
+
step: Optional[int] = None,
|
137
|
+
timestamp: Optional[int] = None,
|
138
|
+
):
|
139
|
+
if timestamp is not None:
|
140
|
+
timestamp = int(timestamp)
|
141
|
+
self.client.log_metric(
|
142
|
+
run_id=self.run_id,
|
143
|
+
key=key,
|
144
|
+
value=value,
|
145
|
+
timestamp=timestamp,
|
146
|
+
step=step,
|
147
|
+
)
|
File without changes
|
@@ -1,34 +1,52 @@
|
|
1
1
|
import asyncio
|
2
|
-
from datetime import datetime
|
3
2
|
import json
|
4
3
|
import logging
|
5
|
-
|
4
|
+
import time
|
6
5
|
import uuid
|
6
|
+
from datetime import datetime
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any
|
9
|
+
from uuid import UUID
|
10
|
+
|
7
11
|
import fastavro
|
8
12
|
import ray
|
9
|
-
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from
|
13
|
+
|
14
|
+
from ..agent import Agent, CitizenAgent, InstitutionAgent
|
15
|
+
from ..economy.econ_client import EconomyClient
|
16
|
+
from ..environment.simulator import Simulator
|
17
|
+
from ..llm.llm import LLM
|
18
|
+
from ..llm.llmconfig import LLMConfig
|
19
|
+
from ..message import Messager
|
20
|
+
from ..metrics import MlflowClient
|
21
|
+
from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
|
22
|
+
STATUS_SCHEMA, SURVEY_SCHEMA)
|
18
23
|
|
19
24
|
logger = logging.getLogger("pycityagent")
|
20
25
|
|
26
|
+
|
21
27
|
@ray.remote
|
22
28
|
class AgentGroup:
|
23
|
-
def __init__(
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
agents: list[Agent],
|
32
|
+
config: dict,
|
33
|
+
exp_id: str | UUID,
|
34
|
+
exp_name: str,
|
35
|
+
enable_avro: bool,
|
36
|
+
avro_path: Path,
|
37
|
+
enable_pgsql: bool,
|
38
|
+
pgsql_copy_writer: ray.ObjectRef,
|
39
|
+
mlflow_run_id: str,
|
40
|
+
logging_level: int,
|
41
|
+
):
|
24
42
|
logger.setLevel(logging_level)
|
25
43
|
self._uuid = str(uuid.uuid4())
|
26
44
|
self.agents = agents
|
27
45
|
self.config = config
|
28
46
|
self.exp_id = exp_id
|
29
47
|
self.enable_avro = enable_avro
|
30
|
-
self.avro_path = avro_path / f"{self._uuid}"
|
31
48
|
if enable_avro:
|
49
|
+
self.avro_path = avro_path / f"{self._uuid}"
|
32
50
|
self.avro_path.mkdir(parents=True, exist_ok=True)
|
33
51
|
self.avro_file = {
|
34
52
|
"profile": self.avro_path / f"profile.avro",
|
@@ -36,7 +54,7 @@ class AgentGroup:
|
|
36
54
|
"status": self.avro_path / f"status.avro",
|
37
55
|
"survey": self.avro_path / f"survey.avro",
|
38
56
|
}
|
39
|
-
|
57
|
+
|
40
58
|
self.messager = Messager(
|
41
59
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
42
60
|
port=config["simulator_request"]["mqtt"]["port"],
|
@@ -63,21 +81,36 @@ class AgentGroup:
|
|
63
81
|
else:
|
64
82
|
self.economy_client = None
|
65
83
|
|
84
|
+
# Mlflow
|
85
|
+
_mlflow_config = config.get("metric_request", {}).get("mlflow")
|
86
|
+
if _mlflow_config:
|
87
|
+
logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
|
88
|
+
self.mlflow_client = MlflowClient(
|
89
|
+
config=_mlflow_config,
|
90
|
+
mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
|
91
|
+
experiment_name=exp_name,
|
92
|
+
run_id=mlflow_run_id,
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
self.mlflow_client = None
|
96
|
+
|
66
97
|
for agent in self.agents:
|
67
|
-
agent.set_exp_id(self.exp_id)
|
98
|
+
agent.set_exp_id(self.exp_id) # type: ignore
|
68
99
|
agent.set_llm_client(self.llm)
|
69
100
|
agent.set_simulator(self.simulator)
|
70
101
|
if self.economy_client is not None:
|
71
102
|
agent.set_economy_client(self.economy_client)
|
103
|
+
if self.mlflow_client is not None:
|
104
|
+
agent.set_mlflow_client(self.mlflow_client)
|
72
105
|
agent.set_messager(self.messager)
|
73
106
|
if self.enable_avro:
|
74
|
-
agent.set_avro_file(self.avro_file)
|
107
|
+
agent.set_avro_file(self.avro_file) # type: ignore
|
75
108
|
|
76
109
|
async def init_agents(self):
|
77
110
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
78
111
|
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
79
112
|
for agent in self.agents:
|
80
|
-
await agent.bind_to_simulator()
|
113
|
+
await agent.bind_to_simulator() # type: ignore
|
81
114
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
82
115
|
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
83
116
|
await self.messager.connect()
|
@@ -104,7 +137,7 @@ class AgentGroup:
|
|
104
137
|
for agent in self.agents:
|
105
138
|
profile = await agent.memory._profile.export()
|
106
139
|
profile = profile[0]
|
107
|
-
profile[
|
140
|
+
profile["id"] = agent._uuid
|
108
141
|
profiles.append(profile)
|
109
142
|
fastavro.writer(f, PROFILE_SCHEMA, profiles)
|
110
143
|
|
@@ -139,7 +172,9 @@ class AgentGroup:
|
|
139
172
|
return results
|
140
173
|
|
141
174
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
142
|
-
logger.debug(
|
175
|
+
logger.debug(
|
176
|
+
f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
|
177
|
+
)
|
143
178
|
agent = self.id2agent[target_agent_uuid]
|
144
179
|
await agent.memory.update(target_key, content)
|
145
180
|
|
@@ -147,7 +182,9 @@ class AgentGroup:
|
|
147
182
|
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
148
183
|
while True:
|
149
184
|
if not self.messager.is_connected():
|
150
|
-
logger.warning(
|
185
|
+
logger.warning(
|
186
|
+
"Messager is not connected. Skipping message processing."
|
187
|
+
)
|
151
188
|
|
152
189
|
# Step 1: 获取消息
|
153
190
|
messages = await self.messager.fetch_messages()
|
@@ -165,7 +202,7 @@ class AgentGroup:
|
|
165
202
|
|
166
203
|
# 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
|
167
204
|
_, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
|
168
|
-
|
205
|
+
|
169
206
|
if agent_uuid in self.id2agent:
|
170
207
|
agent = self.id2agent[agent_uuid]
|
171
208
|
# topic_type: agent-chat, user-chat, user-survey, gather
|