dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
# dsat/services/llm.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Unified, asynchronous LLM service powered by LiteLLM.
|
|
5
|
+
Provides a simple interface for standard calls, structured JSON output,
|
|
6
|
+
and automatic cost tracking.
|
|
7
|
+
"""
|
|
8
|
+
import logging
|
|
9
|
+
import asyncio
|
|
10
|
+
import yaml
|
|
11
|
+
import copy
|
|
12
|
+
import time
|
|
13
|
+
import uuid
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Type, Optional, Any, Dict, List
|
|
17
|
+
|
|
18
|
+
import litellm
|
|
19
|
+
from pydantic import BaseModel, ValidationError
|
|
20
|
+
|
|
21
|
+
from dsat.config import LLMConfig # Use the main pydantic config
|
|
22
|
+
from dsat.common.exceptions import LLMError
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
# Configure LiteLLM globally
|
|
27
|
+
litellm.telemetry = False # Disable anonymous telemetry
|
|
28
|
+
litellm.input_callbacks = []
|
|
29
|
+
litellm.success_callbacks = []
|
|
30
|
+
litellm.failure_callbacks = []
|
|
31
|
+
|
|
32
|
+
# Load custom model pricing from YAML configuration file
|
|
33
|
+
def _load_custom_model_pricing():
|
|
34
|
+
"""Load custom model pricing configuration from config.yaml file."""
|
|
35
|
+
try:
|
|
36
|
+
# Get the path to the config.yaml file relative to this module
|
|
37
|
+
current_dir = Path(__file__).parent
|
|
38
|
+
framework_dir = current_dir.parent.parent # Go up to ds_agent_framework
|
|
39
|
+
config_yaml_path = framework_dir / "config.yaml"
|
|
40
|
+
|
|
41
|
+
if config_yaml_path.exists():
|
|
42
|
+
with open(config_yaml_path, 'r', encoding='utf-8') as f:
|
|
43
|
+
config = yaml.safe_load(f)
|
|
44
|
+
return config.get('custom_model_pricing', {})
|
|
45
|
+
else:
|
|
46
|
+
# Changed to debug to avoid confusing warnings for pip-installed packages
|
|
47
|
+
logger.debug(f"Config file not found at {config_yaml_path} (this is expected for pip-installed packages)")
|
|
48
|
+
return {}
|
|
49
|
+
except Exception as e:
|
|
50
|
+
logger.error(f"Failed to load cost configuration: {e}")
|
|
51
|
+
return {}
|
|
52
|
+
|
|
53
|
+
# Load and apply custom model pricing
|
|
54
|
+
CUSTOM_MODEL_PRICING = _load_custom_model_pricing()
|
|
55
|
+
if CUSTOM_MODEL_PRICING:
|
|
56
|
+
litellm.model_cost.update(CUSTOM_MODEL_PRICING)
|
|
57
|
+
logger.info(f"Loaded custom model pricing for {len(CUSTOM_MODEL_PRICING)} models")
|
|
58
|
+
|
|
59
|
+
class LLMService:
|
|
60
|
+
"""
|
|
61
|
+
A robust wrapper around LiteLLM that handles requests, structured formatting,
|
|
62
|
+
and cost tracking. It's configured via the main DSATConfig's LLMConfig.
|
|
63
|
+
"""
|
|
64
|
+
def __init__(self, config: LLMConfig):
|
|
65
|
+
self.config = config
|
|
66
|
+
self.total_cost = 0.0
|
|
67
|
+
self.total_prompt_tokens = 0
|
|
68
|
+
self.total_completion_tokens = 0
|
|
69
|
+
self.total_prompt_cost = 0.0
|
|
70
|
+
self.total_completion_cost = 0.0
|
|
71
|
+
self.call_history: List[Dict[str, Any]] = []
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _safe_float(value: Any) -> Optional[float]:
|
|
75
|
+
try:
|
|
76
|
+
return float(value)
|
|
77
|
+
except (TypeError, ValueError):
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def _safe_int(value: Any) -> Optional[int]:
|
|
82
|
+
try:
|
|
83
|
+
return int(value)
|
|
84
|
+
except (TypeError, ValueError):
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
def _is_retryable_error(self, error: Exception) -> bool:
|
|
88
|
+
"""
|
|
89
|
+
Determines whether an error is retryable based on litellm exception types.
|
|
90
|
+
Only network timeouts, rate limits, and temporary service issues should be retried.
|
|
91
|
+
Authentication errors, invalid requests, and other permanent failures should not be retried.
|
|
92
|
+
"""
|
|
93
|
+
# Import litellm exceptions locally to avoid import issues
|
|
94
|
+
try:
|
|
95
|
+
import litellm.exceptions as litellm_exceptions
|
|
96
|
+
except ImportError:
|
|
97
|
+
# If litellm exceptions module is not available, be conservative and retry
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
# Non-retryable errors - fail immediately
|
|
101
|
+
non_retryable_errors = (
|
|
102
|
+
litellm_exceptions.AuthenticationError, # API key issues
|
|
103
|
+
litellm_exceptions.InvalidRequestError, # Request format/parameter issues
|
|
104
|
+
litellm_exceptions.PermissionDeniedError, # Insufficient permissions
|
|
105
|
+
litellm_exceptions.NotFoundError, # Model/endpoint not found
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Retryable errors - can be retried
|
|
109
|
+
retryable_errors = (
|
|
110
|
+
litellm_exceptions.RateLimitError, # Rate limit exceeded
|
|
111
|
+
litellm_exceptions.ServiceUnavailableError, # Temporary service issues
|
|
112
|
+
litellm_exceptions.Timeout, # Network timeout
|
|
113
|
+
litellm_exceptions.APIConnectionError, # Connection issues
|
|
114
|
+
litellm_exceptions.InternalServerError, # Server-side temporary issues
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Check for specific error types
|
|
118
|
+
if isinstance(error, non_retryable_errors):
|
|
119
|
+
return False
|
|
120
|
+
elif isinstance(error, retryable_errors):
|
|
121
|
+
return True
|
|
122
|
+
else:
|
|
123
|
+
# For unknown errors, be conservative and retry
|
|
124
|
+
# This handles generic network errors, etc.
|
|
125
|
+
return True
|
|
126
|
+
|
|
127
|
+
def _supports_response_format(self) -> bool:
|
|
128
|
+
"""
|
|
129
|
+
Whether it's safe to pass `response_format` through LiteLLM for this model.
|
|
130
|
+
|
|
131
|
+
Some OpenAI reasoning models (e.g. `o4-mini-*`) reject `response_format` and
|
|
132
|
+
require JSON-only behavior to be enforced via prompt instead.
|
|
133
|
+
"""
|
|
134
|
+
raw_model = (self.config.model or "").strip()
|
|
135
|
+
model = raw_model.split("/")[-1].strip()
|
|
136
|
+
if model.startswith("o4-mini-") or model == "o4-mini":
|
|
137
|
+
return False
|
|
138
|
+
return True
|
|
139
|
+
|
|
140
|
+
async def _make_llm_call_with_retries(
|
|
141
|
+
self, messages: list, response_format: Optional[dict] = None, max_retries: int = 3, base_delay: float = 1.0
|
|
142
|
+
):
|
|
143
|
+
"""
|
|
144
|
+
Internal method to make LLM calls with centralized retry logic and exponential backoff.
|
|
145
|
+
This method is the single point of contact with the LiteLLM library.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
messages: List of message dictionaries for the LLM.
|
|
149
|
+
response_format: Optional response format specification.
|
|
150
|
+
max_retries: Maximum number of retry attempts.
|
|
151
|
+
base_delay: Base delay in seconds for exponential backoff.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
The raw LiteLLM response object upon success.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
LLMError: If all retry attempts fail due to API errors or empty responses.
|
|
158
|
+
"""
|
|
159
|
+
if response_format and not self._supports_response_format():
|
|
160
|
+
logger.info(
|
|
161
|
+
"Dropping unsupported `response_format` for model %s; using prompt-enforced JSON instead.",
|
|
162
|
+
self.config.model,
|
|
163
|
+
)
|
|
164
|
+
response_format = None
|
|
165
|
+
|
|
166
|
+
logger.info(f"prompt: {messages[-1]['content']}")
|
|
167
|
+
last_exception = None
|
|
168
|
+
for attempt in range(max_retries):
|
|
169
|
+
call_id = uuid.uuid4().hex
|
|
170
|
+
call_started_at = datetime.utcnow()
|
|
171
|
+
perf_start = time.perf_counter()
|
|
172
|
+
try:
|
|
173
|
+
kwargs = {
|
|
174
|
+
"model": self.config.model,
|
|
175
|
+
"messages": messages,
|
|
176
|
+
"temperature": self.config.temperature,
|
|
177
|
+
"api_key": self.config.api_key,
|
|
178
|
+
"api_base": self.config.api_base
|
|
179
|
+
}
|
|
180
|
+
if self.config.provider:
|
|
181
|
+
kwargs["custom_llm_provider"] = self.config.provider
|
|
182
|
+
|
|
183
|
+
if response_format:
|
|
184
|
+
kwargs["response_format"] = response_format
|
|
185
|
+
|
|
186
|
+
response = await litellm.acompletion(**kwargs)
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
content = response.choices[0].message.content
|
|
190
|
+
if content and content.strip():
|
|
191
|
+
duration = time.perf_counter() - perf_start
|
|
192
|
+
self._record_successful_call(
|
|
193
|
+
call_id=call_id,
|
|
194
|
+
call_started_at=call_started_at,
|
|
195
|
+
duration=duration,
|
|
196
|
+
messages=messages,
|
|
197
|
+
response=response,
|
|
198
|
+
content=content,
|
|
199
|
+
response_format=response_format
|
|
200
|
+
)
|
|
201
|
+
return response # Success!
|
|
202
|
+
else:
|
|
203
|
+
# Treat empty response as a failure to be retried
|
|
204
|
+
logger.warning(f"LLM returned an empty response on attempt {attempt + 1}/{max_retries}.")
|
|
205
|
+
last_exception = LLMError("LLM returned an empty response.")
|
|
206
|
+
except (IndexError, AttributeError) as content_error:
|
|
207
|
+
logger.warning(f"Invalid response structure on attempt {attempt + 1}/{max_retries}: {content_error}")
|
|
208
|
+
last_exception = LLMError(f"Invalid response structure: {content_error}")
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
# Check if this is a retryable error
|
|
212
|
+
if self._is_retryable_error(e):
|
|
213
|
+
logger.warning(f"Retryable LLM error on attempt {attempt + 1}/{max_retries}: {e}")
|
|
214
|
+
last_exception = e
|
|
215
|
+
logger.debug(f"Debug info - messages: {messages}, response_format: {response_format if response_format else 'None'}")
|
|
216
|
+
else:
|
|
217
|
+
# Non-retryable error - fail immediately
|
|
218
|
+
logger.error(f"Non-retryable LLM error: {e}")
|
|
219
|
+
raise LLMError(f"LLM call failed with non-retryable error: {e}") from e
|
|
220
|
+
|
|
221
|
+
# If this was the last attempt, break the loop to raise the final error
|
|
222
|
+
if attempt == max_retries - 1:
|
|
223
|
+
break
|
|
224
|
+
|
|
225
|
+
# Exponential backoff with jitter
|
|
226
|
+
delay = base_delay * (3 ** attempt) + (asyncio.get_event_loop().time() % 1)
|
|
227
|
+
logger.info(f"Retrying LLM call in {delay:.2f} seconds ({attempt + 2}/{max_retries})...")
|
|
228
|
+
await asyncio.sleep(delay)
|
|
229
|
+
|
|
230
|
+
raise LLMError(f"LLM call failed after {max_retries} attempts. Last error: {last_exception}") from last_exception
|
|
231
|
+
|
|
232
|
+
def _record_successful_call(
|
|
233
|
+
self,
|
|
234
|
+
call_id: str,
|
|
235
|
+
call_started_at: datetime,
|
|
236
|
+
duration: float,
|
|
237
|
+
messages: list,
|
|
238
|
+
response: Any,
|
|
239
|
+
content: str,
|
|
240
|
+
response_format: Optional[dict],
|
|
241
|
+
) -> None:
|
|
242
|
+
"""
|
|
243
|
+
将一次成功的调用附加到历史中,并更新累计 token / 费用。
|
|
244
|
+
"""
|
|
245
|
+
usage_payload = self._extract_usage(response)
|
|
246
|
+
try:
|
|
247
|
+
call_cost_raw = litellm.completion_cost(completion_response=response)
|
|
248
|
+
call_cost = float(call_cost_raw) if call_cost_raw is not None else 0.0
|
|
249
|
+
except Exception:
|
|
250
|
+
call_cost = 0.0
|
|
251
|
+
|
|
252
|
+
self.total_cost += call_cost
|
|
253
|
+
|
|
254
|
+
prompt_tokens = usage_payload.get("prompt_tokens") if usage_payload else None
|
|
255
|
+
completion_tokens = usage_payload.get("completion_tokens") if usage_payload else None
|
|
256
|
+
|
|
257
|
+
if prompt_tokens:
|
|
258
|
+
self.total_prompt_tokens += prompt_tokens
|
|
259
|
+
if completion_tokens:
|
|
260
|
+
self.total_completion_tokens += completion_tokens
|
|
261
|
+
|
|
262
|
+
prompt_cost_val = usage_payload.get("prompt_tokens_cost") if usage_payload else None
|
|
263
|
+
completion_cost_val = usage_payload.get("completion_tokens_cost") if usage_payload else None
|
|
264
|
+
|
|
265
|
+
if prompt_cost_val is not None:
|
|
266
|
+
self.total_prompt_cost += prompt_cost_val
|
|
267
|
+
if completion_cost_val is not None:
|
|
268
|
+
self.total_completion_cost += completion_cost_val
|
|
269
|
+
|
|
270
|
+
total_tokens = usage_payload.get("total_tokens") if usage_payload else None
|
|
271
|
+
cost_per_token = (call_cost / total_tokens) if total_tokens else None
|
|
272
|
+
|
|
273
|
+
history_entry = {
|
|
274
|
+
"call_id": call_id,
|
|
275
|
+
"model": self.config.model,
|
|
276
|
+
"provider": self.config.provider,
|
|
277
|
+
"timestamp_utc": call_started_at.isoformat() + "Z",
|
|
278
|
+
"duration_seconds": round(duration, 4),
|
|
279
|
+
"response_format": "json_object" if response_format else "text",
|
|
280
|
+
"messages": copy.deepcopy(messages),
|
|
281
|
+
"response": content,
|
|
282
|
+
"usage": usage_payload or None,
|
|
283
|
+
"cost": call_cost,
|
|
284
|
+
"cost_per_token": cost_per_token,
|
|
285
|
+
}
|
|
286
|
+
self.call_history.append(history_entry)
|
|
287
|
+
logger.info(f"LLM call complete. Model: {self.config.model}, Cost: ${call_cost:.6f}")
|
|
288
|
+
|
|
289
|
+
def _extract_usage(self, response: Any) -> Dict[str, Any]:
|
|
290
|
+
"""
|
|
291
|
+
从 LiteLLM Response 中提取 token / 费用信息,确保可 JSON 序列化。
|
|
292
|
+
"""
|
|
293
|
+
usage = getattr(response, "usage", None)
|
|
294
|
+
if not usage:
|
|
295
|
+
return {}
|
|
296
|
+
|
|
297
|
+
payload: Dict[str, Any] = {
|
|
298
|
+
"prompt_tokens": self._safe_int(getattr(usage, "prompt_tokens", None)),
|
|
299
|
+
"completion_tokens": self._safe_int(getattr(usage, "completion_tokens", None)),
|
|
300
|
+
"total_tokens": self._safe_int(getattr(usage, "total_tokens", None)),
|
|
301
|
+
"prompt_tokens_cost": self._safe_float(getattr(usage, "prompt_tokens_cost", None)),
|
|
302
|
+
"completion_tokens_cost": self._safe_float(getattr(usage, "completion_tokens_cost", None)),
|
|
303
|
+
}
|
|
304
|
+
total_tokens_cost = self._safe_float(getattr(usage, "total_tokens_cost", None))
|
|
305
|
+
if total_tokens_cost is None:
|
|
306
|
+
prompt_cost = payload.get("prompt_tokens_cost")
|
|
307
|
+
completion_cost = payload.get("completion_tokens_cost")
|
|
308
|
+
if prompt_cost is not None and completion_cost is not None:
|
|
309
|
+
total_tokens_cost = prompt_cost + completion_cost
|
|
310
|
+
payload["total_tokens_cost"] = total_tokens_cost
|
|
311
|
+
return payload
|
|
312
|
+
|
|
313
|
+
async def call(self, prompt: str, system_message: Optional[str] = None, max_retries: Optional[int] = None) -> str:
|
|
314
|
+
"""
|
|
315
|
+
Makes a standard, asynchronous call to the LLM and returns the text response.
|
|
316
|
+
The retry logic is handled by the internal _make_llm_call_with_retries method.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
prompt: The user's prompt.
|
|
320
|
+
system_message: An optional system message to guide the LLM's behavior.
|
|
321
|
+
max_retries: Maximum number of retry attempts (default: 3).
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
The string content of the LLM's response.
|
|
325
|
+
"""
|
|
326
|
+
retries = max_retries if max_retries is not None else self.config.max_retries
|
|
327
|
+
messages = []
|
|
328
|
+
if system_message:
|
|
329
|
+
messages.append({"role": "system", "content": system_message})
|
|
330
|
+
messages.append({"role": "user", "content": prompt})
|
|
331
|
+
|
|
332
|
+
logger.debug(f"Calling LLM ({self.config.model}) with prompt: {prompt[:100]}...")
|
|
333
|
+
|
|
334
|
+
response = await self._make_llm_call_with_retries(messages, max_retries=retries)
|
|
335
|
+
content = response.choices[0].message.content
|
|
336
|
+
logger.info(f"content: {content}")
|
|
337
|
+
return content
|
|
338
|
+
|
|
339
|
+
async def call_with_json(self, prompt: str, output_model: Type[BaseModel], max_retries: Optional[int] = None) -> BaseModel:
|
|
340
|
+
"""
|
|
341
|
+
Calls the LLM and forces the output to be a JSON object conforming to the
|
|
342
|
+
provided Pydantic model. The retry logic is handled by the internal method.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
prompt: The user's prompt.
|
|
346
|
+
output_model: The Pydantic model class for the desired output structure.
|
|
347
|
+
max_retries: Maximum number of retry attempts (default: 3).
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
An instantiated Pydantic model with the LLM's response.
|
|
351
|
+
"""
|
|
352
|
+
retries = max_retries if max_retries is not None else self.config.max_retries
|
|
353
|
+
system_message = (
|
|
354
|
+
"You are a helpful assistant that always responds with a JSON object "
|
|
355
|
+
"that strictly adheres to the provided JSON Schema. Do not add any "
|
|
356
|
+
"other text, explanations, or markdown formatting."
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
prompt_with_schema = (
|
|
360
|
+
f"{prompt}\n\n# RESPONSE JSON SCHEMA:\n"
|
|
361
|
+
f"```json\n{output_model.model_json_schema()}\n```"
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
messages = [
|
|
365
|
+
{"role": "system", "content": system_message},
|
|
366
|
+
{"role": "user", "content": prompt_with_schema}
|
|
367
|
+
]
|
|
368
|
+
|
|
369
|
+
logger.debug(f"Calling LLM ({self.config.model}) for structured JSON output...")
|
|
370
|
+
if self._supports_response_format():
|
|
371
|
+
response = await self._make_llm_call_with_retries(
|
|
372
|
+
messages,
|
|
373
|
+
response_format={"type": "json_object"},
|
|
374
|
+
max_retries=retries,
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
logger.info(
|
|
378
|
+
"Model %s does not support `response_format`; falling back to prompt-enforced JSON.",
|
|
379
|
+
self.config.model,
|
|
380
|
+
)
|
|
381
|
+
response = await self._make_llm_call_with_retries(
|
|
382
|
+
messages,
|
|
383
|
+
response_format=None,
|
|
384
|
+
max_retries=retries,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
try:
|
|
388
|
+
response_content = response.choices[0].message.content
|
|
389
|
+
logger.info(f"content: {response_content}")
|
|
390
|
+
except (IndexError, AttributeError) as e:
|
|
391
|
+
raise LLMError(f"Invalid response structure from LLM: {e}") from e
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
return output_model.model_validate_json(response_content)
|
|
395
|
+
except ValidationError as e:
|
|
396
|
+
logger.error(f"Failed to validate LLM JSON response against Pydantic model: {e}")
|
|
397
|
+
logger.debug(f"Invalid JSON received: {response_content}")
|
|
398
|
+
raise LLMError(f"LLM returned invalid JSON that could not be parsed: {e}") from e
|
|
399
|
+
|
|
400
|
+
def get_total_cost(self) -> float:
|
|
401
|
+
"""Returns the total accumulated cost for this LLM instance."""
|
|
402
|
+
return self.total_cost
|
|
403
|
+
|
|
404
|
+
def get_call_history(self) -> List[Dict[str, Any]]:
|
|
405
|
+
"""Returns a deep copy of the call history for telemetry persistence."""
|
|
406
|
+
return copy.deepcopy(self.call_history)
|
|
407
|
+
|
|
408
|
+
def get_usage_summary(self) -> Dict[str, Any]:
|
|
409
|
+
"""汇总本实例的 token/费用信息。"""
|
|
410
|
+
total_tokens = self.total_prompt_tokens + self.total_completion_tokens
|
|
411
|
+
summary = {
|
|
412
|
+
"prompt_tokens": self.total_prompt_tokens,
|
|
413
|
+
"completion_tokens": self.total_completion_tokens,
|
|
414
|
+
"total_tokens": total_tokens,
|
|
415
|
+
"prompt_tokens_cost": round(self.total_prompt_cost, 12),
|
|
416
|
+
"completion_tokens_cost": round(self.total_completion_cost, 12),
|
|
417
|
+
"total_cost": round(self.total_cost, 12),
|
|
418
|
+
"call_count": len(self.call_history),
|
|
419
|
+
}
|
|
420
|
+
summary["cost_per_token"] = (self.total_cost / total_tokens) if total_tokens else None
|
|
421
|
+
return summary
|