jehoctor-rag-demo 0.1.1.dev1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.3
2
+ Name: jehoctor-rag-demo
3
+ Version: 0.2.0
4
+ Summary: Chat with Wikipedia
5
+ Author: James Hoctor
6
+ Author-email: James Hoctor <JEHoctor@protonmail.com>
7
+ Requires-Dist: aiosqlite==0.21.0
8
+ Requires-Dist: chromadb>=1.3.4
9
+ Requires-Dist: datasets>=4.4.1
10
+ Requires-Dist: httpx>=0.28.1
11
+ Requires-Dist: huggingface-hub>=0.36.0
12
+ Requires-Dist: langchain>=1.0.5
13
+ Requires-Dist: langchain-anthropic>=1.0.2
14
+ Requires-Dist: langchain-community>=0.4.1
15
+ Requires-Dist: langchain-huggingface>=1.1.0
16
+ Requires-Dist: langchain-ollama>=1.0.0
17
+ Requires-Dist: langchain-openai>=1.0.2
18
+ Requires-Dist: langgraph-checkpoint-sqlite>=3.0.1
19
+ Requires-Dist: llama-cpp-python>=0.3.16
20
+ Requires-Dist: nvidia-ml-py>=13.590.44
21
+ Requires-Dist: ollama>=0.6.0
22
+ Requires-Dist: platformdirs>=4.5.0
23
+ Requires-Dist: psutil>=7.1.3
24
+ Requires-Dist: py-cpuinfo>=9.0.0
25
+ Requires-Dist: pydantic>=2.12.4
26
+ Requires-Dist: pyperclip>=1.11.0
27
+ Requires-Dist: textual>=6.5.0
28
+ Requires-Dist: typer>=0.20.0
29
+ Requires-Python: >=3.12
30
+ Description-Content-Type: text/markdown
31
+
32
+ # RAG-demo
33
+
34
+ Chat with (a small portion of) Wikipedia
35
+
36
+ ⚠️ RAG functionality is still under development. ⚠️
37
+
38
+ ![app screenshot](screenshots/screenshot_062f205a.png "App screenshot (this AI response is not accurate)")
39
+
40
+ ## Requirements
41
+
42
+ 1. [uv](https://docs.astral.sh/uv/)
43
+ 2. At least one of the following:
44
+ - A suitable terminal emulator. In particular, on macOS consider using [iTerm2](https://iterm2.com/) instead of the default Terminal.app ([explanation](https://textual.textualize.io/FAQ/#why-doesnt-textual-look-good-on-macos)). On Linux, you might want to try [kitty](https://sw.kovidgoyal.net/kitty/), [wezterm](https://wezterm.org/), [alacritty](https://alacritty.org/), or [ghostty](https://ghostty.org/) instead of the terminal that came with your DE ([reason](https://darren.codes/posts/textual-copy-paste/)). Windows Terminal should be fine as far as I know.
45
+ - Any common web browser
46
+
47
+ ## Optional stuff that could make your experience better
48
+
49
+ 1. [Hugging Face login](https://huggingface.co/docs/huggingface_hub/quick-start#login)
50
+ 2. API key for your favorite LLM provider (support coming soon)
51
+ 3. Ollama installed on your system if you have a GPU
52
+ 4. Run RAG-demo on a more capable (bigger GPU) machine over SSH if you can. It is a terminal app after all.
53
+
54
+
55
+ ## Run from the repository
56
+
57
+ First, clone this repository. Then, run one of the options below.
58
+
59
+ Run in a terminal:
60
+ ```bash
61
+ uv run chat
62
+ ```
63
+
64
+ Or run in a web browser:
65
+ ```bash
66
+ uv run textual serve chat
67
+ ```
68
+
69
+ ## Run from the latest version on PyPI
70
+
71
+ TODO: test uv automatic torch backend selection:
72
+ https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection
73
+
74
+ Run in a terminal:
75
+ ```bash
76
+ uvx --from=jehoctor-rag-demo chat
77
+ ```
78
+
79
+ Or run in a web browser:
80
+ ```bash
81
+ uvx --from=jehoctor-rag-demo textual serve chat
82
+ ```
83
+
84
+ ## CUDA acceleration via Llama.cpp
85
+
86
+ If you have an NVIDIA GPU with CUDA and build tools installed, you might be able to get CUDA acceleration without installing Ollama.
87
+
88
+ ```bash
89
+ CMAKE_ARGS="-DGGML_CUDA=on" uv run chat
90
+ ```
91
+
92
+ ## Metal acceleration via Llama.cpp (on Apple Silicon)
93
+
94
+ On an Apple Silicon machine, make sure `uv` runs an ARM interpreter as this should cause it to install Llama.cpp with Metal support.
95
+
96
+ ## Ollama on Linux
97
+
98
+ Remember that you have to keep Ollama up-to-date manually on Linux.
99
+ A recent version of Ollama (v0.11.10 or later) is required to run the [embedding model we use](https://ollama.com/library/embeddinggemma).
100
+ See this FAQ: https://docs.ollama.com/faq#how-can-i-upgrade-ollama.
@@ -0,0 +1,23 @@
1
+ rag_demo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ rag_demo/__main__.py,sha256=Kak0eQWBRHVGDoWgoHs9j-Tvf_9DMzdurMxD7EM4Jr0,1054
3
+ rag_demo/app.py,sha256=xejrtFApeTeyOQvWDq1H0XPyZEr8cQPn7q9KRwnV660,1812
4
+ rag_demo/app.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ rag_demo/db.py,sha256=53n662Hj9sTqPNcCI2Q-6Ca_HXv3kBQdAhXU4DLhwBM,3226
6
+ rag_demo/dirs.py,sha256=b0VR76kXRHSRWzaXzmAhfPr3-8WKY3ZLW8aLlaPI3Do,309
7
+ rag_demo/logic.py,sha256=7PTWPs9xZJ7bbEtNDMQTX6SX4JKG8HMiq2H_YUfM-CI,12602
8
+ rag_demo/markdown.py,sha256=CxzshWfANeiieZkzMlLzpRaz7tBY2_tZQxhs7b2ImKM,551
9
+ rag_demo/modes/__init__.py,sha256=ccvURDWz51_IotzzlO2OH3i4_Ih_MgnGlOK_JCh45dY,91
10
+ rag_demo/modes/_logic_provider.py,sha256=__eO4XVbyRHkjV_D8OHsPJX5f2R8JoJPcNXhi-w_xFY,1277
11
+ rag_demo/modes/chat.py,sha256=VigWSkw6R2ea95-wZ8tgtKIccev9A-ByzJj7nzglsog,13444
12
+ rag_demo/modes/chat.tcss,sha256=YANlgYygiOr-e61N9HaGGdRPM36pdr-l4u72G0ozt4o,1032
13
+ rag_demo/modes/config.py,sha256=0A8IdY-GOeqCd0kMs2KMgQEsFFeVXEcnowOugtR_Q84,2609
14
+ rag_demo/modes/config.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ rag_demo/modes/help.py,sha256=riV8o4WDtsim09R4cRi0xkpYLgj4CL38IrjEz_mrRmk,713
16
+ rag_demo/modes/help.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ rag_demo/widgets/__init__.py,sha256=JQ1KQjdYQ4texHw2iT4IyBKgTW0SzNYbNoHAbrdCwtk,44
19
+ rag_demo/widgets/escapable_input.py,sha256=VfFij4NOtQ4uX3YFETg5YPd0_nBMky9Xz-02oRdHu-w,4240
20
+ jehoctor_rag_demo-0.2.0.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
21
+ jehoctor_rag_demo-0.2.0.dist-info/entry_points.txt,sha256=-nDSFVcIqdTxzYM4fdveDk3xUKRhmlr_cRuqQechYh4,49
22
+ jehoctor_rag_demo-0.2.0.dist-info/METADATA,sha256=wp1mdAqjB0be_1Uly4hwAoz0bjRUDI6gb6gK5SdrHRU,3531
23
+ jehoctor_rag_demo-0.2.0.dist-info/RECORD,,
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ chat = rag_demo.__main__:main
3
+
rag_demo/__init__.py CHANGED
@@ -1,2 +0,0 @@
1
- def main() -> None:
2
- print("Hello from rag-demo!")
rag_demo/__main__.py ADDED
@@ -0,0 +1,31 @@
1
+ import time
2
+
3
+ # Measure the application start time.
4
+ APPLICATION_START_TIME = time.time()
5
+
6
+ # Disable "module import not at top of file" (aka E402) when importing Typer. This is necessary so that Typer's
7
+ # initialization is included in the application startup time.
8
+ import typer # noqa: E402
9
+
10
+
11
+ def _main(
12
+ name: str | None = typer.Option(None, help="The name you want to want the AI to use with you."),
13
+ ) -> None:
14
+ """Talk to Wikipedia."""
15
+ # Import here so that imports run within the typer.run context for prettier stack traces if errors occur.
16
+ # We ignore PLC0415 because we do not want these imports to be at the top of the module as is usually preferred.
17
+ from rag_demo.app import RAGDemo # noqa: PLC0415
18
+ from rag_demo.logic import Logic # noqa: PLC0415
19
+
20
+ logic = Logic(username=name, application_start_time=APPLICATION_START_TIME)
21
+ app = RAGDemo(logic)
22
+ app.run()
23
+
24
+
25
+ def main() -> None:
26
+ """Entrypoint for the rag demo, specifically the `chat` command."""
27
+ typer.run(_main)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main()
rag_demo/app.py ADDED
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, ClassVar
6
+
7
+ from textual.app import App
8
+ from textual.binding import Binding
9
+
10
+ from rag_demo.modes import ChatScreen, ConfigScreen, HelpScreen
11
+
12
+ if TYPE_CHECKING:
13
+ from rag_demo.logic import Logic, Runtime
14
+
15
+
16
+ class RAGDemo(App):
17
+ """Main application UI.
18
+
19
+ This class is responsible for creating the modes of the application, which are defined in :mod:`rag_demo.modes`.
20
+ """
21
+
22
+ TITLE = "RAG Demo"
23
+ CSS_PATH = Path(__file__).parent / "app.tcss"
24
+ BINDINGS: ClassVar = [
25
+ Binding("z", "switch_mode('chat')", "chat"),
26
+ Binding("c", "switch_mode('config')", "configure"),
27
+ Binding("h", "switch_mode('help')", "help"),
28
+ ]
29
+ MODES: ClassVar = {
30
+ "chat": ChatScreen,
31
+ "config": ConfigScreen,
32
+ "help": HelpScreen,
33
+ }
34
+
35
+ def __init__(self, logic: Logic) -> None:
36
+ """Initialize the main app.
37
+
38
+ Args:
39
+ logic (Logic): Object implementing the application logic.
40
+ """
41
+ super().__init__()
42
+ self.logic = logic
43
+ self._runtime_future: asyncio.Future[Runtime] = asyncio.Future()
44
+
45
+ async def on_mount(self) -> None:
46
+ """Set the initial mode to chat and initialize async parts of the logic."""
47
+ self.switch_mode("chat")
48
+ self.run_worker(self._hold_runtime())
49
+
50
+ async def _hold_runtime(self) -> None:
51
+ async with self.logic.runtime(app_like=self) as runtime:
52
+ self._runtime_future.set_result(runtime)
53
+ # Pause the task until Textual cancels it when the application closes.
54
+ await asyncio.Event().wait()
55
+
56
+ async def runtime(self) -> Runtime:
57
+ """Returns the application runtime logic."""
58
+ return await self._runtime_future
rag_demo/app.tcss ADDED
File without changes
rag_demo/db.py ADDED
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import aiosqlite
6
+
7
+ if TYPE_CHECKING:
8
+ from pathlib import Path
9
+
10
+
11
+ class AtomicIDManager:
12
+ """A database manager for managing thread IDs.
13
+
14
+ This was written by Claude, and I fixed it up with feedback from Ruff and Flake8.
15
+ Maybe one day the app logic database will require something fancier, but this gets the job done now.
16
+
17
+ As you can see from the conversation with Claude, this was quite a simple task for it:
18
+ https://claude.ai/share/227d08ff-96a3-495a-9f56-509a1fd528f7
19
+ """
20
+
21
+ def __init__(self, db_path: str | Path) -> None:
22
+ """Initialize the database manager."""
23
+ self.db_path = db_path
24
+
25
+ async def initialize(self) -> None:
26
+ """Initialize the database and create the table if it doesn't exist."""
27
+ async with aiosqlite.connect(self.db_path) as db:
28
+ # Enable WAL mode for better concurrent access
29
+ await db.execute("PRAGMA journal_mode=WAL")
30
+
31
+ await db.execute("""
32
+ CREATE TABLE IF NOT EXISTS claimed_ids (
33
+ id INTEGER PRIMARY KEY
34
+ )
35
+ """)
36
+ await db.commit()
37
+
38
+ async def claim_next_id(self) -> int:
39
+ """Atomically find the max id, increment it, and claim it. Returns the newly claimed ID.
40
+
41
+ This operation is atomic and multiprocess-safe because:
42
+ 1. SQLite serializes writes by default
43
+ 2. We use IMMEDIATE transaction to acquire write lock immediately
44
+ 3. The entire operation happens in a single transaction
45
+ """
46
+ async with aiosqlite.connect(self.db_path) as db:
47
+ # Start an IMMEDIATE transaction to get write lock right away
48
+ await db.execute("BEGIN IMMEDIATE")
49
+
50
+ try:
51
+ # Find the current max ID
52
+ async with db.execute("SELECT MAX(id) FROM claimed_ids") as cursor:
53
+ row = await cursor.fetchone()
54
+ max_id = row[0] if row is not None and row[0] is not None else 0
55
+
56
+ # Calculate next ID
57
+ next_id = max_id + 1
58
+
59
+ # Insert the new ID
60
+ await db.execute("INSERT INTO claimed_ids (id) VALUES (?)", (next_id,))
61
+
62
+ # Commit the transaction
63
+ await db.commit()
64
+
65
+ except Exception:
66
+ await db.rollback()
67
+ raise
68
+
69
+ else:
70
+ return next_id
71
+
72
+ async def get_all_claimed_ids(self) -> list[int]:
73
+ """Retrieve all claimed IDs."""
74
+ async with (
75
+ aiosqlite.connect(self.db_path) as db,
76
+ db.execute("SELECT id FROM claimed_ids ORDER BY id") as cursor,
77
+ ):
78
+ rows = await cursor.fetchall()
79
+ return [row[0] for row in rows]
80
+
81
+ async def get_count(self) -> int:
82
+ """Get the total number of claimed IDs."""
83
+ async with aiosqlite.connect(self.db_path) as db, db.execute("SELECT COUNT(*) FROM claimed_ids") as cursor:
84
+ row = await cursor.fetchone()
85
+ if row is None:
86
+ raise ValueError("A SQL COUNT query should always return at least one row") # noqa: EM101, TRY003
87
+ return row[0]
rag_demo/dirs.py ADDED
@@ -0,0 +1,14 @@
1
+ from pathlib import Path
2
+
3
+ from platformdirs import PlatformDirs
4
+
5
+ _appdirs = PlatformDirs(appname="jehoctor-rag-demo", ensure_exists=True)
6
+
7
+
8
+ def _ensure(dir_: Path) -> Path:
9
+ dir_.mkdir(parents=True, exist_ok=True)
10
+ return dir_
11
+
12
+
13
+ DATA_DIR = _appdirs.user_data_path
14
+ CONFIG_DIR = _appdirs.user_config_path
rag_demo/logic.py ADDED
@@ -0,0 +1,287 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import platform
5
+ import time
6
+ from contextlib import asynccontextmanager
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Protocol, TypeVar, cast
9
+
10
+ import aiosqlite
11
+ import cpuinfo
12
+ import httpx
13
+ import huggingface_hub
14
+ import llama_cpp
15
+ import ollama
16
+ import psutil
17
+ import pynvml
18
+ from datasets import Dataset, load_dataset
19
+ from huggingface_hub import hf_hub_download
20
+ from huggingface_hub.constants import HF_HUB_CACHE
21
+ from langchain.agents import create_agent
22
+ from langchain.messages import AIMessageChunk, HumanMessage
23
+ from langchain_community.chat_models import ChatLlamaCpp
24
+ from langchain_community.embeddings import LlamaCppEmbeddings
25
+ from langchain_core.exceptions import LangChainException
26
+ from langchain_ollama import ChatOllama, OllamaEmbeddings
27
+ from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
28
+
29
+ from rag_demo import dirs
30
+ from rag_demo.db import AtomicIDManager
31
+ from rag_demo.modes.chat import Response, StoppedStreamError
32
+
33
+ if TYPE_CHECKING:
34
+ from collections.abc import AsyncIterator, Awaitable
35
+
36
+ from textual.worker import Worker
37
+
38
+ from rag_demo.modes import ChatScreen
39
+
40
+ ResultType = TypeVar("ResultType")
41
+
42
+
43
+ class AppLike(Protocol):
44
+ """Protocol for the subset of what the main App can do that the runtime needs."""
45
+
46
+ def run_worker(self, work: Awaitable[ResultType]) -> Worker[ResultType]:
47
+ """Run a coroutine in the background.
48
+
49
+ See https://textual.textualize.io/guide/workers/.
50
+
51
+ Args:
52
+ work (Awaitable[ResultType]): The coroutine to run.
53
+ """
54
+ ...
55
+
56
+
57
+ class Runtime:
58
+ """The application logic with asynchronously initialized resources."""
59
+
60
+ def __init__(
61
+ self,
62
+ logic: Logic,
63
+ checkpoints_conn: aiosqlite.Connection,
64
+ thread_id_manager: AtomicIDManager,
65
+ app_like: AppLike,
66
+ ) -> None:
67
+ self.runtime_start_time = time.time()
68
+ self.logic = logic
69
+ self.checkpoints_conn = checkpoints_conn
70
+ self.thread_id_manager = thread_id_manager
71
+ self.app_like = app_like
72
+
73
+ self.current_thread: int | None = None
74
+ self.generating = False
75
+
76
+ if self.logic.probe_ollama() is not None:
77
+ ollama.pull("gemma3:latest") # 3.3GB
78
+ ollama.pull("embeddinggemma:latest") # 621MB
79
+ self.llm = ChatOllama(
80
+ model="gemma3:latest",
81
+ validate_model_on_init=True,
82
+ temperature=0.5,
83
+ num_predict=4096,
84
+ )
85
+ self.embed = OllamaEmbeddings(model="embeddinggemma:latest")
86
+ else:
87
+ model_path = hf_hub_download(
88
+ repo_id="bartowski/google_gemma-3-4b-it-GGUF",
89
+ filename="google_gemma-3-4b-it-Q6_K_L.gguf", # 3.35GB
90
+ revision="71506238f970075ca85125cd749c28b1b0eee84e",
91
+ )
92
+ embedding_model_path = hf_hub_download(
93
+ repo_id="CompendiumLabs/bge-small-en-v1.5-gguf",
94
+ filename="bge-small-en-v1.5-q8_0.gguf", # 36.8MB
95
+ revision="d32f8c040ea3b516330eeb75b72bcc2d3a780ab7",
96
+ )
97
+ self.llm = ChatLlamaCpp(model_path=model_path, verbose=False)
98
+ self.embed = LlamaCppEmbeddings(model_path=embedding_model_path, verbose=False) # pyright: ignore[reportCallIssue]
99
+
100
+ self.agent = create_agent(
101
+ model=self.llm,
102
+ system_prompt="You are a helpful assistant.",
103
+ checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
104
+ )
105
+
106
+ def get_rag_datasets(self) -> None:
107
+ self.qa_test: Dataset = cast(
108
+ "Dataset",
109
+ load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer", split="test"),
110
+ )
111
+ self.corpus: Dataset = cast(
112
+ "Dataset",
113
+ load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages"),
114
+ )
115
+
116
+ async def stream_response(self, response_widget: Response, request_text: str, thread: str) -> None:
117
+ """Worker method for streaming tokens from the active agent to a response widget.
118
+
119
+ Args:
120
+ response_widget (Response): Target response widget for streamed tokens.
121
+ request_text (str): Text of the user request.
122
+ thread (str): ID of the current thread.
123
+ """
124
+ self.generating = True
125
+ async with response_widget.stream_writer() as writer:
126
+ agent_stream = self.agent.astream(
127
+ {"messages": [HumanMessage(content=request_text)]},
128
+ {"configurable": {"thread_id": thread}},
129
+ stream_mode="messages",
130
+ )
131
+ try:
132
+ async for message_chunk, _ in agent_stream:
133
+ if isinstance(message_chunk, AIMessageChunk):
134
+ token = cast("AIMessageChunk", message_chunk).content
135
+ if isinstance(token, str):
136
+ await writer.write(token)
137
+ else:
138
+ response_widget.log.error(f"Received message content of type {type(token)}")
139
+ else:
140
+ response_widget.log.error(f"Received message chunk of type {type(message_chunk)}")
141
+ except StoppedStreamError as e:
142
+ response_widget.set_shown_object(e)
143
+ except LangChainException as e:
144
+ response_widget.set_shown_object(e)
145
+ self.generating = False
146
+
147
+ def new_conversation(self, chat_screen: ChatScreen) -> None:
148
+ self.current_thread = None
149
+ chat_screen.clear_chats()
150
+
151
+ async def submit_request(self, chat_screen: ChatScreen, request_text: str) -> bool:
152
+ if self.generating:
153
+ return False
154
+ self.generating = True
155
+ if self.current_thread is None:
156
+ chat_screen.log.info("Starting new thread")
157
+ self.current_thread = await self.thread_id_manager.claim_next_id()
158
+ chat_screen.log.info("Claimed thread id", self.current_thread)
159
+ chat_screen.new_request(request_text)
160
+ response = chat_screen.new_response()
161
+ chat_screen.run_worker(self.stream_response(response, request_text, str(self.current_thread)))
162
+ return True
163
+
164
+
165
+ class Logic:
166
+ """Top-level application logic."""
167
+
168
+ def __init__(
169
+ self,
170
+ username: str | None = None,
171
+ application_start_time: float | None = None,
172
+ checkpoints_sqlite_db: str | Path = dirs.DATA_DIR / "checkpoints.sqlite3",
173
+ app_sqlite_db: str | Path = dirs.DATA_DIR / "app.sqlite3",
174
+ ) -> None:
175
+ """Initialize the application logic.
176
+
177
+ Args:
178
+ username (str | None, optional): The username provided as a command line argument. Defaults to None.
179
+ application_start_time (float | None, optional): The time when the application started. Defaults to None.
180
+ checkpoints_sqlite_db (str | Path, optional): The connection string for the SQLite database used for
181
+ Langchain checkpointing. Defaults to (dirs.DATA_DIR / "checkpoints.sqlite3").
182
+ app_sqlite_db (str | Path, optional): The connection string for the SQLite database used for application
183
+ state such a thread metadata. Defaults to (dirs.DATA_DIR / "app.sqlite3").
184
+ """
185
+ self.logic_start_time = time.time()
186
+ self.username = username
187
+ self.application_start_time = application_start_time
188
+ self.checkpoints_sqlite_db = checkpoints_sqlite_db
189
+ self.app_sqlite_db = app_sqlite_db
190
+
191
+ @asynccontextmanager
192
+ async def runtime(self, app_like: AppLike) -> AsyncIterator[Runtime]:
193
+ """Returns a runtime context for the application."""
194
+ # TODO: Do I need to set check_same_thread=False in aiosqlite.connect?
195
+ async with aiosqlite.connect(database=self.checkpoints_sqlite_db) as checkpoints_conn:
196
+ id_manager = AtomicIDManager(self.app_sqlite_db)
197
+ await id_manager.initialize()
198
+ yield Runtime(self, checkpoints_conn, id_manager, app_like)
199
+
200
+ def probe_os(self) -> str:
201
+ """Returns the OS name (eg 'Linux' or 'Windows'), the system name (eg 'Java'), or an empty string if unknown."""
202
+ return platform.system()
203
+
204
+ def probe_architecture(self) -> str:
205
+ """Returns the machine architecture, such as 'i386'."""
206
+ return platform.machine()
207
+
208
+ def probe_cpu(self) -> str:
209
+ """Returns the name of the CPU, e.g. "Intel(R) Core(TM) i7-10610U CPU @ 1.80GHz"."""
210
+ return cpuinfo.get_cpu_info()["brand_raw"]
211
+
212
+ def probe_ram(self) -> int:
213
+ """Returns the total amount of RAM in bytes."""
214
+ return psutil.virtual_memory().total
215
+
216
+ def probe_disk_space(self) -> int:
217
+ """Returns the amount of free space in the root directory (in bytes)."""
218
+ return psutil.disk_usage("/").free
219
+
220
+ def probe_llamacpp_gpu_support(self) -> bool:
221
+ """Returns True if LlamaCpp supports GPU offloading, False otherwise."""
222
+ return llama_cpp.llama_supports_gpu_offload()
223
+
224
+ def probe_huggingface_free_cache_space(self) -> int | None:
225
+ """Returns the amount of free space in the Hugging Face cache (in bytes), or None if it can't be determined."""
226
+ with contextlib.suppress(FileNotFoundError):
227
+ return psutil.disk_usage(HF_HUB_CACHE).free
228
+ for parent_dir in Path(HF_HUB_CACHE).parents:
229
+ with contextlib.suppress(FileNotFoundError):
230
+ return psutil.disk_usage(str(parent_dir)).free
231
+ return None
232
+
233
+ def probe_huggingface_cached_models(self) -> list[huggingface_hub.CachedRepoInfo] | None:
234
+ """Returns a list of models in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
235
+ # The docstring for huggingface_hub.scan_cache_dir says it raises CacheNotFound "if the cache directory does not
236
+ # exist," and ValueError "if the cache directory is a file, instead of a directory."
237
+ with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
238
+ return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "model"]
239
+ return None # Isn't it nice to be explicit?
240
+
241
+ def probe_huggingface_cached_datasets(self) -> list[huggingface_hub.CachedRepoInfo] | None:
242
+ """Returns a list of datasets in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
243
+ with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
244
+ return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "dataset"]
245
+ return None
246
+
247
+ def probe_nvidia(self) -> tuple[int, list[str]]:
248
+ """Detect available NVIDIA GPUs and CUDA driver version.
249
+
250
+ Returns:
251
+ tuple[int, list[str]]: A tuple (cuda_version, nv_gpus) where cuda_version is the installed CUDA driver
252
+ version and nv_gpus is a list of GPU models corresponding to installed NVIDIA GPUs
253
+ """
254
+ try:
255
+ pynvml.nvmlInit()
256
+ except pynvml.NVMLError:
257
+ return -1, []
258
+ cuda_version = -1
259
+ nv_gpus = []
260
+ try:
261
+ cuda_version = pynvml.nvmlSystemGetCudaDriverVersion()
262
+ for i in range(pynvml.nvmlDeviceGetCount()):
263
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
264
+ nv_gpus.append(pynvml.nvmlDeviceGetName(handle))
265
+ except pynvml.NVMLError:
266
+ pass
267
+ finally:
268
+ with contextlib.suppress(pynvml.NVMLError):
269
+ pynvml.nvmlShutdown()
270
+ return cuda_version, nv_gpus
271
+
272
+ def probe_ollama(self) -> list[ollama.ListResponse.Model] | None:
273
+ """Returns a list of models installed in Ollama, or None if connecting to Ollama fails."""
274
+ with contextlib.suppress(ConnectionError):
275
+ return list(ollama.list().models)
276
+ return None
277
+
278
+ def probe_ollama_version(self) -> str | None:
279
+ """Returns the Ollama version string (e.g. "0.13.5"), or None if connecting to Ollama fails."""
280
+ # Yes, this uses private attributes, but that lets me use the Ollama Python lib's env var logic. If you use env
281
+ # vars to direct the app to a different Ollama server, this will query the same Ollama endpoint as the
282
+ # ollama.list() call above. Therefore I silence SLF001 here.
283
+ with contextlib.suppress(httpx.HTTPError, KeyError, ValueError):
284
+ response: httpx.Response = ollama._client._client.request("GET", "/api/version") # noqa: SLF001
285
+ response.raise_for_status()
286
+ return response.json()["version"]
287
+ return None
rag_demo/markdown.py ADDED
@@ -0,0 +1,17 @@
1
+ from markdown_it import MarkdownIt
2
+ from markdown_it.rules_inline import StateInline
3
+
4
+
5
+ def soft2hard_break_plugin(md: MarkdownIt) -> None:
6
+ md.inline.ruler2.push("soft2hard_break", _soft2hard_break_plugin)
7
+
8
+
9
+ def _soft2hard_break_plugin(state: StateInline) -> None:
10
+ for token in state.tokens:
11
+ if token.type == "softbreak":
12
+ token.type = "hardbreak"
13
+
14
+
15
+ def parser_factory() -> MarkdownIt:
16
+ """Modified parser that handles newlines according to LLM conventions."""
17
+ return MarkdownIt("gfm-like").use(soft2hard_break_plugin)
@@ -0,0 +1,3 @@
1
+ from .chat import ChatScreen
2
+ from .config import ConfigScreen
3
+ from .help import HelpScreen
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Protocol, cast
4
+
5
+ from textual.screen import Screen
6
+ from textual.widget import Widget
7
+
8
+ if TYPE_CHECKING:
9
+ from rag_demo.logic import Logic, Runtime
10
+
11
+
12
+ class LogicProvider(Protocol):
13
+ """ABC for classes that contain application logic."""
14
+
15
+ logic: Logic
16
+
17
+ async def runtime(self) -> Runtime: ...
18
+
19
+
20
+ class LogicProviderScreen(Screen):
21
+ """A Screen that provides access to the application logic via its parent app."""
22
+
23
+ @property
24
+ def logic(self) -> Logic:
25
+ """Returns the application logic of the parent app."""
26
+ return cast("LogicProvider", self.app).logic
27
+
28
+ async def runtime(self) -> Runtime:
29
+ """Returns the application runtime of the parent app."""
30
+ return await cast("LogicProvider", self.app).runtime()
31
+
32
+
33
+ class LogicProviderWidget(Widget):
34
+ """A Widget that provides access to the application logic via its parent app."""
35
+
36
+ @property
37
+ def logic(self) -> Logic:
38
+ """Returns the application logic of the parent app."""
39
+ return cast("LogicProvider", self.app).logic
40
+
41
+ async def runtime(self) -> Runtime:
42
+ """Returns the application runtime of the parent app."""
43
+ return await cast("LogicProvider", self.app).runtime()
rag_demo/modes/chat.py ADDED
@@ -0,0 +1,315 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from contextlib import asynccontextmanager
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import pyperclip
9
+ from textual.containers import HorizontalGroup, VerticalGroup, VerticalScroll
10
+ from textual.reactive import reactive
11
+ from textual.widgets import Button, Footer, Header, Input, Label, Markdown, Pretty, Static
12
+ from textual.widgets.markdown import MarkdownStream
13
+
14
+ from rag_demo.markdown import parser_factory
15
+ from rag_demo.modes._logic_provider import LogicProviderScreen, LogicProviderWidget
16
+ from rag_demo.widgets import EscapableInput
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import AsyncIterator
20
+
21
+ from textual.app import ComposeResult
22
+
23
+
24
+ class ResponseStreamInProgressError(ValueError):
25
+ """Exception raised when a Response widget already has an open writer stream."""
26
+
27
+ def __init__(self) -> None: # noqa: D107
28
+ super().__init__("This Response widget already has an open writer stream.")
29
+
30
+
31
+ class StoppedStreamError(ValueError):
32
+ """Exception raised when a ResponseWriter is asked to write after it has been stopped."""
33
+
34
+ def __init__(self) -> None: # noqa: D107
35
+ super().__init__("Can't write to a stopped ResponseWriter stream.")
36
+
37
+
38
+ class ResponseWriter:
39
+ """Stream markdown to a Response widget as if it were a simple Markdown widget.
40
+
41
+ This handles streaming to the Markdown widget and updating the raw text widget and the generation rate label.
42
+
43
+ This class is based on the MarkdownStream class from the Textual library.
44
+ """
45
+
46
+ def __init__(self, response_widget: Response) -> None:
47
+ """Initialize a new ResponseWriter.
48
+
49
+ Args:
50
+ response_widget (Response): The Response widget to write to.
51
+ """
52
+ self.response_widget = response_widget
53
+ self._markdown_widget = response_widget.query_one("#markdown-view", Markdown)
54
+ self._markdown_stream = MarkdownStream(self._markdown_widget)
55
+ self._raw_widget = response_widget.query_one("#raw-view", Label)
56
+ self._start_time: float | None = None
57
+ self._n_chunks = 0
58
+ self._response_text = ""
59
+ self._stopped = False
60
+
61
+ async def stop(self) -> None:
62
+ """Stop this ResponseWriter, particularly its underlying MarkdownStream."""
63
+ self._stopped = True
64
+ # This is safe even if the MarkdownStream has not been started, or has already been stopped.
65
+ await self._markdown_stream.stop()
66
+ # Because of the markdown parsing tweaks I made in src/rag_demo/markdown.py, we need to reparse the final
67
+ # markdown and rerender one more time to clean up small issues with newlines.
68
+ self._markdown_widget.update(self._response_text)
69
+
70
+ async def write(self, markdown_fragment: str) -> None:
71
+ """Stream a single chunk/fragment to the corresponding Response widget.
72
+
73
+ Args:
74
+ markdown_fragment (str): The new markdown fragment to append to the existing markdown.
75
+
76
+ Raises:
77
+ StoppedStreamError: Raised if the new markdown chunk cannot be accepted because the stream has been stopped.
78
+ """
79
+ if self._stopped:
80
+ raise StoppedStreamError
81
+ write_time = time.time()
82
+ self._response_text += markdown_fragment
83
+ self.response_widget.set_reactive(Response.content, self._response_text)
84
+ self._raw_widget.update(self._response_text)
85
+ if self._start_time is None:
86
+ # The first chunk. Can't set the rate label until we have a second chunk.
87
+ self._start_time = write_time
88
+
89
+ self._markdown_widget.update(markdown_fragment)
90
+ self._markdown_stream.start()
91
+ else:
92
+ # The second and subsequent chunks. Note that self._n_chunks has not been incremented yet because
93
+ # we are calculating the generation rate excluding the first chunk, which may have required loading a
94
+ # large model.
95
+ rate = self._n_chunks / (write_time - self._start_time)
96
+ self.response_widget.update_rate_label(rate)
97
+
98
+ await self._markdown_stream.write(markdown_fragment)
99
+ self._n_chunks += 1
100
+
101
+
102
+ class Response(LogicProviderWidget):
103
+ """Allow toggling between raw and rendered versions of markdown text."""
104
+
105
+ show_raw = reactive(False, layout=True)
106
+ content = reactive("", layout=True)
107
+
108
+ def __init__(self, *, content: str = "", classes: str | None = None) -> None:
109
+ """Initialize a new Response widget.
110
+
111
+ Args:
112
+ content (str, optional): Initial response text. Defaults to empty string.
113
+ classes (str | None, optional): Optional widget classes for use with TCSS. Defaults to None.
114
+ """
115
+ super().__init__(classes=classes)
116
+ self.set_reactive(Response.content, content)
117
+ self._stream: ResponseWriter | None = None
118
+ self.__object_to_show_sentinel = object()
119
+ self._object_to_show: Any = self.__object_to_show_sentinel
120
+
121
+ def compose(self) -> ComposeResult:
122
+ """Compose the initial content of the widget."""
123
+ with VerticalGroup():
124
+ with HorizontalGroup(id="header"):
125
+ yield Label("Chunks/s: ???", id="token-rate")
126
+ with HorizontalGroup(id="buttons"):
127
+ yield Button("Stop", id="stop", variant="primary")
128
+ yield Button("Show Raw", id="show-raw", variant="primary")
129
+ yield Button("Copy", id="copy", variant="primary")
130
+ yield Markdown(self.content, id="markdown-view", parser_factory=parser_factory)
131
+ yield Label(self.content, id="raw-view", markup=False)
132
+ yield Pretty(None, id="object-view")
133
+
134
+ def on_mount(self) -> None:
135
+ """Hide certain elements until they are needed."""
136
+ self.query_one("#raw-view", Label).display = False
137
+ self.query_one("#object-view", Pretty).display = False
138
+ self.query_one("#stop", Button).display = False
139
+
140
+ def set_shown_object(self, obj: Any) -> None: # noqa: ANN401
141
+ self._object_to_show = obj
142
+ self.query_one("#markdown-view", Markdown).display = False
143
+ self.query_one("#raw-view", Label).display = False
144
+ self.query_one("#show-raw", Button).display = False
145
+ self.query_one("#object-view", Pretty).update(obj)
146
+ self.query_one("#object-view", Pretty).display = True
147
+
148
+ def clear_shown_object(self) -> None:
149
+ self._object_to_show = self.__object_to_show_sentinel
150
+ self.query_one("#object-view", Pretty).display = False
151
+ if self.show_raw:
152
+ self.query_one("#raw-view", Label).display = True
153
+ else:
154
+ self.query_one("#markdown-view", Markdown).display = True
155
+ self.query_one("#show-raw", Button).display = True
156
+
157
+ @asynccontextmanager
158
+ async def stream_writer(self) -> AsyncIterator[ResponseWriter]:
159
+ """Open an exclusive stream to write markdown in chunks.
160
+
161
+ Raises:
162
+ ResponseWriteInProgressError: Raised when there is already an open stream.
163
+
164
+ Yields:
165
+ ResponseWriter: The new stream writer.
166
+ """
167
+ if self._stream is not None:
168
+ raise ResponseStreamInProgressError
169
+ self._stream = ResponseWriter(self)
170
+ self.query_one("#stop", Button).display = True
171
+ try:
172
+ yield self._stream
173
+ finally:
174
+ await self._stream.stop()
175
+ self.query_one("#stop", Button).display = False
176
+ self._stream = None
177
+
178
+ async def on_button_pressed(self, event: Button.Pressed) -> None:
179
+ """Handle button press events."""
180
+ if event.button.id == "stop":
181
+ if self._stream is not None:
182
+ await self._stream.stop()
183
+ elif event.button.id == "show-raw":
184
+ self.show_raw = not self.show_raw
185
+ elif event.button.id == "copy":
186
+ # Textual and Pyperclip use different methods to copy text to the clipboard. Textual uses ANSI escape
187
+ # sequence magic that is not supported by all terminals. Pyperclip uses OS-specific clipboard APIs, but it
188
+ # does not work over SSH.
189
+ start = time.time()
190
+ self.app.copy_to_clipboard(self.content)
191
+ checkpoint = time.time()
192
+ try:
193
+ pyperclip.copy(self.content)
194
+ except pyperclip.PyperclipException as e:
195
+ self.app.log.error(f"Error copying to clipboard with Pyperclip: {e}")
196
+ checkpoint2 = time.time()
197
+ self.notify(f"Copied {len(self.content.splitlines())} lines of text to clipboard")
198
+ end = time.time()
199
+ self.app.log.info(f"Textual copy took {checkpoint - start:.6f} seconds")
200
+ self.app.log.info(f"Pyperclip copy took {checkpoint2 - checkpoint:.6f} seconds")
201
+ self.app.log.info(f"Notify took {end - checkpoint2:.6f} seconds")
202
+ self.app.log.info(f"Total of {end - start:.6f} seconds")
203
+
204
+ def watch_show_raw(self) -> None:
205
+ """Handle reactive updates to the show_raw attribute by changing the visibility of the child widgets.
206
+
207
+ This also keeps the text on the visibility toggle button up-to-date.
208
+ """
209
+ if self._object_to_show is not self.__object_to_show_sentinel:
210
+ return
211
+ button = self.query_one("#show-raw", Button)
212
+ markdown_view = self.query_one("#markdown-view", Markdown)
213
+ raw_view = self.query_one("#raw-view", Label)
214
+
215
+ if self.show_raw:
216
+ button.label = "Show Rendered"
217
+ markdown_view.display = False
218
+ raw_view.display = True
219
+ else:
220
+ button.label = "Show Raw"
221
+ markdown_view.display = True
222
+ raw_view.display = False
223
+
224
+ def watch_content(self, content: str) -> None:
225
+ """Handle reactive updates to the content attribute by updating the markdown and raw views.
226
+
227
+ Args:
228
+ content (str): New content for the widget.
229
+ """
230
+ self.query_one("#markdown-view", Markdown).update(content)
231
+ self.query_one("#raw-view", Label).update(content)
232
+
233
+ def update_rate_label(self, rate: float | None) -> None:
234
+ """Update or reset the generation rate indicator.
235
+
236
+ Args:
237
+ rate (float | None): Generation rate, or None to reset. Defaults to None.
238
+ """
239
+ label_text = "Chunks/s: ???" if rate is None else f"Chunks/s: {rate:.2f}"
240
+ self.query_one("#token-rate", Label).update(label_text)
241
+
242
+
243
+ class ChatScreen(LogicProviderScreen):
244
+ """Main mode of the app. Talk to the AI agent."""
245
+
246
+ SUB_TITLE = "Chat"
247
+ CSS_PATH = Path(__file__).parent / "chat.tcss"
248
+
249
+ def compose(self) -> ComposeResult:
250
+ """Compose the initial content of the chat screen."""
251
+ yield Header()
252
+ chats = VerticalScroll(id="chats")
253
+ with chats:
254
+ yield HorizontalGroup(id="top-chat-separator")
255
+ with HorizontalGroup(id="new-request-bar"):
256
+ yield Static()
257
+ yield Button("New Conversation", id="new-conversation")
258
+ yield EscapableInput(placeholder=" What do you want to know?", id="new-request", focus_on_escape=chats)
259
+ yield Static()
260
+ yield Footer()
261
+
262
+ def on_mount(self) -> None:
263
+ """When the screen is mounted, focus the input field and enable bottom anchoring for the message view."""
264
+ self.query_one("#new-request", Input).focus()
265
+ self.query_one("#chats", VerticalScroll).anchor()
266
+
267
+ async def on_button_pressed(self, event: Button.Pressed) -> None:
268
+ """Handle button press events."""
269
+ if event.button.id == "new-conversation":
270
+ (await self.runtime()).new_conversation(self)
271
+
272
+ async def on_input_submitted(self, event: Input.Submitted) -> None:
273
+ """Handle submission of new requests."""
274
+ if event.input.id == "new-request":
275
+ accepted = await (await self.runtime()).submit_request(self, event.value)
276
+ if accepted:
277
+ self.query_one("#new-request", Input).value = ""
278
+
279
+ def clear_chats(self) -> None:
280
+ """Clear the chat scroll area."""
281
+ chats = self.query_one("#chats", VerticalScroll)
282
+ for child in chats.children:
283
+ if child.id != "top-chat-separator":
284
+ child.remove()
285
+
286
+ def new_request(self, request_text: str) -> Label:
287
+ """Create a new request element in the chat area.
288
+
289
+ Args:
290
+ request_text (str): The text of the request.
291
+
292
+ Returns:
293
+ Label: The request element.
294
+ """
295
+ chats = self.query_one("#chats", VerticalScroll)
296
+ request = Label(request_text, classes="request")
297
+ chats.mount(HorizontalGroup(request, classes="request-container"))
298
+ chats.anchor()
299
+ return request
300
+
301
+ def new_response(self, response_text: str = "Waiting for AI to respond...") -> Response:
302
+ """Create a new response element in the chat area.
303
+
304
+ Args:
305
+ response_text (str, optional): Initial response text. Usually this is a default message shown before
306
+ streaming the actual response. Defaults to "Waiting for AI to respond...".
307
+
308
+ Returns:
309
+ Response: The response widget/element.
310
+ """
311
+ chats = self.query_one("#chats", VerticalScroll)
312
+ response = Response(content=response_text, classes="response")
313
+ chats.mount(HorizontalGroup(response, classes="response-container"))
314
+ chats.anchor()
315
+ return response
@@ -0,0 +1,75 @@
1
+ ToastRack {
2
+ margin-bottom: 5;
3
+ }
4
+
5
+ #chats {
6
+ width: 100%;
7
+ hatch: left $primary-darken-2 100%;
8
+ }
9
+
10
+ #top-chat-separator {
11
+ hatch: left $primary-darken-2 100%;
12
+ height: 1;
13
+ }
14
+
15
+ .request {
16
+ min-width: 30%;
17
+ max-width: 90%;
18
+ padding: 1 2 1 2;
19
+ margin: 0 2 1 0;
20
+ background: $boost;
21
+ }
22
+
23
+ .request-container {
24
+ align-horizontal: right;
25
+ hatch: left $primary-darken-2 100%;
26
+ }
27
+
28
+ .response-container {
29
+ align-horizontal: left;
30
+ hatch: left $primary-darken-2 100%;
31
+ }
32
+
33
+ .response {
34
+ margin: 0 0 1 2;
35
+ width: 90%;
36
+ height: auto;
37
+ }
38
+
39
+ .response #header {
40
+ background: $primary-darken-3;
41
+ margin-bottom: 1;
42
+ }
43
+
44
+ .response #buttons {
45
+ align-horizontal: right;
46
+ }
47
+
48
+ .response Button {
49
+ border: none;
50
+ min-width: 18;
51
+ margin-left: 1;
52
+ }
53
+
54
+ #token-rate {
55
+ padding-left: 1;
56
+ }
57
+
58
+ #raw-view {
59
+ padding: 0 2 1 2;
60
+ }
61
+
62
+ #new-request-bar > Static {
63
+ width: 5%;
64
+ height: 3;
65
+ hatch: cross $surface 100%;
66
+ }
67
+
68
+ #new-conversation {
69
+ min-width: 20;
70
+ }
71
+
72
+ #new-request {
73
+ width: 1fr;
74
+ height: 3;
75
+ }
@@ -0,0 +1,77 @@
1
+ from pathlib import Path
2
+
3
+ from textual.app import ComposeResult
4
+ from textual.containers import Container, Horizontal
5
+ from textual.widgets import (
6
+ Button,
7
+ Footer,
8
+ Header,
9
+ Input,
10
+ Label,
11
+ RadioButton,
12
+ RadioSet,
13
+ Static,
14
+ )
15
+
16
+ from rag_demo.modes._logic_provider import LogicProviderScreen
17
+
18
+
19
+ class ConfigScreen(LogicProviderScreen):
20
+ SUB_TITLE = "Configure"
21
+ CSS_PATH = Path(__file__).parent / "config.tcss"
22
+
23
+ def compose(self) -> ComposeResult:
24
+ yield Header()
25
+ yield Container(
26
+ Static("🤖 LLM Configuration", classes="title"),
27
+ Label("Select your LLM provider:"),
28
+ RadioSet(
29
+ RadioButton("OpenAI (API)", id="openai"),
30
+ RadioButton("Anthropic Claude (API)", id="anthropic"),
31
+ RadioButton("Ollama (Local)", id="ollama"),
32
+ RadioButton("LlamaCpp (Local)", id="llamacpp"),
33
+ id="provider",
34
+ ),
35
+ Label("Model name:"),
36
+ Input(placeholder="e.g., gpt-4, claude-3-sonnet-20240229", id="model"),
37
+ Label("API Key (if applicable):"),
38
+ Input(placeholder="sk-...", password=True, id="api-key"),
39
+ Label("Base URL (for Ollama):"),
40
+ Input(placeholder="http://localhost:11434", id="base-url"),
41
+ Label("Model Path (for LlamaCpp):"),
42
+ Input(placeholder="/path/to/model.gguf", id="model-path"),
43
+ Horizontal(
44
+ Button("Save & Continue", variant="primary", id="save"),
45
+ Button("Cancel", id="cancel"),
46
+ ),
47
+ )
48
+ yield Footer()
49
+
50
+ def on_button_pressed(self, event: Button.Pressed) -> None:
51
+ if event.button.id == "save":
52
+ config = self.collect_config()
53
+ self.app.config_manager.save_config(config)
54
+ self.app.pop_screen() # Return to main app
55
+ elif event.button.id == "cancel":
56
+ self.app.exit()
57
+
58
+ def collect_config(self) -> dict:
59
+ provider = self.query_one("#provider", RadioSet).pressed_button.id
60
+ model = self.query_one("#model", Input).value
61
+ api_key = self.query_one("#api-key", Input).value
62
+ base_url = self.query_one("#base-url", Input).value
63
+ model_path = self.query_one("#model-path", Input).value
64
+
65
+ config = {
66
+ "provider": provider,
67
+ "model": model,
68
+ }
69
+
70
+ if api_key:
71
+ config["api_key"] = api_key
72
+ if base_url:
73
+ config["base_url"] = base_url
74
+ if model_path:
75
+ config["model_path"] = model_path
76
+
77
+ return config
File without changes
rag_demo/modes/help.py ADDED
@@ -0,0 +1,26 @@
1
+ from pathlib import Path
2
+
3
+ from textual.app import ComposeResult
4
+ from textual.widgets import Footer, Header, Label
5
+
6
+ from rag_demo.modes._logic_provider import LogicProviderScreen
7
+
8
+
9
+ class HelpScreen(LogicProviderScreen):
10
+ """Display information about the application."""
11
+
12
+ SUB_TITLE = "Help"
13
+ CSS_PATH = Path(__file__).parent / "help.tcss"
14
+
15
+ def compose(self) -> ComposeResult:
16
+ """Create the widgets of the help screen.
17
+
18
+ Returns:
19
+ ComposeResult: composition of the help screen
20
+
21
+ Yields:
22
+ Iterator[ComposeResult]: composition of the help screen
23
+ """
24
+ yield Header()
25
+ yield Label("Help Screen (under construction)")
26
+ yield Footer()
File without changes
@@ -0,0 +1 @@
1
+ from .escapable_input import EscapableInput
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from textual.widgets import Input
6
+
7
+ if TYPE_CHECKING:
8
+ from collections.abc import Iterable
9
+
10
+ from rich.console import RenderableType
11
+ from rich.highlighter import Highlighter
12
+ from textual.events import Key
13
+ from textual.suggester import Suggester
14
+ from textual.validation import Validator
15
+ from textual.widget import Widget
16
+ from textual.widgets._input import InputType, InputValidationOn
17
+
18
+
19
+ class EscapableInput(Input):
20
+ """An input widget that deselects itself when the user presses escape.
21
+
22
+ Inherits all properties and methods from the :class:`textual.widgets.Input` class.
23
+ """
24
+
25
+ def __init__( # noqa: PLR0913
26
+ self,
27
+ value: str | None = None,
28
+ placeholder: str = "",
29
+ highlighter: Highlighter | None = None,
30
+ password: bool = False, # noqa: FBT001, FBT002
31
+ *,
32
+ restrict: str | None = None,
33
+ type: InputType = "text", # noqa: A002
34
+ max_length: int = 0,
35
+ suggester: Suggester | None = None,
36
+ validators: Validator | Iterable[Validator] | None = None,
37
+ validate_on: Iterable[InputValidationOn] | None = None,
38
+ valid_empty: bool = False,
39
+ select_on_focus: bool = True,
40
+ name: str | None = None,
41
+ id: str | None = None, # noqa: A002
42
+ classes: str | None = None,
43
+ disabled: bool = False,
44
+ tooltip: RenderableType | None = None,
45
+ compact: bool = False,
46
+ focus_on_escape: Widget | None = None,
47
+ ) -> None:
48
+ """Initialise the `EscapableInput` widget.
49
+
50
+ Args:
51
+ value: An optional default value for the input.
52
+ placeholder: Optional placeholder text for the input.
53
+ highlighter: An optional highlighter for the input.
54
+ password: Flag to say if the field should obfuscate its content.
55
+ restrict: A regex to restrict character inputs.
56
+ type: The type of the input.
57
+ max_length: The maximum length of the input, or 0 for no maximum length.
58
+ suggester: [`Suggester`][textual.suggester.Suggester] associated with this
59
+ input instance.
60
+ validators: An iterable of validators that the Input value will be checked against.
61
+ validate_on: Zero or more of the values "blur", "changed", and "submitted",
62
+ which determine when to do input validation. The default is to do
63
+ validation for all messages.
64
+ valid_empty: Empty values are valid.
65
+ select_on_focus: Whether to select all text on focus.
66
+ name: Optional name for the input widget.
67
+ id: Optional ID for the widget.
68
+ classes: Optional initial classes for the widget.
69
+ disabled: Whether the input is disabled or not.
70
+ tooltip: Optional tooltip.
71
+ compact: Enable compact style (without borders).
72
+ focus_on_escape: An optional widget to focus on when escape is pressed. Defaults to `None`.
73
+ """
74
+ super().__init__(
75
+ value=value,
76
+ placeholder=placeholder,
77
+ highlighter=highlighter,
78
+ password=password,
79
+ restrict=restrict,
80
+ type=type,
81
+ max_length=max_length,
82
+ suggester=suggester,
83
+ validators=validators,
84
+ validate_on=validate_on,
85
+ valid_empty=valid_empty,
86
+ select_on_focus=select_on_focus,
87
+ name=name,
88
+ id=id,
89
+ classes=classes,
90
+ disabled=disabled,
91
+ tooltip=tooltip,
92
+ compact=compact,
93
+ )
94
+ self.focus_on_escape = focus_on_escape
95
+
96
+ def on_key(self, event: Key) -> None:
97
+ """Deselect the input if the event is the escape key.
98
+
99
+ This method overrides the base :meth:`textual.widgets.Input.on_key` implementation.
100
+
101
+ Args:
102
+ event (Key): Event details, including the key pressed.
103
+ """
104
+ if event.key == "escape":
105
+ if self.focus_on_escape is not None:
106
+ self.focus_on_escape.focus()
107
+ else:
108
+ self.blur()
109
+ event.prevent_default()
110
+ event.stop()
@@ -1,11 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: jehoctor-rag-demo
3
- Version: 0.1.1.dev1
4
- Summary: Chat with Wikipedia
5
- Author: James Hoctor
6
- Author-email: James Hoctor <JEHoctor@protonmail.com>
7
- Requires-Python: >=3.13
8
- Description-Content-Type: text/markdown
9
-
10
- # RAG-demo
11
- Chat with Wikipedia
@@ -1,6 +0,0 @@
1
- rag_demo/__init__.py,sha256=STIqC0dNmRNCbDeEDuODLXSsOmuQsB5MC9Ii3el0TDk,54
2
- rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- jehoctor_rag_demo-0.1.1.dev1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
4
- jehoctor_rag_demo-0.1.1.dev1.dist-info/entry_points.txt,sha256=3XjzSTMUH0sKTfuLk5yj8ilz0M-SF4_UuZAu-MVH8qM,40
5
- jehoctor_rag_demo-0.1.1.dev1.dist-info/METADATA,sha256=drDoEBvOOWd5XW3DCedSPqLq3HXkcJquTJ0A4HLi1EU,265
6
- jehoctor_rag_demo-0.1.1.dev1.dist-info/RECORD,,
@@ -1,3 +0,0 @@
1
- [console_scripts]
2
- chat = rag_demo:main
3
-