karaoke-gen 0.81.1__py3-none-any.whl → 0.86.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,298 @@
1
+ """LangFuse prompt management for agentic correction.
2
+
3
+ This module provides prompt fetching from LangFuse, enabling dynamic prompt
4
+ iteration without code redeployment.
5
+ """
6
+
7
+ from typing import Dict, List, Optional, Any
8
+ import logging
9
+ import os
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class LangFusePromptError(Exception):
15
+ """Raised when LangFuse prompt fetching fails."""
16
+ pass
17
+
18
+
19
+ class LangFuseDatasetError(Exception):
20
+ """Raised when LangFuse dataset fetching fails."""
21
+ pass
22
+
23
+
24
+ class LangFusePromptService:
25
+ """Fetches prompts and datasets from LangFuse for agentic correction.
26
+
27
+ This service handles:
28
+ - Fetching prompt templates from LangFuse
29
+ - Fetching few-shot examples from LangFuse datasets
30
+ - Compiling prompts with dynamic variables
31
+ - Fail-fast behavior when LangFuse is configured but unavailable
32
+
33
+ When LangFuse keys are not configured, falls back to hardcoded prompts
34
+ for local development.
35
+ """
36
+
37
+ # Prompt and dataset names in LangFuse
38
+ CLASSIFIER_PROMPT_NAME = "gap-classifier"
39
+ EXAMPLES_DATASET_NAME = "gap-classifier-examples"
40
+
41
+ def __init__(self, client: Optional[Any] = None):
42
+ """Initialize the prompt service.
43
+
44
+ Args:
45
+ client: Optional pre-initialized Langfuse client (for testing).
46
+ If None, will initialize from environment variables.
47
+ """
48
+ self._client = client
49
+ self._initialized = False
50
+ self._use_langfuse = self._should_use_langfuse()
51
+
52
+ if self._use_langfuse and client is None:
53
+ self._init_client()
54
+
55
+ def _should_use_langfuse(self) -> bool:
56
+ """Check if LangFuse credentials are configured."""
57
+ public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
58
+ secret_key = os.getenv("LANGFUSE_SECRET_KEY")
59
+ return bool(public_key and secret_key)
60
+
61
+ def _init_client(self) -> None:
62
+ """Initialize the Langfuse client using the shared singleton."""
63
+ from ..observability.langfuse_integration import get_langfuse_client, LangFuseConfigError
64
+
65
+ try:
66
+ self._client = get_langfuse_client()
67
+ if self._client:
68
+ self._initialized = True
69
+ logger.info("LangFuse prompt service initialized")
70
+ else:
71
+ logger.debug("LangFuse keys not configured, will use hardcoded prompts")
72
+ except LangFuseConfigError as e:
73
+ # Re-raise as RuntimeError for consistent error handling
74
+ raise RuntimeError(str(e)) from e
75
+
76
+ def get_classification_prompt(
77
+ self,
78
+ gap_text: str,
79
+ preceding_words: str,
80
+ following_words: str,
81
+ reference_contexts: Dict[str, str],
82
+ artist: Optional[str] = None,
83
+ title: Optional[str] = None,
84
+ gap_id: Optional[str] = None
85
+ ) -> str:
86
+ """Fetch and compile the gap classification prompt.
87
+
88
+ If LangFuse is configured, fetches the prompt template and examples
89
+ from LangFuse. Otherwise, falls back to hardcoded prompts.
90
+
91
+ Args:
92
+ gap_text: The text of the gap that needs classification
93
+ preceding_words: Text immediately before the gap
94
+ following_words: Text immediately after the gap
95
+ reference_contexts: Dictionary of reference lyrics from each source
96
+ artist: Song artist name for context
97
+ title: Song title for context
98
+ gap_id: Identifier for the gap
99
+
100
+ Returns:
101
+ Compiled prompt string ready for LLM
102
+
103
+ Raises:
104
+ LangFusePromptError: If LangFuse is configured but prompt fetch fails
105
+ """
106
+ if not self._use_langfuse:
107
+ # Fall back to hardcoded prompt for development
108
+ from .classifier import build_classification_prompt_hardcoded
109
+ return build_classification_prompt_hardcoded(
110
+ gap_text=gap_text,
111
+ preceding_words=preceding_words,
112
+ following_words=following_words,
113
+ reference_contexts=reference_contexts,
114
+ artist=artist,
115
+ title=title,
116
+ gap_id=gap_id
117
+ )
118
+
119
+ # Fetch from LangFuse
120
+ try:
121
+ prompt_template = self._fetch_prompt(self.CLASSIFIER_PROMPT_NAME)
122
+ examples = self._fetch_examples()
123
+
124
+ # Build component strings
125
+ song_context = self._build_song_context(artist, title)
126
+ examples_text = self._format_examples(examples)
127
+ references_text = self._format_references(reference_contexts)
128
+
129
+ # Compile the prompt with variables
130
+ compiled = prompt_template.compile(
131
+ song_context=song_context,
132
+ examples_text=examples_text,
133
+ gap_id=gap_id or "unknown",
134
+ preceding_words=preceding_words,
135
+ gap_text=gap_text,
136
+ following_words=following_words,
137
+ references_text=references_text
138
+ )
139
+
140
+ logger.debug(f"Compiled LangFuse prompt for gap {gap_id}")
141
+ return compiled
142
+
143
+ except Exception as e:
144
+ raise LangFusePromptError(
145
+ f"Failed to fetch/compile prompt from LangFuse: {e}"
146
+ ) from e
147
+
148
+ def _fetch_prompt(self, name: str, label: str = "production") -> Any:
149
+ """Fetch a prompt template from LangFuse.
150
+
151
+ Args:
152
+ name: The prompt name in LangFuse
153
+ label: Prompt label to fetch (default: "production"). Falls back to
154
+ version 1 if labeled version not found.
155
+
156
+ Returns:
157
+ LangFuse prompt object
158
+
159
+ Raises:
160
+ LangFusePromptError: If fetch fails
161
+ """
162
+ if not self._client:
163
+ raise LangFusePromptError("LangFuse client not initialized")
164
+
165
+ try:
166
+ # Try to fetch with the specified label (default: production)
167
+ prompt = self._client.get_prompt(name, label=label)
168
+ logger.debug(f"Fetched prompt '{name}' (label={label}) from LangFuse")
169
+ return prompt
170
+ except Exception as label_error:
171
+ # If labeled version not found, try fetching version 1 as fallback
172
+ # This handles newly created prompts that haven't been promoted yet
173
+ try:
174
+ prompt = self._client.get_prompt(name, version=1)
175
+ logger.warning(
176
+ f"Prompt '{name}' label '{label}' not found, using version 1. "
177
+ f"Consider promoting this prompt in LangFuse UI."
178
+ )
179
+ return prompt
180
+ except Exception as version_error:
181
+ raise LangFusePromptError(
182
+ f"Failed to fetch prompt '{name}' from LangFuse: "
183
+ f"Label '{label}' error: {label_error}, "
184
+ f"Version 1 fallback error: {version_error}"
185
+ ) from version_error
186
+
187
+ def _fetch_examples(self) -> List[Dict[str, Any]]:
188
+ """Fetch few-shot examples from LangFuse dataset.
189
+
190
+ Returns:
191
+ List of example dictionaries
192
+
193
+ Raises:
194
+ LangFuseDatasetError: If fetch fails
195
+ """
196
+ if not self._client:
197
+ raise LangFuseDatasetError("LangFuse client not initialized")
198
+
199
+ try:
200
+ dataset = self._client.get_dataset(self.EXAMPLES_DATASET_NAME)
201
+ examples = []
202
+ for item in dataset.items:
203
+ # Dataset items have 'input' field with the example data
204
+ if hasattr(item, 'input') and item.input:
205
+ examples.append(item.input)
206
+
207
+ logger.debug(f"Fetched {len(examples)} examples from LangFuse dataset")
208
+ return examples
209
+ except Exception as e:
210
+ raise LangFuseDatasetError(
211
+ f"Failed to fetch dataset '{self.EXAMPLES_DATASET_NAME}' from LangFuse: {e}"
212
+ ) from e
213
+
214
+ def _build_song_context(self, artist: Optional[str], title: Optional[str]) -> str:
215
+ """Build song context section for the prompt."""
216
+ if artist and title:
217
+ return (
218
+ f"\n## Song Context\n\n"
219
+ f"**Artist:** {artist}\n"
220
+ f"**Title:** {title}\n\n"
221
+ f"Note: The song title and artist name may help identify proper nouns "
222
+ f"or unusual words that could be mis-heard.\n"
223
+ )
224
+ return ""
225
+
226
+ def _format_examples(self, examples: List[Dict[str, Any]]) -> str:
227
+ """Format few-shot examples for inclusion in prompt.
228
+
229
+ Args:
230
+ examples: List of example dictionaries from LangFuse dataset
231
+
232
+ Returns:
233
+ Formatted examples string
234
+ """
235
+ if not examples:
236
+ return ""
237
+
238
+ # Group examples by category
239
+ examples_by_category: Dict[str, List[Dict]] = {}
240
+ for ex in examples:
241
+ category = ex.get("category", "unknown")
242
+ if category not in examples_by_category:
243
+ examples_by_category[category] = []
244
+ examples_by_category[category].append(ex)
245
+
246
+ # Build formatted text
247
+ text = "## Example Classifications\n\n"
248
+ for category, category_examples in examples_by_category.items():
249
+ text += f"### {category.upper().replace('_', ' ')}\n\n"
250
+ for ex in category_examples[:2]: # Limit to 2 examples per category
251
+ text += f"**Gap:** {ex.get('gap_text', '')}\n"
252
+ text += f"**Context:** ...{ex.get('preceding', '')}... [GAP] ...{ex.get('following', '')}...\n"
253
+ if 'reference' in ex:
254
+ text += f"**Reference:** {ex['reference']}\n"
255
+ text += f"**Reasoning:** {ex.get('reasoning', '')}\n"
256
+ text += f"**Action:** {ex.get('action', '')}\n\n"
257
+
258
+ return text
259
+
260
+ def _format_references(self, reference_contexts: Dict[str, str]) -> str:
261
+ """Format reference lyrics for inclusion in prompt.
262
+
263
+ Args:
264
+ reference_contexts: Dictionary of reference lyrics from each source
265
+
266
+ Returns:
267
+ Formatted references string
268
+ """
269
+ if not reference_contexts:
270
+ return ""
271
+
272
+ text = "## Available Reference Lyrics\n\n"
273
+ for source, context in reference_contexts.items():
274
+ text += f"**{source.upper()}:** {context}\n\n"
275
+
276
+ return text
277
+
278
+
279
+ # Module-level singleton for convenience
280
+ _prompt_service: Optional[LangFusePromptService] = None
281
+
282
+
283
+ def get_prompt_service() -> LangFusePromptService:
284
+ """Get or create the global prompt service instance.
285
+
286
+ Returns:
287
+ LangFusePromptService singleton instance
288
+ """
289
+ global _prompt_service
290
+ if _prompt_service is None:
291
+ _prompt_service = LangFusePromptService()
292
+ return _prompt_service
293
+
294
+
295
+ def reset_prompt_service() -> None:
296
+ """Reset the global prompt service instance (for testing)."""
297
+ global _prompt_service
298
+ _prompt_service = None
@@ -19,16 +19,27 @@ class ProviderConfig:
19
19
  cache_dir: str
20
20
 
21
21
  # GCP/Vertex AI settings
22
+ # Note: Gemini 3 models require 'global' location (not regional like us-central1)
22
23
  gcp_project_id: Optional[str] = None
23
- gcp_location: str = "us-central1"
24
+ gcp_location: str = "global"
24
25
 
25
- request_timeout_seconds: float = 30.0
26
+ # Timeout increased to 120s to handle Vertex AI connection establishment
27
+ # and potential network latency. The 499 "operation cancelled" errors seen
28
+ # at ~60s suggest internal timeouts; 120s provides headroom.
29
+ request_timeout_seconds: float = 120.0
26
30
  max_retries: int = 2
27
- retry_backoff_base_seconds: float = 0.2
31
+ # Backoff increased from 0.2s to 2.0s base - if a request times out,
32
+ # retrying immediately is unlikely to help. Give the service time to recover.
33
+ retry_backoff_base_seconds: float = 2.0
28
34
  retry_backoff_factor: float = 2.0
29
35
  circuit_breaker_failure_threshold: int = 3
30
36
  circuit_breaker_open_seconds: int = 60
31
37
 
38
+ # Initialization timeouts - fail fast instead of hanging forever
39
+ # These are separate from request_timeout to catch connection establishment issues
40
+ initialization_timeout_seconds: float = 30.0 # Model creation + warm-up
41
+ warmup_timeout_seconds: float = 15.0 # Just the warm-up call
42
+
32
43
  @staticmethod
33
44
  def from_env(cache_dir: Optional[str] = None) -> "ProviderConfig":
34
45
  """Create config from environment variables.
@@ -51,13 +62,15 @@ class ProviderConfig:
51
62
  privacy_mode=os.getenv("PRIVACY_MODE", "false").lower() in {"1", "true", "yes"},
52
63
  cache_dir=cache_dir,
53
64
  gcp_project_id=os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCP_PROJECT_ID"),
54
- gcp_location=os.getenv("GCP_LOCATION", "us-central1"),
55
- request_timeout_seconds=float(os.getenv("AGENTIC_TIMEOUT_SECONDS", "30.0")),
65
+ gcp_location=os.getenv("GCP_LOCATION", "global"),
66
+ request_timeout_seconds=float(os.getenv("AGENTIC_TIMEOUT_SECONDS", "120.0")),
56
67
  max_retries=int(os.getenv("AGENTIC_MAX_RETRIES", "2")),
57
- retry_backoff_base_seconds=float(os.getenv("AGENTIC_BACKOFF_BASE_SECONDS", "0.2")),
68
+ retry_backoff_base_seconds=float(os.getenv("AGENTIC_BACKOFF_BASE_SECONDS", "2.0")),
58
69
  retry_backoff_factor=float(os.getenv("AGENTIC_BACKOFF_FACTOR", "2.0")),
59
70
  circuit_breaker_failure_threshold=int(os.getenv("AGENTIC_CIRCUIT_THRESHOLD", "3")),
60
71
  circuit_breaker_open_seconds=int(os.getenv("AGENTIC_CIRCUIT_OPEN_SECONDS", "60")),
72
+ initialization_timeout_seconds=float(os.getenv("AGENTIC_INIT_TIMEOUT_SECONDS", "30.0")),
73
+ warmup_timeout_seconds=float(os.getenv("AGENTIC_WARMUP_TIMEOUT_SECONDS", "15.0")),
61
74
  )
62
75
 
63
76
  def validate_environment(self, logger: Optional[object] = None) -> None:
@@ -8,7 +8,7 @@ RESPONSE_LOG_LENGTH = 500 # Characters to log from responses
8
8
  MODEL_SPEC_FORMAT = "provider/model" # Expected format for model identifiers
9
9
 
10
10
  # Default Langfuse host
11
- DEFAULT_LANGFUSE_HOST = "https://cloud.langfuse.com"
11
+ DEFAULT_LANGFUSE_HOST = "https://us.cloud.langfuse.com"
12
12
 
13
13
  # Raw response indicator
14
14
  RAW_RESPONSE_KEY = "raw" # Key used to wrap unparsed responses
@@ -13,6 +13,8 @@ from __future__ import annotations
13
13
 
14
14
  import logging
15
15
  import os
16
+ import time
17
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
16
18
  from typing import List, Dict, Any, Optional
17
19
  from datetime import datetime
18
20
 
@@ -33,6 +35,14 @@ from .constants import (
33
35
 
34
36
  logger = logging.getLogger(__name__)
35
37
 
38
+ # Error constant for initialization timeout
39
+ INIT_TIMEOUT_ERROR = "initialization_timeout"
40
+
41
+
42
+ class InitializationTimeoutError(Exception):
43
+ """Raised when model initialization exceeds the configured timeout."""
44
+ pass
45
+
36
46
 
37
47
  class LangChainBridge(BaseAIProvider):
38
48
  """Provider bridge using LangChain ChatModels with reliability patterns.
@@ -87,6 +97,7 @@ class LangChainBridge(BaseAIProvider):
87
97
 
88
98
  # Lazy-initialized chat model
89
99
  self._chat_model: Optional[Any] = None
100
+ self._warmed_up: bool = False
90
101
 
91
102
  def name(self) -> str:
92
103
  """Return provider name for logging."""
@@ -130,13 +141,45 @@ class LangChainBridge(BaseAIProvider):
130
141
  "until": open_until
131
142
  }]
132
143
 
133
- # Step 2: Get or create chat model
144
+ # Step 2: Get or create chat model with initialization timeout
134
145
  if not self._chat_model:
146
+ timeout = self._config.initialization_timeout_seconds
147
+ logger.info(f"🤖 Initializing model {self._model} with {timeout}s timeout...")
148
+ init_start = time.time()
149
+
135
150
  try:
136
- self._chat_model = self._factory.create_chat_model(
137
- self._model,
138
- self._config
139
- )
151
+ # Use ThreadPoolExecutor for cross-platform timeout
152
+ with ThreadPoolExecutor(max_workers=1) as executor:
153
+ future = executor.submit(
154
+ self._factory.create_chat_model,
155
+ self._model,
156
+ self._config
157
+ )
158
+ try:
159
+ self._chat_model = future.result(timeout=timeout)
160
+ except FuturesTimeoutError:
161
+ raise InitializationTimeoutError(
162
+ f"Model initialization timed out after {timeout}s. "
163
+ f"This may indicate network issues or service unavailability."
164
+ ) from None
165
+
166
+ init_elapsed = time.time() - init_start
167
+ logger.info(f"🤖 Model created in {init_elapsed:.2f}s, starting warm-up...")
168
+
169
+ # Warm up the model to establish connection before real work
170
+ self._warm_up_model()
171
+
172
+ total_elapsed = time.time() - init_start
173
+ logger.info(f"🤖 Model initialization complete in {total_elapsed:.2f}s")
174
+
175
+ except InitializationTimeoutError as e:
176
+ self._circuit_breaker.record_failure(self._model)
177
+ logger.exception("🤖 Model initialization timeout")
178
+ return [{
179
+ "error": INIT_TIMEOUT_ERROR,
180
+ "message": str(e),
181
+ "timeout_seconds": timeout
182
+ }]
140
183
  except Exception as e:
141
184
  self._circuit_breaker.record_failure(self._model)
142
185
  logger.error(f"🤖 Failed to initialize chat model: {e}")
@@ -146,24 +189,27 @@ class LangChainBridge(BaseAIProvider):
146
189
  }]
147
190
 
148
191
  # Step 3: Execute with retry logic
149
- logger.debug(
150
- f"🤖 [LangChain] Sending prompt to {self._model}: "
151
- f"{prompt[:PROMPT_LOG_LENGTH]}..."
192
+ logger.info(
193
+ f"🤖 [LangChain] Sending prompt to {self._model} ({len(prompt)} chars)"
152
194
  )
153
-
195
+ logger.debug(f"🤖 [LangChain] Prompt preview: {prompt[:PROMPT_LOG_LENGTH]}...")
196
+
197
+ invoke_start = time.time()
154
198
  result = self._executor.execute_with_retry(
155
199
  operation=lambda: self._invoke_model(prompt),
156
200
  operation_name=f"invoke_{self._model}"
157
201
  )
158
-
202
+ invoke_elapsed = time.time() - invoke_start
203
+
159
204
  # Step 4: Handle result and update circuit breaker
160
205
  if result.success:
161
206
  self._circuit_breaker.record_success(self._model)
162
-
207
+
163
208
  logger.info(
164
- f"🤖 [LangChain] Got response from {self._model}: "
165
- f"{result.value[:RESPONSE_LOG_LENGTH]}..."
209
+ f"🤖 [LangChain] Got response from {self._model} in {invoke_elapsed:.2f}s "
210
+ f"({len(result.value)} chars)"
166
211
  )
212
+ logger.debug(f"🤖 [LangChain] Response preview: {result.value[:RESPONSE_LOG_LENGTH]}...")
167
213
 
168
214
  # Step 5: Cache the raw response for future use
169
215
  self._cache.set(
@@ -187,26 +233,85 @@ class LangChainBridge(BaseAIProvider):
187
233
 
188
234
  def _invoke_model(self, prompt: str) -> str:
189
235
  """Invoke the chat model with a prompt.
190
-
236
+
191
237
  This is a simple wrapper that can be passed to the retry executor.
192
-
238
+
193
239
  Args:
194
240
  prompt: The prompt to send
195
-
241
+
196
242
  Returns:
197
243
  Response content as string
198
-
244
+
199
245
  Raises:
200
246
  Exception: Any error from the model invocation
201
247
  """
202
248
  from langchain_core.messages import HumanMessage
203
-
249
+
204
250
  # Prepare config with session_id in metadata (Langfuse format)
205
251
  config = {}
206
252
  if hasattr(self, '_session_id') and self._session_id:
207
253
  config["metadata"] = {"langfuse_session_id": self._session_id}
208
254
  logger.debug(f"🤖 [LangChain] Invoking with session_id: {self._session_id}")
209
-
255
+
210
256
  response = self._chat_model.invoke([HumanMessage(content=prompt)], config=config)
211
- return response.content
257
+ content = response.content
258
+
259
+ # Handle multimodal response format from Gemini 3+ models
260
+ # Response can be a list of content parts: [{'type': 'text', 'text': '...'}]
261
+ if isinstance(content, list):
262
+ # Extract text from the first text content part
263
+ for part in content:
264
+ if isinstance(part, dict) and part.get('type') == 'text':
265
+ return part.get('text', '')
266
+ # Fallback: concatenate all text parts
267
+ return ''.join(
268
+ part.get('text', '') if isinstance(part, dict) else str(part)
269
+ for part in content
270
+ )
271
+
272
+ return content
273
+
274
+ def _warm_up_model(self) -> None:
275
+ """Send a lightweight request to warm up the model connection.
276
+
277
+ This helps establish the REST connection and potentially warm up any
278
+ server-side resources before processing real correction requests.
279
+ The warm-up uses a timeout to fail fast if the service is unresponsive.
280
+ """
281
+ if self._warmed_up:
282
+ return
283
+
284
+ timeout = self._config.warmup_timeout_seconds
285
+ # Use print with flush=True for visibility when output is redirected
286
+ print(f"🔥 Warming up {self._model} connection (timeout: {timeout}s)...", flush=True)
287
+ logger.info(f"🔥 Warming up {self._model} connection (timeout: {timeout}s)...")
288
+
289
+ warmup_start = time.time()
290
+ try:
291
+ from langchain_core.messages import HumanMessage
292
+
293
+ # Minimal prompt that requires almost no processing
294
+ warm_up_prompt = 'Respond with exactly: {"status":"ready"}'
295
+
296
+ # Use ThreadPoolExecutor for timeout on warm-up call
297
+ with ThreadPoolExecutor(max_workers=1) as executor:
298
+ future = executor.submit(
299
+ self._chat_model.invoke,
300
+ [HumanMessage(content=warm_up_prompt)]
301
+ )
302
+ try:
303
+ future.result(timeout=timeout)
304
+ except FuturesTimeoutError:
305
+ raise TimeoutError(f"Warm-up timed out after {timeout}s") from None
306
+
307
+ elapsed = time.time() - warmup_start
308
+ self._warmed_up = True
309
+ print(f"🔥 Warm-up complete for {self._model} in {elapsed:.2f}s", flush=True)
310
+ logger.info(f"🔥 Warm-up complete for {self._model} in {elapsed:.2f}s")
311
+ except Exception as e:
312
+ elapsed = time.time() - warmup_start
313
+ # Don't fail the actual request if warm-up fails
314
+ # Just log and continue - the real request might still work
315
+ print(f"🔥 Warm-up failed for {self._model} after {elapsed:.2f}s: {e} (continuing anyway)", flush=True)
316
+ logger.warning(f"🔥 Warm-up failed for {self._model} after {elapsed:.2f}s: {e} (continuing anyway)")
212
317
 
@@ -3,12 +3,19 @@ from __future__ import annotations
3
3
 
4
4
  import logging
5
5
  import os
6
+ import time
6
7
  from typing import Any, Optional, List
7
8
 
8
9
  from .config import ProviderConfig
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
13
+ # Error message constant for TRY003 compliance
14
+ GOOGLE_API_KEY_MISSING_ERROR = (
15
+ "GOOGLE_API_KEY environment variable is required for Google/Gemini models. "
16
+ "Get an API key from https://aistudio.google.com/app/apikey"
17
+ )
18
+
12
19
 
13
20
  class ModelFactory:
14
21
  """Creates and configures LangChain ChatModels with observability.
@@ -100,19 +107,10 @@ class ModelFactory:
100
107
  return
101
108
 
102
109
  try:
103
- from langfuse import Langfuse
104
110
  from langfuse.langchain import CallbackHandler
105
-
106
- # Initialize Langfuse client first (this is required!)
107
- langfuse_client = Langfuse(
108
- public_key=public_key,
109
- secret_key=secret_key,
110
- host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com"),
111
- )
112
-
113
- # Then create callback handler with the same public_key
114
- # The handler will use the initialized client
115
- self._langfuse_handler = CallbackHandler(public_key=public_key)
111
+
112
+ # CallbackHandler auto-discovers credentials from environment variables
113
+ self._langfuse_handler = CallbackHandler()
116
114
  logger.info(f"🤖 Langfuse callback handler initialized for {model_spec}")
117
115
  except Exception as e:
118
116
  # If Langfuse keys are set, we MUST fail fast
@@ -212,21 +210,56 @@ class ModelFactory:
212
210
  def _create_vertexai_model(
213
211
  self, model_name: str, callbacks: List[Any], config: ProviderConfig
214
212
  ) -> Any:
215
- """Create ChatVertexAI model for Google Gemini via Vertex AI.
213
+ """Create ChatGoogleGenerativeAI model for Google Gemini.
216
214
 
217
- Uses Application Default Credentials (ADC) for authentication.
218
- In Cloud Run, this uses the service account automatically.
219
- Locally, run: gcloud auth application-default login
215
+ Uses the unified langchain-google-genai package which supports both:
216
+ - Vertex AI backend (service account / ADC auth) - when project is set
217
+ - Google AI Studio backend (API key auth) - when only api_key is set
218
+
219
+ On Cloud Run, ADC (Application Default Credentials) are used automatically
220
+ when the project parameter is provided, using the service account attached
221
+ to the Cloud Run service.
222
+
223
+ This is a REST-based API that avoids the gRPC connection issues
224
+ seen with the deprecated langchain-google-vertexai package.
220
225
  """
221
- from langchain_google_vertexai import ChatVertexAI
226
+ from langchain_google_genai import ChatGoogleGenerativeAI
222
227
 
223
- model = ChatVertexAI(
224
- model=model_name,
225
- project=config.gcp_project_id,
226
- location=config.gcp_location,
227
- max_retries=config.max_retries,
228
- callbacks=callbacks,
229
- )
230
- logger.debug(f"🤖 Created Vertex AI model: {model_name} (project={config.gcp_project_id})")
228
+ start_time = time.time()
229
+
230
+ # Determine authentication method
231
+ api_key = config.google_api_key
232
+ project = config.gcp_project_id
233
+
234
+ # Prefer Vertex AI (service account) if project is set, otherwise require API key
235
+ if not project and not api_key:
236
+ raise ValueError(GOOGLE_API_KEY_MISSING_ERROR)
237
+
238
+ if project:
239
+ logger.info(f"🤖 Creating Google Gemini model via Vertex AI (project={project}): {model_name}")
240
+ else:
241
+ logger.info(f"🤖 Creating Google Gemini model via AI Studio API: {model_name}")
242
+
243
+ # Build kwargs - only include api_key if set (otherwise ADC is used)
244
+ model_kwargs = {
245
+ "model": model_name,
246
+ "convert_system_message_to_human": True, # Gemini doesn't support system messages
247
+ "max_retries": config.max_retries,
248
+ "timeout": config.request_timeout_seconds,
249
+ "callbacks": callbacks,
250
+ }
251
+
252
+ # Add project to trigger Vertex AI backend with ADC
253
+ if project:
254
+ model_kwargs["project"] = project
255
+
256
+ # Add API key if available (can be used with or without project)
257
+ if api_key:
258
+ model_kwargs["google_api_key"] = api_key
259
+
260
+ model = ChatGoogleGenerativeAI(**model_kwargs)
261
+
262
+ elapsed = time.time() - start_time
263
+ logger.info(f"🤖 Google Gemini model created in {elapsed:.2f}s: {model_name}")
231
264
  return model
232
265