lyrics-transcriber 0.43.1__py3-none-any.whl → 0.45.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lyrics_transcriber/core/controller.py +58 -24
- lyrics_transcriber/correction/anchor_sequence.py +22 -8
- lyrics_transcriber/correction/corrector.py +47 -3
- lyrics_transcriber/correction/handlers/llm.py +15 -12
- lyrics_transcriber/correction/handlers/llm_providers.py +60 -0
- lyrics_transcriber/frontend/.yarn/install-state.gz +0 -0
- lyrics_transcriber/frontend/dist/assets/{index-D0Gr3Ep7.js → index-ZCT0s9MG.js} +10174 -6197
- lyrics_transcriber/frontend/dist/assets/index-ZCT0s9MG.js.map +1 -0
- lyrics_transcriber/frontend/dist/index.html +1 -1
- lyrics_transcriber/frontend/src/App.tsx +5 -5
- lyrics_transcriber/frontend/src/api.ts +37 -0
- lyrics_transcriber/frontend/src/components/AddLyricsModal.tsx +114 -0
- lyrics_transcriber/frontend/src/components/AudioPlayer.tsx +14 -10
- lyrics_transcriber/frontend/src/components/CorrectionMetrics.tsx +62 -56
- lyrics_transcriber/frontend/src/components/EditActionBar.tsx +68 -0
- lyrics_transcriber/frontend/src/components/EditModal.tsx +467 -399
- lyrics_transcriber/frontend/src/components/EditTimelineSection.tsx +373 -0
- lyrics_transcriber/frontend/src/components/EditWordList.tsx +308 -0
- lyrics_transcriber/frontend/src/components/FindReplaceModal.tsx +467 -0
- lyrics_transcriber/frontend/src/components/Header.tsx +141 -101
- lyrics_transcriber/frontend/src/components/LyricsAnalyzer.tsx +569 -107
- lyrics_transcriber/frontend/src/components/ModeSelector.tsx +22 -13
- lyrics_transcriber/frontend/src/components/PreviewVideoSection.tsx +1 -0
- lyrics_transcriber/frontend/src/components/ReferenceView.tsx +29 -12
- lyrics_transcriber/frontend/src/components/ReviewChangesModal.tsx +21 -4
- lyrics_transcriber/frontend/src/components/TimelineEditor.tsx +29 -15
- lyrics_transcriber/frontend/src/components/TranscriptionView.tsx +36 -18
- lyrics_transcriber/frontend/src/components/WordDivider.tsx +187 -0
- lyrics_transcriber/frontend/src/components/shared/components/HighlightedText.tsx +89 -41
- lyrics_transcriber/frontend/src/components/shared/components/SourceSelector.tsx +9 -2
- lyrics_transcriber/frontend/src/components/shared/components/Word.tsx +27 -3
- lyrics_transcriber/frontend/src/components/shared/types.ts +17 -2
- lyrics_transcriber/frontend/src/components/shared/utils/keyboardHandlers.ts +90 -19
- lyrics_transcriber/frontend/src/components/shared/utils/segmentOperations.ts +192 -0
- lyrics_transcriber/frontend/src/hooks/useManualSync.ts +267 -0
- lyrics_transcriber/frontend/src/main.tsx +7 -1
- lyrics_transcriber/frontend/src/theme.ts +177 -0
- lyrics_transcriber/frontend/src/types.ts +1 -1
- lyrics_transcriber/frontend/tsconfig.tsbuildinfo +1 -1
- lyrics_transcriber/lyrics/base_lyrics_provider.py +2 -2
- lyrics_transcriber/lyrics/user_input_provider.py +44 -0
- lyrics_transcriber/output/generator.py +40 -12
- lyrics_transcriber/review/server.py +238 -8
- {lyrics_transcriber-0.43.1.dist-info → lyrics_transcriber-0.45.0.dist-info}/METADATA +3 -2
- {lyrics_transcriber-0.43.1.dist-info → lyrics_transcriber-0.45.0.dist-info}/RECORD +48 -40
- lyrics_transcriber/frontend/dist/assets/index-D0Gr3Ep7.js.map +0 -1
- lyrics_transcriber/frontend/src/components/DetailsModal.tsx +0 -252
- lyrics_transcriber/frontend/src/components/WordEditControls.tsx +0 -110
- {lyrics_transcriber-0.43.1.dist-info → lyrics_transcriber-0.45.0.dist-info}/LICENSE +0 -0
- {lyrics_transcriber-0.43.1.dist-info → lyrics_transcriber-0.45.0.dist-info}/WHEEL +0 -0
- {lyrics_transcriber-0.43.1.dist-info → lyrics_transcriber-0.45.0.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import logging
|
3
|
+
import json
|
3
4
|
from dataclasses import dataclass, field
|
4
5
|
from typing import Dict, Optional, List
|
5
6
|
from lyrics_transcriber.types import LyricsData, TranscriptionResult, CorrectionResult
|
@@ -224,7 +225,29 @@ class LyricsTranscriber:
|
|
224
225
|
|
225
226
|
self.logger.info(f"LyricsTranscriber controller beginning processing for {self.artist} - {self.title}")
|
226
227
|
|
227
|
-
#
|
228
|
+
# Check for existing corrections JSON
|
229
|
+
corrections_json_path = os.path.join(self.output_config.output_dir, f"{self.output_prefix} (Lyrics Corrections).json")
|
230
|
+
|
231
|
+
if os.path.exists(corrections_json_path):
|
232
|
+
self.logger.info(f"Found existing corrections JSON: {corrections_json_path}")
|
233
|
+
try:
|
234
|
+
with open(corrections_json_path, "r", encoding="utf-8") as f:
|
235
|
+
corrections_data = json.load(f)
|
236
|
+
|
237
|
+
# Reconstruct CorrectionResult from JSON
|
238
|
+
self.results.transcription_corrected = CorrectionResult.from_dict(corrections_data)
|
239
|
+
self.logger.info("Successfully loaded existing corrections data")
|
240
|
+
|
241
|
+
# Skip to output generation
|
242
|
+
self.generate_outputs()
|
243
|
+
self.logger.info("Processing completed successfully using existing corrections")
|
244
|
+
return self.results
|
245
|
+
|
246
|
+
except Exception as e:
|
247
|
+
self.logger.error(f"Failed to load existing corrections JSON: {str(e)}")
|
248
|
+
# Continue with normal processing if loading fails
|
249
|
+
|
250
|
+
# Normal processing flow continues...
|
228
251
|
if self.output_config.fetch_lyrics and self.artist and self.title:
|
229
252
|
self.fetch_lyrics()
|
230
253
|
else:
|
@@ -298,6 +321,9 @@ class LyricsTranscriber:
|
|
298
321
|
sorted_results = sorted(self.results.transcription_results, key=lambda x: x.priority)
|
299
322
|
best_transcription = sorted_results[0]
|
300
323
|
|
324
|
+
# Count total words in the transcription
|
325
|
+
total_words = sum(len(segment.words) for segment in best_transcription.result.segments)
|
326
|
+
|
301
327
|
# Create a CorrectionResult with no corrections
|
302
328
|
self.results.transcription_corrected = CorrectionResult(
|
303
329
|
original_segments=best_transcription.result.segments,
|
@@ -308,39 +334,47 @@ class LyricsTranscriber:
|
|
308
334
|
reference_lyrics={},
|
309
335
|
anchor_sequences=[],
|
310
336
|
gap_sequences=[],
|
311
|
-
resized_segments=[],
|
337
|
+
resized_segments=[],
|
338
|
+
correction_steps=[],
|
339
|
+
word_id_map={},
|
340
|
+
segment_id_map={},
|
312
341
|
metadata={
|
313
342
|
"correction_type": "none",
|
314
343
|
"reason": "no_reference_lyrics",
|
315
|
-
"audio_filepath": self.audio_filepath,
|
344
|
+
"audio_filepath": self.audio_filepath,
|
345
|
+
"anchor_sequences_count": 0,
|
346
|
+
"gap_sequences_count": 0,
|
347
|
+
"total_words": total_words,
|
348
|
+
"correction_ratio": 0.0,
|
349
|
+
"available_handlers": [],
|
350
|
+
"enabled_handlers": [],
|
316
351
|
},
|
317
352
|
)
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
}
|
353
|
+
else:
|
354
|
+
# Create metadata dict with song info
|
355
|
+
metadata = {
|
356
|
+
"artist": self.artist,
|
357
|
+
"title": self.title,
|
358
|
+
"full_reference_texts": {source: lyrics.get_full_text() for source, lyrics in self.results.lyrics_results.items()},
|
359
|
+
}
|
326
360
|
|
327
|
-
|
328
|
-
|
361
|
+
# Get enabled handlers from metadata if available
|
362
|
+
enabled_handlers = metadata.get("enabled_handlers", None)
|
329
363
|
|
330
|
-
|
331
|
-
|
364
|
+
# Create corrector with enabled handlers
|
365
|
+
corrector = LyricsCorrector(cache_dir=self.output_config.cache_dir, enabled_handlers=enabled_handlers, logger=self.logger)
|
332
366
|
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
367
|
+
corrected_data = corrector.run(
|
368
|
+
transcription_results=self.results.transcription_results,
|
369
|
+
lyrics_results=self.results.lyrics_results,
|
370
|
+
metadata=metadata,
|
371
|
+
)
|
338
372
|
|
339
|
-
|
340
|
-
|
341
|
-
|
373
|
+
# Store corrected results
|
374
|
+
self.results.transcription_corrected = corrected_data
|
375
|
+
self.logger.info("Lyrics correction completed")
|
342
376
|
|
343
|
-
# Add human review step
|
377
|
+
# Add human review step (moved outside the else block)
|
344
378
|
if self.output_config.enable_review:
|
345
379
|
from lyrics_transcriber.review.server import ReviewServer
|
346
380
|
|
@@ -96,14 +96,17 @@ class AnchorSequenceFinder:
|
|
96
96
|
|
97
97
|
def _get_cache_key(self, transcribed: str, references: Dict[str, LyricsData], transcription_result: TranscriptionResult) -> str:
|
98
98
|
"""Generate a unique cache key for the input combination."""
|
99
|
-
# Create a string that uniquely identifies the inputs,
|
100
|
-
# Use only the text content, not IDs or other potentially varying metadata
|
99
|
+
# Create a string that uniquely identifies the inputs, including word IDs
|
101
100
|
ref_texts = []
|
102
101
|
for source, lyrics in sorted(references.items()):
|
103
|
-
|
104
|
-
|
102
|
+
# Include both text and ID for each word to ensure cache uniqueness
|
103
|
+
words_with_ids = [f"{w.text}:{w.id}" for s in lyrics.segments for w in s.words]
|
104
|
+
ref_texts.append(f"{source}:{','.join(words_with_ids)}")
|
105
105
|
|
106
|
-
|
106
|
+
# Also include transcription word IDs to ensure complete matching
|
107
|
+
trans_words_with_ids = [f"{w.text}:{w.id}" for s in transcription_result.segments for w in s.words]
|
108
|
+
|
109
|
+
input_str = f"{transcribed}|" f"{','.join(trans_words_with_ids)}|" f"{','.join(ref_texts)}"
|
107
110
|
return hashlib.md5(input_str.encode()).hexdigest()
|
108
111
|
|
109
112
|
def _save_to_cache(self, cache_path: Path, anchors: List[ScoredAnchor]) -> None:
|
@@ -259,15 +262,26 @@ class AnchorSequenceFinder:
|
|
259
262
|
for segment in transcription_result.segments:
|
260
263
|
all_words.extend(segment.words)
|
261
264
|
|
262
|
-
# Clean and split texts
|
263
|
-
trans_words = [w.text.lower().strip('.,?!"\n') for w in all_words]
|
265
|
+
# Clean and split texts
|
266
|
+
trans_words = [w.text.lower().strip('.,?!"\n') for w in all_words]
|
264
267
|
ref_texts_clean = {
|
265
268
|
source: self._clean_text(" ".join(w.text for s in lyrics.segments for w in s.words)).split()
|
266
269
|
for source, lyrics in references.items()
|
267
270
|
}
|
268
271
|
ref_words = {source: [w for s in lyrics.segments for w in s.words] for source, lyrics in references.items()}
|
269
272
|
|
270
|
-
|
273
|
+
# Filter out very short reference sources for n-gram length calculation
|
274
|
+
valid_ref_lengths = [
|
275
|
+
len(words) for words in ref_texts_clean.values()
|
276
|
+
if len(words) >= self.min_sequence_length
|
277
|
+
]
|
278
|
+
|
279
|
+
if not valid_ref_lengths:
|
280
|
+
self.logger.warning("No reference sources long enough for anchor detection")
|
281
|
+
return []
|
282
|
+
|
283
|
+
# Calculate max length using only valid reference sources
|
284
|
+
max_length = min(len(trans_words), min(valid_ref_lengths))
|
271
285
|
n_gram_lengths = range(max_length, self.min_sequence_length - 1, -1)
|
272
286
|
|
273
287
|
# Process n-gram lengths in parallel
|
@@ -2,6 +2,7 @@ from typing import List, Optional, Tuple, Union, Dict, Any
|
|
2
2
|
import logging
|
3
3
|
from pathlib import Path
|
4
4
|
from copy import deepcopy
|
5
|
+
import os
|
5
6
|
|
6
7
|
from lyrics_transcriber.correction.handlers.levenshtein import LevenshteinHandler
|
7
8
|
from lyrics_transcriber.correction.handlers.llm import LLMHandler
|
@@ -25,6 +26,7 @@ from lyrics_transcriber.correction.anchor_sequence import AnchorSequenceFinder
|
|
25
26
|
from lyrics_transcriber.correction.handlers.base import GapCorrectionHandler
|
26
27
|
from lyrics_transcriber.correction.handlers.extend_anchor import ExtendAnchorHandler
|
27
28
|
from lyrics_transcriber.utils.word_utils import WordUtils
|
29
|
+
from lyrics_transcriber.correction.handlers.llm_providers import OllamaProvider, OpenAIProvider
|
28
30
|
|
29
31
|
|
30
32
|
class LyricsCorrector:
|
@@ -60,12 +62,54 @@ class LyricsCorrector:
|
|
60
62
|
("SyllablesMatchHandler", SyllablesMatchHandler(logger=self.logger)),
|
61
63
|
("RelaxedWordCountMatchHandler", RelaxedWordCountMatchHandler(logger=self.logger)),
|
62
64
|
("NoSpacePunctuationMatchHandler", NoSpacePunctuationMatchHandler(logger=self.logger)),
|
63
|
-
(
|
65
|
+
(
|
66
|
+
"LLMHandler_Ollama_R17B",
|
67
|
+
LLMHandler(
|
68
|
+
provider=OllamaProvider(model="deepseek-r1:7b", logger=self.logger),
|
69
|
+
name="LLMHandler_Ollama_R17B",
|
70
|
+
logger=self.logger,
|
71
|
+
cache_dir=self._cache_dir,
|
72
|
+
),
|
73
|
+
),
|
64
74
|
("RepeatCorrectionHandler", RepeatCorrectionHandler(logger=self.logger)),
|
65
75
|
("SoundAlikeHandler", SoundAlikeHandler(logger=self.logger)),
|
66
76
|
("LevenshteinHandler", LevenshteinHandler(logger=self.logger)),
|
67
77
|
]
|
68
78
|
|
79
|
+
# Add OpenRouter handlers only if API key is available
|
80
|
+
if os.getenv("OPENROUTER_API_KEY"):
|
81
|
+
openrouter_handlers = [
|
82
|
+
(
|
83
|
+
"LLMHandler_OpenRouter_Sonnet",
|
84
|
+
LLMHandler(
|
85
|
+
provider=OpenAIProvider(
|
86
|
+
model="anthropic/claude-3-sonnet",
|
87
|
+
api_key=os.getenv("OPENROUTER_API_KEY"),
|
88
|
+
base_url="https://openrouter.ai/api/v1",
|
89
|
+
logger=self.logger,
|
90
|
+
),
|
91
|
+
name="LLMHandler_OpenRouter_Sonnet",
|
92
|
+
logger=self.logger,
|
93
|
+
cache_dir=self._cache_dir,
|
94
|
+
),
|
95
|
+
),
|
96
|
+
(
|
97
|
+
"LLMHandler_OpenRouter_R1",
|
98
|
+
LLMHandler(
|
99
|
+
provider=OpenAIProvider(
|
100
|
+
model="deepseek/deepseek-r1",
|
101
|
+
api_key=os.getenv("OPENROUTER_API_KEY"),
|
102
|
+
base_url="https://openrouter.ai/api/v1",
|
103
|
+
logger=self.logger,
|
104
|
+
),
|
105
|
+
name="LLMHandler_OpenRouter_R1",
|
106
|
+
logger=self.logger,
|
107
|
+
cache_dir=self._cache_dir,
|
108
|
+
),
|
109
|
+
),
|
110
|
+
]
|
111
|
+
all_handlers.extend(openrouter_handlers)
|
112
|
+
|
69
113
|
# Store all handler information
|
70
114
|
self.all_handlers = [
|
71
115
|
{
|
@@ -127,8 +171,8 @@ class LyricsCorrector:
|
|
127
171
|
corrections_made = len(corrections)
|
128
172
|
correction_ratio = 1 - (corrections_made / total_words if total_words > 0 else 0)
|
129
173
|
|
130
|
-
# Get the currently enabled handler IDs using
|
131
|
-
enabled_handlers = [handler.__class__.__name__ for handler in self.handlers]
|
174
|
+
# Get the currently enabled handler IDs using the handler's name attribute if available
|
175
|
+
enabled_handlers = [getattr(handler, "name", handler.__class__.__name__) for handler in self.handlers]
|
132
176
|
|
133
177
|
return CorrectionResult(
|
134
178
|
original_segments=primary_transcription.segments,
|
@@ -1,22 +1,25 @@
|
|
1
1
|
from typing import List, Optional, Tuple, Dict, Any, Union
|
2
2
|
import logging
|
3
3
|
import json
|
4
|
-
from ollama import chat
|
5
4
|
from datetime import datetime
|
6
5
|
from pathlib import Path
|
7
6
|
|
8
7
|
from lyrics_transcriber.types import GapSequence, WordCorrection
|
9
8
|
from lyrics_transcriber.correction.handlers.base import GapCorrectionHandler
|
10
9
|
from lyrics_transcriber.correction.handlers.word_operations import WordOperations
|
10
|
+
from lyrics_transcriber.correction.handlers.llm_providers import LLMProvider
|
11
11
|
|
12
12
|
|
13
13
|
class LLMHandler(GapCorrectionHandler):
|
14
14
|
"""Uses an LLM to analyze and correct gaps by comparing with reference lyrics."""
|
15
15
|
|
16
|
-
def __init__(
|
16
|
+
def __init__(
|
17
|
+
self, provider: LLMProvider, name: str, logger: Optional[logging.Logger] = None, cache_dir: Optional[Union[str, Path]] = None
|
18
|
+
):
|
17
19
|
super().__init__(logger)
|
18
20
|
self.logger = logger or logging.getLogger(__name__)
|
19
|
-
self.
|
21
|
+
self.provider = provider
|
22
|
+
self.name = name
|
20
23
|
self.cache_dir = Path(cache_dir) if cache_dir else None
|
21
24
|
|
22
25
|
def _format_prompt(self, gap: GapSequence, data: Optional[Dict[str, Any]] = None) -> str:
|
@@ -160,16 +163,16 @@ class LLMHandler(GapCorrectionHandler):
|
|
160
163
|
self.logger.debug(f"Processing gap words: {transcribed_words}")
|
161
164
|
self.logger.debug(f"Reference word IDs: {gap.reference_word_ids}")
|
162
165
|
|
163
|
-
response =
|
166
|
+
response = self.provider.generate_response(prompt)
|
164
167
|
|
165
168
|
# Write debug info to files
|
166
|
-
self._write_debug_info(prompt, response
|
169
|
+
self._write_debug_info(prompt, response, gap_index, audio_file_hash=data.get("audio_file_hash"))
|
167
170
|
|
168
171
|
try:
|
169
|
-
corrections_data = json.loads(response
|
172
|
+
corrections_data = json.loads(response)
|
170
173
|
except json.JSONDecodeError as e:
|
171
174
|
self.logger.error(f"Failed to parse LLM response as JSON: {e}")
|
172
|
-
self.logger.error(f"Raw response content: {response
|
175
|
+
self.logger.error(f"Raw response content: {response}")
|
173
176
|
return []
|
174
177
|
|
175
178
|
# Check if corrections exist and are non-empty
|
@@ -202,7 +205,7 @@ class LLMHandler(GapCorrectionHandler):
|
|
202
205
|
source="LLM",
|
203
206
|
confidence=correction["confidence"],
|
204
207
|
reason=correction["reason"],
|
205
|
-
handler=
|
208
|
+
handler=self.name,
|
206
209
|
reference_positions=reference_positions,
|
207
210
|
original_word_id=correction["word_id"],
|
208
211
|
corrected_word_id=correction.get("reference_word_id"),
|
@@ -223,7 +226,7 @@ class LLMHandler(GapCorrectionHandler):
|
|
223
226
|
source="LLM",
|
224
227
|
confidence=correction["confidence"],
|
225
228
|
reason=correction["reason"],
|
226
|
-
handler=
|
229
|
+
handler=self.name,
|
227
230
|
reference_positions=reference_positions,
|
228
231
|
original_word_id=correction["word_id"],
|
229
232
|
corrected_word_ids=reference_word_ids,
|
@@ -256,7 +259,7 @@ class LLMHandler(GapCorrectionHandler):
|
|
256
259
|
confidence=correction["confidence"],
|
257
260
|
combine_reason=correction["reason"],
|
258
261
|
delete_reason=f"Part of combining words: {correction['reason']}",
|
259
|
-
handler=
|
262
|
+
handler=self.name,
|
260
263
|
reference_positions=reference_positions,
|
261
264
|
original_word_ids=word_ids_to_combine,
|
262
265
|
corrected_word_id=correction.get("reference_word_id"),
|
@@ -275,10 +278,10 @@ class LLMHandler(GapCorrectionHandler):
|
|
275
278
|
reason=correction["reason"],
|
276
279
|
alternatives={},
|
277
280
|
is_deletion=True,
|
278
|
-
handler=
|
281
|
+
handler=self.name,
|
279
282
|
reference_positions=reference_positions,
|
280
283
|
word_id=correction["word_id"],
|
281
|
-
corrected_word_id=None,
|
284
|
+
corrected_word_id=None,
|
282
285
|
)
|
283
286
|
)
|
284
287
|
|
@@ -0,0 +1,60 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Optional
|
3
|
+
import logging
|
4
|
+
from ollama import chat as ollama_chat
|
5
|
+
import openai
|
6
|
+
|
7
|
+
|
8
|
+
class LLMProvider(ABC):
|
9
|
+
"""Abstract base class for LLM providers."""
|
10
|
+
|
11
|
+
def __init__(self, logger: Optional[logging.Logger] = None):
|
12
|
+
self.logger = logger or logging.getLogger(__name__)
|
13
|
+
|
14
|
+
@abstractmethod
|
15
|
+
def generate_response(self, prompt: str, **kwargs) -> str:
|
16
|
+
"""Generate a response from the LLM.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
prompt: The prompt to send to the LLM
|
20
|
+
**kwargs: Additional provider-specific parameters
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
str: The LLM's response
|
24
|
+
"""
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
28
|
+
class OllamaProvider(LLMProvider):
|
29
|
+
"""Provider for local Ollama models."""
|
30
|
+
|
31
|
+
def __init__(self, model: str, logger: Optional[logging.Logger] = None):
|
32
|
+
super().__init__(logger)
|
33
|
+
self.model = model
|
34
|
+
|
35
|
+
def generate_response(self, prompt: str, **kwargs) -> str:
|
36
|
+
try:
|
37
|
+
response = ollama_chat(model=self.model, messages=[{"role": "user", "content": prompt}], format="json")
|
38
|
+
return response.message.content
|
39
|
+
except Exception as e:
|
40
|
+
self.logger.error(f"Error generating Ollama response: {e}")
|
41
|
+
raise
|
42
|
+
|
43
|
+
|
44
|
+
class OpenAIProvider(LLMProvider):
|
45
|
+
"""Provider for OpenAI-compatible APIs (including OpenRouter)."""
|
46
|
+
|
47
|
+
def __init__(self, model: str, api_key: str, base_url: Optional[str] = None, logger: Optional[logging.Logger] = None):
|
48
|
+
super().__init__(logger)
|
49
|
+
self.model = model
|
50
|
+
self.client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
51
|
+
|
52
|
+
def generate_response(self, prompt: str, **kwargs) -> str:
|
53
|
+
try:
|
54
|
+
response = self.client.chat.completions.create(
|
55
|
+
model=self.model, messages=[{"role": "user", "content": prompt}], response_format={"type": "json_object"}, **kwargs
|
56
|
+
)
|
57
|
+
return response.choices[0].message.content
|
58
|
+
except Exception as e:
|
59
|
+
self.logger.error(f"Error generating OpenAI response: {e}")
|
60
|
+
raise
|
Binary file
|