jehoctor-rag-demo 0.1.1.dev1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jehoctor_rag_demo-0.2.0.dist-info/METADATA +100 -0
- jehoctor_rag_demo-0.2.0.dist-info/RECORD +23 -0
- jehoctor_rag_demo-0.2.0.dist-info/entry_points.txt +3 -0
- rag_demo/__init__.py +0 -2
- rag_demo/__main__.py +31 -0
- rag_demo/app.py +58 -0
- rag_demo/app.tcss +0 -0
- rag_demo/db.py +87 -0
- rag_demo/dirs.py +14 -0
- rag_demo/logic.py +287 -0
- rag_demo/markdown.py +17 -0
- rag_demo/modes/__init__.py +3 -0
- rag_demo/modes/_logic_provider.py +43 -0
- rag_demo/modes/chat.py +315 -0
- rag_demo/modes/chat.tcss +75 -0
- rag_demo/modes/config.py +77 -0
- rag_demo/modes/config.tcss +0 -0
- rag_demo/modes/help.py +26 -0
- rag_demo/modes/help.tcss +0 -0
- rag_demo/widgets/__init__.py +1 -0
- rag_demo/widgets/escapable_input.py +110 -0
- jehoctor_rag_demo-0.1.1.dev1.dist-info/METADATA +0 -11
- jehoctor_rag_demo-0.1.1.dev1.dist-info/RECORD +0 -6
- jehoctor_rag_demo-0.1.1.dev1.dist-info/entry_points.txt +0 -3
- {jehoctor_rag_demo-0.1.1.dev1.dist-info → jehoctor_rag_demo-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: jehoctor-rag-demo
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Chat with Wikipedia
|
|
5
|
+
Author: James Hoctor
|
|
6
|
+
Author-email: James Hoctor <JEHoctor@protonmail.com>
|
|
7
|
+
Requires-Dist: aiosqlite==0.21.0
|
|
8
|
+
Requires-Dist: chromadb>=1.3.4
|
|
9
|
+
Requires-Dist: datasets>=4.4.1
|
|
10
|
+
Requires-Dist: httpx>=0.28.1
|
|
11
|
+
Requires-Dist: huggingface-hub>=0.36.0
|
|
12
|
+
Requires-Dist: langchain>=1.0.5
|
|
13
|
+
Requires-Dist: langchain-anthropic>=1.0.2
|
|
14
|
+
Requires-Dist: langchain-community>=0.4.1
|
|
15
|
+
Requires-Dist: langchain-huggingface>=1.1.0
|
|
16
|
+
Requires-Dist: langchain-ollama>=1.0.0
|
|
17
|
+
Requires-Dist: langchain-openai>=1.0.2
|
|
18
|
+
Requires-Dist: langgraph-checkpoint-sqlite>=3.0.1
|
|
19
|
+
Requires-Dist: llama-cpp-python>=0.3.16
|
|
20
|
+
Requires-Dist: nvidia-ml-py>=13.590.44
|
|
21
|
+
Requires-Dist: ollama>=0.6.0
|
|
22
|
+
Requires-Dist: platformdirs>=4.5.0
|
|
23
|
+
Requires-Dist: psutil>=7.1.3
|
|
24
|
+
Requires-Dist: py-cpuinfo>=9.0.0
|
|
25
|
+
Requires-Dist: pydantic>=2.12.4
|
|
26
|
+
Requires-Dist: pyperclip>=1.11.0
|
|
27
|
+
Requires-Dist: textual>=6.5.0
|
|
28
|
+
Requires-Dist: typer>=0.20.0
|
|
29
|
+
Requires-Python: >=3.12
|
|
30
|
+
Description-Content-Type: text/markdown
|
|
31
|
+
|
|
32
|
+
# RAG-demo
|
|
33
|
+
|
|
34
|
+
Chat with (a small portion of) Wikipedia
|
|
35
|
+
|
|
36
|
+
⚠️ RAG functionality is still under development. ⚠️
|
|
37
|
+
|
|
38
|
+
")
|
|
39
|
+
|
|
40
|
+
## Requirements
|
|
41
|
+
|
|
42
|
+
1. [uv](https://docs.astral.sh/uv/)
|
|
43
|
+
2. At least one of the following:
|
|
44
|
+
- A suitable terminal emulator. In particular, on macOS consider using [iTerm2](https://iterm2.com/) instead of the default Terminal.app ([explanation](https://textual.textualize.io/FAQ/#why-doesnt-textual-look-good-on-macos)). On Linux, you might want to try [kitty](https://sw.kovidgoyal.net/kitty/), [wezterm](https://wezterm.org/), [alacritty](https://alacritty.org/), or [ghostty](https://ghostty.org/) instead of the terminal that came with your DE ([reason](https://darren.codes/posts/textual-copy-paste/)). Windows Terminal should be fine as far as I know.
|
|
45
|
+
- Any common web browser
|
|
46
|
+
|
|
47
|
+
## Optional stuff that could make your experience better
|
|
48
|
+
|
|
49
|
+
1. [Hugging Face login](https://huggingface.co/docs/huggingface_hub/quick-start#login)
|
|
50
|
+
2. API key for your favorite LLM provider (support coming soon)
|
|
51
|
+
3. Ollama installed on your system if you have a GPU
|
|
52
|
+
4. Run RAG-demo on a more capable (bigger GPU) machine over SSH if you can. It is a terminal app after all.
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
## Run from the repository
|
|
56
|
+
|
|
57
|
+
First, clone this repository. Then, run one of the options below.
|
|
58
|
+
|
|
59
|
+
Run in a terminal:
|
|
60
|
+
```bash
|
|
61
|
+
uv run chat
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Or run in a web browser:
|
|
65
|
+
```bash
|
|
66
|
+
uv run textual serve chat
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## Run from the latest version on PyPI
|
|
70
|
+
|
|
71
|
+
TODO: test uv automatic torch backend selection:
|
|
72
|
+
https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection
|
|
73
|
+
|
|
74
|
+
Run in a terminal:
|
|
75
|
+
```bash
|
|
76
|
+
uvx --from=jehoctor-rag-demo chat
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
Or run in a web browser:
|
|
80
|
+
```bash
|
|
81
|
+
uvx --from=jehoctor-rag-demo textual serve chat
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
## CUDA acceleration via Llama.cpp
|
|
85
|
+
|
|
86
|
+
If you have an NVIDIA GPU with CUDA and build tools installed, you might be able to get CUDA acceleration without installing Ollama.
|
|
87
|
+
|
|
88
|
+
```bash
|
|
89
|
+
CMAKE_ARGS="-DGGML_CUDA=on" uv run chat
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
## Metal acceleration via Llama.cpp (on Apple Silicon)
|
|
93
|
+
|
|
94
|
+
On an Apple Silicon machine, make sure `uv` runs an ARM interpreter as this should cause it to install Llama.cpp with Metal support.
|
|
95
|
+
|
|
96
|
+
## Ollama on Linux
|
|
97
|
+
|
|
98
|
+
Remember that you have to keep Ollama up-to-date manually on Linux.
|
|
99
|
+
A recent version of Ollama (v0.11.10 or later) is required to run the [embedding model we use](https://ollama.com/library/embeddinggemma).
|
|
100
|
+
See this FAQ: https://docs.ollama.com/faq#how-can-i-upgrade-ollama.
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
rag_demo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
rag_demo/__main__.py,sha256=Kak0eQWBRHVGDoWgoHs9j-Tvf_9DMzdurMxD7EM4Jr0,1054
|
|
3
|
+
rag_demo/app.py,sha256=xejrtFApeTeyOQvWDq1H0XPyZEr8cQPn7q9KRwnV660,1812
|
|
4
|
+
rag_demo/app.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
+
rag_demo/db.py,sha256=53n662Hj9sTqPNcCI2Q-6Ca_HXv3kBQdAhXU4DLhwBM,3226
|
|
6
|
+
rag_demo/dirs.py,sha256=b0VR76kXRHSRWzaXzmAhfPr3-8WKY3ZLW8aLlaPI3Do,309
|
|
7
|
+
rag_demo/logic.py,sha256=7PTWPs9xZJ7bbEtNDMQTX6SX4JKG8HMiq2H_YUfM-CI,12602
|
|
8
|
+
rag_demo/markdown.py,sha256=CxzshWfANeiieZkzMlLzpRaz7tBY2_tZQxhs7b2ImKM,551
|
|
9
|
+
rag_demo/modes/__init__.py,sha256=ccvURDWz51_IotzzlO2OH3i4_Ih_MgnGlOK_JCh45dY,91
|
|
10
|
+
rag_demo/modes/_logic_provider.py,sha256=__eO4XVbyRHkjV_D8OHsPJX5f2R8JoJPcNXhi-w_xFY,1277
|
|
11
|
+
rag_demo/modes/chat.py,sha256=VigWSkw6R2ea95-wZ8tgtKIccev9A-ByzJj7nzglsog,13444
|
|
12
|
+
rag_demo/modes/chat.tcss,sha256=YANlgYygiOr-e61N9HaGGdRPM36pdr-l4u72G0ozt4o,1032
|
|
13
|
+
rag_demo/modes/config.py,sha256=0A8IdY-GOeqCd0kMs2KMgQEsFFeVXEcnowOugtR_Q84,2609
|
|
14
|
+
rag_demo/modes/config.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
+
rag_demo/modes/help.py,sha256=riV8o4WDtsim09R4cRi0xkpYLgj4CL38IrjEz_mrRmk,713
|
|
16
|
+
rag_demo/modes/help.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
+
rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
+
rag_demo/widgets/__init__.py,sha256=JQ1KQjdYQ4texHw2iT4IyBKgTW0SzNYbNoHAbrdCwtk,44
|
|
19
|
+
rag_demo/widgets/escapable_input.py,sha256=VfFij4NOtQ4uX3YFETg5YPd0_nBMky9Xz-02oRdHu-w,4240
|
|
20
|
+
jehoctor_rag_demo-0.2.0.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
|
|
21
|
+
jehoctor_rag_demo-0.2.0.dist-info/entry_points.txt,sha256=-nDSFVcIqdTxzYM4fdveDk3xUKRhmlr_cRuqQechYh4,49
|
|
22
|
+
jehoctor_rag_demo-0.2.0.dist-info/METADATA,sha256=wp1mdAqjB0be_1Uly4hwAoz0bjRUDI6gb6gK5SdrHRU,3531
|
|
23
|
+
jehoctor_rag_demo-0.2.0.dist-info/RECORD,,
|
rag_demo/__init__.py
CHANGED
rag_demo/__main__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
# Measure the application start time.
|
|
4
|
+
APPLICATION_START_TIME = time.time()
|
|
5
|
+
|
|
6
|
+
# Disable "module import not at top of file" (aka E402) when importing Typer. This is necessary so that Typer's
|
|
7
|
+
# initialization is included in the application startup time.
|
|
8
|
+
import typer # noqa: E402
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _main(
|
|
12
|
+
name: str | None = typer.Option(None, help="The name you want to want the AI to use with you."),
|
|
13
|
+
) -> None:
|
|
14
|
+
"""Talk to Wikipedia."""
|
|
15
|
+
# Import here so that imports run within the typer.run context for prettier stack traces if errors occur.
|
|
16
|
+
# We ignore PLC0415 because we do not want these imports to be at the top of the module as is usually preferred.
|
|
17
|
+
from rag_demo.app import RAGDemo # noqa: PLC0415
|
|
18
|
+
from rag_demo.logic import Logic # noqa: PLC0415
|
|
19
|
+
|
|
20
|
+
logic = Logic(username=name, application_start_time=APPLICATION_START_TIME)
|
|
21
|
+
app = RAGDemo(logic)
|
|
22
|
+
app.run()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def main() -> None:
|
|
26
|
+
"""Entrypoint for the rag demo, specifically the `chat` command."""
|
|
27
|
+
typer.run(_main)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if __name__ == "__main__":
|
|
31
|
+
main()
|
rag_demo/app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
6
|
+
|
|
7
|
+
from textual.app import App
|
|
8
|
+
from textual.binding import Binding
|
|
9
|
+
|
|
10
|
+
from rag_demo.modes import ChatScreen, ConfigScreen, HelpScreen
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from rag_demo.logic import Logic, Runtime
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RAGDemo(App):
|
|
17
|
+
"""Main application UI.
|
|
18
|
+
|
|
19
|
+
This class is responsible for creating the modes of the application, which are defined in :mod:`rag_demo.modes`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
TITLE = "RAG Demo"
|
|
23
|
+
CSS_PATH = Path(__file__).parent / "app.tcss"
|
|
24
|
+
BINDINGS: ClassVar = [
|
|
25
|
+
Binding("z", "switch_mode('chat')", "chat"),
|
|
26
|
+
Binding("c", "switch_mode('config')", "configure"),
|
|
27
|
+
Binding("h", "switch_mode('help')", "help"),
|
|
28
|
+
]
|
|
29
|
+
MODES: ClassVar = {
|
|
30
|
+
"chat": ChatScreen,
|
|
31
|
+
"config": ConfigScreen,
|
|
32
|
+
"help": HelpScreen,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
def __init__(self, logic: Logic) -> None:
|
|
36
|
+
"""Initialize the main app.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
logic (Logic): Object implementing the application logic.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.logic = logic
|
|
43
|
+
self._runtime_future: asyncio.Future[Runtime] = asyncio.Future()
|
|
44
|
+
|
|
45
|
+
async def on_mount(self) -> None:
|
|
46
|
+
"""Set the initial mode to chat and initialize async parts of the logic."""
|
|
47
|
+
self.switch_mode("chat")
|
|
48
|
+
self.run_worker(self._hold_runtime())
|
|
49
|
+
|
|
50
|
+
async def _hold_runtime(self) -> None:
|
|
51
|
+
async with self.logic.runtime(app_like=self) as runtime:
|
|
52
|
+
self._runtime_future.set_result(runtime)
|
|
53
|
+
# Pause the task until Textual cancels it when the application closes.
|
|
54
|
+
await asyncio.Event().wait()
|
|
55
|
+
|
|
56
|
+
async def runtime(self) -> Runtime:
|
|
57
|
+
"""Returns the application runtime logic."""
|
|
58
|
+
return await self._runtime_future
|
rag_demo/app.tcss
ADDED
|
File without changes
|
rag_demo/db.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import aiosqlite
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AtomicIDManager:
|
|
12
|
+
"""A database manager for managing thread IDs.
|
|
13
|
+
|
|
14
|
+
This was written by Claude, and I fixed it up with feedback from Ruff and Flake8.
|
|
15
|
+
Maybe one day the app logic database will require something fancier, but this gets the job done now.
|
|
16
|
+
|
|
17
|
+
As you can see from the conversation with Claude, this was quite a simple task for it:
|
|
18
|
+
https://claude.ai/share/227d08ff-96a3-495a-9f56-509a1fd528f7
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, db_path: str | Path) -> None:
|
|
22
|
+
"""Initialize the database manager."""
|
|
23
|
+
self.db_path = db_path
|
|
24
|
+
|
|
25
|
+
async def initialize(self) -> None:
|
|
26
|
+
"""Initialize the database and create the table if it doesn't exist."""
|
|
27
|
+
async with aiosqlite.connect(self.db_path) as db:
|
|
28
|
+
# Enable WAL mode for better concurrent access
|
|
29
|
+
await db.execute("PRAGMA journal_mode=WAL")
|
|
30
|
+
|
|
31
|
+
await db.execute("""
|
|
32
|
+
CREATE TABLE IF NOT EXISTS claimed_ids (
|
|
33
|
+
id INTEGER PRIMARY KEY
|
|
34
|
+
)
|
|
35
|
+
""")
|
|
36
|
+
await db.commit()
|
|
37
|
+
|
|
38
|
+
async def claim_next_id(self) -> int:
|
|
39
|
+
"""Atomically find the max id, increment it, and claim it. Returns the newly claimed ID.
|
|
40
|
+
|
|
41
|
+
This operation is atomic and multiprocess-safe because:
|
|
42
|
+
1. SQLite serializes writes by default
|
|
43
|
+
2. We use IMMEDIATE transaction to acquire write lock immediately
|
|
44
|
+
3. The entire operation happens in a single transaction
|
|
45
|
+
"""
|
|
46
|
+
async with aiosqlite.connect(self.db_path) as db:
|
|
47
|
+
# Start an IMMEDIATE transaction to get write lock right away
|
|
48
|
+
await db.execute("BEGIN IMMEDIATE")
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
# Find the current max ID
|
|
52
|
+
async with db.execute("SELECT MAX(id) FROM claimed_ids") as cursor:
|
|
53
|
+
row = await cursor.fetchone()
|
|
54
|
+
max_id = row[0] if row is not None and row[0] is not None else 0
|
|
55
|
+
|
|
56
|
+
# Calculate next ID
|
|
57
|
+
next_id = max_id + 1
|
|
58
|
+
|
|
59
|
+
# Insert the new ID
|
|
60
|
+
await db.execute("INSERT INTO claimed_ids (id) VALUES (?)", (next_id,))
|
|
61
|
+
|
|
62
|
+
# Commit the transaction
|
|
63
|
+
await db.commit()
|
|
64
|
+
|
|
65
|
+
except Exception:
|
|
66
|
+
await db.rollback()
|
|
67
|
+
raise
|
|
68
|
+
|
|
69
|
+
else:
|
|
70
|
+
return next_id
|
|
71
|
+
|
|
72
|
+
async def get_all_claimed_ids(self) -> list[int]:
|
|
73
|
+
"""Retrieve all claimed IDs."""
|
|
74
|
+
async with (
|
|
75
|
+
aiosqlite.connect(self.db_path) as db,
|
|
76
|
+
db.execute("SELECT id FROM claimed_ids ORDER BY id") as cursor,
|
|
77
|
+
):
|
|
78
|
+
rows = await cursor.fetchall()
|
|
79
|
+
return [row[0] for row in rows]
|
|
80
|
+
|
|
81
|
+
async def get_count(self) -> int:
|
|
82
|
+
"""Get the total number of claimed IDs."""
|
|
83
|
+
async with aiosqlite.connect(self.db_path) as db, db.execute("SELECT COUNT(*) FROM claimed_ids") as cursor:
|
|
84
|
+
row = await cursor.fetchone()
|
|
85
|
+
if row is None:
|
|
86
|
+
raise ValueError("A SQL COUNT query should always return at least one row") # noqa: EM101, TRY003
|
|
87
|
+
return row[0]
|
rag_demo/dirs.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from platformdirs import PlatformDirs
|
|
4
|
+
|
|
5
|
+
_appdirs = PlatformDirs(appname="jehoctor-rag-demo", ensure_exists=True)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _ensure(dir_: Path) -> Path:
|
|
9
|
+
dir_.mkdir(parents=True, exist_ok=True)
|
|
10
|
+
return dir_
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
DATA_DIR = _appdirs.user_data_path
|
|
14
|
+
CONFIG_DIR = _appdirs.user_config_path
|
rag_demo/logic.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import platform
|
|
5
|
+
import time
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Protocol, TypeVar, cast
|
|
9
|
+
|
|
10
|
+
import aiosqlite
|
|
11
|
+
import cpuinfo
|
|
12
|
+
import httpx
|
|
13
|
+
import huggingface_hub
|
|
14
|
+
import llama_cpp
|
|
15
|
+
import ollama
|
|
16
|
+
import psutil
|
|
17
|
+
import pynvml
|
|
18
|
+
from datasets import Dataset, load_dataset
|
|
19
|
+
from huggingface_hub import hf_hub_download
|
|
20
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
|
21
|
+
from langchain.agents import create_agent
|
|
22
|
+
from langchain.messages import AIMessageChunk, HumanMessage
|
|
23
|
+
from langchain_community.chat_models import ChatLlamaCpp
|
|
24
|
+
from langchain_community.embeddings import LlamaCppEmbeddings
|
|
25
|
+
from langchain_core.exceptions import LangChainException
|
|
26
|
+
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
|
27
|
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
28
|
+
|
|
29
|
+
from rag_demo import dirs
|
|
30
|
+
from rag_demo.db import AtomicIDManager
|
|
31
|
+
from rag_demo.modes.chat import Response, StoppedStreamError
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from collections.abc import AsyncIterator, Awaitable
|
|
35
|
+
|
|
36
|
+
from textual.worker import Worker
|
|
37
|
+
|
|
38
|
+
from rag_demo.modes import ChatScreen
|
|
39
|
+
|
|
40
|
+
ResultType = TypeVar("ResultType")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AppLike(Protocol):
|
|
44
|
+
"""Protocol for the subset of what the main App can do that the runtime needs."""
|
|
45
|
+
|
|
46
|
+
def run_worker(self, work: Awaitable[ResultType]) -> Worker[ResultType]:
|
|
47
|
+
"""Run a coroutine in the background.
|
|
48
|
+
|
|
49
|
+
See https://textual.textualize.io/guide/workers/.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
work (Awaitable[ResultType]): The coroutine to run.
|
|
53
|
+
"""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Runtime:
|
|
58
|
+
"""The application logic with asynchronously initialized resources."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
logic: Logic,
|
|
63
|
+
checkpoints_conn: aiosqlite.Connection,
|
|
64
|
+
thread_id_manager: AtomicIDManager,
|
|
65
|
+
app_like: AppLike,
|
|
66
|
+
) -> None:
|
|
67
|
+
self.runtime_start_time = time.time()
|
|
68
|
+
self.logic = logic
|
|
69
|
+
self.checkpoints_conn = checkpoints_conn
|
|
70
|
+
self.thread_id_manager = thread_id_manager
|
|
71
|
+
self.app_like = app_like
|
|
72
|
+
|
|
73
|
+
self.current_thread: int | None = None
|
|
74
|
+
self.generating = False
|
|
75
|
+
|
|
76
|
+
if self.logic.probe_ollama() is not None:
|
|
77
|
+
ollama.pull("gemma3:latest") # 3.3GB
|
|
78
|
+
ollama.pull("embeddinggemma:latest") # 621MB
|
|
79
|
+
self.llm = ChatOllama(
|
|
80
|
+
model="gemma3:latest",
|
|
81
|
+
validate_model_on_init=True,
|
|
82
|
+
temperature=0.5,
|
|
83
|
+
num_predict=4096,
|
|
84
|
+
)
|
|
85
|
+
self.embed = OllamaEmbeddings(model="embeddinggemma:latest")
|
|
86
|
+
else:
|
|
87
|
+
model_path = hf_hub_download(
|
|
88
|
+
repo_id="bartowski/google_gemma-3-4b-it-GGUF",
|
|
89
|
+
filename="google_gemma-3-4b-it-Q6_K_L.gguf", # 3.35GB
|
|
90
|
+
revision="71506238f970075ca85125cd749c28b1b0eee84e",
|
|
91
|
+
)
|
|
92
|
+
embedding_model_path = hf_hub_download(
|
|
93
|
+
repo_id="CompendiumLabs/bge-small-en-v1.5-gguf",
|
|
94
|
+
filename="bge-small-en-v1.5-q8_0.gguf", # 36.8MB
|
|
95
|
+
revision="d32f8c040ea3b516330eeb75b72bcc2d3a780ab7",
|
|
96
|
+
)
|
|
97
|
+
self.llm = ChatLlamaCpp(model_path=model_path, verbose=False)
|
|
98
|
+
self.embed = LlamaCppEmbeddings(model_path=embedding_model_path, verbose=False) # pyright: ignore[reportCallIssue]
|
|
99
|
+
|
|
100
|
+
self.agent = create_agent(
|
|
101
|
+
model=self.llm,
|
|
102
|
+
system_prompt="You are a helpful assistant.",
|
|
103
|
+
checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def get_rag_datasets(self) -> None:
|
|
107
|
+
self.qa_test: Dataset = cast(
|
|
108
|
+
"Dataset",
|
|
109
|
+
load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer", split="test"),
|
|
110
|
+
)
|
|
111
|
+
self.corpus: Dataset = cast(
|
|
112
|
+
"Dataset",
|
|
113
|
+
load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages"),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
async def stream_response(self, response_widget: Response, request_text: str, thread: str) -> None:
|
|
117
|
+
"""Worker method for streaming tokens from the active agent to a response widget.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
response_widget (Response): Target response widget for streamed tokens.
|
|
121
|
+
request_text (str): Text of the user request.
|
|
122
|
+
thread (str): ID of the current thread.
|
|
123
|
+
"""
|
|
124
|
+
self.generating = True
|
|
125
|
+
async with response_widget.stream_writer() as writer:
|
|
126
|
+
agent_stream = self.agent.astream(
|
|
127
|
+
{"messages": [HumanMessage(content=request_text)]},
|
|
128
|
+
{"configurable": {"thread_id": thread}},
|
|
129
|
+
stream_mode="messages",
|
|
130
|
+
)
|
|
131
|
+
try:
|
|
132
|
+
async for message_chunk, _ in agent_stream:
|
|
133
|
+
if isinstance(message_chunk, AIMessageChunk):
|
|
134
|
+
token = cast("AIMessageChunk", message_chunk).content
|
|
135
|
+
if isinstance(token, str):
|
|
136
|
+
await writer.write(token)
|
|
137
|
+
else:
|
|
138
|
+
response_widget.log.error(f"Received message content of type {type(token)}")
|
|
139
|
+
else:
|
|
140
|
+
response_widget.log.error(f"Received message chunk of type {type(message_chunk)}")
|
|
141
|
+
except StoppedStreamError as e:
|
|
142
|
+
response_widget.set_shown_object(e)
|
|
143
|
+
except LangChainException as e:
|
|
144
|
+
response_widget.set_shown_object(e)
|
|
145
|
+
self.generating = False
|
|
146
|
+
|
|
147
|
+
def new_conversation(self, chat_screen: ChatScreen) -> None:
|
|
148
|
+
self.current_thread = None
|
|
149
|
+
chat_screen.clear_chats()
|
|
150
|
+
|
|
151
|
+
async def submit_request(self, chat_screen: ChatScreen, request_text: str) -> bool:
|
|
152
|
+
if self.generating:
|
|
153
|
+
return False
|
|
154
|
+
self.generating = True
|
|
155
|
+
if self.current_thread is None:
|
|
156
|
+
chat_screen.log.info("Starting new thread")
|
|
157
|
+
self.current_thread = await self.thread_id_manager.claim_next_id()
|
|
158
|
+
chat_screen.log.info("Claimed thread id", self.current_thread)
|
|
159
|
+
chat_screen.new_request(request_text)
|
|
160
|
+
response = chat_screen.new_response()
|
|
161
|
+
chat_screen.run_worker(self.stream_response(response, request_text, str(self.current_thread)))
|
|
162
|
+
return True
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class Logic:
|
|
166
|
+
"""Top-level application logic."""
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
username: str | None = None,
|
|
171
|
+
application_start_time: float | None = None,
|
|
172
|
+
checkpoints_sqlite_db: str | Path = dirs.DATA_DIR / "checkpoints.sqlite3",
|
|
173
|
+
app_sqlite_db: str | Path = dirs.DATA_DIR / "app.sqlite3",
|
|
174
|
+
) -> None:
|
|
175
|
+
"""Initialize the application logic.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
username (str | None, optional): The username provided as a command line argument. Defaults to None.
|
|
179
|
+
application_start_time (float | None, optional): The time when the application started. Defaults to None.
|
|
180
|
+
checkpoints_sqlite_db (str | Path, optional): The connection string for the SQLite database used for
|
|
181
|
+
Langchain checkpointing. Defaults to (dirs.DATA_DIR / "checkpoints.sqlite3").
|
|
182
|
+
app_sqlite_db (str | Path, optional): The connection string for the SQLite database used for application
|
|
183
|
+
state such a thread metadata. Defaults to (dirs.DATA_DIR / "app.sqlite3").
|
|
184
|
+
"""
|
|
185
|
+
self.logic_start_time = time.time()
|
|
186
|
+
self.username = username
|
|
187
|
+
self.application_start_time = application_start_time
|
|
188
|
+
self.checkpoints_sqlite_db = checkpoints_sqlite_db
|
|
189
|
+
self.app_sqlite_db = app_sqlite_db
|
|
190
|
+
|
|
191
|
+
@asynccontextmanager
|
|
192
|
+
async def runtime(self, app_like: AppLike) -> AsyncIterator[Runtime]:
|
|
193
|
+
"""Returns a runtime context for the application."""
|
|
194
|
+
# TODO: Do I need to set check_same_thread=False in aiosqlite.connect?
|
|
195
|
+
async with aiosqlite.connect(database=self.checkpoints_sqlite_db) as checkpoints_conn:
|
|
196
|
+
id_manager = AtomicIDManager(self.app_sqlite_db)
|
|
197
|
+
await id_manager.initialize()
|
|
198
|
+
yield Runtime(self, checkpoints_conn, id_manager, app_like)
|
|
199
|
+
|
|
200
|
+
def probe_os(self) -> str:
|
|
201
|
+
"""Returns the OS name (eg 'Linux' or 'Windows'), the system name (eg 'Java'), or an empty string if unknown."""
|
|
202
|
+
return platform.system()
|
|
203
|
+
|
|
204
|
+
def probe_architecture(self) -> str:
|
|
205
|
+
"""Returns the machine architecture, such as 'i386'."""
|
|
206
|
+
return platform.machine()
|
|
207
|
+
|
|
208
|
+
def probe_cpu(self) -> str:
|
|
209
|
+
"""Returns the name of the CPU, e.g. "Intel(R) Core(TM) i7-10610U CPU @ 1.80GHz"."""
|
|
210
|
+
return cpuinfo.get_cpu_info()["brand_raw"]
|
|
211
|
+
|
|
212
|
+
def probe_ram(self) -> int:
|
|
213
|
+
"""Returns the total amount of RAM in bytes."""
|
|
214
|
+
return psutil.virtual_memory().total
|
|
215
|
+
|
|
216
|
+
def probe_disk_space(self) -> int:
|
|
217
|
+
"""Returns the amount of free space in the root directory (in bytes)."""
|
|
218
|
+
return psutil.disk_usage("/").free
|
|
219
|
+
|
|
220
|
+
def probe_llamacpp_gpu_support(self) -> bool:
|
|
221
|
+
"""Returns True if LlamaCpp supports GPU offloading, False otherwise."""
|
|
222
|
+
return llama_cpp.llama_supports_gpu_offload()
|
|
223
|
+
|
|
224
|
+
def probe_huggingface_free_cache_space(self) -> int | None:
|
|
225
|
+
"""Returns the amount of free space in the Hugging Face cache (in bytes), or None if it can't be determined."""
|
|
226
|
+
with contextlib.suppress(FileNotFoundError):
|
|
227
|
+
return psutil.disk_usage(HF_HUB_CACHE).free
|
|
228
|
+
for parent_dir in Path(HF_HUB_CACHE).parents:
|
|
229
|
+
with contextlib.suppress(FileNotFoundError):
|
|
230
|
+
return psutil.disk_usage(str(parent_dir)).free
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
def probe_huggingface_cached_models(self) -> list[huggingface_hub.CachedRepoInfo] | None:
|
|
234
|
+
"""Returns a list of models in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
|
|
235
|
+
# The docstring for huggingface_hub.scan_cache_dir says it raises CacheNotFound "if the cache directory does not
|
|
236
|
+
# exist," and ValueError "if the cache directory is a file, instead of a directory."
|
|
237
|
+
with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
|
|
238
|
+
return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "model"]
|
|
239
|
+
return None # Isn't it nice to be explicit?
|
|
240
|
+
|
|
241
|
+
def probe_huggingface_cached_datasets(self) -> list[huggingface_hub.CachedRepoInfo] | None:
|
|
242
|
+
"""Returns a list of datasets in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
|
|
243
|
+
with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
|
|
244
|
+
return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "dataset"]
|
|
245
|
+
return None
|
|
246
|
+
|
|
247
|
+
def probe_nvidia(self) -> tuple[int, list[str]]:
|
|
248
|
+
"""Detect available NVIDIA GPUs and CUDA driver version.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
tuple[int, list[str]]: A tuple (cuda_version, nv_gpus) where cuda_version is the installed CUDA driver
|
|
252
|
+
version and nv_gpus is a list of GPU models corresponding to installed NVIDIA GPUs
|
|
253
|
+
"""
|
|
254
|
+
try:
|
|
255
|
+
pynvml.nvmlInit()
|
|
256
|
+
except pynvml.NVMLError:
|
|
257
|
+
return -1, []
|
|
258
|
+
cuda_version = -1
|
|
259
|
+
nv_gpus = []
|
|
260
|
+
try:
|
|
261
|
+
cuda_version = pynvml.nvmlSystemGetCudaDriverVersion()
|
|
262
|
+
for i in range(pynvml.nvmlDeviceGetCount()):
|
|
263
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
264
|
+
nv_gpus.append(pynvml.nvmlDeviceGetName(handle))
|
|
265
|
+
except pynvml.NVMLError:
|
|
266
|
+
pass
|
|
267
|
+
finally:
|
|
268
|
+
with contextlib.suppress(pynvml.NVMLError):
|
|
269
|
+
pynvml.nvmlShutdown()
|
|
270
|
+
return cuda_version, nv_gpus
|
|
271
|
+
|
|
272
|
+
def probe_ollama(self) -> list[ollama.ListResponse.Model] | None:
|
|
273
|
+
"""Returns a list of models installed in Ollama, or None if connecting to Ollama fails."""
|
|
274
|
+
with contextlib.suppress(ConnectionError):
|
|
275
|
+
return list(ollama.list().models)
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
def probe_ollama_version(self) -> str | None:
|
|
279
|
+
"""Returns the Ollama version string (e.g. "0.13.5"), or None if connecting to Ollama fails."""
|
|
280
|
+
# Yes, this uses private attributes, but that lets me use the Ollama Python lib's env var logic. If you use env
|
|
281
|
+
# vars to direct the app to a different Ollama server, this will query the same Ollama endpoint as the
|
|
282
|
+
# ollama.list() call above. Therefore I silence SLF001 here.
|
|
283
|
+
with contextlib.suppress(httpx.HTTPError, KeyError, ValueError):
|
|
284
|
+
response: httpx.Response = ollama._client._client.request("GET", "/api/version") # noqa: SLF001
|
|
285
|
+
response.raise_for_status()
|
|
286
|
+
return response.json()["version"]
|
|
287
|
+
return None
|
rag_demo/markdown.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from markdown_it import MarkdownIt
|
|
2
|
+
from markdown_it.rules_inline import StateInline
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def soft2hard_break_plugin(md: MarkdownIt) -> None:
|
|
6
|
+
md.inline.ruler2.push("soft2hard_break", _soft2hard_break_plugin)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _soft2hard_break_plugin(state: StateInline) -> None:
|
|
10
|
+
for token in state.tokens:
|
|
11
|
+
if token.type == "softbreak":
|
|
12
|
+
token.type = "hardbreak"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parser_factory() -> MarkdownIt:
|
|
16
|
+
"""Modified parser that handles newlines according to LLM conventions."""
|
|
17
|
+
return MarkdownIt("gfm-like").use(soft2hard_break_plugin)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Protocol, cast
|
|
4
|
+
|
|
5
|
+
from textual.screen import Screen
|
|
6
|
+
from textual.widget import Widget
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from rag_demo.logic import Logic, Runtime
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LogicProvider(Protocol):
|
|
13
|
+
"""ABC for classes that contain application logic."""
|
|
14
|
+
|
|
15
|
+
logic: Logic
|
|
16
|
+
|
|
17
|
+
async def runtime(self) -> Runtime: ...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LogicProviderScreen(Screen):
|
|
21
|
+
"""A Screen that provides access to the application logic via its parent app."""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def logic(self) -> Logic:
|
|
25
|
+
"""Returns the application logic of the parent app."""
|
|
26
|
+
return cast("LogicProvider", self.app).logic
|
|
27
|
+
|
|
28
|
+
async def runtime(self) -> Runtime:
|
|
29
|
+
"""Returns the application runtime of the parent app."""
|
|
30
|
+
return await cast("LogicProvider", self.app).runtime()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LogicProviderWidget(Widget):
|
|
34
|
+
"""A Widget that provides access to the application logic via its parent app."""
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def logic(self) -> Logic:
|
|
38
|
+
"""Returns the application logic of the parent app."""
|
|
39
|
+
return cast("LogicProvider", self.app).logic
|
|
40
|
+
|
|
41
|
+
async def runtime(self) -> Runtime:
|
|
42
|
+
"""Returns the application runtime of the parent app."""
|
|
43
|
+
return await cast("LogicProvider", self.app).runtime()
|
rag_demo/modes/chat.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import pyperclip
|
|
9
|
+
from textual.containers import HorizontalGroup, VerticalGroup, VerticalScroll
|
|
10
|
+
from textual.reactive import reactive
|
|
11
|
+
from textual.widgets import Button, Footer, Header, Input, Label, Markdown, Pretty, Static
|
|
12
|
+
from textual.widgets.markdown import MarkdownStream
|
|
13
|
+
|
|
14
|
+
from rag_demo.markdown import parser_factory
|
|
15
|
+
from rag_demo.modes._logic_provider import LogicProviderScreen, LogicProviderWidget
|
|
16
|
+
from rag_demo.widgets import EscapableInput
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from collections.abc import AsyncIterator
|
|
20
|
+
|
|
21
|
+
from textual.app import ComposeResult
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ResponseStreamInProgressError(ValueError):
|
|
25
|
+
"""Exception raised when a Response widget already has an open writer stream."""
|
|
26
|
+
|
|
27
|
+
def __init__(self) -> None: # noqa: D107
|
|
28
|
+
super().__init__("This Response widget already has an open writer stream.")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class StoppedStreamError(ValueError):
|
|
32
|
+
"""Exception raised when a ResponseWriter is asked to write after it has been stopped."""
|
|
33
|
+
|
|
34
|
+
def __init__(self) -> None: # noqa: D107
|
|
35
|
+
super().__init__("Can't write to a stopped ResponseWriter stream.")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ResponseWriter:
|
|
39
|
+
"""Stream markdown to a Response widget as if it were a simple Markdown widget.
|
|
40
|
+
|
|
41
|
+
This handles streaming to the Markdown widget and updating the raw text widget and the generation rate label.
|
|
42
|
+
|
|
43
|
+
This class is based on the MarkdownStream class from the Textual library.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, response_widget: Response) -> None:
|
|
47
|
+
"""Initialize a new ResponseWriter.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
response_widget (Response): The Response widget to write to.
|
|
51
|
+
"""
|
|
52
|
+
self.response_widget = response_widget
|
|
53
|
+
self._markdown_widget = response_widget.query_one("#markdown-view", Markdown)
|
|
54
|
+
self._markdown_stream = MarkdownStream(self._markdown_widget)
|
|
55
|
+
self._raw_widget = response_widget.query_one("#raw-view", Label)
|
|
56
|
+
self._start_time: float | None = None
|
|
57
|
+
self._n_chunks = 0
|
|
58
|
+
self._response_text = ""
|
|
59
|
+
self._stopped = False
|
|
60
|
+
|
|
61
|
+
async def stop(self) -> None:
|
|
62
|
+
"""Stop this ResponseWriter, particularly its underlying MarkdownStream."""
|
|
63
|
+
self._stopped = True
|
|
64
|
+
# This is safe even if the MarkdownStream has not been started, or has already been stopped.
|
|
65
|
+
await self._markdown_stream.stop()
|
|
66
|
+
# Because of the markdown parsing tweaks I made in src/rag_demo/markdown.py, we need to reparse the final
|
|
67
|
+
# markdown and rerender one more time to clean up small issues with newlines.
|
|
68
|
+
self._markdown_widget.update(self._response_text)
|
|
69
|
+
|
|
70
|
+
async def write(self, markdown_fragment: str) -> None:
|
|
71
|
+
"""Stream a single chunk/fragment to the corresponding Response widget.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
markdown_fragment (str): The new markdown fragment to append to the existing markdown.
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
StoppedStreamError: Raised if the new markdown chunk cannot be accepted because the stream has been stopped.
|
|
78
|
+
"""
|
|
79
|
+
if self._stopped:
|
|
80
|
+
raise StoppedStreamError
|
|
81
|
+
write_time = time.time()
|
|
82
|
+
self._response_text += markdown_fragment
|
|
83
|
+
self.response_widget.set_reactive(Response.content, self._response_text)
|
|
84
|
+
self._raw_widget.update(self._response_text)
|
|
85
|
+
if self._start_time is None:
|
|
86
|
+
# The first chunk. Can't set the rate label until we have a second chunk.
|
|
87
|
+
self._start_time = write_time
|
|
88
|
+
|
|
89
|
+
self._markdown_widget.update(markdown_fragment)
|
|
90
|
+
self._markdown_stream.start()
|
|
91
|
+
else:
|
|
92
|
+
# The second and subsequent chunks. Note that self._n_chunks has not been incremented yet because
|
|
93
|
+
# we are calculating the generation rate excluding the first chunk, which may have required loading a
|
|
94
|
+
# large model.
|
|
95
|
+
rate = self._n_chunks / (write_time - self._start_time)
|
|
96
|
+
self.response_widget.update_rate_label(rate)
|
|
97
|
+
|
|
98
|
+
await self._markdown_stream.write(markdown_fragment)
|
|
99
|
+
self._n_chunks += 1
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Response(LogicProviderWidget):
|
|
103
|
+
"""Allow toggling between raw and rendered versions of markdown text."""
|
|
104
|
+
|
|
105
|
+
show_raw = reactive(False, layout=True)
|
|
106
|
+
content = reactive("", layout=True)
|
|
107
|
+
|
|
108
|
+
def __init__(self, *, content: str = "", classes: str | None = None) -> None:
|
|
109
|
+
"""Initialize a new Response widget.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
content (str, optional): Initial response text. Defaults to empty string.
|
|
113
|
+
classes (str | None, optional): Optional widget classes for use with TCSS. Defaults to None.
|
|
114
|
+
"""
|
|
115
|
+
super().__init__(classes=classes)
|
|
116
|
+
self.set_reactive(Response.content, content)
|
|
117
|
+
self._stream: ResponseWriter | None = None
|
|
118
|
+
self.__object_to_show_sentinel = object()
|
|
119
|
+
self._object_to_show: Any = self.__object_to_show_sentinel
|
|
120
|
+
|
|
121
|
+
def compose(self) -> ComposeResult:
|
|
122
|
+
"""Compose the initial content of the widget."""
|
|
123
|
+
with VerticalGroup():
|
|
124
|
+
with HorizontalGroup(id="header"):
|
|
125
|
+
yield Label("Chunks/s: ???", id="token-rate")
|
|
126
|
+
with HorizontalGroup(id="buttons"):
|
|
127
|
+
yield Button("Stop", id="stop", variant="primary")
|
|
128
|
+
yield Button("Show Raw", id="show-raw", variant="primary")
|
|
129
|
+
yield Button("Copy", id="copy", variant="primary")
|
|
130
|
+
yield Markdown(self.content, id="markdown-view", parser_factory=parser_factory)
|
|
131
|
+
yield Label(self.content, id="raw-view", markup=False)
|
|
132
|
+
yield Pretty(None, id="object-view")
|
|
133
|
+
|
|
134
|
+
def on_mount(self) -> None:
|
|
135
|
+
"""Hide certain elements until they are needed."""
|
|
136
|
+
self.query_one("#raw-view", Label).display = False
|
|
137
|
+
self.query_one("#object-view", Pretty).display = False
|
|
138
|
+
self.query_one("#stop", Button).display = False
|
|
139
|
+
|
|
140
|
+
def set_shown_object(self, obj: Any) -> None: # noqa: ANN401
|
|
141
|
+
self._object_to_show = obj
|
|
142
|
+
self.query_one("#markdown-view", Markdown).display = False
|
|
143
|
+
self.query_one("#raw-view", Label).display = False
|
|
144
|
+
self.query_one("#show-raw", Button).display = False
|
|
145
|
+
self.query_one("#object-view", Pretty).update(obj)
|
|
146
|
+
self.query_one("#object-view", Pretty).display = True
|
|
147
|
+
|
|
148
|
+
def clear_shown_object(self) -> None:
|
|
149
|
+
self._object_to_show = self.__object_to_show_sentinel
|
|
150
|
+
self.query_one("#object-view", Pretty).display = False
|
|
151
|
+
if self.show_raw:
|
|
152
|
+
self.query_one("#raw-view", Label).display = True
|
|
153
|
+
else:
|
|
154
|
+
self.query_one("#markdown-view", Markdown).display = True
|
|
155
|
+
self.query_one("#show-raw", Button).display = True
|
|
156
|
+
|
|
157
|
+
@asynccontextmanager
|
|
158
|
+
async def stream_writer(self) -> AsyncIterator[ResponseWriter]:
|
|
159
|
+
"""Open an exclusive stream to write markdown in chunks.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
ResponseWriteInProgressError: Raised when there is already an open stream.
|
|
163
|
+
|
|
164
|
+
Yields:
|
|
165
|
+
ResponseWriter: The new stream writer.
|
|
166
|
+
"""
|
|
167
|
+
if self._stream is not None:
|
|
168
|
+
raise ResponseStreamInProgressError
|
|
169
|
+
self._stream = ResponseWriter(self)
|
|
170
|
+
self.query_one("#stop", Button).display = True
|
|
171
|
+
try:
|
|
172
|
+
yield self._stream
|
|
173
|
+
finally:
|
|
174
|
+
await self._stream.stop()
|
|
175
|
+
self.query_one("#stop", Button).display = False
|
|
176
|
+
self._stream = None
|
|
177
|
+
|
|
178
|
+
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
179
|
+
"""Handle button press events."""
|
|
180
|
+
if event.button.id == "stop":
|
|
181
|
+
if self._stream is not None:
|
|
182
|
+
await self._stream.stop()
|
|
183
|
+
elif event.button.id == "show-raw":
|
|
184
|
+
self.show_raw = not self.show_raw
|
|
185
|
+
elif event.button.id == "copy":
|
|
186
|
+
# Textual and Pyperclip use different methods to copy text to the clipboard. Textual uses ANSI escape
|
|
187
|
+
# sequence magic that is not supported by all terminals. Pyperclip uses OS-specific clipboard APIs, but it
|
|
188
|
+
# does not work over SSH.
|
|
189
|
+
start = time.time()
|
|
190
|
+
self.app.copy_to_clipboard(self.content)
|
|
191
|
+
checkpoint = time.time()
|
|
192
|
+
try:
|
|
193
|
+
pyperclip.copy(self.content)
|
|
194
|
+
except pyperclip.PyperclipException as e:
|
|
195
|
+
self.app.log.error(f"Error copying to clipboard with Pyperclip: {e}")
|
|
196
|
+
checkpoint2 = time.time()
|
|
197
|
+
self.notify(f"Copied {len(self.content.splitlines())} lines of text to clipboard")
|
|
198
|
+
end = time.time()
|
|
199
|
+
self.app.log.info(f"Textual copy took {checkpoint - start:.6f} seconds")
|
|
200
|
+
self.app.log.info(f"Pyperclip copy took {checkpoint2 - checkpoint:.6f} seconds")
|
|
201
|
+
self.app.log.info(f"Notify took {end - checkpoint2:.6f} seconds")
|
|
202
|
+
self.app.log.info(f"Total of {end - start:.6f} seconds")
|
|
203
|
+
|
|
204
|
+
def watch_show_raw(self) -> None:
|
|
205
|
+
"""Handle reactive updates to the show_raw attribute by changing the visibility of the child widgets.
|
|
206
|
+
|
|
207
|
+
This also keeps the text on the visibility toggle button up-to-date.
|
|
208
|
+
"""
|
|
209
|
+
if self._object_to_show is not self.__object_to_show_sentinel:
|
|
210
|
+
return
|
|
211
|
+
button = self.query_one("#show-raw", Button)
|
|
212
|
+
markdown_view = self.query_one("#markdown-view", Markdown)
|
|
213
|
+
raw_view = self.query_one("#raw-view", Label)
|
|
214
|
+
|
|
215
|
+
if self.show_raw:
|
|
216
|
+
button.label = "Show Rendered"
|
|
217
|
+
markdown_view.display = False
|
|
218
|
+
raw_view.display = True
|
|
219
|
+
else:
|
|
220
|
+
button.label = "Show Raw"
|
|
221
|
+
markdown_view.display = True
|
|
222
|
+
raw_view.display = False
|
|
223
|
+
|
|
224
|
+
def watch_content(self, content: str) -> None:
|
|
225
|
+
"""Handle reactive updates to the content attribute by updating the markdown and raw views.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
content (str): New content for the widget.
|
|
229
|
+
"""
|
|
230
|
+
self.query_one("#markdown-view", Markdown).update(content)
|
|
231
|
+
self.query_one("#raw-view", Label).update(content)
|
|
232
|
+
|
|
233
|
+
def update_rate_label(self, rate: float | None) -> None:
|
|
234
|
+
"""Update or reset the generation rate indicator.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
rate (float | None): Generation rate, or None to reset. Defaults to None.
|
|
238
|
+
"""
|
|
239
|
+
label_text = "Chunks/s: ???" if rate is None else f"Chunks/s: {rate:.2f}"
|
|
240
|
+
self.query_one("#token-rate", Label).update(label_text)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class ChatScreen(LogicProviderScreen):
|
|
244
|
+
"""Main mode of the app. Talk to the AI agent."""
|
|
245
|
+
|
|
246
|
+
SUB_TITLE = "Chat"
|
|
247
|
+
CSS_PATH = Path(__file__).parent / "chat.tcss"
|
|
248
|
+
|
|
249
|
+
def compose(self) -> ComposeResult:
|
|
250
|
+
"""Compose the initial content of the chat screen."""
|
|
251
|
+
yield Header()
|
|
252
|
+
chats = VerticalScroll(id="chats")
|
|
253
|
+
with chats:
|
|
254
|
+
yield HorizontalGroup(id="top-chat-separator")
|
|
255
|
+
with HorizontalGroup(id="new-request-bar"):
|
|
256
|
+
yield Static()
|
|
257
|
+
yield Button("New Conversation", id="new-conversation")
|
|
258
|
+
yield EscapableInput(placeholder=" What do you want to know?", id="new-request", focus_on_escape=chats)
|
|
259
|
+
yield Static()
|
|
260
|
+
yield Footer()
|
|
261
|
+
|
|
262
|
+
def on_mount(self) -> None:
|
|
263
|
+
"""When the screen is mounted, focus the input field and enable bottom anchoring for the message view."""
|
|
264
|
+
self.query_one("#new-request", Input).focus()
|
|
265
|
+
self.query_one("#chats", VerticalScroll).anchor()
|
|
266
|
+
|
|
267
|
+
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
268
|
+
"""Handle button press events."""
|
|
269
|
+
if event.button.id == "new-conversation":
|
|
270
|
+
(await self.runtime()).new_conversation(self)
|
|
271
|
+
|
|
272
|
+
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
|
273
|
+
"""Handle submission of new requests."""
|
|
274
|
+
if event.input.id == "new-request":
|
|
275
|
+
accepted = await (await self.runtime()).submit_request(self, event.value)
|
|
276
|
+
if accepted:
|
|
277
|
+
self.query_one("#new-request", Input).value = ""
|
|
278
|
+
|
|
279
|
+
def clear_chats(self) -> None:
|
|
280
|
+
"""Clear the chat scroll area."""
|
|
281
|
+
chats = self.query_one("#chats", VerticalScroll)
|
|
282
|
+
for child in chats.children:
|
|
283
|
+
if child.id != "top-chat-separator":
|
|
284
|
+
child.remove()
|
|
285
|
+
|
|
286
|
+
def new_request(self, request_text: str) -> Label:
|
|
287
|
+
"""Create a new request element in the chat area.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
request_text (str): The text of the request.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Label: The request element.
|
|
294
|
+
"""
|
|
295
|
+
chats = self.query_one("#chats", VerticalScroll)
|
|
296
|
+
request = Label(request_text, classes="request")
|
|
297
|
+
chats.mount(HorizontalGroup(request, classes="request-container"))
|
|
298
|
+
chats.anchor()
|
|
299
|
+
return request
|
|
300
|
+
|
|
301
|
+
def new_response(self, response_text: str = "Waiting for AI to respond...") -> Response:
|
|
302
|
+
"""Create a new response element in the chat area.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
response_text (str, optional): Initial response text. Usually this is a default message shown before
|
|
306
|
+
streaming the actual response. Defaults to "Waiting for AI to respond...".
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Response: The response widget/element.
|
|
310
|
+
"""
|
|
311
|
+
chats = self.query_one("#chats", VerticalScroll)
|
|
312
|
+
response = Response(content=response_text, classes="response")
|
|
313
|
+
chats.mount(HorizontalGroup(response, classes="response-container"))
|
|
314
|
+
chats.anchor()
|
|
315
|
+
return response
|
rag_demo/modes/chat.tcss
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
ToastRack {
|
|
2
|
+
margin-bottom: 5;
|
|
3
|
+
}
|
|
4
|
+
|
|
5
|
+
#chats {
|
|
6
|
+
width: 100%;
|
|
7
|
+
hatch: left $primary-darken-2 100%;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
#top-chat-separator {
|
|
11
|
+
hatch: left $primary-darken-2 100%;
|
|
12
|
+
height: 1;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
.request {
|
|
16
|
+
min-width: 30%;
|
|
17
|
+
max-width: 90%;
|
|
18
|
+
padding: 1 2 1 2;
|
|
19
|
+
margin: 0 2 1 0;
|
|
20
|
+
background: $boost;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
.request-container {
|
|
24
|
+
align-horizontal: right;
|
|
25
|
+
hatch: left $primary-darken-2 100%;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
.response-container {
|
|
29
|
+
align-horizontal: left;
|
|
30
|
+
hatch: left $primary-darken-2 100%;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
.response {
|
|
34
|
+
margin: 0 0 1 2;
|
|
35
|
+
width: 90%;
|
|
36
|
+
height: auto;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
.response #header {
|
|
40
|
+
background: $primary-darken-3;
|
|
41
|
+
margin-bottom: 1;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
.response #buttons {
|
|
45
|
+
align-horizontal: right;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
.response Button {
|
|
49
|
+
border: none;
|
|
50
|
+
min-width: 18;
|
|
51
|
+
margin-left: 1;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
#token-rate {
|
|
55
|
+
padding-left: 1;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
#raw-view {
|
|
59
|
+
padding: 0 2 1 2;
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
#new-request-bar > Static {
|
|
63
|
+
width: 5%;
|
|
64
|
+
height: 3;
|
|
65
|
+
hatch: cross $surface 100%;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
#new-conversation {
|
|
69
|
+
min-width: 20;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
#new-request {
|
|
73
|
+
width: 1fr;
|
|
74
|
+
height: 3;
|
|
75
|
+
}
|
rag_demo/modes/config.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from textual.app import ComposeResult
|
|
4
|
+
from textual.containers import Container, Horizontal
|
|
5
|
+
from textual.widgets import (
|
|
6
|
+
Button,
|
|
7
|
+
Footer,
|
|
8
|
+
Header,
|
|
9
|
+
Input,
|
|
10
|
+
Label,
|
|
11
|
+
RadioButton,
|
|
12
|
+
RadioSet,
|
|
13
|
+
Static,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from rag_demo.modes._logic_provider import LogicProviderScreen
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConfigScreen(LogicProviderScreen):
|
|
20
|
+
SUB_TITLE = "Configure"
|
|
21
|
+
CSS_PATH = Path(__file__).parent / "config.tcss"
|
|
22
|
+
|
|
23
|
+
def compose(self) -> ComposeResult:
|
|
24
|
+
yield Header()
|
|
25
|
+
yield Container(
|
|
26
|
+
Static("🤖 LLM Configuration", classes="title"),
|
|
27
|
+
Label("Select your LLM provider:"),
|
|
28
|
+
RadioSet(
|
|
29
|
+
RadioButton("OpenAI (API)", id="openai"),
|
|
30
|
+
RadioButton("Anthropic Claude (API)", id="anthropic"),
|
|
31
|
+
RadioButton("Ollama (Local)", id="ollama"),
|
|
32
|
+
RadioButton("LlamaCpp (Local)", id="llamacpp"),
|
|
33
|
+
id="provider",
|
|
34
|
+
),
|
|
35
|
+
Label("Model name:"),
|
|
36
|
+
Input(placeholder="e.g., gpt-4, claude-3-sonnet-20240229", id="model"),
|
|
37
|
+
Label("API Key (if applicable):"),
|
|
38
|
+
Input(placeholder="sk-...", password=True, id="api-key"),
|
|
39
|
+
Label("Base URL (for Ollama):"),
|
|
40
|
+
Input(placeholder="http://localhost:11434", id="base-url"),
|
|
41
|
+
Label("Model Path (for LlamaCpp):"),
|
|
42
|
+
Input(placeholder="/path/to/model.gguf", id="model-path"),
|
|
43
|
+
Horizontal(
|
|
44
|
+
Button("Save & Continue", variant="primary", id="save"),
|
|
45
|
+
Button("Cancel", id="cancel"),
|
|
46
|
+
),
|
|
47
|
+
)
|
|
48
|
+
yield Footer()
|
|
49
|
+
|
|
50
|
+
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
51
|
+
if event.button.id == "save":
|
|
52
|
+
config = self.collect_config()
|
|
53
|
+
self.app.config_manager.save_config(config)
|
|
54
|
+
self.app.pop_screen() # Return to main app
|
|
55
|
+
elif event.button.id == "cancel":
|
|
56
|
+
self.app.exit()
|
|
57
|
+
|
|
58
|
+
def collect_config(self) -> dict:
|
|
59
|
+
provider = self.query_one("#provider", RadioSet).pressed_button.id
|
|
60
|
+
model = self.query_one("#model", Input).value
|
|
61
|
+
api_key = self.query_one("#api-key", Input).value
|
|
62
|
+
base_url = self.query_one("#base-url", Input).value
|
|
63
|
+
model_path = self.query_one("#model-path", Input).value
|
|
64
|
+
|
|
65
|
+
config = {
|
|
66
|
+
"provider": provider,
|
|
67
|
+
"model": model,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
if api_key:
|
|
71
|
+
config["api_key"] = api_key
|
|
72
|
+
if base_url:
|
|
73
|
+
config["base_url"] = base_url
|
|
74
|
+
if model_path:
|
|
75
|
+
config["model_path"] = model_path
|
|
76
|
+
|
|
77
|
+
return config
|
|
File without changes
|
rag_demo/modes/help.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from textual.app import ComposeResult
|
|
4
|
+
from textual.widgets import Footer, Header, Label
|
|
5
|
+
|
|
6
|
+
from rag_demo.modes._logic_provider import LogicProviderScreen
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HelpScreen(LogicProviderScreen):
|
|
10
|
+
"""Display information about the application."""
|
|
11
|
+
|
|
12
|
+
SUB_TITLE = "Help"
|
|
13
|
+
CSS_PATH = Path(__file__).parent / "help.tcss"
|
|
14
|
+
|
|
15
|
+
def compose(self) -> ComposeResult:
|
|
16
|
+
"""Create the widgets of the help screen.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
ComposeResult: composition of the help screen
|
|
20
|
+
|
|
21
|
+
Yields:
|
|
22
|
+
Iterator[ComposeResult]: composition of the help screen
|
|
23
|
+
"""
|
|
24
|
+
yield Header()
|
|
25
|
+
yield Label("Help Screen (under construction)")
|
|
26
|
+
yield Footer()
|
rag_demo/modes/help.tcss
ADDED
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .escapable_input import EscapableInput
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from textual.widgets import Input
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from collections.abc import Iterable
|
|
9
|
+
|
|
10
|
+
from rich.console import RenderableType
|
|
11
|
+
from rich.highlighter import Highlighter
|
|
12
|
+
from textual.events import Key
|
|
13
|
+
from textual.suggester import Suggester
|
|
14
|
+
from textual.validation import Validator
|
|
15
|
+
from textual.widget import Widget
|
|
16
|
+
from textual.widgets._input import InputType, InputValidationOn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class EscapableInput(Input):
|
|
20
|
+
"""An input widget that deselects itself when the user presses escape.
|
|
21
|
+
|
|
22
|
+
Inherits all properties and methods from the :class:`textual.widgets.Input` class.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__( # noqa: PLR0913
|
|
26
|
+
self,
|
|
27
|
+
value: str | None = None,
|
|
28
|
+
placeholder: str = "",
|
|
29
|
+
highlighter: Highlighter | None = None,
|
|
30
|
+
password: bool = False, # noqa: FBT001, FBT002
|
|
31
|
+
*,
|
|
32
|
+
restrict: str | None = None,
|
|
33
|
+
type: InputType = "text", # noqa: A002
|
|
34
|
+
max_length: int = 0,
|
|
35
|
+
suggester: Suggester | None = None,
|
|
36
|
+
validators: Validator | Iterable[Validator] | None = None,
|
|
37
|
+
validate_on: Iterable[InputValidationOn] | None = None,
|
|
38
|
+
valid_empty: bool = False,
|
|
39
|
+
select_on_focus: bool = True,
|
|
40
|
+
name: str | None = None,
|
|
41
|
+
id: str | None = None, # noqa: A002
|
|
42
|
+
classes: str | None = None,
|
|
43
|
+
disabled: bool = False,
|
|
44
|
+
tooltip: RenderableType | None = None,
|
|
45
|
+
compact: bool = False,
|
|
46
|
+
focus_on_escape: Widget | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialise the `EscapableInput` widget.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
value: An optional default value for the input.
|
|
52
|
+
placeholder: Optional placeholder text for the input.
|
|
53
|
+
highlighter: An optional highlighter for the input.
|
|
54
|
+
password: Flag to say if the field should obfuscate its content.
|
|
55
|
+
restrict: A regex to restrict character inputs.
|
|
56
|
+
type: The type of the input.
|
|
57
|
+
max_length: The maximum length of the input, or 0 for no maximum length.
|
|
58
|
+
suggester: [`Suggester`][textual.suggester.Suggester] associated with this
|
|
59
|
+
input instance.
|
|
60
|
+
validators: An iterable of validators that the Input value will be checked against.
|
|
61
|
+
validate_on: Zero or more of the values "blur", "changed", and "submitted",
|
|
62
|
+
which determine when to do input validation. The default is to do
|
|
63
|
+
validation for all messages.
|
|
64
|
+
valid_empty: Empty values are valid.
|
|
65
|
+
select_on_focus: Whether to select all text on focus.
|
|
66
|
+
name: Optional name for the input widget.
|
|
67
|
+
id: Optional ID for the widget.
|
|
68
|
+
classes: Optional initial classes for the widget.
|
|
69
|
+
disabled: Whether the input is disabled or not.
|
|
70
|
+
tooltip: Optional tooltip.
|
|
71
|
+
compact: Enable compact style (without borders).
|
|
72
|
+
focus_on_escape: An optional widget to focus on when escape is pressed. Defaults to `None`.
|
|
73
|
+
"""
|
|
74
|
+
super().__init__(
|
|
75
|
+
value=value,
|
|
76
|
+
placeholder=placeholder,
|
|
77
|
+
highlighter=highlighter,
|
|
78
|
+
password=password,
|
|
79
|
+
restrict=restrict,
|
|
80
|
+
type=type,
|
|
81
|
+
max_length=max_length,
|
|
82
|
+
suggester=suggester,
|
|
83
|
+
validators=validators,
|
|
84
|
+
validate_on=validate_on,
|
|
85
|
+
valid_empty=valid_empty,
|
|
86
|
+
select_on_focus=select_on_focus,
|
|
87
|
+
name=name,
|
|
88
|
+
id=id,
|
|
89
|
+
classes=classes,
|
|
90
|
+
disabled=disabled,
|
|
91
|
+
tooltip=tooltip,
|
|
92
|
+
compact=compact,
|
|
93
|
+
)
|
|
94
|
+
self.focus_on_escape = focus_on_escape
|
|
95
|
+
|
|
96
|
+
def on_key(self, event: Key) -> None:
|
|
97
|
+
"""Deselect the input if the event is the escape key.
|
|
98
|
+
|
|
99
|
+
This method overrides the base :meth:`textual.widgets.Input.on_key` implementation.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
event (Key): Event details, including the key pressed.
|
|
103
|
+
"""
|
|
104
|
+
if event.key == "escape":
|
|
105
|
+
if self.focus_on_escape is not None:
|
|
106
|
+
self.focus_on_escape.focus()
|
|
107
|
+
else:
|
|
108
|
+
self.blur()
|
|
109
|
+
event.prevent_default()
|
|
110
|
+
event.stop()
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.3
|
|
2
|
-
Name: jehoctor-rag-demo
|
|
3
|
-
Version: 0.1.1.dev1
|
|
4
|
-
Summary: Chat with Wikipedia
|
|
5
|
-
Author: James Hoctor
|
|
6
|
-
Author-email: James Hoctor <JEHoctor@protonmail.com>
|
|
7
|
-
Requires-Python: >=3.13
|
|
8
|
-
Description-Content-Type: text/markdown
|
|
9
|
-
|
|
10
|
-
# RAG-demo
|
|
11
|
-
Chat with Wikipedia
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
rag_demo/__init__.py,sha256=STIqC0dNmRNCbDeEDuODLXSsOmuQsB5MC9Ii3el0TDk,54
|
|
2
|
-
rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
jehoctor_rag_demo-0.1.1.dev1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
|
|
4
|
-
jehoctor_rag_demo-0.1.1.dev1.dist-info/entry_points.txt,sha256=3XjzSTMUH0sKTfuLk5yj8ilz0M-SF4_UuZAu-MVH8qM,40
|
|
5
|
-
jehoctor_rag_demo-0.1.1.dev1.dist-info/METADATA,sha256=drDoEBvOOWd5XW3DCedSPqLq3HXkcJquTJ0A4HLi1EU,265
|
|
6
|
-
jehoctor_rag_demo-0.1.1.dev1.dist-info/RECORD,,
|
|
File without changes
|