lionagi 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (153) hide show
  1. lionagi/core/action/function_calling.py +13 -6
  2. lionagi/core/action/tool.py +10 -9
  3. lionagi/core/action/tool_manager.py +18 -9
  4. lionagi/core/agent/README.md +1 -1
  5. lionagi/core/agent/base_agent.py +5 -2
  6. lionagi/core/agent/eval/README.md +1 -1
  7. lionagi/core/collections/README.md +1 -1
  8. lionagi/core/collections/_logger.py +16 -6
  9. lionagi/core/collections/abc/README.md +1 -1
  10. lionagi/core/collections/abc/component.py +35 -11
  11. lionagi/core/collections/abc/concepts.py +5 -3
  12. lionagi/core/collections/abc/exceptions.py +3 -1
  13. lionagi/core/collections/flow.py +16 -5
  14. lionagi/core/collections/model.py +34 -8
  15. lionagi/core/collections/pile.py +65 -28
  16. lionagi/core/collections/progression.py +1 -2
  17. lionagi/core/collections/util.py +11 -2
  18. lionagi/core/director/README.md +1 -1
  19. lionagi/core/engine/branch_engine.py +35 -10
  20. lionagi/core/engine/instruction_map_engine.py +14 -5
  21. lionagi/core/engine/sandbox_.py +3 -1
  22. lionagi/core/engine/script_engine.py +6 -2
  23. lionagi/core/executor/base_executor.py +10 -3
  24. lionagi/core/executor/graph_executor.py +12 -4
  25. lionagi/core/executor/neo4j_executor.py +18 -6
  26. lionagi/core/generic/edge.py +7 -2
  27. lionagi/core/generic/graph.py +23 -7
  28. lionagi/core/generic/node.py +14 -5
  29. lionagi/core/generic/tree_node.py +5 -1
  30. lionagi/core/mail/mail_manager.py +3 -1
  31. lionagi/core/mail/package.py +3 -1
  32. lionagi/core/message/action_request.py +9 -2
  33. lionagi/core/message/action_response.py +9 -3
  34. lionagi/core/message/instruction.py +8 -2
  35. lionagi/core/message/util.py +15 -5
  36. lionagi/core/report/base.py +12 -7
  37. lionagi/core/report/form.py +7 -4
  38. lionagi/core/report/report.py +10 -3
  39. lionagi/core/report/util.py +3 -1
  40. lionagi/core/rule/action.py +4 -1
  41. lionagi/core/rule/base.py +17 -6
  42. lionagi/core/rule/rulebook.py +8 -4
  43. lionagi/core/rule/string.py +3 -1
  44. lionagi/core/session/branch.py +15 -4
  45. lionagi/core/session/directive_mixin.py +11 -3
  46. lionagi/core/session/session.py +6 -2
  47. lionagi/core/unit/parallel_unit.py +9 -3
  48. lionagi/core/unit/template/action.py +1 -1
  49. lionagi/core/unit/template/predict.py +3 -1
  50. lionagi/core/unit/template/select.py +5 -3
  51. lionagi/core/unit/unit.py +38 -4
  52. lionagi/core/unit/unit_form.py +13 -15
  53. lionagi/core/unit/unit_mixin.py +45 -27
  54. lionagi/core/unit/util.py +7 -3
  55. lionagi/core/validator/validator.py +28 -15
  56. lionagi/core/work/work_edge.py +7 -3
  57. lionagi/core/work/work_task.py +11 -5
  58. lionagi/core/work/worker.py +20 -5
  59. lionagi/core/work/worker_engine.py +6 -2
  60. lionagi/core/work/worklog.py +3 -1
  61. lionagi/experimental/compressor/llm_compressor.py +20 -5
  62. lionagi/experimental/directive/README.md +1 -1
  63. lionagi/experimental/directive/parser/base_parser.py +41 -14
  64. lionagi/experimental/directive/parser/base_syntax.txt +23 -23
  65. lionagi/experimental/directive/template/base_template.py +14 -6
  66. lionagi/experimental/directive/tokenizer.py +3 -1
  67. lionagi/experimental/evaluator/README.md +1 -1
  68. lionagi/experimental/evaluator/ast_evaluator.py +6 -2
  69. lionagi/experimental/evaluator/base_evaluator.py +27 -16
  70. lionagi/integrations/bridge/autogen_/autogen_.py +7 -3
  71. lionagi/integrations/bridge/langchain_/documents.py +13 -10
  72. lionagi/integrations/bridge/llamaindex_/llama_pack.py +36 -12
  73. lionagi/integrations/bridge/llamaindex_/node_parser.py +8 -3
  74. lionagi/integrations/bridge/llamaindex_/reader.py +3 -1
  75. lionagi/integrations/bridge/llamaindex_/textnode.py +9 -3
  76. lionagi/integrations/bridge/pydantic_/pydantic_bridge.py +7 -1
  77. lionagi/integrations/bridge/transformers_/install_.py +3 -1
  78. lionagi/integrations/chunker/chunk.py +5 -2
  79. lionagi/integrations/loader/load.py +7 -3
  80. lionagi/integrations/loader/load_util.py +35 -16
  81. lionagi/integrations/provider/oai.py +13 -4
  82. lionagi/integrations/provider/openrouter.py +13 -4
  83. lionagi/integrations/provider/services.py +3 -1
  84. lionagi/integrations/provider/transformers.py +5 -3
  85. lionagi/integrations/storage/neo4j.py +23 -7
  86. lionagi/integrations/storage/storage_util.py +23 -7
  87. lionagi/integrations/storage/structure_excel.py +7 -2
  88. lionagi/integrations/storage/to_csv.py +8 -2
  89. lionagi/integrations/storage/to_excel.py +11 -3
  90. lionagi/libs/ln_api.py +41 -19
  91. lionagi/libs/ln_context.py +4 -4
  92. lionagi/libs/ln_convert.py +35 -14
  93. lionagi/libs/ln_dataframe.py +9 -3
  94. lionagi/libs/ln_func_call.py +53 -18
  95. lionagi/libs/ln_image.py +9 -5
  96. lionagi/libs/ln_knowledge_graph.py +21 -7
  97. lionagi/libs/ln_nested.py +57 -16
  98. lionagi/libs/ln_parse.py +45 -15
  99. lionagi/libs/ln_queue.py +8 -3
  100. lionagi/libs/ln_tokenize.py +19 -6
  101. lionagi/libs/ln_validate.py +14 -3
  102. lionagi/libs/sys_util.py +44 -12
  103. lionagi/lions/coder/coder.py +24 -8
  104. lionagi/lions/coder/util.py +6 -2
  105. lionagi/lions/researcher/data_source/google_.py +12 -4
  106. lionagi/lions/researcher/data_source/wiki_.py +3 -1
  107. lionagi/version.py +1 -1
  108. {lionagi-0.2.11.dist-info → lionagi-0.3.1.dist-info}/METADATA +6 -7
  109. lionagi-0.3.1.dist-info/RECORD +226 -0
  110. lionagi/tests/__init__.py +0 -0
  111. lionagi/tests/api/__init__.py +0 -0
  112. lionagi/tests/api/aws/__init__.py +0 -0
  113. lionagi/tests/api/aws/conftest.py +0 -25
  114. lionagi/tests/api/aws/test_aws_s3.py +0 -6
  115. lionagi/tests/integrations/__init__.py +0 -0
  116. lionagi/tests/libs/__init__.py +0 -0
  117. lionagi/tests/libs/test_api.py +0 -48
  118. lionagi/tests/libs/test_convert.py +0 -89
  119. lionagi/tests/libs/test_field_validators.py +0 -354
  120. lionagi/tests/libs/test_func_call.py +0 -701
  121. lionagi/tests/libs/test_nested.py +0 -382
  122. lionagi/tests/libs/test_parse.py +0 -171
  123. lionagi/tests/libs/test_queue.py +0 -68
  124. lionagi/tests/libs/test_sys_util.py +0 -222
  125. lionagi/tests/test_core/__init__.py +0 -0
  126. lionagi/tests/test_core/collections/__init__.py +0 -0
  127. lionagi/tests/test_core/collections/test_component.py +0 -208
  128. lionagi/tests/test_core/collections/test_exchange.py +0 -139
  129. lionagi/tests/test_core/collections/test_flow.py +0 -146
  130. lionagi/tests/test_core/collections/test_pile.py +0 -172
  131. lionagi/tests/test_core/collections/test_progression.py +0 -130
  132. lionagi/tests/test_core/generic/__init__.py +0 -0
  133. lionagi/tests/test_core/generic/test_edge.py +0 -69
  134. lionagi/tests/test_core/generic/test_graph.py +0 -97
  135. lionagi/tests/test_core/generic/test_node.py +0 -107
  136. lionagi/tests/test_core/generic/test_structure.py +0 -194
  137. lionagi/tests/test_core/generic/test_tree_node.py +0 -74
  138. lionagi/tests/test_core/graph/__init__.py +0 -0
  139. lionagi/tests/test_core/graph/test_graph.py +0 -71
  140. lionagi/tests/test_core/graph/test_tree.py +0 -76
  141. lionagi/tests/test_core/mail/__init__.py +0 -0
  142. lionagi/tests/test_core/mail/test_mail.py +0 -98
  143. lionagi/tests/test_core/test_branch.py +0 -116
  144. lionagi/tests/test_core/test_form.py +0 -47
  145. lionagi/tests/test_core/test_report.py +0 -106
  146. lionagi/tests/test_core/test_structure/__init__.py +0 -0
  147. lionagi/tests/test_core/test_structure/test_base_structure.py +0 -198
  148. lionagi/tests/test_core/test_structure/test_graph.py +0 -55
  149. lionagi/tests/test_core/test_structure/test_tree.py +0 -49
  150. lionagi/tests/test_core/test_validator.py +0 -112
  151. lionagi-0.2.11.dist-info/RECORD +0 -267
  152. {lionagi-0.2.11.dist-info → lionagi-0.3.1.dist-info}/LICENSE +0 -0
  153. {lionagi-0.2.11.dist-info → lionagi-0.3.1.dist-info}/WHEEL +0 -0
@@ -59,7 +59,9 @@ class ParallelUnit(Directive):
59
59
  else:
60
60
  self.imodel = session.imodel
61
61
  self.form_template = template or self.default_template
62
- self.validator = Validator(rulebook=rulebook) if rulebook else Validator()
62
+ self.validator = (
63
+ Validator(rulebook=rulebook) if rulebook else Validator()
64
+ )
63
65
 
64
66
  async def pchat(self, *args, **kwargs):
65
67
  """
@@ -174,13 +176,17 @@ class ParallelUnit(Directive):
174
176
 
175
177
  async def _inner_3(i):
176
178
  """different instructions but same context"""
177
- tasks = [_inner_2(i, ins_=ins_) for ins_ in convert.to_list(instruction)]
179
+ tasks = [
180
+ _inner_2(i, ins_=ins_) for ins_ in convert.to_list(instruction)
181
+ ]
178
182
  ress = await AsyncUtil.execute_tasks(*tasks)
179
183
  return convert.to_list(ress)
180
184
 
181
185
  async def _inner_3_b(i):
182
186
  """different context but same instruction"""
183
- tasks = [_inner_2(i, cxt_=cxt_) for cxt_ in convert.to_list(context)]
187
+ tasks = [
188
+ _inner_2(i, cxt_=cxt_) for cxt_ in convert.to_list(context)
189
+ ]
184
190
  ress = await AsyncUtil.execute_tasks(*tasks)
185
191
  return convert.to_list(ress)
186
192
 
@@ -58,7 +58,7 @@ class ActionTemplate(BaseUnitForm):
58
58
 
59
59
  self.task = f"""
60
60
  Perform reasoning and prepare actions with GIVEN TOOLS ONLY.
61
- 1. additional instruction: {instruction or "N/A"}.
61
+ 1. additional instruction: {instruction or "N/A"}.
62
62
  2. additional context: {context or "N/A"}.
63
63
  """
64
64
  if confidence_score:
@@ -34,7 +34,9 @@ class PredictTemplate(BaseUnitForm):
34
34
 
35
35
  template_name: str = "predict_template"
36
36
 
37
- num_sentences: int = Field(2, description="the number of sentences to predict")
37
+ num_sentences: int = Field(
38
+ 2, description="the number of sentences to predict"
39
+ )
38
40
 
39
41
  prediction: None | str | list = Field(
40
42
  None,
@@ -39,7 +39,9 @@ class SelectTemplate(BaseUnitForm):
39
39
  selection: Enum | str | list | None = Field(
40
40
  None, description="selection from given choices"
41
41
  )
42
- choices: list = Field(default_factory=list, description="the given choices")
42
+ choices: list = Field(
43
+ default_factory=list, description="the given choices"
44
+ )
43
45
 
44
46
  assignment: str = "task -> selection"
45
47
 
@@ -78,9 +80,9 @@ class SelectTemplate(BaseUnitForm):
78
80
 
79
81
  self.choices = choices
80
82
  self.task = f"""
81
- select 1 item from the provided choices {choices}.
83
+ select 1 item from the provided choices {choices}.
82
84
  1. additional objective: {instruction or "N/A"}.
83
- 2. additional information: {context or "N/A"}.
85
+ 2. additional information: {context or "N/A"}.
84
86
  """
85
87
  if reason:
86
88
  self.append_to_request("reason")
lionagi/core/unit/unit.py CHANGED
@@ -1,4 +1,8 @@
1
- from typing import Callable
1
+ import logging
2
+ from collections.abc import Callable
3
+
4
+ from lionfuncs import to_dict
5
+ from pydantic import BaseModel
2
6
 
3
7
  from lionagi.core.collections import iModel
4
8
  from lionagi.core.collections.abc import Directive
@@ -71,6 +75,8 @@ class Unit(Directive, DirectiveMixin):
71
75
  return_branch=False,
72
76
  formatter=None,
73
77
  format_kwargs={},
78
+ pydantic_model: type[BaseModel] = None,
79
+ return_pydantic_model: bool = False,
74
80
  **kwargs,
75
81
  ):
76
82
  """
@@ -100,7 +106,17 @@ class Unit(Directive, DirectiveMixin):
100
106
  Any: The processed response.
101
107
  """
102
108
  kwargs = {**retry_kwargs, **kwargs}
103
- return await rcall(
109
+
110
+ if pydantic_model:
111
+ if form:
112
+ raise ValueError("Cannot use both form and pydantic_model.")
113
+ if requested_fields:
114
+ raise ValueError(
115
+ "Cannot use both requested_fields and pydantic_model."
116
+ )
117
+ requested_fields = pydantic_model.model_json_schema()["properties"]
118
+
119
+ output, branch = await rcall(
104
120
  self._chat,
105
121
  instruction=instruction,
106
122
  context=context,
@@ -118,11 +134,27 @@ class Unit(Directive, DirectiveMixin):
118
134
  imodel=imodel,
119
135
  clear_messages=clear_messages,
120
136
  use_annotation=use_annotation,
121
- return_branch=return_branch,
137
+ return_branch=True,
122
138
  formatter=formatter,
123
139
  format_kwargs=format_kwargs,
124
140
  **kwargs,
125
141
  )
142
+ if isinstance(output, tuple | list) and len(output) == 1:
143
+ output = output[0]
144
+
145
+ if isinstance(output, tuple | list) and len(output) == 2:
146
+ if output[0] == output[1]:
147
+ output = output[0]
148
+
149
+ if return_pydantic_model:
150
+ try:
151
+ a_ = to_dict(output, recursive=True, max_recursive_depth=3)
152
+ output = pydantic_model(**a_)
153
+ return output, branch if return_branch else output
154
+ except Exception as e:
155
+ logging.error(f"Error converting to pydantic model: {e}")
156
+
157
+ return output, branch if return_branch else output
126
158
 
127
159
  async def direct(
128
160
  self,
@@ -232,7 +264,9 @@ class Unit(Directive, DirectiveMixin):
232
264
  )
233
265
 
234
266
  if verbose:
235
- print("--------------------------------------------------------------")
267
+ print(
268
+ "--------------------------------------------------------------"
269
+ )
236
270
  print(f"Directive successfully completed!")
237
271
 
238
272
  return out
@@ -38,7 +38,9 @@ class UnitForm(BaseUnitForm):
38
38
  "number, you should provide a number like 1, 23, or 1.1 if float is "
39
39
  "allowed."
40
40
  ),
41
- examples=["{action_1: {function: 'add', arguments: {num1: 1, num2: 2}}}"],
41
+ examples=[
42
+ "{action_1: {function: 'add', arguments: {num1: 1, num2: 2}}}"
43
+ ],
42
44
  )
43
45
 
44
46
  action_required: bool | None = Field(
@@ -188,13 +190,13 @@ class UnitForm(BaseUnitForm):
188
190
 
189
191
  if allow_action:
190
192
  self.append_to_request("actions, action_required, reason")
191
- self.task += "- Reason and prepare actions with GIVEN TOOLS ONLY.\n"
193
+ self.task += (
194
+ "- Reason and prepare actions with GIVEN TOOLS ONLY.\n"
195
+ )
192
196
 
193
197
  if allow_extension:
194
198
  self.append_to_request("extension_required")
195
- self.task += (
196
- f"- Allow auto-extension up to another {max_extension} rounds.\n"
197
- )
199
+ self.task += f"- Allow auto-extension up to another {max_extension} rounds.\n"
198
200
 
199
201
  if tool_schema:
200
202
  self.append_to_input("tool_schema")
@@ -209,21 +211,15 @@ class UnitForm(BaseUnitForm):
209
211
  max_extension = max_extension or plan_num_step
210
212
  allow_extension = True
211
213
  self.append_to_request("plan, extension_required")
212
- self.task += (
213
- f"- Generate a {plan_num_step}-step plan based on the context.\n"
214
- )
214
+ self.task += f"- Generate a {plan_num_step}-step plan based on the context.\n"
215
215
 
216
216
  if predict:
217
217
  self.append_to_request("prediction")
218
- self.task += (
219
- f"- Predict the next {predict_num_sentences or 1} sentence(s).\n"
220
- )
218
+ self.task += f"- Predict the next {predict_num_sentences or 1} sentence(s).\n"
221
219
 
222
220
  if select:
223
221
  self.append_to_request("selection")
224
- self.task += (
225
- f"- Select 1 item from the provided choices: {select_choices}.\n"
226
- )
222
+ self.task += f"- Select 1 item from the provided choices: {select_choices}.\n"
227
223
 
228
224
  if confidence:
229
225
  self.append_to_request("confidence_score")
@@ -238,7 +234,9 @@ class UnitForm(BaseUnitForm):
238
234
  "upper_bound": score_range[1],
239
235
  "lower_bound": score_range[0],
240
236
  "num_type": int if score_num_digits == 0 else float,
241
- "precision": score_num_digits if score_num_digits != 0 else None,
237
+ "precision": (
238
+ score_num_digits if score_num_digits != 0 else None
239
+ ),
242
240
  }
243
241
 
244
242
  self.task += (
@@ -22,16 +22,16 @@ class DirectiveMixin(ABC):
22
22
 
23
23
  def _create_chat_config(
24
24
  self,
25
- system: Optional[str] = None,
26
- instruction: Optional[str] = None,
27
- context: Optional[str] = None,
28
- images: Optional[str] = None,
29
- sender: Optional[str] = None,
30
- recipient: Optional[str] = None,
31
- requested_fields: Optional[list] = None,
25
+ system: str | None = None,
26
+ instruction: str | None = None,
27
+ context: str | None = None,
28
+ images: str | None = None,
29
+ sender: str | None = None,
30
+ recipient: str | None = None,
31
+ requested_fields: list | None = None,
32
32
  form: Form = None,
33
33
  tools: bool = False,
34
- branch: Optional[Any] = None,
34
+ branch: Any | None = None,
35
35
  **kwargs,
36
36
  ) -> Any:
37
37
  """
@@ -87,7 +87,7 @@ class DirectiveMixin(ABC):
87
87
  return config
88
88
 
89
89
  async def _call_chatcompletion(
90
- self, imodel: Optional[Any] = None, branch: Optional[Any] = None, **kwargs
90
+ self, imodel: Any | None = None, branch: Any | None = None, **kwargs
91
91
  ) -> Any:
92
92
  """
93
93
  Calls the chat completion model.
@@ -102,7 +102,9 @@ class DirectiveMixin(ABC):
102
102
  """
103
103
  imodel = imodel or self.imodel
104
104
  branch = branch or self.branch
105
- return await imodel.call_chat_completion(branch.to_chat_messages(), **kwargs)
105
+ return await imodel.call_chat_completion(
106
+ branch.to_chat_messages(), **kwargs
107
+ )
106
108
 
107
109
  async def _process_chatcompletion(
108
110
  self,
@@ -110,8 +112,8 @@ class DirectiveMixin(ABC):
110
112
  completion: dict,
111
113
  sender: str,
112
114
  invoke_tool: bool = True,
113
- branch: Optional[Any] = None,
114
- action_request: Optional[Any] = None,
115
+ branch: Any | None = None,
116
+ action_request: Any | None = None,
115
117
  costs=None,
116
118
  ) -> Any:
117
119
  """
@@ -157,7 +159,9 @@ class DirectiveMixin(ABC):
157
159
  m = completion.get("model", None)
158
160
  if m:
159
161
  ttl = (a * price[0] + b * price[1]) / 1000000
160
- branch.messages[-1]._meta_insert(["extra", "usage", "expense"], ttl)
162
+ branch.messages[-1]._meta_insert(
163
+ ["extra", "usage", "expense"], ttl
164
+ )
161
165
  return msg
162
166
 
163
167
  if _choices and not isinstance(_choices, list):
@@ -180,10 +184,10 @@ class DirectiveMixin(ABC):
180
184
 
181
185
  async def _process_action_request(
182
186
  self,
183
- _msg: Optional[dict] = None,
184
- branch: Optional[Any] = None,
187
+ _msg: dict | None = None,
188
+ branch: Any | None = None,
185
189
  invoke_tool: bool = True,
186
- action_request: Optional[Any] = None,
190
+ action_request: Any | None = None,
187
191
  ) -> Any:
188
192
  """
189
193
  Processes an action request from the assistant response.
@@ -204,9 +208,13 @@ class DirectiveMixin(ABC):
204
208
  if action_request:
205
209
  for i in action_request:
206
210
  if i.function in branch.tool_manager.registry:
207
- i.recipient = branch.tool_manager.registry[i.function].ln_id
211
+ i.recipient = branch.tool_manager.registry[
212
+ i.function
213
+ ].ln_id
208
214
  else:
209
- raise ActionError(f"Tool {i.function} not found in registry")
215
+ raise ActionError(
216
+ f"Tool {i.function} not found in registry"
217
+ )
210
218
  branch.add_message(action_request=i, recipient=i.recipient)
211
219
 
212
220
  if invoke_tool:
@@ -305,7 +313,7 @@ class DirectiveMixin(ABC):
305
313
  requested_fields: dict = None,
306
314
  form: Form = None,
307
315
  tools: Any = False,
308
- images: Optional[str] = None,
316
+ images: str | None = None,
309
317
  invoke_tool: bool = True,
310
318
  return_form: bool = True,
311
319
  strict: bool = False,
@@ -396,7 +404,7 @@ class DirectiveMixin(ABC):
396
404
  return_form=True,
397
405
  strict=False,
398
406
  imodel=None,
399
- images: Optional[str] = None,
407
+ images: str | None = None,
400
408
  clear_messages=False,
401
409
  use_annotation=True,
402
410
  timeout: float = None,
@@ -493,7 +501,7 @@ class DirectiveMixin(ABC):
493
501
  predict_num_sentences=None,
494
502
  clear_messages=False,
495
503
  return_branch=False,
496
- images: Optional[str] = None,
504
+ images: str | None = None,
497
505
  verbose=None,
498
506
  **kwargs,
499
507
  ):
@@ -584,7 +592,7 @@ class DirectiveMixin(ABC):
584
592
  predict_num_sentences=None,
585
593
  clear_messages=False,
586
594
  return_branch=False,
587
- images: Optional[str] = None,
595
+ images: str | None = None,
588
596
  verbose=True,
589
597
  formatter=None,
590
598
  format_kwargs=None,
@@ -736,10 +744,14 @@ class DirectiveMixin(ABC):
736
744
  )
737
745
 
738
746
  if verbose:
739
- print(f"------------------- Extension completed -------------------\n")
747
+ print(
748
+ f"------------------- Extension completed -------------------\n"
749
+ )
740
750
 
741
751
  extension_forms.extend(new_form)
742
- last_form = new_form[-1] if isinstance(new_form, list) else new_form
752
+ last_form = (
753
+ new_form[-1] if isinstance(new_form, list) else new_form
754
+ )
743
755
  ctr += len(form)
744
756
 
745
757
  if extension_forms:
@@ -977,7 +989,9 @@ class DirectiveMixin(ABC):
977
989
  context=context,
978
990
  )
979
991
 
980
- return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
992
+ return await self._chat(
993
+ form=form, return_form=True, branch=branch, **kwargs
994
+ )
981
995
 
982
996
  async def _predict(
983
997
  self,
@@ -1019,7 +1033,9 @@ class DirectiveMixin(ABC):
1019
1033
  reason=reason,
1020
1034
  )
1021
1035
 
1022
- return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
1036
+ return await self._chat(
1037
+ form=form, return_form=True, branch=branch, **kwargs
1038
+ )
1023
1039
 
1024
1040
  async def _score(
1025
1041
  self,
@@ -1066,7 +1082,9 @@ class DirectiveMixin(ABC):
1066
1082
  context=context,
1067
1083
  )
1068
1084
 
1069
- return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
1085
+ return await self._chat(
1086
+ form=form, return_form=True, branch=branch, **kwargs
1087
+ )
1070
1088
 
1071
1089
  async def _plan(
1072
1090
  self,
lionagi/core/unit/util.py CHANGED
@@ -21,7 +21,7 @@ choices_fields = ["index", "message", "logprobs", "finish_reason"]
21
21
 
22
22
  usage_fields = ["prompt_tokens", "completion_tokens", "total_tokens"]
23
23
 
24
- from typing import Callable
24
+ from collections.abc import Callable
25
25
 
26
26
  from lionagi.core.action.tool import Tool
27
27
  from lionagi.core.action.tool_manager import func_to_tool
@@ -40,12 +40,16 @@ def process_tools(tool_obj, branch):
40
40
  def _process_tool(tool_obj, branch):
41
41
  if (
42
42
  isinstance(tool_obj, Tool)
43
- and tool_obj.schema_["function"]["name"] not in branch.tool_manager.registry
43
+ and tool_obj.schema_["function"]["name"]
44
+ not in branch.tool_manager.registry
44
45
  ):
45
46
  branch.register_tools(tool_obj)
46
47
  if isinstance(tool_obj, Callable):
47
48
  tool = func_to_tool(tool_obj)[0]
48
- if tool.schema_["function"]["name"] not in branch.tool_manager.registry:
49
+ if (
50
+ tool.schema_["function"]["name"]
51
+ not in branch.tool_manager.registry
52
+ ):
49
53
  branch.register_tools(tool)
50
54
 
51
55
 
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
- from typing import Any, Callable, Dict, List, Union
2
+ from collections.abc import Callable
3
+ from typing import Any, Dict, List, Union
3
4
 
4
5
  from lionfuncs import lcall
5
6
 
@@ -40,10 +41,10 @@ class Validator:
40
41
  self,
41
42
  *,
42
43
  rulebook: RuleBook = None,
43
- rules: Dict[str, Rule] = None,
44
- order: List[str] = None,
45
- init_config: Dict[str, Dict] = None,
46
- active_rules: Dict[str, Rule] = None,
44
+ rules: dict[str, Rule] = None,
45
+ order: list[str] = None,
46
+ init_config: dict[str, dict] = None,
47
+ active_rules: dict[str, Rule] = None,
47
48
  formatter: Callable = None,
48
49
  format_kwargs: dict = {},
49
50
  ):
@@ -63,12 +64,14 @@ class Validator:
63
64
  self.rulebook = rulebook or RuleBook(
64
65
  rules or _DEFAULT_RULES, order or _DEFAULT_RULEORDER, init_config
65
66
  )
66
- self.active_rules: Dict[str, Rule] = active_rules or self._initiate_rules()
67
+ self.active_rules: dict[str, Rule] = (
68
+ active_rules or self._initiate_rules()
69
+ )
67
70
  self.validation_log = []
68
71
  self.formatter = formatter
69
72
  self.format_kwargs = format_kwargs
70
73
 
71
- def _initiate_rules(self) -> Dict[str, Rule]:
74
+ def _initiate_rules(self) -> dict[str, Rule]:
72
75
  """
73
76
  Initialize rules from the rulebook.
74
77
 
@@ -157,7 +160,7 @@ class Validator:
157
160
  raise FieldError(error_message)
158
161
 
159
162
  async def validate_report(
160
- self, report: Report, forms: List[Form], strict: bool = True
163
+ self, report: Report, forms: list[Form], strict: bool = True
161
164
  ) -> Report:
162
165
  """
163
166
  Validate a report based on active rules.
@@ -176,7 +179,7 @@ class Validator:
176
179
  async def validate_response(
177
180
  self,
178
181
  form: Form,
179
- response: Union[dict, str],
182
+ response: dict | str,
180
183
  strict: bool = True,
181
184
  use_annotation: bool = True,
182
185
  ) -> Form:
@@ -201,21 +204,29 @@ class Validator:
201
204
  else:
202
205
  if self.formatter:
203
206
  if asyncio.iscoroutinefunction(self.formatter):
204
- response = await self.formatter(response, **self.format_kwargs)
207
+ response = await self.formatter(
208
+ response, **self.format_kwargs
209
+ )
205
210
  print("formatter used")
206
211
  else:
207
- response = self.formatter(response, **self.format_kwargs)
212
+ response = self.formatter(
213
+ response, **self.format_kwargs
214
+ )
208
215
  print("formatter used")
209
216
 
210
217
  if not isinstance(response, dict):
211
- raise ValueError(f"The form response format is invalid for filling.")
218
+ raise ValueError(
219
+ f"The form response format is invalid for filling."
220
+ )
212
221
 
213
222
  dict_ = {}
214
223
  for k, v in response.items():
215
224
  if k in form.requested_fields:
216
225
  kwargs = form.validation_kwargs.get(k, {})
217
226
  _annotation = form._field_annotations[k]
218
- if (keys := form._get_field_attr(k, "choices", None)) is not None:
227
+ if (
228
+ keys := form._get_field_attr(k, "choices", None)
229
+ ) is not None:
219
230
  v = await self.validate_field(
220
231
  field=k,
221
232
  value=v,
@@ -227,7 +238,9 @@ class Validator:
227
238
  **kwargs,
228
239
  )
229
240
 
230
- elif (_keys := form._get_field_attr(k, "keys", None)) is not None:
241
+ elif (
242
+ _keys := form._get_field_attr(k, "keys", None)
243
+ ) is not None:
231
244
 
232
245
  v = await self.validate_field(
233
246
  field=k,
@@ -346,7 +359,7 @@ class Validator:
346
359
  }
347
360
  self.validation_log.append(log_entry)
348
361
 
349
- def get_validation_summary(self) -> Dict[str, Any]:
362
+ def get_validation_summary(self) -> dict[str, Any]:
350
363
  """
351
364
  Get a summary of validation results.
352
365
 
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Callable
2
+ from collections.abc import Callable
3
3
 
4
4
  from pydantic import Field, field_validator
5
5
 
@@ -54,7 +54,9 @@ class WorkEdge(Edge, Progressable):
54
54
  getattr(func, "_worklink_decorator_params")
55
55
  return func
56
56
  except:
57
- raise ValueError("convert_function must be a worklink decorated function")
57
+ raise ValueError(
58
+ "convert_function must be a worklink decorated function"
59
+ )
58
60
 
59
61
  @property
60
62
  def name(self):
@@ -94,5 +96,7 @@ class WorkEdge(Edge, Progressable):
94
96
  kwargs = {"from_result": task.current_work.result} | kwargs
95
97
 
96
98
  self.convert_function.auto_schedule = True
97
- next_work = await self.convert_function(self=self.associated_worker, **kwargs)
99
+ next_work = await self.convert_function(
100
+ self=self.associated_worker, **kwargs
101
+ )
98
102
  return next_work
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Callable
2
+ from collections.abc import Callable
3
3
 
4
4
  from pydantic import Field, field_validator
5
5
 
@@ -31,9 +31,13 @@ class WorkTask(Component):
31
31
 
32
32
  work_history: list[Work] = Field([], description="List of works processed")
33
33
 
34
- max_steps: int | None = Field(10, description="Maximum number of works allowed")
34
+ max_steps: int | None = Field(
35
+ 10, description="Maximum number of works allowed"
36
+ )
35
37
 
36
- current_work: Work | None = Field(None, description="The current work in progress")
38
+ current_work: Work | None = Field(
39
+ None, description="The current work in progress"
40
+ )
37
41
 
38
42
  post_processing: Callable | None = Field(
39
43
  None,
@@ -55,7 +59,9 @@ class WorkTask(Component):
55
59
  ValueError: If value is not a positive integer.
56
60
  """
57
61
  if value <= 0:
58
- raise ValueError("Invalid value: max_steps must be a positive integer.")
62
+ raise ValueError(
63
+ "Invalid value: max_steps must be a positive integer."
64
+ )
59
65
  return value
60
66
 
61
67
  @field_validator("post_processing", mode="before")
@@ -72,7 +78,7 @@ class WorkTask(Component):
72
78
  Raises:
73
79
  ValueError: If value is not an asynchronous function.
74
80
  """
75
- if value is not None and not inspect.iscoroutinefunction((value)):
81
+ if value is not None and not inspect.iscoroutinefunction(value):
76
82
  raise ValueError("post_processing must be a async function")
77
83
  return value
78
84
 
@@ -62,7 +62,12 @@ class Worker(ABC):
62
62
  """
63
63
 
64
64
  return (
65
- any([await i.is_progressable() for i in self.work_functions.values()])
65
+ any(
66
+ [
67
+ await i.is_progressable()
68
+ for i in self.work_functions.values()
69
+ ]
70
+ )
66
71
  and not self.stopped
67
72
  )
68
73
 
@@ -261,7 +266,9 @@ def work(
261
266
  **kwargs,
262
267
  ):
263
268
  if not inspect.iscoroutinefunction(func):
264
- raise TypeError(f"{func.__name__} must be an asynchronous function")
269
+ raise TypeError(
270
+ f"{func.__name__} must be an asynchronous function"
271
+ )
265
272
  retry_kwargs = retry_kwargs or {}
266
273
  retry_kwargs["timeout"] = retry_kwargs.get("timeout", timeout)
267
274
  return await self._work_wrapper(
@@ -309,7 +316,9 @@ def worklink(from_: str, to_: str, auto_schedule: bool = True):
309
316
  self: Worker, *args, func=func, from_=from_, to_=to_, **kwargs
310
317
  ):
311
318
  if not inspect.iscoroutinefunction(func):
312
- raise TypeError(f"{func.__name__} must be an asynchronous function")
319
+ raise TypeError(
320
+ f"{func.__name__} must be an asynchronous function"
321
+ )
313
322
 
314
323
  work_funcs = self._get_decorated_functions(
315
324
  decorator_attr="_work_decorator_params"
@@ -363,7 +372,9 @@ def worklink(from_: str, to_: str, auto_schedule: bool = True):
363
372
  next_params[1], dict
364
373
  ):
365
374
  if wrapper.auto_schedule:
366
- return await to_work_func(*next_params[0], **next_params[1])
375
+ return await to_work_func(
376
+ *next_params[0], **next_params[1]
377
+ )
367
378
  else:
368
379
  raise TypeError(f"Invalid return type {func.__name__}")
369
380
  else:
@@ -372,7 +383,11 @@ def worklink(from_: str, to_: str, auto_schedule: bool = True):
372
383
  return next_params
373
384
 
374
385
  wrapper.auto_schedule = auto_schedule
375
- wrapper._worklink_decorator_params = {"func": func, "from_": from_, "to_": to_}
386
+ wrapper._worklink_decorator_params = {
387
+ "func": func,
388
+ "from_": from_,
389
+ "to_": to_,
390
+ }
376
391
 
377
392
  return wrapper
378
393
 
@@ -172,8 +172,12 @@ class WorkerEngine:
172
172
  )
173
173
  for func_name, func, dec_params in work_decorated_function:
174
174
  if func_name not in self.worker.work_functions:
175
- self.worker.work_functions[func_name] = WorkFunctionNode(**dec_params)
176
- self.worker_graph.add_node(self.worker.work_functions[func_name])
175
+ self.worker.work_functions[func_name] = WorkFunctionNode(
176
+ **dec_params
177
+ )
178
+ self.worker_graph.add_node(
179
+ self.worker.work_functions[func_name]
180
+ )
177
181
  else:
178
182
  if not isinstance(
179
183
  self.worker.work_functions[func_name], WorkFunctionNode