yaicli 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +14 -1
- yaicli/cli.py +2 -2
- yaicli/config.py +1 -1
- yaicli/const.py +19 -8
- yaicli/entry.py +24 -1
- yaicli/functions/__init__.py +13 -1
- yaicli/llms/client.py +71 -38
- yaicli/llms/provider.py +3 -0
- yaicli/llms/providers/ai21_provider.py +33 -1
- yaicli/llms/providers/chatglm_provider.py +38 -11
- yaicli/llms/providers/cohere_provider.py +6 -3
- yaicli/llms/providers/deepseek_provider.py +2 -1
- yaicli/llms/providers/doubao_provider.py +0 -14
- yaicli/llms/providers/gemini_provider.py +29 -28
- yaicli/llms/providers/huggingface_provider.py +40 -0
- yaicli/llms/providers/infiniai_provider.py +4 -2
- yaicli/llms/providers/modelscope_provider.py +2 -1
- yaicli/llms/providers/openai_provider.py +20 -11
- yaicli/llms/providers/siliconflow_provider.py +2 -1
- yaicli/tools/__init__.py +127 -0
- yaicli/tools/function.py +90 -0
- yaicli/tools/mcp.py +459 -0
- yaicli/utils.py +34 -0
- {yaicli-0.6.3.dist-info → yaicli-0.7.0.dist-info}/METADATA +231 -19
- yaicli-0.7.0.dist-info/RECORD +49 -0
- yaicli/tools.py +0 -159
- yaicli-0.6.3.dist-info/RECORD +0 -46
- {yaicli-0.6.3.dist-info → yaicli-0.7.0.dist-info}/WHEEL +0 -0
- {yaicli-0.6.3.dist-info → yaicli-0.7.0.dist-info}/entry_points.txt +0 -0
- {yaicli-0.6.3.dist-info → yaicli-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,25 +1,18 @@
|
|
1
1
|
import json
|
2
|
-
from functools import wraps
|
3
2
|
from typing import Any, Callable, Dict, Generator, List
|
4
3
|
|
5
4
|
import google.genai as genai
|
6
5
|
from google.genai import types
|
7
6
|
|
7
|
+
from yaicli.tools.mcp import get_mcp_manager
|
8
|
+
|
8
9
|
from ...config import cfg
|
9
10
|
from ...console import get_console
|
10
11
|
from ...schemas import ChatMessage, LLMResponse
|
11
|
-
from ...tools import
|
12
|
+
from ...tools.function import get_functions_gemini_format
|
12
13
|
from ..provider import Provider
|
13
14
|
|
14
15
|
|
15
|
-
def wrap_function(func):
|
16
|
-
@wraps(func)
|
17
|
-
def wrapper(*args, **kwargs):
|
18
|
-
return func(*args, **kwargs)
|
19
|
-
|
20
|
-
return wrapper
|
21
|
-
|
22
|
-
|
23
16
|
class GeminiProvider(Provider):
|
24
17
|
"""Gemini provider implementation based on google-genai library"""
|
25
18
|
|
@@ -28,6 +21,7 @@ class GeminiProvider(Provider):
|
|
28
21
|
def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
|
29
22
|
self.config = config
|
30
23
|
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
24
|
+
self.enable_mcp = self.config["ENABLE_MCP"]
|
31
25
|
self.verbose = verbose
|
32
26
|
|
33
27
|
# Initialize client
|
@@ -67,16 +61,17 @@ class GeminiProvider(Provider):
|
|
67
61
|
config_map["frequency_penalty"] = self.config["FREQUENCY_PENALTY"]
|
68
62
|
if self.config.get("SEED"):
|
69
63
|
config_map["seed"] = self.config["SEED"]
|
70
|
-
# Indicates whether to include thoughts in the response.
|
64
|
+
# Indicates whether to include thoughts in the response.
|
65
|
+
# If true, thoughts are returned only if the model supports thought and thoughts are available.
|
71
66
|
thinking_config_map = {"include_thoughts": self.config.get("INCLUDE_THOUGHTS", True)}
|
72
67
|
if self.config.get("THINKING_BUDGET"):
|
73
68
|
thinking_config_map["thinking_budget"] = int(self.config["THINKING_BUDGET"])
|
74
69
|
config_map["thinking_config"] = types.ThinkingConfig(**thinking_config_map)
|
75
|
-
|
76
|
-
if self.enable_function:
|
70
|
+
if self.enable_function or self.enable_mcp:
|
77
71
|
# TODO: support disable automatic function calling
|
78
72
|
# config.automatic_function_calling = types.AutomaticFunctionCallingConfig(disable=False)
|
79
|
-
|
73
|
+
config_map["tools"] = self.gen_gemini_functions()
|
74
|
+
config = types.GenerateContentConfig(**config_map)
|
80
75
|
return config
|
81
76
|
|
82
77
|
def _convert_messages(self, messages: List[ChatMessage]) -> List[types.Content]:
|
@@ -88,7 +83,9 @@ class GeminiProvider(Provider):
|
|
88
83
|
content = types.Content(role=self._map_role(msg.role), parts=[types.Part(text=msg.content)])
|
89
84
|
if msg.role == "tool":
|
90
85
|
content.role = "user"
|
91
|
-
content.parts = [
|
86
|
+
content.parts = [
|
87
|
+
types.Part.from_function_response(name=msg.name or "", response={"result": msg.content})
|
88
|
+
]
|
92
89
|
converted_messages.append(content)
|
93
90
|
return converted_messages
|
94
91
|
|
@@ -101,15 +98,19 @@ class GeminiProvider(Provider):
|
|
101
98
|
|
102
99
|
def gen_gemini_functions(self) -> List[Callable[..., Any]]:
|
103
100
|
"""Wrap Gemini functions from OpenAI functions for automatic function calling"""
|
104
|
-
func_name_map = get_func_name_map()
|
105
|
-
if not func_name_map:
|
106
|
-
return []
|
107
101
|
funcs = []
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
102
|
+
|
103
|
+
# Add regular functions
|
104
|
+
if self.enable_function:
|
105
|
+
funcs.extend(get_functions_gemini_format())
|
106
|
+
|
107
|
+
# Add MCP functions if enabled
|
108
|
+
if self.enable_mcp:
|
109
|
+
try:
|
110
|
+
mcp_tools = get_mcp_manager().to_gemini_tools()
|
111
|
+
funcs.extend(mcp_tools)
|
112
|
+
except (ImportError, Exception) as e:
|
113
|
+
self.console.print(f"Failed to load MCP tools for Gemini: {e}", style="red")
|
113
114
|
return funcs
|
114
115
|
|
115
116
|
def completion(
|
@@ -137,14 +138,14 @@ class GeminiProvider(Provider):
|
|
137
138
|
self.console.print(gemini_messages)
|
138
139
|
chat_config = self.get_chat_config()
|
139
140
|
chat_config.system_instruction = messages[0].content
|
140
|
-
chat = self.client.chats.create(model=self.config["MODEL"], history=gemini_messages, config=chat_config)
|
141
|
+
chat = self.client.chats.create(model=self.config["MODEL"], history=gemini_messages, config=chat_config) # type: ignore
|
141
142
|
message = messages[-1].content
|
142
143
|
|
143
144
|
if stream:
|
144
|
-
response = chat.send_message_stream(message=message)
|
145
|
+
response = chat.send_message_stream(message=message) # type: ignore
|
145
146
|
yield from self._handle_stream_response(response)
|
146
147
|
else:
|
147
|
-
response = chat.send_message(message=message)
|
148
|
+
response = chat.send_message(message=message) # type: ignore
|
148
149
|
yield from self._handle_normal_response(response)
|
149
150
|
|
150
151
|
def _handle_normal_response(self, response) -> Generator[LLMResponse, None, None]:
|
@@ -158,7 +159,7 @@ class GeminiProvider(Provider):
|
|
158
159
|
return
|
159
160
|
for part in response.candidates[0].content.parts:
|
160
161
|
if part.thought:
|
161
|
-
yield LLMResponse(reasoning=part.text,
|
162
|
+
yield LLMResponse(reasoning=part.text, finish_reason="stop")
|
162
163
|
else:
|
163
164
|
yield LLMResponse(reasoning=None, content=part.text, finish_reason="stop")
|
164
165
|
|
@@ -181,7 +182,7 @@ class GeminiProvider(Provider):
|
|
181
182
|
reasoning = None
|
182
183
|
yield LLMResponse(
|
183
184
|
reasoning=reasoning,
|
184
|
-
content=content,
|
185
|
+
content=content or "",
|
185
186
|
tool_call=tool_call if finish_reason == "tool_calls" else None,
|
186
187
|
finish_reason=finish_reason or None,
|
187
188
|
)
|
@@ -0,0 +1,40 @@
|
|
1
|
+
from typing import Any, Dict
|
2
|
+
|
3
|
+
from huggingface_hub import InferenceClient
|
4
|
+
|
5
|
+
from .chatglm_provider import ChatglmProvider
|
6
|
+
|
7
|
+
|
8
|
+
class HuggingFaceProvider(ChatglmProvider):
|
9
|
+
"""
|
10
|
+
HuggingFaceProvider is a provider for the HuggingFace API.
|
11
|
+
"""
|
12
|
+
|
13
|
+
CLIENT_CLS = InferenceClient
|
14
|
+
DEFAULT_PROVIDER = "auto"
|
15
|
+
|
16
|
+
COMPLETION_PARAMS_KEYS = {
|
17
|
+
"model": "MODEL",
|
18
|
+
"temperature": "TEMPERATURE",
|
19
|
+
"top_p": "TOP_P",
|
20
|
+
"max_tokens": "MAX_TOKENS",
|
21
|
+
"extra_body": "EXTRA_BODY",
|
22
|
+
}
|
23
|
+
|
24
|
+
def get_client_params(self) -> Dict[str, Any]:
|
25
|
+
client_params = {
|
26
|
+
"api_key": self.config["API_KEY"],
|
27
|
+
"timeout": self.config["TIMEOUT"],
|
28
|
+
"provider": self.config.get("HF_PROVIDER") or self.DEFAULT_PROVIDER,
|
29
|
+
}
|
30
|
+
if self.config["BASE_URL"]:
|
31
|
+
client_params["base_url"] = self.config["BASE_URL"]
|
32
|
+
if self.config["EXTRA_HEADERS"]:
|
33
|
+
client_params["headers"] = {
|
34
|
+
**self.config["EXTRA_HEADERS"],
|
35
|
+
"X-Title": self.APP_NAME,
|
36
|
+
"HTTP-Referer": self.APP_REFERER,
|
37
|
+
}
|
38
|
+
if self.config.get("BILL_TO"):
|
39
|
+
client_params["bill_to"] = self.config["BILL_TO"]
|
40
|
+
return client_params
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import Any, Dict
|
2
2
|
|
3
|
+
from ...config import cfg
|
3
4
|
from .openai_provider import OpenAIProvider
|
4
5
|
|
5
6
|
|
@@ -8,7 +9,7 @@ class InfiniAIProvider(OpenAIProvider):
|
|
8
9
|
|
9
10
|
DEFAULT_BASE_URL = "https://cloud.infini-ai.com/maas/v1"
|
10
11
|
|
11
|
-
def __init__(self, config: dict =
|
12
|
+
def __init__(self, config: dict = cfg, **kwargs):
|
12
13
|
super().__init__(config, **kwargs)
|
13
14
|
if self.enable_function:
|
14
15
|
self.console.print("InfiniAI does not support functions, disabled", style="yellow")
|
@@ -16,5 +17,6 @@ class InfiniAIProvider(OpenAIProvider):
|
|
16
17
|
|
17
18
|
def get_completion_params(self) -> Dict[str, Any]:
|
18
19
|
params = super().get_completion_params()
|
19
|
-
|
20
|
+
if "max_completion_tokens" in params:
|
21
|
+
params["max_tokens"] = params.pop("max_completion_tokens")
|
20
22
|
return params
|
@@ -10,5 +10,6 @@ class ModelScopeProvider(OpenAIProvider):
|
|
10
10
|
|
11
11
|
def get_completion_params(self) -> Dict[str, Any]:
|
12
12
|
params = super().get_completion_params()
|
13
|
-
|
13
|
+
if "max_completion_tokens" in params:
|
14
|
+
params["max_tokens"] = params.pop("max_completion_tokens")
|
14
15
|
return params
|
@@ -10,6 +10,7 @@ from ...config import cfg
|
|
10
10
|
from ...console import get_console
|
11
11
|
from ...schemas import ChatMessage, LLMResponse, ToolCall
|
12
12
|
from ...tools import get_openai_schemas
|
13
|
+
from ...tools.mcp import get_mcp_manager
|
13
14
|
from ..provider import Provider
|
14
15
|
|
15
16
|
|
@@ -19,7 +20,7 @@ class OpenAIProvider(Provider):
|
|
19
20
|
DEFAULT_BASE_URL = "https://api.openai.com/v1"
|
20
21
|
CLIENT_CLS = openai.OpenAI
|
21
22
|
# Base mapping between config keys and API parameter names
|
22
|
-
|
23
|
+
COMPLETION_PARAMS_KEYS = {
|
23
24
|
"model": "MODEL",
|
24
25
|
"temperature": "TEMPERATURE",
|
25
26
|
"top_p": "TOP_P",
|
@@ -34,6 +35,7 @@ class OpenAIProvider(Provider):
|
|
34
35
|
if not self.config.get("API_KEY"):
|
35
36
|
raise ValueError("API_KEY is required")
|
36
37
|
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
38
|
+
self.enable_mcp = self.config["ENABLE_MCP"]
|
37
39
|
self.verbose = verbose
|
38
40
|
|
39
41
|
# Initialize client
|
@@ -50,15 +52,12 @@ class OpenAIProvider(Provider):
|
|
50
52
|
client_params = {
|
51
53
|
"api_key": self.config["API_KEY"],
|
52
54
|
"base_url": self.config["BASE_URL"] or self.DEFAULT_BASE_URL,
|
55
|
+
"default_headers": {"X-Title": self.APP_NAME, "HTTP_Referer": self.APP_REFERER},
|
53
56
|
}
|
54
57
|
|
55
58
|
# Add extra headers if set
|
56
59
|
if self.config["EXTRA_HEADERS"]:
|
57
|
-
client_params["default_headers"] = {
|
58
|
-
**self.config["EXTRA_HEADERS"],
|
59
|
-
"X-Title": self.APP_NAME,
|
60
|
-
"HTTP-Referer": self.APP_REFERER,
|
61
|
-
}
|
60
|
+
client_params["default_headers"] = {**self.config["EXTRA_HEADERS"], **client_params["default_headers"]}
|
62
61
|
return client_params
|
63
62
|
|
64
63
|
def get_completion_params_keys(self) -> Dict[str, str]:
|
@@ -69,7 +68,7 @@ class OpenAIProvider(Provider):
|
|
69
68
|
Returns:
|
70
69
|
Dict[str, str]: Mapping from API parameter names to config keys
|
71
70
|
"""
|
72
|
-
return self.
|
71
|
+
return self.COMPLETION_PARAMS_KEYS.copy()
|
73
72
|
|
74
73
|
def get_completion_params(self) -> Dict[str, Any]:
|
75
74
|
"""
|
@@ -89,7 +88,7 @@ class OpenAIProvider(Provider):
|
|
89
88
|
"""Convert a list of ChatMessage objects to a list of OpenAI message dicts."""
|
90
89
|
converted_messages = []
|
91
90
|
for msg in messages:
|
92
|
-
message = {"role": msg.role, "content": msg.content or ""}
|
91
|
+
message: Dict[str, Any] = {"role": msg.role, "content": msg.content or ""}
|
93
92
|
|
94
93
|
if msg.name:
|
95
94
|
message["name"] = msg.name
|
@@ -134,11 +133,21 @@ class OpenAIProvider(Provider):
|
|
134
133
|
params = self.completion_params.copy()
|
135
134
|
params["messages"] = openai_messages
|
136
135
|
params["stream"] = stream
|
136
|
+
tools = []
|
137
137
|
|
138
138
|
if self.enable_function:
|
139
|
-
tools
|
140
|
-
|
141
|
-
|
139
|
+
tools.extend(get_openai_schemas())
|
140
|
+
|
141
|
+
# Add MCP tools if enabled
|
142
|
+
if self.enable_mcp:
|
143
|
+
try:
|
144
|
+
mcp_tools = get_mcp_manager().to_openai_tools()
|
145
|
+
except (ValueError, FileNotFoundError) as e:
|
146
|
+
self.console.print(f"Failed to load MCP tools: {e}", style="red")
|
147
|
+
mcp_tools = []
|
148
|
+
tools.extend(mcp_tools)
|
149
|
+
if tools:
|
150
|
+
params["tools"] = tools
|
142
151
|
|
143
152
|
try:
|
144
153
|
if stream:
|
@@ -10,5 +10,6 @@ class SiliconFlowProvider(OpenAIProvider):
|
|
10
10
|
|
11
11
|
def get_completion_params(self) -> Dict[str, Any]:
|
12
12
|
params = super().get_completion_params()
|
13
|
-
|
13
|
+
if "max_completion_tokens" in params:
|
14
|
+
params["max_tokens"] = params.pop("max_completion_tokens")
|
14
15
|
return params
|
yaicli/tools/__init__.py
ADDED
@@ -0,0 +1,127 @@
|
|
1
|
+
from typing import Any, Dict, List, Tuple, cast
|
2
|
+
|
3
|
+
from json_repair import repair_json
|
4
|
+
from mcp import types
|
5
|
+
from rich.panel import Panel
|
6
|
+
|
7
|
+
from ..config import cfg
|
8
|
+
from ..console import get_console
|
9
|
+
from ..schemas import ToolCall
|
10
|
+
from .function import get_function, list_functions
|
11
|
+
from .mcp import MCP_TOOL_NAME_PREFIX, get_mcp, get_mcp_manager, parse_mcp_tool_name
|
12
|
+
|
13
|
+
console = get_console()
|
14
|
+
|
15
|
+
|
16
|
+
def get_openai_schemas() -> List[Dict[str, Any]]:
|
17
|
+
"""Get OpenAI-compatible function schemas
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
List of function schemas in OpenAI format
|
21
|
+
"""
|
22
|
+
transformed_schemas = []
|
23
|
+
for function in list_functions():
|
24
|
+
schema = {
|
25
|
+
"type": "function",
|
26
|
+
"function": {
|
27
|
+
"name": function.name,
|
28
|
+
"description": function.description,
|
29
|
+
"parameters": function.parameters,
|
30
|
+
},
|
31
|
+
}
|
32
|
+
transformed_schemas.append(schema)
|
33
|
+
return transformed_schemas
|
34
|
+
|
35
|
+
|
36
|
+
def get_openai_mcp_tools() -> list[dict[str, Any]]:
|
37
|
+
"""Get OpenAI-compatible function schemas
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
List of function schemas in OpenAI format
|
41
|
+
"""
|
42
|
+
return get_mcp_manager().to_openai_tools()
|
43
|
+
|
44
|
+
|
45
|
+
def execute_mcp_tool(tool_name: str, tool_kwargs: dict) -> str:
|
46
|
+
"""Execute an MCP tool
|
47
|
+
|
48
|
+
Args:
|
49
|
+
tool_name: The name of the tool to execute
|
50
|
+
tool_kwargs: The arguments to pass to the tool
|
51
|
+
"""
|
52
|
+
manager = get_mcp_manager()
|
53
|
+
tool = manager.get_tool(tool_name)
|
54
|
+
try:
|
55
|
+
result = tool.execute(**tool_kwargs)
|
56
|
+
if isinstance(result, list) and len(result) > 0:
|
57
|
+
result = result[0]
|
58
|
+
if isinstance(result, types.TextContent):
|
59
|
+
return result.text
|
60
|
+
else:
|
61
|
+
return str(result)
|
62
|
+
except Exception as e:
|
63
|
+
error_msg = f"Call MCP tool error:\nTool name: {tool_name!r}\nArguments: {tool_kwargs!r}\nError: {e}"
|
64
|
+
console.print(error_msg, style="red")
|
65
|
+
return error_msg
|
66
|
+
|
67
|
+
|
68
|
+
def execute_tool_call(tool_call: ToolCall) -> Tuple[str, bool]:
|
69
|
+
"""Execute a tool call and return the result
|
70
|
+
|
71
|
+
Args:
|
72
|
+
tool_call: The tool call to execute
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Tuple[str, bool]: (result text, success flag)
|
76
|
+
"""
|
77
|
+
is_function_call = not tool_call.name.startswith(MCP_TOOL_NAME_PREFIX)
|
78
|
+
if is_function_call:
|
79
|
+
get_tool_func = get_function
|
80
|
+
show_output = cfg["SHOW_FUNCTION_OUTPUT"]
|
81
|
+
_type = "function"
|
82
|
+
else:
|
83
|
+
tool_call.name = parse_mcp_tool_name(tool_call.name)
|
84
|
+
get_tool_func = get_mcp
|
85
|
+
show_output = cfg["SHOW_MCP_OUTPUT"]
|
86
|
+
_type = "mcp"
|
87
|
+
|
88
|
+
console.print(f"@{_type.title()} call: {tool_call.name}({tool_call.arguments})", style="blue")
|
89
|
+
# 1. Get the tool
|
90
|
+
try:
|
91
|
+
tool = get_tool_func(tool_call.name)
|
92
|
+
except ValueError as e:
|
93
|
+
error_msg = f"{_type.title()} '{tool_call.name!r}' not exists: {e}"
|
94
|
+
console.print(error_msg, style="red")
|
95
|
+
return error_msg, False
|
96
|
+
|
97
|
+
# 2. Parse tool arguments
|
98
|
+
try:
|
99
|
+
arguments = repair_json(tool_call.arguments, return_objects=True)
|
100
|
+
if not isinstance(arguments, dict):
|
101
|
+
error_msg = f"Invalid arguments type: {arguments!r}, should be JSON object"
|
102
|
+
console.print(error_msg, style="red")
|
103
|
+
return error_msg, False
|
104
|
+
arguments = cast(dict, arguments)
|
105
|
+
except Exception as e:
|
106
|
+
error_msg = f"Invalid arguments from llm: {e}\nRaw arguments: {tool_call.arguments!r}"
|
107
|
+
console.print(error_msg, style="red")
|
108
|
+
return error_msg, False
|
109
|
+
|
110
|
+
# 3. Execute the tool
|
111
|
+
try:
|
112
|
+
result = tool.execute(**arguments)
|
113
|
+
if show_output:
|
114
|
+
panel = Panel(
|
115
|
+
result,
|
116
|
+
title=f"{_type.title()} output",
|
117
|
+
title_align="left",
|
118
|
+
expand=False,
|
119
|
+
border_style="blue",
|
120
|
+
style="dim",
|
121
|
+
)
|
122
|
+
console.print(panel)
|
123
|
+
return result, True
|
124
|
+
except Exception as e:
|
125
|
+
error_msg = f"Call {_type} error: {e}\n{_type} name: {tool_call.name!r}\nArguments: {arguments!r}"
|
126
|
+
console.print(error_msg, style="red")
|
127
|
+
return error_msg, False
|
yaicli/tools/function.py
ADDED
@@ -0,0 +1,90 @@
|
|
1
|
+
import importlib.util
|
2
|
+
import sys
|
3
|
+
from typing import Callable, List, Optional
|
4
|
+
|
5
|
+
from instructor import OpenAISchema
|
6
|
+
|
7
|
+
from ..const import FUNCTIONS_DIR
|
8
|
+
from ..utils import wrap_function
|
9
|
+
|
10
|
+
|
11
|
+
class Function:
|
12
|
+
"""Function description class"""
|
13
|
+
|
14
|
+
def __init__(self, function: type[OpenAISchema]):
|
15
|
+
self.name = function.openai_schema["name"]
|
16
|
+
self.description = function.openai_schema.get("description", "")
|
17
|
+
self.parameters = function.openai_schema.get("parameters", {})
|
18
|
+
self.execute = function.execute # type: ignore
|
19
|
+
|
20
|
+
|
21
|
+
_func_name_map: Optional[dict[str, Function]] = None
|
22
|
+
|
23
|
+
|
24
|
+
def get_func_name_map() -> dict[str, Function]:
|
25
|
+
"""Get function name map"""
|
26
|
+
global _func_name_map
|
27
|
+
if _func_name_map:
|
28
|
+
return _func_name_map
|
29
|
+
if not FUNCTIONS_DIR.exists():
|
30
|
+
FUNCTIONS_DIR.mkdir(parents=True, exist_ok=True)
|
31
|
+
return {}
|
32
|
+
functions = []
|
33
|
+
for file in FUNCTIONS_DIR.glob("*.py"):
|
34
|
+
if file.name.startswith("_"):
|
35
|
+
continue
|
36
|
+
module_name = str(file).replace("/", ".").rstrip(".py")
|
37
|
+
spec = importlib.util.spec_from_file_location(module_name, str(file))
|
38
|
+
module = importlib.util.module_from_spec(spec) # type: ignore
|
39
|
+
sys.modules[module_name] = module
|
40
|
+
spec.loader.exec_module(module) # type: ignore
|
41
|
+
|
42
|
+
if not issubclass(module.Function, OpenAISchema):
|
43
|
+
raise TypeError(f"Function {module_name} must be a subclass of instructor.OpenAISchema")
|
44
|
+
if not hasattr(module.Function, "execute"):
|
45
|
+
raise TypeError(f"Function {module_name} must have an 'execute' classmethod")
|
46
|
+
|
47
|
+
# Add to function list
|
48
|
+
functions.append(Function(function=module.Function))
|
49
|
+
|
50
|
+
# Cache the function list
|
51
|
+
_func_name_map = {func.name: func for func in functions}
|
52
|
+
return _func_name_map
|
53
|
+
|
54
|
+
|
55
|
+
def list_functions() -> list[Function]:
|
56
|
+
"""List all available buildin functions"""
|
57
|
+
global _func_name_map
|
58
|
+
if not _func_name_map:
|
59
|
+
_func_name_map = get_func_name_map()
|
60
|
+
|
61
|
+
return list(_func_name_map.values())
|
62
|
+
|
63
|
+
|
64
|
+
def get_function(name: str) -> Function:
|
65
|
+
"""Get a function by name
|
66
|
+
|
67
|
+
Args:
|
68
|
+
name: Function name
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
Function execute method
|
72
|
+
|
73
|
+
Raises:
|
74
|
+
ValueError: If function not found
|
75
|
+
"""
|
76
|
+
func_map = get_func_name_map()
|
77
|
+
if name in func_map:
|
78
|
+
return func_map[name]
|
79
|
+
raise ValueError(f"Function {name!r} not found")
|
80
|
+
|
81
|
+
|
82
|
+
def get_functions_gemini_format() -> List[Callable]:
|
83
|
+
"""Get functions in gemini format"""
|
84
|
+
gemini_functions = []
|
85
|
+
for func_name, func in get_func_name_map().items():
|
86
|
+
wrapped_func = wrap_function(func.execute)
|
87
|
+
wrapped_func.__name__ = func_name
|
88
|
+
wrapped_func.__doc__ = func.description
|
89
|
+
gemini_functions.append(wrapped_func)
|
90
|
+
return gemini_functions
|