tactus 0.31.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +49 -0
- tactus/adapters/__init__.py +9 -0
- tactus/adapters/broker_log.py +76 -0
- tactus/adapters/cli_hitl.py +189 -0
- tactus/adapters/cli_log.py +223 -0
- tactus/adapters/cost_collector_log.py +56 -0
- tactus/adapters/file_storage.py +367 -0
- tactus/adapters/http_callback_log.py +109 -0
- tactus/adapters/ide_log.py +71 -0
- tactus/adapters/lua_tools.py +336 -0
- tactus/adapters/mcp.py +289 -0
- tactus/adapters/mcp_manager.py +196 -0
- tactus/adapters/memory.py +53 -0
- tactus/adapters/plugins.py +419 -0
- tactus/backends/http_backend.py +58 -0
- tactus/backends/model_backend.py +35 -0
- tactus/backends/pytorch_backend.py +110 -0
- tactus/broker/__init__.py +12 -0
- tactus/broker/client.py +247 -0
- tactus/broker/protocol.py +183 -0
- tactus/broker/server.py +1123 -0
- tactus/broker/stdio.py +12 -0
- tactus/cli/__init__.py +7 -0
- tactus/cli/app.py +2245 -0
- tactus/cli/commands/__init__.py +0 -0
- tactus/core/__init__.py +32 -0
- tactus/core/config_manager.py +790 -0
- tactus/core/dependencies/__init__.py +14 -0
- tactus/core/dependencies/registry.py +180 -0
- tactus/core/dsl_stubs.py +2117 -0
- tactus/core/exceptions.py +66 -0
- tactus/core/execution_context.py +480 -0
- tactus/core/lua_sandbox.py +508 -0
- tactus/core/message_history_manager.py +236 -0
- tactus/core/mocking.py +286 -0
- tactus/core/output_validator.py +291 -0
- tactus/core/registry.py +499 -0
- tactus/core/runtime.py +2907 -0
- tactus/core/template_resolver.py +142 -0
- tactus/core/yaml_parser.py +301 -0
- tactus/docker/Dockerfile +61 -0
- tactus/docker/entrypoint.sh +69 -0
- tactus/dspy/__init__.py +39 -0
- tactus/dspy/agent.py +1144 -0
- tactus/dspy/broker_lm.py +181 -0
- tactus/dspy/config.py +212 -0
- tactus/dspy/history.py +196 -0
- tactus/dspy/module.py +405 -0
- tactus/dspy/prediction.py +318 -0
- tactus/dspy/signature.py +185 -0
- tactus/formatting/__init__.py +7 -0
- tactus/formatting/formatter.py +437 -0
- tactus/ide/__init__.py +9 -0
- tactus/ide/coding_assistant.py +343 -0
- tactus/ide/server.py +2223 -0
- tactus/primitives/__init__.py +49 -0
- tactus/primitives/control.py +168 -0
- tactus/primitives/file.py +229 -0
- tactus/primitives/handles.py +378 -0
- tactus/primitives/host.py +94 -0
- tactus/primitives/human.py +342 -0
- tactus/primitives/json.py +189 -0
- tactus/primitives/log.py +187 -0
- tactus/primitives/message_history.py +157 -0
- tactus/primitives/model.py +163 -0
- tactus/primitives/procedure.py +564 -0
- tactus/primitives/procedure_callable.py +318 -0
- tactus/primitives/retry.py +155 -0
- tactus/primitives/session.py +152 -0
- tactus/primitives/state.py +182 -0
- tactus/primitives/step.py +209 -0
- tactus/primitives/system.py +93 -0
- tactus/primitives/tool.py +375 -0
- tactus/primitives/tool_handle.py +279 -0
- tactus/primitives/toolset.py +229 -0
- tactus/protocols/__init__.py +38 -0
- tactus/protocols/chat_recorder.py +81 -0
- tactus/protocols/config.py +97 -0
- tactus/protocols/cost.py +31 -0
- tactus/protocols/hitl.py +71 -0
- tactus/protocols/log_handler.py +27 -0
- tactus/protocols/models.py +355 -0
- tactus/protocols/result.py +33 -0
- tactus/protocols/storage.py +90 -0
- tactus/providers/__init__.py +13 -0
- tactus/providers/base.py +92 -0
- tactus/providers/bedrock.py +117 -0
- tactus/providers/google.py +105 -0
- tactus/providers/openai.py +98 -0
- tactus/sandbox/__init__.py +63 -0
- tactus/sandbox/config.py +171 -0
- tactus/sandbox/container_runner.py +1099 -0
- tactus/sandbox/docker_manager.py +433 -0
- tactus/sandbox/entrypoint.py +227 -0
- tactus/sandbox/protocol.py +213 -0
- tactus/stdlib/__init__.py +10 -0
- tactus/stdlib/io/__init__.py +13 -0
- tactus/stdlib/io/csv.py +88 -0
- tactus/stdlib/io/excel.py +136 -0
- tactus/stdlib/io/file.py +90 -0
- tactus/stdlib/io/fs.py +154 -0
- tactus/stdlib/io/hdf5.py +121 -0
- tactus/stdlib/io/json.py +109 -0
- tactus/stdlib/io/parquet.py +83 -0
- tactus/stdlib/io/tsv.py +88 -0
- tactus/stdlib/loader.py +274 -0
- tactus/stdlib/tac/tactus/tools/done.tac +33 -0
- tactus/stdlib/tac/tactus/tools/log.tac +50 -0
- tactus/testing/README.md +273 -0
- tactus/testing/__init__.py +61 -0
- tactus/testing/behave_integration.py +380 -0
- tactus/testing/context.py +486 -0
- tactus/testing/eval_models.py +114 -0
- tactus/testing/evaluation_runner.py +222 -0
- tactus/testing/evaluators.py +634 -0
- tactus/testing/events.py +94 -0
- tactus/testing/gherkin_parser.py +134 -0
- tactus/testing/mock_agent.py +315 -0
- tactus/testing/mock_dependencies.py +234 -0
- tactus/testing/mock_hitl.py +171 -0
- tactus/testing/mock_registry.py +168 -0
- tactus/testing/mock_tools.py +133 -0
- tactus/testing/models.py +115 -0
- tactus/testing/pydantic_eval_runner.py +508 -0
- tactus/testing/steps/__init__.py +13 -0
- tactus/testing/steps/builtin.py +902 -0
- tactus/testing/steps/custom.py +69 -0
- tactus/testing/steps/registry.py +68 -0
- tactus/testing/test_runner.py +489 -0
- tactus/tracing/__init__.py +5 -0
- tactus/tracing/trace_manager.py +417 -0
- tactus/utils/__init__.py +1 -0
- tactus/utils/cost_calculator.py +72 -0
- tactus/utils/model_pricing.py +132 -0
- tactus/utils/safe_file_library.py +502 -0
- tactus/utils/safe_libraries.py +234 -0
- tactus/validation/LuaLexerBase.py +66 -0
- tactus/validation/LuaParserBase.py +23 -0
- tactus/validation/README.md +224 -0
- tactus/validation/__init__.py +7 -0
- tactus/validation/error_listener.py +21 -0
- tactus/validation/generated/LuaLexer.interp +231 -0
- tactus/validation/generated/LuaLexer.py +5548 -0
- tactus/validation/generated/LuaLexer.tokens +124 -0
- tactus/validation/generated/LuaLexerBase.py +66 -0
- tactus/validation/generated/LuaParser.interp +173 -0
- tactus/validation/generated/LuaParser.py +6439 -0
- tactus/validation/generated/LuaParser.tokens +124 -0
- tactus/validation/generated/LuaParserBase.py +23 -0
- tactus/validation/generated/LuaParserVisitor.py +118 -0
- tactus/validation/generated/__init__.py +7 -0
- tactus/validation/grammar/LuaLexer.g4 +123 -0
- tactus/validation/grammar/LuaParser.g4 +178 -0
- tactus/validation/semantic_visitor.py +817 -0
- tactus/validation/validator.py +157 -0
- tactus-0.31.2.dist-info/METADATA +1809 -0
- tactus-0.31.2.dist-info/RECORD +160 -0
- tactus-0.31.2.dist-info/WHEEL +4 -0
- tactus-0.31.2.dist-info/entry_points.txt +2 -0
- tactus-0.31.2.dist-info/licenses/LICENSE +21 -0
tactus/primitives/log.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Log Primitive - Logging operations.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- Log.debug(message, context={}) - Debug logging
|
|
6
|
+
- Log.info(message, context={}) - Info logging
|
|
7
|
+
- Log.warn(message, context={}) - Warning logging
|
|
8
|
+
- Log.error(message, context={}) - Error logging
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Any, Dict, Optional, TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from tactus.protocols.log_handler import LogHandler
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LogPrimitive:
|
|
21
|
+
"""
|
|
22
|
+
Provides logging operations for procedures.
|
|
23
|
+
|
|
24
|
+
All methods log using Python's standard logging module
|
|
25
|
+
with appropriate log levels, and optionally send structured
|
|
26
|
+
events to a LogHandler for custom rendering.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, procedure_id: str, log_handler: Optional["LogHandler"] = None):
|
|
30
|
+
"""
|
|
31
|
+
Initialize Log primitive.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
procedure_id: ID of the running procedure (for context)
|
|
35
|
+
log_handler: Optional handler for structured log events
|
|
36
|
+
"""
|
|
37
|
+
self.procedure_id = procedure_id
|
|
38
|
+
self.logger = logging.getLogger(f"procedure.{procedure_id}")
|
|
39
|
+
self.log_handler = log_handler
|
|
40
|
+
|
|
41
|
+
def _format_message(self, message: str, context: Optional[Dict[str, Any]] = None) -> str:
|
|
42
|
+
"""Format log message with context."""
|
|
43
|
+
if context:
|
|
44
|
+
import json
|
|
45
|
+
|
|
46
|
+
# Convert Lua tables to Python dicts
|
|
47
|
+
context_dict = self._lua_to_python(context)
|
|
48
|
+
context_str = json.dumps(context_dict, indent=2)
|
|
49
|
+
return f"{message}\nContext: {context_str}"
|
|
50
|
+
return message
|
|
51
|
+
|
|
52
|
+
def _lua_to_python(self, obj: Any) -> Any:
|
|
53
|
+
"""Convert Lua objects to Python equivalents recursively."""
|
|
54
|
+
# Check if it's a Lua table
|
|
55
|
+
if hasattr(obj, "items"): # Lua table with dict-like interface
|
|
56
|
+
return {self._lua_to_python(k): self._lua_to_python(v) for k, v in obj.items()}
|
|
57
|
+
elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)): # Lua array
|
|
58
|
+
try:
|
|
59
|
+
return [self._lua_to_python(v) for v in obj]
|
|
60
|
+
except Exception: # noqa: E722
|
|
61
|
+
# If iteration fails, return as-is
|
|
62
|
+
return obj
|
|
63
|
+
else:
|
|
64
|
+
return obj
|
|
65
|
+
|
|
66
|
+
def debug(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Log debug message.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
message: Debug message
|
|
72
|
+
context: Optional context dict
|
|
73
|
+
|
|
74
|
+
Example (Lua):
|
|
75
|
+
Log.debug("Processing item", {index = i, item = item})
|
|
76
|
+
"""
|
|
77
|
+
# Send to log handler if provided
|
|
78
|
+
if self.log_handler:
|
|
79
|
+
from tactus.protocols.models import LogEvent
|
|
80
|
+
|
|
81
|
+
context_dict = self._lua_to_python(context) if context else None
|
|
82
|
+
event = LogEvent(
|
|
83
|
+
level="DEBUG",
|
|
84
|
+
message=message,
|
|
85
|
+
context=context_dict,
|
|
86
|
+
logger_name=self.logger.name,
|
|
87
|
+
procedure_id=self.procedure_id,
|
|
88
|
+
)
|
|
89
|
+
self.log_handler.log(event)
|
|
90
|
+
else:
|
|
91
|
+
# Fall back to Python logging if no handler
|
|
92
|
+
formatted = self._format_message(message, context)
|
|
93
|
+
self.logger.debug(formatted)
|
|
94
|
+
|
|
95
|
+
def info(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
|
|
96
|
+
"""
|
|
97
|
+
Log info message.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
message: Info message
|
|
101
|
+
context: Optional context dict
|
|
102
|
+
|
|
103
|
+
Example (Lua):
|
|
104
|
+
Log.info("Phase complete", {duration = elapsed, items = count})
|
|
105
|
+
"""
|
|
106
|
+
# Send to log handler if provided
|
|
107
|
+
if self.log_handler:
|
|
108
|
+
from tactus.protocols.models import LogEvent
|
|
109
|
+
|
|
110
|
+
context_dict = self._lua_to_python(context) if context else None
|
|
111
|
+
event = LogEvent(
|
|
112
|
+
level="INFO",
|
|
113
|
+
message=message,
|
|
114
|
+
context=context_dict,
|
|
115
|
+
logger_name=self.logger.name,
|
|
116
|
+
procedure_id=self.procedure_id,
|
|
117
|
+
)
|
|
118
|
+
self.log_handler.log(event)
|
|
119
|
+
else:
|
|
120
|
+
# Fall back to Python logging if no handler
|
|
121
|
+
formatted = self._format_message(message, context)
|
|
122
|
+
self.logger.info(formatted)
|
|
123
|
+
|
|
124
|
+
def warn(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Log warning message.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
message: Warning message
|
|
130
|
+
context: Optional context dict
|
|
131
|
+
|
|
132
|
+
Example (Lua):
|
|
133
|
+
Log.warn("Retry limit reached", {attempts = attempts})
|
|
134
|
+
"""
|
|
135
|
+
# Send to log handler if provided
|
|
136
|
+
if self.log_handler:
|
|
137
|
+
from tactus.protocols.models import LogEvent
|
|
138
|
+
|
|
139
|
+
context_dict = self._lua_to_python(context) if context else None
|
|
140
|
+
event = LogEvent(
|
|
141
|
+
level="WARNING",
|
|
142
|
+
message=message,
|
|
143
|
+
context=context_dict,
|
|
144
|
+
logger_name=self.logger.name,
|
|
145
|
+
procedure_id=self.procedure_id,
|
|
146
|
+
)
|
|
147
|
+
self.log_handler.log(event)
|
|
148
|
+
else:
|
|
149
|
+
# Fall back to Python logging if no handler
|
|
150
|
+
formatted = self._format_message(message, context)
|
|
151
|
+
self.logger.warning(formatted)
|
|
152
|
+
|
|
153
|
+
def warning(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
|
|
154
|
+
"""Alias for warn(), matching common logging APIs."""
|
|
155
|
+
self.warn(message, context)
|
|
156
|
+
|
|
157
|
+
def error(self, message: str, context: Optional[Dict[str, Any]] = None) -> None:
|
|
158
|
+
"""
|
|
159
|
+
Log error message.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
message: Error message
|
|
163
|
+
context: Optional context dict
|
|
164
|
+
|
|
165
|
+
Example (Lua):
|
|
166
|
+
Log.error("Operation failed", {error = last_error})
|
|
167
|
+
"""
|
|
168
|
+
# Send to log handler if provided
|
|
169
|
+
if self.log_handler:
|
|
170
|
+
from tactus.protocols.models import LogEvent
|
|
171
|
+
|
|
172
|
+
context_dict = self._lua_to_python(context) if context else None
|
|
173
|
+
event = LogEvent(
|
|
174
|
+
level="ERROR",
|
|
175
|
+
message=message,
|
|
176
|
+
context=context_dict,
|
|
177
|
+
logger_name=self.logger.name,
|
|
178
|
+
procedure_id=self.procedure_id,
|
|
179
|
+
)
|
|
180
|
+
self.log_handler.log(event)
|
|
181
|
+
else:
|
|
182
|
+
# Fall back to Python logging if no handler
|
|
183
|
+
formatted = self._format_message(message, context)
|
|
184
|
+
self.logger.error(formatted)
|
|
185
|
+
|
|
186
|
+
def __repr__(self) -> str:
|
|
187
|
+
return f"LogPrimitive(procedure_id={self.procedure_id})"
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MessageHistory primitive for managing conversation history.
|
|
3
|
+
|
|
4
|
+
Provides Lua-accessible methods for manipulating message history,
|
|
5
|
+
aligned with pydantic-ai's message_history concept.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart
|
|
12
|
+
except ImportError:
|
|
13
|
+
# Fallback types if pydantic_ai not available
|
|
14
|
+
ModelMessage = dict
|
|
15
|
+
ModelRequest = dict
|
|
16
|
+
ModelResponse = dict
|
|
17
|
+
TextPart = dict
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MessageHistoryPrimitive:
|
|
21
|
+
"""
|
|
22
|
+
Primitive for managing conversation message history.
|
|
23
|
+
|
|
24
|
+
Aligned with pydantic-ai's message_history concept.
|
|
25
|
+
|
|
26
|
+
Provides methods to:
|
|
27
|
+
- Append messages to history
|
|
28
|
+
- Inject system messages
|
|
29
|
+
- Clear history
|
|
30
|
+
- Access full history
|
|
31
|
+
- Save/load message history state
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, message_history_manager=None, agent_name: Optional[str] = None):
|
|
35
|
+
"""
|
|
36
|
+
Initialize MessageHistory primitive.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
message_history_manager: MessageHistoryManager instance
|
|
40
|
+
agent_name: Name of the agent this message history belongs to
|
|
41
|
+
"""
|
|
42
|
+
self.message_history_manager = message_history_manager
|
|
43
|
+
self.agent_name = agent_name
|
|
44
|
+
|
|
45
|
+
def append(self, message_data: dict) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Append a message to the message history.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
message_data: Dict with 'role' and 'content' keys
|
|
51
|
+
role: 'user', 'assistant', 'system'
|
|
52
|
+
content: message text
|
|
53
|
+
|
|
54
|
+
Example:
|
|
55
|
+
MessageHistory.append({role = "user", content = "Hello"})
|
|
56
|
+
"""
|
|
57
|
+
if not self.message_history_manager or not self.agent_name:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
role = message_data.get("role", "user")
|
|
61
|
+
content = message_data.get("content", "")
|
|
62
|
+
|
|
63
|
+
# Create a simple message dict
|
|
64
|
+
message = {"role": role, "content": content}
|
|
65
|
+
|
|
66
|
+
self.message_history_manager.add_message(self.agent_name, message)
|
|
67
|
+
|
|
68
|
+
def inject_system(self, text: str) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Inject a system message into the message history.
|
|
71
|
+
|
|
72
|
+
This is useful for providing context or instructions
|
|
73
|
+
for the next agent turn.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
text: System message content
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
MessageHistory.inject_system("Focus on security implications")
|
|
80
|
+
"""
|
|
81
|
+
self.append({"role": "system", "content": text})
|
|
82
|
+
|
|
83
|
+
def clear(self) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Clear the message history for this agent.
|
|
86
|
+
|
|
87
|
+
Example:
|
|
88
|
+
MessageHistory.clear()
|
|
89
|
+
"""
|
|
90
|
+
if not self.message_history_manager or not self.agent_name:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
self.message_history_manager.clear_agent_history(self.agent_name)
|
|
94
|
+
|
|
95
|
+
def get(self) -> list:
|
|
96
|
+
"""
|
|
97
|
+
Get the full message history for this agent.
|
|
98
|
+
|
|
99
|
+
Aligned with pydantic-ai's message_history concept.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of message dicts with 'role' and 'content' keys
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
local messages = MessageHistory.get()
|
|
106
|
+
for i, msg in ipairs(messages) do
|
|
107
|
+
Log.info(msg.role .. ": " .. msg.content)
|
|
108
|
+
end
|
|
109
|
+
"""
|
|
110
|
+
if not self.message_history_manager or not self.agent_name:
|
|
111
|
+
return []
|
|
112
|
+
|
|
113
|
+
messages = self.message_history_manager.histories.get(self.agent_name, [])
|
|
114
|
+
|
|
115
|
+
# Convert to Lua-friendly format
|
|
116
|
+
result = []
|
|
117
|
+
for msg in messages:
|
|
118
|
+
if isinstance(msg, dict):
|
|
119
|
+
result.append({"role": msg.get("role", ""), "content": str(msg.get("content", ""))})
|
|
120
|
+
else:
|
|
121
|
+
# Handle pydantic_ai ModelMessage objects
|
|
122
|
+
try:
|
|
123
|
+
result.append(
|
|
124
|
+
{
|
|
125
|
+
"role": getattr(msg, "role", ""),
|
|
126
|
+
"content": str(getattr(msg, "content", "")),
|
|
127
|
+
}
|
|
128
|
+
)
|
|
129
|
+
except Exception:
|
|
130
|
+
# Fallback: convert to string
|
|
131
|
+
result.append({"role": "unknown", "content": str(msg)})
|
|
132
|
+
|
|
133
|
+
return result
|
|
134
|
+
|
|
135
|
+
def load_from_node(self, node: Any) -> None:
|
|
136
|
+
"""
|
|
137
|
+
Load message history from a graph node.
|
|
138
|
+
|
|
139
|
+
Not yet implemented - placeholder for future graph support.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
node: Graph node containing saved message history
|
|
143
|
+
"""
|
|
144
|
+
# TODO: Implement when graph primitives are added
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
def save_to_node(self, node: Any) -> None:
|
|
148
|
+
"""
|
|
149
|
+
Save message history to a graph node.
|
|
150
|
+
|
|
151
|
+
Not yet implemented - placeholder for future graph support.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
node: Graph node to save message history to
|
|
155
|
+
"""
|
|
156
|
+
# TODO: Implement when graph primitives are added
|
|
157
|
+
pass
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model primitive for ML inference with automatic checkpointing.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
from tactus.core.execution_context import ExecutionContext
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModelPrimitive:
|
|
14
|
+
"""
|
|
15
|
+
Model primitive for ML inference operations.
|
|
16
|
+
|
|
17
|
+
Unlike agents (conversational LLMs), models handle:
|
|
18
|
+
- Classification (sentiment, intent, NER)
|
|
19
|
+
- Extraction (quotes, entities, facts)
|
|
20
|
+
- Embeddings (semantic search, clustering)
|
|
21
|
+
- Custom ML inference (any trained model)
|
|
22
|
+
|
|
23
|
+
Each .predict() call is automatically checkpointed for durability.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model_name: str,
|
|
29
|
+
config: dict,
|
|
30
|
+
context: ExecutionContext | None = None,
|
|
31
|
+
mock_manager: Optional[Any] = None,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Initialize model primitive.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
model_name: Name of the model (for checkpointing)
|
|
38
|
+
config: Model configuration dict with:
|
|
39
|
+
- type: Backend type (http, pytorch, bert, sklearn, etc.)
|
|
40
|
+
- input: Optional input schema
|
|
41
|
+
- output: Optional output schema
|
|
42
|
+
- Backend-specific config (endpoint, path, etc.)
|
|
43
|
+
context: Execution context for checkpointing
|
|
44
|
+
"""
|
|
45
|
+
self.model_name = model_name
|
|
46
|
+
self.config = config
|
|
47
|
+
self.context = context
|
|
48
|
+
self.mock_manager = mock_manager
|
|
49
|
+
|
|
50
|
+
# Extract optional input/output schemas
|
|
51
|
+
self.input_schema = config.get("input", {})
|
|
52
|
+
self.output_schema = config.get("output", {})
|
|
53
|
+
|
|
54
|
+
self.backend = self._create_backend(config)
|
|
55
|
+
|
|
56
|
+
def _create_backend(self, config: dict):
|
|
57
|
+
"""
|
|
58
|
+
Create appropriate backend based on model type.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
config: Model configuration
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Backend instance
|
|
65
|
+
"""
|
|
66
|
+
model_type = config.get("type")
|
|
67
|
+
|
|
68
|
+
if model_type == "http":
|
|
69
|
+
from tactus.backends.http_backend import HTTPModelBackend
|
|
70
|
+
|
|
71
|
+
return HTTPModelBackend(
|
|
72
|
+
endpoint=config["endpoint"],
|
|
73
|
+
timeout=config.get("timeout", 30.0),
|
|
74
|
+
headers=config.get("headers"),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
elif model_type == "pytorch":
|
|
78
|
+
from tactus.backends.pytorch_backend import PyTorchModelBackend
|
|
79
|
+
|
|
80
|
+
return PyTorchModelBackend(
|
|
81
|
+
path=config["path"],
|
|
82
|
+
device=config.get("device", "cpu"),
|
|
83
|
+
labels=config.get("labels"),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"Unknown model type: {model_type}. Supported types: http, pytorch")
|
|
88
|
+
|
|
89
|
+
def predict(self, input_data: Any) -> Any:
|
|
90
|
+
"""
|
|
91
|
+
Run model inference with automatic checkpointing.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
input_data: Input to the model (format depends on backend)
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Model prediction result
|
|
98
|
+
"""
|
|
99
|
+
if self.context is None:
|
|
100
|
+
# No context - run directly without checkpointing
|
|
101
|
+
return self.backend.predict_sync(input_data)
|
|
102
|
+
|
|
103
|
+
# With context - checkpoint the operation
|
|
104
|
+
# Capture source location
|
|
105
|
+
import inspect
|
|
106
|
+
|
|
107
|
+
frame = inspect.currentframe()
|
|
108
|
+
if frame and frame.f_back:
|
|
109
|
+
caller_frame = frame.f_back
|
|
110
|
+
source_info = {
|
|
111
|
+
"file": caller_frame.f_code.co_filename,
|
|
112
|
+
"line": caller_frame.f_lineno,
|
|
113
|
+
"function": caller_frame.f_code.co_name,
|
|
114
|
+
}
|
|
115
|
+
else:
|
|
116
|
+
source_info = None
|
|
117
|
+
|
|
118
|
+
return self.context.checkpoint(
|
|
119
|
+
fn=lambda: self._execute_predict(input_data),
|
|
120
|
+
checkpoint_type="model_predict",
|
|
121
|
+
source_info=source_info,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def _execute_predict(self, input_data: Any) -> Any:
|
|
125
|
+
"""
|
|
126
|
+
Execute the actual prediction.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
input_data: Input to the model
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Model prediction result
|
|
133
|
+
"""
|
|
134
|
+
if self.mock_manager is not None:
|
|
135
|
+
args = input_data if isinstance(input_data, dict) else {"input": input_data}
|
|
136
|
+
mock_result = self.mock_manager.get_mock_response(self.model_name, args)
|
|
137
|
+
if mock_result is not None:
|
|
138
|
+
# Ensure temporal mocks advance and calls are available for assertions.
|
|
139
|
+
try:
|
|
140
|
+
self.mock_manager.record_call(self.model_name, args, mock_result)
|
|
141
|
+
except Exception:
|
|
142
|
+
pass
|
|
143
|
+
return mock_result
|
|
144
|
+
|
|
145
|
+
return self.backend.predict_sync(input_data)
|
|
146
|
+
|
|
147
|
+
def __call__(self, input_data: Any) -> Any:
|
|
148
|
+
"""
|
|
149
|
+
Execute model inference using the callable interface.
|
|
150
|
+
|
|
151
|
+
This is an alias for predict() that enables the unified callable syntax:
|
|
152
|
+
result = classifier({text = "Hello"})
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
input_data: Input to the model (format depends on backend)
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Model prediction result
|
|
159
|
+
"""
|
|
160
|
+
return self.predict(input_data)
|
|
161
|
+
|
|
162
|
+
def __repr__(self) -> str:
|
|
163
|
+
return f"ModelPrimitive({self.model_name}, type={self.config.get('type')})"
|