jehoctor-rag-demo 0.1.1.dev1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jehoctor_rag_demo-0.2.1.dist-info/METADATA +125 -0
- jehoctor_rag_demo-0.2.1.dist-info/RECORD +31 -0
- jehoctor_rag_demo-0.2.1.dist-info/entry_points.txt +3 -0
- rag_demo/__init__.py +0 -2
- rag_demo/__main__.py +42 -0
- rag_demo/agents/__init__.py +4 -0
- rag_demo/agents/base.py +40 -0
- rag_demo/agents/hugging_face.py +116 -0
- rag_demo/agents/llama_cpp.py +113 -0
- rag_demo/agents/ollama.py +91 -0
- rag_demo/app.py +58 -0
- rag_demo/app.tcss +0 -0
- rag_demo/app_protocol.py +101 -0
- rag_demo/constants.py +11 -0
- rag_demo/db.py +87 -0
- rag_demo/dirs.py +14 -0
- rag_demo/logic.py +201 -0
- rag_demo/markdown.py +17 -0
- rag_demo/modes/__init__.py +3 -0
- rag_demo/modes/_logic_provider.py +44 -0
- rag_demo/modes/chat.py +317 -0
- rag_demo/modes/chat.tcss +75 -0
- rag_demo/modes/config.py +77 -0
- rag_demo/modes/config.tcss +0 -0
- rag_demo/modes/help.py +26 -0
- rag_demo/modes/help.tcss +0 -0
- rag_demo/probe.py +129 -0
- rag_demo/widgets/__init__.py +1 -0
- rag_demo/widgets/escapable_input.py +110 -0
- jehoctor_rag_demo-0.1.1.dev1.dist-info/METADATA +0 -11
- jehoctor_rag_demo-0.1.1.dev1.dist-info/RECORD +0 -6
- jehoctor_rag_demo-0.1.1.dev1.dist-info/entry_points.txt +0 -3
- {jehoctor_rag_demo-0.1.1.dev1.dist-info → jehoctor_rag_demo-0.2.1.dist-info}/WHEEL +0 -0
rag_demo/app_protocol.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Interface for the logic to call back into the app code.
|
|
2
|
+
|
|
3
|
+
This is necessary to make the logic code testable. We don't want to have to run all the app code to test the logic. And,
|
|
4
|
+
we want to have a high degree of confidence when mocking out the app code in logic tests. The basic pattern is that each
|
|
5
|
+
piece of functionality that the logic depends on will have a protocol and an implementation of that protocol using the
|
|
6
|
+
Textual App. In the tests, we create a mock implementation of the same protocol. Correctness of the logic is defined by
|
|
7
|
+
its ability to work correctly with any implementation of the protocol, not just the implementation backed by the app.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING, Protocol, TypeVar
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from collections.abc import Awaitable
|
|
16
|
+
|
|
17
|
+
from textual.worker import Worker
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LoggerProtocol(Protocol):
|
|
21
|
+
"""Protocol that mimics textual.Logger."""
|
|
22
|
+
|
|
23
|
+
def __call__(self, *args: object, **kwargs: object) -> None:
|
|
24
|
+
"""Log a message.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
*args (object): Logged directly to the message separated by spaces.
|
|
28
|
+
**kwargs (object): Logged to the message as f"{key}={value!r}", separated by spaces.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def verbosity(self, *, verbose: bool) -> LoggerProtocol:
|
|
32
|
+
"""Get a new logger with selective verbosity.
|
|
33
|
+
|
|
34
|
+
Note that unlike when using this method on a Textual logger directly, the type system will enforce that you use
|
|
35
|
+
`verbose` as a keyword argument (not a positional argument). I made this change to address ruff's FBT001 rule.
|
|
36
|
+
Put simply, this requirement makes the calling code easier to read.
|
|
37
|
+
https://docs.astral.sh/ruff/rules/boolean-type-hint-positional-argument/
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
verbose: True to use HIGH verbosity, otherwise NORMAL.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
New logger.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def verbose(self) -> LoggerProtocol:
|
|
48
|
+
"""A verbose logger."""
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def event(self) -> LoggerProtocol:
|
|
52
|
+
"""Logs events."""
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def debug(self) -> LoggerProtocol:
|
|
56
|
+
"""Logs debug messages."""
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def info(self) -> LoggerProtocol:
|
|
60
|
+
"""Logs information."""
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def warning(self) -> LoggerProtocol:
|
|
64
|
+
"""Logs warnings."""
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def error(self) -> LoggerProtocol:
|
|
68
|
+
"""Logs errors."""
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def system(self) -> LoggerProtocol:
|
|
72
|
+
"""Logs system information."""
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def logging(self) -> LoggerProtocol:
|
|
76
|
+
"""Logs from stdlib logging module."""
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def worker(self) -> LoggerProtocol:
|
|
80
|
+
"""Logs worker information."""
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
ResultType = TypeVar("ResultType")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AppProtocol(Protocol):
|
|
87
|
+
"""Protocol for the subset of what the main App can do that the runtime needs."""
|
|
88
|
+
|
|
89
|
+
def run_worker(self, work: Awaitable[ResultType], *, thread: bool = False) -> Worker[ResultType]:
|
|
90
|
+
"""Run a coroutine in the background.
|
|
91
|
+
|
|
92
|
+
See https://textual.textualize.io/guide/workers/.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
work (Awaitable[ResultType]): The coroutine to run.
|
|
96
|
+
thread (bool): Mark the worker as a thread worker.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def log(self) -> LoggerProtocol:
|
|
101
|
+
"""Returns the application logger."""
|
rag_demo/constants.py
ADDED
rag_demo/db.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import aiosqlite
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AtomicIDManager:
|
|
12
|
+
"""A database manager for managing thread IDs.
|
|
13
|
+
|
|
14
|
+
This was written by Claude, and I fixed it up with feedback from Ruff and Flake8.
|
|
15
|
+
Maybe one day the app logic database will require something fancier, but this gets the job done now.
|
|
16
|
+
|
|
17
|
+
As you can see from the conversation with Claude, this was quite a simple task for it:
|
|
18
|
+
https://claude.ai/share/227d08ff-96a3-495a-9f56-509a1fd528f7
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, db_path: str | Path) -> None:
|
|
22
|
+
"""Initialize the database manager."""
|
|
23
|
+
self.db_path = db_path
|
|
24
|
+
|
|
25
|
+
async def initialize(self) -> None:
|
|
26
|
+
"""Initialize the database and create the table if it doesn't exist."""
|
|
27
|
+
async with aiosqlite.connect(self.db_path) as db:
|
|
28
|
+
# Enable WAL mode for better concurrent access
|
|
29
|
+
await db.execute("PRAGMA journal_mode=WAL")
|
|
30
|
+
|
|
31
|
+
await db.execute("""
|
|
32
|
+
CREATE TABLE IF NOT EXISTS claimed_ids (
|
|
33
|
+
id INTEGER PRIMARY KEY
|
|
34
|
+
)
|
|
35
|
+
""")
|
|
36
|
+
await db.commit()
|
|
37
|
+
|
|
38
|
+
async def claim_next_id(self) -> int:
|
|
39
|
+
"""Atomically find the max id, increment it, and claim it. Returns the newly claimed ID.
|
|
40
|
+
|
|
41
|
+
This operation is atomic and multiprocess-safe because:
|
|
42
|
+
1. SQLite serializes writes by default
|
|
43
|
+
2. We use IMMEDIATE transaction to acquire write lock immediately
|
|
44
|
+
3. The entire operation happens in a single transaction
|
|
45
|
+
"""
|
|
46
|
+
async with aiosqlite.connect(self.db_path) as db:
|
|
47
|
+
# Start an IMMEDIATE transaction to get write lock right away
|
|
48
|
+
await db.execute("BEGIN IMMEDIATE")
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
# Find the current max ID
|
|
52
|
+
async with db.execute("SELECT MAX(id) FROM claimed_ids") as cursor:
|
|
53
|
+
row = await cursor.fetchone()
|
|
54
|
+
max_id = row[0] if row is not None and row[0] is not None else 0
|
|
55
|
+
|
|
56
|
+
# Calculate next ID
|
|
57
|
+
next_id = max_id + 1
|
|
58
|
+
|
|
59
|
+
# Insert the new ID
|
|
60
|
+
await db.execute("INSERT INTO claimed_ids (id) VALUES (?)", (next_id,))
|
|
61
|
+
|
|
62
|
+
# Commit the transaction
|
|
63
|
+
await db.commit()
|
|
64
|
+
|
|
65
|
+
except Exception:
|
|
66
|
+
await db.rollback()
|
|
67
|
+
raise
|
|
68
|
+
|
|
69
|
+
else:
|
|
70
|
+
return next_id
|
|
71
|
+
|
|
72
|
+
async def get_all_claimed_ids(self) -> list[int]:
|
|
73
|
+
"""Retrieve all claimed IDs."""
|
|
74
|
+
async with (
|
|
75
|
+
aiosqlite.connect(self.db_path) as db,
|
|
76
|
+
db.execute("SELECT id FROM claimed_ids ORDER BY id") as cursor,
|
|
77
|
+
):
|
|
78
|
+
rows = await cursor.fetchall()
|
|
79
|
+
return [row[0] for row in rows]
|
|
80
|
+
|
|
81
|
+
async def get_count(self) -> int:
|
|
82
|
+
"""Get the total number of claimed IDs."""
|
|
83
|
+
async with aiosqlite.connect(self.db_path) as db, db.execute("SELECT COUNT(*) FROM claimed_ids") as cursor:
|
|
84
|
+
row = await cursor.fetchone()
|
|
85
|
+
if row is None:
|
|
86
|
+
raise ValueError("A SQL COUNT query should always return at least one row") # noqa: EM101, TRY003
|
|
87
|
+
return row[0]
|
rag_demo/dirs.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from platformdirs import PlatformDirs
|
|
4
|
+
|
|
5
|
+
_appdirs = PlatformDirs(appname="jehoctor-rag-demo", ensure_exists=True)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _ensure(dir_: Path) -> Path:
|
|
9
|
+
dir_.mkdir(parents=True, exist_ok=True)
|
|
10
|
+
return dir_
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
DATA_DIR = _appdirs.user_data_path
|
|
14
|
+
CONFIG_DIR = _appdirs.user_config_path
|
rag_demo/logic.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
6
|
+
|
|
7
|
+
from datasets import Dataset, load_dataset
|
|
8
|
+
from langchain_core.exceptions import LangChainException
|
|
9
|
+
|
|
10
|
+
from rag_demo import dirs
|
|
11
|
+
from rag_demo.agents import (
|
|
12
|
+
Agent,
|
|
13
|
+
AgentProvider,
|
|
14
|
+
HuggingFaceAgentProvider,
|
|
15
|
+
LlamaCppAgentProvider,
|
|
16
|
+
OllamaAgentProvider,
|
|
17
|
+
)
|
|
18
|
+
from rag_demo.db import AtomicIDManager
|
|
19
|
+
from rag_demo.modes.chat import Response, StoppedStreamError
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from collections.abc import AsyncIterator, Sequence
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
from rag_demo.app_protocol import AppProtocol
|
|
26
|
+
from rag_demo.constants import LocalProviderType
|
|
27
|
+
from rag_demo.modes import ChatScreen
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class UnknownPreferredProviderError(ValueError):
|
|
31
|
+
"""Raised when the preferred provider cannot be checked first due to being unknown."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, preferred_provider: LocalProviderType) -> None: # noqa: D107
|
|
34
|
+
super().__init__(f"Unknown preferred provider: {preferred_provider}")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class NoProviderError(RuntimeError):
|
|
38
|
+
"""Raised when no provider could provide an agent."""
|
|
39
|
+
|
|
40
|
+
def __init__(self) -> None: # noqa: D107
|
|
41
|
+
super().__init__("No provider could provide an agent.")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Runtime:
|
|
45
|
+
"""The application logic with asynchronously initialized resources."""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
logic: Logic,
|
|
50
|
+
app: AppProtocol,
|
|
51
|
+
agent: Agent,
|
|
52
|
+
thread_id_manager: AtomicIDManager,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Initialize the runtime.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
logic (Logic): The application logic.
|
|
58
|
+
app (AppProtocol): The application interface.
|
|
59
|
+
agent (Agent): The agent to use.
|
|
60
|
+
thread_id_manager (AtomicIDManager): The thread ID manager.
|
|
61
|
+
"""
|
|
62
|
+
self.runtime_start_time = time.time()
|
|
63
|
+
self.logic = logic
|
|
64
|
+
self.app = app
|
|
65
|
+
self.agent = agent
|
|
66
|
+
self.thread_id_manager = thread_id_manager
|
|
67
|
+
|
|
68
|
+
self.current_thread: int | None = None
|
|
69
|
+
self.generating = False
|
|
70
|
+
|
|
71
|
+
def _get_rag_datasets(self) -> None:
|
|
72
|
+
self.qa_test: Dataset = cast(
|
|
73
|
+
"Dataset",
|
|
74
|
+
load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer", split="test"),
|
|
75
|
+
)
|
|
76
|
+
self.corpus: Dataset = cast(
|
|
77
|
+
"Dataset",
|
|
78
|
+
load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages"),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
async def stream_response(self, response_widget: Response, request_text: str, thread: str) -> None:
|
|
82
|
+
"""Worker method for streaming tokens from the active agent to a response widget.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
response_widget (Response): Target response widget for streamed tokens.
|
|
86
|
+
request_text (str): Text of the user request.
|
|
87
|
+
thread (str): ID of the current thread.
|
|
88
|
+
"""
|
|
89
|
+
self.generating = True
|
|
90
|
+
async with response_widget.stream_writer() as writer:
|
|
91
|
+
try:
|
|
92
|
+
async for message_chunk in self.agent.astream(request_text, thread, self.app):
|
|
93
|
+
await writer.write(message_chunk)
|
|
94
|
+
except StoppedStreamError as e:
|
|
95
|
+
response_widget.set_shown_object(e)
|
|
96
|
+
except LangChainException as e:
|
|
97
|
+
response_widget.set_shown_object(e)
|
|
98
|
+
self.generating = False
|
|
99
|
+
|
|
100
|
+
def new_conversation(self, chat_screen: ChatScreen) -> None:
|
|
101
|
+
"""Clear the screen and start a new conversation with the agent.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
chat_screen (ChatScreen): The chat screen to clear.
|
|
105
|
+
"""
|
|
106
|
+
self.current_thread = None
|
|
107
|
+
chat_screen.clear_chats()
|
|
108
|
+
|
|
109
|
+
async def submit_request(self, chat_screen: ChatScreen, request_text: str) -> bool:
|
|
110
|
+
"""Submit a new user request in the current conversation.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
chat_screen (ChatScreen): The chat screen in which the request is submitted.
|
|
114
|
+
request_text (str): The text of the request.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
bool: True if the request was accepted for immediate processing, False otherwise.
|
|
118
|
+
"""
|
|
119
|
+
if self.generating:
|
|
120
|
+
return False
|
|
121
|
+
self.generating = True
|
|
122
|
+
if self.current_thread is None:
|
|
123
|
+
chat_screen.log.info("Starting new thread")
|
|
124
|
+
self.current_thread = await self.thread_id_manager.claim_next_id()
|
|
125
|
+
chat_screen.log.info("Claimed thread id", self.current_thread)
|
|
126
|
+
chat_screen.new_request(request_text)
|
|
127
|
+
response = chat_screen.new_response()
|
|
128
|
+
chat_screen.run_worker(self.stream_response(response, request_text, str(self.current_thread)))
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class Logic:
|
|
133
|
+
"""Top-level application logic."""
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
username: str | None = None,
|
|
138
|
+
preferred_provider_type: LocalProviderType | None = None,
|
|
139
|
+
application_start_time: float | None = None,
|
|
140
|
+
checkpoints_sqlite_db: str | Path = dirs.DATA_DIR / "checkpoints.sqlite3",
|
|
141
|
+
app_sqlite_db: str | Path = dirs.DATA_DIR / "app.sqlite3",
|
|
142
|
+
agent_providers: Sequence[AgentProvider] = (
|
|
143
|
+
LlamaCppAgentProvider(),
|
|
144
|
+
OllamaAgentProvider(),
|
|
145
|
+
HuggingFaceAgentProvider(),
|
|
146
|
+
),
|
|
147
|
+
) -> None:
|
|
148
|
+
"""Initialize the application logic.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
username (str | None, optional): The username provided as a command line argument. Defaults to None.
|
|
152
|
+
preferred_provider_type (LocalProviderType | None, optional): Provider type to prefer. Defaults to None.
|
|
153
|
+
application_start_time (float | None, optional): The time when the application started. Defaults to None.
|
|
154
|
+
checkpoints_sqlite_db (str | Path, optional): The connection string for the SQLite database used for
|
|
155
|
+
Langchain checkpointing. Defaults to (dirs.DATA_DIR / "checkpoints.sqlite3").
|
|
156
|
+
app_sqlite_db (str | Path, optional): The connection string for the SQLite database used for application
|
|
157
|
+
state such a thread metadata. Defaults to (dirs.DATA_DIR / "app.sqlite3").
|
|
158
|
+
agent_providers (Sequence[AgentProvider], optional): Sequence of agent providers in default preference
|
|
159
|
+
order. If preferred_provider_type is not None, this sequence will be reordered to bring providers of
|
|
160
|
+
that type to the front, using the original order to break ties. Defaults to (
|
|
161
|
+
LlamaCppAgentProvider(),
|
|
162
|
+
OllamaAgentProvider(),
|
|
163
|
+
HuggingFaceAgentProvider(),
|
|
164
|
+
).
|
|
165
|
+
"""
|
|
166
|
+
self.logic_start_time = time.time()
|
|
167
|
+
self.username = username
|
|
168
|
+
self.preferred_provider_type = preferred_provider_type
|
|
169
|
+
self.application_start_time = application_start_time
|
|
170
|
+
self.checkpoints_sqlite_db = checkpoints_sqlite_db
|
|
171
|
+
self.app_sqlite_db = app_sqlite_db
|
|
172
|
+
self.agent_providers: Sequence[AgentProvider] = agent_providers
|
|
173
|
+
|
|
174
|
+
@asynccontextmanager
|
|
175
|
+
async def runtime(self, app: AppProtocol) -> AsyncIterator[Runtime]:
|
|
176
|
+
"""Returns a runtime context for the application."""
|
|
177
|
+
thread_id_manager = AtomicIDManager(self.app_sqlite_db)
|
|
178
|
+
await thread_id_manager.initialize()
|
|
179
|
+
|
|
180
|
+
agent_providers: Sequence[AgentProvider] = self.agent_providers
|
|
181
|
+
if self.preferred_provider_type is not None:
|
|
182
|
+
preferred_providers: Sequence[AgentProvider] = tuple(
|
|
183
|
+
ap for ap in agent_providers if ap.type == self.preferred_provider_type
|
|
184
|
+
)
|
|
185
|
+
if len(preferred_providers) == 0:
|
|
186
|
+
raise UnknownPreferredProviderError(self.preferred_provider_type)
|
|
187
|
+
agent_providers = (
|
|
188
|
+
*preferred_providers,
|
|
189
|
+
*(ap for ap in agent_providers if ap.type != self.preferred_provider_type),
|
|
190
|
+
)
|
|
191
|
+
for agent_provider in agent_providers:
|
|
192
|
+
async with agent_provider.get_agent(checkpoints_sqlite_db=self.checkpoints_sqlite_db) as agent:
|
|
193
|
+
if agent is not None:
|
|
194
|
+
yield Runtime(
|
|
195
|
+
logic=self,
|
|
196
|
+
app=app,
|
|
197
|
+
agent=agent,
|
|
198
|
+
thread_id_manager=thread_id_manager,
|
|
199
|
+
)
|
|
200
|
+
return
|
|
201
|
+
raise NoProviderError
|
rag_demo/markdown.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from markdown_it import MarkdownIt
|
|
2
|
+
from markdown_it.rules_inline import StateInline
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def soft2hard_break_plugin(md: MarkdownIt) -> None:
|
|
6
|
+
md.inline.ruler2.push("soft2hard_break", _soft2hard_break_plugin)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _soft2hard_break_plugin(state: StateInline) -> None:
|
|
10
|
+
for token in state.tokens:
|
|
11
|
+
if token.type == "softbreak":
|
|
12
|
+
token.type = "hardbreak"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parser_factory() -> MarkdownIt:
|
|
16
|
+
"""Modified parser that handles newlines according to LLM conventions."""
|
|
17
|
+
return MarkdownIt("gfm-like").use(soft2hard_break_plugin)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Protocol, cast
|
|
4
|
+
|
|
5
|
+
from textual.screen import Screen
|
|
6
|
+
from textual.widget import Widget
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from rag_demo.logic import Logic, Runtime
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LogicProvider(Protocol):
|
|
13
|
+
"""Protocol for classes that contain application logic."""
|
|
14
|
+
|
|
15
|
+
logic: Logic
|
|
16
|
+
|
|
17
|
+
async def runtime(self) -> Runtime:
|
|
18
|
+
"""Returns the application runtime of the parent app."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LogicProviderScreen(Screen):
|
|
22
|
+
"""A Screen that provides access to the application logic via its parent app."""
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def logic(self) -> Logic:
|
|
26
|
+
"""Returns the application logic of the parent app."""
|
|
27
|
+
return cast("LogicProvider", self.app).logic
|
|
28
|
+
|
|
29
|
+
async def runtime(self) -> Runtime:
|
|
30
|
+
"""Returns the application runtime of the parent app."""
|
|
31
|
+
return await cast("LogicProvider", self.app).runtime()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class LogicProviderWidget(Widget):
|
|
35
|
+
"""A Widget that provides access to the application logic via its parent app."""
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def logic(self) -> Logic:
|
|
39
|
+
"""Returns the application logic of the parent app."""
|
|
40
|
+
return cast("LogicProvider", self.app).logic
|
|
41
|
+
|
|
42
|
+
async def runtime(self) -> Runtime:
|
|
43
|
+
"""Returns the application runtime of the parent app."""
|
|
44
|
+
return await cast("LogicProvider", self.app).runtime()
|