lionagi 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (153) hide show
  1. lionagi/core/action/function_calling.py +13 -6
  2. lionagi/core/action/tool.py +10 -9
  3. lionagi/core/action/tool_manager.py +18 -9
  4. lionagi/core/agent/README.md +1 -1
  5. lionagi/core/agent/base_agent.py +5 -2
  6. lionagi/core/agent/eval/README.md +1 -1
  7. lionagi/core/collections/README.md +1 -1
  8. lionagi/core/collections/_logger.py +16 -6
  9. lionagi/core/collections/abc/README.md +1 -1
  10. lionagi/core/collections/abc/component.py +35 -11
  11. lionagi/core/collections/abc/concepts.py +5 -3
  12. lionagi/core/collections/abc/exceptions.py +3 -1
  13. lionagi/core/collections/flow.py +16 -5
  14. lionagi/core/collections/model.py +34 -8
  15. lionagi/core/collections/pile.py +65 -28
  16. lionagi/core/collections/progression.py +1 -2
  17. lionagi/core/collections/util.py +11 -2
  18. lionagi/core/director/README.md +1 -1
  19. lionagi/core/engine/branch_engine.py +35 -10
  20. lionagi/core/engine/instruction_map_engine.py +14 -5
  21. lionagi/core/engine/sandbox_.py +3 -1
  22. lionagi/core/engine/script_engine.py +6 -2
  23. lionagi/core/executor/base_executor.py +10 -3
  24. lionagi/core/executor/graph_executor.py +12 -4
  25. lionagi/core/executor/neo4j_executor.py +18 -6
  26. lionagi/core/generic/edge.py +7 -2
  27. lionagi/core/generic/graph.py +23 -7
  28. lionagi/core/generic/node.py +14 -5
  29. lionagi/core/generic/tree_node.py +5 -1
  30. lionagi/core/mail/mail_manager.py +3 -1
  31. lionagi/core/mail/package.py +3 -1
  32. lionagi/core/message/action_request.py +9 -2
  33. lionagi/core/message/action_response.py +9 -3
  34. lionagi/core/message/instruction.py +8 -2
  35. lionagi/core/message/util.py +15 -5
  36. lionagi/core/report/base.py +12 -7
  37. lionagi/core/report/form.py +7 -4
  38. lionagi/core/report/report.py +10 -3
  39. lionagi/core/report/util.py +3 -1
  40. lionagi/core/rule/action.py +4 -1
  41. lionagi/core/rule/base.py +17 -6
  42. lionagi/core/rule/rulebook.py +8 -4
  43. lionagi/core/rule/string.py +3 -1
  44. lionagi/core/session/branch.py +15 -4
  45. lionagi/core/session/directive_mixin.py +11 -3
  46. lionagi/core/session/session.py +6 -2
  47. lionagi/core/unit/parallel_unit.py +9 -3
  48. lionagi/core/unit/template/action.py +1 -1
  49. lionagi/core/unit/template/predict.py +3 -1
  50. lionagi/core/unit/template/select.py +5 -3
  51. lionagi/core/unit/unit.py +38 -4
  52. lionagi/core/unit/unit_form.py +13 -15
  53. lionagi/core/unit/unit_mixin.py +45 -27
  54. lionagi/core/unit/util.py +7 -3
  55. lionagi/core/validator/validator.py +28 -15
  56. lionagi/core/work/work_edge.py +7 -3
  57. lionagi/core/work/work_task.py +11 -5
  58. lionagi/core/work/worker.py +20 -5
  59. lionagi/core/work/worker_engine.py +6 -2
  60. lionagi/core/work/worklog.py +3 -1
  61. lionagi/experimental/compressor/llm_compressor.py +20 -5
  62. lionagi/experimental/directive/README.md +1 -1
  63. lionagi/experimental/directive/parser/base_parser.py +41 -14
  64. lionagi/experimental/directive/parser/base_syntax.txt +23 -23
  65. lionagi/experimental/directive/template/base_template.py +14 -6
  66. lionagi/experimental/directive/tokenizer.py +3 -1
  67. lionagi/experimental/evaluator/README.md +1 -1
  68. lionagi/experimental/evaluator/ast_evaluator.py +6 -2
  69. lionagi/experimental/evaluator/base_evaluator.py +27 -16
  70. lionagi/integrations/bridge/autogen_/autogen_.py +7 -3
  71. lionagi/integrations/bridge/langchain_/documents.py +13 -10
  72. lionagi/integrations/bridge/llamaindex_/llama_pack.py +36 -12
  73. lionagi/integrations/bridge/llamaindex_/node_parser.py +8 -3
  74. lionagi/integrations/bridge/llamaindex_/reader.py +3 -1
  75. lionagi/integrations/bridge/llamaindex_/textnode.py +9 -3
  76. lionagi/integrations/bridge/pydantic_/pydantic_bridge.py +7 -1
  77. lionagi/integrations/bridge/transformers_/install_.py +3 -1
  78. lionagi/integrations/chunker/chunk.py +5 -2
  79. lionagi/integrations/loader/load.py +7 -3
  80. lionagi/integrations/loader/load_util.py +35 -16
  81. lionagi/integrations/provider/oai.py +13 -4
  82. lionagi/integrations/provider/openrouter.py +13 -4
  83. lionagi/integrations/provider/services.py +3 -1
  84. lionagi/integrations/provider/transformers.py +5 -3
  85. lionagi/integrations/storage/neo4j.py +23 -7
  86. lionagi/integrations/storage/storage_util.py +23 -7
  87. lionagi/integrations/storage/structure_excel.py +7 -2
  88. lionagi/integrations/storage/to_csv.py +8 -2
  89. lionagi/integrations/storage/to_excel.py +11 -3
  90. lionagi/libs/ln_api.py +41 -19
  91. lionagi/libs/ln_context.py +4 -4
  92. lionagi/libs/ln_convert.py +35 -14
  93. lionagi/libs/ln_dataframe.py +9 -3
  94. lionagi/libs/ln_func_call.py +53 -18
  95. lionagi/libs/ln_image.py +9 -5
  96. lionagi/libs/ln_knowledge_graph.py +21 -7
  97. lionagi/libs/ln_nested.py +57 -16
  98. lionagi/libs/ln_parse.py +45 -15
  99. lionagi/libs/ln_queue.py +8 -3
  100. lionagi/libs/ln_tokenize.py +19 -6
  101. lionagi/libs/ln_validate.py +14 -3
  102. lionagi/libs/sys_util.py +44 -12
  103. lionagi/lions/coder/coder.py +24 -8
  104. lionagi/lions/coder/util.py +6 -2
  105. lionagi/lions/researcher/data_source/google_.py +12 -4
  106. lionagi/lions/researcher/data_source/wiki_.py +3 -1
  107. lionagi/version.py +1 -1
  108. {lionagi-0.2.11.dist-info → lionagi-0.3.1.dist-info}/METADATA +6 -7
  109. lionagi-0.3.1.dist-info/RECORD +226 -0
  110. lionagi/tests/__init__.py +0 -0
  111. lionagi/tests/api/__init__.py +0 -0
  112. lionagi/tests/api/aws/__init__.py +0 -0
  113. lionagi/tests/api/aws/conftest.py +0 -25
  114. lionagi/tests/api/aws/test_aws_s3.py +0 -6
  115. lionagi/tests/integrations/__init__.py +0 -0
  116. lionagi/tests/libs/__init__.py +0 -0
  117. lionagi/tests/libs/test_api.py +0 -48
  118. lionagi/tests/libs/test_convert.py +0 -89
  119. lionagi/tests/libs/test_field_validators.py +0 -354
  120. lionagi/tests/libs/test_func_call.py +0 -701
  121. lionagi/tests/libs/test_nested.py +0 -382
  122. lionagi/tests/libs/test_parse.py +0 -171
  123. lionagi/tests/libs/test_queue.py +0 -68
  124. lionagi/tests/libs/test_sys_util.py +0 -222
  125. lionagi/tests/test_core/__init__.py +0 -0
  126. lionagi/tests/test_core/collections/__init__.py +0 -0
  127. lionagi/tests/test_core/collections/test_component.py +0 -208
  128. lionagi/tests/test_core/collections/test_exchange.py +0 -139
  129. lionagi/tests/test_core/collections/test_flow.py +0 -146
  130. lionagi/tests/test_core/collections/test_pile.py +0 -172
  131. lionagi/tests/test_core/collections/test_progression.py +0 -130
  132. lionagi/tests/test_core/generic/__init__.py +0 -0
  133. lionagi/tests/test_core/generic/test_edge.py +0 -69
  134. lionagi/tests/test_core/generic/test_graph.py +0 -97
  135. lionagi/tests/test_core/generic/test_node.py +0 -107
  136. lionagi/tests/test_core/generic/test_structure.py +0 -194
  137. lionagi/tests/test_core/generic/test_tree_node.py +0 -74
  138. lionagi/tests/test_core/graph/__init__.py +0 -0
  139. lionagi/tests/test_core/graph/test_graph.py +0 -71
  140. lionagi/tests/test_core/graph/test_tree.py +0 -76
  141. lionagi/tests/test_core/mail/__init__.py +0 -0
  142. lionagi/tests/test_core/mail/test_mail.py +0 -98
  143. lionagi/tests/test_core/test_branch.py +0 -116
  144. lionagi/tests/test_core/test_form.py +0 -47
  145. lionagi/tests/test_core/test_report.py +0 -106
  146. lionagi/tests/test_core/test_structure/__init__.py +0 -0
  147. lionagi/tests/test_core/test_structure/test_base_structure.py +0 -198
  148. lionagi/tests/test_core/test_structure/test_graph.py +0 -55
  149. lionagi/tests/test_core/test_structure/test_tree.py +0 -49
  150. lionagi/tests/test_core/test_validator.py +0 -112
  151. lionagi-0.2.11.dist-info/RECORD +0 -267
  152. {lionagi-0.2.11.dist-info → lionagi-0.3.1.dist-info}/LICENSE +0 -0
  153. {lionagi-0.2.11.dist-info → lionagi-0.3.1.dist-info}/WHEEL +0 -0
@@ -24,7 +24,11 @@ class Graph(Node):
24
24
  def internal_edges(self) -> Pile[Edge]:
25
25
  """Return a pile of all edges in the graph."""
26
26
  return pile(
27
- {edge.ln_id: edge for node in self.internal_nodes for edge in node.edges},
27
+ {
28
+ edge.ln_id: edge
29
+ for node in self.internal_nodes
30
+ for edge in node.edges
31
+ },
28
32
  Edge,
29
33
  )
30
34
 
@@ -70,7 +74,9 @@ class Graph(Node):
70
74
  edge = edge if isinstance(edge, list) else [edge]
71
75
  for i in edge:
72
76
  if i not in self.internal_edges:
73
- raise ItemNotFoundError(f"Edge {i} does not exist in structure.")
77
+ raise ItemNotFoundError(
78
+ f"Edge {i} does not exist in structure."
79
+ )
74
80
  with contextlib.suppress(ItemNotFoundError):
75
81
  self._remove_edge(i)
76
82
 
@@ -105,7 +111,8 @@ class Graph(Node):
105
111
  [
106
112
  edge
107
113
  for edge in edges
108
- if edge.label in to_list(label, dropna=True, flatten=True)
114
+ if edge.label
115
+ in to_list(label, dropna=True, flatten=True)
109
116
  ]
110
117
  )
111
118
  if edges
@@ -124,7 +131,9 @@ class Graph(Node):
124
131
  def _remove_edge(self, edge: Edge | str) -> bool:
125
132
  """Remove a specific edge from the graph."""
126
133
  if edge not in self.internal_edges:
127
- raise ItemNotFoundError(f"Edge {edge} does not exist in structure.")
134
+ raise ItemNotFoundError(
135
+ f"Edge {edge} does not exist in structure."
136
+ )
128
137
 
129
138
  edge = self.internal_edges[edge]
130
139
  head: Node = self.internal_nodes[edge.head]
@@ -139,7 +148,8 @@ class Graph(Node):
139
148
  [
140
149
  node
141
150
  for node in self.internal_nodes
142
- if node.relations["in"].is_empty() and not isinstance(node, Actionable)
151
+ if node.relations["in"].is_empty()
152
+ and not isinstance(node, Actionable)
143
153
  ]
144
154
  )
145
155
 
@@ -147,7 +157,9 @@ class Graph(Node):
147
157
  """Check if the graph is acyclic (contains no cycles)."""
148
158
  node_ids = list(self.internal_nodes.keys())
149
159
  check_deque = deque(node_ids)
150
- check_dict = {key: 0 for key in node_ids} # 0: not visited, 1: temp, 2: perm
160
+ check_dict = {
161
+ key: 0 for key in node_ids
162
+ } # 0: not visited, 1: temp, 2: perm
151
163
 
152
164
  def visit(key):
153
165
  if check_dict[key] == 2:
@@ -202,7 +214,11 @@ class Graph(Node):
202
214
  return g
203
215
 
204
216
  def display(
205
- self, node_label="class_name", edge_label="label", draw_kwargs={}, **kwargs
217
+ self,
218
+ node_label="class_name",
219
+ edge_label="label",
220
+ draw_kwargs={},
221
+ **kwargs,
206
222
  ):
207
223
  """Display the graph using NetworkX and Matplotlib."""
208
224
  from lionagi.libs import SysUtil
@@ -8,7 +8,7 @@ Includes functionality for managing relationships, such as adding,
8
8
  modifying, and removing edges, and querying related nodes and connections.
9
9
  """
10
10
 
11
- from typing import Callable
11
+ from collections.abc import Callable
12
12
 
13
13
  from pandas import Series
14
14
  from pydantic import Field
@@ -62,7 +62,11 @@ class Node(Component, Relatable):
62
62
  List of node IDs related to this node.
63
63
  """
64
64
  all_nodes = set(
65
- to_list([[i.head, i.tail] for i in self.edges], flatten=True, dropna=True)
65
+ to_list(
66
+ [[i.head, i.tail] for i in self.edges],
67
+ flatten=True,
68
+ dropna=True,
69
+ )
66
70
  )
67
71
  all_nodes.discard(self.ln_id)
68
72
  return list(all_nodes)
@@ -101,7 +105,9 @@ class Node(Component, Relatable):
101
105
  List of node IDs that precede this node.
102
106
  """
103
107
  return [
104
- node_id for node_id, edges in self.node_relations["in"].items() if edges
108
+ node_id
109
+ for node_id, edges in self.node_relations["in"].items()
110
+ if edges
105
111
  ]
106
112
 
107
113
  @property
@@ -113,7 +119,9 @@ class Node(Component, Relatable):
113
119
  List of node IDs that succeed this node.
114
120
  """
115
121
  return [
116
- node_id for node_id, edges in self.node_relations["out"].items() if edges
122
+ node_id
123
+ for node_id, edges in self.node_relations["out"].items()
124
+ if edges
117
125
  ]
118
126
 
119
127
  def relate(
@@ -141,7 +149,8 @@ class Node(Component, Relatable):
141
149
  """
142
150
  if direction not in ["in", "out"]:
143
151
  raise ValueError(
144
- f"Invalid value for direction: {direction}, " "must be 'in' or 'out'"
152
+ f"Invalid value for direction: {direction}, "
153
+ "must be 'in' or 'out'"
145
154
  )
146
155
 
147
156
  edge = edge_class(
@@ -28,7 +28,11 @@ class TreeNode(Node):
28
28
  if not self.parent:
29
29
  return list(self.related_nodes)
30
30
  else:
31
- return [node for node in self.related_nodes if node != self.parent.ln_id]
31
+ return [
32
+ node
33
+ for node in self.related_nodes
34
+ if node != self.parent.ln_id
35
+ ]
32
36
 
33
37
  def relate_child(
34
38
  self,
@@ -126,7 +126,9 @@ class MailManager(Element, Executable):
126
126
  mail_id = mailbox.pending_outs.popleft()
127
127
  mail = mailbox.pile.pop(mail_id)
128
128
  if mail.recipient not in self.sources:
129
- raise ValueError(f"Recipient source {mail.recipient} does not exist")
129
+ raise ValueError(
130
+ f"Recipient source {mail.recipient} does not exist"
131
+ )
130
132
  if mail.sender not in self.mails[mail.recipient]:
131
133
  self.mails[mail.recipient].update({mail.sender: deque()})
132
134
  self.mails[mail.recipient][mail.sender].append(mail)
@@ -44,4 +44,6 @@ class Package(Element):
44
44
  try:
45
45
  return PackageCategory(value)
46
46
  except Exception as e:
47
- raise ValueError(f"Invalid value for category: {value}.") from e
47
+ raise ValueError(
48
+ f"Invalid value for category: {value}."
49
+ ) from e
@@ -48,14 +48,21 @@ class ActionRequest(RoledMessage):
48
48
  sender (str, optional): The sender of the request.
49
49
  recipient (str, optional): The recipient of the request.
50
50
  """
51
- function = function.__name__ if inspect.isfunction(function) else function
51
+ function = (
52
+ function.__name__ if inspect.isfunction(function) else function
53
+ )
52
54
  arguments = _prepare_arguments(arguments)
53
55
 
54
56
  super().__init__(
55
57
  role=MessageRole.ASSISTANT,
56
58
  sender=sender,
57
59
  recipient=recipient,
58
- content={"action_request": {"function": function, "arguments": arguments}},
60
+ content={
61
+ "action_request": {
62
+ "function": function,
63
+ "arguments": arguments,
64
+ }
65
+ },
59
66
  **kwargs,
60
67
  )
61
68
  self.function = function
@@ -25,8 +25,12 @@ class ActionResponse(RoledMessage):
25
25
  description="The id of the action request that this response corresponds to",
26
26
  )
27
27
 
28
- function: str | None = Field(None, description="The name of the function called")
29
- arguments: dict | None = Field(None, description="The keyword arguments provided")
28
+ function: str | None = Field(
29
+ None, description="The name of the function called"
30
+ )
31
+ arguments: dict | None = Field(
32
+ None, description="The keyword arguments provided"
33
+ )
30
34
  func_outputs: Any | None = Field(
31
35
  None, description="The output of the function call"
32
36
  )
@@ -108,7 +112,9 @@ class ActionResponse(RoledMessage):
108
112
  action_request = ActionRequest(
109
113
  function=self.function, arguments=json.loads(arguments)
110
114
  )
111
- action_response_copy = ActionResponse(action_request=action_request, **kwargs)
115
+ action_response_copy = ActionResponse(
116
+ action_request=action_request, **kwargs
117
+ )
112
118
  action_response_copy.action_request = self.action_request
113
119
  action_response_copy.func_outputs = self.func_outputs
114
120
  action_response_copy.metadata["origin_ln_id"] = self.ln_id
@@ -109,7 +109,11 @@ class Instruction(RoledMessage):
109
109
  self, context, requested_fields, images, image_detail, **kwargs
110
110
  ):
111
111
  if context:
112
- context = {"context": context} if not isinstance(context, dict) else context
112
+ context = (
113
+ {"context": context}
114
+ if not isinstance(context, dict)
115
+ else context
116
+ )
113
117
  if (
114
118
  additional_context := {
115
119
  k: v for k, v in kwargs.items() if k not in SYSTEM_FIELDS
@@ -124,7 +128,9 @@ class Instruction(RoledMessage):
124
128
  )
125
129
 
126
130
  if images:
127
- self.content["images"] = images if isinstance(images, list) else [images]
131
+ self.content["images"] = (
132
+ images if isinstance(images, list) else [images]
133
+ )
128
134
  self.content["image_detail"] = image_detail
129
135
 
130
136
  def clone(self, **kwargs):
@@ -81,7 +81,9 @@ def create_message(
81
81
 
82
82
  if function:
83
83
  if not arguments:
84
- raise ValueError("Error: please provide arguments for the function.")
84
+ raise ValueError(
85
+ "Error: please provide arguments for the function."
86
+ )
85
87
  return ActionRequest(
86
88
  function=function,
87
89
  arguments=arguments,
@@ -158,7 +160,9 @@ def _parse_action_request(response):
158
160
  content_ = message["content"]["tool_uses"]
159
161
 
160
162
  else:
161
- json_block_pattern = re.compile(r"```json\n({.*?tool_uses.*?})\n```", re.DOTALL)
163
+ json_block_pattern = re.compile(
164
+ r"```json\n({.*?tool_uses.*?})\n```", re.DOTALL
165
+ )
162
166
 
163
167
  # Find the JSON block in the text
164
168
  match = json_block_pattern.search(str(message["content"]))
@@ -179,7 +183,9 @@ def _parse_action_request(response):
179
183
  outs = []
180
184
  for func_calling in content_:
181
185
  if "recipient_name" in func_calling:
182
- func_calling["action"] = func_calling["recipient_name"].split(".")[1]
186
+ func_calling["action"] = func_calling["recipient_name"].split(
187
+ "."
188
+ )[1]
183
189
  func_calling["arguments"] = func_calling["parameters"]
184
190
  elif "function" in func_calling:
185
191
  func_calling["action"] = func_calling["function"]
@@ -212,9 +218,13 @@ def _parse_action_request(response):
212
218
  if "function" in func_calling:
213
219
  func_calling["action"] = func_calling["function"]
214
220
  if "parameters" in func_calling:
215
- func_calling["arguments"] = func_calling["parameters"]
221
+ func_calling["arguments"] = func_calling[
222
+ "parameters"
223
+ ]
216
224
  elif "arguments" in func_calling:
217
- func_calling["arguments"] = func_calling["arguments"]
225
+ func_calling["arguments"] = func_calling[
226
+ "arguments"
227
+ ]
218
228
  msg = ActionRequest(
219
229
  function=func_calling["action"]
220
230
  .replace("action_", "")
@@ -71,12 +71,12 @@ class BaseForm(Component):
71
71
  examples=["input1, input2 -> output"],
72
72
  )
73
73
 
74
- input_fields: List[str] = Field(
74
+ input_fields: list[str] = Field(
75
75
  default_factory=list,
76
76
  description="Fields required to carry out the objective of the form.",
77
77
  )
78
78
 
79
- requested_fields: List[str] = Field(
79
+ requested_fields: list[str] = Field(
80
80
  default_factory=list,
81
81
  description="Fields requested to be filled by the user.",
82
82
  )
@@ -86,14 +86,14 @@ class BaseForm(Component):
86
86
  description="The work to be done by the form, including custom instructions.",
87
87
  )
88
88
 
89
- validation_kwargs: Dict[str, Dict[str, Any]] = Field(
89
+ validation_kwargs: dict[str, dict[str, Any]] = Field(
90
90
  default_factory=dict,
91
91
  description="Additional validation constraints for the form fields.",
92
92
  examples=[{"field": {"config1": "a", "config2": "b"}}],
93
93
  )
94
94
 
95
95
  @property
96
- def work_fields(self) -> Dict[str, Any]:
96
+ def work_fields(self) -> dict[str, Any]:
97
97
  """
98
98
  Get the fields relevant to the current task, including input and
99
99
  requested fields. Must be implemented by subclasses.
@@ -169,8 +169,8 @@ class BaseForm(Component):
169
169
  return True
170
170
 
171
171
  def _get_all_fields(
172
- self, form: List["BaseForm"] = None, **kwargs
173
- ) -> Dict[str, Any]:
172
+ self, form: list["BaseForm"] = None, **kwargs
173
+ ) -> dict[str, Any]:
174
174
  """
175
175
  Given a form or collections of forms, and additional fields, gather
176
176
  all fields together including self fields with valid value.
@@ -187,7 +187,12 @@ class BaseForm(Component):
187
187
  all_form_fields = (
188
188
  {}
189
189
  if not form
190
- else {k: v for i in form for k, v in i.work_fields.items() if v is not None}
190
+ else {
191
+ k: v
192
+ for i in form
193
+ for k, v in i.work_fields.items()
194
+ if v is not None
195
+ }
191
196
  )
192
197
  all_fields.update({**all_form_fields, **kwargs})
193
198
  return all_fields
@@ -84,7 +84,7 @@ class Form(BaseForm):
84
84
  )
85
85
 
86
86
  @property
87
- def work_fields(self) -> Dict[str, Any]:
87
+ def work_fields(self) -> dict[str, Any]:
88
88
  """
89
89
  Retrieves a dictionary of the fields relevant to the current task,
90
90
  excluding any SYSTEM_FIELDS and including only the input and requested
@@ -97,7 +97,8 @@ class Form(BaseForm):
97
97
  return {
98
98
  k: v
99
99
  for k, v in dict_.items()
100
- if k not in SYSTEM_FIELDS and k in self.input_fields + self.requested_fields
100
+ if k not in SYSTEM_FIELDS
101
+ and k in self.input_fields + self.requested_fields
101
102
  }
102
103
 
103
104
  def fill(self, form: "Form" = None, strict: bool = True, **kwargs) -> None:
@@ -134,7 +135,9 @@ class Form(BaseForm):
134
135
  bool: True if the form is workable, otherwise raises ValueError.
135
136
  """
136
137
  if self.filled:
137
- raise ValueError("Form is already filled, cannot be worked on again")
138
+ raise ValueError(
139
+ "Form is already filled, cannot be worked on again"
140
+ )
138
141
 
139
142
  for i in self.input_fields:
140
143
  if not getattr(self, i, None):
@@ -172,7 +175,7 @@ class Form(BaseForm):
172
175
  """
173
176
 
174
177
  @property
175
- def _instruction_requested_fields(self) -> Dict[str, str]:
178
+ def _instruction_requested_fields(self) -> dict[str, str]:
176
179
  """
177
180
  Provides a dictionary mapping requested field names to their
178
181
  descriptions.
@@ -25,7 +25,7 @@ class Report(BaseForm):
25
25
  examples=[["a, b -> c", "a -> e", "b -> f", "c -> g", "e, f, g -> h"]],
26
26
  )
27
27
 
28
- form_template: Type[Form] = Field(
28
+ form_template: type[Form] = Field(
29
29
  Form, description="The template for the forms in the report."
30
30
  )
31
31
 
@@ -72,7 +72,12 @@ class Report(BaseForm):
72
72
  all_fields[k] = v
73
73
  return all_fields
74
74
 
75
- def fill(self, form: Form | list[Form] | dict[Form] = None, strict=True, **kwargs):
75
+ def fill(
76
+ self,
77
+ form: Form | list[Form] | dict[Form] = None,
78
+ strict=True,
79
+ **kwargs,
80
+ ):
76
81
  if self.filled:
77
82
  if strict:
78
83
  raise ValueError("Form is filled, cannot be worked on again")
@@ -105,7 +110,9 @@ class Report(BaseForm):
105
110
  bool: True if the report is workable, otherwise raises ValueError.
106
111
  """
107
112
  if self.filled:
108
- raise ValueError("Form is already filled, cannot be worked on again")
113
+ raise ValueError(
114
+ "Form is already filled, cannot be worked on again"
115
+ )
109
116
 
110
117
  for i in self.input_fields:
111
118
  if not getattr(self, i, None):
@@ -3,7 +3,9 @@ def get_input_output_fields(str_: str) -> list[list[str]]:
3
3
  return [], []
4
4
 
5
5
  if "->" not in str_:
6
- raise ValueError("Invalid assignment format. Expected 'inputs -> outputs'.")
6
+ raise ValueError(
7
+ "Invalid assignment format. Expected 'inputs -> outputs'."
8
+ )
7
9
 
8
10
  inputs, outputs = str_.split("->")
9
11
 
@@ -49,7 +49,10 @@ class ActionRequestRule(MappingRule):
49
49
  Raises:
50
50
  ActionError: If the action request is invalid.
51
51
  """
52
- if isinstance(value, dict) and list(value.keys()) >= ["function", "arguments"]:
52
+ if isinstance(value, dict) and list(value.keys()) >= [
53
+ "function",
54
+ "arguments",
55
+ ]:
53
56
  return value
54
57
  raise ActionError(f"Invalid action request: {value}")
55
58
 
lionagi/core/rule/base.py CHANGED
@@ -3,7 +3,12 @@ from typing import Any, Dict, List
3
3
 
4
4
  from pandas import Series
5
5
 
6
- from lionagi.core.collections.abc import Actionable, Component, Condition, FieldError
6
+ from lionagi.core.collections.abc import (
7
+ Actionable,
8
+ Component,
9
+ Condition,
10
+ FieldError,
11
+ )
7
12
  from lionagi.libs import SysUtil
8
13
 
9
14
  _rule_classes = {}
@@ -37,7 +42,9 @@ class Rule(Component, Condition, Actionable):
37
42
  if cls.__name__ not in _rule_classes:
38
43
  _rule_classes[cls.__name__] = cls
39
44
 
40
- def add_log(self, field: str, form: Any, apply: bool = True, **kwargs) -> None:
45
+ def add_log(
46
+ self, field: str, form: Any, apply: bool = True, **kwargs
47
+ ) -> None:
41
48
  """
42
49
  Adds an entry to the applied or invoked log.
43
50
 
@@ -67,7 +74,7 @@ class Rule(Component, Condition, Actionable):
67
74
  value: Any,
68
75
  form: Any,
69
76
  *args,
70
- annotation: List[str] = None,
77
+ annotation: list[str] = None,
71
78
  use_annotation: bool = True,
72
79
  **kwargs,
73
80
  ) -> bool:
@@ -93,7 +100,9 @@ class Rule(Component, Condition, Actionable):
93
100
 
94
101
  if use_annotation:
95
102
  annotation = annotation or form._get_field_annotation(field)
96
- annotation = [annotation] if isinstance(annotation, str) else annotation
103
+ annotation = (
104
+ [annotation] if isinstance(annotation, str) else annotation
105
+ )
97
106
 
98
107
  for i in annotation:
99
108
  if i in self.apply_type and i not in self.exclude_type:
@@ -132,7 +141,9 @@ class Rule(Component, Condition, Actionable):
132
141
  if self.fix:
133
142
  try:
134
143
  a = await self.perform_fix(value, **self.validation_kwargs)
135
- self.add_log(field, form, apply=False, **self.validation_kwargs)
144
+ self.add_log(
145
+ field, form, apply=False, **self.validation_kwargs
146
+ )
136
147
  return a
137
148
  except Exception as e2:
138
149
  raise FieldError(f"failed to fix field") from e2
@@ -184,7 +195,7 @@ class Rule(Component, Condition, Actionable):
184
195
  """
185
196
  pass
186
197
 
187
- def _to_dict(self) -> Dict[str, Any]:
198
+ def _to_dict(self) -> dict[str, Any]:
188
199
  """
189
200
  Converts the rule's attributes to a dictionary.
190
201
 
@@ -3,12 +3,12 @@ from lionfuncs import lcall
3
3
  from lionagi.core.rule.base import Rule
4
4
 
5
5
  """
6
- rule config schema
6
+ rule config schema
7
7
 
8
8
  {
9
9
  rule_name: {
10
10
  "fields: [],
11
- "config": {},
11
+ "config": {},
12
12
  ...
13
13
  }
14
14
  }
@@ -30,12 +30,16 @@ class RuleBook:
30
30
  @property
31
31
  def _all_applied_log(self):
32
32
  """return all applied logs from all rules in the rulebook"""
33
- return lcall(self.rules.values(), lambda x: x.applied_log, flatten=True)
33
+ return lcall(
34
+ self.rules.values(), lambda x: x.applied_log, flatten=True
35
+ )
34
36
 
35
37
  @property
36
38
  def _all_invoked_log(self):
37
39
  """return all invoked logs from all rules in the rulebook"""
38
- return lcall(self.rules.values(), lambda x: x.invoked_log, flatten=True)
40
+ return lcall(
41
+ self.rules.values(), lambda x: x.invoked_log, flatten=True
42
+ )
39
43
 
40
44
  def __getitem__(self, key: str) -> Rule:
41
45
  return self.rules[key]
@@ -38,4 +38,6 @@ class StringRule(Rule):
38
38
  try:
39
39
  return to_str(value, **self.validation_kwargs)
40
40
  except Exception as e:
41
- raise ValueError(f"Failed to convert {value} into a string value") from e
41
+ raise ValueError(
42
+ f"Failed to convert {value} into a string value"
43
+ ) from e
@@ -165,7 +165,9 @@ class Branch(Node, DirectiveMixin):
165
165
  )
166
166
 
167
167
  if isinstance(_msg, System):
168
- _msg.recipient = self.ln_id # the branch itself, system is to the branch
168
+ _msg.recipient = (
169
+ self.ln_id
170
+ ) # the branch itself, system is to the branch
169
171
  self._remove_system()
170
172
  self.system = _msg
171
173
 
@@ -255,8 +257,13 @@ class Branch(Node, DirectiveMixin):
255
257
  return True
256
258
  elif is_same_dtype(tools, Tool):
257
259
  for act_ in tools:
258
- if act_.schema_["function"]["name"] in self.tool_manager.registry:
259
- self.tool_manager.registry.pop(act_.schema_["function"]["name"])
260
+ if (
261
+ act_.schema_["function"]["name"]
262
+ in self.tool_manager.registry
263
+ ):
264
+ self.tool_manager.registry.pop(
265
+ act_.schema_["function"]["name"]
266
+ )
260
267
  if verbose:
261
268
  print("tools successfully deleted")
262
269
  return True
@@ -330,7 +337,11 @@ class Branch(Node, DirectiveMixin):
330
337
  return isinstance(self.messages[-1], ActionResponse)
331
338
 
332
339
  def send(
333
- self, recipient: str, category: str, package: Any, request_source: str = None
340
+ self,
341
+ recipient: str,
342
+ category: str,
343
+ package: Any,
344
+ request_source: str = None,
334
345
  ) -> None:
335
346
  """
336
347
  Sends a mail to a recipient.
@@ -1,4 +1,5 @@
1
1
  # lionagi/core/session/directive_mixin.py
2
+ from pydantic import BaseModel
2
3
 
3
4
  from lionagi.core.unit import Unit
4
5
 
@@ -33,13 +34,14 @@ class DirectiveMixin:
33
34
  default=None,
34
35
  timeout: float = None,
35
36
  timing: bool = False,
36
- return_branch=False,
37
37
  images=None,
38
38
  image_path=None,
39
39
  template=None,
40
40
  verbose=True,
41
41
  formatter=None,
42
42
  format_kwargs=None,
43
+ pydantic_model: type[BaseModel] = None,
44
+ return_pydantic_model: bool = False,
43
45
  **kwargs,
44
46
  ):
45
47
  """
@@ -120,7 +122,7 @@ class DirectiveMixin:
120
122
 
121
123
  images = ImageUtil.read_image_to_base64(image_path)
122
124
 
123
- return await directive.chat(
125
+ output = await directive.chat(
124
126
  instruction=instruction,
125
127
  context=context,
126
128
  sender=sender,
@@ -139,10 +141,16 @@ class DirectiveMixin:
139
141
  timeout=timeout,
140
142
  timing=timing,
141
143
  clear_messages=clear_messages,
142
- return_branch=return_branch,
143
144
  images=images,
145
+ return_pydantic_model=return_pydantic_model,
146
+ pydantic_model=pydantic_model,
147
+ return_branch=False,
144
148
  **kwargs,
145
149
  )
150
+ while isinstance(output, tuple | list):
151
+ if len(output) == 2 and output[0] == output[1]:
152
+ output = output[0]
153
+ return output
146
154
 
147
155
  async def direct(
148
156
  self,
@@ -45,7 +45,9 @@ class Session:
45
45
  ):
46
46
  self.ln_id = SysUtil.create_id()
47
47
  self.timestamp = SysUtil.get_timestamp(sep=None)[:-6]
48
- system = system or "You are a helpful assistant, let's think step by step"
48
+ system = (
49
+ system or "You are a helpful assistant, let's think step by step"
50
+ )
49
51
  self.system = System(system=system, sender=system_sender)
50
52
  self.system_sender = system_sender
51
53
  self.branches: Pile[Branch] = self._validate_branches(branches)
@@ -77,7 +79,9 @@ class Session:
77
79
  if isinstance(value, Pile):
78
80
  for branch in value:
79
81
  if not isinstance(branch, Branch):
80
- raise ValueError("The branches pile contains non-Branch object")
82
+ raise ValueError(
83
+ "The branches pile contains non-Branch object"
84
+ )
81
85
  return value
82
86
  else:
83
87
  try: