dslighting 1.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. dsat/__init__.py +3 -0
  2. dsat/benchmark/__init__.py +1 -0
  3. dsat/benchmark/benchmark.py +168 -0
  4. dsat/benchmark/datasci.py +291 -0
  5. dsat/benchmark/mle.py +777 -0
  6. dsat/benchmark/sciencebench.py +304 -0
  7. dsat/common/__init__.py +0 -0
  8. dsat/common/constants.py +11 -0
  9. dsat/common/exceptions.py +48 -0
  10. dsat/common/typing.py +19 -0
  11. dsat/config.py +79 -0
  12. dsat/models/__init__.py +3 -0
  13. dsat/models/candidates.py +16 -0
  14. dsat/models/formats.py +52 -0
  15. dsat/models/task.py +64 -0
  16. dsat/operators/__init__.py +0 -0
  17. dsat/operators/aflow_ops.py +90 -0
  18. dsat/operators/autokaggle_ops.py +170 -0
  19. dsat/operators/automind_ops.py +38 -0
  20. dsat/operators/base.py +22 -0
  21. dsat/operators/code.py +45 -0
  22. dsat/operators/dsagent_ops.py +123 -0
  23. dsat/operators/llm_basic.py +84 -0
  24. dsat/prompts/__init__.py +0 -0
  25. dsat/prompts/aflow_prompt.py +76 -0
  26. dsat/prompts/aide_prompt.py +52 -0
  27. dsat/prompts/autokaggle_prompt.py +290 -0
  28. dsat/prompts/automind_prompt.py +29 -0
  29. dsat/prompts/common.py +51 -0
  30. dsat/prompts/data_interpreter_prompt.py +82 -0
  31. dsat/prompts/dsagent_prompt.py +88 -0
  32. dsat/runner.py +554 -0
  33. dsat/services/__init__.py +0 -0
  34. dsat/services/data_analyzer.py +387 -0
  35. dsat/services/llm.py +486 -0
  36. dsat/services/llm_single.py +421 -0
  37. dsat/services/sandbox.py +386 -0
  38. dsat/services/states/__init__.py +0 -0
  39. dsat/services/states/autokaggle_state.py +43 -0
  40. dsat/services/states/base.py +14 -0
  41. dsat/services/states/dsa_log.py +13 -0
  42. dsat/services/states/experience.py +237 -0
  43. dsat/services/states/journal.py +153 -0
  44. dsat/services/states/operator_library.py +290 -0
  45. dsat/services/vdb.py +76 -0
  46. dsat/services/workspace.py +178 -0
  47. dsat/tasks/__init__.py +3 -0
  48. dsat/tasks/handlers.py +376 -0
  49. dsat/templates/open_ended/grade_template.py +107 -0
  50. dsat/tools/__init__.py +4 -0
  51. dsat/utils/__init__.py +0 -0
  52. dsat/utils/context.py +172 -0
  53. dsat/utils/dynamic_import.py +71 -0
  54. dsat/utils/parsing.py +33 -0
  55. dsat/workflows/__init__.py +12 -0
  56. dsat/workflows/base.py +53 -0
  57. dsat/workflows/factory.py +439 -0
  58. dsat/workflows/manual/__init__.py +0 -0
  59. dsat/workflows/manual/autokaggle_workflow.py +148 -0
  60. dsat/workflows/manual/data_interpreter_workflow.py +153 -0
  61. dsat/workflows/manual/deepanalyze_workflow.py +484 -0
  62. dsat/workflows/manual/dsagent_workflow.py +76 -0
  63. dsat/workflows/search/__init__.py +0 -0
  64. dsat/workflows/search/aflow_workflow.py +344 -0
  65. dsat/workflows/search/aide_workflow.py +283 -0
  66. dsat/workflows/search/automind_workflow.py +237 -0
  67. dsat/workflows/templates/__init__.py +0 -0
  68. dsat/workflows/templates/basic_kaggle_loop.py +71 -0
  69. dslighting/__init__.py +170 -0
  70. dslighting/core/__init__.py +13 -0
  71. dslighting/core/agent.py +646 -0
  72. dslighting/core/config_builder.py +318 -0
  73. dslighting/core/data_loader.py +422 -0
  74. dslighting/core/task_detector.py +422 -0
  75. dslighting/utils/__init__.py +19 -0
  76. dslighting/utils/defaults.py +151 -0
  77. dslighting-1.3.9.dist-info/METADATA +554 -0
  78. dslighting-1.3.9.dist-info/RECORD +80 -0
  79. dslighting-1.3.9.dist-info/WHEEL +5 -0
  80. dslighting-1.3.9.dist-info/top_level.txt +2 -0
@@ -0,0 +1,421 @@
1
+ # dsat/services/llm.py
2
+
3
+ """
4
+ Unified, asynchronous LLM service powered by LiteLLM.
5
+ Provides a simple interface for standard calls, structured JSON output,
6
+ and automatic cost tracking.
7
+ """
8
+ import logging
9
+ import asyncio
10
+ import yaml
11
+ import copy
12
+ import time
13
+ import uuid
14
+ from datetime import datetime
15
+ from pathlib import Path
16
+ from typing import Type, Optional, Any, Dict, List
17
+
18
+ import litellm
19
+ from pydantic import BaseModel, ValidationError
20
+
21
+ from dsat.config import LLMConfig # Use the main pydantic config
22
+ from dsat.common.exceptions import LLMError
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Configure LiteLLM globally
27
+ litellm.telemetry = False # Disable anonymous telemetry
28
+ litellm.input_callbacks = []
29
+ litellm.success_callbacks = []
30
+ litellm.failure_callbacks = []
31
+
32
+ # Load custom model pricing from YAML configuration file
33
+ def _load_custom_model_pricing():
34
+ """Load custom model pricing configuration from config.yaml file."""
35
+ try:
36
+ # Get the path to the config.yaml file relative to this module
37
+ current_dir = Path(__file__).parent
38
+ framework_dir = current_dir.parent.parent # Go up to ds_agent_framework
39
+ config_yaml_path = framework_dir / "config.yaml"
40
+
41
+ if config_yaml_path.exists():
42
+ with open(config_yaml_path, 'r', encoding='utf-8') as f:
43
+ config = yaml.safe_load(f)
44
+ return config.get('custom_model_pricing', {})
45
+ else:
46
+ # Changed to debug to avoid confusing warnings for pip-installed packages
47
+ logger.debug(f"Config file not found at {config_yaml_path} (this is expected for pip-installed packages)")
48
+ return {}
49
+ except Exception as e:
50
+ logger.error(f"Failed to load cost configuration: {e}")
51
+ return {}
52
+
53
+ # Load and apply custom model pricing
54
+ CUSTOM_MODEL_PRICING = _load_custom_model_pricing()
55
+ if CUSTOM_MODEL_PRICING:
56
+ litellm.model_cost.update(CUSTOM_MODEL_PRICING)
57
+ logger.info(f"Loaded custom model pricing for {len(CUSTOM_MODEL_PRICING)} models")
58
+
59
+ class LLMService:
60
+ """
61
+ A robust wrapper around LiteLLM that handles requests, structured formatting,
62
+ and cost tracking. It's configured via the main DSATConfig's LLMConfig.
63
+ """
64
+ def __init__(self, config: LLMConfig):
65
+ self.config = config
66
+ self.total_cost = 0.0
67
+ self.total_prompt_tokens = 0
68
+ self.total_completion_tokens = 0
69
+ self.total_prompt_cost = 0.0
70
+ self.total_completion_cost = 0.0
71
+ self.call_history: List[Dict[str, Any]] = []
72
+
73
+ @staticmethod
74
+ def _safe_float(value: Any) -> Optional[float]:
75
+ try:
76
+ return float(value)
77
+ except (TypeError, ValueError):
78
+ return None
79
+
80
+ @staticmethod
81
+ def _safe_int(value: Any) -> Optional[int]:
82
+ try:
83
+ return int(value)
84
+ except (TypeError, ValueError):
85
+ return None
86
+
87
+ def _is_retryable_error(self, error: Exception) -> bool:
88
+ """
89
+ Determines whether an error is retryable based on litellm exception types.
90
+ Only network timeouts, rate limits, and temporary service issues should be retried.
91
+ Authentication errors, invalid requests, and other permanent failures should not be retried.
92
+ """
93
+ # Import litellm exceptions locally to avoid import issues
94
+ try:
95
+ import litellm.exceptions as litellm_exceptions
96
+ except ImportError:
97
+ # If litellm exceptions module is not available, be conservative and retry
98
+ return True
99
+
100
+ # Non-retryable errors - fail immediately
101
+ non_retryable_errors = (
102
+ litellm_exceptions.AuthenticationError, # API key issues
103
+ litellm_exceptions.InvalidRequestError, # Request format/parameter issues
104
+ litellm_exceptions.PermissionDeniedError, # Insufficient permissions
105
+ litellm_exceptions.NotFoundError, # Model/endpoint not found
106
+ )
107
+
108
+ # Retryable errors - can be retried
109
+ retryable_errors = (
110
+ litellm_exceptions.RateLimitError, # Rate limit exceeded
111
+ litellm_exceptions.ServiceUnavailableError, # Temporary service issues
112
+ litellm_exceptions.Timeout, # Network timeout
113
+ litellm_exceptions.APIConnectionError, # Connection issues
114
+ litellm_exceptions.InternalServerError, # Server-side temporary issues
115
+ )
116
+
117
+ # Check for specific error types
118
+ if isinstance(error, non_retryable_errors):
119
+ return False
120
+ elif isinstance(error, retryable_errors):
121
+ return True
122
+ else:
123
+ # For unknown errors, be conservative and retry
124
+ # This handles generic network errors, etc.
125
+ return True
126
+
127
+ def _supports_response_format(self) -> bool:
128
+ """
129
+ Whether it's safe to pass `response_format` through LiteLLM for this model.
130
+
131
+ Some OpenAI reasoning models (e.g. `o4-mini-*`) reject `response_format` and
132
+ require JSON-only behavior to be enforced via prompt instead.
133
+ """
134
+ raw_model = (self.config.model or "").strip()
135
+ model = raw_model.split("/")[-1].strip()
136
+ if model.startswith("o4-mini-") or model == "o4-mini":
137
+ return False
138
+ return True
139
+
140
+ async def _make_llm_call_with_retries(
141
+ self, messages: list, response_format: Optional[dict] = None, max_retries: int = 3, base_delay: float = 1.0
142
+ ):
143
+ """
144
+ Internal method to make LLM calls with centralized retry logic and exponential backoff.
145
+ This method is the single point of contact with the LiteLLM library.
146
+
147
+ Args:
148
+ messages: List of message dictionaries for the LLM.
149
+ response_format: Optional response format specification.
150
+ max_retries: Maximum number of retry attempts.
151
+ base_delay: Base delay in seconds for exponential backoff.
152
+
153
+ Returns:
154
+ The raw LiteLLM response object upon success.
155
+
156
+ Raises:
157
+ LLMError: If all retry attempts fail due to API errors or empty responses.
158
+ """
159
+ if response_format and not self._supports_response_format():
160
+ logger.info(
161
+ "Dropping unsupported `response_format` for model %s; using prompt-enforced JSON instead.",
162
+ self.config.model,
163
+ )
164
+ response_format = None
165
+
166
+ logger.info(f"prompt: {messages[-1]['content']}")
167
+ last_exception = None
168
+ for attempt in range(max_retries):
169
+ call_id = uuid.uuid4().hex
170
+ call_started_at = datetime.utcnow()
171
+ perf_start = time.perf_counter()
172
+ try:
173
+ kwargs = {
174
+ "model": self.config.model,
175
+ "messages": messages,
176
+ "temperature": self.config.temperature,
177
+ "api_key": self.config.api_key,
178
+ "api_base": self.config.api_base
179
+ }
180
+ if self.config.provider:
181
+ kwargs["custom_llm_provider"] = self.config.provider
182
+
183
+ if response_format:
184
+ kwargs["response_format"] = response_format
185
+
186
+ response = await litellm.acompletion(**kwargs)
187
+
188
+ try:
189
+ content = response.choices[0].message.content
190
+ if content and content.strip():
191
+ duration = time.perf_counter() - perf_start
192
+ self._record_successful_call(
193
+ call_id=call_id,
194
+ call_started_at=call_started_at,
195
+ duration=duration,
196
+ messages=messages,
197
+ response=response,
198
+ content=content,
199
+ response_format=response_format
200
+ )
201
+ return response # Success!
202
+ else:
203
+ # Treat empty response as a failure to be retried
204
+ logger.warning(f"LLM returned an empty response on attempt {attempt + 1}/{max_retries}.")
205
+ last_exception = LLMError("LLM returned an empty response.")
206
+ except (IndexError, AttributeError) as content_error:
207
+ logger.warning(f"Invalid response structure on attempt {attempt + 1}/{max_retries}: {content_error}")
208
+ last_exception = LLMError(f"Invalid response structure: {content_error}")
209
+
210
+ except Exception as e:
211
+ # Check if this is a retryable error
212
+ if self._is_retryable_error(e):
213
+ logger.warning(f"Retryable LLM error on attempt {attempt + 1}/{max_retries}: {e}")
214
+ last_exception = e
215
+ logger.debug(f"Debug info - messages: {messages}, response_format: {response_format if response_format else 'None'}")
216
+ else:
217
+ # Non-retryable error - fail immediately
218
+ logger.error(f"Non-retryable LLM error: {e}")
219
+ raise LLMError(f"LLM call failed with non-retryable error: {e}") from e
220
+
221
+ # If this was the last attempt, break the loop to raise the final error
222
+ if attempt == max_retries - 1:
223
+ break
224
+
225
+ # Exponential backoff with jitter
226
+ delay = base_delay * (3 ** attempt) + (asyncio.get_event_loop().time() % 1)
227
+ logger.info(f"Retrying LLM call in {delay:.2f} seconds ({attempt + 2}/{max_retries})...")
228
+ await asyncio.sleep(delay)
229
+
230
+ raise LLMError(f"LLM call failed after {max_retries} attempts. Last error: {last_exception}") from last_exception
231
+
232
+ def _record_successful_call(
233
+ self,
234
+ call_id: str,
235
+ call_started_at: datetime,
236
+ duration: float,
237
+ messages: list,
238
+ response: Any,
239
+ content: str,
240
+ response_format: Optional[dict],
241
+ ) -> None:
242
+ """
243
+ 将一次成功的调用附加到历史中,并更新累计 token / 费用。
244
+ """
245
+ usage_payload = self._extract_usage(response)
246
+ try:
247
+ call_cost_raw = litellm.completion_cost(completion_response=response)
248
+ call_cost = float(call_cost_raw) if call_cost_raw is not None else 0.0
249
+ except Exception:
250
+ call_cost = 0.0
251
+
252
+ self.total_cost += call_cost
253
+
254
+ prompt_tokens = usage_payload.get("prompt_tokens") if usage_payload else None
255
+ completion_tokens = usage_payload.get("completion_tokens") if usage_payload else None
256
+
257
+ if prompt_tokens:
258
+ self.total_prompt_tokens += prompt_tokens
259
+ if completion_tokens:
260
+ self.total_completion_tokens += completion_tokens
261
+
262
+ prompt_cost_val = usage_payload.get("prompt_tokens_cost") if usage_payload else None
263
+ completion_cost_val = usage_payload.get("completion_tokens_cost") if usage_payload else None
264
+
265
+ if prompt_cost_val is not None:
266
+ self.total_prompt_cost += prompt_cost_val
267
+ if completion_cost_val is not None:
268
+ self.total_completion_cost += completion_cost_val
269
+
270
+ total_tokens = usage_payload.get("total_tokens") if usage_payload else None
271
+ cost_per_token = (call_cost / total_tokens) if total_tokens else None
272
+
273
+ history_entry = {
274
+ "call_id": call_id,
275
+ "model": self.config.model,
276
+ "provider": self.config.provider,
277
+ "timestamp_utc": call_started_at.isoformat() + "Z",
278
+ "duration_seconds": round(duration, 4),
279
+ "response_format": "json_object" if response_format else "text",
280
+ "messages": copy.deepcopy(messages),
281
+ "response": content,
282
+ "usage": usage_payload or None,
283
+ "cost": call_cost,
284
+ "cost_per_token": cost_per_token,
285
+ }
286
+ self.call_history.append(history_entry)
287
+ logger.info(f"LLM call complete. Model: {self.config.model}, Cost: ${call_cost:.6f}")
288
+
289
+ def _extract_usage(self, response: Any) -> Dict[str, Any]:
290
+ """
291
+ 从 LiteLLM Response 中提取 token / 费用信息,确保可 JSON 序列化。
292
+ """
293
+ usage = getattr(response, "usage", None)
294
+ if not usage:
295
+ return {}
296
+
297
+ payload: Dict[str, Any] = {
298
+ "prompt_tokens": self._safe_int(getattr(usage, "prompt_tokens", None)),
299
+ "completion_tokens": self._safe_int(getattr(usage, "completion_tokens", None)),
300
+ "total_tokens": self._safe_int(getattr(usage, "total_tokens", None)),
301
+ "prompt_tokens_cost": self._safe_float(getattr(usage, "prompt_tokens_cost", None)),
302
+ "completion_tokens_cost": self._safe_float(getattr(usage, "completion_tokens_cost", None)),
303
+ }
304
+ total_tokens_cost = self._safe_float(getattr(usage, "total_tokens_cost", None))
305
+ if total_tokens_cost is None:
306
+ prompt_cost = payload.get("prompt_tokens_cost")
307
+ completion_cost = payload.get("completion_tokens_cost")
308
+ if prompt_cost is not None and completion_cost is not None:
309
+ total_tokens_cost = prompt_cost + completion_cost
310
+ payload["total_tokens_cost"] = total_tokens_cost
311
+ return payload
312
+
313
+ async def call(self, prompt: str, system_message: Optional[str] = None, max_retries: Optional[int] = None) -> str:
314
+ """
315
+ Makes a standard, asynchronous call to the LLM and returns the text response.
316
+ The retry logic is handled by the internal _make_llm_call_with_retries method.
317
+
318
+ Args:
319
+ prompt: The user's prompt.
320
+ system_message: An optional system message to guide the LLM's behavior.
321
+ max_retries: Maximum number of retry attempts (default: 3).
322
+
323
+ Returns:
324
+ The string content of the LLM's response.
325
+ """
326
+ retries = max_retries if max_retries is not None else self.config.max_retries
327
+ messages = []
328
+ if system_message:
329
+ messages.append({"role": "system", "content": system_message})
330
+ messages.append({"role": "user", "content": prompt})
331
+
332
+ logger.debug(f"Calling LLM ({self.config.model}) with prompt: {prompt[:100]}...")
333
+
334
+ response = await self._make_llm_call_with_retries(messages, max_retries=retries)
335
+ content = response.choices[0].message.content
336
+ logger.info(f"content: {content}")
337
+ return content
338
+
339
+ async def call_with_json(self, prompt: str, output_model: Type[BaseModel], max_retries: Optional[int] = None) -> BaseModel:
340
+ """
341
+ Calls the LLM and forces the output to be a JSON object conforming to the
342
+ provided Pydantic model. The retry logic is handled by the internal method.
343
+
344
+ Args:
345
+ prompt: The user's prompt.
346
+ output_model: The Pydantic model class for the desired output structure.
347
+ max_retries: Maximum number of retry attempts (default: 3).
348
+
349
+ Returns:
350
+ An instantiated Pydantic model with the LLM's response.
351
+ """
352
+ retries = max_retries if max_retries is not None else self.config.max_retries
353
+ system_message = (
354
+ "You are a helpful assistant that always responds with a JSON object "
355
+ "that strictly adheres to the provided JSON Schema. Do not add any "
356
+ "other text, explanations, or markdown formatting."
357
+ )
358
+
359
+ prompt_with_schema = (
360
+ f"{prompt}\n\n# RESPONSE JSON SCHEMA:\n"
361
+ f"```json\n{output_model.model_json_schema()}\n```"
362
+ )
363
+
364
+ messages = [
365
+ {"role": "system", "content": system_message},
366
+ {"role": "user", "content": prompt_with_schema}
367
+ ]
368
+
369
+ logger.debug(f"Calling LLM ({self.config.model}) for structured JSON output...")
370
+ if self._supports_response_format():
371
+ response = await self._make_llm_call_with_retries(
372
+ messages,
373
+ response_format={"type": "json_object"},
374
+ max_retries=retries,
375
+ )
376
+ else:
377
+ logger.info(
378
+ "Model %s does not support `response_format`; falling back to prompt-enforced JSON.",
379
+ self.config.model,
380
+ )
381
+ response = await self._make_llm_call_with_retries(
382
+ messages,
383
+ response_format=None,
384
+ max_retries=retries,
385
+ )
386
+
387
+ try:
388
+ response_content = response.choices[0].message.content
389
+ logger.info(f"content: {response_content}")
390
+ except (IndexError, AttributeError) as e:
391
+ raise LLMError(f"Invalid response structure from LLM: {e}") from e
392
+
393
+ try:
394
+ return output_model.model_validate_json(response_content)
395
+ except ValidationError as e:
396
+ logger.error(f"Failed to validate LLM JSON response against Pydantic model: {e}")
397
+ logger.debug(f"Invalid JSON received: {response_content}")
398
+ raise LLMError(f"LLM returned invalid JSON that could not be parsed: {e}") from e
399
+
400
+ def get_total_cost(self) -> float:
401
+ """Returns the total accumulated cost for this LLM instance."""
402
+ return self.total_cost
403
+
404
+ def get_call_history(self) -> List[Dict[str, Any]]:
405
+ """Returns a deep copy of the call history for telemetry persistence."""
406
+ return copy.deepcopy(self.call_history)
407
+
408
+ def get_usage_summary(self) -> Dict[str, Any]:
409
+ """汇总本实例的 token/费用信息。"""
410
+ total_tokens = self.total_prompt_tokens + self.total_completion_tokens
411
+ summary = {
412
+ "prompt_tokens": self.total_prompt_tokens,
413
+ "completion_tokens": self.total_completion_tokens,
414
+ "total_tokens": total_tokens,
415
+ "prompt_tokens_cost": round(self.total_prompt_cost, 12),
416
+ "completion_tokens_cost": round(self.total_completion_cost, 12),
417
+ "total_cost": round(self.total_cost, 12),
418
+ "call_count": len(self.call_history),
419
+ }
420
+ summary["cost_per_token"] = (self.total_cost / total_tokens) if total_tokens else None
421
+ return summary