promptforest 0.1.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
promptforest/config.py CHANGED
@@ -13,10 +13,9 @@ MODELS_DIR = USER_DATA_DIR / "models"
13
13
  DEFAULT_CONFIG = {
14
14
  "models": [
15
15
  {"name": "llama_guard", "path": "llama_guard", "type": "hf", "enabled": True},
16
- {"name": "protectai", "path": "protectai_deberta", "type": "hf", "enabled": True},
17
- {"name": "deepset", "path": "deepset_deberta", "type": "hf", "enabled": True},
18
- {"name": "katanemo", "path": "katanemo_arch", "type": "hf", "enabled": True},
19
- {"name": "xgboost", "type": "xgboost", "enabled": True}
16
+ {"name": "protectai", "path": "protectai", "type": "hf", "enabled": True},
17
+ {"name": "vijil", "path": "vijil_dome", "type": "hf", "enabled": True},
18
+ {"name": "xgboost", "type": "xgboost", "enabled": True, "threshold": 0.10}
20
19
  ],
21
20
  "settings": {
22
21
  "device": "auto", # Options: auto, cuda, mps, cpu
promptforest/download.py CHANGED
@@ -13,9 +13,8 @@ from .llama_guard_86m_downloader import download_llama_guard
13
13
 
14
14
  # Configuration
15
15
  MODELS = {
16
- "protectai_deberta": "protectai/deberta-v3-base-prompt-injection",
17
- "deepset_deberta": "deepset/deberta-v3-base-injection",
18
- "katanemo_arch": "katanemo/Arch-Guard"
16
+ "protectai": "protectai/deberta-v3-base-prompt-injection-v2",
17
+ "vijil_dome": "vijil/vijil_dome_prompt_injection_detection"
19
18
  }
20
19
 
21
20
  EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2'
@@ -31,15 +30,20 @@ def _download_hf_model(name, model_id):
31
30
  try:
32
31
  if save_path.exists():
33
32
  return
33
+
34
+ # Special handling for Vijil (ModernBERT tokenizer issue)
35
+ tokenizer_id = model_id
36
+ if "vijil" in name or "vijil" in model_id:
37
+ tokenizer_id = "answerdotai/ModernBERT-base"
34
38
 
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
39
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
36
40
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
37
41
 
38
42
  tokenizer.save_pretrained(save_path)
39
43
  model.save_pretrained(save_path)
40
44
 
41
45
  except Exception as e:
42
- print("Failed to download a model!")
46
+ print(f"Failed to download {model_id}: {e}")
43
47
 
44
48
  def _download_sentence_transformer():
45
49
  """Download and save the SentenceTransformer model."""
promptforest/lib.py CHANGED
@@ -68,7 +68,8 @@ class HFModel(ModelInference):
68
68
  id2label = self.model.config.id2label
69
69
  found = False
70
70
  for idx, label in id2label.items():
71
- if any(kw in label.lower() for kw in MALICIOUS_KEYWORDS):
71
+ # Check if label is string before calling lower()
72
+ if isinstance(label, str) and any(kw in label.lower() for kw in MALICIOUS_KEYWORDS):
72
73
  self.malicious_idx = idx
73
74
  found = True
74
75
  break
@@ -99,9 +100,11 @@ class HFModel(ModelInference):
99
100
 
100
101
 
101
102
  class XGBoostModel(ModelInference):
102
- def __init__(self, settings):
103
+ def __init__(self, settings, config=None):
103
104
  self.name = "xgboost_custom"
104
105
  self.settings = settings
106
+ self.config = config or {}
107
+ self.threshold = self.config.get('threshold', 0.5)
105
108
  self.model = None
106
109
  self.embedder = None
107
110
 
@@ -140,7 +143,18 @@ class XGBoostModel(ModelInference):
140
143
  try:
141
144
  emb = self.embedder.encode([prompt])
142
145
  prob = self.model.predict_proba(emb)[0][1]
143
- return float(prob)
146
+ prob = float(prob)
147
+
148
+ # Apply threshold adjustment if a custom threshold is set
149
+ if self.threshold != 0.5:
150
+ if prob < self.threshold:
151
+ # Map [0, threshold) -> [0, 0.5)
152
+ prob = 0.5 * (prob / self.threshold)
153
+ else:
154
+ # Map [threshold, 1.0] -> [0.5, 1.0]
155
+ prob = 0.5 + 0.5 * (prob - self.threshold) / (1.0 - self.threshold)
156
+
157
+ return prob
144
158
  except Exception:
145
159
  return 0.0
146
160
 
@@ -151,14 +165,14 @@ class EnsembleGuard:
151
165
  Initialize the EnsembleGuard.
152
166
  :param config: Dictionary containing configuration. If None, loads default/user config.
153
167
  """
154
- # Check if models need to be downloaded
155
- self._ensure_models_available()
156
-
157
168
  if config is None:
158
169
  self.config = load_config()
159
170
  else:
160
171
  self.config = config
161
172
 
173
+ # Check if models need to be downloaded
174
+ self._ensure_models_available()
175
+
162
176
  self.models = []
163
177
  self._init_models()
164
178
  self.device_used = get_device(self.config['settings'].get('device', 'auto'))
@@ -167,14 +181,27 @@ class EnsembleGuard:
167
181
  """Check if models are available, download if needed."""
168
182
  from .config import MODELS_DIR
169
183
 
170
- # Check if models directory exists and has content
171
- if MODELS_DIR.exists() and any(MODELS_DIR.iterdir()):
172
- return
173
-
174
- # Models not found, download them
175
- print("Models not found. Downloading...")
176
- from .download import download_all
177
- download_all()
184
+ missing = False
185
+ if not MODELS_DIR.exists():
186
+ missing = True
187
+ else:
188
+ # Check for each enabled HF model
189
+ for model_cfg in self.config.get('models', []):
190
+ if model_cfg.get('type') == 'hf' and model_cfg.get('enabled', True):
191
+ path = MODELS_DIR / model_cfg['path']
192
+ if not path.exists():
193
+ missing = True
194
+ break
195
+
196
+ # Check for SentenceTransformer (needed for XGBoost)
197
+ st_path = MODELS_DIR / 'sentence_transformer'
198
+ if not st_path.exists():
199
+ missing = True
200
+
201
+ if missing:
202
+ print("Some models not found. Downloading required models...")
203
+ from .download import download_all
204
+ download_all()
178
205
 
179
206
  def _init_models(self):
180
207
  settings = self.config.get('settings', {})
@@ -189,7 +216,7 @@ class EnsembleGuard:
189
216
  if model_type == 'hf':
190
217
  self.models.append(HFModel(model_cfg['name'], model_cfg['path'], settings))
191
218
  elif model_type == 'xgboost':
192
- self.models.append(XGBoostModel(settings))
219
+ self.models.append(XGBoostModel(settings, model_cfg))
193
220
  else:
194
221
  print(f"Unknown model type: {model_type}")
195
222
 
@@ -1,67 +1,58 @@
1
1
  """
2
- Script to download Llama Guard 2 86M from a custom GitHub repository.
3
- Handles split safetensor files and combines them locally.
2
+ Script to download Llama Guard 2 86M from custom GitHub releases.
3
+ Downloads files in parallel for speed.
4
4
  """
5
5
 
6
6
  import os
7
- import shutil
8
- import subprocess
9
- import tempfile
7
+ import requests
10
8
  from pathlib import Path
9
+ from concurrent.futures import ThreadPoolExecutor
11
10
  from .config import MODELS_DIR
12
11
 
13
- LLAMA_GUARD_REPO = "https://github.com/appleroll-research/promptforest-model-ensemble.git"
12
+ BASE_URL = "https://github.com/appleroll-research/promptforest-model-ensemble/releases/download/v0.5.0-alpha"
13
+ FILES_TO_DOWNLOAD = [
14
+ "config.json",
15
+ "model.safetensors",
16
+ "special_tokens_map.json",
17
+ "tokenizer.json",
18
+ "tokenizer_config.json"
19
+ ]
14
20
 
15
- def _download_llama_guard():
16
- """Download Llama Guard from custom repository and combine split files."""
17
- save_path = MODELS_DIR / "llama_guard"
18
-
21
+ def _download_file(url, save_path):
22
+ """Download a single file."""
19
23
  if save_path.exists():
20
24
  return
21
-
25
+
22
26
  try:
23
- # Use temporary directory for cloning
24
- with tempfile.TemporaryDirectory() as temp_dir:
25
- temp_path = Path(temp_dir)
26
-
27
- # Clone repository silently
28
- subprocess.run(
29
- ["git", "clone", "--depth", "1", LLAMA_GUARD_REPO, str(temp_path)],
30
- stdout=subprocess.DEVNULL,
31
- stderr=subprocess.DEVNULL,
32
- check=True
33
- )
34
-
35
- # Get the llama_guard folder from the cloned repo
36
- source_dir = temp_path / "llama_guard"
37
- if not source_dir.exists():
38
- raise FileNotFoundError(f"llama_guard folder not found in repository")
39
-
40
- # Copy to models directory
41
- save_path.parent.mkdir(parents=True, exist_ok=True)
42
- shutil.copytree(source_dir, save_path)
43
-
44
- # Combine split safetensor files
45
- model_file = save_path / "model.safetensors"
46
- if not model_file.exists():
47
- # Find and combine c_* files
48
- split_files = sorted(save_path.glob("c_*"))
49
- if split_files:
50
- with open(model_file, 'wb') as outfile:
51
- for split_file in split_files:
52
- with open(split_file, 'rb') as infile:
53
- outfile.write(infile.read())
54
-
55
- # Delete split files
56
- for split_file in split_files:
57
- split_file.unlink()
27
+ response = requests.get(url, stream=True)
28
+ response.raise_for_status()
58
29
 
30
+ with open(save_path, 'wb') as f:
31
+ for chunk in response.iter_content(chunk_size=8192):
32
+ f.write(chunk)
59
33
  except Exception as e:
60
- # Clean up on failure
34
+ print(f"Failed to download {url}: {e}")
35
+ # Clean up partial file
61
36
  if save_path.exists():
62
- shutil.rmtree(save_path)
63
- raise Exception(f"Failed to download Llama Guard: {e}")
37
+ os.remove(save_path)
64
38
 
65
39
  def download_llama_guard():
66
- """Public interface for downloading Llama Guard."""
67
- _download_llama_guard()
40
+ """Download Llama Guard files in parallel."""
41
+ save_dir = MODELS_DIR / "llama_guard"
42
+
43
+ # Check if all files exist
44
+ if save_dir.exists() and all((save_dir / f).exists() for f in FILES_TO_DOWNLOAD):
45
+ return
46
+
47
+ save_dir.mkdir(parents=True, exist_ok=True)
48
+
49
+ with ThreadPoolExecutor(max_workers=5) as executor:
50
+ futures = []
51
+ for filename in FILES_TO_DOWNLOAD:
52
+ url = f"{BASE_URL}/{filename}"
53
+ save_path = save_dir / filename
54
+ futures.append(executor.submit(_download_file, url, save_path))
55
+
56
+ for future in futures:
57
+ future.result()
58
+
Binary file
@@ -0,0 +1,99 @@
1
+ Metadata-Version: 2.4
2
+ Name: promptforest
3
+ Version: 0.5.0
4
+ Summary: Ensemble Prompt Injection Detection
5
+ Requires-Python: >=3.8
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE.txt
8
+ License-File: NOTICE.md
9
+ Requires-Dist: numpy
10
+ Requires-Dist: pandas
11
+ Requires-Dist: torch
12
+ Requires-Dist: transformers
13
+ Requires-Dist: sentence-transformers
14
+ Requires-Dist: xgboost
15
+ Requires-Dist: scikit-learn
16
+ Requires-Dist: pyyaml
17
+ Requires-Dist: joblib
18
+ Requires-Dist: protobuf
19
+ Dynamic: description
20
+ Dynamic: description-content-type
21
+ Dynamic: license-file
22
+ Dynamic: requires-dist
23
+ Dynamic: requires-python
24
+ Dynamic: summary
25
+
26
+ # PromptForest - Fast and Reliable Injection Detector Ensemble
27
+
28
+ PromptForest is a prompt injection detector ensemble focused on real-world latency and reliability.
29
+
30
+ We rely on an ensemble of small, accurate prompt detection models using a voting system to generate accurate detections.
31
+
32
+ By comparing predictions across multiple models, the system can flag prompts where models disagree, helping to reduce the risk of false negatives.
33
+
34
+ This discrepancy score enables downstream workflows such as:
35
+ - Human-in-the-loop review for high-risk or ambiguous prompts
36
+ - Adaptive throttling or alerting in production systems
37
+ - Continuous monitoring and model improvement
38
+
39
+ ## Statistics
40
+ **E2E Request Latency** \
41
+ Average Case: 100ms \
42
+ Worst Case: 200ms
43
+
44
+ PromptForest was evaluated against the SOTA model Qualifire Sentinel model (v2).
45
+
46
+ | Metric | PromptForest | Sentinel v2 |
47
+ | -------------------------------- | ------------ | ----------- |
48
+ | Accuracy | 0.802 | 0.982 |
49
+ | Avg Confidence on Wrong Answers | 0.643 | 0.858 |
50
+ | Expected Calibration Error (ECE) | 0.060 | 0.202 |
51
+ | Approximate Model Size | ~250M params | 600M params |
52
+
53
+
54
+ ### Key Insights
55
+
56
+ - Calibrated uncertainty: PromptForest is less confident on wrong predictions than Sentinel, resulting in a much lower ECE.
57
+
58
+ - Parameter efficiency: Achieves competitive reliability with <50% of the parameters.
59
+
60
+ - Interpretability: Confidence scores can be used to flag uncertain predictions for human review.
61
+
62
+ Interpretation:
63
+ While Sentinel has higher raw accuracy, PromptForest provides better-calibrated confidence. For systems where overconfidence on wrong answers is risky, PromptForest can reduce the chance of critical errors despite being smaller and faster.
64
+
65
+ Using Sentinel v2 as a baseline, and given that other models perform worse than Sentinel in published benchmarks, PromptForest is expected to offer more reliable and calibrated predictions than most alternatives.
66
+
67
+
68
+ ## Supported Models
69
+
70
+ | Provider | Model Name |
71
+ | ------------- | ----------------------------------------- |
72
+ | **Meta** | [Llama Prompt Guard 86M](https://huggingface.co/meta-llama/Prompt-Guard-86M) (Built with Llama) |
73
+ | **ProtectAI** | [DebertaV3 Prompt Injection Finetune](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2) |
74
+ | **Vijil** | [Vijil Dome Prompt Injection Detection](https://huggingface.co/vijil/vijil_dome_prompt_injection_detection) |
75
+ | **Appleroll** | [PromptForest-XGB](appleroll/promptforest-xgb) |
76
+
77
+ ## Quick Start
78
+ To use PromptForest, simply install the pip package and serve it at a port of your choice. It should automatically start downloading the default model ensemble.
79
+
80
+ Gated models are downloaded through our own [ensemble github respository](https://github.com/appleroll-research/promptforest-model-ensemble) and are released in accordance to their terms and conditions.
81
+
82
+ ```bash
83
+ pip install promptforest
84
+ promptforest serve --port 8000
85
+ ```
86
+
87
+ ## Disclaimer & Limitations
88
+
89
+ PromptForest uses a combination of open-source and third-party machine learning models, including models and weights released by other organizations under their respective licenses (e.g. Meta LLaMA Prompt Guard and other public prompt-injection detectors).
90
+ All third-party components remain the property of their original authors and are used in accordance with their licenses.
91
+
92
+ PromptForest is not a standalone security solution and should not be relied upon as the sole defense mechanism for protecting production systems. Prompt injection detection is an inherently adversarial and evolving problem, and no automated system can guarantee complete protection.
93
+
94
+ This project has not yet been extensively validated against real-world, large-scale, or targeted prompt-injection attacks. Results may vary depending on deployment context, model configuration, and threat model.
95
+
96
+ PromptForest is intended to be used as one layer in a defense-in-depth strategy, alongside input validation, output filtering, access control, sandboxing, monitoring, and human oversight.
97
+
98
+ ## License
99
+ This project is licensed under Apache 2.0. Third-party models and weights are redistributed under their original licenses (see THIRD_PARTY_LICENSES folder for details). Users must comply with these licenses.
@@ -0,0 +1,15 @@
1
+ promptforest/__init__.py,sha256=cE1cQyRL4vUzseCwLYbI5wrZuZ-NRMVXIjAgwTLwIEs,54
2
+ promptforest/cli.py,sha256=LKsnbEQNQ9pP_Ww24Ql2Tb_uomO-StqHnk-IHONSKTM,1856
3
+ promptforest/config.py,sha256=c_7GX7nh_1Aa-QU7SOZlthPNGXSoh2KvYOk7txJeQh4,3284
4
+ promptforest/download.py,sha256=6TQvo2qd3tUUxJU6MMsFMgOciHP5HNDNEo3UTOeYI34,2637
5
+ promptforest/lib.py,sha256=WEuEhNNlRQAerLyEIbTHdi15qdXUMuiQOhfsvaftj4M,9254
6
+ promptforest/llama_guard_86m_downloader.py,sha256=0B2ttwLWHki0yLEoJG3BwyFE1oqJFY0M2mLEtmMWmPk,1720
7
+ promptforest/server.py,sha256=uF4Yj7yR_2vEx_7nQabGHGGw-6GWnT0iBZx3UPQK634,2905
8
+ promptforest/xgboost/xgb_model.pkl,sha256=kSG2r-6TGfhNJfzwklLQOSgG2z610Z5BXxtgQdXE8Vk,2116991
9
+ promptforest-0.5.0.dist-info/licenses/LICENSE.txt,sha256=GgVl4CdplCpCEssTcrmIRbz52zQc0fdcSETZp34uBF4,11349
10
+ promptforest-0.5.0.dist-info/licenses/NOTICE.md,sha256=XGjuV5VAWBinW6Jzu7-9h0Ph3xwCNzcJdbMH_EgU_g4,356
11
+ promptforest-0.5.0.dist-info/METADATA,sha256=fEgp4u7q-P74Zo3eF0gnEjVSFMuIc9z9g-1AoAKPAZs,5002
12
+ promptforest-0.5.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
13
+ promptforest-0.5.0.dist-info/entry_points.txt,sha256=sVcjABvpA7P2fXca2KMZSYf0PNfDgLt1NHlYFMPO_eE,55
14
+ promptforest-0.5.0.dist-info/top_level.txt,sha256=NxasbbadJaf8w9zaRXo5KOdBqNA1oDe-2X7e6zdz3k0,13
15
+ promptforest-0.5.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,21 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: promptforest
3
- Version: 0.1.0
4
- Summary: Ensemble Prompt Injection Detection
5
- Requires-Python: >=3.8
6
- License-File: LICENSE.txt
7
- License-File: NOTICE.md
8
- Requires-Dist: numpy
9
- Requires-Dist: pandas
10
- Requires-Dist: torch
11
- Requires-Dist: transformers
12
- Requires-Dist: sentence-transformers
13
- Requires-Dist: xgboost
14
- Requires-Dist: scikit-learn
15
- Requires-Dist: pyyaml
16
- Requires-Dist: joblib
17
- Requires-Dist: protobuf
18
- Dynamic: license-file
19
- Dynamic: requires-dist
20
- Dynamic: requires-python
21
- Dynamic: summary
@@ -1,15 +0,0 @@
1
- promptforest/__init__.py,sha256=cE1cQyRL4vUzseCwLYbI5wrZuZ-NRMVXIjAgwTLwIEs,54
2
- promptforest/cli.py,sha256=LKsnbEQNQ9pP_Ww24Ql2Tb_uomO-StqHnk-IHONSKTM,1856
3
- promptforest/config.py,sha256=bOFHlK0kI7c3fzccZrcjccNUfZJPzvLtKEAZ_loLvzE,3366
4
- promptforest/download.py,sha256=3Ss1BX6kQatfhif1cbErUekPlSA2RCqtiatUzGi72zo,2454
5
- promptforest/lib.py,sha256=LT8A1_veV9tB2DyrZ0JEOBW4EWEs9El5xOxF0zNHOAc,8042
6
- promptforest/llama_guard_86m_downloader.py,sha256=ibFeeuDgMBVe-8aD0zl23xJKOPdKyw-4Bsf0iZJih4s,2412
7
- promptforest/server.py,sha256=uF4Yj7yR_2vEx_7nQabGHGGw-6GWnT0iBZx3UPQK634,2905
8
- promptforest/xgboost/xgb_model.pkl,sha256=97Y_Dfu8PwubkplRXJdNEuAj9te1v-nEJlXfPpEZWdM,748772
9
- promptforest-0.1.0.dist-info/licenses/LICENSE.txt,sha256=GgVl4CdplCpCEssTcrmIRbz52zQc0fdcSETZp34uBF4,11349
10
- promptforest-0.1.0.dist-info/licenses/NOTICE.md,sha256=XGjuV5VAWBinW6Jzu7-9h0Ph3xwCNzcJdbMH_EgU_g4,356
11
- promptforest-0.1.0.dist-info/METADATA,sha256=OYvSPhnatbf97rur1W3zaY4FE0MFRE67j8QmC8hpz_M,509
12
- promptforest-0.1.0.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
13
- promptforest-0.1.0.dist-info/entry_points.txt,sha256=sVcjABvpA7P2fXca2KMZSYf0PNfDgLt1NHlYFMPO_eE,55
14
- promptforest-0.1.0.dist-info/top_level.txt,sha256=NxasbbadJaf8w9zaRXo5KOdBqNA1oDe-2X7e6zdz3k0,13
15
- promptforest-0.1.0.dist-info/RECORD,,