arcade-core 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arcade_core/__init__.py +2 -0
- arcade_core/annotations.py +8 -0
- arcade_core/auth.py +177 -0
- arcade_core/catalog.py +894 -0
- arcade_core/config.py +23 -0
- arcade_core/config_model.py +146 -0
- arcade_core/errors.py +103 -0
- arcade_core/executor.py +129 -0
- arcade_core/output.py +64 -0
- arcade_core/parse.py +63 -0
- arcade_core/py.typed +0 -0
- arcade_core/schema.py +441 -0
- arcade_core/telemetry.py +130 -0
- arcade_core/toolkit.py +155 -0
- arcade_core/utils.py +99 -0
- arcade_core/version.py +1 -0
- arcade_core-2.0.0.dist-info/METADATA +77 -0
- arcade_core-2.0.0.dist-info/RECORD +19 -0
- arcade_core-2.0.0.dist-info/WHEEL +4 -0
arcade_core/config.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
|
|
3
|
+
from arcade_core.config_model import Config
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@lru_cache(maxsize=1)
|
|
7
|
+
def get_config() -> Config:
|
|
8
|
+
"""
|
|
9
|
+
Get the Arcade configuration.
|
|
10
|
+
|
|
11
|
+
This function is cached, so subsequent calls will return the same Config object
|
|
12
|
+
without reloading from the file, unless the cache is cleared.
|
|
13
|
+
|
|
14
|
+
remember to clear the cache if the configuration file is modified.
|
|
15
|
+
use `get_config.cache_clear()` to clear the cache.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
Config: The Arcade configuration.
|
|
19
|
+
"""
|
|
20
|
+
return Config.load_from_file()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
config = get_config()
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, ValidationError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseConfig(BaseModel):
|
|
10
|
+
model_config = ConfigDict(extra="ignore")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ApiConfig(BaseConfig):
|
|
14
|
+
"""
|
|
15
|
+
Arcade API configuration.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
key: str
|
|
19
|
+
"""
|
|
20
|
+
Arcade API key.
|
|
21
|
+
"""
|
|
22
|
+
version: str = "v1"
|
|
23
|
+
"""
|
|
24
|
+
Arcade API version.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class UserConfig(BaseConfig):
|
|
29
|
+
"""
|
|
30
|
+
Arcade user configuration.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
email: str | None = None
|
|
34
|
+
"""
|
|
35
|
+
User email.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Config(BaseConfig):
|
|
40
|
+
"""
|
|
41
|
+
Configuration for Arcade.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
api: ApiConfig
|
|
45
|
+
"""
|
|
46
|
+
Arcade API configuration.
|
|
47
|
+
"""
|
|
48
|
+
user: UserConfig | None = None
|
|
49
|
+
"""
|
|
50
|
+
Arcade user configuration.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, **data: Any):
|
|
54
|
+
super().__init__(**data)
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def get_config_dir_path(cls) -> Path:
|
|
58
|
+
"""
|
|
59
|
+
Get the path to the Arcade configuration directory.
|
|
60
|
+
"""
|
|
61
|
+
config_path = os.getenv("ARCADE_WORK_DIR") or Path.home() / ".arcade"
|
|
62
|
+
return Path(config_path).resolve()
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def get_config_file_path(cls) -> Path:
|
|
66
|
+
"""
|
|
67
|
+
Get the path to the Arcade configuration file.
|
|
68
|
+
"""
|
|
69
|
+
return cls.get_config_dir_path() / "credentials.yaml"
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def ensure_config_dir_exists(cls) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Create the configuration directory if it does not exist.
|
|
75
|
+
"""
|
|
76
|
+
config_dir = Config.get_config_dir_path()
|
|
77
|
+
if not config_dir.exists():
|
|
78
|
+
config_dir.mkdir(parents=True, exist_ok=True)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def load_from_file(cls) -> "Config":
|
|
82
|
+
"""
|
|
83
|
+
Load the configuration from the YAML file in the configuration directory.
|
|
84
|
+
|
|
85
|
+
If no configuration file exists, this method will create a new one with default values.
|
|
86
|
+
The default configuration includes:
|
|
87
|
+
- An empty API configuration
|
|
88
|
+
- A default Engine configuration (host: "api.arcade.dev", port: None, tls: True)
|
|
89
|
+
- No user configuration
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Config: The loaded or newly created configuration.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: If the existing configuration file is invalid.
|
|
96
|
+
"""
|
|
97
|
+
cls.ensure_config_dir_exists()
|
|
98
|
+
|
|
99
|
+
config_file_path = cls.get_config_file_path()
|
|
100
|
+
|
|
101
|
+
if not config_file_path.exists():
|
|
102
|
+
# Create a file using the default configuration
|
|
103
|
+
default_config = cls.model_construct(api=ApiConfig.model_construct())
|
|
104
|
+
default_config.save_to_file()
|
|
105
|
+
|
|
106
|
+
config_data = yaml.safe_load(config_file_path.read_text())
|
|
107
|
+
|
|
108
|
+
if config_data is None:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Invalid credentials.yaml file. Please ensure it is a valid YAML file."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if "cloud" not in config_data:
|
|
114
|
+
raise ValueError("Invalid credentials.yaml file. Expected a 'cloud' key.")
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
return cls(**config_data["cloud"])
|
|
118
|
+
except ValidationError as e:
|
|
119
|
+
# Get only the errors with {type:missing} and combine them
|
|
120
|
+
# into a nicely-formatted string message.
|
|
121
|
+
# Any other errors without {type:missing} should just be str()ed
|
|
122
|
+
missing_field_errors = [
|
|
123
|
+
".".join(map(str, error["loc"]))
|
|
124
|
+
for error in e.errors()
|
|
125
|
+
if error["type"] == "missing"
|
|
126
|
+
]
|
|
127
|
+
other_errors = [str(error) for error in e.errors() if error["type"] != "missing"]
|
|
128
|
+
|
|
129
|
+
missing_field_errors_str = ", ".join(missing_field_errors)
|
|
130
|
+
other_errors_str = "\n".join(other_errors)
|
|
131
|
+
|
|
132
|
+
pretty_str: str = "Invalid Arcade configuration."
|
|
133
|
+
if missing_field_errors_str:
|
|
134
|
+
pretty_str += f"\nMissing fields: {missing_field_errors_str}\n"
|
|
135
|
+
if other_errors_str:
|
|
136
|
+
pretty_str += f"\nOther errors:\n{other_errors_str}"
|
|
137
|
+
|
|
138
|
+
raise ValueError(pretty_str) from e
|
|
139
|
+
|
|
140
|
+
def save_to_file(self) -> None:
|
|
141
|
+
"""
|
|
142
|
+
Save the configuration to the YAML file in the configuration directory.
|
|
143
|
+
"""
|
|
144
|
+
Config.ensure_config_dir_exists()
|
|
145
|
+
config_file_path = Config.get_config_file_path()
|
|
146
|
+
config_file_path.write_text(yaml.dump(self.model_dump()))
|
arcade_core/errors.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ToolkitError(Exception):
|
|
6
|
+
"""
|
|
7
|
+
Base class for all errors related to toolkits.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ToolkitLoadError(ToolkitError):
|
|
14
|
+
"""
|
|
15
|
+
Raised when there is an error loading a toolkit.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ToolError(Exception):
|
|
22
|
+
"""
|
|
23
|
+
Base class for all errors related to tools.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ToolDefinitionError(ToolError):
|
|
30
|
+
"""
|
|
31
|
+
Raised when there is an error in the definition of a tool.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ------ runtime errors ------
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ToolRuntimeError(RuntimeError):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
message: str,
|
|
44
|
+
developer_message: Optional[str] = None,
|
|
45
|
+
):
|
|
46
|
+
super().__init__(message)
|
|
47
|
+
self.message = message
|
|
48
|
+
self.developer_message = developer_message
|
|
49
|
+
|
|
50
|
+
def traceback_info(self) -> str | None:
|
|
51
|
+
# return the traceback information of the parent exception
|
|
52
|
+
if self.__cause__:
|
|
53
|
+
return "\n".join(traceback.format_exception(self.__cause__))
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ToolExecutionError(ToolRuntimeError):
|
|
58
|
+
"""
|
|
59
|
+
Raised when there is an error executing a tool.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class RetryableToolError(ToolExecutionError):
|
|
66
|
+
"""
|
|
67
|
+
Raised when a tool error is retryable.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
message: str,
|
|
73
|
+
developer_message: Optional[str] = None,
|
|
74
|
+
additional_prompt_content: Optional[str] = None,
|
|
75
|
+
retry_after_ms: Optional[int] = None,
|
|
76
|
+
):
|
|
77
|
+
super().__init__(message, developer_message)
|
|
78
|
+
self.additional_prompt_content = additional_prompt_content
|
|
79
|
+
self.retry_after_ms = retry_after_ms
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ToolSerializationError(ToolRuntimeError):
|
|
83
|
+
"""
|
|
84
|
+
Raised when there is an error executing a tool.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ToolInputError(ToolSerializationError):
|
|
91
|
+
"""
|
|
92
|
+
Raised when there is an error in the input to a tool.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ToolOutputError(ToolSerializationError):
|
|
99
|
+
"""
|
|
100
|
+
Raised when there is an error in the output of a tool.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
pass
|
arcade_core/executor.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import traceback
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ValidationError
|
|
6
|
+
|
|
7
|
+
from arcade_core.errors import (
|
|
8
|
+
RetryableToolError,
|
|
9
|
+
ToolInputError,
|
|
10
|
+
ToolOutputError,
|
|
11
|
+
ToolRuntimeError,
|
|
12
|
+
ToolSerializationError,
|
|
13
|
+
)
|
|
14
|
+
from arcade_core.output import output_factory
|
|
15
|
+
from arcade_core.schema import ToolCallLog, ToolCallOutput, ToolContext, ToolDefinition
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ToolExecutor:
|
|
19
|
+
@staticmethod
|
|
20
|
+
async def run(
|
|
21
|
+
func: Callable,
|
|
22
|
+
definition: ToolDefinition,
|
|
23
|
+
input_model: type[BaseModel],
|
|
24
|
+
output_model: type[BaseModel],
|
|
25
|
+
context: ToolContext,
|
|
26
|
+
*args: Any,
|
|
27
|
+
**kwargs: Any,
|
|
28
|
+
) -> ToolCallOutput:
|
|
29
|
+
"""
|
|
30
|
+
Execute a callable function with validated inputs and outputs via Pydantic models.
|
|
31
|
+
"""
|
|
32
|
+
# only gathering deprecation log for now
|
|
33
|
+
tool_call_logs = []
|
|
34
|
+
if definition.deprecation_message is not None:
|
|
35
|
+
tool_call_logs.append(
|
|
36
|
+
ToolCallLog(
|
|
37
|
+
message=definition.deprecation_message, level="warning", subtype="deprecation"
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
# serialize the input model
|
|
43
|
+
inputs = await ToolExecutor._serialize_input(input_model, **kwargs)
|
|
44
|
+
|
|
45
|
+
# prepare the arguments for the function call
|
|
46
|
+
func_args = inputs.model_dump()
|
|
47
|
+
|
|
48
|
+
# inject ToolContext, if the target function supports it
|
|
49
|
+
if definition.input.tool_context_parameter_name is not None:
|
|
50
|
+
func_args[definition.input.tool_context_parameter_name] = context
|
|
51
|
+
|
|
52
|
+
# execute the tool function
|
|
53
|
+
if asyncio.iscoroutinefunction(func):
|
|
54
|
+
results = await func(**func_args)
|
|
55
|
+
else:
|
|
56
|
+
results = func(**func_args)
|
|
57
|
+
|
|
58
|
+
# serialize the output model
|
|
59
|
+
output = await ToolExecutor._serialize_output(output_model, results)
|
|
60
|
+
|
|
61
|
+
# return the output
|
|
62
|
+
return output_factory.success(data=output, logs=tool_call_logs)
|
|
63
|
+
|
|
64
|
+
except RetryableToolError as e:
|
|
65
|
+
return output_factory.fail_retry(
|
|
66
|
+
message=e.message,
|
|
67
|
+
developer_message=e.developer_message,
|
|
68
|
+
additional_prompt_content=e.additional_prompt_content,
|
|
69
|
+
retry_after_ms=e.retry_after_ms,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
except ToolSerializationError as e:
|
|
73
|
+
return output_factory.fail(message=e.message, developer_message=e.developer_message)
|
|
74
|
+
|
|
75
|
+
# should catch all tool exceptions due to the try/except in the tool decorator
|
|
76
|
+
except ToolRuntimeError as e:
|
|
77
|
+
return output_factory.fail(
|
|
78
|
+
message=e.message,
|
|
79
|
+
developer_message=e.developer_message,
|
|
80
|
+
traceback_info=e.traceback_info(),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# if we get here we're in trouble
|
|
84
|
+
except Exception as e:
|
|
85
|
+
return output_factory.fail(
|
|
86
|
+
message="Error in execution",
|
|
87
|
+
developer_message=str(e),
|
|
88
|
+
traceback_info=traceback.format_exc(),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
async def _serialize_input(input_model: type[BaseModel], **kwargs: Any) -> BaseModel:
|
|
93
|
+
"""
|
|
94
|
+
Serialize the input to a tool function.
|
|
95
|
+
"""
|
|
96
|
+
try:
|
|
97
|
+
# TODO Logging and telemetry
|
|
98
|
+
|
|
99
|
+
# build in the input model to the tool function
|
|
100
|
+
inputs = input_model(**kwargs)
|
|
101
|
+
|
|
102
|
+
except ValidationError as e:
|
|
103
|
+
raise ToolInputError(
|
|
104
|
+
message="Error in tool input deserialization", developer_message=str(e)
|
|
105
|
+
) from e
|
|
106
|
+
|
|
107
|
+
return inputs
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
async def _serialize_output(output_model: type[BaseModel], results: dict) -> BaseModel:
|
|
111
|
+
"""
|
|
112
|
+
Serialize the output of a tool function.
|
|
113
|
+
"""
|
|
114
|
+
# TODO how to type this the results object?
|
|
115
|
+
# TODO how to ensure `results` contains only safe (serializable) stuff?
|
|
116
|
+
try:
|
|
117
|
+
# TODO Logging and telemetry
|
|
118
|
+
|
|
119
|
+
# build the output model
|
|
120
|
+
output = output_model(**{"result": results})
|
|
121
|
+
|
|
122
|
+
except ValidationError as e:
|
|
123
|
+
raise ToolOutputError(
|
|
124
|
+
message="Failed to serialize tool output",
|
|
125
|
+
developer_message=f"Validation error occurred while serializing tool output: {e!s}. "
|
|
126
|
+
f"Please ensure the tool's output matches the expected schema.",
|
|
127
|
+
) from e
|
|
128
|
+
|
|
129
|
+
return output
|
arcade_core/output.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import TypeVar
|
|
2
|
+
|
|
3
|
+
from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput
|
|
4
|
+
from arcade_core.utils import coerce_empty_list_to_none
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ToolOutputFactory:
|
|
10
|
+
"""
|
|
11
|
+
Singleton pattern for unified return method from tools.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def success(
|
|
15
|
+
self,
|
|
16
|
+
*,
|
|
17
|
+
data: T | None = None,
|
|
18
|
+
logs: list[ToolCallLog] | None = None,
|
|
19
|
+
) -> ToolCallOutput:
|
|
20
|
+
value = getattr(data, "result", "") if data else ""
|
|
21
|
+
logs = coerce_empty_list_to_none(logs)
|
|
22
|
+
return ToolCallOutput(value=value, logs=logs)
|
|
23
|
+
|
|
24
|
+
def fail(
|
|
25
|
+
self,
|
|
26
|
+
*,
|
|
27
|
+
message: str,
|
|
28
|
+
developer_message: str | None = None,
|
|
29
|
+
traceback_info: str | None = None,
|
|
30
|
+
logs: list[ToolCallLog] | None = None,
|
|
31
|
+
) -> ToolCallOutput:
|
|
32
|
+
return ToolCallOutput(
|
|
33
|
+
error=ToolCallError(
|
|
34
|
+
message=message,
|
|
35
|
+
developer_message=developer_message,
|
|
36
|
+
can_retry=False,
|
|
37
|
+
traceback_info=traceback_info,
|
|
38
|
+
),
|
|
39
|
+
logs=coerce_empty_list_to_none(logs),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def fail_retry(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
message: str,
|
|
46
|
+
developer_message: str | None = None,
|
|
47
|
+
additional_prompt_content: str | None = None,
|
|
48
|
+
retry_after_ms: int | None = None,
|
|
49
|
+
traceback_info: str | None = None,
|
|
50
|
+
logs: list[ToolCallLog] | None = None,
|
|
51
|
+
) -> ToolCallOutput:
|
|
52
|
+
return ToolCallOutput(
|
|
53
|
+
error=ToolCallError(
|
|
54
|
+
message=message,
|
|
55
|
+
developer_message=developer_message,
|
|
56
|
+
can_retry=True,
|
|
57
|
+
additional_prompt_content=additional_prompt_content,
|
|
58
|
+
retry_after_ms=retry_after_ms,
|
|
59
|
+
),
|
|
60
|
+
logs=coerce_empty_list_to_none(logs),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
output_factory = ToolOutputFactory()
|
arcade_core/parse.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def load_ast_tree(filepath: str | Path) -> ast.AST:
|
|
7
|
+
"""
|
|
8
|
+
Load and parse the Abstract Syntax Tree (AST) from a Python file.
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
try:
|
|
12
|
+
with open(filepath) as file:
|
|
13
|
+
return ast.parse(file.read(), filename=filepath)
|
|
14
|
+
except FileNotFoundError:
|
|
15
|
+
raise FileNotFoundError(f"File {filepath} not found")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_function_name_if_decorated(
|
|
19
|
+
node: Union[ast.FunctionDef, ast.AsyncFunctionDef],
|
|
20
|
+
) -> Optional[str]:
|
|
21
|
+
"""
|
|
22
|
+
Check if a function has a decorator.
|
|
23
|
+
"""
|
|
24
|
+
decorator_ids = {"arc.tool", "tool"}
|
|
25
|
+
for decorator in node.decorator_list:
|
|
26
|
+
# if the function is decorated and the decorator is
|
|
27
|
+
# either called, or placed on the function
|
|
28
|
+
if (
|
|
29
|
+
(isinstance(decorator, ast.Name) and decorator.id in decorator_ids)
|
|
30
|
+
or (
|
|
31
|
+
isinstance(decorator, ast.Attribute)
|
|
32
|
+
and isinstance(decorator.value, ast.Name)
|
|
33
|
+
and f"{decorator.value.id}.{decorator.attr}" in decorator_ids
|
|
34
|
+
)
|
|
35
|
+
or (
|
|
36
|
+
isinstance(decorator, ast.Call)
|
|
37
|
+
and isinstance(decorator.func, ast.Name)
|
|
38
|
+
and decorator.func.id in decorator_ids
|
|
39
|
+
)
|
|
40
|
+
):
|
|
41
|
+
return node.name
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_tools_from_file(filepath: str | Path) -> list[str]:
|
|
46
|
+
"""
|
|
47
|
+
Retrieve tools from a Python file.
|
|
48
|
+
"""
|
|
49
|
+
tree = load_ast_tree(filepath)
|
|
50
|
+
return get_tools_from_ast(tree)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_tools_from_ast(tree: ast.AST) -> list[str]:
|
|
54
|
+
"""
|
|
55
|
+
Retrieve tools from Python source code.
|
|
56
|
+
"""
|
|
57
|
+
tools = []
|
|
58
|
+
for node in ast.walk(tree):
|
|
59
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
60
|
+
tool_name = get_function_name_if_decorated(node)
|
|
61
|
+
if tool_name:
|
|
62
|
+
tools.append(tool_name)
|
|
63
|
+
return tools
|
arcade_core/py.typed
ADDED
|
File without changes
|