tokenbreak-scanner 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ """TokenBreak Model File Scanner.
2
+
3
+ Audit NLP model artifacts for TokenBreak vulnerabilities by inspecting
4
+ tokenizer configurations and model architectures.
5
+ """
6
+
7
+ __version__ = "0.1.0"
@@ -0,0 +1,203 @@
1
+ """CLI entrypoint for the TokenBreak model scanner."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import sys
7
+
8
+ import click
9
+ from rich.console import Console
10
+ from rich.panel import Panel
11
+ from rich.table import Table
12
+ from rich.text import Text
13
+
14
+ from .inspector import inspect_model
15
+ from .models import RiskLevel, ScannerReport
16
+ from .validator import AttackValidationResult, validate_attack
17
+
18
+ console = Console(stderr=True)
19
+
20
+
21
+ def _build_table(report: ScannerReport) -> Table:
22
+ """Build a Rich table for a scanner report."""
23
+ table = Table(title=f"TokenBreak Scan Report: {report.model_name}", show_header=False)
24
+ table.add_column("Field", style="cyan", no_wrap=True)
25
+ table.add_column("Value", style="white")
26
+
27
+ # Risk level with color
28
+ risk_color = {
29
+ RiskLevel.LOW: "green",
30
+ RiskLevel.HIGH: "red",
31
+ RiskLevel.UNKNOWN: "yellow",
32
+ }.get(report.risk_level, "white")
33
+
34
+ table.add_row("Model Name", report.model_name)
35
+ table.add_row("Model Type", report.model_type)
36
+ table.add_row("Model Family", report.model_family)
37
+ table.add_row("Tokenizer Class", report.tokenizer_class)
38
+ table.add_row("Tokenizer Algorithm", report.tokenizer_algorithm.value)
39
+ table.add_row("Vocab Size", str(report.vocab_size) if report.vocab_size else "N/A")
40
+ table.add_row("Confidence Score", f"{report.confidence_score:.2f}")
41
+ table.add_row(
42
+ "Vulnerable to TokenBreak",
43
+ Text("YES", style="bold red") if report.vulnerable_to_tokenbreak else Text("NO", style="bold green"),
44
+ )
45
+ table.add_row("Risk Level", Text(report.risk_level.value, style=f"bold {risk_color}"))
46
+ table.add_row("Source", report.source)
47
+ table.add_row("Recommendation", report.recommendation)
48
+
49
+ # Evidence tree
50
+ if report.detection_sources:
51
+ table.add_row("", "")
52
+ table.add_row("Detection Sources", Text("(evidence tree)", style="dim"))
53
+ for i, src in enumerate(report.detection_sources, 1):
54
+ bullet = f" {i}. [{src.signal}]"
55
+ detail = f"inferred={src.inferred or 'N/A'}, weight={src.weight:.2f}"
56
+ if src.reason:
57
+ detail += f" — {src.reason}"
58
+ table.add_row(bullet, detail)
59
+
60
+ return table
61
+
62
+
63
+ def _print_json(report: ScannerReport, attack_result: AttackValidationResult | None = None) -> None:
64
+ """Print report as JSON."""
65
+ data = report.model_dump(mode="json")
66
+ if attack_result is not None:
67
+ data["attack_validation"] = attack_result.model_dump(mode="json")
68
+ click.echo(json.dumps(data, indent=2))
69
+
70
+
71
+ def _print_table(report: ScannerReport, attack_result: AttackValidationResult | None = None) -> None:
72
+ """Print report as a Rich table."""
73
+ click.echo()
74
+ table = _build_table(report)
75
+ console.print(table)
76
+
77
+ if attack_result is not None:
78
+ click.echo()
79
+ if attack_result.success:
80
+ console.print(
81
+ Panel(
82
+ f"[bold red]Attack Validation: VULNERABLE[/bold red]\n"
83
+ f"Original text classified as: {attack_result.original_label} "
84
+ f"(confidence: {attack_result.original_confidence:.4f})\n"
85
+ f"Manipulated text: {attack_result.manipulated_text}\n"
86
+ f"Manipulated text classified as: {attack_result.manipulated_label} "
87
+ f"(confidence: {attack_result.manipulated_confidence:.4f})\n"
88
+ f"Bypass successful: TokenBreak evades detection.",
89
+ title="Live Attack Test",
90
+ border_style="red",
91
+ )
92
+ )
93
+ else:
94
+ console.print(
95
+ Panel(
96
+ f"[bold green]Attack Validation: NOT VULNERABLE[/bold green]\n"
97
+ f"Original text classified as: {attack_result.original_label} "
98
+ f"(confidence: {attack_result.original_confidence:.4f})\n"
99
+ f"Manipulated text: {attack_result.manipulated_text or 'N/A'}\n"
100
+ f"TokenBreak did not produce a successful bypass.",
101
+ title="Live Attack Test",
102
+ border_style="green",
103
+ )
104
+ )
105
+ click.echo()
106
+
107
+
108
+ @click.command(name="tokenbreak-scan")
109
+ @click.argument("source")
110
+ @click.option(
111
+ "--output",
112
+ "output_format",
113
+ type=click.Choice(["json", "table"], case_sensitive=False),
114
+ default="table",
115
+ show_default=True,
116
+ help="Output format for the report.",
117
+ )
118
+ @click.option(
119
+ "--download",
120
+ is_flag=True,
121
+ default=False,
122
+ help="Download model files from HuggingFace if source is a model ID.",
123
+ )
124
+ @click.option(
125
+ "--trust-remote-code",
126
+ is_flag=True,
127
+ default=False,
128
+ help="Trust remote code when loading tokenizers.",
129
+ )
130
+ @click.option(
131
+ "--test-attack",
132
+ is_flag=True,
133
+ default=False,
134
+ help="Run a live TokenBreak attack validation against the model. "
135
+ "Requires model weights and a classification head.",
136
+ )
137
+ @click.option(
138
+ "--threshold",
139
+ type=float,
140
+ default=0.995,
141
+ show_default=True,
142
+ help="Confidence threshold for TokenBreak attack validation.",
143
+ )
144
+ @click.version_option(version="0.1.0")
145
+ def main(
146
+ source: str,
147
+ output_format: str,
148
+ download: bool,
149
+ trust_remote_code: bool,
150
+ test_attack: bool,
151
+ threshold: float,
152
+ ) -> None:
153
+ """Scan MODEL_PATH_OR_ID for TokenBreak tokenizer vulnerabilities.
154
+
155
+ SOURCE can be a local directory containing model files
156
+ (config.json, tokenizer.json, etc.) or a HuggingFace model ID.
157
+ """
158
+ try:
159
+ report = inspect_model(
160
+ source,
161
+ download=download,
162
+ trust_remote_code=trust_remote_code,
163
+ )
164
+ except FileNotFoundError as exc:
165
+ console.print(f"[bold red]Error:[/bold red] {exc}")
166
+ sys.exit(2)
167
+ except Exception as exc:
168
+ console.print(f"[bold red]Unexpected error during inspection:[/bold red] {exc}")
169
+ sys.exit(2)
170
+
171
+ attack_result: AttackValidationResult | None = None
172
+ if test_attack:
173
+ if report.vulnerable_to_tokenbreak:
174
+ try:
175
+ attack_result = validate_attack(
176
+ source,
177
+ threshold=threshold,
178
+ download=download,
179
+ trust_remote_code=trust_remote_code,
180
+ )
181
+ except Exception as exc:
182
+ console.print(
183
+ f"[bold yellow]Warning:[/bold yellow] Attack validation failed: {exc}"
184
+ )
185
+ else:
186
+ console.print(
187
+ "[bold yellow]Skipping attack test:[/bold yellow] "
188
+ "Model is not flagged as vulnerable (Unigram tokenizer detected)."
189
+ )
190
+
191
+ if output_format == "json":
192
+ _print_json(report, attack_result)
193
+ else:
194
+ _print_table(report, attack_result)
195
+
196
+ # Exit codes for CI pipelines
197
+ if report.risk_level == RiskLevel.HIGH:
198
+ sys.exit(1)
199
+ sys.exit(0)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
@@ -0,0 +1,294 @@
1
+ """Model file introspection engine.
2
+
3
+ Scans downloaded model artifacts (config.json, tokenizer.json, tokenizer_config.json)
4
+ to determine tokenizer type, model family, and TokenBreak vulnerability.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import logging
11
+ import os
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ from transformers import AutoTokenizer
16
+ from transformers.utils import cached_file
17
+
18
+ from .models import DetectionSource, RiskLevel, ScannerReport, TokenizerAlgorithm
19
+ from .tokenizers import (
20
+ detect_from_remote_source,
21
+ detect_from_runtime_tokenizer,
22
+ detect_from_source_code,
23
+ detect_tokenizer_from_config,
24
+ detect_tokenizer_from_json,
25
+ get_model_family,
26
+ get_recommendation,
27
+ is_vulnerable,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Files we expect to find in a HuggingFace model directory
33
+ CONFIG_FILENAME = "config.json"
34
+ TOKENIZER_CONFIG_FILENAME = "tokenizer_config.json"
35
+ TOKENIZER_JSON_FILENAME = "tokenizer.json"
36
+
37
+ # Weights for each detection signal (must each be ≤ 1.0; total can exceed 1.0
38
+ # because we use a cap-and-normalise strategy).
39
+ SIGNAL_WEIGHTS: dict[str, float] = {
40
+ "tokenizer.json model.type": 0.40,
41
+ "runtime._tokenizer.model": 0.40,
42
+ "source_code_fingerprint": 0.30,
43
+ "remote_source_file": 0.30,
44
+ "tokenizer_config.json class": 0.20,
45
+ "config.json model_type": 0.15,
46
+ }
47
+
48
+
49
+ def _load_json(path: Path | str) -> dict[str, Any] | None:
50
+ """Safely load a JSON file, returning None on any error."""
51
+ try:
52
+ with open(path, encoding="utf-8") as f:
53
+ return json.load(f)
54
+ except (OSError, json.JSONDecodeError):
55
+ return None
56
+
57
+
58
+ def _resolve_model_path(source: str, *, download: bool = False) -> Path:
59
+ """Resolve a model identifier to a local path.
60
+
61
+ * If `source` is an existing local directory, return it directly.
62
+ * If `source` looks like a HuggingFace model ID and `download` is True,
63
+ attempt to download/cache the tokenizer files via `transformers`.
64
+ * Otherwise raise FileNotFoundError.
65
+ """
66
+ local = Path(source)
67
+ if local.is_dir():
68
+ return local.resolve()
69
+
70
+ if download:
71
+ logger.info("Downloading tokenizer files for '%s' from HuggingFace...", source)
72
+ try:
73
+ # Use cached_file to resolve and download individual files
74
+ config_path = cached_file(source, CONFIG_FILENAME, _raise_exceptions_for_missing_entries=False)
75
+ tokenizer_config_path = cached_file(
76
+ source, TOKENIZER_CONFIG_FILENAME, _raise_exceptions_for_missing_entries=False
77
+ )
78
+ tokenizer_json_path = cached_file(
79
+ source, TOKENIZER_JSON_FILENAME, _raise_exceptions_for_missing_entries=False
80
+ )
81
+
82
+ if config_path:
83
+ return Path(config_path).parent.resolve()
84
+ if tokenizer_config_path:
85
+ return Path(tokenizer_config_path).parent.resolve()
86
+ if tokenizer_json_path:
87
+ return Path(tokenizer_json_path).parent.resolve()
88
+ except Exception as exc:
89
+ raise FileNotFoundError(
90
+ f"Could not download or cache model '{source}' from HuggingFace."
91
+ ) from exc
92
+
93
+ raise FileNotFoundError(
94
+ f"Model path not found: '{source}'. "
95
+ "Provide a valid local directory or use --download to fetch from HuggingFace."
96
+ )
97
+
98
+
99
+ def inspect_model(
100
+ source: str,
101
+ *,
102
+ download: bool = False,
103
+ trust_remote_code: bool = False,
104
+ ) -> ScannerReport:
105
+ """Inspect a model directory or HuggingFace model ID and return a vulnerability report.
106
+
107
+ Parameters
108
+ ----------
109
+ source
110
+ Local model directory path or HuggingFace model ID (e.g. ``distilbert-base-uncased``).
111
+ download
112
+ If True and ``source`` is a HuggingFace model ID, download tokenizer files.
113
+ trust_remote_code
114
+ Passed through to ``transformers.AutoTokenizer`` when probing vocab size.
115
+ """
116
+ model_path = _resolve_model_path(source, download=download)
117
+
118
+ # Load available metadata files
119
+ config = _load_json(model_path / CONFIG_FILENAME) or {}
120
+ tokenizer_config = _load_json(model_path / TOKENIZER_CONFIG_FILENAME) or {}
121
+ tokenizer_json = _load_json(model_path / TOKENIZER_JSON_FILENAME)
122
+
123
+ # Extract model type
124
+ model_type = config.get("model_type", "")
125
+ model_family = get_model_family(model_type)
126
+
127
+ # ── Detection: collect signals, then aggregate ──
128
+ sources: list[DetectionSource] = []
129
+
130
+ # Signal 1: tokenizer.json "model.type" — most reliable
131
+ if tokenizer_json is not None:
132
+ algo = detect_tokenizer_from_json(tokenizer_json)
133
+ if algo is not None:
134
+ sources.append(
135
+ DetectionSource(
136
+ signal="tokenizer.json model.type",
137
+ value=str(tokenizer_json.get("model", {}).get("type") or tokenizer_json.get("type")),
138
+ inferred=algo.value,
139
+ weight=SIGNAL_WEIGHTS["tokenizer.json model.type"],
140
+ reason="Direct algorithm type from tokenizers library metadata",
141
+ )
142
+ )
143
+ logger.debug("Tokenizer algorithm detected from tokenizer.json: %s", algo)
144
+
145
+ # Signal 2: Attempt to load AutoTokenizer and inspect Rust backend
146
+ loaded_tokenizer: Any | None = None
147
+ vocab_size: int | None = None
148
+ tok_cls_name: str = "unknown"
149
+ try:
150
+ loaded_tokenizer = AutoTokenizer.from_pretrained(
151
+ str(model_path),
152
+ trust_remote_code=trust_remote_code,
153
+ local_files_only=True,
154
+ )
155
+ vocab_size = len(loaded_tokenizer)
156
+ tok_cls_name = loaded_tokenizer.__class__.__name__
157
+ except Exception as exc:
158
+ logger.warning("Could not load tokenizer: %s", exc)
159
+
160
+ if loaded_tokenizer is not None:
161
+ algo, reason = detect_from_runtime_tokenizer(loaded_tokenizer)
162
+ if algo is not None:
163
+ sources.append(
164
+ DetectionSource(
165
+ signal="runtime._tokenizer.model",
166
+ value=reason,
167
+ inferred=algo.value,
168
+ weight=SIGNAL_WEIGHTS["runtime._tokenizer.model"],
169
+ reason="Rust fast-tokenizer backend model type",
170
+ )
171
+ )
172
+
173
+ # Signal 3: source-code fingerprint (if inspect.getsource succeeds)
174
+ algo_src, reason_src = detect_from_source_code(loaded_tokenizer)
175
+ if algo_src is not None:
176
+ sources.append(
177
+ DetectionSource(
178
+ signal="source_code_fingerprint",
179
+ value=reason_src,
180
+ inferred=algo_src.value,
181
+ weight=SIGNAL_WEIGHTS["source_code_fingerprint"],
182
+ reason="Keyword fingerprinting on tokenizer class source",
183
+ )
184
+ )
185
+
186
+ # Signal 4: tokenizer_config.json → tokenizer_class / model_type
187
+ algo_cfg = detect_tokenizer_from_config(tokenizer_config)
188
+ if algo_cfg is not None:
189
+ sources.append(
190
+ DetectionSource(
191
+ signal="tokenizer_config.json class",
192
+ value=tokenizer_config.get("tokenizer_class", tokenizer_config.get("model_type", "")),
193
+ inferred=algo_cfg.value,
194
+ weight=SIGNAL_WEIGHTS["tokenizer_config.json class"],
195
+ reason="Tokenizer class name or model_type from tokenizer_config.json",
196
+ )
197
+ )
198
+
199
+ # Signal 5: config.json model_type fallback
200
+ if model_type:
201
+ from .tokenizers import MODEL_TYPE_MAP
202
+
203
+ algo_meta = MODEL_TYPE_MAP.get(model_type)
204
+ if algo_meta is not None:
205
+ sources.append(
206
+ DetectionSource(
207
+ signal="config.json model_type",
208
+ value=model_type,
209
+ inferred=algo_meta.value,
210
+ weight=SIGNAL_WEIGHTS["config.json model_type"],
211
+ reason="Architecture model_type from config.json",
212
+ )
213
+ )
214
+
215
+ # Signal 6: remote source file for trust_remote_code models
216
+ algo_remote, reason_remote = detect_from_remote_source(model_path, trust_remote_code=trust_remote_code)
217
+ if algo_remote is not None:
218
+ sources.append(
219
+ DetectionSource(
220
+ signal="remote_source_file",
221
+ value=reason_remote,
222
+ inferred=algo_remote.value,
223
+ weight=SIGNAL_WEIGHTS["remote_source_file"],
224
+ reason="Tokenization Python module downloaded from HF Hub",
225
+ )
226
+ )
227
+
228
+ # ── Aggregate weighted votes ──
229
+ algorithm = _aggregate_signals(sources)
230
+ confidence_score = _confidence_from_sources(sources)
231
+
232
+ # Risk assessment
233
+ vulnerable = is_vulnerable(algorithm)
234
+ risk_level = RiskLevel.HIGH if vulnerable else RiskLevel.LOW
235
+ if algorithm == TokenizerAlgorithm.UNKNOWN:
236
+ risk_level = RiskLevel.UNKNOWN
237
+
238
+ recommendation = get_recommendation(algorithm)
239
+
240
+ return ScannerReport(
241
+ model_name=Path(source).name if Path(source).exists() else source,
242
+ model_type=model_type or "unknown",
243
+ model_family=model_family,
244
+ tokenizer_class=tok_cls_name,
245
+ tokenizer_algorithm=algorithm,
246
+ vocab_size=vocab_size,
247
+ vulnerable_to_tokenbreak=vulnerable,
248
+ risk_level=risk_level,
249
+ confidence_score=round(confidence_score, 3),
250
+ detection_sources=sources,
251
+ recommendation=recommendation,
252
+ source=str(model_path),
253
+ config_metadata={
254
+ "config.json": config,
255
+ "tokenizer_config.json": tokenizer_config,
256
+ "detection_confidence": confidence_score,
257
+ },
258
+ tokenizer_metadata=tokenizer_json or {},
259
+ )
260
+
261
+
262
+ def _aggregate_signals(sources: list[DetectionSource]) -> TokenizerAlgorithm:
263
+ """Weighted-majority vote over detection signals.
264
+
265
+ Each source contributes ``weight`` points to the algorithm it inferred.
266
+ The algorithm with the highest total weight wins. If no votes were cast,
267
+ returns :attr:`TokenizerAlgorithm.UNKNOWN`.
268
+ """
269
+ from collections import defaultdict
270
+
271
+ votes: dict[TokenizerAlgorithm, float] = defaultdict(float)
272
+ for src in sources:
273
+ if src.inferred:
274
+ try:
275
+ algo = TokenizerAlgorithm(src.inferred)
276
+ except ValueError:
277
+ continue
278
+ votes[algo] += src.weight
279
+
280
+ if not votes:
281
+ return TokenizerAlgorithm.UNKNOWN
282
+
283
+ best_algo = max(votes, key=lambda a: votes[a])
284
+ return best_algo
285
+
286
+
287
+ def _confidence_from_sources(sources: list[DetectionSource]) -> float:
288
+ """Cap-and-normalise confidence from evidence.
289
+
290
+ Sum raw weights, then clamp to ``[0, 1]``. This is deliberately simple so
291
+ that adding more signals cannot push confidence past certainty.
292
+ """
293
+ total = sum(src.weight for src in sources)
294
+ return min(total, 1.0)
@@ -0,0 +1,67 @@
1
+ """Pydantic data models for scanner reports."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+ from typing import Any
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ class TokenizerAlgorithm(str, Enum):
12
+ """Known tokenizer algorithms relevant to TokenBreak."""
13
+
14
+ BPE = "BPE"
15
+ WORDPIECE = "WordPiece"
16
+ UNIGRAM = "Unigram"
17
+ SENTENCEPIECE = "SentencePiece"
18
+ UNKNOWN = "Unknown"
19
+
20
+
21
+ class RiskLevel(str, Enum):
22
+ """Risk assessment levels."""
23
+
24
+ LOW = "Low"
25
+ HIGH = "High"
26
+ UNKNOWN = "Unknown"
27
+
28
+
29
+ class DetectionSource(BaseModel):
30
+ """A single piece of evidence that contributed to the algorithm detection."""
31
+
32
+ signal: str = Field(description="Name of the detection signal")
33
+ value: str | None = Field(default=None, description="Raw value returned by the signal")
34
+ inferred: str | None = Field(default=None, description="Algorithm inferred from this signal")
35
+ weight: float = Field(default=0.0, ge=0.0, le=1.0, description="Confidence weight of this signal")
36
+ reason: str = Field(default="", description="Human-readable explanation")
37
+
38
+
39
+ class ScannerReport(BaseModel):
40
+ """Complete report for a scanned model."""
41
+
42
+ model_name: str = Field(description="Name or identifier of the model")
43
+ model_type: str = Field(description="Model architecture type (e.g., roberta, bert)")
44
+ model_family: str = Field(description="High-level model family (e.g., RoBERTa, BERT)")
45
+ tokenizer_class: str = Field(description="Tokenizer class name (e.g., RobertaTokenizerFast)")
46
+ tokenizer_algorithm: TokenizerAlgorithm = Field(description="Detected tokenizer algorithm")
47
+ vocab_size: int | None = Field(default=None, description="Tokenizer vocabulary size")
48
+ vulnerable_to_tokenbreak: bool = Field(description="Whether model is vulnerable to TokenBreak")
49
+ risk_level: RiskLevel = Field(description="Risk level assessment")
50
+ confidence_score: float = Field(
51
+ default=0.0, ge=-0.01, le=1.01,
52
+ description="Aggregated confidence score (0.0–1.0) for the detection",
53
+ )
54
+ detection_sources: list[DetectionSource] = Field(
55
+ default_factory=list,
56
+ description="Evidence tree showing why the algorithm was detected",
57
+ )
58
+ recommendation: str = Field(description="Remediation recommendation")
59
+ source: str = Field(description="Source of scan: local path or HuggingFace ID")
60
+ config_metadata: dict[str, Any] = Field(
61
+ default_factory=dict,
62
+ description="Raw metadata from config.json and tokenizer_config.json",
63
+ )
64
+ tokenizer_metadata: dict[str, Any] = Field(
65
+ default_factory=dict,
66
+ description="Raw metadata from tokenizer.json",
67
+ )