zrb 1.5.16__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__init__.py +2 -2
- zrb/__main__.py +12 -12
- zrb/builtin/__init__.py +2 -2
- zrb/builtin/llm/chat_session.py +202 -0
- zrb/builtin/llm/history.py +6 -6
- zrb/builtin/llm/llm_ask.py +142 -0
- zrb/builtin/llm/tool/rag.py +39 -23
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/view/static/default/script.js +1 -1
- zrb/builtin/todo.py +21 -19
- zrb/callback/any_callback.py +1 -1
- zrb/callback/callback.py +69 -7
- zrb/config.py +261 -91
- zrb/context/shared_context.py +4 -2
- zrb/input/text_input.py +9 -6
- zrb/llm_config.py +65 -74
- zrb/runner/cli.py +13 -4
- zrb/runner/web_app.py +3 -3
- zrb/runner/web_config/config_factory.py +11 -22
- zrb/runner/web_route/error_page/show_error_page.py +16 -6
- zrb/runner/web_route/home_page/home_page_route.py +23 -7
- zrb/runner/web_route/home_page/view.html +19 -33
- zrb/runner/web_route/login_page/login_page_route.py +14 -4
- zrb/runner/web_route/login_page/view.html +33 -51
- zrb/runner/web_route/logout_page/logout_page_route.py +15 -5
- zrb/runner/web_route/logout_page/view.html +23 -41
- zrb/runner/web_route/node_page/group/show_group_page.py +26 -10
- zrb/runner/web_route/node_page/group/view.html +22 -37
- zrb/runner/web_route/node_page/task/show_task_page.py +34 -19
- zrb/runner/web_route/node_page/task/view.html +74 -88
- zrb/runner/web_route/static/global_template.html +27 -0
- zrb/runner/web_route/static/resources/common.css +21 -0
- zrb/runner/web_route/static/resources/common.js +28 -0
- zrb/runner/web_route/task_session_api_route.py +3 -1
- zrb/session_state_logger/session_state_logger_factory.py +2 -2
- zrb/task/base_task.py +4 -1
- zrb/task/base_trigger.py +47 -2
- zrb/task/cmd_task.py +3 -3
- zrb/task/llm/agent.py +10 -1
- zrb/task/llm/print_node.py +5 -6
- zrb/task/llm_task.py +1 -1
- zrb/util/git_subtree.py +1 -1
- {zrb-1.5.16.dist-info → zrb-1.6.0.dist-info}/METADATA +1 -1
- {zrb-1.5.16.dist-info → zrb-1.6.0.dist-info}/RECORD +45 -42
- zrb/builtin/llm/llm_chat.py +0 -124
- {zrb-1.5.16.dist-info → zrb-1.6.0.dist-info}/WHEEL +0 -0
- {zrb-1.5.16.dist-info → zrb-1.6.0.dist-info}/entry_points.txt +0 -0
zrb/__init__.py
CHANGED
@@ -11,7 +11,7 @@ from zrb.callback.any_callback import AnyCallback
|
|
11
11
|
from zrb.callback.callback import Callback
|
12
12
|
from zrb.cmd.cmd_result import CmdResult
|
13
13
|
from zrb.cmd.cmd_val import Cmd, CmdPath
|
14
|
-
from zrb.config import
|
14
|
+
from zrb.config import CFG
|
15
15
|
from zrb.content_transformer.any_content_transformer import AnyContentTransformer
|
16
16
|
from zrb.content_transformer.content_transformer import ContentTransformer
|
17
17
|
from zrb.context.any_context import AnyContext
|
@@ -109,7 +109,7 @@ assert Xcom
|
|
109
109
|
assert web_config
|
110
110
|
assert User
|
111
111
|
|
112
|
-
if LOAD_BUILTIN:
|
112
|
+
if CFG.LOAD_BUILTIN:
|
113
113
|
from zrb import builtin
|
114
114
|
|
115
115
|
assert builtin
|
zrb/__main__.py
CHANGED
@@ -2,7 +2,7 @@ import logging
|
|
2
2
|
import os
|
3
3
|
import sys
|
4
4
|
|
5
|
-
from zrb.config import
|
5
|
+
from zrb.config import CFG
|
6
6
|
from zrb.runner.cli import cli
|
7
7
|
from zrb.util.cli.style import stylize_error, stylize_faint, stylize_warning
|
8
8
|
from zrb.util.group import NodeNotFoundError
|
@@ -23,29 +23,29 @@ class FaintFormatter(logging.Formatter):
|
|
23
23
|
|
24
24
|
|
25
25
|
def serve_cli():
|
26
|
-
LOGGER.setLevel(LOGGING_LEVEL)
|
26
|
+
CFG.LOGGER.setLevel(CFG.LOGGING_LEVEL)
|
27
27
|
# Remove existing handlers to avoid duplicates/default formatting
|
28
|
-
for handler in LOGGER.handlers[:]:
|
29
|
-
LOGGER.removeHandler(handler)
|
28
|
+
for handler in CFG.LOGGER.handlers[:]:
|
29
|
+
CFG.LOGGER.removeHandler(handler)
|
30
30
|
handler = logging.StreamHandler()
|
31
31
|
handler.setFormatter(FaintFormatter())
|
32
|
-
LOGGER.addHandler(handler)
|
32
|
+
CFG.LOGGER.addHandler(handler)
|
33
33
|
# --- End Logging Configuration ---
|
34
34
|
try:
|
35
35
|
# load init modules
|
36
|
-
for init_module in INIT_MODULES:
|
37
|
-
LOGGER.info(f"Loading {init_module}")
|
36
|
+
for init_module in CFG.INIT_MODULES:
|
37
|
+
CFG.LOGGER.info(f"Loading {init_module}")
|
38
38
|
load_module(init_module)
|
39
39
|
zrb_init_path_list = _get_zrb_init_path_list()
|
40
40
|
# load init scripts
|
41
|
-
for init_script in INIT_SCRIPTS:
|
41
|
+
for init_script in CFG.INIT_SCRIPTS:
|
42
42
|
abs_init_script = os.path.abspath(os.path.expanduser(init_script))
|
43
43
|
if abs_init_script not in zrb_init_path_list:
|
44
|
-
LOGGER.info(f"Loading {abs_init_script}")
|
44
|
+
CFG.LOGGER.info(f"Loading {abs_init_script}")
|
45
45
|
load_file(abs_init_script, -1)
|
46
46
|
# load zrb init
|
47
47
|
for zrb_init_path in zrb_init_path_list:
|
48
|
-
LOGGER.info(f"Loading {zrb_init_path}")
|
48
|
+
CFG.LOGGER.info(f"Loading {zrb_init_path}")
|
49
49
|
load_file(zrb_init_path)
|
50
50
|
# run the CLI
|
51
51
|
cli.run(sys.argv[1:])
|
@@ -69,8 +69,8 @@ def _get_zrb_init_path_list() -> list[str]:
|
|
69
69
|
dir_path_list.append(current_path)
|
70
70
|
zrb_init_path_list = []
|
71
71
|
for current_path in dir_path_list[::-1]:
|
72
|
-
zrb_init_path = os.path.join(current_path,
|
73
|
-
LOGGER.info(f"Finding {zrb_init_path}")
|
72
|
+
zrb_init_path = os.path.join(current_path, CFG.INIT_FILE_NAME)
|
73
|
+
CFG.LOGGER.info(f"Finding {zrb_init_path}")
|
74
74
|
if os.path.isfile(zrb_init_path):
|
75
75
|
zrb_init_path_list.append(zrb_init_path)
|
76
76
|
return zrb_init_path_list
|
zrb/builtin/__init__.py
CHANGED
@@ -9,7 +9,7 @@ from zrb.builtin.git import (
|
|
9
9
|
from zrb.builtin.git_subtree import git_add_subtree, git_pull_subtree, git_push_subtree
|
10
10
|
from zrb.builtin.http import generate_curl, http_request
|
11
11
|
from zrb.builtin.jwt import decode_jwt, encode_jwt, validate_jwt
|
12
|
-
from zrb.builtin.llm.
|
12
|
+
from zrb.builtin.llm.llm_ask import llm_ask
|
13
13
|
from zrb.builtin.md5 import hash_md5, sum_md5, validate_md5
|
14
14
|
from zrb.builtin.project.add.fastapp.fastapp_task import add_fastapp_to_project
|
15
15
|
from zrb.builtin.project.create.project_task import create_project
|
@@ -55,7 +55,7 @@ assert validate_base64
|
|
55
55
|
assert encode_jwt
|
56
56
|
assert decode_jwt
|
57
57
|
assert validate_jwt
|
58
|
-
assert
|
58
|
+
assert llm_ask
|
59
59
|
assert hash_md5
|
60
60
|
assert sum_md5
|
61
61
|
assert validate_md5
|
@@ -0,0 +1,202 @@
|
|
1
|
+
"""
|
2
|
+
This module provides functions for managing interactive chat sessions with an LLM.
|
3
|
+
|
4
|
+
It handles reading user input, triggering the LLM task, and managing the
|
5
|
+
conversation flow via XCom.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import asyncio
|
9
|
+
|
10
|
+
from zrb.context.any_context import AnyContext
|
11
|
+
from zrb.util.cli.style import stylize_bold_yellow, stylize_faint
|
12
|
+
|
13
|
+
|
14
|
+
async def read_user_prompt(ctx: AnyContext) -> str:
|
15
|
+
"""
|
16
|
+
Reads user input from the CLI for an interactive chat session.
|
17
|
+
|
18
|
+
Handles special commands like /bye, /multi, /end, and /help.
|
19
|
+
Triggers the LLM task and waits for the result.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
ctx: The context object for the task.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
The final result from the LLM session.
|
26
|
+
"""
|
27
|
+
_show_info(ctx)
|
28
|
+
final_result = ""
|
29
|
+
ctx.print(stylize_faint("🧑 >> ") + f"{ctx.input.message}", plain=True)
|
30
|
+
result = await _trigger_ask_and_wait_for_result(
|
31
|
+
ctx,
|
32
|
+
user_prompt=ctx.input.message,
|
33
|
+
previous_session_name=ctx.input.previous_session,
|
34
|
+
start_new=ctx.input.start_new,
|
35
|
+
)
|
36
|
+
if ctx.env.get("_ZRB_WEB_ENV", "0") != "0":
|
37
|
+
# Don't run in web environment
|
38
|
+
if result is not None:
|
39
|
+
final_result = result
|
40
|
+
return final_result
|
41
|
+
multiline_mode = False
|
42
|
+
user_inputs = []
|
43
|
+
while True:
|
44
|
+
await asyncio.sleep(0.01)
|
45
|
+
ctx.print(stylize_faint("🧑 >> "), end="", plain=True)
|
46
|
+
user_input = input()
|
47
|
+
# Handle special input
|
48
|
+
if user_input.strip().lower() in ("/bye", "/quit"):
|
49
|
+
user_prompt = "\n".join(user_inputs)
|
50
|
+
user_inputs = []
|
51
|
+
result = await _trigger_ask_and_wait_for_result(ctx, user_prompt)
|
52
|
+
if result is not None:
|
53
|
+
final_result = result
|
54
|
+
break
|
55
|
+
elif user_input.strip().lower() in ("/multi"):
|
56
|
+
multiline_mode = True
|
57
|
+
elif user_input.strip().lower() in ("/end"):
|
58
|
+
multiline_mode = False
|
59
|
+
user_prompt = "\n".join(user_inputs)
|
60
|
+
user_inputs = []
|
61
|
+
result = await _trigger_ask_and_wait_for_result(ctx, user_prompt)
|
62
|
+
if result is not None:
|
63
|
+
final_result = result
|
64
|
+
elif user_input.strip().lower() in ("/help", "/info"):
|
65
|
+
_show_info(ctx)
|
66
|
+
continue
|
67
|
+
else:
|
68
|
+
user_inputs.append(user_input)
|
69
|
+
if multiline_mode:
|
70
|
+
continue
|
71
|
+
user_prompt = "\n".join(user_inputs)
|
72
|
+
user_inputs = []
|
73
|
+
result = await _trigger_ask_and_wait_for_result(ctx, user_prompt)
|
74
|
+
if result is not None:
|
75
|
+
final_result = result
|
76
|
+
return final_result
|
77
|
+
|
78
|
+
|
79
|
+
async def _trigger_ask_and_wait_for_result(
|
80
|
+
ctx: AnyContext,
|
81
|
+
user_prompt: str,
|
82
|
+
previous_session_name: str | None = None,
|
83
|
+
start_new: bool = False,
|
84
|
+
) -> str | None:
|
85
|
+
"""
|
86
|
+
Triggers the LLM ask task and waits for the result via XCom.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
ctx: The context object for the task.
|
90
|
+
user_prompt: The user's message to send to the LLM.
|
91
|
+
previous_session_name: The name of the previous chat session (optional).
|
92
|
+
start_new: Whether to start a new conversation (optional).
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
The result from the LLM task, or None if the user prompt is empty.
|
96
|
+
"""
|
97
|
+
if user_prompt.strip() == "":
|
98
|
+
return None
|
99
|
+
await _trigger_ask(ctx, user_prompt, previous_session_name, start_new)
|
100
|
+
result = await _wait_ask_result(ctx)
|
101
|
+
ctx.print(stylize_faint("🤖 >> ") + result, plain=True)
|
102
|
+
return result
|
103
|
+
|
104
|
+
|
105
|
+
def get_llm_ask_input_mapping(callback_ctx: AnyContext):
|
106
|
+
"""
|
107
|
+
Generates the input mapping for the LLM ask task from the callback context.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
callback_ctx: The context object for the callback.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
A dictionary containing the input mapping for the LLM ask task.
|
114
|
+
"""
|
115
|
+
data = callback_ctx.xcom.ask_trigger.pop()
|
116
|
+
return {
|
117
|
+
"model": callback_ctx.input.model,
|
118
|
+
"base-url": callback_ctx.input.base_url,
|
119
|
+
"api-key": callback_ctx.input.api_key,
|
120
|
+
"system-prompt": callback_ctx.input.system_prompt,
|
121
|
+
"start-new": data.get("start_new"),
|
122
|
+
"previous-session": data.get("previous_session_name"),
|
123
|
+
"message": data.get("message"),
|
124
|
+
}
|
125
|
+
|
126
|
+
|
127
|
+
async def _trigger_ask(
|
128
|
+
ctx: AnyContext,
|
129
|
+
user_prompt: str,
|
130
|
+
previous_session_name: str | None = None,
|
131
|
+
start_new: bool = False,
|
132
|
+
):
|
133
|
+
"""
|
134
|
+
Triggers the LLM ask task by pushing data to the 'ask_trigger' XCom queue.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
ctx: The context object for the task.
|
138
|
+
user_prompt: The user's message to send to the LLM.
|
139
|
+
previous_session_name: The name of the previous chat session (optional).
|
140
|
+
start_new: Whether to start a new conversation (optional).
|
141
|
+
"""
|
142
|
+
if previous_session_name is None:
|
143
|
+
previous_session_name = await _wait_ask_session_name(ctx)
|
144
|
+
ctx.xcom["ask_trigger"].push(
|
145
|
+
{
|
146
|
+
"previous_session_name": previous_session_name,
|
147
|
+
"start_new": start_new,
|
148
|
+
"message": user_prompt,
|
149
|
+
}
|
150
|
+
)
|
151
|
+
|
152
|
+
|
153
|
+
async def _wait_ask_result(ctx: AnyContext) -> str:
|
154
|
+
"""
|
155
|
+
Waits for and retrieves the LLM task result from the 'ask_result' XCom queue.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
ctx: The context object for the task.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
The result string from the LLM task.
|
162
|
+
"""
|
163
|
+
while "ask_result" not in ctx.xcom or len(ctx.xcom.ask_result) == 0:
|
164
|
+
await asyncio.sleep(0.1)
|
165
|
+
return ctx.xcom.ask_result.pop()
|
166
|
+
|
167
|
+
|
168
|
+
async def _wait_ask_session_name(ctx: AnyContext) -> str:
|
169
|
+
"""
|
170
|
+
Waits for and retrieves the LLM chat session name from the 'ask_session_name' XCom queue.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
ctx: The context object for the task.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
The session name string.
|
177
|
+
"""
|
178
|
+
while "ask_session_name" not in ctx.xcom or len(ctx.xcom.ask_session_name) == 0:
|
179
|
+
await asyncio.sleep(0.1)
|
180
|
+
return ctx.xcom.ask_session_name.pop()
|
181
|
+
|
182
|
+
|
183
|
+
def _show_info(ctx: AnyContext):
|
184
|
+
"""
|
185
|
+
Displays the available chat session commands to the user.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
ctx: The context object for the task.
|
189
|
+
"""
|
190
|
+
ctx.print(
|
191
|
+
stylize_bold_yellow(
|
192
|
+
"\n".join(
|
193
|
+
[
|
194
|
+
"/bye: Quit from chat session",
|
195
|
+
"/multi: Start multiline input",
|
196
|
+
"/end: End multiline input",
|
197
|
+
"/help: Show this message",
|
198
|
+
]
|
199
|
+
)
|
200
|
+
),
|
201
|
+
plain=True,
|
202
|
+
)
|
zrb/builtin/llm/history.py
CHANGED
@@ -2,7 +2,7 @@ import json
|
|
2
2
|
import os
|
3
3
|
from typing import Any
|
4
4
|
|
5
|
-
from zrb.config import
|
5
|
+
from zrb.config import CFG
|
6
6
|
from zrb.context.any_shared_context import AnySharedContext
|
7
7
|
from zrb.task.llm.history import ConversationHistoryData
|
8
8
|
from zrb.util.file import read_file, write_file
|
@@ -17,7 +17,7 @@ def read_chat_conversation(ctx: AnySharedContext) -> dict[str, Any] | list | Non
|
|
17
17
|
return None # Indicate no history to load
|
18
18
|
previous_session_name = ctx.input.previous_session
|
19
19
|
if not previous_session_name: # Check for empty string or None
|
20
|
-
last_session_file_path = os.path.join(LLM_HISTORY_DIR, "last-session")
|
20
|
+
last_session_file_path = os.path.join(CFG.LLM_HISTORY_DIR, "last-session")
|
21
21
|
if os.path.isfile(last_session_file_path):
|
22
22
|
previous_session_name = read_file(last_session_file_path).strip()
|
23
23
|
if not previous_session_name: # Handle empty last-session file
|
@@ -25,7 +25,7 @@ def read_chat_conversation(ctx: AnySharedContext) -> dict[str, Any] | list | Non
|
|
25
25
|
else:
|
26
26
|
return None # No previous session specified and no last session found
|
27
27
|
conversation_file_path = os.path.join(
|
28
|
-
LLM_HISTORY_DIR, f"{previous_session_name}.json"
|
28
|
+
CFG.LLM_HISTORY_DIR, f"{previous_session_name}.json"
|
29
29
|
)
|
30
30
|
if not os.path.isfile(conversation_file_path):
|
31
31
|
ctx.log_warning(f"History file not found: {conversation_file_path}")
|
@@ -55,19 +55,19 @@ def write_chat_conversation(
|
|
55
55
|
ctx: AnySharedContext, history_data: ConversationHistoryData
|
56
56
|
):
|
57
57
|
"""Writes the conversation history data (including context) to a session file."""
|
58
|
-
os.makedirs(LLM_HISTORY_DIR, exist_ok=True)
|
58
|
+
os.makedirs(CFG.LLM_HISTORY_DIR, exist_ok=True)
|
59
59
|
current_session_name = ctx.session.name
|
60
60
|
if not current_session_name:
|
61
61
|
ctx.log_warning("Cannot write history: Session name is empty.")
|
62
62
|
return
|
63
63
|
conversation_file_path = os.path.join(
|
64
|
-
LLM_HISTORY_DIR, f"{current_session_name}.json"
|
64
|
+
CFG.LLM_HISTORY_DIR, f"{current_session_name}.json"
|
65
65
|
)
|
66
66
|
try:
|
67
67
|
# Use model_dump_json to serialize the Pydantic model
|
68
68
|
write_file(conversation_file_path, history_data.model_dump_json(indent=2))
|
69
69
|
# Update the last-session pointer
|
70
|
-
last_session_file_path = os.path.join(LLM_HISTORY_DIR, "last-session")
|
70
|
+
last_session_file_path = os.path.join(CFG.LLM_HISTORY_DIR, "last-session")
|
71
71
|
write_file(last_session_file_path, current_session_name)
|
72
72
|
except Exception as e:
|
73
73
|
ctx.log_error(f"Error writing history file '{conversation_file_path}': {e}")
|
@@ -0,0 +1,142 @@
|
|
1
|
+
from zrb.builtin.group import llm_group
|
2
|
+
from zrb.builtin.llm.chat_session import get_llm_ask_input_mapping, read_user_prompt
|
3
|
+
from zrb.builtin.llm.history import read_chat_conversation, write_chat_conversation
|
4
|
+
from zrb.builtin.llm.input import PreviousSessionInput
|
5
|
+
from zrb.builtin.llm.tool.api import get_current_location, get_current_weather
|
6
|
+
from zrb.builtin.llm.tool.cli import run_shell_command
|
7
|
+
from zrb.builtin.llm.tool.file import (
|
8
|
+
apply_diff,
|
9
|
+
list_files,
|
10
|
+
read_from_file,
|
11
|
+
search_files,
|
12
|
+
write_to_file,
|
13
|
+
)
|
14
|
+
from zrb.builtin.llm.tool.web import (
|
15
|
+
create_search_internet_tool,
|
16
|
+
open_web_page,
|
17
|
+
search_arxiv,
|
18
|
+
search_wikipedia,
|
19
|
+
)
|
20
|
+
from zrb.callback.callback import Callback
|
21
|
+
from zrb.config import CFG
|
22
|
+
from zrb.input.bool_input import BoolInput
|
23
|
+
from zrb.input.str_input import StrInput
|
24
|
+
from zrb.input.text_input import TextInput
|
25
|
+
from zrb.task.base_trigger import BaseTrigger
|
26
|
+
from zrb.task.llm_task import LLMTask
|
27
|
+
|
28
|
+
_llm_ask_inputs = [
|
29
|
+
StrInput(
|
30
|
+
"model",
|
31
|
+
description="LLM Model",
|
32
|
+
prompt="LLM Model",
|
33
|
+
default="",
|
34
|
+
allow_positional_parsing=False,
|
35
|
+
always_prompt=False,
|
36
|
+
allow_empty=True,
|
37
|
+
),
|
38
|
+
StrInput(
|
39
|
+
"base-url",
|
40
|
+
description="LLM API Base URL",
|
41
|
+
prompt="LLM API Base URL",
|
42
|
+
default="",
|
43
|
+
allow_positional_parsing=False,
|
44
|
+
always_prompt=False,
|
45
|
+
allow_empty=True,
|
46
|
+
),
|
47
|
+
StrInput(
|
48
|
+
"api-key",
|
49
|
+
description="LLM API Key",
|
50
|
+
prompt="LLM API Key",
|
51
|
+
default="",
|
52
|
+
allow_positional_parsing=False,
|
53
|
+
always_prompt=False,
|
54
|
+
allow_empty=True,
|
55
|
+
),
|
56
|
+
TextInput(
|
57
|
+
"system-prompt",
|
58
|
+
description="System prompt",
|
59
|
+
prompt="System prompt",
|
60
|
+
default="",
|
61
|
+
allow_positional_parsing=False,
|
62
|
+
always_prompt=False,
|
63
|
+
),
|
64
|
+
BoolInput(
|
65
|
+
"start-new",
|
66
|
+
description="Start new conversation (LLM will forget everything)",
|
67
|
+
prompt="Start new conversation (LLM will forget everything)",
|
68
|
+
default=False,
|
69
|
+
allow_positional_parsing=False,
|
70
|
+
always_prompt=False,
|
71
|
+
),
|
72
|
+
TextInput("message", description="User message", prompt="Your message"),
|
73
|
+
PreviousSessionInput(
|
74
|
+
"previous-session",
|
75
|
+
description="Previous conversation session",
|
76
|
+
prompt="Previous conversation session (can be empty)",
|
77
|
+
allow_positional_parsing=False,
|
78
|
+
allow_empty=True,
|
79
|
+
always_prompt=False,
|
80
|
+
),
|
81
|
+
]
|
82
|
+
|
83
|
+
llm_ask: LLMTask = llm_group.add_task(
|
84
|
+
LLMTask(
|
85
|
+
name="llm-ask",
|
86
|
+
input=_llm_ask_inputs,
|
87
|
+
description="❓ Ask LLM",
|
88
|
+
model=lambda ctx: None if ctx.input.model.strip() == "" else ctx.input.model,
|
89
|
+
model_base_url=lambda ctx: (
|
90
|
+
None if ctx.input.base_url.strip() == "" else ctx.input.base_url
|
91
|
+
),
|
92
|
+
model_api_key=lambda ctx: (
|
93
|
+
None if ctx.input.api_key.strip() == "" else ctx.input.api_key
|
94
|
+
),
|
95
|
+
conversation_history_reader=read_chat_conversation,
|
96
|
+
conversation_history_writer=write_chat_conversation,
|
97
|
+
system_prompt=lambda ctx: (
|
98
|
+
None if ctx.input.system_prompt.strip() == "" else ctx.input.system_prompt
|
99
|
+
),
|
100
|
+
message="{ctx.input.message}",
|
101
|
+
retries=0,
|
102
|
+
),
|
103
|
+
alias="ask",
|
104
|
+
)
|
105
|
+
|
106
|
+
llm_group.add_task(
|
107
|
+
BaseTrigger(
|
108
|
+
name="llm-chat",
|
109
|
+
input=_llm_ask_inputs,
|
110
|
+
description="💬 Chat with LLM",
|
111
|
+
queue_name="ask_trigger",
|
112
|
+
action=read_user_prompt,
|
113
|
+
callback=Callback(
|
114
|
+
task=llm_ask,
|
115
|
+
input_mapping=get_llm_ask_input_mapping,
|
116
|
+
result_queue="ask_result",
|
117
|
+
session_name_queue="ask_session_name",
|
118
|
+
),
|
119
|
+
retries=0,
|
120
|
+
cli_only=True,
|
121
|
+
),
|
122
|
+
alias="chat",
|
123
|
+
)
|
124
|
+
|
125
|
+
if CFG.LLM_ALLOW_ACCESS_LOCAL_FILE:
|
126
|
+
llm_ask.add_tool(list_files)
|
127
|
+
llm_ask.add_tool(read_from_file)
|
128
|
+
llm_ask.add_tool(write_to_file)
|
129
|
+
llm_ask.add_tool(search_files)
|
130
|
+
llm_ask.add_tool(apply_diff)
|
131
|
+
|
132
|
+
if CFG.LLM_ALLOW_ACCESS_SHELL:
|
133
|
+
llm_ask.add_tool(run_shell_command)
|
134
|
+
|
135
|
+
if CFG.LLM_ALLOW_ACCESS_INTERNET:
|
136
|
+
llm_ask.add_tool(open_web_page)
|
137
|
+
llm_ask.add_tool(search_wikipedia)
|
138
|
+
llm_ask.add_tool(search_arxiv)
|
139
|
+
if CFG.SERP_API_KEY != "":
|
140
|
+
llm_ask.add_tool(create_search_internet_tool(CFG.SERP_API_KEY))
|
141
|
+
llm_ask.add_tool(get_current_location)
|
142
|
+
llm_ask.add_tool(get_current_weather)
|
zrb/builtin/llm/tool/rag.py
CHANGED
@@ -8,14 +8,7 @@ from textwrap import dedent
|
|
8
8
|
|
9
9
|
import ulid
|
10
10
|
|
11
|
-
from zrb.config import
|
12
|
-
RAG_CHUNK_SIZE,
|
13
|
-
RAG_EMBEDDING_API_KEY,
|
14
|
-
RAG_EMBEDDING_BASE_URL,
|
15
|
-
RAG_EMBEDDING_MODEL,
|
16
|
-
RAG_MAX_RESULT_COUNT,
|
17
|
-
RAG_OVERLAP,
|
18
|
-
)
|
11
|
+
from zrb.config import CFG
|
19
12
|
from zrb.util.cli.style import stylize_error, stylize_faint
|
20
13
|
from zrb.util.file import read_file
|
21
14
|
|
@@ -42,13 +35,13 @@ def create_rag_from_directory(
|
|
42
35
|
document_dir_path: str = "./documents",
|
43
36
|
vector_db_path: str = "./chroma",
|
44
37
|
vector_db_collection: str = "documents",
|
45
|
-
chunk_size: int =
|
46
|
-
overlap: int =
|
47
|
-
max_result_count: int =
|
38
|
+
chunk_size: int | None = None,
|
39
|
+
overlap: int | None = None,
|
40
|
+
max_result_count: int | None = None,
|
48
41
|
file_reader: list[RAGFileReader] = [],
|
49
|
-
openai_api_key: str =
|
50
|
-
openai_base_url: str =
|
51
|
-
openai_embedding_model: str =
|
42
|
+
openai_api_key: str | None = None,
|
43
|
+
openai_base_url: str | None = None,
|
44
|
+
openai_embedding_model: str | None = None,
|
52
45
|
):
|
53
46
|
"""Create a RAG retrieval tool function for LLM use.
|
54
47
|
This factory configures and returns an async function that takes a query,
|
@@ -64,10 +57,33 @@ def create_rag_from_directory(
|
|
64
57
|
|
65
58
|
# Initialize OpenAI client with custom URL if provided
|
66
59
|
client_args = {}
|
67
|
-
if
|
68
|
-
|
69
|
-
|
70
|
-
|
60
|
+
# Initialize OpenAI client with custom URL if provided
|
61
|
+
client_args = {}
|
62
|
+
api_key_val = (
|
63
|
+
openai_api_key if openai_api_key is not None else CFG.RAG_EMBEDDING_API_KEY
|
64
|
+
)
|
65
|
+
base_url_val = (
|
66
|
+
openai_base_url
|
67
|
+
if openai_base_url is not None
|
68
|
+
else CFG.RAG_EMBEDDING_BASE_URL
|
69
|
+
)
|
70
|
+
embedding_model_val = (
|
71
|
+
openai_embedding_model
|
72
|
+
if openai_embedding_model is not None
|
73
|
+
else CFG.RAG_EMBEDDING_MODEL
|
74
|
+
)
|
75
|
+
chunk_size_val = chunk_size if chunk_size is not None else CFG.RAG_CHUNK_SIZE
|
76
|
+
overlap_val = overlap if overlap is not None else CFG.RAG_OVERLAP
|
77
|
+
max_result_count_val = (
|
78
|
+
max_result_count
|
79
|
+
if max_result_count is not None
|
80
|
+
else CFG.RAG_MAX_RESULT_COUNT
|
81
|
+
)
|
82
|
+
|
83
|
+
if api_key_val:
|
84
|
+
client_args["api_key"] = api_key_val
|
85
|
+
if base_url_val:
|
86
|
+
client_args["base_url"] = base_url_val
|
71
87
|
# Initialize OpenAI client for embeddings
|
72
88
|
openai_client = OpenAI(**client_args)
|
73
89
|
# Initialize ChromaDB client
|
@@ -101,8 +117,8 @@ def create_rag_from_directory(
|
|
101
117
|
collection.delete(where={"file_path": relative_path})
|
102
118
|
content = _read_txt_content(file_path, file_reader)
|
103
119
|
file_id = ulid.new().str
|
104
|
-
for i in range(0, len(content),
|
105
|
-
chunk = content[i : i +
|
120
|
+
for i in range(0, len(content), chunk_size_val - overlap_val):
|
121
|
+
chunk = content[i : i + chunk_size_val]
|
106
122
|
if chunk:
|
107
123
|
chunk_id = ulid.new().str
|
108
124
|
print(
|
@@ -113,7 +129,7 @@ def create_rag_from_directory(
|
|
113
129
|
)
|
114
130
|
# Get embeddings using OpenAI
|
115
131
|
embedding_response = openai_client.embeddings.create(
|
116
|
-
input=chunk, model=
|
132
|
+
input=chunk, model=embedding_model_val
|
117
133
|
)
|
118
134
|
vector = embedding_response.data[0].embedding
|
119
135
|
collection.upsert(
|
@@ -140,13 +156,13 @@ def create_rag_from_directory(
|
|
140
156
|
print(stylize_faint("Vectorizing query"), file=sys.stderr)
|
141
157
|
# Get embeddings using OpenAI
|
142
158
|
embedding_response = openai_client.embeddings.create(
|
143
|
-
input=query, model=
|
159
|
+
input=query, model=embedding_model_val
|
144
160
|
)
|
145
161
|
query_vector = embedding_response.data[0].embedding
|
146
162
|
print(stylize_faint("Searching documents"), file=sys.stderr)
|
147
163
|
results = collection.query(
|
148
164
|
query_embeddings=query_vector,
|
149
|
-
n_results=
|
165
|
+
n_results=max_result_count_val,
|
150
166
|
)
|
151
167
|
return json.dumps(results)
|
152
168
|
|