drift-ml 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- drift/cli/client.py +14 -0
- drift/cli/repl.py +35 -1
- {drift_ml-0.1.1.dist-info → drift_ml-0.1.3.dist-info}/METADATA +1 -1
- {drift_ml-0.1.1.dist-info → drift_ml-0.1.3.dist-info}/RECORD +7 -7
- {drift_ml-0.1.1.dist-info → drift_ml-0.1.3.dist-info}/WHEEL +0 -0
- {drift_ml-0.1.1.dist-info → drift_ml-0.1.3.dist-info}/entry_points.txt +0 -0
- {drift_ml-0.1.1.dist-info → drift_ml-0.1.3.dist-info}/top_level.txt +0 -0
drift/cli/client.py
CHANGED
|
@@ -123,3 +123,17 @@ class BackendClient:
|
|
|
123
123
|
detail = r.text or str(r.status_code)
|
|
124
124
|
raise BackendError(f"Train failed: {detail}", status_code=r.status_code, body=r.text)
|
|
125
125
|
return r.json()
|
|
126
|
+
|
|
127
|
+
def download_notebook(self, run_id: str) -> Optional[bytes]:
|
|
128
|
+
"""GET /download/{run_id}/notebook — download notebook bytes, or None if failed."""
|
|
129
|
+
r = requests.get(self._url(f"/download/{run_id}/notebook"), timeout=60)
|
|
130
|
+
if r.status_code != 200:
|
|
131
|
+
return None
|
|
132
|
+
return r.content
|
|
133
|
+
|
|
134
|
+
def download_model(self, run_id: str) -> Optional[bytes]:
|
|
135
|
+
"""GET /download/{run_id}/model — download model pickle bytes, or None if refused/failed."""
|
|
136
|
+
r = requests.get(self._url(f"/download/{run_id}/model"), timeout=60)
|
|
137
|
+
if r.status_code != 200:
|
|
138
|
+
return None
|
|
139
|
+
return r.content
|
drift/cli/repl.py
CHANGED
|
@@ -3,10 +3,12 @@ Chat-based CLI loop for drift.
|
|
|
3
3
|
Natural language input; maintains session state; reuses backend planner + executor.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import os
|
|
6
7
|
import re
|
|
7
8
|
import sys
|
|
8
9
|
import threading
|
|
9
10
|
import time
|
|
11
|
+
from pathlib import Path
|
|
10
12
|
from typing import Any, Dict, Optional
|
|
11
13
|
|
|
12
14
|
from drift.cli.client import BackendClient, BackendError, DEFAULT_BASE_URL
|
|
@@ -194,8 +196,9 @@ def _run_training_and_show(client: BackendClient, session: SessionState) -> None
|
|
|
194
196
|
print(f"Training error: {train_error}", file=sys.stderr)
|
|
195
197
|
return
|
|
196
198
|
if train_result:
|
|
199
|
+
run_id = train_result.get("run_id", "")
|
|
197
200
|
session.update_after_train(
|
|
198
|
-
run_id=
|
|
201
|
+
run_id=run_id,
|
|
199
202
|
metrics=train_result.get("metrics") or {},
|
|
200
203
|
agent_message=train_result.get("agent_message"),
|
|
201
204
|
)
|
|
@@ -207,6 +210,37 @@ def _run_training_and_show(client: BackendClient, session: SessionState) -> None
|
|
|
207
210
|
primary = metrics.get("primary_metric_value") or metrics.get("accuracy") or metrics.get("r2")
|
|
208
211
|
if primary is not None:
|
|
209
212
|
print(f" Primary metric: {primary}")
|
|
213
|
+
# Auto-save notebook and model to current directory for terminal users
|
|
214
|
+
if run_id and not train_result.get("refused"):
|
|
215
|
+
_save_artifacts(client, run_id)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _save_artifacts(client: BackendClient, run_id: str) -> None:
|
|
219
|
+
"""Download notebook and model to current directory. Print where saved."""
|
|
220
|
+
cwd = Path(os.getcwd())
|
|
221
|
+
prefix = f"model_{run_id[:8]}"
|
|
222
|
+
saved = []
|
|
223
|
+
try:
|
|
224
|
+
nb = client.download_notebook(run_id)
|
|
225
|
+
if nb:
|
|
226
|
+
path = cwd / f"{prefix}_training_notebook.ipynb"
|
|
227
|
+
path.write_bytes(nb)
|
|
228
|
+
saved.append(str(path))
|
|
229
|
+
except Exception:
|
|
230
|
+
pass
|
|
231
|
+
try:
|
|
232
|
+
mdl = client.download_model(run_id)
|
|
233
|
+
if mdl:
|
|
234
|
+
path = cwd / f"{prefix}.pkl"
|
|
235
|
+
path.write_bytes(mdl)
|
|
236
|
+
saved.append(str(path))
|
|
237
|
+
except Exception:
|
|
238
|
+
pass
|
|
239
|
+
if saved:
|
|
240
|
+
print()
|
|
241
|
+
print(" Saved to current directory:")
|
|
242
|
+
for p in saved:
|
|
243
|
+
print(f" {p}")
|
|
210
244
|
|
|
211
245
|
|
|
212
246
|
def _print_message(content: str) -> None:
|
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
drift/__init__.py,sha256=X0NUP5ZAZSz-rBFfjvmmS4IYWP2_CZtu187mqwpWwqk,127
|
|
2
2
|
drift/__main__.py,sha256=MMUjNUbctbLHVe37ZCOW-66h8OZ8WIYiISxHJHRNxes,202
|
|
3
3
|
drift/cli/__init__.py,sha256=OQt7M06e98e_L_60Qz-HsmSqGXKWdH53SvgETJ_BZZ0,172
|
|
4
|
-
drift/cli/client.py,sha256=
|
|
5
|
-
drift/cli/repl.py,sha256=
|
|
4
|
+
drift/cli/client.py,sha256=hGxz-J8fC_gGV9BZnpCBwkQZg4Q0_s9-Agk2PuXQ5MQ,5527
|
|
5
|
+
drift/cli/repl.py,sha256=hRnx4nfP_StnXDbdRZXnnjA8aNvrdQSV-uTNsuaaO60,8262
|
|
6
6
|
drift/cli/session.py,sha256=fMBI_pgiOmm-hcf7GDcjDeMIKESy6mGeEWhWyzdBirM,2860
|
|
7
7
|
drift/llm_adapters/__init__.py,sha256=y1UhZWlC8Ik_OKfLcOp0JZP-FKR3MBBCemWwsL6TnkY,296
|
|
8
8
|
drift/llm_adapters/base.py,sha256=KlZUPYpvCI8pafklBWel0GHHLNtVKVwSHA92MZt3VsI,1331
|
|
9
9
|
drift/llm_adapters/gemini_cli.py,sha256=Z61wY3yFiZqPrQrpJAQrBtMbBkr2qLWBkni-A4p9lZo,2163
|
|
10
10
|
drift/llm_adapters/local_llm.py,sha256=Z6j6z1CXk2LMeQ5ZnY4o38PiYkHYmcIkgJGkHdU50M8,2279
|
|
11
|
-
drift_ml-0.1.
|
|
12
|
-
drift_ml-0.1.
|
|
13
|
-
drift_ml-0.1.
|
|
14
|
-
drift_ml-0.1.
|
|
15
|
-
drift_ml-0.1.
|
|
11
|
+
drift_ml-0.1.3.dist-info/METADATA,sha256=JYHG8KKtZxOeOrkaCUqZ2WnRDtxYnRqoGtDE9REyKmI,168
|
|
12
|
+
drift_ml-0.1.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
13
|
+
drift_ml-0.1.3.dist-info/entry_points.txt,sha256=aCY7U9M8nhYj_tIfTXJmYkVmXY3ZoxF0tebDZzYswv8,46
|
|
14
|
+
drift_ml-0.1.3.dist-info/top_level.txt,sha256=3u2KGqsciGZQ2uCoBivm55t3e8er8S4xnqkgdQ_8oeM,6
|
|
15
|
+
drift_ml-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|