dataact 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataact/format.py ADDED
@@ -0,0 +1,108 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ if TYPE_CHECKING:
7
+ from dataact.cache import SessionCache
8
+
9
+ _INLINE_STR_MAX = 500
10
+ _INLINE_JSON_MAX = 2000
11
+ _INLINE_COLLECTION_MAX = 20
12
+
13
+
14
+ def format_tool_output(
15
+ value: Any,
16
+ cache: "SessionCache | None" = None,
17
+ preferred_name: str | None = None,
18
+ ) -> str:
19
+ """Decide whether to inline output or cache it; return a string for the message."""
20
+ if isinstance(value, Exception):
21
+ return f"Error: {type(value).__name__}: {value}"
22
+
23
+ # DataFrame / ndarray → always cache
24
+ if _is_dataframe(value) or _is_ndarray(value):
25
+ return _cache_value(value, cache, preferred_name or _default_name(value))
26
+
27
+ # Short string → inline
28
+ if isinstance(value, str):
29
+ if len(value) <= _INLINE_STR_MAX:
30
+ return value
31
+ return _cache_value(value, cache, preferred_name or "text_result")
32
+
33
+ # Scalar (int, float, bool, None)
34
+ if isinstance(value, (int, float, bool)) or value is None:
35
+ return str(value)
36
+
37
+ # Short dict/list → inline JSON repr
38
+ if isinstance(value, (dict, list)):
39
+ try:
40
+ serialized = json.dumps(value, default=repr)
41
+ except Exception:
42
+ serialized = repr(value)
43
+ if (
44
+ len(serialized) <= _INLINE_JSON_MAX
45
+ and _collection_size(value) <= _INLINE_COLLECTION_MAX
46
+ ):
47
+ return serialized
48
+ return _cache_value(value, cache, preferred_name or _default_name(value))
49
+
50
+ # Unknown object with preferred_name → cache
51
+ if preferred_name is not None and cache is not None:
52
+ return _cache_value(value, cache, preferred_name)
53
+
54
+ # Unknown object → repr truncated
55
+ r = repr(value)
56
+ if len(r) > _INLINE_STR_MAX:
57
+ return r[:_INLINE_STR_MAX] + "..."
58
+ return r
59
+
60
+
61
+ def _cache_value(value: Any, cache: "SessionCache | None", name: str) -> str:
62
+ if cache is None:
63
+ # No cache available, fall back to repr
64
+ r = repr(value)
65
+ if len(r) > _INLINE_STR_MAX:
66
+ return r[:_INLINE_STR_MAX] + "..."
67
+ return r
68
+ resolved = cache.put(name, value)
69
+ snapshot = cache.snapshot(resolved)
70
+ return f"Saved as `{resolved}`\nSnapshot: {snapshot}"
71
+
72
+
73
+ def _default_name(value: Any) -> str:
74
+ if _is_dataframe(value):
75
+ return "dataframe"
76
+ if _is_ndarray(value):
77
+ return "array"
78
+ if isinstance(value, dict):
79
+ return "result_dict"
80
+ if isinstance(value, list):
81
+ return "result_list"
82
+ return "result"
83
+
84
+
85
+ def _collection_size(value: Any) -> int:
86
+ if isinstance(value, dict):
87
+ return len(value)
88
+ if isinstance(value, list):
89
+ return len(value)
90
+ return 0
91
+
92
+
93
+ def _is_dataframe(value: Any) -> bool:
94
+ try:
95
+ import pandas as pd
96
+
97
+ return isinstance(value, pd.DataFrame)
98
+ except ImportError:
99
+ return False
100
+
101
+
102
+ def _is_ndarray(value: Any) -> bool:
103
+ try:
104
+ import numpy as np
105
+
106
+ return isinstance(value, np.ndarray)
107
+ except ImportError:
108
+ return False
dataact/logger.py ADDED
@@ -0,0 +1,66 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ from datetime import datetime, timezone
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from loguru import logger
10
+
11
+ from dataact.providers.base import NormalizedResponse
12
+ from dataact.serialize import to_jsonable
13
+ from dataact.types import Message, ToolResultBlock
14
+
15
+
16
+ def setup_logger(run_dir: str = "./runs") -> str:
17
+ """Create a timestamped JSONL file and configure loguru. Returns the path."""
18
+ Path(run_dir).mkdir(parents=True, exist_ok=True)
19
+ ts = datetime.now(tz=timezone.utc).strftime("%Y%m%dT%H%M%S")
20
+ run_file = str(Path(run_dir) / f"{ts}.jsonl")
21
+ # Ensure the file exists
22
+ Path(run_file).touch()
23
+ logger.remove()
24
+ logger.add(
25
+ lambda msg: None, level="INFO"
26
+ ) # suppress default stderr; callers can add sinks
27
+ return run_file
28
+
29
+
30
+ def log_turn(
31
+ turn: int,
32
+ system: str,
33
+ messages: list[Message],
34
+ response: NormalizedResponse,
35
+ tool_results: list[ToolResultBlock],
36
+ latency_ms: float,
37
+ run_file: str,
38
+ cache_storage: dict[str, dict[str, str]] | None = None,
39
+ ) -> None:
40
+ """Append one JSON line to the run JSONL file."""
41
+ system_hash = hashlib.sha256(system.encode()).hexdigest()
42
+
43
+ record: dict[str, Any] = {
44
+ "turn": turn,
45
+ "timestamp": datetime.now(tz=timezone.utc).isoformat(),
46
+ "system_hash": system_hash,
47
+ "messages": to_jsonable(messages),
48
+ "response_content": to_jsonable(response.content),
49
+ "stop_reason": response.stop_reason.value,
50
+ "tool_results": to_jsonable(tool_results),
51
+ "metrics": {
52
+ "input_tokens": response.input_tokens,
53
+ "output_tokens": response.output_tokens,
54
+ "cache_read_tokens": response.cache_read_tokens,
55
+ "cache_write_tokens": response.cache_write_tokens,
56
+ "latency_ms": latency_ms,
57
+ },
58
+ }
59
+
60
+ if turn == 1:
61
+ record["system"] = system
62
+ if cache_storage is not None:
63
+ record["cache_storage"] = cache_storage
64
+
65
+ with open(run_file, "a") as f:
66
+ f.write(json.dumps(record) + "\n")
dataact/loop.py ADDED
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
5
+ from dataact.cache import SessionCache
6
+ from dataact.exceptions import MaxTurnsExceeded
7
+ from dataact.format import format_tool_output
8
+ from dataact.logger import log_turn, setup_logger
9
+ from dataact.observe import time_block
10
+ from dataact.providers.base import NormalizedResponse, ProviderAdapter, StopReason
11
+ from dataact.types import Message, TextBlock, ToolResultBlock, ToolSpec, ToolUseBlock
12
+
13
+ _MAX_TURN_REMINDER = (
14
+ "This is the final turn. You MUST produce your complete final output now. "
15
+ "Do not use any more tools. Respond with your answer directly."
16
+ )
17
+
18
+
19
+ class Harness:
20
+ def __init__(
21
+ self,
22
+ adapter: ProviderAdapter,
23
+ system: str,
24
+ tools: list[ToolSpec],
25
+ max_turns: int = 25,
26
+ run_dir: str = "./runs",
27
+ cache: SessionCache | None = None,
28
+ ) -> None:
29
+ self._adapter = adapter
30
+ self._system = system
31
+ self._tools = list(tools)
32
+ self._max_turns = max_turns
33
+ self._run_dir = run_dir
34
+ self._cache = cache if cache is not None else SessionCache()
35
+ self._messages: list[Message] = []
36
+ self._reminders: list[Callable[[int, int], str | None]] = []
37
+
38
+ def register_reminder(self, hook: Callable[[int, int], str | None]) -> None:
39
+ self._reminders.append(hook)
40
+
41
+ def run(self, user_message: str) -> str:
42
+ run_file = setup_logger(self._run_dir)
43
+ self._messages = [Message(role="user", content=[TextBlock(text=user_message)])]
44
+ last_response: NormalizedResponse | None = None
45
+
46
+ for turn in range(1, self._max_turns + 1):
47
+ self._apply_reminders(turn)
48
+ visible_tools = [t for t in self._tools if t.visible]
49
+
50
+ with time_block() as tb:
51
+ response = self._adapter.chat(
52
+ system=self._system,
53
+ messages=self._messages,
54
+ tools=visible_tools,
55
+ )
56
+ last_response = response
57
+ latency = tb.elapsed_ms
58
+
59
+ # Append assistant message
60
+ self._messages.append(Message(role="assistant", content=response.content))
61
+
62
+ tool_results: list[ToolResultBlock] = []
63
+
64
+ if response.stop_reason == StopReason.TOOL_USE:
65
+ tool_results = self._dispatch_tools(response.content)
66
+ user_msg = Message(role="user", content=list(tool_results))
67
+ self._messages.append(user_msg)
68
+
69
+ log_turn(
70
+ turn=turn,
71
+ system=self._system,
72
+ messages=self._messages,
73
+ response=response,
74
+ tool_results=tool_results,
75
+ latency_ms=latency,
76
+ run_file=run_file,
77
+ cache_storage=self._cache.storage_metadata(),
78
+ )
79
+
80
+ if response.stop_reason == StopReason.END_TURN:
81
+ return self._extract_text(response)
82
+
83
+ if turn == self._max_turns:
84
+ raise MaxTurnsExceeded(turn, last_response)
85
+
86
+ raise MaxTurnsExceeded(self._max_turns, last_response)
87
+
88
+ def _apply_reminders(self, turn: int) -> None:
89
+ reminder_texts: list[str] = []
90
+
91
+ for hook in self._reminders:
92
+ text = hook(turn, self._max_turns)
93
+ if text:
94
+ reminder_texts.append(text)
95
+
96
+ # Built-in max-turn reminder
97
+ if turn == self._max_turns - 1:
98
+ reminder_texts.append(_MAX_TURN_REMINDER)
99
+
100
+ if not reminder_texts:
101
+ return
102
+
103
+ combined = "\n\n".join(reminder_texts)
104
+ reminder_block = TextBlock(text=combined)
105
+
106
+ # Append to existing user message or create a new one
107
+ if self._messages and self._messages[-1].role == "user":
108
+ self._messages[-1].content.append(reminder_block)
109
+ else:
110
+ self._messages.append(Message(role="user", content=[reminder_block]))
111
+
112
+ def _dispatch_tools(self, content: list) -> list[ToolResultBlock]:
113
+ tool_uses = [b for b in content if isinstance(b, ToolUseBlock)]
114
+ results = []
115
+ tool_map = {t.name: t for t in self._tools}
116
+
117
+ for tub in tool_uses:
118
+ spec = tool_map.get(tub.tool_name)
119
+ if spec is None or spec.handler is None:
120
+ results.append(
121
+ ToolResultBlock(
122
+ tool_use_id=tub.tool_use_id,
123
+ content=f"Tool not found: {tub.tool_name!r}",
124
+ is_error=True,
125
+ )
126
+ )
127
+ continue
128
+ try:
129
+ raw = spec.handler(**tub.tool_input)
130
+ output = format_tool_output(raw, cache=self._cache)
131
+ except Exception as exc:
132
+ output = repr(exc)
133
+ results.append(
134
+ ToolResultBlock(
135
+ tool_use_id=tub.tool_use_id,
136
+ content=output,
137
+ is_error=True,
138
+ )
139
+ )
140
+ continue
141
+ results.append(
142
+ ToolResultBlock(
143
+ tool_use_id=tub.tool_use_id,
144
+ content=output,
145
+ is_error=False,
146
+ )
147
+ )
148
+
149
+ return results
150
+
151
+ def _extract_text(self, response: NormalizedResponse) -> str:
152
+ texts = [b.text for b in response.content if isinstance(b, TextBlock)]
153
+ return "\n".join(texts)
dataact/observe.py ADDED
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from contextlib import contextmanager
5
+ from dataclasses import dataclass
6
+ from typing import Generator
7
+
8
+
9
+ @dataclass
10
+ class TurnMetrics:
11
+ turn: int
12
+ input_tokens: int
13
+ output_tokens: int
14
+ cache_read_tokens: int
15
+ cache_write_tokens: int
16
+ latency_ms: float
17
+
18
+
19
+ class _TimeResult:
20
+ def __init__(self) -> None:
21
+ self.elapsed_ms: float = 0.0
22
+
23
+
24
+ @contextmanager
25
+ def time_block() -> Generator[_TimeResult, None, None]:
26
+ result = _TimeResult()
27
+ start = time.monotonic()
28
+ try:
29
+ yield result
30
+ finally:
31
+ result.elapsed_ms = (time.monotonic() - start) * 1000
File without changes
@@ -0,0 +1,112 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+
5
+ import anthropic
6
+
7
+ from dataact.providers.base import NormalizedResponse, ProviderAdapter, StopReason
8
+ from dataact.types import Message, TextBlock, ToolResultBlock, ToolSpec, ToolUseBlock
9
+
10
+ _STOP_REASON_MAP = {
11
+ "end_turn": StopReason.END_TURN,
12
+ "tool_use": StopReason.TOOL_USE,
13
+ "max_tokens": StopReason.MAX_TOKENS,
14
+ "stop_sequence": StopReason.STOP_SEQUENCE,
15
+ }
16
+
17
+
18
+ class AnthropicAdapter(ProviderAdapter):
19
+ def __init__(
20
+ self, model: str = "claude-sonnet-4-6", max_tokens: int = 8096
21
+ ) -> None:
22
+ self._model = model
23
+ self._max_tokens = max_tokens
24
+ self._client = anthropic.Anthropic()
25
+
26
+ def format_cache_control(self, obj: dict) -> dict:
27
+ result = copy.copy(obj)
28
+ result["cache_control"] = {"type": "ephemeral"}
29
+ return result
30
+
31
+ def chat(
32
+ self,
33
+ system: str,
34
+ messages: list[Message],
35
+ tools: list[ToolSpec],
36
+ ) -> NormalizedResponse:
37
+ # Deep-copy inputs so we never mutate harness state
38
+ api_system = self._build_system(system)
39
+ api_messages = self._build_messages(messages)
40
+ api_tools = self._build_tools(tools)
41
+
42
+ resp = self._client.messages.create(
43
+ model=self._model,
44
+ max_tokens=self._max_tokens,
45
+ system=api_system,
46
+ messages=api_messages,
47
+ tools=api_tools or anthropic.NOT_GIVEN,
48
+ )
49
+
50
+ stop_reason = _STOP_REASON_MAP.get(resp.stop_reason, StopReason.END_TURN)
51
+ content = self._normalize_content(resp.content)
52
+
53
+ return NormalizedResponse(
54
+ stop_reason=stop_reason,
55
+ content=content,
56
+ input_tokens=resp.usage.input_tokens,
57
+ output_tokens=resp.usage.output_tokens,
58
+ cache_read_tokens=getattr(resp.usage, "cache_read_input_tokens", 0) or 0,
59
+ cache_write_tokens=getattr(resp.usage, "cache_creation_input_tokens", 0)
60
+ or 0,
61
+ )
62
+
63
+ def _build_system(self, system: str) -> list[dict]:
64
+ return [self.format_cache_control({"type": "text", "text": system})]
65
+
66
+ def _build_messages(self, messages: list[Message]) -> list[dict]:
67
+ result = []
68
+ for i, msg in enumerate(messages):
69
+ blocks = [self._block_to_dict(b) for b in msg.content]
70
+ # Apply cache_control to last user message
71
+ if i == len(messages) - 1 and msg.role == "user" and blocks:
72
+ last_block = blocks[-1]
73
+ blocks[-1] = self.format_cache_control(last_block)
74
+ result.append({"role": msg.role, "content": blocks})
75
+ return result
76
+
77
+ def _block_to_dict(self, block) -> dict:
78
+ if isinstance(block, TextBlock):
79
+ return {"type": "text", "text": block.text}
80
+ if isinstance(block, ToolUseBlock):
81
+ return {
82
+ "type": "tool_use",
83
+ "id": block.tool_use_id,
84
+ "name": block.tool_name,
85
+ "input": block.tool_input,
86
+ }
87
+ if isinstance(block, ToolResultBlock):
88
+ return {
89
+ "type": "tool_result",
90
+ "tool_use_id": block.tool_use_id,
91
+ "content": block.content,
92
+ "is_error": block.is_error,
93
+ }
94
+ raise ValueError(f"Unknown block type: {type(block)}")
95
+
96
+ def _build_tools(self, tools: list[ToolSpec]) -> list[dict]:
97
+ return [t.to_provider_dict() for t in tools]
98
+
99
+ def _normalize_content(self, content) -> list:
100
+ blocks = []
101
+ for block in content:
102
+ if block.type == "text":
103
+ blocks.append(TextBlock(text=block.text))
104
+ elif block.type == "tool_use":
105
+ blocks.append(
106
+ ToolUseBlock(
107
+ tool_use_id=block.id,
108
+ tool_name=block.name,
109
+ tool_input=dict(block.input),
110
+ )
111
+ )
112
+ return blocks
@@ -0,0 +1,35 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+
5
+ from dataact.types import ContentBlock, Message, ToolSpec
6
+
7
+
8
+ class StopReason(Enum):
9
+ END_TURN = "end_turn"
10
+ TOOL_USE = "tool_use"
11
+ MAX_TOKENS = "max_tokens"
12
+ STOP_SEQUENCE = "stop_sequence"
13
+
14
+
15
+ @dataclass
16
+ class NormalizedResponse:
17
+ stop_reason: StopReason
18
+ content: list[ContentBlock]
19
+ input_tokens: int
20
+ output_tokens: int
21
+ cache_read_tokens: int
22
+ cache_write_tokens: int
23
+
24
+
25
+ class ProviderAdapter(ABC):
26
+ @abstractmethod
27
+ def chat(
28
+ self,
29
+ system: str,
30
+ messages: list[Message],
31
+ tools: list[ToolSpec],
32
+ ) -> NormalizedResponse: ...
33
+
34
+ @abstractmethod
35
+ def format_cache_control(self, obj: dict) -> dict: ...
@@ -0,0 +1,125 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import json
5
+
6
+ import openai
7
+
8
+ from dataact.providers.base import NormalizedResponse, ProviderAdapter, StopReason
9
+ from dataact.types import Message, TextBlock, ToolResultBlock, ToolSpec, ToolUseBlock
10
+
11
+ _STOP_REASON_MAP = {
12
+ "stop": StopReason.END_TURN,
13
+ "tool_calls": StopReason.TOOL_USE,
14
+ "length": StopReason.MAX_TOKENS,
15
+ # TODO: expose content-filter handling separately if the core stop model grows.
16
+ "content_filter": StopReason.END_TURN,
17
+ "function_call": StopReason.TOOL_USE,
18
+ }
19
+
20
+
21
+ class OpenAIAdapter(ProviderAdapter):
22
+ def __init__(self, model: str = "gpt-4o-mini", max_tokens: int = 4096) -> None:
23
+ self._model = model
24
+ self._max_tokens = max_tokens
25
+ self._client = openai.OpenAI()
26
+
27
+ def format_cache_control(self, obj: dict) -> dict:
28
+ return copy.copy(obj)
29
+
30
+ def chat(
31
+ self,
32
+ system: str,
33
+ messages: list[Message],
34
+ tools: list[ToolSpec],
35
+ ) -> NormalizedResponse:
36
+ api_messages = self._build_messages(system, messages)
37
+ api_tools = self._build_tools(tools)
38
+
39
+ response = self._client.chat.completions.create(
40
+ model=self._model,
41
+ max_tokens=self._max_tokens,
42
+ messages=api_messages,
43
+ tools=api_tools or openai.NOT_GIVEN,
44
+ )
45
+
46
+ choice = response.choices[0]
47
+ stop_reason = _STOP_REASON_MAP.get(choice.finish_reason, StopReason.END_TURN)
48
+
49
+ return NormalizedResponse(
50
+ stop_reason=stop_reason,
51
+ content=self._normalize_message(choice.message),
52
+ input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0,
53
+ output_tokens=getattr(response.usage, "completion_tokens", 0) or 0,
54
+ cache_read_tokens=0,
55
+ cache_write_tokens=0,
56
+ )
57
+
58
+ def _build_messages(self, system: str, messages: list[Message]) -> list[dict]:
59
+ result = [{"role": "system", "content": system}]
60
+ for message in messages:
61
+ text_blocks = [b for b in message.content if isinstance(b, TextBlock)]
62
+ tool_uses = [b for b in message.content if isinstance(b, ToolUseBlock)]
63
+ tool_results = [
64
+ b for b in message.content if isinstance(b, ToolResultBlock)
65
+ ]
66
+
67
+ if text_blocks or tool_uses:
68
+ api_message: dict = {
69
+ "role": message.role,
70
+ "content": "\n".join(b.text for b in text_blocks)
71
+ if text_blocks
72
+ else None,
73
+ }
74
+ if tool_uses:
75
+ api_message["tool_calls"] = [
76
+ {
77
+ "id": block.tool_use_id,
78
+ "type": "function",
79
+ "function": {
80
+ "name": block.tool_name,
81
+ "arguments": json.dumps(block.tool_input),
82
+ },
83
+ }
84
+ for block in tool_uses
85
+ ]
86
+ result.append(api_message)
87
+
88
+ for block in tool_results:
89
+ result.append(
90
+ {
91
+ "role": "tool",
92
+ "tool_call_id": block.tool_use_id,
93
+ "content": block.content,
94
+ }
95
+ )
96
+ return result
97
+
98
+ def _build_tools(self, tools: list[ToolSpec]) -> list[dict]:
99
+ return [
100
+ {
101
+ "type": "function",
102
+ "function": {
103
+ "name": tool.name,
104
+ "description": tool.description,
105
+ "parameters": tool.input_schema,
106
+ },
107
+ }
108
+ for tool in tools
109
+ ]
110
+
111
+ def _normalize_message(self, message) -> list:
112
+ blocks = []
113
+ content = getattr(message, "content", None)
114
+ if content:
115
+ blocks.append(TextBlock(text=content))
116
+
117
+ for call in getattr(message, "tool_calls", None) or []:
118
+ blocks.append(
119
+ ToolUseBlock(
120
+ tool_use_id=call.id,
121
+ tool_name=call.function.name,
122
+ tool_input=json.loads(call.function.arguments or "{}"),
123
+ )
124
+ )
125
+ return blocks