jehoctor-rag-demo 0.1.1.dev1__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,101 @@
1
+ """Interface for the logic to call back into the app code.
2
+
3
+ This is necessary to make the logic code testable. We don't want to have to run all the app code to test the logic. And,
4
+ we want to have a high degree of confidence when mocking out the app code in logic tests. The basic pattern is that each
5
+ piece of functionality that the logic depends on will have a protocol and an implementation of that protocol using the
6
+ Textual App. In the tests, we create a mock implementation of the same protocol. Correctness of the logic is defined by
7
+ its ability to work correctly with any implementation of the protocol, not just the implementation backed by the app.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import TYPE_CHECKING, Protocol, TypeVar
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import Awaitable
16
+
17
+ from textual.worker import Worker
18
+
19
+
20
+ class LoggerProtocol(Protocol):
21
+ """Protocol that mimics textual.Logger."""
22
+
23
+ def __call__(self, *args: object, **kwargs: object) -> None:
24
+ """Log a message.
25
+
26
+ Args:
27
+ *args (object): Logged directly to the message separated by spaces.
28
+ **kwargs (object): Logged to the message as f"{key}={value!r}", separated by spaces.
29
+ """
30
+
31
+ def verbosity(self, *, verbose: bool) -> LoggerProtocol:
32
+ """Get a new logger with selective verbosity.
33
+
34
+ Note that unlike when using this method on a Textual logger directly, the type system will enforce that you use
35
+ `verbose` as a keyword argument (not a positional argument). I made this change to address ruff's FBT001 rule.
36
+ Put simply, this requirement makes the calling code easier to read.
37
+ https://docs.astral.sh/ruff/rules/boolean-type-hint-positional-argument/
38
+
39
+ Args:
40
+ verbose: True to use HIGH verbosity, otherwise NORMAL.
41
+
42
+ Returns:
43
+ New logger.
44
+ """
45
+
46
+ @property
47
+ def verbose(self) -> LoggerProtocol:
48
+ """A verbose logger."""
49
+
50
+ @property
51
+ def event(self) -> LoggerProtocol:
52
+ """Logs events."""
53
+
54
+ @property
55
+ def debug(self) -> LoggerProtocol:
56
+ """Logs debug messages."""
57
+
58
+ @property
59
+ def info(self) -> LoggerProtocol:
60
+ """Logs information."""
61
+
62
+ @property
63
+ def warning(self) -> LoggerProtocol:
64
+ """Logs warnings."""
65
+
66
+ @property
67
+ def error(self) -> LoggerProtocol:
68
+ """Logs errors."""
69
+
70
+ @property
71
+ def system(self) -> LoggerProtocol:
72
+ """Logs system information."""
73
+
74
+ @property
75
+ def logging(self) -> LoggerProtocol:
76
+ """Logs from stdlib logging module."""
77
+
78
+ @property
79
+ def worker(self) -> LoggerProtocol:
80
+ """Logs worker information."""
81
+
82
+
83
+ ResultType = TypeVar("ResultType")
84
+
85
+
86
+ class AppProtocol(Protocol):
87
+ """Protocol for the subset of what the main App can do that the runtime needs."""
88
+
89
+ def run_worker(self, work: Awaitable[ResultType], *, thread: bool = False) -> Worker[ResultType]:
90
+ """Run a coroutine in the background.
91
+
92
+ See https://textual.textualize.io/guide/workers/.
93
+
94
+ Args:
95
+ work (Awaitable[ResultType]): The coroutine to run.
96
+ thread (bool): Mark the worker as a thread worker.
97
+ """
98
+
99
+ @property
100
+ def log(self) -> LoggerProtocol:
101
+ """Returns the application logger."""
rag_demo/constants.py ADDED
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import StrEnum, auto
4
+
5
+
6
+ class LocalProviderType(StrEnum):
7
+ """Enum of supported local LLM backend provider types."""
8
+
9
+ HUGGING_FACE = auto()
10
+ LLAMA_CPP = auto()
11
+ OLLAMA = auto()
rag_demo/db.py ADDED
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import aiosqlite
6
+
7
+ if TYPE_CHECKING:
8
+ from pathlib import Path
9
+
10
+
11
+ class AtomicIDManager:
12
+ """A database manager for managing thread IDs.
13
+
14
+ This was written by Claude, and I fixed it up with feedback from Ruff and Flake8.
15
+ Maybe one day the app logic database will require something fancier, but this gets the job done now.
16
+
17
+ As you can see from the conversation with Claude, this was quite a simple task for it:
18
+ https://claude.ai/share/227d08ff-96a3-495a-9f56-509a1fd528f7
19
+ """
20
+
21
+ def __init__(self, db_path: str | Path) -> None:
22
+ """Initialize the database manager."""
23
+ self.db_path = db_path
24
+
25
+ async def initialize(self) -> None:
26
+ """Initialize the database and create the table if it doesn't exist."""
27
+ async with aiosqlite.connect(self.db_path) as db:
28
+ # Enable WAL mode for better concurrent access
29
+ await db.execute("PRAGMA journal_mode=WAL")
30
+
31
+ await db.execute("""
32
+ CREATE TABLE IF NOT EXISTS claimed_ids (
33
+ id INTEGER PRIMARY KEY
34
+ )
35
+ """)
36
+ await db.commit()
37
+
38
+ async def claim_next_id(self) -> int:
39
+ """Atomically find the max id, increment it, and claim it. Returns the newly claimed ID.
40
+
41
+ This operation is atomic and multiprocess-safe because:
42
+ 1. SQLite serializes writes by default
43
+ 2. We use IMMEDIATE transaction to acquire write lock immediately
44
+ 3. The entire operation happens in a single transaction
45
+ """
46
+ async with aiosqlite.connect(self.db_path) as db:
47
+ # Start an IMMEDIATE transaction to get write lock right away
48
+ await db.execute("BEGIN IMMEDIATE")
49
+
50
+ try:
51
+ # Find the current max ID
52
+ async with db.execute("SELECT MAX(id) FROM claimed_ids") as cursor:
53
+ row = await cursor.fetchone()
54
+ max_id = row[0] if row is not None and row[0] is not None else 0
55
+
56
+ # Calculate next ID
57
+ next_id = max_id + 1
58
+
59
+ # Insert the new ID
60
+ await db.execute("INSERT INTO claimed_ids (id) VALUES (?)", (next_id,))
61
+
62
+ # Commit the transaction
63
+ await db.commit()
64
+
65
+ except Exception:
66
+ await db.rollback()
67
+ raise
68
+
69
+ else:
70
+ return next_id
71
+
72
+ async def get_all_claimed_ids(self) -> list[int]:
73
+ """Retrieve all claimed IDs."""
74
+ async with (
75
+ aiosqlite.connect(self.db_path) as db,
76
+ db.execute("SELECT id FROM claimed_ids ORDER BY id") as cursor,
77
+ ):
78
+ rows = await cursor.fetchall()
79
+ return [row[0] for row in rows]
80
+
81
+ async def get_count(self) -> int:
82
+ """Get the total number of claimed IDs."""
83
+ async with aiosqlite.connect(self.db_path) as db, db.execute("SELECT COUNT(*) FROM claimed_ids") as cursor:
84
+ row = await cursor.fetchone()
85
+ if row is None:
86
+ raise ValueError("A SQL COUNT query should always return at least one row") # noqa: EM101, TRY003
87
+ return row[0]
rag_demo/dirs.py ADDED
@@ -0,0 +1,14 @@
1
+ from pathlib import Path
2
+
3
+ from platformdirs import PlatformDirs
4
+
5
+ _appdirs = PlatformDirs(appname="jehoctor-rag-demo", ensure_exists=True)
6
+
7
+
8
+ def _ensure(dir_: Path) -> Path:
9
+ dir_.mkdir(parents=True, exist_ok=True)
10
+ return dir_
11
+
12
+
13
+ DATA_DIR = _appdirs.user_data_path
14
+ CONFIG_DIR = _appdirs.user_config_path
rag_demo/logic.py ADDED
@@ -0,0 +1,201 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from contextlib import asynccontextmanager
5
+ from typing import TYPE_CHECKING, cast
6
+
7
+ from datasets import Dataset, load_dataset
8
+ from langchain_core.exceptions import LangChainException
9
+
10
+ from rag_demo import dirs
11
+ from rag_demo.agents import (
12
+ Agent,
13
+ AgentProvider,
14
+ HuggingFaceAgentProvider,
15
+ LlamaCppAgentProvider,
16
+ OllamaAgentProvider,
17
+ )
18
+ from rag_demo.db import AtomicIDManager
19
+ from rag_demo.modes.chat import Response, StoppedStreamError
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import AsyncIterator, Sequence
23
+ from pathlib import Path
24
+
25
+ from rag_demo.app_protocol import AppProtocol
26
+ from rag_demo.constants import LocalProviderType
27
+ from rag_demo.modes import ChatScreen
28
+
29
+
30
+ class UnknownPreferredProviderError(ValueError):
31
+ """Raised when the preferred provider cannot be checked first due to being unknown."""
32
+
33
+ def __init__(self, preferred_provider: LocalProviderType) -> None: # noqa: D107
34
+ super().__init__(f"Unknown preferred provider: {preferred_provider}")
35
+
36
+
37
+ class NoProviderError(RuntimeError):
38
+ """Raised when no provider could provide an agent."""
39
+
40
+ def __init__(self) -> None: # noqa: D107
41
+ super().__init__("No provider could provide an agent.")
42
+
43
+
44
+ class Runtime:
45
+ """The application logic with asynchronously initialized resources."""
46
+
47
+ def __init__(
48
+ self,
49
+ logic: Logic,
50
+ app: AppProtocol,
51
+ agent: Agent,
52
+ thread_id_manager: AtomicIDManager,
53
+ ) -> None:
54
+ """Initialize the runtime.
55
+
56
+ Args:
57
+ logic (Logic): The application logic.
58
+ app (AppProtocol): The application interface.
59
+ agent (Agent): The agent to use.
60
+ thread_id_manager (AtomicIDManager): The thread ID manager.
61
+ """
62
+ self.runtime_start_time = time.time()
63
+ self.logic = logic
64
+ self.app = app
65
+ self.agent = agent
66
+ self.thread_id_manager = thread_id_manager
67
+
68
+ self.current_thread: int | None = None
69
+ self.generating = False
70
+
71
+ def _get_rag_datasets(self) -> None:
72
+ self.qa_test: Dataset = cast(
73
+ "Dataset",
74
+ load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer", split="test"),
75
+ )
76
+ self.corpus: Dataset = cast(
77
+ "Dataset",
78
+ load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages"),
79
+ )
80
+
81
+ async def stream_response(self, response_widget: Response, request_text: str, thread: str) -> None:
82
+ """Worker method for streaming tokens from the active agent to a response widget.
83
+
84
+ Args:
85
+ response_widget (Response): Target response widget for streamed tokens.
86
+ request_text (str): Text of the user request.
87
+ thread (str): ID of the current thread.
88
+ """
89
+ self.generating = True
90
+ async with response_widget.stream_writer() as writer:
91
+ try:
92
+ async for message_chunk in self.agent.astream(request_text, thread, self.app):
93
+ await writer.write(message_chunk)
94
+ except StoppedStreamError as e:
95
+ response_widget.set_shown_object(e)
96
+ except LangChainException as e:
97
+ response_widget.set_shown_object(e)
98
+ self.generating = False
99
+
100
+ def new_conversation(self, chat_screen: ChatScreen) -> None:
101
+ """Clear the screen and start a new conversation with the agent.
102
+
103
+ Args:
104
+ chat_screen (ChatScreen): The chat screen to clear.
105
+ """
106
+ self.current_thread = None
107
+ chat_screen.clear_chats()
108
+
109
+ async def submit_request(self, chat_screen: ChatScreen, request_text: str) -> bool:
110
+ """Submit a new user request in the current conversation.
111
+
112
+ Args:
113
+ chat_screen (ChatScreen): The chat screen in which the request is submitted.
114
+ request_text (str): The text of the request.
115
+
116
+ Returns:
117
+ bool: True if the request was accepted for immediate processing, False otherwise.
118
+ """
119
+ if self.generating:
120
+ return False
121
+ self.generating = True
122
+ if self.current_thread is None:
123
+ chat_screen.log.info("Starting new thread")
124
+ self.current_thread = await self.thread_id_manager.claim_next_id()
125
+ chat_screen.log.info("Claimed thread id", self.current_thread)
126
+ chat_screen.new_request(request_text)
127
+ response = chat_screen.new_response()
128
+ chat_screen.run_worker(self.stream_response(response, request_text, str(self.current_thread)))
129
+ return True
130
+
131
+
132
+ class Logic:
133
+ """Top-level application logic."""
134
+
135
+ def __init__(
136
+ self,
137
+ username: str | None = None,
138
+ preferred_provider_type: LocalProviderType | None = None,
139
+ application_start_time: float | None = None,
140
+ checkpoints_sqlite_db: str | Path = dirs.DATA_DIR / "checkpoints.sqlite3",
141
+ app_sqlite_db: str | Path = dirs.DATA_DIR / "app.sqlite3",
142
+ agent_providers: Sequence[AgentProvider] = (
143
+ LlamaCppAgentProvider(),
144
+ OllamaAgentProvider(),
145
+ HuggingFaceAgentProvider(),
146
+ ),
147
+ ) -> None:
148
+ """Initialize the application logic.
149
+
150
+ Args:
151
+ username (str | None, optional): The username provided as a command line argument. Defaults to None.
152
+ preferred_provider_type (LocalProviderType | None, optional): Provider type to prefer. Defaults to None.
153
+ application_start_time (float | None, optional): The time when the application started. Defaults to None.
154
+ checkpoints_sqlite_db (str | Path, optional): The connection string for the SQLite database used for
155
+ Langchain checkpointing. Defaults to (dirs.DATA_DIR / "checkpoints.sqlite3").
156
+ app_sqlite_db (str | Path, optional): The connection string for the SQLite database used for application
157
+ state such a thread metadata. Defaults to (dirs.DATA_DIR / "app.sqlite3").
158
+ agent_providers (Sequence[AgentProvider], optional): Sequence of agent providers in default preference
159
+ order. If preferred_provider_type is not None, this sequence will be reordered to bring providers of
160
+ that type to the front, using the original order to break ties. Defaults to (
161
+ LlamaCppAgentProvider(),
162
+ OllamaAgentProvider(),
163
+ HuggingFaceAgentProvider(),
164
+ ).
165
+ """
166
+ self.logic_start_time = time.time()
167
+ self.username = username
168
+ self.preferred_provider_type = preferred_provider_type
169
+ self.application_start_time = application_start_time
170
+ self.checkpoints_sqlite_db = checkpoints_sqlite_db
171
+ self.app_sqlite_db = app_sqlite_db
172
+ self.agent_providers: Sequence[AgentProvider] = agent_providers
173
+
174
+ @asynccontextmanager
175
+ async def runtime(self, app: AppProtocol) -> AsyncIterator[Runtime]:
176
+ """Returns a runtime context for the application."""
177
+ thread_id_manager = AtomicIDManager(self.app_sqlite_db)
178
+ await thread_id_manager.initialize()
179
+
180
+ agent_providers: Sequence[AgentProvider] = self.agent_providers
181
+ if self.preferred_provider_type is not None:
182
+ preferred_providers: Sequence[AgentProvider] = tuple(
183
+ ap for ap in agent_providers if ap.type == self.preferred_provider_type
184
+ )
185
+ if len(preferred_providers) == 0:
186
+ raise UnknownPreferredProviderError(self.preferred_provider_type)
187
+ agent_providers = (
188
+ *preferred_providers,
189
+ *(ap for ap in agent_providers if ap.type != self.preferred_provider_type),
190
+ )
191
+ for agent_provider in agent_providers:
192
+ async with agent_provider.get_agent(checkpoints_sqlite_db=self.checkpoints_sqlite_db) as agent:
193
+ if agent is not None:
194
+ yield Runtime(
195
+ logic=self,
196
+ app=app,
197
+ agent=agent,
198
+ thread_id_manager=thread_id_manager,
199
+ )
200
+ return
201
+ raise NoProviderError
rag_demo/markdown.py ADDED
@@ -0,0 +1,17 @@
1
+ from markdown_it import MarkdownIt
2
+ from markdown_it.rules_inline import StateInline
3
+
4
+
5
+ def soft2hard_break_plugin(md: MarkdownIt) -> None:
6
+ md.inline.ruler2.push("soft2hard_break", _soft2hard_break_plugin)
7
+
8
+
9
+ def _soft2hard_break_plugin(state: StateInline) -> None:
10
+ for token in state.tokens:
11
+ if token.type == "softbreak":
12
+ token.type = "hardbreak"
13
+
14
+
15
+ def parser_factory() -> MarkdownIt:
16
+ """Modified parser that handles newlines according to LLM conventions."""
17
+ return MarkdownIt("gfm-like").use(soft2hard_break_plugin)
@@ -0,0 +1,3 @@
1
+ from .chat import ChatScreen
2
+ from .config import ConfigScreen
3
+ from .help import HelpScreen
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Protocol, cast
4
+
5
+ from textual.screen import Screen
6
+ from textual.widget import Widget
7
+
8
+ if TYPE_CHECKING:
9
+ from rag_demo.logic import Logic, Runtime
10
+
11
+
12
+ class LogicProvider(Protocol):
13
+ """Protocol for classes that contain application logic."""
14
+
15
+ logic: Logic
16
+
17
+ async def runtime(self) -> Runtime:
18
+ """Returns the application runtime of the parent app."""
19
+
20
+
21
+ class LogicProviderScreen(Screen):
22
+ """A Screen that provides access to the application logic via its parent app."""
23
+
24
+ @property
25
+ def logic(self) -> Logic:
26
+ """Returns the application logic of the parent app."""
27
+ return cast("LogicProvider", self.app).logic
28
+
29
+ async def runtime(self) -> Runtime:
30
+ """Returns the application runtime of the parent app."""
31
+ return await cast("LogicProvider", self.app).runtime()
32
+
33
+
34
+ class LogicProviderWidget(Widget):
35
+ """A Widget that provides access to the application logic via its parent app."""
36
+
37
+ @property
38
+ def logic(self) -> Logic:
39
+ """Returns the application logic of the parent app."""
40
+ return cast("LogicProvider", self.app).logic
41
+
42
+ async def runtime(self) -> Runtime:
43
+ """Returns the application runtime of the parent app."""
44
+ return await cast("LogicProvider", self.app).runtime()