jehoctor-rag-demo 0.1.1.dev1__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,125 @@
1
+ Metadata-Version: 2.3
2
+ Name: jehoctor-rag-demo
3
+ Version: 0.2.1
4
+ Summary: Chat with Wikipedia
5
+ Author: James Hoctor
6
+ Author-email: James Hoctor <JEHoctor@protonmail.com>
7
+ Requires-Dist: aiosqlite==0.21.0
8
+ Requires-Dist: bitsandbytes>=0.49.1
9
+ Requires-Dist: chromadb>=1.3.4
10
+ Requires-Dist: datasets>=4.4.1
11
+ Requires-Dist: httpx>=0.28.1
12
+ Requires-Dist: huggingface-hub>=0.36.0
13
+ Requires-Dist: langchain>=1.0.5
14
+ Requires-Dist: langchain-anthropic>=1.0.2
15
+ Requires-Dist: langchain-community>=0.4.1
16
+ Requires-Dist: langchain-huggingface>=1.1.0
17
+ Requires-Dist: langchain-ollama>=1.0.0
18
+ Requires-Dist: langchain-openai>=1.0.2
19
+ Requires-Dist: langgraph-checkpoint-sqlite>=3.0.1
20
+ Requires-Dist: nvidia-ml-py>=13.590.44
21
+ Requires-Dist: ollama>=0.6.0
22
+ Requires-Dist: platformdirs>=4.5.0
23
+ Requires-Dist: psutil>=7.1.3
24
+ Requires-Dist: py-cpuinfo>=9.0.0
25
+ Requires-Dist: pydantic>=2.12.4
26
+ Requires-Dist: pyperclip>=1.11.0
27
+ Requires-Dist: sentence-transformers>=5.2.2
28
+ Requires-Dist: textual>=6.5.0
29
+ Requires-Dist: transformers[torch]>=4.57.6
30
+ Requires-Dist: typer>=0.20.0
31
+ Requires-Dist: llama-cpp-python>=0.3.16 ; extra == 'llamacpp'
32
+ Requires-Python: ~=3.12.0
33
+ Provides-Extra: llamacpp
34
+ Description-Content-Type: text/markdown
35
+
36
+ # RAG-demo
37
+
38
+ Chat with (a small portion of) Wikipedia
39
+
40
+ ⚠️ RAG functionality is still under development. ⚠️
41
+
42
+ ![app screenshot](screenshots/screenshot_0.2.0.png "App screenshot")
43
+
44
+ ## Requirements
45
+
46
+ 1. The [uv](https://docs.astral.sh/uv/) Python package manager
47
+ - Installing and updating `uv` is easy by following [the docs](https://docs.astral.sh/uv/getting-started/installation/).
48
+ - As of 2026-01-25, I'm developing using `uv` version 0.9.26, and using the new experimental `--pytorch-backend` option.
49
+ 2. A terminal emulator or web browser
50
+ - Any common web browser will work.
51
+ - Some terminal emulators will work better than others.
52
+ See [Notes on terminal emulators](#notes-on-terminal-emulators) below.
53
+
54
+ ### Notes on terminal emulators
55
+
56
+ Certain terminal emulators will not work with some features of this program.
57
+ In particular, on macOS consider using [iTerm2](https://iterm2.com/) instead of the default Terminal.app ([explanation](https://textual.textualize.io/FAQ/#why-doesnt-textual-look-good-on-macos)).
58
+ On Linux you might want to try [kitty](https://sw.kovidgoyal.net/kitty/), [wezterm](https://wezterm.org/), [alacritty](https://alacritty.org/), or [ghostty](https://ghostty.org/), instead of the terminal that came with your desktop environment ([reason](https://darren.codes/posts/textual-copy-paste/)).
59
+ Windows Terminal should be fine as far as I know.
60
+
61
+ ### Optional dependencies
62
+
63
+ 1. [Hugging Face login](https://huggingface.co/docs/huggingface_hub/quick-start#login)
64
+ 2. API key for your favorite LLM provider (support coming soon)
65
+ 3. Ollama installed on your system if you have a GPU
66
+ 4. Run RAG-demo on a more capable (bigger GPU) machine over SSH if you can. It is a terminal app after all.
67
+ 5. A C compiler if you want to build Llama.cpp from source.
68
+
69
+ ## Run the latest version
70
+
71
+ Run in a terminal:
72
+ ```bash
73
+ uvx --torch-backend=auto --from=jehoctor-rag-demo@latest chat
74
+ ```
75
+
76
+ Or run in a web browser:
77
+ ```bash
78
+ uvx --torch-backend=auto --from=jehoctor-rag-demo@latest textual serve chat
79
+ ```
80
+
81
+ ## CUDA acceleration via Llama.cpp
82
+
83
+ If you have an NVIDIA GPU with CUDA and build tools installed, you might be able to get CUDA acceleration without installing Ollama.
84
+
85
+ ```bash
86
+ CMAKE_ARGS="-DGGML_CUDA=on" uv run --extra=llamacpp chat
87
+ ```
88
+
89
+ ## Metal acceleration via Llama.cpp (on Apple Silicon)
90
+
91
+ On an Apple Silicon machine, make sure `uv` runs an ARM interpreter as this should cause it to install Llama.cpp with Metal support.
92
+ Also, run with the extra group `llamacpp`.
93
+ Try this:
94
+
95
+ ```bash
96
+ uvx --python-platform=aarch64-apple-darwin --torch-backend=auto --from=jehoctor-rag-demo[llamacpp]@latest chat
97
+ ```
98
+
99
+ ## Ollama on Linux
100
+
101
+ Remember that you have to keep Ollama up-to-date manually on Linux.
102
+ A recent version of Ollama (v0.11.10 or later) is required to run the [embedding model we use](https://ollama.com/library/embeddinggemma).
103
+ See this FAQ: https://docs.ollama.com/faq#how-can-i-upgrade-ollama.
104
+
105
+ ## Project feature roadmap
106
+
107
+ - ❌ RAG functionality
108
+ - ❌ torch inference via the Langchain local Hugging Face inference integration
109
+ - ❌ uv automatic torch backend selection (see [the docs](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection))
110
+ - ❌ OpenAI integration
111
+ - ❌ Anthropic integration
112
+
113
+ ## Run from the repository
114
+
115
+ First, clone this repository. Then, run one of the options below.
116
+
117
+ Run in a terminal:
118
+ ```bash
119
+ uv run chat
120
+ ```
121
+
122
+ Or run in a web browser:
123
+ ```bash
124
+ uv run textual serve chat
125
+ ```
@@ -0,0 +1,31 @@
1
+ rag_demo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ rag_demo/__main__.py,sha256=S0UlQj3EldcXRk3rhH3ONdSOPmeyITYXIZ0o2JSxWbg,1618
3
+ rag_demo/agents/__init__.py,sha256=dsuO3AGcn2yGDq4gkAsZ32pjeTOqAudOL14G_AsEUyc,221
4
+ rag_demo/agents/base.py,sha256=gib6bC8nVKN1s1KPZd1dJVGRXnu7gFQwf3I3_7TSjQo,1312
5
+ rag_demo/agents/hugging_face.py,sha256=VrbGOlMO2z357LmU3sO5aM_yI5P-xbsfKmTAzH9_lFo,4225
6
+ rag_demo/agents/llama_cpp.py,sha256=C0hInc24sXmt5407_k4mP2Y6svqgfUPemhzAL1N6jY0,4272
7
+ rag_demo/agents/ollama.py,sha256=Fmtu8MSPPz91eT7HKvwvbQnA_xGPaD5HBrHEPPtomZA,3317
8
+ rag_demo/app.py,sha256=AVCJjlQ60y5J0v50TcJ3zZoa0ubhd_yKVDfu1ERsMVU,1807
9
+ rag_demo/app.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ rag_demo/app_protocol.py,sha256=P__Q3KT41uonYazpYmLWmOh1MeoBaCBx4xEydkwK5tk,3292
11
+ rag_demo/constants.py,sha256=5EpyAD6p5Qb0vB5ASMtFSLVVsq3RG5cBn_AYihhDCPc,235
12
+ rag_demo/db.py,sha256=53n662Hj9sTqPNcCI2Q-6Ca_HXv3kBQdAhXU4DLhwBM,3226
13
+ rag_demo/dirs.py,sha256=b0VR76kXRHSRWzaXzmAhfPr3-8WKY3ZLW8aLlaPI3Do,309
14
+ rag_demo/logic.py,sha256=SkF_Hqu1WSLHzwvSd_mJiCMSxZYqDnteYFRpc6oCREY,8236
15
+ rag_demo/markdown.py,sha256=CxzshWfANeiieZkzMlLzpRaz7tBY2_tZQxhs7b2ImKM,551
16
+ rag_demo/modes/__init__.py,sha256=ccvURDWz51_IotzzlO2OH3i4_Ih_MgnGlOK_JCh45dY,91
17
+ rag_demo/modes/_logic_provider.py,sha256=U3J8Fgq8MbNYd92FqENW-5YP_jXqKG3xmMmYoSUzhHo,1343
18
+ rag_demo/modes/chat.py,sha256=2pmKhQ2uYZdjezNnNBINViBMcuTVE5YCom_HEbdJeXg,13607
19
+ rag_demo/modes/chat.tcss,sha256=YANlgYygiOr-e61N9HaGGdRPM36pdr-l4u72G0ozt4o,1032
20
+ rag_demo/modes/config.py,sha256=0A8IdY-GOeqCd0kMs2KMgQEsFFeVXEcnowOugtR_Q84,2609
21
+ rag_demo/modes/config.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ rag_demo/modes/help.py,sha256=riV8o4WDtsim09R4cRi0xkpYLgj4CL38IrjEz_mrRmk,713
23
+ rag_demo/modes/help.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
+ rag_demo/probe.py,sha256=aDD-smNauEXXoBKVgx5xsMawM5tL0QAEBFl07ZGrddc,5101
25
+ rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ rag_demo/widgets/__init__.py,sha256=JQ1KQjdYQ4texHw2iT4IyBKgTW0SzNYbNoHAbrdCwtk,44
27
+ rag_demo/widgets/escapable_input.py,sha256=VfFij4NOtQ4uX3YFETg5YPd0_nBMky9Xz-02oRdHu-w,4240
28
+ jehoctor_rag_demo-0.2.1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
29
+ jehoctor_rag_demo-0.2.1.dist-info/entry_points.txt,sha256=-nDSFVcIqdTxzYM4fdveDk3xUKRhmlr_cRuqQechYh4,49
30
+ jehoctor_rag_demo-0.2.1.dist-info/METADATA,sha256=nCXuy3TYPPf67DPFryndW2P6CRdSCcJkXxFXZ-UN4vs,4650
31
+ jehoctor_rag_demo-0.2.1.dist-info/RECORD,,
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ chat = rag_demo.__main__:main
3
+
rag_demo/__init__.py CHANGED
@@ -1,2 +0,0 @@
1
- def main() -> None:
2
- print("Hello from rag-demo!")
rag_demo/__main__.py ADDED
@@ -0,0 +1,42 @@
1
+ import time
2
+
3
+ # Measure the application start time.
4
+ APPLICATION_START_TIME = time.time()
5
+
6
+ # Disable "module import not at top of file" (aka E402) when importing Typer and other early imports. This is necessary
7
+ # so that the initialization of these modules is included in the application startup time.
8
+ from typing import Annotated # noqa: E402
9
+
10
+ import typer # noqa: E402
11
+
12
+ from rag_demo.constants import LocalProviderType # noqa: E402
13
+
14
+
15
+ def _main(
16
+ name: Annotated[str | None, typer.Option(help="The name you want to want the AI to use with you.")] = None,
17
+ provider: Annotated[LocalProviderType | None, typer.Option(help="The local provider to prefer.")] = None,
18
+ ) -> None:
19
+ """Talk to Wikipedia."""
20
+ # Import here so that imports run within the typer.run context for prettier stack traces if errors occur.
21
+ # We ignore PLC0415 because we do not want these imports to be at the top of the module as is usually preferred.
22
+ import transformers # noqa: PLC0415
23
+
24
+ from rag_demo.app import RAGDemo # noqa: PLC0415
25
+ from rag_demo.logic import Logic # noqa: PLC0415
26
+
27
+ # The transformers library likes to print text that interferes with the TUI. Disable it.
28
+ transformers.logging.set_verbosity(verbosity=transformers.logging.CRITICAL)
29
+ transformers.logging.disable_progress_bar()
30
+
31
+ logic = Logic(username=name, preferred_provider_type=provider, application_start_time=APPLICATION_START_TIME)
32
+ app = RAGDemo(logic)
33
+ app.run()
34
+
35
+
36
+ def main() -> None:
37
+ """Entrypoint for the rag demo, specifically the `chat` command."""
38
+ typer.run(_main)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
@@ -0,0 +1,4 @@
1
+ from .base import Agent, AgentProvider
2
+ from .hugging_face import HuggingFaceAgent, HuggingFaceAgentProvider
3
+ from .llama_cpp import LlamaCppAgent, LlamaCppAgentProvider
4
+ from .ollama import OllamaAgent, OllamaAgentProvider
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Final, Protocol
4
+
5
+ if TYPE_CHECKING:
6
+ from collections.abc import AsyncIterator
7
+ from contextlib import AbstractAsyncContextManager
8
+ from pathlib import Path
9
+
10
+ from rag_demo.app_protocol import AppProtocol
11
+ from rag_demo.constants import LocalProviderType
12
+
13
+
14
+ class Agent(Protocol):
15
+ """An LLM agent that supports streaming responses asynchronously."""
16
+
17
+ def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
18
+ """Stream a response from the agent.
19
+
20
+ Args:
21
+ user_message (str): User's next prompt in the conversation.
22
+ thread_id (str): Identifier for the current thread/conversation.
23
+ app (AppProtocol): Application interface, commonly used for logging.
24
+
25
+ Yields:
26
+ str: A token from the agent's response.
27
+ """
28
+
29
+
30
+ class AgentProvider(Protocol):
31
+ """A strategy for creating LLM agents."""
32
+
33
+ type: Final[LocalProviderType]
34
+
35
+ def get_agent(self, checkpoints_sqlite_db: str | Path) -> AbstractAsyncContextManager[Agent | None]:
36
+ """Attempt to create an agent.
37
+
38
+ Args:
39
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
40
+ """
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import sqlite3
5
+ from contextlib import asynccontextmanager
6
+ from typing import TYPE_CHECKING, Final
7
+
8
+ from huggingface_hub import hf_hub_download
9
+ from langchain.agents import create_agent
10
+ from langchain.messages import AIMessageChunk, HumanMessage
11
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFacePipeline
12
+ from langgraph.checkpoint.sqlite import SqliteSaver
13
+
14
+ from rag_demo.constants import LocalProviderType
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import AsyncIterator
18
+ from pathlib import Path
19
+
20
+ from rag_demo.app_protocol import AppProtocol
21
+
22
+
23
+ class HuggingFaceAgent:
24
+ """An LLM agent powered by Hugging Face local pipelines."""
25
+
26
+ def __init__(
27
+ self,
28
+ checkpoints_sqlite_db: str | Path,
29
+ model_id: str,
30
+ embedding_model_id: str,
31
+ ) -> None:
32
+ """Initialize the HuggingFaceAgent.
33
+
34
+ Args:
35
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
36
+ model_id (str): Hugging Face model ID for the LLM.
37
+ embedding_model_id (str): Hugging Face model ID for the embedding model.
38
+ """
39
+ self.checkpoints_sqlite_db = checkpoints_sqlite_db
40
+ self.model_id = model_id
41
+ self.embedding_model_id = embedding_model_id
42
+
43
+ self.llm = ChatHuggingFace(
44
+ llm=HuggingFacePipeline.from_model_id(
45
+ model_id=model_id,
46
+ task="text-generation",
47
+ device_map="auto",
48
+ pipeline_kwargs={"max_new_tokens": 4096},
49
+ ),
50
+ )
51
+ self.embed = HuggingFaceEmbeddings(model_name=embedding_model_id)
52
+ self.agent = create_agent(
53
+ model=self.llm,
54
+ system_prompt="You are a helpful assistant.",
55
+ checkpointer=SqliteSaver(sqlite3.Connection(self.checkpoints_sqlite_db, check_same_thread=False)),
56
+ )
57
+
58
+ async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
59
+ """Stream a response from the agent.
60
+
61
+ Args:
62
+ user_message (str): User's next prompt in the conversation.
63
+ thread_id (str): Identifier for the current thread/conversation.
64
+ app (AppProtocol): Application interface, commonly used for logging.
65
+
66
+ Yields:
67
+ str: A token from the agent's response.
68
+ """
69
+ agent_stream = self.agent.stream(
70
+ {"messages": [HumanMessage(content=user_message)]},
71
+ {"configurable": {"thread_id": thread_id}},
72
+ stream_mode="messages",
73
+ )
74
+ for message_chunk, _ in agent_stream:
75
+ if isinstance(message_chunk, AIMessageChunk):
76
+ token = message_chunk.content
77
+ if isinstance(token, str):
78
+ yield token
79
+ else:
80
+ app.log.error("Received message content of type", type(token))
81
+ else:
82
+ app.log.error("Received message chunk of type", type(message_chunk))
83
+
84
+
85
+ def _hf_downloads() -> None:
86
+ hf_hub_download(
87
+ repo_id="Qwen/Qwen3-0.6B", # 1.5GB
88
+ filename="model.safetensors",
89
+ revision="c1899de289a04d12100db370d81485cdf75e47ca",
90
+ )
91
+ hf_hub_download(
92
+ repo_id="unsloth/embeddinggemma-300m", # 1.21GB
93
+ filename="model.safetensors",
94
+ revision="bfa3c846ac738e62aa61806ef9112d34acb1dc5a",
95
+ )
96
+
97
+
98
+ class HuggingFaceAgentProvider:
99
+ """Create LLM agents using Hugging Face local pipelines."""
100
+
101
+ type: Final[LocalProviderType] = LocalProviderType.HUGGING_FACE
102
+
103
+ @asynccontextmanager
104
+ async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[HuggingFaceAgent]:
105
+ """Create a Hugging Face local pipeline agent.
106
+
107
+ Args:
108
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
109
+ """
110
+ loop = asyncio.get_running_loop()
111
+ await loop.run_in_executor(None, _hf_downloads)
112
+ yield HuggingFaceAgent(
113
+ checkpoints_sqlite_db,
114
+ model_id="Qwen/Qwen3-0.6B",
115
+ embedding_model_id="unsloth/embeddinggemma-300m",
116
+ )
@@ -0,0 +1,113 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from contextlib import asynccontextmanager
5
+ from typing import TYPE_CHECKING, Final
6
+
7
+ import aiosqlite
8
+ from huggingface_hub import hf_hub_download
9
+ from langchain.agents import create_agent
10
+ from langchain.messages import AIMessageChunk, HumanMessage
11
+ from langchain_community.chat_models import ChatLlamaCpp
12
+ from langchain_community.embeddings import LlamaCppEmbeddings
13
+ from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
14
+
15
+ from rag_demo import probe
16
+ from rag_demo.constants import LocalProviderType
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import AsyncIterator
20
+ from pathlib import Path
21
+
22
+ from rag_demo.app_protocol import AppProtocol
23
+
24
+
25
+ class LlamaCppAgent:
26
+ """An LLM agent powered by Llama.cpp."""
27
+
28
+ def __init__(
29
+ self,
30
+ checkpoints_conn: aiosqlite.Connection,
31
+ model_path: str,
32
+ embedding_model_path: str,
33
+ ) -> None:
34
+ """Initialize the LlamaCppAgent.
35
+
36
+ Args:
37
+ checkpoints_conn (aiosqlite.Connection): Connection to SQLite checkpoint database.
38
+ model_path (str): Path to Llama.cpp model.
39
+ embedding_model_path (str): Path to Llama.cpp embedding model.
40
+ """
41
+ self.checkpoints_conn = checkpoints_conn
42
+ self.llm = ChatLlamaCpp(model_path=model_path, verbose=False)
43
+ self.embed = LlamaCppEmbeddings(model_path=embedding_model_path, verbose=False)
44
+ self.agent = create_agent(
45
+ model=self.llm,
46
+ system_prompt="You are a helpful assistant.",
47
+ checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
48
+ )
49
+
50
+ async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
51
+ """Stream a response from the agent.
52
+
53
+ Args:
54
+ user_message (str): User's next prompt in the conversation.
55
+ thread_id (str): Identifier for the current thread/conversation.
56
+ app (AppProtocol): Application interface, commonly used for logging.
57
+
58
+ Yields:
59
+ str: A token from the agent's response.
60
+ """
61
+ agent_stream = self.agent.astream(
62
+ {"messages": [HumanMessage(content=user_message)]},
63
+ {"configurable": {"thread_id": thread_id}},
64
+ stream_mode="messages",
65
+ )
66
+ async for message_chunk, _ in agent_stream:
67
+ if isinstance(message_chunk, AIMessageChunk):
68
+ token = message_chunk.content
69
+ if isinstance(token, str):
70
+ yield token
71
+ else:
72
+ app.log.error("Received message content of type", type(token))
73
+ else:
74
+ app.log.error("Received message chunk of type", type(message_chunk))
75
+
76
+
77
+ def _hf_downloads() -> tuple[str, str]:
78
+ model_path = hf_hub_download(
79
+ repo_id="bartowski/google_gemma-3-4b-it-GGUF",
80
+ filename="google_gemma-3-4b-it-Q6_K_L.gguf", # 3.35GB
81
+ revision="71506238f970075ca85125cd749c28b1b0eee84e",
82
+ )
83
+ embedding_model_path = hf_hub_download(
84
+ repo_id="CompendiumLabs/bge-small-en-v1.5-gguf",
85
+ filename="bge-small-en-v1.5-q8_0.gguf", # 36.8MB
86
+ revision="d32f8c040ea3b516330eeb75b72bcc2d3a780ab7",
87
+ )
88
+ return model_path, embedding_model_path
89
+
90
+
91
+ class LlamaCppAgentProvider:
92
+ """Create LLM agents using Llama.cpp."""
93
+
94
+ type: Final[LocalProviderType] = LocalProviderType.LLAMA_CPP
95
+
96
+ @asynccontextmanager
97
+ async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[LlamaCppAgent | None]:
98
+ """Attempt to create a Llama.cpp agent.
99
+
100
+ Args:
101
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
102
+ """
103
+ if probe.probe_llama_available():
104
+ loop = asyncio.get_running_loop()
105
+ model_path, embedding_model_path = await loop.run_in_executor(None, _hf_downloads)
106
+ async with aiosqlite.connect(database=checkpoints_sqlite_db) as checkpoints_conn:
107
+ yield LlamaCppAgent(
108
+ checkpoints_conn=checkpoints_conn,
109
+ model_path=model_path,
110
+ embedding_model_path=embedding_model_path,
111
+ )
112
+ else:
113
+ yield None
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import asynccontextmanager
4
+ from typing import TYPE_CHECKING, Final
5
+
6
+ import aiosqlite
7
+ import ollama
8
+ from langchain.agents import create_agent
9
+ from langchain.messages import AIMessageChunk, HumanMessage
10
+ from langchain_ollama import ChatOllama, OllamaEmbeddings
11
+ from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
12
+
13
+ from rag_demo import probe
14
+ from rag_demo.constants import LocalProviderType
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import AsyncIterator
18
+ from pathlib import Path
19
+
20
+ from rag_demo.app_protocol import AppProtocol
21
+
22
+
23
+ class OllamaAgent:
24
+ """An LLM agent powered by Ollama."""
25
+
26
+ def __init__(self, checkpoints_conn: aiosqlite.Connection) -> None:
27
+ """Initialize the OllamaAgent.
28
+
29
+ Args:
30
+ checkpoints_conn (aiosqlite.Connection): Asynchronous connection to SQLite db for checkpoints.
31
+ """
32
+ self.checkpoints_conn = checkpoints_conn
33
+ ollama.pull("gemma3:latest") # 3.3GB
34
+ ollama.pull("embeddinggemma:latest") # 621MB
35
+ self.llm = ChatOllama(
36
+ model="gemma3:latest",
37
+ validate_model_on_init=True,
38
+ temperature=0.5,
39
+ num_predict=4096,
40
+ )
41
+ self.embed = OllamaEmbeddings(model="embeddinggemma:latest")
42
+ self.agent = create_agent(
43
+ model=self.llm,
44
+ system_prompt="You are a helpful assistant.",
45
+ checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
46
+ )
47
+
48
+ async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
49
+ """Stream a response from the agent.
50
+
51
+ Args:
52
+ user_message (str): User's next prompt in the conversation.
53
+ thread_id (str): Identifier for the current thread/conversation.
54
+ app (AppProtocol): Application interface, commonly used for logging.
55
+
56
+ Yields:
57
+ str: A token from the agent's response.
58
+ """
59
+ agent_stream = self.agent.astream(
60
+ {"messages": [HumanMessage(content=user_message)]},
61
+ {"configurable": {"thread_id": thread_id}},
62
+ stream_mode="messages",
63
+ )
64
+ async for message_chunk, _ in agent_stream:
65
+ if isinstance(message_chunk, AIMessageChunk):
66
+ token = message_chunk.content
67
+ if isinstance(token, str):
68
+ yield token
69
+ else:
70
+ app.log.error("Received message content of type", type(token))
71
+ else:
72
+ app.log.error("Received message chunk of type", type(message_chunk))
73
+
74
+
75
+ class OllamaAgentProvider:
76
+ """Create LLM agents using Ollama."""
77
+
78
+ type: Final[LocalProviderType] = LocalProviderType.OLLAMA
79
+
80
+ @asynccontextmanager
81
+ async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[OllamaAgent | None]:
82
+ """Attempt to create an Ollama agent.
83
+
84
+ Args:
85
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
86
+ """
87
+ if probe.probe_ollama() is not None:
88
+ async with aiosqlite.connect(database=checkpoints_sqlite_db) as checkpoints_conn:
89
+ yield OllamaAgent(checkpoints_conn=checkpoints_conn)
90
+ else:
91
+ yield None
rag_demo/app.py ADDED
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, ClassVar
6
+
7
+ from textual.app import App
8
+ from textual.binding import Binding
9
+
10
+ from rag_demo.modes import ChatScreen, ConfigScreen, HelpScreen
11
+
12
+ if TYPE_CHECKING:
13
+ from rag_demo.logic import Logic, Runtime
14
+
15
+
16
+ class RAGDemo(App):
17
+ """Main application UI.
18
+
19
+ This class is responsible for creating the modes of the application, which are defined in :mod:`rag_demo.modes`.
20
+ """
21
+
22
+ TITLE = "RAG Demo"
23
+ CSS_PATH = Path(__file__).parent / "app.tcss"
24
+ BINDINGS: ClassVar = [
25
+ Binding("z", "switch_mode('chat')", "chat"),
26
+ Binding("c", "switch_mode('config')", "configure"),
27
+ Binding("h", "switch_mode('help')", "help"),
28
+ ]
29
+ MODES: ClassVar = {
30
+ "chat": ChatScreen,
31
+ "config": ConfigScreen,
32
+ "help": HelpScreen,
33
+ }
34
+
35
+ def __init__(self, logic: Logic) -> None:
36
+ """Initialize the main app.
37
+
38
+ Args:
39
+ logic (Logic): Object implementing the application logic.
40
+ """
41
+ super().__init__()
42
+ self.logic = logic
43
+ self._runtime_future: asyncio.Future[Runtime] = asyncio.Future()
44
+
45
+ async def on_mount(self) -> None:
46
+ """Set the initial mode to chat and initialize async parts of the logic."""
47
+ self.switch_mode("chat")
48
+ self.run_worker(self._hold_runtime())
49
+
50
+ async def _hold_runtime(self) -> None:
51
+ async with self.logic.runtime(app=self) as runtime:
52
+ self._runtime_future.set_result(runtime)
53
+ # Pause the task until Textual cancels it when the application closes.
54
+ await asyncio.Event().wait()
55
+
56
+ async def runtime(self) -> Runtime:
57
+ """Returns the application runtime logic."""
58
+ return await self._runtime_future
rag_demo/app.tcss ADDED
File without changes