drift-ml 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- drift_ml-0.1.1/PKG-INFO +6 -0
- drift_ml-0.1.1/README.md +186 -0
- drift_ml-0.1.1/drift/__init__.py +5 -0
- drift_ml-0.1.1/drift/__main__.py +14 -0
- drift_ml-0.1.1/drift/cli/__init__.py +6 -0
- drift_ml-0.1.1/drift/cli/client.py +125 -0
- drift_ml-0.1.1/drift/cli/repl.py +220 -0
- drift_ml-0.1.1/drift/cli/session.py +70 -0
- drift_ml-0.1.1/drift/llm_adapters/__init__.py +9 -0
- drift_ml-0.1.1/drift/llm_adapters/base.py +41 -0
- drift_ml-0.1.1/drift/llm_adapters/gemini_cli.py +65 -0
- drift_ml-0.1.1/drift/llm_adapters/local_llm.py +65 -0
- drift_ml-0.1.1/drift_ml.egg-info/PKG-INFO +6 -0
- drift_ml-0.1.1/drift_ml.egg-info/SOURCES.txt +18 -0
- drift_ml-0.1.1/drift_ml.egg-info/dependency_links.txt +1 -0
- drift_ml-0.1.1/drift_ml.egg-info/entry_points.txt +2 -0
- drift_ml-0.1.1/drift_ml.egg-info/requires.txt +1 -0
- drift_ml-0.1.1/drift_ml.egg-info/top_level.txt +2 -0
- drift_ml-0.1.1/pyproject.toml +19 -0
- drift_ml-0.1.1/setup.cfg +4 -0
drift_ml-0.1.1/PKG-INFO
ADDED
drift_ml-0.1.1/README.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
# Intent2Model 🚀
|
|
2
|
+
|
|
3
|
+
LLM-Guided AutoML Agent - Upload a CSV, chat with the AI, get a trained model.
|
|
4
|
+
|
|
5
|
+
## ⚡ Quick Start (Easiest Way)
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
# Make scripts executable (first time only)
|
|
9
|
+
chmod +x start.sh stop.sh
|
|
10
|
+
|
|
11
|
+
# Start everything
|
|
12
|
+
./start.sh
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
Then open **http://localhost:3000** in your browser!
|
|
16
|
+
|
|
17
|
+
To stop: Press `Ctrl+C` or run `./stop.sh`
|
|
18
|
+
|
|
19
|
+
---
|
|
20
|
+
|
|
21
|
+
## 📖 How to Use
|
|
22
|
+
|
|
23
|
+
### 1. Upload CSV
|
|
24
|
+
- Drag & drop your CSV file or click "choose file"
|
|
25
|
+
- System analyzes it automatically
|
|
26
|
+
|
|
27
|
+
### 2. Train Model
|
|
28
|
+
Just type a column name in the chat:
|
|
29
|
+
- **"variety"** → trains model to predict "variety"
|
|
30
|
+
- **"price"** → trains model to predict "price"
|
|
31
|
+
- Or any column name from your dataset
|
|
32
|
+
|
|
33
|
+
### 3. View Results
|
|
34
|
+
- **"report"** → shows beautiful charts and metrics
|
|
35
|
+
- **"show me results"** → displays model performance
|
|
36
|
+
|
|
37
|
+
### 4. Make Predictions
|
|
38
|
+
- **"predict"** or **"can you predict for me?"** → starts prediction flow
|
|
39
|
+
- Provide feature values: **"sepal.length: 5.1, sepal.width: 3.5"**
|
|
40
|
+
|
|
41
|
+
---
|
|
42
|
+
|
|
43
|
+
## 💬 Example Conversation
|
|
44
|
+
|
|
45
|
+
```
|
|
46
|
+
You: [uploads iris.csv]
|
|
47
|
+
AI: ✓ analyzed your dataset • 150 rows • 5 columns
|
|
48
|
+
AI: suggested targets: variety, sepal.length, sepal.width
|
|
49
|
+
|
|
50
|
+
You: variety
|
|
51
|
+
AI: 🚀 training model to predict "variety"...
|
|
52
|
+
AI: ✅ model trained successfully!
|
|
53
|
+
AI: accuracy: 1.000 • best model: RandomForest
|
|
54
|
+
AI: [shows charts]
|
|
55
|
+
|
|
56
|
+
You: report
|
|
57
|
+
AI: [shows detailed charts: metrics, feature importance, CV scores]
|
|
58
|
+
|
|
59
|
+
You: predict
|
|
60
|
+
AI: sure! i need: sepal.length, sepal.width, petal.length, petal.width
|
|
61
|
+
|
|
62
|
+
You: sepal.length: 5.1, sepal.width: 3.5, petal.length: 1.4, petal.width: 0.2
|
|
63
|
+
AI: 🎯 prediction: Setosa
|
|
64
|
+
AI: probabilities: Setosa 99.8%, Versicolor 0.2%, Virginica 0.0%
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
---
|
|
68
|
+
|
|
69
|
+
## 🛠️ Manual Setup (Alternative)
|
|
70
|
+
|
|
71
|
+
### Backend
|
|
72
|
+
**Important:** Always run uvicorn from inside the `backend/` folder. Running from project root will fail with "Could not import module main".
|
|
73
|
+
|
|
74
|
+
```bash
|
|
75
|
+
cd backend
|
|
76
|
+
pip install -r ../requirements.txt
|
|
77
|
+
# Optional: set API key in .env or export GEMINI_API_KEY=your_key
|
|
78
|
+
python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
Or use the helper script (from project root):
|
|
82
|
+
```bash
|
|
83
|
+
chmod +x backend/run.sh
|
|
84
|
+
./backend/run.sh
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
**Backend nahi chal raha?**
|
|
88
|
+
- **"Could not import module main"** → You're in the wrong folder. Run `cd backend` first, then `python3 -m uvicorn main:app --host 0.0.0.0 --port 8000`.
|
|
89
|
+
- **"Address already in use" / port 8000** → Free the port: `lsof -ti:8000 | xargs kill -9`, then start again.
|
|
90
|
+
- **Dependencies missing** → From project root: `pip install -r requirements.txt` (or use `./start.sh` — it creates a venv and installs deps).
|
|
91
|
+
|
|
92
|
+
### Frontend (New Terminal)
|
|
93
|
+
```bash
|
|
94
|
+
cd frontend
|
|
95
|
+
npm install
|
|
96
|
+
npm run dev
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
Visit **http://localhost:3000**
|
|
100
|
+
|
|
101
|
+
---
|
|
102
|
+
|
|
103
|
+
## 🎨 Features
|
|
104
|
+
|
|
105
|
+
- 📊 **Beautiful Charts**: Metrics, feature importance, CV scores
|
|
106
|
+
- 🎨 **Extravagant UI**: Gradient colors, smooth animations
|
|
107
|
+
- 🤖 **LLM-Powered**: Gemini AI generates optimal pipelines
|
|
108
|
+
- 🔮 **Smart Predictions**: Chat-based prediction interface
|
|
109
|
+
- 📈 **Model Comparison**: Tries multiple models, picks best
|
|
110
|
+
- ⚡ **Auto-Detection**: Automatically detects task type and metrics
|
|
111
|
+
|
|
112
|
+
---
|
|
113
|
+
|
|
114
|
+
## 📝 Requirements
|
|
115
|
+
|
|
116
|
+
- Python 3.10+
|
|
117
|
+
- Node.js 18+
|
|
118
|
+
- npm/yarn
|
|
119
|
+
|
|
120
|
+
Install dependencies:
|
|
121
|
+
```bash
|
|
122
|
+
# Backend
|
|
123
|
+
pip install -r requirements.txt
|
|
124
|
+
|
|
125
|
+
# Frontend
|
|
126
|
+
cd frontend
|
|
127
|
+
npm install
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
---
|
|
131
|
+
|
|
132
|
+
## 🐛 Troubleshooting
|
|
133
|
+
|
|
134
|
+
**Services not starting?**
|
|
135
|
+
- Check ports: `lsof -i :8000` and `lsof -i :3000`
|
|
136
|
+
- Check logs: `tail -f backend.log` or `tail -f frontend.log`
|
|
137
|
+
|
|
138
|
+
**Training errors?**
|
|
139
|
+
- Make sure CSV has valid data
|
|
140
|
+
- Check that target column exists
|
|
141
|
+
- Try a different column name
|
|
142
|
+
|
|
143
|
+
---
|
|
144
|
+
|
|
145
|
+
## drift — Terminal-first CLI
|
|
146
|
+
|
|
147
|
+
**drift** by Lakshit Sachdeva. Terminal-first, chat-based AutoML — same engine as the web UI. No commands to memorize.
|
|
148
|
+
|
|
149
|
+
### Exactly what to do (any computer)
|
|
150
|
+
|
|
151
|
+
1. **Install drift** (pick one):
|
|
152
|
+
```bash
|
|
153
|
+
npm install -g drift-ml
|
|
154
|
+
```
|
|
155
|
+
or:
|
|
156
|
+
```bash
|
|
157
|
+
pipx install drift
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
2. **Run drift:**
|
|
161
|
+
```bash
|
|
162
|
+
drift
|
|
163
|
+
```
|
|
164
|
+
You’ll see the welcome and step-by-step instructions in the terminal.
|
|
165
|
+
|
|
166
|
+
3. **Engine** — On first run the CLI downloads and starts the drift engine locally (or set `DRIFT_BACKEND_URL` to a running engine). You need an LLM: Gemini CLI, Ollama, or another local LLM.
|
|
167
|
+
|
|
168
|
+
4. **In drift:** type `load path/to/your.csv`, then chat (e.g. `predict price`, `try something stronger`). Type `quit` to exit.
|
|
169
|
+
|
|
170
|
+
drift shows you the rest when you run it.
|
|
171
|
+
|
|
172
|
+
### Install (details)
|
|
173
|
+
|
|
174
|
+
- **Local-first** — Same engine as the web app; planning and training run on your machine.
|
|
175
|
+
- **Chat-based**: e.g. `load iris.csv`, `predict price`, `try something stronger`, `why is accuracy capped`.
|
|
176
|
+
- **Engine** runs locally (CLI auto-starts it or use `DRIFT_BACKEND_URL`). Web UI can be hosted on Vercel.
|
|
177
|
+
|
|
178
|
+
---
|
|
179
|
+
|
|
180
|
+
## 📚 More Info
|
|
181
|
+
|
|
182
|
+
See `HOW_TO_USE.md` for detailed instructions and examples.
|
|
183
|
+
|
|
184
|
+
---
|
|
185
|
+
|
|
186
|
+
**That's it! Just run `./start.sh` and start chatting! 🎉**
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HTTP client for the existing backend API.
|
|
3
|
+
Uses the same endpoints as the web app: upload, chat, session, runs, train.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import requests
|
|
12
|
+
except ImportError:
|
|
13
|
+
requests = None # type: ignore
|
|
14
|
+
|
|
15
|
+
DEFAULT_BASE_URL = os.environ.get("DRIFT_BACKEND_URL", "http://localhost:8000")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BackendError(Exception):
|
|
19
|
+
"""Raised when the backend returns an error or is unreachable."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, message: str, status_code: Optional[int] = None, body: Any = None):
|
|
22
|
+
self.message = message
|
|
23
|
+
self.status_code = status_code
|
|
24
|
+
self.body = body
|
|
25
|
+
super().__init__(message)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BackendClient:
|
|
29
|
+
"""Client for backend upload, chat, session, runs, and train endpoints."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, base_url: str = DEFAULT_BASE_URL, timeout: int = 300):
|
|
32
|
+
self.base_url = base_url.rstrip("/")
|
|
33
|
+
self.timeout = timeout
|
|
34
|
+
if requests is None:
|
|
35
|
+
raise BackendError(
|
|
36
|
+
"The 'requests' library is required for the drift CLI. Install it with: pip install requests"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def _url(self, path: str) -> str:
|
|
40
|
+
return f"{self.base_url}{path}"
|
|
41
|
+
|
|
42
|
+
def health(self) -> Dict[str, Any]:
|
|
43
|
+
"""GET /health — check backend is up and LLM status."""
|
|
44
|
+
r = requests.get(self._url("/health"), timeout=10)
|
|
45
|
+
r.raise_for_status()
|
|
46
|
+
return r.json()
|
|
47
|
+
|
|
48
|
+
def upload_csv(self, file_path: str) -> Dict[str, Any]:
|
|
49
|
+
"""
|
|
50
|
+
POST /upload — upload a local CSV and create a session.
|
|
51
|
+
Returns session_id, dataset_id, profile, initial_message.
|
|
52
|
+
"""
|
|
53
|
+
path = Path(file_path).resolve()
|
|
54
|
+
if not path.exists():
|
|
55
|
+
raise BackendError(f"File not found: {path}")
|
|
56
|
+
with open(path, "rb") as f:
|
|
57
|
+
name = path.name
|
|
58
|
+
files = {"file": (name, f, "text/csv")}
|
|
59
|
+
r = requests.post(self._url("/upload"), files=files, timeout=60)
|
|
60
|
+
if r.status_code >= 400:
|
|
61
|
+
try:
|
|
62
|
+
body = r.json()
|
|
63
|
+
detail = body.get("detail", str(r.text))
|
|
64
|
+
except Exception:
|
|
65
|
+
detail = r.text or str(r.status_code)
|
|
66
|
+
raise BackendError(f"Upload failed: {detail}", status_code=r.status_code, body=r.text)
|
|
67
|
+
return r.json()
|
|
68
|
+
|
|
69
|
+
def chat(self, session_id: str, message: str) -> Dict[str, Any]:
|
|
70
|
+
"""
|
|
71
|
+
POST /session/{session_id}/chat — send a user message, get agent reply.
|
|
72
|
+
Returns chat_history, trigger_training.
|
|
73
|
+
"""
|
|
74
|
+
r = requests.post(
|
|
75
|
+
self._url(f"/session/{session_id}/chat"),
|
|
76
|
+
json={"message": message},
|
|
77
|
+
timeout=120,
|
|
78
|
+
)
|
|
79
|
+
if r.status_code >= 400:
|
|
80
|
+
try:
|
|
81
|
+
body = r.json()
|
|
82
|
+
detail = body.get("detail", str(r.text))
|
|
83
|
+
except Exception:
|
|
84
|
+
detail = r.text or str(r.status_code)
|
|
85
|
+
raise BackendError(f"Chat failed: {detail}", status_code=r.status_code, body=r.text)
|
|
86
|
+
return r.json()
|
|
87
|
+
|
|
88
|
+
def get_session(self, session_id: str) -> Dict[str, Any]:
|
|
89
|
+
"""GET /session/{session_id} — full session state (chat, model_state, current_run_id, etc.)."""
|
|
90
|
+
r = requests.get(self._url(f"/session/{session_id}"), timeout=30)
|
|
91
|
+
if r.status_code >= 400:
|
|
92
|
+
try:
|
|
93
|
+
body = r.json()
|
|
94
|
+
detail = body.get("detail", str(r.text))
|
|
95
|
+
except Exception:
|
|
96
|
+
detail = r.text or str(r.status_code)
|
|
97
|
+
raise BackendError(f"Get session failed: {detail}", status_code=r.status_code, body=r.text)
|
|
98
|
+
return r.json()
|
|
99
|
+
|
|
100
|
+
def get_run_state(self, run_id: str) -> Dict[str, Any]:
|
|
101
|
+
"""GET /runs/{run_id} — run state: status, current_step, progress, events[]."""
|
|
102
|
+
r = requests.get(self._url(f"/runs/{run_id}"), timeout=30)
|
|
103
|
+
if r.status_code >= 400:
|
|
104
|
+
try:
|
|
105
|
+
body = r.json()
|
|
106
|
+
detail = body.get("detail", str(r.text))
|
|
107
|
+
except Exception:
|
|
108
|
+
detail = r.text or str(r.status_code)
|
|
109
|
+
raise BackendError(f"Get run failed: {detail}", status_code=r.status_code, body=r.text)
|
|
110
|
+
return r.json()
|
|
111
|
+
|
|
112
|
+
def train(self, session_id: str) -> Dict[str, Any]:
|
|
113
|
+
"""
|
|
114
|
+
POST /session/{session_id}/train — run one training attempt (blocking).
|
|
115
|
+
Returns run_id, session_id, metrics, refused, refusal_reason, agent_message.
|
|
116
|
+
"""
|
|
117
|
+
r = requests.post(self._url(f"/session/{session_id}/train"), timeout=self.timeout)
|
|
118
|
+
if r.status_code >= 400:
|
|
119
|
+
try:
|
|
120
|
+
body = r.json()
|
|
121
|
+
detail = body.get("detail", str(r.text))
|
|
122
|
+
except Exception:
|
|
123
|
+
detail = r.text or str(r.status_code)
|
|
124
|
+
raise BackendError(f"Train failed: {detail}", status_code=r.status_code, body=r.text)
|
|
125
|
+
return r.json()
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Chat-based CLI loop for drift.
|
|
3
|
+
Natural language input; maintains session state; reuses backend planner + executor.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import re
|
|
7
|
+
import sys
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from typing import Any, Dict, Optional
|
|
11
|
+
|
|
12
|
+
from drift.cli.client import BackendClient, BackendError, DEFAULT_BASE_URL
|
|
13
|
+
from drift.cli.session import SessionState
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
PROMPT = "drift › "
|
|
17
|
+
LOAD_PATTERN = re.compile(r"^\s*load\s+(.+)$", re.IGNORECASE)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def run_repl(base_url: Optional[str] = None) -> None:
|
|
21
|
+
"""Run the chat-based REPL. Uses backend at base_url (default from DRIFT_BACKEND_URL or localhost:8000)."""
|
|
22
|
+
try:
|
|
23
|
+
client = BackendClient(base_url=base_url or DEFAULT_BASE_URL)
|
|
24
|
+
except BackendError as e:
|
|
25
|
+
print(f"drift: {e.message}", file=sys.stderr)
|
|
26
|
+
sys.exit(1)
|
|
27
|
+
session = SessionState()
|
|
28
|
+
_print_banner(client)
|
|
29
|
+
while True:
|
|
30
|
+
try:
|
|
31
|
+
line = input(PROMPT).strip()
|
|
32
|
+
except EOFError:
|
|
33
|
+
print()
|
|
34
|
+
break
|
|
35
|
+
if not line:
|
|
36
|
+
continue
|
|
37
|
+
if line.lower() in ("quit", "exit", "q"):
|
|
38
|
+
break
|
|
39
|
+
try:
|
|
40
|
+
_handle_input(line, client, session)
|
|
41
|
+
except BackendError as e:
|
|
42
|
+
print(f"Error: {e.message}", file=sys.stderr)
|
|
43
|
+
except KeyboardInterrupt:
|
|
44
|
+
print("\nInterrupted. Type 'quit' to exit.")
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _print_banner(client: BackendClient) -> None:
|
|
49
|
+
_print_welcome()
|
|
50
|
+
engine_ok = False
|
|
51
|
+
llm_name = ""
|
|
52
|
+
try:
|
|
53
|
+
health = client.health()
|
|
54
|
+
engine_ok = (health.get("status") or "").lower() == "healthy"
|
|
55
|
+
llm_name = health.get("llm_provider") or health.get("current_model") or ""
|
|
56
|
+
except Exception:
|
|
57
|
+
pass
|
|
58
|
+
_print_status(engine_ok=engine_ok, llm_name=llm_name)
|
|
59
|
+
_print_examples()
|
|
60
|
+
print()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _print_welcome() -> None:
|
|
64
|
+
"""drift by Lakshit Sachdeva. Local-first ML engineer."""
|
|
65
|
+
print()
|
|
66
|
+
print(" ----------------------------------------")
|
|
67
|
+
print(" drift by Lakshit Sachdeva")
|
|
68
|
+
print(" Local-first ML engineer.")
|
|
69
|
+
print(" Same engine as the web app.")
|
|
70
|
+
print(" No commands to memorize.")
|
|
71
|
+
print(" ----------------------------------------")
|
|
72
|
+
print()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _print_status(engine_ok: bool, llm_name: str) -> None:
|
|
76
|
+
"""Engine running, LLM detected, Ready."""
|
|
77
|
+
if engine_ok:
|
|
78
|
+
print(" \u2713 Engine running")
|
|
79
|
+
if llm_name:
|
|
80
|
+
print(f" \u2713 LLM detected ({llm_name})")
|
|
81
|
+
else:
|
|
82
|
+
print(" \u2713 LLM detected (Gemini CLI / Ollama / local)")
|
|
83
|
+
print(" \u2713 Ready")
|
|
84
|
+
else:
|
|
85
|
+
print(" \u2717 Engine not running")
|
|
86
|
+
print(" Start the engine or set DRIFT_BACKEND_URL to a running engine URL.")
|
|
87
|
+
print()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _print_examples() -> None:
|
|
91
|
+
"""Examples: load, predict, try something stronger, quit."""
|
|
92
|
+
print(" Examples:")
|
|
93
|
+
print(" load data.csv")
|
|
94
|
+
print(" predict price")
|
|
95
|
+
print(" try something stronger")
|
|
96
|
+
print(" why is accuracy capped")
|
|
97
|
+
print(" quit")
|
|
98
|
+
print()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _handle_input(line: str, client: BackendClient, session: SessionState) -> None:
|
|
102
|
+
load_match = LOAD_PATTERN.match(line)
|
|
103
|
+
if load_match:
|
|
104
|
+
path = load_match.group(1).strip().strip('"\'')
|
|
105
|
+
_do_load(path, client, session)
|
|
106
|
+
return
|
|
107
|
+
if not session.has_session():
|
|
108
|
+
print("Load a dataset first: load path/to/file.csv")
|
|
109
|
+
return
|
|
110
|
+
_do_chat(line, client, session)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _do_load(path: str, client: BackendClient, session: SessionState) -> None:
|
|
114
|
+
out = client.upload_csv(path)
|
|
115
|
+
session.dataset_path = path
|
|
116
|
+
session.update_from_upload(
|
|
117
|
+
session_id=out["session_id"],
|
|
118
|
+
dataset_id=out["dataset_id"],
|
|
119
|
+
initial_message=out.get("initial_message"),
|
|
120
|
+
)
|
|
121
|
+
print("Dataset loaded.")
|
|
122
|
+
if out.get("initial_message"):
|
|
123
|
+
content = out["initial_message"].get("content") or ""
|
|
124
|
+
if content:
|
|
125
|
+
_print_message(content)
|
|
126
|
+
print()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _do_chat(message: str, client: BackendClient, session: SessionState) -> None:
|
|
130
|
+
sid = session.session_id
|
|
131
|
+
out = client.chat(sid, message)
|
|
132
|
+
chat_history = out.get("chat_history") or []
|
|
133
|
+
session.update_from_chat(chat_history)
|
|
134
|
+
# Show latest agent reply
|
|
135
|
+
for m in reversed(chat_history):
|
|
136
|
+
if m.get("role") == "agent" and m.get("content"):
|
|
137
|
+
_print_message(m["content"])
|
|
138
|
+
break
|
|
139
|
+
trigger = out.get("trigger_training") is True
|
|
140
|
+
if trigger:
|
|
141
|
+
_run_training_and_show(client, session)
|
|
142
|
+
print()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _run_training_and_show(client: BackendClient, session: SessionState) -> None:
|
|
146
|
+
sid = session.session_id
|
|
147
|
+
if not sid:
|
|
148
|
+
return
|
|
149
|
+
train_result: Optional[Dict[str, Any]] = None
|
|
150
|
+
train_error: Optional[Exception] = None
|
|
151
|
+
|
|
152
|
+
def do_train() -> None:
|
|
153
|
+
nonlocal train_result, train_error
|
|
154
|
+
try:
|
|
155
|
+
train_result = client.train(sid)
|
|
156
|
+
except Exception as e:
|
|
157
|
+
train_error = e
|
|
158
|
+
|
|
159
|
+
thread = threading.Thread(target=do_train, daemon=True)
|
|
160
|
+
thread.start()
|
|
161
|
+
print("Training started…")
|
|
162
|
+
run_id: Optional[str] = None
|
|
163
|
+
last_event_count = 0
|
|
164
|
+
poll_interval = 0.8
|
|
165
|
+
timeout_sec = 300
|
|
166
|
+
start = time.time()
|
|
167
|
+
while thread.is_alive() and (time.time() - start) < timeout_sec:
|
|
168
|
+
time.sleep(poll_interval)
|
|
169
|
+
try:
|
|
170
|
+
sess = client.get_session(sid)
|
|
171
|
+
run_id = sess.get("current_run_id")
|
|
172
|
+
if run_id:
|
|
173
|
+
break
|
|
174
|
+
except Exception:
|
|
175
|
+
pass
|
|
176
|
+
if run_id:
|
|
177
|
+
while thread.is_alive() and (time.time() - start) < timeout_sec:
|
|
178
|
+
time.sleep(poll_interval)
|
|
179
|
+
try:
|
|
180
|
+
run_state = client.get_run_state(run_id)
|
|
181
|
+
events = run_state.get("events") or []
|
|
182
|
+
for ev in events[last_event_count:]:
|
|
183
|
+
msg = ev.get("message") or ev.get("step_name") or ""
|
|
184
|
+
if msg:
|
|
185
|
+
print(f" {msg}")
|
|
186
|
+
last_event_count = len(events)
|
|
187
|
+
status = run_state.get("status")
|
|
188
|
+
if status in ("success", "failed", "refused"):
|
|
189
|
+
break
|
|
190
|
+
except Exception:
|
|
191
|
+
pass
|
|
192
|
+
thread.join(timeout=1.0)
|
|
193
|
+
if train_error:
|
|
194
|
+
print(f"Training error: {train_error}", file=sys.stderr)
|
|
195
|
+
return
|
|
196
|
+
if train_result:
|
|
197
|
+
session.update_after_train(
|
|
198
|
+
run_id=train_result.get("run_id", ""),
|
|
199
|
+
metrics=train_result.get("metrics") or {},
|
|
200
|
+
agent_message=train_result.get("agent_message"),
|
|
201
|
+
)
|
|
202
|
+
agent_message = train_result.get("agent_message")
|
|
203
|
+
if agent_message:
|
|
204
|
+
_print_message(agent_message)
|
|
205
|
+
metrics = train_result.get("metrics") or {}
|
|
206
|
+
if metrics:
|
|
207
|
+
primary = metrics.get("primary_metric_value") or metrics.get("accuracy") or metrics.get("r2")
|
|
208
|
+
if primary is not None:
|
|
209
|
+
print(f" Primary metric: {primary}")
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _print_message(content: str) -> None:
|
|
213
|
+
"""Print agent message; strip markdown bold for terminal if desired, keep newlines."""
|
|
214
|
+
text = content.strip()
|
|
215
|
+
if not text:
|
|
216
|
+
return
|
|
217
|
+
# Optional: replace **x** with x for readability in terminal
|
|
218
|
+
text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
|
|
219
|
+
for line in text.split("\n"):
|
|
220
|
+
print(line)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Session state for the drift CLI.
|
|
3
|
+
Holds dataset ref, plan summary, last metrics, run_id, and chat history
|
|
4
|
+
so the REPL can maintain context and the backend can be queried correctly.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any, Dict, List, Optional
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class SessionState:
|
|
13
|
+
"""In-memory CLI session: dataset, plan, run, metrics, chat."""
|
|
14
|
+
|
|
15
|
+
dataset_path: Optional[str] = None
|
|
16
|
+
dataset_id: Optional[str] = None
|
|
17
|
+
session_id: Optional[str] = None
|
|
18
|
+
run_id: Optional[str] = None
|
|
19
|
+
plan_summary: Optional[str] = None
|
|
20
|
+
last_metrics: Dict[str, Any] = field(default_factory=dict)
|
|
21
|
+
chat_history: List[Dict[str, Any]] = field(default_factory=list)
|
|
22
|
+
|
|
23
|
+
def has_session(self) -> bool:
|
|
24
|
+
return bool(self.session_id)
|
|
25
|
+
|
|
26
|
+
def has_dataset(self) -> bool:
|
|
27
|
+
return bool(self.dataset_id and self.session_id)
|
|
28
|
+
|
|
29
|
+
def clear_run(self) -> None:
|
|
30
|
+
self.run_id = None
|
|
31
|
+
self.last_metrics = {}
|
|
32
|
+
|
|
33
|
+
def update_from_upload(self, session_id: str, dataset_id: str, initial_message: Optional[Dict] = None) -> None:
|
|
34
|
+
self.session_id = session_id
|
|
35
|
+
self.dataset_id = dataset_id
|
|
36
|
+
self.run_id = None
|
|
37
|
+
self.last_metrics = {}
|
|
38
|
+
if initial_message:
|
|
39
|
+
self.chat_history = [initial_message]
|
|
40
|
+
|
|
41
|
+
def update_from_chat(self, chat_history: List[Dict[str, Any]], run_id: Optional[str] = None) -> None:
|
|
42
|
+
self.chat_history = chat_history or self.chat_history
|
|
43
|
+
if run_id is not None:
|
|
44
|
+
self.run_id = run_id
|
|
45
|
+
|
|
46
|
+
def update_from_session(self, session: Dict[str, Any]) -> None:
|
|
47
|
+
self.session_id = session.get("session_id") or self.session_id
|
|
48
|
+
self.dataset_id = session.get("dataset_id") or self.dataset_id
|
|
49
|
+
self.run_id = session.get("current_run_id") or self.run_id
|
|
50
|
+
self.chat_history = session.get("chat_history") or self.chat_history
|
|
51
|
+
ms = session.get("model_state") or {}
|
|
52
|
+
self.last_metrics = ms.get("metrics") or self.last_metrics
|
|
53
|
+
self.plan_summary = None
|
|
54
|
+
if session.get("structural_plan"):
|
|
55
|
+
self.plan_summary = _summarize_plan(session["structural_plan"])
|
|
56
|
+
|
|
57
|
+
def update_after_train(self, run_id: str, metrics: Dict[str, Any], agent_message: Optional[str] = None) -> None:
|
|
58
|
+
self.run_id = run_id
|
|
59
|
+
self.last_metrics = metrics or self.last_metrics
|
|
60
|
+
if agent_message and self.chat_history:
|
|
61
|
+
self.chat_history[-1]["content"] = agent_message
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _summarize_plan(structural_plan: Dict[str, Any]) -> str:
|
|
65
|
+
"""One-line summary of structural plan for display (StructuralPlan has inferred_target, task_type; no models)."""
|
|
66
|
+
if not structural_plan:
|
|
67
|
+
return ""
|
|
68
|
+
target = structural_plan.get("inferred_target") or structural_plan.get("target") or "?"
|
|
69
|
+
task = structural_plan.get("task_type") or "?"
|
|
70
|
+
return f"Target: {target}, task: {task}"
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pluggable LLM adapters for drift CLI.
|
|
3
|
+
Primary planning/execution uses the backend; these adapters support
|
|
4
|
+
optional local processing (e.g. help, summaries) or future offline use.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
|
|
8
|
+
|
|
9
|
+
__all__ = ["BaseLLMAdapter", "LLMResponse"]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base contract for drift LLM adapters.
|
|
3
|
+
The CLI uses the backend as the planner/executor; adapters return
|
|
4
|
+
structured ML intent, plan updates, and explanations when used locally.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class LLMResponse:
|
|
14
|
+
"""Structured response from an LLM adapter: intent, plan updates, explanation."""
|
|
15
|
+
|
|
16
|
+
raw: str
|
|
17
|
+
intent: Optional[Dict[str, Any]] = None # target, drop_columns, performance_mode, start_training
|
|
18
|
+
plan_summary: Optional[str] = None
|
|
19
|
+
explanation: Optional[str] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseLLMAdapter(ABC):
|
|
23
|
+
"""
|
|
24
|
+
LLM adapter contract.
|
|
25
|
+
The LLM is the PLANNER, not the executor; it updates plans and explains outcomes.
|
|
26
|
+
Execution is assumed to happen automatically (backend).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def generate(self, user_message: str, system_prompt: Optional[str] = None, context: Optional[Dict[str, Any]] = None) -> LLMResponse:
|
|
31
|
+
"""
|
|
32
|
+
Turn user message + optional context into a structured response.
|
|
33
|
+
Returns intent, plan summary, and explanation — never claims to edit files or execute.
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def name(self) -> str:
|
|
40
|
+
"""Adapter identifier (e.g. gemini_cli, local_llm)."""
|
|
41
|
+
pass
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gemini CLI adapter — uses gemini CLI (default if available) for local LLM calls.
|
|
3
|
+
Used when drift CLI runs in a mode that calls a local LLM instead of the backend.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
import subprocess
|
|
9
|
+
from typing import Any, Dict, Optional
|
|
10
|
+
|
|
11
|
+
from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GeminiCLIAdapter(BaseLLMAdapter):
|
|
15
|
+
"""Adapter that shells out to gemini CLI (e.g. `gemini` or GEMINI_CLI_CMD)."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, cmd: Optional[str] = None):
|
|
18
|
+
self.cmd = (cmd or os.environ.get("GEMINI_CLI_CMD", "gemini")).strip()
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def name(self) -> str:
|
|
22
|
+
return "gemini_cli"
|
|
23
|
+
|
|
24
|
+
def generate(
|
|
25
|
+
self,
|
|
26
|
+
user_message: str,
|
|
27
|
+
system_prompt: Optional[str] = None,
|
|
28
|
+
context: Optional[Dict[str, Any]] = None,
|
|
29
|
+
) -> LLMResponse:
|
|
30
|
+
prompt = user_message
|
|
31
|
+
if system_prompt:
|
|
32
|
+
prompt = f"{system_prompt}\n\n---\n\n{prompt}"
|
|
33
|
+
if context:
|
|
34
|
+
prompt = f"Context: {context}\n\n{prompt}"
|
|
35
|
+
raw = self._call_cli(prompt)
|
|
36
|
+
intent = self._parse_intent(raw)
|
|
37
|
+
return LLMResponse(raw=raw, intent=intent, explanation=raw)
|
|
38
|
+
|
|
39
|
+
def _call_cli(self, prompt: str) -> str:
|
|
40
|
+
try:
|
|
41
|
+
proc = subprocess.run(
|
|
42
|
+
[self.cmd, "run", "-"],
|
|
43
|
+
input=prompt,
|
|
44
|
+
capture_output=True,
|
|
45
|
+
text=True,
|
|
46
|
+
timeout=120,
|
|
47
|
+
)
|
|
48
|
+
if proc.returncode != 0:
|
|
49
|
+
return f"[gemini CLI error: {proc.stderr or proc.stdout or 'unknown'}]"
|
|
50
|
+
return (proc.stdout or "").strip()
|
|
51
|
+
except FileNotFoundError:
|
|
52
|
+
return f"[gemini CLI not found: {self.cmd}]"
|
|
53
|
+
except subprocess.TimeoutExpired:
|
|
54
|
+
return "[gemini CLI timeout]"
|
|
55
|
+
|
|
56
|
+
def _parse_intent(self, raw: str) -> Optional[Dict[str, Any]]:
|
|
57
|
+
"""Extract INTENT_JSON from response if present."""
|
|
58
|
+
match = re.search(r"INTENT_JSON:\s*(\{.*?\})\s*$", raw, re.DOTALL)
|
|
59
|
+
if not match:
|
|
60
|
+
return None
|
|
61
|
+
try:
|
|
62
|
+
import json
|
|
63
|
+
return json.loads(match.group(1).strip())
|
|
64
|
+
except Exception:
|
|
65
|
+
return None
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local LLM adapter — ollama / llama.cpp style.
|
|
3
|
+
Used when drift CLI is configured to use a local model instead of the backend.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LocalLLMAdapter(BaseLLMAdapter):
|
|
14
|
+
"""
|
|
15
|
+
Adapter for local LLM (e.g. Ollama, llama.cpp server).
|
|
16
|
+
Expects OLLAMA_BASE_URL or LOCAL_LLM_URL; model from OLLAMA_MODEL or LOCAL_LLM_MODEL.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, base_url: Optional[str] = None, model: Optional[str] = None):
|
|
20
|
+
self.base_url = (base_url or os.environ.get("OLLAMA_BASE_URL") or os.environ.get("LOCAL_LLM_URL", "http://localhost:11434")).rstrip("/")
|
|
21
|
+
self.model = model or os.environ.get("OLLAMA_MODEL") or os.environ.get("LOCAL_LLM_MODEL", "llama2")
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def name(self) -> str:
|
|
25
|
+
return "local_llm"
|
|
26
|
+
|
|
27
|
+
def generate(
|
|
28
|
+
self,
|
|
29
|
+
user_message: str,
|
|
30
|
+
system_prompt: Optional[str] = None,
|
|
31
|
+
context: Optional[Dict[str, Any]] = None,
|
|
32
|
+
) -> LLMResponse:
|
|
33
|
+
prompt = user_message
|
|
34
|
+
if system_prompt:
|
|
35
|
+
prompt = f"{system_prompt}\n\n---\n\n{prompt}"
|
|
36
|
+
if context:
|
|
37
|
+
prompt = f"Context: {context}\n\n{prompt}"
|
|
38
|
+
raw = self._call_api(prompt)
|
|
39
|
+
intent = self._parse_intent(raw)
|
|
40
|
+
return LLMResponse(raw=raw, intent=intent, explanation=raw)
|
|
41
|
+
|
|
42
|
+
def _call_api(self, prompt: str) -> str:
|
|
43
|
+
try:
|
|
44
|
+
import requests
|
|
45
|
+
r = requests.post(
|
|
46
|
+
f"{self.base_url}/api/generate",
|
|
47
|
+
json={"model": self.model, "prompt": prompt, "stream": False},
|
|
48
|
+
timeout=120,
|
|
49
|
+
)
|
|
50
|
+
if r.status_code >= 400:
|
|
51
|
+
return f"[Local LLM error: {r.status_code} {r.text[:200]}]"
|
|
52
|
+
data = r.json()
|
|
53
|
+
return (data.get("response") or "").strip()
|
|
54
|
+
except Exception as e:
|
|
55
|
+
return f"[Local LLM error: {e}]"
|
|
56
|
+
|
|
57
|
+
def _parse_intent(self, raw: str) -> Optional[Dict[str, Any]]:
|
|
58
|
+
match = re.search(r"INTENT_JSON:\s*(\{.*?\})\s*$", raw, re.DOTALL)
|
|
59
|
+
if not match:
|
|
60
|
+
return None
|
|
61
|
+
try:
|
|
62
|
+
import json
|
|
63
|
+
return json.loads(match.group(1).strip())
|
|
64
|
+
except Exception:
|
|
65
|
+
return None
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
drift/__init__.py
|
|
4
|
+
drift/__main__.py
|
|
5
|
+
drift/cli/__init__.py
|
|
6
|
+
drift/cli/client.py
|
|
7
|
+
drift/cli/repl.py
|
|
8
|
+
drift/cli/session.py
|
|
9
|
+
drift/llm_adapters/__init__.py
|
|
10
|
+
drift/llm_adapters/base.py
|
|
11
|
+
drift/llm_adapters/gemini_cli.py
|
|
12
|
+
drift/llm_adapters/local_llm.py
|
|
13
|
+
drift_ml.egg-info/PKG-INFO
|
|
14
|
+
drift_ml.egg-info/SOURCES.txt
|
|
15
|
+
drift_ml.egg-info/dependency_links.txt
|
|
16
|
+
drift_ml.egg-info/entry_points.txt
|
|
17
|
+
drift_ml.egg-info/requires.txt
|
|
18
|
+
drift_ml.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
requests>=2.28.0
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "drift-ml"
|
|
7
|
+
version = "0.1.1"
|
|
8
|
+
description = "Terminal-first AutoML CLI - chat-based ML engineer"
|
|
9
|
+
requires-python = ">=3.10"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"requests>=2.28.0",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[project.scripts]
|
|
15
|
+
drift = "drift.__main__:main"
|
|
16
|
+
|
|
17
|
+
[tool.setuptools.packages.find]
|
|
18
|
+
where = ["."]
|
|
19
|
+
include = ["drift*"]
|
drift_ml-0.1.1/setup.cfg
ADDED