genxai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +3 -0
- cli/commands/__init__.py +6 -0
- cli/commands/approval.py +85 -0
- cli/commands/audit.py +127 -0
- cli/commands/metrics.py +25 -0
- cli/commands/tool.py +389 -0
- cli/main.py +32 -0
- genxai/__init__.py +81 -0
- genxai/api/__init__.py +5 -0
- genxai/api/app.py +21 -0
- genxai/config/__init__.py +5 -0
- genxai/config/settings.py +37 -0
- genxai/connectors/__init__.py +19 -0
- genxai/connectors/base.py +122 -0
- genxai/connectors/kafka.py +92 -0
- genxai/connectors/postgres_cdc.py +95 -0
- genxai/connectors/registry.py +44 -0
- genxai/connectors/sqs.py +94 -0
- genxai/connectors/webhook.py +73 -0
- genxai/core/__init__.py +37 -0
- genxai/core/agent/__init__.py +32 -0
- genxai/core/agent/base.py +206 -0
- genxai/core/agent/config_io.py +59 -0
- genxai/core/agent/registry.py +98 -0
- genxai/core/agent/runtime.py +970 -0
- genxai/core/communication/__init__.py +6 -0
- genxai/core/communication/collaboration.py +44 -0
- genxai/core/communication/message_bus.py +192 -0
- genxai/core/communication/protocols.py +35 -0
- genxai/core/execution/__init__.py +22 -0
- genxai/core/execution/metadata.py +181 -0
- genxai/core/execution/queue.py +201 -0
- genxai/core/graph/__init__.py +30 -0
- genxai/core/graph/checkpoints.py +77 -0
- genxai/core/graph/edges.py +131 -0
- genxai/core/graph/engine.py +813 -0
- genxai/core/graph/executor.py +516 -0
- genxai/core/graph/nodes.py +161 -0
- genxai/core/graph/trigger_runner.py +40 -0
- genxai/core/memory/__init__.py +19 -0
- genxai/core/memory/base.py +72 -0
- genxai/core/memory/embedding.py +327 -0
- genxai/core/memory/episodic.py +448 -0
- genxai/core/memory/long_term.py +467 -0
- genxai/core/memory/manager.py +543 -0
- genxai/core/memory/persistence.py +297 -0
- genxai/core/memory/procedural.py +461 -0
- genxai/core/memory/semantic.py +526 -0
- genxai/core/memory/shared.py +62 -0
- genxai/core/memory/short_term.py +303 -0
- genxai/core/memory/vector_store.py +508 -0
- genxai/core/memory/working.py +211 -0
- genxai/core/state/__init__.py +6 -0
- genxai/core/state/manager.py +293 -0
- genxai/core/state/schema.py +115 -0
- genxai/llm/__init__.py +14 -0
- genxai/llm/base.py +150 -0
- genxai/llm/factory.py +329 -0
- genxai/llm/providers/__init__.py +1 -0
- genxai/llm/providers/anthropic.py +249 -0
- genxai/llm/providers/cohere.py +274 -0
- genxai/llm/providers/google.py +334 -0
- genxai/llm/providers/ollama.py +147 -0
- genxai/llm/providers/openai.py +257 -0
- genxai/llm/routing.py +83 -0
- genxai/observability/__init__.py +6 -0
- genxai/observability/logging.py +327 -0
- genxai/observability/metrics.py +494 -0
- genxai/observability/tracing.py +372 -0
- genxai/performance/__init__.py +39 -0
- genxai/performance/cache.py +256 -0
- genxai/performance/pooling.py +289 -0
- genxai/security/audit.py +304 -0
- genxai/security/auth.py +315 -0
- genxai/security/cost_control.py +528 -0
- genxai/security/default_policies.py +44 -0
- genxai/security/jwt.py +142 -0
- genxai/security/oauth.py +226 -0
- genxai/security/pii.py +366 -0
- genxai/security/policy_engine.py +82 -0
- genxai/security/rate_limit.py +341 -0
- genxai/security/rbac.py +247 -0
- genxai/security/validation.py +218 -0
- genxai/tools/__init__.py +21 -0
- genxai/tools/base.py +383 -0
- genxai/tools/builtin/__init__.py +131 -0
- genxai/tools/builtin/communication/__init__.py +15 -0
- genxai/tools/builtin/communication/email_sender.py +159 -0
- genxai/tools/builtin/communication/notification_manager.py +167 -0
- genxai/tools/builtin/communication/slack_notifier.py +118 -0
- genxai/tools/builtin/communication/sms_sender.py +118 -0
- genxai/tools/builtin/communication/webhook_caller.py +136 -0
- genxai/tools/builtin/computation/__init__.py +15 -0
- genxai/tools/builtin/computation/calculator.py +101 -0
- genxai/tools/builtin/computation/code_executor.py +183 -0
- genxai/tools/builtin/computation/data_validator.py +259 -0
- genxai/tools/builtin/computation/hash_generator.py +129 -0
- genxai/tools/builtin/computation/regex_matcher.py +201 -0
- genxai/tools/builtin/data/__init__.py +15 -0
- genxai/tools/builtin/data/csv_processor.py +213 -0
- genxai/tools/builtin/data/data_transformer.py +299 -0
- genxai/tools/builtin/data/json_processor.py +233 -0
- genxai/tools/builtin/data/text_analyzer.py +288 -0
- genxai/tools/builtin/data/xml_processor.py +175 -0
- genxai/tools/builtin/database/__init__.py +15 -0
- genxai/tools/builtin/database/database_inspector.py +157 -0
- genxai/tools/builtin/database/mongodb_query.py +196 -0
- genxai/tools/builtin/database/redis_cache.py +167 -0
- genxai/tools/builtin/database/sql_query.py +145 -0
- genxai/tools/builtin/database/vector_search.py +163 -0
- genxai/tools/builtin/file/__init__.py +17 -0
- genxai/tools/builtin/file/directory_scanner.py +214 -0
- genxai/tools/builtin/file/file_compressor.py +237 -0
- genxai/tools/builtin/file/file_reader.py +102 -0
- genxai/tools/builtin/file/file_writer.py +122 -0
- genxai/tools/builtin/file/image_processor.py +186 -0
- genxai/tools/builtin/file/pdf_parser.py +144 -0
- genxai/tools/builtin/test/__init__.py +15 -0
- genxai/tools/builtin/test/async_simulator.py +62 -0
- genxai/tools/builtin/test/data_transformer.py +99 -0
- genxai/tools/builtin/test/error_generator.py +82 -0
- genxai/tools/builtin/test/simple_math.py +94 -0
- genxai/tools/builtin/test/string_processor.py +72 -0
- genxai/tools/builtin/web/__init__.py +15 -0
- genxai/tools/builtin/web/api_caller.py +161 -0
- genxai/tools/builtin/web/html_parser.py +330 -0
- genxai/tools/builtin/web/http_client.py +187 -0
- genxai/tools/builtin/web/url_validator.py +162 -0
- genxai/tools/builtin/web/web_scraper.py +170 -0
- genxai/tools/custom/my_test_tool_2.py +9 -0
- genxai/tools/dynamic.py +105 -0
- genxai/tools/mcp_server.py +167 -0
- genxai/tools/persistence/__init__.py +6 -0
- genxai/tools/persistence/models.py +55 -0
- genxai/tools/persistence/service.py +322 -0
- genxai/tools/registry.py +227 -0
- genxai/tools/security/__init__.py +11 -0
- genxai/tools/security/limits.py +214 -0
- genxai/tools/security/policy.py +20 -0
- genxai/tools/security/sandbox.py +248 -0
- genxai/tools/templates.py +435 -0
- genxai/triggers/__init__.py +19 -0
- genxai/triggers/base.py +104 -0
- genxai/triggers/file_watcher.py +75 -0
- genxai/triggers/queue.py +68 -0
- genxai/triggers/registry.py +82 -0
- genxai/triggers/schedule.py +66 -0
- genxai/triggers/webhook.py +68 -0
- genxai/utils/__init__.py +1 -0
- genxai/utils/tokens.py +295 -0
- genxai_framework-0.1.0.dist-info/METADATA +495 -0
- genxai_framework-0.1.0.dist-info/RECORD +156 -0
- genxai_framework-0.1.0.dist-info/WHEEL +5 -0
- genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
- genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
- genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Trigger-driven workflow execution utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from genxai.core.graph.executor import WorkflowExecutor
|
|
9
|
+
from genxai.triggers.base import TriggerEvent
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TriggerWorkflowRunner:
|
|
15
|
+
"""Bind trigger events to workflow execution."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
nodes: List[Dict[str, Any]],
|
|
20
|
+
edges: List[Dict[str, Any]],
|
|
21
|
+
openai_api_key: Optional[str] = None,
|
|
22
|
+
anthropic_api_key: Optional[str] = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
self.nodes = nodes
|
|
25
|
+
self.edges = edges
|
|
26
|
+
self.executor = WorkflowExecutor(
|
|
27
|
+
openai_api_key=openai_api_key,
|
|
28
|
+
anthropic_api_key=anthropic_api_key,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
async def handle_event(self, event: TriggerEvent) -> Dict[str, Any]:
|
|
32
|
+
"""Execute the workflow using the trigger event payload as input."""
|
|
33
|
+
logger.info("Trigger event received: %s", event.trigger_id)
|
|
34
|
+
input_data = event.payload or {}
|
|
35
|
+
result = await self.executor.execute(
|
|
36
|
+
nodes=self.nodes,
|
|
37
|
+
edges=self.edges,
|
|
38
|
+
input_data=input_data,
|
|
39
|
+
)
|
|
40
|
+
return result
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Memory system for GenXAI agents."""
|
|
2
|
+
|
|
3
|
+
from genxai.core.memory.base import Memory, MemoryType, MemoryConfig
|
|
4
|
+
from genxai.core.memory.short_term import ShortTermMemory
|
|
5
|
+
from genxai.core.memory.shared import SharedMemoryBus
|
|
6
|
+
from genxai.core.memory.long_term import LongTermMemory
|
|
7
|
+
from genxai.core.memory.manager import MemorySystem
|
|
8
|
+
from genxai.core.memory.persistence import MemoryPersistenceConfig, JsonMemoryStore
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Memory",
|
|
12
|
+
"MemoryType",
|
|
13
|
+
"MemoryConfig",
|
|
14
|
+
"ShortTermMemory",
|
|
15
|
+
"LongTermMemory",
|
|
16
|
+
"MemorySystem",
|
|
17
|
+
"MemoryPersistenceConfig",
|
|
18
|
+
"JsonMemoryStore",
|
|
19
|
+
]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Base memory classes and types."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MemoryType(str, Enum):
|
|
10
|
+
"""Types of memory."""
|
|
11
|
+
|
|
12
|
+
SHORT_TERM = "short_term"
|
|
13
|
+
LONG_TERM = "long_term"
|
|
14
|
+
EPISODIC = "episodic"
|
|
15
|
+
SEMANTIC = "semantic"
|
|
16
|
+
PROCEDURAL = "procedural"
|
|
17
|
+
WORKING = "working"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Memory(BaseModel):
|
|
21
|
+
"""Base memory unit."""
|
|
22
|
+
|
|
23
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
24
|
+
|
|
25
|
+
id: str
|
|
26
|
+
type: MemoryType
|
|
27
|
+
content: Any
|
|
28
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
29
|
+
timestamp: datetime
|
|
30
|
+
importance: float = Field(default=0.5, ge=0.0, le=1.0)
|
|
31
|
+
access_count: int = 0
|
|
32
|
+
# Default to "now" so callers don't have to provide it explicitly.
|
|
33
|
+
last_accessed: datetime = Field(default_factory=datetime.now)
|
|
34
|
+
embedding: Optional[List[float]] = None
|
|
35
|
+
tags: List[str] = Field(default_factory=list)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def __repr__(self) -> str:
|
|
39
|
+
"""String representation."""
|
|
40
|
+
return f"Memory(id={self.id}, type={self.type}, importance={self.importance})"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MemoryConfig(BaseModel):
|
|
44
|
+
"""Configuration for memory system."""
|
|
45
|
+
|
|
46
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
47
|
+
|
|
48
|
+
# Capacity settings
|
|
49
|
+
short_term_capacity: int = Field(default=20, ge=1)
|
|
50
|
+
working_capacity: int = Field(default=5, ge=1)
|
|
51
|
+
|
|
52
|
+
# Feature flags
|
|
53
|
+
long_term_enabled: bool = True
|
|
54
|
+
episodic_enabled: bool = True
|
|
55
|
+
semantic_enabled: bool = True
|
|
56
|
+
procedural_enabled: bool = True
|
|
57
|
+
|
|
58
|
+
# Storage backends
|
|
59
|
+
vector_db: Optional[str] = "pinecone"
|
|
60
|
+
graph_db: Optional[str] = "neo4j"
|
|
61
|
+
cache_db: Optional[str] = "redis"
|
|
62
|
+
|
|
63
|
+
# Consolidation settings
|
|
64
|
+
consolidation_enabled: bool = True
|
|
65
|
+
consolidation_schedule: str = "0 2 * * *" # Daily at 2 AM
|
|
66
|
+
importance_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
|
|
67
|
+
retention_days: int = Field(default=365, ge=1)
|
|
68
|
+
|
|
69
|
+
# Embedding settings
|
|
70
|
+
embedding_model: str = "text-embedding-ada-002"
|
|
71
|
+
embedding_dimension: int = 1536
|
|
72
|
+
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
"""Embedding service for memory vectorization."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Union
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EmbeddingService(ABC):
|
|
12
|
+
"""Abstract base class for embedding services."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
async def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
|
16
|
+
"""Generate embeddings for text.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
text: Single text string or list of texts
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Single embedding or list of embeddings
|
|
23
|
+
"""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def get_dimension(self) -> int:
|
|
28
|
+
"""Get embedding dimension."""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OpenAIEmbeddingService(EmbeddingService):
|
|
33
|
+
"""OpenAI embedding service."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model: str = "text-embedding-ada-002",
|
|
38
|
+
api_key: Optional[str] = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Initialize OpenAI embedding service.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model: OpenAI embedding model
|
|
44
|
+
api_key: OpenAI API key (or set OPENAI_API_KEY env var)
|
|
45
|
+
"""
|
|
46
|
+
self.model = model
|
|
47
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
48
|
+
self._client = None
|
|
49
|
+
self._initialized = False
|
|
50
|
+
|
|
51
|
+
# Model dimensions
|
|
52
|
+
self._dimensions = {
|
|
53
|
+
"text-embedding-ada-002": 1536,
|
|
54
|
+
"text-embedding-3-small": 1536,
|
|
55
|
+
"text-embedding-3-large": 3072,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
async def _ensure_initialized(self) -> None:
|
|
59
|
+
"""Ensure OpenAI client is initialized."""
|
|
60
|
+
if self._initialized:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
from openai import AsyncOpenAI
|
|
65
|
+
|
|
66
|
+
if not self.api_key:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"OpenAI API key required. Set OPENAI_API_KEY env var."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self._client = AsyncOpenAI(api_key=self.api_key)
|
|
72
|
+
self._initialized = True
|
|
73
|
+
logger.info(f"Initialized OpenAI embedding service: {self.model}")
|
|
74
|
+
|
|
75
|
+
except ImportError:
|
|
76
|
+
logger.error(
|
|
77
|
+
"OpenAI not installed. Install with: pip install openai"
|
|
78
|
+
)
|
|
79
|
+
raise RuntimeError("OpenAI not available")
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.error(f"Failed to initialize OpenAI: {e}")
|
|
82
|
+
raise
|
|
83
|
+
|
|
84
|
+
async def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
|
85
|
+
"""Generate embeddings for text."""
|
|
86
|
+
await self._ensure_initialized()
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
# Handle single text vs batch
|
|
90
|
+
is_single = isinstance(text, str)
|
|
91
|
+
texts = [text] if is_single else text
|
|
92
|
+
|
|
93
|
+
# Generate embeddings
|
|
94
|
+
response = await self._client.embeddings.create(
|
|
95
|
+
model=self.model,
|
|
96
|
+
input=texts,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Extract embeddings
|
|
100
|
+
embeddings = [item.embedding for item in response.data]
|
|
101
|
+
|
|
102
|
+
# Return single embedding or list
|
|
103
|
+
return embeddings[0] if is_single else embeddings
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.error(f"Failed to generate embeddings: {e}")
|
|
107
|
+
raise
|
|
108
|
+
|
|
109
|
+
def get_dimension(self) -> int:
|
|
110
|
+
"""Get embedding dimension."""
|
|
111
|
+
return self._dimensions.get(self.model, 1536)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class LocalEmbeddingService(EmbeddingService):
|
|
115
|
+
"""Local embedding service using sentence-transformers."""
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
model: str = "all-MiniLM-L6-v2",
|
|
120
|
+
device: Optional[str] = None,
|
|
121
|
+
) -> None:
|
|
122
|
+
"""Initialize local embedding service.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
model: Sentence-transformers model name
|
|
126
|
+
device: Device to use ('cpu', 'cuda', or None for auto)
|
|
127
|
+
"""
|
|
128
|
+
self.model_name = model
|
|
129
|
+
self.device = device
|
|
130
|
+
self._model = None
|
|
131
|
+
self._initialized = False
|
|
132
|
+
|
|
133
|
+
# Common model dimensions
|
|
134
|
+
self._dimensions = {
|
|
135
|
+
"all-MiniLM-L6-v2": 384,
|
|
136
|
+
"all-mpnet-base-v2": 768,
|
|
137
|
+
"all-MiniLM-L12-v2": 384,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
async def _ensure_initialized(self) -> None:
|
|
141
|
+
"""Ensure model is loaded."""
|
|
142
|
+
if self._initialized:
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
from sentence_transformers import SentenceTransformer
|
|
147
|
+
|
|
148
|
+
self._model = SentenceTransformer(self.model_name, device=self.device)
|
|
149
|
+
self._initialized = True
|
|
150
|
+
logger.info(f"Initialized local embedding model: {self.model_name}")
|
|
151
|
+
|
|
152
|
+
except ImportError:
|
|
153
|
+
logger.error(
|
|
154
|
+
"sentence-transformers not installed. "
|
|
155
|
+
"Install with: pip install sentence-transformers"
|
|
156
|
+
)
|
|
157
|
+
raise RuntimeError("sentence-transformers not available")
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.error(f"Failed to load model: {e}")
|
|
160
|
+
raise
|
|
161
|
+
|
|
162
|
+
async def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
|
163
|
+
"""Generate embeddings for text."""
|
|
164
|
+
await self._ensure_initialized()
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
# Handle single text vs batch
|
|
168
|
+
is_single = isinstance(text, str)
|
|
169
|
+
texts = [text] if is_single else text
|
|
170
|
+
|
|
171
|
+
# Generate embeddings
|
|
172
|
+
embeddings = self._model.encode(texts, convert_to_numpy=True)
|
|
173
|
+
|
|
174
|
+
# Convert to list
|
|
175
|
+
embeddings = embeddings.tolist()
|
|
176
|
+
|
|
177
|
+
# Return single embedding or list
|
|
178
|
+
return embeddings[0] if is_single else embeddings
|
|
179
|
+
|
|
180
|
+
except Exception as e:
|
|
181
|
+
logger.error(f"Failed to generate embeddings: {e}")
|
|
182
|
+
raise
|
|
183
|
+
|
|
184
|
+
def get_dimension(self) -> int:
|
|
185
|
+
"""Get embedding dimension."""
|
|
186
|
+
if self._initialized and self._model:
|
|
187
|
+
return self._model.get_sentence_embedding_dimension()
|
|
188
|
+
return self._dimensions.get(self.model_name, 384)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class CohereEmbeddingService(EmbeddingService):
|
|
192
|
+
"""Cohere embedding service."""
|
|
193
|
+
|
|
194
|
+
def __init__(
|
|
195
|
+
self,
|
|
196
|
+
model: str = "embed-english-v3.0",
|
|
197
|
+
api_key: Optional[str] = None,
|
|
198
|
+
input_type: str = "search_document",
|
|
199
|
+
) -> None:
|
|
200
|
+
"""Initialize Cohere embedding service.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
model: Cohere embedding model
|
|
204
|
+
api_key: Cohere API key (or set COHERE_API_KEY env var)
|
|
205
|
+
input_type: Input type ('search_document', 'search_query', 'classification', 'clustering')
|
|
206
|
+
"""
|
|
207
|
+
self.model = model
|
|
208
|
+
self.api_key = api_key or os.getenv("COHERE_API_KEY")
|
|
209
|
+
self.input_type = input_type
|
|
210
|
+
self._client = None
|
|
211
|
+
self._initialized = False
|
|
212
|
+
|
|
213
|
+
# Model dimensions
|
|
214
|
+
self._dimensions = {
|
|
215
|
+
"embed-english-v3.0": 1024,
|
|
216
|
+
"embed-english-light-v3.0": 384,
|
|
217
|
+
"embed-multilingual-v3.0": 1024,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
async def _ensure_initialized(self) -> None:
|
|
221
|
+
"""Ensure Cohere client is initialized."""
|
|
222
|
+
if self._initialized:
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
import cohere
|
|
227
|
+
|
|
228
|
+
if not self.api_key:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
"Cohere API key required. Set COHERE_API_KEY env var."
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
self._client = cohere.AsyncClient(api_key=self.api_key)
|
|
234
|
+
self._initialized = True
|
|
235
|
+
logger.info(f"Initialized Cohere embedding service: {self.model}")
|
|
236
|
+
|
|
237
|
+
except ImportError:
|
|
238
|
+
logger.error(
|
|
239
|
+
"Cohere not installed. Install with: pip install cohere"
|
|
240
|
+
)
|
|
241
|
+
raise RuntimeError("Cohere not available")
|
|
242
|
+
except Exception as e:
|
|
243
|
+
logger.error(f"Failed to initialize Cohere: {e}")
|
|
244
|
+
raise
|
|
245
|
+
|
|
246
|
+
async def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
|
247
|
+
"""Generate embeddings for text."""
|
|
248
|
+
await self._ensure_initialized()
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
# Handle single text vs batch
|
|
252
|
+
is_single = isinstance(text, str)
|
|
253
|
+
texts = [text] if is_single else text
|
|
254
|
+
|
|
255
|
+
# Generate embeddings
|
|
256
|
+
response = await self._client.embed(
|
|
257
|
+
texts=texts,
|
|
258
|
+
model=self.model,
|
|
259
|
+
input_type=self.input_type,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Extract embeddings
|
|
263
|
+
embeddings = response.embeddings
|
|
264
|
+
|
|
265
|
+
# Return single embedding or list
|
|
266
|
+
return embeddings[0] if is_single else embeddings
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
logger.error(f"Failed to generate embeddings: {e}")
|
|
270
|
+
raise
|
|
271
|
+
|
|
272
|
+
def get_dimension(self) -> int:
|
|
273
|
+
"""Get embedding dimension."""
|
|
274
|
+
return self._dimensions.get(self.model, 1024)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class EmbeddingServiceFactory:
|
|
278
|
+
"""Factory for creating embedding services."""
|
|
279
|
+
|
|
280
|
+
_services = {
|
|
281
|
+
"openai": OpenAIEmbeddingService,
|
|
282
|
+
"local": LocalEmbeddingService,
|
|
283
|
+
"cohere": CohereEmbeddingService,
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
@classmethod
|
|
287
|
+
def create(
|
|
288
|
+
cls,
|
|
289
|
+
provider: str,
|
|
290
|
+
**kwargs
|
|
291
|
+
) -> EmbeddingService:
|
|
292
|
+
"""Create an embedding service instance.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
provider: Embedding provider ("openai", "local", "cohere")
|
|
296
|
+
**kwargs: Provider-specific arguments
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
EmbeddingService instance
|
|
300
|
+
|
|
301
|
+
Raises:
|
|
302
|
+
ValueError: If provider is not supported
|
|
303
|
+
"""
|
|
304
|
+
if provider not in cls._services:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Unsupported embedding provider: {provider}. "
|
|
307
|
+
f"Supported: {list(cls._services.keys())}"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
service_class = cls._services[provider]
|
|
311
|
+
return service_class(**kwargs)
|
|
312
|
+
|
|
313
|
+
@classmethod
|
|
314
|
+
def register(cls, name: str, service_class: type) -> None:
|
|
315
|
+
"""Register a custom embedding service.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
name: Name of the service
|
|
319
|
+
service_class: Embedding service class
|
|
320
|
+
"""
|
|
321
|
+
cls._services[name] = service_class
|
|
322
|
+
logger.info(f"Registered embedding service: {name}")
|
|
323
|
+
|
|
324
|
+
@classmethod
|
|
325
|
+
def list_providers(cls) -> List[str]:
|
|
326
|
+
"""List available embedding providers."""
|
|
327
|
+
return list(cls._services.keys())
|