tactus 0.34.1__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. tactus/__init__.py +1 -1
  2. tactus/adapters/broker_log.py +17 -14
  3. tactus/adapters/channels/__init__.py +17 -15
  4. tactus/adapters/channels/base.py +16 -7
  5. tactus/adapters/channels/broker.py +43 -13
  6. tactus/adapters/channels/cli.py +19 -15
  7. tactus/adapters/channels/host.py +40 -25
  8. tactus/adapters/channels/ipc.py +82 -31
  9. tactus/adapters/channels/sse.py +41 -23
  10. tactus/adapters/cli_hitl.py +19 -19
  11. tactus/adapters/cli_log.py +4 -4
  12. tactus/adapters/control_loop.py +138 -99
  13. tactus/adapters/cost_collector_log.py +9 -9
  14. tactus/adapters/file_storage.py +56 -52
  15. tactus/adapters/http_callback_log.py +23 -13
  16. tactus/adapters/ide_log.py +17 -9
  17. tactus/adapters/lua_tools.py +4 -5
  18. tactus/adapters/mcp.py +16 -19
  19. tactus/adapters/mcp_manager.py +46 -30
  20. tactus/adapters/memory.py +9 -9
  21. tactus/adapters/plugins.py +42 -42
  22. tactus/broker/client.py +75 -78
  23. tactus/broker/protocol.py +57 -57
  24. tactus/broker/server.py +252 -197
  25. tactus/cli/app.py +3 -1
  26. tactus/cli/control.py +2 -2
  27. tactus/core/config_manager.py +181 -135
  28. tactus/core/dependencies/registry.py +66 -48
  29. tactus/core/dsl_stubs.py +222 -163
  30. tactus/core/exceptions.py +10 -1
  31. tactus/core/execution_context.py +152 -112
  32. tactus/core/lua_sandbox.py +72 -64
  33. tactus/core/message_history_manager.py +138 -43
  34. tactus/core/mocking.py +41 -27
  35. tactus/core/output_validator.py +49 -44
  36. tactus/core/registry.py +94 -80
  37. tactus/core/runtime.py +211 -176
  38. tactus/core/template_resolver.py +16 -16
  39. tactus/core/yaml_parser.py +55 -45
  40. tactus/docs/extractor.py +7 -6
  41. tactus/ide/server.py +119 -78
  42. tactus/primitives/control.py +10 -6
  43. tactus/primitives/file.py +48 -46
  44. tactus/primitives/handles.py +47 -35
  45. tactus/primitives/host.py +29 -27
  46. tactus/primitives/human.py +154 -137
  47. tactus/primitives/json.py +22 -23
  48. tactus/primitives/log.py +26 -26
  49. tactus/primitives/message_history.py +285 -31
  50. tactus/primitives/model.py +15 -9
  51. tactus/primitives/procedure.py +86 -64
  52. tactus/primitives/procedure_callable.py +58 -51
  53. tactus/primitives/retry.py +31 -29
  54. tactus/primitives/session.py +42 -29
  55. tactus/primitives/state.py +54 -43
  56. tactus/primitives/step.py +9 -13
  57. tactus/primitives/system.py +34 -21
  58. tactus/primitives/tool.py +44 -31
  59. tactus/primitives/tool_handle.py +76 -54
  60. tactus/primitives/toolset.py +25 -22
  61. tactus/sandbox/config.py +4 -4
  62. tactus/sandbox/container_runner.py +161 -107
  63. tactus/sandbox/docker_manager.py +20 -20
  64. tactus/sandbox/entrypoint.py +16 -14
  65. tactus/sandbox/protocol.py +15 -15
  66. tactus/stdlib/classify/llm.py +1 -3
  67. tactus/stdlib/core/validation.py +0 -3
  68. tactus/testing/pydantic_eval_runner.py +1 -1
  69. tactus/utils/asyncio_helpers.py +27 -0
  70. tactus/utils/cost_calculator.py +7 -7
  71. tactus/utils/model_pricing.py +11 -12
  72. tactus/utils/safe_file_library.py +156 -132
  73. tactus/utils/safe_libraries.py +27 -27
  74. tactus/validation/error_listener.py +18 -5
  75. tactus/validation/semantic_visitor.py +392 -333
  76. tactus/validation/validator.py +89 -49
  77. {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/METADATA +15 -3
  78. {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/RECORD +81 -80
  79. {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/WHEEL +0 -0
  80. {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/entry_points.txt +0 -0
  81. {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/licenses/LICENSE +0 -0
tactus/primitives/json.py CHANGED
@@ -60,16 +60,16 @@ class JsonPrimitive:
60
60
  # Convert Lua tables to Python dicts recursively if needed
61
61
  python_data = self._lua_to_python(data)
62
62
 
63
- json_str = json.dumps(python_data, ensure_ascii=False, indent=None)
64
- logger.debug(f"Encoded data to JSON ({len(json_str)} bytes)")
65
- return json_str
63
+ json_payload = json.dumps(python_data, ensure_ascii=False, indent=None)
64
+ logger.debug("Encoded data to JSON (%s bytes)", len(json_payload))
65
+ return json_payload
66
66
 
67
- except (TypeError, ValueError) as e:
68
- error_msg = f"Failed to encode to JSON: {e}"
69
- logger.error(error_msg)
70
- raise ValueError(error_msg)
67
+ except (TypeError, ValueError) as error:
68
+ error_message = f"Failed to encode to JSON: {error}"
69
+ logger.error(error_message)
70
+ raise ValueError(error_message)
71
71
 
72
- def decode(self, json_str: str):
72
+ def decode(self, json_str: str) -> Any:
73
73
  """
74
74
  Decode JSON string to Lua table.
75
75
 
@@ -93,7 +93,7 @@ class JsonPrimitive:
93
93
  try:
94
94
  # Parse JSON to Python dict
95
95
  python_data = json.loads(json_str)
96
- logger.debug(f"Decoded JSON string ({len(json_str)} bytes)")
96
+ logger.debug("Decoded JSON string (%s bytes)", len(json_str))
97
97
 
98
98
  # Convert to Lua table if lua_sandbox available
99
99
  if self.lua_sandbox:
@@ -102,10 +102,10 @@ class JsonPrimitive:
102
102
  # Fallback: return Python dict (will work but not ideal)
103
103
  return python_data
104
104
 
105
- except json.JSONDecodeError as e:
106
- error_msg = f"Failed to decode JSON: {e}"
107
- logger.error(error_msg)
108
- raise ValueError(error_msg)
105
+ except json.JSONDecodeError as error:
106
+ error_message = f"Failed to decode JSON: {error}"
107
+ logger.error(error_message)
108
+ raise ValueError(error_message)
109
109
 
110
110
  def _lua_to_python(self, value: Any) -> Any:
111
111
  """
@@ -125,27 +125,27 @@ class JsonPrimitive:
125
125
  if lua_type(value) == "table":
126
126
  # Try to determine if it's an array or dict
127
127
  # Lua arrays have consecutive integer keys starting at 1
128
- result = {}
128
+ converted = {}
129
129
  is_array = True
130
130
  keys = []
131
131
 
132
132
  for k, v in value.items():
133
133
  keys.append(k)
134
- result[k] = self._lua_to_python(v)
134
+ converted[k] = self._lua_to_python(v)
135
135
  if not isinstance(k, int) or k < 1:
136
136
  is_array = False
137
137
 
138
138
  # Check if keys are consecutive integers starting at 1
139
139
  if is_array and keys:
140
- keys_sorted = sorted(keys)
141
- if keys_sorted != list(range(1, len(keys) + 1)):
140
+ sorted_keys = sorted(keys)
141
+ if sorted_keys != list(range(1, len(keys) + 1)):
142
142
  is_array = False
143
143
 
144
144
  # Convert to list if it's an array
145
145
  if is_array and keys:
146
- return [result[i] for i in range(1, len(keys) + 1)]
146
+ return [converted[i] for i in range(1, len(keys) + 1)]
147
147
  else:
148
- return result
148
+ return converted
149
149
  else:
150
150
  # Primitive value
151
151
  return value
@@ -174,16 +174,15 @@ class JsonPrimitive:
174
174
  lua_table[k] = self._python_to_lua(v)
175
175
  return lua_table
176
176
 
177
- elif isinstance(value, (list, tuple)):
177
+ if isinstance(value, (list, tuple)):
178
178
  # Convert list to Lua array (1-indexed)
179
179
  lua_table = self.lua_sandbox.lua.table()
180
180
  for i, item in enumerate(value, start=1):
181
181
  lua_table[i] = self._python_to_lua(item)
182
182
  return lua_table
183
183
 
184
- else:
185
- # Primitive value (str, int, float, bool, None)
186
- return value
184
+ # Primitive value (str, int, float, bool, None)
185
+ return value
187
186
 
188
187
  def __repr__(self) -> str:
189
188
  return "JsonPrimitive()"
tactus/primitives/log.py CHANGED
@@ -9,9 +9,9 @@ Provides:
9
9
  """
10
10
 
11
11
  import logging
12
- from typing import Any, Dict, Optional, TYPE_CHECKING
12
+ from typing import Any, Optional, TYPE_CHECKING
13
13
 
14
- if TYPE_CHECKING:
14
+ if TYPE_CHECKING: # pragma: no cover
15
15
  from tactus.protocols.log_handler import LogHandler
16
16
 
17
17
  logger = logging.getLogger(__name__)
@@ -38,32 +38,32 @@ class LogPrimitive:
38
38
  self.logger = logging.getLogger(f"procedure.{procedure_id}")
39
39
  self.log_handler = log_handler
40
40
 
41
- def _format_message(self, message: str, context: Optional[Dict[str, Any]] = None) -> str:
41
+ def _format_message(self, message: str, context: Optional[dict[str, Any]] = None) -> str:
42
42
  """Format log message with context."""
43
43
  if context:
44
44
  import json
45
45
 
46
46
  # Convert Lua tables to Python dicts
47
- context_dict = self._lua_to_python(context)
48
- context_str = json.dumps(context_dict, indent=2)
49
- return f"{message}\nContext: {context_str}"
47
+ context_payload = self._lua_to_python(context)
48
+ context_json = json.dumps(context_payload, indent=2)
49
+ return f"{message}\nContext: {context_json}"
50
50
  return message
51
51
 
52
- def _lua_to_python(self, obj: Any) -> Any:
52
+ def _lua_to_python(self, value: Any) -> Any:
53
53
  """Convert Lua objects to Python equivalents recursively."""
54
54
  # Check if it's a Lua table
55
- if hasattr(obj, "items"): # Lua table with dict-like interface
56
- return {self._lua_to_python(k): self._lua_to_python(v) for k, v in obj.items()}
57
- elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)): # Lua array
55
+ if hasattr(value, "items"): # Lua table with dict-like interface
56
+ return {self._lua_to_python(k): self._lua_to_python(v) for k, v in value.items()}
57
+ elif hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): # Lua array
58
58
  try:
59
- return [self._lua_to_python(v) for v in obj]
59
+ return [self._lua_to_python(v) for v in value]
60
60
  except Exception: # noqa: E722
61
61
  # If iteration fails, return as-is
62
- return obj
62
+ return value
63
63
  else:
64
- return obj
64
+ return value
65
65
 
66
- def debug(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
66
+ def debug(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
67
67
  """
68
68
  Log debug message.
69
69
 
@@ -78,11 +78,11 @@ class LogPrimitive:
78
78
  if self.log_handler:
79
79
  from tactus.protocols.models import LogEvent
80
80
 
81
- context_dict = self._lua_to_python(context) if context else None
81
+ context_payload = self._lua_to_python(context) if context else None
82
82
  event = LogEvent(
83
83
  level="DEBUG",
84
84
  message=message,
85
- context=context_dict,
85
+ context=context_payload,
86
86
  logger_name=self.logger.name,
87
87
  procedure_id=self.procedure_id,
88
88
  )
@@ -92,7 +92,7 @@ class LogPrimitive:
92
92
  formatted = self._format_message(message, context)
93
93
  self.logger.debug(formatted)
94
94
 
95
- def info(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
95
+ def info(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
96
96
  """
97
97
  Log info message.
98
98
 
@@ -107,11 +107,11 @@ class LogPrimitive:
107
107
  if self.log_handler:
108
108
  from tactus.protocols.models import LogEvent
109
109
 
110
- context_dict = self._lua_to_python(context) if context else None
110
+ context_payload = self._lua_to_python(context) if context else None
111
111
  event = LogEvent(
112
112
  level="INFO",
113
113
  message=message,
114
- context=context_dict,
114
+ context=context_payload,
115
115
  logger_name=self.logger.name,
116
116
  procedure_id=self.procedure_id,
117
117
  )
@@ -121,7 +121,7 @@ class LogPrimitive:
121
121
  formatted = self._format_message(message, context)
122
122
  self.logger.info(formatted)
123
123
 
124
- def warn(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
124
+ def warn(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
125
125
  """
126
126
  Log warning message.
127
127
 
@@ -136,11 +136,11 @@ class LogPrimitive:
136
136
  if self.log_handler:
137
137
  from tactus.protocols.models import LogEvent
138
138
 
139
- context_dict = self._lua_to_python(context) if context else None
139
+ context_payload = self._lua_to_python(context) if context else None
140
140
  event = LogEvent(
141
141
  level="WARNING",
142
142
  message=message,
143
- context=context_dict,
143
+ context=context_payload,
144
144
  logger_name=self.logger.name,
145
145
  procedure_id=self.procedure_id,
146
146
  )
@@ -150,11 +150,11 @@ class LogPrimitive:
150
150
  formatted = self._format_message(message, context)
151
151
  self.logger.warning(formatted)
152
152
 
153
- def warning(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
153
+ def warning(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
154
154
  """Alias for warn(), matching common logging APIs."""
155
155
  self.warn(message, context)
156
156
 
157
- def error(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
157
+ def error(self, message: str, context: Optional[dict[str, Any]] = None) -> None:
158
158
  """
159
159
  Log error message.
160
160
 
@@ -169,11 +169,11 @@ class LogPrimitive:
169
169
  if self.log_handler:
170
170
  from tactus.protocols.models import LogEvent
171
171
 
172
- context_dict = self._lua_to_python(context) if context else None
172
+ context_payload = self._lua_to_python(context) if context else None
173
173
  event = LogEvent(
174
174
  level="ERROR",
175
175
  message=message,
176
- context=context_dict,
176
+ context=context_payload,
177
177
  logger_name=self.logger.name,
178
178
  procedure_id=self.procedure_id,
179
179
  )
@@ -42,28 +42,31 @@ class MessageHistoryPrimitive:
42
42
  self.message_history_manager = message_history_manager
43
43
  self.agent_name = agent_name
44
44
 
45
- def append(self, message_data: dict) -> None:
45
+ def append(self, message_payload: dict[str, Any]) -> None:
46
46
  """
47
47
  Append a message to the message history.
48
48
 
49
49
  Args:
50
- message_data: Dict with 'role' and 'content' keys
50
+ message_payload: dict with 'role' and 'content' keys
51
51
  role: 'user', 'assistant', 'system'
52
52
  content: message text
53
53
 
54
54
  Example:
55
55
  MessageHistory.append({role = "user", content = "Hello"})
56
56
  """
57
- if not self.message_history_manager or not self.agent_name:
57
+ if not self.message_history_manager:
58
58
  return
59
59
 
60
- role = message_data.get("role", "user")
61
- content = message_data.get("content", "")
60
+ message_payload = self._normalize_message_payload(message_payload)
61
+ role = message_payload.get("role", "user")
62
+ content = message_payload.get("content", "")
62
63
 
63
- # Create a simple message dict
64
- message = {"role": role, "content": content}
64
+ # Create a message dict and preserve extra fields
65
+ message_entry = dict(message_payload)
66
+ message_entry["role"] = role
67
+ message_entry["content"] = content
65
68
 
66
- self.message_history_manager.add_message(self.agent_name, message)
69
+ self.message_history_manager.add_message(self.agent_name, message_entry)
67
70
 
68
71
  def inject_system(self, text: str) -> None:
69
72
  """
@@ -87,12 +90,14 @@ class MessageHistoryPrimitive:
87
90
  Example:
88
91
  MessageHistory.clear()
89
92
  """
90
- if not self.message_history_manager or not self.agent_name:
93
+ if not self.message_history_manager:
91
94
  return
95
+ if self.agent_name:
96
+ self.message_history_manager.clear_agent_history(self.agent_name)
97
+ else:
98
+ self.message_history_manager.clear_shared_history()
92
99
 
93
- self.message_history_manager.clear_agent_history(self.agent_name)
94
-
95
- def get(self) -> list:
100
+ def get(self) -> list[dict[str, Any]]:
96
101
  """
97
102
  Get the full message history for this agent.
98
103
 
@@ -107,31 +112,280 @@ class MessageHistoryPrimitive:
107
112
  Log.info(msg.role .. ": " .. msg.content)
108
113
  end
109
114
  """
110
- if not self.message_history_manager or not self.agent_name:
115
+ if not self.message_history_manager:
111
116
  return []
112
-
113
- messages = self.message_history_manager.histories.get(self.agent_name, [])
117
+ messages = self._get_history_ref()
114
118
 
115
119
  # Convert to Lua-friendly format
116
- result = []
117
- for msg in messages:
118
- if isinstance(msg, dict):
119
- result.append({"role": msg.get("role", ""), "content": str(msg.get("content", ""))})
120
- else:
121
- # Handle pydantic_ai ModelMessage objects
122
- try:
123
- result.append(
124
- {
125
- "role": getattr(msg, "role", ""),
126
- "content": str(getattr(msg, "content", "")),
127
- }
128
- )
129
- except Exception:
130
- # Fallback: convert to string
131
- result.append({"role": "unknown", "content": str(msg)})
120
+ result: list[dict[str, Any]] = []
121
+ for message in messages:
122
+ serialized_message = self._serialize_message(message)
123
+ result.append(serialized_message)
132
124
 
133
125
  return result
134
126
 
127
+ def replace(self, messages: list[Any]) -> None:
128
+ """
129
+ Replace the current message history with a new list.
130
+
131
+ Args:
132
+ messages: List of message dicts to set as the new history
133
+ """
134
+ if not self.message_history_manager:
135
+ return
136
+
137
+ normalized_messages = self._normalize_messages(messages)
138
+ normalized_messages = [
139
+ self.message_history_manager._ensure_message_metadata(message)
140
+ for message in normalized_messages
141
+ ]
142
+
143
+ if self.agent_name:
144
+ self.message_history_manager.histories[self.agent_name] = normalized_messages
145
+ else:
146
+ self.message_history_manager.shared_history = normalized_messages
147
+
148
+ def reset(self, options: Optional[dict[str, Any]] = None) -> None:
149
+ """
150
+ Reset history while optionally keeping leading system messages.
151
+
152
+ Args:
153
+ options: Optional dict with keep mode:
154
+ - "system_prefix" (default): keep leading system messages only
155
+ - "system_all": keep all system messages
156
+ - "none": clear all messages
157
+ """
158
+ if not self.message_history_manager:
159
+ return
160
+
161
+ keep_mode = "system_prefix"
162
+ normalized_options = self._normalize_options(options)
163
+ if normalized_options:
164
+ keep_mode = normalized_options.get("keep", keep_mode)
165
+ elif isinstance(options, str):
166
+ keep_mode = options
167
+
168
+ messages = self._get_history_ref()
169
+
170
+ if keep_mode == "none":
171
+ self.replace([])
172
+ return
173
+ if keep_mode == "system_all":
174
+ system_messages = self.message_history_manager._filter_by_role(messages, "system")
175
+ self.replace(system_messages)
176
+ return
177
+
178
+ system_prefix_messages = self.message_history_manager._filter_system_prefix(messages)
179
+ self.replace(system_prefix_messages)
180
+
181
+ def head(self, n: int) -> list[dict[str, Any]]:
182
+ """Return the first N messages without mutating history."""
183
+ if not self.message_history_manager:
184
+ return []
185
+ messages = self._get_history_ref()
186
+ limit = max(int(n or 0), 0)
187
+ return self._serialize_messages(messages[:limit])
188
+
189
+ def tail(self, n: int) -> list[dict[str, Any]]:
190
+ """Return the last N messages without mutating history."""
191
+ if not self.message_history_manager:
192
+ return []
193
+ messages = self._get_history_ref()
194
+ limit = max(int(n or 0), 0)
195
+ return self._serialize_messages(messages[-limit:] if limit > 0 else [])
196
+
197
+ def slice(self, options: dict[str, Any]) -> list[dict[str, Any]]:
198
+ """Return a slice of messages using 1-based start/stop indices."""
199
+ if not self.message_history_manager:
200
+ return []
201
+ normalized_options = self._normalize_options(options)
202
+ if not normalized_options:
203
+ return []
204
+ messages = self._get_history_ref()
205
+ start = normalized_options.get("start")
206
+ stop = normalized_options.get("stop")
207
+ start_index = max(int(start or 1) - 1, 0)
208
+ stop_index = int(stop) if stop is not None else None
209
+ sliced = messages[start_index:stop_index]
210
+ return self._serialize_messages(sliced)
211
+
212
+ def tail_tokens(
213
+ self, max_tokens: int, options: Optional[dict[str, Any]] = None
214
+ ) -> list[dict[str, Any]]:
215
+ """Return the last messages that fit within the token budget."""
216
+ if not self.message_history_manager:
217
+ return []
218
+ messages = self._get_history_ref()
219
+ token_filtered_messages = self.message_history_manager._filter_tail_tokens(
220
+ messages, max_tokens
221
+ )
222
+ return self._serialize_messages(token_filtered_messages)
223
+
224
+ def keep_head(self, n: int) -> None:
225
+ """Keep only the first N messages."""
226
+ if not self.message_history_manager:
227
+ return
228
+ messages = self._get_history_ref()
229
+ limit = max(int(n or 0), 0)
230
+ self.replace(messages[:limit])
231
+
232
+ def keep_tail(self, n: int) -> None:
233
+ """Keep only the last N messages."""
234
+ if not self.message_history_manager:
235
+ return
236
+ messages = self._get_history_ref()
237
+ limit = max(int(n or 0), 0)
238
+ self.replace(messages[-limit:] if limit > 0 else [])
239
+
240
+ def keep_tail_tokens(self, max_tokens: int, options: Optional[dict[str, Any]] = None) -> None:
241
+ """Keep only the last messages that fit within the token budget."""
242
+ if not self.message_history_manager:
243
+ return
244
+ messages = self._get_history_ref()
245
+ token_filtered_messages = self.message_history_manager._filter_tail_tokens(
246
+ messages, max_tokens
247
+ )
248
+ self.replace(token_filtered_messages)
249
+
250
+ def rewind(self, n: int) -> None:
251
+ """Remove the last N messages from history."""
252
+ if not self.message_history_manager:
253
+ return
254
+ messages = self._get_history_ref()
255
+ count = max(int(n or 0), 0)
256
+ if count <= 0:
257
+ return
258
+ self.replace(messages[:-count])
259
+
260
+ def rewind_to(self, message_id: Any) -> None:
261
+ """Rewind history back to a message id or checkpoint name."""
262
+ if not self.message_history_manager:
263
+ return
264
+
265
+ target_message_id = message_id
266
+ if isinstance(message_id, str):
267
+ checkpoint_id = self.message_history_manager.get_checkpoint(message_id)
268
+ target_message_id = checkpoint_id if checkpoint_id is not None else message_id
269
+
270
+ try:
271
+ target_message_id = int(target_message_id)
272
+ except (TypeError, ValueError):
273
+ return
274
+
275
+ messages = self._get_history_ref()
276
+ for index, message in enumerate(messages):
277
+ message_id_value = (
278
+ message.get("id") if isinstance(message, dict) else getattr(message, "id", None)
279
+ )
280
+ if message_id_value == target_message_id:
281
+ self.replace(messages[: index + 1])
282
+ return
283
+
284
+ def checkpoint(self, name: Optional[str] = None) -> Optional[int]:
285
+ """Return the id of the last message and optionally store a named checkpoint."""
286
+ if not self.message_history_manager:
287
+ return None
288
+
289
+ messages = self._get_history_ref()
290
+ if not messages:
291
+ return None
292
+
293
+ last_message = messages[-1]
294
+ if isinstance(last_message, dict):
295
+ last_message = self.message_history_manager._ensure_message_metadata(last_message)
296
+ message_id = last_message.get("id")
297
+ else:
298
+ message_id = getattr(last_message, "id", None)
299
+
300
+ if isinstance(name, str) and message_id is not None:
301
+ self.message_history_manager.record_checkpoint(name, message_id)
302
+
303
+ return message_id
304
+
305
+ def _get_history_ref(self) -> list[Any]:
306
+ """Get a direct reference to the underlying history list."""
307
+ if not self.message_history_manager:
308
+ return []
309
+ if self.agent_name:
310
+ return self.message_history_manager.histories.setdefault(self.agent_name, [])
311
+ return self.message_history_manager.shared_history
312
+
313
+ def _serialize_messages(self, messages: list[Any]) -> list[dict[str, Any]]:
314
+ """Serialize message objects to Lua-friendly dicts."""
315
+ result: list[dict[str, Any]] = []
316
+ for message in messages:
317
+ result.append(self._serialize_message(message))
318
+ return result
319
+
320
+ def _normalize_messages(self, messages: Any) -> list[Any]:
321
+ """Normalize Python lists or Lua tables into a list of message dicts."""
322
+ if messages is None:
323
+ return []
324
+ if isinstance(messages, list):
325
+ return messages
326
+ if isinstance(messages, tuple):
327
+ return list(messages)
328
+ if hasattr(messages, "items"):
329
+ items = list(messages.items())
330
+ if items and all(isinstance(key, int) for key, _ in items):
331
+ items.sort(key=lambda pair: pair[0])
332
+ return [value for _, value in items]
333
+ return list(messages)
334
+
335
+ def _normalize_message_payload(self, message_payload: Any) -> dict[str, Any]:
336
+ """Normalize a single message payload into a dict."""
337
+ if message_payload is None:
338
+ return {}
339
+ if isinstance(message_payload, dict):
340
+ return message_payload
341
+ if hasattr(message_payload, "items"):
342
+ try:
343
+ return dict(message_payload.items())
344
+ except Exception:
345
+ pass
346
+ return {"role": "user", "content": str(message_payload)}
347
+
348
+ def _normalize_message_data(self, message_data: Any) -> dict[str, Any]:
349
+ """Compatibility alias for existing tests and external callers."""
350
+ return self._normalize_message_payload(message_data)
351
+
352
+ def _normalize_options(self, options: Any) -> dict[str, Any]:
353
+ """Normalize options from Lua tables or dicts."""
354
+ if options is None:
355
+ return {}
356
+ if isinstance(options, dict):
357
+ return options
358
+ if hasattr(options, "items"):
359
+ try:
360
+ return dict(options.items())
361
+ except Exception:
362
+ return {}
363
+ return {}
364
+
365
+ def _serialize_message(self, message: Any) -> dict[str, Any]:
366
+ """Serialize a single message into a Lua-friendly dict."""
367
+ if isinstance(message, dict):
368
+ message = self.message_history_manager._ensure_message_metadata(message)
369
+ serialized = dict(message)
370
+ serialized["role"] = str(serialized.get("role", ""))
371
+ serialized["content"] = str(serialized.get("content", ""))
372
+ return serialized
373
+
374
+ # Handle pydantic_ai ModelMessage objects
375
+ try:
376
+ serialized = {"role": getattr(message, "role", "")}
377
+ serialized["content"] = str(getattr(message, "content", ""))
378
+ message_id = getattr(message, "id", None)
379
+ if message_id is not None:
380
+ serialized["id"] = message_id
381
+ created_at = getattr(message, "created_at", None)
382
+ if created_at is not None:
383
+ serialized["created_at"] = created_at
384
+ return serialized
385
+ except Exception:
386
+ # Fallback: convert to string
387
+ return {"role": "unknown", "content": str(message)}
388
+
135
389
  def load_from_node(self, node: Any) -> None:
136
390
  """
137
391
  Load message history from a graph node.
@@ -74,7 +74,7 @@ class ModelPrimitive:
74
74
  headers=config.get("headers"),
75
75
  )
76
76
 
77
- elif model_type == "pytorch":
77
+ if model_type == "pytorch":
78
78
  from tactus.backends.pytorch_backend import PyTorchModelBackend
79
79
 
80
80
  return PyTorchModelBackend(
@@ -83,8 +83,7 @@ class ModelPrimitive:
83
83
  labels=config.get("labels"),
84
84
  )
85
85
 
86
- else:
87
- raise ValueError(f"Unknown model type: {model_type}. Supported types: http, pytorch")
86
+ raise ValueError(f"Unknown model type: {model_type}. Supported types: http, pytorch")
88
87
 
89
88
  def predict(self, input_data: Any) -> Any:
90
89
  """
@@ -104,9 +103,9 @@ class ModelPrimitive:
104
103
  # Capture source location
105
104
  import inspect
106
105
 
107
- frame = inspect.currentframe()
108
- if frame and frame.f_back:
109
- caller_frame = frame.f_back
106
+ current_frame = inspect.currentframe()
107
+ if current_frame and current_frame.f_back:
108
+ caller_frame = current_frame.f_back
110
109
  source_info = {
111
110
  "file": caller_frame.f_code.co_filename,
112
111
  "line": caller_frame.f_lineno,
@@ -132,12 +131,19 @@ class ModelPrimitive:
132
131
  Model prediction result
133
132
  """
134
133
  if self.mock_manager is not None:
135
- args = input_data if isinstance(input_data, dict) else {"input": input_data}
136
- mock_result = self.mock_manager.get_mock_response(self.model_name, args)
134
+ args_payload = input_data if isinstance(input_data, dict) else {"input": input_data}
135
+ mock_result = self.mock_manager.get_mock_response(
136
+ self.model_name,
137
+ args_payload,
138
+ )
137
139
  if mock_result is not None:
138
140
  # Ensure temporal mocks advance and calls are available for assertions.
139
141
  try:
140
- self.mock_manager.record_call(self.model_name, args, mock_result)
142
+ self.mock_manager.record_call(
143
+ self.model_name,
144
+ args_payload,
145
+ mock_result,
146
+ )
141
147
  except Exception:
142
148
  pass
143
149
  return mock_result