swarms 7.6.4__tar.gz → 7.6.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {swarms-7.6.4 → swarms-7.6.5}/PKG-INFO +1 -1
- {swarms-7.6.4 → swarms-7.6.5}/pyproject.toml +1 -1
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/__init__.py +1 -3
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/agent.py +77 -0
- swarms-7.6.5/swarms/tools/mcp_integration.py +392 -0
- swarms-7.6.5/swarms/utils/vllm_wrapper.py +146 -0
- swarms-7.6.4/swarms/structs/auto_swarm.py +0 -229
- swarms-7.6.4/swarms/tools/mcp_integration.py +0 -554
- {swarms-7.6.4 → swarms-7.6.5}/LICENSE +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/README.md +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/agent_judge.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/agent_print.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/ape_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/auto_generate_swarm_config.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/consistency_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/create_agents_from_yaml.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/flexion_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/gkp_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/i_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/openai_assistant.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/reasoning_agents.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/reasoning_duo.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/agents/tool_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/artifacts/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/artifacts/main_artifact.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/cli/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/cli/create_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/cli/main.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/cli/onboarding_process.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/client/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/client/main.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/accountant_swarm_prompts.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/ag_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/aga.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/agent_judge_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/agent_prompts.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/agent_system_prompts.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/ai_research_team.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/aot_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/autobloggen.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/autoswarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/chat_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/code_interpreter.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/code_spawner.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/debate.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/documentation.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/education.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/finance_agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/finance_agent_sys_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/growth_agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/idea2img.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/legal_agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/logistics.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/meta_system_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/multi_agent_collab_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/multi_modal_autonomous_instruction_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/multi_modal_prompts.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/multi_modal_visual_prompts.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/operations_agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/personal_stylist.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/product_agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/programming.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/project_manager.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/prompt_generator.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/prompt_generator_optimizer.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/python.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/react.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/reasoning_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/refiner_agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/sales.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/sales_prompts.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/security_team.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/self_operating_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/sop_generator_agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/summaries_prompts.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/support_agent_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/swarm_manager_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/task_assignment_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/tests.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/tools.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/urban_planning.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/visual_cot.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/worker_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/prompts/xray_swarm_prompt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/schemas/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/schemas/agent_input_schema.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/schemas/agent_step_schemas.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/schemas/base_schemas.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/agent_builder.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/agent_registry.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/agent_roles.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/agent_router.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/agents_available.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/async_workflow.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/auto_swarm_builder.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/base_structure.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/base_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/base_workflow.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/concat.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/concurrent_workflow.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/conversation.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/csv_to_agent.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/de_hallucination_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/deep_research_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/dynamic_conversational_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/graph_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/graph_workflow.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/groupchat.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/hiearchical_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/hybrid_hiearchical_peer_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/majority_voting.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/malt.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/matrix_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/meme_agent_persona_generator.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/mixture_of_agents.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/model_router.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/multi_agent_collab.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/multi_agent_exec.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/multi_agent_orchestrator.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/octotools.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/omni_agent_types.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/output_types.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/pulsar_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/queue_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/rearrange.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/round_robin.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/safe_loading.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/sequential_workflow.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/spreadsheet_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/stopping_conditions.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_arange.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_builder.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_eval.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_id_generator.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_load_balancer.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_matcher.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_output_type.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_registry.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarm_router.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/swarming_architectures.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/talk_hier.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/tree_swarm.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/utils.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/various_alt_swarms.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/structs/workspace_manager.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/telemetry/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/telemetry/bootup.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/telemetry/main.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/base_tool.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/cohere_func_call_schema.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/func_calling_utils.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/func_to_str.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/function_util.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/json_former.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/json_utils.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/logits_processor.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/openai_func_calling_schema_pydantic.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/openai_tool_creator_decorator.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/py_func_to_openai_func_str.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/pydantic_to_json.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/tool_parse_exec.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/tool_registry.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/tool_schema_base_model.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/tools/tool_utils.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/__init__.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/any_to_str.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/auto_download_check_packages.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/calculate_func_metrics.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/data_to_text.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/disable_logging.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/file_processing.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/formatter.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/function_caller_model.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/history_output_formatter.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/litellm_tokenizer.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/litellm_wrapper.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/loguru_logger.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/markdown_message.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/parse_code.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/pdf_to_text.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/str_to_dict.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/swarm_reliability_checks.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/try_except_wrapper.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/visualizer.py +0 -0
- {swarms-7.6.4 → swarms-7.6.5}/swarms/utils/wrapper_clusterop.py +0 -0
@@ -2,7 +2,7 @@ from swarms.structs.agent import Agent
|
|
2
2
|
from swarms.structs.agent_builder import AgentsBuilder
|
3
3
|
from swarms.structs.agents_available import showcase_available_agents
|
4
4
|
from swarms.structs.async_workflow import AsyncWorkflow
|
5
|
-
from
|
5
|
+
from experimental.auto_swarm import AutoSwarm, AutoSwarmRouter
|
6
6
|
from swarms.structs.base_structure import BaseStructure
|
7
7
|
from swarms.structs.base_swarm import BaseSwarm
|
8
8
|
from swarms.structs.base_workflow import BaseWorkflow
|
@@ -85,8 +85,6 @@ from swarms.structs.swarming_architectures import (
|
|
85
85
|
__all__ = [
|
86
86
|
"Agent",
|
87
87
|
"AsyncWorkflow",
|
88
|
-
"AutoSwarm",
|
89
|
-
"AutoSwarmRouter",
|
90
88
|
"BaseStructure",
|
91
89
|
"BaseSwarm",
|
92
90
|
"BaseWorkflow",
|
@@ -58,6 +58,12 @@ from swarms.utils.litellm_tokenizer import count_tokens
|
|
58
58
|
from swarms.utils.pdf_to_text import pdf_to_text
|
59
59
|
from swarms.utils.str_to_dict import str_to_dict
|
60
60
|
|
61
|
+
from swarms.tools.mcp_integration import (
|
62
|
+
batch_mcp_flow,
|
63
|
+
mcp_flow_get_tool_schema,
|
64
|
+
MCPServerSseParams,
|
65
|
+
)
|
66
|
+
|
61
67
|
|
62
68
|
# Utils
|
63
69
|
# Custom stopping condition
|
@@ -352,6 +358,7 @@ class Agent:
|
|
352
358
|
role: agent_roles = "worker",
|
353
359
|
no_print: bool = False,
|
354
360
|
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
|
361
|
+
mcp_servers: List[MCPServerSseParams] = [],
|
355
362
|
*args,
|
356
363
|
**kwargs,
|
357
364
|
):
|
@@ -471,6 +478,7 @@ class Agent:
|
|
471
478
|
self.role = role
|
472
479
|
self.no_print = no_print
|
473
480
|
self.tools_list_dictionary = tools_list_dictionary
|
481
|
+
self.mcp_servers = mcp_servers
|
474
482
|
|
475
483
|
if (
|
476
484
|
self.agent_name is not None
|
@@ -584,6 +592,12 @@ class Agent:
|
|
584
592
|
if self.llm is None:
|
585
593
|
self.llm = self.llm_handling()
|
586
594
|
|
595
|
+
if (
|
596
|
+
self.tools_list_dictionary is None
|
597
|
+
and self.mcp_servers is not None
|
598
|
+
):
|
599
|
+
self.tools_list_dictionary = self.mcp_tool_handling()
|
600
|
+
|
587
601
|
def llm_handling(self):
|
588
602
|
from swarms.utils.litellm_wrapper import LiteLLM
|
589
603
|
|
@@ -631,6 +645,69 @@ class Agent:
|
|
631
645
|
logger.error(f"Error in llm_handling: {e}")
|
632
646
|
return None
|
633
647
|
|
648
|
+
def mcp_execution_flow(self, response: any):
|
649
|
+
"""
|
650
|
+
Executes the MCP (Model Context Protocol) flow based on the provided response.
|
651
|
+
|
652
|
+
This method takes a response, converts it from a string to a dictionary format,
|
653
|
+
and checks for the presence of a tool name or a name in the response. If either
|
654
|
+
is found, it retrieves the tool name and proceeds to call the batch_mcp_flow
|
655
|
+
function to execute the corresponding tool actions.
|
656
|
+
|
657
|
+
Args:
|
658
|
+
response (any): The response to be processed, which can be in string format
|
659
|
+
that represents a dictionary.
|
660
|
+
|
661
|
+
Returns:
|
662
|
+
The output from the batch_mcp_flow function, which contains the results of
|
663
|
+
the tool execution. If an error occurs during processing, it logs the error
|
664
|
+
and returns None.
|
665
|
+
|
666
|
+
Raises:
|
667
|
+
Exception: Logs any exceptions that occur during the execution flow.
|
668
|
+
"""
|
669
|
+
try:
|
670
|
+
response = str_to_dict(response)
|
671
|
+
|
672
|
+
tool_output = batch_mcp_flow(
|
673
|
+
self.mcp_servers,
|
674
|
+
function_call=response,
|
675
|
+
)
|
676
|
+
|
677
|
+
return tool_output
|
678
|
+
except Exception as e:
|
679
|
+
logger.error(f"Error in mcp_execution_flow: {e}")
|
680
|
+
return None
|
681
|
+
|
682
|
+
def mcp_tool_handling(self):
|
683
|
+
"""
|
684
|
+
Handles the retrieval of tool schemas from the MCP servers.
|
685
|
+
|
686
|
+
This method iterates over the list of MCP servers, retrieves the tool schema
|
687
|
+
for each server using the mcp_flow_get_tool_schema function, and compiles
|
688
|
+
these schemas into a list. The resulting list is stored in the
|
689
|
+
tools_list_dictionary attribute.
|
690
|
+
|
691
|
+
Returns:
|
692
|
+
list: A list of tool schemas retrieved from the MCP servers. If an error
|
693
|
+
occurs during the retrieval process, it logs the error and returns None.
|
694
|
+
|
695
|
+
Raises:
|
696
|
+
Exception: Logs any exceptions that occur during the tool handling process.
|
697
|
+
"""
|
698
|
+
try:
|
699
|
+
self.tools_list_dictionary = []
|
700
|
+
|
701
|
+
for mcp_server in self.mcp_servers:
|
702
|
+
tool_schema = mcp_flow_get_tool_schema(mcp_server)
|
703
|
+
self.tools_list_dictionary.append(tool_schema)
|
704
|
+
|
705
|
+
print(self.tools_list_dictionary)
|
706
|
+
return self.tools_list_dictionary
|
707
|
+
except Exception as e:
|
708
|
+
logger.error(f"Error in mcp_tool_handling: {e}")
|
709
|
+
return None
|
710
|
+
|
634
711
|
def setup_config(self):
|
635
712
|
# The max_loops will be set dynamically if the dynamic_loop
|
636
713
|
if self.dynamic_loops is True:
|
@@ -0,0 +1,392 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, List
|
4
|
+
|
5
|
+
|
6
|
+
from loguru import logger
|
7
|
+
|
8
|
+
import abc
|
9
|
+
import asyncio
|
10
|
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import Literal
|
13
|
+
|
14
|
+
from anyio.streams.memory import (
|
15
|
+
MemoryObjectReceiveStream,
|
16
|
+
MemoryObjectSendStream,
|
17
|
+
)
|
18
|
+
from mcp import (
|
19
|
+
ClientSession,
|
20
|
+
StdioServerParameters,
|
21
|
+
Tool as MCPTool,
|
22
|
+
stdio_client,
|
23
|
+
)
|
24
|
+
from mcp.client.sse import sse_client
|
25
|
+
from mcp.types import CallToolResult, JSONRPCMessage
|
26
|
+
from typing_extensions import NotRequired, TypedDict
|
27
|
+
|
28
|
+
from swarms.utils.any_to_str import any_to_str
|
29
|
+
|
30
|
+
|
31
|
+
class MCPServer(abc.ABC):
|
32
|
+
"""Base class for Model Context Protocol servers."""
|
33
|
+
|
34
|
+
@abc.abstractmethod
|
35
|
+
async def connect(self):
|
36
|
+
"""Connect to the server. For example, this might mean spawning a subprocess or
|
37
|
+
opening a network connection. The server is expected to remain connected until
|
38
|
+
`cleanup()` is called.
|
39
|
+
"""
|
40
|
+
pass
|
41
|
+
|
42
|
+
@property
|
43
|
+
@abc.abstractmethod
|
44
|
+
def name(self) -> str:
|
45
|
+
"""A readable name for the server."""
|
46
|
+
pass
|
47
|
+
|
48
|
+
@abc.abstractmethod
|
49
|
+
async def cleanup(self):
|
50
|
+
"""Cleanup the server. For example, this might mean closing a subprocess or
|
51
|
+
closing a network connection.
|
52
|
+
"""
|
53
|
+
pass
|
54
|
+
|
55
|
+
@abc.abstractmethod
|
56
|
+
async def list_tools(self) -> list[MCPTool]:
|
57
|
+
"""List the tools available on the server."""
|
58
|
+
pass
|
59
|
+
|
60
|
+
@abc.abstractmethod
|
61
|
+
async def call_tool(
|
62
|
+
self, tool_name: str, arguments: dict[str, Any] | None
|
63
|
+
) -> CallToolResult:
|
64
|
+
"""Invoke a tool on the server."""
|
65
|
+
pass
|
66
|
+
|
67
|
+
|
68
|
+
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
69
|
+
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
70
|
+
|
71
|
+
def __init__(self, cache_tools_list: bool):
|
72
|
+
"""
|
73
|
+
Args:
|
74
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
75
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
76
|
+
fetched from the server on each call to `list_tools()`. The cache can be invalidated
|
77
|
+
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
78
|
+
server will not change its tools list, because it can drastically improve latency
|
79
|
+
(by avoiding a round-trip to the server every time).
|
80
|
+
"""
|
81
|
+
self.session: ClientSession | None = None
|
82
|
+
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
83
|
+
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
84
|
+
self.cache_tools_list = cache_tools_list
|
85
|
+
|
86
|
+
# The cache is always dirty at startup, so that we fetch tools at least once
|
87
|
+
self._cache_dirty = True
|
88
|
+
self._tools_list: list[MCPTool] | None = None
|
89
|
+
|
90
|
+
@abc.abstractmethod
|
91
|
+
def create_streams(
|
92
|
+
self,
|
93
|
+
) -> AbstractAsyncContextManager[
|
94
|
+
tuple[
|
95
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
96
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
97
|
+
]
|
98
|
+
]:
|
99
|
+
"""Create the streams for the server."""
|
100
|
+
pass
|
101
|
+
|
102
|
+
async def __aenter__(self):
|
103
|
+
await self.connect()
|
104
|
+
return self
|
105
|
+
|
106
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
107
|
+
await self.cleanup()
|
108
|
+
|
109
|
+
def invalidate_tools_cache(self):
|
110
|
+
"""Invalidate the tools cache."""
|
111
|
+
self._cache_dirty = True
|
112
|
+
|
113
|
+
async def connect(self):
|
114
|
+
"""Connect to the server."""
|
115
|
+
try:
|
116
|
+
transport = await self.exit_stack.enter_async_context(
|
117
|
+
self.create_streams()
|
118
|
+
)
|
119
|
+
read, write = transport
|
120
|
+
session = await self.exit_stack.enter_async_context(
|
121
|
+
ClientSession(read, write)
|
122
|
+
)
|
123
|
+
await session.initialize()
|
124
|
+
self.session = session
|
125
|
+
except Exception as e:
|
126
|
+
logger.error(f"Error initializing MCP server: {e}")
|
127
|
+
await self.cleanup()
|
128
|
+
raise
|
129
|
+
|
130
|
+
async def list_tools(self) -> list[MCPTool]:
|
131
|
+
"""List the tools available on the server."""
|
132
|
+
if not self.session:
|
133
|
+
raise Exception(
|
134
|
+
"Server not initialized. Make sure you call `connect()` first."
|
135
|
+
)
|
136
|
+
|
137
|
+
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
|
138
|
+
if (
|
139
|
+
self.cache_tools_list
|
140
|
+
and not self._cache_dirty
|
141
|
+
and self._tools_list
|
142
|
+
):
|
143
|
+
return self._tools_list
|
144
|
+
|
145
|
+
# Reset the cache dirty to False
|
146
|
+
self._cache_dirty = False
|
147
|
+
|
148
|
+
# Fetch the tools from the server
|
149
|
+
self._tools_list = (await self.session.list_tools()).tools
|
150
|
+
return self._tools_list
|
151
|
+
|
152
|
+
async def call_tool(
|
153
|
+
self, arguments: dict[str, Any] | None
|
154
|
+
) -> CallToolResult:
|
155
|
+
"""Invoke a tool on the server."""
|
156
|
+
tool_name = arguments.get("tool_name") or arguments.get(
|
157
|
+
"name"
|
158
|
+
)
|
159
|
+
|
160
|
+
if not tool_name:
|
161
|
+
raise Exception("No tool name found in arguments")
|
162
|
+
|
163
|
+
if not self.session:
|
164
|
+
raise Exception(
|
165
|
+
"Server not initialized. Make sure you call `connect()` first."
|
166
|
+
)
|
167
|
+
|
168
|
+
return await self.session.call_tool(tool_name, arguments)
|
169
|
+
|
170
|
+
async def cleanup(self):
|
171
|
+
"""Cleanup the server."""
|
172
|
+
async with self._cleanup_lock:
|
173
|
+
try:
|
174
|
+
await self.exit_stack.aclose()
|
175
|
+
self.session = None
|
176
|
+
except Exception as e:
|
177
|
+
logger.error(f"Error cleaning up server: {e}")
|
178
|
+
|
179
|
+
|
180
|
+
class MCPServerStdioParams(TypedDict):
|
181
|
+
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
|
182
|
+
import.
|
183
|
+
"""
|
184
|
+
|
185
|
+
command: str
|
186
|
+
"""The executable to run to start the server. For example, `python` or `node`."""
|
187
|
+
|
188
|
+
args: NotRequired[list[str]]
|
189
|
+
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
|
190
|
+
`['server.js', '--port', '8080']`."""
|
191
|
+
|
192
|
+
env: NotRequired[dict[str, str]]
|
193
|
+
"""The environment variables to set for the server. ."""
|
194
|
+
|
195
|
+
cwd: NotRequired[str | Path]
|
196
|
+
"""The working directory to use when spawning the process."""
|
197
|
+
|
198
|
+
encoding: NotRequired[str]
|
199
|
+
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
|
200
|
+
|
201
|
+
encoding_error_handler: NotRequired[
|
202
|
+
Literal["strict", "ignore", "replace"]
|
203
|
+
]
|
204
|
+
"""The text encoding error handler. Defaults to `strict`.
|
205
|
+
|
206
|
+
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
207
|
+
explanations of possible values.
|
208
|
+
"""
|
209
|
+
|
210
|
+
|
211
|
+
class MCPServerStdio(_MCPServerWithClientSession):
|
212
|
+
"""MCP server implementation that uses the stdio transport. See the [spec]
|
213
|
+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
|
214
|
+
details.
|
215
|
+
"""
|
216
|
+
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
params: MCPServerStdioParams,
|
220
|
+
cache_tools_list: bool = False,
|
221
|
+
name: str | None = None,
|
222
|
+
):
|
223
|
+
"""Create a new MCP server based on the stdio transport.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
params: The params that configure the server. This includes the command to run to
|
227
|
+
start the server, the args to pass to the command, the environment variables to
|
228
|
+
set for the server, the working directory to use when spawning the process, and
|
229
|
+
the text encoding used when sending/receiving messages to the server.
|
230
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
231
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
232
|
+
fetched from the server on each call to `list_tools()`. The cache can be
|
233
|
+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
234
|
+
if you know the server will not change its tools list, because it can drastically
|
235
|
+
improve latency (by avoiding a round-trip to the server every time).
|
236
|
+
name: A readable name for the server. If not provided, we'll create one from the
|
237
|
+
command.
|
238
|
+
"""
|
239
|
+
super().__init__(cache_tools_list)
|
240
|
+
|
241
|
+
self.params = StdioServerParameters(
|
242
|
+
command=params["command"],
|
243
|
+
args=params.get("args", []),
|
244
|
+
env=params.get("env"),
|
245
|
+
cwd=params.get("cwd"),
|
246
|
+
encoding=params.get("encoding", "utf-8"),
|
247
|
+
encoding_error_handler=params.get(
|
248
|
+
"encoding_error_handler", "strict"
|
249
|
+
),
|
250
|
+
)
|
251
|
+
|
252
|
+
self._name = name or f"stdio: {self.params.command}"
|
253
|
+
|
254
|
+
def create_streams(
|
255
|
+
self,
|
256
|
+
) -> AbstractAsyncContextManager[
|
257
|
+
tuple[
|
258
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
259
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
260
|
+
]
|
261
|
+
]:
|
262
|
+
"""Create the streams for the server."""
|
263
|
+
return stdio_client(self.params)
|
264
|
+
|
265
|
+
@property
|
266
|
+
def name(self) -> str:
|
267
|
+
"""A readable name for the server."""
|
268
|
+
return self._name
|
269
|
+
|
270
|
+
|
271
|
+
class MCPServerSseParams(TypedDict):
|
272
|
+
"""Mirrors the params in`mcp.client.sse.sse_client`."""
|
273
|
+
|
274
|
+
url: str
|
275
|
+
"""The URL of the server."""
|
276
|
+
|
277
|
+
headers: NotRequired[dict[str, str]]
|
278
|
+
"""The headers to send to the server."""
|
279
|
+
|
280
|
+
timeout: NotRequired[float]
|
281
|
+
"""The timeout for the HTTP request. Defaults to 5 seconds."""
|
282
|
+
|
283
|
+
sse_read_timeout: NotRequired[float]
|
284
|
+
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
|
285
|
+
|
286
|
+
|
287
|
+
class MCPServerSse(_MCPServerWithClientSession):
|
288
|
+
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
|
289
|
+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
|
290
|
+
for details.
|
291
|
+
"""
|
292
|
+
|
293
|
+
def __init__(
|
294
|
+
self,
|
295
|
+
params: MCPServerSseParams,
|
296
|
+
cache_tools_list: bool = False,
|
297
|
+
name: str | None = None,
|
298
|
+
):
|
299
|
+
"""Create a new MCP server based on the HTTP with SSE transport.
|
300
|
+
|
301
|
+
Args:
|
302
|
+
params: The params that configure the server. This includes the URL of the server,
|
303
|
+
the headers to send to the server, the timeout for the HTTP request, and the
|
304
|
+
timeout for the SSE connection.
|
305
|
+
|
306
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
307
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
308
|
+
fetched from the server on each call to `list_tools()`. The cache can be
|
309
|
+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
310
|
+
if you know the server will not change its tools list, because it can drastically
|
311
|
+
improve latency (by avoiding a round-trip to the server every time).
|
312
|
+
|
313
|
+
name: A readable name for the server. If not provided, we'll create one from the
|
314
|
+
URL.
|
315
|
+
"""
|
316
|
+
super().__init__(cache_tools_list)
|
317
|
+
|
318
|
+
self.params = params
|
319
|
+
self._name = name or f"sse: {self.params['url']}"
|
320
|
+
|
321
|
+
def create_streams(
|
322
|
+
self,
|
323
|
+
) -> AbstractAsyncContextManager[
|
324
|
+
tuple[
|
325
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
326
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
327
|
+
]
|
328
|
+
]:
|
329
|
+
"""Create the streams for the server."""
|
330
|
+
return sse_client(
|
331
|
+
url=self.params["url"],
|
332
|
+
headers=self.params.get("headers", None),
|
333
|
+
timeout=self.params.get("timeout", 5),
|
334
|
+
sse_read_timeout=self.params.get(
|
335
|
+
"sse_read_timeout", 60 * 5
|
336
|
+
),
|
337
|
+
)
|
338
|
+
|
339
|
+
@property
|
340
|
+
def name(self) -> str:
|
341
|
+
"""A readable name for the server."""
|
342
|
+
return self._name
|
343
|
+
|
344
|
+
|
345
|
+
def mcp_flow_get_tool_schema(
|
346
|
+
params: MCPServerSseParams,
|
347
|
+
) -> MCPServer:
|
348
|
+
server = MCPServerSse(params, cache_tools_list=True)
|
349
|
+
|
350
|
+
# Connect the server
|
351
|
+
asyncio.run(server.connect())
|
352
|
+
|
353
|
+
# Return the server
|
354
|
+
output = asyncio.run(server.list_tools())
|
355
|
+
|
356
|
+
# Cleanup the server
|
357
|
+
asyncio.run(server.cleanup())
|
358
|
+
|
359
|
+
return output.model_dump()
|
360
|
+
|
361
|
+
|
362
|
+
def mcp_flow(
|
363
|
+
params: MCPServerSseParams,
|
364
|
+
function_call: dict[str, Any],
|
365
|
+
) -> MCPServer:
|
366
|
+
server = MCPServerSse(params, cache_tools_list=True)
|
367
|
+
|
368
|
+
# Connect the server
|
369
|
+
asyncio.run(server.connect())
|
370
|
+
|
371
|
+
# Return the server
|
372
|
+
output = asyncio.run(server.call_tool(function_call))
|
373
|
+
|
374
|
+
output = output.model_dump()
|
375
|
+
|
376
|
+
# Cleanup the server
|
377
|
+
asyncio.run(server.cleanup())
|
378
|
+
|
379
|
+
return any_to_str(output)
|
380
|
+
|
381
|
+
|
382
|
+
def batch_mcp_flow(
|
383
|
+
params: List[MCPServerSseParams],
|
384
|
+
function_call: List[dict[str, Any]] = [],
|
385
|
+
) -> MCPServer:
|
386
|
+
output_list = []
|
387
|
+
|
388
|
+
for param in params:
|
389
|
+
output = mcp_flow(param, function_call)
|
390
|
+
output_list.append(output)
|
391
|
+
|
392
|
+
return output_list
|
@@ -0,0 +1,146 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from loguru import logger
|
3
|
+
|
4
|
+
try:
|
5
|
+
from vllm import LLM, SamplingParams
|
6
|
+
except ImportError:
|
7
|
+
import subprocess
|
8
|
+
import sys
|
9
|
+
|
10
|
+
print("Installing vllm")
|
11
|
+
subprocess.check_call(
|
12
|
+
[sys.executable, "-m", "pip", "install", "-U", "vllm"]
|
13
|
+
)
|
14
|
+
print("vllm installed")
|
15
|
+
from vllm import LLM, SamplingParams
|
16
|
+
|
17
|
+
|
18
|
+
class VLLMWrapper:
|
19
|
+
"""
|
20
|
+
A wrapper class for vLLM that provides a similar interface to LiteLLM.
|
21
|
+
This class handles model initialization and inference using vLLM.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
|
27
|
+
system_prompt: Optional[str] = None,
|
28
|
+
stream: bool = False,
|
29
|
+
temperature: float = 0.5,
|
30
|
+
max_tokens: int = 4000,
|
31
|
+
max_completion_tokens: int = 4000,
|
32
|
+
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
|
33
|
+
tool_choice: str = "auto",
|
34
|
+
parallel_tool_calls: bool = False,
|
35
|
+
*args,
|
36
|
+
**kwargs,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Initialize the vLLM wrapper with the given parameters.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
|
43
|
+
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
44
|
+
stream (bool): Whether to stream the output. Defaults to False.
|
45
|
+
temperature (float): The temperature for sampling. Defaults to 0.5.
|
46
|
+
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
|
47
|
+
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
|
48
|
+
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
|
49
|
+
tool_choice (str): How to choose tools. Defaults to "auto".
|
50
|
+
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
|
51
|
+
"""
|
52
|
+
self.model_name = model_name
|
53
|
+
self.system_prompt = system_prompt
|
54
|
+
self.stream = stream
|
55
|
+
self.temperature = temperature
|
56
|
+
self.max_tokens = max_tokens
|
57
|
+
self.max_completion_tokens = max_completion_tokens
|
58
|
+
self.tools_list_dictionary = tools_list_dictionary
|
59
|
+
self.tool_choice = tool_choice
|
60
|
+
self.parallel_tool_calls = parallel_tool_calls
|
61
|
+
|
62
|
+
# Initialize vLLM
|
63
|
+
self.llm = LLM(model=model_name, **kwargs)
|
64
|
+
self.sampling_params = SamplingParams(
|
65
|
+
temperature=temperature,
|
66
|
+
max_tokens=max_tokens,
|
67
|
+
)
|
68
|
+
|
69
|
+
def _prepare_prompt(self, task: str) -> str:
|
70
|
+
"""
|
71
|
+
Prepare the prompt for the given task.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
task (str): The task to prepare the prompt for.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
str: The prepared prompt.
|
78
|
+
"""
|
79
|
+
if self.system_prompt:
|
80
|
+
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
|
81
|
+
return f"User: {task}\nAssistant:"
|
82
|
+
|
83
|
+
def run(self, task: str, *args, **kwargs) -> str:
|
84
|
+
"""
|
85
|
+
Run the model for the given task.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
task (str): The task to run the model for.
|
89
|
+
*args: Additional positional arguments.
|
90
|
+
**kwargs: Additional keyword arguments.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
str: The model's response.
|
94
|
+
"""
|
95
|
+
try:
|
96
|
+
prompt = self._prepare_prompt(task)
|
97
|
+
|
98
|
+
outputs = self.llm.generate(prompt, self.sampling_params)
|
99
|
+
response = outputs[0].outputs[0].text.strip()
|
100
|
+
|
101
|
+
return response
|
102
|
+
|
103
|
+
except Exception as error:
|
104
|
+
logger.error(f"Error in VLLMWrapper: {error}")
|
105
|
+
raise error
|
106
|
+
|
107
|
+
def __call__(self, task: str, *args, **kwargs) -> str:
|
108
|
+
"""
|
109
|
+
Call the model for the given task.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
task (str): The task to run the model for.
|
113
|
+
*args: Additional positional arguments.
|
114
|
+
**kwargs: Additional keyword arguments.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
str: The model's response.
|
118
|
+
"""
|
119
|
+
return self.run(task, *args, **kwargs)
|
120
|
+
|
121
|
+
def batched_run(
|
122
|
+
self, tasks: List[str], batch_size: int = 10
|
123
|
+
) -> List[str]:
|
124
|
+
"""
|
125
|
+
Run the model for multiple tasks in batches.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
tasks (List[str]): List of tasks to run.
|
129
|
+
batch_size (int): Size of each batch. Defaults to 10.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
List[str]: List of model responses.
|
133
|
+
"""
|
134
|
+
logger.info(
|
135
|
+
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
|
136
|
+
)
|
137
|
+
results = []
|
138
|
+
|
139
|
+
for i in range(0, len(tasks), batch_size):
|
140
|
+
batch = tasks[i : i + batch_size]
|
141
|
+
for task in batch:
|
142
|
+
logger.info(f"Running task: {task}")
|
143
|
+
results.append(self.run(task))
|
144
|
+
|
145
|
+
logger.info("Completed all tasks.")
|
146
|
+
return results
|