pi-research 1.4.0 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +30 -86
- package/ml/models/conflict-structured/feature-names.json +22 -0
- package/ml/models/conflict-structured/meta.json +5 -0
- package/ml/models/conflict-structured/model.joblib +0 -0
- package/ml/models/domain/metrics.json +16 -0
- package/ml/models/domain/model.joblib +0 -0
- package/ml/models/domain-lr/metrics.json +16 -0
- package/ml/models/domain-lr/model.joblib +0 -0
- package/ml/models/followup/meta.json +3 -0
- package/ml/models/followup/model.joblib +0 -0
- package/ml/models/sufficiency-structured/feature-names.json +22 -0
- package/ml/models/sufficiency-structured/meta.json +5 -0
- package/ml/models/sufficiency-structured/model.joblib +0 -0
- package/ml/router/README.md +106 -0
- package/ml/router/__pycache__/features.cpython-314.pyc +0 -0
- package/ml/router/benchmark_latency.py +81 -0
- package/ml/router/daemon.py +140 -0
- package/ml/router/embed_model2vec.py +48 -0
- package/ml/router/evaluate_domain.py +67 -0
- package/ml/router/features.py +60 -0
- package/ml/router/requirements.txt +5 -0
- package/ml/router/train_classifier.py +57 -0
- package/ml/router/train_domain_classifier.py +209 -0
- package/ml/router/train_structured_baseline.py +174 -0
- package/package.json +8 -5
package/README.md
CHANGED
|
@@ -1,100 +1,29 @@
|
|
|
1
|
-
#
|
|
1
|
+
# ⚠️ This package has moved
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
> **`pi-research` is deprecated and will no longer receive updates.**
|
|
4
4
|
|
|
5
|
-
[
|
|
6
|
-
[](https://github.com/endgegnerbert-tech/pi-research)
|
|
7
|
-
[](https://pi.ai)
|
|
5
|
+
## 👉 Please migrate to [`emet`](https://www.npmjs.com/package/emet)
|
|
8
6
|
|
|
9
|
-
**The Zero-Setup Research Engine for Autonomous AI Agents.**
|
|
10
|
-
|
|
11
|
-
`pi-research` is an advanced grounding tool designed specifically for AI coding agents. It prevents agents from hallucinating API endpoints, guessing library versions, or inventing CVE details by injecting real-time, highly authoritative, and conflict-resolved web research directly into their context window.
|
|
12
|
-
|
|
13
|
-

|
|
14
|
-
|
|
15
|
-
## 💡 Why `pi-research`?
|
|
16
|
-
|
|
17
|
-
The world does not need just another "AI Search Engine"—there are plenty of massive, standalone research tools out there.
|
|
18
|
-
|
|
19
|
-
Instead, `pi-research` was built specifically to solve a crucial problem in the **Agentic Workflow**: When an autonomous agent is deep in a coding loop, compiling errors, or debugging, it needs hard facts instantly without losing focus. Calling out to heavy external search services or trying to execute brittle Playwright scripts breaks the agent's flow, wastes context window tokens, and leads to hallucinations.
|
|
20
|
-
|
|
21
|
-
`pi-research` solves this by providing a lightweight, internal **cognitive research loop** directly into the agent harness:
|
|
22
|
-
1. **Agent-Centric Routing:** It knows exactly where developers look (GitHub, NPM, NIST, arXiv).
|
|
23
|
-
2. **Authority First:** It prioritizes official documentation over random SEO-optimized tutorials.
|
|
24
|
-
3. **Self-Awareness:** It extracts structured features to know when it lacks information, safely triggering follow-up questions *before* returning an answer to the agent.
|
|
25
|
-
|
|
26
|
-
Best of all? **Zero setup.** No external search API keys to configure, no heavy local LLMs to run, and no flaky browser automation scripts to maintain. It's built to run silently and reliably alongside your agent.
|
|
27
|
-
|
|
28
|
-
---
|
|
29
|
-
|
|
30
|
-
## ✨ Features
|
|
31
|
-
|
|
32
|
-
- 🚀 **Lightning Fast:** Powered by a Hybrid Tiny-Router Architecture (Model2Vec + SVC), routing queries in **< 0.6 milliseconds**.
|
|
33
|
-
- 🛡️ **Anti-Hallucination:** Built-in Veto-Power for high-risk queries. If a security question only finds blog posts, the system forces a follow-up to find authoritative NIST/CVE data.
|
|
34
|
-
- 🕸️ **Resilient Fetching:** Pre-emptively escalates blocked, JS-heavy, or thin pages through an integrated, robust Python `Scrapling` daemon (via IPC JSON-RPC 2.0).
|
|
35
|
-
- 🧩 **Domain Packs:** Built-in heuristics for `github`, `security`, `papers`, `package-registry`, and more.
|
|
36
|
-
- 📊 **Structured Outputs:** Returns citations, code blocks, missing aspects, confidence scores, and conflict summaries (e.g., "Source A contradicts Source B").
|
|
37
|
-
- 📂 **Local Context:** Ingests local files (`options.files`) to ground web research in your current repository context.
|
|
38
|
-
|
|
39
|
-
---
|
|
40
|
-
|
|
41
|
-
## 📦 Installation
|
|
42
|
-
|
|
43
|
-
### Pi Coding Agent (Extension)
|
|
44
|
-
If you are using the Pi Agent harness, install the extension directly:
|
|
45
|
-
```bash
|
|
46
|
-
pi install npm:pi-research
|
|
47
|
-
```
|
|
48
|
-
|
|
49
|
-
### Node.js / NPM (Standalone Server)
|
|
50
|
-
Install it globally to expose the MCP (Model Context Protocol) server for any compatible AI agent:
|
|
51
7
|
```bash
|
|
52
|
-
|
|
53
|
-
pi-research
|
|
54
|
-
```
|
|
55
|
-
*(The MCP server identifies itself as `unblind-mcp`, exposing the tool `pi-research`)*
|
|
8
|
+
# Uninstall old package
|
|
9
|
+
npm uninstall -g pi-research
|
|
56
10
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
## 🚀 Quick Start / Usage
|
|
60
|
-
|
|
61
|
-
Once installed, your agent has access to the `pi-research` tool. It accepts a `query`, a `mode`, and various `options`.
|
|
62
|
-
|
|
63
|
-
### Modes
|
|
64
|
-
| Mode | Best for |
|
|
65
|
-
| --- | --- |
|
|
66
|
-
| `fast` | Quick factual lookups (e.g., "What is the latest LTS version of Node.js?"). Stops fetching early if authoritative sources are found. |
|
|
67
|
-
| `deep` | Broader retrieval with automatic follow-up rounds. Perfect for comparisons, conflicts, or unclear architecture questions. |
|
|
68
|
-
| `code` | Docs, repositories, README-driven answers, and retrieving actual code snippets. |
|
|
69
|
-
| `academic` | Scholarly sources, DOI links, and paper-heavy topics. |
|
|
70
|
-
|
|
71
|
-
### Example Tool Calls (For Agents)
|
|
72
|
-
**Factual Lookup:**
|
|
73
|
-
```json
|
|
74
|
-
{
|
|
75
|
-
"query": "React 19 RC release notes",
|
|
76
|
-
"mode": "fast",
|
|
77
|
-
"options": { "requireAuthoritative": true }
|
|
78
|
-
}
|
|
79
|
-
```
|
|
11
|
+
# Install new package
|
|
12
|
+
npm install -g emet
|
|
80
13
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
{
|
|
84
|
-
"query": "Compare PostgreSQL and MySQL for multi-tenant SaaS",
|
|
85
|
-
"mode": "deep",
|
|
86
|
-
"options": { "preferRecent": true, "maxTurns": 2 }
|
|
87
|
-
}
|
|
14
|
+
# Pi Extension
|
|
15
|
+
pi install npm:emet
|
|
88
16
|
```
|
|
89
17
|
|
|
90
18
|
---
|
|
91
19
|
|
|
92
|
-
|
|
20
|
+
### What changed?
|
|
93
21
|
|
|
94
|
-
|
|
22
|
+
`pi-research` has been rebranded to **`emet`** (Hebrew for truth/fact) — a cleaner, standalone identity that better reflects what the tool does: it grounds AI agents in factual, real-time context.
|
|
95
23
|
|
|
24
|
+
<<<<<<< HEAD
|
|
96
25
|
- **Model2Vec & SVC:** Queries are classified via locally embedded features. Security and paper queries have a 0% downgrade rate.
|
|
97
|
-
- **Structured ML:** Instead of asking a heavy LLM "Is this enough data?", the system extracts deterministic features (`has_authority`, `conflict_state`) and uses an ultra-fast Logistic Regression model to evaluate sufficiency and follow-up actions
|
|
26
|
+
- **Structured ML:** Instead of asking a heavy LLM "Is this enough data?", the system extracts deterministic features (`has_authority`, `conflict_state`) and uses an ultra-fast Logistic Regression model to evaluate sufficiency and follow-up actions wich achieved 100% accuracy on the included eval_unseen_hard.js benchmark dataset (121 test cases)”
|
|
98
27
|
- **Node.js-to-Python IPC:** Operates entirely locally using a highly optimized, line-delimited JSON-RPC daemon to manage Python dependencies (`Scrapling`, `Model2Vec`) without memory leaks.
|
|
99
28
|
|
|
100
29
|
---
|
|
@@ -102,7 +31,7 @@ With `1.4.0`, `pi-research` shifted from heavy, generative JSON-planners to a **
|
|
|
102
31
|
## 🛣️ Future Roadmap
|
|
103
32
|
|
|
104
33
|
We are actively working on scaling the reasoning capabilities:
|
|
105
|
-
- **LLM Data Augmentation (Weak Supervision):** Generating synthetic training data for underconfident domains to boost zero-shot accuracy
|
|
34
|
+
- **LLM Data Augmentation (Weak Supervision):** Generating synthetic training data for underconfident domains to boost zero-shot accuracy targeting >95%” without manual labeling.
|
|
106
35
|
- **Active Learning Telemetry Loop:** Clustering low-confidence predictions from cache logs into a weakly-supervised retraining pipeline to let the system "self-heal."
|
|
107
36
|
- **Cross-Encoder for Conflict Detection:** Transitioning to a fine-tuned Cross-Encoder (e.g., MiniLM + Natural Language Inference) to detect deep semantic contradiction across differing texts (e.g., recognizing that "Node 20 is stable" contradicts "Node 20 is broken").
|
|
108
37
|
|
|
@@ -111,4 +40,19 @@ We are actively working on scaling the reasoning capabilities:
|
|
|
111
40
|
## 📝 License & Notices
|
|
112
41
|
- **License:** MIT
|
|
113
42
|
- **Third-party notices:** See `THIRD_PARTY_NOTICES.md`
|
|
114
|
-
- **GitHub:** [https://github.com/endgegnerbert-tech/pi-research](https://github.com/endgegnerbert-tech/pi-research)
|
|
43
|
+
- **GitHub:** [https://github.com/endgegnerbert-tech/pi-research](https://github.com/endgegnerbert-tech/pi-research)
|
|
44
|
+
=======
|
|
45
|
+
| Old | New |
|
|
46
|
+
|---|---|
|
|
47
|
+
| `pi-research` | `emet` |
|
|
48
|
+
| `unblind-mcp` (MCP server) | `emet-mcp` |
|
|
49
|
+
| `PI_RESEARCH_*` env vars | `EMET_*` env vars |
|
|
50
|
+
| `pi-research` tool name | `emet` tool name |
|
|
51
|
+
|
|
52
|
+
All features, modes (`fast`, `deep`, `code`, `academic`), and the zero-setup
|
|
53
|
+
architecture carry over 100%. No functionality was removed.
|
|
54
|
+
|
|
55
|
+
---
|
|
56
|
+
|
|
57
|
+
**GitHub:** [tomsej/emet](https://github.com/tomsej/emet)
|
|
58
|
+
>>>>>>> 72ee46a (chore: deprecate package, move to emet)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
[
|
|
2
|
+
"authoritative_source_count",
|
|
3
|
+
"blocked_source_count",
|
|
4
|
+
"blog_count",
|
|
5
|
+
"candidate_conflict",
|
|
6
|
+
"file_count",
|
|
7
|
+
"forum_count",
|
|
8
|
+
"github_readme_count",
|
|
9
|
+
"github_repo_count",
|
|
10
|
+
"has_authority_resolution_path",
|
|
11
|
+
"negative_signal_sources",
|
|
12
|
+
"official_doc_count",
|
|
13
|
+
"other_count",
|
|
14
|
+
"paper_count",
|
|
15
|
+
"positive_signal_sources",
|
|
16
|
+
"query_academic",
|
|
17
|
+
"query_comparison",
|
|
18
|
+
"query_procedural",
|
|
19
|
+
"query_temporal",
|
|
20
|
+
"query_versioned",
|
|
21
|
+
"source_count"
|
|
22
|
+
]
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
[
|
|
2
|
+
"authoritative_source_count",
|
|
3
|
+
"blocked_source_count",
|
|
4
|
+
"blog_count",
|
|
5
|
+
"file_count",
|
|
6
|
+
"forum_count",
|
|
7
|
+
"github_readme_count",
|
|
8
|
+
"github_repo_count",
|
|
9
|
+
"has_authority",
|
|
10
|
+
"has_only_one_good_source",
|
|
11
|
+
"negative_signal_sources",
|
|
12
|
+
"official_doc_count",
|
|
13
|
+
"other_count",
|
|
14
|
+
"paper_count",
|
|
15
|
+
"positive_signal_sources",
|
|
16
|
+
"query_academic",
|
|
17
|
+
"query_comparison",
|
|
18
|
+
"query_procedural",
|
|
19
|
+
"query_temporal",
|
|
20
|
+
"query_versioned",
|
|
21
|
+
"source_count"
|
|
22
|
+
]
|
|
Binary file
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Tiny Router Training Runbook
|
|
2
|
+
|
|
3
|
+
Target budget:
|
|
4
|
+
|
|
5
|
+
- GPU RAM: 2 GB
|
|
6
|
+
- CPU RAM: 20 GB
|
|
7
|
+
- Default path: CPU-first, frozen embeddings, small models
|
|
8
|
+
|
|
9
|
+
## Environment
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
python3 -m venv .venv-router
|
|
13
|
+
. .venv-router/bin/activate
|
|
14
|
+
pip install -r ml/router/requirements.txt
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Phase 1 — domain router
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
node scripts/router/audit-cache.mjs
|
|
21
|
+
node scripts/router/export-examples.mjs
|
|
22
|
+
node scripts/router/split-examples.mjs
|
|
23
|
+
|
|
24
|
+
python ml/router/embed_model2vec.py \
|
|
25
|
+
--input data/router/examples.jsonl \
|
|
26
|
+
--gold data/router/gold-domain.jsonl \
|
|
27
|
+
--synthetic data/router/synthetic-train.jsonl
|
|
28
|
+
|
|
29
|
+
python ml/router/train_domain_classifier.py \
|
|
30
|
+
--embeddings data/router/domain-model2vec.npz data/router/synthetic-model2vec.npz \
|
|
31
|
+
--gold-embeddings data/router/gold-model2vec.npz \
|
|
32
|
+
--out .cache/models/pi-research-router/domain \
|
|
33
|
+
--model-type auto
|
|
34
|
+
|
|
35
|
+
python ml/router/evaluate_domain.py \
|
|
36
|
+
--model .cache/models/pi-research-router/domain/model.joblib \
|
|
37
|
+
--embeddings data/router/gold-model2vec.npz \
|
|
38
|
+
--out metrics/router/domain-model2vec-lr.json
|
|
39
|
+
|
|
40
|
+
python ml/router/benchmark_latency.py \
|
|
41
|
+
--model-dir .cache/models/pi-research-router/domain \
|
|
42
|
+
--examples data/router/gold-domain.jsonl \
|
|
43
|
+
--out metrics/router/latency.json
|
|
44
|
+
|
|
45
|
+
python scripts/router/eval_domain_unknown.py \
|
|
46
|
+
--model-dir .cache/models/pi-research-router/domain \
|
|
47
|
+
--input data/router/unknown-domain-smoke.jsonl
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Phase 2 — structured baselines
|
|
51
|
+
|
|
52
|
+
Build provisional structured rows:
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
node scripts/router/export_structured_provisional.mjs
|
|
56
|
+
node scripts/router/eval_structured_baselines.mjs
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
Train conservative structured classifiers:
|
|
60
|
+
|
|
61
|
+
```bash
|
|
62
|
+
python ml/router/train_structured_baseline.py --task conflict
|
|
63
|
+
python ml/router/train_structured_baseline.py --task sufficiency
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
Outputs:
|
|
67
|
+
|
|
68
|
+
- `.cache/models/pi-research-router/conflict-structured/`
|
|
69
|
+
- `.cache/models/pi-research-router/sufficiency-structured/`
|
|
70
|
+
- `metrics/router/conflict-structured-models.json`
|
|
71
|
+
- `metrics/router/sufficiency-structured-models.json`
|
|
72
|
+
|
|
73
|
+
## Runtime flags
|
|
74
|
+
|
|
75
|
+
```bash
|
|
76
|
+
PI_RESEARCH_TINY_ROUTER=1
|
|
77
|
+
PI_RESEARCH_TINY_ROUTER_MODEL=.cache/models/pi-research-router
|
|
78
|
+
PI_RESEARCH_TINY_ROUTER_TIMEOUT_MS=50
|
|
79
|
+
PI_RESEARCH_TINY_ROUTER_DOMAIN=1
|
|
80
|
+
PI_RESEARCH_TINY_ROUTER_FOLLOWUP=1
|
|
81
|
+
PI_RESEARCH_TINY_ROUTER_CONFLICT=0
|
|
82
|
+
PI_RESEARCH_TINY_ROUTER_SUFFICIENCY=0
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
Keep conflict/sufficiency off until metrics are reviewed.
|
|
86
|
+
|
|
87
|
+
## Server deploy
|
|
88
|
+
|
|
89
|
+
Safe MCP runtime deploy:
|
|
90
|
+
|
|
91
|
+
```bash
|
|
92
|
+
scripts/router/deploy-server-runtime.sh \
|
|
93
|
+
blackknight@100.98.190.19 \
|
|
94
|
+
~/work/pi-research-runtime
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
This syncs the repo, installs user-local Node if needed, copies trained router models, runs `npm install`, and writes:
|
|
98
|
+
|
|
99
|
+
- `start-mcp-tiny-router-safe.sh`
|
|
100
|
+
- `start-mcp-tiny-router-experimental.sh`
|
|
101
|
+
|
|
102
|
+
Recommended start command:
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
ssh blackknight@100.98.190.19 'cd ~/work/pi-research-runtime && ./start-mcp-tiny-router-safe.sh'
|
|
106
|
+
```
|
|
Binary file
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import argparse
|
|
3
|
+
import time
|
|
4
|
+
import numpy as np
|
|
5
|
+
import joblib
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
sys_path_added = False
|
|
9
|
+
if not sys_path_added:
|
|
10
|
+
import sys
|
|
11
|
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
12
|
+
sys_path_added = True
|
|
13
|
+
|
|
14
|
+
from features import load_embedding_model, extract_domain_features
|
|
15
|
+
|
|
16
|
+
def main():
|
|
17
|
+
parser = argparse.ArgumentParser()
|
|
18
|
+
parser.add_argument("--model-dir", required=True)
|
|
19
|
+
parser.add_argument("--examples", required=True)
|
|
20
|
+
parser.add_argument("--out", required=True)
|
|
21
|
+
args = parser.parse_args()
|
|
22
|
+
|
|
23
|
+
print(f"Loading Model2Vec...")
|
|
24
|
+
emb_model = load_embedding_model()
|
|
25
|
+
|
|
26
|
+
print(f"Loading Classifier...")
|
|
27
|
+
clf = joblib.load(f"{args.model_dir}/model.joblib")
|
|
28
|
+
|
|
29
|
+
# Load a few queries to test
|
|
30
|
+
queries = []
|
|
31
|
+
with open(args.examples, "r") as f:
|
|
32
|
+
for line in f:
|
|
33
|
+
if not line.strip(): continue
|
|
34
|
+
ex = json.loads(line)
|
|
35
|
+
queries.append(ex["query"])
|
|
36
|
+
|
|
37
|
+
# Warmup
|
|
38
|
+
print("Warming up...")
|
|
39
|
+
for q in queries[:10]:
|
|
40
|
+
feats = extract_domain_features([q], ["fast"], emb_model=emb_model, show_progress_bar=False)
|
|
41
|
+
clf.predict(feats)
|
|
42
|
+
|
|
43
|
+
# Benchmark
|
|
44
|
+
print(f"Benchmarking {len(queries)} queries sequentially...")
|
|
45
|
+
latencies = []
|
|
46
|
+
|
|
47
|
+
for q in queries:
|
|
48
|
+
t0 = time.perf_counter()
|
|
49
|
+
|
|
50
|
+
feats = extract_domain_features([q], ["fast"], emb_model=emb_model, show_progress_bar=False)
|
|
51
|
+
pred = clf.predict(feats)[0]
|
|
52
|
+
|
|
53
|
+
t1 = time.perf_counter()
|
|
54
|
+
latencies.append((t1 - t0) * 1000) # ms
|
|
55
|
+
|
|
56
|
+
latencies = np.array(latencies)
|
|
57
|
+
p50 = np.percentile(latencies, 50)
|
|
58
|
+
p95 = np.percentile(latencies, 95)
|
|
59
|
+
mean = np.mean(latencies)
|
|
60
|
+
|
|
61
|
+
print(f"p50: {p50:.2f} ms")
|
|
62
|
+
print(f"p95: {p95:.2f} ms")
|
|
63
|
+
print(f"Mean: {mean:.2f} ms")
|
|
64
|
+
|
|
65
|
+
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
|
66
|
+
|
|
67
|
+
metrics = {
|
|
68
|
+
"task": "domain",
|
|
69
|
+
"latency_ms": {
|
|
70
|
+
"p50": p50,
|
|
71
|
+
"p95": p95,
|
|
72
|
+
"mean": mean,
|
|
73
|
+
"samples": len(latencies)
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
with open(args.out, "w") as f:
|
|
78
|
+
json.dump(metrics, f, indent=2)
|
|
79
|
+
|
|
80
|
+
if __name__ == "__main__":
|
|
81
|
+
main()
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import joblib
|
|
5
|
+
import numpy as np
|
|
6
|
+
import traceback
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
# Add the directory containing features.py to sys.path
|
|
10
|
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
11
|
+
from features import load_embedding_model, extract_domain_features, extract_followup_features
|
|
12
|
+
|
|
13
|
+
logging.basicConfig(level=logging.ERROR)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def load_model(path):
|
|
17
|
+
return joblib.load(path) if os.path.exists(path) else None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_feature_names(path):
|
|
21
|
+
if not os.path.exists(path):
|
|
22
|
+
return None
|
|
23
|
+
with open(path, "r") as f:
|
|
24
|
+
return json.load(f)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def predict_proba_like(clf, features):
|
|
28
|
+
if hasattr(clf, "predict_proba"):
|
|
29
|
+
proba = clf.predict_proba(features)[0]
|
|
30
|
+
max_idx = int(np.argmax(proba))
|
|
31
|
+
return clf.classes_[max_idx], float(proba[max_idx])
|
|
32
|
+
|
|
33
|
+
pred = clf.predict(features)[0]
|
|
34
|
+
return pred, 1.0
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def vectorize_structured_features(feature_names, features):
|
|
38
|
+
row = [float(features.get(name, 0.0)) for name in feature_names]
|
|
39
|
+
return np.array([row], dtype=np.float32)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def main():
|
|
43
|
+
if len(sys.argv) < 2:
|
|
44
|
+
print(json.dumps({"error": "Missing model path"}))
|
|
45
|
+
sys.exit(1)
|
|
46
|
+
|
|
47
|
+
model_dir = sys.argv[1]
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
emb_model = load_embedding_model()
|
|
51
|
+
domain_clf = load_model(os.path.join(model_dir, "domain", "model.joblib"))
|
|
52
|
+
followup_clf = load_model(os.path.join(model_dir, "followup", "model.joblib"))
|
|
53
|
+
conflict_clf = load_model(os.path.join(model_dir, "conflict-structured", "model.joblib"))
|
|
54
|
+
sufficiency_clf = load_model(os.path.join(model_dir, "sufficiency-structured", "model.joblib"))
|
|
55
|
+
conflict_feature_names = load_feature_names(os.path.join(model_dir, "conflict-structured", "feature-names.json"))
|
|
56
|
+
sufficiency_feature_names = load_feature_names(os.path.join(model_dir, "sufficiency-structured", "feature-names.json"))
|
|
57
|
+
except Exception as e:
|
|
58
|
+
print(json.dumps({"error": f"Failed to load models: {str(e)}"}))
|
|
59
|
+
sys.exit(1)
|
|
60
|
+
|
|
61
|
+
print("READY", flush=True)
|
|
62
|
+
|
|
63
|
+
for line in sys.stdin:
|
|
64
|
+
line = line.strip()
|
|
65
|
+
if not line:
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
req = json.loads(line)
|
|
70
|
+
req_id = req.get("id")
|
|
71
|
+
task = req.get("task", "domain")
|
|
72
|
+
query = req.get("query", "")
|
|
73
|
+
mode = req.get("mode", "fast")
|
|
74
|
+
|
|
75
|
+
if task == "domain":
|
|
76
|
+
if not domain_clf:
|
|
77
|
+
print(json.dumps({"id": req_id, "error": "Domain model not loaded"}), flush=True)
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
feats = extract_domain_features([query], [mode], emb_model=emb_model, show_progress_bar=False)
|
|
81
|
+
pred, confidence = predict_proba_like(domain_clf, feats)
|
|
82
|
+
|
|
83
|
+
print(json.dumps({
|
|
84
|
+
"id": req_id,
|
|
85
|
+
"domain": str(pred),
|
|
86
|
+
"confidence": confidence
|
|
87
|
+
}), flush=True)
|
|
88
|
+
|
|
89
|
+
elif task == "followup":
|
|
90
|
+
if not followup_clf:
|
|
91
|
+
print(json.dumps({"id": req_id, "error": "Followup model not loaded"}), flush=True)
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
conflict = req.get("conflict", "none")
|
|
95
|
+
sources = req.get("sources", {})
|
|
96
|
+
|
|
97
|
+
feats = extract_followup_features([query], [mode], [conflict], [sources], emb_model=emb_model, show_progress_bar=False)
|
|
98
|
+
pred, confidence = predict_proba_like(followup_clf, feats)
|
|
99
|
+
|
|
100
|
+
print(json.dumps({
|
|
101
|
+
"id": req_id,
|
|
102
|
+
"action": str(pred),
|
|
103
|
+
"confidence": confidence
|
|
104
|
+
}), flush=True)
|
|
105
|
+
|
|
106
|
+
elif task == "conflict":
|
|
107
|
+
if not conflict_clf or not conflict_feature_names:
|
|
108
|
+
print(json.dumps({"id": req_id, "error": "Conflict model not loaded"}), flush=True)
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
feats = vectorize_structured_features(conflict_feature_names, req.get("features", {}))
|
|
112
|
+
pred, confidence = predict_proba_like(conflict_clf, feats)
|
|
113
|
+
print(json.dumps({
|
|
114
|
+
"id": req_id,
|
|
115
|
+
"decision": str(pred),
|
|
116
|
+
"confidence": confidence
|
|
117
|
+
}), flush=True)
|
|
118
|
+
|
|
119
|
+
elif task == "sufficiency":
|
|
120
|
+
if not sufficiency_clf or not sufficiency_feature_names:
|
|
121
|
+
print(json.dumps({"id": req_id, "error": "Sufficiency model not loaded"}), flush=True)
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
feats = vectorize_structured_features(sufficiency_feature_names, req.get("features", {}))
|
|
125
|
+
pred, confidence = predict_proba_like(sufficiency_clf, feats)
|
|
126
|
+
print(json.dumps({
|
|
127
|
+
"id": req_id,
|
|
128
|
+
"decision": str(pred),
|
|
129
|
+
"confidence": confidence
|
|
130
|
+
}), flush=True)
|
|
131
|
+
|
|
132
|
+
else:
|
|
133
|
+
print(json.dumps({"id": req_id, "error": f"Unknown task: {task}"}), flush=True)
|
|
134
|
+
|
|
135
|
+
except Exception as e:
|
|
136
|
+
print(json.dumps({"error": str(e), "trace": traceback.format_exc()}), flush=True)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
if __name__ == "__main__":
|
|
140
|
+
main()
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import argparse
|
|
3
|
+
import numpy as np
|
|
4
|
+
from features import load_embedding_model, extract_domain_features, extract_followup_features
|
|
5
|
+
|
|
6
|
+
def main():
|
|
7
|
+
parser = argparse.ArgumentParser()
|
|
8
|
+
parser.add_argument("--input", required=True)
|
|
9
|
+
parser.add_argument("--out", required=True)
|
|
10
|
+
parser.add_argument("--model", default="minishlab/potion-base-8M")
|
|
11
|
+
parser.add_argument("--task", default="domain")
|
|
12
|
+
args = parser.parse_args()
|
|
13
|
+
|
|
14
|
+
examples = []
|
|
15
|
+
with open(args.input, "r") as f:
|
|
16
|
+
for line in f:
|
|
17
|
+
if not line.strip(): continue
|
|
18
|
+
ex = json.loads(line)
|
|
19
|
+
if "task" not in ex or ex["task"] == args.task:
|
|
20
|
+
examples.append(ex)
|
|
21
|
+
|
|
22
|
+
print(f"Loaded {len(examples)} examples for task '{args.task}'")
|
|
23
|
+
|
|
24
|
+
print(f"Loading StaticModel: {args.model}")
|
|
25
|
+
model = load_embedding_model()
|
|
26
|
+
|
|
27
|
+
queries = [ex["query"] for ex in examples]
|
|
28
|
+
modes = [ex.get("mode", ex.get("meta", {}).get("mode", "fast")) for ex in examples]
|
|
29
|
+
|
|
30
|
+
if args.task == "domain":
|
|
31
|
+
print(f"Encoding {len(queries)} queries for domain routing...")
|
|
32
|
+
features = extract_domain_features(queries, modes, emb_model=model, show_progress_bar=True)
|
|
33
|
+
elif args.task == "followup":
|
|
34
|
+
print(f"Encoding {len(queries)} queries for followup action...")
|
|
35
|
+
conflicts = [ex.get("conflict", "none") for ex in examples]
|
|
36
|
+
sources_list = [ex.get("sources", {}) for ex in examples]
|
|
37
|
+
features = extract_followup_features(queries, modes, conflicts, sources_list, emb_model=model, show_progress_bar=True)
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError(f"Unknown task: {args.task}")
|
|
40
|
+
|
|
41
|
+
ids = np.array([ex.get("id", str(i)) for i, ex in enumerate(examples)])
|
|
42
|
+
labels = np.array([ex["label"] for ex in examples])
|
|
43
|
+
|
|
44
|
+
print(f"Saving features shape {features.shape} to {args.out}")
|
|
45
|
+
np.savez(args.out, features=features, ids=ids, labels=labels)
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
main()
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import argparse
|
|
3
|
+
import numpy as np
|
|
4
|
+
import joblib
|
|
5
|
+
import os
|
|
6
|
+
from sklearn.metrics import classification_report, f1_score, confusion_matrix, accuracy_score
|
|
7
|
+
|
|
8
|
+
def main():
|
|
9
|
+
parser = argparse.ArgumentParser()
|
|
10
|
+
parser.add_argument("--model", required=True)
|
|
11
|
+
parser.add_argument("--embeddings", required=True)
|
|
12
|
+
parser.add_argument("--out", required=True)
|
|
13
|
+
args = parser.parse_args()
|
|
14
|
+
|
|
15
|
+
# Load model
|
|
16
|
+
print(f"Loading model from {args.model}")
|
|
17
|
+
clf = joblib.load(args.model)
|
|
18
|
+
|
|
19
|
+
# Load data
|
|
20
|
+
data = np.load(args.embeddings)
|
|
21
|
+
X = data["features"]
|
|
22
|
+
y_true = data["labels"]
|
|
23
|
+
|
|
24
|
+
# Evaluate
|
|
25
|
+
y_pred = clf.predict(X)
|
|
26
|
+
|
|
27
|
+
macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
|
|
28
|
+
accuracy = accuracy_score(y_true, y_pred)
|
|
29
|
+
|
|
30
|
+
print("\nGold Validation Report:")
|
|
31
|
+
print(classification_report(y_true, y_pred, zero_division=0))
|
|
32
|
+
|
|
33
|
+
print(f"\nMacro-F1: {macro_f1:.4f}")
|
|
34
|
+
|
|
35
|
+
# Extract high-risk misclassifications
|
|
36
|
+
classes = clf.classes_
|
|
37
|
+
high_risk_classes = ["security", "papers", "specs"]
|
|
38
|
+
cm = confusion_matrix(y_true, y_pred, labels=classes)
|
|
39
|
+
|
|
40
|
+
high_risk_errors = 0
|
|
41
|
+
web_idx = np.where(classes == "web")[0]
|
|
42
|
+
if len(web_idx) > 0:
|
|
43
|
+
web_idx = web_idx[0]
|
|
44
|
+
for hr_class in high_risk_classes:
|
|
45
|
+
hr_idx = np.where(classes == hr_class)[0]
|
|
46
|
+
if len(hr_idx) > 0:
|
|
47
|
+
errors = cm[hr_idx[0], web_idx]
|
|
48
|
+
if errors > 0:
|
|
49
|
+
print(f"HIGH RISK WARNING: {errors} '{hr_class}' queries routed to 'web'")
|
|
50
|
+
high_risk_errors += errors
|
|
51
|
+
|
|
52
|
+
# Save artifacts
|
|
53
|
+
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
|
54
|
+
|
|
55
|
+
metrics = {
|
|
56
|
+
"task": "domain",
|
|
57
|
+
"eval_set_size": len(X),
|
|
58
|
+
"macro_f1": macro_f1,
|
|
59
|
+
"accuracy": accuracy,
|
|
60
|
+
"high_risk_downgrades": int(high_risk_errors),
|
|
61
|
+
"classes": classes.tolist()
|
|
62
|
+
}
|
|
63
|
+
with open(args.out, "w") as f:
|
|
64
|
+
json.dump(metrics, f, indent=2)
|
|
65
|
+
|
|
66
|
+
if __name__ == "__main__":
|
|
67
|
+
main()
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from model2vec import StaticModel
|
|
3
|
+
|
|
4
|
+
EMBEDDING_MODEL_NAME = "minishlab/potion-base-8M"
|
|
5
|
+
|
|
6
|
+
def load_embedding_model() -> StaticModel:
|
|
7
|
+
"""Loads the base static model for feature extraction."""
|
|
8
|
+
return StaticModel.from_pretrained(EMBEDDING_MODEL_NAME)
|
|
9
|
+
|
|
10
|
+
def encode_modes(modes: list) -> np.ndarray:
|
|
11
|
+
"""Encodes a list of mode strings into a one-hot float32 numpy array."""
|
|
12
|
+
encoded = []
|
|
13
|
+
for mode in modes:
|
|
14
|
+
encoded.append([
|
|
15
|
+
1.0 if mode == "fast" else 0.0,
|
|
16
|
+
1.0 if mode == "deep" else 0.0,
|
|
17
|
+
1.0 if mode == "academic" else 0.0,
|
|
18
|
+
1.0 if mode == "code" else 0.0,
|
|
19
|
+
])
|
|
20
|
+
return np.array(encoded, dtype=np.float32)
|
|
21
|
+
|
|
22
|
+
def extract_domain_features(queries: list, modes: list, emb_model: StaticModel = None, show_progress_bar: bool = False) -> np.ndarray:
|
|
23
|
+
"""Extracts the combined feature vector (text embeddings + one-hot mode) for domain routing."""
|
|
24
|
+
if emb_model is None:
|
|
25
|
+
emb_model = load_embedding_model()
|
|
26
|
+
|
|
27
|
+
emb = emb_model.encode(queries, show_progress_bar=show_progress_bar)
|
|
28
|
+
modes_np = encode_modes(modes)
|
|
29
|
+
return np.hstack([emb, modes_np])
|
|
30
|
+
|
|
31
|
+
def encode_followup_meta(conflicts: list, sources_list: list) -> np.ndarray:
|
|
32
|
+
"""Encodes conflict and source metadata into a feature array for followup classification."""
|
|
33
|
+
encoded = []
|
|
34
|
+
for conflict, sources in zip(conflicts, sources_list):
|
|
35
|
+
row = [
|
|
36
|
+
1.0 if conflict == "severe" else 0.0,
|
|
37
|
+
1.0 if conflict == "minor" else 0.0,
|
|
38
|
+
1.0 if conflict == "none" else 0.0,
|
|
39
|
+
|
|
40
|
+
1.0 if sources.get("has_authority", False) else 0.0,
|
|
41
|
+
1.0 if sources.get("has_forum", False) else 0.0,
|
|
42
|
+
1.0 if sources.get("has_news", False) else 0.0,
|
|
43
|
+
1.0 if sources.get("has_recent", False) else 0.0,
|
|
44
|
+
|
|
45
|
+
# Normalize source count (cap at 10)
|
|
46
|
+
min(float(sources.get("source_count", 3)) / 10.0, 1.0)
|
|
47
|
+
]
|
|
48
|
+
encoded.append(row)
|
|
49
|
+
return np.array(encoded, dtype=np.float32)
|
|
50
|
+
|
|
51
|
+
def extract_followup_features(queries: list, modes: list, conflicts: list, sources_list: list, emb_model: StaticModel = None, show_progress_bar: bool = False) -> np.ndarray:
|
|
52
|
+
"""Extracts features for the followup action classifier."""
|
|
53
|
+
if emb_model is None:
|
|
54
|
+
emb_model = load_embedding_model()
|
|
55
|
+
|
|
56
|
+
emb = emb_model.encode(queries, show_progress_bar=show_progress_bar)
|
|
57
|
+
modes_np = encode_modes(modes)
|
|
58
|
+
meta_np = encode_followup_meta(conflicts, sources_list)
|
|
59
|
+
|
|
60
|
+
return np.hstack([emb, modes_np, meta_np])
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import argparse
|
|
3
|
+
import numpy as np
|
|
4
|
+
import joblib
|
|
5
|
+
from sklearn.svm import LinearSVC
|
|
6
|
+
from sklearn.linear_model import LogisticRegression
|
|
7
|
+
from sklearn.calibration import CalibratedClassifierCV
|
|
8
|
+
from sklearn.metrics import classification_report, f1_score, confusion_matrix
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
def main():
|
|
12
|
+
parser = argparse.ArgumentParser()
|
|
13
|
+
parser.add_argument("--embeddings", required=True, nargs="+")
|
|
14
|
+
parser.add_argument("--out", required=True)
|
|
15
|
+
parser.add_argument("--model-type", choices=["svc", "lr"], default="svc")
|
|
16
|
+
args = parser.parse_args()
|
|
17
|
+
|
|
18
|
+
# Load all data files and combine
|
|
19
|
+
X_list, y_list = [], []
|
|
20
|
+
for emb_file in args.embeddings:
|
|
21
|
+
data = np.load(emb_file)
|
|
22
|
+
X_list.append(data["features"])
|
|
23
|
+
y_list.append(data["labels"])
|
|
24
|
+
|
|
25
|
+
X_train = np.vstack(X_list)
|
|
26
|
+
y_train = np.hstack(y_list)
|
|
27
|
+
|
|
28
|
+
print(f"Combined Train size: {len(X_train)}")
|
|
29
|
+
|
|
30
|
+
from imblearn.over_sampling import RandomOverSampler
|
|
31
|
+
ros = RandomOverSampler(random_state=42)
|
|
32
|
+
X_res, y_res = ros.fit_resample(X_train, y_train)
|
|
33
|
+
print(f"Size after OverSampling: {len(X_res)}")
|
|
34
|
+
|
|
35
|
+
# Train model
|
|
36
|
+
print(f"Training {args.model_type} with class_weight='balanced'...")
|
|
37
|
+
if args.model_type == "svc":
|
|
38
|
+
base_clf = LinearSVC(class_weight="balanced", dual=False, max_iter=5000, C=0.5)
|
|
39
|
+
clf = CalibratedClassifierCV(base_clf, method="sigmoid", cv=5)
|
|
40
|
+
else:
|
|
41
|
+
clf = LogisticRegression(class_weight="balanced", max_iter=5000)
|
|
42
|
+
|
|
43
|
+
clf.fit(X_res, y_res)
|
|
44
|
+
|
|
45
|
+
# Save artifacts
|
|
46
|
+
os.makedirs(args.out, exist_ok=True)
|
|
47
|
+
model_path = os.path.join(args.out, "model.joblib")
|
|
48
|
+
joblib.dump(clf, model_path)
|
|
49
|
+
print(f"Model saved to {model_path}")
|
|
50
|
+
|
|
51
|
+
# Save a small report to standard out
|
|
52
|
+
preds = clf.predict(X_res)
|
|
53
|
+
print("Train Report:")
|
|
54
|
+
print(classification_report(y_res, preds))
|
|
55
|
+
|
|
56
|
+
if __name__ == "__main__":
|
|
57
|
+
main()
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from collections import Counter
|
|
5
|
+
|
|
6
|
+
import joblib
|
|
7
|
+
import numpy as np
|
|
8
|
+
from imblearn.over_sampling import RandomOverSampler
|
|
9
|
+
from sklearn.calibration import CalibratedClassifierCV
|
|
10
|
+
from sklearn.linear_model import LogisticRegression
|
|
11
|
+
from sklearn.metrics import accuracy_score, f1_score
|
|
12
|
+
from sklearn.svm import LinearSVC
|
|
13
|
+
|
|
14
|
+
HIGH_RISK_CLASSES = {"security", "papers", "specs"}
|
|
15
|
+
PRECISION_TARGETS = (0.95, 0.90, 0.85)
|
|
16
|
+
MIN_DEFAULT_THRESHOLD = 0.35
|
|
17
|
+
MIN_HIGH_RISK_THRESHOLD = 0.55
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_embeddings(paths):
|
|
21
|
+
features, labels = [], []
|
|
22
|
+
for path in paths:
|
|
23
|
+
data = np.load(path)
|
|
24
|
+
features.append(data["features"])
|
|
25
|
+
labels.append(data["labels"])
|
|
26
|
+
return np.vstack(features), np.hstack(labels)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def build_classifier(model_type):
|
|
30
|
+
if model_type == "svc":
|
|
31
|
+
base = LinearSVC(class_weight="balanced", dual=False, max_iter=5000, C=0.5)
|
|
32
|
+
return CalibratedClassifierCV(base, method="sigmoid", cv=5)
|
|
33
|
+
if model_type == "lr":
|
|
34
|
+
return LogisticRegression(class_weight="balanced", max_iter=5000)
|
|
35
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def train_classifier(model_type, X_train, y_train):
|
|
39
|
+
ros = RandomOverSampler(random_state=42)
|
|
40
|
+
X_resampled, y_resampled = ros.fit_resample(X_train, y_train)
|
|
41
|
+
clf = build_classifier(model_type)
|
|
42
|
+
clf.fit(X_resampled, y_resampled)
|
|
43
|
+
return clf, len(X_resampled)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def evaluate_classifier(clf, X_eval, y_eval):
|
|
47
|
+
probs = clf.predict_proba(X_eval)
|
|
48
|
+
pred_idx = np.argmax(probs, axis=1)
|
|
49
|
+
preds = clf.classes_[pred_idx]
|
|
50
|
+
confs = np.max(probs, axis=1)
|
|
51
|
+
accuracy = accuracy_score(y_eval, preds)
|
|
52
|
+
macro_f1 = f1_score(y_eval, preds, average="macro", zero_division=0)
|
|
53
|
+
|
|
54
|
+
high_risk_downgrades = 0
|
|
55
|
+
for gold, pred in zip(y_eval, preds):
|
|
56
|
+
if gold in HIGH_RISK_CLASSES and pred == "web":
|
|
57
|
+
high_risk_downgrades += 1
|
|
58
|
+
|
|
59
|
+
return {
|
|
60
|
+
"accuracy": float(accuracy),
|
|
61
|
+
"macro_f1": float(macro_f1),
|
|
62
|
+
"high_risk_downgrades": int(high_risk_downgrades),
|
|
63
|
+
"preds": preds,
|
|
64
|
+
"confs": confs,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def derive_threshold_for_label(label, y_true, preds, confs):
|
|
69
|
+
candidate_thresholds = sorted({0.0, *[float(conf) for pred, conf in zip(preds, confs) if pred == label]})
|
|
70
|
+
best = None
|
|
71
|
+
floor = MIN_HIGH_RISK_THRESHOLD if label in HIGH_RISK_CLASSES else MIN_DEFAULT_THRESHOLD
|
|
72
|
+
|
|
73
|
+
gold_support = sum(1 for gold in y_true if gold == label)
|
|
74
|
+
for target_precision in PRECISION_TARGETS:
|
|
75
|
+
for threshold in candidate_thresholds:
|
|
76
|
+
accepted = [i for i, (pred, conf) in enumerate(zip(preds, confs)) if pred == label and conf >= threshold]
|
|
77
|
+
if not accepted:
|
|
78
|
+
continue
|
|
79
|
+
tp = sum(1 for i in accepted if y_true[i] == label)
|
|
80
|
+
fp = len(accepted) - tp
|
|
81
|
+
precision = tp / len(accepted)
|
|
82
|
+
recall = tp / gold_support if gold_support else 0.0
|
|
83
|
+
score = (precision >= target_precision, recall, precision, -threshold)
|
|
84
|
+
if best is None or score > best[0]:
|
|
85
|
+
best = (score, {
|
|
86
|
+
"threshold": float(threshold),
|
|
87
|
+
"precision": float(precision),
|
|
88
|
+
"recall": float(recall),
|
|
89
|
+
"accepted": len(accepted),
|
|
90
|
+
"tp": int(tp),
|
|
91
|
+
"fp": int(fp),
|
|
92
|
+
"target_precision": float(target_precision),
|
|
93
|
+
})
|
|
94
|
+
if best and best[0][0]:
|
|
95
|
+
best[1]["threshold"] = max(float(best[1]["threshold"]), floor)
|
|
96
|
+
return best[1]
|
|
97
|
+
|
|
98
|
+
if best:
|
|
99
|
+
best[1]["threshold"] = max(float(best[1]["threshold"]), floor)
|
|
100
|
+
return best[1]
|
|
101
|
+
return {
|
|
102
|
+
"threshold": 0.75 if label in HIGH_RISK_CLASSES else 0.80,
|
|
103
|
+
"precision": 0.0,
|
|
104
|
+
"recall": 0.0,
|
|
105
|
+
"accepted": 0,
|
|
106
|
+
"tp": 0,
|
|
107
|
+
"fp": 0,
|
|
108
|
+
"target_precision": PRECISION_TARGETS[-1],
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def derive_calibration(clf, X_eval, y_eval):
|
|
113
|
+
evaluation = evaluate_classifier(clf, X_eval, y_eval)
|
|
114
|
+
preds = evaluation["preds"]
|
|
115
|
+
confs = evaluation["confs"]
|
|
116
|
+
thresholds = {}
|
|
117
|
+
diagnostics = {}
|
|
118
|
+
|
|
119
|
+
for label in clf.classes_:
|
|
120
|
+
diag = derive_threshold_for_label(label, y_eval, preds, confs)
|
|
121
|
+
thresholds[str(label)] = float(diag["threshold"])
|
|
122
|
+
diagnostics[str(label)] = diag
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"defaultThreshold": 0.80,
|
|
126
|
+
"highRiskThreshold": 0.75,
|
|
127
|
+
"domainThresholds": thresholds,
|
|
128
|
+
"diagnostics": diagnostics,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def choose_best_report(reports):
|
|
133
|
+
return max(
|
|
134
|
+
reports,
|
|
135
|
+
key=lambda item: (-item["metrics"]["high_risk_downgrades"], item["metrics"]["accuracy"], item["metrics"]["macro_f1"], item["model_type"]),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def main():
|
|
140
|
+
parser = argparse.ArgumentParser()
|
|
141
|
+
parser.add_argument("--embeddings", required=True, nargs="+")
|
|
142
|
+
parser.add_argument("--gold-embeddings")
|
|
143
|
+
parser.add_argument("--out", required=True)
|
|
144
|
+
parser.add_argument("--model-type", choices=["svc", "lr", "auto"], default="auto")
|
|
145
|
+
args = parser.parse_args()
|
|
146
|
+
|
|
147
|
+
X_train, y_train = load_embeddings(args.embeddings)
|
|
148
|
+
print(f"Combined Train size: {len(X_train)}")
|
|
149
|
+
print(f"Train label distribution: {dict(Counter(y_train))}")
|
|
150
|
+
|
|
151
|
+
candidate_model_types = [args.model_type] if args.model_type != "auto" else ["svc", "lr"]
|
|
152
|
+
gold_data = load_embeddings([args.gold_embeddings]) if args.gold_embeddings else None
|
|
153
|
+
|
|
154
|
+
reports = []
|
|
155
|
+
for model_type in candidate_model_types:
|
|
156
|
+
print(f"Training {model_type} with class_weight='balanced'...")
|
|
157
|
+
clf, resampled_size = train_classifier(model_type, X_train, y_train)
|
|
158
|
+
report = {
|
|
159
|
+
"model_type": model_type,
|
|
160
|
+
"clf": clf,
|
|
161
|
+
"resampled_size": resampled_size,
|
|
162
|
+
}
|
|
163
|
+
if gold_data:
|
|
164
|
+
X_gold, y_gold = gold_data
|
|
165
|
+
report["metrics"] = evaluate_classifier(clf, X_gold, y_gold)
|
|
166
|
+
report["calibration"] = derive_calibration(clf, X_gold, y_gold)
|
|
167
|
+
print(json.dumps({
|
|
168
|
+
"model_type": model_type,
|
|
169
|
+
"accuracy": report["metrics"]["accuracy"],
|
|
170
|
+
"macro_f1": report["metrics"]["macro_f1"],
|
|
171
|
+
"high_risk_downgrades": report["metrics"]["high_risk_downgrades"],
|
|
172
|
+
}, indent=2))
|
|
173
|
+
reports.append(report)
|
|
174
|
+
|
|
175
|
+
best = reports[0] if len(reports) == 1 or not gold_data else choose_best_report(reports)
|
|
176
|
+
print(f"Selected model: {best['model_type']}")
|
|
177
|
+
|
|
178
|
+
os.makedirs(args.out, exist_ok=True)
|
|
179
|
+
model_path = os.path.join(args.out, "model.joblib")
|
|
180
|
+
joblib.dump(best["clf"], model_path)
|
|
181
|
+
print(f"Model saved to {model_path}")
|
|
182
|
+
|
|
183
|
+
meta = {
|
|
184
|
+
"modelType": best["model_type"],
|
|
185
|
+
"trainSize": int(len(X_train)),
|
|
186
|
+
"resampledTrainSize": int(best["resampled_size"]),
|
|
187
|
+
}
|
|
188
|
+
with open(os.path.join(args.out, "meta.json"), "w") as f:
|
|
189
|
+
json.dump(meta, f, indent=2)
|
|
190
|
+
|
|
191
|
+
if best.get("calibration"):
|
|
192
|
+
with open(os.path.join(args.out, "calibration.json"), "w") as f:
|
|
193
|
+
json.dump(best["calibration"], f, indent=2)
|
|
194
|
+
|
|
195
|
+
if best.get("metrics"):
|
|
196
|
+
metrics = {
|
|
197
|
+
"task": "domain",
|
|
198
|
+
"modelType": best["model_type"],
|
|
199
|
+
"accuracy": best["metrics"]["accuracy"],
|
|
200
|
+
"macro_f1": best["metrics"]["macro_f1"],
|
|
201
|
+
"high_risk_downgrades": best["metrics"]["high_risk_downgrades"],
|
|
202
|
+
"classes": [str(label) for label in best["clf"].classes_],
|
|
203
|
+
}
|
|
204
|
+
with open(os.path.join(args.out, "metrics.json"), "w") as f:
|
|
205
|
+
json.dump(metrics, f, indent=2)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
if __name__ == "__main__":
|
|
209
|
+
main()
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from collections import Counter
|
|
6
|
+
|
|
7
|
+
import joblib
|
|
8
|
+
import numpy as np
|
|
9
|
+
from imblearn.over_sampling import RandomOverSampler
|
|
10
|
+
from sklearn.linear_model import LogisticRegression
|
|
11
|
+
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
|
12
|
+
from sklearn.model_selection import GroupKFold
|
|
13
|
+
from sklearn.neural_network import MLPClassifier
|
|
14
|
+
from sklearn.pipeline import Pipeline
|
|
15
|
+
from sklearn.preprocessing import StandardScaler
|
|
16
|
+
|
|
17
|
+
TASK_INPUTS = {
|
|
18
|
+
"conflict": os.path.join("data", "router", "gold-conflict-structured.jsonl"),
|
|
19
|
+
"sufficiency": os.path.join("data", "router", "gold-sufficiency-structured.jsonl"),
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
TASK_BASELINES = {
|
|
23
|
+
"conflict": os.path.join("metrics", "router", "conflict-baseline-provisional.json"),
|
|
24
|
+
"sufficiency": os.path.join("metrics", "router", "sufficiency-baseline-provisional.json"),
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
MODEL_BUILDERS = {
|
|
28
|
+
"lr": lambda: Pipeline([
|
|
29
|
+
("scaler", StandardScaler()),
|
|
30
|
+
("clf", LogisticRegression(max_iter=5000, class_weight="balanced")),
|
|
31
|
+
]),
|
|
32
|
+
"mlp": lambda: Pipeline([
|
|
33
|
+
("scaler", StandardScaler()),
|
|
34
|
+
("clf", MLPClassifier(hidden_layer_sizes=(32, 16), max_iter=2000, random_state=42, early_stopping=False)),
|
|
35
|
+
]),
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def normalize_query_group(query: str) -> str:
|
|
40
|
+
return re.sub(r"\s+", " ", re.sub(r"[^a-z0-9\s]+", " ", (query or "").lower())).strip()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def load_jsonl(path: str):
|
|
44
|
+
with open(path, "r") as f:
|
|
45
|
+
return [json.loads(line) for line in f if line.strip()]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def build_xy(rows):
|
|
49
|
+
feature_names = sorted(rows[0]["features"].keys())
|
|
50
|
+
X = np.array([[row["features"][name] for name in feature_names] for row in rows], dtype=np.float32)
|
|
51
|
+
y = np.array([row["label"] for row in rows])
|
|
52
|
+
groups = np.array([normalize_query_group(row["query"]) for row in rows])
|
|
53
|
+
return X, y, groups, feature_names
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def choose_n_splits(y, groups):
|
|
57
|
+
min_class = min(Counter(y).values())
|
|
58
|
+
return max(2, min(5, len(set(groups)), min_class))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def evaluate_model(model_name, rows):
|
|
62
|
+
X, y, groups, feature_names = build_xy(rows)
|
|
63
|
+
splitter = GroupKFold(n_splits=choose_n_splits(y, groups))
|
|
64
|
+
gold, pred = [], []
|
|
65
|
+
fold_rows = []
|
|
66
|
+
|
|
67
|
+
for fold, (train_idx, test_idx) in enumerate(splitter.split(X, y, groups), start=1):
|
|
68
|
+
ros = RandomOverSampler(random_state=42)
|
|
69
|
+
X_train, y_train = ros.fit_resample(X[train_idx], y[train_idx])
|
|
70
|
+
clf = MODEL_BUILDERS[model_name]()
|
|
71
|
+
clf.fit(X_train, y_train)
|
|
72
|
+
probs = clf.predict_proba(X[test_idx]) if hasattr(clf, "predict_proba") else None
|
|
73
|
+
preds = clf.predict(X[test_idx])
|
|
74
|
+
|
|
75
|
+
for local_idx, pred_label in enumerate(preds):
|
|
76
|
+
idx = test_idx[local_idx]
|
|
77
|
+
confidence = None
|
|
78
|
+
if probs is not None:
|
|
79
|
+
confidence = float(np.max(probs[local_idx]))
|
|
80
|
+
gold.append(str(y[idx]))
|
|
81
|
+
pred.append(str(pred_label))
|
|
82
|
+
fold_rows.append({
|
|
83
|
+
"fold": fold,
|
|
84
|
+
"query": rows[idx]["query"],
|
|
85
|
+
"gold": str(y[idx]),
|
|
86
|
+
"pred": str(pred_label),
|
|
87
|
+
"confidence": confidence,
|
|
88
|
+
})
|
|
89
|
+
|
|
90
|
+
return {
|
|
91
|
+
"model": model_name,
|
|
92
|
+
"accuracy": accuracy_score(gold, pred),
|
|
93
|
+
"macro_f1": f1_score(gold, pred, average="macro"),
|
|
94
|
+
"classification_report": classification_report(gold, pred, output_dict=True),
|
|
95
|
+
"rows": fold_rows,
|
|
96
|
+
"feature_names": feature_names,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def train_full_model(model_name, rows):
|
|
101
|
+
X, y, _, feature_names = build_xy(rows)
|
|
102
|
+
ros = RandomOverSampler(random_state=42)
|
|
103
|
+
X_train, y_train = ros.fit_resample(X, y)
|
|
104
|
+
clf = MODEL_BUILDERS[model_name]()
|
|
105
|
+
clf.fit(X_train, y_train)
|
|
106
|
+
return clf, feature_names
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def load_baseline_metrics(task: str):
|
|
110
|
+
path = TASK_BASELINES[task]
|
|
111
|
+
if not os.path.exists(path):
|
|
112
|
+
return None
|
|
113
|
+
with open(path, "r") as f:
|
|
114
|
+
data = json.load(f)
|
|
115
|
+
return {
|
|
116
|
+
"accuracy": data.get("accuracy"),
|
|
117
|
+
"macroF1": data.get("macroF1"),
|
|
118
|
+
"falseSufficient": data.get("falseSufficient"),
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def main():
|
|
123
|
+
parser = argparse.ArgumentParser()
|
|
124
|
+
parser.add_argument("--task", choices=["conflict", "sufficiency"], required=True)
|
|
125
|
+
parser.add_argument("--input")
|
|
126
|
+
parser.add_argument("--out-dir")
|
|
127
|
+
args = parser.parse_args()
|
|
128
|
+
|
|
129
|
+
input_path = args.input or TASK_INPUTS[args.task]
|
|
130
|
+
out_dir = args.out_dir or os.path.join(".cache", "models", "pi-research-router", f"{args.task}-structured")
|
|
131
|
+
metrics_path = os.path.join("metrics", "router", f"{args.task}-structured-models.json")
|
|
132
|
+
|
|
133
|
+
rows = load_jsonl(input_path)
|
|
134
|
+
baseline = load_baseline_metrics(args.task)
|
|
135
|
+
|
|
136
|
+
reports = {
|
|
137
|
+
model_name: evaluate_model(model_name, rows)
|
|
138
|
+
for model_name in ["lr", "mlp"]
|
|
139
|
+
}
|
|
140
|
+
best_name = max(reports.keys(), key=lambda name: (reports[name]["macro_f1"], reports[name]["accuracy"]))
|
|
141
|
+
best_model, feature_names = train_full_model(best_name, rows)
|
|
142
|
+
|
|
143
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
144
|
+
os.makedirs(os.path.dirname(metrics_path), exist_ok=True)
|
|
145
|
+
joblib.dump(best_model, os.path.join(out_dir, "model.joblib"))
|
|
146
|
+
with open(os.path.join(out_dir, "feature-names.json"), "w") as f:
|
|
147
|
+
json.dump(feature_names, f, indent=2)
|
|
148
|
+
with open(os.path.join(out_dir, "meta.json"), "w") as f:
|
|
149
|
+
json.dump({"task": args.task, "bestModel": best_name, "rows": len(rows)}, f, indent=2)
|
|
150
|
+
|
|
151
|
+
summary = {
|
|
152
|
+
"task": args.task,
|
|
153
|
+
"rows": len(rows),
|
|
154
|
+
"baseline": baseline,
|
|
155
|
+
"best_model": best_name,
|
|
156
|
+
"models": reports,
|
|
157
|
+
}
|
|
158
|
+
with open(metrics_path, "w") as f:
|
|
159
|
+
json.dump(summary, f, indent=2)
|
|
160
|
+
|
|
161
|
+
print(json.dumps({
|
|
162
|
+
"task": args.task,
|
|
163
|
+
"rows": len(rows),
|
|
164
|
+
"baseline": baseline,
|
|
165
|
+
"best_model": best_name,
|
|
166
|
+
"best_accuracy": reports[best_name]["accuracy"],
|
|
167
|
+
"best_macro_f1": reports[best_name]["macro_f1"],
|
|
168
|
+
"lr_macro_f1": reports["lr"]["macro_f1"],
|
|
169
|
+
"mlp_macro_f1": reports["mlp"]["macro_f1"],
|
|
170
|
+
}, indent=2))
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
if __name__ == "__main__":
|
|
174
|
+
main()
|
package/package.json
CHANGED
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "pi-research",
|
|
3
|
-
"version": "1.
|
|
3
|
+
"version": "1.5.0",
|
|
4
4
|
"private": false,
|
|
5
5
|
"type": "module",
|
|
6
|
-
"description": "
|
|
6
|
+
"description": "⚠️ DEPRECATED: This package has moved to 'emet'. Run: npm install -g emet",
|
|
7
7
|
"license": "MIT",
|
|
8
8
|
"main": "./index.js",
|
|
9
9
|
"bin": {
|
|
10
|
-
"pi-research": "
|
|
11
|
-
"unblind-mcp": "
|
|
10
|
+
"pi-research": "bin/pi-research.js",
|
|
11
|
+
"unblind-mcp": "bin/unblind-mcp.js"
|
|
12
12
|
},
|
|
13
13
|
"files": [
|
|
14
14
|
"bin",
|
|
15
15
|
"extensions",
|
|
16
16
|
"index.js",
|
|
17
17
|
"lib",
|
|
18
|
+
"ml",
|
|
18
19
|
"mcp",
|
|
19
20
|
"mcp-server.js",
|
|
20
21
|
"pi-research.js",
|
|
@@ -32,7 +33,9 @@
|
|
|
32
33
|
"url": "https://github.com/endgegnerbert-tech/pi-research/issues"
|
|
33
34
|
},
|
|
34
35
|
"keywords": [
|
|
35
|
-
"
|
|
36
|
+
"deprecated",
|
|
37
|
+
"moved",
|
|
38
|
+
"use-emet"
|
|
36
39
|
],
|
|
37
40
|
"scripts": {
|
|
38
41
|
"test": "node --test",
|