crossref-local 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,413 @@
1
+ """Citation network analysis and visualization.
2
+
3
+ Build citation graphs like Connected Papers from the local database.
4
+
5
+ Usage:
6
+ from crossref_local.citations import get_citing, get_cited, CitationNetwork
7
+
8
+ # Get papers citing a DOI
9
+ citing_dois = get_citing("10.1038/nature12373")
10
+
11
+ # Get papers a DOI cites
12
+ cited_dois = get_cited("10.1038/nature12373")
13
+
14
+ # Build and visualize network
15
+ network = CitationNetwork("10.1038/nature12373", depth=2)
16
+ network.save_html("citation_network.html")
17
+ """
18
+
19
+ from dataclasses import dataclass, field
20
+ from typing import List, Dict, Optional, Set, Tuple
21
+ from pathlib import Path
22
+
23
+ from .db import get_db, Database
24
+ from .models import Work
25
+
26
+
27
+ def get_citing(doi: str, limit: int = 100, db: Optional[Database] = None) -> List[str]:
28
+ """
29
+ Get DOIs of papers that cite the given DOI.
30
+
31
+ Args:
32
+ doi: The DOI to find citations for
33
+ limit: Maximum number of citing papers to return
34
+ db: Database connection (uses singleton if not provided)
35
+
36
+ Returns:
37
+ List of DOIs that cite this paper
38
+ """
39
+ if db is None:
40
+ db = get_db()
41
+
42
+ rows = db.fetchall(
43
+ """
44
+ SELECT citing_doi
45
+ FROM citations
46
+ WHERE cited_doi = ?
47
+ LIMIT ?
48
+ """,
49
+ (doi, limit)
50
+ )
51
+ return [row["citing_doi"] for row in rows]
52
+
53
+
54
+ def get_cited(doi: str, limit: int = 100, db: Optional[Database] = None) -> List[str]:
55
+ """
56
+ Get DOIs of papers that the given DOI cites (references).
57
+
58
+ Args:
59
+ doi: The DOI to find references for
60
+ limit: Maximum number of referenced papers to return
61
+ db: Database connection (uses singleton if not provided)
62
+
63
+ Returns:
64
+ List of DOIs that this paper cites
65
+ """
66
+ if db is None:
67
+ db = get_db()
68
+
69
+ rows = db.fetchall(
70
+ """
71
+ SELECT cited_doi
72
+ FROM citations
73
+ WHERE citing_doi = ?
74
+ LIMIT ?
75
+ """,
76
+ (doi, limit)
77
+ )
78
+ return [row["cited_doi"] for row in rows]
79
+
80
+
81
+ def get_citation_count(doi: str, db: Optional[Database] = None) -> int:
82
+ """
83
+ Get the number of citations for a DOI.
84
+
85
+ Args:
86
+ doi: The DOI to count citations for
87
+ db: Database connection
88
+
89
+ Returns:
90
+ Number of papers citing this DOI
91
+ """
92
+ if db is None:
93
+ db = get_db()
94
+
95
+ row = db.fetchone(
96
+ "SELECT COUNT(*) as count FROM citations WHERE cited_doi = ?",
97
+ (doi,)
98
+ )
99
+ return row["count"] if row else 0
100
+
101
+
102
+ @dataclass
103
+ class CitationNode:
104
+ """A node in the citation network."""
105
+ doi: str
106
+ title: str = ""
107
+ authors: List[str] = field(default_factory=list)
108
+ year: Optional[int] = None
109
+ journal: str = ""
110
+ citation_count: int = 0
111
+ depth: int = 0 # Distance from center node
112
+
113
+ def to_dict(self) -> dict:
114
+ return {
115
+ "doi": self.doi,
116
+ "title": self.title,
117
+ "authors": self.authors,
118
+ "year": self.year,
119
+ "journal": self.journal,
120
+ "citation_count": self.citation_count,
121
+ "depth": self.depth,
122
+ }
123
+
124
+
125
+ @dataclass
126
+ class CitationEdge:
127
+ """An edge in the citation network (citing -> cited)."""
128
+ citing_doi: str
129
+ cited_doi: str
130
+ year: Optional[int] = None
131
+
132
+
133
+ class CitationNetwork:
134
+ """
135
+ Citation network builder and visualizer.
136
+
137
+ Builds a graph of papers connected by citations, similar to Connected Papers.
138
+
139
+ Example:
140
+ >>> network = CitationNetwork("10.1038/nature12373", depth=2)
141
+ >>> print(f"Nodes: {len(network.nodes)}, Edges: {len(network.edges)}")
142
+ >>> network.save_html("network.html")
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ center_doi: str,
148
+ depth: int = 1,
149
+ max_citing: int = 50,
150
+ max_cited: int = 50,
151
+ db: Optional[Database] = None,
152
+ ):
153
+ """
154
+ Build a citation network around a central paper.
155
+
156
+ Args:
157
+ center_doi: The DOI to build the network around
158
+ depth: How many levels of citations to include (1 = direct only)
159
+ max_citing: Max papers citing each node to include
160
+ max_cited: Max papers each node cites to include
161
+ db: Database connection
162
+ """
163
+ self.center_doi = center_doi
164
+ self.depth = depth
165
+ self.max_citing = max_citing
166
+ self.max_cited = max_cited
167
+ self.db = db or get_db()
168
+
169
+ self.nodes: Dict[str, CitationNode] = {}
170
+ self.edges: List[CitationEdge] = []
171
+
172
+ self._build_network()
173
+
174
+ def _build_network(self):
175
+ """Build the citation network by traversing citations."""
176
+ # Start with center node
177
+ to_process: List[Tuple[str, int]] = [(self.center_doi, 0)]
178
+ processed: Set[str] = set()
179
+
180
+ while to_process:
181
+ doi, current_depth = to_process.pop(0)
182
+
183
+ if doi in processed:
184
+ continue
185
+ processed.add(doi)
186
+
187
+ # Add node
188
+ self._add_node(doi, current_depth)
189
+
190
+ # Stop expanding at max depth
191
+ if current_depth >= self.depth:
192
+ continue
193
+
194
+ # Get citing papers (papers that cite this one)
195
+ citing = get_citing(doi, limit=self.max_citing, db=self.db)
196
+ for citing_doi in citing:
197
+ self.edges.append(CitationEdge(citing_doi=citing_doi, cited_doi=doi))
198
+ if citing_doi not in processed:
199
+ to_process.append((citing_doi, current_depth + 1))
200
+
201
+ # Get cited papers (papers this one cites)
202
+ cited = get_cited(doi, limit=self.max_cited, db=self.db)
203
+ for cited_doi in cited:
204
+ self.edges.append(CitationEdge(citing_doi=doi, cited_doi=cited_doi))
205
+ if cited_doi not in processed:
206
+ to_process.append((cited_doi, current_depth + 1))
207
+
208
+ def _add_node(self, doi: str, depth: int):
209
+ """Add a node with metadata from the database."""
210
+ if doi in self.nodes:
211
+ return
212
+
213
+ # Get metadata
214
+ metadata = self.db.get_metadata(doi)
215
+ citation_count = get_citation_count(doi, db=self.db)
216
+
217
+ if metadata:
218
+ work = Work.from_metadata(doi, metadata)
219
+ self.nodes[doi] = CitationNode(
220
+ doi=doi,
221
+ title=work.title or "",
222
+ authors=work.authors,
223
+ year=work.year,
224
+ journal=work.journal or "",
225
+ citation_count=citation_count,
226
+ depth=depth,
227
+ )
228
+ else:
229
+ # DOI not in our database, create minimal node
230
+ self.nodes[doi] = CitationNode(
231
+ doi=doi,
232
+ citation_count=citation_count,
233
+ depth=depth,
234
+ )
235
+
236
+ def to_networkx(self):
237
+ """
238
+ Convert to a NetworkX DiGraph.
239
+
240
+ Returns:
241
+ networkx.DiGraph with nodes and edges
242
+
243
+ Raises:
244
+ ImportError: If networkx is not installed
245
+ """
246
+ try:
247
+ import networkx as nx
248
+ except ImportError:
249
+ raise ImportError("networkx required: pip install networkx")
250
+
251
+ G = nx.DiGraph()
252
+
253
+ # Add nodes with attributes
254
+ for doi, node in self.nodes.items():
255
+ G.add_node(doi, **node.to_dict())
256
+
257
+ # Add edges
258
+ for edge in self.edges:
259
+ if edge.citing_doi in self.nodes and edge.cited_doi in self.nodes:
260
+ G.add_edge(edge.citing_doi, edge.cited_doi)
261
+
262
+ return G
263
+
264
+ def save_html(self, path: str = "citation_network.html", **kwargs):
265
+ """
266
+ Save interactive HTML visualization using pyvis.
267
+
268
+ Args:
269
+ path: Output file path
270
+ **kwargs: Additional options for pyvis Network
271
+
272
+ Raises:
273
+ ImportError: If pyvis is not installed
274
+ """
275
+ try:
276
+ from pyvis.network import Network
277
+ except ImportError:
278
+ raise ImportError("pyvis required: pip install pyvis")
279
+
280
+ # Create pyvis network
281
+ net = Network(
282
+ height="800px",
283
+ width="100%",
284
+ directed=True,
285
+ bgcolor="#ffffff",
286
+ font_color="#333333",
287
+ **kwargs
288
+ )
289
+
290
+ # Configure physics
291
+ net.barnes_hut(
292
+ gravity=-3000,
293
+ central_gravity=0.3,
294
+ spring_length=200,
295
+ spring_strength=0.05,
296
+ )
297
+
298
+ # Add nodes with styling based on depth and citation count
299
+ for doi, node in self.nodes.items():
300
+ # Size based on citation count (log scale)
301
+ import math
302
+ size = 10 + min(30, math.log1p(node.citation_count) * 5)
303
+
304
+ # Color based on depth
305
+ colors = ["#e74c3c", "#3498db", "#2ecc71", "#9b59b6", "#f39c12"]
306
+ color = colors[min(node.depth, len(colors) - 1)]
307
+
308
+ # Label
309
+ title_short = (node.title[:50] + "...") if len(node.title) > 50 else node.title
310
+ label = f"{title_short}\n({node.year or 'N/A'})"
311
+
312
+ # Tooltip
313
+ authors_str = ", ".join(node.authors[:3])
314
+ if len(node.authors) > 3:
315
+ authors_str += " et al."
316
+ tooltip = f"""
317
+ <b>{node.title}</b><br>
318
+ {authors_str}<br>
319
+ {node.journal} ({node.year or 'N/A'})<br>
320
+ Citations: {node.citation_count}<br>
321
+ DOI: {doi}
322
+ """
323
+
324
+ net.add_node(
325
+ doi,
326
+ label=label,
327
+ title=tooltip,
328
+ size=size,
329
+ color=color,
330
+ borderWidth=2 if doi == self.center_doi else 1,
331
+ borderWidthSelected=4,
332
+ )
333
+
334
+ # Add edges
335
+ for edge in self.edges:
336
+ if edge.citing_doi in self.nodes and edge.cited_doi in self.nodes:
337
+ net.add_edge(edge.citing_doi, edge.cited_doi, arrows="to")
338
+
339
+ # Save
340
+ net.save_graph(path)
341
+ return path
342
+
343
+ def save_png(self, path: str = "citation_network.png", figsize: Tuple[int, int] = (12, 10)):
344
+ """
345
+ Save static PNG visualization using matplotlib.
346
+
347
+ Args:
348
+ path: Output file path
349
+ figsize: Figure size (width, height)
350
+
351
+ Raises:
352
+ ImportError: If matplotlib is not installed
353
+ """
354
+ try:
355
+ import matplotlib.pyplot as plt
356
+ import networkx as nx
357
+ except ImportError:
358
+ raise ImportError("matplotlib and networkx required")
359
+
360
+ G = self.to_networkx()
361
+
362
+ fig, ax = plt.subplots(figsize=figsize)
363
+
364
+ # Layout
365
+ pos = nx.spring_layout(G, k=2, iterations=50)
366
+
367
+ # Node sizes based on citation count
368
+ import math
369
+ sizes = [100 + min(500, math.log1p(self.nodes[n].citation_count) * 50) for n in G.nodes()]
370
+
371
+ # Node colors based on depth
372
+ colors = [self.nodes[n].depth for n in G.nodes()]
373
+
374
+ # Draw
375
+ nx.draw_networkx_nodes(G, pos, node_size=sizes, node_color=colors,
376
+ cmap=plt.cm.RdYlBu_r, alpha=0.8, ax=ax)
377
+ nx.draw_networkx_edges(G, pos, alpha=0.3, arrows=True,
378
+ arrowsize=10, ax=ax)
379
+
380
+ # Labels for important nodes (high citation count)
381
+ labels = {}
382
+ for doi in G.nodes():
383
+ node = self.nodes[doi]
384
+ if node.citation_count > 10 or doi == self.center_doi:
385
+ short_title = (node.title[:30] + "...") if len(node.title) > 30 else node.title
386
+ labels[doi] = f"{short_title}\n({node.year or 'N/A'})"
387
+
388
+ nx.draw_networkx_labels(G, pos, labels, font_size=8, ax=ax)
389
+
390
+ ax.set_title(f"Citation Network: {self.center_doi}")
391
+ ax.axis("off")
392
+
393
+ plt.tight_layout()
394
+ plt.savefig(path, dpi=150, bbox_inches="tight")
395
+ plt.close()
396
+
397
+ return path
398
+
399
+ def to_dict(self) -> dict:
400
+ """Export network as dictionary."""
401
+ return {
402
+ "center_doi": self.center_doi,
403
+ "depth": self.depth,
404
+ "nodes": [n.to_dict() for n in self.nodes.values()],
405
+ "edges": [{"citing": e.citing_doi, "cited": e.cited_doi} for e in self.edges],
406
+ "stats": {
407
+ "total_nodes": len(self.nodes),
408
+ "total_edges": len(self.edges),
409
+ }
410
+ }
411
+
412
+ def __repr__(self):
413
+ return f"CitationNetwork(center={self.center_doi}, nodes={len(self.nodes)}, edges={len(self.edges)})"
crossref_local/cli.py ADDED
@@ -0,0 +1,257 @@
1
+ """Command-line interface for crossref_local."""
2
+
3
+ import click
4
+ import json
5
+ import re
6
+ import sys
7
+ from typing import Optional
8
+
9
+ from . import search, get, count, info, __version__
10
+
11
+
12
+ from .impact_factor import ImpactFactorCalculator
13
+
14
+
15
+ def _strip_xml_tags(text: str) -> str:
16
+ """Strip XML/JATS tags from abstract text."""
17
+ if not text:
18
+ return text
19
+ # Remove XML tags
20
+ text = re.sub(r"<[^>]+>", " ", text)
21
+ # Collapse multiple spaces
22
+ text = re.sub(r"\s+", " ", text)
23
+ return text.strip()
24
+
25
+
26
+ class AliasedGroup(click.Group):
27
+ """Click group that supports command aliases."""
28
+
29
+ def __init__(self, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+ self._aliases = {}
32
+
33
+ def command(self, *args, aliases=None, **kwargs):
34
+ """Decorator that registers aliases for commands."""
35
+ def decorator(f):
36
+ cmd = super(AliasedGroup, self).command(*args, **kwargs)(f)
37
+ if aliases:
38
+ for alias in aliases:
39
+ self._aliases[alias] = cmd.name
40
+ return cmd
41
+ return decorator
42
+
43
+ def get_command(self, ctx, cmd_name):
44
+ """Resolve aliases to actual commands."""
45
+ cmd_name = self._aliases.get(cmd_name, cmd_name)
46
+ return super().get_command(ctx, cmd_name)
47
+
48
+ def format_commands(self, ctx, formatter):
49
+ """Format commands with aliases shown inline."""
50
+ commands = []
51
+ for subcommand in self.list_commands(ctx):
52
+ cmd = self.get_command(ctx, subcommand)
53
+ if cmd is None or cmd.hidden:
54
+ continue
55
+
56
+ # Find aliases for this command
57
+ aliases = [a for a, c in self._aliases.items() if c == subcommand]
58
+ if aliases:
59
+ name = f"{subcommand} ({', '.join(aliases)})"
60
+ else:
61
+ name = subcommand
62
+
63
+ help_text = cmd.get_short_help_str(limit=50)
64
+ commands.append((name, help_text))
65
+
66
+ if commands:
67
+ with formatter.section("Commands"):
68
+ formatter.write_dl(commands)
69
+
70
+
71
+ CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]}
72
+
73
+
74
+ @click.group(cls=AliasedGroup, context_settings=CONTEXT_SETTINGS)
75
+ @click.version_option(version=__version__, prog_name="crossref-local")
76
+ def cli():
77
+ """Local CrossRef database with 167M+ works and full-text search."""
78
+ pass
79
+
80
+
81
+ @cli.command(aliases=["s"], context_settings=CONTEXT_SETTINGS)
82
+ @click.argument("query")
83
+ @click.option("-n", "--limit", default=10, help="Number of results")
84
+ @click.option("-o", "--offset", default=0, help="Skip first N results")
85
+ @click.option("-a", "--with-abstracts", is_flag=True, help="Show abstracts")
86
+ @click.option("--json", "as_json", is_flag=True, help="Output as JSON")
87
+ def search_cmd(query: str, limit: int, offset: int, with_abstracts: bool, as_json: bool):
88
+ """Search for works by title, abstract, or authors."""
89
+ results = search(query, limit=limit, offset=offset)
90
+
91
+ if as_json:
92
+ output = {
93
+ "query": results.query,
94
+ "total": results.total,
95
+ "elapsed_ms": results.elapsed_ms,
96
+ "works": [w.to_dict() for w in results.works],
97
+ }
98
+ click.echo(json.dumps(output, indent=2))
99
+ else:
100
+ click.echo(f"Found {results.total:,} matches in {results.elapsed_ms:.1f}ms\n")
101
+ for i, work in enumerate(results.works, start=offset + 1):
102
+ title = _strip_xml_tags(work.title) if work.title else "Untitled"
103
+ year = f"({work.year})" if work.year else ""
104
+ click.echo(f"{i}. {title} {year}")
105
+ click.echo(f" DOI: {work.doi}")
106
+ if work.journal:
107
+ click.echo(f" Journal: {work.journal}")
108
+ if with_abstracts and work.abstract:
109
+ # Strip XML tags and truncate
110
+ abstract = _strip_xml_tags(work.abstract)
111
+ if len(abstract) > 500:
112
+ abstract = abstract[:500] + "..."
113
+ click.echo(f" Abstract: {abstract}")
114
+ click.echo()
115
+
116
+
117
+ @cli.command("get", aliases=["g"], context_settings=CONTEXT_SETTINGS)
118
+ @click.argument("doi")
119
+ @click.option("--json", "as_json", is_flag=True, help="Output as JSON")
120
+ @click.option("--citation", is_flag=True, help="Output as citation")
121
+ def get_cmd(doi: str, as_json: bool, citation: bool):
122
+ """Get a work by DOI."""
123
+ work = get(doi)
124
+
125
+ if work is None:
126
+ click.echo(f"DOI not found: {doi}", err=True)
127
+ sys.exit(1)
128
+
129
+ if as_json:
130
+ click.echo(json.dumps(work.to_dict(), indent=2))
131
+ elif citation:
132
+ click.echo(work.citation())
133
+ else:
134
+ click.echo(f"Title: {work.title}")
135
+ click.echo(f"Authors: {', '.join(work.authors)}")
136
+ click.echo(f"Year: {work.year}")
137
+ click.echo(f"Journal: {work.journal}")
138
+ click.echo(f"DOI: {work.doi}")
139
+ if work.citation_count:
140
+ click.echo(f"Citations: {work.citation_count}")
141
+
142
+
143
+ @cli.command(aliases=["c"], context_settings=CONTEXT_SETTINGS)
144
+ @click.argument("query")
145
+ def count_cmd(query: str):
146
+ """Count matching works."""
147
+ n = count(query)
148
+ click.echo(f"{n:,}")
149
+
150
+
151
+ @cli.command(aliases=["i"], context_settings=CONTEXT_SETTINGS)
152
+ @click.option("--json", "as_json", is_flag=True, help="Output as JSON")
153
+ def info_cmd(as_json: bool):
154
+ """Show database information."""
155
+ db_info = info()
156
+
157
+ if as_json:
158
+ click.echo(json.dumps(db_info, indent=2))
159
+ else:
160
+ click.echo("CrossRef Local Database")
161
+ click.echo("-" * 40)
162
+ click.echo(f"Database: {db_info['db_path']}")
163
+ click.echo(f"Works: {db_info['works']:,}")
164
+ click.echo(f"FTS indexed: {db_info['fts_indexed']:,}")
165
+ click.echo(f"Citations: {db_info['citations']:,}")
166
+
167
+
168
+ @cli.command("impact-factor", aliases=["if"], context_settings=CONTEXT_SETTINGS)
169
+ @click.argument("journal")
170
+ @click.option("-y", "--year", default=2023, help="Target year")
171
+ @click.option("-w", "--window", default=2, help="Citation window years")
172
+ @click.option("--json", "as_json", is_flag=True, help="Output as JSON")
173
+ def impact_factor_cmd(journal: str, year: int, window: int, as_json: bool):
174
+ """Calculate impact factor for a journal."""
175
+ with ImpactFactorCalculator() as calc:
176
+ result = calc.calculate_impact_factor(
177
+ journal_identifier=journal,
178
+ target_year=year,
179
+ window_years=window,
180
+ )
181
+
182
+ if as_json:
183
+ click.echo(json.dumps(result, indent=2))
184
+ else:
185
+ click.echo(f"Journal: {result['journal']}")
186
+ click.echo(f"Year: {result['target_year']}")
187
+ click.echo(f"Window: {result['window_range']}")
188
+ click.echo(f"Articles: {result['total_articles']:,}")
189
+ click.echo(f"Citations: {result['total_citations']:,}")
190
+ click.echo(f"Impact Factor: {result['impact_factor']:.3f}")
191
+
192
+
193
+ @cli.command(context_settings=CONTEXT_SETTINGS)
194
+ def setup():
195
+ """Check setup status and configuration."""
196
+ from .config import Config, DEFAULT_DB_PATHS
197
+ import os
198
+
199
+ click.echo("CrossRef Local - Setup Status")
200
+ click.echo("=" * 50)
201
+ click.echo()
202
+
203
+ # Check environment variable
204
+ env_db = os.environ.get("CROSSREF_LOCAL_DB")
205
+ if env_db:
206
+ click.echo(f"CROSSREF_LOCAL_DB: {env_db}")
207
+ if os.path.exists(env_db):
208
+ click.echo(" Status: OK")
209
+ else:
210
+ click.echo(" Status: NOT FOUND")
211
+ else:
212
+ click.echo("CROSSREF_LOCAL_DB: (not set)")
213
+
214
+ click.echo()
215
+
216
+ # Check default paths
217
+ click.echo("Checking default database locations:")
218
+ db_found = None
219
+ for path in DEFAULT_DB_PATHS:
220
+ if path.exists():
221
+ click.echo(f" [OK] {path}")
222
+ if db_found is None:
223
+ db_found = path
224
+ else:
225
+ click.echo(f" [ ] {path}")
226
+
227
+ click.echo()
228
+
229
+ if db_found:
230
+ click.echo(f"Database found: {db_found}")
231
+ click.echo()
232
+
233
+ try:
234
+ db_info = info()
235
+ click.echo(f" Works: {db_info['works']:,}")
236
+ click.echo(f" FTS indexed: {db_info['fts_indexed']:,}")
237
+ click.echo(f" Citations: {db_info['citations']:,}")
238
+ click.echo()
239
+ click.echo("Setup complete! Try:")
240
+ click.echo(' crossref-local search "machine learning"')
241
+ except Exception as e:
242
+ click.echo(f" Error reading database: {e}", err=True)
243
+ else:
244
+ click.echo("No database found!")
245
+ click.echo()
246
+ click.echo("To set up:")
247
+ click.echo(" export CROSSREF_LOCAL_DB=/path/to/crossref.db")
248
+ click.echo(" See: make db-build-info")
249
+
250
+
251
+ def main():
252
+ """Entry point for CLI."""
253
+ cli()
254
+
255
+
256
+ if __name__ == "__main__":
257
+ main()