ragit 0.8__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +116 -2
- ragit/assistant.py +577 -0
- ragit/config.py +60 -0
- ragit/core/__init__.py +5 -0
- ragit/core/experiment/__init__.py +22 -0
- ragit/core/experiment/experiment.py +571 -0
- ragit/core/experiment/results.py +131 -0
- ragit/loaders.py +245 -0
- ragit/providers/__init__.py +47 -0
- ragit/providers/base.py +147 -0
- ragit/providers/function_adapter.py +237 -0
- ragit/providers/ollama.py +446 -0
- ragit/providers/sentence_transformers.py +225 -0
- ragit/utils/__init__.py +105 -0
- ragit/version.py +5 -0
- ragit-0.8.2.dist-info/METADATA +166 -0
- ragit-0.8.2.dist-info/RECORD +20 -0
- {ragit-0.8.dist-info → ragit-0.8.2.dist-info}/WHEEL +1 -1
- ragit-0.8.2.dist-info/licenses/LICENSE +201 -0
- {ragit-0.8.dist-info → ragit-0.8.2.dist-info}/top_level.txt +0 -0
- ragit/main.py +0 -354
- ragit-0.8.dist-info/LICENSE +0 -21
- ragit-0.8.dist-info/METADATA +0 -176
- ragit-0.8.dist-info/RECORD +0 -7
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Function-based provider adapter for pluggable embedding and LLM functions.
|
|
7
|
+
|
|
8
|
+
This module provides a simple adapter that wraps user-provided functions
|
|
9
|
+
into the provider interface, enabling easy integration with custom
|
|
10
|
+
embedding and LLM implementations.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import inspect
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
|
|
16
|
+
from ragit.providers.base import (
|
|
17
|
+
BaseEmbeddingProvider,
|
|
18
|
+
BaseLLMProvider,
|
|
19
|
+
EmbeddingResponse,
|
|
20
|
+
LLMResponse,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FunctionProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
25
|
+
"""
|
|
26
|
+
Adapter that wraps user-provided embedding and generation functions.
|
|
27
|
+
|
|
28
|
+
This provider allows users to bring their own embedding and/or LLM functions
|
|
29
|
+
without implementing the full provider interface.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
embed_fn : Callable[[str], list[float]], optional
|
|
34
|
+
Function that takes text and returns an embedding vector.
|
|
35
|
+
Example: `lambda text: openai.embeddings.create(input=text).data[0].embedding`
|
|
36
|
+
generate_fn : Callable, optional
|
|
37
|
+
Function for text generation. Supports two signatures:
|
|
38
|
+
- (prompt: str) -> str
|
|
39
|
+
- (prompt: str, system_prompt: str) -> str
|
|
40
|
+
embedding_dimensions : int, optional
|
|
41
|
+
Embedding dimensions. Auto-detected on first call if not provided.
|
|
42
|
+
|
|
43
|
+
Examples
|
|
44
|
+
--------
|
|
45
|
+
>>> # Simple embedding function
|
|
46
|
+
>>> def my_embed(text: str) -> list[float]:
|
|
47
|
+
... return openai.embeddings.create(input=text).data[0].embedding
|
|
48
|
+
>>>
|
|
49
|
+
>>> # Use with RAGAssistant (retrieval-only)
|
|
50
|
+
>>> assistant = RAGAssistant(docs, embed_fn=my_embed)
|
|
51
|
+
>>> results = assistant.retrieve("query")
|
|
52
|
+
>>>
|
|
53
|
+
>>> # With LLM for full RAG
|
|
54
|
+
>>> def my_llm(prompt: str, system_prompt: str = None) -> str:
|
|
55
|
+
... return openai.chat.completions.create(
|
|
56
|
+
... messages=[{"role": "user", "content": prompt}]
|
|
57
|
+
... ).choices[0].message.content
|
|
58
|
+
>>>
|
|
59
|
+
>>> assistant = RAGAssistant(docs, embed_fn=my_embed, generate_fn=my_llm)
|
|
60
|
+
>>> answer = assistant.ask("What is X?")
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
embed_fn: Callable[[str], list[float]] | None = None,
|
|
66
|
+
generate_fn: Callable[..., str] | None = None,
|
|
67
|
+
embedding_dimensions: int | None = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
self._embed_fn = embed_fn
|
|
70
|
+
self._generate_fn = generate_fn
|
|
71
|
+
self._embedding_dimensions = embedding_dimensions
|
|
72
|
+
self._generate_fn_signature: int | None = None # Number of args (1 or 2)
|
|
73
|
+
|
|
74
|
+
# Detect generate_fn signature if provided
|
|
75
|
+
if generate_fn is not None:
|
|
76
|
+
self._detect_generate_signature()
|
|
77
|
+
|
|
78
|
+
def _detect_generate_signature(self) -> None:
|
|
79
|
+
"""Detect whether generate_fn accepts 1 or 2 arguments."""
|
|
80
|
+
if self._generate_fn is None:
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
sig = inspect.signature(self._generate_fn)
|
|
84
|
+
params = [
|
|
85
|
+
p
|
|
86
|
+
for p in sig.parameters.values()
|
|
87
|
+
if p.default is inspect.Parameter.empty and p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
|
|
88
|
+
]
|
|
89
|
+
# Count required parameters
|
|
90
|
+
required_count = len(params)
|
|
91
|
+
|
|
92
|
+
if required_count == 1:
|
|
93
|
+
self._generate_fn_signature = 1
|
|
94
|
+
else:
|
|
95
|
+
# Assume 2 args if more than 1 required or if has optional args
|
|
96
|
+
self._generate_fn_signature = 2
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def provider_name(self) -> str:
|
|
100
|
+
return "function"
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def dimensions(self) -> int:
|
|
104
|
+
if self._embedding_dimensions is None:
|
|
105
|
+
raise ValueError("Embedding dimensions not yet determined. Call embed() first or provide dimensions.")
|
|
106
|
+
return self._embedding_dimensions
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def has_embedding(self) -> bool:
|
|
110
|
+
"""Check if embedding function is configured."""
|
|
111
|
+
return self._embed_fn is not None
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def has_llm(self) -> bool:
|
|
115
|
+
"""Check if LLM generation function is configured."""
|
|
116
|
+
return self._generate_fn is not None
|
|
117
|
+
|
|
118
|
+
def is_available(self) -> bool:
|
|
119
|
+
"""Check if the provider has at least one function configured."""
|
|
120
|
+
return self._embed_fn is not None or self._generate_fn is not None
|
|
121
|
+
|
|
122
|
+
def embed(self, text: str, model: str = "") -> EmbeddingResponse:
|
|
123
|
+
"""
|
|
124
|
+
Generate embedding using the provided function.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
text : str
|
|
129
|
+
Text to embed.
|
|
130
|
+
model : str
|
|
131
|
+
Model identifier (ignored, kept for interface compatibility).
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
EmbeddingResponse
|
|
136
|
+
The embedding response.
|
|
137
|
+
|
|
138
|
+
Raises
|
|
139
|
+
------
|
|
140
|
+
ValueError
|
|
141
|
+
If no embedding function was provided.
|
|
142
|
+
"""
|
|
143
|
+
if self._embed_fn is None:
|
|
144
|
+
raise ValueError("No embedding function configured. Provide embed_fn to use embeddings.")
|
|
145
|
+
|
|
146
|
+
raw_embedding = self._embed_fn(text)
|
|
147
|
+
|
|
148
|
+
# Convert to tuple for immutability
|
|
149
|
+
embedding_tuple: tuple[float, ...] = tuple(raw_embedding)
|
|
150
|
+
|
|
151
|
+
# Auto-detect dimensions on first call
|
|
152
|
+
if self._embedding_dimensions is None:
|
|
153
|
+
self._embedding_dimensions = len(embedding_tuple)
|
|
154
|
+
|
|
155
|
+
return EmbeddingResponse(
|
|
156
|
+
embedding=embedding_tuple,
|
|
157
|
+
model=model or "function",
|
|
158
|
+
provider=self.provider_name,
|
|
159
|
+
dimensions=len(embedding_tuple),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def embed_batch(self, texts: list[str], model: str = "") -> list[EmbeddingResponse]:
|
|
163
|
+
"""
|
|
164
|
+
Generate embeddings for multiple texts.
|
|
165
|
+
|
|
166
|
+
Iterates over embed_fn for each text. For providers with native batch
|
|
167
|
+
support, users should implement their own BatchEmbeddingProvider.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
texts : list[str]
|
|
172
|
+
Texts to embed.
|
|
173
|
+
model : str
|
|
174
|
+
Model identifier (ignored).
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
list[EmbeddingResponse]
|
|
179
|
+
List of embedding responses.
|
|
180
|
+
"""
|
|
181
|
+
return [self.embed(text, model) for text in texts]
|
|
182
|
+
|
|
183
|
+
def generate(
|
|
184
|
+
self,
|
|
185
|
+
prompt: str,
|
|
186
|
+
model: str = "",
|
|
187
|
+
system_prompt: str | None = None,
|
|
188
|
+
temperature: float = 0.7,
|
|
189
|
+
max_tokens: int | None = None,
|
|
190
|
+
) -> LLMResponse:
|
|
191
|
+
"""
|
|
192
|
+
Generate text using the provided function.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
prompt : str
|
|
197
|
+
The user prompt.
|
|
198
|
+
model : str
|
|
199
|
+
Model identifier (ignored, kept for interface compatibility).
|
|
200
|
+
system_prompt : str, optional
|
|
201
|
+
System prompt for context.
|
|
202
|
+
temperature : float
|
|
203
|
+
Sampling temperature (ignored if function doesn't support it).
|
|
204
|
+
max_tokens : int, optional
|
|
205
|
+
Maximum tokens (ignored if function doesn't support it).
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
LLMResponse
|
|
210
|
+
The generated response.
|
|
211
|
+
|
|
212
|
+
Raises
|
|
213
|
+
------
|
|
214
|
+
NotImplementedError
|
|
215
|
+
If no generation function was provided.
|
|
216
|
+
"""
|
|
217
|
+
if self._generate_fn is None:
|
|
218
|
+
raise NotImplementedError(
|
|
219
|
+
"No LLM configured. Provide generate_fn or a provider with LLM support "
|
|
220
|
+
"to use ask(), generate(), or generate_code() methods."
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Call with appropriate signature
|
|
224
|
+
if self._generate_fn_signature == 1:
|
|
225
|
+
# Single argument - prepend system prompt to prompt if provided
|
|
226
|
+
full_prompt = f"{system_prompt}\n\n{prompt}" if system_prompt else prompt
|
|
227
|
+
text = self._generate_fn(full_prompt)
|
|
228
|
+
else:
|
|
229
|
+
# Two arguments - pass separately
|
|
230
|
+
text = self._generate_fn(prompt, system_prompt)
|
|
231
|
+
|
|
232
|
+
return LLMResponse(
|
|
233
|
+
text=text,
|
|
234
|
+
model=model or "function",
|
|
235
|
+
provider=self.provider_name,
|
|
236
|
+
usage=None,
|
|
237
|
+
)
|
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Ollama provider for LLM and Embedding operations.
|
|
7
|
+
|
|
8
|
+
This provider connects to a local or remote Ollama server.
|
|
9
|
+
Configuration is loaded from environment variables.
|
|
10
|
+
|
|
11
|
+
Performance optimizations:
|
|
12
|
+
- Connection pooling via requests.Session()
|
|
13
|
+
- Async parallel embedding via trio + httpx
|
|
14
|
+
- LRU cache for repeated embedding queries
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from functools import lru_cache
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
import requests
|
|
22
|
+
|
|
23
|
+
from ragit.config import config
|
|
24
|
+
from ragit.providers.base import (
|
|
25
|
+
BaseEmbeddingProvider,
|
|
26
|
+
BaseLLMProvider,
|
|
27
|
+
EmbeddingResponse,
|
|
28
|
+
LLMResponse,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Module-level cache for embeddings (shared across instances)
|
|
33
|
+
@lru_cache(maxsize=2048)
|
|
34
|
+
def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
|
|
35
|
+
"""Cache embedding results to avoid redundant API calls."""
|
|
36
|
+
# Truncate oversized inputs
|
|
37
|
+
if len(text) > OllamaProvider.MAX_EMBED_CHARS:
|
|
38
|
+
text = text[: OllamaProvider.MAX_EMBED_CHARS]
|
|
39
|
+
|
|
40
|
+
response = requests.post(
|
|
41
|
+
f"{embedding_url}/api/embed",
|
|
42
|
+
headers={"Content-Type": "application/json"},
|
|
43
|
+
json={"model": model, "input": text},
|
|
44
|
+
timeout=timeout,
|
|
45
|
+
)
|
|
46
|
+
response.raise_for_status()
|
|
47
|
+
data = response.json()
|
|
48
|
+
embeddings = data.get("embeddings", [])
|
|
49
|
+
if not embeddings or not embeddings[0]:
|
|
50
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
51
|
+
return tuple(embeddings[0])
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
55
|
+
"""
|
|
56
|
+
Ollama provider for both LLM and Embedding operations.
|
|
57
|
+
|
|
58
|
+
Performance features:
|
|
59
|
+
- Connection pooling via requests.Session() for faster sequential requests
|
|
60
|
+
- Native batch embedding via /api/embed endpoint (single API call)
|
|
61
|
+
- LRU cache for repeated embedding queries (2048 entries)
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
base_url : str, optional
|
|
66
|
+
Ollama server URL (default: from OLLAMA_BASE_URL env var)
|
|
67
|
+
api_key : str, optional
|
|
68
|
+
API key for authentication (default: from OLLAMA_API_KEY env var)
|
|
69
|
+
timeout : int, optional
|
|
70
|
+
Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
|
|
71
|
+
use_cache : bool, optional
|
|
72
|
+
Enable embedding cache (default: True)
|
|
73
|
+
|
|
74
|
+
Examples
|
|
75
|
+
--------
|
|
76
|
+
>>> provider = OllamaProvider()
|
|
77
|
+
>>> response = provider.generate("What is RAG?", model="llama3")
|
|
78
|
+
>>> print(response.text)
|
|
79
|
+
|
|
80
|
+
>>> # Batch embedding (single API call)
|
|
81
|
+
>>> embeddings = provider.embed_batch(texts, "mxbai-embed-large")
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
# Known embedding model dimensions
|
|
85
|
+
EMBEDDING_DIMENSIONS: dict[str, int] = {
|
|
86
|
+
"nomic-embed-text": 768,
|
|
87
|
+
"nomic-embed-text:latest": 768,
|
|
88
|
+
"mxbai-embed-large": 1024,
|
|
89
|
+
"all-minilm": 384,
|
|
90
|
+
"snowflake-arctic-embed": 1024,
|
|
91
|
+
"qwen3-embedding": 4096,
|
|
92
|
+
"qwen3-embedding:0.6b": 1024,
|
|
93
|
+
"qwen3-embedding:4b": 2560,
|
|
94
|
+
"qwen3-embedding:8b": 4096,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
# Max characters per embedding request (safe limit for 512 token models)
|
|
98
|
+
MAX_EMBED_CHARS = 2000
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
base_url: str | None = None,
|
|
103
|
+
embedding_url: str | None = None,
|
|
104
|
+
api_key: str | None = None,
|
|
105
|
+
timeout: int | None = None,
|
|
106
|
+
use_cache: bool = True,
|
|
107
|
+
) -> None:
|
|
108
|
+
self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
|
|
109
|
+
self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
|
|
110
|
+
self.api_key = api_key or config.OLLAMA_API_KEY
|
|
111
|
+
self.timeout = timeout or config.OLLAMA_TIMEOUT
|
|
112
|
+
self.use_cache = use_cache
|
|
113
|
+
self._current_embed_model: str | None = None
|
|
114
|
+
self._current_dimensions: int = 768 # default
|
|
115
|
+
|
|
116
|
+
# Connection pooling via session
|
|
117
|
+
self._session: requests.Session | None = None
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def session(self) -> requests.Session:
|
|
121
|
+
"""Lazy-initialized session for connection pooling."""
|
|
122
|
+
if self._session is None:
|
|
123
|
+
self._session = requests.Session()
|
|
124
|
+
self._session.headers.update({"Content-Type": "application/json"})
|
|
125
|
+
if self.api_key:
|
|
126
|
+
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
|
127
|
+
return self._session
|
|
128
|
+
|
|
129
|
+
def close(self) -> None:
|
|
130
|
+
"""Close the session and release resources."""
|
|
131
|
+
if self._session is not None:
|
|
132
|
+
self._session.close()
|
|
133
|
+
self._session = None
|
|
134
|
+
|
|
135
|
+
def __del__(self) -> None:
|
|
136
|
+
"""Cleanup on garbage collection."""
|
|
137
|
+
self.close()
|
|
138
|
+
|
|
139
|
+
def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
|
|
140
|
+
"""Get request headers including authentication if API key is set."""
|
|
141
|
+
headers = {"Content-Type": "application/json"}
|
|
142
|
+
if include_auth and self.api_key:
|
|
143
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
144
|
+
return headers
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def provider_name(self) -> str:
|
|
148
|
+
return "ollama"
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def dimensions(self) -> int:
|
|
152
|
+
return self._current_dimensions
|
|
153
|
+
|
|
154
|
+
def is_available(self) -> bool:
|
|
155
|
+
"""Check if Ollama server is reachable."""
|
|
156
|
+
try:
|
|
157
|
+
response = self.session.get(
|
|
158
|
+
f"{self.base_url}/api/tags",
|
|
159
|
+
timeout=5,
|
|
160
|
+
)
|
|
161
|
+
return bool(response.status_code == 200)
|
|
162
|
+
except requests.RequestException:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
def list_models(self) -> list[dict[str, Any]]:
|
|
166
|
+
"""List available models on the Ollama server."""
|
|
167
|
+
try:
|
|
168
|
+
response = self.session.get(
|
|
169
|
+
f"{self.base_url}/api/tags",
|
|
170
|
+
timeout=10,
|
|
171
|
+
)
|
|
172
|
+
response.raise_for_status()
|
|
173
|
+
data = response.json()
|
|
174
|
+
return list(data.get("models", []))
|
|
175
|
+
except requests.RequestException as e:
|
|
176
|
+
raise ConnectionError(f"Failed to list Ollama models: {e}") from e
|
|
177
|
+
|
|
178
|
+
def generate(
|
|
179
|
+
self,
|
|
180
|
+
prompt: str,
|
|
181
|
+
model: str,
|
|
182
|
+
system_prompt: str | None = None,
|
|
183
|
+
temperature: float = 0.7,
|
|
184
|
+
max_tokens: int | None = None,
|
|
185
|
+
) -> LLMResponse:
|
|
186
|
+
"""Generate text using Ollama."""
|
|
187
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
188
|
+
if max_tokens:
|
|
189
|
+
options["num_predict"] = max_tokens
|
|
190
|
+
|
|
191
|
+
payload: dict[str, str | bool | dict[str, float | int]] = {
|
|
192
|
+
"model": model,
|
|
193
|
+
"prompt": prompt,
|
|
194
|
+
"stream": False,
|
|
195
|
+
"options": options,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
if system_prompt:
|
|
199
|
+
payload["system"] = system_prompt
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
response = self.session.post(
|
|
203
|
+
f"{self.base_url}/api/generate",
|
|
204
|
+
json=payload,
|
|
205
|
+
timeout=self.timeout,
|
|
206
|
+
)
|
|
207
|
+
response.raise_for_status()
|
|
208
|
+
data = response.json()
|
|
209
|
+
|
|
210
|
+
return LLMResponse(
|
|
211
|
+
text=data.get("response", ""),
|
|
212
|
+
model=model,
|
|
213
|
+
provider=self.provider_name,
|
|
214
|
+
usage={
|
|
215
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
216
|
+
"completion_tokens": data.get("eval_count"),
|
|
217
|
+
"total_duration": data.get("total_duration"),
|
|
218
|
+
},
|
|
219
|
+
)
|
|
220
|
+
except requests.RequestException as e:
|
|
221
|
+
raise ConnectionError(f"Ollama generate failed: {e}") from e
|
|
222
|
+
|
|
223
|
+
def embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
224
|
+
"""Generate embedding using Ollama with optional caching."""
|
|
225
|
+
self._current_embed_model = model
|
|
226
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
if self.use_cache:
|
|
230
|
+
# Use cached version
|
|
231
|
+
embedding = _cached_embedding(text, model, self.embedding_url, self.timeout)
|
|
232
|
+
else:
|
|
233
|
+
# Direct call without cache
|
|
234
|
+
truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
|
|
235
|
+
response = self.session.post(
|
|
236
|
+
f"{self.embedding_url}/api/embed",
|
|
237
|
+
json={"model": model, "input": truncated},
|
|
238
|
+
timeout=self.timeout,
|
|
239
|
+
)
|
|
240
|
+
response.raise_for_status()
|
|
241
|
+
data = response.json()
|
|
242
|
+
embeddings = data.get("embeddings", [])
|
|
243
|
+
if not embeddings or not embeddings[0]:
|
|
244
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
245
|
+
embedding = tuple(embeddings[0])
|
|
246
|
+
|
|
247
|
+
# Update dimensions from actual response
|
|
248
|
+
self._current_dimensions = len(embedding)
|
|
249
|
+
|
|
250
|
+
return EmbeddingResponse(
|
|
251
|
+
embedding=embedding,
|
|
252
|
+
model=model,
|
|
253
|
+
provider=self.provider_name,
|
|
254
|
+
dimensions=len(embedding),
|
|
255
|
+
)
|
|
256
|
+
except requests.RequestException as e:
|
|
257
|
+
raise ConnectionError(f"Ollama embed failed: {e}") from e
|
|
258
|
+
|
|
259
|
+
def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
260
|
+
"""Generate embeddings for multiple texts in a single API call.
|
|
261
|
+
|
|
262
|
+
The /api/embed endpoint supports batch inputs natively.
|
|
263
|
+
"""
|
|
264
|
+
self._current_embed_model = model
|
|
265
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
266
|
+
|
|
267
|
+
# Truncate oversized inputs
|
|
268
|
+
truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
response = self.session.post(
|
|
272
|
+
f"{self.embedding_url}/api/embed",
|
|
273
|
+
json={"model": model, "input": truncated_texts},
|
|
274
|
+
timeout=self.timeout,
|
|
275
|
+
)
|
|
276
|
+
response.raise_for_status()
|
|
277
|
+
data = response.json()
|
|
278
|
+
embeddings_list = data.get("embeddings", [])
|
|
279
|
+
|
|
280
|
+
if not embeddings_list:
|
|
281
|
+
raise ValueError("Empty embeddings returned from Ollama")
|
|
282
|
+
|
|
283
|
+
results = []
|
|
284
|
+
for embedding_data in embeddings_list:
|
|
285
|
+
embedding = tuple(embedding_data) if embedding_data else ()
|
|
286
|
+
if embedding:
|
|
287
|
+
self._current_dimensions = len(embedding)
|
|
288
|
+
|
|
289
|
+
results.append(
|
|
290
|
+
EmbeddingResponse(
|
|
291
|
+
embedding=embedding,
|
|
292
|
+
model=model,
|
|
293
|
+
provider=self.provider_name,
|
|
294
|
+
dimensions=len(embedding),
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
return results
|
|
298
|
+
except requests.RequestException as e:
|
|
299
|
+
raise ConnectionError(f"Ollama batch embed failed: {e}") from e
|
|
300
|
+
|
|
301
|
+
async def embed_batch_async(
|
|
302
|
+
self,
|
|
303
|
+
texts: list[str],
|
|
304
|
+
model: str,
|
|
305
|
+
max_concurrent: int = 10, # kept for API compatibility, no longer used
|
|
306
|
+
) -> list[EmbeddingResponse]:
|
|
307
|
+
"""Generate embeddings for multiple texts asynchronously.
|
|
308
|
+
|
|
309
|
+
The /api/embed endpoint supports batch inputs natively, so this
|
|
310
|
+
makes a single async HTTP request for all texts.
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
texts : list[str]
|
|
315
|
+
Texts to embed.
|
|
316
|
+
model : str
|
|
317
|
+
Embedding model name.
|
|
318
|
+
max_concurrent : int
|
|
319
|
+
Deprecated, kept for API compatibility. No longer used since
|
|
320
|
+
the API now supports native batching.
|
|
321
|
+
|
|
322
|
+
Returns
|
|
323
|
+
-------
|
|
324
|
+
list[EmbeddingResponse]
|
|
325
|
+
Embeddings in the same order as input texts.
|
|
326
|
+
|
|
327
|
+
Examples
|
|
328
|
+
--------
|
|
329
|
+
>>> import trio
|
|
330
|
+
>>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
|
|
331
|
+
"""
|
|
332
|
+
self._current_embed_model = model
|
|
333
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
334
|
+
|
|
335
|
+
# Truncate oversized inputs
|
|
336
|
+
truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
|
|
337
|
+
|
|
338
|
+
try:
|
|
339
|
+
async with httpx.AsyncClient() as client:
|
|
340
|
+
response = await client.post(
|
|
341
|
+
f"{self.embedding_url}/api/embed",
|
|
342
|
+
json={"model": model, "input": truncated_texts},
|
|
343
|
+
timeout=self.timeout,
|
|
344
|
+
)
|
|
345
|
+
response.raise_for_status()
|
|
346
|
+
data = response.json()
|
|
347
|
+
|
|
348
|
+
embeddings_list = data.get("embeddings", [])
|
|
349
|
+
if not embeddings_list:
|
|
350
|
+
raise ValueError("Empty embeddings returned from Ollama")
|
|
351
|
+
|
|
352
|
+
results = []
|
|
353
|
+
for embedding_data in embeddings_list:
|
|
354
|
+
embedding = tuple(embedding_data) if embedding_data else ()
|
|
355
|
+
if embedding:
|
|
356
|
+
self._current_dimensions = len(embedding)
|
|
357
|
+
|
|
358
|
+
results.append(
|
|
359
|
+
EmbeddingResponse(
|
|
360
|
+
embedding=embedding,
|
|
361
|
+
model=model,
|
|
362
|
+
provider=self.provider_name,
|
|
363
|
+
dimensions=len(embedding),
|
|
364
|
+
)
|
|
365
|
+
)
|
|
366
|
+
return results
|
|
367
|
+
except httpx.HTTPError as e:
|
|
368
|
+
raise ConnectionError(f"Ollama async batch embed failed: {e}") from e
|
|
369
|
+
|
|
370
|
+
def chat(
|
|
371
|
+
self,
|
|
372
|
+
messages: list[dict[str, str]],
|
|
373
|
+
model: str,
|
|
374
|
+
temperature: float = 0.7,
|
|
375
|
+
max_tokens: int | None = None,
|
|
376
|
+
) -> LLMResponse:
|
|
377
|
+
"""
|
|
378
|
+
Chat completion using Ollama.
|
|
379
|
+
|
|
380
|
+
Parameters
|
|
381
|
+
----------
|
|
382
|
+
messages : list[dict]
|
|
383
|
+
List of messages with 'role' and 'content' keys.
|
|
384
|
+
model : str
|
|
385
|
+
Model identifier.
|
|
386
|
+
temperature : float
|
|
387
|
+
Sampling temperature.
|
|
388
|
+
max_tokens : int, optional
|
|
389
|
+
Maximum tokens to generate.
|
|
390
|
+
|
|
391
|
+
Returns
|
|
392
|
+
-------
|
|
393
|
+
LLMResponse
|
|
394
|
+
The generated response.
|
|
395
|
+
"""
|
|
396
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
397
|
+
if max_tokens:
|
|
398
|
+
options["num_predict"] = max_tokens
|
|
399
|
+
|
|
400
|
+
payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
|
|
401
|
+
"model": model,
|
|
402
|
+
"messages": messages,
|
|
403
|
+
"stream": False,
|
|
404
|
+
"options": options,
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
try:
|
|
408
|
+
response = self.session.post(
|
|
409
|
+
f"{self.base_url}/api/chat",
|
|
410
|
+
json=payload,
|
|
411
|
+
timeout=self.timeout,
|
|
412
|
+
)
|
|
413
|
+
response.raise_for_status()
|
|
414
|
+
data = response.json()
|
|
415
|
+
|
|
416
|
+
return LLMResponse(
|
|
417
|
+
text=data.get("message", {}).get("content", ""),
|
|
418
|
+
model=model,
|
|
419
|
+
provider=self.provider_name,
|
|
420
|
+
usage={
|
|
421
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
422
|
+
"completion_tokens": data.get("eval_count"),
|
|
423
|
+
},
|
|
424
|
+
)
|
|
425
|
+
except requests.RequestException as e:
|
|
426
|
+
raise ConnectionError(f"Ollama chat failed: {e}") from e
|
|
427
|
+
|
|
428
|
+
@staticmethod
|
|
429
|
+
def clear_embedding_cache() -> None:
|
|
430
|
+
"""Clear the embedding cache."""
|
|
431
|
+
_cached_embedding.cache_clear()
|
|
432
|
+
|
|
433
|
+
@staticmethod
|
|
434
|
+
def embedding_cache_info() -> dict[str, int]:
|
|
435
|
+
"""Get embedding cache statistics."""
|
|
436
|
+
info = _cached_embedding.cache_info()
|
|
437
|
+
return {
|
|
438
|
+
"hits": info.hits,
|
|
439
|
+
"misses": info.misses,
|
|
440
|
+
"maxsize": info.maxsize or 0,
|
|
441
|
+
"currsize": info.currsize,
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
# Export the EMBEDDING_DIMENSIONS for external use
|
|
446
|
+
EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS
|