codeshield-ai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,605 @@
1
+ """
2
+ Token Efficiency Module - Optimize LLM token usage
3
+
4
+ Implements:
5
+ - Response caching (avoid duplicate API calls)
6
+ - Prompt compression (reduce input tokens)
7
+ - Smart truncation (limit context size)
8
+ - Token budgeting (track and limit usage)
9
+ - Semantic similarity caching (fuzzy matching)
10
+ - Model tiering (cheap models for simple tasks)
11
+ - Local-first processing (skip LLM when possible)
12
+ """
13
+
14
+ import hashlib
15
+ import json
16
+ import re
17
+ import sqlite3
18
+ from pathlib import Path
19
+ from datetime import datetime, timedelta
20
+ from typing import Optional, Dict, Any, List, Tuple
21
+ from dataclasses import dataclass
22
+ from threading import Lock
23
+ from difflib import SequenceMatcher
24
+
25
+
26
+ CACHE_DB_PATH = Path.home() / ".codeshield" / "token_cache.sqlite"
27
+ _cache_lock = Lock()
28
+
29
+
30
+ @dataclass
31
+ class CachedResponse:
32
+ """A cached LLM response"""
33
+ content: str
34
+ provider: str
35
+ model: str
36
+ tokens_saved: int
37
+ cached_at: str
38
+ hits: int = 0
39
+
40
+
41
+ class TokenOptimizer:
42
+ """
43
+ Optimizes token usage through caching and compression.
44
+
45
+ Features:
46
+ - LRU cache for identical prompts
47
+ - Prompt compression for common patterns
48
+ - Smart context truncation
49
+ - Token budget enforcement
50
+ """
51
+
52
+ _instance = None
53
+
54
+ # Token budget settings
55
+ DEFAULT_BUDGET = 100000 # 100K tokens per session
56
+
57
+ # Cache settings
58
+ CACHE_TTL_HOURS = 24
59
+ MAX_CACHE_ENTRIES = 1000
60
+
61
+ def __new__(cls):
62
+ if cls._instance is None:
63
+ cls._instance = super().__new__(cls)
64
+ cls._instance._initialized = False
65
+ return cls._instance
66
+
67
+ def __init__(self):
68
+ if self._initialized:
69
+ return
70
+
71
+ self._session_tokens = 0
72
+ self._token_budget = self.DEFAULT_BUDGET
73
+ self._cache_hits = 0
74
+ self._cache_misses = 0
75
+ self._tokens_saved = 0
76
+ self._ensure_db()
77
+ self._initialized = True
78
+
79
+ def _ensure_db(self):
80
+ """Initialize cache database"""
81
+ CACHE_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
82
+
83
+ conn = sqlite3.connect(str(CACHE_DB_PATH))
84
+ cursor = conn.cursor()
85
+
86
+ cursor.execute("""
87
+ CREATE TABLE IF NOT EXISTS response_cache (
88
+ prompt_hash TEXT PRIMARY KEY,
89
+ content TEXT NOT NULL,
90
+ provider TEXT NOT NULL,
91
+ model TEXT NOT NULL,
92
+ tokens_used INTEGER NOT NULL,
93
+ cached_at TEXT NOT NULL,
94
+ hits INTEGER DEFAULT 0,
95
+ last_hit TEXT
96
+ )
97
+ """)
98
+
99
+ conn.commit()
100
+ conn.close()
101
+
102
+ def _hash_prompt(self, prompt: str, system_prompt: Optional[str] = None) -> str:
103
+ """Generate cache key from prompt"""
104
+ combined = f"{system_prompt or ''}||{prompt}"
105
+ return hashlib.sha256(combined.encode()).hexdigest()[:32]
106
+
107
+ def get_cached(self, prompt: str, system_prompt: Optional[str] = None) -> Optional[CachedResponse]:
108
+ """Check cache for existing response"""
109
+ prompt_hash = self._hash_prompt(prompt, system_prompt)
110
+
111
+ with _cache_lock:
112
+ conn = sqlite3.connect(str(CACHE_DB_PATH))
113
+ cursor = conn.cursor()
114
+
115
+ # Check for valid cache entry
116
+ cursor.execute("""
117
+ SELECT content, provider, model, tokens_used, cached_at, hits
118
+ FROM response_cache
119
+ WHERE prompt_hash = ?
120
+ """, (prompt_hash,))
121
+
122
+ row = cursor.fetchone()
123
+
124
+ if row:
125
+ # Update hit count
126
+ cursor.execute("""
127
+ UPDATE response_cache
128
+ SET hits = hits + 1, last_hit = ?
129
+ WHERE prompt_hash = ?
130
+ """, (datetime.now().isoformat(), prompt_hash))
131
+ conn.commit()
132
+
133
+ self._cache_hits += 1
134
+ self._tokens_saved += row[3] # tokens_used
135
+
136
+ conn.close()
137
+ return CachedResponse(
138
+ content=row[0],
139
+ provider=row[1],
140
+ model=row[2],
141
+ tokens_saved=row[3],
142
+ cached_at=row[4],
143
+ hits=row[5] + 1
144
+ )
145
+
146
+ conn.close()
147
+ self._cache_misses += 1
148
+ return None
149
+
150
+ def cache_response(self, prompt: str, response: Any,
151
+ system_prompt: Optional[str] = None):
152
+ """Cache an LLM response"""
153
+ prompt_hash = self._hash_prompt(prompt, system_prompt)
154
+
155
+ with _cache_lock:
156
+ conn = sqlite3.connect(str(CACHE_DB_PATH))
157
+ cursor = conn.cursor()
158
+
159
+ cursor.execute("""
160
+ INSERT OR REPLACE INTO response_cache
161
+ (prompt_hash, content, provider, model, tokens_used, cached_at, hits)
162
+ VALUES (?, ?, ?, ?, ?, ?, 0)
163
+ """, (
164
+ prompt_hash,
165
+ response.content,
166
+ response.provider,
167
+ response.model,
168
+ response.tokens_used,
169
+ datetime.now().isoformat()
170
+ ))
171
+
172
+ # Cleanup old entries if over limit
173
+ cursor.execute("""
174
+ DELETE FROM response_cache
175
+ WHERE prompt_hash NOT IN (
176
+ SELECT prompt_hash FROM response_cache
177
+ ORDER BY last_hit DESC, cached_at DESC
178
+ LIMIT ?
179
+ )
180
+ """, (self.MAX_CACHE_ENTRIES,))
181
+
182
+ conn.commit()
183
+ conn.close()
184
+
185
+ def compress_prompt(self, prompt: str) -> str:
186
+ """
187
+ Compress prompt to reduce tokens.
188
+
189
+ Techniques:
190
+ - Remove excessive whitespace
191
+ - Shorten common phrases
192
+ - Remove redundant instructions
193
+ """
194
+ # Remove multiple spaces/newlines
195
+ import re
196
+ prompt = re.sub(r'\n{3,}', '\n\n', prompt)
197
+ prompt = re.sub(r' {2,}', ' ', prompt)
198
+ prompt = prompt.strip()
199
+
200
+ # Common compression patterns
201
+ compressions = [
202
+ ("Please ", ""),
203
+ ("Could you please ", ""),
204
+ ("I would like you to ", ""),
205
+ ("Make sure to ", ""),
206
+ ("Be sure to ", ""),
207
+ ("Don't forget to ", ""),
208
+ ("Remember to ", ""),
209
+ ("Note that ", ""),
210
+ ("Please note that ", ""),
211
+ ("It's important that ", ""),
212
+ ("As a reminder, ", ""),
213
+ ]
214
+
215
+ for old, new in compressions:
216
+ prompt = prompt.replace(old, new)
217
+
218
+ return prompt
219
+
220
+ def truncate_code(self, code: str, max_lines: int = 100) -> str:
221
+ """
222
+ Smart truncation of code to reduce tokens.
223
+
224
+ Preserves:
225
+ - Function signatures
226
+ - Class definitions
227
+ - Import statements
228
+ - First/last lines of functions
229
+ """
230
+ lines = code.split('\n')
231
+
232
+ if len(lines) <= max_lines:
233
+ return code
234
+
235
+ # Keep important lines
236
+ important_patterns = [
237
+ 'import ', 'from ', 'class ', 'def ', 'async def ',
238
+ 'return ', 'raise ', '@', 'if __name__'
239
+ ]
240
+
241
+ result = []
242
+ skipped = 0
243
+
244
+ for i, line in enumerate(lines):
245
+ stripped = line.strip()
246
+
247
+ # Always keep important lines
248
+ is_important = any(stripped.startswith(p) for p in important_patterns)
249
+
250
+ # Keep first and last 20 lines always
251
+ is_boundary = i < 20 or i >= len(lines) - 20
252
+
253
+ if is_important or is_boundary:
254
+ if skipped > 0:
255
+ result.append(f" # ... ({skipped} lines omitted)")
256
+ skipped = 0
257
+ result.append(line)
258
+ else:
259
+ skipped += 1
260
+
261
+ if skipped > 0:
262
+ result.append(f" # ... ({skipped} lines omitted)")
263
+
264
+ return '\n'.join(result)
265
+
266
+ def estimate_tokens(self, text: str) -> int:
267
+ """
268
+ Estimate token count (rough approximation).
269
+
270
+ Rule of thumb: ~4 chars per token for English
271
+ Code tends to be ~3 chars per token due to symbols
272
+ """
273
+ # Rough estimate: 1 token ≈ 4 characters
274
+ return len(text) // 4
275
+
276
+ def check_budget(self, estimated_tokens: int) -> bool:
277
+ """Check if request is within budget"""
278
+ return (self._session_tokens + estimated_tokens) <= self._token_budget
279
+
280
+ def record_usage(self, tokens: int):
281
+ """Record token usage"""
282
+ self._session_tokens += tokens
283
+
284
+ def get_stats(self) -> dict:
285
+ """Get optimization statistics"""
286
+ total_requests = self._cache_hits + self._cache_misses
287
+ hit_rate = (self._cache_hits / total_requests * 100) if total_requests > 0 else 0
288
+
289
+ return {
290
+ "cache_hits": self._cache_hits,
291
+ "cache_misses": self._cache_misses,
292
+ "cache_hit_rate": round(hit_rate, 2),
293
+ "tokens_saved_by_cache": self._tokens_saved,
294
+ "tokens_saved_by_local": getattr(self, '_local_saves', 0),
295
+ "tokens_saved_by_compression": getattr(self, '_compression_saves', 0),
296
+ "session_tokens_used": self._session_tokens,
297
+ "token_budget": self._token_budget,
298
+ "budget_remaining": self._token_budget - self._session_tokens,
299
+ "budget_used_percent": round(self._session_tokens / self._token_budget * 100, 2),
300
+ "llm_calls_avoided": self._cache_hits + getattr(self, '_local_saves', 0),
301
+ }
302
+
303
+ def set_budget(self, tokens: int):
304
+ """Set token budget for session"""
305
+ self._token_budget = tokens
306
+
307
+ def reset_session(self):
308
+ """Reset session token counter"""
309
+ self._session_tokens = 0
310
+ self._cache_hits = 0
311
+ self._cache_misses = 0
312
+ self._tokens_saved = 0
313
+ self._local_saves = 0
314
+ self._compression_saves = 0
315
+
316
+
317
+ # Singleton accessor
318
+ def get_token_optimizer() -> TokenOptimizer:
319
+ """Get the global token optimizer instance"""
320
+ return TokenOptimizer()
321
+
322
+
323
+ # =============================================================================
324
+ # LOCAL-FIRST PROCESSING - Skip LLM entirely when possible
325
+ # =============================================================================
326
+
327
+ class LocalProcessor:
328
+ """
329
+ Handle simple tasks locally without LLM calls.
330
+
331
+ HUGE token savings - 100% reduction for supported tasks.
332
+ """
333
+
334
+ # Common import fixes (no LLM needed)
335
+ IMPORT_FIXES = {
336
+ 'json': 'import json',
337
+ 'os': 'import os',
338
+ 'sys': 'import sys',
339
+ 're': 'import re',
340
+ 'math': 'import math',
341
+ 'random': 'import random',
342
+ 'datetime': 'from datetime import datetime',
343
+ 'time': 'import time',
344
+ 'pathlib': 'from pathlib import Path',
345
+ 'typing': 'from typing import Optional, List, Dict, Any',
346
+ 'dataclasses': 'from dataclasses import dataclass',
347
+ 'collections': 'from collections import defaultdict, Counter',
348
+ 'itertools': 'import itertools',
349
+ 'functools': 'import functools',
350
+ 'requests': 'import requests',
351
+ 'httpx': 'import httpx',
352
+ 'asyncio': 'import asyncio',
353
+ 'logging': 'import logging',
354
+ 'subprocess': 'import subprocess',
355
+ 'tempfile': 'import tempfile',
356
+ 'shutil': 'import shutil',
357
+ 'glob': 'import glob',
358
+ 'csv': 'import csv',
359
+ 'sqlite3': 'import sqlite3',
360
+ 'hashlib': 'import hashlib',
361
+ 'base64': 'import base64',
362
+ 'copy': 'import copy',
363
+ 'io': 'import io',
364
+ 'threading': 'import threading',
365
+ 'uuid': 'import uuid',
366
+ 'enum': 'from enum import Enum',
367
+ 'abc': 'from abc import ABC, abstractmethod',
368
+ 'contextlib': 'from contextlib import contextmanager',
369
+ 'pydantic': 'from pydantic import BaseModel',
370
+ 'fastapi': 'from fastapi import FastAPI, HTTPException',
371
+ 'flask': 'from flask import Flask, request, jsonify',
372
+ 'numpy': 'import numpy as np',
373
+ 'pandas': 'import pandas as pd',
374
+ 'pytest': 'import pytest',
375
+ }
376
+
377
+ @classmethod
378
+ def can_fix_locally(cls, code: str, issues: List[str]) -> bool:
379
+ """Check if issues can be fixed without LLM"""
380
+ for issue in issues:
381
+ issue_lower = issue.lower()
382
+ # Only handle simple missing imports locally
383
+ if 'missing import' in issue_lower:
384
+ module = cls._extract_module(issue)
385
+ if module and module in cls.IMPORT_FIXES:
386
+ continue
387
+ return False
388
+ else:
389
+ return False # Other issues need LLM
390
+ return len(issues) > 0
391
+
392
+ @classmethod
393
+ def fix_locally(cls, code: str, issues: List[str]) -> Optional[str]:
394
+ """
395
+ Fix code locally without LLM.
396
+
397
+ Returns fixed code or None if can't fix locally.
398
+ """
399
+ if not cls.can_fix_locally(code, issues):
400
+ return None
401
+
402
+ imports_to_add = []
403
+ for issue in issues:
404
+ if 'missing import' in issue.lower():
405
+ module = cls._extract_module(issue)
406
+ if module and module in cls.IMPORT_FIXES:
407
+ imports_to_add.append(cls.IMPORT_FIXES[module])
408
+
409
+ if not imports_to_add:
410
+ return None
411
+
412
+ # Add imports at the top
413
+ lines = code.split('\n')
414
+ insert_pos = 0
415
+
416
+ # Skip docstrings and existing imports
417
+ in_docstring = False
418
+ for i, line in enumerate(lines):
419
+ stripped = line.strip()
420
+ if stripped.startswith('"""') or stripped.startswith("'''"):
421
+ if in_docstring or stripped.count('"""') >= 2 or stripped.count("'''") >= 2:
422
+ in_docstring = not in_docstring if stripped.count('"""') == 1 or stripped.count("'''") == 1 else in_docstring
423
+ else:
424
+ in_docstring = not in_docstring
425
+ insert_pos = i + 1
426
+ elif not in_docstring and (stripped.startswith('import ') or stripped.startswith('from ')):
427
+ insert_pos = i + 1
428
+ elif not in_docstring and stripped and not stripped.startswith('#'):
429
+ break
430
+
431
+ # Deduplicate imports
432
+ existing_imports = set()
433
+ for line in lines:
434
+ if line.strip().startswith('import ') or line.strip().startswith('from '):
435
+ existing_imports.add(line.strip())
436
+
437
+ new_imports = [imp for imp in imports_to_add if imp not in existing_imports]
438
+
439
+ if not new_imports:
440
+ return code # Nothing to add
441
+
442
+ for imp in reversed(new_imports):
443
+ lines.insert(insert_pos, imp)
444
+
445
+ return '\n'.join(lines)
446
+
447
+ @classmethod
448
+ def _extract_module(cls, issue: str) -> Optional[str]:
449
+ """Extract module name from issue message"""
450
+ # "Missing import: json" -> "json"
451
+ # "Missing import: json (pip install json)" -> "json"
452
+ match = re.search(r'missing import[:\s]+(\w+)', issue.lower())
453
+ if match:
454
+ return match.group(1)
455
+ return None
456
+
457
+
458
+ # =============================================================================
459
+ # SEMANTIC CACHING - Match similar prompts
460
+ # =============================================================================
461
+
462
+ def normalize_code(code: str) -> str:
463
+ """Normalize code for semantic comparison"""
464
+ # Remove comments
465
+ code = re.sub(r'#.*$', '', code, flags=re.MULTILINE)
466
+ # Normalize whitespace
467
+ code = re.sub(r'\s+', ' ', code)
468
+ # Remove string contents (keep structure)
469
+ code = re.sub(r'"[^"]*"', '""', code)
470
+ code = re.sub(r"'[^']*'", "''", code)
471
+ return code.strip().lower()
472
+
473
+
474
+ def code_similarity(code1: str, code2: str) -> float:
475
+ """Calculate similarity between two code snippets"""
476
+ norm1 = normalize_code(code1)
477
+ norm2 = normalize_code(code2)
478
+ return SequenceMatcher(None, norm1, norm2).ratio()
479
+
480
+
481
+ # =============================================================================
482
+ # MODEL TIERING - Use cheaper models for simple tasks
483
+ # =============================================================================
484
+
485
+ class ModelTier:
486
+ """Select optimal model based on task complexity"""
487
+
488
+ # Task complexity thresholds
489
+ SIMPLE_MAX_LINES = 20
490
+ SIMPLE_MAX_ISSUES = 3
491
+
492
+ # Model recommendations per provider
493
+ MODELS = {
494
+ "cometapi": {
495
+ "simple": "deepseek-chat", # Free, fast
496
+ "complex": "gpt-4o-mini", # Smarter but costs more
497
+ },
498
+ "novita": {
499
+ "simple": "meta-llama/llama-3-8b-instruct", # Fast, cheap
500
+ "complex": "deepseek/deepseek-r1", # Better reasoning
501
+ },
502
+ "aiml": {
503
+ "simple": "gpt-4o-mini",
504
+ "complex": "gpt-4o",
505
+ }
506
+ }
507
+
508
+ @classmethod
509
+ def select_model(cls, code: str, issues: List[str], provider: str) -> str:
510
+ """Select optimal model based on task complexity"""
511
+ complexity = cls._assess_complexity(code, issues)
512
+
513
+ default_model = "deepseek-chat"
514
+ models = cls.MODELS.get(provider, {"simple": default_model, "complex": default_model})
515
+
516
+ if complexity == "simple":
517
+ return models.get("simple", default_model)
518
+ return models.get("complex", default_model)
519
+
520
+ @classmethod
521
+ def _assess_complexity(cls, code: str, issues: List[str]) -> str:
522
+ """Assess task complexity"""
523
+ lines = code.count('\n') + 1
524
+
525
+ # Simple: short code, few issues, only import/syntax issues
526
+ if lines <= cls.SIMPLE_MAX_LINES and len(issues) <= cls.SIMPLE_MAX_ISSUES:
527
+ simple_issues = all(
528
+ 'import' in i.lower() or 'syntax' in i.lower() or 'indent' in i.lower()
529
+ for i in issues
530
+ )
531
+ if simple_issues:
532
+ return "simple"
533
+
534
+ return "complex"
535
+
536
+
537
+ # =============================================================================
538
+ # OPTIMIZED PROMPTS - Maximum compression
539
+ # =============================================================================
540
+
541
+ OPTIMIZED_PROMPTS = {
542
+ # Ultra-short fix prompt (~60% smaller than verbose)
543
+ "fix_code": "Fix:\n{issues}\n\nCode:\n```\n{code}\n```\nReturn fixed code only.",
544
+
545
+ # Minimal context briefing
546
+ "context_briefing": "Summarize work state (2 sentences):\nFiles: {files}\nLast: {last_edited}\nTime: {time_ago}",
547
+
548
+ # Style suggestion
549
+ "style_suggest": "Suggest {convention} names for:\n{names}",
550
+
551
+ # Ultra minimal for simple fixes (when we must use LLM)
552
+ "simple_fix": "Add imports and fix:\n```\n{code}\n```",
553
+ }
554
+
555
+
556
+ def optimize_fix_prompt(code: str, issues: List[str]) -> str:
557
+ """Optimized prompt for code fixing - uses ~60% fewer tokens"""
558
+ optimizer = get_token_optimizer()
559
+
560
+ # Try local fix first (0 tokens!)
561
+ local_fix = LocalProcessor.fix_locally(code, issues)
562
+ if local_fix is not None:
563
+ if not hasattr(optimizer, '_local_saves'):
564
+ optimizer._local_saves = 0
565
+ optimizer._local_saves += 1
566
+ return "__LOCAL_FIX__" # Signal to skip LLM
567
+
568
+ # Use ultra-minimal prompt for simple issues
569
+ if all('import' in i.lower() for i in issues) and code.count('\n') < 30:
570
+ code = optimizer.truncate_code(code, max_lines=30)
571
+ return OPTIMIZED_PROMPTS["simple_fix"].format(code=code)
572
+
573
+ # Standard optimized prompt
574
+ issues_text = "; ".join(issues) # Semicolons instead of bullets
575
+ code = optimizer.truncate_code(code, max_lines=50) # Reduced from 80
576
+
577
+ return OPTIMIZED_PROMPTS["fix_code"].format(
578
+ issues=issues_text,
579
+ code=code
580
+ )
581
+
582
+
583
+ def optimize_context_prompt(context: dict) -> str:
584
+ """Optimized prompt for context briefing"""
585
+ return OPTIMIZED_PROMPTS["context_briefing"].format(
586
+ files=", ".join(context.get("files", [])[:3]), # Limit to 3 files (was 5)
587
+ last_edited=context.get("last_edited", "?"),
588
+ time_ago=context.get("time_ago", "?")
589
+ )
590
+
591
+
592
+ # =============================================================================
593
+ # RESPONSE OPTIMIZATION
594
+ # =============================================================================
595
+
596
+ def get_optimal_max_tokens(task: str, code_length: int) -> int:
597
+ """Calculate minimum max_tokens needed for task"""
598
+ if task == "fix":
599
+ # Output is usually similar size to input
600
+ return min(500, max(100, code_length // 3))
601
+ elif task == "briefing":
602
+ return 50 # Just 2 sentences
603
+ elif task == "style":
604
+ return 100 # Short suggestions
605
+ return 500 # Default