tokenable 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,155 @@
1
+ """Configuration loading and volume resolution for TokEnable."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from fnmatch import fnmatch
8
+ from pathlib import Path
9
+
10
+ from tokenable.models import (
11
+ BudgetConfig,
12
+ OverrideConfig,
13
+ TelemetryConfig,
14
+ TokEnableConfig,
15
+ )
16
+
17
+ CONFIG_FILENAME = "tokenable.config.json"
18
+
19
+ _VOLUME_NAMES: dict[str, int] = {
20
+ "low": 100,
21
+ "medium": 1000,
22
+ "high": 10000,
23
+ }
24
+
25
+
26
+ def parseVolume(value: str) -> int | None:
27
+ """Parse a volume string (name or integer)."""
28
+ lower = value.strip().lower()
29
+ if lower in _VOLUME_NAMES:
30
+ return _VOLUME_NAMES[lower]
31
+ try:
32
+ return int(value)
33
+ except ValueError:
34
+ return None
35
+
36
+
37
+ def getEnvVolume() -> int | None:
38
+ """Read volume from TOKENABLE_VOLUME env var."""
39
+ val = os.environ.get("TOKENABLE_VOLUME")
40
+ if val:
41
+ return parseVolume(val)
42
+ return None
43
+
44
+
45
+ def _find_config_file(start: Path | None = None) -> Path | None:
46
+ """Walk up from start directory looking for config file."""
47
+ current = start or Path.cwd()
48
+ current = current.resolve()
49
+ while True:
50
+ candidate = current / CONFIG_FILENAME
51
+ if candidate.is_file():
52
+ return candidate
53
+ parent = current.parent
54
+ if parent == current:
55
+ break
56
+ current = parent
57
+ return None
58
+
59
+
60
+ def load_config(path: str | None = None) -> TokEnableConfig | None:
61
+ """Load config from explicit path, env var, or by walking up from CWD."""
62
+ config_path: Path | None = None
63
+
64
+ if path:
65
+ config_path = Path(path)
66
+ else:
67
+ env_path = os.environ.get("TOKENABLE_CONFIG")
68
+ config_path = Path(env_path) if env_path else _find_config_file()
69
+
70
+ if config_path is None or not config_path.is_file():
71
+ return None
72
+
73
+ raw = json.loads(config_path.read_text("utf-8"))
74
+ return _parse_config(raw)
75
+
76
+
77
+ def _parse_config(raw: dict) -> TokEnableConfig:
78
+ """Parse raw JSON dict into TokEnableConfig, handling deprecated fields."""
79
+ overrides = [
80
+ OverrideConfig(pattern=o["pattern"], volume=o.get("volume"))
81
+ for o in raw.get("overrides", [])
82
+ ]
83
+
84
+ budgets: BudgetConfig | None = None
85
+ if "budgets" in raw:
86
+ b = raw["budgets"]
87
+ budgets = BudgetConfig(
88
+ warn=b.get("warn"),
89
+ block=b.get("block"),
90
+ require_approval=b.get("requireApproval"),
91
+ approvers=b.get("approvers", []),
92
+ max_monthly_cost=b.get("maxMonthlyCost"),
93
+ )
94
+
95
+ telemetry: TelemetryConfig | None = None
96
+ if "telemetry" in raw:
97
+ t = raw["telemetry"]
98
+ telemetry = TelemetryConfig(
99
+ backend=t["backend"],
100
+ endpoint=t["endpoint"],
101
+ headers=t.get("headers", {}),
102
+ api_key=t.get("apiKey"),
103
+ )
104
+ elif raw.get("apiUrl") or raw.get("apiKey"):
105
+ # Deprecated: map apiUrl/apiKey to inferwise-cloud backend
106
+ telemetry = TelemetryConfig(
107
+ backend="inferwise-cloud",
108
+ endpoint=raw.get("apiUrl", ""),
109
+ api_key=raw.get("apiKey"),
110
+ )
111
+
112
+ return TokEnableConfig(
113
+ default_volume=raw.get("defaultVolume"),
114
+ ignore=raw.get("ignore", []),
115
+ overrides=overrides,
116
+ budgets=budgets,
117
+ telemetry=telemetry,
118
+ api_url=raw.get("apiUrl"),
119
+ api_key=raw.get("apiKey"),
120
+ )
121
+
122
+
123
+ def resolveVolume(
124
+ config: TokEnableConfig | None,
125
+ file_path: str,
126
+ cli_volume: int | None,
127
+ cli_volume_explicit: bool,
128
+ ) -> int:
129
+ """Resolve effective volume for a file path.
130
+
131
+ Priority: CLI explicit > override match > env > config default > 1000.
132
+ """
133
+ if cli_volume_explicit and cli_volume is not None:
134
+ return cli_volume
135
+
136
+ # Check overrides
137
+ if config and config.overrides:
138
+ for override in config.overrides:
139
+ if fnmatch(file_path, override.pattern) and override.volume is not None:
140
+ return override.volume
141
+
142
+ # Env var
143
+ env_vol = getEnvVolume()
144
+ if env_vol is not None:
145
+ return env_vol
146
+
147
+ # Config default
148
+ if config and config.default_volume is not None:
149
+ return config.default_volume
150
+
151
+ # CLI non-explicit (fallback)
152
+ if cli_volume is not None:
153
+ return cli_volume
154
+
155
+ return 1000
@@ -0,0 +1,124 @@
1
+ """Enforcement — pre-commit hooks, CI gate generation, budget policy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import stat
7
+ from pathlib import Path
8
+
9
+ PRE_COMMIT_HOOK = """#!/bin/sh
10
+ # TokEnable cost check — blocks commits that exceed budget thresholds
11
+ # Installed by: tokenable init
12
+ # Config: tokenable.config.json
13
+
14
+ tokenable check .
15
+ """
16
+
17
+ PRE_PUSH_HOOK = """#!/bin/sh
18
+ # TokEnable cost diff — compares HEAD against main before push
19
+ # Installed by: tokenable init
20
+ # Config: tokenable.config.json (budgets.block threshold)
21
+
22
+ tokenable diff --format table 2>&1
23
+ """
24
+
25
+ GITHUB_ACTION_WORKFLOW = """name: TokEnable Cost Check
26
+ on:
27
+ pull_request:
28
+ branches: [main]
29
+
30
+ jobs:
31
+ cost-check:
32
+ runs-on: ubuntu-latest
33
+ steps:
34
+ - uses: actions/checkout@v4
35
+ with:
36
+ fetch-depth: 0
37
+ - uses: actions/setup-python@v5
38
+ with:
39
+ python-version: '3.11'
40
+ - run: pip install tokenable
41
+ - run: tokenable diff --base origin/main --head HEAD --format markdown
42
+ """
43
+
44
+ CONFIG_TEMPLATE = {
45
+ "defaultVolume": 1000,
46
+ "ignore": ["node_modules", "dist", "build", "test", "__tests__", "*.test.ts", "*.spec.ts"],
47
+ "budgets": {"warn": 2000, "block": 50000},
48
+ }
49
+
50
+
51
+ def find_git_root(start_dir: str) -> str | None:
52
+ """Walk up from start_dir to find .git directory."""
53
+ current = Path(start_dir).resolve()
54
+ while True:
55
+ if (current / ".git").exists():
56
+ return str(current)
57
+ parent = current.parent
58
+ if parent == current:
59
+ return None
60
+ current = parent
61
+
62
+
63
+ def detect_hook_manager(git_root: str) -> str:
64
+ """Detect hook manager: husky, lefthook, or git."""
65
+ root = Path(git_root)
66
+ if (root / ".husky").exists():
67
+ return "husky"
68
+ if (root / "lefthook.yml").exists() or (root / ".lefthook.yml").exists():
69
+ return "lefthook"
70
+ return "git"
71
+
72
+
73
+ def install_hook(hooks_dir: str, hook_name: str, content: str) -> str:
74
+ """Install a git hook. Returns 'created', 'exists', or 'updated'."""
75
+ hook_path = Path(hooks_dir) / hook_name
76
+
77
+ if hook_path.exists():
78
+ existing = hook_path.read_text()
79
+ if "tokenable" in existing.lower():
80
+ return "exists"
81
+ hook_path.write_text(existing + "\n" + content)
82
+ hook_path.chmod(hook_path.stat().st_mode | stat.S_IEXEC)
83
+ return "updated"
84
+
85
+ hook_path.parent.mkdir(parents=True, exist_ok=True)
86
+ hook_path.write_text(content)
87
+ hook_path.chmod(hook_path.stat().st_mode | stat.S_IEXEC)
88
+ return "created"
89
+
90
+
91
+ def setup_hooks(git_root: str, hook_type: str = "pre-commit") -> str:
92
+ """Install git hooks. Returns status message."""
93
+ manager = detect_hook_manager(git_root)
94
+ content = PRE_PUSH_HOOK if hook_type == "pre-push" else PRE_COMMIT_HOOK
95
+
96
+ if manager == "husky":
97
+ hooks_dir = os.path.join(git_root, ".husky")
98
+ elif manager == "lefthook":
99
+ return "lefthook detected — add tokenable check to your lefthook.yml manually"
100
+ else:
101
+ hooks_dir = os.path.join(git_root, ".git", "hooks")
102
+
103
+ result = install_hook(hooks_dir, hook_type, content)
104
+ return f"{result} {hook_type} hook ({manager})"
105
+
106
+
107
+ def create_config(project_dir: str) -> str | None:
108
+ """Create tokenable.config.json if it doesn't exist. Returns path or None."""
109
+ import json
110
+
111
+ config_path = Path(project_dir) / "tokenable.config.json"
112
+ if config_path.exists():
113
+ return None
114
+ config_path.write_text(json.dumps(CONFIG_TEMPLATE, indent=2) + "\n")
115
+ return str(config_path)
116
+
117
+
118
+ def generate_ci_workflow(project_dir: str) -> str:
119
+ """Generate GitHub Actions workflow file."""
120
+ workflow_dir = Path(project_dir) / ".github" / "workflows"
121
+ workflow_dir.mkdir(parents=True, exist_ok=True)
122
+ workflow_path = workflow_dir / "tokenable-cost-check.yml"
123
+ workflow_path.write_text(GITHUB_ACTION_WORKFLOW)
124
+ return str(workflow_path)
@@ -0,0 +1,192 @@
1
+ """Cost estimation engine for TokEnable."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import tiktoken
6
+
7
+ from tokenable.models import (
8
+ CalibrationData,
9
+ EstimateRow,
10
+ ModelPricing,
11
+ ModelStats,
12
+ ScanResult,
13
+ TokEnableConfig,
14
+ TokenSource,
15
+ )
16
+ from tokenable.providers import calculate_cost as _calc_cost
17
+ from tokenable.providers import get_model, get_provider_models
18
+
19
+ # ── Token counting ───────────────────────────────────────────────────
20
+
21
+ _encoding: tiktoken.Encoding | None = None
22
+
23
+
24
+ def _get_encoding() -> tiktoken.Encoding:
25
+ global _encoding
26
+ if _encoding is None:
27
+ _encoding = tiktoken.get_encoding("cl100k_base")
28
+ return _encoding
29
+
30
+
31
+ def count_tokens(provider: str, model: str, text: str) -> int:
32
+ """Count tokens in text using cl100k_base encoding."""
33
+ return len(_get_encoding().encode(text))
34
+
35
+
36
+ def count_message_tokens(provider: str, model: str, system: str | None, user: str | None) -> int:
37
+ """Count tokens for a system+user message pair."""
38
+ total = 0
39
+ enc = _get_encoding()
40
+ if system:
41
+ total += len(enc.encode(system))
42
+ if user:
43
+ total += len(enc.encode(user))
44
+ return total
45
+
46
+
47
+ # ── Typical token heuristics ─────────────────────────────────────────
48
+
49
+
50
+ def typicalInputTokens(pricing: ModelPricing) -> int:
51
+ """Estimate typical input tokens for a model."""
52
+ if pricing.context_window < 16384:
53
+ return int(pricing.context_window * 0.25)
54
+ return 4096
55
+
56
+
57
+ # Snake-case aliases
58
+ typical_input_tokens = typicalInputTokens
59
+
60
+
61
+ def typicalOutputTokens(pricing: ModelPricing) -> int:
62
+ """Estimate typical output tokens for a model."""
63
+ raw = int(pricing.max_output_tokens * 0.05)
64
+ return max(512, min(raw, 4096))
65
+
66
+
67
+ # Snake-case aliases
68
+ typical_output_tokens = typicalOutputTokens
69
+
70
+
71
+ # ── Estimate builder ─────────────────────────────────────────────────
72
+
73
+
74
+ def calculateCost(model: ModelPricing, input_tokens: int, output_tokens: int) -> float:
75
+ """Wrapper around providers.calculateCost."""
76
+ return _calc_cost(model, input_tokens, output_tokens)
77
+
78
+
79
+ def buildEstimateRows(
80
+ results: list[ScanResult],
81
+ config: TokEnableConfig | None,
82
+ cli_volume: int | None,
83
+ cli_volume_explicit: bool,
84
+ stats_map: dict[str, ModelStats] | None = None,
85
+ calibration: CalibrationData | None = None,
86
+ ) -> tuple[list[EstimateRow], list[str]]:
87
+ """Build estimate rows from scan results.
88
+
89
+ Returns (rows, unknown_models).
90
+ """
91
+ from tokenable.config import resolveVolume
92
+
93
+ rows: list[EstimateRow] = []
94
+ unknown_models: list[str] = []
95
+
96
+ for result in results:
97
+ model_id = result.model
98
+ if not model_id:
99
+ unknown_models.append(f"{result.provider.value}/?")
100
+ continue
101
+
102
+ pricing = get_model(result.provider, model_id)
103
+ if pricing is None:
104
+ unknown_models.append(f"{result.provider.value}/{model_id}")
105
+ continue
106
+
107
+ volume = resolveVolume(config, result.file_path, cli_volume, cli_volume_explicit)
108
+
109
+ # Determine input tokens
110
+ input_tokens, input_source = _resolve_input_tokens(result, pricing, stats_map)
111
+ # Determine output tokens
112
+ output_tokens, output_source = _resolve_output_tokens(result, pricing, stats_map)
113
+
114
+ # Apply calibration for typical/model_limit sources
115
+ if calibration:
116
+ key = f"{result.provider.value}/{model_id}"
117
+ cal = calibration.models.get(key)
118
+ if cal:
119
+ if input_source in (TokenSource.TYPICAL, TokenSource.MODEL_LIMIT):
120
+ input_tokens = int(input_tokens * cal.input_ratio)
121
+ input_source = TokenSource.CALIBRATED
122
+ if output_source in (TokenSource.TYPICAL, TokenSource.MODEL_LIMIT):
123
+ output_tokens = int(output_tokens * cal.output_ratio)
124
+ output_source = TokenSource.CALIBRATED
125
+
126
+ cost_per_call = calculateCost(pricing, input_tokens, output_tokens)
127
+ monthly_cost = cost_per_call * volume
128
+
129
+ rows.append(
130
+ EstimateRow(
131
+ file=result.file_path,
132
+ line=result.line_number,
133
+ provider=result.provider.value,
134
+ model=pricing.id,
135
+ input_tokens=input_tokens,
136
+ input_token_source=input_source,
137
+ output_tokens=output_tokens,
138
+ output_token_source=output_source,
139
+ cost_per_call=cost_per_call,
140
+ monthly_cost=monthly_cost,
141
+ system_prompt=result.system_prompt,
142
+ user_prompt=result.user_prompt,
143
+ )
144
+ )
145
+
146
+ return rows, unknown_models
147
+
148
+
149
+ def _resolve_input_tokens(
150
+ result: ScanResult,
151
+ pricing: ModelPricing,
152
+ stats_map: dict[str, ModelStats] | None,
153
+ ) -> tuple[int, TokenSource]:
154
+ """Determine input token count and source."""
155
+ # Production stats
156
+ if stats_map:
157
+ key = f"{result.provider.value}/{pricing.id}"
158
+ stats = stats_map.get(key)
159
+ if stats:
160
+ return int(stats.avg_input_tokens), TokenSource.PRODUCTION
161
+
162
+ # Count from code
163
+ if result.system_prompt or result.user_prompt:
164
+ tokens = count_message_tokens(
165
+ result.provider.value, pricing.id, result.system_prompt, result.user_prompt
166
+ )
167
+ if tokens > 0:
168
+ return tokens, TokenSource.CODE
169
+
170
+ # Typical heuristic
171
+ return typicalInputTokens(pricing), TokenSource.TYPICAL
172
+
173
+
174
+ def _resolve_output_tokens(
175
+ result: ScanResult,
176
+ pricing: ModelPricing,
177
+ stats_map: dict[str, ModelStats] | None,
178
+ ) -> tuple[int, TokenSource]:
179
+ """Determine output token count and source."""
180
+ # Production stats
181
+ if stats_map:
182
+ key = f"{result.provider.value}/{pricing.id}"
183
+ stats = stats_map.get(key)
184
+ if stats:
185
+ return int(stats.avg_output_tokens), TokenSource.PRODUCTION
186
+
187
+ # Explicit max_output_tokens in code
188
+ if result.max_output_tokens is not None:
189
+ return result.max_output_tokens, TokenSource.MODEL_LIMIT
190
+
191
+ # Typical heuristic
192
+ return typicalOutputTokens(pricing), TokenSource.TYPICAL
@@ -0,0 +1,101 @@
1
+ """Fixer module — applies model swap recommendations to source files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+
8
+ from tokenable.models import AppliedSwap, ApplyResult, ModelSwap, SkippedSwap
9
+
10
+ SEARCH_WINDOW = 10
11
+
12
+
13
+ def apply_model_swap(
14
+ file_content: str,
15
+ line_number: int,
16
+ current_model: str,
17
+ suggested_model: str,
18
+ ) -> dict[str, object] | None:
19
+ """Replace current_model with suggested_model near line_number, preserving quote style."""
20
+ lines = file_content.splitlines()
21
+ start = max(0, line_number - SEARCH_WINDOW)
22
+ end = min(len(lines), line_number + SEARCH_WINDOW)
23
+
24
+ for i in range(start, end):
25
+ line = lines[i]
26
+ for quote in ('"', "'", "`"):
27
+ token = f"{quote}{current_model}{quote}"
28
+ if token in line:
29
+ lines[i] = line.replace(token, f"{quote}{suggested_model}{quote}", 1)
30
+ return {"content": "\n".join(lines), "actual_line": i + 1}
31
+ return None
32
+
33
+
34
+ def apply_recommendations(
35
+ swaps: list[ModelSwap],
36
+ base_path: str,
37
+ dry_run: bool = False,
38
+ ) -> ApplyResult:
39
+ """Apply swaps grouped by file, bottom-to-top to preserve line numbers."""
40
+ applied: list[AppliedSwap] = []
41
+ skipped: list[SkippedSwap] = []
42
+ total_savings = 0.0
43
+
44
+ by_file: dict[str, list[ModelSwap]] = defaultdict(list)
45
+ for s in swaps:
46
+ by_file[s.file].append(s)
47
+
48
+ for file_rel, file_swaps in by_file.items():
49
+ path = Path(base_path) / file_rel
50
+ if not path.exists():
51
+ for s in file_swaps:
52
+ skipped.append(
53
+ SkippedSwap(
54
+ file=s.file,
55
+ line=s.line,
56
+ from_model=s.current_model,
57
+ to_model=s.suggested_model,
58
+ reason="file not found",
59
+ )
60
+ )
61
+ continue
62
+
63
+ content = path.read_text()
64
+ modified = False
65
+ sorted_swaps = sorted(file_swaps, key=lambda s: s.line, reverse=True)
66
+
67
+ for s in sorted_swaps:
68
+ result = apply_model_swap(content, s.line, s.current_model, s.suggested_model)
69
+ if result is None:
70
+ skipped.append(
71
+ SkippedSwap(
72
+ file=s.file,
73
+ line=s.line,
74
+ from_model=s.current_model,
75
+ to_model=s.suggested_model,
76
+ reason="model string not found as literal",
77
+ )
78
+ )
79
+ else:
80
+ content = result["content"] # type: ignore[assignment]
81
+ modified = True
82
+ applied.append(
83
+ AppliedSwap(
84
+ file=s.file,
85
+ line=s.line,
86
+ from_model=s.current_model,
87
+ to_model=s.suggested_model,
88
+ )
89
+ )
90
+ total_savings += s.monthly_savings
91
+
92
+ if modified and not dry_run:
93
+ path.write_text(content)
94
+
95
+ return ApplyResult(
96
+ applied=applied,
97
+ skipped=skipped,
98
+ total_applied=len(applied),
99
+ total_skipped=len(skipped),
100
+ estimated_monthly_savings=total_savings,
101
+ )