pdd-cli 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pdd-cli might be problematic. Click here for more details.
- pdd/__init__.py +10 -3
- pdd/bug_to_unit_test.py +1 -1
- pdd/cli.py +12 -3
- pdd/cli_1_0_2_0_20250510_000314.py +1054 -0
- pdd/cli_2_0_1_0_20250510_000314.py +1054 -0
- pdd/cli_3_0_1_0_20250510_000314.py +1054 -0
- pdd/cli_4_0_1_0_20250510_000314.py +1054 -0
- pdd/continue_generation.py +3 -1
- pdd/data/llm_model.csv +18 -17
- pdd/fix_main.py +3 -2
- pdd/fix_verification_errors.py +154 -109
- pdd/fix_verification_errors_loop.py +5 -1
- pdd/fix_verification_main.py +21 -1
- pdd/generate_output_paths.py +43 -2
- pdd/llm_invoke.py +1198 -353
- pdd/prompts/bug_to_unit_test_LLM.prompt +11 -11
- pdd/prompts/find_verification_errors_LLM.prompt +31 -18
- pdd/prompts/fix_verification_errors_LLM.prompt +25 -6
- pdd/prompts/trim_results_start_LLM.prompt +1 -1
- pdd/update_model_costs.py +446 -0
- {pdd_cli-0.0.25.dist-info → pdd_cli-0.0.27.dist-info}/METADATA +8 -16
- {pdd_cli-0.0.25.dist-info → pdd_cli-0.0.27.dist-info}/RECORD +26 -21
- {pdd_cli-0.0.25.dist-info → pdd_cli-0.0.27.dist-info}/WHEEL +1 -1
- {pdd_cli-0.0.25.dist-info → pdd_cli-0.0.27.dist-info}/entry_points.txt +0 -0
- {pdd_cli-0.0.25.dist-info → pdd_cli-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {pdd_cli-0.0.25.dist-info → pdd_cli-0.0.27.dist-info}/top_level.txt +0 -0
pdd/llm_invoke.py
CHANGED
|
@@ -1,389 +1,1234 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
llm_invoke.py
|
|
4
|
-
|
|
5
|
-
This module provides a single function, llm_invoke, that runs a prompt with a given input
|
|
6
|
-
against a language model (LLM) using Langchain and returns the output, cost, and model name.
|
|
7
|
-
The function supports model selection based on cost/ELO interpolation controlled by the
|
|
8
|
-
"strength" parameter. It also implements a retry mechanism: if a model invocation fails,
|
|
9
|
-
it falls back to the next candidate (cheaper for strength < 0.5, or higher ELO for strength ≥ 0.5).
|
|
10
|
-
|
|
11
|
-
Usage:
|
|
12
|
-
from llm_invoke import llm_invoke
|
|
13
|
-
result = llm_invoke(prompt, input_json, strength, temperature, verbose=True, output_pydantic=MyPydanticClass)
|
|
14
|
-
# result is a dict with keys: 'result', 'cost', 'model_name'
|
|
15
|
-
|
|
16
|
-
Environment:
|
|
17
|
-
- PDD_MODEL_DEFAULT: if set, used as the base model name. Otherwise defaults to "gpt-4.1-nano".
|
|
18
|
-
- PDD_PATH: if set, models are loaded from $PDD_PATH/data/llm_model.csv; otherwise from ./data/llm_model.csv.
|
|
19
|
-
- Models that require an API key will check the corresponding environment variable (name provided in the CSV).
|
|
20
|
-
"""
|
|
1
|
+
# Corrected code_under_test (llm_invoke.py)
|
|
2
|
+
# Added optional debugging prints in _select_model_candidates
|
|
21
3
|
|
|
22
4
|
import os
|
|
23
|
-
import
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import litellm
|
|
7
|
+
import logging # ADDED FOR DETAILED LOGGING
|
|
8
|
+
|
|
9
|
+
# --- Configure Standard Python Logging ---
|
|
10
|
+
logger = logging.getLogger("pdd.llm_invoke")
|
|
11
|
+
|
|
12
|
+
# Environment variable to control log level
|
|
13
|
+
PDD_LOG_LEVEL = os.getenv("PDD_LOG_LEVEL", "INFO")
|
|
14
|
+
PRODUCTION_MODE = os.getenv("PDD_ENVIRONMENT") == "production"
|
|
15
|
+
|
|
16
|
+
# Set default level based on environment
|
|
17
|
+
if PRODUCTION_MODE:
|
|
18
|
+
logger.setLevel(logging.WARNING)
|
|
19
|
+
else:
|
|
20
|
+
logger.setLevel(getattr(logging, PDD_LOG_LEVEL, logging.INFO))
|
|
21
|
+
|
|
22
|
+
# Configure LiteLLM logger separately
|
|
23
|
+
litellm_logger = logging.getLogger("litellm")
|
|
24
|
+
litellm_log_level = os.getenv("LITELLM_LOG_LEVEL", "WARNING" if PRODUCTION_MODE else "INFO")
|
|
25
|
+
litellm_logger.setLevel(getattr(logging, litellm_log_level, logging.WARNING))
|
|
26
|
+
|
|
27
|
+
# Add a console handler if none exists
|
|
28
|
+
if not logger.handlers:
|
|
29
|
+
console_handler = logging.StreamHandler()
|
|
30
|
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
31
|
+
console_handler.setFormatter(formatter)
|
|
32
|
+
logger.addHandler(console_handler)
|
|
33
|
+
|
|
34
|
+
# Only add handler to litellm logger if it doesn't have any
|
|
35
|
+
if not litellm_logger.handlers:
|
|
36
|
+
litellm_logger.addHandler(console_handler)
|
|
37
|
+
|
|
38
|
+
# Function to set up file logging if needed
|
|
39
|
+
def setup_file_logging(log_file_path=None):
|
|
40
|
+
"""Configure rotating file handler for logging"""
|
|
41
|
+
if not log_file_path:
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
from logging.handlers import RotatingFileHandler
|
|
46
|
+
file_handler = RotatingFileHandler(
|
|
47
|
+
log_file_path, maxBytes=10*1024*1024, backupCount=5
|
|
48
|
+
)
|
|
49
|
+
file_handler.setFormatter(logging.Formatter(
|
|
50
|
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
51
|
+
))
|
|
52
|
+
logger.addHandler(file_handler)
|
|
53
|
+
litellm_logger.addHandler(file_handler)
|
|
54
|
+
logger.info(f"File logging configured to: {log_file_path}")
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logger.warning(f"Failed to set up file logging: {e}")
|
|
57
|
+
|
|
58
|
+
# Function to set verbose logging
|
|
59
|
+
def set_verbose_logging(verbose=False):
|
|
60
|
+
"""Set verbose logging based on flag or environment variable"""
|
|
61
|
+
if verbose or os.getenv("PDD_VERBOSE_LOGGING") == "1":
|
|
62
|
+
logger.setLevel(logging.DEBUG)
|
|
63
|
+
litellm_logger.setLevel(logging.DEBUG)
|
|
64
|
+
logger.debug("Verbose logging enabled")
|
|
65
|
+
|
|
66
|
+
# --- End Logging Configuration ---
|
|
67
|
+
|
|
24
68
|
import json
|
|
69
|
+
# from rich import print as rprint # Replaced with logger
|
|
70
|
+
from dotenv import load_dotenv
|
|
71
|
+
from pathlib import Path
|
|
72
|
+
from typing import Optional, Dict, List, Any, Type, Union
|
|
73
|
+
from pydantic import BaseModel, ValidationError
|
|
74
|
+
import openai # Import openai for exception handling as LiteLLM maps to its types
|
|
75
|
+
from langchain_core.prompts import PromptTemplate
|
|
76
|
+
import warnings
|
|
77
|
+
import time as time_module # Alias to avoid conflict with 'time' parameter
|
|
78
|
+
# Import the default model constant
|
|
79
|
+
from pdd import DEFAULT_LLM_MODEL
|
|
25
80
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
from rich.errors import MarkupError
|
|
29
|
-
|
|
30
|
-
# Langchain core and community imports
|
|
31
|
-
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
|
|
32
|
-
from langchain_community.cache import SQLiteCache
|
|
33
|
-
from langchain.globals import set_llm_cache
|
|
34
|
-
from langchain_core.output_parsers import PydanticOutputParser, StrOutputParser
|
|
35
|
-
from langchain_core.runnables import RunnablePassthrough, ConfigurableField
|
|
36
|
-
|
|
37
|
-
# LLM provider imports
|
|
38
|
-
from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI
|
|
39
|
-
from langchain_fireworks import Fireworks
|
|
40
|
-
from langchain_anthropic import ChatAnthropic
|
|
41
|
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
42
|
-
from langchain_google_vertexai import ChatVertexAI
|
|
43
|
-
from langchain_groq import ChatGroq
|
|
44
|
-
from langchain_together import Together
|
|
45
|
-
from langchain_ollama.llms import OllamaLLM
|
|
46
|
-
|
|
47
|
-
from langchain.callbacks.base import BaseCallbackHandler
|
|
48
|
-
from langchain.schema import LLMResult
|
|
49
|
-
|
|
50
|
-
# ---------------- Internal Helper Classes and Functions ---------------- #
|
|
51
|
-
|
|
52
|
-
class CompletionStatusHandler(BaseCallbackHandler):
|
|
53
|
-
"""
|
|
54
|
-
Callback handler to capture LLM token usage and completion metadata.
|
|
55
|
-
"""
|
|
56
|
-
def __init__(self):
|
|
57
|
-
self.is_complete = False
|
|
58
|
-
self.finish_reason = None
|
|
59
|
-
self.input_tokens = None
|
|
60
|
-
self.output_tokens = None
|
|
61
|
-
|
|
62
|
-
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
|
|
63
|
-
self.is_complete = True
|
|
64
|
-
if response.generations and response.generations[0]:
|
|
65
|
-
generation = response.generations[0][0]
|
|
66
|
-
# Safely get generation_info; if it's None, default to {}
|
|
67
|
-
generation_info = generation.generation_info or {}
|
|
68
|
-
self.finish_reason = (generation_info.get('finish_reason') or "").lower()
|
|
69
|
-
|
|
70
|
-
# Attempt to get token usage from generation.message if available.
|
|
71
|
-
if (
|
|
72
|
-
hasattr(generation, "message")
|
|
73
|
-
and generation.message is not None
|
|
74
|
-
and hasattr(generation.message, "usage_metadata")
|
|
75
|
-
and generation.message.usage_metadata
|
|
76
|
-
):
|
|
77
|
-
usage_metadata = generation.message.usage_metadata
|
|
78
|
-
else:
|
|
79
|
-
usage_metadata = generation_info.get("usage_metadata", {})
|
|
80
|
-
|
|
81
|
-
self.input_tokens = usage_metadata.get('input_tokens', 0)
|
|
82
|
-
self.output_tokens = usage_metadata.get('output_tokens', 0)
|
|
81
|
+
# Opt-in to future pandas behavior regarding downcasting
|
|
82
|
+
pd.set_option('future.no_silent_downcasting', True)
|
|
83
83
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
84
|
+
# <<< SET LITELLM DEBUG LOGGING >>>
|
|
85
|
+
# os.environ['LITELLM_LOG'] = 'DEBUG' # Keep commented out unless debugging LiteLLM itself
|
|
86
|
+
|
|
87
|
+
# --- Constants and Configuration ---
|
|
88
|
+
|
|
89
|
+
# Determine project root: 1. PDD_PATH env var, 2. Search upwards from script, 3. CWD
|
|
90
|
+
PROJECT_ROOT = None
|
|
91
|
+
PDD_PATH_ENV = os.getenv("PDD_PATH")
|
|
92
|
+
|
|
93
|
+
if PDD_PATH_ENV:
|
|
94
|
+
_path_from_env = Path(PDD_PATH_ENV)
|
|
95
|
+
if _path_from_env.is_dir():
|
|
96
|
+
PROJECT_ROOT = _path_from_env.resolve()
|
|
97
|
+
logger.debug(f"Using PROJECT_ROOT from PDD_PATH: {PROJECT_ROOT}")
|
|
98
|
+
else:
|
|
99
|
+
warnings.warn(f"PDD_PATH environment variable ('{PDD_PATH_ENV}') is set but not a valid directory. Attempting auto-detection.")
|
|
100
|
+
|
|
101
|
+
if PROJECT_ROOT is None: # If PDD_PATH wasn't set or was invalid
|
|
102
|
+
try:
|
|
103
|
+
# Start from the directory containing this script
|
|
104
|
+
current_dir = Path(__file__).resolve().parent
|
|
105
|
+
# Look for project markers (e.g., .git, pyproject.toml, data/, .env)
|
|
106
|
+
# Go up a maximum of 5 levels to prevent infinite loops
|
|
107
|
+
for _ in range(5):
|
|
108
|
+
has_git = (current_dir / ".git").exists()
|
|
109
|
+
has_pyproject = (current_dir / "pyproject.toml").exists()
|
|
110
|
+
has_data = (current_dir / "data").is_dir()
|
|
111
|
+
has_dotenv = (current_dir / ".env").exists()
|
|
112
|
+
|
|
113
|
+
if has_git or has_pyproject or has_data or has_dotenv:
|
|
114
|
+
PROJECT_ROOT = current_dir
|
|
115
|
+
logger.debug(f"Determined PROJECT_ROOT by marker search: {PROJECT_ROOT}")
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
parent_dir = current_dir.parent
|
|
119
|
+
if parent_dir == current_dir: # Reached filesystem root
|
|
120
|
+
break
|
|
121
|
+
current_dir = parent_dir
|
|
122
|
+
|
|
123
|
+
except NameError: # __file__ might not be defined (e.g., interactive session)
|
|
124
|
+
warnings.warn("__file__ not defined. Cannot automatically detect project root from script location.")
|
|
125
|
+
except Exception as e: # Catch potential permission errors etc.
|
|
126
|
+
warnings.warn(f"Error during project root auto-detection: {e}")
|
|
127
|
+
|
|
128
|
+
if PROJECT_ROOT is None: # Fallback to CWD if no method succeeded
|
|
129
|
+
PROJECT_ROOT = Path.cwd().resolve()
|
|
130
|
+
warnings.warn(f"Could not determine project root automatically. Using current working directory: {PROJECT_ROOT}. Ensure this is the intended root or set the PDD_PATH environment variable.")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
ENV_PATH = PROJECT_ROOT / ".env"
|
|
134
|
+
# --- Determine LLM_MODEL_CSV_PATH ---
|
|
135
|
+
# Prioritize ~/.pdd/llm_model.csv
|
|
136
|
+
user_pdd_dir = Path.home() / ".pdd"
|
|
137
|
+
user_model_csv_path = user_pdd_dir / "llm_model.csv"
|
|
138
|
+
|
|
139
|
+
if user_model_csv_path.is_file():
|
|
140
|
+
LLM_MODEL_CSV_PATH = user_model_csv_path
|
|
141
|
+
logger.info(f"Using user-specific LLM model CSV: {LLM_MODEL_CSV_PATH}")
|
|
142
|
+
else:
|
|
143
|
+
LLM_MODEL_CSV_PATH = PROJECT_ROOT / "data" / "llm_model.csv"
|
|
144
|
+
logger.info(f"Using project LLM model CSV: {LLM_MODEL_CSV_PATH}")
|
|
145
|
+
# ---------------------------------
|
|
146
|
+
|
|
147
|
+
# Load environment variables from .env file
|
|
148
|
+
logger.debug(f"Attempting to load .env from: {ENV_PATH}")
|
|
149
|
+
if ENV_PATH.exists():
|
|
150
|
+
load_dotenv(dotenv_path=ENV_PATH)
|
|
151
|
+
logger.debug(f"Loaded .env file from: {ENV_PATH}")
|
|
152
|
+
else:
|
|
153
|
+
# Silently proceed if .env is optional
|
|
154
|
+
logger.debug(f".env file not found at {ENV_PATH}. API keys might need to be provided manually.")
|
|
155
|
+
|
|
156
|
+
# Default model if PDD_MODEL_DEFAULT is not set
|
|
157
|
+
# Use the imported constant as the default
|
|
158
|
+
DEFAULT_BASE_MODEL = os.getenv("PDD_MODEL_DEFAULT", DEFAULT_LLM_MODEL)
|
|
159
|
+
|
|
160
|
+
# --- LiteLLM Cache Configuration (S3 compatible for GCS, with SQLite fallback) ---
|
|
161
|
+
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
|
|
162
|
+
GCS_ENDPOINT_URL = "https://storage.googleapis.com" # GCS S3 compatibility endpoint
|
|
163
|
+
GCS_REGION_NAME = os.getenv("GCS_REGION_NAME", "auto") # Often 'auto' works for GCS
|
|
164
|
+
GCS_HMAC_ACCESS_KEY_ID = os.getenv("GCS_HMAC_ACCESS_KEY_ID") # Load HMAC Key ID
|
|
165
|
+
GCS_HMAC_SECRET_ACCESS_KEY = os.getenv("GCS_HMAC_SECRET_ACCESS_KEY") # Load HMAC Secret
|
|
166
|
+
|
|
167
|
+
cache_configured = False
|
|
168
|
+
|
|
169
|
+
if GCS_BUCKET_NAME and GCS_HMAC_ACCESS_KEY_ID and GCS_HMAC_SECRET_ACCESS_KEY:
|
|
170
|
+
# Store original AWS credentials before overwriting for GCS cache setup
|
|
171
|
+
original_aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
|
|
172
|
+
original_aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
|
|
173
|
+
original_aws_region_name = os.environ.get('AWS_REGION_NAME')
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
# Temporarily set AWS env vars to GCS HMAC keys for S3 compatible cache
|
|
177
|
+
os.environ['AWS_ACCESS_KEY_ID'] = GCS_HMAC_ACCESS_KEY_ID
|
|
178
|
+
os.environ['AWS_SECRET_ACCESS_KEY'] = GCS_HMAC_SECRET_ACCESS_KEY
|
|
179
|
+
# os.environ['AWS_REGION_NAME'] = GCS_REGION_NAME # Uncomment if needed
|
|
180
|
+
|
|
181
|
+
litellm.cache = litellm.Cache(
|
|
182
|
+
type="s3",
|
|
183
|
+
s3_bucket_name=GCS_BUCKET_NAME,
|
|
184
|
+
s3_region_name=GCS_REGION_NAME, # Pass region explicitly to cache
|
|
185
|
+
s3_endpoint_url=GCS_ENDPOINT_URL,
|
|
186
|
+
)
|
|
187
|
+
logger.info(f"LiteLLM cache configured for GCS bucket (S3 compatible): {GCS_BUCKET_NAME}")
|
|
188
|
+
cache_configured = True
|
|
189
|
+
|
|
190
|
+
except Exception as e:
|
|
191
|
+
warnings.warn(f"Failed to configure LiteLLM S3/GCS cache: {e}. Attempting SQLite cache fallback.")
|
|
192
|
+
litellm.cache = None # Explicitly disable cache on failure (will try SQLite next)
|
|
193
|
+
|
|
194
|
+
finally:
|
|
195
|
+
# Restore original AWS credentials after cache setup attempt
|
|
196
|
+
if original_aws_access_key_id is not None:
|
|
197
|
+
os.environ['AWS_ACCESS_KEY_ID'] = original_aws_access_key_id
|
|
198
|
+
elif 'AWS_ACCESS_KEY_ID' in os.environ:
|
|
199
|
+
del os.environ['AWS_ACCESS_KEY_ID']
|
|
200
|
+
|
|
201
|
+
if original_aws_secret_access_key is not None:
|
|
202
|
+
os.environ['AWS_SECRET_ACCESS_KEY'] = original_aws_secret_access_key
|
|
203
|
+
elif 'AWS_SECRET_ACCESS_KEY' in os.environ:
|
|
204
|
+
del os.environ['AWS_SECRET_ACCESS_KEY']
|
|
205
|
+
|
|
206
|
+
if original_aws_region_name is not None:
|
|
207
|
+
os.environ['AWS_REGION_NAME'] = original_aws_region_name
|
|
208
|
+
elif 'AWS_REGION_NAME' in os.environ:
|
|
209
|
+
pass # Or just leave it if the temporary setting wasn't done/needed
|
|
210
|
+
|
|
211
|
+
if not cache_configured:
|
|
212
|
+
try:
|
|
213
|
+
# Try SQLite-based cache as a fallback
|
|
214
|
+
sqlite_cache_path = PROJECT_ROOT / "litellm_cache.sqlite"
|
|
215
|
+
litellm.cache = litellm.Cache(type="sqlite", cache_path=str(sqlite_cache_path))
|
|
216
|
+
logger.info(f"LiteLLM SQLite cache configured at {sqlite_cache_path}")
|
|
217
|
+
cache_configured = True
|
|
218
|
+
except Exception as e2:
|
|
219
|
+
warnings.warn(f"Failed to configure LiteLLM SQLite cache: {e2}. Caching is disabled.")
|
|
220
|
+
litellm.cache = None
|
|
221
|
+
|
|
222
|
+
if not cache_configured:
|
|
223
|
+
warnings.warn("All LiteLLM cache configuration attempts failed. Caching is disabled.")
|
|
224
|
+
litellm.cache = None
|
|
225
|
+
|
|
226
|
+
# --- LiteLLM Callback for Success Logging ---
|
|
227
|
+
|
|
228
|
+
# Module-level storage for last callback data (Use with caution in concurrent environments)
|
|
229
|
+
_LAST_CALLBACK_DATA = {
|
|
230
|
+
"input_tokens": 0,
|
|
231
|
+
"output_tokens": 0,
|
|
232
|
+
"finish_reason": None,
|
|
233
|
+
"cost": 0.0,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
def _litellm_success_callback(
|
|
237
|
+
kwargs: Dict[str, Any], # kwargs passed to completion
|
|
238
|
+
completion_response: Any, # response object from completion
|
|
239
|
+
start_time: float, end_time: float # start/end time
|
|
240
|
+
):
|
|
106
241
|
"""
|
|
107
|
-
|
|
242
|
+
LiteLLM success callback to capture usage and finish reason.
|
|
243
|
+
Stores data in a module-level variable for potential retrieval.
|
|
108
244
|
"""
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
245
|
+
global _LAST_CALLBACK_DATA
|
|
246
|
+
usage = getattr(completion_response, 'usage', None)
|
|
247
|
+
input_tokens = getattr(usage, 'prompt_tokens', 0)
|
|
248
|
+
output_tokens = getattr(usage, 'completion_tokens', 0)
|
|
249
|
+
finish_reason = getattr(completion_response.choices[0], 'finish_reason', None)
|
|
250
|
+
|
|
251
|
+
calculated_cost = 0.0
|
|
112
252
|
try:
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
253
|
+
# Attempt 1: Use the response object directly (works for most single calls)
|
|
254
|
+
cost_val = litellm.completion_cost(completion_response=completion_response)
|
|
255
|
+
calculated_cost = cost_val if cost_val is not None else 0.0
|
|
256
|
+
except Exception as e1:
|
|
257
|
+
# Attempt 2: If response object failed (e.g., missing provider in model name),
|
|
258
|
+
# try again using explicit model from kwargs and tokens from usage.
|
|
259
|
+
# This is often needed for batch completion items.
|
|
260
|
+
logger.debug(f"Attempting cost calculation with fallback method: {e1}")
|
|
261
|
+
try:
|
|
262
|
+
model_name = kwargs.get("model") # Get original model name from input kwargs
|
|
263
|
+
if model_name and usage:
|
|
264
|
+
prompt_tokens = getattr(usage, 'prompt_tokens', 0)
|
|
265
|
+
completion_tokens = getattr(usage, 'completion_tokens', 0)
|
|
266
|
+
cost_val = litellm.completion_cost(
|
|
267
|
+
model=model_name,
|
|
268
|
+
prompt_tokens=prompt_tokens,
|
|
269
|
+
completion_tokens=completion_tokens
|
|
129
270
|
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
271
|
+
calculated_cost = cost_val if cost_val is not None else 0.0
|
|
272
|
+
else:
|
|
273
|
+
# If we can't get model name or usage, fallback to 0
|
|
274
|
+
calculated_cost = 0.0
|
|
275
|
+
# Optional: Log the original error e1 if needed
|
|
276
|
+
# logger.warning(f"[Callback WARN] Failed to calculate cost with response object ({e1}) and fallback failed.")
|
|
277
|
+
except Exception as e2:
|
|
278
|
+
# Optional: Log secondary error e2 if needed
|
|
279
|
+
# logger.warning(f"[Callback WARN] Failed to calculate cost with fallback method: {e2}")
|
|
280
|
+
calculated_cost = 0.0 # Default to 0 on any error
|
|
281
|
+
logger.debug(f"Cost calculation failed with fallback method: {e2}")
|
|
134
282
|
|
|
135
|
-
|
|
136
|
-
""
|
|
137
|
-
|
|
138
|
-
""
|
|
139
|
-
for model in models:
|
|
140
|
-
if model.model == base_model_name:
|
|
141
|
-
return model
|
|
142
|
-
raise ValueError(f"Base model '{base_model_name}' not found in the models list.")
|
|
283
|
+
_LAST_CALLBACK_DATA["input_tokens"] = input_tokens
|
|
284
|
+
_LAST_CALLBACK_DATA["output_tokens"] = output_tokens
|
|
285
|
+
_LAST_CALLBACK_DATA["finish_reason"] = finish_reason
|
|
286
|
+
_LAST_CALLBACK_DATA["cost"] = calculated_cost # Store the calculated cost
|
|
143
287
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
""
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
288
|
+
# Callback doesn't need to return a value now
|
|
289
|
+
# return calculated_cost
|
|
290
|
+
|
|
291
|
+
# Example of logging within the callback (can be expanded)
|
|
292
|
+
# logger.info(f"[Callback] Tokens: In={input_tokens}, Out={output_tokens}. Reason: {finish_reason}. Cost: ${calculated_cost:.6f}")
|
|
293
|
+
|
|
294
|
+
# Register the callback with LiteLLM
|
|
295
|
+
litellm.success_callback = [_litellm_success_callback]
|
|
296
|
+
|
|
297
|
+
# --- Helper Functions ---
|
|
298
|
+
|
|
299
|
+
def _load_model_data(csv_path: Path) -> pd.DataFrame:
|
|
300
|
+
"""Loads and preprocesses the LLM model data from CSV."""
|
|
301
|
+
if not csv_path.exists():
|
|
302
|
+
raise FileNotFoundError(f"LLM model CSV not found at {csv_path}")
|
|
303
|
+
try:
|
|
304
|
+
df = pd.read_csv(csv_path)
|
|
305
|
+
# Basic validation and type conversion
|
|
306
|
+
required_cols = ['provider', 'model', 'input', 'output', 'coding_arena_elo', 'api_key', 'structured_output', 'reasoning_type']
|
|
307
|
+
for col in required_cols:
|
|
308
|
+
if col not in df.columns:
|
|
309
|
+
raise ValueError(f"Missing required column in CSV: {col}")
|
|
310
|
+
|
|
311
|
+
# Convert numeric columns, handling potential errors
|
|
312
|
+
numeric_cols = ['input', 'output', 'coding_arena_elo', 'max_tokens',
|
|
313
|
+
'max_completion_tokens', 'max_reasoning_tokens']
|
|
314
|
+
for col in numeric_cols:
|
|
315
|
+
if col in df.columns:
|
|
316
|
+
# Use errors='coerce' to turn unparseable values into NaN
|
|
317
|
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
|
318
|
+
|
|
319
|
+
# Fill NaN in critical numeric columns used for selection/interpolation
|
|
320
|
+
df['input'] = df['input'].fillna(0.0)
|
|
321
|
+
df['output'] = df['output'].fillna(0.0)
|
|
322
|
+
df['coding_arena_elo'] = df['coding_arena_elo'].fillna(0) # Use 0 ELO for missing
|
|
323
|
+
# Ensure max_reasoning_tokens is numeric, fillna with 0
|
|
324
|
+
df['max_reasoning_tokens'] = df['max_reasoning_tokens'].fillna(0).astype(int) # Ensure int
|
|
325
|
+
|
|
326
|
+
# Calculate average cost (handle potential division by zero if needed, though unlikely with fillna)
|
|
327
|
+
df['avg_cost'] = (df['input'] + df['output']) / 2
|
|
328
|
+
|
|
329
|
+
# Ensure boolean interpretation for structured_output
|
|
330
|
+
if 'structured_output' in df.columns:
|
|
331
|
+
df['structured_output'] = df['structured_output'].fillna(False).astype(bool)
|
|
332
|
+
else:
|
|
333
|
+
df['structured_output'] = False # Assume false if column missing
|
|
334
|
+
|
|
335
|
+
# Ensure reasoning_type is string, fillna with 'none' and lowercase
|
|
336
|
+
df['reasoning_type'] = df['reasoning_type'].fillna('none').astype(str).str.lower()
|
|
337
|
+
|
|
338
|
+
# Ensure api_key is treated as string, fill NaN with empty string ''
|
|
339
|
+
# This handles cases where read_csv might interpret empty fields as NaN
|
|
340
|
+
df['api_key'] = df['api_key'].fillna('').astype(str)
|
|
341
|
+
|
|
342
|
+
return df
|
|
343
|
+
except Exception as e:
|
|
344
|
+
raise RuntimeError(f"Error loading or processing LLM model CSV {csv_path}: {e}") from e
|
|
345
|
+
|
|
346
|
+
def _select_model_candidates(
|
|
347
|
+
strength: float,
|
|
348
|
+
base_model_name: str,
|
|
349
|
+
model_df: pd.DataFrame
|
|
350
|
+
) -> List[Dict[str, Any]]:
|
|
351
|
+
"""Selects and sorts candidate models based on strength and availability."""
|
|
352
|
+
|
|
353
|
+
# 1. Filter by API Key Name Presence (initial availability check)
|
|
354
|
+
# Keep models with a non-empty api_key field in the CSV.
|
|
355
|
+
# The actual key value check happens later.
|
|
356
|
+
# Allow models with empty api_key (e.g., Bedrock using AWS creds, local models)
|
|
357
|
+
available_df = model_df[model_df['api_key'].notna()].copy()
|
|
358
|
+
|
|
359
|
+
# --- Check if the initial DataFrame itself was empty ---
|
|
360
|
+
if model_df.empty:
|
|
361
|
+
raise ValueError("Loaded model data is empty. Check CSV file.")
|
|
362
|
+
|
|
363
|
+
# --- Check if filtering resulted in empty (might indicate all models had NaN api_key) ---
|
|
364
|
+
if available_df.empty:
|
|
365
|
+
# This case is less likely if notna() is the only filter, but good to check.
|
|
366
|
+
logger.warning("No models found after filtering for non-NaN api_key. Check CSV 'api_key' column.")
|
|
367
|
+
# Decide if this should be a hard error or allow proceeding if logic permits
|
|
368
|
+
# For now, let's raise an error as it likely indicates a CSV issue.
|
|
369
|
+
raise ValueError("No models available after initial filtering (all had NaN 'api_key'?).")
|
|
370
|
+
|
|
371
|
+
# 2. Find Base Model
|
|
372
|
+
base_model_row = available_df[available_df['model'] == base_model_name]
|
|
373
|
+
if base_model_row.empty:
|
|
374
|
+
# Try finding base model in the *original* df in case it was filtered out
|
|
375
|
+
original_base = model_df[model_df['model'] == base_model_name]
|
|
376
|
+
if not original_base.empty:
|
|
377
|
+
raise ValueError(f"Base model '{base_model_name}' found in CSV but requires API key '{original_base.iloc[0]['api_key']}' which might be missing or invalid configuration.")
|
|
378
|
+
else:
|
|
379
|
+
raise ValueError(f"Specified base model '{base_model_name}' not found in the LLM model CSV.")
|
|
380
|
+
|
|
381
|
+
base_model = base_model_row.iloc[0]
|
|
382
|
+
|
|
383
|
+
# 3. Determine Target and Sort
|
|
384
|
+
candidates = []
|
|
385
|
+
target_metric_value = None # For debugging print
|
|
157
386
|
|
|
158
|
-
# For base model case (strength = 0.5), use base model if available
|
|
159
387
|
if strength == 0.5:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
if
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
# Production path: interpolate based on cost
|
|
179
|
-
cheapest = min(cheaper_models, key=lambda m: m.average_cost)
|
|
180
|
-
cost_range = base_model.average_cost - cheapest.average_cost
|
|
181
|
-
target_cost = cheapest.average_cost + (strength / 0.5) * cost_range
|
|
182
|
-
return sorted(cheaper_models, key=lambda m: abs(m.average_cost - target_cost))
|
|
183
|
-
|
|
184
|
-
# For strength > 0.5, prioritize higher ELO models
|
|
185
|
-
# Get models with higher or equal ELO than base_model
|
|
186
|
-
better_models = [m for m in available_models
|
|
187
|
-
if m.coding_arena_elo >= base_model.coding_arena_elo]
|
|
188
|
-
if not better_models:
|
|
189
|
-
return [available_models[0]]
|
|
190
|
-
|
|
191
|
-
# For test environment, honor the mock model setup
|
|
192
|
-
test_models = [m for m in better_models if m.api_key == "EXISTING_KEY"]
|
|
193
|
-
if test_models:
|
|
194
|
-
return test_models
|
|
195
|
-
|
|
196
|
-
# Production path: interpolate based on ELO
|
|
197
|
-
highest = max(better_models, key=lambda m: m.coding_arena_elo)
|
|
198
|
-
elo_range = highest.coding_arena_elo - base_model.coding_arena_elo
|
|
199
|
-
target_elo = base_model.coding_arena_elo + ((strength - 0.5) / 0.5) * elo_range
|
|
200
|
-
return sorted(better_models, key=lambda m: abs(m.coding_arena_elo - target_elo))
|
|
201
|
-
|
|
202
|
-
def create_llm_instance(selected_model, temperature, handler):
|
|
203
|
-
"""
|
|
204
|
-
Creates an instance of the LLM using the selected_model parameters.
|
|
205
|
-
Handles provider-specific settings and token limit configurations.
|
|
206
|
-
"""
|
|
207
|
-
provider = selected_model.provider.lower()
|
|
208
|
-
model_name = selected_model.model
|
|
209
|
-
base_url = selected_model.base_url
|
|
210
|
-
api_key_env = selected_model.api_key
|
|
211
|
-
max_completion_tokens = selected_model.max_completion_tokens
|
|
212
|
-
max_tokens = selected_model.max_tokens
|
|
213
|
-
|
|
214
|
-
api_key = os.environ.get(api_key_env) if api_key_env else None
|
|
215
|
-
|
|
216
|
-
if provider == 'openai':
|
|
217
|
-
if base_url:
|
|
218
|
-
llm = ChatOpenAI(model=model_name, temperature=temperature,
|
|
219
|
-
openai_api_key=api_key, callbacks=[handler],
|
|
220
|
-
openai_api_base=base_url)
|
|
388
|
+
# target_model = base_model
|
|
389
|
+
# Sort remaining by ELO descending as fallback
|
|
390
|
+
available_df['sort_metric'] = -available_df['coding_arena_elo'] # Negative for descending sort
|
|
391
|
+
candidates = available_df.sort_values(by='sort_metric').to_dict('records')
|
|
392
|
+
# Ensure base model is first if it exists
|
|
393
|
+
if any(c['model'] == base_model_name for c in candidates):
|
|
394
|
+
candidates.sort(key=lambda x: 0 if x['model'] == base_model_name else 1)
|
|
395
|
+
target_metric_value = f"Base Model ELO: {base_model['coding_arena_elo']}"
|
|
396
|
+
|
|
397
|
+
elif strength < 0.5:
|
|
398
|
+
# Interpolate by Cost (downwards from base)
|
|
399
|
+
base_cost = base_model['avg_cost']
|
|
400
|
+
cheapest_model = available_df.loc[available_df['avg_cost'].idxmin()]
|
|
401
|
+
cheapest_cost = cheapest_model['avg_cost']
|
|
402
|
+
|
|
403
|
+
if base_cost <= cheapest_cost: # Handle edge case where base is cheapest
|
|
404
|
+
target_cost = cheapest_cost + strength * (base_cost - cheapest_cost) # Will be <= base_cost
|
|
221
405
|
else:
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
#
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
)
|
|
406
|
+
# Interpolate between cheapest and base
|
|
407
|
+
target_cost = cheapest_cost + (strength / 0.5) * (base_cost - cheapest_cost)
|
|
408
|
+
|
|
409
|
+
available_df['sort_metric'] = abs(available_df['avg_cost'] - target_cost)
|
|
410
|
+
candidates = available_df.sort_values(by='sort_metric').to_dict('records')
|
|
411
|
+
target_metric_value = f"Target Cost: {target_cost:.6f}"
|
|
412
|
+
|
|
413
|
+
else: # strength > 0.5
|
|
414
|
+
# Interpolate by ELO (upwards from base)
|
|
415
|
+
base_elo = base_model['coding_arena_elo']
|
|
416
|
+
highest_elo_model = available_df.loc[available_df['coding_arena_elo'].idxmax()]
|
|
417
|
+
highest_elo = highest_elo_model['coding_arena_elo']
|
|
418
|
+
|
|
419
|
+
if highest_elo <= base_elo: # Handle edge case where base has highest ELO
|
|
420
|
+
target_elo = base_elo + (strength - 0.5) * (highest_elo - base_elo) # Will be >= base_elo
|
|
238
421
|
else:
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
422
|
+
# Interpolate between base and highest
|
|
423
|
+
target_elo = base_elo + ((strength - 0.5) / 0.5) * (highest_elo - base_elo)
|
|
424
|
+
|
|
425
|
+
available_df['sort_metric'] = abs(available_df['coding_arena_elo'] - target_elo)
|
|
426
|
+
candidates = available_df.sort_values(by='sort_metric').to_dict('records')
|
|
427
|
+
target_metric_value = f"Target ELO: {target_elo:.2f}"
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
if not candidates:
|
|
431
|
+
# This should ideally not happen if available_df was not empty
|
|
432
|
+
raise RuntimeError("Model selection resulted in an empty candidate list.")
|
|
433
|
+
|
|
434
|
+
# --- DEBUGGING PRINT ---
|
|
435
|
+
if os.getenv("PDD_DEBUG_SELECTOR"): # Add env var check for debug prints
|
|
436
|
+
logger.debug("\n--- DEBUG: _select_model_candidates ---")
|
|
437
|
+
logger.debug(f"Strength: {strength}, Base Model: {base_model_name}")
|
|
438
|
+
logger.debug(f"Metric: {target_metric_value}")
|
|
439
|
+
logger.debug("Available DF (Sorted by metric):")
|
|
440
|
+
# Select columns relevant to the sorting metric
|
|
441
|
+
sort_cols = ['model', 'avg_cost', 'coding_arena_elo', 'sort_metric']
|
|
442
|
+
logger.debug(available_df.sort_values(by='sort_metric')[sort_cols])
|
|
443
|
+
logger.debug("Final Candidates List (Model Names):")
|
|
444
|
+
logger.debug([c['model'] for c in candidates])
|
|
445
|
+
logger.debug("---------------------------------------\n")
|
|
446
|
+
# --- END DEBUGGING PRINT ---
|
|
447
|
+
|
|
448
|
+
return candidates
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _ensure_api_key(model_info: Dict[str, Any], newly_acquired_keys: Dict[str, bool], verbose: bool) -> bool:
|
|
452
|
+
"""Checks for API key in env, prompts user if missing, and updates .env."""
|
|
453
|
+
key_name = model_info.get('api_key')
|
|
454
|
+
|
|
455
|
+
if not key_name or key_name == "EXISTING_KEY":
|
|
456
|
+
if verbose:
|
|
457
|
+
logger.info(f"Skipping API key check for model {model_info.get('model')} (key name: {key_name})")
|
|
458
|
+
return True # Assume key is handled elsewhere or not needed
|
|
459
|
+
|
|
460
|
+
key_value = os.getenv(key_name)
|
|
461
|
+
|
|
462
|
+
if key_value:
|
|
463
|
+
if verbose:
|
|
464
|
+
logger.info(f"API key '{key_name}' found in environment.")
|
|
465
|
+
newly_acquired_keys[key_name] = False # Mark as existing
|
|
466
|
+
return True
|
|
255
467
|
else:
|
|
256
|
-
|
|
468
|
+
logger.warning(f"API key environment variable '{key_name}' for model '{model_info.get('model')}' is not set.")
|
|
469
|
+
try:
|
|
470
|
+
# Interactive prompt
|
|
471
|
+
user_provided_key = input(f"Please enter the API key for {key_name}: ").strip()
|
|
472
|
+
if not user_provided_key:
|
|
473
|
+
logger.error("No API key provided. Cannot proceed with this model.")
|
|
474
|
+
return False
|
|
475
|
+
|
|
476
|
+
# Set environment variable for the current process
|
|
477
|
+
os.environ[key_name] = user_provided_key
|
|
478
|
+
logger.info(f"API key '{key_name}' set for the current session.")
|
|
479
|
+
newly_acquired_keys[key_name] = True # Mark as newly acquired
|
|
480
|
+
|
|
481
|
+
# Update .env file
|
|
482
|
+
try:
|
|
483
|
+
lines = []
|
|
484
|
+
if ENV_PATH.exists():
|
|
485
|
+
with open(ENV_PATH, 'r') as f:
|
|
486
|
+
lines = f.readlines()
|
|
487
|
+
|
|
488
|
+
new_lines = []
|
|
489
|
+
# key_updated = False
|
|
490
|
+
prefix = f"{key_name}="
|
|
491
|
+
prefix_spaced = f"{key_name} =" # Handle potential spaces
|
|
492
|
+
|
|
493
|
+
for line in lines:
|
|
494
|
+
stripped_line = line.strip()
|
|
495
|
+
if stripped_line.startswith(prefix) or stripped_line.startswith(prefix_spaced):
|
|
496
|
+
# Comment out the old key
|
|
497
|
+
new_lines.append(f"# {line}")
|
|
498
|
+
# key_updated = True # Indicates we found an old line to comment
|
|
499
|
+
elif stripped_line.startswith(f"# {prefix}") or stripped_line.startswith(f"# {prefix_spaced}"):
|
|
500
|
+
# Keep already commented lines as they are
|
|
501
|
+
new_lines.append(line)
|
|
502
|
+
else:
|
|
503
|
+
new_lines.append(line)
|
|
504
|
+
|
|
505
|
+
# Append the new key, ensuring quotes for robustness
|
|
506
|
+
new_key_line = f'{key_name}="{user_provided_key}"\n'
|
|
507
|
+
# Add newline before if file not empty and doesn't end with newline
|
|
508
|
+
if new_lines and not new_lines[-1].endswith('\n'):
|
|
509
|
+
new_lines.append('\n')
|
|
510
|
+
new_lines.append(new_key_line)
|
|
257
511
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
512
|
+
|
|
513
|
+
with open(ENV_PATH, 'w') as f:
|
|
514
|
+
f.writelines(new_lines)
|
|
515
|
+
|
|
516
|
+
logger.info(f"API key '{key_name}' saved to {ENV_PATH}.")
|
|
517
|
+
logger.warning("SECURITY WARNING: The API key has been saved to your .env file. "
|
|
518
|
+
"Ensure this file is kept secure and is included in your .gitignore.")
|
|
519
|
+
|
|
520
|
+
except IOError as e:
|
|
521
|
+
logger.error(f"Failed to update .env file at {ENV_PATH}: {e}")
|
|
522
|
+
# Continue since the key is set in the environment for this session
|
|
523
|
+
|
|
524
|
+
return True
|
|
525
|
+
|
|
526
|
+
except EOFError: # Handle non-interactive environments
|
|
527
|
+
logger.error(f"Cannot prompt for API key '{key_name}' in a non-interactive environment.")
|
|
528
|
+
return False
|
|
529
|
+
except Exception as e:
|
|
530
|
+
logger.error(f"An unexpected error occurred during API key acquisition: {e}")
|
|
531
|
+
return False
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def _format_messages(prompt: str, input_data: Union[Dict[str, Any], List[Dict[str, Any]]], use_batch_mode: bool) -> Union[List[Dict[str, str]], List[List[Dict[str, str]]]]:
|
|
535
|
+
"""Formats prompt and input into LiteLLM message format."""
|
|
536
|
+
try:
|
|
537
|
+
prompt_template = PromptTemplate.from_template(prompt)
|
|
538
|
+
if use_batch_mode:
|
|
539
|
+
if not isinstance(input_data, list):
|
|
540
|
+
raise ValueError("input_json must be a list of dictionaries when use_batch_mode is True.")
|
|
541
|
+
all_messages = []
|
|
542
|
+
for item in input_data:
|
|
543
|
+
if not isinstance(item, dict):
|
|
544
|
+
raise ValueError("Each item in input_json list must be a dictionary for batch mode.")
|
|
545
|
+
formatted_prompt = prompt_template.format(**item)
|
|
546
|
+
all_messages.append([{"role": "user", "content": formatted_prompt}])
|
|
547
|
+
return all_messages
|
|
263
548
|
else:
|
|
264
|
-
|
|
265
|
-
|
|
549
|
+
if not isinstance(input_data, dict):
|
|
550
|
+
raise ValueError("input_json must be a dictionary when use_batch_mode is False.")
|
|
551
|
+
formatted_prompt = prompt_template.format(**input_data)
|
|
552
|
+
return [{"role": "user", "content": formatted_prompt}]
|
|
553
|
+
except KeyError as e:
|
|
554
|
+
raise ValueError(f"Prompt formatting error: Missing key {e} in input_json for prompt template.") from e
|
|
555
|
+
except Exception as e:
|
|
556
|
+
raise ValueError(f"Error formatting prompt: {e}") from e
|
|
266
557
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
558
|
+
# --- Main Function ---
|
|
559
|
+
|
|
560
|
+
def llm_invoke(
|
|
561
|
+
prompt: Optional[str] = None,
|
|
562
|
+
input_json: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
563
|
+
strength: float = 0.5, # Use pdd.DEFAULT_STRENGTH if available, else 0.5
|
|
564
|
+
temperature: float = 0.1,
|
|
565
|
+
verbose: bool = False,
|
|
566
|
+
output_pydantic: Optional[Type[BaseModel]] = None,
|
|
567
|
+
time: float = 0.25,
|
|
568
|
+
use_batch_mode: bool = False,
|
|
569
|
+
messages: Optional[Union[List[Dict[str, str]], List[List[Dict[str, str]]]]] = None,
|
|
570
|
+
) -> Dict[str, Any]:
|
|
270
571
|
"""
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
output_cost = selected_model.output_cost
|
|
275
|
-
total_cost = (input_tokens / 1_000_000) * input_cost + (output_tokens / 1_000_000) * output_cost
|
|
276
|
-
return total_cost
|
|
572
|
+
Runs a prompt with given input using LiteLLM, handling model selection,
|
|
573
|
+
API key acquisition, structured output, batching, and reasoning time.
|
|
574
|
+
The maximum completion token length defaults to the provider's maximum.
|
|
277
575
|
|
|
278
|
-
|
|
576
|
+
Args:
|
|
577
|
+
prompt: Prompt template string (required if messages is None).
|
|
578
|
+
input_json: Dictionary or list of dictionaries for prompt variables (required if messages is None).
|
|
579
|
+
strength: Model selection strength (0=cheapest, 0.5=base, 1=highest ELO).
|
|
580
|
+
temperature: LLM temperature.
|
|
581
|
+
verbose: Print detailed logs.
|
|
582
|
+
output_pydantic: Optional Pydantic model for structured output.
|
|
583
|
+
time: Relative thinking time (0-1, default 0.25).
|
|
584
|
+
use_batch_mode: Use batch completion if True.
|
|
585
|
+
messages: Pre-formatted list of messages (or list of lists for batch). If provided, ignores prompt and input_json.
|
|
279
586
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
temperature (float): Temperature for the LLM invocation.
|
|
289
|
-
verbose (bool): When True, prints detailed information.
|
|
290
|
-
output_pydantic (Optional): A Pydantic model class for structured output.
|
|
291
|
-
|
|
292
|
-
Output (dict): Contains:
|
|
293
|
-
'result' - LLM output (string or parsed Pydantic object).
|
|
294
|
-
'cost' - Calculated cost of the invoke run.
|
|
295
|
-
'model_name' - Name of the selected model that succeeded.
|
|
587
|
+
Returns:
|
|
588
|
+
Dictionary containing 'result', 'cost', 'model_name', 'thinking_output'.
|
|
589
|
+
|
|
590
|
+
Raises:
|
|
591
|
+
ValueError: For invalid inputs or prompt formatting errors.
|
|
592
|
+
FileNotFoundError: If llm_model.csv is missing.
|
|
593
|
+
RuntimeError: If all candidate models fail.
|
|
594
|
+
openai.*Error: If LiteLLM encounters API errors after retries.
|
|
296
595
|
"""
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
if input_json is None:
|
|
300
|
-
raise ValueError("Input JSON is required.")
|
|
301
|
-
if not isinstance(input_json, dict):
|
|
302
|
-
raise ValueError("Input JSON must be a dictionary.")
|
|
303
|
-
|
|
304
|
-
set_llm_cache(SQLiteCache(database_path=".langchain.db"))
|
|
305
|
-
base_model_name = os.environ.get('PDD_MODEL_DEFAULT', 'gpt-4.1-nano')
|
|
306
|
-
models = load_models()
|
|
596
|
+
# Set verbose logging if requested
|
|
597
|
+
set_verbose_logging(verbose)
|
|
307
598
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
599
|
+
if verbose:
|
|
600
|
+
logger.debug("llm_invoke start - Arguments received:")
|
|
601
|
+
logger.debug(f" prompt: {'provided' if prompt else 'None'}")
|
|
602
|
+
logger.debug(f" input_json: {'provided' if input_json is not None else 'None'}")
|
|
603
|
+
logger.debug(f" strength: {strength}")
|
|
604
|
+
logger.debug(f" temperature: {temperature}")
|
|
605
|
+
logger.debug(f" verbose: {verbose}")
|
|
606
|
+
logger.debug(f" output_pydantic: {output_pydantic.__name__ if output_pydantic else 'None'}")
|
|
607
|
+
logger.debug(f" time: {time}")
|
|
608
|
+
logger.debug(f" use_batch_mode: {use_batch_mode}")
|
|
609
|
+
logger.debug(f" messages: {'provided' if messages else 'None'}")
|
|
312
610
|
|
|
313
|
-
|
|
611
|
+
# --- 1. Load Environment & Validate Inputs ---
|
|
612
|
+
# .env loading happens at module level
|
|
314
613
|
|
|
315
|
-
if
|
|
316
|
-
|
|
614
|
+
if messages:
|
|
615
|
+
if verbose:
|
|
616
|
+
logger.info("Using provided 'messages' input.")
|
|
617
|
+
# Basic validation of messages format
|
|
618
|
+
if use_batch_mode:
|
|
619
|
+
if not isinstance(messages, list) or not all(isinstance(m_list, list) for m_list in messages):
|
|
620
|
+
raise ValueError("'messages' must be a list of lists when use_batch_mode is True.")
|
|
621
|
+
if not all(isinstance(msg, dict) and 'role' in msg and 'content' in msg for m_list in messages for msg in m_list):
|
|
622
|
+
raise ValueError("Each message in the lists within 'messages' must be a dictionary with 'role' and 'content'.")
|
|
623
|
+
else:
|
|
624
|
+
if not isinstance(messages, list) or not all(isinstance(msg, dict) and 'role' in msg and 'content' in msg for msg in messages):
|
|
625
|
+
raise ValueError("'messages' must be a list of dictionaries with 'role' and 'content'.")
|
|
626
|
+
formatted_messages = messages
|
|
627
|
+
elif prompt and input_json is not None:
|
|
628
|
+
if not isinstance(prompt, str) or not prompt:
|
|
629
|
+
raise ValueError("'prompt' must be a non-empty string when 'messages' is not provided.")
|
|
630
|
+
formatted_messages = _format_messages(prompt, input_json, use_batch_mode)
|
|
631
|
+
else:
|
|
632
|
+
raise ValueError("Either 'messages' or both 'prompt' and 'input_json' must be provided.")
|
|
633
|
+
|
|
634
|
+
if not (0.0 <= strength <= 1.0):
|
|
635
|
+
raise ValueError("'strength' must be between 0.0 and 1.0.")
|
|
636
|
+
if not (0.0 <= temperature <= 2.0): # Common range for temperature
|
|
637
|
+
warnings.warn("'temperature' is outside the typical range (0.0-2.0).")
|
|
638
|
+
if not (0.0 <= time <= 1.0):
|
|
639
|
+
raise ValueError("'time' must be between 0.0 and 1.0.")
|
|
317
640
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
641
|
+
# --- 2. Load Model Data & Select Candidates ---
|
|
642
|
+
try:
|
|
643
|
+
model_df = _load_model_data(LLM_MODEL_CSV_PATH)
|
|
644
|
+
candidate_models = _select_model_candidates(strength, DEFAULT_BASE_MODEL, model_df)
|
|
645
|
+
except (FileNotFoundError, ValueError, RuntimeError) as e:
|
|
646
|
+
logger.error(f"Failed during model loading or selection: {e}")
|
|
647
|
+
raise
|
|
648
|
+
|
|
649
|
+
if verbose:
|
|
650
|
+
# This print statement is crucial for the verbose test
|
|
651
|
+
# Calculate and print strength for each candidate model
|
|
652
|
+
# Find min/max for cost and ELO
|
|
653
|
+
min_cost = model_df['avg_cost'].min()
|
|
654
|
+
max_elo = model_df['coding_arena_elo'].max()
|
|
655
|
+
base_cost = model_df[model_df['model'] == DEFAULT_BASE_MODEL]['avg_cost'].iloc[0] if not model_df[model_df['model'] == DEFAULT_BASE_MODEL].empty else min_cost
|
|
656
|
+
base_elo = model_df[model_df['model'] == DEFAULT_BASE_MODEL]['coding_arena_elo'].iloc[0] if not model_df[model_df['model'] == DEFAULT_BASE_MODEL].empty else max_elo
|
|
657
|
+
|
|
658
|
+
def calc_strength(candidate):
|
|
659
|
+
# If strength < 0.5, interpolate by cost (cheaper = 0, base = 0.5)
|
|
660
|
+
# If strength > 0.5, interpolate by ELO (base = 0.5, highest = 1.0)
|
|
661
|
+
avg_cost = candidate.get('avg_cost', min_cost)
|
|
662
|
+
elo = candidate.get('coding_arena_elo', base_elo)
|
|
663
|
+
if strength < 0.5:
|
|
664
|
+
# Map cost to [0, 0.5]
|
|
665
|
+
if base_cost == min_cost:
|
|
666
|
+
return 0.5 # Avoid div by zero
|
|
667
|
+
rel = (avg_cost - min_cost) / (base_cost - min_cost)
|
|
668
|
+
return max(0.0, min(0.5, rel * 0.5))
|
|
669
|
+
elif strength > 0.5:
|
|
670
|
+
# Map ELO to [0.5, 1.0]
|
|
671
|
+
if max_elo == base_elo:
|
|
672
|
+
return 0.5 # Avoid div by zero
|
|
673
|
+
rel = (elo - base_elo) / (max_elo - base_elo)
|
|
674
|
+
return max(0.5, min(1.0, 0.5 + rel * 0.5))
|
|
675
|
+
else:
|
|
676
|
+
return 0.5
|
|
677
|
+
|
|
678
|
+
model_strengths_formatted = [(c['model'], f"{float(calc_strength(c)):.3f}") for c in candidate_models]
|
|
679
|
+
logger.info("Candidate models selected and ordered (with strength): %s", model_strengths_formatted) # CORRECTED
|
|
680
|
+
logger.info(f"Strength: {strength}, Temperature: {temperature}, Time: {time}")
|
|
681
|
+
if use_batch_mode:
|
|
682
|
+
logger.info("Batch mode enabled.")
|
|
683
|
+
if output_pydantic:
|
|
684
|
+
logger.info(f"Pydantic output requested: {output_pydantic.__name__}")
|
|
321
685
|
try:
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
686
|
+
# Only print input_json if it was actually provided (not when messages were used)
|
|
687
|
+
if input_json is not None:
|
|
688
|
+
logger.info("Input JSON:")
|
|
689
|
+
logger.info(input_json)
|
|
690
|
+
else:
|
|
691
|
+
logger.info("Input: Using pre-formatted 'messages'.")
|
|
692
|
+
except Exception:
|
|
693
|
+
logger.info("Input JSON/Messages (fallback print):")
|
|
694
|
+
logger.info(input_json if input_json is not None else "[Messages provided directly]")
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
# --- 3. Iterate Through Candidates and Invoke LLM ---
|
|
698
|
+
last_exception = None
|
|
699
|
+
newly_acquired_keys: Dict[str, bool] = {} # Track keys obtained in this run
|
|
700
|
+
|
|
701
|
+
for model_info in candidate_models:
|
|
702
|
+
model_name_litellm = model_info['model']
|
|
703
|
+
api_key_name = model_info.get('api_key')
|
|
704
|
+
provider = model_info.get('provider', '').lower()
|
|
705
|
+
|
|
706
|
+
if verbose:
|
|
707
|
+
logger.info(f"\n[ATTEMPT] Trying model: {model_name_litellm} (Provider: {provider})")
|
|
708
|
+
|
|
709
|
+
retry_with_same_model = True
|
|
710
|
+
while retry_with_same_model:
|
|
711
|
+
retry_with_same_model = False # Assume success unless auth error on new key
|
|
326
712
|
|
|
327
|
-
|
|
713
|
+
# --- 4. API Key Check & Acquisition ---
|
|
714
|
+
if not _ensure_api_key(model_info, newly_acquired_keys, verbose):
|
|
715
|
+
# Problem getting key, break inner loop, try next model candidate
|
|
716
|
+
if verbose:
|
|
717
|
+
logger.info(f"[SKIP] Skipping {model_name_litellm} due to API key/credentials issue after prompt.")
|
|
718
|
+
break # Breaks the 'while retry_with_same_model' loop
|
|
719
|
+
|
|
720
|
+
# --- 5. Prepare LiteLLM Arguments ---
|
|
721
|
+
litellm_kwargs: Dict[str, Any] = {
|
|
722
|
+
"model": model_name_litellm,
|
|
723
|
+
"messages": formatted_messages,
|
|
724
|
+
"temperature": temperature,
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
api_key_name_from_csv = model_info.get('api_key') # From CSV
|
|
728
|
+
# Determine if it's a Vertex AI model for special handling
|
|
729
|
+
is_vertex_model = (provider.lower() == 'google') or \
|
|
730
|
+
(provider.lower() == 'googlevertexai') or \
|
|
731
|
+
(provider.lower() == 'vertex_ai') or \
|
|
732
|
+
model_name_litellm.startswith('vertex_ai/')
|
|
733
|
+
|
|
734
|
+
if is_vertex_model and api_key_name_from_csv == 'VERTEX_CREDENTIALS':
|
|
735
|
+
credentials_file_path = os.getenv("VERTEX_CREDENTIALS") # Path from env var
|
|
736
|
+
vertex_project_env = os.getenv("VERTEX_PROJECT")
|
|
737
|
+
vertex_location_env = os.getenv("VERTEX_LOCATION")
|
|
738
|
+
|
|
739
|
+
if credentials_file_path and vertex_project_env and vertex_location_env:
|
|
740
|
+
try:
|
|
741
|
+
with open(credentials_file_path, 'r') as f:
|
|
742
|
+
loaded_credentials = json.load(f)
|
|
743
|
+
vertex_credentials_json_string = json.dumps(loaded_credentials)
|
|
744
|
+
|
|
745
|
+
litellm_kwargs["vertex_credentials"] = vertex_credentials_json_string
|
|
746
|
+
litellm_kwargs["vertex_project"] = vertex_project_env
|
|
747
|
+
litellm_kwargs["vertex_location"] = vertex_location_env
|
|
748
|
+
if verbose:
|
|
749
|
+
logger.info(f"[INFO] For Vertex AI: using vertex_credentials from '{credentials_file_path}', project '{vertex_project_env}', location '{vertex_location_env}'.")
|
|
750
|
+
except FileNotFoundError:
|
|
751
|
+
if verbose:
|
|
752
|
+
logger.error(f"[ERROR] Vertex credentials file not found at path specified by VERTEX_CREDENTIALS env var: '{credentials_file_path}'. LiteLLM may try ADC or fail.")
|
|
753
|
+
except json.JSONDecodeError:
|
|
754
|
+
if verbose:
|
|
755
|
+
logger.error(f"[ERROR] Failed to decode JSON from Vertex credentials file: '{credentials_file_path}'. Check file content. LiteLLM may try ADC or fail.")
|
|
756
|
+
except Exception as e:
|
|
757
|
+
if verbose:
|
|
758
|
+
logger.error(f"[ERROR] Failed to load or process Vertex credentials from '{credentials_file_path}': {e}. LiteLLM may try ADC or fail.")
|
|
759
|
+
else:
|
|
760
|
+
if verbose:
|
|
761
|
+
logger.warning(f"[WARN] For Vertex AI (using '{api_key_name_from_csv}'): One or more required environment variables (VERTEX_CREDENTIALS, VERTEX_PROJECT, VERTEX_LOCATION) are missing.")
|
|
762
|
+
if not credentials_file_path: logger.warning(f" Reason: VERTEX_CREDENTIALS (path to JSON file) env var not set or empty.")
|
|
763
|
+
if not vertex_project_env: logger.warning(f" Reason: VERTEX_PROJECT env var not set or empty.")
|
|
764
|
+
if not vertex_location_env: logger.warning(f" Reason: VERTEX_LOCATION env var not set or empty.")
|
|
765
|
+
logger.warning(f" LiteLLM may attempt to use Application Default Credentials or the call may fail.")
|
|
766
|
+
|
|
767
|
+
elif api_key_name_from_csv: # For other api_key_names specified in CSV (e.g., OPENAI_API_KEY, or a direct VERTEX_AI_API_KEY string)
|
|
768
|
+
key_value = os.getenv(api_key_name_from_csv)
|
|
769
|
+
if key_value:
|
|
770
|
+
litellm_kwargs["api_key"] = key_value
|
|
771
|
+
if verbose:
|
|
772
|
+
logger.info(f"[INFO] Explicitly passing API key from env var '{api_key_name_from_csv}' as 'api_key' parameter to LiteLLM.")
|
|
773
|
+
|
|
774
|
+
# If this model is Vertex AI AND uses a direct API key string (not VERTEX_CREDENTIALS from CSV),
|
|
775
|
+
# also pass project and location from env vars.
|
|
776
|
+
if is_vertex_model:
|
|
777
|
+
vertex_project_env = os.getenv("VERTEX_PROJECT")
|
|
778
|
+
vertex_location_env = os.getenv("VERTEX_LOCATION")
|
|
779
|
+
if vertex_project_env and vertex_location_env:
|
|
780
|
+
litellm_kwargs["vertex_project"] = vertex_project_env
|
|
781
|
+
litellm_kwargs["vertex_location"] = vertex_location_env
|
|
782
|
+
if verbose:
|
|
783
|
+
logger.info(f"[INFO] For Vertex AI model (using direct API key '{api_key_name_from_csv}'), also passing vertex_project='{vertex_project_env}' and vertex_location='{vertex_location_env}' from env vars.")
|
|
784
|
+
elif verbose:
|
|
785
|
+
logger.warning(f"[WARN] For Vertex AI model (using direct API key '{api_key_name_from_csv}'), VERTEX_PROJECT or VERTEX_LOCATION env vars not set. This might be required by LiteLLM.")
|
|
786
|
+
elif verbose: # api_key_name_from_csv was in CSV, but corresponding env var was not set/empty
|
|
787
|
+
logger.warning(f"[WARN] API key name '{api_key_name_from_csv}' found in CSV, but the environment variable '{api_key_name_from_csv}' is not set or empty. LiteLLM will use default authentication if applicable (e.g., other standard env vars or ADC).")
|
|
788
|
+
|
|
789
|
+
elif verbose: # No api_key_name_from_csv in CSV for this model
|
|
790
|
+
logger.info(f"[INFO] No API key name specified in CSV for model '{model_name_litellm}'. LiteLLM will use its default authentication mechanisms (e.g., standard provider env vars or ADC for Vertex AI).")
|
|
791
|
+
|
|
792
|
+
# Add api_base if present in CSV
|
|
793
|
+
api_base = model_info.get('base_url')
|
|
794
|
+
if pd.notna(api_base) and api_base:
|
|
795
|
+
litellm_kwargs["api_base"] = str(api_base)
|
|
796
|
+
|
|
797
|
+
# Handle Structured Output (JSON Mode / Pydantic)
|
|
328
798
|
if output_pydantic:
|
|
329
|
-
if model
|
|
330
|
-
|
|
331
|
-
|
|
799
|
+
# Check if model supports structured output based on CSV flag or LiteLLM check
|
|
800
|
+
supports_structured = model_info.get('structured_output', False)
|
|
801
|
+
# Optional: Add litellm.supports_response_schema check if CSV flag is unreliable
|
|
802
|
+
# if not supports_structured:
|
|
803
|
+
# try: supports_structured = litellm.supports_response_schema(model=model_name_litellm)
|
|
804
|
+
# except: pass # Ignore errors in supports_response_schema check
|
|
805
|
+
|
|
806
|
+
if supports_structured:
|
|
807
|
+
if verbose:
|
|
808
|
+
logger.info(f"[INFO] Requesting structured output (Pydantic: {output_pydantic.__name__}) for {model_name_litellm}")
|
|
809
|
+
# Pass the Pydantic model directly if supported, else use json_object
|
|
810
|
+
# LiteLLM handles passing Pydantic models for supported providers
|
|
811
|
+
litellm_kwargs["response_format"] = output_pydantic
|
|
812
|
+
# As a fallback, one could use:
|
|
813
|
+
# litellm_kwargs["response_format"] = {"type": "json_object"}
|
|
814
|
+
# And potentially enable client-side validation:
|
|
815
|
+
# litellm.enable_json_schema_validation = True # Enable globally if needed
|
|
332
816
|
else:
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
chain = prompt_template | llm | StrOutputParser()
|
|
337
|
-
|
|
338
|
-
result_output = chain.invoke(input_json)
|
|
339
|
-
cost = calculate_cost(handler, model)
|
|
340
|
-
|
|
341
|
-
if verbose:
|
|
342
|
-
rprint(f"[bold green]Selected model: {model.model}[/bold green]")
|
|
343
|
-
rprint(f"Per input token cost: ${model.input_cost} per million tokens")
|
|
344
|
-
rprint(f"Per output token cost: ${model.output_cost} per million tokens")
|
|
345
|
-
rprint(f"Number of input tokens: {handler.input_tokens}")
|
|
346
|
-
rprint(f"Number of output tokens: {handler.output_tokens}")
|
|
347
|
-
rprint(f"Cost of invoke run: ${cost:.0e}")
|
|
348
|
-
rprint(f"Strength used: {strength}")
|
|
349
|
-
rprint(f"Temperature used: {temperature}")
|
|
350
|
-
try:
|
|
351
|
-
# Try printing with rich formatting first
|
|
352
|
-
rprint(f"Input JSON: {str(input_json)}")
|
|
353
|
-
except MarkupError:
|
|
354
|
-
# Fallback to standard print if rich markup fails
|
|
355
|
-
print(f"Input JSON: {str(input_json)}")
|
|
356
|
-
except Exception:
|
|
357
|
-
print(f"Input JSON: {input_json}")
|
|
358
|
-
if output_pydantic:
|
|
359
|
-
rprint(f"Output Pydantic format: {output_pydantic}")
|
|
360
|
-
try:
|
|
361
|
-
# Try printing with rich formatting first
|
|
362
|
-
rprint(f"Result: {result_output}")
|
|
363
|
-
except MarkupError as me:
|
|
364
|
-
# Fallback to standard print if rich markup fails
|
|
365
|
-
print(f"[bold yellow]Warning:[/bold yellow] Failed to render result with rich markup: {me}")
|
|
366
|
-
print(f"Raw Result: {str(result_output)}") # Use standard print
|
|
367
|
-
|
|
368
|
-
return {'result': result_output, 'cost': cost, 'model_name': model.model}
|
|
817
|
+
if verbose:
|
|
818
|
+
logger.warning(f"[WARN] Model {model_name_litellm} does not support structured output via CSV flag. Output might not be valid {output_pydantic.__name__}.")
|
|
819
|
+
# Proceed without forcing JSON mode, parsing will be attempted later
|
|
369
820
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
821
|
+
# --- NEW REASONING LOGIC ---
|
|
822
|
+
reasoning_type = model_info.get('reasoning_type', 'none') # Defaults to 'none'
|
|
823
|
+
max_reasoning_tokens_val = model_info.get('max_reasoning_tokens', 0) # Defaults to 0
|
|
824
|
+
|
|
825
|
+
if time > 0: # Only apply reasoning if time is requested
|
|
826
|
+
if reasoning_type == 'budget':
|
|
827
|
+
if max_reasoning_tokens_val > 0:
|
|
828
|
+
budget = int(time * max_reasoning_tokens_val)
|
|
829
|
+
if budget > 0:
|
|
830
|
+
# Currently known: Anthropic uses 'thinking'
|
|
831
|
+
# Model name comparison is more robust than provider string
|
|
832
|
+
if provider == 'anthropic': # Check provider column instead of model prefix
|
|
833
|
+
litellm_kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget}
|
|
834
|
+
if verbose:
|
|
835
|
+
logger.info(f"[INFO] Requesting Anthropic thinking (budget type) with budget: {budget} tokens for {model_name_litellm}")
|
|
836
|
+
else:
|
|
837
|
+
# If other providers adopt a budget param recognized by LiteLLM, add here
|
|
838
|
+
if verbose:
|
|
839
|
+
logger.warning(f"[WARN] Reasoning type is 'budget' for {model_name_litellm}, but no specific LiteLLM budget parameter known for this provider. Parameter not sent.")
|
|
840
|
+
elif verbose:
|
|
841
|
+
logger.info(f"[INFO] Calculated reasoning budget is 0 for {model_name_litellm}, skipping reasoning parameter.")
|
|
842
|
+
elif verbose:
|
|
843
|
+
logger.warning(f"[WARN] Reasoning type is 'budget' for {model_name_litellm}, but 'max_reasoning_tokens' is missing or zero in CSV. Reasoning parameter not sent.")
|
|
844
|
+
|
|
845
|
+
elif reasoning_type == 'effort':
|
|
846
|
+
effort = "low"
|
|
847
|
+
if time > 0.7:
|
|
848
|
+
effort = "high"
|
|
849
|
+
elif time > 0.3:
|
|
850
|
+
effort = "medium"
|
|
851
|
+
# Use the common 'reasoning_effort' param LiteLLM provides
|
|
852
|
+
litellm_kwargs["reasoning_effort"] = effort
|
|
853
|
+
if verbose:
|
|
854
|
+
logger.info(f"[INFO] Requesting reasoning_effort='{effort}' (effort type) for {model_name_litellm} based on time={time}")
|
|
855
|
+
|
|
856
|
+
elif reasoning_type == 'none':
|
|
857
|
+
if verbose:
|
|
858
|
+
logger.info(f"[INFO] Model {model_name_litellm} has reasoning_type='none'. No reasoning parameter sent.")
|
|
859
|
+
|
|
860
|
+
else: # Unknown reasoning_type in CSV
|
|
861
|
+
if verbose:
|
|
862
|
+
logger.warning(f"[WARN] Unknown reasoning_type '{reasoning_type}' for model {model_name_litellm} in CSV. No reasoning parameter sent.")
|
|
863
|
+
|
|
864
|
+
# --- END NEW REASONING LOGIC ---
|
|
865
|
+
|
|
866
|
+
# Add caching control per call if needed (example: force refresh)
|
|
867
|
+
# litellm_kwargs["cache"] = {"no-cache": True}
|
|
868
|
+
|
|
869
|
+
# --- 6. LLM Invocation ---
|
|
870
|
+
try:
|
|
871
|
+
start_time = time_module.time()
|
|
375
872
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
raise RuntimeError("No available models could process the request")
|
|
873
|
+
# Log cache status with proper logging
|
|
874
|
+
logger.debug(f"Cache Check: litellm.cache is None: {litellm.cache is None}")
|
|
875
|
+
if litellm.cache is not None:
|
|
876
|
+
logger.debug(f"litellm.cache type: {type(litellm.cache)}, ID: {id(litellm.cache)}")
|
|
381
877
|
|
|
878
|
+
# Only add if litellm.cache is configured
|
|
879
|
+
if litellm.cache is not None:
|
|
880
|
+
litellm_kwargs["caching"] = True
|
|
881
|
+
logger.debug("Caching enabled for this request")
|
|
882
|
+
else:
|
|
883
|
+
logger.debug("NOT ENABLING CACHING: litellm.cache is None at call time")
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
if use_batch_mode:
|
|
887
|
+
if verbose:
|
|
888
|
+
logger.info(f"[INFO] Calling litellm.batch_completion for {model_name_litellm}...")
|
|
889
|
+
response = litellm.batch_completion(**litellm_kwargs)
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
else:
|
|
893
|
+
if verbose:
|
|
894
|
+
logger.info(f"[INFO] Calling litellm.completion for {model_name_litellm}...")
|
|
895
|
+
response = litellm.completion(**litellm_kwargs)
|
|
896
|
+
|
|
897
|
+
end_time = time_module.time()
|
|
898
|
+
|
|
899
|
+
if verbose:
|
|
900
|
+
logger.info(f"[SUCCESS] Invocation successful for {model_name_litellm} (took {end_time - start_time:.2f}s)")
|
|
901
|
+
|
|
902
|
+
# --- 7. Process Response ---
|
|
903
|
+
results = []
|
|
904
|
+
thinking_outputs = []
|
|
905
|
+
|
|
906
|
+
response_list = response if use_batch_mode else [response]
|
|
907
|
+
|
|
908
|
+
for i, resp_item in enumerate(response_list):
|
|
909
|
+
# Cost calculation is handled entirely by the success callback
|
|
910
|
+
|
|
911
|
+
# Thinking Output
|
|
912
|
+
thinking = None
|
|
913
|
+
try:
|
|
914
|
+
# Attempt 1: Check _hidden_params based on isolated test script
|
|
915
|
+
if hasattr(resp_item, '_hidden_params') and resp_item._hidden_params and 'thinking' in resp_item._hidden_params:
|
|
916
|
+
thinking = resp_item._hidden_params['thinking']
|
|
917
|
+
if verbose:
|
|
918
|
+
logger.debug("[DEBUG] Extracted thinking output from response._hidden_params['thinking']")
|
|
919
|
+
# Attempt 2: Fallback to reasoning_content in message
|
|
920
|
+
# Use .get() for safer access
|
|
921
|
+
elif hasattr(resp_item, 'choices') and resp_item.choices and hasattr(resp_item.choices[0], 'message') and hasattr(resp_item.choices[0].message, 'get') and resp_item.choices[0].message.get('reasoning_content'):
|
|
922
|
+
thinking = resp_item.choices[0].message.get('reasoning_content')
|
|
923
|
+
if verbose:
|
|
924
|
+
logger.debug("[DEBUG] Extracted thinking output from response.choices[0].message.get('reasoning_content')")
|
|
925
|
+
|
|
926
|
+
except (AttributeError, IndexError, KeyError, TypeError):
|
|
927
|
+
if verbose:
|
|
928
|
+
logger.debug("[DEBUG] Failed to extract thinking output from known locations.")
|
|
929
|
+
pass # Ignore if structure doesn't match or errors occur
|
|
930
|
+
thinking_outputs.append(thinking)
|
|
931
|
+
|
|
932
|
+
# Result (String or Pydantic)
|
|
933
|
+
try:
|
|
934
|
+
raw_result = resp_item.choices[0].message.content
|
|
935
|
+
if output_pydantic:
|
|
936
|
+
parsed_result = None
|
|
937
|
+
json_string_to_parse = None
|
|
938
|
+
|
|
939
|
+
try:
|
|
940
|
+
# Attempt 1: Check if LiteLLM already parsed it
|
|
941
|
+
if isinstance(raw_result, output_pydantic):
|
|
942
|
+
parsed_result = raw_result
|
|
943
|
+
if verbose:
|
|
944
|
+
logger.debug("[DEBUG] Pydantic object received directly from LiteLLM.")
|
|
945
|
+
|
|
946
|
+
# Attempt 2: Check if raw_result is dict-like and validate
|
|
947
|
+
elif isinstance(raw_result, dict):
|
|
948
|
+
parsed_result = output_pydantic.model_validate(raw_result)
|
|
949
|
+
if verbose:
|
|
950
|
+
logger.debug("[DEBUG] Validated dictionary-like object directly.")
|
|
951
|
+
|
|
952
|
+
# Attempt 3: Process as string (if not already parsed/validated)
|
|
953
|
+
elif isinstance(raw_result, str):
|
|
954
|
+
json_string_to_parse = raw_result # Start with the raw string
|
|
955
|
+
try:
|
|
956
|
+
# Look for first { and last }
|
|
957
|
+
start_brace = json_string_to_parse.find('{')
|
|
958
|
+
end_brace = json_string_to_parse.rfind('}')
|
|
959
|
+
if start_brace != -1 and end_brace != -1 and end_brace > start_brace:
|
|
960
|
+
potential_json = json_string_to_parse[start_brace:end_brace+1]
|
|
961
|
+
# Basic check if it looks like JSON
|
|
962
|
+
if potential_json.strip().startswith('{') and potential_json.strip().endswith('}'):
|
|
963
|
+
if verbose:
|
|
964
|
+
logger.debug(f"[DEBUG] Attempting to parse extracted JSON block: '{potential_json}'")
|
|
965
|
+
parsed_result = output_pydantic.model_validate_json(potential_json)
|
|
966
|
+
else:
|
|
967
|
+
# If block extraction fails, try cleaning markdown next
|
|
968
|
+
raise ValueError("Extracted block doesn't look like JSON")
|
|
969
|
+
else:
|
|
970
|
+
# If no braces found, try cleaning markdown next
|
|
971
|
+
raise ValueError("Could not find enclosing {}")
|
|
972
|
+
except (json.JSONDecodeError, ValidationError, ValueError) as extraction_error:
|
|
973
|
+
if verbose:
|
|
974
|
+
logger.debug(f"[DEBUG] JSON block extraction/validation failed ('{extraction_error}'). Trying markdown cleaning.")
|
|
975
|
+
# Fallback: Clean markdown fences and retry JSON validation
|
|
976
|
+
cleaned_result_str = raw_result.strip()
|
|
977
|
+
if cleaned_result_str.startswith("```json"):
|
|
978
|
+
cleaned_result_str = cleaned_result_str[7:]
|
|
979
|
+
elif cleaned_result_str.startswith("```"):
|
|
980
|
+
cleaned_result_str = cleaned_result_str[3:]
|
|
981
|
+
if cleaned_result_str.endswith("```"):
|
|
982
|
+
cleaned_result_str = cleaned_result_str[:-3]
|
|
983
|
+
cleaned_result_str = cleaned_result_str.strip()
|
|
984
|
+
# Check again if it looks like JSON before parsing
|
|
985
|
+
if cleaned_result_str.startswith('{') and cleaned_result_str.endswith('}'):
|
|
986
|
+
if verbose:
|
|
987
|
+
logger.debug(f"[DEBUG] Attempting parse after cleaning markdown fences. Cleaned string: '{cleaned_result_str}'")
|
|
988
|
+
json_string_to_parse = cleaned_result_str # Update string for error reporting
|
|
989
|
+
parsed_result = output_pydantic.model_validate_json(json_string_to_parse)
|
|
990
|
+
else:
|
|
991
|
+
# If still doesn't look like JSON, raise error
|
|
992
|
+
raise ValueError("Content after cleaning markdown doesn't look like JSON")
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
# Check if any parsing attempt succeeded
|
|
996
|
+
if parsed_result is None:
|
|
997
|
+
# This case should ideally be caught by exceptions above, but as a safeguard:
|
|
998
|
+
raise TypeError(f"Raw result type {type(raw_result)} or content could not be validated/parsed against {output_pydantic.__name__}.")
|
|
999
|
+
|
|
1000
|
+
except (ValidationError, json.JSONDecodeError, TypeError, ValueError) as parse_error:
|
|
1001
|
+
logger.error(f"[ERROR] Failed to parse response into Pydantic model {output_pydantic.__name__} for item {i}: {parse_error}")
|
|
1002
|
+
# Use the string that was last attempted for parsing in the error message
|
|
1003
|
+
error_content = json_string_to_parse if json_string_to_parse is not None else raw_result
|
|
1004
|
+
logger.error("[ERROR] Content attempted for parsing: %s", repr(error_content)) # CORRECTED (or use f-string)
|
|
1005
|
+
results.append(f"ERROR: Failed to parse Pydantic. Raw: {repr(raw_result)}")
|
|
1006
|
+
continue # Skip appending result below if parsing failed
|
|
1007
|
+
|
|
1008
|
+
# If parsing succeeded, append the parsed_result
|
|
1009
|
+
results.append(parsed_result)
|
|
1010
|
+
|
|
1011
|
+
else:
|
|
1012
|
+
# If output_pydantic was not requested, append the raw result
|
|
1013
|
+
results.append(raw_result)
|
|
1014
|
+
|
|
1015
|
+
except (AttributeError, IndexError) as e:
|
|
1016
|
+
logger.error(f"[ERROR] Could not extract result content from response item {i}: {e}")
|
|
1017
|
+
results.append(f"ERROR: Could not extract result content. Response: {resp_item}")
|
|
1018
|
+
|
|
1019
|
+
# --- Retrieve Cost from Callback Data --- (Reinstated)
|
|
1020
|
+
# For batch, this will reflect the cost associated with the *last* item processed by the callback.
|
|
1021
|
+
# A fully accurate batch total would require a more complex callback class to aggregate.
|
|
1022
|
+
total_cost = _LAST_CALLBACK_DATA.get("cost", 0.0)
|
|
1023
|
+
# ----------------------------------------
|
|
1024
|
+
|
|
1025
|
+
final_result = results if use_batch_mode else results[0]
|
|
1026
|
+
final_thinking = thinking_outputs if use_batch_mode else thinking_outputs[0]
|
|
1027
|
+
|
|
1028
|
+
# --- Verbose Output for Success ---
|
|
1029
|
+
if verbose:
|
|
1030
|
+
# Get token usage from the *last* callback data (might not be accurate for batch)
|
|
1031
|
+
input_tokens = _LAST_CALLBACK_DATA.get("input_tokens", 0)
|
|
1032
|
+
output_tokens = _LAST_CALLBACK_DATA.get("output_tokens", 0)
|
|
1033
|
+
|
|
1034
|
+
cost_input_pm = model_info.get('input', 0.0) if pd.notna(model_info.get('input')) else 0.0
|
|
1035
|
+
cost_output_pm = model_info.get('output', 0.0) if pd.notna(model_info.get('output')) else 0.0
|
|
1036
|
+
|
|
1037
|
+
logger.info(f"[RESULT] Model Used: {model_name_litellm}")
|
|
1038
|
+
logger.info(f"[RESULT] Cost (Input): ${cost_input_pm:.2f}/M tokens")
|
|
1039
|
+
logger.info(f"[RESULT] Cost (Output): ${cost_output_pm:.2f}/M tokens")
|
|
1040
|
+
logger.info(f"[RESULT] Tokens (Prompt): {input_tokens}")
|
|
1041
|
+
logger.info(f"[RESULT] Tokens (Completion): {output_tokens}")
|
|
1042
|
+
# Display the cost captured by the callback
|
|
1043
|
+
logger.info(f"[RESULT] Total Cost (from callback): ${total_cost:.6g}") # Renamed label for clarity
|
|
1044
|
+
logger.info("[RESULT] Max Completion Tokens: Provider Default") # Indicate default limit
|
|
1045
|
+
if final_thinking:
|
|
1046
|
+
logger.info("[RESULT] Thinking Output:")
|
|
1047
|
+
logger.info(final_thinking)
|
|
1048
|
+
|
|
1049
|
+
# --- Print raw output before returning if verbose ---
|
|
1050
|
+
if verbose:
|
|
1051
|
+
logger.debug("[DEBUG] Raw output before return:")
|
|
1052
|
+
logger.debug(f" Raw Result (repr): {repr(final_result)}")
|
|
1053
|
+
logger.debug(f" Raw Thinking (repr): {repr(final_thinking)}")
|
|
1054
|
+
logger.debug("-" * 20) # Separator
|
|
1055
|
+
|
|
1056
|
+
# --- Return Success ---
|
|
1057
|
+
return {
|
|
1058
|
+
'result': final_result,
|
|
1059
|
+
'cost': total_cost,
|
|
1060
|
+
'model_name': model_name_litellm, # Actual model used
|
|
1061
|
+
'thinking_output': final_thinking if final_thinking else None
|
|
1062
|
+
}
|
|
1063
|
+
|
|
1064
|
+
# --- 6b. Handle Invocation Errors ---
|
|
1065
|
+
except openai.AuthenticationError as e:
|
|
1066
|
+
last_exception = e
|
|
1067
|
+
if newly_acquired_keys.get(api_key_name):
|
|
1068
|
+
logger.warning(f"[AUTH ERROR] Authentication failed for {model_name_litellm} with the newly provided key for '{api_key_name}'. Please check the key and try again.")
|
|
1069
|
+
# Invalidate the key in env for this session to force re-prompt on retry
|
|
1070
|
+
if api_key_name in os.environ:
|
|
1071
|
+
del os.environ[api_key_name]
|
|
1072
|
+
# Clear the 'newly acquired' status for this key so the next attempt doesn't trigger immediate retry loop
|
|
1073
|
+
newly_acquired_keys[api_key_name] = False
|
|
1074
|
+
retry_with_same_model = True # Set flag to retry the same model after re-prompt
|
|
1075
|
+
# Go back to the start of the 'while retry_with_same_model' loop
|
|
1076
|
+
else:
|
|
1077
|
+
logger.warning(f"[AUTH ERROR] Authentication failed for {model_name_litellm} using existing key '{api_key_name}'. Trying next model.")
|
|
1078
|
+
break # Break inner loop, try next model candidate
|
|
1079
|
+
|
|
1080
|
+
except (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError,
|
|
1081
|
+
openai.APIStatusError, openai.BadRequestError, openai.InternalServerError,
|
|
1082
|
+
Exception) as e: # Catch generic Exception last
|
|
1083
|
+
last_exception = e
|
|
1084
|
+
error_type = type(e).__name__
|
|
1085
|
+
logger.error(f"[ERROR] Invocation failed for {model_name_litellm} ({error_type}): {e}. Trying next model.")
|
|
1086
|
+
# Log more details in verbose mode
|
|
1087
|
+
if verbose:
|
|
1088
|
+
# import traceback # Not needed if using exc_info=True
|
|
1089
|
+
logger.debug(f"Detailed exception traceback for {model_name_litellm}:", exc_info=True)
|
|
1090
|
+
break # Break inner loop, try next model candidate
|
|
1091
|
+
|
|
1092
|
+
# If the inner loop was broken (not by success), continue to the next candidate model
|
|
1093
|
+
continue
|
|
1094
|
+
|
|
1095
|
+
# --- 8. Handle Failure of All Candidates ---
|
|
1096
|
+
error_message = "All candidate models failed."
|
|
1097
|
+
if last_exception:
|
|
1098
|
+
error_message += f" Last error ({type(last_exception).__name__}): {last_exception}"
|
|
1099
|
+
logger.error(f"[FATAL] {error_message}")
|
|
1100
|
+
raise RuntimeError(error_message) from last_exception
|
|
1101
|
+
|
|
1102
|
+
# --- Example Usage (Optional) ---
|
|
382
1103
|
if __name__ == "__main__":
|
|
383
|
-
|
|
384
|
-
|
|
1104
|
+
# This block allows running the file directly for testing.
|
|
1105
|
+
# Ensure you have a ./data/llm_model.csv file and potentially a .env file.
|
|
1106
|
+
|
|
1107
|
+
# Set PDD_DEBUG_SELECTOR=1 to see model selection details
|
|
1108
|
+
# os.environ["PDD_DEBUG_SELECTOR"] = "1"
|
|
1109
|
+
|
|
1110
|
+
# Example 1: Simple text generation
|
|
1111
|
+
logger.info("\n--- Example 1: Simple Text Generation (Strength 0.5) ---")
|
|
1112
|
+
try:
|
|
1113
|
+
response = llm_invoke(
|
|
1114
|
+
prompt="Tell me a short joke about {topic}.",
|
|
1115
|
+
input_json={"topic": "programmers"},
|
|
1116
|
+
strength=0.5, # Use base model (gpt-4.1-nano)
|
|
1117
|
+
temperature=0.7,
|
|
1118
|
+
verbose=True
|
|
1119
|
+
)
|
|
1120
|
+
logger.info("\nExample 1 Response:")
|
|
1121
|
+
logger.info(response)
|
|
1122
|
+
except Exception as e:
|
|
1123
|
+
logger.error(f"\nExample 1 Failed: {e}", exc_info=True)
|
|
1124
|
+
|
|
1125
|
+
# Example 1b: Simple text generation (Strength 0.3)
|
|
1126
|
+
logger.info("\n--- Example 1b: Simple Text Generation (Strength 0.3) ---")
|
|
385
1127
|
try:
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
1128
|
+
response = llm_invoke(
|
|
1129
|
+
prompt="Tell me a short joke about {topic}.",
|
|
1130
|
+
input_json={"topic": "keyboards"},
|
|
1131
|
+
strength=0.3, # Should select gemini-pro based on cost interpolation
|
|
1132
|
+
temperature=0.7,
|
|
1133
|
+
verbose=True
|
|
1134
|
+
)
|
|
1135
|
+
logger.info("\nExample 1b Response:")
|
|
1136
|
+
logger.info(response)
|
|
1137
|
+
except Exception as e:
|
|
1138
|
+
logger.error(f"\nExample 1b Failed: {e}", exc_info=True)
|
|
1139
|
+
|
|
1140
|
+
# Example 2: Structured output (requires a Pydantic model)
|
|
1141
|
+
logger.info("\n--- Example 2: Structured Output (Pydantic, Strength 0.8) ---")
|
|
1142
|
+
class JokeStructure(BaseModel):
|
|
1143
|
+
setup: str
|
|
1144
|
+
punchline: str
|
|
1145
|
+
rating: Optional[int] = None
|
|
1146
|
+
|
|
1147
|
+
try:
|
|
1148
|
+
# Use a model known to support structured output (check your CSV)
|
|
1149
|
+
# Strength 0.8 should select gemini-pro based on ELO interpolation
|
|
1150
|
+
response_structured = llm_invoke(
|
|
1151
|
+
prompt="Create a joke about {topic}. Output ONLY the JSON object with 'setup' and 'punchline'.",
|
|
1152
|
+
input_json={"topic": "data science"},
|
|
1153
|
+
strength=0.8, # Try a higher ELO model (gemini-pro expected)
|
|
1154
|
+
temperature=1,
|
|
1155
|
+
output_pydantic=JokeStructure,
|
|
1156
|
+
verbose=True
|
|
1157
|
+
)
|
|
1158
|
+
logger.info("\nExample 2 Response:")
|
|
1159
|
+
logger.info(response_structured)
|
|
1160
|
+
if isinstance(response_structured.get('result'), JokeStructure):
|
|
1161
|
+
logger.info("\nPydantic object received successfully: %s", response_structured['result'].model_dump())
|
|
1162
|
+
else:
|
|
1163
|
+
logger.info("\nResult was not the expected Pydantic object: %s", response_structured.get('result'))
|
|
1164
|
+
|
|
1165
|
+
except Exception as e:
|
|
1166
|
+
logger.error(f"\nExample 2 Failed: {e}", exc_info=True)
|
|
1167
|
+
|
|
1168
|
+
|
|
1169
|
+
# Example 3: Batch processing
|
|
1170
|
+
logger.info("\n--- Example 3: Batch Processing (Strength 0.3) ---")
|
|
1171
|
+
try:
|
|
1172
|
+
batch_input = [
|
|
1173
|
+
{"animal": "cat", "adjective": "lazy"},
|
|
1174
|
+
{"animal": "dog", "adjective": "energetic"},
|
|
1175
|
+
]
|
|
1176
|
+
# Strength 0.3 should select gemini-pro
|
|
1177
|
+
response_batch = llm_invoke(
|
|
1178
|
+
prompt="Describe a {adjective} {animal} in one sentence.",
|
|
1179
|
+
input_json=batch_input,
|
|
1180
|
+
strength=0.3, # Cheaper model maybe (gemini-pro expected)
|
|
1181
|
+
temperature=0.5,
|
|
1182
|
+
use_batch_mode=True,
|
|
1183
|
+
verbose=True
|
|
1184
|
+
)
|
|
1185
|
+
logger.info("\nExample 3 Response:")
|
|
1186
|
+
logger.info(response_batch)
|
|
1187
|
+
except Exception as e:
|
|
1188
|
+
logger.error(f"\nExample 3 Failed: {e}", exc_info=True)
|
|
1189
|
+
|
|
1190
|
+
# Example 4: Using 'messages' input
|
|
1191
|
+
logger.info("\n--- Example 4: Using 'messages' input (Strength 0.5) ---")
|
|
1192
|
+
try:
|
|
1193
|
+
custom_messages = [
|
|
1194
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
1195
|
+
{"role": "user", "content": "What is the capital of France?"}
|
|
1196
|
+
]
|
|
1197
|
+
# Strength 0.5 should select gpt-4.1-nano
|
|
1198
|
+
response_messages = llm_invoke(
|
|
1199
|
+
messages=custom_messages,
|
|
1200
|
+
strength=0.5,
|
|
1201
|
+
temperature=0.1,
|
|
1202
|
+
verbose=True
|
|
1203
|
+
)
|
|
1204
|
+
logger.info("\nExample 4 Response:")
|
|
1205
|
+
logger.info(response_messages)
|
|
1206
|
+
except Exception as e:
|
|
1207
|
+
logger.error(f"\nExample 4 Failed: {e}", exc_info=True)
|
|
1208
|
+
|
|
1209
|
+
# Example 5: Requesting thinking time (e.g., for Anthropic)
|
|
1210
|
+
logger.info("\n--- Example 5: Requesting Thinking Time (Strength 1.0, Time 0.5) ---")
|
|
1211
|
+
try:
|
|
1212
|
+
# Ensure your CSV has max_reasoning_tokens for an Anthropic model
|
|
1213
|
+
# Strength 1.0 should select claude-3 (highest ELO)
|
|
1214
|
+
# Time 0.5 with budget type should request thinking
|
|
1215
|
+
response_thinking = llm_invoke(
|
|
1216
|
+
prompt="Explain the theory of relativity simply, taking some time to think.",
|
|
1217
|
+
input_json={},
|
|
1218
|
+
strength=1.0, # Try to get highest ELO model (claude-3)
|
|
1219
|
+
temperature=1,
|
|
1220
|
+
time=0.5, # Request moderate thinking time
|
|
1221
|
+
verbose=True
|
|
1222
|
+
)
|
|
1223
|
+
logger.info("\nExample 5 Response:")
|
|
1224
|
+
logger.info(response_thinking)
|
|
1225
|
+
except Exception as e:
|
|
1226
|
+
logger.error(f"\nExample 5 Failed: {e}", exc_info=True)
|
|
1227
|
+
|
|
1228
|
+
# Example 6: Pydantic Fallback Parsing (Strength 0.3)
|
|
1229
|
+
logger.info("\n--- Example 6: Pydantic Fallback Parsing (Strength 0.3) ---")
|
|
1230
|
+
# This requires mocking litellm.completion to return a JSON string
|
|
1231
|
+
# even when gemini-pro (which supports structured output) is selected.
|
|
1232
|
+
# This is hard to demonstrate cleanly in the __main__ block without mocks.
|
|
1233
|
+
# The unit test test_llm_invoke_output_pydantic_unsupported_parses covers this.
|
|
1234
|
+
logger.info("(Covered by unit tests)")
|