shiftgate 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shiftgate/__init__.py +9 -0
- shiftgate/cli.py +513 -0
- shiftgate/feedback/__init__.py +1 -0
- shiftgate/feedback/loop.py +182 -0
- shiftgate/registry/__init__.py +1 -0
- shiftgate/registry/adapter_registry.py +162 -0
- shiftgate/registry/schemas.py +115 -0
- shiftgate/registry/task_registry.py +186 -0
- shiftgate/router/__init__.py +1 -0
- shiftgate/router/embedder.py +95 -0
- shiftgate/router/matcher.py +115 -0
- shiftgate/router/router.py +97 -0
- shiftgate/runtime/__init__.py +1 -0
- shiftgate/runtime/backend.py +289 -0
- shiftgate/utils/__init__.py +1 -0
- shiftgate/utils/display.py +297 -0
- shiftgate-0.1.0.dist-info/METADATA +273 -0
- shiftgate-0.1.0.dist-info/RECORD +20 -0
- shiftgate-0.1.0.dist-info/WHEEL +4 -0
- shiftgate-0.1.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Text embedder backed by fastembed.
|
|
3
|
+
|
|
4
|
+
Uses ``BAAI/bge-small-en-v1.5`` — a compact (33 M param) model that runs
|
|
5
|
+
efficiently on CPU. The model is downloaded once by fastembed and cached in
|
|
6
|
+
``~/.cache/fastembed``.
|
|
7
|
+
|
|
8
|
+
A module-level singleton (``_MODEL``) is created lazily on first use so that
|
|
9
|
+
importing this module is cheap. The model is NOT re-created between calls.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# -------------------------------------------------------------------------
|
|
22
|
+
# Default model — small, CPU-friendly, strong quality/speed trade-off.
|
|
23
|
+
# -------------------------------------------------------------------------
|
|
24
|
+
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
|
|
25
|
+
|
|
26
|
+
# Module-level singleton; populated on first call to `_get_model()`.
|
|
27
|
+
_MODEL: Any | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_model(model_name: str = DEFAULT_MODEL) -> Any:
|
|
31
|
+
"""Return the fastembed TextEmbedding singleton, creating it if needed.
|
|
32
|
+
|
|
33
|
+
The model is loaded once per process. If you need a different model,
|
|
34
|
+
call ``reset_model()`` first.
|
|
35
|
+
"""
|
|
36
|
+
global _MODEL
|
|
37
|
+
if _MODEL is None:
|
|
38
|
+
try:
|
|
39
|
+
from fastembed import TextEmbedding # type: ignore
|
|
40
|
+
except ImportError as exc:
|
|
41
|
+
raise ImportError(
|
|
42
|
+
"fastembed is required for shiftgate routing. "
|
|
43
|
+
"Install it with: pip install fastembed"
|
|
44
|
+
) from exc
|
|
45
|
+
|
|
46
|
+
logger.info("Loading embedding model '%s' (first use — one-time download may occur)…", model_name)
|
|
47
|
+
_MODEL = TextEmbedding(model_name=model_name)
|
|
48
|
+
logger.info("Embedding model loaded.")
|
|
49
|
+
return _MODEL
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def reset_model() -> None:
|
|
53
|
+
"""Force the next embed call to recreate the model singleton.
|
|
54
|
+
|
|
55
|
+
Useful in tests or when switching models at runtime.
|
|
56
|
+
"""
|
|
57
|
+
global _MODEL
|
|
58
|
+
_MODEL = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Embedder:
|
|
62
|
+
"""Thin wrapper around the fastembed TextEmbedding model.
|
|
63
|
+
|
|
64
|
+
All embedding operations are synchronous and run on CPU. The model
|
|
65
|
+
is shared across all ``Embedder`` instances via the module-level singleton.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, model_name: str = DEFAULT_MODEL) -> None:
|
|
69
|
+
self._model_name = model_name
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def _model(self) -> Any:
|
|
73
|
+
return _get_model(self._model_name)
|
|
74
|
+
|
|
75
|
+
def embed(self, text: str) -> np.ndarray:
|
|
76
|
+
"""Embed a single text string.
|
|
77
|
+
|
|
78
|
+
Returns a 1-D float32 numpy array of shape ``(dim,)``.
|
|
79
|
+
The vector is **not** L2-normalised here; normalisation is done
|
|
80
|
+
where appropriate (e.g. when computing task centroids).
|
|
81
|
+
"""
|
|
82
|
+
# fastembed returns a generator of numpy arrays.
|
|
83
|
+
results = list(self._model.embed([text]))
|
|
84
|
+
return np.array(results[0], dtype=np.float32)
|
|
85
|
+
|
|
86
|
+
def embed_batch(self, texts: list[str]) -> np.ndarray:
|
|
87
|
+
"""Embed a list of strings.
|
|
88
|
+
|
|
89
|
+
Returns a 2-D float32 numpy array of shape ``(n, dim)`` where
|
|
90
|
+
``n = len(texts)``.
|
|
91
|
+
"""
|
|
92
|
+
if not texts:
|
|
93
|
+
raise ValueError("embed_batch received an empty list.")
|
|
94
|
+
results = list(self._model.embed(texts))
|
|
95
|
+
return np.array(results, dtype=np.float32)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cosine-similarity matcher: maps query embeddings to task clusters and adapters.
|
|
3
|
+
|
|
4
|
+
This module is deliberately stateless — all context (registries, embeddings)
|
|
5
|
+
is passed explicitly so the functions are easy to test in isolation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from shiftgate.registry.schemas import AdapterEntry, TaskCluster
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def top_k_tasks(
|
|
20
|
+
query_embedding: np.ndarray,
|
|
21
|
+
task_clusters: list[TaskCluster],
|
|
22
|
+
k: int = 3,
|
|
23
|
+
) -> list[tuple[TaskCluster, float]]:
|
|
24
|
+
"""Return the top-K task clusters by cosine similarity to the query.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
query_embedding:
|
|
29
|
+
1-D float32 array of shape ``(dim,)``. Need not be L2-normalised;
|
|
30
|
+
this function normalises internally.
|
|
31
|
+
task_clusters:
|
|
32
|
+
All clusters in the registry. Clusters without a computed centroid
|
|
33
|
+
are silently skipped.
|
|
34
|
+
k:
|
|
35
|
+
Number of top clusters to return.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
list of ``(TaskCluster, score)`` pairs sorted by score descending.
|
|
40
|
+
"""
|
|
41
|
+
eligible = [t for t in task_clusters if t.embedding_centroid is not None]
|
|
42
|
+
if not eligible:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"No task cluster has a computed embedding centroid. "
|
|
45
|
+
"Run `shiftgate init` to compute embeddings."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Stack centroids into a matrix for vectorised dot product.
|
|
49
|
+
centroid_matrix = np.array(
|
|
50
|
+
[t.embedding_centroid for t in eligible], dtype=np.float32
|
|
51
|
+
) # shape: (n_tasks, dim)
|
|
52
|
+
|
|
53
|
+
# L2-normalise the query vector.
|
|
54
|
+
q_norm = np.linalg.norm(query_embedding)
|
|
55
|
+
if q_norm == 0:
|
|
56
|
+
raise ValueError("Query produced a zero-norm embedding.")
|
|
57
|
+
q_unit = query_embedding / q_norm
|
|
58
|
+
|
|
59
|
+
# Cosine similarity = dot(q_unit, centroid_unit) because centroids were
|
|
60
|
+
# already L2-normalised at compute time (see task_registry.py).
|
|
61
|
+
scores = centroid_matrix @ q_unit # shape: (n_tasks,)
|
|
62
|
+
|
|
63
|
+
# Grab top-K indices.
|
|
64
|
+
k = min(k, len(eligible))
|
|
65
|
+
top_indices = np.argsort(scores)[::-1][:k]
|
|
66
|
+
|
|
67
|
+
return [(eligible[i], float(scores[i])) for i in top_indices]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def select_adapter(
|
|
71
|
+
top_tasks: list[tuple[TaskCluster, float]],
|
|
72
|
+
adapter_registry, # AdapterRegistry — avoid circular import with string hint
|
|
73
|
+
) -> tuple[AdapterEntry, TaskCluster, float]:
|
|
74
|
+
"""Select the best adapter given the ranked task list.
|
|
75
|
+
|
|
76
|
+
Strategy:
|
|
77
|
+
1. Iterate top tasks in similarity order.
|
|
78
|
+
2. For each task, try ``preferred_adapters`` then ``fallback_adapters``.
|
|
79
|
+
3. Return the first adapter that exists in the registry.
|
|
80
|
+
4. If no registered adapter matches any task, raise ``NoAdapterError``.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
top_tasks:
|
|
85
|
+
Output of ``top_k_tasks`` — list of (TaskCluster, score) descending.
|
|
86
|
+
adapter_registry:
|
|
87
|
+
``AdapterRegistry`` instance to look up adapter IDs.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
``(AdapterEntry, TaskCluster, similarity_score)``
|
|
92
|
+
"""
|
|
93
|
+
for task, score in top_tasks:
|
|
94
|
+
candidates = list(task.preferred_adapters) + list(task.fallback_adapters)
|
|
95
|
+
for adapter_id in candidates:
|
|
96
|
+
adapter = adapter_registry.get_adapter(adapter_id)
|
|
97
|
+
if adapter is not None:
|
|
98
|
+
logger.debug(
|
|
99
|
+
"Selected adapter '%s' via task '%s' (score=%.4f)",
|
|
100
|
+
adapter.id,
|
|
101
|
+
task.id,
|
|
102
|
+
score,
|
|
103
|
+
)
|
|
104
|
+
return adapter, task, score
|
|
105
|
+
|
|
106
|
+
# No adapter matched — surface a helpful error.
|
|
107
|
+
task_ids = [t.id for t, _ in top_tasks]
|
|
108
|
+
raise NoAdapterError(
|
|
109
|
+
f"No registered adapter found for tasks {task_ids}. "
|
|
110
|
+
"Add adapters with `shiftgate adapter add <hf_repo>`."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class NoAdapterError(RuntimeError):
|
|
115
|
+
"""Raised when the matcher cannot find any registered adapter for a query."""
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main routing orchestrator: query string → RoutingTrace.
|
|
3
|
+
|
|
4
|
+
This module ties together the embedder, matcher, and registries.
|
|
5
|
+
It is the single function that CLI commands and the runtime backend call.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import uuid
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
|
|
14
|
+
from shiftgate.registry.adapter_registry import AdapterRegistry
|
|
15
|
+
from shiftgate.registry.schemas import RoutingTrace
|
|
16
|
+
from shiftgate.registry.task_registry import TaskRegistry
|
|
17
|
+
from shiftgate.router.embedder import Embedder
|
|
18
|
+
from shiftgate.router.matcher import NoAdapterError, select_adapter, top_k_tasks
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def route(
|
|
24
|
+
query: str,
|
|
25
|
+
task_registry: TaskRegistry,
|
|
26
|
+
adapter_registry: AdapterRegistry,
|
|
27
|
+
embedder: Embedder,
|
|
28
|
+
top_k: int = 3,
|
|
29
|
+
) -> RoutingTrace:
|
|
30
|
+
"""Route a query string to the best matching adapter.
|
|
31
|
+
|
|
32
|
+
Steps
|
|
33
|
+
-----
|
|
34
|
+
1. Embed the query with the frozen embedding model.
|
|
35
|
+
2. Compute cosine similarity against all task centroid embeddings.
|
|
36
|
+
3. Select the highest-ranked task whose preferred adapters exist.
|
|
37
|
+
4. Build and return a ``RoutingTrace``.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
query:
|
|
42
|
+
The user's natural-language query or instruction.
|
|
43
|
+
task_registry:
|
|
44
|
+
Loaded ``TaskRegistry`` with pre-computed centroids.
|
|
45
|
+
adapter_registry:
|
|
46
|
+
Loaded ``AdapterRegistry``.
|
|
47
|
+
embedder:
|
|
48
|
+
``Embedder`` instance (wraps fastembed singleton).
|
|
49
|
+
top_k:
|
|
50
|
+
Number of top task candidates to consider when walking the fallback
|
|
51
|
+
chain. Defaults to 3.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
A ``RoutingTrace`` describing the decision. The trace is **not**
|
|
56
|
+
persisted here — call ``feedback.loop.record_trace(trace)`` separately.
|
|
57
|
+
|
|
58
|
+
Raises
|
|
59
|
+
------
|
|
60
|
+
NoAdapterError
|
|
61
|
+
If no registered adapter matches any of the top-K tasks.
|
|
62
|
+
ValueError
|
|
63
|
+
If embeddings have not been computed (missing centroids).
|
|
64
|
+
"""
|
|
65
|
+
if not task_registry.embeddings_ready():
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Task embeddings are not initialised. Run `shiftgate init` first."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Step 1: embed the query
|
|
71
|
+
query_embedding = embedder.embed(query)
|
|
72
|
+
|
|
73
|
+
# Step 2: rank tasks by similarity
|
|
74
|
+
all_tasks = task_registry.get_all_tasks()
|
|
75
|
+
ranked = top_k_tasks(query_embedding, all_tasks, k=top_k)
|
|
76
|
+
|
|
77
|
+
# Step 3: pick the best adapter
|
|
78
|
+
adapter, matched_task, score = select_adapter(ranked, adapter_registry)
|
|
79
|
+
|
|
80
|
+
# Step 4: assemble trace
|
|
81
|
+
trace = RoutingTrace(
|
|
82
|
+
id=uuid.uuid4().hex,
|
|
83
|
+
query=query,
|
|
84
|
+
matched_task_id=matched_task.id,
|
|
85
|
+
similarity_score=score,
|
|
86
|
+
selected_adapter_id=adapter.id,
|
|
87
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
logger.info(
|
|
91
|
+
"Routed '%s' → task='%s' (%.2f%%) → adapter='%s'",
|
|
92
|
+
query[:60],
|
|
93
|
+
matched_task.id,
|
|
94
|
+
score * 100,
|
|
95
|
+
adapter.id,
|
|
96
|
+
)
|
|
97
|
+
return trace
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Runtime sub-package: thin clients for Ollama and vLLM backends."""
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Thin clients for local LLM inference backends.
|
|
3
|
+
|
|
4
|
+
Two backends are supported:
|
|
5
|
+
|
|
6
|
+
Ollama (http://localhost:11434)
|
|
7
|
+
Open-source inference server that supports LoRA adapters via custom
|
|
8
|
+
Modelfiles. To use a LoRA with Ollama, create a Modelfile like:
|
|
9
|
+
|
|
10
|
+
FROM llama3
|
|
11
|
+
ADAPTER /path/to/adapter.safetensors
|
|
12
|
+
|
|
13
|
+
and run ``ollama create my-model -f Modelfile``. shiftgate sets
|
|
14
|
+
``model`` to the adapter's Ollama model name (derived from adapter.id).
|
|
15
|
+
|
|
16
|
+
vLLM (http://localhost:8000)
|
|
17
|
+
Provides an OpenAI-compatible ``/v1/chat/completions`` endpoint.
|
|
18
|
+
LoRA adapters are loaded at server start-up with ``--lora-modules``
|
|
19
|
+
and are addressed by name via the ``model`` field in the request.
|
|
20
|
+
|
|
21
|
+
Both backends are auto-detected by ``BackendRouter``, which pings each
|
|
22
|
+
health endpoint and delegates to whichever is available.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import logging
|
|
28
|
+
from abc import ABC, abstractmethod
|
|
29
|
+
|
|
30
|
+
import httpx
|
|
31
|
+
|
|
32
|
+
from shiftgate.registry.schemas import AdapterEntry
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
# Default timeouts (seconds).
|
|
37
|
+
_CONNECT_TIMEOUT = 3.0
|
|
38
|
+
_READ_TIMEOUT = 120.0
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BaseBackend(ABC):
|
|
42
|
+
"""Abstract base for inference backends."""
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def is_available(self) -> bool:
|
|
46
|
+
"""Return True if the backend can be reached."""
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def generate(self, prompt: str, adapter: AdapterEntry) -> str:
|
|
50
|
+
"""Send ``prompt`` to the backend and return the generated text."""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# ---------------------------------------------------------------------------
|
|
54
|
+
# Ollama
|
|
55
|
+
# ---------------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
class OllamaBackend(BaseBackend):
|
|
58
|
+
"""Thin httpx client for the Ollama inference server.
|
|
59
|
+
|
|
60
|
+
Ollama API reference: https://github.com/ollama/ollama/blob/main/docs/api.md
|
|
61
|
+
|
|
62
|
+
LoRA adapters in Ollama
|
|
63
|
+
-----------------------
|
|
64
|
+
Ollama does not have a first-class "load this adapter onto this base model"
|
|
65
|
+
API at request time. Instead you pre-register a composite model via a
|
|
66
|
+
Modelfile::
|
|
67
|
+
|
|
68
|
+
FROM llama3
|
|
69
|
+
ADAPTER /path/to/my-lora.safetensors
|
|
70
|
+
|
|
71
|
+
ollama create my-lora-model -f Modelfile
|
|
72
|
+
|
|
73
|
+
shiftgate uses ``adapter.id`` as the Ollama model name by convention.
|
|
74
|
+
Ensure your Ollama model names match shiftgate adapter IDs.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, base_url: str = "http://localhost:11434") -> None:
|
|
78
|
+
self.base_url = base_url.rstrip("/")
|
|
79
|
+
|
|
80
|
+
def is_available(self) -> bool:
|
|
81
|
+
"""Return True if the Ollama server responds to a health ping."""
|
|
82
|
+
try:
|
|
83
|
+
r = httpx.get(f"{self.base_url}/api/tags", timeout=_CONNECT_TIMEOUT)
|
|
84
|
+
return r.status_code == 200
|
|
85
|
+
except Exception:
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
def generate(
|
|
89
|
+
self,
|
|
90
|
+
prompt: str,
|
|
91
|
+
adapter: AdapterEntry,
|
|
92
|
+
*,
|
|
93
|
+
model_name: str | None = None,
|
|
94
|
+
stream: bool = False,
|
|
95
|
+
) -> str:
|
|
96
|
+
"""Generate text via Ollama's ``/api/generate`` endpoint.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
prompt:
|
|
101
|
+
The full prompt string.
|
|
102
|
+
adapter:
|
|
103
|
+
The selected ``AdapterEntry``; ``adapter.id`` is used as the
|
|
104
|
+
Ollama model name unless overridden by ``model_name``.
|
|
105
|
+
model_name:
|
|
106
|
+
Override the Ollama model name (useful when the Ollama model name
|
|
107
|
+
differs from the shiftgate adapter ID).
|
|
108
|
+
stream:
|
|
109
|
+
If True, Ollama streams response tokens. This client reads the
|
|
110
|
+
full stream and returns the concatenated text.
|
|
111
|
+
"""
|
|
112
|
+
model = model_name or adapter.id
|
|
113
|
+
payload = {"model": model, "prompt": prompt, "stream": stream}
|
|
114
|
+
|
|
115
|
+
logger.debug("Ollama generate: model=%s", model)
|
|
116
|
+
try:
|
|
117
|
+
r = httpx.post(
|
|
118
|
+
f"{self.base_url}/api/generate",
|
|
119
|
+
json=payload,
|
|
120
|
+
timeout=_READ_TIMEOUT,
|
|
121
|
+
)
|
|
122
|
+
r.raise_for_status()
|
|
123
|
+
except httpx.HTTPError as exc:
|
|
124
|
+
raise BackendError(f"Ollama request failed: {exc}") from exc
|
|
125
|
+
|
|
126
|
+
data = r.json()
|
|
127
|
+
return data.get("response", "")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# ---------------------------------------------------------------------------
|
|
131
|
+
# vLLM
|
|
132
|
+
# ---------------------------------------------------------------------------
|
|
133
|
+
|
|
134
|
+
class VLLMBackend(BaseBackend):
|
|
135
|
+
"""Thin httpx client for the vLLM OpenAI-compatible inference server.
|
|
136
|
+
|
|
137
|
+
vLLM API reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
138
|
+
|
|
139
|
+
LoRA adapters in vLLM
|
|
140
|
+
---------------------
|
|
141
|
+
vLLM loads LoRA adapters at startup via the ``--lora-modules`` flag::
|
|
142
|
+
|
|
143
|
+
python -m vllm.entrypoints.openai.api_server \\
|
|
144
|
+
--model meta-llama/Meta-Llama-3-8B \\
|
|
145
|
+
--lora-modules my-lora=/path/to/adapter \\
|
|
146
|
+
--enable-lora
|
|
147
|
+
|
|
148
|
+
After that, passing ``"model": "my-lora"`` in a chat completion request
|
|
149
|
+
automatically activates the adapter.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(self, base_url: str = "http://localhost:8000") -> None:
|
|
153
|
+
self.base_url = base_url.rstrip("/")
|
|
154
|
+
|
|
155
|
+
def is_available(self) -> bool:
|
|
156
|
+
"""Return True if the vLLM server responds on the health endpoint."""
|
|
157
|
+
try:
|
|
158
|
+
r = httpx.get(f"{self.base_url}/health", timeout=_CONNECT_TIMEOUT)
|
|
159
|
+
return r.status_code == 200
|
|
160
|
+
except Exception:
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
def generate(
|
|
164
|
+
self,
|
|
165
|
+
prompt: str,
|
|
166
|
+
adapter: AdapterEntry,
|
|
167
|
+
*,
|
|
168
|
+
lora_name: str | None = None,
|
|
169
|
+
system_prompt: str = "You are a helpful assistant.",
|
|
170
|
+
) -> str:
|
|
171
|
+
"""Generate text via vLLM's ``/v1/chat/completions`` endpoint.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
prompt:
|
|
176
|
+
The user message content.
|
|
177
|
+
adapter:
|
|
178
|
+
The selected ``AdapterEntry``; ``adapter.id`` is used as the
|
|
179
|
+
``model`` field (which vLLM maps to the LoRA name) unless
|
|
180
|
+
overridden by ``lora_name``.
|
|
181
|
+
lora_name:
|
|
182
|
+
Override the vLLM model/lora name.
|
|
183
|
+
system_prompt:
|
|
184
|
+
System message prepended before the user message.
|
|
185
|
+
"""
|
|
186
|
+
model = lora_name or adapter.id
|
|
187
|
+
payload = {
|
|
188
|
+
"model": model,
|
|
189
|
+
"messages": [
|
|
190
|
+
{"role": "system", "content": system_prompt},
|
|
191
|
+
{"role": "user", "content": prompt},
|
|
192
|
+
],
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
logger.debug("vLLM generate: model=%s", model)
|
|
196
|
+
try:
|
|
197
|
+
r = httpx.post(
|
|
198
|
+
f"{self.base_url}/v1/chat/completions",
|
|
199
|
+
json=payload,
|
|
200
|
+
timeout=_READ_TIMEOUT,
|
|
201
|
+
)
|
|
202
|
+
r.raise_for_status()
|
|
203
|
+
except httpx.HTTPError as exc:
|
|
204
|
+
raise BackendError(f"vLLM request failed: {exc}") from exc
|
|
205
|
+
|
|
206
|
+
data = r.json()
|
|
207
|
+
try:
|
|
208
|
+
return data["choices"][0]["message"]["content"]
|
|
209
|
+
except (KeyError, IndexError) as exc:
|
|
210
|
+
raise BackendError(f"Unexpected vLLM response format: {data}") from exc
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# ---------------------------------------------------------------------------
|
|
214
|
+
# BackendRouter — auto-detects which backend is live
|
|
215
|
+
# ---------------------------------------------------------------------------
|
|
216
|
+
|
|
217
|
+
class BackendRouter:
|
|
218
|
+
"""Detects and delegates to whichever local backend is running.
|
|
219
|
+
|
|
220
|
+
Priority: Ollama first, then vLLM. If neither is available, calls to
|
|
221
|
+
``generate`` raise ``NoBackendError`` with a helpful message.
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
def __init__(
|
|
225
|
+
self,
|
|
226
|
+
ollama_url: str = "http://localhost:11434",
|
|
227
|
+
vllm_url: str = "http://localhost:8000",
|
|
228
|
+
) -> None:
|
|
229
|
+
self._ollama = OllamaBackend(ollama_url)
|
|
230
|
+
self._vllm = VLLMBackend(vllm_url)
|
|
231
|
+
self._active: BaseBackend | None = None
|
|
232
|
+
|
|
233
|
+
def detect(self) -> str | None:
|
|
234
|
+
"""Probe both backends and return the name of the one that responds.
|
|
235
|
+
|
|
236
|
+
Returns ``"ollama"``, ``"vllm"``, or ``None``.
|
|
237
|
+
"""
|
|
238
|
+
if self._ollama.is_available():
|
|
239
|
+
self._active = self._ollama
|
|
240
|
+
return "ollama"
|
|
241
|
+
if self._vllm.is_available():
|
|
242
|
+
self._active = self._vllm
|
|
243
|
+
return "vllm"
|
|
244
|
+
self._active = None
|
|
245
|
+
return None
|
|
246
|
+
|
|
247
|
+
def generate(self, prompt: str, adapter: AdapterEntry) -> str:
|
|
248
|
+
"""Route the prompt to the active backend.
|
|
249
|
+
|
|
250
|
+
If no backend was detected yet, ``detect()`` is called automatically.
|
|
251
|
+
|
|
252
|
+
Raises
|
|
253
|
+
------
|
|
254
|
+
NoBackendError
|
|
255
|
+
If neither Ollama nor vLLM is reachable.
|
|
256
|
+
"""
|
|
257
|
+
if self._active is None:
|
|
258
|
+
self.detect()
|
|
259
|
+
if self._active is None:
|
|
260
|
+
raise NoBackendError(
|
|
261
|
+
"No inference backend detected.\n"
|
|
262
|
+
" • Start Ollama : ollama serve\n"
|
|
263
|
+
" • Start vLLM : python -m vllm.entrypoints.openai.api_server "
|
|
264
|
+
"--model <base_model> --enable-lora\n\n"
|
|
265
|
+
"shiftgate can route queries without a backend. "
|
|
266
|
+
"Use `shiftgate route` to see routing decisions without inference."
|
|
267
|
+
)
|
|
268
|
+
return self._active.generate(prompt, adapter)
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def active_backend_name(self) -> str | None:
|
|
272
|
+
"""Return 'ollama', 'vllm', or None depending on what was detected."""
|
|
273
|
+
if self._active is self._ollama:
|
|
274
|
+
return "ollama"
|
|
275
|
+
if self._active is self._vllm:
|
|
276
|
+
return "vllm"
|
|
277
|
+
return None
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
# ---------------------------------------------------------------------------
|
|
281
|
+
# Exceptions
|
|
282
|
+
# ---------------------------------------------------------------------------
|
|
283
|
+
|
|
284
|
+
class BackendError(RuntimeError):
|
|
285
|
+
"""Raised when an inference backend returns an error or unexpected response."""
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class NoBackendError(RuntimeError):
|
|
289
|
+
"""Raised when no local inference backend is reachable."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utils sub-package: Rich terminal UI helpers."""
|