ragit 0.7.5__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +36 -9
- ragit/assistant.py +106 -23
- ragit/config.py +15 -6
- ragit/core/experiment/experiment.py +85 -20
- ragit/providers/__init__.py +30 -3
- ragit/providers/function_adapter.py +237 -0
- ragit/providers/ollama.py +1 -1
- ragit/providers/sentence_transformers.py +225 -0
- ragit/version.py +1 -1
- ragit-0.8.1.dist-info/METADATA +166 -0
- ragit-0.8.1.dist-info/RECORD +20 -0
- ragit-0.7.5.dist-info/METADATA +0 -553
- ragit-0.7.5.dist-info/RECORD +0 -18
- {ragit-0.7.5.dist-info → ragit-0.8.1.dist-info}/WHEEL +0 -0
- {ragit-0.7.5.dist-info → ragit-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {ragit-0.7.5.dist-info → ragit-0.8.1.dist-info}/top_level.txt +0 -0
ragit/__init__.py
CHANGED
|
@@ -9,14 +9,21 @@ Quick Start
|
|
|
9
9
|
-----------
|
|
10
10
|
>>> from ragit import RAGAssistant
|
|
11
11
|
>>>
|
|
12
|
-
>>> #
|
|
13
|
-
>>>
|
|
14
|
-
|
|
15
|
-
|
|
12
|
+
>>> # With custom embedding function (retrieval-only)
|
|
13
|
+
>>> def my_embed(text: str) -> list[float]:
|
|
14
|
+
... # Your embedding implementation
|
|
15
|
+
... pass
|
|
16
|
+
>>> assistant = RAGAssistant("docs/", embed_fn=my_embed)
|
|
17
|
+
>>> results = assistant.retrieve("How do I create a REST API?")
|
|
18
|
+
>>>
|
|
19
|
+
>>> # With SentenceTransformers (offline, requires ragit[transformers])
|
|
20
|
+
>>> from ragit.providers import SentenceTransformersProvider
|
|
21
|
+
>>> assistant = RAGAssistant("docs/", provider=SentenceTransformersProvider())
|
|
16
22
|
>>>
|
|
17
|
-
>>> #
|
|
18
|
-
>>>
|
|
19
|
-
>>>
|
|
23
|
+
>>> # With Ollama (explicit)
|
|
24
|
+
>>> from ragit.providers import OllamaProvider
|
|
25
|
+
>>> assistant = RAGAssistant("docs/", provider=OllamaProvider())
|
|
26
|
+
>>> answer = assistant.ask("How do I create a REST API?")
|
|
20
27
|
|
|
21
28
|
Optimization
|
|
22
29
|
------------
|
|
@@ -25,7 +32,8 @@ Optimization
|
|
|
25
32
|
>>> docs = [Document(id="doc1", content="...")]
|
|
26
33
|
>>> benchmark = [BenchmarkQuestion(question="What is X?", ground_truth="...")]
|
|
27
34
|
>>>
|
|
28
|
-
>>>
|
|
35
|
+
>>> # With explicit provider
|
|
36
|
+
>>> experiment = RagitExperiment(docs, benchmark, provider=OllamaProvider())
|
|
29
37
|
>>> results = experiment.run()
|
|
30
38
|
>>> print(results[0]) # Best configuration
|
|
31
39
|
"""
|
|
@@ -63,7 +71,12 @@ from ragit.loaders import ( # noqa: E402
|
|
|
63
71
|
load_directory,
|
|
64
72
|
load_text,
|
|
65
73
|
)
|
|
66
|
-
from ragit.providers import
|
|
74
|
+
from ragit.providers import ( # noqa: E402
|
|
75
|
+
BaseEmbeddingProvider,
|
|
76
|
+
BaseLLMProvider,
|
|
77
|
+
FunctionProvider,
|
|
78
|
+
OllamaProvider,
|
|
79
|
+
)
|
|
67
80
|
|
|
68
81
|
__all__ = [
|
|
69
82
|
"__version__",
|
|
@@ -79,7 +92,11 @@ __all__ = [
|
|
|
79
92
|
# Core classes
|
|
80
93
|
"Document",
|
|
81
94
|
"Chunk",
|
|
95
|
+
# Providers
|
|
82
96
|
"OllamaProvider",
|
|
97
|
+
"FunctionProvider",
|
|
98
|
+
"BaseLLMProvider",
|
|
99
|
+
"BaseEmbeddingProvider",
|
|
83
100
|
# Optimization
|
|
84
101
|
"RagitExperiment",
|
|
85
102
|
"BenchmarkQuestion",
|
|
@@ -87,3 +104,13 @@ __all__ = [
|
|
|
87
104
|
"EvaluationResult",
|
|
88
105
|
"ExperimentResults",
|
|
89
106
|
]
|
|
107
|
+
|
|
108
|
+
# Conditionally add SentenceTransformersProvider if available
|
|
109
|
+
try:
|
|
110
|
+
from ragit.providers import ( # noqa: E402
|
|
111
|
+
SentenceTransformersProvider as SentenceTransformersProvider,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
__all__ += ["SentenceTransformersProvider"]
|
|
115
|
+
except ImportError:
|
|
116
|
+
pass
|
ragit/assistant.py
CHANGED
|
@@ -10,16 +10,17 @@ Provides a simple interface for RAG-based tasks.
|
|
|
10
10
|
Note: This class is NOT thread-safe. Do not share instances across threads.
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
+
from collections.abc import Callable
|
|
13
14
|
from pathlib import Path
|
|
14
15
|
from typing import TYPE_CHECKING
|
|
15
16
|
|
|
16
17
|
import numpy as np
|
|
17
18
|
from numpy.typing import NDArray
|
|
18
19
|
|
|
19
|
-
from ragit.config import config
|
|
20
20
|
from ragit.core.experiment.experiment import Chunk, Document
|
|
21
21
|
from ragit.loaders import chunk_document, chunk_rst_sections, load_directory, load_text
|
|
22
|
-
from ragit.providers import
|
|
22
|
+
from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
|
|
23
|
+
from ragit.providers.function_adapter import FunctionProvider
|
|
23
24
|
|
|
24
25
|
if TYPE_CHECKING:
|
|
25
26
|
from numpy.typing import NDArray
|
|
@@ -38,48 +39,100 @@ class RAGAssistant:
|
|
|
38
39
|
- List of Document objects
|
|
39
40
|
- Path to a single file
|
|
40
41
|
- Path to a directory (will load all .txt, .md, .rst files)
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
embed_fn : Callable[[str], list[float]], optional
|
|
43
|
+
Function that takes text and returns an embedding vector.
|
|
44
|
+
If provided, creates a FunctionProvider internally.
|
|
45
|
+
generate_fn : Callable, optional
|
|
46
|
+
Function for text generation. Supports (prompt) or (prompt, system_prompt).
|
|
47
|
+
If provided without embed_fn, must also provide embed_fn.
|
|
48
|
+
provider : BaseEmbeddingProvider, optional
|
|
49
|
+
Provider for embeddings (and optionally LLM). If embed_fn is provided,
|
|
50
|
+
this is ignored for embeddings.
|
|
43
51
|
embedding_model : str, optional
|
|
44
|
-
Embedding model name
|
|
52
|
+
Embedding model name (used with provider).
|
|
45
53
|
llm_model : str, optional
|
|
46
|
-
LLM model name
|
|
54
|
+
LLM model name (used with provider).
|
|
47
55
|
chunk_size : int, optional
|
|
48
56
|
Chunk size for splitting documents (default: 512).
|
|
49
57
|
chunk_overlap : int, optional
|
|
50
58
|
Overlap between chunks (default: 50).
|
|
51
59
|
|
|
60
|
+
Raises
|
|
61
|
+
------
|
|
62
|
+
ValueError
|
|
63
|
+
If neither embed_fn nor provider is provided.
|
|
64
|
+
|
|
52
65
|
Note
|
|
53
66
|
----
|
|
54
67
|
This class is NOT thread-safe. Each thread should have its own instance.
|
|
55
68
|
|
|
56
69
|
Examples
|
|
57
70
|
--------
|
|
58
|
-
>>> #
|
|
59
|
-
>>> assistant = RAGAssistant(
|
|
71
|
+
>>> # With custom embedding function (retrieval-only)
|
|
72
|
+
>>> assistant = RAGAssistant(docs, embed_fn=my_embed)
|
|
73
|
+
>>> results = assistant.retrieve("query")
|
|
74
|
+
>>>
|
|
75
|
+
>>> # With custom embedding and LLM functions (full RAG)
|
|
76
|
+
>>> assistant = RAGAssistant(docs, embed_fn=my_embed, generate_fn=my_llm)
|
|
60
77
|
>>> answer = assistant.ask("What is X?")
|
|
61
|
-
|
|
62
|
-
>>> #
|
|
63
|
-
>>>
|
|
64
|
-
>>>
|
|
65
|
-
|
|
66
|
-
>>> #
|
|
67
|
-
>>>
|
|
68
|
-
>>>
|
|
78
|
+
>>>
|
|
79
|
+
>>> # With explicit provider
|
|
80
|
+
>>> from ragit.providers import OllamaProvider
|
|
81
|
+
>>> assistant = RAGAssistant(docs, provider=OllamaProvider())
|
|
82
|
+
>>>
|
|
83
|
+
>>> # With SentenceTransformers (offline)
|
|
84
|
+
>>> from ragit.providers import SentenceTransformersProvider
|
|
85
|
+
>>> assistant = RAGAssistant(docs, provider=SentenceTransformersProvider())
|
|
69
86
|
"""
|
|
70
87
|
|
|
71
88
|
def __init__(
|
|
72
89
|
self,
|
|
73
90
|
documents: list[Document] | str | Path,
|
|
74
|
-
|
|
91
|
+
embed_fn: Callable[[str], list[float]] | None = None,
|
|
92
|
+
generate_fn: Callable[..., str] | None = None,
|
|
93
|
+
provider: BaseEmbeddingProvider | BaseLLMProvider | None = None,
|
|
75
94
|
embedding_model: str | None = None,
|
|
76
95
|
llm_model: str | None = None,
|
|
77
96
|
chunk_size: int = 512,
|
|
78
97
|
chunk_overlap: int = 50,
|
|
79
98
|
):
|
|
80
|
-
|
|
81
|
-
self.
|
|
82
|
-
self.
|
|
99
|
+
# Resolve provider from embed_fn/generate_fn or explicit provider
|
|
100
|
+
self._embedding_provider: BaseEmbeddingProvider
|
|
101
|
+
self._llm_provider: BaseLLMProvider | None = None
|
|
102
|
+
|
|
103
|
+
if embed_fn is not None:
|
|
104
|
+
# Create FunctionProvider from provided functions
|
|
105
|
+
function_provider = FunctionProvider(
|
|
106
|
+
embed_fn=embed_fn,
|
|
107
|
+
generate_fn=generate_fn,
|
|
108
|
+
)
|
|
109
|
+
self._embedding_provider = function_provider
|
|
110
|
+
if generate_fn is not None:
|
|
111
|
+
self._llm_provider = function_provider
|
|
112
|
+
elif provider is not None and isinstance(provider, BaseLLMProvider):
|
|
113
|
+
# Use explicit provider for LLM if function_provider doesn't have LLM
|
|
114
|
+
self._llm_provider = provider
|
|
115
|
+
elif provider is not None:
|
|
116
|
+
# Use explicit provider
|
|
117
|
+
if not isinstance(provider, BaseEmbeddingProvider):
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"Provider must implement BaseEmbeddingProvider for embeddings. "
|
|
120
|
+
"Alternatively, provide embed_fn."
|
|
121
|
+
)
|
|
122
|
+
self._embedding_provider = provider
|
|
123
|
+
if isinstance(provider, BaseLLMProvider):
|
|
124
|
+
self._llm_provider = provider
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
"Must provide embed_fn or provider for embeddings. "
|
|
128
|
+
"Examples:\n"
|
|
129
|
+
" RAGAssistant(docs, embed_fn=my_embed_function)\n"
|
|
130
|
+
" RAGAssistant(docs, provider=OllamaProvider())\n"
|
|
131
|
+
" RAGAssistant(docs, provider=SentenceTransformersProvider())"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self.embedding_model = embedding_model or "default"
|
|
135
|
+
self.llm_model = llm_model or "default"
|
|
83
136
|
self.chunk_size = chunk_size
|
|
84
137
|
self.chunk_overlap = chunk_overlap
|
|
85
138
|
|
|
@@ -128,7 +181,7 @@ class RAGAssistant:
|
|
|
128
181
|
|
|
129
182
|
# Batch embed all chunks at once (single API call)
|
|
130
183
|
texts = [chunk.content for chunk in all_chunks]
|
|
131
|
-
responses = self.
|
|
184
|
+
responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
|
|
132
185
|
|
|
133
186
|
# Build embedding matrix directly (skip storing in chunks to avoid duplication)
|
|
134
187
|
embedding_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
|
|
@@ -169,7 +222,7 @@ class RAGAssistant:
|
|
|
169
222
|
return []
|
|
170
223
|
|
|
171
224
|
# Get query embedding and normalize
|
|
172
|
-
query_response = self.
|
|
225
|
+
query_response = self._embedding_provider.embed(query, self.embedding_model)
|
|
173
226
|
query_vec = np.array(query_response.embedding, dtype=np.float64)
|
|
174
227
|
query_norm = np.linalg.norm(query_vec)
|
|
175
228
|
if query_norm == 0:
|
|
@@ -209,6 +262,15 @@ class RAGAssistant:
|
|
|
209
262
|
results = self.retrieve(query, top_k)
|
|
210
263
|
return "\n\n---\n\n".join(chunk.content for chunk, _ in results)
|
|
211
264
|
|
|
265
|
+
def _ensure_llm(self) -> BaseLLMProvider:
|
|
266
|
+
"""Ensure LLM provider is available."""
|
|
267
|
+
if self._llm_provider is None:
|
|
268
|
+
raise NotImplementedError(
|
|
269
|
+
"No LLM configured. Provide generate_fn or a provider with LLM support "
|
|
270
|
+
"to use ask(), generate(), or generate_code() methods."
|
|
271
|
+
)
|
|
272
|
+
return self._llm_provider
|
|
273
|
+
|
|
212
274
|
def generate(
|
|
213
275
|
self,
|
|
214
276
|
prompt: str,
|
|
@@ -231,8 +293,14 @@ class RAGAssistant:
|
|
|
231
293
|
-------
|
|
232
294
|
str
|
|
233
295
|
Generated text.
|
|
296
|
+
|
|
297
|
+
Raises
|
|
298
|
+
------
|
|
299
|
+
NotImplementedError
|
|
300
|
+
If no LLM is configured.
|
|
234
301
|
"""
|
|
235
|
-
|
|
302
|
+
llm = self._ensure_llm()
|
|
303
|
+
response = llm.generate(
|
|
236
304
|
prompt=prompt,
|
|
237
305
|
model=self.llm_model,
|
|
238
306
|
system_prompt=system_prompt,
|
|
@@ -266,6 +334,11 @@ class RAGAssistant:
|
|
|
266
334
|
str
|
|
267
335
|
Generated answer.
|
|
268
336
|
|
|
337
|
+
Raises
|
|
338
|
+
------
|
|
339
|
+
NotImplementedError
|
|
340
|
+
If no LLM is configured.
|
|
341
|
+
|
|
269
342
|
Examples
|
|
270
343
|
--------
|
|
271
344
|
>>> answer = assistant.ask("How do I create a REST API?")
|
|
@@ -315,6 +388,11 @@ Answer:"""
|
|
|
315
388
|
str
|
|
316
389
|
Generated code (cleaned, without markdown).
|
|
317
390
|
|
|
391
|
+
Raises
|
|
392
|
+
------
|
|
393
|
+
NotImplementedError
|
|
394
|
+
If no LLM is configured.
|
|
395
|
+
|
|
318
396
|
Examples
|
|
319
397
|
--------
|
|
320
398
|
>>> code = assistant.generate_code("create a REST API with user endpoints")
|
|
@@ -357,3 +435,8 @@ Generate the {language} code:"""
|
|
|
357
435
|
def num_documents(self) -> int:
|
|
358
436
|
"""Return number of loaded documents."""
|
|
359
437
|
return len(self.documents)
|
|
438
|
+
|
|
439
|
+
@property
|
|
440
|
+
def has_llm(self) -> bool:
|
|
441
|
+
"""Check if LLM is configured."""
|
|
442
|
+
return self._llm_provider is not None
|
ragit/config.py
CHANGED
|
@@ -6,6 +6,9 @@
|
|
|
6
6
|
Ragit configuration management.
|
|
7
7
|
|
|
8
8
|
Loads configuration from environment variables and .env files.
|
|
9
|
+
|
|
10
|
+
Note: As of v0.8.0, ragit no longer has default LLM or embedding models.
|
|
11
|
+
Users must explicitly configure providers.
|
|
9
12
|
"""
|
|
10
13
|
|
|
11
14
|
import os
|
|
@@ -27,21 +30,27 @@ else:
|
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
class Config:
|
|
30
|
-
"""Ragit configuration loaded from environment variables.
|
|
33
|
+
"""Ragit configuration loaded from environment variables.
|
|
34
|
+
|
|
35
|
+
Note: As of v0.8.0, DEFAULT_LLM_MODEL and DEFAULT_EMBEDDING_MODEL are
|
|
36
|
+
no longer used as defaults. They are only read from environment variables
|
|
37
|
+
for backwards compatibility with user configurations.
|
|
38
|
+
"""
|
|
31
39
|
|
|
32
|
-
# Ollama LLM API Configuration (
|
|
40
|
+
# Ollama LLM API Configuration (used when explicitly using OllamaProvider)
|
|
33
41
|
OLLAMA_BASE_URL: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
|
34
42
|
OLLAMA_API_KEY: str | None = os.getenv("OLLAMA_API_KEY")
|
|
35
43
|
OLLAMA_TIMEOUT: int = int(os.getenv("OLLAMA_TIMEOUT", "120"))
|
|
36
44
|
|
|
37
|
-
# Ollama Embedding API Configuration
|
|
45
|
+
# Ollama Embedding API Configuration
|
|
38
46
|
OLLAMA_EMBEDDING_URL: str = os.getenv(
|
|
39
47
|
"OLLAMA_EMBEDDING_URL", os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
|
40
48
|
)
|
|
41
49
|
|
|
42
|
-
#
|
|
43
|
-
|
|
44
|
-
|
|
50
|
+
# Model settings (only used if explicitly requested, no defaults)
|
|
51
|
+
# These can still be set via environment variables for convenience
|
|
52
|
+
DEFAULT_LLM_MODEL: str | None = os.getenv("RAGIT_DEFAULT_LLM_MODEL")
|
|
53
|
+
DEFAULT_EMBEDDING_MODEL: str | None = os.getenv("RAGIT_DEFAULT_EMBEDDING_MODEL")
|
|
45
54
|
|
|
46
55
|
# Logging
|
|
47
56
|
LOG_LEVEL: str = os.getenv("RAGIT_LOG_LEVEL", "INFO")
|
|
@@ -9,6 +9,7 @@ This module provides the main experiment class for optimizing RAG hyperparameter
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import time
|
|
12
|
+
from collections.abc import Callable
|
|
12
13
|
from dataclasses import dataclass, field
|
|
13
14
|
from itertools import product
|
|
14
15
|
from typing import Any
|
|
@@ -16,9 +17,9 @@ from typing import Any
|
|
|
16
17
|
import numpy as np
|
|
17
18
|
from tqdm import tqdm
|
|
18
19
|
|
|
19
|
-
from ragit.config import config
|
|
20
20
|
from ragit.core.experiment.results import EvaluationResult
|
|
21
|
-
from ragit.providers import
|
|
21
|
+
from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
|
|
22
|
+
from ragit.providers.function_adapter import FunctionProvider
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
@dataclass
|
|
@@ -145,14 +146,28 @@ class RagitExperiment:
|
|
|
145
146
|
Documents to use as the knowledge base.
|
|
146
147
|
benchmark : list[BenchmarkQuestion]
|
|
147
148
|
Benchmark questions for evaluation.
|
|
148
|
-
|
|
149
|
-
|
|
149
|
+
embed_fn : Callable[[str], list[float]], optional
|
|
150
|
+
Function that takes text and returns an embedding vector.
|
|
151
|
+
generate_fn : Callable, optional
|
|
152
|
+
Function for text generation.
|
|
153
|
+
provider : BaseEmbeddingProvider, optional
|
|
154
|
+
Provider for embeddings and LLM. If embed_fn is provided, this is
|
|
155
|
+
ignored for embeddings but can be used for LLM.
|
|
156
|
+
|
|
157
|
+
Raises
|
|
158
|
+
------
|
|
159
|
+
ValueError
|
|
160
|
+
If neither embed_fn nor provider is provided.
|
|
150
161
|
|
|
151
162
|
Examples
|
|
152
163
|
--------
|
|
153
|
-
>>>
|
|
154
|
-
>>>
|
|
155
|
-
>>>
|
|
164
|
+
>>> # With custom functions
|
|
165
|
+
>>> experiment = RagitExperiment(docs, benchmark, embed_fn=my_embed, generate_fn=my_llm)
|
|
166
|
+
>>>
|
|
167
|
+
>>> # With explicit provider
|
|
168
|
+
>>> from ragit.providers import OllamaProvider
|
|
169
|
+
>>> experiment = RagitExperiment(docs, benchmark, provider=OllamaProvider())
|
|
170
|
+
>>>
|
|
156
171
|
>>> results = experiment.run()
|
|
157
172
|
>>> print(results[0].config) # Best configuration
|
|
158
173
|
"""
|
|
@@ -161,14 +176,59 @@ class RagitExperiment:
|
|
|
161
176
|
self,
|
|
162
177
|
documents: list[Document],
|
|
163
178
|
benchmark: list[BenchmarkQuestion],
|
|
164
|
-
|
|
179
|
+
embed_fn: Callable[[str], list[float]] | None = None,
|
|
180
|
+
generate_fn: Callable[..., str] | None = None,
|
|
181
|
+
provider: BaseEmbeddingProvider | BaseLLMProvider | None = None,
|
|
165
182
|
):
|
|
166
183
|
self.documents = documents
|
|
167
184
|
self.benchmark = benchmark
|
|
168
|
-
self.provider = provider or OllamaProvider()
|
|
169
185
|
self.vector_store = SimpleVectorStore()
|
|
170
186
|
self.results: list[EvaluationResult] = []
|
|
171
187
|
|
|
188
|
+
# Resolve provider from functions or explicit provider
|
|
189
|
+
self._embedding_provider: BaseEmbeddingProvider
|
|
190
|
+
self._llm_provider: BaseLLMProvider | None = None
|
|
191
|
+
|
|
192
|
+
if embed_fn is not None:
|
|
193
|
+
# Create FunctionProvider from provided functions
|
|
194
|
+
function_provider = FunctionProvider(
|
|
195
|
+
embed_fn=embed_fn,
|
|
196
|
+
generate_fn=generate_fn,
|
|
197
|
+
)
|
|
198
|
+
self._embedding_provider = function_provider
|
|
199
|
+
if generate_fn is not None:
|
|
200
|
+
self._llm_provider = function_provider
|
|
201
|
+
elif provider is not None and isinstance(provider, BaseLLMProvider):
|
|
202
|
+
self._llm_provider = provider
|
|
203
|
+
elif provider is not None:
|
|
204
|
+
if not isinstance(provider, BaseEmbeddingProvider):
|
|
205
|
+
raise ValueError(
|
|
206
|
+
"Provider must implement BaseEmbeddingProvider for embeddings. "
|
|
207
|
+
"Alternatively, provide embed_fn."
|
|
208
|
+
)
|
|
209
|
+
self._embedding_provider = provider
|
|
210
|
+
if isinstance(provider, BaseLLMProvider):
|
|
211
|
+
self._llm_provider = provider
|
|
212
|
+
else:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"Must provide embed_fn or provider for embeddings. "
|
|
215
|
+
"Examples:\n"
|
|
216
|
+
" RagitExperiment(docs, benchmark, embed_fn=my_embed, generate_fn=my_llm)\n"
|
|
217
|
+
" RagitExperiment(docs, benchmark, provider=OllamaProvider())"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# LLM is required for evaluation
|
|
221
|
+
if self._llm_provider is None:
|
|
222
|
+
raise ValueError(
|
|
223
|
+
"RagitExperiment requires LLM for evaluation. "
|
|
224
|
+
"Provide generate_fn or a provider with LLM support."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def provider(self) -> BaseEmbeddingProvider:
|
|
229
|
+
"""Return the embedding provider (for backwards compatibility)."""
|
|
230
|
+
return self._embedding_provider
|
|
231
|
+
|
|
172
232
|
def define_search_space(
|
|
173
233
|
self,
|
|
174
234
|
chunk_sizes: list[int] | None = None,
|
|
@@ -187,11 +247,11 @@ class RagitExperiment:
|
|
|
187
247
|
chunk_overlaps : list[int], optional
|
|
188
248
|
Chunk overlaps to test. Default: [50, 100]
|
|
189
249
|
num_chunks_options : list[int], optional
|
|
190
|
-
Number of chunks to retrieve. Default: [2, 3
|
|
250
|
+
Number of chunks to retrieve. Default: [2, 3]
|
|
191
251
|
embedding_models : list[str], optional
|
|
192
|
-
Embedding models to test. Default:
|
|
252
|
+
Embedding models to test. Default: ["default"]
|
|
193
253
|
llm_models : list[str], optional
|
|
194
|
-
LLM models to test. Default:
|
|
254
|
+
LLM models to test. Default: ["default"]
|
|
195
255
|
|
|
196
256
|
Returns
|
|
197
257
|
-------
|
|
@@ -201,8 +261,8 @@ class RagitExperiment:
|
|
|
201
261
|
chunk_sizes = chunk_sizes or [256, 512]
|
|
202
262
|
chunk_overlaps = chunk_overlaps or [50, 100]
|
|
203
263
|
num_chunks_options = num_chunks_options or [2, 3]
|
|
204
|
-
embedding_models = embedding_models or [
|
|
205
|
-
llm_models = llm_models or [
|
|
264
|
+
embedding_models = embedding_models or ["default"]
|
|
265
|
+
llm_models = llm_models or ["default"]
|
|
206
266
|
|
|
207
267
|
configs = []
|
|
208
268
|
pattern_num = 1
|
|
@@ -270,7 +330,7 @@ class RagitExperiment:
|
|
|
270
330
|
|
|
271
331
|
# Batch embed all chunks at once (single API call)
|
|
272
332
|
texts = [chunk.content for chunk in all_chunks]
|
|
273
|
-
responses = self.
|
|
333
|
+
responses = self._embedding_provider.embed_batch(texts, config.embedding_model)
|
|
274
334
|
|
|
275
335
|
for chunk, response in zip(all_chunks, responses, strict=True):
|
|
276
336
|
chunk.embedding = response.embedding
|
|
@@ -279,12 +339,15 @@ class RagitExperiment:
|
|
|
279
339
|
|
|
280
340
|
def _retrieve(self, query: str, config: RAGConfig) -> list[Chunk]:
|
|
281
341
|
"""Retrieve relevant chunks for a query."""
|
|
282
|
-
query_response = self.
|
|
342
|
+
query_response = self._embedding_provider.embed(query, config.embedding_model)
|
|
283
343
|
results = self.vector_store.search(query_response.embedding, top_k=config.num_chunks)
|
|
284
344
|
return [chunk for chunk, _ in results]
|
|
285
345
|
|
|
286
346
|
def _generate(self, question: str, context: str, config: RAGConfig) -> str:
|
|
287
347
|
"""Generate answer using RAG."""
|
|
348
|
+
if self._llm_provider is None:
|
|
349
|
+
raise ValueError("LLM provider is required for generation")
|
|
350
|
+
|
|
288
351
|
system_prompt = """You are a helpful assistant. Answer questions based ONLY on the provided context.
|
|
289
352
|
If the context doesn't contain enough information, say so. Be concise and accurate."""
|
|
290
353
|
|
|
@@ -295,7 +358,7 @@ Question: {question}
|
|
|
295
358
|
|
|
296
359
|
Answer:"""
|
|
297
360
|
|
|
298
|
-
response = self.
|
|
361
|
+
response = self._llm_provider.generate(
|
|
299
362
|
prompt=prompt,
|
|
300
363
|
model=config.llm_model,
|
|
301
364
|
system_prompt=system_prompt,
|
|
@@ -312,6 +375,8 @@ Answer:"""
|
|
|
312
375
|
config: RAGConfig,
|
|
313
376
|
) -> EvaluationScores:
|
|
314
377
|
"""Evaluate a RAG response using LLM-as-judge."""
|
|
378
|
+
if self._llm_provider is None:
|
|
379
|
+
raise ValueError("LLM provider is required for evaluation")
|
|
315
380
|
|
|
316
381
|
def extract_score(response: str) -> float:
|
|
317
382
|
"""Extract numeric score from LLM response."""
|
|
@@ -334,7 +399,7 @@ Generated Answer: {generated}
|
|
|
334
399
|
|
|
335
400
|
Respond with ONLY a number 0-100."""
|
|
336
401
|
|
|
337
|
-
resp = self.
|
|
402
|
+
resp = self._llm_provider.generate(correctness_prompt, config.llm_model)
|
|
338
403
|
correctness = extract_score(resp.text)
|
|
339
404
|
|
|
340
405
|
# Evaluate context relevance
|
|
@@ -345,7 +410,7 @@ Context: {context[:1000]}
|
|
|
345
410
|
|
|
346
411
|
Respond with ONLY a number 0-100."""
|
|
347
412
|
|
|
348
|
-
resp = self.
|
|
413
|
+
resp = self._llm_provider.generate(relevance_prompt, config.llm_model)
|
|
349
414
|
relevance = extract_score(resp.text)
|
|
350
415
|
|
|
351
416
|
# Evaluate faithfulness
|
|
@@ -356,7 +421,7 @@ Answer: {generated}
|
|
|
356
421
|
|
|
357
422
|
Respond with ONLY a number 0-100."""
|
|
358
423
|
|
|
359
|
-
resp = self.
|
|
424
|
+
resp = self._llm_provider.generate(faithfulness_prompt, config.llm_model)
|
|
360
425
|
faithfulness = extract_score(resp.text)
|
|
361
426
|
|
|
362
427
|
return EvaluationScores(
|
ragit/providers/__init__.py
CHANGED
|
@@ -6,15 +6,42 @@
|
|
|
6
6
|
Ragit Providers - LLM and Embedding providers for RAG optimization.
|
|
7
7
|
|
|
8
8
|
Supported providers:
|
|
9
|
-
- Ollama
|
|
10
|
-
-
|
|
9
|
+
- OllamaProvider: Connect to local or remote Ollama servers
|
|
10
|
+
- FunctionProvider: Wrap custom embedding/LLM functions
|
|
11
|
+
- SentenceTransformersProvider: Offline embedding (requires ragit[transformers])
|
|
12
|
+
|
|
13
|
+
Base classes for implementing custom providers:
|
|
14
|
+
- BaseLLMProvider: Abstract base for LLM providers
|
|
15
|
+
- BaseEmbeddingProvider: Abstract base for embedding providers
|
|
11
16
|
"""
|
|
12
17
|
|
|
13
|
-
from ragit.providers.base import
|
|
18
|
+
from ragit.providers.base import (
|
|
19
|
+
BaseEmbeddingProvider,
|
|
20
|
+
BaseLLMProvider,
|
|
21
|
+
EmbeddingResponse,
|
|
22
|
+
LLMResponse,
|
|
23
|
+
)
|
|
24
|
+
from ragit.providers.function_adapter import FunctionProvider
|
|
14
25
|
from ragit.providers.ollama import OllamaProvider
|
|
15
26
|
|
|
16
27
|
__all__ = [
|
|
28
|
+
# Base classes
|
|
17
29
|
"BaseLLMProvider",
|
|
18
30
|
"BaseEmbeddingProvider",
|
|
31
|
+
"LLMResponse",
|
|
32
|
+
"EmbeddingResponse",
|
|
33
|
+
# Built-in providers
|
|
19
34
|
"OllamaProvider",
|
|
35
|
+
"FunctionProvider",
|
|
20
36
|
]
|
|
37
|
+
|
|
38
|
+
# Conditionally export SentenceTransformersProvider if available
|
|
39
|
+
try:
|
|
40
|
+
from ragit.providers.sentence_transformers import (
|
|
41
|
+
SentenceTransformersProvider as SentenceTransformersProvider,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
__all__ += ["SentenceTransformersProvider"]
|
|
45
|
+
except ImportError:
|
|
46
|
+
# sentence-transformers not installed, SentenceTransformersProvider not available
|
|
47
|
+
pass
|