crossref-local 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crossref_local/__init__.py +24 -10
- crossref_local/_aio/__init__.py +30 -0
- crossref_local/_aio/_impl.py +238 -0
- crossref_local/_cache/__init__.py +15 -0
- crossref_local/{cache_export.py → _cache/export.py} +27 -10
- crossref_local/_cache/utils.py +93 -0
- crossref_local/_cli/__init__.py +9 -0
- crossref_local/_cli/cli.py +389 -0
- crossref_local/_cli/mcp.py +351 -0
- crossref_local/_cli/mcp_server.py +457 -0
- crossref_local/_cli/search.py +199 -0
- crossref_local/_core/__init__.py +62 -0
- crossref_local/{api.py → _core/api.py} +26 -5
- crossref_local/{citations.py → _core/citations.py} +55 -26
- crossref_local/{config.py → _core/config.py} +40 -22
- crossref_local/{db.py → _core/db.py} +32 -26
- crossref_local/_core/export.py +344 -0
- crossref_local/{fts.py → _core/fts.py} +37 -14
- crossref_local/{models.py → _core/models.py} +120 -6
- crossref_local/_remote/__init__.py +56 -0
- crossref_local/_remote/base.py +378 -0
- crossref_local/_remote/collections.py +175 -0
- crossref_local/_server/__init__.py +140 -0
- crossref_local/_server/middleware.py +25 -0
- crossref_local/_server/models.py +143 -0
- crossref_local/_server/routes_citations.py +98 -0
- crossref_local/_server/routes_collections.py +282 -0
- crossref_local/_server/routes_compat.py +102 -0
- crossref_local/_server/routes_works.py +178 -0
- crossref_local/_server/server.py +19 -0
- crossref_local/aio.py +30 -206
- crossref_local/cache.py +100 -100
- crossref_local/cli.py +5 -515
- crossref_local/jobs.py +169 -0
- crossref_local/mcp_server.py +5 -410
- crossref_local/remote.py +5 -266
- crossref_local/server.py +5 -349
- {crossref_local-0.4.0.dist-info → crossref_local-0.5.1.dist-info}/METADATA +36 -11
- crossref_local-0.5.1.dist-info/RECORD +49 -0
- {crossref_local-0.4.0.dist-info → crossref_local-0.5.1.dist-info}/entry_points.txt +1 -1
- crossref_local/cli_mcp.py +0 -275
- crossref_local-0.4.0.dist-info/RECORD +0 -27
- /crossref_local/{cache_viz.py → _cache/viz.py} +0 -0
- /crossref_local/{cli_cache.py → _cli/cache.py} +0 -0
- /crossref_local/{cli_completion.py → _cli/completion.py} +0 -0
- /crossref_local/{cli_main.py → _cli/main.py} +0 -0
- /crossref_local/{impact_factor → _impact_factor}/__init__.py +0 -0
- /crossref_local/{impact_factor → _impact_factor}/calculator.py +0 -0
- /crossref_local/{impact_factor → _impact_factor}/journal_lookup.py +0 -0
- {crossref_local-0.4.0.dist-info → crossref_local-0.5.1.dist-info}/WHEEL +0 -0
|
@@ -12,15 +12,34 @@ Mode is auto-detected or can be set explicitly via:
|
|
|
12
12
|
|
|
13
13
|
from typing import List, Optional
|
|
14
14
|
|
|
15
|
-
from .config import Config
|
|
16
|
-
from .db import get_db, close_db
|
|
17
|
-
from .models import Work, SearchResult
|
|
18
15
|
from . import fts
|
|
16
|
+
from .config import Config
|
|
17
|
+
from .db import close_db, get_db
|
|
18
|
+
from .models import SearchResult, Work
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"search",
|
|
22
|
+
"count",
|
|
23
|
+
"get",
|
|
24
|
+
"get_many",
|
|
25
|
+
"exists",
|
|
26
|
+
"configure",
|
|
27
|
+
"configure_http",
|
|
28
|
+
"configure_remote",
|
|
29
|
+
"enrich",
|
|
30
|
+
"enrich_dois",
|
|
31
|
+
"get_mode",
|
|
32
|
+
"info",
|
|
33
|
+
# Re-exported for convenience
|
|
34
|
+
"Work",
|
|
35
|
+
"SearchResult",
|
|
36
|
+
"Config",
|
|
37
|
+
]
|
|
19
38
|
|
|
20
39
|
|
|
21
40
|
def _get_http_client():
|
|
22
41
|
"""Get HTTP client (lazy import to avoid circular dependency)."""
|
|
23
|
-
from
|
|
42
|
+
from .._remote import RemoteClient # Uses enhanced client with collections
|
|
24
43
|
|
|
25
44
|
return RemoteClient(Config.get_api_url())
|
|
26
45
|
|
|
@@ -29,6 +48,7 @@ def search(
|
|
|
29
48
|
query: str,
|
|
30
49
|
limit: int = 10,
|
|
31
50
|
offset: int = 0,
|
|
51
|
+
with_if: bool = False,
|
|
32
52
|
) -> SearchResult:
|
|
33
53
|
"""
|
|
34
54
|
Full-text search across works.
|
|
@@ -39,6 +59,7 @@ def search(
|
|
|
39
59
|
query: Search query (supports FTS5 syntax)
|
|
40
60
|
limit: Maximum results to return
|
|
41
61
|
offset: Skip first N results (for pagination)
|
|
62
|
+
with_if: Include impact factor data (OpenAlex)
|
|
42
63
|
|
|
43
64
|
Returns:
|
|
44
65
|
SearchResult with matching works
|
|
@@ -50,7 +71,7 @@ def search(
|
|
|
50
71
|
"""
|
|
51
72
|
if Config.get_mode() == "http":
|
|
52
73
|
client = _get_http_client()
|
|
53
|
-
return client.search(query=query, limit=limit)
|
|
74
|
+
return client.search(query=query, limit=limit, offset=offset, with_if=with_if)
|
|
54
75
|
return fts.search(query, limit, offset)
|
|
55
76
|
|
|
56
77
|
|
|
@@ -16,13 +16,22 @@ Usage:
|
|
|
16
16
|
network.save_html("citation_network.html")
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
|
-
from dataclasses import dataclass
|
|
20
|
-
from
|
|
21
|
-
from
|
|
19
|
+
from dataclasses import dataclass as _dataclass
|
|
20
|
+
from dataclasses import field as _field
|
|
21
|
+
from typing import Dict, List, Optional, Set, Tuple
|
|
22
22
|
|
|
23
|
-
from .db import
|
|
23
|
+
from .db import Database, get_db
|
|
24
24
|
from .models import Work
|
|
25
25
|
|
|
26
|
+
__all__ = [
|
|
27
|
+
"get_citing",
|
|
28
|
+
"get_cited",
|
|
29
|
+
"get_citation_count",
|
|
30
|
+
"CitationNode",
|
|
31
|
+
"CitationEdge",
|
|
32
|
+
"CitationNetwork",
|
|
33
|
+
]
|
|
34
|
+
|
|
26
35
|
|
|
27
36
|
def get_citing(doi: str, limit: int = 100, db: Optional[Database] = None) -> List[str]:
|
|
28
37
|
"""
|
|
@@ -46,7 +55,7 @@ def get_citing(doi: str, limit: int = 100, db: Optional[Database] = None) -> Lis
|
|
|
46
55
|
WHERE cited_doi = ?
|
|
47
56
|
LIMIT ?
|
|
48
57
|
""",
|
|
49
|
-
(doi, limit)
|
|
58
|
+
(doi, limit),
|
|
50
59
|
)
|
|
51
60
|
return [row["citing_doi"] for row in rows]
|
|
52
61
|
|
|
@@ -73,7 +82,7 @@ def get_cited(doi: str, limit: int = 100, db: Optional[Database] = None) -> List
|
|
|
73
82
|
WHERE citing_doi = ?
|
|
74
83
|
LIMIT ?
|
|
75
84
|
""",
|
|
76
|
-
(doi, limit)
|
|
85
|
+
(doi, limit),
|
|
77
86
|
)
|
|
78
87
|
return [row["cited_doi"] for row in rows]
|
|
79
88
|
|
|
@@ -93,18 +102,18 @@ def get_citation_count(doi: str, db: Optional[Database] = None) -> int:
|
|
|
93
102
|
db = get_db()
|
|
94
103
|
|
|
95
104
|
row = db.fetchone(
|
|
96
|
-
"SELECT COUNT(*) as count FROM citations WHERE cited_doi = ?",
|
|
97
|
-
(doi,)
|
|
105
|
+
"SELECT COUNT(*) as count FROM citations WHERE cited_doi = ?", (doi,)
|
|
98
106
|
)
|
|
99
107
|
return row["count"] if row else 0
|
|
100
108
|
|
|
101
109
|
|
|
102
|
-
@
|
|
110
|
+
@_dataclass
|
|
103
111
|
class CitationNode:
|
|
104
112
|
"""A node in the citation network."""
|
|
113
|
+
|
|
105
114
|
doi: str
|
|
106
115
|
title: str = ""
|
|
107
|
-
authors: List[str] =
|
|
116
|
+
authors: List[str] = _field(default_factory=list)
|
|
108
117
|
year: Optional[int] = None
|
|
109
118
|
journal: str = ""
|
|
110
119
|
citation_count: int = 0
|
|
@@ -122,9 +131,10 @@ class CitationNode:
|
|
|
122
131
|
}
|
|
123
132
|
|
|
124
133
|
|
|
125
|
-
@
|
|
134
|
+
@_dataclass
|
|
126
135
|
class CitationEdge:
|
|
127
136
|
"""An edge in the citation network (citing -> cited)."""
|
|
137
|
+
|
|
128
138
|
citing_doi: str
|
|
129
139
|
cited_doi: str
|
|
130
140
|
year: Optional[int] = None
|
|
@@ -272,6 +282,8 @@ class CitationNetwork:
|
|
|
272
282
|
Raises:
|
|
273
283
|
ImportError: If pyvis is not installed
|
|
274
284
|
"""
|
|
285
|
+
import math as _math
|
|
286
|
+
|
|
275
287
|
try:
|
|
276
288
|
from pyvis.network import Network
|
|
277
289
|
except ImportError:
|
|
@@ -284,7 +296,7 @@ class CitationNetwork:
|
|
|
284
296
|
directed=True,
|
|
285
297
|
bgcolor="#ffffff",
|
|
286
298
|
font_color="#333333",
|
|
287
|
-
**kwargs
|
|
299
|
+
**kwargs,
|
|
288
300
|
)
|
|
289
301
|
|
|
290
302
|
# Configure physics
|
|
@@ -298,15 +310,16 @@ class CitationNetwork:
|
|
|
298
310
|
# Add nodes with styling based on depth and citation count
|
|
299
311
|
for doi, node in self.nodes.items():
|
|
300
312
|
# Size based on citation count (log scale)
|
|
301
|
-
|
|
302
|
-
size = 10 + min(30, math.log1p(node.citation_count) * 5)
|
|
313
|
+
size = 10 + min(30, _math.log1p(node.citation_count) * 5)
|
|
303
314
|
|
|
304
315
|
# Color based on depth
|
|
305
316
|
colors = ["#e74c3c", "#3498db", "#2ecc71", "#9b59b6", "#f39c12"]
|
|
306
317
|
color = colors[min(node.depth, len(colors) - 1)]
|
|
307
318
|
|
|
308
319
|
# Label
|
|
309
|
-
title_short = (
|
|
320
|
+
title_short = (
|
|
321
|
+
(node.title[:50] + "...") if len(node.title) > 50 else node.title
|
|
322
|
+
)
|
|
310
323
|
label = f"{title_short}\n({node.year or 'N/A'})"
|
|
311
324
|
|
|
312
325
|
# Tooltip
|
|
@@ -316,7 +329,7 @@ class CitationNetwork:
|
|
|
316
329
|
tooltip = f"""
|
|
317
330
|
<b>{node.title}</b><br>
|
|
318
331
|
{authors_str}<br>
|
|
319
|
-
{node.journal} ({node.year or
|
|
332
|
+
{node.journal} ({node.year or "N/A"})<br>
|
|
320
333
|
Citations: {node.citation_count}<br>
|
|
321
334
|
DOI: {doi}
|
|
322
335
|
"""
|
|
@@ -340,7 +353,9 @@ class CitationNetwork:
|
|
|
340
353
|
net.save_graph(path)
|
|
341
354
|
return path
|
|
342
355
|
|
|
343
|
-
def save_png(
|
|
356
|
+
def save_png(
|
|
357
|
+
self, path: str = "citation_network.png", figsize: Tuple[int, int] = (12, 10)
|
|
358
|
+
):
|
|
344
359
|
"""
|
|
345
360
|
Save static PNG visualization using matplotlib.
|
|
346
361
|
|
|
@@ -351,6 +366,8 @@ class CitationNetwork:
|
|
|
351
366
|
Raises:
|
|
352
367
|
ImportError: If matplotlib is not installed
|
|
353
368
|
"""
|
|
369
|
+
import math as _math
|
|
370
|
+
|
|
354
371
|
try:
|
|
355
372
|
import matplotlib.pyplot as plt
|
|
356
373
|
import networkx as nx
|
|
@@ -365,24 +382,34 @@ class CitationNetwork:
|
|
|
365
382
|
pos = nx.spring_layout(G, k=2, iterations=50)
|
|
366
383
|
|
|
367
384
|
# Node sizes based on citation count
|
|
368
|
-
|
|
369
|
-
|
|
385
|
+
sizes = [
|
|
386
|
+
100 + min(500, _math.log1p(self.nodes[n].citation_count) * 50)
|
|
387
|
+
for n in G.nodes()
|
|
388
|
+
]
|
|
370
389
|
|
|
371
390
|
# Node colors based on depth
|
|
372
391
|
colors = [self.nodes[n].depth for n in G.nodes()]
|
|
373
392
|
|
|
374
393
|
# Draw
|
|
375
|
-
nx.draw_networkx_nodes(
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
394
|
+
nx.draw_networkx_nodes(
|
|
395
|
+
G,
|
|
396
|
+
pos,
|
|
397
|
+
node_size=sizes,
|
|
398
|
+
node_color=colors,
|
|
399
|
+
cmap=plt.cm.RdYlBu_r,
|
|
400
|
+
alpha=0.8,
|
|
401
|
+
ax=ax,
|
|
402
|
+
)
|
|
403
|
+
nx.draw_networkx_edges(G, pos, alpha=0.3, arrows=True, arrowsize=10, ax=ax)
|
|
379
404
|
|
|
380
405
|
# Labels for important nodes (high citation count)
|
|
381
406
|
labels = {}
|
|
382
407
|
for doi in G.nodes():
|
|
383
408
|
node = self.nodes[doi]
|
|
384
409
|
if node.citation_count > 10 or doi == self.center_doi:
|
|
385
|
-
short_title = (
|
|
410
|
+
short_title = (
|
|
411
|
+
(node.title[:30] + "...") if len(node.title) > 30 else node.title
|
|
412
|
+
)
|
|
386
413
|
labels[doi] = f"{short_title}\n({node.year or 'N/A'})"
|
|
387
414
|
|
|
388
415
|
nx.draw_networkx_labels(G, pos, labels, font_size=8, ax=ax)
|
|
@@ -402,11 +429,13 @@ class CitationNetwork:
|
|
|
402
429
|
"center_doi": self.center_doi,
|
|
403
430
|
"depth": self.depth,
|
|
404
431
|
"nodes": [n.to_dict() for n in self.nodes.values()],
|
|
405
|
-
"edges": [
|
|
432
|
+
"edges": [
|
|
433
|
+
{"citing": e.citing_doi, "cited": e.cited_doi} for e in self.edges
|
|
434
|
+
],
|
|
406
435
|
"stats": {
|
|
407
436
|
"total_nodes": len(self.nodes),
|
|
408
437
|
"total_edges": len(self.edges),
|
|
409
|
-
}
|
|
438
|
+
},
|
|
410
439
|
}
|
|
411
440
|
|
|
412
441
|
def __repr__(self):
|
|
@@ -1,29 +1,42 @@
|
|
|
1
1
|
"""Configuration for crossref_local."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
|
-
from pathlib import Path
|
|
3
|
+
import os as _os
|
|
4
|
+
from pathlib import Path as _Path
|
|
5
5
|
from typing import Optional
|
|
6
6
|
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Config",
|
|
9
|
+
"get_db_path",
|
|
10
|
+
"DEFAULT_PORT",
|
|
11
|
+
"DEFAULT_API_URL",
|
|
12
|
+
]
|
|
13
|
+
|
|
7
14
|
# Default database locations (checked in order)
|
|
8
15
|
DEFAULT_DB_PATHS = [
|
|
9
|
-
|
|
10
|
-
|
|
16
|
+
_Path.cwd() / "data" / "crossref.db",
|
|
17
|
+
_Path.home() / ".crossref_local" / "crossref.db",
|
|
11
18
|
]
|
|
12
19
|
|
|
13
|
-
# Default
|
|
20
|
+
# Default port: SCITEX convention (3129X scheme)
|
|
21
|
+
# 31290: scitex-cloud, 31291: crossref-local, 31292: openalex-local, 31293: audio relay
|
|
22
|
+
DEFAULT_PORT = 31291
|
|
23
|
+
|
|
24
|
+
# Default remote API URLs (checked in order)
|
|
14
25
|
DEFAULT_API_URLS = [
|
|
15
|
-
"http://localhost:
|
|
26
|
+
f"http://localhost:{DEFAULT_PORT}", # SCITEX default
|
|
27
|
+
"http://localhost:8333", # Legacy port (backwards compatibility)
|
|
16
28
|
]
|
|
17
29
|
DEFAULT_API_URL = DEFAULT_API_URLS[0]
|
|
18
30
|
|
|
19
31
|
|
|
20
|
-
def get_db_path() ->
|
|
32
|
+
def get_db_path() -> _Path:
|
|
21
33
|
"""
|
|
22
34
|
Get database path from environment or auto-detect.
|
|
23
35
|
|
|
24
36
|
Priority:
|
|
25
|
-
1.
|
|
26
|
-
2.
|
|
37
|
+
1. SCITEX_SCHOLAR_CROSSREF_DB environment variable
|
|
38
|
+
2. CROSSREF_LOCAL_DB environment variable
|
|
39
|
+
3. First existing path from DEFAULT_DB_PATHS
|
|
27
40
|
|
|
28
41
|
Returns:
|
|
29
42
|
Path to the database file
|
|
@@ -31,13 +44,15 @@ def get_db_path() -> Path:
|
|
|
31
44
|
Raises:
|
|
32
45
|
FileNotFoundError: If no database found
|
|
33
46
|
"""
|
|
34
|
-
# Check environment variable first
|
|
35
|
-
env_path =
|
|
47
|
+
# Check SCITEX environment variable first (takes priority)
|
|
48
|
+
env_path = _os.environ.get("SCITEX_SCHOLAR_CROSSREF_DB")
|
|
49
|
+
if not env_path:
|
|
50
|
+
env_path = _os.environ.get("CROSSREF_LOCAL_DB")
|
|
36
51
|
if env_path:
|
|
37
|
-
path =
|
|
52
|
+
path = _Path(env_path)
|
|
38
53
|
if path.exists():
|
|
39
54
|
return path
|
|
40
|
-
raise FileNotFoundError(f"
|
|
55
|
+
raise FileNotFoundError(f"Database path not found: {env_path}")
|
|
41
56
|
|
|
42
57
|
# Auto-detect from default locations
|
|
43
58
|
for path in DEFAULT_DB_PATHS:
|
|
@@ -53,7 +68,7 @@ def get_db_path() -> Path:
|
|
|
53
68
|
class Config:
|
|
54
69
|
"""Configuration container."""
|
|
55
70
|
|
|
56
|
-
_db_path: Optional[
|
|
71
|
+
_db_path: Optional[_Path] = None
|
|
57
72
|
_api_url: Optional[str] = None
|
|
58
73
|
_mode: str = "auto" # "auto", "db", or "http"
|
|
59
74
|
|
|
@@ -67,15 +82,18 @@ class Config:
|
|
|
67
82
|
"http" if using HTTP API
|
|
68
83
|
"""
|
|
69
84
|
if cls._mode == "auto":
|
|
70
|
-
# Check environment
|
|
71
|
-
env_mode =
|
|
85
|
+
# Check environment variables (SCITEX takes priority)
|
|
86
|
+
env_mode = _os.environ.get(
|
|
87
|
+
"SCITEX_SCHOLAR_CROSSREF_MODE",
|
|
88
|
+
_os.environ.get("CROSSREF_LOCAL_MODE", ""),
|
|
89
|
+
).lower()
|
|
72
90
|
if env_mode in ("http", "remote", "api"):
|
|
73
91
|
return "http"
|
|
74
92
|
if env_mode in ("db", "local"):
|
|
75
93
|
return "db"
|
|
76
94
|
|
|
77
95
|
# Check if API URL is set
|
|
78
|
-
if cls._api_url or
|
|
96
|
+
if cls._api_url or _os.environ.get("CROSSREF_LOCAL_API_URL"):
|
|
79
97
|
return "http"
|
|
80
98
|
|
|
81
99
|
# Check if local database exists
|
|
@@ -96,16 +114,16 @@ class Config:
|
|
|
96
114
|
cls._mode = mode
|
|
97
115
|
|
|
98
116
|
@classmethod
|
|
99
|
-
def get_db_path(cls) ->
|
|
117
|
+
def get_db_path(cls) -> _Path:
|
|
100
118
|
"""Get or auto-detect database path."""
|
|
101
119
|
if cls._db_path is None:
|
|
102
120
|
cls._db_path = get_db_path()
|
|
103
121
|
return cls._db_path
|
|
104
122
|
|
|
105
123
|
@classmethod
|
|
106
|
-
def set_db_path(cls, path: str |
|
|
124
|
+
def set_db_path(cls, path: str | _Path) -> None:
|
|
107
125
|
"""Set database path explicitly."""
|
|
108
|
-
path =
|
|
126
|
+
path = _Path(path)
|
|
109
127
|
if not path.exists():
|
|
110
128
|
raise FileNotFoundError(f"Database not found: {path}")
|
|
111
129
|
cls._db_path = path
|
|
@@ -125,7 +143,7 @@ class Config:
|
|
|
125
143
|
if cls._api_url:
|
|
126
144
|
return cls._api_url
|
|
127
145
|
|
|
128
|
-
env_url =
|
|
146
|
+
env_url = _os.environ.get("CROSSREF_LOCAL_API_URL")
|
|
129
147
|
if env_url:
|
|
130
148
|
return env_url
|
|
131
149
|
|
|
@@ -140,8 +158,8 @@ class Config:
|
|
|
140
158
|
@classmethod
|
|
141
159
|
def _find_working_api(cls) -> Optional[str]:
|
|
142
160
|
"""Try each default API URL and return first working one."""
|
|
143
|
-
import urllib.request
|
|
144
161
|
import urllib.error
|
|
162
|
+
import urllib.request
|
|
145
163
|
|
|
146
164
|
for url in DEFAULT_API_URLS:
|
|
147
165
|
try:
|
|
@@ -1,13 +1,20 @@
|
|
|
1
1
|
"""Database connection handling for crossref_local."""
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
import zlib
|
|
6
|
-
from contextlib import contextmanager
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import
|
|
3
|
+
import json as _json
|
|
4
|
+
import sqlite3 as _sqlite3
|
|
5
|
+
import zlib as _zlib
|
|
6
|
+
from contextlib import contextmanager as _contextmanager
|
|
7
|
+
from pathlib import Path as _Path
|
|
8
|
+
from typing import Generator, Optional
|
|
9
9
|
|
|
10
|
-
from .config import Config
|
|
10
|
+
from .config import Config as _Config
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Database",
|
|
14
|
+
"get_db",
|
|
15
|
+
"close_db",
|
|
16
|
+
"connection",
|
|
17
|
+
]
|
|
11
18
|
|
|
12
19
|
|
|
13
20
|
class Database:
|
|
@@ -17,7 +24,7 @@ class Database:
|
|
|
17
24
|
Supports both direct usage and context manager pattern.
|
|
18
25
|
"""
|
|
19
26
|
|
|
20
|
-
def __init__(self, db_path: Optional[str |
|
|
27
|
+
def __init__(self, db_path: Optional[str | _Path] = None):
|
|
21
28
|
"""
|
|
22
29
|
Initialize database connection.
|
|
23
30
|
|
|
@@ -25,19 +32,19 @@ class Database:
|
|
|
25
32
|
db_path: Path to database. If None, auto-detects.
|
|
26
33
|
"""
|
|
27
34
|
if db_path:
|
|
28
|
-
self.db_path =
|
|
35
|
+
self.db_path = _Path(db_path)
|
|
29
36
|
else:
|
|
30
|
-
self.db_path =
|
|
37
|
+
self.db_path = _Config.get_db_path()
|
|
31
38
|
|
|
32
|
-
self.conn: Optional[
|
|
39
|
+
self.conn: Optional[_sqlite3.Connection] = None
|
|
33
40
|
self._connect()
|
|
34
41
|
|
|
35
42
|
def _connect(self) -> None:
|
|
36
43
|
"""Establish database connection."""
|
|
37
44
|
# check_same_thread=False allows connection to be used across threads
|
|
38
45
|
# Safe for read-only operations (which is our use case)
|
|
39
|
-
self.conn =
|
|
40
|
-
self.conn.row_factory =
|
|
46
|
+
self.conn = _sqlite3.connect(self.db_path, check_same_thread=False)
|
|
47
|
+
self.conn.row_factory = _sqlite3.Row
|
|
41
48
|
|
|
42
49
|
def close(self) -> None:
|
|
43
50
|
"""Close database connection."""
|
|
@@ -51,11 +58,11 @@ class Database:
|
|
|
51
58
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
52
59
|
self.close()
|
|
53
60
|
|
|
54
|
-
def execute(self, query: str, params: tuple = ()) ->
|
|
61
|
+
def execute(self, query: str, params: tuple = ()) -> _sqlite3.Cursor:
|
|
55
62
|
"""Execute SQL query."""
|
|
56
63
|
return self.conn.execute(query, params)
|
|
57
64
|
|
|
58
|
-
def fetchone(self, query: str, params: tuple = ()) -> Optional[
|
|
65
|
+
def fetchone(self, query: str, params: tuple = ()) -> Optional[_sqlite3.Row]:
|
|
59
66
|
"""Execute query and fetch one result."""
|
|
60
67
|
cursor = self.execute(query, params)
|
|
61
68
|
return cursor.fetchone()
|
|
@@ -75,10 +82,7 @@ class Database:
|
|
|
75
82
|
Returns:
|
|
76
83
|
Metadata dictionary or None
|
|
77
84
|
"""
|
|
78
|
-
row = self.fetchone(
|
|
79
|
-
"SELECT metadata FROM works WHERE doi = ?",
|
|
80
|
-
(doi,)
|
|
81
|
-
)
|
|
85
|
+
row = self.fetchone("SELECT metadata FROM works WHERE doi = ?", (doi,))
|
|
82
86
|
if row and row["metadata"]:
|
|
83
87
|
return self._decompress_metadata(row["metadata"])
|
|
84
88
|
return None
|
|
@@ -87,15 +91,15 @@ class Database:
|
|
|
87
91
|
"""Decompress and parse metadata (handles both compressed and plain JSON)."""
|
|
88
92
|
# If it's already a string, parse directly
|
|
89
93
|
if isinstance(data, str):
|
|
90
|
-
return
|
|
94
|
+
return _json.loads(data)
|
|
91
95
|
|
|
92
96
|
# If bytes, try decompression
|
|
93
97
|
if isinstance(data, bytes):
|
|
94
98
|
try:
|
|
95
|
-
decompressed =
|
|
96
|
-
return
|
|
97
|
-
except
|
|
98
|
-
return
|
|
99
|
+
decompressed = _zlib.decompress(data)
|
|
100
|
+
return _json.loads(decompressed)
|
|
101
|
+
except _zlib.error:
|
|
102
|
+
return _json.loads(data.decode("utf-8"))
|
|
99
103
|
|
|
100
104
|
return data
|
|
101
105
|
|
|
@@ -120,8 +124,10 @@ def close_db() -> None:
|
|
|
120
124
|
_db = None
|
|
121
125
|
|
|
122
126
|
|
|
123
|
-
@
|
|
124
|
-
def connection(
|
|
127
|
+
@_contextmanager
|
|
128
|
+
def connection(
|
|
129
|
+
db_path: Optional[str | _Path] = None,
|
|
130
|
+
) -> Generator[Database, None, None]:
|
|
125
131
|
"""
|
|
126
132
|
Context manager for database connection.
|
|
127
133
|
|