ml-approach-suggestion-agent 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml_approach_suggestion_agent/__init__.py +0 -0
- ml_approach_suggestion_agent/agent.py +120 -0
- ml_approach_suggestion_agent/config.py +18 -0
- ml_approach_suggestion_agent/constants.py +343 -0
- ml_approach_suggestion_agent/models.py +18 -0
- ml_approach_suggestion_agent-0.1.8.dist-info/METADATA +214 -0
- ml_approach_suggestion_agent-0.1.8.dist-info/RECORD +9 -0
- ml_approach_suggestion_agent-0.1.8.dist-info/WHEEL +5 -0
- ml_approach_suggestion_agent-0.1.8.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from sfn_blueprint import SFNAIHandler, self_correcting_sql, Context
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from .config import MethodologyConfig
|
|
9
|
+
from .constants import format_approach_prompt
|
|
10
|
+
from .models import MethodologyRecommendation
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MLApproachDecisionAgent:
|
|
14
|
+
def __init__(self, config: Optional[MethodologyConfig] = None):
|
|
15
|
+
self.logger = logging.getLogger(__name__)
|
|
16
|
+
self.config = config or MethodologyConfig()
|
|
17
|
+
self.ai_handler = SFNAIHandler()
|
|
18
|
+
|
|
19
|
+
def suggest_approach(self, domain_name, domain_description, use_case, column_insights, max_try=1) -> Tuple[MethodologyRecommendation, Dict[str, Any]]:
|
|
20
|
+
"""
|
|
21
|
+
Suggests a machine learning approach based on the provided domain, use case, and column descriptions.
|
|
22
|
+
Args:
|
|
23
|
+
domain_name (str): The name of the domain.
|
|
24
|
+
domain_description (str): The description of the domain.
|
|
25
|
+
use_case (str): problem need to solve.
|
|
26
|
+
column_descriptions (List[str]): A list of column descriptions.
|
|
27
|
+
column_insights (List[str]): A list of column insights.
|
|
28
|
+
max_try (int, optional): The maximum number of attempts to make the API call. Defaults to 3.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
MethodologyRecommendation: The suggested machine learning approach.
|
|
32
|
+
|
|
33
|
+
TODO:
|
|
34
|
+
- USER prompt should consider those approaches which will be supported.
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
system_prompt, user_prompt = format_approach_prompt(domain_name=domain_name, domain_description=domain_description, use_case=use_case, column_insights=column_insights)
|
|
39
|
+
for _ in range(max_try):
|
|
40
|
+
try:
|
|
41
|
+
response, cost_summary = self.ai_handler.route_to(
|
|
42
|
+
llm_provider=self.config.methodology_ai_provider,
|
|
43
|
+
configuration={
|
|
44
|
+
"messages": [
|
|
45
|
+
{"role": "system", "content": system_prompt},
|
|
46
|
+
{"role": "user", "content": user_prompt}
|
|
47
|
+
],
|
|
48
|
+
"max_tokens": self.config.methodology_max_tokens,
|
|
49
|
+
# "temperature": self.config.methodology_temperature,
|
|
50
|
+
"text_format":MethodologyRecommendation
|
|
51
|
+
},
|
|
52
|
+
model=self.config.methodology_ai_model
|
|
53
|
+
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
return response, cost_summary
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
self.logger.error(f"Error while executing API call to {self.config.methodology_ai_provider}: {e}")
|
|
61
|
+
|
|
62
|
+
return {}, {}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def execute_task(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
67
|
+
self.logger.info("Executing data quality assessment task.")
|
|
68
|
+
domain_name, domain_description, use_case, column_insights = (
|
|
69
|
+
task_data["domain_name"],
|
|
70
|
+
task_data["domain_description"],
|
|
71
|
+
task_data["use_case"],
|
|
72
|
+
task_data["column_insights"],
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Suggest an approach
|
|
76
|
+
result, cost_summary = self.suggest_approach(
|
|
77
|
+
domain_name=domain_name,
|
|
78
|
+
domain_description=domain_description,
|
|
79
|
+
use_case=use_case,
|
|
80
|
+
column_insights=column_insights,
|
|
81
|
+
)
|
|
82
|
+
if not result:
|
|
83
|
+
return {
|
|
84
|
+
"success": False,
|
|
85
|
+
"error": "Failed to suggest approach.",
|
|
86
|
+
"agent": self.__class__.__name__
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
# Check if we have workflow storage information
|
|
91
|
+
if 'workflow_storage_path' in task_data or 'workflow_id' in task_data:
|
|
92
|
+
from sfn_blueprint import WorkflowStorageManager
|
|
93
|
+
|
|
94
|
+
# Determine workflow storage path
|
|
95
|
+
workflow_storage_path = task_data.get('workflow_storage_path', 'outputs/workflows')
|
|
96
|
+
workflow_id = task_data.get('workflow_id', 'unknown')
|
|
97
|
+
|
|
98
|
+
# Initialize storage manager
|
|
99
|
+
storage_manager = WorkflowStorageManager(workflow_storage_path, workflow_id)
|
|
100
|
+
storage_manager.save_agent_result(
|
|
101
|
+
agent_name=self.__class__.__name__,
|
|
102
|
+
step_name=" ",
|
|
103
|
+
data={"quality_reports": result.model_dump(), "cost_summary": cost_summary},
|
|
104
|
+
metadata={ "execution_time": datetime.now().isoformat()}
|
|
105
|
+
)
|
|
106
|
+
self.logger.info(" saved to workflow storage.")
|
|
107
|
+
except Exception as e:
|
|
108
|
+
self.logger.warning(f"Failed to save results to workflow storage: {e}")
|
|
109
|
+
|
|
110
|
+
return {
|
|
111
|
+
"success": True,
|
|
112
|
+
"result": {
|
|
113
|
+
"approach": result ,
|
|
114
|
+
"cost_summary": cost_summary
|
|
115
|
+
},
|
|
116
|
+
"agent": self.__class__.__name__
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
def __call__(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
120
|
+
return self.execute_task(task_data)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from pydantic import Field
|
|
2
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MethodologyConfig(BaseSettings):
|
|
6
|
+
model_config = SettingsConfigDict(
|
|
7
|
+
env_file='.env',
|
|
8
|
+
env_file_encoding='utf-8',
|
|
9
|
+
case_sensitive=False,
|
|
10
|
+
extra='ignore'
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
methodology_ai_provider: str = Field(default="openai", description="AI provider to use")
|
|
14
|
+
methodology_ai_model: str = Field(default="gpt-5-mini", description="AI model to use")
|
|
15
|
+
methodology_temperature: float = Field(default=0.3, ge=0.0, le=0.5, description="AI model temperature")
|
|
16
|
+
methodology_max_tokens: int = Field(default=4000, ge=100, le=8000, description="Maximum tokens for AI response")
|
|
17
|
+
|
|
18
|
+
|
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
# METHODOLOGY_SELECTION_SYSTEM_PROMPT = """You are an expert ML methodology advisor. Your task is to analyze the problem carefully and select the single most appropriate methodology from: binary_classification, timeseries_binary_classification, or not_applicable.
|
|
2
|
+
|
|
3
|
+
# **Methodology Definitions:**
|
|
4
|
+
|
|
5
|
+
# 1. **Binary Classification**
|
|
6
|
+
# - Predicts one of TWO possible outcomes (Yes/No, True/False, 1/0, Pass/Fail)
|
|
7
|
+
# - Uses historical data with labels to learn patterns
|
|
8
|
+
# - Predictions are categorical, not numerical values
|
|
9
|
+
# - Time is NOT a critical feature for prediction
|
|
10
|
+
|
|
11
|
+
# **Examples:**
|
|
12
|
+
# - "Is this transaction fraudulent?" → Fraud/Not Fraud
|
|
13
|
+
# - "Will the machine fail?" → Fail/Not Fail
|
|
14
|
+
# - "Is this email spam?" → Spam/Not Spam
|
|
15
|
+
|
|
16
|
+
# 2. **Timeseries Binary Classification**
|
|
17
|
+
# - Predicts one of TWO categories for sequential/temporal data
|
|
18
|
+
# - ORDER and TEMPORAL PATTERNS in the data are CRITICAL for making predictions
|
|
19
|
+
# - The sequence itself contains information (trends, seasonality, patterns over time)
|
|
20
|
+
# - Uses time-ordered observations where the temporal relationship matters
|
|
21
|
+
|
|
22
|
+
# **Examples:**
|
|
23
|
+
# - "Classify equipment state as 'normal' or 'anomalous' based on sensor readings over time"
|
|
24
|
+
# - "Predict if a patient will be readmitted within 30 days based on their medical history sequence"
|
|
25
|
+
# - "Classify stock price movement as 'upward' or 'downward' based on historical patterns"
|
|
26
|
+
# - "Detect if a time series pattern indicates an upcoming failure event"
|
|
27
|
+
|
|
28
|
+
# 3. **Not Applicable**
|
|
29
|
+
# - No machine learning prediction is needed
|
|
30
|
+
# - Pure data analysis, reporting, dashboards, or descriptive statistics
|
|
31
|
+
# - Insufficient information to determine methodology
|
|
32
|
+
# - Problem requires regression, multi-class classification, or other ML approaches not listed
|
|
33
|
+
|
|
34
|
+
# **Critical Decision Framework:**
|
|
35
|
+
|
|
36
|
+
# Ask yourself these questions in order:
|
|
37
|
+
|
|
38
|
+
# 1. **Is a prediction needed?**
|
|
39
|
+
# - No → `not_applicable`
|
|
40
|
+
# - Yes → Continue
|
|
41
|
+
|
|
42
|
+
# 2. **What is being predicted?**
|
|
43
|
+
# - A binary outcome (2 categories) → Continue to question 3
|
|
44
|
+
# - A numerical value → `not_applicable` (this is regression)
|
|
45
|
+
# - Multiple categories (3+) → `not_applicable` (this is multi-class)
|
|
46
|
+
# - Nothing specific → `not_applicable`
|
|
47
|
+
|
|
48
|
+
# 3. **Are temporal patterns ESSENTIAL for making the prediction?**
|
|
49
|
+
# - Yes, the sequence/order of observations contains critical information → `timeseries_binary_classification`
|
|
50
|
+
# - No, individual data points or snapshots are sufficient → `binary_classification`
|
|
51
|
+
|
|
52
|
+
# **Common Pitfalls to Avoid:**
|
|
53
|
+
|
|
54
|
+
# **Don't assume time series just because:**
|
|
55
|
+
# - The data has timestamps (most datasets do)
|
|
56
|
+
# - Events happened over time
|
|
57
|
+
# - There's a date column
|
|
58
|
+
|
|
59
|
+
# **Choose timeseries binary classification ONLY when:**
|
|
60
|
+
# - The temporal sequence itself reveals patterns needed for classification
|
|
61
|
+
# - Order matters: shuffling observations would lose critical information
|
|
62
|
+
# - Trends, seasonality, or temporal dependencies are key features
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# **Don't confuse classification with forecasting:**
|
|
66
|
+
# - Timeseries Binary Classification → Predict a category based on temporal patterns
|
|
67
|
+
# - Time Series Forecasting → Predict future numerical values (NOT an option here)
|
|
68
|
+
|
|
69
|
+
# **Output Requirements:**
|
|
70
|
+
|
|
71
|
+
# You must provide:
|
|
72
|
+
|
|
73
|
+
# 1. **selected_methodology**: Exactly one of: `binary_classification`, `timeseries_binary_classification`, or `not_applicable`
|
|
74
|
+
|
|
75
|
+
# 2. **justification**: A clear, structured explanation that includes:
|
|
76
|
+
# - **Business Goal**: What problem is being solved?
|
|
77
|
+
# - **Prediction Type**: What specific outcome needs to be predicted?
|
|
78
|
+
# - **Temporal Dependency**: Are time-based patterns essential for this prediction?
|
|
79
|
+
# - **Methodology Fit**: Why is the selected methodology the best match?
|
|
80
|
+
# - **Key Reasoning**: The critical factors that led to this decision
|
|
81
|
+
|
|
82
|
+
# Be decisive, analytical, and precise in your selection.
|
|
83
|
+
# """
|
|
84
|
+
|
|
85
|
+
METHODOLOGY_SELECTION_SYSTEM_PROMPT = """You are an expert ML methodology advisor. Your task is to analyze the problem carefully and select the single most appropriate methodology from:
|
|
86
|
+
|
|
87
|
+
binary_classification,
|
|
88
|
+
multiclass_classification,
|
|
89
|
+
regression,
|
|
90
|
+
timeseries_regression,
|
|
91
|
+
timeseries_binary_classification,
|
|
92
|
+
recommendation_engine,
|
|
93
|
+
timeseries_recommendation_engine,
|
|
94
|
+
clustering,
|
|
95
|
+
anomaly_detection,
|
|
96
|
+
timeseries_anomaly_detection,
|
|
97
|
+
or not_applicable.
|
|
98
|
+
|
|
99
|
+
---
|
|
100
|
+
|
|
101
|
+
## Methodology Definitions
|
|
102
|
+
|
|
103
|
+
### 1. Binary Classification
|
|
104
|
+
- Predicts one of TWO discrete outcomes (Yes/No, 1/0, Pass/Fail)
|
|
105
|
+
- Uses labeled historical data
|
|
106
|
+
- Output is categorical (2 classes)
|
|
107
|
+
- Time is NOT essential
|
|
108
|
+
|
|
109
|
+
**Examples:**
|
|
110
|
+
- "Will a customer churn?" → Yes / No
|
|
111
|
+
- "Is this transaction fraudulent?" → Fraud / Not Fraud
|
|
112
|
+
|
|
113
|
+
---
|
|
114
|
+
|
|
115
|
+
### 2. Multiclass Classification
|
|
116
|
+
- Predicts one of THREE OR MORE discrete classes
|
|
117
|
+
- Uses labeled historical data
|
|
118
|
+
- Output is categorical (3+ classes)
|
|
119
|
+
- Time is NOT essential
|
|
120
|
+
|
|
121
|
+
**Examples:**
|
|
122
|
+
- "Classify customer into Bronze / Silver / Gold"
|
|
123
|
+
- "Predict Heart Disease Level Low / Medium / High"
|
|
124
|
+
|
|
125
|
+
---
|
|
126
|
+
|
|
127
|
+
### 3. Regression
|
|
128
|
+
- Predicts a continuous numerical value
|
|
129
|
+
- Uses labeled historical data
|
|
130
|
+
- Time may exist but is NOT the main signal
|
|
131
|
+
|
|
132
|
+
**Examples:**
|
|
133
|
+
- "Predict loan amount"
|
|
134
|
+
- "Estimate house price"
|
|
135
|
+
- "Predict expected revenue"
|
|
136
|
+
|
|
137
|
+
---
|
|
138
|
+
|
|
139
|
+
### 4. Time Series Regression
|
|
140
|
+
- Predicts future numerical values
|
|
141
|
+
- Temporal order and patterns are ESSENTIAL
|
|
142
|
+
- Forecasting based on trends, seasonality, lag effects
|
|
143
|
+
|
|
144
|
+
**Examples:**
|
|
145
|
+
- "Predict energy consumption over time"
|
|
146
|
+
- "Estimate future demand"
|
|
147
|
+
|
|
148
|
+
---
|
|
149
|
+
|
|
150
|
+
### 5. Time Series Binary Classification
|
|
151
|
+
- Predicts one of TWO classes
|
|
152
|
+
- Temporal sequence is ESSENTIAL
|
|
153
|
+
- Order, trends, and patterns matter
|
|
154
|
+
|
|
155
|
+
**Examples:**
|
|
156
|
+
- "Predict machine failure in next 24 hours using sensor history"
|
|
157
|
+
- "Classify stock movement as Up / Down using price history"
|
|
158
|
+
|
|
159
|
+
---
|
|
160
|
+
|
|
161
|
+
### 6. Recommendation Engine
|
|
162
|
+
- Produces ranked lists or personalized suggestions
|
|
163
|
+
- Output is NOT a single class or number
|
|
164
|
+
- Time is NOT essential
|
|
165
|
+
|
|
166
|
+
**Examples:**
|
|
167
|
+
- "Recommend products to users"
|
|
168
|
+
- "Suggest movies based on viewing history"
|
|
169
|
+
|
|
170
|
+
---
|
|
171
|
+
|
|
172
|
+
### 7. Time Series Recommendation Engine
|
|
173
|
+
- Recommendations depend on sequence or recency
|
|
174
|
+
- Session-based or time-aware recommendations
|
|
175
|
+
|
|
176
|
+
**Examples:**
|
|
177
|
+
- "Recommend next product based on recent actions"
|
|
178
|
+
- "Suggest content based on session behavior"
|
|
179
|
+
|
|
180
|
+
---
|
|
181
|
+
|
|
182
|
+
### 8. Clustering
|
|
183
|
+
- No labeled target variable
|
|
184
|
+
- Groups similar entities together
|
|
185
|
+
- Discovers structure in data
|
|
186
|
+
|
|
187
|
+
**Examples:**
|
|
188
|
+
- "Customer segmentation"
|
|
189
|
+
- "Group users by behavior"
|
|
190
|
+
|
|
191
|
+
---
|
|
192
|
+
|
|
193
|
+
### 9. Anomaly Detection
|
|
194
|
+
- Detects rare, unusual, or abnormal observations
|
|
195
|
+
- Snapshot-based (time not essential)
|
|
196
|
+
|
|
197
|
+
**Examples:**
|
|
198
|
+
- "Detect fraudulent transactions"
|
|
199
|
+
- "Identify outlier sensor readings"
|
|
200
|
+
|
|
201
|
+
---
|
|
202
|
+
|
|
203
|
+
### 10. Not Applicable
|
|
204
|
+
- No ML prediction required
|
|
205
|
+
- Pure reporting, dashboards, or descriptive analysis
|
|
206
|
+
- Problem requires an unsupported methodology
|
|
207
|
+
- Insufficient information
|
|
208
|
+
|
|
209
|
+
---
|
|
210
|
+
|
|
211
|
+
## Critical Decision Framework
|
|
212
|
+
|
|
213
|
+
Ask these questions in order:
|
|
214
|
+
|
|
215
|
+
### 1. Is a prediction, recommendation, grouping, or anomaly detection needed?
|
|
216
|
+
- No → not_applicable
|
|
217
|
+
- Yes → Continue
|
|
218
|
+
|
|
219
|
+
---
|
|
220
|
+
|
|
221
|
+
### 2. What is the nature of the output?
|
|
222
|
+
- Ranked list or suggestions → Recommendation Engine
|
|
223
|
+
- Continuous numeric value → Regression
|
|
224
|
+
- Two classes → Binary Classification
|
|
225
|
+
- Three or more classes → Multiclass Classification
|
|
226
|
+
- No labels, discover groups → Clustering
|
|
227
|
+
- Detect rare/abnormal behavior → Anomaly Detection
|
|
228
|
+
|
|
229
|
+
---
|
|
230
|
+
|
|
231
|
+
### 3. Are temporal patterns ESSENTIAL?
|
|
232
|
+
If YES:
|
|
233
|
+
- Regression → timeseries_regression
|
|
234
|
+
- Binary classification → timeseries_binary_classification
|
|
235
|
+
- Recommendation → timeseries_recommendation_engine
|
|
236
|
+
|
|
237
|
+
If NO:
|
|
238
|
+
- Use the non-time-series variant
|
|
239
|
+
|
|
240
|
+
---
|
|
241
|
+
|
|
242
|
+
## Common Pitfalls to Avoid
|
|
243
|
+
|
|
244
|
+
**Don't confuse classification with regression:**
|
|
245
|
+
- "Will customer default?" → Binary Classification
|
|
246
|
+
- "How much will customer default?" → Regression
|
|
247
|
+
|
|
248
|
+
# **Don't assume time series just because:**
|
|
249
|
+
# - The data has timestamps
|
|
250
|
+
# - Events happened over time
|
|
251
|
+
# - There's a date column
|
|
252
|
+
|
|
253
|
+
**Don't confuse regression with forecasting:**
|
|
254
|
+
- Regression predicts numeric value using features
|
|
255
|
+
- Forecasting predicts future numeric values using time
|
|
256
|
+
- Forecasting = Time Series Regression
|
|
257
|
+
|
|
258
|
+
**Don't confuse recommendation with prediction:**
|
|
259
|
+
- "Will user buy product X?" → Binary Classification
|
|
260
|
+
- "Which products should we show?" → Recommendation Engine
|
|
261
|
+
|
|
262
|
+
---
|
|
263
|
+
|
|
264
|
+
## Output Requirements
|
|
265
|
+
|
|
266
|
+
You MUST return:
|
|
267
|
+
|
|
268
|
+
1. selected_methodology: Exactly ONE of:
|
|
269
|
+
- binary_classification
|
|
270
|
+
- multiclass_classification
|
|
271
|
+
- regression
|
|
272
|
+
- timeseries_regression
|
|
273
|
+
- timeseries_binary_classification
|
|
274
|
+
- recommendation_engine
|
|
275
|
+
- timeseries_recommendation_engine
|
|
276
|
+
- clustering
|
|
277
|
+
- anomaly_detection
|
|
278
|
+
- not_applicable
|
|
279
|
+
|
|
280
|
+
2. justification:
|
|
281
|
+
A clear, structured explanation including:
|
|
282
|
+
- Business Goal
|
|
283
|
+
- Output Type
|
|
284
|
+
- Target Variable / Output
|
|
285
|
+
- Temporal Dependency
|
|
286
|
+
- Methodology Fit
|
|
287
|
+
- Key Reasoning
|
|
288
|
+
|
|
289
|
+
Be decisive, analytical, and precise.
|
|
290
|
+
Do NOT choose not_applicable if any methodology clearly fits.
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
METHODOLOGY_SELECTION_USER_PROMPT = """**Business Context:**
|
|
294
|
+
Domain: {domain_name}
|
|
295
|
+
{domain_description}
|
|
296
|
+
|
|
297
|
+
**Use Case:**
|
|
298
|
+
{use_case_description}
|
|
299
|
+
|
|
300
|
+
**Dataset Characteristics:**
|
|
301
|
+
{column_insights}
|
|
302
|
+
|
|
303
|
+
Analyze the above information and determine the most appropriate ML methodology."""
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def format_approach_prompt(
|
|
308
|
+
domain_name: str,
|
|
309
|
+
domain_description: str,
|
|
310
|
+
use_case: str,
|
|
311
|
+
column_insights: str
|
|
312
|
+
) -> tuple[str, str]:
|
|
313
|
+
"""
|
|
314
|
+
Format the methodology selection prompts for the LLM.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
domain_name: The domain of the data (e.g., "Healthcare", "Finance")
|
|
318
|
+
domain_description: Detailed description of the domain context
|
|
319
|
+
use_case: Description of what the user wants to achieve
|
|
320
|
+
column_descriptions: Description of the columns in the dataset
|
|
321
|
+
column_insights: Statistical insights about the columns (data types,
|
|
322
|
+
unique counts, distributions, etc.)
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
tuple[str, str]: The formatted system prompt and user prompt
|
|
326
|
+
|
|
327
|
+
Example:
|
|
328
|
+
system_prompt, user_prompt = format_approach_prompt(
|
|
329
|
+
domain_name="E-commerce",
|
|
330
|
+
domain_description="Online retail platform with customer transactions",
|
|
331
|
+
use_case="Predict if a customer will make a purchase",
|
|
332
|
+
column_descriptions="user_id, page_views, cart_additions, timestamp",
|
|
333
|
+
column_insights="4 columns, 10000 rows, mixed types"
|
|
334
|
+
)
|
|
335
|
+
"""
|
|
336
|
+
user_prompt = METHODOLOGY_SELECTION_USER_PROMPT.format(
|
|
337
|
+
domain_name=domain_name,
|
|
338
|
+
domain_description=domain_description,
|
|
339
|
+
use_case_description=use_case,
|
|
340
|
+
column_insights=column_insights
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
return METHODOLOGY_SELECTION_SYSTEM_PROMPT, user_prompt
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
class MethodologyRecommendation(BaseModel):
|
|
5
|
+
selected_methodology: Literal[ "binary_classification",
|
|
6
|
+
"multiclass_classification",
|
|
7
|
+
"regression",
|
|
8
|
+
"timeseries_regression",
|
|
9
|
+
"timeseries_binary_classification",
|
|
10
|
+
"recommendation_engine",
|
|
11
|
+
"timeseries_recommendation_engine",
|
|
12
|
+
"clustering",
|
|
13
|
+
"anomaly_detection",
|
|
14
|
+
"not_applicable"] = Field(..., description="The most appropriate ML approach for this problem")
|
|
15
|
+
|
|
16
|
+
justification: str = Field( ..., description="Structured explanation with: business goal, prediction type, temporal dependency analysis, and methodology fit")
|
|
17
|
+
|
|
18
|
+
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: ml_approach_suggestion_agent
|
|
3
|
+
Version: 0.1.8
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Requires-Python: >=3.11
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
Requires-Dist: pydantic-settings
|
|
15
|
+
Requires-Dist: sfn-blueprint>=0.6.16
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Requires-Dist: pytest; extra == "dev"
|
|
18
|
+
Requires-Dist: pytest-mock; extra == "dev"
|
|
19
|
+
|
|
20
|
+
# ml_approach_suggestion_agent
|
|
21
|
+
|
|
22
|
+
An AI-powered agent that analyzes a dataset and use case to recommend the most appropriate machine learning methodology.
|
|
23
|
+
|
|
24
|
+
## Description
|
|
25
|
+
|
|
26
|
+
This agent takes a detailed description of a business domain, a specific use case, and information about the dataset—including column descriptions, insights, and target variable details—to suggest the best ML approach. It uses a large language model to:
|
|
27
|
+
|
|
28
|
+
1. **Analyze** the relationship between the use case and the target variable.
|
|
29
|
+
2. **Evaluate** the characteristics of the data (especially the target column).
|
|
30
|
+
3. **Recommend** the most suitable methodology from a predefined list: `Binary Classification`, `Multiclass Classification`, `Regression`, `Forecasting`, `Clustering`, `Recommendation Engine`, `Timeseries Classification`, `Timeseries Recommendation Engine`, `Anomaly Detection`, or `No-ML`.
|
|
31
|
+
4. **Provide** a clear justification for its recommendation.
|
|
32
|
+
|
|
33
|
+
This helps data scientists and analysts quickly and confidently choose the right path for their modeling efforts, saving time and reducing the risk of starting with an incorrect approach.
|
|
34
|
+
|
|
35
|
+
## Key Features
|
|
36
|
+
|
|
37
|
+
- **Intelligent Use Case Analysis**: Leverages an LLM to understand the core objective of the business problem.
|
|
38
|
+
- **Target-Aware Recommendation**: Places special emphasis on the nature of the target variable to guide its decision.
|
|
39
|
+
- **Context-Driven Suggestions**: Considers the entire data context, including domain and column descriptions, to make an informed choice.
|
|
40
|
+
- **Accelerates Model Planning**: Provides a validated starting point for ML projects, ensuring alignment between the problem and the proposed solution.
|
|
41
|
+
|
|
42
|
+
## Installation
|
|
43
|
+
|
|
44
|
+
### Prerequisites
|
|
45
|
+
|
|
46
|
+
- [**uv**](https://docs.astral.sh/uv/getting-started/installation/) – A fast Python package and environment manager.
|
|
47
|
+
- For a quick setup on macOS/Linux, you can use:
|
|
48
|
+
```bash
|
|
49
|
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
50
|
+
```
|
|
51
|
+
- [**Git**](https://git-scm.com/)
|
|
52
|
+
|
|
53
|
+
### Steps
|
|
54
|
+
|
|
55
|
+
1. **Clone the `methodology_selection_agent` repository:**
|
|
56
|
+
```bash
|
|
57
|
+
git clone https://github.com/stepfnAI/ml_approach_suggestion_agent.git
|
|
58
|
+
cd ml_approach_suggestion_agent
|
|
59
|
+
git switch main
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
2. **Create a virtual environment and install dependencies:**
|
|
63
|
+
This command creates a `.venv` folder in the current directory and installs all required packages.
|
|
64
|
+
```bash
|
|
65
|
+
uv sync --extra dev
|
|
66
|
+
source .venv/bin/activate
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## Configuration
|
|
70
|
+
|
|
71
|
+
You can configure the agent by creating a `.env` file in the project root or by exporting environment variables in your shell. Settings loaded via `export` will override those in a `.env` file.
|
|
72
|
+
|
|
73
|
+
### Available Settings
|
|
74
|
+
|
|
75
|
+
| Environment Variable | Description | Default |
|
|
76
|
+
| ------------------------------- | -------------------------------------------- | -------- |
|
|
77
|
+
| `OPENAI_API_KEY` | **(Required)** Your OpenAI API key. | *None* |
|
|
78
|
+
| `METHODOLOGY_AI_PROVIDER` | AI provider for methodology suggestions. | `openai` |
|
|
79
|
+
| `METHODOLOGY_AI_MODEL` | AI model for methodology suggestions. | `gpt-4o` |
|
|
80
|
+
| `METHODOLOGY_TEMPERATURE` | AI model temperature (e.g., `0.0` to `0.5`). | `0.3` |
|
|
81
|
+
| `METHODOLOGY_MAX_TOKENS` | Maximum tokens for the AI response. | `4000` |
|
|
82
|
+
|
|
83
|
+
---
|
|
84
|
+
|
|
85
|
+
### Method 1: Using a `.env` File (Recommended)
|
|
86
|
+
|
|
87
|
+
Create a `.env` file in the root directory to store API keys and project-wide defaults.
|
|
88
|
+
|
|
89
|
+
#### Example `.env` file:
|
|
90
|
+
|
|
91
|
+
```dotenv
|
|
92
|
+
# .env
|
|
93
|
+
|
|
94
|
+
# --- Required Settings ---
|
|
95
|
+
OPENAI_API_KEY="sk-your-api-key-here"
|
|
96
|
+
|
|
97
|
+
# --- Optional Overrides ---
|
|
98
|
+
# Use a different model
|
|
99
|
+
METHODOLOGY_AI_MODEL="gpt-4o-mini"
|
|
100
|
+
|
|
101
|
+
# Use a lower temperature for more deterministic responses
|
|
102
|
+
METHODOLOGY_TEMPERATURE=0.1
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
---
|
|
106
|
+
|
|
107
|
+
### Method 2: Using `export` Commands
|
|
108
|
+
|
|
109
|
+
Use `export` in your terminal for temporary settings or in CI/CD environments.
|
|
110
|
+
|
|
111
|
+
#### Example `export` commands:
|
|
112
|
+
|
|
113
|
+
```bash
|
|
114
|
+
# Set the environment variables for the current terminal session
|
|
115
|
+
export OPENAI_API_KEY="sk-your-api-key-here"
|
|
116
|
+
export METHODOLOGY_AI_MODEL="gpt-4o-mini"
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
## Testing
|
|
120
|
+
|
|
121
|
+
To run the test suite, use the following command from the root of the project directory:
|
|
122
|
+
|
|
123
|
+
```bash
|
|
124
|
+
pytest -s
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
## Usage
|
|
128
|
+
|
|
129
|
+
### Running the Example Script
|
|
130
|
+
|
|
131
|
+
To see a quick demonstration, run the provided example script. This will execute the agent with pre-defined data and print the recommended methodology.
|
|
132
|
+
|
|
133
|
+
```bash
|
|
134
|
+
python examples/basic_usage.py
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
### Using as a Library
|
|
138
|
+
|
|
139
|
+
Integrate the `MLApproachDecisionAgent` directly into your Python applications to get methodology recommendations programmatically.
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
import logging
|
|
143
|
+
from ml_approach_suggestion_agent.agent import MLApproachDecisionAgent
|
|
144
|
+
|
|
145
|
+
# Configure logging
|
|
146
|
+
logging.basicConfig(level=logging.INFO)
|
|
147
|
+
|
|
148
|
+
# 1. Define the domain, use case, and data context
|
|
149
|
+
domain_name = "Mortgage Loan Servicing"
|
|
150
|
+
domain_description = "Managing mortgage loans from post-origination to payoff, including payment collection, escrow management, and compliance for domestic and international loans."
|
|
151
|
+
use_case = "To predict the likelihood of a borrower becoming delinquent on their mortgage payment within the next 60 days using their demographic and financial data to enable proactive intervention."
|
|
152
|
+
|
|
153
|
+
column_descriptions = {
|
|
154
|
+
"CreditScore": "Borrower's credit score from credit bureau sources",
|
|
155
|
+
"EmploymentStatus": "Current employment status (e.g., employed, self-employed, unemployed)",
|
|
156
|
+
# ... other column descriptions
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
column_insights = {
|
|
160
|
+
"table_info": { "row_count": 50000 },
|
|
161
|
+
"table_columns_info": {
|
|
162
|
+
"CreditScore": { "data_type": "Int64", "min_max_value": [350, 850] },
|
|
163
|
+
"EmploymentStatus": { "data_type": "string", "distinct_count": 5 },
|
|
164
|
+
# ... other column insights
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
target_column_name = "IsDelinquent"
|
|
169
|
+
target_column_insights = {
|
|
170
|
+
"Target Column Description": "A binary categorical flag indicating if the borrower has missed one or more mortgage payments in the last 60 days.",
|
|
171
|
+
"Data Type": "Integer (or Boolean)",
|
|
172
|
+
"Value Distribution": {
|
|
173
|
+
"0 (Not Delinquent)": "92%",
|
|
174
|
+
"1 (Delinquent)": "8%"
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
# 2. Prepare the task data payload
|
|
179
|
+
task_data = {
|
|
180
|
+
"domain_name": domain_name,
|
|
181
|
+
"domain_description": domain_description,
|
|
182
|
+
"use_case": use_case,
|
|
183
|
+
"column_descriptions": column_descriptions,
|
|
184
|
+
"column_insights": column_insights,
|
|
185
|
+
"target_column_name": target_column_name,
|
|
186
|
+
"target_column_insights": target_column_insights
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
# 3. Initialize and execute the agent
|
|
190
|
+
agent = MLApproachDecisionAgent()
|
|
191
|
+
result = agent(task_data)
|
|
192
|
+
|
|
193
|
+
# 4. Print the suggested methodology
|
|
194
|
+
if result["success"]:
|
|
195
|
+
print("Successfully suggested an approach:")
|
|
196
|
+
print(result["result"]["approach"].model_dump_json(indent=4))
|
|
197
|
+
print(f"Cost summary: {result['result']['cost_summary']}")
|
|
198
|
+
else:
|
|
199
|
+
print("Failed to suggest an approach.")
|
|
200
|
+
|
|
201
|
+
```
|
|
202
|
+
|
|
203
|
+
### Example Output
|
|
204
|
+
|
|
205
|
+
The agent returns a JSON object containing the recommended methodology and a detailed explanation for the choice.
|
|
206
|
+
|
|
207
|
+
*(Note: The actual output may vary slightly based on the LLM's response.)*
|
|
208
|
+
|
|
209
|
+
```json
|
|
210
|
+
{
|
|
211
|
+
"recommended": "Classification",
|
|
212
|
+
"description": "The goal is to predict the likelihood of a borrower becoming delinquent on their mortgage payment within the next 60 days. This is a binary outcome (delinquent or not delinquent), making classification the appropriate methodology. The target variable is categorical, and the available demographic and financial data can be used as features to train a classification model."
|
|
213
|
+
}
|
|
214
|
+
```
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
ml_approach_suggestion_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
ml_approach_suggestion_agent/agent.py,sha256=ea46aRQVp2uFKj4DslF3rCTX4vMYv0sms9TE-SP-XuE,5009
|
|
3
|
+
ml_approach_suggestion_agent/config.py,sha256=kNZDiHYOB-A90Su-4iFcDik_I60w_cVGvtpASJKn638,701
|
|
4
|
+
ml_approach_suggestion_agent/constants.py,sha256=KLjZXxPhWHhoSgBNXeOwElpYzB4gr1jmjG-prdn9lPc,10628
|
|
5
|
+
ml_approach_suggestion_agent/models.py,sha256=g-rhJZjFpl8Wyu14Gq83NIw1XjP4cAy33XLGL7vE_OM,961
|
|
6
|
+
ml_approach_suggestion_agent-0.1.8.dist-info/METADATA,sha256=rzcEpMZUFNNZbCAfgPITx3dscJtJDWApXQM3EJyTfD8,8205
|
|
7
|
+
ml_approach_suggestion_agent-0.1.8.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
8
|
+
ml_approach_suggestion_agent-0.1.8.dist-info/top_level.txt,sha256=3-KHls6umFXtNFJoP7OFCLvb4zd12AWH71PVKNd5Aok,29
|
|
9
|
+
ml_approach_suggestion_agent-0.1.8.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ml_approach_suggestion_agent
|