shiftgate 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shiftgate/__init__.py +9 -0
- shiftgate/cli.py +513 -0
- shiftgate/feedback/__init__.py +1 -0
- shiftgate/feedback/loop.py +182 -0
- shiftgate/registry/__init__.py +1 -0
- shiftgate/registry/adapter_registry.py +162 -0
- shiftgate/registry/schemas.py +115 -0
- shiftgate/registry/task_registry.py +186 -0
- shiftgate/router/__init__.py +1 -0
- shiftgate/router/embedder.py +95 -0
- shiftgate/router/matcher.py +115 -0
- shiftgate/router/router.py +97 -0
- shiftgate/runtime/__init__.py +1 -0
- shiftgate/runtime/backend.py +289 -0
- shiftgate/utils/__init__.py +1 -0
- shiftgate/utils/display.py +297 -0
- shiftgate-0.1.0.dist-info/METADATA +273 -0
- shiftgate-0.1.0.dist-info/RECORD +20 -0
- shiftgate-0.1.0.dist-info/WHEEL +4 -0
- shiftgate-0.1.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Feedback loop: persist routing traces and compute adapter acceptance scores.
|
|
3
|
+
|
|
4
|
+
Traces are stored as newline-delimited JSON in ``~/.shiftgate/traces.jsonl``.
|
|
5
|
+
Each line is a serialised ``RoutingTrace``. This format is append-only and
|
|
6
|
+
easy to stream-process without loading the entire file into memory.
|
|
7
|
+
|
|
8
|
+
Workflow
|
|
9
|
+
--------
|
|
10
|
+
1. After every ``shiftgate route`` / ``shiftgate run``, call ``record_trace``.
|
|
11
|
+
2. User runs ``shiftgate feedback accept`` or ``shiftgate feedback reject``.
|
|
12
|
+
3. Call ``mark_accepted(trace_id, accepted)`` to annotate the trace.
|
|
13
|
+
4. ``compute_adapter_scores()`` aggregates acceptance rates per adapter.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
from shiftgate.registry.schemas import RoutingTrace
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
_SHIFTGATE_DIR = Path.home() / ".shiftgate"
|
|
27
|
+
_TRACES_PATH = _SHIFTGATE_DIR / "traces.jsonl"
|
|
28
|
+
|
|
29
|
+
# How many recent traces to scan when ``mark_accepted`` searches by trace ID.
|
|
30
|
+
_RECENT_SCAN_LIMIT = 200
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def record_trace(trace: RoutingTrace) -> None:
|
|
34
|
+
"""Append a ``RoutingTrace`` as a JSON line to the traces log.
|
|
35
|
+
|
|
36
|
+
The file is created (along with its parent directory) on first write.
|
|
37
|
+
"""
|
|
38
|
+
_SHIFTGATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
line = trace.model_dump_json()
|
|
40
|
+
with _TRACES_PATH.open("a", encoding="utf-8") as fh:
|
|
41
|
+
fh.write(line + "\n")
|
|
42
|
+
logger.debug("Trace %s recorded.", trace.id)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_last_trace() -> RoutingTrace | None:
|
|
46
|
+
"""Return the most recently recorded trace, or None if no traces exist."""
|
|
47
|
+
if not _TRACES_PATH.exists():
|
|
48
|
+
return None
|
|
49
|
+
last_line: str | None = None
|
|
50
|
+
with _TRACES_PATH.open("r", encoding="utf-8") as fh:
|
|
51
|
+
for line in fh:
|
|
52
|
+
line = line.strip()
|
|
53
|
+
if line:
|
|
54
|
+
last_line = line
|
|
55
|
+
if last_line is None:
|
|
56
|
+
return None
|
|
57
|
+
return RoutingTrace.model_validate_json(last_line)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def mark_accepted(trace_id: str, accepted: bool) -> bool:
|
|
61
|
+
"""Set the ``accepted`` field on a specific trace.
|
|
62
|
+
|
|
63
|
+
Rewrites the last ``_RECENT_SCAN_LIMIT`` lines of the traces file in-place
|
|
64
|
+
(only those lines, prepending unchanged older lines). Trades slight memory
|
|
65
|
+
use for simplicity.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
trace_id:
|
|
70
|
+
The ``RoutingTrace.id`` hex string to update.
|
|
71
|
+
accepted:
|
|
72
|
+
True = good routing decision, False = bad routing decision.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
True if the trace was found and updated; False if not found.
|
|
77
|
+
"""
|
|
78
|
+
if not _TRACES_PATH.exists():
|
|
79
|
+
logger.warning("No traces file found at %s.", _TRACES_PATH)
|
|
80
|
+
return False
|
|
81
|
+
|
|
82
|
+
lines = _TRACES_PATH.read_text(encoding="utf-8").splitlines()
|
|
83
|
+
updated = False
|
|
84
|
+
|
|
85
|
+
for i in range(len(lines) - 1, max(-1, len(lines) - _RECENT_SCAN_LIMIT - 1), -1):
|
|
86
|
+
line = lines[i].strip()
|
|
87
|
+
if not line:
|
|
88
|
+
continue
|
|
89
|
+
try:
|
|
90
|
+
data = json.loads(line)
|
|
91
|
+
except json.JSONDecodeError:
|
|
92
|
+
continue
|
|
93
|
+
if data.get("id") == trace_id:
|
|
94
|
+
data["accepted"] = accepted
|
|
95
|
+
lines[i] = json.dumps(data, ensure_ascii=False)
|
|
96
|
+
updated = True
|
|
97
|
+
break
|
|
98
|
+
|
|
99
|
+
if updated:
|
|
100
|
+
_TRACES_PATH.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
|
101
|
+
logger.debug("Trace %s marked accepted=%s.", trace_id, accepted)
|
|
102
|
+
|
|
103
|
+
return updated
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def mark_last_accepted(accepted: bool) -> RoutingTrace | None:
|
|
107
|
+
"""Convenience: mark the most recent trace as accepted/rejected.
|
|
108
|
+
|
|
109
|
+
Returns the updated trace, or None if no traces exist.
|
|
110
|
+
"""
|
|
111
|
+
trace = get_last_trace()
|
|
112
|
+
if trace is None:
|
|
113
|
+
return None
|
|
114
|
+
mark_accepted(trace.id, accepted)
|
|
115
|
+
trace.accepted = accepted
|
|
116
|
+
return trace
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def load_all_traces() -> list[RoutingTrace]:
|
|
120
|
+
"""Load all traces from disk into memory.
|
|
121
|
+
|
|
122
|
+
For large files prefer streaming with ``iter_traces()`` instead.
|
|
123
|
+
"""
|
|
124
|
+
return list(iter_traces())
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def iter_traces():
|
|
128
|
+
"""Yield ``RoutingTrace`` objects one at a time from the traces file."""
|
|
129
|
+
if not _TRACES_PATH.exists():
|
|
130
|
+
return
|
|
131
|
+
with _TRACES_PATH.open("r", encoding="utf-8") as fh:
|
|
132
|
+
for line in fh:
|
|
133
|
+
line = line.strip()
|
|
134
|
+
if not line:
|
|
135
|
+
continue
|
|
136
|
+
try:
|
|
137
|
+
yield RoutingTrace.model_validate_json(line)
|
|
138
|
+
except Exception as exc:
|
|
139
|
+
logger.warning("Skipping malformed trace line: %s", exc)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def compute_adapter_scores() -> dict[str, float]:
|
|
143
|
+
"""Compute the acceptance rate for each adapter across all rated traces.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
A dict mapping ``adapter_id`` → acceptance rate (0.0 – 1.0).
|
|
148
|
+
Only adapters with at least one rated trace are included.
|
|
149
|
+
Adapters with a 0 % acceptance rate are included with score 0.0.
|
|
150
|
+
"""
|
|
151
|
+
totals: dict[str, int] = {}
|
|
152
|
+
accepted_counts: dict[str, int] = {}
|
|
153
|
+
|
|
154
|
+
for trace in iter_traces():
|
|
155
|
+
if trace.accepted is None:
|
|
156
|
+
continue
|
|
157
|
+
aid = trace.selected_adapter_id
|
|
158
|
+
totals[aid] = totals.get(aid, 0) + 1
|
|
159
|
+
if trace.accepted:
|
|
160
|
+
accepted_counts[aid] = accepted_counts.get(aid, 0) + 1
|
|
161
|
+
|
|
162
|
+
return {
|
|
163
|
+
aid: accepted_counts.get(aid, 0) / total
|
|
164
|
+
for aid, total in totals.items()
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def get_trace_stats() -> dict[str, int]:
|
|
169
|
+
"""Return summary statistics about the traces file.
|
|
170
|
+
|
|
171
|
+
Keys: ``total``, ``accepted``, ``rejected``, ``unrated``.
|
|
172
|
+
"""
|
|
173
|
+
stats = {"total": 0, "accepted": 0, "rejected": 0, "unrated": 0}
|
|
174
|
+
for trace in iter_traces():
|
|
175
|
+
stats["total"] += 1
|
|
176
|
+
if trace.accepted is True:
|
|
177
|
+
stats["accepted"] += 1
|
|
178
|
+
elif trace.accepted is False:
|
|
179
|
+
stats["rejected"] += 1
|
|
180
|
+
else:
|
|
181
|
+
stats["unrated"] += 1
|
|
182
|
+
return stats
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Registry sub-package: adapter catalog, task clusters, and Pydantic schemas."""
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Adapter registry: load, persist, and manage AdapterEntry definitions.
|
|
3
|
+
|
|
4
|
+
The registry reads from (in priority order):
|
|
5
|
+
1. ``~/.shiftgate/adapters.json`` — user-edited / previously saved
|
|
6
|
+
2. ``<package>/../../data/default_adapters.json`` — bundled defaults (empty list)
|
|
7
|
+
|
|
8
|
+
Adapters can be added by passing a HuggingFace repo ID string or a full
|
|
9
|
+
``AdapterEntry`` object. When a bare HF repo ID is provided, metadata is
|
|
10
|
+
fetched from the Hub to fill in the entry automatically.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
from shiftgate.registry.schemas import AdapterEntry
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
_SHIFTGATE_DIR = Path.home() / ".shiftgate"
|
|
24
|
+
_USER_ADAPTERS_PATH = _SHIFTGATE_DIR / "adapters.json"
|
|
25
|
+
_DEFAULT_ADAPTERS_PATH = Path(__file__).parent.parent.parent / "data" / "default_adapters.json"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AdapterRegistry:
|
|
29
|
+
"""In-memory store for AdapterEntry objects, backed by a JSON file.
|
|
30
|
+
|
|
31
|
+
Usage::
|
|
32
|
+
|
|
33
|
+
registry = AdapterRegistry.load()
|
|
34
|
+
registry.add_adapter(AdapterEntry(id="my-lora", ...))
|
|
35
|
+
registry.save()
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, adapters: list[AdapterEntry], source_path: Path) -> None:
|
|
39
|
+
self._adapters: dict[str, AdapterEntry] = {a.id: a for a in adapters}
|
|
40
|
+
self._source_path = source_path
|
|
41
|
+
|
|
42
|
+
# ------------------------------------------------------------------
|
|
43
|
+
# Factory / persistence
|
|
44
|
+
# ------------------------------------------------------------------
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def load(cls) -> "AdapterRegistry":
|
|
48
|
+
"""Load the adapter registry from disk.
|
|
49
|
+
|
|
50
|
+
Prefers ``~/.shiftgate/adapters.json``. Falls back to the bundled
|
|
51
|
+
``data/default_adapters.json`` (which ships as an empty list).
|
|
52
|
+
"""
|
|
53
|
+
if _USER_ADAPTERS_PATH.exists():
|
|
54
|
+
source = _USER_ADAPTERS_PATH
|
|
55
|
+
elif _DEFAULT_ADAPTERS_PATH.exists():
|
|
56
|
+
source = _DEFAULT_ADAPTERS_PATH
|
|
57
|
+
else:
|
|
58
|
+
logger.warning("No adapter registry found; starting empty.")
|
|
59
|
+
return cls([], source_path=_USER_ADAPTERS_PATH)
|
|
60
|
+
|
|
61
|
+
logger.debug("Loading adapter registry from %s", source)
|
|
62
|
+
raw = json.loads(source.read_text(encoding="utf-8"))
|
|
63
|
+
adapters = [AdapterEntry.model_validate(a) for a in raw]
|
|
64
|
+
return cls(adapters, source_path=source)
|
|
65
|
+
|
|
66
|
+
def save(self) -> None:
|
|
67
|
+
"""Persist the current registry to ``~/.shiftgate/adapters.json``."""
|
|
68
|
+
_SHIFTGATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
69
|
+
data = [a.model_dump() for a in self._adapters.values()]
|
|
70
|
+
_USER_ADAPTERS_PATH.write_text(
|
|
71
|
+
json.dumps(data, indent=2, ensure_ascii=False),
|
|
72
|
+
encoding="utf-8",
|
|
73
|
+
)
|
|
74
|
+
logger.debug("Adapter registry saved to %s", _USER_ADAPTERS_PATH)
|
|
75
|
+
|
|
76
|
+
# ------------------------------------------------------------------
|
|
77
|
+
# CRUD
|
|
78
|
+
# ------------------------------------------------------------------
|
|
79
|
+
|
|
80
|
+
def get_adapter(self, adapter_id: str) -> AdapterEntry | None:
|
|
81
|
+
"""Return an adapter by ID, or None if not found."""
|
|
82
|
+
return self._adapters.get(adapter_id)
|
|
83
|
+
|
|
84
|
+
def list_adapters(self) -> list[AdapterEntry]:
|
|
85
|
+
"""Return all registered adapters."""
|
|
86
|
+
return list(self._adapters.values())
|
|
87
|
+
|
|
88
|
+
def add_adapter(self, adapter: AdapterEntry | str, **kwargs: object) -> AdapterEntry:
|
|
89
|
+
"""Add or replace an adapter in the registry.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
adapter:
|
|
94
|
+
Either a fully-constructed ``AdapterEntry`` or a HuggingFace
|
|
95
|
+
repo ID string (e.g. ``"username/my-lora-adapter"``). When a
|
|
96
|
+
string is provided the repo ID is used as ``hf_repo`` and a
|
|
97
|
+
best-effort ID slug is derived from it. Extra keyword arguments
|
|
98
|
+
(``tags``, ``base_model``, ``description``) override auto-derived
|
|
99
|
+
values.
|
|
100
|
+
"""
|
|
101
|
+
if isinstance(adapter, str):
|
|
102
|
+
adapter = _adapter_from_hf_repo(adapter, **kwargs)
|
|
103
|
+
|
|
104
|
+
self._adapters[adapter.id] = adapter
|
|
105
|
+
logger.debug("Adapter '%s' added to registry.", adapter.id)
|
|
106
|
+
return adapter
|
|
107
|
+
|
|
108
|
+
def remove_adapter(self, adapter_id: str) -> bool:
|
|
109
|
+
"""Remove an adapter by ID. Returns True if it existed."""
|
|
110
|
+
if adapter_id in self._adapters:
|
|
111
|
+
del self._adapters[adapter_id]
|
|
112
|
+
return True
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
def __len__(self) -> int:
|
|
116
|
+
return len(self._adapters)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# ---------------------------------------------------------------------------
|
|
120
|
+
# Helpers
|
|
121
|
+
# ---------------------------------------------------------------------------
|
|
122
|
+
|
|
123
|
+
def _adapter_from_hf_repo(hf_repo: str, **kwargs: object) -> AdapterEntry:
|
|
124
|
+
"""Construct a minimal AdapterEntry from a HuggingFace repo ID.
|
|
125
|
+
|
|
126
|
+
Tries to pull card metadata from the Hub. If that fails (offline, private
|
|
127
|
+
repo, etc.) it builds a stub entry from the repo ID alone.
|
|
128
|
+
|
|
129
|
+
Extra ``kwargs`` are merged after auto-detection and override any
|
|
130
|
+
auto-derived fields (``tags``, ``base_model``, ``description``).
|
|
131
|
+
"""
|
|
132
|
+
# Derive a clean ID slug from the repo path (e.g. "org/my-lora" → "my-lora")
|
|
133
|
+
slug = hf_repo.split("/")[-1].lower().replace("_", "-")
|
|
134
|
+
|
|
135
|
+
entry_data: dict = {
|
|
136
|
+
"id": slug,
|
|
137
|
+
"name": slug.replace("-", " ").title(),
|
|
138
|
+
"base_model": kwargs.pop("base_model", "unknown"),
|
|
139
|
+
"task_tags": kwargs.pop("tags", []),
|
|
140
|
+
"description": kwargs.pop("description", f"Imported from {hf_repo}"),
|
|
141
|
+
"hf_repo": hf_repo,
|
|
142
|
+
}
|
|
143
|
+
entry_data.update(kwargs)
|
|
144
|
+
|
|
145
|
+
# Attempt to enrich from HuggingFace Hub metadata.
|
|
146
|
+
try:
|
|
147
|
+
from huggingface_hub import hf_hub_download, model_info # type: ignore
|
|
148
|
+
|
|
149
|
+
info = model_info(hf_repo)
|
|
150
|
+
if info.card_data:
|
|
151
|
+
card = info.card_data
|
|
152
|
+
if hasattr(card, "base_model") and card.base_model:
|
|
153
|
+
base = card.base_model
|
|
154
|
+
entry_data["base_model"] = base[0] if isinstance(base, list) else base
|
|
155
|
+
if hasattr(card, "tags") and card.tags and not entry_data["task_tags"]:
|
|
156
|
+
entry_data["task_tags"] = list(card.tags)[:8]
|
|
157
|
+
if info.id:
|
|
158
|
+
entry_data["name"] = info.id.split("/")[-1]
|
|
159
|
+
except Exception as exc:
|
|
160
|
+
logger.debug("Could not fetch HF metadata for '%s': %s", hf_repo, exc)
|
|
161
|
+
|
|
162
|
+
return AdapterEntry.model_validate(entry_data)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pydantic v2 schemas for shiftgate's core data model.
|
|
3
|
+
|
|
4
|
+
Three top-level types:
|
|
5
|
+
- AdapterEntry : a LoRA adapter (or fine-tuned model) in the registry
|
|
6
|
+
- TaskCluster : a group of semantically related tasks with example queries
|
|
7
|
+
- RoutingTrace : one routing decision, optionally annotated with user feedback
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel, Field
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AdapterEntry(BaseModel):
|
|
16
|
+
"""A LoRA adapter registered with shiftgate.
|
|
17
|
+
|
|
18
|
+
Adapters can live on HuggingFace (``hf_repo``) or locally (``local_path``).
|
|
19
|
+
At least one of the two must be set for inference to work, though the
|
|
20
|
+
registry itself does not enforce this so adapters can be catalogued before
|
|
21
|
+
they are downloaded.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
id: str = Field(
|
|
25
|
+
description="Unique slug, e.g. 'python-lora-llama3'. Used as a stable reference key."
|
|
26
|
+
)
|
|
27
|
+
name: str = Field(description="Human-readable display name.")
|
|
28
|
+
base_model: str = Field(
|
|
29
|
+
description="The base model this adapter was trained on, e.g. 'meta-llama/Meta-Llama-3-8B'."
|
|
30
|
+
)
|
|
31
|
+
task_tags: list[str] = Field(
|
|
32
|
+
default_factory=list,
|
|
33
|
+
description="Free-form tags describing the adapter's specialisation, e.g. ['code', 'python'].",
|
|
34
|
+
)
|
|
35
|
+
description: str = Field(default="", description="Short prose description of the adapter's purpose.")
|
|
36
|
+
hf_repo: str | None = Field(
|
|
37
|
+
default=None,
|
|
38
|
+
description="HuggingFace Hub repository ID, e.g. 'username/my-lora-adapter'.",
|
|
39
|
+
)
|
|
40
|
+
local_path: str | None = Field(
|
|
41
|
+
default=None,
|
|
42
|
+
description="Absolute path to a local .safetensors file or adapter directory.",
|
|
43
|
+
)
|
|
44
|
+
benchmark_score: float | None = Field(
|
|
45
|
+
default=None,
|
|
46
|
+
description="Optional benchmark score (0–1) reported by the adapter author.",
|
|
47
|
+
)
|
|
48
|
+
context_length: int = Field(
|
|
49
|
+
default=4096,
|
|
50
|
+
description="Maximum context window in tokens.",
|
|
51
|
+
)
|
|
52
|
+
memory_mb: int | None = Field(
|
|
53
|
+
default=None,
|
|
54
|
+
description="Approximate VRAM/RAM usage in MB when the adapter is loaded.",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TaskCluster(BaseModel):
|
|
59
|
+
"""A cluster of semantically related tasks used for routing.
|
|
60
|
+
|
|
61
|
+
During ``shiftgate init``, the ``validation_examples`` are embedded and
|
|
62
|
+
averaged to produce ``embedding_centroid``. At routing time the query
|
|
63
|
+
embedding is compared against every cluster's centroid.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
id: str = Field(
|
|
67
|
+
description="Unique slug, e.g. 'code_python'. Used as a stable routing key."
|
|
68
|
+
)
|
|
69
|
+
name: str = Field(description="Human-readable cluster name, e.g. 'Python Code Generation'.")
|
|
70
|
+
description: str = Field(description="Short description of what tasks belong here.")
|
|
71
|
+
validation_examples: list[str] = Field(
|
|
72
|
+
description="3–10 representative query strings used to compute the centroid embedding.",
|
|
73
|
+
)
|
|
74
|
+
embedding_centroid: list[float] | None = Field(
|
|
75
|
+
default=None,
|
|
76
|
+
description="Pre-computed mean embedding of the validation_examples. Populated by init.",
|
|
77
|
+
)
|
|
78
|
+
preferred_adapters: list[str] = Field(
|
|
79
|
+
default_factory=list,
|
|
80
|
+
description="Adapter IDs in priority order. The first available adapter is selected.",
|
|
81
|
+
)
|
|
82
|
+
fallback_adapters: list[str] = Field(
|
|
83
|
+
default_factory=list,
|
|
84
|
+
description="Adapter IDs to try when none of the preferred adapters are available.",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class RoutingTrace(BaseModel):
|
|
89
|
+
"""A single routing decision recorded for observability and feedback.
|
|
90
|
+
|
|
91
|
+
Traces are appended as JSON lines to ``~/.shiftgate/traces.jsonl``.
|
|
92
|
+
The ``accepted`` field starts as ``None`` and is filled in via
|
|
93
|
+
``shiftgate feedback accept/reject``.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
id: str = Field(
|
|
97
|
+
description="Unique trace ID (UUID4 hex string) for targeted feedback updates."
|
|
98
|
+
)
|
|
99
|
+
query: str = Field(description="The original user query that triggered this routing decision.")
|
|
100
|
+
matched_task_id: str = Field(description="ID of the TaskCluster that won the similarity match.")
|
|
101
|
+
similarity_score: float = Field(
|
|
102
|
+
description="Cosine similarity between the query embedding and the winning centroid (0–1)."
|
|
103
|
+
)
|
|
104
|
+
selected_adapter_id: str = Field(description="ID of the adapter that was selected for inference.")
|
|
105
|
+
accepted: bool | None = Field(
|
|
106
|
+
default=None,
|
|
107
|
+
description="User feedback: True = good routing, False = bad routing, None = not yet rated.",
|
|
108
|
+
)
|
|
109
|
+
latency_ms: float | None = Field(
|
|
110
|
+
default=None,
|
|
111
|
+
description="End-to-end inference latency in milliseconds (None if only routing, no run).",
|
|
112
|
+
)
|
|
113
|
+
timestamp: str = Field(
|
|
114
|
+
description="ISO-8601 UTC timestamp of when this trace was created."
|
|
115
|
+
)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task registry: load, persist, and manage TaskCluster definitions.
|
|
3
|
+
|
|
4
|
+
The registry reads from (in priority order):
|
|
5
|
+
1. ``~/.shiftgate/tasks.json`` — user-edited / previously saved
|
|
6
|
+
2. ``<package>/../../data/default_tasks.json`` — bundled defaults
|
|
7
|
+
|
|
8
|
+
On first run (``shiftgate init``) the ``compute_embeddings`` method is called
|
|
9
|
+
to populate ``embedding_centroid`` for every cluster and cache them to
|
|
10
|
+
``~/.shiftgate/embeddings_cache.npy`` so subsequent startups are instant.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
from shiftgate.registry.schemas import TaskCluster
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from shiftgate.router.embedder import Embedder
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# Canonical locations
|
|
30
|
+
_SHIFTGATE_DIR = Path.home() / ".shiftgate"
|
|
31
|
+
_USER_TASKS_PATH = _SHIFTGATE_DIR / "tasks.json"
|
|
32
|
+
_CACHE_PATH = _SHIFTGATE_DIR / "embeddings_cache.npy"
|
|
33
|
+
|
|
34
|
+
# Path to the bundled default tasks, resolved relative to this file's location.
|
|
35
|
+
_DEFAULT_TASKS_PATH = Path(__file__).parent.parent.parent / "data" / "default_tasks.json"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TaskRegistry:
|
|
39
|
+
"""In-memory store for TaskCluster objects, backed by a JSON file.
|
|
40
|
+
|
|
41
|
+
Usage::
|
|
42
|
+
|
|
43
|
+
registry = TaskRegistry.load()
|
|
44
|
+
registry.compute_embeddings(embedder)
|
|
45
|
+
registry.save()
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, tasks: list[TaskCluster], source_path: Path) -> None:
|
|
49
|
+
self._tasks: dict[str, TaskCluster] = {t.id: t for t in tasks}
|
|
50
|
+
self._source_path = source_path
|
|
51
|
+
|
|
52
|
+
# ------------------------------------------------------------------
|
|
53
|
+
# Factory / persistence
|
|
54
|
+
# ------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def load(cls) -> "TaskRegistry":
|
|
58
|
+
"""Load the task registry from disk.
|
|
59
|
+
|
|
60
|
+
Prefers the user's ``~/.shiftgate/tasks.json`` and falls back to the
|
|
61
|
+
bundled ``data/default_tasks.json`` if the user file does not exist.
|
|
62
|
+
"""
|
|
63
|
+
if _USER_TASKS_PATH.exists():
|
|
64
|
+
source = _USER_TASKS_PATH
|
|
65
|
+
elif _DEFAULT_TASKS_PATH.exists():
|
|
66
|
+
source = _DEFAULT_TASKS_PATH
|
|
67
|
+
else:
|
|
68
|
+
raise FileNotFoundError(
|
|
69
|
+
f"No task registry found. Expected one of:\n"
|
|
70
|
+
f" {_USER_TASKS_PATH}\n"
|
|
71
|
+
f" {_DEFAULT_TASKS_PATH}\n"
|
|
72
|
+
"Run `shiftgate init` to set up the default registry."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
logger.debug("Loading task registry from %s", source)
|
|
76
|
+
raw = json.loads(source.read_text(encoding="utf-8"))
|
|
77
|
+
tasks = [TaskCluster.model_validate(t) for t in raw]
|
|
78
|
+
instance = cls(tasks, source_path=source)
|
|
79
|
+
|
|
80
|
+
# Eagerly restore cached centroids so ``compute_embeddings`` can be
|
|
81
|
+
# skipped on normal runs (not first init).
|
|
82
|
+
instance._restore_cache()
|
|
83
|
+
return instance
|
|
84
|
+
|
|
85
|
+
def save(self) -> None:
|
|
86
|
+
"""Persist the current registry to ``~/.shiftgate/tasks.json``."""
|
|
87
|
+
_SHIFTGATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
data = [t.model_dump() for t in self._tasks.values()]
|
|
89
|
+
_USER_TASKS_PATH.write_text(
|
|
90
|
+
json.dumps(data, indent=2, ensure_ascii=False),
|
|
91
|
+
encoding="utf-8",
|
|
92
|
+
)
|
|
93
|
+
logger.debug("Task registry saved to %s", _USER_TASKS_PATH)
|
|
94
|
+
|
|
95
|
+
# ------------------------------------------------------------------
|
|
96
|
+
# Embedding management
|
|
97
|
+
# ------------------------------------------------------------------
|
|
98
|
+
|
|
99
|
+
def compute_embeddings(self, embedder: "Embedder") -> None:
|
|
100
|
+
"""Compute and store the centroid embedding for every task cluster.
|
|
101
|
+
|
|
102
|
+
For each cluster the validation examples are embedded individually and
|
|
103
|
+
then averaged (L2-normalised mean) to form a single centroid vector.
|
|
104
|
+
The results are written back into each ``TaskCluster.embedding_centroid``
|
|
105
|
+
field **and** saved to ``~/.shiftgate/embeddings_cache.npy`` as a
|
|
106
|
+
(n_tasks × dim) float32 array for fast loading on future runs.
|
|
107
|
+
"""
|
|
108
|
+
task_list = list(self._tasks.values())
|
|
109
|
+
logger.info("Computing embeddings for %d task clusters…", len(task_list))
|
|
110
|
+
|
|
111
|
+
for task in task_list:
|
|
112
|
+
all_examples = task.validation_examples
|
|
113
|
+
embeddings = embedder.embed_batch(all_examples) # shape: (n, dim)
|
|
114
|
+
centroid = embeddings.mean(axis=0)
|
|
115
|
+
# L2-normalise so cosine similarity reduces to dot product later.
|
|
116
|
+
norm = np.linalg.norm(centroid)
|
|
117
|
+
if norm > 0:
|
|
118
|
+
centroid = centroid / norm
|
|
119
|
+
task.embedding_centroid = centroid.tolist()
|
|
120
|
+
|
|
121
|
+
# Persist centroids as a numpy array indexed by task order.
|
|
122
|
+
self._save_cache(task_list)
|
|
123
|
+
logger.info("Embeddings computed and cached.")
|
|
124
|
+
|
|
125
|
+
def _save_cache(self, task_list: list[TaskCluster]) -> None:
|
|
126
|
+
"""Write centroids to the numpy cache file."""
|
|
127
|
+
_SHIFTGATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
128
|
+
centroids = [t.embedding_centroid for t in task_list if t.embedding_centroid]
|
|
129
|
+
if centroids:
|
|
130
|
+
arr = np.array(centroids, dtype=np.float32)
|
|
131
|
+
np.save(_CACHE_PATH, arr)
|
|
132
|
+
logger.debug("Centroid cache written to %s (%s)", _CACHE_PATH, arr.shape)
|
|
133
|
+
|
|
134
|
+
def _restore_cache(self) -> None:
|
|
135
|
+
"""Re-populate ``embedding_centroid`` from the numpy cache if available.
|
|
136
|
+
|
|
137
|
+
This avoids a full re-embedding on every startup. The cache is keyed
|
|
138
|
+
positionally — tasks must stay in the same order between runs, which is
|
|
139
|
+
true as long as the registry JSON is not manually reordered.
|
|
140
|
+
"""
|
|
141
|
+
if not _CACHE_PATH.exists():
|
|
142
|
+
return
|
|
143
|
+
try:
|
|
144
|
+
arr = np.load(_CACHE_PATH)
|
|
145
|
+
task_list = list(self._tasks.values())
|
|
146
|
+
for i, task in enumerate(task_list):
|
|
147
|
+
if i < len(arr):
|
|
148
|
+
task.embedding_centroid = arr[i].tolist()
|
|
149
|
+
logger.debug("Restored centroids from cache (%d tasks)", len(task_list))
|
|
150
|
+
except Exception as exc:
|
|
151
|
+
logger.warning("Could not restore embedding cache (%s). Re-run `shiftgate init`.", exc)
|
|
152
|
+
|
|
153
|
+
def embeddings_ready(self) -> bool:
|
|
154
|
+
"""Return True if all task clusters have a computed centroid."""
|
|
155
|
+
return all(t.embedding_centroid is not None for t in self._tasks.values())
|
|
156
|
+
|
|
157
|
+
# ------------------------------------------------------------------
|
|
158
|
+
# CRUD
|
|
159
|
+
# ------------------------------------------------------------------
|
|
160
|
+
|
|
161
|
+
def get_all_tasks(self) -> list[TaskCluster]:
|
|
162
|
+
"""Return all registered task clusters."""
|
|
163
|
+
return list(self._tasks.values())
|
|
164
|
+
|
|
165
|
+
def get_task(self, task_id: str) -> TaskCluster | None:
|
|
166
|
+
"""Return a single task cluster by ID, or None if not found."""
|
|
167
|
+
return self._tasks.get(task_id)
|
|
168
|
+
|
|
169
|
+
def add_task(self, task: TaskCluster) -> None:
|
|
170
|
+
"""Add or replace a task cluster in the registry.
|
|
171
|
+
|
|
172
|
+
If a task with the same ID already exists it is silently overwritten.
|
|
173
|
+
Call ``save()`` afterwards to persist the change.
|
|
174
|
+
"""
|
|
175
|
+
self._tasks[task.id] = task
|
|
176
|
+
logger.debug("Task '%s' added to registry.", task.id)
|
|
177
|
+
|
|
178
|
+
def remove_task(self, task_id: str) -> bool:
|
|
179
|
+
"""Remove a task by ID. Returns True if it existed."""
|
|
180
|
+
if task_id in self._tasks:
|
|
181
|
+
del self._tasks[task_id]
|
|
182
|
+
return True
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
def __len__(self) -> int:
|
|
186
|
+
return len(self._tasks)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Router sub-package: embedding, cosine matching, and routing logic."""
|