cite-agent 1.3.9__py3-none-any.whl → 1.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cite_agent/__init__.py +13 -13
- cite_agent/__version__.py +1 -1
- cite_agent/action_first_mode.py +150 -0
- cite_agent/adaptive_providers.py +413 -0
- cite_agent/archive_api_client.py +186 -0
- cite_agent/auth.py +0 -1
- cite_agent/auto_expander.py +70 -0
- cite_agent/cache.py +379 -0
- cite_agent/circuit_breaker.py +370 -0
- cite_agent/citation_network.py +377 -0
- cite_agent/cli.py +8 -16
- cite_agent/cli_conversational.py +113 -3
- cite_agent/confidence_calibration.py +381 -0
- cite_agent/deduplication.py +325 -0
- cite_agent/enhanced_ai_agent.py +689 -371
- cite_agent/error_handler.py +228 -0
- cite_agent/execution_safety.py +329 -0
- cite_agent/full_paper_reader.py +239 -0
- cite_agent/observability.py +398 -0
- cite_agent/offline_mode.py +348 -0
- cite_agent/paper_comparator.py +368 -0
- cite_agent/paper_summarizer.py +420 -0
- cite_agent/pdf_extractor.py +350 -0
- cite_agent/proactive_boundaries.py +266 -0
- cite_agent/quality_gate.py +442 -0
- cite_agent/request_queue.py +390 -0
- cite_agent/response_enhancer.py +257 -0
- cite_agent/response_formatter.py +458 -0
- cite_agent/response_pipeline.py +295 -0
- cite_agent/response_style_enhancer.py +259 -0
- cite_agent/self_healing.py +418 -0
- cite_agent/similarity_finder.py +524 -0
- cite_agent/streaming_ui.py +13 -9
- cite_agent/thinking_blocks.py +308 -0
- cite_agent/tool_orchestrator.py +416 -0
- cite_agent/trend_analyzer.py +540 -0
- cite_agent/unpaywall_client.py +226 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/METADATA +15 -1
- cite_agent-1.4.3.dist-info/RECORD +62 -0
- cite_agent-1.3.9.dist-info/RECORD +0 -32
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/WHEEL +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/entry_points.txt +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/licenses/LICENSE +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Archive API Client - Wrapper for Semantic Scholar and OpenAlex APIs
|
|
3
|
+
|
|
4
|
+
Provides a unified interface for:
|
|
5
|
+
- Getting paper details
|
|
6
|
+
- Getting paper citations
|
|
7
|
+
- Getting paper references
|
|
8
|
+
- Searching papers
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import requests
|
|
13
|
+
from typing import Dict, Any, List, Optional
|
|
14
|
+
from urllib.parse import quote
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ArchiveAPIClient:
|
|
20
|
+
"""Client for accessing academic paper APIs (Semantic Scholar, OpenAlex)"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, timeout: int = 10):
|
|
23
|
+
"""
|
|
24
|
+
Initialize API client
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
timeout: Request timeout in seconds
|
|
28
|
+
"""
|
|
29
|
+
self.timeout = timeout
|
|
30
|
+
self.s2_base_url = "https://api.semanticscholar.org/graph/v1"
|
|
31
|
+
self.session = requests.Session()
|
|
32
|
+
self.session.headers.update({
|
|
33
|
+
'User-Agent': 'Cite-Agent/1.0 (Academic Research Tool)'
|
|
34
|
+
})
|
|
35
|
+
|
|
36
|
+
def get_paper(self, paper_id: str, fields: Optional[List[str]] = None) -> Optional[Dict[str, Any]]:
|
|
37
|
+
"""
|
|
38
|
+
Get paper details from Semantic Scholar
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
paper_id: Paper ID (DOI, arXiv ID, or Semantic Scholar ID)
|
|
42
|
+
fields: Fields to retrieve (default: basic metadata)
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Paper data or None if not found
|
|
46
|
+
"""
|
|
47
|
+
if not fields:
|
|
48
|
+
fields = ['paperId', 'title', 'authors', 'year', 'citationCount', 'abstract']
|
|
49
|
+
|
|
50
|
+
fields_param = ','.join(fields)
|
|
51
|
+
url = f"{self.s2_base_url}/paper/{quote(paper_id)}?fields={fields_param}"
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
response = self.session.get(url, timeout=self.timeout)
|
|
55
|
+
|
|
56
|
+
if response.status_code == 200:
|
|
57
|
+
return response.json()
|
|
58
|
+
elif response.status_code == 404:
|
|
59
|
+
logger.warning(f"Paper not found: {paper_id}")
|
|
60
|
+
return None
|
|
61
|
+
else:
|
|
62
|
+
logger.error(f"S2 API error {response.status_code}: {response.text}")
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
except requests.RequestException as e:
|
|
66
|
+
logger.error(f"Request failed for paper {paper_id}: {e}")
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
def get_paper_citations(self, paper_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
|
70
|
+
"""
|
|
71
|
+
Get papers that cite this paper
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
paper_id: Paper ID
|
|
75
|
+
limit: Maximum citations to return
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
List of citing papers
|
|
79
|
+
"""
|
|
80
|
+
url = f"{self.s2_base_url}/paper/{quote(paper_id)}/citations"
|
|
81
|
+
params = {
|
|
82
|
+
'limit': min(limit, 100), # S2 API max is 100
|
|
83
|
+
'fields': 'paperId,title,authors,year,citationCount'
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
response = self.session.get(url, params=params, timeout=self.timeout)
|
|
88
|
+
|
|
89
|
+
if response.status_code == 200:
|
|
90
|
+
data = response.json()
|
|
91
|
+
citations = data.get('data', [])
|
|
92
|
+
# Extract cited paper from each citation
|
|
93
|
+
return [c.get('citingPaper', {}) for c in citations if 'citingPaper' in c]
|
|
94
|
+
else:
|
|
95
|
+
logger.warning(f"Citations request failed: {response.status_code}")
|
|
96
|
+
return []
|
|
97
|
+
|
|
98
|
+
except requests.RequestException as e:
|
|
99
|
+
logger.error(f"Request failed for citations of {paper_id}: {e}")
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
def get_paper_references(self, paper_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
|
103
|
+
"""
|
|
104
|
+
Get papers referenced by this paper
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
paper_id: Paper ID
|
|
108
|
+
limit: Maximum references to return
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
List of referenced papers
|
|
112
|
+
"""
|
|
113
|
+
url = f"{self.s2_base_url}/paper/{quote(paper_id)}/references"
|
|
114
|
+
params = {
|
|
115
|
+
'limit': min(limit, 100),
|
|
116
|
+
'fields': 'paperId,title,authors,year,citationCount'
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
response = self.session.get(url, params=params, timeout=self.timeout)
|
|
121
|
+
|
|
122
|
+
if response.status_code == 200:
|
|
123
|
+
data = response.json()
|
|
124
|
+
references = data.get('data', [])
|
|
125
|
+
# Extract cited paper from each reference
|
|
126
|
+
return [r.get('citedPaper', {}) for r in references if 'citedPaper' in r]
|
|
127
|
+
else:
|
|
128
|
+
logger.warning(f"References request failed: {response.status_code}")
|
|
129
|
+
return []
|
|
130
|
+
|
|
131
|
+
except requests.RequestException as e:
|
|
132
|
+
logger.error(f"Request failed for references of {paper_id}: {e}")
|
|
133
|
+
return []
|
|
134
|
+
|
|
135
|
+
def search_papers(self, query: str, limit: int = 10, fields: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
|
136
|
+
"""
|
|
137
|
+
Search for papers
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
query: Search query
|
|
141
|
+
limit: Maximum papers to return
|
|
142
|
+
fields: Fields to retrieve
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
List of papers matching query
|
|
146
|
+
"""
|
|
147
|
+
if not fields:
|
|
148
|
+
fields = ['paperId', 'title', 'authors', 'year', 'citationCount', 'abstract']
|
|
149
|
+
|
|
150
|
+
url = f"{self.s2_base_url}/paper/search"
|
|
151
|
+
params = {
|
|
152
|
+
'query': query,
|
|
153
|
+
'limit': min(limit, 100),
|
|
154
|
+
'fields': ','.join(fields)
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
response = self.session.get(url, params=params, timeout=self.timeout)
|
|
159
|
+
|
|
160
|
+
if response.status_code == 200:
|
|
161
|
+
data = response.json()
|
|
162
|
+
return data.get('data', [])
|
|
163
|
+
else:
|
|
164
|
+
logger.warning(f"Search request failed: {response.status_code}")
|
|
165
|
+
return []
|
|
166
|
+
|
|
167
|
+
except requests.RequestException as e:
|
|
168
|
+
logger.error(f"Search request failed for '{query}': {e}")
|
|
169
|
+
return []
|
|
170
|
+
|
|
171
|
+
def close(self):
|
|
172
|
+
"""Close the session"""
|
|
173
|
+
self.session.close()
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_archive_client(timeout: int = 10) -> ArchiveAPIClient:
|
|
177
|
+
"""
|
|
178
|
+
Get an ArchiveAPIClient instance
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
timeout: Request timeout in seconds
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
ArchiveAPIClient instance
|
|
185
|
+
"""
|
|
186
|
+
return ArchiveAPIClient(timeout=timeout)
|
cite_agent/auth.py
CHANGED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Automatic Response Expansion
|
|
3
|
+
|
|
4
|
+
When agent returns minimal info, automatically fetch and show more detail
|
|
5
|
+
|
|
6
|
+
Examples:
|
|
7
|
+
- List of files → Show preview of main file
|
|
8
|
+
- List of papers → Show abstracts
|
|
9
|
+
- Data query → Show breakdown/visualization
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import re
|
|
14
|
+
from typing import Dict, Any, Optional
|
|
15
|
+
|
|
16
|
+
from .proactive_boundaries import ProactiveBoundaries
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AutoExpander:
|
|
22
|
+
"""
|
|
23
|
+
Automatically expands minimal responses with useful detail
|
|
24
|
+
|
|
25
|
+
PHILOSOPHY: Don't make user ask twice for obvious next step
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def should_expand(cls, response: str, query: str, context: Dict[str, Any]) -> bool:
|
|
30
|
+
"""
|
|
31
|
+
Check if response should be automatically expanded
|
|
32
|
+
|
|
33
|
+
Returns True if response is minimal and expansion would be useful
|
|
34
|
+
"""
|
|
35
|
+
expansion_info = ProactiveBoundaries.get_auto_expansion_for_query(query, response)
|
|
36
|
+
return expansion_info['should_expand']
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def expand(cls, response: str, query: str, context: Dict[str, Any]) -> str:
|
|
40
|
+
"""
|
|
41
|
+
Detect when expansion is needed and log it
|
|
42
|
+
|
|
43
|
+
NOTE: With our action-first prompt changes, the LLM should already
|
|
44
|
+
be providing expanded responses. This function mainly serves as a
|
|
45
|
+
quality check - if it detects expansion is needed, it means the
|
|
46
|
+
LLM didn't follow the action-first guidelines.
|
|
47
|
+
|
|
48
|
+
In production, this could trigger a second LLM call to expand,
|
|
49
|
+
but for now we just log the issue.
|
|
50
|
+
"""
|
|
51
|
+
expansion_info = ProactiveBoundaries.get_auto_expansion_for_query(query, response)
|
|
52
|
+
|
|
53
|
+
if not expansion_info['should_expand']:
|
|
54
|
+
return response # No expansion needed - response is already good
|
|
55
|
+
|
|
56
|
+
# Response needs expansion - this is a problem!
|
|
57
|
+
logger.warning(f"⚠️ Response needs expansion but LLM didn't provide it")
|
|
58
|
+
logger.warning(f" Reason: {expansion_info['reason']}")
|
|
59
|
+
logger.warning(f" Missing actions: {expansion_info['expansion_actions']}")
|
|
60
|
+
logger.warning(f" This suggests prompt needs improvement or LLM isn't following guidelines")
|
|
61
|
+
|
|
62
|
+
# For now, return original response
|
|
63
|
+
# In production, we could trigger a second LLM call here to expand
|
|
64
|
+
return response
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# Convenience function
|
|
68
|
+
def auto_expand(response: str, query: str, context: Dict[str, Any] = None) -> str:
|
|
69
|
+
"""Quick auto-expansion"""
|
|
70
|
+
return AutoExpander.expand(response, query, context or {})
|
cite_agent/cache.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Disk Cache for API Responses
|
|
4
|
+
|
|
5
|
+
Caches API responses to disk for:
|
|
6
|
+
- Faster repeated queries
|
|
7
|
+
- Offline access to previous results
|
|
8
|
+
- Reduced API costs
|
|
9
|
+
- Better performance
|
|
10
|
+
|
|
11
|
+
Uses SQLite for indexed storage
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
import sqlite3
|
|
16
|
+
import json
|
|
17
|
+
import hashlib
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Optional, Dict, Any, List
|
|
20
|
+
from datetime import datetime, timedelta
|
|
21
|
+
from contextlib import contextmanager
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DiskCache:
|
|
27
|
+
"""
|
|
28
|
+
SQLite-based cache for API responses
|
|
29
|
+
|
|
30
|
+
Features:
|
|
31
|
+
- Automatic expiration (TTL)
|
|
32
|
+
- LRU eviction when cache is full
|
|
33
|
+
- Indexed lookups by query hash
|
|
34
|
+
- Compression for large responses
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
cache_dir: str = "~/.cite_agent/cache",
|
|
40
|
+
max_size_mb: int = 500,
|
|
41
|
+
default_ttl_hours: int = 24
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize cache
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
cache_dir: Directory for cache database
|
|
48
|
+
max_size_mb: Maximum cache size in MB
|
|
49
|
+
default_ttl_hours: Default time-to-live in hours
|
|
50
|
+
"""
|
|
51
|
+
self.cache_dir = Path(cache_dir).expanduser()
|
|
52
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
|
|
54
|
+
self.db_path = self.cache_dir / "cache.db"
|
|
55
|
+
self.max_size_mb = max_size_mb
|
|
56
|
+
self.default_ttl_hours = default_ttl_hours
|
|
57
|
+
|
|
58
|
+
self._init_db()
|
|
59
|
+
|
|
60
|
+
def _init_db(self):
|
|
61
|
+
"""Initialize SQLite database"""
|
|
62
|
+
with self._get_connection() as conn:
|
|
63
|
+
conn.execute("""
|
|
64
|
+
CREATE TABLE IF NOT EXISTS cache (
|
|
65
|
+
key TEXT PRIMARY KEY,
|
|
66
|
+
value TEXT NOT NULL,
|
|
67
|
+
query_type TEXT NOT NULL,
|
|
68
|
+
query_text TEXT,
|
|
69
|
+
created_at TEXT NOT NULL,
|
|
70
|
+
expires_at TEXT NOT NULL,
|
|
71
|
+
access_count INTEGER DEFAULT 0,
|
|
72
|
+
last_accessed TEXT,
|
|
73
|
+
size_bytes INTEGER
|
|
74
|
+
)
|
|
75
|
+
""")
|
|
76
|
+
|
|
77
|
+
# Create indexes for fast lookups
|
|
78
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_expires_at ON cache(expires_at)")
|
|
79
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_query_type ON cache(query_type)")
|
|
80
|
+
conn.execute("CREATE INDEX IF NOT EXISTS idx_last_accessed ON cache(last_accessed)")
|
|
81
|
+
|
|
82
|
+
conn.commit()
|
|
83
|
+
|
|
84
|
+
@contextmanager
|
|
85
|
+
def _get_connection(self):
|
|
86
|
+
"""Get database connection"""
|
|
87
|
+
conn = sqlite3.connect(str(self.db_path))
|
|
88
|
+
conn.row_factory = sqlite3.Row
|
|
89
|
+
try:
|
|
90
|
+
yield conn
|
|
91
|
+
finally:
|
|
92
|
+
conn.close()
|
|
93
|
+
|
|
94
|
+
def _make_key(self, query_type: str, **params) -> str:
|
|
95
|
+
"""
|
|
96
|
+
Generate cache key from query parameters
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
query_type: Type of query (search, financial, etc.)
|
|
100
|
+
**params: Query parameters
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Cache key (hash)
|
|
104
|
+
"""
|
|
105
|
+
# Create deterministic string from params
|
|
106
|
+
param_str = json.dumps(params, sort_keys=True)
|
|
107
|
+
combined = f"{query_type}:{param_str}"
|
|
108
|
+
|
|
109
|
+
# Hash to fixed-length key
|
|
110
|
+
return hashlib.sha256(combined.encode()).hexdigest()
|
|
111
|
+
|
|
112
|
+
def get(self, query_type: str, **params) -> Optional[Dict[str, Any]]:
|
|
113
|
+
"""
|
|
114
|
+
Get cached value
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
query_type: Type of query
|
|
118
|
+
**params: Query parameters
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Cached value if exists and not expired, None otherwise
|
|
122
|
+
"""
|
|
123
|
+
key = self._make_key(query_type, **params)
|
|
124
|
+
|
|
125
|
+
with self._get_connection() as conn:
|
|
126
|
+
cursor = conn.execute(
|
|
127
|
+
"SELECT value, expires_at, access_count FROM cache WHERE key = ?",
|
|
128
|
+
(key,)
|
|
129
|
+
)
|
|
130
|
+
row = cursor.fetchone()
|
|
131
|
+
|
|
132
|
+
if not row:
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
# Check expiration
|
|
136
|
+
expires_at = datetime.fromisoformat(row["expires_at"])
|
|
137
|
+
if datetime.now() > expires_at:
|
|
138
|
+
logger.debug(f"Cache expired for {query_type}")
|
|
139
|
+
self._delete(conn, key)
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
# Update access stats
|
|
143
|
+
conn.execute("""
|
|
144
|
+
UPDATE cache
|
|
145
|
+
SET access_count = access_count + 1,
|
|
146
|
+
last_accessed = ?
|
|
147
|
+
WHERE key = ?
|
|
148
|
+
""", (datetime.now().isoformat(), key))
|
|
149
|
+
conn.commit()
|
|
150
|
+
|
|
151
|
+
# Deserialize value
|
|
152
|
+
try:
|
|
153
|
+
value = json.loads(row["value"])
|
|
154
|
+
logger.debug(f"Cache HIT for {query_type} (accessed {row['access_count']} times)")
|
|
155
|
+
return value
|
|
156
|
+
except json.JSONDecodeError:
|
|
157
|
+
logger.warning(f"Corrupted cache entry for {query_type}")
|
|
158
|
+
self._delete(conn, key)
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
def set(
|
|
162
|
+
self,
|
|
163
|
+
query_type: str,
|
|
164
|
+
value: Dict[str, Any],
|
|
165
|
+
ttl_hours: Optional[int] = None,
|
|
166
|
+
**params
|
|
167
|
+
):
|
|
168
|
+
"""
|
|
169
|
+
Cache a value
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
query_type: Type of query
|
|
173
|
+
value: Value to cache
|
|
174
|
+
ttl_hours: Time to live in hours (None = use default)
|
|
175
|
+
**params: Query parameters
|
|
176
|
+
"""
|
|
177
|
+
key = self._make_key(query_type, **params)
|
|
178
|
+
ttl = ttl_hours if ttl_hours is not None else self.default_ttl_hours
|
|
179
|
+
|
|
180
|
+
# Serialize value
|
|
181
|
+
value_json = json.dumps(value)
|
|
182
|
+
size_bytes = len(value_json.encode('utf-8'))
|
|
183
|
+
|
|
184
|
+
# Calculate expiration
|
|
185
|
+
created_at = datetime.now()
|
|
186
|
+
expires_at = created_at + timedelta(hours=ttl)
|
|
187
|
+
|
|
188
|
+
# Check if we need to evict old entries
|
|
189
|
+
self._maybe_evict()
|
|
190
|
+
|
|
191
|
+
with self._get_connection() as conn:
|
|
192
|
+
# Insert or replace
|
|
193
|
+
conn.execute("""
|
|
194
|
+
INSERT OR REPLACE INTO cache
|
|
195
|
+
(key, value, query_type, query_text, created_at, expires_at, size_bytes, last_accessed)
|
|
196
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
197
|
+
""", (
|
|
198
|
+
key,
|
|
199
|
+
value_json,
|
|
200
|
+
query_type,
|
|
201
|
+
json.dumps(params),
|
|
202
|
+
created_at.isoformat(),
|
|
203
|
+
expires_at.isoformat(),
|
|
204
|
+
size_bytes,
|
|
205
|
+
created_at.isoformat()
|
|
206
|
+
))
|
|
207
|
+
conn.commit()
|
|
208
|
+
|
|
209
|
+
logger.debug(f"Cached {query_type} ({size_bytes} bytes, TTL: {ttl}h)")
|
|
210
|
+
|
|
211
|
+
def _delete(self, conn: sqlite3.Connection, key: str):
|
|
212
|
+
"""Delete cache entry"""
|
|
213
|
+
conn.execute("DELETE FROM cache WHERE key = ?", (key,))
|
|
214
|
+
conn.commit()
|
|
215
|
+
|
|
216
|
+
def _maybe_evict(self):
|
|
217
|
+
"""Evict old entries if cache is too large"""
|
|
218
|
+
with self._get_connection() as conn:
|
|
219
|
+
# Check current size
|
|
220
|
+
cursor = conn.execute("SELECT SUM(size_bytes) as total FROM cache")
|
|
221
|
+
row = cursor.fetchone()
|
|
222
|
+
total_bytes = row["total"] or 0
|
|
223
|
+
total_mb = total_bytes / (1024 * 1024)
|
|
224
|
+
|
|
225
|
+
if total_mb > self.max_size_mb:
|
|
226
|
+
# Evict least recently used entries
|
|
227
|
+
evict_count = int(self.get_stats()["total_entries"] * 0.2) # Evict 20%
|
|
228
|
+
|
|
229
|
+
conn.execute("""
|
|
230
|
+
DELETE FROM cache
|
|
231
|
+
WHERE key IN (
|
|
232
|
+
SELECT key FROM cache
|
|
233
|
+
ORDER BY last_accessed ASC
|
|
234
|
+
LIMIT ?
|
|
235
|
+
)
|
|
236
|
+
""", (evict_count,))
|
|
237
|
+
conn.commit()
|
|
238
|
+
|
|
239
|
+
logger.info(f"Evicted {evict_count} cache entries (cache was {total_mb:.1f}MB)")
|
|
240
|
+
|
|
241
|
+
def clear_expired(self):
|
|
242
|
+
"""Remove all expired entries"""
|
|
243
|
+
with self._get_connection() as conn:
|
|
244
|
+
cursor = conn.execute(
|
|
245
|
+
"DELETE FROM cache WHERE expires_at < ?",
|
|
246
|
+
(datetime.now().isoformat(),)
|
|
247
|
+
)
|
|
248
|
+
count = cursor.rowcount
|
|
249
|
+
conn.commit()
|
|
250
|
+
|
|
251
|
+
if count > 0:
|
|
252
|
+
logger.info(f"Cleared {count} expired cache entries")
|
|
253
|
+
|
|
254
|
+
return count
|
|
255
|
+
|
|
256
|
+
def clear_all(self):
|
|
257
|
+
"""Clear entire cache"""
|
|
258
|
+
with self._get_connection() as conn:
|
|
259
|
+
cursor = conn.execute("DELETE FROM cache")
|
|
260
|
+
count = cursor.rowcount
|
|
261
|
+
conn.commit()
|
|
262
|
+
|
|
263
|
+
logger.info(f"Cleared all cache ({count} entries)")
|
|
264
|
+
return count
|
|
265
|
+
|
|
266
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
267
|
+
"""
|
|
268
|
+
Get cache statistics
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Statistics dictionary
|
|
272
|
+
"""
|
|
273
|
+
with self._get_connection() as conn:
|
|
274
|
+
# Total entries
|
|
275
|
+
cursor = conn.execute("SELECT COUNT(*) as count FROM cache")
|
|
276
|
+
total = cursor.fetchone()["count"]
|
|
277
|
+
|
|
278
|
+
# Total size
|
|
279
|
+
cursor = conn.execute("SELECT SUM(size_bytes) as total FROM cache")
|
|
280
|
+
total_bytes = cursor.fetchone()["total"] or 0
|
|
281
|
+
|
|
282
|
+
# By query type
|
|
283
|
+
cursor = conn.execute("""
|
|
284
|
+
SELECT query_type, COUNT(*) as count
|
|
285
|
+
FROM cache
|
|
286
|
+
GROUP BY query_type
|
|
287
|
+
""")
|
|
288
|
+
by_type = {row["query_type"]: row["count"] for row in cursor.fetchall()}
|
|
289
|
+
|
|
290
|
+
# Expired count
|
|
291
|
+
cursor = conn.execute(
|
|
292
|
+
"SELECT COUNT(*) as count FROM cache WHERE expires_at < ?",
|
|
293
|
+
(datetime.now().isoformat(),)
|
|
294
|
+
)
|
|
295
|
+
expired_count = cursor.fetchone()["count"]
|
|
296
|
+
|
|
297
|
+
return {
|
|
298
|
+
"total_entries": total,
|
|
299
|
+
"total_size_mb": total_bytes / (1024 * 1024),
|
|
300
|
+
"max_size_mb": self.max_size_mb,
|
|
301
|
+
"usage_percent": (total_bytes / (self.max_size_mb * 1024 * 1024)) * 100,
|
|
302
|
+
"by_type": by_type,
|
|
303
|
+
"expired_count": expired_count
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
def get_recent_queries(self, limit: int = 10) -> List[Dict[str, Any]]:
|
|
307
|
+
"""
|
|
308
|
+
Get recently cached queries
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
limit: Maximum number to return
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
List of recent queries
|
|
315
|
+
"""
|
|
316
|
+
with self._get_connection() as conn:
|
|
317
|
+
cursor = conn.execute("""
|
|
318
|
+
SELECT query_type, query_text, created_at, access_count
|
|
319
|
+
FROM cache
|
|
320
|
+
ORDER BY created_at DESC
|
|
321
|
+
LIMIT ?
|
|
322
|
+
""", (limit,))
|
|
323
|
+
|
|
324
|
+
return [
|
|
325
|
+
{
|
|
326
|
+
"query_type": row["query_type"],
|
|
327
|
+
"query": json.loads(row["query_text"]) if row["query_text"] else {},
|
|
328
|
+
"cached_at": row["created_at"],
|
|
329
|
+
"access_count": row["access_count"]
|
|
330
|
+
}
|
|
331
|
+
for row in cursor.fetchall()
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# Global cache instance
|
|
336
|
+
_cache = None
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def get_cache() -> DiskCache:
|
|
340
|
+
"""Get global cache instance"""
|
|
341
|
+
global _cache
|
|
342
|
+
if _cache is None:
|
|
343
|
+
_cache = DiskCache()
|
|
344
|
+
return _cache
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def cached_api_call(query_type: str, ttl_hours: int = 24):
|
|
348
|
+
"""
|
|
349
|
+
Decorator for caching API calls
|
|
350
|
+
|
|
351
|
+
Usage:
|
|
352
|
+
@cached_api_call("academic_search", ttl_hours=24)
|
|
353
|
+
async def search_papers(query: str, limit: int):
|
|
354
|
+
# ... API call ...
|
|
355
|
+
return results
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
query_type: Type of query
|
|
359
|
+
ttl_hours: Cache TTL in hours
|
|
360
|
+
"""
|
|
361
|
+
def decorator(func):
|
|
362
|
+
async def wrapper(*args, **kwargs):
|
|
363
|
+
cache = get_cache()
|
|
364
|
+
|
|
365
|
+
# Try cache first
|
|
366
|
+
cached_result = cache.get(query_type, **kwargs)
|
|
367
|
+
if cached_result is not None:
|
|
368
|
+
return cached_result
|
|
369
|
+
|
|
370
|
+
# Call function
|
|
371
|
+
result = await func(*args, **kwargs)
|
|
372
|
+
|
|
373
|
+
# Cache result
|
|
374
|
+
cache.set(query_type, result, ttl_hours, **kwargs)
|
|
375
|
+
|
|
376
|
+
return result
|
|
377
|
+
|
|
378
|
+
return wrapper
|
|
379
|
+
return decorator
|