dataact 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataact/__init__.py +31 -0
- dataact/agent.py +237 -0
- dataact/cache.py +319 -0
- dataact/exceptions.py +21 -0
- dataact/format.py +108 -0
- dataact/logger.py +66 -0
- dataact/loop.py +153 -0
- dataact/observe.py +31 -0
- dataact/providers/__init__.py +0 -0
- dataact/providers/anthropic.py +112 -0
- dataact/providers/base.py +35 -0
- dataact/providers/openai.py +125 -0
- dataact/schema.py +79 -0
- dataact/serialize.py +111 -0
- dataact/testing.py +70 -0
- dataact/tools/__init__.py +0 -0
- dataact/tools/connectors.py +129 -0
- dataact/tools/interpreter.py +189 -0
- dataact/tools/planner.py +107 -0
- dataact/tools/subagent.py +222 -0
- dataact/tools/variables.py +25 -0
- dataact/types.py +54 -0
- dataact-0.1.0.dist-info/METADATA +212 -0
- dataact-0.1.0.dist-info/RECORD +26 -0
- dataact-0.1.0.dist-info/WHEEL +4 -0
- dataact-0.1.0.dist-info/licenses/LICENSE +21 -0
dataact/schema.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Input-schema inference for small connector functions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import dataclasses
|
|
6
|
+
import inspect
|
|
7
|
+
import types
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any, get_args, get_origin, get_type_hints
|
|
10
|
+
|
|
11
|
+
_OVERRIDE_HINT = "pass input_schema=... to override"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def infer_input_schema(fn: Callable[..., Any]) -> dict:
|
|
15
|
+
"""Infer a small JSON schema from a connector function signature."""
|
|
16
|
+
signature = inspect.signature(fn)
|
|
17
|
+
try:
|
|
18
|
+
hints = get_type_hints(fn)
|
|
19
|
+
except Exception as exc:
|
|
20
|
+
raise _unsupported(fn, "could not resolve type annotations") from exc
|
|
21
|
+
|
|
22
|
+
properties: dict[str, dict] = {}
|
|
23
|
+
required: list[str] = []
|
|
24
|
+
|
|
25
|
+
for name, parameter in signature.parameters.items():
|
|
26
|
+
if parameter.kind in (
|
|
27
|
+
inspect.Parameter.VAR_POSITIONAL,
|
|
28
|
+
inspect.Parameter.VAR_KEYWORD,
|
|
29
|
+
):
|
|
30
|
+
raise _unsupported(fn, f"unsupported variadic parameter {name!r}")
|
|
31
|
+
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
|
|
32
|
+
raise _unsupported(fn, f"unsupported positional-only parameter {name!r}")
|
|
33
|
+
if parameter.annotation is inspect.Parameter.empty or name not in hints:
|
|
34
|
+
raise _unsupported(fn, f"missing annotation for parameter {name!r}")
|
|
35
|
+
|
|
36
|
+
properties[name] = _schema_for_annotation(fn, hints[name])
|
|
37
|
+
if parameter.default is inspect.Parameter.empty:
|
|
38
|
+
required.append(name)
|
|
39
|
+
|
|
40
|
+
return {
|
|
41
|
+
"type": "object",
|
|
42
|
+
"properties": properties,
|
|
43
|
+
"required": required,
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _schema_for_annotation(fn: Callable[..., Any], annotation: Any) -> dict:
|
|
48
|
+
if annotation is str:
|
|
49
|
+
return {"type": "string"}
|
|
50
|
+
if annotation is int:
|
|
51
|
+
return {"type": "integer"}
|
|
52
|
+
if annotation is float:
|
|
53
|
+
return {"type": "number"}
|
|
54
|
+
if annotation is bool:
|
|
55
|
+
return {"type": "boolean"}
|
|
56
|
+
|
|
57
|
+
origin = get_origin(annotation)
|
|
58
|
+
args = get_args(annotation)
|
|
59
|
+
if origin is list and args == (str,):
|
|
60
|
+
return {"type": "array", "items": {"type": "string"}}
|
|
61
|
+
|
|
62
|
+
if annotation is Any:
|
|
63
|
+
raise _unsupported(fn, "Any is not supported")
|
|
64
|
+
if annotation is dict or origin is dict:
|
|
65
|
+
raise _unsupported(fn, "dict is not supported")
|
|
66
|
+
if dataclasses.is_dataclass(annotation):
|
|
67
|
+
raise _unsupported(fn, "dataclass annotations are not supported")
|
|
68
|
+
if origin in (types.UnionType, getattr(types, "UnionType", object)):
|
|
69
|
+
raise _unsupported(fn, "union annotations are not supported")
|
|
70
|
+
if str(origin) == "typing.Union":
|
|
71
|
+
raise _unsupported(fn, "union annotations are not supported")
|
|
72
|
+
|
|
73
|
+
raise _unsupported(fn, f"unsupported annotation {annotation!r}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _unsupported(fn: Callable[..., Any], reason: str) -> TypeError:
|
|
77
|
+
return TypeError(
|
|
78
|
+
f"Cannot infer input schema for {fn.__name__}: {reason}; {_OVERRIDE_HINT}"
|
|
79
|
+
)
|
dataact/serialize.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def to_jsonable(obj: Any) -> Any:
|
|
10
|
+
"""Recursively convert obj to a JSON-serializable structure. Never raises."""
|
|
11
|
+
try:
|
|
12
|
+
return _convert(obj)
|
|
13
|
+
except Exception as exc:
|
|
14
|
+
return f"<serialization error: {exc!r}>"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _convert(obj: Any) -> Any:
|
|
18
|
+
if obj is None or isinstance(obj, (bool, int, float, str)):
|
|
19
|
+
return obj
|
|
20
|
+
|
|
21
|
+
if isinstance(obj, Enum):
|
|
22
|
+
return obj.value
|
|
23
|
+
|
|
24
|
+
if isinstance(obj, datetime):
|
|
25
|
+
return obj.isoformat()
|
|
26
|
+
|
|
27
|
+
if isinstance(obj, Exception):
|
|
28
|
+
return {"error_type": type(obj).__name__, "error_message": str(obj)}
|
|
29
|
+
|
|
30
|
+
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
|
|
31
|
+
result: dict[str, Any] = {}
|
|
32
|
+
if isinstance(obj, _get_text_block_type()):
|
|
33
|
+
result["type"] = "text"
|
|
34
|
+
result["text"] = _convert(obj.text) # type: ignore[attr-defined]
|
|
35
|
+
elif isinstance(obj, _get_tool_use_block_type()):
|
|
36
|
+
result["type"] = "tool_use"
|
|
37
|
+
result["id"] = _convert(obj.tool_use_id) # type: ignore[attr-defined]
|
|
38
|
+
result["name"] = _convert(obj.tool_name) # type: ignore[attr-defined]
|
|
39
|
+
result["input"] = _convert(obj.tool_input) # type: ignore[attr-defined]
|
|
40
|
+
elif isinstance(obj, _get_tool_result_block_type()):
|
|
41
|
+
result["type"] = "tool_result"
|
|
42
|
+
result["tool_use_id"] = _convert(obj.tool_use_id) # type: ignore[attr-defined]
|
|
43
|
+
result["content"] = _convert(obj.content) # type: ignore[attr-defined]
|
|
44
|
+
result["is_error"] = _convert(obj.is_error) # type: ignore[attr-defined]
|
|
45
|
+
else:
|
|
46
|
+
for f in dataclasses.fields(obj):
|
|
47
|
+
result[f.name] = _convert(getattr(obj, f.name))
|
|
48
|
+
return result
|
|
49
|
+
|
|
50
|
+
if isinstance(obj, dict):
|
|
51
|
+
return {str(k): _convert(v) for k, v in obj.items()}
|
|
52
|
+
|
|
53
|
+
if isinstance(obj, (list, tuple)):
|
|
54
|
+
return [_convert(item) for item in obj]
|
|
55
|
+
|
|
56
|
+
# Try pandas DataFrame
|
|
57
|
+
try:
|
|
58
|
+
import pandas as pd
|
|
59
|
+
|
|
60
|
+
if isinstance(obj, pd.DataFrame):
|
|
61
|
+
return {
|
|
62
|
+
"type": "dataframe_snapshot",
|
|
63
|
+
"shape": list(obj.shape),
|
|
64
|
+
"columns": list(obj.columns),
|
|
65
|
+
"sample": obj.head(5).to_dict(orient="records"),
|
|
66
|
+
}
|
|
67
|
+
except ImportError:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
# Try numpy ndarray
|
|
71
|
+
try:
|
|
72
|
+
import numpy as np
|
|
73
|
+
|
|
74
|
+
if isinstance(obj, np.ndarray):
|
|
75
|
+
return {
|
|
76
|
+
"type": "ndarray_snapshot",
|
|
77
|
+
"shape": list(obj.shape),
|
|
78
|
+
"dtype": str(obj.dtype),
|
|
79
|
+
"sample": obj.flat[:5].tolist(),
|
|
80
|
+
}
|
|
81
|
+
except ImportError:
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
return repr(obj)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _get_text_block_type():
|
|
88
|
+
try:
|
|
89
|
+
from dataact.types import TextBlock
|
|
90
|
+
|
|
91
|
+
return TextBlock
|
|
92
|
+
except ImportError:
|
|
93
|
+
return type(None)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _get_tool_use_block_type():
|
|
97
|
+
try:
|
|
98
|
+
from dataact.types import ToolUseBlock
|
|
99
|
+
|
|
100
|
+
return ToolUseBlock
|
|
101
|
+
except ImportError:
|
|
102
|
+
return type(None)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _get_tool_result_block_type():
|
|
106
|
+
try:
|
|
107
|
+
from dataact.types import ToolResultBlock
|
|
108
|
+
|
|
109
|
+
return ToolResultBlock
|
|
110
|
+
except ImportError:
|
|
111
|
+
return type(None)
|
dataact/testing.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Public testing helpers.
|
|
2
|
+
|
|
3
|
+
`FakeAdapter` is a scripted `ProviderAdapter` that returns pre-built responses
|
|
4
|
+
in order. It exists so that documentation snippets, unit tests, and the
|
|
5
|
+
`Agent` quick-start example can run without an API key.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import copy
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from dataact.providers.base import NormalizedResponse, ProviderAdapter, StopReason
|
|
14
|
+
from dataact.types import Message, TextBlock, ToolSpec, ToolUseBlock
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FakeAdapter(ProviderAdapter):
|
|
18
|
+
def __init__(self, responses: list[NormalizedResponse]) -> None:
|
|
19
|
+
self._responses = list(responses)
|
|
20
|
+
self.calls: list[dict[str, Any]] = []
|
|
21
|
+
|
|
22
|
+
def chat(
|
|
23
|
+
self,
|
|
24
|
+
system: str,
|
|
25
|
+
messages: list[Message],
|
|
26
|
+
tools: list[ToolSpec],
|
|
27
|
+
) -> NormalizedResponse:
|
|
28
|
+
self.calls.append(
|
|
29
|
+
{
|
|
30
|
+
"system": system,
|
|
31
|
+
"messages": copy.deepcopy(messages),
|
|
32
|
+
"tools": copy.deepcopy(tools),
|
|
33
|
+
}
|
|
34
|
+
)
|
|
35
|
+
return self._responses.pop(0)
|
|
36
|
+
|
|
37
|
+
def format_cache_control(self, obj: dict) -> dict:
|
|
38
|
+
result = dict(obj)
|
|
39
|
+
result["cache_control"] = {"type": "ephemeral"}
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def text(text: str) -> NormalizedResponse:
|
|
44
|
+
return NormalizedResponse(
|
|
45
|
+
stop_reason=StopReason.END_TURN,
|
|
46
|
+
content=[TextBlock(text=text)],
|
|
47
|
+
input_tokens=0,
|
|
48
|
+
output_tokens=0,
|
|
49
|
+
cache_read_tokens=0,
|
|
50
|
+
cache_write_tokens=0,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def tool_use(
|
|
55
|
+
tool_use_id: str, tool_name: str, tool_input: dict
|
|
56
|
+
) -> NormalizedResponse:
|
|
57
|
+
return NormalizedResponse(
|
|
58
|
+
stop_reason=StopReason.TOOL_USE,
|
|
59
|
+
content=[
|
|
60
|
+
ToolUseBlock(
|
|
61
|
+
tool_use_id=tool_use_id,
|
|
62
|
+
tool_name=tool_name,
|
|
63
|
+
tool_input=tool_input,
|
|
64
|
+
)
|
|
65
|
+
],
|
|
66
|
+
input_tokens=0,
|
|
67
|
+
output_tokens=0,
|
|
68
|
+
cache_read_tokens=0,
|
|
69
|
+
cache_write_tokens=0,
|
|
70
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
from dataact.cache import SessionCache
|
|
6
|
+
from dataact.format import format_tool_output
|
|
7
|
+
from dataact.types import ToolSpec
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConnectorRegistry:
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
self._directory: dict[str, str] = {} # name -> one-line description
|
|
13
|
+
self._connector_tools: dict[
|
|
14
|
+
str, list[ToolSpec]
|
|
15
|
+
] = {} # name -> list of ToolSpec
|
|
16
|
+
|
|
17
|
+
def register(
|
|
18
|
+
self,
|
|
19
|
+
name: str,
|
|
20
|
+
description: str,
|
|
21
|
+
tools: list[ToolSpec],
|
|
22
|
+
) -> None:
|
|
23
|
+
self._directory[name] = description
|
|
24
|
+
# Ensure all tools are hidden by default
|
|
25
|
+
for spec in tools:
|
|
26
|
+
spec.visible = False
|
|
27
|
+
self._connector_tools[name] = list(tools)
|
|
28
|
+
|
|
29
|
+
def get_load_connectors_spec(self) -> ToolSpec:
|
|
30
|
+
directory = dict(self._directory)
|
|
31
|
+
connector_tools = self._connector_tools
|
|
32
|
+
|
|
33
|
+
def load_connector(name: str) -> str:
|
|
34
|
+
if name not in connector_tools:
|
|
35
|
+
available = list(directory.keys())
|
|
36
|
+
return f"Error: connector {name!r} not found. Available: {available}"
|
|
37
|
+
for spec in connector_tools[name]:
|
|
38
|
+
spec.visible = True
|
|
39
|
+
desc = directory.get(name, "")
|
|
40
|
+
tool_names = [s.name for s in connector_tools[name]]
|
|
41
|
+
return (
|
|
42
|
+
f"Loaded connector {name!r}.\n"
|
|
43
|
+
f"Description: {desc}\n"
|
|
44
|
+
f"Available tools: {tool_names}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
dir_lines = "\n".join(f"- {k}: {v}" for k, v in directory.items())
|
|
48
|
+
return ToolSpec(
|
|
49
|
+
name="load_connectors",
|
|
50
|
+
description=(
|
|
51
|
+
f"Load a data connector to make its tools available.\n"
|
|
52
|
+
f"Available connectors:\n{dir_lines}"
|
|
53
|
+
),
|
|
54
|
+
input_schema={
|
|
55
|
+
"type": "object",
|
|
56
|
+
"properties": {
|
|
57
|
+
"name": {
|
|
58
|
+
"type": "string",
|
|
59
|
+
"description": (
|
|
60
|
+
f"Connector name. One of: {list(directory.keys())}"
|
|
61
|
+
),
|
|
62
|
+
}
|
|
63
|
+
},
|
|
64
|
+
"required": ["name"],
|
|
65
|
+
},
|
|
66
|
+
handler=load_connector,
|
|
67
|
+
visible=True,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def all_tool_specs(self) -> list[ToolSpec]:
|
|
71
|
+
specs = []
|
|
72
|
+
for tool_list in self._connector_tools.values():
|
|
73
|
+
specs.extend(tool_list)
|
|
74
|
+
return specs
|
|
75
|
+
|
|
76
|
+
def call_connector(
|
|
77
|
+
self,
|
|
78
|
+
tool_name: str,
|
|
79
|
+
tool_input: dict,
|
|
80
|
+
cache: SessionCache,
|
|
81
|
+
) -> str:
|
|
82
|
+
for tool_list in self._connector_tools.values():
|
|
83
|
+
for spec in tool_list:
|
|
84
|
+
if spec.name == tool_name and spec.handler is not None:
|
|
85
|
+
raw = spec.handler(**tool_input)
|
|
86
|
+
return format_tool_output(
|
|
87
|
+
raw, cache=cache, preferred_name=tool_name.split("__")[-1]
|
|
88
|
+
)
|
|
89
|
+
return f"Error: tool {tool_name!r} not found"
|
|
90
|
+
|
|
91
|
+
def make_wrapped_specs(self, cache: SessionCache) -> list[ToolSpec]:
|
|
92
|
+
"""
|
|
93
|
+
Return ToolSpecs whose handlers auto-cache large results.
|
|
94
|
+
|
|
95
|
+
Replaces the specs in the registry in-place so that load_connectors'
|
|
96
|
+
visibility flip applies to the returned (wrapped) specs, not stale originals.
|
|
97
|
+
"""
|
|
98
|
+
result = []
|
|
99
|
+
for connector_name, tool_list in self._connector_tools.items():
|
|
100
|
+
new_list = []
|
|
101
|
+
for orig_spec in tool_list:
|
|
102
|
+
handler = orig_spec.handler
|
|
103
|
+
if handler is None:
|
|
104
|
+
new_list.append(orig_spec)
|
|
105
|
+
result.append(orig_spec)
|
|
106
|
+
continue
|
|
107
|
+
preferred = orig_spec.name.split("__")[-1]
|
|
108
|
+
|
|
109
|
+
def make_handler(h: Callable, pname: str):
|
|
110
|
+
def wrapped(**kwargs: Any) -> str:
|
|
111
|
+
raw = h(**kwargs)
|
|
112
|
+
return format_tool_output(
|
|
113
|
+
raw, cache=cache, preferred_name=pname
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return wrapped
|
|
117
|
+
|
|
118
|
+
new_spec = ToolSpec(
|
|
119
|
+
name=orig_spec.name,
|
|
120
|
+
description=orig_spec.description,
|
|
121
|
+
input_schema=orig_spec.input_schema,
|
|
122
|
+
handler=make_handler(handler, preferred),
|
|
123
|
+
visible=orig_spec.visible,
|
|
124
|
+
)
|
|
125
|
+
new_list.append(new_spec)
|
|
126
|
+
result.append(new_spec)
|
|
127
|
+
# Replace in registry so load_connectors flips the wrapped specs' visible
|
|
128
|
+
self._connector_tools[connector_name] = new_list
|
|
129
|
+
return result
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import builtins
|
|
5
|
+
import io
|
|
6
|
+
import traceback
|
|
7
|
+
from contextlib import redirect_stdout
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from dataact.cache import SessionCache
|
|
11
|
+
from dataact.types import ToolSpec
|
|
12
|
+
|
|
13
|
+
_DEFAULT_ALLOWLIST = frozenset(
|
|
14
|
+
{
|
|
15
|
+
"pandas",
|
|
16
|
+
"numpy",
|
|
17
|
+
"json",
|
|
18
|
+
"math",
|
|
19
|
+
"datetime",
|
|
20
|
+
"collections",
|
|
21
|
+
"itertools",
|
|
22
|
+
"pd",
|
|
23
|
+
"np", # common aliases
|
|
24
|
+
}
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
_FORBIDDEN_NAMES = frozenset({"eval", "exec", "__import__", "open", "compile"})
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class _SecurityVisitor(ast.NodeVisitor):
|
|
31
|
+
"""AST visitor that raises ValueError on forbidden patterns."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, allowlist: frozenset[str]) -> None:
|
|
34
|
+
self._allowlist = allowlist
|
|
35
|
+
self.errors: list[str] = []
|
|
36
|
+
|
|
37
|
+
def visit_Import(self, node: ast.Import) -> None:
|
|
38
|
+
for alias in node.names:
|
|
39
|
+
top = alias.name.split(".")[0]
|
|
40
|
+
if top not in self._allowlist:
|
|
41
|
+
self.errors.append(f"Import not allowed: {alias.name!r}")
|
|
42
|
+
self.generic_visit(node)
|
|
43
|
+
|
|
44
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
45
|
+
if node.module:
|
|
46
|
+
top = node.module.split(".")[0]
|
|
47
|
+
if top not in self._allowlist:
|
|
48
|
+
self.errors.append(f"Import not allowed: {node.module!r}")
|
|
49
|
+
self.generic_visit(node)
|
|
50
|
+
|
|
51
|
+
def visit_Call(self, node: ast.Call) -> None:
|
|
52
|
+
if isinstance(node.func, ast.Name) and node.func.id in _FORBIDDEN_NAMES:
|
|
53
|
+
self.errors.append(f"Call not allowed: {node.func.id!r}")
|
|
54
|
+
self.generic_visit(node)
|
|
55
|
+
|
|
56
|
+
def visit_Attribute(self, node: ast.Attribute) -> None:
|
|
57
|
+
if node.attr.startswith("__") and node.attr.endswith("__"):
|
|
58
|
+
self.errors.append(f"Dunder attribute access not allowed: {node.attr!r}")
|
|
59
|
+
self.generic_visit(node)
|
|
60
|
+
|
|
61
|
+
def visit_Name(self, node: ast.Name) -> None:
|
|
62
|
+
if node.id in _FORBIDDEN_NAMES:
|
|
63
|
+
self.errors.append(f"Name not allowed: {node.id!r}")
|
|
64
|
+
self.generic_visit(node)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class PythonInterpreter:
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
cache: SessionCache,
|
|
71
|
+
allowlist: frozenset[str] | None = None,
|
|
72
|
+
) -> None:
|
|
73
|
+
self._cache = cache
|
|
74
|
+
self._allowlist = allowlist if allowlist is not None else _DEFAULT_ALLOWLIST
|
|
75
|
+
|
|
76
|
+
def run(self, code: str) -> str:
|
|
77
|
+
# AST security check
|
|
78
|
+
try:
|
|
79
|
+
tree = ast.parse(code)
|
|
80
|
+
except SyntaxError as exc:
|
|
81
|
+
return f"SyntaxError: {exc}"
|
|
82
|
+
|
|
83
|
+
visitor = _SecurityVisitor(self._allowlist)
|
|
84
|
+
visitor.visit(tree)
|
|
85
|
+
if visitor.errors:
|
|
86
|
+
return "SecurityError: " + "; ".join(visitor.errors) + " — not allowed"
|
|
87
|
+
|
|
88
|
+
# Build fresh locals for this call
|
|
89
|
+
local_vars: dict[str, Any] = {}
|
|
90
|
+
|
|
91
|
+
# Inject cache handles
|
|
92
|
+
for name, value in self._cache.items():
|
|
93
|
+
local_vars[name] = value
|
|
94
|
+
|
|
95
|
+
# Inject save() helper
|
|
96
|
+
def save(name: str, value: Any) -> str:
|
|
97
|
+
return self._cache.put(name, value)
|
|
98
|
+
|
|
99
|
+
local_vars["save"] = save
|
|
100
|
+
|
|
101
|
+
# Capture stdout
|
|
102
|
+
buf = io.StringIO()
|
|
103
|
+
try:
|
|
104
|
+
with redirect_stdout(buf):
|
|
105
|
+
exec(
|
|
106
|
+
compile(tree, "<code>", "exec"),
|
|
107
|
+
{"__builtins__": _safe_builtins(self._allowlist)},
|
|
108
|
+
local_vars,
|
|
109
|
+
) # noqa: S102
|
|
110
|
+
except Exception:
|
|
111
|
+
err = traceback.format_exc()
|
|
112
|
+
return f"Error:\n{err}"
|
|
113
|
+
|
|
114
|
+
output = buf.getvalue()
|
|
115
|
+
return output if output else "ran successfully with no output"
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def make_tool_spec(cache: SessionCache) -> ToolSpec:
|
|
119
|
+
interp = PythonInterpreter(cache=cache)
|
|
120
|
+
return ToolSpec(
|
|
121
|
+
name="python_interpreter",
|
|
122
|
+
description=(
|
|
123
|
+
"Run Python code over cached data handles. "
|
|
124
|
+
"Cache handles are available as local variables. "
|
|
125
|
+
"Call save(name, value) to store computed artifacts back to cache."
|
|
126
|
+
),
|
|
127
|
+
input_schema={
|
|
128
|
+
"type": "object",
|
|
129
|
+
"properties": {
|
|
130
|
+
"code": {"type": "string", "description": "Python code to execute"},
|
|
131
|
+
},
|
|
132
|
+
"required": ["code"],
|
|
133
|
+
},
|
|
134
|
+
handler=interp.run,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _safe_builtins(allowlist: frozenset[str]) -> dict:
|
|
139
|
+
safe = {
|
|
140
|
+
"print": print,
|
|
141
|
+
"len": len,
|
|
142
|
+
"range": range,
|
|
143
|
+
"enumerate": enumerate,
|
|
144
|
+
"zip": zip,
|
|
145
|
+
"map": map,
|
|
146
|
+
"filter": filter,
|
|
147
|
+
"sorted": sorted,
|
|
148
|
+
"reversed": reversed,
|
|
149
|
+
"list": list,
|
|
150
|
+
"dict": dict,
|
|
151
|
+
"set": set,
|
|
152
|
+
"tuple": tuple,
|
|
153
|
+
"str": str,
|
|
154
|
+
"int": int,
|
|
155
|
+
"float": float,
|
|
156
|
+
"bool": bool,
|
|
157
|
+
"type": type,
|
|
158
|
+
"isinstance": isinstance,
|
|
159
|
+
"hasattr": hasattr,
|
|
160
|
+
"getattr": getattr,
|
|
161
|
+
"abs": abs,
|
|
162
|
+
"round": round,
|
|
163
|
+
"min": min,
|
|
164
|
+
"max": max,
|
|
165
|
+
"sum": sum,
|
|
166
|
+
"any": any,
|
|
167
|
+
"all": all,
|
|
168
|
+
"repr": repr,
|
|
169
|
+
"format": format,
|
|
170
|
+
"vars": vars,
|
|
171
|
+
"dir": dir,
|
|
172
|
+
"None": None,
|
|
173
|
+
"True": True,
|
|
174
|
+
"False": False,
|
|
175
|
+
"__import__": _make_safe_import(allowlist),
|
|
176
|
+
}
|
|
177
|
+
return safe
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _make_safe_import(allowlist: frozenset[str]):
|
|
181
|
+
def safe_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
182
|
+
if level != 0:
|
|
183
|
+
raise ImportError("relative imports are not allowed")
|
|
184
|
+
top = name.split(".")[0]
|
|
185
|
+
if top not in allowlist:
|
|
186
|
+
raise ImportError(f"Import not allowed: {name!r}")
|
|
187
|
+
return builtins.__import__(name, globals, locals, fromlist, level)
|
|
188
|
+
|
|
189
|
+
return safe_import
|
dataact/tools/planner.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from dataact.types import ToolSpec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Planner:
|
|
10
|
+
def __init__(self) -> None:
|
|
11
|
+
self._items: list[dict[str, Any]] = []
|
|
12
|
+
self._turns_since_update: int = 0
|
|
13
|
+
|
|
14
|
+
def add(self, items: list[str]) -> str:
|
|
15
|
+
for text in items:
|
|
16
|
+
self._items.append(
|
|
17
|
+
{
|
|
18
|
+
"id": str(uuid.uuid4())[:8],
|
|
19
|
+
"text": text,
|
|
20
|
+
"status": "pending",
|
|
21
|
+
}
|
|
22
|
+
)
|
|
23
|
+
self._turns_since_update = 0
|
|
24
|
+
return self.list()
|
|
25
|
+
|
|
26
|
+
def update(self, id: str, status: str) -> str:
|
|
27
|
+
for item in self._items:
|
|
28
|
+
if item["id"] == id:
|
|
29
|
+
item["status"] = status
|
|
30
|
+
self._turns_since_update = 0
|
|
31
|
+
return f"Updated {id!r} to {status!r}"
|
|
32
|
+
return f"Item {id!r} not found"
|
|
33
|
+
|
|
34
|
+
def list(self) -> str:
|
|
35
|
+
if not self._items:
|
|
36
|
+
return "Todo list is empty."
|
|
37
|
+
lines = []
|
|
38
|
+
for item in self._items:
|
|
39
|
+
lines.append(f"[{item['id']}] ({item['status']}) {item['text']}")
|
|
40
|
+
return "\n".join(lines)
|
|
41
|
+
|
|
42
|
+
def reminder_hook(self, current_turn: int, max_turns: int) -> str | None:
|
|
43
|
+
pending = [i for i in self._items if i["status"] == "pending"]
|
|
44
|
+
n = self._turns_since_update
|
|
45
|
+
self._turns_since_update += 1
|
|
46
|
+
|
|
47
|
+
if not pending:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
if n >= 12:
|
|
51
|
+
return (
|
|
52
|
+
f"URGENT: You have {len(pending)} pending todo item(s) "
|
|
53
|
+
f"that haven't been updated in {n} turns. Address them immediately."
|
|
54
|
+
)
|
|
55
|
+
if n >= 8:
|
|
56
|
+
return (
|
|
57
|
+
f"WARNING: {len(pending)} pending todo item(s) remain "
|
|
58
|
+
f"with no updates for {n} turns. Please make progress on your plan."
|
|
59
|
+
)
|
|
60
|
+
if n >= 4:
|
|
61
|
+
return (
|
|
62
|
+
f"Reminder: You have {len(pending)} pending todo item(s). "
|
|
63
|
+
f"Consider updating your plan."
|
|
64
|
+
)
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
def make_tool_specs(self) -> list[ToolSpec]:
|
|
68
|
+
return [
|
|
69
|
+
ToolSpec(
|
|
70
|
+
name="planner__add",
|
|
71
|
+
description="Add items to your todo list.",
|
|
72
|
+
input_schema={
|
|
73
|
+
"type": "object",
|
|
74
|
+
"properties": {
|
|
75
|
+
"items": {
|
|
76
|
+
"type": "array",
|
|
77
|
+
"items": {"type": "string"},
|
|
78
|
+
"description": "List of task descriptions to add.",
|
|
79
|
+
}
|
|
80
|
+
},
|
|
81
|
+
"required": ["items"],
|
|
82
|
+
},
|
|
83
|
+
handler=self.add,
|
|
84
|
+
),
|
|
85
|
+
ToolSpec(
|
|
86
|
+
name="planner__update",
|
|
87
|
+
description="Update the status of a todo item.",
|
|
88
|
+
input_schema={
|
|
89
|
+
"type": "object",
|
|
90
|
+
"properties": {
|
|
91
|
+
"id": {"type": "string", "description": "Item ID"},
|
|
92
|
+
"status": {
|
|
93
|
+
"type": "string",
|
|
94
|
+
"enum": ["pending", "in_progress", "done", "blocked"],
|
|
95
|
+
},
|
|
96
|
+
},
|
|
97
|
+
"required": ["id", "status"],
|
|
98
|
+
},
|
|
99
|
+
handler=self.update,
|
|
100
|
+
),
|
|
101
|
+
ToolSpec(
|
|
102
|
+
name="planner__list",
|
|
103
|
+
description="List all todo items and their statuses.",
|
|
104
|
+
input_schema={"type": "object", "properties": {}},
|
|
105
|
+
handler=self.list,
|
|
106
|
+
),
|
|
107
|
+
]
|