karaoke-gen 0.81.1__py3-none-any.whl → 0.86.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- karaoke_gen/instrumental_review/static/index.html +2 -2
- karaoke_gen/lyrics_processor.py +9 -5
- {karaoke_gen-0.81.1.dist-info → karaoke_gen-0.86.5.dist-info}/METADATA +2 -2
- {karaoke_gen-0.81.1.dist-info → karaoke_gen-0.86.5.dist-info}/RECORD +20 -21
- lyrics_transcriber/core/controller.py +16 -5
- lyrics_transcriber/correction/agentic/agent.py +19 -7
- lyrics_transcriber/correction/agentic/observability/langfuse_integration.py +178 -5
- lyrics_transcriber/correction/agentic/prompts/__init__.py +23 -0
- lyrics_transcriber/correction/agentic/prompts/classifier.py +66 -6
- lyrics_transcriber/correction/agentic/prompts/langfuse_prompts.py +298 -0
- lyrics_transcriber/correction/agentic/providers/config.py +19 -6
- lyrics_transcriber/correction/agentic/providers/constants.py +1 -1
- lyrics_transcriber/correction/agentic/providers/langchain_bridge.py +125 -20
- lyrics_transcriber/correction/agentic/providers/model_factory.py +58 -25
- lyrics_transcriber/correction/agentic/providers/response_parser.py +18 -6
- lyrics_transcriber/correction/agentic/router.py +2 -1
- lyrics_transcriber/correction/corrector.py +44 -49
- lyrics_transcriber/correction/handlers/llm.py +0 -293
- lyrics_transcriber/correction/handlers/llm_providers.py +0 -60
- {karaoke_gen-0.81.1.dist-info → karaoke_gen-0.86.5.dist-info}/WHEEL +0 -0
- {karaoke_gen-0.81.1.dist-info → karaoke_gen-0.86.5.dist-info}/entry_points.txt +0 -0
- {karaoke_gen-0.81.1.dist-info → karaoke_gen-0.86.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
"""LangFuse prompt management for agentic correction.
|
|
2
|
+
|
|
3
|
+
This module provides prompt fetching from LangFuse, enabling dynamic prompt
|
|
4
|
+
iteration without code redeployment.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Dict, List, Optional, Any
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LangFusePromptError(Exception):
|
|
15
|
+
"""Raised when LangFuse prompt fetching fails."""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LangFuseDatasetError(Exception):
|
|
20
|
+
"""Raised when LangFuse dataset fetching fails."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LangFusePromptService:
|
|
25
|
+
"""Fetches prompts and datasets from LangFuse for agentic correction.
|
|
26
|
+
|
|
27
|
+
This service handles:
|
|
28
|
+
- Fetching prompt templates from LangFuse
|
|
29
|
+
- Fetching few-shot examples from LangFuse datasets
|
|
30
|
+
- Compiling prompts with dynamic variables
|
|
31
|
+
- Fail-fast behavior when LangFuse is configured but unavailable
|
|
32
|
+
|
|
33
|
+
When LangFuse keys are not configured, falls back to hardcoded prompts
|
|
34
|
+
for local development.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# Prompt and dataset names in LangFuse
|
|
38
|
+
CLASSIFIER_PROMPT_NAME = "gap-classifier"
|
|
39
|
+
EXAMPLES_DATASET_NAME = "gap-classifier-examples"
|
|
40
|
+
|
|
41
|
+
def __init__(self, client: Optional[Any] = None):
|
|
42
|
+
"""Initialize the prompt service.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
client: Optional pre-initialized Langfuse client (for testing).
|
|
46
|
+
If None, will initialize from environment variables.
|
|
47
|
+
"""
|
|
48
|
+
self._client = client
|
|
49
|
+
self._initialized = False
|
|
50
|
+
self._use_langfuse = self._should_use_langfuse()
|
|
51
|
+
|
|
52
|
+
if self._use_langfuse and client is None:
|
|
53
|
+
self._init_client()
|
|
54
|
+
|
|
55
|
+
def _should_use_langfuse(self) -> bool:
|
|
56
|
+
"""Check if LangFuse credentials are configured."""
|
|
57
|
+
public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
|
|
58
|
+
secret_key = os.getenv("LANGFUSE_SECRET_KEY")
|
|
59
|
+
return bool(public_key and secret_key)
|
|
60
|
+
|
|
61
|
+
def _init_client(self) -> None:
|
|
62
|
+
"""Initialize the Langfuse client using the shared singleton."""
|
|
63
|
+
from ..observability.langfuse_integration import get_langfuse_client, LangFuseConfigError
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
self._client = get_langfuse_client()
|
|
67
|
+
if self._client:
|
|
68
|
+
self._initialized = True
|
|
69
|
+
logger.info("LangFuse prompt service initialized")
|
|
70
|
+
else:
|
|
71
|
+
logger.debug("LangFuse keys not configured, will use hardcoded prompts")
|
|
72
|
+
except LangFuseConfigError as e:
|
|
73
|
+
# Re-raise as RuntimeError for consistent error handling
|
|
74
|
+
raise RuntimeError(str(e)) from e
|
|
75
|
+
|
|
76
|
+
def get_classification_prompt(
|
|
77
|
+
self,
|
|
78
|
+
gap_text: str,
|
|
79
|
+
preceding_words: str,
|
|
80
|
+
following_words: str,
|
|
81
|
+
reference_contexts: Dict[str, str],
|
|
82
|
+
artist: Optional[str] = None,
|
|
83
|
+
title: Optional[str] = None,
|
|
84
|
+
gap_id: Optional[str] = None
|
|
85
|
+
) -> str:
|
|
86
|
+
"""Fetch and compile the gap classification prompt.
|
|
87
|
+
|
|
88
|
+
If LangFuse is configured, fetches the prompt template and examples
|
|
89
|
+
from LangFuse. Otherwise, falls back to hardcoded prompts.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
gap_text: The text of the gap that needs classification
|
|
93
|
+
preceding_words: Text immediately before the gap
|
|
94
|
+
following_words: Text immediately after the gap
|
|
95
|
+
reference_contexts: Dictionary of reference lyrics from each source
|
|
96
|
+
artist: Song artist name for context
|
|
97
|
+
title: Song title for context
|
|
98
|
+
gap_id: Identifier for the gap
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Compiled prompt string ready for LLM
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
LangFusePromptError: If LangFuse is configured but prompt fetch fails
|
|
105
|
+
"""
|
|
106
|
+
if not self._use_langfuse:
|
|
107
|
+
# Fall back to hardcoded prompt for development
|
|
108
|
+
from .classifier import build_classification_prompt_hardcoded
|
|
109
|
+
return build_classification_prompt_hardcoded(
|
|
110
|
+
gap_text=gap_text,
|
|
111
|
+
preceding_words=preceding_words,
|
|
112
|
+
following_words=following_words,
|
|
113
|
+
reference_contexts=reference_contexts,
|
|
114
|
+
artist=artist,
|
|
115
|
+
title=title,
|
|
116
|
+
gap_id=gap_id
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Fetch from LangFuse
|
|
120
|
+
try:
|
|
121
|
+
prompt_template = self._fetch_prompt(self.CLASSIFIER_PROMPT_NAME)
|
|
122
|
+
examples = self._fetch_examples()
|
|
123
|
+
|
|
124
|
+
# Build component strings
|
|
125
|
+
song_context = self._build_song_context(artist, title)
|
|
126
|
+
examples_text = self._format_examples(examples)
|
|
127
|
+
references_text = self._format_references(reference_contexts)
|
|
128
|
+
|
|
129
|
+
# Compile the prompt with variables
|
|
130
|
+
compiled = prompt_template.compile(
|
|
131
|
+
song_context=song_context,
|
|
132
|
+
examples_text=examples_text,
|
|
133
|
+
gap_id=gap_id or "unknown",
|
|
134
|
+
preceding_words=preceding_words,
|
|
135
|
+
gap_text=gap_text,
|
|
136
|
+
following_words=following_words,
|
|
137
|
+
references_text=references_text
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
logger.debug(f"Compiled LangFuse prompt for gap {gap_id}")
|
|
141
|
+
return compiled
|
|
142
|
+
|
|
143
|
+
except Exception as e:
|
|
144
|
+
raise LangFusePromptError(
|
|
145
|
+
f"Failed to fetch/compile prompt from LangFuse: {e}"
|
|
146
|
+
) from e
|
|
147
|
+
|
|
148
|
+
def _fetch_prompt(self, name: str, label: str = "production") -> Any:
|
|
149
|
+
"""Fetch a prompt template from LangFuse.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
name: The prompt name in LangFuse
|
|
153
|
+
label: Prompt label to fetch (default: "production"). Falls back to
|
|
154
|
+
version 1 if labeled version not found.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
LangFuse prompt object
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
LangFusePromptError: If fetch fails
|
|
161
|
+
"""
|
|
162
|
+
if not self._client:
|
|
163
|
+
raise LangFusePromptError("LangFuse client not initialized")
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
# Try to fetch with the specified label (default: production)
|
|
167
|
+
prompt = self._client.get_prompt(name, label=label)
|
|
168
|
+
logger.debug(f"Fetched prompt '{name}' (label={label}) from LangFuse")
|
|
169
|
+
return prompt
|
|
170
|
+
except Exception as label_error:
|
|
171
|
+
# If labeled version not found, try fetching version 1 as fallback
|
|
172
|
+
# This handles newly created prompts that haven't been promoted yet
|
|
173
|
+
try:
|
|
174
|
+
prompt = self._client.get_prompt(name, version=1)
|
|
175
|
+
logger.warning(
|
|
176
|
+
f"Prompt '{name}' label '{label}' not found, using version 1. "
|
|
177
|
+
f"Consider promoting this prompt in LangFuse UI."
|
|
178
|
+
)
|
|
179
|
+
return prompt
|
|
180
|
+
except Exception as version_error:
|
|
181
|
+
raise LangFusePromptError(
|
|
182
|
+
f"Failed to fetch prompt '{name}' from LangFuse: "
|
|
183
|
+
f"Label '{label}' error: {label_error}, "
|
|
184
|
+
f"Version 1 fallback error: {version_error}"
|
|
185
|
+
) from version_error
|
|
186
|
+
|
|
187
|
+
def _fetch_examples(self) -> List[Dict[str, Any]]:
|
|
188
|
+
"""Fetch few-shot examples from LangFuse dataset.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
List of example dictionaries
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
LangFuseDatasetError: If fetch fails
|
|
195
|
+
"""
|
|
196
|
+
if not self._client:
|
|
197
|
+
raise LangFuseDatasetError("LangFuse client not initialized")
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
dataset = self._client.get_dataset(self.EXAMPLES_DATASET_NAME)
|
|
201
|
+
examples = []
|
|
202
|
+
for item in dataset.items:
|
|
203
|
+
# Dataset items have 'input' field with the example data
|
|
204
|
+
if hasattr(item, 'input') and item.input:
|
|
205
|
+
examples.append(item.input)
|
|
206
|
+
|
|
207
|
+
logger.debug(f"Fetched {len(examples)} examples from LangFuse dataset")
|
|
208
|
+
return examples
|
|
209
|
+
except Exception as e:
|
|
210
|
+
raise LangFuseDatasetError(
|
|
211
|
+
f"Failed to fetch dataset '{self.EXAMPLES_DATASET_NAME}' from LangFuse: {e}"
|
|
212
|
+
) from e
|
|
213
|
+
|
|
214
|
+
def _build_song_context(self, artist: Optional[str], title: Optional[str]) -> str:
|
|
215
|
+
"""Build song context section for the prompt."""
|
|
216
|
+
if artist and title:
|
|
217
|
+
return (
|
|
218
|
+
f"\n## Song Context\n\n"
|
|
219
|
+
f"**Artist:** {artist}\n"
|
|
220
|
+
f"**Title:** {title}\n\n"
|
|
221
|
+
f"Note: The song title and artist name may help identify proper nouns "
|
|
222
|
+
f"or unusual words that could be mis-heard.\n"
|
|
223
|
+
)
|
|
224
|
+
return ""
|
|
225
|
+
|
|
226
|
+
def _format_examples(self, examples: List[Dict[str, Any]]) -> str:
|
|
227
|
+
"""Format few-shot examples for inclusion in prompt.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
examples: List of example dictionaries from LangFuse dataset
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Formatted examples string
|
|
234
|
+
"""
|
|
235
|
+
if not examples:
|
|
236
|
+
return ""
|
|
237
|
+
|
|
238
|
+
# Group examples by category
|
|
239
|
+
examples_by_category: Dict[str, List[Dict]] = {}
|
|
240
|
+
for ex in examples:
|
|
241
|
+
category = ex.get("category", "unknown")
|
|
242
|
+
if category not in examples_by_category:
|
|
243
|
+
examples_by_category[category] = []
|
|
244
|
+
examples_by_category[category].append(ex)
|
|
245
|
+
|
|
246
|
+
# Build formatted text
|
|
247
|
+
text = "## Example Classifications\n\n"
|
|
248
|
+
for category, category_examples in examples_by_category.items():
|
|
249
|
+
text += f"### {category.upper().replace('_', ' ')}\n\n"
|
|
250
|
+
for ex in category_examples[:2]: # Limit to 2 examples per category
|
|
251
|
+
text += f"**Gap:** {ex.get('gap_text', '')}\n"
|
|
252
|
+
text += f"**Context:** ...{ex.get('preceding', '')}... [GAP] ...{ex.get('following', '')}...\n"
|
|
253
|
+
if 'reference' in ex:
|
|
254
|
+
text += f"**Reference:** {ex['reference']}\n"
|
|
255
|
+
text += f"**Reasoning:** {ex.get('reasoning', '')}\n"
|
|
256
|
+
text += f"**Action:** {ex.get('action', '')}\n\n"
|
|
257
|
+
|
|
258
|
+
return text
|
|
259
|
+
|
|
260
|
+
def _format_references(self, reference_contexts: Dict[str, str]) -> str:
|
|
261
|
+
"""Format reference lyrics for inclusion in prompt.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
reference_contexts: Dictionary of reference lyrics from each source
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Formatted references string
|
|
268
|
+
"""
|
|
269
|
+
if not reference_contexts:
|
|
270
|
+
return ""
|
|
271
|
+
|
|
272
|
+
text = "## Available Reference Lyrics\n\n"
|
|
273
|
+
for source, context in reference_contexts.items():
|
|
274
|
+
text += f"**{source.upper()}:** {context}\n\n"
|
|
275
|
+
|
|
276
|
+
return text
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# Module-level singleton for convenience
|
|
280
|
+
_prompt_service: Optional[LangFusePromptService] = None
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def get_prompt_service() -> LangFusePromptService:
|
|
284
|
+
"""Get or create the global prompt service instance.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
LangFusePromptService singleton instance
|
|
288
|
+
"""
|
|
289
|
+
global _prompt_service
|
|
290
|
+
if _prompt_service is None:
|
|
291
|
+
_prompt_service = LangFusePromptService()
|
|
292
|
+
return _prompt_service
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def reset_prompt_service() -> None:
|
|
296
|
+
"""Reset the global prompt service instance (for testing)."""
|
|
297
|
+
global _prompt_service
|
|
298
|
+
_prompt_service = None
|
|
@@ -19,16 +19,27 @@ class ProviderConfig:
|
|
|
19
19
|
cache_dir: str
|
|
20
20
|
|
|
21
21
|
# GCP/Vertex AI settings
|
|
22
|
+
# Note: Gemini 3 models require 'global' location (not regional like us-central1)
|
|
22
23
|
gcp_project_id: Optional[str] = None
|
|
23
|
-
gcp_location: str = "
|
|
24
|
+
gcp_location: str = "global"
|
|
24
25
|
|
|
25
|
-
|
|
26
|
+
# Timeout increased to 120s to handle Vertex AI connection establishment
|
|
27
|
+
# and potential network latency. The 499 "operation cancelled" errors seen
|
|
28
|
+
# at ~60s suggest internal timeouts; 120s provides headroom.
|
|
29
|
+
request_timeout_seconds: float = 120.0
|
|
26
30
|
max_retries: int = 2
|
|
27
|
-
|
|
31
|
+
# Backoff increased from 0.2s to 2.0s base - if a request times out,
|
|
32
|
+
# retrying immediately is unlikely to help. Give the service time to recover.
|
|
33
|
+
retry_backoff_base_seconds: float = 2.0
|
|
28
34
|
retry_backoff_factor: float = 2.0
|
|
29
35
|
circuit_breaker_failure_threshold: int = 3
|
|
30
36
|
circuit_breaker_open_seconds: int = 60
|
|
31
37
|
|
|
38
|
+
# Initialization timeouts - fail fast instead of hanging forever
|
|
39
|
+
# These are separate from request_timeout to catch connection establishment issues
|
|
40
|
+
initialization_timeout_seconds: float = 30.0 # Model creation + warm-up
|
|
41
|
+
warmup_timeout_seconds: float = 15.0 # Just the warm-up call
|
|
42
|
+
|
|
32
43
|
@staticmethod
|
|
33
44
|
def from_env(cache_dir: Optional[str] = None) -> "ProviderConfig":
|
|
34
45
|
"""Create config from environment variables.
|
|
@@ -51,13 +62,15 @@ class ProviderConfig:
|
|
|
51
62
|
privacy_mode=os.getenv("PRIVACY_MODE", "false").lower() in {"1", "true", "yes"},
|
|
52
63
|
cache_dir=cache_dir,
|
|
53
64
|
gcp_project_id=os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCP_PROJECT_ID"),
|
|
54
|
-
gcp_location=os.getenv("GCP_LOCATION", "
|
|
55
|
-
request_timeout_seconds=float(os.getenv("AGENTIC_TIMEOUT_SECONDS", "
|
|
65
|
+
gcp_location=os.getenv("GCP_LOCATION", "global"),
|
|
66
|
+
request_timeout_seconds=float(os.getenv("AGENTIC_TIMEOUT_SECONDS", "120.0")),
|
|
56
67
|
max_retries=int(os.getenv("AGENTIC_MAX_RETRIES", "2")),
|
|
57
|
-
retry_backoff_base_seconds=float(os.getenv("AGENTIC_BACKOFF_BASE_SECONDS", "0
|
|
68
|
+
retry_backoff_base_seconds=float(os.getenv("AGENTIC_BACKOFF_BASE_SECONDS", "2.0")),
|
|
58
69
|
retry_backoff_factor=float(os.getenv("AGENTIC_BACKOFF_FACTOR", "2.0")),
|
|
59
70
|
circuit_breaker_failure_threshold=int(os.getenv("AGENTIC_CIRCUIT_THRESHOLD", "3")),
|
|
60
71
|
circuit_breaker_open_seconds=int(os.getenv("AGENTIC_CIRCUIT_OPEN_SECONDS", "60")),
|
|
72
|
+
initialization_timeout_seconds=float(os.getenv("AGENTIC_INIT_TIMEOUT_SECONDS", "30.0")),
|
|
73
|
+
warmup_timeout_seconds=float(os.getenv("AGENTIC_WARMUP_TIMEOUT_SECONDS", "15.0")),
|
|
61
74
|
)
|
|
62
75
|
|
|
63
76
|
def validate_environment(self, logger: Optional[object] = None) -> None:
|
|
@@ -8,7 +8,7 @@ RESPONSE_LOG_LENGTH = 500 # Characters to log from responses
|
|
|
8
8
|
MODEL_SPEC_FORMAT = "provider/model" # Expected format for model identifiers
|
|
9
9
|
|
|
10
10
|
# Default Langfuse host
|
|
11
|
-
DEFAULT_LANGFUSE_HOST = "https://cloud.langfuse.com"
|
|
11
|
+
DEFAULT_LANGFUSE_HOST = "https://us.cloud.langfuse.com"
|
|
12
12
|
|
|
13
13
|
# Raw response indicator
|
|
14
14
|
RAW_RESPONSE_KEY = "raw" # Key used to wrap unparsed responses
|
|
@@ -13,6 +13,8 @@ from __future__ import annotations
|
|
|
13
13
|
|
|
14
14
|
import logging
|
|
15
15
|
import os
|
|
16
|
+
import time
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
|
16
18
|
from typing import List, Dict, Any, Optional
|
|
17
19
|
from datetime import datetime
|
|
18
20
|
|
|
@@ -33,6 +35,14 @@ from .constants import (
|
|
|
33
35
|
|
|
34
36
|
logger = logging.getLogger(__name__)
|
|
35
37
|
|
|
38
|
+
# Error constant for initialization timeout
|
|
39
|
+
INIT_TIMEOUT_ERROR = "initialization_timeout"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class InitializationTimeoutError(Exception):
|
|
43
|
+
"""Raised when model initialization exceeds the configured timeout."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
36
46
|
|
|
37
47
|
class LangChainBridge(BaseAIProvider):
|
|
38
48
|
"""Provider bridge using LangChain ChatModels with reliability patterns.
|
|
@@ -87,6 +97,7 @@ class LangChainBridge(BaseAIProvider):
|
|
|
87
97
|
|
|
88
98
|
# Lazy-initialized chat model
|
|
89
99
|
self._chat_model: Optional[Any] = None
|
|
100
|
+
self._warmed_up: bool = False
|
|
90
101
|
|
|
91
102
|
def name(self) -> str:
|
|
92
103
|
"""Return provider name for logging."""
|
|
@@ -130,13 +141,45 @@ class LangChainBridge(BaseAIProvider):
|
|
|
130
141
|
"until": open_until
|
|
131
142
|
}]
|
|
132
143
|
|
|
133
|
-
# Step 2: Get or create chat model
|
|
144
|
+
# Step 2: Get or create chat model with initialization timeout
|
|
134
145
|
if not self._chat_model:
|
|
146
|
+
timeout = self._config.initialization_timeout_seconds
|
|
147
|
+
logger.info(f"🤖 Initializing model {self._model} with {timeout}s timeout...")
|
|
148
|
+
init_start = time.time()
|
|
149
|
+
|
|
135
150
|
try:
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
151
|
+
# Use ThreadPoolExecutor for cross-platform timeout
|
|
152
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
153
|
+
future = executor.submit(
|
|
154
|
+
self._factory.create_chat_model,
|
|
155
|
+
self._model,
|
|
156
|
+
self._config
|
|
157
|
+
)
|
|
158
|
+
try:
|
|
159
|
+
self._chat_model = future.result(timeout=timeout)
|
|
160
|
+
except FuturesTimeoutError:
|
|
161
|
+
raise InitializationTimeoutError(
|
|
162
|
+
f"Model initialization timed out after {timeout}s. "
|
|
163
|
+
f"This may indicate network issues or service unavailability."
|
|
164
|
+
) from None
|
|
165
|
+
|
|
166
|
+
init_elapsed = time.time() - init_start
|
|
167
|
+
logger.info(f"🤖 Model created in {init_elapsed:.2f}s, starting warm-up...")
|
|
168
|
+
|
|
169
|
+
# Warm up the model to establish connection before real work
|
|
170
|
+
self._warm_up_model()
|
|
171
|
+
|
|
172
|
+
total_elapsed = time.time() - init_start
|
|
173
|
+
logger.info(f"🤖 Model initialization complete in {total_elapsed:.2f}s")
|
|
174
|
+
|
|
175
|
+
except InitializationTimeoutError as e:
|
|
176
|
+
self._circuit_breaker.record_failure(self._model)
|
|
177
|
+
logger.exception("🤖 Model initialization timeout")
|
|
178
|
+
return [{
|
|
179
|
+
"error": INIT_TIMEOUT_ERROR,
|
|
180
|
+
"message": str(e),
|
|
181
|
+
"timeout_seconds": timeout
|
|
182
|
+
}]
|
|
140
183
|
except Exception as e:
|
|
141
184
|
self._circuit_breaker.record_failure(self._model)
|
|
142
185
|
logger.error(f"🤖 Failed to initialize chat model: {e}")
|
|
@@ -146,24 +189,27 @@ class LangChainBridge(BaseAIProvider):
|
|
|
146
189
|
}]
|
|
147
190
|
|
|
148
191
|
# Step 3: Execute with retry logic
|
|
149
|
-
logger.
|
|
150
|
-
f"🤖 [LangChain] Sending prompt to {self._model}
|
|
151
|
-
f"{prompt[:PROMPT_LOG_LENGTH]}..."
|
|
192
|
+
logger.info(
|
|
193
|
+
f"🤖 [LangChain] Sending prompt to {self._model} ({len(prompt)} chars)"
|
|
152
194
|
)
|
|
153
|
-
|
|
195
|
+
logger.debug(f"🤖 [LangChain] Prompt preview: {prompt[:PROMPT_LOG_LENGTH]}...")
|
|
196
|
+
|
|
197
|
+
invoke_start = time.time()
|
|
154
198
|
result = self._executor.execute_with_retry(
|
|
155
199
|
operation=lambda: self._invoke_model(prompt),
|
|
156
200
|
operation_name=f"invoke_{self._model}"
|
|
157
201
|
)
|
|
158
|
-
|
|
202
|
+
invoke_elapsed = time.time() - invoke_start
|
|
203
|
+
|
|
159
204
|
# Step 4: Handle result and update circuit breaker
|
|
160
205
|
if result.success:
|
|
161
206
|
self._circuit_breaker.record_success(self._model)
|
|
162
|
-
|
|
207
|
+
|
|
163
208
|
logger.info(
|
|
164
|
-
f"🤖 [LangChain] Got response from {self._model}
|
|
165
|
-
f"{result.value
|
|
209
|
+
f"🤖 [LangChain] Got response from {self._model} in {invoke_elapsed:.2f}s "
|
|
210
|
+
f"({len(result.value)} chars)"
|
|
166
211
|
)
|
|
212
|
+
logger.debug(f"🤖 [LangChain] Response preview: {result.value[:RESPONSE_LOG_LENGTH]}...")
|
|
167
213
|
|
|
168
214
|
# Step 5: Cache the raw response for future use
|
|
169
215
|
self._cache.set(
|
|
@@ -187,26 +233,85 @@ class LangChainBridge(BaseAIProvider):
|
|
|
187
233
|
|
|
188
234
|
def _invoke_model(self, prompt: str) -> str:
|
|
189
235
|
"""Invoke the chat model with a prompt.
|
|
190
|
-
|
|
236
|
+
|
|
191
237
|
This is a simple wrapper that can be passed to the retry executor.
|
|
192
|
-
|
|
238
|
+
|
|
193
239
|
Args:
|
|
194
240
|
prompt: The prompt to send
|
|
195
|
-
|
|
241
|
+
|
|
196
242
|
Returns:
|
|
197
243
|
Response content as string
|
|
198
|
-
|
|
244
|
+
|
|
199
245
|
Raises:
|
|
200
246
|
Exception: Any error from the model invocation
|
|
201
247
|
"""
|
|
202
248
|
from langchain_core.messages import HumanMessage
|
|
203
|
-
|
|
249
|
+
|
|
204
250
|
# Prepare config with session_id in metadata (Langfuse format)
|
|
205
251
|
config = {}
|
|
206
252
|
if hasattr(self, '_session_id') and self._session_id:
|
|
207
253
|
config["metadata"] = {"langfuse_session_id": self._session_id}
|
|
208
254
|
logger.debug(f"🤖 [LangChain] Invoking with session_id: {self._session_id}")
|
|
209
|
-
|
|
255
|
+
|
|
210
256
|
response = self._chat_model.invoke([HumanMessage(content=prompt)], config=config)
|
|
211
|
-
|
|
257
|
+
content = response.content
|
|
258
|
+
|
|
259
|
+
# Handle multimodal response format from Gemini 3+ models
|
|
260
|
+
# Response can be a list of content parts: [{'type': 'text', 'text': '...'}]
|
|
261
|
+
if isinstance(content, list):
|
|
262
|
+
# Extract text from the first text content part
|
|
263
|
+
for part in content:
|
|
264
|
+
if isinstance(part, dict) and part.get('type') == 'text':
|
|
265
|
+
return part.get('text', '')
|
|
266
|
+
# Fallback: concatenate all text parts
|
|
267
|
+
return ''.join(
|
|
268
|
+
part.get('text', '') if isinstance(part, dict) else str(part)
|
|
269
|
+
for part in content
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return content
|
|
273
|
+
|
|
274
|
+
def _warm_up_model(self) -> None:
|
|
275
|
+
"""Send a lightweight request to warm up the model connection.
|
|
276
|
+
|
|
277
|
+
This helps establish the REST connection and potentially warm up any
|
|
278
|
+
server-side resources before processing real correction requests.
|
|
279
|
+
The warm-up uses a timeout to fail fast if the service is unresponsive.
|
|
280
|
+
"""
|
|
281
|
+
if self._warmed_up:
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
timeout = self._config.warmup_timeout_seconds
|
|
285
|
+
# Use print with flush=True for visibility when output is redirected
|
|
286
|
+
print(f"🔥 Warming up {self._model} connection (timeout: {timeout}s)...", flush=True)
|
|
287
|
+
logger.info(f"🔥 Warming up {self._model} connection (timeout: {timeout}s)...")
|
|
288
|
+
|
|
289
|
+
warmup_start = time.time()
|
|
290
|
+
try:
|
|
291
|
+
from langchain_core.messages import HumanMessage
|
|
292
|
+
|
|
293
|
+
# Minimal prompt that requires almost no processing
|
|
294
|
+
warm_up_prompt = 'Respond with exactly: {"status":"ready"}'
|
|
295
|
+
|
|
296
|
+
# Use ThreadPoolExecutor for timeout on warm-up call
|
|
297
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
298
|
+
future = executor.submit(
|
|
299
|
+
self._chat_model.invoke,
|
|
300
|
+
[HumanMessage(content=warm_up_prompt)]
|
|
301
|
+
)
|
|
302
|
+
try:
|
|
303
|
+
future.result(timeout=timeout)
|
|
304
|
+
except FuturesTimeoutError:
|
|
305
|
+
raise TimeoutError(f"Warm-up timed out after {timeout}s") from None
|
|
306
|
+
|
|
307
|
+
elapsed = time.time() - warmup_start
|
|
308
|
+
self._warmed_up = True
|
|
309
|
+
print(f"🔥 Warm-up complete for {self._model} in {elapsed:.2f}s", flush=True)
|
|
310
|
+
logger.info(f"🔥 Warm-up complete for {self._model} in {elapsed:.2f}s")
|
|
311
|
+
except Exception as e:
|
|
312
|
+
elapsed = time.time() - warmup_start
|
|
313
|
+
# Don't fail the actual request if warm-up fails
|
|
314
|
+
# Just log and continue - the real request might still work
|
|
315
|
+
print(f"🔥 Warm-up failed for {self._model} after {elapsed:.2f}s: {e} (continuing anyway)", flush=True)
|
|
316
|
+
logger.warning(f"🔥 Warm-up failed for {self._model} after {elapsed:.2f}s: {e} (continuing anyway)")
|
|
212
317
|
|
|
@@ -3,12 +3,19 @@ from __future__ import annotations
|
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
import os
|
|
6
|
+
import time
|
|
6
7
|
from typing import Any, Optional, List
|
|
7
8
|
|
|
8
9
|
from .config import ProviderConfig
|
|
9
10
|
|
|
10
11
|
logger = logging.getLogger(__name__)
|
|
11
12
|
|
|
13
|
+
# Error message constant for TRY003 compliance
|
|
14
|
+
GOOGLE_API_KEY_MISSING_ERROR = (
|
|
15
|
+
"GOOGLE_API_KEY environment variable is required for Google/Gemini models. "
|
|
16
|
+
"Get an API key from https://aistudio.google.com/app/apikey"
|
|
17
|
+
)
|
|
18
|
+
|
|
12
19
|
|
|
13
20
|
class ModelFactory:
|
|
14
21
|
"""Creates and configures LangChain ChatModels with observability.
|
|
@@ -100,19 +107,10 @@ class ModelFactory:
|
|
|
100
107
|
return
|
|
101
108
|
|
|
102
109
|
try:
|
|
103
|
-
from langfuse import Langfuse
|
|
104
110
|
from langfuse.langchain import CallbackHandler
|
|
105
|
-
|
|
106
|
-
#
|
|
107
|
-
|
|
108
|
-
public_key=public_key,
|
|
109
|
-
secret_key=secret_key,
|
|
110
|
-
host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com"),
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
# Then create callback handler with the same public_key
|
|
114
|
-
# The handler will use the initialized client
|
|
115
|
-
self._langfuse_handler = CallbackHandler(public_key=public_key)
|
|
111
|
+
|
|
112
|
+
# CallbackHandler auto-discovers credentials from environment variables
|
|
113
|
+
self._langfuse_handler = CallbackHandler()
|
|
116
114
|
logger.info(f"🤖 Langfuse callback handler initialized for {model_spec}")
|
|
117
115
|
except Exception as e:
|
|
118
116
|
# If Langfuse keys are set, we MUST fail fast
|
|
@@ -212,21 +210,56 @@ class ModelFactory:
|
|
|
212
210
|
def _create_vertexai_model(
|
|
213
211
|
self, model_name: str, callbacks: List[Any], config: ProviderConfig
|
|
214
212
|
) -> Any:
|
|
215
|
-
"""Create
|
|
213
|
+
"""Create ChatGoogleGenerativeAI model for Google Gemini.
|
|
216
214
|
|
|
217
|
-
Uses
|
|
218
|
-
|
|
219
|
-
|
|
215
|
+
Uses the unified langchain-google-genai package which supports both:
|
|
216
|
+
- Vertex AI backend (service account / ADC auth) - when project is set
|
|
217
|
+
- Google AI Studio backend (API key auth) - when only api_key is set
|
|
218
|
+
|
|
219
|
+
On Cloud Run, ADC (Application Default Credentials) are used automatically
|
|
220
|
+
when the project parameter is provided, using the service account attached
|
|
221
|
+
to the Cloud Run service.
|
|
222
|
+
|
|
223
|
+
This is a REST-based API that avoids the gRPC connection issues
|
|
224
|
+
seen with the deprecated langchain-google-vertexai package.
|
|
220
225
|
"""
|
|
221
|
-
from
|
|
226
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
222
227
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
)
|
|
230
|
-
|
|
228
|
+
start_time = time.time()
|
|
229
|
+
|
|
230
|
+
# Determine authentication method
|
|
231
|
+
api_key = config.google_api_key
|
|
232
|
+
project = config.gcp_project_id
|
|
233
|
+
|
|
234
|
+
# Prefer Vertex AI (service account) if project is set, otherwise require API key
|
|
235
|
+
if not project and not api_key:
|
|
236
|
+
raise ValueError(GOOGLE_API_KEY_MISSING_ERROR)
|
|
237
|
+
|
|
238
|
+
if project:
|
|
239
|
+
logger.info(f"🤖 Creating Google Gemini model via Vertex AI (project={project}): {model_name}")
|
|
240
|
+
else:
|
|
241
|
+
logger.info(f"🤖 Creating Google Gemini model via AI Studio API: {model_name}")
|
|
242
|
+
|
|
243
|
+
# Build kwargs - only include api_key if set (otherwise ADC is used)
|
|
244
|
+
model_kwargs = {
|
|
245
|
+
"model": model_name,
|
|
246
|
+
"convert_system_message_to_human": True, # Gemini doesn't support system messages
|
|
247
|
+
"max_retries": config.max_retries,
|
|
248
|
+
"timeout": config.request_timeout_seconds,
|
|
249
|
+
"callbacks": callbacks,
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
# Add project to trigger Vertex AI backend with ADC
|
|
253
|
+
if project:
|
|
254
|
+
model_kwargs["project"] = project
|
|
255
|
+
|
|
256
|
+
# Add API key if available (can be used with or without project)
|
|
257
|
+
if api_key:
|
|
258
|
+
model_kwargs["google_api_key"] = api_key
|
|
259
|
+
|
|
260
|
+
model = ChatGoogleGenerativeAI(**model_kwargs)
|
|
261
|
+
|
|
262
|
+
elapsed = time.time() - start_time
|
|
263
|
+
logger.info(f"🤖 Google Gemini model created in {elapsed:.2f}s: {model_name}")
|
|
231
264
|
return model
|
|
232
265
|
|