dslighting 1.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. dsat/__init__.py +3 -0
  2. dsat/benchmark/__init__.py +1 -0
  3. dsat/benchmark/benchmark.py +168 -0
  4. dsat/benchmark/datasci.py +291 -0
  5. dsat/benchmark/mle.py +777 -0
  6. dsat/benchmark/sciencebench.py +304 -0
  7. dsat/common/__init__.py +0 -0
  8. dsat/common/constants.py +11 -0
  9. dsat/common/exceptions.py +48 -0
  10. dsat/common/typing.py +19 -0
  11. dsat/config.py +79 -0
  12. dsat/models/__init__.py +3 -0
  13. dsat/models/candidates.py +16 -0
  14. dsat/models/formats.py +52 -0
  15. dsat/models/task.py +64 -0
  16. dsat/operators/__init__.py +0 -0
  17. dsat/operators/aflow_ops.py +90 -0
  18. dsat/operators/autokaggle_ops.py +170 -0
  19. dsat/operators/automind_ops.py +38 -0
  20. dsat/operators/base.py +22 -0
  21. dsat/operators/code.py +45 -0
  22. dsat/operators/dsagent_ops.py +123 -0
  23. dsat/operators/llm_basic.py +84 -0
  24. dsat/prompts/__init__.py +0 -0
  25. dsat/prompts/aflow_prompt.py +76 -0
  26. dsat/prompts/aide_prompt.py +52 -0
  27. dsat/prompts/autokaggle_prompt.py +290 -0
  28. dsat/prompts/automind_prompt.py +29 -0
  29. dsat/prompts/common.py +51 -0
  30. dsat/prompts/data_interpreter_prompt.py +82 -0
  31. dsat/prompts/dsagent_prompt.py +88 -0
  32. dsat/runner.py +554 -0
  33. dsat/services/__init__.py +0 -0
  34. dsat/services/data_analyzer.py +387 -0
  35. dsat/services/llm.py +486 -0
  36. dsat/services/llm_single.py +421 -0
  37. dsat/services/sandbox.py +386 -0
  38. dsat/services/states/__init__.py +0 -0
  39. dsat/services/states/autokaggle_state.py +43 -0
  40. dsat/services/states/base.py +14 -0
  41. dsat/services/states/dsa_log.py +13 -0
  42. dsat/services/states/experience.py +237 -0
  43. dsat/services/states/journal.py +153 -0
  44. dsat/services/states/operator_library.py +290 -0
  45. dsat/services/vdb.py +76 -0
  46. dsat/services/workspace.py +178 -0
  47. dsat/tasks/__init__.py +3 -0
  48. dsat/tasks/handlers.py +376 -0
  49. dsat/templates/open_ended/grade_template.py +107 -0
  50. dsat/tools/__init__.py +4 -0
  51. dsat/utils/__init__.py +0 -0
  52. dsat/utils/context.py +172 -0
  53. dsat/utils/dynamic_import.py +71 -0
  54. dsat/utils/parsing.py +33 -0
  55. dsat/workflows/__init__.py +12 -0
  56. dsat/workflows/base.py +53 -0
  57. dsat/workflows/factory.py +439 -0
  58. dsat/workflows/manual/__init__.py +0 -0
  59. dsat/workflows/manual/autokaggle_workflow.py +148 -0
  60. dsat/workflows/manual/data_interpreter_workflow.py +153 -0
  61. dsat/workflows/manual/deepanalyze_workflow.py +484 -0
  62. dsat/workflows/manual/dsagent_workflow.py +76 -0
  63. dsat/workflows/search/__init__.py +0 -0
  64. dsat/workflows/search/aflow_workflow.py +344 -0
  65. dsat/workflows/search/aide_workflow.py +283 -0
  66. dsat/workflows/search/automind_workflow.py +237 -0
  67. dsat/workflows/templates/__init__.py +0 -0
  68. dsat/workflows/templates/basic_kaggle_loop.py +71 -0
  69. dslighting/__init__.py +170 -0
  70. dslighting/core/__init__.py +13 -0
  71. dslighting/core/agent.py +646 -0
  72. dslighting/core/config_builder.py +318 -0
  73. dslighting/core/data_loader.py +422 -0
  74. dslighting/core/task_detector.py +422 -0
  75. dslighting/utils/__init__.py +19 -0
  76. dslighting/utils/defaults.py +151 -0
  77. dslighting-1.3.9.dist-info/METADATA +554 -0
  78. dslighting-1.3.9.dist-info/RECORD +80 -0
  79. dslighting-1.3.9.dist-info/WHEEL +5 -0
  80. dslighting-1.3.9.dist-info/top_level.txt +2 -0
dsat/benchmark/mle.py ADDED
@@ -0,0 +1,777 @@
1
+ # dsat/benchmark/mle.py
2
+
3
+ import os
4
+ import time
5
+ import uuid
6
+ import yaml
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any, Callable, List, Tuple, Optional, Dict
10
+ import pandas as pd
11
+ import logging
12
+
13
+ from dsat.benchmark.benchmark import BaseBenchmark
14
+
15
+ # --- mlebench related imports ---
16
+ from mlebench.data import is_dataset_prepared
17
+ from mlebench.grade import aggregate_reports, grade_csv
18
+ from mlebench.grade_helpers import CompetitionReport
19
+ from mlebench.registry import Competition, Registry
20
+ from mlebench.registry import registry as DEFAULT_MLE_REGISTRY
21
+
22
+ # --- DSAT core model imports ---
23
+ from dsat.models.task import TaskDefinition
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class MLEBenchmark(BaseBenchmark):
29
+ """
30
+ Benchmark class to integrate mle_bench competitions into the DSAT framework.
31
+ """
32
+ def __init__(
33
+ self,
34
+ name: str,
35
+ file_path: Optional[str],
36
+ log_path: str,
37
+ data_dir: Optional[str] = None,
38
+ competitions: Optional[List[str]] = None,
39
+ data_source: str = "prepared" # Default to "prepared" for competition simulation
40
+ ):
41
+ # Set up data_dir and registry before calling parent constructor
42
+ self.data_dir = Path(data_dir) if data_dir else DEFAULT_MLE_REGISTRY.get_data_dir()
43
+ self.registry: Registry = DEFAULT_MLE_REGISTRY.set_data_dir(self.data_dir)
44
+ self.data_source = data_source # Save data source preference
45
+
46
+ # Load configuration
47
+ self.config = self._load_config()
48
+
49
+ if competitions:
50
+ # Intelligent merge: Filter config entries by CLI args, preserving metadata (like mode).
51
+ cli_ids = set(competitions)
52
+ merged_list = []
53
+
54
+ # 1. Index existing config entries
55
+ config_map = {}
56
+ for entry in self.config.get("competitions", []):
57
+ if isinstance(entry, str):
58
+ config_map[entry] = {"id": entry, "mode": "standard_ml"}
59
+ elif isinstance(entry, dict) and "id" in entry:
60
+ config_map[entry["id"]] = entry
61
+
62
+ # 2. Build new list based on CLI args
63
+ for cid in competitions:
64
+ if cid in config_map:
65
+ merged_list.append(config_map[cid])
66
+ else:
67
+ # New competition not in config, use default string format
68
+ merged_list.append(cid)
69
+
70
+ self.config["competitions"] = merged_list
71
+
72
+ # file_path is accepted for compatibility but will be ignored.
73
+ super().__init__(name, file_path, log_path)
74
+
75
+ Path(self.log_path).mkdir(parents=True, exist_ok=True)
76
+
77
+ # RE-INITIALIZE problems by calling the correct loader after registry is set up.
78
+ self.problems = self._load_problems()
79
+ logger.info(f"MLEBenchmark initialized with data_dir: {self.data_dir}")
80
+
81
+ def _load_config(self) -> Dict[str, Any]:
82
+ """Load configuration from config.yaml file."""
83
+ try:
84
+ # Get the path to the config.yaml file relative to this module
85
+ framework_dir = Path(__file__).parent.parent.parent
86
+ config_path = framework_dir / "config.yaml"
87
+
88
+ if not config_path.exists():
89
+ logger.warning(f"Config file not found at {config_path}, using default configuration")
90
+ return {"competitions": []}
91
+
92
+ with open(config_path, 'r', encoding='utf-8') as f:
93
+ config = yaml.safe_load(f)
94
+ return config or {"competitions": []}
95
+ except Exception as e:
96
+ logger.error(f"Error loading config file: {e}, using default configuration")
97
+ return {"competitions": []}
98
+
99
+ def _load_problems(self) -> List[Dict[str, Any]]:
100
+ """
101
+ Dynamically load competition problems from the mlebench registry
102
+ instead of a static file. This is the correct integration pattern.
103
+ """
104
+ logger.info(f"Discovering prepared competitions in {self.data_dir}...")
105
+
106
+ if not self.data_dir.exists():
107
+ raise FileNotFoundError(
108
+ f"MLEBench data directory not found: {self.data_dir}. "
109
+ "Please provide a valid path via --mle-data-dir."
110
+ )
111
+
112
+ problems = []
113
+ # The competition IDs are loaded from configuration (already merged with arguments in __init__)
114
+ competition_entries = self.config.get("competitions", [])
115
+
116
+ if not competition_entries:
117
+ logger.warning("No competitions configured in config.yaml")
118
+ return problems
119
+
120
+ # Iterate over competitions from configuration
121
+ for entry in competition_entries:
122
+ # Handle both string (legacy) and dict (new) config formats
123
+ if isinstance(entry, str):
124
+ comp_id = entry
125
+ mode = "standard_ml"
126
+ elif isinstance(entry, dict):
127
+ comp_id = entry.get("id")
128
+ mode = entry.get("mode", "standard_ml")
129
+ else:
130
+ logger.warning(f"Invalid competition entry type: {type(entry)}")
131
+ continue
132
+
133
+ if not comp_id:
134
+ continue
135
+
136
+ try:
137
+ competition = self.registry.get_competition(comp_id)
138
+
139
+ # For open_ended tasks, skip the prepared check and add directly
140
+ if mode == "open_ended":
141
+ problems.append({
142
+ "competition_id": comp_id,
143
+ "mode": mode
144
+ })
145
+ logger.info(f"Found open-ended competition: {comp_id} (mode={mode})")
146
+ elif is_dataset_prepared(competition, grading_only=False):
147
+ problems.append({
148
+ "competition_id": comp_id,
149
+ "mode": mode
150
+ })
151
+ logger.debug(f"Found prepared competition: {comp_id} (mode={mode})")
152
+ else:
153
+ # Standard ML tasks must be prepared
154
+ logger.warning(
155
+ f"Skipping standard ML competition '{comp_id}' as its dataset is not fully prepared."
156
+ )
157
+ except Exception as e:
158
+ logger.warning(f"Error loading competition '{comp_id}': {e}")
159
+
160
+ if not problems:
161
+ logger.error(
162
+ f"No prepared competitions found in {self.data_dir}. "
163
+ "Please run `mlebench prep <competition_id>` or `mlebench prep --all`."
164
+ )
165
+
166
+ return problems
167
+
168
+ def set_mode(self, mode: str):
169
+ """Sets the benchmark mode to 'validation' or 'test'."""
170
+ logger.info(f"Setting MLEBenchmark mode to '{mode}'")
171
+ self.registry.set_mode(mode)
172
+
173
+ async def grade(self, submission_path: Path) -> float:
174
+ """Grades a submission and returns a numerical fitness score for AFlow."""
175
+ if not self.problems:
176
+ return 0.0
177
+
178
+ # Assume we are grading for the first (and likely only) competition
179
+ problem = self.problems[0]
180
+ competition_id = problem["competition_id"]
181
+ mode = problem.get("mode", "standard_ml")
182
+
183
+ if mode == "open_ended":
184
+ # For open-ended tasks, use LLM judge for scoring
185
+ return await self._grade_open_ended(submission_path, competition_id, mode)
186
+
187
+ try:
188
+ competition = self.registry.get_competition(competition_id)
189
+ if not submission_path.exists():
190
+ logger.warning(f"Grading failed: submission file not found at {submission_path}")
191
+ return 0.0
192
+
193
+ report = grade_csv(submission_path, competition)
194
+ score = report.score if report.score is not None else 0.0
195
+
196
+ # Normalize score to be a fitness value (higher is better)
197
+ if report.is_lower_better:
198
+ return 1.0 / (1.0 + score) if score > 0 else 1.0
199
+ return score
200
+
201
+ except Exception as e:
202
+ logger.error(f"Error during grading for {competition_id}: {e}")
203
+ return 0.0
204
+
205
+ async def _grade_open_ended(self, artifacts_path: Path, competition_id: str, mode: str) -> float:
206
+ """
207
+ Grade open-ended tasks using LLM judge.
208
+
209
+ Args:
210
+ artifacts_path: Path to the artifacts directory
211
+ competition_id: Competition identifier
212
+ mode: Task mode (should be "open_ended")
213
+
214
+ Returns:
215
+ Float score between 0.0 and 1.0
216
+ """
217
+ try:
218
+ # Check if artifacts directory exists
219
+ if not artifacts_path.exists():
220
+ logger.warning(f"Artifacts directory not found: {artifacts_path}")
221
+ return 0.0
222
+
223
+ # Load task description and rubric
224
+ competition = self.registry.get_competition(competition_id)
225
+ competition_dir = competition.raw_dir.parent
226
+
227
+ description_path = competition_dir / "description.md"
228
+ rubric_path = competition_dir / "rubric.md"
229
+
230
+ description = ""
231
+ rubric = ""
232
+
233
+ if description_path.exists():
234
+ description = description_path.read_text(encoding='utf-8')
235
+ else:
236
+ description = competition.description
237
+
238
+ if rubric_path.exists():
239
+ rubric = rubric_path.read_text(encoding='utf-8')
240
+
241
+ # Analyze artifacts directory contents
242
+ artifact_files = list(artifacts_path.iterdir()) if artifacts_path.is_dir() else []
243
+
244
+ # Build file listing
245
+ file_summary = []
246
+ for file in sorted(artifact_files):
247
+ if file.is_file():
248
+ # Read file content based on type
249
+ try:
250
+ if file.suffix in ['.py', '.md', '.txt', '.csv']:
251
+ content = file.read_text(encoding='utf-8', errors='ignore')
252
+ # Truncate large files
253
+ if len(content) > 2000:
254
+ content = content[:2000] + "\n... (truncated)"
255
+ file_summary.append(f"### {file.name}\n```\n{content}\n```")
256
+ elif file.suffix in ['.png', '.jpg', '.jpeg']:
257
+ file_summary.append(f"### {file.name}\n[Image file - {file.stat().st_size} bytes]")
258
+ else:
259
+ file_summary.append(f"### {file.name}\n[Binary file - {file.stat().st_size} bytes]")
260
+ except Exception as e:
261
+ file_summary.append(f"### {file.name}\n[Error reading file: {e}]")
262
+
263
+ # If no files found, return minimal score
264
+ if not file_summary:
265
+ logger.warning(f"No artifacts found in {artifacts_path}")
266
+ return 0.1 # Small score for at least creating the directory
267
+
268
+ # Construct LLM judge prompt
269
+ # Define separator outside f-string (backslashes not allowed in f-string expressions)
270
+ file_sep = "\n\n"
271
+ judge_prompt = f"""You are an expert judge for open-ended data science tasks. Your role is to evaluate submissions based on given criteria.
272
+
273
+ ## Task Description
274
+
275
+ {description}
276
+
277
+ ## Evaluation Criteria
278
+
279
+ {rubric if rubric else "Evaluate based on completeness, correctness, and quality of the solution."}
280
+
281
+ ## Submitted Artifacts
282
+
283
+ The following files were submitted:
284
+
285
+ {file_sep.join(file_summary)}
286
+
287
+ ## Scoring Instructions
288
+
289
+ Evaluate the submission on a scale from 0.0 to 1.0 based on:
290
+
291
+ 1. **Completeness** (0-0.3): Did the submission address all aspects of the task?
292
+ 2. **Correctness** (0-0.4): Are the methods and results correct?
293
+ 3. **Quality** (0-0.3): Is the code well-structured, documented, and the analysis thorough?
294
+
295
+ Provide your response in the following JSON format:
296
+
297
+ ```json
298
+ {{
299
+ "reasoning": "Brief explanation of the score",
300
+ "completeness": 0.0-0.3,
301
+ "correctness": 0.0-0.4,
302
+ "quality": 0.0-0.3,
303
+ "total_score": 0.0-1.0
304
+ }}
305
+ ```
306
+
307
+ Respond ONLY with the JSON object, no additional text.
308
+ """
309
+
310
+ # Call LLM for judgment
311
+ try:
312
+ from dsat.services.llm import LLMService
313
+ from dsat.config import LLMConfig
314
+
315
+ # Get LLM service - need to initialize it
316
+ # Try to get LLM model from config or environment
317
+ llm_model = os.environ.get('LLM_MODEL', 'glm-4.7')
318
+ # For DSATRunner context, api_key etc are usually available in env or config
319
+ # Here we try to instantiate a service quickly
320
+
321
+ # Assuming environment variables are set from .env
322
+ api_key = os.environ.get('API_KEY')
323
+ api_base = os.environ.get('API_BASE')
324
+ provider = os.environ.get('LLM_PROVIDER', 'openai')
325
+
326
+ llm_config = LLMConfig(
327
+ model=llm_model,
328
+ api_key=api_key,
329
+ api_base=api_base,
330
+ provider=provider
331
+ )
332
+ llm_service = LLMService(llm_config)
333
+
334
+ logger.info(f"Calling LLM judge for {competition_id} with model {llm_model}")
335
+
336
+ # Call LLM
337
+ # Note: LLMService.achat is async
338
+ messages = [{"role": "user", "content": judge_prompt}]
339
+ llm_response = await llm_service.achat(messages)
340
+
341
+ # Parse JSON response
342
+ import json
343
+ import re
344
+
345
+ # Extract JSON from response
346
+ json_match = re.search(r'\{[^{}]*\}', llm_response, re.DOTALL)
347
+ if json_match:
348
+ result = json.loads(json_match.group())
349
+ score = float(result.get('total_score', 0.5))
350
+
351
+ # Log detailed breakdown
352
+ completeness = result.get('completeness', 0.0)
353
+ correctness = result.get('correctness', 0.0)
354
+ quality = result.get('quality', 0.0)
355
+ reasoning = result.get('reasoning', '')
356
+
357
+ logger.info(f"LLM judge score for {competition_id}: {score:.2f}")
358
+ logger.info(f" Breakdown: completeness={completeness:.2f}, correctness={correctness:.2f}, quality={quality:.2f}")
359
+ logger.info(f" Reasoning: {reasoning[:200]}...")
360
+
361
+ return max(0.0, min(score, 1.0)) # Ensure score is in [0, 1]
362
+ else:
363
+ logger.warning(f"Failed to parse LLM judge response as JSON: {llm_response[:200]}")
364
+ # Fallback to heuristic if parsing fails
365
+ return self._heuristic_score(artifact_files)
366
+
367
+ except Exception as llm_error:
368
+ logger.error(f"LLM judge call failed: {llm_error}, using heuristic scoring")
369
+ return self._heuristic_score(artifact_files)
370
+
371
+ except Exception as e:
372
+ logger.error(f"Error during LLM judge grading: {e}", exc_info=True)
373
+ return 0.0
374
+
375
+ def _heuristic_score(self, artifact_files) -> float:
376
+ """Fallback heuristic scoring when LLM judge is unavailable."""
377
+ has_code = any(f.suffix == '.py' for f in artifact_files if f.is_file())
378
+ has_plots = any(f.suffix in ['.png', '.jpg', '.jpeg'] for f in artifact_files if f.is_file())
379
+ has_report = any(f.suffix in ['.md', '.txt'] for f in artifact_files if f.is_file())
380
+
381
+ # Simple heuristic: 0.6 base + 0.2 for code + 0.1 for plots + 0.1 for report
382
+ score = 0.6
383
+ if has_code:
384
+ score += 0.2
385
+ if has_plots:
386
+ score += 0.1
387
+ if has_report:
388
+ score += 0.1
389
+
390
+ score = min(score, 1.0) # Cap at 1.0
391
+ logger.info(f"Heuristic score: {score:.2f} (code={has_code}, plots={has_plots}, report={has_report})")
392
+ return score
393
+
394
+
395
+ def get_result_columns(self) -> List[str]:
396
+ return [
397
+ "competition_id", "submission_path", "answers_path", "score", "cost", "running_time",
398
+ "input_tokens", "output_tokens", "total_tokens",
399
+ "gold_medal", "silver_medal", "bronze_medal", "above_median",
400
+ "submission_exists", "valid_submission", "error_message",
401
+ ]
402
+
403
+ def _create_error_report(self, competition_id: str, submission_path: Path, error_msg: str) -> CompetitionReport:
404
+ """Creates a dummy report if grading or eval_fn execution fails."""
405
+ # Create a minimal dict that CompetitionReport.from_dict can parse
406
+ base_data = {
407
+ "competition_id": competition_id,
408
+ "score": None,
409
+ "gold_threshold": float('nan'),
410
+ "silver_threshold": float('nan'),
411
+ "bronze_threshold": float('nan'),
412
+ "median_threshold": float('nan'),
413
+ "any_medal": False,
414
+ "gold_medal": False,
415
+ "silver_medal": False,
416
+ "bronze_medal": False,
417
+ "above_median": False,
418
+ "submission_exists": submission_path.exists(),
419
+ "valid_submission": False,
420
+ "is_lower_better": False, # Default
421
+ "created_at": datetime.now().isoformat(),
422
+ "submission_path": str(submission_path),
423
+ }
424
+ # Use from_dict for safety
425
+ report = CompetitionReport.from_dict(base_data)
426
+ # Log the actual error separately
427
+ logger.error(f"Error for {competition_id}: {error_msg}")
428
+ return report
429
+
430
+ async def evaluate_problem(self, problem: dict, eval_fn: Callable) -> Tuple[Tuple, CompetitionReport, Optional[str]]:
431
+ """
432
+ Evaluates a single MLEBench competition.
433
+ """
434
+ competition_id = problem.get("competition_id")
435
+ mode = problem.get("mode", "standard_ml") # Default to standard_ml
436
+
437
+ if not competition_id:
438
+ raise ValueError("Problem data must contain 'competition_id'")
439
+
440
+ # Define unique output path
441
+ # timestamp = datetime.now().strftime('%Y%m%dT%H%M%S')
442
+ unique_id = uuid.uuid4().hex[:6]
443
+ output_filename = f"submission_{competition_id}_{unique_id}.csv"
444
+ output_submission_path = (Path(self.log_path) / output_filename).absolute()
445
+
446
+ # Start timing
447
+ start_time = time.perf_counter()
448
+
449
+ cost = 0.0
450
+ running_time = 0.0
451
+ input_tokens = 0
452
+ output_tokens = 0
453
+ total_tokens = 0
454
+ report: Optional[CompetitionReport] = None
455
+ error_message: Optional[str] = None
456
+ competition: Optional[Competition] = None
457
+
458
+ try:
459
+ competition = self.registry.get_competition(competition_id)
460
+ # Skip prepared check for open-ended tasks
461
+ if mode != "open_ended" and not is_dataset_prepared(competition, grading_only=False):
462
+ raise ValueError(f"Dataset for '{competition_id}' not prepared in '{self.data_dir}'.")
463
+
464
+ # Determine source data directory (Host Path)
465
+ # This path will be symlinked to './data' (or similar) in the workspace by the Runner/WorkspaceService.
466
+ if self.data_source == "prepared":
467
+ if competition.public_dir.exists():
468
+ source_data_dir = competition.public_dir
469
+ logger.info(f"Using prepared data source: {source_data_dir}")
470
+ else:
471
+ # Strict enforcement: Do not fallback to raw if prepared was requested.
472
+ # This prevents accidental leakage of raw data (which might contain answers) in competition mode.
473
+ raise FileNotFoundError(
474
+ f"Prepared public data directory not found at {competition.public_dir}. "
475
+ f"Please run `mlebench prep {competition_id}` first, or check your data directory structure."
476
+ )
477
+ else:
478
+ # data_source == "raw"
479
+ source_data_dir = competition.raw_dir
480
+ logger.info(f"Using raw data source: {source_data_dir}")
481
+
482
+ # 1. Create standardized TaskDefinition
483
+ # Determine task_type based on mode
484
+ task_type = "open_ended" if mode == "open_ended" else "kaggle"
485
+
486
+ # Prepare payload based on task type
487
+ if mode == "open_ended":
488
+ # For open-ended tasks, provide file paths to description, rubric, and raw data
489
+ competition_dir = competition.raw_dir.parent # Go up from /raw to competition root
490
+
491
+ description_path = competition_dir / "description.md"
492
+ rubric_path = competition_dir / "rubric.md"
493
+
494
+ payload = {
495
+ "description": competition.description, # Fallback brief description
496
+ "description_file": str(description_path) if description_path.exists() else "",
497
+ "rubric": "", # Will be loaded from file
498
+ "rubric_file": str(rubric_path) if rubric_path.exists() else "",
499
+ "public_data_dir": str(source_data_dir.absolute()),
500
+ "raw_data_dir": "./data",
501
+ "output_submission_path": str(output_submission_path.absolute())
502
+ }
503
+ logger.info(f"Open-ended task payload: data_source={source_data_dir}, mapped_to=./data")
504
+ else:
505
+ # For standard ML tasks
506
+ payload = {
507
+ "description": competition.description,
508
+ "public_data_dir": str(source_data_dir.absolute()), # Use the determined source
509
+ "output_submission_path": str(output_submission_path.absolute())
510
+ }
511
+
512
+ task = TaskDefinition(
513
+ task_id=competition_id,
514
+ task_type=task_type,
515
+ mode=mode,
516
+ payload=payload
517
+ )
518
+
519
+ # 2. Call the generic evaluation function
520
+ result, cost, usage_summary = await eval_fn(task)
521
+
522
+ # 3. Extract token information from usage_summary
523
+ input_tokens = usage_summary.get("prompt_tokens", 0)
524
+ output_tokens = usage_summary.get("completion_tokens", 0)
525
+ total_tokens = usage_summary.get("total_tokens", 0)
526
+
527
+ # 3. Process result and perform grading
528
+ if isinstance(result, Path):
529
+ logger.debug(f"Grading submission {result} for {competition_id} (mode={mode})")
530
+
531
+ if mode == "open_ended":
532
+ # Skip standard grading for open-ended tasks
533
+ logger.info(f"Skipping CSV grading for open-ended task {competition_id}")
534
+ # Create a dummy 'success' report
535
+ base_data = {
536
+ "competition_id": competition_id,
537
+ "score": 1.0, # Placeholder score
538
+ "gold_threshold": 0.0,
539
+ "silver_threshold": 0.0,
540
+ "bronze_threshold": 0.0,
541
+ "median_threshold": 0.0,
542
+ "any_medal": True,
543
+ "gold_medal": True, # Mark as success for visualization
544
+ "silver_medal": False,
545
+ "bronze_medal": False,
546
+ "above_median": True,
547
+ "submission_exists": result.exists(),
548
+ "valid_submission": True, # It ran successfully
549
+ "is_lower_better": False,
550
+ "created_at": datetime.now().isoformat(),
551
+ "submission_path": str(result),
552
+ }
553
+ report = CompetitionReport.from_dict(base_data)
554
+ else:
555
+ report = grade_csv(result, competition)
556
+
557
+ elif isinstance(result, str) and result.startswith("[ERROR]"):
558
+ error_message = f"DSAT workflow failed: {result}"
559
+ logger.error(error_message)
560
+ report = self._create_error_report(competition_id, output_submission_path, error_message)
561
+ else:
562
+ error_message = f"Unexpected result type from eval_fn: {type(result).__name__}"
563
+ logger.error(error_message)
564
+ report = self._create_error_report(competition_id, output_submission_path, error_message)
565
+
566
+ except Exception as e:
567
+ error_message = f"Error during MLEBenchmark evaluation of {competition_id}: {e}"
568
+ logger.error(error_message, exc_info=True)
569
+ report = self._create_error_report(competition_id, output_submission_path, error_message)
570
+
571
+ if report is None:
572
+ final_error = error_message or "Unknown error: report is None"
573
+ report = self._create_error_report(competition_id, output_submission_path, final_error)
574
+ error_message = final_error
575
+
576
+ if not report.valid_submission:
577
+ answers_path_str = str(getattr(competition, 'answers', 'N/A')) if competition else 'N/A'
578
+ self.log_mismatch(
579
+ problem=competition_id,
580
+ expected_output=answers_path_str,
581
+ prediction=f"File: {output_submission_path}, Exists: {report.submission_exists}, Valid: {report.valid_submission}",
582
+ extracted_output=report.score,
583
+ extract_answer_code=error_message or "Grading function failed or file invalid/missing"
584
+ )
585
+ if not error_message:
586
+ error_message = "Submission invalid or missing."
587
+
588
+ # Calculate running time
589
+ running_time = round(time.perf_counter() - start_time, 4)
590
+
591
+ answers_path_str = str(getattr(competition, 'answers', 'N/A')) if competition else 'N/A'
592
+ csv_tuple = (
593
+ report.competition_id, str(report.submission_path), answers_path_str,
594
+ report.score, cost, running_time,
595
+ input_tokens, output_tokens, total_tokens,
596
+ report.gold_medal, report.silver_medal, report.bronze_medal,
597
+ report.above_median, report.submission_exists, report.valid_submission, error_message,
598
+ )
599
+ return csv_tuple, report, error_message
600
+
601
+ async def _grade_open_ended(self, artifacts_path: Path, competition_id: str, mode: str) -> float:
602
+ """
603
+ Grade open-ended tasks using LLM judge.
604
+
605
+ Args:
606
+ artifacts_path: Path to the artifacts directory
607
+ competition_id: Competition identifier
608
+ mode: Task mode (should be "open_ended")
609
+
610
+ Returns:
611
+ Float score between 0.0 and 1.0
612
+ """
613
+ try:
614
+ # Check if artifacts directory exists
615
+ if not artifacts_path.exists():
616
+ logger.warning(f"Artifacts directory not found: {artifacts_path}")
617
+ return 0.0
618
+
619
+ # Load task description and rubric
620
+ competition = self.registry.get_competition(competition_id)
621
+ competition_dir = competition.raw_dir.parent
622
+
623
+ description_path = competition_dir / "description.md"
624
+ rubric_path = competition_dir / "rubric.md"
625
+
626
+ description = ""
627
+ rubric = ""
628
+
629
+ if description_path.exists():
630
+ description = description_path.read_text(encoding='utf-8')
631
+ else:
632
+ description = competition.description
633
+
634
+ if rubric_path.exists():
635
+ rubric = rubric_path.read_text(encoding='utf-8')
636
+
637
+ # Analyze artifacts directory contents
638
+ artifact_files = list(artifacts_path.iterdir()) if artifacts_path.is_dir() else []
639
+
640
+ # Build file listing
641
+ file_summary = []
642
+ for file in sorted(artifact_files):
643
+ if file.is_file():
644
+ # Read file content based on type
645
+ try:
646
+ if file.suffix in ['.py', '.md', '.txt', '.csv']:
647
+ content = file.read_text(encoding='utf-8', errors='ignore')
648
+ # Truncate large files
649
+ if len(content) > 2000:
650
+ content = content[:2000] + "\n... (truncated)"
651
+ file_summary.append(f"### {file.name}\n```\n{content}\n```")
652
+ elif file.suffix in ['.png', '.jpg', '.jpeg']:
653
+ file_summary.append(f"### {file.name}\n[Image file - {file.stat().st_size} bytes]")
654
+ else:
655
+ file_summary.append(f"### {file.name}\n[Binary file - {file.stat().st_size} bytes]")
656
+ except Exception as e:
657
+ file_summary.append(f"### {file.name}\n[Error reading file: {e}]")
658
+
659
+ # If no files found, return minimal score
660
+ if not file_summary:
661
+ logger.warning(f"No artifacts found in {artifacts_path}")
662
+ return 0.1 # Small score for at least creating the directory
663
+
664
+ # Construct LLM judge prompt
665
+ # Define separator outside f-string (backslashes not allowed in f-string expressions)
666
+ file_sep = "\n\n"
667
+ judge_prompt = f"""You are an expert judge for open-ended data science tasks. Your role is to evaluate submissions based on given criteria.
668
+
669
+ ## Task Description
670
+
671
+ {description}
672
+
673
+ ## Evaluation Criteria
674
+
675
+ {rubric if rubric else "Evaluate based on completeness, correctness, and quality of the solution."}
676
+
677
+ ## Submitted Artifacts
678
+
679
+ The following files were submitted:
680
+
681
+ {file_sep.join(file_summary)}
682
+
683
+ ## Scoring Instructions
684
+
685
+ Evaluate the submission on a scale from 0.0 to 1.0 based on:
686
+
687
+ 1. **Completeness** (0-0.3): Did the submission address all aspects of the task?
688
+ 2. **Correctness** (0-0.4): Are the methods and results correct?
689
+ 3. **Quality** (0-0.3): Is the code well-structured, documented, and the analysis thorough?
690
+
691
+ Provide your response in the following JSON format:
692
+
693
+ ```json
694
+ {{
695
+ "reasoning": "Brief explanation of the score",
696
+ "completeness": 0.0-0.3,
697
+ "correctness": 0.0-0.4,
698
+ "quality": 0.0-0.3,
699
+ "total_score": 0.0-1.0
700
+ }}
701
+ ```
702
+
703
+ Respond ONLY with the JSON object, no additional text.
704
+ """
705
+
706
+ # Call LLM for judgment
707
+ try:
708
+ from dsat.llm import LLMService
709
+ from dsat.config import DSATConfig
710
+
711
+ # Get LLM service - need to initialize it
712
+ # Try to get LLM model from config or environment
713
+ llm_model = os.environ.get('LLM_MODEL', 'glm-4.7')
714
+ llm_provider = os.environ.get('LLM_PROVIDER', 'openai')
715
+
716
+ llm_service = LLMService(
717
+ model=llm_model,
718
+ provider=llm_provider
719
+ )
720
+
721
+ logger.info(f"Calling LLM judge for {competition_id} with model {llm_model}")
722
+
723
+ # Call LLM
724
+ llm_response = await llm_service.call(judge_prompt)
725
+
726
+ # Parse JSON response
727
+ import json
728
+ import re
729
+
730
+ # Extract JSON from response
731
+ json_match = re.search(r'\{[^{}]*\}', llm_response, re.DOTALL)
732
+ if json_match:
733
+ result = json.loads(json_match.group())
734
+ score = float(result.get('total_score', 0.5))
735
+
736
+ # Log detailed breakdown
737
+ completeness = result.get('completeness', 0.0)
738
+ correctness = result.get('correctness', 0.0)
739
+ quality = result.get('quality', 0.0)
740
+ reasoning = result.get('reasoning', '')
741
+
742
+ logger.info(f"LLM judge score for {competition_id}: {score:.2f}")
743
+ logger.info(f" Breakdown: completeness={completeness:.2f}, correctness={correctness:.2f}, quality={quality:.2f}")
744
+ logger.info(f" Reasoning: {reasoning[:200]}...")
745
+
746
+ return max(0.0, min(score, 1.0)) # Ensure score is in [0, 1]
747
+ else:
748
+ logger.warning(f"Failed to parse LLM judge response as JSON: {llm_response[:200]}")
749
+ # Fallback to heuristic if parsing fails
750
+ return self._heuristic_score(artifact_files)
751
+
752
+ except Exception as llm_error:
753
+ logger.error(f"LLM judge call failed: {llm_error}, using heuristic scoring")
754
+ return self._heuristic_score(artifact_files)
755
+
756
+ except Exception as e:
757
+ logger.error(f"Error during LLM judge grading: {e}", exc_info=True)
758
+ return 0.0
759
+
760
+ def _heuristic_score(self, artifact_files) -> float:
761
+ """Fallback heuristic scoring when LLM judge is unavailable."""
762
+ has_code = any(f.suffix == '.py' for f in artifact_files if f.is_file())
763
+ has_plots = any(f.suffix in ['.png', '.jpg', '.jpeg'] for f in artifact_files if f.is_file())
764
+ has_report = any(f.suffix in ['.md', '.txt'] for f in artifact_files if f.is_file())
765
+
766
+ # Simple heuristic: 0.6 base + 0.2 for code + 0.1 for plots + 0.1 for report
767
+ score = 0.6
768
+ if has_code:
769
+ score += 0.2
770
+ if has_plots:
771
+ score += 0.1
772
+ if has_report:
773
+ score += 0.1
774
+
775
+ score = min(score, 1.0) # Cap at 1.0
776
+ logger.info(f"Heuristic score: {score:.2f} (code={has_code}, plots={has_plots}, report={has_report})")
777
+ return score