gemi-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gemi/__init__.py +1 -0
- gemi/agent/__init__.py +0 -0
- gemi/agent/loop.py +594 -0
- gemi/agent/tools.py +571 -0
- gemi/compaction.py +67 -0
- gemi/config.py +53 -0
- gemi/keys/__init__.py +0 -0
- gemi/keys/manager.py +265 -0
- gemi/keys/store.py +92 -0
- gemi/main.py +426 -0
- gemi/providers/__init__.py +0 -0
- gemi/providers/base.py +35 -0
- gemi/providers/gemini.py +126 -0
- gemi/providers/ollama.py +72 -0
- gemi/providers/openai_compat.py +140 -0
- gemi/registry.py +201 -0
- gemi/sessions.py +84 -0
- gemi/ui.py +387 -0
- gemi_cli-0.1.0.dist-info/METADATA +462 -0
- gemi_cli-0.1.0.dist-info/RECORD +22 -0
- gemi_cli-0.1.0.dist-info/WHEEL +4 -0
- gemi_cli-0.1.0.dist-info/entry_points.txt +2 -0
gemi/main.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
import typer
|
|
6
|
+
from prompt_toolkit import PromptSession
|
|
7
|
+
from prompt_toolkit.history import FileHistory
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from rich.table import Table
|
|
10
|
+
|
|
11
|
+
from gemi.config import GEMI_DIR, load_config, save_config
|
|
12
|
+
from gemi.keys.store import add_key, list_keys, remove_key
|
|
13
|
+
from gemi.registry import ALL_PROVIDER_NAMES, PROVIDERS, get_base_url, get_provider_info, get_provider_type
|
|
14
|
+
from gemi.sessions import (
|
|
15
|
+
generate_session_id,
|
|
16
|
+
list_sessions,
|
|
17
|
+
load_session,
|
|
18
|
+
)
|
|
19
|
+
from gemi.ui import print_banner, print_help, print_key_status, print_welcome
|
|
20
|
+
|
|
21
|
+
app = typer.Typer(
|
|
22
|
+
name="gemi",
|
|
23
|
+
help="Free AI coding agent — multi-account key rotation with provider failover",
|
|
24
|
+
no_args_is_help=False,
|
|
25
|
+
)
|
|
26
|
+
console = Console()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@app.callback(invoke_without_command=True)
|
|
30
|
+
def main(
|
|
31
|
+
ctx: typer.Context,
|
|
32
|
+
resume: str = typer.Option(None, "--resume", "-r", help="Resume a previous session by ID"),
|
|
33
|
+
):
|
|
34
|
+
if ctx.invoked_subcommand is not None:
|
|
35
|
+
return
|
|
36
|
+
asyncio.run(_interactive_loop(resume_id=resume))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def _interactive_loop(resume_id: str | None = None):
|
|
40
|
+
from gemi.agent.loop import AgentLoop
|
|
41
|
+
from gemi.agent.tools import cleanup_background_processes, get_last_edit
|
|
42
|
+
|
|
43
|
+
print_banner()
|
|
44
|
+
|
|
45
|
+
session_id = resume_id or generate_session_id()
|
|
46
|
+
agent = AgentLoop(session_id=session_id)
|
|
47
|
+
|
|
48
|
+
if not agent.provider:
|
|
49
|
+
console.print("[bold yellow]No API keys found.[/bold yellow]")
|
|
50
|
+
console.print("Add a Gemini API key to get started:\n")
|
|
51
|
+
console.print(" [bold]gemi key add gemini[/bold]\n")
|
|
52
|
+
console.print("Get a free key at: [link]https://aistudio.google.com/apikey[/link]\n")
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
if resume_id:
|
|
56
|
+
old_messages = load_session(resume_id)
|
|
57
|
+
if old_messages:
|
|
58
|
+
agent.load_session(old_messages)
|
|
59
|
+
console.print(f" [green]Resumed session {resume_id} ({len(old_messages)} messages)[/green]")
|
|
60
|
+
else:
|
|
61
|
+
console.print(f" [yellow]Session {resume_id} not found, starting fresh[/yellow]")
|
|
62
|
+
|
|
63
|
+
provider_name = agent.key_manager.get_current_provider()
|
|
64
|
+
print_welcome(provider_name, agent.model, os.getcwd())
|
|
65
|
+
console.print(f" Session: [dim]{session_id}[/dim]")
|
|
66
|
+
|
|
67
|
+
GEMI_DIR.mkdir(parents=True, exist_ok=True)
|
|
68
|
+
history_file = str(GEMI_DIR / "history.txt")
|
|
69
|
+
session = PromptSession(history=FileHistory(history_file))
|
|
70
|
+
|
|
71
|
+
def _get_border():
|
|
72
|
+
try:
|
|
73
|
+
w = os.get_terminal_size().columns
|
|
74
|
+
except OSError:
|
|
75
|
+
w = 80
|
|
76
|
+
return "─" * (w - 2)
|
|
77
|
+
|
|
78
|
+
def _prompt():
|
|
79
|
+
try:
|
|
80
|
+
return session.prompt("│ ❯ ")
|
|
81
|
+
except (EOFError, KeyboardInterrupt):
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
console.print(f"\n[dim]╭{_get_border()}╮[/dim]")
|
|
85
|
+
|
|
86
|
+
while True:
|
|
87
|
+
try:
|
|
88
|
+
user_input = await asyncio.get_event_loop().run_in_executor(
|
|
89
|
+
None, _prompt
|
|
90
|
+
)
|
|
91
|
+
except (EOFError, KeyboardInterrupt, asyncio.CancelledError):
|
|
92
|
+
user_input = None
|
|
93
|
+
|
|
94
|
+
if user_input is None:
|
|
95
|
+
console.print(f"[dim]╰{_get_border()}╯[/dim]")
|
|
96
|
+
cleanup_background_processes()
|
|
97
|
+
console.print("[dim]Goodbye![/dim]")
|
|
98
|
+
break
|
|
99
|
+
|
|
100
|
+
user_input = user_input.strip()
|
|
101
|
+
if not user_input:
|
|
102
|
+
sys.stdout.write("\033[A\033[2K")
|
|
103
|
+
sys.stdout.flush()
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
console.print(f"[dim]╰{_get_border()}╯[/dim]")
|
|
107
|
+
|
|
108
|
+
if user_input.startswith("/"):
|
|
109
|
+
cmd = user_input.lower().split()[0]
|
|
110
|
+
if cmd in ("/quit", "/exit", "/q"):
|
|
111
|
+
cleanup_background_processes()
|
|
112
|
+
console.print("[dim]Goodbye![/dim]")
|
|
113
|
+
break
|
|
114
|
+
elif cmd == "/help":
|
|
115
|
+
print_help()
|
|
116
|
+
elif cmd == "/status":
|
|
117
|
+
print_key_status(agent.key_manager.get_status())
|
|
118
|
+
console.print(f"\n {agent.get_status_line()}")
|
|
119
|
+
elif cmd == "/clear":
|
|
120
|
+
agent.messages = agent.messages[:1]
|
|
121
|
+
agent.total_tokens_used = 0
|
|
122
|
+
agent.total_requests = 0
|
|
123
|
+
console.print("[dim]Conversation cleared.[/dim]")
|
|
124
|
+
elif cmd == "/model":
|
|
125
|
+
parts = user_input.split()
|
|
126
|
+
if len(parts) < 2:
|
|
127
|
+
console.print(f" Current model: [green]{agent.model}[/green]")
|
|
128
|
+
console.print(" Usage: /model <model-name>")
|
|
129
|
+
else:
|
|
130
|
+
agent.model = parts[1]
|
|
131
|
+
console.print(f" Switched to: [green]{agent.model}[/green]")
|
|
132
|
+
elif cmd == "/undo":
|
|
133
|
+
last_edit = get_last_edit()
|
|
134
|
+
if last_edit:
|
|
135
|
+
from pathlib import Path
|
|
136
|
+
p = Path(last_edit["path"])
|
|
137
|
+
p.write_text(last_edit["old_content"])
|
|
138
|
+
console.print(f" [green]Undone last edit to {last_edit['path']}[/green]")
|
|
139
|
+
else:
|
|
140
|
+
console.print(" [yellow]Nothing to undo[/yellow]")
|
|
141
|
+
elif cmd == "/plan":
|
|
142
|
+
plan = agent.get_plan()
|
|
143
|
+
if plan:
|
|
144
|
+
from gemi.ui import print_plan
|
|
145
|
+
print_plan(plan["title"], plan["steps"])
|
|
146
|
+
else:
|
|
147
|
+
console.print(" [dim]No active plan. Start a complex task and gemi will create one automatically.[/dim]")
|
|
148
|
+
elif cmd == "/sessions":
|
|
149
|
+
_show_sessions()
|
|
150
|
+
elif cmd == "/tokens":
|
|
151
|
+
console.print(agent.get_detailed_status())
|
|
152
|
+
else:
|
|
153
|
+
console.print(f"[yellow]Unknown command: {cmd}[/yellow]")
|
|
154
|
+
console.print(f"\n[dim]╭{_get_border()}╮[/dim]")
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
await agent.chat(user_input)
|
|
158
|
+
console.print(f"\n[dim]╭{_get_border()}╮[/dim]")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _show_sessions():
|
|
162
|
+
sessions = list_sessions()
|
|
163
|
+
if not sessions:
|
|
164
|
+
console.print(" [yellow]No saved sessions[/yellow]")
|
|
165
|
+
return
|
|
166
|
+
table = Table(title="Saved Sessions", show_header=True)
|
|
167
|
+
table.add_column("ID", style="cyan")
|
|
168
|
+
table.add_column("Messages", justify="right")
|
|
169
|
+
table.add_column("Directory", style="dim")
|
|
170
|
+
table.add_column("Preview")
|
|
171
|
+
for s in sessions[:10]:
|
|
172
|
+
table.add_row(s["id"], str(s["messages"]), s.get("cwd", ""), s["preview"][:60])
|
|
173
|
+
console.print(table)
|
|
174
|
+
console.print("\n Resume with: [bold]gemi --resume <id>[/bold]")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
key_app = typer.Typer(help="Manage API keys")
|
|
178
|
+
app.add_typer(key_app, name="key")
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _validate_api_key(provider: str, api_key: str) -> tuple[bool, str]:
|
|
182
|
+
provider_type = get_provider_type(provider)
|
|
183
|
+
info = get_provider_info(provider)
|
|
184
|
+
model = info["default_model"] if info else "gpt-4o-mini"
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
if provider_type == "gemini":
|
|
188
|
+
from google import genai
|
|
189
|
+
client = genai.Client(api_key=api_key)
|
|
190
|
+
response = client.models.generate_content(
|
|
191
|
+
model=model,
|
|
192
|
+
contents="Say hi in one word.",
|
|
193
|
+
config={"max_output_tokens": 5},
|
|
194
|
+
)
|
|
195
|
+
return True, ""
|
|
196
|
+
|
|
197
|
+
elif provider_type == "openai_compat":
|
|
198
|
+
from openai import OpenAI
|
|
199
|
+
base_url = get_base_url(provider) or "https://api.openai.com/v1"
|
|
200
|
+
extra_headers = {}
|
|
201
|
+
if "openrouter" in base_url:
|
|
202
|
+
extra_headers["HTTP-Referer"] = "https://github.com/gemi-cli/gemi"
|
|
203
|
+
extra_headers["X-Title"] = "gemi"
|
|
204
|
+
client = OpenAI(
|
|
205
|
+
api_key=api_key,
|
|
206
|
+
base_url=base_url,
|
|
207
|
+
default_headers=extra_headers or None,
|
|
208
|
+
)
|
|
209
|
+
response = client.chat.completions.create(
|
|
210
|
+
model=model,
|
|
211
|
+
messages=[{"role": "user", "content": "Say hi in one word."}],
|
|
212
|
+
max_tokens=5,
|
|
213
|
+
)
|
|
214
|
+
return True, ""
|
|
215
|
+
|
|
216
|
+
except Exception as e:
|
|
217
|
+
error_str = str(e).lower()
|
|
218
|
+
if "401" in error_str or "403" in error_str or "unauthorized" in error_str or "user not found" in error_str or "invalid" in error_str:
|
|
219
|
+
return False, "Authentication failed — key is invalid or expired"
|
|
220
|
+
elif "404" in error_str or "not found" in error_str and "model" in error_str:
|
|
221
|
+
return True, ""
|
|
222
|
+
elif "429" in error_str or "rate" in error_str or "quota" in error_str:
|
|
223
|
+
return True, ""
|
|
224
|
+
else:
|
|
225
|
+
return False, str(e)
|
|
226
|
+
|
|
227
|
+
return True, ""
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@key_app.command("add")
|
|
231
|
+
def key_add(
|
|
232
|
+
provider: str = typer.Argument(help="Provider: " + ", ".join(ALL_PROVIDER_NAMES)),
|
|
233
|
+
name: str = typer.Option("default", "--name", "-n", help="Label for this key"),
|
|
234
|
+
key: str = typer.Option(None, "--key", "-k", help="API key (or omit to enter interactively)"),
|
|
235
|
+
):
|
|
236
|
+
info = get_provider_info(provider)
|
|
237
|
+
if not info:
|
|
238
|
+
console.print(f"[red]Unknown provider: {provider}[/red]")
|
|
239
|
+
console.print(f"Available: {', '.join(ALL_PROVIDER_NAMES)}")
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
if not info["needs_key"]:
|
|
243
|
+
console.print(f"[green]{info['name']} doesn't need an API key — it runs locally![/green]")
|
|
244
|
+
console.print(f" Default model: {info['default_model']}")
|
|
245
|
+
console.print(f" Make sure it's running: ollama serve")
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
existing = list_keys(provider)
|
|
249
|
+
existing_names = {k["name"] for k in existing}
|
|
250
|
+
|
|
251
|
+
if name == "default" and "default" in existing_names:
|
|
252
|
+
console.print(f"\n [yellow]You already have a '{name}' key for {provider}.[/yellow]")
|
|
253
|
+
console.print(f" Existing keys: {', '.join(existing_names)}")
|
|
254
|
+
choice = typer.prompt(
|
|
255
|
+
" Enter a new name for this key, or 'replace' to overwrite",
|
|
256
|
+
default=f"acc{len(existing) + 1}",
|
|
257
|
+
)
|
|
258
|
+
if choice.lower() == "replace":
|
|
259
|
+
pass
|
|
260
|
+
else:
|
|
261
|
+
name = choice.strip()
|
|
262
|
+
|
|
263
|
+
if key:
|
|
264
|
+
api_key = key.strip()
|
|
265
|
+
else:
|
|
266
|
+
console.print(f"\n [bold]{info['name']}[/bold]")
|
|
267
|
+
if info["key_url"]:
|
|
268
|
+
console.print(f" Get your key at: [link]{info['key_url']}[/link]")
|
|
269
|
+
if info["free_tier"]:
|
|
270
|
+
console.print(f" [green]Free tier available[/green]")
|
|
271
|
+
console.print(f" Default model: {info['default_model']}")
|
|
272
|
+
console.print()
|
|
273
|
+
api_key = typer.prompt("Paste your API key", hide_input=True)
|
|
274
|
+
|
|
275
|
+
if not api_key.strip():
|
|
276
|
+
console.print("[red]Empty key, aborting.[/red]")
|
|
277
|
+
return
|
|
278
|
+
|
|
279
|
+
api_key = api_key.strip()
|
|
280
|
+
|
|
281
|
+
console.print(f"\n [dim]Validating key with {info['name']}...[/dim]")
|
|
282
|
+
valid, error_msg = _validate_api_key(provider, api_key)
|
|
283
|
+
|
|
284
|
+
if not valid:
|
|
285
|
+
console.print(f" [bold red]Invalid API key:[/bold red] {error_msg}")
|
|
286
|
+
console.print(f" Key was NOT saved. Please check your key and try again.")
|
|
287
|
+
if info["key_url"]:
|
|
288
|
+
console.print(f" Get a valid key at: [link]{info['key_url']}[/link]")
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
console.print(f" [green]Key verified![/green]")
|
|
292
|
+
|
|
293
|
+
add_key(provider, name, api_key)
|
|
294
|
+
console.print(f"\n[green]Key '{name}' added for {info['name']}[/green]")
|
|
295
|
+
console.print(f" Base URL: [dim]{info['base_url'] or 'native SDK'}[/dim]")
|
|
296
|
+
console.print(f" Model: [dim]{info['default_model']}[/dim]")
|
|
297
|
+
|
|
298
|
+
existing = list_keys(provider)
|
|
299
|
+
if len(existing) > 1:
|
|
300
|
+
console.print(f" [dim]You now have {len(existing)} keys for {provider} — rotation enabled![/dim]")
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@key_app.command("list")
|
|
304
|
+
def key_list(
|
|
305
|
+
provider: str = typer.Argument(None, help="Filter by provider"),
|
|
306
|
+
):
|
|
307
|
+
keys = list_keys(provider)
|
|
308
|
+
if not keys:
|
|
309
|
+
console.print("[yellow]No keys found.[/yellow]")
|
|
310
|
+
console.print("Add one with: [bold]gemi key add gemini[/bold]")
|
|
311
|
+
return
|
|
312
|
+
|
|
313
|
+
table = Table(title="API Keys", show_header=True)
|
|
314
|
+
table.add_column("Provider", style="cyan")
|
|
315
|
+
table.add_column("Name", style="white")
|
|
316
|
+
table.add_column("Added", style="dim")
|
|
317
|
+
|
|
318
|
+
for k in keys:
|
|
319
|
+
table.add_row(k["provider"], k["name"], k["added_at"][:10])
|
|
320
|
+
|
|
321
|
+
console.print(table)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@key_app.command("remove")
|
|
325
|
+
def key_remove(
|
|
326
|
+
provider: str = typer.Argument(help="Provider name"),
|
|
327
|
+
name: str = typer.Argument("default", help="Key label"),
|
|
328
|
+
):
|
|
329
|
+
if remove_key(provider, name):
|
|
330
|
+
console.print(f"[green]Removed key '{name}' for {provider}[/green]")
|
|
331
|
+
else:
|
|
332
|
+
console.print(f"[red]Key '{name}' for {provider} not found[/red]")
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@key_app.command("status")
|
|
336
|
+
def key_status():
|
|
337
|
+
from gemi.agent.loop import AgentLoop
|
|
338
|
+
agent = AgentLoop()
|
|
339
|
+
print_key_status(agent.key_manager.get_status())
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@app.command("sessions")
|
|
343
|
+
def sessions_cmd():
|
|
344
|
+
_show_sessions()
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
config_app = typer.Typer(help="View and edit configuration")
|
|
348
|
+
app.add_typer(config_app, name="config")
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@config_app.callback(invoke_without_command=True)
|
|
352
|
+
def config_show(ctx: typer.Context):
|
|
353
|
+
if ctx.invoked_subcommand is not None:
|
|
354
|
+
return
|
|
355
|
+
import yaml
|
|
356
|
+
config = load_config()
|
|
357
|
+
console.print("[bold]Current configuration:[/bold]\n")
|
|
358
|
+
console.print(yaml.dump(config, default_flow_style=False))
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
@config_app.command("set")
|
|
362
|
+
def config_set(
|
|
363
|
+
key: str = typer.Argument(help="Config key (dot notation, e.g. 'default_model')"),
|
|
364
|
+
value: str = typer.Argument(help="Value to set"),
|
|
365
|
+
):
|
|
366
|
+
config = load_config()
|
|
367
|
+
parts = key.split(".")
|
|
368
|
+
target = config
|
|
369
|
+
for part in parts[:-1]:
|
|
370
|
+
if part not in target or not isinstance(target[part], dict):
|
|
371
|
+
target[part] = {}
|
|
372
|
+
target = target[part]
|
|
373
|
+
target[parts[-1]] = value
|
|
374
|
+
save_config(config)
|
|
375
|
+
console.print(f"[green]Set {key} = {value}[/green]")
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@app.command("model")
|
|
379
|
+
def model_cmd(name: str = typer.Argument(None, help="Model to switch to")):
|
|
380
|
+
config = load_config()
|
|
381
|
+
if name:
|
|
382
|
+
config["default_model"] = name
|
|
383
|
+
save_config(config)
|
|
384
|
+
console.print(f"[green]Default model set to: {name}[/green]")
|
|
385
|
+
else:
|
|
386
|
+
console.print(f"Current model: [green]{config.get('default_model', 'gemini-2.5-flash')}[/green]")
|
|
387
|
+
console.print(f"Current provider: [green]{config.get('default_provider', 'gemini')}[/green]")
|
|
388
|
+
console.print()
|
|
389
|
+
stored = {k["provider"] for k in list_keys()}
|
|
390
|
+
stored.add("ollama")
|
|
391
|
+
for pname in ALL_PROVIDER_NAMES:
|
|
392
|
+
info = PROVIDERS[pname]
|
|
393
|
+
has_key = pname in stored
|
|
394
|
+
status = "[green]configured[/green]" if has_key else "[dim]no key[/dim]"
|
|
395
|
+
console.print(f" [bold]{pname}[/bold] ({info['name']}) — {status}")
|
|
396
|
+
if has_key:
|
|
397
|
+
for m in info["models"][:4]:
|
|
398
|
+
console.print(f" - {m}")
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
@app.command("providers")
|
|
402
|
+
def providers_cmd():
|
|
403
|
+
table = Table(title="Available Providers", show_header=True)
|
|
404
|
+
table.add_column("Provider", style="cyan")
|
|
405
|
+
table.add_column("Name", style="white")
|
|
406
|
+
table.add_column("Free", style="green")
|
|
407
|
+
table.add_column("Default Model", style="dim")
|
|
408
|
+
table.add_column("Key URL")
|
|
409
|
+
table.add_column("Status")
|
|
410
|
+
|
|
411
|
+
stored = {k["provider"] for k in list_keys()}
|
|
412
|
+
stored.add("ollama")
|
|
413
|
+
|
|
414
|
+
for pname, info in PROVIDERS.items():
|
|
415
|
+
has_key = pname in stored
|
|
416
|
+
free = "Yes" if info["free_tier"] else "No"
|
|
417
|
+
url = info["key_url"] or "—"
|
|
418
|
+
status = "[green]Ready[/green]" if has_key else "[dim]Not configured[/dim]"
|
|
419
|
+
table.add_row(pname, info["name"], free, info["default_model"], url, status)
|
|
420
|
+
|
|
421
|
+
console.print(table)
|
|
422
|
+
console.print("\nAdd a provider: [bold]gemi key add <provider>[/bold]")
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
if __name__ == "__main__":
|
|
426
|
+
app()
|
|
File without changes
|
gemi/providers/base.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, AsyncIterator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class Chunk:
|
|
8
|
+
text: str = ""
|
|
9
|
+
tool_calls: list[dict] | None = None
|
|
10
|
+
finish_reason: str | None = None
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class Message:
|
|
15
|
+
role: str # "user", "assistant", "system", "tool"
|
|
16
|
+
content: str = ""
|
|
17
|
+
tool_calls: list[dict] | None = None
|
|
18
|
+
tool_call_id: str | None = None
|
|
19
|
+
name: str | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseProvider(ABC):
|
|
23
|
+
@abstractmethod
|
|
24
|
+
async def chat(
|
|
25
|
+
self,
|
|
26
|
+
messages: list[Message],
|
|
27
|
+
tools: list[dict] | None = None,
|
|
28
|
+
model: str | None = None,
|
|
29
|
+
stream: bool = True,
|
|
30
|
+
) -> AsyncIterator[Chunk]:
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def format_tools(self, tools: list[dict]) -> list:
|
|
35
|
+
...
|
gemi/providers/gemini.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import AsyncIterator
|
|
3
|
+
|
|
4
|
+
from google import genai
|
|
5
|
+
from google.genai import types
|
|
6
|
+
|
|
7
|
+
from gemi.providers.base import BaseProvider, Chunk, Message
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GeminiProvider(BaseProvider):
|
|
11
|
+
def __init__(self, api_key: str):
|
|
12
|
+
self.client = genai.Client(api_key=api_key)
|
|
13
|
+
|
|
14
|
+
def update_key(self, api_key: str):
|
|
15
|
+
self.client = genai.Client(api_key=api_key)
|
|
16
|
+
|
|
17
|
+
async def chat(
|
|
18
|
+
self,
|
|
19
|
+
messages: list[Message],
|
|
20
|
+
tools: list[dict] | None = None,
|
|
21
|
+
model: str | None = None,
|
|
22
|
+
stream: bool = True,
|
|
23
|
+
) -> AsyncIterator[Chunk]:
|
|
24
|
+
model = model or "gemini-2.5-flash"
|
|
25
|
+
contents = self._build_contents(messages)
|
|
26
|
+
formatted_tools = self.format_tools(tools) if tools else None
|
|
27
|
+
|
|
28
|
+
config = types.GenerateContentConfig()
|
|
29
|
+
if formatted_tools:
|
|
30
|
+
config.tools = formatted_tools
|
|
31
|
+
|
|
32
|
+
system_msgs = [m for m in messages if m.role == "system"]
|
|
33
|
+
if system_msgs:
|
|
34
|
+
config.system_instruction = system_msgs[0].content
|
|
35
|
+
|
|
36
|
+
if stream:
|
|
37
|
+
stream_response = await self.client.aio.models.generate_content_stream(
|
|
38
|
+
model=model,
|
|
39
|
+
contents=contents,
|
|
40
|
+
config=config,
|
|
41
|
+
)
|
|
42
|
+
async for response in stream_response:
|
|
43
|
+
if response.candidates and response.candidates[0].content and response.candidates[0].content.parts:
|
|
44
|
+
for part in response.candidates[0].content.parts:
|
|
45
|
+
if part.text:
|
|
46
|
+
yield Chunk(text=part.text)
|
|
47
|
+
elif part.function_call:
|
|
48
|
+
yield Chunk(tool_calls=[{
|
|
49
|
+
"id": part.function_call.name,
|
|
50
|
+
"function": {
|
|
51
|
+
"name": part.function_call.name,
|
|
52
|
+
"arguments": dict(part.function_call.args) if part.function_call.args else {},
|
|
53
|
+
},
|
|
54
|
+
}])
|
|
55
|
+
else:
|
|
56
|
+
response = await self.client.aio.models.generate_content(
|
|
57
|
+
model=model,
|
|
58
|
+
contents=contents,
|
|
59
|
+
config=config,
|
|
60
|
+
)
|
|
61
|
+
if response.candidates and response.candidates[0].content and response.candidates[0].content.parts:
|
|
62
|
+
for part in response.candidates[0].content.parts:
|
|
63
|
+
if part.text:
|
|
64
|
+
yield Chunk(text=part.text)
|
|
65
|
+
elif part.function_call:
|
|
66
|
+
yield Chunk(tool_calls=[{
|
|
67
|
+
"id": part.function_call.name,
|
|
68
|
+
"function": {
|
|
69
|
+
"name": part.function_call.name,
|
|
70
|
+
"arguments": dict(part.function_call.args) if part.function_call.args else {},
|
|
71
|
+
},
|
|
72
|
+
}])
|
|
73
|
+
|
|
74
|
+
def _build_contents(self, messages: list[Message]) -> list:
|
|
75
|
+
contents = []
|
|
76
|
+
for msg in messages:
|
|
77
|
+
if msg.role == "system":
|
|
78
|
+
continue
|
|
79
|
+
elif msg.role == "user":
|
|
80
|
+
contents.append(types.Content(
|
|
81
|
+
role="user",
|
|
82
|
+
parts=[types.Part.from_text(text=msg.content)],
|
|
83
|
+
))
|
|
84
|
+
elif msg.role == "assistant":
|
|
85
|
+
parts = []
|
|
86
|
+
if msg.content:
|
|
87
|
+
parts.append(types.Part.from_text(text=msg.content))
|
|
88
|
+
if msg.tool_calls:
|
|
89
|
+
for tc in msg.tool_calls:
|
|
90
|
+
parts.append(types.Part.from_function_call(
|
|
91
|
+
name=tc["function"]["name"],
|
|
92
|
+
args=tc["function"]["arguments"],
|
|
93
|
+
))
|
|
94
|
+
if parts:
|
|
95
|
+
contents.append(types.Content(role="model", parts=parts))
|
|
96
|
+
elif msg.role == "tool":
|
|
97
|
+
contents.append(types.Content(
|
|
98
|
+
role="user",
|
|
99
|
+
parts=[types.Part.from_function_response(
|
|
100
|
+
name=msg.name or "unknown",
|
|
101
|
+
response={"result": msg.content},
|
|
102
|
+
)],
|
|
103
|
+
))
|
|
104
|
+
return contents
|
|
105
|
+
|
|
106
|
+
def format_tools(self, tools: list[dict]) -> list:
|
|
107
|
+
declarations = []
|
|
108
|
+
for tool in tools:
|
|
109
|
+
props = {}
|
|
110
|
+
required = tool.get("parameters", {}).get("required", [])
|
|
111
|
+
for pname, pdef in tool.get("parameters", {}).get("properties", {}).items():
|
|
112
|
+
schema_type = pdef.get("type", "string").upper()
|
|
113
|
+
props[pname] = types.Schema(
|
|
114
|
+
type=schema_type,
|
|
115
|
+
description=pdef.get("description", ""),
|
|
116
|
+
)
|
|
117
|
+
declarations.append(types.FunctionDeclaration(
|
|
118
|
+
name=tool["name"],
|
|
119
|
+
description=tool.get("description", ""),
|
|
120
|
+
parameters=types.Schema(
|
|
121
|
+
type="OBJECT",
|
|
122
|
+
properties=props,
|
|
123
|
+
required=required,
|
|
124
|
+
) if props else None,
|
|
125
|
+
))
|
|
126
|
+
return [types.Tool(function_declarations=declarations)]
|
gemi/providers/ollama.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import AsyncIterator
|
|
3
|
+
|
|
4
|
+
import ollama as ollama_sdk
|
|
5
|
+
|
|
6
|
+
from gemi.providers.base import BaseProvider, Chunk, Message
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OllamaProvider(BaseProvider):
|
|
10
|
+
def __init__(self, base_url: str = "http://localhost:11434"):
|
|
11
|
+
self.client = ollama_sdk.AsyncClient(host=base_url)
|
|
12
|
+
|
|
13
|
+
async def chat(
|
|
14
|
+
self,
|
|
15
|
+
messages: list[Message],
|
|
16
|
+
tools: list[dict] | None = None,
|
|
17
|
+
model: str | None = None,
|
|
18
|
+
stream: bool = True,
|
|
19
|
+
) -> AsyncIterator[Chunk]:
|
|
20
|
+
model = model or "llama3"
|
|
21
|
+
formatted_messages = self._build_messages(messages)
|
|
22
|
+
formatted_tools = self.format_tools(tools) if tools else None
|
|
23
|
+
|
|
24
|
+
if stream and not formatted_tools:
|
|
25
|
+
async for part in await self.client.chat(
|
|
26
|
+
model=model,
|
|
27
|
+
messages=formatted_messages,
|
|
28
|
+
stream=True,
|
|
29
|
+
):
|
|
30
|
+
if part.get("message", {}).get("content"):
|
|
31
|
+
yield Chunk(text=part["message"]["content"])
|
|
32
|
+
else:
|
|
33
|
+
response = await self.client.chat(
|
|
34
|
+
model=model,
|
|
35
|
+
messages=formatted_messages,
|
|
36
|
+
tools=formatted_tools,
|
|
37
|
+
stream=False,
|
|
38
|
+
)
|
|
39
|
+
msg = response.get("message", {})
|
|
40
|
+
if msg.get("content"):
|
|
41
|
+
yield Chunk(text=msg["content"])
|
|
42
|
+
if msg.get("tool_calls"):
|
|
43
|
+
for tc in msg["tool_calls"]:
|
|
44
|
+
yield Chunk(tool_calls=[{
|
|
45
|
+
"id": tc["function"]["name"],
|
|
46
|
+
"function": {
|
|
47
|
+
"name": tc["function"]["name"],
|
|
48
|
+
"arguments": tc["function"]["arguments"],
|
|
49
|
+
},
|
|
50
|
+
}])
|
|
51
|
+
|
|
52
|
+
def _build_messages(self, messages: list[Message]) -> list[dict]:
|
|
53
|
+
result = []
|
|
54
|
+
for msg in messages:
|
|
55
|
+
entry = {"role": msg.role, "content": msg.content}
|
|
56
|
+
if msg.role == "tool":
|
|
57
|
+
entry["role"] = "tool"
|
|
58
|
+
result.append(entry)
|
|
59
|
+
return result
|
|
60
|
+
|
|
61
|
+
def format_tools(self, tools: list[dict]) -> list[dict]:
|
|
62
|
+
formatted = []
|
|
63
|
+
for tool in tools:
|
|
64
|
+
formatted.append({
|
|
65
|
+
"type": "function",
|
|
66
|
+
"function": {
|
|
67
|
+
"name": tool["name"],
|
|
68
|
+
"description": tool.get("description", ""),
|
|
69
|
+
"parameters": tool.get("parameters", {}),
|
|
70
|
+
},
|
|
71
|
+
})
|
|
72
|
+
return formatted
|