mcpswitch-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcpswitch/tier.py ADDED
@@ -0,0 +1,170 @@
1
+ """Tier management — Free, Pro (donation-ware), and Team (Lemon Squeezy).
2
+
3
+ Free: profile management, token estimation, auto-switch, analyze
4
+ Pro: waste detection, usage stats, weekly digest, community profiles, sync
5
+ -> always free; a periodic donation nudge is shown (non-blocking)
6
+ Team: team analytics, shared profiles, Slack alerts
7
+ -> requires a paid Team license (Lemon Squeezy)
8
+
9
+ License keys are stored in ``~/.mcpswitch/license.json`` and validated
10
+ offline via HMAC (see ``mcpswitch.billing``).
11
+ """
12
+
13
+ import json
14
+ import time
15
+ from pathlib import Path
16
+
17
+ from .billing import (
18
+ TEAM_PAYMENT_URL,
19
+ validate_license_key,
20
+ )
21
+
22
+ LICENSE_FILE = Path.home() / ".mcpswitch" / "license.json"
23
+ NUDGE_FILE = Path.home() / ".mcpswitch" / "nudge.json"
24
+
25
+ # How often to show the donation nudge (seconds). Default: once per 7 days.
26
+ _NUDGE_INTERVAL = 7 * 24 * 3600
27
+
28
+
29
+
30
+ # ── License persistence ───────────────────────────────────────────────────────
31
+
32
+ def _load_license() -> dict:
33
+ if not LICENSE_FILE.exists():
34
+ return {}
35
+ try:
36
+ with open(LICENSE_FILE, "r", encoding="utf-8") as f:
37
+ return json.load(f)
38
+ except Exception:
39
+ return {}
40
+
41
+
42
+ def _save_license(data: dict) -> None:
43
+ LICENSE_FILE.parent.mkdir(parents=True, exist_ok=True)
44
+ with open(LICENSE_FILE, "w", encoding="utf-8") as f:
45
+ json.dump(data, f, indent=2)
46
+
47
+
48
+ # ── Activation ────────────────────────────────────────────────────────────────
49
+
50
+ def activate_license(key: str) -> dict:
51
+ """Activate a Team license key.
52
+
53
+ Returns ``{"success": bool, "tier": str, "message": str}``.
54
+ """
55
+ key = key.strip()
56
+
57
+ result = validate_license_key(key)
58
+ if not result["valid"]:
59
+ return {
60
+ "success": False,
61
+ "tier": "free",
62
+ "message": f"Invalid license key: {result['error']}",
63
+ }
64
+
65
+ tier = result["tier"] # always "team" for Lemon Squeezy-issued keys
66
+ seats = result.get("seats", 1)
67
+
68
+ _save_license({
69
+ "key": key.upper(),
70
+ "tier": tier,
71
+ "activated_at": time.time(),
72
+ "seats": seats,
73
+ "email_hash": result.get("email_hash", ""),
74
+ })
75
+
76
+ seat_str = f" ({seats} seat{'s' if seats != 1 else ''})"
77
+ return {
78
+ "success": True,
79
+ "tier": tier,
80
+ "message": f"Activated {tier.upper()} tier{seat_str}.",
81
+ }
82
+
83
+
84
+ # ── Tier queries ──────────────────────────────────────────────────────────────
85
+
86
+ def get_tier() -> str:
87
+ """Return current tier: ``'free'``, ``'pro'``, or ``'team'``."""
88
+ return _load_license().get("tier", "free")
89
+
90
+
91
+ def is_pro() -> bool:
92
+ # Pro features are free — everyone is considered "pro".
93
+ return True
94
+
95
+
96
+ def is_team() -> bool:
97
+ return get_tier() == "team"
98
+
99
+
100
+ def deactivate() -> None:
101
+ if LICENSE_FILE.exists():
102
+ LICENSE_FILE.unlink()
103
+
104
+
105
+ # ── Feature gates ─────────────────────────────────────────────────────────────
106
+
107
+ def require_pro(feature_name: str) -> bool:
108
+ """Always returns ``True`` — Pro features are free for everyone.
109
+
110
+ Call ``maybe_show_donation_nudge()`` separately to show an optional nudge.
111
+ """
112
+ return True
113
+
114
+
115
+ def pro_gate_message(feature_name: str) -> str:
116
+ """Kept for backward-compat; now returns the donation nudge message."""
117
+ return donation_nudge_message()
118
+
119
+
120
+ # ── Donation nudge (non-blocking) ─────────────────────────────────────────────
121
+
122
+ def _load_nudge() -> dict:
123
+ if not NUDGE_FILE.exists():
124
+ return {}
125
+ try:
126
+ with open(NUDGE_FILE, "r", encoding="utf-8") as f:
127
+ return json.load(f)
128
+ except Exception:
129
+ return {}
130
+
131
+
132
+ def _save_nudge(data: dict) -> None:
133
+ NUDGE_FILE.parent.mkdir(parents=True, exist_ok=True)
134
+ with open(NUDGE_FILE, "w", encoding="utf-8") as f:
135
+ json.dump(data, f, indent=2)
136
+
137
+
138
+ def should_show_nudge() -> bool:
139
+ """Return True if enough time has passed since the last donation nudge."""
140
+ if is_team():
141
+ return False # paid users don't see donation prompts
142
+ nudge = _load_nudge()
143
+ last = nudge.get("last_shown", 0)
144
+ return (time.time() - last) >= _NUDGE_INTERVAL
145
+
146
+
147
+ def record_nudge_shown() -> None:
148
+ """Mark that the nudge was shown now."""
149
+ _save_nudge({"last_shown": time.time()})
150
+
151
+
152
+ def donation_nudge_message() -> str:
153
+ return (
154
+ "\n[dim]─────────────────────────────────────────────────────[/dim]\n"
155
+ "[bold yellow]MCPSwitch is free and open-source.[/bold yellow] "
156
+ "If it saves you context, consider supporting it:\n\n"
157
+ " [bold cyan]mcpswitch donate[/bold cyan] "
158
+ "→ one-time tip, any amount, no account needed\n\n"
159
+ "Want team features (shared profiles, Slack alerts)?\n"
160
+ " [bold cyan]mcpswitch upgrade[/bold cyan] "
161
+ f"→ {TEAM_PAYMENT_URL}\n"
162
+ "[dim]─────────────────────────────────────────────────────[/dim]\n"
163
+ )
164
+
165
+
166
+ def maybe_show_donation_nudge(console) -> None:
167
+ """Print the donation nudge if the weekly timer has elapsed."""
168
+ if should_show_nudge():
169
+ console.print(donation_nudge_message())
170
+ record_nudge_shown()
mcpswitch/tokens.py ADDED
@@ -0,0 +1,426 @@
1
+ """Estimate token cost of MCP server tool schemas.
2
+
3
+ Strategy (in order):
4
+ 1. Cache hit — return stored exact count from previous live query
5
+ 2. Known list — return hardcoded count for popular servers
6
+ 3. Live query — start the server process, call tools/list, measure real schemas
7
+ 4. Fallback — assume 5 tools if server won't start
8
+ """
9
+
10
+ import json
11
+ import sqlite3
12
+ import subprocess
13
+ import time
14
+ from pathlib import Path
15
+ from typing import Optional
16
+
17
+ import tiktoken
18
+
19
+ # ── Constants ────────────────────────────────────────────────────────────────
20
+
21
+ TOKENS_PER_TOOL_AVG = 180 # avg tokens per tool definition (name+desc+schema)
22
+ BASE_OVERHEAD_PER_SERVER = 50 # MCP server metadata / transport overhead
23
+ CONTEXT_WINDOW = 200_000
24
+ CACHE_DB = Path.home() / ".mcpswitch" / "token_cache.db"
25
+ CACHE_TTL_DAYS = 7 # re-query after 7 days
26
+
27
+ # Hardcoded counts for the most popular servers (fast path, no process spawn)
28
+ KNOWN_TOOL_COUNTS = {
29
+ "github": 30,
30
+ "context7": 3,
31
+ "brave-search": 2,
32
+ "playwright": 25,
33
+ "sequential-thinking": 1,
34
+ "filesystem": 11,
35
+ "postgres": 6,
36
+ "sqlite": 8,
37
+ "memory": 5,
38
+ "fetch": 2,
39
+ "slack": 12,
40
+ "notion": 8,
41
+ "obsidian": 4,
42
+ "prompts-local": 2,
43
+ "linear": 20,
44
+ "jira": 18,
45
+ "figma": 10,
46
+ "google-drive": 8,
47
+ "gmail": 6,
48
+ "calendar": 5,
49
+ "stripe": 14,
50
+ "supabase": 12,
51
+ "vercel": 8,
52
+ "cloudflare": 10,
53
+ }
54
+
55
+ _enc: Optional[tiktoken.Encoding] = None
56
+
57
+
58
+ def _get_encoder() -> tiktoken.Encoding:
59
+ global _enc
60
+ if _enc is None:
61
+ _enc = tiktoken.get_encoding("cl100k_base")
62
+ return _enc
63
+
64
+
65
+ def _count_tokens(text: str) -> int:
66
+ return len(_get_encoder().encode(text))
67
+
68
+
69
+ # ── Cache (SQLite) ────────────────────────────────────────────────────────────
70
+
71
+ def _get_cache_db() -> sqlite3.Connection:
72
+ CACHE_DB.parent.mkdir(parents=True, exist_ok=True)
73
+ conn = sqlite3.connect(str(CACHE_DB))
74
+ conn.execute("""
75
+ CREATE TABLE IF NOT EXISTS tool_cache (
76
+ server_name TEXT PRIMARY KEY,
77
+ token_count INTEGER NOT NULL,
78
+ tool_count INTEGER NOT NULL,
79
+ queried_at REAL NOT NULL
80
+ )
81
+ """)
82
+ conn.commit()
83
+ return conn
84
+
85
+
86
+ def _cache_get(server_name: str) -> Optional[int]:
87
+ """Return cached token count if fresh, else None."""
88
+ try:
89
+ conn = _get_cache_db()
90
+ row = conn.execute(
91
+ "SELECT token_count, queried_at FROM tool_cache WHERE server_name = ?",
92
+ (server_name,)
93
+ ).fetchone()
94
+ conn.close()
95
+ if row:
96
+ token_count, queried_at = row
97
+ age_days = (time.time() - queried_at) / 86400
98
+ if age_days < CACHE_TTL_DAYS:
99
+ return token_count
100
+ except Exception:
101
+ pass
102
+ return None
103
+
104
+
105
+ def _cache_set(server_name: str, token_count: int, tool_count: int) -> None:
106
+ try:
107
+ conn = _get_cache_db()
108
+ conn.execute(
109
+ """INSERT OR REPLACE INTO tool_cache
110
+ (server_name, token_count, tool_count, queried_at)
111
+ VALUES (?, ?, ?, ?)""",
112
+ (server_name, token_count, tool_count, time.time())
113
+ )
114
+ conn.commit()
115
+ conn.close()
116
+ except Exception:
117
+ pass
118
+
119
+
120
+ def _cache_set_bulk(entries: list[tuple[str, int, int]]) -> None:
121
+ """Batch insert many (server_name, token_count, tool_count) entries at once.
122
+ Much faster than calling _cache_set() in a loop for 100+ entries.
123
+ """
124
+ if not entries:
125
+ return
126
+ try:
127
+ now = time.time()
128
+ conn = _get_cache_db()
129
+ conn.executemany(
130
+ """INSERT OR REPLACE INTO tool_cache
131
+ (server_name, token_count, tool_count, queried_at)
132
+ VALUES (?, ?, ?, ?)""",
133
+ [(name, tok, tools, now) for name, tok, tools in entries],
134
+ )
135
+ conn.commit()
136
+ conn.close()
137
+ except Exception:
138
+ pass
139
+
140
+
141
+ def clear_cache(server_name: Optional[str] = None) -> int:
142
+ """Clear cache. Pass server_name to clear one entry, or None to clear all.
143
+ Returns number of rows deleted."""
144
+ try:
145
+ conn = _get_cache_db()
146
+ if server_name:
147
+ cur = conn.execute("DELETE FROM tool_cache WHERE server_name = ?", (server_name,))
148
+ else:
149
+ cur = conn.execute("DELETE FROM tool_cache")
150
+ conn.commit()
151
+ count = cur.rowcount
152
+ conn.close()
153
+ return count
154
+ except Exception:
155
+ return 0
156
+
157
+
158
+ def list_cache() -> list[dict]:
159
+ """Return all cached entries."""
160
+ try:
161
+ conn = _get_cache_db()
162
+ rows = conn.execute(
163
+ "SELECT server_name, token_count, tool_count, queried_at FROM tool_cache ORDER BY token_count DESC"
164
+ ).fetchall()
165
+ conn.close()
166
+ return [
167
+ {
168
+ "server": r[0],
169
+ "tokens": r[1],
170
+ "tools": r[2],
171
+ "age_hours": round((time.time() - r[3]) / 3600, 1),
172
+ }
173
+ for r in rows
174
+ ]
175
+ except Exception:
176
+ return []
177
+
178
+
179
+ # ── Live Query via MCP Protocol ───────────────────────────────────────────────
180
+
181
+ _MCP_INIT_MSG = json.dumps({
182
+ "jsonrpc": "2.0",
183
+ "id": 1,
184
+ "method": "initialize",
185
+ "params": {
186
+ "protocolVersion": "2024-11-05",
187
+ "capabilities": {},
188
+ "clientInfo": {"name": "mcpswitch", "version": "0.1.0"},
189
+ },
190
+ }) + "\n"
191
+
192
+ _MCP_TOOLS_MSG = json.dumps({
193
+ "jsonrpc": "2.0",
194
+ "id": 2,
195
+ "method": "tools/list",
196
+ "params": {},
197
+ }) + "\n"
198
+
199
+
200
+ def _live_query_tools(server_config: dict, timeout: float = 8.0) -> Optional[list[dict]]:
201
+ """Start the MCP server process and call tools/list.
202
+
203
+ Returns list of tool dicts, or None if the server can't be reached.
204
+ Each tool dict has 'name', 'description', 'inputSchema'.
205
+ """
206
+ command = server_config.get("command")
207
+ args = server_config.get("args", [])
208
+ env_extra = server_config.get("env", {})
209
+
210
+ if not command:
211
+ return None
212
+
213
+ import os
214
+ env = os.environ.copy()
215
+ env.update(env_extra)
216
+
217
+ try:
218
+ proc = subprocess.Popen(
219
+ [command] + args,
220
+ stdin=subprocess.PIPE,
221
+ stdout=subprocess.PIPE,
222
+ stderr=subprocess.DEVNULL,
223
+ env=env,
224
+ text=True,
225
+ )
226
+
227
+ # Send initialize
228
+ proc.stdin.write(_MCP_INIT_MSG)
229
+ proc.stdin.flush()
230
+
231
+ # Read initialize response (with timeout via poll)
232
+ deadline = time.time() + timeout
233
+ init_resp = None
234
+ while time.time() < deadline:
235
+ line = proc.stdout.readline()
236
+ if not line:
237
+ time.sleep(0.05)
238
+ continue
239
+ try:
240
+ msg = json.loads(line.strip())
241
+ if msg.get("id") == 1:
242
+ init_resp = msg
243
+ break
244
+ except json.JSONDecodeError:
245
+ continue
246
+
247
+ if init_resp is None:
248
+ proc.terminate()
249
+ return None
250
+
251
+ # Send initialized notification
252
+ proc.stdin.write(json.dumps({
253
+ "jsonrpc": "2.0",
254
+ "method": "notifications/initialized",
255
+ "params": {},
256
+ }) + "\n")
257
+ proc.stdin.flush()
258
+
259
+ # Send tools/list
260
+ proc.stdin.write(_MCP_TOOLS_MSG)
261
+ proc.stdin.flush()
262
+
263
+ # Read tools/list response
264
+ tools_resp = None
265
+ deadline = time.time() + timeout
266
+ while time.time() < deadline:
267
+ line = proc.stdout.readline()
268
+ if not line:
269
+ time.sleep(0.05)
270
+ continue
271
+ try:
272
+ msg = json.loads(line.strip())
273
+ if msg.get("id") == 2:
274
+ tools_resp = msg
275
+ break
276
+ except json.JSONDecodeError:
277
+ continue
278
+
279
+ proc.terminate()
280
+
281
+ if tools_resp and "result" in tools_resp:
282
+ return tools_resp["result"].get("tools", [])
283
+
284
+ except Exception:
285
+ pass
286
+
287
+ return None
288
+
289
+
290
+ def _tokens_from_tool_list(tools: list[dict]) -> int:
291
+ """Count exact tokens from real tool schemas."""
292
+ total = BASE_OVERHEAD_PER_SERVER
293
+ for tool in tools:
294
+ schema_str = json.dumps(tool)
295
+ total += _count_tokens(schema_str)
296
+ return total
297
+
298
+
299
+ # ── Public API ────────────────────────────────────────────────────────────────
300
+
301
+ def get_server_tokens(
302
+ server_name: str,
303
+ server_config: dict,
304
+ live: bool = False,
305
+ ) -> tuple[int, str]:
306
+ """Return (token_count, source) for a server.
307
+
308
+ source is one of: 'cache', 'known', 'live', 'fallback'
309
+
310
+ Args:
311
+ server_name: name from mcpServers config
312
+ server_config: full server config dict
313
+ live: if True, force a live query even if cache/known exists
314
+ """
315
+ name_lower = server_name.lower()
316
+
317
+ # 1. Cache
318
+ if not live:
319
+ cached = _cache_get(server_name)
320
+ if cached is not None:
321
+ return cached, "cache"
322
+
323
+ # 2. Known list (fast path for popular servers)
324
+ if not live:
325
+ for known, count in KNOWN_TOOL_COUNTS.items():
326
+ if known in name_lower:
327
+ tokens = BASE_OVERHEAD_PER_SERVER + (count * TOKENS_PER_TOOL_AVG)
328
+ return tokens, "known"
329
+
330
+ # 3. Live query — start the server, call tools/list
331
+ tools = _live_query_tools(server_config)
332
+ if tools is not None:
333
+ tokens = _tokens_from_tool_list(tools)
334
+ _cache_set(server_name, tokens, len(tools))
335
+ return tokens, "live"
336
+
337
+ # 4. Fallback — assume 5 tools
338
+ tokens = BASE_OVERHEAD_PER_SERVER + (5 * TOKENS_PER_TOOL_AVG)
339
+ return tokens, "fallback"
340
+
341
+
342
+ def _cache_get_bulk(names: list[str]) -> dict[str, int]:
343
+ """Fetch multiple cache entries in a single SQLite query.
344
+ Returns {server_name: token_count} for cache-fresh entries only.
345
+ """
346
+ if not names:
347
+ return {}
348
+ try:
349
+ conn = _get_cache_db()
350
+ placeholders = ",".join("?" * len(names))
351
+ rows = conn.execute(
352
+ f"SELECT server_name, token_count, queried_at FROM tool_cache WHERE server_name IN ({placeholders})",
353
+ names,
354
+ ).fetchall()
355
+ conn.close()
356
+ now = time.time()
357
+ return {
358
+ r[0]: r[1]
359
+ for r in rows
360
+ if (now - r[2]) / 86400 < CACHE_TTL_DAYS
361
+ }
362
+ except Exception:
363
+ return {}
364
+
365
+
366
+ def estimate_total_tokens(servers: dict, live: bool = False) -> dict:
367
+ """Return token estimates for a set of MCP servers.
368
+
369
+ Uses a single bulk cache query for all servers — fast even at 1000+ servers.
370
+
371
+ Returns:
372
+ {
373
+ "total": int,
374
+ "per_server": { "name": {"tokens": int, "source": str} },
375
+ "context_pct": float
376
+ }
377
+ """
378
+ # Bulk cache lookup — one SQLite query for all servers
379
+ cache_hits = {} if live else _cache_get_bulk(list(servers.keys()))
380
+
381
+ per_server = {}
382
+ for name, cfg in servers.items():
383
+ # 1. Cache hit
384
+ if name in cache_hits:
385
+ per_server[name] = {"tokens": cache_hits[name], "source": "cache"}
386
+ continue
387
+
388
+ # 2. Known list
389
+ if not live:
390
+ name_lower = name.lower()
391
+ matched = False
392
+ for known, count in KNOWN_TOOL_COUNTS.items():
393
+ if known in name_lower:
394
+ tokens = BASE_OVERHEAD_PER_SERVER + (count * TOKENS_PER_TOOL_AVG)
395
+ per_server[name] = {"tokens": tokens, "source": "known"}
396
+ matched = True
397
+ break
398
+ if matched:
399
+ continue
400
+
401
+ # 3. Live query
402
+ tools = _live_query_tools(cfg)
403
+ if tools is not None:
404
+ tokens = _tokens_from_tool_list(tools)
405
+ _cache_set(name, tokens, len(tools))
406
+ per_server[name] = {"tokens": tokens, "source": "live"}
407
+ continue
408
+
409
+ # 4. Fallback
410
+ tokens = BASE_OVERHEAD_PER_SERVER + (5 * TOKENS_PER_TOOL_AVG)
411
+ per_server[name] = {"tokens": tokens, "source": "fallback"}
412
+
413
+ total = sum(v["tokens"] for v in per_server.values())
414
+ pct = round((total / CONTEXT_WINDOW) * 100, 1)
415
+
416
+ return {
417
+ "total": total,
418
+ "per_server": per_server,
419
+ "context_pct": pct,
420
+ }
421
+
422
+
423
+ def format_token_count(n: int) -> str:
424
+ if n >= 1000:
425
+ return f"{n:,} (~{n/1000:.1f}k)"
426
+ return str(n)