jehoctor-rag-demo 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jehoctor_rag_demo-0.2.0.dist-info → jehoctor_rag_demo-0.2.1.dist-info}/METADATA +56 -31
- jehoctor_rag_demo-0.2.1.dist-info/RECORD +31 -0
- rag_demo/__main__.py +15 -4
- rag_demo/agents/__init__.py +4 -0
- rag_demo/agents/base.py +40 -0
- rag_demo/agents/hugging_face.py +116 -0
- rag_demo/agents/llama_cpp.py +113 -0
- rag_demo/agents/ollama.py +91 -0
- rag_demo/app.py +1 -1
- rag_demo/app_protocol.py +101 -0
- rag_demo/constants.py +11 -0
- rag_demo/logic.py +90 -176
- rag_demo/modes/_logic_provider.py +3 -2
- rag_demo/modes/chat.py +10 -8
- rag_demo/probe.py +129 -0
- jehoctor_rag_demo-0.2.0.dist-info/RECORD +0 -23
- {jehoctor_rag_demo-0.2.0.dist-info → jehoctor_rag_demo-0.2.1.dist-info}/WHEEL +0 -0
- {jehoctor_rag_demo-0.2.0.dist-info → jehoctor_rag_demo-0.2.1.dist-info}/entry_points.txt +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: jehoctor-rag-demo
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Chat with Wikipedia
|
|
5
5
|
Author: James Hoctor
|
|
6
6
|
Author-email: James Hoctor <JEHoctor@protonmail.com>
|
|
7
7
|
Requires-Dist: aiosqlite==0.21.0
|
|
8
|
+
Requires-Dist: bitsandbytes>=0.49.1
|
|
8
9
|
Requires-Dist: chromadb>=1.3.4
|
|
9
10
|
Requires-Dist: datasets>=4.4.1
|
|
10
11
|
Requires-Dist: httpx>=0.28.1
|
|
@@ -16,7 +17,6 @@ Requires-Dist: langchain-huggingface>=1.1.0
|
|
|
16
17
|
Requires-Dist: langchain-ollama>=1.0.0
|
|
17
18
|
Requires-Dist: langchain-openai>=1.0.2
|
|
18
19
|
Requires-Dist: langgraph-checkpoint-sqlite>=3.0.1
|
|
19
|
-
Requires-Dist: llama-cpp-python>=0.3.16
|
|
20
20
|
Requires-Dist: nvidia-ml-py>=13.590.44
|
|
21
21
|
Requires-Dist: ollama>=0.6.0
|
|
22
22
|
Requires-Dist: platformdirs>=4.5.0
|
|
@@ -24,9 +24,13 @@ Requires-Dist: psutil>=7.1.3
|
|
|
24
24
|
Requires-Dist: py-cpuinfo>=9.0.0
|
|
25
25
|
Requires-Dist: pydantic>=2.12.4
|
|
26
26
|
Requires-Dist: pyperclip>=1.11.0
|
|
27
|
+
Requires-Dist: sentence-transformers>=5.2.2
|
|
27
28
|
Requires-Dist: textual>=6.5.0
|
|
29
|
+
Requires-Dist: transformers[torch]>=4.57.6
|
|
28
30
|
Requires-Dist: typer>=0.20.0
|
|
29
|
-
Requires-
|
|
31
|
+
Requires-Dist: llama-cpp-python>=0.3.16 ; extra == 'llamacpp'
|
|
32
|
+
Requires-Python: ~=3.12.0
|
|
33
|
+
Provides-Extra: llamacpp
|
|
30
34
|
Description-Content-Type: text/markdown
|
|
31
35
|
|
|
32
36
|
# RAG-demo
|
|
@@ -35,50 +39,43 @@ Chat with (a small portion of) Wikipedia
|
|
|
35
39
|
|
|
36
40
|
⚠️ RAG functionality is still under development. ⚠️
|
|
37
41
|
|
|
38
|
-

|
|
39
43
|
|
|
40
44
|
## Requirements
|
|
41
45
|
|
|
42
|
-
1. [uv](https://docs.astral.sh/uv/)
|
|
43
|
-
|
|
44
|
-
-
|
|
45
|
-
|
|
46
|
+
1. The [uv](https://docs.astral.sh/uv/) Python package manager
|
|
47
|
+
- Installing and updating `uv` is easy by following [the docs](https://docs.astral.sh/uv/getting-started/installation/).
|
|
48
|
+
- As of 2026-01-25, I'm developing using `uv` version 0.9.26, and using the new experimental `--pytorch-backend` option.
|
|
49
|
+
2. A terminal emulator or web browser
|
|
50
|
+
- Any common web browser will work.
|
|
51
|
+
- Some terminal emulators will work better than others.
|
|
52
|
+
See [Notes on terminal emulators](#notes-on-terminal-emulators) below.
|
|
46
53
|
|
|
47
|
-
|
|
54
|
+
### Notes on terminal emulators
|
|
55
|
+
|
|
56
|
+
Certain terminal emulators will not work with some features of this program.
|
|
57
|
+
In particular, on macOS consider using [iTerm2](https://iterm2.com/) instead of the default Terminal.app ([explanation](https://textual.textualize.io/FAQ/#why-doesnt-textual-look-good-on-macos)).
|
|
58
|
+
On Linux you might want to try [kitty](https://sw.kovidgoyal.net/kitty/), [wezterm](https://wezterm.org/), [alacritty](https://alacritty.org/), or [ghostty](https://ghostty.org/), instead of the terminal that came with your desktop environment ([reason](https://darren.codes/posts/textual-copy-paste/)).
|
|
59
|
+
Windows Terminal should be fine as far as I know.
|
|
60
|
+
|
|
61
|
+
### Optional dependencies
|
|
48
62
|
|
|
49
63
|
1. [Hugging Face login](https://huggingface.co/docs/huggingface_hub/quick-start#login)
|
|
50
64
|
2. API key for your favorite LLM provider (support coming soon)
|
|
51
65
|
3. Ollama installed on your system if you have a GPU
|
|
52
66
|
4. Run RAG-demo on a more capable (bigger GPU) machine over SSH if you can. It is a terminal app after all.
|
|
67
|
+
5. A C compiler if you want to build Llama.cpp from source.
|
|
53
68
|
|
|
54
|
-
|
|
55
|
-
## Run from the repository
|
|
56
|
-
|
|
57
|
-
First, clone this repository. Then, run one of the options below.
|
|
69
|
+
## Run the latest version
|
|
58
70
|
|
|
59
71
|
Run in a terminal:
|
|
60
72
|
```bash
|
|
61
|
-
|
|
73
|
+
uvx --torch-backend=auto --from=jehoctor-rag-demo@latest chat
|
|
62
74
|
```
|
|
63
75
|
|
|
64
76
|
Or run in a web browser:
|
|
65
77
|
```bash
|
|
66
|
-
|
|
67
|
-
```
|
|
68
|
-
|
|
69
|
-
## Run from the latest version on PyPI
|
|
70
|
-
|
|
71
|
-
TODO: test uv automatic torch backend selection:
|
|
72
|
-
https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection
|
|
73
|
-
|
|
74
|
-
Run in a terminal:
|
|
75
|
-
```bash
|
|
76
|
-
uvx --from=jehoctor-rag-demo chat
|
|
77
|
-
```
|
|
78
|
-
|
|
79
|
-
Or run in a web browser:
|
|
80
|
-
```bash
|
|
81
|
-
uvx --from=jehoctor-rag-demo textual serve chat
|
|
78
|
+
uvx --torch-backend=auto --from=jehoctor-rag-demo@latest textual serve chat
|
|
82
79
|
```
|
|
83
80
|
|
|
84
81
|
## CUDA acceleration via Llama.cpp
|
|
@@ -86,15 +83,43 @@ uvx --from=jehoctor-rag-demo textual serve chat
|
|
|
86
83
|
If you have an NVIDIA GPU with CUDA and build tools installed, you might be able to get CUDA acceleration without installing Ollama.
|
|
87
84
|
|
|
88
85
|
```bash
|
|
89
|
-
CMAKE_ARGS="-DGGML_CUDA=on" uv run chat
|
|
86
|
+
CMAKE_ARGS="-DGGML_CUDA=on" uv run --extra=llamacpp chat
|
|
90
87
|
```
|
|
91
88
|
|
|
92
89
|
## Metal acceleration via Llama.cpp (on Apple Silicon)
|
|
93
90
|
|
|
94
91
|
On an Apple Silicon machine, make sure `uv` runs an ARM interpreter as this should cause it to install Llama.cpp with Metal support.
|
|
92
|
+
Also, run with the extra group `llamacpp`.
|
|
93
|
+
Try this:
|
|
94
|
+
|
|
95
|
+
```bash
|
|
96
|
+
uvx --python-platform=aarch64-apple-darwin --torch-backend=auto --from=jehoctor-rag-demo[llamacpp]@latest chat
|
|
97
|
+
```
|
|
95
98
|
|
|
96
99
|
## Ollama on Linux
|
|
97
100
|
|
|
98
101
|
Remember that you have to keep Ollama up-to-date manually on Linux.
|
|
99
102
|
A recent version of Ollama (v0.11.10 or later) is required to run the [embedding model we use](https://ollama.com/library/embeddinggemma).
|
|
100
103
|
See this FAQ: https://docs.ollama.com/faq#how-can-i-upgrade-ollama.
|
|
104
|
+
|
|
105
|
+
## Project feature roadmap
|
|
106
|
+
|
|
107
|
+
- ❌ RAG functionality
|
|
108
|
+
- ❌ torch inference via the Langchain local Hugging Face inference integration
|
|
109
|
+
- ❌ uv automatic torch backend selection (see [the docs](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection))
|
|
110
|
+
- ❌ OpenAI integration
|
|
111
|
+
- ❌ Anthropic integration
|
|
112
|
+
|
|
113
|
+
## Run from the repository
|
|
114
|
+
|
|
115
|
+
First, clone this repository. Then, run one of the options below.
|
|
116
|
+
|
|
117
|
+
Run in a terminal:
|
|
118
|
+
```bash
|
|
119
|
+
uv run chat
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
Or run in a web browser:
|
|
123
|
+
```bash
|
|
124
|
+
uv run textual serve chat
|
|
125
|
+
```
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
rag_demo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
rag_demo/__main__.py,sha256=S0UlQj3EldcXRk3rhH3ONdSOPmeyITYXIZ0o2JSxWbg,1618
|
|
3
|
+
rag_demo/agents/__init__.py,sha256=dsuO3AGcn2yGDq4gkAsZ32pjeTOqAudOL14G_AsEUyc,221
|
|
4
|
+
rag_demo/agents/base.py,sha256=gib6bC8nVKN1s1KPZd1dJVGRXnu7gFQwf3I3_7TSjQo,1312
|
|
5
|
+
rag_demo/agents/hugging_face.py,sha256=VrbGOlMO2z357LmU3sO5aM_yI5P-xbsfKmTAzH9_lFo,4225
|
|
6
|
+
rag_demo/agents/llama_cpp.py,sha256=C0hInc24sXmt5407_k4mP2Y6svqgfUPemhzAL1N6jY0,4272
|
|
7
|
+
rag_demo/agents/ollama.py,sha256=Fmtu8MSPPz91eT7HKvwvbQnA_xGPaD5HBrHEPPtomZA,3317
|
|
8
|
+
rag_demo/app.py,sha256=AVCJjlQ60y5J0v50TcJ3zZoa0ubhd_yKVDfu1ERsMVU,1807
|
|
9
|
+
rag_demo/app.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
rag_demo/app_protocol.py,sha256=P__Q3KT41uonYazpYmLWmOh1MeoBaCBx4xEydkwK5tk,3292
|
|
11
|
+
rag_demo/constants.py,sha256=5EpyAD6p5Qb0vB5ASMtFSLVVsq3RG5cBn_AYihhDCPc,235
|
|
12
|
+
rag_demo/db.py,sha256=53n662Hj9sTqPNcCI2Q-6Ca_HXv3kBQdAhXU4DLhwBM,3226
|
|
13
|
+
rag_demo/dirs.py,sha256=b0VR76kXRHSRWzaXzmAhfPr3-8WKY3ZLW8aLlaPI3Do,309
|
|
14
|
+
rag_demo/logic.py,sha256=SkF_Hqu1WSLHzwvSd_mJiCMSxZYqDnteYFRpc6oCREY,8236
|
|
15
|
+
rag_demo/markdown.py,sha256=CxzshWfANeiieZkzMlLzpRaz7tBY2_tZQxhs7b2ImKM,551
|
|
16
|
+
rag_demo/modes/__init__.py,sha256=ccvURDWz51_IotzzlO2OH3i4_Ih_MgnGlOK_JCh45dY,91
|
|
17
|
+
rag_demo/modes/_logic_provider.py,sha256=U3J8Fgq8MbNYd92FqENW-5YP_jXqKG3xmMmYoSUzhHo,1343
|
|
18
|
+
rag_demo/modes/chat.py,sha256=2pmKhQ2uYZdjezNnNBINViBMcuTVE5YCom_HEbdJeXg,13607
|
|
19
|
+
rag_demo/modes/chat.tcss,sha256=YANlgYygiOr-e61N9HaGGdRPM36pdr-l4u72G0ozt4o,1032
|
|
20
|
+
rag_demo/modes/config.py,sha256=0A8IdY-GOeqCd0kMs2KMgQEsFFeVXEcnowOugtR_Q84,2609
|
|
21
|
+
rag_demo/modes/config.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
+
rag_demo/modes/help.py,sha256=riV8o4WDtsim09R4cRi0xkpYLgj4CL38IrjEz_mrRmk,713
|
|
23
|
+
rag_demo/modes/help.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
|
+
rag_demo/probe.py,sha256=aDD-smNauEXXoBKVgx5xsMawM5tL0QAEBFl07ZGrddc,5101
|
|
25
|
+
rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
+
rag_demo/widgets/__init__.py,sha256=JQ1KQjdYQ4texHw2iT4IyBKgTW0SzNYbNoHAbrdCwtk,44
|
|
27
|
+
rag_demo/widgets/escapable_input.py,sha256=VfFij4NOtQ4uX3YFETg5YPd0_nBMky9Xz-02oRdHu-w,4240
|
|
28
|
+
jehoctor_rag_demo-0.2.1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
|
|
29
|
+
jehoctor_rag_demo-0.2.1.dist-info/entry_points.txt,sha256=-nDSFVcIqdTxzYM4fdveDk3xUKRhmlr_cRuqQechYh4,49
|
|
30
|
+
jehoctor_rag_demo-0.2.1.dist-info/METADATA,sha256=nCXuy3TYPPf67DPFryndW2P6CRdSCcJkXxFXZ-UN4vs,4650
|
|
31
|
+
jehoctor_rag_demo-0.2.1.dist-info/RECORD,,
|
rag_demo/__main__.py
CHANGED
|
@@ -3,21 +3,32 @@ import time
|
|
|
3
3
|
# Measure the application start time.
|
|
4
4
|
APPLICATION_START_TIME = time.time()
|
|
5
5
|
|
|
6
|
-
# Disable "module import not at top of file" (aka E402) when importing Typer. This is necessary
|
|
7
|
-
# initialization is included in the application startup time.
|
|
6
|
+
# Disable "module import not at top of file" (aka E402) when importing Typer and other early imports. This is necessary
|
|
7
|
+
# so that the initialization of these modules is included in the application startup time.
|
|
8
|
+
from typing import Annotated # noqa: E402
|
|
9
|
+
|
|
8
10
|
import typer # noqa: E402
|
|
9
11
|
|
|
12
|
+
from rag_demo.constants import LocalProviderType # noqa: E402
|
|
13
|
+
|
|
10
14
|
|
|
11
15
|
def _main(
|
|
12
|
-
name: str | None
|
|
16
|
+
name: Annotated[str | None, typer.Option(help="The name you want to want the AI to use with you.")] = None,
|
|
17
|
+
provider: Annotated[LocalProviderType | None, typer.Option(help="The local provider to prefer.")] = None,
|
|
13
18
|
) -> None:
|
|
14
19
|
"""Talk to Wikipedia."""
|
|
15
20
|
# Import here so that imports run within the typer.run context for prettier stack traces if errors occur.
|
|
16
21
|
# We ignore PLC0415 because we do not want these imports to be at the top of the module as is usually preferred.
|
|
22
|
+
import transformers # noqa: PLC0415
|
|
23
|
+
|
|
17
24
|
from rag_demo.app import RAGDemo # noqa: PLC0415
|
|
18
25
|
from rag_demo.logic import Logic # noqa: PLC0415
|
|
19
26
|
|
|
20
|
-
|
|
27
|
+
# The transformers library likes to print text that interferes with the TUI. Disable it.
|
|
28
|
+
transformers.logging.set_verbosity(verbosity=transformers.logging.CRITICAL)
|
|
29
|
+
transformers.logging.disable_progress_bar()
|
|
30
|
+
|
|
31
|
+
logic = Logic(username=name, preferred_provider_type=provider, application_start_time=APPLICATION_START_TIME)
|
|
21
32
|
app = RAGDemo(logic)
|
|
22
33
|
app.run()
|
|
23
34
|
|
rag_demo/agents/base.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Final, Protocol
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from contextlib import AbstractAsyncContextManager
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from rag_demo.app_protocol import AppProtocol
|
|
11
|
+
from rag_demo.constants import LocalProviderType
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Agent(Protocol):
|
|
15
|
+
"""An LLM agent that supports streaming responses asynchronously."""
|
|
16
|
+
|
|
17
|
+
def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
|
|
18
|
+
"""Stream a response from the agent.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
user_message (str): User's next prompt in the conversation.
|
|
22
|
+
thread_id (str): Identifier for the current thread/conversation.
|
|
23
|
+
app (AppProtocol): Application interface, commonly used for logging.
|
|
24
|
+
|
|
25
|
+
Yields:
|
|
26
|
+
str: A token from the agent's response.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AgentProvider(Protocol):
|
|
31
|
+
"""A strategy for creating LLM agents."""
|
|
32
|
+
|
|
33
|
+
type: Final[LocalProviderType]
|
|
34
|
+
|
|
35
|
+
def get_agent(self, checkpoints_sqlite_db: str | Path) -> AbstractAsyncContextManager[Agent | None]:
|
|
36
|
+
"""Attempt to create an agent.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
40
|
+
"""
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import sqlite3
|
|
5
|
+
from contextlib import asynccontextmanager
|
|
6
|
+
from typing import TYPE_CHECKING, Final
|
|
7
|
+
|
|
8
|
+
from huggingface_hub import hf_hub_download
|
|
9
|
+
from langchain.agents import create_agent
|
|
10
|
+
from langchain.messages import AIMessageChunk, HumanMessage
|
|
11
|
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFacePipeline
|
|
12
|
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
13
|
+
|
|
14
|
+
from rag_demo.constants import LocalProviderType
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import AsyncIterator
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
from rag_demo.app_protocol import AppProtocol
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HuggingFaceAgent:
|
|
24
|
+
"""An LLM agent powered by Hugging Face local pipelines."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
checkpoints_sqlite_db: str | Path,
|
|
29
|
+
model_id: str,
|
|
30
|
+
embedding_model_id: str,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initialize the HuggingFaceAgent.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
36
|
+
model_id (str): Hugging Face model ID for the LLM.
|
|
37
|
+
embedding_model_id (str): Hugging Face model ID for the embedding model.
|
|
38
|
+
"""
|
|
39
|
+
self.checkpoints_sqlite_db = checkpoints_sqlite_db
|
|
40
|
+
self.model_id = model_id
|
|
41
|
+
self.embedding_model_id = embedding_model_id
|
|
42
|
+
|
|
43
|
+
self.llm = ChatHuggingFace(
|
|
44
|
+
llm=HuggingFacePipeline.from_model_id(
|
|
45
|
+
model_id=model_id,
|
|
46
|
+
task="text-generation",
|
|
47
|
+
device_map="auto",
|
|
48
|
+
pipeline_kwargs={"max_new_tokens": 4096},
|
|
49
|
+
),
|
|
50
|
+
)
|
|
51
|
+
self.embed = HuggingFaceEmbeddings(model_name=embedding_model_id)
|
|
52
|
+
self.agent = create_agent(
|
|
53
|
+
model=self.llm,
|
|
54
|
+
system_prompt="You are a helpful assistant.",
|
|
55
|
+
checkpointer=SqliteSaver(sqlite3.Connection(self.checkpoints_sqlite_db, check_same_thread=False)),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
|
|
59
|
+
"""Stream a response from the agent.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
user_message (str): User's next prompt in the conversation.
|
|
63
|
+
thread_id (str): Identifier for the current thread/conversation.
|
|
64
|
+
app (AppProtocol): Application interface, commonly used for logging.
|
|
65
|
+
|
|
66
|
+
Yields:
|
|
67
|
+
str: A token from the agent's response.
|
|
68
|
+
"""
|
|
69
|
+
agent_stream = self.agent.stream(
|
|
70
|
+
{"messages": [HumanMessage(content=user_message)]},
|
|
71
|
+
{"configurable": {"thread_id": thread_id}},
|
|
72
|
+
stream_mode="messages",
|
|
73
|
+
)
|
|
74
|
+
for message_chunk, _ in agent_stream:
|
|
75
|
+
if isinstance(message_chunk, AIMessageChunk):
|
|
76
|
+
token = message_chunk.content
|
|
77
|
+
if isinstance(token, str):
|
|
78
|
+
yield token
|
|
79
|
+
else:
|
|
80
|
+
app.log.error("Received message content of type", type(token))
|
|
81
|
+
else:
|
|
82
|
+
app.log.error("Received message chunk of type", type(message_chunk))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _hf_downloads() -> None:
|
|
86
|
+
hf_hub_download(
|
|
87
|
+
repo_id="Qwen/Qwen3-0.6B", # 1.5GB
|
|
88
|
+
filename="model.safetensors",
|
|
89
|
+
revision="c1899de289a04d12100db370d81485cdf75e47ca",
|
|
90
|
+
)
|
|
91
|
+
hf_hub_download(
|
|
92
|
+
repo_id="unsloth/embeddinggemma-300m", # 1.21GB
|
|
93
|
+
filename="model.safetensors",
|
|
94
|
+
revision="bfa3c846ac738e62aa61806ef9112d34acb1dc5a",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class HuggingFaceAgentProvider:
|
|
99
|
+
"""Create LLM agents using Hugging Face local pipelines."""
|
|
100
|
+
|
|
101
|
+
type: Final[LocalProviderType] = LocalProviderType.HUGGING_FACE
|
|
102
|
+
|
|
103
|
+
@asynccontextmanager
|
|
104
|
+
async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[HuggingFaceAgent]:
|
|
105
|
+
"""Create a Hugging Face local pipeline agent.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
109
|
+
"""
|
|
110
|
+
loop = asyncio.get_running_loop()
|
|
111
|
+
await loop.run_in_executor(None, _hf_downloads)
|
|
112
|
+
yield HuggingFaceAgent(
|
|
113
|
+
checkpoints_sqlite_db,
|
|
114
|
+
model_id="Qwen/Qwen3-0.6B",
|
|
115
|
+
embedding_model_id="unsloth/embeddinggemma-300m",
|
|
116
|
+
)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from typing import TYPE_CHECKING, Final
|
|
6
|
+
|
|
7
|
+
import aiosqlite
|
|
8
|
+
from huggingface_hub import hf_hub_download
|
|
9
|
+
from langchain.agents import create_agent
|
|
10
|
+
from langchain.messages import AIMessageChunk, HumanMessage
|
|
11
|
+
from langchain_community.chat_models import ChatLlamaCpp
|
|
12
|
+
from langchain_community.embeddings import LlamaCppEmbeddings
|
|
13
|
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
14
|
+
|
|
15
|
+
from rag_demo import probe
|
|
16
|
+
from rag_demo.constants import LocalProviderType
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from collections.abc import AsyncIterator
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
from rag_demo.app_protocol import AppProtocol
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LlamaCppAgent:
|
|
26
|
+
"""An LLM agent powered by Llama.cpp."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
checkpoints_conn: aiosqlite.Connection,
|
|
31
|
+
model_path: str,
|
|
32
|
+
embedding_model_path: str,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initialize the LlamaCppAgent.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
checkpoints_conn (aiosqlite.Connection): Connection to SQLite checkpoint database.
|
|
38
|
+
model_path (str): Path to Llama.cpp model.
|
|
39
|
+
embedding_model_path (str): Path to Llama.cpp embedding model.
|
|
40
|
+
"""
|
|
41
|
+
self.checkpoints_conn = checkpoints_conn
|
|
42
|
+
self.llm = ChatLlamaCpp(model_path=model_path, verbose=False)
|
|
43
|
+
self.embed = LlamaCppEmbeddings(model_path=embedding_model_path, verbose=False)
|
|
44
|
+
self.agent = create_agent(
|
|
45
|
+
model=self.llm,
|
|
46
|
+
system_prompt="You are a helpful assistant.",
|
|
47
|
+
checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
|
|
51
|
+
"""Stream a response from the agent.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
user_message (str): User's next prompt in the conversation.
|
|
55
|
+
thread_id (str): Identifier for the current thread/conversation.
|
|
56
|
+
app (AppProtocol): Application interface, commonly used for logging.
|
|
57
|
+
|
|
58
|
+
Yields:
|
|
59
|
+
str: A token from the agent's response.
|
|
60
|
+
"""
|
|
61
|
+
agent_stream = self.agent.astream(
|
|
62
|
+
{"messages": [HumanMessage(content=user_message)]},
|
|
63
|
+
{"configurable": {"thread_id": thread_id}},
|
|
64
|
+
stream_mode="messages",
|
|
65
|
+
)
|
|
66
|
+
async for message_chunk, _ in agent_stream:
|
|
67
|
+
if isinstance(message_chunk, AIMessageChunk):
|
|
68
|
+
token = message_chunk.content
|
|
69
|
+
if isinstance(token, str):
|
|
70
|
+
yield token
|
|
71
|
+
else:
|
|
72
|
+
app.log.error("Received message content of type", type(token))
|
|
73
|
+
else:
|
|
74
|
+
app.log.error("Received message chunk of type", type(message_chunk))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _hf_downloads() -> tuple[str, str]:
|
|
78
|
+
model_path = hf_hub_download(
|
|
79
|
+
repo_id="bartowski/google_gemma-3-4b-it-GGUF",
|
|
80
|
+
filename="google_gemma-3-4b-it-Q6_K_L.gguf", # 3.35GB
|
|
81
|
+
revision="71506238f970075ca85125cd749c28b1b0eee84e",
|
|
82
|
+
)
|
|
83
|
+
embedding_model_path = hf_hub_download(
|
|
84
|
+
repo_id="CompendiumLabs/bge-small-en-v1.5-gguf",
|
|
85
|
+
filename="bge-small-en-v1.5-q8_0.gguf", # 36.8MB
|
|
86
|
+
revision="d32f8c040ea3b516330eeb75b72bcc2d3a780ab7",
|
|
87
|
+
)
|
|
88
|
+
return model_path, embedding_model_path
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class LlamaCppAgentProvider:
|
|
92
|
+
"""Create LLM agents using Llama.cpp."""
|
|
93
|
+
|
|
94
|
+
type: Final[LocalProviderType] = LocalProviderType.LLAMA_CPP
|
|
95
|
+
|
|
96
|
+
@asynccontextmanager
|
|
97
|
+
async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[LlamaCppAgent | None]:
|
|
98
|
+
"""Attempt to create a Llama.cpp agent.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
102
|
+
"""
|
|
103
|
+
if probe.probe_llama_available():
|
|
104
|
+
loop = asyncio.get_running_loop()
|
|
105
|
+
model_path, embedding_model_path = await loop.run_in_executor(None, _hf_downloads)
|
|
106
|
+
async with aiosqlite.connect(database=checkpoints_sqlite_db) as checkpoints_conn:
|
|
107
|
+
yield LlamaCppAgent(
|
|
108
|
+
checkpoints_conn=checkpoints_conn,
|
|
109
|
+
model_path=model_path,
|
|
110
|
+
embedding_model_path=embedding_model_path,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
yield None
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from typing import TYPE_CHECKING, Final
|
|
5
|
+
|
|
6
|
+
import aiosqlite
|
|
7
|
+
import ollama
|
|
8
|
+
from langchain.agents import create_agent
|
|
9
|
+
from langchain.messages import AIMessageChunk, HumanMessage
|
|
10
|
+
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
|
11
|
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
12
|
+
|
|
13
|
+
from rag_demo import probe
|
|
14
|
+
from rag_demo.constants import LocalProviderType
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import AsyncIterator
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
from rag_demo.app_protocol import AppProtocol
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OllamaAgent:
|
|
24
|
+
"""An LLM agent powered by Ollama."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, checkpoints_conn: aiosqlite.Connection) -> None:
|
|
27
|
+
"""Initialize the OllamaAgent.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
checkpoints_conn (aiosqlite.Connection): Asynchronous connection to SQLite db for checkpoints.
|
|
31
|
+
"""
|
|
32
|
+
self.checkpoints_conn = checkpoints_conn
|
|
33
|
+
ollama.pull("gemma3:latest") # 3.3GB
|
|
34
|
+
ollama.pull("embeddinggemma:latest") # 621MB
|
|
35
|
+
self.llm = ChatOllama(
|
|
36
|
+
model="gemma3:latest",
|
|
37
|
+
validate_model_on_init=True,
|
|
38
|
+
temperature=0.5,
|
|
39
|
+
num_predict=4096,
|
|
40
|
+
)
|
|
41
|
+
self.embed = OllamaEmbeddings(model="embeddinggemma:latest")
|
|
42
|
+
self.agent = create_agent(
|
|
43
|
+
model=self.llm,
|
|
44
|
+
system_prompt="You are a helpful assistant.",
|
|
45
|
+
checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
|
|
49
|
+
"""Stream a response from the agent.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
user_message (str): User's next prompt in the conversation.
|
|
53
|
+
thread_id (str): Identifier for the current thread/conversation.
|
|
54
|
+
app (AppProtocol): Application interface, commonly used for logging.
|
|
55
|
+
|
|
56
|
+
Yields:
|
|
57
|
+
str: A token from the agent's response.
|
|
58
|
+
"""
|
|
59
|
+
agent_stream = self.agent.astream(
|
|
60
|
+
{"messages": [HumanMessage(content=user_message)]},
|
|
61
|
+
{"configurable": {"thread_id": thread_id}},
|
|
62
|
+
stream_mode="messages",
|
|
63
|
+
)
|
|
64
|
+
async for message_chunk, _ in agent_stream:
|
|
65
|
+
if isinstance(message_chunk, AIMessageChunk):
|
|
66
|
+
token = message_chunk.content
|
|
67
|
+
if isinstance(token, str):
|
|
68
|
+
yield token
|
|
69
|
+
else:
|
|
70
|
+
app.log.error("Received message content of type", type(token))
|
|
71
|
+
else:
|
|
72
|
+
app.log.error("Received message chunk of type", type(message_chunk))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class OllamaAgentProvider:
|
|
76
|
+
"""Create LLM agents using Ollama."""
|
|
77
|
+
|
|
78
|
+
type: Final[LocalProviderType] = LocalProviderType.OLLAMA
|
|
79
|
+
|
|
80
|
+
@asynccontextmanager
|
|
81
|
+
async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[OllamaAgent | None]:
|
|
82
|
+
"""Attempt to create an Ollama agent.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
86
|
+
"""
|
|
87
|
+
if probe.probe_ollama() is not None:
|
|
88
|
+
async with aiosqlite.connect(database=checkpoints_sqlite_db) as checkpoints_conn:
|
|
89
|
+
yield OllamaAgent(checkpoints_conn=checkpoints_conn)
|
|
90
|
+
else:
|
|
91
|
+
yield None
|
rag_demo/app.py
CHANGED
|
@@ -48,7 +48,7 @@ class RAGDemo(App):
|
|
|
48
48
|
self.run_worker(self._hold_runtime())
|
|
49
49
|
|
|
50
50
|
async def _hold_runtime(self) -> None:
|
|
51
|
-
async with self.logic.runtime(
|
|
51
|
+
async with self.logic.runtime(app=self) as runtime:
|
|
52
52
|
self._runtime_future.set_result(runtime)
|
|
53
53
|
# Pause the task until Textual cancels it when the application closes.
|
|
54
54
|
await asyncio.Event().wait()
|
rag_demo/app_protocol.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Interface for the logic to call back into the app code.
|
|
2
|
+
|
|
3
|
+
This is necessary to make the logic code testable. We don't want to have to run all the app code to test the logic. And,
|
|
4
|
+
we want to have a high degree of confidence when mocking out the app code in logic tests. The basic pattern is that each
|
|
5
|
+
piece of functionality that the logic depends on will have a protocol and an implementation of that protocol using the
|
|
6
|
+
Textual App. In the tests, we create a mock implementation of the same protocol. Correctness of the logic is defined by
|
|
7
|
+
its ability to work correctly with any implementation of the protocol, not just the implementation backed by the app.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING, Protocol, TypeVar
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from collections.abc import Awaitable
|
|
16
|
+
|
|
17
|
+
from textual.worker import Worker
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LoggerProtocol(Protocol):
|
|
21
|
+
"""Protocol that mimics textual.Logger."""
|
|
22
|
+
|
|
23
|
+
def __call__(self, *args: object, **kwargs: object) -> None:
|
|
24
|
+
"""Log a message.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
*args (object): Logged directly to the message separated by spaces.
|
|
28
|
+
**kwargs (object): Logged to the message as f"{key}={value!r}", separated by spaces.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def verbosity(self, *, verbose: bool) -> LoggerProtocol:
|
|
32
|
+
"""Get a new logger with selective verbosity.
|
|
33
|
+
|
|
34
|
+
Note that unlike when using this method on a Textual logger directly, the type system will enforce that you use
|
|
35
|
+
`verbose` as a keyword argument (not a positional argument). I made this change to address ruff's FBT001 rule.
|
|
36
|
+
Put simply, this requirement makes the calling code easier to read.
|
|
37
|
+
https://docs.astral.sh/ruff/rules/boolean-type-hint-positional-argument/
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
verbose: True to use HIGH verbosity, otherwise NORMAL.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
New logger.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def verbose(self) -> LoggerProtocol:
|
|
48
|
+
"""A verbose logger."""
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def event(self) -> LoggerProtocol:
|
|
52
|
+
"""Logs events."""
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def debug(self) -> LoggerProtocol:
|
|
56
|
+
"""Logs debug messages."""
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def info(self) -> LoggerProtocol:
|
|
60
|
+
"""Logs information."""
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def warning(self) -> LoggerProtocol:
|
|
64
|
+
"""Logs warnings."""
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def error(self) -> LoggerProtocol:
|
|
68
|
+
"""Logs errors."""
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def system(self) -> LoggerProtocol:
|
|
72
|
+
"""Logs system information."""
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def logging(self) -> LoggerProtocol:
|
|
76
|
+
"""Logs from stdlib logging module."""
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def worker(self) -> LoggerProtocol:
|
|
80
|
+
"""Logs worker information."""
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
ResultType = TypeVar("ResultType")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AppProtocol(Protocol):
|
|
87
|
+
"""Protocol for the subset of what the main App can do that the runtime needs."""
|
|
88
|
+
|
|
89
|
+
def run_worker(self, work: Awaitable[ResultType], *, thread: bool = False) -> Worker[ResultType]:
|
|
90
|
+
"""Run a coroutine in the background.
|
|
91
|
+
|
|
92
|
+
See https://textual.textualize.io/guide/workers/.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
work (Awaitable[ResultType]): The coroutine to run.
|
|
96
|
+
thread (bool): Mark the worker as a thread worker.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def log(self) -> LoggerProtocol:
|
|
101
|
+
"""Returns the application logger."""
|
rag_demo/constants.py
ADDED
rag_demo/logic.py
CHANGED
|
@@ -1,57 +1,44 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import contextlib
|
|
4
|
-
import platform
|
|
5
3
|
import time
|
|
6
4
|
from contextlib import asynccontextmanager
|
|
7
|
-
from
|
|
8
|
-
from typing import TYPE_CHECKING, Protocol, TypeVar, cast
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
9
6
|
|
|
10
|
-
import aiosqlite
|
|
11
|
-
import cpuinfo
|
|
12
|
-
import httpx
|
|
13
|
-
import huggingface_hub
|
|
14
|
-
import llama_cpp
|
|
15
|
-
import ollama
|
|
16
|
-
import psutil
|
|
17
|
-
import pynvml
|
|
18
7
|
from datasets import Dataset, load_dataset
|
|
19
|
-
from huggingface_hub import hf_hub_download
|
|
20
|
-
from huggingface_hub.constants import HF_HUB_CACHE
|
|
21
|
-
from langchain.agents import create_agent
|
|
22
|
-
from langchain.messages import AIMessageChunk, HumanMessage
|
|
23
|
-
from langchain_community.chat_models import ChatLlamaCpp
|
|
24
|
-
from langchain_community.embeddings import LlamaCppEmbeddings
|
|
25
8
|
from langchain_core.exceptions import LangChainException
|
|
26
|
-
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
|
27
|
-
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
28
9
|
|
|
29
10
|
from rag_demo import dirs
|
|
11
|
+
from rag_demo.agents import (
|
|
12
|
+
Agent,
|
|
13
|
+
AgentProvider,
|
|
14
|
+
HuggingFaceAgentProvider,
|
|
15
|
+
LlamaCppAgentProvider,
|
|
16
|
+
OllamaAgentProvider,
|
|
17
|
+
)
|
|
30
18
|
from rag_demo.db import AtomicIDManager
|
|
31
19
|
from rag_demo.modes.chat import Response, StoppedStreamError
|
|
32
20
|
|
|
33
21
|
if TYPE_CHECKING:
|
|
34
|
-
from collections.abc import AsyncIterator,
|
|
35
|
-
|
|
36
|
-
from textual.worker import Worker
|
|
22
|
+
from collections.abc import AsyncIterator, Sequence
|
|
23
|
+
from pathlib import Path
|
|
37
24
|
|
|
25
|
+
from rag_demo.app_protocol import AppProtocol
|
|
26
|
+
from rag_demo.constants import LocalProviderType
|
|
38
27
|
from rag_demo.modes import ChatScreen
|
|
39
28
|
|
|
40
|
-
ResultType = TypeVar("ResultType")
|
|
41
29
|
|
|
30
|
+
class UnknownPreferredProviderError(ValueError):
|
|
31
|
+
"""Raised when the preferred provider cannot be checked first due to being unknown."""
|
|
42
32
|
|
|
43
|
-
|
|
44
|
-
|
|
33
|
+
def __init__(self, preferred_provider: LocalProviderType) -> None: # noqa: D107
|
|
34
|
+
super().__init__(f"Unknown preferred provider: {preferred_provider}")
|
|
45
35
|
|
|
46
|
-
def run_worker(self, work: Awaitable[ResultType]) -> Worker[ResultType]:
|
|
47
|
-
"""Run a coroutine in the background.
|
|
48
36
|
|
|
49
|
-
|
|
37
|
+
class NoProviderError(RuntimeError):
|
|
38
|
+
"""Raised when no provider could provide an agent."""
|
|
50
39
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
"""
|
|
54
|
-
...
|
|
40
|
+
def __init__(self) -> None: # noqa: D107
|
|
41
|
+
super().__init__("No provider could provide an agent.")
|
|
55
42
|
|
|
56
43
|
|
|
57
44
|
class Runtime:
|
|
@@ -60,50 +47,28 @@ class Runtime:
|
|
|
60
47
|
def __init__(
|
|
61
48
|
self,
|
|
62
49
|
logic: Logic,
|
|
63
|
-
|
|
50
|
+
app: AppProtocol,
|
|
51
|
+
agent: Agent,
|
|
64
52
|
thread_id_manager: AtomicIDManager,
|
|
65
|
-
app_like: AppLike,
|
|
66
53
|
) -> None:
|
|
54
|
+
"""Initialize the runtime.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
logic (Logic): The application logic.
|
|
58
|
+
app (AppProtocol): The application interface.
|
|
59
|
+
agent (Agent): The agent to use.
|
|
60
|
+
thread_id_manager (AtomicIDManager): The thread ID manager.
|
|
61
|
+
"""
|
|
67
62
|
self.runtime_start_time = time.time()
|
|
68
63
|
self.logic = logic
|
|
69
|
-
self.
|
|
64
|
+
self.app = app
|
|
65
|
+
self.agent = agent
|
|
70
66
|
self.thread_id_manager = thread_id_manager
|
|
71
|
-
self.app_like = app_like
|
|
72
67
|
|
|
73
68
|
self.current_thread: int | None = None
|
|
74
69
|
self.generating = False
|
|
75
70
|
|
|
76
|
-
|
|
77
|
-
ollama.pull("gemma3:latest") # 3.3GB
|
|
78
|
-
ollama.pull("embeddinggemma:latest") # 621MB
|
|
79
|
-
self.llm = ChatOllama(
|
|
80
|
-
model="gemma3:latest",
|
|
81
|
-
validate_model_on_init=True,
|
|
82
|
-
temperature=0.5,
|
|
83
|
-
num_predict=4096,
|
|
84
|
-
)
|
|
85
|
-
self.embed = OllamaEmbeddings(model="embeddinggemma:latest")
|
|
86
|
-
else:
|
|
87
|
-
model_path = hf_hub_download(
|
|
88
|
-
repo_id="bartowski/google_gemma-3-4b-it-GGUF",
|
|
89
|
-
filename="google_gemma-3-4b-it-Q6_K_L.gguf", # 3.35GB
|
|
90
|
-
revision="71506238f970075ca85125cd749c28b1b0eee84e",
|
|
91
|
-
)
|
|
92
|
-
embedding_model_path = hf_hub_download(
|
|
93
|
-
repo_id="CompendiumLabs/bge-small-en-v1.5-gguf",
|
|
94
|
-
filename="bge-small-en-v1.5-q8_0.gguf", # 36.8MB
|
|
95
|
-
revision="d32f8c040ea3b516330eeb75b72bcc2d3a780ab7",
|
|
96
|
-
)
|
|
97
|
-
self.llm = ChatLlamaCpp(model_path=model_path, verbose=False)
|
|
98
|
-
self.embed = LlamaCppEmbeddings(model_path=embedding_model_path, verbose=False) # pyright: ignore[reportCallIssue]
|
|
99
|
-
|
|
100
|
-
self.agent = create_agent(
|
|
101
|
-
model=self.llm,
|
|
102
|
-
system_prompt="You are a helpful assistant.",
|
|
103
|
-
checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
def get_rag_datasets(self) -> None:
|
|
71
|
+
def _get_rag_datasets(self) -> None:
|
|
107
72
|
self.qa_test: Dataset = cast(
|
|
108
73
|
"Dataset",
|
|
109
74
|
load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer", split="test"),
|
|
@@ -123,21 +88,9 @@ class Runtime:
|
|
|
123
88
|
"""
|
|
124
89
|
self.generating = True
|
|
125
90
|
async with response_widget.stream_writer() as writer:
|
|
126
|
-
agent_stream = self.agent.astream(
|
|
127
|
-
{"messages": [HumanMessage(content=request_text)]},
|
|
128
|
-
{"configurable": {"thread_id": thread}},
|
|
129
|
-
stream_mode="messages",
|
|
130
|
-
)
|
|
131
91
|
try:
|
|
132
|
-
async for message_chunk,
|
|
133
|
-
|
|
134
|
-
token = cast("AIMessageChunk", message_chunk).content
|
|
135
|
-
if isinstance(token, str):
|
|
136
|
-
await writer.write(token)
|
|
137
|
-
else:
|
|
138
|
-
response_widget.log.error(f"Received message content of type {type(token)}")
|
|
139
|
-
else:
|
|
140
|
-
response_widget.log.error(f"Received message chunk of type {type(message_chunk)}")
|
|
92
|
+
async for message_chunk in self.agent.astream(request_text, thread, self.app):
|
|
93
|
+
await writer.write(message_chunk)
|
|
141
94
|
except StoppedStreamError as e:
|
|
142
95
|
response_widget.set_shown_object(e)
|
|
143
96
|
except LangChainException as e:
|
|
@@ -145,10 +98,24 @@ class Runtime:
|
|
|
145
98
|
self.generating = False
|
|
146
99
|
|
|
147
100
|
def new_conversation(self, chat_screen: ChatScreen) -> None:
|
|
101
|
+
"""Clear the screen and start a new conversation with the agent.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
chat_screen (ChatScreen): The chat screen to clear.
|
|
105
|
+
"""
|
|
148
106
|
self.current_thread = None
|
|
149
107
|
chat_screen.clear_chats()
|
|
150
108
|
|
|
151
109
|
async def submit_request(self, chat_screen: ChatScreen, request_text: str) -> bool:
|
|
110
|
+
"""Submit a new user request in the current conversation.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
chat_screen (ChatScreen): The chat screen in which the request is submitted.
|
|
114
|
+
request_text (str): The text of the request.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
bool: True if the request was accepted for immediate processing, False otherwise.
|
|
118
|
+
"""
|
|
152
119
|
if self.generating:
|
|
153
120
|
return False
|
|
154
121
|
self.generating = True
|
|
@@ -168,120 +135,67 @@ class Logic:
|
|
|
168
135
|
def __init__(
|
|
169
136
|
self,
|
|
170
137
|
username: str | None = None,
|
|
138
|
+
preferred_provider_type: LocalProviderType | None = None,
|
|
171
139
|
application_start_time: float | None = None,
|
|
172
140
|
checkpoints_sqlite_db: str | Path = dirs.DATA_DIR / "checkpoints.sqlite3",
|
|
173
141
|
app_sqlite_db: str | Path = dirs.DATA_DIR / "app.sqlite3",
|
|
142
|
+
agent_providers: Sequence[AgentProvider] = (
|
|
143
|
+
LlamaCppAgentProvider(),
|
|
144
|
+
OllamaAgentProvider(),
|
|
145
|
+
HuggingFaceAgentProvider(),
|
|
146
|
+
),
|
|
174
147
|
) -> None:
|
|
175
148
|
"""Initialize the application logic.
|
|
176
149
|
|
|
177
150
|
Args:
|
|
178
151
|
username (str | None, optional): The username provided as a command line argument. Defaults to None.
|
|
152
|
+
preferred_provider_type (LocalProviderType | None, optional): Provider type to prefer. Defaults to None.
|
|
179
153
|
application_start_time (float | None, optional): The time when the application started. Defaults to None.
|
|
180
154
|
checkpoints_sqlite_db (str | Path, optional): The connection string for the SQLite database used for
|
|
181
155
|
Langchain checkpointing. Defaults to (dirs.DATA_DIR / "checkpoints.sqlite3").
|
|
182
156
|
app_sqlite_db (str | Path, optional): The connection string for the SQLite database used for application
|
|
183
157
|
state such a thread metadata. Defaults to (dirs.DATA_DIR / "app.sqlite3").
|
|
158
|
+
agent_providers (Sequence[AgentProvider], optional): Sequence of agent providers in default preference
|
|
159
|
+
order. If preferred_provider_type is not None, this sequence will be reordered to bring providers of
|
|
160
|
+
that type to the front, using the original order to break ties. Defaults to (
|
|
161
|
+
LlamaCppAgentProvider(),
|
|
162
|
+
OllamaAgentProvider(),
|
|
163
|
+
HuggingFaceAgentProvider(),
|
|
164
|
+
).
|
|
184
165
|
"""
|
|
185
166
|
self.logic_start_time = time.time()
|
|
186
167
|
self.username = username
|
|
168
|
+
self.preferred_provider_type = preferred_provider_type
|
|
187
169
|
self.application_start_time = application_start_time
|
|
188
170
|
self.checkpoints_sqlite_db = checkpoints_sqlite_db
|
|
189
171
|
self.app_sqlite_db = app_sqlite_db
|
|
172
|
+
self.agent_providers: Sequence[AgentProvider] = agent_providers
|
|
190
173
|
|
|
191
174
|
@asynccontextmanager
|
|
192
|
-
async def runtime(self,
|
|
175
|
+
async def runtime(self, app: AppProtocol) -> AsyncIterator[Runtime]:
|
|
193
176
|
"""Returns a runtime context for the application."""
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
id_manager = AtomicIDManager(self.app_sqlite_db)
|
|
197
|
-
await id_manager.initialize()
|
|
198
|
-
yield Runtime(self, checkpoints_conn, id_manager, app_like)
|
|
199
|
-
|
|
200
|
-
def probe_os(self) -> str:
|
|
201
|
-
"""Returns the OS name (eg 'Linux' or 'Windows'), the system name (eg 'Java'), or an empty string if unknown."""
|
|
202
|
-
return platform.system()
|
|
203
|
-
|
|
204
|
-
def probe_architecture(self) -> str:
|
|
205
|
-
"""Returns the machine architecture, such as 'i386'."""
|
|
206
|
-
return platform.machine()
|
|
207
|
-
|
|
208
|
-
def probe_cpu(self) -> str:
|
|
209
|
-
"""Returns the name of the CPU, e.g. "Intel(R) Core(TM) i7-10610U CPU @ 1.80GHz"."""
|
|
210
|
-
return cpuinfo.get_cpu_info()["brand_raw"]
|
|
211
|
-
|
|
212
|
-
def probe_ram(self) -> int:
|
|
213
|
-
"""Returns the total amount of RAM in bytes."""
|
|
214
|
-
return psutil.virtual_memory().total
|
|
215
|
-
|
|
216
|
-
def probe_disk_space(self) -> int:
|
|
217
|
-
"""Returns the amount of free space in the root directory (in bytes)."""
|
|
218
|
-
return psutil.disk_usage("/").free
|
|
219
|
-
|
|
220
|
-
def probe_llamacpp_gpu_support(self) -> bool:
|
|
221
|
-
"""Returns True if LlamaCpp supports GPU offloading, False otherwise."""
|
|
222
|
-
return llama_cpp.llama_supports_gpu_offload()
|
|
223
|
-
|
|
224
|
-
def probe_huggingface_free_cache_space(self) -> int | None:
|
|
225
|
-
"""Returns the amount of free space in the Hugging Face cache (in bytes), or None if it can't be determined."""
|
|
226
|
-
with contextlib.suppress(FileNotFoundError):
|
|
227
|
-
return psutil.disk_usage(HF_HUB_CACHE).free
|
|
228
|
-
for parent_dir in Path(HF_HUB_CACHE).parents:
|
|
229
|
-
with contextlib.suppress(FileNotFoundError):
|
|
230
|
-
return psutil.disk_usage(str(parent_dir)).free
|
|
231
|
-
return None
|
|
177
|
+
thread_id_manager = AtomicIDManager(self.app_sqlite_db)
|
|
178
|
+
await thread_id_manager.initialize()
|
|
232
179
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
pynvml.nvmlInit()
|
|
256
|
-
except pynvml.NVMLError:
|
|
257
|
-
return -1, []
|
|
258
|
-
cuda_version = -1
|
|
259
|
-
nv_gpus = []
|
|
260
|
-
try:
|
|
261
|
-
cuda_version = pynvml.nvmlSystemGetCudaDriverVersion()
|
|
262
|
-
for i in range(pynvml.nvmlDeviceGetCount()):
|
|
263
|
-
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
264
|
-
nv_gpus.append(pynvml.nvmlDeviceGetName(handle))
|
|
265
|
-
except pynvml.NVMLError:
|
|
266
|
-
pass
|
|
267
|
-
finally:
|
|
268
|
-
with contextlib.suppress(pynvml.NVMLError):
|
|
269
|
-
pynvml.nvmlShutdown()
|
|
270
|
-
return cuda_version, nv_gpus
|
|
271
|
-
|
|
272
|
-
def probe_ollama(self) -> list[ollama.ListResponse.Model] | None:
|
|
273
|
-
"""Returns a list of models installed in Ollama, or None if connecting to Ollama fails."""
|
|
274
|
-
with contextlib.suppress(ConnectionError):
|
|
275
|
-
return list(ollama.list().models)
|
|
276
|
-
return None
|
|
277
|
-
|
|
278
|
-
def probe_ollama_version(self) -> str | None:
|
|
279
|
-
"""Returns the Ollama version string (e.g. "0.13.5"), or None if connecting to Ollama fails."""
|
|
280
|
-
# Yes, this uses private attributes, but that lets me use the Ollama Python lib's env var logic. If you use env
|
|
281
|
-
# vars to direct the app to a different Ollama server, this will query the same Ollama endpoint as the
|
|
282
|
-
# ollama.list() call above. Therefore I silence SLF001 here.
|
|
283
|
-
with contextlib.suppress(httpx.HTTPError, KeyError, ValueError):
|
|
284
|
-
response: httpx.Response = ollama._client._client.request("GET", "/api/version") # noqa: SLF001
|
|
285
|
-
response.raise_for_status()
|
|
286
|
-
return response.json()["version"]
|
|
287
|
-
return None
|
|
180
|
+
agent_providers: Sequence[AgentProvider] = self.agent_providers
|
|
181
|
+
if self.preferred_provider_type is not None:
|
|
182
|
+
preferred_providers: Sequence[AgentProvider] = tuple(
|
|
183
|
+
ap for ap in agent_providers if ap.type == self.preferred_provider_type
|
|
184
|
+
)
|
|
185
|
+
if len(preferred_providers) == 0:
|
|
186
|
+
raise UnknownPreferredProviderError(self.preferred_provider_type)
|
|
187
|
+
agent_providers = (
|
|
188
|
+
*preferred_providers,
|
|
189
|
+
*(ap for ap in agent_providers if ap.type != self.preferred_provider_type),
|
|
190
|
+
)
|
|
191
|
+
for agent_provider in agent_providers:
|
|
192
|
+
async with agent_provider.get_agent(checkpoints_sqlite_db=self.checkpoints_sqlite_db) as agent:
|
|
193
|
+
if agent is not None:
|
|
194
|
+
yield Runtime(
|
|
195
|
+
logic=self,
|
|
196
|
+
app=app,
|
|
197
|
+
agent=agent,
|
|
198
|
+
thread_id_manager=thread_id_manager,
|
|
199
|
+
)
|
|
200
|
+
return
|
|
201
|
+
raise NoProviderError
|
|
@@ -10,11 +10,12 @@ if TYPE_CHECKING:
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class LogicProvider(Protocol):
|
|
13
|
-
"""
|
|
13
|
+
"""Protocol for classes that contain application logic."""
|
|
14
14
|
|
|
15
15
|
logic: Logic
|
|
16
16
|
|
|
17
|
-
async def runtime(self) -> Runtime:
|
|
17
|
+
async def runtime(self) -> Runtime:
|
|
18
|
+
"""Returns the application runtime of the parent app."""
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class LogicProviderScreen(Screen):
|
rag_demo/modes/chat.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import time
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
8
|
import pyperclip
|
|
9
9
|
from textual.containers import HorizontalGroup, VerticalGroup, VerticalScroll
|
|
@@ -116,7 +116,7 @@ class Response(LogicProviderWidget):
|
|
|
116
116
|
self.set_reactive(Response.content, content)
|
|
117
117
|
self._stream: ResponseWriter | None = None
|
|
118
118
|
self.__object_to_show_sentinel = object()
|
|
119
|
-
self._object_to_show:
|
|
119
|
+
self._object_to_show: object = self.__object_to_show_sentinel
|
|
120
120
|
|
|
121
121
|
def compose(self) -> ComposeResult:
|
|
122
122
|
"""Compose the initial content of the widget."""
|
|
@@ -137,7 +137,8 @@ class Response(LogicProviderWidget):
|
|
|
137
137
|
self.query_one("#object-view", Pretty).display = False
|
|
138
138
|
self.query_one("#stop", Button).display = False
|
|
139
139
|
|
|
140
|
-
def set_shown_object(self, obj:
|
|
140
|
+
def set_shown_object(self, obj: object) -> None:
|
|
141
|
+
"""Show an object using a Pretty Widget instead of showing markdown or raw response content."""
|
|
141
142
|
self._object_to_show = obj
|
|
142
143
|
self.query_one("#markdown-view", Markdown).display = False
|
|
143
144
|
self.query_one("#raw-view", Label).display = False
|
|
@@ -146,6 +147,7 @@ class Response(LogicProviderWidget):
|
|
|
146
147
|
self.query_one("#object-view", Pretty).display = True
|
|
147
148
|
|
|
148
149
|
def clear_shown_object(self) -> None:
|
|
150
|
+
"""Stop showing an object in the widget."""
|
|
149
151
|
self._object_to_show = self.__object_to_show_sentinel
|
|
150
152
|
self.query_one("#object-view", Pretty).display = False
|
|
151
153
|
if self.show_raw:
|
|
@@ -192,14 +194,14 @@ class Response(LogicProviderWidget):
|
|
|
192
194
|
try:
|
|
193
195
|
pyperclip.copy(self.content)
|
|
194
196
|
except pyperclip.PyperclipException as e:
|
|
195
|
-
self.app.log.error(
|
|
197
|
+
self.app.log.error("Error copying to clipboard with Pyperclip:", e)
|
|
196
198
|
checkpoint2 = time.time()
|
|
197
199
|
self.notify(f"Copied {len(self.content.splitlines())} lines of text to clipboard")
|
|
198
200
|
end = time.time()
|
|
199
|
-
self.app.log.info(
|
|
200
|
-
self.app.log.info(
|
|
201
|
-
self.app.log.info(
|
|
202
|
-
self.app.log.info(
|
|
201
|
+
self.app.log.info("Textual copy took", f"{checkpoint - start:.6f}", "seconds")
|
|
202
|
+
self.app.log.info("Pyperclip copy took", f"{checkpoint2 - checkpoint:.6f}", "seconds")
|
|
203
|
+
self.app.log.info("Notify took", f"{end - checkpoint2:.6f}", "seconds")
|
|
204
|
+
self.app.log.info("Total of", f"{end - start:.6f}", "seconds")
|
|
203
205
|
|
|
204
206
|
def watch_show_raw(self) -> None:
|
|
205
207
|
"""Handle reactive updates to the show_raw attribute by changing the visibility of the child widgets.
|
rag_demo/probe.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import platform
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import cpuinfo
|
|
8
|
+
import httpx
|
|
9
|
+
import huggingface_hub
|
|
10
|
+
import ollama
|
|
11
|
+
import psutil
|
|
12
|
+
import pynvml
|
|
13
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
# llama-cpp-python is an optional dependency. If it is not installed in the dev environment then we need to ignore
|
|
17
|
+
# unresolved-import. If it is installed, then we need to ignore unused-ignore-comment (because there is no need to
|
|
18
|
+
# ignore unresolved-import in this case).
|
|
19
|
+
import llama_cpp # ty:ignore[unresolved-import, unused-ignore-comment]
|
|
20
|
+
|
|
21
|
+
LLAMA_AVAILABLE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
LLAMA_AVAILABLE = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def probe_os() -> str:
|
|
27
|
+
"""Returns the OS name (eg 'Linux' or 'Windows'), the system name (eg 'Java'), or an empty string if unknown."""
|
|
28
|
+
return platform.system()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def probe_architecture() -> str:
|
|
32
|
+
"""Returns the machine architecture, such as 'i386'."""
|
|
33
|
+
return platform.machine()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def probe_cpu() -> str:
|
|
37
|
+
"""Returns the name of the CPU, e.g. "Intel(R) Core(TM) i7-10610U CPU @ 1.80GHz"."""
|
|
38
|
+
return cpuinfo.get_cpu_info()["brand_raw"]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def probe_ram() -> int:
|
|
42
|
+
"""Returns the total amount of RAM in bytes."""
|
|
43
|
+
return psutil.virtual_memory().total
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def probe_disk_space() -> int:
|
|
47
|
+
"""Returns the amount of free space in the root directory (in bytes)."""
|
|
48
|
+
return psutil.disk_usage("/").free
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def probe_llama_available() -> bool:
|
|
52
|
+
"""Returns True if llama-cpp-python is installed, False otherwise."""
|
|
53
|
+
return LLAMA_AVAILABLE
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def probe_llamacpp_gpu_support() -> bool:
|
|
57
|
+
"""Returns True if the installed version of llama-cpp-python supports GPU offloading, False otherwise."""
|
|
58
|
+
return LLAMA_AVAILABLE and llama_cpp.llama_supports_gpu_offload()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def probe_huggingface_free_cache_space() -> int | None:
|
|
62
|
+
"""Returns the amount of free space in the Hugging Face cache (in bytes), or None if it can't be determined."""
|
|
63
|
+
with contextlib.suppress(FileNotFoundError):
|
|
64
|
+
return psutil.disk_usage(HF_HUB_CACHE).free
|
|
65
|
+
for parent_dir in Path(HF_HUB_CACHE).parents:
|
|
66
|
+
with contextlib.suppress(FileNotFoundError):
|
|
67
|
+
return psutil.disk_usage(str(parent_dir)).free
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def probe_huggingface_cached_models() -> list[huggingface_hub.CachedRepoInfo] | None:
|
|
72
|
+
"""Returns a list of models in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
|
|
73
|
+
# The docstring for huggingface_hub.scan_cache_dir says it raises CacheNotFound "if the cache directory does not
|
|
74
|
+
# exist," and ValueError "if the cache directory is a file, instead of a directory."
|
|
75
|
+
with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
|
|
76
|
+
return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "model"]
|
|
77
|
+
return None # Isn't it nice to be explicit?
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def probe_huggingface_cached_datasets() -> list[huggingface_hub.CachedRepoInfo] | None:
|
|
81
|
+
"""Returns a list of datasets in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
|
|
82
|
+
with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
|
|
83
|
+
return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "dataset"]
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def probe_nvidia() -> tuple[int, list[str]]:
|
|
88
|
+
"""Detect available NVIDIA GPUs and CUDA driver version.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
tuple[int, list[str]]: A tuple (cuda_version, nv_gpus) where cuda_version is the installed CUDA driver
|
|
92
|
+
version and nv_gpus is a list of GPU models corresponding to installed NVIDIA GPUs
|
|
93
|
+
"""
|
|
94
|
+
try:
|
|
95
|
+
pynvml.nvmlInit()
|
|
96
|
+
except pynvml.NVMLError:
|
|
97
|
+
return -1, []
|
|
98
|
+
cuda_version = -1
|
|
99
|
+
nv_gpus = []
|
|
100
|
+
try:
|
|
101
|
+
cuda_version = pynvml.nvmlSystemGetCudaDriverVersion()
|
|
102
|
+
for i in range(pynvml.nvmlDeviceGetCount()):
|
|
103
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
104
|
+
nv_gpus.append(pynvml.nvmlDeviceGetName(handle))
|
|
105
|
+
except pynvml.NVMLError:
|
|
106
|
+
pass
|
|
107
|
+
finally:
|
|
108
|
+
with contextlib.suppress(pynvml.NVMLError):
|
|
109
|
+
pynvml.nvmlShutdown()
|
|
110
|
+
return cuda_version, nv_gpus
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def probe_ollama() -> list[ollama.ListResponse.Model] | None:
|
|
114
|
+
"""Returns a list of models installed in Ollama, or None if connecting to Ollama fails."""
|
|
115
|
+
with contextlib.suppress(ConnectionError):
|
|
116
|
+
return list(ollama.list().models)
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def probe_ollama_version() -> str | None:
|
|
121
|
+
"""Returns the Ollama version string (e.g. "0.13.5"), or None if connecting to Ollama fails."""
|
|
122
|
+
# Yes, this uses private attributes, but that lets me use the Ollama Python lib's env var logic. If you use env
|
|
123
|
+
# vars to direct the app to a different Ollama server, this will query the same Ollama endpoint as the
|
|
124
|
+
# ollama.list() call above. Therefore I silence SLF001 here.
|
|
125
|
+
with contextlib.suppress(httpx.HTTPError, KeyError, ValueError):
|
|
126
|
+
response: httpx.Response = ollama._client._client.request("GET", "/api/version") # noqa: SLF001
|
|
127
|
+
response.raise_for_status()
|
|
128
|
+
return response.json()["version"]
|
|
129
|
+
return None
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
rag_demo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
rag_demo/__main__.py,sha256=Kak0eQWBRHVGDoWgoHs9j-Tvf_9DMzdurMxD7EM4Jr0,1054
|
|
3
|
-
rag_demo/app.py,sha256=xejrtFApeTeyOQvWDq1H0XPyZEr8cQPn7q9KRwnV660,1812
|
|
4
|
-
rag_demo/app.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
rag_demo/db.py,sha256=53n662Hj9sTqPNcCI2Q-6Ca_HXv3kBQdAhXU4DLhwBM,3226
|
|
6
|
-
rag_demo/dirs.py,sha256=b0VR76kXRHSRWzaXzmAhfPr3-8WKY3ZLW8aLlaPI3Do,309
|
|
7
|
-
rag_demo/logic.py,sha256=7PTWPs9xZJ7bbEtNDMQTX6SX4JKG8HMiq2H_YUfM-CI,12602
|
|
8
|
-
rag_demo/markdown.py,sha256=CxzshWfANeiieZkzMlLzpRaz7tBY2_tZQxhs7b2ImKM,551
|
|
9
|
-
rag_demo/modes/__init__.py,sha256=ccvURDWz51_IotzzlO2OH3i4_Ih_MgnGlOK_JCh45dY,91
|
|
10
|
-
rag_demo/modes/_logic_provider.py,sha256=__eO4XVbyRHkjV_D8OHsPJX5f2R8JoJPcNXhi-w_xFY,1277
|
|
11
|
-
rag_demo/modes/chat.py,sha256=VigWSkw6R2ea95-wZ8tgtKIccev9A-ByzJj7nzglsog,13444
|
|
12
|
-
rag_demo/modes/chat.tcss,sha256=YANlgYygiOr-e61N9HaGGdRPM36pdr-l4u72G0ozt4o,1032
|
|
13
|
-
rag_demo/modes/config.py,sha256=0A8IdY-GOeqCd0kMs2KMgQEsFFeVXEcnowOugtR_Q84,2609
|
|
14
|
-
rag_demo/modes/config.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
rag_demo/modes/help.py,sha256=riV8o4WDtsim09R4cRi0xkpYLgj4CL38IrjEz_mrRmk,713
|
|
16
|
-
rag_demo/modes/help.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
-
rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
rag_demo/widgets/__init__.py,sha256=JQ1KQjdYQ4texHw2iT4IyBKgTW0SzNYbNoHAbrdCwtk,44
|
|
19
|
-
rag_demo/widgets/escapable_input.py,sha256=VfFij4NOtQ4uX3YFETg5YPd0_nBMky9Xz-02oRdHu-w,4240
|
|
20
|
-
jehoctor_rag_demo-0.2.0.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
|
|
21
|
-
jehoctor_rag_demo-0.2.0.dist-info/entry_points.txt,sha256=-nDSFVcIqdTxzYM4fdveDk3xUKRhmlr_cRuqQechYh4,49
|
|
22
|
-
jehoctor_rag_demo-0.2.0.dist-info/METADATA,sha256=wp1mdAqjB0be_1Uly4hwAoz0bjRUDI6gb6gK5SdrHRU,3531
|
|
23
|
-
jehoctor_rag_demo-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|