genxai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +3 -0
- cli/commands/__init__.py +6 -0
- cli/commands/approval.py +85 -0
- cli/commands/audit.py +127 -0
- cli/commands/metrics.py +25 -0
- cli/commands/tool.py +389 -0
- cli/main.py +32 -0
- genxai/__init__.py +81 -0
- genxai/api/__init__.py +5 -0
- genxai/api/app.py +21 -0
- genxai/config/__init__.py +5 -0
- genxai/config/settings.py +37 -0
- genxai/connectors/__init__.py +19 -0
- genxai/connectors/base.py +122 -0
- genxai/connectors/kafka.py +92 -0
- genxai/connectors/postgres_cdc.py +95 -0
- genxai/connectors/registry.py +44 -0
- genxai/connectors/sqs.py +94 -0
- genxai/connectors/webhook.py +73 -0
- genxai/core/__init__.py +37 -0
- genxai/core/agent/__init__.py +32 -0
- genxai/core/agent/base.py +206 -0
- genxai/core/agent/config_io.py +59 -0
- genxai/core/agent/registry.py +98 -0
- genxai/core/agent/runtime.py +970 -0
- genxai/core/communication/__init__.py +6 -0
- genxai/core/communication/collaboration.py +44 -0
- genxai/core/communication/message_bus.py +192 -0
- genxai/core/communication/protocols.py +35 -0
- genxai/core/execution/__init__.py +22 -0
- genxai/core/execution/metadata.py +181 -0
- genxai/core/execution/queue.py +201 -0
- genxai/core/graph/__init__.py +30 -0
- genxai/core/graph/checkpoints.py +77 -0
- genxai/core/graph/edges.py +131 -0
- genxai/core/graph/engine.py +813 -0
- genxai/core/graph/executor.py +516 -0
- genxai/core/graph/nodes.py +161 -0
- genxai/core/graph/trigger_runner.py +40 -0
- genxai/core/memory/__init__.py +19 -0
- genxai/core/memory/base.py +72 -0
- genxai/core/memory/embedding.py +327 -0
- genxai/core/memory/episodic.py +448 -0
- genxai/core/memory/long_term.py +467 -0
- genxai/core/memory/manager.py +543 -0
- genxai/core/memory/persistence.py +297 -0
- genxai/core/memory/procedural.py +461 -0
- genxai/core/memory/semantic.py +526 -0
- genxai/core/memory/shared.py +62 -0
- genxai/core/memory/short_term.py +303 -0
- genxai/core/memory/vector_store.py +508 -0
- genxai/core/memory/working.py +211 -0
- genxai/core/state/__init__.py +6 -0
- genxai/core/state/manager.py +293 -0
- genxai/core/state/schema.py +115 -0
- genxai/llm/__init__.py +14 -0
- genxai/llm/base.py +150 -0
- genxai/llm/factory.py +329 -0
- genxai/llm/providers/__init__.py +1 -0
- genxai/llm/providers/anthropic.py +249 -0
- genxai/llm/providers/cohere.py +274 -0
- genxai/llm/providers/google.py +334 -0
- genxai/llm/providers/ollama.py +147 -0
- genxai/llm/providers/openai.py +257 -0
- genxai/llm/routing.py +83 -0
- genxai/observability/__init__.py +6 -0
- genxai/observability/logging.py +327 -0
- genxai/observability/metrics.py +494 -0
- genxai/observability/tracing.py +372 -0
- genxai/performance/__init__.py +39 -0
- genxai/performance/cache.py +256 -0
- genxai/performance/pooling.py +289 -0
- genxai/security/audit.py +304 -0
- genxai/security/auth.py +315 -0
- genxai/security/cost_control.py +528 -0
- genxai/security/default_policies.py +44 -0
- genxai/security/jwt.py +142 -0
- genxai/security/oauth.py +226 -0
- genxai/security/pii.py +366 -0
- genxai/security/policy_engine.py +82 -0
- genxai/security/rate_limit.py +341 -0
- genxai/security/rbac.py +247 -0
- genxai/security/validation.py +218 -0
- genxai/tools/__init__.py +21 -0
- genxai/tools/base.py +383 -0
- genxai/tools/builtin/__init__.py +131 -0
- genxai/tools/builtin/communication/__init__.py +15 -0
- genxai/tools/builtin/communication/email_sender.py +159 -0
- genxai/tools/builtin/communication/notification_manager.py +167 -0
- genxai/tools/builtin/communication/slack_notifier.py +118 -0
- genxai/tools/builtin/communication/sms_sender.py +118 -0
- genxai/tools/builtin/communication/webhook_caller.py +136 -0
- genxai/tools/builtin/computation/__init__.py +15 -0
- genxai/tools/builtin/computation/calculator.py +101 -0
- genxai/tools/builtin/computation/code_executor.py +183 -0
- genxai/tools/builtin/computation/data_validator.py +259 -0
- genxai/tools/builtin/computation/hash_generator.py +129 -0
- genxai/tools/builtin/computation/regex_matcher.py +201 -0
- genxai/tools/builtin/data/__init__.py +15 -0
- genxai/tools/builtin/data/csv_processor.py +213 -0
- genxai/tools/builtin/data/data_transformer.py +299 -0
- genxai/tools/builtin/data/json_processor.py +233 -0
- genxai/tools/builtin/data/text_analyzer.py +288 -0
- genxai/tools/builtin/data/xml_processor.py +175 -0
- genxai/tools/builtin/database/__init__.py +15 -0
- genxai/tools/builtin/database/database_inspector.py +157 -0
- genxai/tools/builtin/database/mongodb_query.py +196 -0
- genxai/tools/builtin/database/redis_cache.py +167 -0
- genxai/tools/builtin/database/sql_query.py +145 -0
- genxai/tools/builtin/database/vector_search.py +163 -0
- genxai/tools/builtin/file/__init__.py +17 -0
- genxai/tools/builtin/file/directory_scanner.py +214 -0
- genxai/tools/builtin/file/file_compressor.py +237 -0
- genxai/tools/builtin/file/file_reader.py +102 -0
- genxai/tools/builtin/file/file_writer.py +122 -0
- genxai/tools/builtin/file/image_processor.py +186 -0
- genxai/tools/builtin/file/pdf_parser.py +144 -0
- genxai/tools/builtin/test/__init__.py +15 -0
- genxai/tools/builtin/test/async_simulator.py +62 -0
- genxai/tools/builtin/test/data_transformer.py +99 -0
- genxai/tools/builtin/test/error_generator.py +82 -0
- genxai/tools/builtin/test/simple_math.py +94 -0
- genxai/tools/builtin/test/string_processor.py +72 -0
- genxai/tools/builtin/web/__init__.py +15 -0
- genxai/tools/builtin/web/api_caller.py +161 -0
- genxai/tools/builtin/web/html_parser.py +330 -0
- genxai/tools/builtin/web/http_client.py +187 -0
- genxai/tools/builtin/web/url_validator.py +162 -0
- genxai/tools/builtin/web/web_scraper.py +170 -0
- genxai/tools/custom/my_test_tool_2.py +9 -0
- genxai/tools/dynamic.py +105 -0
- genxai/tools/mcp_server.py +167 -0
- genxai/tools/persistence/__init__.py +6 -0
- genxai/tools/persistence/models.py +55 -0
- genxai/tools/persistence/service.py +322 -0
- genxai/tools/registry.py +227 -0
- genxai/tools/security/__init__.py +11 -0
- genxai/tools/security/limits.py +214 -0
- genxai/tools/security/policy.py +20 -0
- genxai/tools/security/sandbox.py +248 -0
- genxai/tools/templates.py +435 -0
- genxai/triggers/__init__.py +19 -0
- genxai/triggers/base.py +104 -0
- genxai/triggers/file_watcher.py +75 -0
- genxai/triggers/queue.py +68 -0
- genxai/triggers/registry.py +82 -0
- genxai/triggers/schedule.py +66 -0
- genxai/triggers/webhook.py +68 -0
- genxai/utils/__init__.py +1 -0
- genxai/utils/tokens.py +295 -0
- genxai_framework-0.1.0.dist-info/METADATA +495 -0
- genxai_framework-0.1.0.dist-info/RECORD +156 -0
- genxai_framework-0.1.0.dist-info/WHEEL +5 -0
- genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
- genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
- genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
genxai/llm/base.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Base LLM provider interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional, AsyncIterator
|
|
4
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LLMResponse(BaseModel):
|
|
12
|
+
"""Response from LLM provider."""
|
|
13
|
+
|
|
14
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
15
|
+
|
|
16
|
+
content: str = Field(..., description="Generated content")
|
|
17
|
+
model: str = Field(..., description="Model used")
|
|
18
|
+
usage: Dict[str, int] = Field(default_factory=dict, description="Token usage")
|
|
19
|
+
finish_reason: Optional[str] = Field(None, description="Reason for completion")
|
|
20
|
+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LLMProvider(ABC):
|
|
25
|
+
"""Base class for LLM providers."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model: str,
|
|
30
|
+
temperature: float = 0.7,
|
|
31
|
+
max_tokens: Optional[int] = None,
|
|
32
|
+
**kwargs: Any,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initialize LLM provider.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
model: Model name
|
|
38
|
+
temperature: Sampling temperature
|
|
39
|
+
max_tokens: Maximum tokens to generate
|
|
40
|
+
**kwargs: Additional provider-specific arguments
|
|
41
|
+
"""
|
|
42
|
+
self.model = model
|
|
43
|
+
self.temperature = temperature
|
|
44
|
+
self.max_tokens = max_tokens
|
|
45
|
+
self.kwargs = kwargs
|
|
46
|
+
self._total_tokens = 0
|
|
47
|
+
self._request_count = 0
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def generate(
|
|
51
|
+
self,
|
|
52
|
+
prompt: str,
|
|
53
|
+
system_prompt: Optional[str] = None,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> LLMResponse:
|
|
56
|
+
"""Generate completion for prompt.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
prompt: User prompt
|
|
60
|
+
system_prompt: System prompt
|
|
61
|
+
**kwargs: Additional generation parameters
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
LLM response
|
|
65
|
+
"""
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
async def generate_stream(
|
|
70
|
+
self,
|
|
71
|
+
prompt: str,
|
|
72
|
+
system_prompt: Optional[str] = None,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> AsyncIterator[str]:
|
|
75
|
+
"""Generate completion with streaming.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
prompt: User prompt
|
|
79
|
+
system_prompt: System prompt
|
|
80
|
+
**kwargs: Additional generation parameters
|
|
81
|
+
|
|
82
|
+
Yields:
|
|
83
|
+
Content chunks
|
|
84
|
+
"""
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
async def generate_chat(
|
|
88
|
+
self,
|
|
89
|
+
messages: List[Dict[str, str]],
|
|
90
|
+
**kwargs: Any,
|
|
91
|
+
) -> LLMResponse:
|
|
92
|
+
"""Generate completion for chat messages.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
messages: List of message dictionaries with 'role' and 'content'
|
|
96
|
+
**kwargs: Additional generation parameters
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
LLM response
|
|
100
|
+
"""
|
|
101
|
+
# Default implementation converts to single prompt
|
|
102
|
+
prompt_parts = []
|
|
103
|
+
system_prompt = None
|
|
104
|
+
|
|
105
|
+
for msg in messages:
|
|
106
|
+
role = msg.get("role", "user")
|
|
107
|
+
content = msg.get("content", "")
|
|
108
|
+
|
|
109
|
+
if role == "system":
|
|
110
|
+
system_prompt = content
|
|
111
|
+
elif role == "user":
|
|
112
|
+
prompt_parts.append(f"User: {content}")
|
|
113
|
+
elif role == "assistant":
|
|
114
|
+
prompt_parts.append(f"Assistant: {content}")
|
|
115
|
+
|
|
116
|
+
prompt = "\n".join(prompt_parts)
|
|
117
|
+
return await self.generate(prompt, system_prompt, **kwargs)
|
|
118
|
+
|
|
119
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
120
|
+
"""Get provider statistics.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Statistics dictionary
|
|
124
|
+
"""
|
|
125
|
+
return {
|
|
126
|
+
"model": self.model,
|
|
127
|
+
"total_tokens": self._total_tokens,
|
|
128
|
+
"request_count": self._request_count,
|
|
129
|
+
"avg_tokens_per_request": (
|
|
130
|
+
self._total_tokens / self._request_count if self._request_count > 0 else 0
|
|
131
|
+
),
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
def reset_stats(self) -> None:
|
|
135
|
+
"""Reset provider statistics."""
|
|
136
|
+
self._total_tokens = 0
|
|
137
|
+
self._request_count = 0
|
|
138
|
+
|
|
139
|
+
def _update_stats(self, usage: Dict[str, int]) -> None:
|
|
140
|
+
"""Update provider statistics.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
usage: Token usage dictionary
|
|
144
|
+
"""
|
|
145
|
+
self._request_count += 1
|
|
146
|
+
self._total_tokens += usage.get("total_tokens", 0)
|
|
147
|
+
|
|
148
|
+
def __repr__(self) -> str:
|
|
149
|
+
"""String representation."""
|
|
150
|
+
return f"{self.__class__.__name__}(model={self.model})"
|
genxai/llm/factory.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""LLM Provider Factory for creating and managing LLM providers."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Dict, Any, Iterable, List
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from genxai.llm.base import LLMProvider
|
|
8
|
+
from genxai.llm.routing import RoutedLLMProvider
|
|
9
|
+
from genxai.llm.providers.openai import OpenAIProvider
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LLMProviderFactory:
|
|
15
|
+
"""Factory for creating LLM provider instances."""
|
|
16
|
+
|
|
17
|
+
_fallback_chain: List[str] = [
|
|
18
|
+
"gpt-4",
|
|
19
|
+
"gpt-4-turbo",
|
|
20
|
+
"gpt-3.5-turbo",
|
|
21
|
+
"claude-3-sonnet",
|
|
22
|
+
"claude-3-haiku",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
_providers: Dict[str, type[LLMProvider]] = {
|
|
26
|
+
# OpenAI
|
|
27
|
+
"openai": OpenAIProvider,
|
|
28
|
+
"gpt-4": OpenAIProvider,
|
|
29
|
+
"gpt-3.5-turbo": OpenAIProvider,
|
|
30
|
+
"gpt-4-turbo": OpenAIProvider,
|
|
31
|
+
"gpt-4o": OpenAIProvider,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
# Lazy-loaded providers (imported on demand)
|
|
35
|
+
_provider_modules: Dict[str, str] = {
|
|
36
|
+
"anthropic": "genxai.llm.providers.anthropic.AnthropicProvider",
|
|
37
|
+
"claude-3-opus": "genxai.llm.providers.anthropic.AnthropicProvider",
|
|
38
|
+
"claude-3-sonnet": "genxai.llm.providers.anthropic.AnthropicProvider",
|
|
39
|
+
"claude-3-haiku": "genxai.llm.providers.anthropic.AnthropicProvider",
|
|
40
|
+
"google": "genxai.llm.providers.google.GoogleProvider",
|
|
41
|
+
"gemini-pro": "genxai.llm.providers.google.GoogleProvider",
|
|
42
|
+
"gemini-ultra": "genxai.llm.providers.google.GoogleProvider",
|
|
43
|
+
"cohere": "genxai.llm.providers.cohere.CohereProvider",
|
|
44
|
+
"command": "genxai.llm.providers.cohere.CohereProvider",
|
|
45
|
+
"command-r": "genxai.llm.providers.cohere.CohereProvider",
|
|
46
|
+
"ollama": "genxai.llm.providers.ollama.OllamaProvider",
|
|
47
|
+
"llama3": "genxai.llm.providers.ollama.OllamaProvider",
|
|
48
|
+
"mistral": "genxai.llm.providers.ollama.OllamaProvider",
|
|
49
|
+
"phi3": "genxai.llm.providers.ollama.OllamaProvider",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def register_provider(cls, name: str, provider_class: type[LLMProvider]) -> None:
|
|
54
|
+
"""Register a new LLM provider.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
name: Provider name or model name
|
|
58
|
+
provider_class: Provider class
|
|
59
|
+
"""
|
|
60
|
+
cls._providers[name] = provider_class
|
|
61
|
+
logger.info(f"Registered LLM provider: {name}")
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def create_provider(
|
|
65
|
+
cls,
|
|
66
|
+
model: str,
|
|
67
|
+
api_key: Optional[str] = None,
|
|
68
|
+
temperature: float = 0.7,
|
|
69
|
+
max_tokens: Optional[int] = None,
|
|
70
|
+
fallback_models: Optional[list[str]] = None,
|
|
71
|
+
**kwargs: Any,
|
|
72
|
+
) -> LLMProvider:
|
|
73
|
+
"""Create an LLM provider instance.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
model: Model name (e.g., "gpt-4", "claude-3-opus")
|
|
77
|
+
api_key: API key for the provider
|
|
78
|
+
temperature: Sampling temperature
|
|
79
|
+
max_tokens: Maximum tokens to generate
|
|
80
|
+
fallback_models: List of fallback models if primary fails
|
|
81
|
+
**kwargs: Additional provider-specific parameters
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
LLM provider instance
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If provider not found or API key missing
|
|
88
|
+
"""
|
|
89
|
+
# Determine provider from model name
|
|
90
|
+
provider_class = cls._get_provider_class(model)
|
|
91
|
+
|
|
92
|
+
if not provider_class:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"No provider found for model '{model}'. "
|
|
95
|
+
f"Available providers: {list(cls._providers.keys())}"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Get API key from environment if not provided
|
|
99
|
+
if api_key is None:
|
|
100
|
+
api_key = cls._get_api_key_for_provider(provider_class)
|
|
101
|
+
|
|
102
|
+
if not api_key:
|
|
103
|
+
logger.warning(
|
|
104
|
+
f"No API key provided for {provider_class.__name__}. "
|
|
105
|
+
"Provider may fail at runtime."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Create provider instance
|
|
109
|
+
try:
|
|
110
|
+
provider = provider_class(
|
|
111
|
+
model=model,
|
|
112
|
+
api_key=api_key,
|
|
113
|
+
temperature=temperature,
|
|
114
|
+
max_tokens=max_tokens,
|
|
115
|
+
**kwargs,
|
|
116
|
+
)
|
|
117
|
+
logger.info(f"Created LLM provider: {provider_class.__name__} with model {model}")
|
|
118
|
+
if fallback_models:
|
|
119
|
+
return RoutedLLMProvider(
|
|
120
|
+
primary=provider,
|
|
121
|
+
fallbacks=cls._create_fallback_providers(
|
|
122
|
+
fallback_models,
|
|
123
|
+
api_key=api_key,
|
|
124
|
+
temperature=temperature,
|
|
125
|
+
max_tokens=max_tokens,
|
|
126
|
+
**kwargs,
|
|
127
|
+
),
|
|
128
|
+
)
|
|
129
|
+
return provider
|
|
130
|
+
|
|
131
|
+
except Exception as e:
|
|
132
|
+
logger.error(f"Failed to create provider for model '{model}': {e}")
|
|
133
|
+
|
|
134
|
+
# Try fallback models if provided
|
|
135
|
+
if fallback_models:
|
|
136
|
+
logger.info(f"Attempting fallback models: {fallback_models}")
|
|
137
|
+
for fallback_model in fallback_models:
|
|
138
|
+
try:
|
|
139
|
+
return cls.create_provider(
|
|
140
|
+
model=fallback_model,
|
|
141
|
+
api_key=api_key,
|
|
142
|
+
temperature=temperature,
|
|
143
|
+
max_tokens=max_tokens,
|
|
144
|
+
**kwargs,
|
|
145
|
+
)
|
|
146
|
+
except Exception as fallback_error:
|
|
147
|
+
logger.warning(f"Fallback model '{fallback_model}' failed: {fallback_error}")
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
raise ValueError(f"Failed to create provider for model '{model}': {e}") from e
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def create_routed_provider(
|
|
154
|
+
cls,
|
|
155
|
+
primary_model: str,
|
|
156
|
+
fallback_models: Optional[List[str]] = None,
|
|
157
|
+
**kwargs: Any,
|
|
158
|
+
) -> LLMProvider:
|
|
159
|
+
"""Create a routed provider with a fallback chain.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
primary_model: Primary model name
|
|
163
|
+
fallback_models: Optional override for fallback chain
|
|
164
|
+
**kwargs: Provider options forwarded to create_provider
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
RoutedLLMProvider instance
|
|
168
|
+
"""
|
|
169
|
+
fallback_chain = fallback_models or cls._fallback_chain
|
|
170
|
+
return cls.create_provider(
|
|
171
|
+
model=primary_model,
|
|
172
|
+
fallback_models=fallback_chain,
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def set_default_fallback_chain(cls, models: Iterable[str]) -> None:
|
|
178
|
+
"""Set the default fallback model chain."""
|
|
179
|
+
cls._fallback_chain = list(models)
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def _create_fallback_providers(
|
|
183
|
+
cls,
|
|
184
|
+
fallback_models: Iterable[str],
|
|
185
|
+
api_key: Optional[str],
|
|
186
|
+
temperature: float,
|
|
187
|
+
max_tokens: Optional[int],
|
|
188
|
+
**kwargs: Any,
|
|
189
|
+
) -> List[LLMProvider]:
|
|
190
|
+
providers: List[LLMProvider] = []
|
|
191
|
+
for fallback_model in fallback_models:
|
|
192
|
+
try:
|
|
193
|
+
provider_class = cls._get_provider_class(fallback_model)
|
|
194
|
+
if not provider_class:
|
|
195
|
+
logger.warning("No provider class for fallback model %s", fallback_model)
|
|
196
|
+
continue
|
|
197
|
+
fallback_key = api_key or cls._get_api_key_for_provider(provider_class)
|
|
198
|
+
provider = provider_class(
|
|
199
|
+
model=fallback_model,
|
|
200
|
+
api_key=fallback_key,
|
|
201
|
+
temperature=temperature,
|
|
202
|
+
max_tokens=max_tokens,
|
|
203
|
+
**kwargs,
|
|
204
|
+
)
|
|
205
|
+
providers.append(provider)
|
|
206
|
+
except Exception as exc:
|
|
207
|
+
logger.warning("Failed to initialize fallback model %s: %s", fallback_model, exc)
|
|
208
|
+
return providers
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def _load_provider_class(cls, module_path: str) -> Optional[type[LLMProvider]]:
|
|
212
|
+
"""Dynamically load a provider class.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
module_path: Full module path (e.g., 'genxai.llm.providers.anthropic.AnthropicProvider')
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Provider class or None
|
|
219
|
+
"""
|
|
220
|
+
try:
|
|
221
|
+
module_name, class_name = module_path.rsplit(".", 1)
|
|
222
|
+
module = __import__(module_name, fromlist=[class_name])
|
|
223
|
+
return getattr(module, class_name)
|
|
224
|
+
except Exception as e:
|
|
225
|
+
logger.error(f"Failed to load provider from {module_path}: {e}")
|
|
226
|
+
return None
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def _get_provider_class(cls, model: str) -> Optional[type[LLMProvider]]:
|
|
230
|
+
"""Get provider class for a model.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
model: Model name
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Provider class or None
|
|
237
|
+
"""
|
|
238
|
+
# Direct match in pre-loaded providers
|
|
239
|
+
if model in cls._providers:
|
|
240
|
+
return cls._providers[model]
|
|
241
|
+
|
|
242
|
+
# Check lazy-loaded providers
|
|
243
|
+
if model in cls._provider_modules:
|
|
244
|
+
provider_class = cls._load_provider_class(cls._provider_modules[model])
|
|
245
|
+
if provider_class:
|
|
246
|
+
# Cache it for future use
|
|
247
|
+
cls._providers[model] = provider_class
|
|
248
|
+
return provider_class
|
|
249
|
+
|
|
250
|
+
# Check if model starts with known provider prefix
|
|
251
|
+
model_lower = model.lower()
|
|
252
|
+
if model_lower.startswith("gpt"):
|
|
253
|
+
return OpenAIProvider
|
|
254
|
+
elif model_lower.startswith("claude"):
|
|
255
|
+
provider_class = cls._load_provider_class("genxai.llm.providers.anthropic.AnthropicProvider")
|
|
256
|
+
if provider_class:
|
|
257
|
+
cls._providers[model] = provider_class
|
|
258
|
+
return provider_class
|
|
259
|
+
elif model_lower.startswith("gemini"):
|
|
260
|
+
provider_class = cls._load_provider_class("genxai.llm.providers.google.GoogleProvider")
|
|
261
|
+
if provider_class:
|
|
262
|
+
cls._providers[model] = provider_class
|
|
263
|
+
return provider_class
|
|
264
|
+
elif model_lower.startswith("command"):
|
|
265
|
+
provider_class = cls._load_provider_class("genxai.llm.providers.cohere.CohereProvider")
|
|
266
|
+
if provider_class:
|
|
267
|
+
cls._providers[model] = provider_class
|
|
268
|
+
return provider_class
|
|
269
|
+
elif model_lower.startswith("llama") or model_lower.startswith("mistral") or model_lower.startswith("phi"):
|
|
270
|
+
provider_class = cls._load_provider_class("genxai.llm.providers.ollama.OllamaProvider")
|
|
271
|
+
if provider_class:
|
|
272
|
+
cls._providers[model] = provider_class
|
|
273
|
+
return provider_class
|
|
274
|
+
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
@classmethod
|
|
278
|
+
def _get_api_key_for_provider(cls, provider_class: type[LLMProvider]) -> Optional[str]:
|
|
279
|
+
"""Get API key from environment for a provider.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
provider_class: Provider class
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
API key or None
|
|
286
|
+
"""
|
|
287
|
+
provider_name = provider_class.__name__
|
|
288
|
+
|
|
289
|
+
if provider_name == "OpenAIProvider":
|
|
290
|
+
return os.getenv("OPENAI_API_KEY")
|
|
291
|
+
elif provider_name == "AnthropicProvider":
|
|
292
|
+
return os.getenv("ANTHROPIC_API_KEY")
|
|
293
|
+
elif provider_name == "GoogleProvider":
|
|
294
|
+
return os.getenv("GOOGLE_API_KEY")
|
|
295
|
+
elif provider_name == "CohereProvider":
|
|
296
|
+
return os.getenv("COHERE_API_KEY")
|
|
297
|
+
elif provider_name == "OllamaProvider":
|
|
298
|
+
return os.getenv("OLLAMA_API_KEY")
|
|
299
|
+
|
|
300
|
+
return None
|
|
301
|
+
|
|
302
|
+
@classmethod
|
|
303
|
+
def list_available_providers(cls) -> list[str]:
|
|
304
|
+
"""List all available provider names.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
List of provider names
|
|
308
|
+
"""
|
|
309
|
+
return list(cls._providers.keys())
|
|
310
|
+
|
|
311
|
+
@classmethod
|
|
312
|
+
def list_providers(cls) -> list[str]:
|
|
313
|
+
"""List the canonical provider identifiers.
|
|
314
|
+
|
|
315
|
+
The unit tests expect these high-level names (not model aliases).
|
|
316
|
+
"""
|
|
317
|
+
return ["openai", "anthropic", "google", "cohere", "ollama"]
|
|
318
|
+
|
|
319
|
+
@classmethod
|
|
320
|
+
def supports_model(cls, model: str) -> bool:
|
|
321
|
+
"""Check if a model is supported.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
model: Model name
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
True if supported
|
|
328
|
+
"""
|
|
329
|
+
return cls._get_provider_class(model) is not None
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""LLM provider implementations."""
|