jehoctor-rag-demo 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,11 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: jehoctor-rag-demo
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Chat with Wikipedia
5
5
  Author: James Hoctor
6
6
  Author-email: James Hoctor <JEHoctor@protonmail.com>
7
7
  Requires-Dist: aiosqlite==0.21.0
8
+ Requires-Dist: bitsandbytes>=0.49.1
8
9
  Requires-Dist: chromadb>=1.3.4
9
10
  Requires-Dist: datasets>=4.4.1
10
11
  Requires-Dist: httpx>=0.28.1
@@ -16,7 +17,6 @@ Requires-Dist: langchain-huggingface>=1.1.0
16
17
  Requires-Dist: langchain-ollama>=1.0.0
17
18
  Requires-Dist: langchain-openai>=1.0.2
18
19
  Requires-Dist: langgraph-checkpoint-sqlite>=3.0.1
19
- Requires-Dist: llama-cpp-python>=0.3.16
20
20
  Requires-Dist: nvidia-ml-py>=13.590.44
21
21
  Requires-Dist: ollama>=0.6.0
22
22
  Requires-Dist: platformdirs>=4.5.0
@@ -24,9 +24,13 @@ Requires-Dist: psutil>=7.1.3
24
24
  Requires-Dist: py-cpuinfo>=9.0.0
25
25
  Requires-Dist: pydantic>=2.12.4
26
26
  Requires-Dist: pyperclip>=1.11.0
27
+ Requires-Dist: sentence-transformers>=5.2.2
27
28
  Requires-Dist: textual>=6.5.0
29
+ Requires-Dist: transformers[torch]>=4.57.6
28
30
  Requires-Dist: typer>=0.20.0
29
- Requires-Python: >=3.12
31
+ Requires-Dist: llama-cpp-python>=0.3.16 ; extra == 'llamacpp'
32
+ Requires-Python: ~=3.12.0
33
+ Provides-Extra: llamacpp
30
34
  Description-Content-Type: text/markdown
31
35
 
32
36
  # RAG-demo
@@ -35,50 +39,43 @@ Chat with (a small portion of) Wikipedia
35
39
 
36
40
  ⚠️ RAG functionality is still under development. ⚠️
37
41
 
38
- ![app screenshot](screenshots/screenshot_062f205a.png "App screenshot (this AI response is not accurate)")
42
+ ![app screenshot](screenshots/screenshot_0.2.0.png "App screenshot")
39
43
 
40
44
  ## Requirements
41
45
 
42
- 1. [uv](https://docs.astral.sh/uv/)
43
- 2. At least one of the following:
44
- - A suitable terminal emulator. In particular, on macOS consider using [iTerm2](https://iterm2.com/) instead of the default Terminal.app ([explanation](https://textual.textualize.io/FAQ/#why-doesnt-textual-look-good-on-macos)). On Linux, you might want to try [kitty](https://sw.kovidgoyal.net/kitty/), [wezterm](https://wezterm.org/), [alacritty](https://alacritty.org/), or [ghostty](https://ghostty.org/) instead of the terminal that came with your DE ([reason](https://darren.codes/posts/textual-copy-paste/)). Windows Terminal should be fine as far as I know.
45
- - Any common web browser
46
+ 1. The [uv](https://docs.astral.sh/uv/) Python package manager
47
+ - Installing and updating `uv` is easy by following [the docs](https://docs.astral.sh/uv/getting-started/installation/).
48
+ - As of 2026-01-25, I'm developing using `uv` version 0.9.26, and using the new experimental `--pytorch-backend` option.
49
+ 2. A terminal emulator or web browser
50
+ - Any common web browser will work.
51
+ - Some terminal emulators will work better than others.
52
+ See [Notes on terminal emulators](#notes-on-terminal-emulators) below.
46
53
 
47
- ## Optional stuff that could make your experience better
54
+ ### Notes on terminal emulators
55
+
56
+ Certain terminal emulators will not work with some features of this program.
57
+ In particular, on macOS consider using [iTerm2](https://iterm2.com/) instead of the default Terminal.app ([explanation](https://textual.textualize.io/FAQ/#why-doesnt-textual-look-good-on-macos)).
58
+ On Linux you might want to try [kitty](https://sw.kovidgoyal.net/kitty/), [wezterm](https://wezterm.org/), [alacritty](https://alacritty.org/), or [ghostty](https://ghostty.org/), instead of the terminal that came with your desktop environment ([reason](https://darren.codes/posts/textual-copy-paste/)).
59
+ Windows Terminal should be fine as far as I know.
60
+
61
+ ### Optional dependencies
48
62
 
49
63
  1. [Hugging Face login](https://huggingface.co/docs/huggingface_hub/quick-start#login)
50
64
  2. API key for your favorite LLM provider (support coming soon)
51
65
  3. Ollama installed on your system if you have a GPU
52
66
  4. Run RAG-demo on a more capable (bigger GPU) machine over SSH if you can. It is a terminal app after all.
67
+ 5. A C compiler if you want to build Llama.cpp from source.
53
68
 
54
-
55
- ## Run from the repository
56
-
57
- First, clone this repository. Then, run one of the options below.
69
+ ## Run the latest version
58
70
 
59
71
  Run in a terminal:
60
72
  ```bash
61
- uv run chat
73
+ uvx --torch-backend=auto --from=jehoctor-rag-demo@latest chat
62
74
  ```
63
75
 
64
76
  Or run in a web browser:
65
77
  ```bash
66
- uv run textual serve chat
67
- ```
68
-
69
- ## Run from the latest version on PyPI
70
-
71
- TODO: test uv automatic torch backend selection:
72
- https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection
73
-
74
- Run in a terminal:
75
- ```bash
76
- uvx --from=jehoctor-rag-demo chat
77
- ```
78
-
79
- Or run in a web browser:
80
- ```bash
81
- uvx --from=jehoctor-rag-demo textual serve chat
78
+ uvx --torch-backend=auto --from=jehoctor-rag-demo@latest textual serve chat
82
79
  ```
83
80
 
84
81
  ## CUDA acceleration via Llama.cpp
@@ -86,15 +83,43 @@ uvx --from=jehoctor-rag-demo textual serve chat
86
83
  If you have an NVIDIA GPU with CUDA and build tools installed, you might be able to get CUDA acceleration without installing Ollama.
87
84
 
88
85
  ```bash
89
- CMAKE_ARGS="-DGGML_CUDA=on" uv run chat
86
+ CMAKE_ARGS="-DGGML_CUDA=on" uv run --extra=llamacpp chat
90
87
  ```
91
88
 
92
89
  ## Metal acceleration via Llama.cpp (on Apple Silicon)
93
90
 
94
91
  On an Apple Silicon machine, make sure `uv` runs an ARM interpreter as this should cause it to install Llama.cpp with Metal support.
92
+ Also, run with the extra group `llamacpp`.
93
+ Try this:
94
+
95
+ ```bash
96
+ uvx --python-platform=aarch64-apple-darwin --torch-backend=auto --from=jehoctor-rag-demo[llamacpp]@latest chat
97
+ ```
95
98
 
96
99
  ## Ollama on Linux
97
100
 
98
101
  Remember that you have to keep Ollama up-to-date manually on Linux.
99
102
  A recent version of Ollama (v0.11.10 or later) is required to run the [embedding model we use](https://ollama.com/library/embeddinggemma).
100
103
  See this FAQ: https://docs.ollama.com/faq#how-can-i-upgrade-ollama.
104
+
105
+ ## Project feature roadmap
106
+
107
+ - ❌ RAG functionality
108
+ - ❌ torch inference via the Langchain local Hugging Face inference integration
109
+ - ❌ uv automatic torch backend selection (see [the docs](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection))
110
+ - ❌ OpenAI integration
111
+ - ❌ Anthropic integration
112
+
113
+ ## Run from the repository
114
+
115
+ First, clone this repository. Then, run one of the options below.
116
+
117
+ Run in a terminal:
118
+ ```bash
119
+ uv run chat
120
+ ```
121
+
122
+ Or run in a web browser:
123
+ ```bash
124
+ uv run textual serve chat
125
+ ```
@@ -0,0 +1,31 @@
1
+ rag_demo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ rag_demo/__main__.py,sha256=S0UlQj3EldcXRk3rhH3ONdSOPmeyITYXIZ0o2JSxWbg,1618
3
+ rag_demo/agents/__init__.py,sha256=dsuO3AGcn2yGDq4gkAsZ32pjeTOqAudOL14G_AsEUyc,221
4
+ rag_demo/agents/base.py,sha256=gib6bC8nVKN1s1KPZd1dJVGRXnu7gFQwf3I3_7TSjQo,1312
5
+ rag_demo/agents/hugging_face.py,sha256=VrbGOlMO2z357LmU3sO5aM_yI5P-xbsfKmTAzH9_lFo,4225
6
+ rag_demo/agents/llama_cpp.py,sha256=C0hInc24sXmt5407_k4mP2Y6svqgfUPemhzAL1N6jY0,4272
7
+ rag_demo/agents/ollama.py,sha256=Fmtu8MSPPz91eT7HKvwvbQnA_xGPaD5HBrHEPPtomZA,3317
8
+ rag_demo/app.py,sha256=AVCJjlQ60y5J0v50TcJ3zZoa0ubhd_yKVDfu1ERsMVU,1807
9
+ rag_demo/app.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ rag_demo/app_protocol.py,sha256=P__Q3KT41uonYazpYmLWmOh1MeoBaCBx4xEydkwK5tk,3292
11
+ rag_demo/constants.py,sha256=5EpyAD6p5Qb0vB5ASMtFSLVVsq3RG5cBn_AYihhDCPc,235
12
+ rag_demo/db.py,sha256=53n662Hj9sTqPNcCI2Q-6Ca_HXv3kBQdAhXU4DLhwBM,3226
13
+ rag_demo/dirs.py,sha256=b0VR76kXRHSRWzaXzmAhfPr3-8WKY3ZLW8aLlaPI3Do,309
14
+ rag_demo/logic.py,sha256=SkF_Hqu1WSLHzwvSd_mJiCMSxZYqDnteYFRpc6oCREY,8236
15
+ rag_demo/markdown.py,sha256=CxzshWfANeiieZkzMlLzpRaz7tBY2_tZQxhs7b2ImKM,551
16
+ rag_demo/modes/__init__.py,sha256=ccvURDWz51_IotzzlO2OH3i4_Ih_MgnGlOK_JCh45dY,91
17
+ rag_demo/modes/_logic_provider.py,sha256=U3J8Fgq8MbNYd92FqENW-5YP_jXqKG3xmMmYoSUzhHo,1343
18
+ rag_demo/modes/chat.py,sha256=2pmKhQ2uYZdjezNnNBINViBMcuTVE5YCom_HEbdJeXg,13607
19
+ rag_demo/modes/chat.tcss,sha256=YANlgYygiOr-e61N9HaGGdRPM36pdr-l4u72G0ozt4o,1032
20
+ rag_demo/modes/config.py,sha256=0A8IdY-GOeqCd0kMs2KMgQEsFFeVXEcnowOugtR_Q84,2609
21
+ rag_demo/modes/config.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ rag_demo/modes/help.py,sha256=riV8o4WDtsim09R4cRi0xkpYLgj4CL38IrjEz_mrRmk,713
23
+ rag_demo/modes/help.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
+ rag_demo/probe.py,sha256=aDD-smNauEXXoBKVgx5xsMawM5tL0QAEBFl07ZGrddc,5101
25
+ rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ rag_demo/widgets/__init__.py,sha256=JQ1KQjdYQ4texHw2iT4IyBKgTW0SzNYbNoHAbrdCwtk,44
27
+ rag_demo/widgets/escapable_input.py,sha256=VfFij4NOtQ4uX3YFETg5YPd0_nBMky9Xz-02oRdHu-w,4240
28
+ jehoctor_rag_demo-0.2.1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
29
+ jehoctor_rag_demo-0.2.1.dist-info/entry_points.txt,sha256=-nDSFVcIqdTxzYM4fdveDk3xUKRhmlr_cRuqQechYh4,49
30
+ jehoctor_rag_demo-0.2.1.dist-info/METADATA,sha256=nCXuy3TYPPf67DPFryndW2P6CRdSCcJkXxFXZ-UN4vs,4650
31
+ jehoctor_rag_demo-0.2.1.dist-info/RECORD,,
rag_demo/__main__.py CHANGED
@@ -3,21 +3,32 @@ import time
3
3
  # Measure the application start time.
4
4
  APPLICATION_START_TIME = time.time()
5
5
 
6
- # Disable "module import not at top of file" (aka E402) when importing Typer. This is necessary so that Typer's
7
- # initialization is included in the application startup time.
6
+ # Disable "module import not at top of file" (aka E402) when importing Typer and other early imports. This is necessary
7
+ # so that the initialization of these modules is included in the application startup time.
8
+ from typing import Annotated # noqa: E402
9
+
8
10
  import typer # noqa: E402
9
11
 
12
+ from rag_demo.constants import LocalProviderType # noqa: E402
13
+
10
14
 
11
15
  def _main(
12
- name: str | None = typer.Option(None, help="The name you want to want the AI to use with you."),
16
+ name: Annotated[str | None, typer.Option(help="The name you want to want the AI to use with you.")] = None,
17
+ provider: Annotated[LocalProviderType | None, typer.Option(help="The local provider to prefer.")] = None,
13
18
  ) -> None:
14
19
  """Talk to Wikipedia."""
15
20
  # Import here so that imports run within the typer.run context for prettier stack traces if errors occur.
16
21
  # We ignore PLC0415 because we do not want these imports to be at the top of the module as is usually preferred.
22
+ import transformers # noqa: PLC0415
23
+
17
24
  from rag_demo.app import RAGDemo # noqa: PLC0415
18
25
  from rag_demo.logic import Logic # noqa: PLC0415
19
26
 
20
- logic = Logic(username=name, application_start_time=APPLICATION_START_TIME)
27
+ # The transformers library likes to print text that interferes with the TUI. Disable it.
28
+ transformers.logging.set_verbosity(verbosity=transformers.logging.CRITICAL)
29
+ transformers.logging.disable_progress_bar()
30
+
31
+ logic = Logic(username=name, preferred_provider_type=provider, application_start_time=APPLICATION_START_TIME)
21
32
  app = RAGDemo(logic)
22
33
  app.run()
23
34
 
@@ -0,0 +1,4 @@
1
+ from .base import Agent, AgentProvider
2
+ from .hugging_face import HuggingFaceAgent, HuggingFaceAgentProvider
3
+ from .llama_cpp import LlamaCppAgent, LlamaCppAgentProvider
4
+ from .ollama import OllamaAgent, OllamaAgentProvider
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Final, Protocol
4
+
5
+ if TYPE_CHECKING:
6
+ from collections.abc import AsyncIterator
7
+ from contextlib import AbstractAsyncContextManager
8
+ from pathlib import Path
9
+
10
+ from rag_demo.app_protocol import AppProtocol
11
+ from rag_demo.constants import LocalProviderType
12
+
13
+
14
+ class Agent(Protocol):
15
+ """An LLM agent that supports streaming responses asynchronously."""
16
+
17
+ def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
18
+ """Stream a response from the agent.
19
+
20
+ Args:
21
+ user_message (str): User's next prompt in the conversation.
22
+ thread_id (str): Identifier for the current thread/conversation.
23
+ app (AppProtocol): Application interface, commonly used for logging.
24
+
25
+ Yields:
26
+ str: A token from the agent's response.
27
+ """
28
+
29
+
30
+ class AgentProvider(Protocol):
31
+ """A strategy for creating LLM agents."""
32
+
33
+ type: Final[LocalProviderType]
34
+
35
+ def get_agent(self, checkpoints_sqlite_db: str | Path) -> AbstractAsyncContextManager[Agent | None]:
36
+ """Attempt to create an agent.
37
+
38
+ Args:
39
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
40
+ """
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import sqlite3
5
+ from contextlib import asynccontextmanager
6
+ from typing import TYPE_CHECKING, Final
7
+
8
+ from huggingface_hub import hf_hub_download
9
+ from langchain.agents import create_agent
10
+ from langchain.messages import AIMessageChunk, HumanMessage
11
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings, HuggingFacePipeline
12
+ from langgraph.checkpoint.sqlite import SqliteSaver
13
+
14
+ from rag_demo.constants import LocalProviderType
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import AsyncIterator
18
+ from pathlib import Path
19
+
20
+ from rag_demo.app_protocol import AppProtocol
21
+
22
+
23
+ class HuggingFaceAgent:
24
+ """An LLM agent powered by Hugging Face local pipelines."""
25
+
26
+ def __init__(
27
+ self,
28
+ checkpoints_sqlite_db: str | Path,
29
+ model_id: str,
30
+ embedding_model_id: str,
31
+ ) -> None:
32
+ """Initialize the HuggingFaceAgent.
33
+
34
+ Args:
35
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
36
+ model_id (str): Hugging Face model ID for the LLM.
37
+ embedding_model_id (str): Hugging Face model ID for the embedding model.
38
+ """
39
+ self.checkpoints_sqlite_db = checkpoints_sqlite_db
40
+ self.model_id = model_id
41
+ self.embedding_model_id = embedding_model_id
42
+
43
+ self.llm = ChatHuggingFace(
44
+ llm=HuggingFacePipeline.from_model_id(
45
+ model_id=model_id,
46
+ task="text-generation",
47
+ device_map="auto",
48
+ pipeline_kwargs={"max_new_tokens": 4096},
49
+ ),
50
+ )
51
+ self.embed = HuggingFaceEmbeddings(model_name=embedding_model_id)
52
+ self.agent = create_agent(
53
+ model=self.llm,
54
+ system_prompt="You are a helpful assistant.",
55
+ checkpointer=SqliteSaver(sqlite3.Connection(self.checkpoints_sqlite_db, check_same_thread=False)),
56
+ )
57
+
58
+ async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
59
+ """Stream a response from the agent.
60
+
61
+ Args:
62
+ user_message (str): User's next prompt in the conversation.
63
+ thread_id (str): Identifier for the current thread/conversation.
64
+ app (AppProtocol): Application interface, commonly used for logging.
65
+
66
+ Yields:
67
+ str: A token from the agent's response.
68
+ """
69
+ agent_stream = self.agent.stream(
70
+ {"messages": [HumanMessage(content=user_message)]},
71
+ {"configurable": {"thread_id": thread_id}},
72
+ stream_mode="messages",
73
+ )
74
+ for message_chunk, _ in agent_stream:
75
+ if isinstance(message_chunk, AIMessageChunk):
76
+ token = message_chunk.content
77
+ if isinstance(token, str):
78
+ yield token
79
+ else:
80
+ app.log.error("Received message content of type", type(token))
81
+ else:
82
+ app.log.error("Received message chunk of type", type(message_chunk))
83
+
84
+
85
+ def _hf_downloads() -> None:
86
+ hf_hub_download(
87
+ repo_id="Qwen/Qwen3-0.6B", # 1.5GB
88
+ filename="model.safetensors",
89
+ revision="c1899de289a04d12100db370d81485cdf75e47ca",
90
+ )
91
+ hf_hub_download(
92
+ repo_id="unsloth/embeddinggemma-300m", # 1.21GB
93
+ filename="model.safetensors",
94
+ revision="bfa3c846ac738e62aa61806ef9112d34acb1dc5a",
95
+ )
96
+
97
+
98
+ class HuggingFaceAgentProvider:
99
+ """Create LLM agents using Hugging Face local pipelines."""
100
+
101
+ type: Final[LocalProviderType] = LocalProviderType.HUGGING_FACE
102
+
103
+ @asynccontextmanager
104
+ async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[HuggingFaceAgent]:
105
+ """Create a Hugging Face local pipeline agent.
106
+
107
+ Args:
108
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
109
+ """
110
+ loop = asyncio.get_running_loop()
111
+ await loop.run_in_executor(None, _hf_downloads)
112
+ yield HuggingFaceAgent(
113
+ checkpoints_sqlite_db,
114
+ model_id="Qwen/Qwen3-0.6B",
115
+ embedding_model_id="unsloth/embeddinggemma-300m",
116
+ )
@@ -0,0 +1,113 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from contextlib import asynccontextmanager
5
+ from typing import TYPE_CHECKING, Final
6
+
7
+ import aiosqlite
8
+ from huggingface_hub import hf_hub_download
9
+ from langchain.agents import create_agent
10
+ from langchain.messages import AIMessageChunk, HumanMessage
11
+ from langchain_community.chat_models import ChatLlamaCpp
12
+ from langchain_community.embeddings import LlamaCppEmbeddings
13
+ from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
14
+
15
+ from rag_demo import probe
16
+ from rag_demo.constants import LocalProviderType
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import AsyncIterator
20
+ from pathlib import Path
21
+
22
+ from rag_demo.app_protocol import AppProtocol
23
+
24
+
25
+ class LlamaCppAgent:
26
+ """An LLM agent powered by Llama.cpp."""
27
+
28
+ def __init__(
29
+ self,
30
+ checkpoints_conn: aiosqlite.Connection,
31
+ model_path: str,
32
+ embedding_model_path: str,
33
+ ) -> None:
34
+ """Initialize the LlamaCppAgent.
35
+
36
+ Args:
37
+ checkpoints_conn (aiosqlite.Connection): Connection to SQLite checkpoint database.
38
+ model_path (str): Path to Llama.cpp model.
39
+ embedding_model_path (str): Path to Llama.cpp embedding model.
40
+ """
41
+ self.checkpoints_conn = checkpoints_conn
42
+ self.llm = ChatLlamaCpp(model_path=model_path, verbose=False)
43
+ self.embed = LlamaCppEmbeddings(model_path=embedding_model_path, verbose=False)
44
+ self.agent = create_agent(
45
+ model=self.llm,
46
+ system_prompt="You are a helpful assistant.",
47
+ checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
48
+ )
49
+
50
+ async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
51
+ """Stream a response from the agent.
52
+
53
+ Args:
54
+ user_message (str): User's next prompt in the conversation.
55
+ thread_id (str): Identifier for the current thread/conversation.
56
+ app (AppProtocol): Application interface, commonly used for logging.
57
+
58
+ Yields:
59
+ str: A token from the agent's response.
60
+ """
61
+ agent_stream = self.agent.astream(
62
+ {"messages": [HumanMessage(content=user_message)]},
63
+ {"configurable": {"thread_id": thread_id}},
64
+ stream_mode="messages",
65
+ )
66
+ async for message_chunk, _ in agent_stream:
67
+ if isinstance(message_chunk, AIMessageChunk):
68
+ token = message_chunk.content
69
+ if isinstance(token, str):
70
+ yield token
71
+ else:
72
+ app.log.error("Received message content of type", type(token))
73
+ else:
74
+ app.log.error("Received message chunk of type", type(message_chunk))
75
+
76
+
77
+ def _hf_downloads() -> tuple[str, str]:
78
+ model_path = hf_hub_download(
79
+ repo_id="bartowski/google_gemma-3-4b-it-GGUF",
80
+ filename="google_gemma-3-4b-it-Q6_K_L.gguf", # 3.35GB
81
+ revision="71506238f970075ca85125cd749c28b1b0eee84e",
82
+ )
83
+ embedding_model_path = hf_hub_download(
84
+ repo_id="CompendiumLabs/bge-small-en-v1.5-gguf",
85
+ filename="bge-small-en-v1.5-q8_0.gguf", # 36.8MB
86
+ revision="d32f8c040ea3b516330eeb75b72bcc2d3a780ab7",
87
+ )
88
+ return model_path, embedding_model_path
89
+
90
+
91
+ class LlamaCppAgentProvider:
92
+ """Create LLM agents using Llama.cpp."""
93
+
94
+ type: Final[LocalProviderType] = LocalProviderType.LLAMA_CPP
95
+
96
+ @asynccontextmanager
97
+ async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[LlamaCppAgent | None]:
98
+ """Attempt to create a Llama.cpp agent.
99
+
100
+ Args:
101
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
102
+ """
103
+ if probe.probe_llama_available():
104
+ loop = asyncio.get_running_loop()
105
+ model_path, embedding_model_path = await loop.run_in_executor(None, _hf_downloads)
106
+ async with aiosqlite.connect(database=checkpoints_sqlite_db) as checkpoints_conn:
107
+ yield LlamaCppAgent(
108
+ checkpoints_conn=checkpoints_conn,
109
+ model_path=model_path,
110
+ embedding_model_path=embedding_model_path,
111
+ )
112
+ else:
113
+ yield None
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import asynccontextmanager
4
+ from typing import TYPE_CHECKING, Final
5
+
6
+ import aiosqlite
7
+ import ollama
8
+ from langchain.agents import create_agent
9
+ from langchain.messages import AIMessageChunk, HumanMessage
10
+ from langchain_ollama import ChatOllama, OllamaEmbeddings
11
+ from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
12
+
13
+ from rag_demo import probe
14
+ from rag_demo.constants import LocalProviderType
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import AsyncIterator
18
+ from pathlib import Path
19
+
20
+ from rag_demo.app_protocol import AppProtocol
21
+
22
+
23
+ class OllamaAgent:
24
+ """An LLM agent powered by Ollama."""
25
+
26
+ def __init__(self, checkpoints_conn: aiosqlite.Connection) -> None:
27
+ """Initialize the OllamaAgent.
28
+
29
+ Args:
30
+ checkpoints_conn (aiosqlite.Connection): Asynchronous connection to SQLite db for checkpoints.
31
+ """
32
+ self.checkpoints_conn = checkpoints_conn
33
+ ollama.pull("gemma3:latest") # 3.3GB
34
+ ollama.pull("embeddinggemma:latest") # 621MB
35
+ self.llm = ChatOllama(
36
+ model="gemma3:latest",
37
+ validate_model_on_init=True,
38
+ temperature=0.5,
39
+ num_predict=4096,
40
+ )
41
+ self.embed = OllamaEmbeddings(model="embeddinggemma:latest")
42
+ self.agent = create_agent(
43
+ model=self.llm,
44
+ system_prompt="You are a helpful assistant.",
45
+ checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
46
+ )
47
+
48
+ async def astream(self, user_message: str, thread_id: str, app: AppProtocol) -> AsyncIterator[str]:
49
+ """Stream a response from the agent.
50
+
51
+ Args:
52
+ user_message (str): User's next prompt in the conversation.
53
+ thread_id (str): Identifier for the current thread/conversation.
54
+ app (AppProtocol): Application interface, commonly used for logging.
55
+
56
+ Yields:
57
+ str: A token from the agent's response.
58
+ """
59
+ agent_stream = self.agent.astream(
60
+ {"messages": [HumanMessage(content=user_message)]},
61
+ {"configurable": {"thread_id": thread_id}},
62
+ stream_mode="messages",
63
+ )
64
+ async for message_chunk, _ in agent_stream:
65
+ if isinstance(message_chunk, AIMessageChunk):
66
+ token = message_chunk.content
67
+ if isinstance(token, str):
68
+ yield token
69
+ else:
70
+ app.log.error("Received message content of type", type(token))
71
+ else:
72
+ app.log.error("Received message chunk of type", type(message_chunk))
73
+
74
+
75
+ class OllamaAgentProvider:
76
+ """Create LLM agents using Ollama."""
77
+
78
+ type: Final[LocalProviderType] = LocalProviderType.OLLAMA
79
+
80
+ @asynccontextmanager
81
+ async def get_agent(self, checkpoints_sqlite_db: str | Path) -> AsyncIterator[OllamaAgent | None]:
82
+ """Attempt to create an Ollama agent.
83
+
84
+ Args:
85
+ checkpoints_sqlite_db (str | Path): Connection string for SQLite database used for LangChain checkpoints.
86
+ """
87
+ if probe.probe_ollama() is not None:
88
+ async with aiosqlite.connect(database=checkpoints_sqlite_db) as checkpoints_conn:
89
+ yield OllamaAgent(checkpoints_conn=checkpoints_conn)
90
+ else:
91
+ yield None
rag_demo/app.py CHANGED
@@ -48,7 +48,7 @@ class RAGDemo(App):
48
48
  self.run_worker(self._hold_runtime())
49
49
 
50
50
  async def _hold_runtime(self) -> None:
51
- async with self.logic.runtime(app_like=self) as runtime:
51
+ async with self.logic.runtime(app=self) as runtime:
52
52
  self._runtime_future.set_result(runtime)
53
53
  # Pause the task until Textual cancels it when the application closes.
54
54
  await asyncio.Event().wait()
@@ -0,0 +1,101 @@
1
+ """Interface for the logic to call back into the app code.
2
+
3
+ This is necessary to make the logic code testable. We don't want to have to run all the app code to test the logic. And,
4
+ we want to have a high degree of confidence when mocking out the app code in logic tests. The basic pattern is that each
5
+ piece of functionality that the logic depends on will have a protocol and an implementation of that protocol using the
6
+ Textual App. In the tests, we create a mock implementation of the same protocol. Correctness of the logic is defined by
7
+ its ability to work correctly with any implementation of the protocol, not just the implementation backed by the app.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import TYPE_CHECKING, Protocol, TypeVar
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import Awaitable
16
+
17
+ from textual.worker import Worker
18
+
19
+
20
+ class LoggerProtocol(Protocol):
21
+ """Protocol that mimics textual.Logger."""
22
+
23
+ def __call__(self, *args: object, **kwargs: object) -> None:
24
+ """Log a message.
25
+
26
+ Args:
27
+ *args (object): Logged directly to the message separated by spaces.
28
+ **kwargs (object): Logged to the message as f"{key}={value!r}", separated by spaces.
29
+ """
30
+
31
+ def verbosity(self, *, verbose: bool) -> LoggerProtocol:
32
+ """Get a new logger with selective verbosity.
33
+
34
+ Note that unlike when using this method on a Textual logger directly, the type system will enforce that you use
35
+ `verbose` as a keyword argument (not a positional argument). I made this change to address ruff's FBT001 rule.
36
+ Put simply, this requirement makes the calling code easier to read.
37
+ https://docs.astral.sh/ruff/rules/boolean-type-hint-positional-argument/
38
+
39
+ Args:
40
+ verbose: True to use HIGH verbosity, otherwise NORMAL.
41
+
42
+ Returns:
43
+ New logger.
44
+ """
45
+
46
+ @property
47
+ def verbose(self) -> LoggerProtocol:
48
+ """A verbose logger."""
49
+
50
+ @property
51
+ def event(self) -> LoggerProtocol:
52
+ """Logs events."""
53
+
54
+ @property
55
+ def debug(self) -> LoggerProtocol:
56
+ """Logs debug messages."""
57
+
58
+ @property
59
+ def info(self) -> LoggerProtocol:
60
+ """Logs information."""
61
+
62
+ @property
63
+ def warning(self) -> LoggerProtocol:
64
+ """Logs warnings."""
65
+
66
+ @property
67
+ def error(self) -> LoggerProtocol:
68
+ """Logs errors."""
69
+
70
+ @property
71
+ def system(self) -> LoggerProtocol:
72
+ """Logs system information."""
73
+
74
+ @property
75
+ def logging(self) -> LoggerProtocol:
76
+ """Logs from stdlib logging module."""
77
+
78
+ @property
79
+ def worker(self) -> LoggerProtocol:
80
+ """Logs worker information."""
81
+
82
+
83
+ ResultType = TypeVar("ResultType")
84
+
85
+
86
+ class AppProtocol(Protocol):
87
+ """Protocol for the subset of what the main App can do that the runtime needs."""
88
+
89
+ def run_worker(self, work: Awaitable[ResultType], *, thread: bool = False) -> Worker[ResultType]:
90
+ """Run a coroutine in the background.
91
+
92
+ See https://textual.textualize.io/guide/workers/.
93
+
94
+ Args:
95
+ work (Awaitable[ResultType]): The coroutine to run.
96
+ thread (bool): Mark the worker as a thread worker.
97
+ """
98
+
99
+ @property
100
+ def log(self) -> LoggerProtocol:
101
+ """Returns the application logger."""
rag_demo/constants.py ADDED
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import StrEnum, auto
4
+
5
+
6
+ class LocalProviderType(StrEnum):
7
+ """Enum of supported local LLM backend provider types."""
8
+
9
+ HUGGING_FACE = auto()
10
+ LLAMA_CPP = auto()
11
+ OLLAMA = auto()
rag_demo/logic.py CHANGED
@@ -1,57 +1,44 @@
1
1
  from __future__ import annotations
2
2
 
3
- import contextlib
4
- import platform
5
3
  import time
6
4
  from contextlib import asynccontextmanager
7
- from pathlib import Path
8
- from typing import TYPE_CHECKING, Protocol, TypeVar, cast
5
+ from typing import TYPE_CHECKING, cast
9
6
 
10
- import aiosqlite
11
- import cpuinfo
12
- import httpx
13
- import huggingface_hub
14
- import llama_cpp
15
- import ollama
16
- import psutil
17
- import pynvml
18
7
  from datasets import Dataset, load_dataset
19
- from huggingface_hub import hf_hub_download
20
- from huggingface_hub.constants import HF_HUB_CACHE
21
- from langchain.agents import create_agent
22
- from langchain.messages import AIMessageChunk, HumanMessage
23
- from langchain_community.chat_models import ChatLlamaCpp
24
- from langchain_community.embeddings import LlamaCppEmbeddings
25
8
  from langchain_core.exceptions import LangChainException
26
- from langchain_ollama import ChatOllama, OllamaEmbeddings
27
- from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
28
9
 
29
10
  from rag_demo import dirs
11
+ from rag_demo.agents import (
12
+ Agent,
13
+ AgentProvider,
14
+ HuggingFaceAgentProvider,
15
+ LlamaCppAgentProvider,
16
+ OllamaAgentProvider,
17
+ )
30
18
  from rag_demo.db import AtomicIDManager
31
19
  from rag_demo.modes.chat import Response, StoppedStreamError
32
20
 
33
21
  if TYPE_CHECKING:
34
- from collections.abc import AsyncIterator, Awaitable
35
-
36
- from textual.worker import Worker
22
+ from collections.abc import AsyncIterator, Sequence
23
+ from pathlib import Path
37
24
 
25
+ from rag_demo.app_protocol import AppProtocol
26
+ from rag_demo.constants import LocalProviderType
38
27
  from rag_demo.modes import ChatScreen
39
28
 
40
- ResultType = TypeVar("ResultType")
41
29
 
30
+ class UnknownPreferredProviderError(ValueError):
31
+ """Raised when the preferred provider cannot be checked first due to being unknown."""
42
32
 
43
- class AppLike(Protocol):
44
- """Protocol for the subset of what the main App can do that the runtime needs."""
33
+ def __init__(self, preferred_provider: LocalProviderType) -> None: # noqa: D107
34
+ super().__init__(f"Unknown preferred provider: {preferred_provider}")
45
35
 
46
- def run_worker(self, work: Awaitable[ResultType]) -> Worker[ResultType]:
47
- """Run a coroutine in the background.
48
36
 
49
- See https://textual.textualize.io/guide/workers/.
37
+ class NoProviderError(RuntimeError):
38
+ """Raised when no provider could provide an agent."""
50
39
 
51
- Args:
52
- work (Awaitable[ResultType]): The coroutine to run.
53
- """
54
- ...
40
+ def __init__(self) -> None: # noqa: D107
41
+ super().__init__("No provider could provide an agent.")
55
42
 
56
43
 
57
44
  class Runtime:
@@ -60,50 +47,28 @@ class Runtime:
60
47
  def __init__(
61
48
  self,
62
49
  logic: Logic,
63
- checkpoints_conn: aiosqlite.Connection,
50
+ app: AppProtocol,
51
+ agent: Agent,
64
52
  thread_id_manager: AtomicIDManager,
65
- app_like: AppLike,
66
53
  ) -> None:
54
+ """Initialize the runtime.
55
+
56
+ Args:
57
+ logic (Logic): The application logic.
58
+ app (AppProtocol): The application interface.
59
+ agent (Agent): The agent to use.
60
+ thread_id_manager (AtomicIDManager): The thread ID manager.
61
+ """
67
62
  self.runtime_start_time = time.time()
68
63
  self.logic = logic
69
- self.checkpoints_conn = checkpoints_conn
64
+ self.app = app
65
+ self.agent = agent
70
66
  self.thread_id_manager = thread_id_manager
71
- self.app_like = app_like
72
67
 
73
68
  self.current_thread: int | None = None
74
69
  self.generating = False
75
70
 
76
- if self.logic.probe_ollama() is not None:
77
- ollama.pull("gemma3:latest") # 3.3GB
78
- ollama.pull("embeddinggemma:latest") # 621MB
79
- self.llm = ChatOllama(
80
- model="gemma3:latest",
81
- validate_model_on_init=True,
82
- temperature=0.5,
83
- num_predict=4096,
84
- )
85
- self.embed = OllamaEmbeddings(model="embeddinggemma:latest")
86
- else:
87
- model_path = hf_hub_download(
88
- repo_id="bartowski/google_gemma-3-4b-it-GGUF",
89
- filename="google_gemma-3-4b-it-Q6_K_L.gguf", # 3.35GB
90
- revision="71506238f970075ca85125cd749c28b1b0eee84e",
91
- )
92
- embedding_model_path = hf_hub_download(
93
- repo_id="CompendiumLabs/bge-small-en-v1.5-gguf",
94
- filename="bge-small-en-v1.5-q8_0.gguf", # 36.8MB
95
- revision="d32f8c040ea3b516330eeb75b72bcc2d3a780ab7",
96
- )
97
- self.llm = ChatLlamaCpp(model_path=model_path, verbose=False)
98
- self.embed = LlamaCppEmbeddings(model_path=embedding_model_path, verbose=False) # pyright: ignore[reportCallIssue]
99
-
100
- self.agent = create_agent(
101
- model=self.llm,
102
- system_prompt="You are a helpful assistant.",
103
- checkpointer=AsyncSqliteSaver(self.checkpoints_conn),
104
- )
105
-
106
- def get_rag_datasets(self) -> None:
71
+ def _get_rag_datasets(self) -> None:
107
72
  self.qa_test: Dataset = cast(
108
73
  "Dataset",
109
74
  load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer", split="test"),
@@ -123,21 +88,9 @@ class Runtime:
123
88
  """
124
89
  self.generating = True
125
90
  async with response_widget.stream_writer() as writer:
126
- agent_stream = self.agent.astream(
127
- {"messages": [HumanMessage(content=request_text)]},
128
- {"configurable": {"thread_id": thread}},
129
- stream_mode="messages",
130
- )
131
91
  try:
132
- async for message_chunk, _ in agent_stream:
133
- if isinstance(message_chunk, AIMessageChunk):
134
- token = cast("AIMessageChunk", message_chunk).content
135
- if isinstance(token, str):
136
- await writer.write(token)
137
- else:
138
- response_widget.log.error(f"Received message content of type {type(token)}")
139
- else:
140
- response_widget.log.error(f"Received message chunk of type {type(message_chunk)}")
92
+ async for message_chunk in self.agent.astream(request_text, thread, self.app):
93
+ await writer.write(message_chunk)
141
94
  except StoppedStreamError as e:
142
95
  response_widget.set_shown_object(e)
143
96
  except LangChainException as e:
@@ -145,10 +98,24 @@ class Runtime:
145
98
  self.generating = False
146
99
 
147
100
  def new_conversation(self, chat_screen: ChatScreen) -> None:
101
+ """Clear the screen and start a new conversation with the agent.
102
+
103
+ Args:
104
+ chat_screen (ChatScreen): The chat screen to clear.
105
+ """
148
106
  self.current_thread = None
149
107
  chat_screen.clear_chats()
150
108
 
151
109
  async def submit_request(self, chat_screen: ChatScreen, request_text: str) -> bool:
110
+ """Submit a new user request in the current conversation.
111
+
112
+ Args:
113
+ chat_screen (ChatScreen): The chat screen in which the request is submitted.
114
+ request_text (str): The text of the request.
115
+
116
+ Returns:
117
+ bool: True if the request was accepted for immediate processing, False otherwise.
118
+ """
152
119
  if self.generating:
153
120
  return False
154
121
  self.generating = True
@@ -168,120 +135,67 @@ class Logic:
168
135
  def __init__(
169
136
  self,
170
137
  username: str | None = None,
138
+ preferred_provider_type: LocalProviderType | None = None,
171
139
  application_start_time: float | None = None,
172
140
  checkpoints_sqlite_db: str | Path = dirs.DATA_DIR / "checkpoints.sqlite3",
173
141
  app_sqlite_db: str | Path = dirs.DATA_DIR / "app.sqlite3",
142
+ agent_providers: Sequence[AgentProvider] = (
143
+ LlamaCppAgentProvider(),
144
+ OllamaAgentProvider(),
145
+ HuggingFaceAgentProvider(),
146
+ ),
174
147
  ) -> None:
175
148
  """Initialize the application logic.
176
149
 
177
150
  Args:
178
151
  username (str | None, optional): The username provided as a command line argument. Defaults to None.
152
+ preferred_provider_type (LocalProviderType | None, optional): Provider type to prefer. Defaults to None.
179
153
  application_start_time (float | None, optional): The time when the application started. Defaults to None.
180
154
  checkpoints_sqlite_db (str | Path, optional): The connection string for the SQLite database used for
181
155
  Langchain checkpointing. Defaults to (dirs.DATA_DIR / "checkpoints.sqlite3").
182
156
  app_sqlite_db (str | Path, optional): The connection string for the SQLite database used for application
183
157
  state such a thread metadata. Defaults to (dirs.DATA_DIR / "app.sqlite3").
158
+ agent_providers (Sequence[AgentProvider], optional): Sequence of agent providers in default preference
159
+ order. If preferred_provider_type is not None, this sequence will be reordered to bring providers of
160
+ that type to the front, using the original order to break ties. Defaults to (
161
+ LlamaCppAgentProvider(),
162
+ OllamaAgentProvider(),
163
+ HuggingFaceAgentProvider(),
164
+ ).
184
165
  """
185
166
  self.logic_start_time = time.time()
186
167
  self.username = username
168
+ self.preferred_provider_type = preferred_provider_type
187
169
  self.application_start_time = application_start_time
188
170
  self.checkpoints_sqlite_db = checkpoints_sqlite_db
189
171
  self.app_sqlite_db = app_sqlite_db
172
+ self.agent_providers: Sequence[AgentProvider] = agent_providers
190
173
 
191
174
  @asynccontextmanager
192
- async def runtime(self, app_like: AppLike) -> AsyncIterator[Runtime]:
175
+ async def runtime(self, app: AppProtocol) -> AsyncIterator[Runtime]:
193
176
  """Returns a runtime context for the application."""
194
- # TODO: Do I need to set check_same_thread=False in aiosqlite.connect?
195
- async with aiosqlite.connect(database=self.checkpoints_sqlite_db) as checkpoints_conn:
196
- id_manager = AtomicIDManager(self.app_sqlite_db)
197
- await id_manager.initialize()
198
- yield Runtime(self, checkpoints_conn, id_manager, app_like)
199
-
200
- def probe_os(self) -> str:
201
- """Returns the OS name (eg 'Linux' or 'Windows'), the system name (eg 'Java'), or an empty string if unknown."""
202
- return platform.system()
203
-
204
- def probe_architecture(self) -> str:
205
- """Returns the machine architecture, such as 'i386'."""
206
- return platform.machine()
207
-
208
- def probe_cpu(self) -> str:
209
- """Returns the name of the CPU, e.g. "Intel(R) Core(TM) i7-10610U CPU @ 1.80GHz"."""
210
- return cpuinfo.get_cpu_info()["brand_raw"]
211
-
212
- def probe_ram(self) -> int:
213
- """Returns the total amount of RAM in bytes."""
214
- return psutil.virtual_memory().total
215
-
216
- def probe_disk_space(self) -> int:
217
- """Returns the amount of free space in the root directory (in bytes)."""
218
- return psutil.disk_usage("/").free
219
-
220
- def probe_llamacpp_gpu_support(self) -> bool:
221
- """Returns True if LlamaCpp supports GPU offloading, False otherwise."""
222
- return llama_cpp.llama_supports_gpu_offload()
223
-
224
- def probe_huggingface_free_cache_space(self) -> int | None:
225
- """Returns the amount of free space in the Hugging Face cache (in bytes), or None if it can't be determined."""
226
- with contextlib.suppress(FileNotFoundError):
227
- return psutil.disk_usage(HF_HUB_CACHE).free
228
- for parent_dir in Path(HF_HUB_CACHE).parents:
229
- with contextlib.suppress(FileNotFoundError):
230
- return psutil.disk_usage(str(parent_dir)).free
231
- return None
177
+ thread_id_manager = AtomicIDManager(self.app_sqlite_db)
178
+ await thread_id_manager.initialize()
232
179
 
233
- def probe_huggingface_cached_models(self) -> list[huggingface_hub.CachedRepoInfo] | None:
234
- """Returns a list of models in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
235
- # The docstring for huggingface_hub.scan_cache_dir says it raises CacheNotFound "if the cache directory does not
236
- # exist," and ValueError "if the cache directory is a file, instead of a directory."
237
- with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
238
- return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "model"]
239
- return None # Isn't it nice to be explicit?
240
-
241
- def probe_huggingface_cached_datasets(self) -> list[huggingface_hub.CachedRepoInfo] | None:
242
- """Returns a list of datasets in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
243
- with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
244
- return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "dataset"]
245
- return None
246
-
247
- def probe_nvidia(self) -> tuple[int, list[str]]:
248
- """Detect available NVIDIA GPUs and CUDA driver version.
249
-
250
- Returns:
251
- tuple[int, list[str]]: A tuple (cuda_version, nv_gpus) where cuda_version is the installed CUDA driver
252
- version and nv_gpus is a list of GPU models corresponding to installed NVIDIA GPUs
253
- """
254
- try:
255
- pynvml.nvmlInit()
256
- except pynvml.NVMLError:
257
- return -1, []
258
- cuda_version = -1
259
- nv_gpus = []
260
- try:
261
- cuda_version = pynvml.nvmlSystemGetCudaDriverVersion()
262
- for i in range(pynvml.nvmlDeviceGetCount()):
263
- handle = pynvml.nvmlDeviceGetHandleByIndex(i)
264
- nv_gpus.append(pynvml.nvmlDeviceGetName(handle))
265
- except pynvml.NVMLError:
266
- pass
267
- finally:
268
- with contextlib.suppress(pynvml.NVMLError):
269
- pynvml.nvmlShutdown()
270
- return cuda_version, nv_gpus
271
-
272
- def probe_ollama(self) -> list[ollama.ListResponse.Model] | None:
273
- """Returns a list of models installed in Ollama, or None if connecting to Ollama fails."""
274
- with contextlib.suppress(ConnectionError):
275
- return list(ollama.list().models)
276
- return None
277
-
278
- def probe_ollama_version(self) -> str | None:
279
- """Returns the Ollama version string (e.g. "0.13.5"), or None if connecting to Ollama fails."""
280
- # Yes, this uses private attributes, but that lets me use the Ollama Python lib's env var logic. If you use env
281
- # vars to direct the app to a different Ollama server, this will query the same Ollama endpoint as the
282
- # ollama.list() call above. Therefore I silence SLF001 here.
283
- with contextlib.suppress(httpx.HTTPError, KeyError, ValueError):
284
- response: httpx.Response = ollama._client._client.request("GET", "/api/version") # noqa: SLF001
285
- response.raise_for_status()
286
- return response.json()["version"]
287
- return None
180
+ agent_providers: Sequence[AgentProvider] = self.agent_providers
181
+ if self.preferred_provider_type is not None:
182
+ preferred_providers: Sequence[AgentProvider] = tuple(
183
+ ap for ap in agent_providers if ap.type == self.preferred_provider_type
184
+ )
185
+ if len(preferred_providers) == 0:
186
+ raise UnknownPreferredProviderError(self.preferred_provider_type)
187
+ agent_providers = (
188
+ *preferred_providers,
189
+ *(ap for ap in agent_providers if ap.type != self.preferred_provider_type),
190
+ )
191
+ for agent_provider in agent_providers:
192
+ async with agent_provider.get_agent(checkpoints_sqlite_db=self.checkpoints_sqlite_db) as agent:
193
+ if agent is not None:
194
+ yield Runtime(
195
+ logic=self,
196
+ app=app,
197
+ agent=agent,
198
+ thread_id_manager=thread_id_manager,
199
+ )
200
+ return
201
+ raise NoProviderError
@@ -10,11 +10,12 @@ if TYPE_CHECKING:
10
10
 
11
11
 
12
12
  class LogicProvider(Protocol):
13
- """ABC for classes that contain application logic."""
13
+ """Protocol for classes that contain application logic."""
14
14
 
15
15
  logic: Logic
16
16
 
17
- async def runtime(self) -> Runtime: ...
17
+ async def runtime(self) -> Runtime:
18
+ """Returns the application runtime of the parent app."""
18
19
 
19
20
 
20
21
  class LogicProviderScreen(Screen):
rag_demo/modes/chat.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import time
4
4
  from contextlib import asynccontextmanager
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any
6
+ from typing import TYPE_CHECKING
7
7
 
8
8
  import pyperclip
9
9
  from textual.containers import HorizontalGroup, VerticalGroup, VerticalScroll
@@ -116,7 +116,7 @@ class Response(LogicProviderWidget):
116
116
  self.set_reactive(Response.content, content)
117
117
  self._stream: ResponseWriter | None = None
118
118
  self.__object_to_show_sentinel = object()
119
- self._object_to_show: Any = self.__object_to_show_sentinel
119
+ self._object_to_show: object = self.__object_to_show_sentinel
120
120
 
121
121
  def compose(self) -> ComposeResult:
122
122
  """Compose the initial content of the widget."""
@@ -137,7 +137,8 @@ class Response(LogicProviderWidget):
137
137
  self.query_one("#object-view", Pretty).display = False
138
138
  self.query_one("#stop", Button).display = False
139
139
 
140
- def set_shown_object(self, obj: Any) -> None: # noqa: ANN401
140
+ def set_shown_object(self, obj: object) -> None:
141
+ """Show an object using a Pretty Widget instead of showing markdown or raw response content."""
141
142
  self._object_to_show = obj
142
143
  self.query_one("#markdown-view", Markdown).display = False
143
144
  self.query_one("#raw-view", Label).display = False
@@ -146,6 +147,7 @@ class Response(LogicProviderWidget):
146
147
  self.query_one("#object-view", Pretty).display = True
147
148
 
148
149
  def clear_shown_object(self) -> None:
150
+ """Stop showing an object in the widget."""
149
151
  self._object_to_show = self.__object_to_show_sentinel
150
152
  self.query_one("#object-view", Pretty).display = False
151
153
  if self.show_raw:
@@ -192,14 +194,14 @@ class Response(LogicProviderWidget):
192
194
  try:
193
195
  pyperclip.copy(self.content)
194
196
  except pyperclip.PyperclipException as e:
195
- self.app.log.error(f"Error copying to clipboard with Pyperclip: {e}")
197
+ self.app.log.error("Error copying to clipboard with Pyperclip:", e)
196
198
  checkpoint2 = time.time()
197
199
  self.notify(f"Copied {len(self.content.splitlines())} lines of text to clipboard")
198
200
  end = time.time()
199
- self.app.log.info(f"Textual copy took {checkpoint - start:.6f} seconds")
200
- self.app.log.info(f"Pyperclip copy took {checkpoint2 - checkpoint:.6f} seconds")
201
- self.app.log.info(f"Notify took {end - checkpoint2:.6f} seconds")
202
- self.app.log.info(f"Total of {end - start:.6f} seconds")
201
+ self.app.log.info("Textual copy took", f"{checkpoint - start:.6f}", "seconds")
202
+ self.app.log.info("Pyperclip copy took", f"{checkpoint2 - checkpoint:.6f}", "seconds")
203
+ self.app.log.info("Notify took", f"{end - checkpoint2:.6f}", "seconds")
204
+ self.app.log.info("Total of", f"{end - start:.6f}", "seconds")
203
205
 
204
206
  def watch_show_raw(self) -> None:
205
207
  """Handle reactive updates to the show_raw attribute by changing the visibility of the child widgets.
rag_demo/probe.py ADDED
@@ -0,0 +1,129 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import platform
5
+ from pathlib import Path
6
+
7
+ import cpuinfo
8
+ import httpx
9
+ import huggingface_hub
10
+ import ollama
11
+ import psutil
12
+ import pynvml
13
+ from huggingface_hub.constants import HF_HUB_CACHE
14
+
15
+ try:
16
+ # llama-cpp-python is an optional dependency. If it is not installed in the dev environment then we need to ignore
17
+ # unresolved-import. If it is installed, then we need to ignore unused-ignore-comment (because there is no need to
18
+ # ignore unresolved-import in this case).
19
+ import llama_cpp # ty:ignore[unresolved-import, unused-ignore-comment]
20
+
21
+ LLAMA_AVAILABLE = True
22
+ except ImportError:
23
+ LLAMA_AVAILABLE = False
24
+
25
+
26
+ def probe_os() -> str:
27
+ """Returns the OS name (eg 'Linux' or 'Windows'), the system name (eg 'Java'), or an empty string if unknown."""
28
+ return platform.system()
29
+
30
+
31
+ def probe_architecture() -> str:
32
+ """Returns the machine architecture, such as 'i386'."""
33
+ return platform.machine()
34
+
35
+
36
+ def probe_cpu() -> str:
37
+ """Returns the name of the CPU, e.g. "Intel(R) Core(TM) i7-10610U CPU @ 1.80GHz"."""
38
+ return cpuinfo.get_cpu_info()["brand_raw"]
39
+
40
+
41
+ def probe_ram() -> int:
42
+ """Returns the total amount of RAM in bytes."""
43
+ return psutil.virtual_memory().total
44
+
45
+
46
+ def probe_disk_space() -> int:
47
+ """Returns the amount of free space in the root directory (in bytes)."""
48
+ return psutil.disk_usage("/").free
49
+
50
+
51
+ def probe_llama_available() -> bool:
52
+ """Returns True if llama-cpp-python is installed, False otherwise."""
53
+ return LLAMA_AVAILABLE
54
+
55
+
56
+ def probe_llamacpp_gpu_support() -> bool:
57
+ """Returns True if the installed version of llama-cpp-python supports GPU offloading, False otherwise."""
58
+ return LLAMA_AVAILABLE and llama_cpp.llama_supports_gpu_offload()
59
+
60
+
61
+ def probe_huggingface_free_cache_space() -> int | None:
62
+ """Returns the amount of free space in the Hugging Face cache (in bytes), or None if it can't be determined."""
63
+ with contextlib.suppress(FileNotFoundError):
64
+ return psutil.disk_usage(HF_HUB_CACHE).free
65
+ for parent_dir in Path(HF_HUB_CACHE).parents:
66
+ with contextlib.suppress(FileNotFoundError):
67
+ return psutil.disk_usage(str(parent_dir)).free
68
+ return None
69
+
70
+
71
+ def probe_huggingface_cached_models() -> list[huggingface_hub.CachedRepoInfo] | None:
72
+ """Returns a list of models in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
73
+ # The docstring for huggingface_hub.scan_cache_dir says it raises CacheNotFound "if the cache directory does not
74
+ # exist," and ValueError "if the cache directory is a file, instead of a directory."
75
+ with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
76
+ return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "model"]
77
+ return None # Isn't it nice to be explicit?
78
+
79
+
80
+ def probe_huggingface_cached_datasets() -> list[huggingface_hub.CachedRepoInfo] | None:
81
+ """Returns a list of datasets in the Hugging Face cache (possibly empty), or None if the cache doesn't exist."""
82
+ with contextlib.suppress(ValueError, huggingface_hub.CacheNotFound):
83
+ return [repo for repo in huggingface_hub.scan_cache_dir().repos if repo.repo_type == "dataset"]
84
+ return None
85
+
86
+
87
+ def probe_nvidia() -> tuple[int, list[str]]:
88
+ """Detect available NVIDIA GPUs and CUDA driver version.
89
+
90
+ Returns:
91
+ tuple[int, list[str]]: A tuple (cuda_version, nv_gpus) where cuda_version is the installed CUDA driver
92
+ version and nv_gpus is a list of GPU models corresponding to installed NVIDIA GPUs
93
+ """
94
+ try:
95
+ pynvml.nvmlInit()
96
+ except pynvml.NVMLError:
97
+ return -1, []
98
+ cuda_version = -1
99
+ nv_gpus = []
100
+ try:
101
+ cuda_version = pynvml.nvmlSystemGetCudaDriverVersion()
102
+ for i in range(pynvml.nvmlDeviceGetCount()):
103
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
104
+ nv_gpus.append(pynvml.nvmlDeviceGetName(handle))
105
+ except pynvml.NVMLError:
106
+ pass
107
+ finally:
108
+ with contextlib.suppress(pynvml.NVMLError):
109
+ pynvml.nvmlShutdown()
110
+ return cuda_version, nv_gpus
111
+
112
+
113
+ def probe_ollama() -> list[ollama.ListResponse.Model] | None:
114
+ """Returns a list of models installed in Ollama, or None if connecting to Ollama fails."""
115
+ with contextlib.suppress(ConnectionError):
116
+ return list(ollama.list().models)
117
+ return None
118
+
119
+
120
+ def probe_ollama_version() -> str | None:
121
+ """Returns the Ollama version string (e.g. "0.13.5"), or None if connecting to Ollama fails."""
122
+ # Yes, this uses private attributes, but that lets me use the Ollama Python lib's env var logic. If you use env
123
+ # vars to direct the app to a different Ollama server, this will query the same Ollama endpoint as the
124
+ # ollama.list() call above. Therefore I silence SLF001 here.
125
+ with contextlib.suppress(httpx.HTTPError, KeyError, ValueError):
126
+ response: httpx.Response = ollama._client._client.request("GET", "/api/version") # noqa: SLF001
127
+ response.raise_for_status()
128
+ return response.json()["version"]
129
+ return None
@@ -1,23 +0,0 @@
1
- rag_demo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- rag_demo/__main__.py,sha256=Kak0eQWBRHVGDoWgoHs9j-Tvf_9DMzdurMxD7EM4Jr0,1054
3
- rag_demo/app.py,sha256=xejrtFApeTeyOQvWDq1H0XPyZEr8cQPn7q9KRwnV660,1812
4
- rag_demo/app.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- rag_demo/db.py,sha256=53n662Hj9sTqPNcCI2Q-6Ca_HXv3kBQdAhXU4DLhwBM,3226
6
- rag_demo/dirs.py,sha256=b0VR76kXRHSRWzaXzmAhfPr3-8WKY3ZLW8aLlaPI3Do,309
7
- rag_demo/logic.py,sha256=7PTWPs9xZJ7bbEtNDMQTX6SX4JKG8HMiq2H_YUfM-CI,12602
8
- rag_demo/markdown.py,sha256=CxzshWfANeiieZkzMlLzpRaz7tBY2_tZQxhs7b2ImKM,551
9
- rag_demo/modes/__init__.py,sha256=ccvURDWz51_IotzzlO2OH3i4_Ih_MgnGlOK_JCh45dY,91
10
- rag_demo/modes/_logic_provider.py,sha256=__eO4XVbyRHkjV_D8OHsPJX5f2R8JoJPcNXhi-w_xFY,1277
11
- rag_demo/modes/chat.py,sha256=VigWSkw6R2ea95-wZ8tgtKIccev9A-ByzJj7nzglsog,13444
12
- rag_demo/modes/chat.tcss,sha256=YANlgYygiOr-e61N9HaGGdRPM36pdr-l4u72G0ozt4o,1032
13
- rag_demo/modes/config.py,sha256=0A8IdY-GOeqCd0kMs2KMgQEsFFeVXEcnowOugtR_Q84,2609
14
- rag_demo/modes/config.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- rag_demo/modes/help.py,sha256=riV8o4WDtsim09R4cRi0xkpYLgj4CL38IrjEz_mrRmk,713
16
- rag_demo/modes/help.tcss,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- rag_demo/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- rag_demo/widgets/__init__.py,sha256=JQ1KQjdYQ4texHw2iT4IyBKgTW0SzNYbNoHAbrdCwtk,44
19
- rag_demo/widgets/escapable_input.py,sha256=VfFij4NOtQ4uX3YFETg5YPd0_nBMky9Xz-02oRdHu-w,4240
20
- jehoctor_rag_demo-0.2.0.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
21
- jehoctor_rag_demo-0.2.0.dist-info/entry_points.txt,sha256=-nDSFVcIqdTxzYM4fdveDk3xUKRhmlr_cRuqQechYh4,49
22
- jehoctor_rag_demo-0.2.0.dist-info/METADATA,sha256=wp1mdAqjB0be_1Uly4hwAoz0bjRUDI6gb6gK5SdrHRU,3531
23
- jehoctor_rag_demo-0.2.0.dist-info/RECORD,,