ccs-llmconnector 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ccs_llmconnector-1.0.0.dist-info/METADATA +349 -0
- ccs_llmconnector-1.0.0.dist-info/RECORD +14 -0
- ccs_llmconnector-1.0.0.dist-info/WHEEL +5 -0
- ccs_llmconnector-1.0.0.dist-info/entry_points.txt +2 -0
- ccs_llmconnector-1.0.0.dist-info/licenses/LICENSE +22 -0
- ccs_llmconnector-1.0.0.dist-info/top_level.txt +1 -0
- llmconnector/__init__.py +39 -0
- llmconnector/anthropic_client.py +190 -0
- llmconnector/client.py +148 -0
- llmconnector/client_cli.py +325 -0
- llmconnector/gemini_client.py +191 -0
- llmconnector/grok_client.py +139 -0
- llmconnector/openai_client.py +139 -0
- llmconnector/py.typed +1 -0
llmconnector/client.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""Provider-agnostic entry point for working with large language models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Dict, Optional, Protocol, Sequence, Union
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from .openai_client import ImageInput, OpenAIResponsesClient
|
|
10
|
+
else:
|
|
11
|
+
ImageInput = Union[str, Path]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SupportsGenerateResponse(Protocol):
|
|
15
|
+
"""Protocol describing provider clients."""
|
|
16
|
+
|
|
17
|
+
def generate_response(
|
|
18
|
+
self,
|
|
19
|
+
*,
|
|
20
|
+
api_key: str,
|
|
21
|
+
prompt: str,
|
|
22
|
+
model: str,
|
|
23
|
+
max_tokens: int = 32000,
|
|
24
|
+
reasoning_effort: Optional[str] = None,
|
|
25
|
+
images: Optional[Sequence[ImageInput]] = None,
|
|
26
|
+
) -> str:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def list_models(self, *, api_key: str) -> Sequence[dict[str, Optional[str]]]:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LLMClient:
|
|
34
|
+
"""Central client capable of routing requests to different providers."""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
providers: Optional[Dict[str, SupportsGenerateResponse]] = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
self._providers: Dict[str, SupportsGenerateResponse] = {}
|
|
41
|
+
default_providers = providers or self._discover_default_providers()
|
|
42
|
+
for name, client in default_providers.items():
|
|
43
|
+
self.register_provider(name, client)
|
|
44
|
+
|
|
45
|
+
if not self._providers:
|
|
46
|
+
raise RuntimeError(
|
|
47
|
+
"No provider implementations registered. Install the required extras "
|
|
48
|
+
"for your target provider (e.g. `pip install openai`)."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def register_provider(self, name: str, client: SupportsGenerateResponse) -> None:
|
|
52
|
+
"""Register or overwrite a provider implementation."""
|
|
53
|
+
if not name:
|
|
54
|
+
raise ValueError("Provider name must be provided.")
|
|
55
|
+
if client is None:
|
|
56
|
+
raise ValueError("Provider client must be provided.")
|
|
57
|
+
|
|
58
|
+
self._providers[name.lower()] = client
|
|
59
|
+
|
|
60
|
+
def generate_response(
|
|
61
|
+
self,
|
|
62
|
+
*,
|
|
63
|
+
provider: str,
|
|
64
|
+
api_key: str,
|
|
65
|
+
prompt: str,
|
|
66
|
+
model: str,
|
|
67
|
+
max_tokens: int = 32000,
|
|
68
|
+
reasoning_effort: Optional[str] = None,
|
|
69
|
+
images: Optional[Sequence[ImageInput]] = None,
|
|
70
|
+
) -> str:
|
|
71
|
+
"""Generate a response using the selected provider."""
|
|
72
|
+
if not provider:
|
|
73
|
+
raise ValueError("provider must be provided.")
|
|
74
|
+
|
|
75
|
+
provider_client = self._providers.get(provider.lower())
|
|
76
|
+
if provider_client is None:
|
|
77
|
+
available = ", ".join(sorted(self._providers)) or "<none>"
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Unknown provider '{provider}'. Available providers: {available}."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return provider_client.generate_response(
|
|
83
|
+
api_key=api_key,
|
|
84
|
+
prompt=prompt,
|
|
85
|
+
model=model,
|
|
86
|
+
max_tokens=max_tokens,
|
|
87
|
+
reasoning_effort=reasoning_effort,
|
|
88
|
+
images=images,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def list_models(
|
|
92
|
+
self,
|
|
93
|
+
*,
|
|
94
|
+
provider: str,
|
|
95
|
+
api_key: str,
|
|
96
|
+
) -> Sequence[dict[str, Optional[str]]]:
|
|
97
|
+
"""List models available for the specified provider."""
|
|
98
|
+
if not provider:
|
|
99
|
+
raise ValueError("provider must be provided.")
|
|
100
|
+
|
|
101
|
+
provider_client = self._providers.get(provider.lower())
|
|
102
|
+
if provider_client is None:
|
|
103
|
+
available = ", ".join(sorted(self._providers)) or "<none>"
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Unknown provider '{provider}'. Available providers: {available}."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return provider_client.list_models(api_key=api_key)
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _discover_default_providers() -> Dict[str, SupportsGenerateResponse]:
|
|
112
|
+
providers: Dict[str, SupportsGenerateResponse] = {}
|
|
113
|
+
try:
|
|
114
|
+
from .openai_client import OpenAIResponsesClient # type: ignore
|
|
115
|
+
except ModuleNotFoundError as exc:
|
|
116
|
+
if exc.name != "openai":
|
|
117
|
+
raise
|
|
118
|
+
return providers
|
|
119
|
+
|
|
120
|
+
providers["openai"] = OpenAIResponsesClient()
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
from .gemini_client import GeminiClient # type: ignore
|
|
124
|
+
except ModuleNotFoundError as exc:
|
|
125
|
+
if exc.name not in {"google", "google.genai"}:
|
|
126
|
+
raise
|
|
127
|
+
else:
|
|
128
|
+
providers["gemini"] = GeminiClient()
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
from .anthropic_client import AnthropicClient # type: ignore
|
|
132
|
+
except ModuleNotFoundError as exc:
|
|
133
|
+
if exc.name != "anthropic":
|
|
134
|
+
raise
|
|
135
|
+
else:
|
|
136
|
+
providers["anthropic"] = AnthropicClient()
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
from .grok_client import GrokClient # type: ignore
|
|
140
|
+
except ModuleNotFoundError as exc:
|
|
141
|
+
if exc.name != "xai_sdk":
|
|
142
|
+
raise
|
|
143
|
+
else:
|
|
144
|
+
grok_client = GrokClient()
|
|
145
|
+
providers["grok"] = grok_client
|
|
146
|
+
providers.setdefault("xai", grok_client)
|
|
147
|
+
|
|
148
|
+
return providers
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
"""Command-line client for llmconnector.
|
|
2
|
+
|
|
3
|
+
Reads the API key from an environment variable derived from the provider name
|
|
4
|
+
("{PROVIDER}_API_KEY", e.g. OPENAI_API_KEY) and exposes
|
|
5
|
+
simple commands to generate a response or list available models for a
|
|
6
|
+
provider.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
import sys
|
|
15
|
+
from typing import Sequence
|
|
16
|
+
|
|
17
|
+
# Support both package execution and direct file execution.
|
|
18
|
+
# When run as a script (no package), add the parent of this file's directory to sys.path
|
|
19
|
+
# so that `import llmconnector` resolves.
|
|
20
|
+
try:
|
|
21
|
+
if __package__ in (None, ""):
|
|
22
|
+
# Running as a script: add repo/src to path
|
|
23
|
+
_here = os.path.dirname(os.path.abspath(__file__))
|
|
24
|
+
_pkg_root = os.path.dirname(_here)
|
|
25
|
+
if _pkg_root not in sys.path:
|
|
26
|
+
sys.path.insert(0, _pkg_root)
|
|
27
|
+
from llmconnector.client import LLMClient # type: ignore
|
|
28
|
+
else:
|
|
29
|
+
from .client import LLMClient
|
|
30
|
+
except Exception: # pragma: no cover - defensive import fallback
|
|
31
|
+
# Last-resort fallback if the above logic fails in unusual environments
|
|
32
|
+
from llmconnector.client import LLMClient # type: ignore
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _env_api_key(provider: str) -> str:
|
|
36
|
+
env_name = f"{provider.upper()}_API_KEY"
|
|
37
|
+
api_key = os.environ.get(env_name, "")
|
|
38
|
+
if not api_key:
|
|
39
|
+
raise SystemExit(
|
|
40
|
+
f"Missing {env_name} environment variable. Set it before running."
|
|
41
|
+
)
|
|
42
|
+
return api_key
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
46
|
+
parser = argparse.ArgumentParser(
|
|
47
|
+
prog="client_cli",
|
|
48
|
+
description="CLI for provider-agnostic LLM requests",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
52
|
+
|
|
53
|
+
# respond: generate a model response
|
|
54
|
+
p_respond = subparsers.add_parser(
|
|
55
|
+
"respond", help="Generate a response from a provider model"
|
|
56
|
+
)
|
|
57
|
+
p_respond.add_argument(
|
|
58
|
+
"--provider",
|
|
59
|
+
required=False,
|
|
60
|
+
help="Provider to use (e.g. openai, gemini, anthropic, grok). If omitted, you will be prompted.",
|
|
61
|
+
)
|
|
62
|
+
p_respond.add_argument(
|
|
63
|
+
"--model",
|
|
64
|
+
required=False,
|
|
65
|
+
help="Model identifier for the provider (e.g. gpt-4o). If omitted, you will be prompted.",
|
|
66
|
+
)
|
|
67
|
+
p_respond.add_argument(
|
|
68
|
+
"--prompt",
|
|
69
|
+
default="",
|
|
70
|
+
help="Text prompt to send (omit if only using images)",
|
|
71
|
+
)
|
|
72
|
+
p_respond.add_argument(
|
|
73
|
+
"--image",
|
|
74
|
+
action="append",
|
|
75
|
+
dest="images",
|
|
76
|
+
default=None,
|
|
77
|
+
help="Image path or URL; may be provided multiple times",
|
|
78
|
+
)
|
|
79
|
+
p_respond.add_argument(
|
|
80
|
+
"--max-tokens",
|
|
81
|
+
type=int,
|
|
82
|
+
default=32000,
|
|
83
|
+
help="Maximum output tokens (provider-specific meaning)",
|
|
84
|
+
)
|
|
85
|
+
p_respond.add_argument(
|
|
86
|
+
"--reasoning-effort",
|
|
87
|
+
choices=["low", "medium", "high"],
|
|
88
|
+
default=None,
|
|
89
|
+
help="Optional reasoning effort hint if supported",
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# models: list available models
|
|
93
|
+
p_models = subparsers.add_parser(
|
|
94
|
+
"models", help="List models available to the provider"
|
|
95
|
+
)
|
|
96
|
+
p_models.add_argument(
|
|
97
|
+
"--provider",
|
|
98
|
+
required=False,
|
|
99
|
+
help="Provider to query (e.g. openai, gemini, anthropic, grok). If omitted, you will be prompted.",
|
|
100
|
+
)
|
|
101
|
+
p_models.add_argument(
|
|
102
|
+
"--json",
|
|
103
|
+
action="store_true",
|
|
104
|
+
help="Output as JSON (default is human-readable)",
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# all-models: list models for all registered providers
|
|
108
|
+
p_all_models = subparsers.add_parser(
|
|
109
|
+
"all-models", help="List models for all registered providers"
|
|
110
|
+
)
|
|
111
|
+
p_all_models.add_argument(
|
|
112
|
+
"--json",
|
|
113
|
+
action="store_true",
|
|
114
|
+
help="Output as JSON (default is human-readable)",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return parser
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _cmd_respond(args: argparse.Namespace) -> int:
|
|
121
|
+
client = LLMClient()
|
|
122
|
+
provider = args.provider
|
|
123
|
+
if not provider:
|
|
124
|
+
# Try to hint known providers
|
|
125
|
+
try:
|
|
126
|
+
known = sorted(LLMClient._discover_default_providers().keys()) # type: ignore[attr-defined]
|
|
127
|
+
except Exception:
|
|
128
|
+
known = []
|
|
129
|
+
hint = f" ({'/'.join(known)})" if known else ""
|
|
130
|
+
provider = input(f"Provider{hint}: ").strip()
|
|
131
|
+
if not provider:
|
|
132
|
+
print("Error: provider is required.", file=sys.stderr)
|
|
133
|
+
return 2
|
|
134
|
+
|
|
135
|
+
api_key = _env_api_key(provider)
|
|
136
|
+
|
|
137
|
+
model = args.model
|
|
138
|
+
if not model:
|
|
139
|
+
model = input("Model id: ").strip()
|
|
140
|
+
if not model:
|
|
141
|
+
print("Error: model is required.", file=sys.stderr)
|
|
142
|
+
return 2
|
|
143
|
+
|
|
144
|
+
prompt = args.prompt
|
|
145
|
+
images: Sequence[str] | None = args.images
|
|
146
|
+
if not prompt and not images:
|
|
147
|
+
prompt = input("Prompt: ")
|
|
148
|
+
if not prompt and not images:
|
|
149
|
+
print("Error: provide a prompt or at least one image.", file=sys.stderr)
|
|
150
|
+
return 2
|
|
151
|
+
try:
|
|
152
|
+
output = client.generate_response(
|
|
153
|
+
provider=provider,
|
|
154
|
+
api_key=api_key,
|
|
155
|
+
prompt=prompt,
|
|
156
|
+
model=model,
|
|
157
|
+
max_tokens=args.max_tokens,
|
|
158
|
+
reasoning_effort=args.reasoning_effort,
|
|
159
|
+
images=images,
|
|
160
|
+
)
|
|
161
|
+
except Exception as exc: # pragma: no cover - CLI surface
|
|
162
|
+
print(f"Error: {exc}", file=sys.stderr)
|
|
163
|
+
return 2
|
|
164
|
+
|
|
165
|
+
print(output)
|
|
166
|
+
return 0
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _cmd_models(args: argparse.Namespace) -> int:
|
|
170
|
+
client = LLMClient()
|
|
171
|
+
provider = args.provider
|
|
172
|
+
if not provider:
|
|
173
|
+
try:
|
|
174
|
+
known = sorted(LLMClient._discover_default_providers().keys()) # type: ignore[attr-defined]
|
|
175
|
+
except Exception:
|
|
176
|
+
known = []
|
|
177
|
+
hint = f" ({'/'.join(known)})" if known else ""
|
|
178
|
+
provider = input(f"Provider{hint}: ").strip()
|
|
179
|
+
if not provider:
|
|
180
|
+
print("Error: provider is required.", file=sys.stderr)
|
|
181
|
+
return 2
|
|
182
|
+
|
|
183
|
+
api_key = _env_api_key(provider)
|
|
184
|
+
try:
|
|
185
|
+
models = client.list_models(provider=provider, api_key=api_key)
|
|
186
|
+
except Exception as exc: # pragma: no cover - CLI surface
|
|
187
|
+
print(f"Error: {exc}", file=sys.stderr)
|
|
188
|
+
return 2
|
|
189
|
+
|
|
190
|
+
if args.json:
|
|
191
|
+
print(json.dumps(models, indent=2))
|
|
192
|
+
else:
|
|
193
|
+
if not models:
|
|
194
|
+
print("No models found.")
|
|
195
|
+
else:
|
|
196
|
+
for m in models:
|
|
197
|
+
mid = m.get("id") or "<unknown>"
|
|
198
|
+
name = m.get("display_name") or ""
|
|
199
|
+
if name:
|
|
200
|
+
print(f"{mid} - {name}")
|
|
201
|
+
else:
|
|
202
|
+
print(mid)
|
|
203
|
+
return 0
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _env_api_key_with_fallbacks(provider: str) -> tuple[str | None, list[str]]:
|
|
207
|
+
"""Return the API key for a provider, trying common env var fallbacks.
|
|
208
|
+
|
|
209
|
+
Returns a tuple of (api_key_or_none, tried_env_names).
|
|
210
|
+
"""
|
|
211
|
+
name = provider.upper()
|
|
212
|
+
# Default convention first
|
|
213
|
+
env_names: list[str] = [f"{name}_API_KEY"]
|
|
214
|
+
|
|
215
|
+
# Provider-specific fallbacks commonly used by SDKs/docs
|
|
216
|
+
if provider.lower() == "gemini":
|
|
217
|
+
env_names.extend(["GOOGLE_API_KEY"]) # google-genai default
|
|
218
|
+
elif provider.lower() in {"grok", "xai"}:
|
|
219
|
+
env_names.extend(["XAI_API_KEY", "GROK_API_KEY"]) # prefer XAI
|
|
220
|
+
elif provider.lower() == "openai":
|
|
221
|
+
# OPENAI_API_KEY already covered by default
|
|
222
|
+
pass
|
|
223
|
+
elif provider.lower() == "anthropic":
|
|
224
|
+
# ANTHROPIC_API_KEY already covered by default
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
for env_name in env_names:
|
|
228
|
+
val = os.environ.get(env_name)
|
|
229
|
+
if val:
|
|
230
|
+
return val, env_names
|
|
231
|
+
|
|
232
|
+
return None, env_names
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _cmd_all_models(args: argparse.Namespace) -> int:
|
|
236
|
+
client = LLMClient()
|
|
237
|
+
|
|
238
|
+
# Group provider names by underlying client instance to avoid duplicates
|
|
239
|
+
by_client: dict[int, dict[str, object]] = {}
|
|
240
|
+
for name, prov_client in getattr(client, "_providers", {}).items(): # type: ignore[attr-defined]
|
|
241
|
+
key = id(prov_client)
|
|
242
|
+
group = by_client.setdefault(key, {"names": [], "client": prov_client})
|
|
243
|
+
names = group["names"] # type: ignore[assignment]
|
|
244
|
+
assert isinstance(names, list)
|
|
245
|
+
names.append(name)
|
|
246
|
+
|
|
247
|
+
results: list[dict[str, object]] = []
|
|
248
|
+
for entry in by_client.values():
|
|
249
|
+
names = sorted(entry["names"]) # type: ignore[index]
|
|
250
|
+
primary = names[0]
|
|
251
|
+
display_name = "/".join(names)
|
|
252
|
+
|
|
253
|
+
api_key, tried = _env_api_key_with_fallbacks(primary)
|
|
254
|
+
if not api_key:
|
|
255
|
+
results.append(
|
|
256
|
+
{
|
|
257
|
+
"provider": display_name,
|
|
258
|
+
"error": f"missing API key (tried: {', '.join(tried)})",
|
|
259
|
+
"models": [],
|
|
260
|
+
}
|
|
261
|
+
)
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
models = client.list_models(provider=primary, api_key=api_key)
|
|
266
|
+
except Exception as exc: # pragma: no cover - CLI surface
|
|
267
|
+
results.append(
|
|
268
|
+
{
|
|
269
|
+
"provider": display_name,
|
|
270
|
+
"error": str(exc),
|
|
271
|
+
"models": [],
|
|
272
|
+
}
|
|
273
|
+
)
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
results.append({"provider": display_name, "models": models})
|
|
277
|
+
|
|
278
|
+
if args.json:
|
|
279
|
+
print(json.dumps(results, indent=2))
|
|
280
|
+
else:
|
|
281
|
+
if not results:
|
|
282
|
+
print("No providers registered.")
|
|
283
|
+
return 0
|
|
284
|
+
|
|
285
|
+
for item in results:
|
|
286
|
+
provider_label = str(item.get("provider", "<unknown>"))
|
|
287
|
+
print(f"== {provider_label} ==")
|
|
288
|
+
if item.get("error"):
|
|
289
|
+
print(f" Skipped: {item['error']}")
|
|
290
|
+
continue
|
|
291
|
+
models = item.get("models") or []
|
|
292
|
+
if not models:
|
|
293
|
+
print(" No models found.")
|
|
294
|
+
continue
|
|
295
|
+
for m in models: # type: ignore[assignment]
|
|
296
|
+
if not isinstance(m, dict):
|
|
297
|
+
print(f" {m}")
|
|
298
|
+
continue
|
|
299
|
+
mid = m.get("id") or "<unknown>"
|
|
300
|
+
name = m.get("display_name") or ""
|
|
301
|
+
if name:
|
|
302
|
+
print(f" {mid} - {name}")
|
|
303
|
+
else:
|
|
304
|
+
print(f" {mid}")
|
|
305
|
+
|
|
306
|
+
return 0
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def main(argv: Sequence[str] | None = None) -> int:
|
|
310
|
+
parser = _build_parser()
|
|
311
|
+
args = parser.parse_args(argv)
|
|
312
|
+
|
|
313
|
+
if args.command == "respond":
|
|
314
|
+
return _cmd_respond(args)
|
|
315
|
+
if args.command == "models":
|
|
316
|
+
return _cmd_models(args)
|
|
317
|
+
if args.command == "all-models":
|
|
318
|
+
return _cmd_all_models(args)
|
|
319
|
+
|
|
320
|
+
parser.error("unknown command")
|
|
321
|
+
return 2
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
if __name__ == "__main__":
|
|
325
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Thin wrapper around the Google Gemini API via the google-genai SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import mimetypes
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional, Sequence, Union
|
|
9
|
+
from urllib.error import URLError
|
|
10
|
+
from urllib.request import urlopen
|
|
11
|
+
|
|
12
|
+
from google import genai
|
|
13
|
+
from google.genai import types
|
|
14
|
+
|
|
15
|
+
ImageInput = Union[str, Path]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GeminiClient:
|
|
19
|
+
"""Convenience wrapper around the Google Gemini SDK."""
|
|
20
|
+
|
|
21
|
+
def generate_response(
|
|
22
|
+
self,
|
|
23
|
+
*,
|
|
24
|
+
api_key: str,
|
|
25
|
+
prompt: str,
|
|
26
|
+
model: str,
|
|
27
|
+
max_tokens: int = 32000,
|
|
28
|
+
reasoning_effort: Optional[str] = None,
|
|
29
|
+
images: Optional[Sequence[ImageInput]] = None,
|
|
30
|
+
) -> str:
|
|
31
|
+
"""Generate a response from the specified Gemini model.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
api_key: API key used to authenticate with the Gemini API.
|
|
35
|
+
prompt: Natural-language instruction or query for the model.
|
|
36
|
+
model: Identifier of the Gemini model to target (for example, ``"gemini-2.5-flash"``).
|
|
37
|
+
max_tokens: Cap for tokens across the entire exchange, defaults to 32000.
|
|
38
|
+
reasoning_effort: Included for API parity; currently unused by the Gemini SDK.
|
|
39
|
+
images: Optional collection of image references (local paths, URLs, or data URLs).
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
The text output produced by the model.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: If required arguments are missing or the request payload is empty.
|
|
46
|
+
URLError: If an image URL cannot be retrieved.
|
|
47
|
+
google.genai.errors.APIError: If the underlying Gemini request fails.
|
|
48
|
+
"""
|
|
49
|
+
if not api_key:
|
|
50
|
+
raise ValueError("api_key must be provided.")
|
|
51
|
+
if not prompt and not images:
|
|
52
|
+
raise ValueError("At least one of prompt or images must be provided.")
|
|
53
|
+
if not model:
|
|
54
|
+
raise ValueError("model must be provided.")
|
|
55
|
+
|
|
56
|
+
parts: list[types.Part] = []
|
|
57
|
+
if prompt:
|
|
58
|
+
parts.append(types.Part.from_text(text=prompt))
|
|
59
|
+
|
|
60
|
+
if images:
|
|
61
|
+
for image in images:
|
|
62
|
+
parts.append(self._to_image_part(image))
|
|
63
|
+
|
|
64
|
+
if not parts:
|
|
65
|
+
raise ValueError("No content provided for response generation.")
|
|
66
|
+
|
|
67
|
+
content = types.Content(role="user", parts=parts)
|
|
68
|
+
|
|
69
|
+
config = types.GenerateContentConfig(max_output_tokens=max_tokens)
|
|
70
|
+
# reasoning_effort is accepted for compatibility but not currently applied because the
|
|
71
|
+
# Gemini SDK does not expose an equivalent configuration parameter.
|
|
72
|
+
|
|
73
|
+
client = genai.Client(api_key=api_key)
|
|
74
|
+
try:
|
|
75
|
+
response = client.models.generate_content(
|
|
76
|
+
model=model,
|
|
77
|
+
contents=[content],
|
|
78
|
+
config=config,
|
|
79
|
+
)
|
|
80
|
+
finally:
|
|
81
|
+
closer = getattr(client, "close", None)
|
|
82
|
+
if callable(closer):
|
|
83
|
+
try:
|
|
84
|
+
closer()
|
|
85
|
+
except Exception:
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
if response.text:
|
|
89
|
+
return response.text
|
|
90
|
+
|
|
91
|
+
candidate_texts: list[str] = []
|
|
92
|
+
for candidate in getattr(response, "candidates", []) or []:
|
|
93
|
+
content_obj = getattr(candidate, "content", None)
|
|
94
|
+
if not content_obj:
|
|
95
|
+
continue
|
|
96
|
+
for part in getattr(content_obj, "parts", []) or []:
|
|
97
|
+
text = getattr(part, "text", None)
|
|
98
|
+
if text:
|
|
99
|
+
candidate_texts.append(text)
|
|
100
|
+
|
|
101
|
+
if candidate_texts:
|
|
102
|
+
return "\n".join(candidate_texts)
|
|
103
|
+
|
|
104
|
+
raise RuntimeError("Gemini response did not include any text output.")
|
|
105
|
+
|
|
106
|
+
def list_models(self, *, api_key: str) -> list[dict[str, Optional[str]]]:
|
|
107
|
+
"""Return the models available to the authenticated Gemini account."""
|
|
108
|
+
if not api_key:
|
|
109
|
+
raise ValueError("api_key must be provided.")
|
|
110
|
+
|
|
111
|
+
models: list[dict[str, Optional[str]]] = []
|
|
112
|
+
client = genai.Client(api_key=api_key)
|
|
113
|
+
try:
|
|
114
|
+
for model in client.models.list():
|
|
115
|
+
model_id = getattr(model, "name", None)
|
|
116
|
+
if model_id is None and isinstance(model, dict):
|
|
117
|
+
model_id = model.get("name")
|
|
118
|
+
if not model_id:
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
# Normalize IDs like "models/<id>" -> "<id>"
|
|
122
|
+
if isinstance(model_id, str) and model_id.startswith("models/"):
|
|
123
|
+
model_id = model_id.split("/", 1)[1]
|
|
124
|
+
|
|
125
|
+
display_name = getattr(model, "display_name", None)
|
|
126
|
+
if display_name is None and isinstance(model, dict):
|
|
127
|
+
display_name = model.get("display_name")
|
|
128
|
+
|
|
129
|
+
models.append({"id": model_id, "display_name": display_name})
|
|
130
|
+
finally:
|
|
131
|
+
closer = getattr(client, "close", None)
|
|
132
|
+
if callable(closer):
|
|
133
|
+
try:
|
|
134
|
+
closer()
|
|
135
|
+
except Exception:
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
return models
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def _to_image_part(image: ImageInput) -> types.Part:
|
|
142
|
+
"""Convert an image reference into a Gemini SDK part."""
|
|
143
|
+
if isinstance(image, Path):
|
|
144
|
+
return _part_from_path(image)
|
|
145
|
+
|
|
146
|
+
if image.startswith("data:"):
|
|
147
|
+
return _part_from_data_url(image)
|
|
148
|
+
|
|
149
|
+
if image.startswith(("http://", "https://")):
|
|
150
|
+
return _part_from_url(image)
|
|
151
|
+
|
|
152
|
+
return _part_from_path(Path(image))
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _part_from_path(path: Path) -> types.Part:
|
|
156
|
+
"""Create an image part from a local filesystem path."""
|
|
157
|
+
expanded = path.expanduser()
|
|
158
|
+
data = expanded.read_bytes()
|
|
159
|
+
mime_type = mimetypes.guess_type(expanded.name)[0] or "application/octet-stream"
|
|
160
|
+
return types.Part.from_bytes(data=data, mime_type=mime_type)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _part_from_url(url: str) -> types.Part:
|
|
164
|
+
"""Create an image part by downloading content from a URL."""
|
|
165
|
+
with urlopen(url) as response:
|
|
166
|
+
data = response.read()
|
|
167
|
+
mime_type = response.info().get_content_type()
|
|
168
|
+
|
|
169
|
+
if not mime_type or mime_type == "application/octet-stream":
|
|
170
|
+
mime_type = mimetypes.guess_type(url)[0] or "application/octet-stream"
|
|
171
|
+
|
|
172
|
+
return types.Part.from_bytes(data=data, mime_type=mime_type)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _part_from_data_url(data_url: str) -> types.Part:
|
|
176
|
+
"""Create an image part from a data URL."""
|
|
177
|
+
header, encoded = data_url.split(",", 1)
|
|
178
|
+
metadata = header[len("data:") :]
|
|
179
|
+
mime_type = "application/octet-stream"
|
|
180
|
+
|
|
181
|
+
if ";" in metadata:
|
|
182
|
+
mime_type, _, metadata = metadata.partition(";")
|
|
183
|
+
elif metadata:
|
|
184
|
+
mime_type = metadata
|
|
185
|
+
|
|
186
|
+
if "base64" in metadata:
|
|
187
|
+
data = base64.b64decode(encoded)
|
|
188
|
+
else:
|
|
189
|
+
data = encoded.encode("utf-8")
|
|
190
|
+
|
|
191
|
+
return types.Part.from_bytes(data=data, mime_type=mime_type or "application/octet-stream")
|