shiftgate 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,182 @@
1
+ """
2
+ Feedback loop: persist routing traces and compute adapter acceptance scores.
3
+
4
+ Traces are stored as newline-delimited JSON in ``~/.shiftgate/traces.jsonl``.
5
+ Each line is a serialised ``RoutingTrace``. This format is append-only and
6
+ easy to stream-process without loading the entire file into memory.
7
+
8
+ Workflow
9
+ --------
10
+ 1. After every ``shiftgate route`` / ``shiftgate run``, call ``record_trace``.
11
+ 2. User runs ``shiftgate feedback accept`` or ``shiftgate feedback reject``.
12
+ 3. Call ``mark_accepted(trace_id, accepted)`` to annotate the trace.
13
+ 4. ``compute_adapter_scores()`` aggregates acceptance rates per adapter.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import logging
20
+ from pathlib import Path
21
+
22
+ from shiftgate.registry.schemas import RoutingTrace
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ _SHIFTGATE_DIR = Path.home() / ".shiftgate"
27
+ _TRACES_PATH = _SHIFTGATE_DIR / "traces.jsonl"
28
+
29
+ # How many recent traces to scan when ``mark_accepted`` searches by trace ID.
30
+ _RECENT_SCAN_LIMIT = 200
31
+
32
+
33
+ def record_trace(trace: RoutingTrace) -> None:
34
+ """Append a ``RoutingTrace`` as a JSON line to the traces log.
35
+
36
+ The file is created (along with its parent directory) on first write.
37
+ """
38
+ _SHIFTGATE_DIR.mkdir(parents=True, exist_ok=True)
39
+ line = trace.model_dump_json()
40
+ with _TRACES_PATH.open("a", encoding="utf-8") as fh:
41
+ fh.write(line + "\n")
42
+ logger.debug("Trace %s recorded.", trace.id)
43
+
44
+
45
+ def get_last_trace() -> RoutingTrace | None:
46
+ """Return the most recently recorded trace, or None if no traces exist."""
47
+ if not _TRACES_PATH.exists():
48
+ return None
49
+ last_line: str | None = None
50
+ with _TRACES_PATH.open("r", encoding="utf-8") as fh:
51
+ for line in fh:
52
+ line = line.strip()
53
+ if line:
54
+ last_line = line
55
+ if last_line is None:
56
+ return None
57
+ return RoutingTrace.model_validate_json(last_line)
58
+
59
+
60
+ def mark_accepted(trace_id: str, accepted: bool) -> bool:
61
+ """Set the ``accepted`` field on a specific trace.
62
+
63
+ Rewrites the last ``_RECENT_SCAN_LIMIT`` lines of the traces file in-place
64
+ (only those lines, prepending unchanged older lines). Trades slight memory
65
+ use for simplicity.
66
+
67
+ Parameters
68
+ ----------
69
+ trace_id:
70
+ The ``RoutingTrace.id`` hex string to update.
71
+ accepted:
72
+ True = good routing decision, False = bad routing decision.
73
+
74
+ Returns
75
+ -------
76
+ True if the trace was found and updated; False if not found.
77
+ """
78
+ if not _TRACES_PATH.exists():
79
+ logger.warning("No traces file found at %s.", _TRACES_PATH)
80
+ return False
81
+
82
+ lines = _TRACES_PATH.read_text(encoding="utf-8").splitlines()
83
+ updated = False
84
+
85
+ for i in range(len(lines) - 1, max(-1, len(lines) - _RECENT_SCAN_LIMIT - 1), -1):
86
+ line = lines[i].strip()
87
+ if not line:
88
+ continue
89
+ try:
90
+ data = json.loads(line)
91
+ except json.JSONDecodeError:
92
+ continue
93
+ if data.get("id") == trace_id:
94
+ data["accepted"] = accepted
95
+ lines[i] = json.dumps(data, ensure_ascii=False)
96
+ updated = True
97
+ break
98
+
99
+ if updated:
100
+ _TRACES_PATH.write_text("\n".join(lines) + "\n", encoding="utf-8")
101
+ logger.debug("Trace %s marked accepted=%s.", trace_id, accepted)
102
+
103
+ return updated
104
+
105
+
106
+ def mark_last_accepted(accepted: bool) -> RoutingTrace | None:
107
+ """Convenience: mark the most recent trace as accepted/rejected.
108
+
109
+ Returns the updated trace, or None if no traces exist.
110
+ """
111
+ trace = get_last_trace()
112
+ if trace is None:
113
+ return None
114
+ mark_accepted(trace.id, accepted)
115
+ trace.accepted = accepted
116
+ return trace
117
+
118
+
119
+ def load_all_traces() -> list[RoutingTrace]:
120
+ """Load all traces from disk into memory.
121
+
122
+ For large files prefer streaming with ``iter_traces()`` instead.
123
+ """
124
+ return list(iter_traces())
125
+
126
+
127
+ def iter_traces():
128
+ """Yield ``RoutingTrace`` objects one at a time from the traces file."""
129
+ if not _TRACES_PATH.exists():
130
+ return
131
+ with _TRACES_PATH.open("r", encoding="utf-8") as fh:
132
+ for line in fh:
133
+ line = line.strip()
134
+ if not line:
135
+ continue
136
+ try:
137
+ yield RoutingTrace.model_validate_json(line)
138
+ except Exception as exc:
139
+ logger.warning("Skipping malformed trace line: %s", exc)
140
+
141
+
142
+ def compute_adapter_scores() -> dict[str, float]:
143
+ """Compute the acceptance rate for each adapter across all rated traces.
144
+
145
+ Returns
146
+ -------
147
+ A dict mapping ``adapter_id`` → acceptance rate (0.0 – 1.0).
148
+ Only adapters with at least one rated trace are included.
149
+ Adapters with a 0 % acceptance rate are included with score 0.0.
150
+ """
151
+ totals: dict[str, int] = {}
152
+ accepted_counts: dict[str, int] = {}
153
+
154
+ for trace in iter_traces():
155
+ if trace.accepted is None:
156
+ continue
157
+ aid = trace.selected_adapter_id
158
+ totals[aid] = totals.get(aid, 0) + 1
159
+ if trace.accepted:
160
+ accepted_counts[aid] = accepted_counts.get(aid, 0) + 1
161
+
162
+ return {
163
+ aid: accepted_counts.get(aid, 0) / total
164
+ for aid, total in totals.items()
165
+ }
166
+
167
+
168
+ def get_trace_stats() -> dict[str, int]:
169
+ """Return summary statistics about the traces file.
170
+
171
+ Keys: ``total``, ``accepted``, ``rejected``, ``unrated``.
172
+ """
173
+ stats = {"total": 0, "accepted": 0, "rejected": 0, "unrated": 0}
174
+ for trace in iter_traces():
175
+ stats["total"] += 1
176
+ if trace.accepted is True:
177
+ stats["accepted"] += 1
178
+ elif trace.accepted is False:
179
+ stats["rejected"] += 1
180
+ else:
181
+ stats["unrated"] += 1
182
+ return stats
@@ -0,0 +1 @@
1
+ """Registry sub-package: adapter catalog, task clusters, and Pydantic schemas."""
@@ -0,0 +1,162 @@
1
+ """
2
+ Adapter registry: load, persist, and manage AdapterEntry definitions.
3
+
4
+ The registry reads from (in priority order):
5
+ 1. ``~/.shiftgate/adapters.json`` — user-edited / previously saved
6
+ 2. ``<package>/../../data/default_adapters.json`` — bundled defaults (empty list)
7
+
8
+ Adapters can be added by passing a HuggingFace repo ID string or a full
9
+ ``AdapterEntry`` object. When a bare HF repo ID is provided, metadata is
10
+ fetched from the Hub to fill in the entry automatically.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import logging
17
+ from pathlib import Path
18
+
19
+ from shiftgate.registry.schemas import AdapterEntry
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ _SHIFTGATE_DIR = Path.home() / ".shiftgate"
24
+ _USER_ADAPTERS_PATH = _SHIFTGATE_DIR / "adapters.json"
25
+ _DEFAULT_ADAPTERS_PATH = Path(__file__).parent.parent.parent / "data" / "default_adapters.json"
26
+
27
+
28
+ class AdapterRegistry:
29
+ """In-memory store for AdapterEntry objects, backed by a JSON file.
30
+
31
+ Usage::
32
+
33
+ registry = AdapterRegistry.load()
34
+ registry.add_adapter(AdapterEntry(id="my-lora", ...))
35
+ registry.save()
36
+ """
37
+
38
+ def __init__(self, adapters: list[AdapterEntry], source_path: Path) -> None:
39
+ self._adapters: dict[str, AdapterEntry] = {a.id: a for a in adapters}
40
+ self._source_path = source_path
41
+
42
+ # ------------------------------------------------------------------
43
+ # Factory / persistence
44
+ # ------------------------------------------------------------------
45
+
46
+ @classmethod
47
+ def load(cls) -> "AdapterRegistry":
48
+ """Load the adapter registry from disk.
49
+
50
+ Prefers ``~/.shiftgate/adapters.json``. Falls back to the bundled
51
+ ``data/default_adapters.json`` (which ships as an empty list).
52
+ """
53
+ if _USER_ADAPTERS_PATH.exists():
54
+ source = _USER_ADAPTERS_PATH
55
+ elif _DEFAULT_ADAPTERS_PATH.exists():
56
+ source = _DEFAULT_ADAPTERS_PATH
57
+ else:
58
+ logger.warning("No adapter registry found; starting empty.")
59
+ return cls([], source_path=_USER_ADAPTERS_PATH)
60
+
61
+ logger.debug("Loading adapter registry from %s", source)
62
+ raw = json.loads(source.read_text(encoding="utf-8"))
63
+ adapters = [AdapterEntry.model_validate(a) for a in raw]
64
+ return cls(adapters, source_path=source)
65
+
66
+ def save(self) -> None:
67
+ """Persist the current registry to ``~/.shiftgate/adapters.json``."""
68
+ _SHIFTGATE_DIR.mkdir(parents=True, exist_ok=True)
69
+ data = [a.model_dump() for a in self._adapters.values()]
70
+ _USER_ADAPTERS_PATH.write_text(
71
+ json.dumps(data, indent=2, ensure_ascii=False),
72
+ encoding="utf-8",
73
+ )
74
+ logger.debug("Adapter registry saved to %s", _USER_ADAPTERS_PATH)
75
+
76
+ # ------------------------------------------------------------------
77
+ # CRUD
78
+ # ------------------------------------------------------------------
79
+
80
+ def get_adapter(self, adapter_id: str) -> AdapterEntry | None:
81
+ """Return an adapter by ID, or None if not found."""
82
+ return self._adapters.get(adapter_id)
83
+
84
+ def list_adapters(self) -> list[AdapterEntry]:
85
+ """Return all registered adapters."""
86
+ return list(self._adapters.values())
87
+
88
+ def add_adapter(self, adapter: AdapterEntry | str, **kwargs: object) -> AdapterEntry:
89
+ """Add or replace an adapter in the registry.
90
+
91
+ Parameters
92
+ ----------
93
+ adapter:
94
+ Either a fully-constructed ``AdapterEntry`` or a HuggingFace
95
+ repo ID string (e.g. ``"username/my-lora-adapter"``). When a
96
+ string is provided the repo ID is used as ``hf_repo`` and a
97
+ best-effort ID slug is derived from it. Extra keyword arguments
98
+ (``tags``, ``base_model``, ``description``) override auto-derived
99
+ values.
100
+ """
101
+ if isinstance(adapter, str):
102
+ adapter = _adapter_from_hf_repo(adapter, **kwargs)
103
+
104
+ self._adapters[adapter.id] = adapter
105
+ logger.debug("Adapter '%s' added to registry.", adapter.id)
106
+ return adapter
107
+
108
+ def remove_adapter(self, adapter_id: str) -> bool:
109
+ """Remove an adapter by ID. Returns True if it existed."""
110
+ if adapter_id in self._adapters:
111
+ del self._adapters[adapter_id]
112
+ return True
113
+ return False
114
+
115
+ def __len__(self) -> int:
116
+ return len(self._adapters)
117
+
118
+
119
+ # ---------------------------------------------------------------------------
120
+ # Helpers
121
+ # ---------------------------------------------------------------------------
122
+
123
+ def _adapter_from_hf_repo(hf_repo: str, **kwargs: object) -> AdapterEntry:
124
+ """Construct a minimal AdapterEntry from a HuggingFace repo ID.
125
+
126
+ Tries to pull card metadata from the Hub. If that fails (offline, private
127
+ repo, etc.) it builds a stub entry from the repo ID alone.
128
+
129
+ Extra ``kwargs`` are merged after auto-detection and override any
130
+ auto-derived fields (``tags``, ``base_model``, ``description``).
131
+ """
132
+ # Derive a clean ID slug from the repo path (e.g. "org/my-lora" → "my-lora")
133
+ slug = hf_repo.split("/")[-1].lower().replace("_", "-")
134
+
135
+ entry_data: dict = {
136
+ "id": slug,
137
+ "name": slug.replace("-", " ").title(),
138
+ "base_model": kwargs.pop("base_model", "unknown"),
139
+ "task_tags": kwargs.pop("tags", []),
140
+ "description": kwargs.pop("description", f"Imported from {hf_repo}"),
141
+ "hf_repo": hf_repo,
142
+ }
143
+ entry_data.update(kwargs)
144
+
145
+ # Attempt to enrich from HuggingFace Hub metadata.
146
+ try:
147
+ from huggingface_hub import hf_hub_download, model_info # type: ignore
148
+
149
+ info = model_info(hf_repo)
150
+ if info.card_data:
151
+ card = info.card_data
152
+ if hasattr(card, "base_model") and card.base_model:
153
+ base = card.base_model
154
+ entry_data["base_model"] = base[0] if isinstance(base, list) else base
155
+ if hasattr(card, "tags") and card.tags and not entry_data["task_tags"]:
156
+ entry_data["task_tags"] = list(card.tags)[:8]
157
+ if info.id:
158
+ entry_data["name"] = info.id.split("/")[-1]
159
+ except Exception as exc:
160
+ logger.debug("Could not fetch HF metadata for '%s': %s", hf_repo, exc)
161
+
162
+ return AdapterEntry.model_validate(entry_data)
@@ -0,0 +1,115 @@
1
+ """
2
+ Pydantic v2 schemas for shiftgate's core data model.
3
+
4
+ Three top-level types:
5
+ - AdapterEntry : a LoRA adapter (or fine-tuned model) in the registry
6
+ - TaskCluster : a group of semantically related tasks with example queries
7
+ - RoutingTrace : one routing decision, optionally annotated with user feedback
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from pydantic import BaseModel, Field
13
+
14
+
15
+ class AdapterEntry(BaseModel):
16
+ """A LoRA adapter registered with shiftgate.
17
+
18
+ Adapters can live on HuggingFace (``hf_repo``) or locally (``local_path``).
19
+ At least one of the two must be set for inference to work, though the
20
+ registry itself does not enforce this so adapters can be catalogued before
21
+ they are downloaded.
22
+ """
23
+
24
+ id: str = Field(
25
+ description="Unique slug, e.g. 'python-lora-llama3'. Used as a stable reference key."
26
+ )
27
+ name: str = Field(description="Human-readable display name.")
28
+ base_model: str = Field(
29
+ description="The base model this adapter was trained on, e.g. 'meta-llama/Meta-Llama-3-8B'."
30
+ )
31
+ task_tags: list[str] = Field(
32
+ default_factory=list,
33
+ description="Free-form tags describing the adapter's specialisation, e.g. ['code', 'python'].",
34
+ )
35
+ description: str = Field(default="", description="Short prose description of the adapter's purpose.")
36
+ hf_repo: str | None = Field(
37
+ default=None,
38
+ description="HuggingFace Hub repository ID, e.g. 'username/my-lora-adapter'.",
39
+ )
40
+ local_path: str | None = Field(
41
+ default=None,
42
+ description="Absolute path to a local .safetensors file or adapter directory.",
43
+ )
44
+ benchmark_score: float | None = Field(
45
+ default=None,
46
+ description="Optional benchmark score (0–1) reported by the adapter author.",
47
+ )
48
+ context_length: int = Field(
49
+ default=4096,
50
+ description="Maximum context window in tokens.",
51
+ )
52
+ memory_mb: int | None = Field(
53
+ default=None,
54
+ description="Approximate VRAM/RAM usage in MB when the adapter is loaded.",
55
+ )
56
+
57
+
58
+ class TaskCluster(BaseModel):
59
+ """A cluster of semantically related tasks used for routing.
60
+
61
+ During ``shiftgate init``, the ``validation_examples`` are embedded and
62
+ averaged to produce ``embedding_centroid``. At routing time the query
63
+ embedding is compared against every cluster's centroid.
64
+ """
65
+
66
+ id: str = Field(
67
+ description="Unique slug, e.g. 'code_python'. Used as a stable routing key."
68
+ )
69
+ name: str = Field(description="Human-readable cluster name, e.g. 'Python Code Generation'.")
70
+ description: str = Field(description="Short description of what tasks belong here.")
71
+ validation_examples: list[str] = Field(
72
+ description="3–10 representative query strings used to compute the centroid embedding.",
73
+ )
74
+ embedding_centroid: list[float] | None = Field(
75
+ default=None,
76
+ description="Pre-computed mean embedding of the validation_examples. Populated by init.",
77
+ )
78
+ preferred_adapters: list[str] = Field(
79
+ default_factory=list,
80
+ description="Adapter IDs in priority order. The first available adapter is selected.",
81
+ )
82
+ fallback_adapters: list[str] = Field(
83
+ default_factory=list,
84
+ description="Adapter IDs to try when none of the preferred adapters are available.",
85
+ )
86
+
87
+
88
+ class RoutingTrace(BaseModel):
89
+ """A single routing decision recorded for observability and feedback.
90
+
91
+ Traces are appended as JSON lines to ``~/.shiftgate/traces.jsonl``.
92
+ The ``accepted`` field starts as ``None`` and is filled in via
93
+ ``shiftgate feedback accept/reject``.
94
+ """
95
+
96
+ id: str = Field(
97
+ description="Unique trace ID (UUID4 hex string) for targeted feedback updates."
98
+ )
99
+ query: str = Field(description="The original user query that triggered this routing decision.")
100
+ matched_task_id: str = Field(description="ID of the TaskCluster that won the similarity match.")
101
+ similarity_score: float = Field(
102
+ description="Cosine similarity between the query embedding and the winning centroid (0–1)."
103
+ )
104
+ selected_adapter_id: str = Field(description="ID of the adapter that was selected for inference.")
105
+ accepted: bool | None = Field(
106
+ default=None,
107
+ description="User feedback: True = good routing, False = bad routing, None = not yet rated.",
108
+ )
109
+ latency_ms: float | None = Field(
110
+ default=None,
111
+ description="End-to-end inference latency in milliseconds (None if only routing, no run).",
112
+ )
113
+ timestamp: str = Field(
114
+ description="ISO-8601 UTC timestamp of when this trace was created."
115
+ )
@@ -0,0 +1,186 @@
1
+ """
2
+ Task registry: load, persist, and manage TaskCluster definitions.
3
+
4
+ The registry reads from (in priority order):
5
+ 1. ``~/.shiftgate/tasks.json`` — user-edited / previously saved
6
+ 2. ``<package>/../../data/default_tasks.json`` — bundled defaults
7
+
8
+ On first run (``shiftgate init``) the ``compute_embeddings`` method is called
9
+ to populate ``embedding_centroid`` for every cluster and cache them to
10
+ ``~/.shiftgate/embeddings_cache.npy`` so subsequent startups are instant.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import logging
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING
19
+
20
+ import numpy as np
21
+
22
+ from shiftgate.registry.schemas import TaskCluster
23
+
24
+ if TYPE_CHECKING:
25
+ from shiftgate.router.embedder import Embedder
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Canonical locations
30
+ _SHIFTGATE_DIR = Path.home() / ".shiftgate"
31
+ _USER_TASKS_PATH = _SHIFTGATE_DIR / "tasks.json"
32
+ _CACHE_PATH = _SHIFTGATE_DIR / "embeddings_cache.npy"
33
+
34
+ # Path to the bundled default tasks, resolved relative to this file's location.
35
+ _DEFAULT_TASKS_PATH = Path(__file__).parent.parent.parent / "data" / "default_tasks.json"
36
+
37
+
38
+ class TaskRegistry:
39
+ """In-memory store for TaskCluster objects, backed by a JSON file.
40
+
41
+ Usage::
42
+
43
+ registry = TaskRegistry.load()
44
+ registry.compute_embeddings(embedder)
45
+ registry.save()
46
+ """
47
+
48
+ def __init__(self, tasks: list[TaskCluster], source_path: Path) -> None:
49
+ self._tasks: dict[str, TaskCluster] = {t.id: t for t in tasks}
50
+ self._source_path = source_path
51
+
52
+ # ------------------------------------------------------------------
53
+ # Factory / persistence
54
+ # ------------------------------------------------------------------
55
+
56
+ @classmethod
57
+ def load(cls) -> "TaskRegistry":
58
+ """Load the task registry from disk.
59
+
60
+ Prefers the user's ``~/.shiftgate/tasks.json`` and falls back to the
61
+ bundled ``data/default_tasks.json`` if the user file does not exist.
62
+ """
63
+ if _USER_TASKS_PATH.exists():
64
+ source = _USER_TASKS_PATH
65
+ elif _DEFAULT_TASKS_PATH.exists():
66
+ source = _DEFAULT_TASKS_PATH
67
+ else:
68
+ raise FileNotFoundError(
69
+ f"No task registry found. Expected one of:\n"
70
+ f" {_USER_TASKS_PATH}\n"
71
+ f" {_DEFAULT_TASKS_PATH}\n"
72
+ "Run `shiftgate init` to set up the default registry."
73
+ )
74
+
75
+ logger.debug("Loading task registry from %s", source)
76
+ raw = json.loads(source.read_text(encoding="utf-8"))
77
+ tasks = [TaskCluster.model_validate(t) for t in raw]
78
+ instance = cls(tasks, source_path=source)
79
+
80
+ # Eagerly restore cached centroids so ``compute_embeddings`` can be
81
+ # skipped on normal runs (not first init).
82
+ instance._restore_cache()
83
+ return instance
84
+
85
+ def save(self) -> None:
86
+ """Persist the current registry to ``~/.shiftgate/tasks.json``."""
87
+ _SHIFTGATE_DIR.mkdir(parents=True, exist_ok=True)
88
+ data = [t.model_dump() for t in self._tasks.values()]
89
+ _USER_TASKS_PATH.write_text(
90
+ json.dumps(data, indent=2, ensure_ascii=False),
91
+ encoding="utf-8",
92
+ )
93
+ logger.debug("Task registry saved to %s", _USER_TASKS_PATH)
94
+
95
+ # ------------------------------------------------------------------
96
+ # Embedding management
97
+ # ------------------------------------------------------------------
98
+
99
+ def compute_embeddings(self, embedder: "Embedder") -> None:
100
+ """Compute and store the centroid embedding for every task cluster.
101
+
102
+ For each cluster the validation examples are embedded individually and
103
+ then averaged (L2-normalised mean) to form a single centroid vector.
104
+ The results are written back into each ``TaskCluster.embedding_centroid``
105
+ field **and** saved to ``~/.shiftgate/embeddings_cache.npy`` as a
106
+ (n_tasks × dim) float32 array for fast loading on future runs.
107
+ """
108
+ task_list = list(self._tasks.values())
109
+ logger.info("Computing embeddings for %d task clusters…", len(task_list))
110
+
111
+ for task in task_list:
112
+ all_examples = task.validation_examples
113
+ embeddings = embedder.embed_batch(all_examples) # shape: (n, dim)
114
+ centroid = embeddings.mean(axis=0)
115
+ # L2-normalise so cosine similarity reduces to dot product later.
116
+ norm = np.linalg.norm(centroid)
117
+ if norm > 0:
118
+ centroid = centroid / norm
119
+ task.embedding_centroid = centroid.tolist()
120
+
121
+ # Persist centroids as a numpy array indexed by task order.
122
+ self._save_cache(task_list)
123
+ logger.info("Embeddings computed and cached.")
124
+
125
+ def _save_cache(self, task_list: list[TaskCluster]) -> None:
126
+ """Write centroids to the numpy cache file."""
127
+ _SHIFTGATE_DIR.mkdir(parents=True, exist_ok=True)
128
+ centroids = [t.embedding_centroid for t in task_list if t.embedding_centroid]
129
+ if centroids:
130
+ arr = np.array(centroids, dtype=np.float32)
131
+ np.save(_CACHE_PATH, arr)
132
+ logger.debug("Centroid cache written to %s (%s)", _CACHE_PATH, arr.shape)
133
+
134
+ def _restore_cache(self) -> None:
135
+ """Re-populate ``embedding_centroid`` from the numpy cache if available.
136
+
137
+ This avoids a full re-embedding on every startup. The cache is keyed
138
+ positionally — tasks must stay in the same order between runs, which is
139
+ true as long as the registry JSON is not manually reordered.
140
+ """
141
+ if not _CACHE_PATH.exists():
142
+ return
143
+ try:
144
+ arr = np.load(_CACHE_PATH)
145
+ task_list = list(self._tasks.values())
146
+ for i, task in enumerate(task_list):
147
+ if i < len(arr):
148
+ task.embedding_centroid = arr[i].tolist()
149
+ logger.debug("Restored centroids from cache (%d tasks)", len(task_list))
150
+ except Exception as exc:
151
+ logger.warning("Could not restore embedding cache (%s). Re-run `shiftgate init`.", exc)
152
+
153
+ def embeddings_ready(self) -> bool:
154
+ """Return True if all task clusters have a computed centroid."""
155
+ return all(t.embedding_centroid is not None for t in self._tasks.values())
156
+
157
+ # ------------------------------------------------------------------
158
+ # CRUD
159
+ # ------------------------------------------------------------------
160
+
161
+ def get_all_tasks(self) -> list[TaskCluster]:
162
+ """Return all registered task clusters."""
163
+ return list(self._tasks.values())
164
+
165
+ def get_task(self, task_id: str) -> TaskCluster | None:
166
+ """Return a single task cluster by ID, or None if not found."""
167
+ return self._tasks.get(task_id)
168
+
169
+ def add_task(self, task: TaskCluster) -> None:
170
+ """Add or replace a task cluster in the registry.
171
+
172
+ If a task with the same ID already exists it is silently overwritten.
173
+ Call ``save()`` afterwards to persist the change.
174
+ """
175
+ self._tasks[task.id] = task
176
+ logger.debug("Task '%s' added to registry.", task.id)
177
+
178
+ def remove_task(self, task_id: str) -> bool:
179
+ """Remove a task by ID. Returns True if it existed."""
180
+ if task_id in self._tasks:
181
+ del self._tasks[task_id]
182
+ return True
183
+ return False
184
+
185
+ def __len__(self) -> int:
186
+ return len(self._tasks)
@@ -0,0 +1 @@
1
+ """Router sub-package: embedding, cosine matching, and routing logic."""