shiftgate 0.1.9__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {shiftgate-0.1.9 → shiftgate-0.2.1}/PKG-INFO +39 -2
  2. {shiftgate-0.1.9 → shiftgate-0.2.1}/README.md +38 -1
  3. {shiftgate-0.1.9 → shiftgate-0.2.1}/pyproject.toml +1 -1
  4. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/cli.py +137 -16
  5. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/registry/adapter_registry.py +26 -0
  6. shiftgate-0.2.1/shiftgate/router/embedder.py +174 -0
  7. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/router/matcher.py +64 -28
  8. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/router/router.py +17 -1
  9. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/runtime/backend.py +223 -9
  10. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/serve/app.py +123 -23
  11. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/utils/display.py +33 -2
  12. shiftgate-0.2.1/tests/test_backend.py +466 -0
  13. shiftgate-0.2.1/tests/test_cli.py +146 -0
  14. {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_packaging.py +6 -0
  15. {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_router.py +88 -0
  16. {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_serve.py +99 -0
  17. shiftgate-0.1.9/shiftgate/router/embedder.py +0 -95
  18. shiftgate-0.1.9/tests/test_backend.py +0 -232
  19. {shiftgate-0.1.9 → shiftgate-0.2.1}/.gitignore +0 -0
  20. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/__init__.py +0 -0
  21. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/data/__init__.py +0 -0
  22. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/data/default_tasks.json +0 -0
  23. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/feedback/__init__.py +0 -0
  24. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/feedback/loop.py +0 -0
  25. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/registry/__init__.py +0 -0
  26. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/registry/schemas.py +0 -0
  27. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/registry/task_registry.py +0 -0
  28. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/router/__init__.py +0 -0
  29. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/runtime/__init__.py +0 -0
  30. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/serve/__init__.py +0 -0
  31. {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/utils/__init__.py +0 -0
  32. {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/__init__.py +0 -0
  33. {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_feedback.py +0 -0
  34. {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_registry.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shiftgate
3
- Version: 0.1.9
3
+ Version: 0.2.1
4
4
  Summary: Intelligent routing layer that automatically selects the right LoRA adapter for each task in your local agent loop.
5
5
  Project-URL: Homepage, https://github.com/shiftgate-ai/shiftgate
6
6
  Project-URL: Repository, https://github.com/shiftgate-ai/shiftgate
@@ -276,6 +276,10 @@ User query
276
276
  └────────────────────────────────┘
277
277
  ```
278
278
 
279
+ ### How routing works
280
+
281
+ When a backend is active, shiftgate filters candidate adapters to only those actually loaded on that backend. Switch from vLLM to Cerebras and shiftgate automatically picks Cerebras-compatible adapters — no re-registration needed. (When you run `shiftgate route` with no backend running, no filtering is applied, so you still see the full routing preview.)
282
+
279
283
  ---
280
284
 
281
285
  ## Bring Your Own Models
@@ -352,10 +356,43 @@ shiftgate adapter add llama3.1-8b --runtime llama3.1-8b --tags general --base ll
352
356
  shiftgate run "write a python sorting function"
353
357
  ```
354
358
 
355
- shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras**, so local backends always win and Cerebras is used only when no local backend is running.
359
+ shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras → Cloudflare**, so local backends always win and the cloud backends are used only when no local backend is running.
356
360
 
357
361
  > **Honest status:** shiftgate routes to Cerebras' base-model inference today. When Cerebras Multi-LoRA goes public, register your adapter with `--runtime <cerebras-lora-id>` and it just works — no shiftgate update needed.
358
362
 
363
+ ### Option 5 — Cloudflare Workers AI (cloud, LoRA-native)
364
+
365
+ [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) serves your own LoRA finetunes on top of supported base models.
366
+
367
+ ```bash
368
+ # 1. Upload your LoRA to Cloudflare (one-time)
369
+ npx wrangler ai finetune create @cf/mistral/mistral-7b-instruct-v0.2-lora my-sql-lora ./adapter-folder
370
+
371
+ # 2. Set credentials
372
+ export CLOUDFLARE_ACCOUNT_ID=...
373
+ export CLOUDFLARE_API_TOKEN=...
374
+
375
+ # 3. Register in shiftgate — note --base is the Cloudflare model name
376
+ shiftgate adapter add my-sql-lora \
377
+ --runtime my-sql-lora \
378
+ --base @cf/mistral/mistral-7b-instruct-v0.2-lora \
379
+ --tags sql
380
+
381
+ # 4. Run
382
+ shiftgate run "write a sql join query"
383
+ ```
384
+
385
+ You can also pass credentials per-run with the `--cf-account-id` and `--cf-api-token` global flags.
386
+
387
+ **Architectural difference (handled transparently):** Cloudflare keeps the base model in the URL and accepts the LoRA name as a separate `lora` field — not as the `model` value like vLLM/Cerebras. shiftgate handles this transparently; your routing logic doesn't change. The `--base` you register **must** be a Cloudflare model name starting with `@cf/`.
388
+
389
+ **Limitations** (from [Cloudflare's docs](https://developers.cloudflare.com/workers-ai/features/fine-tunes/)):
390
+
391
+ - Up to **100 LoRAs** per account.
392
+ - LoRA file must be **< 300 MB**.
393
+ - Must be trained with rank **r ≤ 8** (up to 32 on some models).
394
+ - **Streaming is not yet supported** through shiftgate for Cloudflare — you get a single response. (Streaming requests to `shiftgate serve` against Cloudflare return HTTP 501.)
395
+
359
396
  ---
360
397
 
361
398
  ## How to contribute adapters
@@ -240,6 +240,10 @@ User query
240
240
  └────────────────────────────────┘
241
241
  ```
242
242
 
243
+ ### How routing works
244
+
245
+ When a backend is active, shiftgate filters candidate adapters to only those actually loaded on that backend. Switch from vLLM to Cerebras and shiftgate automatically picks Cerebras-compatible adapters — no re-registration needed. (When you run `shiftgate route` with no backend running, no filtering is applied, so you still see the full routing preview.)
246
+
243
247
  ---
244
248
 
245
249
  ## Bring Your Own Models
@@ -316,10 +320,43 @@ shiftgate adapter add llama3.1-8b --runtime llama3.1-8b --tags general --base ll
316
320
  shiftgate run "write a python sorting function"
317
321
  ```
318
322
 
319
- shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras**, so local backends always win and Cerebras is used only when no local backend is running.
323
+ shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras → Cloudflare**, so local backends always win and the cloud backends are used only when no local backend is running.
320
324
 
321
325
  > **Honest status:** shiftgate routes to Cerebras' base-model inference today. When Cerebras Multi-LoRA goes public, register your adapter with `--runtime <cerebras-lora-id>` and it just works — no shiftgate update needed.
322
326
 
327
+ ### Option 5 — Cloudflare Workers AI (cloud, LoRA-native)
328
+
329
+ [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) serves your own LoRA finetunes on top of supported base models.
330
+
331
+ ```bash
332
+ # 1. Upload your LoRA to Cloudflare (one-time)
333
+ npx wrangler ai finetune create @cf/mistral/mistral-7b-instruct-v0.2-lora my-sql-lora ./adapter-folder
334
+
335
+ # 2. Set credentials
336
+ export CLOUDFLARE_ACCOUNT_ID=...
337
+ export CLOUDFLARE_API_TOKEN=...
338
+
339
+ # 3. Register in shiftgate — note --base is the Cloudflare model name
340
+ shiftgate adapter add my-sql-lora \
341
+ --runtime my-sql-lora \
342
+ --base @cf/mistral/mistral-7b-instruct-v0.2-lora \
343
+ --tags sql
344
+
345
+ # 4. Run
346
+ shiftgate run "write a sql join query"
347
+ ```
348
+
349
+ You can also pass credentials per-run with the `--cf-account-id` and `--cf-api-token` global flags.
350
+
351
+ **Architectural difference (handled transparently):** Cloudflare keeps the base model in the URL and accepts the LoRA name as a separate `lora` field — not as the `model` value like vLLM/Cerebras. shiftgate handles this transparently; your routing logic doesn't change. The `--base` you register **must** be a Cloudflare model name starting with `@cf/`.
352
+
353
+ **Limitations** (from [Cloudflare's docs](https://developers.cloudflare.com/workers-ai/features/fine-tunes/)):
354
+
355
+ - Up to **100 LoRAs** per account.
356
+ - LoRA file must be **< 300 MB**.
357
+ - Must be trained with rank **r ≤ 8** (up to 32 on some models).
358
+ - **Streaming is not yet supported** through shiftgate for Cloudflare — you get a single response. (Streaming requests to `shiftgate serve` against Cloudflare return HTTP 501.)
359
+
323
360
  ---
324
361
 
325
362
  ## How to contribute adapters
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "shiftgate"
7
- version = "0.1.9"
7
+ version = "0.2.1"
8
8
  description = "Intelligent routing layer that automatically selects the right LoRA adapter for each task in your local agent loop."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -52,12 +52,47 @@ def _main(
52
52
  ),
53
53
  ),
54
54
  ] = None,
55
+ cf_account_id: Annotated[
56
+ Optional[str],
57
+ typer.Option(
58
+ "--cf-account-id",
59
+ help="Cloudflare account ID. Sets CLOUDFLARE_ACCOUNT_ID for this run.",
60
+ ),
61
+ ] = None,
62
+ cf_api_token: Annotated[
63
+ Optional[str],
64
+ typer.Option(
65
+ "--cf-api-token",
66
+ help="Cloudflare API token. Sets CLOUDFLARE_API_TOKEN for this run.",
67
+ ),
68
+ ] = None,
69
+ verbose: Annotated[
70
+ bool,
71
+ typer.Option(
72
+ "--verbose",
73
+ "-v",
74
+ help="Enable DEBUG logging (shows routing internals like runtime filtering).",
75
+ ),
76
+ ] = False,
55
77
  ) -> None:
56
78
  """Global options applied before any command runs."""
57
- if cerebras_key:
58
- import os
79
+ import os
59
80
 
81
+ if cerebras_key:
60
82
  os.environ["CEREBRAS_API_KEY"] = cerebras_key
83
+ if cf_account_id:
84
+ os.environ["CLOUDFLARE_ACCOUNT_ID"] = cf_account_id
85
+ if cf_api_token:
86
+ os.environ["CLOUDFLARE_API_TOKEN"] = cf_api_token
87
+
88
+ if verbose:
89
+ import logging
90
+
91
+ logging.basicConfig(
92
+ level=logging.DEBUG,
93
+ format="%(levelname)s [%(name)s] %(message)s",
94
+ )
95
+ logging.getLogger("shiftgate").setLevel(logging.DEBUG)
61
96
 
62
97
 
63
98
  # ---------------------------------------------------------------------------
@@ -83,6 +118,35 @@ def _get_embedder():
83
118
  return Embedder()
84
119
 
85
120
 
121
+ def _active_runtimes(backend_router, adapter_reg=None) -> set[str] | None:
122
+ """Return the set of runtime names usable on the active backend, or None.
123
+
124
+ ``None`` means no backend is active → the router should not filter
125
+ (preview behaviour). An empty set means a backend is active but reports no
126
+ usable models.
127
+
128
+ For Cloudflare, base models are always available without any finetune
129
+ upload, so every registered ``@cf/`` adapter that has no finetune
130
+ ``runtime_name`` is considered usable (base-model inference).
131
+ """
132
+ active = backend_router.active_backend
133
+ if active is None:
134
+ return None
135
+
136
+ usable = set(active.list_loaded_adapters_cached())
137
+
138
+ from shiftgate.runtime.backend import CloudflareBackend
139
+
140
+ if isinstance(active, CloudflareBackend) and adapter_reg is not None:
141
+ for adapter in adapter_reg.list_adapters():
142
+ is_cf_base = (adapter.base_model or "").startswith("@cf/")
143
+ has_finetune = bool((adapter.runtime_name or "").strip())
144
+ if is_cf_base and not has_finetune:
145
+ usable.add(adapter.effective_backend_name())
146
+
147
+ return usable
148
+
149
+
86
150
  def _auto_link_adapter(adapter: AdapterEntry, task_reg) -> list[str]:
87
151
  """Add ``adapter.id`` to the ``preferred_adapters`` of matching task clusters.
88
152
 
@@ -147,9 +211,15 @@ def _verify_runtime_adapter(adapter: AdapterEntry, adapter_reg) -> None:
147
211
  """
148
212
  from shiftgate.runtime.backend import BackendRouter
149
213
 
214
+ # A Cloudflare base model (@cf/...) implies this is a Cloudflare adapter —
215
+ # prefer the Cloudflare backend for verification when it's reachable.
216
+ is_cloudflare = (adapter.base_model or "").startswith("@cf/")
217
+
150
218
  try:
151
219
  with console.status("[cyan]Verifying adapter against running backend…[/cyan]"):
152
220
  router = BackendRouter()
221
+ if is_cloudflare and router._cloudflare.is_available():
222
+ router.select("cloudflare")
153
223
  is_loaded, backend_name = router.verify_adapter(adapter)
154
224
  except Exception as exc: # pragma: no cover - defensive, should not happen
155
225
  logger_msg = f"verification error: {exc}"
@@ -166,9 +236,14 @@ def _verify_runtime_adapter(adapter: AdapterEntry, adapter_reg) -> None:
166
236
  console.print(f" [green]Backend: {backend_name} ✓ verified[/green]")
167
237
  else:
168
238
  adapter.verified = False
239
+ hint = (
240
+ "— upload it with `npx wrangler ai finetune create`"
241
+ if backend_name == "cloudflare"
242
+ else "— did you pass --lora-modules?"
243
+ )
169
244
  console.print(
170
245
  f" [yellow]Backend: {backend_name} ⚠ runtime '{runtime}' not loaded "
171
- "— did you pass --lora-modules?[/yellow]"
246
+ f"{hint}[/yellow]"
172
247
  )
173
248
 
174
249
  adapter_reg.save()
@@ -194,12 +269,23 @@ def init() -> None:
194
269
  task_reg = TaskRegistry.load()
195
270
 
196
271
  if task_reg.embeddings_ready():
197
- console.print("[dim]Embeddings already computed. Skipping (delete embeddings_cache.npy to force refresh).[/dim]")
272
+ console.print(
273
+ "[dim]Task centroids already computed — skipping re-embed "
274
+ "(delete embeddings_cache.npy to force refresh).[/dim]"
275
+ )
198
276
  else:
199
277
  console.print("[cyan]Computing task embeddings (first run — model download may take a moment)…[/cyan]")
200
278
  embedder = _get_embedder()
201
279
  task_reg.compute_embeddings(embedder)
202
- console.print("[green][/green] Embeddings computed.")
280
+ console.print("[green]OK[/green] Embeddings computed.")
281
+
282
+ # Centroids can exist while the ONNX runtime model is missing (e.g. after
283
+ # deleting fastembed_cache). Always warm-load so routing works.
284
+ console.print("[cyan]Loading embedder model (one-time ~30 MB download if not cached)…[/cyan]")
285
+ from shiftgate.router.embedder import warm_up
286
+
287
+ dim = warm_up()
288
+ console.print(f"[green]OK[/green] Embedder ready (dim={dim}).")
203
289
 
204
290
  task_reg.save()
205
291
  console.print(f"[green]✓[/green] Task registry saved to {shiftgate_dir}")
@@ -288,6 +374,7 @@ def adapter_add(
288
374
  """
289
375
  from shiftgate.registry.adapter_registry import (
290
376
  AdapterRegistry,
377
+ adapter_from_base_model,
291
378
  adapter_from_hf,
292
379
  adapter_from_local,
293
380
  adapter_from_runtime,
@@ -334,12 +421,22 @@ def adapter_add(
334
421
  **shared_kwargs,
335
422
  )
336
423
 
337
- # --- Ambiguous: no '/', no --local, no --runtime ---
424
+ # --- Mode D: Cloudflare base model (always available, no finetune) ---
425
+ elif (base or "").startswith("@cf/"):
426
+ base_model_kwargs = {k: v for k, v in shared_kwargs.items() if k != "base_model"}
427
+ adapter = adapter_from_base_model(
428
+ base_model=base,
429
+ adapter_id=adapter_id or identifier,
430
+ **base_model_kwargs,
431
+ )
432
+
433
+ # --- Ambiguous: no '/', no --local, no --runtime, no @cf/ base ---
338
434
  else:
339
435
  console.print(
340
436
  f"[red]Error:[/red] '{identifier}' doesn't look like a HuggingFace repo ID (missing '/').\n"
341
- " Use [cyan]--local /path/to/adapter[/cyan] to register a local adapter, or\n"
342
- " use [cyan]--runtime <backend-name>[/cyan] for a runtime-registered adapter."
437
+ " Use [cyan]--local /path/to/adapter[/cyan] to register a local adapter,\n"
438
+ " use [cyan]--runtime <backend-name>[/cyan] for a runtime-registered adapter, or\n"
439
+ " use [cyan]--base @cf/...[/cyan] for a Cloudflare Workers AI base model."
343
440
  )
344
441
  raise typer.Exit(1)
345
442
 
@@ -350,6 +447,11 @@ def adapter_add(
350
447
  # has it loaded. Purely informational — never fails the add command.
351
448
  if adapter.runtime_name:
352
449
  _verify_runtime_adapter(adapter, adapter_reg)
450
+ elif (adapter.base_model or "").startswith("@cf/"):
451
+ console.print(
452
+ " [green]Backend:[/green] cloudflare "
453
+ "[green]✓[/green] base model is always available (no upload needed)"
454
+ )
353
455
 
354
456
 
355
457
  @adapter_app.command("list")
@@ -464,6 +566,7 @@ def route(
464
566
  """
465
567
  from shiftgate.feedback import loop as feedback_loop
466
568
  from shiftgate.router import router as routing
569
+ from shiftgate.runtime.backend import BackendRouter
467
570
  from shiftgate.utils.display import show_explain_decision, show_routing_decision
468
571
 
469
572
  task_reg, adapter_reg = _load_registries()
@@ -474,8 +577,15 @@ def route(
474
577
 
475
578
  embedder = _get_embedder()
476
579
 
580
+ backend_router = BackendRouter()
581
+ backend_name = backend_router.detect()
582
+ available_runtimes = _active_runtimes(backend_router, adapter_reg)
583
+
477
584
  try:
478
- trace, match_result = routing.route(query, task_reg, adapter_reg, embedder, top_k=top_k)
585
+ trace, match_result = routing.route(
586
+ query, task_reg, adapter_reg, embedder,
587
+ top_k=top_k, available_runtimes=available_runtimes,
588
+ )
479
589
  except Exception as exc:
480
590
  console.print(f"[red]Routing error:[/red] {exc}")
481
591
  raise typer.Exit(1)
@@ -487,7 +597,9 @@ def route(
487
597
  trace,
488
598
  adapter=adapter,
489
599
  task_name=task.name if task else None,
490
- backend_name=None,
600
+ backend_name=backend_name,
601
+ loaded_runtimes=available_runtimes,
602
+ selection_method=match_result.selection_method,
491
603
  )
492
604
 
493
605
  if explain:
@@ -524,22 +636,29 @@ def run(
524
636
 
525
637
  embedder = _get_embedder()
526
638
 
639
+ backend_router = BackendRouter()
640
+ backend_name = backend_router.detect()
641
+ available_runtimes = _active_runtimes(backend_router, adapter_reg)
642
+
527
643
  try:
528
- trace, match_result = routing.route(query, task_reg, adapter_reg, embedder, top_k=top_k)
644
+ trace, match_result = routing.route(
645
+ query, task_reg, adapter_reg, embedder,
646
+ top_k=top_k, available_runtimes=available_runtimes,
647
+ )
529
648
  except Exception as exc:
530
649
  console.print(f"[red]Routing error:[/red] {exc}")
531
650
  raise typer.Exit(1)
532
651
 
533
652
  adapter = adapter_reg.get_adapter(trace.selected_adapter_id)
534
653
  task = task_reg.get_task(trace.matched_task_id)
535
- backend_router = BackendRouter()
536
- backend_name = backend_router.detect()
537
654
 
538
655
  show_routing_decision(
539
656
  trace,
540
657
  adapter=adapter,
541
658
  task_name=task.name if task else None,
542
659
  backend_name=backend_name,
660
+ loaded_runtimes=available_runtimes,
661
+ selection_method=match_result.selection_method,
543
662
  )
544
663
 
545
664
  if adapter is None:
@@ -686,9 +805,11 @@ def doctor() -> None:
686
805
  router = BackendRouter()
687
806
  backend_name = router.detect()
688
807
  backend_url = router.active_backend_url
689
- loaded_adapters: list[str] = []
808
+ loaded_adapters: set[str] = set()
690
809
  if backend_name is not None and router._active is not None:
691
- loaded_adapters = router._active.list_loaded_adapters()
810
+ # Cloudflare base models are always available, so use the same
811
+ # usable-runtime computation the router uses when filtering.
812
+ loaded_adapters = _active_runtimes(router, adapter_reg) or set()
692
813
 
693
814
  # --- 3. Per-adapter runtime availability ---
694
815
  adapter_rows = []
@@ -744,7 +865,7 @@ def serve(
744
865
  str,
745
866
  typer.Option(
746
867
  "--backend",
747
- help="Backend to forward to: auto | ollama | vllm | cerebras.",
868
+ help="Backend to forward to: auto | ollama | vllm | cerebras | cloudflare.",
748
869
  ),
749
870
  ] = "auto",
750
871
  ) -> None:
@@ -244,6 +244,32 @@ def adapter_from_runtime(
244
244
  )
245
245
 
246
246
 
247
+ def adapter_from_base_model(
248
+ base_model: str,
249
+ *,
250
+ adapter_id: str,
251
+ name: str | None = None,
252
+ tags: list[str] | None = None,
253
+ description: str | None = None,
254
+ benchmark_score: float | None = None,
255
+ ) -> AdapterEntry:
256
+ """Build an AdapterEntry that routes to a base model with no finetune.
257
+
258
+ Used for backends whose base models are always available without any
259
+ upload (e.g. Cloudflare Workers AI ``@cf/`` models). ``runtime_name`` is
260
+ left unset so the backend runs the base model directly.
261
+ """
262
+ slug = _slugify(adapter_id)
263
+ return AdapterEntry(
264
+ id=slug,
265
+ name=name or slug.replace("-", " ").title(),
266
+ base_model=base_model,
267
+ task_tags=list(tags or []),
268
+ description=description or f"Base model: {base_model}",
269
+ benchmark_score=benchmark_score,
270
+ )
271
+
272
+
247
273
  # ---------------------------------------------------------------------------
248
274
  # Internal helpers
249
275
  # ---------------------------------------------------------------------------
@@ -0,0 +1,174 @@
1
+ """
2
+ Text embedder backed by fastembed.
3
+
4
+ Uses ``BAAI/bge-small-en-v1.5`` — a compact (33 M param) model that runs
5
+ efficiently on CPU. The model is downloaded once by fastembed and cached in
6
+ ``~/.shiftgate/fastembed_cache`` (avoids Windows ``%TEMP%`` corruption issues).
7
+
8
+ A module-level singleton (``_MODEL``) is created lazily on first use so that
9
+ importing this module is cheap. The model is NOT re-created between calls.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import time
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Stable cache location — fastembed defaults to %TEMP% on Windows, which is
24
+ # prone to partial downloads and "file sizes do not match" corruption.
25
+ FASTEMBED_CACHE_DIR = Path.home() / ".shiftgate" / "fastembed_cache"
26
+
27
+ # -------------------------------------------------------------------------
28
+ # Default model — small, CPU-friendly, strong quality/speed trade-off.
29
+ # -------------------------------------------------------------------------
30
+ DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
31
+
32
+ # HuggingFace downloads can flake; retry a few times before giving up.
33
+ _LOAD_RETRIES = 3
34
+ _LOAD_RETRY_DELAY_S = 2.0
35
+
36
+ # Module-level singleton; populated on first call to `_get_model()`.
37
+ _MODEL: Any | None = None
38
+
39
+
40
+ def _is_retryable_download_error(exc: BaseException) -> bool:
41
+ """Return True for transient HuggingFace / network failures."""
42
+ msg = str(exc).lower()
43
+ needles = (
44
+ "server disconnected",
45
+ "connection reset",
46
+ "connection aborted",
47
+ "timed out",
48
+ "timeout",
49
+ "temporary failure",
50
+ "503",
51
+ "502",
52
+ "429",
53
+ )
54
+ return any(n in msg for n in needles)
55
+
56
+
57
+ def _format_load_error(model_name: str, exc: BaseException) -> str:
58
+ cache_hint = str(FASTEMBED_CACHE_DIR)
59
+ if _is_retryable_download_error(exc):
60
+ return (
61
+ f"Failed to download embedding model '{model_name}' from HuggingFace: {exc}\n"
62
+ "This is usually a transient network/rate-limit issue. Retry:\n"
63
+ f" uv run shiftgate init\n"
64
+ "Optional: set HF_TOKEN for higher HuggingFace rate limits.\n"
65
+ f"If downloads keep failing, delete '{cache_hint}' and retry."
66
+ )
67
+ return (
68
+ f"Failed to load embedding model '{model_name}': {exc}\n"
69
+ "If you see NO_SUCHFILE or 'file sizes do not match', delete the cache "
70
+ "and retry:\n"
71
+ f" Remove-Item -Recurse -Force '{cache_hint}'\n"
72
+ "Also clear any stale copy at $env:TEMP\\fastembed_cache, then run "
73
+ "`shiftgate init` again."
74
+ )
75
+
76
+
77
+ def _get_model(model_name: str = DEFAULT_MODEL) -> Any:
78
+ """Return the fastembed TextEmbedding singleton, creating it if needed.
79
+
80
+ The model is loaded once per process. If you need a different model,
81
+ call ``reset_model()`` first.
82
+ """
83
+ global _MODEL
84
+ if _MODEL is None:
85
+ try:
86
+ from fastembed import TextEmbedding # type: ignore
87
+ except ImportError as exc:
88
+ raise ImportError(
89
+ "fastembed is required for shiftgate routing. "
90
+ "Install it with: pip install fastembed"
91
+ ) from exc
92
+
93
+ FASTEMBED_CACHE_DIR.mkdir(parents=True, exist_ok=True)
94
+ logger.info(
95
+ "Loading embedding model '%s' (first use — one-time download may occur)…",
96
+ model_name,
97
+ )
98
+ last_exc: BaseException | None = None
99
+ for attempt in range(1, _LOAD_RETRIES + 1):
100
+ try:
101
+ _MODEL = TextEmbedding(
102
+ model_name=model_name,
103
+ cache_dir=str(FASTEMBED_CACHE_DIR),
104
+ )
105
+ break
106
+ except Exception as exc:
107
+ last_exc = exc
108
+ if attempt < _LOAD_RETRIES and _is_retryable_download_error(exc):
109
+ logger.warning(
110
+ "Embedder load attempt %d/%d failed (%s); retrying…",
111
+ attempt,
112
+ _LOAD_RETRIES,
113
+ exc,
114
+ )
115
+ time.sleep(_LOAD_RETRY_DELAY_S * attempt)
116
+ continue
117
+ raise RuntimeError(_format_load_error(model_name, exc)) from exc
118
+ else:
119
+ assert last_exc is not None
120
+ raise RuntimeError(_format_load_error(model_name, last_exc)) from last_exc
121
+ logger.info("Embedding model loaded.")
122
+ return _MODEL
123
+
124
+
125
+ def warm_up(model_name: str = DEFAULT_MODEL) -> int:
126
+ """Load the embedder and run a dummy embed. Returns embedding dimension."""
127
+ vec = Embedder(model_name).embed("warmup")
128
+ return int(vec.shape[0])
129
+
130
+
131
+ def reset_model() -> None:
132
+ """Force the next embed call to recreate the model singleton.
133
+
134
+ Useful in tests or when switching models at runtime.
135
+ """
136
+ global _MODEL
137
+ _MODEL = None
138
+
139
+
140
+ class Embedder:
141
+ """Thin wrapper around the fastembed TextEmbedding model.
142
+
143
+ All embedding operations are synchronous and run on CPU. The model
144
+ is shared across all ``Embedder`` instances via the module-level singleton.
145
+ """
146
+
147
+ def __init__(self, model_name: str = DEFAULT_MODEL) -> None:
148
+ self._model_name = model_name
149
+
150
+ @property
151
+ def _model(self) -> Any:
152
+ return _get_model(self._model_name)
153
+
154
+ def embed(self, text: str) -> np.ndarray:
155
+ """Embed a single text string.
156
+
157
+ Returns a 1-D float32 numpy array of shape ``(dim,)``.
158
+ The vector is **not** L2-normalised here; normalisation is done
159
+ where appropriate (e.g. when computing task centroids).
160
+ """
161
+ # fastembed returns a generator of numpy arrays.
162
+ results = list(self._model.embed([text]))
163
+ return np.array(results[0], dtype=np.float32)
164
+
165
+ def embed_batch(self, texts: list[str]) -> np.ndarray:
166
+ """Embed a list of strings.
167
+
168
+ Returns a 2-D float32 numpy array of shape ``(n, dim)`` where
169
+ ``n = len(texts)``.
170
+ """
171
+ if not texts:
172
+ raise ValueError("embed_batch received an empty list.")
173
+ results = list(self._model.embed(texts))
174
+ return np.array(results, dtype=np.float32)