academic-refchecker 2.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- academic_refchecker-2.0.7.dist-info/METADATA +738 -0
- academic_refchecker-2.0.7.dist-info/RECORD +64 -0
- academic_refchecker-2.0.7.dist-info/WHEEL +5 -0
- academic_refchecker-2.0.7.dist-info/entry_points.txt +3 -0
- academic_refchecker-2.0.7.dist-info/licenses/LICENSE +21 -0
- academic_refchecker-2.0.7.dist-info/top_level.txt +2 -0
- backend/__init__.py +21 -0
- backend/__main__.py +11 -0
- backend/cli.py +64 -0
- backend/concurrency.py +100 -0
- backend/database.py +711 -0
- backend/main.py +1367 -0
- backend/models.py +99 -0
- backend/refchecker_wrapper.py +1126 -0
- backend/static/assets/index-2P6L_39v.css +1 -0
- backend/static/assets/index-hk21nqxR.js +25 -0
- backend/static/favicon.svg +6 -0
- backend/static/index.html +15 -0
- backend/static/vite.svg +1 -0
- backend/thumbnail.py +517 -0
- backend/websocket_manager.py +104 -0
- refchecker/__init__.py +13 -0
- refchecker/__main__.py +11 -0
- refchecker/__version__.py +3 -0
- refchecker/checkers/__init__.py +17 -0
- refchecker/checkers/crossref.py +541 -0
- refchecker/checkers/enhanced_hybrid_checker.py +563 -0
- refchecker/checkers/github_checker.py +326 -0
- refchecker/checkers/local_semantic_scholar.py +540 -0
- refchecker/checkers/openalex.py +513 -0
- refchecker/checkers/openreview_checker.py +984 -0
- refchecker/checkers/pdf_paper_checker.py +493 -0
- refchecker/checkers/semantic_scholar.py +764 -0
- refchecker/checkers/webpage_checker.py +938 -0
- refchecker/config/__init__.py +1 -0
- refchecker/config/logging.conf +36 -0
- refchecker/config/settings.py +170 -0
- refchecker/core/__init__.py +7 -0
- refchecker/core/db_connection_pool.py +141 -0
- refchecker/core/parallel_processor.py +415 -0
- refchecker/core/refchecker.py +5838 -0
- refchecker/database/__init__.py +6 -0
- refchecker/database/download_semantic_scholar_db.py +1725 -0
- refchecker/llm/__init__.py +0 -0
- refchecker/llm/base.py +376 -0
- refchecker/llm/providers.py +911 -0
- refchecker/scripts/__init__.py +1 -0
- refchecker/scripts/start_vllm_server.py +121 -0
- refchecker/services/__init__.py +8 -0
- refchecker/services/pdf_processor.py +268 -0
- refchecker/utils/__init__.py +27 -0
- refchecker/utils/arxiv_utils.py +462 -0
- refchecker/utils/author_utils.py +179 -0
- refchecker/utils/biblatex_parser.py +584 -0
- refchecker/utils/bibliography_utils.py +332 -0
- refchecker/utils/bibtex_parser.py +411 -0
- refchecker/utils/config_validator.py +262 -0
- refchecker/utils/db_utils.py +210 -0
- refchecker/utils/doi_utils.py +190 -0
- refchecker/utils/error_utils.py +482 -0
- refchecker/utils/mock_objects.py +211 -0
- refchecker/utils/text_utils.py +5057 -0
- refchecker/utils/unicode_utils.py +335 -0
- refchecker/utils/url_utils.py +307 -0
|
@@ -0,0 +1,911 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM provider implementations for reference extraction
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import subprocess
|
|
8
|
+
from typing import List, Dict, Any, Optional
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
from .base import LLMProvider
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LLMProviderMixin:
|
|
18
|
+
"""Common functionality for all LLM providers"""
|
|
19
|
+
|
|
20
|
+
def _clean_bibtex_for_llm(self, bibliography_text: str) -> str:
|
|
21
|
+
"""Clean BibTeX text before sending to LLM to remove formatting artifacts"""
|
|
22
|
+
if not bibliography_text:
|
|
23
|
+
return bibliography_text
|
|
24
|
+
|
|
25
|
+
import re
|
|
26
|
+
|
|
27
|
+
# First, protect LaTeX commands from being stripped
|
|
28
|
+
protected_commands = []
|
|
29
|
+
command_pattern = r'\{\\[a-zA-Z]+(?:\s+[^{}]*?)?\}'
|
|
30
|
+
|
|
31
|
+
def protect_command(match):
|
|
32
|
+
protected_commands.append(match.group(0))
|
|
33
|
+
return f"__PROTECTED_LATEX_{len(protected_commands)-1}__"
|
|
34
|
+
|
|
35
|
+
text = re.sub(command_pattern, protect_command, bibliography_text)
|
|
36
|
+
|
|
37
|
+
# Clean up LaTeX math expressions in titles (but preserve the math content)
|
|
38
|
+
# Convert $expression$ to expression and ${expression}$ to expression
|
|
39
|
+
text = re.sub(r'\$\{([^{}]+)\}\$', r'\1', text) # ${expr}$ -> expr
|
|
40
|
+
text = re.sub(r'\$([^$]+)\$', r'\1', text) # $expr$ -> expr
|
|
41
|
+
|
|
42
|
+
# Remove curly braces around titles and other fields
|
|
43
|
+
# Match { content } where content doesn't contain unmatched braces
|
|
44
|
+
text = re.sub(r'\{([^{}]+)\}', r'\1', text)
|
|
45
|
+
|
|
46
|
+
# Clean up DOI and URL field contamination
|
|
47
|
+
# Fix cases where DOI field contains both DOI and URL separated by *
|
|
48
|
+
# Pattern: DOI*URL -> separate them properly
|
|
49
|
+
text = re.sub(r'(doi\s*=\s*\{?)([^}*,]+)\*http([^},\s]*)\}?', r'\1\2},\n url = {http\3}', text)
|
|
50
|
+
text = re.sub(r'(\d+\.\d+/[^*\s,]+)\*http', r'\1,\n url = {http', text)
|
|
51
|
+
|
|
52
|
+
# Clean up asterisk contamination in DOI values within the text
|
|
53
|
+
text = re.sub(r'(10\.[0-9]+/[A-Za-z0-9\-.:()/_]+)\*http', r'\1', text)
|
|
54
|
+
|
|
55
|
+
# Restore protected LaTeX commands
|
|
56
|
+
for i, command in enumerate(protected_commands):
|
|
57
|
+
text = text.replace(f"__PROTECTED_LATEX_{i}__", command)
|
|
58
|
+
|
|
59
|
+
return text
|
|
60
|
+
|
|
61
|
+
def _create_extraction_prompt(self, bibliography_text: str) -> str:
|
|
62
|
+
"""Create prompt for reference extraction"""
|
|
63
|
+
# Clean BibTeX formatting before sending to LLM
|
|
64
|
+
cleaned_bibliography = self._clean_bibtex_for_llm(bibliography_text)
|
|
65
|
+
|
|
66
|
+
return f"""
|
|
67
|
+
Please extract individual references from the following bibliography text. Each reference should be a complete bibliographic entry.
|
|
68
|
+
|
|
69
|
+
Instructions:
|
|
70
|
+
1. Split the bibliography into individual references based on numbered markers like [1], [2], etc.
|
|
71
|
+
2. IMPORTANT: References may span multiple lines. A single reference includes everything from one number marker (e.g., [37]) until the next number marker (e.g., [38])
|
|
72
|
+
3. For each reference, extract: authors, title, publication venue, year, and any URLs/DOIs
|
|
73
|
+
- For BibTeX entries, extract fields correctly:
|
|
74
|
+
* title = the actual paper title from "title" field
|
|
75
|
+
* venue = from "journal", "booktitle", "conference" fields
|
|
76
|
+
* Do NOT confuse journal names like "arXiv preprint arXiv:1234.5678" with paper titles
|
|
77
|
+
4. Include references that are incomplete, like only author names and titles, but ignore ones that are just a URL without other details
|
|
78
|
+
5. Place a hashmark (#) rather than period between fields of a reference, but asterisks (*) between individual authors
|
|
79
|
+
e.g. Author1*Author2*Author3#Title#Venue#Year#URL
|
|
80
|
+
6. CRITICAL: When extracting authors, understand BibTeX author field format correctly
|
|
81
|
+
- In BibTeX, the "author" field contains author names separated by " and " (not commas)
|
|
82
|
+
- Individual author names may be in "Last, First" format (e.g., "Smith, John")
|
|
83
|
+
- Multiple authors are separated by " and " (e.g., "Smith, John and Doe, Jane")
|
|
84
|
+
- SPECIAL CASE for collaborations: Handle "Last, First and others" pattern correctly
|
|
85
|
+
* author = {"Khachatryan, Vardan and others"} → ONE explicit author plus et al: "Vardan Khachatryan*et al"
|
|
86
|
+
* author = {"Smith, John and others"} → ONE explicit author plus et al: "John Smith*et al"
|
|
87
|
+
* The "Last, First and others" pattern indicates a collaboration paper where only the first author is listed explicitly
|
|
88
|
+
- EXAMPLES:
|
|
89
|
+
* author = {"Dolan, Brian P."} → ONE author: "Dolan, Brian P."
|
|
90
|
+
* author = {"Smith, John and Doe, Jane"} → TWO authors: "Smith, John*Doe, Jane"
|
|
91
|
+
* author = {"Arnab, Anurag and Dehghani, Mostafa and Heigold, Georg"} → THREE authors: "Arnab, Anurag*Dehghani, Mostafa*Heigold, Georg"
|
|
92
|
+
* author = {"Khachatryan, Vardan and others"} → ONE explicit author plus et al: "Vardan Khachatryan*et al"
|
|
93
|
+
- Use asterisks (*) to separate individual authors in your output
|
|
94
|
+
- For "Last, First" format, convert to "First Last" for readability (e.g., "Smith, John" → "John Smith")
|
|
95
|
+
- If a BibTeX entry has NO author field, output an empty author field (nothing before the first #)
|
|
96
|
+
- Do NOT infer or guess authors based on title or context - only use what is explicitly stated
|
|
97
|
+
7. CRITICAL: When extracting authors, preserve "et al" and similar indicators exactly as they appear
|
|
98
|
+
- If the original says "John Smith, Jane Doe, et al" then output "John Smith, Jane Doe, et al"
|
|
99
|
+
- If the original says "John Smith et al." then output "John Smith et al."
|
|
100
|
+
- Also preserve variations like "and others", "etc.", "..." when used to indicate additional authors
|
|
101
|
+
- Do NOT expand "et al" into individual author names, even if you know them
|
|
102
|
+
8. Return ONLY the references, one per line
|
|
103
|
+
9. Do not include reference numbers like [1], [2], etc. in your output
|
|
104
|
+
10. Do not add any additional text or explanations
|
|
105
|
+
11. Ensure that URLs and DOIs are from the specific reference only
|
|
106
|
+
- When extracting URLs, preserve the complete URL including protocol
|
|
107
|
+
- For BibTeX howpublished fields, extract the full URL from the field value
|
|
108
|
+
12. When parsing multi-line references, combine all authors from all lines before the title
|
|
109
|
+
13. CRITICAL: If the text contains no valid bibliographic references (e.g., only figures, appendix material, or explanatory text), simply return nothing - do NOT explain why you cannot extract references
|
|
110
|
+
|
|
111
|
+
Bibliography text:
|
|
112
|
+
{cleaned_bibliography}
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def _parse_llm_response(self, content: str) -> List[str]:
|
|
116
|
+
"""Parse LLM response into list of references"""
|
|
117
|
+
if not content:
|
|
118
|
+
return []
|
|
119
|
+
|
|
120
|
+
# Ensure content is a string
|
|
121
|
+
if not isinstance(content, str):
|
|
122
|
+
content = str(content)
|
|
123
|
+
|
|
124
|
+
# Clean the content - remove leading/trailing whitespace
|
|
125
|
+
content = content.strip()
|
|
126
|
+
|
|
127
|
+
# Split by double newlines first to handle paragraph-style formatting
|
|
128
|
+
# then fall back to single newlines
|
|
129
|
+
references = []
|
|
130
|
+
|
|
131
|
+
# Try double newline splitting first (paragraph style)
|
|
132
|
+
if '\n\n' in content:
|
|
133
|
+
potential_refs = content.split('\n\n')
|
|
134
|
+
else:
|
|
135
|
+
# Fall back to single newline splitting
|
|
136
|
+
potential_refs = content.split('\n')
|
|
137
|
+
|
|
138
|
+
for ref in potential_refs:
|
|
139
|
+
ref = ref.strip()
|
|
140
|
+
|
|
141
|
+
# Skip empty lines, headers, and explanatory text
|
|
142
|
+
if not ref:
|
|
143
|
+
continue
|
|
144
|
+
if ref.lower().startswith(('reference', 'here are', 'below are', 'extracted', 'bibliography')):
|
|
145
|
+
continue
|
|
146
|
+
if ref.startswith('#'):
|
|
147
|
+
continue
|
|
148
|
+
if 'extracted from the bibliography' in ref.lower():
|
|
149
|
+
continue
|
|
150
|
+
if 'formatted as a complete' in ref.lower():
|
|
151
|
+
continue
|
|
152
|
+
# Skip verbose LLM explanatory responses
|
|
153
|
+
if 'cannot extract' in ref.lower() and ('references' in ref.lower() or 'bibliographic' in ref.lower()):
|
|
154
|
+
continue
|
|
155
|
+
if 'appears to be from' in ref.lower() and 'appendix' in ref.lower():
|
|
156
|
+
continue
|
|
157
|
+
if 'no numbered reference markers' in ref.lower():
|
|
158
|
+
continue
|
|
159
|
+
if 'only figures' in ref.lower() and 'learning curves' in ref.lower():
|
|
160
|
+
continue
|
|
161
|
+
if ref.lower().startswith('i cannot'):
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
# Remove common prefixes (bullets, numbers, etc.)
|
|
165
|
+
ref = ref.lstrip('- *•')
|
|
166
|
+
ref = ref.strip()
|
|
167
|
+
|
|
168
|
+
# Remove reference numbers like "1.", "[1]", "(1)" from the beginning
|
|
169
|
+
import re
|
|
170
|
+
ref = re.sub(r'^(\d+\.|\[\d+\]|\(\d+\))\s*', '', ref)
|
|
171
|
+
|
|
172
|
+
# Filter out very short lines (likely not complete references)
|
|
173
|
+
if len(ref) > 30: # Increased minimum length for academic references
|
|
174
|
+
references.append(ref)
|
|
175
|
+
|
|
176
|
+
return references
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class OpenAIProvider(LLMProvider, LLMProviderMixin):
|
|
180
|
+
"""OpenAI GPT provider for reference extraction"""
|
|
181
|
+
|
|
182
|
+
def __init__(self, config: Dict[str, Any]):
|
|
183
|
+
super().__init__(config)
|
|
184
|
+
self.api_key = config.get("api_key") or os.getenv("REFCHECKER_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")
|
|
185
|
+
self.client = None
|
|
186
|
+
|
|
187
|
+
if self.api_key:
|
|
188
|
+
try:
|
|
189
|
+
import openai
|
|
190
|
+
self.client = openai.OpenAI(api_key=self.api_key)
|
|
191
|
+
except ImportError:
|
|
192
|
+
logger.error("OpenAI library not installed. Install with: pip install openai")
|
|
193
|
+
|
|
194
|
+
def is_available(self) -> bool:
|
|
195
|
+
return self.client is not None and self.api_key is not None
|
|
196
|
+
|
|
197
|
+
def extract_references(self, bibliography_text: str) -> List[str]:
|
|
198
|
+
return self.extract_references_with_chunking(bibliography_text)
|
|
199
|
+
|
|
200
|
+
def _create_extraction_prompt(self, bibliography_text: str) -> str:
|
|
201
|
+
"""Create prompt for reference extraction"""
|
|
202
|
+
return LLMProviderMixin._create_extraction_prompt(self, bibliography_text)
|
|
203
|
+
|
|
204
|
+
def _call_llm(self, prompt: str) -> str:
|
|
205
|
+
"""Make the actual OpenAI API call and return the response text"""
|
|
206
|
+
try:
|
|
207
|
+
response = self.client.chat.completions.create(
|
|
208
|
+
model=self.model or "gpt-4.1",
|
|
209
|
+
messages=[
|
|
210
|
+
{"role": "user", "content": prompt}
|
|
211
|
+
],
|
|
212
|
+
max_tokens=self.max_tokens,
|
|
213
|
+
temperature=self.temperature
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return response.choices[0].message.content or ""
|
|
217
|
+
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error(f"OpenAI API call failed: {e}")
|
|
220
|
+
raise
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class AnthropicProvider(LLMProvider, LLMProviderMixin):
|
|
224
|
+
"""Anthropic Claude provider for reference extraction"""
|
|
225
|
+
|
|
226
|
+
def __init__(self, config: Dict[str, Any]):
|
|
227
|
+
super().__init__(config)
|
|
228
|
+
self.api_key = config.get("api_key") or os.getenv("REFCHECKER_ANTHROPIC_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
|
|
229
|
+
self.client = None
|
|
230
|
+
|
|
231
|
+
if self.api_key:
|
|
232
|
+
try:
|
|
233
|
+
import anthropic
|
|
234
|
+
self.client = anthropic.Anthropic(api_key=self.api_key)
|
|
235
|
+
except ImportError:
|
|
236
|
+
logger.error("Anthropic library not installed. Install with: pip install anthropic")
|
|
237
|
+
|
|
238
|
+
def is_available(self) -> bool:
|
|
239
|
+
return self.client is not None and self.api_key is not None
|
|
240
|
+
|
|
241
|
+
def extract_references(self, bibliography_text: str) -> List[str]:
|
|
242
|
+
return self.extract_references_with_chunking(bibliography_text)
|
|
243
|
+
|
|
244
|
+
def _create_extraction_prompt(self, bibliography_text: str) -> str:
|
|
245
|
+
"""Create prompt for reference extraction"""
|
|
246
|
+
return LLMProviderMixin._create_extraction_prompt(self, bibliography_text)
|
|
247
|
+
|
|
248
|
+
def _call_llm(self, prompt: str) -> str:
|
|
249
|
+
"""Make the actual Anthropic API call and return the response text"""
|
|
250
|
+
try:
|
|
251
|
+
response = self.client.messages.create(
|
|
252
|
+
model=self.model or "claude-sonnet-4-20250514",
|
|
253
|
+
max_tokens=self.max_tokens,
|
|
254
|
+
temperature=self.temperature,
|
|
255
|
+
messages=[
|
|
256
|
+
{"role": "user", "content": prompt}
|
|
257
|
+
]
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
logger.debug(f"Anthropic response type: {type(response.content)}")
|
|
261
|
+
logger.debug(f"Anthropic response content: {response.content}")
|
|
262
|
+
|
|
263
|
+
# Handle different response formats
|
|
264
|
+
if hasattr(response.content[0], 'text'):
|
|
265
|
+
content = response.content[0].text
|
|
266
|
+
elif isinstance(response.content[0], dict) and 'text' in response.content[0]:
|
|
267
|
+
content = response.content[0]['text']
|
|
268
|
+
elif hasattr(response.content[0], 'content'):
|
|
269
|
+
content = response.content[0].content
|
|
270
|
+
else:
|
|
271
|
+
content = str(response.content[0])
|
|
272
|
+
|
|
273
|
+
# Ensure content is a string
|
|
274
|
+
if not isinstance(content, str):
|
|
275
|
+
content = str(content)
|
|
276
|
+
|
|
277
|
+
return content
|
|
278
|
+
|
|
279
|
+
except Exception as e:
|
|
280
|
+
logger.error(f"Anthropic API call failed: {e}")
|
|
281
|
+
raise
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class GoogleProvider(LLMProvider, LLMProviderMixin):
|
|
285
|
+
"""Google Gemini provider for reference extraction"""
|
|
286
|
+
|
|
287
|
+
def __init__(self, config: Dict[str, Any]):
|
|
288
|
+
super().__init__(config)
|
|
289
|
+
self.api_key = config.get("api_key") or os.getenv("REFCHECKER_GOOGLE_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
|
290
|
+
self.client = None
|
|
291
|
+
|
|
292
|
+
if self.api_key:
|
|
293
|
+
try:
|
|
294
|
+
import google.generativeai as genai
|
|
295
|
+
genai.configure(api_key=self.api_key)
|
|
296
|
+
self.client = genai.GenerativeModel(self.model or "gemini-1.5-flash")
|
|
297
|
+
except ImportError:
|
|
298
|
+
logger.error("Google Generative AI library not installed. Install with: pip install google-generativeai")
|
|
299
|
+
|
|
300
|
+
def is_available(self) -> bool:
|
|
301
|
+
return self.client is not None and self.api_key is not None
|
|
302
|
+
|
|
303
|
+
def extract_references(self, bibliography_text: str) -> List[str]:
|
|
304
|
+
return self.extract_references_with_chunking(bibliography_text)
|
|
305
|
+
|
|
306
|
+
def _create_extraction_prompt(self, bibliography_text: str) -> str:
|
|
307
|
+
"""Create prompt for reference extraction"""
|
|
308
|
+
return LLMProviderMixin._create_extraction_prompt(self, bibliography_text)
|
|
309
|
+
|
|
310
|
+
def _call_llm(self, prompt: str) -> str:
|
|
311
|
+
"""Make the actual Google API call and return the response text"""
|
|
312
|
+
try:
|
|
313
|
+
response = self.client.generate_content(
|
|
314
|
+
prompt,
|
|
315
|
+
generation_config={
|
|
316
|
+
"max_output_tokens": self.max_tokens,
|
|
317
|
+
"temperature": self.temperature,
|
|
318
|
+
}
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Handle empty responses (content safety filter or other issues)
|
|
322
|
+
if not response.candidates:
|
|
323
|
+
logger.warning("Google API returned empty candidates (possibly content filtered)")
|
|
324
|
+
return ""
|
|
325
|
+
|
|
326
|
+
# Safely access the text
|
|
327
|
+
try:
|
|
328
|
+
return response.text or ""
|
|
329
|
+
except (ValueError, AttributeError) as e:
|
|
330
|
+
# response.text raises ValueError if multiple candidates or no text
|
|
331
|
+
logger.warning(f"Could not get text from Google response: {e}")
|
|
332
|
+
# Try to extract text from first candidate manually
|
|
333
|
+
if response.candidates and hasattr(response.candidates[0], 'content'):
|
|
334
|
+
content = response.candidates[0].content
|
|
335
|
+
if hasattr(content, 'parts') and content.parts:
|
|
336
|
+
return content.parts[0].text or ""
|
|
337
|
+
return ""
|
|
338
|
+
|
|
339
|
+
except Exception as e:
|
|
340
|
+
logger.error(f"Google API call failed: {e}")
|
|
341
|
+
raise
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class AzureProvider(LLMProvider, LLMProviderMixin):
|
|
345
|
+
"""Azure OpenAI provider for reference extraction"""
|
|
346
|
+
|
|
347
|
+
def __init__(self, config: Dict[str, Any]):
|
|
348
|
+
super().__init__(config)
|
|
349
|
+
self.api_key = config.get("api_key") or os.getenv("REFCHECKER_AZURE_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY")
|
|
350
|
+
self.endpoint = config.get("endpoint") or os.getenv("REFCHECKER_AZURE_ENDPOINT") or os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
351
|
+
self.client = None
|
|
352
|
+
|
|
353
|
+
logger.debug(f"Azure provider initialized - API key present: {self.api_key is not None}, Endpoint present: {self.endpoint is not None}")
|
|
354
|
+
|
|
355
|
+
if self.api_key and self.endpoint:
|
|
356
|
+
try:
|
|
357
|
+
import openai
|
|
358
|
+
self.client = openai.AzureOpenAI(
|
|
359
|
+
api_key=self.api_key,
|
|
360
|
+
api_version="2024-02-01",
|
|
361
|
+
azure_endpoint=self.endpoint
|
|
362
|
+
)
|
|
363
|
+
logger.debug("Azure OpenAI client created successfully")
|
|
364
|
+
except ImportError:
|
|
365
|
+
logger.error("OpenAI library not installed. Install with: pip install openai")
|
|
366
|
+
else:
|
|
367
|
+
logger.warning(f"Azure provider not available - missing {'API key' if not self.api_key else 'endpoint'}")
|
|
368
|
+
|
|
369
|
+
def is_available(self) -> bool:
|
|
370
|
+
available = self.client is not None and self.api_key is not None and self.endpoint is not None
|
|
371
|
+
if not available:
|
|
372
|
+
logger.debug(f"Azure provider not available: client={self.client is not None}, api_key={self.api_key is not None}, endpoint={self.endpoint is not None}")
|
|
373
|
+
return available
|
|
374
|
+
|
|
375
|
+
def extract_references(self, bibliography_text: str) -> List[str]:
|
|
376
|
+
return self.extract_references_with_chunking(bibliography_text)
|
|
377
|
+
|
|
378
|
+
def _create_extraction_prompt(self, bibliography_text: str) -> str:
|
|
379
|
+
"""Create prompt for reference extraction"""
|
|
380
|
+
return LLMProviderMixin._create_extraction_prompt(self, bibliography_text)
|
|
381
|
+
|
|
382
|
+
def _call_llm(self, prompt: str) -> str:
|
|
383
|
+
"""Make the actual Azure OpenAI API call and return the response text"""
|
|
384
|
+
try:
|
|
385
|
+
response = self.client.chat.completions.create(
|
|
386
|
+
model=self.model or "gpt-4o",
|
|
387
|
+
messages=[
|
|
388
|
+
{"role": "user", "content": prompt}
|
|
389
|
+
],
|
|
390
|
+
max_tokens=self.max_tokens,
|
|
391
|
+
temperature=self.temperature
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
return response.choices[0].message.content or ""
|
|
395
|
+
|
|
396
|
+
except Exception as e:
|
|
397
|
+
logger.error(f"Azure API call failed: {e}")
|
|
398
|
+
raise
|
|
399
|
+
|
|
400
|
+
class vLLMProvider(LLMProvider, LLMProviderMixin):
|
|
401
|
+
"""vLLM provider using OpenAI-compatible server mode for local Hugging Face models"""
|
|
402
|
+
|
|
403
|
+
def __init__(self, config: Dict[str, Any]):
|
|
404
|
+
super().__init__(config)
|
|
405
|
+
self.model_name = config.get("model") or "microsoft/DialoGPT-medium"
|
|
406
|
+
self.server_url = config.get("server_url") or os.getenv("REFCHECKER_VLLM_SERVER_URL") or "http://localhost:8000"
|
|
407
|
+
self.auto_start_server = config.get("auto_start_server", os.getenv("REFCHECKER_VLLM_AUTO_START", "true").lower() == "true")
|
|
408
|
+
self.server_timeout = config.get("server_timeout", int(os.getenv("REFCHECKER_VLLM_TIMEOUT", "300")))
|
|
409
|
+
|
|
410
|
+
# Allow skipping initialization for testing
|
|
411
|
+
self.skip_initialization = config.get("skip_initialization", False)
|
|
412
|
+
|
|
413
|
+
self.client = None
|
|
414
|
+
self.server_process = None
|
|
415
|
+
|
|
416
|
+
logger.info(f"vLLM provider initialized - Server URL: {self.server_url}, Model: {self.model_name}, Auto start: {self.auto_start_server}")
|
|
417
|
+
|
|
418
|
+
# Only initialize if not skipping
|
|
419
|
+
if not self.skip_initialization:
|
|
420
|
+
# Clean debugger environment variables early
|
|
421
|
+
self._clean_debugger_environment()
|
|
422
|
+
|
|
423
|
+
if self.auto_start_server:
|
|
424
|
+
if self._ensure_server_running() == False:
|
|
425
|
+
logger.error("Failed to start vLLM server, provider will not be available")
|
|
426
|
+
# this is a fatal error that shouldn't create the object
|
|
427
|
+
raise Exception("vLLM server failed to start")
|
|
428
|
+
|
|
429
|
+
try:
|
|
430
|
+
import openai
|
|
431
|
+
# vLLM provides OpenAI-compatible API
|
|
432
|
+
self.client = openai.OpenAI(
|
|
433
|
+
api_key="EMPTY", # vLLM doesn't require API key
|
|
434
|
+
base_url=f"{self.server_url}/v1"
|
|
435
|
+
)
|
|
436
|
+
logger.info("OpenAI client configured for vLLM server")
|
|
437
|
+
except ImportError:
|
|
438
|
+
logger.error("OpenAI library not installed. Install with: pip install openai")
|
|
439
|
+
|
|
440
|
+
def _clean_debugger_environment(self):
|
|
441
|
+
"""Clean debugger environment variables that interfere with vLLM"""
|
|
442
|
+
debugger_vars = [
|
|
443
|
+
'DEBUGPY_LAUNCHER_PORT',
|
|
444
|
+
'PYDEVD_LOAD_VALUES_ASYNC',
|
|
445
|
+
'PYDEVD_USE_FRAME_EVAL',
|
|
446
|
+
'PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT'
|
|
447
|
+
]
|
|
448
|
+
|
|
449
|
+
for var in debugger_vars:
|
|
450
|
+
if var in os.environ:
|
|
451
|
+
logger.debug(f"Removing debugger variable: {var}")
|
|
452
|
+
del os.environ[var]
|
|
453
|
+
|
|
454
|
+
# Clean PYTHONPATH of debugger modules
|
|
455
|
+
if 'PYTHONPATH' in os.environ:
|
|
456
|
+
pythonpath_parts = os.environ['PYTHONPATH'].split(':')
|
|
457
|
+
clean_pythonpath = [p for p in pythonpath_parts if 'debugpy' not in p and 'pydevd' not in p]
|
|
458
|
+
if clean_pythonpath != pythonpath_parts:
|
|
459
|
+
logger.debug("Cleaned PYTHONPATH of debugger modules")
|
|
460
|
+
os.environ['PYTHONPATH'] = ':'.join(clean_pythonpath)
|
|
461
|
+
|
|
462
|
+
def _get_optimal_tensor_parallel_size(self):
|
|
463
|
+
"""Determine optimal tensor parallel size based on available GPUs"""
|
|
464
|
+
try:
|
|
465
|
+
import torch
|
|
466
|
+
|
|
467
|
+
available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
|
468
|
+
|
|
469
|
+
if available_gpus <= 1:
|
|
470
|
+
return 1
|
|
471
|
+
|
|
472
|
+
# For most models, use up to 4 GPUs for stability
|
|
473
|
+
return min(available_gpus, 4)
|
|
474
|
+
|
|
475
|
+
except Exception as e:
|
|
476
|
+
logger.debug(f"Error determining tensor parallel size: {e}, defaulting to 1")
|
|
477
|
+
return 1
|
|
478
|
+
|
|
479
|
+
def _kill_existing_server(self):
|
|
480
|
+
"""Kill any existing vLLM server processes"""
|
|
481
|
+
try:
|
|
482
|
+
import subprocess
|
|
483
|
+
# Use a more specific pattern to only kill vLLM server processes, not any process containing "vllm"
|
|
484
|
+
subprocess.run(["pkill", "-f", "vllm.entrypoints.openai.api_server"], timeout=10, capture_output=True)
|
|
485
|
+
import time
|
|
486
|
+
time.sleep(2) # Wait for cleanup
|
|
487
|
+
except Exception as e:
|
|
488
|
+
logger.debug(f"Error killing existing server: {e}")
|
|
489
|
+
|
|
490
|
+
def _start_server(self):
|
|
491
|
+
"""Start vLLM server using standalone launcher"""
|
|
492
|
+
try:
|
|
493
|
+
import subprocess
|
|
494
|
+
import torch
|
|
495
|
+
|
|
496
|
+
# Kill any existing server
|
|
497
|
+
self._kill_existing_server()
|
|
498
|
+
|
|
499
|
+
# Determine optimal tensor parallel size
|
|
500
|
+
tensor_parallel_size = self._get_optimal_tensor_parallel_size()
|
|
501
|
+
|
|
502
|
+
# Always use standalone server launcher for reliability
|
|
503
|
+
return self._start_server_standalone(tensor_parallel_size)
|
|
504
|
+
|
|
505
|
+
except Exception as e:
|
|
506
|
+
logger.error(f"Failed to start vLLM server: {e}")
|
|
507
|
+
return False
|
|
508
|
+
|
|
509
|
+
def _find_vllm_launcher_script(self):
|
|
510
|
+
"""Find the vLLM launcher script, supporting both development and PyPI installs"""
|
|
511
|
+
import pkg_resources
|
|
512
|
+
|
|
513
|
+
# First try to find it as a package resource (for PyPI installs)
|
|
514
|
+
try:
|
|
515
|
+
script_path = pkg_resources.resource_filename('refchecker', 'scripts/start_vllm_server.py')
|
|
516
|
+
if os.path.exists(script_path):
|
|
517
|
+
logger.debug(f"Found vLLM launcher script via pkg_resources: {script_path}")
|
|
518
|
+
return script_path
|
|
519
|
+
except Exception as e:
|
|
520
|
+
logger.debug(f"Could not find script via pkg_resources: {e}")
|
|
521
|
+
|
|
522
|
+
# Try relative path for development installs
|
|
523
|
+
current_dir = os.path.dirname(os.path.dirname(__file__)) # src/llm -> src
|
|
524
|
+
project_root = os.path.dirname(current_dir) # src -> project root
|
|
525
|
+
script_path = os.path.join(project_root, "scripts", "start_vllm_server.py")
|
|
526
|
+
|
|
527
|
+
if os.path.exists(script_path):
|
|
528
|
+
logger.debug(f"Found vLLM launcher script via relative path: {script_path}")
|
|
529
|
+
return script_path
|
|
530
|
+
|
|
531
|
+
# Try looking in the same directory structure as this file (for src-based installs)
|
|
532
|
+
src_dir = os.path.dirname(os.path.dirname(__file__)) # src/llm -> src
|
|
533
|
+
script_path = os.path.join(src_dir, "scripts", "start_vllm_server.py")
|
|
534
|
+
|
|
535
|
+
if os.path.exists(script_path):
|
|
536
|
+
logger.debug(f"Found vLLM launcher script in src directory: {script_path}")
|
|
537
|
+
return script_path
|
|
538
|
+
|
|
539
|
+
# If all else fails, try to create a temporary script
|
|
540
|
+
logger.warning("Could not find standalone vLLM launcher script, creating temporary one")
|
|
541
|
+
return self._create_temporary_launcher_script()
|
|
542
|
+
|
|
543
|
+
def _create_temporary_launcher_script(self):
|
|
544
|
+
"""Create a temporary launcher script if the packaged one cannot be found"""
|
|
545
|
+
import tempfile
|
|
546
|
+
import textwrap
|
|
547
|
+
|
|
548
|
+
# Create a temporary file with the launcher script content
|
|
549
|
+
fd, temp_script_path = tempfile.mkstemp(suffix='.py', prefix='vllm_launcher_')
|
|
550
|
+
|
|
551
|
+
launcher_code = textwrap.dedent('''
|
|
552
|
+
#!/usr/bin/env python3
|
|
553
|
+
"""
|
|
554
|
+
Temporary vLLM server launcher script
|
|
555
|
+
"""
|
|
556
|
+
|
|
557
|
+
import sys
|
|
558
|
+
import subprocess
|
|
559
|
+
import os
|
|
560
|
+
import time
|
|
561
|
+
import argparse
|
|
562
|
+
import signal
|
|
563
|
+
|
|
564
|
+
def start_vllm_server(model_name, port=8000, tensor_parallel_size=1, max_model_len=None, gpu_memory_util=0.9):
|
|
565
|
+
"""Start vLLM server with specified parameters"""
|
|
566
|
+
|
|
567
|
+
# Kill any existing server on the port
|
|
568
|
+
try:
|
|
569
|
+
subprocess.run(["pkill", "-f", "vllm.entrypoints.openai.api_server"],
|
|
570
|
+
timeout=10, capture_output=True)
|
|
571
|
+
time.sleep(2)
|
|
572
|
+
except:
|
|
573
|
+
pass
|
|
574
|
+
|
|
575
|
+
# Build command
|
|
576
|
+
cmd = [
|
|
577
|
+
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
|
|
578
|
+
"--model", model_name,
|
|
579
|
+
"--host", "0.0.0.0",
|
|
580
|
+
"--port", str(port),
|
|
581
|
+
"--tensor-parallel-size", str(tensor_parallel_size),
|
|
582
|
+
"--gpu-memory-utilization", str(gpu_memory_util)
|
|
583
|
+
]
|
|
584
|
+
|
|
585
|
+
if max_model_len:
|
|
586
|
+
cmd.extend(["--max-model-len", str(max_model_len)])
|
|
587
|
+
|
|
588
|
+
print(f"Starting vLLM server: {' '.join(cmd)}")
|
|
589
|
+
|
|
590
|
+
# Create clean environment without debugger variables
|
|
591
|
+
clean_env = {}
|
|
592
|
+
for key, value in os.environ.items():
|
|
593
|
+
if not any(debug_key in key.upper() for debug_key in ['DEBUGPY', 'PYDEVD']):
|
|
594
|
+
clean_env[key] = value
|
|
595
|
+
|
|
596
|
+
# Remove debugger paths from PYTHONPATH if present
|
|
597
|
+
if 'PYTHONPATH' in clean_env:
|
|
598
|
+
pythonpath_parts = clean_env['PYTHONPATH'].split(':')
|
|
599
|
+
clean_pythonpath = [p for p in pythonpath_parts if 'debugpy' not in p and 'pydevd' not in p]
|
|
600
|
+
if clean_pythonpath:
|
|
601
|
+
clean_env['PYTHONPATH'] = ':'.join(clean_pythonpath)
|
|
602
|
+
else:
|
|
603
|
+
del clean_env['PYTHONPATH']
|
|
604
|
+
|
|
605
|
+
# Start server as daemon if requested
|
|
606
|
+
if '--daemon' in sys.argv:
|
|
607
|
+
# Start server in background
|
|
608
|
+
process = subprocess.Popen(cmd, env=clean_env, start_new_session=True,
|
|
609
|
+
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
610
|
+
print(f"Started vLLM server as daemon with PID: {process.pid}")
|
|
611
|
+
else:
|
|
612
|
+
# Start server in foreground
|
|
613
|
+
subprocess.run(cmd, env=clean_env)
|
|
614
|
+
|
|
615
|
+
if __name__ == "__main__":
|
|
616
|
+
parser = argparse.ArgumentParser(description="Start vLLM server")
|
|
617
|
+
parser.add_argument("--model", required=True, help="Model name")
|
|
618
|
+
parser.add_argument("--port", type=int, default=8000, help="Port number")
|
|
619
|
+
parser.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallel size")
|
|
620
|
+
parser.add_argument("--max-model-len", type=int, help="Maximum model length")
|
|
621
|
+
parser.add_argument("--gpu-memory-util", type=float, default=0.9, help="GPU memory utilization")
|
|
622
|
+
parser.add_argument("--daemon", action="store_true", help="Run as daemon")
|
|
623
|
+
|
|
624
|
+
args = parser.parse_args()
|
|
625
|
+
|
|
626
|
+
start_vllm_server(
|
|
627
|
+
model_name=args.model,
|
|
628
|
+
port=args.port,
|
|
629
|
+
tensor_parallel_size=args.tensor_parallel_size,
|
|
630
|
+
max_model_len=args.max_model_len,
|
|
631
|
+
gpu_memory_util=args.gpu_memory_util
|
|
632
|
+
)
|
|
633
|
+
''')
|
|
634
|
+
|
|
635
|
+
try:
|
|
636
|
+
with os.fdopen(fd, 'w') as f:
|
|
637
|
+
f.write(launcher_code)
|
|
638
|
+
|
|
639
|
+
# Make the script executable
|
|
640
|
+
os.chmod(temp_script_path, 0o755)
|
|
641
|
+
|
|
642
|
+
logger.info(f"Created temporary vLLM launcher script: {temp_script_path}")
|
|
643
|
+
return temp_script_path
|
|
644
|
+
|
|
645
|
+
except Exception as e:
|
|
646
|
+
os.close(fd) # Clean up if writing failed
|
|
647
|
+
os.unlink(temp_script_path)
|
|
648
|
+
raise Exception(f"Failed to create temporary launcher script: {e}")
|
|
649
|
+
|
|
650
|
+
def _start_server_standalone(self, tensor_parallel_size):
|
|
651
|
+
"""Start server using standalone script"""
|
|
652
|
+
import subprocess
|
|
653
|
+
import torch
|
|
654
|
+
import os
|
|
655
|
+
|
|
656
|
+
# Find the standalone launcher script - support both development and PyPI installs
|
|
657
|
+
script_path = self._find_vllm_launcher_script()
|
|
658
|
+
|
|
659
|
+
# Build command for standalone launcher
|
|
660
|
+
cmd = [
|
|
661
|
+
"python", script_path,
|
|
662
|
+
"--model", self.model_name,
|
|
663
|
+
"--port", "8000",
|
|
664
|
+
"--tensor-parallel-size", str(tensor_parallel_size)
|
|
665
|
+
]
|
|
666
|
+
|
|
667
|
+
# Add daemon flag unless explicitly disabled via environment variable or debug mode
|
|
668
|
+
# Check if we're in debug mode by examining the current logging level
|
|
669
|
+
import logging
|
|
670
|
+
current_logger = logging.getLogger()
|
|
671
|
+
is_debug_mode = current_logger.getEffectiveLevel() <= logging.DEBUG
|
|
672
|
+
|
|
673
|
+
if not (os.getenv('VLLM_NO_DAEMON', '').lower() in ('1', 'true', 'yes') or is_debug_mode):
|
|
674
|
+
cmd.append("--daemon")
|
|
675
|
+
|
|
676
|
+
# Add memory optimization for smaller GPUs
|
|
677
|
+
if torch.cuda.is_available():
|
|
678
|
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # GB
|
|
679
|
+
if gpu_memory < 40: # Less than 40GB VRAM
|
|
680
|
+
cmd.extend([
|
|
681
|
+
"--gpu-memory-util", "0.8",
|
|
682
|
+
"--max-model-len", "4096"
|
|
683
|
+
])
|
|
684
|
+
|
|
685
|
+
logger.info(f"Starting vLLM server via standalone launcher: {' '.join(cmd)}")
|
|
686
|
+
|
|
687
|
+
# Check if daemon mode is disabled
|
|
688
|
+
daemon_mode = "--daemon" in cmd
|
|
689
|
+
|
|
690
|
+
if daemon_mode:
|
|
691
|
+
# Daemon mode: start launcher and wait for it to complete
|
|
692
|
+
launcher_timeout = 120 # 2 minutes for launcher to complete
|
|
693
|
+
|
|
694
|
+
try:
|
|
695
|
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=launcher_timeout)
|
|
696
|
+
|
|
697
|
+
if result.returncode == 0:
|
|
698
|
+
logger.info("vLLM server launcher completed successfully")
|
|
699
|
+
logger.debug(f"Launcher stdout: {result.stdout}")
|
|
700
|
+
# The actual server process is running as daemon, we don't have direct handle
|
|
701
|
+
self.server_process = None # We don't manage the daemon directly
|
|
702
|
+
return True
|
|
703
|
+
else:
|
|
704
|
+
logger.error(f"vLLM server launcher failed with return code {result.returncode}")
|
|
705
|
+
logger.error(f"Launcher stderr: {result.stderr}")
|
|
706
|
+
logger.error(f"Launcher stdout: {result.stdout}")
|
|
707
|
+
return False
|
|
708
|
+
|
|
709
|
+
except subprocess.TimeoutExpired:
|
|
710
|
+
logger.error(f"vLLM server launcher timed out after {launcher_timeout} seconds")
|
|
711
|
+
logger.error("This may happen if the model is large and takes time to download/load")
|
|
712
|
+
return False
|
|
713
|
+
|
|
714
|
+
else:
|
|
715
|
+
# Non-daemon mode: start launcher and let it stream output
|
|
716
|
+
logger.info("Starting vLLM server in non-daemon mode (output will be visible)")
|
|
717
|
+
try:
|
|
718
|
+
# Start the launcher without capturing output so logs are visible
|
|
719
|
+
process = subprocess.Popen(cmd, stdout=None, stderr=None)
|
|
720
|
+
self.server_process = process
|
|
721
|
+
|
|
722
|
+
# Give the server a moment to start
|
|
723
|
+
import time
|
|
724
|
+
time.sleep(5)
|
|
725
|
+
|
|
726
|
+
# Check if the process is still running (hasn't crashed immediately)
|
|
727
|
+
if process.poll() is None:
|
|
728
|
+
logger.info("vLLM server launcher started successfully in foreground mode")
|
|
729
|
+
return True
|
|
730
|
+
else:
|
|
731
|
+
logger.error(f"vLLM server launcher exited immediately with code {process.returncode}")
|
|
732
|
+
return False
|
|
733
|
+
|
|
734
|
+
except Exception as e:
|
|
735
|
+
logger.error(f"Failed to start vLLM server launcher: {e}")
|
|
736
|
+
return False
|
|
737
|
+
|
|
738
|
+
def _wait_for_server(self, timeout=300):
|
|
739
|
+
"""Wait for vLLM server to be ready"""
|
|
740
|
+
import time
|
|
741
|
+
import requests
|
|
742
|
+
|
|
743
|
+
start_time = time.time()
|
|
744
|
+
|
|
745
|
+
logger.info(f"Waiting for vLLM server to start (timeout: {timeout}s)...")
|
|
746
|
+
|
|
747
|
+
while (time.time() - start_time) < timeout:
|
|
748
|
+
try:
|
|
749
|
+
# Check health endpoint
|
|
750
|
+
response = requests.get(f"{self.server_url}/health", timeout=5)
|
|
751
|
+
if response.status_code == 200:
|
|
752
|
+
logger.info("vLLM server health check passed")
|
|
753
|
+
|
|
754
|
+
# Check models endpoint
|
|
755
|
+
response = requests.get(f"{self.server_url}/v1/models", timeout=5)
|
|
756
|
+
if response.status_code == 200:
|
|
757
|
+
models_data = response.json()
|
|
758
|
+
loaded_models = [model["id"] for model in models_data.get("data", [])]
|
|
759
|
+
logger.info(f"vLLM server is ready with models: {loaded_models}")
|
|
760
|
+
return True
|
|
761
|
+
|
|
762
|
+
except requests.exceptions.RequestException as e:
|
|
763
|
+
logger.debug(f"Server not ready yet: {e}")
|
|
764
|
+
pass
|
|
765
|
+
|
|
766
|
+
elapsed = time.time() - start_time
|
|
767
|
+
if elapsed % 30 == 0: # Log every 30 seconds
|
|
768
|
+
logger.info(f"Still waiting for server... ({elapsed:.0f}s elapsed)")
|
|
769
|
+
|
|
770
|
+
time.sleep(2)
|
|
771
|
+
|
|
772
|
+
logger.error(f"vLLM server failed to start within {timeout} seconds")
|
|
773
|
+
return False
|
|
774
|
+
|
|
775
|
+
def _ensure_server_running(self):
|
|
776
|
+
"""Ensure vLLM server is running, start if necessary"""
|
|
777
|
+
# First check if server is already running
|
|
778
|
+
if self._check_server_health():
|
|
779
|
+
logger.info("vLLM server is already running and healthy")
|
|
780
|
+
return True
|
|
781
|
+
|
|
782
|
+
logger.info("Starting vLLM server...")
|
|
783
|
+
|
|
784
|
+
# Try to start the server
|
|
785
|
+
if self._start_server():
|
|
786
|
+
if self._wait_for_server(self.server_timeout):
|
|
787
|
+
return True
|
|
788
|
+
|
|
789
|
+
# If we get here, server failed to start
|
|
790
|
+
logger.error("Server startup failed")
|
|
791
|
+
return False
|
|
792
|
+
|
|
793
|
+
def _check_server_health(self):
|
|
794
|
+
"""Check if vLLM server is healthy and has the correct model"""
|
|
795
|
+
try:
|
|
796
|
+
import requests
|
|
797
|
+
|
|
798
|
+
# First check if server is responding
|
|
799
|
+
response = requests.get(f"{self.server_url}/health", timeout=10)
|
|
800
|
+
if response.status_code != 200:
|
|
801
|
+
logger.debug(f"Health check failed: {response.status_code}")
|
|
802
|
+
return False
|
|
803
|
+
|
|
804
|
+
# Check if the correct model is loaded
|
|
805
|
+
response = requests.get(f"{self.server_url}/v1/models", timeout=10)
|
|
806
|
+
if response.status_code == 200:
|
|
807
|
+
models_data = response.json()
|
|
808
|
+
loaded_models = [model["id"] for model in models_data.get("data", [])]
|
|
809
|
+
if self.model_name in loaded_models:
|
|
810
|
+
logger.debug(f"Correct model {self.model_name} is loaded")
|
|
811
|
+
return True
|
|
812
|
+
else:
|
|
813
|
+
logger.info(f"Wrong model loaded. Expected: {self.model_name}, Found: {loaded_models}")
|
|
814
|
+
return False
|
|
815
|
+
|
|
816
|
+
return False
|
|
817
|
+
|
|
818
|
+
except requests.exceptions.RequestException as e:
|
|
819
|
+
logger.debug(f"Server health check failed: {e}")
|
|
820
|
+
return False
|
|
821
|
+
|
|
822
|
+
def is_available(self) -> bool:
|
|
823
|
+
"""Check if vLLM server is available"""
|
|
824
|
+
if not self.client:
|
|
825
|
+
return False
|
|
826
|
+
|
|
827
|
+
# Check server health
|
|
828
|
+
if self._check_server_health():
|
|
829
|
+
return True
|
|
830
|
+
|
|
831
|
+
# If auto_start_server is enabled, try to start it
|
|
832
|
+
if self.auto_start_server:
|
|
833
|
+
logger.info("vLLM server not responding, attempting to restart...")
|
|
834
|
+
return self._ensure_server_running()
|
|
835
|
+
|
|
836
|
+
return False
|
|
837
|
+
|
|
838
|
+
def extract_references(self, bibliography_text: str) -> List[str]:
|
|
839
|
+
return self.extract_references_with_chunking(bibliography_text)
|
|
840
|
+
|
|
841
|
+
def _create_extraction_prompt(self, bibliography_text: str) -> str:
|
|
842
|
+
"""Create prompt for reference extraction"""
|
|
843
|
+
return LLMProviderMixin._create_extraction_prompt(self, bibliography_text)
|
|
844
|
+
|
|
845
|
+
def _call_llm(self, prompt: str) -> str:
|
|
846
|
+
"""Make the actual vLLM API call and return the response text"""
|
|
847
|
+
try:
|
|
848
|
+
logger.debug(f"Sending prompt to vLLM server (length: {len(prompt)})")
|
|
849
|
+
|
|
850
|
+
# Use chat completions API - vLLM will automatically apply chat templates
|
|
851
|
+
response = self.client.chat.completions.create(
|
|
852
|
+
model=self.model_name,
|
|
853
|
+
messages=[
|
|
854
|
+
{"role": "user", "content": prompt}
|
|
855
|
+
],
|
|
856
|
+
max_tokens=self.max_tokens,
|
|
857
|
+
temperature=self.temperature,
|
|
858
|
+
stop=None # Let the model use its default stop tokens
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
content = response.choices[0].message.content
|
|
862
|
+
|
|
863
|
+
logger.debug(f"Received response from vLLM server:")
|
|
864
|
+
logger.debug(f" Length: {len(content)}")
|
|
865
|
+
logger.debug(f" First 200 chars: {content[:200]}...")
|
|
866
|
+
logger.debug(f" Finish reason: {response.choices[0].finish_reason}")
|
|
867
|
+
|
|
868
|
+
return content or ""
|
|
869
|
+
|
|
870
|
+
except Exception as e:
|
|
871
|
+
logger.error(f"vLLM server API call failed: {e}")
|
|
872
|
+
raise
|
|
873
|
+
|
|
874
|
+
def test_server_response(self):
|
|
875
|
+
"""Test method to verify server is responding correctly"""
|
|
876
|
+
if not self.is_available():
|
|
877
|
+
print("Server not available")
|
|
878
|
+
return
|
|
879
|
+
|
|
880
|
+
test_prompt = "What is 2+2? Answer briefly."
|
|
881
|
+
|
|
882
|
+
try:
|
|
883
|
+
response = self.client.chat.completions.create(
|
|
884
|
+
model=self.model_name,
|
|
885
|
+
messages=[
|
|
886
|
+
{"role": "user", "content": test_prompt}
|
|
887
|
+
],
|
|
888
|
+
max_tokens=50,
|
|
889
|
+
temperature=0.1
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
content = response.choices[0].message.content
|
|
893
|
+
print(f"Test successful!")
|
|
894
|
+
print(f"Prompt: {test_prompt}")
|
|
895
|
+
print(f"Response: {content}")
|
|
896
|
+
print(f"Finish reason: {response.choices[0].finish_reason}")
|
|
897
|
+
|
|
898
|
+
except Exception as e:
|
|
899
|
+
print(f"Test failed: {e}")
|
|
900
|
+
|
|
901
|
+
def cleanup(self):
|
|
902
|
+
"""Cleanup vLLM server resources"""
|
|
903
|
+
logger.info("Shutting down vLLM server...")
|
|
904
|
+
try:
|
|
905
|
+
self._kill_existing_server()
|
|
906
|
+
except Exception as e:
|
|
907
|
+
logger.error(f"Error during vLLM server cleanup: {e}")
|
|
908
|
+
|
|
909
|
+
def __del__(self):
|
|
910
|
+
"""Cleanup on deletion"""
|
|
911
|
+
self.cleanup()
|