dataact 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataact/__init__.py +31 -0
- dataact/agent.py +237 -0
- dataact/cache.py +319 -0
- dataact/exceptions.py +21 -0
- dataact/format.py +108 -0
- dataact/logger.py +66 -0
- dataact/loop.py +153 -0
- dataact/observe.py +31 -0
- dataact/providers/__init__.py +0 -0
- dataact/providers/anthropic.py +112 -0
- dataact/providers/base.py +35 -0
- dataact/providers/openai.py +125 -0
- dataact/schema.py +79 -0
- dataact/serialize.py +111 -0
- dataact/testing.py +70 -0
- dataact/tools/__init__.py +0 -0
- dataact/tools/connectors.py +129 -0
- dataact/tools/interpreter.py +189 -0
- dataact/tools/planner.py +107 -0
- dataact/tools/subagent.py +222 -0
- dataact/tools/variables.py +25 -0
- dataact/types.py +54 -0
- dataact-0.1.0.dist-info/METADATA +212 -0
- dataact-0.1.0.dist-info/RECORD +26 -0
- dataact-0.1.0.dist-info/WHEEL +4 -0
- dataact-0.1.0.dist-info/licenses/LICENSE +21 -0
dataact/format.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from dataact.cache import SessionCache
|
|
8
|
+
|
|
9
|
+
_INLINE_STR_MAX = 500
|
|
10
|
+
_INLINE_JSON_MAX = 2000
|
|
11
|
+
_INLINE_COLLECTION_MAX = 20
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def format_tool_output(
|
|
15
|
+
value: Any,
|
|
16
|
+
cache: "SessionCache | None" = None,
|
|
17
|
+
preferred_name: str | None = None,
|
|
18
|
+
) -> str:
|
|
19
|
+
"""Decide whether to inline output or cache it; return a string for the message."""
|
|
20
|
+
if isinstance(value, Exception):
|
|
21
|
+
return f"Error: {type(value).__name__}: {value}"
|
|
22
|
+
|
|
23
|
+
# DataFrame / ndarray → always cache
|
|
24
|
+
if _is_dataframe(value) or _is_ndarray(value):
|
|
25
|
+
return _cache_value(value, cache, preferred_name or _default_name(value))
|
|
26
|
+
|
|
27
|
+
# Short string → inline
|
|
28
|
+
if isinstance(value, str):
|
|
29
|
+
if len(value) <= _INLINE_STR_MAX:
|
|
30
|
+
return value
|
|
31
|
+
return _cache_value(value, cache, preferred_name or "text_result")
|
|
32
|
+
|
|
33
|
+
# Scalar (int, float, bool, None)
|
|
34
|
+
if isinstance(value, (int, float, bool)) or value is None:
|
|
35
|
+
return str(value)
|
|
36
|
+
|
|
37
|
+
# Short dict/list → inline JSON repr
|
|
38
|
+
if isinstance(value, (dict, list)):
|
|
39
|
+
try:
|
|
40
|
+
serialized = json.dumps(value, default=repr)
|
|
41
|
+
except Exception:
|
|
42
|
+
serialized = repr(value)
|
|
43
|
+
if (
|
|
44
|
+
len(serialized) <= _INLINE_JSON_MAX
|
|
45
|
+
and _collection_size(value) <= _INLINE_COLLECTION_MAX
|
|
46
|
+
):
|
|
47
|
+
return serialized
|
|
48
|
+
return _cache_value(value, cache, preferred_name or _default_name(value))
|
|
49
|
+
|
|
50
|
+
# Unknown object with preferred_name → cache
|
|
51
|
+
if preferred_name is not None and cache is not None:
|
|
52
|
+
return _cache_value(value, cache, preferred_name)
|
|
53
|
+
|
|
54
|
+
# Unknown object → repr truncated
|
|
55
|
+
r = repr(value)
|
|
56
|
+
if len(r) > _INLINE_STR_MAX:
|
|
57
|
+
return r[:_INLINE_STR_MAX] + "..."
|
|
58
|
+
return r
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _cache_value(value: Any, cache: "SessionCache | None", name: str) -> str:
|
|
62
|
+
if cache is None:
|
|
63
|
+
# No cache available, fall back to repr
|
|
64
|
+
r = repr(value)
|
|
65
|
+
if len(r) > _INLINE_STR_MAX:
|
|
66
|
+
return r[:_INLINE_STR_MAX] + "..."
|
|
67
|
+
return r
|
|
68
|
+
resolved = cache.put(name, value)
|
|
69
|
+
snapshot = cache.snapshot(resolved)
|
|
70
|
+
return f"Saved as `{resolved}`\nSnapshot: {snapshot}"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _default_name(value: Any) -> str:
|
|
74
|
+
if _is_dataframe(value):
|
|
75
|
+
return "dataframe"
|
|
76
|
+
if _is_ndarray(value):
|
|
77
|
+
return "array"
|
|
78
|
+
if isinstance(value, dict):
|
|
79
|
+
return "result_dict"
|
|
80
|
+
if isinstance(value, list):
|
|
81
|
+
return "result_list"
|
|
82
|
+
return "result"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _collection_size(value: Any) -> int:
|
|
86
|
+
if isinstance(value, dict):
|
|
87
|
+
return len(value)
|
|
88
|
+
if isinstance(value, list):
|
|
89
|
+
return len(value)
|
|
90
|
+
return 0
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _is_dataframe(value: Any) -> bool:
|
|
94
|
+
try:
|
|
95
|
+
import pandas as pd
|
|
96
|
+
|
|
97
|
+
return isinstance(value, pd.DataFrame)
|
|
98
|
+
except ImportError:
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _is_ndarray(value: Any) -> bool:
|
|
103
|
+
try:
|
|
104
|
+
import numpy as np
|
|
105
|
+
|
|
106
|
+
return isinstance(value, np.ndarray)
|
|
107
|
+
except ImportError:
|
|
108
|
+
return False
|
dataact/logger.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from loguru import logger
|
|
10
|
+
|
|
11
|
+
from dataact.providers.base import NormalizedResponse
|
|
12
|
+
from dataact.serialize import to_jsonable
|
|
13
|
+
from dataact.types import Message, ToolResultBlock
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def setup_logger(run_dir: str = "./runs") -> str:
|
|
17
|
+
"""Create a timestamped JSONL file and configure loguru. Returns the path."""
|
|
18
|
+
Path(run_dir).mkdir(parents=True, exist_ok=True)
|
|
19
|
+
ts = datetime.now(tz=timezone.utc).strftime("%Y%m%dT%H%M%S")
|
|
20
|
+
run_file = str(Path(run_dir) / f"{ts}.jsonl")
|
|
21
|
+
# Ensure the file exists
|
|
22
|
+
Path(run_file).touch()
|
|
23
|
+
logger.remove()
|
|
24
|
+
logger.add(
|
|
25
|
+
lambda msg: None, level="INFO"
|
|
26
|
+
) # suppress default stderr; callers can add sinks
|
|
27
|
+
return run_file
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def log_turn(
|
|
31
|
+
turn: int,
|
|
32
|
+
system: str,
|
|
33
|
+
messages: list[Message],
|
|
34
|
+
response: NormalizedResponse,
|
|
35
|
+
tool_results: list[ToolResultBlock],
|
|
36
|
+
latency_ms: float,
|
|
37
|
+
run_file: str,
|
|
38
|
+
cache_storage: dict[str, dict[str, str]] | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Append one JSON line to the run JSONL file."""
|
|
41
|
+
system_hash = hashlib.sha256(system.encode()).hexdigest()
|
|
42
|
+
|
|
43
|
+
record: dict[str, Any] = {
|
|
44
|
+
"turn": turn,
|
|
45
|
+
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
|
|
46
|
+
"system_hash": system_hash,
|
|
47
|
+
"messages": to_jsonable(messages),
|
|
48
|
+
"response_content": to_jsonable(response.content),
|
|
49
|
+
"stop_reason": response.stop_reason.value,
|
|
50
|
+
"tool_results": to_jsonable(tool_results),
|
|
51
|
+
"metrics": {
|
|
52
|
+
"input_tokens": response.input_tokens,
|
|
53
|
+
"output_tokens": response.output_tokens,
|
|
54
|
+
"cache_read_tokens": response.cache_read_tokens,
|
|
55
|
+
"cache_write_tokens": response.cache_write_tokens,
|
|
56
|
+
"latency_ms": latency_ms,
|
|
57
|
+
},
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
if turn == 1:
|
|
61
|
+
record["system"] = system
|
|
62
|
+
if cache_storage is not None:
|
|
63
|
+
record["cache_storage"] = cache_storage
|
|
64
|
+
|
|
65
|
+
with open(run_file, "a") as f:
|
|
66
|
+
f.write(json.dumps(record) + "\n")
|
dataact/loop.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
from dataact.cache import SessionCache
|
|
6
|
+
from dataact.exceptions import MaxTurnsExceeded
|
|
7
|
+
from dataact.format import format_tool_output
|
|
8
|
+
from dataact.logger import log_turn, setup_logger
|
|
9
|
+
from dataact.observe import time_block
|
|
10
|
+
from dataact.providers.base import NormalizedResponse, ProviderAdapter, StopReason
|
|
11
|
+
from dataact.types import Message, TextBlock, ToolResultBlock, ToolSpec, ToolUseBlock
|
|
12
|
+
|
|
13
|
+
_MAX_TURN_REMINDER = (
|
|
14
|
+
"This is the final turn. You MUST produce your complete final output now. "
|
|
15
|
+
"Do not use any more tools. Respond with your answer directly."
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Harness:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
adapter: ProviderAdapter,
|
|
23
|
+
system: str,
|
|
24
|
+
tools: list[ToolSpec],
|
|
25
|
+
max_turns: int = 25,
|
|
26
|
+
run_dir: str = "./runs",
|
|
27
|
+
cache: SessionCache | None = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
self._adapter = adapter
|
|
30
|
+
self._system = system
|
|
31
|
+
self._tools = list(tools)
|
|
32
|
+
self._max_turns = max_turns
|
|
33
|
+
self._run_dir = run_dir
|
|
34
|
+
self._cache = cache if cache is not None else SessionCache()
|
|
35
|
+
self._messages: list[Message] = []
|
|
36
|
+
self._reminders: list[Callable[[int, int], str | None]] = []
|
|
37
|
+
|
|
38
|
+
def register_reminder(self, hook: Callable[[int, int], str | None]) -> None:
|
|
39
|
+
self._reminders.append(hook)
|
|
40
|
+
|
|
41
|
+
def run(self, user_message: str) -> str:
|
|
42
|
+
run_file = setup_logger(self._run_dir)
|
|
43
|
+
self._messages = [Message(role="user", content=[TextBlock(text=user_message)])]
|
|
44
|
+
last_response: NormalizedResponse | None = None
|
|
45
|
+
|
|
46
|
+
for turn in range(1, self._max_turns + 1):
|
|
47
|
+
self._apply_reminders(turn)
|
|
48
|
+
visible_tools = [t for t in self._tools if t.visible]
|
|
49
|
+
|
|
50
|
+
with time_block() as tb:
|
|
51
|
+
response = self._adapter.chat(
|
|
52
|
+
system=self._system,
|
|
53
|
+
messages=self._messages,
|
|
54
|
+
tools=visible_tools,
|
|
55
|
+
)
|
|
56
|
+
last_response = response
|
|
57
|
+
latency = tb.elapsed_ms
|
|
58
|
+
|
|
59
|
+
# Append assistant message
|
|
60
|
+
self._messages.append(Message(role="assistant", content=response.content))
|
|
61
|
+
|
|
62
|
+
tool_results: list[ToolResultBlock] = []
|
|
63
|
+
|
|
64
|
+
if response.stop_reason == StopReason.TOOL_USE:
|
|
65
|
+
tool_results = self._dispatch_tools(response.content)
|
|
66
|
+
user_msg = Message(role="user", content=list(tool_results))
|
|
67
|
+
self._messages.append(user_msg)
|
|
68
|
+
|
|
69
|
+
log_turn(
|
|
70
|
+
turn=turn,
|
|
71
|
+
system=self._system,
|
|
72
|
+
messages=self._messages,
|
|
73
|
+
response=response,
|
|
74
|
+
tool_results=tool_results,
|
|
75
|
+
latency_ms=latency,
|
|
76
|
+
run_file=run_file,
|
|
77
|
+
cache_storage=self._cache.storage_metadata(),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if response.stop_reason == StopReason.END_TURN:
|
|
81
|
+
return self._extract_text(response)
|
|
82
|
+
|
|
83
|
+
if turn == self._max_turns:
|
|
84
|
+
raise MaxTurnsExceeded(turn, last_response)
|
|
85
|
+
|
|
86
|
+
raise MaxTurnsExceeded(self._max_turns, last_response)
|
|
87
|
+
|
|
88
|
+
def _apply_reminders(self, turn: int) -> None:
|
|
89
|
+
reminder_texts: list[str] = []
|
|
90
|
+
|
|
91
|
+
for hook in self._reminders:
|
|
92
|
+
text = hook(turn, self._max_turns)
|
|
93
|
+
if text:
|
|
94
|
+
reminder_texts.append(text)
|
|
95
|
+
|
|
96
|
+
# Built-in max-turn reminder
|
|
97
|
+
if turn == self._max_turns - 1:
|
|
98
|
+
reminder_texts.append(_MAX_TURN_REMINDER)
|
|
99
|
+
|
|
100
|
+
if not reminder_texts:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
combined = "\n\n".join(reminder_texts)
|
|
104
|
+
reminder_block = TextBlock(text=combined)
|
|
105
|
+
|
|
106
|
+
# Append to existing user message or create a new one
|
|
107
|
+
if self._messages and self._messages[-1].role == "user":
|
|
108
|
+
self._messages[-1].content.append(reminder_block)
|
|
109
|
+
else:
|
|
110
|
+
self._messages.append(Message(role="user", content=[reminder_block]))
|
|
111
|
+
|
|
112
|
+
def _dispatch_tools(self, content: list) -> list[ToolResultBlock]:
|
|
113
|
+
tool_uses = [b for b in content if isinstance(b, ToolUseBlock)]
|
|
114
|
+
results = []
|
|
115
|
+
tool_map = {t.name: t for t in self._tools}
|
|
116
|
+
|
|
117
|
+
for tub in tool_uses:
|
|
118
|
+
spec = tool_map.get(tub.tool_name)
|
|
119
|
+
if spec is None or spec.handler is None:
|
|
120
|
+
results.append(
|
|
121
|
+
ToolResultBlock(
|
|
122
|
+
tool_use_id=tub.tool_use_id,
|
|
123
|
+
content=f"Tool not found: {tub.tool_name!r}",
|
|
124
|
+
is_error=True,
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
continue
|
|
128
|
+
try:
|
|
129
|
+
raw = spec.handler(**tub.tool_input)
|
|
130
|
+
output = format_tool_output(raw, cache=self._cache)
|
|
131
|
+
except Exception as exc:
|
|
132
|
+
output = repr(exc)
|
|
133
|
+
results.append(
|
|
134
|
+
ToolResultBlock(
|
|
135
|
+
tool_use_id=tub.tool_use_id,
|
|
136
|
+
content=output,
|
|
137
|
+
is_error=True,
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
continue
|
|
141
|
+
results.append(
|
|
142
|
+
ToolResultBlock(
|
|
143
|
+
tool_use_id=tub.tool_use_id,
|
|
144
|
+
content=output,
|
|
145
|
+
is_error=False,
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return results
|
|
150
|
+
|
|
151
|
+
def _extract_text(self, response: NormalizedResponse) -> str:
|
|
152
|
+
texts = [b.text for b in response.content if isinstance(b, TextBlock)]
|
|
153
|
+
return "\n".join(texts)
|
dataact/observe.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Generator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TurnMetrics:
|
|
11
|
+
turn: int
|
|
12
|
+
input_tokens: int
|
|
13
|
+
output_tokens: int
|
|
14
|
+
cache_read_tokens: int
|
|
15
|
+
cache_write_tokens: int
|
|
16
|
+
latency_ms: float
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _TimeResult:
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self.elapsed_ms: float = 0.0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@contextmanager
|
|
25
|
+
def time_block() -> Generator[_TimeResult, None, None]:
|
|
26
|
+
result = _TimeResult()
|
|
27
|
+
start = time.monotonic()
|
|
28
|
+
try:
|
|
29
|
+
yield result
|
|
30
|
+
finally:
|
|
31
|
+
result.elapsed_ms = (time.monotonic() - start) * 1000
|
|
File without changes
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
|
|
5
|
+
import anthropic
|
|
6
|
+
|
|
7
|
+
from dataact.providers.base import NormalizedResponse, ProviderAdapter, StopReason
|
|
8
|
+
from dataact.types import Message, TextBlock, ToolResultBlock, ToolSpec, ToolUseBlock
|
|
9
|
+
|
|
10
|
+
_STOP_REASON_MAP = {
|
|
11
|
+
"end_turn": StopReason.END_TURN,
|
|
12
|
+
"tool_use": StopReason.TOOL_USE,
|
|
13
|
+
"max_tokens": StopReason.MAX_TOKENS,
|
|
14
|
+
"stop_sequence": StopReason.STOP_SEQUENCE,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AnthropicAdapter(ProviderAdapter):
|
|
19
|
+
def __init__(
|
|
20
|
+
self, model: str = "claude-sonnet-4-6", max_tokens: int = 8096
|
|
21
|
+
) -> None:
|
|
22
|
+
self._model = model
|
|
23
|
+
self._max_tokens = max_tokens
|
|
24
|
+
self._client = anthropic.Anthropic()
|
|
25
|
+
|
|
26
|
+
def format_cache_control(self, obj: dict) -> dict:
|
|
27
|
+
result = copy.copy(obj)
|
|
28
|
+
result["cache_control"] = {"type": "ephemeral"}
|
|
29
|
+
return result
|
|
30
|
+
|
|
31
|
+
def chat(
|
|
32
|
+
self,
|
|
33
|
+
system: str,
|
|
34
|
+
messages: list[Message],
|
|
35
|
+
tools: list[ToolSpec],
|
|
36
|
+
) -> NormalizedResponse:
|
|
37
|
+
# Deep-copy inputs so we never mutate harness state
|
|
38
|
+
api_system = self._build_system(system)
|
|
39
|
+
api_messages = self._build_messages(messages)
|
|
40
|
+
api_tools = self._build_tools(tools)
|
|
41
|
+
|
|
42
|
+
resp = self._client.messages.create(
|
|
43
|
+
model=self._model,
|
|
44
|
+
max_tokens=self._max_tokens,
|
|
45
|
+
system=api_system,
|
|
46
|
+
messages=api_messages,
|
|
47
|
+
tools=api_tools or anthropic.NOT_GIVEN,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
stop_reason = _STOP_REASON_MAP.get(resp.stop_reason, StopReason.END_TURN)
|
|
51
|
+
content = self._normalize_content(resp.content)
|
|
52
|
+
|
|
53
|
+
return NormalizedResponse(
|
|
54
|
+
stop_reason=stop_reason,
|
|
55
|
+
content=content,
|
|
56
|
+
input_tokens=resp.usage.input_tokens,
|
|
57
|
+
output_tokens=resp.usage.output_tokens,
|
|
58
|
+
cache_read_tokens=getattr(resp.usage, "cache_read_input_tokens", 0) or 0,
|
|
59
|
+
cache_write_tokens=getattr(resp.usage, "cache_creation_input_tokens", 0)
|
|
60
|
+
or 0,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def _build_system(self, system: str) -> list[dict]:
|
|
64
|
+
return [self.format_cache_control({"type": "text", "text": system})]
|
|
65
|
+
|
|
66
|
+
def _build_messages(self, messages: list[Message]) -> list[dict]:
|
|
67
|
+
result = []
|
|
68
|
+
for i, msg in enumerate(messages):
|
|
69
|
+
blocks = [self._block_to_dict(b) for b in msg.content]
|
|
70
|
+
# Apply cache_control to last user message
|
|
71
|
+
if i == len(messages) - 1 and msg.role == "user" and blocks:
|
|
72
|
+
last_block = blocks[-1]
|
|
73
|
+
blocks[-1] = self.format_cache_control(last_block)
|
|
74
|
+
result.append({"role": msg.role, "content": blocks})
|
|
75
|
+
return result
|
|
76
|
+
|
|
77
|
+
def _block_to_dict(self, block) -> dict:
|
|
78
|
+
if isinstance(block, TextBlock):
|
|
79
|
+
return {"type": "text", "text": block.text}
|
|
80
|
+
if isinstance(block, ToolUseBlock):
|
|
81
|
+
return {
|
|
82
|
+
"type": "tool_use",
|
|
83
|
+
"id": block.tool_use_id,
|
|
84
|
+
"name": block.tool_name,
|
|
85
|
+
"input": block.tool_input,
|
|
86
|
+
}
|
|
87
|
+
if isinstance(block, ToolResultBlock):
|
|
88
|
+
return {
|
|
89
|
+
"type": "tool_result",
|
|
90
|
+
"tool_use_id": block.tool_use_id,
|
|
91
|
+
"content": block.content,
|
|
92
|
+
"is_error": block.is_error,
|
|
93
|
+
}
|
|
94
|
+
raise ValueError(f"Unknown block type: {type(block)}")
|
|
95
|
+
|
|
96
|
+
def _build_tools(self, tools: list[ToolSpec]) -> list[dict]:
|
|
97
|
+
return [t.to_provider_dict() for t in tools]
|
|
98
|
+
|
|
99
|
+
def _normalize_content(self, content) -> list:
|
|
100
|
+
blocks = []
|
|
101
|
+
for block in content:
|
|
102
|
+
if block.type == "text":
|
|
103
|
+
blocks.append(TextBlock(text=block.text))
|
|
104
|
+
elif block.type == "tool_use":
|
|
105
|
+
blocks.append(
|
|
106
|
+
ToolUseBlock(
|
|
107
|
+
tool_use_id=block.id,
|
|
108
|
+
tool_name=block.name,
|
|
109
|
+
tool_input=dict(block.input),
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
return blocks
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
from dataact.types import ContentBlock, Message, ToolSpec
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StopReason(Enum):
|
|
9
|
+
END_TURN = "end_turn"
|
|
10
|
+
TOOL_USE = "tool_use"
|
|
11
|
+
MAX_TOKENS = "max_tokens"
|
|
12
|
+
STOP_SEQUENCE = "stop_sequence"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class NormalizedResponse:
|
|
17
|
+
stop_reason: StopReason
|
|
18
|
+
content: list[ContentBlock]
|
|
19
|
+
input_tokens: int
|
|
20
|
+
output_tokens: int
|
|
21
|
+
cache_read_tokens: int
|
|
22
|
+
cache_write_tokens: int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ProviderAdapter(ABC):
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def chat(
|
|
28
|
+
self,
|
|
29
|
+
system: str,
|
|
30
|
+
messages: list[Message],
|
|
31
|
+
tools: list[ToolSpec],
|
|
32
|
+
) -> NormalizedResponse: ...
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def format_cache_control(self, obj: dict) -> dict: ...
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
import openai
|
|
7
|
+
|
|
8
|
+
from dataact.providers.base import NormalizedResponse, ProviderAdapter, StopReason
|
|
9
|
+
from dataact.types import Message, TextBlock, ToolResultBlock, ToolSpec, ToolUseBlock
|
|
10
|
+
|
|
11
|
+
_STOP_REASON_MAP = {
|
|
12
|
+
"stop": StopReason.END_TURN,
|
|
13
|
+
"tool_calls": StopReason.TOOL_USE,
|
|
14
|
+
"length": StopReason.MAX_TOKENS,
|
|
15
|
+
# TODO: expose content-filter handling separately if the core stop model grows.
|
|
16
|
+
"content_filter": StopReason.END_TURN,
|
|
17
|
+
"function_call": StopReason.TOOL_USE,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OpenAIAdapter(ProviderAdapter):
|
|
22
|
+
def __init__(self, model: str = "gpt-4o-mini", max_tokens: int = 4096) -> None:
|
|
23
|
+
self._model = model
|
|
24
|
+
self._max_tokens = max_tokens
|
|
25
|
+
self._client = openai.OpenAI()
|
|
26
|
+
|
|
27
|
+
def format_cache_control(self, obj: dict) -> dict:
|
|
28
|
+
return copy.copy(obj)
|
|
29
|
+
|
|
30
|
+
def chat(
|
|
31
|
+
self,
|
|
32
|
+
system: str,
|
|
33
|
+
messages: list[Message],
|
|
34
|
+
tools: list[ToolSpec],
|
|
35
|
+
) -> NormalizedResponse:
|
|
36
|
+
api_messages = self._build_messages(system, messages)
|
|
37
|
+
api_tools = self._build_tools(tools)
|
|
38
|
+
|
|
39
|
+
response = self._client.chat.completions.create(
|
|
40
|
+
model=self._model,
|
|
41
|
+
max_tokens=self._max_tokens,
|
|
42
|
+
messages=api_messages,
|
|
43
|
+
tools=api_tools or openai.NOT_GIVEN,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
choice = response.choices[0]
|
|
47
|
+
stop_reason = _STOP_REASON_MAP.get(choice.finish_reason, StopReason.END_TURN)
|
|
48
|
+
|
|
49
|
+
return NormalizedResponse(
|
|
50
|
+
stop_reason=stop_reason,
|
|
51
|
+
content=self._normalize_message(choice.message),
|
|
52
|
+
input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0,
|
|
53
|
+
output_tokens=getattr(response.usage, "completion_tokens", 0) or 0,
|
|
54
|
+
cache_read_tokens=0,
|
|
55
|
+
cache_write_tokens=0,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def _build_messages(self, system: str, messages: list[Message]) -> list[dict]:
|
|
59
|
+
result = [{"role": "system", "content": system}]
|
|
60
|
+
for message in messages:
|
|
61
|
+
text_blocks = [b for b in message.content if isinstance(b, TextBlock)]
|
|
62
|
+
tool_uses = [b for b in message.content if isinstance(b, ToolUseBlock)]
|
|
63
|
+
tool_results = [
|
|
64
|
+
b for b in message.content if isinstance(b, ToolResultBlock)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
if text_blocks or tool_uses:
|
|
68
|
+
api_message: dict = {
|
|
69
|
+
"role": message.role,
|
|
70
|
+
"content": "\n".join(b.text for b in text_blocks)
|
|
71
|
+
if text_blocks
|
|
72
|
+
else None,
|
|
73
|
+
}
|
|
74
|
+
if tool_uses:
|
|
75
|
+
api_message["tool_calls"] = [
|
|
76
|
+
{
|
|
77
|
+
"id": block.tool_use_id,
|
|
78
|
+
"type": "function",
|
|
79
|
+
"function": {
|
|
80
|
+
"name": block.tool_name,
|
|
81
|
+
"arguments": json.dumps(block.tool_input),
|
|
82
|
+
},
|
|
83
|
+
}
|
|
84
|
+
for block in tool_uses
|
|
85
|
+
]
|
|
86
|
+
result.append(api_message)
|
|
87
|
+
|
|
88
|
+
for block in tool_results:
|
|
89
|
+
result.append(
|
|
90
|
+
{
|
|
91
|
+
"role": "tool",
|
|
92
|
+
"tool_call_id": block.tool_use_id,
|
|
93
|
+
"content": block.content,
|
|
94
|
+
}
|
|
95
|
+
)
|
|
96
|
+
return result
|
|
97
|
+
|
|
98
|
+
def _build_tools(self, tools: list[ToolSpec]) -> list[dict]:
|
|
99
|
+
return [
|
|
100
|
+
{
|
|
101
|
+
"type": "function",
|
|
102
|
+
"function": {
|
|
103
|
+
"name": tool.name,
|
|
104
|
+
"description": tool.description,
|
|
105
|
+
"parameters": tool.input_schema,
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
for tool in tools
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
def _normalize_message(self, message) -> list:
|
|
112
|
+
blocks = []
|
|
113
|
+
content = getattr(message, "content", None)
|
|
114
|
+
if content:
|
|
115
|
+
blocks.append(TextBlock(text=content))
|
|
116
|
+
|
|
117
|
+
for call in getattr(message, "tool_calls", None) or []:
|
|
118
|
+
blocks.append(
|
|
119
|
+
ToolUseBlock(
|
|
120
|
+
tool_use_id=call.id,
|
|
121
|
+
tool_name=call.function.name,
|
|
122
|
+
tool_input=json.loads(call.function.arguments or "{}"),
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
return blocks
|