pdd-cli 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pdd-cli might be problematic. Click here for more details.

pdd/llm_invoke.py CHANGED
@@ -1,389 +1,1234 @@
1
- #!/usr/bin/env python
2
- """
3
- llm_invoke.py
4
-
5
- This module provides a single function, llm_invoke, that runs a prompt with a given input
6
- against a language model (LLM) using Langchain and returns the output, cost, and model name.
7
- The function supports model selection based on cost/ELO interpolation controlled by the
8
- "strength" parameter. It also implements a retry mechanism: if a model invocation fails,
9
- it falls back to the next candidate (cheaper for strength < 0.5, or higher ELO for strength ≥ 0.5).
10
-
11
- Usage:
12
- from llm_invoke import llm_invoke
13
- result = llm_invoke(prompt, input_json, strength, temperature, verbose=True, output_pydantic=MyPydanticClass)
14
- # result is a dict with keys: 'result', 'cost', 'model_name'
15
-
16
- Environment:
17
- - PDD_MODEL_DEFAULT: if set, used as the base model name. Otherwise defaults to "gpt-4.1-nano".
18
- - PDD_PATH: if set, models are loaded from $PDD_PATH/data/llm_model.csv; otherwise from ./data/llm_model.csv.
19
- - Models that require an API key will check the corresponding environment variable (name provided in the CSV).
20
- """
1
+ # Corrected code_under_test (llm_invoke.py)
2
+ # Added optional debugging prints in _select_model_candidates
21
3
 
22
4
  import os
23
- import csv
5
+ import pandas as pd
6
+ import litellm
7
+ import logging # ADDED FOR DETAILED LOGGING
8
+
9
+ # --- Configure Standard Python Logging ---
10
+ logger = logging.getLogger("pdd.llm_invoke")
11
+
12
+ # Environment variable to control log level
13
+ PDD_LOG_LEVEL = os.getenv("PDD_LOG_LEVEL", "INFO")
14
+ PRODUCTION_MODE = os.getenv("PDD_ENVIRONMENT") == "production"
15
+
16
+ # Set default level based on environment
17
+ if PRODUCTION_MODE:
18
+ logger.setLevel(logging.WARNING)
19
+ else:
20
+ logger.setLevel(getattr(logging, PDD_LOG_LEVEL, logging.INFO))
21
+
22
+ # Configure LiteLLM logger separately
23
+ litellm_logger = logging.getLogger("litellm")
24
+ litellm_log_level = os.getenv("LITELLM_LOG_LEVEL", "WARNING" if PRODUCTION_MODE else "INFO")
25
+ litellm_logger.setLevel(getattr(logging, litellm_log_level, logging.WARNING))
26
+
27
+ # Add a console handler if none exists
28
+ if not logger.handlers:
29
+ console_handler = logging.StreamHandler()
30
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
31
+ console_handler.setFormatter(formatter)
32
+ logger.addHandler(console_handler)
33
+
34
+ # Only add handler to litellm logger if it doesn't have any
35
+ if not litellm_logger.handlers:
36
+ litellm_logger.addHandler(console_handler)
37
+
38
+ # Function to set up file logging if needed
39
+ def setup_file_logging(log_file_path=None):
40
+ """Configure rotating file handler for logging"""
41
+ if not log_file_path:
42
+ return
43
+
44
+ try:
45
+ from logging.handlers import RotatingFileHandler
46
+ file_handler = RotatingFileHandler(
47
+ log_file_path, maxBytes=10*1024*1024, backupCount=5
48
+ )
49
+ file_handler.setFormatter(logging.Formatter(
50
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
51
+ ))
52
+ logger.addHandler(file_handler)
53
+ litellm_logger.addHandler(file_handler)
54
+ logger.info(f"File logging configured to: {log_file_path}")
55
+ except Exception as e:
56
+ logger.warning(f"Failed to set up file logging: {e}")
57
+
58
+ # Function to set verbose logging
59
+ def set_verbose_logging(verbose=False):
60
+ """Set verbose logging based on flag or environment variable"""
61
+ if verbose or os.getenv("PDD_VERBOSE_LOGGING") == "1":
62
+ logger.setLevel(logging.DEBUG)
63
+ litellm_logger.setLevel(logging.DEBUG)
64
+ logger.debug("Verbose logging enabled")
65
+
66
+ # --- End Logging Configuration ---
67
+
24
68
  import json
69
+ # from rich import print as rprint # Replaced with logger
70
+ from dotenv import load_dotenv
71
+ from pathlib import Path
72
+ from typing import Optional, Dict, List, Any, Type, Union
73
+ from pydantic import BaseModel, ValidationError
74
+ import openai # Import openai for exception handling as LiteLLM maps to its types
75
+ from langchain_core.prompts import PromptTemplate
76
+ import warnings
77
+ import time as time_module # Alias to avoid conflict with 'time' parameter
78
+ # Import the default model constant
79
+ from pdd import DEFAULT_LLM_MODEL
25
80
 
26
- from pydantic import BaseModel, Field
27
- from rich import print as rprint
28
- from rich.errors import MarkupError
29
-
30
- # Langchain core and community imports
31
- from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
32
- from langchain_community.cache import SQLiteCache
33
- from langchain.globals import set_llm_cache
34
- from langchain_core.output_parsers import PydanticOutputParser, StrOutputParser
35
- from langchain_core.runnables import RunnablePassthrough, ConfigurableField
36
-
37
- # LLM provider imports
38
- from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI
39
- from langchain_fireworks import Fireworks
40
- from langchain_anthropic import ChatAnthropic
41
- from langchain_google_genai import ChatGoogleGenerativeAI
42
- from langchain_google_vertexai import ChatVertexAI
43
- from langchain_groq import ChatGroq
44
- from langchain_together import Together
45
- from langchain_ollama.llms import OllamaLLM
46
-
47
- from langchain.callbacks.base import BaseCallbackHandler
48
- from langchain.schema import LLMResult
49
-
50
- # ---------------- Internal Helper Classes and Functions ---------------- #
51
-
52
- class CompletionStatusHandler(BaseCallbackHandler):
53
- """
54
- Callback handler to capture LLM token usage and completion metadata.
55
- """
56
- def __init__(self):
57
- self.is_complete = False
58
- self.finish_reason = None
59
- self.input_tokens = None
60
- self.output_tokens = None
61
-
62
- def on_llm_end(self, response: LLMResult, **kwargs) -> None:
63
- self.is_complete = True
64
- if response.generations and response.generations[0]:
65
- generation = response.generations[0][0]
66
- # Safely get generation_info; if it's None, default to {}
67
- generation_info = generation.generation_info or {}
68
- self.finish_reason = (generation_info.get('finish_reason') or "").lower()
69
-
70
- # Attempt to get token usage from generation.message if available.
71
- if (
72
- hasattr(generation, "message")
73
- and generation.message is not None
74
- and hasattr(generation.message, "usage_metadata")
75
- and generation.message.usage_metadata
76
- ):
77
- usage_metadata = generation.message.usage_metadata
78
- else:
79
- usage_metadata = generation_info.get("usage_metadata", {})
80
-
81
- self.input_tokens = usage_metadata.get('input_tokens', 0)
82
- self.output_tokens = usage_metadata.get('output_tokens', 0)
81
+ # Opt-in to future pandas behavior regarding downcasting
82
+ pd.set_option('future.no_silent_downcasting', True)
83
83
 
84
- class ModelInfo:
85
- """
86
- Represents information about an LLM model as loaded from the CSV.
87
- """
88
- def __init__(self, provider, model, input_cost, output_cost, coding_arena_elo,
89
- base_url, api_key, counter, encoder, max_tokens, max_completion_tokens,
90
- structured_output):
91
- self.provider = provider.strip() if provider else ""
92
- self.model = model.strip() if model else ""
93
- self.input_cost = float(input_cost) if input_cost else 0.0
94
- self.output_cost = float(output_cost) if output_cost else 0.0
95
- self.average_cost = (self.input_cost + self.output_cost) / 2
96
- self.coding_arena_elo = float(coding_arena_elo) if coding_arena_elo else 0.0
97
- self.base_url = base_url.strip() if base_url else None
98
- self.api_key = api_key.strip() if api_key else None
99
- self.counter = counter.strip() if counter else None
100
- self.encoder = encoder.strip() if encoder else None
101
- self.max_tokens = int(max_tokens) if max_tokens else None
102
- self.max_completion_tokens = int(max_completion_tokens) if max_completion_tokens else None
103
- self.structured_output = (str(structured_output).lower() == 'true') if structured_output else False
104
-
105
- def load_models():
84
+ # <<< SET LITELLM DEBUG LOGGING >>>
85
+ # os.environ['LITELLM_LOG'] = 'DEBUG' # Keep commented out unless debugging LiteLLM itself
86
+
87
+ # --- Constants and Configuration ---
88
+
89
+ # Determine project root: 1. PDD_PATH env var, 2. Search upwards from script, 3. CWD
90
+ PROJECT_ROOT = None
91
+ PDD_PATH_ENV = os.getenv("PDD_PATH")
92
+
93
+ if PDD_PATH_ENV:
94
+ _path_from_env = Path(PDD_PATH_ENV)
95
+ if _path_from_env.is_dir():
96
+ PROJECT_ROOT = _path_from_env.resolve()
97
+ logger.debug(f"Using PROJECT_ROOT from PDD_PATH: {PROJECT_ROOT}")
98
+ else:
99
+ warnings.warn(f"PDD_PATH environment variable ('{PDD_PATH_ENV}') is set but not a valid directory. Attempting auto-detection.")
100
+
101
+ if PROJECT_ROOT is None: # If PDD_PATH wasn't set or was invalid
102
+ try:
103
+ # Start from the directory containing this script
104
+ current_dir = Path(__file__).resolve().parent
105
+ # Look for project markers (e.g., .git, pyproject.toml, data/, .env)
106
+ # Go up a maximum of 5 levels to prevent infinite loops
107
+ for _ in range(5):
108
+ has_git = (current_dir / ".git").exists()
109
+ has_pyproject = (current_dir / "pyproject.toml").exists()
110
+ has_data = (current_dir / "data").is_dir()
111
+ has_dotenv = (current_dir / ".env").exists()
112
+
113
+ if has_git or has_pyproject or has_data or has_dotenv:
114
+ PROJECT_ROOT = current_dir
115
+ logger.debug(f"Determined PROJECT_ROOT by marker search: {PROJECT_ROOT}")
116
+ break
117
+
118
+ parent_dir = current_dir.parent
119
+ if parent_dir == current_dir: # Reached filesystem root
120
+ break
121
+ current_dir = parent_dir
122
+
123
+ except NameError: # __file__ might not be defined (e.g., interactive session)
124
+ warnings.warn("__file__ not defined. Cannot automatically detect project root from script location.")
125
+ except Exception as e: # Catch potential permission errors etc.
126
+ warnings.warn(f"Error during project root auto-detection: {e}")
127
+
128
+ if PROJECT_ROOT is None: # Fallback to CWD if no method succeeded
129
+ PROJECT_ROOT = Path.cwd().resolve()
130
+ warnings.warn(f"Could not determine project root automatically. Using current working directory: {PROJECT_ROOT}. Ensure this is the intended root or set the PDD_PATH environment variable.")
131
+
132
+
133
+ ENV_PATH = PROJECT_ROOT / ".env"
134
+ # --- Determine LLM_MODEL_CSV_PATH ---
135
+ # Prioritize ~/.pdd/llm_model.csv
136
+ user_pdd_dir = Path.home() / ".pdd"
137
+ user_model_csv_path = user_pdd_dir / "llm_model.csv"
138
+
139
+ if user_model_csv_path.is_file():
140
+ LLM_MODEL_CSV_PATH = user_model_csv_path
141
+ logger.info(f"Using user-specific LLM model CSV: {LLM_MODEL_CSV_PATH}")
142
+ else:
143
+ LLM_MODEL_CSV_PATH = PROJECT_ROOT / "data" / "llm_model.csv"
144
+ logger.info(f"Using project LLM model CSV: {LLM_MODEL_CSV_PATH}")
145
+ # ---------------------------------
146
+
147
+ # Load environment variables from .env file
148
+ logger.debug(f"Attempting to load .env from: {ENV_PATH}")
149
+ if ENV_PATH.exists():
150
+ load_dotenv(dotenv_path=ENV_PATH)
151
+ logger.debug(f"Loaded .env file from: {ENV_PATH}")
152
+ else:
153
+ # Silently proceed if .env is optional
154
+ logger.debug(f".env file not found at {ENV_PATH}. API keys might need to be provided manually.")
155
+
156
+ # Default model if PDD_MODEL_DEFAULT is not set
157
+ # Use the imported constant as the default
158
+ DEFAULT_BASE_MODEL = os.getenv("PDD_MODEL_DEFAULT", DEFAULT_LLM_MODEL)
159
+
160
+ # --- LiteLLM Cache Configuration (S3 compatible for GCS, with SQLite fallback) ---
161
+ GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
162
+ GCS_ENDPOINT_URL = "https://storage.googleapis.com" # GCS S3 compatibility endpoint
163
+ GCS_REGION_NAME = os.getenv("GCS_REGION_NAME", "auto") # Often 'auto' works for GCS
164
+ GCS_HMAC_ACCESS_KEY_ID = os.getenv("GCS_HMAC_ACCESS_KEY_ID") # Load HMAC Key ID
165
+ GCS_HMAC_SECRET_ACCESS_KEY = os.getenv("GCS_HMAC_SECRET_ACCESS_KEY") # Load HMAC Secret
166
+
167
+ cache_configured = False
168
+
169
+ if GCS_BUCKET_NAME and GCS_HMAC_ACCESS_KEY_ID and GCS_HMAC_SECRET_ACCESS_KEY:
170
+ # Store original AWS credentials before overwriting for GCS cache setup
171
+ original_aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
172
+ original_aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
173
+ original_aws_region_name = os.environ.get('AWS_REGION_NAME')
174
+
175
+ try:
176
+ # Temporarily set AWS env vars to GCS HMAC keys for S3 compatible cache
177
+ os.environ['AWS_ACCESS_KEY_ID'] = GCS_HMAC_ACCESS_KEY_ID
178
+ os.environ['AWS_SECRET_ACCESS_KEY'] = GCS_HMAC_SECRET_ACCESS_KEY
179
+ # os.environ['AWS_REGION_NAME'] = GCS_REGION_NAME # Uncomment if needed
180
+
181
+ litellm.cache = litellm.Cache(
182
+ type="s3",
183
+ s3_bucket_name=GCS_BUCKET_NAME,
184
+ s3_region_name=GCS_REGION_NAME, # Pass region explicitly to cache
185
+ s3_endpoint_url=GCS_ENDPOINT_URL,
186
+ )
187
+ logger.info(f"LiteLLM cache configured for GCS bucket (S3 compatible): {GCS_BUCKET_NAME}")
188
+ cache_configured = True
189
+
190
+ except Exception as e:
191
+ warnings.warn(f"Failed to configure LiteLLM S3/GCS cache: {e}. Attempting SQLite cache fallback.")
192
+ litellm.cache = None # Explicitly disable cache on failure (will try SQLite next)
193
+
194
+ finally:
195
+ # Restore original AWS credentials after cache setup attempt
196
+ if original_aws_access_key_id is not None:
197
+ os.environ['AWS_ACCESS_KEY_ID'] = original_aws_access_key_id
198
+ elif 'AWS_ACCESS_KEY_ID' in os.environ:
199
+ del os.environ['AWS_ACCESS_KEY_ID']
200
+
201
+ if original_aws_secret_access_key is not None:
202
+ os.environ['AWS_SECRET_ACCESS_KEY'] = original_aws_secret_access_key
203
+ elif 'AWS_SECRET_ACCESS_KEY' in os.environ:
204
+ del os.environ['AWS_SECRET_ACCESS_KEY']
205
+
206
+ if original_aws_region_name is not None:
207
+ os.environ['AWS_REGION_NAME'] = original_aws_region_name
208
+ elif 'AWS_REGION_NAME' in os.environ:
209
+ pass # Or just leave it if the temporary setting wasn't done/needed
210
+
211
+ if not cache_configured:
212
+ try:
213
+ # Try SQLite-based cache as a fallback
214
+ sqlite_cache_path = PROJECT_ROOT / "litellm_cache.sqlite"
215
+ litellm.cache = litellm.Cache(type="sqlite", cache_path=str(sqlite_cache_path))
216
+ logger.info(f"LiteLLM SQLite cache configured at {sqlite_cache_path}")
217
+ cache_configured = True
218
+ except Exception as e2:
219
+ warnings.warn(f"Failed to configure LiteLLM SQLite cache: {e2}. Caching is disabled.")
220
+ litellm.cache = None
221
+
222
+ if not cache_configured:
223
+ warnings.warn("All LiteLLM cache configuration attempts failed. Caching is disabled.")
224
+ litellm.cache = None
225
+
226
+ # --- LiteLLM Callback for Success Logging ---
227
+
228
+ # Module-level storage for last callback data (Use with caution in concurrent environments)
229
+ _LAST_CALLBACK_DATA = {
230
+ "input_tokens": 0,
231
+ "output_tokens": 0,
232
+ "finish_reason": None,
233
+ "cost": 0.0,
234
+ }
235
+
236
+ def _litellm_success_callback(
237
+ kwargs: Dict[str, Any], # kwargs passed to completion
238
+ completion_response: Any, # response object from completion
239
+ start_time: float, end_time: float # start/end time
240
+ ):
106
241
  """
107
- Loads model information from llm_model.csv located in either $PDD_PATH/data or ./data.
242
+ LiteLLM success callback to capture usage and finish reason.
243
+ Stores data in a module-level variable for potential retrieval.
108
244
  """
109
- pdd_path = os.environ.get('PDD_PATH', '.')
110
- models_file = os.path.join(pdd_path, 'data', 'llm_model.csv')
111
- models = []
245
+ global _LAST_CALLBACK_DATA
246
+ usage = getattr(completion_response, 'usage', None)
247
+ input_tokens = getattr(usage, 'prompt_tokens', 0)
248
+ output_tokens = getattr(usage, 'completion_tokens', 0)
249
+ finish_reason = getattr(completion_response.choices[0], 'finish_reason', None)
250
+
251
+ calculated_cost = 0.0
112
252
  try:
113
- with open(models_file, newline='') as csvfile:
114
- reader = csv.DictReader(csvfile)
115
- for row in reader:
116
- model_info = ModelInfo(
117
- provider=row.get('provider',''),
118
- model=row.get('model',''),
119
- input_cost=row.get('input','0'),
120
- output_cost=row.get('output','0'),
121
- coding_arena_elo=row.get('coding_arena_elo','0'),
122
- base_url=row.get('base_url',''),
123
- api_key=row.get('api_key',''),
124
- counter=row.get('counter',''),
125
- encoder=row.get('encoder',''),
126
- max_tokens=row.get('max_tokens',''),
127
- max_completion_tokens=row.get('max_completion_tokens',''),
128
- structured_output=row.get('structured_output','False')
253
+ # Attempt 1: Use the response object directly (works for most single calls)
254
+ cost_val = litellm.completion_cost(completion_response=completion_response)
255
+ calculated_cost = cost_val if cost_val is not None else 0.0
256
+ except Exception as e1:
257
+ # Attempt 2: If response object failed (e.g., missing provider in model name),
258
+ # try again using explicit model from kwargs and tokens from usage.
259
+ # This is often needed for batch completion items.
260
+ logger.debug(f"Attempting cost calculation with fallback method: {e1}")
261
+ try:
262
+ model_name = kwargs.get("model") # Get original model name from input kwargs
263
+ if model_name and usage:
264
+ prompt_tokens = getattr(usage, 'prompt_tokens', 0)
265
+ completion_tokens = getattr(usage, 'completion_tokens', 0)
266
+ cost_val = litellm.completion_cost(
267
+ model=model_name,
268
+ prompt_tokens=prompt_tokens,
269
+ completion_tokens=completion_tokens
129
270
  )
130
- models.append(model_info)
131
- except FileNotFoundError:
132
- raise FileNotFoundError(f"llm_model.csv not found at {models_file}")
133
- return models
271
+ calculated_cost = cost_val if cost_val is not None else 0.0
272
+ else:
273
+ # If we can't get model name or usage, fallback to 0
274
+ calculated_cost = 0.0
275
+ # Optional: Log the original error e1 if needed
276
+ # logger.warning(f"[Callback WARN] Failed to calculate cost with response object ({e1}) and fallback failed.")
277
+ except Exception as e2:
278
+ # Optional: Log secondary error e2 if needed
279
+ # logger.warning(f"[Callback WARN] Failed to calculate cost with fallback method: {e2}")
280
+ calculated_cost = 0.0 # Default to 0 on any error
281
+ logger.debug(f"Cost calculation failed with fallback method: {e2}")
134
282
 
135
- def select_model(models, base_model_name):
136
- """
137
- Retrieve the base model whose name matches base_model_name. Raises an error if not found.
138
- """
139
- for model in models:
140
- if model.model == base_model_name:
141
- return model
142
- raise ValueError(f"Base model '{base_model_name}' not found in the models list.")
283
+ _LAST_CALLBACK_DATA["input_tokens"] = input_tokens
284
+ _LAST_CALLBACK_DATA["output_tokens"] = output_tokens
285
+ _LAST_CALLBACK_DATA["finish_reason"] = finish_reason
286
+ _LAST_CALLBACK_DATA["cost"] = calculated_cost # Store the calculated cost
143
287
 
144
- def get_candidate_models(strength, models, base_model):
145
- """
146
- Returns ordered list of candidate models based on strength parameter.
147
- Only includes models with available API keys.
148
- """
149
- # Filter for models with valid API keys (including test environment)
150
- available_models = [m for m in models
151
- if not m.api_key or
152
- os.environ.get(m.api_key) or
153
- m.api_key == "EXISTING_KEY"]
154
-
155
- if not available_models:
156
- raise RuntimeError("No models available with valid API keys")
288
+ # Callback doesn't need to return a value now
289
+ # return calculated_cost
290
+
291
+ # Example of logging within the callback (can be expanded)
292
+ # logger.info(f"[Callback] Tokens: In={input_tokens}, Out={output_tokens}. Reason: {finish_reason}. Cost: ${calculated_cost:.6f}")
293
+
294
+ # Register the callback with LiteLLM
295
+ litellm.success_callback = [_litellm_success_callback]
296
+
297
+ # --- Helper Functions ---
298
+
299
+ def _load_model_data(csv_path: Path) -> pd.DataFrame:
300
+ """Loads and preprocesses the LLM model data from CSV."""
301
+ if not csv_path.exists():
302
+ raise FileNotFoundError(f"LLM model CSV not found at {csv_path}")
303
+ try:
304
+ df = pd.read_csv(csv_path)
305
+ # Basic validation and type conversion
306
+ required_cols = ['provider', 'model', 'input', 'output', 'coding_arena_elo', 'api_key', 'structured_output', 'reasoning_type']
307
+ for col in required_cols:
308
+ if col not in df.columns:
309
+ raise ValueError(f"Missing required column in CSV: {col}")
310
+
311
+ # Convert numeric columns, handling potential errors
312
+ numeric_cols = ['input', 'output', 'coding_arena_elo', 'max_tokens',
313
+ 'max_completion_tokens', 'max_reasoning_tokens']
314
+ for col in numeric_cols:
315
+ if col in df.columns:
316
+ # Use errors='coerce' to turn unparseable values into NaN
317
+ df[col] = pd.to_numeric(df[col], errors='coerce')
318
+
319
+ # Fill NaN in critical numeric columns used for selection/interpolation
320
+ df['input'] = df['input'].fillna(0.0)
321
+ df['output'] = df['output'].fillna(0.0)
322
+ df['coding_arena_elo'] = df['coding_arena_elo'].fillna(0) # Use 0 ELO for missing
323
+ # Ensure max_reasoning_tokens is numeric, fillna with 0
324
+ df['max_reasoning_tokens'] = df['max_reasoning_tokens'].fillna(0).astype(int) # Ensure int
325
+
326
+ # Calculate average cost (handle potential division by zero if needed, though unlikely with fillna)
327
+ df['avg_cost'] = (df['input'] + df['output']) / 2
328
+
329
+ # Ensure boolean interpretation for structured_output
330
+ if 'structured_output' in df.columns:
331
+ df['structured_output'] = df['structured_output'].fillna(False).astype(bool)
332
+ else:
333
+ df['structured_output'] = False # Assume false if column missing
334
+
335
+ # Ensure reasoning_type is string, fillna with 'none' and lowercase
336
+ df['reasoning_type'] = df['reasoning_type'].fillna('none').astype(str).str.lower()
337
+
338
+ # Ensure api_key is treated as string, fill NaN with empty string ''
339
+ # This handles cases where read_csv might interpret empty fields as NaN
340
+ df['api_key'] = df['api_key'].fillna('').astype(str)
341
+
342
+ return df
343
+ except Exception as e:
344
+ raise RuntimeError(f"Error loading or processing LLM model CSV {csv_path}: {e}") from e
345
+
346
+ def _select_model_candidates(
347
+ strength: float,
348
+ base_model_name: str,
349
+ model_df: pd.DataFrame
350
+ ) -> List[Dict[str, Any]]:
351
+ """Selects and sorts candidate models based on strength and availability."""
352
+
353
+ # 1. Filter by API Key Name Presence (initial availability check)
354
+ # Keep models with a non-empty api_key field in the CSV.
355
+ # The actual key value check happens later.
356
+ # Allow models with empty api_key (e.g., Bedrock using AWS creds, local models)
357
+ available_df = model_df[model_df['api_key'].notna()].copy()
358
+
359
+ # --- Check if the initial DataFrame itself was empty ---
360
+ if model_df.empty:
361
+ raise ValueError("Loaded model data is empty. Check CSV file.")
362
+
363
+ # --- Check if filtering resulted in empty (might indicate all models had NaN api_key) ---
364
+ if available_df.empty:
365
+ # This case is less likely if notna() is the only filter, but good to check.
366
+ logger.warning("No models found after filtering for non-NaN api_key. Check CSV 'api_key' column.")
367
+ # Decide if this should be a hard error or allow proceeding if logic permits
368
+ # For now, let's raise an error as it likely indicates a CSV issue.
369
+ raise ValueError("No models available after initial filtering (all had NaN 'api_key'?).")
370
+
371
+ # 2. Find Base Model
372
+ base_model_row = available_df[available_df['model'] == base_model_name]
373
+ if base_model_row.empty:
374
+ # Try finding base model in the *original* df in case it was filtered out
375
+ original_base = model_df[model_df['model'] == base_model_name]
376
+ if not original_base.empty:
377
+ raise ValueError(f"Base model '{base_model_name}' found in CSV but requires API key '{original_base.iloc[0]['api_key']}' which might be missing or invalid configuration.")
378
+ else:
379
+ raise ValueError(f"Specified base model '{base_model_name}' not found in the LLM model CSV.")
380
+
381
+ base_model = base_model_row.iloc[0]
382
+
383
+ # 3. Determine Target and Sort
384
+ candidates = []
385
+ target_metric_value = None # For debugging print
157
386
 
158
- # For base model case (strength = 0.5), use base model if available
159
387
  if strength == 0.5:
160
- base_candidates = [m for m in available_models if m.model == base_model.model]
161
- if base_candidates:
162
- return base_candidates
163
- return [available_models[0]]
164
-
165
- # For strength < 0.5, prioritize cheaper models
166
- if strength < 0.5:
167
- # Get models cheaper than or equal to base model
168
- cheaper_models = [m for m in available_models
169
- if m.average_cost <= base_model.average_cost]
170
- if not cheaper_models:
171
- return [available_models[0]]
172
-
173
- # For test environment, honor the mock model setup
174
- test_models = [m for m in cheaper_models if m.api_key == "EXISTING_KEY"]
175
- if test_models:
176
- return test_models
177
-
178
- # Production path: interpolate based on cost
179
- cheapest = min(cheaper_models, key=lambda m: m.average_cost)
180
- cost_range = base_model.average_cost - cheapest.average_cost
181
- target_cost = cheapest.average_cost + (strength / 0.5) * cost_range
182
- return sorted(cheaper_models, key=lambda m: abs(m.average_cost - target_cost))
183
-
184
- # For strength > 0.5, prioritize higher ELO models
185
- # Get models with higher or equal ELO than base_model
186
- better_models = [m for m in available_models
187
- if m.coding_arena_elo >= base_model.coding_arena_elo]
188
- if not better_models:
189
- return [available_models[0]]
190
-
191
- # For test environment, honor the mock model setup
192
- test_models = [m for m in better_models if m.api_key == "EXISTING_KEY"]
193
- if test_models:
194
- return test_models
195
-
196
- # Production path: interpolate based on ELO
197
- highest = max(better_models, key=lambda m: m.coding_arena_elo)
198
- elo_range = highest.coding_arena_elo - base_model.coding_arena_elo
199
- target_elo = base_model.coding_arena_elo + ((strength - 0.5) / 0.5) * elo_range
200
- return sorted(better_models, key=lambda m: abs(m.coding_arena_elo - target_elo))
201
-
202
- def create_llm_instance(selected_model, temperature, handler):
203
- """
204
- Creates an instance of the LLM using the selected_model parameters.
205
- Handles provider-specific settings and token limit configurations.
206
- """
207
- provider = selected_model.provider.lower()
208
- model_name = selected_model.model
209
- base_url = selected_model.base_url
210
- api_key_env = selected_model.api_key
211
- max_completion_tokens = selected_model.max_completion_tokens
212
- max_tokens = selected_model.max_tokens
213
-
214
- api_key = os.environ.get(api_key_env) if api_key_env else None
215
-
216
- if provider == 'openai':
217
- if base_url:
218
- llm = ChatOpenAI(model=model_name, temperature=temperature,
219
- openai_api_key=api_key, callbacks=[handler],
220
- openai_api_base=base_url)
388
+ # target_model = base_model
389
+ # Sort remaining by ELO descending as fallback
390
+ available_df['sort_metric'] = -available_df['coding_arena_elo'] # Negative for descending sort
391
+ candidates = available_df.sort_values(by='sort_metric').to_dict('records')
392
+ # Ensure base model is first if it exists
393
+ if any(c['model'] == base_model_name for c in candidates):
394
+ candidates.sort(key=lambda x: 0 if x['model'] == base_model_name else 1)
395
+ target_metric_value = f"Base Model ELO: {base_model['coding_arena_elo']}"
396
+
397
+ elif strength < 0.5:
398
+ # Interpolate by Cost (downwards from base)
399
+ base_cost = base_model['avg_cost']
400
+ cheapest_model = available_df.loc[available_df['avg_cost'].idxmin()]
401
+ cheapest_cost = cheapest_model['avg_cost']
402
+
403
+ if base_cost <= cheapest_cost: # Handle edge case where base is cheapest
404
+ target_cost = cheapest_cost + strength * (base_cost - cheapest_cost) # Will be <= base_cost
221
405
  else:
222
- if model_name.startswith('o'):
223
- llm = ChatOpenAI(model=model_name, temperature=temperature,
224
- openai_api_key=api_key, callbacks=[handler],
225
- reasoning={"effort": "high","summary": "auto"})
226
- else:
227
- llm = ChatOpenAI(model=model_name, temperature=temperature,
228
- openai_api_key=api_key, callbacks=[handler])
229
- elif provider == 'anthropic':
230
- # Special case for Claude 3.7 Sonnet with thinking token budget
231
- if 'claude-3-7-sonnet' in model_name:
232
- llm = ChatAnthropic(
233
- model=model_name,
234
- temperature=temperature,
235
- callbacks=[handler],
236
- thinking={"type": "enabled", "budget_tokens": 4000} # 32K thinking token budget
237
- )
406
+ # Interpolate between cheapest and base
407
+ target_cost = cheapest_cost + (strength / 0.5) * (base_cost - cheapest_cost)
408
+
409
+ available_df['sort_metric'] = abs(available_df['avg_cost'] - target_cost)
410
+ candidates = available_df.sort_values(by='sort_metric').to_dict('records')
411
+ target_metric_value = f"Target Cost: {target_cost:.6f}"
412
+
413
+ else: # strength > 0.5
414
+ # Interpolate by ELO (upwards from base)
415
+ base_elo = base_model['coding_arena_elo']
416
+ highest_elo_model = available_df.loc[available_df['coding_arena_elo'].idxmax()]
417
+ highest_elo = highest_elo_model['coding_arena_elo']
418
+
419
+ if highest_elo <= base_elo: # Handle edge case where base has highest ELO
420
+ target_elo = base_elo + (strength - 0.5) * (highest_elo - base_elo) # Will be >= base_elo
238
421
  else:
239
- llm = ChatAnthropic(model=model_name, temperature=temperature, callbacks=[handler])
240
- elif provider == 'google':
241
- llm = ChatGoogleGenerativeAI(model=model_name, temperature=temperature, callbacks=[handler])
242
- elif provider == 'googlevertexai':
243
- llm = ChatVertexAI(model=model_name, temperature=temperature, callbacks=[handler])
244
- elif provider == 'ollama':
245
- llm = OllamaLLM(model=model_name, temperature=temperature, callbacks=[handler])
246
- elif provider == 'azure':
247
- llm = AzureChatOpenAI(model=model_name, temperature=temperature,
248
- callbacks=[handler], openai_api_key=api_key, openai_api_base=base_url)
249
- elif provider == 'fireworks':
250
- llm = Fireworks(model=model_name, temperature=temperature, callbacks=[handler])
251
- elif provider == 'together':
252
- llm = Together(model=model_name, temperature=temperature, callbacks=[handler])
253
- elif provider == 'groq':
254
- llm = ChatGroq(model_name=model_name, temperature=temperature, callbacks=[handler])
422
+ # Interpolate between base and highest
423
+ target_elo = base_elo + ((strength - 0.5) / 0.5) * (highest_elo - base_elo)
424
+
425
+ available_df['sort_metric'] = abs(available_df['coding_arena_elo'] - target_elo)
426
+ candidates = available_df.sort_values(by='sort_metric').to_dict('records')
427
+ target_metric_value = f"Target ELO: {target_elo:.2f}"
428
+
429
+
430
+ if not candidates:
431
+ # This should ideally not happen if available_df was not empty
432
+ raise RuntimeError("Model selection resulted in an empty candidate list.")
433
+
434
+ # --- DEBUGGING PRINT ---
435
+ if os.getenv("PDD_DEBUG_SELECTOR"): # Add env var check for debug prints
436
+ logger.debug("\n--- DEBUG: _select_model_candidates ---")
437
+ logger.debug(f"Strength: {strength}, Base Model: {base_model_name}")
438
+ logger.debug(f"Metric: {target_metric_value}")
439
+ logger.debug("Available DF (Sorted by metric):")
440
+ # Select columns relevant to the sorting metric
441
+ sort_cols = ['model', 'avg_cost', 'coding_arena_elo', 'sort_metric']
442
+ logger.debug(available_df.sort_values(by='sort_metric')[sort_cols])
443
+ logger.debug("Final Candidates List (Model Names):")
444
+ logger.debug([c['model'] for c in candidates])
445
+ logger.debug("---------------------------------------\n")
446
+ # --- END DEBUGGING PRINT ---
447
+
448
+ return candidates
449
+
450
+
451
+ def _ensure_api_key(model_info: Dict[str, Any], newly_acquired_keys: Dict[str, bool], verbose: bool) -> bool:
452
+ """Checks for API key in env, prompts user if missing, and updates .env."""
453
+ key_name = model_info.get('api_key')
454
+
455
+ if not key_name or key_name == "EXISTING_KEY":
456
+ if verbose:
457
+ logger.info(f"Skipping API key check for model {model_info.get('model')} (key name: {key_name})")
458
+ return True # Assume key is handled elsewhere or not needed
459
+
460
+ key_value = os.getenv(key_name)
461
+
462
+ if key_value:
463
+ if verbose:
464
+ logger.info(f"API key '{key_name}' found in environment.")
465
+ newly_acquired_keys[key_name] = False # Mark as existing
466
+ return True
255
467
  else:
256
- raise ValueError(f"Unsupported provider: {selected_model.provider}")
468
+ logger.warning(f"API key environment variable '{key_name}' for model '{model_info.get('model')}' is not set.")
469
+ try:
470
+ # Interactive prompt
471
+ user_provided_key = input(f"Please enter the API key for {key_name}: ").strip()
472
+ if not user_provided_key:
473
+ logger.error("No API key provided. Cannot proceed with this model.")
474
+ return False
475
+
476
+ # Set environment variable for the current process
477
+ os.environ[key_name] = user_provided_key
478
+ logger.info(f"API key '{key_name}' set for the current session.")
479
+ newly_acquired_keys[key_name] = True # Mark as newly acquired
480
+
481
+ # Update .env file
482
+ try:
483
+ lines = []
484
+ if ENV_PATH.exists():
485
+ with open(ENV_PATH, 'r') as f:
486
+ lines = f.readlines()
487
+
488
+ new_lines = []
489
+ # key_updated = False
490
+ prefix = f"{key_name}="
491
+ prefix_spaced = f"{key_name} =" # Handle potential spaces
492
+
493
+ for line in lines:
494
+ stripped_line = line.strip()
495
+ if stripped_line.startswith(prefix) or stripped_line.startswith(prefix_spaced):
496
+ # Comment out the old key
497
+ new_lines.append(f"# {line}")
498
+ # key_updated = True # Indicates we found an old line to comment
499
+ elif stripped_line.startswith(f"# {prefix}") or stripped_line.startswith(f"# {prefix_spaced}"):
500
+ # Keep already commented lines as they are
501
+ new_lines.append(line)
502
+ else:
503
+ new_lines.append(line)
504
+
505
+ # Append the new key, ensuring quotes for robustness
506
+ new_key_line = f'{key_name}="{user_provided_key}"\n'
507
+ # Add newline before if file not empty and doesn't end with newline
508
+ if new_lines and not new_lines[-1].endswith('\n'):
509
+ new_lines.append('\n')
510
+ new_lines.append(new_key_line)
257
511
 
258
- if max_completion_tokens:
259
- llm.model_kwargs = {"max_completion_tokens": max_completion_tokens}
260
- elif max_tokens:
261
- if provider == 'google' or provider == 'googlevertexai':
262
- llm.max_output_tokens = max_tokens
512
+
513
+ with open(ENV_PATH, 'w') as f:
514
+ f.writelines(new_lines)
515
+
516
+ logger.info(f"API key '{key_name}' saved to {ENV_PATH}.")
517
+ logger.warning("SECURITY WARNING: The API key has been saved to your .env file. "
518
+ "Ensure this file is kept secure and is included in your .gitignore.")
519
+
520
+ except IOError as e:
521
+ logger.error(f"Failed to update .env file at {ENV_PATH}: {e}")
522
+ # Continue since the key is set in the environment for this session
523
+
524
+ return True
525
+
526
+ except EOFError: # Handle non-interactive environments
527
+ logger.error(f"Cannot prompt for API key '{key_name}' in a non-interactive environment.")
528
+ return False
529
+ except Exception as e:
530
+ logger.error(f"An unexpected error occurred during API key acquisition: {e}")
531
+ return False
532
+
533
+
534
+ def _format_messages(prompt: str, input_data: Union[Dict[str, Any], List[Dict[str, Any]]], use_batch_mode: bool) -> Union[List[Dict[str, str]], List[List[Dict[str, str]]]]:
535
+ """Formats prompt and input into LiteLLM message format."""
536
+ try:
537
+ prompt_template = PromptTemplate.from_template(prompt)
538
+ if use_batch_mode:
539
+ if not isinstance(input_data, list):
540
+ raise ValueError("input_json must be a list of dictionaries when use_batch_mode is True.")
541
+ all_messages = []
542
+ for item in input_data:
543
+ if not isinstance(item, dict):
544
+ raise ValueError("Each item in input_json list must be a dictionary for batch mode.")
545
+ formatted_prompt = prompt_template.format(**item)
546
+ all_messages.append([{"role": "user", "content": formatted_prompt}])
547
+ return all_messages
263
548
  else:
264
- llm.max_tokens = max_tokens
265
- return llm
549
+ if not isinstance(input_data, dict):
550
+ raise ValueError("input_json must be a dictionary when use_batch_mode is False.")
551
+ formatted_prompt = prompt_template.format(**input_data)
552
+ return [{"role": "user", "content": formatted_prompt}]
553
+ except KeyError as e:
554
+ raise ValueError(f"Prompt formatting error: Missing key {e} in input_json for prompt template.") from e
555
+ except Exception as e:
556
+ raise ValueError(f"Error formatting prompt: {e}") from e
266
557
 
267
- def calculate_cost(handler, selected_model):
268
- """
269
- Calculates the cost of the invoke run based on token usage.
558
+ # --- Main Function ---
559
+
560
+ def llm_invoke(
561
+ prompt: Optional[str] = None,
562
+ input_json: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
563
+ strength: float = 0.5, # Use pdd.DEFAULT_STRENGTH if available, else 0.5
564
+ temperature: float = 0.1,
565
+ verbose: bool = False,
566
+ output_pydantic: Optional[Type[BaseModel]] = None,
567
+ time: float = 0.25,
568
+ use_batch_mode: bool = False,
569
+ messages: Optional[Union[List[Dict[str, str]], List[List[Dict[str, str]]]]] = None,
570
+ ) -> Dict[str, Any]:
270
571
  """
271
- input_tokens = handler.input_tokens or 0
272
- output_tokens = handler.output_tokens or 0
273
- input_cost = selected_model.input_cost
274
- output_cost = selected_model.output_cost
275
- total_cost = (input_tokens / 1_000_000) * input_cost + (output_tokens / 1_000_000) * output_cost
276
- return total_cost
572
+ Runs a prompt with given input using LiteLLM, handling model selection,
573
+ API key acquisition, structured output, batching, and reasoning time.
574
+ The maximum completion token length defaults to the provider's maximum.
277
575
 
278
- # ---------------- Main Function ---------------- #
576
+ Args:
577
+ prompt: Prompt template string (required if messages is None).
578
+ input_json: Dictionary or list of dictionaries for prompt variables (required if messages is None).
579
+ strength: Model selection strength (0=cheapest, 0.5=base, 1=highest ELO).
580
+ temperature: LLM temperature.
581
+ verbose: Print detailed logs.
582
+ output_pydantic: Optional Pydantic model for structured output.
583
+ time: Relative thinking time (0-1, default 0.25).
584
+ use_batch_mode: Use batch completion if True.
585
+ messages: Pre-formatted list of messages (or list of lists for batch). If provided, ignores prompt and input_json.
279
586
 
280
- def llm_invoke(prompt, input_json, strength, temperature, verbose=False, output_pydantic=None):
281
- """
282
- Invokes an LLM chain with the provided prompt and input_json, using a model selected based on the strength parameter.
283
-
284
- Inputs:
285
- prompt (str): The prompt template as a string.
286
- input_json (dict): JSON object containing inputs for the prompt.
287
- strength (float): 0 (cheapest) to 1 (highest ELO); 0.5 uses the base model.
288
- temperature (float): Temperature for the LLM invocation.
289
- verbose (bool): When True, prints detailed information.
290
- output_pydantic (Optional): A Pydantic model class for structured output.
291
-
292
- Output (dict): Contains:
293
- 'result' - LLM output (string or parsed Pydantic object).
294
- 'cost' - Calculated cost of the invoke run.
295
- 'model_name' - Name of the selected model that succeeded.
587
+ Returns:
588
+ Dictionary containing 'result', 'cost', 'model_name', 'thinking_output'.
589
+
590
+ Raises:
591
+ ValueError: For invalid inputs or prompt formatting errors.
592
+ FileNotFoundError: If llm_model.csv is missing.
593
+ RuntimeError: If all candidate models fail.
594
+ openai.*Error: If LiteLLM encounters API errors after retries.
296
595
  """
297
- if prompt is None or not isinstance(prompt, str):
298
- raise ValueError("Prompt is required.")
299
- if input_json is None:
300
- raise ValueError("Input JSON is required.")
301
- if not isinstance(input_json, dict):
302
- raise ValueError("Input JSON must be a dictionary.")
303
-
304
- set_llm_cache(SQLiteCache(database_path=".langchain.db"))
305
- base_model_name = os.environ.get('PDD_MODEL_DEFAULT', 'gpt-4.1-nano')
306
- models = load_models()
596
+ # Set verbose logging if requested
597
+ set_verbose_logging(verbose)
307
598
 
308
- try:
309
- base_model = select_model(models, base_model_name)
310
- except ValueError as e:
311
- raise RuntimeError(f"Base model error: {str(e)}") from e
599
+ if verbose:
600
+ logger.debug("llm_invoke start - Arguments received:")
601
+ logger.debug(f" prompt: {'provided' if prompt else 'None'}")
602
+ logger.debug(f" input_json: {'provided' if input_json is not None else 'None'}")
603
+ logger.debug(f" strength: {strength}")
604
+ logger.debug(f" temperature: {temperature}")
605
+ logger.debug(f" verbose: {verbose}")
606
+ logger.debug(f" output_pydantic: {output_pydantic.__name__ if output_pydantic else 'None'}")
607
+ logger.debug(f" time: {time}")
608
+ logger.debug(f" use_batch_mode: {use_batch_mode}")
609
+ logger.debug(f" messages: {'provided' if messages else 'None'}")
312
610
 
313
- candidate_models = get_candidate_models(strength, models, base_model)
611
+ # --- 1. Load Environment & Validate Inputs ---
612
+ # .env loading happens at module level
314
613
 
315
- if verbose:
316
- rprint(f"[bold cyan]Candidate models (in order):[/bold cyan] {[m.model for m in candidate_models]}")
614
+ if messages:
615
+ if verbose:
616
+ logger.info("Using provided 'messages' input.")
617
+ # Basic validation of messages format
618
+ if use_batch_mode:
619
+ if not isinstance(messages, list) or not all(isinstance(m_list, list) for m_list in messages):
620
+ raise ValueError("'messages' must be a list of lists when use_batch_mode is True.")
621
+ if not all(isinstance(msg, dict) and 'role' in msg and 'content' in msg for m_list in messages for msg in m_list):
622
+ raise ValueError("Each message in the lists within 'messages' must be a dictionary with 'role' and 'content'.")
623
+ else:
624
+ if not isinstance(messages, list) or not all(isinstance(msg, dict) and 'role' in msg and 'content' in msg for msg in messages):
625
+ raise ValueError("'messages' must be a list of dictionaries with 'role' and 'content'.")
626
+ formatted_messages = messages
627
+ elif prompt and input_json is not None:
628
+ if not isinstance(prompt, str) or not prompt:
629
+ raise ValueError("'prompt' must be a non-empty string when 'messages' is not provided.")
630
+ formatted_messages = _format_messages(prompt, input_json, use_batch_mode)
631
+ else:
632
+ raise ValueError("Either 'messages' or both 'prompt' and 'input_json' must be provided.")
633
+
634
+ if not (0.0 <= strength <= 1.0):
635
+ raise ValueError("'strength' must be between 0.0 and 1.0.")
636
+ if not (0.0 <= temperature <= 2.0): # Common range for temperature
637
+ warnings.warn("'temperature' is outside the typical range (0.0-2.0).")
638
+ if not (0.0 <= time <= 1.0):
639
+ raise ValueError("'time' must be between 0.0 and 1.0.")
317
640
 
318
- last_error = None
319
- for model in candidate_models:
320
- handler = CompletionStatusHandler()
641
+ # --- 2. Load Model Data & Select Candidates ---
642
+ try:
643
+ model_df = _load_model_data(LLM_MODEL_CSV_PATH)
644
+ candidate_models = _select_model_candidates(strength, DEFAULT_BASE_MODEL, model_df)
645
+ except (FileNotFoundError, ValueError, RuntimeError) as e:
646
+ logger.error(f"Failed during model loading or selection: {e}")
647
+ raise
648
+
649
+ if verbose:
650
+ # This print statement is crucial for the verbose test
651
+ # Calculate and print strength for each candidate model
652
+ # Find min/max for cost and ELO
653
+ min_cost = model_df['avg_cost'].min()
654
+ max_elo = model_df['coding_arena_elo'].max()
655
+ base_cost = model_df[model_df['model'] == DEFAULT_BASE_MODEL]['avg_cost'].iloc[0] if not model_df[model_df['model'] == DEFAULT_BASE_MODEL].empty else min_cost
656
+ base_elo = model_df[model_df['model'] == DEFAULT_BASE_MODEL]['coding_arena_elo'].iloc[0] if not model_df[model_df['model'] == DEFAULT_BASE_MODEL].empty else max_elo
657
+
658
+ def calc_strength(candidate):
659
+ # If strength < 0.5, interpolate by cost (cheaper = 0, base = 0.5)
660
+ # If strength > 0.5, interpolate by ELO (base = 0.5, highest = 1.0)
661
+ avg_cost = candidate.get('avg_cost', min_cost)
662
+ elo = candidate.get('coding_arena_elo', base_elo)
663
+ if strength < 0.5:
664
+ # Map cost to [0, 0.5]
665
+ if base_cost == min_cost:
666
+ return 0.5 # Avoid div by zero
667
+ rel = (avg_cost - min_cost) / (base_cost - min_cost)
668
+ return max(0.0, min(0.5, rel * 0.5))
669
+ elif strength > 0.5:
670
+ # Map ELO to [0.5, 1.0]
671
+ if max_elo == base_elo:
672
+ return 0.5 # Avoid div by zero
673
+ rel = (elo - base_elo) / (max_elo - base_elo)
674
+ return max(0.5, min(1.0, 0.5 + rel * 0.5))
675
+ else:
676
+ return 0.5
677
+
678
+ model_strengths_formatted = [(c['model'], f"{float(calc_strength(c)):.3f}") for c in candidate_models]
679
+ logger.info("Candidate models selected and ordered (with strength): %s", model_strengths_formatted) # CORRECTED
680
+ logger.info(f"Strength: {strength}, Temperature: {temperature}, Time: {time}")
681
+ if use_batch_mode:
682
+ logger.info("Batch mode enabled.")
683
+ if output_pydantic:
684
+ logger.info(f"Pydantic output requested: {output_pydantic.__name__}")
321
685
  try:
322
- try:
323
- prompt_template = PromptTemplate.from_template(prompt)
324
- except ValueError:
325
- raise ValueError("Invalid prompt template")
686
+ # Only print input_json if it was actually provided (not when messages were used)
687
+ if input_json is not None:
688
+ logger.info("Input JSON:")
689
+ logger.info(input_json)
690
+ else:
691
+ logger.info("Input: Using pre-formatted 'messages'.")
692
+ except Exception:
693
+ logger.info("Input JSON/Messages (fallback print):")
694
+ logger.info(input_json if input_json is not None else "[Messages provided directly]")
695
+
696
+
697
+ # --- 3. Iterate Through Candidates and Invoke LLM ---
698
+ last_exception = None
699
+ newly_acquired_keys: Dict[str, bool] = {} # Track keys obtained in this run
700
+
701
+ for model_info in candidate_models:
702
+ model_name_litellm = model_info['model']
703
+ api_key_name = model_info.get('api_key')
704
+ provider = model_info.get('provider', '').lower()
705
+
706
+ if verbose:
707
+ logger.info(f"\n[ATTEMPT] Trying model: {model_name_litellm} (Provider: {provider})")
708
+
709
+ retry_with_same_model = True
710
+ while retry_with_same_model:
711
+ retry_with_same_model = False # Assume success unless auth error on new key
326
712
 
327
- llm = create_llm_instance(model, temperature, handler)
713
+ # --- 4. API Key Check & Acquisition ---
714
+ if not _ensure_api_key(model_info, newly_acquired_keys, verbose):
715
+ # Problem getting key, break inner loop, try next model candidate
716
+ if verbose:
717
+ logger.info(f"[SKIP] Skipping {model_name_litellm} due to API key/credentials issue after prompt.")
718
+ break # Breaks the 'while retry_with_same_model' loop
719
+
720
+ # --- 5. Prepare LiteLLM Arguments ---
721
+ litellm_kwargs: Dict[str, Any] = {
722
+ "model": model_name_litellm,
723
+ "messages": formatted_messages,
724
+ "temperature": temperature,
725
+ }
726
+
727
+ api_key_name_from_csv = model_info.get('api_key') # From CSV
728
+ # Determine if it's a Vertex AI model for special handling
729
+ is_vertex_model = (provider.lower() == 'google') or \
730
+ (provider.lower() == 'googlevertexai') or \
731
+ (provider.lower() == 'vertex_ai') or \
732
+ model_name_litellm.startswith('vertex_ai/')
733
+
734
+ if is_vertex_model and api_key_name_from_csv == 'VERTEX_CREDENTIALS':
735
+ credentials_file_path = os.getenv("VERTEX_CREDENTIALS") # Path from env var
736
+ vertex_project_env = os.getenv("VERTEX_PROJECT")
737
+ vertex_location_env = os.getenv("VERTEX_LOCATION")
738
+
739
+ if credentials_file_path and vertex_project_env and vertex_location_env:
740
+ try:
741
+ with open(credentials_file_path, 'r') as f:
742
+ loaded_credentials = json.load(f)
743
+ vertex_credentials_json_string = json.dumps(loaded_credentials)
744
+
745
+ litellm_kwargs["vertex_credentials"] = vertex_credentials_json_string
746
+ litellm_kwargs["vertex_project"] = vertex_project_env
747
+ litellm_kwargs["vertex_location"] = vertex_location_env
748
+ if verbose:
749
+ logger.info(f"[INFO] For Vertex AI: using vertex_credentials from '{credentials_file_path}', project '{vertex_project_env}', location '{vertex_location_env}'.")
750
+ except FileNotFoundError:
751
+ if verbose:
752
+ logger.error(f"[ERROR] Vertex credentials file not found at path specified by VERTEX_CREDENTIALS env var: '{credentials_file_path}'. LiteLLM may try ADC or fail.")
753
+ except json.JSONDecodeError:
754
+ if verbose:
755
+ logger.error(f"[ERROR] Failed to decode JSON from Vertex credentials file: '{credentials_file_path}'. Check file content. LiteLLM may try ADC or fail.")
756
+ except Exception as e:
757
+ if verbose:
758
+ logger.error(f"[ERROR] Failed to load or process Vertex credentials from '{credentials_file_path}': {e}. LiteLLM may try ADC or fail.")
759
+ else:
760
+ if verbose:
761
+ logger.warning(f"[WARN] For Vertex AI (using '{api_key_name_from_csv}'): One or more required environment variables (VERTEX_CREDENTIALS, VERTEX_PROJECT, VERTEX_LOCATION) are missing.")
762
+ if not credentials_file_path: logger.warning(f" Reason: VERTEX_CREDENTIALS (path to JSON file) env var not set or empty.")
763
+ if not vertex_project_env: logger.warning(f" Reason: VERTEX_PROJECT env var not set or empty.")
764
+ if not vertex_location_env: logger.warning(f" Reason: VERTEX_LOCATION env var not set or empty.")
765
+ logger.warning(f" LiteLLM may attempt to use Application Default Credentials or the call may fail.")
766
+
767
+ elif api_key_name_from_csv: # For other api_key_names specified in CSV (e.g., OPENAI_API_KEY, or a direct VERTEX_AI_API_KEY string)
768
+ key_value = os.getenv(api_key_name_from_csv)
769
+ if key_value:
770
+ litellm_kwargs["api_key"] = key_value
771
+ if verbose:
772
+ logger.info(f"[INFO] Explicitly passing API key from env var '{api_key_name_from_csv}' as 'api_key' parameter to LiteLLM.")
773
+
774
+ # If this model is Vertex AI AND uses a direct API key string (not VERTEX_CREDENTIALS from CSV),
775
+ # also pass project and location from env vars.
776
+ if is_vertex_model:
777
+ vertex_project_env = os.getenv("VERTEX_PROJECT")
778
+ vertex_location_env = os.getenv("VERTEX_LOCATION")
779
+ if vertex_project_env and vertex_location_env:
780
+ litellm_kwargs["vertex_project"] = vertex_project_env
781
+ litellm_kwargs["vertex_location"] = vertex_location_env
782
+ if verbose:
783
+ logger.info(f"[INFO] For Vertex AI model (using direct API key '{api_key_name_from_csv}'), also passing vertex_project='{vertex_project_env}' and vertex_location='{vertex_location_env}' from env vars.")
784
+ elif verbose:
785
+ logger.warning(f"[WARN] For Vertex AI model (using direct API key '{api_key_name_from_csv}'), VERTEX_PROJECT or VERTEX_LOCATION env vars not set. This might be required by LiteLLM.")
786
+ elif verbose: # api_key_name_from_csv was in CSV, but corresponding env var was not set/empty
787
+ logger.warning(f"[WARN] API key name '{api_key_name_from_csv}' found in CSV, but the environment variable '{api_key_name_from_csv}' is not set or empty. LiteLLM will use default authentication if applicable (e.g., other standard env vars or ADC).")
788
+
789
+ elif verbose: # No api_key_name_from_csv in CSV for this model
790
+ logger.info(f"[INFO] No API key name specified in CSV for model '{model_name_litellm}'. LiteLLM will use its default authentication mechanisms (e.g., standard provider env vars or ADC for Vertex AI).")
791
+
792
+ # Add api_base if present in CSV
793
+ api_base = model_info.get('base_url')
794
+ if pd.notna(api_base) and api_base:
795
+ litellm_kwargs["api_base"] = str(api_base)
796
+
797
+ # Handle Structured Output (JSON Mode / Pydantic)
328
798
  if output_pydantic:
329
- if model.structured_output:
330
- llm = llm.with_structured_output(output_pydantic)
331
- chain = prompt_template | llm
799
+ # Check if model supports structured output based on CSV flag or LiteLLM check
800
+ supports_structured = model_info.get('structured_output', False)
801
+ # Optional: Add litellm.supports_response_schema check if CSV flag is unreliable
802
+ # if not supports_structured:
803
+ # try: supports_structured = litellm.supports_response_schema(model=model_name_litellm)
804
+ # except: pass # Ignore errors in supports_response_schema check
805
+
806
+ if supports_structured:
807
+ if verbose:
808
+ logger.info(f"[INFO] Requesting structured output (Pydantic: {output_pydantic.__name__}) for {model_name_litellm}")
809
+ # Pass the Pydantic model directly if supported, else use json_object
810
+ # LiteLLM handles passing Pydantic models for supported providers
811
+ litellm_kwargs["response_format"] = output_pydantic
812
+ # As a fallback, one could use:
813
+ # litellm_kwargs["response_format"] = {"type": "json_object"}
814
+ # And potentially enable client-side validation:
815
+ # litellm.enable_json_schema_validation = True # Enable globally if needed
332
816
  else:
333
- parser = PydanticOutputParser(pydantic_object=output_pydantic)
334
- chain = prompt_template | llm | parser
335
- else:
336
- chain = prompt_template | llm | StrOutputParser()
337
-
338
- result_output = chain.invoke(input_json)
339
- cost = calculate_cost(handler, model)
340
-
341
- if verbose:
342
- rprint(f"[bold green]Selected model: {model.model}[/bold green]")
343
- rprint(f"Per input token cost: ${model.input_cost} per million tokens")
344
- rprint(f"Per output token cost: ${model.output_cost} per million tokens")
345
- rprint(f"Number of input tokens: {handler.input_tokens}")
346
- rprint(f"Number of output tokens: {handler.output_tokens}")
347
- rprint(f"Cost of invoke run: ${cost:.0e}")
348
- rprint(f"Strength used: {strength}")
349
- rprint(f"Temperature used: {temperature}")
350
- try:
351
- # Try printing with rich formatting first
352
- rprint(f"Input JSON: {str(input_json)}")
353
- except MarkupError:
354
- # Fallback to standard print if rich markup fails
355
- print(f"Input JSON: {str(input_json)}")
356
- except Exception:
357
- print(f"Input JSON: {input_json}")
358
- if output_pydantic:
359
- rprint(f"Output Pydantic format: {output_pydantic}")
360
- try:
361
- # Try printing with rich formatting first
362
- rprint(f"Result: {result_output}")
363
- except MarkupError as me:
364
- # Fallback to standard print if rich markup fails
365
- print(f"[bold yellow]Warning:[/bold yellow] Failed to render result with rich markup: {me}")
366
- print(f"Raw Result: {str(result_output)}") # Use standard print
367
-
368
- return {'result': result_output, 'cost': cost, 'model_name': model.model}
817
+ if verbose:
818
+ logger.warning(f"[WARN] Model {model_name_litellm} does not support structured output via CSV flag. Output might not be valid {output_pydantic.__name__}.")
819
+ # Proceed without forcing JSON mode, parsing will be attempted later
369
820
 
370
- except Exception as e:
371
- last_error = e
372
- if verbose:
373
- rprint(f"[red]Error with model {model.model}: {str(e)}[/red]")
374
- continue
821
+ # --- NEW REASONING LOGIC ---
822
+ reasoning_type = model_info.get('reasoning_type', 'none') # Defaults to 'none'
823
+ max_reasoning_tokens_val = model_info.get('max_reasoning_tokens', 0) # Defaults to 0
824
+
825
+ if time > 0: # Only apply reasoning if time is requested
826
+ if reasoning_type == 'budget':
827
+ if max_reasoning_tokens_val > 0:
828
+ budget = int(time * max_reasoning_tokens_val)
829
+ if budget > 0:
830
+ # Currently known: Anthropic uses 'thinking'
831
+ # Model name comparison is more robust than provider string
832
+ if provider == 'anthropic': # Check provider column instead of model prefix
833
+ litellm_kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget}
834
+ if verbose:
835
+ logger.info(f"[INFO] Requesting Anthropic thinking (budget type) with budget: {budget} tokens for {model_name_litellm}")
836
+ else:
837
+ # If other providers adopt a budget param recognized by LiteLLM, add here
838
+ if verbose:
839
+ logger.warning(f"[WARN] Reasoning type is 'budget' for {model_name_litellm}, but no specific LiteLLM budget parameter known for this provider. Parameter not sent.")
840
+ elif verbose:
841
+ logger.info(f"[INFO] Calculated reasoning budget is 0 for {model_name_litellm}, skipping reasoning parameter.")
842
+ elif verbose:
843
+ logger.warning(f"[WARN] Reasoning type is 'budget' for {model_name_litellm}, but 'max_reasoning_tokens' is missing or zero in CSV. Reasoning parameter not sent.")
844
+
845
+ elif reasoning_type == 'effort':
846
+ effort = "low"
847
+ if time > 0.7:
848
+ effort = "high"
849
+ elif time > 0.3:
850
+ effort = "medium"
851
+ # Use the common 'reasoning_effort' param LiteLLM provides
852
+ litellm_kwargs["reasoning_effort"] = effort
853
+ if verbose:
854
+ logger.info(f"[INFO] Requesting reasoning_effort='{effort}' (effort type) for {model_name_litellm} based on time={time}")
855
+
856
+ elif reasoning_type == 'none':
857
+ if verbose:
858
+ logger.info(f"[INFO] Model {model_name_litellm} has reasoning_type='none'. No reasoning parameter sent.")
859
+
860
+ else: # Unknown reasoning_type in CSV
861
+ if verbose:
862
+ logger.warning(f"[WARN] Unknown reasoning_type '{reasoning_type}' for model {model_name_litellm} in CSV. No reasoning parameter sent.")
863
+
864
+ # --- END NEW REASONING LOGIC ---
865
+
866
+ # Add caching control per call if needed (example: force refresh)
867
+ # litellm_kwargs["cache"] = {"no-cache": True}
868
+
869
+ # --- 6. LLM Invocation ---
870
+ try:
871
+ start_time = time_module.time()
375
872
 
376
- if isinstance(last_error, ValueError) and "Invalid prompt template" in str(last_error):
377
- raise ValueError("Invalid prompt template")
378
- if last_error:
379
- raise RuntimeError(f"Error during LLM invocation: {str(last_error)}")
380
- raise RuntimeError("No available models could process the request")
873
+ # Log cache status with proper logging
874
+ logger.debug(f"Cache Check: litellm.cache is None: {litellm.cache is None}")
875
+ if litellm.cache is not None:
876
+ logger.debug(f"litellm.cache type: {type(litellm.cache)}, ID: {id(litellm.cache)}")
381
877
 
878
+ # Only add if litellm.cache is configured
879
+ if litellm.cache is not None:
880
+ litellm_kwargs["caching"] = True
881
+ logger.debug("Caching enabled for this request")
882
+ else:
883
+ logger.debug("NOT ENABLING CACHING: litellm.cache is None at call time")
884
+
885
+
886
+ if use_batch_mode:
887
+ if verbose:
888
+ logger.info(f"[INFO] Calling litellm.batch_completion for {model_name_litellm}...")
889
+ response = litellm.batch_completion(**litellm_kwargs)
890
+
891
+
892
+ else:
893
+ if verbose:
894
+ logger.info(f"[INFO] Calling litellm.completion for {model_name_litellm}...")
895
+ response = litellm.completion(**litellm_kwargs)
896
+
897
+ end_time = time_module.time()
898
+
899
+ if verbose:
900
+ logger.info(f"[SUCCESS] Invocation successful for {model_name_litellm} (took {end_time - start_time:.2f}s)")
901
+
902
+ # --- 7. Process Response ---
903
+ results = []
904
+ thinking_outputs = []
905
+
906
+ response_list = response if use_batch_mode else [response]
907
+
908
+ for i, resp_item in enumerate(response_list):
909
+ # Cost calculation is handled entirely by the success callback
910
+
911
+ # Thinking Output
912
+ thinking = None
913
+ try:
914
+ # Attempt 1: Check _hidden_params based on isolated test script
915
+ if hasattr(resp_item, '_hidden_params') and resp_item._hidden_params and 'thinking' in resp_item._hidden_params:
916
+ thinking = resp_item._hidden_params['thinking']
917
+ if verbose:
918
+ logger.debug("[DEBUG] Extracted thinking output from response._hidden_params['thinking']")
919
+ # Attempt 2: Fallback to reasoning_content in message
920
+ # Use .get() for safer access
921
+ elif hasattr(resp_item, 'choices') and resp_item.choices and hasattr(resp_item.choices[0], 'message') and hasattr(resp_item.choices[0].message, 'get') and resp_item.choices[0].message.get('reasoning_content'):
922
+ thinking = resp_item.choices[0].message.get('reasoning_content')
923
+ if verbose:
924
+ logger.debug("[DEBUG] Extracted thinking output from response.choices[0].message.get('reasoning_content')")
925
+
926
+ except (AttributeError, IndexError, KeyError, TypeError):
927
+ if verbose:
928
+ logger.debug("[DEBUG] Failed to extract thinking output from known locations.")
929
+ pass # Ignore if structure doesn't match or errors occur
930
+ thinking_outputs.append(thinking)
931
+
932
+ # Result (String or Pydantic)
933
+ try:
934
+ raw_result = resp_item.choices[0].message.content
935
+ if output_pydantic:
936
+ parsed_result = None
937
+ json_string_to_parse = None
938
+
939
+ try:
940
+ # Attempt 1: Check if LiteLLM already parsed it
941
+ if isinstance(raw_result, output_pydantic):
942
+ parsed_result = raw_result
943
+ if verbose:
944
+ logger.debug("[DEBUG] Pydantic object received directly from LiteLLM.")
945
+
946
+ # Attempt 2: Check if raw_result is dict-like and validate
947
+ elif isinstance(raw_result, dict):
948
+ parsed_result = output_pydantic.model_validate(raw_result)
949
+ if verbose:
950
+ logger.debug("[DEBUG] Validated dictionary-like object directly.")
951
+
952
+ # Attempt 3: Process as string (if not already parsed/validated)
953
+ elif isinstance(raw_result, str):
954
+ json_string_to_parse = raw_result # Start with the raw string
955
+ try:
956
+ # Look for first { and last }
957
+ start_brace = json_string_to_parse.find('{')
958
+ end_brace = json_string_to_parse.rfind('}')
959
+ if start_brace != -1 and end_brace != -1 and end_brace > start_brace:
960
+ potential_json = json_string_to_parse[start_brace:end_brace+1]
961
+ # Basic check if it looks like JSON
962
+ if potential_json.strip().startswith('{') and potential_json.strip().endswith('}'):
963
+ if verbose:
964
+ logger.debug(f"[DEBUG] Attempting to parse extracted JSON block: '{potential_json}'")
965
+ parsed_result = output_pydantic.model_validate_json(potential_json)
966
+ else:
967
+ # If block extraction fails, try cleaning markdown next
968
+ raise ValueError("Extracted block doesn't look like JSON")
969
+ else:
970
+ # If no braces found, try cleaning markdown next
971
+ raise ValueError("Could not find enclosing {}")
972
+ except (json.JSONDecodeError, ValidationError, ValueError) as extraction_error:
973
+ if verbose:
974
+ logger.debug(f"[DEBUG] JSON block extraction/validation failed ('{extraction_error}'). Trying markdown cleaning.")
975
+ # Fallback: Clean markdown fences and retry JSON validation
976
+ cleaned_result_str = raw_result.strip()
977
+ if cleaned_result_str.startswith("```json"):
978
+ cleaned_result_str = cleaned_result_str[7:]
979
+ elif cleaned_result_str.startswith("```"):
980
+ cleaned_result_str = cleaned_result_str[3:]
981
+ if cleaned_result_str.endswith("```"):
982
+ cleaned_result_str = cleaned_result_str[:-3]
983
+ cleaned_result_str = cleaned_result_str.strip()
984
+ # Check again if it looks like JSON before parsing
985
+ if cleaned_result_str.startswith('{') and cleaned_result_str.endswith('}'):
986
+ if verbose:
987
+ logger.debug(f"[DEBUG] Attempting parse after cleaning markdown fences. Cleaned string: '{cleaned_result_str}'")
988
+ json_string_to_parse = cleaned_result_str # Update string for error reporting
989
+ parsed_result = output_pydantic.model_validate_json(json_string_to_parse)
990
+ else:
991
+ # If still doesn't look like JSON, raise error
992
+ raise ValueError("Content after cleaning markdown doesn't look like JSON")
993
+
994
+
995
+ # Check if any parsing attempt succeeded
996
+ if parsed_result is None:
997
+ # This case should ideally be caught by exceptions above, but as a safeguard:
998
+ raise TypeError(f"Raw result type {type(raw_result)} or content could not be validated/parsed against {output_pydantic.__name__}.")
999
+
1000
+ except (ValidationError, json.JSONDecodeError, TypeError, ValueError) as parse_error:
1001
+ logger.error(f"[ERROR] Failed to parse response into Pydantic model {output_pydantic.__name__} for item {i}: {parse_error}")
1002
+ # Use the string that was last attempted for parsing in the error message
1003
+ error_content = json_string_to_parse if json_string_to_parse is not None else raw_result
1004
+ logger.error("[ERROR] Content attempted for parsing: %s", repr(error_content)) # CORRECTED (or use f-string)
1005
+ results.append(f"ERROR: Failed to parse Pydantic. Raw: {repr(raw_result)}")
1006
+ continue # Skip appending result below if parsing failed
1007
+
1008
+ # If parsing succeeded, append the parsed_result
1009
+ results.append(parsed_result)
1010
+
1011
+ else:
1012
+ # If output_pydantic was not requested, append the raw result
1013
+ results.append(raw_result)
1014
+
1015
+ except (AttributeError, IndexError) as e:
1016
+ logger.error(f"[ERROR] Could not extract result content from response item {i}: {e}")
1017
+ results.append(f"ERROR: Could not extract result content. Response: {resp_item}")
1018
+
1019
+ # --- Retrieve Cost from Callback Data --- (Reinstated)
1020
+ # For batch, this will reflect the cost associated with the *last* item processed by the callback.
1021
+ # A fully accurate batch total would require a more complex callback class to aggregate.
1022
+ total_cost = _LAST_CALLBACK_DATA.get("cost", 0.0)
1023
+ # ----------------------------------------
1024
+
1025
+ final_result = results if use_batch_mode else results[0]
1026
+ final_thinking = thinking_outputs if use_batch_mode else thinking_outputs[0]
1027
+
1028
+ # --- Verbose Output for Success ---
1029
+ if verbose:
1030
+ # Get token usage from the *last* callback data (might not be accurate for batch)
1031
+ input_tokens = _LAST_CALLBACK_DATA.get("input_tokens", 0)
1032
+ output_tokens = _LAST_CALLBACK_DATA.get("output_tokens", 0)
1033
+
1034
+ cost_input_pm = model_info.get('input', 0.0) if pd.notna(model_info.get('input')) else 0.0
1035
+ cost_output_pm = model_info.get('output', 0.0) if pd.notna(model_info.get('output')) else 0.0
1036
+
1037
+ logger.info(f"[RESULT] Model Used: {model_name_litellm}")
1038
+ logger.info(f"[RESULT] Cost (Input): ${cost_input_pm:.2f}/M tokens")
1039
+ logger.info(f"[RESULT] Cost (Output): ${cost_output_pm:.2f}/M tokens")
1040
+ logger.info(f"[RESULT] Tokens (Prompt): {input_tokens}")
1041
+ logger.info(f"[RESULT] Tokens (Completion): {output_tokens}")
1042
+ # Display the cost captured by the callback
1043
+ logger.info(f"[RESULT] Total Cost (from callback): ${total_cost:.6g}") # Renamed label for clarity
1044
+ logger.info("[RESULT] Max Completion Tokens: Provider Default") # Indicate default limit
1045
+ if final_thinking:
1046
+ logger.info("[RESULT] Thinking Output:")
1047
+ logger.info(final_thinking)
1048
+
1049
+ # --- Print raw output before returning if verbose ---
1050
+ if verbose:
1051
+ logger.debug("[DEBUG] Raw output before return:")
1052
+ logger.debug(f" Raw Result (repr): {repr(final_result)}")
1053
+ logger.debug(f" Raw Thinking (repr): {repr(final_thinking)}")
1054
+ logger.debug("-" * 20) # Separator
1055
+
1056
+ # --- Return Success ---
1057
+ return {
1058
+ 'result': final_result,
1059
+ 'cost': total_cost,
1060
+ 'model_name': model_name_litellm, # Actual model used
1061
+ 'thinking_output': final_thinking if final_thinking else None
1062
+ }
1063
+
1064
+ # --- 6b. Handle Invocation Errors ---
1065
+ except openai.AuthenticationError as e:
1066
+ last_exception = e
1067
+ if newly_acquired_keys.get(api_key_name):
1068
+ logger.warning(f"[AUTH ERROR] Authentication failed for {model_name_litellm} with the newly provided key for '{api_key_name}'. Please check the key and try again.")
1069
+ # Invalidate the key in env for this session to force re-prompt on retry
1070
+ if api_key_name in os.environ:
1071
+ del os.environ[api_key_name]
1072
+ # Clear the 'newly acquired' status for this key so the next attempt doesn't trigger immediate retry loop
1073
+ newly_acquired_keys[api_key_name] = False
1074
+ retry_with_same_model = True # Set flag to retry the same model after re-prompt
1075
+ # Go back to the start of the 'while retry_with_same_model' loop
1076
+ else:
1077
+ logger.warning(f"[AUTH ERROR] Authentication failed for {model_name_litellm} using existing key '{api_key_name}'. Trying next model.")
1078
+ break # Break inner loop, try next model candidate
1079
+
1080
+ except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError,
1081
+ openai.APIStatusError, openai.BadRequestError, openai.InternalServerError,
1082
+ Exception) as e: # Catch generic Exception last
1083
+ last_exception = e
1084
+ error_type = type(e).__name__
1085
+ logger.error(f"[ERROR] Invocation failed for {model_name_litellm} ({error_type}): {e}. Trying next model.")
1086
+ # Log more details in verbose mode
1087
+ if verbose:
1088
+ # import traceback # Not needed if using exc_info=True
1089
+ logger.debug(f"Detailed exception traceback for {model_name_litellm}:", exc_info=True)
1090
+ break # Break inner loop, try next model candidate
1091
+
1092
+ # If the inner loop was broken (not by success), continue to the next candidate model
1093
+ continue
1094
+
1095
+ # --- 8. Handle Failure of All Candidates ---
1096
+ error_message = "All candidate models failed."
1097
+ if last_exception:
1098
+ error_message += f" Last error ({type(last_exception).__name__}): {last_exception}"
1099
+ logger.error(f"[FATAL] {error_message}")
1100
+ raise RuntimeError(error_message) from last_exception
1101
+
1102
+ # --- Example Usage (Optional) ---
382
1103
  if __name__ == "__main__":
383
- example_prompt = "Tell me a joke about {topic}"
384
- example_input = {"topic": "programming"}
1104
+ # This block allows running the file directly for testing.
1105
+ # Ensure you have a ./data/llm_model.csv file and potentially a .env file.
1106
+
1107
+ # Set PDD_DEBUG_SELECTOR=1 to see model selection details
1108
+ # os.environ["PDD_DEBUG_SELECTOR"] = "1"
1109
+
1110
+ # Example 1: Simple text generation
1111
+ logger.info("\n--- Example 1: Simple Text Generation (Strength 0.5) ---")
1112
+ try:
1113
+ response = llm_invoke(
1114
+ prompt="Tell me a short joke about {topic}.",
1115
+ input_json={"topic": "programmers"},
1116
+ strength=0.5, # Use base model (gpt-4.1-nano)
1117
+ temperature=0.7,
1118
+ verbose=True
1119
+ )
1120
+ logger.info("\nExample 1 Response:")
1121
+ logger.info(response)
1122
+ except Exception as e:
1123
+ logger.error(f"\nExample 1 Failed: {e}", exc_info=True)
1124
+
1125
+ # Example 1b: Simple text generation (Strength 0.3)
1126
+ logger.info("\n--- Example 1b: Simple Text Generation (Strength 0.3) ---")
385
1127
  try:
386
- output = llm_invoke(example_prompt, example_input, strength=0.5, temperature=0.7, verbose=True)
387
- rprint("[bold magenta]Invocation succeeded:[/bold magenta]", output)
388
- except Exception as err:
389
- rprint(f"[bold red]Invocation failed:[/bold red] {err}")
1128
+ response = llm_invoke(
1129
+ prompt="Tell me a short joke about {topic}.",
1130
+ input_json={"topic": "keyboards"},
1131
+ strength=0.3, # Should select gemini-pro based on cost interpolation
1132
+ temperature=0.7,
1133
+ verbose=True
1134
+ )
1135
+ logger.info("\nExample 1b Response:")
1136
+ logger.info(response)
1137
+ except Exception as e:
1138
+ logger.error(f"\nExample 1b Failed: {e}", exc_info=True)
1139
+
1140
+ # Example 2: Structured output (requires a Pydantic model)
1141
+ logger.info("\n--- Example 2: Structured Output (Pydantic, Strength 0.8) ---")
1142
+ class JokeStructure(BaseModel):
1143
+ setup: str
1144
+ punchline: str
1145
+ rating: Optional[int] = None
1146
+
1147
+ try:
1148
+ # Use a model known to support structured output (check your CSV)
1149
+ # Strength 0.8 should select gemini-pro based on ELO interpolation
1150
+ response_structured = llm_invoke(
1151
+ prompt="Create a joke about {topic}. Output ONLY the JSON object with 'setup' and 'punchline'.",
1152
+ input_json={"topic": "data science"},
1153
+ strength=0.8, # Try a higher ELO model (gemini-pro expected)
1154
+ temperature=1,
1155
+ output_pydantic=JokeStructure,
1156
+ verbose=True
1157
+ )
1158
+ logger.info("\nExample 2 Response:")
1159
+ logger.info(response_structured)
1160
+ if isinstance(response_structured.get('result'), JokeStructure):
1161
+ logger.info("\nPydantic object received successfully: %s", response_structured['result'].model_dump())
1162
+ else:
1163
+ logger.info("\nResult was not the expected Pydantic object: %s", response_structured.get('result'))
1164
+
1165
+ except Exception as e:
1166
+ logger.error(f"\nExample 2 Failed: {e}", exc_info=True)
1167
+
1168
+
1169
+ # Example 3: Batch processing
1170
+ logger.info("\n--- Example 3: Batch Processing (Strength 0.3) ---")
1171
+ try:
1172
+ batch_input = [
1173
+ {"animal": "cat", "adjective": "lazy"},
1174
+ {"animal": "dog", "adjective": "energetic"},
1175
+ ]
1176
+ # Strength 0.3 should select gemini-pro
1177
+ response_batch = llm_invoke(
1178
+ prompt="Describe a {adjective} {animal} in one sentence.",
1179
+ input_json=batch_input,
1180
+ strength=0.3, # Cheaper model maybe (gemini-pro expected)
1181
+ temperature=0.5,
1182
+ use_batch_mode=True,
1183
+ verbose=True
1184
+ )
1185
+ logger.info("\nExample 3 Response:")
1186
+ logger.info(response_batch)
1187
+ except Exception as e:
1188
+ logger.error(f"\nExample 3 Failed: {e}", exc_info=True)
1189
+
1190
+ # Example 4: Using 'messages' input
1191
+ logger.info("\n--- Example 4: Using 'messages' input (Strength 0.5) ---")
1192
+ try:
1193
+ custom_messages = [
1194
+ {"role": "system", "content": "You are a helpful assistant."},
1195
+ {"role": "user", "content": "What is the capital of France?"}
1196
+ ]
1197
+ # Strength 0.5 should select gpt-4.1-nano
1198
+ response_messages = llm_invoke(
1199
+ messages=custom_messages,
1200
+ strength=0.5,
1201
+ temperature=0.1,
1202
+ verbose=True
1203
+ )
1204
+ logger.info("\nExample 4 Response:")
1205
+ logger.info(response_messages)
1206
+ except Exception as e:
1207
+ logger.error(f"\nExample 4 Failed: {e}", exc_info=True)
1208
+
1209
+ # Example 5: Requesting thinking time (e.g., for Anthropic)
1210
+ logger.info("\n--- Example 5: Requesting Thinking Time (Strength 1.0, Time 0.5) ---")
1211
+ try:
1212
+ # Ensure your CSV has max_reasoning_tokens for an Anthropic model
1213
+ # Strength 1.0 should select claude-3 (highest ELO)
1214
+ # Time 0.5 with budget type should request thinking
1215
+ response_thinking = llm_invoke(
1216
+ prompt="Explain the theory of relativity simply, taking some time to think.",
1217
+ input_json={},
1218
+ strength=1.0, # Try to get highest ELO model (claude-3)
1219
+ temperature=1,
1220
+ time=0.5, # Request moderate thinking time
1221
+ verbose=True
1222
+ )
1223
+ logger.info("\nExample 5 Response:")
1224
+ logger.info(response_thinking)
1225
+ except Exception as e:
1226
+ logger.error(f"\nExample 5 Failed: {e}", exc_info=True)
1227
+
1228
+ # Example 6: Pydantic Fallback Parsing (Strength 0.3)
1229
+ logger.info("\n--- Example 6: Pydantic Fallback Parsing (Strength 0.3) ---")
1230
+ # This requires mocking litellm.completion to return a JSON string
1231
+ # even when gemini-pro (which supports structured output) is selected.
1232
+ # This is hard to demonstrate cleanly in the __main__ block without mocks.
1233
+ # The unit test test_llm_invoke_output_pydantic_unsupported_parses covers this.
1234
+ logger.info("(Covered by unit tests)")