shiftgate 0.1.9__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {shiftgate-0.1.9 → shiftgate-0.2.1}/PKG-INFO +39 -2
- {shiftgate-0.1.9 → shiftgate-0.2.1}/README.md +38 -1
- {shiftgate-0.1.9 → shiftgate-0.2.1}/pyproject.toml +1 -1
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/cli.py +137 -16
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/registry/adapter_registry.py +26 -0
- shiftgate-0.2.1/shiftgate/router/embedder.py +174 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/router/matcher.py +64 -28
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/router/router.py +17 -1
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/runtime/backend.py +223 -9
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/serve/app.py +123 -23
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/utils/display.py +33 -2
- shiftgate-0.2.1/tests/test_backend.py +466 -0
- shiftgate-0.2.1/tests/test_cli.py +146 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_packaging.py +6 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_router.py +88 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_serve.py +99 -0
- shiftgate-0.1.9/shiftgate/router/embedder.py +0 -95
- shiftgate-0.1.9/tests/test_backend.py +0 -232
- {shiftgate-0.1.9 → shiftgate-0.2.1}/.gitignore +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/data/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/data/default_tasks.json +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/feedback/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/feedback/loop.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/registry/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/registry/schemas.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/registry/task_registry.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/router/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/runtime/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/serve/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/shiftgate/utils/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/__init__.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_feedback.py +0 -0
- {shiftgate-0.1.9 → shiftgate-0.2.1}/tests/test_registry.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: shiftgate
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Intelligent routing layer that automatically selects the right LoRA adapter for each task in your local agent loop.
|
|
5
5
|
Project-URL: Homepage, https://github.com/shiftgate-ai/shiftgate
|
|
6
6
|
Project-URL: Repository, https://github.com/shiftgate-ai/shiftgate
|
|
@@ -276,6 +276,10 @@ User query
|
|
|
276
276
|
└────────────────────────────────┘
|
|
277
277
|
```
|
|
278
278
|
|
|
279
|
+
### How routing works
|
|
280
|
+
|
|
281
|
+
When a backend is active, shiftgate filters candidate adapters to only those actually loaded on that backend. Switch from vLLM to Cerebras and shiftgate automatically picks Cerebras-compatible adapters — no re-registration needed. (When you run `shiftgate route` with no backend running, no filtering is applied, so you still see the full routing preview.)
|
|
282
|
+
|
|
279
283
|
---
|
|
280
284
|
|
|
281
285
|
## Bring Your Own Models
|
|
@@ -352,10 +356,43 @@ shiftgate adapter add llama3.1-8b --runtime llama3.1-8b --tags general --base ll
|
|
|
352
356
|
shiftgate run "write a python sorting function"
|
|
353
357
|
```
|
|
354
358
|
|
|
355
|
-
shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras**, so local backends always win and
|
|
359
|
+
shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras → Cloudflare**, so local backends always win and the cloud backends are used only when no local backend is running.
|
|
356
360
|
|
|
357
361
|
> **Honest status:** shiftgate routes to Cerebras' base-model inference today. When Cerebras Multi-LoRA goes public, register your adapter with `--runtime <cerebras-lora-id>` and it just works — no shiftgate update needed.
|
|
358
362
|
|
|
363
|
+
### Option 5 — Cloudflare Workers AI (cloud, LoRA-native)
|
|
364
|
+
|
|
365
|
+
[Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) serves your own LoRA finetunes on top of supported base models.
|
|
366
|
+
|
|
367
|
+
```bash
|
|
368
|
+
# 1. Upload your LoRA to Cloudflare (one-time)
|
|
369
|
+
npx wrangler ai finetune create @cf/mistral/mistral-7b-instruct-v0.2-lora my-sql-lora ./adapter-folder
|
|
370
|
+
|
|
371
|
+
# 2. Set credentials
|
|
372
|
+
export CLOUDFLARE_ACCOUNT_ID=...
|
|
373
|
+
export CLOUDFLARE_API_TOKEN=...
|
|
374
|
+
|
|
375
|
+
# 3. Register in shiftgate — note --base is the Cloudflare model name
|
|
376
|
+
shiftgate adapter add my-sql-lora \
|
|
377
|
+
--runtime my-sql-lora \
|
|
378
|
+
--base @cf/mistral/mistral-7b-instruct-v0.2-lora \
|
|
379
|
+
--tags sql
|
|
380
|
+
|
|
381
|
+
# 4. Run
|
|
382
|
+
shiftgate run "write a sql join query"
|
|
383
|
+
```
|
|
384
|
+
|
|
385
|
+
You can also pass credentials per-run with the `--cf-account-id` and `--cf-api-token` global flags.
|
|
386
|
+
|
|
387
|
+
**Architectural difference (handled transparently):** Cloudflare keeps the base model in the URL and accepts the LoRA name as a separate `lora` field — not as the `model` value like vLLM/Cerebras. shiftgate handles this transparently; your routing logic doesn't change. The `--base` you register **must** be a Cloudflare model name starting with `@cf/`.
|
|
388
|
+
|
|
389
|
+
**Limitations** (from [Cloudflare's docs](https://developers.cloudflare.com/workers-ai/features/fine-tunes/)):
|
|
390
|
+
|
|
391
|
+
- Up to **100 LoRAs** per account.
|
|
392
|
+
- LoRA file must be **< 300 MB**.
|
|
393
|
+
- Must be trained with rank **r ≤ 8** (up to 32 on some models).
|
|
394
|
+
- **Streaming is not yet supported** through shiftgate for Cloudflare — you get a single response. (Streaming requests to `shiftgate serve` against Cloudflare return HTTP 501.)
|
|
395
|
+
|
|
359
396
|
---
|
|
360
397
|
|
|
361
398
|
## How to contribute adapters
|
|
@@ -240,6 +240,10 @@ User query
|
|
|
240
240
|
└────────────────────────────────┘
|
|
241
241
|
```
|
|
242
242
|
|
|
243
|
+
### How routing works
|
|
244
|
+
|
|
245
|
+
When a backend is active, shiftgate filters candidate adapters to only those actually loaded on that backend. Switch from vLLM to Cerebras and shiftgate automatically picks Cerebras-compatible adapters — no re-registration needed. (When you run `shiftgate route` with no backend running, no filtering is applied, so you still see the full routing preview.)
|
|
246
|
+
|
|
243
247
|
---
|
|
244
248
|
|
|
245
249
|
## Bring Your Own Models
|
|
@@ -316,10 +320,43 @@ shiftgate adapter add llama3.1-8b --runtime llama3.1-8b --tags general --base ll
|
|
|
316
320
|
shiftgate run "write a python sorting function"
|
|
317
321
|
```
|
|
318
322
|
|
|
319
|
-
shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras**, so local backends always win and
|
|
323
|
+
shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras → Cloudflare**, so local backends always win and the cloud backends are used only when no local backend is running.
|
|
320
324
|
|
|
321
325
|
> **Honest status:** shiftgate routes to Cerebras' base-model inference today. When Cerebras Multi-LoRA goes public, register your adapter with `--runtime <cerebras-lora-id>` and it just works — no shiftgate update needed.
|
|
322
326
|
|
|
327
|
+
### Option 5 — Cloudflare Workers AI (cloud, LoRA-native)
|
|
328
|
+
|
|
329
|
+
[Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) serves your own LoRA finetunes on top of supported base models.
|
|
330
|
+
|
|
331
|
+
```bash
|
|
332
|
+
# 1. Upload your LoRA to Cloudflare (one-time)
|
|
333
|
+
npx wrangler ai finetune create @cf/mistral/mistral-7b-instruct-v0.2-lora my-sql-lora ./adapter-folder
|
|
334
|
+
|
|
335
|
+
# 2. Set credentials
|
|
336
|
+
export CLOUDFLARE_ACCOUNT_ID=...
|
|
337
|
+
export CLOUDFLARE_API_TOKEN=...
|
|
338
|
+
|
|
339
|
+
# 3. Register in shiftgate — note --base is the Cloudflare model name
|
|
340
|
+
shiftgate adapter add my-sql-lora \
|
|
341
|
+
--runtime my-sql-lora \
|
|
342
|
+
--base @cf/mistral/mistral-7b-instruct-v0.2-lora \
|
|
343
|
+
--tags sql
|
|
344
|
+
|
|
345
|
+
# 4. Run
|
|
346
|
+
shiftgate run "write a sql join query"
|
|
347
|
+
```
|
|
348
|
+
|
|
349
|
+
You can also pass credentials per-run with the `--cf-account-id` and `--cf-api-token` global flags.
|
|
350
|
+
|
|
351
|
+
**Architectural difference (handled transparently):** Cloudflare keeps the base model in the URL and accepts the LoRA name as a separate `lora` field — not as the `model` value like vLLM/Cerebras. shiftgate handles this transparently; your routing logic doesn't change. The `--base` you register **must** be a Cloudflare model name starting with `@cf/`.
|
|
352
|
+
|
|
353
|
+
**Limitations** (from [Cloudflare's docs](https://developers.cloudflare.com/workers-ai/features/fine-tunes/)):
|
|
354
|
+
|
|
355
|
+
- Up to **100 LoRAs** per account.
|
|
356
|
+
- LoRA file must be **< 300 MB**.
|
|
357
|
+
- Must be trained with rank **r ≤ 8** (up to 32 on some models).
|
|
358
|
+
- **Streaming is not yet supported** through shiftgate for Cloudflare — you get a single response. (Streaming requests to `shiftgate serve` against Cloudflare return HTTP 501.)
|
|
359
|
+
|
|
323
360
|
---
|
|
324
361
|
|
|
325
362
|
## How to contribute adapters
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "shiftgate"
|
|
7
|
-
version = "0.1
|
|
7
|
+
version = "0.2.1"
|
|
8
8
|
description = "Intelligent routing layer that automatically selects the right LoRA adapter for each task in your local agent loop."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
@@ -52,12 +52,47 @@ def _main(
|
|
|
52
52
|
),
|
|
53
53
|
),
|
|
54
54
|
] = None,
|
|
55
|
+
cf_account_id: Annotated[
|
|
56
|
+
Optional[str],
|
|
57
|
+
typer.Option(
|
|
58
|
+
"--cf-account-id",
|
|
59
|
+
help="Cloudflare account ID. Sets CLOUDFLARE_ACCOUNT_ID for this run.",
|
|
60
|
+
),
|
|
61
|
+
] = None,
|
|
62
|
+
cf_api_token: Annotated[
|
|
63
|
+
Optional[str],
|
|
64
|
+
typer.Option(
|
|
65
|
+
"--cf-api-token",
|
|
66
|
+
help="Cloudflare API token. Sets CLOUDFLARE_API_TOKEN for this run.",
|
|
67
|
+
),
|
|
68
|
+
] = None,
|
|
69
|
+
verbose: Annotated[
|
|
70
|
+
bool,
|
|
71
|
+
typer.Option(
|
|
72
|
+
"--verbose",
|
|
73
|
+
"-v",
|
|
74
|
+
help="Enable DEBUG logging (shows routing internals like runtime filtering).",
|
|
75
|
+
),
|
|
76
|
+
] = False,
|
|
55
77
|
) -> None:
|
|
56
78
|
"""Global options applied before any command runs."""
|
|
57
|
-
|
|
58
|
-
import os
|
|
79
|
+
import os
|
|
59
80
|
|
|
81
|
+
if cerebras_key:
|
|
60
82
|
os.environ["CEREBRAS_API_KEY"] = cerebras_key
|
|
83
|
+
if cf_account_id:
|
|
84
|
+
os.environ["CLOUDFLARE_ACCOUNT_ID"] = cf_account_id
|
|
85
|
+
if cf_api_token:
|
|
86
|
+
os.environ["CLOUDFLARE_API_TOKEN"] = cf_api_token
|
|
87
|
+
|
|
88
|
+
if verbose:
|
|
89
|
+
import logging
|
|
90
|
+
|
|
91
|
+
logging.basicConfig(
|
|
92
|
+
level=logging.DEBUG,
|
|
93
|
+
format="%(levelname)s [%(name)s] %(message)s",
|
|
94
|
+
)
|
|
95
|
+
logging.getLogger("shiftgate").setLevel(logging.DEBUG)
|
|
61
96
|
|
|
62
97
|
|
|
63
98
|
# ---------------------------------------------------------------------------
|
|
@@ -83,6 +118,35 @@ def _get_embedder():
|
|
|
83
118
|
return Embedder()
|
|
84
119
|
|
|
85
120
|
|
|
121
|
+
def _active_runtimes(backend_router, adapter_reg=None) -> set[str] | None:
|
|
122
|
+
"""Return the set of runtime names usable on the active backend, or None.
|
|
123
|
+
|
|
124
|
+
``None`` means no backend is active → the router should not filter
|
|
125
|
+
(preview behaviour). An empty set means a backend is active but reports no
|
|
126
|
+
usable models.
|
|
127
|
+
|
|
128
|
+
For Cloudflare, base models are always available without any finetune
|
|
129
|
+
upload, so every registered ``@cf/`` adapter that has no finetune
|
|
130
|
+
``runtime_name`` is considered usable (base-model inference).
|
|
131
|
+
"""
|
|
132
|
+
active = backend_router.active_backend
|
|
133
|
+
if active is None:
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
usable = set(active.list_loaded_adapters_cached())
|
|
137
|
+
|
|
138
|
+
from shiftgate.runtime.backend import CloudflareBackend
|
|
139
|
+
|
|
140
|
+
if isinstance(active, CloudflareBackend) and adapter_reg is not None:
|
|
141
|
+
for adapter in adapter_reg.list_adapters():
|
|
142
|
+
is_cf_base = (adapter.base_model or "").startswith("@cf/")
|
|
143
|
+
has_finetune = bool((adapter.runtime_name or "").strip())
|
|
144
|
+
if is_cf_base and not has_finetune:
|
|
145
|
+
usable.add(adapter.effective_backend_name())
|
|
146
|
+
|
|
147
|
+
return usable
|
|
148
|
+
|
|
149
|
+
|
|
86
150
|
def _auto_link_adapter(adapter: AdapterEntry, task_reg) -> list[str]:
|
|
87
151
|
"""Add ``adapter.id`` to the ``preferred_adapters`` of matching task clusters.
|
|
88
152
|
|
|
@@ -147,9 +211,15 @@ def _verify_runtime_adapter(adapter: AdapterEntry, adapter_reg) -> None:
|
|
|
147
211
|
"""
|
|
148
212
|
from shiftgate.runtime.backend import BackendRouter
|
|
149
213
|
|
|
214
|
+
# A Cloudflare base model (@cf/...) implies this is a Cloudflare adapter —
|
|
215
|
+
# prefer the Cloudflare backend for verification when it's reachable.
|
|
216
|
+
is_cloudflare = (adapter.base_model or "").startswith("@cf/")
|
|
217
|
+
|
|
150
218
|
try:
|
|
151
219
|
with console.status("[cyan]Verifying adapter against running backend…[/cyan]"):
|
|
152
220
|
router = BackendRouter()
|
|
221
|
+
if is_cloudflare and router._cloudflare.is_available():
|
|
222
|
+
router.select("cloudflare")
|
|
153
223
|
is_loaded, backend_name = router.verify_adapter(adapter)
|
|
154
224
|
except Exception as exc: # pragma: no cover - defensive, should not happen
|
|
155
225
|
logger_msg = f"verification error: {exc}"
|
|
@@ -166,9 +236,14 @@ def _verify_runtime_adapter(adapter: AdapterEntry, adapter_reg) -> None:
|
|
|
166
236
|
console.print(f" [green]Backend: {backend_name} ✓ verified[/green]")
|
|
167
237
|
else:
|
|
168
238
|
adapter.verified = False
|
|
239
|
+
hint = (
|
|
240
|
+
"— upload it with `npx wrangler ai finetune create`"
|
|
241
|
+
if backend_name == "cloudflare"
|
|
242
|
+
else "— did you pass --lora-modules?"
|
|
243
|
+
)
|
|
169
244
|
console.print(
|
|
170
245
|
f" [yellow]Backend: {backend_name} ⚠ runtime '{runtime}' not loaded "
|
|
171
|
-
"
|
|
246
|
+
f"{hint}[/yellow]"
|
|
172
247
|
)
|
|
173
248
|
|
|
174
249
|
adapter_reg.save()
|
|
@@ -194,12 +269,23 @@ def init() -> None:
|
|
|
194
269
|
task_reg = TaskRegistry.load()
|
|
195
270
|
|
|
196
271
|
if task_reg.embeddings_ready():
|
|
197
|
-
console.print(
|
|
272
|
+
console.print(
|
|
273
|
+
"[dim]Task centroids already computed — skipping re-embed "
|
|
274
|
+
"(delete embeddings_cache.npy to force refresh).[/dim]"
|
|
275
|
+
)
|
|
198
276
|
else:
|
|
199
277
|
console.print("[cyan]Computing task embeddings (first run — model download may take a moment)…[/cyan]")
|
|
200
278
|
embedder = _get_embedder()
|
|
201
279
|
task_reg.compute_embeddings(embedder)
|
|
202
|
-
console.print("[green]
|
|
280
|
+
console.print("[green]OK[/green] Embeddings computed.")
|
|
281
|
+
|
|
282
|
+
# Centroids can exist while the ONNX runtime model is missing (e.g. after
|
|
283
|
+
# deleting fastembed_cache). Always warm-load so routing works.
|
|
284
|
+
console.print("[cyan]Loading embedder model (one-time ~30 MB download if not cached)…[/cyan]")
|
|
285
|
+
from shiftgate.router.embedder import warm_up
|
|
286
|
+
|
|
287
|
+
dim = warm_up()
|
|
288
|
+
console.print(f"[green]OK[/green] Embedder ready (dim={dim}).")
|
|
203
289
|
|
|
204
290
|
task_reg.save()
|
|
205
291
|
console.print(f"[green]✓[/green] Task registry saved to {shiftgate_dir}")
|
|
@@ -288,6 +374,7 @@ def adapter_add(
|
|
|
288
374
|
"""
|
|
289
375
|
from shiftgate.registry.adapter_registry import (
|
|
290
376
|
AdapterRegistry,
|
|
377
|
+
adapter_from_base_model,
|
|
291
378
|
adapter_from_hf,
|
|
292
379
|
adapter_from_local,
|
|
293
380
|
adapter_from_runtime,
|
|
@@ -334,12 +421,22 @@ def adapter_add(
|
|
|
334
421
|
**shared_kwargs,
|
|
335
422
|
)
|
|
336
423
|
|
|
337
|
-
# ---
|
|
424
|
+
# --- Mode D: Cloudflare base model (always available, no finetune) ---
|
|
425
|
+
elif (base or "").startswith("@cf/"):
|
|
426
|
+
base_model_kwargs = {k: v for k, v in shared_kwargs.items() if k != "base_model"}
|
|
427
|
+
adapter = adapter_from_base_model(
|
|
428
|
+
base_model=base,
|
|
429
|
+
adapter_id=adapter_id or identifier,
|
|
430
|
+
**base_model_kwargs,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# --- Ambiguous: no '/', no --local, no --runtime, no @cf/ base ---
|
|
338
434
|
else:
|
|
339
435
|
console.print(
|
|
340
436
|
f"[red]Error:[/red] '{identifier}' doesn't look like a HuggingFace repo ID (missing '/').\n"
|
|
341
|
-
" Use [cyan]--local /path/to/adapter[/cyan] to register a local adapter
|
|
342
|
-
" use [cyan]--runtime <backend-name>[/cyan] for a runtime-registered adapter
|
|
437
|
+
" Use [cyan]--local /path/to/adapter[/cyan] to register a local adapter,\n"
|
|
438
|
+
" use [cyan]--runtime <backend-name>[/cyan] for a runtime-registered adapter, or\n"
|
|
439
|
+
" use [cyan]--base @cf/...[/cyan] for a Cloudflare Workers AI base model."
|
|
343
440
|
)
|
|
344
441
|
raise typer.Exit(1)
|
|
345
442
|
|
|
@@ -350,6 +447,11 @@ def adapter_add(
|
|
|
350
447
|
# has it loaded. Purely informational — never fails the add command.
|
|
351
448
|
if adapter.runtime_name:
|
|
352
449
|
_verify_runtime_adapter(adapter, adapter_reg)
|
|
450
|
+
elif (adapter.base_model or "").startswith("@cf/"):
|
|
451
|
+
console.print(
|
|
452
|
+
" [green]Backend:[/green] cloudflare "
|
|
453
|
+
"[green]✓[/green] base model is always available (no upload needed)"
|
|
454
|
+
)
|
|
353
455
|
|
|
354
456
|
|
|
355
457
|
@adapter_app.command("list")
|
|
@@ -464,6 +566,7 @@ def route(
|
|
|
464
566
|
"""
|
|
465
567
|
from shiftgate.feedback import loop as feedback_loop
|
|
466
568
|
from shiftgate.router import router as routing
|
|
569
|
+
from shiftgate.runtime.backend import BackendRouter
|
|
467
570
|
from shiftgate.utils.display import show_explain_decision, show_routing_decision
|
|
468
571
|
|
|
469
572
|
task_reg, adapter_reg = _load_registries()
|
|
@@ -474,8 +577,15 @@ def route(
|
|
|
474
577
|
|
|
475
578
|
embedder = _get_embedder()
|
|
476
579
|
|
|
580
|
+
backend_router = BackendRouter()
|
|
581
|
+
backend_name = backend_router.detect()
|
|
582
|
+
available_runtimes = _active_runtimes(backend_router, adapter_reg)
|
|
583
|
+
|
|
477
584
|
try:
|
|
478
|
-
trace, match_result = routing.route(
|
|
585
|
+
trace, match_result = routing.route(
|
|
586
|
+
query, task_reg, adapter_reg, embedder,
|
|
587
|
+
top_k=top_k, available_runtimes=available_runtimes,
|
|
588
|
+
)
|
|
479
589
|
except Exception as exc:
|
|
480
590
|
console.print(f"[red]Routing error:[/red] {exc}")
|
|
481
591
|
raise typer.Exit(1)
|
|
@@ -487,7 +597,9 @@ def route(
|
|
|
487
597
|
trace,
|
|
488
598
|
adapter=adapter,
|
|
489
599
|
task_name=task.name if task else None,
|
|
490
|
-
backend_name=
|
|
600
|
+
backend_name=backend_name,
|
|
601
|
+
loaded_runtimes=available_runtimes,
|
|
602
|
+
selection_method=match_result.selection_method,
|
|
491
603
|
)
|
|
492
604
|
|
|
493
605
|
if explain:
|
|
@@ -524,22 +636,29 @@ def run(
|
|
|
524
636
|
|
|
525
637
|
embedder = _get_embedder()
|
|
526
638
|
|
|
639
|
+
backend_router = BackendRouter()
|
|
640
|
+
backend_name = backend_router.detect()
|
|
641
|
+
available_runtimes = _active_runtimes(backend_router, adapter_reg)
|
|
642
|
+
|
|
527
643
|
try:
|
|
528
|
-
trace, match_result = routing.route(
|
|
644
|
+
trace, match_result = routing.route(
|
|
645
|
+
query, task_reg, adapter_reg, embedder,
|
|
646
|
+
top_k=top_k, available_runtimes=available_runtimes,
|
|
647
|
+
)
|
|
529
648
|
except Exception as exc:
|
|
530
649
|
console.print(f"[red]Routing error:[/red] {exc}")
|
|
531
650
|
raise typer.Exit(1)
|
|
532
651
|
|
|
533
652
|
adapter = adapter_reg.get_adapter(trace.selected_adapter_id)
|
|
534
653
|
task = task_reg.get_task(trace.matched_task_id)
|
|
535
|
-
backend_router = BackendRouter()
|
|
536
|
-
backend_name = backend_router.detect()
|
|
537
654
|
|
|
538
655
|
show_routing_decision(
|
|
539
656
|
trace,
|
|
540
657
|
adapter=adapter,
|
|
541
658
|
task_name=task.name if task else None,
|
|
542
659
|
backend_name=backend_name,
|
|
660
|
+
loaded_runtimes=available_runtimes,
|
|
661
|
+
selection_method=match_result.selection_method,
|
|
543
662
|
)
|
|
544
663
|
|
|
545
664
|
if adapter is None:
|
|
@@ -686,9 +805,11 @@ def doctor() -> None:
|
|
|
686
805
|
router = BackendRouter()
|
|
687
806
|
backend_name = router.detect()
|
|
688
807
|
backend_url = router.active_backend_url
|
|
689
|
-
loaded_adapters:
|
|
808
|
+
loaded_adapters: set[str] = set()
|
|
690
809
|
if backend_name is not None and router._active is not None:
|
|
691
|
-
|
|
810
|
+
# Cloudflare base models are always available, so use the same
|
|
811
|
+
# usable-runtime computation the router uses when filtering.
|
|
812
|
+
loaded_adapters = _active_runtimes(router, adapter_reg) or set()
|
|
692
813
|
|
|
693
814
|
# --- 3. Per-adapter runtime availability ---
|
|
694
815
|
adapter_rows = []
|
|
@@ -744,7 +865,7 @@ def serve(
|
|
|
744
865
|
str,
|
|
745
866
|
typer.Option(
|
|
746
867
|
"--backend",
|
|
747
|
-
help="Backend to forward to: auto | ollama | vllm | cerebras.",
|
|
868
|
+
help="Backend to forward to: auto | ollama | vllm | cerebras | cloudflare.",
|
|
748
869
|
),
|
|
749
870
|
] = "auto",
|
|
750
871
|
) -> None:
|
|
@@ -244,6 +244,32 @@ def adapter_from_runtime(
|
|
|
244
244
|
)
|
|
245
245
|
|
|
246
246
|
|
|
247
|
+
def adapter_from_base_model(
|
|
248
|
+
base_model: str,
|
|
249
|
+
*,
|
|
250
|
+
adapter_id: str,
|
|
251
|
+
name: str | None = None,
|
|
252
|
+
tags: list[str] | None = None,
|
|
253
|
+
description: str | None = None,
|
|
254
|
+
benchmark_score: float | None = None,
|
|
255
|
+
) -> AdapterEntry:
|
|
256
|
+
"""Build an AdapterEntry that routes to a base model with no finetune.
|
|
257
|
+
|
|
258
|
+
Used for backends whose base models are always available without any
|
|
259
|
+
upload (e.g. Cloudflare Workers AI ``@cf/`` models). ``runtime_name`` is
|
|
260
|
+
left unset so the backend runs the base model directly.
|
|
261
|
+
"""
|
|
262
|
+
slug = _slugify(adapter_id)
|
|
263
|
+
return AdapterEntry(
|
|
264
|
+
id=slug,
|
|
265
|
+
name=name or slug.replace("-", " ").title(),
|
|
266
|
+
base_model=base_model,
|
|
267
|
+
task_tags=list(tags or []),
|
|
268
|
+
description=description or f"Base model: {base_model}",
|
|
269
|
+
benchmark_score=benchmark_score,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
247
273
|
# ---------------------------------------------------------------------------
|
|
248
274
|
# Internal helpers
|
|
249
275
|
# ---------------------------------------------------------------------------
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Text embedder backed by fastembed.
|
|
3
|
+
|
|
4
|
+
Uses ``BAAI/bge-small-en-v1.5`` — a compact (33 M param) model that runs
|
|
5
|
+
efficiently on CPU. The model is downloaded once by fastembed and cached in
|
|
6
|
+
``~/.shiftgate/fastembed_cache`` (avoids Windows ``%TEMP%`` corruption issues).
|
|
7
|
+
|
|
8
|
+
A module-level singleton (``_MODEL``) is created lazily on first use so that
|
|
9
|
+
importing this module is cheap. The model is NOT re-created between calls.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
import time
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Stable cache location — fastembed defaults to %TEMP% on Windows, which is
|
|
24
|
+
# prone to partial downloads and "file sizes do not match" corruption.
|
|
25
|
+
FASTEMBED_CACHE_DIR = Path.home() / ".shiftgate" / "fastembed_cache"
|
|
26
|
+
|
|
27
|
+
# -------------------------------------------------------------------------
|
|
28
|
+
# Default model — small, CPU-friendly, strong quality/speed trade-off.
|
|
29
|
+
# -------------------------------------------------------------------------
|
|
30
|
+
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
|
|
31
|
+
|
|
32
|
+
# HuggingFace downloads can flake; retry a few times before giving up.
|
|
33
|
+
_LOAD_RETRIES = 3
|
|
34
|
+
_LOAD_RETRY_DELAY_S = 2.0
|
|
35
|
+
|
|
36
|
+
# Module-level singleton; populated on first call to `_get_model()`.
|
|
37
|
+
_MODEL: Any | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _is_retryable_download_error(exc: BaseException) -> bool:
|
|
41
|
+
"""Return True for transient HuggingFace / network failures."""
|
|
42
|
+
msg = str(exc).lower()
|
|
43
|
+
needles = (
|
|
44
|
+
"server disconnected",
|
|
45
|
+
"connection reset",
|
|
46
|
+
"connection aborted",
|
|
47
|
+
"timed out",
|
|
48
|
+
"timeout",
|
|
49
|
+
"temporary failure",
|
|
50
|
+
"503",
|
|
51
|
+
"502",
|
|
52
|
+
"429",
|
|
53
|
+
)
|
|
54
|
+
return any(n in msg for n in needles)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _format_load_error(model_name: str, exc: BaseException) -> str:
|
|
58
|
+
cache_hint = str(FASTEMBED_CACHE_DIR)
|
|
59
|
+
if _is_retryable_download_error(exc):
|
|
60
|
+
return (
|
|
61
|
+
f"Failed to download embedding model '{model_name}' from HuggingFace: {exc}\n"
|
|
62
|
+
"This is usually a transient network/rate-limit issue. Retry:\n"
|
|
63
|
+
f" uv run shiftgate init\n"
|
|
64
|
+
"Optional: set HF_TOKEN for higher HuggingFace rate limits.\n"
|
|
65
|
+
f"If downloads keep failing, delete '{cache_hint}' and retry."
|
|
66
|
+
)
|
|
67
|
+
return (
|
|
68
|
+
f"Failed to load embedding model '{model_name}': {exc}\n"
|
|
69
|
+
"If you see NO_SUCHFILE or 'file sizes do not match', delete the cache "
|
|
70
|
+
"and retry:\n"
|
|
71
|
+
f" Remove-Item -Recurse -Force '{cache_hint}'\n"
|
|
72
|
+
"Also clear any stale copy at $env:TEMP\\fastembed_cache, then run "
|
|
73
|
+
"`shiftgate init` again."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _get_model(model_name: str = DEFAULT_MODEL) -> Any:
|
|
78
|
+
"""Return the fastembed TextEmbedding singleton, creating it if needed.
|
|
79
|
+
|
|
80
|
+
The model is loaded once per process. If you need a different model,
|
|
81
|
+
call ``reset_model()`` first.
|
|
82
|
+
"""
|
|
83
|
+
global _MODEL
|
|
84
|
+
if _MODEL is None:
|
|
85
|
+
try:
|
|
86
|
+
from fastembed import TextEmbedding # type: ignore
|
|
87
|
+
except ImportError as exc:
|
|
88
|
+
raise ImportError(
|
|
89
|
+
"fastembed is required for shiftgate routing. "
|
|
90
|
+
"Install it with: pip install fastembed"
|
|
91
|
+
) from exc
|
|
92
|
+
|
|
93
|
+
FASTEMBED_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
94
|
+
logger.info(
|
|
95
|
+
"Loading embedding model '%s' (first use — one-time download may occur)…",
|
|
96
|
+
model_name,
|
|
97
|
+
)
|
|
98
|
+
last_exc: BaseException | None = None
|
|
99
|
+
for attempt in range(1, _LOAD_RETRIES + 1):
|
|
100
|
+
try:
|
|
101
|
+
_MODEL = TextEmbedding(
|
|
102
|
+
model_name=model_name,
|
|
103
|
+
cache_dir=str(FASTEMBED_CACHE_DIR),
|
|
104
|
+
)
|
|
105
|
+
break
|
|
106
|
+
except Exception as exc:
|
|
107
|
+
last_exc = exc
|
|
108
|
+
if attempt < _LOAD_RETRIES and _is_retryable_download_error(exc):
|
|
109
|
+
logger.warning(
|
|
110
|
+
"Embedder load attempt %d/%d failed (%s); retrying…",
|
|
111
|
+
attempt,
|
|
112
|
+
_LOAD_RETRIES,
|
|
113
|
+
exc,
|
|
114
|
+
)
|
|
115
|
+
time.sleep(_LOAD_RETRY_DELAY_S * attempt)
|
|
116
|
+
continue
|
|
117
|
+
raise RuntimeError(_format_load_error(model_name, exc)) from exc
|
|
118
|
+
else:
|
|
119
|
+
assert last_exc is not None
|
|
120
|
+
raise RuntimeError(_format_load_error(model_name, last_exc)) from last_exc
|
|
121
|
+
logger.info("Embedding model loaded.")
|
|
122
|
+
return _MODEL
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def warm_up(model_name: str = DEFAULT_MODEL) -> int:
|
|
126
|
+
"""Load the embedder and run a dummy embed. Returns embedding dimension."""
|
|
127
|
+
vec = Embedder(model_name).embed("warmup")
|
|
128
|
+
return int(vec.shape[0])
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def reset_model() -> None:
|
|
132
|
+
"""Force the next embed call to recreate the model singleton.
|
|
133
|
+
|
|
134
|
+
Useful in tests or when switching models at runtime.
|
|
135
|
+
"""
|
|
136
|
+
global _MODEL
|
|
137
|
+
_MODEL = None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Embedder:
|
|
141
|
+
"""Thin wrapper around the fastembed TextEmbedding model.
|
|
142
|
+
|
|
143
|
+
All embedding operations are synchronous and run on CPU. The model
|
|
144
|
+
is shared across all ``Embedder`` instances via the module-level singleton.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(self, model_name: str = DEFAULT_MODEL) -> None:
|
|
148
|
+
self._model_name = model_name
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def _model(self) -> Any:
|
|
152
|
+
return _get_model(self._model_name)
|
|
153
|
+
|
|
154
|
+
def embed(self, text: str) -> np.ndarray:
|
|
155
|
+
"""Embed a single text string.
|
|
156
|
+
|
|
157
|
+
Returns a 1-D float32 numpy array of shape ``(dim,)``.
|
|
158
|
+
The vector is **not** L2-normalised here; normalisation is done
|
|
159
|
+
where appropriate (e.g. when computing task centroids).
|
|
160
|
+
"""
|
|
161
|
+
# fastembed returns a generator of numpy arrays.
|
|
162
|
+
results = list(self._model.embed([text]))
|
|
163
|
+
return np.array(results[0], dtype=np.float32)
|
|
164
|
+
|
|
165
|
+
def embed_batch(self, texts: list[str]) -> np.ndarray:
|
|
166
|
+
"""Embed a list of strings.
|
|
167
|
+
|
|
168
|
+
Returns a 2-D float32 numpy array of shape ``(n, dim)`` where
|
|
169
|
+
``n = len(texts)``.
|
|
170
|
+
"""
|
|
171
|
+
if not texts:
|
|
172
|
+
raise ValueError("embed_batch received an empty list.")
|
|
173
|
+
results = list(self._model.embed(texts))
|
|
174
|
+
return np.array(results, dtype=np.float32)
|