shiftgate 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {shiftgate-0.2.0 → shiftgate-0.2.1}/PKG-INFO +35 -2
- {shiftgate-0.2.0 → shiftgate-0.2.1}/README.md +34 -1
- {shiftgate-0.2.0 → shiftgate-0.2.1}/pyproject.toml +1 -1
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/cli.py +108 -17
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/registry/adapter_registry.py +26 -0
- shiftgate-0.2.1/shiftgate/router/embedder.py +174 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/router/router.py +9 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/runtime/backend.py +223 -9
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/serve/app.py +102 -23
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/utils/display.py +5 -0
- shiftgate-0.2.1/tests/test_backend.py +466 -0
- shiftgate-0.2.1/tests/test_cli.py +146 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_packaging.py +6 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_serve.py +95 -0
- shiftgate-0.2.0/shiftgate/router/embedder.py +0 -95
- shiftgate-0.2.0/tests/test_backend.py +0 -232
- {shiftgate-0.2.0 → shiftgate-0.2.1}/.gitignore +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/data/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/data/default_tasks.json +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/feedback/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/feedback/loop.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/registry/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/registry/schemas.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/registry/task_registry.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/router/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/router/matcher.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/runtime/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/serve/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/shiftgate/utils/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/__init__.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_feedback.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_registry.py +0 -0
- {shiftgate-0.2.0 → shiftgate-0.2.1}/tests/test_router.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: shiftgate
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Intelligent routing layer that automatically selects the right LoRA adapter for each task in your local agent loop.
|
|
5
5
|
Project-URL: Homepage, https://github.com/shiftgate-ai/shiftgate
|
|
6
6
|
Project-URL: Repository, https://github.com/shiftgate-ai/shiftgate
|
|
@@ -356,10 +356,43 @@ shiftgate adapter add llama3.1-8b --runtime llama3.1-8b --tags general --base ll
|
|
|
356
356
|
shiftgate run "write a python sorting function"
|
|
357
357
|
```
|
|
358
358
|
|
|
359
|
-
shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras**, so local backends always win and
|
|
359
|
+
shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras → Cloudflare**, so local backends always win and the cloud backends are used only when no local backend is running.
|
|
360
360
|
|
|
361
361
|
> **Honest status:** shiftgate routes to Cerebras' base-model inference today. When Cerebras Multi-LoRA goes public, register your adapter with `--runtime <cerebras-lora-id>` and it just works — no shiftgate update needed.
|
|
362
362
|
|
|
363
|
+
### Option 5 — Cloudflare Workers AI (cloud, LoRA-native)
|
|
364
|
+
|
|
365
|
+
[Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) serves your own LoRA finetunes on top of supported base models.
|
|
366
|
+
|
|
367
|
+
```bash
|
|
368
|
+
# 1. Upload your LoRA to Cloudflare (one-time)
|
|
369
|
+
npx wrangler ai finetune create @cf/mistral/mistral-7b-instruct-v0.2-lora my-sql-lora ./adapter-folder
|
|
370
|
+
|
|
371
|
+
# 2. Set credentials
|
|
372
|
+
export CLOUDFLARE_ACCOUNT_ID=...
|
|
373
|
+
export CLOUDFLARE_API_TOKEN=...
|
|
374
|
+
|
|
375
|
+
# 3. Register in shiftgate — note --base is the Cloudflare model name
|
|
376
|
+
shiftgate adapter add my-sql-lora \
|
|
377
|
+
--runtime my-sql-lora \
|
|
378
|
+
--base @cf/mistral/mistral-7b-instruct-v0.2-lora \
|
|
379
|
+
--tags sql
|
|
380
|
+
|
|
381
|
+
# 4. Run
|
|
382
|
+
shiftgate run "write a sql join query"
|
|
383
|
+
```
|
|
384
|
+
|
|
385
|
+
You can also pass credentials per-run with the `--cf-account-id` and `--cf-api-token` global flags.
|
|
386
|
+
|
|
387
|
+
**Architectural difference (handled transparently):** Cloudflare keeps the base model in the URL and accepts the LoRA name as a separate `lora` field — not as the `model` value like vLLM/Cerebras. shiftgate handles this transparently; your routing logic doesn't change. The `--base` you register **must** be a Cloudflare model name starting with `@cf/`.
|
|
388
|
+
|
|
389
|
+
**Limitations** (from [Cloudflare's docs](https://developers.cloudflare.com/workers-ai/features/fine-tunes/)):
|
|
390
|
+
|
|
391
|
+
- Up to **100 LoRAs** per account.
|
|
392
|
+
- LoRA file must be **< 300 MB**.
|
|
393
|
+
- Must be trained with rank **r ≤ 8** (up to 32 on some models).
|
|
394
|
+
- **Streaming is not yet supported** through shiftgate for Cloudflare — you get a single response. (Streaming requests to `shiftgate serve` against Cloudflare return HTTP 501.)
|
|
395
|
+
|
|
363
396
|
---
|
|
364
397
|
|
|
365
398
|
## How to contribute adapters
|
|
@@ -320,10 +320,43 @@ shiftgate adapter add llama3.1-8b --runtime llama3.1-8b --tags general --base ll
|
|
|
320
320
|
shiftgate run "write a python sorting function"
|
|
321
321
|
```
|
|
322
322
|
|
|
323
|
-
shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras**, so local backends always win and
|
|
323
|
+
shiftgate auto-detects backends in the order **Ollama → vLLM → Cerebras → Cloudflare**, so local backends always win and the cloud backends are used only when no local backend is running.
|
|
324
324
|
|
|
325
325
|
> **Honest status:** shiftgate routes to Cerebras' base-model inference today. When Cerebras Multi-LoRA goes public, register your adapter with `--runtime <cerebras-lora-id>` and it just works — no shiftgate update needed.
|
|
326
326
|
|
|
327
|
+
### Option 5 — Cloudflare Workers AI (cloud, LoRA-native)
|
|
328
|
+
|
|
329
|
+
[Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) serves your own LoRA finetunes on top of supported base models.
|
|
330
|
+
|
|
331
|
+
```bash
|
|
332
|
+
# 1. Upload your LoRA to Cloudflare (one-time)
|
|
333
|
+
npx wrangler ai finetune create @cf/mistral/mistral-7b-instruct-v0.2-lora my-sql-lora ./adapter-folder
|
|
334
|
+
|
|
335
|
+
# 2. Set credentials
|
|
336
|
+
export CLOUDFLARE_ACCOUNT_ID=...
|
|
337
|
+
export CLOUDFLARE_API_TOKEN=...
|
|
338
|
+
|
|
339
|
+
# 3. Register in shiftgate — note --base is the Cloudflare model name
|
|
340
|
+
shiftgate adapter add my-sql-lora \
|
|
341
|
+
--runtime my-sql-lora \
|
|
342
|
+
--base @cf/mistral/mistral-7b-instruct-v0.2-lora \
|
|
343
|
+
--tags sql
|
|
344
|
+
|
|
345
|
+
# 4. Run
|
|
346
|
+
shiftgate run "write a sql join query"
|
|
347
|
+
```
|
|
348
|
+
|
|
349
|
+
You can also pass credentials per-run with the `--cf-account-id` and `--cf-api-token` global flags.
|
|
350
|
+
|
|
351
|
+
**Architectural difference (handled transparently):** Cloudflare keeps the base model in the URL and accepts the LoRA name as a separate `lora` field — not as the `model` value like vLLM/Cerebras. shiftgate handles this transparently; your routing logic doesn't change. The `--base` you register **must** be a Cloudflare model name starting with `@cf/`.
|
|
352
|
+
|
|
353
|
+
**Limitations** (from [Cloudflare's docs](https://developers.cloudflare.com/workers-ai/features/fine-tunes/)):
|
|
354
|
+
|
|
355
|
+
- Up to **100 LoRAs** per account.
|
|
356
|
+
- LoRA file must be **< 300 MB**.
|
|
357
|
+
- Must be trained with rank **r ≤ 8** (up to 32 on some models).
|
|
358
|
+
- **Streaming is not yet supported** through shiftgate for Cloudflare — you get a single response. (Streaming requests to `shiftgate serve` against Cloudflare return HTTP 501.)
|
|
359
|
+
|
|
327
360
|
---
|
|
328
361
|
|
|
329
362
|
## How to contribute adapters
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "shiftgate"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.1"
|
|
8
8
|
description = "Intelligent routing layer that automatically selects the right LoRA adapter for each task in your local agent loop."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
@@ -52,12 +52,47 @@ def _main(
|
|
|
52
52
|
),
|
|
53
53
|
),
|
|
54
54
|
] = None,
|
|
55
|
+
cf_account_id: Annotated[
|
|
56
|
+
Optional[str],
|
|
57
|
+
typer.Option(
|
|
58
|
+
"--cf-account-id",
|
|
59
|
+
help="Cloudflare account ID. Sets CLOUDFLARE_ACCOUNT_ID for this run.",
|
|
60
|
+
),
|
|
61
|
+
] = None,
|
|
62
|
+
cf_api_token: Annotated[
|
|
63
|
+
Optional[str],
|
|
64
|
+
typer.Option(
|
|
65
|
+
"--cf-api-token",
|
|
66
|
+
help="Cloudflare API token. Sets CLOUDFLARE_API_TOKEN for this run.",
|
|
67
|
+
),
|
|
68
|
+
] = None,
|
|
69
|
+
verbose: Annotated[
|
|
70
|
+
bool,
|
|
71
|
+
typer.Option(
|
|
72
|
+
"--verbose",
|
|
73
|
+
"-v",
|
|
74
|
+
help="Enable DEBUG logging (shows routing internals like runtime filtering).",
|
|
75
|
+
),
|
|
76
|
+
] = False,
|
|
55
77
|
) -> None:
|
|
56
78
|
"""Global options applied before any command runs."""
|
|
57
|
-
|
|
58
|
-
import os
|
|
79
|
+
import os
|
|
59
80
|
|
|
81
|
+
if cerebras_key:
|
|
60
82
|
os.environ["CEREBRAS_API_KEY"] = cerebras_key
|
|
83
|
+
if cf_account_id:
|
|
84
|
+
os.environ["CLOUDFLARE_ACCOUNT_ID"] = cf_account_id
|
|
85
|
+
if cf_api_token:
|
|
86
|
+
os.environ["CLOUDFLARE_API_TOKEN"] = cf_api_token
|
|
87
|
+
|
|
88
|
+
if verbose:
|
|
89
|
+
import logging
|
|
90
|
+
|
|
91
|
+
logging.basicConfig(
|
|
92
|
+
level=logging.DEBUG,
|
|
93
|
+
format="%(levelname)s [%(name)s] %(message)s",
|
|
94
|
+
)
|
|
95
|
+
logging.getLogger("shiftgate").setLevel(logging.DEBUG)
|
|
61
96
|
|
|
62
97
|
|
|
63
98
|
# ---------------------------------------------------------------------------
|
|
@@ -83,17 +118,33 @@ def _get_embedder():
|
|
|
83
118
|
return Embedder()
|
|
84
119
|
|
|
85
120
|
|
|
86
|
-
def _active_runtimes(backend_router) -> set[str] | None:
|
|
87
|
-
"""Return the set of runtime names
|
|
121
|
+
def _active_runtimes(backend_router, adapter_reg=None) -> set[str] | None:
|
|
122
|
+
"""Return the set of runtime names usable on the active backend, or None.
|
|
88
123
|
|
|
89
124
|
``None`` means no backend is active → the router should not filter
|
|
90
125
|
(preview behaviour). An empty set means a backend is active but reports no
|
|
91
|
-
|
|
126
|
+
usable models.
|
|
127
|
+
|
|
128
|
+
For Cloudflare, base models are always available without any finetune
|
|
129
|
+
upload, so every registered ``@cf/`` adapter that has no finetune
|
|
130
|
+
``runtime_name`` is considered usable (base-model inference).
|
|
92
131
|
"""
|
|
93
132
|
active = backend_router.active_backend
|
|
94
133
|
if active is None:
|
|
95
134
|
return None
|
|
96
|
-
|
|
135
|
+
|
|
136
|
+
usable = set(active.list_loaded_adapters_cached())
|
|
137
|
+
|
|
138
|
+
from shiftgate.runtime.backend import CloudflareBackend
|
|
139
|
+
|
|
140
|
+
if isinstance(active, CloudflareBackend) and adapter_reg is not None:
|
|
141
|
+
for adapter in adapter_reg.list_adapters():
|
|
142
|
+
is_cf_base = (adapter.base_model or "").startswith("@cf/")
|
|
143
|
+
has_finetune = bool((adapter.runtime_name or "").strip())
|
|
144
|
+
if is_cf_base and not has_finetune:
|
|
145
|
+
usable.add(adapter.effective_backend_name())
|
|
146
|
+
|
|
147
|
+
return usable
|
|
97
148
|
|
|
98
149
|
|
|
99
150
|
def _auto_link_adapter(adapter: AdapterEntry, task_reg) -> list[str]:
|
|
@@ -160,9 +211,15 @@ def _verify_runtime_adapter(adapter: AdapterEntry, adapter_reg) -> None:
|
|
|
160
211
|
"""
|
|
161
212
|
from shiftgate.runtime.backend import BackendRouter
|
|
162
213
|
|
|
214
|
+
# A Cloudflare base model (@cf/...) implies this is a Cloudflare adapter —
|
|
215
|
+
# prefer the Cloudflare backend for verification when it's reachable.
|
|
216
|
+
is_cloudflare = (adapter.base_model or "").startswith("@cf/")
|
|
217
|
+
|
|
163
218
|
try:
|
|
164
219
|
with console.status("[cyan]Verifying adapter against running backend…[/cyan]"):
|
|
165
220
|
router = BackendRouter()
|
|
221
|
+
if is_cloudflare and router._cloudflare.is_available():
|
|
222
|
+
router.select("cloudflare")
|
|
166
223
|
is_loaded, backend_name = router.verify_adapter(adapter)
|
|
167
224
|
except Exception as exc: # pragma: no cover - defensive, should not happen
|
|
168
225
|
logger_msg = f"verification error: {exc}"
|
|
@@ -179,9 +236,14 @@ def _verify_runtime_adapter(adapter: AdapterEntry, adapter_reg) -> None:
|
|
|
179
236
|
console.print(f" [green]Backend: {backend_name} ✓ verified[/green]")
|
|
180
237
|
else:
|
|
181
238
|
adapter.verified = False
|
|
239
|
+
hint = (
|
|
240
|
+
"— upload it with `npx wrangler ai finetune create`"
|
|
241
|
+
if backend_name == "cloudflare"
|
|
242
|
+
else "— did you pass --lora-modules?"
|
|
243
|
+
)
|
|
182
244
|
console.print(
|
|
183
245
|
f" [yellow]Backend: {backend_name} ⚠ runtime '{runtime}' not loaded "
|
|
184
|
-
"
|
|
246
|
+
f"{hint}[/yellow]"
|
|
185
247
|
)
|
|
186
248
|
|
|
187
249
|
adapter_reg.save()
|
|
@@ -207,12 +269,23 @@ def init() -> None:
|
|
|
207
269
|
task_reg = TaskRegistry.load()
|
|
208
270
|
|
|
209
271
|
if task_reg.embeddings_ready():
|
|
210
|
-
console.print(
|
|
272
|
+
console.print(
|
|
273
|
+
"[dim]Task centroids already computed — skipping re-embed "
|
|
274
|
+
"(delete embeddings_cache.npy to force refresh).[/dim]"
|
|
275
|
+
)
|
|
211
276
|
else:
|
|
212
277
|
console.print("[cyan]Computing task embeddings (first run — model download may take a moment)…[/cyan]")
|
|
213
278
|
embedder = _get_embedder()
|
|
214
279
|
task_reg.compute_embeddings(embedder)
|
|
215
|
-
console.print("[green]
|
|
280
|
+
console.print("[green]OK[/green] Embeddings computed.")
|
|
281
|
+
|
|
282
|
+
# Centroids can exist while the ONNX runtime model is missing (e.g. after
|
|
283
|
+
# deleting fastembed_cache). Always warm-load so routing works.
|
|
284
|
+
console.print("[cyan]Loading embedder model (one-time ~30 MB download if not cached)…[/cyan]")
|
|
285
|
+
from shiftgate.router.embedder import warm_up
|
|
286
|
+
|
|
287
|
+
dim = warm_up()
|
|
288
|
+
console.print(f"[green]OK[/green] Embedder ready (dim={dim}).")
|
|
216
289
|
|
|
217
290
|
task_reg.save()
|
|
218
291
|
console.print(f"[green]✓[/green] Task registry saved to {shiftgate_dir}")
|
|
@@ -301,6 +374,7 @@ def adapter_add(
|
|
|
301
374
|
"""
|
|
302
375
|
from shiftgate.registry.adapter_registry import (
|
|
303
376
|
AdapterRegistry,
|
|
377
|
+
adapter_from_base_model,
|
|
304
378
|
adapter_from_hf,
|
|
305
379
|
adapter_from_local,
|
|
306
380
|
adapter_from_runtime,
|
|
@@ -347,12 +421,22 @@ def adapter_add(
|
|
|
347
421
|
**shared_kwargs,
|
|
348
422
|
)
|
|
349
423
|
|
|
350
|
-
# ---
|
|
424
|
+
# --- Mode D: Cloudflare base model (always available, no finetune) ---
|
|
425
|
+
elif (base or "").startswith("@cf/"):
|
|
426
|
+
base_model_kwargs = {k: v for k, v in shared_kwargs.items() if k != "base_model"}
|
|
427
|
+
adapter = adapter_from_base_model(
|
|
428
|
+
base_model=base,
|
|
429
|
+
adapter_id=adapter_id or identifier,
|
|
430
|
+
**base_model_kwargs,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# --- Ambiguous: no '/', no --local, no --runtime, no @cf/ base ---
|
|
351
434
|
else:
|
|
352
435
|
console.print(
|
|
353
436
|
f"[red]Error:[/red] '{identifier}' doesn't look like a HuggingFace repo ID (missing '/').\n"
|
|
354
|
-
" Use [cyan]--local /path/to/adapter[/cyan] to register a local adapter
|
|
355
|
-
" use [cyan]--runtime <backend-name>[/cyan] for a runtime-registered adapter
|
|
437
|
+
" Use [cyan]--local /path/to/adapter[/cyan] to register a local adapter,\n"
|
|
438
|
+
" use [cyan]--runtime <backend-name>[/cyan] for a runtime-registered adapter, or\n"
|
|
439
|
+
" use [cyan]--base @cf/...[/cyan] for a Cloudflare Workers AI base model."
|
|
356
440
|
)
|
|
357
441
|
raise typer.Exit(1)
|
|
358
442
|
|
|
@@ -363,6 +447,11 @@ def adapter_add(
|
|
|
363
447
|
# has it loaded. Purely informational — never fails the add command.
|
|
364
448
|
if adapter.runtime_name:
|
|
365
449
|
_verify_runtime_adapter(adapter, adapter_reg)
|
|
450
|
+
elif (adapter.base_model or "").startswith("@cf/"):
|
|
451
|
+
console.print(
|
|
452
|
+
" [green]Backend:[/green] cloudflare "
|
|
453
|
+
"[green]✓[/green] base model is always available (no upload needed)"
|
|
454
|
+
)
|
|
366
455
|
|
|
367
456
|
|
|
368
457
|
@adapter_app.command("list")
|
|
@@ -490,7 +579,7 @@ def route(
|
|
|
490
579
|
|
|
491
580
|
backend_router = BackendRouter()
|
|
492
581
|
backend_name = backend_router.detect()
|
|
493
|
-
available_runtimes = _active_runtimes(backend_router)
|
|
582
|
+
available_runtimes = _active_runtimes(backend_router, adapter_reg)
|
|
494
583
|
|
|
495
584
|
try:
|
|
496
585
|
trace, match_result = routing.route(
|
|
@@ -549,7 +638,7 @@ def run(
|
|
|
549
638
|
|
|
550
639
|
backend_router = BackendRouter()
|
|
551
640
|
backend_name = backend_router.detect()
|
|
552
|
-
available_runtimes = _active_runtimes(backend_router)
|
|
641
|
+
available_runtimes = _active_runtimes(backend_router, adapter_reg)
|
|
553
642
|
|
|
554
643
|
try:
|
|
555
644
|
trace, match_result = routing.route(
|
|
@@ -716,9 +805,11 @@ def doctor() -> None:
|
|
|
716
805
|
router = BackendRouter()
|
|
717
806
|
backend_name = router.detect()
|
|
718
807
|
backend_url = router.active_backend_url
|
|
719
|
-
loaded_adapters:
|
|
808
|
+
loaded_adapters: set[str] = set()
|
|
720
809
|
if backend_name is not None and router._active is not None:
|
|
721
|
-
|
|
810
|
+
# Cloudflare base models are always available, so use the same
|
|
811
|
+
# usable-runtime computation the router uses when filtering.
|
|
812
|
+
loaded_adapters = _active_runtimes(router, adapter_reg) or set()
|
|
722
813
|
|
|
723
814
|
# --- 3. Per-adapter runtime availability ---
|
|
724
815
|
adapter_rows = []
|
|
@@ -774,7 +865,7 @@ def serve(
|
|
|
774
865
|
str,
|
|
775
866
|
typer.Option(
|
|
776
867
|
"--backend",
|
|
777
|
-
help="Backend to forward to: auto | ollama | vllm | cerebras.",
|
|
868
|
+
help="Backend to forward to: auto | ollama | vllm | cerebras | cloudflare.",
|
|
778
869
|
),
|
|
779
870
|
] = "auto",
|
|
780
871
|
) -> None:
|
|
@@ -244,6 +244,32 @@ def adapter_from_runtime(
|
|
|
244
244
|
)
|
|
245
245
|
|
|
246
246
|
|
|
247
|
+
def adapter_from_base_model(
|
|
248
|
+
base_model: str,
|
|
249
|
+
*,
|
|
250
|
+
adapter_id: str,
|
|
251
|
+
name: str | None = None,
|
|
252
|
+
tags: list[str] | None = None,
|
|
253
|
+
description: str | None = None,
|
|
254
|
+
benchmark_score: float | None = None,
|
|
255
|
+
) -> AdapterEntry:
|
|
256
|
+
"""Build an AdapterEntry that routes to a base model with no finetune.
|
|
257
|
+
|
|
258
|
+
Used for backends whose base models are always available without any
|
|
259
|
+
upload (e.g. Cloudflare Workers AI ``@cf/`` models). ``runtime_name`` is
|
|
260
|
+
left unset so the backend runs the base model directly.
|
|
261
|
+
"""
|
|
262
|
+
slug = _slugify(adapter_id)
|
|
263
|
+
return AdapterEntry(
|
|
264
|
+
id=slug,
|
|
265
|
+
name=name or slug.replace("-", " ").title(),
|
|
266
|
+
base_model=base_model,
|
|
267
|
+
task_tags=list(tags or []),
|
|
268
|
+
description=description or f"Base model: {base_model}",
|
|
269
|
+
benchmark_score=benchmark_score,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
247
273
|
# ---------------------------------------------------------------------------
|
|
248
274
|
# Internal helpers
|
|
249
275
|
# ---------------------------------------------------------------------------
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Text embedder backed by fastembed.
|
|
3
|
+
|
|
4
|
+
Uses ``BAAI/bge-small-en-v1.5`` — a compact (33 M param) model that runs
|
|
5
|
+
efficiently on CPU. The model is downloaded once by fastembed and cached in
|
|
6
|
+
``~/.shiftgate/fastembed_cache`` (avoids Windows ``%TEMP%`` corruption issues).
|
|
7
|
+
|
|
8
|
+
A module-level singleton (``_MODEL``) is created lazily on first use so that
|
|
9
|
+
importing this module is cheap. The model is NOT re-created between calls.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
import time
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Stable cache location — fastembed defaults to %TEMP% on Windows, which is
|
|
24
|
+
# prone to partial downloads and "file sizes do not match" corruption.
|
|
25
|
+
FASTEMBED_CACHE_DIR = Path.home() / ".shiftgate" / "fastembed_cache"
|
|
26
|
+
|
|
27
|
+
# -------------------------------------------------------------------------
|
|
28
|
+
# Default model — small, CPU-friendly, strong quality/speed trade-off.
|
|
29
|
+
# -------------------------------------------------------------------------
|
|
30
|
+
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
|
|
31
|
+
|
|
32
|
+
# HuggingFace downloads can flake; retry a few times before giving up.
|
|
33
|
+
_LOAD_RETRIES = 3
|
|
34
|
+
_LOAD_RETRY_DELAY_S = 2.0
|
|
35
|
+
|
|
36
|
+
# Module-level singleton; populated on first call to `_get_model()`.
|
|
37
|
+
_MODEL: Any | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _is_retryable_download_error(exc: BaseException) -> bool:
|
|
41
|
+
"""Return True for transient HuggingFace / network failures."""
|
|
42
|
+
msg = str(exc).lower()
|
|
43
|
+
needles = (
|
|
44
|
+
"server disconnected",
|
|
45
|
+
"connection reset",
|
|
46
|
+
"connection aborted",
|
|
47
|
+
"timed out",
|
|
48
|
+
"timeout",
|
|
49
|
+
"temporary failure",
|
|
50
|
+
"503",
|
|
51
|
+
"502",
|
|
52
|
+
"429",
|
|
53
|
+
)
|
|
54
|
+
return any(n in msg for n in needles)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _format_load_error(model_name: str, exc: BaseException) -> str:
|
|
58
|
+
cache_hint = str(FASTEMBED_CACHE_DIR)
|
|
59
|
+
if _is_retryable_download_error(exc):
|
|
60
|
+
return (
|
|
61
|
+
f"Failed to download embedding model '{model_name}' from HuggingFace: {exc}\n"
|
|
62
|
+
"This is usually a transient network/rate-limit issue. Retry:\n"
|
|
63
|
+
f" uv run shiftgate init\n"
|
|
64
|
+
"Optional: set HF_TOKEN for higher HuggingFace rate limits.\n"
|
|
65
|
+
f"If downloads keep failing, delete '{cache_hint}' and retry."
|
|
66
|
+
)
|
|
67
|
+
return (
|
|
68
|
+
f"Failed to load embedding model '{model_name}': {exc}\n"
|
|
69
|
+
"If you see NO_SUCHFILE or 'file sizes do not match', delete the cache "
|
|
70
|
+
"and retry:\n"
|
|
71
|
+
f" Remove-Item -Recurse -Force '{cache_hint}'\n"
|
|
72
|
+
"Also clear any stale copy at $env:TEMP\\fastembed_cache, then run "
|
|
73
|
+
"`shiftgate init` again."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _get_model(model_name: str = DEFAULT_MODEL) -> Any:
|
|
78
|
+
"""Return the fastembed TextEmbedding singleton, creating it if needed.
|
|
79
|
+
|
|
80
|
+
The model is loaded once per process. If you need a different model,
|
|
81
|
+
call ``reset_model()`` first.
|
|
82
|
+
"""
|
|
83
|
+
global _MODEL
|
|
84
|
+
if _MODEL is None:
|
|
85
|
+
try:
|
|
86
|
+
from fastembed import TextEmbedding # type: ignore
|
|
87
|
+
except ImportError as exc:
|
|
88
|
+
raise ImportError(
|
|
89
|
+
"fastembed is required for shiftgate routing. "
|
|
90
|
+
"Install it with: pip install fastembed"
|
|
91
|
+
) from exc
|
|
92
|
+
|
|
93
|
+
FASTEMBED_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
94
|
+
logger.info(
|
|
95
|
+
"Loading embedding model '%s' (first use — one-time download may occur)…",
|
|
96
|
+
model_name,
|
|
97
|
+
)
|
|
98
|
+
last_exc: BaseException | None = None
|
|
99
|
+
for attempt in range(1, _LOAD_RETRIES + 1):
|
|
100
|
+
try:
|
|
101
|
+
_MODEL = TextEmbedding(
|
|
102
|
+
model_name=model_name,
|
|
103
|
+
cache_dir=str(FASTEMBED_CACHE_DIR),
|
|
104
|
+
)
|
|
105
|
+
break
|
|
106
|
+
except Exception as exc:
|
|
107
|
+
last_exc = exc
|
|
108
|
+
if attempt < _LOAD_RETRIES and _is_retryable_download_error(exc):
|
|
109
|
+
logger.warning(
|
|
110
|
+
"Embedder load attempt %d/%d failed (%s); retrying…",
|
|
111
|
+
attempt,
|
|
112
|
+
_LOAD_RETRIES,
|
|
113
|
+
exc,
|
|
114
|
+
)
|
|
115
|
+
time.sleep(_LOAD_RETRY_DELAY_S * attempt)
|
|
116
|
+
continue
|
|
117
|
+
raise RuntimeError(_format_load_error(model_name, exc)) from exc
|
|
118
|
+
else:
|
|
119
|
+
assert last_exc is not None
|
|
120
|
+
raise RuntimeError(_format_load_error(model_name, last_exc)) from last_exc
|
|
121
|
+
logger.info("Embedding model loaded.")
|
|
122
|
+
return _MODEL
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def warm_up(model_name: str = DEFAULT_MODEL) -> int:
|
|
126
|
+
"""Load the embedder and run a dummy embed. Returns embedding dimension."""
|
|
127
|
+
vec = Embedder(model_name).embed("warmup")
|
|
128
|
+
return int(vec.shape[0])
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def reset_model() -> None:
|
|
132
|
+
"""Force the next embed call to recreate the model singleton.
|
|
133
|
+
|
|
134
|
+
Useful in tests or when switching models at runtime.
|
|
135
|
+
"""
|
|
136
|
+
global _MODEL
|
|
137
|
+
_MODEL = None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Embedder:
|
|
141
|
+
"""Thin wrapper around the fastembed TextEmbedding model.
|
|
142
|
+
|
|
143
|
+
All embedding operations are synchronous and run on CPU. The model
|
|
144
|
+
is shared across all ``Embedder`` instances via the module-level singleton.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(self, model_name: str = DEFAULT_MODEL) -> None:
|
|
148
|
+
self._model_name = model_name
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def _model(self) -> Any:
|
|
152
|
+
return _get_model(self._model_name)
|
|
153
|
+
|
|
154
|
+
def embed(self, text: str) -> np.ndarray:
|
|
155
|
+
"""Embed a single text string.
|
|
156
|
+
|
|
157
|
+
Returns a 1-D float32 numpy array of shape ``(dim,)``.
|
|
158
|
+
The vector is **not** L2-normalised here; normalisation is done
|
|
159
|
+
where appropriate (e.g. when computing task centroids).
|
|
160
|
+
"""
|
|
161
|
+
# fastembed returns a generator of numpy arrays.
|
|
162
|
+
results = list(self._model.embed([text]))
|
|
163
|
+
return np.array(results[0], dtype=np.float32)
|
|
164
|
+
|
|
165
|
+
def embed_batch(self, texts: list[str]) -> np.ndarray:
|
|
166
|
+
"""Embed a list of strings.
|
|
167
|
+
|
|
168
|
+
Returns a 2-D float32 numpy array of shape ``(n, dim)`` where
|
|
169
|
+
``n = len(texts)``.
|
|
170
|
+
"""
|
|
171
|
+
if not texts:
|
|
172
|
+
raise ValueError("embed_batch received an empty list.")
|
|
173
|
+
results = list(self._model.embed(texts))
|
|
174
|
+
return np.array(results, dtype=np.float32)
|
|
@@ -80,6 +80,15 @@ def route(
|
|
|
80
80
|
query_embedding = embedder.embed(query)
|
|
81
81
|
all_tasks = task_registry.get_all_tasks()
|
|
82
82
|
ranked = top_k_tasks(query_embedding, all_tasks, k=top_k)
|
|
83
|
+
|
|
84
|
+
if available_runtimes is not None:
|
|
85
|
+
logger.debug(
|
|
86
|
+
"filtering adapters to backend runtimes: %s",
|
|
87
|
+
sorted(available_runtimes),
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
logger.debug("no active backend — adapter runtime filtering disabled")
|
|
91
|
+
|
|
83
92
|
result = select_adapter(ranked, adapter_registry, available_runtimes=available_runtimes)
|
|
84
93
|
|
|
85
94
|
selected_id = result.selected_adapter.id if result.selected_adapter else None
|