experiment-configuration-agent 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- experiment_configuration_agent-0.1.0/PKG-INFO +117 -0
- experiment_configuration_agent-0.1.0/README.md +105 -0
- experiment_configuration_agent-0.1.0/pyproject.toml +27 -0
- experiment_configuration_agent-0.1.0/setup.cfg +4 -0
- experiment_configuration_agent-0.1.0/src/experiment_config_agent/__init__.py +0 -0
- experiment_configuration_agent-0.1.0/src/experiment_config_agent/agent.py +57 -0
- experiment_configuration_agent-0.1.0/src/experiment_config_agent/config.py +20 -0
- experiment_configuration_agent-0.1.0/src/experiment_config_agent/constants.py +235 -0
- experiment_configuration_agent-0.1.0/src/experiment_config_agent/models.py +30 -0
- experiment_configuration_agent-0.1.0/src/experiment_configuration_agent.egg-info/PKG-INFO +117 -0
- experiment_configuration_agent-0.1.0/src/experiment_configuration_agent.egg-info/SOURCES.txt +13 -0
- experiment_configuration_agent-0.1.0/src/experiment_configuration_agent.egg-info/dependency_links.txt +1 -0
- experiment_configuration_agent-0.1.0/src/experiment_configuration_agent.egg-info/requires.txt +6 -0
- experiment_configuration_agent-0.1.0/src/experiment_configuration_agent.egg-info/top_level.txt +1 -0
- experiment_configuration_agent-0.1.0/tests/test_agent.py +89 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: experiment-configuration-agent
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: pydantic-settings
|
|
8
|
+
Requires-Dist: sfn-blueprint>=0.6.16
|
|
9
|
+
Provides-Extra: dev
|
|
10
|
+
Requires-Dist: pytest; extra == "dev"
|
|
11
|
+
Requires-Dist: pytest-mock; extra == "dev"
|
|
12
|
+
|
|
13
|
+
# Experiment Configuration Agent for AutoGluon
|
|
14
|
+
|
|
15
|
+
This agent uses a Large Language Model to recommend optimal configurations for AutoGluon's `TabularPredictor` based on your machine learning problem context. By providing details about your domain, use case, and dataset, the agent will generate a set of `TabularPredictor` parameters designed to optimize for performance and efficiency.
|
|
16
|
+
|
|
17
|
+
## Features
|
|
18
|
+
|
|
19
|
+
- **Intelligent Configuration:** Leverages LLMs to recommend `eval_metric`, `presets`, `time_limit`, and ensembling parameters.
|
|
20
|
+
- **Context-Aware:** Considers the business domain, specific use case, ML methodology (e.g., classification, regression), and dataset characteristics.
|
|
21
|
+
- **Flexible Backend:** Powered by `sfn-blueprint`, allowing for a configurable LLM backend.
|
|
22
|
+
- **Multiple Scenarios:** Provides recommendations for different optimization goals, such as maximizing accuracy, balancing performance and speed, or fast prototyping.
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
This project uses `uv` for dependency management and requires Python 3.10 or higher.
|
|
27
|
+
|
|
28
|
+
1. **Clone the repository:**
|
|
29
|
+
```bash
|
|
30
|
+
git clone <repository-url>
|
|
31
|
+
cd experiment-configuration-agent
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
2. **Set up the environment and install dependencies:**
|
|
35
|
+
It is recommended to use a virtual environment. `uv` can create one for you.
|
|
36
|
+
```bash
|
|
37
|
+
# Install uv if you don't have it
|
|
38
|
+
pip install uv
|
|
39
|
+
|
|
40
|
+
# Create a virtual environment and install dependencies
|
|
41
|
+
uv sync
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Usage
|
|
45
|
+
|
|
46
|
+
To get a configuration recommendation, instantiate the `AutoGluonConfigAgent` and pass a dictionary containing the problem context.
|
|
47
|
+
|
|
48
|
+
1. **Create a `.env` file** in the project root to configure the LLM provider. See the [Configuration](#configuration) section for more details.
|
|
49
|
+
|
|
50
|
+
```
|
|
51
|
+
PROVIDER="openai"
|
|
52
|
+
MODEL="gpt-4-turbo"
|
|
53
|
+
# Add your API key, e.g., OPENAI_API_KEY="sk-..."
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
2. **Create your Python script:**
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
from experiment_configuration_agent.agent import AutoGluonConfigAgent
|
|
60
|
+
|
|
61
|
+
# 1. Define the problem context
|
|
62
|
+
task_data = {
|
|
63
|
+
"domain": {
|
|
64
|
+
"name": "Manufacturing",
|
|
65
|
+
"description": "An automotive parts manufacturing facility with multiple production lines."
|
|
66
|
+
},
|
|
67
|
+
"use_case": {
|
|
68
|
+
"name": "Predictive Maintenance",
|
|
69
|
+
"description": "Detect unusual temporal patterns in sensor data to predict equipment failure and prevent breakdowns."
|
|
70
|
+
},
|
|
71
|
+
"methodology": "binary_classification",
|
|
72
|
+
"dataset_insights": {
|
|
73
|
+
"num_samples": 5000,
|
|
74
|
+
"num_features": 10,
|
|
75
|
+
"target": {
|
|
76
|
+
"name": "failure_flag",
|
|
77
|
+
"imbalance_ratio": 0.05 # Highly imbalanced
|
|
78
|
+
},
|
|
79
|
+
"feature_summary": {
|
|
80
|
+
"sensor_A": {"min": 0.1, "max": 100.5, "dtype": "float"},
|
|
81
|
+
"production_line_id": {"unique_count": 3, "dtype": "category"}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
# 2. Initialize the agent
|
|
87
|
+
agent = AutoGluonConfigAgent()
|
|
88
|
+
|
|
89
|
+
# 3. Get the configuration recommendation
|
|
90
|
+
result = agent(task_data)
|
|
91
|
+
|
|
92
|
+
# 4. Print the result
|
|
93
|
+
print("Recommended AutoGluon Configuration:")
|
|
94
|
+
print(result.get("configuration"))
|
|
95
|
+
print("\nCost Summary:")
|
|
96
|
+
print(result.get("cost_summary"))
|
|
97
|
+
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
## Configuration
|
|
101
|
+
|
|
102
|
+
The agent is configured via environment variables, which can be placed in a `.env` file in the project root. The primary configurations are inherited from the `GluonConfig` class.
|
|
103
|
+
|
|
104
|
+
- `PROVIDER`: The LLM provider to use (e.g., `"openai"`, `"anthropic"`).
|
|
105
|
+
- `MODEL`: The specific model to use (e.g., `"gpt-4-turbo"`, `"claude-3-opus-20240229"`).
|
|
106
|
+
- `TEMPERATURE`: The model's temperature setting (e.g., `0.3`).
|
|
107
|
+
- `MAX_TOKENS`: The maximum number of tokens for the response (e.g., `4000`).
|
|
108
|
+
|
|
109
|
+
You will also need to set the API key for your chosen provider, for example `OPENAI_API_KEY="your-key-here"`.
|
|
110
|
+
|
|
111
|
+
## Testing
|
|
112
|
+
|
|
113
|
+
This project uses `pytest`. To run the test suite, execute the following command from the project root:
|
|
114
|
+
|
|
115
|
+
```bash
|
|
116
|
+
pytest
|
|
117
|
+
```
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Experiment Configuration Agent for AutoGluon
|
|
2
|
+
|
|
3
|
+
This agent uses a Large Language Model to recommend optimal configurations for AutoGluon's `TabularPredictor` based on your machine learning problem context. By providing details about your domain, use case, and dataset, the agent will generate a set of `TabularPredictor` parameters designed to optimize for performance and efficiency.
|
|
4
|
+
|
|
5
|
+
## Features
|
|
6
|
+
|
|
7
|
+
- **Intelligent Configuration:** Leverages LLMs to recommend `eval_metric`, `presets`, `time_limit`, and ensembling parameters.
|
|
8
|
+
- **Context-Aware:** Considers the business domain, specific use case, ML methodology (e.g., classification, regression), and dataset characteristics.
|
|
9
|
+
- **Flexible Backend:** Powered by `sfn-blueprint`, allowing for a configurable LLM backend.
|
|
10
|
+
- **Multiple Scenarios:** Provides recommendations for different optimization goals, such as maximizing accuracy, balancing performance and speed, or fast prototyping.
|
|
11
|
+
|
|
12
|
+
## Installation
|
|
13
|
+
|
|
14
|
+
This project uses `uv` for dependency management and requires Python 3.10 or higher.
|
|
15
|
+
|
|
16
|
+
1. **Clone the repository:**
|
|
17
|
+
```bash
|
|
18
|
+
git clone <repository-url>
|
|
19
|
+
cd experiment-configuration-agent
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
2. **Set up the environment and install dependencies:**
|
|
23
|
+
It is recommended to use a virtual environment. `uv` can create one for you.
|
|
24
|
+
```bash
|
|
25
|
+
# Install uv if you don't have it
|
|
26
|
+
pip install uv
|
|
27
|
+
|
|
28
|
+
# Create a virtual environment and install dependencies
|
|
29
|
+
uv sync
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Usage
|
|
33
|
+
|
|
34
|
+
To get a configuration recommendation, instantiate the `AutoGluonConfigAgent` and pass a dictionary containing the problem context.
|
|
35
|
+
|
|
36
|
+
1. **Create a `.env` file** in the project root to configure the LLM provider. See the [Configuration](#configuration) section for more details.
|
|
37
|
+
|
|
38
|
+
```
|
|
39
|
+
PROVIDER="openai"
|
|
40
|
+
MODEL="gpt-4-turbo"
|
|
41
|
+
# Add your API key, e.g., OPENAI_API_KEY="sk-..."
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
2. **Create your Python script:**
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from experiment_configuration_agent.agent import AutoGluonConfigAgent
|
|
48
|
+
|
|
49
|
+
# 1. Define the problem context
|
|
50
|
+
task_data = {
|
|
51
|
+
"domain": {
|
|
52
|
+
"name": "Manufacturing",
|
|
53
|
+
"description": "An automotive parts manufacturing facility with multiple production lines."
|
|
54
|
+
},
|
|
55
|
+
"use_case": {
|
|
56
|
+
"name": "Predictive Maintenance",
|
|
57
|
+
"description": "Detect unusual temporal patterns in sensor data to predict equipment failure and prevent breakdowns."
|
|
58
|
+
},
|
|
59
|
+
"methodology": "binary_classification",
|
|
60
|
+
"dataset_insights": {
|
|
61
|
+
"num_samples": 5000,
|
|
62
|
+
"num_features": 10,
|
|
63
|
+
"target": {
|
|
64
|
+
"name": "failure_flag",
|
|
65
|
+
"imbalance_ratio": 0.05 # Highly imbalanced
|
|
66
|
+
},
|
|
67
|
+
"feature_summary": {
|
|
68
|
+
"sensor_A": {"min": 0.1, "max": 100.5, "dtype": "float"},
|
|
69
|
+
"production_line_id": {"unique_count": 3, "dtype": "category"}
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# 2. Initialize the agent
|
|
75
|
+
agent = AutoGluonConfigAgent()
|
|
76
|
+
|
|
77
|
+
# 3. Get the configuration recommendation
|
|
78
|
+
result = agent(task_data)
|
|
79
|
+
|
|
80
|
+
# 4. Print the result
|
|
81
|
+
print("Recommended AutoGluon Configuration:")
|
|
82
|
+
print(result.get("configuration"))
|
|
83
|
+
print("\nCost Summary:")
|
|
84
|
+
print(result.get("cost_summary"))
|
|
85
|
+
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
## Configuration
|
|
89
|
+
|
|
90
|
+
The agent is configured via environment variables, which can be placed in a `.env` file in the project root. The primary configurations are inherited from the `GluonConfig` class.
|
|
91
|
+
|
|
92
|
+
- `PROVIDER`: The LLM provider to use (e.g., `"openai"`, `"anthropic"`).
|
|
93
|
+
- `MODEL`: The specific model to use (e.g., `"gpt-4-turbo"`, `"claude-3-opus-20240229"`).
|
|
94
|
+
- `TEMPERATURE`: The model's temperature setting (e.g., `0.3`).
|
|
95
|
+
- `MAX_TOKENS`: The maximum number of tokens for the response (e.g., `4000`).
|
|
96
|
+
|
|
97
|
+
You will also need to set the API key for your chosen provider, for example `OPENAI_API_KEY="your-key-here"`.
|
|
98
|
+
|
|
99
|
+
## Testing
|
|
100
|
+
|
|
101
|
+
This project uses `pytest`. To run the test suite, execute the following command from the project root:
|
|
102
|
+
|
|
103
|
+
```bash
|
|
104
|
+
pytest
|
|
105
|
+
```
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "experiment-configuration-agent"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.10"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"pydantic-settings",
|
|
9
|
+
"sfn-blueprint>=0.6.16",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
[project.optional-dependencies]
|
|
13
|
+
dev = [
|
|
14
|
+
"pytest",
|
|
15
|
+
"pytest-mock",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[tool.setuptools.packages.find]
|
|
19
|
+
where = ["src"]
|
|
20
|
+
|
|
21
|
+
[tool.pytest.ini_options]
|
|
22
|
+
pythonpath = [
|
|
23
|
+
"src"
|
|
24
|
+
]
|
|
25
|
+
filterwarnings = [
|
|
26
|
+
"ignore::DeprecationWarning",
|
|
27
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
|
|
2
|
+
from sfn_blueprint import BaseLangChainAgent
|
|
3
|
+
from .config import GluonConfig
|
|
4
|
+
from .constants import format_autogluon_config_prompt
|
|
5
|
+
from .models import AutoGluonConfig
|
|
6
|
+
from typing import Dict, Any, Tuple, Optional
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AutoGluonConfigAgent(BaseLangChainAgent):
|
|
11
|
+
def __init__(self, config: Optional[GluonConfig] = None):
|
|
12
|
+
super().__init__(config or GluonConfig())
|
|
13
|
+
|
|
14
|
+
def configure_training(
|
|
15
|
+
self,
|
|
16
|
+
domain: Dict[str, str],
|
|
17
|
+
use_case: Dict[str, str],
|
|
18
|
+
methodology: str,
|
|
19
|
+
dataset_insights: Dict[str, Any]
|
|
20
|
+
) :
|
|
21
|
+
"""
|
|
22
|
+
Generates an AutoGluon configuration based on domain context and data insights.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
system_prompt, user_prompt = format_autogluon_config_prompt(
|
|
26
|
+
domain=domain,
|
|
27
|
+
use_case=use_case,
|
|
28
|
+
methodology=methodology,
|
|
29
|
+
dataset_insights=dataset_insights
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
response, cost_summary = self.route_with_langchain(
|
|
33
|
+
system_prompt,
|
|
34
|
+
user_prompt,
|
|
35
|
+
AutoGluonConfig
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
return response, cost_summary
|
|
39
|
+
|
|
40
|
+
def execute_task(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
41
|
+
"""
|
|
42
|
+
Wrapper to execute via standard task dictionary interface.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
config_recommendation, cost = self.configure_training(
|
|
46
|
+
domain=task_data["domain"],
|
|
47
|
+
use_case=task_data["use_case"],
|
|
48
|
+
methodology=task_data["methodology"],
|
|
49
|
+
dataset_insights=task_data["dataset_insights"]
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
return {
|
|
53
|
+
"configuration": config_recommendation.model_dump(),
|
|
54
|
+
"cost_summary": cost
|
|
55
|
+
}
|
|
56
|
+
def __call__(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
57
|
+
return self.execute_task(task_data)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from pydantic import Field
|
|
2
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GluonConfig(BaseSettings):
|
|
8
|
+
model_config = SettingsConfigDict(
|
|
9
|
+
env_file='.env',
|
|
10
|
+
env_file_encoding='utf-8',
|
|
11
|
+
case_sensitive=False,
|
|
12
|
+
extra='ignore'
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
provider: str = Field(default="openai", description="AI provider to use")
|
|
16
|
+
model: str = Field(default="gpt-5-mini", description="AI model to use")
|
|
17
|
+
temperature: float = Field(default=0.3, ge=0.0, le=0.5, description="AI model temperature")
|
|
18
|
+
max_tokens: int = Field(default=4000, ge=0, le=8000, description="Maximum tokens for AI response")
|
|
19
|
+
|
|
20
|
+
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
AUTOGLUON_CONFIG_SYSTEM_PROMPT = """You are an expert AutoGluon configuration advisor specializing in optimizing TabularPredictor settings for various machine learning problems.
|
|
2
|
+
|
|
3
|
+
Your role is to analyze the provided information about:
|
|
4
|
+
1. The domain and business context
|
|
5
|
+
2. The specific use case and problem type
|
|
6
|
+
3. The methodology (classification type or regression)
|
|
7
|
+
4. Dataset characteristics from the modeling-ready insights
|
|
8
|
+
|
|
9
|
+
Based on this analysis, you must recommend optimal AutoGluon TabularPredictor configuration parameters that balance:
|
|
10
|
+
- Predictive accuracy and model quality
|
|
11
|
+
- Training time and computational resources
|
|
12
|
+
- Inference speed requirements
|
|
13
|
+
- Model complexity and interpretability needs
|
|
14
|
+
- Domain-specific requirements
|
|
15
|
+
|
|
16
|
+
Key AutoGluon Concepts:
|
|
17
|
+
======================
|
|
18
|
+
|
|
19
|
+
EVALUATION METRICS:
|
|
20
|
+
- For binary classification: f1, precision, recall, accuracy, roc_auc, balanced_accuracy, log_loss, mcc
|
|
21
|
+
- For multiclass classification: accuracy, balanced_accuracy, log_loss, f1_macro, f1_micro, f1_weighted
|
|
22
|
+
- For regression: rmse, mae, r2, mse, mape
|
|
23
|
+
- Choose metrics based on business requirements (e.g., f1 for imbalanced classes, accuracy for balanced classes)
|
|
24
|
+
|
|
25
|
+
PRESETS (Quality vs Speed Tradeoff):
|
|
26
|
+
- "best_quality": Maximum accuracy with auto_stack=True, uses bagging/stacking (slowest, most accurate)
|
|
27
|
+
- "high_quality": Strong accuracy with fast inference, moderate complexity
|
|
28
|
+
- "good_quality": Good accuracy with very fast inference (recommended default)
|
|
29
|
+
- "medium_quality": Fast training time, ideal for prototyping
|
|
30
|
+
- "optimize_for_deployment": Minimal memory/compute footprint
|
|
31
|
+
- "interpretable": Simple, explainable models (e.g., linear models, shallow trees)
|
|
32
|
+
|
|
33
|
+
BAGGING & STACKING (Ensemble Methods):
|
|
34
|
+
- num_bag_folds: 0=no bagging, 5-10=k-fold bagging (improves accuracy, increases training time)
|
|
35
|
+
- num_bag_sets: Number of times to repeat k-fold bagging (1-3, only if num_bag_folds>0)
|
|
36
|
+
- num_stack_levels: 0=no stacking, 1=single-level stacking, 2+=multi-level stacking
|
|
37
|
+
- auto_stack: True=AutoGluon automatically determines optimal stacking/bagging (recommended for best quality)
|
|
38
|
+
- refit_full: True=retrain on full dataset after validation (essential when bagging is used)
|
|
39
|
+
|
|
40
|
+
TIME CONSTRAINTS:
|
|
41
|
+
- time_limit: Total seconds for training (60-3600+ depending on dataset size)
|
|
42
|
+
- Larger datasets need more time, more complex models need more time
|
|
43
|
+
- Consider: small datasets (<10K rows): 60-300s, medium (10K-100K): 300-1800s, large (>100K): 1800-7200s
|
|
44
|
+
|
|
45
|
+
INFERENCE SPEED:
|
|
46
|
+
- infer_limit: Max seconds per row prediction (e.g., 0.001=1ms, 0.00005=0.05ms)
|
|
47
|
+
- infer_limit_batch_size: Batch size for speed calculation (1=online, 1000+=batch inference)
|
|
48
|
+
- Only specify if real-time inference speed is critical
|
|
49
|
+
|
|
50
|
+
DECISION THRESHOLD CALIBRATION (Binary Classification):
|
|
51
|
+
- "auto": Calibrate for metrics like f1, balanced_accuracy (recommended)
|
|
52
|
+
- True: Always calibrate
|
|
53
|
+
- False: Never calibrate
|
|
54
|
+
- Significantly improves f1, balanced_accuracy, and recall-related metrics
|
|
55
|
+
|
|
56
|
+
Domain-Specific Considerations:
|
|
57
|
+
==============================
|
|
58
|
+
|
|
59
|
+
MANUFACTURING/QUALITY CONTROL:
|
|
60
|
+
- High precision to avoid false positives (unnecessary inspections)
|
|
61
|
+
- Use precision, f1 metrics
|
|
62
|
+
- Consider interpretable models for regulatory compliance
|
|
63
|
+
- Fast inference for real-time monitoring
|
|
64
|
+
|
|
65
|
+
HEALTHCARE/MEDICAL:
|
|
66
|
+
- High recall to avoid missing critical cases
|
|
67
|
+
- Use recall, balanced_accuracy, f1 metrics
|
|
68
|
+
- Interpretability is often critical
|
|
69
|
+
- May need decision threshold calibration
|
|
70
|
+
|
|
71
|
+
FRAUD DETECTION:
|
|
72
|
+
- Highly imbalanced classes
|
|
73
|
+
- Use f1, balanced_accuracy, roc_auc, precision-recall metrics
|
|
74
|
+
- Fast inference for real-time detection
|
|
75
|
+
- Consider calibrate_decision_threshold=True
|
|
76
|
+
|
|
77
|
+
PREDICTIVE MAINTENANCE:
|
|
78
|
+
- Balance precision and recall
|
|
79
|
+
- Consider temporal patterns
|
|
80
|
+
- Fast inference for real-time monitoring
|
|
81
|
+
- Good interpretability for actionable insights
|
|
82
|
+
|
|
83
|
+
CUSTOMER CHURN/RETENTION:
|
|
84
|
+
- Focus on identifying at-risk customers (recall)
|
|
85
|
+
- Use f1, balanced_accuracy, roc_auc
|
|
86
|
+
- Interpretability helps with intervention strategies
|
|
87
|
+
|
|
88
|
+
FINANCIAL/CREDIT SCORING:
|
|
89
|
+
- Regulatory requirements for interpretability
|
|
90
|
+
- Balance precision and recall
|
|
91
|
+
- Consider "interpretable" preset or simple models
|
|
92
|
+
|
|
93
|
+
Dataset Characteristics Impact:
|
|
94
|
+
==============================
|
|
95
|
+
|
|
96
|
+
SMALL DATASETS (<10K rows):
|
|
97
|
+
- Benefit greatly from bagging (num_bag_folds=5-8)
|
|
98
|
+
- May overfit with too many models
|
|
99
|
+
- Shorter time_limit (60-300s)
|
|
100
|
+
- Consider cross-validation
|
|
101
|
+
|
|
102
|
+
IMBALANCED CLASSES:
|
|
103
|
+
- Use f1, balanced_accuracy, roc_auc metrics (NOT accuracy)
|
|
104
|
+
- Consider calibrate_decision_threshold=True
|
|
105
|
+
- May benefit from custom hyperparameters
|
|
106
|
+
|
|
107
|
+
HIGH-DIMENSIONAL DATA (many features):
|
|
108
|
+
- May need longer time_limit
|
|
109
|
+
- Feature importance analysis post-training
|
|
110
|
+
- Consider preset="best_quality" for feature interactions
|
|
111
|
+
|
|
112
|
+
TEMPORAL/TIME-SERIES PATTERNS:
|
|
113
|
+
- Be cautious with shuffling (AutoGluon does stratified splits)
|
|
114
|
+
- May need time-based validation splits
|
|
115
|
+
- Consider specialized time-series features
|
|
116
|
+
|
|
117
|
+
Be thoughtful, analytical, and provide production-ready recommendations."""
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def format_autogluon_config_prompt(
|
|
122
|
+
domain: dict,
|
|
123
|
+
use_case: str,
|
|
124
|
+
methodology: str,
|
|
125
|
+
dataset_insights: dict
|
|
126
|
+
) -> tuple[str, str]:
|
|
127
|
+
"""
|
|
128
|
+
Format the system and user prompts for AutoGluon configuration recommendation.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
domain_name: Name of the business domain
|
|
132
|
+
domain_description: Detailed description of the domain context
|
|
133
|
+
use_case: Description of the specific use case and problem
|
|
134
|
+
methodology: Type of ML problem (binary_classification, multiclass_classification, regression)
|
|
135
|
+
dataset_insights: Dictionary containing feature and target information
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Tuple of (system_prompt, user_prompt)
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
# feature_columns = dataset_insights.get('feature_columns', {})
|
|
142
|
+
# target_info = dataset_insights.get('target', {})
|
|
143
|
+
|
|
144
|
+
# feature_summary = []
|
|
145
|
+
# for col_name, col_info in feature_columns.items():
|
|
146
|
+
# feature_summary.append(
|
|
147
|
+
# f" - {col_name}: "
|
|
148
|
+
# f"type={col_info.get('dtype', 'unknown')}, "
|
|
149
|
+
# f"missing={col_info.get('missing_pct', 0):.1f}%, "
|
|
150
|
+
# f"unique={col_info.get('unique_count', 'N/A')}"
|
|
151
|
+
# )
|
|
152
|
+
# if 'min' in col_info and 'max' in col_info:
|
|
153
|
+
# feature_summary.append(f" range=[{col_info['min']}, {col_info['max']}]")
|
|
154
|
+
|
|
155
|
+
# feature_text = "\n".join(feature_summary) if feature_summary else "No feature information provided"
|
|
156
|
+
|
|
157
|
+
# # Format target information
|
|
158
|
+
# target_text = []
|
|
159
|
+
# if target_info:
|
|
160
|
+
# target_text.append(f"Target Column: {target_info.get('name', 'unknown')}")
|
|
161
|
+
# target_text.append(f" Type: {target_info.get('dtype', 'unknown')}")
|
|
162
|
+
|
|
163
|
+
# if 'class_distribution' in target_info:
|
|
164
|
+
# target_text.append(" Class Distribution:")
|
|
165
|
+
# for cls, count in target_info['class_distribution'].items():
|
|
166
|
+
# target_text.append(f" - {cls}: {count}")
|
|
167
|
+
|
|
168
|
+
# if 'min' in target_info and 'max' in target_info:
|
|
169
|
+
# target_text.append(f" Range: [{target_info['min']}, {target_info['max']}]")
|
|
170
|
+
|
|
171
|
+
# if 'mean' in target_info:
|
|
172
|
+
# target_text.append(f" Mean: {target_info['mean']:.2f}")
|
|
173
|
+
|
|
174
|
+
# target_summary = "\n".join(target_text) if target_text else "No target information provided"
|
|
175
|
+
|
|
176
|
+
# # Get dataset size information
|
|
177
|
+
# num_samples = dataset_insights.get('num_samples', 'unknown')
|
|
178
|
+
# num_features = len(feature_columns) if feature_columns else 'unknown'
|
|
179
|
+
|
|
180
|
+
# DATASET INSIGHTS:
|
|
181
|
+
# ================
|
|
182
|
+
# Number of Samples: {num_samples}
|
|
183
|
+
# Number of Features: {num_features}
|
|
184
|
+
|
|
185
|
+
# Features:
|
|
186
|
+
# {feature_text}
|
|
187
|
+
|
|
188
|
+
# {target_summary}
|
|
189
|
+
# 8. hyperparameters: Hyperparameter preset ("default", "light", "very_light")
|
|
190
|
+
# 9. auto_stack: Whether to use automatic stacking (true/false)
|
|
191
|
+
# 10. infer_limit: Max inference time per row in seconds (or null)
|
|
192
|
+
# 11. infer_limit_batch_size: Batch size for inference speed (or null)
|
|
193
|
+
# 12. refit_full: Whether to retrain on full data (true/false)
|
|
194
|
+
# 13. calibrate_decision_threshold: Threshold calibration setting ("auto", true, false)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
user_prompt = f"""Please recommend optimal AutoGluon TabularPredictor configuration for the following scenario:
|
|
198
|
+
|
|
199
|
+
DOMAIN INFORMATION:
|
|
200
|
+
==================
|
|
201
|
+
Domain: {domain}
|
|
202
|
+
|
|
203
|
+
USE CASE:
|
|
204
|
+
=========
|
|
205
|
+
{use_case}
|
|
206
|
+
|
|
207
|
+
METHODOLOGY:
|
|
208
|
+
===========
|
|
209
|
+
Problem Type: {methodology}
|
|
210
|
+
|
|
211
|
+
DATASET INSIGHTS:
|
|
212
|
+
================
|
|
213
|
+
{dataset_insights}
|
|
214
|
+
|
|
215
|
+
TASK:
|
|
216
|
+
=====
|
|
217
|
+
Based on the above information, recommend an optimal AutoGluon configuration that includes:
|
|
218
|
+
|
|
219
|
+
1. eval_metric: The primary metric to optimize
|
|
220
|
+
2. preset: Quality/speed tradeoff preset
|
|
221
|
+
3. additional_metrics: Other metrics to track (list)
|
|
222
|
+
4. time_limit: Training time in seconds
|
|
223
|
+
5. num_bag_folds: Number of k-fold bagging folds (0 for none, 5-10 for bagging)
|
|
224
|
+
6. num_bag_sets: Number of bagging sets (1-3, only if bagging is used)
|
|
225
|
+
7. num_stack_levels: Number of stacking levels
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
Consider multiple scenarios:
|
|
229
|
+
- Scenario A: Maximum accuracy (accepting longer training time)
|
|
230
|
+
- Scenario B: Balanced accuracy and speed (production-ready)
|
|
231
|
+
- Scenario C: Fast training and inference (prototyping/deployment constrained)
|
|
232
|
+
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
return AUTOGLUON_CONFIG_SYSTEM_PROMPT, user_prompt
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
from typing import List, Literal, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AutoGluonConfig(BaseModel):
|
|
6
|
+
eval_metric: str = Field(..., description="Primary metric to optimize. For binary classification: 'f1', 'precision', 'recall', 'accuracy', 'roc_auc', 'balanced_accuracy', 'log_loss'. For multiclass: 'accuracy', 'balanced_accuracy', 'log_loss', 'f1_macro', 'f1_micro', 'f1_weighted'. For regression: 'rmse', 'mae', 'r2', 'mse'")
|
|
7
|
+
|
|
8
|
+
preset: Literal["best_quality", "high_quality", "good_quality", "medium_quality", "optimize_for_deployment", "interpretable"] = Field(..., description="Preset configurations that control model complexity and training strategy. 'best_quality' for maximum accuracy with stacking/bagging, 'high_quality' for strong accuracy with fast inference, 'good_quality' for good accuracy with very fast inference, 'medium_quality' for fast training, 'optimize_for_deployment' for minimal memory/compute, 'interpretable' for simple models")
|
|
9
|
+
|
|
10
|
+
additional_metrics: List[str] = Field(..., description="List of additional metrics to track during training. Options: 'f1', 'precision', 'recall', 'accuracy', 'roc_auc', 'balanced_accuracy', 'log_loss', 'rmse', 'mae', 'r2', 'mse'")
|
|
11
|
+
|
|
12
|
+
time_limit: int = Field(..., description="Time limit in seconds for training all models. Recommended: 60-3600 seconds depending on dataset size and complexity")
|
|
13
|
+
|
|
14
|
+
num_bag_folds: int = Field(..., description="Number of folds for k-fold bagging. 0 = no bagging, 5-10 = good for improving accuracy with longer training time. Higher values reduce overfitting but increase training time and memory usage")
|
|
15
|
+
|
|
16
|
+
num_bag_sets: int = Field(..., description="Number of bagging sets. Each set repeats k-fold bagging to further reduce variance. 1 = standard bagging, 2-3 = better accuracy but much longer training. Only relevant when num_bag_folds > 0")
|
|
17
|
+
|
|
18
|
+
num_stack_levels: int = Field(..., description="Number of stacking levels. 0 = no stacking, 1 = one level of stacking (models trained on predictions of base models), 2+ = multi-level stacking. Higher values can improve accuracy but increase training time exponentially")
|
|
19
|
+
|
|
20
|
+
# hyperparameters: Optional[str] = Field(default="default", description="Hyperparameters preset. 'default' = balanced set of models, 'light' = fast training with simpler models, 'very_light' = very fast training, 'toy' = minimal training for testing. Can also be a custom dict")
|
|
21
|
+
|
|
22
|
+
# auto_stack: bool = Field(..., description="If True, automatically determines optimal stacking/bagging configuration. Recommended for best quality. Overrides num_bag_folds, num_bag_sets, num_stack_levels")
|
|
23
|
+
|
|
24
|
+
# infer_limit: Optional[float] = Field(..., description="Maximum time in seconds to predict a single row. If specified, AutoGluon will only train models that meet this constraint. Example: 0.001 = 1ms per row, 0.00005 = 0.05ms per row (20000 rows/sec)")
|
|
25
|
+
|
|
26
|
+
# infer_limit_batch_size: Optional[int] = Field(..., description="Batch size for inference speed calculation when infer_limit is specified. 1 = online inference (strict), 1000+ = batch inference (easier to satisfy). Must be specified if infer_limit is set")
|
|
27
|
+
|
|
28
|
+
# refit_full: bool = Field(..., description="Whether to retrain models on full dataset after initial training. Significantly improves quality when bagging is used. Recommended: True if bagging is enabled")
|
|
29
|
+
|
|
30
|
+
# calibrate_decision_threshold: Literal["auto", True, False] = Field(..., description="Whether to calibrate the decision threshold for binary classification. 'auto' = calibrate for metrics like f1, balanced_accuracy but not for accuracy, True = always calibrate, False = never calibrate")
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: experiment-configuration-agent
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: pydantic-settings
|
|
8
|
+
Requires-Dist: sfn-blueprint>=0.6.16
|
|
9
|
+
Provides-Extra: dev
|
|
10
|
+
Requires-Dist: pytest; extra == "dev"
|
|
11
|
+
Requires-Dist: pytest-mock; extra == "dev"
|
|
12
|
+
|
|
13
|
+
# Experiment Configuration Agent for AutoGluon
|
|
14
|
+
|
|
15
|
+
This agent uses a Large Language Model to recommend optimal configurations for AutoGluon's `TabularPredictor` based on your machine learning problem context. By providing details about your domain, use case, and dataset, the agent will generate a set of `TabularPredictor` parameters designed to optimize for performance and efficiency.
|
|
16
|
+
|
|
17
|
+
## Features
|
|
18
|
+
|
|
19
|
+
- **Intelligent Configuration:** Leverages LLMs to recommend `eval_metric`, `presets`, `time_limit`, and ensembling parameters.
|
|
20
|
+
- **Context-Aware:** Considers the business domain, specific use case, ML methodology (e.g., classification, regression), and dataset characteristics.
|
|
21
|
+
- **Flexible Backend:** Powered by `sfn-blueprint`, allowing for a configurable LLM backend.
|
|
22
|
+
- **Multiple Scenarios:** Provides recommendations for different optimization goals, such as maximizing accuracy, balancing performance and speed, or fast prototyping.
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
This project uses `uv` for dependency management and requires Python 3.10 or higher.
|
|
27
|
+
|
|
28
|
+
1. **Clone the repository:**
|
|
29
|
+
```bash
|
|
30
|
+
git clone <repository-url>
|
|
31
|
+
cd experiment-configuration-agent
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
2. **Set up the environment and install dependencies:**
|
|
35
|
+
It is recommended to use a virtual environment. `uv` can create one for you.
|
|
36
|
+
```bash
|
|
37
|
+
# Install uv if you don't have it
|
|
38
|
+
pip install uv
|
|
39
|
+
|
|
40
|
+
# Create a virtual environment and install dependencies
|
|
41
|
+
uv sync
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Usage
|
|
45
|
+
|
|
46
|
+
To get a configuration recommendation, instantiate the `AutoGluonConfigAgent` and pass a dictionary containing the problem context.
|
|
47
|
+
|
|
48
|
+
1. **Create a `.env` file** in the project root to configure the LLM provider. See the [Configuration](#configuration) section for more details.
|
|
49
|
+
|
|
50
|
+
```
|
|
51
|
+
PROVIDER="openai"
|
|
52
|
+
MODEL="gpt-4-turbo"
|
|
53
|
+
# Add your API key, e.g., OPENAI_API_KEY="sk-..."
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
2. **Create your Python script:**
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
from experiment_configuration_agent.agent import AutoGluonConfigAgent
|
|
60
|
+
|
|
61
|
+
# 1. Define the problem context
|
|
62
|
+
task_data = {
|
|
63
|
+
"domain": {
|
|
64
|
+
"name": "Manufacturing",
|
|
65
|
+
"description": "An automotive parts manufacturing facility with multiple production lines."
|
|
66
|
+
},
|
|
67
|
+
"use_case": {
|
|
68
|
+
"name": "Predictive Maintenance",
|
|
69
|
+
"description": "Detect unusual temporal patterns in sensor data to predict equipment failure and prevent breakdowns."
|
|
70
|
+
},
|
|
71
|
+
"methodology": "binary_classification",
|
|
72
|
+
"dataset_insights": {
|
|
73
|
+
"num_samples": 5000,
|
|
74
|
+
"num_features": 10,
|
|
75
|
+
"target": {
|
|
76
|
+
"name": "failure_flag",
|
|
77
|
+
"imbalance_ratio": 0.05 # Highly imbalanced
|
|
78
|
+
},
|
|
79
|
+
"feature_summary": {
|
|
80
|
+
"sensor_A": {"min": 0.1, "max": 100.5, "dtype": "float"},
|
|
81
|
+
"production_line_id": {"unique_count": 3, "dtype": "category"}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
# 2. Initialize the agent
|
|
87
|
+
agent = AutoGluonConfigAgent()
|
|
88
|
+
|
|
89
|
+
# 3. Get the configuration recommendation
|
|
90
|
+
result = agent(task_data)
|
|
91
|
+
|
|
92
|
+
# 4. Print the result
|
|
93
|
+
print("Recommended AutoGluon Configuration:")
|
|
94
|
+
print(result.get("configuration"))
|
|
95
|
+
print("\nCost Summary:")
|
|
96
|
+
print(result.get("cost_summary"))
|
|
97
|
+
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
## Configuration
|
|
101
|
+
|
|
102
|
+
The agent is configured via environment variables, which can be placed in a `.env` file in the project root. The primary configurations are inherited from the `GluonConfig` class.
|
|
103
|
+
|
|
104
|
+
- `PROVIDER`: The LLM provider to use (e.g., `"openai"`, `"anthropic"`).
|
|
105
|
+
- `MODEL`: The specific model to use (e.g., `"gpt-4-turbo"`, `"claude-3-opus-20240229"`).
|
|
106
|
+
- `TEMPERATURE`: The model's temperature setting (e.g., `0.3`).
|
|
107
|
+
- `MAX_TOKENS`: The maximum number of tokens for the response (e.g., `4000`).
|
|
108
|
+
|
|
109
|
+
You will also need to set the API key for your chosen provider, for example `OPENAI_API_KEY="your-key-here"`.
|
|
110
|
+
|
|
111
|
+
## Testing
|
|
112
|
+
|
|
113
|
+
This project uses `pytest`. To run the test suite, execute the following command from the project root:
|
|
114
|
+
|
|
115
|
+
```bash
|
|
116
|
+
pytest
|
|
117
|
+
```
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
src/experiment_config_agent/__init__.py
|
|
4
|
+
src/experiment_config_agent/agent.py
|
|
5
|
+
src/experiment_config_agent/config.py
|
|
6
|
+
src/experiment_config_agent/constants.py
|
|
7
|
+
src/experiment_config_agent/models.py
|
|
8
|
+
src/experiment_configuration_agent.egg-info/PKG-INFO
|
|
9
|
+
src/experiment_configuration_agent.egg-info/SOURCES.txt
|
|
10
|
+
src/experiment_configuration_agent.egg-info/dependency_links.txt
|
|
11
|
+
src/experiment_configuration_agent.egg-info/requires.txt
|
|
12
|
+
src/experiment_configuration_agent.egg-info/top_level.txt
|
|
13
|
+
tests/test_agent.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
experiment_configuration_agent-0.1.0/src/experiment_configuration_agent.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
experiment_config_agent
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import patch, MagicMock
|
|
3
|
+
from experiment_config_agent.agent import AutoGluonConfigAgent
|
|
4
|
+
from experiment_config_agent.models import AutoGluonConfig
|
|
5
|
+
|
|
6
|
+
@pytest.fixture
|
|
7
|
+
def agent():
|
|
8
|
+
"""Fixture to create an AutoGluonConfigAgent instance."""
|
|
9
|
+
return AutoGluonConfigAgent()
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def task_data():
|
|
13
|
+
"""Fixture to provide sample task data."""
|
|
14
|
+
return {
|
|
15
|
+
"domain": {"name": "Finance", "description": "Credit scoring"},
|
|
16
|
+
"use_case": {"name": "Loan Default Prediction", "description": "Predict if a customer will default on a loan"},
|
|
17
|
+
"methodology": "binary_classification",
|
|
18
|
+
"dataset_insights": {"samples": 10000, "features": 15}
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
def test_configure_training(agent, task_data):
|
|
22
|
+
"""Test the configure_training method."""
|
|
23
|
+
mock_response = AutoGluonConfig(
|
|
24
|
+
eval_metric='f1',
|
|
25
|
+
preset='high_quality',
|
|
26
|
+
additional_metrics=['roc_auc', 'accuracy'],
|
|
27
|
+
time_limit=3600,
|
|
28
|
+
num_bag_folds=5,
|
|
29
|
+
num_bag_sets=1,
|
|
30
|
+
num_stack_levels=1
|
|
31
|
+
)
|
|
32
|
+
mock_cost = {"total_cost": 0.05}
|
|
33
|
+
|
|
34
|
+
with patch.object(AutoGluonConfigAgent, 'route_with_langchain', return_value=(mock_response, mock_cost)) as mock_route:
|
|
35
|
+
response, cost_summary = agent.configure_training(
|
|
36
|
+
domain=task_data["domain"],
|
|
37
|
+
use_case=task_data["use_case"],
|
|
38
|
+
methodology=task_data["methodology"],
|
|
39
|
+
dataset_insights=task_data["dataset_insights"]
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
assert response == mock_response
|
|
43
|
+
assert cost_summary == mock_cost
|
|
44
|
+
mock_route.assert_called_once()
|
|
45
|
+
system_prompt, user_prompt, model = mock_route.call_args[0]
|
|
46
|
+
assert "You are an expert AutoGluon configuration advisor" in system_prompt
|
|
47
|
+
assert "DOMAIN INFORMATION" in user_prompt
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_execute_task(agent, task_data):
|
|
51
|
+
"""Test the execute_task method."""
|
|
52
|
+
mock_config = AutoGluonConfig(
|
|
53
|
+
eval_metric='f1',
|
|
54
|
+
preset='good_quality',
|
|
55
|
+
additional_metrics=['accuracy'],
|
|
56
|
+
time_limit=600,
|
|
57
|
+
num_bag_folds=0,
|
|
58
|
+
num_bag_sets=0,
|
|
59
|
+
num_stack_levels=0
|
|
60
|
+
)
|
|
61
|
+
mock_cost = {"total_cost": 0.02}
|
|
62
|
+
|
|
63
|
+
expected_config_dump = mock_config.model_dump()
|
|
64
|
+
|
|
65
|
+
with patch.object(AutoGluonConfigAgent, 'configure_training', return_value=(mock_config, mock_cost)) as mock_configure:
|
|
66
|
+
result = agent.execute_task(task_data)
|
|
67
|
+
|
|
68
|
+
mock_configure.assert_called_once_with(
|
|
69
|
+
domain=task_data["domain"],
|
|
70
|
+
use_case=task_data["use_case"],
|
|
71
|
+
methodology=task_data["methodology"],
|
|
72
|
+
dataset_insights=task_data["dataset_insights"]
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
assert "configuration" in result
|
|
76
|
+
assert "cost_summary" in result
|
|
77
|
+
assert result["configuration"] == expected_config_dump
|
|
78
|
+
assert result["cost_summary"] == mock_cost
|
|
79
|
+
|
|
80
|
+
def test_call_method(agent, task_data):
|
|
81
|
+
"""Test that the __call__ method invokes execute_task."""
|
|
82
|
+
mock_result = {
|
|
83
|
+
"configuration": {"eval_metric": "f1"},
|
|
84
|
+
"cost_summary": {"total_cost": 0.01}
|
|
85
|
+
}
|
|
86
|
+
with patch.object(AutoGluonConfigAgent, 'execute_task', return_value=mock_result) as mock_execute:
|
|
87
|
+
result = agent(task_data)
|
|
88
|
+
mock_execute.assert_called_once_with(task_data)
|
|
89
|
+
assert result == mock_result
|