pdd-cli 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pdd-cli might be problematic. Click here for more details.

pdd/llm_invoke.py CHANGED
@@ -1,389 +1,1202 @@
1
- #!/usr/bin/env python
2
- """
3
- llm_invoke.py
4
-
5
- This module provides a single function, llm_invoke, that runs a prompt with a given input
6
- against a language model (LLM) using Langchain and returns the output, cost, and model name.
7
- The function supports model selection based on cost/ELO interpolation controlled by the
8
- "strength" parameter. It also implements a retry mechanism: if a model invocation fails,
9
- it falls back to the next candidate (cheaper for strength < 0.5, or higher ELO for strength ≥ 0.5).
10
-
11
- Usage:
12
- from llm_invoke import llm_invoke
13
- result = llm_invoke(prompt, input_json, strength, temperature, verbose=True, output_pydantic=MyPydanticClass)
14
- # result is a dict with keys: 'result', 'cost', 'model_name'
15
-
16
- Environment:
17
- - PDD_MODEL_DEFAULT: if set, used as the base model name. Otherwise defaults to "gpt-4.1-nano".
18
- - PDD_PATH: if set, models are loaded from $PDD_PATH/data/llm_model.csv; otherwise from ./data/llm_model.csv.
19
- - Models that require an API key will check the corresponding environment variable (name provided in the CSV).
20
- """
1
+ # Corrected code_under_test (llm_invoke.py)
2
+ # Added optional debugging prints in _select_model_candidates
21
3
 
22
4
  import os
23
- import csv
24
- import json
5
+ import pandas as pd
6
+ import litellm
7
+ import logging # ADDED FOR DETAILED LOGGING
8
+
9
+ # --- Configure Detailed Logging for LiteLLM --- MODIFIED SECTION
10
+ PROJECT_ROOT_FOR_LOG = '/Users/gregtanaka/Documents/pdd_cloud/pdd' # Explicit project root
11
+ LOG_FILE_PATH = os.path.join(PROJECT_ROOT_FOR_LOG, 'litellm_debug.log')
12
+
13
+ # Get the litellm logger specifically
14
+ litellm_logger = logging.getLogger("litellm")
15
+ litellm_logger.setLevel(logging.DEBUG) # Set its level to DEBUG
16
+
17
+ # Create a file handler
18
+ file_handler = logging.FileHandler(LOG_FILE_PATH, mode='w')
19
+ file_handler.setLevel(logging.DEBUG)
20
+
21
+ # Create a formatter and add it to the handler
22
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
23
+ file_handler.setFormatter(formatter)
24
+
25
+ # Add the handler to the litellm logger
26
+ # Check if handlers are already present to avoid duplication if module is reloaded
27
+ if not litellm_logger.handlers:
28
+ litellm_logger.addHandler(file_handler)
29
+
30
+ # Also ensure the root logger has a basic handler if nothing else is configured
31
+ # This can help catch messages if litellm logs to root or other unnamed loggers
32
+ if not logging.getLogger().handlers: # Check root logger
33
+ logging.basicConfig(level=logging.DEBUG) # Default to console for other logs
34
+ # --- End Detailed Logging Configuration ---
25
35
 
26
- from pydantic import BaseModel, Field
36
+
37
+ import json
27
38
  from rich import print as rprint
28
- from rich.errors import MarkupError
29
-
30
- # Langchain core and community imports
31
- from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
32
- from langchain_community.cache import SQLiteCache
33
- from langchain.globals import set_llm_cache
34
- from langchain_core.output_parsers import PydanticOutputParser, StrOutputParser
35
- from langchain_core.runnables import RunnablePassthrough, ConfigurableField
36
-
37
- # LLM provider imports
38
- from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI
39
- from langchain_fireworks import Fireworks
40
- from langchain_anthropic import ChatAnthropic
41
- from langchain_google_genai import ChatGoogleGenerativeAI
42
- from langchain_google_vertexai import ChatVertexAI
43
- from langchain_groq import ChatGroq
44
- from langchain_together import Together
45
- from langchain_ollama.llms import OllamaLLM
46
-
47
- from langchain.callbacks.base import BaseCallbackHandler
48
- from langchain.schema import LLMResult
49
-
50
- # ---------------- Internal Helper Classes and Functions ---------------- #
51
-
52
- class CompletionStatusHandler(BaseCallbackHandler):
53
- """
54
- Callback handler to capture LLM token usage and completion metadata.
55
- """
56
- def __init__(self):
57
- self.is_complete = False
58
- self.finish_reason = None
59
- self.input_tokens = None
60
- self.output_tokens = None
61
-
62
- def on_llm_end(self, response: LLMResult, **kwargs) -> None:
63
- self.is_complete = True
64
- if response.generations and response.generations[0]:
65
- generation = response.generations[0][0]
66
- # Safely get generation_info; if it's None, default to {}
67
- generation_info = generation.generation_info or {}
68
- self.finish_reason = (generation_info.get('finish_reason') or "").lower()
69
-
70
- # Attempt to get token usage from generation.message if available.
71
- if (
72
- hasattr(generation, "message")
73
- and generation.message is not None
74
- and hasattr(generation.message, "usage_metadata")
75
- and generation.message.usage_metadata
76
- ):
77
- usage_metadata = generation.message.usage_metadata
78
- else:
79
- usage_metadata = generation_info.get("usage_metadata", {})
80
-
81
- self.input_tokens = usage_metadata.get('input_tokens', 0)
82
- self.output_tokens = usage_metadata.get('output_tokens', 0)
39
+ from dotenv import load_dotenv
40
+ from pathlib import Path
41
+ from typing import Optional, Dict, List, Any, Type, Union
42
+ from pydantic import BaseModel, ValidationError
43
+ import openai # Import openai for exception handling as LiteLLM maps to its types
44
+ from langchain_core.prompts import PromptTemplate
45
+ import warnings
46
+ import time as time_module # Alias to avoid conflict with 'time' parameter
47
+ # Import the default model constant
48
+ from pdd import DEFAULT_LLM_MODEL
83
49
 
84
- class ModelInfo:
85
- """
86
- Represents information about an LLM model as loaded from the CSV.
87
- """
88
- def __init__(self, provider, model, input_cost, output_cost, coding_arena_elo,
89
- base_url, api_key, counter, encoder, max_tokens, max_completion_tokens,
90
- structured_output):
91
- self.provider = provider.strip() if provider else ""
92
- self.model = model.strip() if model else ""
93
- self.input_cost = float(input_cost) if input_cost else 0.0
94
- self.output_cost = float(output_cost) if output_cost else 0.0
95
- self.average_cost = (self.input_cost + self.output_cost) / 2
96
- self.coding_arena_elo = float(coding_arena_elo) if coding_arena_elo else 0.0
97
- self.base_url = base_url.strip() if base_url else None
98
- self.api_key = api_key.strip() if api_key else None
99
- self.counter = counter.strip() if counter else None
100
- self.encoder = encoder.strip() if encoder else None
101
- self.max_tokens = int(max_tokens) if max_tokens else None
102
- self.max_completion_tokens = int(max_completion_tokens) if max_completion_tokens else None
103
- self.structured_output = (str(structured_output).lower() == 'true') if structured_output else False
104
-
105
- def load_models():
50
+ # Opt-in to future pandas behavior regarding downcasting
51
+ pd.set_option('future.no_silent_downcasting', True)
52
+
53
+ # <<< SET LITELLM DEBUG LOGGING >>>
54
+ # os.environ['LITELLM_LOG'] = 'DEBUG' # Keep commented out unless debugging LiteLLM itself
55
+
56
+ # --- Constants and Configuration ---
57
+
58
+ # Determine project root: 1. PDD_PATH env var, 2. Search upwards from script, 3. CWD
59
+ PROJECT_ROOT = None
60
+ PDD_PATH_ENV = os.getenv("PDD_PATH")
61
+
62
+ if PDD_PATH_ENV:
63
+ _path_from_env = Path(PDD_PATH_ENV)
64
+ if _path_from_env.is_dir():
65
+ PROJECT_ROOT = _path_from_env.resolve()
66
+ # print(f"[DEBUG] Using PROJECT_ROOT from PDD_PATH: {PROJECT_ROOT}") # Optional debug
67
+ else:
68
+ warnings.warn(f"PDD_PATH environment variable ('{PDD_PATH_ENV}') is set but not a valid directory. Attempting auto-detection.")
69
+
70
+ if PROJECT_ROOT is None: # If PDD_PATH wasn't set or was invalid
71
+ try:
72
+ # Start from the directory containing this script
73
+ current_dir = Path(__file__).resolve().parent
74
+ # Look for project markers (e.g., .git, pyproject.toml, data/, .env)
75
+ # Go up a maximum of 5 levels to prevent infinite loops
76
+ for _ in range(5):
77
+ has_git = (current_dir / ".git").exists()
78
+ has_pyproject = (current_dir / "pyproject.toml").exists()
79
+ has_data = (current_dir / "data").is_dir()
80
+ has_dotenv = (current_dir / ".env").exists()
81
+
82
+ if has_git or has_pyproject or has_data or has_dotenv:
83
+ PROJECT_ROOT = current_dir
84
+ # print(f"[DEBUG] Determined PROJECT_ROOT by marker search: {PROJECT_ROOT}") # Optional debug
85
+ break
86
+
87
+ parent_dir = current_dir.parent
88
+ if parent_dir == current_dir: # Reached filesystem root
89
+ break
90
+ current_dir = parent_dir
91
+
92
+ except NameError: # __file__ might not be defined (e.g., interactive session)
93
+ warnings.warn("__file__ not defined. Cannot automatically detect project root from script location.")
94
+ except Exception as e: # Catch potential permission errors etc.
95
+ warnings.warn(f"Error during project root auto-detection: {e}")
96
+
97
+ if PROJECT_ROOT is None: # Fallback to CWD if no method succeeded
98
+ PROJECT_ROOT = Path.cwd().resolve()
99
+ warnings.warn(f"Could not determine project root automatically. Using current working directory: {PROJECT_ROOT}. Ensure this is the intended root or set the PDD_PATH environment variable.")
100
+
101
+
102
+ ENV_PATH = PROJECT_ROOT / ".env"
103
+ # --- Determine LLM_MODEL_CSV_PATH ---
104
+ # Prioritize ~/.pdd/llm_model.csv
105
+ user_pdd_dir = Path.home() / ".pdd"
106
+ user_model_csv_path = user_pdd_dir / "llm_model.csv"
107
+
108
+ if user_model_csv_path.is_file():
109
+ LLM_MODEL_CSV_PATH = user_model_csv_path
110
+ print(f"[INFO] Using user-specific LLM model CSV: {LLM_MODEL_CSV_PATH}")
111
+ else:
112
+ LLM_MODEL_CSV_PATH = PROJECT_ROOT / "data" / "llm_model.csv"
113
+ print(f"[INFO] Using project LLM model CSV: {LLM_MODEL_CSV_PATH}")
114
+ # ---------------------------------
115
+
116
+ # Load environment variables from .env file
117
+ # print(f"[DEBUG] Attempting to load .env from: {ENV_PATH}") # Optional debug
118
+ if ENV_PATH.exists():
119
+ load_dotenv(dotenv_path=ENV_PATH)
120
+ # print(f"[DEBUG] Loaded .env file from: {ENV_PATH}") # Optional debug
121
+ else:
122
+ # Reduce verbosity if .env is optional or often missing
123
+ # warnings.warn(f".env file not found at {ENV_PATH}. API keys might be missing.")
124
+ pass # Silently proceed if .env is optional
125
+
126
+ # Default model if PDD_MODEL_DEFAULT is not set
127
+ # Use the imported constant as the default
128
+ DEFAULT_BASE_MODEL = os.getenv("PDD_MODEL_DEFAULT", DEFAULT_LLM_MODEL)
129
+
130
+ # --- LiteLLM Cache Configuration (S3 compatible for GCS, with SQLite fallback) ---
131
+ GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
132
+ GCS_ENDPOINT_URL = "https://storage.googleapis.com" # GCS S3 compatibility endpoint
133
+ GCS_REGION_NAME = os.getenv("GCS_REGION_NAME", "auto") # Often 'auto' works for GCS
134
+ GCS_HMAC_ACCESS_KEY_ID = os.getenv("GCS_HMAC_ACCESS_KEY_ID") # Load HMAC Key ID
135
+ GCS_HMAC_SECRET_ACCESS_KEY = os.getenv("GCS_HMAC_SECRET_ACCESS_KEY") # Load HMAC Secret
136
+
137
+ cache_configured = False
138
+
139
+ if GCS_BUCKET_NAME and GCS_HMAC_ACCESS_KEY_ID and GCS_HMAC_SECRET_ACCESS_KEY:
140
+ # Store original AWS credentials before overwriting for GCS cache setup
141
+ original_aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
142
+ original_aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
143
+ original_aws_region_name = os.environ.get('AWS_REGION_NAME')
144
+
145
+ try:
146
+ # Temporarily set AWS env vars to GCS HMAC keys for S3 compatible cache
147
+ os.environ['AWS_ACCESS_KEY_ID'] = GCS_HMAC_ACCESS_KEY_ID
148
+ os.environ['AWS_SECRET_ACCESS_KEY'] = GCS_HMAC_SECRET_ACCESS_KEY
149
+ # os.environ['AWS_REGION_NAME'] = GCS_REGION_NAME # Uncomment if needed
150
+
151
+ litellm.cache = litellm.Cache(
152
+ type="s3",
153
+ s3_bucket_name=GCS_BUCKET_NAME,
154
+ s3_region_name=GCS_REGION_NAME, # Pass region explicitly to cache
155
+ s3_endpoint_url=GCS_ENDPOINT_URL,
156
+ )
157
+ print(f"[INFO] LiteLLM cache configured for GCS bucket (S3 compatible): {GCS_BUCKET_NAME}")
158
+ cache_configured = True
159
+
160
+ except Exception as e:
161
+ warnings.warn(f"Failed to configure LiteLLM S3/GCS cache: {e}. Attempting SQLite cache fallback.")
162
+ litellm.cache = None # Explicitly disable cache on failure (will try SQLite next)
163
+
164
+ finally:
165
+ # Restore original AWS credentials after cache setup attempt
166
+ if original_aws_access_key_id is not None:
167
+ os.environ['AWS_ACCESS_KEY_ID'] = original_aws_access_key_id
168
+ elif 'AWS_ACCESS_KEY_ID' in os.environ:
169
+ del os.environ['AWS_ACCESS_KEY_ID']
170
+
171
+ if original_aws_secret_access_key is not None:
172
+ os.environ['AWS_SECRET_ACCESS_KEY'] = original_aws_secret_access_key
173
+ elif 'AWS_SECRET_ACCESS_KEY' in os.environ:
174
+ del os.environ['AWS_SECRET_ACCESS_KEY']
175
+
176
+ if original_aws_region_name is not None:
177
+ os.environ['AWS_REGION_NAME'] = original_aws_region_name
178
+ elif 'AWS_REGION_NAME' in os.environ:
179
+ pass # Or just leave it if the temporary setting wasn't done/needed
180
+
181
+ if not cache_configured:
182
+ try:
183
+ # Try SQLite-based cache as a fallback
184
+ sqlite_cache_path = PROJECT_ROOT / "litellm_cache.sqlite"
185
+ litellm.cache = litellm.Cache(type="sqlite", cache_path=str(sqlite_cache_path))
186
+ print(f"[INFO] LiteLLM SQLite cache configured at {sqlite_cache_path}")
187
+ cache_configured = True
188
+ except Exception as e2:
189
+ warnings.warn(f"Failed to configure LiteLLM SQLite cache: {e2}. Caching is disabled.")
190
+ litellm.cache = None
191
+
192
+ if not cache_configured:
193
+ warnings.warn("All LiteLLM cache configuration attempts failed. Caching is disabled.")
194
+ litellm.cache = None
195
+
196
+ # --- LiteLLM Callback for Success Logging ---
197
+
198
+ # Module-level storage for last callback data (Use with caution in concurrent environments)
199
+ _LAST_CALLBACK_DATA = {
200
+ "input_tokens": 0,
201
+ "output_tokens": 0,
202
+ "finish_reason": None,
203
+ "cost": 0.0,
204
+ }
205
+
206
+ def _litellm_success_callback(
207
+ kwargs: Dict[str, Any], # kwargs passed to completion
208
+ completion_response: Any, # response object from completion
209
+ start_time: float, end_time: float # start/end time
210
+ ):
106
211
  """
107
- Loads model information from llm_model.csv located in either $PDD_PATH/data or ./data.
212
+ LiteLLM success callback to capture usage and finish reason.
213
+ Stores data in a module-level variable for potential retrieval.
108
214
  """
109
- pdd_path = os.environ.get('PDD_PATH', '.')
110
- models_file = os.path.join(pdd_path, 'data', 'llm_model.csv')
111
- models = []
215
+ global _LAST_CALLBACK_DATA
216
+ usage = getattr(completion_response, 'usage', None)
217
+ input_tokens = getattr(usage, 'prompt_tokens', 0)
218
+ output_tokens = getattr(usage, 'completion_tokens', 0)
219
+ finish_reason = getattr(completion_response.choices[0], 'finish_reason', None)
220
+
221
+ calculated_cost = 0.0
112
222
  try:
113
- with open(models_file, newline='') as csvfile:
114
- reader = csv.DictReader(csvfile)
115
- for row in reader:
116
- model_info = ModelInfo(
117
- provider=row.get('provider',''),
118
- model=row.get('model',''),
119
- input_cost=row.get('input','0'),
120
- output_cost=row.get('output','0'),
121
- coding_arena_elo=row.get('coding_arena_elo','0'),
122
- base_url=row.get('base_url',''),
123
- api_key=row.get('api_key',''),
124
- counter=row.get('counter',''),
125
- encoder=row.get('encoder',''),
126
- max_tokens=row.get('max_tokens',''),
127
- max_completion_tokens=row.get('max_completion_tokens',''),
128
- structured_output=row.get('structured_output','False')
223
+ # Attempt 1: Use the response object directly (works for most single calls)
224
+ cost_val = litellm.completion_cost(completion_response=completion_response)
225
+ calculated_cost = cost_val if cost_val is not None else 0.0
226
+ except Exception as e1:
227
+ # Attempt 2: If response object failed (e.g., missing provider in model name),
228
+ # try again using explicit model from kwargs and tokens from usage.
229
+ # This is often needed for batch completion items.
230
+ print(f"[DEBUG] Attempting cost calculation with fallback method: {e1}")
231
+ try:
232
+ model_name = kwargs.get("model") # Get original model name from input kwargs
233
+ if model_name and usage:
234
+ prompt_tokens = getattr(usage, 'prompt_tokens', 0)
235
+ completion_tokens = getattr(usage, 'completion_tokens', 0)
236
+ cost_val = litellm.completion_cost(
237
+ model=model_name,
238
+ prompt_tokens=prompt_tokens,
239
+ completion_tokens=completion_tokens
129
240
  )
130
- models.append(model_info)
131
- except FileNotFoundError:
132
- raise FileNotFoundError(f"llm_model.csv not found at {models_file}")
133
- return models
241
+ calculated_cost = cost_val if cost_val is not None else 0.0
242
+ else:
243
+ # If we can't get model name or usage, fallback to 0
244
+ calculated_cost = 0.0
245
+ # Optional: Log the original error e1 if needed
246
+ # print(f"[Callback WARN] Failed to calculate cost with response object ({e1}) and fallback failed.")
247
+ except Exception as e2:
248
+ # Optional: Log secondary error e2 if needed
249
+ # print(f"[Callback WARN] Failed to calculate cost with fallback method: {e2}")
250
+ calculated_cost = 0.0 # Default to 0 on any error
251
+ print(f"[DEBUG] Cost calculation failed with fallback method: {e2}")
134
252
 
135
- def select_model(models, base_model_name):
136
- """
137
- Retrieve the base model whose name matches base_model_name. Raises an error if not found.
138
- """
139
- for model in models:
140
- if model.model == base_model_name:
141
- return model
142
- raise ValueError(f"Base model '{base_model_name}' not found in the models list.")
253
+ _LAST_CALLBACK_DATA["input_tokens"] = input_tokens
254
+ _LAST_CALLBACK_DATA["output_tokens"] = output_tokens
255
+ _LAST_CALLBACK_DATA["finish_reason"] = finish_reason
256
+ _LAST_CALLBACK_DATA["cost"] = calculated_cost # Store the calculated cost
257
+
258
+ # Callback doesn't need to return a value now
259
+ # return calculated_cost
260
+
261
+ # Example of logging within the callback (can be expanded)
262
+ # print(f"[Callback] Tokens: In={input_tokens}, Out={output_tokens}. Reason: {finish_reason}. Cost: ${calculated_cost:.6f}")
263
+
264
+ # Register the callback with LiteLLM
265
+ litellm.success_callback = [_litellm_success_callback]
266
+
267
+ # --- Helper Functions ---
268
+
269
+ def _load_model_data(csv_path: Path) -> pd.DataFrame:
270
+ """Loads and preprocesses the LLM model data from CSV."""
271
+ if not csv_path.exists():
272
+ raise FileNotFoundError(f"LLM model CSV not found at {csv_path}")
273
+ try:
274
+ df = pd.read_csv(csv_path)
275
+ # Basic validation and type conversion
276
+ required_cols = ['provider', 'model', 'input', 'output', 'coding_arena_elo', 'api_key', 'structured_output', 'reasoning_type']
277
+ for col in required_cols:
278
+ if col not in df.columns:
279
+ raise ValueError(f"Missing required column in CSV: {col}")
280
+
281
+ # Convert numeric columns, handling potential errors
282
+ numeric_cols = ['input', 'output', 'coding_arena_elo', 'max_tokens',
283
+ 'max_completion_tokens', 'max_reasoning_tokens']
284
+ for col in numeric_cols:
285
+ if col in df.columns:
286
+ # Use errors='coerce' to turn unparseable values into NaN
287
+ df[col] = pd.to_numeric(df[col], errors='coerce')
288
+
289
+ # Fill NaN in critical numeric columns used for selection/interpolation
290
+ df['input'] = df['input'].fillna(0.0)
291
+ df['output'] = df['output'].fillna(0.0)
292
+ df['coding_arena_elo'] = df['coding_arena_elo'].fillna(0) # Use 0 ELO for missing
293
+ # Ensure max_reasoning_tokens is numeric, fillna with 0
294
+ df['max_reasoning_tokens'] = df['max_reasoning_tokens'].fillna(0).astype(int) # Ensure int
295
+
296
+ # Calculate average cost (handle potential division by zero if needed, though unlikely with fillna)
297
+ df['avg_cost'] = (df['input'] + df['output']) / 2
298
+
299
+ # Ensure boolean interpretation for structured_output
300
+ if 'structured_output' in df.columns:
301
+ df['structured_output'] = df['structured_output'].fillna(False).astype(bool)
302
+ else:
303
+ df['structured_output'] = False # Assume false if column missing
304
+
305
+ # Ensure reasoning_type is string, fillna with 'none' and lowercase
306
+ df['reasoning_type'] = df['reasoning_type'].fillna('none').astype(str).str.lower()
307
+
308
+ # Ensure api_key is treated as string, fill NaN with empty string ''
309
+ # This handles cases where read_csv might interpret empty fields as NaN
310
+ df['api_key'] = df['api_key'].fillna('').astype(str)
311
+
312
+ return df
313
+ except Exception as e:
314
+ raise RuntimeError(f"Error loading or processing LLM model CSV {csv_path}: {e}") from e
315
+
316
+ def _select_model_candidates(
317
+ strength: float,
318
+ base_model_name: str,
319
+ model_df: pd.DataFrame
320
+ ) -> List[Dict[str, Any]]:
321
+ """Selects and sorts candidate models based on strength and availability."""
322
+
323
+ # 1. Filter by API Key Name Presence (initial availability check)
324
+ # Keep models with a non-empty api_key field in the CSV.
325
+ # The actual key value check happens later.
326
+ # Allow models with empty api_key (e.g., Bedrock using AWS creds, local models)
327
+ available_df = model_df[model_df['api_key'].notna()].copy()
328
+
329
+ # --- Check if the initial DataFrame itself was empty ---
330
+ if model_df.empty:
331
+ raise ValueError("Loaded model data is empty. Check CSV file.")
332
+
333
+ # --- Check if filtering resulted in empty (might indicate all models had NaN api_key) ---
334
+ if available_df.empty:
335
+ # This case is less likely if notna() is the only filter, but good to check.
336
+ rprint("[WARN] No models found after filtering for non-NaN api_key. Check CSV 'api_key' column.")
337
+ # Decide if this should be a hard error or allow proceeding if logic permits
338
+ # For now, let's raise an error as it likely indicates a CSV issue.
339
+ raise ValueError("No models available after initial filtering (all had NaN 'api_key'?).")
340
+
341
+ # 2. Find Base Model
342
+ base_model_row = available_df[available_df['model'] == base_model_name]
343
+ if base_model_row.empty:
344
+ # Try finding base model in the *original* df in case it was filtered out
345
+ original_base = model_df[model_df['model'] == base_model_name]
346
+ if not original_base.empty:
347
+ raise ValueError(f"Base model '{base_model_name}' found in CSV but requires API key '{original_base.iloc[0]['api_key']}' which might be missing or invalid configuration.")
348
+ else:
349
+ raise ValueError(f"Specified base model '{base_model_name}' not found in the LLM model CSV.")
350
+
351
+ base_model = base_model_row.iloc[0]
352
+
353
+ # 3. Determine Target and Sort
354
+ candidates = []
355
+ target_metric_value = None # For debugging print
143
356
 
144
- def get_candidate_models(strength, models, base_model):
145
- """
146
- Returns ordered list of candidate models based on strength parameter.
147
- Only includes models with available API keys.
148
- """
149
- # Filter for models with valid API keys (including test environment)
150
- available_models = [m for m in models
151
- if not m.api_key or
152
- os.environ.get(m.api_key) or
153
- m.api_key == "EXISTING_KEY"]
154
-
155
- if not available_models:
156
- raise RuntimeError("No models available with valid API keys")
157
-
158
- # For base model case (strength = 0.5), use base model if available
159
357
  if strength == 0.5:
160
- base_candidates = [m for m in available_models if m.model == base_model.model]
161
- if base_candidates:
162
- return base_candidates
163
- return [available_models[0]]
164
-
165
- # For strength < 0.5, prioritize cheaper models
166
- if strength < 0.5:
167
- # Get models cheaper than or equal to base model
168
- cheaper_models = [m for m in available_models
169
- if m.average_cost <= base_model.average_cost]
170
- if not cheaper_models:
171
- return [available_models[0]]
172
-
173
- # For test environment, honor the mock model setup
174
- test_models = [m for m in cheaper_models if m.api_key == "EXISTING_KEY"]
175
- if test_models:
176
- return test_models
177
-
178
- # Production path: interpolate based on cost
179
- cheapest = min(cheaper_models, key=lambda m: m.average_cost)
180
- cost_range = base_model.average_cost - cheapest.average_cost
181
- target_cost = cheapest.average_cost + (strength / 0.5) * cost_range
182
- return sorted(cheaper_models, key=lambda m: abs(m.average_cost - target_cost))
183
-
184
- # For strength > 0.5, prioritize higher ELO models
185
- # Get models with higher or equal ELO than base_model
186
- better_models = [m for m in available_models
187
- if m.coding_arena_elo >= base_model.coding_arena_elo]
188
- if not better_models:
189
- return [available_models[0]]
190
-
191
- # For test environment, honor the mock model setup
192
- test_models = [m for m in better_models if m.api_key == "EXISTING_KEY"]
193
- if test_models:
194
- return test_models
195
-
196
- # Production path: interpolate based on ELO
197
- highest = max(better_models, key=lambda m: m.coding_arena_elo)
198
- elo_range = highest.coding_arena_elo - base_model.coding_arena_elo
199
- target_elo = base_model.coding_arena_elo + ((strength - 0.5) / 0.5) * elo_range
200
- return sorted(better_models, key=lambda m: abs(m.coding_arena_elo - target_elo))
201
-
202
- def create_llm_instance(selected_model, temperature, handler):
203
- """
204
- Creates an instance of the LLM using the selected_model parameters.
205
- Handles provider-specific settings and token limit configurations.
206
- """
207
- provider = selected_model.provider.lower()
208
- model_name = selected_model.model
209
- base_url = selected_model.base_url
210
- api_key_env = selected_model.api_key
211
- max_completion_tokens = selected_model.max_completion_tokens
212
- max_tokens = selected_model.max_tokens
213
-
214
- api_key = os.environ.get(api_key_env) if api_key_env else None
215
-
216
- if provider == 'openai':
217
- if base_url:
218
- llm = ChatOpenAI(model=model_name, temperature=temperature,
219
- openai_api_key=api_key, callbacks=[handler],
220
- openai_api_base=base_url)
358
+ # target_model = base_model
359
+ # Sort remaining by ELO descending as fallback
360
+ available_df['sort_metric'] = -available_df['coding_arena_elo'] # Negative for descending sort
361
+ candidates = available_df.sort_values(by='sort_metric').to_dict('records')
362
+ # Ensure base model is first if it exists
363
+ if any(c['model'] == base_model_name for c in candidates):
364
+ candidates.sort(key=lambda x: 0 if x['model'] == base_model_name else 1)
365
+ target_metric_value = f"Base Model ELO: {base_model['coding_arena_elo']}"
366
+
367
+ elif strength < 0.5:
368
+ # Interpolate by Cost (downwards from base)
369
+ base_cost = base_model['avg_cost']
370
+ cheapest_model = available_df.loc[available_df['avg_cost'].idxmin()]
371
+ cheapest_cost = cheapest_model['avg_cost']
372
+
373
+ if base_cost <= cheapest_cost: # Handle edge case where base is cheapest
374
+ target_cost = cheapest_cost + strength * (base_cost - cheapest_cost) # Will be <= base_cost
221
375
  else:
222
- if model_name.startswith('o'):
223
- llm = ChatOpenAI(model=model_name, temperature=temperature,
224
- openai_api_key=api_key, callbacks=[handler],
225
- reasoning={"effort": "high","summary": "auto"})
226
- else:
227
- llm = ChatOpenAI(model=model_name, temperature=temperature,
228
- openai_api_key=api_key, callbacks=[handler])
229
- elif provider == 'anthropic':
230
- # Special case for Claude 3.7 Sonnet with thinking token budget
231
- if 'claude-3-7-sonnet' in model_name:
232
- llm = ChatAnthropic(
233
- model=model_name,
234
- temperature=temperature,
235
- callbacks=[handler],
236
- thinking={"type": "enabled", "budget_tokens": 4000} # 32K thinking token budget
237
- )
376
+ # Interpolate between cheapest and base
377
+ target_cost = cheapest_cost + (strength / 0.5) * (base_cost - cheapest_cost)
378
+
379
+ available_df['sort_metric'] = abs(available_df['avg_cost'] - target_cost)
380
+ candidates = available_df.sort_values(by='sort_metric').to_dict('records')
381
+ target_metric_value = f"Target Cost: {target_cost:.6f}"
382
+
383
+ else: # strength > 0.5
384
+ # Interpolate by ELO (upwards from base)
385
+ base_elo = base_model['coding_arena_elo']
386
+ highest_elo_model = available_df.loc[available_df['coding_arena_elo'].idxmax()]
387
+ highest_elo = highest_elo_model['coding_arena_elo']
388
+
389
+ if highest_elo <= base_elo: # Handle edge case where base has highest ELO
390
+ target_elo = base_elo + (strength - 0.5) * (highest_elo - base_elo) # Will be >= base_elo
238
391
  else:
239
- llm = ChatAnthropic(model=model_name, temperature=temperature, callbacks=[handler])
240
- elif provider == 'google':
241
- llm = ChatGoogleGenerativeAI(model=model_name, temperature=temperature, callbacks=[handler])
242
- elif provider == 'googlevertexai':
243
- llm = ChatVertexAI(model=model_name, temperature=temperature, callbacks=[handler])
244
- elif provider == 'ollama':
245
- llm = OllamaLLM(model=model_name, temperature=temperature, callbacks=[handler])
246
- elif provider == 'azure':
247
- llm = AzureChatOpenAI(model=model_name, temperature=temperature,
248
- callbacks=[handler], openai_api_key=api_key, openai_api_base=base_url)
249
- elif provider == 'fireworks':
250
- llm = Fireworks(model=model_name, temperature=temperature, callbacks=[handler])
251
- elif provider == 'together':
252
- llm = Together(model=model_name, temperature=temperature, callbacks=[handler])
253
- elif provider == 'groq':
254
- llm = ChatGroq(model_name=model_name, temperature=temperature, callbacks=[handler])
392
+ # Interpolate between base and highest
393
+ target_elo = base_elo + ((strength - 0.5) / 0.5) * (highest_elo - base_elo)
394
+
395
+ available_df['sort_metric'] = abs(available_df['coding_arena_elo'] - target_elo)
396
+ candidates = available_df.sort_values(by='sort_metric').to_dict('records')
397
+ target_metric_value = f"Target ELO: {target_elo:.2f}"
398
+
399
+
400
+ if not candidates:
401
+ # This should ideally not happen if available_df was not empty
402
+ raise RuntimeError("Model selection resulted in an empty candidate list.")
403
+
404
+ # --- DEBUGGING PRINT ---
405
+ if os.getenv("PDD_DEBUG_SELECTOR"): # Add env var check for debug prints
406
+ print("\n--- DEBUG: _select_model_candidates ---")
407
+ print(f"Strength: {strength}, Base Model: {base_model_name}")
408
+ print(f"Metric: {target_metric_value}")
409
+ print("Available DF (Sorted by metric):")
410
+ # Select columns relevant to the sorting metric
411
+ sort_cols = ['model', 'avg_cost', 'coding_arena_elo', 'sort_metric']
412
+ print(available_df.sort_values(by='sort_metric')[sort_cols])
413
+ print("Final Candidates List (Model Names):")
414
+ print([c['model'] for c in candidates])
415
+ print("---------------------------------------\n")
416
+ # --- END DEBUGGING PRINT ---
417
+
418
+ return candidates
419
+
420
+
421
+ def _ensure_api_key(model_info: Dict[str, Any], newly_acquired_keys: Dict[str, bool], verbose: bool) -> bool:
422
+ """Checks for API key in env, prompts user if missing, and updates .env."""
423
+ key_name = model_info.get('api_key')
424
+
425
+ if not key_name or key_name == "EXISTING_KEY":
426
+ if verbose:
427
+ rprint(f"[INFO] Skipping API key check for model {model_info.get('model')} (key name: {key_name})")
428
+ return True # Assume key is handled elsewhere or not needed
429
+
430
+ key_value = os.getenv(key_name)
431
+
432
+ if key_value:
433
+ if verbose:
434
+ rprint(f"[INFO] API key '{key_name}' found in environment.")
435
+ newly_acquired_keys[key_name] = False # Mark as existing
436
+ return True
255
437
  else:
256
- raise ValueError(f"Unsupported provider: {selected_model.provider}")
438
+ rprint(f"[WARN] API key environment variable '{key_name}' for model '{model_info.get('model')}' is not set.")
439
+ try:
440
+ # Interactive prompt
441
+ user_provided_key = input(f"Please enter the API key for {key_name}: ").strip()
442
+ if not user_provided_key:
443
+ rprint("[ERROR] No API key provided. Cannot proceed with this model.")
444
+ return False
445
+
446
+ # Set environment variable for the current process
447
+ os.environ[key_name] = user_provided_key
448
+ rprint(f"[INFO] API key '{key_name}' set for the current session.")
449
+ newly_acquired_keys[key_name] = True # Mark as newly acquired
450
+
451
+ # Update .env file
452
+ try:
453
+ lines = []
454
+ if ENV_PATH.exists():
455
+ with open(ENV_PATH, 'r') as f:
456
+ lines = f.readlines()
457
+
458
+ new_lines = []
459
+ # key_updated = False
460
+ prefix = f"{key_name}="
461
+ prefix_spaced = f"{key_name} =" # Handle potential spaces
462
+
463
+ for line in lines:
464
+ stripped_line = line.strip()
465
+ if stripped_line.startswith(prefix) or stripped_line.startswith(prefix_spaced):
466
+ # Comment out the old key
467
+ new_lines.append(f"# {line}")
468
+ # key_updated = True # Indicates we found an old line to comment
469
+ elif stripped_line.startswith(f"# {prefix}") or stripped_line.startswith(f"# {prefix_spaced}"):
470
+ # Keep already commented lines as they are
471
+ new_lines.append(line)
472
+ else:
473
+ new_lines.append(line)
474
+
475
+ # Append the new key, ensuring quotes for robustness
476
+ new_key_line = f'{key_name}="{user_provided_key}"\n'
477
+ # Add newline before if file not empty and doesn't end with newline
478
+ if new_lines and not new_lines[-1].endswith('\n'):
479
+ new_lines.append('\n')
480
+ new_lines.append(new_key_line)
257
481
 
258
- if max_completion_tokens:
259
- llm.model_kwargs = {"max_completion_tokens": max_completion_tokens}
260
- elif max_tokens:
261
- if provider == 'google' or provider == 'googlevertexai':
262
- llm.max_output_tokens = max_tokens
482
+
483
+ with open(ENV_PATH, 'w') as f:
484
+ f.writelines(new_lines)
485
+
486
+ rprint(f"[INFO] API key '{key_name}' saved to {ENV_PATH}.")
487
+ rprint("[bold yellow]SECURITY WARNING:[/bold yellow] The API key has been saved to your .env file. "
488
+ "Ensure this file is kept secure and is included in your .gitignore.")
489
+
490
+ except IOError as e:
491
+ rprint(f"[ERROR] Failed to update .env file at {ENV_PATH}: {e}")
492
+ # Continue since the key is set in the environment for this session
493
+
494
+ return True
495
+
496
+ except EOFError: # Handle non-interactive environments
497
+ rprint(f"[ERROR] Cannot prompt for API key '{key_name}' in a non-interactive environment.")
498
+ return False
499
+ except Exception as e:
500
+ rprint(f"[ERROR] An unexpected error occurred during API key acquisition: {e}")
501
+ return False
502
+
503
+
504
+ def _format_messages(prompt: str, input_data: Union[Dict[str, Any], List[Dict[str, Any]]], use_batch_mode: bool) -> Union[List[Dict[str, str]], List[List[Dict[str, str]]]]:
505
+ """Formats prompt and input into LiteLLM message format."""
506
+ try:
507
+ prompt_template = PromptTemplate.from_template(prompt)
508
+ if use_batch_mode:
509
+ if not isinstance(input_data, list):
510
+ raise ValueError("input_json must be a list of dictionaries when use_batch_mode is True.")
511
+ all_messages = []
512
+ for item in input_data:
513
+ if not isinstance(item, dict):
514
+ raise ValueError("Each item in input_json list must be a dictionary for batch mode.")
515
+ formatted_prompt = prompt_template.format(**item)
516
+ all_messages.append([{"role": "user", "content": formatted_prompt}])
517
+ return all_messages
263
518
  else:
264
- llm.max_tokens = max_tokens
265
- return llm
519
+ if not isinstance(input_data, dict):
520
+ raise ValueError("input_json must be a dictionary when use_batch_mode is False.")
521
+ formatted_prompt = prompt_template.format(**input_data)
522
+ return [{"role": "user", "content": formatted_prompt}]
523
+ except KeyError as e:
524
+ raise ValueError(f"Prompt formatting error: Missing key {e} in input_json for prompt template.") from e
525
+ except Exception as e:
526
+ raise ValueError(f"Error formatting prompt: {e}") from e
266
527
 
267
- def calculate_cost(handler, selected_model):
268
- """
269
- Calculates the cost of the invoke run based on token usage.
528
+ # --- Main Function ---
529
+
530
+ def llm_invoke(
531
+ prompt: Optional[str] = None,
532
+ input_json: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
533
+ strength: float = 0.5, # Use pdd.DEFAULT_STRENGTH if available, else 0.5
534
+ temperature: float = 0.1,
535
+ verbose: bool = False,
536
+ output_pydantic: Optional[Type[BaseModel]] = None,
537
+ time: float = 0.25,
538
+ use_batch_mode: bool = False,
539
+ messages: Optional[Union[List[Dict[str, str]], List[List[Dict[str, str]]]]] = None,
540
+ ) -> Dict[str, Any]:
270
541
  """
271
- input_tokens = handler.input_tokens or 0
272
- output_tokens = handler.output_tokens or 0
273
- input_cost = selected_model.input_cost
274
- output_cost = selected_model.output_cost
275
- total_cost = (input_tokens / 1_000_000) * input_cost + (output_tokens / 1_000_000) * output_cost
276
- return total_cost
542
+ Runs a prompt with given input using LiteLLM, handling model selection,
543
+ API key acquisition, structured output, batching, and reasoning time.
544
+ The maximum completion token length defaults to the provider's maximum.
277
545
 
278
- # ---------------- Main Function ---------------- #
546
+ Args:
547
+ prompt: Prompt template string (required if messages is None).
548
+ input_json: Dictionary or list of dictionaries for prompt variables (required if messages is None).
549
+ strength: Model selection strength (0=cheapest, 0.5=base, 1=highest ELO).
550
+ temperature: LLM temperature.
551
+ verbose: Print detailed logs.
552
+ output_pydantic: Optional Pydantic model for structured output.
553
+ time: Relative thinking time (0-1, default 0.25).
554
+ use_batch_mode: Use batch completion if True.
555
+ messages: Pre-formatted list of messages (or list of lists for batch). If provided, ignores prompt and input_json.
279
556
 
280
- def llm_invoke(prompt, input_json, strength, temperature, verbose=False, output_pydantic=None):
281
- """
282
- Invokes an LLM chain with the provided prompt and input_json, using a model selected based on the strength parameter.
283
-
284
- Inputs:
285
- prompt (str): The prompt template as a string.
286
- input_json (dict): JSON object containing inputs for the prompt.
287
- strength (float): 0 (cheapest) to 1 (highest ELO); 0.5 uses the base model.
288
- temperature (float): Temperature for the LLM invocation.
289
- verbose (bool): When True, prints detailed information.
290
- output_pydantic (Optional): A Pydantic model class for structured output.
291
-
292
- Output (dict): Contains:
293
- 'result' - LLM output (string or parsed Pydantic object).
294
- 'cost' - Calculated cost of the invoke run.
295
- 'model_name' - Name of the selected model that succeeded.
557
+ Returns:
558
+ Dictionary containing 'result', 'cost', 'model_name', 'thinking_output'.
559
+
560
+ Raises:
561
+ ValueError: For invalid inputs or prompt formatting errors.
562
+ FileNotFoundError: If llm_model.csv is missing.
563
+ RuntimeError: If all candidate models fail.
564
+ openai.*Error: If LiteLLM encounters API errors after retries.
296
565
  """
297
- if prompt is None or not isinstance(prompt, str):
298
- raise ValueError("Prompt is required.")
299
- if input_json is None:
300
- raise ValueError("Input JSON is required.")
301
- if not isinstance(input_json, dict):
302
- raise ValueError("Input JSON must be a dictionary.")
303
-
304
- set_llm_cache(SQLiteCache(database_path=".langchain.db"))
305
- base_model_name = os.environ.get('PDD_MODEL_DEFAULT', 'gpt-4.1-nano')
306
- models = load_models()
307
-
308
- try:
309
- base_model = select_model(models, base_model_name)
310
- except ValueError as e:
311
- raise RuntimeError(f"Base model error: {str(e)}") from e
566
+ if verbose: # Print args early if verbose
567
+ rprint("[DEBUG llm_invoke start] Arguments received:")
568
+ rprint(f" prompt: {'provided' if prompt else 'None'}")
569
+ rprint(f" input_json: {'provided' if input_json is not None else 'None'}")
570
+ rprint(f" strength: {strength}")
571
+ rprint(f" temperature: {temperature}")
572
+ rprint(f" verbose: {verbose}")
573
+ rprint(f" output_pydantic: {output_pydantic.__name__ if output_pydantic else 'None'}")
574
+ rprint(f" time: {time}")
575
+ rprint(f" use_batch_mode: {use_batch_mode}")
576
+ rprint(f" messages: {'provided' if messages else 'None'}")
312
577
 
313
- candidate_models = get_candidate_models(strength, models, base_model)
578
+ # --- 1. Load Environment & Validate Inputs ---
579
+ # .env loading happens at module level
314
580
 
315
- if verbose:
316
- rprint(f"[bold cyan]Candidate models (in order):[/bold cyan] {[m.model for m in candidate_models]}")
581
+ if messages:
582
+ if verbose:
583
+ rprint("[INFO] Using provided 'messages' input.")
584
+ # Basic validation of messages format
585
+ if use_batch_mode:
586
+ if not isinstance(messages, list) or not all(isinstance(m_list, list) for m_list in messages):
587
+ raise ValueError("'messages' must be a list of lists when use_batch_mode is True.")
588
+ if not all(isinstance(msg, dict) and 'role' in msg and 'content' in msg for m_list in messages for msg in m_list):
589
+ raise ValueError("Each message in the lists within 'messages' must be a dictionary with 'role' and 'content'.")
590
+ else:
591
+ if not isinstance(messages, list) or not all(isinstance(msg, dict) and 'role' in msg and 'content' in msg for msg in messages):
592
+ raise ValueError("'messages' must be a list of dictionaries with 'role' and 'content'.")
593
+ formatted_messages = messages
594
+ elif prompt and input_json is not None:
595
+ if not isinstance(prompt, str) or not prompt:
596
+ raise ValueError("'prompt' must be a non-empty string when 'messages' is not provided.")
597
+ formatted_messages = _format_messages(prompt, input_json, use_batch_mode)
598
+ else:
599
+ raise ValueError("Either 'messages' or both 'prompt' and 'input_json' must be provided.")
600
+
601
+ if not (0.0 <= strength <= 1.0):
602
+ raise ValueError("'strength' must be between 0.0 and 1.0.")
603
+ if not (0.0 <= temperature <= 2.0): # Common range for temperature
604
+ warnings.warn("'temperature' is outside the typical range (0.0-2.0).")
605
+ if not (0.0 <= time <= 1.0):
606
+ raise ValueError("'time' must be between 0.0 and 1.0.")
607
+
608
+ # --- 2. Load Model Data & Select Candidates ---
609
+ try:
610
+ model_df = _load_model_data(LLM_MODEL_CSV_PATH)
611
+ candidate_models = _select_model_candidates(strength, DEFAULT_BASE_MODEL, model_df)
612
+ except (FileNotFoundError, ValueError, RuntimeError) as e:
613
+ rprint(f"[ERROR] Failed during model loading or selection: {e}")
614
+ raise
317
615
 
318
- last_error = None
319
- for model in candidate_models:
320
- handler = CompletionStatusHandler()
616
+ if verbose:
617
+ # This print statement is crucial for the verbose test
618
+ # Calculate and print strength for each candidate model
619
+ # Find min/max for cost and ELO
620
+ min_cost = model_df['avg_cost'].min()
621
+ max_elo = model_df['coding_arena_elo'].max()
622
+ base_cost = model_df[model_df['model'] == DEFAULT_BASE_MODEL]['avg_cost'].iloc[0] if not model_df[model_df['model'] == DEFAULT_BASE_MODEL].empty else min_cost
623
+ base_elo = model_df[model_df['model'] == DEFAULT_BASE_MODEL]['coding_arena_elo'].iloc[0] if not model_df[model_df['model'] == DEFAULT_BASE_MODEL].empty else max_elo
624
+
625
+ def calc_strength(candidate):
626
+ # If strength < 0.5, interpolate by cost (cheaper = 0, base = 0.5)
627
+ # If strength > 0.5, interpolate by ELO (base = 0.5, highest = 1.0)
628
+ avg_cost = candidate.get('avg_cost', min_cost)
629
+ elo = candidate.get('coding_arena_elo', base_elo)
630
+ if strength < 0.5:
631
+ # Map cost to [0, 0.5]
632
+ if base_cost == min_cost:
633
+ return 0.5 # Avoid div by zero
634
+ rel = (avg_cost - min_cost) / (base_cost - min_cost)
635
+ return max(0.0, min(0.5, rel * 0.5))
636
+ elif strength > 0.5:
637
+ # Map ELO to [0.5, 1.0]
638
+ if max_elo == base_elo:
639
+ return 0.5 # Avoid div by zero
640
+ rel = (elo - base_elo) / (max_elo - base_elo)
641
+ return max(0.5, min(1.0, 0.5 + rel * 0.5))
642
+ else:
643
+ return 0.5
644
+
645
+ model_strengths_formatted = [(c['model'], f"{float(calc_strength(c)):.3f}") for c in candidate_models]
646
+ rprint("[INFO] Candidate models selected and ordered (with strength):", model_strengths_formatted)
647
+ rprint(f"[INFO] Strength: {strength}, Temperature: {temperature}, Time: {time}")
648
+ if use_batch_mode:
649
+ rprint("[INFO] Batch mode enabled.")
650
+ if output_pydantic:
651
+ rprint(f"[INFO] Pydantic output requested: {output_pydantic.__name__}")
321
652
  try:
322
- try:
323
- prompt_template = PromptTemplate.from_template(prompt)
324
- except ValueError:
325
- raise ValueError("Invalid prompt template")
653
+ # Only print input_json if it was actually provided (not when messages were used)
654
+ if input_json is not None:
655
+ rprint("[INFO] Input JSON:")
656
+ rprint(input_json)
657
+ else:
658
+ rprint("[INFO] Input: Using pre-formatted 'messages'.")
659
+ except Exception:
660
+ print("[INFO] Input JSON/Messages (fallback print):") # Fallback for complex objects rich might fail on
661
+ print(input_json if input_json is not None else "[Messages provided directly]")
326
662
 
327
- llm = create_llm_instance(model, temperature, handler)
663
+
664
+ # --- 3. Iterate Through Candidates and Invoke LLM ---
665
+ last_exception = None
666
+ newly_acquired_keys: Dict[str, bool] = {} # Track keys obtained in this run
667
+
668
+ for model_info in candidate_models:
669
+ model_name_litellm = model_info['model']
670
+ api_key_name = model_info.get('api_key')
671
+ provider = model_info.get('provider', '').lower()
672
+
673
+ if verbose:
674
+ rprint(f"\n[ATTEMPT] Trying model: {model_name_litellm} (Provider: {provider})")
675
+
676
+ retry_with_same_model = True
677
+ while retry_with_same_model:
678
+ retry_with_same_model = False # Assume success unless auth error on new key
679
+
680
+ # --- 4. API Key Check & Acquisition ---
681
+ if not _ensure_api_key(model_info, newly_acquired_keys, verbose):
682
+ # Problem getting key, break inner loop, try next model candidate
683
+ if verbose:
684
+ rprint(f"[SKIP] Skipping {model_name_litellm} due to API key/credentials issue after prompt.")
685
+ break # Breaks the 'while retry_with_same_model' loop
686
+
687
+ # --- 5. Prepare LiteLLM Arguments ---
688
+ litellm_kwargs: Dict[str, Any] = {
689
+ "model": model_name_litellm,
690
+ "messages": formatted_messages,
691
+ "temperature": temperature,
692
+ }
693
+
694
+ api_key_name_from_csv = model_info.get('api_key') # From CSV
695
+ # Determine if it's a Vertex AI model for special handling
696
+ is_vertex_model = (provider.lower() == 'google') or \
697
+ (provider.lower() == 'googlevertexai') or \
698
+ (provider.lower() == 'vertex_ai') or \
699
+ model_name_litellm.startswith('vertex_ai/')
700
+
701
+ if is_vertex_model and api_key_name_from_csv == 'VERTEX_CREDENTIALS':
702
+ credentials_file_path = os.getenv("VERTEX_CREDENTIALS") # Path from env var
703
+ vertex_project_env = os.getenv("VERTEX_PROJECT")
704
+ vertex_location_env = os.getenv("VERTEX_LOCATION")
705
+
706
+ if credentials_file_path and vertex_project_env and vertex_location_env:
707
+ try:
708
+ with open(credentials_file_path, 'r') as f:
709
+ loaded_credentials = json.load(f)
710
+ vertex_credentials_json_string = json.dumps(loaded_credentials)
711
+
712
+ litellm_kwargs["vertex_credentials"] = vertex_credentials_json_string
713
+ litellm_kwargs["vertex_project"] = vertex_project_env
714
+ litellm_kwargs["vertex_location"] = vertex_location_env
715
+ if verbose:
716
+ rprint(f"[INFO] For Vertex AI: using vertex_credentials from '{credentials_file_path}', project '{vertex_project_env}', location '{vertex_location_env}'.")
717
+ except FileNotFoundError:
718
+ if verbose:
719
+ rprint(f"[ERROR] Vertex credentials file not found at path specified by VERTEX_CREDENTIALS env var: '{credentials_file_path}'. LiteLLM may try ADC or fail.")
720
+ except json.JSONDecodeError:
721
+ if verbose:
722
+ rprint(f"[ERROR] Failed to decode JSON from Vertex credentials file: '{credentials_file_path}'. Check file content. LiteLLM may try ADC or fail.")
723
+ except Exception as e:
724
+ if verbose:
725
+ rprint(f"[ERROR] Failed to load or process Vertex credentials from '{credentials_file_path}': {e}. LiteLLM may try ADC or fail.")
726
+ else:
727
+ if verbose:
728
+ rprint(f"[WARN] For Vertex AI (using '{api_key_name_from_csv}'): One or more required environment variables (VERTEX_CREDENTIALS, VERTEX_PROJECT, VERTEX_LOCATION) are missing.")
729
+ if not credentials_file_path: rprint(f" Reason: VERTEX_CREDENTIALS (path to JSON file) env var not set or empty.")
730
+ if not vertex_project_env: rprint(f" Reason: VERTEX_PROJECT env var not set or empty.")
731
+ if not vertex_location_env: rprint(f" Reason: VERTEX_LOCATION env var not set or empty.")
732
+ rprint(f" LiteLLM may attempt to use Application Default Credentials or the call may fail.")
733
+
734
+ elif api_key_name_from_csv: # For other api_key_names specified in CSV (e.g., OPENAI_API_KEY, or a direct VERTEX_AI_API_KEY string)
735
+ key_value = os.getenv(api_key_name_from_csv)
736
+ if key_value:
737
+ litellm_kwargs["api_key"] = key_value
738
+ if verbose:
739
+ rprint(f"[INFO] Explicitly passing API key from env var '{api_key_name_from_csv}' as 'api_key' parameter to LiteLLM.")
740
+
741
+ # If this model is Vertex AI AND uses a direct API key string (not VERTEX_CREDENTIALS from CSV),
742
+ # also pass project and location from env vars.
743
+ if is_vertex_model:
744
+ vertex_project_env = os.getenv("VERTEX_PROJECT")
745
+ vertex_location_env = os.getenv("VERTEX_LOCATION")
746
+ if vertex_project_env and vertex_location_env:
747
+ litellm_kwargs["vertex_project"] = vertex_project_env
748
+ litellm_kwargs["vertex_location"] = vertex_location_env
749
+ if verbose:
750
+ rprint(f"[INFO] For Vertex AI model (using direct API key '{api_key_name_from_csv}'), also passing vertex_project='{vertex_project_env}' and vertex_location='{vertex_location_env}' from env vars.")
751
+ elif verbose:
752
+ rprint(f"[WARN] For Vertex AI model (using direct API key '{api_key_name_from_csv}'), VERTEX_PROJECT or VERTEX_LOCATION env vars not set. This might be required by LiteLLM.")
753
+ elif verbose: # api_key_name_from_csv was in CSV, but corresponding env var was not set/empty
754
+ rprint(f"[WARN] API key name '{api_key_name_from_csv}' found in CSV, but the environment variable '{api_key_name_from_csv}' is not set or empty. LiteLLM will use default authentication if applicable (e.g., other standard env vars or ADC).")
755
+
756
+ elif verbose: # No api_key_name_from_csv in CSV for this model
757
+ rprint(f"[INFO] No API key name specified in CSV for model '{model_name_litellm}'. LiteLLM will use its default authentication mechanisms (e.g., standard provider env vars or ADC for Vertex AI).")
758
+
759
+ # Add api_base if present in CSV
760
+ api_base = model_info.get('base_url')
761
+ if pd.notna(api_base) and api_base:
762
+ litellm_kwargs["api_base"] = str(api_base)
763
+
764
+ # Handle Structured Output (JSON Mode / Pydantic)
328
765
  if output_pydantic:
329
- if model.structured_output:
330
- llm = llm.with_structured_output(output_pydantic)
331
- chain = prompt_template | llm
766
+ # Check if model supports structured output based on CSV flag or LiteLLM check
767
+ supports_structured = model_info.get('structured_output', False)
768
+ # Optional: Add litellm.supports_response_schema check if CSV flag is unreliable
769
+ # if not supports_structured:
770
+ # try: supports_structured = litellm.supports_response_schema(model=model_name_litellm)
771
+ # except: pass # Ignore errors in supports_response_schema check
772
+
773
+ if supports_structured:
774
+ if verbose:
775
+ rprint(f"[INFO] Requesting structured output (Pydantic: {output_pydantic.__name__}) for {model_name_litellm}")
776
+ # Pass the Pydantic model directly if supported, else use json_object
777
+ # LiteLLM handles passing Pydantic models for supported providers
778
+ litellm_kwargs["response_format"] = output_pydantic
779
+ # As a fallback, one could use:
780
+ # litellm_kwargs["response_format"] = {"type": "json_object"}
781
+ # And potentially enable client-side validation:
782
+ # litellm.enable_json_schema_validation = True # Enable globally if needed
332
783
  else:
333
- parser = PydanticOutputParser(pydantic_object=output_pydantic)
334
- chain = prompt_template | llm | parser
335
- else:
336
- chain = prompt_template | llm | StrOutputParser()
337
-
338
- result_output = chain.invoke(input_json)
339
- cost = calculate_cost(handler, model)
340
-
341
- if verbose:
342
- rprint(f"[bold green]Selected model: {model.model}[/bold green]")
343
- rprint(f"Per input token cost: ${model.input_cost} per million tokens")
344
- rprint(f"Per output token cost: ${model.output_cost} per million tokens")
345
- rprint(f"Number of input tokens: {handler.input_tokens}")
346
- rprint(f"Number of output tokens: {handler.output_tokens}")
347
- rprint(f"Cost of invoke run: ${cost:.0e}")
348
- rprint(f"Strength used: {strength}")
349
- rprint(f"Temperature used: {temperature}")
350
- try:
351
- # Try printing with rich formatting first
352
- rprint(f"Input JSON: {str(input_json)}")
353
- except MarkupError:
354
- # Fallback to standard print if rich markup fails
355
- print(f"Input JSON: {str(input_json)}")
356
- except Exception:
357
- print(f"Input JSON: {input_json}")
358
- if output_pydantic:
359
- rprint(f"Output Pydantic format: {output_pydantic}")
360
- try:
361
- # Try printing with rich formatting first
362
- rprint(f"Result: {result_output}")
363
- except MarkupError as me:
364
- # Fallback to standard print if rich markup fails
365
- print(f"[bold yellow]Warning:[/bold yellow] Failed to render result with rich markup: {me}")
366
- print(f"Raw Result: {str(result_output)}") # Use standard print
367
-
368
- return {'result': result_output, 'cost': cost, 'model_name': model.model}
784
+ if verbose:
785
+ rprint(f"[WARN] Model {model_name_litellm} does not support structured output via CSV flag. Output might not be valid {output_pydantic.__name__}.")
786
+ # Proceed without forcing JSON mode, parsing will be attempted later
369
787
 
370
- except Exception as e:
371
- last_error = e
372
- if verbose:
373
- rprint(f"[red]Error with model {model.model}: {str(e)}[/red]")
374
- continue
788
+ # --- NEW REASONING LOGIC ---
789
+ reasoning_type = model_info.get('reasoning_type', 'none') # Defaults to 'none'
790
+ max_reasoning_tokens_val = model_info.get('max_reasoning_tokens', 0) # Defaults to 0
791
+
792
+ if time > 0: # Only apply reasoning if time is requested
793
+ if reasoning_type == 'budget':
794
+ if max_reasoning_tokens_val > 0:
795
+ budget = int(time * max_reasoning_tokens_val)
796
+ if budget > 0:
797
+ # Currently known: Anthropic uses 'thinking'
798
+ # Model name comparison is more robust than provider string
799
+ if provider == 'anthropic': # Check provider column instead of model prefix
800
+ litellm_kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget}
801
+ if verbose:
802
+ rprint(f"[INFO] Requesting Anthropic thinking (budget type) with budget: {budget} tokens for {model_name_litellm}")
803
+ else:
804
+ # If other providers adopt a budget param recognized by LiteLLM, add here
805
+ if verbose:
806
+ rprint(f"[WARN] Reasoning type is 'budget' for {model_name_litellm}, but no specific LiteLLM budget parameter known for this provider. Parameter not sent.")
807
+ elif verbose:
808
+ rprint(f"[INFO] Calculated reasoning budget is 0 for {model_name_litellm}, skipping reasoning parameter.")
809
+ elif verbose:
810
+ rprint(f"[WARN] Reasoning type is 'budget' for {model_name_litellm}, but 'max_reasoning_tokens' is missing or zero in CSV. Reasoning parameter not sent.")
811
+
812
+ elif reasoning_type == 'effort':
813
+ effort = "low"
814
+ if time > 0.7:
815
+ effort = "high"
816
+ elif time > 0.3:
817
+ effort = "medium"
818
+ # Use the common 'reasoning_effort' param LiteLLM provides
819
+ litellm_kwargs["reasoning_effort"] = effort
820
+ if verbose:
821
+ rprint(f"[INFO] Requesting reasoning_effort='{effort}' (effort type) for {model_name_litellm} based on time={time}")
822
+
823
+ elif reasoning_type == 'none':
824
+ if verbose:
825
+ rprint(f"[INFO] Model {model_name_litellm} has reasoning_type='none'. No reasoning parameter sent.")
826
+
827
+ else: # Unknown reasoning_type in CSV
828
+ if verbose:
829
+ rprint(f"[WARN] Unknown reasoning_type '{reasoning_type}' for model {model_name_litellm} in CSV. No reasoning parameter sent.")
830
+
831
+ # --- END NEW REASONING LOGIC ---
832
+
833
+ # Add caching control per call if needed (example: force refresh)
834
+ # litellm_kwargs["cache"] = {"no-cache": True}
835
+
836
+ # --- 6. LLM Invocation ---
837
+ try:
838
+ start_time = time_module.time()
839
+
840
+ # --- ADDED CACHE STATUS DEBUGGING (NOW UNCONDITIONAL) ---
841
+ print(f"[DEBUG llm_invoke] Cache Check: litellm.cache is None: {litellm.cache is None}") # MODIFIED: unconditional print
842
+ if litellm.cache is not None:
843
+ print(f"[DEBUG llm_invoke] litellm.cache type: {type(litellm.cache)}, ID: {id(litellm.cache)}") # MODIFIED: unconditional print
844
+ # --- END ADDED CACHE STATUS DEBUGGING ---
375
845
 
376
- if isinstance(last_error, ValueError) and "Invalid prompt template" in str(last_error):
377
- raise ValueError("Invalid prompt template")
378
- if last_error:
379
- raise RuntimeError(f"Error during LLM invocation: {str(last_error)}")
380
- raise RuntimeError("No available models could process the request")
846
+ # <<< EXPLICITLY ENABLE CACHING >>>
847
+ # Only add if litellm.cache is configured
848
+ if litellm.cache is not None:
849
+ litellm_kwargs["caching"] = True
850
+ else: # MODIFIED: unconditional print for this path too
851
+ print(f"[DEBUG llm_invoke] NOT ENABLING CACHING: litellm.cache is None at call time.")
381
852
 
853
+
854
+ if use_batch_mode:
855
+ if verbose:
856
+ rprint(f"[INFO] Calling litellm.batch_completion for {model_name_litellm}...")
857
+ response = litellm.batch_completion(**litellm_kwargs)
858
+
859
+
860
+ else:
861
+ if verbose:
862
+ rprint(f"[INFO] Calling litellm.completion for {model_name_litellm}...")
863
+ response = litellm.completion(**litellm_kwargs)
864
+
865
+ end_time = time_module.time()
866
+
867
+ if verbose:
868
+ rprint(f"[SUCCESS] Invocation successful for {model_name_litellm} (took {end_time - start_time:.2f}s)")
869
+
870
+ # --- 7. Process Response ---
871
+ results = []
872
+ thinking_outputs = []
873
+
874
+ response_list = response if use_batch_mode else [response]
875
+
876
+ for i, resp_item in enumerate(response_list):
877
+ # Cost calculation is handled entirely by the success callback
878
+
879
+ # Thinking Output
880
+ thinking = None
881
+ try:
882
+ # Attempt 1: Check _hidden_params based on isolated test script
883
+ if hasattr(resp_item, '_hidden_params') and resp_item._hidden_params and 'thinking' in resp_item._hidden_params:
884
+ thinking = resp_item._hidden_params['thinking']
885
+ if verbose:
886
+ rprint("[DEBUG] Extracted thinking output from response._hidden_params['thinking']")
887
+ # Attempt 2: Fallback to reasoning_content in message
888
+ # Use .get() for safer access
889
+ elif hasattr(resp_item, 'choices') and resp_item.choices and hasattr(resp_item.choices[0], 'message') and hasattr(resp_item.choices[0].message, 'get') and resp_item.choices[0].message.get('reasoning_content'):
890
+ thinking = resp_item.choices[0].message.get('reasoning_content')
891
+ if verbose:
892
+ rprint("[DEBUG] Extracted thinking output from response.choices[0].message.get('reasoning_content')")
893
+
894
+ except (AttributeError, IndexError, KeyError, TypeError):
895
+ if verbose:
896
+ rprint("[DEBUG] Failed to extract thinking output from known locations.")
897
+ pass # Ignore if structure doesn't match or errors occur
898
+ thinking_outputs.append(thinking)
899
+
900
+ # Result (String or Pydantic)
901
+ try:
902
+ raw_result = resp_item.choices[0].message.content
903
+ if output_pydantic:
904
+ parsed_result = None
905
+ json_string_to_parse = None
906
+
907
+ try:
908
+ # Attempt 1: Check if LiteLLM already parsed it
909
+ if isinstance(raw_result, output_pydantic):
910
+ parsed_result = raw_result
911
+ if verbose:
912
+ rprint("[DEBUG] Pydantic object received directly from LiteLLM.")
913
+
914
+ # Attempt 2: Check if raw_result is dict-like and validate
915
+ elif isinstance(raw_result, dict):
916
+ parsed_result = output_pydantic.model_validate(raw_result)
917
+ if verbose:
918
+ rprint("[DEBUG] Validated dictionary-like object directly.")
919
+
920
+ # Attempt 3: Process as string (if not already parsed/validated)
921
+ elif isinstance(raw_result, str):
922
+ json_string_to_parse = raw_result # Start with the raw string
923
+ try:
924
+ # Look for first { and last }
925
+ start_brace = json_string_to_parse.find('{')
926
+ end_brace = json_string_to_parse.rfind('}')
927
+ if start_brace != -1 and end_brace != -1 and end_brace > start_brace:
928
+ potential_json = json_string_to_parse[start_brace:end_brace+1]
929
+ # Basic check if it looks like JSON
930
+ if potential_json.strip().startswith('{') and potential_json.strip().endswith('}'):
931
+ if verbose:
932
+ rprint(f"[DEBUG] Attempting to parse extracted JSON block: '{potential_json}'")
933
+ parsed_result = output_pydantic.model_validate_json(potential_json)
934
+ else:
935
+ # If block extraction fails, try cleaning markdown next
936
+ raise ValueError("Extracted block doesn't look like JSON")
937
+ else:
938
+ # If no braces found, try cleaning markdown next
939
+ raise ValueError("Could not find enclosing {}")
940
+ except (json.JSONDecodeError, ValidationError, ValueError) as extraction_error:
941
+ if verbose:
942
+ rprint(f"[DEBUG] JSON block extraction/validation failed ('{extraction_error}'). Trying markdown cleaning.")
943
+ # Fallback: Clean markdown fences and retry JSON validation
944
+ cleaned_result_str = raw_result.strip()
945
+ if cleaned_result_str.startswith("```json"):
946
+ cleaned_result_str = cleaned_result_str[7:]
947
+ elif cleaned_result_str.startswith("```"):
948
+ cleaned_result_str = cleaned_result_str[3:]
949
+ if cleaned_result_str.endswith("```"):
950
+ cleaned_result_str = cleaned_result_str[:-3]
951
+ cleaned_result_str = cleaned_result_str.strip()
952
+ # Check again if it looks like JSON before parsing
953
+ if cleaned_result_str.startswith('{') and cleaned_result_str.endswith('}'):
954
+ if verbose:
955
+ rprint(f"[DEBUG] Attempting parse after cleaning markdown fences. Cleaned string: '{cleaned_result_str}'")
956
+ json_string_to_parse = cleaned_result_str # Update string for error reporting
957
+ parsed_result = output_pydantic.model_validate_json(json_string_to_parse)
958
+ else:
959
+ # If still doesn't look like JSON, raise error
960
+ raise ValueError("Content after cleaning markdown doesn't look like JSON")
961
+
962
+
963
+ # Check if any parsing attempt succeeded
964
+ if parsed_result is None:
965
+ # This case should ideally be caught by exceptions above, but as a safeguard:
966
+ raise TypeError(f"Raw result type {type(raw_result)} or content could not be validated/parsed against {output_pydantic.__name__}.")
967
+
968
+ except (ValidationError, json.JSONDecodeError, TypeError, ValueError) as parse_error:
969
+ rprint(f"[ERROR] Failed to parse response into Pydantic model {output_pydantic.__name__} for item {i}: {parse_error}")
970
+ # Use the string that was last attempted for parsing in the error message
971
+ error_content = json_string_to_parse if json_string_to_parse is not None else raw_result
972
+ rprint("[ERROR] Content attempted for parsing:", repr(error_content)) # Use repr for clarity
973
+ results.append(f"ERROR: Failed to parse Pydantic. Raw: {repr(raw_result)}")
974
+ continue # Skip appending result below if parsing failed
975
+
976
+ # If parsing succeeded, append the parsed_result
977
+ results.append(parsed_result)
978
+
979
+ else:
980
+ # If output_pydantic was not requested, append the raw result
981
+ results.append(raw_result)
982
+
983
+ except (AttributeError, IndexError) as e:
984
+ rprint(f"[ERROR] Could not extract result content from response item {i}: {e}")
985
+ results.append(f"ERROR: Could not extract result content. Response: {resp_item}")
986
+
987
+ # --- Retrieve Cost from Callback Data --- (Reinstated)
988
+ # For batch, this will reflect the cost associated with the *last* item processed by the callback.
989
+ # A fully accurate batch total would require a more complex callback class to aggregate.
990
+ total_cost = _LAST_CALLBACK_DATA.get("cost", 0.0)
991
+ # ----------------------------------------
992
+
993
+ final_result = results if use_batch_mode else results[0]
994
+ final_thinking = thinking_outputs if use_batch_mode else thinking_outputs[0]
995
+
996
+ # --- Verbose Output for Success ---
997
+ if verbose:
998
+ # Get token usage from the *last* callback data (might not be accurate for batch)
999
+ input_tokens = _LAST_CALLBACK_DATA.get("input_tokens", 0)
1000
+ output_tokens = _LAST_CALLBACK_DATA.get("output_tokens", 0)
1001
+
1002
+ cost_input_pm = model_info.get('input', 0.0) if pd.notna(model_info.get('input')) else 0.0
1003
+ cost_output_pm = model_info.get('output', 0.0) if pd.notna(model_info.get('output')) else 0.0
1004
+
1005
+ rprint(f"[RESULT] Model Used: {model_name_litellm}")
1006
+ rprint(f"[RESULT] Cost (Input): ${cost_input_pm:.2f}/M tokens")
1007
+ rprint(f"[RESULT] Cost (Output): ${cost_output_pm:.2f}/M tokens")
1008
+ rprint(f"[RESULT] Tokens (Prompt): {input_tokens}")
1009
+ rprint(f"[RESULT] Tokens (Completion): {output_tokens}")
1010
+ # Display the cost captured by the callback
1011
+ rprint(f"[RESULT] Total Cost (from callback): ${total_cost:.6g}") # Renamed label for clarity
1012
+ rprint("[RESULT] Max Completion Tokens: Provider Default") # Indicate default limit
1013
+ if final_thinking:
1014
+ rprint("[RESULT] Thinking Output:")
1015
+ rprint(final_thinking) # Rich print should handle the thinking output format
1016
+
1017
+ # --- Print raw output before returning if verbose ---
1018
+ if verbose:
1019
+ rprint("[DEBUG] Raw output before return:")
1020
+ print(f" Raw Result (repr): {repr(final_result)}")
1021
+ print(f" Raw Thinking (repr): {repr(final_thinking)}")
1022
+ rprint("-" * 20) # Separator
1023
+
1024
+ # --- Return Success ---
1025
+ return {
1026
+ 'result': final_result,
1027
+ 'cost': total_cost,
1028
+ 'model_name': model_name_litellm, # Actual model used
1029
+ 'thinking_output': final_thinking if final_thinking else None
1030
+ }
1031
+
1032
+ # --- 6b. Handle Invocation Errors ---
1033
+ except openai.AuthenticationError as e:
1034
+ last_exception = e
1035
+ if newly_acquired_keys.get(api_key_name):
1036
+ rprint(f"[AUTH ERROR] Authentication failed for {model_name_litellm} with the newly provided key for '{api_key_name}'. Please check the key and try again.")
1037
+ # Invalidate the key in env for this session to force re-prompt on retry
1038
+ if api_key_name in os.environ:
1039
+ del os.environ[api_key_name]
1040
+ # Clear the 'newly acquired' status for this key so the next attempt doesn't trigger immediate retry loop
1041
+ newly_acquired_keys[api_key_name] = False
1042
+ retry_with_same_model = True # Set flag to retry the same model after re-prompt
1043
+ # Go back to the start of the 'while retry_with_same_model' loop
1044
+ else:
1045
+ rprint(f"[AUTH ERROR] Authentication failed for {model_name_litellm} using existing key '{api_key_name}'. Trying next model.")
1046
+ break # Break inner loop, try next model candidate
1047
+
1048
+ except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError,
1049
+ openai.APIStatusError, openai.BadRequestError, openai.InternalServerError,
1050
+ Exception) as e:
1051
+ last_exception = e
1052
+ error_type = type(e).__name__
1053
+ rprint(f"[ERROR] Invocation failed for {model_name_litellm} ({error_type}): {e}. Trying next model.")
1054
+ # Log more details in verbose mode
1055
+ if verbose:
1056
+ import traceback
1057
+ traceback.print_exc()
1058
+ break # Break inner loop, try next model candidate
1059
+
1060
+ # If the inner loop was broken (not by success), continue to the next candidate model
1061
+ continue
1062
+
1063
+ # --- 8. Handle Failure of All Candidates ---
1064
+ error_message = "All candidate models failed."
1065
+ if last_exception:
1066
+ error_message += f" Last error ({type(last_exception).__name__}): {last_exception}"
1067
+ rprint(f"[FATAL] {error_message}")
1068
+ raise RuntimeError(error_message) from last_exception
1069
+
1070
+ # --- Example Usage (Optional) ---
382
1071
  if __name__ == "__main__":
383
- example_prompt = "Tell me a joke about {topic}"
384
- example_input = {"topic": "programming"}
1072
+ # This block allows running the file directly for testing.
1073
+ # Ensure you have a ./data/llm_model.csv file and potentially a .env file.
1074
+
1075
+ # Set PDD_DEBUG_SELECTOR=1 to see model selection details
1076
+ # os.environ["PDD_DEBUG_SELECTOR"] = "1"
1077
+
1078
+ # Example 1: Simple text generation
1079
+ print("\n--- Example 1: Simple Text Generation (Strength 0.5) ---")
385
1080
  try:
386
- output = llm_invoke(example_prompt, example_input, strength=0.5, temperature=0.7, verbose=True)
387
- rprint("[bold magenta]Invocation succeeded:[/bold magenta]", output)
388
- except Exception as err:
389
- rprint(f"[bold red]Invocation failed:[/bold red] {err}")
1081
+ response = llm_invoke(
1082
+ prompt="Tell me a short joke about {topic}.",
1083
+ input_json={"topic": "programmers"},
1084
+ strength=0.5, # Use base model (gpt-4.1-nano)
1085
+ temperature=0.7,
1086
+ verbose=True
1087
+ )
1088
+ rprint("\nExample 1 Response:")
1089
+ rprint(response)
1090
+ except Exception as e:
1091
+ rprint(f"\nExample 1 Failed: {e}")
1092
+
1093
+ # Example 1b: Simple text generation (Strength 0.3)
1094
+ print("\n--- Example 1b: Simple Text Generation (Strength 0.3) ---")
1095
+ try:
1096
+ response = llm_invoke(
1097
+ prompt="Tell me a short joke about {topic}.",
1098
+ input_json={"topic": "keyboards"},
1099
+ strength=0.3, # Should select gemini-pro based on cost interpolation
1100
+ temperature=0.7,
1101
+ verbose=True
1102
+ )
1103
+ rprint("\nExample 1b Response:")
1104
+ rprint(response)
1105
+ except Exception as e:
1106
+ rprint(f"\nExample 1b Failed: {e}")
1107
+
1108
+ # Example 2: Structured output (requires a Pydantic model)
1109
+ print("\n--- Example 2: Structured Output (Pydantic, Strength 0.8) ---")
1110
+ class JokeStructure(BaseModel):
1111
+ setup: str
1112
+ punchline: str
1113
+ rating: Optional[int] = None
1114
+
1115
+ try:
1116
+ # Use a model known to support structured output (check your CSV)
1117
+ # Strength 0.8 should select gemini-pro based on ELO interpolation
1118
+ response_structured = llm_invoke(
1119
+ prompt="Create a joke about {topic}. Output ONLY the JSON object with 'setup' and 'punchline'.",
1120
+ input_json={"topic": "data science"},
1121
+ strength=0.8, # Try a higher ELO model (gemini-pro expected)
1122
+ temperature=1,
1123
+ output_pydantic=JokeStructure,
1124
+ verbose=True
1125
+ )
1126
+ rprint("\nExample 2 Response:")
1127
+ rprint(response_structured)
1128
+ if isinstance(response_structured.get('result'), JokeStructure):
1129
+ rprint("\nPydantic object received successfully:", response_structured['result'].model_dump())
1130
+ else:
1131
+ rprint("\nResult was not the expected Pydantic object:", response_structured.get('result'))
1132
+
1133
+ except Exception as e:
1134
+ rprint(f"\nExample 2 Failed: {e}")
1135
+
1136
+
1137
+ # Example 3: Batch processing
1138
+ print("\n--- Example 3: Batch Processing (Strength 0.3) ---")
1139
+ try:
1140
+ batch_input = [
1141
+ {"animal": "cat", "adjective": "lazy"},
1142
+ {"animal": "dog", "adjective": "energetic"},
1143
+ ]
1144
+ # Strength 0.3 should select gemini-pro
1145
+ response_batch = llm_invoke(
1146
+ prompt="Describe a {adjective} {animal} in one sentence.",
1147
+ input_json=batch_input,
1148
+ strength=0.3, # Cheaper model maybe (gemini-pro expected)
1149
+ temperature=0.5,
1150
+ use_batch_mode=True,
1151
+ verbose=True
1152
+ )
1153
+ rprint("\nExample 3 Response:")
1154
+ rprint(response_batch)
1155
+ except Exception as e:
1156
+ rprint(f"\nExample 3 Failed: {e}")
1157
+
1158
+ # Example 4: Using 'messages' input
1159
+ print("\n--- Example 4: Using 'messages' input (Strength 0.5) ---")
1160
+ try:
1161
+ custom_messages = [
1162
+ {"role": "system", "content": "You are a helpful assistant."},
1163
+ {"role": "user", "content": "What is the capital of France?"}
1164
+ ]
1165
+ # Strength 0.5 should select gpt-4.1-nano
1166
+ response_messages = llm_invoke(
1167
+ messages=custom_messages,
1168
+ strength=0.5,
1169
+ temperature=0.1,
1170
+ verbose=True
1171
+ )
1172
+ rprint("\nExample 4 Response:")
1173
+ rprint(response_messages)
1174
+ except Exception as e:
1175
+ rprint(f"\nExample 4 Failed: {e}")
1176
+
1177
+ # Example 5: Requesting thinking time (e.g., for Anthropic)
1178
+ print("\n--- Example 5: Requesting Thinking Time (Strength 1.0, Time 0.5) ---")
1179
+ try:
1180
+ # Ensure your CSV has max_reasoning_tokens for an Anthropic model
1181
+ # Strength 1.0 should select claude-3 (highest ELO)
1182
+ # Time 0.5 with budget type should request thinking
1183
+ response_thinking = llm_invoke(
1184
+ prompt="Explain the theory of relativity simply, taking some time to think.",
1185
+ input_json={},
1186
+ strength=1.0, # Try to get highest ELO model (claude-3)
1187
+ temperature=1,
1188
+ time=0.5, # Request moderate thinking time
1189
+ verbose=True
1190
+ )
1191
+ rprint("\nExample 5 Response:")
1192
+ rprint(response_thinking)
1193
+ except Exception as e:
1194
+ rprint(f"\nExample 5 Failed: {e}")
1195
+
1196
+ # Example 6: Pydantic Fallback Parsing (Strength 0.3)
1197
+ print("\n--- Example 6: Pydantic Fallback Parsing (Strength 0.3) ---")
1198
+ # This requires mocking litellm.completion to return a JSON string
1199
+ # even when gemini-pro (which supports structured output) is selected.
1200
+ # This is hard to demonstrate cleanly in the __main__ block without mocks.
1201
+ # The unit test test_llm_invoke_output_pydantic_unsupported_parses covers this.
1202
+ print("(Covered by unit tests)")