lionagi 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. lionagi/__init__.py +60 -5
  2. lionagi/core/__init__.py +0 -25
  3. lionagi/core/_setting/_setting.py +59 -0
  4. lionagi/core/action/__init__.py +14 -0
  5. lionagi/core/action/function_calling.py +136 -0
  6. lionagi/core/action/manual.py +1 -0
  7. lionagi/core/action/node.py +109 -0
  8. lionagi/core/action/tool.py +114 -0
  9. lionagi/core/action/tool_manager.py +356 -0
  10. lionagi/core/agent/base_agent.py +27 -13
  11. lionagi/core/agent/eval/evaluator.py +1 -0
  12. lionagi/core/agent/eval/vote.py +40 -0
  13. lionagi/core/agent/learn/learner.py +59 -0
  14. lionagi/core/agent/plan/unit_template.py +1 -0
  15. lionagi/core/collections/__init__.py +17 -0
  16. lionagi/core/{generic/data_logger.py → collections/_logger.py} +69 -55
  17. lionagi/core/collections/abc/__init__.py +53 -0
  18. lionagi/core/collections/abc/component.py +615 -0
  19. lionagi/core/collections/abc/concepts.py +297 -0
  20. lionagi/core/collections/abc/exceptions.py +150 -0
  21. lionagi/core/collections/abc/util.py +45 -0
  22. lionagi/core/collections/exchange.py +161 -0
  23. lionagi/core/collections/flow.py +426 -0
  24. lionagi/core/collections/model.py +419 -0
  25. lionagi/core/collections/pile.py +913 -0
  26. lionagi/core/collections/progression.py +236 -0
  27. lionagi/core/collections/util.py +64 -0
  28. lionagi/core/director/direct.py +314 -0
  29. lionagi/core/director/director.py +2 -0
  30. lionagi/core/{execute/branch_executor.py → engine/branch_engine.py} +134 -97
  31. lionagi/core/{execute/instruction_map_executor.py → engine/instruction_map_engine.py} +80 -55
  32. lionagi/{experimental/directive/evaluator → core/engine}/script_engine.py +17 -1
  33. lionagi/core/executor/base_executor.py +90 -0
  34. lionagi/core/{execute/structure_executor.py → executor/graph_executor.py} +62 -66
  35. lionagi/core/{execute → executor}/neo4j_executor.py +70 -67
  36. lionagi/core/generic/__init__.py +3 -33
  37. lionagi/core/generic/edge.py +29 -79
  38. lionagi/core/generic/edge_condition.py +16 -0
  39. lionagi/core/generic/graph.py +236 -0
  40. lionagi/core/generic/hyperedge.py +1 -0
  41. lionagi/core/generic/node.py +156 -221
  42. lionagi/core/generic/tree.py +48 -0
  43. lionagi/core/generic/tree_node.py +79 -0
  44. lionagi/core/mail/__init__.py +12 -0
  45. lionagi/core/mail/mail.py +25 -0
  46. lionagi/core/mail/mail_manager.py +139 -58
  47. lionagi/core/mail/package.py +45 -0
  48. lionagi/core/mail/start_mail.py +36 -0
  49. lionagi/core/message/__init__.py +19 -0
  50. lionagi/core/message/action_request.py +133 -0
  51. lionagi/core/message/action_response.py +135 -0
  52. lionagi/core/message/assistant_response.py +95 -0
  53. lionagi/core/message/instruction.py +234 -0
  54. lionagi/core/message/message.py +101 -0
  55. lionagi/core/message/system.py +86 -0
  56. lionagi/core/message/util.py +283 -0
  57. lionagi/core/report/__init__.py +4 -0
  58. lionagi/core/report/base.py +217 -0
  59. lionagi/core/report/form.py +231 -0
  60. lionagi/core/report/report.py +166 -0
  61. lionagi/core/report/util.py +28 -0
  62. lionagi/core/rule/_default.py +16 -0
  63. lionagi/core/rule/action.py +99 -0
  64. lionagi/core/rule/base.py +238 -0
  65. lionagi/core/rule/boolean.py +56 -0
  66. lionagi/core/rule/choice.py +47 -0
  67. lionagi/core/rule/mapping.py +96 -0
  68. lionagi/core/rule/number.py +71 -0
  69. lionagi/core/rule/rulebook.py +109 -0
  70. lionagi/core/rule/string.py +52 -0
  71. lionagi/core/rule/util.py +35 -0
  72. lionagi/core/session/branch.py +431 -0
  73. lionagi/core/session/directive_mixin.py +287 -0
  74. lionagi/core/session/session.py +229 -903
  75. lionagi/core/structure/__init__.py +1 -0
  76. lionagi/core/structure/chain.py +1 -0
  77. lionagi/core/structure/forest.py +1 -0
  78. lionagi/core/structure/graph.py +1 -0
  79. lionagi/core/structure/tree.py +1 -0
  80. lionagi/core/unit/__init__.py +5 -0
  81. lionagi/core/unit/parallel_unit.py +245 -0
  82. lionagi/core/unit/template/action.py +81 -0
  83. lionagi/core/unit/template/base.py +51 -0
  84. lionagi/core/unit/template/plan.py +84 -0
  85. lionagi/core/unit/template/predict.py +109 -0
  86. lionagi/core/unit/template/score.py +124 -0
  87. lionagi/core/unit/template/select.py +104 -0
  88. lionagi/core/unit/unit.py +362 -0
  89. lionagi/core/unit/unit_form.py +305 -0
  90. lionagi/core/unit/unit_mixin.py +1168 -0
  91. lionagi/core/unit/util.py +71 -0
  92. lionagi/core/validator/validator.py +364 -0
  93. lionagi/core/work/work.py +76 -0
  94. lionagi/core/work/work_function.py +101 -0
  95. lionagi/core/work/work_queue.py +103 -0
  96. lionagi/core/work/worker.py +258 -0
  97. lionagi/core/work/worklog.py +120 -0
  98. lionagi/experimental/compressor/base.py +46 -0
  99. lionagi/experimental/compressor/llm_compressor.py +247 -0
  100. lionagi/experimental/compressor/llm_summarizer.py +61 -0
  101. lionagi/experimental/compressor/util.py +70 -0
  102. lionagi/experimental/directive/__init__.py +19 -0
  103. lionagi/experimental/directive/parser/base_parser.py +69 -2
  104. lionagi/experimental/directive/{template_ → template}/base_template.py +17 -1
  105. lionagi/{libs/ln_tokenizer.py → experimental/directive/tokenizer.py} +16 -0
  106. lionagi/experimental/{directive/evaluator → evaluator}/ast_evaluator.py +16 -0
  107. lionagi/experimental/{directive/evaluator → evaluator}/base_evaluator.py +16 -0
  108. lionagi/experimental/knowledge/base.py +10 -0
  109. lionagi/experimental/memory/__init__.py +0 -0
  110. lionagi/experimental/strategies/__init__.py +0 -0
  111. lionagi/experimental/strategies/base.py +1 -0
  112. lionagi/integrations/bridge/langchain_/documents.py +4 -0
  113. lionagi/integrations/bridge/llamaindex_/index.py +30 -0
  114. lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +6 -0
  115. lionagi/integrations/chunker/chunk.py +161 -24
  116. lionagi/integrations/config/oai_configs.py +34 -3
  117. lionagi/integrations/config/openrouter_configs.py +14 -2
  118. lionagi/integrations/loader/load.py +122 -21
  119. lionagi/integrations/loader/load_util.py +6 -77
  120. lionagi/integrations/provider/_mapping.py +46 -0
  121. lionagi/integrations/provider/litellm.py +2 -1
  122. lionagi/integrations/provider/mlx_service.py +16 -9
  123. lionagi/integrations/provider/oai.py +91 -4
  124. lionagi/integrations/provider/ollama.py +6 -5
  125. lionagi/integrations/provider/openrouter.py +115 -8
  126. lionagi/integrations/provider/services.py +2 -2
  127. lionagi/integrations/provider/transformers.py +18 -22
  128. lionagi/integrations/storage/__init__.py +3 -3
  129. lionagi/integrations/storage/neo4j.py +52 -60
  130. lionagi/integrations/storage/storage_util.py +44 -46
  131. lionagi/integrations/storage/structure_excel.py +43 -26
  132. lionagi/integrations/storage/to_excel.py +11 -4
  133. lionagi/libs/__init__.py +22 -1
  134. lionagi/libs/ln_api.py +75 -20
  135. lionagi/libs/ln_context.py +37 -0
  136. lionagi/libs/ln_convert.py +21 -9
  137. lionagi/libs/ln_func_call.py +69 -28
  138. lionagi/libs/ln_image.py +107 -0
  139. lionagi/libs/ln_nested.py +26 -11
  140. lionagi/libs/ln_parse.py +82 -23
  141. lionagi/libs/ln_queue.py +16 -0
  142. lionagi/libs/ln_tokenize.py +164 -0
  143. lionagi/libs/ln_validate.py +16 -0
  144. lionagi/libs/special_tokens.py +172 -0
  145. lionagi/libs/sys_util.py +95 -24
  146. lionagi/lions/coder/code_form.py +13 -0
  147. lionagi/lions/coder/coder.py +50 -3
  148. lionagi/lions/coder/util.py +30 -25
  149. lionagi/tests/libs/test_func_call.py +23 -21
  150. lionagi/tests/libs/test_nested.py +36 -21
  151. lionagi/tests/libs/test_parse.py +1 -1
  152. lionagi/tests/test_core/collections/__init__.py +0 -0
  153. lionagi/tests/test_core/collections/test_component.py +206 -0
  154. lionagi/tests/test_core/collections/test_exchange.py +138 -0
  155. lionagi/tests/test_core/collections/test_flow.py +145 -0
  156. lionagi/tests/test_core/collections/test_pile.py +171 -0
  157. lionagi/tests/test_core/collections/test_progression.py +129 -0
  158. lionagi/tests/test_core/generic/test_edge.py +67 -0
  159. lionagi/tests/test_core/generic/test_graph.py +96 -0
  160. lionagi/tests/test_core/generic/test_node.py +106 -0
  161. lionagi/tests/test_core/generic/test_tree_node.py +73 -0
  162. lionagi/tests/test_core/test_branch.py +115 -294
  163. lionagi/tests/test_core/test_form.py +46 -0
  164. lionagi/tests/test_core/test_report.py +105 -0
  165. lionagi/tests/test_core/test_validator.py +111 -0
  166. lionagi/version.py +1 -1
  167. lionagi-0.2.1.dist-info/LICENSE +202 -0
  168. lionagi-0.2.1.dist-info/METADATA +272 -0
  169. lionagi-0.2.1.dist-info/RECORD +240 -0
  170. lionagi/core/branch/base.py +0 -653
  171. lionagi/core/branch/branch.py +0 -474
  172. lionagi/core/branch/flow_mixin.py +0 -96
  173. lionagi/core/branch/util.py +0 -323
  174. lionagi/core/direct/__init__.py +0 -19
  175. lionagi/core/direct/cot.py +0 -123
  176. lionagi/core/direct/plan.py +0 -164
  177. lionagi/core/direct/predict.py +0 -166
  178. lionagi/core/direct/react.py +0 -171
  179. lionagi/core/direct/score.py +0 -279
  180. lionagi/core/direct/select.py +0 -170
  181. lionagi/core/direct/sentiment.py +0 -1
  182. lionagi/core/direct/utils.py +0 -110
  183. lionagi/core/direct/vote.py +0 -64
  184. lionagi/core/execute/base_executor.py +0 -47
  185. lionagi/core/flow/baseflow.py +0 -23
  186. lionagi/core/flow/monoflow/ReAct.py +0 -240
  187. lionagi/core/flow/monoflow/__init__.py +0 -9
  188. lionagi/core/flow/monoflow/chat.py +0 -95
  189. lionagi/core/flow/monoflow/chat_mixin.py +0 -253
  190. lionagi/core/flow/monoflow/followup.py +0 -215
  191. lionagi/core/flow/polyflow/__init__.py +0 -1
  192. lionagi/core/flow/polyflow/chat.py +0 -251
  193. lionagi/core/form/action_form.py +0 -26
  194. lionagi/core/form/field_validator.py +0 -287
  195. lionagi/core/form/form.py +0 -302
  196. lionagi/core/form/mixin.py +0 -214
  197. lionagi/core/form/scored_form.py +0 -13
  198. lionagi/core/generic/action.py +0 -26
  199. lionagi/core/generic/component.py +0 -532
  200. lionagi/core/generic/condition.py +0 -46
  201. lionagi/core/generic/mail.py +0 -90
  202. lionagi/core/generic/mailbox.py +0 -36
  203. lionagi/core/generic/relation.py +0 -70
  204. lionagi/core/generic/signal.py +0 -22
  205. lionagi/core/generic/structure.py +0 -362
  206. lionagi/core/generic/transfer.py +0 -20
  207. lionagi/core/generic/work.py +0 -40
  208. lionagi/core/graph/graph.py +0 -126
  209. lionagi/core/graph/tree.py +0 -190
  210. lionagi/core/mail/schema.py +0 -63
  211. lionagi/core/messages/schema.py +0 -325
  212. lionagi/core/tool/__init__.py +0 -5
  213. lionagi/core/tool/tool.py +0 -28
  214. lionagi/core/tool/tool_manager.py +0 -283
  215. lionagi/experimental/report/form.py +0 -64
  216. lionagi/experimental/report/report.py +0 -138
  217. lionagi/experimental/report/util.py +0 -47
  218. lionagi/experimental/tool/function_calling.py +0 -43
  219. lionagi/experimental/tool/manual.py +0 -66
  220. lionagi/experimental/tool/schema.py +0 -59
  221. lionagi/experimental/tool/tool_manager.py +0 -138
  222. lionagi/experimental/tool/util.py +0 -16
  223. lionagi/experimental/validator/rule.py +0 -139
  224. lionagi/experimental/validator/validator.py +0 -56
  225. lionagi/experimental/work/__init__.py +0 -10
  226. lionagi/experimental/work/async_queue.py +0 -54
  227. lionagi/experimental/work/schema.py +0 -73
  228. lionagi/experimental/work/work_function.py +0 -67
  229. lionagi/experimental/work/worker.py +0 -56
  230. lionagi/experimental/work2/form.py +0 -371
  231. lionagi/experimental/work2/report.py +0 -289
  232. lionagi/experimental/work2/schema.py +0 -30
  233. lionagi/experimental/work2/tests.py +0 -72
  234. lionagi/experimental/work2/work_function.py +0 -89
  235. lionagi/experimental/work2/worker.py +0 -12
  236. lionagi/integrations/bridge/llamaindex_/get_index.py +0 -294
  237. lionagi/tests/test_core/generic/test_component.py +0 -89
  238. lionagi/tests/test_core/test_base_branch.py +0 -426
  239. lionagi/tests/test_core/test_chat_flow.py +0 -63
  240. lionagi/tests/test_core/test_mail_manager.py +0 -75
  241. lionagi/tests/test_core/test_prompts.py +0 -51
  242. lionagi/tests/test_core/test_session.py +0 -254
  243. lionagi/tests/test_core/test_session_base_util.py +0 -313
  244. lionagi/tests/test_core/test_tool_manager.py +0 -95
  245. lionagi-0.1.2.dist-info/LICENSE +0 -9
  246. lionagi-0.1.2.dist-info/METADATA +0 -174
  247. lionagi-0.1.2.dist-info/RECORD +0 -206
  248. /lionagi/core/{branch → _setting}/__init__.py +0 -0
  249. /lionagi/core/{execute → agent/eval}/__init__.py +0 -0
  250. /lionagi/core/{flow → agent/learn}/__init__.py +0 -0
  251. /lionagi/core/{form → agent/plan}/__init__.py +0 -0
  252. /lionagi/core/{branch/executable_branch.py → agent/plan/plan.py} +0 -0
  253. /lionagi/core/{graph → director}/__init__.py +0 -0
  254. /lionagi/core/{messages → engine}/__init__.py +0 -0
  255. /lionagi/{experimental/directive/evaluator → core/engine}/sandbox_.py +0 -0
  256. /lionagi/{experimental/directive/evaluator → core/executor}/__init__.py +0 -0
  257. /lionagi/{experimental/directive/template_ → core/rule}/__init__.py +0 -0
  258. /lionagi/{experimental/report → core/unit/template}/__init__.py +0 -0
  259. /lionagi/{experimental/tool → core/validator}/__init__.py +0 -0
  260. /lionagi/{experimental/validator → core/work}/__init__.py +0 -0
  261. /lionagi/experimental/{work2 → compressor}/__init__.py +0 -0
  262. /lionagi/{core/flow/mono_chat_mixin.py → experimental/directive/template/__init__.py} +0 -0
  263. /lionagi/experimental/directive/{schema.py → template/schema.py} +0 -0
  264. /lionagi/experimental/{work2/util.py → evaluator/__init__.py} +0 -0
  265. /lionagi/experimental/{work2/work.py → knowledge/__init__.py} +0 -0
  266. /lionagi/{tests/libs/test_async.py → experimental/knowledge/graph.py} +0 -0
  267. {lionagi-0.1.2.dist-info → lionagi-0.2.1.dist-info}/WHEEL +0 -0
  268. {lionagi-0.1.2.dist-info → lionagi-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1168 @@
1
+ """
2
+ Copyright 2024 HaiyangLi
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ The base directive module.
18
+ """
19
+
20
+ import asyncio
21
+ import contextlib
22
+ import re
23
+ from abc import ABC
24
+
25
+ from typing import Any, Optional
26
+
27
+ from lionagi.libs import ParseUtil, StringMatch, to_list
28
+ from lionagi.libs.ln_nested import nmerge
29
+ from lionagi.core.collections.abc import ActionError
30
+ from lionagi.core.message import ActionRequest, ActionResponse, Instruction
31
+ from lionagi.core.message.util import _parse_action_request
32
+ from lionagi.core.report.form import Form
33
+ from lionagi.core.unit.util import process_tools
34
+ from lionagi.core.validator.validator import Validator
35
+
36
+
37
+ class DirectiveMixin(ABC):
38
+ """
39
+ DirectiveMixin is a class for handling chat operations and
40
+ processing responses.
41
+ """
42
+
43
+ def _create_chat_config(
44
+ self,
45
+ system: Optional[str] = None,
46
+ instruction: Optional[str] = None,
47
+ context: Optional[str] = None,
48
+ images: Optional[str] = None,
49
+ sender: Optional[str] = None,
50
+ recipient: Optional[str] = None,
51
+ requested_fields: Optional[list] = None,
52
+ form: Form = None,
53
+ tools: bool = False,
54
+ branch: Optional[Any] = None,
55
+ **kwargs,
56
+ ) -> Any:
57
+ """
58
+ Create the chat configuration based on the provided parameters.
59
+
60
+ Args:
61
+ system: System message.
62
+ instruction: Instruction message.
63
+ context: Context message.
64
+ sender: Sender identifier.
65
+ recipient: Recipient identifier.
66
+ requested_fields: Fields requested in the response.
67
+ form: Form data.
68
+ tools: Flag indicating if tools should be used.
69
+ branch: Branch instance.
70
+ kwargs: Additional keyword arguments.
71
+
72
+ Returns:
73
+ dict: The chat configuration.
74
+ """
75
+ branch = branch or self.branch
76
+
77
+ if system:
78
+ branch.add_message(system=system)
79
+
80
+ if not form:
81
+ if recipient == "branch.ln_id":
82
+ recipient = branch.ln_id
83
+
84
+ branch.add_message(
85
+ instruction=instruction,
86
+ context=context,
87
+ sender=sender,
88
+ recipient=recipient,
89
+ requested_fields=requested_fields,
90
+ images=images,
91
+ )
92
+ else:
93
+ instruct_ = Instruction.from_form(form)
94
+ branch.add_message(instruction=instruct_)
95
+
96
+ if "tool_parsed" in kwargs:
97
+ kwargs.pop("tool_parsed")
98
+ tool_kwarg = {"tools": tools}
99
+ kwargs = tool_kwarg | kwargs
100
+ elif tools and branch.has_tools:
101
+ kwargs = branch.tool_manager.parse_tool(tools=tools, **kwargs)
102
+
103
+ config = {**self.imodel.config, **kwargs}
104
+ if sender is not None:
105
+ config["sender"] = sender
106
+
107
+ return config
108
+
109
+ async def _call_chatcompletion(
110
+ self, imodel: Optional[Any] = None, branch: Optional[Any] = None, **kwargs
111
+ ) -> Any:
112
+ """
113
+ Calls the chat completion model.
114
+
115
+ Args:
116
+ imodel: The model instance.
117
+ branch: The branch instance.
118
+ kwargs: Additional keyword arguments.
119
+
120
+ Returns:
121
+ Any: The chat completion result.
122
+ """
123
+ imodel = imodel or self.imodel
124
+ branch = branch or self.branch
125
+ return await imodel.call_chat_completion(branch.to_chat_messages(), **kwargs)
126
+
127
+ async def _process_chatcompletion(
128
+ self,
129
+ payload: dict,
130
+ completion: dict,
131
+ sender: str,
132
+ invoke_tool: bool = True,
133
+ branch: Optional[Any] = None,
134
+ action_request: Optional[Any] = None,
135
+ costs=None,
136
+ ) -> Any:
137
+ """
138
+ Processes the chat completion response.
139
+ Currently only support last message for function calling
140
+
141
+ Args:
142
+ payload: The payload data.
143
+ completion: The completion data.
144
+ sender: The sender identifier.
145
+ invoke_tool: Flag indicating if tools should be invoked.
146
+ branch: The branch instance.
147
+ action_request: The action request instance.
148
+
149
+ Returns:
150
+ Any: The processed result.
151
+ """
152
+ branch = branch or self.branch
153
+ _msg = None
154
+
155
+ if "choices" in completion:
156
+ payload.pop("messages", None)
157
+ branch.update_last_instruction_meta(payload)
158
+ _choices = completion.pop("choices", None)
159
+
160
+ def process_completion_choice(choice, price=None):
161
+ if isinstance(choice, dict):
162
+ msg = choice.pop("message", None)
163
+ _completion = completion.copy()
164
+ _completion.update(choice)
165
+ branch.add_message(
166
+ assistant_response=msg,
167
+ metadata=_completion,
168
+ sender=sender,
169
+ )
170
+
171
+ a = branch.messages[-1]._meta_get(
172
+ ["extra", "usage", "prompt_tokens"], 0
173
+ )
174
+ b = branch.messages[-1]._meta_get(
175
+ ["extra", "usage", "completion_tokens"], 0
176
+ )
177
+ m = completion.get("model", None)
178
+ if m:
179
+ ttl = (a * price[0] + b * price[1]) / 1000000
180
+ branch.messages[-1]._meta_insert(["extra", "usage", "expense"], ttl)
181
+ return msg
182
+
183
+ if _choices and not isinstance(_choices, list):
184
+ _choices = [_choices]
185
+
186
+ if _choices and isinstance(_choices, list):
187
+ for _choice in _choices:
188
+ _msg = process_completion_choice(_choice, price=costs)
189
+
190
+ branch.imodel.status_tracker.num_tasks_succeeded += 1
191
+ else:
192
+ branch.imodel.status_tracker.num_tasks_failed += 1
193
+
194
+ return await self._process_action_request(
195
+ _msg=_msg,
196
+ branch=branch,
197
+ invoke_tool=invoke_tool,
198
+ action_request=action_request,
199
+ )
200
+
201
+ async def _process_action_request(
202
+ self,
203
+ _msg: Optional[dict] = None,
204
+ branch: Optional[Any] = None,
205
+ invoke_tool: bool = True,
206
+ action_request: Optional[Any] = None,
207
+ ) -> Any:
208
+ """
209
+ Processes an action request from the assistant response.
210
+
211
+ Args:
212
+ _msg: The message data.
213
+ branch: The branch instance.
214
+ invoke_tool: Flag indicating if tools should be invoked.
215
+ action_request: The action request instance.
216
+
217
+ Returns:
218
+ Any: The processed result.
219
+ """
220
+ action_request = action_request or _parse_action_request(_msg)
221
+ if action_request is None:
222
+ return _msg if _msg else False
223
+
224
+ if action_request:
225
+ for i in action_request:
226
+ if i.function in branch.tool_manager.registry:
227
+ i.recipient = branch.tool_manager.registry[i.function].ln_id
228
+ else:
229
+ raise ActionError(f"Tool {i.function} not found in registry")
230
+ branch.add_message(action_request=i, recipient=i.recipient)
231
+
232
+ if invoke_tool:
233
+ tasks = []
234
+ for i in action_request:
235
+ tool = branch.tool_manager.registry[i.function]
236
+ tasks.append(asyncio.create_task(tool.invoke(i.arguments)))
237
+ results = await asyncio.gather(*tasks)
238
+
239
+ for idx, item in enumerate(results):
240
+ if item is not None:
241
+ branch.add_message(
242
+ action_request=action_request[idx],
243
+ func_outputs=item,
244
+ sender=action_request[idx].recipient,
245
+ recipient=action_request[idx].sender,
246
+ )
247
+
248
+ return None
249
+
250
+ async def _output(
251
+ self,
252
+ payload: dict,
253
+ completion: dict,
254
+ sender: str,
255
+ invoke_tool: bool,
256
+ requested_fields: dict,
257
+ form: Form = None,
258
+ return_form: bool = True,
259
+ strict: bool = False,
260
+ rulebook: Any = None,
261
+ use_annotation: bool = True,
262
+ template_name: str = None,
263
+ costs=None,
264
+ ) -> Any:
265
+ """
266
+ Outputs the final processed response.
267
+
268
+ Args:
269
+ payload: The payload data.
270
+ completion: The completion data.
271
+ sender: The sender identifier.
272
+ invoke_tool: Flag indicating if tools should be invoked.
273
+ requested_fields: Fields requested in the response.
274
+ form: Form data.
275
+ return_form: Flag indicating if form should be returned.
276
+ strict: Flag indicating if strict validation should be applied.
277
+ rulebook: Rulebook instance for validation.
278
+ use_annotation: Flag indicating if annotations should be used.
279
+ template_name: Template name for form.
280
+
281
+ Returns:
282
+ Any: The processed response.
283
+ """
284
+ _msg = await self._process_chatcompletion(
285
+ payload=payload,
286
+ completion=completion,
287
+ sender=sender,
288
+ invoke_tool=invoke_tool,
289
+ costs=costs,
290
+ )
291
+
292
+ if _msg is None:
293
+ return None
294
+
295
+ response_ = self._process_model_response(_msg, requested_fields)
296
+
297
+ if form:
298
+ validator = Validator(rulebook=rulebook) if rulebook else self.validator
299
+ form = await validator.validate_response(
300
+ form=form,
301
+ response=response_,
302
+ strict=strict,
303
+ use_annotation=use_annotation,
304
+ )
305
+ if template_name:
306
+ form.template_name = template_name
307
+
308
+ return (
309
+ form
310
+ if return_form
311
+ else {
312
+ i: form.work_fields[i]
313
+ for i in form.requested_fields
314
+ if form.work_fields[i] is not None
315
+ }
316
+ )
317
+
318
+ return response_
319
+
320
+ async def _base_chat(
321
+ self,
322
+ instruction: Any = None,
323
+ *,
324
+ system: Any = None,
325
+ context: Any = None,
326
+ sender: Any = None,
327
+ recipient: Any = None,
328
+ requested_fields: dict = None,
329
+ form: Form = None,
330
+ tools: Any = False,
331
+ images: Optional[str] = None,
332
+ invoke_tool: bool = True,
333
+ return_form: bool = True,
334
+ strict: bool = False,
335
+ rulebook: Any = None,
336
+ imodel: Any = None,
337
+ use_annotation: bool = True,
338
+ branch: Any = None,
339
+ clear_messages: bool = False,
340
+ return_branch: bool = False,
341
+ **kwargs,
342
+ ) -> Any:
343
+ """
344
+ Handles the base chat operation by configuring the chat and
345
+ processing the response.
346
+
347
+ Args:
348
+ instruction: Instruction message.
349
+ system: System message.
350
+ context: Context message.
351
+ sender: Sender identifier.
352
+ recipient: Recipient identifier.
353
+ requested_fields: Fields requested in the response.
354
+ form: Form data.
355
+ tools: Flag indicating if tools should be used.
356
+ invoke_tool: Flag indicating if tools should be invoked.
357
+ return_form: Flag indicating if form should be returned.
358
+ strict: Flag indicating if strict validation should be applied.
359
+ rulebook: Rulebook instance for validation.
360
+ imodel: Model instance.
361
+ use_annotation: Flag indicating if annotations should be used.
362
+ branch: Branch instance.
363
+ clear_messages: Flag indicating if messages should be cleared.
364
+ return_branch: Flag indicating if branch should be returned.
365
+ kwargs: Additional keyword arguments.
366
+
367
+ Returns:
368
+ Any: The processed response and branch.
369
+ """
370
+ branch = branch or self.branch
371
+ if clear_messages:
372
+ branch.clear()
373
+ branch.set_system(system)
374
+
375
+ config = self._create_chat_config(
376
+ system=system,
377
+ instruction=instruction,
378
+ context=context,
379
+ sender=sender,
380
+ recipient=recipient,
381
+ requested_fields=requested_fields,
382
+ form=form,
383
+ tools=tools,
384
+ branch=branch,
385
+ images=images,
386
+ **kwargs,
387
+ )
388
+
389
+ payload, completion = await self._call_chatcompletion(
390
+ imodel=imodel, branch=branch, **config
391
+ )
392
+
393
+ imodel = imodel or self.imodel
394
+ out_ = await self._output(
395
+ payload=payload,
396
+ completion=completion,
397
+ sender=sender,
398
+ invoke_tool=invoke_tool,
399
+ requested_fields=requested_fields,
400
+ form=form,
401
+ return_form=return_form,
402
+ strict=strict,
403
+ rulebook=rulebook,
404
+ use_annotation=use_annotation,
405
+ costs=imodel.costs,
406
+ )
407
+
408
+ return out_, branch if return_branch else out_
409
+
410
+ async def _chat(
411
+ self,
412
+ instruction=None,
413
+ context=None,
414
+ system=None,
415
+ sender=None,
416
+ recipient=None,
417
+ branch=None,
418
+ requested_fields=None,
419
+ form: Form = None,
420
+ tools=False,
421
+ invoke_tool=True,
422
+ return_form=True,
423
+ strict=False,
424
+ rulebook=None,
425
+ imodel=None,
426
+ images: Optional[str] = None,
427
+ clear_messages=False,
428
+ use_annotation=True,
429
+ timeout: float = None,
430
+ return_branch=False,
431
+ **kwargs,
432
+ ):
433
+ """
434
+ Handles the chat operation.
435
+
436
+ Args:
437
+ instruction: Instruction message.
438
+ context: Context message.
439
+ system: System message.
440
+ sender: Sender identifier.
441
+ recipient: Recipient identifier.
442
+ branch: Branch instance.
443
+ requested_fields: Fields requested in the response.
444
+ form: Form data.
445
+ tools: Flag indicating if tools should be used.
446
+ invoke_tool: Flag indicating if tools should be invoked.
447
+ return_form: Flag indicating if form should be returned.
448
+ strict: Flag indicating if strict validation should be applied.
449
+ rulebook: Rulebook instance for validation.
450
+ imodel: Model instance.
451
+ clear_messages: Flag indicating if messages should be cleared.
452
+ use_annotation: Flag indicating if annotations should be used.
453
+ timeout: Timeout value.
454
+ return_branch: Flag indicating if branch should be returned.
455
+ kwargs: Additional keyword arguments.
456
+
457
+ Returns:
458
+ Any: The processed response.
459
+ """
460
+ a = await self._base_chat(
461
+ context=context,
462
+ instruction=instruction,
463
+ system=system,
464
+ sender=sender,
465
+ recipient=recipient,
466
+ requested_fields=requested_fields,
467
+ form=form,
468
+ tools=tools,
469
+ images=images,
470
+ invoke_tool=invoke_tool,
471
+ return_form=return_form,
472
+ strict=strict,
473
+ rulebook=rulebook,
474
+ imodel=imodel,
475
+ use_annotation=use_annotation,
476
+ timeout=timeout,
477
+ branch=branch,
478
+ clear_messages=clear_messages,
479
+ return_branch=return_branch,
480
+ **kwargs,
481
+ )
482
+
483
+ if isinstance(a, str):
484
+ return a
485
+
486
+ a = list(a)
487
+
488
+ if len(a) == 2 and a[0] == a[1]:
489
+ return a[0] if not isinstance(a[0], tuple) else a[0][0]
490
+ if len(a) == 2 and a[0] != a[1]:
491
+ return a[0], a[1]
492
+ if len(a) == 1 and isinstance(a[0], tuple):
493
+ return a[0][0]
494
+ if len(a) == 1 and not isinstance(a[0], tuple):
495
+ return a[0]
496
+
497
+ async def _direct(
498
+ self,
499
+ instruction=None,
500
+ context=None,
501
+ form: Form = None,
502
+ branch=None,
503
+ tools=None,
504
+ reason: bool = None,
505
+ predict: bool = None,
506
+ score: bool = None,
507
+ select: bool = None,
508
+ plan: bool = None,
509
+ allow_action: bool = None,
510
+ allow_extension: bool = None,
511
+ confidence: bool = None,
512
+ max_extension: int = None,
513
+ score_num_digits=None,
514
+ score_range=None,
515
+ select_choices=None,
516
+ plan_num_step=None,
517
+ predict_num_sentences=None,
518
+ clear_messages=False,
519
+ return_branch=False,
520
+ images: Optional[str] = None,
521
+ verbose=None,
522
+ **kwargs,
523
+ ):
524
+ """
525
+ Directs the operation based on the provided parameters.
526
+
527
+ Args:
528
+ instruction: Instruction message.
529
+ context: Context message.
530
+ form: Form data.
531
+ branch: Branch instance.
532
+ tools: Tools data.
533
+ reason: Flag indicating if reason should be included.
534
+ predict: Flag indicating if prediction should be included.
535
+ score: Flag indicating if score should be included.
536
+ select: Flag indicating if selection should be included.
537
+ plan: Flag indicating if plan should be included.
538
+ allow_action: Flag indicating if action should be allowed.
539
+ allow_extension: Flag indicating if extension should be allowed.
540
+ confidence: Flag indicating if confidence should be included.
541
+ max_extension: Maximum extension value.
542
+ score_num_digits: Number of digits for score.
543
+ score_range: Range for score.
544
+ select_choices: Choices for selection.
545
+ plan_num_step: Number of steps for plan.
546
+ predict_num_sentences: Number of sentences for prediction.
547
+ clear_messages: Flag indicating if messages should be cleared.
548
+ return_branch: Flag indicating if branch should be returned.
549
+ kwargs: Additional keyword arguments.
550
+
551
+ Returns:
552
+ Any: The processed response and branch.
553
+ """
554
+ a = await self._base_direct(
555
+ instruction=instruction,
556
+ context=context,
557
+ form=form,
558
+ branch=branch,
559
+ tools=tools,
560
+ reason=reason,
561
+ predict=predict,
562
+ score=score,
563
+ select=select,
564
+ images=images,
565
+ plan=plan,
566
+ allow_action=allow_action,
567
+ allow_extension=allow_extension,
568
+ confidence=confidence,
569
+ max_extension=max_extension,
570
+ score_num_digits=score_num_digits,
571
+ score_range=score_range,
572
+ select_choices=select_choices,
573
+ plan_num_step=plan_num_step,
574
+ predict_num_sentences=predict_num_sentences,
575
+ clear_messages=clear_messages,
576
+ return_branch=return_branch,
577
+ verbose=verbose,
578
+ **kwargs,
579
+ )
580
+
581
+ a = list(a)
582
+ if len(a) == 2 and a[0] == a[1]:
583
+ return a[0] if not isinstance(a[0], tuple) else a[0][0]
584
+
585
+ return a[0], a[1]
586
+
587
+ async def _base_direct(
588
+ self,
589
+ instruction=None,
590
+ *,
591
+ context=None,
592
+ form: Form = None,
593
+ branch=None,
594
+ tools=None,
595
+ reason: bool = None,
596
+ predict: bool = None,
597
+ score: bool = None,
598
+ select: bool = None,
599
+ plan: bool = None,
600
+ allow_action: bool = None,
601
+ allow_extension: bool = None,
602
+ confidence: bool = None,
603
+ max_extension: int = None,
604
+ score_num_digits=None,
605
+ score_range=None,
606
+ select_choices=None,
607
+ plan_num_step=None,
608
+ predict_num_sentences=None,
609
+ clear_messages=False,
610
+ return_branch=False,
611
+ images: Optional[str] = None,
612
+ verbose=None,
613
+ **kwargs,
614
+ ):
615
+ """
616
+ Handles the base direct operation.
617
+
618
+ Args:
619
+ instruction: Instruction message.
620
+ context: Context message.
621
+ form: Form data.
622
+ branch: Branch instance.
623
+ tools: Tools data.
624
+ reason: Flag indicating if reason should be included.
625
+ predict: Flag indicating if prediction should be included.
626
+ score: Flag indicating if score should be included.
627
+ select: Flag indicating if selection should be included.
628
+ plan: Flag indicating if plan should be included.
629
+ allow_action: Flag indicating if action should be allowed.
630
+ allow_extension: Flag indicating if extension should be allowed.
631
+ confidence: Flag indicating if confidence should be included.
632
+ max_extension: Maximum extension value.
633
+ score_num_digits: Number of digits for score.
634
+ score_range: Range for score.
635
+ select_choices: Choices for selection.
636
+ plan_num_step: Number of steps for plan.
637
+ predict_num_sentences: Number of sentences for prediction.
638
+ clear_messages: Flag indicating if messages should be cleared.
639
+ return_branch: Flag indicating if branch should be returned.
640
+ kwargs: Additional keyword arguments.
641
+
642
+ Returns:
643
+ Any: The processed response and branch.
644
+ """
645
+ # Ensure branch is initialized
646
+ branch = branch or self.branch
647
+ if clear_messages:
648
+ branch.clear()
649
+
650
+ # Set a default max_extension if allow_extension is True and max_extension is None
651
+ if allow_extension and not max_extension:
652
+ max_extension = 3 # Set a default limit for recursion
653
+
654
+ # Process tools if provided
655
+ if tools:
656
+ process_tools(tools, branch)
657
+
658
+ if allow_action and not tools:
659
+ tools = True
660
+
661
+ tool_schema=None
662
+ if tools:
663
+ tool_schema = branch.tool_manager.get_tool_schema(tools)
664
+
665
+ if not form:
666
+ form = self.default_template(
667
+ instruction=instruction,
668
+ context=context,
669
+ reason=reason,
670
+ predict=predict,
671
+ score=score,
672
+ select=select,
673
+ plan=plan,
674
+ tool_schema=tool_schema,
675
+ allow_action=allow_action,
676
+ allow_extension=allow_extension,
677
+ max_extension=max_extension,
678
+ confidence=confidence,
679
+ score_num_digits=score_num_digits,
680
+ score_range=score_range,
681
+ select_choices=select_choices,
682
+ plan_num_step=plan_num_step,
683
+ predict_num_sentences=predict_num_sentences,
684
+ )
685
+
686
+ elif form and "tool_schema" not in form._all_fields:
687
+ form.append_to_input("tool_schema")
688
+ form.tool_schema = tool_schema
689
+
690
+ else:
691
+ form.tool_schema = tool_schema
692
+
693
+ verbose = (
694
+ verbose
695
+ if verbose is not None and isinstance(verbose, bool)
696
+ else self.verbose
697
+ )
698
+ if verbose:
699
+ print("Chatting with model...")
700
+
701
+ # Call the base chat method
702
+ form = await self._chat(
703
+ form=form,
704
+ branch=branch,
705
+ images=images,
706
+ **kwargs,
707
+ )
708
+
709
+ # Handle actions if allowed and required
710
+ if allow_action and getattr(form, "action_required", None):
711
+ actions = getattr(form, "actions", None)
712
+ if actions:
713
+ if verbose:
714
+ print(
715
+ "Found action requests in model response. Processing actions..."
716
+ )
717
+ form = await self._act(form, branch, actions=actions)
718
+ if verbose:
719
+ print("Actions processed!")
720
+
721
+ last_form = form
722
+
723
+ ctr = 1
724
+
725
+ # Handle extensions if allowed and required
726
+ extension_forms = []
727
+ max_extension = max_extension if isinstance(max_extension, int) else 3
728
+ while (
729
+ allow_extension
730
+ and max_extension > 0
731
+ and getattr(last_form, "extension_required", None)
732
+ ):
733
+ if getattr(last_form, "is_extension", None):
734
+ break
735
+ if verbose:
736
+ print(f"\nFound extension requests in model response.")
737
+ print(
738
+ f"------------------- Processing extension No.{ctr} -------------------"
739
+ )
740
+
741
+ max_extension -= 1
742
+
743
+ # new form cannot be extended, otherwise it will be an infinite loop
744
+ new_form = await self._extend(
745
+ tools=tools,
746
+ reason=reason,
747
+ predict=predict,
748
+ score=score,
749
+ select=select,
750
+ plan=getattr(last_form, "plan", None),
751
+ allow_action=allow_action,
752
+ confidence=confidence,
753
+ score_num_digits=score_num_digits,
754
+ score_range=score_range,
755
+ select_choices=select_choices,
756
+ predict_num_sentences=predict_num_sentences,
757
+ **kwargs,
758
+ )
759
+
760
+ if verbose:
761
+ print(f"------------------- Extension completed -------------------\n")
762
+
763
+ extension_forms.extend(new_form)
764
+ last_form = new_form[-1] if isinstance(new_form, list) else new_form
765
+ ctr += len(form)
766
+
767
+ if extension_forms:
768
+ if not getattr(form, "extension_forms", None):
769
+ form._add_field("extension_forms", list, None, [])
770
+ form.extension_forms.extend(extension_forms)
771
+ action_responses = [
772
+ i.action_response
773
+ for i in extension_forms
774
+ if getattr(i, "action_response", None) is not None
775
+ ]
776
+ if not hasattr(form, "action_response"):
777
+ form.add_field("action_response", {})
778
+
779
+ for action_response in action_responses:
780
+ nmerge([form.action_response, action_response])
781
+
782
+ if "PLEASE_ACTION" in form.answer:
783
+ if verbose:
784
+ print("Analyzing action responses and generating answer...")
785
+
786
+ answer = await self._chat(
787
+ "please provide final answer basing on the above"
788
+ " information, provide answer value as a string only"
789
+ " do not return as json, do not include other information",
790
+ )
791
+
792
+ if isinstance(answer, dict):
793
+ a = answer.get("answer", None)
794
+ if a is not None:
795
+ answer = a
796
+
797
+ answer = str(answer).strip()
798
+ if answer.startswith("{") and answer.endswith("}"):
799
+ answer = answer[1:-1]
800
+ answer = answer.strip()
801
+ if '"answer":' in answer:
802
+ answer.replace('"answer":', "")
803
+ answer = answer.strip()
804
+ elif "'answer':" in answer:
805
+ answer.replace("'answer':", "")
806
+ answer = answer.strip()
807
+
808
+ form.answer = answer
809
+
810
+ return form, branch if return_branch else form
811
+
812
+ async def _extend(
813
+ self,
814
+ tools,
815
+ reason,
816
+ predict,
817
+ score,
818
+ select,
819
+ plan,
820
+ # image,
821
+ allow_action,
822
+ confidence,
823
+ score_num_digits,
824
+ score_range,
825
+ select_choices,
826
+ predict_num_sentences,
827
+ **kwargs,
828
+ ):
829
+ """
830
+ Handles the extension of the form based on the provided parameters.
831
+
832
+ Args:
833
+ form: Form data.
834
+ tools: Tools data.
835
+ reason: Flag indicating if reason should be included.
836
+ predict: Flag indicating if prediction should be included.
837
+ score: Flag indicating if score should be included.
838
+ select: Flag indicating if selection should be included.
839
+ plan: Flag indicating if plan should be included.
840
+ allow_action: Flag indicating if action should be allowed.
841
+ confidence: Flag indicating if confidence should be included.
842
+ score_num_digits: Number of digits for score.
843
+ score_range: Range for score.
844
+ select_choices: Choices for selection.
845
+ predict_num_sentences: Number of sentences for prediction.
846
+ allow_extension: Flag indicating if extension should be allowed.
847
+ max_extension: Maximum extension value.
848
+ kwargs: Additional keyword arguments.
849
+
850
+ Returns:
851
+ list: The extended forms.
852
+ """
853
+ extension_forms = []
854
+
855
+ # Ensure the next step in the plan is handled
856
+ directive_kwargs = {
857
+ "tools": tools,
858
+ "reason": reason,
859
+ "predict": predict,
860
+ "score": score,
861
+ "select": select,
862
+ "allow_action": allow_action,
863
+ "confidence": confidence,
864
+ "score_num_digits": score_num_digits,
865
+ "score_range": score_range,
866
+ "select_choices": select_choices,
867
+ "predict_num_sentences": predict_num_sentences,
868
+ **kwargs,
869
+ }
870
+
871
+ if plan:
872
+ keys = [f"step_{i+1}" for i in range(len(plan))]
873
+ plan = StringMatch.force_validate_dict(plan, keys)
874
+
875
+ # If plan is provided, process each step
876
+ for i in keys:
877
+ directive_kwargs["instruction"] = plan[i]
878
+ last_form = await self._direct(**directive_kwargs)
879
+ last_form.is_extension = True
880
+ extension_forms.append(last_form)
881
+ directive_kwargs["max_extension"] -= 1
882
+ if not getattr(last_form, "extension_required", None):
883
+ break
884
+
885
+ else:
886
+ # Handle single step extension
887
+ last_form = await self._direct(**directive_kwargs)
888
+ last_form.is_extension = True
889
+ extension_forms.append(last_form)
890
+
891
+ return extension_forms
892
+
893
+ async def _act(self, form, branch, actions=None):
894
+ """
895
+ Processes actions based on the provided form and actions.
896
+
897
+ Args:
898
+ form: Form data.
899
+ branch: Branch instance.
900
+ actions: Actions data.
901
+
902
+ Returns:
903
+ dict: The updated form.
904
+ """
905
+ if getattr(form, "action_performed", None) is True:
906
+ return form
907
+
908
+ keys = [f"action_{i+1}" for i in range(len(actions))]
909
+ actions = StringMatch.force_validate_dict(actions, keys)
910
+
911
+ try:
912
+ requests = []
913
+ for k in keys:
914
+ _func = actions[k]["function"]
915
+ _func = _func.replace("functions.", "")
916
+ msg = ActionRequest(
917
+ function=_func,
918
+ arguments=actions[k]["arguments"],
919
+ sender=branch.ln_id,
920
+ recipient=branch.tool_manager.registry[_func].ln_id,
921
+ )
922
+ requests.append(msg)
923
+ branch.add_message(action_request=msg)
924
+
925
+ if requests:
926
+ out = await self._process_action_request(
927
+ branch=branch, invoke_tool=True, action_request=requests
928
+ )
929
+
930
+ if out is False:
931
+ raise ValueError("No requests found.")
932
+
933
+ len_actions = len(actions)
934
+ action_responses = [
935
+ i
936
+ for i in branch.messages[-len_actions:]
937
+ if isinstance(i, ActionResponse)
938
+ ]
939
+
940
+ _action_responses = {}
941
+ for idx, item in enumerate(action_responses):
942
+ _action_responses[f"action_{idx+1}"] = item._to_dict()
943
+
944
+ form.append_to_request("action_response")
945
+ if (a := getattr(form, "action_response", None)) is None:
946
+ form.add_field("action_response", {})
947
+
948
+ len1 = len(form.action_response)
949
+ for k, v in _action_responses.items():
950
+ while k in form.action_response:
951
+ k = f"{k}_1"
952
+ form.action_response[k] = v
953
+
954
+ if len(form.action_response) > len1:
955
+ form.append_to_request("action_performed")
956
+ form.action_performed = True
957
+ return form
958
+
959
+ except Exception as e:
960
+ raise ValueError(f"Error processing action request: {e}")
961
+
962
+ async def _select(
963
+ self,
964
+ form=None,
965
+ choices=None,
966
+ reason=False,
967
+ confidence_score=None,
968
+ instruction=None,
969
+ template=None,
970
+ context=None,
971
+ branch=None,
972
+ **kwargs,
973
+ ):
974
+ """
975
+ Selects a response based on the provided parameters.
976
+
977
+ Args:
978
+ form (Any, optional): Form to create instruction from.
979
+ choices (Any, optional): Choices for the selection.
980
+ reason (bool, optional): Whether to include a reason for the selection.
981
+ confidence_score (Any, optional): Confidence score for the selection.
982
+ instruction (Any, optional): Instruction for the selection.
983
+ template (Any, optional): Template for the selection.
984
+ context (Any, optional): Context to perform the selection on.
985
+ branch (Any, optional): Branch to use for the selection.
986
+ **kwargs: Additional arguments for the selection.
987
+
988
+ Returns:
989
+ Any: The selection response.
990
+ """
991
+ branch = branch or self.branch
992
+
993
+ if not form:
994
+ form = template(
995
+ choices=choices,
996
+ reason=reason,
997
+ confidence_score=confidence_score,
998
+ instruction=instruction,
999
+ context=context,
1000
+ )
1001
+
1002
+ return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
1003
+
1004
+ async def _predict(
1005
+ self,
1006
+ form=None,
1007
+ num_sentences=None,
1008
+ reason=False,
1009
+ confidence_score=None,
1010
+ instruction=None,
1011
+ context=None,
1012
+ branch=None,
1013
+ template=None,
1014
+ **kwargs,
1015
+ ):
1016
+ """
1017
+ Predicts a response based on the provided parameters.
1018
+
1019
+ Args:
1020
+ form: Form data.
1021
+ num_sentences: Number of sentences for the prediction.
1022
+ reason: Flag indicating if reason should be included.
1023
+ confidence_score: Confidence score for the prediction.
1024
+ instruction: Instruction for the prediction.
1025
+ context: Context to perform the prediction on.
1026
+ branch: Branch instance.
1027
+ template: Template for the prediction.
1028
+ kwargs: Additional keyword arguments.
1029
+
1030
+ Returns:
1031
+ Any: The prediction response.
1032
+ """
1033
+ branch = branch or self.branch
1034
+
1035
+ if not form:
1036
+ form = template(
1037
+ instruction=instruction,
1038
+ context=context,
1039
+ num_sentences=num_sentences,
1040
+ confidence_score=confidence_score,
1041
+ reason=reason,
1042
+ )
1043
+
1044
+ return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
1045
+
1046
+ async def _score(
1047
+ self,
1048
+ form=None,
1049
+ score_range=None,
1050
+ include_endpoints=None,
1051
+ num_digit=None,
1052
+ reason=False,
1053
+ confidence_score=None,
1054
+ instruction=None,
1055
+ context=None,
1056
+ branch=None,
1057
+ template=None,
1058
+ **kwargs,
1059
+ ):
1060
+ """
1061
+ Scores a response based on the provided parameters.
1062
+
1063
+ Args:
1064
+ form: Form data.
1065
+ score_range: Range for score.
1066
+ include_endpoints: Flag indicating if endpoints should be included.
1067
+ num_digit: Number of digits for score.
1068
+ reason: Flag indicating if reason should be included.
1069
+ confidence_score: Confidence score for the score.
1070
+ instruction: Instruction for the score.
1071
+ context: Context to perform the score on.
1072
+ branch: Branch instance.
1073
+ template: Template for the score.
1074
+ kwargs: Additional keyword arguments.
1075
+
1076
+ Returns:
1077
+ Any: The score response.
1078
+ """
1079
+ branch = branch or self.branch
1080
+ if not form:
1081
+ form = template(
1082
+ score_range=score_range,
1083
+ include_endpoints=include_endpoints,
1084
+ num_digit=num_digit,
1085
+ reason=reason,
1086
+ confidence_score=confidence_score,
1087
+ instruction=instruction,
1088
+ context=context,
1089
+ )
1090
+
1091
+ return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
1092
+
1093
+ async def _plan(
1094
+ self,
1095
+ form=None,
1096
+ num_step=None,
1097
+ reason=False,
1098
+ confidence_score=None,
1099
+ instruction=None,
1100
+ context=None,
1101
+ branch=None,
1102
+ template=None,
1103
+ **kwargs,
1104
+ ):
1105
+ """
1106
+ Plans a response based on the provided parameters.
1107
+
1108
+ Args:
1109
+ form: Form data.
1110
+ num_step: Number of steps for the plan.
1111
+ reason: Flag indicating if reason should be included.
1112
+ confidence_score: Confidence score for the plan.
1113
+ instruction: Instruction for the plan.
1114
+ context: Context to perform the plan on.
1115
+ branch: Branch instance.
1116
+ template: Template for the plan.
1117
+ kwargs: Additional keyword arguments.
1118
+
1119
+ Returns:
1120
+ Any: The plan response.
1121
+ """
1122
+ branch = branch or self.branch
1123
+ template = template or self.default_template
1124
+
1125
+ if not form:
1126
+ form = template(
1127
+ instruction=instruction,
1128
+ context=context,
1129
+ num_step=num_step,
1130
+ reason=reason,
1131
+ confidence_score=confidence_score,
1132
+ )
1133
+
1134
+ return await self._chat(form=form, **kwargs)
1135
+
1136
+ @staticmethod
1137
+ def _process_model_response(content_, requested_fields):
1138
+ """
1139
+ Processes the model response content.
1140
+
1141
+ Args:
1142
+ content_: The content data.
1143
+ requested_fields: Fields requested in the response.
1144
+
1145
+ Returns:
1146
+ Any: The processed response.
1147
+ """
1148
+ out_ = content_.get("content", "")
1149
+ if out_ == "":
1150
+ out_ = content_
1151
+
1152
+ if requested_fields:
1153
+ with contextlib.suppress(Exception):
1154
+ return StringMatch.force_validate_dict(out_, requested_fields)
1155
+
1156
+ if isinstance(out_, str):
1157
+ with contextlib.suppress(Exception):
1158
+ return ParseUtil.fuzzy_parse_json(out_)
1159
+
1160
+ with contextlib.suppress(Exception):
1161
+ return ParseUtil.extract_json_block(out_)
1162
+
1163
+ with contextlib.suppress(Exception):
1164
+ match = re.search(r"```json\n({.*?})\n```", out_, re.DOTALL)
1165
+ if match:
1166
+ return ParseUtil.fuzzy_parse_json(match.group(1))
1167
+
1168
+ return out_