ag2 0.9.6__py3-none-any.whl → 0.9.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- {ag2-0.9.6.dist-info → ag2-0.9.8.post1.dist-info}/METADATA +102 -75
- ag2-0.9.8.post1.dist-info/RECORD +387 -0
- autogen/__init__.py +1 -2
- autogen/_website/generate_api_references.py +4 -5
- autogen/_website/generate_mkdocs.py +9 -15
- autogen/_website/notebook_processor.py +13 -14
- autogen/_website/process_notebooks.py +10 -10
- autogen/_website/utils.py +5 -4
- autogen/agentchat/agent.py +13 -13
- autogen/agentchat/assistant_agent.py +7 -6
- autogen/agentchat/contrib/agent_eval/agent_eval.py +3 -3
- autogen/agentchat/contrib/agent_eval/critic_agent.py +3 -3
- autogen/agentchat/contrib/agent_eval/quantifier_agent.py +3 -3
- autogen/agentchat/contrib/agent_eval/subcritic_agent.py +3 -3
- autogen/agentchat/contrib/agent_optimizer.py +3 -3
- autogen/agentchat/contrib/capabilities/generate_images.py +11 -11
- autogen/agentchat/contrib/capabilities/teachability.py +15 -15
- autogen/agentchat/contrib/capabilities/transforms.py +17 -18
- autogen/agentchat/contrib/capabilities/transforms_util.py +5 -5
- autogen/agentchat/contrib/capabilities/vision_capability.py +4 -3
- autogen/agentchat/contrib/captainagent/agent_builder.py +30 -30
- autogen/agentchat/contrib/captainagent/captainagent.py +22 -21
- autogen/agentchat/contrib/captainagent/tool_retriever.py +2 -3
- autogen/agentchat/contrib/gpt_assistant_agent.py +9 -9
- autogen/agentchat/contrib/graph_rag/document.py +3 -3
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +3 -3
- autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +3 -3
- autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +5 -11
- autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py +7 -7
- autogen/agentchat/contrib/graph_rag/neo4j_native_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/img_utils.py +1 -1
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +11 -11
- autogen/agentchat/contrib/llava_agent.py +18 -4
- autogen/agentchat/contrib/math_user_proxy_agent.py +11 -11
- autogen/agentchat/contrib/multimodal_conversable_agent.py +8 -8
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +6 -5
- autogen/agentchat/contrib/rag/chromadb_query_engine.py +22 -26
- autogen/agentchat/contrib/rag/llamaindex_query_engine.py +14 -17
- autogen/agentchat/contrib/rag/mongodb_query_engine.py +27 -37
- autogen/agentchat/contrib/rag/query_engine.py +7 -5
- autogen/agentchat/contrib/retrieve_assistant_agent.py +5 -5
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +8 -7
- autogen/agentchat/contrib/society_of_mind_agent.py +15 -14
- autogen/agentchat/contrib/swarm_agent.py +76 -98
- autogen/agentchat/contrib/text_analyzer_agent.py +7 -7
- autogen/agentchat/contrib/vectordb/base.py +10 -18
- autogen/agentchat/contrib/vectordb/chromadb.py +2 -1
- autogen/agentchat/contrib/vectordb/couchbase.py +18 -20
- autogen/agentchat/contrib/vectordb/mongodb.py +6 -5
- autogen/agentchat/contrib/vectordb/pgvectordb.py +40 -41
- autogen/agentchat/contrib/vectordb/qdrant.py +5 -5
- autogen/agentchat/contrib/web_surfer.py +20 -19
- autogen/agentchat/conversable_agent.py +311 -295
- autogen/agentchat/group/context_str.py +1 -3
- autogen/agentchat/group/context_variables.py +15 -25
- autogen/agentchat/group/group_tool_executor.py +10 -10
- autogen/agentchat/group/group_utils.py +15 -15
- autogen/agentchat/group/guardrails.py +7 -7
- autogen/agentchat/group/handoffs.py +19 -36
- autogen/agentchat/group/multi_agent_chat.py +7 -7
- autogen/agentchat/group/on_condition.py +4 -7
- autogen/agentchat/group/on_context_condition.py +4 -7
- autogen/agentchat/group/patterns/auto.py +8 -7
- autogen/agentchat/group/patterns/manual.py +7 -6
- autogen/agentchat/group/patterns/pattern.py +13 -12
- autogen/agentchat/group/patterns/random.py +3 -3
- autogen/agentchat/group/patterns/round_robin.py +3 -3
- autogen/agentchat/group/reply_result.py +2 -4
- autogen/agentchat/group/speaker_selection_result.py +5 -5
- autogen/agentchat/group/targets/group_chat_target.py +7 -6
- autogen/agentchat/group/targets/group_manager_target.py +4 -4
- autogen/agentchat/group/targets/transition_target.py +2 -1
- autogen/agentchat/groupchat.py +58 -61
- autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py +4 -4
- autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py +4 -4
- autogen/agentchat/realtime/experimental/clients/gemini/client.py +7 -7
- autogen/agentchat/realtime/experimental/clients/oai/base_client.py +8 -8
- autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py +6 -6
- autogen/agentchat/realtime/experimental/clients/realtime_client.py +10 -9
- autogen/agentchat/realtime/experimental/realtime_agent.py +10 -9
- autogen/agentchat/realtime/experimental/realtime_observer.py +3 -3
- autogen/agentchat/realtime/experimental/realtime_swarm.py +44 -44
- autogen/agentchat/user_proxy_agent.py +10 -9
- autogen/agentchat/utils.py +3 -3
- autogen/agents/contrib/time/time_reply_agent.py +6 -5
- autogen/agents/contrib/time/time_tool_agent.py +2 -1
- autogen/agents/experimental/deep_research/deep_research.py +3 -3
- autogen/agents/experimental/discord/discord.py +2 -2
- autogen/agents/experimental/document_agent/chroma_query_engine.py +29 -44
- autogen/agents/experimental/document_agent/docling_doc_ingest_agent.py +9 -14
- autogen/agents/experimental/document_agent/document_agent.py +15 -16
- autogen/agents/experimental/document_agent/document_conditions.py +3 -3
- autogen/agents/experimental/document_agent/document_utils.py +5 -9
- autogen/agents/experimental/document_agent/inmemory_query_engine.py +14 -20
- autogen/agents/experimental/document_agent/parser_utils.py +4 -4
- autogen/agents/experimental/document_agent/url_utils.py +14 -23
- autogen/agents/experimental/reasoning/reasoning_agent.py +33 -33
- autogen/agents/experimental/slack/slack.py +2 -2
- autogen/agents/experimental/telegram/telegram.py +2 -3
- autogen/agents/experimental/websurfer/websurfer.py +4 -4
- autogen/agents/experimental/wikipedia/wikipedia.py +5 -7
- autogen/browser_utils.py +8 -8
- autogen/cache/abstract_cache_base.py +5 -5
- autogen/cache/cache.py +12 -12
- autogen/cache/cache_factory.py +4 -4
- autogen/cache/cosmos_db_cache.py +9 -9
- autogen/cache/disk_cache.py +6 -6
- autogen/cache/in_memory_cache.py +4 -4
- autogen/cache/redis_cache.py +4 -4
- autogen/code_utils.py +18 -18
- autogen/coding/base.py +6 -6
- autogen/coding/docker_commandline_code_executor.py +9 -9
- autogen/coding/func_with_reqs.py +7 -6
- autogen/coding/jupyter/base.py +3 -3
- autogen/coding/jupyter/docker_jupyter_server.py +3 -4
- autogen/coding/jupyter/import_utils.py +3 -3
- autogen/coding/jupyter/jupyter_client.py +5 -5
- autogen/coding/jupyter/jupyter_code_executor.py +3 -4
- autogen/coding/jupyter/local_jupyter_server.py +2 -6
- autogen/coding/local_commandline_code_executor.py +8 -7
- autogen/coding/markdown_code_extractor.py +1 -2
- autogen/coding/utils.py +1 -2
- autogen/doc_utils.py +3 -2
- autogen/environments/docker_python_environment.py +19 -29
- autogen/environments/python_environment.py +8 -17
- autogen/environments/system_python_environment.py +3 -4
- autogen/environments/venv_python_environment.py +8 -12
- autogen/environments/working_directory.py +1 -2
- autogen/events/agent_events.py +106 -109
- autogen/events/base_event.py +6 -5
- autogen/events/client_events.py +15 -14
- autogen/events/helpers.py +1 -1
- autogen/events/print_event.py +4 -5
- autogen/fast_depends/_compat.py +10 -15
- autogen/fast_depends/core/build.py +17 -36
- autogen/fast_depends/core/model.py +64 -113
- autogen/fast_depends/dependencies/model.py +2 -1
- autogen/fast_depends/dependencies/provider.py +3 -2
- autogen/fast_depends/library/model.py +4 -4
- autogen/fast_depends/schema.py +7 -7
- autogen/fast_depends/use.py +17 -25
- autogen/fast_depends/utils.py +10 -30
- autogen/formatting_utils.py +6 -6
- autogen/graph_utils.py +1 -4
- autogen/import_utils.py +13 -13
- autogen/interop/crewai/crewai.py +2 -2
- autogen/interop/interoperable.py +2 -2
- autogen/interop/langchain/langchain_chat_model_factory.py +3 -2
- autogen/interop/langchain/langchain_tool.py +2 -6
- autogen/interop/litellm/litellm_config_factory.py +6 -7
- autogen/interop/pydantic_ai/pydantic_ai.py +4 -7
- autogen/interop/registry.py +2 -1
- autogen/io/base.py +5 -5
- autogen/io/run_response.py +33 -32
- autogen/io/websockets.py +6 -5
- autogen/json_utils.py +1 -2
- autogen/llm_config/__init__.py +11 -0
- autogen/llm_config/client.py +58 -0
- autogen/llm_config/config.py +384 -0
- autogen/llm_config/entry.py +154 -0
- autogen/logger/base_logger.py +4 -3
- autogen/logger/file_logger.py +2 -1
- autogen/logger/logger_factory.py +2 -2
- autogen/logger/logger_utils.py +2 -2
- autogen/logger/sqlite_logger.py +3 -2
- autogen/math_utils.py +4 -5
- autogen/mcp/__main__.py +6 -6
- autogen/mcp/helpers.py +4 -4
- autogen/mcp/mcp_client.py +170 -29
- autogen/mcp/mcp_proxy/fastapi_code_generator_helpers.py +3 -4
- autogen/mcp/mcp_proxy/mcp_proxy.py +23 -26
- autogen/mcp/mcp_proxy/operation_grouping.py +4 -5
- autogen/mcp/mcp_proxy/operation_renaming.py +6 -10
- autogen/mcp/mcp_proxy/security.py +2 -3
- autogen/messages/agent_messages.py +96 -98
- autogen/messages/base_message.py +6 -5
- autogen/messages/client_messages.py +15 -14
- autogen/messages/print_message.py +4 -5
- autogen/oai/__init__.py +1 -2
- autogen/oai/anthropic.py +42 -41
- autogen/oai/bedrock.py +68 -57
- autogen/oai/cerebras.py +26 -25
- autogen/oai/client.py +118 -138
- autogen/oai/client_utils.py +3 -3
- autogen/oai/cohere.py +34 -11
- autogen/oai/gemini.py +40 -17
- autogen/oai/gemini_types.py +11 -12
- autogen/oai/groq.py +22 -10
- autogen/oai/mistral.py +17 -11
- autogen/oai/oai_models/__init__.py +14 -2
- autogen/oai/oai_models/_models.py +2 -2
- autogen/oai/oai_models/chat_completion.py +13 -14
- autogen/oai/oai_models/chat_completion_message.py +11 -9
- autogen/oai/oai_models/chat_completion_message_tool_call.py +26 -3
- autogen/oai/oai_models/chat_completion_token_logprob.py +3 -4
- autogen/oai/oai_models/completion_usage.py +8 -9
- autogen/oai/ollama.py +22 -10
- autogen/oai/openai_responses.py +40 -17
- autogen/oai/openai_utils.py +159 -85
- autogen/oai/together.py +29 -14
- autogen/retrieve_utils.py +6 -7
- autogen/runtime_logging.py +5 -4
- autogen/token_count_utils.py +7 -4
- autogen/tools/contrib/time/time.py +0 -1
- autogen/tools/dependency_injection.py +5 -6
- autogen/tools/experimental/browser_use/browser_use.py +10 -10
- autogen/tools/experimental/code_execution/python_code_execution.py +5 -7
- autogen/tools/experimental/crawl4ai/crawl4ai.py +12 -15
- autogen/tools/experimental/deep_research/deep_research.py +9 -8
- autogen/tools/experimental/duckduckgo/duckduckgo_search.py +5 -11
- autogen/tools/experimental/firecrawl/firecrawl_tool.py +98 -115
- autogen/tools/experimental/google/authentication/credentials_local_provider.py +1 -1
- autogen/tools/experimental/google/drive/drive_functions.py +4 -4
- autogen/tools/experimental/google/drive/toolkit.py +5 -5
- autogen/tools/experimental/google_search/google_search.py +5 -5
- autogen/tools/experimental/google_search/youtube_search.py +5 -5
- autogen/tools/experimental/messageplatform/discord/discord.py +8 -12
- autogen/tools/experimental/messageplatform/slack/slack.py +14 -20
- autogen/tools/experimental/messageplatform/telegram/telegram.py +8 -12
- autogen/tools/experimental/perplexity/perplexity_search.py +18 -29
- autogen/tools/experimental/reliable/reliable.py +68 -74
- autogen/tools/experimental/searxng/searxng_search.py +20 -19
- autogen/tools/experimental/tavily/tavily_search.py +12 -19
- autogen/tools/experimental/web_search_preview/web_search_preview.py +13 -7
- autogen/tools/experimental/wikipedia/wikipedia.py +7 -10
- autogen/tools/function_utils.py +7 -7
- autogen/tools/tool.py +6 -5
- autogen/types.py +2 -2
- autogen/version.py +1 -1
- ag2-0.9.6.dist-info/RECORD +0 -421
- autogen/llm_config.py +0 -385
- {ag2-0.9.6.dist-info → ag2-0.9.8.post1.dist-info}/WHEEL +0 -0
- {ag2-0.9.6.dist-info → ag2-0.9.8.post1.dist-info}/licenses/LICENSE +0 -0
- {ag2-0.9.6.dist-info → ag2-0.9.8.post1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -7,13 +7,14 @@ import copy
|
|
|
7
7
|
import inspect
|
|
8
8
|
import threading
|
|
9
9
|
import warnings
|
|
10
|
+
from collections.abc import Callable
|
|
10
11
|
from dataclasses import dataclass
|
|
11
12
|
from enum import Enum
|
|
12
13
|
from functools import partial
|
|
13
14
|
from types import MethodType
|
|
14
|
-
from typing import Annotated, Any,
|
|
15
|
+
from typing import Annotated, Any, Literal
|
|
15
16
|
|
|
16
|
-
from pydantic import BaseModel, field_serializer
|
|
17
|
+
from pydantic import BaseModel, ConfigDict, field_serializer
|
|
17
18
|
|
|
18
19
|
from ...doc_utils import export_module
|
|
19
20
|
from ...events.agent_events import ErrorEvent, RunCompletionEvent
|
|
@@ -75,10 +76,8 @@ class AfterWork: # noqa: N801
|
|
|
75
76
|
def my_selection_message(agent: ConversableAgent, messages: list[dict[str, Any]]) -> str
|
|
76
77
|
"""
|
|
77
78
|
|
|
78
|
-
agent:
|
|
79
|
-
next_agent_selection_msg:
|
|
80
|
-
Union[str, ContextStr, Callable[[ConversableAgent, list[dict[str, Any]]], str]]
|
|
81
|
-
] = None
|
|
79
|
+
agent: AfterWorkOption | ConversableAgent | str | Callable[..., Any]
|
|
80
|
+
next_agent_selection_msg: str | ContextStr | Callable[[ConversableAgent, list[dict[str, Any]]], str] | None = None
|
|
82
81
|
|
|
83
82
|
def __post_init__(self) -> None:
|
|
84
83
|
if isinstance(self.agent, str):
|
|
@@ -138,9 +137,9 @@ class OnCondition: # noqa: N801
|
|
|
138
137
|
def my_available_func(agent: ConversableAgent, messages: list[Dict[str, Any]]) -> bool
|
|
139
138
|
"""
|
|
140
139
|
|
|
141
|
-
target:
|
|
142
|
-
condition:
|
|
143
|
-
available:
|
|
140
|
+
target: ConversableAgent | dict[str, Any] | None = None
|
|
141
|
+
condition: str | ContextStr | Callable[[ConversableAgent, list[dict[str, Any]]], str] | None = None
|
|
142
|
+
available: Callable[[ConversableAgent, list[dict[str, Any]]], bool] | str | ContextExpression | None = None
|
|
144
143
|
|
|
145
144
|
def __post_init__(self) -> None:
|
|
146
145
|
# Ensure valid types
|
|
@@ -200,9 +199,9 @@ class OnContextCondition: # noqa: N801
|
|
|
200
199
|
|
|
201
200
|
"""
|
|
202
201
|
|
|
203
|
-
target:
|
|
204
|
-
condition:
|
|
205
|
-
available:
|
|
202
|
+
target: ConversableAgent | dict[str, Any] | None = None
|
|
203
|
+
condition: str | ContextExpression | None = None
|
|
204
|
+
available: Callable[[ConversableAgent, list[dict[str, Any]]], bool] | str | ContextExpression | None = None
|
|
206
205
|
|
|
207
206
|
def __post_init__(self) -> None:
|
|
208
207
|
# Ensure valid types
|
|
@@ -277,10 +276,10 @@ def _link_agents_to_swarm_manager(agents: list[Agent], group_chat_manager: Agent
|
|
|
277
276
|
|
|
278
277
|
def _run_oncontextconditions(
|
|
279
278
|
agent: ConversableAgent,
|
|
280
|
-
messages:
|
|
281
|
-
sender:
|
|
282
|
-
config:
|
|
283
|
-
) -> tuple[bool,
|
|
279
|
+
messages: list[dict[str, Any]] | None = None,
|
|
280
|
+
sender: Agent | None = None,
|
|
281
|
+
config: Any | None = None,
|
|
282
|
+
) -> tuple[bool, str | dict[str, Any] | None]:
|
|
284
283
|
"""Run OnContextConditions for an agent before any other reply function."""
|
|
285
284
|
for on_condition in agent._swarm_oncontextconditions: # type: ignore[attr-defined]
|
|
286
285
|
is_available = True
|
|
@@ -346,7 +345,6 @@ def _change_tool_context_variables_to_depends(
|
|
|
346
345
|
agent: ConversableAgent, current_tool: Tool, context_variables: ContextVariables
|
|
347
346
|
) -> None:
|
|
348
347
|
"""Checks for the context_variables parameter in the tool and updates it to use dependency injection."""
|
|
349
|
-
|
|
350
348
|
# If the tool has a context_variables parameter, remove the tool and reregister it without the parameter
|
|
351
349
|
if __CONTEXT_VARIABLES_PARAM_NAME__ in current_tool.tool_schema["function"]["parameters"]["properties"]:
|
|
352
350
|
# We'll replace the tool, so start with getting the underlying function
|
|
@@ -505,11 +503,11 @@ def _create_nested_chats(agent: ConversableAgent, nested_chat_agents: list[Conve
|
|
|
505
503
|
|
|
506
504
|
|
|
507
505
|
def _process_initial_messages(
|
|
508
|
-
messages:
|
|
509
|
-
user_agent:
|
|
506
|
+
messages: list[dict[str, Any]] | str,
|
|
507
|
+
user_agent: UserProxyAgent | None,
|
|
510
508
|
agents: list[ConversableAgent],
|
|
511
509
|
nested_chat_agents: list[ConversableAgent],
|
|
512
|
-
) -> tuple[list[dict[str, Any]],
|
|
510
|
+
) -> tuple[list[dict[str, Any]], Agent | None, list[str], list[Agent]]:
|
|
513
511
|
"""Process initial messages, validating agent names against messages, and determining the last agent to speak.
|
|
514
512
|
|
|
515
513
|
Args:
|
|
@@ -531,8 +529,8 @@ def _process_initial_messages(
|
|
|
531
529
|
|
|
532
530
|
# If there's only one message and there's no identified swarm agent
|
|
533
531
|
# Start with a user proxy agent, creating one if they haven't passed one in
|
|
534
|
-
last_agent:
|
|
535
|
-
temp_user_proxy:
|
|
532
|
+
last_agent: Agent | None
|
|
533
|
+
temp_user_proxy: Agent | None = None
|
|
536
534
|
temp_user_list: list[Agent] = []
|
|
537
535
|
if len(messages) == 1 and "name" not in messages[0] and not user_agent:
|
|
538
536
|
temp_user_proxy = UserProxyAgent(name="_User", code_execution_config=False)
|
|
@@ -585,9 +583,10 @@ def _cleanup_temp_user_messages(chat_result: ChatResult) -> None:
|
|
|
585
583
|
def _prepare_groupchat_auto_speaker(
|
|
586
584
|
groupchat: GroupChat,
|
|
587
585
|
last_swarm_agent: ConversableAgent,
|
|
588
|
-
after_work_next_agent_selection_msg:
|
|
589
|
-
|
|
590
|
-
],
|
|
586
|
+
after_work_next_agent_selection_msg: str
|
|
587
|
+
| ContextStr
|
|
588
|
+
| Callable[[ConversableAgent, list[dict[str, Any]]], str]
|
|
589
|
+
| None,
|
|
591
590
|
) -> None:
|
|
592
591
|
"""Prepare the group chat for auto speaker selection, includes updating or restore the groupchat speaker selection message.
|
|
593
592
|
|
|
@@ -645,9 +644,9 @@ def _determine_next_agent(
|
|
|
645
644
|
use_initial_agent: bool,
|
|
646
645
|
tool_execution: ConversableAgent,
|
|
647
646
|
swarm_agent_names: list[str],
|
|
648
|
-
user_agent:
|
|
649
|
-
swarm_after_work:
|
|
650
|
-
) ->
|
|
647
|
+
user_agent: UserProxyAgent | None,
|
|
648
|
+
swarm_after_work: AfterWorkOption | Callable[..., Any] | None,
|
|
649
|
+
) -> Agent | Literal["auto"] | None:
|
|
651
650
|
"""Determine the next agent in the conversation.
|
|
652
651
|
|
|
653
652
|
Args:
|
|
@@ -669,7 +668,7 @@ def _determine_next_agent(
|
|
|
669
668
|
after_work_condition = None
|
|
670
669
|
|
|
671
670
|
if tool_execution._swarm_next_agent is not None: # type: ignore[attr-defined]
|
|
672
|
-
next_agent:
|
|
671
|
+
next_agent: Agent | None = tool_execution._swarm_next_agent # type: ignore[attr-defined]
|
|
673
672
|
tool_execution._swarm_next_agent = None # type: ignore[attr-defined]
|
|
674
673
|
|
|
675
674
|
if not isinstance(next_agent, AfterWorkOption):
|
|
@@ -747,9 +746,9 @@ def create_swarm_transition(
|
|
|
747
746
|
initial_agent: ConversableAgent,
|
|
748
747
|
tool_execution: ConversableAgent,
|
|
749
748
|
swarm_agent_names: list[str],
|
|
750
|
-
user_agent:
|
|
751
|
-
swarm_after_work:
|
|
752
|
-
) -> Callable[[ConversableAgent, GroupChat],
|
|
749
|
+
user_agent: UserProxyAgent | None,
|
|
750
|
+
swarm_after_work: AfterWorkOption | Callable[..., Any] | None,
|
|
751
|
+
) -> Callable[[ConversableAgent, GroupChat], Agent | Literal["auto"] | None]:
|
|
753
752
|
"""Creates a transition function for swarm chat with enclosed state for the use_initial_agent.
|
|
754
753
|
|
|
755
754
|
Args:
|
|
@@ -766,9 +765,7 @@ def create_swarm_transition(
|
|
|
766
765
|
# of swarm_transition
|
|
767
766
|
state = {"use_initial_agent": True}
|
|
768
767
|
|
|
769
|
-
def swarm_transition(
|
|
770
|
-
last_speaker: ConversableAgent, groupchat: GroupChat
|
|
771
|
-
) -> Optional[Union[Agent, Literal["auto"]]]:
|
|
768
|
+
def swarm_transition(last_speaker: ConversableAgent, groupchat: GroupChat) -> Agent | Literal["auto"] | None:
|
|
772
769
|
result = _determine_next_agent(
|
|
773
770
|
last_speaker=last_speaker,
|
|
774
771
|
groupchat=groupchat,
|
|
@@ -786,7 +783,7 @@ def create_swarm_transition(
|
|
|
786
783
|
|
|
787
784
|
|
|
788
785
|
def _create_swarm_manager(
|
|
789
|
-
groupchat: GroupChat, swarm_manager_args:
|
|
786
|
+
groupchat: GroupChat, swarm_manager_args: dict[str, Any] | None, agents: list[ConversableAgent]
|
|
790
787
|
) -> GroupChatManager:
|
|
791
788
|
"""Create a GroupChatManager for the swarm chat utilising any arguments passed in and ensure an LLM Config exists if needed
|
|
792
789
|
|
|
@@ -871,20 +868,15 @@ def make_remove_function(tool_msgs_to_remove: list[str]) -> Callable[[list[dict[
|
|
|
871
868
|
@export_module("autogen")
|
|
872
869
|
def initiate_swarm_chat(
|
|
873
870
|
initial_agent: ConversableAgent,
|
|
874
|
-
messages:
|
|
871
|
+
messages: list[dict[str, Any]] | str,
|
|
875
872
|
agents: list[ConversableAgent],
|
|
876
|
-
user_agent:
|
|
877
|
-
swarm_manager_args:
|
|
873
|
+
user_agent: UserProxyAgent | None = None,
|
|
874
|
+
swarm_manager_args: dict[str, Any] | None = None,
|
|
878
875
|
max_rounds: int = 20,
|
|
879
|
-
context_variables:
|
|
880
|
-
after_work:
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
Callable[
|
|
884
|
-
[ConversableAgent, list[dict[str, Any]], GroupChat], Union[AfterWorkOption, ConversableAgent, str]
|
|
885
|
-
],
|
|
886
|
-
]
|
|
887
|
-
] = AfterWorkOption.TERMINATE,
|
|
876
|
+
context_variables: ContextVariables | None = None,
|
|
877
|
+
after_work: AfterWorkOption
|
|
878
|
+
| Callable[[ConversableAgent, list[dict[str, Any]], GroupChat], AfterWorkOption | ConversableAgent | str]
|
|
879
|
+
| None = AfterWorkOption.TERMINATE,
|
|
888
880
|
exclude_transit_message: bool = True,
|
|
889
881
|
) -> tuple[ChatResult, ContextVariables, ConversableAgent]:
|
|
890
882
|
"""Initialize and run a swarm chat
|
|
@@ -910,6 +902,7 @@ def initiate_swarm_chat(
|
|
|
910
902
|
```
|
|
911
903
|
exclude_transit_message: all registered handoff function call and responses messages will be removed from message list before calling an LLM.
|
|
912
904
|
Note: only with transition functions added with `register_handoff` will be removed. If you pass in a function to manage workflow, it will not be removed. You may register a cumstomized hook to `process_all_messages_before_reply` to remove that.
|
|
905
|
+
|
|
913
906
|
Returns:
|
|
914
907
|
ChatResult: Conversations chat history.
|
|
915
908
|
ContextVariables: Updated Context variables.
|
|
@@ -974,20 +967,15 @@ def initiate_swarm_chat(
|
|
|
974
967
|
@export_module("autogen")
|
|
975
968
|
def run_swarm(
|
|
976
969
|
initial_agent: ConversableAgent,
|
|
977
|
-
messages:
|
|
970
|
+
messages: list[dict[str, Any]] | str,
|
|
978
971
|
agents: list[ConversableAgent],
|
|
979
|
-
user_agent:
|
|
980
|
-
swarm_manager_args:
|
|
972
|
+
user_agent: UserProxyAgent | None = None,
|
|
973
|
+
swarm_manager_args: dict[str, Any] | None = None,
|
|
981
974
|
max_rounds: int = 20,
|
|
982
|
-
context_variables:
|
|
983
|
-
after_work:
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
Callable[
|
|
987
|
-
[ConversableAgent, list[dict[str, Any]], GroupChat], Union[AfterWorkOption, ConversableAgent, str]
|
|
988
|
-
],
|
|
989
|
-
]
|
|
990
|
-
] = AfterWorkOption.TERMINATE,
|
|
975
|
+
context_variables: ContextVariables | None = None,
|
|
976
|
+
after_work: AfterWorkOption
|
|
977
|
+
| Callable[[ConversableAgent, list[dict[str, Any]], GroupChat], AfterWorkOption | ConversableAgent | str]
|
|
978
|
+
| None = AfterWorkOption.TERMINATE,
|
|
991
979
|
exclude_transit_message: bool = True,
|
|
992
980
|
) -> RunResponseProtocol:
|
|
993
981
|
iostream = ThreadIOStream()
|
|
@@ -1033,20 +1021,15 @@ def run_swarm(
|
|
|
1033
1021
|
@export_module("autogen")
|
|
1034
1022
|
async def a_initiate_swarm_chat(
|
|
1035
1023
|
initial_agent: ConversableAgent,
|
|
1036
|
-
messages:
|
|
1024
|
+
messages: list[dict[str, Any]] | str,
|
|
1037
1025
|
agents: list[ConversableAgent],
|
|
1038
|
-
user_agent:
|
|
1039
|
-
swarm_manager_args:
|
|
1026
|
+
user_agent: UserProxyAgent | None = None,
|
|
1027
|
+
swarm_manager_args: dict[str, Any] | None = None,
|
|
1040
1028
|
max_rounds: int = 20,
|
|
1041
|
-
context_variables:
|
|
1042
|
-
after_work:
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
Callable[
|
|
1046
|
-
[ConversableAgent, list[dict[str, Any]], GroupChat], Union[AfterWorkOption, ConversableAgent, str]
|
|
1047
|
-
],
|
|
1048
|
-
]
|
|
1049
|
-
] = AfterWorkOption.TERMINATE,
|
|
1029
|
+
context_variables: ContextVariables | None = None,
|
|
1030
|
+
after_work: AfterWorkOption
|
|
1031
|
+
| Callable[[ConversableAgent, list[dict[str, Any]], GroupChat], AfterWorkOption | ConversableAgent | str]
|
|
1032
|
+
| None = AfterWorkOption.TERMINATE,
|
|
1050
1033
|
exclude_transit_message: bool = True,
|
|
1051
1034
|
) -> tuple[ChatResult, ContextVariables, ConversableAgent]:
|
|
1052
1035
|
"""Initialize and run a swarm chat asynchronously
|
|
@@ -1072,6 +1055,7 @@ async def a_initiate_swarm_chat(
|
|
|
1072
1055
|
```
|
|
1073
1056
|
exclude_transit_message: all registered handoff function call and responses messages will be removed from message list before calling an LLM.
|
|
1074
1057
|
Note: only with transition functions added with `register_handoff` will be removed. If you pass in a function to manage workflow, it will not be removed. You may register a cumstomized hook to `process_all_messages_before_reply` to remove that.
|
|
1058
|
+
|
|
1075
1059
|
Returns:
|
|
1076
1060
|
ChatResult: Conversations chat history.
|
|
1077
1061
|
ContextVariables: Updated Context variables.
|
|
@@ -1135,20 +1119,15 @@ async def a_initiate_swarm_chat(
|
|
|
1135
1119
|
@export_module("autogen")
|
|
1136
1120
|
async def a_run_swarm(
|
|
1137
1121
|
initial_agent: ConversableAgent,
|
|
1138
|
-
messages:
|
|
1122
|
+
messages: list[dict[str, Any]] | str,
|
|
1139
1123
|
agents: list[ConversableAgent],
|
|
1140
|
-
user_agent:
|
|
1141
|
-
swarm_manager_args:
|
|
1124
|
+
user_agent: UserProxyAgent | None = None,
|
|
1125
|
+
swarm_manager_args: dict[str, Any] | None = None,
|
|
1142
1126
|
max_rounds: int = 20,
|
|
1143
|
-
context_variables:
|
|
1144
|
-
after_work:
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
Callable[
|
|
1148
|
-
[ConversableAgent, list[dict[str, Any]], GroupChat], Union[AfterWorkOption, ConversableAgent, str]
|
|
1149
|
-
],
|
|
1150
|
-
]
|
|
1151
|
-
] = AfterWorkOption.TERMINATE,
|
|
1127
|
+
context_variables: ContextVariables | None = None,
|
|
1128
|
+
after_work: AfterWorkOption
|
|
1129
|
+
| Callable[[ConversableAgent, list[dict[str, Any]], GroupChat], AfterWorkOption | ConversableAgent | str]
|
|
1130
|
+
| None = AfterWorkOption.TERMINATE,
|
|
1152
1131
|
exclude_transit_message: bool = True,
|
|
1153
1132
|
) -> AsyncRunResponseProtocol:
|
|
1154
1133
|
iostream = AsyncThreadIOStream()
|
|
@@ -1194,11 +1173,11 @@ class SwarmResult(BaseModel):
|
|
|
1194
1173
|
"""Encapsulates the possible return values for a swarm agent function."""
|
|
1195
1174
|
|
|
1196
1175
|
values: str = ""
|
|
1197
|
-
agent:
|
|
1198
|
-
context_variables:
|
|
1176
|
+
agent: ConversableAgent | AfterWorkOption | str | None = None
|
|
1177
|
+
context_variables: ContextVariables | None = None
|
|
1199
1178
|
|
|
1200
1179
|
@field_serializer("agent", when_used="json")
|
|
1201
|
-
def serialize_agent(self, agent:
|
|
1180
|
+
def serialize_agent(self, agent: ConversableAgent | str) -> str:
|
|
1202
1181
|
if isinstance(agent, ConversableAgent):
|
|
1203
1182
|
return agent.name
|
|
1204
1183
|
return agent
|
|
@@ -1208,8 +1187,7 @@ class SwarmResult(BaseModel):
|
|
|
1208
1187
|
if self.context_variables is None:
|
|
1209
1188
|
self.context_variables = ContextVariables()
|
|
1210
1189
|
|
|
1211
|
-
|
|
1212
|
-
arbitrary_types_allowed = True
|
|
1190
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
1213
1191
|
|
|
1214
1192
|
def __str__(self) -> str:
|
|
1215
1193
|
return self.values
|
|
@@ -1229,7 +1207,7 @@ def _set_to_tool_execution(agent: ConversableAgent) -> None:
|
|
|
1229
1207
|
@export_module("autogen")
|
|
1230
1208
|
def register_hand_off(
|
|
1231
1209
|
agent: ConversableAgent,
|
|
1232
|
-
hand_to:
|
|
1210
|
+
hand_to: list[OnCondition | OnContextCondition | AfterWork] | OnCondition | OnContextCondition | AfterWork,
|
|
1233
1211
|
) -> None:
|
|
1234
1212
|
"""Register a function to hand off to another agent.
|
|
1235
1213
|
|
|
@@ -1302,7 +1280,7 @@ def register_hand_off(
|
|
|
1302
1280
|
raise ValueError("Invalid hand off condition, must be either OnCondition or AfterWork")
|
|
1303
1281
|
|
|
1304
1282
|
|
|
1305
|
-
def _update_conditional_functions(agent: ConversableAgent, messages:
|
|
1283
|
+
def _update_conditional_functions(agent: ConversableAgent, messages: list[dict[str, Any]] | None = None) -> None:
|
|
1306
1284
|
"""Updates the agent's functions based on the OnCondition's available condition."""
|
|
1307
1285
|
for func_name, (func, on_condition) in agent._swarm_conditional_functions.items(): # type: ignore[attr-defined]
|
|
1308
1286
|
is_available = True
|
|
@@ -1334,17 +1312,17 @@ def _update_conditional_functions(agent: ConversableAgent, messages: Optional[li
|
|
|
1334
1312
|
|
|
1335
1313
|
def _generate_swarm_tool_reply(
|
|
1336
1314
|
agent: ConversableAgent,
|
|
1337
|
-
messages:
|
|
1338
|
-
sender:
|
|
1339
|
-
config:
|
|
1340
|
-
) -> tuple[bool,
|
|
1315
|
+
messages: list[dict[str, Any]] | None = None,
|
|
1316
|
+
sender: Agent | None = None,
|
|
1317
|
+
config: OpenAIWrapper | None = None,
|
|
1318
|
+
) -> tuple[bool, dict[str, Any] | None]:
|
|
1341
1319
|
"""Pre-processes and generates tool call replies.
|
|
1342
1320
|
|
|
1343
1321
|
This function:
|
|
1344
1322
|
1. Adds context_variables back to the tool call for the function, if necessary.
|
|
1345
1323
|
2. Generates the tool calls reply.
|
|
1346
|
-
3. Updates context_variables and next_agent based on the tool call response.
|
|
1347
|
-
|
|
1324
|
+
3. Updates context_variables and next_agent based on the tool call response.
|
|
1325
|
+
"""
|
|
1348
1326
|
if config is None:
|
|
1349
1327
|
config = agent # type: ignore[assignment]
|
|
1350
1328
|
if messages is None:
|
|
@@ -1355,7 +1333,7 @@ def _generate_swarm_tool_reply(
|
|
|
1355
1333
|
tool_call_count = len(message["tool_calls"])
|
|
1356
1334
|
|
|
1357
1335
|
# Loop through tool calls individually (so context can be updated after each function call)
|
|
1358
|
-
next_agent:
|
|
1336
|
+
next_agent: Agent | None = None
|
|
1359
1337
|
tool_responses_inner = []
|
|
1360
1338
|
contents = []
|
|
1361
1339
|
for index in range(tool_call_count):
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
#
|
|
5
5
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
6
|
# SPDX-License-Identifier: MIT
|
|
7
|
-
from typing import Any, Literal
|
|
7
|
+
from typing import Any, Literal
|
|
8
8
|
|
|
9
9
|
from ...llm_config import LLMConfig
|
|
10
10
|
from ..agent import Agent
|
|
@@ -22,9 +22,9 @@ class TextAnalyzerAgent(ConversableAgent):
|
|
|
22
22
|
def __init__(
|
|
23
23
|
self,
|
|
24
24
|
name="analyzer",
|
|
25
|
-
system_message:
|
|
25
|
+
system_message: str | None = system_message,
|
|
26
26
|
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
|
|
27
|
-
llm_config:
|
|
27
|
+
llm_config: LLMConfig | dict[str, Any] | bool | None = None,
|
|
28
28
|
**kwargs: Any,
|
|
29
29
|
):
|
|
30
30
|
"""Args:
|
|
@@ -48,10 +48,10 @@ class TextAnalyzerAgent(ConversableAgent):
|
|
|
48
48
|
|
|
49
49
|
def _analyze_in_reply(
|
|
50
50
|
self,
|
|
51
|
-
messages:
|
|
52
|
-
sender:
|
|
53
|
-
config:
|
|
54
|
-
) -> tuple[bool,
|
|
51
|
+
messages: list[dict[str, Any]] | None = None,
|
|
52
|
+
sender: Agent | None = None,
|
|
53
|
+
config: Any | None = None,
|
|
54
|
+
) -> tuple[bool, str | dict[str, Any] | None]:
|
|
55
55
|
"""Analyzes the given text as instructed, and returns the analysis as a message.
|
|
56
56
|
Assumes exactly two messages containing the text to analyze and the analysis instructions.
|
|
57
57
|
See Teachability.analyze for an example of how to use this method.
|
|
@@ -4,20 +4,12 @@
|
|
|
4
4
|
#
|
|
5
5
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
6
|
# SPDX-License-Identifier: MIT
|
|
7
|
-
from collections.abc import Mapping, Sequence
|
|
8
|
-
from typing import
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
TypedDict,
|
|
14
|
-
Union,
|
|
15
|
-
runtime_checkable,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
Metadata = Union[Mapping[str, Any], None]
|
|
19
|
-
Vector = Union[Sequence[float], Sequence[int]]
|
|
20
|
-
ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does
|
|
7
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
8
|
+
from typing import Any, Protocol, TypedDict, runtime_checkable
|
|
9
|
+
|
|
10
|
+
Metadata = Mapping[str, Any] | None
|
|
11
|
+
Vector = Sequence[float] | Sequence[int]
|
|
12
|
+
ItemID = str | int # chromadb doesn't support int ids, VikingDB does
|
|
21
13
|
|
|
22
14
|
|
|
23
15
|
class Document(TypedDict):
|
|
@@ -31,8 +23,8 @@ class Document(TypedDict):
|
|
|
31
23
|
|
|
32
24
|
id: ItemID
|
|
33
25
|
content: str
|
|
34
|
-
metadata:
|
|
35
|
-
embedding:
|
|
26
|
+
metadata: Metadata | None
|
|
27
|
+
embedding: Vector | None
|
|
36
28
|
|
|
37
29
|
|
|
38
30
|
"""QueryResults is the response from the vector database for a query/queries.
|
|
@@ -63,7 +55,7 @@ class VectorDB(Protocol):
|
|
|
63
55
|
|
|
64
56
|
active_collection: Any = None
|
|
65
57
|
type: str = ""
|
|
66
|
-
embedding_function:
|
|
58
|
+
embedding_function: Callable[[list[str]], list[list[float]]] | None = (
|
|
67
59
|
None # embeddings = embedding_function(sentences)
|
|
68
60
|
)
|
|
69
61
|
|
|
@@ -172,7 +164,7 @@ class VectorDB(Protocol):
|
|
|
172
164
|
...
|
|
173
165
|
|
|
174
166
|
def get_docs_by_ids(
|
|
175
|
-
self, ids: list[ItemID] = None, collection_name: str = None, include:
|
|
167
|
+
self, ids: list[ItemID] = None, collection_name: str = None, include: list[str] | None = None, **kwargs: Any
|
|
176
168
|
) -> list[Document]:
|
|
177
169
|
"""Retrieve documents from the collection of the vector database based on the ids.
|
|
178
170
|
|
|
@@ -5,7 +5,8 @@
|
|
|
5
5
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
6
|
# SPDX-License-Identifier: MIT
|
|
7
7
|
import os
|
|
8
|
-
from
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
9
10
|
|
|
10
11
|
from ....import_utils import optional_import_block, require_optional_import
|
|
11
12
|
from .base import Document, ItemID, QueryResults, VectorDB
|
|
@@ -7,8 +7,9 @@
|
|
|
7
7
|
|
|
8
8
|
import json
|
|
9
9
|
import time
|
|
10
|
+
from collections.abc import Callable
|
|
10
11
|
from datetime import timedelta
|
|
11
|
-
from typing import Any,
|
|
12
|
+
from typing import Any, Literal, Optional
|
|
12
13
|
|
|
13
14
|
from ....import_utils import optional_import_block, require_optional_import
|
|
14
15
|
from .base import Document, ItemID, QueryResults, VectorDB
|
|
@@ -36,9 +37,7 @@ EMBEDDING_KEY = "embedding"
|
|
|
36
37
|
|
|
37
38
|
@require_optional_import(["couchbase", "sentence_transformers"], "retrievechat-couchbase")
|
|
38
39
|
class CouchbaseVectorDB(VectorDB):
|
|
39
|
-
"""
|
|
40
|
-
A vector database implementation that uses Couchbase as the backend.
|
|
41
|
-
"""
|
|
40
|
+
"""A vector database implementation that uses Couchbase as the backend."""
|
|
42
41
|
|
|
43
42
|
def __init__(
|
|
44
43
|
self,
|
|
@@ -51,8 +50,8 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
51
50
|
collection_name: str = "_default",
|
|
52
51
|
index_name: str = None,
|
|
53
52
|
):
|
|
54
|
-
"""
|
|
55
|
-
|
|
53
|
+
"""Initialize the vector database.
|
|
54
|
+
|
|
56
55
|
Args:
|
|
57
56
|
connection_string (str): The Couchbase connection string to connect to. Default is 'couchbase://localhost'.
|
|
58
57
|
username (str): The username for Couchbase authentication. Default is 'Administrator'.
|
|
@@ -107,8 +106,8 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
107
106
|
overwrite: bool = False,
|
|
108
107
|
get_or_create: bool = True,
|
|
109
108
|
) -> "Collection":
|
|
110
|
-
"""
|
|
111
|
-
|
|
109
|
+
"""Create a collection in the vector database and create a vector search index in the collection.
|
|
110
|
+
|
|
112
111
|
Args:
|
|
113
112
|
collection_name (str): The name of the collection.
|
|
114
113
|
overwrite (bool): Whether to overwrite the collection if it exists. Default is False.
|
|
@@ -135,8 +134,8 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
135
134
|
def create_index_if_not_exists(
|
|
136
135
|
self, index_name: str = "vector_index", collection: Optional["Collection"] = None
|
|
137
136
|
) -> None:
|
|
138
|
-
"""
|
|
139
|
-
|
|
137
|
+
"""Creates a vector search index on the specified collection in Couchbase.
|
|
138
|
+
|
|
140
139
|
Args:
|
|
141
140
|
index_name (str, optional): The name of the vector search index to create. Defaults to "vector_search_index".
|
|
142
141
|
collection (Collection, optional): The Couchbase collection to create the index on. Defaults to None.
|
|
@@ -144,11 +143,12 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
144
143
|
if not self.search_index_exists(index_name):
|
|
145
144
|
self.create_vector_search_index(collection, index_name)
|
|
146
145
|
|
|
147
|
-
def get_collection(self, collection_name:
|
|
148
|
-
"""
|
|
149
|
-
|
|
146
|
+
def get_collection(self, collection_name: str | None = None) -> "Collection":
|
|
147
|
+
"""Get the collection from the vector database.
|
|
148
|
+
|
|
150
149
|
Args:
|
|
151
150
|
collection_name (str): The name of the collection. Default is None. If None, return the current active collection.
|
|
151
|
+
|
|
152
152
|
Returns:
|
|
153
153
|
The collection object (Collection)
|
|
154
154
|
"""
|
|
@@ -165,8 +165,8 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
165
165
|
return self.active_collection
|
|
166
166
|
|
|
167
167
|
def delete_collection(self, collection_name: str) -> None:
|
|
168
|
-
"""
|
|
169
|
-
|
|
168
|
+
"""Delete the collection from the vector database.
|
|
169
|
+
|
|
170
170
|
Args:
|
|
171
171
|
collection_name (str): The name of the collection.
|
|
172
172
|
"""
|
|
@@ -179,7 +179,7 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
179
179
|
def create_vector_search_index(
|
|
180
180
|
self,
|
|
181
181
|
collection,
|
|
182
|
-
index_name:
|
|
182
|
+
index_name: str | None = "vector_index",
|
|
183
183
|
similarity: Literal["l2_norm", "dot_product"] = "dot_product",
|
|
184
184
|
) -> None:
|
|
185
185
|
"""Create a vector search index in the collection."""
|
|
@@ -329,9 +329,9 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
329
329
|
|
|
330
330
|
def get_docs_by_ids(
|
|
331
331
|
self,
|
|
332
|
-
ids:
|
|
332
|
+
ids: list[ItemID] | None = None,
|
|
333
333
|
collection_name: str = None,
|
|
334
|
-
include:
|
|
334
|
+
include: list[str] | None = None,
|
|
335
335
|
**kwargs: Any,
|
|
336
336
|
) -> list[Document]:
|
|
337
337
|
"""Retrieve documents from the collection of the vector database based on the ids."""
|
|
@@ -365,7 +365,6 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
365
365
|
"""Retrieve documents from the collection of the vector database based on the queries.
|
|
366
366
|
Note: Distance threshold is not supported in Couchbase FTS.
|
|
367
367
|
"""
|
|
368
|
-
|
|
369
368
|
results: QueryResults = []
|
|
370
369
|
for query_text in queries:
|
|
371
370
|
query_vector = np.array(self.embedding_function([query_text])).tolist()[0]
|
|
@@ -381,7 +380,6 @@ class CouchbaseVectorDB(VectorDB):
|
|
|
381
380
|
self, embedding_vector: list[float], n_results: int = 10, **kwargs
|
|
382
381
|
) -> list[tuple[dict[str, Any], float]]:
|
|
383
382
|
"""Core vector search using Couchbase FTS."""
|
|
384
|
-
|
|
385
383
|
search_req = search.SearchRequest.create(
|
|
386
384
|
VectorSearch.from_vector_query(
|
|
387
385
|
VectorQuery(
|
|
@@ -4,9 +4,10 @@
|
|
|
4
4
|
#
|
|
5
5
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
6
|
# SPDX-License-Identifier: MIT
|
|
7
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
7
8
|
from copy import deepcopy
|
|
8
9
|
from time import monotonic, sleep
|
|
9
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, Literal
|
|
10
11
|
|
|
11
12
|
from ....import_utils import optional_import_block, require_optional_import
|
|
12
13
|
from .base import Document, ItemID, QueryResults, VectorDB
|
|
@@ -40,12 +41,12 @@ class MongoDBAtlasVectorDB(VectorDB):
|
|
|
40
41
|
self,
|
|
41
42
|
connection_string: str = "",
|
|
42
43
|
database_name: str = "vector_db",
|
|
43
|
-
embedding_function:
|
|
44
|
+
embedding_function: Callable[..., Any] | None = None,
|
|
44
45
|
collection_name: str = None,
|
|
45
46
|
index_name: str = "vector_index",
|
|
46
47
|
overwrite: bool = False,
|
|
47
|
-
wait_until_index_ready:
|
|
48
|
-
wait_until_document_ready:
|
|
48
|
+
wait_until_index_ready: float | None = None,
|
|
49
|
+
wait_until_document_ready: float | None = None,
|
|
49
50
|
):
|
|
50
51
|
"""Initialize the vector database.
|
|
51
52
|
|
|
@@ -221,7 +222,7 @@ class MongoDBAtlasVectorDB(VectorDB):
|
|
|
221
222
|
def create_vector_search_index(
|
|
222
223
|
self,
|
|
223
224
|
collection: "Collection",
|
|
224
|
-
index_name:
|
|
225
|
+
index_name: str | None = "vector_index",
|
|
225
226
|
similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine",
|
|
226
227
|
) -> None:
|
|
227
228
|
"""Create a vector search index in the collection.
|