tactus 0.34.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +1 -1
- tactus/adapters/broker_log.py +17 -14
- tactus/adapters/channels/__init__.py +17 -15
- tactus/adapters/channels/base.py +16 -7
- tactus/adapters/channels/broker.py +43 -13
- tactus/adapters/channels/cli.py +19 -15
- tactus/adapters/channels/host.py +15 -6
- tactus/adapters/channels/ipc.py +82 -31
- tactus/adapters/channels/sse.py +41 -23
- tactus/adapters/cli_hitl.py +19 -19
- tactus/adapters/cli_log.py +4 -4
- tactus/adapters/control_loop.py +138 -99
- tactus/adapters/cost_collector_log.py +9 -9
- tactus/adapters/file_storage.py +56 -52
- tactus/adapters/http_callback_log.py +23 -13
- tactus/adapters/ide_log.py +17 -9
- tactus/adapters/lua_tools.py +4 -5
- tactus/adapters/mcp.py +16 -19
- tactus/adapters/mcp_manager.py +46 -30
- tactus/adapters/memory.py +9 -9
- tactus/adapters/plugins.py +42 -42
- tactus/broker/client.py +75 -78
- tactus/broker/protocol.py +57 -57
- tactus/broker/server.py +252 -197
- tactus/cli/app.py +3 -1
- tactus/cli/control.py +2 -2
- tactus/core/config_manager.py +181 -135
- tactus/core/dependencies/registry.py +66 -48
- tactus/core/dsl_stubs.py +222 -163
- tactus/core/exceptions.py +10 -1
- tactus/core/execution_context.py +152 -112
- tactus/core/lua_sandbox.py +72 -64
- tactus/core/message_history_manager.py +138 -43
- tactus/core/mocking.py +41 -27
- tactus/core/output_validator.py +49 -44
- tactus/core/registry.py +94 -80
- tactus/core/runtime.py +211 -176
- tactus/core/template_resolver.py +16 -16
- tactus/core/yaml_parser.py +55 -45
- tactus/docs/extractor.py +7 -6
- tactus/ide/server.py +119 -78
- tactus/primitives/control.py +10 -6
- tactus/primitives/file.py +48 -46
- tactus/primitives/handles.py +47 -35
- tactus/primitives/host.py +29 -27
- tactus/primitives/human.py +154 -137
- tactus/primitives/json.py +22 -23
- tactus/primitives/log.py +26 -26
- tactus/primitives/message_history.py +285 -31
- tactus/primitives/model.py +15 -9
- tactus/primitives/procedure.py +86 -64
- tactus/primitives/procedure_callable.py +58 -51
- tactus/primitives/retry.py +31 -29
- tactus/primitives/session.py +42 -29
- tactus/primitives/state.py +54 -43
- tactus/primitives/step.py +9 -13
- tactus/primitives/system.py +34 -21
- tactus/primitives/tool.py +44 -31
- tactus/primitives/tool_handle.py +76 -54
- tactus/primitives/toolset.py +25 -22
- tactus/sandbox/config.py +4 -4
- tactus/sandbox/container_runner.py +161 -107
- tactus/sandbox/docker_manager.py +20 -20
- tactus/sandbox/entrypoint.py +16 -14
- tactus/sandbox/protocol.py +15 -15
- tactus/stdlib/classify/llm.py +1 -3
- tactus/stdlib/core/validation.py +0 -3
- tactus/testing/pydantic_eval_runner.py +1 -1
- tactus/utils/asyncio_helpers.py +27 -0
- tactus/utils/cost_calculator.py +7 -7
- tactus/utils/model_pricing.py +11 -12
- tactus/utils/safe_file_library.py +156 -132
- tactus/utils/safe_libraries.py +27 -27
- tactus/validation/error_listener.py +18 -5
- tactus/validation/semantic_visitor.py +392 -333
- tactus/validation/validator.py +89 -49
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/METADATA +12 -3
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/RECORD +81 -80
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/WHEEL +0 -0
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/entry_points.txt +0 -0
- {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/licenses/LICENSE +0 -0
tactus/core/lua_sandbox.py
CHANGED
|
@@ -12,7 +12,7 @@ Provides a sandboxed Lua runtime with:
|
|
|
12
12
|
|
|
13
13
|
import logging
|
|
14
14
|
import os
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
try:
|
|
18
18
|
import lupa
|
|
@@ -59,10 +59,13 @@ class LuaSandbox:
|
|
|
59
59
|
# Fix base_path at initialization time to prevent security boundary expansion
|
|
60
60
|
# This ensures file I/O libraries and require() always use the same base path,
|
|
61
61
|
# even if the working directory changes later
|
|
62
|
-
self.base_path = base_path
|
|
62
|
+
self.base_path = base_path or os.getcwd()
|
|
63
63
|
|
|
64
64
|
# Create Lua runtime with safety restrictions
|
|
65
|
-
self.lua = LuaRuntime(
|
|
65
|
+
self.lua = LuaRuntime(
|
|
66
|
+
unpack_returned_tuples=True,
|
|
67
|
+
attribute_filter=self._attribute_filter,
|
|
68
|
+
)
|
|
66
69
|
|
|
67
70
|
# Remove dangerous modules
|
|
68
71
|
self._remove_dangerous_modules()
|
|
@@ -75,7 +78,7 @@ class LuaSandbox:
|
|
|
75
78
|
|
|
76
79
|
logger.debug("Lua sandbox initialized successfully")
|
|
77
80
|
|
|
78
|
-
def _attribute_filter(self, obj, attr_name, is_setting):
|
|
81
|
+
def _attribute_filter(self, obj: Any, attr_name: str, is_setting: bool) -> str:
|
|
79
82
|
"""
|
|
80
83
|
Filter attribute access to prevent dangerous operations.
|
|
81
84
|
|
|
@@ -86,7 +89,7 @@ class LuaSandbox:
|
|
|
86
89
|
raise AttributeError(f"Access to private attribute '{attr_name}' is not allowed")
|
|
87
90
|
|
|
88
91
|
# Block access to certain dangerous methods
|
|
89
|
-
|
|
92
|
+
blocked_attributes = {
|
|
90
93
|
"__import__",
|
|
91
94
|
"__loader__",
|
|
92
95
|
"__spec__",
|
|
@@ -98,12 +101,12 @@ class LuaSandbox:
|
|
|
98
101
|
"__subclasses__",
|
|
99
102
|
}
|
|
100
103
|
|
|
101
|
-
if attr_name in
|
|
104
|
+
if attr_name in blocked_attributes:
|
|
102
105
|
raise AttributeError(f"Access to '{attr_name}' is not allowed in sandbox")
|
|
103
106
|
|
|
104
107
|
return attr_name
|
|
105
108
|
|
|
106
|
-
def _remove_dangerous_modules(self):
|
|
109
|
+
def _remove_dangerous_modules(self) -> None:
|
|
107
110
|
"""Remove dangerous Lua standard library modules."""
|
|
108
111
|
# Remove modules that provide file system or system access
|
|
109
112
|
# Note: 'package' and 'require' are kept but restricted in _setup_safe_require()
|
|
@@ -125,17 +128,19 @@ class LuaSandbox:
|
|
|
125
128
|
# Whitelist only safe debug functions for source location tracking
|
|
126
129
|
# Keep debug.getinfo but remove dangerous debug functions
|
|
127
130
|
if "debug" in lua_globals:
|
|
128
|
-
self.lua.execute(
|
|
131
|
+
self.lua.execute(
|
|
132
|
+
"""
|
|
129
133
|
if debug then
|
|
130
134
|
local safe_debug = {
|
|
131
135
|
getinfo = debug.getinfo
|
|
132
136
|
}
|
|
133
137
|
debug = safe_debug
|
|
134
138
|
end
|
|
135
|
-
|
|
139
|
+
"""
|
|
140
|
+
)
|
|
136
141
|
logger.debug("Replaced debug module with safe_debug (only getinfo allowed)")
|
|
137
142
|
|
|
138
|
-
def _setup_safe_require(self):
|
|
143
|
+
def _setup_safe_require(self) -> None:
|
|
139
144
|
"""Configure require/package to search user's project and stdlib.
|
|
140
145
|
|
|
141
146
|
This allows using Lua's require() mechanism while restricting module
|
|
@@ -157,17 +162,22 @@ class LuaSandbox:
|
|
|
157
162
|
# 1. User's project directory (existing behavior)
|
|
158
163
|
# 2. Tactus stdlib .tac files
|
|
159
164
|
# Both single-file modules (?.tac) and directory modules (?/init.tac) are supported
|
|
160
|
-
|
|
165
|
+
user_module_path = os.path.join(self.base_path, "?.tac")
|
|
161
166
|
user_init_path = os.path.join(self.base_path, "?", "init.tac")
|
|
162
|
-
|
|
167
|
+
stdlib_module_path = os.path.join(stdlib_tac_path, "?.tac")
|
|
163
168
|
stdlib_init_path = os.path.join(stdlib_tac_path, "?", "init.tac")
|
|
164
169
|
|
|
165
170
|
# Normalize backslashes for cross-platform compatibility
|
|
166
|
-
|
|
167
|
-
|
|
171
|
+
raw_paths = [
|
|
172
|
+
user_module_path,
|
|
173
|
+
user_init_path,
|
|
174
|
+
stdlib_module_path,
|
|
175
|
+
stdlib_init_path,
|
|
176
|
+
]
|
|
177
|
+
normalized_paths = [path.replace("\\", "/") for path in raw_paths]
|
|
168
178
|
|
|
169
179
|
# Join with Lua's path separator (semicolon)
|
|
170
|
-
safe_path = ";".join(
|
|
180
|
+
safe_path = ";".join(normalized_paths)
|
|
171
181
|
|
|
172
182
|
lua_globals = self.lua.globals()
|
|
173
183
|
package = lua_globals["package"]
|
|
@@ -186,11 +196,11 @@ class LuaSandbox:
|
|
|
186
196
|
# Add Python stdlib loader
|
|
187
197
|
self._setup_python_stdlib_loader()
|
|
188
198
|
|
|
189
|
-
logger.debug(
|
|
199
|
+
logger.debug("Configured safe require with paths: %s", safe_path)
|
|
190
200
|
else:
|
|
191
201
|
logger.warning("package module not available - require will not work")
|
|
192
202
|
|
|
193
|
-
def _setup_python_stdlib_loader(self):
|
|
203
|
+
def _setup_python_stdlib_loader(self) -> None:
|
|
194
204
|
"""Add custom loader for Python stdlib modules."""
|
|
195
205
|
from tactus.stdlib.loader import StdlibModuleLoader
|
|
196
206
|
|
|
@@ -203,7 +213,8 @@ class LuaSandbox:
|
|
|
203
213
|
|
|
204
214
|
# Add to package.loaders (Lua 5.1) or package.searchers (Lua 5.2+)
|
|
205
215
|
# Lupa uses LuaJIT which follows Lua 5.1 conventions
|
|
206
|
-
self.lua.execute(
|
|
216
|
+
self.lua.execute(
|
|
217
|
+
"""
|
|
207
218
|
-- Add Python stdlib loader to package.loaders
|
|
208
219
|
-- Insert after the preload loader but before path loader
|
|
209
220
|
local loaders = package.loaders or package.searchers
|
|
@@ -221,11 +232,12 @@ class LuaSandbox:
|
|
|
221
232
|
-- Insert at position 2 (after preload, before path)
|
|
222
233
|
table.insert(loaders, 2, python_searcher)
|
|
223
234
|
end
|
|
224
|
-
|
|
235
|
+
"""
|
|
236
|
+
)
|
|
225
237
|
|
|
226
238
|
logger.debug("Python stdlib loader installed")
|
|
227
239
|
|
|
228
|
-
def _setup_safe_globals(self):
|
|
240
|
+
def _setup_safe_globals(self) -> None:
|
|
229
241
|
"""Setup safe global functions and utilities."""
|
|
230
242
|
# Keep safe standard library functions
|
|
231
243
|
# (These are already available by default, just documenting them)
|
|
@@ -252,7 +264,7 @@ class LuaSandbox:
|
|
|
252
264
|
}
|
|
253
265
|
|
|
254
266
|
# Just log what's available - no need to explicitly set
|
|
255
|
-
logger.debug(
|
|
267
|
+
logger.debug("Safe Lua functions available: %s", ", ".join(safe_functions))
|
|
256
268
|
|
|
257
269
|
# Replace math and os libraries with safe versions if context available
|
|
258
270
|
if self.execution_context is not None:
|
|
@@ -301,7 +313,7 @@ class LuaSandbox:
|
|
|
301
313
|
self.lua.globals()["os"] = safe_os
|
|
302
314
|
logger.debug("Added safe os.date() function")
|
|
303
315
|
|
|
304
|
-
def setup_assignment_interception(self, callback: Any):
|
|
316
|
+
def setup_assignment_interception(self, callback: Any) -> None:
|
|
305
317
|
"""
|
|
306
318
|
Setup assignment interception on global scope to capture variable definitions.
|
|
307
319
|
|
|
@@ -337,11 +349,15 @@ class LuaSandbox:
|
|
|
337
349
|
try:
|
|
338
350
|
self.lua.execute(lua_code)
|
|
339
351
|
logger.debug("Assignment interception enabled with metatable on _G")
|
|
340
|
-
except Exception as
|
|
341
|
-
logger.error(
|
|
342
|
-
|
|
352
|
+
except Exception as exception:
|
|
353
|
+
logger.error(
|
|
354
|
+
"Failed to setup assignment interception: %s",
|
|
355
|
+
exception,
|
|
356
|
+
exc_info=True,
|
|
357
|
+
)
|
|
358
|
+
raise LuaSandboxError(f"Could not setup assignment interception: {exception}")
|
|
343
359
|
|
|
344
|
-
def set_execution_context(self, context: Any):
|
|
360
|
+
def set_execution_context(self, context: Any) -> None:
|
|
345
361
|
"""
|
|
346
362
|
Set or update execution context and refresh safe libraries.
|
|
347
363
|
|
|
@@ -353,7 +369,7 @@ class LuaSandbox:
|
|
|
353
369
|
self._setup_safe_globals()
|
|
354
370
|
logger.debug("ExecutionContext attached to LuaSandbox")
|
|
355
371
|
|
|
356
|
-
def inject_primitive(self, name: str, primitive_obj: Any):
|
|
372
|
+
def inject_primitive(self, name: str, primitive_obj: Any) -> None:
|
|
357
373
|
"""
|
|
358
374
|
Inject a Python primitive object into Lua globals.
|
|
359
375
|
|
|
@@ -362,9 +378,9 @@ class LuaSandbox:
|
|
|
362
378
|
primitive_obj: Python object to expose to Lua
|
|
363
379
|
"""
|
|
364
380
|
self.lua.globals()[name] = primitive_obj
|
|
365
|
-
logger.debug(
|
|
381
|
+
logger.debug("Injected primitive '%s' into Lua sandbox", name)
|
|
366
382
|
|
|
367
|
-
def set_global(self, name: str, value: Any):
|
|
383
|
+
def set_global(self, name: str, value: Any) -> None:
|
|
368
384
|
"""
|
|
369
385
|
Set a global variable in Lua.
|
|
370
386
|
|
|
@@ -372,28 +388,20 @@ class LuaSandbox:
|
|
|
372
388
|
name: Name of the global variable
|
|
373
389
|
value: Value to set (can be Python object, dict, etc.)
|
|
374
390
|
"""
|
|
375
|
-
|
|
391
|
+
self.lua.globals()[name] = self._convert_python_value_to_lua(value)
|
|
392
|
+
logger.debug("Set global '%s' in Lua sandbox", name)
|
|
393
|
+
|
|
394
|
+
def _convert_python_value_to_lua(self, value: Any) -> Any:
|
|
395
|
+
"""Convert Python values to Lua-friendly values."""
|
|
376
396
|
if isinstance(value, dict):
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
if isinstance(v, dict):
|
|
380
|
-
# Recursively convert nested dicts
|
|
381
|
-
lua_table[k] = self._dict_to_lua_table(v)
|
|
382
|
-
else:
|
|
383
|
-
lua_table[k] = v
|
|
384
|
-
self.lua.globals()[name] = lua_table
|
|
385
|
-
else:
|
|
386
|
-
self.lua.globals()[name] = value
|
|
387
|
-
logger.debug(f"Set global '{name}' in Lua sandbox")
|
|
397
|
+
return self._dict_to_lua_table(value)
|
|
398
|
+
return value
|
|
388
399
|
|
|
389
|
-
def _dict_to_lua_table(self,
|
|
400
|
+
def _dict_to_lua_table(self, python_dict: dict) -> Any:
|
|
390
401
|
"""Convert Python dict to Lua table recursively."""
|
|
391
402
|
lua_table = self.lua.table()
|
|
392
|
-
for
|
|
393
|
-
|
|
394
|
-
lua_table[k] = self._dict_to_lua_table(v)
|
|
395
|
-
else:
|
|
396
|
-
lua_table[k] = v
|
|
403
|
+
for key, value in python_dict.items():
|
|
404
|
+
lua_table[key] = self._convert_python_value_to_lua(value)
|
|
397
405
|
return lua_table
|
|
398
406
|
|
|
399
407
|
def execute(self, lua_code: str) -> Any:
|
|
@@ -410,21 +418,21 @@ class LuaSandbox:
|
|
|
410
418
|
LuaSandboxError: If execution fails
|
|
411
419
|
"""
|
|
412
420
|
try:
|
|
413
|
-
logger.debug(
|
|
421
|
+
logger.debug("Executing Lua code (%s bytes)", len(lua_code))
|
|
414
422
|
result = self.lua.execute(lua_code)
|
|
415
423
|
logger.debug("Lua execution completed successfully")
|
|
416
424
|
return result
|
|
417
425
|
|
|
418
|
-
except lupa.LuaError as
|
|
426
|
+
except lupa.LuaError as exception:
|
|
419
427
|
# Lua runtime error
|
|
420
|
-
|
|
421
|
-
logger.error(
|
|
422
|
-
raise LuaSandboxError(f"Lua runtime error: {
|
|
428
|
+
error_message = str(exception)
|
|
429
|
+
logger.error("Lua execution error: %s", error_message)
|
|
430
|
+
raise LuaSandboxError(f"Lua runtime error: {error_message}")
|
|
423
431
|
|
|
424
|
-
except Exception as
|
|
432
|
+
except Exception as exception:
|
|
425
433
|
# Other Python exceptions
|
|
426
|
-
logger.error(
|
|
427
|
-
raise LuaSandboxError(f"Sandbox error: {
|
|
434
|
+
logger.error("Sandbox execution error: %s", exception)
|
|
435
|
+
raise LuaSandboxError(f"Sandbox error: {exception}")
|
|
428
436
|
|
|
429
437
|
def eval(self, lua_expression: str) -> Any:
|
|
430
438
|
"""
|
|
@@ -443,16 +451,16 @@ class LuaSandbox:
|
|
|
443
451
|
result = self.lua.eval(lua_expression)
|
|
444
452
|
return result
|
|
445
453
|
|
|
446
|
-
except lupa.LuaError as
|
|
447
|
-
|
|
448
|
-
logger.error(
|
|
449
|
-
raise LuaSandboxError(f"Lua eval error: {
|
|
454
|
+
except lupa.LuaError as exception:
|
|
455
|
+
error_message = str(exception)
|
|
456
|
+
logger.error("Lua eval error: %s", error_message)
|
|
457
|
+
raise LuaSandboxError(f"Lua eval error: {error_message}")
|
|
450
458
|
|
|
451
459
|
def get_global(self, name: str) -> Any:
|
|
452
460
|
"""Get a value from Lua global scope."""
|
|
453
461
|
return self.lua.globals()[name]
|
|
454
462
|
|
|
455
|
-
def create_lua_table(self, python_dict: Optional[
|
|
463
|
+
def create_lua_table(self, python_dict: Optional[dict[str, Any]] = None) -> Any:
|
|
456
464
|
"""
|
|
457
465
|
Create a Lua table from a Python dictionary.
|
|
458
466
|
|
|
@@ -469,11 +477,11 @@ class LuaSandbox:
|
|
|
469
477
|
# Create and populate Lua table
|
|
470
478
|
lua_table = self.lua.table()
|
|
471
479
|
for key, value in python_dict.items():
|
|
472
|
-
lua_table[key] = value
|
|
480
|
+
lua_table[key] = self._convert_python_value_to_lua(value)
|
|
473
481
|
|
|
474
482
|
return lua_table
|
|
475
483
|
|
|
476
|
-
def lua_table_to_dict(self, lua_table: Any) ->
|
|
484
|
+
def lua_table_to_dict(self, lua_table: Any) -> dict[str, Any]:
|
|
477
485
|
"""
|
|
478
486
|
Convert a Lua table to a Python dictionary.
|
|
479
487
|
|
|
@@ -495,8 +503,8 @@ class LuaSandbox:
|
|
|
495
503
|
else:
|
|
496
504
|
result[key] = value
|
|
497
505
|
|
|
498
|
-
except Exception as
|
|
499
|
-
logger.warning(
|
|
506
|
+
except Exception as exception:
|
|
507
|
+
logger.warning("Error converting Lua table to dict: %s", exception)
|
|
500
508
|
# Fallback: try direct iteration
|
|
501
509
|
try:
|
|
502
510
|
for key in lua_table:
|
|
@@ -7,6 +7,7 @@ token budgets, message limits, and custom filters.
|
|
|
7
7
|
Aligned with pydantic-ai's message_history concept.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
+
from datetime import datetime, timezone
|
|
10
11
|
from typing import Any, Optional
|
|
11
12
|
|
|
12
13
|
try:
|
|
@@ -29,6 +30,8 @@ class MessageHistoryManager:
|
|
|
29
30
|
"""Initialize message history manager."""
|
|
30
31
|
self.histories: dict[str, list[ModelMessage]] = {}
|
|
31
32
|
self.shared_history: list[ModelMessage] = []
|
|
33
|
+
self._next_message_id = 1
|
|
34
|
+
self._checkpoints: dict[str, int] = {}
|
|
32
35
|
|
|
33
36
|
def get_history_for_agent(
|
|
34
37
|
self,
|
|
@@ -56,22 +59,24 @@ class MessageHistoryManager:
|
|
|
56
59
|
|
|
57
60
|
# Determine source
|
|
58
61
|
if message_history_config.source == "own":
|
|
59
|
-
|
|
62
|
+
selected_messages = self.histories.get(agent_name, [])
|
|
60
63
|
elif message_history_config.source == "shared":
|
|
61
|
-
|
|
64
|
+
selected_messages = self.shared_history
|
|
62
65
|
else:
|
|
63
66
|
# Another agent's history
|
|
64
|
-
|
|
67
|
+
selected_messages = self.histories.get(message_history_config.source, [])
|
|
65
68
|
|
|
66
69
|
# Apply filter if specified
|
|
67
70
|
if message_history_config.filter:
|
|
68
|
-
|
|
71
|
+
selected_messages = self._apply_filter(
|
|
72
|
+
selected_messages, message_history_config.filter, context
|
|
73
|
+
)
|
|
69
74
|
|
|
70
|
-
return
|
|
75
|
+
return selected_messages
|
|
71
76
|
|
|
72
77
|
def add_message(
|
|
73
78
|
self,
|
|
74
|
-
agent_name: str,
|
|
79
|
+
agent_name: Optional[str],
|
|
75
80
|
message: ModelMessage,
|
|
76
81
|
also_shared: bool = False,
|
|
77
82
|
) -> None:
|
|
@@ -83,6 +88,12 @@ class MessageHistoryManager:
|
|
|
83
88
|
message: Message to add
|
|
84
89
|
also_shared: Also add to shared history
|
|
85
90
|
"""
|
|
91
|
+
message = self._ensure_message_metadata(message)
|
|
92
|
+
|
|
93
|
+
if agent_name is None:
|
|
94
|
+
self.shared_history.append(message)
|
|
95
|
+
return
|
|
96
|
+
|
|
86
97
|
if agent_name not in self.histories:
|
|
87
98
|
self.histories[agent_name] = []
|
|
88
99
|
|
|
@@ -102,7 +113,7 @@ class MessageHistoryManager:
|
|
|
102
113
|
def _apply_filter(
|
|
103
114
|
self,
|
|
104
115
|
messages: list[ModelMessage],
|
|
105
|
-
|
|
116
|
+
filter_specification: Any,
|
|
106
117
|
context: Optional[Any],
|
|
107
118
|
) -> list[ModelMessage]:
|
|
108
119
|
"""
|
|
@@ -117,33 +128,41 @@ class MessageHistoryManager:
|
|
|
117
128
|
Filtered messages
|
|
118
129
|
"""
|
|
119
130
|
# If it's a callable (Lua function), call it
|
|
120
|
-
if callable(
|
|
131
|
+
if callable(filter_specification):
|
|
121
132
|
try:
|
|
122
|
-
return
|
|
123
|
-
except Exception as
|
|
133
|
+
return filter_specification(messages, context)
|
|
134
|
+
except Exception as exception:
|
|
124
135
|
# If filter fails, return unfiltered
|
|
125
|
-
print(f"Warning: Filter function failed: {
|
|
136
|
+
print(f"Warning: Filter function failed: {exception}")
|
|
126
137
|
return messages
|
|
127
138
|
|
|
128
139
|
# Otherwise it's a tuple (filter_type, filter_arg)
|
|
129
|
-
if not isinstance(
|
|
140
|
+
if not isinstance(filter_specification, tuple) or len(filter_specification) < 2:
|
|
130
141
|
return messages
|
|
131
142
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
if
|
|
136
|
-
return self._filter_last_n(messages,
|
|
137
|
-
elif
|
|
138
|
-
return self.
|
|
139
|
-
elif
|
|
140
|
-
return self.
|
|
141
|
-
elif
|
|
143
|
+
filter_name = filter_specification[0]
|
|
144
|
+
filter_value = filter_specification[1]
|
|
145
|
+
|
|
146
|
+
if filter_name == "last_n":
|
|
147
|
+
return self._filter_last_n(messages, filter_value)
|
|
148
|
+
elif filter_name == "first_n":
|
|
149
|
+
return self._filter_first_n(messages, filter_value)
|
|
150
|
+
elif filter_name == "token_budget":
|
|
151
|
+
return self._filter_by_token_budget(messages, filter_value)
|
|
152
|
+
elif filter_name == "head_tokens":
|
|
153
|
+
return self._filter_head_tokens(messages, filter_value)
|
|
154
|
+
elif filter_name == "tail_tokens":
|
|
155
|
+
return self._filter_tail_tokens(messages, filter_value)
|
|
156
|
+
elif filter_name == "by_role":
|
|
157
|
+
return self._filter_by_role(messages, filter_value)
|
|
158
|
+
elif filter_name == "system_prefix":
|
|
159
|
+
return self._filter_system_prefix(messages)
|
|
160
|
+
elif filter_name == "compose":
|
|
142
161
|
# Apply multiple filters in sequence
|
|
143
|
-
|
|
144
|
-
for
|
|
145
|
-
|
|
146
|
-
return
|
|
162
|
+
filtered_messages = messages
|
|
163
|
+
for filter_step in filter_value:
|
|
164
|
+
filtered_messages = self._apply_filter(filtered_messages, filter_step, context)
|
|
165
|
+
return filtered_messages
|
|
147
166
|
else:
|
|
148
167
|
# Unknown filter type, return unfiltered
|
|
149
168
|
return messages
|
|
@@ -156,6 +175,14 @@ class MessageHistoryManager:
|
|
|
156
175
|
"""Keep only the last N messages."""
|
|
157
176
|
return messages[-n:] if n > 0 else []
|
|
158
177
|
|
|
178
|
+
def _filter_first_n(
|
|
179
|
+
self,
|
|
180
|
+
messages: list[ModelMessage],
|
|
181
|
+
n: int,
|
|
182
|
+
) -> list[ModelMessage]:
|
|
183
|
+
"""Keep only the first N messages."""
|
|
184
|
+
return messages[:n] if n > 0 else []
|
|
185
|
+
|
|
159
186
|
def _filter_by_token_budget(
|
|
160
187
|
self,
|
|
161
188
|
messages: list[ModelMessage],
|
|
@@ -173,22 +200,52 @@ class MessageHistoryManager:
|
|
|
173
200
|
# Rough estimate: 4 chars per token
|
|
174
201
|
max_chars = max_tokens * 4
|
|
175
202
|
|
|
176
|
-
|
|
177
|
-
|
|
203
|
+
filtered_messages = []
|
|
204
|
+
current_character_count = 0
|
|
178
205
|
|
|
179
206
|
# Work backwards from most recent
|
|
180
207
|
for message in reversed(messages):
|
|
181
208
|
# Estimate message size
|
|
182
|
-
|
|
209
|
+
message_character_count = self._estimate_message_chars(message)
|
|
183
210
|
|
|
184
|
-
if
|
|
211
|
+
if current_character_count + message_character_count > max_chars:
|
|
185
212
|
# Would exceed budget, stop here
|
|
186
213
|
break
|
|
187
214
|
|
|
188
|
-
|
|
189
|
-
|
|
215
|
+
filtered_messages.insert(0, message)
|
|
216
|
+
current_character_count += message_character_count
|
|
190
217
|
|
|
191
|
-
return
|
|
218
|
+
return filtered_messages
|
|
219
|
+
|
|
220
|
+
def _filter_head_tokens(
|
|
221
|
+
self,
|
|
222
|
+
messages: list[ModelMessage],
|
|
223
|
+
max_tokens: int,
|
|
224
|
+
) -> list[ModelMessage]:
|
|
225
|
+
"""Keep earliest messages that fit within the token budget."""
|
|
226
|
+
if max_tokens <= 0:
|
|
227
|
+
return []
|
|
228
|
+
|
|
229
|
+
max_chars = max_tokens * 4
|
|
230
|
+
filtered_messages = []
|
|
231
|
+
current_character_count = 0
|
|
232
|
+
|
|
233
|
+
for message in messages:
|
|
234
|
+
message_character_count = self._estimate_message_chars(message)
|
|
235
|
+
if current_character_count + message_character_count > max_chars:
|
|
236
|
+
break
|
|
237
|
+
filtered_messages.append(message)
|
|
238
|
+
current_character_count += message_character_count
|
|
239
|
+
|
|
240
|
+
return filtered_messages
|
|
241
|
+
|
|
242
|
+
def _filter_tail_tokens(
|
|
243
|
+
self,
|
|
244
|
+
messages: list[ModelMessage],
|
|
245
|
+
max_tokens: int,
|
|
246
|
+
) -> list[ModelMessage]:
|
|
247
|
+
"""Keep latest messages that fit within the token budget."""
|
|
248
|
+
return self._filter_by_token_budget(messages, max_tokens)
|
|
192
249
|
|
|
193
250
|
def _filter_by_role(
|
|
194
251
|
self,
|
|
@@ -198,23 +255,61 @@ class MessageHistoryManager:
|
|
|
198
255
|
"""Keep only messages with specified role."""
|
|
199
256
|
return [m for m in messages if self._get_message_role(m) == role]
|
|
200
257
|
|
|
258
|
+
def _filter_system_prefix(
|
|
259
|
+
self,
|
|
260
|
+
messages: list[ModelMessage],
|
|
261
|
+
) -> list[ModelMessage]:
|
|
262
|
+
"""Keep only the leading contiguous system messages."""
|
|
263
|
+
system_prefix_messages: list[ModelMessage] = []
|
|
264
|
+
for message in messages:
|
|
265
|
+
if self._get_message_role(message) != "system":
|
|
266
|
+
break
|
|
267
|
+
system_prefix_messages.append(message)
|
|
268
|
+
return system_prefix_messages
|
|
269
|
+
|
|
270
|
+
def _ensure_message_metadata(self, message: ModelMessage) -> ModelMessage:
|
|
271
|
+
"""Ensure message has id and created_at metadata when dict-based."""
|
|
272
|
+
if not isinstance(message, dict):
|
|
273
|
+
return message
|
|
274
|
+
|
|
275
|
+
if "id" not in message:
|
|
276
|
+
message["id"] = self._next_message_id
|
|
277
|
+
self._next_message_id += 1
|
|
278
|
+
|
|
279
|
+
if "created_at" not in message:
|
|
280
|
+
message["created_at"] = datetime.now(timezone.utc).isoformat()
|
|
281
|
+
|
|
282
|
+
return message
|
|
283
|
+
|
|
284
|
+
def record_checkpoint(self, name: str, message_id: int) -> None:
|
|
285
|
+
"""Record a named checkpoint pointing at a message id."""
|
|
286
|
+
self._checkpoints[name] = message_id
|
|
287
|
+
|
|
288
|
+
def get_checkpoint(self, name: str) -> Optional[int]:
|
|
289
|
+
"""Retrieve a checkpoint id by name."""
|
|
290
|
+
return self._checkpoints.get(name)
|
|
291
|
+
|
|
292
|
+
def next_message_id(self) -> int:
|
|
293
|
+
"""Return the next message id that will be assigned."""
|
|
294
|
+
return self._next_message_id
|
|
295
|
+
|
|
201
296
|
def _estimate_message_chars(self, message: ModelMessage) -> int:
|
|
202
297
|
"""Estimate character count of a message."""
|
|
203
298
|
if isinstance(message, dict):
|
|
204
299
|
# Dict-based message
|
|
205
|
-
|
|
206
|
-
if isinstance(
|
|
207
|
-
return len(
|
|
208
|
-
elif isinstance(
|
|
300
|
+
message_content = message.get("content", "")
|
|
301
|
+
if isinstance(message_content, str):
|
|
302
|
+
return len(message_content)
|
|
303
|
+
elif isinstance(message_content, list):
|
|
209
304
|
# Multiple content parts
|
|
210
|
-
|
|
211
|
-
for part in
|
|
305
|
+
total_character_count = 0
|
|
306
|
+
for part in message_content:
|
|
212
307
|
if isinstance(part, dict):
|
|
213
|
-
|
|
308
|
+
total_character_count += len(str(part.get("text", "")))
|
|
214
309
|
else:
|
|
215
|
-
|
|
216
|
-
return
|
|
217
|
-
return len(str(
|
|
310
|
+
total_character_count += len(str(part))
|
|
311
|
+
return total_character_count
|
|
312
|
+
return len(str(message_content))
|
|
218
313
|
else:
|
|
219
314
|
# Pydantic AI ModelMessage object
|
|
220
315
|
try:
|