prompt-compress 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,30 @@
1
+ # Secrets
2
+ .env
3
+
4
+ # Python bytecode / packaging
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+ *.egg-info/
9
+ *.egg
10
+ dist/
11
+ build/
12
+ .eggs/
13
+
14
+ # Virtualenvs
15
+ .venv/
16
+ venv/
17
+ env/
18
+
19
+ # OS / editor cruft
20
+ .DS_Store
21
+ .idea/
22
+ .vscode/
23
+
24
+ # Frontend dependencies / build metadata
25
+ node_modules/
26
+ *.tsbuildinfo
27
+
28
+ # Research caches (regenerate via research/benchmark.py + research/evaluate.py)
29
+ data/results/benchmark/*.json
30
+ !data/results/benchmark/.gitkeep
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 joela03
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,133 @@
1
+ Metadata-Version: 2.4
2
+ Name: prompt-compress
3
+ Version: 0.1.0
4
+ Summary: Structural prompt compression with safety gating
5
+ Project-URL: Homepage, https://github.com/joela03/bayesian-prompt-compressor-
6
+ Project-URL: Repository, https://github.com/joela03/bayesian-prompt-compressor-
7
+ Project-URL: Issues, https://github.com/joela03/bayesian-prompt-compressor-/issues
8
+ Author: joela03
9
+ License: MIT
10
+ License-File: LICENSE
11
+ Keywords: bayesian-optimization,compression,llm,prompt
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Classifier: Programming Language :: Python :: 3.13
21
+ Classifier: Programming Language :: Python :: 3.14
22
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
+ Classifier: Typing :: Typed
24
+ Requires-Python: >=3.10
25
+ Requires-Dist: networkx
26
+ Requires-Dist: numpy
27
+ Requires-Dist: scikit-learn
28
+ Requires-Dist: sentence-transformers
29
+ Provides-Extra: openai
30
+ Requires-Dist: openai>=1.0; extra == 'openai'
31
+ Provides-Extra: research
32
+ Requires-Dist: datasets; extra == 'research'
33
+ Requires-Dist: llmlingua>=0.2; extra == 'research'
34
+ Requires-Dist: matplotlib; extra == 'research'
35
+ Requires-Dist: openai>=1.0; extra == 'research'
36
+ Requires-Dist: pandas; extra == 'research'
37
+ Requires-Dist: python-dotenv; extra == 'research'
38
+ Requires-Dist: scipy; extra == 'research'
39
+ Requires-Dist: seaborn; extra == 'research'
40
+ Description-Content-Type: text/markdown
41
+
42
+ # prompt-compress
43
+
44
+ Structural prompt compression for production LLM apps. Where LLMLingua removes individual low-perplexity tokens, this library parses your system prompt into named components (instruction, examples, constraints, style, context), uses Bayesian optimisation to search which components to keep and how aggressively to compress each, scores candidates by semantic similarity to the original, and gates every output through a post-compression validator (persona / placeholder / similarity). Prompts that are already information-dense are detected up front and passed through unchanged.
45
+
46
+ ## Install
47
+
48
+ ```bash
49
+ pip install prompt-compress
50
+ ```
51
+
52
+ ## Quickstart — production integration
53
+
54
+ ```python
55
+ from prompt_compress import PromptCompressor, CompressionFailedError
56
+
57
+ compressor = PromptCompressor()
58
+
59
+ try:
60
+ result = compressor.compress(
61
+ SYSTEM_PROMPT,
62
+ min_similarity=0.80,
63
+ on_failure='raise',
64
+ )
65
+ SYSTEM_PROMPT = result.compressed_text
66
+ print(f"Saved {result.tokens_saved} tokens per call ({result.compression_ratio:.1%})")
67
+ except CompressionFailedError as e:
68
+ print(f"Compression unsafe, using original: {e}")
69
+ ```
70
+
71
+ `on_failure` accepts `'fallback'` (default — return the original silently with `gate_passed=False`), `'raise'` (raise `CompressionFailedError`), or `'warn'` (log a warning and return the fallback). The library never blocks on user input.
72
+
73
+ ## Inspecting results
74
+
75
+ ```python
76
+ result = compressor.compress(SYSTEM_PROMPT)
77
+
78
+ print(result.summary()) # one-screen terminal summary
79
+ print(result.diff()) # side-by-side original vs compressed
80
+ result.to_dict() # JSON-serialisable, useful for caching/logging
81
+ ```
82
+
83
+ Key properties on `CompressionResult`:
84
+
85
+ | Property | Description |
86
+ |---|---|
87
+ | `compressed_text` | the output you should use |
88
+ | `compression_ratio` | tokens saved / original tokens |
89
+ | `tokens_saved` | absolute token count saved |
90
+ | `semantic_similarity` | cosine sim of original vs compressed (MiniLM) |
91
+ | `compression_efficiency` | `compression_ratio × semantic_similarity` |
92
+ | `safe_to_use` | True iff all validator checks passed |
93
+ | `persona_preserved` | True iff the "You are…" line survived |
94
+ | `placeholders_preserved` | True iff every `{var}` from the original is in the output |
95
+ | `tier` / `tier_label` | which pipeline tier ran (1 BO, 2 TextRank, 3 Preserved) |
96
+ | `density` | information density score used for routing |
97
+
98
+ ## Configuration
99
+
100
+ ```python
101
+ from prompt_compress import PromptCompressor, OptimisationConfig
102
+
103
+ compressor = PromptCompressor(
104
+ # Optimiser variants:
105
+ use_informed_prior=False, # seed BO with P3-derived prior
106
+ use_attention_prior=False, # per-prompt attention prior + ISR safety gate
107
+ # Trade-off knob:
108
+ alpha=0.3, # "auto" → 0.3 (validated benchmark default)
109
+ # Tune BO budget:
110
+ optimisation_config=OptimisationConfig(
111
+ n_iterations=20, n_init=5, beta=2.0, random_seed=42,
112
+ ),
113
+ )
114
+ ```
115
+
116
+ `min_similarity` and `on_failure` are per-call (`compressor.compress(prompt, min_similarity=…, on_failure=…)`) so different parts of your app can adopt different safety bars without rebuilding the compressor.
117
+
118
+ ## Benchmark results
119
+
120
+ Matched-subset comparison against LLMLingua on the 38 prompts both systems successfully compressed (see `research/benchmark.py` and `research/evaluate.py` to reproduce):
121
+
122
+ | Metric | Ours | LLMLingua |
123
+ |------------------------------|---------|-----------|
124
+ | Compression ratio | 24.1% | 24.2% |
125
+ | LLM judge score (0–100) | 73.3 | 70.2 |
126
+ | Persona preservation | 100% | 53% |
127
+ | Compression efficiency | 0.179 | 0.155 |
128
+
129
+ `Compression efficiency = compression_ratio × output_similarity` — rewards being high on both axes.
130
+
131
+ ## Citation
132
+
133
+ > *EMNLP manuscript in preparation.*
@@ -0,0 +1,92 @@
1
+ # prompt-compress
2
+
3
+ Structural prompt compression for production LLM apps. Where LLMLingua removes individual low-perplexity tokens, this library parses your system prompt into named components (instruction, examples, constraints, style, context), uses Bayesian optimisation to search which components to keep and how aggressively to compress each, scores candidates by semantic similarity to the original, and gates every output through a post-compression validator (persona / placeholder / similarity). Prompts that are already information-dense are detected up front and passed through unchanged.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install prompt-compress
9
+ ```
10
+
11
+ ## Quickstart — production integration
12
+
13
+ ```python
14
+ from prompt_compress import PromptCompressor, CompressionFailedError
15
+
16
+ compressor = PromptCompressor()
17
+
18
+ try:
19
+ result = compressor.compress(
20
+ SYSTEM_PROMPT,
21
+ min_similarity=0.80,
22
+ on_failure='raise',
23
+ )
24
+ SYSTEM_PROMPT = result.compressed_text
25
+ print(f"Saved {result.tokens_saved} tokens per call ({result.compression_ratio:.1%})")
26
+ except CompressionFailedError as e:
27
+ print(f"Compression unsafe, using original: {e}")
28
+ ```
29
+
30
+ `on_failure` accepts `'fallback'` (default — return the original silently with `gate_passed=False`), `'raise'` (raise `CompressionFailedError`), or `'warn'` (log a warning and return the fallback). The library never blocks on user input.
31
+
32
+ ## Inspecting results
33
+
34
+ ```python
35
+ result = compressor.compress(SYSTEM_PROMPT)
36
+
37
+ print(result.summary()) # one-screen terminal summary
38
+ print(result.diff()) # side-by-side original vs compressed
39
+ result.to_dict() # JSON-serialisable, useful for caching/logging
40
+ ```
41
+
42
+ Key properties on `CompressionResult`:
43
+
44
+ | Property | Description |
45
+ |---|---|
46
+ | `compressed_text` | the output you should use |
47
+ | `compression_ratio` | tokens saved / original tokens |
48
+ | `tokens_saved` | absolute token count saved |
49
+ | `semantic_similarity` | cosine sim of original vs compressed (MiniLM) |
50
+ | `compression_efficiency` | `compression_ratio × semantic_similarity` |
51
+ | `safe_to_use` | True iff all validator checks passed |
52
+ | `persona_preserved` | True iff the "You are…" line survived |
53
+ | `placeholders_preserved` | True iff every `{var}` from the original is in the output |
54
+ | `tier` / `tier_label` | which pipeline tier ran (1 BO, 2 TextRank, 3 Preserved) |
55
+ | `density` | information density score used for routing |
56
+
57
+ ## Configuration
58
+
59
+ ```python
60
+ from prompt_compress import PromptCompressor, OptimisationConfig
61
+
62
+ compressor = PromptCompressor(
63
+ # Optimiser variants:
64
+ use_informed_prior=False, # seed BO with P3-derived prior
65
+ use_attention_prior=False, # per-prompt attention prior + ISR safety gate
66
+ # Trade-off knob:
67
+ alpha=0.3, # "auto" → 0.3 (validated benchmark default)
68
+ # Tune BO budget:
69
+ optimisation_config=OptimisationConfig(
70
+ n_iterations=20, n_init=5, beta=2.0, random_seed=42,
71
+ ),
72
+ )
73
+ ```
74
+
75
+ `min_similarity` and `on_failure` are per-call (`compressor.compress(prompt, min_similarity=…, on_failure=…)`) so different parts of your app can adopt different safety bars without rebuilding the compressor.
76
+
77
+ ## Benchmark results
78
+
79
+ Matched-subset comparison against LLMLingua on the 38 prompts both systems successfully compressed (see `research/benchmark.py` and `research/evaluate.py` to reproduce):
80
+
81
+ | Metric | Ours | LLMLingua |
82
+ |------------------------------|---------|-----------|
83
+ | Compression ratio | 24.1% | 24.2% |
84
+ | LLM judge score (0–100) | 73.3 | 70.2 |
85
+ | Persona preservation | 100% | 53% |
86
+ | Compression efficiency | 0.179 | 0.155 |
87
+
88
+ `Compression efficiency = compression_ratio × output_similarity` — rewards being high on both axes.
89
+
90
+ ## Citation
91
+
92
+ > *EMNLP manuscript in preparation.*
@@ -0,0 +1,26 @@
1
+ """
2
+ prompt-compress — structural prompt compression with safety gating.
3
+
4
+ Public API:
5
+ PromptCompressor — the compressor
6
+ CompressionResult — what compress() returns
7
+ CompressionFailedError — raised when on_failure='raise' and validation fails
8
+ OptimisationConfig — knobs for the Bayesian optimiser
9
+
10
+ Internal classes (parser, validator, evaluator, etc.) are importable from
11
+ their submodules for advanced use but are intentionally not promoted here.
12
+ """
13
+
14
+ from .compressor import CompressionFailedError, PromptCompressor
15
+ from .optimiser import OptimisationConfig
16
+ from .result import CompressionResult
17
+
18
+ __version__ = '0.1.0'
19
+
20
+ __all__ = [
21
+ 'PromptCompressor',
22
+ 'CompressionFailedError',
23
+ 'OptimisationConfig',
24
+ 'CompressionResult',
25
+ '__version__',
26
+ ]
@@ -0,0 +1,27 @@
1
+ """
2
+ Shared persona-detection patterns.
3
+
4
+ Lives in its own module so both `parser.py` and `validators.py` can import
5
+ from it without introducing a circular dependency between them.
6
+ """
7
+
8
+ import re
9
+
10
+ PERSONA_PATTERNS = (
11
+ 'you are',
12
+ 'act as',
13
+ 'i want you to act',
14
+ 'pretend to be',
15
+ 'pretend you are',
16
+ 'imagine you are',
17
+ "let's role[- ]?play as",
18
+ 'you will (be|act as|play)',
19
+ 'assume the role',
20
+ )
21
+
22
+
23
+ def persona_present(text: str) -> bool:
24
+ """True if text opens with a recognised persona pattern."""
25
+ first_line = text.lstrip().split('\n', 1)[0].strip()
26
+ first_sentence = re.split(r'[.!?]', first_line, maxsplit=1)[0].strip().lower()
27
+ return any(re.match(rf'^{pat}\b', first_sentence) for pat in PERSONA_PATTERNS)
@@ -0,0 +1,201 @@
1
+ """
2
+ Attention-informed Bayesian optimiser.
3
+
4
+ Subclass of InformedBayesianOptimiser that swaps the static P3-JSON prior
5
+ for a per-prompt prior generated from component-to-component attention,
6
+ and adds an Information Sufficiency Ratio (ISR) pre-compression gate so
7
+ prompts already near their Minimum Description Length are not compressed.
8
+ """
9
+
10
+ import logging
11
+ from typing import Dict, Optional
12
+
13
+ from .attention_priors import AttentionPriorGenerator
14
+ from .encoders import PromptStructure
15
+ from .information_sufficiency import ISRGate
16
+ from .informed_optimiser import InformedBayesianOptimiser
17
+ from .optimiser import OptimisationConfig, OptimisationResult
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class AttentionInformedOptimiser(InformedBayesianOptimiser):
23
+ """
24
+ Use a fresh attention-derived prior for each prompt instead of a static
25
+ dataset-wide prior, with an ISR gate that can short-circuit BO when the
26
+ prompt is information-dense (near MDL) or flag aggressive compression
27
+ when the prompt is highly redundant.
28
+ """
29
+
30
+ AGGRESSIVE_BETA_MULTIPLIER = 1.2 # widen UCB exploration when ISR is low
31
+ AUTO_ALPHA = 0.3 # default for alpha="auto"; matches the
32
+ # validated benchmark (see SemanticEvaluator.AUTO_ALPHA)
33
+
34
+ def __init__(
35
+ self,
36
+ encoder,
37
+ evaluator,
38
+ config: Optional[OptimisationConfig] = None,
39
+ components: Optional[Dict] = None,
40
+ prior_generator: Optional[AttentionPriorGenerator] = None,
41
+ prompt_text: Optional[str] = None,
42
+ use_isr_gate: bool = True,
43
+ isr_high_threshold: float = 0.85,
44
+ isr_low_threshold: float = 0.40,
45
+ alpha=AUTO_ALPHA,
46
+ ):
47
+ """
48
+ Args:
49
+ components: Parsed components dict (PromptParser output) for
50
+ the current prompt. Required to generate the prior.
51
+ prior_generator: Optional shared AttentionPriorGenerator instance.
52
+ prompt_text: Raw prompt text. Required for the ISR gate to
53
+ fire; if None, the gate is silently skipped.
54
+ use_isr_gate: Master switch for the ISR pre-check.
55
+ isr_high_threshold: ISR above this ⇒ skip compression entirely.
56
+ isr_low_threshold: ISR below this ⇒ enable aggressive mode (the
57
+ UCB beta is multiplied by AGGRESSIVE_BETA_MULTIPLIER
58
+ to favour exploration of compressive structures).
59
+ alpha: Quality/compression trade-off. Larger alpha
60
+ penalises length more heavily. Accepts "auto"
61
+ (resolves to AUTO_ALPHA=1.0) or any float.
62
+ The objective lives on SemanticEvaluator —
63
+ this kwarg is stored for logging and is
64
+ surfaced on the OptimisationResult.
65
+ """
66
+ self.alpha = self._resolve_alpha(alpha)
67
+ if components is None:
68
+ raise ValueError(
69
+ "AttentionInformedOptimiser requires the parsed components dict"
70
+ )
71
+
72
+ self._components = components
73
+ self._prior_generator = prior_generator or AttentionPriorGenerator()
74
+ self._prompt_text = prompt_text
75
+
76
+ if config is None:
77
+ config = OptimisationConfig(
78
+ n_iterations=20, n_init=5, beta=2.0, random_seed=42
79
+ )
80
+
81
+ # Honour both the explicit kwargs and the OptimisationConfig fields.
82
+ # Explicit kwargs win so callers can override per-instance.
83
+ self.use_isr_gate = use_isr_gate and config.enable_isr_gate
84
+ self.isr_gate = ISRGate(
85
+ high_threshold=isr_high_threshold if isr_high_threshold is not None else config.isr_high,
86
+ low_threshold=isr_low_threshold if isr_low_threshold is not None else config.isr_low,
87
+ )
88
+
89
+ # Skip the parent's JSON-loading constructor by initialising the
90
+ # grandparent (BayesianPromptOptimiser) directly.
91
+ from .optimiser import BayesianPromptOptimiser
92
+ BayesianPromptOptimiser.__init__(self, encoder, evaluator, config)
93
+
94
+ self.prior = self._prior_generator.generate(components)
95
+ logger.info(
96
+ "Loaded attention-informed prior "
97
+ "(mean_attention=%s)",
98
+ {k: round(v, 2) for k, v in self.prior['mean_attention'].items()},
99
+ )
100
+
101
+ @classmethod
102
+ def _resolve_alpha(cls, alpha) -> float:
103
+ if isinstance(alpha, str):
104
+ if alpha == "auto":
105
+ return cls.AUTO_ALPHA
106
+ raise ValueError(f"alpha must be 'auto' or numeric, got {alpha!r}")
107
+ return float(alpha)
108
+
109
+ def check_isr(self) -> tuple[bool, float, str]:
110
+ """
111
+ Run the ISR gate against the stored prompt text. Returns the same
112
+ triple as ISRGate.should_compress.
113
+
114
+ Returns (True, 0.0, "isr disabled") if the gate is off or no prompt
115
+ text was supplied.
116
+ """
117
+ if not self.use_isr_gate or self._prompt_text is None:
118
+ return True, 0.0, "isr disabled"
119
+ return self.isr_gate.should_compress(self._prompt_text)
120
+
121
+ def optimise(self) -> OptimisationResult:
122
+ """
123
+ Run ISR gate first, then dispatch to the parent BO loop.
124
+
125
+ Behaviour:
126
+ - ISR > high_threshold: return a `skipped=True` result; the
127
+ orchestrator (PromptCompressor) treats this as "return original".
128
+ - ISR < low_threshold: temporarily inflate UCB beta to favour
129
+ exploration of compressive structures.
130
+ - Otherwise: run BO unchanged.
131
+ """
132
+ should_compress, isr, reason = self.check_isr()
133
+
134
+ if not should_compress:
135
+ logger.info(
136
+ "ISR=%.2f exceeds threshold. Prompt already near Minimum "
137
+ "Description Length. Preserving original.",
138
+ isr,
139
+ )
140
+ return self._skipped_result(isr, reason)
141
+
142
+ if self.use_isr_gate and self._prompt_text is not None:
143
+ band = "low" if isr < self.isr_gate.low_threshold else (
144
+ "high" if isr > self.isr_gate.high_threshold else "moderate"
145
+ )
146
+ logger.info("ISR=%.2f (%s). %s.", isr, band, reason.capitalize())
147
+
148
+ # Aggressive mode: temporarily widen UCB exploration. Restore beta
149
+ # after the run so the optimiser stays reusable.
150
+ original_beta = self.config.beta
151
+ if isr < self.isr_gate.low_threshold:
152
+ self.config.beta = original_beta * self.AGGRESSIVE_BETA_MULTIPLIER
153
+ logger.info(
154
+ "ISR=%.2f (low). Enabling aggressive compression "
155
+ "(beta %.2f -> %.2f).",
156
+ isr,
157
+ original_beta,
158
+ self.config.beta,
159
+ )
160
+
161
+ try:
162
+ result = super().optimise()
163
+ finally:
164
+ self.config.beta = original_beta
165
+
166
+ result.skipped = False
167
+ result.isr_score = isr if self.use_isr_gate and self._prompt_text else None
168
+ result.isr_reason = reason if self.use_isr_gate and self._prompt_text else None
169
+ result.alpha_used = self.alpha
170
+ return result
171
+
172
+ def _skipped_result(self, isr: float, reason: str) -> OptimisationResult:
173
+ """
174
+ Build a sentinel OptimisationResult signalling "do not compress".
175
+
176
+ The structure is set to "keep everything" so that if a caller ignores
177
+ `skipped` and materialises this structure anyway, the output is at
178
+ worst the full original prompt — not a corrupted compression.
179
+ """
180
+ identity_structure = PromptStructure(
181
+ has_instruction=True,
182
+ has_examples=True,
183
+ has_constraints=True,
184
+ has_style=True,
185
+ has_context=True,
186
+ num_examples=1.0,
187
+ instruction_length=1.0,
188
+ total_tokens=1.0,
189
+ component_ordering=[1, 2, 3, 4, 5],
190
+ )
191
+ return OptimisationResult(
192
+ best_structure=identity_structure,
193
+ best_score=1.0,
194
+ all_scores=[1.0],
195
+ all_structures=[identity_structure],
196
+ total_evaluations=0,
197
+ skipped=True,
198
+ isr_score=isr,
199
+ isr_reason=reason,
200
+ alpha_used=self.alpha,
201
+ )
@@ -0,0 +1,149 @@
1
+ """
2
+ Per-prompt attention-informed priors for the Bayesian optimiser.
3
+
4
+ For each component in the parsed prompt, compute the mean cosine similarity
5
+ ("attention") to every other component using sentence-transformers embeddings.
6
+ High mean attention (>0.7) means the component is tightly coupled to the rest
7
+ of the prompt — keep it. Low mean attention (<0.4) means it is independent —
8
+ safe to drop.
9
+
10
+ The output is shaped to match `InformedBayesianOptimiser._load_prior()` so the
11
+ existing prior pipeline can consume it without changes.
12
+ """
13
+
14
+ import numpy as np
15
+ from typing import Dict, Optional
16
+ from sentence_transformers import SentenceTransformer
17
+
18
+ COMPONENT_NAMES = ['instruction', 'examples', 'constraints', 'style', 'context']
19
+
20
+
21
+ class AttentionPriorGenerator:
22
+ """
23
+ Generate a prompt-specific prior from component-to-component attention.
24
+ """
25
+
26
+ EMBEDDING_MODEL = 'all-MiniLM-L6-v2'
27
+
28
+ HIGH_THRESHOLD = 0.7
29
+ LOW_THRESHOLD = 0.4
30
+
31
+ # Inclusion probabilities mapped from coupling strength
32
+ HIGH_INCLUSION = 0.85
33
+ MID_INCLUSION = 0.5
34
+ LOW_INCLUSION = 0.25
35
+
36
+ _shared_model: Optional[SentenceTransformer] = None
37
+
38
+ def __init__(self):
39
+ if AttentionPriorGenerator._shared_model is None:
40
+ AttentionPriorGenerator._shared_model = SentenceTransformer(self.EMBEDDING_MODEL)
41
+ self.model = AttentionPriorGenerator._shared_model
42
+
43
+ def generate(self, components: Dict) -> Dict:
44
+ """
45
+ Build a prior dict compatible with InformedBayesianOptimiser.
46
+
47
+ Args:
48
+ components: Dict[str, list[str]] from PromptParser.parse(). May
49
+ include a '__protected__' key, which is ignored.
50
+
51
+ Returns:
52
+ Dict with keys 'mean', 'variance', 'source', plus an extra
53
+ 'attention_matrix' (5x5) and 'mean_attention' (per-component)
54
+ for debugging/inspection.
55
+ """
56
+ present, embeddings = self._embed_components(components)
57
+
58
+ # Default inclusion probabilities (used when a component is empty)
59
+ prior_mean = {
60
+ 'has_instruction': self.HIGH_INCLUSION, # always favour instruction
61
+ 'has_examples': self.MID_INCLUSION,
62
+ 'has_constraints': self.MID_INCLUSION,
63
+ 'has_style': self.LOW_INCLUSION,
64
+ 'has_context': self.LOW_INCLUSION,
65
+ 'instruction_length': 0.7,
66
+ 'num_examples': 0.4,
67
+ }
68
+
69
+ attention_matrix = np.zeros((5, 5))
70
+ mean_attention = {name: 0.0 for name in COMPONENT_NAMES}
71
+
72
+ if len(present) >= 2:
73
+ sim = self._cosine_matrix(embeddings)
74
+ for i, name_i in enumerate(present):
75
+ col = COMPONENT_NAMES.index(name_i)
76
+ for j, name_j in enumerate(present):
77
+ row = COMPONENT_NAMES.index(name_j)
78
+ attention_matrix[col, row] = sim[i, j]
79
+ # Mean attention excludes the diagonal (self-similarity = 1.0)
80
+ others = [sim[i, j] for j in range(len(present)) if j != i]
81
+ mean_attention[name_i] = float(np.mean(others)) if others else 0.0
82
+
83
+ for name in present:
84
+ ma = mean_attention[name]
85
+ if ma >= self.HIGH_THRESHOLD:
86
+ prior_mean[f'has_{name}'] = self.HIGH_INCLUSION
87
+ elif ma < self.LOW_THRESHOLD:
88
+ prior_mean[f'has_{name}'] = self.LOW_INCLUSION
89
+ else:
90
+ prior_mean[f'has_{name}'] = self.MID_INCLUSION
91
+
92
+ # Variance: shrink when we have strong signal (many present components),
93
+ # widen when we have little to learn from.
94
+ prior_variance = 0.15 if len(present) >= 3 else 0.25
95
+
96
+ return {
97
+ 'mean': prior_mean,
98
+ 'variance': prior_variance,
99
+ 'source': 'per-prompt attention',
100
+ 'attention_matrix': attention_matrix.tolist(),
101
+ 'mean_attention': mean_attention,
102
+ }
103
+
104
+ def _embed_components(self, components: Dict):
105
+ """
106
+ Embed each present component (concatenated sentences) and return the
107
+ ordered list of names and their embedding matrix.
108
+ """
109
+ present_names: list[str] = []
110
+ texts: list[str] = []
111
+ for name in COMPONENT_NAMES:
112
+ sentences = components.get(name, [])
113
+ if not sentences:
114
+ continue
115
+ joined = ' '.join(sentences) if isinstance(sentences, list) else str(sentences)
116
+ if not joined.strip():
117
+ continue
118
+ present_names.append(name)
119
+ texts.append(joined)
120
+ if not texts:
121
+ return [], np.zeros((0, 384))
122
+ embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
123
+ return present_names, embeddings
124
+
125
+ @staticmethod
126
+ def _cosine_matrix(embeddings: np.ndarray) -> np.ndarray:
127
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
128
+ norms = np.where(norms == 0.0, 1.0, norms)
129
+ normalised = embeddings / norms
130
+ return normalised @ normalised.T
131
+
132
+
133
+ if __name__ == "__main__":
134
+ sample_components = {
135
+ 'instruction': ['You are a SQL tutor. Explain query plans.'],
136
+ 'examples': ['Example: SELECT * FROM users WHERE id = 1.'],
137
+ 'constraints': ['Always include the cost estimate.', 'Never invent table names.'],
138
+ 'style': ['Use a formal tone.'],
139
+ 'context': [],
140
+ }
141
+ gen = AttentionPriorGenerator()
142
+ prior = gen.generate(sample_components)
143
+ print('mean inclusion probabilities:')
144
+ for k, v in prior['mean'].items():
145
+ print(f' {k}: {v:.3f}')
146
+ print()
147
+ print('mean attention per component:')
148
+ for k, v in prior['mean_attention'].items():
149
+ print(f' {k}: {v:.3f}')