parishad 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parishad/__init__.py +70 -0
- parishad/__main__.py +10 -0
- parishad/checker/__init__.py +25 -0
- parishad/checker/deterministic.py +644 -0
- parishad/checker/ensemble.py +496 -0
- parishad/checker/retrieval.py +546 -0
- parishad/cli/__init__.py +6 -0
- parishad/cli/code.py +3254 -0
- parishad/cli/main.py +1158 -0
- parishad/cli/prarambh.py +99 -0
- parishad/cli/sthapana.py +368 -0
- parishad/config/modes.py +139 -0
- parishad/config/pipeline.core.yaml +128 -0
- parishad/config/pipeline.extended.yaml +172 -0
- parishad/config/pipeline.fast.yaml +89 -0
- parishad/config/user_config.py +115 -0
- parishad/data/catalog.py +118 -0
- parishad/data/models.json +108 -0
- parishad/memory/__init__.py +79 -0
- parishad/models/__init__.py +181 -0
- parishad/models/backends/__init__.py +247 -0
- parishad/models/backends/base.py +211 -0
- parishad/models/backends/huggingface.py +318 -0
- parishad/models/backends/llama_cpp.py +239 -0
- parishad/models/backends/mlx_lm.py +141 -0
- parishad/models/backends/ollama.py +253 -0
- parishad/models/backends/openai_api.py +193 -0
- parishad/models/backends/transformers_hf.py +198 -0
- parishad/models/costs.py +385 -0
- parishad/models/downloader.py +1557 -0
- parishad/models/optimizations.py +871 -0
- parishad/models/profiles.py +610 -0
- parishad/models/reliability.py +876 -0
- parishad/models/runner.py +651 -0
- parishad/models/tokenization.py +287 -0
- parishad/orchestrator/__init__.py +24 -0
- parishad/orchestrator/config_loader.py +210 -0
- parishad/orchestrator/engine.py +1113 -0
- parishad/orchestrator/exceptions.py +14 -0
- parishad/roles/__init__.py +71 -0
- parishad/roles/base.py +712 -0
- parishad/roles/dandadhyaksha.py +163 -0
- parishad/roles/darbari.py +246 -0
- parishad/roles/majumdar.py +274 -0
- parishad/roles/pantapradhan.py +150 -0
- parishad/roles/prerak.py +357 -0
- parishad/roles/raja.py +345 -0
- parishad/roles/sacheev.py +203 -0
- parishad/roles/sainik.py +427 -0
- parishad/roles/sar_senapati.py +164 -0
- parishad/roles/vidushak.py +69 -0
- parishad/tools/__init__.py +7 -0
- parishad/tools/base.py +57 -0
- parishad/tools/fs.py +110 -0
- parishad/tools/perception.py +96 -0
- parishad/tools/retrieval.py +74 -0
- parishad/tools/shell.py +103 -0
- parishad/utils/__init__.py +7 -0
- parishad/utils/hardware.py +122 -0
- parishad/utils/logging.py +79 -0
- parishad/utils/scanner.py +164 -0
- parishad/utils/text.py +61 -0
- parishad/utils/tracing.py +133 -0
- parishad-0.1.0.dist-info/METADATA +256 -0
- parishad-0.1.0.dist-info/RECORD +68 -0
- parishad-0.1.0.dist-info/WHEEL +4 -0
- parishad-0.1.0.dist-info/entry_points.txt +2 -0
- parishad-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,651 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified model abstraction layer for Parishad.
|
|
3
|
+
|
|
4
|
+
Provides a consistent interface for different LLM backends:
|
|
5
|
+
- llama.cpp (local GGUF models)
|
|
6
|
+
- OpenAI API
|
|
7
|
+
- HuggingFace Transformers
|
|
8
|
+
|
|
9
|
+
This module serves as the router that dispatches to the appropriate backend
|
|
10
|
+
based on configuration.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Optional, Tuple
|
|
19
|
+
import logging
|
|
20
|
+
import time
|
|
21
|
+
|
|
22
|
+
from ..roles.base import Slot
|
|
23
|
+
from .backends import (
|
|
24
|
+
BackendConfig,
|
|
25
|
+
BackendError,
|
|
26
|
+
BackendResult,
|
|
27
|
+
is_backend_available,
|
|
28
|
+
ModelBackend as BackendProtocol,
|
|
29
|
+
LlamaCppBackend,
|
|
30
|
+
OpenAIBackend,
|
|
31
|
+
TransformersBackend,
|
|
32
|
+
MlxBackend,
|
|
33
|
+
OllamaNativeBackend,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# Exceptions
|
|
42
|
+
# =============================================================================
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ModelRunnerError(Exception):
|
|
46
|
+
"""Base exception for ModelRunner errors."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class UnknownSlotError(ModelRunnerError):
|
|
51
|
+
"""Raised when an unknown slot is requested."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ModelBackendError(ModelRunnerError):
|
|
56
|
+
"""Raised when backend model call fails."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class BackendNotAvailableError(ModelRunnerError):
|
|
61
|
+
"""Raised when a required backend is not installed."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# =============================================================================
|
|
66
|
+
# Enums and Configuration
|
|
67
|
+
# =============================================================================
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class Backend(Enum):
|
|
71
|
+
"""Supported LLM backends."""
|
|
72
|
+
LLAMA_CPP = "llama_cpp"
|
|
73
|
+
OPENAI = "openai"
|
|
74
|
+
OLLAMA = "ollama"
|
|
75
|
+
TRANSFORMERS = "transformers"
|
|
76
|
+
MLX = "mlx"
|
|
77
|
+
ANTHROPIC = "anthropic"
|
|
78
|
+
GRPC = "grpc"
|
|
79
|
+
PRP = "prp"
|
|
80
|
+
NATIVE = "native"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class SlotConfig:
|
|
85
|
+
"""
|
|
86
|
+
Configuration for a model slot.
|
|
87
|
+
|
|
88
|
+
Defines how a slot (SMALL/MID/BIG) should be configured,
|
|
89
|
+
including the backend to use and generation parameters.
|
|
90
|
+
"""
|
|
91
|
+
model_id: str
|
|
92
|
+
backend: Backend | str = Backend.TRANSFORMERS
|
|
93
|
+
|
|
94
|
+
# Context and generation settings
|
|
95
|
+
context_length: int = 8192
|
|
96
|
+
default_max_tokens: int = 1024
|
|
97
|
+
default_temperature: float = 0.5
|
|
98
|
+
top_p: float = 0.9
|
|
99
|
+
|
|
100
|
+
# Stop sequences
|
|
101
|
+
stop: list[str] | None = None
|
|
102
|
+
|
|
103
|
+
# Timeout for generation
|
|
104
|
+
timeout: float = 120.0
|
|
105
|
+
|
|
106
|
+
# Backend-specific options (passed to backend as 'extra')
|
|
107
|
+
# For llama_cpp: n_gpu_layers, n_batch, verbose, chat_format
|
|
108
|
+
# For transformers: quantization, device_map, torch_dtype
|
|
109
|
+
# For openai: api_key_env, base_url, organization
|
|
110
|
+
quantization: Optional[str] = None
|
|
111
|
+
device_map: str = "auto"
|
|
112
|
+
model_file: Optional[str] = None
|
|
113
|
+
n_gpu_layers: int = -1
|
|
114
|
+
api_key_env: Optional[str] = None
|
|
115
|
+
|
|
116
|
+
# Generic extra args (for backend-specific settings like host/port)
|
|
117
|
+
extra: dict[str, Any] = field(default_factory=dict)
|
|
118
|
+
|
|
119
|
+
# Legacy fields for backward compatibility
|
|
120
|
+
max_context: int = 8192 # Alias for context_length
|
|
121
|
+
top_k: int = 50
|
|
122
|
+
repetition_penalty: float = 1.0
|
|
123
|
+
|
|
124
|
+
def __post_init__(self):
|
|
125
|
+
"""Normalize backend to enum."""
|
|
126
|
+
if isinstance(self.backend, str):
|
|
127
|
+
try:
|
|
128
|
+
self.backend = Backend(self.backend)
|
|
129
|
+
except ValueError:
|
|
130
|
+
# Keep as string for unknown backends
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# Sync max_context with context_length
|
|
134
|
+
if self.max_context != 8192:
|
|
135
|
+
self.context_length = self.max_context
|
|
136
|
+
|
|
137
|
+
def to_backend_config(self) -> BackendConfig:
|
|
138
|
+
"""Convert to BackendConfig for the backends package."""
|
|
139
|
+
extra = {
|
|
140
|
+
"quantization": self.quantization,
|
|
141
|
+
"device_map": self.device_map,
|
|
142
|
+
"n_gpu_layers": self.n_gpu_layers,
|
|
143
|
+
"top_k": self.top_k,
|
|
144
|
+
"repetition_penalty": self.repetition_penalty,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
if self.model_file:
|
|
148
|
+
extra["model_file"] = self.model_file
|
|
149
|
+
if self.api_key_env:
|
|
150
|
+
extra["api_key_env"] = self.api_key_env
|
|
151
|
+
|
|
152
|
+
# Merge generic extra args
|
|
153
|
+
if self.extra:
|
|
154
|
+
extra.update(self.extra)
|
|
155
|
+
|
|
156
|
+
return BackendConfig(
|
|
157
|
+
model_id=self.model_file or self.model_id,
|
|
158
|
+
context_length=self.context_length,
|
|
159
|
+
temperature=self.default_temperature,
|
|
160
|
+
top_p=self.top_p,
|
|
161
|
+
max_tokens=self.default_max_tokens,
|
|
162
|
+
stop=self.stop,
|
|
163
|
+
timeout=self.timeout,
|
|
164
|
+
extra=extra,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@dataclass
|
|
169
|
+
class ModelConfig:
|
|
170
|
+
"""
|
|
171
|
+
Complete model configuration for Parishad.
|
|
172
|
+
|
|
173
|
+
Can be loaded from YAML or constructed programmatically.
|
|
174
|
+
"""
|
|
175
|
+
slots: dict[str, SlotConfig] = field(default_factory=dict)
|
|
176
|
+
|
|
177
|
+
# Cost tracking weights per slot
|
|
178
|
+
token_weights: dict[str, float] = field(default_factory=lambda: {
|
|
179
|
+
"small": 1.0,
|
|
180
|
+
"mid": 2.5,
|
|
181
|
+
"big": 5.0
|
|
182
|
+
})
|
|
183
|
+
|
|
184
|
+
# Legacy attributes for backward compatibility
|
|
185
|
+
small: SlotConfig | None = None
|
|
186
|
+
mid: SlotConfig | None = None
|
|
187
|
+
big: SlotConfig | None = None
|
|
188
|
+
default_temperature: float = 0.5
|
|
189
|
+
|
|
190
|
+
def __post_init__(self):
|
|
191
|
+
"""Initialize slots from legacy attributes if provided."""
|
|
192
|
+
if self.small and "small" not in self.slots:
|
|
193
|
+
self.slots["small"] = self.small
|
|
194
|
+
if self.mid and "mid" not in self.slots:
|
|
195
|
+
self.slots["mid"] = self.mid
|
|
196
|
+
if self.big and "big" not in self.slots:
|
|
197
|
+
self.slots["big"] = self.big
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def from_yaml(cls, path: str | Path) -> "ModelConfig":
|
|
201
|
+
"""Load configuration from YAML file."""
|
|
202
|
+
import yaml
|
|
203
|
+
|
|
204
|
+
with open(path) as f:
|
|
205
|
+
data = yaml.safe_load(f)
|
|
206
|
+
|
|
207
|
+
slots = {}
|
|
208
|
+
for slot_name, slot_data in data.get("slots", {}).items():
|
|
209
|
+
slots[slot_name] = cls._parse_slot_config(slot_data)
|
|
210
|
+
|
|
211
|
+
token_weights = data.get("cost", {}).get("token_weights", {
|
|
212
|
+
"small": 1.0, "mid": 2.5, "big": 5.0
|
|
213
|
+
})
|
|
214
|
+
|
|
215
|
+
return cls(slots=slots, token_weights=token_weights)
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def from_profile(cls, profile_name: str, path: str | Path) -> "ModelConfig":
|
|
219
|
+
"""
|
|
220
|
+
Load a specific profile from YAML file.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
profile_name: Name of the profile to load (e.g., 'stub', 'local_small')
|
|
224
|
+
path: Path to the YAML config file
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
ModelConfig with the profile's slot configurations
|
|
228
|
+
"""
|
|
229
|
+
import yaml
|
|
230
|
+
|
|
231
|
+
with open(path) as f:
|
|
232
|
+
data = yaml.safe_load(f)
|
|
233
|
+
|
|
234
|
+
profiles = data.get("profiles", {})
|
|
235
|
+
|
|
236
|
+
if profile_name not in profiles:
|
|
237
|
+
available = list(profiles.keys())
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Profile '{profile_name}' not found. Available: {available}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
profile_data = profiles[profile_name]
|
|
243
|
+
|
|
244
|
+
slots = {}
|
|
245
|
+
slots_data = profile_data.get("slots", profile_data)
|
|
246
|
+
for slot_name, slot_data in slots_data.items():
|
|
247
|
+
slots[slot_name] = cls._parse_slot_config(slot_data)
|
|
248
|
+
|
|
249
|
+
token_weights = data.get("cost", {}).get("token_weights", {
|
|
250
|
+
"small": 1.0, "mid": 2.5, "big": 5.0
|
|
251
|
+
})
|
|
252
|
+
|
|
253
|
+
return cls(slots=slots, token_weights=token_weights)
|
|
254
|
+
|
|
255
|
+
@staticmethod
|
|
256
|
+
def _parse_slot_config(slot_data: dict) -> SlotConfig:
|
|
257
|
+
"""Parse a slot configuration dictionary."""
|
|
258
|
+
backend_str = slot_data.get("backend", "transformers")
|
|
259
|
+
try:
|
|
260
|
+
backend = Backend(backend_str)
|
|
261
|
+
except ValueError:
|
|
262
|
+
backend = backend_str # Keep as string
|
|
263
|
+
|
|
264
|
+
return SlotConfig(
|
|
265
|
+
model_id=slot_data.get("model_id", ""),
|
|
266
|
+
backend=backend,
|
|
267
|
+
context_length=slot_data.get("context_length", slot_data.get("max_context", 8192)),
|
|
268
|
+
default_max_tokens=slot_data.get("max_tokens", slot_data.get("default_max_tokens", 1024)),
|
|
269
|
+
default_temperature=slot_data.get("temperature", slot_data.get("default_temperature", 0.5)),
|
|
270
|
+
top_p=slot_data.get("top_p", 0.9),
|
|
271
|
+
stop=slot_data.get("stop"),
|
|
272
|
+
timeout=slot_data.get("timeout", 120.0),
|
|
273
|
+
quantization=slot_data.get("quantization"),
|
|
274
|
+
device_map=slot_data.get("device_map", "auto"),
|
|
275
|
+
model_file=slot_data.get("model_file"),
|
|
276
|
+
n_gpu_layers=slot_data.get("n_gpu_layers", -1),
|
|
277
|
+
api_key_env=slot_data.get("api_key_env"),
|
|
278
|
+
max_context=slot_data.get("max_context", 8192),
|
|
279
|
+
top_k=slot_data.get("top_k", 50),
|
|
280
|
+
repetition_penalty=slot_data.get("repetition_penalty", 1.0),
|
|
281
|
+
extra=slot_data.get("extra", {}), # Pass through extra config
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# =============================================================================
|
|
286
|
+
# Backend Factory
|
|
287
|
+
# =============================================================================
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _create_backend(backend_type: Backend | str) -> BackendProtocol:
|
|
291
|
+
"""
|
|
292
|
+
Create a backend instance based on type using unified factory.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
backend_type: Backend enum or string name
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Backend instance (not yet loaded/configured)
|
|
299
|
+
|
|
300
|
+
Raises:
|
|
301
|
+
BackendNotAvailableError: If backend dependencies not installed
|
|
302
|
+
"""
|
|
303
|
+
if isinstance(backend_type, str):
|
|
304
|
+
backend_name = backend_type
|
|
305
|
+
else:
|
|
306
|
+
backend_name = backend_type.value
|
|
307
|
+
|
|
308
|
+
if backend_name == "llama_cpp":
|
|
309
|
+
if not is_backend_available("llama_cpp"):
|
|
310
|
+
raise BackendNotAvailableError(
|
|
311
|
+
"llama-cpp-python is not installed. "
|
|
312
|
+
"Install with: pip install llama-cpp-python"
|
|
313
|
+
)
|
|
314
|
+
return LlamaCppBackend()
|
|
315
|
+
|
|
316
|
+
if backend_name == "openai":
|
|
317
|
+
if not is_backend_available("openai"):
|
|
318
|
+
raise BackendNotAvailableError(
|
|
319
|
+
"openai package is not installed. "
|
|
320
|
+
"Install with: pip install openai"
|
|
321
|
+
)
|
|
322
|
+
return OpenAIBackend()
|
|
323
|
+
|
|
324
|
+
if backend_name == "transformers":
|
|
325
|
+
if not is_backend_available("transformers"):
|
|
326
|
+
raise BackendNotAvailableError(
|
|
327
|
+
"transformers and torch are not installed. "
|
|
328
|
+
"Install with: pip install transformers torch"
|
|
329
|
+
)
|
|
330
|
+
return TransformersBackend()
|
|
331
|
+
|
|
332
|
+
if backend_name == "mlx":
|
|
333
|
+
if not is_backend_available("mlx"):
|
|
334
|
+
raise BackendNotAvailableError(
|
|
335
|
+
"mlx-lm is not installed. "
|
|
336
|
+
"Install with: pip install mlx-lm"
|
|
337
|
+
)
|
|
338
|
+
return MlxBackend()
|
|
339
|
+
|
|
340
|
+
if backend_name == "ollama":
|
|
341
|
+
if not is_backend_available("ollama_native"):
|
|
342
|
+
raise BackendNotAvailableError(
|
|
343
|
+
"requests package is not installed. "
|
|
344
|
+
"Install with: pip install requests"
|
|
345
|
+
)
|
|
346
|
+
return OllamaNativeBackend()
|
|
347
|
+
|
|
348
|
+
raise ValueError(f"Unknown backend type: {backend_name}")
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
# =============================================================================
|
|
352
|
+
# ModelRunner
|
|
353
|
+
# =============================================================================
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class ModelRunner:
|
|
357
|
+
"""
|
|
358
|
+
Unified interface for running models across different backends.
|
|
359
|
+
|
|
360
|
+
Manages multiple model slots (SMALL/MID/BIG) and routes generation
|
|
361
|
+
requests to the appropriate backend based on configuration.
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
def __init__(
|
|
365
|
+
self,
|
|
366
|
+
config: Optional[ModelConfig] = None,
|
|
367
|
+
):
|
|
368
|
+
"""
|
|
369
|
+
Initialize ModelRunner.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
config: Model configuration. If None, uses defaults.
|
|
373
|
+
"""
|
|
374
|
+
self.config = config or ModelConfig()
|
|
375
|
+
self._backends: dict[str, BackendProtocol] = {}
|
|
376
|
+
self._loaded_slots: set[str] = set()
|
|
377
|
+
|
|
378
|
+
@classmethod
|
|
379
|
+
def from_profile(
|
|
380
|
+
cls,
|
|
381
|
+
profile_name: str,
|
|
382
|
+
config_path: str | Path,
|
|
383
|
+
) -> "ModelRunner":
|
|
384
|
+
"""
|
|
385
|
+
Create ModelRunner from a named profile in config file.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
profile_name: Profile to load (e.g., 'local_small')
|
|
389
|
+
config_path: Path to models.yaml configuration
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
Configured ModelRunner
|
|
393
|
+
"""
|
|
394
|
+
config = ModelConfig.from_profile(profile_name, config_path)
|
|
395
|
+
return cls(config=config)
|
|
396
|
+
|
|
397
|
+
def _get_backend(self, slot: Slot) -> BackendProtocol:
|
|
398
|
+
"""
|
|
399
|
+
Get or create backend for a slot.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
slot: Slot to get backend for
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
Backend instance for the slot
|
|
406
|
+
"""
|
|
407
|
+
slot_name = slot.value
|
|
408
|
+
|
|
409
|
+
if slot_name in self._backends:
|
|
410
|
+
return self._backends[slot_name]
|
|
411
|
+
|
|
412
|
+
slot_config = self.config.slots.get(slot_name)
|
|
413
|
+
if not slot_config:
|
|
414
|
+
raise UnknownSlotError(f"No configuration for slot: {slot_name}")
|
|
415
|
+
|
|
416
|
+
backend = _create_backend(slot_config.backend)
|
|
417
|
+
|
|
418
|
+
# Get backend type
|
|
419
|
+
backend_type = slot_config.backend
|
|
420
|
+
if isinstance(backend_type, Backend):
|
|
421
|
+
backend_name = backend_type.value
|
|
422
|
+
else:
|
|
423
|
+
backend_name = backend_type
|
|
424
|
+
|
|
425
|
+
# Create backend using unified factory
|
|
426
|
+
try:
|
|
427
|
+
backend = _create_backend(backend_name)
|
|
428
|
+
self._backends[slot_name] = backend
|
|
429
|
+
return backend
|
|
430
|
+
except BackendNotAvailableError:
|
|
431
|
+
raise
|
|
432
|
+
except Exception as e:
|
|
433
|
+
raise BackendError(
|
|
434
|
+
f"Failed to create backend '{backend_name}' for slot '{slot_name}': {e}"
|
|
435
|
+
) from e
|
|
436
|
+
|
|
437
|
+
def load_slot(self, slot: Slot) -> None:
|
|
438
|
+
"""
|
|
439
|
+
Load a model slot.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
slot: Slot to load
|
|
443
|
+
"""
|
|
444
|
+
slot_name = slot.value
|
|
445
|
+
|
|
446
|
+
if slot_name in self._loaded_slots:
|
|
447
|
+
return
|
|
448
|
+
|
|
449
|
+
slot_config = self.config.slots.get(slot_name)
|
|
450
|
+
|
|
451
|
+
if not slot_config:
|
|
452
|
+
raise UnknownSlotError(f"No configuration for slot: {slot_name}")
|
|
453
|
+
|
|
454
|
+
# Get backend (already configured via factory)
|
|
455
|
+
backend = self._get_backend(slot)
|
|
456
|
+
|
|
457
|
+
# Build backend config
|
|
458
|
+
backend_config = slot_config.to_backend_config()
|
|
459
|
+
|
|
460
|
+
# DEBUG: Log what we're loading
|
|
461
|
+
logger.info(f"Loading slot {slot_name}: model_id={backend_config.model_id}")
|
|
462
|
+
|
|
463
|
+
# Load the backend with configuration
|
|
464
|
+
backend.load(backend_config)
|
|
465
|
+
|
|
466
|
+
# Mark slot as loaded
|
|
467
|
+
self._loaded_slots.add(slot_name)
|
|
468
|
+
|
|
469
|
+
backend_name = backend.name if hasattr(backend, 'name') else str(type(backend).__name__)
|
|
470
|
+
logger.info(f"Slot {slot_name} loaded with backend {backend_name}")
|
|
471
|
+
|
|
472
|
+
def unload_slot(self, slot: Slot) -> None:
|
|
473
|
+
"""
|
|
474
|
+
Unload a model slot to free memory.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
slot: Slot to unload
|
|
478
|
+
"""
|
|
479
|
+
slot_name = slot.value
|
|
480
|
+
|
|
481
|
+
if slot_name not in self._loaded_slots:
|
|
482
|
+
return
|
|
483
|
+
|
|
484
|
+
if slot_name in self._backends:
|
|
485
|
+
self._backends[slot_name].unload()
|
|
486
|
+
del self._backends[slot_name]
|
|
487
|
+
|
|
488
|
+
self._loaded_slots.discard(slot_name)
|
|
489
|
+
logger.info(f"Slot {slot_name} unloaded")
|
|
490
|
+
|
|
491
|
+
def generate(
|
|
492
|
+
self,
|
|
493
|
+
system_prompt: str,
|
|
494
|
+
user_message: str,
|
|
495
|
+
slot: Slot,
|
|
496
|
+
max_tokens: Optional[int] = None,
|
|
497
|
+
temperature: Optional[float] = None,
|
|
498
|
+
stop: list[str] | None = None,
|
|
499
|
+
**kwargs,
|
|
500
|
+
) -> Tuple[str, int, str]:
|
|
501
|
+
"""
|
|
502
|
+
Generate a response using the specified slot.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
system_prompt: System prompt for the model
|
|
506
|
+
user_message: User message to respond to
|
|
507
|
+
slot: Which model slot to use (SMALL, MID, BIG)
|
|
508
|
+
max_tokens: Maximum tokens to generate (overrides config)
|
|
509
|
+
temperature: Sampling temperature (overrides config)
|
|
510
|
+
stop: Stop sequences (overrides config)
|
|
511
|
+
**kwargs: Additional backend-specific parameters
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
Tuple of (response_text, tokens_used, model_id)
|
|
515
|
+
|
|
516
|
+
Raises:
|
|
517
|
+
UnknownSlotError: If slot is not configured
|
|
518
|
+
ModelBackendError: If backend call fails
|
|
519
|
+
"""
|
|
520
|
+
# Validate slot
|
|
521
|
+
if not isinstance(slot, Slot):
|
|
522
|
+
raise UnknownSlotError(
|
|
523
|
+
f"Invalid slot type: {type(slot)}. Must be Slot enum."
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
slot_name = slot.value
|
|
527
|
+
|
|
528
|
+
# Check if slot exists in config
|
|
529
|
+
if slot_name not in self.config.slots:
|
|
530
|
+
raise UnknownSlotError(
|
|
531
|
+
f"Slot '{slot_name}' not found in configuration. "
|
|
532
|
+
f"Available slots: {list(self.config.slots.keys())}"
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Validate parameters
|
|
536
|
+
if max_tokens is not None and max_tokens <= 0:
|
|
537
|
+
raise ValueError(f"max_tokens must be positive, got {max_tokens}")
|
|
538
|
+
|
|
539
|
+
if temperature is not None and not (0.0 <= temperature <= 2.0):
|
|
540
|
+
raise ValueError(
|
|
541
|
+
f"temperature must be between 0.0 and 2.0, got {temperature}"
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
try:
|
|
545
|
+
# Ensure slot is loaded
|
|
546
|
+
if slot_name not in self._loaded_slots:
|
|
547
|
+
self.load_slot(slot)
|
|
548
|
+
|
|
549
|
+
backend = self._backends.get(slot_name)
|
|
550
|
+
if not backend:
|
|
551
|
+
raise UnknownSlotError(
|
|
552
|
+
f"Backend for slot '{slot_name}' not initialized"
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# Get defaults from config
|
|
556
|
+
slot_config = self.config.slots.get(slot_name)
|
|
557
|
+
if slot_config:
|
|
558
|
+
max_tokens = max_tokens or slot_config.default_max_tokens
|
|
559
|
+
temperature = (
|
|
560
|
+
temperature if temperature is not None
|
|
561
|
+
else slot_config.default_temperature
|
|
562
|
+
)
|
|
563
|
+
stop = stop or slot_config.stop
|
|
564
|
+
top_p = slot_config.top_p
|
|
565
|
+
else:
|
|
566
|
+
max_tokens = max_tokens or 1024
|
|
567
|
+
temperature = temperature if temperature is not None else 0.5
|
|
568
|
+
top_p = 0.9
|
|
569
|
+
|
|
570
|
+
# Format prompt (combine system and user for backends)
|
|
571
|
+
full_prompt = self._format_prompt(system_prompt, user_message)
|
|
572
|
+
|
|
573
|
+
# Generate
|
|
574
|
+
start_time = time.perf_counter()
|
|
575
|
+
|
|
576
|
+
result: BackendResult = backend.generate(
|
|
577
|
+
prompt=full_prompt,
|
|
578
|
+
max_tokens=max_tokens,
|
|
579
|
+
temperature=temperature,
|
|
580
|
+
top_p=top_p,
|
|
581
|
+
stop=stop,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
elapsed = time.perf_counter() - start_time
|
|
585
|
+
|
|
586
|
+
# Calculate total tokens used
|
|
587
|
+
tokens_used = result.tokens_in + result.tokens_out
|
|
588
|
+
|
|
589
|
+
logger.debug(
|
|
590
|
+
f"Generated {result.tokens_out} tokens with {result.model_id} "
|
|
591
|
+
f"in {elapsed:.2f}s (latency: {result.latency_ms:.1f}ms)"
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
return result.text, tokens_used, result.model_id
|
|
595
|
+
|
|
596
|
+
except (UnknownSlotError, ValueError):
|
|
597
|
+
# Re-raise validation errors as-is
|
|
598
|
+
raise
|
|
599
|
+
except BackendError as e:
|
|
600
|
+
# Convert backend errors
|
|
601
|
+
logger.error(f"Backend error for slot {slot_name}: {e}")
|
|
602
|
+
raise ModelBackendError(
|
|
603
|
+
f"Backend error in slot '{slot_name}': {e}"
|
|
604
|
+
) from e
|
|
605
|
+
except Exception as e:
|
|
606
|
+
# Wrap all other exceptions
|
|
607
|
+
logger.error(f"Model generation failed for slot {slot_name}: {e}")
|
|
608
|
+
raise ModelBackendError(
|
|
609
|
+
f"Backend error in slot '{slot_name}': {type(e).__name__}: {e}"
|
|
610
|
+
) from e
|
|
611
|
+
|
|
612
|
+
def _format_prompt(self, system_prompt: str, user_message: str) -> str:
|
|
613
|
+
"""
|
|
614
|
+
Format system and user prompts into a single string.
|
|
615
|
+
Using Llama-3 Chat Template for better Llama-3.2-1B control.
|
|
616
|
+
"""
|
|
617
|
+
return (
|
|
618
|
+
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
|
619
|
+
f"{system_prompt}<|eot_id|>"
|
|
620
|
+
f"<|start_header_id|>user<|end_header_id|>\n\n"
|
|
621
|
+
f"{user_message}<|eot_id|>"
|
|
622
|
+
f"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def get_token_weight(self, slot: Slot) -> float:
|
|
626
|
+
"""
|
|
627
|
+
Get the token cost weight for a slot.
|
|
628
|
+
|
|
629
|
+
Args:
|
|
630
|
+
slot: Slot to get weight for
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
Token weight multiplier
|
|
634
|
+
"""
|
|
635
|
+
return self.config.token_weights.get(slot.value, 1.0)
|
|
636
|
+
|
|
637
|
+
def unload_all(self) -> None:
|
|
638
|
+
"""Unload all loaded model slots."""
|
|
639
|
+
for slot_name in list(self._loaded_slots):
|
|
640
|
+
try:
|
|
641
|
+
self.unload_slot(Slot(slot_name))
|
|
642
|
+
except ValueError:
|
|
643
|
+
# Skip invalid slot names
|
|
644
|
+
pass
|
|
645
|
+
|
|
646
|
+
def __del__(self):
|
|
647
|
+
"""Cleanup on deletion."""
|
|
648
|
+
try:
|
|
649
|
+
self.unload_all()
|
|
650
|
+
except Exception:
|
|
651
|
+
pass # Ignore errors during cleanup
|