gemi-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gemi/keys/manager.py ADDED
@@ -0,0 +1,265 @@
1
+ import time
2
+ from dataclasses import dataclass, field
3
+
4
+ from rich.console import Console
5
+
6
+ from gemi.config import load_config
7
+ from gemi.keys.store import get_decrypted_keys
8
+ from gemi.registry import ALL_PROVIDER_NAMES, PROVIDERS, get_provider_info
9
+
10
+ console = Console()
11
+
12
+
13
+ @dataclass
14
+ class KeyState:
15
+ provider: str
16
+ name: str
17
+ api_key: str
18
+ requests_used: int = 0
19
+ tokens_used: int = 0
20
+ last_used: float = 0
21
+ cooldown_until: float = 0
22
+ is_exhausted: bool = False
23
+
24
+
25
+ @dataclass
26
+ class KeyManager:
27
+ config: dict = field(default_factory=dict)
28
+ _keys: dict[str, list[KeyState]] = field(default_factory=dict)
29
+ _current_index: dict[str, int] = field(default_factory=dict)
30
+ _current_provider: str = ""
31
+ _model_index: dict[str, int] = field(default_factory=dict)
32
+ _failed_models: dict[str, set] = field(default_factory=dict)
33
+
34
+ def __post_init__(self):
35
+ if not self.config:
36
+ self.config = load_config()
37
+ self._current_provider = self.config.get("default_provider", "gemini")
38
+ self._load_all_keys()
39
+
40
+ def _load_all_keys(self):
41
+ priority = self.config.get("rotation", {}).get(
42
+ "provider_priority", ALL_PROVIDER_NAMES
43
+ )
44
+ all_providers = set(priority) | set(ALL_PROVIDER_NAMES)
45
+
46
+ for provider in priority:
47
+ self._load_provider_keys(provider)
48
+
49
+ for provider in all_providers - set(priority):
50
+ self._load_provider_keys(provider)
51
+
52
+ def _load_provider_keys(self, provider: str):
53
+ info = get_provider_info(provider)
54
+
55
+ if provider == "ollama":
56
+ self._keys[provider] = [
57
+ KeyState(provider="ollama", name="local", api_key="")
58
+ ]
59
+ self._current_index[provider] = 0
60
+ return
61
+
62
+ keys = get_decrypted_keys(provider)
63
+ if keys:
64
+ self._keys[provider] = [
65
+ KeyState(
66
+ provider=k["provider"],
67
+ name=k["name"],
68
+ api_key=k["api_key"],
69
+ )
70
+ for k in keys
71
+ ]
72
+ self._current_index[provider] = 0
73
+
74
+ def get_current_key(self) -> KeyState | None:
75
+ keys = self._keys.get(self._current_provider, [])
76
+ if not keys:
77
+ return self._try_failover()
78
+ idx = self._current_index.get(self._current_provider, 0)
79
+ key = keys[idx]
80
+ if key.cooldown_until > time.time():
81
+ return self._rotate_key()
82
+ return key
83
+
84
+ def get_current_provider(self) -> str:
85
+ return self._current_provider
86
+
87
+ def get_current_model(self) -> str | None:
88
+ info = PROVIDERS.get(self._current_provider)
89
+ if not info or not info["models"]:
90
+ return None
91
+ idx = self._model_index.get(self._current_provider, 0)
92
+ return info["models"][idx]
93
+
94
+ def get_models_for_provider(self, provider: str) -> list[str]:
95
+ info = PROVIDERS.get(provider)
96
+ if not info:
97
+ return []
98
+ return info["models"]
99
+
100
+ def try_next_model(self) -> str | None:
101
+ provider = self._current_provider
102
+ info = PROVIDERS.get(provider)
103
+ if not info or not info["models"]:
104
+ return None
105
+
106
+ models = info["models"]
107
+ current_idx = self._model_index.get(provider, 0)
108
+ failed = self._failed_models.get(provider, set())
109
+ failed.add(models[current_idx])
110
+ self._failed_models[provider] = failed
111
+
112
+ for i in range(1, len(models)):
113
+ next_idx = (current_idx + i) % len(models)
114
+ candidate = models[next_idx]
115
+ if candidate not in failed:
116
+ self._model_index[provider] = next_idx
117
+ console.print(f" [cyan]Trying model {candidate} on {provider}...[/cyan]")
118
+ return candidate
119
+
120
+ return None
121
+
122
+ def reset_failed_models(self):
123
+ self._failed_models.clear()
124
+ self._model_index.clear()
125
+
126
+ def record_usage(self, tokens: int = 0):
127
+ key = self._get_current_key_state()
128
+ if key:
129
+ key.requests_used += 1
130
+ key.tokens_used += tokens
131
+ key.last_used = time.time()
132
+
133
+ def report_rate_limit(self, retry_after: float | None = None) -> KeyState | None:
134
+ key = self._get_current_key_state()
135
+ if key:
136
+ cooldown = retry_after or 60
137
+ key.cooldown_until = time.time() + cooldown
138
+ mins, secs = divmod(int(cooldown), 60)
139
+ if mins > 0:
140
+ wait_str = f"{mins}m {secs}s"
141
+ else:
142
+ wait_str = f"{secs}s"
143
+ console.print(
144
+ f" [yellow]Rate limited on {key.provider}/{key.name} — retry in {wait_str}, rotating...[/yellow]"
145
+ )
146
+ return self._rotate_key()
147
+
148
+ def report_exhausted(self) -> KeyState | None:
149
+ key = self._get_current_key_state()
150
+ if key:
151
+ key.is_exhausted = True
152
+ console.print(
153
+ f" [yellow]{key.provider}/{key.name} exhausted[/yellow]"
154
+ )
155
+ return self._rotate_key()
156
+
157
+ def _rotate_key(self) -> KeyState | None:
158
+ keys = self._keys.get(self._current_provider, [])
159
+ if not keys:
160
+ return self._try_failover()
161
+
162
+ start_idx = self._current_index.get(self._current_provider, 0)
163
+ now = time.time()
164
+
165
+ for i in range(1, len(keys) + 1):
166
+ idx = (start_idx + i) % len(keys)
167
+ candidate = keys[idx]
168
+ if not candidate.is_exhausted and candidate.cooldown_until <= now:
169
+ self._current_index[self._current_provider] = idx
170
+ console.print(
171
+ f" [green]Switched to {candidate.provider}/{candidate.name}[/green]"
172
+ )
173
+ return candidate
174
+
175
+ return self._try_failover()
176
+
177
+ def _try_failover(self) -> KeyState | None:
178
+ if not self.config.get("rotation", {}).get("auto_switch_provider", True):
179
+ return None
180
+
181
+ priority = self.config.get("rotation", {}).get(
182
+ "provider_priority", ALL_PROVIDER_NAMES
183
+ )
184
+ now = time.time()
185
+
186
+ for provider in priority:
187
+ if provider == self._current_provider:
188
+ continue
189
+ keys = self._keys.get(provider, [])
190
+ for i, key in enumerate(keys):
191
+ if not key.is_exhausted and key.cooldown_until <= now:
192
+ self._current_provider = provider
193
+ self._current_index[provider] = i
194
+ info = get_provider_info(provider)
195
+ display = info["name"] if info else provider
196
+ console.print(
197
+ f" [bold green]Failover → {display} ({key.name})[/bold green]"
198
+ )
199
+ return key
200
+
201
+ return None
202
+
203
+ def get_nearest_cooldown(self) -> float | None:
204
+ now = time.time()
205
+ nearest = None
206
+ for keys in self._keys.values():
207
+ for key in keys:
208
+ if key.is_exhausted:
209
+ continue
210
+ if key.cooldown_until > now:
211
+ wait = key.cooldown_until - now
212
+ if nearest is None or wait < nearest:
213
+ nearest = wait
214
+ return nearest
215
+
216
+ def get_any_available_key(self) -> KeyState | None:
217
+ now = time.time()
218
+ priority = self.config.get("rotation", {}).get(
219
+ "provider_priority", ALL_PROVIDER_NAMES
220
+ )
221
+ for provider in priority:
222
+ keys = self._keys.get(provider, [])
223
+ for i, key in enumerate(keys):
224
+ if not key.is_exhausted and key.cooldown_until <= now:
225
+ self._current_provider = provider
226
+ self._current_index[provider] = i
227
+ return key
228
+ return None
229
+
230
+ def _get_current_key_state(self) -> KeyState | None:
231
+ keys = self._keys.get(self._current_provider, [])
232
+ if not keys:
233
+ return None
234
+ idx = self._current_index.get(self._current_provider, 0)
235
+ return keys[idx]
236
+
237
+ def get_status(self) -> list[dict]:
238
+ status = []
239
+ now = time.time()
240
+ for provider, keys in self._keys.items():
241
+ for i, key in enumerate(keys):
242
+ is_current = (
243
+ provider == self._current_provider
244
+ and i == self._current_index.get(provider, 0)
245
+ )
246
+ state = "active" if is_current else "standby"
247
+ cooldown_remaining = 0
248
+ if key.is_exhausted:
249
+ state = "exhausted"
250
+ elif key.cooldown_until > now:
251
+ cooldown_remaining = int(key.cooldown_until - now)
252
+ mins, secs = divmod(cooldown_remaining, 60)
253
+ if mins > 0:
254
+ state = f"cooldown ({mins}m {secs}s)"
255
+ else:
256
+ state = f"cooldown ({secs}s)"
257
+ status.append({
258
+ "provider": provider,
259
+ "name": key.name,
260
+ "state": state,
261
+ "requests": key.requests_used,
262
+ "tokens": key.tokens_used,
263
+ "cooldown_remaining": cooldown_remaining,
264
+ })
265
+ return status
gemi/keys/store.py ADDED
@@ -0,0 +1,92 @@
1
+ import json
2
+ import os
3
+ import platform
4
+ from datetime import datetime, timezone
5
+ from pathlib import Path
6
+
7
+ from cryptography.fernet import Fernet
8
+
9
+ from gemi.config import GEMI_DIR, ensure_gemi_dir
10
+
11
+ KEYS_PATH = GEMI_DIR / "keys.json"
12
+ SECRET_PATH = GEMI_DIR / ".secret"
13
+
14
+
15
+ def _get_or_create_secret() -> bytes:
16
+ ensure_gemi_dir()
17
+ if SECRET_PATH.exists():
18
+ return SECRET_PATH.read_bytes()
19
+ key = Fernet.generate_key()
20
+ SECRET_PATH.write_bytes(key)
21
+ SECRET_PATH.chmod(0o600)
22
+ return key
23
+
24
+
25
+ def _cipher() -> Fernet:
26
+ return Fernet(_get_or_create_secret())
27
+
28
+
29
+ def _load_raw() -> list[dict]:
30
+ if not KEYS_PATH.exists():
31
+ return []
32
+ data = json.loads(KEYS_PATH.read_text())
33
+ return data if isinstance(data, list) else []
34
+
35
+
36
+ def _save_raw(keys: list[dict]):
37
+ ensure_gemi_dir()
38
+ KEYS_PATH.write_text(json.dumps(keys, indent=2))
39
+ KEYS_PATH.chmod(0o600)
40
+
41
+
42
+ def add_key(provider: str, name: str, api_key: str):
43
+ keys = _load_raw()
44
+ for k in keys:
45
+ if k["provider"] == provider and k["name"] == name:
46
+ k["api_key"] = _cipher().encrypt(api_key.encode()).decode()
47
+ _save_raw(keys)
48
+ return
49
+ keys.append({
50
+ "provider": provider,
51
+ "name": name,
52
+ "api_key": _cipher().encrypt(api_key.encode()).decode(),
53
+ "added_at": datetime.now(timezone.utc).isoformat(),
54
+ })
55
+ _save_raw(keys)
56
+
57
+
58
+ def remove_key(provider: str, name: str) -> bool:
59
+ keys = _load_raw()
60
+ filtered = [k for k in keys if not (k["provider"] == provider and k["name"] == name)]
61
+ if len(filtered) == len(keys):
62
+ return False
63
+ _save_raw(filtered)
64
+ return True
65
+
66
+
67
+ def list_keys(provider: str | None = None) -> list[dict]:
68
+ keys = _load_raw()
69
+ if provider:
70
+ keys = [k for k in keys if k["provider"] == provider]
71
+ result = []
72
+ for k in keys:
73
+ result.append({
74
+ "provider": k["provider"],
75
+ "name": k["name"],
76
+ "added_at": k["added_at"],
77
+ })
78
+ return result
79
+
80
+
81
+ def get_decrypted_keys(provider: str) -> list[dict]:
82
+ keys = _load_raw()
83
+ result = []
84
+ f = _cipher()
85
+ for k in keys:
86
+ if k["provider"] == provider:
87
+ result.append({
88
+ "provider": k["provider"],
89
+ "name": k["name"],
90
+ "api_key": f.decrypt(k["api_key"].encode()).decode(),
91
+ })
92
+ return result