promptforest 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
promptforest/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  __version__ = "0.1.0"
2
2
 
3
- from .lib import EnsembleGuard
3
+ from .lib import PFEnsemble
promptforest/cli.py CHANGED
@@ -33,11 +33,11 @@ def main():
33
33
  run_server(port=args.port, config=cfg)
34
34
 
35
35
  elif args.command == "check":
36
- from .lib import EnsembleGuard
36
+ from .lib import PFEnsemble
37
37
  cfg = get_user_config(args.config)
38
38
  try:
39
39
  print(f"Loading PromptForest...")
40
- guard = EnsembleGuard(config=cfg)
40
+ guard = PFEnsemble(config=cfg)
41
41
  print(f"Device: {guard.device_used}")
42
42
  result = guard.check_prompt(args.prompt)
43
43
  print(json.dumps(result, indent=2))
promptforest/config.py CHANGED
@@ -12,10 +12,9 @@ MODELS_DIR = USER_DATA_DIR / "models"
12
12
 
13
13
  DEFAULT_CONFIG = {
14
14
  "models": [
15
- {"name": "llama_guard", "path": "llama_guard", "type": "hf", "enabled": True},
16
- {"name": "protectai", "path": "protectai", "type": "hf", "enabled": True},
17
- {"name": "vijil", "path": "vijil_dome", "type": "hf", "enabled": True},
18
- {"name": "xgboost", "type": "xgboost", "enabled": True, "threshold": 0.10}
15
+ {"name": "llama_guard", "path": "llama_guard", "type": "hf", "enabled": True, "accuracy_weight": 1.0},
16
+ {"name": "vijil", "path": "vijil_dome", "type": "hf", "enabled": True, "accuracy_weight": 1.0},
17
+ {"name": "xgboost", "type": "xgboost", "enabled": True, "threshold": 0.10, "accuracy_weight": 1.0}
19
18
  ],
20
19
  "settings": {
21
20
  "device": "auto", # Options: auto, cuda, mps, cpu
@@ -30,7 +29,7 @@ def load_config(config_path=None):
30
29
  """
31
30
  Load configuration from a YAML file, merging with defaults.
32
31
  """
33
- # Start with a deep copy of the default config structure
32
+ # Deep copy of the default config structure
34
33
  config = {
35
34
  "models": [m.copy() for m in DEFAULT_CONFIG["models"]],
36
35
  "settings": DEFAULT_CONFIG["settings"].copy(),
@@ -44,19 +43,17 @@ def load_config(config_path=None):
44
43
  with open(path, 'r') as f:
45
44
  user_config = yaml.safe_load(f)
46
45
  if user_config:
47
- # 1. Merge Settings
46
+ # Merge config
47
+ # @todo: is there a smarter way to merge this?
48
48
  if "settings" in user_config:
49
49
  config["settings"].update(user_config["settings"])
50
50
 
51
- # 2. Merge Logging
52
51
  if "logging" in user_config:
53
52
  config["logging"].update(user_config["logging"])
54
53
 
55
- # 3. Merge Models (Smart Merge)
56
54
  if "models" in user_config:
57
55
  user_models = user_config["models"]
58
56
  if isinstance(user_models, list):
59
- # Convert current models to dict for easy lookup by name
60
57
  existing_model_map = {m["name"]: m for m in config["models"]}
61
58
 
62
59
  for u_model in user_models:
promptforest/download.py CHANGED
@@ -11,9 +11,9 @@ from sentence_transformers import SentenceTransformer
11
11
  from .config import MODELS_DIR
12
12
  from .llama_guard_86m_downloader import download_llama_guard
13
13
 
14
- # Configuration
15
14
  MODELS = {
16
- "protectai": "protectai/deberta-v3-base-prompt-injection-v2",
15
+ # We are currently not using ProtectAI as it actively degrades ensemble performance while being 86M params
16
+ # "protectai": "protectai/deberta-v3-base-prompt-injection-v2",
17
17
  "vijil_dome": "vijil/vijil_dome_prompt_injection_detection"
18
18
  }
19
19
 
@@ -31,8 +31,11 @@ def _download_hf_model(name, model_id):
31
31
  if save_path.exists():
32
32
  return
33
33
 
34
- # Special handling for Vijil (ModernBERT tokenizer issue)
34
+
35
35
  tokenizer_id = model_id
36
+
37
+ # Vijil uses ModernBERT tokenizer
38
+ # @todo this should not be hardcoded
36
39
  if "vijil" in name or "vijil" in model_id:
37
40
  tokenizer_id = "answerdotai/ModernBERT-base"
38
41
 
@@ -47,17 +50,14 @@ def _download_hf_model(name, model_id):
47
50
 
48
51
  def _download_sentence_transformer():
49
52
  """Download and save the SentenceTransformer model."""
50
- # print(f"Downloading SentenceTransformer ({EMBEDDING_MODEL_NAME})...")
51
53
  save_path = MODELS_DIR / 'sentence_transformer'
52
54
 
53
55
  try:
54
56
  if save_path.exists():
55
- # print(f" - Model already exists at {save_path}. Skipping.")
56
57
  return
57
58
 
58
59
  model = SentenceTransformer(EMBEDDING_MODEL_NAME)
59
60
  model.save(str(save_path))
60
- #print(f" - Saved to {save_path}")
61
61
 
62
62
  except Exception as e:
63
63
  print(f"SentenceTransformer download failed: {e}")
@@ -66,11 +66,11 @@ def download_all():
66
66
  print(f"Downloading models to {MODELS_DIR}...")
67
67
  _ensure_dir(MODELS_DIR)
68
68
 
69
- # Download Llama Guard in parallel (slowest download)
69
+ # Download Llama Guard first as it takes the longest
70
70
  llama_guard_thread = threading.Thread(target=download_llama_guard, daemon=False)
71
71
  llama_guard_thread.start()
72
72
 
73
- # Download HF Classification Models
73
+ # Download each model from Hugging Face
74
74
  for name, model_id in MODELS.items():
75
75
  _download_hf_model(name, model_id)
76
76
 
promptforest/lib.py CHANGED
@@ -16,7 +16,7 @@ from .config import MODELS_DIR, XGB_MODEL_PATH, load_config
16
16
 
17
17
  # Suppress Warnings
18
18
  transformers_logging.set_verbosity_error()
19
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # Prevent deadlocks/warnings
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
20
 
21
21
  MALICIOUS_KEYWORDS = ['unsafe', 'malicious', 'injection', 'attack', 'jailbreak']
22
22
 
@@ -45,7 +45,7 @@ class HFModel(ModelInference):
45
45
 
46
46
  def _load(self):
47
47
  if not self.path.exists():
48
- print(f"[WARN] Model path not found: {self.path}")
48
+ print(f"Model path not found: {self.path}")
49
49
  return
50
50
 
51
51
  try:
@@ -61,7 +61,7 @@ class HFModel(ModelInference):
61
61
  self.model.eval()
62
62
  self._determine_label_map()
63
63
  except Exception as e:
64
- print(f"[ERR] Failed to load {self.name}: {e}")
64
+ print(f"Error: Failed to load {self.name}: {e}")
65
65
  self.model = None
66
66
 
67
67
  def _determine_label_map(self):
@@ -101,9 +101,9 @@ class HFModel(ModelInference):
101
101
 
102
102
  class XGBoostModel(ModelInference):
103
103
  def __init__(self, settings, config=None):
104
- self.name = "xgboost_custom"
105
- self.settings = settings
106
104
  self.config = config or {}
105
+ self.name = self.config.get("name", "xgboost")
106
+ self.settings = settings
107
107
  self.threshold = self.config.get('threshold', 0.5)
108
108
  self.model = None
109
109
  self.embedder = None
@@ -123,7 +123,6 @@ class XGBoostModel(ModelInference):
123
123
  self.embedder = SentenceTransformer(str(ST_PATH))
124
124
  else:
125
125
  print("Cannot find local SentenceTransformer model. Downloading...")
126
- # Fallback to download default if local cache missing
127
126
  self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
128
127
 
129
128
  if self.device_name in ['cuda', 'mps']:
@@ -134,7 +133,7 @@ class XGBoostModel(ModelInference):
134
133
  except:
135
134
  pass
136
135
  except Exception as e:
137
- print(f"[ERR] Failed to load XGBoost: {e}")
136
+ print(f"Error: Failed to load XGBoost: {e}")
138
137
  self.model = None
139
138
 
140
139
  def predict(self, prompt):
@@ -159,10 +158,10 @@ class XGBoostModel(ModelInference):
159
158
  return 0.0
160
159
 
161
160
 
162
- class EnsembleGuard:
161
+ class PFEnsemble:
163
162
  def __init__(self, config=None):
164
163
  """
165
- Initialize the EnsembleGuard.
164
+ Initialize the PFEnsemble.
166
165
  :param config: Dictionary containing configuration. If None, loads default/user config.
167
166
  """
168
167
  if config is None:
@@ -221,6 +220,11 @@ class EnsembleGuard:
221
220
  print(f"Unknown model type: {model_type}")
222
221
 
223
222
  def check_prompt(self, prompt):
223
+ """
224
+ Checks the prompt using the ensemble of models.
225
+
226
+ :param prompt: The prompt string to check.
227
+ """
224
228
  start_time = time.perf_counter()
225
229
  results = {}
226
230
 
@@ -241,24 +245,43 @@ class EnsembleGuard:
241
245
  if not probs:
242
246
  return {"error": "No models loaded"}
243
247
 
248
+ # Calculate weighted average
249
+ weighted_sum = 0.0
250
+ total_weight = 0.0
251
+ model_configs = {m['name']: m for m in self.config.get('models', [])}
252
+
253
+ for model_name, prob in results.items():
254
+ model_cfg = model_configs.get(model_name, {})
255
+ # Default weight is 1.0 if not specified
256
+ weight = float(model_cfg.get('accuracy_weight', 1.0))
257
+ weighted_sum += prob * weight
258
+ total_weight += weight
259
+
260
+ if total_weight > 0:
261
+ w_avg_prob = weighted_sum / total_weight
262
+ else:
263
+ w_avg_prob = np.mean(probs)
264
+
244
265
  avg_prob = np.mean(probs)
266
+
245
267
  max_prob = np.max(probs)
246
268
 
247
269
  # Uncertainty
248
270
  std_dev = np.std(probs)
249
271
  uncertainty_score = min(std_dev * 2, 1.0)
250
272
 
251
- is_malicious = avg_prob > 0.5
273
+ is_malicious = w_avg_prob > 0.5
252
274
 
253
275
  response = {
254
276
  "is_malicious": bool(is_malicious),
277
+ # We use average probability for confidence - better results (2-3x improvement in benchmarks)
255
278
  "confidence": float(avg_prob if is_malicious else 1 - avg_prob),
256
279
  "uncertainty": float(uncertainty_score),
257
280
  "malicious_score": float(avg_prob),
258
281
  "max_risk_score": float(max_prob)
259
282
  }
260
283
 
261
- # Add stats if requested
284
+ # Add stats if logging is requested
262
285
  if self.config.get('logging', {}).get('stats', True):
263
286
  response["details"] = results
264
287
  response["latency_ms"] = duration_ms
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Script to download Llama Guard 2 86M from custom GitHub releases.
3
- Downloads files in parallel for speed.
3
+
4
+ @todo: This should be legal, but as soon as anyone in Meta hints us to remove this, we will comply immediately.
4
5
  """
5
6
 
6
7
  import os
@@ -42,6 +43,7 @@ def download_llama_guard():
42
43
 
43
44
  # Check if all files exist
44
45
  if save_dir.exists() and all((save_dir / f).exists() for f in FILES_TO_DOWNLOAD):
46
+ # All files already exist, we don't need to download them again
45
47
  return
46
48
 
47
49
  save_dir.mkdir(parents=True, exist_ok=True)
promptforest/server.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """
2
- Simple HTTP Server for PromptForest.
2
+ Server module for PromptForest. Sets up an HTTP server to handle inference requests
3
3
  """
4
4
 
5
5
  import http.server
@@ -7,12 +7,12 @@ import socketserver
7
7
  import json
8
8
  import sys
9
9
  import time
10
- from .lib import EnsembleGuard
10
+ from .lib import PFEnsemble
11
11
 
12
12
  PORT = 8000
13
13
  ensemble = None
14
14
 
15
- class GuardRequestHandler(http.server.BaseHTTPRequestHandler):
15
+ class PFRequestHandler(http.server.BaseHTTPRequestHandler):
16
16
  def do_POST(self):
17
17
  """Handle POST requests for inference."""
18
18
  if self.path == '/analyze':
@@ -65,7 +65,7 @@ def run_server(port=8000, config=None):
65
65
  global ensemble
66
66
  print(f"Initializing PromptForest...")
67
67
  try:
68
- ensemble = EnsembleGuard(config=config)
68
+ ensemble = PFEnsemble(config=config)
69
69
  print(f"Device: {ensemble.device_used}")
70
70
  print("Warming up...")
71
71
  ensemble.check_prompt("warmup")
@@ -77,7 +77,7 @@ def run_server(port=8000, config=None):
77
77
  print(f"\nListening on http://localhost:{port}")
78
78
  socketserver.TCPServer.allow_reuse_address = True
79
79
 
80
- with ThreadedHTTPServer(("", port), GuardRequestHandler) as httpd:
80
+ with ThreadedHTTPServer(("", port), PFRequestHandler) as httpd:
81
81
  try:
82
82
  httpd.serve_forever()
83
83
  except KeyboardInterrupt:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: promptforest
3
- Version: 0.5.0
3
+ Version: 0.7.0
4
4
  Summary: Ensemble Prompt Injection Detection
5
5
  Requires-Python: >=3.8
6
6
  Description-Content-Type: text/markdown
@@ -24,6 +24,10 @@ Dynamic: requires-python
24
24
  Dynamic: summary
25
25
 
26
26
  # PromptForest - Fast and Reliable Injection Detector Ensemble
27
+ ![PyPI Downloads](https://img.shields.io/pypi/dm/promptforest)
28
+ ![Apache License](https://img.shields.io/badge/license-Apache%20License%202.0-blue)
29
+
30
+ **📖 TRY IT OUT ON A NOTEBOOK [HERE](https://colab.research.google.com/drive/1EW49Qx1ZlaAYchqplDIVk2FJVzCqOs6B?usp=sharing)!**
27
31
 
28
32
  PromptForest is a prompt injection detector ensemble focused on real-world latency and reliability.
29
33
 
@@ -36,53 +40,52 @@ This discrepancy score enables downstream workflows such as:
36
40
  - Adaptive throttling or alerting in production systems
37
41
  - Continuous monitoring and model improvement
38
42
 
43
+ ## Quick Start
44
+ To use PromptForest, simply install the pip package and serve it at a port of your choice. It should automatically start downloading the default model ensemble.
45
+
46
+ Gated models are downloaded through our own [ensemble github respository](https://github.com/appleroll-research/promptforest-model-ensemble) and are released in accordance to their terms and conditions.
47
+
48
+ ```bash
49
+ pip install promptforest
50
+ promptforest serve --port 8000
51
+ ```
52
+
39
53
  ## Statistics
40
54
  **E2E Request Latency** \
41
55
  Average Case: 100ms \
42
56
  Worst Case: 200ms
43
57
 
44
- PromptForest was evaluated against the SOTA model Qualifire Sentinel model (v2).
58
+ PromptForest was evaluated against the models from Deepset, ProtectAI, Meta and Vijil, with Promptforest and the SOTA model Qualifire's Sentinel V2 performing the best in terms of reliability.
45
59
 
46
60
  | Metric | PromptForest | Sentinel v2 |
47
61
  | -------------------------------- | ------------ | ----------- |
48
- | Accuracy | 0.802 | 0.982 |
49
- | Avg Confidence on Wrong Answers | 0.643 | 0.858 |
50
- | Expected Calibration Error (ECE) | 0.060 | 0.202 |
51
- | Approximate Model Size | ~250M params | 600M params |
62
+ | Accuracy | 0.901 | 0.973 |
63
+ | Avg Confidence on Wrong Answers | 0.642 | 0.76 |
64
+ | Expected Calibration Error (ECE) | 0.070 | 0.096 |
65
+ | Total Model Size | ~237M params | ~600M params |
52
66
 
53
67
 
54
68
  ### Key Insights
55
69
 
56
- - Calibrated uncertainty: PromptForest is less confident on wrong predictions than Sentinel, resulting in a much lower ECE.
70
+ - Calibrated uncertainty: PromptForest is less confident on wrong predictions than compared models, resulting in a much lower ECE.
57
71
 
58
72
  - Parameter efficiency: Achieves competitive reliability with <50% of the parameters.
59
73
 
60
74
  - Interpretability: Confidence scores can be used to flag uncertain predictions for human review.
61
75
 
62
- Interpretation:
63
- While Sentinel has higher raw accuracy, PromptForest provides better-calibrated confidence. For systems where overconfidence on wrong answers is risky, PromptForest can reduce the chance of critical errors despite being smaller and faster.
64
-
65
76
  Using Sentinel v2 as a baseline, and given that other models perform worse than Sentinel in published benchmarks, PromptForest is expected to offer more reliable and calibrated predictions than most alternatives.
66
77
 
67
78
 
68
- ## Supported Models
79
+ ## Models
69
80
 
70
81
  | Provider | Model Name |
71
82
  | ------------- | ----------------------------------------- |
72
83
  | **Meta** | [Llama Prompt Guard 86M](https://huggingface.co/meta-llama/Prompt-Guard-86M) (Built with Llama) |
73
- | **ProtectAI** | [DebertaV3 Prompt Injection Finetune](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2) |
74
84
  | **Vijil** | [Vijil Dome Prompt Injection Detection](https://huggingface.co/vijil/vijil_dome_prompt_injection_detection) |
75
- | **Appleroll** | [PromptForest-XGB](appleroll/promptforest-xgb) |
76
-
77
- ## Quick Start
78
- To use PromptForest, simply install the pip package and serve it at a port of your choice. It should automatically start downloading the default model ensemble.
79
-
80
- Gated models are downloaded through our own [ensemble github respository](https://github.com/appleroll-research/promptforest-model-ensemble) and are released in accordance to their terms and conditions.
85
+ | **Appleroll** | [PromptForest-XGB](https://huggingface.co/appleroll/promptforest-xgb) |
81
86
 
82
- ```bash
83
- pip install promptforest
84
- promptforest serve --port 8000
85
- ```
87
+ ## Current Goals
88
+ This project is actively being updated. Our current focus is on implementing weights on individual models to improve accuracy, as well as retraining the XGBoost model with an updated corpus.
86
89
 
87
90
  ## Disclaimer & Limitations
88
91
 
@@ -0,0 +1,15 @@
1
+ promptforest/__init__.py,sha256=wbIBy-XFARDVm5TmFud-IBumHPZ5ps8Phsjz0tYTUgU,51
2
+ promptforest/cli.py,sha256=T4sLBrNp09mHePOBHSB4QobyCDZdbLcBEm3_saoRMQU,1850
3
+ promptforest/config.py,sha256=MyeH2qeHpE3mVmqaSrjC8VP8P0Q0ZsftdN3TMfQGUP4,3138
4
+ promptforest/download.py,sha256=lKje_L2-CU2e56U932l9Q3ueLDt-Mcq8SLEMNOk_lBA,2572
5
+ promptforest/lib.py,sha256=Sp8DdvTooKJVGqG2MtFaFhoGbEgACHxyVxRCjsALlXs,10053
6
+ promptforest/llama_guard_86m_downloader.py,sha256=MS9YG6MepU0ToskZ9f2iwESg2EnXKGyFzWXwR2s2Xac,1866
7
+ promptforest/server.py,sha256=NY6mn9l4-PpTvkQx6zlQpoE6vDTv_Cm1QvwE3SfWi6g,2940
8
+ promptforest/xgboost/xgb_model.pkl,sha256=kSG2r-6TGfhNJfzwklLQOSgG2z610Z5BXxtgQdXE8Vk,2116991
9
+ promptforest-0.7.0.dist-info/licenses/LICENSE.txt,sha256=GgVl4CdplCpCEssTcrmIRbz52zQc0fdcSETZp34uBF4,11349
10
+ promptforest-0.7.0.dist-info/licenses/NOTICE.md,sha256=XGjuV5VAWBinW6Jzu7-9h0Ph3xwCNzcJdbMH_EgU_g4,356
11
+ promptforest-0.7.0.dist-info/METADATA,sha256=99z2XiN2SFQVons3IHfZW6Iai7ndHNt6XkEk8LWTmb8,5222
12
+ promptforest-0.7.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
13
+ promptforest-0.7.0.dist-info/entry_points.txt,sha256=sVcjABvpA7P2fXca2KMZSYf0PNfDgLt1NHlYFMPO_eE,55
14
+ promptforest-0.7.0.dist-info/top_level.txt,sha256=NxasbbadJaf8w9zaRXo5KOdBqNA1oDe-2X7e6zdz3k0,13
15
+ promptforest-0.7.0.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- promptforest/__init__.py,sha256=cE1cQyRL4vUzseCwLYbI5wrZuZ-NRMVXIjAgwTLwIEs,54
2
- promptforest/cli.py,sha256=LKsnbEQNQ9pP_Ww24Ql2Tb_uomO-StqHnk-IHONSKTM,1856
3
- promptforest/config.py,sha256=c_7GX7nh_1Aa-QU7SOZlthPNGXSoh2KvYOk7txJeQh4,3284
4
- promptforest/download.py,sha256=6TQvo2qd3tUUxJU6MMsFMgOciHP5HNDNEo3UTOeYI34,2637
5
- promptforest/lib.py,sha256=WEuEhNNlRQAerLyEIbTHdi15qdXUMuiQOhfsvaftj4M,9254
6
- promptforest/llama_guard_86m_downloader.py,sha256=0B2ttwLWHki0yLEoJG3BwyFE1oqJFY0M2mLEtmMWmPk,1720
7
- promptforest/server.py,sha256=uF4Yj7yR_2vEx_7nQabGHGGw-6GWnT0iBZx3UPQK634,2905
8
- promptforest/xgboost/xgb_model.pkl,sha256=kSG2r-6TGfhNJfzwklLQOSgG2z610Z5BXxtgQdXE8Vk,2116991
9
- promptforest-0.5.0.dist-info/licenses/LICENSE.txt,sha256=GgVl4CdplCpCEssTcrmIRbz52zQc0fdcSETZp34uBF4,11349
10
- promptforest-0.5.0.dist-info/licenses/NOTICE.md,sha256=XGjuV5VAWBinW6Jzu7-9h0Ph3xwCNzcJdbMH_EgU_g4,356
11
- promptforest-0.5.0.dist-info/METADATA,sha256=fEgp4u7q-P74Zo3eF0gnEjVSFMuIc9z9g-1AoAKPAZs,5002
12
- promptforest-0.5.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
13
- promptforest-0.5.0.dist-info/entry_points.txt,sha256=sVcjABvpA7P2fXca2KMZSYf0PNfDgLt1NHlYFMPO_eE,55
14
- promptforest-0.5.0.dist-info/top_level.txt,sha256=NxasbbadJaf8w9zaRXo5KOdBqNA1oDe-2X7e6zdz3k0,13
15
- promptforest-0.5.0.dist-info/RECORD,,