lionagi 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +60 -5
- lionagi/core/__init__.py +0 -25
- lionagi/core/_setting/_setting.py +59 -0
- lionagi/core/action/__init__.py +14 -0
- lionagi/core/action/function_calling.py +136 -0
- lionagi/core/action/manual.py +1 -0
- lionagi/core/action/node.py +109 -0
- lionagi/core/action/tool.py +114 -0
- lionagi/core/action/tool_manager.py +356 -0
- lionagi/core/agent/base_agent.py +27 -13
- lionagi/core/agent/eval/evaluator.py +1 -0
- lionagi/core/agent/eval/vote.py +40 -0
- lionagi/core/agent/learn/learner.py +59 -0
- lionagi/core/agent/plan/unit_template.py +1 -0
- lionagi/core/collections/__init__.py +17 -0
- lionagi/core/{generic/data_logger.py → collections/_logger.py} +69 -55
- lionagi/core/collections/abc/__init__.py +53 -0
- lionagi/core/collections/abc/component.py +615 -0
- lionagi/core/collections/abc/concepts.py +297 -0
- lionagi/core/collections/abc/exceptions.py +150 -0
- lionagi/core/collections/abc/util.py +45 -0
- lionagi/core/collections/exchange.py +161 -0
- lionagi/core/collections/flow.py +426 -0
- lionagi/core/collections/model.py +419 -0
- lionagi/core/collections/pile.py +913 -0
- lionagi/core/collections/progression.py +236 -0
- lionagi/core/collections/util.py +64 -0
- lionagi/core/director/direct.py +314 -0
- lionagi/core/director/director.py +2 -0
- lionagi/core/{execute/branch_executor.py → engine/branch_engine.py} +134 -97
- lionagi/core/{execute/instruction_map_executor.py → engine/instruction_map_engine.py} +80 -55
- lionagi/{experimental/directive/evaluator → core/engine}/script_engine.py +17 -1
- lionagi/core/executor/base_executor.py +90 -0
- lionagi/core/{execute/structure_executor.py → executor/graph_executor.py} +62 -66
- lionagi/core/{execute → executor}/neo4j_executor.py +70 -67
- lionagi/core/generic/__init__.py +3 -33
- lionagi/core/generic/edge.py +29 -79
- lionagi/core/generic/edge_condition.py +16 -0
- lionagi/core/generic/graph.py +236 -0
- lionagi/core/generic/hyperedge.py +1 -0
- lionagi/core/generic/node.py +156 -221
- lionagi/core/generic/tree.py +48 -0
- lionagi/core/generic/tree_node.py +79 -0
- lionagi/core/mail/__init__.py +12 -0
- lionagi/core/mail/mail.py +25 -0
- lionagi/core/mail/mail_manager.py +139 -58
- lionagi/core/mail/package.py +45 -0
- lionagi/core/mail/start_mail.py +36 -0
- lionagi/core/message/__init__.py +19 -0
- lionagi/core/message/action_request.py +133 -0
- lionagi/core/message/action_response.py +135 -0
- lionagi/core/message/assistant_response.py +95 -0
- lionagi/core/message/instruction.py +234 -0
- lionagi/core/message/message.py +101 -0
- lionagi/core/message/system.py +86 -0
- lionagi/core/message/util.py +283 -0
- lionagi/core/report/__init__.py +4 -0
- lionagi/core/report/base.py +217 -0
- lionagi/core/report/form.py +231 -0
- lionagi/core/report/report.py +166 -0
- lionagi/core/report/util.py +28 -0
- lionagi/core/rule/_default.py +16 -0
- lionagi/core/rule/action.py +99 -0
- lionagi/core/rule/base.py +238 -0
- lionagi/core/rule/boolean.py +56 -0
- lionagi/core/rule/choice.py +47 -0
- lionagi/core/rule/mapping.py +96 -0
- lionagi/core/rule/number.py +71 -0
- lionagi/core/rule/rulebook.py +109 -0
- lionagi/core/rule/string.py +52 -0
- lionagi/core/rule/util.py +35 -0
- lionagi/core/session/branch.py +431 -0
- lionagi/core/session/directive_mixin.py +287 -0
- lionagi/core/session/session.py +229 -903
- lionagi/core/structure/__init__.py +1 -0
- lionagi/core/structure/chain.py +1 -0
- lionagi/core/structure/forest.py +1 -0
- lionagi/core/structure/graph.py +1 -0
- lionagi/core/structure/tree.py +1 -0
- lionagi/core/unit/__init__.py +5 -0
- lionagi/core/unit/parallel_unit.py +245 -0
- lionagi/core/unit/template/action.py +81 -0
- lionagi/core/unit/template/base.py +51 -0
- lionagi/core/unit/template/plan.py +84 -0
- lionagi/core/unit/template/predict.py +109 -0
- lionagi/core/unit/template/score.py +124 -0
- lionagi/core/unit/template/select.py +104 -0
- lionagi/core/unit/unit.py +362 -0
- lionagi/core/unit/unit_form.py +305 -0
- lionagi/core/unit/unit_mixin.py +1168 -0
- lionagi/core/unit/util.py +71 -0
- lionagi/core/validator/validator.py +364 -0
- lionagi/core/work/work.py +76 -0
- lionagi/core/work/work_function.py +101 -0
- lionagi/core/work/work_queue.py +103 -0
- lionagi/core/work/worker.py +258 -0
- lionagi/core/work/worklog.py +120 -0
- lionagi/experimental/compressor/base.py +46 -0
- lionagi/experimental/compressor/llm_compressor.py +247 -0
- lionagi/experimental/compressor/llm_summarizer.py +61 -0
- lionagi/experimental/compressor/util.py +70 -0
- lionagi/experimental/directive/__init__.py +19 -0
- lionagi/experimental/directive/parser/base_parser.py +69 -2
- lionagi/experimental/directive/{template_ → template}/base_template.py +17 -1
- lionagi/{libs/ln_tokenizer.py → experimental/directive/tokenizer.py} +16 -0
- lionagi/experimental/{directive/evaluator → evaluator}/ast_evaluator.py +16 -0
- lionagi/experimental/{directive/evaluator → evaluator}/base_evaluator.py +16 -0
- lionagi/experimental/knowledge/base.py +10 -0
- lionagi/experimental/memory/__init__.py +0 -0
- lionagi/experimental/strategies/__init__.py +0 -0
- lionagi/experimental/strategies/base.py +1 -0
- lionagi/integrations/bridge/langchain_/documents.py +4 -0
- lionagi/integrations/bridge/llamaindex_/index.py +30 -0
- lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +6 -0
- lionagi/integrations/chunker/chunk.py +161 -24
- lionagi/integrations/config/oai_configs.py +34 -3
- lionagi/integrations/config/openrouter_configs.py +14 -2
- lionagi/integrations/loader/load.py +122 -21
- lionagi/integrations/loader/load_util.py +6 -77
- lionagi/integrations/provider/_mapping.py +46 -0
- lionagi/integrations/provider/litellm.py +2 -1
- lionagi/integrations/provider/mlx_service.py +16 -9
- lionagi/integrations/provider/oai.py +91 -4
- lionagi/integrations/provider/ollama.py +6 -5
- lionagi/integrations/provider/openrouter.py +115 -8
- lionagi/integrations/provider/services.py +2 -2
- lionagi/integrations/provider/transformers.py +18 -22
- lionagi/integrations/storage/__init__.py +3 -3
- lionagi/integrations/storage/neo4j.py +52 -60
- lionagi/integrations/storage/storage_util.py +44 -46
- lionagi/integrations/storage/structure_excel.py +43 -26
- lionagi/integrations/storage/to_excel.py +11 -4
- lionagi/libs/__init__.py +22 -1
- lionagi/libs/ln_api.py +75 -20
- lionagi/libs/ln_context.py +37 -0
- lionagi/libs/ln_convert.py +21 -9
- lionagi/libs/ln_func_call.py +69 -28
- lionagi/libs/ln_image.py +107 -0
- lionagi/libs/ln_nested.py +26 -11
- lionagi/libs/ln_parse.py +82 -23
- lionagi/libs/ln_queue.py +16 -0
- lionagi/libs/ln_tokenize.py +164 -0
- lionagi/libs/ln_validate.py +16 -0
- lionagi/libs/special_tokens.py +172 -0
- lionagi/libs/sys_util.py +95 -24
- lionagi/lions/coder/code_form.py +13 -0
- lionagi/lions/coder/coder.py +50 -3
- lionagi/lions/coder/util.py +30 -25
- lionagi/tests/libs/test_func_call.py +23 -21
- lionagi/tests/libs/test_nested.py +36 -21
- lionagi/tests/libs/test_parse.py +1 -1
- lionagi/tests/test_core/collections/__init__.py +0 -0
- lionagi/tests/test_core/collections/test_component.py +206 -0
- lionagi/tests/test_core/collections/test_exchange.py +138 -0
- lionagi/tests/test_core/collections/test_flow.py +145 -0
- lionagi/tests/test_core/collections/test_pile.py +171 -0
- lionagi/tests/test_core/collections/test_progression.py +129 -0
- lionagi/tests/test_core/generic/test_edge.py +67 -0
- lionagi/tests/test_core/generic/test_graph.py +96 -0
- lionagi/tests/test_core/generic/test_node.py +106 -0
- lionagi/tests/test_core/generic/test_tree_node.py +73 -0
- lionagi/tests/test_core/test_branch.py +115 -294
- lionagi/tests/test_core/test_form.py +46 -0
- lionagi/tests/test_core/test_report.py +105 -0
- lionagi/tests/test_core/test_validator.py +111 -0
- lionagi/version.py +1 -1
- lionagi-0.2.1.dist-info/LICENSE +202 -0
- lionagi-0.2.1.dist-info/METADATA +272 -0
- lionagi-0.2.1.dist-info/RECORD +240 -0
- lionagi/core/branch/base.py +0 -653
- lionagi/core/branch/branch.py +0 -474
- lionagi/core/branch/flow_mixin.py +0 -96
- lionagi/core/branch/util.py +0 -323
- lionagi/core/direct/__init__.py +0 -19
- lionagi/core/direct/cot.py +0 -123
- lionagi/core/direct/plan.py +0 -164
- lionagi/core/direct/predict.py +0 -166
- lionagi/core/direct/react.py +0 -171
- lionagi/core/direct/score.py +0 -279
- lionagi/core/direct/select.py +0 -170
- lionagi/core/direct/sentiment.py +0 -1
- lionagi/core/direct/utils.py +0 -110
- lionagi/core/direct/vote.py +0 -64
- lionagi/core/execute/base_executor.py +0 -47
- lionagi/core/flow/baseflow.py +0 -23
- lionagi/core/flow/monoflow/ReAct.py +0 -240
- lionagi/core/flow/monoflow/__init__.py +0 -9
- lionagi/core/flow/monoflow/chat.py +0 -95
- lionagi/core/flow/monoflow/chat_mixin.py +0 -253
- lionagi/core/flow/monoflow/followup.py +0 -215
- lionagi/core/flow/polyflow/__init__.py +0 -1
- lionagi/core/flow/polyflow/chat.py +0 -251
- lionagi/core/form/action_form.py +0 -26
- lionagi/core/form/field_validator.py +0 -287
- lionagi/core/form/form.py +0 -302
- lionagi/core/form/mixin.py +0 -214
- lionagi/core/form/scored_form.py +0 -13
- lionagi/core/generic/action.py +0 -26
- lionagi/core/generic/component.py +0 -532
- lionagi/core/generic/condition.py +0 -46
- lionagi/core/generic/mail.py +0 -90
- lionagi/core/generic/mailbox.py +0 -36
- lionagi/core/generic/relation.py +0 -70
- lionagi/core/generic/signal.py +0 -22
- lionagi/core/generic/structure.py +0 -362
- lionagi/core/generic/transfer.py +0 -20
- lionagi/core/generic/work.py +0 -40
- lionagi/core/graph/graph.py +0 -126
- lionagi/core/graph/tree.py +0 -190
- lionagi/core/mail/schema.py +0 -63
- lionagi/core/messages/schema.py +0 -325
- lionagi/core/tool/__init__.py +0 -5
- lionagi/core/tool/tool.py +0 -28
- lionagi/core/tool/tool_manager.py +0 -283
- lionagi/experimental/report/form.py +0 -64
- lionagi/experimental/report/report.py +0 -138
- lionagi/experimental/report/util.py +0 -47
- lionagi/experimental/tool/function_calling.py +0 -43
- lionagi/experimental/tool/manual.py +0 -66
- lionagi/experimental/tool/schema.py +0 -59
- lionagi/experimental/tool/tool_manager.py +0 -138
- lionagi/experimental/tool/util.py +0 -16
- lionagi/experimental/validator/rule.py +0 -139
- lionagi/experimental/validator/validator.py +0 -56
- lionagi/experimental/work/__init__.py +0 -10
- lionagi/experimental/work/async_queue.py +0 -54
- lionagi/experimental/work/schema.py +0 -73
- lionagi/experimental/work/work_function.py +0 -67
- lionagi/experimental/work/worker.py +0 -56
- lionagi/experimental/work2/form.py +0 -371
- lionagi/experimental/work2/report.py +0 -289
- lionagi/experimental/work2/schema.py +0 -30
- lionagi/experimental/work2/tests.py +0 -72
- lionagi/experimental/work2/work_function.py +0 -89
- lionagi/experimental/work2/worker.py +0 -12
- lionagi/integrations/bridge/llamaindex_/get_index.py +0 -294
- lionagi/tests/test_core/generic/test_component.py +0 -89
- lionagi/tests/test_core/test_base_branch.py +0 -426
- lionagi/tests/test_core/test_chat_flow.py +0 -63
- lionagi/tests/test_core/test_mail_manager.py +0 -75
- lionagi/tests/test_core/test_prompts.py +0 -51
- lionagi/tests/test_core/test_session.py +0 -254
- lionagi/tests/test_core/test_session_base_util.py +0 -313
- lionagi/tests/test_core/test_tool_manager.py +0 -95
- lionagi-0.1.2.dist-info/LICENSE +0 -9
- lionagi-0.1.2.dist-info/METADATA +0 -174
- lionagi-0.1.2.dist-info/RECORD +0 -206
- /lionagi/core/{branch → _setting}/__init__.py +0 -0
- /lionagi/core/{execute → agent/eval}/__init__.py +0 -0
- /lionagi/core/{flow → agent/learn}/__init__.py +0 -0
- /lionagi/core/{form → agent/plan}/__init__.py +0 -0
- /lionagi/core/{branch/executable_branch.py → agent/plan/plan.py} +0 -0
- /lionagi/core/{graph → director}/__init__.py +0 -0
- /lionagi/core/{messages → engine}/__init__.py +0 -0
- /lionagi/{experimental/directive/evaluator → core/engine}/sandbox_.py +0 -0
- /lionagi/{experimental/directive/evaluator → core/executor}/__init__.py +0 -0
- /lionagi/{experimental/directive/template_ → core/rule}/__init__.py +0 -0
- /lionagi/{experimental/report → core/unit/template}/__init__.py +0 -0
- /lionagi/{experimental/tool → core/validator}/__init__.py +0 -0
- /lionagi/{experimental/validator → core/work}/__init__.py +0 -0
- /lionagi/experimental/{work2 → compressor}/__init__.py +0 -0
- /lionagi/{core/flow/mono_chat_mixin.py → experimental/directive/template/__init__.py} +0 -0
- /lionagi/experimental/directive/{schema.py → template/schema.py} +0 -0
- /lionagi/experimental/{work2/util.py → evaluator/__init__.py} +0 -0
- /lionagi/experimental/{work2/work.py → knowledge/__init__.py} +0 -0
- /lionagi/{tests/libs/test_async.py → experimental/knowledge/graph.py} +0 -0
- {lionagi-0.1.2.dist-info → lionagi-0.2.1.dist-info}/WHEEL +0 -0
- {lionagi-0.1.2.dist-info → lionagi-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1168 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2024 HaiyangLi
|
3
|
+
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
you may not use this file except in compliance with the License.
|
6
|
+
You may obtain a copy of the License at
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
The base directive module.
|
18
|
+
"""
|
19
|
+
|
20
|
+
import asyncio
|
21
|
+
import contextlib
|
22
|
+
import re
|
23
|
+
from abc import ABC
|
24
|
+
|
25
|
+
from typing import Any, Optional
|
26
|
+
|
27
|
+
from lionagi.libs import ParseUtil, StringMatch, to_list
|
28
|
+
from lionagi.libs.ln_nested import nmerge
|
29
|
+
from lionagi.core.collections.abc import ActionError
|
30
|
+
from lionagi.core.message import ActionRequest, ActionResponse, Instruction
|
31
|
+
from lionagi.core.message.util import _parse_action_request
|
32
|
+
from lionagi.core.report.form import Form
|
33
|
+
from lionagi.core.unit.util import process_tools
|
34
|
+
from lionagi.core.validator.validator import Validator
|
35
|
+
|
36
|
+
|
37
|
+
class DirectiveMixin(ABC):
|
38
|
+
"""
|
39
|
+
DirectiveMixin is a class for handling chat operations and
|
40
|
+
processing responses.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def _create_chat_config(
|
44
|
+
self,
|
45
|
+
system: Optional[str] = None,
|
46
|
+
instruction: Optional[str] = None,
|
47
|
+
context: Optional[str] = None,
|
48
|
+
images: Optional[str] = None,
|
49
|
+
sender: Optional[str] = None,
|
50
|
+
recipient: Optional[str] = None,
|
51
|
+
requested_fields: Optional[list] = None,
|
52
|
+
form: Form = None,
|
53
|
+
tools: bool = False,
|
54
|
+
branch: Optional[Any] = None,
|
55
|
+
**kwargs,
|
56
|
+
) -> Any:
|
57
|
+
"""
|
58
|
+
Create the chat configuration based on the provided parameters.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
system: System message.
|
62
|
+
instruction: Instruction message.
|
63
|
+
context: Context message.
|
64
|
+
sender: Sender identifier.
|
65
|
+
recipient: Recipient identifier.
|
66
|
+
requested_fields: Fields requested in the response.
|
67
|
+
form: Form data.
|
68
|
+
tools: Flag indicating if tools should be used.
|
69
|
+
branch: Branch instance.
|
70
|
+
kwargs: Additional keyword arguments.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
dict: The chat configuration.
|
74
|
+
"""
|
75
|
+
branch = branch or self.branch
|
76
|
+
|
77
|
+
if system:
|
78
|
+
branch.add_message(system=system)
|
79
|
+
|
80
|
+
if not form:
|
81
|
+
if recipient == "branch.ln_id":
|
82
|
+
recipient = branch.ln_id
|
83
|
+
|
84
|
+
branch.add_message(
|
85
|
+
instruction=instruction,
|
86
|
+
context=context,
|
87
|
+
sender=sender,
|
88
|
+
recipient=recipient,
|
89
|
+
requested_fields=requested_fields,
|
90
|
+
images=images,
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
instruct_ = Instruction.from_form(form)
|
94
|
+
branch.add_message(instruction=instruct_)
|
95
|
+
|
96
|
+
if "tool_parsed" in kwargs:
|
97
|
+
kwargs.pop("tool_parsed")
|
98
|
+
tool_kwarg = {"tools": tools}
|
99
|
+
kwargs = tool_kwarg | kwargs
|
100
|
+
elif tools and branch.has_tools:
|
101
|
+
kwargs = branch.tool_manager.parse_tool(tools=tools, **kwargs)
|
102
|
+
|
103
|
+
config = {**self.imodel.config, **kwargs}
|
104
|
+
if sender is not None:
|
105
|
+
config["sender"] = sender
|
106
|
+
|
107
|
+
return config
|
108
|
+
|
109
|
+
async def _call_chatcompletion(
|
110
|
+
self, imodel: Optional[Any] = None, branch: Optional[Any] = None, **kwargs
|
111
|
+
) -> Any:
|
112
|
+
"""
|
113
|
+
Calls the chat completion model.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
imodel: The model instance.
|
117
|
+
branch: The branch instance.
|
118
|
+
kwargs: Additional keyword arguments.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
Any: The chat completion result.
|
122
|
+
"""
|
123
|
+
imodel = imodel or self.imodel
|
124
|
+
branch = branch or self.branch
|
125
|
+
return await imodel.call_chat_completion(branch.to_chat_messages(), **kwargs)
|
126
|
+
|
127
|
+
async def _process_chatcompletion(
|
128
|
+
self,
|
129
|
+
payload: dict,
|
130
|
+
completion: dict,
|
131
|
+
sender: str,
|
132
|
+
invoke_tool: bool = True,
|
133
|
+
branch: Optional[Any] = None,
|
134
|
+
action_request: Optional[Any] = None,
|
135
|
+
costs=None,
|
136
|
+
) -> Any:
|
137
|
+
"""
|
138
|
+
Processes the chat completion response.
|
139
|
+
Currently only support last message for function calling
|
140
|
+
|
141
|
+
Args:
|
142
|
+
payload: The payload data.
|
143
|
+
completion: The completion data.
|
144
|
+
sender: The sender identifier.
|
145
|
+
invoke_tool: Flag indicating if tools should be invoked.
|
146
|
+
branch: The branch instance.
|
147
|
+
action_request: The action request instance.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
Any: The processed result.
|
151
|
+
"""
|
152
|
+
branch = branch or self.branch
|
153
|
+
_msg = None
|
154
|
+
|
155
|
+
if "choices" in completion:
|
156
|
+
payload.pop("messages", None)
|
157
|
+
branch.update_last_instruction_meta(payload)
|
158
|
+
_choices = completion.pop("choices", None)
|
159
|
+
|
160
|
+
def process_completion_choice(choice, price=None):
|
161
|
+
if isinstance(choice, dict):
|
162
|
+
msg = choice.pop("message", None)
|
163
|
+
_completion = completion.copy()
|
164
|
+
_completion.update(choice)
|
165
|
+
branch.add_message(
|
166
|
+
assistant_response=msg,
|
167
|
+
metadata=_completion,
|
168
|
+
sender=sender,
|
169
|
+
)
|
170
|
+
|
171
|
+
a = branch.messages[-1]._meta_get(
|
172
|
+
["extra", "usage", "prompt_tokens"], 0
|
173
|
+
)
|
174
|
+
b = branch.messages[-1]._meta_get(
|
175
|
+
["extra", "usage", "completion_tokens"], 0
|
176
|
+
)
|
177
|
+
m = completion.get("model", None)
|
178
|
+
if m:
|
179
|
+
ttl = (a * price[0] + b * price[1]) / 1000000
|
180
|
+
branch.messages[-1]._meta_insert(["extra", "usage", "expense"], ttl)
|
181
|
+
return msg
|
182
|
+
|
183
|
+
if _choices and not isinstance(_choices, list):
|
184
|
+
_choices = [_choices]
|
185
|
+
|
186
|
+
if _choices and isinstance(_choices, list):
|
187
|
+
for _choice in _choices:
|
188
|
+
_msg = process_completion_choice(_choice, price=costs)
|
189
|
+
|
190
|
+
branch.imodel.status_tracker.num_tasks_succeeded += 1
|
191
|
+
else:
|
192
|
+
branch.imodel.status_tracker.num_tasks_failed += 1
|
193
|
+
|
194
|
+
return await self._process_action_request(
|
195
|
+
_msg=_msg,
|
196
|
+
branch=branch,
|
197
|
+
invoke_tool=invoke_tool,
|
198
|
+
action_request=action_request,
|
199
|
+
)
|
200
|
+
|
201
|
+
async def _process_action_request(
|
202
|
+
self,
|
203
|
+
_msg: Optional[dict] = None,
|
204
|
+
branch: Optional[Any] = None,
|
205
|
+
invoke_tool: bool = True,
|
206
|
+
action_request: Optional[Any] = None,
|
207
|
+
) -> Any:
|
208
|
+
"""
|
209
|
+
Processes an action request from the assistant response.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
_msg: The message data.
|
213
|
+
branch: The branch instance.
|
214
|
+
invoke_tool: Flag indicating if tools should be invoked.
|
215
|
+
action_request: The action request instance.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
Any: The processed result.
|
219
|
+
"""
|
220
|
+
action_request = action_request or _parse_action_request(_msg)
|
221
|
+
if action_request is None:
|
222
|
+
return _msg if _msg else False
|
223
|
+
|
224
|
+
if action_request:
|
225
|
+
for i in action_request:
|
226
|
+
if i.function in branch.tool_manager.registry:
|
227
|
+
i.recipient = branch.tool_manager.registry[i.function].ln_id
|
228
|
+
else:
|
229
|
+
raise ActionError(f"Tool {i.function} not found in registry")
|
230
|
+
branch.add_message(action_request=i, recipient=i.recipient)
|
231
|
+
|
232
|
+
if invoke_tool:
|
233
|
+
tasks = []
|
234
|
+
for i in action_request:
|
235
|
+
tool = branch.tool_manager.registry[i.function]
|
236
|
+
tasks.append(asyncio.create_task(tool.invoke(i.arguments)))
|
237
|
+
results = await asyncio.gather(*tasks)
|
238
|
+
|
239
|
+
for idx, item in enumerate(results):
|
240
|
+
if item is not None:
|
241
|
+
branch.add_message(
|
242
|
+
action_request=action_request[idx],
|
243
|
+
func_outputs=item,
|
244
|
+
sender=action_request[idx].recipient,
|
245
|
+
recipient=action_request[idx].sender,
|
246
|
+
)
|
247
|
+
|
248
|
+
return None
|
249
|
+
|
250
|
+
async def _output(
|
251
|
+
self,
|
252
|
+
payload: dict,
|
253
|
+
completion: dict,
|
254
|
+
sender: str,
|
255
|
+
invoke_tool: bool,
|
256
|
+
requested_fields: dict,
|
257
|
+
form: Form = None,
|
258
|
+
return_form: bool = True,
|
259
|
+
strict: bool = False,
|
260
|
+
rulebook: Any = None,
|
261
|
+
use_annotation: bool = True,
|
262
|
+
template_name: str = None,
|
263
|
+
costs=None,
|
264
|
+
) -> Any:
|
265
|
+
"""
|
266
|
+
Outputs the final processed response.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
payload: The payload data.
|
270
|
+
completion: The completion data.
|
271
|
+
sender: The sender identifier.
|
272
|
+
invoke_tool: Flag indicating if tools should be invoked.
|
273
|
+
requested_fields: Fields requested in the response.
|
274
|
+
form: Form data.
|
275
|
+
return_form: Flag indicating if form should be returned.
|
276
|
+
strict: Flag indicating if strict validation should be applied.
|
277
|
+
rulebook: Rulebook instance for validation.
|
278
|
+
use_annotation: Flag indicating if annotations should be used.
|
279
|
+
template_name: Template name for form.
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
Any: The processed response.
|
283
|
+
"""
|
284
|
+
_msg = await self._process_chatcompletion(
|
285
|
+
payload=payload,
|
286
|
+
completion=completion,
|
287
|
+
sender=sender,
|
288
|
+
invoke_tool=invoke_tool,
|
289
|
+
costs=costs,
|
290
|
+
)
|
291
|
+
|
292
|
+
if _msg is None:
|
293
|
+
return None
|
294
|
+
|
295
|
+
response_ = self._process_model_response(_msg, requested_fields)
|
296
|
+
|
297
|
+
if form:
|
298
|
+
validator = Validator(rulebook=rulebook) if rulebook else self.validator
|
299
|
+
form = await validator.validate_response(
|
300
|
+
form=form,
|
301
|
+
response=response_,
|
302
|
+
strict=strict,
|
303
|
+
use_annotation=use_annotation,
|
304
|
+
)
|
305
|
+
if template_name:
|
306
|
+
form.template_name = template_name
|
307
|
+
|
308
|
+
return (
|
309
|
+
form
|
310
|
+
if return_form
|
311
|
+
else {
|
312
|
+
i: form.work_fields[i]
|
313
|
+
for i in form.requested_fields
|
314
|
+
if form.work_fields[i] is not None
|
315
|
+
}
|
316
|
+
)
|
317
|
+
|
318
|
+
return response_
|
319
|
+
|
320
|
+
async def _base_chat(
|
321
|
+
self,
|
322
|
+
instruction: Any = None,
|
323
|
+
*,
|
324
|
+
system: Any = None,
|
325
|
+
context: Any = None,
|
326
|
+
sender: Any = None,
|
327
|
+
recipient: Any = None,
|
328
|
+
requested_fields: dict = None,
|
329
|
+
form: Form = None,
|
330
|
+
tools: Any = False,
|
331
|
+
images: Optional[str] = None,
|
332
|
+
invoke_tool: bool = True,
|
333
|
+
return_form: bool = True,
|
334
|
+
strict: bool = False,
|
335
|
+
rulebook: Any = None,
|
336
|
+
imodel: Any = None,
|
337
|
+
use_annotation: bool = True,
|
338
|
+
branch: Any = None,
|
339
|
+
clear_messages: bool = False,
|
340
|
+
return_branch: bool = False,
|
341
|
+
**kwargs,
|
342
|
+
) -> Any:
|
343
|
+
"""
|
344
|
+
Handles the base chat operation by configuring the chat and
|
345
|
+
processing the response.
|
346
|
+
|
347
|
+
Args:
|
348
|
+
instruction: Instruction message.
|
349
|
+
system: System message.
|
350
|
+
context: Context message.
|
351
|
+
sender: Sender identifier.
|
352
|
+
recipient: Recipient identifier.
|
353
|
+
requested_fields: Fields requested in the response.
|
354
|
+
form: Form data.
|
355
|
+
tools: Flag indicating if tools should be used.
|
356
|
+
invoke_tool: Flag indicating if tools should be invoked.
|
357
|
+
return_form: Flag indicating if form should be returned.
|
358
|
+
strict: Flag indicating if strict validation should be applied.
|
359
|
+
rulebook: Rulebook instance for validation.
|
360
|
+
imodel: Model instance.
|
361
|
+
use_annotation: Flag indicating if annotations should be used.
|
362
|
+
branch: Branch instance.
|
363
|
+
clear_messages: Flag indicating if messages should be cleared.
|
364
|
+
return_branch: Flag indicating if branch should be returned.
|
365
|
+
kwargs: Additional keyword arguments.
|
366
|
+
|
367
|
+
Returns:
|
368
|
+
Any: The processed response and branch.
|
369
|
+
"""
|
370
|
+
branch = branch or self.branch
|
371
|
+
if clear_messages:
|
372
|
+
branch.clear()
|
373
|
+
branch.set_system(system)
|
374
|
+
|
375
|
+
config = self._create_chat_config(
|
376
|
+
system=system,
|
377
|
+
instruction=instruction,
|
378
|
+
context=context,
|
379
|
+
sender=sender,
|
380
|
+
recipient=recipient,
|
381
|
+
requested_fields=requested_fields,
|
382
|
+
form=form,
|
383
|
+
tools=tools,
|
384
|
+
branch=branch,
|
385
|
+
images=images,
|
386
|
+
**kwargs,
|
387
|
+
)
|
388
|
+
|
389
|
+
payload, completion = await self._call_chatcompletion(
|
390
|
+
imodel=imodel, branch=branch, **config
|
391
|
+
)
|
392
|
+
|
393
|
+
imodel = imodel or self.imodel
|
394
|
+
out_ = await self._output(
|
395
|
+
payload=payload,
|
396
|
+
completion=completion,
|
397
|
+
sender=sender,
|
398
|
+
invoke_tool=invoke_tool,
|
399
|
+
requested_fields=requested_fields,
|
400
|
+
form=form,
|
401
|
+
return_form=return_form,
|
402
|
+
strict=strict,
|
403
|
+
rulebook=rulebook,
|
404
|
+
use_annotation=use_annotation,
|
405
|
+
costs=imodel.costs,
|
406
|
+
)
|
407
|
+
|
408
|
+
return out_, branch if return_branch else out_
|
409
|
+
|
410
|
+
async def _chat(
|
411
|
+
self,
|
412
|
+
instruction=None,
|
413
|
+
context=None,
|
414
|
+
system=None,
|
415
|
+
sender=None,
|
416
|
+
recipient=None,
|
417
|
+
branch=None,
|
418
|
+
requested_fields=None,
|
419
|
+
form: Form = None,
|
420
|
+
tools=False,
|
421
|
+
invoke_tool=True,
|
422
|
+
return_form=True,
|
423
|
+
strict=False,
|
424
|
+
rulebook=None,
|
425
|
+
imodel=None,
|
426
|
+
images: Optional[str] = None,
|
427
|
+
clear_messages=False,
|
428
|
+
use_annotation=True,
|
429
|
+
timeout: float = None,
|
430
|
+
return_branch=False,
|
431
|
+
**kwargs,
|
432
|
+
):
|
433
|
+
"""
|
434
|
+
Handles the chat operation.
|
435
|
+
|
436
|
+
Args:
|
437
|
+
instruction: Instruction message.
|
438
|
+
context: Context message.
|
439
|
+
system: System message.
|
440
|
+
sender: Sender identifier.
|
441
|
+
recipient: Recipient identifier.
|
442
|
+
branch: Branch instance.
|
443
|
+
requested_fields: Fields requested in the response.
|
444
|
+
form: Form data.
|
445
|
+
tools: Flag indicating if tools should be used.
|
446
|
+
invoke_tool: Flag indicating if tools should be invoked.
|
447
|
+
return_form: Flag indicating if form should be returned.
|
448
|
+
strict: Flag indicating if strict validation should be applied.
|
449
|
+
rulebook: Rulebook instance for validation.
|
450
|
+
imodel: Model instance.
|
451
|
+
clear_messages: Flag indicating if messages should be cleared.
|
452
|
+
use_annotation: Flag indicating if annotations should be used.
|
453
|
+
timeout: Timeout value.
|
454
|
+
return_branch: Flag indicating if branch should be returned.
|
455
|
+
kwargs: Additional keyword arguments.
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
Any: The processed response.
|
459
|
+
"""
|
460
|
+
a = await self._base_chat(
|
461
|
+
context=context,
|
462
|
+
instruction=instruction,
|
463
|
+
system=system,
|
464
|
+
sender=sender,
|
465
|
+
recipient=recipient,
|
466
|
+
requested_fields=requested_fields,
|
467
|
+
form=form,
|
468
|
+
tools=tools,
|
469
|
+
images=images,
|
470
|
+
invoke_tool=invoke_tool,
|
471
|
+
return_form=return_form,
|
472
|
+
strict=strict,
|
473
|
+
rulebook=rulebook,
|
474
|
+
imodel=imodel,
|
475
|
+
use_annotation=use_annotation,
|
476
|
+
timeout=timeout,
|
477
|
+
branch=branch,
|
478
|
+
clear_messages=clear_messages,
|
479
|
+
return_branch=return_branch,
|
480
|
+
**kwargs,
|
481
|
+
)
|
482
|
+
|
483
|
+
if isinstance(a, str):
|
484
|
+
return a
|
485
|
+
|
486
|
+
a = list(a)
|
487
|
+
|
488
|
+
if len(a) == 2 and a[0] == a[1]:
|
489
|
+
return a[0] if not isinstance(a[0], tuple) else a[0][0]
|
490
|
+
if len(a) == 2 and a[0] != a[1]:
|
491
|
+
return a[0], a[1]
|
492
|
+
if len(a) == 1 and isinstance(a[0], tuple):
|
493
|
+
return a[0][0]
|
494
|
+
if len(a) == 1 and not isinstance(a[0], tuple):
|
495
|
+
return a[0]
|
496
|
+
|
497
|
+
async def _direct(
|
498
|
+
self,
|
499
|
+
instruction=None,
|
500
|
+
context=None,
|
501
|
+
form: Form = None,
|
502
|
+
branch=None,
|
503
|
+
tools=None,
|
504
|
+
reason: bool = None,
|
505
|
+
predict: bool = None,
|
506
|
+
score: bool = None,
|
507
|
+
select: bool = None,
|
508
|
+
plan: bool = None,
|
509
|
+
allow_action: bool = None,
|
510
|
+
allow_extension: bool = None,
|
511
|
+
confidence: bool = None,
|
512
|
+
max_extension: int = None,
|
513
|
+
score_num_digits=None,
|
514
|
+
score_range=None,
|
515
|
+
select_choices=None,
|
516
|
+
plan_num_step=None,
|
517
|
+
predict_num_sentences=None,
|
518
|
+
clear_messages=False,
|
519
|
+
return_branch=False,
|
520
|
+
images: Optional[str] = None,
|
521
|
+
verbose=None,
|
522
|
+
**kwargs,
|
523
|
+
):
|
524
|
+
"""
|
525
|
+
Directs the operation based on the provided parameters.
|
526
|
+
|
527
|
+
Args:
|
528
|
+
instruction: Instruction message.
|
529
|
+
context: Context message.
|
530
|
+
form: Form data.
|
531
|
+
branch: Branch instance.
|
532
|
+
tools: Tools data.
|
533
|
+
reason: Flag indicating if reason should be included.
|
534
|
+
predict: Flag indicating if prediction should be included.
|
535
|
+
score: Flag indicating if score should be included.
|
536
|
+
select: Flag indicating if selection should be included.
|
537
|
+
plan: Flag indicating if plan should be included.
|
538
|
+
allow_action: Flag indicating if action should be allowed.
|
539
|
+
allow_extension: Flag indicating if extension should be allowed.
|
540
|
+
confidence: Flag indicating if confidence should be included.
|
541
|
+
max_extension: Maximum extension value.
|
542
|
+
score_num_digits: Number of digits for score.
|
543
|
+
score_range: Range for score.
|
544
|
+
select_choices: Choices for selection.
|
545
|
+
plan_num_step: Number of steps for plan.
|
546
|
+
predict_num_sentences: Number of sentences for prediction.
|
547
|
+
clear_messages: Flag indicating if messages should be cleared.
|
548
|
+
return_branch: Flag indicating if branch should be returned.
|
549
|
+
kwargs: Additional keyword arguments.
|
550
|
+
|
551
|
+
Returns:
|
552
|
+
Any: The processed response and branch.
|
553
|
+
"""
|
554
|
+
a = await self._base_direct(
|
555
|
+
instruction=instruction,
|
556
|
+
context=context,
|
557
|
+
form=form,
|
558
|
+
branch=branch,
|
559
|
+
tools=tools,
|
560
|
+
reason=reason,
|
561
|
+
predict=predict,
|
562
|
+
score=score,
|
563
|
+
select=select,
|
564
|
+
images=images,
|
565
|
+
plan=plan,
|
566
|
+
allow_action=allow_action,
|
567
|
+
allow_extension=allow_extension,
|
568
|
+
confidence=confidence,
|
569
|
+
max_extension=max_extension,
|
570
|
+
score_num_digits=score_num_digits,
|
571
|
+
score_range=score_range,
|
572
|
+
select_choices=select_choices,
|
573
|
+
plan_num_step=plan_num_step,
|
574
|
+
predict_num_sentences=predict_num_sentences,
|
575
|
+
clear_messages=clear_messages,
|
576
|
+
return_branch=return_branch,
|
577
|
+
verbose=verbose,
|
578
|
+
**kwargs,
|
579
|
+
)
|
580
|
+
|
581
|
+
a = list(a)
|
582
|
+
if len(a) == 2 and a[0] == a[1]:
|
583
|
+
return a[0] if not isinstance(a[0], tuple) else a[0][0]
|
584
|
+
|
585
|
+
return a[0], a[1]
|
586
|
+
|
587
|
+
async def _base_direct(
|
588
|
+
self,
|
589
|
+
instruction=None,
|
590
|
+
*,
|
591
|
+
context=None,
|
592
|
+
form: Form = None,
|
593
|
+
branch=None,
|
594
|
+
tools=None,
|
595
|
+
reason: bool = None,
|
596
|
+
predict: bool = None,
|
597
|
+
score: bool = None,
|
598
|
+
select: bool = None,
|
599
|
+
plan: bool = None,
|
600
|
+
allow_action: bool = None,
|
601
|
+
allow_extension: bool = None,
|
602
|
+
confidence: bool = None,
|
603
|
+
max_extension: int = None,
|
604
|
+
score_num_digits=None,
|
605
|
+
score_range=None,
|
606
|
+
select_choices=None,
|
607
|
+
plan_num_step=None,
|
608
|
+
predict_num_sentences=None,
|
609
|
+
clear_messages=False,
|
610
|
+
return_branch=False,
|
611
|
+
images: Optional[str] = None,
|
612
|
+
verbose=None,
|
613
|
+
**kwargs,
|
614
|
+
):
|
615
|
+
"""
|
616
|
+
Handles the base direct operation.
|
617
|
+
|
618
|
+
Args:
|
619
|
+
instruction: Instruction message.
|
620
|
+
context: Context message.
|
621
|
+
form: Form data.
|
622
|
+
branch: Branch instance.
|
623
|
+
tools: Tools data.
|
624
|
+
reason: Flag indicating if reason should be included.
|
625
|
+
predict: Flag indicating if prediction should be included.
|
626
|
+
score: Flag indicating if score should be included.
|
627
|
+
select: Flag indicating if selection should be included.
|
628
|
+
plan: Flag indicating if plan should be included.
|
629
|
+
allow_action: Flag indicating if action should be allowed.
|
630
|
+
allow_extension: Flag indicating if extension should be allowed.
|
631
|
+
confidence: Flag indicating if confidence should be included.
|
632
|
+
max_extension: Maximum extension value.
|
633
|
+
score_num_digits: Number of digits for score.
|
634
|
+
score_range: Range for score.
|
635
|
+
select_choices: Choices for selection.
|
636
|
+
plan_num_step: Number of steps for plan.
|
637
|
+
predict_num_sentences: Number of sentences for prediction.
|
638
|
+
clear_messages: Flag indicating if messages should be cleared.
|
639
|
+
return_branch: Flag indicating if branch should be returned.
|
640
|
+
kwargs: Additional keyword arguments.
|
641
|
+
|
642
|
+
Returns:
|
643
|
+
Any: The processed response and branch.
|
644
|
+
"""
|
645
|
+
# Ensure branch is initialized
|
646
|
+
branch = branch or self.branch
|
647
|
+
if clear_messages:
|
648
|
+
branch.clear()
|
649
|
+
|
650
|
+
# Set a default max_extension if allow_extension is True and max_extension is None
|
651
|
+
if allow_extension and not max_extension:
|
652
|
+
max_extension = 3 # Set a default limit for recursion
|
653
|
+
|
654
|
+
# Process tools if provided
|
655
|
+
if tools:
|
656
|
+
process_tools(tools, branch)
|
657
|
+
|
658
|
+
if allow_action and not tools:
|
659
|
+
tools = True
|
660
|
+
|
661
|
+
tool_schema=None
|
662
|
+
if tools:
|
663
|
+
tool_schema = branch.tool_manager.get_tool_schema(tools)
|
664
|
+
|
665
|
+
if not form:
|
666
|
+
form = self.default_template(
|
667
|
+
instruction=instruction,
|
668
|
+
context=context,
|
669
|
+
reason=reason,
|
670
|
+
predict=predict,
|
671
|
+
score=score,
|
672
|
+
select=select,
|
673
|
+
plan=plan,
|
674
|
+
tool_schema=tool_schema,
|
675
|
+
allow_action=allow_action,
|
676
|
+
allow_extension=allow_extension,
|
677
|
+
max_extension=max_extension,
|
678
|
+
confidence=confidence,
|
679
|
+
score_num_digits=score_num_digits,
|
680
|
+
score_range=score_range,
|
681
|
+
select_choices=select_choices,
|
682
|
+
plan_num_step=plan_num_step,
|
683
|
+
predict_num_sentences=predict_num_sentences,
|
684
|
+
)
|
685
|
+
|
686
|
+
elif form and "tool_schema" not in form._all_fields:
|
687
|
+
form.append_to_input("tool_schema")
|
688
|
+
form.tool_schema = tool_schema
|
689
|
+
|
690
|
+
else:
|
691
|
+
form.tool_schema = tool_schema
|
692
|
+
|
693
|
+
verbose = (
|
694
|
+
verbose
|
695
|
+
if verbose is not None and isinstance(verbose, bool)
|
696
|
+
else self.verbose
|
697
|
+
)
|
698
|
+
if verbose:
|
699
|
+
print("Chatting with model...")
|
700
|
+
|
701
|
+
# Call the base chat method
|
702
|
+
form = await self._chat(
|
703
|
+
form=form,
|
704
|
+
branch=branch,
|
705
|
+
images=images,
|
706
|
+
**kwargs,
|
707
|
+
)
|
708
|
+
|
709
|
+
# Handle actions if allowed and required
|
710
|
+
if allow_action and getattr(form, "action_required", None):
|
711
|
+
actions = getattr(form, "actions", None)
|
712
|
+
if actions:
|
713
|
+
if verbose:
|
714
|
+
print(
|
715
|
+
"Found action requests in model response. Processing actions..."
|
716
|
+
)
|
717
|
+
form = await self._act(form, branch, actions=actions)
|
718
|
+
if verbose:
|
719
|
+
print("Actions processed!")
|
720
|
+
|
721
|
+
last_form = form
|
722
|
+
|
723
|
+
ctr = 1
|
724
|
+
|
725
|
+
# Handle extensions if allowed and required
|
726
|
+
extension_forms = []
|
727
|
+
max_extension = max_extension if isinstance(max_extension, int) else 3
|
728
|
+
while (
|
729
|
+
allow_extension
|
730
|
+
and max_extension > 0
|
731
|
+
and getattr(last_form, "extension_required", None)
|
732
|
+
):
|
733
|
+
if getattr(last_form, "is_extension", None):
|
734
|
+
break
|
735
|
+
if verbose:
|
736
|
+
print(f"\nFound extension requests in model response.")
|
737
|
+
print(
|
738
|
+
f"------------------- Processing extension No.{ctr} -------------------"
|
739
|
+
)
|
740
|
+
|
741
|
+
max_extension -= 1
|
742
|
+
|
743
|
+
# new form cannot be extended, otherwise it will be an infinite loop
|
744
|
+
new_form = await self._extend(
|
745
|
+
tools=tools,
|
746
|
+
reason=reason,
|
747
|
+
predict=predict,
|
748
|
+
score=score,
|
749
|
+
select=select,
|
750
|
+
plan=getattr(last_form, "plan", None),
|
751
|
+
allow_action=allow_action,
|
752
|
+
confidence=confidence,
|
753
|
+
score_num_digits=score_num_digits,
|
754
|
+
score_range=score_range,
|
755
|
+
select_choices=select_choices,
|
756
|
+
predict_num_sentences=predict_num_sentences,
|
757
|
+
**kwargs,
|
758
|
+
)
|
759
|
+
|
760
|
+
if verbose:
|
761
|
+
print(f"------------------- Extension completed -------------------\n")
|
762
|
+
|
763
|
+
extension_forms.extend(new_form)
|
764
|
+
last_form = new_form[-1] if isinstance(new_form, list) else new_form
|
765
|
+
ctr += len(form)
|
766
|
+
|
767
|
+
if extension_forms:
|
768
|
+
if not getattr(form, "extension_forms", None):
|
769
|
+
form._add_field("extension_forms", list, None, [])
|
770
|
+
form.extension_forms.extend(extension_forms)
|
771
|
+
action_responses = [
|
772
|
+
i.action_response
|
773
|
+
for i in extension_forms
|
774
|
+
if getattr(i, "action_response", None) is not None
|
775
|
+
]
|
776
|
+
if not hasattr(form, "action_response"):
|
777
|
+
form.add_field("action_response", {})
|
778
|
+
|
779
|
+
for action_response in action_responses:
|
780
|
+
nmerge([form.action_response, action_response])
|
781
|
+
|
782
|
+
if "PLEASE_ACTION" in form.answer:
|
783
|
+
if verbose:
|
784
|
+
print("Analyzing action responses and generating answer...")
|
785
|
+
|
786
|
+
answer = await self._chat(
|
787
|
+
"please provide final answer basing on the above"
|
788
|
+
" information, provide answer value as a string only"
|
789
|
+
" do not return as json, do not include other information",
|
790
|
+
)
|
791
|
+
|
792
|
+
if isinstance(answer, dict):
|
793
|
+
a = answer.get("answer", None)
|
794
|
+
if a is not None:
|
795
|
+
answer = a
|
796
|
+
|
797
|
+
answer = str(answer).strip()
|
798
|
+
if answer.startswith("{") and answer.endswith("}"):
|
799
|
+
answer = answer[1:-1]
|
800
|
+
answer = answer.strip()
|
801
|
+
if '"answer":' in answer:
|
802
|
+
answer.replace('"answer":', "")
|
803
|
+
answer = answer.strip()
|
804
|
+
elif "'answer':" in answer:
|
805
|
+
answer.replace("'answer':", "")
|
806
|
+
answer = answer.strip()
|
807
|
+
|
808
|
+
form.answer = answer
|
809
|
+
|
810
|
+
return form, branch if return_branch else form
|
811
|
+
|
812
|
+
async def _extend(
|
813
|
+
self,
|
814
|
+
tools,
|
815
|
+
reason,
|
816
|
+
predict,
|
817
|
+
score,
|
818
|
+
select,
|
819
|
+
plan,
|
820
|
+
# image,
|
821
|
+
allow_action,
|
822
|
+
confidence,
|
823
|
+
score_num_digits,
|
824
|
+
score_range,
|
825
|
+
select_choices,
|
826
|
+
predict_num_sentences,
|
827
|
+
**kwargs,
|
828
|
+
):
|
829
|
+
"""
|
830
|
+
Handles the extension of the form based on the provided parameters.
|
831
|
+
|
832
|
+
Args:
|
833
|
+
form: Form data.
|
834
|
+
tools: Tools data.
|
835
|
+
reason: Flag indicating if reason should be included.
|
836
|
+
predict: Flag indicating if prediction should be included.
|
837
|
+
score: Flag indicating if score should be included.
|
838
|
+
select: Flag indicating if selection should be included.
|
839
|
+
plan: Flag indicating if plan should be included.
|
840
|
+
allow_action: Flag indicating if action should be allowed.
|
841
|
+
confidence: Flag indicating if confidence should be included.
|
842
|
+
score_num_digits: Number of digits for score.
|
843
|
+
score_range: Range for score.
|
844
|
+
select_choices: Choices for selection.
|
845
|
+
predict_num_sentences: Number of sentences for prediction.
|
846
|
+
allow_extension: Flag indicating if extension should be allowed.
|
847
|
+
max_extension: Maximum extension value.
|
848
|
+
kwargs: Additional keyword arguments.
|
849
|
+
|
850
|
+
Returns:
|
851
|
+
list: The extended forms.
|
852
|
+
"""
|
853
|
+
extension_forms = []
|
854
|
+
|
855
|
+
# Ensure the next step in the plan is handled
|
856
|
+
directive_kwargs = {
|
857
|
+
"tools": tools,
|
858
|
+
"reason": reason,
|
859
|
+
"predict": predict,
|
860
|
+
"score": score,
|
861
|
+
"select": select,
|
862
|
+
"allow_action": allow_action,
|
863
|
+
"confidence": confidence,
|
864
|
+
"score_num_digits": score_num_digits,
|
865
|
+
"score_range": score_range,
|
866
|
+
"select_choices": select_choices,
|
867
|
+
"predict_num_sentences": predict_num_sentences,
|
868
|
+
**kwargs,
|
869
|
+
}
|
870
|
+
|
871
|
+
if plan:
|
872
|
+
keys = [f"step_{i+1}" for i in range(len(plan))]
|
873
|
+
plan = StringMatch.force_validate_dict(plan, keys)
|
874
|
+
|
875
|
+
# If plan is provided, process each step
|
876
|
+
for i in keys:
|
877
|
+
directive_kwargs["instruction"] = plan[i]
|
878
|
+
last_form = await self._direct(**directive_kwargs)
|
879
|
+
last_form.is_extension = True
|
880
|
+
extension_forms.append(last_form)
|
881
|
+
directive_kwargs["max_extension"] -= 1
|
882
|
+
if not getattr(last_form, "extension_required", None):
|
883
|
+
break
|
884
|
+
|
885
|
+
else:
|
886
|
+
# Handle single step extension
|
887
|
+
last_form = await self._direct(**directive_kwargs)
|
888
|
+
last_form.is_extension = True
|
889
|
+
extension_forms.append(last_form)
|
890
|
+
|
891
|
+
return extension_forms
|
892
|
+
|
893
|
+
async def _act(self, form, branch, actions=None):
|
894
|
+
"""
|
895
|
+
Processes actions based on the provided form and actions.
|
896
|
+
|
897
|
+
Args:
|
898
|
+
form: Form data.
|
899
|
+
branch: Branch instance.
|
900
|
+
actions: Actions data.
|
901
|
+
|
902
|
+
Returns:
|
903
|
+
dict: The updated form.
|
904
|
+
"""
|
905
|
+
if getattr(form, "action_performed", None) is True:
|
906
|
+
return form
|
907
|
+
|
908
|
+
keys = [f"action_{i+1}" for i in range(len(actions))]
|
909
|
+
actions = StringMatch.force_validate_dict(actions, keys)
|
910
|
+
|
911
|
+
try:
|
912
|
+
requests = []
|
913
|
+
for k in keys:
|
914
|
+
_func = actions[k]["function"]
|
915
|
+
_func = _func.replace("functions.", "")
|
916
|
+
msg = ActionRequest(
|
917
|
+
function=_func,
|
918
|
+
arguments=actions[k]["arguments"],
|
919
|
+
sender=branch.ln_id,
|
920
|
+
recipient=branch.tool_manager.registry[_func].ln_id,
|
921
|
+
)
|
922
|
+
requests.append(msg)
|
923
|
+
branch.add_message(action_request=msg)
|
924
|
+
|
925
|
+
if requests:
|
926
|
+
out = await self._process_action_request(
|
927
|
+
branch=branch, invoke_tool=True, action_request=requests
|
928
|
+
)
|
929
|
+
|
930
|
+
if out is False:
|
931
|
+
raise ValueError("No requests found.")
|
932
|
+
|
933
|
+
len_actions = len(actions)
|
934
|
+
action_responses = [
|
935
|
+
i
|
936
|
+
for i in branch.messages[-len_actions:]
|
937
|
+
if isinstance(i, ActionResponse)
|
938
|
+
]
|
939
|
+
|
940
|
+
_action_responses = {}
|
941
|
+
for idx, item in enumerate(action_responses):
|
942
|
+
_action_responses[f"action_{idx+1}"] = item._to_dict()
|
943
|
+
|
944
|
+
form.append_to_request("action_response")
|
945
|
+
if (a := getattr(form, "action_response", None)) is None:
|
946
|
+
form.add_field("action_response", {})
|
947
|
+
|
948
|
+
len1 = len(form.action_response)
|
949
|
+
for k, v in _action_responses.items():
|
950
|
+
while k in form.action_response:
|
951
|
+
k = f"{k}_1"
|
952
|
+
form.action_response[k] = v
|
953
|
+
|
954
|
+
if len(form.action_response) > len1:
|
955
|
+
form.append_to_request("action_performed")
|
956
|
+
form.action_performed = True
|
957
|
+
return form
|
958
|
+
|
959
|
+
except Exception as e:
|
960
|
+
raise ValueError(f"Error processing action request: {e}")
|
961
|
+
|
962
|
+
async def _select(
|
963
|
+
self,
|
964
|
+
form=None,
|
965
|
+
choices=None,
|
966
|
+
reason=False,
|
967
|
+
confidence_score=None,
|
968
|
+
instruction=None,
|
969
|
+
template=None,
|
970
|
+
context=None,
|
971
|
+
branch=None,
|
972
|
+
**kwargs,
|
973
|
+
):
|
974
|
+
"""
|
975
|
+
Selects a response based on the provided parameters.
|
976
|
+
|
977
|
+
Args:
|
978
|
+
form (Any, optional): Form to create instruction from.
|
979
|
+
choices (Any, optional): Choices for the selection.
|
980
|
+
reason (bool, optional): Whether to include a reason for the selection.
|
981
|
+
confidence_score (Any, optional): Confidence score for the selection.
|
982
|
+
instruction (Any, optional): Instruction for the selection.
|
983
|
+
template (Any, optional): Template for the selection.
|
984
|
+
context (Any, optional): Context to perform the selection on.
|
985
|
+
branch (Any, optional): Branch to use for the selection.
|
986
|
+
**kwargs: Additional arguments for the selection.
|
987
|
+
|
988
|
+
Returns:
|
989
|
+
Any: The selection response.
|
990
|
+
"""
|
991
|
+
branch = branch or self.branch
|
992
|
+
|
993
|
+
if not form:
|
994
|
+
form = template(
|
995
|
+
choices=choices,
|
996
|
+
reason=reason,
|
997
|
+
confidence_score=confidence_score,
|
998
|
+
instruction=instruction,
|
999
|
+
context=context,
|
1000
|
+
)
|
1001
|
+
|
1002
|
+
return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
|
1003
|
+
|
1004
|
+
async def _predict(
|
1005
|
+
self,
|
1006
|
+
form=None,
|
1007
|
+
num_sentences=None,
|
1008
|
+
reason=False,
|
1009
|
+
confidence_score=None,
|
1010
|
+
instruction=None,
|
1011
|
+
context=None,
|
1012
|
+
branch=None,
|
1013
|
+
template=None,
|
1014
|
+
**kwargs,
|
1015
|
+
):
|
1016
|
+
"""
|
1017
|
+
Predicts a response based on the provided parameters.
|
1018
|
+
|
1019
|
+
Args:
|
1020
|
+
form: Form data.
|
1021
|
+
num_sentences: Number of sentences for the prediction.
|
1022
|
+
reason: Flag indicating if reason should be included.
|
1023
|
+
confidence_score: Confidence score for the prediction.
|
1024
|
+
instruction: Instruction for the prediction.
|
1025
|
+
context: Context to perform the prediction on.
|
1026
|
+
branch: Branch instance.
|
1027
|
+
template: Template for the prediction.
|
1028
|
+
kwargs: Additional keyword arguments.
|
1029
|
+
|
1030
|
+
Returns:
|
1031
|
+
Any: The prediction response.
|
1032
|
+
"""
|
1033
|
+
branch = branch or self.branch
|
1034
|
+
|
1035
|
+
if not form:
|
1036
|
+
form = template(
|
1037
|
+
instruction=instruction,
|
1038
|
+
context=context,
|
1039
|
+
num_sentences=num_sentences,
|
1040
|
+
confidence_score=confidence_score,
|
1041
|
+
reason=reason,
|
1042
|
+
)
|
1043
|
+
|
1044
|
+
return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
|
1045
|
+
|
1046
|
+
async def _score(
|
1047
|
+
self,
|
1048
|
+
form=None,
|
1049
|
+
score_range=None,
|
1050
|
+
include_endpoints=None,
|
1051
|
+
num_digit=None,
|
1052
|
+
reason=False,
|
1053
|
+
confidence_score=None,
|
1054
|
+
instruction=None,
|
1055
|
+
context=None,
|
1056
|
+
branch=None,
|
1057
|
+
template=None,
|
1058
|
+
**kwargs,
|
1059
|
+
):
|
1060
|
+
"""
|
1061
|
+
Scores a response based on the provided parameters.
|
1062
|
+
|
1063
|
+
Args:
|
1064
|
+
form: Form data.
|
1065
|
+
score_range: Range for score.
|
1066
|
+
include_endpoints: Flag indicating if endpoints should be included.
|
1067
|
+
num_digit: Number of digits for score.
|
1068
|
+
reason: Flag indicating if reason should be included.
|
1069
|
+
confidence_score: Confidence score for the score.
|
1070
|
+
instruction: Instruction for the score.
|
1071
|
+
context: Context to perform the score on.
|
1072
|
+
branch: Branch instance.
|
1073
|
+
template: Template for the score.
|
1074
|
+
kwargs: Additional keyword arguments.
|
1075
|
+
|
1076
|
+
Returns:
|
1077
|
+
Any: The score response.
|
1078
|
+
"""
|
1079
|
+
branch = branch or self.branch
|
1080
|
+
if not form:
|
1081
|
+
form = template(
|
1082
|
+
score_range=score_range,
|
1083
|
+
include_endpoints=include_endpoints,
|
1084
|
+
num_digit=num_digit,
|
1085
|
+
reason=reason,
|
1086
|
+
confidence_score=confidence_score,
|
1087
|
+
instruction=instruction,
|
1088
|
+
context=context,
|
1089
|
+
)
|
1090
|
+
|
1091
|
+
return await self._chat(form=form, return_form=True, branch=branch, **kwargs)
|
1092
|
+
|
1093
|
+
async def _plan(
|
1094
|
+
self,
|
1095
|
+
form=None,
|
1096
|
+
num_step=None,
|
1097
|
+
reason=False,
|
1098
|
+
confidence_score=None,
|
1099
|
+
instruction=None,
|
1100
|
+
context=None,
|
1101
|
+
branch=None,
|
1102
|
+
template=None,
|
1103
|
+
**kwargs,
|
1104
|
+
):
|
1105
|
+
"""
|
1106
|
+
Plans a response based on the provided parameters.
|
1107
|
+
|
1108
|
+
Args:
|
1109
|
+
form: Form data.
|
1110
|
+
num_step: Number of steps for the plan.
|
1111
|
+
reason: Flag indicating if reason should be included.
|
1112
|
+
confidence_score: Confidence score for the plan.
|
1113
|
+
instruction: Instruction for the plan.
|
1114
|
+
context: Context to perform the plan on.
|
1115
|
+
branch: Branch instance.
|
1116
|
+
template: Template for the plan.
|
1117
|
+
kwargs: Additional keyword arguments.
|
1118
|
+
|
1119
|
+
Returns:
|
1120
|
+
Any: The plan response.
|
1121
|
+
"""
|
1122
|
+
branch = branch or self.branch
|
1123
|
+
template = template or self.default_template
|
1124
|
+
|
1125
|
+
if not form:
|
1126
|
+
form = template(
|
1127
|
+
instruction=instruction,
|
1128
|
+
context=context,
|
1129
|
+
num_step=num_step,
|
1130
|
+
reason=reason,
|
1131
|
+
confidence_score=confidence_score,
|
1132
|
+
)
|
1133
|
+
|
1134
|
+
return await self._chat(form=form, **kwargs)
|
1135
|
+
|
1136
|
+
@staticmethod
|
1137
|
+
def _process_model_response(content_, requested_fields):
|
1138
|
+
"""
|
1139
|
+
Processes the model response content.
|
1140
|
+
|
1141
|
+
Args:
|
1142
|
+
content_: The content data.
|
1143
|
+
requested_fields: Fields requested in the response.
|
1144
|
+
|
1145
|
+
Returns:
|
1146
|
+
Any: The processed response.
|
1147
|
+
"""
|
1148
|
+
out_ = content_.get("content", "")
|
1149
|
+
if out_ == "":
|
1150
|
+
out_ = content_
|
1151
|
+
|
1152
|
+
if requested_fields:
|
1153
|
+
with contextlib.suppress(Exception):
|
1154
|
+
return StringMatch.force_validate_dict(out_, requested_fields)
|
1155
|
+
|
1156
|
+
if isinstance(out_, str):
|
1157
|
+
with contextlib.suppress(Exception):
|
1158
|
+
return ParseUtil.fuzzy_parse_json(out_)
|
1159
|
+
|
1160
|
+
with contextlib.suppress(Exception):
|
1161
|
+
return ParseUtil.extract_json_block(out_)
|
1162
|
+
|
1163
|
+
with contextlib.suppress(Exception):
|
1164
|
+
match = re.search(r"```json\n({.*?})\n```", out_, re.DOTALL)
|
1165
|
+
if match:
|
1166
|
+
return ParseUtil.fuzzy_parse_json(match.group(1))
|
1167
|
+
|
1168
|
+
return out_
|