shiftgate 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,95 @@
1
+ """
2
+ Text embedder backed by fastembed.
3
+
4
+ Uses ``BAAI/bge-small-en-v1.5`` — a compact (33 M param) model that runs
5
+ efficiently on CPU. The model is downloaded once by fastembed and cached in
6
+ ``~/.cache/fastembed``.
7
+
8
+ A module-level singleton (``_MODEL``) is created lazily on first use so that
9
+ importing this module is cheap. The model is NOT re-created between calls.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # -------------------------------------------------------------------------
22
+ # Default model — small, CPU-friendly, strong quality/speed trade-off.
23
+ # -------------------------------------------------------------------------
24
+ DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
25
+
26
+ # Module-level singleton; populated on first call to `_get_model()`.
27
+ _MODEL: Any | None = None
28
+
29
+
30
+ def _get_model(model_name: str = DEFAULT_MODEL) -> Any:
31
+ """Return the fastembed TextEmbedding singleton, creating it if needed.
32
+
33
+ The model is loaded once per process. If you need a different model,
34
+ call ``reset_model()`` first.
35
+ """
36
+ global _MODEL
37
+ if _MODEL is None:
38
+ try:
39
+ from fastembed import TextEmbedding # type: ignore
40
+ except ImportError as exc:
41
+ raise ImportError(
42
+ "fastembed is required for shiftgate routing. "
43
+ "Install it with: pip install fastembed"
44
+ ) from exc
45
+
46
+ logger.info("Loading embedding model '%s' (first use — one-time download may occur)…", model_name)
47
+ _MODEL = TextEmbedding(model_name=model_name)
48
+ logger.info("Embedding model loaded.")
49
+ return _MODEL
50
+
51
+
52
+ def reset_model() -> None:
53
+ """Force the next embed call to recreate the model singleton.
54
+
55
+ Useful in tests or when switching models at runtime.
56
+ """
57
+ global _MODEL
58
+ _MODEL = None
59
+
60
+
61
+ class Embedder:
62
+ """Thin wrapper around the fastembed TextEmbedding model.
63
+
64
+ All embedding operations are synchronous and run on CPU. The model
65
+ is shared across all ``Embedder`` instances via the module-level singleton.
66
+ """
67
+
68
+ def __init__(self, model_name: str = DEFAULT_MODEL) -> None:
69
+ self._model_name = model_name
70
+
71
+ @property
72
+ def _model(self) -> Any:
73
+ return _get_model(self._model_name)
74
+
75
+ def embed(self, text: str) -> np.ndarray:
76
+ """Embed a single text string.
77
+
78
+ Returns a 1-D float32 numpy array of shape ``(dim,)``.
79
+ The vector is **not** L2-normalised here; normalisation is done
80
+ where appropriate (e.g. when computing task centroids).
81
+ """
82
+ # fastembed returns a generator of numpy arrays.
83
+ results = list(self._model.embed([text]))
84
+ return np.array(results[0], dtype=np.float32)
85
+
86
+ def embed_batch(self, texts: list[str]) -> np.ndarray:
87
+ """Embed a list of strings.
88
+
89
+ Returns a 2-D float32 numpy array of shape ``(n, dim)`` where
90
+ ``n = len(texts)``.
91
+ """
92
+ if not texts:
93
+ raise ValueError("embed_batch received an empty list.")
94
+ results = list(self._model.embed(texts))
95
+ return np.array(results, dtype=np.float32)
@@ -0,0 +1,115 @@
1
+ """
2
+ Cosine-similarity matcher: maps query embeddings to task clusters and adapters.
3
+
4
+ This module is deliberately stateless — all context (registries, embeddings)
5
+ is passed explicitly so the functions are easy to test in isolation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+
12
+ import numpy as np
13
+
14
+ from shiftgate.registry.schemas import AdapterEntry, TaskCluster
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def top_k_tasks(
20
+ query_embedding: np.ndarray,
21
+ task_clusters: list[TaskCluster],
22
+ k: int = 3,
23
+ ) -> list[tuple[TaskCluster, float]]:
24
+ """Return the top-K task clusters by cosine similarity to the query.
25
+
26
+ Parameters
27
+ ----------
28
+ query_embedding:
29
+ 1-D float32 array of shape ``(dim,)``. Need not be L2-normalised;
30
+ this function normalises internally.
31
+ task_clusters:
32
+ All clusters in the registry. Clusters without a computed centroid
33
+ are silently skipped.
34
+ k:
35
+ Number of top clusters to return.
36
+
37
+ Returns
38
+ -------
39
+ list of ``(TaskCluster, score)`` pairs sorted by score descending.
40
+ """
41
+ eligible = [t for t in task_clusters if t.embedding_centroid is not None]
42
+ if not eligible:
43
+ raise ValueError(
44
+ "No task cluster has a computed embedding centroid. "
45
+ "Run `shiftgate init` to compute embeddings."
46
+ )
47
+
48
+ # Stack centroids into a matrix for vectorised dot product.
49
+ centroid_matrix = np.array(
50
+ [t.embedding_centroid for t in eligible], dtype=np.float32
51
+ ) # shape: (n_tasks, dim)
52
+
53
+ # L2-normalise the query vector.
54
+ q_norm = np.linalg.norm(query_embedding)
55
+ if q_norm == 0:
56
+ raise ValueError("Query produced a zero-norm embedding.")
57
+ q_unit = query_embedding / q_norm
58
+
59
+ # Cosine similarity = dot(q_unit, centroid_unit) because centroids were
60
+ # already L2-normalised at compute time (see task_registry.py).
61
+ scores = centroid_matrix @ q_unit # shape: (n_tasks,)
62
+
63
+ # Grab top-K indices.
64
+ k = min(k, len(eligible))
65
+ top_indices = np.argsort(scores)[::-1][:k]
66
+
67
+ return [(eligible[i], float(scores[i])) for i in top_indices]
68
+
69
+
70
+ def select_adapter(
71
+ top_tasks: list[tuple[TaskCluster, float]],
72
+ adapter_registry, # AdapterRegistry — avoid circular import with string hint
73
+ ) -> tuple[AdapterEntry, TaskCluster, float]:
74
+ """Select the best adapter given the ranked task list.
75
+
76
+ Strategy:
77
+ 1. Iterate top tasks in similarity order.
78
+ 2. For each task, try ``preferred_adapters`` then ``fallback_adapters``.
79
+ 3. Return the first adapter that exists in the registry.
80
+ 4. If no registered adapter matches any task, raise ``NoAdapterError``.
81
+
82
+ Parameters
83
+ ----------
84
+ top_tasks:
85
+ Output of ``top_k_tasks`` — list of (TaskCluster, score) descending.
86
+ adapter_registry:
87
+ ``AdapterRegistry`` instance to look up adapter IDs.
88
+
89
+ Returns
90
+ -------
91
+ ``(AdapterEntry, TaskCluster, similarity_score)``
92
+ """
93
+ for task, score in top_tasks:
94
+ candidates = list(task.preferred_adapters) + list(task.fallback_adapters)
95
+ for adapter_id in candidates:
96
+ adapter = adapter_registry.get_adapter(adapter_id)
97
+ if adapter is not None:
98
+ logger.debug(
99
+ "Selected adapter '%s' via task '%s' (score=%.4f)",
100
+ adapter.id,
101
+ task.id,
102
+ score,
103
+ )
104
+ return adapter, task, score
105
+
106
+ # No adapter matched — surface a helpful error.
107
+ task_ids = [t.id for t, _ in top_tasks]
108
+ raise NoAdapterError(
109
+ f"No registered adapter found for tasks {task_ids}. "
110
+ "Add adapters with `shiftgate adapter add <hf_repo>`."
111
+ )
112
+
113
+
114
+ class NoAdapterError(RuntimeError):
115
+ """Raised when the matcher cannot find any registered adapter for a query."""
@@ -0,0 +1,97 @@
1
+ """
2
+ Main routing orchestrator: query string → RoutingTrace.
3
+
4
+ This module ties together the embedder, matcher, and registries.
5
+ It is the single function that CLI commands and the runtime backend call.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import uuid
12
+ from datetime import datetime, timezone
13
+
14
+ from shiftgate.registry.adapter_registry import AdapterRegistry
15
+ from shiftgate.registry.schemas import RoutingTrace
16
+ from shiftgate.registry.task_registry import TaskRegistry
17
+ from shiftgate.router.embedder import Embedder
18
+ from shiftgate.router.matcher import NoAdapterError, select_adapter, top_k_tasks
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def route(
24
+ query: str,
25
+ task_registry: TaskRegistry,
26
+ adapter_registry: AdapterRegistry,
27
+ embedder: Embedder,
28
+ top_k: int = 3,
29
+ ) -> RoutingTrace:
30
+ """Route a query string to the best matching adapter.
31
+
32
+ Steps
33
+ -----
34
+ 1. Embed the query with the frozen embedding model.
35
+ 2. Compute cosine similarity against all task centroid embeddings.
36
+ 3. Select the highest-ranked task whose preferred adapters exist.
37
+ 4. Build and return a ``RoutingTrace``.
38
+
39
+ Parameters
40
+ ----------
41
+ query:
42
+ The user's natural-language query or instruction.
43
+ task_registry:
44
+ Loaded ``TaskRegistry`` with pre-computed centroids.
45
+ adapter_registry:
46
+ Loaded ``AdapterRegistry``.
47
+ embedder:
48
+ ``Embedder`` instance (wraps fastembed singleton).
49
+ top_k:
50
+ Number of top task candidates to consider when walking the fallback
51
+ chain. Defaults to 3.
52
+
53
+ Returns
54
+ -------
55
+ A ``RoutingTrace`` describing the decision. The trace is **not**
56
+ persisted here — call ``feedback.loop.record_trace(trace)`` separately.
57
+
58
+ Raises
59
+ ------
60
+ NoAdapterError
61
+ If no registered adapter matches any of the top-K tasks.
62
+ ValueError
63
+ If embeddings have not been computed (missing centroids).
64
+ """
65
+ if not task_registry.embeddings_ready():
66
+ raise ValueError(
67
+ "Task embeddings are not initialised. Run `shiftgate init` first."
68
+ )
69
+
70
+ # Step 1: embed the query
71
+ query_embedding = embedder.embed(query)
72
+
73
+ # Step 2: rank tasks by similarity
74
+ all_tasks = task_registry.get_all_tasks()
75
+ ranked = top_k_tasks(query_embedding, all_tasks, k=top_k)
76
+
77
+ # Step 3: pick the best adapter
78
+ adapter, matched_task, score = select_adapter(ranked, adapter_registry)
79
+
80
+ # Step 4: assemble trace
81
+ trace = RoutingTrace(
82
+ id=uuid.uuid4().hex,
83
+ query=query,
84
+ matched_task_id=matched_task.id,
85
+ similarity_score=score,
86
+ selected_adapter_id=adapter.id,
87
+ timestamp=datetime.now(timezone.utc).isoformat(),
88
+ )
89
+
90
+ logger.info(
91
+ "Routed '%s' → task='%s' (%.2f%%) → adapter='%s'",
92
+ query[:60],
93
+ matched_task.id,
94
+ score * 100,
95
+ adapter.id,
96
+ )
97
+ return trace
@@ -0,0 +1 @@
1
+ """Runtime sub-package: thin clients for Ollama and vLLM backends."""
@@ -0,0 +1,289 @@
1
+ """
2
+ Thin clients for local LLM inference backends.
3
+
4
+ Two backends are supported:
5
+
6
+ Ollama (http://localhost:11434)
7
+ Open-source inference server that supports LoRA adapters via custom
8
+ Modelfiles. To use a LoRA with Ollama, create a Modelfile like:
9
+
10
+ FROM llama3
11
+ ADAPTER /path/to/adapter.safetensors
12
+
13
+ and run ``ollama create my-model -f Modelfile``. shiftgate sets
14
+ ``model`` to the adapter's Ollama model name (derived from adapter.id).
15
+
16
+ vLLM (http://localhost:8000)
17
+ Provides an OpenAI-compatible ``/v1/chat/completions`` endpoint.
18
+ LoRA adapters are loaded at server start-up with ``--lora-modules``
19
+ and are addressed by name via the ``model`` field in the request.
20
+
21
+ Both backends are auto-detected by ``BackendRouter``, which pings each
22
+ health endpoint and delegates to whichever is available.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import logging
28
+ from abc import ABC, abstractmethod
29
+
30
+ import httpx
31
+
32
+ from shiftgate.registry.schemas import AdapterEntry
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # Default timeouts (seconds).
37
+ _CONNECT_TIMEOUT = 3.0
38
+ _READ_TIMEOUT = 120.0
39
+
40
+
41
+ class BaseBackend(ABC):
42
+ """Abstract base for inference backends."""
43
+
44
+ @abstractmethod
45
+ def is_available(self) -> bool:
46
+ """Return True if the backend can be reached."""
47
+
48
+ @abstractmethod
49
+ def generate(self, prompt: str, adapter: AdapterEntry) -> str:
50
+ """Send ``prompt`` to the backend and return the generated text."""
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Ollama
55
+ # ---------------------------------------------------------------------------
56
+
57
+ class OllamaBackend(BaseBackend):
58
+ """Thin httpx client for the Ollama inference server.
59
+
60
+ Ollama API reference: https://github.com/ollama/ollama/blob/main/docs/api.md
61
+
62
+ LoRA adapters in Ollama
63
+ -----------------------
64
+ Ollama does not have a first-class "load this adapter onto this base model"
65
+ API at request time. Instead you pre-register a composite model via a
66
+ Modelfile::
67
+
68
+ FROM llama3
69
+ ADAPTER /path/to/my-lora.safetensors
70
+
71
+ ollama create my-lora-model -f Modelfile
72
+
73
+ shiftgate uses ``adapter.id`` as the Ollama model name by convention.
74
+ Ensure your Ollama model names match shiftgate adapter IDs.
75
+ """
76
+
77
+ def __init__(self, base_url: str = "http://localhost:11434") -> None:
78
+ self.base_url = base_url.rstrip("/")
79
+
80
+ def is_available(self) -> bool:
81
+ """Return True if the Ollama server responds to a health ping."""
82
+ try:
83
+ r = httpx.get(f"{self.base_url}/api/tags", timeout=_CONNECT_TIMEOUT)
84
+ return r.status_code == 200
85
+ except Exception:
86
+ return False
87
+
88
+ def generate(
89
+ self,
90
+ prompt: str,
91
+ adapter: AdapterEntry,
92
+ *,
93
+ model_name: str | None = None,
94
+ stream: bool = False,
95
+ ) -> str:
96
+ """Generate text via Ollama's ``/api/generate`` endpoint.
97
+
98
+ Parameters
99
+ ----------
100
+ prompt:
101
+ The full prompt string.
102
+ adapter:
103
+ The selected ``AdapterEntry``; ``adapter.id`` is used as the
104
+ Ollama model name unless overridden by ``model_name``.
105
+ model_name:
106
+ Override the Ollama model name (useful when the Ollama model name
107
+ differs from the shiftgate adapter ID).
108
+ stream:
109
+ If True, Ollama streams response tokens. This client reads the
110
+ full stream and returns the concatenated text.
111
+ """
112
+ model = model_name or adapter.id
113
+ payload = {"model": model, "prompt": prompt, "stream": stream}
114
+
115
+ logger.debug("Ollama generate: model=%s", model)
116
+ try:
117
+ r = httpx.post(
118
+ f"{self.base_url}/api/generate",
119
+ json=payload,
120
+ timeout=_READ_TIMEOUT,
121
+ )
122
+ r.raise_for_status()
123
+ except httpx.HTTPError as exc:
124
+ raise BackendError(f"Ollama request failed: {exc}") from exc
125
+
126
+ data = r.json()
127
+ return data.get("response", "")
128
+
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # vLLM
132
+ # ---------------------------------------------------------------------------
133
+
134
+ class VLLMBackend(BaseBackend):
135
+ """Thin httpx client for the vLLM OpenAI-compatible inference server.
136
+
137
+ vLLM API reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
138
+
139
+ LoRA adapters in vLLM
140
+ ---------------------
141
+ vLLM loads LoRA adapters at startup via the ``--lora-modules`` flag::
142
+
143
+ python -m vllm.entrypoints.openai.api_server \\
144
+ --model meta-llama/Meta-Llama-3-8B \\
145
+ --lora-modules my-lora=/path/to/adapter \\
146
+ --enable-lora
147
+
148
+ After that, passing ``"model": "my-lora"`` in a chat completion request
149
+ automatically activates the adapter.
150
+ """
151
+
152
+ def __init__(self, base_url: str = "http://localhost:8000") -> None:
153
+ self.base_url = base_url.rstrip("/")
154
+
155
+ def is_available(self) -> bool:
156
+ """Return True if the vLLM server responds on the health endpoint."""
157
+ try:
158
+ r = httpx.get(f"{self.base_url}/health", timeout=_CONNECT_TIMEOUT)
159
+ return r.status_code == 200
160
+ except Exception:
161
+ return False
162
+
163
+ def generate(
164
+ self,
165
+ prompt: str,
166
+ adapter: AdapterEntry,
167
+ *,
168
+ lora_name: str | None = None,
169
+ system_prompt: str = "You are a helpful assistant.",
170
+ ) -> str:
171
+ """Generate text via vLLM's ``/v1/chat/completions`` endpoint.
172
+
173
+ Parameters
174
+ ----------
175
+ prompt:
176
+ The user message content.
177
+ adapter:
178
+ The selected ``AdapterEntry``; ``adapter.id`` is used as the
179
+ ``model`` field (which vLLM maps to the LoRA name) unless
180
+ overridden by ``lora_name``.
181
+ lora_name:
182
+ Override the vLLM model/lora name.
183
+ system_prompt:
184
+ System message prepended before the user message.
185
+ """
186
+ model = lora_name or adapter.id
187
+ payload = {
188
+ "model": model,
189
+ "messages": [
190
+ {"role": "system", "content": system_prompt},
191
+ {"role": "user", "content": prompt},
192
+ ],
193
+ }
194
+
195
+ logger.debug("vLLM generate: model=%s", model)
196
+ try:
197
+ r = httpx.post(
198
+ f"{self.base_url}/v1/chat/completions",
199
+ json=payload,
200
+ timeout=_READ_TIMEOUT,
201
+ )
202
+ r.raise_for_status()
203
+ except httpx.HTTPError as exc:
204
+ raise BackendError(f"vLLM request failed: {exc}") from exc
205
+
206
+ data = r.json()
207
+ try:
208
+ return data["choices"][0]["message"]["content"]
209
+ except (KeyError, IndexError) as exc:
210
+ raise BackendError(f"Unexpected vLLM response format: {data}") from exc
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # BackendRouter — auto-detects which backend is live
215
+ # ---------------------------------------------------------------------------
216
+
217
+ class BackendRouter:
218
+ """Detects and delegates to whichever local backend is running.
219
+
220
+ Priority: Ollama first, then vLLM. If neither is available, calls to
221
+ ``generate`` raise ``NoBackendError`` with a helpful message.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ ollama_url: str = "http://localhost:11434",
227
+ vllm_url: str = "http://localhost:8000",
228
+ ) -> None:
229
+ self._ollama = OllamaBackend(ollama_url)
230
+ self._vllm = VLLMBackend(vllm_url)
231
+ self._active: BaseBackend | None = None
232
+
233
+ def detect(self) -> str | None:
234
+ """Probe both backends and return the name of the one that responds.
235
+
236
+ Returns ``"ollama"``, ``"vllm"``, or ``None``.
237
+ """
238
+ if self._ollama.is_available():
239
+ self._active = self._ollama
240
+ return "ollama"
241
+ if self._vllm.is_available():
242
+ self._active = self._vllm
243
+ return "vllm"
244
+ self._active = None
245
+ return None
246
+
247
+ def generate(self, prompt: str, adapter: AdapterEntry) -> str:
248
+ """Route the prompt to the active backend.
249
+
250
+ If no backend was detected yet, ``detect()`` is called automatically.
251
+
252
+ Raises
253
+ ------
254
+ NoBackendError
255
+ If neither Ollama nor vLLM is reachable.
256
+ """
257
+ if self._active is None:
258
+ self.detect()
259
+ if self._active is None:
260
+ raise NoBackendError(
261
+ "No inference backend detected.\n"
262
+ " • Start Ollama : ollama serve\n"
263
+ " • Start vLLM : python -m vllm.entrypoints.openai.api_server "
264
+ "--model <base_model> --enable-lora\n\n"
265
+ "shiftgate can route queries without a backend. "
266
+ "Use `shiftgate route` to see routing decisions without inference."
267
+ )
268
+ return self._active.generate(prompt, adapter)
269
+
270
+ @property
271
+ def active_backend_name(self) -> str | None:
272
+ """Return 'ollama', 'vllm', or None depending on what was detected."""
273
+ if self._active is self._ollama:
274
+ return "ollama"
275
+ if self._active is self._vllm:
276
+ return "vllm"
277
+ return None
278
+
279
+
280
+ # ---------------------------------------------------------------------------
281
+ # Exceptions
282
+ # ---------------------------------------------------------------------------
283
+
284
+ class BackendError(RuntimeError):
285
+ """Raised when an inference backend returns an error or unexpected response."""
286
+
287
+
288
+ class NoBackendError(RuntimeError):
289
+ """Raised when no local inference backend is reachable."""
@@ -0,0 +1 @@
1
+ """Utils sub-package: Rich terminal UI helpers."""