genxai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +3 -0
- cli/commands/__init__.py +6 -0
- cli/commands/approval.py +85 -0
- cli/commands/audit.py +127 -0
- cli/commands/metrics.py +25 -0
- cli/commands/tool.py +389 -0
- cli/main.py +32 -0
- genxai/__init__.py +81 -0
- genxai/api/__init__.py +5 -0
- genxai/api/app.py +21 -0
- genxai/config/__init__.py +5 -0
- genxai/config/settings.py +37 -0
- genxai/connectors/__init__.py +19 -0
- genxai/connectors/base.py +122 -0
- genxai/connectors/kafka.py +92 -0
- genxai/connectors/postgres_cdc.py +95 -0
- genxai/connectors/registry.py +44 -0
- genxai/connectors/sqs.py +94 -0
- genxai/connectors/webhook.py +73 -0
- genxai/core/__init__.py +37 -0
- genxai/core/agent/__init__.py +32 -0
- genxai/core/agent/base.py +206 -0
- genxai/core/agent/config_io.py +59 -0
- genxai/core/agent/registry.py +98 -0
- genxai/core/agent/runtime.py +970 -0
- genxai/core/communication/__init__.py +6 -0
- genxai/core/communication/collaboration.py +44 -0
- genxai/core/communication/message_bus.py +192 -0
- genxai/core/communication/protocols.py +35 -0
- genxai/core/execution/__init__.py +22 -0
- genxai/core/execution/metadata.py +181 -0
- genxai/core/execution/queue.py +201 -0
- genxai/core/graph/__init__.py +30 -0
- genxai/core/graph/checkpoints.py +77 -0
- genxai/core/graph/edges.py +131 -0
- genxai/core/graph/engine.py +813 -0
- genxai/core/graph/executor.py +516 -0
- genxai/core/graph/nodes.py +161 -0
- genxai/core/graph/trigger_runner.py +40 -0
- genxai/core/memory/__init__.py +19 -0
- genxai/core/memory/base.py +72 -0
- genxai/core/memory/embedding.py +327 -0
- genxai/core/memory/episodic.py +448 -0
- genxai/core/memory/long_term.py +467 -0
- genxai/core/memory/manager.py +543 -0
- genxai/core/memory/persistence.py +297 -0
- genxai/core/memory/procedural.py +461 -0
- genxai/core/memory/semantic.py +526 -0
- genxai/core/memory/shared.py +62 -0
- genxai/core/memory/short_term.py +303 -0
- genxai/core/memory/vector_store.py +508 -0
- genxai/core/memory/working.py +211 -0
- genxai/core/state/__init__.py +6 -0
- genxai/core/state/manager.py +293 -0
- genxai/core/state/schema.py +115 -0
- genxai/llm/__init__.py +14 -0
- genxai/llm/base.py +150 -0
- genxai/llm/factory.py +329 -0
- genxai/llm/providers/__init__.py +1 -0
- genxai/llm/providers/anthropic.py +249 -0
- genxai/llm/providers/cohere.py +274 -0
- genxai/llm/providers/google.py +334 -0
- genxai/llm/providers/ollama.py +147 -0
- genxai/llm/providers/openai.py +257 -0
- genxai/llm/routing.py +83 -0
- genxai/observability/__init__.py +6 -0
- genxai/observability/logging.py +327 -0
- genxai/observability/metrics.py +494 -0
- genxai/observability/tracing.py +372 -0
- genxai/performance/__init__.py +39 -0
- genxai/performance/cache.py +256 -0
- genxai/performance/pooling.py +289 -0
- genxai/security/audit.py +304 -0
- genxai/security/auth.py +315 -0
- genxai/security/cost_control.py +528 -0
- genxai/security/default_policies.py +44 -0
- genxai/security/jwt.py +142 -0
- genxai/security/oauth.py +226 -0
- genxai/security/pii.py +366 -0
- genxai/security/policy_engine.py +82 -0
- genxai/security/rate_limit.py +341 -0
- genxai/security/rbac.py +247 -0
- genxai/security/validation.py +218 -0
- genxai/tools/__init__.py +21 -0
- genxai/tools/base.py +383 -0
- genxai/tools/builtin/__init__.py +131 -0
- genxai/tools/builtin/communication/__init__.py +15 -0
- genxai/tools/builtin/communication/email_sender.py +159 -0
- genxai/tools/builtin/communication/notification_manager.py +167 -0
- genxai/tools/builtin/communication/slack_notifier.py +118 -0
- genxai/tools/builtin/communication/sms_sender.py +118 -0
- genxai/tools/builtin/communication/webhook_caller.py +136 -0
- genxai/tools/builtin/computation/__init__.py +15 -0
- genxai/tools/builtin/computation/calculator.py +101 -0
- genxai/tools/builtin/computation/code_executor.py +183 -0
- genxai/tools/builtin/computation/data_validator.py +259 -0
- genxai/tools/builtin/computation/hash_generator.py +129 -0
- genxai/tools/builtin/computation/regex_matcher.py +201 -0
- genxai/tools/builtin/data/__init__.py +15 -0
- genxai/tools/builtin/data/csv_processor.py +213 -0
- genxai/tools/builtin/data/data_transformer.py +299 -0
- genxai/tools/builtin/data/json_processor.py +233 -0
- genxai/tools/builtin/data/text_analyzer.py +288 -0
- genxai/tools/builtin/data/xml_processor.py +175 -0
- genxai/tools/builtin/database/__init__.py +15 -0
- genxai/tools/builtin/database/database_inspector.py +157 -0
- genxai/tools/builtin/database/mongodb_query.py +196 -0
- genxai/tools/builtin/database/redis_cache.py +167 -0
- genxai/tools/builtin/database/sql_query.py +145 -0
- genxai/tools/builtin/database/vector_search.py +163 -0
- genxai/tools/builtin/file/__init__.py +17 -0
- genxai/tools/builtin/file/directory_scanner.py +214 -0
- genxai/tools/builtin/file/file_compressor.py +237 -0
- genxai/tools/builtin/file/file_reader.py +102 -0
- genxai/tools/builtin/file/file_writer.py +122 -0
- genxai/tools/builtin/file/image_processor.py +186 -0
- genxai/tools/builtin/file/pdf_parser.py +144 -0
- genxai/tools/builtin/test/__init__.py +15 -0
- genxai/tools/builtin/test/async_simulator.py +62 -0
- genxai/tools/builtin/test/data_transformer.py +99 -0
- genxai/tools/builtin/test/error_generator.py +82 -0
- genxai/tools/builtin/test/simple_math.py +94 -0
- genxai/tools/builtin/test/string_processor.py +72 -0
- genxai/tools/builtin/web/__init__.py +15 -0
- genxai/tools/builtin/web/api_caller.py +161 -0
- genxai/tools/builtin/web/html_parser.py +330 -0
- genxai/tools/builtin/web/http_client.py +187 -0
- genxai/tools/builtin/web/url_validator.py +162 -0
- genxai/tools/builtin/web/web_scraper.py +170 -0
- genxai/tools/custom/my_test_tool_2.py +9 -0
- genxai/tools/dynamic.py +105 -0
- genxai/tools/mcp_server.py +167 -0
- genxai/tools/persistence/__init__.py +6 -0
- genxai/tools/persistence/models.py +55 -0
- genxai/tools/persistence/service.py +322 -0
- genxai/tools/registry.py +227 -0
- genxai/tools/security/__init__.py +11 -0
- genxai/tools/security/limits.py +214 -0
- genxai/tools/security/policy.py +20 -0
- genxai/tools/security/sandbox.py +248 -0
- genxai/tools/templates.py +435 -0
- genxai/triggers/__init__.py +19 -0
- genxai/triggers/base.py +104 -0
- genxai/triggers/file_watcher.py +75 -0
- genxai/triggers/queue.py +68 -0
- genxai/triggers/registry.py +82 -0
- genxai/triggers/schedule.py +66 -0
- genxai/triggers/webhook.py +68 -0
- genxai/utils/__init__.py +1 -0
- genxai/utils/tokens.py +295 -0
- genxai_framework-0.1.0.dist-info/METADATA +495 -0
- genxai_framework-0.1.0.dist-info/RECORD +156 -0
- genxai_framework-0.1.0.dist-info/WHEEL +5 -0
- genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
- genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
- genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""Anthropic Claude LLM provider implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, AsyncIterator
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from genxai.llm.base import LLMProvider, LLMResponse
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AnthropicProvider(LLMProvider):
|
|
13
|
+
"""Anthropic Claude LLM provider."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
model: str = "claude-3-opus-20240229",
|
|
18
|
+
api_key: Optional[str] = None,
|
|
19
|
+
temperature: float = 0.7,
|
|
20
|
+
max_tokens: Optional[int] = None,
|
|
21
|
+
**kwargs: Any,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Initialize Anthropic provider.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model: Model name (claude-3-opus, claude-3-sonnet, claude-3-haiku)
|
|
27
|
+
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
|
28
|
+
temperature: Sampling temperature
|
|
29
|
+
max_tokens: Maximum tokens to generate
|
|
30
|
+
**kwargs: Additional Anthropic-specific parameters
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(model, temperature, max_tokens, **kwargs)
|
|
33
|
+
|
|
34
|
+
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
35
|
+
if not self.api_key:
|
|
36
|
+
logger.warning("Anthropic API key not provided. Set ANTHROPIC_API_KEY environment variable.")
|
|
37
|
+
|
|
38
|
+
self._client: Optional[Any] = None
|
|
39
|
+
self._initialize_client()
|
|
40
|
+
|
|
41
|
+
def _initialize_client(self) -> None:
|
|
42
|
+
"""Initialize Anthropic client."""
|
|
43
|
+
try:
|
|
44
|
+
from anthropic import AsyncAnthropic
|
|
45
|
+
self._client = AsyncAnthropic(api_key=self.api_key)
|
|
46
|
+
logger.info(f"Anthropic client initialized with model: {self.model}")
|
|
47
|
+
except ImportError:
|
|
48
|
+
logger.error(
|
|
49
|
+
"Anthropic package not installed. Install with: pip install anthropic"
|
|
50
|
+
)
|
|
51
|
+
self._client = None
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.error(f"Failed to initialize Anthropic client: {e}")
|
|
54
|
+
self._client = None
|
|
55
|
+
|
|
56
|
+
async def generate(
|
|
57
|
+
self,
|
|
58
|
+
prompt: str,
|
|
59
|
+
system_prompt: Optional[str] = None,
|
|
60
|
+
**kwargs: Any,
|
|
61
|
+
) -> LLMResponse:
|
|
62
|
+
"""Generate completion using Claude.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
prompt: User prompt
|
|
66
|
+
system_prompt: System prompt
|
|
67
|
+
**kwargs: Additional generation parameters
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
LLM response
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
RuntimeError: If client not initialized
|
|
74
|
+
Exception: If API call fails
|
|
75
|
+
"""
|
|
76
|
+
if not self._client:
|
|
77
|
+
raise RuntimeError("Anthropic client not initialized")
|
|
78
|
+
|
|
79
|
+
# Build messages
|
|
80
|
+
messages = [{"role": "user", "content": prompt}]
|
|
81
|
+
|
|
82
|
+
# Merge parameters
|
|
83
|
+
params: Dict[str, Any] = {
|
|
84
|
+
"model": self.model,
|
|
85
|
+
"messages": messages,
|
|
86
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
87
|
+
"max_tokens": kwargs.get("max_tokens", self.max_tokens or 1024),
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
# Add system prompt if provided
|
|
91
|
+
if system_prompt:
|
|
92
|
+
params["system"] = system_prompt
|
|
93
|
+
|
|
94
|
+
# Add additional parameters
|
|
95
|
+
for key in ["top_p", "top_k", "stop_sequences"]:
|
|
96
|
+
if key in kwargs:
|
|
97
|
+
params[key] = kwargs[key]
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
logger.debug(f"Calling Anthropic API with model: {self.model}")
|
|
101
|
+
response = await self._client.messages.create(**params)
|
|
102
|
+
|
|
103
|
+
# Extract response
|
|
104
|
+
content = response.content[0].text if response.content else ""
|
|
105
|
+
finish_reason = response.stop_reason
|
|
106
|
+
|
|
107
|
+
# Extract usage
|
|
108
|
+
usage = {
|
|
109
|
+
"prompt_tokens": response.usage.input_tokens if response.usage else 0,
|
|
110
|
+
"completion_tokens": response.usage.output_tokens if response.usage else 0,
|
|
111
|
+
"total_tokens": (
|
|
112
|
+
(response.usage.input_tokens + response.usage.output_tokens)
|
|
113
|
+
if response.usage else 0
|
|
114
|
+
),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
# Update stats
|
|
118
|
+
self._update_stats(usage)
|
|
119
|
+
|
|
120
|
+
return LLMResponse(
|
|
121
|
+
content=content,
|
|
122
|
+
model=response.model,
|
|
123
|
+
usage=usage,
|
|
124
|
+
finish_reason=finish_reason,
|
|
125
|
+
metadata={"response_id": response.id, "type": response.type},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"Anthropic API call failed: {e}")
|
|
130
|
+
raise
|
|
131
|
+
|
|
132
|
+
async def generate_stream(
|
|
133
|
+
self,
|
|
134
|
+
prompt: str,
|
|
135
|
+
system_prompt: Optional[str] = None,
|
|
136
|
+
**kwargs: Any,
|
|
137
|
+
) -> AsyncIterator[str]:
|
|
138
|
+
"""Generate completion with streaming.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
prompt: User prompt
|
|
142
|
+
system_prompt: System prompt
|
|
143
|
+
**kwargs: Additional generation parameters
|
|
144
|
+
|
|
145
|
+
Yields:
|
|
146
|
+
Content chunks
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
RuntimeError: If client not initialized
|
|
150
|
+
"""
|
|
151
|
+
if not self._client:
|
|
152
|
+
raise RuntimeError("Anthropic client not initialized")
|
|
153
|
+
|
|
154
|
+
# Build messages
|
|
155
|
+
messages = [{"role": "user", "content": prompt}]
|
|
156
|
+
|
|
157
|
+
# Merge parameters
|
|
158
|
+
params: Dict[str, Any] = {
|
|
159
|
+
"model": self.model,
|
|
160
|
+
"messages": messages,
|
|
161
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
162
|
+
"max_tokens": kwargs.get("max_tokens", self.max_tokens or 1024),
|
|
163
|
+
"stream": True,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
# Add system prompt if provided
|
|
167
|
+
if system_prompt:
|
|
168
|
+
params["system"] = system_prompt
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
logger.debug(f"Streaming from Anthropic API with model: {self.model}")
|
|
172
|
+
|
|
173
|
+
async with self._client.messages.stream(**params) as stream:
|
|
174
|
+
async for text in stream.text_stream:
|
|
175
|
+
yield text
|
|
176
|
+
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.error(f"Anthropic streaming failed: {e}")
|
|
179
|
+
raise
|
|
180
|
+
|
|
181
|
+
async def generate_chat(
|
|
182
|
+
self,
|
|
183
|
+
messages: list[Dict[str, str]],
|
|
184
|
+
**kwargs: Any,
|
|
185
|
+
) -> LLMResponse:
|
|
186
|
+
"""Generate completion for chat messages.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
messages: List of message dictionaries
|
|
190
|
+
**kwargs: Additional generation parameters
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
LLM response
|
|
194
|
+
"""
|
|
195
|
+
if not self._client:
|
|
196
|
+
raise RuntimeError("Anthropic client not initialized")
|
|
197
|
+
|
|
198
|
+
# Extract system prompt if present
|
|
199
|
+
system_prompt = None
|
|
200
|
+
chat_messages = []
|
|
201
|
+
|
|
202
|
+
for msg in messages:
|
|
203
|
+
if msg.get("role") == "system":
|
|
204
|
+
system_prompt = msg.get("content", "")
|
|
205
|
+
else:
|
|
206
|
+
chat_messages.append({
|
|
207
|
+
"role": msg.get("role", "user"),
|
|
208
|
+
"content": msg.get("content", ""),
|
|
209
|
+
})
|
|
210
|
+
|
|
211
|
+
# Merge parameters
|
|
212
|
+
params: Dict[str, Any] = {
|
|
213
|
+
"model": self.model,
|
|
214
|
+
"messages": chat_messages,
|
|
215
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
216
|
+
"max_tokens": kwargs.get("max_tokens", self.max_tokens or 1024),
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
if system_prompt:
|
|
220
|
+
params["system"] = system_prompt
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
response = await self._client.messages.create(**params)
|
|
224
|
+
|
|
225
|
+
content = response.content[0].text if response.content else ""
|
|
226
|
+
finish_reason = response.stop_reason
|
|
227
|
+
|
|
228
|
+
usage = {
|
|
229
|
+
"prompt_tokens": response.usage.input_tokens if response.usage else 0,
|
|
230
|
+
"completion_tokens": response.usage.output_tokens if response.usage else 0,
|
|
231
|
+
"total_tokens": (
|
|
232
|
+
(response.usage.input_tokens + response.usage.output_tokens)
|
|
233
|
+
if response.usage else 0
|
|
234
|
+
),
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
self._update_stats(usage)
|
|
238
|
+
|
|
239
|
+
return LLMResponse(
|
|
240
|
+
content=content,
|
|
241
|
+
model=response.model,
|
|
242
|
+
usage=usage,
|
|
243
|
+
finish_reason=finish_reason,
|
|
244
|
+
metadata={"response_id": response.id, "type": response.type},
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
except Exception as e:
|
|
248
|
+
logger.error(f"Anthropic chat API call failed: {e}")
|
|
249
|
+
raise
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Cohere LLM provider implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, AsyncIterator
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from genxai.llm.base import LLMProvider, LLMResponse
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CohereProvider(LLMProvider):
|
|
13
|
+
"""Cohere LLM provider."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
model: str = "command",
|
|
18
|
+
api_key: Optional[str] = None,
|
|
19
|
+
temperature: float = 0.7,
|
|
20
|
+
max_tokens: Optional[int] = None,
|
|
21
|
+
**kwargs: Any,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Initialize Cohere provider.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model: Model name (command, command-light, command-r, command-r-plus)
|
|
27
|
+
api_key: Cohere API key (defaults to COHERE_API_KEY env var)
|
|
28
|
+
temperature: Sampling temperature
|
|
29
|
+
max_tokens: Maximum tokens to generate
|
|
30
|
+
**kwargs: Additional Cohere-specific parameters
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(model, temperature, max_tokens, **kwargs)
|
|
33
|
+
|
|
34
|
+
self.api_key = api_key or os.getenv("COHERE_API_KEY")
|
|
35
|
+
if not self.api_key:
|
|
36
|
+
logger.warning("Cohere API key not provided. Set COHERE_API_KEY environment variable.")
|
|
37
|
+
|
|
38
|
+
self._client: Optional[Any] = None
|
|
39
|
+
self._initialize_client()
|
|
40
|
+
|
|
41
|
+
def _initialize_client(self) -> None:
|
|
42
|
+
"""Initialize Cohere client."""
|
|
43
|
+
try:
|
|
44
|
+
import cohere
|
|
45
|
+
self._client = cohere.AsyncClient(api_key=self.api_key)
|
|
46
|
+
logger.info(f"Cohere client initialized with model: {self.model}")
|
|
47
|
+
except ImportError:
|
|
48
|
+
logger.error(
|
|
49
|
+
"Cohere package not installed. Install with: pip install cohere"
|
|
50
|
+
)
|
|
51
|
+
self._client = None
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.error(f"Failed to initialize Cohere client: {e}")
|
|
54
|
+
self._client = None
|
|
55
|
+
|
|
56
|
+
async def generate(
|
|
57
|
+
self,
|
|
58
|
+
prompt: str,
|
|
59
|
+
system_prompt: Optional[str] = None,
|
|
60
|
+
**kwargs: Any,
|
|
61
|
+
) -> LLMResponse:
|
|
62
|
+
"""Generate completion using Cohere.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
prompt: User prompt
|
|
66
|
+
system_prompt: System prompt (prepended to user prompt)
|
|
67
|
+
**kwargs: Additional generation parameters
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
LLM response
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
RuntimeError: If client not initialized
|
|
74
|
+
Exception: If API call fails
|
|
75
|
+
"""
|
|
76
|
+
if not self._client:
|
|
77
|
+
raise RuntimeError("Cohere client not initialized")
|
|
78
|
+
|
|
79
|
+
# Combine system prompt with user prompt
|
|
80
|
+
full_prompt = prompt
|
|
81
|
+
if system_prompt:
|
|
82
|
+
full_prompt = f"{system_prompt}\n\n{prompt}"
|
|
83
|
+
|
|
84
|
+
# Merge parameters
|
|
85
|
+
params: Dict[str, Any] = {
|
|
86
|
+
"model": self.model,
|
|
87
|
+
"prompt": full_prompt,
|
|
88
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if self.max_tokens:
|
|
92
|
+
params["max_tokens"] = kwargs.get("max_tokens", self.max_tokens)
|
|
93
|
+
|
|
94
|
+
# Add additional parameters
|
|
95
|
+
for key in ["p", "k", "frequency_penalty", "presence_penalty", "stop_sequences"]:
|
|
96
|
+
if key in kwargs:
|
|
97
|
+
params[key] = kwargs[key]
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
logger.debug(f"Calling Cohere API with model: {self.model}")
|
|
101
|
+
response = await self._client.generate(**params)
|
|
102
|
+
|
|
103
|
+
# Extract response
|
|
104
|
+
content = response.generations[0].text if response.generations else ""
|
|
105
|
+
finish_reason = response.generations[0].finish_reason if response.generations else None
|
|
106
|
+
|
|
107
|
+
# Extract usage (Cohere provides token counts in meta)
|
|
108
|
+
usage = {
|
|
109
|
+
"prompt_tokens": 0,
|
|
110
|
+
"completion_tokens": 0,
|
|
111
|
+
"total_tokens": 0,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
if hasattr(response, 'meta') and response.meta:
|
|
115
|
+
billed_units = response.meta.billed_units
|
|
116
|
+
if billed_units:
|
|
117
|
+
usage["prompt_tokens"] = getattr(billed_units, 'input_tokens', 0)
|
|
118
|
+
usage["completion_tokens"] = getattr(billed_units, 'output_tokens', 0)
|
|
119
|
+
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
|
|
120
|
+
|
|
121
|
+
# Update stats
|
|
122
|
+
self._update_stats(usage)
|
|
123
|
+
|
|
124
|
+
return LLMResponse(
|
|
125
|
+
content=content,
|
|
126
|
+
model=self.model,
|
|
127
|
+
usage=usage,
|
|
128
|
+
finish_reason=finish_reason,
|
|
129
|
+
metadata={
|
|
130
|
+
"generation_id": response.generations[0].id if response.generations else None,
|
|
131
|
+
},
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.error(f"Cohere API call failed: {e}")
|
|
136
|
+
raise
|
|
137
|
+
|
|
138
|
+
async def generate_stream(
|
|
139
|
+
self,
|
|
140
|
+
prompt: str,
|
|
141
|
+
system_prompt: Optional[str] = None,
|
|
142
|
+
**kwargs: Any,
|
|
143
|
+
) -> AsyncIterator[str]:
|
|
144
|
+
"""Generate completion with streaming.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
prompt: User prompt
|
|
148
|
+
system_prompt: System prompt
|
|
149
|
+
**kwargs: Additional generation parameters
|
|
150
|
+
|
|
151
|
+
Yields:
|
|
152
|
+
Content chunks
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
RuntimeError: If client not initialized
|
|
156
|
+
"""
|
|
157
|
+
if not self._client:
|
|
158
|
+
raise RuntimeError("Cohere client not initialized")
|
|
159
|
+
|
|
160
|
+
# Combine system prompt with user prompt
|
|
161
|
+
full_prompt = prompt
|
|
162
|
+
if system_prompt:
|
|
163
|
+
full_prompt = f"{system_prompt}\n\n{prompt}"
|
|
164
|
+
|
|
165
|
+
# Merge parameters
|
|
166
|
+
params: Dict[str, Any] = {
|
|
167
|
+
"model": self.model,
|
|
168
|
+
"prompt": full_prompt,
|
|
169
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
170
|
+
"stream": True,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
if self.max_tokens:
|
|
174
|
+
params["max_tokens"] = kwargs.get("max_tokens", self.max_tokens)
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
logger.debug(f"Streaming from Cohere API with model: {self.model}")
|
|
178
|
+
|
|
179
|
+
response = await self._client.generate(**params)
|
|
180
|
+
|
|
181
|
+
async for event in response:
|
|
182
|
+
if event.event_type == "text-generation":
|
|
183
|
+
yield event.text
|
|
184
|
+
|
|
185
|
+
except Exception as e:
|
|
186
|
+
logger.error(f"Cohere streaming failed: {e}")
|
|
187
|
+
raise
|
|
188
|
+
|
|
189
|
+
async def generate_chat(
|
|
190
|
+
self,
|
|
191
|
+
messages: list[Dict[str, str]],
|
|
192
|
+
**kwargs: Any,
|
|
193
|
+
) -> LLMResponse:
|
|
194
|
+
"""Generate completion for chat messages.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
messages: List of message dictionaries
|
|
198
|
+
**kwargs: Additional generation parameters
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
LLM response
|
|
202
|
+
"""
|
|
203
|
+
if not self._client:
|
|
204
|
+
raise RuntimeError("Cohere client not initialized")
|
|
205
|
+
|
|
206
|
+
# Convert messages to Cohere chat format
|
|
207
|
+
chat_history = []
|
|
208
|
+
system_prompt = None
|
|
209
|
+
last_user_message = ""
|
|
210
|
+
|
|
211
|
+
for msg in messages:
|
|
212
|
+
role = msg.get("role", "user")
|
|
213
|
+
content = msg.get("content", "")
|
|
214
|
+
|
|
215
|
+
if role == "system":
|
|
216
|
+
system_prompt = content
|
|
217
|
+
elif role == "user":
|
|
218
|
+
last_user_message = content
|
|
219
|
+
# Add to history if not the last message
|
|
220
|
+
if msg != messages[-1]:
|
|
221
|
+
chat_history.append({"role": "USER", "message": content})
|
|
222
|
+
elif role == "assistant":
|
|
223
|
+
chat_history.append({"role": "CHATBOT", "message": content})
|
|
224
|
+
|
|
225
|
+
# Merge parameters
|
|
226
|
+
params: Dict[str, Any] = {
|
|
227
|
+
"model": self.model,
|
|
228
|
+
"message": last_user_message,
|
|
229
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
if chat_history:
|
|
233
|
+
params["chat_history"] = chat_history
|
|
234
|
+
|
|
235
|
+
if system_prompt:
|
|
236
|
+
params["preamble"] = system_prompt
|
|
237
|
+
|
|
238
|
+
if self.max_tokens:
|
|
239
|
+
params["max_tokens"] = kwargs.get("max_tokens", self.max_tokens)
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
response = await self._client.chat(**params)
|
|
243
|
+
|
|
244
|
+
content = response.text if hasattr(response, 'text') else ""
|
|
245
|
+
finish_reason = response.finish_reason if hasattr(response, 'finish_reason') else None
|
|
246
|
+
|
|
247
|
+
usage = {
|
|
248
|
+
"prompt_tokens": 0,
|
|
249
|
+
"completion_tokens": 0,
|
|
250
|
+
"total_tokens": 0,
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
if hasattr(response, 'meta') and response.meta:
|
|
254
|
+
billed_units = response.meta.billed_units
|
|
255
|
+
if billed_units:
|
|
256
|
+
usage["prompt_tokens"] = getattr(billed_units, 'input_tokens', 0)
|
|
257
|
+
usage["completion_tokens"] = getattr(billed_units, 'output_tokens', 0)
|
|
258
|
+
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
|
|
259
|
+
|
|
260
|
+
self._update_stats(usage)
|
|
261
|
+
|
|
262
|
+
return LLMResponse(
|
|
263
|
+
content=content,
|
|
264
|
+
model=self.model,
|
|
265
|
+
usage=usage,
|
|
266
|
+
finish_reason=finish_reason,
|
|
267
|
+
metadata={
|
|
268
|
+
"generation_id": response.generation_id if hasattr(response, 'generation_id') else None,
|
|
269
|
+
},
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"Cohere chat API call failed: {e}")
|
|
274
|
+
raise
|