parishad 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parishad/__init__.py +70 -0
- parishad/__main__.py +10 -0
- parishad/checker/__init__.py +25 -0
- parishad/checker/deterministic.py +644 -0
- parishad/checker/ensemble.py +496 -0
- parishad/checker/retrieval.py +546 -0
- parishad/cli/__init__.py +6 -0
- parishad/cli/code.py +3254 -0
- parishad/cli/main.py +1158 -0
- parishad/cli/prarambh.py +99 -0
- parishad/cli/sthapana.py +368 -0
- parishad/config/modes.py +139 -0
- parishad/config/pipeline.core.yaml +128 -0
- parishad/config/pipeline.extended.yaml +172 -0
- parishad/config/pipeline.fast.yaml +89 -0
- parishad/config/user_config.py +115 -0
- parishad/data/catalog.py +118 -0
- parishad/data/models.json +108 -0
- parishad/memory/__init__.py +79 -0
- parishad/models/__init__.py +181 -0
- parishad/models/backends/__init__.py +247 -0
- parishad/models/backends/base.py +211 -0
- parishad/models/backends/huggingface.py +318 -0
- parishad/models/backends/llama_cpp.py +239 -0
- parishad/models/backends/mlx_lm.py +141 -0
- parishad/models/backends/ollama.py +253 -0
- parishad/models/backends/openai_api.py +193 -0
- parishad/models/backends/transformers_hf.py +198 -0
- parishad/models/costs.py +385 -0
- parishad/models/downloader.py +1557 -0
- parishad/models/optimizations.py +871 -0
- parishad/models/profiles.py +610 -0
- parishad/models/reliability.py +876 -0
- parishad/models/runner.py +651 -0
- parishad/models/tokenization.py +287 -0
- parishad/orchestrator/__init__.py +24 -0
- parishad/orchestrator/config_loader.py +210 -0
- parishad/orchestrator/engine.py +1113 -0
- parishad/orchestrator/exceptions.py +14 -0
- parishad/roles/__init__.py +71 -0
- parishad/roles/base.py +712 -0
- parishad/roles/dandadhyaksha.py +163 -0
- parishad/roles/darbari.py +246 -0
- parishad/roles/majumdar.py +274 -0
- parishad/roles/pantapradhan.py +150 -0
- parishad/roles/prerak.py +357 -0
- parishad/roles/raja.py +345 -0
- parishad/roles/sacheev.py +203 -0
- parishad/roles/sainik.py +427 -0
- parishad/roles/sar_senapati.py +164 -0
- parishad/roles/vidushak.py +69 -0
- parishad/tools/__init__.py +7 -0
- parishad/tools/base.py +57 -0
- parishad/tools/fs.py +110 -0
- parishad/tools/perception.py +96 -0
- parishad/tools/retrieval.py +74 -0
- parishad/tools/shell.py +103 -0
- parishad/utils/__init__.py +7 -0
- parishad/utils/hardware.py +122 -0
- parishad/utils/logging.py +79 -0
- parishad/utils/scanner.py +164 -0
- parishad/utils/text.py +61 -0
- parishad/utils/tracing.py +133 -0
- parishad-0.1.0.dist-info/METADATA +256 -0
- parishad-0.1.0.dist-info/RECORD +68 -0
- parishad-0.1.0.dist-info/WHEEL +4 -0
- parishad-0.1.0.dist-info/entry_points.txt +2 -0
- parishad-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,871 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Performance optimizations for Parishad model inference.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- ResponseCache: LRU cache for model responses
|
|
6
|
+
- RequestBatcher: Batch multiple requests for efficiency
|
|
7
|
+
- ConnectionPool: Reuse backend connections
|
|
8
|
+
- RateLimiter: Token bucket rate limiting
|
|
9
|
+
|
|
10
|
+
These optimizations are optional and can be enabled via configuration.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import hashlib
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import sqlite3
|
|
20
|
+
import threading
|
|
21
|
+
import time
|
|
22
|
+
from abc import ABC, abstractmethod
|
|
23
|
+
from collections import OrderedDict
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from queue import Queue, Empty
|
|
27
|
+
from typing import Any, Callable, Optional, TypeVar
|
|
28
|
+
from contextlib import contextmanager
|
|
29
|
+
|
|
30
|
+
from .backends import BackendConfig, BackendResult, ModelBackend
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
T = TypeVar("T")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# Response Cache
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class CacheEntry:
|
|
45
|
+
"""Entry in the response cache."""
|
|
46
|
+
key: str
|
|
47
|
+
response: BackendResult
|
|
48
|
+
created_at: float
|
|
49
|
+
access_count: int = 0
|
|
50
|
+
last_accessed: float = field(default_factory=time.time)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def age_seconds(self) -> float:
|
|
54
|
+
"""Age of the entry in seconds."""
|
|
55
|
+
return time.time() - self.created_at
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ResponseCache:
|
|
59
|
+
"""
|
|
60
|
+
LRU cache for model responses.
|
|
61
|
+
|
|
62
|
+
Caches responses based on prompt hash to avoid redundant model calls.
|
|
63
|
+
Thread-safe for concurrent access.
|
|
64
|
+
|
|
65
|
+
Usage:
|
|
66
|
+
cache = ResponseCache(max_size=1000, ttl_seconds=3600)
|
|
67
|
+
|
|
68
|
+
key = cache.make_key(prompt, model_id, temperature)
|
|
69
|
+
if cached := cache.get(key):
|
|
70
|
+
return cached
|
|
71
|
+
|
|
72
|
+
result = model.generate(prompt)
|
|
73
|
+
cache.put(key, result)
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
max_size: int = 1000,
|
|
79
|
+
ttl_seconds: float = 3600,
|
|
80
|
+
enabled: bool = True,
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Initialize cache.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
max_size: Maximum number of entries
|
|
87
|
+
ttl_seconds: Time-to-live for entries
|
|
88
|
+
enabled: Whether caching is enabled
|
|
89
|
+
"""
|
|
90
|
+
self.max_size = max_size
|
|
91
|
+
self.ttl_seconds = ttl_seconds
|
|
92
|
+
self.enabled = enabled
|
|
93
|
+
|
|
94
|
+
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
|
95
|
+
self._lock = threading.RLock()
|
|
96
|
+
|
|
97
|
+
# Statistics
|
|
98
|
+
self._hits = 0
|
|
99
|
+
self._misses = 0
|
|
100
|
+
|
|
101
|
+
def make_key(
|
|
102
|
+
self,
|
|
103
|
+
prompt: str,
|
|
104
|
+
model_id: str,
|
|
105
|
+
temperature: float = 0.0,
|
|
106
|
+
max_tokens: int = 0,
|
|
107
|
+
**kwargs,
|
|
108
|
+
) -> str:
|
|
109
|
+
"""
|
|
110
|
+
Create a cache key from request parameters.
|
|
111
|
+
|
|
112
|
+
Note: Only caches deterministic requests (temperature=0).
|
|
113
|
+
"""
|
|
114
|
+
# Only cache deterministic requests
|
|
115
|
+
if temperature > 0.01:
|
|
116
|
+
return "" # Empty key means don't cache
|
|
117
|
+
|
|
118
|
+
key_data = json.dumps({
|
|
119
|
+
"prompt": prompt,
|
|
120
|
+
"model_id": model_id,
|
|
121
|
+
"temperature": temperature,
|
|
122
|
+
"max_tokens": max_tokens,
|
|
123
|
+
**kwargs,
|
|
124
|
+
}, sort_keys=True)
|
|
125
|
+
|
|
126
|
+
return hashlib.sha256(key_data.encode()).hexdigest()[:32]
|
|
127
|
+
|
|
128
|
+
def get(self, key: str) -> Optional[BackendResult]:
|
|
129
|
+
"""
|
|
130
|
+
Get cached response.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
key: Cache key
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Cached BackendResult or None
|
|
137
|
+
"""
|
|
138
|
+
if not self.enabled or not key:
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
with self._lock:
|
|
142
|
+
entry = self._cache.get(key)
|
|
143
|
+
|
|
144
|
+
if entry is None:
|
|
145
|
+
self._misses += 1
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
# Check TTL
|
|
149
|
+
if entry.age_seconds > self.ttl_seconds:
|
|
150
|
+
del self._cache[key]
|
|
151
|
+
self._misses += 1
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
# Update access stats and move to end (LRU)
|
|
155
|
+
entry.access_count += 1
|
|
156
|
+
entry.last_accessed = time.time()
|
|
157
|
+
self._cache.move_to_end(key)
|
|
158
|
+
|
|
159
|
+
self._hits += 1
|
|
160
|
+
return entry.response
|
|
161
|
+
|
|
162
|
+
def put(self, key: str, response: BackendResult) -> None:
|
|
163
|
+
"""
|
|
164
|
+
Store response in cache.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
key: Cache key
|
|
168
|
+
response: Response to cache
|
|
169
|
+
"""
|
|
170
|
+
if not self.enabled or not key:
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
with self._lock:
|
|
174
|
+
# Evict oldest if at capacity
|
|
175
|
+
while len(self._cache) >= self.max_size:
|
|
176
|
+
self._cache.popitem(last=False)
|
|
177
|
+
|
|
178
|
+
self._cache[key] = CacheEntry(
|
|
179
|
+
key=key,
|
|
180
|
+
response=response,
|
|
181
|
+
created_at=time.time(),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def invalidate(self, key: str) -> bool:
|
|
185
|
+
"""Remove specific key from cache."""
|
|
186
|
+
with self._lock:
|
|
187
|
+
if key in self._cache:
|
|
188
|
+
del self._cache[key]
|
|
189
|
+
return True
|
|
190
|
+
return False
|
|
191
|
+
|
|
192
|
+
def clear(self) -> int:
|
|
193
|
+
"""Clear all cache entries. Returns count cleared."""
|
|
194
|
+
with self._lock:
|
|
195
|
+
count = len(self._cache)
|
|
196
|
+
self._cache.clear()
|
|
197
|
+
return count
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def size(self) -> int:
|
|
201
|
+
"""Current cache size."""
|
|
202
|
+
return len(self._cache)
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def hit_rate(self) -> float:
|
|
206
|
+
"""Cache hit rate."""
|
|
207
|
+
total = self._hits + self._misses
|
|
208
|
+
return self._hits / total if total > 0 else 0.0
|
|
209
|
+
|
|
210
|
+
def get_stats(self) -> dict:
|
|
211
|
+
"""Get cache statistics."""
|
|
212
|
+
return {
|
|
213
|
+
"size": self.size,
|
|
214
|
+
"max_size": self.max_size,
|
|
215
|
+
"hits": self._hits,
|
|
216
|
+
"misses": self._misses,
|
|
217
|
+
"hit_rate": self.hit_rate,
|
|
218
|
+
"enabled": self.enabled,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class PersistentCache(ResponseCache):
|
|
223
|
+
"""
|
|
224
|
+
SQLite-backed persistent cache.
|
|
225
|
+
|
|
226
|
+
Survives process restarts. Uses same interface as ResponseCache.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
def __init__(
|
|
230
|
+
self,
|
|
231
|
+
path: str | Path,
|
|
232
|
+
max_size: int = 10000,
|
|
233
|
+
ttl_seconds: float = 86400, # 24 hours
|
|
234
|
+
enabled: bool = True,
|
|
235
|
+
):
|
|
236
|
+
super().__init__(max_size=max_size, ttl_seconds=ttl_seconds, enabled=enabled)
|
|
237
|
+
self.path = Path(path)
|
|
238
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
239
|
+
|
|
240
|
+
self._init_db()
|
|
241
|
+
|
|
242
|
+
def _init_db(self):
|
|
243
|
+
"""Initialize SQLite database."""
|
|
244
|
+
with self._get_conn() as conn:
|
|
245
|
+
conn.execute("""
|
|
246
|
+
CREATE TABLE IF NOT EXISTS cache (
|
|
247
|
+
key TEXT PRIMARY KEY,
|
|
248
|
+
response_json TEXT,
|
|
249
|
+
created_at REAL,
|
|
250
|
+
access_count INTEGER DEFAULT 0,
|
|
251
|
+
last_accessed REAL
|
|
252
|
+
)
|
|
253
|
+
""")
|
|
254
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_created ON cache(created_at)")
|
|
255
|
+
|
|
256
|
+
@contextmanager
|
|
257
|
+
def _get_conn(self):
|
|
258
|
+
"""Get database connection."""
|
|
259
|
+
conn = sqlite3.connect(str(self.path))
|
|
260
|
+
try:
|
|
261
|
+
yield conn
|
|
262
|
+
conn.commit()
|
|
263
|
+
finally:
|
|
264
|
+
conn.close()
|
|
265
|
+
|
|
266
|
+
def get(self, key: str) -> Optional[BackendResult]:
|
|
267
|
+
"""Get from persistent cache."""
|
|
268
|
+
if not self.enabled or not key:
|
|
269
|
+
return None
|
|
270
|
+
|
|
271
|
+
with self._get_conn() as conn:
|
|
272
|
+
cursor = conn.execute(
|
|
273
|
+
"SELECT response_json, created_at FROM cache WHERE key = ?",
|
|
274
|
+
(key,)
|
|
275
|
+
)
|
|
276
|
+
row = cursor.fetchone()
|
|
277
|
+
|
|
278
|
+
if not row:
|
|
279
|
+
self._misses += 1
|
|
280
|
+
return None
|
|
281
|
+
|
|
282
|
+
response_json, created_at = row
|
|
283
|
+
|
|
284
|
+
# Check TTL
|
|
285
|
+
if time.time() - created_at > self.ttl_seconds:
|
|
286
|
+
conn.execute("DELETE FROM cache WHERE key = ?", (key,))
|
|
287
|
+
self._misses += 1
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
# Update access stats
|
|
291
|
+
conn.execute(
|
|
292
|
+
"UPDATE cache SET access_count = access_count + 1, last_accessed = ? WHERE key = ?",
|
|
293
|
+
(time.time(), key)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
self._hits += 1
|
|
297
|
+
data = json.loads(response_json)
|
|
298
|
+
return BackendResult(**data)
|
|
299
|
+
|
|
300
|
+
def put(self, key: str, response: BackendResult) -> None:
|
|
301
|
+
"""Store in persistent cache."""
|
|
302
|
+
if not self.enabled or not key:
|
|
303
|
+
return
|
|
304
|
+
|
|
305
|
+
with self._get_conn() as conn:
|
|
306
|
+
# Evict old entries if needed
|
|
307
|
+
cursor = conn.execute("SELECT COUNT(*) FROM cache")
|
|
308
|
+
count = cursor.fetchone()[0]
|
|
309
|
+
|
|
310
|
+
if count >= self.max_size:
|
|
311
|
+
# Delete oldest 10%
|
|
312
|
+
delete_count = max(1, self.max_size // 10)
|
|
313
|
+
conn.execute(
|
|
314
|
+
"DELETE FROM cache WHERE key IN (SELECT key FROM cache ORDER BY last_accessed LIMIT ?)",
|
|
315
|
+
(delete_count,)
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
response_json = json.dumps({
|
|
319
|
+
"text": response.text,
|
|
320
|
+
"tokens_in": response.tokens_in,
|
|
321
|
+
"tokens_out": response.tokens_out,
|
|
322
|
+
"model_id": response.model_id,
|
|
323
|
+
"latency_ms": response.latency_ms,
|
|
324
|
+
})
|
|
325
|
+
|
|
326
|
+
conn.execute(
|
|
327
|
+
"""INSERT OR REPLACE INTO cache
|
|
328
|
+
(key, response_json, created_at, access_count, last_accessed)
|
|
329
|
+
VALUES (?, ?, ?, 0, ?)""",
|
|
330
|
+
(key, response_json, time.time(), time.time())
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def clear(self) -> int:
|
|
334
|
+
"""Clear all entries."""
|
|
335
|
+
with self._get_conn() as conn:
|
|
336
|
+
cursor = conn.execute("SELECT COUNT(*) FROM cache")
|
|
337
|
+
count = cursor.fetchone()[0]
|
|
338
|
+
conn.execute("DELETE FROM cache")
|
|
339
|
+
return count
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
# =============================================================================
|
|
343
|
+
# Request Batcher
|
|
344
|
+
# =============================================================================
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@dataclass
|
|
348
|
+
class BatchRequest:
|
|
349
|
+
"""A request in the batch queue."""
|
|
350
|
+
prompt: str
|
|
351
|
+
config: BackendConfig
|
|
352
|
+
future: "asyncio.Future[BackendResult]"
|
|
353
|
+
submitted_at: float = field(default_factory=time.time)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class RequestBatcher:
|
|
357
|
+
"""
|
|
358
|
+
Batch multiple requests for efficient processing.
|
|
359
|
+
|
|
360
|
+
Collects requests over a time window and processes them together.
|
|
361
|
+
Useful for backends that support batch inference.
|
|
362
|
+
|
|
363
|
+
Usage:
|
|
364
|
+
batcher = RequestBatcher(backend, batch_size=8, wait_ms=50)
|
|
365
|
+
result = await batcher.submit(prompt, config)
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
def __init__(
|
|
369
|
+
self,
|
|
370
|
+
backend: ModelBackend,
|
|
371
|
+
batch_size: int = 8,
|
|
372
|
+
wait_ms: float = 50.0,
|
|
373
|
+
enabled: bool = True,
|
|
374
|
+
):
|
|
375
|
+
"""
|
|
376
|
+
Initialize batcher.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
backend: Backend to use for generation
|
|
380
|
+
batch_size: Maximum batch size
|
|
381
|
+
wait_ms: Maximum wait time before processing
|
|
382
|
+
enabled: Whether batching is enabled
|
|
383
|
+
"""
|
|
384
|
+
self.backend = backend
|
|
385
|
+
self.batch_size = batch_size
|
|
386
|
+
self.wait_ms = wait_ms
|
|
387
|
+
self.enabled = enabled
|
|
388
|
+
|
|
389
|
+
self._queue: list[BatchRequest] = []
|
|
390
|
+
self._lock = threading.Lock()
|
|
391
|
+
self._processing = False
|
|
392
|
+
|
|
393
|
+
# Statistics
|
|
394
|
+
self._batches_processed = 0
|
|
395
|
+
self._requests_processed = 0
|
|
396
|
+
|
|
397
|
+
async def submit(
|
|
398
|
+
self,
|
|
399
|
+
prompt: str,
|
|
400
|
+
config: BackendConfig,
|
|
401
|
+
) -> BackendResult:
|
|
402
|
+
"""
|
|
403
|
+
Submit a request for batched processing.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
prompt: Input prompt
|
|
407
|
+
config: Backend configuration
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
BackendResult from generation
|
|
411
|
+
"""
|
|
412
|
+
if not self.enabled:
|
|
413
|
+
# Direct processing if batching disabled
|
|
414
|
+
return self.backend.generate(
|
|
415
|
+
prompt=prompt,
|
|
416
|
+
max_tokens=config.max_tokens,
|
|
417
|
+
temperature=config.temperature,
|
|
418
|
+
top_p=config.top_p,
|
|
419
|
+
stop=config.stop,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
loop = asyncio.get_event_loop()
|
|
423
|
+
future: asyncio.Future[BackendResult] = loop.create_future()
|
|
424
|
+
|
|
425
|
+
request = BatchRequest(prompt=prompt, config=config, future=future)
|
|
426
|
+
|
|
427
|
+
with self._lock:
|
|
428
|
+
self._queue.append(request)
|
|
429
|
+
|
|
430
|
+
if len(self._queue) >= self.batch_size:
|
|
431
|
+
# Process immediately if batch is full
|
|
432
|
+
self._schedule_processing()
|
|
433
|
+
elif len(self._queue) == 1:
|
|
434
|
+
# Schedule delayed processing
|
|
435
|
+
loop.call_later(self.wait_ms / 1000, self._schedule_processing)
|
|
436
|
+
|
|
437
|
+
return await future
|
|
438
|
+
|
|
439
|
+
def _schedule_processing(self) -> None:
|
|
440
|
+
"""Schedule batch processing."""
|
|
441
|
+
with self._lock:
|
|
442
|
+
if self._processing or not self._queue:
|
|
443
|
+
return
|
|
444
|
+
|
|
445
|
+
self._processing = True
|
|
446
|
+
batch = self._queue[:self.batch_size]
|
|
447
|
+
self._queue = self._queue[self.batch_size:]
|
|
448
|
+
|
|
449
|
+
# Process in thread pool
|
|
450
|
+
try:
|
|
451
|
+
self._process_batch(batch)
|
|
452
|
+
finally:
|
|
453
|
+
with self._lock:
|
|
454
|
+
self._processing = False
|
|
455
|
+
|
|
456
|
+
def _process_batch(self, batch: list[BatchRequest]) -> None:
|
|
457
|
+
"""Process a batch of requests."""
|
|
458
|
+
for request in batch:
|
|
459
|
+
try:
|
|
460
|
+
result = self.backend.generate(
|
|
461
|
+
prompt=request.prompt,
|
|
462
|
+
max_tokens=request.config.max_tokens,
|
|
463
|
+
temperature=request.config.temperature,
|
|
464
|
+
top_p=request.config.top_p,
|
|
465
|
+
stop=request.config.stop,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
if not request.future.done():
|
|
469
|
+
request.future.get_loop().call_soon_threadsafe(
|
|
470
|
+
request.future.set_result, result
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
except Exception as e:
|
|
474
|
+
if not request.future.done():
|
|
475
|
+
request.future.get_loop().call_soon_threadsafe(
|
|
476
|
+
request.future.set_exception, e
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
self._batches_processed += 1
|
|
480
|
+
self._requests_processed += len(batch)
|
|
481
|
+
|
|
482
|
+
def get_stats(self) -> dict:
|
|
483
|
+
"""Get batcher statistics."""
|
|
484
|
+
return {
|
|
485
|
+
"batches_processed": self._batches_processed,
|
|
486
|
+
"requests_processed": self._requests_processed,
|
|
487
|
+
"avg_batch_size": (
|
|
488
|
+
self._requests_processed / self._batches_processed
|
|
489
|
+
if self._batches_processed > 0 else 0
|
|
490
|
+
),
|
|
491
|
+
"queue_size": len(self._queue),
|
|
492
|
+
"enabled": self.enabled,
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
# =============================================================================
|
|
497
|
+
# Connection Pool
|
|
498
|
+
# =============================================================================
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
class ConnectionPool:
|
|
502
|
+
"""
|
|
503
|
+
Pool of reusable backend connections.
|
|
504
|
+
|
|
505
|
+
Reduces overhead of creating new connections for each request.
|
|
506
|
+
Thread-safe for concurrent access.
|
|
507
|
+
|
|
508
|
+
Usage:
|
|
509
|
+
pool = ConnectionPool(backend_factory, max_size=4)
|
|
510
|
+
|
|
511
|
+
with pool.acquire() as backend:
|
|
512
|
+
result = backend.generate(prompt)
|
|
513
|
+
"""
|
|
514
|
+
|
|
515
|
+
def __init__(
|
|
516
|
+
self,
|
|
517
|
+
backend_factory: Callable[[], ModelBackend],
|
|
518
|
+
max_size: int = 4,
|
|
519
|
+
min_size: int = 1,
|
|
520
|
+
):
|
|
521
|
+
"""
|
|
522
|
+
Initialize pool.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
backend_factory: Factory function to create backends
|
|
526
|
+
max_size: Maximum pool size
|
|
527
|
+
min_size: Minimum backends to keep ready
|
|
528
|
+
"""
|
|
529
|
+
self.backend_factory = backend_factory
|
|
530
|
+
self.max_size = max_size
|
|
531
|
+
self.min_size = min_size
|
|
532
|
+
|
|
533
|
+
self._available: Queue[ModelBackend] = Queue()
|
|
534
|
+
self._in_use: set[int] = set()
|
|
535
|
+
self._lock = threading.Lock()
|
|
536
|
+
self._total_created = 0
|
|
537
|
+
|
|
538
|
+
# Pre-create minimum backends
|
|
539
|
+
for _ in range(min_size):
|
|
540
|
+
self._create_backend()
|
|
541
|
+
|
|
542
|
+
def _create_backend(self) -> ModelBackend:
|
|
543
|
+
"""Create a new backend instance."""
|
|
544
|
+
backend = self.backend_factory()
|
|
545
|
+
self._available.put(backend)
|
|
546
|
+
self._total_created += 1
|
|
547
|
+
return backend
|
|
548
|
+
|
|
549
|
+
@contextmanager
|
|
550
|
+
def acquire(self, timeout: float = 30.0):
|
|
551
|
+
"""
|
|
552
|
+
Acquire a backend from the pool.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
timeout: Maximum time to wait for a backend
|
|
556
|
+
|
|
557
|
+
Yields:
|
|
558
|
+
ModelBackend instance
|
|
559
|
+
"""
|
|
560
|
+
backend = None
|
|
561
|
+
|
|
562
|
+
try:
|
|
563
|
+
# Try to get from available
|
|
564
|
+
try:
|
|
565
|
+
backend = self._available.get(timeout=timeout)
|
|
566
|
+
except Empty:
|
|
567
|
+
# Create new if under limit
|
|
568
|
+
with self._lock:
|
|
569
|
+
current_size = self._total_created
|
|
570
|
+
if current_size < self.max_size:
|
|
571
|
+
backend = self.backend_factory()
|
|
572
|
+
self._total_created += 1
|
|
573
|
+
else:
|
|
574
|
+
raise TimeoutError("No backends available in pool")
|
|
575
|
+
|
|
576
|
+
with self._lock:
|
|
577
|
+
self._in_use.add(id(backend))
|
|
578
|
+
|
|
579
|
+
yield backend
|
|
580
|
+
|
|
581
|
+
finally:
|
|
582
|
+
if backend is not None:
|
|
583
|
+
with self._lock:
|
|
584
|
+
self._in_use.discard(id(backend))
|
|
585
|
+
self._available.put(backend)
|
|
586
|
+
|
|
587
|
+
def get_stats(self) -> dict:
|
|
588
|
+
"""Get pool statistics."""
|
|
589
|
+
return {
|
|
590
|
+
"total_created": self._total_created,
|
|
591
|
+
"available": self._available.qsize(),
|
|
592
|
+
"in_use": len(self._in_use),
|
|
593
|
+
"max_size": self.max_size,
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
# =============================================================================
|
|
598
|
+
# Rate Limiter
|
|
599
|
+
# =============================================================================
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
class RateLimiter:
|
|
603
|
+
"""
|
|
604
|
+
Token bucket rate limiter.
|
|
605
|
+
|
|
606
|
+
Controls request rate to avoid overwhelming backends or hitting API limits.
|
|
607
|
+
|
|
608
|
+
Usage:
|
|
609
|
+
limiter = RateLimiter(tokens_per_second=10, burst_size=20)
|
|
610
|
+
|
|
611
|
+
await limiter.acquire() # Blocks until token available
|
|
612
|
+
result = model.generate(prompt)
|
|
613
|
+
"""
|
|
614
|
+
|
|
615
|
+
def __init__(
|
|
616
|
+
self,
|
|
617
|
+
tokens_per_second: float = 10.0,
|
|
618
|
+
burst_size: int = 20,
|
|
619
|
+
):
|
|
620
|
+
"""
|
|
621
|
+
Initialize rate limiter.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
tokens_per_second: Token refill rate
|
|
625
|
+
burst_size: Maximum tokens (burst capacity)
|
|
626
|
+
"""
|
|
627
|
+
self.tokens_per_second = tokens_per_second
|
|
628
|
+
self.burst_size = burst_size
|
|
629
|
+
|
|
630
|
+
self._tokens = float(burst_size)
|
|
631
|
+
self._last_refill = time.time()
|
|
632
|
+
self._lock = threading.Lock()
|
|
633
|
+
|
|
634
|
+
# Statistics
|
|
635
|
+
self._requests = 0
|
|
636
|
+
self._waits = 0
|
|
637
|
+
self._total_wait_time = 0.0
|
|
638
|
+
|
|
639
|
+
def _refill(self) -> None:
|
|
640
|
+
"""Refill tokens based on elapsed time."""
|
|
641
|
+
now = time.time()
|
|
642
|
+
elapsed = now - self._last_refill
|
|
643
|
+
self._tokens = min(
|
|
644
|
+
self.burst_size,
|
|
645
|
+
self._tokens + elapsed * self.tokens_per_second
|
|
646
|
+
)
|
|
647
|
+
self._last_refill = now
|
|
648
|
+
|
|
649
|
+
def acquire(self, tokens: int = 1) -> float:
|
|
650
|
+
"""
|
|
651
|
+
Acquire tokens, blocking if necessary.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
tokens: Number of tokens to acquire
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
Wait time in seconds
|
|
658
|
+
"""
|
|
659
|
+
wait_time = 0.0
|
|
660
|
+
|
|
661
|
+
with self._lock:
|
|
662
|
+
self._refill()
|
|
663
|
+
|
|
664
|
+
while self._tokens < tokens:
|
|
665
|
+
# Calculate wait time
|
|
666
|
+
needed = tokens - self._tokens
|
|
667
|
+
wait = needed / self.tokens_per_second
|
|
668
|
+
wait_time += wait
|
|
669
|
+
|
|
670
|
+
self._lock.release()
|
|
671
|
+
time.sleep(wait)
|
|
672
|
+
self._lock.acquire()
|
|
673
|
+
|
|
674
|
+
self._refill()
|
|
675
|
+
|
|
676
|
+
self._tokens -= tokens
|
|
677
|
+
self._requests += 1
|
|
678
|
+
|
|
679
|
+
if wait_time > 0:
|
|
680
|
+
self._waits += 1
|
|
681
|
+
self._total_wait_time += wait_time
|
|
682
|
+
|
|
683
|
+
return wait_time
|
|
684
|
+
|
|
685
|
+
async def acquire_async(self, tokens: int = 1) -> float:
|
|
686
|
+
"""Async version of acquire."""
|
|
687
|
+
wait_time = 0.0
|
|
688
|
+
|
|
689
|
+
with self._lock:
|
|
690
|
+
self._refill()
|
|
691
|
+
|
|
692
|
+
if self._tokens < tokens:
|
|
693
|
+
needed = tokens - self._tokens
|
|
694
|
+
wait_time = needed / self.tokens_per_second
|
|
695
|
+
|
|
696
|
+
if wait_time > 0:
|
|
697
|
+
await asyncio.sleep(wait_time)
|
|
698
|
+
self._waits += 1
|
|
699
|
+
self._total_wait_time += wait_time
|
|
700
|
+
|
|
701
|
+
with self._lock:
|
|
702
|
+
self._refill()
|
|
703
|
+
self._tokens -= tokens
|
|
704
|
+
self._requests += 1
|
|
705
|
+
|
|
706
|
+
return wait_time
|
|
707
|
+
|
|
708
|
+
def get_stats(self) -> dict:
|
|
709
|
+
"""Get rate limiter statistics."""
|
|
710
|
+
return {
|
|
711
|
+
"requests": self._requests,
|
|
712
|
+
"waits": self._waits,
|
|
713
|
+
"total_wait_time": self._total_wait_time,
|
|
714
|
+
"avg_wait_time": (
|
|
715
|
+
self._total_wait_time / self._waits if self._waits > 0 else 0
|
|
716
|
+
),
|
|
717
|
+
"current_tokens": self._tokens,
|
|
718
|
+
"tokens_per_second": self.tokens_per_second,
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
# =============================================================================
|
|
723
|
+
# Optimized Runner Wrapper
|
|
724
|
+
# =============================================================================
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
class OptimizedRunner:
|
|
728
|
+
"""
|
|
729
|
+
Wrapper that adds caching, batching, and rate limiting to a ModelRunner.
|
|
730
|
+
|
|
731
|
+
Usage:
|
|
732
|
+
from parishad.models.runner import ModelRunner
|
|
733
|
+
|
|
734
|
+
runner = ModelRunner(stub=True)
|
|
735
|
+
optimized = OptimizedRunner(
|
|
736
|
+
runner,
|
|
737
|
+
cache_enabled=True,
|
|
738
|
+
rate_limit=10.0,
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
text, tokens, model = optimized.generate(
|
|
742
|
+
system_prompt="You are helpful.",
|
|
743
|
+
user_message="Hello!",
|
|
744
|
+
slot=Slot.SMALL,
|
|
745
|
+
)
|
|
746
|
+
"""
|
|
747
|
+
|
|
748
|
+
def __init__(
|
|
749
|
+
self,
|
|
750
|
+
runner: "ModelRunner", # type: ignore
|
|
751
|
+
cache_enabled: bool = False,
|
|
752
|
+
cache_max_size: int = 1000,
|
|
753
|
+
cache_ttl: float = 3600,
|
|
754
|
+
rate_limit: Optional[float] = None,
|
|
755
|
+
rate_burst: int = 20,
|
|
756
|
+
):
|
|
757
|
+
"""
|
|
758
|
+
Initialize optimized runner.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
runner: Base ModelRunner to wrap
|
|
762
|
+
cache_enabled: Enable response caching
|
|
763
|
+
cache_max_size: Maximum cache entries
|
|
764
|
+
cache_ttl: Cache TTL in seconds
|
|
765
|
+
rate_limit: Rate limit (requests per second)
|
|
766
|
+
rate_burst: Rate limit burst size
|
|
767
|
+
"""
|
|
768
|
+
self.runner = runner
|
|
769
|
+
|
|
770
|
+
self.cache = ResponseCache(
|
|
771
|
+
max_size=cache_max_size,
|
|
772
|
+
ttl_seconds=cache_ttl,
|
|
773
|
+
enabled=cache_enabled,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
self.rate_limiter: Optional[RateLimiter] = None
|
|
777
|
+
if rate_limit is not None:
|
|
778
|
+
self.rate_limiter = RateLimiter(
|
|
779
|
+
tokens_per_second=rate_limit,
|
|
780
|
+
burst_size=rate_burst,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
def generate(
|
|
784
|
+
self,
|
|
785
|
+
system_prompt: str,
|
|
786
|
+
user_message: str,
|
|
787
|
+
slot: "Slot", # type: ignore
|
|
788
|
+
max_tokens: Optional[int] = None,
|
|
789
|
+
temperature: Optional[float] = None,
|
|
790
|
+
**kwargs,
|
|
791
|
+
) -> tuple[str, int, str]:
|
|
792
|
+
"""
|
|
793
|
+
Generate with optimizations applied.
|
|
794
|
+
|
|
795
|
+
Args:
|
|
796
|
+
system_prompt: System prompt
|
|
797
|
+
user_message: User message
|
|
798
|
+
slot: Model slot
|
|
799
|
+
max_tokens: Maximum tokens
|
|
800
|
+
temperature: Sampling temperature
|
|
801
|
+
**kwargs: Additional arguments
|
|
802
|
+
|
|
803
|
+
Returns:
|
|
804
|
+
Tuple of (text, tokens, model_id)
|
|
805
|
+
"""
|
|
806
|
+
# Build cache key
|
|
807
|
+
prompt = f"{system_prompt}\n{user_message}"
|
|
808
|
+
cache_key = self.cache.make_key(
|
|
809
|
+
prompt=prompt,
|
|
810
|
+
model_id=slot.value,
|
|
811
|
+
temperature=temperature or 0.0,
|
|
812
|
+
max_tokens=max_tokens or 0,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
# Check cache
|
|
816
|
+
if cached := self.cache.get(cache_key):
|
|
817
|
+
logger.debug("Cache hit for request")
|
|
818
|
+
return cached.text, cached.tokens_in + cached.tokens_out, cached.model_id
|
|
819
|
+
|
|
820
|
+
# Apply rate limiting
|
|
821
|
+
if self.rate_limiter:
|
|
822
|
+
self.rate_limiter.acquire()
|
|
823
|
+
|
|
824
|
+
# Generate
|
|
825
|
+
text, tokens, model_id = self.runner.generate(
|
|
826
|
+
system_prompt=system_prompt,
|
|
827
|
+
user_message=user_message,
|
|
828
|
+
slot=slot,
|
|
829
|
+
max_tokens=max_tokens,
|
|
830
|
+
temperature=temperature,
|
|
831
|
+
**kwargs,
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
# Cache result
|
|
835
|
+
from .backends.base import BackendResult
|
|
836
|
+
result = BackendResult(
|
|
837
|
+
text=text,
|
|
838
|
+
tokens_in=tokens // 2, # Approximate
|
|
839
|
+
tokens_out=tokens - tokens // 2,
|
|
840
|
+
model_id=model_id,
|
|
841
|
+
latency_ms=0,
|
|
842
|
+
)
|
|
843
|
+
self.cache.put(cache_key, result)
|
|
844
|
+
|
|
845
|
+
return text, tokens, model_id
|
|
846
|
+
|
|
847
|
+
def get_stats(self) -> dict:
|
|
848
|
+
"""Get optimization statistics."""
|
|
849
|
+
stats = {
|
|
850
|
+
"cache": self.cache.get_stats(),
|
|
851
|
+
}
|
|
852
|
+
if self.rate_limiter:
|
|
853
|
+
stats["rate_limiter"] = self.rate_limiter.get_stats()
|
|
854
|
+
return stats
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
__all__ = [
|
|
858
|
+
# Cache
|
|
859
|
+
"CacheEntry",
|
|
860
|
+
"ResponseCache",
|
|
861
|
+
"PersistentCache",
|
|
862
|
+
# Batching
|
|
863
|
+
"BatchRequest",
|
|
864
|
+
"RequestBatcher",
|
|
865
|
+
# Connection pool
|
|
866
|
+
"ConnectionPool",
|
|
867
|
+
# Rate limiting
|
|
868
|
+
"RateLimiter",
|
|
869
|
+
# Optimized wrapper
|
|
870
|
+
"OptimizedRunner",
|
|
871
|
+
]
|