jehoctor-rag-demo 0.1.1.dev1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jehoctor_rag_demo-0.2.1.dist-info/METADATA +125 -0
- jehoctor_rag_demo-0.2.1.dist-info/RECORD +31 -0
- jehoctor_rag_demo-0.2.1.dist-info/entry_points.txt +3 -0
- rag_demo/__init__.py +0 -2
- rag_demo/__main__.py +42 -0
- rag_demo/agents/__init__.py +4 -0
- rag_demo/agents/base.py +40 -0
- rag_demo/agents/hugging_face.py +116 -0
- rag_demo/agents/llama_cpp.py +113 -0
- rag_demo/agents/ollama.py +91 -0
- rag_demo/app.py +58 -0
- rag_demo/app.tcss +0 -0
- rag_demo/app_protocol.py +101 -0
- rag_demo/constants.py +11 -0
- rag_demo/db.py +87 -0
- rag_demo/dirs.py +14 -0
- rag_demo/logic.py +201 -0
- rag_demo/markdown.py +17 -0
- rag_demo/modes/__init__.py +3 -0
- rag_demo/modes/_logic_provider.py +44 -0
- rag_demo/modes/chat.py +317 -0
- rag_demo/modes/chat.tcss +75 -0
- rag_demo/modes/config.py +77 -0
- rag_demo/modes/config.tcss +0 -0
- rag_demo/modes/help.py +26 -0
- rag_demo/modes/help.tcss +0 -0
- rag_demo/probe.py +129 -0
- rag_demo/widgets/__init__.py +1 -0
- rag_demo/widgets/escapable_input.py +110 -0
- jehoctor_rag_demo-0.1.1.dev1.dist-info/METADATA +0 -11
- jehoctor_rag_demo-0.1.1.dev1.dist-info/RECORD +0 -6
- jehoctor_rag_demo-0.1.1.dev1.dist-info/entry_points.txt +0 -3
- {jehoctor_rag_demo-0.1.1.dev1.dist-info → jehoctor_rag_demo-0.2.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: jehoctor-rag-demo
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: Chat with Wikipedia
|
|
5
|
+
Author: James Hoctor
|
|
6
|
+
Author-email: James Hoctor <JEHoctor@protonmail.com>
|
|
7
|
+
Requires-Dist: aiosqlite==0.21.0
|
|
8
|
+
Requires-Dist: bitsandbytes>=0.49.1
|
|
9
|
+
Requires-Dist: chromadb>=1.3.4
|
|
10
|
+
Requires-Dist: datasets>=4.4.1
|
|
11
|
+
Requires-Dist: httpx>=0.28.1
|
|
12
|
+
Requires-Dist: huggingface-hub>=0.36.0
|
|
13
|
+
Requires-Dist: langchain>=1.0.5
|
|
14
|
+
Requires-Dist: langchain-anthropic>=1.0.2
|
|
15
|
+
Requires-Dist: langchain-community>=0.4.1
|
|
16
|
+
Requires-Dist: langchain-huggingface>=1.1.0
|
|
17
|
+
Requires-Dist: langchain-ollama>=1.0.0
|
|
18
|
+
Requires-Dist: langchain-openai>=1.0.2
|
|
19
|
+
Requires-Dist: langgraph-checkpoint-sqlite>=3.0.1
|
|
20
|
+
Requires-Dist: nvidia-ml-py>=13.590.44
|
|
21
|
+
Requires-Dist: ollama>=0.6.0
|
|
22
|
+
Requires-Dist: platformdirs>=4.5.0
|
|
23
|
+
Requires-Dist: psutil>=7.1.3
|
|
24
|
+
Requires-Dist: py-cpuinfo>=9.0.0
|
|
25
|
+
Requires-Dist: pydantic>=2.12.4
|
|
26
|
+
Requires-Dist: pyperclip>=1.11.0
|
|
27
|
+
Requires-Dist: sentence-transformers>=5.2.2
|
|
28
|
+
Requires-Dist: textual>=6.5.0
|
|
29
|
+
Requires-Dist: transformers[torch]>=4.57.6
|
|
30
|
+
Requires-Dist: typer>=0.20.0
|
|
31
|
+
Requires-Dist: llama-cpp-python>=0.3.16 ; extra == 'llamacpp'
|
|
32
|
+
Requires-Python: ~=3.12.0
|
|
33
|
+
Provides-Extra: llamacpp
|
|
34
|
+
Description-Content-Type: text/markdown
|
|
35
|
+
|
|
36
|
+
# RAG-demo
|
|
37
|
+
|
|
38
|
+
Chat with (a small portion of) Wikipedia
|
|
39
|
+
|
|
40
|
+
⚠️ RAG functionality is still under development. ⚠️
|
|
41
|
+
|
|
42
|
+

|
|
43
|
+
|
|
44
|
+
## Requirements
|
|
45
|
+
|
|
46
|
+
1. The [uv](https://docs.astral.sh/uv/) Python package manager
|
|
47
|
+
- Installing and updating `uv` is easy by following [the docs](https://docs.astral.sh/uv/getting-started/installation/).
|
|
48
|
+
- As of 2026-01-25, I'm developing using `uv` version 0.9.26, and using the new experimental `--pytorch-backend` option.
|
|
49
|
+
2. A terminal emulator or web browser
|
|
50
|
+
- Any common web browser will work.
|
|
51
|
+
- Some terminal emulators will work better than others.
|
|
52
|
+
See [Notes on terminal emulators](#notes-on-terminal-emulators) below.
|
|
53
|
+
|
|
54
|
+
### Notes on terminal emulators
|
|
55
|
+
|
|
56
|
+
Certain terminal emulators will not work with some features of this program.
|
|
57
|
+
In particular, on macOS consider using [iTerm2](https://iterm2.com/) instead of the default Terminal.app ([explanation](https://textual.textualize.io/FAQ/#why-doesnt-textual-look-good-on-macos)).
|
|
58
|
+
On Linux you might want to try [kitty](https://sw.kovidgoyal.net/kitty/), [wezterm](https://wezterm.org/), [alacritty](https://alacritty.org/), or [ghostty](https://ghostty.org/), instead of the terminal that came with your desktop environment ([reason](https://darren.codes/posts/textual-copy-paste/)).
|
|
59
|
+
Windows Terminal should be fine as far as I know.
|
|
60
|
+
|
|
61
|
+
### Optional dependencies
|
|
62
|
+
|
|
63
|
+
1. [Hugging Face login](https://huggingface.co/docs/huggingface_hub/quick-start#login)
|
|
64
|
+
2. API key for your favorite LLM provider (support coming soon)
|
|
65
|
+
3. Ollama installed on your system if you have a GPU
|
|
66
|
+
4. Run RAG-demo on a more capable (bigger GPU) machine over SSH if you can. It is a terminal app after all.
|
|
67
|
+
5. A C compiler if you want to build Llama.cpp from source.
|
|
68
|
+
|
|
69
|
+
## Run the latest version
|
|
70
|
+
|
|
71
|
+
Run in a terminal:
|
|
72
|
+
```bash
|
|
73
|
+
uvx --torch-backend=auto --from=jehoctor-rag-demo@latest chat
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Or run in a web browser:
|
|
77
|
+
```bash
|
|
78
|
+
uvx --torch-backend=auto --from=jehoctor-rag-demo@latest textual serve chat
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
## CUDA acceleration via Llama.cpp
|
|
82
|
+
|
|
83
|
+
If you have an NVIDIA GPU with CUDA and build tools installed, you might be able to get CUDA acceleration without installing Ollama.
|
|
84
|
+
|
|
85
|
+
```bash
|
|
86
|
+
CMAKE_ARGS="-DGGML_CUDA=on" uv run --extra=llamacpp chat
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
## Metal acceleration via Llama.cpp (on Apple Silicon)
|
|
90
|
+
|
|
91
|
+
On an Apple Silicon machine, make sure `uv` runs an ARM interpreter as this should cause it to install Llama.cpp with Metal support.
|
|
92
|
+
Also, run with the extra group `llamacpp`.
|
|
93
|
+
Try this:
|
|
94
|
+
|
|
95
|
+
```bash
|
|
96
|
+
uvx --python-platform=aarch64-apple-darwin --torch-backend=auto --from=jehoctor-rag-demo[llamacpp]@latest chat
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
## Ollama on Linux
|
|
100
|
+
|
|
101
|
+
Remember that you have to keep Ollama up-to-date manually on Linux.
|
|
102
|
+
A recent version of Ollama (v0.11.10 or later) is required to run the [embedding model we use](https://ollama.com/library/embeddinggemma).
|
|
103
|
+
See this FAQ: https://docs.ollama.com/faq#how-can-i-upgrade-ollama.
|
|
104
|
+
|
|
105
|
+
## Project feature roadmap
|
|
106
|
+
|
|
107
|
+
- ❌ RAG functionality
|
|
108
|
+
- ❌ torch inference via the Langchain local Hugging Face inference integration
|
|
109
|
+
- ❌ uv automatic torch backend selection (see [the docs](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection))
|
|
110
|
+
- ❌ OpenAI integration
|
|
111
|
+
- ❌ Anthropic integration
|
|
112
|
+
|
|
113
|
+
## Run from the repository
|
|
114
|
+
|
|
115
|
+
First, clone this repository. Then, run one of the options below.
|
|
116
|
+
|
|
117
|
+
Run in a terminal:
|
|
118
|
+
```bash
|
|
119
|
+
uv run chat
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
Or run in a web browser:
|
|
123
|
+
```bash
|
|
124
|
+
uv run textual serve chat
|
|
125
|
+
```
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
rag_demo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
rag_demo/__main__.py,sha256=S0UlQj3EldcXRk3rhH3ONdSOPmeyITYXIZ0o2JSxWbg,1618
|
|
3
|
+
rag_demo/agents/__init__.py,sha256=dsuO3AGcn2yGDq4gkAsZ32pjeTOqAudOL14G_AsEUyc,221
|
|
4
|
+
rag_demo/agents/base.py,sha256=gib6bC8nVKN1s1KPZd1dJVGRXnu7gFQwf3I3_7TSjQo,1312
|
|
5
|
+
rag_demo/agents/hugging_face.py,sha256=VrbGOlMO2z357LmU3sO5aM_yI5P-xbsfKmTAzH9_lFo,4225
|
|
6
|
+
rag_demo/agents/llama_cpp.py,sha256=C0hInc24sXmt5407_k4mP2Y6svqgfUPemhzAL1N6jY0,4272
|
|
7
|
+
rag_demo/agents/ollama.py,sha256=Fmtu8MSPPz91eT7HKvwvbQnA_xGPaD5HBrHEPPtomZA,3317
|
|
8
|
+
rag_demo/app.py,sha256=AVCJjlQ60y5J0v50TcJ3zZoa0ubhd_yKVDfu1ERsMVU,1807
|
|
9
|
+
rag_demo/app.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
rag_demo/app_protocol.py,sha256=P__Q3KT41uonYazpYmLWmOh1MeoBaCBx4xEydkwK5tk,3292
|
|
11
|
+
rag_demo/constants.py,sha256=5EpyAD6p5Qb0vB5ASMtFSLVVsq3RG5cBn_AYihhDCPc,235
|
|
12
|
+
rag_demo/db.py,sha256=53n662Hj9sTqPNcCI2Q-6Ca_HXv3kBQdAhXU4DLhwBM,3226
|
|
13
|
+
rag_demo/dirs.py,sha256=b0VR76kXRHSRWzaXzmAhfPr3-8WKY3ZLW8aLlaPI3Do,309
|
|
14
|
+
rag_demo/logic.py,sha256=SkF_Hqu1WSLHzwvSd_mJiCMSxZYqDnteYFRpc6oCREY,8236
|
|
15
|
+
rag_demo/markdown.py,sha256=CxzshWfANeiieZkzMlLzpRaz7tBY2_tZQxhs7b2ImKM,551
|
|
16
|
+
rag_demo/modes/__init__.py,sha256=ccvURDWz51_IotzzlO2OH3i4_Ih_MgnGlOK_JCh45dY,91
|
|
17
|
+
rag_demo/modes/_logic_provider.py,sha256=U3J8Fgq8MbNYd92FqENW-5YP_jXqKG3xmMmYoSUzhHo,1343
|
|
18
|
+
rag_demo/modes/chat.py,sha256=2pmKhQ2uYZdjezNnNBINViBMcuTVE5YCom_HEbdJeXg,13607
|
|
19
|
+
rag_demo/modes/chat.tcss,sha256=YANlgYygiOr-e61N9HaGGdRPM36pdr-l4u72G0ozt4o,1032
|
|
20
|
+
rag_demo/modes/config.py,sha256=0A8IdY-GOeqCd0kMs2KMgQEsFFeVXEcnowOugtR_Q84,2609
|
|
21
|
+
rag_demo/modes/config.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
+
rag_demo/modes/help.py,sha256=riV8o4WDtsim09R4cRi0xkpYLgj4CL38IrjEz_mrRmk,713
|
|
23
|
+
rag_demo/modes/help.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
|
+
rag_demo/probe.py,sha256=aDD-smNauEXXoBKVgx5xsMawM5tL0QAEBFl07ZGrddc,5101
|
|
25
|
+
rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
+
rag_demo/widgets/__init__.py,sha256=JQ1KQjdYQ4texHw2iT4IyBKgTW0SzNYbNoHAbrdCwtk,44
|
|
27
|
+
rag_demo/widgets/escapable_input.py,sha256=VfFij4NOtQ4uX3YFETg5YPd0_nBMky9Xz-02oRdHu-w,4240
|
|
28
|
+
jehoctor_rag_demo-0.2.1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
|
|
29
|
+
jehoctor_rag_demo-0.2.1.dist-info/entry_points.txt,sha256=-nDSFVcIqdTxzYM4fdveDk3xUKRhmlr_cRuqQechYh4,49
|
|
30
|
+
jehoctor_rag_demo-0.2.1.dist-info/METADATA,sha256=nCXuy3TYPPf67DPFryndW2P6CRdSCcJkXxFXZ-UN4vs,4650
|
|
31
|
+
jehoctor_rag_demo-0.2.1.dist-info/RECORD,,
|
rag_demo/__init__.py
CHANGED
rag_demo/__main__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
# Measure the application start time.
|
|
4
|
+
APPLICATION_START_TIME = time.time()
|
|
5
|
+
|
|
6
|
+
# Disable "module import not at top of file" (aka E402) when importing Typer and other early imports. This is necessary
|
|
7
|
+
# so that the initialization of these modules is included in the application startup time.
|
|
8
|
+
from typing import Annotated # noqa: E402
|
|
9
|
+
|
|
10
|
+
import typer # noqa: E402
|
|
11
|
+
|
|
12
|
+
from rag_demo.constants import LocalProviderType # noqa: E402
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _main(
|
|
16
|
+
name: Annotated[str | None, typer.Option(help="The name you want to want the AI to use with you.")] = None,
|
|
17
|
+
provider: Annotated[LocalProviderType | None, typer.Option(help="The local provider to prefer.")] = None,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Talk to Wikipedia."""
|
|
20
|
+
# Import here so that imports run within the typer.run context for prettier stack traces if errors occur.
|
|
21
|
+
# We ignore PLC0415 because we do not want these imports to be at the top of the module as is usually preferred.
|
|
22
|
+
import transformers # noqa: PLC0415
|
|
23
|
+
|
|
24
|
+
from rag_demo.app import RAGDemo # noqa: PLC0415
|
|
25
|
+
from rag_demo.logic import Logic # noqa: PLC0415
|
|
26
|
+
|
|
27
|
+
# The transformers library likes to print text that interferes with the TUI. Disable it.
|
|
28
|
+
transformers.logging.set_verbosity(verbosity=transformers.logging.CRITICAL)
|
|
29
|
+
transformers.logging.disable_progress_bar()
|
|
30
|
+
|
|
31
|
+
logic = Logic(username=name, preferred_provider_type=provider, application_start_time=APPLICATION_START_TIME)
|
|
32
|
+
app = RAGDemo(logic)
|
|
33
|
+
app.run()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def main() -> None:
|
|
37
|
+
"""Entrypoint for the rag demo, specifically the `chat` command."""
|
|
38
|
+
typer.run(_main)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
if __name__ == "__main__":
|
|
42
|
+
main()
|
rag_demo/agents/base.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Final, Protocol
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from contextlib import AbstractAsyncContextManager
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from rag_demo.app_protocol import AppProtocol
|
|
11
|
+
from rag_demo.constants import LocalProviderType
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Agent(Protocol):
|
|
15
|
+
"""An LLM agent that supports streaming responses asynchronously."""
|
|
16
|
+
|
|
17
|
+
def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
|
|
18
|
+
"""Stream a response from the agent.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
user_message (str): User's next prompt in the conversation.
|
|
22
|
+
thread_id (str): Identifier for the current thread/conversation.
|
|
23
|
+
app (AppProtocol): Application interface, commonly used for logging.
|
|
24
|
+
|
|
25
|
+
Yields:
|
|
26
|
+
str: A token from the agent's response.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AgentProvider(Protocol):
|
|
31
|
+
"""A strategy for creating LLM agents."""
|
|
32
|
+
|
|
33
|
+
type: Final[LocalProviderType]
|
|
34
|
+
|
|
35
|
+
def get_agent(self, checkpoints_sqlite_db: str | Path) -> AbstractAsyncContextManager[Agent | None]:
|
|
36
|
+
"""Attempt to create an agent.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
40
|
+
"""
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import sqlite3
|
|
5
|
+
from contextlib import asynccontextmanager
|
|
6
|
+
from typing import TYPE_CHECKING, Final
|
|
7
|
+
|
|
8
|
+
from huggingface_hub import hf_hub_download
|
|
9
|
+
from langchain.agents import create_agent
|
|
10
|
+
from langchain.messages import AIMessageChunk, HumanMessage
|
|
11
|
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFacePipeline
|
|
12
|
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
13
|
+
|
|
14
|
+
from rag_demo.constants import LocalProviderType
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import AsyncIterator
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
from rag_demo.app_protocol import AppProtocol
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HuggingFaceAgent:
|
|
24
|
+
"""An LLM agent powered by Hugging Face local pipelines."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
checkpoints_sqlite_db: str | Path,
|
|
29
|
+
model_id: str,
|
|
30
|
+
embedding_model_id: str,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initialize the HuggingFaceAgent.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
36
|
+
model_id (str): Hugging Face model ID for the LLM.
|
|
37
|
+
embedding_model_id (str): Hugging Face model ID for the embedding model.
|
|
38
|
+
"""
|
|
39
|
+
self.checkpoints_sqlite_db = checkpoints_sqlite_db
|
|
40
|
+
self.model_id = model_id
|
|
41
|
+
self.embedding_model_id = embedding_model_id
|
|
42
|
+
|
|
43
|
+
self.llm = ChatHuggingFace(
|
|
44
|
+
llm=HuggingFacePipeline.from_model_id(
|
|
45
|
+
model_id=model_id,
|
|
46
|
+
task="text-generation",
|
|
47
|
+
device_map="auto",
|
|
48
|
+
pipeline_kwargs={"max_new_tokens": 4096},
|
|
49
|
+
),
|
|
50
|
+
)
|
|
51
|
+
self.embed = HuggingFaceEmbeddings(model_name=embedding_model_id)
|
|
52
|
+
self.agent = create_agent(
|
|
53
|
+
model=self.llm,
|
|
54
|
+
system_prompt="You are a helpful assistant.",
|
|
55
|
+
checkpointer=SqliteSaver(sqlite3.Connection(self.checkpoints_sqlite_db, check_same_thread=False)),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
|
|
59
|
+
"""Stream a response from the agent.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
user_message (str): User's next prompt in the conversation.
|
|
63
|
+
thread_id (str): Identifier for the current thread/conversation.
|
|
64
|
+
app (AppProtocol): Application interface, commonly used for logging.
|
|
65
|
+
|
|
66
|
+
Yields:
|
|
67
|
+
str: A token from the agent's response.
|
|
68
|
+
"""
|
|
69
|
+
agent_stream = self.agent.stream(
|
|
70
|
+
{"messages": [HumanMessage(content=user_message)]},
|
|
71
|
+
{"configurable": {"thread_id": thread_id}},
|
|
72
|
+
stream_mode="messages",
|
|
73
|
+
)
|
|
74
|
+
for message_chunk, _ in agent_stream:
|
|
75
|
+
if isinstance(message_chunk, AIMessageChunk):
|
|
76
|
+
token = message_chunk.content
|
|
77
|
+
if isinstance(token, str):
|
|
78
|
+
yield token
|
|
79
|
+
else:
|
|
80
|
+
app.log.error("Received message content of type", type(token))
|
|
81
|
+
else:
|
|
82
|
+
app.log.error("Received message chunk of type", type(message_chunk))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _hf_downloads() -> None:
|
|
86
|
+
hf_hub_download(
|
|
87
|
+
repo_id="Qwen/Qwen3-0.6B", # 1.5GB
|
|
88
|
+
filename="model.safetensors",
|
|
89
|
+
revision="c1899de289a04d12100db370d81485cdf75e47ca",
|
|
90
|
+
)
|
|
91
|
+
hf_hub_download(
|
|
92
|
+
repo_id="unsloth/embeddinggemma-300m", # 1.21GB
|
|
93
|
+
filename="model.safetensors",
|
|
94
|
+
revision="bfa3c846ac738e62aa61806ef9112d34acb1dc5a",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class HuggingFaceAgentProvider:
|
|
99
|
+
"""Create LLM agents using Hugging Face local pipelines."""
|
|
100
|
+
|
|
101
|
+
type: Final[LocalProviderType] = LocalProviderType.HUGGING_FACE
|
|
102
|
+
|
|
103
|
+
@asynccontextmanager
|
|
104
|
+
async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[HuggingFaceAgent]:
|
|
105
|
+
"""Create a Hugging Face local pipeline agent.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
109
|
+
"""
|
|
110
|
+
loop = asyncio.get_running_loop()
|
|
111
|
+
await loop.run_in_executor(None, _hf_downloads)
|
|
112
|
+
yield HuggingFaceAgent(
|
|
113
|
+
checkpoints_sqlite_db,
|
|
114
|
+
model_id="Qwen/Qwen3-0.6B",
|
|
115
|
+
embedding_model_id="unsloth/embeddinggemma-300m",
|
|
116
|
+
)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from typing import TYPE_CHECKING, Final
|
|
6
|
+
|
|
7
|
+
import aiosqlite
|
|
8
|
+
from huggingface_hub import hf_hub_download
|
|
9
|
+
from langchain.agents import create_agent
|
|
10
|
+
from langchain.messages import AIMessageChunk, HumanMessage
|
|
11
|
+
from langchain_community.chat_models import ChatLlamaCpp
|
|
12
|
+
from langchain_community.embeddings import LlamaCppEmbeddings
|
|
13
|
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
14
|
+
|
|
15
|
+
from rag_demo import probe
|
|
16
|
+
from rag_demo.constants import LocalProviderType
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from collections.abc import AsyncIterator
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
from rag_demo.app_protocol import AppProtocol
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LlamaCppAgent:
|
|
26
|
+
"""An LLM agent powered by Llama.cpp."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
checkpoints_conn: aiosqlite.Connection,
|
|
31
|
+
model_path: str,
|
|
32
|
+
embedding_model_path: str,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initialize the LlamaCppAgent.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
checkpoints_conn (aiosqlite.Connection): Connection to SQLite checkpoint database.
|
|
38
|
+
model_path (str): Path to Llama.cpp model.
|
|
39
|
+
embedding_model_path (str): Path to Llama.cpp embedding model.
|
|
40
|
+
"""
|
|
41
|
+
self.checkpoints_conn = checkpoints_conn
|
|
42
|
+
self.llm = ChatLlamaCpp(model_path=model_path, verbose=False)
|
|
43
|
+
self.embed = LlamaCppEmbeddings(model_path=embedding_model_path, verbose=False)
|
|
44
|
+
self.agent = create_agent(
|
|
45
|
+
model=self.llm,
|
|
46
|
+
system_prompt="You are a helpful assistant.",
|
|
47
|
+
checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
|
|
51
|
+
"""Stream a response from the agent.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
user_message (str): User's next prompt in the conversation.
|
|
55
|
+
thread_id (str): Identifier for the current thread/conversation.
|
|
56
|
+
app (AppProtocol): Application interface, commonly used for logging.
|
|
57
|
+
|
|
58
|
+
Yields:
|
|
59
|
+
str: A token from the agent's response.
|
|
60
|
+
"""
|
|
61
|
+
agent_stream = self.agent.astream(
|
|
62
|
+
{"messages": [HumanMessage(content=user_message)]},
|
|
63
|
+
{"configurable": {"thread_id": thread_id}},
|
|
64
|
+
stream_mode="messages",
|
|
65
|
+
)
|
|
66
|
+
async for message_chunk, _ in agent_stream:
|
|
67
|
+
if isinstance(message_chunk, AIMessageChunk):
|
|
68
|
+
token = message_chunk.content
|
|
69
|
+
if isinstance(token, str):
|
|
70
|
+
yield token
|
|
71
|
+
else:
|
|
72
|
+
app.log.error("Received message content of type", type(token))
|
|
73
|
+
else:
|
|
74
|
+
app.log.error("Received message chunk of type", type(message_chunk))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _hf_downloads() -> tuple[str, str]:
|
|
78
|
+
model_path = hf_hub_download(
|
|
79
|
+
repo_id="bartowski/google_gemma-3-4b-it-GGUF",
|
|
80
|
+
filename="google_gemma-3-4b-it-Q6_K_L.gguf", # 3.35GB
|
|
81
|
+
revision="71506238f970075ca85125cd749c28b1b0eee84e",
|
|
82
|
+
)
|
|
83
|
+
embedding_model_path = hf_hub_download(
|
|
84
|
+
repo_id="CompendiumLabs/bge-small-en-v1.5-gguf",
|
|
85
|
+
filename="bge-small-en-v1.5-q8_0.gguf", # 36.8MB
|
|
86
|
+
revision="d32f8c040ea3b516330eeb75b72bcc2d3a780ab7",
|
|
87
|
+
)
|
|
88
|
+
return model_path, embedding_model_path
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class LlamaCppAgentProvider:
|
|
92
|
+
"""Create LLM agents using Llama.cpp."""
|
|
93
|
+
|
|
94
|
+
type: Final[LocalProviderType] = LocalProviderType.LLAMA_CPP
|
|
95
|
+
|
|
96
|
+
@asynccontextmanager
|
|
97
|
+
async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[LlamaCppAgent | None]:
|
|
98
|
+
"""Attempt to create a Llama.cpp agent.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
102
|
+
"""
|
|
103
|
+
if probe.probe_llama_available():
|
|
104
|
+
loop = asyncio.get_running_loop()
|
|
105
|
+
model_path, embedding_model_path = await loop.run_in_executor(None, _hf_downloads)
|
|
106
|
+
async with aiosqlite.connect(database=checkpoints_sqlite_db) as checkpoints_conn:
|
|
107
|
+
yield LlamaCppAgent(
|
|
108
|
+
checkpoints_conn=checkpoints_conn,
|
|
109
|
+
model_path=model_path,
|
|
110
|
+
embedding_model_path=embedding_model_path,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
yield None
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from typing import TYPE_CHECKING, Final
|
|
5
|
+
|
|
6
|
+
import aiosqlite
|
|
7
|
+
import ollama
|
|
8
|
+
from langchain.agents import create_agent
|
|
9
|
+
from langchain.messages import AIMessageChunk, HumanMessage
|
|
10
|
+
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
|
11
|
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
12
|
+
|
|
13
|
+
from rag_demo import probe
|
|
14
|
+
from rag_demo.constants import LocalProviderType
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import AsyncIterator
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
from rag_demo.app_protocol import AppProtocol
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OllamaAgent:
|
|
24
|
+
"""An LLM agent powered by Ollama."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, checkpoints_conn: aiosqlite.Connection) -> None:
|
|
27
|
+
"""Initialize the OllamaAgent.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
checkpoints_conn (aiosqlite.Connection): Asynchronous connection to SQLite db for checkpoints.
|
|
31
|
+
"""
|
|
32
|
+
self.checkpoints_conn = checkpoints_conn
|
|
33
|
+
ollama.pull("gemma3:latest") # 3.3GB
|
|
34
|
+
ollama.pull("embeddinggemma:latest") # 621MB
|
|
35
|
+
self.llm = ChatOllama(
|
|
36
|
+
model="gemma3:latest",
|
|
37
|
+
validate_model_on_init=True,
|
|
38
|
+
temperature=0.5,
|
|
39
|
+
num_predict=4096,
|
|
40
|
+
)
|
|
41
|
+
self.embed = OllamaEmbeddings(model="embeddinggemma:latest")
|
|
42
|
+
self.agent = create_agent(
|
|
43
|
+
model=self.llm,
|
|
44
|
+
system_prompt="You are a helpful assistant.",
|
|
45
|
+
checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
|
|
49
|
+
"""Stream a response from the agent.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
user_message (str): User's next prompt in the conversation.
|
|
53
|
+
thread_id (str): Identifier for the current thread/conversation.
|
|
54
|
+
app (AppProtocol): Application interface, commonly used for logging.
|
|
55
|
+
|
|
56
|
+
Yields:
|
|
57
|
+
str: A token from the agent's response.
|
|
58
|
+
"""
|
|
59
|
+
agent_stream = self.agent.astream(
|
|
60
|
+
{"messages": [HumanMessage(content=user_message)]},
|
|
61
|
+
{"configurable": {"thread_id": thread_id}},
|
|
62
|
+
stream_mode="messages",
|
|
63
|
+
)
|
|
64
|
+
async for message_chunk, _ in agent_stream:
|
|
65
|
+
if isinstance(message_chunk, AIMessageChunk):
|
|
66
|
+
token = message_chunk.content
|
|
67
|
+
if isinstance(token, str):
|
|
68
|
+
yield token
|
|
69
|
+
else:
|
|
70
|
+
app.log.error("Received message content of type", type(token))
|
|
71
|
+
else:
|
|
72
|
+
app.log.error("Received message chunk of type", type(message_chunk))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class OllamaAgentProvider:
|
|
76
|
+
"""Create LLM agents using Ollama."""
|
|
77
|
+
|
|
78
|
+
type: Final[LocalProviderType] = LocalProviderType.OLLAMA
|
|
79
|
+
|
|
80
|
+
@asynccontextmanager
|
|
81
|
+
async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[OllamaAgent | None]:
|
|
82
|
+
"""Attempt to create an Ollama agent.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
|
|
86
|
+
"""
|
|
87
|
+
if probe.probe_ollama() is not None:
|
|
88
|
+
async with aiosqlite.connect(database=checkpoints_sqlite_db) as checkpoints_conn:
|
|
89
|
+
yield OllamaAgent(checkpoints_conn=checkpoints_conn)
|
|
90
|
+
else:
|
|
91
|
+
yield None
|
rag_demo/app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
6
|
+
|
|
7
|
+
from textual.app import App
|
|
8
|
+
from textual.binding import Binding
|
|
9
|
+
|
|
10
|
+
from rag_demo.modes import ChatScreen, ConfigScreen, HelpScreen
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from rag_demo.logic import Logic, Runtime
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RAGDemo(App):
|
|
17
|
+
"""Main application UI.
|
|
18
|
+
|
|
19
|
+
This class is responsible for creating the modes of the application, which are defined in :mod:`rag_demo.modes`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
TITLE = "RAG Demo"
|
|
23
|
+
CSS_PATH = Path(__file__).parent / "app.tcss"
|
|
24
|
+
BINDINGS: ClassVar = [
|
|
25
|
+
Binding("z", "switch_mode('chat')", "chat"),
|
|
26
|
+
Binding("c", "switch_mode('config')", "configure"),
|
|
27
|
+
Binding("h", "switch_mode('help')", "help"),
|
|
28
|
+
]
|
|
29
|
+
MODES: ClassVar = {
|
|
30
|
+
"chat": ChatScreen,
|
|
31
|
+
"config": ConfigScreen,
|
|
32
|
+
"help": HelpScreen,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
def __init__(self, logic: Logic) -> None:
|
|
36
|
+
"""Initialize the main app.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
logic (Logic): Object implementing the application logic.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.logic = logic
|
|
43
|
+
self._runtime_future: asyncio.Future[Runtime] = asyncio.Future()
|
|
44
|
+
|
|
45
|
+
async def on_mount(self) -> None:
|
|
46
|
+
"""Set the initial mode to chat and initialize async parts of the logic."""
|
|
47
|
+
self.switch_mode("chat")
|
|
48
|
+
self.run_worker(self._hold_runtime())
|
|
49
|
+
|
|
50
|
+
async def _hold_runtime(self) -> None:
|
|
51
|
+
async with self.logic.runtime(app=self) as runtime:
|
|
52
|
+
self._runtime_future.set_result(runtime)
|
|
53
|
+
# Pause the task until Textual cancels it when the application closes.
|
|
54
|
+
await asyncio.Event().wait()
|
|
55
|
+
|
|
56
|
+
async def runtime(self) -> Runtime:
|
|
57
|
+
"""Returns the application runtime logic."""
|
|
58
|
+
return await self._runtime_future
|
rag_demo/app.tcss
ADDED
|
File without changes
|