pycityagent 2.0.0a43__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +23 -0
- pycityagent/agent.py +833 -0
- pycityagent/cli/wrapper.py +44 -0
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +355 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +198 -0
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +416 -0
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +39 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +44 -0
- pycityagent/environment/sim/economy_services.py +192 -0
- pycityagent/environment/sim/lane_service.py +111 -0
- pycityagent/environment/sim/light_service.py +122 -0
- pycityagent/environment/sim/person_service.py +295 -0
- pycityagent/environment/sim/road_service.py +39 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +59 -0
- pycityagent/environment/simulator.py +331 -0
- pycityagent/environment/utils/__init__.py +14 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +244 -0
- pycityagent/environment/utils/geojson.py +24 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +41 -0
- pycityagent/llm/__init__.py +11 -0
- pycityagent/llm/embeddings.py +231 -0
- pycityagent/llm/llm.py +377 -0
- pycityagent/llm/llmconfig.py +13 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +13 -0
- pycityagent/memory/const.py +43 -0
- pycityagent/memory/faiss_query.py +302 -0
- pycityagent/memory/memory.py +448 -0
- pycityagent/memory/memory_base.py +170 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +28 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +88 -0
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +8 -0
- pycityagent/simulation/agentgroup.py +580 -0
- pycityagent/simulation/simulation.py +634 -0
- pycityagent/simulation/storage/pg.py +184 -0
- pycityagent/survey/__init__.py +4 -0
- pycityagent/survey/manager.py +54 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +11 -0
- pycityagent/utils/avro_schema.py +109 -0
- pycityagent/utils/decorators.py +99 -0
- pycityagent/utils/parsers/__init__.py +13 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/utils/pg_query.py +92 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/__init__.py +26 -0
- pycityagent/workflow/block.py +211 -0
- pycityagent/workflow/prompt.py +79 -0
- pycityagent/workflow/tool.py +240 -0
- pycityagent/workflow/trigger.py +163 -0
- pycityagent-2.0.0a43.dist-info/LICENSE +21 -0
- pycityagent-2.0.0a43.dist-info/METADATA +235 -0
- pycityagent-2.0.0a43.dist-info/RECORD +81 -0
- pycityagent-2.0.0a43.dist-info/WHEEL +5 -0
- pycityagent-2.0.0a43.dist-info/entry_points.txt +3 -0
- pycityagent-2.0.0a43.dist-info/top_level.txt +3 -0
@@ -0,0 +1,184 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import psycopg
|
6
|
+
import psycopg.sql
|
7
|
+
import ray
|
8
|
+
from psycopg.rows import dict_row
|
9
|
+
|
10
|
+
from ...utils.decorators import lock_decorator
|
11
|
+
from ...utils.pg_query import PGSQL_DICT, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
12
|
+
|
13
|
+
logger = logging.getLogger("pg")
|
14
|
+
|
15
|
+
|
16
|
+
def create_pg_tables(exp_id: str, dsn: str):
|
17
|
+
for table_type, exec_strs in PGSQL_DICT.items():
|
18
|
+
if not table_type == "experiment":
|
19
|
+
table_name = f"socialcity_{exp_id.replace('-', '_')}_{table_type}"
|
20
|
+
else:
|
21
|
+
table_name = f"socialcity_{table_type}"
|
22
|
+
# # debug str
|
23
|
+
# for _str in [f"DROP TABLE IF EXISTS {table_name}"] + [
|
24
|
+
# _exec_str.format(table_name=table_name) for _exec_str in exec_strs
|
25
|
+
# ]:
|
26
|
+
# print(_str)
|
27
|
+
with psycopg.connect(dsn) as conn:
|
28
|
+
with conn.cursor() as cur:
|
29
|
+
if not table_type == "experiment":
|
30
|
+
# delete table
|
31
|
+
cur.execute(f"DROP TABLE IF EXISTS {table_name}") # type:ignore
|
32
|
+
logger.debug(
|
33
|
+
f"table:{table_name} sql: DROP TABLE IF EXISTS {table_name}"
|
34
|
+
)
|
35
|
+
conn.commit()
|
36
|
+
# create table
|
37
|
+
for _exec_str in exec_strs:
|
38
|
+
exec_str = _exec_str.format(table_name=table_name)
|
39
|
+
cur.execute(exec_str)
|
40
|
+
logger.debug(f"table:{table_name} sql: {exec_str}")
|
41
|
+
conn.commit()
|
42
|
+
|
43
|
+
|
44
|
+
@ray.remote
|
45
|
+
class PgWriter:
|
46
|
+
def __init__(self, exp_id: str, dsn: str):
|
47
|
+
self.exp_id = exp_id
|
48
|
+
self._dsn = dsn
|
49
|
+
self._lock = asyncio.Lock()
|
50
|
+
|
51
|
+
@lock_decorator
|
52
|
+
async def async_write_dialog(self, rows: list[tuple]):
|
53
|
+
_tuple_types = [str, int, float, int, str, str, str, None]
|
54
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_dialog"
|
55
|
+
# 将数据插入数据库
|
56
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
57
|
+
copy_sql = psycopg.sql.SQL(
|
58
|
+
"COPY {} (id, day, t, type, speaker, content, created_at) FROM STDIN"
|
59
|
+
).format(psycopg.sql.Identifier(table_name))
|
60
|
+
_rows: list[Any] = []
|
61
|
+
async with aconn.cursor() as cur:
|
62
|
+
async with cur.copy(copy_sql) as copy:
|
63
|
+
for row in rows:
|
64
|
+
_row = [
|
65
|
+
_type(r) if _type is not None else r
|
66
|
+
for (_type, r) in zip(_tuple_types, row)
|
67
|
+
]
|
68
|
+
await copy.write_row(_row)
|
69
|
+
_rows.append(_row)
|
70
|
+
logger.debug(f"table:{table_name} sql: {copy_sql} values: {_rows}")
|
71
|
+
|
72
|
+
@lock_decorator
|
73
|
+
async def async_write_status(self, rows: list[tuple]):
|
74
|
+
_tuple_types = [str, int, float, float, float, int, str, str, None]
|
75
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
|
76
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
77
|
+
copy_sql = psycopg.sql.SQL(
|
78
|
+
"COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
|
79
|
+
).format(psycopg.sql.Identifier(table_name))
|
80
|
+
_rows: list[Any] = []
|
81
|
+
async with aconn.cursor() as cur:
|
82
|
+
async with cur.copy(copy_sql) as copy:
|
83
|
+
for row in rows:
|
84
|
+
_row = [
|
85
|
+
_type(r) if _type is not None else r
|
86
|
+
for (_type, r) in zip(_tuple_types, row)
|
87
|
+
]
|
88
|
+
await copy.write_row(_row)
|
89
|
+
_rows.append(_row)
|
90
|
+
logger.debug(f"table:{table_name} sql: {copy_sql} values: {_rows}")
|
91
|
+
|
92
|
+
@lock_decorator
|
93
|
+
async def async_write_profile(self, rows: list[tuple]):
|
94
|
+
_tuple_types = [str, str, str]
|
95
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_profile"
|
96
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
97
|
+
copy_sql = psycopg.sql.SQL("COPY {} (id, name, profile) FROM STDIN").format(
|
98
|
+
psycopg.sql.Identifier(table_name)
|
99
|
+
)
|
100
|
+
_rows: list[Any] = []
|
101
|
+
async with aconn.cursor() as cur:
|
102
|
+
async with cur.copy(copy_sql) as copy:
|
103
|
+
for row in rows:
|
104
|
+
_row = [
|
105
|
+
_type(r) if _type is not None else r
|
106
|
+
for (_type, r) in zip(_tuple_types, row)
|
107
|
+
]
|
108
|
+
await copy.write_row(_row)
|
109
|
+
_rows.append(_row)
|
110
|
+
logger.debug(f"table:{table_name} sql: {copy_sql} values: {_rows}")
|
111
|
+
|
112
|
+
@lock_decorator
|
113
|
+
async def async_write_survey(self, rows: list[tuple]):
|
114
|
+
_tuple_types = [str, int, float, str, str, None]
|
115
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_survey"
|
116
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
117
|
+
copy_sql = psycopg.sql.SQL(
|
118
|
+
"COPY {} (id, day, t, survey_id, result, created_at) FROM STDIN"
|
119
|
+
).format(psycopg.sql.Identifier(table_name))
|
120
|
+
_rows: list[Any] = []
|
121
|
+
async with aconn.cursor() as cur:
|
122
|
+
async with cur.copy(copy_sql) as copy:
|
123
|
+
for row in rows:
|
124
|
+
_row = [
|
125
|
+
_type(r) if _type is not None else r
|
126
|
+
for (_type, r) in zip(_tuple_types, row)
|
127
|
+
]
|
128
|
+
await copy.write_row(_row)
|
129
|
+
_rows.append(_row)
|
130
|
+
logger.debug(f"table:{table_name} sql: {copy_sql} values: {_rows}")
|
131
|
+
|
132
|
+
@lock_decorator
|
133
|
+
async def async_update_exp_info(self, exp_info: dict[str, Any]):
|
134
|
+
# timestamp不做类型转换
|
135
|
+
table_name = f"socialcity_experiment"
|
136
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
137
|
+
async with aconn.cursor(row_factory=dict_row) as cur:
|
138
|
+
exec_str = "SELECT * FROM {table_name} WHERE id=%s".format(
|
139
|
+
table_name=table_name
|
140
|
+
), (self.exp_id,)
|
141
|
+
await cur.execute(
|
142
|
+
"SELECT * FROM {table_name} WHERE id=%s".format(
|
143
|
+
table_name=table_name
|
144
|
+
),
|
145
|
+
(self.exp_id,),
|
146
|
+
) # type:ignore
|
147
|
+
logger.debug(f"table:{table_name} sql: {exec_str}")
|
148
|
+
record_exists = await cur.fetchall()
|
149
|
+
if record_exists:
|
150
|
+
# UPDATE
|
151
|
+
columns = ", ".join(
|
152
|
+
f"{key} = %s" for key, _ in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
153
|
+
)
|
154
|
+
update_sql = psycopg.sql.SQL(
|
155
|
+
f"UPDATE {{}} SET {columns} WHERE id='{self.exp_id}'" # type:ignore
|
156
|
+
).format(psycopg.sql.Identifier(table_name))
|
157
|
+
params = [
|
158
|
+
_type(exp_info[key]) if _type is not None else exp_info[key]
|
159
|
+
for key, _type in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
160
|
+
]
|
161
|
+
logger.debug(
|
162
|
+
f"table:{table_name} sql: {update_sql} values: {params}"
|
163
|
+
)
|
164
|
+
await cur.execute(update_sql, params)
|
165
|
+
else:
|
166
|
+
# INSERT
|
167
|
+
keys = ", ".join(
|
168
|
+
key for key, _ in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
169
|
+
)
|
170
|
+
placeholders = ", ".join(
|
171
|
+
["%s"] * len(TO_UPDATE_EXP_INFO_KEYS_AND_TYPES)
|
172
|
+
)
|
173
|
+
insert_sql = psycopg.sql.SQL(
|
174
|
+
f"INSERT INTO {{}} ({keys}) VALUES ({placeholders})" # type:ignore
|
175
|
+
).format(psycopg.sql.Identifier(table_name))
|
176
|
+
params = [
|
177
|
+
_type(exp_info[key]) if _type is not None else exp_info[key]
|
178
|
+
for key, _type in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
179
|
+
]
|
180
|
+
logger.debug(
|
181
|
+
f"table:{table_name} sql: {insert_sql} values: {params}"
|
182
|
+
)
|
183
|
+
await cur.execute(insert_sql, params)
|
184
|
+
await aconn.commit()
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import json
|
2
|
+
import uuid
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
from .models import Page, Question, QuestionType, Survey
|
7
|
+
|
8
|
+
|
9
|
+
class SurveyManager:
|
10
|
+
def __init__(self):
|
11
|
+
self._surveys: dict[str, Survey] = {}
|
12
|
+
|
13
|
+
def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
|
14
|
+
"""创建新问卷"""
|
15
|
+
survey_id = uuid.uuid4()
|
16
|
+
|
17
|
+
# 转换页面和问题数据
|
18
|
+
survey_pages = []
|
19
|
+
for page_data in pages:
|
20
|
+
questions = []
|
21
|
+
for q in page_data["elements"]:
|
22
|
+
question = Question(
|
23
|
+
name=q["name"],
|
24
|
+
title=q["title"],
|
25
|
+
type=QuestionType(q["type"]),
|
26
|
+
required=q.get("required", True),
|
27
|
+
choices=q.get("choices", []),
|
28
|
+
columns=q.get("columns", []),
|
29
|
+
rows=q.get("rows", []),
|
30
|
+
min_rating=q.get("min_rating", 1),
|
31
|
+
max_rating=q.get("max_rating", 5),
|
32
|
+
)
|
33
|
+
questions.append(question)
|
34
|
+
|
35
|
+
page = Page(name=page_data["name"], elements=questions)
|
36
|
+
survey_pages.append(page)
|
37
|
+
|
38
|
+
survey = Survey(
|
39
|
+
id=survey_id,
|
40
|
+
title=title,
|
41
|
+
description=description,
|
42
|
+
pages=survey_pages,
|
43
|
+
)
|
44
|
+
|
45
|
+
self._surveys[str(survey_id)] = survey
|
46
|
+
return survey
|
47
|
+
|
48
|
+
def get_survey(self, survey_id: str) -> Optional[Survey]:
|
49
|
+
"""获取指定问卷"""
|
50
|
+
return self._surveys.get(survey_id)
|
51
|
+
|
52
|
+
def get_all_surveys(self) -> list[Survey]:
|
53
|
+
"""获取所有问卷"""
|
54
|
+
return list(self._surveys.values())
|
@@ -0,0 +1,120 @@
|
|
1
|
+
import json
|
2
|
+
import uuid
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from datetime import datetime
|
5
|
+
from enum import Enum
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
|
9
|
+
class QuestionType(Enum):
|
10
|
+
TEXT = "text"
|
11
|
+
RADIO = "radiogroup"
|
12
|
+
CHECKBOX = "checkbox"
|
13
|
+
BOOLEAN = "boolean"
|
14
|
+
RATING = "rating"
|
15
|
+
MATRIX = "matrix"
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class Question:
|
20
|
+
name: str
|
21
|
+
title: str
|
22
|
+
type: QuestionType
|
23
|
+
choices: list[str] = field(default_factory=list)
|
24
|
+
columns: list[str] = field(default_factory=list)
|
25
|
+
rows: list[str] = field(default_factory=list)
|
26
|
+
required: bool = True
|
27
|
+
min_rating: int = 1
|
28
|
+
max_rating: int = 5
|
29
|
+
|
30
|
+
def to_dict(self) -> dict:
|
31
|
+
base_dict: dict[str, Any] = {
|
32
|
+
"type": self.type.value,
|
33
|
+
"name": self.name,
|
34
|
+
"title": self.title,
|
35
|
+
}
|
36
|
+
|
37
|
+
if self.type in [QuestionType.RADIO, QuestionType.CHECKBOX]:
|
38
|
+
base_dict["choices"] = self.choices
|
39
|
+
elif self.type == QuestionType.MATRIX:
|
40
|
+
base_dict["columns"] = self.columns
|
41
|
+
base_dict["rows"] = self.rows
|
42
|
+
elif self.type == QuestionType.RATING:
|
43
|
+
base_dict["min_rating"] = self.min_rating
|
44
|
+
base_dict["max_rating"] = self.max_rating
|
45
|
+
|
46
|
+
return base_dict
|
47
|
+
|
48
|
+
|
49
|
+
@dataclass
|
50
|
+
class Page:
|
51
|
+
name: str
|
52
|
+
elements: list[Question]
|
53
|
+
|
54
|
+
def to_dict(self) -> dict:
|
55
|
+
return {"name": self.name, "elements": [q.to_dict() for q in self.elements]}
|
56
|
+
|
57
|
+
|
58
|
+
@dataclass
|
59
|
+
class Survey:
|
60
|
+
id: uuid.UUID
|
61
|
+
title: str
|
62
|
+
description: str
|
63
|
+
pages: list[Page]
|
64
|
+
responses: dict[str, dict] = field(default_factory=dict)
|
65
|
+
created_at: datetime = field(default_factory=datetime.now)
|
66
|
+
|
67
|
+
def to_dict(self) -> dict:
|
68
|
+
return {
|
69
|
+
"id": str(self.id),
|
70
|
+
"title": self.title,
|
71
|
+
"description": self.description,
|
72
|
+
"pages": [p.to_dict() for p in self.pages],
|
73
|
+
"response_count": len(self.responses),
|
74
|
+
}
|
75
|
+
|
76
|
+
def to_json(self) -> str:
|
77
|
+
"""Convert the survey to a JSON string for MQTT transmission"""
|
78
|
+
survey_dict = {
|
79
|
+
"id": str(self.id),
|
80
|
+
"title": self.title,
|
81
|
+
"description": self.description,
|
82
|
+
"pages": [p.to_dict() for p in self.pages],
|
83
|
+
"responses": self.responses,
|
84
|
+
"created_at": self.created_at.isoformat(),
|
85
|
+
}
|
86
|
+
return json.dumps(survey_dict)
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def from_json(cls, json_str: str) -> "Survey":
|
90
|
+
"""Create a Survey instance from a JSON string"""
|
91
|
+
data = json.loads(json_str)
|
92
|
+
pages = [
|
93
|
+
Page(
|
94
|
+
name=p["name"],
|
95
|
+
elements=[
|
96
|
+
Question(
|
97
|
+
name=q["name"],
|
98
|
+
title=q["title"],
|
99
|
+
type=QuestionType(q["type"]),
|
100
|
+
required=q.get("required", True),
|
101
|
+
choices=q.get("choices", []),
|
102
|
+
columns=q.get("columns", []),
|
103
|
+
rows=q.get("rows", []),
|
104
|
+
min_rating=q.get("min_rating", 1),
|
105
|
+
max_rating=q.get("max_rating", 5),
|
106
|
+
)
|
107
|
+
for q in p["elements"]
|
108
|
+
],
|
109
|
+
)
|
110
|
+
for p in data["pages"]
|
111
|
+
]
|
112
|
+
|
113
|
+
return cls(
|
114
|
+
id=uuid.UUID(data["id"]),
|
115
|
+
title=data["title"],
|
116
|
+
description=data["description"],
|
117
|
+
pages=pages,
|
118
|
+
responses=data.get("responses", {}),
|
119
|
+
created_at=datetime.fromisoformat(data["created_at"]),
|
120
|
+
)
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from .avro_schema import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA,
|
2
|
+
PROFILE_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA)
|
3
|
+
from .pg_query import PGSQL_DICT, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
4
|
+
from .survey_util import process_survey_for_llm
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
|
8
|
+
"process_survey_for_llm",
|
9
|
+
"TO_UPDATE_EXP_INFO_KEYS_AND_TYPES",
|
10
|
+
"PGSQL_DICT",
|
11
|
+
]
|
@@ -0,0 +1,109 @@
|
|
1
|
+
PROFILE_SCHEMA = {
|
2
|
+
"doc": "Agent属性",
|
3
|
+
"name": "AgentProfile",
|
4
|
+
"namespace": "com.socialcity",
|
5
|
+
"type": "record",
|
6
|
+
"fields": [
|
7
|
+
{"name": "id", "type": "string"}, # uuid as string
|
8
|
+
{"name": "name", "type": "string"},
|
9
|
+
{"name": "gender", "type": "string"},
|
10
|
+
{"name": "age", "type": "float"},
|
11
|
+
{"name": "education", "type": "string"},
|
12
|
+
{"name": "skill", "type": "string"},
|
13
|
+
{"name": "occupation", "type": "string"},
|
14
|
+
{"name": "family_consumption", "type": "string"},
|
15
|
+
{"name": "consumption", "type": "string"},
|
16
|
+
{"name": "personality", "type": "string"},
|
17
|
+
{"name": "income", "type": "string"},
|
18
|
+
{"name": "currency", "type": "float"},
|
19
|
+
{"name": "residence", "type": "string"},
|
20
|
+
{"name": "race", "type": "string"},
|
21
|
+
{"name": "religion", "type": "string"},
|
22
|
+
{"name": "marital_status", "type": "string"},
|
23
|
+
],
|
24
|
+
}
|
25
|
+
|
26
|
+
DIALOG_SCHEMA = {
|
27
|
+
"doc": "Agent对话",
|
28
|
+
"name": "AgentDialog",
|
29
|
+
"namespace": "com.socialcity",
|
30
|
+
"type": "record",
|
31
|
+
"fields": [
|
32
|
+
{"name": "id", "type": "string"}, # uuid as string
|
33
|
+
{"name": "day", "type": "int"},
|
34
|
+
{"name": "t", "type": "float"},
|
35
|
+
{"name": "type", "type": "int"},
|
36
|
+
{"name": "speaker", "type": "string"},
|
37
|
+
{"name": "content", "type": "string"},
|
38
|
+
{
|
39
|
+
"name": "created_at",
|
40
|
+
"type": {"type": "long", "logicalType": "timestamp-millis"},
|
41
|
+
},
|
42
|
+
],
|
43
|
+
}
|
44
|
+
|
45
|
+
STATUS_SCHEMA = {
|
46
|
+
"doc": "Agent状态",
|
47
|
+
"name": "AgentStatus",
|
48
|
+
"namespace": "com.socialcity",
|
49
|
+
"type": "record",
|
50
|
+
"fields": [
|
51
|
+
{"name": "id", "type": "string"}, # uuid as string
|
52
|
+
{"name": "day", "type": "int"},
|
53
|
+
{"name": "t", "type": "float"},
|
54
|
+
{"name": "lng", "type": "double"},
|
55
|
+
{"name": "lat", "type": "double"},
|
56
|
+
{"name": "parent_id", "type": "int"},
|
57
|
+
{"name": "action", "type": "string"},
|
58
|
+
{"name": "hungry", "type": "float"},
|
59
|
+
{"name": "tired", "type": "float"},
|
60
|
+
{"name": "safe", "type": "float"},
|
61
|
+
{"name": "social", "type": "float"},
|
62
|
+
{
|
63
|
+
"name": "created_at",
|
64
|
+
"type": {"type": "long", "logicalType": "timestamp-millis"},
|
65
|
+
},
|
66
|
+
],
|
67
|
+
}
|
68
|
+
|
69
|
+
INSTITUTION_STATUS_SCHEMA = {
|
70
|
+
"doc": "Institution状态",
|
71
|
+
"name": "InstitutionStatus",
|
72
|
+
"namespace": "com.socialcity",
|
73
|
+
"type": "record",
|
74
|
+
"fields": [
|
75
|
+
{"name": "id", "type": "string"}, # uuid as string
|
76
|
+
{"name": "day", "type": "int"},
|
77
|
+
{"name": "t", "type": "float"},
|
78
|
+
{"name": "type", "type": "int"},
|
79
|
+
{"name": "nominal_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
80
|
+
{"name": "real_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
81
|
+
{"name": "unemployment", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
82
|
+
{"name": "wages", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
83
|
+
{"name": "prices", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
84
|
+
{"name": "inventory", "type": ["int", "null"]},
|
85
|
+
{"name": "price", "type": ["float", "null"]},
|
86
|
+
{"name": "interest_rate", "type": ["float", "null"]},
|
87
|
+
{"name": "bracket_cutoffs", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
88
|
+
{"name": "bracket_rates", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
89
|
+
{"name": "employees", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
90
|
+
],
|
91
|
+
}
|
92
|
+
|
93
|
+
SURVEY_SCHEMA = {
|
94
|
+
"doc": "Agent问卷",
|
95
|
+
"name": "AgentSurvey",
|
96
|
+
"namespace": "com.socialcity",
|
97
|
+
"type": "record",
|
98
|
+
"fields": [
|
99
|
+
{"name": "id", "type": "string"}, # uuid as string
|
100
|
+
{"name": "day", "type": "int"},
|
101
|
+
{"name": "t", "type": "float"},
|
102
|
+
{"name": "survey_id", "type": "string"},
|
103
|
+
{"name": "result", "type": "string"},
|
104
|
+
{
|
105
|
+
"name": "created_at",
|
106
|
+
"type": {"type": "long", "logicalType": "timestamp-millis"},
|
107
|
+
},
|
108
|
+
],
|
109
|
+
}
|
@@ -0,0 +1,99 @@
|
|
1
|
+
import time
|
2
|
+
import functools
|
3
|
+
import inspect
|
4
|
+
|
5
|
+
CALLING_STRING = 'function: `{func_name}` in "{file_path}", line {line_number}, arguments: `{arguments}` start time: `{start_time}` end time: `{end_time}` output: `{output}`'
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"record_call_aio",
|
9
|
+
"record_call",
|
10
|
+
"lock_decorator",
|
11
|
+
]
|
12
|
+
|
13
|
+
|
14
|
+
def record_call_aio(record_function_calling: bool = True):
|
15
|
+
"""
|
16
|
+
Decorator to log the async function call details if `record_function_calling` is True.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def decorator(func):
|
20
|
+
async def wrapper(*args, **kwargs):
|
21
|
+
cur_frame = inspect.currentframe()
|
22
|
+
assert cur_frame is not None
|
23
|
+
frame = cur_frame.f_back
|
24
|
+
assert frame is not None
|
25
|
+
line_number = frame.f_lineno
|
26
|
+
file_path = frame.f_code.co_filename
|
27
|
+
args_repr = [repr(a) for a in args]
|
28
|
+
kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
|
29
|
+
signature = ", ".join(args_repr + kwargs_repr)
|
30
|
+
start_time = time.time()
|
31
|
+
result = await func(*args, **kwargs)
|
32
|
+
end_time = time.time()
|
33
|
+
if record_function_calling:
|
34
|
+
print(
|
35
|
+
CALLING_STRING.format(
|
36
|
+
func_name=func,
|
37
|
+
line_number=line_number,
|
38
|
+
file_path=file_path,
|
39
|
+
arguments=signature,
|
40
|
+
start_time=start_time,
|
41
|
+
end_time=end_time,
|
42
|
+
output=result,
|
43
|
+
)
|
44
|
+
)
|
45
|
+
return result
|
46
|
+
|
47
|
+
return wrapper
|
48
|
+
|
49
|
+
return decorator
|
50
|
+
|
51
|
+
|
52
|
+
def record_call(record_function_calling: bool = True):
|
53
|
+
"""
|
54
|
+
Decorator to log the function call details if `record_function_calling` is True.
|
55
|
+
"""
|
56
|
+
|
57
|
+
def decorator(func):
|
58
|
+
def wrapper(*args, **kwargs):
|
59
|
+
cur_frame = inspect.currentframe()
|
60
|
+
assert cur_frame is not None
|
61
|
+
frame = cur_frame.f_back
|
62
|
+
assert frame is not None
|
63
|
+
line_number = frame.f_lineno
|
64
|
+
file_path = frame.f_code.co_filename
|
65
|
+
args_repr = [repr(a) for a in args]
|
66
|
+
kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
|
67
|
+
signature = ", ".join(args_repr + kwargs_repr)
|
68
|
+
start_time = time.time()
|
69
|
+
result = func(*args, **kwargs)
|
70
|
+
end_time = time.time()
|
71
|
+
if record_function_calling:
|
72
|
+
print(
|
73
|
+
CALLING_STRING.format(
|
74
|
+
func_name=func,
|
75
|
+
line_number=line_number,
|
76
|
+
file_path=file_path,
|
77
|
+
arguments=signature,
|
78
|
+
start_time=start_time,
|
79
|
+
end_time=end_time,
|
80
|
+
output=result,
|
81
|
+
)
|
82
|
+
)
|
83
|
+
return result
|
84
|
+
|
85
|
+
return wrapper
|
86
|
+
|
87
|
+
return decorator
|
88
|
+
|
89
|
+
|
90
|
+
def lock_decorator(func):
|
91
|
+
async def wrapper(self, *args, **kwargs):
|
92
|
+
lock = self._lock
|
93
|
+
await lock.acquire()
|
94
|
+
try:
|
95
|
+
return await func(self, *args, **kwargs)
|
96
|
+
finally:
|
97
|
+
lock.release()
|
98
|
+
|
99
|
+
return wrapper
|
@@ -0,0 +1,13 @@
|
|
1
|
+
"""Model response parser module."""
|
2
|
+
|
3
|
+
from .parser_base import ParserBase
|
4
|
+
from .json_parser import JsonDictParser, JsonObjectParser
|
5
|
+
from .code_block_parser import CodeBlockParser
|
6
|
+
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"ParserBase",
|
10
|
+
"JsonDictParser",
|
11
|
+
"JsonObjectParser",
|
12
|
+
"CodeBlockParser",
|
13
|
+
]
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from .parser_base import ParserBase
|
5
|
+
|
6
|
+
|
7
|
+
class CodeBlockParser(ParserBase):
|
8
|
+
"""A parser that extracts specific objects from a response string enclosed within specific tags.
|
9
|
+
|
10
|
+
Attributes:
|
11
|
+
tag_start (str): The start tag used to identify the beginning of the object.
|
12
|
+
tag_end (str): The end tag used to identify the end of the object.
|
13
|
+
"""
|
14
|
+
|
15
|
+
tag_start = "```{language_name}"
|
16
|
+
tag_end = "```"
|
17
|
+
|
18
|
+
def __init__(self, language_name: str) -> None:
|
19
|
+
"""Initialize the CodeBlockParser with default tags."""
|
20
|
+
super().__init__()
|
21
|
+
self.language_name = language_name
|
22
|
+
|
23
|
+
def parse(self, response: str) -> str:
|
24
|
+
"""Parse the response string to extract and return a str.
|
25
|
+
|
26
|
+
Parameters:
|
27
|
+
response (str): The response string containing the specified language object.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
str: The parsed `str` object.
|
31
|
+
"""
|
32
|
+
extract_text = self._extract_text_within_tags(
|
33
|
+
response=response,
|
34
|
+
tag_start=self.tag_start.format(language_name=self.language_name),
|
35
|
+
tag_end=self.tag_end,
|
36
|
+
)
|
37
|
+
return extract_text
|