shiftgate 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
shiftgate/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """
2
+ shiftgate — Intelligent LoRA adapter routing for local LLM inference.
3
+
4
+ Automatically selects the right adapter for each task using semantic
5
+ similarity, inspired by the LORAUTER paper (EPFL, 2026).
6
+ """
7
+
8
+ __version__ = "0.1.0"
9
+ __author__ = "shiftgate contributors"
shiftgate/cli.py ADDED
@@ -0,0 +1,513 @@
1
+ """
2
+ shiftgate CLI — Typer-based command-line interface.
3
+
4
+ All user-facing interactions happen through this module. Commands are
5
+ grouped into four sub-apps (adapter, task, feedback) plus top-level
6
+ commands (init, route, run, status, demo).
7
+
8
+ Entry point (registered in pyproject.toml):
9
+ shiftgate = "shiftgate.cli:app"
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Annotated, Optional
17
+
18
+ import typer
19
+ from rich.console import Console
20
+ from rich.prompt import Confirm, Prompt
21
+
22
+ from shiftgate.registry.schemas import AdapterEntry, TaskCluster
23
+
24
+ console = Console()
25
+ app = typer.Typer(
26
+ name="shiftgate",
27
+ help="Intelligent LoRA adapter routing for local LLM inference.",
28
+ no_args_is_help=True,
29
+ pretty_exceptions_show_locals=False,
30
+ )
31
+
32
+ # Sub-apps
33
+ adapter_app = typer.Typer(help="Manage the adapter registry.", no_args_is_help=True)
34
+ task_app = typer.Typer(help="Manage task clusters.", no_args_is_help=True)
35
+ feedback_app = typer.Typer(help="Record and review routing feedback.", no_args_is_help=True)
36
+
37
+ app.add_typer(adapter_app, name="adapter")
38
+ app.add_typer(task_app, name="task")
39
+ app.add_typer(feedback_app, name="feedback")
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Helpers
44
+ # ---------------------------------------------------------------------------
45
+
46
+ def _load_registries():
47
+ """Return (task_registry, adapter_registry) — exits with error on failure."""
48
+ from shiftgate.registry.adapter_registry import AdapterRegistry
49
+ from shiftgate.registry.task_registry import TaskRegistry
50
+
51
+ try:
52
+ task_reg = TaskRegistry.load()
53
+ adapter_reg = AdapterRegistry.load()
54
+ except FileNotFoundError as exc:
55
+ console.print(f"[red]Error:[/red] {exc}")
56
+ raise typer.Exit(1)
57
+ return task_reg, adapter_reg
58
+
59
+
60
+ def _get_embedder():
61
+ from shiftgate.router.embedder import Embedder
62
+ return Embedder()
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # shiftgate init
67
+ # ---------------------------------------------------------------------------
68
+
69
+ @app.command()
70
+ def init() -> None:
71
+ """Set up ~/.shiftgate/, compute task embeddings, and show a welcome message."""
72
+ from shiftgate.registry.adapter_registry import AdapterRegistry
73
+ from shiftgate.registry.task_registry import TaskRegistry
74
+ from shiftgate.utils.display import show_adapter_table, show_task_table, show_welcome_banner
75
+
76
+ show_welcome_banner()
77
+
78
+ shiftgate_dir = Path.home() / ".shiftgate"
79
+ shiftgate_dir.mkdir(parents=True, exist_ok=True)
80
+ console.print(f"[dim]Config directory:[/dim] {shiftgate_dir}")
81
+ console.print()
82
+
83
+ # Load defaults (copies them into the user's ~/.shiftgate/ on first save)
84
+ console.print("[cyan]Loading task registry…[/cyan]")
85
+ task_reg = TaskRegistry.load()
86
+ adapter_reg = AdapterRegistry.load()
87
+
88
+ # Compute embeddings (downloads model on first run)
89
+ if task_reg.embeddings_ready():
90
+ console.print("[dim]Embeddings already computed. Skipping (delete embeddings_cache.npy to force refresh).[/dim]")
91
+ else:
92
+ console.print("[cyan]Computing task embeddings (first run — model download may take a moment)…[/cyan]")
93
+ embedder = _get_embedder()
94
+ task_reg.compute_embeddings(embedder)
95
+ console.print("[green]✓[/green] Embeddings computed.")
96
+
97
+ # Persist to ~/.shiftgate/
98
+ task_reg.save()
99
+ adapter_reg.save()
100
+ console.print(f"[green]✓[/green] Registry saved to {shiftgate_dir}")
101
+ console.print()
102
+
103
+ show_task_table(task_reg.get_all_tasks())
104
+ console.print()
105
+ show_adapter_table(adapter_reg.list_adapters())
106
+ console.print()
107
+
108
+ console.print(
109
+ "[bold green]shiftgate is ready![/bold green]\n\n"
110
+ " Try it:\n"
111
+ ' [cyan]shiftgate route "write a python function"[/cyan]\n\n'
112
+ " Add a LoRA adapter:\n"
113
+ " [cyan]shiftgate adapter add monology/pmc-llama-13b-lora[/cyan]\n"
114
+ )
115
+
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # shiftgate adapter
119
+ # ---------------------------------------------------------------------------
120
+
121
+ @adapter_app.command("add")
122
+ def adapter_add(
123
+ hf_repo_or_path: Annotated[str, typer.Argument(help="HuggingFace repo ID or local path.")],
124
+ tags: Annotated[
125
+ Optional[list[str]],
126
+ typer.Option("--tags", "-t", help="Task tags, e.g. --tags code python"),
127
+ ] = None,
128
+ base: Annotated[
129
+ Optional[str],
130
+ typer.Option("--base", "-b", help="Base model name, e.g. 'meta-llama/Meta-Llama-3-8B'"),
131
+ ] = None,
132
+ name: Annotated[Optional[str], typer.Option(help="Override display name.")] = None,
133
+ description: Annotated[Optional[str], typer.Option(help="Short description.")] = None,
134
+ ) -> None:
135
+ """Register a new LoRA adapter from a HuggingFace repo or local path."""
136
+ _, adapter_reg = _load_registries()
137
+
138
+ kwargs: dict = {}
139
+ if tags:
140
+ kwargs["tags"] = tags
141
+ if base:
142
+ kwargs["base_model"] = base
143
+ if description:
144
+ kwargs["description"] = description
145
+
146
+ with console.status("[cyan]Fetching adapter metadata…[/cyan]"):
147
+ adapter = adapter_reg.add_adapter(hf_repo_or_path, **kwargs)
148
+
149
+ if name:
150
+ adapter.name = name
151
+ adapter_reg._adapters[adapter.id] = adapter # refresh
152
+
153
+ adapter_reg.save()
154
+ console.print(f"[green]✓[/green] Adapter '[bold magenta]{adapter.id}[/bold magenta]' registered.")
155
+ console.print(f" Name: {adapter.name}")
156
+ console.print(f" Base: {adapter.base_model}")
157
+ if adapter.task_tags:
158
+ console.print(f" Tags: {', '.join(adapter.task_tags)}")
159
+
160
+
161
+ @adapter_app.command("list")
162
+ def adapter_list() -> None:
163
+ """Show all registered adapters in a Rich table."""
164
+ from shiftgate.utils.display import show_adapter_table
165
+
166
+ _, adapter_reg = _load_registries()
167
+ show_adapter_table(adapter_reg.list_adapters())
168
+
169
+
170
+ @adapter_app.command("remove")
171
+ def adapter_remove(
172
+ adapter_id: Annotated[str, typer.Argument(help="Adapter ID to remove.")],
173
+ ) -> None:
174
+ """Remove an adapter from the registry."""
175
+ _, adapter_reg = _load_registries()
176
+ if adapter_reg.remove_adapter(adapter_id):
177
+ adapter_reg.save()
178
+ console.print(f"[green]✓[/green] Adapter '{adapter_id}' removed.")
179
+ else:
180
+ console.print(f"[red]Error:[/red] Adapter '{adapter_id}' not found.")
181
+ raise typer.Exit(1)
182
+
183
+
184
+ # ---------------------------------------------------------------------------
185
+ # shiftgate task
186
+ # ---------------------------------------------------------------------------
187
+
188
+ @task_app.command("list")
189
+ def task_list() -> None:
190
+ """Show all task clusters in a Rich table."""
191
+ from shiftgate.utils.display import show_task_table
192
+
193
+ task_reg, _ = _load_registries()
194
+ show_task_table(task_reg.get_all_tasks())
195
+
196
+
197
+ @task_app.command("add")
198
+ def task_add() -> None:
199
+ """Interactively add a new task cluster to the registry."""
200
+ task_reg, _ = _load_registries()
201
+
202
+ console.print("[bold cyan]Add a new task cluster[/bold cyan]")
203
+ console.print("[dim]Press Ctrl-C to cancel.[/dim]\n")
204
+
205
+ task_id = Prompt.ask("Task ID (slug, e.g. code_rust)")
206
+ task_name = Prompt.ask("Display name (e.g. Rust Code Generation)")
207
+ task_desc = Prompt.ask("Short description")
208
+
209
+ console.print("\nEnter validation examples (one per line). Empty line to finish:")
210
+ examples: list[str] = []
211
+ while True:
212
+ ex = Prompt.ask(f" Example {len(examples) + 1} (or Enter to finish)", default="")
213
+ if not ex:
214
+ break
215
+ examples.append(ex)
216
+
217
+ if len(examples) < 3:
218
+ console.print("[yellow]Warning:[/yellow] Fewer than 3 examples may produce a poor centroid.")
219
+
220
+ preferred_raw = Prompt.ask("Preferred adapter IDs (comma-separated, or leave blank)")
221
+ preferred = [p.strip() for p in preferred_raw.split(",") if p.strip()]
222
+
223
+ task = TaskCluster(
224
+ id=task_id,
225
+ name=task_name,
226
+ description=task_desc,
227
+ validation_examples=examples,
228
+ preferred_adapters=preferred,
229
+ )
230
+
231
+ if Confirm.ask(f"\nSave task '[bold]{task_id}[/bold]'?"):
232
+ # Recompute centroid for the new task only.
233
+ try:
234
+ embedder = _get_embedder()
235
+ import numpy as np
236
+ vecs = embedder.embed_batch(task.validation_examples)
237
+ centroid = vecs.mean(axis=0)
238
+ norm = np.linalg.norm(centroid)
239
+ if norm > 0:
240
+ centroid = centroid / norm
241
+ task.embedding_centroid = centroid.tolist()
242
+ console.print("[green]✓[/green] Centroid computed for new task.")
243
+ except Exception as exc:
244
+ console.print(f"[yellow]Warning:[/yellow] Could not compute centroid: {exc}")
245
+
246
+ task_reg.add_task(task)
247
+ task_reg.save()
248
+ console.print(f"[green]✓[/green] Task '[bold]{task_id}[/bold]' saved.")
249
+ else:
250
+ console.print("[dim]Cancelled.[/dim]")
251
+
252
+
253
+ # ---------------------------------------------------------------------------
254
+ # shiftgate route — routing only, no inference
255
+ # ---------------------------------------------------------------------------
256
+
257
+ @app.command()
258
+ def route(
259
+ query: Annotated[str, typer.Argument(help="Query to route.")],
260
+ top_k: Annotated[int, typer.Option("--top-k", "-k", help="Number of candidate tasks.")] = 3,
261
+ record: Annotated[bool, typer.Option(help="Save trace to ~/.shiftgate/traces.jsonl.")] = True,
262
+ ) -> None:
263
+ """Route a query to the best adapter (no inference — just the routing decision)."""
264
+ from shiftgate.feedback import loop as feedback_loop
265
+ from shiftgate.registry.task_registry import TaskRegistry
266
+ from shiftgate.router import router as routing
267
+ from shiftgate.utils.display import show_routing_decision
268
+
269
+ task_reg, adapter_reg = _load_registries()
270
+
271
+ if not task_reg.embeddings_ready():
272
+ console.print("[red]Error:[/red] Task embeddings not initialised. Run `shiftgate init` first.")
273
+ raise typer.Exit(1)
274
+
275
+ embedder = _get_embedder()
276
+
277
+ try:
278
+ trace = routing.route(query, task_reg, adapter_reg, embedder, top_k=top_k)
279
+ except Exception as exc:
280
+ console.print(f"[red]Routing error:[/red] {exc}")
281
+ raise typer.Exit(1)
282
+
283
+ adapter = adapter_reg.get_adapter(trace.selected_adapter_id)
284
+ task = task_reg.get_task(trace.matched_task_id)
285
+
286
+ show_routing_decision(
287
+ trace,
288
+ adapter=adapter,
289
+ task_name=task.name if task else None,
290
+ backend_name=None,
291
+ )
292
+
293
+ if record:
294
+ feedback_loop.record_trace(trace)
295
+ console.print(f"[dim]Trace {trace.id[:8]}… recorded. Run `shiftgate feedback accept/reject` to rate it.[/dim]")
296
+
297
+
298
+ # ---------------------------------------------------------------------------
299
+ # shiftgate run — route + run inference
300
+ # ---------------------------------------------------------------------------
301
+
302
+ @app.command()
303
+ def run(
304
+ query: Annotated[str, typer.Argument(help="Query to route and run.")],
305
+ top_k: Annotated[int, typer.Option("--top-k", "-k")] = 3,
306
+ ) -> None:
307
+ """Route a query and run it through the detected Ollama or vLLM backend."""
308
+ from shiftgate.feedback import loop as feedback_loop
309
+ from shiftgate.router import router as routing
310
+ from shiftgate.runtime.backend import BackendRouter, NoBackendError
311
+ from shiftgate.utils.display import show_routing_decision
312
+
313
+ task_reg, adapter_reg = _load_registries()
314
+
315
+ if not task_reg.embeddings_ready():
316
+ console.print("[red]Error:[/red] Task embeddings not initialised. Run `shiftgate init` first.")
317
+ raise typer.Exit(1)
318
+
319
+ embedder = _get_embedder()
320
+
321
+ try:
322
+ trace = routing.route(query, task_reg, adapter_reg, embedder, top_k=top_k)
323
+ except Exception as exc:
324
+ console.print(f"[red]Routing error:[/red] {exc}")
325
+ raise typer.Exit(1)
326
+
327
+ adapter = adapter_reg.get_adapter(trace.selected_adapter_id)
328
+ task = task_reg.get_task(trace.matched_task_id)
329
+ backend_router = BackendRouter()
330
+ backend_name = backend_router.detect()
331
+
332
+ show_routing_decision(
333
+ trace,
334
+ adapter=adapter,
335
+ task_name=task.name if task else None,
336
+ backend_name=backend_name,
337
+ )
338
+
339
+ if adapter is None:
340
+ console.print(f"[red]Adapter '{trace.selected_adapter_id}' not found in registry.[/red]")
341
+ raise typer.Exit(1)
342
+
343
+ if backend_name is None:
344
+ console.print(
345
+ "[yellow]No inference backend detected.[/yellow]\n"
346
+ " shiftgate routed your query to "
347
+ f"[bold magenta]{trace.selected_adapter_id}[/bold magenta].\n"
348
+ " To run inference, start a backend:\n"
349
+ " [cyan]ollama serve[/cyan]\n"
350
+ " [cyan]python -m vllm.entrypoints.openai.api_server --model <base_model>[/cyan]"
351
+ )
352
+ feedback_loop.record_trace(trace)
353
+ raise typer.Exit(0)
354
+
355
+ # Run inference
356
+ console.print(f"[cyan]Running via [bold]{backend_name}[/bold]…[/cyan]")
357
+ t0 = time.monotonic()
358
+ try:
359
+ response = backend_router.generate(query, adapter)
360
+ except NoBackendError as exc:
361
+ console.print(f"[red]{exc}[/red]")
362
+ raise typer.Exit(1)
363
+ except Exception as exc:
364
+ console.print(f"[red]Inference error:[/red] {exc}")
365
+ raise typer.Exit(1)
366
+
367
+ elapsed_ms = (time.monotonic() - t0) * 1000
368
+ trace.latency_ms = elapsed_ms
369
+
370
+ console.print()
371
+ console.rule("[dim]Response[/dim]")
372
+ console.print(response)
373
+ console.rule()
374
+ console.print(f"[dim]Latency: {elapsed_ms:.0f} ms[/dim]")
375
+
376
+ feedback_loop.record_trace(trace)
377
+
378
+
379
+ # ---------------------------------------------------------------------------
380
+ # shiftgate feedback
381
+ # ---------------------------------------------------------------------------
382
+
383
+ @feedback_app.command("accept")
384
+ def feedback_accept() -> None:
385
+ """Mark the last routing trace as accepted (good routing decision)."""
386
+ from shiftgate.feedback import loop as feedback_loop
387
+
388
+ trace = feedback_loop.mark_last_accepted(True)
389
+ if trace:
390
+ console.print(f"[green]✓[/green] Trace {trace.id[:8]}… marked [green]accepted[/green].")
391
+ else:
392
+ console.print("[yellow]No traces found.[/yellow]")
393
+
394
+
395
+ @feedback_app.command("reject")
396
+ def feedback_reject() -> None:
397
+ """Mark the last routing trace as rejected (bad routing decision)."""
398
+ from shiftgate.feedback import loop as feedback_loop
399
+
400
+ trace = feedback_loop.mark_last_accepted(False)
401
+ if trace:
402
+ console.print(f"[green]✓[/green] Trace {trace.id[:8]}… marked [red]rejected[/red].")
403
+ else:
404
+ console.print("[yellow]No traces found.[/yellow]")
405
+
406
+
407
+ @feedback_app.command("stats")
408
+ def feedback_stats() -> None:
409
+ """Show adapter acceptance rates across all rated traces."""
410
+ from shiftgate.feedback import loop as feedback_loop
411
+ from shiftgate.utils.display import show_feedback_stats
412
+
413
+ scores = feedback_loop.compute_adapter_scores()
414
+ stats = feedback_loop.get_trace_stats()
415
+ show_feedback_stats(scores, stats)
416
+
417
+
418
+ # ---------------------------------------------------------------------------
419
+ # shiftgate status
420
+ # ---------------------------------------------------------------------------
421
+
422
+ @app.command()
423
+ def status() -> None:
424
+ """Show backend connectivity, registry sizes, and embedding status."""
425
+ from shiftgate.runtime.backend import BackendRouter
426
+ from shiftgate.utils.display import show_status
427
+
428
+ task_reg, adapter_reg = _load_registries()
429
+
430
+ with console.status("[cyan]Probing backends…[/cyan]"):
431
+ backend_router = BackendRouter()
432
+ backend_name = backend_router.detect()
433
+
434
+ show_status(
435
+ backend_name=backend_name,
436
+ n_adapters=len(adapter_reg),
437
+ n_tasks=len(task_reg),
438
+ embeddings_ready=task_reg.embeddings_ready(),
439
+ )
440
+
441
+
442
+ # ---------------------------------------------------------------------------
443
+ # shiftgate demo
444
+ # ---------------------------------------------------------------------------
445
+
446
+ @app.command()
447
+ def demo() -> None:
448
+ """Run an animated demo: fake routing traces and an adapter swap."""
449
+ from shiftgate.registry.schemas import RoutingTrace
450
+ from shiftgate.utils.display import (
451
+ animate_swap,
452
+ show_routing_decision,
453
+ show_welcome_banner,
454
+ )
455
+
456
+ show_welcome_banner()
457
+ time.sleep(0.5)
458
+
459
+ demo_traces = [
460
+ {
461
+ "query": "Write a Python function to parse JSON from a REST API",
462
+ "task": "Python Code Generation",
463
+ "adapter": "python-lora-llama3",
464
+ "score": 0.91,
465
+ },
466
+ {
467
+ "query": "SELECT all users who signed up in the last 30 days",
468
+ "task": "SQL Query Writing",
469
+ "adapter": "sql-lora-mistral",
470
+ "score": 0.87,
471
+ },
472
+ {
473
+ "query": "Summarise this research paper in 3 bullet points",
474
+ "task": "Text Summarization",
475
+ "adapter": "summarize-lora-llama3",
476
+ "score": 0.83,
477
+ },
478
+ {
479
+ "query": "Fix the KeyError on line 42 in this Python script",
480
+ "task": "Debugging & Error Fixing",
481
+ "adapter": "debug-lora-codellama",
482
+ "score": 0.78,
483
+ },
484
+ ]
485
+
486
+ import uuid
487
+ from datetime import datetime, timezone
488
+
489
+ for i, entry in enumerate(demo_traces):
490
+ trace = RoutingTrace(
491
+ id=uuid.uuid4().hex,
492
+ query=entry["query"],
493
+ matched_task_id=entry["task"].lower().replace(" ", "_"),
494
+ similarity_score=entry["score"],
495
+ selected_adapter_id=entry["adapter"],
496
+ timestamp=datetime.now(timezone.utc).isoformat(),
497
+ )
498
+ show_routing_decision(trace, task_name=entry["task"])
499
+ time.sleep(0.8)
500
+
501
+ if i < len(demo_traces) - 1:
502
+ next_adapter = demo_traces[i + 1]["adapter"]
503
+ animate_swap(entry["adapter"], next_adapter, duration=1.0)
504
+ time.sleep(0.4)
505
+
506
+ console.print()
507
+ console.print("[bold green]Demo complete![/bold green] shiftgate routes tasks at inference time — zero training required.")
508
+ console.print(
509
+ "\n Get started:\n"
510
+ ' [cyan]shiftgate init[/cyan]\n'
511
+ ' [cyan]shiftgate adapter add <hf_repo>[/cyan]\n'
512
+ ' [cyan]shiftgate route "your query here"[/cyan]\n'
513
+ )
@@ -0,0 +1 @@
1
+ """Feedback sub-package: trace storage and adapter scoring loop."""