shiftgate 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shiftgate/__init__.py +9 -0
- shiftgate/cli.py +513 -0
- shiftgate/feedback/__init__.py +1 -0
- shiftgate/feedback/loop.py +182 -0
- shiftgate/registry/__init__.py +1 -0
- shiftgate/registry/adapter_registry.py +162 -0
- shiftgate/registry/schemas.py +115 -0
- shiftgate/registry/task_registry.py +186 -0
- shiftgate/router/__init__.py +1 -0
- shiftgate/router/embedder.py +95 -0
- shiftgate/router/matcher.py +115 -0
- shiftgate/router/router.py +97 -0
- shiftgate/runtime/__init__.py +1 -0
- shiftgate/runtime/backend.py +289 -0
- shiftgate/utils/__init__.py +1 -0
- shiftgate/utils/display.py +297 -0
- shiftgate-0.1.0.dist-info/METADATA +273 -0
- shiftgate-0.1.0.dist-info/RECORD +20 -0
- shiftgate-0.1.0.dist-info/WHEEL +4 -0
- shiftgate-0.1.0.dist-info/entry_points.txt +2 -0
shiftgate/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
shiftgate — Intelligent LoRA adapter routing for local LLM inference.
|
|
3
|
+
|
|
4
|
+
Automatically selects the right adapter for each task using semantic
|
|
5
|
+
similarity, inspired by the LORAUTER paper (EPFL, 2026).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
__author__ = "shiftgate contributors"
|
shiftgate/cli.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
"""
|
|
2
|
+
shiftgate CLI — Typer-based command-line interface.
|
|
3
|
+
|
|
4
|
+
All user-facing interactions happen through this module. Commands are
|
|
5
|
+
grouped into four sub-apps (adapter, task, feedback) plus top-level
|
|
6
|
+
commands (init, route, run, status, demo).
|
|
7
|
+
|
|
8
|
+
Entry point (registered in pyproject.toml):
|
|
9
|
+
shiftgate = "shiftgate.cli:app"
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import time
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Annotated, Optional
|
|
17
|
+
|
|
18
|
+
import typer
|
|
19
|
+
from rich.console import Console
|
|
20
|
+
from rich.prompt import Confirm, Prompt
|
|
21
|
+
|
|
22
|
+
from shiftgate.registry.schemas import AdapterEntry, TaskCluster
|
|
23
|
+
|
|
24
|
+
console = Console()
|
|
25
|
+
app = typer.Typer(
|
|
26
|
+
name="shiftgate",
|
|
27
|
+
help="Intelligent LoRA adapter routing for local LLM inference.",
|
|
28
|
+
no_args_is_help=True,
|
|
29
|
+
pretty_exceptions_show_locals=False,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Sub-apps
|
|
33
|
+
adapter_app = typer.Typer(help="Manage the adapter registry.", no_args_is_help=True)
|
|
34
|
+
task_app = typer.Typer(help="Manage task clusters.", no_args_is_help=True)
|
|
35
|
+
feedback_app = typer.Typer(help="Record and review routing feedback.", no_args_is_help=True)
|
|
36
|
+
|
|
37
|
+
app.add_typer(adapter_app, name="adapter")
|
|
38
|
+
app.add_typer(task_app, name="task")
|
|
39
|
+
app.add_typer(feedback_app, name="feedback")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ---------------------------------------------------------------------------
|
|
43
|
+
# Helpers
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
|
|
46
|
+
def _load_registries():
|
|
47
|
+
"""Return (task_registry, adapter_registry) — exits with error on failure."""
|
|
48
|
+
from shiftgate.registry.adapter_registry import AdapterRegistry
|
|
49
|
+
from shiftgate.registry.task_registry import TaskRegistry
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
task_reg = TaskRegistry.load()
|
|
53
|
+
adapter_reg = AdapterRegistry.load()
|
|
54
|
+
except FileNotFoundError as exc:
|
|
55
|
+
console.print(f"[red]Error:[/red] {exc}")
|
|
56
|
+
raise typer.Exit(1)
|
|
57
|
+
return task_reg, adapter_reg
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _get_embedder():
|
|
61
|
+
from shiftgate.router.embedder import Embedder
|
|
62
|
+
return Embedder()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ---------------------------------------------------------------------------
|
|
66
|
+
# shiftgate init
|
|
67
|
+
# ---------------------------------------------------------------------------
|
|
68
|
+
|
|
69
|
+
@app.command()
|
|
70
|
+
def init() -> None:
|
|
71
|
+
"""Set up ~/.shiftgate/, compute task embeddings, and show a welcome message."""
|
|
72
|
+
from shiftgate.registry.adapter_registry import AdapterRegistry
|
|
73
|
+
from shiftgate.registry.task_registry import TaskRegistry
|
|
74
|
+
from shiftgate.utils.display import show_adapter_table, show_task_table, show_welcome_banner
|
|
75
|
+
|
|
76
|
+
show_welcome_banner()
|
|
77
|
+
|
|
78
|
+
shiftgate_dir = Path.home() / ".shiftgate"
|
|
79
|
+
shiftgate_dir.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
console.print(f"[dim]Config directory:[/dim] {shiftgate_dir}")
|
|
81
|
+
console.print()
|
|
82
|
+
|
|
83
|
+
# Load defaults (copies them into the user's ~/.shiftgate/ on first save)
|
|
84
|
+
console.print("[cyan]Loading task registry…[/cyan]")
|
|
85
|
+
task_reg = TaskRegistry.load()
|
|
86
|
+
adapter_reg = AdapterRegistry.load()
|
|
87
|
+
|
|
88
|
+
# Compute embeddings (downloads model on first run)
|
|
89
|
+
if task_reg.embeddings_ready():
|
|
90
|
+
console.print("[dim]Embeddings already computed. Skipping (delete embeddings_cache.npy to force refresh).[/dim]")
|
|
91
|
+
else:
|
|
92
|
+
console.print("[cyan]Computing task embeddings (first run — model download may take a moment)…[/cyan]")
|
|
93
|
+
embedder = _get_embedder()
|
|
94
|
+
task_reg.compute_embeddings(embedder)
|
|
95
|
+
console.print("[green]✓[/green] Embeddings computed.")
|
|
96
|
+
|
|
97
|
+
# Persist to ~/.shiftgate/
|
|
98
|
+
task_reg.save()
|
|
99
|
+
adapter_reg.save()
|
|
100
|
+
console.print(f"[green]✓[/green] Registry saved to {shiftgate_dir}")
|
|
101
|
+
console.print()
|
|
102
|
+
|
|
103
|
+
show_task_table(task_reg.get_all_tasks())
|
|
104
|
+
console.print()
|
|
105
|
+
show_adapter_table(adapter_reg.list_adapters())
|
|
106
|
+
console.print()
|
|
107
|
+
|
|
108
|
+
console.print(
|
|
109
|
+
"[bold green]shiftgate is ready![/bold green]\n\n"
|
|
110
|
+
" Try it:\n"
|
|
111
|
+
' [cyan]shiftgate route "write a python function"[/cyan]\n\n'
|
|
112
|
+
" Add a LoRA adapter:\n"
|
|
113
|
+
" [cyan]shiftgate adapter add monology/pmc-llama-13b-lora[/cyan]\n"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# ---------------------------------------------------------------------------
|
|
118
|
+
# shiftgate adapter
|
|
119
|
+
# ---------------------------------------------------------------------------
|
|
120
|
+
|
|
121
|
+
@adapter_app.command("add")
|
|
122
|
+
def adapter_add(
|
|
123
|
+
hf_repo_or_path: Annotated[str, typer.Argument(help="HuggingFace repo ID or local path.")],
|
|
124
|
+
tags: Annotated[
|
|
125
|
+
Optional[list[str]],
|
|
126
|
+
typer.Option("--tags", "-t", help="Task tags, e.g. --tags code python"),
|
|
127
|
+
] = None,
|
|
128
|
+
base: Annotated[
|
|
129
|
+
Optional[str],
|
|
130
|
+
typer.Option("--base", "-b", help="Base model name, e.g. 'meta-llama/Meta-Llama-3-8B'"),
|
|
131
|
+
] = None,
|
|
132
|
+
name: Annotated[Optional[str], typer.Option(help="Override display name.")] = None,
|
|
133
|
+
description: Annotated[Optional[str], typer.Option(help="Short description.")] = None,
|
|
134
|
+
) -> None:
|
|
135
|
+
"""Register a new LoRA adapter from a HuggingFace repo or local path."""
|
|
136
|
+
_, adapter_reg = _load_registries()
|
|
137
|
+
|
|
138
|
+
kwargs: dict = {}
|
|
139
|
+
if tags:
|
|
140
|
+
kwargs["tags"] = tags
|
|
141
|
+
if base:
|
|
142
|
+
kwargs["base_model"] = base
|
|
143
|
+
if description:
|
|
144
|
+
kwargs["description"] = description
|
|
145
|
+
|
|
146
|
+
with console.status("[cyan]Fetching adapter metadata…[/cyan]"):
|
|
147
|
+
adapter = adapter_reg.add_adapter(hf_repo_or_path, **kwargs)
|
|
148
|
+
|
|
149
|
+
if name:
|
|
150
|
+
adapter.name = name
|
|
151
|
+
adapter_reg._adapters[adapter.id] = adapter # refresh
|
|
152
|
+
|
|
153
|
+
adapter_reg.save()
|
|
154
|
+
console.print(f"[green]✓[/green] Adapter '[bold magenta]{adapter.id}[/bold magenta]' registered.")
|
|
155
|
+
console.print(f" Name: {adapter.name}")
|
|
156
|
+
console.print(f" Base: {adapter.base_model}")
|
|
157
|
+
if adapter.task_tags:
|
|
158
|
+
console.print(f" Tags: {', '.join(adapter.task_tags)}")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@adapter_app.command("list")
|
|
162
|
+
def adapter_list() -> None:
|
|
163
|
+
"""Show all registered adapters in a Rich table."""
|
|
164
|
+
from shiftgate.utils.display import show_adapter_table
|
|
165
|
+
|
|
166
|
+
_, adapter_reg = _load_registries()
|
|
167
|
+
show_adapter_table(adapter_reg.list_adapters())
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@adapter_app.command("remove")
|
|
171
|
+
def adapter_remove(
|
|
172
|
+
adapter_id: Annotated[str, typer.Argument(help="Adapter ID to remove.")],
|
|
173
|
+
) -> None:
|
|
174
|
+
"""Remove an adapter from the registry."""
|
|
175
|
+
_, adapter_reg = _load_registries()
|
|
176
|
+
if adapter_reg.remove_adapter(adapter_id):
|
|
177
|
+
adapter_reg.save()
|
|
178
|
+
console.print(f"[green]✓[/green] Adapter '{adapter_id}' removed.")
|
|
179
|
+
else:
|
|
180
|
+
console.print(f"[red]Error:[/red] Adapter '{adapter_id}' not found.")
|
|
181
|
+
raise typer.Exit(1)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# ---------------------------------------------------------------------------
|
|
185
|
+
# shiftgate task
|
|
186
|
+
# ---------------------------------------------------------------------------
|
|
187
|
+
|
|
188
|
+
@task_app.command("list")
|
|
189
|
+
def task_list() -> None:
|
|
190
|
+
"""Show all task clusters in a Rich table."""
|
|
191
|
+
from shiftgate.utils.display import show_task_table
|
|
192
|
+
|
|
193
|
+
task_reg, _ = _load_registries()
|
|
194
|
+
show_task_table(task_reg.get_all_tasks())
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@task_app.command("add")
|
|
198
|
+
def task_add() -> None:
|
|
199
|
+
"""Interactively add a new task cluster to the registry."""
|
|
200
|
+
task_reg, _ = _load_registries()
|
|
201
|
+
|
|
202
|
+
console.print("[bold cyan]Add a new task cluster[/bold cyan]")
|
|
203
|
+
console.print("[dim]Press Ctrl-C to cancel.[/dim]\n")
|
|
204
|
+
|
|
205
|
+
task_id = Prompt.ask("Task ID (slug, e.g. code_rust)")
|
|
206
|
+
task_name = Prompt.ask("Display name (e.g. Rust Code Generation)")
|
|
207
|
+
task_desc = Prompt.ask("Short description")
|
|
208
|
+
|
|
209
|
+
console.print("\nEnter validation examples (one per line). Empty line to finish:")
|
|
210
|
+
examples: list[str] = []
|
|
211
|
+
while True:
|
|
212
|
+
ex = Prompt.ask(f" Example {len(examples) + 1} (or Enter to finish)", default="")
|
|
213
|
+
if not ex:
|
|
214
|
+
break
|
|
215
|
+
examples.append(ex)
|
|
216
|
+
|
|
217
|
+
if len(examples) < 3:
|
|
218
|
+
console.print("[yellow]Warning:[/yellow] Fewer than 3 examples may produce a poor centroid.")
|
|
219
|
+
|
|
220
|
+
preferred_raw = Prompt.ask("Preferred adapter IDs (comma-separated, or leave blank)")
|
|
221
|
+
preferred = [p.strip() for p in preferred_raw.split(",") if p.strip()]
|
|
222
|
+
|
|
223
|
+
task = TaskCluster(
|
|
224
|
+
id=task_id,
|
|
225
|
+
name=task_name,
|
|
226
|
+
description=task_desc,
|
|
227
|
+
validation_examples=examples,
|
|
228
|
+
preferred_adapters=preferred,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if Confirm.ask(f"\nSave task '[bold]{task_id}[/bold]'?"):
|
|
232
|
+
# Recompute centroid for the new task only.
|
|
233
|
+
try:
|
|
234
|
+
embedder = _get_embedder()
|
|
235
|
+
import numpy as np
|
|
236
|
+
vecs = embedder.embed_batch(task.validation_examples)
|
|
237
|
+
centroid = vecs.mean(axis=0)
|
|
238
|
+
norm = np.linalg.norm(centroid)
|
|
239
|
+
if norm > 0:
|
|
240
|
+
centroid = centroid / norm
|
|
241
|
+
task.embedding_centroid = centroid.tolist()
|
|
242
|
+
console.print("[green]✓[/green] Centroid computed for new task.")
|
|
243
|
+
except Exception as exc:
|
|
244
|
+
console.print(f"[yellow]Warning:[/yellow] Could not compute centroid: {exc}")
|
|
245
|
+
|
|
246
|
+
task_reg.add_task(task)
|
|
247
|
+
task_reg.save()
|
|
248
|
+
console.print(f"[green]✓[/green] Task '[bold]{task_id}[/bold]' saved.")
|
|
249
|
+
else:
|
|
250
|
+
console.print("[dim]Cancelled.[/dim]")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# ---------------------------------------------------------------------------
|
|
254
|
+
# shiftgate route — routing only, no inference
|
|
255
|
+
# ---------------------------------------------------------------------------
|
|
256
|
+
|
|
257
|
+
@app.command()
|
|
258
|
+
def route(
|
|
259
|
+
query: Annotated[str, typer.Argument(help="Query to route.")],
|
|
260
|
+
top_k: Annotated[int, typer.Option("--top-k", "-k", help="Number of candidate tasks.")] = 3,
|
|
261
|
+
record: Annotated[bool, typer.Option(help="Save trace to ~/.shiftgate/traces.jsonl.")] = True,
|
|
262
|
+
) -> None:
|
|
263
|
+
"""Route a query to the best adapter (no inference — just the routing decision)."""
|
|
264
|
+
from shiftgate.feedback import loop as feedback_loop
|
|
265
|
+
from shiftgate.registry.task_registry import TaskRegistry
|
|
266
|
+
from shiftgate.router import router as routing
|
|
267
|
+
from shiftgate.utils.display import show_routing_decision
|
|
268
|
+
|
|
269
|
+
task_reg, adapter_reg = _load_registries()
|
|
270
|
+
|
|
271
|
+
if not task_reg.embeddings_ready():
|
|
272
|
+
console.print("[red]Error:[/red] Task embeddings not initialised. Run `shiftgate init` first.")
|
|
273
|
+
raise typer.Exit(1)
|
|
274
|
+
|
|
275
|
+
embedder = _get_embedder()
|
|
276
|
+
|
|
277
|
+
try:
|
|
278
|
+
trace = routing.route(query, task_reg, adapter_reg, embedder, top_k=top_k)
|
|
279
|
+
except Exception as exc:
|
|
280
|
+
console.print(f"[red]Routing error:[/red] {exc}")
|
|
281
|
+
raise typer.Exit(1)
|
|
282
|
+
|
|
283
|
+
adapter = adapter_reg.get_adapter(trace.selected_adapter_id)
|
|
284
|
+
task = task_reg.get_task(trace.matched_task_id)
|
|
285
|
+
|
|
286
|
+
show_routing_decision(
|
|
287
|
+
trace,
|
|
288
|
+
adapter=adapter,
|
|
289
|
+
task_name=task.name if task else None,
|
|
290
|
+
backend_name=None,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
if record:
|
|
294
|
+
feedback_loop.record_trace(trace)
|
|
295
|
+
console.print(f"[dim]Trace {trace.id[:8]}… recorded. Run `shiftgate feedback accept/reject` to rate it.[/dim]")
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# ---------------------------------------------------------------------------
|
|
299
|
+
# shiftgate run — route + run inference
|
|
300
|
+
# ---------------------------------------------------------------------------
|
|
301
|
+
|
|
302
|
+
@app.command()
|
|
303
|
+
def run(
|
|
304
|
+
query: Annotated[str, typer.Argument(help="Query to route and run.")],
|
|
305
|
+
top_k: Annotated[int, typer.Option("--top-k", "-k")] = 3,
|
|
306
|
+
) -> None:
|
|
307
|
+
"""Route a query and run it through the detected Ollama or vLLM backend."""
|
|
308
|
+
from shiftgate.feedback import loop as feedback_loop
|
|
309
|
+
from shiftgate.router import router as routing
|
|
310
|
+
from shiftgate.runtime.backend import BackendRouter, NoBackendError
|
|
311
|
+
from shiftgate.utils.display import show_routing_decision
|
|
312
|
+
|
|
313
|
+
task_reg, adapter_reg = _load_registries()
|
|
314
|
+
|
|
315
|
+
if not task_reg.embeddings_ready():
|
|
316
|
+
console.print("[red]Error:[/red] Task embeddings not initialised. Run `shiftgate init` first.")
|
|
317
|
+
raise typer.Exit(1)
|
|
318
|
+
|
|
319
|
+
embedder = _get_embedder()
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
trace = routing.route(query, task_reg, adapter_reg, embedder, top_k=top_k)
|
|
323
|
+
except Exception as exc:
|
|
324
|
+
console.print(f"[red]Routing error:[/red] {exc}")
|
|
325
|
+
raise typer.Exit(1)
|
|
326
|
+
|
|
327
|
+
adapter = adapter_reg.get_adapter(trace.selected_adapter_id)
|
|
328
|
+
task = task_reg.get_task(trace.matched_task_id)
|
|
329
|
+
backend_router = BackendRouter()
|
|
330
|
+
backend_name = backend_router.detect()
|
|
331
|
+
|
|
332
|
+
show_routing_decision(
|
|
333
|
+
trace,
|
|
334
|
+
adapter=adapter,
|
|
335
|
+
task_name=task.name if task else None,
|
|
336
|
+
backend_name=backend_name,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
if adapter is None:
|
|
340
|
+
console.print(f"[red]Adapter '{trace.selected_adapter_id}' not found in registry.[/red]")
|
|
341
|
+
raise typer.Exit(1)
|
|
342
|
+
|
|
343
|
+
if backend_name is None:
|
|
344
|
+
console.print(
|
|
345
|
+
"[yellow]No inference backend detected.[/yellow]\n"
|
|
346
|
+
" shiftgate routed your query to "
|
|
347
|
+
f"[bold magenta]{trace.selected_adapter_id}[/bold magenta].\n"
|
|
348
|
+
" To run inference, start a backend:\n"
|
|
349
|
+
" [cyan]ollama serve[/cyan]\n"
|
|
350
|
+
" [cyan]python -m vllm.entrypoints.openai.api_server --model <base_model>[/cyan]"
|
|
351
|
+
)
|
|
352
|
+
feedback_loop.record_trace(trace)
|
|
353
|
+
raise typer.Exit(0)
|
|
354
|
+
|
|
355
|
+
# Run inference
|
|
356
|
+
console.print(f"[cyan]Running via [bold]{backend_name}[/bold]…[/cyan]")
|
|
357
|
+
t0 = time.monotonic()
|
|
358
|
+
try:
|
|
359
|
+
response = backend_router.generate(query, adapter)
|
|
360
|
+
except NoBackendError as exc:
|
|
361
|
+
console.print(f"[red]{exc}[/red]")
|
|
362
|
+
raise typer.Exit(1)
|
|
363
|
+
except Exception as exc:
|
|
364
|
+
console.print(f"[red]Inference error:[/red] {exc}")
|
|
365
|
+
raise typer.Exit(1)
|
|
366
|
+
|
|
367
|
+
elapsed_ms = (time.monotonic() - t0) * 1000
|
|
368
|
+
trace.latency_ms = elapsed_ms
|
|
369
|
+
|
|
370
|
+
console.print()
|
|
371
|
+
console.rule("[dim]Response[/dim]")
|
|
372
|
+
console.print(response)
|
|
373
|
+
console.rule()
|
|
374
|
+
console.print(f"[dim]Latency: {elapsed_ms:.0f} ms[/dim]")
|
|
375
|
+
|
|
376
|
+
feedback_loop.record_trace(trace)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
# ---------------------------------------------------------------------------
|
|
380
|
+
# shiftgate feedback
|
|
381
|
+
# ---------------------------------------------------------------------------
|
|
382
|
+
|
|
383
|
+
@feedback_app.command("accept")
|
|
384
|
+
def feedback_accept() -> None:
|
|
385
|
+
"""Mark the last routing trace as accepted (good routing decision)."""
|
|
386
|
+
from shiftgate.feedback import loop as feedback_loop
|
|
387
|
+
|
|
388
|
+
trace = feedback_loop.mark_last_accepted(True)
|
|
389
|
+
if trace:
|
|
390
|
+
console.print(f"[green]✓[/green] Trace {trace.id[:8]}… marked [green]accepted[/green].")
|
|
391
|
+
else:
|
|
392
|
+
console.print("[yellow]No traces found.[/yellow]")
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
@feedback_app.command("reject")
|
|
396
|
+
def feedback_reject() -> None:
|
|
397
|
+
"""Mark the last routing trace as rejected (bad routing decision)."""
|
|
398
|
+
from shiftgate.feedback import loop as feedback_loop
|
|
399
|
+
|
|
400
|
+
trace = feedback_loop.mark_last_accepted(False)
|
|
401
|
+
if trace:
|
|
402
|
+
console.print(f"[green]✓[/green] Trace {trace.id[:8]}… marked [red]rejected[/red].")
|
|
403
|
+
else:
|
|
404
|
+
console.print("[yellow]No traces found.[/yellow]")
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@feedback_app.command("stats")
|
|
408
|
+
def feedback_stats() -> None:
|
|
409
|
+
"""Show adapter acceptance rates across all rated traces."""
|
|
410
|
+
from shiftgate.feedback import loop as feedback_loop
|
|
411
|
+
from shiftgate.utils.display import show_feedback_stats
|
|
412
|
+
|
|
413
|
+
scores = feedback_loop.compute_adapter_scores()
|
|
414
|
+
stats = feedback_loop.get_trace_stats()
|
|
415
|
+
show_feedback_stats(scores, stats)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# ---------------------------------------------------------------------------
|
|
419
|
+
# shiftgate status
|
|
420
|
+
# ---------------------------------------------------------------------------
|
|
421
|
+
|
|
422
|
+
@app.command()
|
|
423
|
+
def status() -> None:
|
|
424
|
+
"""Show backend connectivity, registry sizes, and embedding status."""
|
|
425
|
+
from shiftgate.runtime.backend import BackendRouter
|
|
426
|
+
from shiftgate.utils.display import show_status
|
|
427
|
+
|
|
428
|
+
task_reg, adapter_reg = _load_registries()
|
|
429
|
+
|
|
430
|
+
with console.status("[cyan]Probing backends…[/cyan]"):
|
|
431
|
+
backend_router = BackendRouter()
|
|
432
|
+
backend_name = backend_router.detect()
|
|
433
|
+
|
|
434
|
+
show_status(
|
|
435
|
+
backend_name=backend_name,
|
|
436
|
+
n_adapters=len(adapter_reg),
|
|
437
|
+
n_tasks=len(task_reg),
|
|
438
|
+
embeddings_ready=task_reg.embeddings_ready(),
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
# ---------------------------------------------------------------------------
|
|
443
|
+
# shiftgate demo
|
|
444
|
+
# ---------------------------------------------------------------------------
|
|
445
|
+
|
|
446
|
+
@app.command()
|
|
447
|
+
def demo() -> None:
|
|
448
|
+
"""Run an animated demo: fake routing traces and an adapter swap."""
|
|
449
|
+
from shiftgate.registry.schemas import RoutingTrace
|
|
450
|
+
from shiftgate.utils.display import (
|
|
451
|
+
animate_swap,
|
|
452
|
+
show_routing_decision,
|
|
453
|
+
show_welcome_banner,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
show_welcome_banner()
|
|
457
|
+
time.sleep(0.5)
|
|
458
|
+
|
|
459
|
+
demo_traces = [
|
|
460
|
+
{
|
|
461
|
+
"query": "Write a Python function to parse JSON from a REST API",
|
|
462
|
+
"task": "Python Code Generation",
|
|
463
|
+
"adapter": "python-lora-llama3",
|
|
464
|
+
"score": 0.91,
|
|
465
|
+
},
|
|
466
|
+
{
|
|
467
|
+
"query": "SELECT all users who signed up in the last 30 days",
|
|
468
|
+
"task": "SQL Query Writing",
|
|
469
|
+
"adapter": "sql-lora-mistral",
|
|
470
|
+
"score": 0.87,
|
|
471
|
+
},
|
|
472
|
+
{
|
|
473
|
+
"query": "Summarise this research paper in 3 bullet points",
|
|
474
|
+
"task": "Text Summarization",
|
|
475
|
+
"adapter": "summarize-lora-llama3",
|
|
476
|
+
"score": 0.83,
|
|
477
|
+
},
|
|
478
|
+
{
|
|
479
|
+
"query": "Fix the KeyError on line 42 in this Python script",
|
|
480
|
+
"task": "Debugging & Error Fixing",
|
|
481
|
+
"adapter": "debug-lora-codellama",
|
|
482
|
+
"score": 0.78,
|
|
483
|
+
},
|
|
484
|
+
]
|
|
485
|
+
|
|
486
|
+
import uuid
|
|
487
|
+
from datetime import datetime, timezone
|
|
488
|
+
|
|
489
|
+
for i, entry in enumerate(demo_traces):
|
|
490
|
+
trace = RoutingTrace(
|
|
491
|
+
id=uuid.uuid4().hex,
|
|
492
|
+
query=entry["query"],
|
|
493
|
+
matched_task_id=entry["task"].lower().replace(" ", "_"),
|
|
494
|
+
similarity_score=entry["score"],
|
|
495
|
+
selected_adapter_id=entry["adapter"],
|
|
496
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
497
|
+
)
|
|
498
|
+
show_routing_decision(trace, task_name=entry["task"])
|
|
499
|
+
time.sleep(0.8)
|
|
500
|
+
|
|
501
|
+
if i < len(demo_traces) - 1:
|
|
502
|
+
next_adapter = demo_traces[i + 1]["adapter"]
|
|
503
|
+
animate_swap(entry["adapter"], next_adapter, duration=1.0)
|
|
504
|
+
time.sleep(0.4)
|
|
505
|
+
|
|
506
|
+
console.print()
|
|
507
|
+
console.print("[bold green]Demo complete![/bold green] shiftgate routes tasks at inference time — zero training required.")
|
|
508
|
+
console.print(
|
|
509
|
+
"\n Get started:\n"
|
|
510
|
+
' [cyan]shiftgate init[/cyan]\n'
|
|
511
|
+
' [cyan]shiftgate adapter add <hf_repo>[/cyan]\n'
|
|
512
|
+
' [cyan]shiftgate route "your query here"[/cyan]\n'
|
|
513
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Feedback sub-package: trace storage and adapter scoring loop."""
|