parishad 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. parishad/__init__.py +70 -0
  2. parishad/__main__.py +10 -0
  3. parishad/checker/__init__.py +25 -0
  4. parishad/checker/deterministic.py +644 -0
  5. parishad/checker/ensemble.py +496 -0
  6. parishad/checker/retrieval.py +546 -0
  7. parishad/cli/__init__.py +6 -0
  8. parishad/cli/code.py +3254 -0
  9. parishad/cli/main.py +1158 -0
  10. parishad/cli/prarambh.py +99 -0
  11. parishad/cli/sthapana.py +368 -0
  12. parishad/config/modes.py +139 -0
  13. parishad/config/pipeline.core.yaml +128 -0
  14. parishad/config/pipeline.extended.yaml +172 -0
  15. parishad/config/pipeline.fast.yaml +89 -0
  16. parishad/config/user_config.py +115 -0
  17. parishad/data/catalog.py +118 -0
  18. parishad/data/models.json +108 -0
  19. parishad/memory/__init__.py +79 -0
  20. parishad/models/__init__.py +181 -0
  21. parishad/models/backends/__init__.py +247 -0
  22. parishad/models/backends/base.py +211 -0
  23. parishad/models/backends/huggingface.py +318 -0
  24. parishad/models/backends/llama_cpp.py +239 -0
  25. parishad/models/backends/mlx_lm.py +141 -0
  26. parishad/models/backends/ollama.py +253 -0
  27. parishad/models/backends/openai_api.py +193 -0
  28. parishad/models/backends/transformers_hf.py +198 -0
  29. parishad/models/costs.py +385 -0
  30. parishad/models/downloader.py +1557 -0
  31. parishad/models/optimizations.py +871 -0
  32. parishad/models/profiles.py +610 -0
  33. parishad/models/reliability.py +876 -0
  34. parishad/models/runner.py +651 -0
  35. parishad/models/tokenization.py +287 -0
  36. parishad/orchestrator/__init__.py +24 -0
  37. parishad/orchestrator/config_loader.py +210 -0
  38. parishad/orchestrator/engine.py +1113 -0
  39. parishad/orchestrator/exceptions.py +14 -0
  40. parishad/roles/__init__.py +71 -0
  41. parishad/roles/base.py +712 -0
  42. parishad/roles/dandadhyaksha.py +163 -0
  43. parishad/roles/darbari.py +246 -0
  44. parishad/roles/majumdar.py +274 -0
  45. parishad/roles/pantapradhan.py +150 -0
  46. parishad/roles/prerak.py +357 -0
  47. parishad/roles/raja.py +345 -0
  48. parishad/roles/sacheev.py +203 -0
  49. parishad/roles/sainik.py +427 -0
  50. parishad/roles/sar_senapati.py +164 -0
  51. parishad/roles/vidushak.py +69 -0
  52. parishad/tools/__init__.py +7 -0
  53. parishad/tools/base.py +57 -0
  54. parishad/tools/fs.py +110 -0
  55. parishad/tools/perception.py +96 -0
  56. parishad/tools/retrieval.py +74 -0
  57. parishad/tools/shell.py +103 -0
  58. parishad/utils/__init__.py +7 -0
  59. parishad/utils/hardware.py +122 -0
  60. parishad/utils/logging.py +79 -0
  61. parishad/utils/scanner.py +164 -0
  62. parishad/utils/text.py +61 -0
  63. parishad/utils/tracing.py +133 -0
  64. parishad-0.1.0.dist-info/METADATA +256 -0
  65. parishad-0.1.0.dist-info/RECORD +68 -0
  66. parishad-0.1.0.dist-info/WHEEL +4 -0
  67. parishad-0.1.0.dist-info/entry_points.txt +2 -0
  68. parishad-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,871 @@
1
+ """
2
+ Performance optimizations for Parishad model inference.
3
+
4
+ Provides:
5
+ - ResponseCache: LRU cache for model responses
6
+ - RequestBatcher: Batch multiple requests for efficiency
7
+ - ConnectionPool: Reuse backend connections
8
+ - RateLimiter: Token bucket rate limiting
9
+
10
+ These optimizations are optional and can be enabled via configuration.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ import hashlib
17
+ import json
18
+ import logging
19
+ import sqlite3
20
+ import threading
21
+ import time
22
+ from abc import ABC, abstractmethod
23
+ from collections import OrderedDict
24
+ from dataclasses import dataclass, field
25
+ from pathlib import Path
26
+ from queue import Queue, Empty
27
+ from typing import Any, Callable, Optional, TypeVar
28
+ from contextlib import contextmanager
29
+
30
+ from .backends import BackendConfig, BackendResult, ModelBackend
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ T = TypeVar("T")
36
+
37
+
38
+ # =============================================================================
39
+ # Response Cache
40
+ # =============================================================================
41
+
42
+
43
+ @dataclass
44
+ class CacheEntry:
45
+ """Entry in the response cache."""
46
+ key: str
47
+ response: BackendResult
48
+ created_at: float
49
+ access_count: int = 0
50
+ last_accessed: float = field(default_factory=time.time)
51
+
52
+ @property
53
+ def age_seconds(self) -> float:
54
+ """Age of the entry in seconds."""
55
+ return time.time() - self.created_at
56
+
57
+
58
+ class ResponseCache:
59
+ """
60
+ LRU cache for model responses.
61
+
62
+ Caches responses based on prompt hash to avoid redundant model calls.
63
+ Thread-safe for concurrent access.
64
+
65
+ Usage:
66
+ cache = ResponseCache(max_size=1000, ttl_seconds=3600)
67
+
68
+ key = cache.make_key(prompt, model_id, temperature)
69
+ if cached := cache.get(key):
70
+ return cached
71
+
72
+ result = model.generate(prompt)
73
+ cache.put(key, result)
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ max_size: int = 1000,
79
+ ttl_seconds: float = 3600,
80
+ enabled: bool = True,
81
+ ):
82
+ """
83
+ Initialize cache.
84
+
85
+ Args:
86
+ max_size: Maximum number of entries
87
+ ttl_seconds: Time-to-live for entries
88
+ enabled: Whether caching is enabled
89
+ """
90
+ self.max_size = max_size
91
+ self.ttl_seconds = ttl_seconds
92
+ self.enabled = enabled
93
+
94
+ self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
95
+ self._lock = threading.RLock()
96
+
97
+ # Statistics
98
+ self._hits = 0
99
+ self._misses = 0
100
+
101
+ def make_key(
102
+ self,
103
+ prompt: str,
104
+ model_id: str,
105
+ temperature: float = 0.0,
106
+ max_tokens: int = 0,
107
+ **kwargs,
108
+ ) -> str:
109
+ """
110
+ Create a cache key from request parameters.
111
+
112
+ Note: Only caches deterministic requests (temperature=0).
113
+ """
114
+ # Only cache deterministic requests
115
+ if temperature > 0.01:
116
+ return "" # Empty key means don't cache
117
+
118
+ key_data = json.dumps({
119
+ "prompt": prompt,
120
+ "model_id": model_id,
121
+ "temperature": temperature,
122
+ "max_tokens": max_tokens,
123
+ **kwargs,
124
+ }, sort_keys=True)
125
+
126
+ return hashlib.sha256(key_data.encode()).hexdigest()[:32]
127
+
128
+ def get(self, key: str) -> Optional[BackendResult]:
129
+ """
130
+ Get cached response.
131
+
132
+ Args:
133
+ key: Cache key
134
+
135
+ Returns:
136
+ Cached BackendResult or None
137
+ """
138
+ if not self.enabled or not key:
139
+ return None
140
+
141
+ with self._lock:
142
+ entry = self._cache.get(key)
143
+
144
+ if entry is None:
145
+ self._misses += 1
146
+ return None
147
+
148
+ # Check TTL
149
+ if entry.age_seconds > self.ttl_seconds:
150
+ del self._cache[key]
151
+ self._misses += 1
152
+ return None
153
+
154
+ # Update access stats and move to end (LRU)
155
+ entry.access_count += 1
156
+ entry.last_accessed = time.time()
157
+ self._cache.move_to_end(key)
158
+
159
+ self._hits += 1
160
+ return entry.response
161
+
162
+ def put(self, key: str, response: BackendResult) -> None:
163
+ """
164
+ Store response in cache.
165
+
166
+ Args:
167
+ key: Cache key
168
+ response: Response to cache
169
+ """
170
+ if not self.enabled or not key:
171
+ return
172
+
173
+ with self._lock:
174
+ # Evict oldest if at capacity
175
+ while len(self._cache) >= self.max_size:
176
+ self._cache.popitem(last=False)
177
+
178
+ self._cache[key] = CacheEntry(
179
+ key=key,
180
+ response=response,
181
+ created_at=time.time(),
182
+ )
183
+
184
+ def invalidate(self, key: str) -> bool:
185
+ """Remove specific key from cache."""
186
+ with self._lock:
187
+ if key in self._cache:
188
+ del self._cache[key]
189
+ return True
190
+ return False
191
+
192
+ def clear(self) -> int:
193
+ """Clear all cache entries. Returns count cleared."""
194
+ with self._lock:
195
+ count = len(self._cache)
196
+ self._cache.clear()
197
+ return count
198
+
199
+ @property
200
+ def size(self) -> int:
201
+ """Current cache size."""
202
+ return len(self._cache)
203
+
204
+ @property
205
+ def hit_rate(self) -> float:
206
+ """Cache hit rate."""
207
+ total = self._hits + self._misses
208
+ return self._hits / total if total > 0 else 0.0
209
+
210
+ def get_stats(self) -> dict:
211
+ """Get cache statistics."""
212
+ return {
213
+ "size": self.size,
214
+ "max_size": self.max_size,
215
+ "hits": self._hits,
216
+ "misses": self._misses,
217
+ "hit_rate": self.hit_rate,
218
+ "enabled": self.enabled,
219
+ }
220
+
221
+
222
+ class PersistentCache(ResponseCache):
223
+ """
224
+ SQLite-backed persistent cache.
225
+
226
+ Survives process restarts. Uses same interface as ResponseCache.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ path: str | Path,
232
+ max_size: int = 10000,
233
+ ttl_seconds: float = 86400, # 24 hours
234
+ enabled: bool = True,
235
+ ):
236
+ super().__init__(max_size=max_size, ttl_seconds=ttl_seconds, enabled=enabled)
237
+ self.path = Path(path)
238
+ self.path.parent.mkdir(parents=True, exist_ok=True)
239
+
240
+ self._init_db()
241
+
242
+ def _init_db(self):
243
+ """Initialize SQLite database."""
244
+ with self._get_conn() as conn:
245
+ conn.execute("""
246
+ CREATE TABLE IF NOT EXISTS cache (
247
+ key TEXT PRIMARY KEY,
248
+ response_json TEXT,
249
+ created_at REAL,
250
+ access_count INTEGER DEFAULT 0,
251
+ last_accessed REAL
252
+ )
253
+ """)
254
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_created ON cache(created_at)")
255
+
256
+ @contextmanager
257
+ def _get_conn(self):
258
+ """Get database connection."""
259
+ conn = sqlite3.connect(str(self.path))
260
+ try:
261
+ yield conn
262
+ conn.commit()
263
+ finally:
264
+ conn.close()
265
+
266
+ def get(self, key: str) -> Optional[BackendResult]:
267
+ """Get from persistent cache."""
268
+ if not self.enabled or not key:
269
+ return None
270
+
271
+ with self._get_conn() as conn:
272
+ cursor = conn.execute(
273
+ "SELECT response_json, created_at FROM cache WHERE key = ?",
274
+ (key,)
275
+ )
276
+ row = cursor.fetchone()
277
+
278
+ if not row:
279
+ self._misses += 1
280
+ return None
281
+
282
+ response_json, created_at = row
283
+
284
+ # Check TTL
285
+ if time.time() - created_at > self.ttl_seconds:
286
+ conn.execute("DELETE FROM cache WHERE key = ?", (key,))
287
+ self._misses += 1
288
+ return None
289
+
290
+ # Update access stats
291
+ conn.execute(
292
+ "UPDATE cache SET access_count = access_count + 1, last_accessed = ? WHERE key = ?",
293
+ (time.time(), key)
294
+ )
295
+
296
+ self._hits += 1
297
+ data = json.loads(response_json)
298
+ return BackendResult(**data)
299
+
300
+ def put(self, key: str, response: BackendResult) -> None:
301
+ """Store in persistent cache."""
302
+ if not self.enabled or not key:
303
+ return
304
+
305
+ with self._get_conn() as conn:
306
+ # Evict old entries if needed
307
+ cursor = conn.execute("SELECT COUNT(*) FROM cache")
308
+ count = cursor.fetchone()[0]
309
+
310
+ if count >= self.max_size:
311
+ # Delete oldest 10%
312
+ delete_count = max(1, self.max_size // 10)
313
+ conn.execute(
314
+ "DELETE FROM cache WHERE key IN (SELECT key FROM cache ORDER BY last_accessed LIMIT ?)",
315
+ (delete_count,)
316
+ )
317
+
318
+ response_json = json.dumps({
319
+ "text": response.text,
320
+ "tokens_in": response.tokens_in,
321
+ "tokens_out": response.tokens_out,
322
+ "model_id": response.model_id,
323
+ "latency_ms": response.latency_ms,
324
+ })
325
+
326
+ conn.execute(
327
+ """INSERT OR REPLACE INTO cache
328
+ (key, response_json, created_at, access_count, last_accessed)
329
+ VALUES (?, ?, ?, 0, ?)""",
330
+ (key, response_json, time.time(), time.time())
331
+ )
332
+
333
+ def clear(self) -> int:
334
+ """Clear all entries."""
335
+ with self._get_conn() as conn:
336
+ cursor = conn.execute("SELECT COUNT(*) FROM cache")
337
+ count = cursor.fetchone()[0]
338
+ conn.execute("DELETE FROM cache")
339
+ return count
340
+
341
+
342
+ # =============================================================================
343
+ # Request Batcher
344
+ # =============================================================================
345
+
346
+
347
+ @dataclass
348
+ class BatchRequest:
349
+ """A request in the batch queue."""
350
+ prompt: str
351
+ config: BackendConfig
352
+ future: "asyncio.Future[BackendResult]"
353
+ submitted_at: float = field(default_factory=time.time)
354
+
355
+
356
+ class RequestBatcher:
357
+ """
358
+ Batch multiple requests for efficient processing.
359
+
360
+ Collects requests over a time window and processes them together.
361
+ Useful for backends that support batch inference.
362
+
363
+ Usage:
364
+ batcher = RequestBatcher(backend, batch_size=8, wait_ms=50)
365
+ result = await batcher.submit(prompt, config)
366
+ """
367
+
368
+ def __init__(
369
+ self,
370
+ backend: ModelBackend,
371
+ batch_size: int = 8,
372
+ wait_ms: float = 50.0,
373
+ enabled: bool = True,
374
+ ):
375
+ """
376
+ Initialize batcher.
377
+
378
+ Args:
379
+ backend: Backend to use for generation
380
+ batch_size: Maximum batch size
381
+ wait_ms: Maximum wait time before processing
382
+ enabled: Whether batching is enabled
383
+ """
384
+ self.backend = backend
385
+ self.batch_size = batch_size
386
+ self.wait_ms = wait_ms
387
+ self.enabled = enabled
388
+
389
+ self._queue: list[BatchRequest] = []
390
+ self._lock = threading.Lock()
391
+ self._processing = False
392
+
393
+ # Statistics
394
+ self._batches_processed = 0
395
+ self._requests_processed = 0
396
+
397
+ async def submit(
398
+ self,
399
+ prompt: str,
400
+ config: BackendConfig,
401
+ ) -> BackendResult:
402
+ """
403
+ Submit a request for batched processing.
404
+
405
+ Args:
406
+ prompt: Input prompt
407
+ config: Backend configuration
408
+
409
+ Returns:
410
+ BackendResult from generation
411
+ """
412
+ if not self.enabled:
413
+ # Direct processing if batching disabled
414
+ return self.backend.generate(
415
+ prompt=prompt,
416
+ max_tokens=config.max_tokens,
417
+ temperature=config.temperature,
418
+ top_p=config.top_p,
419
+ stop=config.stop,
420
+ )
421
+
422
+ loop = asyncio.get_event_loop()
423
+ future: asyncio.Future[BackendResult] = loop.create_future()
424
+
425
+ request = BatchRequest(prompt=prompt, config=config, future=future)
426
+
427
+ with self._lock:
428
+ self._queue.append(request)
429
+
430
+ if len(self._queue) >= self.batch_size:
431
+ # Process immediately if batch is full
432
+ self._schedule_processing()
433
+ elif len(self._queue) == 1:
434
+ # Schedule delayed processing
435
+ loop.call_later(self.wait_ms / 1000, self._schedule_processing)
436
+
437
+ return await future
438
+
439
+ def _schedule_processing(self) -> None:
440
+ """Schedule batch processing."""
441
+ with self._lock:
442
+ if self._processing or not self._queue:
443
+ return
444
+
445
+ self._processing = True
446
+ batch = self._queue[:self.batch_size]
447
+ self._queue = self._queue[self.batch_size:]
448
+
449
+ # Process in thread pool
450
+ try:
451
+ self._process_batch(batch)
452
+ finally:
453
+ with self._lock:
454
+ self._processing = False
455
+
456
+ def _process_batch(self, batch: list[BatchRequest]) -> None:
457
+ """Process a batch of requests."""
458
+ for request in batch:
459
+ try:
460
+ result = self.backend.generate(
461
+ prompt=request.prompt,
462
+ max_tokens=request.config.max_tokens,
463
+ temperature=request.config.temperature,
464
+ top_p=request.config.top_p,
465
+ stop=request.config.stop,
466
+ )
467
+
468
+ if not request.future.done():
469
+ request.future.get_loop().call_soon_threadsafe(
470
+ request.future.set_result, result
471
+ )
472
+
473
+ except Exception as e:
474
+ if not request.future.done():
475
+ request.future.get_loop().call_soon_threadsafe(
476
+ request.future.set_exception, e
477
+ )
478
+
479
+ self._batches_processed += 1
480
+ self._requests_processed += len(batch)
481
+
482
+ def get_stats(self) -> dict:
483
+ """Get batcher statistics."""
484
+ return {
485
+ "batches_processed": self._batches_processed,
486
+ "requests_processed": self._requests_processed,
487
+ "avg_batch_size": (
488
+ self._requests_processed / self._batches_processed
489
+ if self._batches_processed > 0 else 0
490
+ ),
491
+ "queue_size": len(self._queue),
492
+ "enabled": self.enabled,
493
+ }
494
+
495
+
496
+ # =============================================================================
497
+ # Connection Pool
498
+ # =============================================================================
499
+
500
+
501
+ class ConnectionPool:
502
+ """
503
+ Pool of reusable backend connections.
504
+
505
+ Reduces overhead of creating new connections for each request.
506
+ Thread-safe for concurrent access.
507
+
508
+ Usage:
509
+ pool = ConnectionPool(backend_factory, max_size=4)
510
+
511
+ with pool.acquire() as backend:
512
+ result = backend.generate(prompt)
513
+ """
514
+
515
+ def __init__(
516
+ self,
517
+ backend_factory: Callable[[], ModelBackend],
518
+ max_size: int = 4,
519
+ min_size: int = 1,
520
+ ):
521
+ """
522
+ Initialize pool.
523
+
524
+ Args:
525
+ backend_factory: Factory function to create backends
526
+ max_size: Maximum pool size
527
+ min_size: Minimum backends to keep ready
528
+ """
529
+ self.backend_factory = backend_factory
530
+ self.max_size = max_size
531
+ self.min_size = min_size
532
+
533
+ self._available: Queue[ModelBackend] = Queue()
534
+ self._in_use: set[int] = set()
535
+ self._lock = threading.Lock()
536
+ self._total_created = 0
537
+
538
+ # Pre-create minimum backends
539
+ for _ in range(min_size):
540
+ self._create_backend()
541
+
542
+ def _create_backend(self) -> ModelBackend:
543
+ """Create a new backend instance."""
544
+ backend = self.backend_factory()
545
+ self._available.put(backend)
546
+ self._total_created += 1
547
+ return backend
548
+
549
+ @contextmanager
550
+ def acquire(self, timeout: float = 30.0):
551
+ """
552
+ Acquire a backend from the pool.
553
+
554
+ Args:
555
+ timeout: Maximum time to wait for a backend
556
+
557
+ Yields:
558
+ ModelBackend instance
559
+ """
560
+ backend = None
561
+
562
+ try:
563
+ # Try to get from available
564
+ try:
565
+ backend = self._available.get(timeout=timeout)
566
+ except Empty:
567
+ # Create new if under limit
568
+ with self._lock:
569
+ current_size = self._total_created
570
+ if current_size < self.max_size:
571
+ backend = self.backend_factory()
572
+ self._total_created += 1
573
+ else:
574
+ raise TimeoutError("No backends available in pool")
575
+
576
+ with self._lock:
577
+ self._in_use.add(id(backend))
578
+
579
+ yield backend
580
+
581
+ finally:
582
+ if backend is not None:
583
+ with self._lock:
584
+ self._in_use.discard(id(backend))
585
+ self._available.put(backend)
586
+
587
+ def get_stats(self) -> dict:
588
+ """Get pool statistics."""
589
+ return {
590
+ "total_created": self._total_created,
591
+ "available": self._available.qsize(),
592
+ "in_use": len(self._in_use),
593
+ "max_size": self.max_size,
594
+ }
595
+
596
+
597
+ # =============================================================================
598
+ # Rate Limiter
599
+ # =============================================================================
600
+
601
+
602
+ class RateLimiter:
603
+ """
604
+ Token bucket rate limiter.
605
+
606
+ Controls request rate to avoid overwhelming backends or hitting API limits.
607
+
608
+ Usage:
609
+ limiter = RateLimiter(tokens_per_second=10, burst_size=20)
610
+
611
+ await limiter.acquire() # Blocks until token available
612
+ result = model.generate(prompt)
613
+ """
614
+
615
+ def __init__(
616
+ self,
617
+ tokens_per_second: float = 10.0,
618
+ burst_size: int = 20,
619
+ ):
620
+ """
621
+ Initialize rate limiter.
622
+
623
+ Args:
624
+ tokens_per_second: Token refill rate
625
+ burst_size: Maximum tokens (burst capacity)
626
+ """
627
+ self.tokens_per_second = tokens_per_second
628
+ self.burst_size = burst_size
629
+
630
+ self._tokens = float(burst_size)
631
+ self._last_refill = time.time()
632
+ self._lock = threading.Lock()
633
+
634
+ # Statistics
635
+ self._requests = 0
636
+ self._waits = 0
637
+ self._total_wait_time = 0.0
638
+
639
+ def _refill(self) -> None:
640
+ """Refill tokens based on elapsed time."""
641
+ now = time.time()
642
+ elapsed = now - self._last_refill
643
+ self._tokens = min(
644
+ self.burst_size,
645
+ self._tokens + elapsed * self.tokens_per_second
646
+ )
647
+ self._last_refill = now
648
+
649
+ def acquire(self, tokens: int = 1) -> float:
650
+ """
651
+ Acquire tokens, blocking if necessary.
652
+
653
+ Args:
654
+ tokens: Number of tokens to acquire
655
+
656
+ Returns:
657
+ Wait time in seconds
658
+ """
659
+ wait_time = 0.0
660
+
661
+ with self._lock:
662
+ self._refill()
663
+
664
+ while self._tokens < tokens:
665
+ # Calculate wait time
666
+ needed = tokens - self._tokens
667
+ wait = needed / self.tokens_per_second
668
+ wait_time += wait
669
+
670
+ self._lock.release()
671
+ time.sleep(wait)
672
+ self._lock.acquire()
673
+
674
+ self._refill()
675
+
676
+ self._tokens -= tokens
677
+ self._requests += 1
678
+
679
+ if wait_time > 0:
680
+ self._waits += 1
681
+ self._total_wait_time += wait_time
682
+
683
+ return wait_time
684
+
685
+ async def acquire_async(self, tokens: int = 1) -> float:
686
+ """Async version of acquire."""
687
+ wait_time = 0.0
688
+
689
+ with self._lock:
690
+ self._refill()
691
+
692
+ if self._tokens < tokens:
693
+ needed = tokens - self._tokens
694
+ wait_time = needed / self.tokens_per_second
695
+
696
+ if wait_time > 0:
697
+ await asyncio.sleep(wait_time)
698
+ self._waits += 1
699
+ self._total_wait_time += wait_time
700
+
701
+ with self._lock:
702
+ self._refill()
703
+ self._tokens -= tokens
704
+ self._requests += 1
705
+
706
+ return wait_time
707
+
708
+ def get_stats(self) -> dict:
709
+ """Get rate limiter statistics."""
710
+ return {
711
+ "requests": self._requests,
712
+ "waits": self._waits,
713
+ "total_wait_time": self._total_wait_time,
714
+ "avg_wait_time": (
715
+ self._total_wait_time / self._waits if self._waits > 0 else 0
716
+ ),
717
+ "current_tokens": self._tokens,
718
+ "tokens_per_second": self.tokens_per_second,
719
+ }
720
+
721
+
722
+ # =============================================================================
723
+ # Optimized Runner Wrapper
724
+ # =============================================================================
725
+
726
+
727
+ class OptimizedRunner:
728
+ """
729
+ Wrapper that adds caching, batching, and rate limiting to a ModelRunner.
730
+
731
+ Usage:
732
+ from parishad.models.runner import ModelRunner
733
+
734
+ runner = ModelRunner(stub=True)
735
+ optimized = OptimizedRunner(
736
+ runner,
737
+ cache_enabled=True,
738
+ rate_limit=10.0,
739
+ )
740
+
741
+ text, tokens, model = optimized.generate(
742
+ system_prompt="You are helpful.",
743
+ user_message="Hello!",
744
+ slot=Slot.SMALL,
745
+ )
746
+ """
747
+
748
+ def __init__(
749
+ self,
750
+ runner: "ModelRunner", # type: ignore
751
+ cache_enabled: bool = False,
752
+ cache_max_size: int = 1000,
753
+ cache_ttl: float = 3600,
754
+ rate_limit: Optional[float] = None,
755
+ rate_burst: int = 20,
756
+ ):
757
+ """
758
+ Initialize optimized runner.
759
+
760
+ Args:
761
+ runner: Base ModelRunner to wrap
762
+ cache_enabled: Enable response caching
763
+ cache_max_size: Maximum cache entries
764
+ cache_ttl: Cache TTL in seconds
765
+ rate_limit: Rate limit (requests per second)
766
+ rate_burst: Rate limit burst size
767
+ """
768
+ self.runner = runner
769
+
770
+ self.cache = ResponseCache(
771
+ max_size=cache_max_size,
772
+ ttl_seconds=cache_ttl,
773
+ enabled=cache_enabled,
774
+ )
775
+
776
+ self.rate_limiter: Optional[RateLimiter] = None
777
+ if rate_limit is not None:
778
+ self.rate_limiter = RateLimiter(
779
+ tokens_per_second=rate_limit,
780
+ burst_size=rate_burst,
781
+ )
782
+
783
+ def generate(
784
+ self,
785
+ system_prompt: str,
786
+ user_message: str,
787
+ slot: "Slot", # type: ignore
788
+ max_tokens: Optional[int] = None,
789
+ temperature: Optional[float] = None,
790
+ **kwargs,
791
+ ) -> tuple[str, int, str]:
792
+ """
793
+ Generate with optimizations applied.
794
+
795
+ Args:
796
+ system_prompt: System prompt
797
+ user_message: User message
798
+ slot: Model slot
799
+ max_tokens: Maximum tokens
800
+ temperature: Sampling temperature
801
+ **kwargs: Additional arguments
802
+
803
+ Returns:
804
+ Tuple of (text, tokens, model_id)
805
+ """
806
+ # Build cache key
807
+ prompt = f"{system_prompt}\n{user_message}"
808
+ cache_key = self.cache.make_key(
809
+ prompt=prompt,
810
+ model_id=slot.value,
811
+ temperature=temperature or 0.0,
812
+ max_tokens=max_tokens or 0,
813
+ )
814
+
815
+ # Check cache
816
+ if cached := self.cache.get(cache_key):
817
+ logger.debug("Cache hit for request")
818
+ return cached.text, cached.tokens_in + cached.tokens_out, cached.model_id
819
+
820
+ # Apply rate limiting
821
+ if self.rate_limiter:
822
+ self.rate_limiter.acquire()
823
+
824
+ # Generate
825
+ text, tokens, model_id = self.runner.generate(
826
+ system_prompt=system_prompt,
827
+ user_message=user_message,
828
+ slot=slot,
829
+ max_tokens=max_tokens,
830
+ temperature=temperature,
831
+ **kwargs,
832
+ )
833
+
834
+ # Cache result
835
+ from .backends.base import BackendResult
836
+ result = BackendResult(
837
+ text=text,
838
+ tokens_in=tokens // 2, # Approximate
839
+ tokens_out=tokens - tokens // 2,
840
+ model_id=model_id,
841
+ latency_ms=0,
842
+ )
843
+ self.cache.put(cache_key, result)
844
+
845
+ return text, tokens, model_id
846
+
847
+ def get_stats(self) -> dict:
848
+ """Get optimization statistics."""
849
+ stats = {
850
+ "cache": self.cache.get_stats(),
851
+ }
852
+ if self.rate_limiter:
853
+ stats["rate_limiter"] = self.rate_limiter.get_stats()
854
+ return stats
855
+
856
+
857
+ __all__ = [
858
+ # Cache
859
+ "CacheEntry",
860
+ "ResponseCache",
861
+ "PersistentCache",
862
+ # Batching
863
+ "BatchRequest",
864
+ "RequestBatcher",
865
+ # Connection pool
866
+ "ConnectionPool",
867
+ # Rate limiting
868
+ "RateLimiter",
869
+ # Optimized wrapper
870
+ "OptimizedRunner",
871
+ ]