agno 2.0.0rc2__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +6009 -2874
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +385 -6
- agno/db/dynamo/dynamo.py +388 -81
- agno/db/dynamo/schemas.py +47 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +435 -64
- agno/db/firestore/schemas.py +11 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +384 -42
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +351 -66
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +339 -48
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +510 -37
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2036 -0
- agno/db/mongo/mongo.py +653 -76
- agno/db/mongo/schemas.py +13 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/mysql.py +687 -25
- agno/db/mysql/schemas.py +61 -37
- agno/db/mysql/utils.py +60 -2
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2001 -0
- agno/db/postgres/postgres.py +676 -57
- agno/db/postgres/schemas.py +43 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +344 -38
- agno/db/redis/schemas.py +18 -0
- agno/db/redis/utils.py +60 -2
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/memory.py +13 -0
- agno/db/singlestore/schemas.py +26 -1
- agno/db/singlestore/singlestore.py +687 -53
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2371 -0
- agno/db/sqlite/schemas.py +24 -0
- agno/db/sqlite/sqlite.py +774 -85
- agno/db/sqlite/utils.py +168 -5
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +309 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1361 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +50 -22
- agno/eval/accuracy.py +50 -43
- agno/eval/performance.py +6 -3
- agno/eval/reliability.py +6 -3
- agno/eval/utils.py +33 -16
- agno/exceptions.py +68 -1
- agno/filters.py +354 -0
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +52 -0
- agno/integrations/discord/client.py +1 -0
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +1 -1
- agno/knowledge/chunking/semantic.py +40 -8
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/aws_bedrock.py +9 -4
- agno/knowledge/embedder/azure_openai.py +54 -0
- agno/knowledge/embedder/base.py +2 -0
- agno/knowledge/embedder/cohere.py +184 -5
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/google.py +79 -1
- agno/knowledge/embedder/huggingface.py +9 -4
- agno/knowledge/embedder/jina.py +63 -0
- agno/knowledge/embedder/mistral.py +78 -11
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +13 -0
- agno/knowledge/embedder/openai.py +37 -65
- agno/knowledge/embedder/sentence_transformer.py +8 -4
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/embedder/voyageai.py +69 -16
- agno/knowledge/knowledge.py +595 -187
- agno/knowledge/reader/base.py +9 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +290 -0
- agno/knowledge/reader/json_reader.py +6 -5
- agno/knowledge/reader/markdown_reader.py +13 -13
- agno/knowledge/reader/pdf_reader.py +43 -68
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +51 -6
- agno/knowledge/reader/s3_reader.py +3 -15
- agno/knowledge/reader/tavily_reader.py +194 -0
- agno/knowledge/reader/text_reader.py +13 -13
- agno/knowledge/reader/web_search_reader.py +2 -43
- agno/knowledge/reader/website_reader.py +43 -25
- agno/knowledge/reranker/__init__.py +3 -0
- agno/knowledge/types.py +9 -0
- agno/knowledge/utils.py +20 -0
- agno/media.py +339 -266
- agno/memory/manager.py +336 -82
- agno/models/aimlapi/aimlapi.py +2 -2
- agno/models/anthropic/claude.py +183 -37
- agno/models/aws/bedrock.py +52 -112
- agno/models/aws/claude.py +33 -1
- agno/models/azure/ai_foundry.py +33 -15
- agno/models/azure/openai_chat.py +25 -8
- agno/models/base.py +1011 -566
- agno/models/cerebras/cerebras.py +19 -13
- agno/models/cerebras/cerebras_openai.py +8 -5
- agno/models/cohere/chat.py +27 -1
- agno/models/cometapi/__init__.py +5 -0
- agno/models/cometapi/cometapi.py +57 -0
- agno/models/dashscope/dashscope.py +1 -0
- agno/models/deepinfra/deepinfra.py +2 -2
- agno/models/deepseek/deepseek.py +2 -2
- agno/models/fireworks/fireworks.py +2 -2
- agno/models/google/gemini.py +110 -37
- agno/models/groq/groq.py +28 -11
- agno/models/huggingface/huggingface.py +2 -1
- agno/models/internlm/internlm.py +2 -2
- agno/models/langdb/langdb.py +4 -4
- agno/models/litellm/chat.py +18 -1
- agno/models/litellm/litellm_openai.py +2 -2
- agno/models/llama_cpp/__init__.py +5 -0
- agno/models/llama_cpp/llama_cpp.py +22 -0
- agno/models/message.py +143 -4
- agno/models/meta/llama.py +27 -10
- agno/models/meta/llama_openai.py +5 -17
- agno/models/nebius/nebius.py +6 -6
- agno/models/nexus/__init__.py +3 -0
- agno/models/nexus/nexus.py +22 -0
- agno/models/nvidia/nvidia.py +2 -2
- agno/models/ollama/chat.py +60 -6
- agno/models/openai/chat.py +102 -43
- agno/models/openai/responses.py +103 -106
- agno/models/openrouter/openrouter.py +41 -3
- agno/models/perplexity/perplexity.py +4 -5
- agno/models/portkey/portkey.py +3 -3
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +52 -0
- agno/models/response.py +81 -5
- agno/models/sambanova/sambanova.py +2 -2
- agno/models/siliconflow/__init__.py +5 -0
- agno/models/siliconflow/siliconflow.py +25 -0
- agno/models/together/together.py +2 -2
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +2 -2
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +96 -0
- agno/models/vllm/vllm.py +1 -0
- agno/models/xai/xai.py +3 -2
- agno/os/app.py +543 -175
- agno/os/auth.py +24 -14
- agno/os/config.py +1 -0
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +250 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/agui.py +23 -7
- agno/os/interfaces/agui/router.py +27 -3
- agno/os/interfaces/agui/utils.py +242 -142
- agno/os/interfaces/base.py +6 -2
- agno/os/interfaces/slack/router.py +81 -23
- agno/os/interfaces/slack/slack.py +29 -14
- agno/os/interfaces/whatsapp/router.py +11 -4
- agno/os/interfaces/whatsapp/whatsapp.py +14 -7
- agno/os/mcp.py +111 -54
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +556 -139
- agno/os/routers/evals/evals.py +71 -34
- agno/os/routers/evals/schemas.py +31 -31
- agno/os/routers/evals/utils.py +6 -5
- agno/os/routers/health.py +31 -0
- agno/os/routers/home.py +52 -0
- agno/os/routers/knowledge/knowledge.py +185 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +158 -53
- agno/os/routers/memory/schemas.py +20 -16
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +499 -38
- agno/os/schema.py +308 -198
- agno/os/utils.py +401 -41
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/azure_ai_foundry.py +2 -2
- agno/reasoning/deepseek.py +2 -2
- agno/reasoning/default.py +3 -1
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/groq.py +2 -2
- agno/reasoning/ollama.py +2 -2
- agno/reasoning/openai.py +7 -2
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +266 -112
- agno/run/base.py +53 -24
- agno/run/team.py +252 -111
- agno/run/workflow.py +156 -45
- agno/session/agent.py +105 -89
- agno/session/summary.py +65 -25
- agno/session/team.py +176 -96
- agno/session/workflow.py +406 -40
- agno/team/team.py +3854 -1692
- agno/tools/brightdata.py +3 -3
- agno/tools/cartesia.py +3 -5
- agno/tools/dalle.py +9 -8
- agno/tools/decorator.py +4 -2
- agno/tools/desi_vocal.py +2 -2
- agno/tools/duckduckgo.py +15 -11
- agno/tools/e2b.py +20 -13
- agno/tools/eleven_labs.py +26 -28
- agno/tools/exa.py +21 -16
- agno/tools/fal.py +4 -4
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +350 -0
- agno/tools/firecrawl.py +4 -4
- agno/tools/function.py +257 -37
- agno/tools/giphy.py +2 -2
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +270 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/knowledge.py +3 -3
- agno/tools/lumalab.py +3 -3
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +284 -0
- agno/tools/mem0.py +11 -17
- agno/tools/memori.py +1 -53
- agno/tools/memory.py +419 -0
- agno/tools/models/azure_openai.py +2 -2
- agno/tools/models/gemini.py +3 -3
- agno/tools/models/groq.py +3 -5
- agno/tools/models/nebius.py +7 -7
- agno/tools/models_labs.py +25 -15
- agno/tools/notion.py +204 -0
- agno/tools/openai.py +4 -9
- agno/tools/opencv.py +3 -3
- agno/tools/parallel.py +314 -0
- agno/tools/replicate.py +7 -7
- agno/tools/scrapegraph.py +58 -31
- agno/tools/searxng.py +2 -2
- agno/tools/serper.py +2 -2
- agno/tools/slack.py +18 -3
- agno/tools/spider.py +2 -2
- agno/tools/tavily.py +146 -0
- agno/tools/whatsapp.py +1 -1
- agno/tools/workflow.py +278 -0
- agno/tools/yfinance.py +12 -11
- agno/utils/agent.py +820 -0
- agno/utils/audio.py +27 -0
- agno/utils/common.py +90 -1
- agno/utils/events.py +222 -7
- agno/utils/gemini.py +181 -23
- agno/utils/hooks.py +57 -0
- agno/utils/http.py +111 -0
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +95 -5
- agno/utils/media.py +188 -10
- agno/utils/merge_dict.py +22 -1
- agno/utils/message.py +60 -0
- agno/utils/models/claude.py +40 -11
- agno/utils/models/cohere.py +1 -1
- agno/utils/models/watsonx.py +1 -1
- agno/utils/openai.py +1 -1
- agno/utils/print_response/agent.py +105 -21
- agno/utils/print_response/team.py +103 -38
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/reasoning.py +22 -1
- agno/utils/serialize.py +32 -0
- agno/utils/streamlit.py +16 -10
- agno/utils/string.py +41 -0
- agno/utils/team.py +98 -9
- agno/utils/tools.py +1 -1
- agno/vectordb/base.py +23 -4
- agno/vectordb/cassandra/cassandra.py +65 -9
- agno/vectordb/chroma/chromadb.py +182 -38
- agno/vectordb/clickhouse/clickhousedb.py +64 -11
- agno/vectordb/couchbase/couchbase.py +105 -10
- agno/vectordb/lancedb/lance_db.py +183 -135
- agno/vectordb/langchaindb/langchaindb.py +25 -7
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/__init__.py +3 -0
- agno/vectordb/llamaindex/llamaindexdb.py +46 -7
- agno/vectordb/milvus/milvus.py +126 -9
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +112 -7
- agno/vectordb/pgvector/pgvector.py +142 -21
- agno/vectordb/pineconedb/pineconedb.py +80 -8
- agno/vectordb/qdrant/qdrant.py +125 -39
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +694 -0
- agno/vectordb/singlestore/singlestore.py +111 -25
- agno/vectordb/surrealdb/surrealdb.py +31 -5
- agno/vectordb/upstashdb/upstashdb.py +76 -8
- agno/vectordb/weaviate/weaviate.py +86 -15
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +112 -18
- agno/workflow/loop.py +69 -10
- agno/workflow/parallel.py +266 -118
- agno/workflow/router.py +110 -17
- agno/workflow/step.py +645 -136
- agno/workflow/steps.py +65 -6
- agno/workflow/types.py +71 -33
- agno/workflow/workflow.py +2113 -300
- agno-2.3.0.dist-info/METADATA +618 -0
- agno-2.3.0.dist-info/RECORD +577 -0
- agno-2.3.0.dist-info/licenses/LICENSE +201 -0
- agno/knowledge/reader/url_reader.py +0 -128
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -610
- agno/utils/models/aws_claude.py +0 -170
- agno-2.0.0rc2.dist-info/METADATA +0 -355
- agno-2.0.0rc2.dist-info/RECORD +0 -515
- agno-2.0.0rc2.dist-info/licenses/LICENSE +0 -375
- {agno-2.0.0rc2.dist-info → agno-2.3.0.dist-info}/WHEEL +0 -0
- {agno-2.0.0rc2.dist-info → agno-2.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from os import getenv
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
from agno.knowledge.embedder.base import Embedder
|
|
7
|
+
from agno.utils.log import logger
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from vllm import LLM # type: ignore
|
|
11
|
+
from vllm.outputs import EmbeddingRequestOutput # type: ignore
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError("`vllm` not installed. Please install using `pip install vllm`.")
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from openai import AsyncOpenAI
|
|
17
|
+
from openai import OpenAI as OpenAIClient
|
|
18
|
+
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class VLLMEmbedder(Embedder):
|
|
23
|
+
"""
|
|
24
|
+
VLLM Embedder supporting both local and remote deployment modes.
|
|
25
|
+
|
|
26
|
+
Local Mode (default):
|
|
27
|
+
- Loads model locally and runs inference on your GPU/CPU
|
|
28
|
+
- No API key required
|
|
29
|
+
- Example: VLLMEmbedder(id="intfloat/e5-mistral-7b-instruct")
|
|
30
|
+
|
|
31
|
+
Remote Mode:
|
|
32
|
+
- Connects to a remote vLLM server via OpenAI-compatible API
|
|
33
|
+
- Uses OpenAI SDK to communicate with vLLM's OpenAI-compatible endpoint
|
|
34
|
+
- Requires base_url and optionally api_key
|
|
35
|
+
- Example: VLLMEmbedder(base_url="http://localhost:8000/v1", api_key="your-key")
|
|
36
|
+
- Ref: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
id: str = "sentence-transformers/all-MiniLM-L6-v2"
|
|
40
|
+
dimensions: int = 4096
|
|
41
|
+
# Local mode parameters
|
|
42
|
+
enforce_eager: bool = True
|
|
43
|
+
vllm_kwargs: Optional[Dict[str, Any]] = None
|
|
44
|
+
vllm_client: Optional[LLM] = None
|
|
45
|
+
# Remote mode parameters
|
|
46
|
+
api_key: Optional[str] = getenv("VLLM_API_KEY")
|
|
47
|
+
base_url: Optional[str] = None
|
|
48
|
+
request_params: Optional[Dict[str, Any]] = None
|
|
49
|
+
client_params: Optional[Dict[str, Any]] = None
|
|
50
|
+
remote_client: Optional["OpenAIClient"] = None # OpenAI-compatible client for vLLM server
|
|
51
|
+
async_remote_client: Optional["AsyncOpenAI"] = None # Async OpenAI-compatible client for vLLM server
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def is_remote(self) -> bool:
|
|
55
|
+
"""Determine if we should use remote mode."""
|
|
56
|
+
return self.base_url is not None
|
|
57
|
+
|
|
58
|
+
def _get_vllm_client(self) -> LLM:
|
|
59
|
+
"""Get local VLLM client."""
|
|
60
|
+
if self.vllm_client:
|
|
61
|
+
return self.vllm_client
|
|
62
|
+
|
|
63
|
+
_vllm_params: Dict[str, Any] = {
|
|
64
|
+
"model": self.id,
|
|
65
|
+
"task": "embed",
|
|
66
|
+
"enforce_eager": self.enforce_eager,
|
|
67
|
+
}
|
|
68
|
+
if self.vllm_kwargs:
|
|
69
|
+
_vllm_params.update(self.vllm_kwargs)
|
|
70
|
+
self.vllm_client = LLM(**_vllm_params)
|
|
71
|
+
return self.vllm_client
|
|
72
|
+
|
|
73
|
+
def _get_remote_client(self) -> "OpenAIClient":
|
|
74
|
+
"""Get OpenAI-compatible client for remote vLLM server."""
|
|
75
|
+
if self.remote_client:
|
|
76
|
+
return self.remote_client
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
from openai import OpenAI as OpenAIClient
|
|
80
|
+
except ImportError:
|
|
81
|
+
raise ImportError("`openai` package required for remote vLLM mode. ")
|
|
82
|
+
|
|
83
|
+
_client_params: Dict[str, Any] = {
|
|
84
|
+
"api_key": self.api_key or "EMPTY", # VLLM can run without API key
|
|
85
|
+
"base_url": self.base_url,
|
|
86
|
+
}
|
|
87
|
+
if self.client_params:
|
|
88
|
+
_client_params.update(self.client_params)
|
|
89
|
+
self.remote_client = OpenAIClient(**_client_params)
|
|
90
|
+
return self.remote_client
|
|
91
|
+
|
|
92
|
+
def _get_async_remote_client(self) -> "AsyncOpenAI":
|
|
93
|
+
"""Get async OpenAI-compatible client for remote vLLM server."""
|
|
94
|
+
if self.async_remote_client:
|
|
95
|
+
return self.async_remote_client
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
from openai import AsyncOpenAI
|
|
99
|
+
except ImportError:
|
|
100
|
+
raise ImportError("`openai` package required for remote vLLM mode. ")
|
|
101
|
+
|
|
102
|
+
_client_params: Dict[str, Any] = {
|
|
103
|
+
"api_key": self.api_key or "EMPTY",
|
|
104
|
+
"base_url": self.base_url,
|
|
105
|
+
}
|
|
106
|
+
if self.client_params:
|
|
107
|
+
_client_params.update(self.client_params)
|
|
108
|
+
self.async_remote_client = AsyncOpenAI(**_client_params)
|
|
109
|
+
return self.async_remote_client
|
|
110
|
+
|
|
111
|
+
def _create_embedding_local(self, text: str) -> Optional[EmbeddingRequestOutput]:
|
|
112
|
+
"""Create embedding using local VLLM."""
|
|
113
|
+
try:
|
|
114
|
+
outputs = self._get_vllm_client().embed([text])
|
|
115
|
+
return outputs[0] if outputs else None
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.warning(f"Error creating local embedding: {e}")
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
def _create_embedding_remote(self, text: str) -> "CreateEmbeddingResponse":
|
|
121
|
+
"""Create embedding using remote vLLM server."""
|
|
122
|
+
_request_params: Dict[str, Any] = {
|
|
123
|
+
"input": text,
|
|
124
|
+
"model": self.id,
|
|
125
|
+
}
|
|
126
|
+
if self.request_params:
|
|
127
|
+
_request_params.update(self.request_params)
|
|
128
|
+
return self._get_remote_client().embeddings.create(**_request_params)
|
|
129
|
+
|
|
130
|
+
def get_embedding(self, text: str) -> List[float]:
|
|
131
|
+
try:
|
|
132
|
+
if self.is_remote:
|
|
133
|
+
# Remote mode: OpenAI-compatible API
|
|
134
|
+
response: "CreateEmbeddingResponse" = self._create_embedding_remote(text=text)
|
|
135
|
+
return response.data[0].embedding
|
|
136
|
+
else:
|
|
137
|
+
# Local mode: Direct VLLM
|
|
138
|
+
output = self._create_embedding_local(text=text)
|
|
139
|
+
if output and hasattr(output, "outputs") and hasattr(output.outputs, "embedding"):
|
|
140
|
+
embedding = output.outputs.embedding
|
|
141
|
+
if len(embedding) != self.dimensions:
|
|
142
|
+
logger.warning(f"Expected embedding dimension {self.dimensions}, but got {len(embedding)}")
|
|
143
|
+
return embedding
|
|
144
|
+
return []
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.warning(f"Error extracting embedding: {e}")
|
|
147
|
+
return []
|
|
148
|
+
|
|
149
|
+
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
|
|
150
|
+
if self.is_remote:
|
|
151
|
+
try:
|
|
152
|
+
response: "CreateEmbeddingResponse" = self._create_embedding_remote(text=text)
|
|
153
|
+
embedding = response.data[0].embedding
|
|
154
|
+
usage = response.usage
|
|
155
|
+
if usage:
|
|
156
|
+
return embedding, usage.model_dump()
|
|
157
|
+
return embedding, None
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.warning(f"Error in remote embedding: {e}")
|
|
160
|
+
return [], None
|
|
161
|
+
else:
|
|
162
|
+
embedding = self.get_embedding(text=text)
|
|
163
|
+
# Local VLLM doesn't provide usage information
|
|
164
|
+
return embedding, None
|
|
165
|
+
|
|
166
|
+
async def async_get_embedding(self, text: str) -> List[float]:
|
|
167
|
+
"""Async version of get_embedding using thread executor for local mode."""
|
|
168
|
+
if self.is_remote:
|
|
169
|
+
# Remote mode: async client for vLLM server
|
|
170
|
+
try:
|
|
171
|
+
req: Dict[str, Any] = {
|
|
172
|
+
"input": text,
|
|
173
|
+
"model": self.id,
|
|
174
|
+
}
|
|
175
|
+
if self.request_params:
|
|
176
|
+
req.update(self.request_params)
|
|
177
|
+
response: "CreateEmbeddingResponse" = await self._get_async_remote_client().embeddings.create(**req)
|
|
178
|
+
return response.data[0].embedding
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.warning(f"Error in async remote embedding: {e}")
|
|
181
|
+
return []
|
|
182
|
+
else:
|
|
183
|
+
# Local mode: use thread executor for CPU-bound operations
|
|
184
|
+
loop = asyncio.get_event_loop()
|
|
185
|
+
return await loop.run_in_executor(None, self.get_embedding, text)
|
|
186
|
+
|
|
187
|
+
async def async_get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
|
|
188
|
+
"""Async version of get_embedding_and_usage using thread executor for local mode."""
|
|
189
|
+
if self.is_remote:
|
|
190
|
+
try:
|
|
191
|
+
req: Dict[str, Any] = {
|
|
192
|
+
"input": text,
|
|
193
|
+
"model": self.id,
|
|
194
|
+
}
|
|
195
|
+
if self.request_params:
|
|
196
|
+
req.update(self.request_params)
|
|
197
|
+
response: "CreateEmbeddingResponse" = await self._get_async_remote_client().embeddings.create(**req)
|
|
198
|
+
embedding = response.data[0].embedding
|
|
199
|
+
usage = response.usage
|
|
200
|
+
return embedding, usage.model_dump() if usage else None
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.warning(f"Error in async remote embedding: {e}")
|
|
203
|
+
return [], None
|
|
204
|
+
else:
|
|
205
|
+
# Local mode: use thread executor for CPU-bound operations
|
|
206
|
+
try:
|
|
207
|
+
loop = asyncio.get_event_loop()
|
|
208
|
+
return await loop.run_in_executor(None, self.get_embedding_and_usage, text)
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logger.warning(f"Error in async local embedding: {e}")
|
|
211
|
+
return [], None
|
|
212
|
+
|
|
213
|
+
async def async_get_embeddings_batch_and_usage(
|
|
214
|
+
self, texts: List[str]
|
|
215
|
+
) -> Tuple[List[List[float]], List[Optional[Dict]]]:
|
|
216
|
+
"""
|
|
217
|
+
Get embeddings and usage for multiple texts in batches (async version).
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
texts: List of text strings to embed
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Tuple of (List of embedding vectors, List of usage dictionaries)
|
|
224
|
+
"""
|
|
225
|
+
all_embeddings = []
|
|
226
|
+
all_usage = []
|
|
227
|
+
logger.info(f"Getting embeddings for {len(texts)} texts in batches of {self.batch_size} (async)")
|
|
228
|
+
|
|
229
|
+
for i in range(0, len(texts), self.batch_size):
|
|
230
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
if self.is_remote:
|
|
234
|
+
# Remote mode: use batch API
|
|
235
|
+
req: Dict[str, Any] = {
|
|
236
|
+
"input": batch_texts,
|
|
237
|
+
"model": self.id,
|
|
238
|
+
}
|
|
239
|
+
if self.request_params:
|
|
240
|
+
req.update(self.request_params)
|
|
241
|
+
response: "CreateEmbeddingResponse" = await self._get_async_remote_client().embeddings.create(**req)
|
|
242
|
+
batch_embeddings = [data.embedding for data in response.data]
|
|
243
|
+
all_embeddings.extend(batch_embeddings)
|
|
244
|
+
|
|
245
|
+
# For each embedding in the batch, add the same usage information
|
|
246
|
+
usage_dict = response.usage.model_dump() if response.usage else None
|
|
247
|
+
all_usage.extend([usage_dict] * len(batch_embeddings))
|
|
248
|
+
else:
|
|
249
|
+
# Local mode: process individually using thread executor
|
|
250
|
+
for text in batch_texts:
|
|
251
|
+
embedding, usage = await self.async_get_embedding_and_usage(text)
|
|
252
|
+
all_embeddings.append(embedding)
|
|
253
|
+
all_usage.append(usage)
|
|
254
|
+
|
|
255
|
+
except Exception as e:
|
|
256
|
+
logger.warning(f"Error in async batch embedding: {e}")
|
|
257
|
+
# Fallback: add empty results for failed batch
|
|
258
|
+
for _ in batch_texts:
|
|
259
|
+
all_embeddings.append([])
|
|
260
|
+
all_usage.append(None)
|
|
261
|
+
|
|
262
|
+
return all_embeddings, all_usage
|
|
@@ -30,12 +30,13 @@ class VoyageAIEmbedder(Embedder):
|
|
|
30
30
|
if self.voyage_client:
|
|
31
31
|
return self.voyage_client
|
|
32
32
|
|
|
33
|
-
_client_params = {
|
|
34
|
-
|
|
35
|
-
"
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
33
|
+
_client_params: Dict[str, Any] = {}
|
|
34
|
+
if self.api_key is not None:
|
|
35
|
+
_client_params["api_key"] = self.api_key
|
|
36
|
+
if self.max_retries is not None:
|
|
37
|
+
_client_params["max_retries"] = self.max_retries
|
|
38
|
+
if self.timeout is not None:
|
|
39
|
+
_client_params["timeout"] = self.timeout
|
|
39
40
|
if self.client_params:
|
|
40
41
|
_client_params.update(self.client_params)
|
|
41
42
|
self.voyage_client = VoyageClient(**_client_params)
|
|
@@ -46,12 +47,13 @@ class VoyageAIEmbedder(Embedder):
|
|
|
46
47
|
if self.async_client:
|
|
47
48
|
return self.async_client
|
|
48
49
|
|
|
49
|
-
_client_params = {
|
|
50
|
-
|
|
51
|
-
"
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
50
|
+
_client_params: Dict[str, Any] = {}
|
|
51
|
+
if self.api_key is not None:
|
|
52
|
+
_client_params["api_key"] = self.api_key
|
|
53
|
+
if self.max_retries is not None:
|
|
54
|
+
_client_params["max_retries"] = self.max_retries
|
|
55
|
+
if self.timeout is not None:
|
|
56
|
+
_client_params["timeout"] = self.timeout
|
|
55
57
|
if self.client_params:
|
|
56
58
|
_client_params.update(self.client_params)
|
|
57
59
|
self.async_client = AsyncVoyageClient(**_client_params)
|
|
@@ -69,7 +71,8 @@ class VoyageAIEmbedder(Embedder):
|
|
|
69
71
|
def get_embedding(self, text: str) -> List[float]:
|
|
70
72
|
response: EmbeddingsObject = self._response(text=text)
|
|
71
73
|
try:
|
|
72
|
-
|
|
74
|
+
embedding = response.embeddings[0]
|
|
75
|
+
return [float(x) for x in embedding] # Ensure all values are float
|
|
73
76
|
except Exception as e:
|
|
74
77
|
logger.warning(e)
|
|
75
78
|
return []
|
|
@@ -79,7 +82,7 @@ class VoyageAIEmbedder(Embedder):
|
|
|
79
82
|
|
|
80
83
|
embedding = response.embeddings[0]
|
|
81
84
|
usage = {"total_tokens": response.total_tokens}
|
|
82
|
-
return embedding, usage
|
|
85
|
+
return [float(x) for x in embedding], usage
|
|
83
86
|
|
|
84
87
|
async def _async_response(self, text: str) -> EmbeddingsObject:
|
|
85
88
|
"""Async version of _response using AsyncVoyageClient."""
|
|
@@ -95,7 +98,8 @@ class VoyageAIEmbedder(Embedder):
|
|
|
95
98
|
"""Async version of get_embedding."""
|
|
96
99
|
try:
|
|
97
100
|
response: EmbeddingsObject = await self._async_response(text=text)
|
|
98
|
-
|
|
101
|
+
embedding = response.embeddings[0]
|
|
102
|
+
return [float(x) for x in embedding] # Ensure all values are float
|
|
99
103
|
except Exception as e:
|
|
100
104
|
logger.warning(f"Error getting embedding: {e}")
|
|
101
105
|
return []
|
|
@@ -106,7 +110,56 @@ class VoyageAIEmbedder(Embedder):
|
|
|
106
110
|
response: EmbeddingsObject = await self._async_response(text=text)
|
|
107
111
|
embedding = response.embeddings[0]
|
|
108
112
|
usage = {"total_tokens": response.total_tokens}
|
|
109
|
-
return embedding, usage
|
|
113
|
+
return [float(x) for x in embedding], usage
|
|
110
114
|
except Exception as e:
|
|
111
115
|
logger.warning(f"Error getting embedding and usage: {e}")
|
|
112
116
|
return [], None
|
|
117
|
+
|
|
118
|
+
async def async_get_embeddings_batch_and_usage(
|
|
119
|
+
self, texts: List[str]
|
|
120
|
+
) -> Tuple[List[List[float]], List[Optional[Dict]]]:
|
|
121
|
+
"""
|
|
122
|
+
Get embeddings and usage for multiple texts in batches.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
texts: List of text strings to embed
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Tuple of (List of embedding vectors, List of usage dictionaries)
|
|
129
|
+
"""
|
|
130
|
+
all_embeddings: List[List[float]] = []
|
|
131
|
+
all_usage: List[Optional[Dict]] = []
|
|
132
|
+
logger.info(f"Getting embeddings and usage for {len(texts)} texts in batches of {self.batch_size}")
|
|
133
|
+
|
|
134
|
+
for i in range(0, len(texts), self.batch_size):
|
|
135
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
136
|
+
|
|
137
|
+
req: Dict[str, Any] = {
|
|
138
|
+
"texts": batch_texts,
|
|
139
|
+
"model": self.id,
|
|
140
|
+
}
|
|
141
|
+
if self.request_params:
|
|
142
|
+
req.update(self.request_params)
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
response: EmbeddingsObject = await self.aclient.embed(**req)
|
|
146
|
+
batch_embeddings = [[float(x) for x in emb] for emb in response.embeddings]
|
|
147
|
+
all_embeddings.extend(batch_embeddings)
|
|
148
|
+
|
|
149
|
+
# For each embedding in the batch, add the same usage information
|
|
150
|
+
usage_dict = {"total_tokens": response.total_tokens}
|
|
151
|
+
all_usage.extend([usage_dict] * len(batch_embeddings))
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.warning(f"Error in async batch embedding: {e}")
|
|
154
|
+
# Fallback to individual calls for this batch
|
|
155
|
+
for text in batch_texts:
|
|
156
|
+
try:
|
|
157
|
+
embedding, usage = await self.async_get_embedding_and_usage(text)
|
|
158
|
+
all_embeddings.append(embedding)
|
|
159
|
+
all_usage.append(usage)
|
|
160
|
+
except Exception as e2:
|
|
161
|
+
logger.warning(f"Error in individual async embedding fallback: {e2}")
|
|
162
|
+
all_embeddings.append([])
|
|
163
|
+
all_usage.append(None)
|
|
164
|
+
|
|
165
|
+
return all_embeddings, all_usage
|