memstack-skill-loader 3.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memstack_skill_loader/__init__.py +1 -0
- memstack_skill_loader/__main__.py +18 -0
- memstack_skill_loader/compression.py +345 -0
- memstack_skill_loader/config.py +114 -0
- memstack_skill_loader/dashboard.html +829 -0
- memstack_skill_loader/dashboard.py +360 -0
- memstack_skill_loader/indexer.py +240 -0
- memstack_skill_loader/license.py +409 -0
- memstack_skill_loader/search.py +164 -0
- memstack_skill_loader/server.py +883 -0
- memstack_skill_loader/stats.py +428 -0
- memstack_skill_loader/tfidf_search.py +142 -0
- memstack_skill_loader/version_check.py +93 -0
- memstack_skill_loader-3.5.0.dist-info/METADATA +10 -0
- memstack_skill_loader-3.5.0.dist-info/RECORD +18 -0
- memstack_skill_loader-3.5.0.dist-info/WHEEL +5 -0
- memstack_skill_loader-3.5.0.dist-info/entry_points.txt +2 -0
- memstack_skill_loader-3.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
"""License validation for MemStack Pro skill gating."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import functools
|
|
5
|
+
import getpass
|
|
6
|
+
import hashlib
|
|
7
|
+
import hmac as hmac_mod
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import socket
|
|
11
|
+
import sys
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from datetime import date
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
import httpx
|
|
19
|
+
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
# Data types
|
|
22
|
+
# ---------------------------------------------------------------------------
|
|
23
|
+
|
|
24
|
+
VALIDATE_URL = "https://admin.cwaffiliateinvestments.com/api/licenses/validate"
|
|
25
|
+
LICENSE_FILE = Path.home() / ".memstack" / "license.json"
|
|
26
|
+
GRACE_PERIOD_FILE = Path.home() / ".memstack" / "grace-period.json"
|
|
27
|
+
DISK_CACHE_FILE = Path.home() / ".memstack" / "license-cache.json"
|
|
28
|
+
DISK_CACHE_MAX_AGE = 24 * 60 * 60 # 24 hours in seconds
|
|
29
|
+
HTTP_TIMEOUT = 5.0 # seconds
|
|
30
|
+
MAX_KEY_LEN = 256
|
|
31
|
+
GRACE_PERIOD_DAYS = 3
|
|
32
|
+
|
|
33
|
+
# Skills that require a Pro license. Free-tier users see a locked message.
|
|
34
|
+
PRO_EXCLUSIVE_SKILLS = frozenset({"consolidate", "context-db", "api-docs", "branching", "multi-agent", "codebase-index", "doc-index", "diagram-generator", "browser-use", "session-restore", "drift-detection", "mcp-builder", "claude-api-helper", "test-generator", "log-analyzer", "performance-profiler", "dependency-auditor", "git-worktrees", "error-handler", "web-scraper", "advanced-security", "env-manager-pro", "hooks-integration", "developer-growth-analysis", "gtm-validator", "meeting-insights-analyzer", "model-router", "rag-builder", "video-pipeline"})
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class LicenseStatus:
|
|
39
|
+
valid: bool
|
|
40
|
+
tier: str
|
|
41
|
+
cached_at: float # time.time() when this entry was cached
|
|
42
|
+
grace_period: bool = False # True if access granted via grace period
|
|
43
|
+
grace_days_remaining: int = 0
|
|
44
|
+
grace_expired: bool = False # True only when grace period has ended
|
|
45
|
+
grace_tampered: bool = False # True if HMAC integrity check failed
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def is_pro_exclusive(skill_name: str) -> bool:
|
|
49
|
+
"""Return True if *skill_name* requires a Pro license."""
|
|
50
|
+
return skill_name.lower() in PRO_EXCLUSIVE_SKILLS
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# ---------------------------------------------------------------------------
|
|
54
|
+
# In-memory cache: {license_key: (LicenseStatus, expires_at)}
|
|
55
|
+
# Per-key locks prevent duplicate API calls from concurrent requests.
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
|
|
58
|
+
_cache: dict[str, tuple[LicenseStatus, float]] = {}
|
|
59
|
+
_locks: dict[str, asyncio.Lock] = {}
|
|
60
|
+
_session_validated: set[str] = set() # keys validated this session
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _get_lock(key: str) -> asyncio.Lock:
|
|
64
|
+
"""Return a per-key asyncio.Lock, creating one if needed."""
|
|
65
|
+
if key not in _locks:
|
|
66
|
+
_locks[key] = asyncio.Lock()
|
|
67
|
+
return _locks[key]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def clear_cache(key: str | None = None) -> None:
|
|
71
|
+
"""Remove cached license status.
|
|
72
|
+
|
|
73
|
+
Pass *key* to clear a single entry, or ``None`` to clear all.
|
|
74
|
+
"""
|
|
75
|
+
if key is None:
|
|
76
|
+
_cache.clear()
|
|
77
|
+
_session_validated.clear()
|
|
78
|
+
else:
|
|
79
|
+
_cache.pop(key, None)
|
|
80
|
+
_session_validated.discard(key)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _save_disk_cache(key: str, status: LicenseStatus) -> None:
|
|
84
|
+
"""Persist license validation result to disk with a timestamp."""
|
|
85
|
+
try:
|
|
86
|
+
DISK_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
87
|
+
DISK_CACHE_FILE.write_text(
|
|
88
|
+
json.dumps({
|
|
89
|
+
"license_key": key,
|
|
90
|
+
"valid": status.valid,
|
|
91
|
+
"tier": status.tier,
|
|
92
|
+
"timestamp": time.time(),
|
|
93
|
+
}, indent=2),
|
|
94
|
+
encoding="utf-8",
|
|
95
|
+
)
|
|
96
|
+
except OSError as exc:
|
|
97
|
+
print(f"[license] failed to write disk cache: {exc}", file=sys.stderr)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _load_disk_cache(key: str) -> LicenseStatus | None:
|
|
101
|
+
"""Load cached license result from disk if it exists and is < 24hrs old."""
|
|
102
|
+
try:
|
|
103
|
+
if not DISK_CACHE_FILE.exists():
|
|
104
|
+
return None
|
|
105
|
+
data = json.loads(DISK_CACHE_FILE.read_text(encoding="utf-8"))
|
|
106
|
+
if data.get("license_key") != key:
|
|
107
|
+
return None
|
|
108
|
+
cached_time = data.get("timestamp", 0)
|
|
109
|
+
age = time.time() - cached_time
|
|
110
|
+
if age > DISK_CACHE_MAX_AGE:
|
|
111
|
+
print("[license] disk cache expired (>24hrs)", file=sys.stderr)
|
|
112
|
+
return None
|
|
113
|
+
return LicenseStatus(
|
|
114
|
+
valid=bool(data.get("valid", False)),
|
|
115
|
+
tier=data.get("tier", "free"),
|
|
116
|
+
cached_at=cached_time,
|
|
117
|
+
)
|
|
118
|
+
except (OSError, json.JSONDecodeError, ValueError) as exc:
|
|
119
|
+
print(f"[license] failed to read disk cache: {exc}", file=sys.stderr)
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# ---------------------------------------------------------------------------
|
|
124
|
+
# HMAC helpers — machine-tied tamper detection for grace period file.
|
|
125
|
+
# ---------------------------------------------------------------------------
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@functools.lru_cache(maxsize=None)
|
|
129
|
+
def _machine_secret() -> bytes:
|
|
130
|
+
"""Derive a machine-specific secret from hostname + username."""
|
|
131
|
+
try:
|
|
132
|
+
username = os.getlogin()
|
|
133
|
+
except OSError:
|
|
134
|
+
username = getpass.getuser()
|
|
135
|
+
identity = f"{socket.gethostname()}{username}"
|
|
136
|
+
return hmac_mod.new(
|
|
137
|
+
b"memstack-grace-v1", identity.encode("utf-8"), hashlib.sha256
|
|
138
|
+
).digest()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@functools.lru_cache(maxsize=None)
|
|
142
|
+
def _machine_id() -> str:
|
|
143
|
+
"""Generate a stable machine fingerprint for license binding."""
|
|
144
|
+
try:
|
|
145
|
+
username = os.getlogin()
|
|
146
|
+
except OSError:
|
|
147
|
+
username = getpass.getuser()
|
|
148
|
+
raw = f"{socket.gethostname()}-{username}-{sys.platform}"
|
|
149
|
+
return hashlib.sha256(raw.encode()).hexdigest()[:32]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _compute_hmac(first_run_str: str) -> str:
|
|
153
|
+
"""Compute HMAC-SHA256 hex digest for a first_run date string."""
|
|
154
|
+
return hmac_mod.new(
|
|
155
|
+
_machine_secret(), first_run_str.encode("utf-8"), hashlib.sha256
|
|
156
|
+
).hexdigest()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# ---------------------------------------------------------------------------
|
|
160
|
+
# Grace period helpers — cached in-memory, refreshed daily.
|
|
161
|
+
# Uses a threading lock (not asyncio) since file I/O is synchronous.
|
|
162
|
+
# ---------------------------------------------------------------------------
|
|
163
|
+
|
|
164
|
+
_grace_lock = threading.Lock()
|
|
165
|
+
_grace_cache: tuple[bool, int, str, bool] | None = None
|
|
166
|
+
# (allowed, days_remaining, date_str, tampered)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _get_grace_period_status() -> tuple[bool, int, bool]:
|
|
170
|
+
"""Check the 3-day grace period for users without a license key.
|
|
171
|
+
|
|
172
|
+
Returns (is_active, days_remaining, tampered).
|
|
173
|
+
- is_active=True, days_remaining>0: grace period still active
|
|
174
|
+
- is_active=False, days_remaining<=0: grace period expired
|
|
175
|
+
- tampered=True: HMAC integrity check failed
|
|
176
|
+
"""
|
|
177
|
+
global _grace_cache
|
|
178
|
+
|
|
179
|
+
today_str = date.today().isoformat()
|
|
180
|
+
|
|
181
|
+
# Check in-memory cache — valid if same calendar day
|
|
182
|
+
if _grace_cache is not None:
|
|
183
|
+
cached_allowed, cached_remaining, cached_date, cached_tampered = _grace_cache
|
|
184
|
+
if cached_date == today_str:
|
|
185
|
+
return (cached_allowed, cached_remaining, cached_tampered)
|
|
186
|
+
|
|
187
|
+
with _grace_lock:
|
|
188
|
+
# Double-check after acquiring lock
|
|
189
|
+
if _grace_cache is not None:
|
|
190
|
+
cached_allowed, cached_remaining, cached_date, cached_tampered = _grace_cache
|
|
191
|
+
if cached_date == today_str:
|
|
192
|
+
return (cached_allowed, cached_remaining, cached_tampered)
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
if GRACE_PERIOD_FILE.exists():
|
|
196
|
+
data = json.loads(GRACE_PERIOD_FILE.read_text(encoding="utf-8"))
|
|
197
|
+
first_run_str = data["first_run"]
|
|
198
|
+
|
|
199
|
+
if "hmac" in data:
|
|
200
|
+
# Verify HMAC
|
|
201
|
+
expected = _compute_hmac(first_run_str)
|
|
202
|
+
if not hmac_mod.compare_digest(data["hmac"], expected):
|
|
203
|
+
print(
|
|
204
|
+
"[license] grace period HMAC mismatch — tampered",
|
|
205
|
+
file=sys.stderr,
|
|
206
|
+
)
|
|
207
|
+
_grace_cache = (False, 0, today_str, True)
|
|
208
|
+
return (False, 0, True)
|
|
209
|
+
else:
|
|
210
|
+
# Legacy file without HMAC — migrate by rewriting with HMAC
|
|
211
|
+
print(
|
|
212
|
+
"[license] migrating legacy grace-period.json (adding HMAC)",
|
|
213
|
+
file=sys.stderr,
|
|
214
|
+
)
|
|
215
|
+
data["hmac"] = _compute_hmac(first_run_str)
|
|
216
|
+
GRACE_PERIOD_FILE.write_text(
|
|
217
|
+
json.dumps(data, indent=2), encoding="utf-8"
|
|
218
|
+
)
|
|
219
|
+
try:
|
|
220
|
+
os.chmod(GRACE_PERIOD_FILE, 0o600)
|
|
221
|
+
except OSError:
|
|
222
|
+
pass
|
|
223
|
+
|
|
224
|
+
first_run = date.fromisoformat(first_run_str)
|
|
225
|
+
else:
|
|
226
|
+
# First run — create the grace period file with HMAC
|
|
227
|
+
first_run = date.today()
|
|
228
|
+
first_run_str = first_run.isoformat()
|
|
229
|
+
file_data = {
|
|
230
|
+
"first_run": first_run_str,
|
|
231
|
+
"hmac": _compute_hmac(first_run_str),
|
|
232
|
+
}
|
|
233
|
+
GRACE_PERIOD_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
234
|
+
GRACE_PERIOD_FILE.write_text(
|
|
235
|
+
json.dumps(file_data, indent=2), encoding="utf-8"
|
|
236
|
+
)
|
|
237
|
+
try:
|
|
238
|
+
os.chmod(GRACE_PERIOD_FILE, 0o600)
|
|
239
|
+
except OSError:
|
|
240
|
+
pass
|
|
241
|
+
print(
|
|
242
|
+
f"[license] grace period started: {first_run_str}",
|
|
243
|
+
file=sys.stderr,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
elapsed = (date.today() - first_run).days
|
|
247
|
+
remaining = max(GRACE_PERIOD_DAYS - elapsed, 0)
|
|
248
|
+
allowed = remaining > 0
|
|
249
|
+
|
|
250
|
+
_grace_cache = (allowed, remaining, today_str, False)
|
|
251
|
+
return (allowed, remaining, False)
|
|
252
|
+
except (json.JSONDecodeError, KeyError, ValueError, OSError) as exc:
|
|
253
|
+
print(f"[license] grace period check failed: {exc}", file=sys.stderr)
|
|
254
|
+
_grace_cache = (False, 0, today_str, False)
|
|
255
|
+
return (False, 0, False)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
# ---------------------------------------------------------------------------
|
|
259
|
+
# License key helpers
|
|
260
|
+
# ---------------------------------------------------------------------------
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def get_license_key() -> str | None:
|
|
264
|
+
"""Read the license key from env var or ~/.memstack/license.json."""
|
|
265
|
+
key = os.environ.get("MEMSTACK_PRO_LICENSE_KEY")
|
|
266
|
+
if key:
|
|
267
|
+
return key.strip()
|
|
268
|
+
|
|
269
|
+
if LICENSE_FILE.exists():
|
|
270
|
+
try:
|
|
271
|
+
data = json.loads(LICENSE_FILE.read_text(encoding="utf-8"))
|
|
272
|
+
k = data.get("license_key", "")
|
|
273
|
+
if k:
|
|
274
|
+
return k.strip()
|
|
275
|
+
except (json.JSONDecodeError, OSError) as exc:
|
|
276
|
+
print(f"[license] failed to read {LICENSE_FILE}: {exc}", file=sys.stderr)
|
|
277
|
+
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def save_license_key(key: str) -> None:
|
|
282
|
+
"""Persist *key* to ~/.memstack/license.json with restricted permissions."""
|
|
283
|
+
LICENSE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
284
|
+
LICENSE_FILE.write_text(
|
|
285
|
+
json.dumps({"license_key": key}, indent=2),
|
|
286
|
+
encoding="utf-8",
|
|
287
|
+
)
|
|
288
|
+
try:
|
|
289
|
+
os.chmod(LICENSE_FILE, 0o600)
|
|
290
|
+
except OSError:
|
|
291
|
+
pass
|
|
292
|
+
print(f"[license] saved key to {LICENSE_FILE}", file=sys.stderr)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# ---------------------------------------------------------------------------
|
|
296
|
+
# Validation
|
|
297
|
+
# ---------------------------------------------------------------------------
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
async def validate_license(
|
|
301
|
+
license_key: str | None = None,
|
|
302
|
+
email: str | None = None,
|
|
303
|
+
) -> LicenseStatus:
|
|
304
|
+
"""Validate a MemStack Pro license key.
|
|
305
|
+
|
|
306
|
+
Resolution order for the key:
|
|
307
|
+
1. Explicit *license_key* argument
|
|
308
|
+
2. MEMSTACK_PRO_LICENSE_KEY env var
|
|
309
|
+
3. ~/.memstack/license.json
|
|
310
|
+
4. If none found, check 14-day grace period
|
|
311
|
+
|
|
312
|
+
*email* is forwarded to the API for contact tracking but does not
|
|
313
|
+
affect caching — the cache key is the license key alone.
|
|
314
|
+
|
|
315
|
+
Returns a cached result when the cache is still fresh.
|
|
316
|
+
"""
|
|
317
|
+
key = license_key or get_license_key()
|
|
318
|
+
if not key:
|
|
319
|
+
# No key — check grace period
|
|
320
|
+
grace_active, days_remaining, tampered = _get_grace_period_status()
|
|
321
|
+
|
|
322
|
+
if tampered:
|
|
323
|
+
return LicenseStatus(
|
|
324
|
+
valid=False, tier="free", cached_at=0, grace_tampered=True
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if grace_active:
|
|
328
|
+
print(
|
|
329
|
+
f"[license] grace period active: {days_remaining} days remaining",
|
|
330
|
+
file=sys.stderr,
|
|
331
|
+
)
|
|
332
|
+
return LicenseStatus(
|
|
333
|
+
valid=True,
|
|
334
|
+
tier="pro",
|
|
335
|
+
cached_at=time.time(),
|
|
336
|
+
grace_period=True,
|
|
337
|
+
grace_days_remaining=days_remaining,
|
|
338
|
+
)
|
|
339
|
+
# Grace expired
|
|
340
|
+
return LicenseStatus(
|
|
341
|
+
valid=False, tier="free", cached_at=0, grace_expired=True
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Basic format check
|
|
345
|
+
if len(key) > MAX_KEY_LEN or not key.strip():
|
|
346
|
+
return LicenseStatus(valid=False, tier="free", cached_at=0)
|
|
347
|
+
|
|
348
|
+
async with _get_lock(key):
|
|
349
|
+
return await _validate_locked(key, email)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
async def _validate_locked(key: str, email: str | None) -> LicenseStatus:
|
|
353
|
+
"""Inner validation, called while holding the per-key lock.
|
|
354
|
+
|
|
355
|
+
- First call per session: always hits the API.
|
|
356
|
+
- Subsequent calls in the same session: return in-memory result.
|
|
357
|
+
- If API is unreachable: fall back to disk cache (< 24hrs old).
|
|
358
|
+
- If disk cache is stale/missing and API is unreachable: deny access.
|
|
359
|
+
"""
|
|
360
|
+
now = time.time()
|
|
361
|
+
|
|
362
|
+
# Within a session, reuse the in-memory result (don't re-validate)
|
|
363
|
+
if key in _session_validated and key in _cache:
|
|
364
|
+
cached_status, _ = _cache[key]
|
|
365
|
+
return cached_status
|
|
366
|
+
|
|
367
|
+
# --- build request body ---
|
|
368
|
+
body_payload: dict[str, str] = {"license_key": key, "machine_id": _machine_id()}
|
|
369
|
+
if email:
|
|
370
|
+
body_payload["email"] = email
|
|
371
|
+
|
|
372
|
+
# --- always call API on session start ---
|
|
373
|
+
try:
|
|
374
|
+
async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
|
|
375
|
+
resp = await client.post(VALIDATE_URL, json=body_payload)
|
|
376
|
+
resp.raise_for_status()
|
|
377
|
+
body = resp.json()
|
|
378
|
+
except httpx.HTTPStatusError as exc:
|
|
379
|
+
print(f"[license] HTTP {exc.response.status_code}: {exc}", file=sys.stderr)
|
|
380
|
+
status = LicenseStatus(valid=False, tier="free", cached_at=now)
|
|
381
|
+
_cache[key] = (status, now + 60)
|
|
382
|
+
_session_validated.add(key)
|
|
383
|
+
return status
|
|
384
|
+
except (httpx.HTTPError, json.JSONDecodeError, ValueError) as exc:
|
|
385
|
+
print(f"[license] network error: {exc}", file=sys.stderr)
|
|
386
|
+
# API unreachable — fall back to disk cache if < 24hrs old
|
|
387
|
+
disk_status = _load_disk_cache(key)
|
|
388
|
+
if disk_status is not None:
|
|
389
|
+
print("[license] using disk cache fallback", file=sys.stderr)
|
|
390
|
+
_cache[key] = (disk_status, now + 3600)
|
|
391
|
+
_session_validated.add(key)
|
|
392
|
+
return disk_status
|
|
393
|
+
# No valid disk cache — deny access
|
|
394
|
+
status = LicenseStatus(valid=False, tier="free", cached_at=now)
|
|
395
|
+
_cache[key] = (status, now + 60)
|
|
396
|
+
_session_validated.add(key)
|
|
397
|
+
return status
|
|
398
|
+
|
|
399
|
+
valid = bool(body.get("valid", False))
|
|
400
|
+
tier = body.get("tier", "free") if valid else "free"
|
|
401
|
+
|
|
402
|
+
status = LicenseStatus(valid=valid, tier=tier, cached_at=now)
|
|
403
|
+
_cache[key] = (status, now + 86400) # session-long in-memory cache
|
|
404
|
+
_session_validated.add(key)
|
|
405
|
+
|
|
406
|
+
# Persist to disk for offline fallback
|
|
407
|
+
_save_disk_cache(key, status)
|
|
408
|
+
|
|
409
|
+
return status
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Semantic search against the LanceDB skill index."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import lancedb
|
|
8
|
+
|
|
9
|
+
from .config import Config, load_config
|
|
10
|
+
|
|
11
|
+
_model = None
|
|
12
|
+
_db = None
|
|
13
|
+
_table = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_model(config: Config):
|
|
17
|
+
"""Lazy-load the sentence-transformers model."""
|
|
18
|
+
global _model
|
|
19
|
+
if _model is None:
|
|
20
|
+
from sentence_transformers import SentenceTransformer
|
|
21
|
+
_model = SentenceTransformer(config.embedding_model)
|
|
22
|
+
return _model
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def warmup_model(config: Config):
|
|
26
|
+
"""Pre-load the embedding model into memory. Call from a background thread at startup."""
|
|
27
|
+
cache_dir = os.environ.get("HF_HOME") or os.environ.get(
|
|
28
|
+
"TRANSFORMERS_CACHE", os.path.join(Path.home(), ".cache", "huggingface")
|
|
29
|
+
)
|
|
30
|
+
print(f"HuggingFace cache directory: {cache_dir}", file=sys.stderr)
|
|
31
|
+
print(f"Loading embedding model '{config.embedding_model}' in background...", file=sys.stderr)
|
|
32
|
+
try:
|
|
33
|
+
_get_model(config)
|
|
34
|
+
print("Embedding model loaded and ready.", file=sys.stderr)
|
|
35
|
+
except Exception as e:
|
|
36
|
+
print(f"Background model warmup failed: {e}", file=sys.stderr)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_table(config: Config):
|
|
40
|
+
"""Get the LanceDB skills table, return None if not indexed yet."""
|
|
41
|
+
global _db, _table
|
|
42
|
+
if _table is None:
|
|
43
|
+
db_path = str(config.resolved_vector_db_path)
|
|
44
|
+
if not Path(db_path).exists():
|
|
45
|
+
return None
|
|
46
|
+
_db = lancedb.connect(db_path)
|
|
47
|
+
try:
|
|
48
|
+
_table = _db.open_table("skills")
|
|
49
|
+
except Exception:
|
|
50
|
+
return None
|
|
51
|
+
return _table
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def reset_cache():
|
|
55
|
+
"""Reset cached db/table/model. Call after reindexing."""
|
|
56
|
+
global _db, _table, _model
|
|
57
|
+
_db = None
|
|
58
|
+
_table = None
|
|
59
|
+
_model = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def search_skills(query: str, config: Config | None = None, top_k: int | None = None) -> list[dict]:
|
|
63
|
+
"""Semantic search for skills matching the query.
|
|
64
|
+
|
|
65
|
+
Returns list of dicts with: name, description, source_label, content, score
|
|
66
|
+
"""
|
|
67
|
+
if config is None:
|
|
68
|
+
config = load_config()
|
|
69
|
+
if top_k is None:
|
|
70
|
+
top_k = config.default_top_k
|
|
71
|
+
|
|
72
|
+
table = _get_table(config)
|
|
73
|
+
if table is None:
|
|
74
|
+
return []
|
|
75
|
+
|
|
76
|
+
model = _get_model(config)
|
|
77
|
+
query_vector = model.encode(query).tolist()
|
|
78
|
+
|
|
79
|
+
results = (
|
|
80
|
+
table.search(query_vector)
|
|
81
|
+
.metric("cosine")
|
|
82
|
+
.limit(top_k)
|
|
83
|
+
.to_pandas()
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
skills = []
|
|
87
|
+
for _, row in results.iterrows():
|
|
88
|
+
# Cosine distance is 0-2; convert to similarity 0-1
|
|
89
|
+
distance = row.get("_distance", 0)
|
|
90
|
+
similarity = round(max(0, 1 - distance), 4)
|
|
91
|
+
skills.append({
|
|
92
|
+
"name": row["name"],
|
|
93
|
+
"description": row["description"],
|
|
94
|
+
"source_label": row["source_label"],
|
|
95
|
+
"content": row["content"],
|
|
96
|
+
"filepath": row["filepath"],
|
|
97
|
+
"score": similarity,
|
|
98
|
+
})
|
|
99
|
+
|
|
100
|
+
return skills
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def list_all_skills(config: Config | None = None) -> list[dict]:
|
|
104
|
+
"""Return all indexed skills (name + description only, no content)."""
|
|
105
|
+
if config is None:
|
|
106
|
+
config = load_config()
|
|
107
|
+
|
|
108
|
+
table = _get_table(config)
|
|
109
|
+
if table is None:
|
|
110
|
+
return []
|
|
111
|
+
|
|
112
|
+
df = table.to_pandas()
|
|
113
|
+
skills = []
|
|
114
|
+
for _, row in df.iterrows():
|
|
115
|
+
skills.append({
|
|
116
|
+
"name": row["name"],
|
|
117
|
+
"description": row["description"],
|
|
118
|
+
"source_label": row["source_label"],
|
|
119
|
+
})
|
|
120
|
+
|
|
121
|
+
return sorted(skills, key=lambda s: s["name"])
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_skill_by_name(name: str, config: Config | None = None) -> dict | None:
|
|
125
|
+
"""Get a skill by exact or fuzzy name match."""
|
|
126
|
+
if config is None:
|
|
127
|
+
config = load_config()
|
|
128
|
+
|
|
129
|
+
table = _get_table(config)
|
|
130
|
+
if table is None:
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
df = table.to_pandas()
|
|
134
|
+
name_lower = name.lower().replace("-", " ").replace("_", " ")
|
|
135
|
+
|
|
136
|
+
# Exact match first
|
|
137
|
+
for _, row in df.iterrows():
|
|
138
|
+
if row["name"].lower() == name_lower:
|
|
139
|
+
return {
|
|
140
|
+
"name": row["name"],
|
|
141
|
+
"description": row["description"],
|
|
142
|
+
"source_label": row["source_label"],
|
|
143
|
+
"content": row["content"],
|
|
144
|
+
"filepath": row["filepath"],
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
# Fuzzy: check if query is substring of name or name is substring of query
|
|
148
|
+
for _, row in df.iterrows():
|
|
149
|
+
row_name_lower = row["name"].lower()
|
|
150
|
+
if name_lower in row_name_lower or row_name_lower in name_lower:
|
|
151
|
+
return {
|
|
152
|
+
"name": row["name"],
|
|
153
|
+
"description": row["description"],
|
|
154
|
+
"source_label": row["source_label"],
|
|
155
|
+
"content": row["content"],
|
|
156
|
+
"filepath": row["filepath"],
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
# Last resort: semantic search with top_k=1
|
|
160
|
+
results = search_skills(name, config, top_k=1)
|
|
161
|
+
if results and results[0]["score"] > 0.5:
|
|
162
|
+
return results[0]
|
|
163
|
+
|
|
164
|
+
return None
|