arcade-core 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arcade_core/config.py ADDED
@@ -0,0 +1,23 @@
1
+ from functools import lru_cache
2
+
3
+ from arcade_core.config_model import Config
4
+
5
+
6
+ @lru_cache(maxsize=1)
7
+ def get_config() -> Config:
8
+ """
9
+ Get the Arcade configuration.
10
+
11
+ This function is cached, so subsequent calls will return the same Config object
12
+ without reloading from the file, unless the cache is cleared.
13
+
14
+ remember to clear the cache if the configuration file is modified.
15
+ use `get_config.cache_clear()` to clear the cache.
16
+
17
+ Returns:
18
+ Config: The Arcade configuration.
19
+ """
20
+ return Config.load_from_file()
21
+
22
+
23
+ config = get_config()
@@ -0,0 +1,146 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import yaml
6
+ from pydantic import BaseModel, ConfigDict, ValidationError
7
+
8
+
9
+ class BaseConfig(BaseModel):
10
+ model_config = ConfigDict(extra="ignore")
11
+
12
+
13
+ class ApiConfig(BaseConfig):
14
+ """
15
+ Arcade API configuration.
16
+ """
17
+
18
+ key: str
19
+ """
20
+ Arcade API key.
21
+ """
22
+ version: str = "v1"
23
+ """
24
+ Arcade API version.
25
+ """
26
+
27
+
28
+ class UserConfig(BaseConfig):
29
+ """
30
+ Arcade user configuration.
31
+ """
32
+
33
+ email: str | None = None
34
+ """
35
+ User email.
36
+ """
37
+
38
+
39
+ class Config(BaseConfig):
40
+ """
41
+ Configuration for Arcade.
42
+ """
43
+
44
+ api: ApiConfig
45
+ """
46
+ Arcade API configuration.
47
+ """
48
+ user: UserConfig | None = None
49
+ """
50
+ Arcade user configuration.
51
+ """
52
+
53
+ def __init__(self, **data: Any):
54
+ super().__init__(**data)
55
+
56
+ @classmethod
57
+ def get_config_dir_path(cls) -> Path:
58
+ """
59
+ Get the path to the Arcade configuration directory.
60
+ """
61
+ config_path = os.getenv("ARCADE_WORK_DIR") or Path.home() / ".arcade"
62
+ return Path(config_path).resolve()
63
+
64
+ @classmethod
65
+ def get_config_file_path(cls) -> Path:
66
+ """
67
+ Get the path to the Arcade configuration file.
68
+ """
69
+ return cls.get_config_dir_path() / "credentials.yaml"
70
+
71
+ @classmethod
72
+ def ensure_config_dir_exists(cls) -> None:
73
+ """
74
+ Create the configuration directory if it does not exist.
75
+ """
76
+ config_dir = Config.get_config_dir_path()
77
+ if not config_dir.exists():
78
+ config_dir.mkdir(parents=True, exist_ok=True)
79
+
80
+ @classmethod
81
+ def load_from_file(cls) -> "Config":
82
+ """
83
+ Load the configuration from the YAML file in the configuration directory.
84
+
85
+ If no configuration file exists, this method will create a new one with default values.
86
+ The default configuration includes:
87
+ - An empty API configuration
88
+ - A default Engine configuration (host: "api.arcade.dev", port: None, tls: True)
89
+ - No user configuration
90
+
91
+ Returns:
92
+ Config: The loaded or newly created configuration.
93
+
94
+ Raises:
95
+ ValueError: If the existing configuration file is invalid.
96
+ """
97
+ cls.ensure_config_dir_exists()
98
+
99
+ config_file_path = cls.get_config_file_path()
100
+
101
+ if not config_file_path.exists():
102
+ # Create a file using the default configuration
103
+ default_config = cls.model_construct(api=ApiConfig.model_construct())
104
+ default_config.save_to_file()
105
+
106
+ config_data = yaml.safe_load(config_file_path.read_text())
107
+
108
+ if config_data is None:
109
+ raise ValueError(
110
+ "Invalid credentials.yaml file. Please ensure it is a valid YAML file."
111
+ )
112
+
113
+ if "cloud" not in config_data:
114
+ raise ValueError("Invalid credentials.yaml file. Expected a 'cloud' key.")
115
+
116
+ try:
117
+ return cls(**config_data["cloud"])
118
+ except ValidationError as e:
119
+ # Get only the errors with {type:missing} and combine them
120
+ # into a nicely-formatted string message.
121
+ # Any other errors without {type:missing} should just be str()ed
122
+ missing_field_errors = [
123
+ ".".join(map(str, error["loc"]))
124
+ for error in e.errors()
125
+ if error["type"] == "missing"
126
+ ]
127
+ other_errors = [str(error) for error in e.errors() if error["type"] != "missing"]
128
+
129
+ missing_field_errors_str = ", ".join(missing_field_errors)
130
+ other_errors_str = "\n".join(other_errors)
131
+
132
+ pretty_str: str = "Invalid Arcade configuration."
133
+ if missing_field_errors_str:
134
+ pretty_str += f"\nMissing fields: {missing_field_errors_str}\n"
135
+ if other_errors_str:
136
+ pretty_str += f"\nOther errors:\n{other_errors_str}"
137
+
138
+ raise ValueError(pretty_str) from e
139
+
140
+ def save_to_file(self) -> None:
141
+ """
142
+ Save the configuration to the YAML file in the configuration directory.
143
+ """
144
+ Config.ensure_config_dir_exists()
145
+ config_file_path = Config.get_config_file_path()
146
+ config_file_path.write_text(yaml.dump(self.model_dump()))
arcade_core/errors.py ADDED
@@ -0,0 +1,103 @@
1
+ import traceback
2
+ from typing import Optional
3
+
4
+
5
+ class ToolkitError(Exception):
6
+ """
7
+ Base class for all errors related to toolkits.
8
+ """
9
+
10
+ pass
11
+
12
+
13
+ class ToolkitLoadError(ToolkitError):
14
+ """
15
+ Raised when there is an error loading a toolkit.
16
+ """
17
+
18
+ pass
19
+
20
+
21
+ class ToolError(Exception):
22
+ """
23
+ Base class for all errors related to tools.
24
+ """
25
+
26
+ pass
27
+
28
+
29
+ class ToolDefinitionError(ToolError):
30
+ """
31
+ Raised when there is an error in the definition of a tool.
32
+ """
33
+
34
+ pass
35
+
36
+
37
+ # ------ runtime errors ------
38
+
39
+
40
+ class ToolRuntimeError(RuntimeError):
41
+ def __init__(
42
+ self,
43
+ message: str,
44
+ developer_message: Optional[str] = None,
45
+ ):
46
+ super().__init__(message)
47
+ self.message = message
48
+ self.developer_message = developer_message
49
+
50
+ def traceback_info(self) -> str | None:
51
+ # return the traceback information of the parent exception
52
+ if self.__cause__:
53
+ return "\n".join(traceback.format_exception(self.__cause__))
54
+ return None
55
+
56
+
57
+ class ToolExecutionError(ToolRuntimeError):
58
+ """
59
+ Raised when there is an error executing a tool.
60
+ """
61
+
62
+ pass
63
+
64
+
65
+ class RetryableToolError(ToolExecutionError):
66
+ """
67
+ Raised when a tool error is retryable.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ message: str,
73
+ developer_message: Optional[str] = None,
74
+ additional_prompt_content: Optional[str] = None,
75
+ retry_after_ms: Optional[int] = None,
76
+ ):
77
+ super().__init__(message, developer_message)
78
+ self.additional_prompt_content = additional_prompt_content
79
+ self.retry_after_ms = retry_after_ms
80
+
81
+
82
+ class ToolSerializationError(ToolRuntimeError):
83
+ """
84
+ Raised when there is an error executing a tool.
85
+ """
86
+
87
+ pass
88
+
89
+
90
+ class ToolInputError(ToolSerializationError):
91
+ """
92
+ Raised when there is an error in the input to a tool.
93
+ """
94
+
95
+ pass
96
+
97
+
98
+ class ToolOutputError(ToolSerializationError):
99
+ """
100
+ Raised when there is an error in the output of a tool.
101
+ """
102
+
103
+ pass
@@ -0,0 +1,129 @@
1
+ import asyncio
2
+ import traceback
3
+ from typing import Any, Callable
4
+
5
+ from pydantic import BaseModel, ValidationError
6
+
7
+ from arcade_core.errors import (
8
+ RetryableToolError,
9
+ ToolInputError,
10
+ ToolOutputError,
11
+ ToolRuntimeError,
12
+ ToolSerializationError,
13
+ )
14
+ from arcade_core.output import output_factory
15
+ from arcade_core.schema import ToolCallLog, ToolCallOutput, ToolContext, ToolDefinition
16
+
17
+
18
+ class ToolExecutor:
19
+ @staticmethod
20
+ async def run(
21
+ func: Callable,
22
+ definition: ToolDefinition,
23
+ input_model: type[BaseModel],
24
+ output_model: type[BaseModel],
25
+ context: ToolContext,
26
+ *args: Any,
27
+ **kwargs: Any,
28
+ ) -> ToolCallOutput:
29
+ """
30
+ Execute a callable function with validated inputs and outputs via Pydantic models.
31
+ """
32
+ # only gathering deprecation log for now
33
+ tool_call_logs = []
34
+ if definition.deprecation_message is not None:
35
+ tool_call_logs.append(
36
+ ToolCallLog(
37
+ message=definition.deprecation_message, level="warning", subtype="deprecation"
38
+ )
39
+ )
40
+
41
+ try:
42
+ # serialize the input model
43
+ inputs = await ToolExecutor._serialize_input(input_model, **kwargs)
44
+
45
+ # prepare the arguments for the function call
46
+ func_args = inputs.model_dump()
47
+
48
+ # inject ToolContext, if the target function supports it
49
+ if definition.input.tool_context_parameter_name is not None:
50
+ func_args[definition.input.tool_context_parameter_name] = context
51
+
52
+ # execute the tool function
53
+ if asyncio.iscoroutinefunction(func):
54
+ results = await func(**func_args)
55
+ else:
56
+ results = func(**func_args)
57
+
58
+ # serialize the output model
59
+ output = await ToolExecutor._serialize_output(output_model, results)
60
+
61
+ # return the output
62
+ return output_factory.success(data=output, logs=tool_call_logs)
63
+
64
+ except RetryableToolError as e:
65
+ return output_factory.fail_retry(
66
+ message=e.message,
67
+ developer_message=e.developer_message,
68
+ additional_prompt_content=e.additional_prompt_content,
69
+ retry_after_ms=e.retry_after_ms,
70
+ )
71
+
72
+ except ToolSerializationError as e:
73
+ return output_factory.fail(message=e.message, developer_message=e.developer_message)
74
+
75
+ # should catch all tool exceptions due to the try/except in the tool decorator
76
+ except ToolRuntimeError as e:
77
+ return output_factory.fail(
78
+ message=e.message,
79
+ developer_message=e.developer_message,
80
+ traceback_info=e.traceback_info(),
81
+ )
82
+
83
+ # if we get here we're in trouble
84
+ except Exception as e:
85
+ return output_factory.fail(
86
+ message="Error in execution",
87
+ developer_message=str(e),
88
+ traceback_info=traceback.format_exc(),
89
+ )
90
+
91
+ @staticmethod
92
+ async def _serialize_input(input_model: type[BaseModel], **kwargs: Any) -> BaseModel:
93
+ """
94
+ Serialize the input to a tool function.
95
+ """
96
+ try:
97
+ # TODO Logging and telemetry
98
+
99
+ # build in the input model to the tool function
100
+ inputs = input_model(**kwargs)
101
+
102
+ except ValidationError as e:
103
+ raise ToolInputError(
104
+ message="Error in tool input deserialization", developer_message=str(e)
105
+ ) from e
106
+
107
+ return inputs
108
+
109
+ @staticmethod
110
+ async def _serialize_output(output_model: type[BaseModel], results: dict) -> BaseModel:
111
+ """
112
+ Serialize the output of a tool function.
113
+ """
114
+ # TODO how to type this the results object?
115
+ # TODO how to ensure `results` contains only safe (serializable) stuff?
116
+ try:
117
+ # TODO Logging and telemetry
118
+
119
+ # build the output model
120
+ output = output_model(**{"result": results})
121
+
122
+ except ValidationError as e:
123
+ raise ToolOutputError(
124
+ message="Failed to serialize tool output",
125
+ developer_message=f"Validation error occurred while serializing tool output: {e!s}. "
126
+ f"Please ensure the tool's output matches the expected schema.",
127
+ ) from e
128
+
129
+ return output
arcade_core/output.py ADDED
@@ -0,0 +1,64 @@
1
+ from typing import TypeVar
2
+
3
+ from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput
4
+ from arcade_core.utils import coerce_empty_list_to_none
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ class ToolOutputFactory:
10
+ """
11
+ Singleton pattern for unified return method from tools.
12
+ """
13
+
14
+ def success(
15
+ self,
16
+ *,
17
+ data: T | None = None,
18
+ logs: list[ToolCallLog] | None = None,
19
+ ) -> ToolCallOutput:
20
+ value = getattr(data, "result", "") if data else ""
21
+ logs = coerce_empty_list_to_none(logs)
22
+ return ToolCallOutput(value=value, logs=logs)
23
+
24
+ def fail(
25
+ self,
26
+ *,
27
+ message: str,
28
+ developer_message: str | None = None,
29
+ traceback_info: str | None = None,
30
+ logs: list[ToolCallLog] | None = None,
31
+ ) -> ToolCallOutput:
32
+ return ToolCallOutput(
33
+ error=ToolCallError(
34
+ message=message,
35
+ developer_message=developer_message,
36
+ can_retry=False,
37
+ traceback_info=traceback_info,
38
+ ),
39
+ logs=coerce_empty_list_to_none(logs),
40
+ )
41
+
42
+ def fail_retry(
43
+ self,
44
+ *,
45
+ message: str,
46
+ developer_message: str | None = None,
47
+ additional_prompt_content: str | None = None,
48
+ retry_after_ms: int | None = None,
49
+ traceback_info: str | None = None,
50
+ logs: list[ToolCallLog] | None = None,
51
+ ) -> ToolCallOutput:
52
+ return ToolCallOutput(
53
+ error=ToolCallError(
54
+ message=message,
55
+ developer_message=developer_message,
56
+ can_retry=True,
57
+ additional_prompt_content=additional_prompt_content,
58
+ retry_after_ms=retry_after_ms,
59
+ ),
60
+ logs=coerce_empty_list_to_none(logs),
61
+ )
62
+
63
+
64
+ output_factory = ToolOutputFactory()
arcade_core/parse.py ADDED
@@ -0,0 +1,63 @@
1
+ import ast
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+
6
+ def load_ast_tree(filepath: str | Path) -> ast.AST:
7
+ """
8
+ Load and parse the Abstract Syntax Tree (AST) from a Python file.
9
+
10
+ """
11
+ try:
12
+ with open(filepath) as file:
13
+ return ast.parse(file.read(), filename=filepath)
14
+ except FileNotFoundError:
15
+ raise FileNotFoundError(f"File {filepath} not found")
16
+
17
+
18
+ def get_function_name_if_decorated(
19
+ node: Union[ast.FunctionDef, ast.AsyncFunctionDef],
20
+ ) -> Optional[str]:
21
+ """
22
+ Check if a function has a decorator.
23
+ """
24
+ decorator_ids = {"arc.tool", "tool"}
25
+ for decorator in node.decorator_list:
26
+ # if the function is decorated and the decorator is
27
+ # either called, or placed on the function
28
+ if (
29
+ (isinstance(decorator, ast.Name) and decorator.id in decorator_ids)
30
+ or (
31
+ isinstance(decorator, ast.Attribute)
32
+ and isinstance(decorator.value, ast.Name)
33
+ and f"{decorator.value.id}.{decorator.attr}" in decorator_ids
34
+ )
35
+ or (
36
+ isinstance(decorator, ast.Call)
37
+ and isinstance(decorator.func, ast.Name)
38
+ and decorator.func.id in decorator_ids
39
+ )
40
+ ):
41
+ return node.name
42
+ return None
43
+
44
+
45
+ def get_tools_from_file(filepath: str | Path) -> list[str]:
46
+ """
47
+ Retrieve tools from a Python file.
48
+ """
49
+ tree = load_ast_tree(filepath)
50
+ return get_tools_from_ast(tree)
51
+
52
+
53
+ def get_tools_from_ast(tree: ast.AST) -> list[str]:
54
+ """
55
+ Retrieve tools from Python source code.
56
+ """
57
+ tools = []
58
+ for node in ast.walk(tree):
59
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
60
+ tool_name = get_function_name_if_decorated(node)
61
+ if tool_name:
62
+ tools.append(tool_name)
63
+ return tools
arcade_core/py.typed ADDED
File without changes