tactus 0.31.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (160) hide show
  1. tactus/__init__.py +49 -0
  2. tactus/adapters/__init__.py +9 -0
  3. tactus/adapters/broker_log.py +76 -0
  4. tactus/adapters/cli_hitl.py +189 -0
  5. tactus/adapters/cli_log.py +223 -0
  6. tactus/adapters/cost_collector_log.py +56 -0
  7. tactus/adapters/file_storage.py +367 -0
  8. tactus/adapters/http_callback_log.py +109 -0
  9. tactus/adapters/ide_log.py +71 -0
  10. tactus/adapters/lua_tools.py +336 -0
  11. tactus/adapters/mcp.py +289 -0
  12. tactus/adapters/mcp_manager.py +196 -0
  13. tactus/adapters/memory.py +53 -0
  14. tactus/adapters/plugins.py +419 -0
  15. tactus/backends/http_backend.py +58 -0
  16. tactus/backends/model_backend.py +35 -0
  17. tactus/backends/pytorch_backend.py +110 -0
  18. tactus/broker/__init__.py +12 -0
  19. tactus/broker/client.py +247 -0
  20. tactus/broker/protocol.py +183 -0
  21. tactus/broker/server.py +1123 -0
  22. tactus/broker/stdio.py +12 -0
  23. tactus/cli/__init__.py +7 -0
  24. tactus/cli/app.py +2245 -0
  25. tactus/cli/commands/__init__.py +0 -0
  26. tactus/core/__init__.py +32 -0
  27. tactus/core/config_manager.py +790 -0
  28. tactus/core/dependencies/__init__.py +14 -0
  29. tactus/core/dependencies/registry.py +180 -0
  30. tactus/core/dsl_stubs.py +2117 -0
  31. tactus/core/exceptions.py +66 -0
  32. tactus/core/execution_context.py +480 -0
  33. tactus/core/lua_sandbox.py +508 -0
  34. tactus/core/message_history_manager.py +236 -0
  35. tactus/core/mocking.py +286 -0
  36. tactus/core/output_validator.py +291 -0
  37. tactus/core/registry.py +499 -0
  38. tactus/core/runtime.py +2907 -0
  39. tactus/core/template_resolver.py +142 -0
  40. tactus/core/yaml_parser.py +301 -0
  41. tactus/docker/Dockerfile +61 -0
  42. tactus/docker/entrypoint.sh +69 -0
  43. tactus/dspy/__init__.py +39 -0
  44. tactus/dspy/agent.py +1144 -0
  45. tactus/dspy/broker_lm.py +181 -0
  46. tactus/dspy/config.py +212 -0
  47. tactus/dspy/history.py +196 -0
  48. tactus/dspy/module.py +405 -0
  49. tactus/dspy/prediction.py +318 -0
  50. tactus/dspy/signature.py +185 -0
  51. tactus/formatting/__init__.py +7 -0
  52. tactus/formatting/formatter.py +437 -0
  53. tactus/ide/__init__.py +9 -0
  54. tactus/ide/coding_assistant.py +343 -0
  55. tactus/ide/server.py +2223 -0
  56. tactus/primitives/__init__.py +49 -0
  57. tactus/primitives/control.py +168 -0
  58. tactus/primitives/file.py +229 -0
  59. tactus/primitives/handles.py +378 -0
  60. tactus/primitives/host.py +94 -0
  61. tactus/primitives/human.py +342 -0
  62. tactus/primitives/json.py +189 -0
  63. tactus/primitives/log.py +187 -0
  64. tactus/primitives/message_history.py +157 -0
  65. tactus/primitives/model.py +163 -0
  66. tactus/primitives/procedure.py +564 -0
  67. tactus/primitives/procedure_callable.py +318 -0
  68. tactus/primitives/retry.py +155 -0
  69. tactus/primitives/session.py +152 -0
  70. tactus/primitives/state.py +182 -0
  71. tactus/primitives/step.py +209 -0
  72. tactus/primitives/system.py +93 -0
  73. tactus/primitives/tool.py +375 -0
  74. tactus/primitives/tool_handle.py +279 -0
  75. tactus/primitives/toolset.py +229 -0
  76. tactus/protocols/__init__.py +38 -0
  77. tactus/protocols/chat_recorder.py +81 -0
  78. tactus/protocols/config.py +97 -0
  79. tactus/protocols/cost.py +31 -0
  80. tactus/protocols/hitl.py +71 -0
  81. tactus/protocols/log_handler.py +27 -0
  82. tactus/protocols/models.py +355 -0
  83. tactus/protocols/result.py +33 -0
  84. tactus/protocols/storage.py +90 -0
  85. tactus/providers/__init__.py +13 -0
  86. tactus/providers/base.py +92 -0
  87. tactus/providers/bedrock.py +117 -0
  88. tactus/providers/google.py +105 -0
  89. tactus/providers/openai.py +98 -0
  90. tactus/sandbox/__init__.py +63 -0
  91. tactus/sandbox/config.py +171 -0
  92. tactus/sandbox/container_runner.py +1099 -0
  93. tactus/sandbox/docker_manager.py +433 -0
  94. tactus/sandbox/entrypoint.py +227 -0
  95. tactus/sandbox/protocol.py +213 -0
  96. tactus/stdlib/__init__.py +10 -0
  97. tactus/stdlib/io/__init__.py +13 -0
  98. tactus/stdlib/io/csv.py +88 -0
  99. tactus/stdlib/io/excel.py +136 -0
  100. tactus/stdlib/io/file.py +90 -0
  101. tactus/stdlib/io/fs.py +154 -0
  102. tactus/stdlib/io/hdf5.py +121 -0
  103. tactus/stdlib/io/json.py +109 -0
  104. tactus/stdlib/io/parquet.py +83 -0
  105. tactus/stdlib/io/tsv.py +88 -0
  106. tactus/stdlib/loader.py +274 -0
  107. tactus/stdlib/tac/tactus/tools/done.tac +33 -0
  108. tactus/stdlib/tac/tactus/tools/log.tac +50 -0
  109. tactus/testing/README.md +273 -0
  110. tactus/testing/__init__.py +61 -0
  111. tactus/testing/behave_integration.py +380 -0
  112. tactus/testing/context.py +486 -0
  113. tactus/testing/eval_models.py +114 -0
  114. tactus/testing/evaluation_runner.py +222 -0
  115. tactus/testing/evaluators.py +634 -0
  116. tactus/testing/events.py +94 -0
  117. tactus/testing/gherkin_parser.py +134 -0
  118. tactus/testing/mock_agent.py +315 -0
  119. tactus/testing/mock_dependencies.py +234 -0
  120. tactus/testing/mock_hitl.py +171 -0
  121. tactus/testing/mock_registry.py +168 -0
  122. tactus/testing/mock_tools.py +133 -0
  123. tactus/testing/models.py +115 -0
  124. tactus/testing/pydantic_eval_runner.py +508 -0
  125. tactus/testing/steps/__init__.py +13 -0
  126. tactus/testing/steps/builtin.py +902 -0
  127. tactus/testing/steps/custom.py +69 -0
  128. tactus/testing/steps/registry.py +68 -0
  129. tactus/testing/test_runner.py +489 -0
  130. tactus/tracing/__init__.py +5 -0
  131. tactus/tracing/trace_manager.py +417 -0
  132. tactus/utils/__init__.py +1 -0
  133. tactus/utils/cost_calculator.py +72 -0
  134. tactus/utils/model_pricing.py +132 -0
  135. tactus/utils/safe_file_library.py +502 -0
  136. tactus/utils/safe_libraries.py +234 -0
  137. tactus/validation/LuaLexerBase.py +66 -0
  138. tactus/validation/LuaParserBase.py +23 -0
  139. tactus/validation/README.md +224 -0
  140. tactus/validation/__init__.py +7 -0
  141. tactus/validation/error_listener.py +21 -0
  142. tactus/validation/generated/LuaLexer.interp +231 -0
  143. tactus/validation/generated/LuaLexer.py +5548 -0
  144. tactus/validation/generated/LuaLexer.tokens +124 -0
  145. tactus/validation/generated/LuaLexerBase.py +66 -0
  146. tactus/validation/generated/LuaParser.interp +173 -0
  147. tactus/validation/generated/LuaParser.py +6439 -0
  148. tactus/validation/generated/LuaParser.tokens +124 -0
  149. tactus/validation/generated/LuaParserBase.py +23 -0
  150. tactus/validation/generated/LuaParserVisitor.py +118 -0
  151. tactus/validation/generated/__init__.py +7 -0
  152. tactus/validation/grammar/LuaLexer.g4 +123 -0
  153. tactus/validation/grammar/LuaParser.g4 +178 -0
  154. tactus/validation/semantic_visitor.py +817 -0
  155. tactus/validation/validator.py +157 -0
  156. tactus-0.31.2.dist-info/METADATA +1809 -0
  157. tactus-0.31.2.dist-info/RECORD +160 -0
  158. tactus-0.31.2.dist-info/WHEEL +4 -0
  159. tactus-0.31.2.dist-info/entry_points.txt +2 -0
  160. tactus-0.31.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,417 @@
1
+ """
2
+ Trace Manager - API for managing execution traces and debugging sessions.
3
+
4
+ Provides operations for querying, filtering, and analyzing procedure execution traces.
5
+ """
6
+
7
+ import logging
8
+ from typing import Optional, List, Dict, Any
9
+ from pathlib import Path
10
+
11
+ from tactus.protocols.storage import StorageBackend
12
+ from tactus.protocols.models import ExecutionRun, CheckpointEntry, Breakpoint
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class TraceManager:
18
+ """
19
+ Manages execution traces and debugging sessions.
20
+
21
+ Provides API for:
22
+ - Listing and querying execution runs
23
+ - Accessing checkpoint data
24
+ - Managing breakpoints
25
+ - Comparing runs
26
+ - Exporting traces
27
+ """
28
+
29
+ def __init__(self, storage: StorageBackend):
30
+ """
31
+ Initialize trace manager.
32
+
33
+ Args:
34
+ storage: Storage backend for trace persistence
35
+ """
36
+ self.storage = storage
37
+ logger.info("TraceManager initialized")
38
+
39
+ # Run Management
40
+
41
+ def list_runs(
42
+ self, procedure_name: Optional[str] = None, limit: Optional[int] = None
43
+ ) -> List[ExecutionRun]:
44
+ """
45
+ List all execution runs, optionally filtered by procedure name.
46
+
47
+ Args:
48
+ procedure_name: Optional procedure name filter
49
+ limit: Optional limit on number of runs returned
50
+
51
+ Returns:
52
+ List of execution runs, sorted by start time (newest first)
53
+ """
54
+ logger.debug(f"Listing runs (procedure={procedure_name}, limit={limit})")
55
+
56
+ runs = self.storage.list_runs(procedure_name=procedure_name)
57
+
58
+ if limit:
59
+ runs = runs[:limit]
60
+
61
+ return runs
62
+
63
+ def get_run(self, run_id: str) -> ExecutionRun:
64
+ """
65
+ Get complete run data.
66
+
67
+ Args:
68
+ run_id: Run identifier
69
+
70
+ Returns:
71
+ Complete execution run with all checkpoints
72
+
73
+ Raises:
74
+ FileNotFoundError: If run not found
75
+ """
76
+ logger.debug(f"Getting run {run_id}")
77
+ return self.storage.load_run(run_id)
78
+
79
+ def get_checkpoint(self, run_id: str, position: int) -> CheckpointEntry:
80
+ """
81
+ Get specific checkpoint from a run.
82
+
83
+ Args:
84
+ run_id: Run identifier
85
+ position: Checkpoint position (0-indexed)
86
+
87
+ Returns:
88
+ Checkpoint entry
89
+
90
+ Raises:
91
+ FileNotFoundError: If run not found
92
+ IndexError: If position out of range
93
+ """
94
+ logger.debug(f"Getting checkpoint {position} from run {run_id}")
95
+
96
+ run = self.get_run(run_id)
97
+
98
+ if position < 0 or position >= len(run.execution_log):
99
+ raise IndexError(
100
+ f"Checkpoint position {position} out of range (0-{len(run.execution_log) - 1})"
101
+ )
102
+
103
+ return run.execution_log[position]
104
+
105
+ def get_checkpoints(
106
+ self, run_id: str, start: int = 0, end: Optional[int] = None
107
+ ) -> List[CheckpointEntry]:
108
+ """
109
+ Get range of checkpoints from a run.
110
+
111
+ Args:
112
+ run_id: Run identifier
113
+ start: Start position (inclusive, 0-indexed)
114
+ end: End position (exclusive, None = end of log)
115
+
116
+ Returns:
117
+ List of checkpoint entries
118
+ """
119
+ logger.debug(f"Getting checkpoints {start}:{end} from run {run_id}")
120
+
121
+ run = self.get_run(run_id)
122
+
123
+ if end is None:
124
+ return run.execution_log[start:]
125
+ else:
126
+ return run.execution_log[start:end]
127
+
128
+ # Breakpoint Management
129
+
130
+ def set_breakpoint(self, file: str, line: int, condition: Optional[str] = None) -> Breakpoint:
131
+ """
132
+ Set a breakpoint at file:line.
133
+
134
+ Args:
135
+ file: File path
136
+ line: Line number (1-indexed)
137
+ condition: Optional Python expression to evaluate
138
+
139
+ Returns:
140
+ Created breakpoint
141
+ """
142
+ import uuid
143
+
144
+ logger.info(f"Setting breakpoint at {file}:{line}")
145
+
146
+ breakpoint = Breakpoint(
147
+ breakpoint_id=str(uuid.uuid4()),
148
+ file=file,
149
+ line=line,
150
+ condition=condition,
151
+ enabled=True,
152
+ hit_count=0,
153
+ )
154
+
155
+ # Load existing breakpoints for this file
156
+ # Extract procedure name from file path
157
+ procedure_name = Path(file).stem
158
+ breakpoints = self.storage.load_breakpoints(procedure_name)
159
+
160
+ # Add new breakpoint
161
+ breakpoints.append(breakpoint)
162
+
163
+ # Save updated list
164
+ self.storage.save_breakpoints(procedure_name, breakpoints)
165
+
166
+ return breakpoint
167
+
168
+ def remove_breakpoint(self, breakpoint_id: str) -> None:
169
+ """
170
+ Remove a breakpoint.
171
+
172
+ Args:
173
+ breakpoint_id: Breakpoint identifier
174
+ """
175
+ logger.info(f"Removing breakpoint {breakpoint_id}")
176
+
177
+ # We need to search all breakpoint files
178
+ # For now, this is a simple implementation that loads all
179
+ # TODO: Optimize with an index if this becomes a bottleneck
180
+
181
+ # This is a placeholder - in practice we'd need to know which file
182
+ # For now, we'll need to search or maintain an index
183
+ raise NotImplementedError("Remove breakpoint requires file index")
184
+
185
+ def list_breakpoints(self, file: Optional[str] = None) -> List[Breakpoint]:
186
+ """
187
+ List all breakpoints, optionally filtered by file.
188
+
189
+ Args:
190
+ file: Optional file path filter
191
+
192
+ Returns:
193
+ List of breakpoints
194
+ """
195
+ logger.debug(f"Listing breakpoints (file={file})")
196
+
197
+ if file:
198
+ procedure_name = Path(file).stem
199
+ return self.storage.load_breakpoints(procedure_name)
200
+ else:
201
+ # List all breakpoints across all procedures
202
+ # This requires iterating through breakpoint directory
203
+ # For now, require a file parameter
204
+ raise NotImplementedError("Listing all breakpoints requires file parameter")
205
+
206
+ def toggle_breakpoint(self, breakpoint_id: str, enabled: bool) -> None:
207
+ """
208
+ Enable or disable a breakpoint.
209
+
210
+ Args:
211
+ breakpoint_id: Breakpoint identifier
212
+ enabled: Whether to enable or disable
213
+ """
214
+ logger.info(f"Toggling breakpoint {breakpoint_id} to {enabled}")
215
+
216
+ # Similar to remove_breakpoint, we need an index or file reference
217
+ raise NotImplementedError("Toggle breakpoint requires file index")
218
+
219
+ # Query/Analysis
220
+
221
+ def find_checkpoint_after_line(
222
+ self, run_id: str, file: str, line: int
223
+ ) -> Optional[CheckpointEntry]:
224
+ """
225
+ Find nearest checkpoint after specified line.
226
+
227
+ This is used for breakpoint mapping: when a user sets a breakpoint at
228
+ a specific line, we find the next checkpoint that will be hit.
229
+
230
+ Args:
231
+ run_id: Run identifier
232
+ file: File path
233
+ line: Line number (1-indexed)
234
+
235
+ Returns:
236
+ Next checkpoint after the line, or None if no checkpoint found
237
+ """
238
+ logger.debug(f"Finding checkpoint after {file}:{line} in run {run_id}")
239
+
240
+ run = self.get_run(run_id)
241
+
242
+ # Search for first checkpoint with source location >= line
243
+ for checkpoint in run.execution_log:
244
+ if checkpoint.source_location:
245
+ if (
246
+ checkpoint.source_location.file == file
247
+ and checkpoint.source_location.line >= line
248
+ ):
249
+ return checkpoint
250
+
251
+ return None
252
+
253
+ def find_checkpoints_by_type(self, run_id: str, checkpoint_type: str) -> List[CheckpointEntry]:
254
+ """
255
+ Find all checkpoints of a specific type.
256
+
257
+ Args:
258
+ run_id: Run identifier
259
+ checkpoint_type: Checkpoint type (agent_turn, model_predict, etc.)
260
+
261
+ Returns:
262
+ List of matching checkpoints
263
+ """
264
+ logger.debug(f"Finding checkpoints of type '{checkpoint_type}' in run {run_id}")
265
+
266
+ run = self.get_run(run_id)
267
+
268
+ return [cp for cp in run.execution_log if cp.type == checkpoint_type]
269
+
270
+ def compare_runs(self, run_id1: str, run_id2: str) -> Dict[str, Any]:
271
+ """
272
+ Compare two runs for debugging non-determinism.
273
+
274
+ Args:
275
+ run_id1: First run identifier
276
+ run_id2: Second run identifier
277
+
278
+ Returns:
279
+ Comparison results with differences
280
+ """
281
+ logger.info(f"Comparing runs {run_id1} vs {run_id2}")
282
+
283
+ run1 = self.get_run(run_id1)
284
+ run2 = self.get_run(run_id2)
285
+
286
+ comparison = {
287
+ "run1": {
288
+ "id": run_id1,
289
+ "procedure": run1.procedure_name,
290
+ "status": run1.status,
291
+ "checkpoint_count": len(run1.execution_log),
292
+ },
293
+ "run2": {
294
+ "id": run_id2,
295
+ "procedure": run2.procedure_name,
296
+ "status": run2.status,
297
+ "checkpoint_count": len(run2.execution_log),
298
+ },
299
+ "differences": [],
300
+ }
301
+
302
+ # Compare checkpoint counts
303
+ if len(run1.execution_log) != len(run2.execution_log):
304
+ comparison["differences"].append(
305
+ {
306
+ "type": "checkpoint_count_mismatch",
307
+ "run1_count": len(run1.execution_log),
308
+ "run2_count": len(run2.execution_log),
309
+ }
310
+ )
311
+
312
+ # Compare checkpoints position by position
313
+ for i in range(min(len(run1.execution_log), len(run2.execution_log))):
314
+ cp1 = run1.execution_log[i]
315
+ cp2 = run2.execution_log[i]
316
+
317
+ # Compare types
318
+ if cp1.type != cp2.type:
319
+ comparison["differences"].append(
320
+ {
321
+ "type": "checkpoint_type_mismatch",
322
+ "position": i,
323
+ "run1_type": cp1.type,
324
+ "run2_type": cp2.type,
325
+ }
326
+ )
327
+
328
+ # Compare source locations
329
+ if cp1.source_location and cp2.source_location:
330
+ if cp1.source_location.line != cp2.source_location.line:
331
+ comparison["differences"].append(
332
+ {
333
+ "type": "source_location_mismatch",
334
+ "position": i,
335
+ "run1_line": cp1.source_location.line,
336
+ "run2_line": cp2.source_location.line,
337
+ }
338
+ )
339
+
340
+ # Compare results (simple equality check)
341
+ if cp1.result != cp2.result:
342
+ comparison["differences"].append(
343
+ {
344
+ "type": "result_mismatch",
345
+ "position": i,
346
+ "checkpoint_type": cp1.type,
347
+ }
348
+ )
349
+
350
+ return comparison
351
+
352
+ def export_trace(self, run_id: str, format: str = "json") -> str:
353
+ """
354
+ Export trace for external analysis.
355
+
356
+ Args:
357
+ run_id: Run identifier
358
+ format: Export format (json, csv, etc.)
359
+
360
+ Returns:
361
+ Exported trace data as string
362
+
363
+ Raises:
364
+ ValueError: If format not supported
365
+ """
366
+ logger.info(f"Exporting run {run_id} as {format}")
367
+
368
+ run = self.get_run(run_id)
369
+
370
+ if format == "json":
371
+ import json
372
+
373
+ return json.dumps(run.model_dump(), indent=2, default=str)
374
+ else:
375
+ raise ValueError(f"Unsupported export format: {format}")
376
+
377
+ def get_statistics(self, run_id: str) -> Dict[str, Any]:
378
+ """
379
+ Get statistics about a run.
380
+
381
+ Args:
382
+ run_id: Run identifier
383
+
384
+ Returns:
385
+ Statistics dictionary with checkpoint counts, timing, etc.
386
+ """
387
+ logger.debug(f"Getting statistics for run {run_id}")
388
+
389
+ run = self.get_run(run_id)
390
+
391
+ # Count checkpoints by type
392
+ type_counts: Dict[str, int] = {}
393
+ total_duration_ms = 0.0
394
+
395
+ for checkpoint in run.execution_log:
396
+ type_counts[checkpoint.type] = type_counts.get(checkpoint.type, 0) + 1
397
+ if checkpoint.duration_ms:
398
+ total_duration_ms += checkpoint.duration_ms
399
+
400
+ # Calculate timing stats
401
+ if run.start_time and run.end_time:
402
+ total_time_sec = (run.end_time - run.start_time).total_seconds()
403
+ else:
404
+ total_time_sec = None
405
+
406
+ return {
407
+ "run_id": run_id,
408
+ "procedure": run.procedure_name,
409
+ "status": run.status,
410
+ "total_checkpoints": len(run.execution_log),
411
+ "checkpoints_by_type": type_counts,
412
+ "total_duration_ms": total_duration_ms,
413
+ "total_time_sec": total_time_sec,
414
+ "has_source_locations": sum(
415
+ 1 for cp in run.execution_log if cp.source_location is not None
416
+ ),
417
+ }
@@ -0,0 +1 @@
1
+ """Utility modules for Tactus."""
@@ -0,0 +1,72 @@
1
+ """
2
+ Cost calculator for LLM usage.
3
+
4
+ Calculates costs based on token usage and model pricing.
5
+ """
6
+
7
+ from typing import Dict, Any, Optional
8
+ from .model_pricing import get_model_pricing, normalize_model_name
9
+
10
+
11
+ class CostCalculator:
12
+ """
13
+ Calculate LLM costs from token usage and model information.
14
+
15
+ Aligned with pydantic-ai's usage tracking.
16
+ """
17
+
18
+ def calculate_cost(
19
+ self,
20
+ model_name: str,
21
+ provider: Optional[str],
22
+ prompt_tokens: int,
23
+ completion_tokens: int,
24
+ cache_tokens: Optional[int] = None,
25
+ ) -> Dict[str, Any]:
26
+ """
27
+ Calculate cost for a single LLM call.
28
+
29
+ Args:
30
+ model_name: Model identifier
31
+ provider: Provider name (openai, anthropic, bedrock, google)
32
+ prompt_tokens: Number of prompt tokens
33
+ completion_tokens: Number of completion tokens
34
+ cache_tokens: Number of cached tokens (if applicable)
35
+
36
+ Returns:
37
+ Dict with:
38
+ - prompt_cost: Cost for prompt tokens
39
+ - completion_cost: Cost for completion tokens
40
+ - cache_cost: Cost savings from cache (if applicable)
41
+ - total_cost: Total cost
42
+ - model: Normalized model name
43
+ - provider: Detected provider
44
+ - pricing_found: Whether pricing was found (False = using defaults)
45
+ """
46
+ # Normalize model name and get provider
47
+ normalized_model, detected_provider = normalize_model_name(model_name, provider)
48
+
49
+ # Get pricing
50
+ pricing = get_model_pricing(model_name, provider)
51
+
52
+ # Calculate costs (pricing is per million tokens)
53
+ prompt_cost = (prompt_tokens / 1_000_000) * pricing["input"]
54
+ completion_cost = (completion_tokens / 1_000_000) * pricing["output"]
55
+
56
+ # Calculate cache savings if applicable
57
+ cache_cost = None
58
+ if cache_tokens and cache_tokens > 0:
59
+ # Cached tokens typically cost 10% of input tokens
60
+ cache_cost = (cache_tokens / 1_000_000) * pricing["input"] * 0.9
61
+
62
+ total_cost = prompt_cost + completion_cost
63
+
64
+ return {
65
+ "prompt_cost": prompt_cost,
66
+ "completion_cost": completion_cost,
67
+ "cache_cost": cache_cost,
68
+ "total_cost": total_cost,
69
+ "model": normalized_model,
70
+ "provider": detected_provider,
71
+ "pricing_found": True, # Could track if we used DEFAULT_PRICING
72
+ }
@@ -0,0 +1,132 @@
1
+ """
2
+ Model pricing data for cost calculation.
3
+
4
+ Prices are per million tokens in USD.
5
+ Data sourced from provider documentation and pricing pages.
6
+ """
7
+
8
+ from typing import Dict, Optional
9
+
10
+ # Pricing per million tokens (USD)
11
+ MODEL_PRICING: Dict[str, Dict[str, Dict[str, float]]] = {
12
+ "openai": {
13
+ "gpt-4o": {"input": 2.50, "output": 10.00},
14
+ "gpt-4o-mini": {"input": 0.15, "output": 0.60},
15
+ "gpt-4o-2024-11-20": {"input": 2.50, "output": 10.00},
16
+ "gpt-4o-2024-08-06": {"input": 2.50, "output": 10.00},
17
+ "gpt-4o-2024-05-13": {"input": 5.00, "output": 15.00},
18
+ "gpt-4-turbo": {"input": 10.00, "output": 30.00},
19
+ "gpt-4-turbo-preview": {"input": 10.00, "output": 30.00},
20
+ "gpt-4": {"input": 30.00, "output": 60.00},
21
+ "gpt-3.5-turbo": {"input": 0.50, "output": 1.50},
22
+ "o1": {"input": 15.00, "output": 60.00},
23
+ "o1-mini": {"input": 1.10, "output": 4.40},
24
+ "o3-mini": {"input": 1.10, "output": 4.40},
25
+ },
26
+ "anthropic": {
27
+ "claude-3-5-sonnet": {"input": 3.00, "output": 15.00},
28
+ "claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
29
+ "claude-3-5-haiku": {"input": 0.80, "output": 4.00},
30
+ "claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
31
+ "claude-3-opus": {"input": 15.00, "output": 75.00},
32
+ "claude-3-sonnet": {"input": 3.00, "output": 15.00},
33
+ "claude-3-haiku": {"input": 0.25, "output": 1.25},
34
+ "claude-sonnet-4": {"input": 3.00, "output": 15.00},
35
+ "claude-sonnet-4.5": {"input": 3.00, "output": 15.00},
36
+ "claude-opus-4": {"input": 15.00, "output": 75.00},
37
+ },
38
+ "bedrock": {
39
+ # Anthropic models on Bedrock
40
+ "anthropic.claude-3-5-sonnet-20241022-v2:0": {"input": 3.00, "output": 15.00},
41
+ "anthropic.claude-3-5-sonnet-20240620-v1:0": {"input": 3.00, "output": 15.00},
42
+ "anthropic.claude-3-5-haiku-20241022-v1:0": {"input": 0.80, "output": 4.00},
43
+ "anthropic.claude-3-opus-20240229-v1:0": {"input": 15.00, "output": 75.00},
44
+ "anthropic.claude-3-sonnet-20240229-v1:0": {"input": 3.00, "output": 15.00},
45
+ "anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.25, "output": 1.25},
46
+ },
47
+ "google": {
48
+ "gemini-2.5-flash": {"input": 0.30, "output": 2.50},
49
+ "gemini-2.5-pro": {"input": 1.25, "output": 10.00},
50
+ "gemini-2.0-flash": {"input": 0.10, "output": 0.40},
51
+ "gemini-1.5-pro": {"input": 1.25, "output": 5.00},
52
+ "gemini-1.5-flash": {"input": 0.075, "output": 0.30},
53
+ },
54
+ }
55
+
56
+ # Default pricing for unknown models (conservative estimate)
57
+ DEFAULT_PRICING = {"input": 10.00, "output": 30.00}
58
+
59
+
60
+ def normalize_model_name(model_name: str, provider: Optional[str] = None) -> tuple[str, str]:
61
+ """
62
+ Normalize model name and extract provider.
63
+
64
+ Handles formats like:
65
+ - "gpt-4o" -> ("gpt-4o", "openai")
66
+ - "openai:gpt-4o" -> ("gpt-4o", "openai")
67
+ - "anthropic.claude-3-5-sonnet-20241022-v2:0" -> (full name, "bedrock")
68
+
69
+ Args:
70
+ model_name: Model identifier
71
+ provider: Optional provider hint
72
+
73
+ Returns:
74
+ Tuple of (normalized_model_name, provider)
75
+ """
76
+ # Extract provider from model name if present
77
+ if ":" in model_name:
78
+ parts = model_name.split(":", 1)
79
+ detected_provider = parts[0].lower()
80
+ model_only = parts[1]
81
+ return (model_only, detected_provider)
82
+
83
+ # Check for Bedrock format (anthropic.claude-...)
84
+ if model_name.startswith("anthropic."):
85
+ return (model_name, "bedrock")
86
+
87
+ # Use provided provider or try to infer
88
+ if provider:
89
+ return (model_name, provider.lower())
90
+
91
+ # Infer provider from model name patterns
92
+ if model_name.startswith("gpt-") or model_name.startswith("o1") or model_name.startswith("o3"):
93
+ return (model_name, "openai")
94
+ elif model_name.startswith("claude-"):
95
+ return (model_name, "anthropic")
96
+ elif model_name.startswith("gemini-"):
97
+ return (model_name, "google")
98
+
99
+ # Default to openai if unknown
100
+ return (model_name, provider or "openai")
101
+
102
+
103
+ def get_model_pricing(model_name: str, provider: Optional[str] = None) -> Dict[str, float]:
104
+ """
105
+ Get pricing for a model.
106
+
107
+ Args:
108
+ model_name: Model identifier
109
+ provider: Optional provider
110
+
111
+ Returns:
112
+ Dict with 'input' and 'output' pricing per million tokens
113
+ """
114
+ normalized_model, detected_provider = normalize_model_name(model_name, provider)
115
+
116
+ # Look up pricing
117
+ provider_pricing = MODEL_PRICING.get(detected_provider, {})
118
+ pricing = provider_pricing.get(normalized_model)
119
+
120
+ if pricing:
121
+ return pricing
122
+
123
+ # Try without version suffix (e.g., "gpt-4o-2024-11-20" -> "gpt-4o")
124
+ base_model = normalized_model.split("-")[0:2] # Get first two parts
125
+ if len(base_model) >= 2:
126
+ base_name = "-".join(base_model)
127
+ pricing = provider_pricing.get(base_name)
128
+ if pricing:
129
+ return pricing
130
+
131
+ # Return default pricing with warning
132
+ return DEFAULT_PRICING