zrb 1.5.17__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. zrb/__init__.py +2 -2
  2. zrb/__main__.py +12 -12
  3. zrb/builtin/__init__.py +2 -2
  4. zrb/builtin/llm/chat_session.py +202 -0
  5. zrb/builtin/llm/history.py +6 -6
  6. zrb/builtin/llm/llm_ask.py +142 -0
  7. zrb/builtin/llm/tool/rag.py +39 -23
  8. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/view/static/default/script.js +1 -1
  9. zrb/builtin/todo.py +21 -19
  10. zrb/callback/any_callback.py +1 -1
  11. zrb/callback/callback.py +69 -7
  12. zrb/config.py +261 -91
  13. zrb/context/shared_context.py +4 -2
  14. zrb/input/text_input.py +9 -6
  15. zrb/llm_config.py +65 -74
  16. zrb/runner/cli.py +13 -4
  17. zrb/runner/web_app.py +3 -3
  18. zrb/runner/web_config/config_factory.py +11 -22
  19. zrb/runner/web_route/error_page/show_error_page.py +16 -6
  20. zrb/runner/web_route/home_page/home_page_route.py +23 -7
  21. zrb/runner/web_route/home_page/view.html +19 -33
  22. zrb/runner/web_route/login_page/login_page_route.py +14 -4
  23. zrb/runner/web_route/login_page/view.html +33 -51
  24. zrb/runner/web_route/logout_page/logout_page_route.py +15 -5
  25. zrb/runner/web_route/logout_page/view.html +23 -41
  26. zrb/runner/web_route/node_page/group/show_group_page.py +26 -10
  27. zrb/runner/web_route/node_page/group/view.html +22 -37
  28. zrb/runner/web_route/node_page/task/show_task_page.py +34 -19
  29. zrb/runner/web_route/node_page/task/view.html +74 -88
  30. zrb/runner/web_route/static/global_template.html +27 -0
  31. zrb/runner/web_route/static/resources/common.css +21 -0
  32. zrb/runner/web_route/static/resources/common.js +28 -0
  33. zrb/runner/web_route/task_session_api_route.py +3 -1
  34. zrb/session_state_logger/session_state_logger_factory.py +2 -2
  35. zrb/task/base_task.py +4 -1
  36. zrb/task/base_trigger.py +47 -2
  37. zrb/task/cmd_task.py +3 -3
  38. zrb/task/llm/agent.py +10 -1
  39. zrb/task/llm/print_node.py +5 -6
  40. zrb/task/llm_task.py +1 -1
  41. {zrb-1.5.17.dist-info → zrb-1.6.0.dist-info}/METADATA +1 -1
  42. {zrb-1.5.17.dist-info → zrb-1.6.0.dist-info}/RECORD +44 -41
  43. zrb/builtin/llm/llm_chat.py +0 -124
  44. {zrb-1.5.17.dist-info → zrb-1.6.0.dist-info}/WHEEL +0 -0
  45. {zrb-1.5.17.dist-info → zrb-1.6.0.dist-info}/entry_points.txt +0 -0
zrb/__init__.py CHANGED
@@ -11,7 +11,7 @@ from zrb.callback.any_callback import AnyCallback
11
11
  from zrb.callback.callback import Callback
12
12
  from zrb.cmd.cmd_result import CmdResult
13
13
  from zrb.cmd.cmd_val import Cmd, CmdPath
14
- from zrb.config import LOAD_BUILTIN
14
+ from zrb.config import CFG
15
15
  from zrb.content_transformer.any_content_transformer import AnyContentTransformer
16
16
  from zrb.content_transformer.content_transformer import ContentTransformer
17
17
  from zrb.context.any_context import AnyContext
@@ -109,7 +109,7 @@ assert Xcom
109
109
  assert web_config
110
110
  assert User
111
111
 
112
- if LOAD_BUILTIN:
112
+ if CFG.LOAD_BUILTIN:
113
113
  from zrb import builtin
114
114
 
115
115
  assert builtin
zrb/__main__.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
2
  import os
3
3
  import sys
4
4
 
5
- from zrb.config import INIT_MODULES, INIT_SCRIPTS, LOGGER, LOGGING_LEVEL
5
+ from zrb.config import CFG
6
6
  from zrb.runner.cli import cli
7
7
  from zrb.util.cli.style import stylize_error, stylize_faint, stylize_warning
8
8
  from zrb.util.group import NodeNotFoundError
@@ -23,29 +23,29 @@ class FaintFormatter(logging.Formatter):
23
23
 
24
24
 
25
25
  def serve_cli():
26
- LOGGER.setLevel(LOGGING_LEVEL)
26
+ CFG.LOGGER.setLevel(CFG.LOGGING_LEVEL)
27
27
  # Remove existing handlers to avoid duplicates/default formatting
28
- for handler in LOGGER.handlers[:]:
29
- LOGGER.removeHandler(handler)
28
+ for handler in CFG.LOGGER.handlers[:]:
29
+ CFG.LOGGER.removeHandler(handler)
30
30
  handler = logging.StreamHandler()
31
31
  handler.setFormatter(FaintFormatter())
32
- LOGGER.addHandler(handler)
32
+ CFG.LOGGER.addHandler(handler)
33
33
  # --- End Logging Configuration ---
34
34
  try:
35
35
  # load init modules
36
- for init_module in INIT_MODULES:
37
- LOGGER.info(f"Loading {init_module}")
36
+ for init_module in CFG.INIT_MODULES:
37
+ CFG.LOGGER.info(f"Loading {init_module}")
38
38
  load_module(init_module)
39
39
  zrb_init_path_list = _get_zrb_init_path_list()
40
40
  # load init scripts
41
- for init_script in INIT_SCRIPTS:
41
+ for init_script in CFG.INIT_SCRIPTS:
42
42
  abs_init_script = os.path.abspath(os.path.expanduser(init_script))
43
43
  if abs_init_script not in zrb_init_path_list:
44
- LOGGER.info(f"Loading {abs_init_script}")
44
+ CFG.LOGGER.info(f"Loading {abs_init_script}")
45
45
  load_file(abs_init_script, -1)
46
46
  # load zrb init
47
47
  for zrb_init_path in zrb_init_path_list:
48
- LOGGER.info(f"Loading {zrb_init_path}")
48
+ CFG.LOGGER.info(f"Loading {zrb_init_path}")
49
49
  load_file(zrb_init_path)
50
50
  # run the CLI
51
51
  cli.run(sys.argv[1:])
@@ -69,8 +69,8 @@ def _get_zrb_init_path_list() -> list[str]:
69
69
  dir_path_list.append(current_path)
70
70
  zrb_init_path_list = []
71
71
  for current_path in dir_path_list[::-1]:
72
- zrb_init_path = os.path.join(current_path, "zrb_init.py")
73
- LOGGER.info(f"Finding {zrb_init_path}")
72
+ zrb_init_path = os.path.join(current_path, CFG.INIT_FILE_NAME)
73
+ CFG.LOGGER.info(f"Finding {zrb_init_path}")
74
74
  if os.path.isfile(zrb_init_path):
75
75
  zrb_init_path_list.append(zrb_init_path)
76
76
  return zrb_init_path_list
zrb/builtin/__init__.py CHANGED
@@ -9,7 +9,7 @@ from zrb.builtin.git import (
9
9
  from zrb.builtin.git_subtree import git_add_subtree, git_pull_subtree, git_push_subtree
10
10
  from zrb.builtin.http import generate_curl, http_request
11
11
  from zrb.builtin.jwt import decode_jwt, encode_jwt, validate_jwt
12
- from zrb.builtin.llm.llm_chat import llm_chat
12
+ from zrb.builtin.llm.llm_ask import llm_ask
13
13
  from zrb.builtin.md5 import hash_md5, sum_md5, validate_md5
14
14
  from zrb.builtin.project.add.fastapp.fastapp_task import add_fastapp_to_project
15
15
  from zrb.builtin.project.create.project_task import create_project
@@ -55,7 +55,7 @@ assert validate_base64
55
55
  assert encode_jwt
56
56
  assert decode_jwt
57
57
  assert validate_jwt
58
- assert llm_chat
58
+ assert llm_ask
59
59
  assert hash_md5
60
60
  assert sum_md5
61
61
  assert validate_md5
@@ -0,0 +1,202 @@
1
+ """
2
+ This module provides functions for managing interactive chat sessions with an LLM.
3
+
4
+ It handles reading user input, triggering the LLM task, and managing the
5
+ conversation flow via XCom.
6
+ """
7
+
8
+ import asyncio
9
+
10
+ from zrb.context.any_context import AnyContext
11
+ from zrb.util.cli.style import stylize_bold_yellow, stylize_faint
12
+
13
+
14
+ async def read_user_prompt(ctx: AnyContext) -> str:
15
+ """
16
+ Reads user input from the CLI for an interactive chat session.
17
+
18
+ Handles special commands like /bye, /multi, /end, and /help.
19
+ Triggers the LLM task and waits for the result.
20
+
21
+ Args:
22
+ ctx: The context object for the task.
23
+
24
+ Returns:
25
+ The final result from the LLM session.
26
+ """
27
+ _show_info(ctx)
28
+ final_result = ""
29
+ ctx.print(stylize_faint("🧑 >> ") + f"{ctx.input.message}", plain=True)
30
+ result = await _trigger_ask_and_wait_for_result(
31
+ ctx,
32
+ user_prompt=ctx.input.message,
33
+ previous_session_name=ctx.input.previous_session,
34
+ start_new=ctx.input.start_new,
35
+ )
36
+ if ctx.env.get("_ZRB_WEB_ENV", "0") != "0":
37
+ # Don't run in web environment
38
+ if result is not None:
39
+ final_result = result
40
+ return final_result
41
+ multiline_mode = False
42
+ user_inputs = []
43
+ while True:
44
+ await asyncio.sleep(0.01)
45
+ ctx.print(stylize_faint("🧑 >> "), end="", plain=True)
46
+ user_input = input()
47
+ # Handle special input
48
+ if user_input.strip().lower() in ("/bye", "/quit"):
49
+ user_prompt = "\n".join(user_inputs)
50
+ user_inputs = []
51
+ result = await _trigger_ask_and_wait_for_result(ctx, user_prompt)
52
+ if result is not None:
53
+ final_result = result
54
+ break
55
+ elif user_input.strip().lower() in ("/multi"):
56
+ multiline_mode = True
57
+ elif user_input.strip().lower() in ("/end"):
58
+ multiline_mode = False
59
+ user_prompt = "\n".join(user_inputs)
60
+ user_inputs = []
61
+ result = await _trigger_ask_and_wait_for_result(ctx, user_prompt)
62
+ if result is not None:
63
+ final_result = result
64
+ elif user_input.strip().lower() in ("/help", "/info"):
65
+ _show_info(ctx)
66
+ continue
67
+ else:
68
+ user_inputs.append(user_input)
69
+ if multiline_mode:
70
+ continue
71
+ user_prompt = "\n".join(user_inputs)
72
+ user_inputs = []
73
+ result = await _trigger_ask_and_wait_for_result(ctx, user_prompt)
74
+ if result is not None:
75
+ final_result = result
76
+ return final_result
77
+
78
+
79
+ async def _trigger_ask_and_wait_for_result(
80
+ ctx: AnyContext,
81
+ user_prompt: str,
82
+ previous_session_name: str | None = None,
83
+ start_new: bool = False,
84
+ ) -> str | None:
85
+ """
86
+ Triggers the LLM ask task and waits for the result via XCom.
87
+
88
+ Args:
89
+ ctx: The context object for the task.
90
+ user_prompt: The user's message to send to the LLM.
91
+ previous_session_name: The name of the previous chat session (optional).
92
+ start_new: Whether to start a new conversation (optional).
93
+
94
+ Returns:
95
+ The result from the LLM task, or None if the user prompt is empty.
96
+ """
97
+ if user_prompt.strip() == "":
98
+ return None
99
+ await _trigger_ask(ctx, user_prompt, previous_session_name, start_new)
100
+ result = await _wait_ask_result(ctx)
101
+ ctx.print(stylize_faint("🤖 >> ") + result, plain=True)
102
+ return result
103
+
104
+
105
+ def get_llm_ask_input_mapping(callback_ctx: AnyContext):
106
+ """
107
+ Generates the input mapping for the LLM ask task from the callback context.
108
+
109
+ Args:
110
+ callback_ctx: The context object for the callback.
111
+
112
+ Returns:
113
+ A dictionary containing the input mapping for the LLM ask task.
114
+ """
115
+ data = callback_ctx.xcom.ask_trigger.pop()
116
+ return {
117
+ "model": callback_ctx.input.model,
118
+ "base-url": callback_ctx.input.base_url,
119
+ "api-key": callback_ctx.input.api_key,
120
+ "system-prompt": callback_ctx.input.system_prompt,
121
+ "start-new": data.get("start_new"),
122
+ "previous-session": data.get("previous_session_name"),
123
+ "message": data.get("message"),
124
+ }
125
+
126
+
127
+ async def _trigger_ask(
128
+ ctx: AnyContext,
129
+ user_prompt: str,
130
+ previous_session_name: str | None = None,
131
+ start_new: bool = False,
132
+ ):
133
+ """
134
+ Triggers the LLM ask task by pushing data to the 'ask_trigger' XCom queue.
135
+
136
+ Args:
137
+ ctx: The context object for the task.
138
+ user_prompt: The user's message to send to the LLM.
139
+ previous_session_name: The name of the previous chat session (optional).
140
+ start_new: Whether to start a new conversation (optional).
141
+ """
142
+ if previous_session_name is None:
143
+ previous_session_name = await _wait_ask_session_name(ctx)
144
+ ctx.xcom["ask_trigger"].push(
145
+ {
146
+ "previous_session_name": previous_session_name,
147
+ "start_new": start_new,
148
+ "message": user_prompt,
149
+ }
150
+ )
151
+
152
+
153
+ async def _wait_ask_result(ctx: AnyContext) -> str:
154
+ """
155
+ Waits for and retrieves the LLM task result from the 'ask_result' XCom queue.
156
+
157
+ Args:
158
+ ctx: The context object for the task.
159
+
160
+ Returns:
161
+ The result string from the LLM task.
162
+ """
163
+ while "ask_result" not in ctx.xcom or len(ctx.xcom.ask_result) == 0:
164
+ await asyncio.sleep(0.1)
165
+ return ctx.xcom.ask_result.pop()
166
+
167
+
168
+ async def _wait_ask_session_name(ctx: AnyContext) -> str:
169
+ """
170
+ Waits for and retrieves the LLM chat session name from the 'ask_session_name' XCom queue.
171
+
172
+ Args:
173
+ ctx: The context object for the task.
174
+
175
+ Returns:
176
+ The session name string.
177
+ """
178
+ while "ask_session_name" not in ctx.xcom or len(ctx.xcom.ask_session_name) == 0:
179
+ await asyncio.sleep(0.1)
180
+ return ctx.xcom.ask_session_name.pop()
181
+
182
+
183
+ def _show_info(ctx: AnyContext):
184
+ """
185
+ Displays the available chat session commands to the user.
186
+
187
+ Args:
188
+ ctx: The context object for the task.
189
+ """
190
+ ctx.print(
191
+ stylize_bold_yellow(
192
+ "\n".join(
193
+ [
194
+ "/bye: Quit from chat session",
195
+ "/multi: Start multiline input",
196
+ "/end: End multiline input",
197
+ "/help: Show this message",
198
+ ]
199
+ )
200
+ ),
201
+ plain=True,
202
+ )
@@ -2,7 +2,7 @@ import json
2
2
  import os
3
3
  from typing import Any
4
4
 
5
- from zrb.config import LLM_HISTORY_DIR
5
+ from zrb.config import CFG
6
6
  from zrb.context.any_shared_context import AnySharedContext
7
7
  from zrb.task.llm.history import ConversationHistoryData
8
8
  from zrb.util.file import read_file, write_file
@@ -17,7 +17,7 @@ def read_chat_conversation(ctx: AnySharedContext) -> dict[str, Any] | list | Non
17
17
  return None # Indicate no history to load
18
18
  previous_session_name = ctx.input.previous_session
19
19
  if not previous_session_name: # Check for empty string or None
20
- last_session_file_path = os.path.join(LLM_HISTORY_DIR, "last-session")
20
+ last_session_file_path = os.path.join(CFG.LLM_HISTORY_DIR, "last-session")
21
21
  if os.path.isfile(last_session_file_path):
22
22
  previous_session_name = read_file(last_session_file_path).strip()
23
23
  if not previous_session_name: # Handle empty last-session file
@@ -25,7 +25,7 @@ def read_chat_conversation(ctx: AnySharedContext) -> dict[str, Any] | list | Non
25
25
  else:
26
26
  return None # No previous session specified and no last session found
27
27
  conversation_file_path = os.path.join(
28
- LLM_HISTORY_DIR, f"{previous_session_name}.json"
28
+ CFG.LLM_HISTORY_DIR, f"{previous_session_name}.json"
29
29
  )
30
30
  if not os.path.isfile(conversation_file_path):
31
31
  ctx.log_warning(f"History file not found: {conversation_file_path}")
@@ -55,19 +55,19 @@ def write_chat_conversation(
55
55
  ctx: AnySharedContext, history_data: ConversationHistoryData
56
56
  ):
57
57
  """Writes the conversation history data (including context) to a session file."""
58
- os.makedirs(LLM_HISTORY_DIR, exist_ok=True)
58
+ os.makedirs(CFG.LLM_HISTORY_DIR, exist_ok=True)
59
59
  current_session_name = ctx.session.name
60
60
  if not current_session_name:
61
61
  ctx.log_warning("Cannot write history: Session name is empty.")
62
62
  return
63
63
  conversation_file_path = os.path.join(
64
- LLM_HISTORY_DIR, f"{current_session_name}.json"
64
+ CFG.LLM_HISTORY_DIR, f"{current_session_name}.json"
65
65
  )
66
66
  try:
67
67
  # Use model_dump_json to serialize the Pydantic model
68
68
  write_file(conversation_file_path, history_data.model_dump_json(indent=2))
69
69
  # Update the last-session pointer
70
- last_session_file_path = os.path.join(LLM_HISTORY_DIR, "last-session")
70
+ last_session_file_path = os.path.join(CFG.LLM_HISTORY_DIR, "last-session")
71
71
  write_file(last_session_file_path, current_session_name)
72
72
  except Exception as e:
73
73
  ctx.log_error(f"Error writing history file '{conversation_file_path}': {e}")
@@ -0,0 +1,142 @@
1
+ from zrb.builtin.group import llm_group
2
+ from zrb.builtin.llm.chat_session import get_llm_ask_input_mapping, read_user_prompt
3
+ from zrb.builtin.llm.history import read_chat_conversation, write_chat_conversation
4
+ from zrb.builtin.llm.input import PreviousSessionInput
5
+ from zrb.builtin.llm.tool.api import get_current_location, get_current_weather
6
+ from zrb.builtin.llm.tool.cli import run_shell_command
7
+ from zrb.builtin.llm.tool.file import (
8
+ apply_diff,
9
+ list_files,
10
+ read_from_file,
11
+ search_files,
12
+ write_to_file,
13
+ )
14
+ from zrb.builtin.llm.tool.web import (
15
+ create_search_internet_tool,
16
+ open_web_page,
17
+ search_arxiv,
18
+ search_wikipedia,
19
+ )
20
+ from zrb.callback.callback import Callback
21
+ from zrb.config import CFG
22
+ from zrb.input.bool_input import BoolInput
23
+ from zrb.input.str_input import StrInput
24
+ from zrb.input.text_input import TextInput
25
+ from zrb.task.base_trigger import BaseTrigger
26
+ from zrb.task.llm_task import LLMTask
27
+
28
+ _llm_ask_inputs = [
29
+ StrInput(
30
+ "model",
31
+ description="LLM Model",
32
+ prompt="LLM Model",
33
+ default="",
34
+ allow_positional_parsing=False,
35
+ always_prompt=False,
36
+ allow_empty=True,
37
+ ),
38
+ StrInput(
39
+ "base-url",
40
+ description="LLM API Base URL",
41
+ prompt="LLM API Base URL",
42
+ default="",
43
+ allow_positional_parsing=False,
44
+ always_prompt=False,
45
+ allow_empty=True,
46
+ ),
47
+ StrInput(
48
+ "api-key",
49
+ description="LLM API Key",
50
+ prompt="LLM API Key",
51
+ default="",
52
+ allow_positional_parsing=False,
53
+ always_prompt=False,
54
+ allow_empty=True,
55
+ ),
56
+ TextInput(
57
+ "system-prompt",
58
+ description="System prompt",
59
+ prompt="System prompt",
60
+ default="",
61
+ allow_positional_parsing=False,
62
+ always_prompt=False,
63
+ ),
64
+ BoolInput(
65
+ "start-new",
66
+ description="Start new conversation (LLM will forget everything)",
67
+ prompt="Start new conversation (LLM will forget everything)",
68
+ default=False,
69
+ allow_positional_parsing=False,
70
+ always_prompt=False,
71
+ ),
72
+ TextInput("message", description="User message", prompt="Your message"),
73
+ PreviousSessionInput(
74
+ "previous-session",
75
+ description="Previous conversation session",
76
+ prompt="Previous conversation session (can be empty)",
77
+ allow_positional_parsing=False,
78
+ allow_empty=True,
79
+ always_prompt=False,
80
+ ),
81
+ ]
82
+
83
+ llm_ask: LLMTask = llm_group.add_task(
84
+ LLMTask(
85
+ name="llm-ask",
86
+ input=_llm_ask_inputs,
87
+ description="❓ Ask LLM",
88
+ model=lambda ctx: None if ctx.input.model.strip() == "" else ctx.input.model,
89
+ model_base_url=lambda ctx: (
90
+ None if ctx.input.base_url.strip() == "" else ctx.input.base_url
91
+ ),
92
+ model_api_key=lambda ctx: (
93
+ None if ctx.input.api_key.strip() == "" else ctx.input.api_key
94
+ ),
95
+ conversation_history_reader=read_chat_conversation,
96
+ conversation_history_writer=write_chat_conversation,
97
+ system_prompt=lambda ctx: (
98
+ None if ctx.input.system_prompt.strip() == "" else ctx.input.system_prompt
99
+ ),
100
+ message="{ctx.input.message}",
101
+ retries=0,
102
+ ),
103
+ alias="ask",
104
+ )
105
+
106
+ llm_group.add_task(
107
+ BaseTrigger(
108
+ name="llm-chat",
109
+ input=_llm_ask_inputs,
110
+ description="💬 Chat with LLM",
111
+ queue_name="ask_trigger",
112
+ action=read_user_prompt,
113
+ callback=Callback(
114
+ task=llm_ask,
115
+ input_mapping=get_llm_ask_input_mapping,
116
+ result_queue="ask_result",
117
+ session_name_queue="ask_session_name",
118
+ ),
119
+ retries=0,
120
+ cli_only=True,
121
+ ),
122
+ alias="chat",
123
+ )
124
+
125
+ if CFG.LLM_ALLOW_ACCESS_LOCAL_FILE:
126
+ llm_ask.add_tool(list_files)
127
+ llm_ask.add_tool(read_from_file)
128
+ llm_ask.add_tool(write_to_file)
129
+ llm_ask.add_tool(search_files)
130
+ llm_ask.add_tool(apply_diff)
131
+
132
+ if CFG.LLM_ALLOW_ACCESS_SHELL:
133
+ llm_ask.add_tool(run_shell_command)
134
+
135
+ if CFG.LLM_ALLOW_ACCESS_INTERNET:
136
+ llm_ask.add_tool(open_web_page)
137
+ llm_ask.add_tool(search_wikipedia)
138
+ llm_ask.add_tool(search_arxiv)
139
+ if CFG.SERP_API_KEY != "":
140
+ llm_ask.add_tool(create_search_internet_tool(CFG.SERP_API_KEY))
141
+ llm_ask.add_tool(get_current_location)
142
+ llm_ask.add_tool(get_current_weather)
@@ -8,14 +8,7 @@ from textwrap import dedent
8
8
 
9
9
  import ulid
10
10
 
11
- from zrb.config import (
12
- RAG_CHUNK_SIZE,
13
- RAG_EMBEDDING_API_KEY,
14
- RAG_EMBEDDING_BASE_URL,
15
- RAG_EMBEDDING_MODEL,
16
- RAG_MAX_RESULT_COUNT,
17
- RAG_OVERLAP,
18
- )
11
+ from zrb.config import CFG
19
12
  from zrb.util.cli.style import stylize_error, stylize_faint
20
13
  from zrb.util.file import read_file
21
14
 
@@ -42,13 +35,13 @@ def create_rag_from_directory(
42
35
  document_dir_path: str = "./documents",
43
36
  vector_db_path: str = "./chroma",
44
37
  vector_db_collection: str = "documents",
45
- chunk_size: int = RAG_CHUNK_SIZE,
46
- overlap: int = RAG_OVERLAP,
47
- max_result_count: int = RAG_MAX_RESULT_COUNT,
38
+ chunk_size: int | None = None,
39
+ overlap: int | None = None,
40
+ max_result_count: int | None = None,
48
41
  file_reader: list[RAGFileReader] = [],
49
- openai_api_key: str = RAG_EMBEDDING_API_KEY,
50
- openai_base_url: str = RAG_EMBEDDING_BASE_URL,
51
- openai_embedding_model: str = RAG_EMBEDDING_MODEL,
42
+ openai_api_key: str | None = None,
43
+ openai_base_url: str | None = None,
44
+ openai_embedding_model: str | None = None,
52
45
  ):
53
46
  """Create a RAG retrieval tool function for LLM use.
54
47
  This factory configures and returns an async function that takes a query,
@@ -64,10 +57,33 @@ def create_rag_from_directory(
64
57
 
65
58
  # Initialize OpenAI client with custom URL if provided
66
59
  client_args = {}
67
- if openai_api_key:
68
- client_args["api_key"] = openai_api_key
69
- if openai_base_url:
70
- client_args["base_url"] = openai_base_url
60
+ # Initialize OpenAI client with custom URL if provided
61
+ client_args = {}
62
+ api_key_val = (
63
+ openai_api_key if openai_api_key is not None else CFG.RAG_EMBEDDING_API_KEY
64
+ )
65
+ base_url_val = (
66
+ openai_base_url
67
+ if openai_base_url is not None
68
+ else CFG.RAG_EMBEDDING_BASE_URL
69
+ )
70
+ embedding_model_val = (
71
+ openai_embedding_model
72
+ if openai_embedding_model is not None
73
+ else CFG.RAG_EMBEDDING_MODEL
74
+ )
75
+ chunk_size_val = chunk_size if chunk_size is not None else CFG.RAG_CHUNK_SIZE
76
+ overlap_val = overlap if overlap is not None else CFG.RAG_OVERLAP
77
+ max_result_count_val = (
78
+ max_result_count
79
+ if max_result_count is not None
80
+ else CFG.RAG_MAX_RESULT_COUNT
81
+ )
82
+
83
+ if api_key_val:
84
+ client_args["api_key"] = api_key_val
85
+ if base_url_val:
86
+ client_args["base_url"] = base_url_val
71
87
  # Initialize OpenAI client for embeddings
72
88
  openai_client = OpenAI(**client_args)
73
89
  # Initialize ChromaDB client
@@ -101,8 +117,8 @@ def create_rag_from_directory(
101
117
  collection.delete(where={"file_path": relative_path})
102
118
  content = _read_txt_content(file_path, file_reader)
103
119
  file_id = ulid.new().str
104
- for i in range(0, len(content), chunk_size - overlap):
105
- chunk = content[i : i + chunk_size]
120
+ for i in range(0, len(content), chunk_size_val - overlap_val):
121
+ chunk = content[i : i + chunk_size_val]
106
122
  if chunk:
107
123
  chunk_id = ulid.new().str
108
124
  print(
@@ -113,7 +129,7 @@ def create_rag_from_directory(
113
129
  )
114
130
  # Get embeddings using OpenAI
115
131
  embedding_response = openai_client.embeddings.create(
116
- input=chunk, model=openai_embedding_model
132
+ input=chunk, model=embedding_model_val
117
133
  )
118
134
  vector = embedding_response.data[0].embedding
119
135
  collection.upsert(
@@ -140,13 +156,13 @@ def create_rag_from_directory(
140
156
  print(stylize_faint("Vectorizing query"), file=sys.stderr)
141
157
  # Get embeddings using OpenAI
142
158
  embedding_response = openai_client.embeddings.create(
143
- input=query, model=openai_embedding_model
159
+ input=query, model=embedding_model_val
144
160
  )
145
161
  query_vector = embedding_response.data[0].embedding
146
162
  print(stylize_faint("Searching documents"), file=sys.stderr)
147
163
  results = collection.query(
148
164
  query_embeddings=query_vector,
149
- n_results=max_result_count,
165
+ n_results=max_result_count_val,
150
166
  )
151
167
  return json.dumps(results)
152
168
 
@@ -41,4 +41,4 @@ function updateAutoTheme() {
41
41
  }
42
42
 
43
43
  updateAutoTheme();
44
- window.matchMedia('(prefers-color-scheme: dark)').addListener(updateAutoTheme);
44
+ window.matchMedia('(prefers-color-scheme: dark)').addEventListener('change', updateAutoTheme);