tactus 0.34.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +1 -1
- tactus/adapters/broker_log.py +17 -14
- tactus/adapters/channels/__init__.py +17 -15
- tactus/adapters/channels/base.py +16 -7
- tactus/adapters/channels/broker.py +43 -13
- tactus/adapters/channels/cli.py +19 -15
- tactus/adapters/channels/host.py +15 -6
- tactus/adapters/channels/ipc.py +82 -31
- tactus/adapters/channels/sse.py +41 -23
- tactus/adapters/cli_hitl.py +19 -19
- tactus/adapters/cli_log.py +4 -4
- tactus/adapters/control_loop.py +138 -99
- tactus/adapters/cost_collector_log.py +9 -9
- tactus/adapters/file_storage.py +56 -52
- tactus/adapters/http_callback_log.py +23 -13
- tactus/adapters/ide_log.py +17 -9
- tactus/adapters/lua_tools.py +4 -5
- tactus/adapters/mcp.py +16 -19
- tactus/adapters/mcp_manager.py +46 -30
- tactus/adapters/memory.py +9 -9
- tactus/adapters/plugins.py +42 -42
- tactus/broker/client.py +75 -78
- tactus/broker/protocol.py +57 -57
- tactus/broker/server.py +252 -197
- tactus/cli/app.py +3 -1
- tactus/cli/control.py +2 -2
- tactus/core/config_manager.py +181 -135
- tactus/core/dependencies/registry.py +66 -48
- tactus/core/dsl_stubs.py +222 -163
- tactus/core/exceptions.py +10 -1
- tactus/core/execution_context.py +152 -112
- tactus/core/lua_sandbox.py +72 -64
- tactus/core/message_history_manager.py +138 -43
- tactus/core/mocking.py +41 -27
- tactus/core/output_validator.py +49 -44
- tactus/core/registry.py +94 -80
- tactus/core/runtime.py +211 -176
- tactus/core/template_resolver.py +16 -16
- tactus/core/yaml_parser.py +55 -45
- tactus/docs/extractor.py +7 -6
- tactus/ide/server.py +119 -78
- tactus/primitives/control.py +10 -6
- tactus/primitives/file.py +48 -46
- tactus/primitives/handles.py +47 -35
- tactus/primitives/host.py +29 -27
- tactus/primitives/human.py +154 -137
- tactus/primitives/json.py +22 -23
- tactus/primitives/log.py +26 -26
- tactus/primitives/message_history.py +285 -31
- tactus/primitives/model.py +15 -9
- tactus/primitives/procedure.py +86 -64
- tactus/primitives/procedure_callable.py +58 -51
- tactus/primitives/retry.py +31 -29
- tactus/primitives/session.py +42 -29
- tactus/primitives/state.py +54 -43
- tactus/primitives/step.py +9 -13
- tactus/primitives/system.py +34 -21
- tactus/primitives/tool.py +44 -31
- tactus/primitives/tool_handle.py +76 -54
- tactus/primitives/toolset.py +25 -22
- tactus/sandbox/config.py +4 -4
- tactus/sandbox/container_runner.py +161 -107
- tactus/sandbox/docker_manager.py +20 -20
- tactus/sandbox/entrypoint.py +16 -14
- tactus/sandbox/protocol.py +15 -15
- tactus/stdlib/classify/llm.py +1 -3
- tactus/stdlib/core/validation.py +0 -3
- tactus/testing/pydantic_eval_runner.py +1 -1
- tactus/utils/asyncio_helpers.py +27 -0
- tactus/utils/cost_calculator.py +7 -7
- tactus/utils/model_pricing.py +11 -12
- tactus/utils/safe_file_library.py +156 -132
- tactus/utils/safe_libraries.py +27 -27
- tactus/validation/error_listener.py +18 -5
- tactus/validation/semantic_visitor.py +392 -333
- tactus/validation/validator.py +89 -49
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/METADATA +12 -3
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/RECORD +81 -80
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/WHEEL +0 -0
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/entry_points.txt +0 -0
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/licenses/LICENSE +0 -0
tactus/primitives/json.py
CHANGED
|
@@ -60,16 +60,16 @@ class JsonPrimitive:
|
|
|
60
60
|
# Convert Lua tables to Python dicts recursively if needed
|
|
61
61
|
python_data = self._lua_to_python(data)
|
|
62
62
|
|
|
63
|
-
|
|
64
|
-
logger.debug(
|
|
65
|
-
return
|
|
63
|
+
json_payload = json.dumps(python_data, ensure_ascii=False, indent=None)
|
|
64
|
+
logger.debug("Encoded data to JSON (%s bytes)", len(json_payload))
|
|
65
|
+
return json_payload
|
|
66
66
|
|
|
67
|
-
except (TypeError, ValueError) as
|
|
68
|
-
|
|
69
|
-
logger.error(
|
|
70
|
-
raise ValueError(
|
|
67
|
+
except (TypeError, ValueError) as error:
|
|
68
|
+
error_message = f"Failed to encode to JSON: {error}"
|
|
69
|
+
logger.error(error_message)
|
|
70
|
+
raise ValueError(error_message)
|
|
71
71
|
|
|
72
|
-
def decode(self, json_str: str):
|
|
72
|
+
def decode(self, json_str: str) -> Any:
|
|
73
73
|
"""
|
|
74
74
|
Decode JSON string to Lua table.
|
|
75
75
|
|
|
@@ -93,7 +93,7 @@ class JsonPrimitive:
|
|
|
93
93
|
try:
|
|
94
94
|
# Parse JSON to Python dict
|
|
95
95
|
python_data = json.loads(json_str)
|
|
96
|
-
logger.debug(
|
|
96
|
+
logger.debug("Decoded JSON string (%s bytes)", len(json_str))
|
|
97
97
|
|
|
98
98
|
# Convert to Lua table if lua_sandbox available
|
|
99
99
|
if self.lua_sandbox:
|
|
@@ -102,10 +102,10 @@ class JsonPrimitive:
|
|
|
102
102
|
# Fallback: return Python dict (will work but not ideal)
|
|
103
103
|
return python_data
|
|
104
104
|
|
|
105
|
-
except json.JSONDecodeError as
|
|
106
|
-
|
|
107
|
-
logger.error(
|
|
108
|
-
raise ValueError(
|
|
105
|
+
except json.JSONDecodeError as error:
|
|
106
|
+
error_message = f"Failed to decode JSON: {error}"
|
|
107
|
+
logger.error(error_message)
|
|
108
|
+
raise ValueError(error_message)
|
|
109
109
|
|
|
110
110
|
def _lua_to_python(self, value: Any) -> Any:
|
|
111
111
|
"""
|
|
@@ -125,27 +125,27 @@ class JsonPrimitive:
|
|
|
125
125
|
if lua_type(value) == "table":
|
|
126
126
|
# Try to determine if it's an array or dict
|
|
127
127
|
# Lua arrays have consecutive integer keys starting at 1
|
|
128
|
-
|
|
128
|
+
converted = {}
|
|
129
129
|
is_array = True
|
|
130
130
|
keys = []
|
|
131
131
|
|
|
132
132
|
for k, v in value.items():
|
|
133
133
|
keys.append(k)
|
|
134
|
-
|
|
134
|
+
converted[k] = self._lua_to_python(v)
|
|
135
135
|
if not isinstance(k, int) or k < 1:
|
|
136
136
|
is_array = False
|
|
137
137
|
|
|
138
138
|
# Check if keys are consecutive integers starting at 1
|
|
139
139
|
if is_array and keys:
|
|
140
|
-
|
|
141
|
-
if
|
|
140
|
+
sorted_keys = sorted(keys)
|
|
141
|
+
if sorted_keys != list(range(1, len(keys) + 1)):
|
|
142
142
|
is_array = False
|
|
143
143
|
|
|
144
144
|
# Convert to list if it's an array
|
|
145
145
|
if is_array and keys:
|
|
146
|
-
return [
|
|
146
|
+
return [converted[i] for i in range(1, len(keys) + 1)]
|
|
147
147
|
else:
|
|
148
|
-
return
|
|
148
|
+
return converted
|
|
149
149
|
else:
|
|
150
150
|
# Primitive value
|
|
151
151
|
return value
|
|
@@ -174,16 +174,15 @@ class JsonPrimitive:
|
|
|
174
174
|
lua_table[k] = self._python_to_lua(v)
|
|
175
175
|
return lua_table
|
|
176
176
|
|
|
177
|
-
|
|
177
|
+
if isinstance(value, (list, tuple)):
|
|
178
178
|
# Convert list to Lua array (1-indexed)
|
|
179
179
|
lua_table = self.lua_sandbox.lua.table()
|
|
180
180
|
for i, item in enumerate(value, start=1):
|
|
181
181
|
lua_table[i] = self._python_to_lua(item)
|
|
182
182
|
return lua_table
|
|
183
183
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
return value
|
|
184
|
+
# Primitive value (str, int, float, bool, None)
|
|
185
|
+
return value
|
|
187
186
|
|
|
188
187
|
def __repr__(self) -> str:
|
|
189
188
|
return "JsonPrimitive()"
|
tactus/primitives/log.py
CHANGED
|
@@ -9,9 +9,9 @@ Provides:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
|
-
from typing import Any,
|
|
12
|
+
from typing import Any, Optional, TYPE_CHECKING
|
|
13
13
|
|
|
14
|
-
if TYPE_CHECKING:
|
|
14
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
15
15
|
from tactus.protocols.log_handler import LogHandler
|
|
16
16
|
|
|
17
17
|
logger = logging.getLogger(__name__)
|
|
@@ -38,32 +38,32 @@ class LogPrimitive:
|
|
|
38
38
|
self.logger = logging.getLogger(f"procedure.{procedure_id}")
|
|
39
39
|
self.log_handler = log_handler
|
|
40
40
|
|
|
41
|
-
def _format_message(self, message: str, context: Optional[
|
|
41
|
+
def _format_message(self, message: str, context: Optional[dict[str, Any]] = None) -> str:
|
|
42
42
|
"""Format log message with context."""
|
|
43
43
|
if context:
|
|
44
44
|
import json
|
|
45
45
|
|
|
46
46
|
# Convert Lua tables to Python dicts
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
return f"{message}\nContext: {
|
|
47
|
+
context_payload = self._lua_to_python(context)
|
|
48
|
+
context_json = json.dumps(context_payload, indent=2)
|
|
49
|
+
return f"{message}\nContext: {context_json}"
|
|
50
50
|
return message
|
|
51
51
|
|
|
52
|
-
def _lua_to_python(self,
|
|
52
|
+
def _lua_to_python(self, value: Any) -> Any:
|
|
53
53
|
"""Convert Lua objects to Python equivalents recursively."""
|
|
54
54
|
# Check if it's a Lua table
|
|
55
|
-
if hasattr(
|
|
56
|
-
return {self._lua_to_python(k): self._lua_to_python(v) for k, v in
|
|
57
|
-
elif hasattr(
|
|
55
|
+
if hasattr(value, "items"): # Lua table with dict-like interface
|
|
56
|
+
return {self._lua_to_python(k): self._lua_to_python(v) for k, v in value.items()}
|
|
57
|
+
elif hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): # Lua array
|
|
58
58
|
try:
|
|
59
|
-
return [self._lua_to_python(v) for v in
|
|
59
|
+
return [self._lua_to_python(v) for v in value]
|
|
60
60
|
except Exception: # noqa: E722
|
|
61
61
|
# If iteration fails, return as-is
|
|
62
|
-
return
|
|
62
|
+
return value
|
|
63
63
|
else:
|
|
64
|
-
return
|
|
64
|
+
return value
|
|
65
65
|
|
|
66
|
-
def debug(self, message: str, context: Optional[
|
|
66
|
+
def debug(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
|
|
67
67
|
"""
|
|
68
68
|
Log debug message.
|
|
69
69
|
|
|
@@ -78,11 +78,11 @@ class LogPrimitive:
|
|
|
78
78
|
if self.log_handler:
|
|
79
79
|
from tactus.protocols.models import LogEvent
|
|
80
80
|
|
|
81
|
-
|
|
81
|
+
context_payload = self._lua_to_python(context) if context else None
|
|
82
82
|
event = LogEvent(
|
|
83
83
|
level="DEBUG",
|
|
84
84
|
message=message,
|
|
85
|
-
context=
|
|
85
|
+
context=context_payload,
|
|
86
86
|
logger_name=self.logger.name,
|
|
87
87
|
procedure_id=self.procedure_id,
|
|
88
88
|
)
|
|
@@ -92,7 +92,7 @@ class LogPrimitive:
|
|
|
92
92
|
formatted = self._format_message(message, context)
|
|
93
93
|
self.logger.debug(formatted)
|
|
94
94
|
|
|
95
|
-
def info(self, message: str, context: Optional[
|
|
95
|
+
def info(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
|
|
96
96
|
"""
|
|
97
97
|
Log info message.
|
|
98
98
|
|
|
@@ -107,11 +107,11 @@ class LogPrimitive:
|
|
|
107
107
|
if self.log_handler:
|
|
108
108
|
from tactus.protocols.models import LogEvent
|
|
109
109
|
|
|
110
|
-
|
|
110
|
+
context_payload = self._lua_to_python(context) if context else None
|
|
111
111
|
event = LogEvent(
|
|
112
112
|
level="INFO",
|
|
113
113
|
message=message,
|
|
114
|
-
context=
|
|
114
|
+
context=context_payload,
|
|
115
115
|
logger_name=self.logger.name,
|
|
116
116
|
procedure_id=self.procedure_id,
|
|
117
117
|
)
|
|
@@ -121,7 +121,7 @@ class LogPrimitive:
|
|
|
121
121
|
formatted = self._format_message(message, context)
|
|
122
122
|
self.logger.info(formatted)
|
|
123
123
|
|
|
124
|
-
def warn(self, message: str, context: Optional[
|
|
124
|
+
def warn(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
|
|
125
125
|
"""
|
|
126
126
|
Log warning message.
|
|
127
127
|
|
|
@@ -136,11 +136,11 @@ class LogPrimitive:
|
|
|
136
136
|
if self.log_handler:
|
|
137
137
|
from tactus.protocols.models import LogEvent
|
|
138
138
|
|
|
139
|
-
|
|
139
|
+
context_payload = self._lua_to_python(context) if context else None
|
|
140
140
|
event = LogEvent(
|
|
141
141
|
level="WARNING",
|
|
142
142
|
message=message,
|
|
143
|
-
context=
|
|
143
|
+
context=context_payload,
|
|
144
144
|
logger_name=self.logger.name,
|
|
145
145
|
procedure_id=self.procedure_id,
|
|
146
146
|
)
|
|
@@ -150,11 +150,11 @@ class LogPrimitive:
|
|
|
150
150
|
formatted = self._format_message(message, context)
|
|
151
151
|
self.logger.warning(formatted)
|
|
152
152
|
|
|
153
|
-
def warning(self, message: str, context: Optional[
|
|
153
|
+
def warning(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
|
|
154
154
|
"""Alias for warn(), matching common logging APIs."""
|
|
155
155
|
self.warn(message, context)
|
|
156
156
|
|
|
157
|
-
def error(self, message: str, context: Optional[
|
|
157
|
+
def error(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
|
|
158
158
|
"""
|
|
159
159
|
Log error message.
|
|
160
160
|
|
|
@@ -169,11 +169,11 @@ class LogPrimitive:
|
|
|
169
169
|
if self.log_handler:
|
|
170
170
|
from tactus.protocols.models import LogEvent
|
|
171
171
|
|
|
172
|
-
|
|
172
|
+
context_payload = self._lua_to_python(context) if context else None
|
|
173
173
|
event = LogEvent(
|
|
174
174
|
level="ERROR",
|
|
175
175
|
message=message,
|
|
176
|
-
context=
|
|
176
|
+
context=context_payload,
|
|
177
177
|
logger_name=self.logger.name,
|
|
178
178
|
procedure_id=self.procedure_id,
|
|
179
179
|
)
|
|
@@ -42,28 +42,31 @@ class MessageHistoryPrimitive:
|
|
|
42
42
|
self.message_history_manager = message_history_manager
|
|
43
43
|
self.agent_name = agent_name
|
|
44
44
|
|
|
45
|
-
def append(self,
|
|
45
|
+
def append(self, message_payload: dict[str, Any]) -> None:
|
|
46
46
|
"""
|
|
47
47
|
Append a message to the message history.
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
|
|
50
|
+
message_payload: dict with 'role' and 'content' keys
|
|
51
51
|
role: 'user', 'assistant', 'system'
|
|
52
52
|
content: message text
|
|
53
53
|
|
|
54
54
|
Example:
|
|
55
55
|
MessageHistory.append({role = "user", content = "Hello"})
|
|
56
56
|
"""
|
|
57
|
-
if not self.message_history_manager
|
|
57
|
+
if not self.message_history_manager:
|
|
58
58
|
return
|
|
59
59
|
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
message_payload = self._normalize_message_payload(message_payload)
|
|
61
|
+
role = message_payload.get("role", "user")
|
|
62
|
+
content = message_payload.get("content", "")
|
|
62
63
|
|
|
63
|
-
# Create a
|
|
64
|
-
|
|
64
|
+
# Create a message dict and preserve extra fields
|
|
65
|
+
message_entry = dict(message_payload)
|
|
66
|
+
message_entry["role"] = role
|
|
67
|
+
message_entry["content"] = content
|
|
65
68
|
|
|
66
|
-
self.message_history_manager.add_message(self.agent_name,
|
|
69
|
+
self.message_history_manager.add_message(self.agent_name, message_entry)
|
|
67
70
|
|
|
68
71
|
def inject_system(self, text: str) -> None:
|
|
69
72
|
"""
|
|
@@ -87,12 +90,14 @@ class MessageHistoryPrimitive:
|
|
|
87
90
|
Example:
|
|
88
91
|
MessageHistory.clear()
|
|
89
92
|
"""
|
|
90
|
-
if not self.message_history_manager
|
|
93
|
+
if not self.message_history_manager:
|
|
91
94
|
return
|
|
95
|
+
if self.agent_name:
|
|
96
|
+
self.message_history_manager.clear_agent_history(self.agent_name)
|
|
97
|
+
else:
|
|
98
|
+
self.message_history_manager.clear_shared_history()
|
|
92
99
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def get(self) -> list:
|
|
100
|
+
def get(self) -> list[dict[str, Any]]:
|
|
96
101
|
"""
|
|
97
102
|
Get the full message history for this agent.
|
|
98
103
|
|
|
@@ -107,31 +112,280 @@ class MessageHistoryPrimitive:
|
|
|
107
112
|
Log.info(msg.role .. ": " .. msg.content)
|
|
108
113
|
end
|
|
109
114
|
"""
|
|
110
|
-
if not self.message_history_manager
|
|
115
|
+
if not self.message_history_manager:
|
|
111
116
|
return []
|
|
112
|
-
|
|
113
|
-
messages = self.message_history_manager.histories.get(self.agent_name, [])
|
|
117
|
+
messages = self._get_history_ref()
|
|
114
118
|
|
|
115
119
|
# Convert to Lua-friendly format
|
|
116
|
-
result = []
|
|
117
|
-
for
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
else:
|
|
121
|
-
# Handle pydantic_ai ModelMessage objects
|
|
122
|
-
try:
|
|
123
|
-
result.append(
|
|
124
|
-
{
|
|
125
|
-
"role": getattr(msg, "role", ""),
|
|
126
|
-
"content": str(getattr(msg, "content", "")),
|
|
127
|
-
}
|
|
128
|
-
)
|
|
129
|
-
except Exception:
|
|
130
|
-
# Fallback: convert to string
|
|
131
|
-
result.append({"role": "unknown", "content": str(msg)})
|
|
120
|
+
result: list[dict[str, Any]] = []
|
|
121
|
+
for message in messages:
|
|
122
|
+
serialized_message = self._serialize_message(message)
|
|
123
|
+
result.append(serialized_message)
|
|
132
124
|
|
|
133
125
|
return result
|
|
134
126
|
|
|
127
|
+
def replace(self, messages: list[Any]) -> None:
|
|
128
|
+
"""
|
|
129
|
+
Replace the current message history with a new list.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
messages: List of message dicts to set as the new history
|
|
133
|
+
"""
|
|
134
|
+
if not self.message_history_manager:
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
normalized_messages = self._normalize_messages(messages)
|
|
138
|
+
normalized_messages = [
|
|
139
|
+
self.message_history_manager._ensure_message_metadata(message)
|
|
140
|
+
for message in normalized_messages
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
if self.agent_name:
|
|
144
|
+
self.message_history_manager.histories[self.agent_name] = normalized_messages
|
|
145
|
+
else:
|
|
146
|
+
self.message_history_manager.shared_history = normalized_messages
|
|
147
|
+
|
|
148
|
+
def reset(self, options: Optional[dict[str, Any]] = None) -> None:
|
|
149
|
+
"""
|
|
150
|
+
Reset history while optionally keeping leading system messages.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
options: Optional dict with keep mode:
|
|
154
|
+
- "system_prefix" (default): keep leading system messages only
|
|
155
|
+
- "system_all": keep all system messages
|
|
156
|
+
- "none": clear all messages
|
|
157
|
+
"""
|
|
158
|
+
if not self.message_history_manager:
|
|
159
|
+
return
|
|
160
|
+
|
|
161
|
+
keep_mode = "system_prefix"
|
|
162
|
+
normalized_options = self._normalize_options(options)
|
|
163
|
+
if normalized_options:
|
|
164
|
+
keep_mode = normalized_options.get("keep", keep_mode)
|
|
165
|
+
elif isinstance(options, str):
|
|
166
|
+
keep_mode = options
|
|
167
|
+
|
|
168
|
+
messages = self._get_history_ref()
|
|
169
|
+
|
|
170
|
+
if keep_mode == "none":
|
|
171
|
+
self.replace([])
|
|
172
|
+
return
|
|
173
|
+
if keep_mode == "system_all":
|
|
174
|
+
system_messages = self.message_history_manager._filter_by_role(messages, "system")
|
|
175
|
+
self.replace(system_messages)
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
system_prefix_messages = self.message_history_manager._filter_system_prefix(messages)
|
|
179
|
+
self.replace(system_prefix_messages)
|
|
180
|
+
|
|
181
|
+
def head(self, n: int) -> list[dict[str, Any]]:
|
|
182
|
+
"""Return the first N messages without mutating history."""
|
|
183
|
+
if not self.message_history_manager:
|
|
184
|
+
return []
|
|
185
|
+
messages = self._get_history_ref()
|
|
186
|
+
limit = max(int(n or 0), 0)
|
|
187
|
+
return self._serialize_messages(messages[:limit])
|
|
188
|
+
|
|
189
|
+
def tail(self, n: int) -> list[dict[str, Any]]:
|
|
190
|
+
"""Return the last N messages without mutating history."""
|
|
191
|
+
if not self.message_history_manager:
|
|
192
|
+
return []
|
|
193
|
+
messages = self._get_history_ref()
|
|
194
|
+
limit = max(int(n or 0), 0)
|
|
195
|
+
return self._serialize_messages(messages[-limit:] if limit > 0 else [])
|
|
196
|
+
|
|
197
|
+
def slice(self, options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
198
|
+
"""Return a slice of messages using 1-based start/stop indices."""
|
|
199
|
+
if not self.message_history_manager:
|
|
200
|
+
return []
|
|
201
|
+
normalized_options = self._normalize_options(options)
|
|
202
|
+
if not normalized_options:
|
|
203
|
+
return []
|
|
204
|
+
messages = self._get_history_ref()
|
|
205
|
+
start = normalized_options.get("start")
|
|
206
|
+
stop = normalized_options.get("stop")
|
|
207
|
+
start_index = max(int(start or 1) - 1, 0)
|
|
208
|
+
stop_index = int(stop) if stop is not None else None
|
|
209
|
+
sliced = messages[start_index:stop_index]
|
|
210
|
+
return self._serialize_messages(sliced)
|
|
211
|
+
|
|
212
|
+
def tail_tokens(
|
|
213
|
+
self, max_tokens: int, options: Optional[dict[str, Any]] = None
|
|
214
|
+
) -> list[dict[str, Any]]:
|
|
215
|
+
"""Return the last messages that fit within the token budget."""
|
|
216
|
+
if not self.message_history_manager:
|
|
217
|
+
return []
|
|
218
|
+
messages = self._get_history_ref()
|
|
219
|
+
token_filtered_messages = self.message_history_manager._filter_tail_tokens(
|
|
220
|
+
messages, max_tokens
|
|
221
|
+
)
|
|
222
|
+
return self._serialize_messages(token_filtered_messages)
|
|
223
|
+
|
|
224
|
+
def keep_head(self, n: int) -> None:
|
|
225
|
+
"""Keep only the first N messages."""
|
|
226
|
+
if not self.message_history_manager:
|
|
227
|
+
return
|
|
228
|
+
messages = self._get_history_ref()
|
|
229
|
+
limit = max(int(n or 0), 0)
|
|
230
|
+
self.replace(messages[:limit])
|
|
231
|
+
|
|
232
|
+
def keep_tail(self, n: int) -> None:
|
|
233
|
+
"""Keep only the last N messages."""
|
|
234
|
+
if not self.message_history_manager:
|
|
235
|
+
return
|
|
236
|
+
messages = self._get_history_ref()
|
|
237
|
+
limit = max(int(n or 0), 0)
|
|
238
|
+
self.replace(messages[-limit:] if limit > 0 else [])
|
|
239
|
+
|
|
240
|
+
def keep_tail_tokens(self, max_tokens: int, options: Optional[dict[str, Any]] = None) -> None:
|
|
241
|
+
"""Keep only the last messages that fit within the token budget."""
|
|
242
|
+
if not self.message_history_manager:
|
|
243
|
+
return
|
|
244
|
+
messages = self._get_history_ref()
|
|
245
|
+
token_filtered_messages = self.message_history_manager._filter_tail_tokens(
|
|
246
|
+
messages, max_tokens
|
|
247
|
+
)
|
|
248
|
+
self.replace(token_filtered_messages)
|
|
249
|
+
|
|
250
|
+
def rewind(self, n: int) -> None:
|
|
251
|
+
"""Remove the last N messages from history."""
|
|
252
|
+
if not self.message_history_manager:
|
|
253
|
+
return
|
|
254
|
+
messages = self._get_history_ref()
|
|
255
|
+
count = max(int(n or 0), 0)
|
|
256
|
+
if count <= 0:
|
|
257
|
+
return
|
|
258
|
+
self.replace(messages[:-count])
|
|
259
|
+
|
|
260
|
+
def rewind_to(self, message_id: Any) -> None:
|
|
261
|
+
"""Rewind history back to a message id or checkpoint name."""
|
|
262
|
+
if not self.message_history_manager:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
target_message_id = message_id
|
|
266
|
+
if isinstance(message_id, str):
|
|
267
|
+
checkpoint_id = self.message_history_manager.get_checkpoint(message_id)
|
|
268
|
+
target_message_id = checkpoint_id if checkpoint_id is not None else message_id
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
target_message_id = int(target_message_id)
|
|
272
|
+
except (TypeError, ValueError):
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
messages = self._get_history_ref()
|
|
276
|
+
for index, message in enumerate(messages):
|
|
277
|
+
message_id_value = (
|
|
278
|
+
message.get("id") if isinstance(message, dict) else getattr(message, "id", None)
|
|
279
|
+
)
|
|
280
|
+
if message_id_value == target_message_id:
|
|
281
|
+
self.replace(messages[: index + 1])
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
def checkpoint(self, name: Optional[str] = None) -> Optional[int]:
|
|
285
|
+
"""Return the id of the last message and optionally store a named checkpoint."""
|
|
286
|
+
if not self.message_history_manager:
|
|
287
|
+
return None
|
|
288
|
+
|
|
289
|
+
messages = self._get_history_ref()
|
|
290
|
+
if not messages:
|
|
291
|
+
return None
|
|
292
|
+
|
|
293
|
+
last_message = messages[-1]
|
|
294
|
+
if isinstance(last_message, dict):
|
|
295
|
+
last_message = self.message_history_manager._ensure_message_metadata(last_message)
|
|
296
|
+
message_id = last_message.get("id")
|
|
297
|
+
else:
|
|
298
|
+
message_id = getattr(last_message, "id", None)
|
|
299
|
+
|
|
300
|
+
if isinstance(name, str) and message_id is not None:
|
|
301
|
+
self.message_history_manager.record_checkpoint(name, message_id)
|
|
302
|
+
|
|
303
|
+
return message_id
|
|
304
|
+
|
|
305
|
+
def _get_history_ref(self) -> list[Any]:
|
|
306
|
+
"""Get a direct reference to the underlying history list."""
|
|
307
|
+
if not self.message_history_manager:
|
|
308
|
+
return []
|
|
309
|
+
if self.agent_name:
|
|
310
|
+
return self.message_history_manager.histories.setdefault(self.agent_name, [])
|
|
311
|
+
return self.message_history_manager.shared_history
|
|
312
|
+
|
|
313
|
+
def _serialize_messages(self, messages: list[Any]) -> list[dict[str, Any]]:
|
|
314
|
+
"""Serialize message objects to Lua-friendly dicts."""
|
|
315
|
+
result: list[dict[str, Any]] = []
|
|
316
|
+
for message in messages:
|
|
317
|
+
result.append(self._serialize_message(message))
|
|
318
|
+
return result
|
|
319
|
+
|
|
320
|
+
def _normalize_messages(self, messages: Any) -> list[Any]:
|
|
321
|
+
"""Normalize Python lists or Lua tables into a list of message dicts."""
|
|
322
|
+
if messages is None:
|
|
323
|
+
return []
|
|
324
|
+
if isinstance(messages, list):
|
|
325
|
+
return messages
|
|
326
|
+
if isinstance(messages, tuple):
|
|
327
|
+
return list(messages)
|
|
328
|
+
if hasattr(messages, "items"):
|
|
329
|
+
items = list(messages.items())
|
|
330
|
+
if items and all(isinstance(key, int) for key, _ in items):
|
|
331
|
+
items.sort(key=lambda pair: pair[0])
|
|
332
|
+
return [value for _, value in items]
|
|
333
|
+
return list(messages)
|
|
334
|
+
|
|
335
|
+
def _normalize_message_payload(self, message_payload: Any) -> dict[str, Any]:
|
|
336
|
+
"""Normalize a single message payload into a dict."""
|
|
337
|
+
if message_payload is None:
|
|
338
|
+
return {}
|
|
339
|
+
if isinstance(message_payload, dict):
|
|
340
|
+
return message_payload
|
|
341
|
+
if hasattr(message_payload, "items"):
|
|
342
|
+
try:
|
|
343
|
+
return dict(message_payload.items())
|
|
344
|
+
except Exception:
|
|
345
|
+
pass
|
|
346
|
+
return {"role": "user", "content": str(message_payload)}
|
|
347
|
+
|
|
348
|
+
def _normalize_message_data(self, message_data: Any) -> dict[str, Any]:
|
|
349
|
+
"""Compatibility alias for existing tests and external callers."""
|
|
350
|
+
return self._normalize_message_payload(message_data)
|
|
351
|
+
|
|
352
|
+
def _normalize_options(self, options: Any) -> dict[str, Any]:
|
|
353
|
+
"""Normalize options from Lua tables or dicts."""
|
|
354
|
+
if options is None:
|
|
355
|
+
return {}
|
|
356
|
+
if isinstance(options, dict):
|
|
357
|
+
return options
|
|
358
|
+
if hasattr(options, "items"):
|
|
359
|
+
try:
|
|
360
|
+
return dict(options.items())
|
|
361
|
+
except Exception:
|
|
362
|
+
return {}
|
|
363
|
+
return {}
|
|
364
|
+
|
|
365
|
+
def _serialize_message(self, message: Any) -> dict[str, Any]:
|
|
366
|
+
"""Serialize a single message into a Lua-friendly dict."""
|
|
367
|
+
if isinstance(message, dict):
|
|
368
|
+
message = self.message_history_manager._ensure_message_metadata(message)
|
|
369
|
+
serialized = dict(message)
|
|
370
|
+
serialized["role"] = str(serialized.get("role", ""))
|
|
371
|
+
serialized["content"] = str(serialized.get("content", ""))
|
|
372
|
+
return serialized
|
|
373
|
+
|
|
374
|
+
# Handle pydantic_ai ModelMessage objects
|
|
375
|
+
try:
|
|
376
|
+
serialized = {"role": getattr(message, "role", "")}
|
|
377
|
+
serialized["content"] = str(getattr(message, "content", ""))
|
|
378
|
+
message_id = getattr(message, "id", None)
|
|
379
|
+
if message_id is not None:
|
|
380
|
+
serialized["id"] = message_id
|
|
381
|
+
created_at = getattr(message, "created_at", None)
|
|
382
|
+
if created_at is not None:
|
|
383
|
+
serialized["created_at"] = created_at
|
|
384
|
+
return serialized
|
|
385
|
+
except Exception:
|
|
386
|
+
# Fallback: convert to string
|
|
387
|
+
return {"role": "unknown", "content": str(message)}
|
|
388
|
+
|
|
135
389
|
def load_from_node(self, node: Any) -> None:
|
|
136
390
|
"""
|
|
137
391
|
Load message history from a graph node.
|
tactus/primitives/model.py
CHANGED
|
@@ -74,7 +74,7 @@ class ModelPrimitive:
|
|
|
74
74
|
headers=config.get("headers"),
|
|
75
75
|
)
|
|
76
76
|
|
|
77
|
-
|
|
77
|
+
if model_type == "pytorch":
|
|
78
78
|
from tactus.backends.pytorch_backend import PyTorchModelBackend
|
|
79
79
|
|
|
80
80
|
return PyTorchModelBackend(
|
|
@@ -83,8 +83,7 @@ class ModelPrimitive:
|
|
|
83
83
|
labels=config.get("labels"),
|
|
84
84
|
)
|
|
85
85
|
|
|
86
|
-
|
|
87
|
-
raise ValueError(f"Unknown model type: {model_type}. Supported types: http, pytorch")
|
|
86
|
+
raise ValueError(f"Unknown model type: {model_type}. Supported types: http, pytorch")
|
|
88
87
|
|
|
89
88
|
def predict(self, input_data: Any) -> Any:
|
|
90
89
|
"""
|
|
@@ -104,9 +103,9 @@ class ModelPrimitive:
|
|
|
104
103
|
# Capture source location
|
|
105
104
|
import inspect
|
|
106
105
|
|
|
107
|
-
|
|
108
|
-
if
|
|
109
|
-
caller_frame =
|
|
106
|
+
current_frame = inspect.currentframe()
|
|
107
|
+
if current_frame and current_frame.f_back:
|
|
108
|
+
caller_frame = current_frame.f_back
|
|
110
109
|
source_info = {
|
|
111
110
|
"file": caller_frame.f_code.co_filename,
|
|
112
111
|
"line": caller_frame.f_lineno,
|
|
@@ -132,12 +131,19 @@ class ModelPrimitive:
|
|
|
132
131
|
Model prediction result
|
|
133
132
|
"""
|
|
134
133
|
if self.mock_manager is not None:
|
|
135
|
-
|
|
136
|
-
mock_result = self.mock_manager.get_mock_response(
|
|
134
|
+
args_payload = input_data if isinstance(input_data, dict) else {"input": input_data}
|
|
135
|
+
mock_result = self.mock_manager.get_mock_response(
|
|
136
|
+
self.model_name,
|
|
137
|
+
args_payload,
|
|
138
|
+
)
|
|
137
139
|
if mock_result is not None:
|
|
138
140
|
# Ensure temporal mocks advance and calls are available for assertions.
|
|
139
141
|
try:
|
|
140
|
-
self.mock_manager.record_call(
|
|
142
|
+
self.mock_manager.record_call(
|
|
143
|
+
self.model_name,
|
|
144
|
+
args_payload,
|
|
145
|
+
mock_result,
|
|
146
|
+
)
|
|
141
147
|
except Exception:
|
|
142
148
|
pass
|
|
143
149
|
return mock_result
|