pycityagent 2.0.0a66__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a67__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. pycityagent/agent/agent.py +157 -57
  2. pycityagent/agent/agent_base.py +316 -43
  3. pycityagent/cityagent/bankagent.py +49 -9
  4. pycityagent/cityagent/blocks/__init__.py +1 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +54 -31
  6. pycityagent/cityagent/blocks/dispatcher.py +22 -17
  7. pycityagent/cityagent/blocks/economy_block.py +46 -32
  8. pycityagent/cityagent/blocks/mobility_block.py +130 -100
  9. pycityagent/cityagent/blocks/needs_block.py +101 -44
  10. pycityagent/cityagent/blocks/other_block.py +42 -33
  11. pycityagent/cityagent/blocks/plan_block.py +59 -42
  12. pycityagent/cityagent/blocks/social_block.py +167 -116
  13. pycityagent/cityagent/blocks/utils.py +13 -6
  14. pycityagent/cityagent/firmagent.py +17 -35
  15. pycityagent/cityagent/governmentagent.py +3 -3
  16. pycityagent/cityagent/initial.py +79 -44
  17. pycityagent/cityagent/memory_config.py +108 -88
  18. pycityagent/cityagent/message_intercept.py +0 -4
  19. pycityagent/cityagent/metrics.py +41 -0
  20. pycityagent/cityagent/nbsagent.py +24 -36
  21. pycityagent/cityagent/societyagent.py +7 -3
  22. pycityagent/cli/wrapper.py +2 -2
  23. pycityagent/economy/econ_client.py +407 -81
  24. pycityagent/environment/__init__.py +0 -3
  25. pycityagent/environment/sim/__init__.py +0 -3
  26. pycityagent/environment/sim/aoi_service.py +2 -2
  27. pycityagent/environment/sim/client.py +3 -31
  28. pycityagent/environment/sim/clock_service.py +2 -2
  29. pycityagent/environment/sim/lane_service.py +8 -8
  30. pycityagent/environment/sim/light_service.py +8 -8
  31. pycityagent/environment/sim/pause_service.py +9 -10
  32. pycityagent/environment/sim/person_service.py +20 -20
  33. pycityagent/environment/sim/road_service.py +2 -2
  34. pycityagent/environment/sim/sim_env.py +21 -5
  35. pycityagent/environment/sim/social_service.py +4 -4
  36. pycityagent/environment/simulator.py +249 -27
  37. pycityagent/environment/utils/__init__.py +2 -2
  38. pycityagent/environment/utils/geojson.py +2 -2
  39. pycityagent/environment/utils/grpc.py +4 -4
  40. pycityagent/environment/utils/map_utils.py +2 -2
  41. pycityagent/llm/embeddings.py +147 -28
  42. pycityagent/llm/llm.py +122 -77
  43. pycityagent/llm/llmconfig.py +5 -0
  44. pycityagent/llm/utils.py +4 -0
  45. pycityagent/memory/__init__.py +0 -4
  46. pycityagent/memory/const.py +2 -2
  47. pycityagent/memory/faiss_query.py +140 -61
  48. pycityagent/memory/memory.py +393 -90
  49. pycityagent/memory/memory_base.py +140 -34
  50. pycityagent/memory/profile.py +13 -13
  51. pycityagent/memory/self_define.py +13 -13
  52. pycityagent/memory/state.py +14 -14
  53. pycityagent/message/message_interceptor.py +253 -3
  54. pycityagent/message/messager.py +133 -6
  55. pycityagent/metrics/mlflow_client.py +47 -4
  56. pycityagent/pycityagent-sim +0 -0
  57. pycityagent/pycityagent-ui +0 -0
  58. pycityagent/simulation/__init__.py +3 -2
  59. pycityagent/simulation/agentgroup.py +145 -52
  60. pycityagent/simulation/simulation.py +257 -62
  61. pycityagent/survey/manager.py +45 -3
  62. pycityagent/survey/models.py +42 -2
  63. pycityagent/tools/__init__.py +1 -2
  64. pycityagent/tools/tool.py +93 -69
  65. pycityagent/utils/avro_schema.py +2 -2
  66. pycityagent/utils/parsers/code_block_parser.py +1 -1
  67. pycityagent/utils/parsers/json_parser.py +2 -2
  68. pycityagent/utils/parsers/parser_base.py +2 -2
  69. pycityagent/workflow/block.py +64 -13
  70. pycityagent/workflow/prompt.py +31 -23
  71. pycityagent/workflow/trigger.py +91 -24
  72. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/METADATA +2 -2
  73. pycityagent-2.0.0a67.dist-info/RECORD +97 -0
  74. pycityagent/environment/interact/__init__.py +0 -0
  75. pycityagent/environment/interact/interact.py +0 -198
  76. pycityagent/environment/message/__init__.py +0 -0
  77. pycityagent/environment/sence/__init__.py +0 -0
  78. pycityagent/environment/sence/static.py +0 -416
  79. pycityagent/environment/sidecar/__init__.py +0 -8
  80. pycityagent/environment/sidecar/sidecarv2.py +0 -109
  81. pycityagent/environment/sim/economy_services.py +0 -192
  82. pycityagent/metrics/utils/const.py +0 -0
  83. pycityagent-2.0.0a66.dist-info/RECORD +0 -105
  84. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/LICENSE +0 -0
  85. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/WHEEL +0 -0
  86. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/entry_points.txt +0 -0
  87. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/top_level.txt +0 -0
pycityagent/tools/tool.py CHANGED
@@ -7,20 +7,45 @@ from typing import Any, Optional, Union
7
7
  from mlflow.entities import Metric
8
8
 
9
9
  from ..agent import Agent
10
- from ..environment import (LEVEL_ONE_PRE, POI_TYPE_DICT, AoiService,
11
- PersonService)
10
+ from ..environment import AoiService, PersonService
12
11
  from ..utils.decorators import lock_decorator
13
12
  from ..workflow import Block
14
13
 
14
+ __all__ = [
15
+ "Tool",
16
+ "ExportMlflowMetrics",
17
+ "GetMap",
18
+ "UpdateWithSimulator",
19
+ "ResetAgentPosition",
20
+ ]
21
+
15
22
 
16
23
  class Tool:
17
24
  """Abstract tool class for callable tools. Can be bound to an `Agent` or `Block` instance.
18
25
 
19
26
  This class serves as a base for creating various tools that can perform different operations.
20
27
  It is intended to be subclassed by specific tool implementations.
28
+
29
+ - **Attributes**:
30
+ - `_instance`: A reference to the instance (`Agent` or `Block`) this tool is bound to.
21
31
  """
22
32
 
23
33
  def __get__(self, instance, owner):
34
+ """
35
+ Descriptor method for binding the tool to an instance.
36
+
37
+ - **Args**:
38
+ - `instance`: The instance that the tool is being accessed through.
39
+ - `owner`: The type of the owner class.
40
+
41
+ - **Returns**:
42
+ - `Tool`: An instance of the tool bound to the given instance.
43
+
44
+ - **Description**:
45
+ - If accessed via the class rather than an instance, returns the descriptor itself.
46
+ - Otherwise, it checks if the tool has already been instantiated for this instance,
47
+ and if not, creates and stores a new tool instance specifically for this instance.
48
+ """
24
49
  if instance is None:
25
50
  return self
26
51
  subclass = type(self)
@@ -36,11 +61,23 @@ class Tool:
36
61
  """Invoke the tool's functionality.
37
62
 
38
63
  This method must be implemented by subclasses to provide specific behavior.
64
+
65
+ - **Raises**:
66
+ - `NotImplementedError`: When called directly on the base class.
39
67
  """
40
68
  raise NotImplementedError
41
69
 
42
70
  @property
43
71
  def agent(self) -> Agent:
72
+ """
73
+ Access the `Agent` this tool is bound to.
74
+
75
+ - **Returns**:
76
+ - `Agent`: The agent instance.
77
+
78
+ - **Raises**:
79
+ - `RuntimeError`: If the tool is not bound to an `Agent`.
80
+ """
44
81
  instance = self._instance # type:ignore
45
82
  if not isinstance(instance, Agent):
46
83
  raise RuntimeError(
@@ -50,6 +87,15 @@ class Tool:
50
87
 
51
88
  @property
52
89
  def block(self) -> Block:
90
+ """
91
+ Access the `Block` this tool is bound to.
92
+
93
+ - **Returns**:
94
+ - `Block`: The block instance.
95
+
96
+ - **Raises**:
97
+ - `RuntimeError`: If the tool is not bound to a `Block`.
98
+ """
53
99
  instance = self._instance # type:ignore
54
100
  if not isinstance(instance, Block):
55
101
  raise RuntimeError(
@@ -71,72 +117,9 @@ class GetMap(Tool):
71
117
  return agent.simulator.map
72
118
 
73
119
 
74
- class SencePOI(Tool):
75
- """Retrieve the Point of Interest (POI) of the current scene.
76
-
77
- This tool computes the POI based on the current `position` stored in memory and returns
78
- points of interest (POIs) within a specified radius. Can be bound only to an `Agent` instance.
79
-
80
- Attributes:
81
- radius (int): The radius within which to search for POIs.
82
- category_prefix (str): The prefix for the categories of POIs to consider.
83
- variables (list[str]): A list of variables relevant to the tool's operation.
84
-
85
- Args:
86
- radius (int, optional): The circular search radius. Defaults to 100.
87
- category_prefix (str, optional): The category prefix to filter POIs. Defaults to LEVEL_ONE_PRE.
88
-
89
- Methods:
90
- __call__(radius: Optional[int] = None, category_prefix: Optional[str] = None) -> Union[Any, Callable]:
91
- Executes the AOI retrieval operation, returning POIs based on the current state of memory and simulator.
92
- """
93
-
94
- def __init__(self, radius: int = 100, category_prefix=LEVEL_ONE_PRE) -> None:
95
- self.radius = radius
96
- self.category_prefix = category_prefix
97
- self.variables = ["position"]
98
-
99
- async def __call__(
100
- self, radius: Optional[int] = None, category_prefix: Optional[str] = None
101
- ) -> Union[Any, Callable]:
102
- """Retrieve the POIs within the specified radius and category prefix.
103
-
104
- If both `radius` and `category_prefix` are None, the method will use the current position
105
- from memory to query POIs using the simulator. Otherwise, it will return a new instance
106
- of SenceAoi with the specified parameters.
107
-
108
- Args:
109
- radius (Optional[int]): A specific radius for the AOI query. If not provided, defaults to the instance's radius.
110
- category_prefix (Optional[str]): A specific category prefix to filter POIs. If not provided, defaults to the instance's category_prefix.
111
-
112
- Raises:
113
- ValueError: If memory or simulator is not set.
114
-
115
- Returns:
116
- Union[Any, Callable]: The query results or a callable for a new SenceAoi instance.
117
- """
118
- agent = self.agent
119
- if agent.memory is None or agent.simulator is None:
120
- raise ValueError("Memory or Simulator is not set.")
121
- if radius is None and category_prefix is None:
122
- position = await agent.status.get("position")
123
- resp = []
124
- for prefix in self.category_prefix:
125
- resp += agent.simulator.map.query_pois(
126
- center=(position["xy_position"]["x"], position["xy_position"]["y"]),
127
- radius=self.radius,
128
- category_prefix=prefix,
129
- )
130
- # * Map six-digit codes to specific types
131
- for poi in resp:
132
- cate_str = poi[0]["category"]
133
- poi[0]["category"] = POI_TYPE_DICT[cate_str]
134
- else:
135
- radius_ = radius if radius else self.radius
136
- return SencePOI(radius_, category_prefix)
137
-
138
-
139
120
  class UpdateWithSimulator(Tool):
121
+ """Automatically update status memory from simulator"""
122
+
140
123
  def __init__(self) -> None:
141
124
  self._lock = asyncio.Lock()
142
125
 
@@ -180,6 +163,18 @@ class ResetAgentPosition(Tool):
180
163
  lane_id: Optional[int] = None,
181
164
  s: Optional[float] = None,
182
165
  ):
166
+ """
167
+ Reset the position of the agent associated with this tool.
168
+
169
+ - **Args**:
170
+ - `aoi_id` (Optional[int], optional): Area of interest ID. Defaults to None.
171
+ - `poi_id` (Optional[int], optional): Point of interest ID. Defaults to None.
172
+ - `lane_id` (Optional[int], optional): Lane ID. Defaults to None.
173
+ - `s` (Optional[float], optional): Position along the lane. Defaults to None.
174
+
175
+ - **Description**:
176
+ - Resets the agent's position based on the provided parameters using the simulator.
177
+ """
183
178
  agent = self.agent
184
179
  status = agent.status
185
180
  await agent.simulator.reset_person_position(
@@ -192,7 +187,22 @@ class ResetAgentPosition(Tool):
192
187
 
193
188
 
194
189
  class ExportMlflowMetrics(Tool):
190
+ """
191
+ A tool for exporting metrics to MLflow in batches.
192
+
193
+ - **Attributes**:
194
+ - `_log_batch_size` (int): The number of metrics to log in each batch.
195
+ - `metric_log_cache` (Dict[str, List[Metric]]): Cache for storing metrics before batching.
196
+ - `_lock` (asyncio.Lock): Ensures thread-safe operations when logging metrics.
197
+ """
198
+
195
199
  def __init__(self, log_batch_size: int = 100) -> None:
200
+ """
201
+ Initialize the ExportMlflowMetrics tool with a specified batch size and an asynchronous lock.
202
+
203
+ - **Args**:
204
+ - `log_batch_size` (int, optional): Number of metrics per batch. Defaults to 100.
205
+ """
196
206
  self._log_batch_size = log_batch_size
197
207
  # TODO: support other log types
198
208
  self.metric_log_cache: dict[str, list[Metric]] = defaultdict(list)
@@ -204,6 +214,17 @@ class ExportMlflowMetrics(Tool):
204
214
  metric: Union[Sequence[Union[Metric, dict]], Union[Metric, dict]],
205
215
  clear_cache: bool = False,
206
216
  ):
217
+ """
218
+ Add metrics to the cache and export them to MLflow in batches if the batch size limit is reached.
219
+
220
+ - **Args**:
221
+ - `metric` (Union[Sequence[Union[Metric, dict]], Union[Metric, dict]]): A single metric or a sequence of metrics.
222
+ - `clear_cache` (bool, optional): Flag indicating whether to clear the cache after logging. Defaults to False.
223
+
224
+ - **Description**:
225
+ - Adds metrics to the cache. If the cache exceeds the batch size, logs a batch of metrics to MLflow.
226
+ - Optionally clears the entire cache.
227
+ """
207
228
  agent = self.agent
208
229
  batch_size = self._log_batch_size
209
230
  if not isinstance(metric, Sequence):
@@ -223,7 +244,7 @@ class ExportMlflowMetrics(Tool):
223
244
  self.metric_log_cache[metric_key].append(item)
224
245
  for metric_key, _cache in self.metric_log_cache.items():
225
246
  if len(_cache) > batch_size:
226
- client = agent.mlflow_client
247
+ client = agent.mlflow_client # type:ignore
227
248
  await client.log_batch(
228
249
  metrics=_cache[:batch_size],
229
250
  )
@@ -234,8 +255,11 @@ class ExportMlflowMetrics(Tool):
234
255
  async def _clear_cache(
235
256
  self,
236
257
  ):
258
+ """
259
+ Log any remaining metrics from the cache to MLflow and then clear the cache.
260
+ """
237
261
  agent = self.agent
238
- client = agent.mlflow_client
262
+ client = agent.mlflow_client # type:ignore
239
263
  for metric_key, _cache in self.metric_log_cache.items():
240
264
  if len(_cache) > 0:
241
265
  await client.log_batch(
@@ -12,9 +12,9 @@ PROFILE_SCHEMA = {
12
12
  {"name": "skill", "type": "string"},
13
13
  {"name": "occupation", "type": "string"},
14
14
  {"name": "family_consumption", "type": "string"},
15
- {"name": "consumption", "type": "string"},
15
+ {"name": "consumption", "type": "float"},
16
16
  {"name": "personality", "type": "string"},
17
- {"name": "income", "type": "string"},
17
+ {"name": "income", "type": "float"},
18
18
  {"name": "currency", "type": "float"},
19
19
  {"name": "residence", "type": "string"},
20
20
  {"name": "race", "type": "string"},
@@ -26,7 +26,7 @@ class CodeBlockParser(ParserBase):
26
26
  Parameters:
27
27
  response (str): The response string containing the specified language object.
28
28
 
29
- Returns:
29
+ - **Returns**:
30
30
  str: The parsed `str` object.
31
31
  """
32
32
  extract_text = self._extract_text_within_tags(
@@ -26,7 +26,7 @@ class JsonObjectParser(ParserBase):
26
26
  Parameters:
27
27
  response (str): The response string containing the JSON object.
28
28
 
29
- Returns:
29
+ - **Returns**:
30
30
  Any: The parsed JSON object.
31
31
  """
32
32
  extract_text = self._extract_text_within_tags(
@@ -73,7 +73,7 @@ class JsonDictParser(JsonObjectParser):
73
73
  Parameters:
74
74
  response (str): The response string containing the JSON object.
75
75
 
76
- Returns:
76
+ - **Returns**:
77
77
  dict: The parsed JSON object as a dictionary.
78
78
  """
79
79
  parsed_json = super().parse(response)
@@ -21,7 +21,7 @@ class ParserBase(ABC):
21
21
  Parameters:
22
22
  response (str): The raw string returned by the model.
23
23
 
24
- Returns:
24
+ - **Returns**:
25
25
  Any: The converted data, the specific type depends on the parsing result.
26
26
  """
27
27
  pass
@@ -37,7 +37,7 @@ class ParserBase(ABC):
37
37
  tag_start (str): The start tag.
38
38
  tag_end (str): The end tag.
39
39
 
40
- Returns:
40
+ - **Returns**:
41
41
  str: The string between the start and end tags.
42
42
  """
43
43
 
@@ -33,11 +33,11 @@ def log_and_check_with_memory(
33
33
 
34
34
  This decorator is specifically designed to be used with the `block` method. A 'Memory' object is required in method input.
35
35
 
36
- Args:
37
- condition (Callable): A condition function that must be satisfied before the decorated function is executed.
36
+ - **Args**:
37
+ - `condition` (Callable): A condition function that must be satisfied before the decorated function is executed.
38
38
  Can be synchronous or asynchronous.
39
- trigger_interval (float): The interval (in seconds) to wait between condition checks.
40
- record_function_calling (bool): Whether to log the function call information.
39
+ - `trigger_interval` (float): The interval (in seconds) to wait between condition checks.
40
+ - `record_function_calling` (bool): Whether to log the function call information.
41
41
  """
42
42
 
43
43
  def decorator(func):
@@ -93,11 +93,11 @@ def log_and_check(
93
93
 
94
94
  This decorator is specifically designed to be used with the `block` method.
95
95
 
96
- Args:
97
- condition (Callable): A condition function that must be satisfied before the decorated function is executed.
96
+ - **Args**:
97
+ - `condition` (Callable): A condition function that must be satisfied before the decorated function is executed.
98
98
  Can be synchronous or asynchronous.
99
- trigger_interval (float): The interval (in seconds) to wait between condition checks.
100
- record_function_calling (bool): Whether to log the function call information.
99
+ - `trigger_interval` (float): The interval (in seconds) to wait between condition checks.
100
+ - `record_function_calling` (bool): Whether to log the function call information.
101
101
  """
102
102
 
103
103
  def decorator(func):
@@ -144,6 +144,15 @@ def trigger_class():
144
144
 
145
145
  # Define a Block, similar to a layer in PyTorch
146
146
  class Block:
147
+ """
148
+ A foundational component similar to a layer in PyTorch, used for building complex systems.
149
+
150
+ - **Attributes**:
151
+ - `configurable_fields` (list[str]): A list of fields that can be configured.
152
+ - `default_values` (dict[str, Any]): Default values for configurable fields.
153
+ - `fields_description` (dict[str, str]): Descriptions for each configurable field.
154
+ """
155
+
147
156
  configurable_fields: list[str] = []
148
157
  default_values: dict[str, Any] = {}
149
158
  fields_description: dict[str, str] = {}
@@ -156,6 +165,17 @@ class Block:
156
165
  simulator: Optional[Simulator] = None,
157
166
  trigger: Optional[EventTrigger] = None,
158
167
  ):
168
+ """
169
+ - **Description**:
170
+ - Initializes a new instance of the Block class with optional LLM, Memory, Simulator, and Trigger components.
171
+
172
+ - **Args**:
173
+ - `name` (str): The name of the block.
174
+ - `llm` (Optional[LLM], optional): An instance of LLM. Defaults to None.
175
+ - `memory` (Optional[Memory], optional): An instance of Memory. Defaults to None.
176
+ - `simulator` (Optional[Simulator], optional): An instance of Simulator. Defaults to None.
177
+ - `trigger` (Optional[EventTrigger], optional): An event trigger that may be associated with this block. Defaults to None.
178
+ """
159
179
  self.name = name
160
180
  self._llm = llm
161
181
  self._memory = memory
@@ -167,13 +187,27 @@ class Block:
167
187
  self.trigger = trigger
168
188
 
169
189
  def export_config(self) -> dict[str, Optional[str]]:
190
+ """
191
+ - **Description**:
192
+ - Exports the configuration of the block as a dictionary.
193
+
194
+ - **Returns**:
195
+ - `Dict[str, Optional[str]]`: A dictionary containing the configuration of the block.
196
+ """
170
197
  return {
171
198
  field: self.default_values.get(field, "default_value")
172
199
  for field in self.configurable_fields
173
200
  }
174
201
 
175
202
  @classmethod
176
- def export_class_config(cls) -> dict[str, str]:
203
+ def export_class_config(cls) -> tuple[dict[str, Any], dict[str, Any]]:
204
+ """
205
+ - **Description**:
206
+ - Exports the default configuration and descriptions for the configurable fields of the class.
207
+
208
+ - **Returns**:
209
+ - `tuple[Dict[str, Any], Dict[str, Any]]`: A tuple containing two dictionaries, one for default values and one for field descriptions.
210
+ """
177
211
  return (
178
212
  {
179
213
  field: cls.default_values.get(field, "default_value")
@@ -182,11 +216,21 @@ class Block:
182
216
  {
183
217
  field: cls.fields_description.get(field, "")
184
218
  for field in cls.configurable_fields
185
- }
219
+ },
186
220
  )
187
221
 
188
222
  @classmethod
189
223
  def import_config(cls, config: dict[str, Union[str, dict]]) -> Block:
224
+ """
225
+ - **Description**:
226
+ - Creates an instance of the Block from a configuration dictionary.
227
+
228
+ - **Args**:
229
+ - `config` (Dict[str, Union[str, dict]]): Configuration dictionary for creating the block.
230
+
231
+ - **Returns**:
232
+ - `Block`: An instance of the Block created from the provided configuration.
233
+ """
190
234
  instance = cls(name=config["name"]) # type: ignore
191
235
  assert isinstance(config["config"], dict)
192
236
  for field, value in config["config"].items():
@@ -202,7 +246,11 @@ class Block:
202
246
 
203
247
  def load_from_config(self, config: dict[str, list[dict]]) -> None:
204
248
  """
205
- 使用配置更新当前Block实例的参数,并递归更新子Block。
249
+ - **Description**:
250
+ - Updates the current Block instance parameters using a configuration dictionary and recursively updates its children.
251
+
252
+ - **Args**:
253
+ - `config` (Dict[str, List[Dict]]): Configuration dictionary for updating the block.
206
254
  """
207
255
  # 更新当前Block的参数
208
256
  for field in self.configurable_fields:
@@ -233,8 +281,11 @@ class Block:
233
281
 
234
282
  async def forward(self):
235
283
  """
236
- Each block performs a specific reasoning task.
237
- To be overridden by specific block implementations.
284
+ - **Description**:
285
+ - Each block performs a specific reasoning task. This method should be overridden by subclasses.
286
+
287
+ - **Raises**:
288
+ - `NotImplementedError`: Subclasses must implement this method.
238
289
  """
239
290
  raise NotImplementedError("Subclasses should implement this method")
240
291
 
@@ -1,5 +1,5 @@
1
- from typing import Optional, Union
2
1
  import re
2
+ from typing import Optional, Union
3
3
 
4
4
 
5
5
  class FormatPrompt:
@@ -7,20 +7,21 @@ class FormatPrompt:
7
7
  A class to handle the formatting of prompts based on a template,
8
8
  with support for system prompts and variable extraction.
9
9
 
10
- Attributes:
11
- template (str): The template string containing placeholders.
12
- system_prompt (Optional[str]): An optional system prompt to add to the dialog.
13
- variables (list[str]): A list of variable names extracted from the template.
14
- formatted_string (str): The formatted string derived from the template and provided variables.
10
+ - **Attributes**:
11
+ - `template` (str): The template string containing placeholders.
12
+ - `system_prompt` (Optional[str]): An optional system prompt to add to the dialog.
13
+ - `variables` (List[str]): A list of variable names extracted from the template.
14
+ - `formatted_string` (str): The formatted string derived from the template and provided variables.
15
15
  """
16
16
 
17
17
  def __init__(self, template: str, system_prompt: Optional[str] = None) -> None:
18
18
  """
19
- Initializes the FormatPrompt with a template and an optional system prompt.
19
+ - **Description**:
20
+ - Initializes the FormatPrompt with a template and an optional system prompt.
20
21
 
21
- Args:
22
- template (str): The string template with variable placeholders.
23
- system_prompt (Optional[str]): An optional system prompt.
22
+ - **Args**:
23
+ - `template` (str): The string template with variable placeholders.
24
+ - `system_prompt` (Optional[str], optional): An optional system prompt. Defaults to None.
24
25
  """
25
26
  self.template = template
26
27
  self.system_prompt = system_prompt # Store the system prompt
@@ -29,22 +30,27 @@ class FormatPrompt:
29
30
 
30
31
  def _extract_variables(self) -> list[str]:
31
32
  """
32
- Extracts variable names from the template string.
33
+ - **Description**:
34
+ - Extracts variable names from the template string using regular expressions.
33
35
 
34
- Returns:
35
- list[str]: A list of variable names found within the template.
36
+ - **Returns**:
37
+ - `List[str]`: A list of variable names found within the template.
36
38
  """
37
39
  return re.findall(r"\{(\w+)\}", self.template)
38
40
 
39
41
  def format(self, **kwargs) -> str:
40
42
  """
41
- Formats the template string using the provided keyword arguments.
43
+ - **Description**:
44
+ - Formats the template string using the provided keyword arguments.
45
+
46
+ - **Args**:
47
+ - `**kwargs`: Variable names and their corresponding values to format the template.
42
48
 
43
- Args:
44
- **kwargs: Variable names and their corresponding values to format the template.
49
+ - **Returns**:
50
+ - `str`: The formatted string.
45
51
 
46
- Returns:
47
- str: The formatted string.
52
+ - **Raises**:
53
+ - `KeyError`: If a placeholder in the template does not have a corresponding key in kwargs.
48
54
  """
49
55
  self.formatted_string = self.template.format(
50
56
  **kwargs
@@ -53,10 +59,11 @@ class FormatPrompt:
53
59
 
54
60
  def to_dialog(self) -> list[dict[str, str]]:
55
61
  """
56
- Converts the formatted prompt and optional system prompt into a dialog format.
62
+ - **Description**:
63
+ - Converts the formatted prompt and optional system prompt into a dialog format suitable for chat systems.
57
64
 
58
- Returns:
59
- list[dict[str, str]]: A list representing the dialog with roles and content.
65
+ - **Returns**:
66
+ - `List[Dict[str, str]]`: A list representing the dialog with roles and content.
60
67
  """
61
68
  dialog = []
62
69
  if self.system_prompt:
@@ -70,8 +77,9 @@ class FormatPrompt:
70
77
 
71
78
  def log(self) -> None:
72
79
  """
73
- Logs the details of the FormatPrompt, including the template,
74
- system prompt, extracted variables, and formatted string.
80
+ - **Description**:
81
+ - Logs the details of the FormatPrompt instance, including the template,
82
+ system prompt, extracted variables, and formatted string.
75
83
  """
76
84
  print(f"FormatPrompt: {self.template}")
77
85
  print(f"System Prompt: {self.system_prompt}") # Log the system prompt