ragit 0.7__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +89 -2
- ragit/assistant.py +359 -0
- ragit/config.py +51 -0
- ragit/core/__init__.py +5 -0
- ragit/core/experiment/__init__.py +22 -0
- ragit/core/experiment/experiment.py +507 -0
- ragit/core/experiment/results.py +131 -0
- ragit/loaders.py +219 -0
- ragit/providers/__init__.py +20 -0
- ragit/providers/base.py +147 -0
- ragit/providers/ollama.py +284 -0
- ragit/utils/__init__.py +105 -0
- ragit/version.py +5 -0
- ragit-0.7.2.dist-info/METADATA +480 -0
- ragit-0.7.2.dist-info/RECORD +18 -0
- {ragit-0.7.dist-info → ragit-0.7.2.dist-info}/WHEEL +1 -1
- ragit-0.7.2.dist-info/licenses/LICENSE +201 -0
- ragit/main.py +0 -354
- ragit-0.7.dist-info/METADATA +0 -170
- ragit-0.7.dist-info/RECORD +0 -6
- {ragit-0.7.dist-info → ragit-0.7.2.dist-info}/top_level.txt +0 -0
ragit/loaders.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Document loading and chunking utilities.
|
|
7
|
+
|
|
8
|
+
Provides simple functions to load documents from files and chunk text.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import re
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
from ragit.core.experiment.experiment import Chunk, Document
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_text(path: str | Path) -> Document:
|
|
18
|
+
"""
|
|
19
|
+
Load a single text file as a Document.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
path : str or Path
|
|
24
|
+
Path to the text file (.txt, .md, .rst, etc.)
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
Document
|
|
29
|
+
Document with file content and metadata.
|
|
30
|
+
|
|
31
|
+
Examples
|
|
32
|
+
--------
|
|
33
|
+
>>> doc = load_text("docs/tutorial.rst")
|
|
34
|
+
>>> print(doc.id, len(doc.content))
|
|
35
|
+
"""
|
|
36
|
+
path = Path(path)
|
|
37
|
+
content = path.read_text(encoding="utf-8")
|
|
38
|
+
return Document(id=path.stem, content=content, metadata={"source": str(path), "filename": path.name})
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load_directory(path: str | Path, pattern: str = "*.txt", recursive: bool = False) -> list[Document]:
|
|
42
|
+
"""
|
|
43
|
+
Load all matching files from a directory as Documents.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
path : str or Path
|
|
48
|
+
Directory path.
|
|
49
|
+
pattern : str
|
|
50
|
+
Glob pattern for files (default: "*.txt").
|
|
51
|
+
recursive : bool
|
|
52
|
+
If True, search recursively (default: False).
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
list[Document]
|
|
57
|
+
List of loaded documents.
|
|
58
|
+
|
|
59
|
+
Examples
|
|
60
|
+
--------
|
|
61
|
+
>>> docs = load_directory("docs/", "*.rst")
|
|
62
|
+
>>> docs = load_directory("docs/", "**/*.md", recursive=True)
|
|
63
|
+
"""
|
|
64
|
+
path = Path(path)
|
|
65
|
+
glob_method = path.rglob if recursive else path.glob
|
|
66
|
+
documents = []
|
|
67
|
+
|
|
68
|
+
for file_path in sorted(glob_method(pattern)):
|
|
69
|
+
if file_path.is_file():
|
|
70
|
+
documents.append(load_text(file_path))
|
|
71
|
+
|
|
72
|
+
return documents
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def chunk_text(text: str, chunk_size: int = 512, chunk_overlap: int = 50, doc_id: str = "doc") -> list[Chunk]:
|
|
76
|
+
"""
|
|
77
|
+
Split text into overlapping chunks.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
text : str
|
|
82
|
+
Text to chunk.
|
|
83
|
+
chunk_size : int
|
|
84
|
+
Maximum characters per chunk (default: 512).
|
|
85
|
+
chunk_overlap : int
|
|
86
|
+
Overlap between chunks (default: 50).
|
|
87
|
+
doc_id : str
|
|
88
|
+
Document ID for the chunks (default: "doc").
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
list[Chunk]
|
|
93
|
+
List of text chunks.
|
|
94
|
+
|
|
95
|
+
Examples
|
|
96
|
+
--------
|
|
97
|
+
>>> chunks = chunk_text("Long document...", chunk_size=256, chunk_overlap=50)
|
|
98
|
+
"""
|
|
99
|
+
if chunk_overlap >= chunk_size:
|
|
100
|
+
raise ValueError("chunk_overlap must be less than chunk_size")
|
|
101
|
+
|
|
102
|
+
chunks = []
|
|
103
|
+
start = 0
|
|
104
|
+
chunk_idx = 0
|
|
105
|
+
|
|
106
|
+
while start < len(text):
|
|
107
|
+
end = start + chunk_size
|
|
108
|
+
chunk_text = text[start:end].strip()
|
|
109
|
+
|
|
110
|
+
if chunk_text:
|
|
111
|
+
chunks.append(Chunk(content=chunk_text, doc_id=doc_id, chunk_index=chunk_idx))
|
|
112
|
+
chunk_idx += 1
|
|
113
|
+
|
|
114
|
+
start = end - chunk_overlap
|
|
115
|
+
if start >= len(text) - chunk_overlap:
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
return chunks
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def chunk_document(doc: Document, chunk_size: int = 512, chunk_overlap: int = 50) -> list[Chunk]:
|
|
122
|
+
"""
|
|
123
|
+
Split a Document into overlapping chunks.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
doc : Document
|
|
128
|
+
Document to chunk.
|
|
129
|
+
chunk_size : int
|
|
130
|
+
Maximum characters per chunk.
|
|
131
|
+
chunk_overlap : int
|
|
132
|
+
Overlap between chunks.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
list[Chunk]
|
|
137
|
+
List of chunks from the document.
|
|
138
|
+
"""
|
|
139
|
+
return chunk_text(doc.content, chunk_size, chunk_overlap, doc.id)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def chunk_by_separator(text: str, separator: str = "\n\n", doc_id: str = "doc") -> list[Chunk]:
|
|
143
|
+
"""
|
|
144
|
+
Split text by a separator (e.g., paragraphs, sections).
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
text : str
|
|
149
|
+
Text to split.
|
|
150
|
+
separator : str
|
|
151
|
+
Separator string (default: double newline for paragraphs).
|
|
152
|
+
doc_id : str
|
|
153
|
+
Document ID for the chunks.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
list[Chunk]
|
|
158
|
+
List of chunks.
|
|
159
|
+
|
|
160
|
+
Examples
|
|
161
|
+
--------
|
|
162
|
+
>>> chunks = chunk_by_separator(text, separator="\\n---\\n")
|
|
163
|
+
"""
|
|
164
|
+
parts = text.split(separator)
|
|
165
|
+
chunks = []
|
|
166
|
+
|
|
167
|
+
for idx, part in enumerate(parts):
|
|
168
|
+
content = part.strip()
|
|
169
|
+
if content:
|
|
170
|
+
chunks.append(Chunk(content=content, doc_id=doc_id, chunk_index=idx))
|
|
171
|
+
|
|
172
|
+
return chunks
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def chunk_rst_sections(text: str, doc_id: str = "doc") -> list[Chunk]:
|
|
176
|
+
"""
|
|
177
|
+
Split RST document by section headers.
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
text : str
|
|
182
|
+
RST document text.
|
|
183
|
+
doc_id : str
|
|
184
|
+
Document ID for the chunks.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
list[Chunk]
|
|
189
|
+
List of section chunks.
|
|
190
|
+
"""
|
|
191
|
+
# Match RST section headers (title followed by underline of =, -, ~, etc.)
|
|
192
|
+
pattern = r"\n([^\n]+)\n([=\-~`\'\"^_*+#]+)\n"
|
|
193
|
+
|
|
194
|
+
# Find all section positions
|
|
195
|
+
matches = list(re.finditer(pattern, text))
|
|
196
|
+
|
|
197
|
+
if not matches:
|
|
198
|
+
# No sections found, return whole text as one chunk
|
|
199
|
+
return [Chunk(content=text.strip(), doc_id=doc_id, chunk_index=0)] if text.strip() else []
|
|
200
|
+
|
|
201
|
+
chunks = []
|
|
202
|
+
|
|
203
|
+
# Handle content before first section
|
|
204
|
+
first_pos = matches[0].start()
|
|
205
|
+
if first_pos > 0:
|
|
206
|
+
pre_content = text[:first_pos].strip()
|
|
207
|
+
if pre_content:
|
|
208
|
+
chunks.append(Chunk(content=pre_content, doc_id=doc_id, chunk_index=0))
|
|
209
|
+
|
|
210
|
+
# Extract each section
|
|
211
|
+
for i, match in enumerate(matches):
|
|
212
|
+
start = match.start()
|
|
213
|
+
end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
|
|
214
|
+
|
|
215
|
+
section_content = text[start:end].strip()
|
|
216
|
+
if section_content:
|
|
217
|
+
chunks.append(Chunk(content=section_content, doc_id=doc_id, chunk_index=len(chunks)))
|
|
218
|
+
|
|
219
|
+
return chunks
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Ragit Providers - LLM and Embedding providers for RAG optimization.
|
|
7
|
+
|
|
8
|
+
Supported providers:
|
|
9
|
+
- Ollama (local)
|
|
10
|
+
- Future: Gemini, Claude, OpenAI
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
|
|
14
|
+
from ragit.providers.ollama import OllamaProvider
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"BaseLLMProvider",
|
|
18
|
+
"BaseEmbeddingProvider",
|
|
19
|
+
"OllamaProvider",
|
|
20
|
+
]
|
ragit/providers/base.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Base provider interfaces for LLM and Embedding providers.
|
|
7
|
+
|
|
8
|
+
These abstract classes define the interface that all providers must implement,
|
|
9
|
+
making it easy to add new providers (Gemini, Claude, OpenAI, etc.)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class LLMResponse:
|
|
18
|
+
"""Response from an LLM call."""
|
|
19
|
+
|
|
20
|
+
text: str
|
|
21
|
+
model: str
|
|
22
|
+
provider: str
|
|
23
|
+
usage: dict[str, int] | None = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class EmbeddingResponse:
|
|
28
|
+
"""Response from an embedding call (immutable)."""
|
|
29
|
+
|
|
30
|
+
embedding: tuple[float, ...]
|
|
31
|
+
model: str
|
|
32
|
+
provider: str
|
|
33
|
+
dimensions: int
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BaseLLMProvider(ABC):
|
|
37
|
+
"""
|
|
38
|
+
Abstract base class for LLM providers.
|
|
39
|
+
|
|
40
|
+
Implement this to add support for new LLM providers like Gemini, Claude, etc.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def provider_name(self) -> str:
|
|
46
|
+
"""Return the provider name (e.g., 'ollama', 'gemini', 'claude')."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def generate(
|
|
51
|
+
self,
|
|
52
|
+
prompt: str,
|
|
53
|
+
model: str,
|
|
54
|
+
system_prompt: str | None = None,
|
|
55
|
+
temperature: float = 0.7,
|
|
56
|
+
max_tokens: int | None = None,
|
|
57
|
+
) -> LLMResponse:
|
|
58
|
+
"""
|
|
59
|
+
Generate text from the LLM.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
prompt : str
|
|
64
|
+
The user prompt/query.
|
|
65
|
+
model : str
|
|
66
|
+
Model identifier (e.g., 'llama3', 'qwen3-vl:235b-instruct-cloud').
|
|
67
|
+
system_prompt : str, optional
|
|
68
|
+
System prompt for context/instructions.
|
|
69
|
+
temperature : float
|
|
70
|
+
Sampling temperature (0.0 to 1.0).
|
|
71
|
+
max_tokens : int, optional
|
|
72
|
+
Maximum tokens to generate.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
LLMResponse
|
|
77
|
+
The generated response.
|
|
78
|
+
"""
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def is_available(self) -> bool:
|
|
83
|
+
"""Check if the provider is available and configured."""
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class BaseEmbeddingProvider(ABC):
|
|
88
|
+
"""
|
|
89
|
+
Abstract base class for embedding providers.
|
|
90
|
+
|
|
91
|
+
Implement this to add support for new embedding providers.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def provider_name(self) -> str:
|
|
97
|
+
"""Return the provider name."""
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def dimensions(self) -> int:
|
|
103
|
+
"""Return the embedding dimensions for the current model."""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
108
|
+
"""
|
|
109
|
+
Generate embedding for text.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
text : str
|
|
114
|
+
Text to embed.
|
|
115
|
+
model : str
|
|
116
|
+
Model identifier (e.g., 'nomic-embed-text').
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
EmbeddingResponse
|
|
121
|
+
The embedding response.
|
|
122
|
+
"""
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
@abstractmethod
|
|
126
|
+
def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
127
|
+
"""
|
|
128
|
+
Generate embeddings for multiple texts.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
texts : list[str]
|
|
133
|
+
Texts to embed.
|
|
134
|
+
model : str
|
|
135
|
+
Model identifier.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
list[EmbeddingResponse]
|
|
140
|
+
List of embedding responses.
|
|
141
|
+
"""
|
|
142
|
+
pass
|
|
143
|
+
|
|
144
|
+
@abstractmethod
|
|
145
|
+
def is_available(self) -> bool:
|
|
146
|
+
"""Check if the provider is available and configured."""
|
|
147
|
+
pass
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Ollama provider for LLM and Embedding operations.
|
|
7
|
+
|
|
8
|
+
This provider connects to a local or remote Ollama server.
|
|
9
|
+
Configuration is loaded from environment variables.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import requests
|
|
13
|
+
|
|
14
|
+
from ragit.config import config
|
|
15
|
+
from ragit.providers.base import (
|
|
16
|
+
BaseEmbeddingProvider,
|
|
17
|
+
BaseLLMProvider,
|
|
18
|
+
EmbeddingResponse,
|
|
19
|
+
LLMResponse,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
24
|
+
"""
|
|
25
|
+
Ollama provider for both LLM and Embedding operations.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
base_url : str, optional
|
|
30
|
+
Ollama server URL (default: from OLLAMA_BASE_URL env var)
|
|
31
|
+
api_key : str, optional
|
|
32
|
+
API key for authentication (default: from OLLAMA_API_KEY env var)
|
|
33
|
+
timeout : int, optional
|
|
34
|
+
Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
|
|
35
|
+
|
|
36
|
+
Examples
|
|
37
|
+
--------
|
|
38
|
+
>>> provider = OllamaProvider()
|
|
39
|
+
>>> response = provider.generate("What is RAG?", model="llama3")
|
|
40
|
+
>>> print(response.text)
|
|
41
|
+
|
|
42
|
+
>>> embedding = provider.embed("Hello world", model="nomic-embed-text")
|
|
43
|
+
>>> print(len(embedding.embedding))
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
# Known embedding model dimensions
|
|
47
|
+
EMBEDDING_DIMENSIONS = {
|
|
48
|
+
"nomic-embed-text": 768,
|
|
49
|
+
"nomic-embed-text:latest": 768,
|
|
50
|
+
"mxbai-embed-large": 1024,
|
|
51
|
+
"all-minilm": 384,
|
|
52
|
+
"snowflake-arctic-embed": 1024,
|
|
53
|
+
"qwen3-embedding": 4096,
|
|
54
|
+
"qwen3-embedding:0.6b": 1024,
|
|
55
|
+
"qwen3-embedding:4b": 2560,
|
|
56
|
+
"qwen3-embedding:8b": 4096,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
base_url: str | None = None,
|
|
62
|
+
embedding_url: str | None = None,
|
|
63
|
+
api_key: str | None = None,
|
|
64
|
+
timeout: int | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
|
|
67
|
+
self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
|
|
68
|
+
self.api_key = api_key or config.OLLAMA_API_KEY
|
|
69
|
+
self.timeout = timeout or config.OLLAMA_TIMEOUT
|
|
70
|
+
self._current_embed_model: str | None = None
|
|
71
|
+
self._current_dimensions: int = 768 # default
|
|
72
|
+
|
|
73
|
+
def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
|
|
74
|
+
"""Get request headers including authentication if API key is set."""
|
|
75
|
+
headers = {"Content-Type": "application/json"}
|
|
76
|
+
if include_auth and self.api_key:
|
|
77
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
78
|
+
return headers
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def provider_name(self) -> str:
|
|
82
|
+
return "ollama"
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def dimensions(self) -> int:
|
|
86
|
+
return self._current_dimensions
|
|
87
|
+
|
|
88
|
+
def is_available(self) -> bool:
|
|
89
|
+
"""Check if Ollama server is reachable."""
|
|
90
|
+
try:
|
|
91
|
+
response = requests.get(
|
|
92
|
+
f"{self.base_url}/api/tags",
|
|
93
|
+
headers=self._get_headers(),
|
|
94
|
+
timeout=5,
|
|
95
|
+
)
|
|
96
|
+
return response.status_code == 200
|
|
97
|
+
except requests.RequestException:
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
def list_models(self) -> list[dict[str, str]]:
|
|
101
|
+
"""List available models on the Ollama server."""
|
|
102
|
+
try:
|
|
103
|
+
response = requests.get(
|
|
104
|
+
f"{self.base_url}/api/tags",
|
|
105
|
+
headers=self._get_headers(),
|
|
106
|
+
timeout=10,
|
|
107
|
+
)
|
|
108
|
+
response.raise_for_status()
|
|
109
|
+
data = response.json()
|
|
110
|
+
return list(data.get("models", []))
|
|
111
|
+
except requests.RequestException as e:
|
|
112
|
+
raise ConnectionError(f"Failed to list Ollama models: {e}") from e
|
|
113
|
+
|
|
114
|
+
def generate(
|
|
115
|
+
self,
|
|
116
|
+
prompt: str,
|
|
117
|
+
model: str,
|
|
118
|
+
system_prompt: str | None = None,
|
|
119
|
+
temperature: float = 0.7,
|
|
120
|
+
max_tokens: int | None = None,
|
|
121
|
+
) -> LLMResponse:
|
|
122
|
+
"""Generate text using Ollama."""
|
|
123
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
124
|
+
if max_tokens:
|
|
125
|
+
options["num_predict"] = max_tokens
|
|
126
|
+
|
|
127
|
+
payload: dict[str, str | bool | dict[str, float | int]] = {
|
|
128
|
+
"model": model,
|
|
129
|
+
"prompt": prompt,
|
|
130
|
+
"stream": False,
|
|
131
|
+
"options": options,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if system_prompt:
|
|
135
|
+
payload["system"] = system_prompt
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
response = requests.post(
|
|
139
|
+
f"{self.base_url}/api/generate",
|
|
140
|
+
headers=self._get_headers(),
|
|
141
|
+
json=payload,
|
|
142
|
+
timeout=self.timeout,
|
|
143
|
+
)
|
|
144
|
+
response.raise_for_status()
|
|
145
|
+
data = response.json()
|
|
146
|
+
|
|
147
|
+
return LLMResponse(
|
|
148
|
+
text=data.get("response", ""),
|
|
149
|
+
model=model,
|
|
150
|
+
provider=self.provider_name,
|
|
151
|
+
usage={
|
|
152
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
153
|
+
"completion_tokens": data.get("eval_count"),
|
|
154
|
+
"total_duration": data.get("total_duration"),
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
except requests.RequestException as e:
|
|
158
|
+
raise ConnectionError(f"Ollama generate failed: {e}") from e
|
|
159
|
+
|
|
160
|
+
def embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
161
|
+
"""Generate embedding using Ollama (uses embedding_url, no auth for local)."""
|
|
162
|
+
self._current_embed_model = model
|
|
163
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
response = requests.post(
|
|
167
|
+
f"{self.embedding_url}/api/embeddings",
|
|
168
|
+
headers=self._get_headers(include_auth=False),
|
|
169
|
+
json={"model": model, "prompt": text},
|
|
170
|
+
timeout=self.timeout,
|
|
171
|
+
)
|
|
172
|
+
response.raise_for_status()
|
|
173
|
+
data = response.json()
|
|
174
|
+
|
|
175
|
+
embedding = data.get("embedding", [])
|
|
176
|
+
if not embedding:
|
|
177
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
178
|
+
|
|
179
|
+
# Update dimensions from actual response
|
|
180
|
+
self._current_dimensions = len(embedding)
|
|
181
|
+
|
|
182
|
+
return EmbeddingResponse(
|
|
183
|
+
embedding=tuple(embedding),
|
|
184
|
+
model=model,
|
|
185
|
+
provider=self.provider_name,
|
|
186
|
+
dimensions=len(embedding),
|
|
187
|
+
)
|
|
188
|
+
except requests.RequestException as e:
|
|
189
|
+
raise ConnectionError(f"Ollama embed failed: {e}") from e
|
|
190
|
+
|
|
191
|
+
def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
192
|
+
"""Generate embeddings for multiple texts (uses embedding_url, no auth for local).
|
|
193
|
+
|
|
194
|
+
Note: Ollama /api/embeddings only supports single prompts, so we loop.
|
|
195
|
+
"""
|
|
196
|
+
self._current_embed_model = model
|
|
197
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
198
|
+
|
|
199
|
+
results = []
|
|
200
|
+
try:
|
|
201
|
+
for text in texts:
|
|
202
|
+
response = requests.post(
|
|
203
|
+
f"{self.embedding_url}/api/embeddings",
|
|
204
|
+
headers=self._get_headers(include_auth=False),
|
|
205
|
+
json={"model": model, "prompt": text},
|
|
206
|
+
timeout=self.timeout,
|
|
207
|
+
)
|
|
208
|
+
response.raise_for_status()
|
|
209
|
+
data = response.json()
|
|
210
|
+
|
|
211
|
+
embedding = data.get("embedding", [])
|
|
212
|
+
if embedding:
|
|
213
|
+
self._current_dimensions = len(embedding)
|
|
214
|
+
|
|
215
|
+
results.append(
|
|
216
|
+
EmbeddingResponse(
|
|
217
|
+
embedding=tuple(embedding),
|
|
218
|
+
model=model,
|
|
219
|
+
provider=self.provider_name,
|
|
220
|
+
dimensions=len(embedding),
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
return results
|
|
224
|
+
except requests.RequestException as e:
|
|
225
|
+
raise ConnectionError(f"Ollama batch embed failed: {e}") from e
|
|
226
|
+
|
|
227
|
+
def chat(
|
|
228
|
+
self,
|
|
229
|
+
messages: list[dict[str, str]],
|
|
230
|
+
model: str,
|
|
231
|
+
temperature: float = 0.7,
|
|
232
|
+
max_tokens: int | None = None,
|
|
233
|
+
) -> LLMResponse:
|
|
234
|
+
"""
|
|
235
|
+
Chat completion using Ollama.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
messages : list[dict]
|
|
240
|
+
List of messages with 'role' and 'content' keys.
|
|
241
|
+
model : str
|
|
242
|
+
Model identifier.
|
|
243
|
+
temperature : float
|
|
244
|
+
Sampling temperature.
|
|
245
|
+
max_tokens : int, optional
|
|
246
|
+
Maximum tokens to generate.
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
LLMResponse
|
|
251
|
+
The generated response.
|
|
252
|
+
"""
|
|
253
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
254
|
+
if max_tokens:
|
|
255
|
+
options["num_predict"] = max_tokens
|
|
256
|
+
|
|
257
|
+
payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
|
|
258
|
+
"model": model,
|
|
259
|
+
"messages": messages,
|
|
260
|
+
"stream": False,
|
|
261
|
+
"options": options,
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
response = requests.post(
|
|
266
|
+
f"{self.base_url}/api/chat",
|
|
267
|
+
headers=self._get_headers(),
|
|
268
|
+
json=payload,
|
|
269
|
+
timeout=self.timeout,
|
|
270
|
+
)
|
|
271
|
+
response.raise_for_status()
|
|
272
|
+
data = response.json()
|
|
273
|
+
|
|
274
|
+
return LLMResponse(
|
|
275
|
+
text=data.get("message", {}).get("content", ""),
|
|
276
|
+
model=model,
|
|
277
|
+
provider=self.provider_name,
|
|
278
|
+
usage={
|
|
279
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
280
|
+
"completion_tokens": data.get("eval_count"),
|
|
281
|
+
},
|
|
282
|
+
)
|
|
283
|
+
except requests.RequestException as e:
|
|
284
|
+
raise ConnectionError(f"Ollama chat failed: {e}") from e
|