tactus 0.34.1__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +1 -1
- tactus/adapters/broker_log.py +17 -14
- tactus/adapters/channels/__init__.py +17 -15
- tactus/adapters/channels/base.py +16 -7
- tactus/adapters/channels/broker.py +43 -13
- tactus/adapters/channels/cli.py +19 -15
- tactus/adapters/channels/host.py +40 -25
- tactus/adapters/channels/ipc.py +82 -31
- tactus/adapters/channels/sse.py +41 -23
- tactus/adapters/cli_hitl.py +19 -19
- tactus/adapters/cli_log.py +4 -4
- tactus/adapters/control_loop.py +138 -99
- tactus/adapters/cost_collector_log.py +9 -9
- tactus/adapters/file_storage.py +56 -52
- tactus/adapters/http_callback_log.py +23 -13
- tactus/adapters/ide_log.py +17 -9
- tactus/adapters/lua_tools.py +4 -5
- tactus/adapters/mcp.py +16 -19
- tactus/adapters/mcp_manager.py +46 -30
- tactus/adapters/memory.py +9 -9
- tactus/adapters/plugins.py +42 -42
- tactus/broker/client.py +75 -78
- tactus/broker/protocol.py +57 -57
- tactus/broker/server.py +252 -197
- tactus/cli/app.py +3 -1
- tactus/cli/control.py +2 -2
- tactus/core/config_manager.py +181 -135
- tactus/core/dependencies/registry.py +66 -48
- tactus/core/dsl_stubs.py +222 -163
- tactus/core/exceptions.py +10 -1
- tactus/core/execution_context.py +152 -112
- tactus/core/lua_sandbox.py +72 -64
- tactus/core/message_history_manager.py +138 -43
- tactus/core/mocking.py +41 -27
- tactus/core/output_validator.py +49 -44
- tactus/core/registry.py +94 -80
- tactus/core/runtime.py +211 -176
- tactus/core/template_resolver.py +16 -16
- tactus/core/yaml_parser.py +55 -45
- tactus/docs/extractor.py +7 -6
- tactus/ide/server.py +119 -78
- tactus/primitives/control.py +10 -6
- tactus/primitives/file.py +48 -46
- tactus/primitives/handles.py +47 -35
- tactus/primitives/host.py +29 -27
- tactus/primitives/human.py +154 -137
- tactus/primitives/json.py +22 -23
- tactus/primitives/log.py +26 -26
- tactus/primitives/message_history.py +285 -31
- tactus/primitives/model.py +15 -9
- tactus/primitives/procedure.py +86 -64
- tactus/primitives/procedure_callable.py +58 -51
- tactus/primitives/retry.py +31 -29
- tactus/primitives/session.py +42 -29
- tactus/primitives/state.py +54 -43
- tactus/primitives/step.py +9 -13
- tactus/primitives/system.py +34 -21
- tactus/primitives/tool.py +44 -31
- tactus/primitives/tool_handle.py +76 -54
- tactus/primitives/toolset.py +25 -22
- tactus/sandbox/config.py +4 -4
- tactus/sandbox/container_runner.py +161 -107
- tactus/sandbox/docker_manager.py +20 -20
- tactus/sandbox/entrypoint.py +16 -14
- tactus/sandbox/protocol.py +15 -15
- tactus/stdlib/classify/llm.py +1 -3
- tactus/stdlib/core/validation.py +0 -3
- tactus/testing/pydantic_eval_runner.py +1 -1
- tactus/utils/asyncio_helpers.py +27 -0
- tactus/utils/cost_calculator.py +7 -7
- tactus/utils/model_pricing.py +11 -12
- tactus/utils/safe_file_library.py +156 -132
- tactus/utils/safe_libraries.py +27 -27
- tactus/validation/error_listener.py +18 -5
- tactus/validation/semantic_visitor.py +392 -333
- tactus/validation/validator.py +89 -49
- {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/METADATA +15 -3
- {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/RECORD +81 -80
- {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/WHEEL +0 -0
- {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/entry_points.txt +0 -0
- {tactus-0.34.1.dist-info → tactus-0.35.1.dist-info}/licenses/LICENSE +0 -0
tactus/sandbox/docker_manager.py
CHANGED
|
@@ -9,7 +9,7 @@ import logging
|
|
|
9
9
|
import shutil
|
|
10
10
|
import subprocess
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
from typing import
|
|
12
|
+
from typing import Optional
|
|
13
13
|
|
|
14
14
|
logger = logging.getLogger(__name__)
|
|
15
15
|
|
|
@@ -18,7 +18,7 @@ DEFAULT_IMAGE_NAME = "tactus-sandbox"
|
|
|
18
18
|
DEFAULT_IMAGE_TAG = "local"
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def resolve_dockerfile_path(tactus_root: Path) ->
|
|
21
|
+
def resolve_dockerfile_path(tactus_root: Path) -> tuple[Path, str]:
|
|
22
22
|
"""
|
|
23
23
|
Choose the appropriate Dockerfile for the sandbox build.
|
|
24
24
|
|
|
@@ -94,7 +94,7 @@ def calculate_source_hash(tactus_root: Path) -> str:
|
|
|
94
94
|
return hasher.hexdigest()[:16]
|
|
95
95
|
|
|
96
96
|
|
|
97
|
-
def is_docker_available() ->
|
|
97
|
+
def is_docker_available() -> tuple[bool, str]:
|
|
98
98
|
"""
|
|
99
99
|
Check if Docker is available and running.
|
|
100
100
|
|
|
@@ -134,8 +134,8 @@ def is_docker_available() -> Tuple[bool, str]:
|
|
|
134
134
|
return False, "Docker daemon not responding (timeout after 10s)"
|
|
135
135
|
except FileNotFoundError:
|
|
136
136
|
return False, "Docker CLI not found"
|
|
137
|
-
except Exception as
|
|
138
|
-
return False, f"Docker check failed: {
|
|
137
|
+
except Exception as error:
|
|
138
|
+
return False, f"Docker check failed: {error}"
|
|
139
139
|
|
|
140
140
|
|
|
141
141
|
class DockerManager:
|
|
@@ -263,7 +263,7 @@ class DockerManager:
|
|
|
263
263
|
return True
|
|
264
264
|
|
|
265
265
|
if image_hash != current_hash:
|
|
266
|
-
logger.debug(
|
|
266
|
+
logger.debug("Source hash mismatch: %s != %s", image_hash, current_hash)
|
|
267
267
|
return True
|
|
268
268
|
|
|
269
269
|
return False
|
|
@@ -275,7 +275,7 @@ class DockerManager:
|
|
|
275
275
|
version: str,
|
|
276
276
|
source_hash: Optional[str] = None,
|
|
277
277
|
verbose: bool = False,
|
|
278
|
-
) ->
|
|
278
|
+
) -> tuple[bool, str]:
|
|
279
279
|
"""
|
|
280
280
|
Build the sandbox Docker image.
|
|
281
281
|
|
|
@@ -295,7 +295,7 @@ class DockerManager:
|
|
|
295
295
|
if not context_path.exists():
|
|
296
296
|
return False, f"Build context not found: {context_path}"
|
|
297
297
|
|
|
298
|
-
logger.info(
|
|
298
|
+
logger.info("Building sandbox image: %s", self.full_image_name)
|
|
299
299
|
|
|
300
300
|
cmd = [
|
|
301
301
|
"docker",
|
|
@@ -329,7 +329,7 @@ class DockerManager:
|
|
|
329
329
|
for line in iter(process.stdout.readline, ""):
|
|
330
330
|
if line:
|
|
331
331
|
output_lines.append(line.rstrip())
|
|
332
|
-
logger.info(line.rstrip())
|
|
332
|
+
logger.info("%s", line.rstrip())
|
|
333
333
|
process.wait()
|
|
334
334
|
returncode = process.returncode
|
|
335
335
|
output = "\n".join(output_lines)
|
|
@@ -344,15 +344,15 @@ class DockerManager:
|
|
|
344
344
|
output = result.stderr if result.returncode != 0 else result.stdout
|
|
345
345
|
|
|
346
346
|
if returncode == 0:
|
|
347
|
-
logger.info(
|
|
347
|
+
logger.info("Successfully built: %s", self.full_image_name)
|
|
348
348
|
return True, f"Successfully built {self.full_image_name}"
|
|
349
349
|
else:
|
|
350
350
|
return False, f"Build failed: {output}"
|
|
351
351
|
|
|
352
352
|
except subprocess.TimeoutExpired:
|
|
353
353
|
return False, "Build timed out after 10 minutes"
|
|
354
|
-
except Exception as
|
|
355
|
-
return False, f"Build failed: {
|
|
354
|
+
except Exception as error:
|
|
355
|
+
return False, f"Build failed: {error}"
|
|
356
356
|
|
|
357
357
|
def ensure_image_exists(
|
|
358
358
|
self,
|
|
@@ -360,7 +360,7 @@ class DockerManager:
|
|
|
360
360
|
context_path: Path,
|
|
361
361
|
version: str,
|
|
362
362
|
force_rebuild: bool = False,
|
|
363
|
-
) ->
|
|
363
|
+
) -> tuple[bool, str]:
|
|
364
364
|
"""
|
|
365
365
|
Ensure the sandbox image exists, building if necessary.
|
|
366
366
|
|
|
@@ -378,7 +378,7 @@ class DockerManager:
|
|
|
378
378
|
|
|
379
379
|
return True, f"Image {self.full_image_name} is up to date"
|
|
380
380
|
|
|
381
|
-
def remove_image(self) ->
|
|
381
|
+
def remove_image(self) -> tuple[bool, str]:
|
|
382
382
|
"""
|
|
383
383
|
Remove the sandbox image.
|
|
384
384
|
|
|
@@ -399,10 +399,10 @@ class DockerManager:
|
|
|
399
399
|
return True, f"Removed {self.full_image_name}"
|
|
400
400
|
else:
|
|
401
401
|
return False, f"Failed to remove image: {result.stderr}"
|
|
402
|
-
except Exception as
|
|
403
|
-
return False, f"Failed to remove image: {
|
|
402
|
+
except Exception as error:
|
|
403
|
+
return False, f"Failed to remove image: {error}"
|
|
404
404
|
|
|
405
|
-
def cleanup_old_images(self, keep_tags: Optional[list] = None) -> int:
|
|
405
|
+
def cleanup_old_images(self, keep_tags: Optional[list[str]] = None) -> int:
|
|
406
406
|
"""
|
|
407
407
|
Remove old sandbox images, keeping specified tags.
|
|
408
408
|
|
|
@@ -447,10 +447,10 @@ class DockerManager:
|
|
|
447
447
|
)
|
|
448
448
|
if rm_result.returncode == 0:
|
|
449
449
|
removed += 1
|
|
450
|
-
logger.info(
|
|
450
|
+
logger.info("Removed old image: %s", line)
|
|
451
451
|
|
|
452
452
|
return removed
|
|
453
453
|
|
|
454
|
-
except Exception as
|
|
455
|
-
logger.warning(
|
|
454
|
+
except Exception as error:
|
|
455
|
+
logger.warning("Failed to cleanup old images: %s", error)
|
|
456
456
|
return 0
|
tactus/sandbox/entrypoint.py
CHANGED
|
@@ -18,7 +18,7 @@ import os
|
|
|
18
18
|
import sys
|
|
19
19
|
import time
|
|
20
20
|
import traceback
|
|
21
|
-
from typing import Any,
|
|
21
|
+
from typing import Any, Optional
|
|
22
22
|
|
|
23
23
|
from tactus.sandbox.protocol import ExecutionResult
|
|
24
24
|
|
|
@@ -32,8 +32,8 @@ _LOG_LEVELS = {
|
|
|
32
32
|
"critical": logging.CRITICAL,
|
|
33
33
|
}
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
_log_level = _LOG_LEVELS.get(
|
|
35
|
+
_log_level_name = os.environ.get("TACTUS_LOG_LEVEL", "info").strip().lower()
|
|
36
|
+
_log_level = _LOG_LEVELS.get(_log_level_name, logging.INFO)
|
|
37
37
|
|
|
38
38
|
# CloudWatch-friendly, one line per record.
|
|
39
39
|
_log_fmt = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
|
@@ -53,21 +53,21 @@ if _log_level > logging.DEBUG:
|
|
|
53
53
|
logging.getLogger("tactus.stdlib").setLevel(logging.WARNING)
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def read_request_from_stdin() -> Optional[
|
|
56
|
+
def read_request_from_stdin() -> Optional[dict[str, Any]]:
|
|
57
57
|
"""Read the execution request from stdin as JSON."""
|
|
58
58
|
import json
|
|
59
59
|
|
|
60
60
|
try:
|
|
61
61
|
# Read exactly one JSON message (the initial ExecutionRequest).
|
|
62
62
|
# Keep stdin open for broker responses during execution.
|
|
63
|
-
|
|
64
|
-
if not
|
|
63
|
+
input_line = sys.stdin.readline()
|
|
64
|
+
if not input_line.strip():
|
|
65
65
|
logger.error("No input received on stdin")
|
|
66
66
|
return None
|
|
67
67
|
|
|
68
|
-
return json.loads(
|
|
69
|
-
except json.JSONDecodeError as
|
|
70
|
-
logger.error(
|
|
68
|
+
return json.loads(input_line)
|
|
69
|
+
except json.JSONDecodeError as error:
|
|
70
|
+
logger.error("Failed to parse JSON from stdin: %s", error)
|
|
71
71
|
return None
|
|
72
72
|
|
|
73
73
|
|
|
@@ -82,7 +82,7 @@ def write_result_to_stdout(result: ExecutionResult) -> None:
|
|
|
82
82
|
|
|
83
83
|
async def execute_procedure(
|
|
84
84
|
source: str,
|
|
85
|
-
params:
|
|
85
|
+
params: dict[str, Any],
|
|
86
86
|
source_file_path: Optional[str] = None,
|
|
87
87
|
format: str = "lua",
|
|
88
88
|
run_id: Optional[str] = None,
|
|
@@ -118,14 +118,16 @@ async def execute_procedure(
|
|
|
118
118
|
log_handler = HTTPCallbackLogHandler.from_environment()
|
|
119
119
|
if log_handler:
|
|
120
120
|
logger.info(
|
|
121
|
-
|
|
121
|
+
"[SANDBOX] Using HTTP callback log handler: %s",
|
|
122
|
+
os.environ.get("TACTUS_CALLBACK_URL"),
|
|
122
123
|
)
|
|
123
124
|
else:
|
|
124
125
|
# Otherwise, try broker socket streaming (works without container networking, e.g. stdio/UDS).
|
|
125
126
|
log_handler = BrokerLogHandler.from_environment()
|
|
126
127
|
if log_handler:
|
|
127
128
|
logger.info(
|
|
128
|
-
|
|
129
|
+
"[SANDBOX] Using broker log handler: %s",
|
|
130
|
+
os.environ.get("TACTUS_BROKER_SOCKET"),
|
|
129
131
|
)
|
|
130
132
|
else:
|
|
131
133
|
# Provide cost collection + checkpoint event handling even without IDE callbacks.
|
|
@@ -196,7 +198,7 @@ async def main_async() -> int:
|
|
|
196
198
|
try:
|
|
197
199
|
# Parse request
|
|
198
200
|
request = ExecutionRequest(**request_data)
|
|
199
|
-
logger.info(
|
|
201
|
+
logger.info("Executing procedure (id=%s)", request.execution_id)
|
|
200
202
|
|
|
201
203
|
# Execute procedure
|
|
202
204
|
proc_result = await execute_procedure(
|
|
@@ -218,7 +220,7 @@ async def main_async() -> int:
|
|
|
218
220
|
return 0
|
|
219
221
|
|
|
220
222
|
except Exception as e:
|
|
221
|
-
logger.exception(
|
|
223
|
+
logger.exception("Procedure execution failed: %s", e)
|
|
222
224
|
|
|
223
225
|
duration = time.time() - start_time
|
|
224
226
|
result = ExecutionResult.failure(
|
tactus/sandbox/protocol.py
CHANGED
|
@@ -6,9 +6,9 @@ over stdio between the host process and the sandboxed container.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import json
|
|
9
|
-
from dataclasses import dataclass, field
|
|
10
|
-
from typing import Any, Dict, List, Optional
|
|
9
|
+
from dataclasses import asdict, dataclass, field
|
|
11
10
|
from enum import Enum
|
|
11
|
+
from typing import Any, Optional
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel
|
|
14
14
|
|
|
@@ -46,7 +46,7 @@ class ExecutionRequest:
|
|
|
46
46
|
working_dir: str = "/workspace"
|
|
47
47
|
|
|
48
48
|
# Input parameters for the procedure
|
|
49
|
-
params:
|
|
49
|
+
params: dict[str, Any] = field(default_factory=dict)
|
|
50
50
|
|
|
51
51
|
# Unique execution ID for tracking
|
|
52
52
|
execution_id: Optional[str] = None
|
|
@@ -101,10 +101,10 @@ class ExecutionResult:
|
|
|
101
101
|
exit_code: int = 0
|
|
102
102
|
|
|
103
103
|
# Structured logs from execution
|
|
104
|
-
logs:
|
|
104
|
+
logs: list[dict[str, Any]] = field(default_factory=list)
|
|
105
105
|
|
|
106
106
|
# Metadata about the execution
|
|
107
|
-
metadata:
|
|
107
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
108
108
|
|
|
109
109
|
def to_json(self) -> str:
|
|
110
110
|
"""Serialize to JSON string."""
|
|
@@ -126,8 +126,8 @@ class ExecutionResult:
|
|
|
126
126
|
cls,
|
|
127
127
|
result: Any,
|
|
128
128
|
duration_seconds: float = 0.0,
|
|
129
|
-
logs: Optional[
|
|
130
|
-
metadata: Optional[
|
|
129
|
+
logs: Optional[list[dict[str, Any]]] = None,
|
|
130
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
131
131
|
) -> "ExecutionResult":
|
|
132
132
|
"""Create a successful result."""
|
|
133
133
|
return cls(
|
|
@@ -147,7 +147,7 @@ class ExecutionResult:
|
|
|
147
147
|
traceback: Optional[str] = None,
|
|
148
148
|
duration_seconds: float = 0.0,
|
|
149
149
|
exit_code: int = 1,
|
|
150
|
-
logs: Optional[
|
|
150
|
+
logs: Optional[list[dict[str, Any]]] = None,
|
|
151
151
|
) -> "ExecutionResult":
|
|
152
152
|
"""Create a failed result."""
|
|
153
153
|
return cls(
|
|
@@ -164,7 +164,7 @@ class ExecutionResult:
|
|
|
164
164
|
def timeout(
|
|
165
165
|
cls,
|
|
166
166
|
duration_seconds: float,
|
|
167
|
-
logs: Optional[
|
|
167
|
+
logs: Optional[list[dict[str, Any]]] = None,
|
|
168
168
|
) -> "ExecutionResult":
|
|
169
169
|
"""Create a timeout result."""
|
|
170
170
|
return cls(
|
|
@@ -198,17 +198,17 @@ def extract_result_from_stdout(stdout: str) -> Optional[ExecutionResult]:
|
|
|
198
198
|
|
|
199
199
|
Returns None if no valid result is found.
|
|
200
200
|
"""
|
|
201
|
-
|
|
202
|
-
if
|
|
201
|
+
start_marker_index = stdout.find(RESULT_START_MARKER)
|
|
202
|
+
if start_marker_index == -1:
|
|
203
203
|
return None
|
|
204
204
|
|
|
205
|
-
|
|
206
|
-
if
|
|
205
|
+
end_marker_index = stdout.find(RESULT_END_MARKER, start_marker_index)
|
|
206
|
+
if end_marker_index == -1:
|
|
207
207
|
return None
|
|
208
208
|
|
|
209
209
|
# Extract JSON between markers
|
|
210
|
-
json_start =
|
|
211
|
-
json_str = stdout[json_start:
|
|
210
|
+
json_start = start_marker_index + len(RESULT_START_MARKER)
|
|
211
|
+
json_str = stdout[json_start:end_marker_index].strip()
|
|
212
212
|
|
|
213
213
|
try:
|
|
214
214
|
return ExecutionResult.from_json(json_str)
|
tactus/stdlib/classify/llm.py
CHANGED
|
@@ -221,12 +221,10 @@ Start your response with the classification on its own line."""
|
|
|
221
221
|
|
|
222
222
|
def _parse_response(self, response: str) -> Dict[str, Any]:
|
|
223
223
|
"""Parse classification response to extract value and explanation."""
|
|
224
|
-
if not response:
|
|
224
|
+
if not response or not response.strip():
|
|
225
225
|
return {"value": None, "explanation": None}
|
|
226
226
|
|
|
227
227
|
lines = response.strip().split("\n")
|
|
228
|
-
if not lines:
|
|
229
|
-
return {"value": None, "explanation": None}
|
|
230
228
|
|
|
231
229
|
# First non-empty line should be the classification
|
|
232
230
|
first_line = lines[0].strip()
|
tactus/stdlib/core/validation.py
CHANGED
|
@@ -12,7 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
|
|
12
12
|
|
|
13
13
|
from .eval_models import EvaluationConfig, EvalCase
|
|
14
14
|
|
|
15
|
-
if TYPE_CHECKING:
|
|
15
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
16
16
|
from tactus.core.runtime import TactusRuntime
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Asyncio helper utilities.
|
|
3
|
+
|
|
4
|
+
These helpers protect synchronous codepaths from inheriting a closed event loop.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def clear_closed_event_loop() -> None:
|
|
13
|
+
"""
|
|
14
|
+
Ensure the current thread does not hold a closed event loop.
|
|
15
|
+
|
|
16
|
+
Pytest-asyncio and other frameworks can leave a closed loop set as the
|
|
17
|
+
current event loop after async tests complete. Synchronous code that uses
|
|
18
|
+
asyncio.run() or creates its own event loop should not inherit a closed
|
|
19
|
+
loop reference. This helper resets the current loop to None when needed.
|
|
20
|
+
"""
|
|
21
|
+
try:
|
|
22
|
+
current_loop = asyncio.get_event_loop()
|
|
23
|
+
except RuntimeError:
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
if getattr(current_loop, "is_closed", lambda: False)():
|
|
27
|
+
asyncio.set_event_loop(None)
|
tactus/utils/cost_calculator.py
CHANGED
|
@@ -4,7 +4,7 @@ Cost calculator for LLM usage.
|
|
|
4
4
|
Calculates costs based on token usage and model pricing.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
8
|
from .model_pricing import get_model_pricing, normalize_model_name
|
|
9
9
|
|
|
10
10
|
|
|
@@ -44,20 +44,20 @@ class CostCalculator:
|
|
|
44
44
|
- pricing_found: Whether pricing was found (False = using defaults)
|
|
45
45
|
"""
|
|
46
46
|
# Normalize model name and get provider
|
|
47
|
-
|
|
47
|
+
normalized_model_name, detected_provider = normalize_model_name(model_name, provider)
|
|
48
48
|
|
|
49
49
|
# Get pricing
|
|
50
|
-
|
|
50
|
+
model_pricing = get_model_pricing(model_name, provider)
|
|
51
51
|
|
|
52
52
|
# Calculate costs (pricing is per million tokens)
|
|
53
|
-
prompt_cost = (prompt_tokens / 1_000_000) *
|
|
54
|
-
completion_cost = (completion_tokens / 1_000_000) *
|
|
53
|
+
prompt_cost = (prompt_tokens / 1_000_000) * model_pricing["input"]
|
|
54
|
+
completion_cost = (completion_tokens / 1_000_000) * model_pricing["output"]
|
|
55
55
|
|
|
56
56
|
# Calculate cache savings if applicable
|
|
57
57
|
cache_cost = None
|
|
58
58
|
if cache_tokens and cache_tokens > 0:
|
|
59
59
|
# Cached tokens typically cost 10% of input tokens
|
|
60
|
-
cache_cost = (cache_tokens / 1_000_000) *
|
|
60
|
+
cache_cost = (cache_tokens / 1_000_000) * model_pricing["input"] * 0.9
|
|
61
61
|
|
|
62
62
|
total_cost = prompt_cost + completion_cost
|
|
63
63
|
|
|
@@ -66,7 +66,7 @@ class CostCalculator:
|
|
|
66
66
|
"completion_cost": completion_cost,
|
|
67
67
|
"cache_cost": cache_cost,
|
|
68
68
|
"total_cost": total_cost,
|
|
69
|
-
"model":
|
|
69
|
+
"model": normalized_model_name,
|
|
70
70
|
"provider": detected_provider,
|
|
71
71
|
"pricing_found": True, # Could track if we used DEFAULT_PRICING
|
|
72
72
|
}
|
tactus/utils/model_pricing.py
CHANGED
|
@@ -73,17 +73,16 @@ def normalize_model_name(model_name: str, provider: Optional[str] = None) -> tup
|
|
|
73
73
|
Returns:
|
|
74
74
|
Tuple of (normalized_model_name, provider)
|
|
75
75
|
"""
|
|
76
|
-
# Extract provider from model name if present
|
|
77
|
-
if ":" in model_name:
|
|
78
|
-
parts = model_name.split(":", 1)
|
|
79
|
-
detected_provider = parts[0].lower()
|
|
80
|
-
model_only = parts[1]
|
|
81
|
-
return (model_only, detected_provider)
|
|
82
|
-
|
|
83
76
|
# Check for Bedrock format (anthropic.claude-...)
|
|
84
77
|
if model_name.startswith("anthropic."):
|
|
85
78
|
return (model_name, "bedrock")
|
|
86
79
|
|
|
80
|
+
# Extract provider from model name if present
|
|
81
|
+
if ":" in model_name:
|
|
82
|
+
provider_prefix, model_without_prefix = model_name.split(":", 1)
|
|
83
|
+
detected_provider = provider_prefix.lower()
|
|
84
|
+
return (model_without_prefix, detected_provider)
|
|
85
|
+
|
|
87
86
|
# Use provided provider or try to infer
|
|
88
87
|
if provider:
|
|
89
88
|
return (model_name, provider.lower())
|
|
@@ -111,19 +110,19 @@ def get_model_pricing(model_name: str, provider: Optional[str] = None) -> Dict[s
|
|
|
111
110
|
Returns:
|
|
112
111
|
Dict with 'input' and 'output' pricing per million tokens
|
|
113
112
|
"""
|
|
114
|
-
|
|
113
|
+
normalized_model_name, detected_provider = normalize_model_name(model_name, provider)
|
|
115
114
|
|
|
116
115
|
# Look up pricing
|
|
117
116
|
provider_pricing = MODEL_PRICING.get(detected_provider, {})
|
|
118
|
-
pricing = provider_pricing.get(
|
|
117
|
+
pricing = provider_pricing.get(normalized_model_name)
|
|
119
118
|
|
|
120
119
|
if pricing:
|
|
121
120
|
return pricing
|
|
122
121
|
|
|
123
122
|
# Try without version suffix (e.g., "gpt-4o-2024-11-20" -> "gpt-4o")
|
|
124
|
-
|
|
125
|
-
if len(
|
|
126
|
-
base_name = "-".join(
|
|
123
|
+
base_name_parts = normalized_model_name.split("-")[0:2]
|
|
124
|
+
if len(base_name_parts) >= 2:
|
|
125
|
+
base_name = "-".join(base_name_parts)
|
|
127
126
|
pricing = provider_pricing.get(base_name)
|
|
128
127
|
if pricing:
|
|
129
128
|
return pricing
|