schift-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- schift_cli/__init__.py +1 -0
- schift_cli/client.py +119 -0
- schift_cli/commands/__init__.py +0 -0
- schift_cli/commands/auth.py +68 -0
- schift_cli/commands/bench.py +65 -0
- schift_cli/commands/catalog.py +74 -0
- schift_cli/commands/db.py +96 -0
- schift_cli/commands/embed.py +104 -0
- schift_cli/commands/migrate.py +127 -0
- schift_cli/commands/query.py +66 -0
- schift_cli/commands/skill.py +110 -0
- schift_cli/commands/usage.py +50 -0
- schift_cli/config.py +58 -0
- schift_cli/data/schift-best-practices/AGENTS.md +77 -0
- schift_cli/data/schift-best-practices/CLAUDE.md +77 -0
- schift_cli/data/schift-best-practices/SKILL.md +89 -0
- schift_cli/data/schift-best-practices/references/bucket-organization.md +126 -0
- schift_cli/data/schift-best-practices/references/bucket-upload.md +116 -0
- schift_cli/data/schift-best-practices/references/chatbot-widget.md +238 -0
- schift_cli/data/schift-best-practices/references/cost-batching.md +179 -0
- schift_cli/data/schift-best-practices/references/cost-storage-tiers.md +183 -0
- schift_cli/data/schift-best-practices/references/deploy-cloudrun.md +140 -0
- schift_cli/data/schift-best-practices/references/embed-batch-processing.md +86 -0
- schift_cli/data/schift-best-practices/references/embed-error-handling.md +155 -0
- schift_cli/data/schift-best-practices/references/embed-multimodal.md +100 -0
- schift_cli/data/schift-best-practices/references/embed-task-types.md +135 -0
- schift_cli/data/schift-best-practices/references/rag-chunking.md +173 -0
- schift_cli/data/schift-best-practices/references/rag-workflow-builder.md +205 -0
- schift_cli/data/schift-best-practices/references/sdk-async-patterns.md +103 -0
- schift_cli/data/schift-best-practices/references/sdk-auth-patterns.md +76 -0
- schift_cli/data/schift-best-practices/references/search-collection-design.md +229 -0
- schift_cli/data/schift-best-practices/references/search-hybrid.md +163 -0
- schift_cli/data/schift-best-practices/references/search-similarity-tuning.md +134 -0
- schift_cli/display.py +85 -0
- schift_cli/main.py +39 -0
- schift_cli-0.1.0.dist-info/METADATA +12 -0
- schift_cli-0.1.0.dist-info/RECORD +40 -0
- schift_cli-0.1.0.dist-info/WHEEL +5 -0
- schift_cli-0.1.0.dist-info/entry_points.txt +2 -0
- schift_cli-0.1.0.dist-info/top_level.txt +1 -0
schift_cli/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
schift_cli/client.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from schift_cli.config import get_api_key, get_api_url
|
|
9
|
+
|
|
10
|
+
# Timeout: 30s connect, 120s read (migrations can be slow)
|
|
11
|
+
DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=120.0)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SchiftAPIError(Exception):
|
|
15
|
+
"""Raised when the Schift API returns a non-2xx response."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, status_code: int, detail: str):
|
|
18
|
+
self.status_code = status_code
|
|
19
|
+
self.detail = detail
|
|
20
|
+
super().__init__(f"HTTP {status_code}: {detail}")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SchiftClient:
|
|
24
|
+
"""HTTP client for the Schift API.
|
|
25
|
+
|
|
26
|
+
Handles authentication headers, base URL resolution, and consistent
|
|
27
|
+
error handling so command modules can stay focused on CLI logic.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
|
31
|
+
self.api_key = api_key or get_api_key()
|
|
32
|
+
self.base_url = (base_url or get_api_url()).rstrip("/")
|
|
33
|
+
self._http = httpx.Client(
|
|
34
|
+
base_url=self.base_url,
|
|
35
|
+
timeout=DEFAULT_TIMEOUT,
|
|
36
|
+
headers=self._build_headers(),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def _build_headers(self) -> dict[str, str]:
|
|
40
|
+
headers: dict[str, str] = {
|
|
41
|
+
"User-Agent": "schift-cli/0.1.0",
|
|
42
|
+
"Accept": "application/json",
|
|
43
|
+
}
|
|
44
|
+
if self.api_key:
|
|
45
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
46
|
+
return headers
|
|
47
|
+
|
|
48
|
+
# -- HTTP verbs ----------------------------------------------------------
|
|
49
|
+
|
|
50
|
+
def get(self, path: str, **kwargs: Any) -> Any:
|
|
51
|
+
return self._request("GET", path, **kwargs)
|
|
52
|
+
|
|
53
|
+
def post(self, path: str, **kwargs: Any) -> Any:
|
|
54
|
+
return self._request("POST", path, **kwargs)
|
|
55
|
+
|
|
56
|
+
def put(self, path: str, **kwargs: Any) -> Any:
|
|
57
|
+
return self._request("PUT", path, **kwargs)
|
|
58
|
+
|
|
59
|
+
def delete(self, path: str, **kwargs: Any) -> Any:
|
|
60
|
+
return self._request("DELETE", path, **kwargs)
|
|
61
|
+
|
|
62
|
+
# -- Internal -------------------------------------------------------------
|
|
63
|
+
|
|
64
|
+
def _request(self, method: str, path: str, **kwargs: Any) -> Any:
|
|
65
|
+
try:
|
|
66
|
+
resp = self._http.request(method, path, **kwargs)
|
|
67
|
+
except httpx.ConnectError:
|
|
68
|
+
raise click.ClickException(
|
|
69
|
+
f"Could not connect to Schift API at {self.base_url}\n"
|
|
70
|
+
" The server may be unavailable. Check your network or set "
|
|
71
|
+
"SCHIFT_API_URL if using a custom endpoint."
|
|
72
|
+
)
|
|
73
|
+
except httpx.TimeoutException:
|
|
74
|
+
raise click.ClickException(
|
|
75
|
+
"Request timed out. The server may be under heavy load — try again."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if resp.status_code == 401:
|
|
79
|
+
raise click.ClickException(
|
|
80
|
+
"Authentication failed. Run `schift auth login` to set your API key."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if resp.status_code >= 400:
|
|
84
|
+
try:
|
|
85
|
+
body = resp.json()
|
|
86
|
+
detail = body.get("detail") or body.get("message") or resp.text
|
|
87
|
+
except Exception:
|
|
88
|
+
detail = resp.text
|
|
89
|
+
raise SchiftAPIError(resp.status_code, str(detail))
|
|
90
|
+
|
|
91
|
+
if resp.status_code == 204:
|
|
92
|
+
return None
|
|
93
|
+
return resp.json()
|
|
94
|
+
|
|
95
|
+
def close(self) -> None:
|
|
96
|
+
self._http.close()
|
|
97
|
+
|
|
98
|
+
def __enter__(self) -> SchiftClient:
|
|
99
|
+
return self
|
|
100
|
+
|
|
101
|
+
def __exit__(self, *args: Any) -> None:
|
|
102
|
+
self.close()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def require_api_key() -> str:
|
|
106
|
+
"""Return the API key or abort with a helpful message."""
|
|
107
|
+
key = get_api_key()
|
|
108
|
+
if not key:
|
|
109
|
+
raise click.ClickException(
|
|
110
|
+
"No API key configured.\n"
|
|
111
|
+
" Run `schift auth login` or set the SCHIFT_API_KEY environment variable."
|
|
112
|
+
)
|
|
113
|
+
return key
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_client() -> SchiftClient:
|
|
117
|
+
"""Create a client, ensuring an API key is present."""
|
|
118
|
+
api_key = require_api_key()
|
|
119
|
+
return SchiftClient(api_key=api_key)
|
|
File without changes
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
|
|
5
|
+
from schift_cli.config import clear_api_key, get_api_key, set_api_key, CONFIG_FILE
|
|
6
|
+
from schift_cli.display import success, info, error
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.group("auth")
|
|
10
|
+
def auth() -> None:
|
|
11
|
+
"""Manage authentication with the Schift platform."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@auth.command()
|
|
15
|
+
def login() -> None:
|
|
16
|
+
"""Set your Schift API key."""
|
|
17
|
+
existing = get_api_key()
|
|
18
|
+
if existing:
|
|
19
|
+
overwrite = click.confirm(
|
|
20
|
+
"An API key is already configured. Overwrite?", default=False
|
|
21
|
+
)
|
|
22
|
+
if not overwrite:
|
|
23
|
+
info("Keeping existing API key.")
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
api_key = click.prompt("Enter your Schift API key", hide_input=True)
|
|
27
|
+
api_key = api_key.strip()
|
|
28
|
+
|
|
29
|
+
if not api_key:
|
|
30
|
+
raise click.ClickException("API key cannot be empty.")
|
|
31
|
+
|
|
32
|
+
set_api_key(api_key)
|
|
33
|
+
success(f"API key saved to {CONFIG_FILE}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@auth.command()
|
|
37
|
+
def logout() -> None:
|
|
38
|
+
"""Remove the stored API key."""
|
|
39
|
+
if not get_api_key():
|
|
40
|
+
info("No API key is currently stored.")
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
clear_api_key()
|
|
44
|
+
success("API key removed.")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@auth.command()
|
|
48
|
+
def status() -> None:
|
|
49
|
+
"""Show current authentication status."""
|
|
50
|
+
import os
|
|
51
|
+
from schift_cli.config import ENV_API_KEY
|
|
52
|
+
|
|
53
|
+
env_key = os.environ.get(ENV_API_KEY)
|
|
54
|
+
file_key = None
|
|
55
|
+
try:
|
|
56
|
+
from schift_cli.config import load_config
|
|
57
|
+
file_key = load_config().get("api_key")
|
|
58
|
+
except Exception:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
if env_key:
|
|
62
|
+
masked = env_key[:8] + "..." + env_key[-4:] if len(env_key) > 12 else "***"
|
|
63
|
+
success(f"Authenticated via {ENV_API_KEY} env var (key: {masked})")
|
|
64
|
+
elif file_key:
|
|
65
|
+
masked = file_key[:8] + "..." + file_key[-4:] if len(file_key) > 12 else "***"
|
|
66
|
+
success(f"Authenticated via config file (key: {masked})")
|
|
67
|
+
else:
|
|
68
|
+
error("Not authenticated. Run `schift auth login` to get started.")
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
|
|
7
|
+
from schift_cli.client import get_client, SchiftAPIError
|
|
8
|
+
from schift_cli.display import console, error, info, print_kv, spinner, success
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@click.command("bench")
|
|
12
|
+
@click.option("--source", "-s", required=True, help="Source model ID (e.g. openai/text-embedding-3-large)")
|
|
13
|
+
@click.option("--target", "-t", required=True, help="Target model ID (e.g. google/gemini-embedding-004)")
|
|
14
|
+
@click.option("--data", "-d", type=click.Path(exists=True, path_type=Path), required=True,
|
|
15
|
+
help="JSONL file with benchmark queries")
|
|
16
|
+
@click.option("--top-k", "-k", type=int, default=10, show_default=True,
|
|
17
|
+
help="Number of results to compare per query")
|
|
18
|
+
def bench(source: str, target: str, data: Path, top_k: int) -> None:
|
|
19
|
+
"""Benchmark embedding quality between two models.
|
|
20
|
+
|
|
21
|
+
Measures how well a Schift projection preserves retrieval quality
|
|
22
|
+
when switching from SOURCE to TARGET model.
|
|
23
|
+
"""
|
|
24
|
+
info(f"Benchmarking projection: {source} -> {target}")
|
|
25
|
+
info(f"Data: {data} | top-k: {top_k}")
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
with get_client() as client:
|
|
29
|
+
with spinner("Running benchmark...") as progress:
|
|
30
|
+
progress.add_task("Running benchmark...", total=None)
|
|
31
|
+
result = client.post(
|
|
32
|
+
"/bench",
|
|
33
|
+
json={
|
|
34
|
+
"source_model": source,
|
|
35
|
+
"target_model": target,
|
|
36
|
+
"data_path": str(data),
|
|
37
|
+
"top_k": top_k,
|
|
38
|
+
},
|
|
39
|
+
)
|
|
40
|
+
except SchiftAPIError as e:
|
|
41
|
+
error(f"Benchmark failed: {e.detail}")
|
|
42
|
+
raise SystemExit(1)
|
|
43
|
+
except click.ClickException:
|
|
44
|
+
raise
|
|
45
|
+
|
|
46
|
+
report = result.get("report", result)
|
|
47
|
+
print_kv("Benchmark Report", {
|
|
48
|
+
"Source Model": source,
|
|
49
|
+
"Target Model": target,
|
|
50
|
+
"Queries": report.get("num_queries", "-"),
|
|
51
|
+
"Recall@k": report.get("recall_at_k", "-"),
|
|
52
|
+
"MRR": report.get("mrr", "-"),
|
|
53
|
+
"Cosine Similarity (avg)": report.get("avg_cosine_similarity", "-"),
|
|
54
|
+
"Latency (p50)": report.get("latency_p50_ms", "-"),
|
|
55
|
+
"Latency (p99)": report.get("latency_p99_ms", "-"),
|
|
56
|
+
})
|
|
57
|
+
|
|
58
|
+
quality = report.get("recall_at_k")
|
|
59
|
+
if quality is not None:
|
|
60
|
+
if float(quality) >= 0.95:
|
|
61
|
+
success("Projection quality is excellent.")
|
|
62
|
+
elif float(quality) >= 0.85:
|
|
63
|
+
console.print("[yellow]Projection quality is acceptable but may degrade edge cases.[/]")
|
|
64
|
+
else:
|
|
65
|
+
console.print("[red]Projection quality is low. Consider increasing sample size in `migrate fit`.[/]")
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
|
|
5
|
+
from schift_cli.client import get_client, SchiftAPIError
|
|
6
|
+
from schift_cli.display import print_table, print_kv, error
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.group("catalog")
|
|
10
|
+
def catalog() -> None:
|
|
11
|
+
"""Browse available embedding models."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@catalog.command("list")
|
|
15
|
+
def list_models() -> None:
|
|
16
|
+
"""List all supported embedding models."""
|
|
17
|
+
try:
|
|
18
|
+
with get_client() as client:
|
|
19
|
+
data = client.get("/catalog/models")
|
|
20
|
+
except SchiftAPIError as e:
|
|
21
|
+
error(f"Failed to fetch model catalog: {e.detail}")
|
|
22
|
+
raise SystemExit(1)
|
|
23
|
+
except click.ClickException:
|
|
24
|
+
raise
|
|
25
|
+
|
|
26
|
+
models = data.get("models", [])
|
|
27
|
+
if not models:
|
|
28
|
+
click.echo("No models found in the catalog.")
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
rows = [
|
|
32
|
+
(
|
|
33
|
+
m.get("id", ""),
|
|
34
|
+
m.get("provider", ""),
|
|
35
|
+
str(m.get("dimensions", "")),
|
|
36
|
+
m.get("max_tokens", ""),
|
|
37
|
+
m.get("status", ""),
|
|
38
|
+
)
|
|
39
|
+
for m in models
|
|
40
|
+
]
|
|
41
|
+
print_table(
|
|
42
|
+
"Embedding Model Catalog",
|
|
43
|
+
["Model ID", "Provider", "Dimensions", "Max Tokens", "Status"],
|
|
44
|
+
rows,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@catalog.command("get")
|
|
49
|
+
@click.argument("model_id")
|
|
50
|
+
def get_model(model_id: str) -> None:
|
|
51
|
+
"""Show details for a specific model.
|
|
52
|
+
|
|
53
|
+
MODEL_ID is the fully qualified model name, e.g. openai/text-embedding-3-large
|
|
54
|
+
"""
|
|
55
|
+
try:
|
|
56
|
+
with get_client() as client:
|
|
57
|
+
data = client.get(f"/catalog/models/{model_id}")
|
|
58
|
+
except SchiftAPIError as e:
|
|
59
|
+
if e.status_code == 404:
|
|
60
|
+
error(f"Model not found: {model_id}")
|
|
61
|
+
else:
|
|
62
|
+
error(f"Failed to fetch model: {e.detail}")
|
|
63
|
+
raise SystemExit(1)
|
|
64
|
+
except click.ClickException:
|
|
65
|
+
raise
|
|
66
|
+
|
|
67
|
+
model = data.get("model", data)
|
|
68
|
+
print_kv(f"Model: {model_id}", {
|
|
69
|
+
"Provider": model.get("provider", "-"),
|
|
70
|
+
"Dimensions": model.get("dimensions", "-"),
|
|
71
|
+
"Max Tokens": model.get("max_tokens", "-"),
|
|
72
|
+
"Status": model.get("status", "-"),
|
|
73
|
+
"Description": model.get("description", "-"),
|
|
74
|
+
})
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
|
|
5
|
+
from schift_cli.client import get_client, SchiftAPIError
|
|
6
|
+
from schift_cli.display import error, info, print_kv, print_table, success
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.group("db")
|
|
10
|
+
def db() -> None:
|
|
11
|
+
"""Manage vector collections."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@db.command("create")
|
|
15
|
+
@click.argument("name")
|
|
16
|
+
@click.option("--dim", "-d", type=int, required=True, help="Vector dimensions (e.g. 3072)")
|
|
17
|
+
@click.option("--metric", type=click.Choice(["cosine", "euclidean", "dot"]),
|
|
18
|
+
default="cosine", show_default=True, help="Distance metric")
|
|
19
|
+
def create(name: str, dim: int, metric: str) -> None:
|
|
20
|
+
"""Create a new vector collection."""
|
|
21
|
+
try:
|
|
22
|
+
with get_client() as client:
|
|
23
|
+
data = client.post(
|
|
24
|
+
"/collections",
|
|
25
|
+
json={"name": name, "dimensions": dim, "metric": metric},
|
|
26
|
+
)
|
|
27
|
+
except SchiftAPIError as e:
|
|
28
|
+
error(f"Failed to create collection: {e.detail}")
|
|
29
|
+
raise SystemExit(1)
|
|
30
|
+
except click.ClickException:
|
|
31
|
+
raise
|
|
32
|
+
|
|
33
|
+
collection = data.get("collection", data)
|
|
34
|
+
success(f"Collection '{name}' created (id: {collection.get('id', '-')})")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@db.command("list")
|
|
38
|
+
def list_collections() -> None:
|
|
39
|
+
"""List all vector collections."""
|
|
40
|
+
try:
|
|
41
|
+
with get_client() as client:
|
|
42
|
+
data = client.get("/collections")
|
|
43
|
+
except SchiftAPIError as e:
|
|
44
|
+
error(f"Failed to list collections: {e.detail}")
|
|
45
|
+
raise SystemExit(1)
|
|
46
|
+
except click.ClickException:
|
|
47
|
+
raise
|
|
48
|
+
|
|
49
|
+
collections = data.get("collections", [])
|
|
50
|
+
if not collections:
|
|
51
|
+
info("No collections found. Create one with `schift db create`.")
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
rows = [
|
|
55
|
+
(
|
|
56
|
+
c.get("name", ""),
|
|
57
|
+
str(c.get("dimensions", "")),
|
|
58
|
+
c.get("metric", ""),
|
|
59
|
+
str(c.get("vector_count", "")),
|
|
60
|
+
c.get("created_at", ""),
|
|
61
|
+
)
|
|
62
|
+
for c in collections
|
|
63
|
+
]
|
|
64
|
+
print_table(
|
|
65
|
+
"Vector Collections",
|
|
66
|
+
["Name", "Dimensions", "Metric", "Vectors", "Created"],
|
|
67
|
+
rows,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@db.command("stats")
|
|
72
|
+
@click.argument("name")
|
|
73
|
+
def stats(name: str) -> None:
|
|
74
|
+
"""Show statistics for a collection."""
|
|
75
|
+
try:
|
|
76
|
+
with get_client() as client:
|
|
77
|
+
data = client.get(f"/collections/{name}/stats")
|
|
78
|
+
except SchiftAPIError as e:
|
|
79
|
+
if e.status_code == 404:
|
|
80
|
+
error(f"Collection not found: {name}")
|
|
81
|
+
else:
|
|
82
|
+
error(f"Failed to get stats: {e.detail}")
|
|
83
|
+
raise SystemExit(1)
|
|
84
|
+
except click.ClickException:
|
|
85
|
+
raise
|
|
86
|
+
|
|
87
|
+
s = data.get("stats", data)
|
|
88
|
+
print_kv(f"Collection: {name}", {
|
|
89
|
+
"Vectors": s.get("vector_count", "-"),
|
|
90
|
+
"Dimensions": s.get("dimensions", "-"),
|
|
91
|
+
"Metric": s.get("metric", "-"),
|
|
92
|
+
"Index Type": s.get("index_type", "-"),
|
|
93
|
+
"Storage Size": s.get("storage_size", "-"),
|
|
94
|
+
"Created": s.get("created_at", "-"),
|
|
95
|
+
"Last Updated": s.get("updated_at", "-"),
|
|
96
|
+
})
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import click
|
|
7
|
+
|
|
8
|
+
from schift_cli.client import get_client, SchiftAPIError
|
|
9
|
+
from schift_cli.display import console, error, info, success
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@click.group("embed", invoke_without_command=True)
|
|
13
|
+
@click.argument("text", required=False, default=None)
|
|
14
|
+
@click.option("--model", "-m", default=None,
|
|
15
|
+
help="Embedding model ID (e.g. openai/text-embedding-3-large)")
|
|
16
|
+
@click.pass_context
|
|
17
|
+
def embed(ctx: click.Context, text: str | None, model: str | None) -> None:
|
|
18
|
+
"""Generate embeddings for text.
|
|
19
|
+
|
|
20
|
+
\b
|
|
21
|
+
Usage:
|
|
22
|
+
schift embed "hello world" --model openai/text-embedding-3-large
|
|
23
|
+
schift embed batch --file texts.jsonl --model google/gemini-embedding-004
|
|
24
|
+
"""
|
|
25
|
+
if ctx.invoked_subcommand is not None:
|
|
26
|
+
return
|
|
27
|
+
|
|
28
|
+
if not text:
|
|
29
|
+
raise click.UsageError("Provide TEXT to embed, or use `schift embed batch`.")
|
|
30
|
+
if not model:
|
|
31
|
+
raise click.UsageError("--model is required.")
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
with get_client() as client:
|
|
35
|
+
data = client.post("/embed", json={"text": text, "model": model})
|
|
36
|
+
except SchiftAPIError as e:
|
|
37
|
+
error(f"Embedding failed: {e.detail}")
|
|
38
|
+
raise SystemExit(1)
|
|
39
|
+
except click.ClickException:
|
|
40
|
+
raise
|
|
41
|
+
|
|
42
|
+
embedding = data.get("embedding", [])
|
|
43
|
+
dims = len(embedding)
|
|
44
|
+
preview = embedding[:5]
|
|
45
|
+
preview_str = ", ".join(f"{v:.6f}" for v in preview)
|
|
46
|
+
|
|
47
|
+
success(f"Generated {dims}-dimensional embedding")
|
|
48
|
+
console.print(f" [{preview_str}, ...]")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@embed.command("batch")
|
|
52
|
+
@click.option("--file", "-f", "file_path", required=True,
|
|
53
|
+
type=click.Path(exists=True, path_type=Path),
|
|
54
|
+
help="JSONL file with one text per line (field: \"text\")")
|
|
55
|
+
@click.option("--model", "-m", required=True,
|
|
56
|
+
help="Embedding model ID (e.g. google/gemini-embedding-004)")
|
|
57
|
+
@click.option("--output", "-o", type=click.Path(path_type=Path), default=None,
|
|
58
|
+
help="Output JSONL file for results (default: stdout summary)")
|
|
59
|
+
def embed_batch(file_path: Path, model: str, output: Path | None) -> None:
|
|
60
|
+
"""Embed multiple texts from a JSONL file.
|
|
61
|
+
|
|
62
|
+
Each line in the input file must be a JSON object with a "text" field.
|
|
63
|
+
"""
|
|
64
|
+
texts: list[str] = []
|
|
65
|
+
with open(file_path) as f:
|
|
66
|
+
for i, line in enumerate(f, 1):
|
|
67
|
+
line = line.strip()
|
|
68
|
+
if not line:
|
|
69
|
+
continue
|
|
70
|
+
try:
|
|
71
|
+
obj = json.loads(line)
|
|
72
|
+
texts.append(obj["text"])
|
|
73
|
+
except (json.JSONDecodeError, KeyError):
|
|
74
|
+
error(f"Line {i}: expected JSON with a \"text\" field")
|
|
75
|
+
raise SystemExit(1)
|
|
76
|
+
|
|
77
|
+
if not texts:
|
|
78
|
+
error("No texts found in input file.")
|
|
79
|
+
raise SystemExit(1)
|
|
80
|
+
|
|
81
|
+
info(f"Embedding {len(texts)} texts with model {model}")
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
with get_client() as client:
|
|
85
|
+
data = client.post(
|
|
86
|
+
"/embed/batch",
|
|
87
|
+
json={"texts": texts, "model": model},
|
|
88
|
+
)
|
|
89
|
+
except SchiftAPIError as e:
|
|
90
|
+
error(f"Batch embedding failed: {e.detail}")
|
|
91
|
+
raise SystemExit(1)
|
|
92
|
+
except click.ClickException:
|
|
93
|
+
raise
|
|
94
|
+
|
|
95
|
+
embeddings = data.get("embeddings", [])
|
|
96
|
+
|
|
97
|
+
if output:
|
|
98
|
+
with open(output, "w") as f:
|
|
99
|
+
for text_val, emb in zip(texts, embeddings):
|
|
100
|
+
f.write(json.dumps({"text": text_val, "embedding": emb}) + "\n")
|
|
101
|
+
success(f"Wrote {len(embeddings)} embeddings to {output}")
|
|
102
|
+
else:
|
|
103
|
+
dims = len(embeddings[0]) if embeddings else 0
|
|
104
|
+
success(f"Generated {len(embeddings)} embeddings ({dims} dimensions each)")
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
|
|
5
|
+
from schift_cli.client import get_client, SchiftAPIError
|
|
6
|
+
from schift_cli.display import console, error, info, print_kv, spinner, success
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.group("migrate")
|
|
10
|
+
def migrate() -> None:
|
|
11
|
+
"""Fit projection matrices and migrate vector databases."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@migrate.command("fit")
|
|
15
|
+
@click.option("--source", "-s", required=True, help="Source embedding model ID")
|
|
16
|
+
@click.option("--target", "-t", required=True, help="Target embedding model ID")
|
|
17
|
+
@click.option("--sample", type=float, default=0.1, show_default=True,
|
|
18
|
+
help="Fraction of data to sample for fitting (0.0-1.0)")
|
|
19
|
+
def fit(source: str, target: str, sample: float) -> None:
|
|
20
|
+
"""Fit a projection matrix between two embedding models.
|
|
21
|
+
|
|
22
|
+
The projection is computed server-side. You never see the matrix --
|
|
23
|
+
Schift stores it and returns a projection ID for use in `migrate run`.
|
|
24
|
+
"""
|
|
25
|
+
if not 0.0 < sample <= 1.0:
|
|
26
|
+
raise click.BadParameter("Sample must be between 0 and 1.", param_hint="--sample")
|
|
27
|
+
|
|
28
|
+
info(f"Fitting projection: {source} -> {target} (sample={sample})")
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
with get_client() as client:
|
|
32
|
+
with spinner("Fitting projection matrix...") as progress:
|
|
33
|
+
progress.add_task("Fitting projection matrix...", total=None)
|
|
34
|
+
result = client.post(
|
|
35
|
+
"/migrate/fit",
|
|
36
|
+
json={
|
|
37
|
+
"source_model": source,
|
|
38
|
+
"target_model": target,
|
|
39
|
+
"sample_fraction": sample,
|
|
40
|
+
},
|
|
41
|
+
)
|
|
42
|
+
except SchiftAPIError as e:
|
|
43
|
+
error(f"Fit failed: {e.detail}")
|
|
44
|
+
raise SystemExit(1)
|
|
45
|
+
except click.ClickException:
|
|
46
|
+
raise
|
|
47
|
+
|
|
48
|
+
projection = result.get("projection", result)
|
|
49
|
+
proj_id = projection.get("id", "unknown")
|
|
50
|
+
|
|
51
|
+
print_kv("Projection Created", {
|
|
52
|
+
"Projection ID": proj_id,
|
|
53
|
+
"Source Model": source,
|
|
54
|
+
"Target Model": target,
|
|
55
|
+
"Sample Fraction": sample,
|
|
56
|
+
"Status": projection.get("status", "ready"),
|
|
57
|
+
"Quality (R2)": projection.get("r2_score", "-"),
|
|
58
|
+
})
|
|
59
|
+
|
|
60
|
+
success(f"Use this projection ID to migrate: schift migrate run --projection {proj_id}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@migrate.command("run")
|
|
64
|
+
@click.option("--projection", "-p", required=True, help="Projection ID from `migrate fit`")
|
|
65
|
+
@click.option("--db", required=True, help="Database connection string (e.g. pgvector://...)")
|
|
66
|
+
@click.option("--dry-run", is_flag=True, default=False, help="Preview the migration without applying changes")
|
|
67
|
+
@click.option("--batch-size", type=int, default=1000, show_default=True,
|
|
68
|
+
help="Number of vectors to migrate per batch")
|
|
69
|
+
def run(projection: str, db: str, dry_run: bool, batch_size: int) -> None:
|
|
70
|
+
"""Apply a projection to migrate vectors in a database.
|
|
71
|
+
|
|
72
|
+
Use --dry-run first to preview the migration plan.
|
|
73
|
+
"""
|
|
74
|
+
mode = "DRY RUN" if dry_run else "LIVE"
|
|
75
|
+
info(f"Migration [{mode}]: projection={projection}")
|
|
76
|
+
info(f"Database: {_mask_connection_string(db)}")
|
|
77
|
+
|
|
78
|
+
if not dry_run:
|
|
79
|
+
click.confirm(
|
|
80
|
+
"This will modify vectors in your database. Continue?",
|
|
81
|
+
abort=True,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
with get_client() as client:
|
|
86
|
+
with spinner("Running migration...") as progress:
|
|
87
|
+
progress.add_task("Running migration...", total=None)
|
|
88
|
+
result = client.post(
|
|
89
|
+
"/migrate/run",
|
|
90
|
+
json={
|
|
91
|
+
"projection_id": projection,
|
|
92
|
+
"db_connection": db,
|
|
93
|
+
"dry_run": dry_run,
|
|
94
|
+
"batch_size": batch_size,
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
except SchiftAPIError as e:
|
|
98
|
+
error(f"Migration failed: {e.detail}")
|
|
99
|
+
raise SystemExit(1)
|
|
100
|
+
except click.ClickException:
|
|
101
|
+
raise
|
|
102
|
+
|
|
103
|
+
migration = result.get("migration", result)
|
|
104
|
+
|
|
105
|
+
print_kv(f"Migration Result ({mode})", {
|
|
106
|
+
"Vectors Processed": migration.get("vectors_processed", "-"),
|
|
107
|
+
"Vectors Skipped": migration.get("vectors_skipped", "-"),
|
|
108
|
+
"Duration": migration.get("duration", "-"),
|
|
109
|
+
"Status": migration.get("status", "-"),
|
|
110
|
+
})
|
|
111
|
+
|
|
112
|
+
if dry_run:
|
|
113
|
+
info("No changes were applied. Remove --dry-run to execute.")
|
|
114
|
+
else:
|
|
115
|
+
success("Migration complete.")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _mask_connection_string(conn: str) -> str:
|
|
119
|
+
"""Hide password in connection strings for display."""
|
|
120
|
+
if "@" in conn:
|
|
121
|
+
# pgvector://user:pass@host -> pgvector://user:***@host
|
|
122
|
+
before_at = conn.split("@")[0]
|
|
123
|
+
after_at = conn.split("@", 1)[1]
|
|
124
|
+
if ":" in before_at:
|
|
125
|
+
scheme_user = before_at.rsplit(":", 1)[0]
|
|
126
|
+
return f"{scheme_user}:***@{after_at}"
|
|
127
|
+
return conn
|