zrb 1.0.0b2__py3-none-any.whl → 1.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. zrb/__main__.py +3 -0
  2. zrb/builtin/llm/llm_chat.py +85 -5
  3. zrb/builtin/llm/previous-session.js +13 -0
  4. zrb/builtin/llm/tool/api.py +29 -0
  5. zrb/builtin/llm/tool/cli.py +1 -1
  6. zrb/builtin/llm/tool/rag.py +108 -145
  7. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/client_method.py +6 -6
  8. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py +3 -1
  9. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py +88 -44
  10. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py +12 -0
  11. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py +28 -22
  12. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py +6 -6
  13. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py +43 -29
  14. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_repository.py +8 -0
  15. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py +46 -14
  16. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py +158 -20
  17. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py +29 -0
  18. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py +36 -14
  19. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py +14 -14
  20. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py +1 -1
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py +34 -6
  22. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py +2 -6
  23. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py +41 -2
  24. zrb/builtin/todo.py +1 -0
  25. zrb/config.py +23 -4
  26. zrb/input/any_input.py +5 -0
  27. zrb/input/base_input.py +6 -0
  28. zrb/input/bool_input.py +2 -0
  29. zrb/input/float_input.py +2 -0
  30. zrb/input/int_input.py +2 -0
  31. zrb/input/option_input.py +2 -0
  32. zrb/input/password_input.py +2 -0
  33. zrb/input/text_input.py +2 -0
  34. zrb/runner/common_util.py +1 -1
  35. zrb/runner/web_route/error_page/show_error_page.py +2 -1
  36. zrb/runner/web_route/static/resources/session/current-session.js +4 -2
  37. zrb/runner/web_route/static/resources/session/event.js +8 -2
  38. zrb/runner/web_route/task_session_api_route.py +48 -3
  39. zrb/task/base_task.py +14 -13
  40. zrb/task/llm_task.py +214 -84
  41. zrb/util/llm/tool.py +3 -7
  42. {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/METADATA +2 -1
  43. {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/RECORD +45 -43
  44. {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/WHEEL +0 -0
  45. {zrb-1.0.0b2.dist-info → zrb-1.0.0b4.dist-info}/entry_points.txt +0 -0
zrb/__main__.py CHANGED
@@ -16,5 +16,8 @@ def serve_cli():
16
16
  cli.run(sys.argv[1:])
17
17
  except KeyboardInterrupt:
18
18
  print(stylize_warning("\nStopped"), file=sys.stderr)
19
+ except RuntimeError as e:
20
+ if f"{e}".lower() != "event loop is closed":
21
+ raise e
19
22
  except NodeNotFoundError as e:
20
23
  print(stylize_error(f"{e}"), file=sys.stderr)
@@ -1,16 +1,76 @@
1
+ import json
2
+ import os
3
+ from typing import Any
4
+
1
5
  from zrb.builtin.group import llm_group
6
+ from zrb.builtin.llm.tool.api import get_current_location, get_current_weather
2
7
  from zrb.builtin.llm.tool.cli import run_shell_command
3
8
  from zrb.builtin.llm.tool.web import open_web_route, query_internet
4
9
  from zrb.config import (
10
+ LLM_ALLOW_ACCESS_INTERNET,
5
11
  LLM_ALLOW_ACCESS_SHELL,
6
- LLM_ALLOW_ACCESS_WEB,
7
- LLM_HISTORY_FILE,
12
+ LLM_HISTORY_DIR,
8
13
  LLM_MODEL,
9
14
  LLM_SYSTEM_PROMPT,
10
15
  )
16
+ from zrb.context.any_shared_context import AnySharedContext
17
+ from zrb.input.bool_input import BoolInput
11
18
  from zrb.input.str_input import StrInput
12
19
  from zrb.input.text_input import TextInput
13
20
  from zrb.task.llm_task import LLMTask
21
+ from zrb.util.file import read_file, write_file
22
+ from zrb.util.string.conversion import to_pascal_case
23
+
24
+
25
+ class PreviousSessionInput(StrInput):
26
+
27
+ def to_html(self, ctx: AnySharedContext) -> str:
28
+ name = self.name
29
+ description = self.description
30
+ default = self.get_default_str(ctx)
31
+ script = read_file(
32
+ file_path=os.path.join(os.path.dirname(__file__), "previous-session.js"),
33
+ replace_map={
34
+ "CURRENT_INPUT_NAME": name,
35
+ "CurrentPascalInputName": to_pascal_case(name),
36
+ },
37
+ )
38
+ return "\n".join(
39
+ [
40
+ f'<input name="{name}" placeholder="{description}" value="{default}" />',
41
+ f"<script>{script}</script>",
42
+ ]
43
+ )
44
+
45
+
46
+ def _read_chat_conversation(ctx: AnySharedContext) -> list[dict[str, Any]]:
47
+ if ctx.input.start_new:
48
+ return []
49
+ previous_session_name = ctx.input.previous_session
50
+ if previous_session_name == "" or previous_session_name is None:
51
+ last_session_file_path = os.path.join(LLM_HISTORY_DIR, "last-session")
52
+ if os.path.isfile(last_session_file_path):
53
+ previous_session_name = read_file(last_session_file_path).strip()
54
+ conversation_file_path = os.path.join(
55
+ LLM_HISTORY_DIR, f"{previous_session_name}.json"
56
+ )
57
+ if not os.path.isfile(conversation_file_path):
58
+ return []
59
+ return json.loads(read_file(conversation_file_path))
60
+
61
+
62
+ def _write_chat_conversation(
63
+ ctx: AnySharedContext, conversations: list[dict[str, Any]]
64
+ ):
65
+ os.makedirs(LLM_HISTORY_DIR, exist_ok=True)
66
+ current_session_name = ctx.session.name
67
+ conversation_file_path = os.path.join(
68
+ LLM_HISTORY_DIR, f"{current_session_name}.json"
69
+ )
70
+ write_file(conversation_file_path, json.dumps(conversations, indent=2))
71
+ last_session_file_path = os.path.join(LLM_HISTORY_DIR, "last-session")
72
+ write_file(last_session_file_path, current_session_name)
73
+
14
74
 
15
75
  llm_chat: LLMTask = llm_group.add_task(
16
76
  LLMTask(
@@ -21,20 +81,38 @@ llm_chat: LLMTask = llm_group.add_task(
21
81
  description="LLM Model",
22
82
  prompt="LLM Model",
23
83
  default_str=LLM_MODEL,
84
+ allow_positional_parsing=False,
24
85
  ),
25
- StrInput(
86
+ TextInput(
26
87
  "system-prompt",
27
88
  description="System prompt",
28
89
  prompt="System prompt",
29
90
  default_str=LLM_SYSTEM_PROMPT,
91
+ allow_positional_parsing=False,
92
+ ),
93
+ BoolInput(
94
+ "start-new",
95
+ description="Start new conversation session",
96
+ prompt="Forget everything and start new conversation session",
97
+ default_str="false",
98
+ allow_positional_parsing=False,
30
99
  ),
31
100
  TextInput("message", description="User message", prompt="Your message"),
101
+ PreviousSessionInput(
102
+ "previous-session",
103
+ description="Previous conversation session",
104
+ prompt="Previous conversation session (can be empty)",
105
+ allow_positional_parsing=False,
106
+ allow_empty=True,
107
+ ),
32
108
  ],
33
- history_file=LLM_HISTORY_FILE,
109
+ conversation_history_reader=_read_chat_conversation,
110
+ conversation_history_writer=_write_chat_conversation,
34
111
  description="Chat with LLM",
35
112
  model="{ctx.input.model}",
36
113
  system_prompt="{ctx.input['system-prompt']}",
37
114
  message="{ctx.input.message}",
115
+ retries=0,
38
116
  ),
39
117
  alias="chat",
40
118
  )
@@ -42,6 +120,8 @@ llm_chat: LLMTask = llm_group.add_task(
42
120
  if LLM_ALLOW_ACCESS_SHELL:
43
121
  llm_chat.add_tool(run_shell_command)
44
122
 
45
- if LLM_ALLOW_ACCESS_WEB:
123
+ if LLM_ALLOW_ACCESS_INTERNET:
46
124
  llm_chat.add_tool(open_web_route)
47
125
  llm_chat.add_tool(query_internet)
126
+ llm_chat.add_tool(get_current_location)
127
+ llm_chat.add_tool(get_current_weather)
@@ -0,0 +1,13 @@
1
+ let hasUpdateCurrentPascalInputName = false;
2
+ document.getElementById("submit-task-form").addEventListener("change", async function(event) {
3
+ const currentInput = event.target;
4
+ if (hasUpdateCurrentPascalInputName || currentInput.name === "CURRENT_INPUT_NAME") {
5
+ return
6
+ }
7
+ const previousSessionInput = submitTaskForm.querySelector('[name="CURRENT_INPUT_NAME"]');
8
+ if (previousSessionInput) {
9
+ const currentSessionName = cfg.SESSION_NAME
10
+ previousSessionInput.value = currentSessionName;
11
+ }
12
+ hasUpdateCurrentPascalInputName = true;
13
+ });
@@ -0,0 +1,29 @@
1
+ import json
2
+ from typing import Annotated, Literal
3
+
4
+ import requests
5
+
6
+
7
+ def get_current_location() -> (
8
+ Annotated[str, "JSON string representing latitude and longitude"]
9
+ ): # noqa
10
+ """Get the user's current location."""
11
+ return json.dumps(requests.get("http://ip-api.com/json?fields=lat,lon").json())
12
+
13
+
14
+ def get_current_weather(
15
+ latitude: float,
16
+ longitude: float,
17
+ temperature_unit: Literal["celsius", "fahrenheit"],
18
+ ) -> str:
19
+ """Get the current weather in a given location."""
20
+ resp = requests.get(
21
+ "https://api.open-meteo.com/v1/forecast",
22
+ params={
23
+ "latitude": latitude,
24
+ "longitude": longitude,
25
+ "temperature_unit": temperature_unit,
26
+ "current_weather": True,
27
+ },
28
+ )
29
+ return json.dumps(resp.json())
@@ -2,7 +2,7 @@ import subprocess
2
2
 
3
3
 
4
4
  def run_shell_command(command: str) -> str:
5
- """Running a shell command"""
5
+ """Running an actual shell command on user's computer."""
6
6
  output = subprocess.check_output(
7
7
  command, shell=True, stderr=subprocess.STDOUT, text=True
8
8
  )
@@ -1,9 +1,10 @@
1
+ import hashlib
1
2
  import json
2
3
  import os
3
4
  import sys
4
- from collections.abc import Callable, Iterable
5
5
 
6
6
  import litellm
7
+ import ulid
7
8
 
8
9
  from zrb.config import (
9
10
  RAG_CHUNK_SIZE,
@@ -13,10 +14,6 @@ from zrb.config import (
13
14
  )
14
15
  from zrb.util.cli.style import stylize_error, stylize_faint
15
16
  from zrb.util.file import read_file
16
- from zrb.util.run import run_async
17
-
18
- Document = str | Callable[[], str]
19
- Documents = Callable[[], Iterable[Document]] | Iterable[Document]
20
17
 
21
18
 
22
19
  def create_rag_from_directory(
@@ -30,86 +27,87 @@ def create_rag_from_directory(
30
27
  overlap: int = RAG_OVERLAP,
31
28
  max_result_count: int = RAG_MAX_RESULT_COUNT,
32
29
  ):
33
- return create_rag(
34
- tool_name=tool_name,
35
- tool_description=tool_description,
36
- documents=get_rag_documents(os.path.expanduser(document_dir_path)),
37
- model=model,
38
- vector_db_path=vector_db_path,
39
- vector_db_collection=vector_db_collection,
40
- reset_db=get_rag_reset_db(
41
- document_dir_path=os.path.expanduser(document_dir_path),
42
- vector_db_path=os.path.expanduser(vector_db_path),
43
- ),
44
- chunk_size=chunk_size,
45
- overlap=overlap,
46
- max_result_count=max_result_count,
47
- )
48
-
49
-
50
- def create_rag(
51
- tool_name: str,
52
- tool_description: str,
53
- documents: Documents = [],
54
- model: str = RAG_EMBEDDING_MODEL,
55
- vector_db_path: str = "./chroma",
56
- vector_db_collection: str = "documents",
57
- reset_db: Callable[[], bool] | bool = False,
58
- chunk_size: int = RAG_CHUNK_SIZE,
59
- overlap: int = RAG_OVERLAP,
60
- max_result_count: int = RAG_MAX_RESULT_COUNT,
61
- ) -> Callable[[str], str]:
62
30
  async def retrieve(query: str) -> str:
63
- import chromadb
31
+ from chromadb import PersistentClient
64
32
  from chromadb.config import Settings
65
33
 
66
- is_db_exist = os.path.isdir(vector_db_path)
67
- client = chromadb.PersistentClient(
34
+ client = PersistentClient(
68
35
  path=vector_db_path, settings=Settings(allow_reset=True)
69
36
  )
70
- should_reset_db = (
71
- await run_async(reset_db()) if callable(reset_db) else reset_db
72
- )
73
- if (not is_db_exist) or should_reset_db:
74
- client.reset()
75
- collection = client.get_or_create_collection(vector_db_collection)
76
- chunk_index = 0
77
- print(stylize_faint("Scanning documents"), file=sys.stderr)
78
- docs = await run_async(documents()) if callable(documents) else documents
79
- for document in docs:
80
- if callable(document):
81
- try:
82
- document = await run_async(document())
83
- except Exception as error:
84
- print(stylize_error(f"Error: {error}"), file=sys.stderr)
85
- continue
86
- for i in range(0, len(document), chunk_size - overlap):
87
- chunk = document[i : i + chunk_size]
88
- if len(chunk) > 0:
89
- print(
90
- stylize_faint(f"Vectorize chunk {chunk_index}"),
91
- file=sys.stderr,
92
- )
93
- response = await litellm.aembedding(model=model, input=[chunk])
94
- vector = response["data"][0]["embedding"]
95
- print(
96
- stylize_faint(f"Adding chunk {chunk_index} to db"),
97
- file=sys.stderr,
98
- )
99
- collection.upsert(
100
- ids=[f"id{chunk_index}"],
101
- embeddings=[vector],
102
- documents=[chunk],
103
- )
104
- chunk_index += 1
105
37
  collection = client.get_or_create_collection(vector_db_collection)
106
- # Generate embedding for the query
107
- print(stylize_faint("Vectorize query"), file=sys.stderr)
38
+
39
+ # Track file changes using a hash-based approach
40
+ hash_file_path = os.path.join(vector_db_path, "file_hashes.json")
41
+ previous_hashes = _load_hashes(hash_file_path)
42
+ current_hashes = {}
43
+
44
+ updated_files = []
45
+
46
+ for root, _, files in os.walk(document_dir_path):
47
+ for file in files:
48
+ file_path = os.path.join(root, file)
49
+ file_hash = _compute_file_hash(file_path)
50
+ relative_path = os.path.relpath(file_path, document_dir_path)
51
+ current_hashes[relative_path] = file_hash
52
+
53
+ if previous_hashes.get(relative_path) != file_hash:
54
+ updated_files.append(file_path)
55
+
56
+ if updated_files:
57
+ print(
58
+ stylize_faint(f"Updating {len(updated_files)} changed files"),
59
+ file=sys.stderr,
60
+ )
61
+
62
+ for file_path in updated_files:
63
+ try:
64
+ relative_path = os.path.relpath(file_path, document_dir_path)
65
+ collection.delete(where={"file_path": relative_path})
66
+ content = _read_file_content(file_path)
67
+ file_id = ulid.new().str
68
+ for i in range(0, len(content), chunk_size - overlap):
69
+ chunk = content[i : i + chunk_size]
70
+ if chunk:
71
+ chunk_id = ulid.new().str
72
+ print(
73
+ stylize_faint(
74
+ f"Vectorizing {relative_path} chunk {chunk_id}"
75
+ ),
76
+ file=sys.stderr,
77
+ )
78
+ response = await litellm.aembedding(
79
+ model=model, input=[chunk]
80
+ )
81
+ vector = response["data"][0]["embedding"]
82
+ collection.upsert(
83
+ ids=[chunk_id],
84
+ embeddings=[vector],
85
+ documents=[chunk],
86
+ metadatas={
87
+ "file_path": relative_path,
88
+ "file_id": file_id,
89
+ },
90
+ )
91
+ except Exception as e:
92
+ print(
93
+ stylize_error(f"Error processing {file_path}: {e}"),
94
+ file=sys.stderr,
95
+ )
96
+
97
+ _save_hashes(hash_file_path, current_hashes)
98
+ else:
99
+ print(
100
+ stylize_faint("No changes detected. Skipping database update."),
101
+ file=sys.stderr,
102
+ )
103
+
104
+ print(stylize_faint("Vectorizing query"), file=sys.stderr)
108
105
  query_response = await litellm.aembedding(model=model, input=[query])
109
- print(stylize_faint("Search documents"), file=sys.stderr)
110
- # Search for the top_k most similar documents
106
+ query_vector = query_response["data"][0]["embedding"]
107
+
108
+ print(stylize_faint("Searching documents"), file=sys.stderr)
111
109
  results = collection.query(
112
- query_embeddings=query_response["data"][0]["embedding"],
110
+ query_embeddings=query_vector,
113
111
  n_results=max_result_count,
114
112
  )
115
113
  return json.dumps(results)
@@ -119,71 +117,36 @@ def create_rag(
119
117
  return retrieve
120
118
 
121
119
 
122
- def get_rag_documents(document_dir_path: str) -> Callable[[], list[Callable[[], str]]]:
123
- def get_documents() -> list[Callable[[], str]]:
124
- # Walk through the directory
125
- readers = []
126
- for root, _, files in os.walk(document_dir_path):
127
- for file in files:
128
- file_path = os.path.join(root, file)
129
- if file_path.lower().endswith(".pdf"):
130
- readers.append(_get_pdf_reader(file_path))
131
- continue
132
- readers.append(_get_text_reader(file_path))
133
- return readers
134
-
135
- return get_documents
136
-
137
-
138
- def _get_text_reader(file_path: str):
139
- def read():
140
- print(stylize_faint(f"Start reading {file_path}"), file=sys.stderr)
141
- content = read_file(file_path)
142
- print(stylize_faint(f"Complete reading {file_path}"), file=sys.stderr)
143
- return content
144
-
145
- return read
146
-
147
-
148
- def _get_pdf_reader(file_path):
149
- def read():
150
- import pdfplumber
151
-
152
- print(stylize_faint(f"Start reading {file_path}"), file=sys.stderr)
153
- contents = []
154
- with pdfplumber.open(file_path) as pdf:
155
- for page in pdf.pages:
156
- contents.append(page.extract_text())
157
- print(stylize_faint(f"Complete reading {file_path}"), file=sys.stderr)
158
- return "\n".join(contents)
159
-
160
- return read
161
-
162
-
163
- def get_rag_reset_db(
164
- document_dir_path: str, vector_db_path: str = "./chroma"
165
- ) -> Callable[[], bool]:
166
- def should_reset_db() -> bool:
167
- document_exist = os.path.isdir(document_dir_path)
168
- if not document_exist:
169
- raise ValueError(f"Document directory not exists: {document_dir_path}")
170
- vector_db_exist = os.path.isdir(vector_db_path)
171
- if not vector_db_exist:
172
- return True
173
- document_mtime = _get_most_recent_mtime(document_dir_path)
174
- vector_db_mtime = _get_most_recent_mtime(vector_db_path)
175
- return document_mtime > vector_db_mtime
176
-
177
- return should_reset_db
178
-
179
-
180
- def _get_most_recent_mtime(directory):
181
- most_recent_mtime = 0
182
- for root, dirs, files in os.walk(directory):
183
- # Check mtime for directories
184
- for name in dirs + files:
185
- file_path = os.path.join(root, name)
186
- mtime = os.path.getmtime(file_path)
187
- if mtime > most_recent_mtime:
188
- most_recent_mtime = mtime
189
- return most_recent_mtime
120
+ def _compute_file_hash(file_path: str) -> str:
121
+ hash_md5 = hashlib.md5()
122
+ with open(file_path, "rb") as f:
123
+ for chunk in iter(lambda: f.read(4096), b""):
124
+ hash_md5.update(chunk)
125
+ return hash_md5.hexdigest()
126
+
127
+
128
+ def _read_file_content(file_path: str) -> str:
129
+ if file_path.lower().endswith(".pdf"):
130
+ return _read_pdf(file_path)
131
+ return read_file(file_path)
132
+
133
+
134
+ def _read_pdf(file_path: str) -> str:
135
+ import pdfplumber
136
+
137
+ with pdfplumber.open(file_path) as pdf:
138
+ return "\n".join(
139
+ page.extract_text() for page in pdf.pages if page.extract_text()
140
+ )
141
+
142
+
143
+ def _load_hashes(file_path: str) -> dict:
144
+ if os.path.exists(file_path):
145
+ with open(file_path, "r") as f:
146
+ return json.load(f)
147
+ return {}
148
+
149
+
150
+ def _save_hashes(file_path: str, hashes: dict):
151
+ with open(file_path, "w") as f:
152
+ json.dump(hashes, f)
@@ -18,17 +18,17 @@ async def get_my_entities(
18
18
 
19
19
 
20
20
  @abstractmethod
21
- async def create_my_entity(self, data: MyEntityCreateWithAudit) -> MyEntityResponse:
22
- """Create a new my entities"""
23
-
24
-
25
- @abstractmethod
26
- async def create_my_entity(
21
+ async def create_my_entity_bulk(
27
22
  self, data: list[MyEntityCreateWithAudit]
28
23
  ) -> list[MyEntityResponse]:
29
24
  """Create new my entities"""
30
25
 
31
26
 
27
+ @abstractmethod
28
+ async def create_my_entity(self, data: MyEntityCreateWithAudit) -> MyEntityResponse:
29
+ """Create a new my entities"""
30
+
31
+
32
32
  @abstractmethod
33
33
  async def update_my_entity_bulk(
34
34
  self, my_entity_ids: list[str], data: MyEntityUpdateWithAudit
@@ -51,7 +51,9 @@ async def update_my_entity_bulk(my_entity_ids: list[str], data: MyEntityUpdate):
51
51
  response_model=MyEntityResponse,
52
52
  )
53
53
  async def update_my_entity(my_entity_id: str, data: MyEntityUpdate):
54
- return await my_module_client.update_my_entity(data.with_audit(updated_by="system"))
54
+ return await my_module_client.update_my_entity(
55
+ my_entity_id, data.with_audit(updated_by="system")
56
+ )
55
57
 
56
58
 
57
59
  @app.delete(