drift-ml 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ Metadata-Version: 2.4
2
+ Name: drift-ml
3
+ Version: 0.1.1
4
+ Summary: Terminal-first AutoML CLI - chat-based ML engineer
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: requests>=2.28.0
@@ -0,0 +1,186 @@
1
+ # Intent2Model 🚀
2
+
3
+ LLM-Guided AutoML Agent - Upload a CSV, chat with the AI, get a trained model.
4
+
5
+ ## ⚡ Quick Start (Easiest Way)
6
+
7
+ ```bash
8
+ # Make scripts executable (first time only)
9
+ chmod +x start.sh stop.sh
10
+
11
+ # Start everything
12
+ ./start.sh
13
+ ```
14
+
15
+ Then open **http://localhost:3000** in your browser!
16
+
17
+ To stop: Press `Ctrl+C` or run `./stop.sh`
18
+
19
+ ---
20
+
21
+ ## 📖 How to Use
22
+
23
+ ### 1. Upload CSV
24
+ - Drag & drop your CSV file or click "choose file"
25
+ - System analyzes it automatically
26
+
27
+ ### 2. Train Model
28
+ Just type a column name in the chat:
29
+ - **"variety"** → trains model to predict "variety"
30
+ - **"price"** → trains model to predict "price"
31
+ - Or any column name from your dataset
32
+
33
+ ### 3. View Results
34
+ - **"report"** → shows beautiful charts and metrics
35
+ - **"show me results"** → displays model performance
36
+
37
+ ### 4. Make Predictions
38
+ - **"predict"** or **"can you predict for me?"** → starts prediction flow
39
+ - Provide feature values: **"sepal.length: 5.1, sepal.width: 3.5"**
40
+
41
+ ---
42
+
43
+ ## 💬 Example Conversation
44
+
45
+ ```
46
+ You: [uploads iris.csv]
47
+ AI: ✓ analyzed your dataset • 150 rows • 5 columns
48
+ AI: suggested targets: variety, sepal.length, sepal.width
49
+
50
+ You: variety
51
+ AI: 🚀 training model to predict "variety"...
52
+ AI: ✅ model trained successfully!
53
+ AI: accuracy: 1.000 • best model: RandomForest
54
+ AI: [shows charts]
55
+
56
+ You: report
57
+ AI: [shows detailed charts: metrics, feature importance, CV scores]
58
+
59
+ You: predict
60
+ AI: sure! i need: sepal.length, sepal.width, petal.length, petal.width
61
+
62
+ You: sepal.length: 5.1, sepal.width: 3.5, petal.length: 1.4, petal.width: 0.2
63
+ AI: 🎯 prediction: Setosa
64
+ AI: probabilities: Setosa 99.8%, Versicolor 0.2%, Virginica 0.0%
65
+ ```
66
+
67
+ ---
68
+
69
+ ## 🛠️ Manual Setup (Alternative)
70
+
71
+ ### Backend
72
+ **Important:** Always run uvicorn from inside the `backend/` folder. Running from project root will fail with "Could not import module main".
73
+
74
+ ```bash
75
+ cd backend
76
+ pip install -r ../requirements.txt
77
+ # Optional: set API key in .env or export GEMINI_API_KEY=your_key
78
+ python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload
79
+ ```
80
+
81
+ Or use the helper script (from project root):
82
+ ```bash
83
+ chmod +x backend/run.sh
84
+ ./backend/run.sh
85
+ ```
86
+
87
+ **Backend nahi chal raha?**
88
+ - **"Could not import module main"** → You're in the wrong folder. Run `cd backend` first, then `python3 -m uvicorn main:app --host 0.0.0.0 --port 8000`.
89
+ - **"Address already in use" / port 8000** → Free the port: `lsof -ti:8000 | xargs kill -9`, then start again.
90
+ - **Dependencies missing** → From project root: `pip install -r requirements.txt` (or use `./start.sh` — it creates a venv and installs deps).
91
+
92
+ ### Frontend (New Terminal)
93
+ ```bash
94
+ cd frontend
95
+ npm install
96
+ npm run dev
97
+ ```
98
+
99
+ Visit **http://localhost:3000**
100
+
101
+ ---
102
+
103
+ ## 🎨 Features
104
+
105
+ - 📊 **Beautiful Charts**: Metrics, feature importance, CV scores
106
+ - 🎨 **Extravagant UI**: Gradient colors, smooth animations
107
+ - 🤖 **LLM-Powered**: Gemini AI generates optimal pipelines
108
+ - 🔮 **Smart Predictions**: Chat-based prediction interface
109
+ - 📈 **Model Comparison**: Tries multiple models, picks best
110
+ - ⚡ **Auto-Detection**: Automatically detects task type and metrics
111
+
112
+ ---
113
+
114
+ ## 📝 Requirements
115
+
116
+ - Python 3.10+
117
+ - Node.js 18+
118
+ - npm/yarn
119
+
120
+ Install dependencies:
121
+ ```bash
122
+ # Backend
123
+ pip install -r requirements.txt
124
+
125
+ # Frontend
126
+ cd frontend
127
+ npm install
128
+ ```
129
+
130
+ ---
131
+
132
+ ## 🐛 Troubleshooting
133
+
134
+ **Services not starting?**
135
+ - Check ports: `lsof -i :8000` and `lsof -i :3000`
136
+ - Check logs: `tail -f backend.log` or `tail -f frontend.log`
137
+
138
+ **Training errors?**
139
+ - Make sure CSV has valid data
140
+ - Check that target column exists
141
+ - Try a different column name
142
+
143
+ ---
144
+
145
+ ## drift — Terminal-first CLI
146
+
147
+ **drift** by Lakshit Sachdeva. Terminal-first, chat-based AutoML — same engine as the web UI. No commands to memorize.
148
+
149
+ ### Exactly what to do (any computer)
150
+
151
+ 1. **Install drift** (pick one):
152
+ ```bash
153
+ npm install -g drift-ml
154
+ ```
155
+ or:
156
+ ```bash
157
+ pipx install drift
158
+ ```
159
+
160
+ 2. **Run drift:**
161
+ ```bash
162
+ drift
163
+ ```
164
+ You’ll see the welcome and step-by-step instructions in the terminal.
165
+
166
+ 3. **Engine** — On first run the CLI downloads and starts the drift engine locally (or set `DRIFT_BACKEND_URL` to a running engine). You need an LLM: Gemini CLI, Ollama, or another local LLM.
167
+
168
+ 4. **In drift:** type `load path/to/your.csv`, then chat (e.g. `predict price`, `try something stronger`). Type `quit` to exit.
169
+
170
+ drift shows you the rest when you run it.
171
+
172
+ ### Install (details)
173
+
174
+ - **Local-first** — Same engine as the web app; planning and training run on your machine.
175
+ - **Chat-based**: e.g. `load iris.csv`, `predict price`, `try something stronger`, `why is accuracy capped`.
176
+ - **Engine** runs locally (CLI auto-starts it or use `DRIFT_BACKEND_URL`). Web UI can be hosted on Vercel.
177
+
178
+ ---
179
+
180
+ ## 📚 More Info
181
+
182
+ See `HOW_TO_USE.md` for detailed instructions and examples.
183
+
184
+ ---
185
+
186
+ **That's it! Just run `./start.sh` and start chatting! 🎉**
@@ -0,0 +1,5 @@
1
+ """
2
+ drift — terminal-first, chat-based CLI for the same AutoML planner + executor as the web app.
3
+ """
4
+
5
+ __version__ = "0.1.0"
@@ -0,0 +1,14 @@
1
+ """
2
+ Entry point for `python -m drift` or `drift` command.
3
+ Runs the chat-based REPL.
4
+ """
5
+
6
+ from drift.cli.repl import run_repl
7
+
8
+
9
+ def main() -> None:
10
+ run_repl()
11
+
12
+
13
+ if __name__ == "__main__":
14
+ main()
@@ -0,0 +1,6 @@
1
+ """drift CLI: chat-based REPL and session state."""
2
+
3
+ from drift.cli.session import SessionState
4
+ from drift.cli.repl import run_repl
5
+
6
+ __all__ = ["SessionState", "run_repl"]
@@ -0,0 +1,125 @@
1
+ """
2
+ HTTP client for the existing backend API.
3
+ Uses the same endpoints as the web app: upload, chat, session, runs, train.
4
+ """
5
+
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ try:
11
+ import requests
12
+ except ImportError:
13
+ requests = None # type: ignore
14
+
15
+ DEFAULT_BASE_URL = os.environ.get("DRIFT_BACKEND_URL", "http://localhost:8000")
16
+
17
+
18
+ class BackendError(Exception):
19
+ """Raised when the backend returns an error or is unreachable."""
20
+
21
+ def __init__(self, message: str, status_code: Optional[int] = None, body: Any = None):
22
+ self.message = message
23
+ self.status_code = status_code
24
+ self.body = body
25
+ super().__init__(message)
26
+
27
+
28
+ class BackendClient:
29
+ """Client for backend upload, chat, session, runs, and train endpoints."""
30
+
31
+ def __init__(self, base_url: str = DEFAULT_BASE_URL, timeout: int = 300):
32
+ self.base_url = base_url.rstrip("/")
33
+ self.timeout = timeout
34
+ if requests is None:
35
+ raise BackendError(
36
+ "The 'requests' library is required for the drift CLI. Install it with: pip install requests"
37
+ )
38
+
39
+ def _url(self, path: str) -> str:
40
+ return f"{self.base_url}{path}"
41
+
42
+ def health(self) -> Dict[str, Any]:
43
+ """GET /health — check backend is up and LLM status."""
44
+ r = requests.get(self._url("/health"), timeout=10)
45
+ r.raise_for_status()
46
+ return r.json()
47
+
48
+ def upload_csv(self, file_path: str) -> Dict[str, Any]:
49
+ """
50
+ POST /upload — upload a local CSV and create a session.
51
+ Returns session_id, dataset_id, profile, initial_message.
52
+ """
53
+ path = Path(file_path).resolve()
54
+ if not path.exists():
55
+ raise BackendError(f"File not found: {path}")
56
+ with open(path, "rb") as f:
57
+ name = path.name
58
+ files = {"file": (name, f, "text/csv")}
59
+ r = requests.post(self._url("/upload"), files=files, timeout=60)
60
+ if r.status_code >= 400:
61
+ try:
62
+ body = r.json()
63
+ detail = body.get("detail", str(r.text))
64
+ except Exception:
65
+ detail = r.text or str(r.status_code)
66
+ raise BackendError(f"Upload failed: {detail}", status_code=r.status_code, body=r.text)
67
+ return r.json()
68
+
69
+ def chat(self, session_id: str, message: str) -> Dict[str, Any]:
70
+ """
71
+ POST /session/{session_id}/chat — send a user message, get agent reply.
72
+ Returns chat_history, trigger_training.
73
+ """
74
+ r = requests.post(
75
+ self._url(f"/session/{session_id}/chat"),
76
+ json={"message": message},
77
+ timeout=120,
78
+ )
79
+ if r.status_code >= 400:
80
+ try:
81
+ body = r.json()
82
+ detail = body.get("detail", str(r.text))
83
+ except Exception:
84
+ detail = r.text or str(r.status_code)
85
+ raise BackendError(f"Chat failed: {detail}", status_code=r.status_code, body=r.text)
86
+ return r.json()
87
+
88
+ def get_session(self, session_id: str) -> Dict[str, Any]:
89
+ """GET /session/{session_id} — full session state (chat, model_state, current_run_id, etc.)."""
90
+ r = requests.get(self._url(f"/session/{session_id}"), timeout=30)
91
+ if r.status_code >= 400:
92
+ try:
93
+ body = r.json()
94
+ detail = body.get("detail", str(r.text))
95
+ except Exception:
96
+ detail = r.text or str(r.status_code)
97
+ raise BackendError(f"Get session failed: {detail}", status_code=r.status_code, body=r.text)
98
+ return r.json()
99
+
100
+ def get_run_state(self, run_id: str) -> Dict[str, Any]:
101
+ """GET /runs/{run_id} — run state: status, current_step, progress, events[]."""
102
+ r = requests.get(self._url(f"/runs/{run_id}"), timeout=30)
103
+ if r.status_code >= 400:
104
+ try:
105
+ body = r.json()
106
+ detail = body.get("detail", str(r.text))
107
+ except Exception:
108
+ detail = r.text or str(r.status_code)
109
+ raise BackendError(f"Get run failed: {detail}", status_code=r.status_code, body=r.text)
110
+ return r.json()
111
+
112
+ def train(self, session_id: str) -> Dict[str, Any]:
113
+ """
114
+ POST /session/{session_id}/train — run one training attempt (blocking).
115
+ Returns run_id, session_id, metrics, refused, refusal_reason, agent_message.
116
+ """
117
+ r = requests.post(self._url(f"/session/{session_id}/train"), timeout=self.timeout)
118
+ if r.status_code >= 400:
119
+ try:
120
+ body = r.json()
121
+ detail = body.get("detail", str(r.text))
122
+ except Exception:
123
+ detail = r.text or str(r.status_code)
124
+ raise BackendError(f"Train failed: {detail}", status_code=r.status_code, body=r.text)
125
+ return r.json()
@@ -0,0 +1,220 @@
1
+ """
2
+ Chat-based CLI loop for drift.
3
+ Natural language input; maintains session state; reuses backend planner + executor.
4
+ """
5
+
6
+ import re
7
+ import sys
8
+ import threading
9
+ import time
10
+ from typing import Any, Dict, Optional
11
+
12
+ from drift.cli.client import BackendClient, BackendError, DEFAULT_BASE_URL
13
+ from drift.cli.session import SessionState
14
+
15
+
16
+ PROMPT = "drift › "
17
+ LOAD_PATTERN = re.compile(r"^\s*load\s+(.+)$", re.IGNORECASE)
18
+
19
+
20
+ def run_repl(base_url: Optional[str] = None) -> None:
21
+ """Run the chat-based REPL. Uses backend at base_url (default from DRIFT_BACKEND_URL or localhost:8000)."""
22
+ try:
23
+ client = BackendClient(base_url=base_url or DEFAULT_BASE_URL)
24
+ except BackendError as e:
25
+ print(f"drift: {e.message}", file=sys.stderr)
26
+ sys.exit(1)
27
+ session = SessionState()
28
+ _print_banner(client)
29
+ while True:
30
+ try:
31
+ line = input(PROMPT).strip()
32
+ except EOFError:
33
+ print()
34
+ break
35
+ if not line:
36
+ continue
37
+ if line.lower() in ("quit", "exit", "q"):
38
+ break
39
+ try:
40
+ _handle_input(line, client, session)
41
+ except BackendError as e:
42
+ print(f"Error: {e.message}", file=sys.stderr)
43
+ except KeyboardInterrupt:
44
+ print("\nInterrupted. Type 'quit' to exit.")
45
+ continue
46
+
47
+
48
+ def _print_banner(client: BackendClient) -> None:
49
+ _print_welcome()
50
+ engine_ok = False
51
+ llm_name = ""
52
+ try:
53
+ health = client.health()
54
+ engine_ok = (health.get("status") or "").lower() == "healthy"
55
+ llm_name = health.get("llm_provider") or health.get("current_model") or ""
56
+ except Exception:
57
+ pass
58
+ _print_status(engine_ok=engine_ok, llm_name=llm_name)
59
+ _print_examples()
60
+ print()
61
+
62
+
63
+ def _print_welcome() -> None:
64
+ """drift by Lakshit Sachdeva. Local-first ML engineer."""
65
+ print()
66
+ print(" ----------------------------------------")
67
+ print(" drift by Lakshit Sachdeva")
68
+ print(" Local-first ML engineer.")
69
+ print(" Same engine as the web app.")
70
+ print(" No commands to memorize.")
71
+ print(" ----------------------------------------")
72
+ print()
73
+
74
+
75
+ def _print_status(engine_ok: bool, llm_name: str) -> None:
76
+ """Engine running, LLM detected, Ready."""
77
+ if engine_ok:
78
+ print(" \u2713 Engine running")
79
+ if llm_name:
80
+ print(f" \u2713 LLM detected ({llm_name})")
81
+ else:
82
+ print(" \u2713 LLM detected (Gemini CLI / Ollama / local)")
83
+ print(" \u2713 Ready")
84
+ else:
85
+ print(" \u2717 Engine not running")
86
+ print(" Start the engine or set DRIFT_BACKEND_URL to a running engine URL.")
87
+ print()
88
+
89
+
90
+ def _print_examples() -> None:
91
+ """Examples: load, predict, try something stronger, quit."""
92
+ print(" Examples:")
93
+ print(" load data.csv")
94
+ print(" predict price")
95
+ print(" try something stronger")
96
+ print(" why is accuracy capped")
97
+ print(" quit")
98
+ print()
99
+
100
+
101
+ def _handle_input(line: str, client: BackendClient, session: SessionState) -> None:
102
+ load_match = LOAD_PATTERN.match(line)
103
+ if load_match:
104
+ path = load_match.group(1).strip().strip('"\'')
105
+ _do_load(path, client, session)
106
+ return
107
+ if not session.has_session():
108
+ print("Load a dataset first: load path/to/file.csv")
109
+ return
110
+ _do_chat(line, client, session)
111
+
112
+
113
+ def _do_load(path: str, client: BackendClient, session: SessionState) -> None:
114
+ out = client.upload_csv(path)
115
+ session.dataset_path = path
116
+ session.update_from_upload(
117
+ session_id=out["session_id"],
118
+ dataset_id=out["dataset_id"],
119
+ initial_message=out.get("initial_message"),
120
+ )
121
+ print("Dataset loaded.")
122
+ if out.get("initial_message"):
123
+ content = out["initial_message"].get("content") or ""
124
+ if content:
125
+ _print_message(content)
126
+ print()
127
+
128
+
129
+ def _do_chat(message: str, client: BackendClient, session: SessionState) -> None:
130
+ sid = session.session_id
131
+ out = client.chat(sid, message)
132
+ chat_history = out.get("chat_history") or []
133
+ session.update_from_chat(chat_history)
134
+ # Show latest agent reply
135
+ for m in reversed(chat_history):
136
+ if m.get("role") == "agent" and m.get("content"):
137
+ _print_message(m["content"])
138
+ break
139
+ trigger = out.get("trigger_training") is True
140
+ if trigger:
141
+ _run_training_and_show(client, session)
142
+ print()
143
+
144
+
145
+ def _run_training_and_show(client: BackendClient, session: SessionState) -> None:
146
+ sid = session.session_id
147
+ if not sid:
148
+ return
149
+ train_result: Optional[Dict[str, Any]] = None
150
+ train_error: Optional[Exception] = None
151
+
152
+ def do_train() -> None:
153
+ nonlocal train_result, train_error
154
+ try:
155
+ train_result = client.train(sid)
156
+ except Exception as e:
157
+ train_error = e
158
+
159
+ thread = threading.Thread(target=do_train, daemon=True)
160
+ thread.start()
161
+ print("Training started…")
162
+ run_id: Optional[str] = None
163
+ last_event_count = 0
164
+ poll_interval = 0.8
165
+ timeout_sec = 300
166
+ start = time.time()
167
+ while thread.is_alive() and (time.time() - start) < timeout_sec:
168
+ time.sleep(poll_interval)
169
+ try:
170
+ sess = client.get_session(sid)
171
+ run_id = sess.get("current_run_id")
172
+ if run_id:
173
+ break
174
+ except Exception:
175
+ pass
176
+ if run_id:
177
+ while thread.is_alive() and (time.time() - start) < timeout_sec:
178
+ time.sleep(poll_interval)
179
+ try:
180
+ run_state = client.get_run_state(run_id)
181
+ events = run_state.get("events") or []
182
+ for ev in events[last_event_count:]:
183
+ msg = ev.get("message") or ev.get("step_name") or ""
184
+ if msg:
185
+ print(f" {msg}")
186
+ last_event_count = len(events)
187
+ status = run_state.get("status")
188
+ if status in ("success", "failed", "refused"):
189
+ break
190
+ except Exception:
191
+ pass
192
+ thread.join(timeout=1.0)
193
+ if train_error:
194
+ print(f"Training error: {train_error}", file=sys.stderr)
195
+ return
196
+ if train_result:
197
+ session.update_after_train(
198
+ run_id=train_result.get("run_id", ""),
199
+ metrics=train_result.get("metrics") or {},
200
+ agent_message=train_result.get("agent_message"),
201
+ )
202
+ agent_message = train_result.get("agent_message")
203
+ if agent_message:
204
+ _print_message(agent_message)
205
+ metrics = train_result.get("metrics") or {}
206
+ if metrics:
207
+ primary = metrics.get("primary_metric_value") or metrics.get("accuracy") or metrics.get("r2")
208
+ if primary is not None:
209
+ print(f" Primary metric: {primary}")
210
+
211
+
212
+ def _print_message(content: str) -> None:
213
+ """Print agent message; strip markdown bold for terminal if desired, keep newlines."""
214
+ text = content.strip()
215
+ if not text:
216
+ return
217
+ # Optional: replace **x** with x for readability in terminal
218
+ text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
219
+ for line in text.split("\n"):
220
+ print(line)
@@ -0,0 +1,70 @@
1
+ """
2
+ Session state for the drift CLI.
3
+ Holds dataset ref, plan summary, last metrics, run_id, and chat history
4
+ so the REPL can maintain context and the backend can be queried correctly.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, List, Optional
9
+
10
+
11
+ @dataclass
12
+ class SessionState:
13
+ """In-memory CLI session: dataset, plan, run, metrics, chat."""
14
+
15
+ dataset_path: Optional[str] = None
16
+ dataset_id: Optional[str] = None
17
+ session_id: Optional[str] = None
18
+ run_id: Optional[str] = None
19
+ plan_summary: Optional[str] = None
20
+ last_metrics: Dict[str, Any] = field(default_factory=dict)
21
+ chat_history: List[Dict[str, Any]] = field(default_factory=list)
22
+
23
+ def has_session(self) -> bool:
24
+ return bool(self.session_id)
25
+
26
+ def has_dataset(self) -> bool:
27
+ return bool(self.dataset_id and self.session_id)
28
+
29
+ def clear_run(self) -> None:
30
+ self.run_id = None
31
+ self.last_metrics = {}
32
+
33
+ def update_from_upload(self, session_id: str, dataset_id: str, initial_message: Optional[Dict] = None) -> None:
34
+ self.session_id = session_id
35
+ self.dataset_id = dataset_id
36
+ self.run_id = None
37
+ self.last_metrics = {}
38
+ if initial_message:
39
+ self.chat_history = [initial_message]
40
+
41
+ def update_from_chat(self, chat_history: List[Dict[str, Any]], run_id: Optional[str] = None) -> None:
42
+ self.chat_history = chat_history or self.chat_history
43
+ if run_id is not None:
44
+ self.run_id = run_id
45
+
46
+ def update_from_session(self, session: Dict[str, Any]) -> None:
47
+ self.session_id = session.get("session_id") or self.session_id
48
+ self.dataset_id = session.get("dataset_id") or self.dataset_id
49
+ self.run_id = session.get("current_run_id") or self.run_id
50
+ self.chat_history = session.get("chat_history") or self.chat_history
51
+ ms = session.get("model_state") or {}
52
+ self.last_metrics = ms.get("metrics") or self.last_metrics
53
+ self.plan_summary = None
54
+ if session.get("structural_plan"):
55
+ self.plan_summary = _summarize_plan(session["structural_plan"])
56
+
57
+ def update_after_train(self, run_id: str, metrics: Dict[str, Any], agent_message: Optional[str] = None) -> None:
58
+ self.run_id = run_id
59
+ self.last_metrics = metrics or self.last_metrics
60
+ if agent_message and self.chat_history:
61
+ self.chat_history[-1]["content"] = agent_message
62
+
63
+
64
+ def _summarize_plan(structural_plan: Dict[str, Any]) -> str:
65
+ """One-line summary of structural plan for display (StructuralPlan has inferred_target, task_type; no models)."""
66
+ if not structural_plan:
67
+ return ""
68
+ target = structural_plan.get("inferred_target") or structural_plan.get("target") or "?"
69
+ task = structural_plan.get("task_type") or "?"
70
+ return f"Target: {target}, task: {task}"
@@ -0,0 +1,9 @@
1
+ """
2
+ Pluggable LLM adapters for drift CLI.
3
+ Primary planning/execution uses the backend; these adapters support
4
+ optional local processing (e.g. help, summaries) or future offline use.
5
+ """
6
+
7
+ from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
8
+
9
+ __all__ = ["BaseLLMAdapter", "LLMResponse"]
@@ -0,0 +1,41 @@
1
+ """
2
+ Base contract for drift LLM adapters.
3
+ The CLI uses the backend as the planner/executor; adapters return
4
+ structured ML intent, plan updates, and explanations when used locally.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional
10
+
11
+
12
+ @dataclass
13
+ class LLMResponse:
14
+ """Structured response from an LLM adapter: intent, plan updates, explanation."""
15
+
16
+ raw: str
17
+ intent: Optional[Dict[str, Any]] = None # target, drop_columns, performance_mode, start_training
18
+ plan_summary: Optional[str] = None
19
+ explanation: Optional[str] = None
20
+
21
+
22
+ class BaseLLMAdapter(ABC):
23
+ """
24
+ LLM adapter contract.
25
+ The LLM is the PLANNER, not the executor; it updates plans and explains outcomes.
26
+ Execution is assumed to happen automatically (backend).
27
+ """
28
+
29
+ @abstractmethod
30
+ def generate(self, user_message: str, system_prompt: Optional[str] = None, context: Optional[Dict[str, Any]] = None) -> LLMResponse:
31
+ """
32
+ Turn user message + optional context into a structured response.
33
+ Returns intent, plan summary, and explanation — never claims to edit files or execute.
34
+ """
35
+ pass
36
+
37
+ @property
38
+ @abstractmethod
39
+ def name(self) -> str:
40
+ """Adapter identifier (e.g. gemini_cli, local_llm)."""
41
+ pass
@@ -0,0 +1,65 @@
1
+ """
2
+ Gemini CLI adapter — uses gemini CLI (default if available) for local LLM calls.
3
+ Used when drift CLI runs in a mode that calls a local LLM instead of the backend.
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import subprocess
9
+ from typing import Any, Dict, Optional
10
+
11
+ from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
12
+
13
+
14
+ class GeminiCLIAdapter(BaseLLMAdapter):
15
+ """Adapter that shells out to gemini CLI (e.g. `gemini` or GEMINI_CLI_CMD)."""
16
+
17
+ def __init__(self, cmd: Optional[str] = None):
18
+ self.cmd = (cmd or os.environ.get("GEMINI_CLI_CMD", "gemini")).strip()
19
+
20
+ @property
21
+ def name(self) -> str:
22
+ return "gemini_cli"
23
+
24
+ def generate(
25
+ self,
26
+ user_message: str,
27
+ system_prompt: Optional[str] = None,
28
+ context: Optional[Dict[str, Any]] = None,
29
+ ) -> LLMResponse:
30
+ prompt = user_message
31
+ if system_prompt:
32
+ prompt = f"{system_prompt}\n\n---\n\n{prompt}"
33
+ if context:
34
+ prompt = f"Context: {context}\n\n{prompt}"
35
+ raw = self._call_cli(prompt)
36
+ intent = self._parse_intent(raw)
37
+ return LLMResponse(raw=raw, intent=intent, explanation=raw)
38
+
39
+ def _call_cli(self, prompt: str) -> str:
40
+ try:
41
+ proc = subprocess.run(
42
+ [self.cmd, "run", "-"],
43
+ input=prompt,
44
+ capture_output=True,
45
+ text=True,
46
+ timeout=120,
47
+ )
48
+ if proc.returncode != 0:
49
+ return f"[gemini CLI error: {proc.stderr or proc.stdout or 'unknown'}]"
50
+ return (proc.stdout or "").strip()
51
+ except FileNotFoundError:
52
+ return f"[gemini CLI not found: {self.cmd}]"
53
+ except subprocess.TimeoutExpired:
54
+ return "[gemini CLI timeout]"
55
+
56
+ def _parse_intent(self, raw: str) -> Optional[Dict[str, Any]]:
57
+ """Extract INTENT_JSON from response if present."""
58
+ match = re.search(r"INTENT_JSON:\s*(\{.*?\})\s*$", raw, re.DOTALL)
59
+ if not match:
60
+ return None
61
+ try:
62
+ import json
63
+ return json.loads(match.group(1).strip())
64
+ except Exception:
65
+ return None
@@ -0,0 +1,65 @@
1
+ """
2
+ Local LLM adapter — ollama / llama.cpp style.
3
+ Used when drift CLI is configured to use a local model instead of the backend.
4
+ """
5
+
6
+ import os
7
+ import re
8
+ from typing import Any, Dict, Optional
9
+
10
+ from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
11
+
12
+
13
+ class LocalLLMAdapter(BaseLLMAdapter):
14
+ """
15
+ Adapter for local LLM (e.g. Ollama, llama.cpp server).
16
+ Expects OLLAMA_BASE_URL or LOCAL_LLM_URL; model from OLLAMA_MODEL or LOCAL_LLM_MODEL.
17
+ """
18
+
19
+ def __init__(self, base_url: Optional[str] = None, model: Optional[str] = None):
20
+ self.base_url = (base_url or os.environ.get("OLLAMA_BASE_URL") or os.environ.get("LOCAL_LLM_URL", "http://localhost:11434")).rstrip("/")
21
+ self.model = model or os.environ.get("OLLAMA_MODEL") or os.environ.get("LOCAL_LLM_MODEL", "llama2")
22
+
23
+ @property
24
+ def name(self) -> str:
25
+ return "local_llm"
26
+
27
+ def generate(
28
+ self,
29
+ user_message: str,
30
+ system_prompt: Optional[str] = None,
31
+ context: Optional[Dict[str, Any]] = None,
32
+ ) -> LLMResponse:
33
+ prompt = user_message
34
+ if system_prompt:
35
+ prompt = f"{system_prompt}\n\n---\n\n{prompt}"
36
+ if context:
37
+ prompt = f"Context: {context}\n\n{prompt}"
38
+ raw = self._call_api(prompt)
39
+ intent = self._parse_intent(raw)
40
+ return LLMResponse(raw=raw, intent=intent, explanation=raw)
41
+
42
+ def _call_api(self, prompt: str) -> str:
43
+ try:
44
+ import requests
45
+ r = requests.post(
46
+ f"{self.base_url}/api/generate",
47
+ json={"model": self.model, "prompt": prompt, "stream": False},
48
+ timeout=120,
49
+ )
50
+ if r.status_code >= 400:
51
+ return f"[Local LLM error: {r.status_code} {r.text[:200]}]"
52
+ data = r.json()
53
+ return (data.get("response") or "").strip()
54
+ except Exception as e:
55
+ return f"[Local LLM error: {e}]"
56
+
57
+ def _parse_intent(self, raw: str) -> Optional[Dict[str, Any]]:
58
+ match = re.search(r"INTENT_JSON:\s*(\{.*?\})\s*$", raw, re.DOTALL)
59
+ if not match:
60
+ return None
61
+ try:
62
+ import json
63
+ return json.loads(match.group(1).strip())
64
+ except Exception:
65
+ return None
@@ -0,0 +1,6 @@
1
+ Metadata-Version: 2.4
2
+ Name: drift-ml
3
+ Version: 0.1.1
4
+ Summary: Terminal-first AutoML CLI - chat-based ML engineer
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: requests>=2.28.0
@@ -0,0 +1,18 @@
1
+ README.md
2
+ pyproject.toml
3
+ drift/__init__.py
4
+ drift/__main__.py
5
+ drift/cli/__init__.py
6
+ drift/cli/client.py
7
+ drift/cli/repl.py
8
+ drift/cli/session.py
9
+ drift/llm_adapters/__init__.py
10
+ drift/llm_adapters/base.py
11
+ drift/llm_adapters/gemini_cli.py
12
+ drift/llm_adapters/local_llm.py
13
+ drift_ml.egg-info/PKG-INFO
14
+ drift_ml.egg-info/SOURCES.txt
15
+ drift_ml.egg-info/dependency_links.txt
16
+ drift_ml.egg-info/entry_points.txt
17
+ drift_ml.egg-info/requires.txt
18
+ drift_ml.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ drift = drift.__main__:main
@@ -0,0 +1 @@
1
+ requests>=2.28.0
@@ -0,0 +1,2 @@
1
+ drift
2
+ drift-npm
@@ -0,0 +1,19 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "drift-ml"
7
+ version = "0.1.1"
8
+ description = "Terminal-first AutoML CLI - chat-based ML engineer"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "requests>=2.28.0",
12
+ ]
13
+
14
+ [project.scripts]
15
+ drift = "drift.__main__:main"
16
+
17
+ [tool.setuptools.packages.find]
18
+ where = ["."]
19
+ include = ["drift*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+