promptforest 0.1.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
promptforest/config.py CHANGED
@@ -13,10 +13,9 @@ MODELS_DIR = USER_DATA_DIR / "models"
13
13
  DEFAULT_CONFIG = {
14
14
  "models": [
15
15
  {"name": "llama_guard", "path": "llama_guard", "type": "hf", "enabled": True},
16
- {"name": "protectai", "path": "protectai_deberta", "type": "hf", "enabled": True},
17
- {"name": "deepset", "path": "deepset_deberta", "type": "hf", "enabled": True},
18
- {"name": "katanemo", "path": "katanemo_arch", "type": "hf", "enabled": True},
19
- {"name": "xgboost", "type": "xgboost", "enabled": True}
16
+ {"name": "protectai", "path": "protectai", "type": "hf", "enabled": True},
17
+ {"name": "vijil", "path": "vijil_dome", "type": "hf", "enabled": True},
18
+ {"name": "xgboost", "type": "xgboost", "enabled": True, "threshold": 0.10}
20
19
  ],
21
20
  "settings": {
22
21
  "device": "auto", # Options: auto, cuda, mps, cpu
promptforest/download.py CHANGED
@@ -13,9 +13,8 @@ from .llama_guard_86m_downloader import download_llama_guard
13
13
 
14
14
  # Configuration
15
15
  MODELS = {
16
- "protectai_deberta": "protectai/deberta-v3-base-prompt-injection",
17
- "deepset_deberta": "deepset/deberta-v3-base-injection",
18
- "katanemo_arch": "katanemo/Arch-Guard"
16
+ "protectai": "protectai/deberta-v3-base-prompt-injection-v2",
17
+ "vijil_dome": "vijil/vijil_dome_prompt_injection_detection"
19
18
  }
20
19
 
21
20
  EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2'
@@ -31,15 +30,20 @@ def _download_hf_model(name, model_id):
31
30
  try:
32
31
  if save_path.exists():
33
32
  return
33
+
34
+ # Special handling for Vijil (ModernBERT tokenizer issue)
35
+ tokenizer_id = model_id
36
+ if "vijil" in name or "vijil" in model_id:
37
+ tokenizer_id = "answerdotai/ModernBERT-base"
34
38
 
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
39
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
36
40
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
37
41
 
38
42
  tokenizer.save_pretrained(save_path)
39
43
  model.save_pretrained(save_path)
40
44
 
41
45
  except Exception as e:
42
- print("Failed to download a model!")
46
+ print(f"Failed to download {model_id}: {e}")
43
47
 
44
48
  def _download_sentence_transformer():
45
49
  """Download and save the SentenceTransformer model."""
promptforest/lib.py CHANGED
@@ -68,7 +68,8 @@ class HFModel(ModelInference):
68
68
  id2label = self.model.config.id2label
69
69
  found = False
70
70
  for idx, label in id2label.items():
71
- if any(kw in label.lower() for kw in MALICIOUS_KEYWORDS):
71
+ # Check if label is string before calling lower()
72
+ if isinstance(label, str) and any(kw in label.lower() for kw in MALICIOUS_KEYWORDS):
72
73
  self.malicious_idx = idx
73
74
  found = True
74
75
  break
@@ -99,9 +100,11 @@ class HFModel(ModelInference):
99
100
 
100
101
 
101
102
  class XGBoostModel(ModelInference):
102
- def __init__(self, settings):
103
+ def __init__(self, settings, config=None):
103
104
  self.name = "xgboost_custom"
104
105
  self.settings = settings
106
+ self.config = config or {}
107
+ self.threshold = self.config.get('threshold', 0.5)
105
108
  self.model = None
106
109
  self.embedder = None
107
110
 
@@ -140,7 +143,18 @@ class XGBoostModel(ModelInference):
140
143
  try:
141
144
  emb = self.embedder.encode([prompt])
142
145
  prob = self.model.predict_proba(emb)[0][1]
143
- return float(prob)
146
+ prob = float(prob)
147
+
148
+ # Apply threshold adjustment if a custom threshold is set
149
+ if self.threshold != 0.5:
150
+ if prob < self.threshold:
151
+ # Map [0, threshold) -> [0, 0.5)
152
+ prob = 0.5 * (prob / self.threshold)
153
+ else:
154
+ # Map [threshold, 1.0] -> [0.5, 1.0]
155
+ prob = 0.5 + 0.5 * (prob - self.threshold) / (1.0 - self.threshold)
156
+
157
+ return prob
144
158
  except Exception:
145
159
  return 0.0
146
160
 
@@ -151,14 +165,14 @@ class EnsembleGuard:
151
165
  Initialize the EnsembleGuard.
152
166
  :param config: Dictionary containing configuration. If None, loads default/user config.
153
167
  """
154
- # Check if models need to be downloaded
155
- self._ensure_models_available()
156
-
157
168
  if config is None:
158
169
  self.config = load_config()
159
170
  else:
160
171
  self.config = config
161
172
 
173
+ # Check if models need to be downloaded
174
+ self._ensure_models_available()
175
+
162
176
  self.models = []
163
177
  self._init_models()
164
178
  self.device_used = get_device(self.config['settings'].get('device', 'auto'))
@@ -167,14 +181,27 @@ class EnsembleGuard:
167
181
  """Check if models are available, download if needed."""
168
182
  from .config import MODELS_DIR
169
183
 
170
- # Check if models directory exists and has content
171
- if MODELS_DIR.exists() and any(MODELS_DIR.iterdir()):
172
- return
173
-
174
- # Models not found, download them
175
- print("Models not found. Downloading...")
176
- from .download import download_all
177
- download_all()
184
+ missing = False
185
+ if not MODELS_DIR.exists():
186
+ missing = True
187
+ else:
188
+ # Check for each enabled HF model
189
+ for model_cfg in self.config.get('models', []):
190
+ if model_cfg.get('type') == 'hf' and model_cfg.get('enabled', True):
191
+ path = MODELS_DIR / model_cfg['path']
192
+ if not path.exists():
193
+ missing = True
194
+ break
195
+
196
+ # Check for SentenceTransformer (needed for XGBoost)
197
+ st_path = MODELS_DIR / 'sentence_transformer'
198
+ if not st_path.exists():
199
+ missing = True
200
+
201
+ if missing:
202
+ print("Some models not found. Downloading required models...")
203
+ from .download import download_all
204
+ download_all()
178
205
 
179
206
  def _init_models(self):
180
207
  settings = self.config.get('settings', {})
@@ -189,7 +216,7 @@ class EnsembleGuard:
189
216
  if model_type == 'hf':
190
217
  self.models.append(HFModel(model_cfg['name'], model_cfg['path'], settings))
191
218
  elif model_type == 'xgboost':
192
- self.models.append(XGBoostModel(settings))
219
+ self.models.append(XGBoostModel(settings, model_cfg))
193
220
  else:
194
221
  print(f"Unknown model type: {model_type}")
195
222
 
@@ -1,67 +1,58 @@
1
1
  """
2
- Script to download Llama Guard 2 86M from a custom GitHub repository.
3
- Handles split safetensor files and combines them locally.
2
+ Script to download Llama Guard 2 86M from custom GitHub releases.
3
+ Downloads files in parallel for speed.
4
4
  """
5
5
 
6
6
  import os
7
- import shutil
8
- import subprocess
9
- import tempfile
7
+ import requests
10
8
  from pathlib import Path
9
+ from concurrent.futures import ThreadPoolExecutor
11
10
  from .config import MODELS_DIR
12
11
 
13
- LLAMA_GUARD_REPO = "https://github.com/appleroll-research/promptforest-model-ensemble.git"
12
+ BASE_URL = "https://github.com/appleroll-research/promptforest-model-ensemble/releases/download/v0.5.0-alpha"
13
+ FILES_TO_DOWNLOAD = [
14
+ "config.json",
15
+ "model.safetensors",
16
+ "special_tokens_map.json",
17
+ "tokenizer.json",
18
+ "tokenizer_config.json"
19
+ ]
14
20
 
15
- def _download_llama_guard():
16
- """Download Llama Guard from custom repository and combine split files."""
17
- save_path = MODELS_DIR / "llama_guard"
18
-
21
+ def _download_file(url, save_path):
22
+ """Download a single file."""
19
23
  if save_path.exists():
20
24
  return
21
-
25
+
22
26
  try:
23
- # Use temporary directory for cloning
24
- with tempfile.TemporaryDirectory() as temp_dir:
25
- temp_path = Path(temp_dir)
26
-
27
- # Clone repository silently
28
- subprocess.run(
29
- ["git", "clone", "--depth", "1", LLAMA_GUARD_REPO, str(temp_path)],
30
- stdout=subprocess.DEVNULL,
31
- stderr=subprocess.DEVNULL,
32
- check=True
33
- )
34
-
35
- # Get the llama_guard folder from the cloned repo
36
- source_dir = temp_path / "llama_guard"
37
- if not source_dir.exists():
38
- raise FileNotFoundError(f"llama_guard folder not found in repository")
39
-
40
- # Copy to models directory
41
- save_path.parent.mkdir(parents=True, exist_ok=True)
42
- shutil.copytree(source_dir, save_path)
43
-
44
- # Combine split safetensor files
45
- model_file = save_path / "model.safetensors"
46
- if not model_file.exists():
47
- # Find and combine c_* files
48
- split_files = sorted(save_path.glob("c_*"))
49
- if split_files:
50
- with open(model_file, 'wb') as outfile:
51
- for split_file in split_files:
52
- with open(split_file, 'rb') as infile:
53
- outfile.write(infile.read())
54
-
55
- # Delete split files
56
- for split_file in split_files:
57
- split_file.unlink()
27
+ response = requests.get(url, stream=True)
28
+ response.raise_for_status()
58
29
 
30
+ with open(save_path, 'wb') as f:
31
+ for chunk in response.iter_content(chunk_size=8192):
32
+ f.write(chunk)
59
33
  except Exception as e:
60
- # Clean up on failure
34
+ print(f"Failed to download {url}: {e}")
35
+ # Clean up partial file
61
36
  if save_path.exists():
62
- shutil.rmtree(save_path)
63
- raise Exception(f"Failed to download Llama Guard: {e}")
37
+ os.remove(save_path)
64
38
 
65
39
  def download_llama_guard():
66
- """Public interface for downloading Llama Guard."""
67
- _download_llama_guard()
40
+ """Download Llama Guard files in parallel."""
41
+ save_dir = MODELS_DIR / "llama_guard"
42
+
43
+ # Check if all files exist
44
+ if save_dir.exists() and all((save_dir / f).exists() for f in FILES_TO_DOWNLOAD):
45
+ return
46
+
47
+ save_dir.mkdir(parents=True, exist_ok=True)
48
+
49
+ with ThreadPoolExecutor(max_workers=5) as executor:
50
+ futures = []
51
+ for filename in FILES_TO_DOWNLOAD:
52
+ url = f"{BASE_URL}/{filename}"
53
+ save_path = save_dir / filename
54
+ futures.append(executor.submit(_download_file, url, save_path))
55
+
56
+ for future in futures:
57
+ future.result()
58
+
Binary file
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: promptforest
3
- Version: 0.1.1
3
+ Version: 0.5.0
4
4
  Summary: Ensemble Prompt Injection Detection
5
5
  Requires-Python: >=3.8
6
6
  Description-Content-Type: text/markdown
@@ -36,25 +36,43 @@ This discrepancy score enables downstream workflows such as:
36
36
  - Adaptive throttling or alerting in production systems
37
37
  - Continuous monitoring and model improvement
38
38
 
39
+ ## Statistics
40
+ **E2E Request Latency** \
41
+ Average Case: 100ms \
42
+ Worst Case: 200ms
43
+
44
+ PromptForest was evaluated against the SOTA model Qualifire Sentinel model (v2).
45
+
46
+ | Metric | PromptForest | Sentinel v2 |
47
+ | -------------------------------- | ------------ | ----------- |
48
+ | Accuracy | 0.802 | 0.982 |
49
+ | Avg Confidence on Wrong Answers | 0.643 | 0.858 |
50
+ | Expected Calibration Error (ECE) | 0.060 | 0.202 |
51
+ | Approximate Model Size | ~250M params | 600M params |
52
+
53
+
54
+ ### Key Insights
55
+
56
+ - Calibrated uncertainty: PromptForest is less confident on wrong predictions than Sentinel, resulting in a much lower ECE.
57
+
58
+ - Parameter efficiency: Achieves competitive reliability with <50% of the parameters.
59
+
60
+ - Interpretability: Confidence scores can be used to flag uncertain predictions for human review.
61
+
62
+ Interpretation:
63
+ While Sentinel has higher raw accuracy, PromptForest provides better-calibrated confidence. For systems where overconfidence on wrong answers is risky, PromptForest can reduce the chance of critical errors despite being smaller and faster.
64
+
65
+ Using Sentinel v2 as a baseline, and given that other models perform worse than Sentinel in published benchmarks, PromptForest is expected to offer more reliable and calibrated predictions than most alternatives.
66
+
39
67
 
40
68
  ## Supported Models
41
69
 
42
70
  | Provider | Model Name |
43
71
  | ------------- | ----------------------------------------- |
44
- | **Meta** | Llama Prompt Guard 86M (Built with Llama) |
45
- | **ProtectAI** | DebertaV3 Prompt Injection Finetune |
46
- | **Deepset** | DebertaV3-base Injection Finetune |
47
- | **Katanemo** | Arch-Guard |
48
- | **Appleroll** | PromptForest-XGBoost |
49
-
50
- ## Performance
51
- **Request Latency** \
52
- Best Case: 50ms \
53
- Worst Case: 200ms
54
-
55
- **Accuracy** \
56
- Preliminary results indicate ensemble performance is at least as good as any individual model. Extensive benchmarking is ongoing.
57
-
72
+ | **Meta** | [Llama Prompt Guard 86M](https://huggingface.co/meta-llama/Prompt-Guard-86M) (Built with Llama) |
73
+ | **ProtectAI** | [DebertaV3 Prompt Injection Finetune](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2) |
74
+ | **Vijil** | [Vijil Dome Prompt Injection Detection](https://huggingface.co/vijil/vijil_dome_prompt_injection_detection) |
75
+ | **Appleroll** | [PromptForest-XGB](appleroll/promptforest-xgb) |
58
76
 
59
77
  ## Quick Start
60
78
  To use PromptForest, simply install the pip package and serve it at a port of your choice. It should automatically start downloading the default model ensemble.
@@ -0,0 +1,15 @@
1
+ promptforest/__init__.py,sha256=cE1cQyRL4vUzseCwLYbI5wrZuZ-NRMVXIjAgwTLwIEs,54
2
+ promptforest/cli.py,sha256=LKsnbEQNQ9pP_Ww24Ql2Tb_uomO-StqHnk-IHONSKTM,1856
3
+ promptforest/config.py,sha256=c_7GX7nh_1Aa-QU7SOZlthPNGXSoh2KvYOk7txJeQh4,3284
4
+ promptforest/download.py,sha256=6TQvo2qd3tUUxJU6MMsFMgOciHP5HNDNEo3UTOeYI34,2637
5
+ promptforest/lib.py,sha256=WEuEhNNlRQAerLyEIbTHdi15qdXUMuiQOhfsvaftj4M,9254
6
+ promptforest/llama_guard_86m_downloader.py,sha256=0B2ttwLWHki0yLEoJG3BwyFE1oqJFY0M2mLEtmMWmPk,1720
7
+ promptforest/server.py,sha256=uF4Yj7yR_2vEx_7nQabGHGGw-6GWnT0iBZx3UPQK634,2905
8
+ promptforest/xgboost/xgb_model.pkl,sha256=kSG2r-6TGfhNJfzwklLQOSgG2z610Z5BXxtgQdXE8Vk,2116991
9
+ promptforest-0.5.0.dist-info/licenses/LICENSE.txt,sha256=GgVl4CdplCpCEssTcrmIRbz52zQc0fdcSETZp34uBF4,11349
10
+ promptforest-0.5.0.dist-info/licenses/NOTICE.md,sha256=XGjuV5VAWBinW6Jzu7-9h0Ph3xwCNzcJdbMH_EgU_g4,356
11
+ promptforest-0.5.0.dist-info/METADATA,sha256=fEgp4u7q-P74Zo3eF0gnEjVSFMuIc9z9g-1AoAKPAZs,5002
12
+ promptforest-0.5.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
13
+ promptforest-0.5.0.dist-info/entry_points.txt,sha256=sVcjABvpA7P2fXca2KMZSYf0PNfDgLt1NHlYFMPO_eE,55
14
+ promptforest-0.5.0.dist-info/top_level.txt,sha256=NxasbbadJaf8w9zaRXo5KOdBqNA1oDe-2X7e6zdz3k0,13
15
+ promptforest-0.5.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,15 +0,0 @@
1
- promptforest/__init__.py,sha256=cE1cQyRL4vUzseCwLYbI5wrZuZ-NRMVXIjAgwTLwIEs,54
2
- promptforest/cli.py,sha256=LKsnbEQNQ9pP_Ww24Ql2Tb_uomO-StqHnk-IHONSKTM,1856
3
- promptforest/config.py,sha256=bOFHlK0kI7c3fzccZrcjccNUfZJPzvLtKEAZ_loLvzE,3366
4
- promptforest/download.py,sha256=3Ss1BX6kQatfhif1cbErUekPlSA2RCqtiatUzGi72zo,2454
5
- promptforest/lib.py,sha256=LT8A1_veV9tB2DyrZ0JEOBW4EWEs9El5xOxF0zNHOAc,8042
6
- promptforest/llama_guard_86m_downloader.py,sha256=ibFeeuDgMBVe-8aD0zl23xJKOPdKyw-4Bsf0iZJih4s,2412
7
- promptforest/server.py,sha256=uF4Yj7yR_2vEx_7nQabGHGGw-6GWnT0iBZx3UPQK634,2905
8
- promptforest/xgboost/xgb_model.pkl,sha256=97Y_Dfu8PwubkplRXJdNEuAj9te1v-nEJlXfPpEZWdM,748772
9
- promptforest-0.1.1.dist-info/licenses/LICENSE.txt,sha256=GgVl4CdplCpCEssTcrmIRbz52zQc0fdcSETZp34uBF4,11349
10
- promptforest-0.1.1.dist-info/licenses/NOTICE.md,sha256=XGjuV5VAWBinW6Jzu7-9h0Ph3xwCNzcJdbMH_EgU_g4,356
11
- promptforest-0.1.1.dist-info/METADATA,sha256=o1T79TkOnH3uMEWzI31xwmyP-QvFKH2JMHBLFv-WGVI,3700
12
- promptforest-0.1.1.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
13
- promptforest-0.1.1.dist-info/entry_points.txt,sha256=sVcjABvpA7P2fXca2KMZSYf0PNfDgLt1NHlYFMPO_eE,55
14
- promptforest-0.1.1.dist-info/top_level.txt,sha256=NxasbbadJaf8w9zaRXo5KOdBqNA1oDe-2X7e6zdz3k0,13
15
- promptforest-0.1.1.dist-info/RECORD,,