zrb 1.0.0b2__py3-none-any.whl → 1.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__main__.py +3 -0
- zrb/builtin/llm/llm_chat.py +85 -5
- zrb/builtin/llm/previous-session.js +13 -0
- zrb/builtin/llm/tool/api.py +29 -0
- zrb/builtin/llm/tool/cli.py +1 -1
- zrb/builtin/llm/tool/rag.py +108 -145
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/client_method.py +6 -6
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py +3 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py +88 -44
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py +12 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py +28 -22
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py +6 -6
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py +43 -29
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_repository.py +8 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py +46 -14
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py +158 -20
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py +29 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py +36 -14
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py +14 -14
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py +34 -6
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py +2 -6
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py +41 -2
- zrb/builtin/todo.py +1 -0
- zrb/config.py +23 -4
- zrb/input/any_input.py +5 -0
- zrb/input/base_input.py +6 -0
- zrb/input/bool_input.py +2 -0
- zrb/input/float_input.py +2 -0
- zrb/input/int_input.py +2 -0
- zrb/input/option_input.py +2 -0
- zrb/input/password_input.py +2 -0
- zrb/input/text_input.py +2 -0
- zrb/runner/common_util.py +1 -1
- zrb/runner/web_route/error_page/show_error_page.py +2 -1
- zrb/runner/web_route/static/resources/session/current-session.js +4 -2
- zrb/runner/web_route/static/resources/session/event.js +8 -2
- zrb/runner/web_route/task_session_api_route.py +48 -3
- zrb/task/base_task.py +14 -13
- zrb/task/llm_task.py +214 -84
- zrb/util/llm/tool.py +3 -7
- {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/METADATA +2 -1
- {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/RECORD +45 -43
- {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/WHEEL +0 -0
- {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/entry_points.txt +0 -0
zrb/input/base_input.py
CHANGED
@@ -16,6 +16,7 @@ class BaseInput(AnyInput):
|
|
16
16
|
default_str: StrAttr = "",
|
17
17
|
auto_render: bool = True,
|
18
18
|
allow_empty: bool = False,
|
19
|
+
allow_positional_parsing: bool = True,
|
19
20
|
):
|
20
21
|
self._name = name
|
21
22
|
self._description = description
|
@@ -23,6 +24,7 @@ class BaseInput(AnyInput):
|
|
23
24
|
self._default_str = default_str
|
24
25
|
self._auto_render = auto_render
|
25
26
|
self._allow_empty = allow_empty
|
27
|
+
self._allow_positional_parsing = allow_positional_parsing
|
26
28
|
|
27
29
|
def __repr__(self):
|
28
30
|
return f"<{self.__class__.__name__} name={self._name}>"
|
@@ -39,6 +41,10 @@ class BaseInput(AnyInput):
|
|
39
41
|
def prompt_message(self) -> str:
|
40
42
|
return self._prompt if self._prompt is not None else self.name
|
41
43
|
|
44
|
+
@property
|
45
|
+
def allow_positional_parsing(self) -> bool:
|
46
|
+
return self._allow_positional_parsing
|
47
|
+
|
42
48
|
def to_html(self, ctx: AnySharedContext) -> str:
|
43
49
|
name = self.name
|
44
50
|
description = self.description
|
zrb/input/bool_input.py
CHANGED
@@ -13,6 +13,7 @@ class BoolInput(BaseInput):
|
|
13
13
|
default_str: StrAttr = "False",
|
14
14
|
auto_render: bool = True,
|
15
15
|
allow_empty: bool = False,
|
16
|
+
allow_positional_parsing: bool = True,
|
16
17
|
):
|
17
18
|
super().__init__(
|
18
19
|
name=name,
|
@@ -21,6 +22,7 @@ class BoolInput(BaseInput):
|
|
21
22
|
default_str=default_str,
|
22
23
|
auto_render=auto_render,
|
23
24
|
allow_empty=allow_empty,
|
25
|
+
allow_positional_parsing=allow_positional_parsing,
|
24
26
|
)
|
25
27
|
|
26
28
|
def to_html(self, ctx: AnySharedContext) -> str:
|
zrb/input/float_input.py
CHANGED
@@ -12,6 +12,7 @@ class FloatInput(BaseInput):
|
|
12
12
|
default_str: StrAttr = "0.0",
|
13
13
|
auto_render: bool = True,
|
14
14
|
allow_empty: bool = False,
|
15
|
+
allow_positional_parsing: bool = True,
|
15
16
|
):
|
16
17
|
super().__init__(
|
17
18
|
name=name,
|
@@ -20,6 +21,7 @@ class FloatInput(BaseInput):
|
|
20
21
|
default_str=default_str,
|
21
22
|
auto_render=auto_render,
|
22
23
|
allow_empty=allow_empty,
|
24
|
+
allow_positional_parsing=allow_positional_parsing,
|
23
25
|
)
|
24
26
|
|
25
27
|
def to_html(self, ctx: AnySharedContext) -> str:
|
zrb/input/int_input.py
CHANGED
@@ -12,6 +12,7 @@ class IntInput(BaseInput):
|
|
12
12
|
default_str: StrAttr = "0",
|
13
13
|
auto_render: bool = True,
|
14
14
|
allow_empty: bool = False,
|
15
|
+
allow_positional_parsing: bool = True,
|
15
16
|
):
|
16
17
|
super().__init__(
|
17
18
|
name=name,
|
@@ -20,6 +21,7 @@ class IntInput(BaseInput):
|
|
20
21
|
default_str=default_str,
|
21
22
|
auto_render=auto_render,
|
22
23
|
allow_empty=allow_empty,
|
24
|
+
allow_positional_parsing=allow_positional_parsing,
|
23
25
|
)
|
24
26
|
|
25
27
|
def to_html(self, ctx: AnySharedContext) -> str:
|
zrb/input/option_input.py
CHANGED
@@ -14,6 +14,7 @@ class OptionInput(BaseInput):
|
|
14
14
|
default_str: StrAttr = "",
|
15
15
|
auto_render: bool = True,
|
16
16
|
allow_empty: bool = False,
|
17
|
+
allow_positional_parsing: bool = True,
|
17
18
|
):
|
18
19
|
super().__init__(
|
19
20
|
name=name,
|
@@ -22,6 +23,7 @@ class OptionInput(BaseInput):
|
|
22
23
|
default_str=default_str,
|
23
24
|
auto_render=auto_render,
|
24
25
|
allow_empty=allow_empty,
|
26
|
+
allow_positional_parsing=allow_positional_parsing,
|
25
27
|
)
|
26
28
|
self._options = options
|
27
29
|
|
zrb/input/password_input.py
CHANGED
@@ -14,6 +14,7 @@ class PasswordInput(BaseInput):
|
|
14
14
|
default_str: str | Callable[[AnySharedContext], str] = "",
|
15
15
|
auto_render: bool = True,
|
16
16
|
allow_empty: bool = False,
|
17
|
+
allow_positional_parsing: bool = True,
|
17
18
|
):
|
18
19
|
super().__init__(
|
19
20
|
name=name,
|
@@ -22,6 +23,7 @@ class PasswordInput(BaseInput):
|
|
22
23
|
default_str=default_str,
|
23
24
|
auto_render=auto_render,
|
24
25
|
allow_empty=allow_empty,
|
26
|
+
allow_positional_parsing=allow_positional_parsing,
|
25
27
|
)
|
26
28
|
self._is_secret = True
|
27
29
|
|
zrb/input/text_input.py
CHANGED
@@ -18,6 +18,7 @@ class TextInput(BaseInput):
|
|
18
18
|
default_str: str | Callable[[AnySharedContext], str] = "",
|
19
19
|
auto_render: bool = True,
|
20
20
|
allow_empty: bool = False,
|
21
|
+
allow_positional_parsing: bool = True,
|
21
22
|
editor: str = DEFAULT_EDITOR,
|
22
23
|
extension: str = ".txt",
|
23
24
|
comment_start: str | None = None,
|
@@ -30,6 +31,7 @@ class TextInput(BaseInput):
|
|
30
31
|
default_str=default_str,
|
31
32
|
auto_render=auto_render,
|
32
33
|
allow_empty=allow_empty,
|
34
|
+
allow_positional_parsing=allow_positional_parsing,
|
33
35
|
)
|
34
36
|
self._editor = editor
|
35
37
|
self._extension = extension
|
zrb/runner/common_util.py
CHANGED
@@ -15,7 +15,7 @@ def get_run_kwargs(
|
|
15
15
|
if task_input.name in str_kwargs:
|
16
16
|
# Update shared context for next input default value
|
17
17
|
task_input.update_shared_context(shared_ctx, str_kwargs[task_input.name])
|
18
|
-
elif arg_index < len(args):
|
18
|
+
elif arg_index < len(args) and task_input.allow_positional_parsing:
|
19
19
|
run_kwargs[task_input.name] = args[arg_index]
|
20
20
|
# Update shared context for next input default value
|
21
21
|
task_input.update_shared_context(shared_ctx, run_kwargs[task_input.name])
|
@@ -13,14 +13,16 @@ const CURRENT_SESSION = {
|
|
13
13
|
for (const inputName in dataInputs) {
|
14
14
|
const inputValue = dataInputs[inputName];
|
15
15
|
const input = submitTaskForm.querySelector(`[name="${inputName}"]`);
|
16
|
-
input
|
16
|
+
if (input) {
|
17
|
+
input.value = inputValue;
|
18
|
+
}
|
17
19
|
}
|
18
20
|
resultLineCount = data.final_result.split("\n").length;
|
19
21
|
resultTextarea.rows = resultLineCount <= 5 ? resultLineCount : 5;
|
20
22
|
// update text areas
|
21
23
|
resultTextarea.value = data.final_result;
|
22
24
|
logTextarea.value = data.log.join("\n");
|
23
|
-
logTextarea.scrollTop = logTextarea.scrollHeight;
|
25
|
+
// logTextarea.scrollTop = logTextarea.scrollHeight;
|
24
26
|
// visualize history
|
25
27
|
this.showCurrentSession(data.task_status, data.finished);
|
26
28
|
if (data.finished) {
|
@@ -20,9 +20,9 @@ window.addEventListener("load", async function () {
|
|
20
20
|
|
21
21
|
|
22
22
|
const submitTaskForm = document.getElementById("submit-task-form");
|
23
|
-
submitTaskForm.addEventListener("
|
23
|
+
submitTaskForm.addEventListener("change", async function(event) {
|
24
24
|
const currentInput = event.target;
|
25
|
-
const inputs = Array.from(submitTaskForm.querySelectorAll("input[name]"));
|
25
|
+
const inputs = Array.from(submitTaskForm.querySelectorAll("input[name], textarea[name], select[name]"));
|
26
26
|
const inputMap = {};
|
27
27
|
const fixedInputNames = [];
|
28
28
|
for (const input of inputs) {
|
@@ -53,6 +53,12 @@ submitTaskForm.addEventListener("input", async function(event) {
|
|
53
53
|
return;
|
54
54
|
}
|
55
55
|
const input = submitTaskForm.querySelector(`[name="${key}"]`);
|
56
|
+
if (input === currentInput) {
|
57
|
+
return;
|
58
|
+
}
|
59
|
+
if (value === "") {
|
60
|
+
return;
|
61
|
+
}
|
56
62
|
input.value = value;
|
57
63
|
});
|
58
64
|
} else {
|
@@ -94,9 +94,54 @@ def serve_task_session_api(
|
|
94
94
|
if min_start_query is None
|
95
95
|
else datetime.strptime(min_start_query, "%Y-%m-%d %H:%M:%S")
|
96
96
|
)
|
97
|
-
return
|
98
|
-
|
97
|
+
return sanitize_session_state_log_list(
|
98
|
+
task,
|
99
|
+
session_state_logger.list(
|
100
|
+
task_path, min_start_time, max_start_time, page, limit
|
101
|
+
),
|
99
102
|
)
|
100
103
|
else:
|
101
|
-
return
|
104
|
+
return sanitize_session_state_log(
|
105
|
+
task, session_state_logger.read(residual_args[0])
|
106
|
+
)
|
102
107
|
return JSONResponse(content={"detail": "Not found"}, status_code=404)
|
108
|
+
|
109
|
+
|
110
|
+
def sanitize_session_state_log_list(
|
111
|
+
task: AnyTask, session_state_log_list: SessionStateLogList
|
112
|
+
) -> SessionStateLogList:
|
113
|
+
return SessionStateLogList(
|
114
|
+
total=session_state_log_list.total,
|
115
|
+
data=[
|
116
|
+
sanitize_session_state_log(task, data)
|
117
|
+
for data in session_state_log_list.data
|
118
|
+
],
|
119
|
+
)
|
120
|
+
|
121
|
+
|
122
|
+
def sanitize_session_state_log(
|
123
|
+
task: AnyTask, session_state_log: SessionStateLog
|
124
|
+
) -> SessionStateLog:
|
125
|
+
"""
|
126
|
+
In session, we create snake_case aliases of inputs.
|
127
|
+
The purpose was to increase ergonomics, so that user can use `input.system_prompt`
|
128
|
+
instead of `input["system-prompt"]`
|
129
|
+
However, when we serve the session through HTTP API,
|
130
|
+
we only want to show the original input names.
|
131
|
+
"""
|
132
|
+
enhanced_inputs = session_state_log.input
|
133
|
+
real_inputs = {}
|
134
|
+
for real_input in task.inputs:
|
135
|
+
real_input_name = real_input.name
|
136
|
+
real_inputs[real_input_name] = enhanced_inputs[real_input_name]
|
137
|
+
return SessionStateLog(
|
138
|
+
name=session_state_log.name,
|
139
|
+
start_time=session_state_log.start_time,
|
140
|
+
main_task_name=session_state_log.main_task_name,
|
141
|
+
path=session_state_log.path,
|
142
|
+
input=real_inputs,
|
143
|
+
final_result=session_state_log.final_result,
|
144
|
+
finished=session_state_log.finished,
|
145
|
+
log=session_state_log.log,
|
146
|
+
task_status=session_state_log.task_status,
|
147
|
+
)
|
zrb/task/base_task.py
CHANGED
@@ -242,28 +242,28 @@ class BaseTask(AnyTask):
|
|
242
242
|
def run(
|
243
243
|
self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
|
244
244
|
) -> Any:
|
245
|
-
|
246
|
-
try:
|
247
|
-
return loop.run_until_complete(self._run_and_cleanup(session, str_kwargs))
|
248
|
-
finally:
|
249
|
-
loop.close()
|
245
|
+
return asyncio.run(self._run_and_cleanup(session, str_kwargs))
|
250
246
|
|
251
247
|
async def _run_and_cleanup(
|
252
|
-
self,
|
248
|
+
self,
|
249
|
+
session: AnySession | None = None,
|
250
|
+
str_kwargs: dict[str, str] = {},
|
253
251
|
) -> Any:
|
252
|
+
current_task = asyncio.create_task(self.async_run(session, str_kwargs))
|
254
253
|
try:
|
255
|
-
result = await
|
254
|
+
result = await current_task
|
256
255
|
finally:
|
257
|
-
if not session.is_terminated:
|
256
|
+
if session and not session.is_terminated:
|
258
257
|
session.terminate()
|
259
258
|
# Cancel all running tasks except the current one
|
260
|
-
current_task = asyncio.current_task()
|
261
259
|
pending = [task for task in asyncio.all_tasks() if task is not current_task]
|
262
260
|
for task in pending:
|
263
261
|
task.cancel()
|
264
|
-
# Wait for all tasks to complete with a timeout
|
265
262
|
if pending:
|
266
|
-
|
263
|
+
try:
|
264
|
+
await asyncio.wait(pending, timeout=5)
|
265
|
+
except asyncio.CancelledError:
|
266
|
+
pass
|
267
267
|
return result
|
268
268
|
|
269
269
|
async def async_run(
|
@@ -274,7 +274,8 @@ class BaseTask(AnyTask):
|
|
274
274
|
# Update session
|
275
275
|
self.__fill_shared_context_inputs(session.shared_ctx, str_kwargs)
|
276
276
|
self.__fill_shared_context_envs(session.shared_ctx)
|
277
|
-
|
277
|
+
result = await run_async(self.exec_root_tasks(session))
|
278
|
+
return result
|
278
279
|
|
279
280
|
def __fill_shared_context_inputs(
|
280
281
|
self, shared_context: AnySharedContext, str_kwargs: dict[str, str] = {}
|
@@ -352,7 +353,7 @@ class BaseTask(AnyTask):
|
|
352
353
|
session.get_task_status(self).mark_as_skipped()
|
353
354
|
return
|
354
355
|
# Wait for task to be ready
|
355
|
-
await run_async(self.__exec_action_until_ready(session))
|
356
|
+
return await run_async(self.__exec_action_until_ready(session))
|
356
357
|
|
357
358
|
def __get_execute_condition(self, session: Session) -> bool:
|
358
359
|
ctx = self.get_ctx(session)
|
zrb/task/llm_task.py
CHANGED
@@ -17,6 +17,7 @@ from zrb.util.attr import get_str_attr
|
|
17
17
|
from zrb.util.cli.style import stylize_faint
|
18
18
|
from zrb.util.file import read_file, write_file
|
19
19
|
from zrb.util.llm.tool import callable_to_tool_schema
|
20
|
+
from zrb.util.run import run_async
|
20
21
|
|
21
22
|
ListOfDict = list[dict[str, Any]]
|
22
23
|
|
@@ -24,14 +25,18 @@ ListOfDict = list[dict[str, Any]]
|
|
24
25
|
class AdditionalTool(BaseModel):
|
25
26
|
fn: Callable
|
26
27
|
name: str | None
|
27
|
-
description: str | None
|
28
28
|
|
29
29
|
|
30
30
|
def scratchpad(thought: str) -> str:
|
31
|
-
"""
|
31
|
+
"""Write your thought, analysis, reasoning, and evaluation here."""
|
32
32
|
return thought
|
33
33
|
|
34
34
|
|
35
|
+
def end_conversation(final_answer: str) -> str:
|
36
|
+
"""End conversation with a final answer containing all necessary information"""
|
37
|
+
return final_answer
|
38
|
+
|
39
|
+
|
35
40
|
class LLMTask(BaseTask):
|
36
41
|
def __init__(
|
37
42
|
self,
|
@@ -47,11 +52,17 @@ class LLMTask(BaseTask):
|
|
47
52
|
system_prompt: StrAttr | None = LLM_SYSTEM_PROMPT,
|
48
53
|
render_system_prompt: bool = True,
|
49
54
|
message: StrAttr | None = None,
|
50
|
-
tools:
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
+
tools: list[Callable] | Callable[[AnySharedContext], list[Callable]] = [],
|
56
|
+
conversation_history: (
|
57
|
+
ListOfDict | Callable[[AnySharedContext], ListOfDict]
|
58
|
+
) = [],
|
59
|
+
conversation_history_reader: (
|
60
|
+
Callable[[AnySharedContext], ListOfDict] | None
|
61
|
+
) = None,
|
62
|
+
conversation_history_writer: (
|
63
|
+
Callable[[AnySharedContext, ListOfDict], None] | None
|
64
|
+
) = None,
|
65
|
+
conversation_history_file: StrAttr | None = None,
|
55
66
|
render_history_file: bool = True,
|
56
67
|
model_kwargs: (
|
57
68
|
dict[str, Any] | Callable[[AnySharedContext], dict[str, Any]]
|
@@ -97,84 +108,165 @@ class LLMTask(BaseTask):
|
|
97
108
|
self._render_system_prompt = render_system_prompt
|
98
109
|
self._message = message
|
99
110
|
self._tools = tools
|
100
|
-
self.
|
101
|
-
self.
|
111
|
+
self._conversation_history = conversation_history
|
112
|
+
self._conversation_history_reader = conversation_history_reader
|
113
|
+
self._conversation_history_writer = conversation_history_writer
|
114
|
+
self._conversation_history_file = conversation_history_file
|
102
115
|
self._render_history_file = render_history_file
|
103
|
-
self._additional_tools: list[AdditionalTool] = []
|
104
116
|
|
105
|
-
def add_tool(
|
106
|
-
self
|
107
|
-
):
|
108
|
-
self._additional_tools.append(
|
109
|
-
AdditionalTool(fn=tool, name=name, description=description)
|
110
|
-
)
|
117
|
+
def add_tool(self, tool: Callable):
|
118
|
+
self._tools.append(tool)
|
111
119
|
|
112
120
|
async def _exec_action(self, ctx: AnyContext) -> Any:
|
113
|
-
|
121
|
+
import litellm
|
122
|
+
from litellm.utils import supports_function_calling
|
114
123
|
|
124
|
+
user_message = {"role": "user", "content": self._get_message(ctx)}
|
125
|
+
ctx.print(stylize_faint(f"{user_message}"))
|
115
126
|
model = self._get_model(ctx)
|
116
127
|
try:
|
117
|
-
|
128
|
+
is_function_call_supported = supports_function_calling(model=model)
|
118
129
|
except Exception:
|
119
|
-
|
120
|
-
|
130
|
+
is_function_call_supported = False
|
131
|
+
litellm.add_function_to_prompt = True
|
132
|
+
if not is_function_call_supported:
|
133
|
+
ctx.log_warning(f"Model {model} doesn't support function call")
|
134
|
+
available_tools = self._get_available_tools(
|
135
|
+
ctx, include_end_conversation=not is_function_call_supported
|
136
|
+
)
|
137
|
+
model_kwargs = self._get_model_kwargs(ctx, available_tools)
|
121
138
|
ctx.log_debug("MODEL KWARGS", model_kwargs)
|
122
139
|
system_prompt = self._get_system_prompt(ctx)
|
123
140
|
ctx.log_debug("SYSTEM PROMPT", system_prompt)
|
124
|
-
history = self.
|
141
|
+
history = await self._read_conversation_history(ctx)
|
125
142
|
ctx.log_debug("HISTORY PROMPT", history)
|
126
|
-
|
127
|
-
ctx.print(stylize_faint(f"{user_message}"))
|
128
|
-
messages = history + [user_message]
|
129
|
-
available_tools = self._get_tools(ctx)
|
130
|
-
available_tools["scratchpad"] = scratchpad
|
131
|
-
if allow_function_call:
|
132
|
-
tool_schema = [
|
133
|
-
callable_to_tool_schema(tool, name)
|
134
|
-
for name, tool in available_tools.items()
|
135
|
-
]
|
136
|
-
for additional_tool in self._additional_tools:
|
137
|
-
fn = additional_tool.fn
|
138
|
-
tool_name = additional_tool.name or fn.__name__
|
139
|
-
tool_description = additional_tool.description
|
140
|
-
available_tools[tool_name] = additional_tool.fn
|
141
|
-
tool_schema.append(
|
142
|
-
callable_to_tool_schema(
|
143
|
-
fn, name=tool_name, description=tool_description
|
144
|
-
)
|
145
|
-
)
|
146
|
-
model_kwargs["tools"] = tool_schema
|
147
|
-
ctx.log_debug("TOOL SCHEMA", tool_schema)
|
148
|
-
history_file = self._get_history_file(ctx)
|
143
|
+
conversations = history + [user_message]
|
149
144
|
while True:
|
150
|
-
|
151
|
-
model
|
152
|
-
|
153
|
-
|
145
|
+
llm_response = await self._get_llm_response(
|
146
|
+
model, system_prompt, conversations, model_kwargs
|
147
|
+
)
|
148
|
+
llm_response_dict = llm_response.to_dict()
|
149
|
+
ctx.print(stylize_faint(f"{llm_response_dict}"))
|
150
|
+
conversations.append(llm_response_dict)
|
151
|
+
ctx.log_debug("RESPONSE MESSAGE", llm_response)
|
152
|
+
if is_function_call_supported:
|
153
|
+
if not llm_response.tool_calls:
|
154
|
+
# No tool call, end conversation
|
155
|
+
await self._write_conversation_history(ctx, conversations)
|
156
|
+
return llm_response.content
|
157
|
+
await self._handle_tool_calls(
|
158
|
+
ctx, available_tools, conversations, llm_response
|
159
|
+
)
|
160
|
+
if not is_function_call_supported:
|
161
|
+
try:
|
162
|
+
json_payload = json.loads(llm_response.content)
|
163
|
+
function_name = _get_fallback_function_name(json_payload)
|
164
|
+
function_kwargs = _get_fallback_function_kwargs(json_payload)
|
165
|
+
tool_execution_message = (
|
166
|
+
await self._create_fallback_tool_exec_message(
|
167
|
+
available_tools, function_name, function_kwargs
|
168
|
+
)
|
169
|
+
)
|
170
|
+
ctx.print(stylize_faint(f"{tool_execution_message}"))
|
171
|
+
conversations.append(tool_execution_message)
|
172
|
+
if function_name == "end_conversation":
|
173
|
+
await self._write_conversation_history(ctx, conversations)
|
174
|
+
return function_kwargs.get("final_answer", "")
|
175
|
+
except Exception as e:
|
176
|
+
ctx.log_error(e)
|
177
|
+
tool_execution_message = self._create_exec_scratchpad_message(
|
178
|
+
f"{e}"
|
179
|
+
)
|
180
|
+
conversations.append(tool_execution_message)
|
181
|
+
|
182
|
+
async def _handle_tool_calls(
|
183
|
+
self,
|
184
|
+
ctx: AnyContext,
|
185
|
+
available_tools: dict[str, Callable],
|
186
|
+
conversations: list[dict[str, Any]],
|
187
|
+
llm_response: Any,
|
188
|
+
):
|
189
|
+
# noqa Reference: https://docs.litellm.ai/docs/completion/function_call#full-code---parallel-function-calling-with-gpt-35-turbo-1106
|
190
|
+
for tool_call in llm_response.tool_calls:
|
191
|
+
tool_execution_message = await self._create_tool_exec_message(
|
192
|
+
available_tools, tool_call
|
154
193
|
)
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
194
|
+
ctx.print(stylize_faint(f"{tool_execution_message}"))
|
195
|
+
conversations.append(tool_execution_message)
|
196
|
+
|
197
|
+
async def _write_conversation_history(
|
198
|
+
self, ctx: AnyContext, conversations: list[Any]
|
199
|
+
):
|
200
|
+
if self._conversation_history_writer is not None:
|
201
|
+
await run_async(self._conversation_history_writer(ctx, conversations))
|
202
|
+
history_file = self._get_history_file(ctx)
|
203
|
+
if history_file != "":
|
204
|
+
write_file(history_file, json.dumps(conversations, indent=2))
|
205
|
+
|
206
|
+
async def _get_llm_response(
|
207
|
+
self,
|
208
|
+
model: str,
|
209
|
+
system_prompt: str,
|
210
|
+
conversations: list[Any],
|
211
|
+
model_kwargs: dict[str, Any],
|
212
|
+
) -> Any:
|
213
|
+
from litellm import acompletion
|
214
|
+
|
215
|
+
llm_response = await acompletion(
|
216
|
+
model=model,
|
217
|
+
messages=[{"role": "system", "content": system_prompt}] + conversations,
|
218
|
+
**model_kwargs,
|
219
|
+
)
|
220
|
+
return llm_response.choices[0].message
|
221
|
+
|
222
|
+
async def _create_tool_exec_message(
|
223
|
+
self, available_tools: dict[str, Callable], tool_call: Any
|
224
|
+
) -> dict[str, Any]:
|
225
|
+
function_name = tool_call.function.name
|
226
|
+
function_kwargs = json.loads(tool_call.function.arguments)
|
227
|
+
return {
|
228
|
+
"tool_call_id": tool_call.id,
|
229
|
+
"role": "tool",
|
230
|
+
"name": function_name,
|
231
|
+
"content": await self._get_exec_tool_result(
|
232
|
+
available_tools, function_name, function_kwargs
|
233
|
+
),
|
234
|
+
}
|
235
|
+
|
236
|
+
async def _create_fallback_tool_exec_message(
|
237
|
+
self,
|
238
|
+
available_tools: dict[str, Callable],
|
239
|
+
function_name: str,
|
240
|
+
function_kwargs: dict[str, Any],
|
241
|
+
) -> dict[str, Any]:
|
242
|
+
result = await self._get_exec_tool_result(
|
243
|
+
available_tools, function_name, function_kwargs
|
244
|
+
)
|
245
|
+
return self._create_exec_scratchpad_message(
|
246
|
+
f"Result of {function_name} call: {result}"
|
247
|
+
)
|
248
|
+
|
249
|
+
def _create_exec_scratchpad_message(self, message: str) -> dict[str, Any]:
|
250
|
+
return {
|
251
|
+
"role": "assistant",
|
252
|
+
"content": json.dumps(
|
253
|
+
{"name": "scratchpad", "arguments": {"thought": message}}
|
254
|
+
),
|
255
|
+
}
|
256
|
+
|
257
|
+
async def _get_exec_tool_result(
|
258
|
+
self,
|
259
|
+
available_tools: dict[str, Callable],
|
260
|
+
function_name: str,
|
261
|
+
function_kwargs: dict[str, Any],
|
262
|
+
) -> str:
|
263
|
+
if function_name not in available_tools:
|
264
|
+
return f"[ERROR] Invalid tool: {function_name}"
|
265
|
+
function_to_call = available_tools[function_name]
|
266
|
+
try:
|
267
|
+
return await run_async(function_to_call(**function_kwargs))
|
268
|
+
except Exception as e:
|
269
|
+
return f"[ERROR] {e}"
|
178
270
|
|
179
271
|
def _get_model(self, ctx: AnyContext) -> str:
|
180
272
|
return get_str_attr(
|
@@ -192,29 +284,67 @@ class LLMTask(BaseTask):
|
|
192
284
|
def _get_message(self, ctx: AnyContext) -> str:
|
193
285
|
return get_str_attr(ctx, self._message, "How are you?", auto_render=True)
|
194
286
|
|
195
|
-
def _get_model_kwargs(
|
287
|
+
def _get_model_kwargs(
|
288
|
+
self, ctx: AnyContext, available_tools: dict[str, Callable]
|
289
|
+
) -> dict[str, Any]:
|
290
|
+
model_kwargs = {}
|
196
291
|
if callable(self._model_kwargs):
|
197
|
-
|
198
|
-
|
292
|
+
model_kwargs = self._model_kwargs(ctx)
|
293
|
+
else:
|
294
|
+
model_kwargs = self._model_kwargs
|
295
|
+
model_kwargs["tools"] = [
|
296
|
+
callable_to_tool_schema(tool) for tool in available_tools.values()
|
297
|
+
]
|
298
|
+
return model_kwargs
|
199
299
|
|
200
|
-
def
|
201
|
-
|
202
|
-
|
203
|
-
|
300
|
+
def _get_available_tools(
|
301
|
+
self, ctx: AnyContext, include_end_conversation: bool
|
302
|
+
) -> dict[str, Callable]:
|
303
|
+
tools = {"scratchpad": scratchpad}
|
304
|
+
if include_end_conversation:
|
305
|
+
tools["end_conversation"] = end_conversation
|
306
|
+
tool_list = self._tools(ctx) if callable(self._tools) else self._tools
|
307
|
+
for tool in tool_list:
|
308
|
+
tools[tool.__name__] = tool
|
309
|
+
return tools
|
204
310
|
|
205
|
-
def
|
206
|
-
if
|
207
|
-
return self.
|
311
|
+
async def _read_conversation_history(self, ctx: AnyContext) -> ListOfDict:
|
312
|
+
if self._conversation_history_reader is not None:
|
313
|
+
return await run_async(self._conversation_history_reader(ctx))
|
314
|
+
if callable(self._conversation_history):
|
315
|
+
return self._conversation_history(ctx)
|
208
316
|
history_file = self._get_history_file(ctx)
|
209
317
|
if (
|
210
|
-
len(self.
|
318
|
+
len(self._conversation_history) == 0
|
211
319
|
and history_file != ""
|
212
320
|
and os.path.isfile(history_file)
|
213
321
|
):
|
214
322
|
return json.loads(read_file(history_file))
|
215
|
-
return self.
|
323
|
+
return self._conversation_history
|
216
324
|
|
217
325
|
def _get_history_file(self, ctx: AnyContext) -> str:
|
218
326
|
return get_str_attr(
|
219
|
-
ctx,
|
327
|
+
ctx,
|
328
|
+
self._conversation_history_file,
|
329
|
+
"",
|
330
|
+
auto_render=self._render_history_file,
|
220
331
|
)
|
332
|
+
|
333
|
+
|
334
|
+
def _get_fallback_function_name(json_payload: dict[str, Any]) -> str:
|
335
|
+
for key in ("name",):
|
336
|
+
if key in json_payload:
|
337
|
+
return json_payload[key]
|
338
|
+
raise ValueError("Function name not provided")
|
339
|
+
|
340
|
+
|
341
|
+
def _get_fallback_function_kwargs(json_payload: dict[str, Any]) -> str:
|
342
|
+
for key in (
|
343
|
+
"arguments",
|
344
|
+
"args",
|
345
|
+
"parameters",
|
346
|
+
"params",
|
347
|
+
):
|
348
|
+
if key in json_payload:
|
349
|
+
return json_payload[key]
|
350
|
+
raise ValueError("Function arguments not provided")
|