crossref-local 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crossref_local/__init__.py +86 -22
- crossref_local/__main__.py +6 -0
- crossref_local/aio.py +0 -0
- crossref_local/api.py +148 -5
- crossref_local/cache.py +466 -0
- crossref_local/cache_export.py +83 -0
- crossref_local/cache_viz.py +296 -0
- crossref_local/citations.py +0 -0
- crossref_local/cli.py +358 -97
- crossref_local/cli_cache.py +179 -0
- crossref_local/cli_completion.py +245 -0
- crossref_local/cli_main.py +20 -0
- crossref_local/cli_mcp.py +275 -0
- crossref_local/config.py +99 -3
- crossref_local/db.py +3 -1
- crossref_local/fts.py +38 -4
- crossref_local/impact_factor/__init__.py +0 -0
- crossref_local/impact_factor/calculator.py +0 -0
- crossref_local/impact_factor/journal_lookup.py +0 -0
- crossref_local/mcp_server.py +413 -0
- crossref_local/models.py +0 -0
- crossref_local/remote.py +269 -0
- crossref_local/server.py +352 -0
- {crossref_local-0.3.0.dist-info → crossref_local-0.4.0.dist-info}/METADATA +152 -7
- crossref_local-0.4.0.dist-info/RECORD +27 -0
- crossref_local-0.4.0.dist-info/entry_points.txt +3 -0
- crossref_local-0.3.0.dist-info/RECORD +0 -16
- crossref_local-0.3.0.dist-info/entry_points.txt +0 -2
- {crossref_local-0.3.0.dist-info → crossref_local-0.4.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""MCP server management commands for crossref-local CLI."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
import click
|
|
7
|
+
|
|
8
|
+
CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]}
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@click.group("mcp", context_settings=CONTEXT_SETTINGS)
|
|
12
|
+
def mcp():
|
|
13
|
+
"""MCP (Model Context Protocol) server management.
|
|
14
|
+
|
|
15
|
+
\b
|
|
16
|
+
Commands for running and managing the MCP server that enables
|
|
17
|
+
AI assistants like Claude to search academic papers.
|
|
18
|
+
|
|
19
|
+
\b
|
|
20
|
+
Quick start:
|
|
21
|
+
crossref-local mcp start # Start stdio server
|
|
22
|
+
crossref-local mcp start -t http # Start HTTP server
|
|
23
|
+
crossref-local mcp doctor # Check dependencies
|
|
24
|
+
crossref-local mcp installation # Show config snippets
|
|
25
|
+
"""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@mcp.command("start", context_settings=CONTEXT_SETTINGS)
|
|
30
|
+
@click.option(
|
|
31
|
+
"-t",
|
|
32
|
+
"--transport",
|
|
33
|
+
type=click.Choice(["stdio", "sse", "http"]),
|
|
34
|
+
default="stdio",
|
|
35
|
+
help="Transport protocol (http recommended for remote)",
|
|
36
|
+
)
|
|
37
|
+
@click.option(
|
|
38
|
+
"--host",
|
|
39
|
+
default="localhost",
|
|
40
|
+
envvar="CROSSREF_LOCAL_MCP_HOST",
|
|
41
|
+
help="Host for HTTP/SSE transport",
|
|
42
|
+
)
|
|
43
|
+
@click.option(
|
|
44
|
+
"--port",
|
|
45
|
+
default=8082,
|
|
46
|
+
type=int,
|
|
47
|
+
envvar="CROSSREF_LOCAL_MCP_PORT",
|
|
48
|
+
help="Port for HTTP/SSE transport",
|
|
49
|
+
)
|
|
50
|
+
def start_cmd(transport: str, host: str, port: int):
|
|
51
|
+
"""Start the MCP server.
|
|
52
|
+
|
|
53
|
+
\b
|
|
54
|
+
Transports:
|
|
55
|
+
stdio - Standard I/O (default, for Claude Desktop local)
|
|
56
|
+
http - Streamable HTTP (recommended for remote/persistent)
|
|
57
|
+
sse - Server-Sent Events (deprecated as of MCP spec 2025-03-26)
|
|
58
|
+
|
|
59
|
+
\b
|
|
60
|
+
Examples:
|
|
61
|
+
crossref-local mcp start # stdio for Claude Desktop
|
|
62
|
+
crossref-local mcp start -t http # HTTP on localhost:8082
|
|
63
|
+
crossref-local mcp start -t http --port 9000 # Custom port
|
|
64
|
+
"""
|
|
65
|
+
try:
|
|
66
|
+
from .mcp_server import run_server
|
|
67
|
+
except ImportError:
|
|
68
|
+
click.echo(
|
|
69
|
+
"MCP server requires fastmcp. Install with:\n"
|
|
70
|
+
" pip install crossref-local[mcp]",
|
|
71
|
+
err=True,
|
|
72
|
+
)
|
|
73
|
+
sys.exit(1)
|
|
74
|
+
|
|
75
|
+
run_server(transport=transport, host=host, port=port)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@mcp.command("doctor", context_settings=CONTEXT_SETTINGS)
|
|
79
|
+
def doctor_cmd():
|
|
80
|
+
"""Check MCP server dependencies and configuration.
|
|
81
|
+
|
|
82
|
+
Verifies that all required packages are installed and
|
|
83
|
+
the database is accessible.
|
|
84
|
+
"""
|
|
85
|
+
click.echo("MCP Server Health Check")
|
|
86
|
+
click.echo("=" * 40)
|
|
87
|
+
|
|
88
|
+
issues = []
|
|
89
|
+
|
|
90
|
+
# Check fastmcp
|
|
91
|
+
try:
|
|
92
|
+
import fastmcp
|
|
93
|
+
|
|
94
|
+
click.echo(f"[OK] fastmcp {fastmcp.__version__}")
|
|
95
|
+
except ImportError:
|
|
96
|
+
click.echo("[FAIL] fastmcp not installed")
|
|
97
|
+
issues.append("Install fastmcp: pip install crossref-local[mcp]")
|
|
98
|
+
|
|
99
|
+
# Check database
|
|
100
|
+
try:
|
|
101
|
+
from . import info
|
|
102
|
+
|
|
103
|
+
db_info = info()
|
|
104
|
+
works = db_info.get("works", 0)
|
|
105
|
+
click.echo(f"[OK] Database: {works:,} works")
|
|
106
|
+
except Exception as e:
|
|
107
|
+
click.echo(f"[FAIL] Database: {e}")
|
|
108
|
+
issues.append("Configure database: export CROSSREF_LOCAL_DB=/path/to/db")
|
|
109
|
+
|
|
110
|
+
# Check FTS index
|
|
111
|
+
try:
|
|
112
|
+
from . import info
|
|
113
|
+
|
|
114
|
+
db_info = info()
|
|
115
|
+
fts = db_info.get("fts_indexed", 0)
|
|
116
|
+
if fts > 0:
|
|
117
|
+
click.echo(f"[OK] FTS index: {fts:,} indexed")
|
|
118
|
+
else:
|
|
119
|
+
click.echo("[WARN] FTS index: not built")
|
|
120
|
+
issues.append("Build FTS index: make fts-build")
|
|
121
|
+
except Exception:
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
click.echo()
|
|
125
|
+
if issues:
|
|
126
|
+
click.echo("Issues found:")
|
|
127
|
+
for issue in issues:
|
|
128
|
+
click.echo(f" - {issue}")
|
|
129
|
+
sys.exit(1)
|
|
130
|
+
else:
|
|
131
|
+
click.echo("All checks passed!")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@mcp.command("installation", context_settings=CONTEXT_SETTINGS)
|
|
135
|
+
@click.option(
|
|
136
|
+
"-t",
|
|
137
|
+
"--transport",
|
|
138
|
+
type=click.Choice(["stdio", "http"]),
|
|
139
|
+
default="stdio",
|
|
140
|
+
help="Transport type for config",
|
|
141
|
+
)
|
|
142
|
+
@click.option("--host", default="localhost", help="Host for HTTP transport")
|
|
143
|
+
@click.option("--port", default=8082, type=int, help="Port for HTTP transport")
|
|
144
|
+
def installation_cmd(transport: str, host: str, port: int):
|
|
145
|
+
"""Show MCP client configuration snippets.
|
|
146
|
+
|
|
147
|
+
Outputs JSON configuration for Claude Desktop or other MCP clients.
|
|
148
|
+
|
|
149
|
+
\b
|
|
150
|
+
Examples:
|
|
151
|
+
crossref-local mcp installation # stdio config
|
|
152
|
+
crossref-local mcp installation -t http # HTTP config
|
|
153
|
+
"""
|
|
154
|
+
if transport == "stdio":
|
|
155
|
+
config = {
|
|
156
|
+
"mcpServers": {
|
|
157
|
+
"crossref-local": {
|
|
158
|
+
"command": "crossref-local",
|
|
159
|
+
"args": ["mcp", "start"],
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
click.echo("Claude Desktop configuration (stdio):")
|
|
164
|
+
click.echo()
|
|
165
|
+
click.echo(
|
|
166
|
+
"Add to ~/Library/Application Support/Claude/claude_desktop_config.json"
|
|
167
|
+
)
|
|
168
|
+
click.echo("or ~/.config/claude/claude_desktop_config.json:")
|
|
169
|
+
click.echo()
|
|
170
|
+
else:
|
|
171
|
+
url = f"http://{host}:{port}/mcp"
|
|
172
|
+
config = {"mcpServers": {"crossref-local": {"url": url}}}
|
|
173
|
+
click.echo(f"Claude Desktop configuration (HTTP at {url}):")
|
|
174
|
+
click.echo()
|
|
175
|
+
click.echo("First start the server:")
|
|
176
|
+
click.echo(f" crossref-local mcp start -t http --host {host} --port {port}")
|
|
177
|
+
click.echo()
|
|
178
|
+
click.echo("Then add to claude_desktop_config.json:")
|
|
179
|
+
click.echo()
|
|
180
|
+
|
|
181
|
+
click.echo(json.dumps(config, indent=2))
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@mcp.command("list-tools", context_settings=CONTEXT_SETTINGS)
|
|
185
|
+
@click.option("--json", "as_json", is_flag=True, help="Output as JSON")
|
|
186
|
+
def list_tools_cmd(as_json: bool):
|
|
187
|
+
"""List available MCP tools.
|
|
188
|
+
|
|
189
|
+
Shows all tools exposed by the MCP server with their descriptions.
|
|
190
|
+
"""
|
|
191
|
+
tools = [
|
|
192
|
+
{
|
|
193
|
+
"name": "search",
|
|
194
|
+
"description": "Search for academic works by title, abstract, or authors",
|
|
195
|
+
"parameters": ["query", "limit", "offset", "with_abstracts"],
|
|
196
|
+
},
|
|
197
|
+
{
|
|
198
|
+
"name": "search_by_doi",
|
|
199
|
+
"description": "Get detailed information about a work by DOI",
|
|
200
|
+
"parameters": ["doi", "as_citation"],
|
|
201
|
+
},
|
|
202
|
+
{
|
|
203
|
+
"name": "status",
|
|
204
|
+
"description": "Get database statistics and status",
|
|
205
|
+
"parameters": [],
|
|
206
|
+
},
|
|
207
|
+
{
|
|
208
|
+
"name": "enrich_dois",
|
|
209
|
+
"description": "Enrich DOIs with full metadata including citations",
|
|
210
|
+
"parameters": ["dois"],
|
|
211
|
+
},
|
|
212
|
+
{
|
|
213
|
+
"name": "cache_create",
|
|
214
|
+
"description": "Create a paper cache from search query",
|
|
215
|
+
"parameters": ["name", "query", "limit"],
|
|
216
|
+
},
|
|
217
|
+
{
|
|
218
|
+
"name": "cache_query",
|
|
219
|
+
"description": "Query cached papers with field filtering",
|
|
220
|
+
"parameters": ["name", "fields", "year_min", "year_max", "limit"],
|
|
221
|
+
},
|
|
222
|
+
{
|
|
223
|
+
"name": "cache_stats",
|
|
224
|
+
"description": "Get cache statistics",
|
|
225
|
+
"parameters": ["name"],
|
|
226
|
+
},
|
|
227
|
+
{
|
|
228
|
+
"name": "cache_list",
|
|
229
|
+
"description": "List all available caches",
|
|
230
|
+
"parameters": [],
|
|
231
|
+
},
|
|
232
|
+
{
|
|
233
|
+
"name": "cache_top_cited",
|
|
234
|
+
"description": "Get top cited papers from cache",
|
|
235
|
+
"parameters": ["name", "n", "year_min", "year_max"],
|
|
236
|
+
},
|
|
237
|
+
{
|
|
238
|
+
"name": "cache_citation_summary",
|
|
239
|
+
"description": "Get citation statistics for cached papers",
|
|
240
|
+
"parameters": ["name"],
|
|
241
|
+
},
|
|
242
|
+
{
|
|
243
|
+
"name": "cache_plot_scatter",
|
|
244
|
+
"description": "Generate year vs citations scatter plot",
|
|
245
|
+
"parameters": ["name", "output", "top_n"],
|
|
246
|
+
},
|
|
247
|
+
{
|
|
248
|
+
"name": "cache_plot_network",
|
|
249
|
+
"description": "Generate citation network visualization",
|
|
250
|
+
"parameters": ["name", "output", "max_nodes"],
|
|
251
|
+
},
|
|
252
|
+
{
|
|
253
|
+
"name": "cache_export",
|
|
254
|
+
"description": "Export cache to file (json, csv, bibtex, dois)",
|
|
255
|
+
"parameters": ["name", "output_path", "format", "fields"],
|
|
256
|
+
},
|
|
257
|
+
]
|
|
258
|
+
|
|
259
|
+
if as_json:
|
|
260
|
+
click.echo(json.dumps(tools, indent=2))
|
|
261
|
+
else:
|
|
262
|
+
click.echo("CrossRef Local MCP Tools")
|
|
263
|
+
click.echo("=" * 50)
|
|
264
|
+
click.echo()
|
|
265
|
+
for tool in tools:
|
|
266
|
+
click.echo(f" {tool['name']}")
|
|
267
|
+
click.echo(f" {tool['description']}")
|
|
268
|
+
if tool["parameters"]:
|
|
269
|
+
click.echo(f" Parameters: {', '.join(tool['parameters'])}")
|
|
270
|
+
click.echo()
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def register_mcp_commands(cli_group):
|
|
274
|
+
"""Register MCP commands with the main CLI group."""
|
|
275
|
+
cli_group.add_command(mcp)
|
crossref_local/config.py
CHANGED
|
@@ -6,11 +6,15 @@ from typing import Optional
|
|
|
6
6
|
|
|
7
7
|
# Default database locations (checked in order)
|
|
8
8
|
DEFAULT_DB_PATHS = [
|
|
9
|
-
Path("/home/ywatanabe/proj/crossref_local/data/crossref.db"),
|
|
10
|
-
Path("/mnt/nas_ug/crossref_local/data/crossref.db"),
|
|
11
|
-
Path.home() / ".crossref_local" / "crossref.db",
|
|
12
9
|
Path.cwd() / "data" / "crossref.db",
|
|
10
|
+
Path.home() / ".crossref_local" / "crossref.db",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
# Default remote API URL (via SSH tunnel)
|
|
14
|
+
DEFAULT_API_URLS = [
|
|
15
|
+
"http://localhost:8333", # SSH tunnel to NAS
|
|
13
16
|
]
|
|
17
|
+
DEFAULT_API_URL = DEFAULT_API_URLS[0]
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
def get_db_path() -> Path:
|
|
@@ -50,6 +54,46 @@ class Config:
|
|
|
50
54
|
"""Configuration container."""
|
|
51
55
|
|
|
52
56
|
_db_path: Optional[Path] = None
|
|
57
|
+
_api_url: Optional[str] = None
|
|
58
|
+
_mode: str = "auto" # "auto", "db", or "http"
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def get_mode(cls) -> str:
|
|
62
|
+
"""
|
|
63
|
+
Get current mode.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
"db" if using direct database access
|
|
67
|
+
"http" if using HTTP API
|
|
68
|
+
"""
|
|
69
|
+
if cls._mode == "auto":
|
|
70
|
+
# Check environment variable
|
|
71
|
+
env_mode = os.environ.get("CROSSREF_LOCAL_MODE", "").lower()
|
|
72
|
+
if env_mode in ("http", "remote", "api"):
|
|
73
|
+
return "http"
|
|
74
|
+
if env_mode in ("db", "local"):
|
|
75
|
+
return "db"
|
|
76
|
+
|
|
77
|
+
# Check if API URL is set
|
|
78
|
+
if cls._api_url or os.environ.get("CROSSREF_LOCAL_API_URL"):
|
|
79
|
+
return "http"
|
|
80
|
+
|
|
81
|
+
# Check if local database exists
|
|
82
|
+
try:
|
|
83
|
+
get_db_path()
|
|
84
|
+
return "db"
|
|
85
|
+
except FileNotFoundError:
|
|
86
|
+
# No local DB, try http
|
|
87
|
+
return "http"
|
|
88
|
+
|
|
89
|
+
return cls._mode
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def set_mode(cls, mode: str) -> None:
|
|
93
|
+
"""Set mode explicitly: 'db', 'http', or 'auto'."""
|
|
94
|
+
if mode not in ("auto", "db", "http"):
|
|
95
|
+
raise ValueError(f"Invalid mode: {mode}. Use 'auto', 'db', or 'http'")
|
|
96
|
+
cls._mode = mode
|
|
53
97
|
|
|
54
98
|
@classmethod
|
|
55
99
|
def get_db_path(cls) -> Path:
|
|
@@ -65,8 +109,60 @@ class Config:
|
|
|
65
109
|
if not path.exists():
|
|
66
110
|
raise FileNotFoundError(f"Database not found: {path}")
|
|
67
111
|
cls._db_path = path
|
|
112
|
+
cls._mode = "db"
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def get_api_url(cls, auto_detect: bool = True) -> str:
|
|
116
|
+
"""
|
|
117
|
+
Get API URL for remote mode.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
auto_detect: If True, test each URL and use first working one
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
API URL string
|
|
124
|
+
"""
|
|
125
|
+
if cls._api_url:
|
|
126
|
+
return cls._api_url
|
|
127
|
+
|
|
128
|
+
env_url = os.environ.get("CROSSREF_LOCAL_API_URL")
|
|
129
|
+
if env_url:
|
|
130
|
+
return env_url
|
|
131
|
+
|
|
132
|
+
if auto_detect:
|
|
133
|
+
working_url = cls._find_working_api()
|
|
134
|
+
if working_url:
|
|
135
|
+
cls._api_url = working_url
|
|
136
|
+
return working_url
|
|
137
|
+
|
|
138
|
+
return DEFAULT_API_URL
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def _find_working_api(cls) -> Optional[str]:
|
|
142
|
+
"""Try each default API URL and return first working one."""
|
|
143
|
+
import urllib.request
|
|
144
|
+
import urllib.error
|
|
145
|
+
|
|
146
|
+
for url in DEFAULT_API_URLS:
|
|
147
|
+
try:
|
|
148
|
+
req = urllib.request.Request(f"{url}/health", method="GET")
|
|
149
|
+
req.add_header("Accept", "application/json")
|
|
150
|
+
with urllib.request.urlopen(req, timeout=3) as response:
|
|
151
|
+
if response.status == 200:
|
|
152
|
+
return url
|
|
153
|
+
except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError):
|
|
154
|
+
continue
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
def set_api_url(cls, url: str) -> None:
|
|
159
|
+
"""Set API URL for http mode."""
|
|
160
|
+
cls._api_url = url.rstrip("/")
|
|
161
|
+
cls._mode = "http"
|
|
68
162
|
|
|
69
163
|
@classmethod
|
|
70
164
|
def reset(cls) -> None:
|
|
71
165
|
"""Reset configuration (for testing)."""
|
|
72
166
|
cls._db_path = None
|
|
167
|
+
cls._api_url = None
|
|
168
|
+
cls._mode = "auto"
|
crossref_local/db.py
CHANGED
|
@@ -34,7 +34,9 @@ class Database:
|
|
|
34
34
|
|
|
35
35
|
def _connect(self) -> None:
|
|
36
36
|
"""Establish database connection."""
|
|
37
|
-
|
|
37
|
+
# check_same_thread=False allows connection to be used across threads
|
|
38
|
+
# Safe for read-only operations (which is our use case)
|
|
39
|
+
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
38
40
|
self.conn.row_factory = sqlite3.Row
|
|
39
41
|
|
|
40
42
|
def close(self) -> None:
|
crossref_local/fts.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Full-text search using FTS5."""
|
|
2
2
|
|
|
3
|
+
import re
|
|
3
4
|
import time
|
|
4
5
|
from typing import List, Optional
|
|
5
6
|
|
|
@@ -7,6 +8,34 @@ from .db import Database, get_db
|
|
|
7
8
|
from .models import Work, SearchResult
|
|
8
9
|
|
|
9
10
|
|
|
11
|
+
def _sanitize_query(query: str) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Sanitize query for FTS5.
|
|
14
|
+
|
|
15
|
+
Handles special characters that FTS5 interprets as operators:
|
|
16
|
+
- Hyphens in words like "RS-1" or "CRISPR-Cas9"
|
|
17
|
+
- Other special characters
|
|
18
|
+
|
|
19
|
+
If query contains problematic characters, wrap each term in quotes.
|
|
20
|
+
"""
|
|
21
|
+
# If already quoted, return as-is
|
|
22
|
+
if query.startswith('"') and query.endswith('"'):
|
|
23
|
+
return query
|
|
24
|
+
|
|
25
|
+
# Check for problematic patterns (hyphenated words, special chars)
|
|
26
|
+
# But allow explicit FTS5 operators: AND, OR, NOT, NEAR
|
|
27
|
+
has_hyphenated_word = re.search(r'\w+-\w+', query)
|
|
28
|
+
has_special = re.search(r'[/\\@#$%^&]', query)
|
|
29
|
+
|
|
30
|
+
if has_hyphenated_word or has_special:
|
|
31
|
+
# Quote each word to treat as literal
|
|
32
|
+
words = query.split()
|
|
33
|
+
quoted = ' '.join(f'"{w}"' for w in words)
|
|
34
|
+
return quoted
|
|
35
|
+
|
|
36
|
+
return query
|
|
37
|
+
|
|
38
|
+
|
|
10
39
|
def search(
|
|
11
40
|
query: str,
|
|
12
41
|
limit: int = 10,
|
|
@@ -38,10 +67,13 @@ def search(
|
|
|
38
67
|
|
|
39
68
|
start = time.perf_counter()
|
|
40
69
|
|
|
70
|
+
# Sanitize query for FTS5
|
|
71
|
+
safe_query = _sanitize_query(query)
|
|
72
|
+
|
|
41
73
|
# Get total count
|
|
42
74
|
count_row = db.fetchone(
|
|
43
75
|
"SELECT COUNT(*) as total FROM works_fts WHERE works_fts MATCH ?",
|
|
44
|
-
(
|
|
76
|
+
(safe_query,)
|
|
45
77
|
)
|
|
46
78
|
total = count_row["total"] if count_row else 0
|
|
47
79
|
|
|
@@ -54,7 +86,7 @@ def search(
|
|
|
54
86
|
WHERE works_fts MATCH ?
|
|
55
87
|
LIMIT ? OFFSET ?
|
|
56
88
|
""",
|
|
57
|
-
(
|
|
89
|
+
(safe_query, limit, offset)
|
|
58
90
|
)
|
|
59
91
|
|
|
60
92
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
@@ -87,9 +119,10 @@ def count(query: str, db: Optional[Database] = None) -> int:
|
|
|
87
119
|
if db is None:
|
|
88
120
|
db = get_db()
|
|
89
121
|
|
|
122
|
+
safe_query = _sanitize_query(query)
|
|
90
123
|
row = db.fetchone(
|
|
91
124
|
"SELECT COUNT(*) as total FROM works_fts WHERE works_fts MATCH ?",
|
|
92
|
-
(
|
|
125
|
+
(safe_query,)
|
|
93
126
|
)
|
|
94
127
|
return row["total"] if row else 0
|
|
95
128
|
|
|
@@ -113,6 +146,7 @@ def search_dois(
|
|
|
113
146
|
if db is None:
|
|
114
147
|
db = get_db()
|
|
115
148
|
|
|
149
|
+
safe_query = _sanitize_query(query)
|
|
116
150
|
rows = db.fetchall(
|
|
117
151
|
"""
|
|
118
152
|
SELECT w.doi
|
|
@@ -121,7 +155,7 @@ def search_dois(
|
|
|
121
155
|
WHERE works_fts MATCH ?
|
|
122
156
|
LIMIT ?
|
|
123
157
|
""",
|
|
124
|
-
(
|
|
158
|
+
(safe_query, limit)
|
|
125
159
|
)
|
|
126
160
|
|
|
127
161
|
return [row["doi"] for row in rows]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|