lionagi 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (268) hide show
  1. lionagi/__init__.py +60 -5
  2. lionagi/core/__init__.py +0 -25
  3. lionagi/core/_setting/_setting.py +59 -0
  4. lionagi/core/action/__init__.py +14 -0
  5. lionagi/core/action/function_calling.py +136 -0
  6. lionagi/core/action/manual.py +1 -0
  7. lionagi/core/action/node.py +109 -0
  8. lionagi/core/action/tool.py +114 -0
  9. lionagi/core/action/tool_manager.py +356 -0
  10. lionagi/core/agent/base_agent.py +27 -13
  11. lionagi/core/agent/eval/evaluator.py +1 -0
  12. lionagi/core/agent/eval/vote.py +40 -0
  13. lionagi/core/agent/learn/learner.py +59 -0
  14. lionagi/core/agent/plan/unit_template.py +1 -0
  15. lionagi/core/collections/__init__.py +17 -0
  16. lionagi/core/{generic/data_logger.py → collections/_logger.py} +69 -55
  17. lionagi/core/collections/abc/__init__.py +53 -0
  18. lionagi/core/collections/abc/component.py +615 -0
  19. lionagi/core/collections/abc/concepts.py +297 -0
  20. lionagi/core/collections/abc/exceptions.py +150 -0
  21. lionagi/core/collections/abc/util.py +45 -0
  22. lionagi/core/collections/exchange.py +161 -0
  23. lionagi/core/collections/flow.py +426 -0
  24. lionagi/core/collections/model.py +419 -0
  25. lionagi/core/collections/pile.py +913 -0
  26. lionagi/core/collections/progression.py +236 -0
  27. lionagi/core/collections/util.py +64 -0
  28. lionagi/core/director/direct.py +314 -0
  29. lionagi/core/director/director.py +2 -0
  30. lionagi/core/{execute/branch_executor.py → engine/branch_engine.py} +134 -97
  31. lionagi/core/{execute/instruction_map_executor.py → engine/instruction_map_engine.py} +80 -55
  32. lionagi/{experimental/directive/evaluator → core/engine}/script_engine.py +17 -1
  33. lionagi/core/executor/base_executor.py +90 -0
  34. lionagi/core/{execute/structure_executor.py → executor/graph_executor.py} +62 -66
  35. lionagi/core/{execute → executor}/neo4j_executor.py +70 -67
  36. lionagi/core/generic/__init__.py +3 -33
  37. lionagi/core/generic/edge.py +29 -79
  38. lionagi/core/generic/edge_condition.py +16 -0
  39. lionagi/core/generic/graph.py +236 -0
  40. lionagi/core/generic/hyperedge.py +1 -0
  41. lionagi/core/generic/node.py +156 -221
  42. lionagi/core/generic/tree.py +48 -0
  43. lionagi/core/generic/tree_node.py +79 -0
  44. lionagi/core/mail/__init__.py +12 -0
  45. lionagi/core/mail/mail.py +25 -0
  46. lionagi/core/mail/mail_manager.py +139 -58
  47. lionagi/core/mail/package.py +45 -0
  48. lionagi/core/mail/start_mail.py +36 -0
  49. lionagi/core/message/__init__.py +19 -0
  50. lionagi/core/message/action_request.py +133 -0
  51. lionagi/core/message/action_response.py +135 -0
  52. lionagi/core/message/assistant_response.py +95 -0
  53. lionagi/core/message/instruction.py +234 -0
  54. lionagi/core/message/message.py +101 -0
  55. lionagi/core/message/system.py +86 -0
  56. lionagi/core/message/util.py +283 -0
  57. lionagi/core/report/__init__.py +4 -0
  58. lionagi/core/report/base.py +217 -0
  59. lionagi/core/report/form.py +231 -0
  60. lionagi/core/report/report.py +166 -0
  61. lionagi/core/report/util.py +28 -0
  62. lionagi/core/rule/_default.py +16 -0
  63. lionagi/core/rule/action.py +99 -0
  64. lionagi/core/rule/base.py +238 -0
  65. lionagi/core/rule/boolean.py +56 -0
  66. lionagi/core/rule/choice.py +47 -0
  67. lionagi/core/rule/mapping.py +96 -0
  68. lionagi/core/rule/number.py +71 -0
  69. lionagi/core/rule/rulebook.py +109 -0
  70. lionagi/core/rule/string.py +52 -0
  71. lionagi/core/rule/util.py +35 -0
  72. lionagi/core/session/branch.py +431 -0
  73. lionagi/core/session/directive_mixin.py +287 -0
  74. lionagi/core/session/session.py +229 -903
  75. lionagi/core/structure/__init__.py +1 -0
  76. lionagi/core/structure/chain.py +1 -0
  77. lionagi/core/structure/forest.py +1 -0
  78. lionagi/core/structure/graph.py +1 -0
  79. lionagi/core/structure/tree.py +1 -0
  80. lionagi/core/unit/__init__.py +5 -0
  81. lionagi/core/unit/parallel_unit.py +245 -0
  82. lionagi/core/unit/template/action.py +81 -0
  83. lionagi/core/unit/template/base.py +51 -0
  84. lionagi/core/unit/template/plan.py +84 -0
  85. lionagi/core/unit/template/predict.py +109 -0
  86. lionagi/core/unit/template/score.py +124 -0
  87. lionagi/core/unit/template/select.py +104 -0
  88. lionagi/core/unit/unit.py +362 -0
  89. lionagi/core/unit/unit_form.py +305 -0
  90. lionagi/core/unit/unit_mixin.py +1168 -0
  91. lionagi/core/unit/util.py +71 -0
  92. lionagi/core/validator/validator.py +364 -0
  93. lionagi/core/work/work.py +74 -0
  94. lionagi/core/work/work_function.py +92 -0
  95. lionagi/core/work/work_queue.py +81 -0
  96. lionagi/core/work/worker.py +195 -0
  97. lionagi/core/work/worklog.py +124 -0
  98. lionagi/experimental/compressor/base.py +46 -0
  99. lionagi/experimental/compressor/llm_compressor.py +247 -0
  100. lionagi/experimental/compressor/llm_summarizer.py +61 -0
  101. lionagi/experimental/compressor/util.py +70 -0
  102. lionagi/experimental/directive/__init__.py +19 -0
  103. lionagi/experimental/directive/parser/base_parser.py +69 -2
  104. lionagi/experimental/directive/{template_ → template}/base_template.py +17 -1
  105. lionagi/{libs/ln_tokenizer.py → experimental/directive/tokenizer.py} +16 -0
  106. lionagi/experimental/{directive/evaluator → evaluator}/ast_evaluator.py +16 -0
  107. lionagi/experimental/{directive/evaluator → evaluator}/base_evaluator.py +16 -0
  108. lionagi/experimental/knowledge/base.py +10 -0
  109. lionagi/experimental/memory/__init__.py +0 -0
  110. lionagi/experimental/strategies/__init__.py +0 -0
  111. lionagi/experimental/strategies/base.py +1 -0
  112. lionagi/integrations/bridge/langchain_/documents.py +4 -0
  113. lionagi/integrations/bridge/llamaindex_/index.py +30 -0
  114. lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +6 -0
  115. lionagi/integrations/chunker/chunk.py +161 -24
  116. lionagi/integrations/config/oai_configs.py +34 -3
  117. lionagi/integrations/config/openrouter_configs.py +14 -2
  118. lionagi/integrations/loader/load.py +122 -21
  119. lionagi/integrations/loader/load_util.py +6 -77
  120. lionagi/integrations/provider/_mapping.py +46 -0
  121. lionagi/integrations/provider/litellm.py +2 -1
  122. lionagi/integrations/provider/mlx_service.py +16 -9
  123. lionagi/integrations/provider/oai.py +91 -4
  124. lionagi/integrations/provider/ollama.py +6 -5
  125. lionagi/integrations/provider/openrouter.py +115 -8
  126. lionagi/integrations/provider/services.py +2 -2
  127. lionagi/integrations/provider/transformers.py +18 -22
  128. lionagi/integrations/storage/__init__.py +3 -3
  129. lionagi/integrations/storage/neo4j.py +52 -60
  130. lionagi/integrations/storage/storage_util.py +44 -46
  131. lionagi/integrations/storage/structure_excel.py +43 -26
  132. lionagi/integrations/storage/to_excel.py +11 -4
  133. lionagi/libs/__init__.py +22 -1
  134. lionagi/libs/ln_api.py +75 -20
  135. lionagi/libs/ln_context.py +37 -0
  136. lionagi/libs/ln_convert.py +21 -9
  137. lionagi/libs/ln_func_call.py +69 -28
  138. lionagi/libs/ln_image.py +107 -0
  139. lionagi/libs/ln_nested.py +26 -11
  140. lionagi/libs/ln_parse.py +82 -23
  141. lionagi/libs/ln_queue.py +16 -0
  142. lionagi/libs/ln_tokenize.py +164 -0
  143. lionagi/libs/ln_validate.py +16 -0
  144. lionagi/libs/special_tokens.py +172 -0
  145. lionagi/libs/sys_util.py +95 -24
  146. lionagi/lions/coder/code_form.py +13 -0
  147. lionagi/lions/coder/coder.py +50 -3
  148. lionagi/lions/coder/util.py +30 -25
  149. lionagi/tests/libs/test_func_call.py +23 -21
  150. lionagi/tests/libs/test_nested.py +36 -21
  151. lionagi/tests/libs/test_parse.py +1 -1
  152. lionagi/tests/test_core/collections/__init__.py +0 -0
  153. lionagi/tests/test_core/collections/test_component.py +206 -0
  154. lionagi/tests/test_core/collections/test_exchange.py +138 -0
  155. lionagi/tests/test_core/collections/test_flow.py +145 -0
  156. lionagi/tests/test_core/collections/test_pile.py +171 -0
  157. lionagi/tests/test_core/collections/test_progression.py +129 -0
  158. lionagi/tests/test_core/generic/test_edge.py +67 -0
  159. lionagi/tests/test_core/generic/test_graph.py +96 -0
  160. lionagi/tests/test_core/generic/test_node.py +106 -0
  161. lionagi/tests/test_core/generic/test_tree_node.py +73 -0
  162. lionagi/tests/test_core/test_branch.py +115 -294
  163. lionagi/tests/test_core/test_form.py +46 -0
  164. lionagi/tests/test_core/test_report.py +105 -0
  165. lionagi/tests/test_core/test_validator.py +111 -0
  166. lionagi/version.py +1 -1
  167. lionagi-0.2.0.dist-info/LICENSE +202 -0
  168. lionagi-0.2.0.dist-info/METADATA +272 -0
  169. lionagi-0.2.0.dist-info/RECORD +240 -0
  170. lionagi/core/branch/base.py +0 -653
  171. lionagi/core/branch/branch.py +0 -474
  172. lionagi/core/branch/flow_mixin.py +0 -96
  173. lionagi/core/branch/util.py +0 -323
  174. lionagi/core/direct/__init__.py +0 -19
  175. lionagi/core/direct/cot.py +0 -123
  176. lionagi/core/direct/plan.py +0 -164
  177. lionagi/core/direct/predict.py +0 -166
  178. lionagi/core/direct/react.py +0 -171
  179. lionagi/core/direct/score.py +0 -279
  180. lionagi/core/direct/select.py +0 -170
  181. lionagi/core/direct/sentiment.py +0 -1
  182. lionagi/core/direct/utils.py +0 -110
  183. lionagi/core/direct/vote.py +0 -64
  184. lionagi/core/execute/base_executor.py +0 -47
  185. lionagi/core/flow/baseflow.py +0 -23
  186. lionagi/core/flow/monoflow/ReAct.py +0 -240
  187. lionagi/core/flow/monoflow/__init__.py +0 -9
  188. lionagi/core/flow/monoflow/chat.py +0 -95
  189. lionagi/core/flow/monoflow/chat_mixin.py +0 -253
  190. lionagi/core/flow/monoflow/followup.py +0 -215
  191. lionagi/core/flow/polyflow/__init__.py +0 -1
  192. lionagi/core/flow/polyflow/chat.py +0 -251
  193. lionagi/core/form/action_form.py +0 -26
  194. lionagi/core/form/field_validator.py +0 -287
  195. lionagi/core/form/form.py +0 -302
  196. lionagi/core/form/mixin.py +0 -214
  197. lionagi/core/form/scored_form.py +0 -13
  198. lionagi/core/generic/action.py +0 -26
  199. lionagi/core/generic/component.py +0 -532
  200. lionagi/core/generic/condition.py +0 -46
  201. lionagi/core/generic/mail.py +0 -90
  202. lionagi/core/generic/mailbox.py +0 -36
  203. lionagi/core/generic/relation.py +0 -70
  204. lionagi/core/generic/signal.py +0 -22
  205. lionagi/core/generic/structure.py +0 -362
  206. lionagi/core/generic/transfer.py +0 -20
  207. lionagi/core/generic/work.py +0 -40
  208. lionagi/core/graph/graph.py +0 -126
  209. lionagi/core/graph/tree.py +0 -190
  210. lionagi/core/mail/schema.py +0 -63
  211. lionagi/core/messages/schema.py +0 -325
  212. lionagi/core/tool/__init__.py +0 -5
  213. lionagi/core/tool/tool.py +0 -28
  214. lionagi/core/tool/tool_manager.py +0 -283
  215. lionagi/experimental/report/form.py +0 -64
  216. lionagi/experimental/report/report.py +0 -138
  217. lionagi/experimental/report/util.py +0 -47
  218. lionagi/experimental/tool/function_calling.py +0 -43
  219. lionagi/experimental/tool/manual.py +0 -66
  220. lionagi/experimental/tool/schema.py +0 -59
  221. lionagi/experimental/tool/tool_manager.py +0 -138
  222. lionagi/experimental/tool/util.py +0 -16
  223. lionagi/experimental/validator/rule.py +0 -139
  224. lionagi/experimental/validator/validator.py +0 -56
  225. lionagi/experimental/work/__init__.py +0 -10
  226. lionagi/experimental/work/async_queue.py +0 -54
  227. lionagi/experimental/work/schema.py +0 -73
  228. lionagi/experimental/work/work_function.py +0 -67
  229. lionagi/experimental/work/worker.py +0 -56
  230. lionagi/experimental/work2/form.py +0 -371
  231. lionagi/experimental/work2/report.py +0 -289
  232. lionagi/experimental/work2/schema.py +0 -30
  233. lionagi/experimental/work2/tests.py +0 -72
  234. lionagi/experimental/work2/work_function.py +0 -89
  235. lionagi/experimental/work2/worker.py +0 -12
  236. lionagi/integrations/bridge/llamaindex_/get_index.py +0 -294
  237. lionagi/tests/test_core/generic/test_component.py +0 -89
  238. lionagi/tests/test_core/test_base_branch.py +0 -426
  239. lionagi/tests/test_core/test_chat_flow.py +0 -63
  240. lionagi/tests/test_core/test_mail_manager.py +0 -75
  241. lionagi/tests/test_core/test_prompts.py +0 -51
  242. lionagi/tests/test_core/test_session.py +0 -254
  243. lionagi/tests/test_core/test_session_base_util.py +0 -313
  244. lionagi/tests/test_core/test_tool_manager.py +0 -95
  245. lionagi-0.1.2.dist-info/LICENSE +0 -9
  246. lionagi-0.1.2.dist-info/METADATA +0 -174
  247. lionagi-0.1.2.dist-info/RECORD +0 -206
  248. /lionagi/core/{branch → _setting}/__init__.py +0 -0
  249. /lionagi/core/{execute → agent/eval}/__init__.py +0 -0
  250. /lionagi/core/{flow → agent/learn}/__init__.py +0 -0
  251. /lionagi/core/{form → agent/plan}/__init__.py +0 -0
  252. /lionagi/core/{branch/executable_branch.py → agent/plan/plan.py} +0 -0
  253. /lionagi/core/{graph → director}/__init__.py +0 -0
  254. /lionagi/core/{messages → engine}/__init__.py +0 -0
  255. /lionagi/{experimental/directive/evaluator → core/engine}/sandbox_.py +0 -0
  256. /lionagi/{experimental/directive/evaluator → core/executor}/__init__.py +0 -0
  257. /lionagi/{experimental/directive/template_ → core/rule}/__init__.py +0 -0
  258. /lionagi/{experimental/report → core/unit/template}/__init__.py +0 -0
  259. /lionagi/{experimental/tool → core/validator}/__init__.py +0 -0
  260. /lionagi/{experimental/validator → core/work}/__init__.py +0 -0
  261. /lionagi/experimental/{work2 → compressor}/__init__.py +0 -0
  262. /lionagi/{core/flow/mono_chat_mixin.py → experimental/directive/template/__init__.py} +0 -0
  263. /lionagi/experimental/directive/{schema.py → template/schema.py} +0 -0
  264. /lionagi/experimental/{work2/util.py → evaluator/__init__.py} +0 -0
  265. /lionagi/experimental/{work2/work.py → knowledge/__init__.py} +0 -0
  266. /lionagi/{tests/libs/test_async.py → experimental/knowledge/graph.py} +0 -0
  267. {lionagi-0.1.2.dist-info → lionagi-0.2.0.dist-info}/WHEEL +0 -0
  268. {lionagi-0.1.2.dist-info → lionagi-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1168 @@
1
+ """
2
+ Copyright 2024 HaiyangLi
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ The base directive module.
18
+ """
19
+
20
+ import asyncio
21
+ import contextlib
22
+ import re
23
+ from abc import ABC
24
+
25
+ from typing import Any, Optional
26
+
27
+ from lionagi.libs import ParseUtil, StringMatch, to_list
28
+ from lionagi.libs.ln_nested import nmerge
29
+ from lionagi.core.collections.abc import ActionError
30
+ from lionagi.core.message import ActionRequest, ActionResponse, Instruction
31
+ from lionagi.core.message.util import _parse_action_request
32
+ from lionagi.core.report.form import Form
33
+ from lionagi.core.unit.util import process_tools
34
+ from lionagi.core.validator.validator import Validator
35
+
36
+
37
+ class DirectiveMixin(ABC):
38
+ """
39
+ DirectiveMixin is a class for handling chat operations and
40
+ processing responses.
41
+ """
42
+
43
+ def _create_chat_config(
44
+ self,
45
+ system: Optional[str] = None,
46
+ instruction: Optional[str] = None,
47
+ context: Optional[str] = None,
48
+ images: Optional[str] = None,
49
+ sender: Optional[str] = None,
50
+ recipient: Optional[str] = None,
51
+ requested_fields: Optional[list] = None,
52
+ form: Form = None,
53
+ tools: bool = False,
54
+ branch: Optional[Any] = None,
55
+ **kwargs,
56
+ ) -> Any:
57
+ """
58
+ Create the chat configuration based on the provided parameters.
59
+
60
+ Args:
61
+ system: System message.
62
+ instruction: Instruction message.
63
+ context: Context message.
64
+ sender: Sender identifier.
65
+ recipient: Recipient identifier.
66
+ requested_fields: Fields requested in the response.
67
+ form: Form data.
68
+ tools: Flag indicating if tools should be used.
69
+ branch: Branch instance.
70
+ kwargs: Additional keyword arguments.
71
+
72
+ Returns:
73
+ dict: The chat configuration.
74
+ """
75
+ branch = branch or self.branch
76
+
77
+ if system:
78
+ branch.add_message(system=system)
79
+
80
+ if not form:
81
+ if recipient == "branch.ln_id":
82
+ recipient = branch.ln_id
83
+
84
+ branch.add_message(
85
+ instruction=instruction,
86
+ context=context,
87
+ sender=sender,
88
+ recipient=recipient,
89
+ requested_fields=requested_fields,
90
+ images=images,
91
+ )
92
+ else:
93
+ instruct_ = Instruction.from_form(form)
94
+ branch.add_message(instruction=instruct_)
95
+
96
+ if "tool_parsed" in kwargs:
97
+ kwargs.pop("tool_parsed")
98
+ tool_kwarg = {"tools": tools}
99
+ kwargs = tool_kwarg | kwargs
100
+ elif tools and branch.has_tools:
101
+ kwargs = branch.tool_manager.parse_tool(tools=tools, **kwargs)
102
+
103
+ config = {**self.imodel.config, **kwargs}
104
+ if sender is not None:
105
+ config["sender"] = sender
106
+
107
+ return config
108
+
109
+ async def _call_chatcompletion(
110
+ self, imodel: Optional[Any] = None, branch: Optional[Any] = None, **kwargs
111
+ ) -> Any:
112
+ """
113
+ Calls the chat completion model.
114
+
115
+ Args:
116
+ imodel: The model instance.
117
+ branch: The branch instance.
118
+ kwargs: Additional keyword arguments.
119
+
120
+ Returns:
121
+ Any: The chat completion result.
122
+ """
123
+ imodel = imodel or self.imodel
124
+ branch = branch or self.branch
125
+ return await imodel.call_chat_completion(branch.to_chat_messages(), **kwargs)
126
+
127
+ async def _process_chatcompletion(
128
+ self,
129
+ payload: dict,
130
+ completion: dict,
131
+ sender: str,
132
+ invoke_tool: bool = True,
133
+ branch: Optional[Any] = None,
134
+ action_request: Optional[Any] = None,
135
+ costs=None,
136
+ ) -> Any:
137
+ """
138
+ Processes the chat completion response.
139
+ Currently only support last message for function calling
140
+
141
+ Args:
142
+ payload: The payload data.
143
+ completion: The completion data.
144
+ sender: The sender identifier.
145
+ invoke_tool: Flag indicating if tools should be invoked.
146
+ branch: The branch instance.
147
+ action_request: The action request instance.
148
+
149
+ Returns:
150
+ Any: The processed result.
151
+ """
152
+ branch = branch or self.branch
153
+ _msg = None
154
+
155
+ if "choices" in completion:
156
+ payload.pop("messages", None)
157
+ branch.update_last_instruction_meta(payload)
158
+ _choices = completion.pop("choices", None)
159
+
160
+ def process_completion_choice(choice, price=None):
161
+ if isinstance(choice, dict):
162
+ msg = choice.pop("message", None)
163
+ _completion = completion.copy()
164
+ _completion.update(choice)
165
+ branch.add_message(
166
+ assistant_response=msg,
167
+ metadata=_completion,
168
+ sender=sender,
169
+ )
170
+
171
+ a = branch.messages[-1]._meta_get(
172
+ ["extra", "usage", "prompt_tokens"], 0
173
+ )
174
+ b = branch.messages[-1]._meta_get(
175
+ ["extra", "usage", "completion_tokens"], 0
176
+ )
177
+ m = completion.get("model", None)
178
+ if m:
179
+ ttl = (a * price[0] + b * price[1]) / 1000000
180
+ branch.messages[-1]._meta_insert(["extra", "usage", "expense"], ttl)
181
+ return msg
182
+
183
+ if _choices and not isinstance(_choices, list):
184
+ _choices = [_choices]
185
+
186
+ if _choices and isinstance(_choices, list):
187
+ for _choice in _choices:
188
+ _msg = process_completion_choice(_choice, price=costs)
189
+
190
+ branch.imodel.status_tracker.num_tasks_succeeded += 1
191
+ else:
192
+ branch.imodel.status_tracker.num_tasks_failed += 1
193
+
194
+ return await self._process_action_request(
195
+ _msg=_msg,
196
+ branch=branch,
197
+ invoke_tool=invoke_tool,
198
+ action_request=action_request,
199
+ )
200
+
201
+ async def _process_action_request(
202
+ self,
203
+ _msg: Optional[dict] = None,
204
+ branch: Optional[Any] = None,
205
+ invoke_tool: bool = True,
206
+ action_request: Optional[Any] = None,
207
+ ) -> Any:
208
+ """
209
+ Processes an action request from the assistant response.
210
+
211
+ Args:
212
+ _msg: The message data.
213
+ branch: The branch instance.
214
+ invoke_tool: Flag indicating if tools should be invoked.
215
+ action_request: The action request instance.
216
+
217
+ Returns:
218
+ Any: The processed result.
219
+ """
220
+ action_request = action_request or _parse_action_request(_msg)
221
+ if action_request is None:
222
+ return _msg if _msg else False
223
+
224
+ if action_request:
225
+ for i in action_request:
226
+ if i.function in branch.tool_manager.registry:
227
+ i.recipient = branch.tool_manager.registry[i.function].ln_id
228
+ else:
229
+ raise ActionError(f"Tool {i.function} not found in registry")
230
+ branch.add_message(action_request=i, recipient=i.recipient)
231
+
232
+ if invoke_tool:
233
+ tasks = []
234
+ for i in action_request:
235
+ tool = branch.tool_manager.registry[i.function]
236
+ tasks.append(asyncio.create_task(tool.invoke(i.arguments)))
237
+ results = await asyncio.gather(*tasks)
238
+
239
+ for idx, item in enumerate(results):
240
+ if item is not None:
241
+ branch.add_message(
242
+ action_request=action_request[idx],
243
+ func_outputs=item,
244
+ sender=action_request[idx].recipient,
245
+ recipient=action_request[idx].sender,
246
+ )
247
+
248
+ return None
249
+
250
+ async def _output(
251
+ self,
252
+ payload: dict,
253
+ completion: dict,
254
+ sender: str,
255
+ invoke_tool: bool,
256
+ requested_fields: dict,
257
+ form: Form = None,
258
+ return_form: bool = True,
259
+ strict: bool = False,
260
+ rulebook: Any = None,
261
+ use_annotation: bool = True,
262
+ template_name: str = None,
263
+ costs=None,
264
+ ) -> Any:
265
+ """
266
+ Outputs the final processed response.
267
+
268
+ Args:
269
+ payload: The payload data.
270
+ completion: The completion data.
271
+ sender: The sender identifier.
272
+ invoke_tool: Flag indicating if tools should be invoked.
273
+ requested_fields: Fields requested in the response.
274
+ form: Form data.
275
+ return_form: Flag indicating if form should be returned.
276
+ strict: Flag indicating if strict validation should be applied.
277
+ rulebook: Rulebook instance for validation.
278
+ use_annotation: Flag indicating if annotations should be used.
279
+ template_name: Template name for form.
280
+
281
+ Returns:
282
+ Any: The processed response.
283
+ """
284
+ _msg = await self._process_chatcompletion(
285
+ payload=payload,
286
+ completion=completion,
287
+ sender=sender,
288
+ invoke_tool=invoke_tool,
289
+ costs=costs,
290
+ )
291
+
292
+ if _msg is None:
293
+ return None
294
+
295
+ response_ = self._process_model_response(_msg, requested_fields)
296
+
297
+ if form:
298
+ validator = Validator(rulebook=rulebook) if rulebook else self.validator
299
+ form = await validator.validate_response(
300
+ form=form,
301
+ response=response_,
302
+ strict=strict,
303
+ use_annotation=use_annotation,
304
+ )
305
+ if template_name:
306
+ form.template_name = template_name
307
+
308
+ return (
309
+ form
310
+ if return_form
311
+ else {
312
+ i: form.work_fields[i]
313
+ for i in form.requested_fields
314
+ if form.work_fields[i] is not None
315
+ }
316
+ )
317
+
318
+ return response_
319
+
320
+ async def _base_chat(
321
+ self,
322
+ instruction: Any = None,
323
+ *,
324
+ system: Any = None,
325
+ context: Any = None,
326
+ sender: Any = None,
327
+ recipient: Any = None,
328
+ requested_fields: dict = None,
329
+ form: Form = None,
330
+ tools: Any = False,
331
+ images: Optional[str] = None,
332
+ invoke_tool: bool = True,
333
+ return_form: bool = True,
334
+ strict: bool = False,
335
+ rulebook: Any = None,
336
+ imodel: Any = None,
337
+ use_annotation: bool = True,
338
+ branch: Any = None,
339
+ clear_messages: bool = False,
340
+ return_branch: bool = False,
341
+ **kwargs,
342
+ ) -> Any:
343
+ """
344
+ Handles the base chat operation by configuring the chat and
345
+ processing the response.
346
+
347
+ Args:
348
+ instruction: Instruction message.
349
+ system: System message.
350
+ context: Context message.
351
+ sender: Sender identifier.
352
+ recipient: Recipient identifier.
353
+ requested_fields: Fields requested in the response.
354
+ form: Form data.
355
+ tools: Flag indicating if tools should be used.
356
+ invoke_tool: Flag indicating if tools should be invoked.
357
+ return_form: Flag indicating if form should be returned.
358
+ strict: Flag indicating if strict validation should be applied.
359
+ rulebook: Rulebook instance for validation.
360
+ imodel: Model instance.
361
+ use_annotation: Flag indicating if annotations should be used.
362
+ branch: Branch instance.
363
+ clear_messages: Flag indicating if messages should be cleared.
364
+ return_branch: Flag indicating if branch should be returned.
365
+ kwargs: Additional keyword arguments.
366
+
367
+ Returns:
368
+ Any: The processed response and branch.
369
+ """
370
+ branch = branch or self.branch
371
+ if clear_messages:
372
+ branch.clear()
373
+ branch.set_system(system)
374
+
375
+ config = self._create_chat_config(
376
+ system=system,
377
+ instruction=instruction,
378
+ context=context,
379
+ sender=sender,
380
+ recipient=recipient,
381
+ requested_fields=requested_fields,
382
+ form=form,
383
+ tools=tools,
384
+ branch=branch,
385
+ images=images,
386
+ **kwargs,
387
+ )
388
+
389
+ payload, completion = await self._call_chatcompletion(
390
+ imodel=imodel, branch=branch, **config
391
+ )
392
+
393
+ imodel = imodel or self.imodel
394
+ out_ = await self._output(
395
+ payload=payload,
396
+ completion=completion,
397
+ sender=sender,
398
+ invoke_tool=invoke_tool,
399
+ requested_fields=requested_fields,
400
+ form=form,
401
+ return_form=return_form,
402
+ strict=strict,
403
+ rulebook=rulebook,
404
+ use_annotation=use_annotation,
405
+ costs=imodel.costs,
406
+ )
407
+
408
+ return out_, branch if return_branch else out_
409
+
410
+ async def _chat(
411
+ self,
412
+ instruction=None,
413
+ context=None,
414
+ system=None,
415
+ sender=None,
416
+ recipient=None,
417
+ branch=None,
418
+ requested_fields=None,
419
+ form: Form = None,
420
+ tools=False,
421
+ invoke_tool=True,
422
+ return_form=True,
423
+ strict=False,
424
+ rulebook=None,
425
+ imodel=None,
426
+ images: Optional[str] = None,
427
+ clear_messages=False,
428
+ use_annotation=True,
429
+ timeout: float = None,
430
+ return_branch=False,
431
+ **kwargs,
432
+ ):
433
+ """
434
+ Handles the chat operation.
435
+
436
+ Args:
437
+ instruction: Instruction message.
438
+ context: Context message.
439
+ system: System message.
440
+ sender: Sender identifier.
441
+ recipient: Recipient identifier.
442
+ branch: Branch instance.
443
+ requested_fields: Fields requested in the response.
444
+ form: Form data.
445
+ tools: Flag indicating if tools should be used.
446
+ invoke_tool: Flag indicating if tools should be invoked.
447
+ return_form: Flag indicating if form should be returned.
448
+ strict: Flag indicating if strict validation should be applied.
449
+ rulebook: Rulebook instance for validation.
450
+ imodel: Model instance.
451
+ clear_messages: Flag indicating if messages should be cleared.
452
+ use_annotation: Flag indicating if annotations should be used.
453
+ timeout: Timeout value.
454
+ return_branch: Flag indicating if branch should be returned.
455
+ kwargs: Additional keyword arguments.
456
+
457
+ Returns:
458
+ Any: The processed response.
459
+ """
460
+ a = await self._base_chat(
461
+ context=context,
462
+ instruction=instruction,
463
+ system=system,
464
+ sender=sender,
465
+ recipient=recipient,
466
+ requested_fields=requested_fields,
467
+ form=form,
468
+ tools=tools,
469
+ images=images,
470
+ invoke_tool=invoke_tool,
471
+ return_form=return_form,
472
+ strict=strict,
473
+ rulebook=rulebook,
474
+ imodel=imodel,
475
+ use_annotation=use_annotation,
476
+ timeout=timeout,
477
+ branch=branch,
478
+ clear_messages=clear_messages,
479
+ return_branch=return_branch,
480
+ **kwargs,
481
+ )
482
+
483
+ if isinstance(a, str):
484
+ return a
485
+
486
+ a = list(a)
487
+
488
+ if len(a) == 2 and a[0] == a[1]:
489
+ return a[0] if not isinstance(a[0], tuple) else a[0][0]
490
+ if len(a) == 2 and a[0] != a[1]:
491
+ return a[0], a[1]
492
+ if len(a) == 1 and isinstance(a[0], tuple):
493
+ return a[0][0]
494
+ if len(a) == 1 and not isinstance(a[0], tuple):
495
+ return a[0]
496
+
497
+ async def _direct(
498
+ self,
499
+ instruction=None,
500
+ context=None,
501
+ form: Form = None,
502
+ branch=None,
503
+ tools=None,
504
+ reason: bool = None,
505
+ predict: bool = None,
506
+ score: bool = None,
507
+ select: bool = None,
508
+ plan: bool = None,
509
+ allow_action: bool = None,
510
+ allow_extension: bool = None,
511
+ confidence: bool = None,
512
+ max_extension: int = None,
513
+ score_num_digits=None,
514
+ score_range=None,
515
+ select_choices=None,
516
+ plan_num_step=None,
517
+ predict_num_sentences=None,
518
+ clear_messages=False,
519
+ return_branch=False,
520
+ images: Optional[str] = None,
521
+ verbose=None,
522
+ **kwargs,
523
+ ):
524
+ """
525
+ Directs the operation based on the provided parameters.
526
+
527
+ Args:
528
+ instruction: Instruction message.
529
+ context: Context message.
530
+ form: Form data.
531
+ branch: Branch instance.
532
+ tools: Tools data.
533
+ reason: Flag indicating if reason should be included.
534
+ predict: Flag indicating if prediction should be included.
535
+ score: Flag indicating if score should be included.
536
+ select: Flag indicating if selection should be included.
537
+ plan: Flag indicating if plan should be included.
538
+ allow_action: Flag indicating if action should be allowed.
539
+ allow_extension: Flag indicating if extension should be allowed.
540
+ confidence: Flag indicating if confidence should be included.
541
+ max_extension: Maximum extension value.
542
+ score_num_digits: Number of digits for score.
543
+ score_range: Range for score.
544
+ select_choices: Choices for selection.
545
+ plan_num_step: Number of steps for plan.
546
+ predict_num_sentences: Number of sentences for prediction.
547
+ clear_messages: Flag indicating if messages should be cleared.
548
+ return_branch: Flag indicating if branch should be returned.
549
+ kwargs: Additional keyword arguments.
550
+
551
+ Returns:
552
+ Any: The processed response and branch.
553
+ """
554
+ a = await self._base_direct(
555
+ instruction=instruction,
556
+ context=context,
557
+ form=form,
558
+ branch=branch,
559
+ tools=tools,
560
+ reason=reason,
561
+ predict=predict,
562
+ score=score,
563
+ select=select,
564
+ images=images,
565
+ plan=plan,
566
+ allow_action=allow_action,
567
+ allow_extension=allow_extension,
568
+ confidence=confidence,
569
+ max_extension=max_extension,
570
+ score_num_digits=score_num_digits,
571
+ score_range=score_range,
572
+ select_choices=select_choices,
573
+ plan_num_step=plan_num_step,
574
+ predict_num_sentences=predict_num_sentences,
575
+ clear_messages=clear_messages,
576
+ return_branch=return_branch,
577
+ verbose=verbose,
578
+ **kwargs,
579
+ )
580
+
581
+ a = list(a)
582
+ if len(a) == 2 and a[0] == a[1]:
583
+ return a[0] if not isinstance(a[0], tuple) else a[0][0]
584
+
585
+ return a[0], a[1]
586
+
587
+ async def _base_direct(
588
+ self,
589
+ instruction=None,
590
+ *,
591
+ context=None,
592
+ form: Form = None,
593
+ branch=None,
594
+ tools=None,
595
+ reason: bool = None,
596
+ predict: bool = None,
597
+ score: bool = None,
598
+ select: bool = None,
599
+ plan: bool = None,
600
+ allow_action: bool = None,
601
+ allow_extension: bool = None,
602
+ confidence: bool = None,
603
+ max_extension: int = None,
604
+ score_num_digits=None,
605
+ score_range=None,
606
+ select_choices=None,
607
+ plan_num_step=None,
608
+ predict_num_sentences=None,
609
+ clear_messages=False,
610
+ return_branch=False,
611
+ images: Optional[str] = None,
612
+ verbose=None,
613
+ **kwargs,
614
+ ):
615
+ """
616
+ Handles the base direct operation.
617
+
618
+ Args:
619
+ instruction: Instruction message.
620
+ context: Context message.
621
+ form: Form data.
622
+ branch: Branch instance.
623
+ tools: Tools data.
624
+ reason: Flag indicating if reason should be included.
625
+ predict: Flag indicating if prediction should be included.
626
+ score: Flag indicating if score should be included.
627
+ select: Flag indicating if selection should be included.
628
+ plan: Flag indicating if plan should be included.
629
+ allow_action: Flag indicating if action should be allowed.
630
+ allow_extension: Flag indicating if extension should be allowed.
631
+ confidence: Flag indicating if confidence should be included.
632
+ max_extension: Maximum extension value.
633
+ score_num_digits: Number of digits for score.
634
+ score_range: Range for score.
635
+ select_choices: Choices for selection.
636
+ plan_num_step: Number of steps for plan.
637
+ predict_num_sentences: Number of sentences for prediction.
638
+ clear_messages: Flag indicating if messages should be cleared.
639
+ return_branch: Flag indicating if branch should be returned.
640
+ kwargs: Additional keyword arguments.
641
+
642
+ Returns:
643
+ Any: The processed response and branch.
644
+ """
645
+ # Ensure branch is initialized
646
+ branch = branch or self.branch
647
+ if clear_messages:
648
+ branch.clear()
649
+
650
+ # Set a default max_extension if allow_extension is True and max_extension is None
651
+ if allow_extension and not max_extension:
652
+ max_extension = 3 # Set a default limit for recursion
653
+
654
+ # Process tools if provided
655
+ if tools:
656
+ process_tools(tools, branch)
657
+
658
+ if allow_action and not tools:
659
+ tools = True
660
+
661
+ tool_schema=None
662
+ if tools:
663
+ tool_schema = branch.tool_manager.get_tool_schema(tools)
664
+
665
+ if not form:
666
+ form = self.default_template(
667
+ instruction=instruction,
668
+ context=context,
669
+ reason=reason,
670
+ predict=predict,
671
+ score=score,
672
+ select=select,
673
+ plan=plan,
674
+ tool_schema=tool_schema,
675
+ allow_action=allow_action,
676
+ allow_extension=allow_extension,
677
+ max_extension=max_extension,
678
+ confidence=confidence,
679
+ score_num_digits=score_num_digits,
680
+ score_range=score_range,
681
+ select_choices=select_choices,
682
+ plan_num_step=plan_num_step,
683
+ predict_num_sentences=predict_num_sentences,
684
+ )
685
+
686
+ elif form and "tool_schema" not in form._all_fields:
687
+ form.append_to_input("tool_schema")
688
+ form.tool_schema = tool_schema
689
+
690
+ else:
691
+ form.tool_schema = tool_schema
692
+
693
+ verbose = (
694
+ verbose
695
+ if verbose is not None and isinstance(verbose, bool)
696
+ else self.verbose
697
+ )
698
+ if verbose:
699
+ print("Chatting with model...")
700
+
701
+ # Call the base chat method
702
+ form = await self._chat(
703
+ form=form,
704
+ branch=branch,
705
+ images=images,
706
+ **kwargs,
707
+ )
708
+
709
+ # Handle actions if allowed and required
710
+ if allow_action and getattr(form, "action_required", None):
711
+ actions = getattr(form, "actions", None)
712
+ if actions:
713
+ if verbose:
714
+ print(
715
+ "Found action requests in model response. Processing actions..."
716
+ )
717
+ form = await self._act(form, branch, actions=actions)
718
+ if verbose:
719
+ print("Actions processed!")
720
+
721
+ last_form = form
722
+
723
+ ctr = 1
724
+
725
+ # Handle extensions if allowed and required
726
+ extension_forms = []
727
+ max_extension = max_extension if isinstance(max_extension, int) else 3
728
+ while (
729
+ allow_extension
730
+ and max_extension > 0
731
+ and getattr(last_form, "extension_required", None)
732
+ ):
733
+ if getattr(last_form, "is_extension", None):
734
+ break
735
+ if verbose:
736
+ print(f"\nFound extension requests in model response.")
737
+ print(
738
+ f"------------------- Processing extension No.{ctr} -------------------"
739
+ )
740
+
741
+ max_extension -= 1
742
+
743
+ # new form cannot be extended, otherwise it will be an infinite loop
744
+ new_form = await self._extend(
745
+ tools=tools,
746
+ reason=reason,
747
+ predict=predict,
748
+ score=score,
749
+ select=select,
750
+ plan=getattr(last_form, "plan", None),
751
+ allow_action=allow_action,
752
+ confidence=confidence,
753
+ score_num_digits=score_num_digits,
754
+ score_range=score_range,
755
+ select_choices=select_choices,
756
+ predict_num_sentences=predict_num_sentences,
757
+ **kwargs,
758
+ )
759
+
760
+ if verbose:
761
+ print(f"------------------- Extension completed -------------------\n")
762
+
763
+ extension_forms.extend(new_form)
764
+ last_form = new_form[-1] if isinstance(new_form, list) else new_form
765
+ ctr += len(form)
766
+
767
+ if extension_forms:
768
+ if not getattr(form, "extension_forms", None):
769
+ form._add_field("extension_forms", list, None, [])
770
+ form.extension_forms.extend(extension_forms)
771
+ action_responses = [
772
+ i.action_response
773
+ for i in extension_forms
774
+ if getattr(i, "action_response", None) is not None
775
+ ]
776
+ if not hasattr(form, "action_response"):
777
+ form.add_field("action_response", {})
778
+
779
+ for action_response in action_responses:
780
+ nmerge([form.action_response, action_response])
781
+
782
+ if "PLEASE_ACTION" in form.answer:
783
+ if verbose:
784
+ print("Analyzing action responses and generating answer...")
785
+
786
+ answer = await self._chat(
787
+ "please provide final answer basing on the above"
788
+ " information, provide answer value as a string only"
789
+ " do not return as json, do not include other information",
790
+ )
791
+
792
+ if isinstance(answer, dict):
793
+ a = answer.get("answer", None)
794
+ if a is not None:
795
+ answer = a
796
+
797
+ answer = str(answer).strip()
798
+ if answer.startswith("{") and answer.endswith("}"):
799
+ answer = answer[1:-1]
800
+ answer = answer.strip()
801
+ if '"answer":' in answer:
802
+ answer.replace('"answer":', "")
803
+ answer = answer.strip()
804
+ elif "'answer':" in answer:
805
+ answer.replace("'answer':", "")
806
+ answer = answer.strip()
807
+
808
+ form.answer = answer
809
+
810
+ return form, branch if return_branch else form
811
+
812
+ async def _extend(
813
+ self,
814
+ tools,
815
+ reason,
816
+ predict,
817
+ score,
818
+ select,
819
+ plan,
820
+ # image,
821
+ allow_action,
822
+ confidence,
823
+ score_num_digits,
824
+ score_range,
825
+ select_choices,
826
+ predict_num_sentences,
827
+ **kwargs,
828
+ ):
829
+ """
830
+ Handles the extension of the form based on the provided parameters.
831
+
832
+ Args:
833
+ form: Form data.
834
+ tools: Tools data.
835
+ reason: Flag indicating if reason should be included.
836
+ predict: Flag indicating if prediction should be included.
837
+ score: Flag indicating if score should be included.
838
+ select: Flag indicating if selection should be included.
839
+ plan: Flag indicating if plan should be included.
840
+ allow_action: Flag indicating if action should be allowed.
841
+ confidence: Flag indicating if confidence should be included.
842
+ score_num_digits: Number of digits for score.
843
+ score_range: Range for score.
844
+ select_choices: Choices for selection.
845
+ predict_num_sentences: Number of sentences for prediction.
846
+ allow_extension: Flag indicating if extension should be allowed.
847
+ max_extension: Maximum extension value.
848
+ kwargs: Additional keyword arguments.
849
+
850
+ Returns:
851
+ list: The extended forms.
852
+ """
853
+ extension_forms = []
854
+
855
+ # Ensure the next step in the plan is handled
856
+ directive_kwargs = {
857
+ "tools": tools,
858
+ "reason": reason,
859
+ "predict": predict,
860
+ "score": score,
861
+ "select": select,
862
+ "allow_action": allow_action,
863
+ "confidence": confidence,
864
+ "score_num_digits": score_num_digits,
865
+ "score_range": score_range,
866
+ "select_choices": select_choices,
867
+ "predict_num_sentences": predict_num_sentences,
868
+ **kwargs,
869
+ }
870
+
871
+ if plan:
872
+ keys = [f"step_{i+1}" for i in range(len(plan))]
873
+ plan = StringMatch.force_validate_dict(plan, keys)
874
+
875
+ # If plan is provided, process each step
876
+ for i in keys:
877
+ directive_kwargs["instruction"] = plan[i]
878
+ last_form = await self._direct(**directive_kwargs)
879
+ last_form.is_extension = True
880
+ extension_forms.append(last_form)
881
+ directive_kwargs["max_extension"] -= 1
882
+ if not getattr(last_form, "extension_required", None):
883
+ break
884
+
885
+ else:
886
+ # Handle single step extension
887
+ last_form = await self._direct(**directive_kwargs)
888
+ last_form.is_extension = True
889
+ extension_forms.append(last_form)
890
+
891
+ return extension_forms
892
+
893
+ async def _act(self, form, branch, actions=None):
894
+ """
895
+ Processes actions based on the provided form and actions.
896
+
897
+ Args:
898
+ form: Form data.
899
+ branch: Branch instance.
900
+ actions: Actions data.
901
+
902
+ Returns:
903
+ dict: The updated form.
904
+ """
905
+ if getattr(form, "action_performed", None) is True:
906
+ return form
907
+
908
+ keys = [f"action_{i+1}" for i in range(len(actions))]
909
+ actions = StringMatch.force_validate_dict(actions, keys)
910
+
911
+ try:
912
+ requests = []
913
+ for k in keys:
914
+ _func = actions[k]["function"]
915
+ _func = _func.replace("functions.", "")
916
+ msg = ActionRequest(
917
+ function=_func,
918
+ arguments=actions[k]["arguments"],
919
+ sender=branch.ln_id,
920
+ recipient=branch.tool_manager.registry[_func].ln_id,
921
+ )
922
+ requests.append(msg)
923
+ branch.add_message(action_request=msg)
924
+
925
+ if requests:
926
+ out = await self._process_action_request(
927
+ branch=branch, invoke_tool=True, action_request=requests
928
+ )
929
+
930
+ if out is False:
931
+ raise ValueError("No requests found.")
932
+
933
+ len_actions = len(actions)
934
+ action_responses = [
935
+ i
936
+ for i in branch.messages[-len_actions:]
937
+ if isinstance(i, ActionResponse)
938
+ ]
939
+
940
+ _action_responses = {}
941
+ for idx, item in enumerate(action_responses):
942
+ _action_responses[f"action_{idx+1}"] = item._to_dict()
943
+
944
+ form.append_to_request("action_response")
945
+ if (a := getattr(form, "action_response", None)) is None:
946
+ form.add_field("action_response", {})
947
+
948
+ len1 = len(form.action_response)
949
+ for k, v in _action_responses.items():
950
+ while k in form.action_response:
951
+ k = f"{k}_1"
952
+ form.action_response[k] = v
953
+
954
+ if len(form.action_response) > len1:
955
+ form.append_to_request("action_performed")
956
+ form.action_performed = True
957
+ return form
958
+
959
+ except Exception as e:
960
+ raise ValueError(f"Error processing action request: {e}")
961
+
962
+ async def _select(
963
+ self,
964
+ form=None,
965
+ choices=None,
966
+ reason=False,
967
+ confidence_score=None,
968
+ instruction=None,
969
+ template=None,
970
+ context=None,
971
+ branch=None,
972
+ **kwargs,
973
+ ):
974
+ """
975
+ Selects a response based on the provided parameters.
976
+
977
+ Args:
978
+ form (Any, optional): Form to create instruction from.
979
+ choices (Any, optional): Choices for the selection.
980
+ reason (bool, optional): Whether to include a reason for the selection.
981
+ confidence_score (Any, optional): Confidence score for the selection.
982
+ instruction (Any, optional): Instruction for the selection.
983
+ template (Any, optional): Template for the selection.
984
+ context (Any, optional): Context to perform the selection on.
985
+ branch (Any, optional): Branch to use for the selection.
986
+ **kwargs: Additional arguments for the selection.
987
+
988
+ Returns:
989
+ Any: The selection response.
990
+ """
991
+ branch = branch or self.branch
992
+
993
+ if not form:
994
+ form = template(
995
+ choices=choices,
996
+ reason=reason,
997
+ confidence_score=confidence_score,
998
+ instruction=instruction,
999
+ context=context,
1000
+ )
1001
+
1002
+ return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
1003
+
1004
+ async def _predict(
1005
+ self,
1006
+ form=None,
1007
+ num_sentences=None,
1008
+ reason=False,
1009
+ confidence_score=None,
1010
+ instruction=None,
1011
+ context=None,
1012
+ branch=None,
1013
+ template=None,
1014
+ **kwargs,
1015
+ ):
1016
+ """
1017
+ Predicts a response based on the provided parameters.
1018
+
1019
+ Args:
1020
+ form: Form data.
1021
+ num_sentences: Number of sentences for the prediction.
1022
+ reason: Flag indicating if reason should be included.
1023
+ confidence_score: Confidence score for the prediction.
1024
+ instruction: Instruction for the prediction.
1025
+ context: Context to perform the prediction on.
1026
+ branch: Branch instance.
1027
+ template: Template for the prediction.
1028
+ kwargs: Additional keyword arguments.
1029
+
1030
+ Returns:
1031
+ Any: The prediction response.
1032
+ """
1033
+ branch = branch or self.branch
1034
+
1035
+ if not form:
1036
+ form = template(
1037
+ instruction=instruction,
1038
+ context=context,
1039
+ num_sentences=num_sentences,
1040
+ confidence_score=confidence_score,
1041
+ reason=reason,
1042
+ )
1043
+
1044
+ return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
1045
+
1046
+ async def _score(
1047
+ self,
1048
+ form=None,
1049
+ score_range=None,
1050
+ include_endpoints=None,
1051
+ num_digit=None,
1052
+ reason=False,
1053
+ confidence_score=None,
1054
+ instruction=None,
1055
+ context=None,
1056
+ branch=None,
1057
+ template=None,
1058
+ **kwargs,
1059
+ ):
1060
+ """
1061
+ Scores a response based on the provided parameters.
1062
+
1063
+ Args:
1064
+ form: Form data.
1065
+ score_range: Range for score.
1066
+ include_endpoints: Flag indicating if endpoints should be included.
1067
+ num_digit: Number of digits for score.
1068
+ reason: Flag indicating if reason should be included.
1069
+ confidence_score: Confidence score for the score.
1070
+ instruction: Instruction for the score.
1071
+ context: Context to perform the score on.
1072
+ branch: Branch instance.
1073
+ template: Template for the score.
1074
+ kwargs: Additional keyword arguments.
1075
+
1076
+ Returns:
1077
+ Any: The score response.
1078
+ """
1079
+ branch = branch or self.branch
1080
+ if not form:
1081
+ form = template(
1082
+ score_range=score_range,
1083
+ include_endpoints=include_endpoints,
1084
+ num_digit=num_digit,
1085
+ reason=reason,
1086
+ confidence_score=confidence_score,
1087
+ instruction=instruction,
1088
+ context=context,
1089
+ )
1090
+
1091
+ return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
1092
+
1093
+ async def _plan(
1094
+ self,
1095
+ form=None,
1096
+ num_step=None,
1097
+ reason=False,
1098
+ confidence_score=None,
1099
+ instruction=None,
1100
+ context=None,
1101
+ branch=None,
1102
+ template=None,
1103
+ **kwargs,
1104
+ ):
1105
+ """
1106
+ Plans a response based on the provided parameters.
1107
+
1108
+ Args:
1109
+ form: Form data.
1110
+ num_step: Number of steps for the plan.
1111
+ reason: Flag indicating if reason should be included.
1112
+ confidence_score: Confidence score for the plan.
1113
+ instruction: Instruction for the plan.
1114
+ context: Context to perform the plan on.
1115
+ branch: Branch instance.
1116
+ template: Template for the plan.
1117
+ kwargs: Additional keyword arguments.
1118
+
1119
+ Returns:
1120
+ Any: The plan response.
1121
+ """
1122
+ branch = branch or self.branch
1123
+ template = template or self.default_template
1124
+
1125
+ if not form:
1126
+ form = template(
1127
+ instruction=instruction,
1128
+ context=context,
1129
+ num_step=num_step,
1130
+ reason=reason,
1131
+ confidence_score=confidence_score,
1132
+ )
1133
+
1134
+ return await self._chat(form=form, **kwargs)
1135
+
1136
+ @staticmethod
1137
+ def _process_model_response(content_, requested_fields):
1138
+ """
1139
+ Processes the model response content.
1140
+
1141
+ Args:
1142
+ content_: The content data.
1143
+ requested_fields: Fields requested in the response.
1144
+
1145
+ Returns:
1146
+ Any: The processed response.
1147
+ """
1148
+ out_ = content_.get("content", "")
1149
+ if out_ == "":
1150
+ out_ = content_
1151
+
1152
+ if requested_fields:
1153
+ with contextlib.suppress(Exception):
1154
+ return StringMatch.force_validate_dict(out_, requested_fields)
1155
+
1156
+ if isinstance(out_, str):
1157
+ with contextlib.suppress(Exception):
1158
+ return ParseUtil.fuzzy_parse_json(out_)
1159
+
1160
+ with contextlib.suppress(Exception):
1161
+ return ParseUtil.extract_json_block(out_)
1162
+
1163
+ with contextlib.suppress(Exception):
1164
+ match = re.search(r"```json\n({.*?})\n```", out_, re.DOTALL)
1165
+ if match:
1166
+ return ParseUtil.fuzzy_parse_json(match.group(1))
1167
+
1168
+ return out_