dslighting 1.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. dsat/__init__.py +3 -0
  2. dsat/benchmark/__init__.py +1 -0
  3. dsat/benchmark/benchmark.py +168 -0
  4. dsat/benchmark/datasci.py +291 -0
  5. dsat/benchmark/mle.py +777 -0
  6. dsat/benchmark/sciencebench.py +304 -0
  7. dsat/common/__init__.py +0 -0
  8. dsat/common/constants.py +11 -0
  9. dsat/common/exceptions.py +48 -0
  10. dsat/common/typing.py +19 -0
  11. dsat/config.py +79 -0
  12. dsat/models/__init__.py +3 -0
  13. dsat/models/candidates.py +16 -0
  14. dsat/models/formats.py +52 -0
  15. dsat/models/task.py +64 -0
  16. dsat/operators/__init__.py +0 -0
  17. dsat/operators/aflow_ops.py +90 -0
  18. dsat/operators/autokaggle_ops.py +170 -0
  19. dsat/operators/automind_ops.py +38 -0
  20. dsat/operators/base.py +22 -0
  21. dsat/operators/code.py +45 -0
  22. dsat/operators/dsagent_ops.py +123 -0
  23. dsat/operators/llm_basic.py +84 -0
  24. dsat/prompts/__init__.py +0 -0
  25. dsat/prompts/aflow_prompt.py +76 -0
  26. dsat/prompts/aide_prompt.py +52 -0
  27. dsat/prompts/autokaggle_prompt.py +290 -0
  28. dsat/prompts/automind_prompt.py +29 -0
  29. dsat/prompts/common.py +51 -0
  30. dsat/prompts/data_interpreter_prompt.py +82 -0
  31. dsat/prompts/dsagent_prompt.py +88 -0
  32. dsat/runner.py +554 -0
  33. dsat/services/__init__.py +0 -0
  34. dsat/services/data_analyzer.py +387 -0
  35. dsat/services/llm.py +486 -0
  36. dsat/services/llm_single.py +421 -0
  37. dsat/services/sandbox.py +386 -0
  38. dsat/services/states/__init__.py +0 -0
  39. dsat/services/states/autokaggle_state.py +43 -0
  40. dsat/services/states/base.py +14 -0
  41. dsat/services/states/dsa_log.py +13 -0
  42. dsat/services/states/experience.py +237 -0
  43. dsat/services/states/journal.py +153 -0
  44. dsat/services/states/operator_library.py +290 -0
  45. dsat/services/vdb.py +76 -0
  46. dsat/services/workspace.py +178 -0
  47. dsat/tasks/__init__.py +3 -0
  48. dsat/tasks/handlers.py +376 -0
  49. dsat/templates/open_ended/grade_template.py +107 -0
  50. dsat/tools/__init__.py +4 -0
  51. dsat/utils/__init__.py +0 -0
  52. dsat/utils/context.py +172 -0
  53. dsat/utils/dynamic_import.py +71 -0
  54. dsat/utils/parsing.py +33 -0
  55. dsat/workflows/__init__.py +12 -0
  56. dsat/workflows/base.py +53 -0
  57. dsat/workflows/factory.py +439 -0
  58. dsat/workflows/manual/__init__.py +0 -0
  59. dsat/workflows/manual/autokaggle_workflow.py +148 -0
  60. dsat/workflows/manual/data_interpreter_workflow.py +153 -0
  61. dsat/workflows/manual/deepanalyze_workflow.py +484 -0
  62. dsat/workflows/manual/dsagent_workflow.py +76 -0
  63. dsat/workflows/search/__init__.py +0 -0
  64. dsat/workflows/search/aflow_workflow.py +344 -0
  65. dsat/workflows/search/aide_workflow.py +283 -0
  66. dsat/workflows/search/automind_workflow.py +237 -0
  67. dsat/workflows/templates/__init__.py +0 -0
  68. dsat/workflows/templates/basic_kaggle_loop.py +71 -0
  69. dslighting/__init__.py +170 -0
  70. dslighting/core/__init__.py +13 -0
  71. dslighting/core/agent.py +646 -0
  72. dslighting/core/config_builder.py +318 -0
  73. dslighting/core/data_loader.py +422 -0
  74. dslighting/core/task_detector.py +422 -0
  75. dslighting/utils/__init__.py +19 -0
  76. dslighting/utils/defaults.py +151 -0
  77. dslighting-1.3.9.dist-info/METADATA +554 -0
  78. dslighting-1.3.9.dist-info/RECORD +80 -0
  79. dslighting-1.3.9.dist-info/WHEEL +5 -0
  80. dslighting-1.3.9.dist-info/top_level.txt +2 -0
@@ -0,0 +1,439 @@
1
+ # dsat/workflows/factory.py
2
+
3
+ from abc import ABC, abstractmethod
4
+ import inspect
5
+ import logging
6
+ from typing import Dict, Any, Optional, Type
7
+
8
+ # --- Core configuration and interface imports ---
9
+ from dsat.config import DSATConfig
10
+ from dsat.workflows.base import DSATWorkflow
11
+ from dsat.benchmark.benchmark import BaseBenchmark
12
+
13
+ # --- Services imports ---
14
+ from dsat.services.workspace import WorkspaceService
15
+ from dsat.services.llm import LLMService
16
+ from dsat.services.sandbox import SandboxService
17
+ from dsat.services.vdb import VDBService
18
+
19
+ # --- State management imports ---
20
+ from dsat.services.states.journal import JournalState
21
+ from dsat.services.states.dsa_log import DSAgentState
22
+
23
+ # --- Operators imports ---
24
+ # General operators
25
+ from dsat.operators.llm_basic import GenerateCodeAndPlanOperator, ReviewOperator, PlanOperator
26
+ from dsat.operators.code import ExecuteAndTestOperator
27
+ # AutoMind specific operators
28
+ from dsat.operators.automind_ops import ComplexityScorerOperator, PlanDecomposerOperator
29
+ # DS-Agent specific operators
30
+ from dsat.operators.dsagent_ops import DevelopPlanOperator, ExecutePlanOperator, ReviseLogOperator
31
+ # AutoKaggle specific operators
32
+ from dsat.operators.autokaggle_ops import * # Import all new operators
33
+
34
+ from dsat.operators.aflow_ops import ScEnsembleOperator, ReviewOperator as AFlowReviewOperator, ReviseOperator as AFlowReviseOperator
35
+ from dsat.utils.dynamic_import import import_workflow_from_string
36
+ from dsat.common.exceptions import DynamicImportError
37
+
38
+
39
+ # --- Concrete workflow imports ---
40
+ from dsat.workflows.search.automind_workflow import AutoMindWorkflow
41
+ from dsat.workflows.search.aide_workflow import AIDEWorkflow
42
+ from dsat.workflows.search.aflow_workflow import AFlowWorkflow
43
+ from dsat.workflows.manual.deepanalyze_workflow import DeepAnalyzeWorkflow
44
+ from dsat.workflows.manual.dsagent_workflow import DSAgentWorkflow
45
+ from dsat.workflows.manual.data_interpreter_workflow import DataInterpreterWorkflow
46
+ from dsat.workflows.manual.autokaggle_workflow import AutoKaggleWorkflow
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class WorkflowFactory(ABC):
52
+ """
53
+ Abstract base class for workflow factories.
54
+
55
+ Defines a unified interface for creating workflow instances based on configuration.
56
+ This follows the factory pattern, separating object creation logic from usage.
57
+ """
58
+ @abstractmethod
59
+ def create_workflow(self, config: DSATConfig, benchmark: Optional[BaseBenchmark] = None) -> DSATWorkflow:
60
+ """
61
+ Create and return a configured workflow instance based on the provided configuration.
62
+
63
+ Args:
64
+ config: Complete DSATConfig object containing all runtime parameters.
65
+
66
+ Returns:
67
+ A fully initialized DSATWorkflow instance ready to execute solve() method.
68
+ """
69
+ raise NotImplementedError
70
+
71
+
72
+ # ==============================================================================
73
+ # == AIDE WORKFLOW FACTORY ==
74
+ # ==============================================================================
75
+ class AIDEWorkflowFactory(WorkflowFactory):
76
+ """A specialized factory for creating and assembling AIDEWorkflow."""
77
+ def create_workflow(self, config: DSATConfig, benchmark: Optional[BaseBenchmark] = None) -> AIDEWorkflow:
78
+ logger.info("AIDEWorkflowFactory: Assembling AIDE workflow...")
79
+
80
+ workspace_base = None
81
+ if config.workflow and config.workflow.params:
82
+ workspace_base = config.workflow.params.get("workspace_base_dir")
83
+ workspace = WorkspaceService(run_name=config.run.name, base_dir=workspace_base)
84
+ llm_service = LLMService(config=config.llm)
85
+ sandbox_service = SandboxService(workspace=workspace, timeout=config.sandbox.timeout)
86
+ state = JournalState()
87
+
88
+ operators = {
89
+ "generate": GenerateCodeAndPlanOperator(llm_service=llm_service),
90
+ "execute": ExecuteAndTestOperator(sandbox_service=sandbox_service),
91
+ "review": ReviewOperator(llm_service=llm_service),
92
+ }
93
+
94
+ services = {
95
+ "llm": llm_service,
96
+ "sandbox": sandbox_service,
97
+ "state": state,
98
+ "workspace": workspace,
99
+ }
100
+
101
+ workflow = AIDEWorkflow(
102
+ operators=operators,
103
+ services=services,
104
+ agent_config=config.agent.model_dump(),
105
+ benchmark=benchmark
106
+ )
107
+
108
+ logger.info("AIDE workflow assembled successfully.")
109
+ return workflow
110
+
111
+
112
+ # ==============================================================================
113
+ # == AUTOMIND WORKFLOW FACTORY ==
114
+ # ==============================================================================
115
+ class AutoMindWorkflowFactory(WorkflowFactory):
116
+ """
117
+ A specialized factory for creating and assembling AutoMindWorkflow.
118
+
119
+ This class encapsulates all the complexity required to create AutoMindWorkflow,
120
+ including instantiating its dependent services, state managers and operators.
121
+ """
122
+ def create_workflow(self, config: DSATConfig, benchmark: Optional[BaseBenchmark] = None) -> AutoMindWorkflow:
123
+ """
124
+ Build a fully functional AutoMindWorkflow instance.
125
+ """
126
+ logger.info("AutoMindWorkflowFactory: Assembling AutoMind workflow...")
127
+
128
+ # 1. Instantiate all base services required by this workflow
129
+ logger.debug("Instantiating services...")
130
+ workspace = WorkspaceService(run_name=config.run.name)
131
+ llm_service = LLMService(config=config.llm)
132
+ sandbox_service = SandboxService(workspace=workspace, timeout=config.sandbox.timeout)
133
+ case_dir = config.workflow.params.get('case_dir', 'experience_replay')
134
+ vdb_service = VDBService(case_dir=case_dir)
135
+ state = JournalState()
136
+
137
+ # 2. Instantiate all operators required by this workflow, injecting their service dependencies
138
+ logger.debug("Instantiating operators...")
139
+ operators = {
140
+ "generate": GenerateCodeAndPlanOperator(llm_service=llm_service),
141
+ "execute": ExecuteAndTestOperator(sandbox_service=sandbox_service),
142
+ "review": ReviewOperator(llm_service=llm_service),
143
+ "complexity_scorer": ComplexityScorerOperator(llm_service=llm_service),
144
+ "plan_decomposer": PlanDecomposerOperator(llm_service=llm_service),
145
+ }
146
+
147
+ # 3. Package all services for injection
148
+ services = {
149
+ "llm": llm_service,
150
+ "sandbox": sandbox_service,
151
+ "vdb": vdb_service,
152
+ "state": state,
153
+ "workspace": workspace, # Also optionally inject workspace
154
+ }
155
+
156
+ logger.debug("Instantiating AutoMindWorkflow with dependencies...")
157
+ workflow = AutoMindWorkflow(
158
+ operators=operators,
159
+ services=services,
160
+ agent_config=config.agent.model_dump(),
161
+ benchmark=benchmark
162
+ )
163
+
164
+ logger.info("AutoMind workflow assembled successfully.")
165
+ return workflow
166
+
167
+
168
+ # ==============================================================================
169
+ # == DS-AGENT WORKFLOW FACTORY ==
170
+ # ==============================================================================
171
+ class DSAgentWorkflowFactory(WorkflowFactory):
172
+ """A specialized factory for creating and assembling DSAgentWorkflow."""
173
+ def create_workflow(self, config: DSATConfig, benchmark: Optional[BaseBenchmark] = None) -> DSAgentWorkflow:
174
+ logger.info("DSAgentWorkflowFactory: Assembling DS-Agent workflow...")
175
+
176
+ workspace = WorkspaceService(run_name=config.run.name)
177
+ llm_service = LLMService(config=config.llm)
178
+ sandbox_service = SandboxService(workspace=workspace, timeout=config.sandbox.timeout)
179
+ case_dir = config.workflow.params.get('case_dir', 'experience_replay')
180
+ vdb_service = VDBService(case_dir=case_dir)
181
+ state = DSAgentState()
182
+
183
+ operators = {
184
+ "planner": DevelopPlanOperator(llm_service=llm_service, vdb_service=vdb_service),
185
+ "executor": ExecutePlanOperator(llm_service=llm_service, sandbox_service=sandbox_service),
186
+ "logger": ReviseLogOperator(llm_service=llm_service),
187
+ }
188
+
189
+ services = {
190
+ "llm": llm_service,
191
+ "sandbox": sandbox_service,
192
+ "vdb": vdb_service,
193
+ "state": state,
194
+ "workspace": workspace,
195
+ }
196
+
197
+ workflow = DSAgentWorkflow(
198
+ operators=operators,
199
+ services=services,
200
+ agent_config=config.agent.model_dump()
201
+ )
202
+
203
+ logger.info("DS-Agent workflow assembled successfully.")
204
+ return workflow
205
+
206
+
207
+ # ==============================================================================
208
+ # == DATA INTERPRETER WORKFLOW FACTORY ==
209
+ # ==============================================================================
210
+ class DataInterpreterWorkflowFactory(WorkflowFactory):
211
+ """A specialized factory for creating and assembling DataInterpreterWorkflow."""
212
+ def create_workflow(self, config: DSATConfig, benchmark: Optional[BaseBenchmark] = None) -> DataInterpreterWorkflow:
213
+ logger.info("DataInterpreterWorkflowFactory: Assembling Data Interpreter workflow...")
214
+
215
+ workspace = WorkspaceService(run_name=config.run.name)
216
+ llm_service = LLMService(config=config.llm)
217
+ sandbox_service = SandboxService(workspace=workspace, timeout=config.sandbox.timeout)
218
+
219
+ operators = {
220
+ "planner": PlanOperator(llm_service=llm_service),
221
+ "generator": GenerateCodeAndPlanOperator(llm_service=llm_service),
222
+ "debugger": GenerateCodeAndPlanOperator(llm_service=llm_service),
223
+ "executor": ExecuteAndTestOperator(sandbox_service=sandbox_service),
224
+ }
225
+
226
+ services = {
227
+ "llm": llm_service,
228
+ "sandbox": sandbox_service,
229
+ "workspace": workspace,
230
+ }
231
+
232
+ workflow = DataInterpreterWorkflow(
233
+ operators=operators,
234
+ services=services,
235
+ agent_config=config.agent.model_dump()
236
+ )
237
+
238
+ logger.info("Data Interpreter workflow assembled successfully.")
239
+ return workflow
240
+
241
+
242
+ # ==============================================================================
243
+ # == AUTOKAGGLE SOP WORKFLOW FACTORY ==
244
+ # ==============================================================================
245
+ class AutoKaggleWorkflowFactory(WorkflowFactory):
246
+ """A specialized factory for creating and assembling the dynamic AutoKaggleWorkflow."""
247
+ def create_workflow(self, config: DSATConfig, benchmark: Optional[BaseBenchmark] = None) -> AutoKaggleWorkflow:
248
+ logger.info("AutoKaggleWorkflowFactory: Assembling AutoKaggle SOP workflow...")
249
+
250
+ workspace = WorkspaceService(run_name=config.run.name)
251
+ llm_service = LLMService(config=config.llm)
252
+ sandbox_service = SandboxService(workspace=workspace, timeout=config.sandbox.timeout)
253
+
254
+ services = {
255
+ "llm": llm_service,
256
+ "sandbox": sandbox_service,
257
+ "workspace": workspace,
258
+ }
259
+
260
+ # The workflow now instantiates its own operators, so we pass an empty dict
261
+ workflow = AutoKaggleWorkflow(
262
+ operators={},
263
+ services=services,
264
+ agent_config=config.agent.model_dump()
265
+ )
266
+
267
+ logger.info("AutoKaggle SOP workflow assembled successfully.")
268
+ return workflow
269
+
270
+
271
+ # ==============================================================================
272
+ # == DEEPANALYZE WORKFLOW FACTORY ==
273
+ # ==============================================================================
274
+ class DeepAnalyzeWorkflowFactory(WorkflowFactory):
275
+ """Factory for assembling DeepAnalyzeWorkflow."""
276
+
277
+ def create_workflow(
278
+ self,
279
+ config: DSATConfig,
280
+ benchmark: Optional[BaseBenchmark] = None,
281
+ ) -> DeepAnalyzeWorkflow:
282
+ logger.info("DeepAnalyzeWorkflowFactory: Assembling DeepAnalyze workflow...")
283
+
284
+ workspace = WorkspaceService(run_name=config.run.name)
285
+ llm_service = LLMService(config=config.llm)
286
+ sandbox_service = SandboxService(workspace=workspace, timeout=config.sandbox.timeout)
287
+
288
+ operators = {
289
+ "execute": ExecuteAndTestOperator(sandbox_service=sandbox_service),
290
+ }
291
+
292
+ services = {
293
+ "llm": llm_service,
294
+ "sandbox": sandbox_service,
295
+ "workspace": workspace,
296
+ }
297
+
298
+ workflow = DeepAnalyzeWorkflow(
299
+ operators=operators,
300
+ services=services,
301
+ agent_config=config.agent.model_dump(),
302
+ benchmark=benchmark,
303
+ )
304
+
305
+ logger.info("DeepAnalyze workflow assembled successfully.")
306
+ return workflow
307
+
308
+
309
+ # ==============================================================================
310
+ # == AFLOW WORKFLOW FACTORY ==
311
+ # ==============================================================================
312
+ class AFlowWorkflowFactory(WorkflowFactory):
313
+ """A specialized factory for creating and assembling AFlowWorkflow."""
314
+ def create_workflow(self, config: DSATConfig, benchmark: Optional[BaseBenchmark] = None) -> AFlowWorkflow:
315
+ logger.info("AFlowWorkflowFactory: Assembling AFlow workflow...")
316
+
317
+ workspace_base = None
318
+ if config.workflow and config.workflow.params:
319
+ workspace_base = config.workflow.params.get("workspace_base_dir")
320
+ workspace = WorkspaceService(run_name=config.run.name, base_dir=workspace_base)
321
+ llm_service = LLMService(config=config.llm)
322
+ # Add SandboxService for code execution capabilities
323
+ sandbox_service = SandboxService(workspace=workspace, timeout=config.sandbox.timeout)
324
+
325
+ services = {
326
+ "llm": llm_service,
327
+ "workspace": workspace,
328
+ "sandbox": sandbox_service,
329
+ }
330
+
331
+ agent_config = config.agent.model_dump()
332
+ if config.optimizer:
333
+ agent_config["optimizer"] = config.optimizer.model_dump()
334
+
335
+ workflow = AFlowWorkflow(
336
+ operators={}, # AFlow creates its own operators
337
+ services=services,
338
+ agent_config=agent_config,
339
+ benchmark=benchmark, # Pass the benchmark instance
340
+ )
341
+
342
+ logger.info("AFlow workflow assembled successfully.")
343
+ return workflow
344
+
345
+
346
+ class DynamicWorkflowFactory(WorkflowFactory):
347
+ """
348
+ A factory that creates a workflow instance from a Python code string at runtime.
349
+ This is used by the AFLOW paradigm to evaluate its discovered "best" workflow.
350
+ """
351
+ def __init__(
352
+ self,
353
+ code_string: str,
354
+ operator_classes: Optional[Dict[str, Type["Operator"]]] = None,
355
+ ):
356
+ self.code_string = code_string
357
+ self.operator_classes = operator_classes
358
+ try:
359
+ self.workflow_class = import_workflow_from_string(self.code_string)
360
+ except DynamicImportError as e:
361
+ raise ValueError("Failed to dynamically import 'Workflow' class from the provided code string.") from e
362
+
363
+ def create_workflow(self, config: DSATConfig, benchmark: Optional[BaseBenchmark] = None) -> DSATWorkflow:
364
+ logger.info(f"DynamicWorkflowFactory: Instantiating workflow from code string...")
365
+
366
+ workspace = WorkspaceService(run_name=config.run.name)
367
+ llm_service = LLMService(config=config.llm)
368
+ sandbox_service = SandboxService(workspace=workspace, timeout=config.sandbox.timeout)
369
+
370
+ services = {
371
+ "llm": llm_service,
372
+ "sandbox": sandbox_service,
373
+ "workspace": workspace,
374
+ }
375
+
376
+ operators = self._build_operator_instances(
377
+ llm_service=llm_service,
378
+ sandbox_service=sandbox_service,
379
+ workspace=workspace,
380
+ )
381
+
382
+ # Instantiate the dynamically imported class
383
+ workflow_instance = self.workflow_class(
384
+ operators=operators,
385
+ services=services,
386
+ agent_config=config.agent.model_dump()
387
+ )
388
+ logger.info("Dynamically-loaded workflow instantiated successfully.")
389
+ return workflow_instance
390
+
391
+ def _build_operator_instances(
392
+ self,
393
+ llm_service: LLMService,
394
+ sandbox_service: SandboxService,
395
+ workspace: WorkspaceService,
396
+ ) -> Dict[str, Any]:
397
+ """
398
+ Build operator instances for a dynamically imported workflow.
399
+
400
+ - Default: provide AFLOW operators (backwards compatible).
401
+ - Override: callers may pass `operator_classes` to inject a custom toolbox.
402
+ """
403
+ if not self.operator_classes:
404
+ return {
405
+ "ScEnsemble": ScEnsembleOperator(llm_service=llm_service),
406
+ "Review": AFlowReviewOperator(llm_service=llm_service),
407
+ "Revise": AFlowReviseOperator(llm_service=llm_service),
408
+ }
409
+
410
+ operators: Dict[str, Any] = {}
411
+ for name, cls in self.operator_classes.items():
412
+ operators[name] = self._instantiate_operator(
413
+ cls=cls,
414
+ llm_service=llm_service,
415
+ sandbox_service=sandbox_service,
416
+ workspace=workspace,
417
+ operators=operators,
418
+ )
419
+ return operators
420
+
421
+ @staticmethod
422
+ def _instantiate_operator(
423
+ cls: Type["Operator"],
424
+ llm_service: LLMService,
425
+ sandbox_service: SandboxService,
426
+ workspace: WorkspaceService,
427
+ operators: Dict[str, Any],
428
+ ) -> Any:
429
+ params = inspect.signature(cls.__init__).parameters
430
+ kwargs: Dict[str, Any] = {}
431
+ if "llm_service" in params:
432
+ kwargs["llm_service"] = llm_service
433
+ if "sandbox_service" in params:
434
+ kwargs["sandbox_service"] = sandbox_service
435
+ if "workspace" in params:
436
+ kwargs["workspace"] = workspace
437
+ if "operators" in params:
438
+ kwargs["operators"] = operators
439
+ return cls(**kwargs) # type: ignore[arg-type]
File without changes
@@ -0,0 +1,148 @@
1
+ # dsat/workflows/manual/autokaggle_workflow.py
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Dict, Any
5
+ import shutil # <-- ADD THIS IMPORT
6
+
7
+ from dsat.workflows.base import DSATWorkflow
8
+ from dsat.services.llm import LLMService
9
+ from dsat.services.sandbox import SandboxService
10
+ from dsat.services.workspace import WorkspaceService
11
+ from dsat.services.states.autokaggle_state import AutoKaggleState, PhaseMemory, AttemptMemory
12
+ from dsat.operators.autokaggle_ops import (
13
+ TaskDeconstructionOperator,
14
+ AutoKagglePlannerOperator,
15
+ AutoKaggleDeveloperOperator,
16
+ DynamicValidationOperator,
17
+ AutoKaggleReviewerOperator,
18
+ AutoKaggleSummarizerOperator,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class AutoKaggleWorkflow(DSATWorkflow):
25
+ """
26
+ Implements the Contract-Driven Dynamic AutoKaggle Standard Operating Procedure.
27
+ """
28
+
29
+ def __init__(self, operators: Dict[str, Any], services: Dict[str, Any], agent_config: Dict[str, Any]):
30
+ super().__init__(operators, services, agent_config)
31
+ self.workspace: WorkspaceService = services["workspace"]
32
+ self.llm_service: LLMService = services["llm"]
33
+ self.sandbox: SandboxService = services["sandbox"]
34
+
35
+ # Operator Initialization with Dependency Injection
36
+ validator = DynamicValidationOperator(llm_service=self.llm_service)
37
+ self.operators = {
38
+ "deconstructor": TaskDeconstructionOperator(llm_service=self.llm_service),
39
+ "planner": AutoKagglePlannerOperator(llm_service=self.llm_service),
40
+ "developer": AutoKaggleDeveloperOperator(llm_service=self.llm_service, sandbox_service=self.sandbox, validator=validator),
41
+ "reviewer": AutoKaggleReviewerOperator(llm_service=self.llm_service),
42
+ "summarizer": AutoKaggleSummarizerOperator(llm_service=self.llm_service),
43
+ }
44
+ sop_config = agent_config.get("autokaggle", {})
45
+ self.config = {
46
+ "max_attempts_per_phase": sop_config.get("max_attempts_per_phase", 5), # Give it a few more tries
47
+ "success_threshold": sop_config.get("success_threshold", 3.0)
48
+ }
49
+
50
+ async def solve(self, description: str, io_instructions: str, data_dir: Path, output_path: Path) -> None:
51
+ logger.info("Starting Stateful Contract-Driven Dynamic SOP Workflow...")
52
+
53
+ full_context_for_deconstructor = f"{description}\n\n{io_instructions}"
54
+
55
+ task_contract = await self.operators["deconstructor"](full_context_for_deconstructor)
56
+ dynamic_phases = await self.operators["planner"].plan_phases(task_contract)
57
+
58
+ state = AutoKaggleState(
59
+ contract=task_contract,
60
+ dynamic_phases=dynamic_phases,
61
+ io_instructions=io_instructions,
62
+ full_task_description=description
63
+ )
64
+
65
+ for i, phase_goal in enumerate(state.dynamic_phases):
66
+ logger.info(f"--- Starting Dynamic Phase {i+1}/{len(state.dynamic_phases)}: '{phase_goal}' ---")
67
+
68
+ current_phase_memory = PhaseMemory(phase_goal=phase_goal)
69
+ phase_succeeded = False
70
+
71
+ for attempt in range(self.config["max_attempts_per_phase"]):
72
+ logger.info(f"--- Phase '{phase_goal}', Attempt {attempt + 1} ---")
73
+
74
+ step_plan = await self.operators["planner"].plan_step_details(state, phase_goal)
75
+ dev_result = await self.operators["developer"](state, phase_goal, step_plan.plan, current_phase_memory.attempts)
76
+ review_result = await self.operators["reviewer"](state, phase_goal, dev_result, plan=step_plan.plan)
77
+
78
+ ## --- MODIFIED: Stricter Artifact-based Validation --- ##
79
+ all_artifacts_produced = True
80
+ sandbox_workdir = self.sandbox.workspace.get_path("sandbox_workdir")
81
+ if dev_result['status']: # Only check for artifacts if code ran successfully
82
+ if not step_plan.output_files:
83
+ logger.warning(f"Phase '{phase_goal}' has no planned output files. Relying on reviewer score alone.")
84
+ for filename in step_plan.output_files:
85
+ if not (sandbox_workdir / filename).exists():
86
+ logger.error(f"Attempt failed: Planned artifact '{filename}' was NOT created.")
87
+ all_artifacts_produced = False
88
+ break # No need to check other files
89
+ else:
90
+ all_artifacts_produced = False
91
+
92
+ attempt_memory = AttemptMemory(
93
+ attempt_number=attempt,
94
+ plan=step_plan.plan,
95
+ code=dev_result['code'],
96
+ execution_output=dev_result['output'],
97
+ execution_error=dev_result['error'],
98
+ validation_result=dev_result.get('validation_result', {}),
99
+ review_score=review_result.get('score', 1.0),
100
+ review_suggestion=review_result.get('suggestion', 'No suggestion provided.')
101
+ )
102
+ current_phase_memory.attempts.append(attempt_memory)
103
+
104
+ if dev_result['status'] and all_artifacts_produced and review_result.get('score', 1.0) >= self.config["success_threshold"]:
105
+ logger.info(f"--- Phase '{phase_goal}' Succeeded ---")
106
+ phase_succeeded = True
107
+
108
+ # Register newly created artifacts in the global state
109
+ for filename in step_plan.output_files:
110
+ description = f"Generated during phase: {phase_goal}"
111
+ state.global_artifacts[filename] = description
112
+ current_phase_memory.output_artifacts[filename] = description
113
+ logger.info(f"Registered new artifact: {filename}")
114
+
115
+ break # Exit attempt loop on success
116
+ else:
117
+ logger.warning(f"Attempt failed. Code Success: {dev_result['status']}. Artifacts Produced: {all_artifacts_produced}. Score: {review_result.get('score', 1.0)}. Retrying...")
118
+
119
+ if phase_succeeded:
120
+ summary_report = await self.operators["summarizer"](state, current_phase_memory)
121
+ current_phase_memory.final_report = summary_report
122
+ current_phase_memory.is_successful = True
123
+ state.phase_history.append(current_phase_memory)
124
+ else:
125
+ logger.error(f"--- Phase '{phase_goal}' FAILED after all attempts. Aborting workflow. ---")
126
+ return # Abort entire workflow if a phase fails
127
+
128
+ logger.info("All dynamic phases completed successfully.")
129
+
130
+ ## --- ADDED: Final Artifact Collection --- ##
131
+ final_submission_filename = None
132
+ if state.contract.output_files:
133
+ # Assume the first output file in the contract is the required one.
134
+ final_submission_filename = state.contract.output_files[0].filename
135
+
136
+ if final_submission_filename and final_submission_filename in state.global_artifacts:
137
+ source_file = sandbox_workdir / final_submission_filename
138
+ destination_file = output_path
139
+
140
+ logger.info(f"Collecting final submission artifact '{source_file}' to '{destination_file}'.")
141
+ try:
142
+ destination_file.parent.mkdir(parents=True, exist_ok=True)
143
+ shutil.copy(source_file, destination_file)
144
+ logger.info("Final artifact collected successfully.")
145
+ except Exception as e:
146
+ logger.error(f"Failed to collect final artifact: {e}")
147
+ else:
148
+ logger.error(f"Workflow finished, but required output file '{final_submission_filename}' was not found in the global artifact registry.")