pycityagent 2.0.0a43__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. pycityagent/__init__.py +23 -0
  2. pycityagent/agent.py +833 -0
  3. pycityagent/cli/wrapper.py +44 -0
  4. pycityagent/economy/__init__.py +5 -0
  5. pycityagent/economy/econ_client.py +355 -0
  6. pycityagent/environment/__init__.py +7 -0
  7. pycityagent/environment/interact/__init__.py +0 -0
  8. pycityagent/environment/interact/interact.py +198 -0
  9. pycityagent/environment/message/__init__.py +0 -0
  10. pycityagent/environment/sence/__init__.py +0 -0
  11. pycityagent/environment/sence/static.py +416 -0
  12. pycityagent/environment/sidecar/__init__.py +8 -0
  13. pycityagent/environment/sidecar/sidecarv2.py +109 -0
  14. pycityagent/environment/sim/__init__.py +29 -0
  15. pycityagent/environment/sim/aoi_service.py +39 -0
  16. pycityagent/environment/sim/client.py +126 -0
  17. pycityagent/environment/sim/clock_service.py +44 -0
  18. pycityagent/environment/sim/economy_services.py +192 -0
  19. pycityagent/environment/sim/lane_service.py +111 -0
  20. pycityagent/environment/sim/light_service.py +122 -0
  21. pycityagent/environment/sim/person_service.py +295 -0
  22. pycityagent/environment/sim/road_service.py +39 -0
  23. pycityagent/environment/sim/sim_env.py +145 -0
  24. pycityagent/environment/sim/social_service.py +59 -0
  25. pycityagent/environment/simulator.py +331 -0
  26. pycityagent/environment/utils/__init__.py +14 -0
  27. pycityagent/environment/utils/base64.py +16 -0
  28. pycityagent/environment/utils/const.py +244 -0
  29. pycityagent/environment/utils/geojson.py +24 -0
  30. pycityagent/environment/utils/grpc.py +57 -0
  31. pycityagent/environment/utils/map_utils.py +157 -0
  32. pycityagent/environment/utils/port.py +11 -0
  33. pycityagent/environment/utils/protobuf.py +41 -0
  34. pycityagent/llm/__init__.py +11 -0
  35. pycityagent/llm/embeddings.py +231 -0
  36. pycityagent/llm/llm.py +377 -0
  37. pycityagent/llm/llmconfig.py +13 -0
  38. pycityagent/llm/utils.py +6 -0
  39. pycityagent/memory/__init__.py +13 -0
  40. pycityagent/memory/const.py +43 -0
  41. pycityagent/memory/faiss_query.py +302 -0
  42. pycityagent/memory/memory.py +448 -0
  43. pycityagent/memory/memory_base.py +170 -0
  44. pycityagent/memory/profile.py +165 -0
  45. pycityagent/memory/self_define.py +165 -0
  46. pycityagent/memory/state.py +173 -0
  47. pycityagent/memory/utils.py +28 -0
  48. pycityagent/message/__init__.py +3 -0
  49. pycityagent/message/messager.py +88 -0
  50. pycityagent/metrics/__init__.py +6 -0
  51. pycityagent/metrics/mlflow_client.py +147 -0
  52. pycityagent/metrics/utils/const.py +0 -0
  53. pycityagent/pycityagent-sim +0 -0
  54. pycityagent/pycityagent-ui +0 -0
  55. pycityagent/simulation/__init__.py +8 -0
  56. pycityagent/simulation/agentgroup.py +580 -0
  57. pycityagent/simulation/simulation.py +634 -0
  58. pycityagent/simulation/storage/pg.py +184 -0
  59. pycityagent/survey/__init__.py +4 -0
  60. pycityagent/survey/manager.py +54 -0
  61. pycityagent/survey/models.py +120 -0
  62. pycityagent/utils/__init__.py +11 -0
  63. pycityagent/utils/avro_schema.py +109 -0
  64. pycityagent/utils/decorators.py +99 -0
  65. pycityagent/utils/parsers/__init__.py +13 -0
  66. pycityagent/utils/parsers/code_block_parser.py +37 -0
  67. pycityagent/utils/parsers/json_parser.py +86 -0
  68. pycityagent/utils/parsers/parser_base.py +60 -0
  69. pycityagent/utils/pg_query.py +92 -0
  70. pycityagent/utils/survey_util.py +53 -0
  71. pycityagent/workflow/__init__.py +26 -0
  72. pycityagent/workflow/block.py +211 -0
  73. pycityagent/workflow/prompt.py +79 -0
  74. pycityagent/workflow/tool.py +240 -0
  75. pycityagent/workflow/trigger.py +163 -0
  76. pycityagent-2.0.0a43.dist-info/LICENSE +21 -0
  77. pycityagent-2.0.0a43.dist-info/METADATA +235 -0
  78. pycityagent-2.0.0a43.dist-info/RECORD +81 -0
  79. pycityagent-2.0.0a43.dist-info/WHEEL +5 -0
  80. pycityagent-2.0.0a43.dist-info/entry_points.txt +3 -0
  81. pycityagent-2.0.0a43.dist-info/top_level.txt +3 -0
@@ -0,0 +1,184 @@
1
+ import asyncio
2
+ import logging
3
+ from typing import Any
4
+
5
+ import psycopg
6
+ import psycopg.sql
7
+ import ray
8
+ from psycopg.rows import dict_row
9
+
10
+ from ...utils.decorators import lock_decorator
11
+ from ...utils.pg_query import PGSQL_DICT, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
12
+
13
+ logger = logging.getLogger("pg")
14
+
15
+
16
+ def create_pg_tables(exp_id: str, dsn: str):
17
+ for table_type, exec_strs in PGSQL_DICT.items():
18
+ if not table_type == "experiment":
19
+ table_name = f"socialcity_{exp_id.replace('-', '_')}_{table_type}"
20
+ else:
21
+ table_name = f"socialcity_{table_type}"
22
+ # # debug str
23
+ # for _str in [f"DROP TABLE IF EXISTS {table_name}"] + [
24
+ # _exec_str.format(table_name=table_name) for _exec_str in exec_strs
25
+ # ]:
26
+ # print(_str)
27
+ with psycopg.connect(dsn) as conn:
28
+ with conn.cursor() as cur:
29
+ if not table_type == "experiment":
30
+ # delete table
31
+ cur.execute(f"DROP TABLE IF EXISTS {table_name}") # type:ignore
32
+ logger.debug(
33
+ f"table:{table_name} sql: DROP TABLE IF EXISTS {table_name}"
34
+ )
35
+ conn.commit()
36
+ # create table
37
+ for _exec_str in exec_strs:
38
+ exec_str = _exec_str.format(table_name=table_name)
39
+ cur.execute(exec_str)
40
+ logger.debug(f"table:{table_name} sql: {exec_str}")
41
+ conn.commit()
42
+
43
+
44
+ @ray.remote
45
+ class PgWriter:
46
+ def __init__(self, exp_id: str, dsn: str):
47
+ self.exp_id = exp_id
48
+ self._dsn = dsn
49
+ self._lock = asyncio.Lock()
50
+
51
+ @lock_decorator
52
+ async def async_write_dialog(self, rows: list[tuple]):
53
+ _tuple_types = [str, int, float, int, str, str, str, None]
54
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_dialog"
55
+ # 将数据插入数据库
56
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
57
+ copy_sql = psycopg.sql.SQL(
58
+ "COPY {} (id, day, t, type, speaker, content, created_at) FROM STDIN"
59
+ ).format(psycopg.sql.Identifier(table_name))
60
+ _rows: list[Any] = []
61
+ async with aconn.cursor() as cur:
62
+ async with cur.copy(copy_sql) as copy:
63
+ for row in rows:
64
+ _row = [
65
+ _type(r) if _type is not None else r
66
+ for (_type, r) in zip(_tuple_types, row)
67
+ ]
68
+ await copy.write_row(_row)
69
+ _rows.append(_row)
70
+ logger.debug(f"table:{table_name} sql: {copy_sql} values: {_rows}")
71
+
72
+ @lock_decorator
73
+ async def async_write_status(self, rows: list[tuple]):
74
+ _tuple_types = [str, int, float, float, float, int, str, str, None]
75
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
76
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
77
+ copy_sql = psycopg.sql.SQL(
78
+ "COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
79
+ ).format(psycopg.sql.Identifier(table_name))
80
+ _rows: list[Any] = []
81
+ async with aconn.cursor() as cur:
82
+ async with cur.copy(copy_sql) as copy:
83
+ for row in rows:
84
+ _row = [
85
+ _type(r) if _type is not None else r
86
+ for (_type, r) in zip(_tuple_types, row)
87
+ ]
88
+ await copy.write_row(_row)
89
+ _rows.append(_row)
90
+ logger.debug(f"table:{table_name} sql: {copy_sql} values: {_rows}")
91
+
92
+ @lock_decorator
93
+ async def async_write_profile(self, rows: list[tuple]):
94
+ _tuple_types = [str, str, str]
95
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_profile"
96
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
97
+ copy_sql = psycopg.sql.SQL("COPY {} (id, name, profile) FROM STDIN").format(
98
+ psycopg.sql.Identifier(table_name)
99
+ )
100
+ _rows: list[Any] = []
101
+ async with aconn.cursor() as cur:
102
+ async with cur.copy(copy_sql) as copy:
103
+ for row in rows:
104
+ _row = [
105
+ _type(r) if _type is not None else r
106
+ for (_type, r) in zip(_tuple_types, row)
107
+ ]
108
+ await copy.write_row(_row)
109
+ _rows.append(_row)
110
+ logger.debug(f"table:{table_name} sql: {copy_sql} values: {_rows}")
111
+
112
+ @lock_decorator
113
+ async def async_write_survey(self, rows: list[tuple]):
114
+ _tuple_types = [str, int, float, str, str, None]
115
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_survey"
116
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
117
+ copy_sql = psycopg.sql.SQL(
118
+ "COPY {} (id, day, t, survey_id, result, created_at) FROM STDIN"
119
+ ).format(psycopg.sql.Identifier(table_name))
120
+ _rows: list[Any] = []
121
+ async with aconn.cursor() as cur:
122
+ async with cur.copy(copy_sql) as copy:
123
+ for row in rows:
124
+ _row = [
125
+ _type(r) if _type is not None else r
126
+ for (_type, r) in zip(_tuple_types, row)
127
+ ]
128
+ await copy.write_row(_row)
129
+ _rows.append(_row)
130
+ logger.debug(f"table:{table_name} sql: {copy_sql} values: {_rows}")
131
+
132
+ @lock_decorator
133
+ async def async_update_exp_info(self, exp_info: dict[str, Any]):
134
+ # timestamp不做类型转换
135
+ table_name = f"socialcity_experiment"
136
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
137
+ async with aconn.cursor(row_factory=dict_row) as cur:
138
+ exec_str = "SELECT * FROM {table_name} WHERE id=%s".format(
139
+ table_name=table_name
140
+ ), (self.exp_id,)
141
+ await cur.execute(
142
+ "SELECT * FROM {table_name} WHERE id=%s".format(
143
+ table_name=table_name
144
+ ),
145
+ (self.exp_id,),
146
+ ) # type:ignore
147
+ logger.debug(f"table:{table_name} sql: {exec_str}")
148
+ record_exists = await cur.fetchall()
149
+ if record_exists:
150
+ # UPDATE
151
+ columns = ", ".join(
152
+ f"{key} = %s" for key, _ in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
153
+ )
154
+ update_sql = psycopg.sql.SQL(
155
+ f"UPDATE {{}} SET {columns} WHERE id='{self.exp_id}'" # type:ignore
156
+ ).format(psycopg.sql.Identifier(table_name))
157
+ params = [
158
+ _type(exp_info[key]) if _type is not None else exp_info[key]
159
+ for key, _type in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
160
+ ]
161
+ logger.debug(
162
+ f"table:{table_name} sql: {update_sql} values: {params}"
163
+ )
164
+ await cur.execute(update_sql, params)
165
+ else:
166
+ # INSERT
167
+ keys = ", ".join(
168
+ key for key, _ in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
169
+ )
170
+ placeholders = ", ".join(
171
+ ["%s"] * len(TO_UPDATE_EXP_INFO_KEYS_AND_TYPES)
172
+ )
173
+ insert_sql = psycopg.sql.SQL(
174
+ f"INSERT INTO {{}} ({keys}) VALUES ({placeholders})" # type:ignore
175
+ ).format(psycopg.sql.Identifier(table_name))
176
+ params = [
177
+ _type(exp_info[key]) if _type is not None else exp_info[key]
178
+ for key, _type in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
179
+ ]
180
+ logger.debug(
181
+ f"table:{table_name} sql: {insert_sql} values: {params}"
182
+ )
183
+ await cur.execute(insert_sql, params)
184
+ await aconn.commit()
@@ -0,0 +1,4 @@
1
+ from .models import QuestionType, Question, Survey
2
+ from .manager import SurveyManager
3
+
4
+ __all__ = ["QuestionType", "Question", "Survey", "SurveyManager"]
@@ -0,0 +1,54 @@
1
+ import json
2
+ import uuid
3
+ from datetime import datetime
4
+ from typing import Optional
5
+
6
+ from .models import Page, Question, QuestionType, Survey
7
+
8
+
9
+ class SurveyManager:
10
+ def __init__(self):
11
+ self._surveys: dict[str, Survey] = {}
12
+
13
+ def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
14
+ """创建新问卷"""
15
+ survey_id = uuid.uuid4()
16
+
17
+ # 转换页面和问题数据
18
+ survey_pages = []
19
+ for page_data in pages:
20
+ questions = []
21
+ for q in page_data["elements"]:
22
+ question = Question(
23
+ name=q["name"],
24
+ title=q["title"],
25
+ type=QuestionType(q["type"]),
26
+ required=q.get("required", True),
27
+ choices=q.get("choices", []),
28
+ columns=q.get("columns", []),
29
+ rows=q.get("rows", []),
30
+ min_rating=q.get("min_rating", 1),
31
+ max_rating=q.get("max_rating", 5),
32
+ )
33
+ questions.append(question)
34
+
35
+ page = Page(name=page_data["name"], elements=questions)
36
+ survey_pages.append(page)
37
+
38
+ survey = Survey(
39
+ id=survey_id,
40
+ title=title,
41
+ description=description,
42
+ pages=survey_pages,
43
+ )
44
+
45
+ self._surveys[str(survey_id)] = survey
46
+ return survey
47
+
48
+ def get_survey(self, survey_id: str) -> Optional[Survey]:
49
+ """获取指定问卷"""
50
+ return self._surveys.get(survey_id)
51
+
52
+ def get_all_surveys(self) -> list[Survey]:
53
+ """获取所有问卷"""
54
+ return list(self._surveys.values())
@@ -0,0 +1,120 @@
1
+ import json
2
+ import uuid
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime
5
+ from enum import Enum
6
+ from typing import Any
7
+
8
+
9
+ class QuestionType(Enum):
10
+ TEXT = "text"
11
+ RADIO = "radiogroup"
12
+ CHECKBOX = "checkbox"
13
+ BOOLEAN = "boolean"
14
+ RATING = "rating"
15
+ MATRIX = "matrix"
16
+
17
+
18
+ @dataclass
19
+ class Question:
20
+ name: str
21
+ title: str
22
+ type: QuestionType
23
+ choices: list[str] = field(default_factory=list)
24
+ columns: list[str] = field(default_factory=list)
25
+ rows: list[str] = field(default_factory=list)
26
+ required: bool = True
27
+ min_rating: int = 1
28
+ max_rating: int = 5
29
+
30
+ def to_dict(self) -> dict:
31
+ base_dict: dict[str, Any] = {
32
+ "type": self.type.value,
33
+ "name": self.name,
34
+ "title": self.title,
35
+ }
36
+
37
+ if self.type in [QuestionType.RADIO, QuestionType.CHECKBOX]:
38
+ base_dict["choices"] = self.choices
39
+ elif self.type == QuestionType.MATRIX:
40
+ base_dict["columns"] = self.columns
41
+ base_dict["rows"] = self.rows
42
+ elif self.type == QuestionType.RATING:
43
+ base_dict["min_rating"] = self.min_rating
44
+ base_dict["max_rating"] = self.max_rating
45
+
46
+ return base_dict
47
+
48
+
49
+ @dataclass
50
+ class Page:
51
+ name: str
52
+ elements: list[Question]
53
+
54
+ def to_dict(self) -> dict:
55
+ return {"name": self.name, "elements": [q.to_dict() for q in self.elements]}
56
+
57
+
58
+ @dataclass
59
+ class Survey:
60
+ id: uuid.UUID
61
+ title: str
62
+ description: str
63
+ pages: list[Page]
64
+ responses: dict[str, dict] = field(default_factory=dict)
65
+ created_at: datetime = field(default_factory=datetime.now)
66
+
67
+ def to_dict(self) -> dict:
68
+ return {
69
+ "id": str(self.id),
70
+ "title": self.title,
71
+ "description": self.description,
72
+ "pages": [p.to_dict() for p in self.pages],
73
+ "response_count": len(self.responses),
74
+ }
75
+
76
+ def to_json(self) -> str:
77
+ """Convert the survey to a JSON string for MQTT transmission"""
78
+ survey_dict = {
79
+ "id": str(self.id),
80
+ "title": self.title,
81
+ "description": self.description,
82
+ "pages": [p.to_dict() for p in self.pages],
83
+ "responses": self.responses,
84
+ "created_at": self.created_at.isoformat(),
85
+ }
86
+ return json.dumps(survey_dict)
87
+
88
+ @classmethod
89
+ def from_json(cls, json_str: str) -> "Survey":
90
+ """Create a Survey instance from a JSON string"""
91
+ data = json.loads(json_str)
92
+ pages = [
93
+ Page(
94
+ name=p["name"],
95
+ elements=[
96
+ Question(
97
+ name=q["name"],
98
+ title=q["title"],
99
+ type=QuestionType(q["type"]),
100
+ required=q.get("required", True),
101
+ choices=q.get("choices", []),
102
+ columns=q.get("columns", []),
103
+ rows=q.get("rows", []),
104
+ min_rating=q.get("min_rating", 1),
105
+ max_rating=q.get("max_rating", 5),
106
+ )
107
+ for q in p["elements"]
108
+ ],
109
+ )
110
+ for p in data["pages"]
111
+ ]
112
+
113
+ return cls(
114
+ id=uuid.UUID(data["id"]),
115
+ title=data["title"],
116
+ description=data["description"],
117
+ pages=pages,
118
+ responses=data.get("responses", {}),
119
+ created_at=datetime.fromisoformat(data["created_at"]),
120
+ )
@@ -0,0 +1,11 @@
1
+ from .avro_schema import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA,
2
+ PROFILE_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA)
3
+ from .pg_query import PGSQL_DICT, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
4
+ from .survey_util import process_survey_for_llm
5
+
6
+ __all__ = [
7
+ "PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
8
+ "process_survey_for_llm",
9
+ "TO_UPDATE_EXP_INFO_KEYS_AND_TYPES",
10
+ "PGSQL_DICT",
11
+ ]
@@ -0,0 +1,109 @@
1
+ PROFILE_SCHEMA = {
2
+ "doc": "Agent属性",
3
+ "name": "AgentProfile",
4
+ "namespace": "com.socialcity",
5
+ "type": "record",
6
+ "fields": [
7
+ {"name": "id", "type": "string"}, # uuid as string
8
+ {"name": "name", "type": "string"},
9
+ {"name": "gender", "type": "string"},
10
+ {"name": "age", "type": "float"},
11
+ {"name": "education", "type": "string"},
12
+ {"name": "skill", "type": "string"},
13
+ {"name": "occupation", "type": "string"},
14
+ {"name": "family_consumption", "type": "string"},
15
+ {"name": "consumption", "type": "string"},
16
+ {"name": "personality", "type": "string"},
17
+ {"name": "income", "type": "string"},
18
+ {"name": "currency", "type": "float"},
19
+ {"name": "residence", "type": "string"},
20
+ {"name": "race", "type": "string"},
21
+ {"name": "religion", "type": "string"},
22
+ {"name": "marital_status", "type": "string"},
23
+ ],
24
+ }
25
+
26
+ DIALOG_SCHEMA = {
27
+ "doc": "Agent对话",
28
+ "name": "AgentDialog",
29
+ "namespace": "com.socialcity",
30
+ "type": "record",
31
+ "fields": [
32
+ {"name": "id", "type": "string"}, # uuid as string
33
+ {"name": "day", "type": "int"},
34
+ {"name": "t", "type": "float"},
35
+ {"name": "type", "type": "int"},
36
+ {"name": "speaker", "type": "string"},
37
+ {"name": "content", "type": "string"},
38
+ {
39
+ "name": "created_at",
40
+ "type": {"type": "long", "logicalType": "timestamp-millis"},
41
+ },
42
+ ],
43
+ }
44
+
45
+ STATUS_SCHEMA = {
46
+ "doc": "Agent状态",
47
+ "name": "AgentStatus",
48
+ "namespace": "com.socialcity",
49
+ "type": "record",
50
+ "fields": [
51
+ {"name": "id", "type": "string"}, # uuid as string
52
+ {"name": "day", "type": "int"},
53
+ {"name": "t", "type": "float"},
54
+ {"name": "lng", "type": "double"},
55
+ {"name": "lat", "type": "double"},
56
+ {"name": "parent_id", "type": "int"},
57
+ {"name": "action", "type": "string"},
58
+ {"name": "hungry", "type": "float"},
59
+ {"name": "tired", "type": "float"},
60
+ {"name": "safe", "type": "float"},
61
+ {"name": "social", "type": "float"},
62
+ {
63
+ "name": "created_at",
64
+ "type": {"type": "long", "logicalType": "timestamp-millis"},
65
+ },
66
+ ],
67
+ }
68
+
69
+ INSTITUTION_STATUS_SCHEMA = {
70
+ "doc": "Institution状态",
71
+ "name": "InstitutionStatus",
72
+ "namespace": "com.socialcity",
73
+ "type": "record",
74
+ "fields": [
75
+ {"name": "id", "type": "string"}, # uuid as string
76
+ {"name": "day", "type": "int"},
77
+ {"name": "t", "type": "float"},
78
+ {"name": "type", "type": "int"},
79
+ {"name": "nominal_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
80
+ {"name": "real_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
81
+ {"name": "unemployment", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
82
+ {"name": "wages", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
83
+ {"name": "prices", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
84
+ {"name": "inventory", "type": ["int", "null"]},
85
+ {"name": "price", "type": ["float", "null"]},
86
+ {"name": "interest_rate", "type": ["float", "null"]},
87
+ {"name": "bracket_cutoffs", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
88
+ {"name": "bracket_rates", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
89
+ {"name": "employees", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
90
+ ],
91
+ }
92
+
93
+ SURVEY_SCHEMA = {
94
+ "doc": "Agent问卷",
95
+ "name": "AgentSurvey",
96
+ "namespace": "com.socialcity",
97
+ "type": "record",
98
+ "fields": [
99
+ {"name": "id", "type": "string"}, # uuid as string
100
+ {"name": "day", "type": "int"},
101
+ {"name": "t", "type": "float"},
102
+ {"name": "survey_id", "type": "string"},
103
+ {"name": "result", "type": "string"},
104
+ {
105
+ "name": "created_at",
106
+ "type": {"type": "long", "logicalType": "timestamp-millis"},
107
+ },
108
+ ],
109
+ }
@@ -0,0 +1,99 @@
1
+ import time
2
+ import functools
3
+ import inspect
4
+
5
+ CALLING_STRING = 'function: `{func_name}` in "{file_path}", line {line_number}, arguments: `{arguments}` start time: `{start_time}` end time: `{end_time}` output: `{output}`'
6
+
7
+ __all__ = [
8
+ "record_call_aio",
9
+ "record_call",
10
+ "lock_decorator",
11
+ ]
12
+
13
+
14
+ def record_call_aio(record_function_calling: bool = True):
15
+ """
16
+ Decorator to log the async function call details if `record_function_calling` is True.
17
+ """
18
+
19
+ def decorator(func):
20
+ async def wrapper(*args, **kwargs):
21
+ cur_frame = inspect.currentframe()
22
+ assert cur_frame is not None
23
+ frame = cur_frame.f_back
24
+ assert frame is not None
25
+ line_number = frame.f_lineno
26
+ file_path = frame.f_code.co_filename
27
+ args_repr = [repr(a) for a in args]
28
+ kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
29
+ signature = ", ".join(args_repr + kwargs_repr)
30
+ start_time = time.time()
31
+ result = await func(*args, **kwargs)
32
+ end_time = time.time()
33
+ if record_function_calling:
34
+ print(
35
+ CALLING_STRING.format(
36
+ func_name=func,
37
+ line_number=line_number,
38
+ file_path=file_path,
39
+ arguments=signature,
40
+ start_time=start_time,
41
+ end_time=end_time,
42
+ output=result,
43
+ )
44
+ )
45
+ return result
46
+
47
+ return wrapper
48
+
49
+ return decorator
50
+
51
+
52
+ def record_call(record_function_calling: bool = True):
53
+ """
54
+ Decorator to log the function call details if `record_function_calling` is True.
55
+ """
56
+
57
+ def decorator(func):
58
+ def wrapper(*args, **kwargs):
59
+ cur_frame = inspect.currentframe()
60
+ assert cur_frame is not None
61
+ frame = cur_frame.f_back
62
+ assert frame is not None
63
+ line_number = frame.f_lineno
64
+ file_path = frame.f_code.co_filename
65
+ args_repr = [repr(a) for a in args]
66
+ kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
67
+ signature = ", ".join(args_repr + kwargs_repr)
68
+ start_time = time.time()
69
+ result = func(*args, **kwargs)
70
+ end_time = time.time()
71
+ if record_function_calling:
72
+ print(
73
+ CALLING_STRING.format(
74
+ func_name=func,
75
+ line_number=line_number,
76
+ file_path=file_path,
77
+ arguments=signature,
78
+ start_time=start_time,
79
+ end_time=end_time,
80
+ output=result,
81
+ )
82
+ )
83
+ return result
84
+
85
+ return wrapper
86
+
87
+ return decorator
88
+
89
+
90
+ def lock_decorator(func):
91
+ async def wrapper(self, *args, **kwargs):
92
+ lock = self._lock
93
+ await lock.acquire()
94
+ try:
95
+ return await func(self, *args, **kwargs)
96
+ finally:
97
+ lock.release()
98
+
99
+ return wrapper
@@ -0,0 +1,13 @@
1
+ """Model response parser module."""
2
+
3
+ from .parser_base import ParserBase
4
+ from .json_parser import JsonDictParser, JsonObjectParser
5
+ from .code_block_parser import CodeBlockParser
6
+
7
+
8
+ __all__ = [
9
+ "ParserBase",
10
+ "JsonDictParser",
11
+ "JsonObjectParser",
12
+ "CodeBlockParser",
13
+ ]
@@ -0,0 +1,37 @@
1
+ import logging
2
+ from typing import Any
3
+
4
+ from .parser_base import ParserBase
5
+
6
+
7
+ class CodeBlockParser(ParserBase):
8
+ """A parser that extracts specific objects from a response string enclosed within specific tags.
9
+
10
+ Attributes:
11
+ tag_start (str): The start tag used to identify the beginning of the object.
12
+ tag_end (str): The end tag used to identify the end of the object.
13
+ """
14
+
15
+ tag_start = "```{language_name}"
16
+ tag_end = "```"
17
+
18
+ def __init__(self, language_name: str) -> None:
19
+ """Initialize the CodeBlockParser with default tags."""
20
+ super().__init__()
21
+ self.language_name = language_name
22
+
23
+ def parse(self, response: str) -> str:
24
+ """Parse the response string to extract and return a str.
25
+
26
+ Parameters:
27
+ response (str): The response string containing the specified language object.
28
+
29
+ Returns:
30
+ str: The parsed `str` object.
31
+ """
32
+ extract_text = self._extract_text_within_tags(
33
+ response=response,
34
+ tag_start=self.tag_start.format(language_name=self.language_name),
35
+ tag_end=self.tag_end,
36
+ )
37
+ return extract_text