drift-ml 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- drift/__init__.py +5 -0
- drift/__main__.py +14 -0
- drift/cli/__init__.py +6 -0
- drift/cli/client.py +125 -0
- drift/cli/repl.py +220 -0
- drift/cli/session.py +70 -0
- drift/llm_adapters/__init__.py +9 -0
- drift/llm_adapters/base.py +41 -0
- drift/llm_adapters/gemini_cli.py +65 -0
- drift/llm_adapters/local_llm.py +65 -0
- drift_ml-0.1.1.dist-info/METADATA +6 -0
- drift_ml-0.1.1.dist-info/RECORD +15 -0
- drift_ml-0.1.1.dist-info/WHEEL +5 -0
- drift_ml-0.1.1.dist-info/entry_points.txt +2 -0
- drift_ml-0.1.1.dist-info/top_level.txt +1 -0
drift/__init__.py
ADDED
drift/__main__.py
ADDED
drift/cli/__init__.py
ADDED
drift/cli/client.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HTTP client for the existing backend API.
|
|
3
|
+
Uses the same endpoints as the web app: upload, chat, session, runs, train.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import requests
|
|
12
|
+
except ImportError:
|
|
13
|
+
requests = None # type: ignore
|
|
14
|
+
|
|
15
|
+
DEFAULT_BASE_URL = os.environ.get("DRIFT_BACKEND_URL", "http://localhost:8000")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BackendError(Exception):
|
|
19
|
+
"""Raised when the backend returns an error or is unreachable."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, message: str, status_code: Optional[int] = None, body: Any = None):
|
|
22
|
+
self.message = message
|
|
23
|
+
self.status_code = status_code
|
|
24
|
+
self.body = body
|
|
25
|
+
super().__init__(message)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BackendClient:
|
|
29
|
+
"""Client for backend upload, chat, session, runs, and train endpoints."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, base_url: str = DEFAULT_BASE_URL, timeout: int = 300):
|
|
32
|
+
self.base_url = base_url.rstrip("/")
|
|
33
|
+
self.timeout = timeout
|
|
34
|
+
if requests is None:
|
|
35
|
+
raise BackendError(
|
|
36
|
+
"The 'requests' library is required for the drift CLI. Install it with: pip install requests"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def _url(self, path: str) -> str:
|
|
40
|
+
return f"{self.base_url}{path}"
|
|
41
|
+
|
|
42
|
+
def health(self) -> Dict[str, Any]:
|
|
43
|
+
"""GET /health — check backend is up and LLM status."""
|
|
44
|
+
r = requests.get(self._url("/health"), timeout=10)
|
|
45
|
+
r.raise_for_status()
|
|
46
|
+
return r.json()
|
|
47
|
+
|
|
48
|
+
def upload_csv(self, file_path: str) -> Dict[str, Any]:
|
|
49
|
+
"""
|
|
50
|
+
POST /upload — upload a local CSV and create a session.
|
|
51
|
+
Returns session_id, dataset_id, profile, initial_message.
|
|
52
|
+
"""
|
|
53
|
+
path = Path(file_path).resolve()
|
|
54
|
+
if not path.exists():
|
|
55
|
+
raise BackendError(f"File not found: {path}")
|
|
56
|
+
with open(path, "rb") as f:
|
|
57
|
+
name = path.name
|
|
58
|
+
files = {"file": (name, f, "text/csv")}
|
|
59
|
+
r = requests.post(self._url("/upload"), files=files, timeout=60)
|
|
60
|
+
if r.status_code >= 400:
|
|
61
|
+
try:
|
|
62
|
+
body = r.json()
|
|
63
|
+
detail = body.get("detail", str(r.text))
|
|
64
|
+
except Exception:
|
|
65
|
+
detail = r.text or str(r.status_code)
|
|
66
|
+
raise BackendError(f"Upload failed: {detail}", status_code=r.status_code, body=r.text)
|
|
67
|
+
return r.json()
|
|
68
|
+
|
|
69
|
+
def chat(self, session_id: str, message: str) -> Dict[str, Any]:
|
|
70
|
+
"""
|
|
71
|
+
POST /session/{session_id}/chat — send a user message, get agent reply.
|
|
72
|
+
Returns chat_history, trigger_training.
|
|
73
|
+
"""
|
|
74
|
+
r = requests.post(
|
|
75
|
+
self._url(f"/session/{session_id}/chat"),
|
|
76
|
+
json={"message": message},
|
|
77
|
+
timeout=120,
|
|
78
|
+
)
|
|
79
|
+
if r.status_code >= 400:
|
|
80
|
+
try:
|
|
81
|
+
body = r.json()
|
|
82
|
+
detail = body.get("detail", str(r.text))
|
|
83
|
+
except Exception:
|
|
84
|
+
detail = r.text or str(r.status_code)
|
|
85
|
+
raise BackendError(f"Chat failed: {detail}", status_code=r.status_code, body=r.text)
|
|
86
|
+
return r.json()
|
|
87
|
+
|
|
88
|
+
def get_session(self, session_id: str) -> Dict[str, Any]:
|
|
89
|
+
"""GET /session/{session_id} — full session state (chat, model_state, current_run_id, etc.)."""
|
|
90
|
+
r = requests.get(self._url(f"/session/{session_id}"), timeout=30)
|
|
91
|
+
if r.status_code >= 400:
|
|
92
|
+
try:
|
|
93
|
+
body = r.json()
|
|
94
|
+
detail = body.get("detail", str(r.text))
|
|
95
|
+
except Exception:
|
|
96
|
+
detail = r.text or str(r.status_code)
|
|
97
|
+
raise BackendError(f"Get session failed: {detail}", status_code=r.status_code, body=r.text)
|
|
98
|
+
return r.json()
|
|
99
|
+
|
|
100
|
+
def get_run_state(self, run_id: str) -> Dict[str, Any]:
|
|
101
|
+
"""GET /runs/{run_id} — run state: status, current_step, progress, events[]."""
|
|
102
|
+
r = requests.get(self._url(f"/runs/{run_id}"), timeout=30)
|
|
103
|
+
if r.status_code >= 400:
|
|
104
|
+
try:
|
|
105
|
+
body = r.json()
|
|
106
|
+
detail = body.get("detail", str(r.text))
|
|
107
|
+
except Exception:
|
|
108
|
+
detail = r.text or str(r.status_code)
|
|
109
|
+
raise BackendError(f"Get run failed: {detail}", status_code=r.status_code, body=r.text)
|
|
110
|
+
return r.json()
|
|
111
|
+
|
|
112
|
+
def train(self, session_id: str) -> Dict[str, Any]:
|
|
113
|
+
"""
|
|
114
|
+
POST /session/{session_id}/train — run one training attempt (blocking).
|
|
115
|
+
Returns run_id, session_id, metrics, refused, refusal_reason, agent_message.
|
|
116
|
+
"""
|
|
117
|
+
r = requests.post(self._url(f"/session/{session_id}/train"), timeout=self.timeout)
|
|
118
|
+
if r.status_code >= 400:
|
|
119
|
+
try:
|
|
120
|
+
body = r.json()
|
|
121
|
+
detail = body.get("detail", str(r.text))
|
|
122
|
+
except Exception:
|
|
123
|
+
detail = r.text or str(r.status_code)
|
|
124
|
+
raise BackendError(f"Train failed: {detail}", status_code=r.status_code, body=r.text)
|
|
125
|
+
return r.json()
|
drift/cli/repl.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Chat-based CLI loop for drift.
|
|
3
|
+
Natural language input; maintains session state; reuses backend planner + executor.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import re
|
|
7
|
+
import sys
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from typing import Any, Dict, Optional
|
|
11
|
+
|
|
12
|
+
from drift.cli.client import BackendClient, BackendError, DEFAULT_BASE_URL
|
|
13
|
+
from drift.cli.session import SessionState
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
PROMPT = "drift › "
|
|
17
|
+
LOAD_PATTERN = re.compile(r"^\s*load\s+(.+)$", re.IGNORECASE)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def run_repl(base_url: Optional[str] = None) -> None:
|
|
21
|
+
"""Run the chat-based REPL. Uses backend at base_url (default from DRIFT_BACKEND_URL or localhost:8000)."""
|
|
22
|
+
try:
|
|
23
|
+
client = BackendClient(base_url=base_url or DEFAULT_BASE_URL)
|
|
24
|
+
except BackendError as e:
|
|
25
|
+
print(f"drift: {e.message}", file=sys.stderr)
|
|
26
|
+
sys.exit(1)
|
|
27
|
+
session = SessionState()
|
|
28
|
+
_print_banner(client)
|
|
29
|
+
while True:
|
|
30
|
+
try:
|
|
31
|
+
line = input(PROMPT).strip()
|
|
32
|
+
except EOFError:
|
|
33
|
+
print()
|
|
34
|
+
break
|
|
35
|
+
if not line:
|
|
36
|
+
continue
|
|
37
|
+
if line.lower() in ("quit", "exit", "q"):
|
|
38
|
+
break
|
|
39
|
+
try:
|
|
40
|
+
_handle_input(line, client, session)
|
|
41
|
+
except BackendError as e:
|
|
42
|
+
print(f"Error: {e.message}", file=sys.stderr)
|
|
43
|
+
except KeyboardInterrupt:
|
|
44
|
+
print("\nInterrupted. Type 'quit' to exit.")
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _print_banner(client: BackendClient) -> None:
|
|
49
|
+
_print_welcome()
|
|
50
|
+
engine_ok = False
|
|
51
|
+
llm_name = ""
|
|
52
|
+
try:
|
|
53
|
+
health = client.health()
|
|
54
|
+
engine_ok = (health.get("status") or "").lower() == "healthy"
|
|
55
|
+
llm_name = health.get("llm_provider") or health.get("current_model") or ""
|
|
56
|
+
except Exception:
|
|
57
|
+
pass
|
|
58
|
+
_print_status(engine_ok=engine_ok, llm_name=llm_name)
|
|
59
|
+
_print_examples()
|
|
60
|
+
print()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _print_welcome() -> None:
|
|
64
|
+
"""drift by Lakshit Sachdeva. Local-first ML engineer."""
|
|
65
|
+
print()
|
|
66
|
+
print(" ----------------------------------------")
|
|
67
|
+
print(" drift by Lakshit Sachdeva")
|
|
68
|
+
print(" Local-first ML engineer.")
|
|
69
|
+
print(" Same engine as the web app.")
|
|
70
|
+
print(" No commands to memorize.")
|
|
71
|
+
print(" ----------------------------------------")
|
|
72
|
+
print()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _print_status(engine_ok: bool, llm_name: str) -> None:
|
|
76
|
+
"""Engine running, LLM detected, Ready."""
|
|
77
|
+
if engine_ok:
|
|
78
|
+
print(" \u2713 Engine running")
|
|
79
|
+
if llm_name:
|
|
80
|
+
print(f" \u2713 LLM detected ({llm_name})")
|
|
81
|
+
else:
|
|
82
|
+
print(" \u2713 LLM detected (Gemini CLI / Ollama / local)")
|
|
83
|
+
print(" \u2713 Ready")
|
|
84
|
+
else:
|
|
85
|
+
print(" \u2717 Engine not running")
|
|
86
|
+
print(" Start the engine or set DRIFT_BACKEND_URL to a running engine URL.")
|
|
87
|
+
print()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _print_examples() -> None:
|
|
91
|
+
"""Examples: load, predict, try something stronger, quit."""
|
|
92
|
+
print(" Examples:")
|
|
93
|
+
print(" load data.csv")
|
|
94
|
+
print(" predict price")
|
|
95
|
+
print(" try something stronger")
|
|
96
|
+
print(" why is accuracy capped")
|
|
97
|
+
print(" quit")
|
|
98
|
+
print()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _handle_input(line: str, client: BackendClient, session: SessionState) -> None:
|
|
102
|
+
load_match = LOAD_PATTERN.match(line)
|
|
103
|
+
if load_match:
|
|
104
|
+
path = load_match.group(1).strip().strip('"\'')
|
|
105
|
+
_do_load(path, client, session)
|
|
106
|
+
return
|
|
107
|
+
if not session.has_session():
|
|
108
|
+
print("Load a dataset first: load path/to/file.csv")
|
|
109
|
+
return
|
|
110
|
+
_do_chat(line, client, session)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _do_load(path: str, client: BackendClient, session: SessionState) -> None:
|
|
114
|
+
out = client.upload_csv(path)
|
|
115
|
+
session.dataset_path = path
|
|
116
|
+
session.update_from_upload(
|
|
117
|
+
session_id=out["session_id"],
|
|
118
|
+
dataset_id=out["dataset_id"],
|
|
119
|
+
initial_message=out.get("initial_message"),
|
|
120
|
+
)
|
|
121
|
+
print("Dataset loaded.")
|
|
122
|
+
if out.get("initial_message"):
|
|
123
|
+
content = out["initial_message"].get("content") or ""
|
|
124
|
+
if content:
|
|
125
|
+
_print_message(content)
|
|
126
|
+
print()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _do_chat(message: str, client: BackendClient, session: SessionState) -> None:
|
|
130
|
+
sid = session.session_id
|
|
131
|
+
out = client.chat(sid, message)
|
|
132
|
+
chat_history = out.get("chat_history") or []
|
|
133
|
+
session.update_from_chat(chat_history)
|
|
134
|
+
# Show latest agent reply
|
|
135
|
+
for m in reversed(chat_history):
|
|
136
|
+
if m.get("role") == "agent" and m.get("content"):
|
|
137
|
+
_print_message(m["content"])
|
|
138
|
+
break
|
|
139
|
+
trigger = out.get("trigger_training") is True
|
|
140
|
+
if trigger:
|
|
141
|
+
_run_training_and_show(client, session)
|
|
142
|
+
print()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _run_training_and_show(client: BackendClient, session: SessionState) -> None:
|
|
146
|
+
sid = session.session_id
|
|
147
|
+
if not sid:
|
|
148
|
+
return
|
|
149
|
+
train_result: Optional[Dict[str, Any]] = None
|
|
150
|
+
train_error: Optional[Exception] = None
|
|
151
|
+
|
|
152
|
+
def do_train() -> None:
|
|
153
|
+
nonlocal train_result, train_error
|
|
154
|
+
try:
|
|
155
|
+
train_result = client.train(sid)
|
|
156
|
+
except Exception as e:
|
|
157
|
+
train_error = e
|
|
158
|
+
|
|
159
|
+
thread = threading.Thread(target=do_train, daemon=True)
|
|
160
|
+
thread.start()
|
|
161
|
+
print("Training started…")
|
|
162
|
+
run_id: Optional[str] = None
|
|
163
|
+
last_event_count = 0
|
|
164
|
+
poll_interval = 0.8
|
|
165
|
+
timeout_sec = 300
|
|
166
|
+
start = time.time()
|
|
167
|
+
while thread.is_alive() and (time.time() - start) < timeout_sec:
|
|
168
|
+
time.sleep(poll_interval)
|
|
169
|
+
try:
|
|
170
|
+
sess = client.get_session(sid)
|
|
171
|
+
run_id = sess.get("current_run_id")
|
|
172
|
+
if run_id:
|
|
173
|
+
break
|
|
174
|
+
except Exception:
|
|
175
|
+
pass
|
|
176
|
+
if run_id:
|
|
177
|
+
while thread.is_alive() and (time.time() - start) < timeout_sec:
|
|
178
|
+
time.sleep(poll_interval)
|
|
179
|
+
try:
|
|
180
|
+
run_state = client.get_run_state(run_id)
|
|
181
|
+
events = run_state.get("events") or []
|
|
182
|
+
for ev in events[last_event_count:]:
|
|
183
|
+
msg = ev.get("message") or ev.get("step_name") or ""
|
|
184
|
+
if msg:
|
|
185
|
+
print(f" {msg}")
|
|
186
|
+
last_event_count = len(events)
|
|
187
|
+
status = run_state.get("status")
|
|
188
|
+
if status in ("success", "failed", "refused"):
|
|
189
|
+
break
|
|
190
|
+
except Exception:
|
|
191
|
+
pass
|
|
192
|
+
thread.join(timeout=1.0)
|
|
193
|
+
if train_error:
|
|
194
|
+
print(f"Training error: {train_error}", file=sys.stderr)
|
|
195
|
+
return
|
|
196
|
+
if train_result:
|
|
197
|
+
session.update_after_train(
|
|
198
|
+
run_id=train_result.get("run_id", ""),
|
|
199
|
+
metrics=train_result.get("metrics") or {},
|
|
200
|
+
agent_message=train_result.get("agent_message"),
|
|
201
|
+
)
|
|
202
|
+
agent_message = train_result.get("agent_message")
|
|
203
|
+
if agent_message:
|
|
204
|
+
_print_message(agent_message)
|
|
205
|
+
metrics = train_result.get("metrics") or {}
|
|
206
|
+
if metrics:
|
|
207
|
+
primary = metrics.get("primary_metric_value") or metrics.get("accuracy") or metrics.get("r2")
|
|
208
|
+
if primary is not None:
|
|
209
|
+
print(f" Primary metric: {primary}")
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _print_message(content: str) -> None:
|
|
213
|
+
"""Print agent message; strip markdown bold for terminal if desired, keep newlines."""
|
|
214
|
+
text = content.strip()
|
|
215
|
+
if not text:
|
|
216
|
+
return
|
|
217
|
+
# Optional: replace **x** with x for readability in terminal
|
|
218
|
+
text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
|
|
219
|
+
for line in text.split("\n"):
|
|
220
|
+
print(line)
|
drift/cli/session.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Session state for the drift CLI.
|
|
3
|
+
Holds dataset ref, plan summary, last metrics, run_id, and chat history
|
|
4
|
+
so the REPL can maintain context and the backend can be queried correctly.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any, Dict, List, Optional
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class SessionState:
|
|
13
|
+
"""In-memory CLI session: dataset, plan, run, metrics, chat."""
|
|
14
|
+
|
|
15
|
+
dataset_path: Optional[str] = None
|
|
16
|
+
dataset_id: Optional[str] = None
|
|
17
|
+
session_id: Optional[str] = None
|
|
18
|
+
run_id: Optional[str] = None
|
|
19
|
+
plan_summary: Optional[str] = None
|
|
20
|
+
last_metrics: Dict[str, Any] = field(default_factory=dict)
|
|
21
|
+
chat_history: List[Dict[str, Any]] = field(default_factory=list)
|
|
22
|
+
|
|
23
|
+
def has_session(self) -> bool:
|
|
24
|
+
return bool(self.session_id)
|
|
25
|
+
|
|
26
|
+
def has_dataset(self) -> bool:
|
|
27
|
+
return bool(self.dataset_id and self.session_id)
|
|
28
|
+
|
|
29
|
+
def clear_run(self) -> None:
|
|
30
|
+
self.run_id = None
|
|
31
|
+
self.last_metrics = {}
|
|
32
|
+
|
|
33
|
+
def update_from_upload(self, session_id: str, dataset_id: str, initial_message: Optional[Dict] = None) -> None:
|
|
34
|
+
self.session_id = session_id
|
|
35
|
+
self.dataset_id = dataset_id
|
|
36
|
+
self.run_id = None
|
|
37
|
+
self.last_metrics = {}
|
|
38
|
+
if initial_message:
|
|
39
|
+
self.chat_history = [initial_message]
|
|
40
|
+
|
|
41
|
+
def update_from_chat(self, chat_history: List[Dict[str, Any]], run_id: Optional[str] = None) -> None:
|
|
42
|
+
self.chat_history = chat_history or self.chat_history
|
|
43
|
+
if run_id is not None:
|
|
44
|
+
self.run_id = run_id
|
|
45
|
+
|
|
46
|
+
def update_from_session(self, session: Dict[str, Any]) -> None:
|
|
47
|
+
self.session_id = session.get("session_id") or self.session_id
|
|
48
|
+
self.dataset_id = session.get("dataset_id") or self.dataset_id
|
|
49
|
+
self.run_id = session.get("current_run_id") or self.run_id
|
|
50
|
+
self.chat_history = session.get("chat_history") or self.chat_history
|
|
51
|
+
ms = session.get("model_state") or {}
|
|
52
|
+
self.last_metrics = ms.get("metrics") or self.last_metrics
|
|
53
|
+
self.plan_summary = None
|
|
54
|
+
if session.get("structural_plan"):
|
|
55
|
+
self.plan_summary = _summarize_plan(session["structural_plan"])
|
|
56
|
+
|
|
57
|
+
def update_after_train(self, run_id: str, metrics: Dict[str, Any], agent_message: Optional[str] = None) -> None:
|
|
58
|
+
self.run_id = run_id
|
|
59
|
+
self.last_metrics = metrics or self.last_metrics
|
|
60
|
+
if agent_message and self.chat_history:
|
|
61
|
+
self.chat_history[-1]["content"] = agent_message
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _summarize_plan(structural_plan: Dict[str, Any]) -> str:
|
|
65
|
+
"""One-line summary of structural plan for display (StructuralPlan has inferred_target, task_type; no models)."""
|
|
66
|
+
if not structural_plan:
|
|
67
|
+
return ""
|
|
68
|
+
target = structural_plan.get("inferred_target") or structural_plan.get("target") or "?"
|
|
69
|
+
task = structural_plan.get("task_type") or "?"
|
|
70
|
+
return f"Target: {target}, task: {task}"
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pluggable LLM adapters for drift CLI.
|
|
3
|
+
Primary planning/execution uses the backend; these adapters support
|
|
4
|
+
optional local processing (e.g. help, summaries) or future offline use.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
|
|
8
|
+
|
|
9
|
+
__all__ = ["BaseLLMAdapter", "LLMResponse"]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base contract for drift LLM adapters.
|
|
3
|
+
The CLI uses the backend as the planner/executor; adapters return
|
|
4
|
+
structured ML intent, plan updates, and explanations when used locally.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class LLMResponse:
|
|
14
|
+
"""Structured response from an LLM adapter: intent, plan updates, explanation."""
|
|
15
|
+
|
|
16
|
+
raw: str
|
|
17
|
+
intent: Optional[Dict[str, Any]] = None # target, drop_columns, performance_mode, start_training
|
|
18
|
+
plan_summary: Optional[str] = None
|
|
19
|
+
explanation: Optional[str] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseLLMAdapter(ABC):
|
|
23
|
+
"""
|
|
24
|
+
LLM adapter contract.
|
|
25
|
+
The LLM is the PLANNER, not the executor; it updates plans and explains outcomes.
|
|
26
|
+
Execution is assumed to happen automatically (backend).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def generate(self, user_message: str, system_prompt: Optional[str] = None, context: Optional[Dict[str, Any]] = None) -> LLMResponse:
|
|
31
|
+
"""
|
|
32
|
+
Turn user message + optional context into a structured response.
|
|
33
|
+
Returns intent, plan summary, and explanation — never claims to edit files or execute.
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def name(self) -> str:
|
|
40
|
+
"""Adapter identifier (e.g. gemini_cli, local_llm)."""
|
|
41
|
+
pass
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gemini CLI adapter — uses gemini CLI (default if available) for local LLM calls.
|
|
3
|
+
Used when drift CLI runs in a mode that calls a local LLM instead of the backend.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
import subprocess
|
|
9
|
+
from typing import Any, Dict, Optional
|
|
10
|
+
|
|
11
|
+
from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GeminiCLIAdapter(BaseLLMAdapter):
|
|
15
|
+
"""Adapter that shells out to gemini CLI (e.g. `gemini` or GEMINI_CLI_CMD)."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, cmd: Optional[str] = None):
|
|
18
|
+
self.cmd = (cmd or os.environ.get("GEMINI_CLI_CMD", "gemini")).strip()
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def name(self) -> str:
|
|
22
|
+
return "gemini_cli"
|
|
23
|
+
|
|
24
|
+
def generate(
|
|
25
|
+
self,
|
|
26
|
+
user_message: str,
|
|
27
|
+
system_prompt: Optional[str] = None,
|
|
28
|
+
context: Optional[Dict[str, Any]] = None,
|
|
29
|
+
) -> LLMResponse:
|
|
30
|
+
prompt = user_message
|
|
31
|
+
if system_prompt:
|
|
32
|
+
prompt = f"{system_prompt}\n\n---\n\n{prompt}"
|
|
33
|
+
if context:
|
|
34
|
+
prompt = f"Context: {context}\n\n{prompt}"
|
|
35
|
+
raw = self._call_cli(prompt)
|
|
36
|
+
intent = self._parse_intent(raw)
|
|
37
|
+
return LLMResponse(raw=raw, intent=intent, explanation=raw)
|
|
38
|
+
|
|
39
|
+
def _call_cli(self, prompt: str) -> str:
|
|
40
|
+
try:
|
|
41
|
+
proc = subprocess.run(
|
|
42
|
+
[self.cmd, "run", "-"],
|
|
43
|
+
input=prompt,
|
|
44
|
+
capture_output=True,
|
|
45
|
+
text=True,
|
|
46
|
+
timeout=120,
|
|
47
|
+
)
|
|
48
|
+
if proc.returncode != 0:
|
|
49
|
+
return f"[gemini CLI error: {proc.stderr or proc.stdout or 'unknown'}]"
|
|
50
|
+
return (proc.stdout or "").strip()
|
|
51
|
+
except FileNotFoundError:
|
|
52
|
+
return f"[gemini CLI not found: {self.cmd}]"
|
|
53
|
+
except subprocess.TimeoutExpired:
|
|
54
|
+
return "[gemini CLI timeout]"
|
|
55
|
+
|
|
56
|
+
def _parse_intent(self, raw: str) -> Optional[Dict[str, Any]]:
|
|
57
|
+
"""Extract INTENT_JSON from response if present."""
|
|
58
|
+
match = re.search(r"INTENT_JSON:\s*(\{.*?\})\s*$", raw, re.DOTALL)
|
|
59
|
+
if not match:
|
|
60
|
+
return None
|
|
61
|
+
try:
|
|
62
|
+
import json
|
|
63
|
+
return json.loads(match.group(1).strip())
|
|
64
|
+
except Exception:
|
|
65
|
+
return None
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local LLM adapter — ollama / llama.cpp style.
|
|
3
|
+
Used when drift CLI is configured to use a local model instead of the backend.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
from drift.llm_adapters.base import BaseLLMAdapter, LLMResponse
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LocalLLMAdapter(BaseLLMAdapter):
|
|
14
|
+
"""
|
|
15
|
+
Adapter for local LLM (e.g. Ollama, llama.cpp server).
|
|
16
|
+
Expects OLLAMA_BASE_URL or LOCAL_LLM_URL; model from OLLAMA_MODEL or LOCAL_LLM_MODEL.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, base_url: Optional[str] = None, model: Optional[str] = None):
|
|
20
|
+
self.base_url = (base_url or os.environ.get("OLLAMA_BASE_URL") or os.environ.get("LOCAL_LLM_URL", "http://localhost:11434")).rstrip("/")
|
|
21
|
+
self.model = model or os.environ.get("OLLAMA_MODEL") or os.environ.get("LOCAL_LLM_MODEL", "llama2")
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def name(self) -> str:
|
|
25
|
+
return "local_llm"
|
|
26
|
+
|
|
27
|
+
def generate(
|
|
28
|
+
self,
|
|
29
|
+
user_message: str,
|
|
30
|
+
system_prompt: Optional[str] = None,
|
|
31
|
+
context: Optional[Dict[str, Any]] = None,
|
|
32
|
+
) -> LLMResponse:
|
|
33
|
+
prompt = user_message
|
|
34
|
+
if system_prompt:
|
|
35
|
+
prompt = f"{system_prompt}\n\n---\n\n{prompt}"
|
|
36
|
+
if context:
|
|
37
|
+
prompt = f"Context: {context}\n\n{prompt}"
|
|
38
|
+
raw = self._call_api(prompt)
|
|
39
|
+
intent = self._parse_intent(raw)
|
|
40
|
+
return LLMResponse(raw=raw, intent=intent, explanation=raw)
|
|
41
|
+
|
|
42
|
+
def _call_api(self, prompt: str) -> str:
|
|
43
|
+
try:
|
|
44
|
+
import requests
|
|
45
|
+
r = requests.post(
|
|
46
|
+
f"{self.base_url}/api/generate",
|
|
47
|
+
json={"model": self.model, "prompt": prompt, "stream": False},
|
|
48
|
+
timeout=120,
|
|
49
|
+
)
|
|
50
|
+
if r.status_code >= 400:
|
|
51
|
+
return f"[Local LLM error: {r.status_code} {r.text[:200]}]"
|
|
52
|
+
data = r.json()
|
|
53
|
+
return (data.get("response") or "").strip()
|
|
54
|
+
except Exception as e:
|
|
55
|
+
return f"[Local LLM error: {e}]"
|
|
56
|
+
|
|
57
|
+
def _parse_intent(self, raw: str) -> Optional[Dict[str, Any]]:
|
|
58
|
+
match = re.search(r"INTENT_JSON:\s*(\{.*?\})\s*$", raw, re.DOTALL)
|
|
59
|
+
if not match:
|
|
60
|
+
return None
|
|
61
|
+
try:
|
|
62
|
+
import json
|
|
63
|
+
return json.loads(match.group(1).strip())
|
|
64
|
+
except Exception:
|
|
65
|
+
return None
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
drift/__init__.py,sha256=X0NUP5ZAZSz-rBFfjvmmS4IYWP2_CZtu187mqwpWwqk,127
|
|
2
|
+
drift/__main__.py,sha256=MMUjNUbctbLHVe37ZCOW-66h8OZ8WIYiISxHJHRNxes,202
|
|
3
|
+
drift/cli/__init__.py,sha256=OQt7M06e98e_L_60Qz-HsmSqGXKWdH53SvgETJ_BZZ0,172
|
|
4
|
+
drift/cli/client.py,sha256=AiV5hNCWgQbH9biznPUD7jctsSpDaXdy1WcgAIoBqtM,4880
|
|
5
|
+
drift/cli/repl.py,sha256=ifZH2xoWFwCfdsGSh7HEkdO378t1c4_h07iNZGiUrMs,7220
|
|
6
|
+
drift/cli/session.py,sha256=fMBI_pgiOmm-hcf7GDcjDeMIKESy6mGeEWhWyzdBirM,2860
|
|
7
|
+
drift/llm_adapters/__init__.py,sha256=y1UhZWlC8Ik_OKfLcOp0JZP-FKR3MBBCemWwsL6TnkY,296
|
|
8
|
+
drift/llm_adapters/base.py,sha256=KlZUPYpvCI8pafklBWel0GHHLNtVKVwSHA92MZt3VsI,1331
|
|
9
|
+
drift/llm_adapters/gemini_cli.py,sha256=Z61wY3yFiZqPrQrpJAQrBtMbBkr2qLWBkni-A4p9lZo,2163
|
|
10
|
+
drift/llm_adapters/local_llm.py,sha256=Z6j6z1CXk2LMeQ5ZnY4o38PiYkHYmcIkgJGkHdU50M8,2279
|
|
11
|
+
drift_ml-0.1.1.dist-info/METADATA,sha256=EGmvqr1xDEnVM42RZFtYS4htDu7KfDimXC2aD9qPu6o,168
|
|
12
|
+
drift_ml-0.1.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
13
|
+
drift_ml-0.1.1.dist-info/entry_points.txt,sha256=aCY7U9M8nhYj_tIfTXJmYkVmXY3ZoxF0tebDZzYswv8,46
|
|
14
|
+
drift_ml-0.1.1.dist-info/top_level.txt,sha256=3u2KGqsciGZQ2uCoBivm55t3e8er8S4xnqkgdQ_8oeM,6
|
|
15
|
+
drift_ml-0.1.1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
drift
|