genxai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +3 -0
- cli/commands/__init__.py +6 -0
- cli/commands/approval.py +85 -0
- cli/commands/audit.py +127 -0
- cli/commands/metrics.py +25 -0
- cli/commands/tool.py +389 -0
- cli/main.py +32 -0
- genxai/__init__.py +81 -0
- genxai/api/__init__.py +5 -0
- genxai/api/app.py +21 -0
- genxai/config/__init__.py +5 -0
- genxai/config/settings.py +37 -0
- genxai/connectors/__init__.py +19 -0
- genxai/connectors/base.py +122 -0
- genxai/connectors/kafka.py +92 -0
- genxai/connectors/postgres_cdc.py +95 -0
- genxai/connectors/registry.py +44 -0
- genxai/connectors/sqs.py +94 -0
- genxai/connectors/webhook.py +73 -0
- genxai/core/__init__.py +37 -0
- genxai/core/agent/__init__.py +32 -0
- genxai/core/agent/base.py +206 -0
- genxai/core/agent/config_io.py +59 -0
- genxai/core/agent/registry.py +98 -0
- genxai/core/agent/runtime.py +970 -0
- genxai/core/communication/__init__.py +6 -0
- genxai/core/communication/collaboration.py +44 -0
- genxai/core/communication/message_bus.py +192 -0
- genxai/core/communication/protocols.py +35 -0
- genxai/core/execution/__init__.py +22 -0
- genxai/core/execution/metadata.py +181 -0
- genxai/core/execution/queue.py +201 -0
- genxai/core/graph/__init__.py +30 -0
- genxai/core/graph/checkpoints.py +77 -0
- genxai/core/graph/edges.py +131 -0
- genxai/core/graph/engine.py +813 -0
- genxai/core/graph/executor.py +516 -0
- genxai/core/graph/nodes.py +161 -0
- genxai/core/graph/trigger_runner.py +40 -0
- genxai/core/memory/__init__.py +19 -0
- genxai/core/memory/base.py +72 -0
- genxai/core/memory/embedding.py +327 -0
- genxai/core/memory/episodic.py +448 -0
- genxai/core/memory/long_term.py +467 -0
- genxai/core/memory/manager.py +543 -0
- genxai/core/memory/persistence.py +297 -0
- genxai/core/memory/procedural.py +461 -0
- genxai/core/memory/semantic.py +526 -0
- genxai/core/memory/shared.py +62 -0
- genxai/core/memory/short_term.py +303 -0
- genxai/core/memory/vector_store.py +508 -0
- genxai/core/memory/working.py +211 -0
- genxai/core/state/__init__.py +6 -0
- genxai/core/state/manager.py +293 -0
- genxai/core/state/schema.py +115 -0
- genxai/llm/__init__.py +14 -0
- genxai/llm/base.py +150 -0
- genxai/llm/factory.py +329 -0
- genxai/llm/providers/__init__.py +1 -0
- genxai/llm/providers/anthropic.py +249 -0
- genxai/llm/providers/cohere.py +274 -0
- genxai/llm/providers/google.py +334 -0
- genxai/llm/providers/ollama.py +147 -0
- genxai/llm/providers/openai.py +257 -0
- genxai/llm/routing.py +83 -0
- genxai/observability/__init__.py +6 -0
- genxai/observability/logging.py +327 -0
- genxai/observability/metrics.py +494 -0
- genxai/observability/tracing.py +372 -0
- genxai/performance/__init__.py +39 -0
- genxai/performance/cache.py +256 -0
- genxai/performance/pooling.py +289 -0
- genxai/security/audit.py +304 -0
- genxai/security/auth.py +315 -0
- genxai/security/cost_control.py +528 -0
- genxai/security/default_policies.py +44 -0
- genxai/security/jwt.py +142 -0
- genxai/security/oauth.py +226 -0
- genxai/security/pii.py +366 -0
- genxai/security/policy_engine.py +82 -0
- genxai/security/rate_limit.py +341 -0
- genxai/security/rbac.py +247 -0
- genxai/security/validation.py +218 -0
- genxai/tools/__init__.py +21 -0
- genxai/tools/base.py +383 -0
- genxai/tools/builtin/__init__.py +131 -0
- genxai/tools/builtin/communication/__init__.py +15 -0
- genxai/tools/builtin/communication/email_sender.py +159 -0
- genxai/tools/builtin/communication/notification_manager.py +167 -0
- genxai/tools/builtin/communication/slack_notifier.py +118 -0
- genxai/tools/builtin/communication/sms_sender.py +118 -0
- genxai/tools/builtin/communication/webhook_caller.py +136 -0
- genxai/tools/builtin/computation/__init__.py +15 -0
- genxai/tools/builtin/computation/calculator.py +101 -0
- genxai/tools/builtin/computation/code_executor.py +183 -0
- genxai/tools/builtin/computation/data_validator.py +259 -0
- genxai/tools/builtin/computation/hash_generator.py +129 -0
- genxai/tools/builtin/computation/regex_matcher.py +201 -0
- genxai/tools/builtin/data/__init__.py +15 -0
- genxai/tools/builtin/data/csv_processor.py +213 -0
- genxai/tools/builtin/data/data_transformer.py +299 -0
- genxai/tools/builtin/data/json_processor.py +233 -0
- genxai/tools/builtin/data/text_analyzer.py +288 -0
- genxai/tools/builtin/data/xml_processor.py +175 -0
- genxai/tools/builtin/database/__init__.py +15 -0
- genxai/tools/builtin/database/database_inspector.py +157 -0
- genxai/tools/builtin/database/mongodb_query.py +196 -0
- genxai/tools/builtin/database/redis_cache.py +167 -0
- genxai/tools/builtin/database/sql_query.py +145 -0
- genxai/tools/builtin/database/vector_search.py +163 -0
- genxai/tools/builtin/file/__init__.py +17 -0
- genxai/tools/builtin/file/directory_scanner.py +214 -0
- genxai/tools/builtin/file/file_compressor.py +237 -0
- genxai/tools/builtin/file/file_reader.py +102 -0
- genxai/tools/builtin/file/file_writer.py +122 -0
- genxai/tools/builtin/file/image_processor.py +186 -0
- genxai/tools/builtin/file/pdf_parser.py +144 -0
- genxai/tools/builtin/test/__init__.py +15 -0
- genxai/tools/builtin/test/async_simulator.py +62 -0
- genxai/tools/builtin/test/data_transformer.py +99 -0
- genxai/tools/builtin/test/error_generator.py +82 -0
- genxai/tools/builtin/test/simple_math.py +94 -0
- genxai/tools/builtin/test/string_processor.py +72 -0
- genxai/tools/builtin/web/__init__.py +15 -0
- genxai/tools/builtin/web/api_caller.py +161 -0
- genxai/tools/builtin/web/html_parser.py +330 -0
- genxai/tools/builtin/web/http_client.py +187 -0
- genxai/tools/builtin/web/url_validator.py +162 -0
- genxai/tools/builtin/web/web_scraper.py +170 -0
- genxai/tools/custom/my_test_tool_2.py +9 -0
- genxai/tools/dynamic.py +105 -0
- genxai/tools/mcp_server.py +167 -0
- genxai/tools/persistence/__init__.py +6 -0
- genxai/tools/persistence/models.py +55 -0
- genxai/tools/persistence/service.py +322 -0
- genxai/tools/registry.py +227 -0
- genxai/tools/security/__init__.py +11 -0
- genxai/tools/security/limits.py +214 -0
- genxai/tools/security/policy.py +20 -0
- genxai/tools/security/sandbox.py +248 -0
- genxai/tools/templates.py +435 -0
- genxai/triggers/__init__.py +19 -0
- genxai/triggers/base.py +104 -0
- genxai/triggers/file_watcher.py +75 -0
- genxai/triggers/queue.py +68 -0
- genxai/triggers/registry.py +82 -0
- genxai/triggers/schedule.py +66 -0
- genxai/triggers/webhook.py +68 -0
- genxai/utils/__init__.py +1 -0
- genxai/utils/tokens.py +295 -0
- genxai_framework-0.1.0.dist-info/METADATA +495 -0
- genxai_framework-0.1.0.dist-info/RECORD +156 -0
- genxai_framework-0.1.0.dist-info/WHEEL +5 -0
- genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
- genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
- genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Google Gemini LLM provider implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, AsyncIterator
|
|
4
|
+
import importlib
|
|
5
|
+
import os
|
|
6
|
+
import logging
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
from genxai.llm.base import LLMProvider, LLMResponse
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GoogleProvider(LLMProvider):
|
|
15
|
+
"""Google Gemini LLM provider."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model: str = "gemini-pro",
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
temperature: float = 0.7,
|
|
22
|
+
max_tokens: Optional[int] = None,
|
|
23
|
+
**kwargs: Any,
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Initialize Google provider.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model: Model name (gemini-pro, gemini-pro-vision, gemini-ultra)
|
|
29
|
+
api_key: Google API key (defaults to GOOGLE_API_KEY env var)
|
|
30
|
+
temperature: Sampling temperature
|
|
31
|
+
max_tokens: Maximum tokens to generate
|
|
32
|
+
**kwargs: Additional Google-specific parameters
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(model, temperature, max_tokens, **kwargs)
|
|
35
|
+
|
|
36
|
+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
37
|
+
if not self.api_key:
|
|
38
|
+
logger.warning("Google API key not provided. Set GOOGLE_API_KEY environment variable.")
|
|
39
|
+
|
|
40
|
+
self._client: Optional[Any] = None
|
|
41
|
+
self._model_instance: Optional[Any] = None
|
|
42
|
+
self._initialize_client()
|
|
43
|
+
|
|
44
|
+
def _initialize_client(self) -> None:
|
|
45
|
+
"""Initialize Google Generative AI client."""
|
|
46
|
+
try:
|
|
47
|
+
genai = importlib.import_module("google.genai")
|
|
48
|
+
|
|
49
|
+
if hasattr(genai, "configure"):
|
|
50
|
+
genai.configure(api_key=self.api_key)
|
|
51
|
+
|
|
52
|
+
if hasattr(genai, "GenerativeModel"):
|
|
53
|
+
self._client = genai
|
|
54
|
+
self._model_instance = genai.GenerativeModel(self.model)
|
|
55
|
+
logger.info(f"Google Gemini client initialized with model: {self.model}")
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
if hasattr(genai, "Client"):
|
|
59
|
+
self._client = genai.Client(api_key=self.api_key)
|
|
60
|
+
models_attr = getattr(self._client, "models", None)
|
|
61
|
+
if models_attr and hasattr(models_attr, "get"):
|
|
62
|
+
self._model_instance = models_attr.get(self.model)
|
|
63
|
+
elif hasattr(self._client, "get_model"):
|
|
64
|
+
self._model_instance = self._client.get_model(self.model)
|
|
65
|
+
else:
|
|
66
|
+
raise RuntimeError("google.genai client does not expose a model accessor")
|
|
67
|
+
logger.info(f"Google Gemini client initialized with model: {self.model}")
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
raise RuntimeError("google.genai does not expose a known client API")
|
|
71
|
+
except ImportError:
|
|
72
|
+
with warnings.catch_warnings():
|
|
73
|
+
warnings.filterwarnings(
|
|
74
|
+
"ignore",
|
|
75
|
+
category=FutureWarning,
|
|
76
|
+
module=r"google\.generativeai",
|
|
77
|
+
)
|
|
78
|
+
warnings.filterwarnings(
|
|
79
|
+
"ignore",
|
|
80
|
+
message=r"All support for the `google.generativeai` package has ended.*",
|
|
81
|
+
category=FutureWarning,
|
|
82
|
+
)
|
|
83
|
+
genai = importlib.import_module("google.generativeai")
|
|
84
|
+
|
|
85
|
+
genai.configure(api_key=self.api_key)
|
|
86
|
+
self._client = genai
|
|
87
|
+
self._model_instance = genai.GenerativeModel(self.model)
|
|
88
|
+
logger.info(f"Google Gemini client initialized with model: {self.model}")
|
|
89
|
+
except ImportError:
|
|
90
|
+
logger.error(
|
|
91
|
+
"Google Generative AI package not installed. "
|
|
92
|
+
"Install with: pip install google-genai"
|
|
93
|
+
)
|
|
94
|
+
self._client = None
|
|
95
|
+
self._model_instance = None
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.error(f"Failed to initialize Google client: {e}")
|
|
98
|
+
self._client = None
|
|
99
|
+
self._model_instance = None
|
|
100
|
+
|
|
101
|
+
async def generate(
|
|
102
|
+
self,
|
|
103
|
+
prompt: str,
|
|
104
|
+
system_prompt: Optional[str] = None,
|
|
105
|
+
**kwargs: Any,
|
|
106
|
+
) -> LLMResponse:
|
|
107
|
+
"""Generate completion using Gemini.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
prompt: User prompt
|
|
111
|
+
system_prompt: System prompt (prepended to user prompt)
|
|
112
|
+
**kwargs: Additional generation parameters
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
LLM response
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
RuntimeError: If client not initialized
|
|
119
|
+
Exception: If API call fails
|
|
120
|
+
"""
|
|
121
|
+
if not self._model_instance:
|
|
122
|
+
raise RuntimeError("Google Gemini client not initialized")
|
|
123
|
+
|
|
124
|
+
# Combine system prompt with user prompt
|
|
125
|
+
full_prompt = prompt
|
|
126
|
+
if system_prompt:
|
|
127
|
+
full_prompt = f"{system_prompt}\n\n{prompt}"
|
|
128
|
+
|
|
129
|
+
# Configure generation parameters
|
|
130
|
+
generation_config = {
|
|
131
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
132
|
+
"max_output_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
# Add additional parameters
|
|
136
|
+
if "top_p" in kwargs:
|
|
137
|
+
generation_config["top_p"] = kwargs["top_p"]
|
|
138
|
+
if "top_k" in kwargs:
|
|
139
|
+
generation_config["top_k"] = kwargs["top_k"]
|
|
140
|
+
if "stop_sequences" in kwargs:
|
|
141
|
+
generation_config["stop_sequences"] = kwargs["stop_sequences"]
|
|
142
|
+
|
|
143
|
+
# Configure safety settings if provided
|
|
144
|
+
safety_settings = kwargs.get("safety_settings")
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
logger.debug(f"Calling Google Gemini API with model: {self.model}")
|
|
148
|
+
|
|
149
|
+
response = await self._model_instance.generate_content_async(
|
|
150
|
+
full_prompt,
|
|
151
|
+
generation_config=generation_config,
|
|
152
|
+
safety_settings=safety_settings,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Extract response
|
|
156
|
+
content = response.text if hasattr(response, 'text') else ""
|
|
157
|
+
|
|
158
|
+
# Extract usage (Gemini doesn't always provide detailed token counts)
|
|
159
|
+
usage = {
|
|
160
|
+
"prompt_tokens": 0, # Gemini API doesn't expose this directly
|
|
161
|
+
"completion_tokens": 0, # Gemini API doesn't expose this directly
|
|
162
|
+
"total_tokens": 0,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
# Try to get token count if available
|
|
166
|
+
if hasattr(response, 'usage_metadata'):
|
|
167
|
+
usage["prompt_tokens"] = getattr(response.usage_metadata, 'prompt_token_count', 0)
|
|
168
|
+
usage["completion_tokens"] = getattr(response.usage_metadata, 'candidates_token_count', 0)
|
|
169
|
+
usage["total_tokens"] = getattr(response.usage_metadata, 'total_token_count', 0)
|
|
170
|
+
|
|
171
|
+
# Update stats
|
|
172
|
+
self._update_stats(usage)
|
|
173
|
+
|
|
174
|
+
# Get finish reason
|
|
175
|
+
finish_reason = None
|
|
176
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
177
|
+
finish_reason = str(response.candidates[0].finish_reason)
|
|
178
|
+
|
|
179
|
+
return LLMResponse(
|
|
180
|
+
content=content,
|
|
181
|
+
model=self.model,
|
|
182
|
+
usage=usage,
|
|
183
|
+
finish_reason=finish_reason,
|
|
184
|
+
metadata={
|
|
185
|
+
"safety_ratings": (
|
|
186
|
+
[rating for candidate in response.candidates
|
|
187
|
+
for rating in candidate.safety_ratings]
|
|
188
|
+
if hasattr(response, 'candidates') else []
|
|
189
|
+
),
|
|
190
|
+
},
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
except Exception as e:
|
|
194
|
+
logger.error(f"Google Gemini API call failed: {e}")
|
|
195
|
+
raise
|
|
196
|
+
|
|
197
|
+
async def generate_stream(
|
|
198
|
+
self,
|
|
199
|
+
prompt: str,
|
|
200
|
+
system_prompt: Optional[str] = None,
|
|
201
|
+
**kwargs: Any,
|
|
202
|
+
) -> AsyncIterator[str]:
|
|
203
|
+
"""Generate completion with streaming.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
prompt: User prompt
|
|
207
|
+
system_prompt: System prompt
|
|
208
|
+
**kwargs: Additional generation parameters
|
|
209
|
+
|
|
210
|
+
Yields:
|
|
211
|
+
Content chunks
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
RuntimeError: If client not initialized
|
|
215
|
+
"""
|
|
216
|
+
if not self._model_instance:
|
|
217
|
+
raise RuntimeError("Google Gemini client not initialized")
|
|
218
|
+
|
|
219
|
+
# Combine system prompt with user prompt
|
|
220
|
+
full_prompt = prompt
|
|
221
|
+
if system_prompt:
|
|
222
|
+
full_prompt = f"{system_prompt}\n\n{prompt}"
|
|
223
|
+
|
|
224
|
+
# Configure generation parameters
|
|
225
|
+
generation_config = {
|
|
226
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
227
|
+
"max_output_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
safety_settings = kwargs.get("safety_settings")
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
logger.debug(f"Streaming from Google Gemini API with model: {self.model}")
|
|
234
|
+
|
|
235
|
+
response = await self._model_instance.generate_content_async(
|
|
236
|
+
full_prompt,
|
|
237
|
+
generation_config=generation_config,
|
|
238
|
+
safety_settings=safety_settings,
|
|
239
|
+
stream=True,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
async for chunk in response:
|
|
243
|
+
if hasattr(chunk, 'text') and chunk.text:
|
|
244
|
+
yield chunk.text
|
|
245
|
+
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.error(f"Google Gemini streaming failed: {e}")
|
|
248
|
+
raise
|
|
249
|
+
|
|
250
|
+
async def generate_chat(
|
|
251
|
+
self,
|
|
252
|
+
messages: list[Dict[str, str]],
|
|
253
|
+
**kwargs: Any,
|
|
254
|
+
) -> LLMResponse:
|
|
255
|
+
"""Generate completion for chat messages.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
messages: List of message dictionaries
|
|
259
|
+
**kwargs: Additional generation parameters
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
LLM response
|
|
263
|
+
"""
|
|
264
|
+
if not self._model_instance:
|
|
265
|
+
raise RuntimeError("Google Gemini client not initialized")
|
|
266
|
+
|
|
267
|
+
# Convert messages to Gemini format
|
|
268
|
+
# Gemini uses a simpler format: alternating user/model messages
|
|
269
|
+
chat_history = []
|
|
270
|
+
system_prompt = None
|
|
271
|
+
|
|
272
|
+
for msg in messages:
|
|
273
|
+
role = msg.get("role", "user")
|
|
274
|
+
content = msg.get("content", "")
|
|
275
|
+
|
|
276
|
+
if role == "system":
|
|
277
|
+
system_prompt = content
|
|
278
|
+
elif role == "user":
|
|
279
|
+
chat_history.append({"role": "user", "parts": [content]})
|
|
280
|
+
elif role == "assistant":
|
|
281
|
+
chat_history.append({"role": "model", "parts": [content]})
|
|
282
|
+
|
|
283
|
+
# Start chat session
|
|
284
|
+
chat = self._model_instance.start_chat(history=chat_history[:-1] if chat_history else [])
|
|
285
|
+
|
|
286
|
+
# Get last user message
|
|
287
|
+
last_message = chat_history[-1]["parts"][0] if chat_history else ""
|
|
288
|
+
|
|
289
|
+
# Prepend system prompt if provided
|
|
290
|
+
if system_prompt:
|
|
291
|
+
last_message = f"{system_prompt}\n\n{last_message}"
|
|
292
|
+
|
|
293
|
+
# Configure generation
|
|
294
|
+
generation_config = {
|
|
295
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
296
|
+
"max_output_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
response = await chat.send_message_async(
|
|
301
|
+
last_message,
|
|
302
|
+
generation_config=generation_config,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
content = response.text if hasattr(response, 'text') else ""
|
|
306
|
+
|
|
307
|
+
usage = {
|
|
308
|
+
"prompt_tokens": 0,
|
|
309
|
+
"completion_tokens": 0,
|
|
310
|
+
"total_tokens": 0,
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
if hasattr(response, 'usage_metadata'):
|
|
314
|
+
usage["prompt_tokens"] = getattr(response.usage_metadata, 'prompt_token_count', 0)
|
|
315
|
+
usage["completion_tokens"] = getattr(response.usage_metadata, 'candidates_token_count', 0)
|
|
316
|
+
usage["total_tokens"] = getattr(response.usage_metadata, 'total_token_count', 0)
|
|
317
|
+
|
|
318
|
+
self._update_stats(usage)
|
|
319
|
+
|
|
320
|
+
finish_reason = None
|
|
321
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
322
|
+
finish_reason = str(response.candidates[0].finish_reason)
|
|
323
|
+
|
|
324
|
+
return LLMResponse(
|
|
325
|
+
content=content,
|
|
326
|
+
model=self.model,
|
|
327
|
+
usage=usage,
|
|
328
|
+
finish_reason=finish_reason,
|
|
329
|
+
metadata={},
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
except Exception as e:
|
|
333
|
+
logger.error(f"Google Gemini chat API call failed: {e}")
|
|
334
|
+
raise
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Ollama (local) LLM provider implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Optional, AsyncIterator
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from genxai.llm.base import LLMProvider, LLMResponse
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OllamaProvider(LLMProvider):
|
|
17
|
+
"""Ollama LLM provider for local model execution.
|
|
18
|
+
|
|
19
|
+
Docs: https://github.com/ollama/ollama/blob/main/docs/api.md
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model: str = "llama3",
|
|
25
|
+
api_key: Optional[str] = None,
|
|
26
|
+
temperature: float = 0.7,
|
|
27
|
+
max_tokens: Optional[int] = None,
|
|
28
|
+
base_url: Optional[str] = None,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Initialize Ollama provider.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
model: Ollama model name (e.g., llama3, mistral)
|
|
35
|
+
api_key: Optional API key (Ollama typically runs locally)
|
|
36
|
+
temperature: Sampling temperature
|
|
37
|
+
max_tokens: Maximum tokens to generate
|
|
38
|
+
base_url: Ollama server URL (default: http://localhost:11434)
|
|
39
|
+
**kwargs: Additional Ollama-specific parameters
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(model, temperature, max_tokens, **kwargs)
|
|
42
|
+
self.api_key = api_key
|
|
43
|
+
self.base_url = base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
|
44
|
+
self._client = httpx.AsyncClient(base_url=self.base_url, timeout=kwargs.get("timeout", 120))
|
|
45
|
+
|
|
46
|
+
async def generate(
|
|
47
|
+
self,
|
|
48
|
+
prompt: str,
|
|
49
|
+
system_prompt: Optional[str] = None,
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
) -> LLMResponse:
|
|
52
|
+
"""Generate completion using Ollama.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
prompt: User prompt
|
|
56
|
+
system_prompt: Optional system prompt
|
|
57
|
+
**kwargs: Additional generation parameters
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
LLM response
|
|
61
|
+
"""
|
|
62
|
+
payload: Dict[str, Any] = {
|
|
63
|
+
"model": self.model,
|
|
64
|
+
"prompt": prompt,
|
|
65
|
+
"stream": False,
|
|
66
|
+
"options": {
|
|
67
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
68
|
+
},
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
if system_prompt:
|
|
72
|
+
payload["system"] = system_prompt
|
|
73
|
+
if self.max_tokens:
|
|
74
|
+
payload["options"]["num_predict"] = kwargs.get("max_tokens", self.max_tokens)
|
|
75
|
+
|
|
76
|
+
# Merge any custom options
|
|
77
|
+
if "options" in kwargs and isinstance(kwargs["options"], dict):
|
|
78
|
+
payload["options"].update(kwargs["options"])
|
|
79
|
+
|
|
80
|
+
logger.debug("Calling Ollama generate with model %s", self.model)
|
|
81
|
+
response = await self._client.post("/api/generate", json=payload)
|
|
82
|
+
response.raise_for_status()
|
|
83
|
+
data = response.json()
|
|
84
|
+
|
|
85
|
+
content = data.get("response", "")
|
|
86
|
+
usage = {
|
|
87
|
+
"prompt_tokens": data.get("prompt_eval_count", 0),
|
|
88
|
+
"completion_tokens": data.get("eval_count", 0),
|
|
89
|
+
"total_tokens": (data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
self._update_stats(usage)
|
|
93
|
+
|
|
94
|
+
return LLMResponse(
|
|
95
|
+
content=content,
|
|
96
|
+
model=self.model,
|
|
97
|
+
usage=usage,
|
|
98
|
+
finish_reason="stop",
|
|
99
|
+
metadata={"done": data.get("done", True)},
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
async def generate_stream(
|
|
103
|
+
self,
|
|
104
|
+
prompt: str,
|
|
105
|
+
system_prompt: Optional[str] = None,
|
|
106
|
+
**kwargs: Any,
|
|
107
|
+
) -> AsyncIterator[str]:
|
|
108
|
+
"""Generate completion with streaming.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
prompt: User prompt
|
|
112
|
+
system_prompt: Optional system prompt
|
|
113
|
+
**kwargs: Additional generation parameters
|
|
114
|
+
|
|
115
|
+
Yields:
|
|
116
|
+
Content chunks
|
|
117
|
+
"""
|
|
118
|
+
payload: Dict[str, Any] = {
|
|
119
|
+
"model": self.model,
|
|
120
|
+
"prompt": prompt,
|
|
121
|
+
"stream": True,
|
|
122
|
+
"options": {
|
|
123
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
124
|
+
},
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
if system_prompt:
|
|
128
|
+
payload["system"] = system_prompt
|
|
129
|
+
if self.max_tokens:
|
|
130
|
+
payload["options"]["num_predict"] = kwargs.get("max_tokens", self.max_tokens)
|
|
131
|
+
if "options" in kwargs and isinstance(kwargs["options"], dict):
|
|
132
|
+
payload["options"].update(kwargs["options"])
|
|
133
|
+
|
|
134
|
+
logger.debug("Streaming from Ollama with model %s", self.model)
|
|
135
|
+
async with self._client.stream("POST", "/api/generate", json=payload) as response:
|
|
136
|
+
response.raise_for_status()
|
|
137
|
+
async for line in response.aiter_lines():
|
|
138
|
+
if not line:
|
|
139
|
+
continue
|
|
140
|
+
data = httpx.Response(200, content=line).json()
|
|
141
|
+
chunk = data.get("response")
|
|
142
|
+
if chunk:
|
|
143
|
+
yield chunk
|
|
144
|
+
|
|
145
|
+
async def close(self) -> None:
|
|
146
|
+
"""Close HTTP client."""
|
|
147
|
+
await self._client.aclose()
|