ebk 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ebk/__init__.py +35 -0
- ebk/ai/__init__.py +23 -0
- ebk/ai/knowledge_graph.py +450 -0
- ebk/ai/llm_providers/__init__.py +26 -0
- ebk/ai/llm_providers/anthropic.py +209 -0
- ebk/ai/llm_providers/base.py +295 -0
- ebk/ai/llm_providers/gemini.py +285 -0
- ebk/ai/llm_providers/ollama.py +294 -0
- ebk/ai/metadata_enrichment.py +394 -0
- ebk/ai/question_generator.py +328 -0
- ebk/ai/reading_companion.py +224 -0
- ebk/ai/semantic_search.py +433 -0
- ebk/ai/text_extractor.py +393 -0
- ebk/calibre_import.py +66 -0
- ebk/cli.py +6433 -0
- ebk/config.py +230 -0
- ebk/db/__init__.py +37 -0
- ebk/db/migrations.py +507 -0
- ebk/db/models.py +725 -0
- ebk/db/session.py +144 -0
- ebk/decorators.py +1 -0
- ebk/exports/__init__.py +0 -0
- ebk/exports/base_exporter.py +218 -0
- ebk/exports/echo_export.py +279 -0
- ebk/exports/html_library.py +1743 -0
- ebk/exports/html_utils.py +87 -0
- ebk/exports/hugo.py +59 -0
- ebk/exports/jinja_export.py +286 -0
- ebk/exports/multi_facet_export.py +159 -0
- ebk/exports/opds_export.py +232 -0
- ebk/exports/symlink_dag.py +479 -0
- ebk/exports/zip.py +25 -0
- ebk/extract_metadata.py +341 -0
- ebk/ident.py +89 -0
- ebk/library_db.py +1440 -0
- ebk/opds.py +748 -0
- ebk/plugins/__init__.py +42 -0
- ebk/plugins/base.py +502 -0
- ebk/plugins/hooks.py +442 -0
- ebk/plugins/registry.py +499 -0
- ebk/repl/__init__.py +9 -0
- ebk/repl/find.py +126 -0
- ebk/repl/grep.py +173 -0
- ebk/repl/shell.py +1677 -0
- ebk/repl/text_utils.py +320 -0
- ebk/search_parser.py +413 -0
- ebk/server.py +3608 -0
- ebk/services/__init__.py +28 -0
- ebk/services/annotation_extraction.py +351 -0
- ebk/services/annotation_service.py +380 -0
- ebk/services/export_service.py +577 -0
- ebk/services/import_service.py +447 -0
- ebk/services/personal_metadata_service.py +347 -0
- ebk/services/queue_service.py +253 -0
- ebk/services/tag_service.py +281 -0
- ebk/services/text_extraction.py +317 -0
- ebk/services/view_service.py +12 -0
- ebk/similarity/__init__.py +77 -0
- ebk/similarity/base.py +154 -0
- ebk/similarity/core.py +471 -0
- ebk/similarity/extractors.py +168 -0
- ebk/similarity/metrics.py +376 -0
- ebk/skills/SKILL.md +182 -0
- ebk/skills/__init__.py +1 -0
- ebk/vfs/__init__.py +101 -0
- ebk/vfs/base.py +298 -0
- ebk/vfs/library_vfs.py +122 -0
- ebk/vfs/nodes/__init__.py +54 -0
- ebk/vfs/nodes/authors.py +196 -0
- ebk/vfs/nodes/books.py +480 -0
- ebk/vfs/nodes/files.py +155 -0
- ebk/vfs/nodes/metadata.py +385 -0
- ebk/vfs/nodes/root.py +100 -0
- ebk/vfs/nodes/similar.py +165 -0
- ebk/vfs/nodes/subjects.py +184 -0
- ebk/vfs/nodes/tags.py +371 -0
- ebk/vfs/resolver.py +228 -0
- ebk/vfs_router.py +275 -0
- ebk/views/__init__.py +32 -0
- ebk/views/dsl.py +668 -0
- ebk/views/service.py +619 -0
- ebk-0.4.4.dist-info/METADATA +755 -0
- ebk-0.4.4.dist-info/RECORD +87 -0
- ebk-0.4.4.dist-info/WHEEL +5 -0
- ebk-0.4.4.dist-info/entry_points.txt +2 -0
- ebk-0.4.4.dist-info/licenses/LICENSE +21 -0
- ebk-0.4.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google Gemini LLM Provider.
|
|
3
|
+
|
|
4
|
+
Supports Gemini models via the Google AI API.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from typing import Dict, Any, List, Optional, AsyncIterator
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from .base import BaseLLMProvider, LLMConfig, LLMResponse, ModelCapability
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GeminiProvider(BaseLLMProvider):
|
|
15
|
+
"""
|
|
16
|
+
Google Gemini LLM provider.
|
|
17
|
+
|
|
18
|
+
Supports:
|
|
19
|
+
- Gemini 1.5 models (flash, pro)
|
|
20
|
+
- Gemini 2.0 models
|
|
21
|
+
- Streaming completions
|
|
22
|
+
- JSON mode
|
|
23
|
+
- Embeddings
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
DEFAULT_MODEL = "gemini-1.5-flash"
|
|
27
|
+
|
|
28
|
+
def __init__(self, config: LLMConfig):
|
|
29
|
+
"""
|
|
30
|
+
Initialize Gemini provider.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config: LLM configuration with api_key for Google AI
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(config)
|
|
36
|
+
if not config.api_key:
|
|
37
|
+
raise ValueError("Google AI API key is required")
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def name(self) -> str:
|
|
41
|
+
return "gemini"
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def supported_capabilities(self) -> List[ModelCapability]:
|
|
45
|
+
return [
|
|
46
|
+
ModelCapability.TEXT_GENERATION,
|
|
47
|
+
ModelCapability.JSON_MODE,
|
|
48
|
+
ModelCapability.STREAMING,
|
|
49
|
+
ModelCapability.EMBEDDINGS,
|
|
50
|
+
ModelCapability.VISION,
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def create(
|
|
55
|
+
cls,
|
|
56
|
+
api_key: str,
|
|
57
|
+
model: str = DEFAULT_MODEL,
|
|
58
|
+
**kwargs
|
|
59
|
+
) -> 'GeminiProvider':
|
|
60
|
+
"""
|
|
61
|
+
Create a Gemini provider.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
api_key: Google AI API key
|
|
65
|
+
model: Model name (e.g., 'gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-2.0-flash-exp')
|
|
66
|
+
**kwargs: Additional config parameters
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Configured GeminiProvider
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
>>> provider = GeminiProvider.create(
|
|
73
|
+
... api_key="AIza...",
|
|
74
|
+
... model="gemini-1.5-flash"
|
|
75
|
+
... )
|
|
76
|
+
"""
|
|
77
|
+
config = LLMConfig(
|
|
78
|
+
base_url="https://generativelanguage.googleapis.com",
|
|
79
|
+
api_key=api_key,
|
|
80
|
+
model=model,
|
|
81
|
+
**kwargs
|
|
82
|
+
)
|
|
83
|
+
return cls(config)
|
|
84
|
+
|
|
85
|
+
async def initialize(self) -> None:
|
|
86
|
+
"""Initialize HTTP client."""
|
|
87
|
+
self._client = httpx.AsyncClient(
|
|
88
|
+
base_url=self.config.base_url,
|
|
89
|
+
timeout=self.config.timeout,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
async def cleanup(self) -> None:
|
|
93
|
+
"""Close HTTP client."""
|
|
94
|
+
if self._client:
|
|
95
|
+
await self._client.aclose()
|
|
96
|
+
self._client = None
|
|
97
|
+
|
|
98
|
+
def _build_endpoint(self, action: str = "generateContent") -> str:
|
|
99
|
+
"""Build API endpoint with model and API key."""
|
|
100
|
+
return f"/v1beta/models/{self.config.model}:{action}?key={self.config.api_key}"
|
|
101
|
+
|
|
102
|
+
async def complete(
|
|
103
|
+
self,
|
|
104
|
+
prompt: str,
|
|
105
|
+
system_prompt: Optional[str] = None,
|
|
106
|
+
**kwargs
|
|
107
|
+
) -> LLMResponse:
|
|
108
|
+
"""
|
|
109
|
+
Generate completion using Gemini.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
prompt: User prompt
|
|
113
|
+
system_prompt: Optional system prompt
|
|
114
|
+
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
LLMResponse with generated text
|
|
118
|
+
"""
|
|
119
|
+
if not self._client:
|
|
120
|
+
await self.initialize()
|
|
121
|
+
|
|
122
|
+
# Build request payload - Gemini uses contents format
|
|
123
|
+
contents = []
|
|
124
|
+
|
|
125
|
+
# Add system prompt as first user message if provided
|
|
126
|
+
if system_prompt:
|
|
127
|
+
contents.append({
|
|
128
|
+
"role": "user",
|
|
129
|
+
"parts": [{"text": f"System instruction: {system_prompt}"}]
|
|
130
|
+
})
|
|
131
|
+
contents.append({
|
|
132
|
+
"role": "model",
|
|
133
|
+
"parts": [{"text": "Understood. I will follow those instructions."}]
|
|
134
|
+
})
|
|
135
|
+
|
|
136
|
+
contents.append({
|
|
137
|
+
"role": "user",
|
|
138
|
+
"parts": [{"text": prompt}]
|
|
139
|
+
})
|
|
140
|
+
|
|
141
|
+
data = {
|
|
142
|
+
"contents": contents,
|
|
143
|
+
"generationConfig": {
|
|
144
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
145
|
+
"topP": kwargs.get("top_p", self.config.top_p),
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
if self.config.max_tokens or kwargs.get("max_tokens"):
|
|
150
|
+
data["generationConfig"]["maxOutputTokens"] = kwargs.get("max_tokens", self.config.max_tokens)
|
|
151
|
+
|
|
152
|
+
# Make request
|
|
153
|
+
endpoint = self._build_endpoint()
|
|
154
|
+
response = await self._client.post(endpoint, json=data)
|
|
155
|
+
response.raise_for_status()
|
|
156
|
+
|
|
157
|
+
result = response.json()
|
|
158
|
+
|
|
159
|
+
# Extract content from response
|
|
160
|
+
content = ""
|
|
161
|
+
candidates = result.get("candidates", [])
|
|
162
|
+
if candidates:
|
|
163
|
+
parts = candidates[0].get("content", {}).get("parts", [])
|
|
164
|
+
for part in parts:
|
|
165
|
+
if "text" in part:
|
|
166
|
+
content += part["text"]
|
|
167
|
+
|
|
168
|
+
# Get usage info
|
|
169
|
+
usage_metadata = result.get("usageMetadata", {})
|
|
170
|
+
|
|
171
|
+
return LLMResponse(
|
|
172
|
+
content=content,
|
|
173
|
+
model=self.config.model,
|
|
174
|
+
finish_reason=candidates[0].get("finishReason") if candidates else None,
|
|
175
|
+
usage={
|
|
176
|
+
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
|
|
177
|
+
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
|
178
|
+
},
|
|
179
|
+
raw_response=result,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
async def complete_streaming(
|
|
183
|
+
self,
|
|
184
|
+
prompt: str,
|
|
185
|
+
system_prompt: Optional[str] = None,
|
|
186
|
+
**kwargs
|
|
187
|
+
) -> AsyncIterator[str]:
|
|
188
|
+
"""
|
|
189
|
+
Generate streaming completion.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
prompt: User prompt
|
|
193
|
+
system_prompt: Optional system prompt
|
|
194
|
+
**kwargs: Additional parameters
|
|
195
|
+
|
|
196
|
+
Yields:
|
|
197
|
+
Text chunks as they are generated
|
|
198
|
+
"""
|
|
199
|
+
if not self._client:
|
|
200
|
+
await self.initialize()
|
|
201
|
+
|
|
202
|
+
# Build request payload
|
|
203
|
+
contents = []
|
|
204
|
+
|
|
205
|
+
if system_prompt:
|
|
206
|
+
contents.append({
|
|
207
|
+
"role": "user",
|
|
208
|
+
"parts": [{"text": f"System instruction: {system_prompt}"}]
|
|
209
|
+
})
|
|
210
|
+
contents.append({
|
|
211
|
+
"role": "model",
|
|
212
|
+
"parts": [{"text": "Understood. I will follow those instructions."}]
|
|
213
|
+
})
|
|
214
|
+
|
|
215
|
+
contents.append({
|
|
216
|
+
"role": "user",
|
|
217
|
+
"parts": [{"text": prompt}]
|
|
218
|
+
})
|
|
219
|
+
|
|
220
|
+
data = {
|
|
221
|
+
"contents": contents,
|
|
222
|
+
"generationConfig": {
|
|
223
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
224
|
+
"topP": kwargs.get("top_p", self.config.top_p),
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
if self.config.max_tokens or kwargs.get("max_tokens"):
|
|
229
|
+
data["generationConfig"]["maxOutputTokens"] = kwargs.get("max_tokens", self.config.max_tokens)
|
|
230
|
+
|
|
231
|
+
endpoint = self._build_endpoint("streamGenerateContent")
|
|
232
|
+
endpoint += "&alt=sse" # Enable SSE streaming
|
|
233
|
+
|
|
234
|
+
async with self._client.stream("POST", endpoint, json=data) as response:
|
|
235
|
+
response.raise_for_status()
|
|
236
|
+
|
|
237
|
+
async for line in response.aiter_lines():
|
|
238
|
+
if line.startswith("data: "):
|
|
239
|
+
try:
|
|
240
|
+
chunk = json.loads(line[6:])
|
|
241
|
+
candidates = chunk.get("candidates", [])
|
|
242
|
+
if candidates:
|
|
243
|
+
parts = candidates[0].get("content", {}).get("parts", [])
|
|
244
|
+
for part in parts:
|
|
245
|
+
if "text" in part:
|
|
246
|
+
yield part["text"]
|
|
247
|
+
except json.JSONDecodeError:
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
async def get_embeddings(
|
|
251
|
+
self,
|
|
252
|
+
texts: List[str],
|
|
253
|
+
**kwargs
|
|
254
|
+
) -> List[List[float]]:
|
|
255
|
+
"""
|
|
256
|
+
Get embeddings using Gemini.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
texts: List of texts to embed
|
|
260
|
+
**kwargs: Additional parameters
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
List of embedding vectors
|
|
264
|
+
"""
|
|
265
|
+
if not self._client:
|
|
266
|
+
await self.initialize()
|
|
267
|
+
|
|
268
|
+
embeddings = []
|
|
269
|
+
embed_model = kwargs.get("embed_model", "text-embedding-004")
|
|
270
|
+
|
|
271
|
+
for text in texts:
|
|
272
|
+
data = {
|
|
273
|
+
"content": {
|
|
274
|
+
"parts": [{"text": text}]
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
endpoint = f"/v1beta/models/{embed_model}:embedContent?key={self.config.api_key}"
|
|
279
|
+
response = await self._client.post(endpoint, json=data)
|
|
280
|
+
response.raise_for_status()
|
|
281
|
+
|
|
282
|
+
result = response.json()
|
|
283
|
+
embeddings.append(result.get("embedding", {}).get("values", []))
|
|
284
|
+
|
|
285
|
+
return embeddings
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ollama LLM Provider.
|
|
3
|
+
|
|
4
|
+
Supports both local and remote Ollama instances.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from typing import Dict, Any, List, Optional, AsyncIterator
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from .base import BaseLLMProvider, LLMConfig, LLMResponse, ModelCapability
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OllamaProvider(BaseLLMProvider):
|
|
15
|
+
"""
|
|
16
|
+
Ollama LLM provider.
|
|
17
|
+
|
|
18
|
+
Supports:
|
|
19
|
+
- Local Ollama (default: http://localhost:11434)
|
|
20
|
+
- Remote Ollama (e.g., basement GPU server)
|
|
21
|
+
- Streaming completions
|
|
22
|
+
- JSON mode
|
|
23
|
+
- Embeddings
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: LLMConfig):
|
|
27
|
+
"""
|
|
28
|
+
Initialize Ollama provider.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
config: LLM configuration with base_url pointing to Ollama
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(config)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def name(self) -> str:
|
|
37
|
+
return "ollama"
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def supported_capabilities(self) -> List[ModelCapability]:
|
|
41
|
+
return [
|
|
42
|
+
ModelCapability.TEXT_GENERATION,
|
|
43
|
+
ModelCapability.JSON_MODE,
|
|
44
|
+
ModelCapability.STREAMING,
|
|
45
|
+
ModelCapability.EMBEDDINGS,
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def local(cls, model: str = "llama3.2", **kwargs) -> 'OllamaProvider':
|
|
50
|
+
"""
|
|
51
|
+
Create provider for local Ollama instance.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model: Model name (e.g., 'llama3.2', 'mistral', 'codellama')
|
|
55
|
+
**kwargs: Additional config parameters
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Configured OllamaProvider
|
|
59
|
+
"""
|
|
60
|
+
config = LLMConfig(
|
|
61
|
+
base_url="http://localhost:11434",
|
|
62
|
+
model=model,
|
|
63
|
+
**kwargs
|
|
64
|
+
)
|
|
65
|
+
return cls(config)
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def remote(
|
|
69
|
+
cls,
|
|
70
|
+
host: str,
|
|
71
|
+
port: int = 11434,
|
|
72
|
+
model: str = "llama3.2",
|
|
73
|
+
**kwargs
|
|
74
|
+
) -> 'OllamaProvider':
|
|
75
|
+
"""
|
|
76
|
+
Create provider for remote Ollama instance.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
host: Remote host (e.g., '192.168.1.100', 'basement-gpu.local')
|
|
80
|
+
port: Ollama port (default: 11434)
|
|
81
|
+
model: Model name
|
|
82
|
+
**kwargs: Additional config parameters
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Configured OllamaProvider
|
|
86
|
+
|
|
87
|
+
Example:
|
|
88
|
+
>>> provider = OllamaProvider.remote(
|
|
89
|
+
... host='192.168.1.100',
|
|
90
|
+
... model='llama3.2'
|
|
91
|
+
... )
|
|
92
|
+
"""
|
|
93
|
+
config = LLMConfig(
|
|
94
|
+
base_url=f"http://{host}:{port}",
|
|
95
|
+
model=model,
|
|
96
|
+
**kwargs
|
|
97
|
+
)
|
|
98
|
+
return cls(config)
|
|
99
|
+
|
|
100
|
+
async def initialize(self) -> None:
|
|
101
|
+
"""Initialize HTTP client."""
|
|
102
|
+
self._client = httpx.AsyncClient(
|
|
103
|
+
base_url=self.config.base_url,
|
|
104
|
+
timeout=self.config.timeout,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
async def cleanup(self) -> None:
|
|
108
|
+
"""Close HTTP client."""
|
|
109
|
+
if self._client:
|
|
110
|
+
await self._client.aclose()
|
|
111
|
+
self._client = None
|
|
112
|
+
|
|
113
|
+
async def complete(
|
|
114
|
+
self,
|
|
115
|
+
prompt: str,
|
|
116
|
+
system_prompt: Optional[str] = None,
|
|
117
|
+
**kwargs
|
|
118
|
+
) -> LLMResponse:
|
|
119
|
+
"""
|
|
120
|
+
Generate completion using Ollama.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
prompt: User prompt
|
|
124
|
+
system_prompt: Optional system prompt
|
|
125
|
+
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
LLMResponse with generated text
|
|
129
|
+
"""
|
|
130
|
+
if not self._client:
|
|
131
|
+
await self.initialize()
|
|
132
|
+
|
|
133
|
+
# Build request payload
|
|
134
|
+
data = {
|
|
135
|
+
"model": self.config.model,
|
|
136
|
+
"prompt": prompt,
|
|
137
|
+
"stream": False,
|
|
138
|
+
"options": {
|
|
139
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
140
|
+
"top_p": kwargs.get("top_p", self.config.top_p),
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
if system_prompt:
|
|
145
|
+
data["system"] = system_prompt
|
|
146
|
+
|
|
147
|
+
if self.config.max_tokens:
|
|
148
|
+
data["options"]["num_predict"] = self.config.max_tokens
|
|
149
|
+
|
|
150
|
+
# Make request
|
|
151
|
+
response = await self._client.post("/api/generate", json=data)
|
|
152
|
+
response.raise_for_status()
|
|
153
|
+
|
|
154
|
+
result = response.json()
|
|
155
|
+
|
|
156
|
+
return LLMResponse(
|
|
157
|
+
content=result["response"],
|
|
158
|
+
model=result.get("model", self.config.model),
|
|
159
|
+
finish_reason=result.get("done_reason"),
|
|
160
|
+
usage={
|
|
161
|
+
"prompt_tokens": result.get("prompt_eval_count", 0),
|
|
162
|
+
"completion_tokens": result.get("eval_count", 0),
|
|
163
|
+
},
|
|
164
|
+
raw_response=result,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
async def complete_json(
|
|
168
|
+
self,
|
|
169
|
+
prompt: str,
|
|
170
|
+
system_prompt: Optional[str] = None,
|
|
171
|
+
schema: Optional[Dict[str, Any]] = None,
|
|
172
|
+
**kwargs
|
|
173
|
+
) -> Dict[str, Any]:
|
|
174
|
+
"""
|
|
175
|
+
Generate JSON completion using Ollama's JSON format mode.
|
|
176
|
+
|
|
177
|
+
Overrides base to set Ollama's format="json" parameter
|
|
178
|
+
before delegating to the shared prompt-building logic.
|
|
179
|
+
"""
|
|
180
|
+
# Use Ollama's JSON format mode if available
|
|
181
|
+
kwargs["format"] = "json"
|
|
182
|
+
return await super().complete_json(prompt, system_prompt=system_prompt, schema=schema, **kwargs)
|
|
183
|
+
|
|
184
|
+
async def complete_streaming(
|
|
185
|
+
self,
|
|
186
|
+
prompt: str,
|
|
187
|
+
system_prompt: Optional[str] = None,
|
|
188
|
+
**kwargs
|
|
189
|
+
) -> AsyncIterator[str]:
|
|
190
|
+
"""
|
|
191
|
+
Generate streaming completion.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
prompt: User prompt
|
|
195
|
+
system_prompt: Optional system prompt
|
|
196
|
+
**kwargs: Additional parameters
|
|
197
|
+
|
|
198
|
+
Yields:
|
|
199
|
+
Text chunks as they are generated
|
|
200
|
+
"""
|
|
201
|
+
if not self._client:
|
|
202
|
+
await self.initialize()
|
|
203
|
+
|
|
204
|
+
data = {
|
|
205
|
+
"model": self.config.model,
|
|
206
|
+
"prompt": prompt,
|
|
207
|
+
"stream": True,
|
|
208
|
+
"options": {
|
|
209
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
210
|
+
"top_p": kwargs.get("top_p", self.config.top_p),
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
if system_prompt:
|
|
215
|
+
data["system"] = system_prompt
|
|
216
|
+
|
|
217
|
+
async with self._client.stream("POST", "/api/generate", json=data) as response:
|
|
218
|
+
response.raise_for_status()
|
|
219
|
+
|
|
220
|
+
async for line in response.aiter_lines():
|
|
221
|
+
if line.strip():
|
|
222
|
+
try:
|
|
223
|
+
chunk = json.loads(line)
|
|
224
|
+
if "response" in chunk:
|
|
225
|
+
yield chunk["response"]
|
|
226
|
+
if chunk.get("done", False):
|
|
227
|
+
break
|
|
228
|
+
except json.JSONDecodeError:
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
async def get_embeddings(
|
|
232
|
+
self,
|
|
233
|
+
texts: List[str],
|
|
234
|
+
**kwargs
|
|
235
|
+
) -> List[List[float]]:
|
|
236
|
+
"""
|
|
237
|
+
Get embeddings using Ollama.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
texts: List of texts to embed
|
|
241
|
+
**kwargs: Additional parameters
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
List of embedding vectors
|
|
245
|
+
"""
|
|
246
|
+
if not self._client:
|
|
247
|
+
await self.initialize()
|
|
248
|
+
|
|
249
|
+
embeddings = []
|
|
250
|
+
|
|
251
|
+
for text in texts:
|
|
252
|
+
data = {
|
|
253
|
+
"model": self.config.model,
|
|
254
|
+
"prompt": text,
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
response = await self._client.post("/api/embeddings", json=data)
|
|
258
|
+
response.raise_for_status()
|
|
259
|
+
|
|
260
|
+
result = response.json()
|
|
261
|
+
embeddings.append(result["embedding"])
|
|
262
|
+
|
|
263
|
+
return embeddings
|
|
264
|
+
|
|
265
|
+
async def list_models(self) -> List[str]:
|
|
266
|
+
"""
|
|
267
|
+
List available models on Ollama server.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
List of model names
|
|
271
|
+
"""
|
|
272
|
+
if not self._client:
|
|
273
|
+
await self.initialize()
|
|
274
|
+
|
|
275
|
+
response = await self._client.get("/api/tags")
|
|
276
|
+
response.raise_for_status()
|
|
277
|
+
|
|
278
|
+
result = response.json()
|
|
279
|
+
return [model["name"] for model in result.get("models", [])]
|
|
280
|
+
|
|
281
|
+
async def pull_model(self, model_name: str) -> None:
|
|
282
|
+
"""
|
|
283
|
+
Pull a model from Ollama registry.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
model_name: Name of model to pull (e.g., 'llama3.2', 'mistral')
|
|
287
|
+
"""
|
|
288
|
+
if not self._client:
|
|
289
|
+
await self.initialize()
|
|
290
|
+
|
|
291
|
+
data = {"name": model_name, "stream": False}
|
|
292
|
+
|
|
293
|
+
response = await self._client.post("/api/pull", json=data)
|
|
294
|
+
response.raise_for_status()
|