cite-agent 1.3.9__py3-none-any.whl → 1.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cite_agent/__init__.py +13 -13
- cite_agent/__version__.py +1 -1
- cite_agent/action_first_mode.py +150 -0
- cite_agent/adaptive_providers.py +413 -0
- cite_agent/archive_api_client.py +186 -0
- cite_agent/auth.py +0 -1
- cite_agent/auto_expander.py +70 -0
- cite_agent/cache.py +379 -0
- cite_agent/circuit_breaker.py +370 -0
- cite_agent/citation_network.py +377 -0
- cite_agent/cli.py +8 -16
- cite_agent/cli_conversational.py +113 -3
- cite_agent/confidence_calibration.py +381 -0
- cite_agent/deduplication.py +325 -0
- cite_agent/enhanced_ai_agent.py +689 -371
- cite_agent/error_handler.py +228 -0
- cite_agent/execution_safety.py +329 -0
- cite_agent/full_paper_reader.py +239 -0
- cite_agent/observability.py +398 -0
- cite_agent/offline_mode.py +348 -0
- cite_agent/paper_comparator.py +368 -0
- cite_agent/paper_summarizer.py +420 -0
- cite_agent/pdf_extractor.py +350 -0
- cite_agent/proactive_boundaries.py +266 -0
- cite_agent/quality_gate.py +442 -0
- cite_agent/request_queue.py +390 -0
- cite_agent/response_enhancer.py +257 -0
- cite_agent/response_formatter.py +458 -0
- cite_agent/response_pipeline.py +295 -0
- cite_agent/response_style_enhancer.py +259 -0
- cite_agent/self_healing.py +418 -0
- cite_agent/similarity_finder.py +524 -0
- cite_agent/streaming_ui.py +13 -9
- cite_agent/thinking_blocks.py +308 -0
- cite_agent/tool_orchestrator.py +416 -0
- cite_agent/trend_analyzer.py +540 -0
- cite_agent/unpaywall_client.py +226 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/METADATA +15 -1
- cite_agent-1.4.3.dist-info/RECORD +62 -0
- cite_agent-1.3.9.dist-info/RECORD +0 -32
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/WHEEL +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/entry_points.txt +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/licenses/LICENSE +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Confidence Calibration - Know When Uncertain
|
|
3
|
+
Assesses confidence in responses and adds appropriate caveats
|
|
4
|
+
|
|
5
|
+
This prevents overconfident wrong answers
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
from typing import Dict, Any, Optional
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ConfidenceAssessment:
|
|
18
|
+
"""Result of confidence assessment"""
|
|
19
|
+
confidence_score: float # 0.0-1.0
|
|
20
|
+
confidence_level: str # "high", "medium", "low"
|
|
21
|
+
should_add_caveat: bool
|
|
22
|
+
caveat_text: Optional[str]
|
|
23
|
+
factors: Dict[str, float] # What contributed to confidence
|
|
24
|
+
reasoning: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ConfidenceCalibrator:
|
|
28
|
+
"""
|
|
29
|
+
Assesses confidence in responses
|
|
30
|
+
|
|
31
|
+
Factors that affect confidence:
|
|
32
|
+
1. Data quality (Do we have good data?)
|
|
33
|
+
2. Query clarity (Is the question clear?)
|
|
34
|
+
3. Answer completeness (Did we answer fully?)
|
|
35
|
+
4. Source reliability (Are sources trustworthy?)
|
|
36
|
+
5. Response consistency (Does answer make sense?)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
# Confidence thresholds
|
|
40
|
+
HIGH_CONFIDENCE = 0.8
|
|
41
|
+
MEDIUM_CONFIDENCE = 0.6
|
|
42
|
+
LOW_CONFIDENCE = 0.4
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def assess_confidence(
|
|
46
|
+
cls,
|
|
47
|
+
response: str,
|
|
48
|
+
query: str,
|
|
49
|
+
context: Dict[str, Any]
|
|
50
|
+
) -> ConfidenceAssessment:
|
|
51
|
+
"""
|
|
52
|
+
Assess confidence in a response
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
response: The generated response
|
|
56
|
+
query: Original user query
|
|
57
|
+
context: Context including tools used, data sources, etc.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Confidence assessment with score and caveat if needed
|
|
61
|
+
"""
|
|
62
|
+
factors = {}
|
|
63
|
+
|
|
64
|
+
# Factor 1: Data quality (40% weight)
|
|
65
|
+
factors['data_quality'] = cls._assess_data_quality(context)
|
|
66
|
+
|
|
67
|
+
# Factor 2: Query clarity (20% weight)
|
|
68
|
+
factors['query_clarity'] = cls._assess_query_clarity(query)
|
|
69
|
+
|
|
70
|
+
# Factor 3: Answer completeness (25% weight)
|
|
71
|
+
factors['answer_completeness'] = cls._assess_completeness(response, query)
|
|
72
|
+
|
|
73
|
+
# Factor 4: Source reliability (10% weight)
|
|
74
|
+
factors['source_reliability'] = cls._assess_source_reliability(context)
|
|
75
|
+
|
|
76
|
+
# Factor 5: Response consistency (5% weight)
|
|
77
|
+
factors['response_consistency'] = cls._assess_consistency(response)
|
|
78
|
+
|
|
79
|
+
# Calculate weighted confidence score
|
|
80
|
+
weights = {
|
|
81
|
+
'data_quality': 0.40,
|
|
82
|
+
'query_clarity': 0.20,
|
|
83
|
+
'answer_completeness': 0.25,
|
|
84
|
+
'source_reliability': 0.10,
|
|
85
|
+
'response_consistency': 0.05
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
confidence_score = sum(factors[k] * weights[k] for k in weights)
|
|
89
|
+
|
|
90
|
+
# Determine confidence level
|
|
91
|
+
if confidence_score >= cls.HIGH_CONFIDENCE:
|
|
92
|
+
confidence_level = "high"
|
|
93
|
+
elif confidence_score >= cls.MEDIUM_CONFIDENCE:
|
|
94
|
+
confidence_level = "medium"
|
|
95
|
+
else:
|
|
96
|
+
confidence_level = "low"
|
|
97
|
+
|
|
98
|
+
# Determine if we should add a caveat
|
|
99
|
+
should_add_caveat = confidence_score < cls.MEDIUM_CONFIDENCE
|
|
100
|
+
|
|
101
|
+
# Generate caveat text if needed
|
|
102
|
+
caveat_text = None
|
|
103
|
+
if should_add_caveat:
|
|
104
|
+
caveat_text = cls._generate_caveat(confidence_score, factors, context)
|
|
105
|
+
|
|
106
|
+
# Generate reasoning
|
|
107
|
+
reasoning = cls._explain_confidence(confidence_score, factors)
|
|
108
|
+
|
|
109
|
+
return ConfidenceAssessment(
|
|
110
|
+
confidence_score=confidence_score,
|
|
111
|
+
confidence_level=confidence_level,
|
|
112
|
+
should_add_caveat=should_add_caveat,
|
|
113
|
+
caveat_text=caveat_text,
|
|
114
|
+
factors=factors,
|
|
115
|
+
reasoning=reasoning
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def _assess_data_quality(cls, context: Dict[str, Any]) -> float:
|
|
120
|
+
"""
|
|
121
|
+
Assess quality of data used
|
|
122
|
+
|
|
123
|
+
High quality: Multiple reliable sources, recent data
|
|
124
|
+
Low quality: No sources, old data, incomplete data
|
|
125
|
+
"""
|
|
126
|
+
score = 0.5 # Neutral start
|
|
127
|
+
|
|
128
|
+
tools_used = context.get('tools_used', [])
|
|
129
|
+
api_results = context.get('api_results', {})
|
|
130
|
+
|
|
131
|
+
# Check if we have data sources
|
|
132
|
+
has_data = bool(api_results) or bool(tools_used)
|
|
133
|
+
|
|
134
|
+
if not has_data:
|
|
135
|
+
return 0.3 # Low confidence if no data
|
|
136
|
+
|
|
137
|
+
# Check quality of data sources
|
|
138
|
+
reliable_sources = ['archive_api', 'finsight_api', 'shell_execution']
|
|
139
|
+
unreliable_sources = ['web_search']
|
|
140
|
+
|
|
141
|
+
reliable_count = sum(1 for tool in tools_used if tool in reliable_sources)
|
|
142
|
+
unreliable_count = sum(1 for tool in tools_used if tool in unreliable_sources)
|
|
143
|
+
|
|
144
|
+
if reliable_count > 0:
|
|
145
|
+
score += 0.3
|
|
146
|
+
|
|
147
|
+
if unreliable_count > 0 and reliable_count == 0:
|
|
148
|
+
score -= 0.2
|
|
149
|
+
|
|
150
|
+
# Check for empty results
|
|
151
|
+
if api_results:
|
|
152
|
+
# Check if results are empty
|
|
153
|
+
for key, value in api_results.items():
|
|
154
|
+
if isinstance(value, dict) and value.get('results') == []:
|
|
155
|
+
score -= 0.4 # Major penalty for empty results
|
|
156
|
+
elif isinstance(value, list) and not value:
|
|
157
|
+
score -= 0.3
|
|
158
|
+
|
|
159
|
+
return min(1.0, max(0.0, score))
|
|
160
|
+
|
|
161
|
+
@classmethod
|
|
162
|
+
def _assess_query_clarity(cls, query: str) -> float:
|
|
163
|
+
"""
|
|
164
|
+
Assess how clear the query is
|
|
165
|
+
|
|
166
|
+
Clear query: Specific, unambiguous
|
|
167
|
+
Unclear query: Vague, pronouns without context, too short
|
|
168
|
+
"""
|
|
169
|
+
score = 1.0 # Start optimistic
|
|
170
|
+
|
|
171
|
+
# Too short is often ambiguous
|
|
172
|
+
word_count = len(query.split())
|
|
173
|
+
if word_count < 3:
|
|
174
|
+
score -= 0.3
|
|
175
|
+
|
|
176
|
+
# Pronouns without context
|
|
177
|
+
pronouns = ['it', 'that', 'those', 'this', 'them']
|
|
178
|
+
has_pronoun = any(pronoun in query.lower().split() for pronoun in pronouns)
|
|
179
|
+
|
|
180
|
+
if has_pronoun and word_count < 8:
|
|
181
|
+
score -= 0.2
|
|
182
|
+
|
|
183
|
+
# Very vague terms
|
|
184
|
+
vague_terms = ['something', 'stuff', 'things', 'anything']
|
|
185
|
+
if any(term in query.lower() for term in vague_terms):
|
|
186
|
+
score -= 0.25
|
|
187
|
+
|
|
188
|
+
# Question marks without clear question
|
|
189
|
+
if '?' in query and word_count < 4:
|
|
190
|
+
score -= 0.15
|
|
191
|
+
|
|
192
|
+
return min(1.0, max(0.0, score))
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def _assess_completeness(cls, response: str, query: str) -> float:
|
|
196
|
+
"""
|
|
197
|
+
Assess if response fully answers the query
|
|
198
|
+
|
|
199
|
+
Complete: Addresses all aspects of query
|
|
200
|
+
Incomplete: Partial answer, deflects, says "I don't know"
|
|
201
|
+
"""
|
|
202
|
+
score = 0.7 # Assume mostly complete
|
|
203
|
+
|
|
204
|
+
response_lower = response.lower()
|
|
205
|
+
|
|
206
|
+
# Check for deflection
|
|
207
|
+
deflection_phrases = [
|
|
208
|
+
"i don't know",
|
|
209
|
+
"i'm not sure",
|
|
210
|
+
"i don't have",
|
|
211
|
+
"i can't",
|
|
212
|
+
"unclear"
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
if any(phrase in response_lower for phrase in deflection_phrases):
|
|
216
|
+
score -= 0.4
|
|
217
|
+
|
|
218
|
+
# Check response length relative to query complexity
|
|
219
|
+
query_words = len(query.split())
|
|
220
|
+
response_words = len(response.split())
|
|
221
|
+
|
|
222
|
+
if query_words > 10 and response_words < 30:
|
|
223
|
+
score -= 0.2 # Too brief for complex query
|
|
224
|
+
|
|
225
|
+
# Check if response addresses key terms from query
|
|
226
|
+
query_terms = set(query.lower().split())
|
|
227
|
+
response_terms = set(response_lower.split())
|
|
228
|
+
|
|
229
|
+
# Remove stop words
|
|
230
|
+
stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'what', 'how', 'why'}
|
|
231
|
+
query_terms -= stop_words
|
|
232
|
+
|
|
233
|
+
if query_terms:
|
|
234
|
+
overlap = len(query_terms & response_terms) / len(query_terms)
|
|
235
|
+
if overlap < 0.3:
|
|
236
|
+
score -= 0.25
|
|
237
|
+
|
|
238
|
+
return min(1.0, max(0.0, score))
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def _assess_source_reliability(cls, context: Dict[str, Any]) -> float:
|
|
242
|
+
"""
|
|
243
|
+
Assess reliability of sources used
|
|
244
|
+
|
|
245
|
+
Reliable: Official APIs, verified databases
|
|
246
|
+
Unreliable: Web scraping, unverified sources
|
|
247
|
+
"""
|
|
248
|
+
tools_used = context.get('tools_used', [])
|
|
249
|
+
|
|
250
|
+
if not tools_used:
|
|
251
|
+
return 0.5 # Neutral if no tools used
|
|
252
|
+
|
|
253
|
+
# Reliability scores for different tools
|
|
254
|
+
reliability_map = {
|
|
255
|
+
'archive_api': 0.95, # Academic papers - very reliable
|
|
256
|
+
'finsight_api': 0.90, # SEC filings - very reliable
|
|
257
|
+
'shell_execution': 0.85, # Direct file access - reliable
|
|
258
|
+
'web_search': 0.60, # Web - less reliable
|
|
259
|
+
'fallback': 0.30 # Fallback responses - low reliability
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
# Average reliability of tools used
|
|
263
|
+
reliabilities = [reliability_map.get(tool, 0.5) for tool in tools_used]
|
|
264
|
+
avg_reliability = sum(reliabilities) / len(reliabilities) if reliabilities else 0.5
|
|
265
|
+
|
|
266
|
+
return avg_reliability
|
|
267
|
+
|
|
268
|
+
@classmethod
|
|
269
|
+
def _assess_consistency(cls, response: str) -> float:
|
|
270
|
+
"""
|
|
271
|
+
Assess internal consistency of response
|
|
272
|
+
|
|
273
|
+
Consistent: No contradictions, logical flow
|
|
274
|
+
Inconsistent: Contradictions, illogical statements
|
|
275
|
+
"""
|
|
276
|
+
score = 1.0 # Assume consistent
|
|
277
|
+
|
|
278
|
+
# Check for hedge words that indicate uncertainty
|
|
279
|
+
hedge_words = ['maybe', 'possibly', 'might', 'could', 'perhaps']
|
|
280
|
+
hedge_count = sum(1 for word in hedge_words if f' {word} ' in response.lower())
|
|
281
|
+
|
|
282
|
+
if hedge_count > 3:
|
|
283
|
+
score -= 0.2 # Too many hedges suggests uncertainty
|
|
284
|
+
|
|
285
|
+
# Check for contradictions
|
|
286
|
+
contradiction_patterns = [
|
|
287
|
+
('but ', 'however '),
|
|
288
|
+
('although ', 'though '),
|
|
289
|
+
('not ', 'no ')
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
contradiction_count = sum(
|
|
293
|
+
1 for pattern_words in contradiction_patterns
|
|
294
|
+
if all(word in response.lower() for word in pattern_words)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if contradiction_count > 2:
|
|
298
|
+
score -= 0.15
|
|
299
|
+
|
|
300
|
+
return min(1.0, max(0.0, score))
|
|
301
|
+
|
|
302
|
+
@classmethod
|
|
303
|
+
def _generate_caveat(cls, confidence_score: float, factors: Dict[str, float], context: Dict[str, Any]) -> str:
|
|
304
|
+
"""
|
|
305
|
+
Generate appropriate caveat text based on confidence factors
|
|
306
|
+
"""
|
|
307
|
+
# Identify the weakest factor
|
|
308
|
+
weakest_factor = min(factors.items(), key=lambda x: x[1])
|
|
309
|
+
factor_name, factor_score = weakest_factor
|
|
310
|
+
|
|
311
|
+
caveats = {
|
|
312
|
+
'data_quality': "Based on limited data available, ",
|
|
313
|
+
'query_clarity': "I interpreted your question to mean: ",
|
|
314
|
+
'answer_completeness': "Based on what I could find, ",
|
|
315
|
+
'source_reliability': "According to available sources, ",
|
|
316
|
+
'response_consistency': "To the best of my understanding, "
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
caveat_prefix = caveats.get(factor_name, "Based on available information, ")
|
|
320
|
+
|
|
321
|
+
# Add severity based on overall confidence
|
|
322
|
+
if confidence_score < cls.LOW_CONFIDENCE:
|
|
323
|
+
caveat_prefix = "⚠️ Low confidence: " + caveat_prefix
|
|
324
|
+
|
|
325
|
+
return caveat_prefix
|
|
326
|
+
|
|
327
|
+
@classmethod
|
|
328
|
+
def _explain_confidence(cls, confidence_score: float, factors: Dict[str, float]) -> str:
|
|
329
|
+
"""Generate explanation of confidence level"""
|
|
330
|
+
level = "high" if confidence_score >= cls.HIGH_CONFIDENCE else \
|
|
331
|
+
"medium" if confidence_score >= cls.MEDIUM_CONFIDENCE else "low"
|
|
332
|
+
|
|
333
|
+
# Identify top contributing factors
|
|
334
|
+
sorted_factors = sorted(factors.items(), key=lambda x: x[1], reverse=True)
|
|
335
|
+
top_factor = sorted_factors[0][0]
|
|
336
|
+
bottom_factor = sorted_factors[-1][0]
|
|
337
|
+
|
|
338
|
+
reasoning = (
|
|
339
|
+
f"Confidence level: {level} ({confidence_score:.2f}). "
|
|
340
|
+
f"Strongest factor: {top_factor} ({factors[top_factor]:.2f}). "
|
|
341
|
+
f"Weakest factor: {bottom_factor} ({factors[bottom_factor]:.2f})."
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
return reasoning
|
|
345
|
+
|
|
346
|
+
@classmethod
|
|
347
|
+
def add_caveat_to_response(cls, response: str, caveat: str) -> str:
|
|
348
|
+
"""
|
|
349
|
+
Add caveat to response in a natural way
|
|
350
|
+
|
|
351
|
+
Inserts caveat at the beginning of the response
|
|
352
|
+
"""
|
|
353
|
+
# If response already starts with a caveat-like phrase, don't add another
|
|
354
|
+
caveat_starters = ['based on', 'according to', 'to the best', 'from what']
|
|
355
|
+
|
|
356
|
+
response_lower = response.lower()
|
|
357
|
+
if any(starter in response_lower[:50] for starter in caveat_starters):
|
|
358
|
+
return response # Already has a caveat
|
|
359
|
+
|
|
360
|
+
# Add caveat at the start
|
|
361
|
+
return caveat + response[0].lower() + response[1:]
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
# Convenience function
|
|
365
|
+
def assess_and_apply_caveat(response: str, query: str, context: Dict[str, Any]) -> tuple:
|
|
366
|
+
"""
|
|
367
|
+
Assess confidence and apply caveat if needed
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
(final_response, confidence_assessment)
|
|
371
|
+
"""
|
|
372
|
+
assessment = ConfidenceCalibrator.assess_confidence(response, query, context)
|
|
373
|
+
|
|
374
|
+
final_response = response
|
|
375
|
+
if assessment.should_add_caveat and assessment.caveat_text:
|
|
376
|
+
final_response = ConfidenceCalibrator.add_caveat_to_response(
|
|
377
|
+
response,
|
|
378
|
+
assessment.caveat_text
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
return final_response, assessment
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Paper Deduplication - Remove Duplicate Papers
|
|
4
|
+
|
|
5
|
+
When searching multiple sources (Semantic Scholar, OpenAlex, PubMed),
|
|
6
|
+
the same paper often appears multiple times.
|
|
7
|
+
|
|
8
|
+
This module provides intelligent deduplication:
|
|
9
|
+
- By DOI (most reliable)
|
|
10
|
+
- By title similarity (fuzzy matching)
|
|
11
|
+
- By arXiv ID
|
|
12
|
+
- Merge metadata from multiple sources
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from typing import List, Dict, Any, Set, Optional
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from difflib import SequenceMatcher
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class PaperIdentifiers:
|
|
25
|
+
"""All possible identifiers for a paper"""
|
|
26
|
+
doi: Optional[str] = None
|
|
27
|
+
arxiv_id: Optional[str] = None
|
|
28
|
+
pmid: Optional[str] = None # PubMed ID
|
|
29
|
+
semantic_scholar_id: Optional[str] = None
|
|
30
|
+
openalex_id: Optional[str] = None
|
|
31
|
+
title: Optional[str] = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PaperDeduplicator:
|
|
35
|
+
"""
|
|
36
|
+
Intelligent paper deduplication
|
|
37
|
+
|
|
38
|
+
Strategy:
|
|
39
|
+
1. Exact DOI match (highest confidence)
|
|
40
|
+
2. arXiv ID match
|
|
41
|
+
3. PubMed ID match
|
|
42
|
+
4. Title fuzzy match (>90% similarity)
|
|
43
|
+
5. Keep best version (most complete metadata)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, title_similarity_threshold: float = 0.9):
|
|
47
|
+
"""
|
|
48
|
+
Initialize deduplicator
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
title_similarity_threshold: Minimum similarity for title matching (0.0-1.0)
|
|
52
|
+
"""
|
|
53
|
+
self.title_threshold = title_similarity_threshold
|
|
54
|
+
|
|
55
|
+
def deduplicate(self, papers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
56
|
+
"""
|
|
57
|
+
Deduplicate list of papers
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
papers: List of papers from multiple sources
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Deduplicated list with merged metadata
|
|
64
|
+
"""
|
|
65
|
+
if not papers:
|
|
66
|
+
return []
|
|
67
|
+
|
|
68
|
+
# Group papers by identifiers
|
|
69
|
+
groups = self._group_duplicates(papers)
|
|
70
|
+
|
|
71
|
+
# Merge each group into best representation
|
|
72
|
+
deduplicated = []
|
|
73
|
+
for group in groups:
|
|
74
|
+
merged = self._merge_papers(group)
|
|
75
|
+
deduplicated.append(merged)
|
|
76
|
+
|
|
77
|
+
original_count = len(papers)
|
|
78
|
+
final_count = len(deduplicated)
|
|
79
|
+
removed = original_count - final_count
|
|
80
|
+
|
|
81
|
+
if removed > 0:
|
|
82
|
+
logger.info(f"🔍 Deduplicated: {original_count} → {final_count} papers ({removed} duplicates removed)")
|
|
83
|
+
|
|
84
|
+
return deduplicated
|
|
85
|
+
|
|
86
|
+
def _group_duplicates(self, papers: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
|
|
87
|
+
"""
|
|
88
|
+
Group duplicate papers together
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
papers: List of papers
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
List of groups, where each group contains duplicate papers
|
|
95
|
+
"""
|
|
96
|
+
# Track which papers are already grouped
|
|
97
|
+
grouped_indices: Set[int] = set()
|
|
98
|
+
groups: List[List[Dict[str, Any]]] = []
|
|
99
|
+
|
|
100
|
+
for i, paper1 in enumerate(papers):
|
|
101
|
+
if i in grouped_indices:
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
# Start new group with this paper
|
|
105
|
+
group = [paper1]
|
|
106
|
+
grouped_indices.add(i)
|
|
107
|
+
|
|
108
|
+
# Find all duplicates
|
|
109
|
+
for j, paper2 in enumerate(papers[i+1:], start=i+1):
|
|
110
|
+
if j in grouped_indices:
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
if self._are_duplicates(paper1, paper2):
|
|
114
|
+
group.append(paper2)
|
|
115
|
+
grouped_indices.add(j)
|
|
116
|
+
|
|
117
|
+
groups.append(group)
|
|
118
|
+
|
|
119
|
+
return groups
|
|
120
|
+
|
|
121
|
+
def _are_duplicates(self, paper1: Dict[str, Any], paper2: Dict[str, Any]) -> bool:
|
|
122
|
+
"""
|
|
123
|
+
Check if two papers are duplicates
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
paper1: First paper
|
|
127
|
+
paper2: Second paper
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
True if papers are duplicates
|
|
131
|
+
"""
|
|
132
|
+
# Extract identifiers
|
|
133
|
+
id1 = self._extract_identifiers(paper1)
|
|
134
|
+
id2 = self._extract_identifiers(paper2)
|
|
135
|
+
|
|
136
|
+
# Check DOI match (most reliable)
|
|
137
|
+
if id1.doi and id2.doi:
|
|
138
|
+
# Normalize DOIs (remove https://doi.org/ prefix)
|
|
139
|
+
doi1 = id1.doi.lower().replace("https://doi.org/", "").strip()
|
|
140
|
+
doi2 = id2.doi.lower().replace("https://doi.org/", "").strip()
|
|
141
|
+
if doi1 == doi2:
|
|
142
|
+
logger.debug(f"DOI match: {doi1}")
|
|
143
|
+
return True
|
|
144
|
+
|
|
145
|
+
# Check arXiv ID match
|
|
146
|
+
if id1.arxiv_id and id2.arxiv_id:
|
|
147
|
+
if id1.arxiv_id.lower() == id2.arxiv_id.lower():
|
|
148
|
+
logger.debug(f"arXiv match: {id1.arxiv_id}")
|
|
149
|
+
return True
|
|
150
|
+
|
|
151
|
+
# Check PubMed ID match
|
|
152
|
+
if id1.pmid and id2.pmid:
|
|
153
|
+
if id1.pmid == id2.pmid:
|
|
154
|
+
logger.debug(f"PMID match: {id1.pmid}")
|
|
155
|
+
return True
|
|
156
|
+
|
|
157
|
+
# Check title similarity (fuzzy matching)
|
|
158
|
+
if id1.title and id2.title:
|
|
159
|
+
similarity = self._title_similarity(id1.title, id2.title)
|
|
160
|
+
if similarity >= self.title_threshold:
|
|
161
|
+
logger.debug(f"Title match: {similarity:.2%} - '{id1.title[:50]}...'")
|
|
162
|
+
return True
|
|
163
|
+
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
def _extract_identifiers(self, paper: Dict[str, Any]) -> PaperIdentifiers:
|
|
167
|
+
"""
|
|
168
|
+
Extract all identifiers from paper
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
paper: Paper dictionary
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
PaperIdentifiers object
|
|
175
|
+
"""
|
|
176
|
+
# Handle different API formats
|
|
177
|
+
return PaperIdentifiers(
|
|
178
|
+
doi=paper.get("doi") or paper.get("DOI"),
|
|
179
|
+
arxiv_id=paper.get("arxivId") or paper.get("arxiv_id"),
|
|
180
|
+
pmid=paper.get("pmid") or paper.get("pubmed_id"),
|
|
181
|
+
semantic_scholar_id=paper.get("paperId") or paper.get("semantic_scholar_id"),
|
|
182
|
+
openalex_id=paper.get("id") if isinstance(paper.get("id"), str) and paper.get("id", "").startswith("https://openalex.org/") else None,
|
|
183
|
+
title=paper.get("title")
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def _title_similarity(self, title1: str, title2: str) -> float:
|
|
187
|
+
"""
|
|
188
|
+
Calculate similarity between two titles
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
title1: First title
|
|
192
|
+
title2: Second title
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Similarity score (0.0-1.0)
|
|
196
|
+
"""
|
|
197
|
+
# Normalize titles
|
|
198
|
+
t1 = self._normalize_title(title1)
|
|
199
|
+
t2 = self._normalize_title(title2)
|
|
200
|
+
|
|
201
|
+
# Calculate similarity using SequenceMatcher
|
|
202
|
+
return SequenceMatcher(None, t1, t2).ratio()
|
|
203
|
+
|
|
204
|
+
def _normalize_title(self, title: str) -> str:
|
|
205
|
+
"""
|
|
206
|
+
Normalize title for comparison
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
title: Original title
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Normalized title
|
|
213
|
+
"""
|
|
214
|
+
# Convert to lowercase
|
|
215
|
+
normalized = title.lower()
|
|
216
|
+
|
|
217
|
+
# Remove common punctuation
|
|
218
|
+
for char in [".", ",", ":", ";", "!", "?", "-", "(", ")", "[", "]", "{", "}"]:
|
|
219
|
+
normalized = normalized.replace(char, " ")
|
|
220
|
+
|
|
221
|
+
# Normalize whitespace
|
|
222
|
+
normalized = " ".join(normalized.split())
|
|
223
|
+
|
|
224
|
+
return normalized
|
|
225
|
+
|
|
226
|
+
def _merge_papers(self, papers: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
227
|
+
"""
|
|
228
|
+
Merge duplicate papers into best representation
|
|
229
|
+
|
|
230
|
+
Strategy:
|
|
231
|
+
- Keep most complete metadata
|
|
232
|
+
- Prefer DOI > arXiv > Semantic Scholar
|
|
233
|
+
- Merge citation counts (take max)
|
|
234
|
+
- Merge sources
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
papers: List of duplicate papers
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Merged paper with best metadata
|
|
241
|
+
"""
|
|
242
|
+
if len(papers) == 1:
|
|
243
|
+
return papers[0]
|
|
244
|
+
|
|
245
|
+
# Start with paper that has most fields
|
|
246
|
+
merged = max(papers, key=lambda p: len([v for v in p.values() if v is not None]))
|
|
247
|
+
|
|
248
|
+
# Merge citation counts (take maximum)
|
|
249
|
+
citation_counts = [p.get("citationCount", 0) for p in papers]
|
|
250
|
+
if citation_counts:
|
|
251
|
+
merged["citationCount"] = max(citation_counts)
|
|
252
|
+
|
|
253
|
+
# Merge sources (track which APIs returned this paper)
|
|
254
|
+
sources = set()
|
|
255
|
+
for paper in papers:
|
|
256
|
+
if paper.get("paperId"): # Semantic Scholar
|
|
257
|
+
sources.add("semantic_scholar")
|
|
258
|
+
if paper.get("id", "").startswith("https://openalex.org/"):
|
|
259
|
+
sources.add("openalex")
|
|
260
|
+
if paper.get("pmid"):
|
|
261
|
+
sources.add("pubmed")
|
|
262
|
+
|
|
263
|
+
merged["_sources"] = list(sources)
|
|
264
|
+
merged["_duplicate_count"] = len(papers)
|
|
265
|
+
|
|
266
|
+
# Fill in any missing fields from other papers
|
|
267
|
+
for paper in papers:
|
|
268
|
+
for key, value in paper.items():
|
|
269
|
+
if value is not None and merged.get(key) is None:
|
|
270
|
+
merged[key] = value
|
|
271
|
+
|
|
272
|
+
return merged
|
|
273
|
+
|
|
274
|
+
def get_deduplication_stats(self, original: List[Dict[str, Any]], deduplicated: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
275
|
+
"""
|
|
276
|
+
Get statistics about deduplication
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
original: Original paper list
|
|
280
|
+
deduplicated: Deduplicated list
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Deduplication statistics
|
|
284
|
+
"""
|
|
285
|
+
removed = len(original) - len(deduplicated)
|
|
286
|
+
removal_rate = (removed / len(original) * 100) if original else 0
|
|
287
|
+
|
|
288
|
+
# Count by source
|
|
289
|
+
sources = {}
|
|
290
|
+
for paper in deduplicated:
|
|
291
|
+
for source in paper.get("_sources", []):
|
|
292
|
+
sources[source] = sources.get(source, 0) + 1
|
|
293
|
+
|
|
294
|
+
return {
|
|
295
|
+
"original_count": len(original),
|
|
296
|
+
"deduplicated_count": len(deduplicated),
|
|
297
|
+
"removed_count": removed,
|
|
298
|
+
"removal_rate": removal_rate,
|
|
299
|
+
"sources": sources
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
# Global deduplicator instance
|
|
304
|
+
_deduplicator = None
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def get_deduplicator() -> PaperDeduplicator:
|
|
308
|
+
"""Get global deduplicator instance"""
|
|
309
|
+
global _deduplicator
|
|
310
|
+
if _deduplicator is None:
|
|
311
|
+
_deduplicator = PaperDeduplicator()
|
|
312
|
+
return _deduplicator
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def deduplicate_papers(papers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
316
|
+
"""
|
|
317
|
+
Convenience function to deduplicate papers
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
papers: List of papers
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Deduplicated list
|
|
324
|
+
"""
|
|
325
|
+
return get_deduplicator().deduplicate(papers)
|