crossref-local 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,413 @@
1
+ """Citation network analysis and visualization.
2
+
3
+ Build citation graphs like Connected Papers from the local database.
4
+
5
+ Usage:
6
+ from crossref_local.citations import get_citing, get_cited, CitationNetwork
7
+
8
+ # Get papers citing a DOI
9
+ citing_dois = get_citing("10.1038/nature12373")
10
+
11
+ # Get papers a DOI cites
12
+ cited_dois = get_cited("10.1038/nature12373")
13
+
14
+ # Build and visualize network
15
+ network = CitationNetwork("10.1038/nature12373", depth=2)
16
+ network.save_html("citation_network.html")
17
+ """
18
+
19
+ from dataclasses import dataclass, field
20
+ from typing import List, Dict, Optional, Set, Tuple
21
+ from pathlib import Path
22
+
23
+ from .db import get_db, Database
24
+ from .models import Work
25
+
26
+
27
+ def get_citing(doi: str, limit: int = 100, db: Optional[Database] = None) -> List[str]:
28
+ """
29
+ Get DOIs of papers that cite the given DOI.
30
+
31
+ Args:
32
+ doi: The DOI to find citations for
33
+ limit: Maximum number of citing papers to return
34
+ db: Database connection (uses singleton if not provided)
35
+
36
+ Returns:
37
+ List of DOIs that cite this paper
38
+ """
39
+ if db is None:
40
+ db = get_db()
41
+
42
+ rows = db.fetchall(
43
+ """
44
+ SELECT citing_doi
45
+ FROM citations
46
+ WHERE cited_doi = ?
47
+ LIMIT ?
48
+ """,
49
+ (doi, limit)
50
+ )
51
+ return [row["citing_doi"] for row in rows]
52
+
53
+
54
+ def get_cited(doi: str, limit: int = 100, db: Optional[Database] = None) -> List[str]:
55
+ """
56
+ Get DOIs of papers that the given DOI cites (references).
57
+
58
+ Args:
59
+ doi: The DOI to find references for
60
+ limit: Maximum number of referenced papers to return
61
+ db: Database connection (uses singleton if not provided)
62
+
63
+ Returns:
64
+ List of DOIs that this paper cites
65
+ """
66
+ if db is None:
67
+ db = get_db()
68
+
69
+ rows = db.fetchall(
70
+ """
71
+ SELECT cited_doi
72
+ FROM citations
73
+ WHERE citing_doi = ?
74
+ LIMIT ?
75
+ """,
76
+ (doi, limit)
77
+ )
78
+ return [row["cited_doi"] for row in rows]
79
+
80
+
81
+ def get_citation_count(doi: str, db: Optional[Database] = None) -> int:
82
+ """
83
+ Get the number of citations for a DOI.
84
+
85
+ Args:
86
+ doi: The DOI to count citations for
87
+ db: Database connection
88
+
89
+ Returns:
90
+ Number of papers citing this DOI
91
+ """
92
+ if db is None:
93
+ db = get_db()
94
+
95
+ row = db.fetchone(
96
+ "SELECT COUNT(*) as count FROM citations WHERE cited_doi = ?",
97
+ (doi,)
98
+ )
99
+ return row["count"] if row else 0
100
+
101
+
102
+ @dataclass
103
+ class CitationNode:
104
+ """A node in the citation network."""
105
+ doi: str
106
+ title: str = ""
107
+ authors: List[str] = field(default_factory=list)
108
+ year: Optional[int] = None
109
+ journal: str = ""
110
+ citation_count: int = 0
111
+ depth: int = 0 # Distance from center node
112
+
113
+ def to_dict(self) -> dict:
114
+ return {
115
+ "doi": self.doi,
116
+ "title": self.title,
117
+ "authors": self.authors,
118
+ "year": self.year,
119
+ "journal": self.journal,
120
+ "citation_count": self.citation_count,
121
+ "depth": self.depth,
122
+ }
123
+
124
+
125
+ @dataclass
126
+ class CitationEdge:
127
+ """An edge in the citation network (citing -> cited)."""
128
+ citing_doi: str
129
+ cited_doi: str
130
+ year: Optional[int] = None
131
+
132
+
133
+ class CitationNetwork:
134
+ """
135
+ Citation network builder and visualizer.
136
+
137
+ Builds a graph of papers connected by citations, similar to Connected Papers.
138
+
139
+ Example:
140
+ >>> network = CitationNetwork("10.1038/nature12373", depth=2)
141
+ >>> print(f"Nodes: {len(network.nodes)}, Edges: {len(network.edges)}")
142
+ >>> network.save_html("network.html")
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ center_doi: str,
148
+ depth: int = 1,
149
+ max_citing: int = 50,
150
+ max_cited: int = 50,
151
+ db: Optional[Database] = None,
152
+ ):
153
+ """
154
+ Build a citation network around a central paper.
155
+
156
+ Args:
157
+ center_doi: The DOI to build the network around
158
+ depth: How many levels of citations to include (1 = direct only)
159
+ max_citing: Max papers citing each node to include
160
+ max_cited: Max papers each node cites to include
161
+ db: Database connection
162
+ """
163
+ self.center_doi = center_doi
164
+ self.depth = depth
165
+ self.max_citing = max_citing
166
+ self.max_cited = max_cited
167
+ self.db = db or get_db()
168
+
169
+ self.nodes: Dict[str, CitationNode] = {}
170
+ self.edges: List[CitationEdge] = []
171
+
172
+ self._build_network()
173
+
174
+ def _build_network(self):
175
+ """Build the citation network by traversing citations."""
176
+ # Start with center node
177
+ to_process: List[Tuple[str, int]] = [(self.center_doi, 0)]
178
+ processed: Set[str] = set()
179
+
180
+ while to_process:
181
+ doi, current_depth = to_process.pop(0)
182
+
183
+ if doi in processed:
184
+ continue
185
+ processed.add(doi)
186
+
187
+ # Add node
188
+ self._add_node(doi, current_depth)
189
+
190
+ # Stop expanding at max depth
191
+ if current_depth >= self.depth:
192
+ continue
193
+
194
+ # Get citing papers (papers that cite this one)
195
+ citing = get_citing(doi, limit=self.max_citing, db=self.db)
196
+ for citing_doi in citing:
197
+ self.edges.append(CitationEdge(citing_doi=citing_doi, cited_doi=doi))
198
+ if citing_doi not in processed:
199
+ to_process.append((citing_doi, current_depth + 1))
200
+
201
+ # Get cited papers (papers this one cites)
202
+ cited = get_cited(doi, limit=self.max_cited, db=self.db)
203
+ for cited_doi in cited:
204
+ self.edges.append(CitationEdge(citing_doi=doi, cited_doi=cited_doi))
205
+ if cited_doi not in processed:
206
+ to_process.append((cited_doi, current_depth + 1))
207
+
208
+ def _add_node(self, doi: str, depth: int):
209
+ """Add a node with metadata from the database."""
210
+ if doi in self.nodes:
211
+ return
212
+
213
+ # Get metadata
214
+ metadata = self.db.get_metadata(doi)
215
+ citation_count = get_citation_count(doi, db=self.db)
216
+
217
+ if metadata:
218
+ work = Work.from_metadata(doi, metadata)
219
+ self.nodes[doi] = CitationNode(
220
+ doi=doi,
221
+ title=work.title or "",
222
+ authors=work.authors,
223
+ year=work.year,
224
+ journal=work.journal or "",
225
+ citation_count=citation_count,
226
+ depth=depth,
227
+ )
228
+ else:
229
+ # DOI not in our database, create minimal node
230
+ self.nodes[doi] = CitationNode(
231
+ doi=doi,
232
+ citation_count=citation_count,
233
+ depth=depth,
234
+ )
235
+
236
+ def to_networkx(self):
237
+ """
238
+ Convert to a NetworkX DiGraph.
239
+
240
+ Returns:
241
+ networkx.DiGraph with nodes and edges
242
+
243
+ Raises:
244
+ ImportError: If networkx is not installed
245
+ """
246
+ try:
247
+ import networkx as nx
248
+ except ImportError:
249
+ raise ImportError("networkx required: pip install networkx")
250
+
251
+ G = nx.DiGraph()
252
+
253
+ # Add nodes with attributes
254
+ for doi, node in self.nodes.items():
255
+ G.add_node(doi, **node.to_dict())
256
+
257
+ # Add edges
258
+ for edge in self.edges:
259
+ if edge.citing_doi in self.nodes and edge.cited_doi in self.nodes:
260
+ G.add_edge(edge.citing_doi, edge.cited_doi)
261
+
262
+ return G
263
+
264
+ def save_html(self, path: str = "citation_network.html", **kwargs):
265
+ """
266
+ Save interactive HTML visualization using pyvis.
267
+
268
+ Args:
269
+ path: Output file path
270
+ **kwargs: Additional options for pyvis Network
271
+
272
+ Raises:
273
+ ImportError: If pyvis is not installed
274
+ """
275
+ try:
276
+ from pyvis.network import Network
277
+ except ImportError:
278
+ raise ImportError("pyvis required: pip install pyvis")
279
+
280
+ # Create pyvis network
281
+ net = Network(
282
+ height="800px",
283
+ width="100%",
284
+ directed=True,
285
+ bgcolor="#ffffff",
286
+ font_color="#333333",
287
+ **kwargs
288
+ )
289
+
290
+ # Configure physics
291
+ net.barnes_hut(
292
+ gravity=-3000,
293
+ central_gravity=0.3,
294
+ spring_length=200,
295
+ spring_strength=0.05,
296
+ )
297
+
298
+ # Add nodes with styling based on depth and citation count
299
+ for doi, node in self.nodes.items():
300
+ # Size based on citation count (log scale)
301
+ import math
302
+ size = 10 + min(30, math.log1p(node.citation_count) * 5)
303
+
304
+ # Color based on depth
305
+ colors = ["#e74c3c", "#3498db", "#2ecc71", "#9b59b6", "#f39c12"]
306
+ color = colors[min(node.depth, len(colors) - 1)]
307
+
308
+ # Label
309
+ title_short = (node.title[:50] + "...") if len(node.title) > 50 else node.title
310
+ label = f"{title_short}\n({node.year or 'N/A'})"
311
+
312
+ # Tooltip
313
+ authors_str = ", ".join(node.authors[:3])
314
+ if len(node.authors) > 3:
315
+ authors_str += " et al."
316
+ tooltip = f"""
317
+ <b>{node.title}</b><br>
318
+ {authors_str}<br>
319
+ {node.journal} ({node.year or 'N/A'})<br>
320
+ Citations: {node.citation_count}<br>
321
+ DOI: {doi}
322
+ """
323
+
324
+ net.add_node(
325
+ doi,
326
+ label=label,
327
+ title=tooltip,
328
+ size=size,
329
+ color=color,
330
+ borderWidth=2 if doi == self.center_doi else 1,
331
+ borderWidthSelected=4,
332
+ )
333
+
334
+ # Add edges
335
+ for edge in self.edges:
336
+ if edge.citing_doi in self.nodes and edge.cited_doi in self.nodes:
337
+ net.add_edge(edge.citing_doi, edge.cited_doi, arrows="to")
338
+
339
+ # Save
340
+ net.save_graph(path)
341
+ return path
342
+
343
+ def save_png(self, path: str = "citation_network.png", figsize: Tuple[int, int] = (12, 10)):
344
+ """
345
+ Save static PNG visualization using matplotlib.
346
+
347
+ Args:
348
+ path: Output file path
349
+ figsize: Figure size (width, height)
350
+
351
+ Raises:
352
+ ImportError: If matplotlib is not installed
353
+ """
354
+ try:
355
+ import matplotlib.pyplot as plt
356
+ import networkx as nx
357
+ except ImportError:
358
+ raise ImportError("matplotlib and networkx required")
359
+
360
+ G = self.to_networkx()
361
+
362
+ fig, ax = plt.subplots(figsize=figsize)
363
+
364
+ # Layout
365
+ pos = nx.spring_layout(G, k=2, iterations=50)
366
+
367
+ # Node sizes based on citation count
368
+ import math
369
+ sizes = [100 + min(500, math.log1p(self.nodes[n].citation_count) * 50) for n in G.nodes()]
370
+
371
+ # Node colors based on depth
372
+ colors = [self.nodes[n].depth for n in G.nodes()]
373
+
374
+ # Draw
375
+ nx.draw_networkx_nodes(G, pos, node_size=sizes, node_color=colors,
376
+ cmap=plt.cm.RdYlBu_r, alpha=0.8, ax=ax)
377
+ nx.draw_networkx_edges(G, pos, alpha=0.3, arrows=True,
378
+ arrowsize=10, ax=ax)
379
+
380
+ # Labels for important nodes (high citation count)
381
+ labels = {}
382
+ for doi in G.nodes():
383
+ node = self.nodes[doi]
384
+ if node.citation_count > 10 or doi == self.center_doi:
385
+ short_title = (node.title[:30] + "...") if len(node.title) > 30 else node.title
386
+ labels[doi] = f"{short_title}\n({node.year or 'N/A'})"
387
+
388
+ nx.draw_networkx_labels(G, pos, labels, font_size=8, ax=ax)
389
+
390
+ ax.set_title(f"Citation Network: {self.center_doi}")
391
+ ax.axis("off")
392
+
393
+ plt.tight_layout()
394
+ plt.savefig(path, dpi=150, bbox_inches="tight")
395
+ plt.close()
396
+
397
+ return path
398
+
399
+ def to_dict(self) -> dict:
400
+ """Export network as dictionary."""
401
+ return {
402
+ "center_doi": self.center_doi,
403
+ "depth": self.depth,
404
+ "nodes": [n.to_dict() for n in self.nodes.values()],
405
+ "edges": [{"citing": e.citing_doi, "cited": e.cited_doi} for e in self.edges],
406
+ "stats": {
407
+ "total_nodes": len(self.nodes),
408
+ "total_edges": len(self.edges),
409
+ }
410
+ }
411
+
412
+ def __repr__(self):
413
+ return f"CitationNetwork(center={self.center_doi}, nodes={len(self.nodes)}, edges={len(self.edges)})"