parishad 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parishad/__init__.py +70 -0
- parishad/__main__.py +10 -0
- parishad/checker/__init__.py +25 -0
- parishad/checker/deterministic.py +644 -0
- parishad/checker/ensemble.py +496 -0
- parishad/checker/retrieval.py +546 -0
- parishad/cli/__init__.py +6 -0
- parishad/cli/code.py +3254 -0
- parishad/cli/main.py +1158 -0
- parishad/cli/prarambh.py +99 -0
- parishad/cli/sthapana.py +368 -0
- parishad/config/modes.py +139 -0
- parishad/config/pipeline.core.yaml +128 -0
- parishad/config/pipeline.extended.yaml +172 -0
- parishad/config/pipeline.fast.yaml +89 -0
- parishad/config/user_config.py +115 -0
- parishad/data/catalog.py +118 -0
- parishad/data/models.json +108 -0
- parishad/memory/__init__.py +79 -0
- parishad/models/__init__.py +181 -0
- parishad/models/backends/__init__.py +247 -0
- parishad/models/backends/base.py +211 -0
- parishad/models/backends/huggingface.py +318 -0
- parishad/models/backends/llama_cpp.py +239 -0
- parishad/models/backends/mlx_lm.py +141 -0
- parishad/models/backends/ollama.py +253 -0
- parishad/models/backends/openai_api.py +193 -0
- parishad/models/backends/transformers_hf.py +198 -0
- parishad/models/costs.py +385 -0
- parishad/models/downloader.py +1557 -0
- parishad/models/optimizations.py +871 -0
- parishad/models/profiles.py +610 -0
- parishad/models/reliability.py +876 -0
- parishad/models/runner.py +651 -0
- parishad/models/tokenization.py +287 -0
- parishad/orchestrator/__init__.py +24 -0
- parishad/orchestrator/config_loader.py +210 -0
- parishad/orchestrator/engine.py +1113 -0
- parishad/orchestrator/exceptions.py +14 -0
- parishad/roles/__init__.py +71 -0
- parishad/roles/base.py +712 -0
- parishad/roles/dandadhyaksha.py +163 -0
- parishad/roles/darbari.py +246 -0
- parishad/roles/majumdar.py +274 -0
- parishad/roles/pantapradhan.py +150 -0
- parishad/roles/prerak.py +357 -0
- parishad/roles/raja.py +345 -0
- parishad/roles/sacheev.py +203 -0
- parishad/roles/sainik.py +427 -0
- parishad/roles/sar_senapati.py +164 -0
- parishad/roles/vidushak.py +69 -0
- parishad/tools/__init__.py +7 -0
- parishad/tools/base.py +57 -0
- parishad/tools/fs.py +110 -0
- parishad/tools/perception.py +96 -0
- parishad/tools/retrieval.py +74 -0
- parishad/tools/shell.py +103 -0
- parishad/utils/__init__.py +7 -0
- parishad/utils/hardware.py +122 -0
- parishad/utils/logging.py +79 -0
- parishad/utils/scanner.py +164 -0
- parishad/utils/text.py +61 -0
- parishad/utils/tracing.py +133 -0
- parishad-0.1.0.dist-info/METADATA +256 -0
- parishad-0.1.0.dist-info/RECORD +68 -0
- parishad-0.1.0.dist-info/WHEEL +4 -0
- parishad-0.1.0.dist-info/entry_points.txt +2 -0
- parishad-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Llama.cpp backend for GGUF models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import gc
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
from .base import BackendConfig, BackendError, BackendResult, BaseBackend
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Suppress verbose llama.cpp logging (ggml_metal_init messages)
|
|
19
|
+
os.environ.setdefault("GGML_METAL_LOG_LEVEL", "0")
|
|
20
|
+
os.environ.setdefault("LLAMA_CPP_LOG_LEVEL", "0")
|
|
21
|
+
|
|
22
|
+
# Lazy imports
|
|
23
|
+
_llama_cpp = None
|
|
24
|
+
_suppress_output = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _get_llama_cpp():
|
|
28
|
+
"""Lazy import of llama-cpp-python."""
|
|
29
|
+
global _llama_cpp, _suppress_output
|
|
30
|
+
if _llama_cpp is None:
|
|
31
|
+
try:
|
|
32
|
+
import llama_cpp
|
|
33
|
+
from llama_cpp import suppress_stdout_stderr
|
|
34
|
+
_llama_cpp = llama_cpp
|
|
35
|
+
_suppress_output = suppress_stdout_stderr
|
|
36
|
+
except ImportError:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"llama-cpp-python is required for LlamaCppBackend. "
|
|
39
|
+
"Install with: pip install llama-cpp-python"
|
|
40
|
+
)
|
|
41
|
+
return _llama_cpp
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def resolve_model_path(model_id: str) -> Optional[Path]:
|
|
45
|
+
"""Resolve a model ID to a file path."""
|
|
46
|
+
# Direct path
|
|
47
|
+
direct = Path(model_id)
|
|
48
|
+
if direct.exists():
|
|
49
|
+
return direct
|
|
50
|
+
|
|
51
|
+
# Try model manager
|
|
52
|
+
try:
|
|
53
|
+
from ..downloader import ModelManager, get_default_model_dir
|
|
54
|
+
manager = ModelManager()
|
|
55
|
+
|
|
56
|
+
# Check registry
|
|
57
|
+
path = manager.get_model_path(model_id)
|
|
58
|
+
if path and path.exists():
|
|
59
|
+
return path
|
|
60
|
+
|
|
61
|
+
# Check ollama symlinks
|
|
62
|
+
ollama_dir = get_default_model_dir() / "ollama"
|
|
63
|
+
if ollama_dir.exists():
|
|
64
|
+
safe_name = model_id.replace(":", "_").replace("/", "_")
|
|
65
|
+
for suffix in [".gguf", ""]:
|
|
66
|
+
candidate = ollama_dir / f"{safe_name}{suffix}"
|
|
67
|
+
if candidate.exists():
|
|
68
|
+
return candidate
|
|
69
|
+
|
|
70
|
+
# Search for matching GGUF files
|
|
71
|
+
for model in manager.list_models():
|
|
72
|
+
if model.format.value == "gguf":
|
|
73
|
+
if model_id in model.name or model_id in str(model.path):
|
|
74
|
+
return model.path
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.debug(f"Model manager lookup failed: {e}")
|
|
77
|
+
|
|
78
|
+
# Common locations
|
|
79
|
+
search_paths = [
|
|
80
|
+
Path.cwd() / "models",
|
|
81
|
+
Path.home() / ".cache" / "parishad" / "models",
|
|
82
|
+
Path.home() / ".local" / "share" / "parishad" / "models",
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
for search_dir in search_paths:
|
|
86
|
+
if search_dir.exists():
|
|
87
|
+
candidate = search_dir / model_id
|
|
88
|
+
if candidate.exists():
|
|
89
|
+
return candidate
|
|
90
|
+
for gguf_file in search_dir.rglob("*.gguf"):
|
|
91
|
+
if model_id in gguf_file.name:
|
|
92
|
+
return gguf_file
|
|
93
|
+
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class LlamaCppBackend(BaseBackend):
|
|
98
|
+
"""Backend for GGUF models using llama-cpp-python."""
|
|
99
|
+
|
|
100
|
+
_name = "llama_cpp"
|
|
101
|
+
|
|
102
|
+
def __init__(self):
|
|
103
|
+
"""Initialize LlamaCppBackend."""
|
|
104
|
+
super().__init__()
|
|
105
|
+
self._llm = None
|
|
106
|
+
|
|
107
|
+
def load(self, config: BackendConfig) -> None:
|
|
108
|
+
"""Load a GGUF model."""
|
|
109
|
+
llama_cpp = _get_llama_cpp()
|
|
110
|
+
|
|
111
|
+
model_path = resolve_model_path(config.model_id)
|
|
112
|
+
|
|
113
|
+
if model_path is None:
|
|
114
|
+
raise BackendError(
|
|
115
|
+
f"Model not found: {config.model_id}. "
|
|
116
|
+
"Download with: parishad download <model_name>",
|
|
117
|
+
backend_name=self._name,
|
|
118
|
+
model_id=config.model_id,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
extra = config.extra or {}
|
|
122
|
+
n_gpu_layers = extra.get("n_gpu_layers", -1)
|
|
123
|
+
n_ctx = extra.get("n_ctx", config.context_length)
|
|
124
|
+
n_batch = extra.get("n_batch", 512)
|
|
125
|
+
verbose = extra.get("verbose", False)
|
|
126
|
+
chat_format = extra.get("chat_format", None)
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
suppress_ctx = _suppress_output(disable=False) if _suppress_output else None
|
|
130
|
+
if suppress_ctx:
|
|
131
|
+
with suppress_ctx:
|
|
132
|
+
self._llm = llama_cpp.Llama(
|
|
133
|
+
model_path=str(model_path),
|
|
134
|
+
n_gpu_layers=n_gpu_layers,
|
|
135
|
+
n_ctx=n_ctx,
|
|
136
|
+
n_batch=n_batch,
|
|
137
|
+
verbose=verbose,
|
|
138
|
+
chat_format=chat_format,
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
self._llm = llama_cpp.Llama(
|
|
142
|
+
model_path=str(model_path),
|
|
143
|
+
n_gpu_layers=n_gpu_layers,
|
|
144
|
+
n_ctx=n_ctx,
|
|
145
|
+
n_batch=n_batch,
|
|
146
|
+
verbose=verbose,
|
|
147
|
+
chat_format=chat_format,
|
|
148
|
+
)
|
|
149
|
+
self._config = config
|
|
150
|
+
self._model_id = config.model_id
|
|
151
|
+
self._loaded = True
|
|
152
|
+
|
|
153
|
+
except Exception as e:
|
|
154
|
+
raise BackendError(
|
|
155
|
+
f"Failed to load model: {e}",
|
|
156
|
+
backend_name=self._name,
|
|
157
|
+
model_id=config.model_id,
|
|
158
|
+
original_error=e,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def generate(
|
|
162
|
+
self,
|
|
163
|
+
prompt: str,
|
|
164
|
+
max_tokens: int,
|
|
165
|
+
temperature: float,
|
|
166
|
+
top_p: float,
|
|
167
|
+
stop: list[str] | None = None,
|
|
168
|
+
) -> BackendResult:
|
|
169
|
+
"""Generate text using llama.cpp."""
|
|
170
|
+
if not self._loaded or self._llm is None:
|
|
171
|
+
raise BackendError(
|
|
172
|
+
"Model not loaded",
|
|
173
|
+
backend_name=self._name,
|
|
174
|
+
model_id=self._model_id,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
start_time = time.perf_counter()
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
logger.debug(f"Calling llama_cpp with prompt len={len(prompt)}, max_tokens={max_tokens}, temp={temperature}")
|
|
181
|
+
|
|
182
|
+
result = self._llm(
|
|
183
|
+
prompt,
|
|
184
|
+
max_tokens=max_tokens,
|
|
185
|
+
temperature=max(temperature, 0.01),
|
|
186
|
+
top_p=top_p,
|
|
187
|
+
stop=stop or [],
|
|
188
|
+
echo=False,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
logger.debug(f"llama_cpp raw result keys: {result.keys()}")
|
|
192
|
+
if "choices" in result and result["choices"]:
|
|
193
|
+
logger.debug(f"First choice keys: {result['choices'][0].keys()}")
|
|
194
|
+
logger.debug(f"Finish reason: {result['choices'][0].get('finish_reason')}")
|
|
195
|
+
else:
|
|
196
|
+
logger.error(f"No choices in result: {result}")
|
|
197
|
+
|
|
198
|
+
text = result["choices"][0]["text"]
|
|
199
|
+
finish_reason = result["choices"][0].get("finish_reason", "stop")
|
|
200
|
+
|
|
201
|
+
usage = result.get("usage", {})
|
|
202
|
+
tokens_in = usage.get("prompt_tokens", self._estimate_tokens(prompt))
|
|
203
|
+
tokens_out = usage.get("completion_tokens", self._estimate_tokens(text))
|
|
204
|
+
|
|
205
|
+
latency_ms = (time.perf_counter() - start_time) * 1000
|
|
206
|
+
|
|
207
|
+
return BackendResult(
|
|
208
|
+
text=text,
|
|
209
|
+
tokens_in=tokens_in,
|
|
210
|
+
tokens_out=tokens_out,
|
|
211
|
+
model_id=self._model_id,
|
|
212
|
+
finish_reason=finish_reason,
|
|
213
|
+
latency_ms=latency_ms,
|
|
214
|
+
extra={"total_tokens": usage.get("total_tokens", tokens_in + tokens_out)},
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
except Exception as e:
|
|
218
|
+
raise BackendError(
|
|
219
|
+
f"Generation failed: {e}",
|
|
220
|
+
backend_name=self._name,
|
|
221
|
+
model_id=self._model_id,
|
|
222
|
+
original_error=e,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
def unload(self) -> None:
|
|
226
|
+
"""Unload the model to free memory."""
|
|
227
|
+
if self._llm is not None:
|
|
228
|
+
del self._llm
|
|
229
|
+
self._llm = None
|
|
230
|
+
|
|
231
|
+
super().unload()
|
|
232
|
+
gc.collect()
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def context_length(self) -> int:
|
|
236
|
+
"""Get the model's context length."""
|
|
237
|
+
if self._llm is not None:
|
|
238
|
+
return self._llm.n_ctx()
|
|
239
|
+
return self._config.context_length if self._config else 4096
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MLX backend for Apple Silicon.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from .base import BackendConfig, BackendError, BackendResult, BaseBackend
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Lazy imports
|
|
16
|
+
_mlx_lm = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_mlx_lm():
|
|
20
|
+
"""Lazy import of mlx-lm."""
|
|
21
|
+
global _mlx_lm
|
|
22
|
+
if _mlx_lm is None:
|
|
23
|
+
try:
|
|
24
|
+
import mlx_lm
|
|
25
|
+
_mlx_lm = mlx_lm
|
|
26
|
+
except ImportError:
|
|
27
|
+
raise ImportError(
|
|
28
|
+
"mlx-lm is required for MlxBackend. "
|
|
29
|
+
"Install with: pip install mlx-lm"
|
|
30
|
+
)
|
|
31
|
+
return _mlx_lm
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MlxBackend(BaseBackend):
|
|
35
|
+
"""Backend for MLX models on Apple Silicon."""
|
|
36
|
+
|
|
37
|
+
_name = "mlx"
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
"""Initialize MlxBackend."""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self._model = None
|
|
43
|
+
self._tokenizer = None
|
|
44
|
+
|
|
45
|
+
def load(self, config: BackendConfig) -> None:
|
|
46
|
+
"""Load an MLX model."""
|
|
47
|
+
mlx_lm = _get_mlx_lm()
|
|
48
|
+
|
|
49
|
+
model_id_or_path = config.model_id
|
|
50
|
+
|
|
51
|
+
p = Path(model_id_or_path)
|
|
52
|
+
if p.exists():
|
|
53
|
+
if p.is_file():
|
|
54
|
+
model_id_or_path = str(p.parent)
|
|
55
|
+
else:
|
|
56
|
+
model_id_or_path = str(p)
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
self._model, self._tokenizer = mlx_lm.load(
|
|
60
|
+
model_id_or_path,
|
|
61
|
+
tokenizer_config={"trust_remote_code": True}
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self._config = config
|
|
65
|
+
self._model_id = config.model_id
|
|
66
|
+
self._loaded = True
|
|
67
|
+
|
|
68
|
+
except Exception as e:
|
|
69
|
+
raise BackendError(
|
|
70
|
+
f"Failed to load MLX model: {e}",
|
|
71
|
+
backend_name=self._name,
|
|
72
|
+
model_id=config.model_id,
|
|
73
|
+
original_error=e,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def generate(
|
|
77
|
+
self,
|
|
78
|
+
prompt: str,
|
|
79
|
+
max_tokens: int,
|
|
80
|
+
temperature: float,
|
|
81
|
+
top_p: float,
|
|
82
|
+
stop: list[str] | None = None,
|
|
83
|
+
) -> BackendResult:
|
|
84
|
+
"""Generate text using MLX."""
|
|
85
|
+
if not self._loaded or self._model is None:
|
|
86
|
+
raise BackendError(
|
|
87
|
+
"Model not loaded",
|
|
88
|
+
backend_name=self._name,
|
|
89
|
+
model_id=self._model_id,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
mlx_lm = _get_mlx_lm()
|
|
93
|
+
start_time = time.perf_counter()
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
tokens_in = len(self._tokenizer.encode(prompt))
|
|
97
|
+
|
|
98
|
+
text = mlx_lm.generate(
|
|
99
|
+
model=self._model,
|
|
100
|
+
tokenizer=self._tokenizer,
|
|
101
|
+
prompt=prompt,
|
|
102
|
+
max_tokens=max_tokens,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
finish_reason = "length"
|
|
106
|
+
if stop:
|
|
107
|
+
for s in stop:
|
|
108
|
+
if s in text:
|
|
109
|
+
text = text.split(s)[0]
|
|
110
|
+
finish_reason = "stop"
|
|
111
|
+
break
|
|
112
|
+
|
|
113
|
+
tokens_out = len(self._tokenizer.encode(text))
|
|
114
|
+
latency_ms = (time.perf_counter() - start_time) * 1000
|
|
115
|
+
|
|
116
|
+
return BackendResult(
|
|
117
|
+
text=text,
|
|
118
|
+
tokens_in=tokens_in,
|
|
119
|
+
tokens_out=tokens_out,
|
|
120
|
+
model_id=self._model_id,
|
|
121
|
+
finish_reason=finish_reason,
|
|
122
|
+
latency_ms=latency_ms,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
except Exception as e:
|
|
126
|
+
raise BackendError(
|
|
127
|
+
f"MLX generation failed: {e}",
|
|
128
|
+
backend_name=self._name,
|
|
129
|
+
model_id=self._model_id,
|
|
130
|
+
original_error=e,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def unload(self) -> None:
|
|
134
|
+
"""Unload model."""
|
|
135
|
+
if self._model is not None:
|
|
136
|
+
del self._model
|
|
137
|
+
self._model = None
|
|
138
|
+
if self._tokenizer is not None:
|
|
139
|
+
del self._tokenizer
|
|
140
|
+
self._tokenizer = None
|
|
141
|
+
super().unload()
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ollama backend using native API.
|
|
3
|
+
|
|
4
|
+
Supports both OpenAI-compatible mode (via inherited OllamaBackend in openai_api.py)
|
|
5
|
+
and native Ollama API mode (this file).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import time
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from .base import BackendConfig, BackendError, BackendResult, BaseBackend
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Lazy import
|
|
19
|
+
_requests = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_requests():
|
|
23
|
+
"""Lazy import of requests."""
|
|
24
|
+
global _requests
|
|
25
|
+
if _requests is None:
|
|
26
|
+
try:
|
|
27
|
+
import requests
|
|
28
|
+
_requests = requests
|
|
29
|
+
except ImportError:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"requests package is required for OllamaNativeBackend. "
|
|
32
|
+
"Install with: pip install requests"
|
|
33
|
+
)
|
|
34
|
+
return _requests
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OllamaNativeBackend(BaseBackend):
|
|
38
|
+
"""
|
|
39
|
+
Backend for Ollama using native API.
|
|
40
|
+
|
|
41
|
+
Uses Ollama's /api/generate endpoint directly instead of OpenAI compatibility layer.
|
|
42
|
+
This provides access to Ollama-specific features like:
|
|
43
|
+
- Raw mode for exact prompts
|
|
44
|
+
- System prompts
|
|
45
|
+
- Streaming
|
|
46
|
+
- Context management
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
_name = "ollama_native"
|
|
50
|
+
|
|
51
|
+
def __init__(self):
|
|
52
|
+
"""Initialize OllamaNativeBackend."""
|
|
53
|
+
super().__init__()
|
|
54
|
+
self._base_url = "http://localhost:11434"
|
|
55
|
+
self._session = None
|
|
56
|
+
|
|
57
|
+
def load(self, config: BackendConfig) -> None:
|
|
58
|
+
"""Initialize Ollama connection."""
|
|
59
|
+
requests = _get_requests()
|
|
60
|
+
|
|
61
|
+
extra = config.extra or {}
|
|
62
|
+
self._base_url = extra.get("base_url", "http://localhost:11434")
|
|
63
|
+
|
|
64
|
+
# Strip ollama: prefix if present
|
|
65
|
+
model_id = config.model_id
|
|
66
|
+
if model_id.startswith("ollama:"):
|
|
67
|
+
model_id = model_id.replace("ollama:", "", 1)
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
# Test connection
|
|
71
|
+
self._session = requests.Session()
|
|
72
|
+
response = self._session.get(f"{self._base_url}/api/tags", timeout=5)
|
|
73
|
+
|
|
74
|
+
if response.status_code != 200:
|
|
75
|
+
raise BackendError(
|
|
76
|
+
f"Ollama server not responding at {self._base_url}",
|
|
77
|
+
backend_name=self._name
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Check if model is available
|
|
81
|
+
tags = response.json()
|
|
82
|
+
available_models = [m["name"] for m in tags.get("models", [])]
|
|
83
|
+
|
|
84
|
+
# Check for exact match or partial match
|
|
85
|
+
model_found = False
|
|
86
|
+
for m in available_models:
|
|
87
|
+
if model_id in m or m in model_id:
|
|
88
|
+
model_found = True
|
|
89
|
+
break
|
|
90
|
+
|
|
91
|
+
if not model_found and available_models:
|
|
92
|
+
logger.warning(
|
|
93
|
+
f"Model '{model_id}' not found in Ollama. "
|
|
94
|
+
f"Available: {available_models[:5]}..."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self._config = config
|
|
98
|
+
self._model_id = model_id
|
|
99
|
+
self._loaded = True
|
|
100
|
+
|
|
101
|
+
logger.info(f"✅ Connected to Ollama at {self._base_url}")
|
|
102
|
+
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise BackendError(
|
|
105
|
+
f"Failed to connect to Ollama: {e}",
|
|
106
|
+
backend_name=self._name,
|
|
107
|
+
model_id=config.model_id,
|
|
108
|
+
original_error=e,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def generate(
|
|
112
|
+
self,
|
|
113
|
+
prompt: str,
|
|
114
|
+
max_tokens: int,
|
|
115
|
+
temperature: float,
|
|
116
|
+
top_p: float,
|
|
117
|
+
stop: list[str] | None = None,
|
|
118
|
+
) -> BackendResult:
|
|
119
|
+
"""Generate text using Ollama native API."""
|
|
120
|
+
if not self._loaded or self._session is None:
|
|
121
|
+
raise BackendError(
|
|
122
|
+
"Backend not loaded",
|
|
123
|
+
backend_name=self._name,
|
|
124
|
+
model_id=self._model_id,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
start_time = time.perf_counter()
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
# Parse system/user if present in prompt
|
|
131
|
+
system_prompt = ""
|
|
132
|
+
user_prompt = prompt
|
|
133
|
+
|
|
134
|
+
if "<|start_header_id|>system<|end_header_id|>" in prompt:
|
|
135
|
+
# Llama-3 format - extract parts
|
|
136
|
+
parts = prompt.split("<|start_header_id|>")
|
|
137
|
+
for part in parts:
|
|
138
|
+
if part.startswith("system"):
|
|
139
|
+
system_prompt = part.split("<|end_header_id|>")[1].split("<|eot_id|>")[0].strip()
|
|
140
|
+
elif part.startswith("user"):
|
|
141
|
+
user_prompt = part.split("<|end_header_id|>")[1].split("<|eot_id|>")[0].strip()
|
|
142
|
+
|
|
143
|
+
# Build request payload
|
|
144
|
+
payload: dict[str, Any] = {
|
|
145
|
+
"model": self._model_id,
|
|
146
|
+
"prompt": user_prompt,
|
|
147
|
+
"stream": False,
|
|
148
|
+
"options": {
|
|
149
|
+
"num_predict": max_tokens,
|
|
150
|
+
"temperature": temperature,
|
|
151
|
+
"top_p": top_p,
|
|
152
|
+
# Increase context window for large documents (default 32k)
|
|
153
|
+
"num_ctx": self._config.extra.get("num_ctx", 32768) if self._config and self._config.extra else 32768,
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
if system_prompt:
|
|
158
|
+
payload["system"] = system_prompt
|
|
159
|
+
|
|
160
|
+
if stop:
|
|
161
|
+
payload["options"]["stop"] = stop
|
|
162
|
+
|
|
163
|
+
# Make request
|
|
164
|
+
response = self._session.post(
|
|
165
|
+
f"{self._base_url}/api/generate",
|
|
166
|
+
json=payload,
|
|
167
|
+
# Increase default timeout to 300s (5m) for slower models/hardware
|
|
168
|
+
timeout=self._config.timeout if self._config and self._config.timeout else 300
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if response.status_code != 200:
|
|
172
|
+
raise BackendError(
|
|
173
|
+
f"Ollama API error: {response.status_code} - {response.text}",
|
|
174
|
+
backend_name=self._name,
|
|
175
|
+
model_id=self._model_id
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
result = response.json()
|
|
179
|
+
|
|
180
|
+
text = result.get("response", "")
|
|
181
|
+
|
|
182
|
+
# Get token counts from Ollama response
|
|
183
|
+
tokens_in = result.get("prompt_eval_count", self._estimate_tokens(prompt))
|
|
184
|
+
tokens_out = result.get("eval_count", self._estimate_tokens(text))
|
|
185
|
+
|
|
186
|
+
latency_ms = (time.perf_counter() - start_time) * 1000
|
|
187
|
+
|
|
188
|
+
# Determine finish reason
|
|
189
|
+
done_reason = result.get("done_reason", "stop")
|
|
190
|
+
if done_reason == "length":
|
|
191
|
+
finish_reason = "length"
|
|
192
|
+
else:
|
|
193
|
+
finish_reason = "stop"
|
|
194
|
+
|
|
195
|
+
return BackendResult(
|
|
196
|
+
text=text,
|
|
197
|
+
tokens_in=tokens_in,
|
|
198
|
+
tokens_out=tokens_out,
|
|
199
|
+
model_id=self._model_id,
|
|
200
|
+
finish_reason=finish_reason,
|
|
201
|
+
latency_ms=latency_ms,
|
|
202
|
+
extra={
|
|
203
|
+
"total_duration": result.get("total_duration"),
|
|
204
|
+
"load_duration": result.get("load_duration"),
|
|
205
|
+
"eval_duration": result.get("eval_duration"),
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
except Exception as e:
|
|
210
|
+
if isinstance(e, BackendError):
|
|
211
|
+
raise
|
|
212
|
+
raise BackendError(
|
|
213
|
+
f"Ollama generation failed: {e}",
|
|
214
|
+
backend_name=self._name,
|
|
215
|
+
model_id=self._model_id,
|
|
216
|
+
original_error=e,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def unload(self) -> None:
|
|
220
|
+
"""Close the session."""
|
|
221
|
+
if self._session:
|
|
222
|
+
self._session.close()
|
|
223
|
+
self._session = None
|
|
224
|
+
super().unload()
|
|
225
|
+
|
|
226
|
+
def list_models(self) -> list[str]:
|
|
227
|
+
"""List available Ollama models."""
|
|
228
|
+
if not self._session:
|
|
229
|
+
return []
|
|
230
|
+
|
|
231
|
+
try:
|
|
232
|
+
response = self._session.get(f"{self._base_url}/api/tags", timeout=5)
|
|
233
|
+
if response.status_code == 200:
|
|
234
|
+
tags = response.json()
|
|
235
|
+
return [m["name"] for m in tags.get("models", [])]
|
|
236
|
+
except Exception:
|
|
237
|
+
pass
|
|
238
|
+
return []
|
|
239
|
+
|
|
240
|
+
def pull_model(self, model_name: str) -> bool:
|
|
241
|
+
"""Pull a model from Ollama registry."""
|
|
242
|
+
if not self._session:
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
response = self._session.post(
|
|
247
|
+
f"{self._base_url}/api/pull",
|
|
248
|
+
json={"name": model_name, "stream": False},
|
|
249
|
+
timeout=600 # Models can take a while to download
|
|
250
|
+
)
|
|
251
|
+
return response.status_code == 200
|
|
252
|
+
except Exception:
|
|
253
|
+
return False
|