shiftgate 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {shiftgate-0.2.0 → shiftgate-0.2.1}/PKG-INFO +35 -2
  2. {shiftgate-0.2.0 → shiftgate-0.2.1}/README.md +34 -1
  3. {shiftgate-0.2.0 → shiftgate-0.2.1}/pyproject.toml +1 -1
  4. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/cli.py +108 -17
  5. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/registry/adapter_registry.py +26 -0
  6. shiftgate-0.2.1/shiftgate/router/embedder.py +174 -0
  7. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/router/router.py +9 -0
  8. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/runtime/backend.py +223 -9
  9. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/serve/app.py +102 -23
  10. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/utils/display.py +5 -0
  11. shiftgate-0.2.1/tests/test_backend.py +466 -0
  12. shiftgate-0.2.1/tests/test_cli.py +146 -0
  13. {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_packaging.py +6 -0
  14. {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_serve.py +95 -0
  15. shiftgate-0.2.0/shiftgate/router/embedder.py +0 -95
  16. shiftgate-0.2.0/tests/test_backend.py +0 -232
  17. {shiftgate-0.2.0 → shiftgate-0.2.1}/.gitignore +0 -0
  18. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/__init__.py +0 -0
  19. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/data/__init__.py +0 -0
  20. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/data/default_tasks.json +0 -0
  21. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/feedback/__init__.py +0 -0
  22. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/feedback/loop.py +0 -0
  23. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/registry/__init__.py +0 -0
  24. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/registry/schemas.py +0 -0
  25. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/registry/task_registry.py +0 -0
  26. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/router/__init__.py +0 -0
  27. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/router/matcher.py +0 -0
  28. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/runtime/__init__.py +0 -0
  29. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/serve/__init__.py +0 -0
  30. {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/utils/__init__.py +0 -0
  31. {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/__init__.py +0 -0
  32. {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_feedback.py +0 -0
  33. {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_registry.py +0 -0
  34. {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_router.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shiftgate
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Intelligent routing layer that automatically selects the right LoRA adapter for each task in your local agent loop.
5
5
  Project-URL: Homepage, https://github.com/shiftgate-ai/shiftgate
6
6
  Project-URL: Repository, https://github.com/shiftgate-ai/shiftgate
@@ -356,10 +356,43 @@ shiftgate adapter add llama3.1-8b --runtime llama3.1-8b --tags general --base ll
356
356
  shiftgate run "write a python sorting function"
357
357
  ```
358
358
 
359
- shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras**, so local backends always win and Cerebras is used only when no local backend is running.
359
+ shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras → Cloudflare**, so local backends always win and the cloud backends are used only when no local backend is running.
360
360
 
361
361
  > **Honest status:** shiftgate routes to Cerebras' base-model inference today. When Cerebras Multi-LoRA goes public, register your adapter with `--runtime <cerebras-lora-id>` and it just works — no shiftgate update needed.
362
362
 
363
+ ### Option 5 — Cloudflare Workers AI (cloud, LoRA-native)
364
+
365
+ [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) serves your own LoRA finetunes on top of supported base models.
366
+
367
+ ```bash
368
+ # 1. Upload your LoRA to Cloudflare (one-time)
369
+ npx wrangler ai finetune create @cf/mistral/mistral-7b-instruct-v0.2-lora my-sql-lora ./adapter-folder
370
+
371
+ # 2. Set credentials
372
+ export CLOUDFLARE_ACCOUNT_ID=...
373
+ export CLOUDFLARE_API_TOKEN=...
374
+
375
+ # 3. Register in shiftgate — note --base is the Cloudflare model name
376
+ shiftgate adapter add my-sql-lora \
377
+ --runtime my-sql-lora \
378
+ --base @cf/mistral/mistral-7b-instruct-v0.2-lora \
379
+ --tags sql
380
+
381
+ # 4. Run
382
+ shiftgate run "write a sql join query"
383
+ ```
384
+
385
+ You can also pass credentials per-run with the `--cf-account-id` and `--cf-api-token` global flags.
386
+
387
+ **Architectural difference (handled transparently):** Cloudflare keeps the base model in the URL and accepts the LoRA name as a separate `lora` field — not as the `model` value like vLLM/Cerebras. shiftgate handles this transparently; your routing logic doesn't change. The `--base` you register **must** be a Cloudflare model name starting with `@cf/`.
388
+
389
+ **Limitations** (from [Cloudflare's docs](https://developers.cloudflare.com/workers-ai/features/fine-tunes/)):
390
+
391
+ - Up to **100 LoRAs** per account.
392
+ - LoRA file must be **< 300 MB**.
393
+ - Must be trained with rank **r ≤ 8** (up to 32 on some models).
394
+ - **Streaming is not yet supported** through shiftgate for Cloudflare — you get a single response. (Streaming requests to `shiftgate serve` against Cloudflare return HTTP 501.)
395
+
363
396
  ---
364
397
 
365
398
  ## How to contribute adapters
@@ -320,10 +320,43 @@ shiftgate adapter add llama3.1-8b --runtime llama3.1-8b --tags general --base ll
320
320
  shiftgate run "write a python sorting function"
321
321
  ```
322
322
 
323
- shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras**, so local backends always win and Cerebras is used only when no local backend is running.
323
+ shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras → Cloudflare**, so local backends always win and the cloud backends are used only when no local backend is running.
324
324
 
325
325
  > **Honest status:** shiftgate routes to Cerebras' base-model inference today. When Cerebras Multi-LoRA goes public, register your adapter with `--runtime <cerebras-lora-id>` and it just works — no shiftgate update needed.
326
326
 
327
+ ### Option 5 — Cloudflare Workers AI (cloud, LoRA-native)
328
+
329
+ [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) serves your own LoRA finetunes on top of supported base models.
330
+
331
+ ```bash
332
+ # 1. Upload your LoRA to Cloudflare (one-time)
333
+ npx wrangler ai finetune create @cf/mistral/mistral-7b-instruct-v0.2-lora my-sql-lora ./adapter-folder
334
+
335
+ # 2. Set credentials
336
+ export CLOUDFLARE_ACCOUNT_ID=...
337
+ export CLOUDFLARE_API_TOKEN=...
338
+
339
+ # 3. Register in shiftgate — note --base is the Cloudflare model name
340
+ shiftgate adapter add my-sql-lora \
341
+ --runtime my-sql-lora \
342
+ --base @cf/mistral/mistral-7b-instruct-v0.2-lora \
343
+ --tags sql
344
+
345
+ # 4. Run
346
+ shiftgate run "write a sql join query"
347
+ ```
348
+
349
+ You can also pass credentials per-run with the `--cf-account-id` and `--cf-api-token` global flags.
350
+
351
+ **Architectural difference (handled transparently):** Cloudflare keeps the base model in the URL and accepts the LoRA name as a separate `lora` field — not as the `model` value like vLLM/Cerebras. shiftgate handles this transparently; your routing logic doesn't change. The `--base` you register **must** be a Cloudflare model name starting with `@cf/`.
352
+
353
+ **Limitations** (from [Cloudflare's docs](https://developers.cloudflare.com/workers-ai/features/fine-tunes/)):
354
+
355
+ - Up to **100 LoRAs** per account.
356
+ - LoRA file must be **< 300 MB**.
357
+ - Must be trained with rank **r ≤ 8** (up to 32 on some models).
358
+ - **Streaming is not yet supported** through shiftgate for Cloudflare — you get a single response. (Streaming requests to `shiftgate serve` against Cloudflare return HTTP 501.)
359
+
327
360
  ---
328
361
 
329
362
  ## How to contribute adapters
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "shiftgate"
7
- version = "0.2.0"
7
+ version = "0.2.1"
8
8
  description = "Intelligent routing layer that automatically selects the right LoRA adapter for each task in your local agent loop."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -52,12 +52,47 @@ def _main(
52
52
  ),
53
53
  ),
54
54
  ] = None,
55
+ cf_account_id: Annotated[
56
+ Optional[str],
57
+ typer.Option(
58
+ "--cf-account-id",
59
+ help="Cloudflare account ID. Sets CLOUDFLARE_ACCOUNT_ID for this run.",
60
+ ),
61
+ ] = None,
62
+ cf_api_token: Annotated[
63
+ Optional[str],
64
+ typer.Option(
65
+ "--cf-api-token",
66
+ help="Cloudflare API token. Sets CLOUDFLARE_API_TOKEN for this run.",
67
+ ),
68
+ ] = None,
69
+ verbose: Annotated[
70
+ bool,
71
+ typer.Option(
72
+ "--verbose",
73
+ "-v",
74
+ help="Enable DEBUG logging (shows routing internals like runtime filtering).",
75
+ ),
76
+ ] = False,
55
77
  ) -> None:
56
78
  """Global options applied before any command runs."""
57
- if cerebras_key:
58
- import os
79
+ import os
59
80
 
81
+ if cerebras_key:
60
82
  os.environ["CEREBRAS_API_KEY"] = cerebras_key
83
+ if cf_account_id:
84
+ os.environ["CLOUDFLARE_ACCOUNT_ID"] = cf_account_id
85
+ if cf_api_token:
86
+ os.environ["CLOUDFLARE_API_TOKEN"] = cf_api_token
87
+
88
+ if verbose:
89
+ import logging
90
+
91
+ logging.basicConfig(
92
+ level=logging.DEBUG,
93
+ format="%(levelname)s [%(name)s] %(message)s",
94
+ )
95
+ logging.getLogger("shiftgate").setLevel(logging.DEBUG)
61
96
 
62
97
 
63
98
  # ---------------------------------------------------------------------------
@@ -83,17 +118,33 @@ def _get_embedder():
83
118
  return Embedder()
84
119
 
85
120
 
86
- def _active_runtimes(backend_router) -> set[str] | None:
87
- """Return the set of runtime names loaded on the active backend, or None.
121
+ def _active_runtimes(backend_router, adapter_reg=None) -> set[str] | None:
122
+ """Return the set of runtime names usable on the active backend, or None.
88
123
 
89
124
  ``None`` means no backend is active → the router should not filter
90
125
  (preview behaviour). An empty set means a backend is active but reports no
91
- loaded models.
126
+ usable models.
127
+
128
+ For Cloudflare, base models are always available without any finetune
129
+ upload, so every registered ``@cf/`` adapter that has no finetune
130
+ ``runtime_name`` is considered usable (base-model inference).
92
131
  """
93
132
  active = backend_router.active_backend
94
133
  if active is None:
95
134
  return None
96
- return set(active.list_loaded_adapters())
135
+
136
+ usable = set(active.list_loaded_adapters_cached())
137
+
138
+ from shiftgate.runtime.backend import CloudflareBackend
139
+
140
+ if isinstance(active, CloudflareBackend) and adapter_reg is not None:
141
+ for adapter in adapter_reg.list_adapters():
142
+ is_cf_base = (adapter.base_model or "").startswith("@cf/")
143
+ has_finetune = bool((adapter.runtime_name or "").strip())
144
+ if is_cf_base and not has_finetune:
145
+ usable.add(adapter.effective_backend_name())
146
+
147
+ return usable
97
148
 
98
149
 
99
150
  def _auto_link_adapter(adapter: AdapterEntry, task_reg) -> list[str]:
@@ -160,9 +211,15 @@ def _verify_runtime_adapter(adapter: AdapterEntry, adapter_reg) -> None:
160
211
  """
161
212
  from shiftgate.runtime.backend import BackendRouter
162
213
 
214
+ # A Cloudflare base model (@cf/...) implies this is a Cloudflare adapter —
215
+ # prefer the Cloudflare backend for verification when it's reachable.
216
+ is_cloudflare = (adapter.base_model or "").startswith("@cf/")
217
+
163
218
  try:
164
219
  with console.status("[cyan]Verifying adapter against running backend…[/cyan]"):
165
220
  router = BackendRouter()
221
+ if is_cloudflare and router._cloudflare.is_available():
222
+ router.select("cloudflare")
166
223
  is_loaded, backend_name = router.verify_adapter(adapter)
167
224
  except Exception as exc: # pragma: no cover - defensive, should not happen
168
225
  logger_msg = f"verification error: {exc}"
@@ -179,9 +236,14 @@ def _verify_runtime_adapter(adapter: AdapterEntry, adapter_reg) -> None:
179
236
  console.print(f" [green]Backend: {backend_name} ✓ verified[/green]")
180
237
  else:
181
238
  adapter.verified = False
239
+ hint = (
240
+ "— upload it with `npx wrangler ai finetune create`"
241
+ if backend_name == "cloudflare"
242
+ else "— did you pass --lora-modules?"
243
+ )
182
244
  console.print(
183
245
  f" [yellow]Backend: {backend_name} ⚠ runtime '{runtime}' not loaded "
184
- "— did you pass --lora-modules?[/yellow]"
246
+ f"{hint}[/yellow]"
185
247
  )
186
248
 
187
249
  adapter_reg.save()
@@ -207,12 +269,23 @@ def init() -> None:
207
269
  task_reg = TaskRegistry.load()
208
270
 
209
271
  if task_reg.embeddings_ready():
210
- console.print("[dim]Embeddings already computed. Skipping (delete embeddings_cache.npy to force refresh).[/dim]")
272
+ console.print(
273
+ "[dim]Task centroids already computed — skipping re-embed "
274
+ "(delete embeddings_cache.npy to force refresh).[/dim]"
275
+ )
211
276
  else:
212
277
  console.print("[cyan]Computing task embeddings (first run — model download may take a moment)…[/cyan]")
213
278
  embedder = _get_embedder()
214
279
  task_reg.compute_embeddings(embedder)
215
- console.print("[green][/green] Embeddings computed.")
280
+ console.print("[green]OK[/green] Embeddings computed.")
281
+
282
+ # Centroids can exist while the ONNX runtime model is missing (e.g. after
283
+ # deleting fastembed_cache). Always warm-load so routing works.
284
+ console.print("[cyan]Loading embedder model (one-time ~30 MB download if not cached)…[/cyan]")
285
+ from shiftgate.router.embedder import warm_up
286
+
287
+ dim = warm_up()
288
+ console.print(f"[green]OK[/green] Embedder ready (dim={dim}).")
216
289
 
217
290
  task_reg.save()
218
291
  console.print(f"[green]✓[/green] Task registry saved to {shiftgate_dir}")
@@ -301,6 +374,7 @@ def adapter_add(
301
374
  """
302
375
  from shiftgate.registry.adapter_registry import (
303
376
  AdapterRegistry,
377
+ adapter_from_base_model,
304
378
  adapter_from_hf,
305
379
  adapter_from_local,
306
380
  adapter_from_runtime,
@@ -347,12 +421,22 @@ def adapter_add(
347
421
  **shared_kwargs,
348
422
  )
349
423
 
350
- # --- Ambiguous: no '/', no --local, no --runtime ---
424
+ # --- Mode D: Cloudflare base model (always available, no finetune) ---
425
+ elif (base or "").startswith("@cf/"):
426
+ base_model_kwargs = {k: v for k, v in shared_kwargs.items() if k != "base_model"}
427
+ adapter = adapter_from_base_model(
428
+ base_model=base,
429
+ adapter_id=adapter_id or identifier,
430
+ **base_model_kwargs,
431
+ )
432
+
433
+ # --- Ambiguous: no '/', no --local, no --runtime, no @cf/ base ---
351
434
  else:
352
435
  console.print(
353
436
  f"[red]Error:[/red] '{identifier}' doesn't look like a HuggingFace repo ID (missing '/').\n"
354
- " Use [cyan]--local /path/to/adapter[/cyan] to register a local adapter, or\n"
355
- " use [cyan]--runtime <backend-name>[/cyan] for a runtime-registered adapter."
437
+ " Use [cyan]--local /path/to/adapter[/cyan] to register a local adapter,\n"
438
+ " use [cyan]--runtime <backend-name>[/cyan] for a runtime-registered adapter, or\n"
439
+ " use [cyan]--base @cf/...[/cyan] for a Cloudflare Workers AI base model."
356
440
  )
357
441
  raise typer.Exit(1)
358
442
 
@@ -363,6 +447,11 @@ def adapter_add(
363
447
  # has it loaded. Purely informational — never fails the add command.
364
448
  if adapter.runtime_name:
365
449
  _verify_runtime_adapter(adapter, adapter_reg)
450
+ elif (adapter.base_model or "").startswith("@cf/"):
451
+ console.print(
452
+ " [green]Backend:[/green] cloudflare "
453
+ "[green]✓[/green] base model is always available (no upload needed)"
454
+ )
366
455
 
367
456
 
368
457
  @adapter_app.command("list")
@@ -490,7 +579,7 @@ def route(
490
579
 
491
580
  backend_router = BackendRouter()
492
581
  backend_name = backend_router.detect()
493
- available_runtimes = _active_runtimes(backend_router)
582
+ available_runtimes = _active_runtimes(backend_router, adapter_reg)
494
583
 
495
584
  try:
496
585
  trace, match_result = routing.route(
@@ -549,7 +638,7 @@ def run(
549
638
 
550
639
  backend_router = BackendRouter()
551
640
  backend_name = backend_router.detect()
552
- available_runtimes = _active_runtimes(backend_router)
641
+ available_runtimes = _active_runtimes(backend_router, adapter_reg)
553
642
 
554
643
  try:
555
644
  trace, match_result = routing.route(
@@ -716,9 +805,11 @@ def doctor() -> None:
716
805
  router = BackendRouter()
717
806
  backend_name = router.detect()
718
807
  backend_url = router.active_backend_url
719
- loaded_adapters: list[str] = []
808
+ loaded_adapters: set[str] = set()
720
809
  if backend_name is not None and router._active is not None:
721
- loaded_adapters = router._active.list_loaded_adapters()
810
+ # Cloudflare base models are always available, so use the same
811
+ # usable-runtime computation the router uses when filtering.
812
+ loaded_adapters = _active_runtimes(router, adapter_reg) or set()
722
813
 
723
814
  # --- 3. Per-adapter runtime availability ---
724
815
  adapter_rows = []
@@ -774,7 +865,7 @@ def serve(
774
865
  str,
775
866
  typer.Option(
776
867
  "--backend",
777
- help="Backend to forward to: auto | ollama | vllm | cerebras.",
868
+ help="Backend to forward to: auto | ollama | vllm | cerebras | cloudflare.",
778
869
  ),
779
870
  ] = "auto",
780
871
  ) -> None:
@@ -244,6 +244,32 @@ def adapter_from_runtime(
244
244
  )
245
245
 
246
246
 
247
+ def adapter_from_base_model(
248
+ base_model: str,
249
+ *,
250
+ adapter_id: str,
251
+ name: str | None = None,
252
+ tags: list[str] | None = None,
253
+ description: str | None = None,
254
+ benchmark_score: float | None = None,
255
+ ) -> AdapterEntry:
256
+ """Build an AdapterEntry that routes to a base model with no finetune.
257
+
258
+ Used for backends whose base models are always available without any
259
+ upload (e.g. Cloudflare Workers AI ``@cf/`` models). ``runtime_name`` is
260
+ left unset so the backend runs the base model directly.
261
+ """
262
+ slug = _slugify(adapter_id)
263
+ return AdapterEntry(
264
+ id=slug,
265
+ name=name or slug.replace("-", " ").title(),
266
+ base_model=base_model,
267
+ task_tags=list(tags or []),
268
+ description=description or f"Base model: {base_model}",
269
+ benchmark_score=benchmark_score,
270
+ )
271
+
272
+
247
273
  # ---------------------------------------------------------------------------
248
274
  # Internal helpers
249
275
  # ---------------------------------------------------------------------------
@@ -0,0 +1,174 @@
1
+ """
2
+ Text embedder backed by fastembed.
3
+
4
+ Uses ``BAAI/bge-small-en-v1.5`` — a compact (33 M param) model that runs
5
+ efficiently on CPU. The model is downloaded once by fastembed and cached in
6
+ ``~/.shiftgate/fastembed_cache`` (avoids Windows ``%TEMP%`` corruption issues).
7
+
8
+ A module-level singleton (``_MODEL``) is created lazily on first use so that
9
+ importing this module is cheap. The model is NOT re-created between calls.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import time
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Stable cache location — fastembed defaults to %TEMP% on Windows, which is
24
+ # prone to partial downloads and "file sizes do not match" corruption.
25
+ FASTEMBED_CACHE_DIR = Path.home() / ".shiftgate" / "fastembed_cache"
26
+
27
+ # -------------------------------------------------------------------------
28
+ # Default model — small, CPU-friendly, strong quality/speed trade-off.
29
+ # -------------------------------------------------------------------------
30
+ DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
31
+
32
+ # HuggingFace downloads can flake; retry a few times before giving up.
33
+ _LOAD_RETRIES = 3
34
+ _LOAD_RETRY_DELAY_S = 2.0
35
+
36
+ # Module-level singleton; populated on first call to `_get_model()`.
37
+ _MODEL: Any | None = None
38
+
39
+
40
+ def _is_retryable_download_error(exc: BaseException) -> bool:
41
+ """Return True for transient HuggingFace / network failures."""
42
+ msg = str(exc).lower()
43
+ needles = (
44
+ "server disconnected",
45
+ "connection reset",
46
+ "connection aborted",
47
+ "timed out",
48
+ "timeout",
49
+ "temporary failure",
50
+ "503",
51
+ "502",
52
+ "429",
53
+ )
54
+ return any(n in msg for n in needles)
55
+
56
+
57
+ def _format_load_error(model_name: str, exc: BaseException) -> str:
58
+ cache_hint = str(FASTEMBED_CACHE_DIR)
59
+ if _is_retryable_download_error(exc):
60
+ return (
61
+ f"Failed to download embedding model '{model_name}' from HuggingFace: {exc}\n"
62
+ "This is usually a transient network/rate-limit issue. Retry:\n"
63
+ f" uv run shiftgate init\n"
64
+ "Optional: set HF_TOKEN for higher HuggingFace rate limits.\n"
65
+ f"If downloads keep failing, delete '{cache_hint}' and retry."
66
+ )
67
+ return (
68
+ f"Failed to load embedding model '{model_name}': {exc}\n"
69
+ "If you see NO_SUCHFILE or 'file sizes do not match', delete the cache "
70
+ "and retry:\n"
71
+ f" Remove-Item -Recurse -Force '{cache_hint}'\n"
72
+ "Also clear any stale copy at $env:TEMP\\fastembed_cache, then run "
73
+ "`shiftgate init` again."
74
+ )
75
+
76
+
77
+ def _get_model(model_name: str = DEFAULT_MODEL) -> Any:
78
+ """Return the fastembed TextEmbedding singleton, creating it if needed.
79
+
80
+ The model is loaded once per process. If you need a different model,
81
+ call ``reset_model()`` first.
82
+ """
83
+ global _MODEL
84
+ if _MODEL is None:
85
+ try:
86
+ from fastembed import TextEmbedding # type: ignore
87
+ except ImportError as exc:
88
+ raise ImportError(
89
+ "fastembed is required for shiftgate routing. "
90
+ "Install it with: pip install fastembed"
91
+ ) from exc
92
+
93
+ FASTEMBED_CACHE_DIR.mkdir(parents=True, exist_ok=True)
94
+ logger.info(
95
+ "Loading embedding model '%s' (first use — one-time download may occur)…",
96
+ model_name,
97
+ )
98
+ last_exc: BaseException | None = None
99
+ for attempt in range(1, _LOAD_RETRIES + 1):
100
+ try:
101
+ _MODEL = TextEmbedding(
102
+ model_name=model_name,
103
+ cache_dir=str(FASTEMBED_CACHE_DIR),
104
+ )
105
+ break
106
+ except Exception as exc:
107
+ last_exc = exc
108
+ if attempt < _LOAD_RETRIES and _is_retryable_download_error(exc):
109
+ logger.warning(
110
+ "Embedder load attempt %d/%d failed (%s); retrying…",
111
+ attempt,
112
+ _LOAD_RETRIES,
113
+ exc,
114
+ )
115
+ time.sleep(_LOAD_RETRY_DELAY_S * attempt)
116
+ continue
117
+ raise RuntimeError(_format_load_error(model_name, exc)) from exc
118
+ else:
119
+ assert last_exc is not None
120
+ raise RuntimeError(_format_load_error(model_name, last_exc)) from last_exc
121
+ logger.info("Embedding model loaded.")
122
+ return _MODEL
123
+
124
+
125
+ def warm_up(model_name: str = DEFAULT_MODEL) -> int:
126
+ """Load the embedder and run a dummy embed. Returns embedding dimension."""
127
+ vec = Embedder(model_name).embed("warmup")
128
+ return int(vec.shape[0])
129
+
130
+
131
+ def reset_model() -> None:
132
+ """Force the next embed call to recreate the model singleton.
133
+
134
+ Useful in tests or when switching models at runtime.
135
+ """
136
+ global _MODEL
137
+ _MODEL = None
138
+
139
+
140
+ class Embedder:
141
+ """Thin wrapper around the fastembed TextEmbedding model.
142
+
143
+ All embedding operations are synchronous and run on CPU. The model
144
+ is shared across all ``Embedder`` instances via the module-level singleton.
145
+ """
146
+
147
+ def __init__(self, model_name: str = DEFAULT_MODEL) -> None:
148
+ self._model_name = model_name
149
+
150
+ @property
151
+ def _model(self) -> Any:
152
+ return _get_model(self._model_name)
153
+
154
+ def embed(self, text: str) -> np.ndarray:
155
+ """Embed a single text string.
156
+
157
+ Returns a 1-D float32 numpy array of shape ``(dim,)``.
158
+ The vector is **not** L2-normalised here; normalisation is done
159
+ where appropriate (e.g. when computing task centroids).
160
+ """
161
+ # fastembed returns a generator of numpy arrays.
162
+ results = list(self._model.embed([text]))
163
+ return np.array(results[0], dtype=np.float32)
164
+
165
+ def embed_batch(self, texts: list[str]) -> np.ndarray:
166
+ """Embed a list of strings.
167
+
168
+ Returns a 2-D float32 numpy array of shape ``(n, dim)`` where
169
+ ``n = len(texts)``.
170
+ """
171
+ if not texts:
172
+ raise ValueError("embed_batch received an empty list.")
173
+ results = list(self._model.embed(texts))
174
+ return np.array(results, dtype=np.float32)
@@ -80,6 +80,15 @@ def route(
80
80
  query_embedding = embedder.embed(query)
81
81
  all_tasks = task_registry.get_all_tasks()
82
82
  ranked = top_k_tasks(query_embedding, all_tasks, k=top_k)
83
+
84
+ if available_runtimes is not None:
85
+ logger.debug(
86
+ "filtering adapters to backend runtimes: %s",
87
+ sorted(available_runtimes),
88
+ )
89
+ else:
90
+ logger.debug("no active backend — adapter runtime filtering disabled")
91
+
83
92
  result = select_adapter(ranked, adapter_registry, available_runtimes=available_runtimes)
84
93
 
85
94
  selected_id = result.selected_adapter.id if result.selected_adapter else None