tactus 0.34.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. tactus/__init__.py +1 -1
  2. tactus/adapters/broker_log.py +17 -14
  3. tactus/adapters/channels/__init__.py +17 -15
  4. tactus/adapters/channels/base.py +16 -7
  5. tactus/adapters/channels/broker.py +43 -13
  6. tactus/adapters/channels/cli.py +19 -15
  7. tactus/adapters/channels/host.py +15 -6
  8. tactus/adapters/channels/ipc.py +82 -31
  9. tactus/adapters/channels/sse.py +41 -23
  10. tactus/adapters/cli_hitl.py +19 -19
  11. tactus/adapters/cli_log.py +4 -4
  12. tactus/adapters/control_loop.py +138 -99
  13. tactus/adapters/cost_collector_log.py +9 -9
  14. tactus/adapters/file_storage.py +56 -52
  15. tactus/adapters/http_callback_log.py +23 -13
  16. tactus/adapters/ide_log.py +17 -9
  17. tactus/adapters/lua_tools.py +4 -5
  18. tactus/adapters/mcp.py +16 -19
  19. tactus/adapters/mcp_manager.py +46 -30
  20. tactus/adapters/memory.py +9 -9
  21. tactus/adapters/plugins.py +42 -42
  22. tactus/broker/client.py +75 -78
  23. tactus/broker/protocol.py +57 -57
  24. tactus/broker/server.py +252 -197
  25. tactus/cli/app.py +3 -1
  26. tactus/cli/control.py +2 -2
  27. tactus/core/config_manager.py +181 -135
  28. tactus/core/dependencies/registry.py +66 -48
  29. tactus/core/dsl_stubs.py +222 -163
  30. tactus/core/exceptions.py +10 -1
  31. tactus/core/execution_context.py +152 -112
  32. tactus/core/lua_sandbox.py +72 -64
  33. tactus/core/message_history_manager.py +138 -43
  34. tactus/core/mocking.py +41 -27
  35. tactus/core/output_validator.py +49 -44
  36. tactus/core/registry.py +94 -80
  37. tactus/core/runtime.py +211 -176
  38. tactus/core/template_resolver.py +16 -16
  39. tactus/core/yaml_parser.py +55 -45
  40. tactus/docs/extractor.py +7 -6
  41. tactus/ide/server.py +119 -78
  42. tactus/primitives/control.py +10 -6
  43. tactus/primitives/file.py +48 -46
  44. tactus/primitives/handles.py +47 -35
  45. tactus/primitives/host.py +29 -27
  46. tactus/primitives/human.py +154 -137
  47. tactus/primitives/json.py +22 -23
  48. tactus/primitives/log.py +26 -26
  49. tactus/primitives/message_history.py +285 -31
  50. tactus/primitives/model.py +15 -9
  51. tactus/primitives/procedure.py +86 -64
  52. tactus/primitives/procedure_callable.py +58 -51
  53. tactus/primitives/retry.py +31 -29
  54. tactus/primitives/session.py +42 -29
  55. tactus/primitives/state.py +54 -43
  56. tactus/primitives/step.py +9 -13
  57. tactus/primitives/system.py +34 -21
  58. tactus/primitives/tool.py +44 -31
  59. tactus/primitives/tool_handle.py +76 -54
  60. tactus/primitives/toolset.py +25 -22
  61. tactus/sandbox/config.py +4 -4
  62. tactus/sandbox/container_runner.py +161 -107
  63. tactus/sandbox/docker_manager.py +20 -20
  64. tactus/sandbox/entrypoint.py +16 -14
  65. tactus/sandbox/protocol.py +15 -15
  66. tactus/stdlib/classify/llm.py +1 -3
  67. tactus/stdlib/core/validation.py +0 -3
  68. tactus/testing/pydantic_eval_runner.py +1 -1
  69. tactus/utils/asyncio_helpers.py +27 -0
  70. tactus/utils/cost_calculator.py +7 -7
  71. tactus/utils/model_pricing.py +11 -12
  72. tactus/utils/safe_file_library.py +156 -132
  73. tactus/utils/safe_libraries.py +27 -27
  74. tactus/validation/error_listener.py +18 -5
  75. tactus/validation/semantic_visitor.py +392 -333
  76. tactus/validation/validator.py +89 -49
  77. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/METADATA +12 -3
  78. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/RECORD +81 -80
  79. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/WHEEL +0 -0
  80. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/entry_points.txt +0 -0
  81. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@ Provides a sandboxed Lua runtime with:
12
12
 
13
13
  import logging
14
14
  import os
15
- from typing import Dict, Any, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  try:
18
18
  import lupa
@@ -59,10 +59,13 @@ class LuaSandbox:
59
59
  # Fix base_path at initialization time to prevent security boundary expansion
60
60
  # This ensures file I/O libraries and require() always use the same base path,
61
61
  # even if the working directory changes later
62
- self.base_path = base_path if base_path else os.getcwd()
62
+ self.base_path = base_path or os.getcwd()
63
63
 
64
64
  # Create Lua runtime with safety restrictions
65
- self.lua = LuaRuntime(unpack_returned_tuples=True, attribute_filter=self._attribute_filter)
65
+ self.lua = LuaRuntime(
66
+ unpack_returned_tuples=True,
67
+ attribute_filter=self._attribute_filter,
68
+ )
66
69
 
67
70
  # Remove dangerous modules
68
71
  self._remove_dangerous_modules()
@@ -75,7 +78,7 @@ class LuaSandbox:
75
78
 
76
79
  logger.debug("Lua sandbox initialized successfully")
77
80
 
78
- def _attribute_filter(self, obj, attr_name, is_setting):
81
+ def _attribute_filter(self, obj: Any, attr_name: str, is_setting: bool) -> str:
79
82
  """
80
83
  Filter attribute access to prevent dangerous operations.
81
84
 
@@ -86,7 +89,7 @@ class LuaSandbox:
86
89
  raise AttributeError(f"Access to private attribute '{attr_name}' is not allowed")
87
90
 
88
91
  # Block access to certain dangerous methods
89
- blocked_methods = {
92
+ blocked_attributes = {
90
93
  "__import__",
91
94
  "__loader__",
92
95
  "__spec__",
@@ -98,12 +101,12 @@ class LuaSandbox:
98
101
  "__subclasses__",
99
102
  }
100
103
 
101
- if attr_name in blocked_methods:
104
+ if attr_name in blocked_attributes:
102
105
  raise AttributeError(f"Access to '{attr_name}' is not allowed in sandbox")
103
106
 
104
107
  return attr_name
105
108
 
106
- def _remove_dangerous_modules(self):
109
+ def _remove_dangerous_modules(self) -> None:
107
110
  """Remove dangerous Lua standard library modules."""
108
111
  # Remove modules that provide file system or system access
109
112
  # Note: 'package' and 'require' are kept but restricted in _setup_safe_require()
@@ -125,17 +128,19 @@ class LuaSandbox:
125
128
  # Whitelist only safe debug functions for source location tracking
126
129
  # Keep debug.getinfo but remove dangerous debug functions
127
130
  if "debug" in lua_globals:
128
- self.lua.execute("""
131
+ self.lua.execute(
132
+ """
129
133
  if debug then
130
134
  local safe_debug = {
131
135
  getinfo = debug.getinfo
132
136
  }
133
137
  debug = safe_debug
134
138
  end
135
- """)
139
+ """
140
+ )
136
141
  logger.debug("Replaced debug module with safe_debug (only getinfo allowed)")
137
142
 
138
- def _setup_safe_require(self):
143
+ def _setup_safe_require(self) -> None:
139
144
  """Configure require/package to search user's project and stdlib.
140
145
 
141
146
  This allows using Lua's require() mechanism while restricting module
@@ -157,17 +162,22 @@ class LuaSandbox:
157
162
  # 1. User's project directory (existing behavior)
158
163
  # 2. Tactus stdlib .tac files
159
164
  # Both single-file modules (?.tac) and directory modules (?/init.tac) are supported
160
- user_path = os.path.join(self.base_path, "?.tac")
165
+ user_module_path = os.path.join(self.base_path, "?.tac")
161
166
  user_init_path = os.path.join(self.base_path, "?", "init.tac")
162
- stdlib_path = os.path.join(stdlib_tac_path, "?.tac")
167
+ stdlib_module_path = os.path.join(stdlib_tac_path, "?.tac")
163
168
  stdlib_init_path = os.path.join(stdlib_tac_path, "?", "init.tac")
164
169
 
165
170
  # Normalize backslashes for cross-platform compatibility
166
- paths = [user_path, user_init_path, stdlib_path, stdlib_init_path]
167
- paths = [p.replace("\\", "/") for p in paths]
171
+ raw_paths = [
172
+ user_module_path,
173
+ user_init_path,
174
+ stdlib_module_path,
175
+ stdlib_init_path,
176
+ ]
177
+ normalized_paths = [path.replace("\\", "/") for path in raw_paths]
168
178
 
169
179
  # Join with Lua's path separator (semicolon)
170
- safe_path = ";".join(paths)
180
+ safe_path = ";".join(normalized_paths)
171
181
 
172
182
  lua_globals = self.lua.globals()
173
183
  package = lua_globals["package"]
@@ -186,11 +196,11 @@ class LuaSandbox:
186
196
  # Add Python stdlib loader
187
197
  self._setup_python_stdlib_loader()
188
198
 
189
- logger.debug(f"Configured safe require with paths: {safe_path}")
199
+ logger.debug("Configured safe require with paths: %s", safe_path)
190
200
  else:
191
201
  logger.warning("package module not available - require will not work")
192
202
 
193
- def _setup_python_stdlib_loader(self):
203
+ def _setup_python_stdlib_loader(self) -> None:
194
204
  """Add custom loader for Python stdlib modules."""
195
205
  from tactus.stdlib.loader import StdlibModuleLoader
196
206
 
@@ -203,7 +213,8 @@ class LuaSandbox:
203
213
 
204
214
  # Add to package.loaders (Lua 5.1) or package.searchers (Lua 5.2+)
205
215
  # Lupa uses LuaJIT which follows Lua 5.1 conventions
206
- self.lua.execute("""
216
+ self.lua.execute(
217
+ """
207
218
  -- Add Python stdlib loader to package.loaders
208
219
  -- Insert after the preload loader but before path loader
209
220
  local loaders = package.loaders or package.searchers
@@ -221,11 +232,12 @@ class LuaSandbox:
221
232
  -- Insert at position 2 (after preload, before path)
222
233
  table.insert(loaders, 2, python_searcher)
223
234
  end
224
- """)
235
+ """
236
+ )
225
237
 
226
238
  logger.debug("Python stdlib loader installed")
227
239
 
228
- def _setup_safe_globals(self):
240
+ def _setup_safe_globals(self) -> None:
229
241
  """Setup safe global functions and utilities."""
230
242
  # Keep safe standard library functions
231
243
  # (These are already available by default, just documenting them)
@@ -252,7 +264,7 @@ class LuaSandbox:
252
264
  }
253
265
 
254
266
  # Just log what's available - no need to explicitly set
255
- logger.debug(f"Safe Lua functions available: {', '.join(safe_functions)}")
267
+ logger.debug("Safe Lua functions available: %s", ", ".join(safe_functions))
256
268
 
257
269
  # Replace math and os libraries with safe versions if context available
258
270
  if self.execution_context is not None:
@@ -301,7 +313,7 @@ class LuaSandbox:
301
313
  self.lua.globals()["os"] = safe_os
302
314
  logger.debug("Added safe os.date() function")
303
315
 
304
- def setup_assignment_interception(self, callback: Any):
316
+ def setup_assignment_interception(self, callback: Any) -> None:
305
317
  """
306
318
  Setup assignment interception on global scope to capture variable definitions.
307
319
 
@@ -337,11 +349,15 @@ class LuaSandbox:
337
349
  try:
338
350
  self.lua.execute(lua_code)
339
351
  logger.debug("Assignment interception enabled with metatable on _G")
340
- except Exception as e:
341
- logger.error(f"Failed to setup assignment interception: {e}", exc_info=True)
342
- raise LuaSandboxError(f"Could not setup assignment interception: {e}")
352
+ except Exception as exception:
353
+ logger.error(
354
+ "Failed to setup assignment interception: %s",
355
+ exception,
356
+ exc_info=True,
357
+ )
358
+ raise LuaSandboxError(f"Could not setup assignment interception: {exception}")
343
359
 
344
- def set_execution_context(self, context: Any):
360
+ def set_execution_context(self, context: Any) -> None:
345
361
  """
346
362
  Set or update execution context and refresh safe libraries.
347
363
 
@@ -353,7 +369,7 @@ class LuaSandbox:
353
369
  self._setup_safe_globals()
354
370
  logger.debug("ExecutionContext attached to LuaSandbox")
355
371
 
356
- def inject_primitive(self, name: str, primitive_obj: Any):
372
+ def inject_primitive(self, name: str, primitive_obj: Any) -> None:
357
373
  """
358
374
  Inject a Python primitive object into Lua globals.
359
375
 
@@ -362,9 +378,9 @@ class LuaSandbox:
362
378
  primitive_obj: Python object to expose to Lua
363
379
  """
364
380
  self.lua.globals()[name] = primitive_obj
365
- logger.debug(f"Injected primitive '{name}' into Lua sandbox")
381
+ logger.debug("Injected primitive '%s' into Lua sandbox", name)
366
382
 
367
- def set_global(self, name: str, value: Any):
383
+ def set_global(self, name: str, value: Any) -> None:
368
384
  """
369
385
  Set a global variable in Lua.
370
386
 
@@ -372,28 +388,20 @@ class LuaSandbox:
372
388
  name: Name of the global variable
373
389
  value: Value to set (can be Python object, dict, etc.)
374
390
  """
375
- # Convert Python dicts to Lua tables if needed
391
+ self.lua.globals()[name] = self._convert_python_value_to_lua(value)
392
+ logger.debug("Set global '%s' in Lua sandbox", name)
393
+
394
+ def _convert_python_value_to_lua(self, value: Any) -> Any:
395
+ """Convert Python values to Lua-friendly values."""
376
396
  if isinstance(value, dict):
377
- lua_table = self.lua.table()
378
- for k, v in value.items():
379
- if isinstance(v, dict):
380
- # Recursively convert nested dicts
381
- lua_table[k] = self._dict_to_lua_table(v)
382
- else:
383
- lua_table[k] = v
384
- self.lua.globals()[name] = lua_table
385
- else:
386
- self.lua.globals()[name] = value
387
- logger.debug(f"Set global '{name}' in Lua sandbox")
397
+ return self._dict_to_lua_table(value)
398
+ return value
388
399
 
389
- def _dict_to_lua_table(self, d: dict):
400
+ def _dict_to_lua_table(self, python_dict: dict) -> Any:
390
401
  """Convert Python dict to Lua table recursively."""
391
402
  lua_table = self.lua.table()
392
- for k, v in d.items():
393
- if isinstance(v, dict):
394
- lua_table[k] = self._dict_to_lua_table(v)
395
- else:
396
- lua_table[k] = v
403
+ for key, value in python_dict.items():
404
+ lua_table[key] = self._convert_python_value_to_lua(value)
397
405
  return lua_table
398
406
 
399
407
  def execute(self, lua_code: str) -> Any:
@@ -410,21 +418,21 @@ class LuaSandbox:
410
418
  LuaSandboxError: If execution fails
411
419
  """
412
420
  try:
413
- logger.debug(f"Executing Lua code ({len(lua_code)} bytes)")
421
+ logger.debug("Executing Lua code (%s bytes)", len(lua_code))
414
422
  result = self.lua.execute(lua_code)
415
423
  logger.debug("Lua execution completed successfully")
416
424
  return result
417
425
 
418
- except lupa.LuaError as e:
426
+ except lupa.LuaError as exception:
419
427
  # Lua runtime error
420
- error_msg = str(e)
421
- logger.error(f"Lua execution error: {error_msg}")
422
- raise LuaSandboxError(f"Lua runtime error: {error_msg}")
428
+ error_message = str(exception)
429
+ logger.error("Lua execution error: %s", error_message)
430
+ raise LuaSandboxError(f"Lua runtime error: {error_message}")
423
431
 
424
- except Exception as e:
432
+ except Exception as exception:
425
433
  # Other Python exceptions
426
- logger.error(f"Sandbox execution error: {e}")
427
- raise LuaSandboxError(f"Sandbox error: {e}")
434
+ logger.error("Sandbox execution error: %s", exception)
435
+ raise LuaSandboxError(f"Sandbox error: {exception}")
428
436
 
429
437
  def eval(self, lua_expression: str) -> Any:
430
438
  """
@@ -443,16 +451,16 @@ class LuaSandbox:
443
451
  result = self.lua.eval(lua_expression)
444
452
  return result
445
453
 
446
- except lupa.LuaError as e:
447
- error_msg = str(e)
448
- logger.error(f"Lua eval error: {error_msg}")
449
- raise LuaSandboxError(f"Lua eval error: {error_msg}")
454
+ except lupa.LuaError as exception:
455
+ error_message = str(exception)
456
+ logger.error("Lua eval error: %s", error_message)
457
+ raise LuaSandboxError(f"Lua eval error: {error_message}")
450
458
 
451
459
  def get_global(self, name: str) -> Any:
452
460
  """Get a value from Lua global scope."""
453
461
  return self.lua.globals()[name]
454
462
 
455
- def create_lua_table(self, python_dict: Optional[Dict[str, Any]] = None) -> Any:
463
+ def create_lua_table(self, python_dict: Optional[dict[str, Any]] = None) -> Any:
456
464
  """
457
465
  Create a Lua table from a Python dictionary.
458
466
 
@@ -469,11 +477,11 @@ class LuaSandbox:
469
477
  # Create and populate Lua table
470
478
  lua_table = self.lua.table()
471
479
  for key, value in python_dict.items():
472
- lua_table[key] = value
480
+ lua_table[key] = self._convert_python_value_to_lua(value)
473
481
 
474
482
  return lua_table
475
483
 
476
- def lua_table_to_dict(self, lua_table: Any) -> Dict[str, Any]:
484
+ def lua_table_to_dict(self, lua_table: Any) -> dict[str, Any]:
477
485
  """
478
486
  Convert a Lua table to a Python dictionary.
479
487
 
@@ -495,8 +503,8 @@ class LuaSandbox:
495
503
  else:
496
504
  result[key] = value
497
505
 
498
- except Exception as e:
499
- logger.warning(f"Error converting Lua table to dict: {e}")
506
+ except Exception as exception:
507
+ logger.warning("Error converting Lua table to dict: %s", exception)
500
508
  # Fallback: try direct iteration
501
509
  try:
502
510
  for key in lua_table:
@@ -7,6 +7,7 @@ token budgets, message limits, and custom filters.
7
7
  Aligned with pydantic-ai's message_history concept.
8
8
  """
9
9
 
10
+ from datetime import datetime, timezone
10
11
  from typing import Any, Optional
11
12
 
12
13
  try:
@@ -29,6 +30,8 @@ class MessageHistoryManager:
29
30
  """Initialize message history manager."""
30
31
  self.histories: dict[str, list[ModelMessage]] = {}
31
32
  self.shared_history: list[ModelMessage] = []
33
+ self._next_message_id = 1
34
+ self._checkpoints: dict[str, int] = {}
32
35
 
33
36
  def get_history_for_agent(
34
37
  self,
@@ -56,22 +59,24 @@ class MessageHistoryManager:
56
59
 
57
60
  # Determine source
58
61
  if message_history_config.source == "own":
59
- messages = self.histories.get(agent_name, [])
62
+ selected_messages = self.histories.get(agent_name, [])
60
63
  elif message_history_config.source == "shared":
61
- messages = self.shared_history
64
+ selected_messages = self.shared_history
62
65
  else:
63
66
  # Another agent's history
64
- messages = self.histories.get(message_history_config.source, [])
67
+ selected_messages = self.histories.get(message_history_config.source, [])
65
68
 
66
69
  # Apply filter if specified
67
70
  if message_history_config.filter:
68
- messages = self._apply_filter(messages, message_history_config.filter, context)
71
+ selected_messages = self._apply_filter(
72
+ selected_messages, message_history_config.filter, context
73
+ )
69
74
 
70
- return messages
75
+ return selected_messages
71
76
 
72
77
  def add_message(
73
78
  self,
74
- agent_name: str,
79
+ agent_name: Optional[str],
75
80
  message: ModelMessage,
76
81
  also_shared: bool = False,
77
82
  ) -> None:
@@ -83,6 +88,12 @@ class MessageHistoryManager:
83
88
  message: Message to add
84
89
  also_shared: Also add to shared history
85
90
  """
91
+ message = self._ensure_message_metadata(message)
92
+
93
+ if agent_name is None:
94
+ self.shared_history.append(message)
95
+ return
96
+
86
97
  if agent_name not in self.histories:
87
98
  self.histories[agent_name] = []
88
99
 
@@ -102,7 +113,7 @@ class MessageHistoryManager:
102
113
  def _apply_filter(
103
114
  self,
104
115
  messages: list[ModelMessage],
105
- filter_spec: Any,
116
+ filter_specification: Any,
106
117
  context: Optional[Any],
107
118
  ) -> list[ModelMessage]:
108
119
  """
@@ -117,33 +128,41 @@ class MessageHistoryManager:
117
128
  Filtered messages
118
129
  """
119
130
  # If it's a callable (Lua function), call it
120
- if callable(filter_spec):
131
+ if callable(filter_specification):
121
132
  try:
122
- return filter_spec(messages, context)
123
- except Exception as e:
133
+ return filter_specification(messages, context)
134
+ except Exception as exception:
124
135
  # If filter fails, return unfiltered
125
- print(f"Warning: Filter function failed: {e}")
136
+ print(f"Warning: Filter function failed: {exception}")
126
137
  return messages
127
138
 
128
139
  # Otherwise it's a tuple (filter_type, filter_arg)
129
- if not isinstance(filter_spec, tuple) or len(filter_spec) < 2:
140
+ if not isinstance(filter_specification, tuple) or len(filter_specification) < 2:
130
141
  return messages
131
142
 
132
- filter_type = filter_spec[0]
133
- filter_arg = filter_spec[1]
134
-
135
- if filter_type == "last_n":
136
- return self._filter_last_n(messages, filter_arg)
137
- elif filter_type == "token_budget":
138
- return self._filter_by_token_budget(messages, filter_arg)
139
- elif filter_type == "by_role":
140
- return self._filter_by_role(messages, filter_arg)
141
- elif filter_type == "compose":
143
+ filter_name = filter_specification[0]
144
+ filter_value = filter_specification[1]
145
+
146
+ if filter_name == "last_n":
147
+ return self._filter_last_n(messages, filter_value)
148
+ elif filter_name == "first_n":
149
+ return self._filter_first_n(messages, filter_value)
150
+ elif filter_name == "token_budget":
151
+ return self._filter_by_token_budget(messages, filter_value)
152
+ elif filter_name == "head_tokens":
153
+ return self._filter_head_tokens(messages, filter_value)
154
+ elif filter_name == "tail_tokens":
155
+ return self._filter_tail_tokens(messages, filter_value)
156
+ elif filter_name == "by_role":
157
+ return self._filter_by_role(messages, filter_value)
158
+ elif filter_name == "system_prefix":
159
+ return self._filter_system_prefix(messages)
160
+ elif filter_name == "compose":
142
161
  # Apply multiple filters in sequence
143
- result = messages
144
- for f in filter_arg:
145
- result = self._apply_filter(result, f, context)
146
- return result
162
+ filtered_messages = messages
163
+ for filter_step in filter_value:
164
+ filtered_messages = self._apply_filter(filtered_messages, filter_step, context)
165
+ return filtered_messages
147
166
  else:
148
167
  # Unknown filter type, return unfiltered
149
168
  return messages
@@ -156,6 +175,14 @@ class MessageHistoryManager:
156
175
  """Keep only the last N messages."""
157
176
  return messages[-n:] if n > 0 else []
158
177
 
178
+ def _filter_first_n(
179
+ self,
180
+ messages: list[ModelMessage],
181
+ n: int,
182
+ ) -> list[ModelMessage]:
183
+ """Keep only the first N messages."""
184
+ return messages[:n] if n > 0 else []
185
+
159
186
  def _filter_by_token_budget(
160
187
  self,
161
188
  messages: list[ModelMessage],
@@ -173,22 +200,52 @@ class MessageHistoryManager:
173
200
  # Rough estimate: 4 chars per token
174
201
  max_chars = max_tokens * 4
175
202
 
176
- result = []
177
- current_chars = 0
203
+ filtered_messages = []
204
+ current_character_count = 0
178
205
 
179
206
  # Work backwards from most recent
180
207
  for message in reversed(messages):
181
208
  # Estimate message size
182
- message_chars = self._estimate_message_chars(message)
209
+ message_character_count = self._estimate_message_chars(message)
183
210
 
184
- if current_chars + message_chars > max_chars:
211
+ if current_character_count + message_character_count > max_chars:
185
212
  # Would exceed budget, stop here
186
213
  break
187
214
 
188
- result.insert(0, message)
189
- current_chars += message_chars
215
+ filtered_messages.insert(0, message)
216
+ current_character_count += message_character_count
190
217
 
191
- return result
218
+ return filtered_messages
219
+
220
+ def _filter_head_tokens(
221
+ self,
222
+ messages: list[ModelMessage],
223
+ max_tokens: int,
224
+ ) -> list[ModelMessage]:
225
+ """Keep earliest messages that fit within the token budget."""
226
+ if max_tokens <= 0:
227
+ return []
228
+
229
+ max_chars = max_tokens * 4
230
+ filtered_messages = []
231
+ current_character_count = 0
232
+
233
+ for message in messages:
234
+ message_character_count = self._estimate_message_chars(message)
235
+ if current_character_count + message_character_count > max_chars:
236
+ break
237
+ filtered_messages.append(message)
238
+ current_character_count += message_character_count
239
+
240
+ return filtered_messages
241
+
242
+ def _filter_tail_tokens(
243
+ self,
244
+ messages: list[ModelMessage],
245
+ max_tokens: int,
246
+ ) -> list[ModelMessage]:
247
+ """Keep latest messages that fit within the token budget."""
248
+ return self._filter_by_token_budget(messages, max_tokens)
192
249
 
193
250
  def _filter_by_role(
194
251
  self,
@@ -198,23 +255,61 @@ class MessageHistoryManager:
198
255
  """Keep only messages with specified role."""
199
256
  return [m for m in messages if self._get_message_role(m) == role]
200
257
 
258
+ def _filter_system_prefix(
259
+ self,
260
+ messages: list[ModelMessage],
261
+ ) -> list[ModelMessage]:
262
+ """Keep only the leading contiguous system messages."""
263
+ system_prefix_messages: list[ModelMessage] = []
264
+ for message in messages:
265
+ if self._get_message_role(message) != "system":
266
+ break
267
+ system_prefix_messages.append(message)
268
+ return system_prefix_messages
269
+
270
+ def _ensure_message_metadata(self, message: ModelMessage) -> ModelMessage:
271
+ """Ensure message has id and created_at metadata when dict-based."""
272
+ if not isinstance(message, dict):
273
+ return message
274
+
275
+ if "id" not in message:
276
+ message["id"] = self._next_message_id
277
+ self._next_message_id += 1
278
+
279
+ if "created_at" not in message:
280
+ message["created_at"] = datetime.now(timezone.utc).isoformat()
281
+
282
+ return message
283
+
284
+ def record_checkpoint(self, name: str, message_id: int) -> None:
285
+ """Record a named checkpoint pointing at a message id."""
286
+ self._checkpoints[name] = message_id
287
+
288
+ def get_checkpoint(self, name: str) -> Optional[int]:
289
+ """Retrieve a checkpoint id by name."""
290
+ return self._checkpoints.get(name)
291
+
292
+ def next_message_id(self) -> int:
293
+ """Return the next message id that will be assigned."""
294
+ return self._next_message_id
295
+
201
296
  def _estimate_message_chars(self, message: ModelMessage) -> int:
202
297
  """Estimate character count of a message."""
203
298
  if isinstance(message, dict):
204
299
  # Dict-based message
205
- content = message.get("content", "")
206
- if isinstance(content, str):
207
- return len(content)
208
- elif isinstance(content, list):
300
+ message_content = message.get("content", "")
301
+ if isinstance(message_content, str):
302
+ return len(message_content)
303
+ elif isinstance(message_content, list):
209
304
  # Multiple content parts
210
- total = 0
211
- for part in content:
305
+ total_character_count = 0
306
+ for part in message_content:
212
307
  if isinstance(part, dict):
213
- total += len(str(part.get("text", "")))
308
+ total_character_count += len(str(part.get("text", "")))
214
309
  else:
215
- total += len(str(part))
216
- return total
217
- return len(str(content))
310
+ total_character_count += len(str(part))
311
+ return total_character_count
312
+ return len(str(message_content))
218
313
  else:
219
314
  # Pydantic AI ModelMessage object
220
315
  try: