thoughtflow 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,27 @@
1
+ """
2
+ Tool interfaces for ThoughtFlow.
3
+
4
+ Tools are functions with contracts that agents can invoke.
5
+ ThoughtFlow makes tool use explicit, testable, and auditable.
6
+
7
+ Example:
8
+ >>> from thoughtflow.tools import Tool
9
+ >>>
10
+ >>> class Calculator(Tool):
11
+ ... name = "calculator"
12
+ ... description = "Perform arithmetic operations"
13
+ ...
14
+ ... def call(self, payload):
15
+ ... return eval(payload["expression"])
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from thoughtflow.tools.base import Tool, ToolResult
21
+ from thoughtflow.tools.registry import ToolRegistry
22
+
23
+ __all__ = [
24
+ "Tool",
25
+ "ToolResult",
26
+ "ToolRegistry",
27
+ ]
@@ -0,0 +1,145 @@
1
+ """
2
+ Base tool interface for ThoughtFlow.
3
+
4
+ Tools are functions with contracts. Tool invocation is an explicit step,
5
+ tool results are recorded in the trace, and tools can be simulated/stubbed
6
+ for deterministic tests.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from abc import ABC, abstractmethod
12
+ from dataclasses import dataclass, field
13
+ from typing import Any
14
+
15
+
16
+ @dataclass
17
+ class ToolResult:
18
+ """Result of a tool invocation.
19
+
20
+ Attributes:
21
+ success: Whether the tool call succeeded.
22
+ output: The tool's output (if successful).
23
+ error: Error message (if failed).
24
+ metadata: Additional metadata about the call.
25
+ """
26
+
27
+ success: bool
28
+ output: Any = None
29
+ error: str | None = None
30
+ metadata: dict[str, Any] = field(default_factory=dict)
31
+
32
+ @classmethod
33
+ def ok(cls, output: Any, **metadata: Any) -> ToolResult:
34
+ """Create a successful result.
35
+
36
+ Args:
37
+ output: The tool's output.
38
+ **metadata: Additional metadata.
39
+
40
+ Returns:
41
+ A successful ToolResult.
42
+ """
43
+ return cls(success=True, output=output, metadata=metadata)
44
+
45
+ @classmethod
46
+ def fail(cls, error: str, **metadata: Any) -> ToolResult:
47
+ """Create a failed result.
48
+
49
+ Args:
50
+ error: Error message.
51
+ **metadata: Additional metadata.
52
+
53
+ Returns:
54
+ A failed ToolResult.
55
+ """
56
+ return cls(success=False, error=error, metadata=metadata)
57
+
58
+
59
+ class Tool(ABC):
60
+ """Abstract base class for tools.
61
+
62
+ Tools are the mechanism for agents to interact with the outside world.
63
+ Each tool has:
64
+ - A unique name
65
+ - A description (for the LLM to understand when to use it)
66
+ - A schema (JSON Schema for the expected input)
67
+ - A call method that executes the tool
68
+
69
+ Example:
70
+ >>> class WebSearch(Tool):
71
+ ... name = "web_search"
72
+ ... description = "Search the web for information"
73
+ ...
74
+ ... def get_schema(self):
75
+ ... return {
76
+ ... "type": "object",
77
+ ... "properties": {
78
+ ... "query": {"type": "string"}
79
+ ... },
80
+ ... "required": ["query"]
81
+ ... }
82
+ ...
83
+ ... def call(self, payload, params=None):
84
+ ... query = payload["query"]
85
+ ... # ... perform search ...
86
+ ... return ToolResult.ok(results)
87
+ """
88
+
89
+ # Subclasses should override these
90
+ name: str = "unnamed_tool"
91
+ description: str = "No description provided"
92
+
93
+ @abstractmethod
94
+ def call(
95
+ self,
96
+ payload: dict[str, Any],
97
+ params: dict[str, Any] | None = None,
98
+ ) -> ToolResult:
99
+ """Execute the tool with the given payload.
100
+
101
+ Args:
102
+ payload: The input data for the tool.
103
+ params: Optional execution parameters.
104
+
105
+ Returns:
106
+ ToolResult indicating success/failure and output.
107
+ """
108
+ raise NotImplementedError
109
+
110
+ def get_schema(self) -> dict[str, Any]:
111
+ """Get the JSON Schema for the tool's input.
112
+
113
+ Override this to provide a schema for the LLM.
114
+
115
+ Returns:
116
+ JSON Schema dict describing expected input.
117
+ """
118
+ return {"type": "object", "properties": {}}
119
+
120
+ def to_openai_tool(self) -> dict[str, Any]:
121
+ """Convert to OpenAI tool format.
122
+
123
+ Returns:
124
+ Dict in OpenAI's tool specification format.
125
+ """
126
+ return {
127
+ "type": "function",
128
+ "function": {
129
+ "name": self.name,
130
+ "description": self.description,
131
+ "parameters": self.get_schema(),
132
+ },
133
+ }
134
+
135
+ def to_anthropic_tool(self) -> dict[str, Any]:
136
+ """Convert to Anthropic tool format.
137
+
138
+ Returns:
139
+ Dict in Anthropic's tool specification format.
140
+ """
141
+ return {
142
+ "name": self.name,
143
+ "description": self.description,
144
+ "input_schema": self.get_schema(),
145
+ }
@@ -0,0 +1,122 @@
1
+ """
2
+ Tool registry for ThoughtFlow.
3
+
4
+ Provides an explicit registry for tools. This is optional - you can
5
+ also pass tools directly to agents. The registry is useful for
6
+ organizing and discovering available tools.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import TYPE_CHECKING
12
+
13
+ if TYPE_CHECKING:
14
+ from thoughtflow.tools.base import Tool
15
+
16
+
17
+ class ToolRegistry:
18
+ """Registry for managing available tools.
19
+
20
+ The registry provides a central place to register and lookup tools.
21
+ This is completely optional - ThoughtFlow doesn't require using a registry.
22
+
23
+ Example:
24
+ >>> registry = ToolRegistry()
25
+ >>> registry.register(calculator_tool)
26
+ >>> registry.register(web_search_tool)
27
+ >>>
28
+ >>> # Get a tool by name
29
+ >>> calc = registry.get("calculator")
30
+ >>>
31
+ >>> # Get all tools
32
+ >>> all_tools = registry.list()
33
+ """
34
+
35
+ def __init__(self) -> None:
36
+ """Initialize an empty registry."""
37
+ self._tools: dict[str, Tool] = {}
38
+
39
+ def register(self, tool: Tool) -> None:
40
+ """Register a tool.
41
+
42
+ Args:
43
+ tool: The tool to register.
44
+
45
+ Raises:
46
+ ValueError: If a tool with the same name already exists.
47
+ """
48
+ if tool.name in self._tools:
49
+ raise ValueError(
50
+ f"Tool '{tool.name}' is already registered. "
51
+ "Use replace=True to override."
52
+ )
53
+ self._tools[tool.name] = tool
54
+
55
+ def unregister(self, name: str) -> None:
56
+ """Unregister a tool by name.
57
+
58
+ Args:
59
+ name: Name of the tool to unregister.
60
+
61
+ Raises:
62
+ KeyError: If no tool with that name exists.
63
+ """
64
+ if name not in self._tools:
65
+ raise KeyError(f"Tool '{name}' not found in registry")
66
+ del self._tools[name]
67
+
68
+ def get(self, name: str) -> Tool:
69
+ """Get a tool by name.
70
+
71
+ Args:
72
+ name: Name of the tool.
73
+
74
+ Returns:
75
+ The registered Tool.
76
+
77
+ Raises:
78
+ KeyError: If no tool with that name exists.
79
+ """
80
+ if name not in self._tools:
81
+ raise KeyError(f"Tool '{name}' not found in registry")
82
+ return self._tools[name]
83
+
84
+ def list(self) -> list[Tool]:
85
+ """List all registered tools.
86
+
87
+ Returns:
88
+ List of all registered tools.
89
+ """
90
+ return list(self._tools.values())
91
+
92
+ def names(self) -> list[str]:
93
+ """List all registered tool names.
94
+
95
+ Returns:
96
+ List of tool names.
97
+ """
98
+ return list(self._tools.keys())
99
+
100
+ def to_openai_tools(self) -> list[dict]:
101
+ """Convert all tools to OpenAI format.
102
+
103
+ Returns:
104
+ List of tool dicts in OpenAI format.
105
+ """
106
+ return [tool.to_openai_tool() for tool in self._tools.values()]
107
+
108
+ def to_anthropic_tools(self) -> list[dict]:
109
+ """Convert all tools to Anthropic format.
110
+
111
+ Returns:
112
+ List of tool dicts in Anthropic format.
113
+ """
114
+ return [tool.to_anthropic_tool() for tool in self._tools.values()]
115
+
116
+ def __len__(self) -> int:
117
+ """Return number of registered tools."""
118
+ return len(self._tools)
119
+
120
+ def __contains__(self, name: str) -> bool:
121
+ """Check if a tool is registered."""
122
+ return name in self._tools
@@ -0,0 +1,34 @@
1
+ """
2
+ Tracing and session management for ThoughtFlow.
3
+
4
+ Traces capture complete run state: inputs, outputs, tool calls, model calls,
5
+ timing, token usage, and costs. This enables debugging, evaluation,
6
+ reproducibility, regression testing, and replay/diff across versions.
7
+
8
+ Example:
9
+ >>> from thoughtflow.trace import Session
10
+ >>>
11
+ >>> session = Session()
12
+ >>> response = agent.call(messages, session=session)
13
+ >>>
14
+ >>> # Inspect the trace
15
+ >>> print(session.events)
16
+ >>> print(session.total_tokens)
17
+ >>> print(session.total_cost)
18
+ >>>
19
+ >>> # Save for replay
20
+ >>> session.save("trace.json")
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from thoughtflow.trace.session import Session
26
+ from thoughtflow.trace.events import Event, EventType
27
+ from thoughtflow.trace.schema import TraceSchema
28
+
29
+ __all__ = [
30
+ "Session",
31
+ "Event",
32
+ "EventType",
33
+ "TraceSchema",
34
+ ]
@@ -0,0 +1,183 @@
1
+ """
2
+ Event types for ThoughtFlow tracing.
3
+
4
+ Events represent discrete occurrences during an agent run:
5
+ - Model calls (start, end, error)
6
+ - Tool invocations
7
+ - Memory operations
8
+ - Custom user events
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass, field
14
+ from datetime import datetime
15
+ from enum import Enum
16
+ from typing import Any
17
+
18
+
19
+ class EventType(str, Enum):
20
+ """Types of events that can occur during a session."""
21
+
22
+ # Agent lifecycle
23
+ CALL_START = "call_start"
24
+ CALL_END = "call_end"
25
+ CALL_ERROR = "call_error"
26
+
27
+ # Model interactions
28
+ MODEL_REQUEST = "model_request"
29
+ MODEL_RESPONSE = "model_response"
30
+ MODEL_ERROR = "model_error"
31
+
32
+ # Tool interactions
33
+ TOOL_CALL = "tool_call"
34
+ TOOL_RESULT = "tool_result"
35
+ TOOL_ERROR = "tool_error"
36
+
37
+ # Memory interactions
38
+ MEMORY_RETRIEVE = "memory_retrieve"
39
+ MEMORY_STORE = "memory_store"
40
+
41
+ # Custom
42
+ CUSTOM = "custom"
43
+
44
+
45
+ @dataclass
46
+ class Event:
47
+ """A single event in a session trace.
48
+
49
+ Events capture everything that happens during an agent run,
50
+ enabling complete visibility and replay capability.
51
+
52
+ Attributes:
53
+ event_type: The type of event.
54
+ timestamp: When the event occurred.
55
+ data: Event-specific data.
56
+ duration_ms: Duration in milliseconds (for end events).
57
+ metadata: Additional metadata.
58
+
59
+ Example:
60
+ >>> event = Event(
61
+ ... event_type=EventType.MODEL_REQUEST,
62
+ ... data={
63
+ ... "messages": [...],
64
+ ... "params": {"model": "gpt-4", "temperature": 0.7}
65
+ ... }
66
+ ... )
67
+ """
68
+
69
+ event_type: EventType | str
70
+ timestamp: datetime = field(default_factory=datetime.now)
71
+ data: dict[str, Any] = field(default_factory=dict)
72
+ duration_ms: int | None = None
73
+ metadata: dict[str, Any] = field(default_factory=dict)
74
+
75
+ def to_dict(self) -> dict[str, Any]:
76
+ """Convert to a serializable dict.
77
+
78
+ Returns:
79
+ Dict representation of the event.
80
+ """
81
+ return {
82
+ "event_type": (
83
+ self.event_type.value
84
+ if isinstance(self.event_type, EventType)
85
+ else self.event_type
86
+ ),
87
+ "timestamp": self.timestamp.isoformat(),
88
+ "data": self.data,
89
+ "duration_ms": self.duration_ms,
90
+ "metadata": self.metadata,
91
+ }
92
+
93
+ @classmethod
94
+ def from_dict(cls, data: dict[str, Any]) -> Event:
95
+ """Create an Event from a dict.
96
+
97
+ Args:
98
+ data: Dict with event data.
99
+
100
+ Returns:
101
+ Event instance.
102
+ """
103
+ event_type_str = data["event_type"]
104
+ try:
105
+ event_type = EventType(event_type_str)
106
+ except ValueError:
107
+ event_type = event_type_str
108
+
109
+ return cls(
110
+ event_type=event_type,
111
+ timestamp=datetime.fromisoformat(data["timestamp"]),
112
+ data=data.get("data", {}),
113
+ duration_ms=data.get("duration_ms"),
114
+ metadata=data.get("metadata", {}),
115
+ )
116
+
117
+
118
+ # Convenience functions for creating common events
119
+
120
+
121
+ def call_start(messages: list[dict], params: dict | None = None) -> Event:
122
+ """Create a CALL_START event.
123
+
124
+ Args:
125
+ messages: The input messages.
126
+ params: Call parameters.
127
+
128
+ Returns:
129
+ Event instance.
130
+ """
131
+ return Event(
132
+ event_type=EventType.CALL_START,
133
+ data={"messages": messages, "params": params or {}},
134
+ )
135
+
136
+
137
+ def call_end(response: str, tokens: dict | None = None) -> Event:
138
+ """Create a CALL_END event.
139
+
140
+ Args:
141
+ response: The agent's response.
142
+ tokens: Token usage information.
143
+
144
+ Returns:
145
+ Event instance.
146
+ """
147
+ return Event(
148
+ event_type=EventType.CALL_END,
149
+ data={"response": response, "tokens": tokens or {}},
150
+ )
151
+
152
+
153
+ def tool_call(tool_name: str, payload: dict) -> Event:
154
+ """Create a TOOL_CALL event.
155
+
156
+ Args:
157
+ tool_name: Name of the tool being called.
158
+ payload: The tool's input payload.
159
+
160
+ Returns:
161
+ Event instance.
162
+ """
163
+ return Event(
164
+ event_type=EventType.TOOL_CALL,
165
+ data={"tool_name": tool_name, "payload": payload},
166
+ )
167
+
168
+
169
+ def tool_result(tool_name: str, result: Any, success: bool = True) -> Event:
170
+ """Create a TOOL_RESULT event.
171
+
172
+ Args:
173
+ tool_name: Name of the tool.
174
+ result: The tool's output.
175
+ success: Whether the tool call succeeded.
176
+
177
+ Returns:
178
+ Event instance.
179
+ """
180
+ return Event(
181
+ event_type=EventType.TOOL_RESULT,
182
+ data={"tool_name": tool_name, "result": result, "success": success},
183
+ )
@@ -0,0 +1,111 @@
1
+ """
2
+ Trace schema versioning for ThoughtFlow.
3
+
4
+ Once trace schemas are in use, downstream systems depend on them.
5
+ This module provides schema versioning to maintain backward compatibility.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+
14
+ # Current schema version
15
+ SCHEMA_VERSION = "1.0.0"
16
+
17
+
18
+ @dataclass
19
+ class TraceSchema:
20
+ """Schema metadata for trace files.
21
+
22
+ Enables forward/backward compatibility as the trace format evolves.
23
+
24
+ Attributes:
25
+ version: Schema version string.
26
+ thoughtflow_version: ThoughtFlow version that created the trace.
27
+ features: List of optional features used in this trace.
28
+ """
29
+
30
+ version: str = SCHEMA_VERSION
31
+ thoughtflow_version: str | None = None
32
+ features: list[str] | None = None
33
+
34
+ def to_dict(self) -> dict[str, Any]:
35
+ """Convert to a dict for embedding in trace files.
36
+
37
+ Returns:
38
+ Dict with schema information.
39
+ """
40
+ return {
41
+ "schema_version": self.version,
42
+ "thoughtflow_version": self.thoughtflow_version,
43
+ "features": self.features or [],
44
+ }
45
+
46
+ @classmethod
47
+ def from_dict(cls, data: dict[str, Any]) -> TraceSchema:
48
+ """Create from a dict.
49
+
50
+ Args:
51
+ data: Dict with schema information.
52
+
53
+ Returns:
54
+ TraceSchema instance.
55
+ """
56
+ return cls(
57
+ version=data.get("schema_version", "1.0.0"),
58
+ thoughtflow_version=data.get("thoughtflow_version"),
59
+ features=data.get("features"),
60
+ )
61
+
62
+ def is_compatible(self, other_version: str) -> bool:
63
+ """Check if this schema is compatible with another version.
64
+
65
+ Compatibility rules:
66
+ - Same major version = compatible
67
+ - Different major version = incompatible
68
+
69
+ Args:
70
+ other_version: Version string to check against.
71
+
72
+ Returns:
73
+ True if compatible, False otherwise.
74
+ """
75
+ this_major = self.version.split(".")[0]
76
+ other_major = other_version.split(".")[0]
77
+ return this_major == other_major
78
+
79
+
80
+ def validate_trace(trace_data: dict[str, Any]) -> list[str]:
81
+ """Validate a trace against the current schema.
82
+
83
+ Args:
84
+ trace_data: The trace data to validate.
85
+
86
+ Returns:
87
+ List of validation errors (empty if valid).
88
+ """
89
+ errors: list[str] = []
90
+
91
+ # Check required fields
92
+ if "session_id" not in trace_data:
93
+ errors.append("Missing required field: session_id")
94
+
95
+ if "events" not in trace_data:
96
+ errors.append("Missing required field: events")
97
+ elif not isinstance(trace_data["events"], list):
98
+ errors.append("Field 'events' must be a list")
99
+
100
+ # Check schema version compatibility
101
+ schema_info = trace_data.get("schema", {})
102
+ if schema_info:
103
+ trace_version = schema_info.get("schema_version", "1.0.0")
104
+ current_schema = TraceSchema()
105
+ if not current_schema.is_compatible(trace_version):
106
+ errors.append(
107
+ f"Schema version mismatch: trace is v{trace_version}, "
108
+ f"current is v{SCHEMA_VERSION}"
109
+ )
110
+
111
+ return errors