leakit 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- leakit-0.1.0/.gitignore +8 -0
- leakit-0.1.0/LICENSE +21 -0
- leakit-0.1.0/PKG-INFO +136 -0
- leakit-0.1.0/README.md +114 -0
- leakit-0.1.0/install.sh +69 -0
- leakit-0.1.0/pyproject.toml +44 -0
- leakit-0.1.0/src/leakit/__init__.py +21 -0
- leakit-0.1.0/src/leakit/_stats.py +69 -0
- leakit-0.1.0/src/leakit/cli.py +175 -0
- leakit-0.1.0/src/leakit/core.py +122 -0
- leakit-0.1.0/src/leakit/sampler.py +155 -0
- leakit-0.1.0/tests/conftest.py +80 -0
- leakit-0.1.0/tests/test_cli.py +86 -0
- leakit-0.1.0/tests/test_integration_http.py +102 -0
- leakit-0.1.0/tests/test_sampler.py +62 -0
- leakit-0.1.0/tests/test_stats.py +54 -0
leakit-0.1.0/.gitignore
ADDED
leakit-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Victor Maricato
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
leakit-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: leakit
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Continuation-free membership inference on closed language models via sample self-concentration
|
|
5
|
+
Project-URL: Homepage, https://github.com/victormaricato/leakit
|
|
6
|
+
Project-URL: Repository, https://github.com/victormaricato/leakit
|
|
7
|
+
Author: Victor Maricato
|
|
8
|
+
License: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Keywords: audit,llm,membership-inference,mia,privacy,security
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Classifier: Topic :: Security
|
|
17
|
+
Requires-Python: >=3.9
|
|
18
|
+
Requires-Dist: openai>=1.40
|
|
19
|
+
Provides-Extra: dev
|
|
20
|
+
Requires-Dist: pytest>=7; extra == 'dev'
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
|
|
23
|
+
# leakit
|
|
24
|
+
|
|
25
|
+
**Continuation-free membership inference for closed language models.**
|
|
26
|
+
|
|
27
|
+
`leakit` tells you whether a document was likely in a model's training set using
|
|
28
|
+
nothing but its sampling API. No logits, no log-probabilities, and -- unlike
|
|
29
|
+
prior sampling attacks such as SaMIA -- no need to know the document's true
|
|
30
|
+
continuation. You give it the *opening* of a document; it samples several
|
|
31
|
+
continuations and measures how much they agree with each other. Training
|
|
32
|
+
documents pull the model's continuation distribution toward the memorised text,
|
|
33
|
+
so the samples concentrate; novel documents leave the distribution diffuse.
|
|
34
|
+
|
|
35
|
+
This is the reference implementation of the *self-concentration* attack from the
|
|
36
|
+
paper *"Leak It: Continuation-Free Membership Inference on Closed Language Models
|
|
37
|
+
via Sample Self-Concentration."*
|
|
38
|
+
|
|
39
|
+
## Install
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
curl -fsSL https://raw.githubusercontent.com/victormaricato/leakit/main/install.sh | bash
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
or, directly, with any of:
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
uv tool install leakit # recommended
|
|
49
|
+
pipx install leakit
|
|
50
|
+
pip install leakit
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## Use
|
|
54
|
+
|
|
55
|
+
`leakit` talks to any **OpenAI-compatible** endpoint. Set the API key for the
|
|
56
|
+
service you are probing -- the key maps to whatever provider `--base-url` points
|
|
57
|
+
at -- then run it.
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
export LEAKIT_API_KEY="sk-..." # or OPENAI_API_KEY
|
|
61
|
+
|
|
62
|
+
# OpenAI
|
|
63
|
+
leakit --model gpt-4o-mini suspect.txt
|
|
64
|
+
|
|
65
|
+
# Anything OpenAI-compatible (OpenRouter, Anthropic compat route, vLLM, Together, local server)
|
|
66
|
+
leakit --model anthropic/claude-3.5-sonnet \
|
|
67
|
+
--base-url https://openrouter.ai/api/v1 \
|
|
68
|
+
--api-key-env OPENROUTER_API_KEY \
|
|
69
|
+
-n 32 suspect.txt
|
|
70
|
+
|
|
71
|
+
# Compare a candidate against known non-member documents (relative percentile)
|
|
72
|
+
leakit --model gpt-4o-mini --calibrate clean/*.txt suspect.txt
|
|
73
|
+
|
|
74
|
+
# Pipe text in, get JSON out
|
|
75
|
+
cat article.txt | leakit --model gpt-4o-mini --json
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
Output:
|
|
79
|
+
|
|
80
|
+
```
|
|
81
|
+
document score samples
|
|
82
|
+
-----------------------------
|
|
83
|
+
suspect.txt 0.4213 32/32
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
A higher score means the sampled continuations agree more, which correlates with
|
|
87
|
+
membership. The absolute scale is model-dependent, so interpret scores
|
|
88
|
+
*relatively*: score several documents together, or use `--calibrate` with a set
|
|
89
|
+
of documents you know were **not** in training to get a percentile.
|
|
90
|
+
|
|
91
|
+
### Key options
|
|
92
|
+
|
|
93
|
+
| Flag | Meaning | Default |
|
|
94
|
+
|------|---------|---------|
|
|
95
|
+
| `--model` | model id passed to the API | required |
|
|
96
|
+
| `--base-url` | OpenAI-compatible endpoint | OpenAI |
|
|
97
|
+
| `--api-key-env` | env var holding the key | `LEAKIT_API_KEY`, then `OPENAI_API_KEY` |
|
|
98
|
+
| `-n, --samples` | continuations per document | 16 |
|
|
99
|
+
| `--max-tokens` | tokens per continuation | 64 |
|
|
100
|
+
| `--temperature` | sampling temperature | 1.0 |
|
|
101
|
+
| `--prefix-chars` | chars of each doc used as the prefix (0 = whole doc) | 256 |
|
|
102
|
+
| `--statistic` | `word-jaccard` (parameter-free) or `kgram` | `word-jaccard` |
|
|
103
|
+
| `--mode` | `chat` (closed APIs) or `completion` (base models) | `chat` |
|
|
104
|
+
| `--calibrate` | non-member baseline file(s) for a percentile | off |
|
|
105
|
+
| `--json` | machine-readable output | off |
|
|
106
|
+
|
|
107
|
+
For base/text-completion models (e.g. self-hosted Pythia/Llama base), use
|
|
108
|
+
`--mode completion` to sample the raw continuation distribution. For chat/instruct
|
|
109
|
+
models, the default `chat` mode asks the model to continue the passage verbatim.
|
|
110
|
+
|
|
111
|
+
## Python API
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
from leakit import LeakIt
|
|
115
|
+
|
|
116
|
+
scorer = LeakIt(model="gpt-4o-mini", n_samples=32) # reads LEAKIT_API_KEY/OPENAI_API_KEY
|
|
117
|
+
result = scorer.score(open("suspect.txt").read())
|
|
118
|
+
print(result.score, result.n_returned)
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
The raw statistics are exposed too:
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
from leakit import self_concentration_word_jaccard
|
|
125
|
+
self_concentration_word_jaccard(["a b c", "a b c", "x y z"])
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
## Responsible use
|
|
129
|
+
|
|
130
|
+
`leakit` is a privacy-auditing and red-teaming tool: use it to test models you
|
|
131
|
+
own or are authorised to assess. The self-concentration signal is a statistical
|
|
132
|
+
indicator, not proof of membership; calibrate before drawing conclusions.
|
|
133
|
+
|
|
134
|
+
## License
|
|
135
|
+
|
|
136
|
+
MIT.
|
leakit-0.1.0/README.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# leakit
|
|
2
|
+
|
|
3
|
+
**Continuation-free membership inference for closed language models.**
|
|
4
|
+
|
|
5
|
+
`leakit` tells you whether a document was likely in a model's training set using
|
|
6
|
+
nothing but its sampling API. No logits, no log-probabilities, and -- unlike
|
|
7
|
+
prior sampling attacks such as SaMIA -- no need to know the document's true
|
|
8
|
+
continuation. You give it the *opening* of a document; it samples several
|
|
9
|
+
continuations and measures how much they agree with each other. Training
|
|
10
|
+
documents pull the model's continuation distribution toward the memorised text,
|
|
11
|
+
so the samples concentrate; novel documents leave the distribution diffuse.
|
|
12
|
+
|
|
13
|
+
This is the reference implementation of the *self-concentration* attack from the
|
|
14
|
+
paper *"Leak It: Continuation-Free Membership Inference on Closed Language Models
|
|
15
|
+
via Sample Self-Concentration."*
|
|
16
|
+
|
|
17
|
+
## Install
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
curl -fsSL https://raw.githubusercontent.com/victormaricato/leakit/main/install.sh | bash
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
or, directly, with any of:
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
uv tool install leakit # recommended
|
|
27
|
+
pipx install leakit
|
|
28
|
+
pip install leakit
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
## Use
|
|
32
|
+
|
|
33
|
+
`leakit` talks to any **OpenAI-compatible** endpoint. Set the API key for the
|
|
34
|
+
service you are probing -- the key maps to whatever provider `--base-url` points
|
|
35
|
+
at -- then run it.
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
export LEAKIT_API_KEY="sk-..." # or OPENAI_API_KEY
|
|
39
|
+
|
|
40
|
+
# OpenAI
|
|
41
|
+
leakit --model gpt-4o-mini suspect.txt
|
|
42
|
+
|
|
43
|
+
# Anything OpenAI-compatible (OpenRouter, Anthropic compat route, vLLM, Together, local server)
|
|
44
|
+
leakit --model anthropic/claude-3.5-sonnet \
|
|
45
|
+
--base-url https://openrouter.ai/api/v1 \
|
|
46
|
+
--api-key-env OPENROUTER_API_KEY \
|
|
47
|
+
-n 32 suspect.txt
|
|
48
|
+
|
|
49
|
+
# Compare a candidate against known non-member documents (relative percentile)
|
|
50
|
+
leakit --model gpt-4o-mini --calibrate clean/*.txt suspect.txt
|
|
51
|
+
|
|
52
|
+
# Pipe text in, get JSON out
|
|
53
|
+
cat article.txt | leakit --model gpt-4o-mini --json
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
Output:
|
|
57
|
+
|
|
58
|
+
```
|
|
59
|
+
document score samples
|
|
60
|
+
-----------------------------
|
|
61
|
+
suspect.txt 0.4213 32/32
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
A higher score means the sampled continuations agree more, which correlates with
|
|
65
|
+
membership. The absolute scale is model-dependent, so interpret scores
|
|
66
|
+
*relatively*: score several documents together, or use `--calibrate` with a set
|
|
67
|
+
of documents you know were **not** in training to get a percentile.
|
|
68
|
+
|
|
69
|
+
### Key options
|
|
70
|
+
|
|
71
|
+
| Flag | Meaning | Default |
|
|
72
|
+
|------|---------|---------|
|
|
73
|
+
| `--model` | model id passed to the API | required |
|
|
74
|
+
| `--base-url` | OpenAI-compatible endpoint | OpenAI |
|
|
75
|
+
| `--api-key-env` | env var holding the key | `LEAKIT_API_KEY`, then `OPENAI_API_KEY` |
|
|
76
|
+
| `-n, --samples` | continuations per document | 16 |
|
|
77
|
+
| `--max-tokens` | tokens per continuation | 64 |
|
|
78
|
+
| `--temperature` | sampling temperature | 1.0 |
|
|
79
|
+
| `--prefix-chars` | chars of each doc used as the prefix (0 = whole doc) | 256 |
|
|
80
|
+
| `--statistic` | `word-jaccard` (parameter-free) or `kgram` | `word-jaccard` |
|
|
81
|
+
| `--mode` | `chat` (closed APIs) or `completion` (base models) | `chat` |
|
|
82
|
+
| `--calibrate` | non-member baseline file(s) for a percentile | off |
|
|
83
|
+
| `--json` | machine-readable output | off |
|
|
84
|
+
|
|
85
|
+
For base/text-completion models (e.g. self-hosted Pythia/Llama base), use
|
|
86
|
+
`--mode completion` to sample the raw continuation distribution. For chat/instruct
|
|
87
|
+
models, the default `chat` mode asks the model to continue the passage verbatim.
|
|
88
|
+
|
|
89
|
+
## Python API
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
from leakit import LeakIt
|
|
93
|
+
|
|
94
|
+
scorer = LeakIt(model="gpt-4o-mini", n_samples=32) # reads LEAKIT_API_KEY/OPENAI_API_KEY
|
|
95
|
+
result = scorer.score(open("suspect.txt").read())
|
|
96
|
+
print(result.score, result.n_returned)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
The raw statistics are exposed too:
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from leakit import self_concentration_word_jaccard
|
|
103
|
+
self_concentration_word_jaccard(["a b c", "a b c", "x y z"])
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
## Responsible use
|
|
107
|
+
|
|
108
|
+
`leakit` is a privacy-auditing and red-teaming tool: use it to test models you
|
|
109
|
+
own or are authorised to assess. The self-concentration signal is a statistical
|
|
110
|
+
indicator, not proof of membership; calibrate before drawing conclusions.
|
|
111
|
+
|
|
112
|
+
## License
|
|
113
|
+
|
|
114
|
+
MIT.
|
leakit-0.1.0/install.sh
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
#!/usr/bin/env bash
|
|
2
|
+
# leakit installer.
|
|
3
|
+
#
|
|
4
|
+
# curl -fsSL https://raw.githubusercontent.com/victormaricato/leakit/main/install.sh | bash
|
|
5
|
+
#
|
|
6
|
+
# Installs the `leakit` CLI into an isolated environment using uv (preferred)
|
|
7
|
+
# or pipx. Override the source with LEAKIT_SOURCE, e.g.
|
|
8
|
+
# LEAKIT_SOURCE="git+https://github.com/victormaricato/leakit.git" bash install.sh
|
|
9
|
+
set -euo pipefail
|
|
10
|
+
|
|
11
|
+
SOURCE="${LEAKIT_SOURCE:-leakit}"
|
|
12
|
+
GIT_FALLBACK="git+https://github.com/victormaricato/leakit.git"
|
|
13
|
+
|
|
14
|
+
say() { printf '\033[1;34m==>\033[0m %s\n' "$*"; }
|
|
15
|
+
warn() { printf '\033[1;33mwarning:\033[0m %s\n' "$*" >&2; }
|
|
16
|
+
die() { printf '\033[1;31merror:\033[0m %s\n' "$*" >&2; exit 1; }
|
|
17
|
+
|
|
18
|
+
ensure_uv() {
|
|
19
|
+
if command -v uv >/dev/null 2>&1; then return 0; fi
|
|
20
|
+
say "uv not found; installing it..."
|
|
21
|
+
curl -fsSL https://astral.sh/uv/install.sh | sh
|
|
22
|
+
# uv installs to ~/.local/bin by default
|
|
23
|
+
export PATH="$HOME/.local/bin:$PATH"
|
|
24
|
+
command -v uv >/dev/null 2>&1
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
install_with_uv() {
|
|
28
|
+
say "Installing leakit from '$SOURCE' with uv..."
|
|
29
|
+
if ! uv tool install --force "$SOURCE" 2>/dev/null; then
|
|
30
|
+
warn "install from '$SOURCE' failed; falling back to GitHub source"
|
|
31
|
+
uv tool install --force "$GIT_FALLBACK"
|
|
32
|
+
fi
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
install_with_pipx() {
|
|
36
|
+
say "Installing leakit from '$SOURCE' with pipx..."
|
|
37
|
+
if ! pipx install --force "$SOURCE" 2>/dev/null; then
|
|
38
|
+
warn "install from '$SOURCE' failed; falling back to GitHub source"
|
|
39
|
+
pipx install --force "$GIT_FALLBACK"
|
|
40
|
+
fi
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
main() {
|
|
44
|
+
command -v curl >/dev/null 2>&1 || die "curl is required"
|
|
45
|
+
if ensure_uv; then
|
|
46
|
+
install_with_uv
|
|
47
|
+
elif command -v pipx >/dev/null 2>&1; then
|
|
48
|
+
install_with_pipx
|
|
49
|
+
else
|
|
50
|
+
die "need uv or pipx to install; see https://docs.astral.sh/uv/"
|
|
51
|
+
fi
|
|
52
|
+
|
|
53
|
+
echo
|
|
54
|
+
say "leakit installed. Verify with: leakit --version"
|
|
55
|
+
cat <<'EOF'
|
|
56
|
+
|
|
57
|
+
Next steps:
|
|
58
|
+
1. Export the API key for the service you want to probe, e.g.
|
|
59
|
+
export LEAKIT_API_KEY="sk-..." # or OPENAI_API_KEY
|
|
60
|
+
2. Score a document:
|
|
61
|
+
leakit --model gpt-4o-mini suspect.txt
|
|
62
|
+
3. Probe any OpenAI-compatible endpoint with --base-url / --api-key-env.
|
|
63
|
+
|
|
64
|
+
If `leakit` is not found, add uv's tool bin to your PATH:
|
|
65
|
+
export PATH="$HOME/.local/bin:$PATH"
|
|
66
|
+
EOF
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
main "$@"
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "leakit"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Continuation-free membership inference on closed language models via sample self-concentration"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.9"
|
|
7
|
+
license = { text = "MIT" }
|
|
8
|
+
authors = [{ name = "Victor Maricato" }]
|
|
9
|
+
keywords = ["membership-inference", "privacy", "llm", "security", "mia", "audit"]
|
|
10
|
+
classifiers = [
|
|
11
|
+
"Development Status :: 4 - Beta",
|
|
12
|
+
"Intended Audience :: Science/Research",
|
|
13
|
+
"License :: OSI Approved :: MIT License",
|
|
14
|
+
"Programming Language :: Python :: 3",
|
|
15
|
+
"Topic :: Security",
|
|
16
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
17
|
+
]
|
|
18
|
+
dependencies = [
|
|
19
|
+
"openai>=1.40",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[project.urls]
|
|
23
|
+
Homepage = "https://github.com/victormaricato/leakit"
|
|
24
|
+
Repository = "https://github.com/victormaricato/leakit"
|
|
25
|
+
|
|
26
|
+
[project.scripts]
|
|
27
|
+
leakit = "leakit.cli:main"
|
|
28
|
+
|
|
29
|
+
[project.optional-dependencies]
|
|
30
|
+
dev = ["pytest>=7"]
|
|
31
|
+
|
|
32
|
+
[build-system]
|
|
33
|
+
requires = ["hatchling"]
|
|
34
|
+
build-backend = "hatchling.build"
|
|
35
|
+
|
|
36
|
+
[tool.hatch.build.targets.wheel]
|
|
37
|
+
packages = ["src/leakit"]
|
|
38
|
+
|
|
39
|
+
[tool.hatch.build.targets.sdist]
|
|
40
|
+
include = ["src/leakit", "tests", "README.md", "LICENSE", "install.sh", "pyproject.toml"]
|
|
41
|
+
|
|
42
|
+
[tool.ruff]
|
|
43
|
+
line-length = 100
|
|
44
|
+
target-version = "py39"
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""leakit: continuation-free membership inference for closed language models.
|
|
2
|
+
|
|
3
|
+
Public API:
|
|
4
|
+
LeakIt - high-level scorer over any OpenAI-compatible endpoint
|
|
5
|
+
ScoreResult - per-document result
|
|
6
|
+
self_concentration_word_jaccard / self_concentration_kgram - raw statistics
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from ._stats import self_concentration_kgram, self_concentration_word_jaccard
|
|
10
|
+
from .core import LeakIt, ScoreResult, percentile_of
|
|
11
|
+
|
|
12
|
+
__version__ = "0.1.0"
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"LeakIt",
|
|
16
|
+
"ScoreResult",
|
|
17
|
+
"percentile_of",
|
|
18
|
+
"self_concentration_word_jaccard",
|
|
19
|
+
"self_concentration_kgram",
|
|
20
|
+
"__version__",
|
|
21
|
+
]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Self-concentration statistics for continuation-free membership inference.
|
|
2
|
+
|
|
3
|
+
These are the membership signals defined in the paper. The statistic operates
|
|
4
|
+
purely on a set of sampled continuations: it never sees a gold continuation and
|
|
5
|
+
needs no model internals. Higher values indicate a more concentrated sampling
|
|
6
|
+
distribution, which the paper shows is predictive of training-set membership.
|
|
7
|
+
|
|
8
|
+
Pure Python, no third-party dependencies, so the package stays lightweight.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from itertools import combinations
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _word_set(text: str) -> set[str]:
|
|
17
|
+
return set(text.split())
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _kgram_set(text: str, k: int) -> set[str]:
|
|
21
|
+
if len(text) < k:
|
|
22
|
+
return set()
|
|
23
|
+
return {text[i : i + k] for i in range(len(text) - k + 1)}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _mean_pairwise_jaccard(sets: list[set[str]]) -> float:
|
|
27
|
+
if len(sets) < 2:
|
|
28
|
+
return 0.0
|
|
29
|
+
total = 0.0
|
|
30
|
+
pairs = 0
|
|
31
|
+
for a, b in combinations(sets, 2):
|
|
32
|
+
union = a | b
|
|
33
|
+
if not union:
|
|
34
|
+
continue
|
|
35
|
+
total += len(a & b) / len(union)
|
|
36
|
+
pairs += 1
|
|
37
|
+
return total / pairs if pairs else 0.0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def self_concentration_word_jaccard(completions: list[str]) -> float:
|
|
41
|
+
"""Parameter-free self-concentration: mean pairwise word-set Jaccard.
|
|
42
|
+
|
|
43
|
+
Each completion is reduced to its set of whitespace-delimited tokens, and we
|
|
44
|
+
average the Jaccard similarity across all unordered pairs. This is the
|
|
45
|
+
headline statistic in the paper (no n-gram size to tune).
|
|
46
|
+
"""
|
|
47
|
+
return _mean_pairwise_jaccard([_word_set(c) for c in completions])
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def self_concentration_kgram(completions: list[str], k: int = 5) -> float:
|
|
51
|
+
"""Self-concentration over character k-grams: mean pairwise k-gram Jaccard."""
|
|
52
|
+
return _mean_pairwise_jaccard([_kgram_set(c, k) for c in completions])
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
STATISTICS = {
|
|
56
|
+
"word-jaccard": lambda completions, k: self_concentration_word_jaccard(completions),
|
|
57
|
+
"kgram": self_concentration_kgram,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def compute(completions: list[str], statistic: str = "word-jaccard", k: int = 5) -> float:
|
|
62
|
+
"""Dispatch to the named statistic. Raises ValueError on unknown names."""
|
|
63
|
+
try:
|
|
64
|
+
fn = STATISTICS[statistic]
|
|
65
|
+
except KeyError:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"unknown statistic {statistic!r}; choose from {sorted(STATISTICS)}"
|
|
68
|
+
) from None
|
|
69
|
+
return fn(completions, k)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Command-line interface for leakit."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import json
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from . import __version__, _stats
|
|
11
|
+
from .core import LeakIt, ScoreResult, percentile_of
|
|
12
|
+
|
|
13
|
+
_EPILOG = """\
|
|
14
|
+
examples:
|
|
15
|
+
# score a file against an OpenAI model (reads LEAKIT_API_KEY or OPENAI_API_KEY)
|
|
16
|
+
leakit --model gpt-4o-mini suspect.txt
|
|
17
|
+
|
|
18
|
+
# probe a model served behind any OpenAI-compatible endpoint
|
|
19
|
+
leakit --model anthropic/claude-3.5-sonnet \\
|
|
20
|
+
--base-url https://openrouter.ai/api/v1 \\
|
|
21
|
+
--api-key-env OPENROUTER_API_KEY \\
|
|
22
|
+
-n 32 suspect.txt
|
|
23
|
+
|
|
24
|
+
# compare a candidate against a baseline of known non-member documents
|
|
25
|
+
leakit --model gpt-4o-mini --calibrate known_clean/*.txt suspect.txt
|
|
26
|
+
|
|
27
|
+
# pipe text in
|
|
28
|
+
cat article.txt | leakit --model gpt-4o-mini
|
|
29
|
+
|
|
30
|
+
A higher self-concentration score means the model's continuations agree more,
|
|
31
|
+
which the paper shows correlates with training-set membership. The absolute
|
|
32
|
+
scale is model-dependent; use --calibrate (or score several documents together)
|
|
33
|
+
to interpret a score relatively.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
38
|
+
p = argparse.ArgumentParser(
|
|
39
|
+
prog="leakit",
|
|
40
|
+
description="Continuation-free membership inference on closed language models.",
|
|
41
|
+
epilog=_EPILOG,
|
|
42
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
43
|
+
)
|
|
44
|
+
p.add_argument("documents", nargs="*", help="document file(s) to score; omit to read stdin")
|
|
45
|
+
p.add_argument("--text", action="append", default=[], metavar="STR",
|
|
46
|
+
help="inline document text (repeatable)")
|
|
47
|
+
p.add_argument("--model", required=True, help="model id passed to the API")
|
|
48
|
+
p.add_argument("--base-url", default=None,
|
|
49
|
+
help="OpenAI-compatible base URL (default: OpenAI)")
|
|
50
|
+
p.add_argument("--api-key-env", default=None, metavar="VAR",
|
|
51
|
+
help="env var holding the API key (default: LEAKIT_API_KEY then OPENAI_API_KEY)")
|
|
52
|
+
p.add_argument("-n", "--samples", type=int, default=16,
|
|
53
|
+
help="number of sampled continuations per document (default: 16)")
|
|
54
|
+
p.add_argument("--max-tokens", type=int, default=64,
|
|
55
|
+
help="tokens generated per continuation (default: 64)")
|
|
56
|
+
p.add_argument("--temperature", type=float, default=1.0, help="sampling temperature (default: 1.0)")
|
|
57
|
+
p.add_argument("--top-p", type=float, default=1.0, help="nucleus sampling top-p (default: 1.0)")
|
|
58
|
+
p.add_argument("--prefix-chars", type=int, default=256,
|
|
59
|
+
help="chars of each document used as the conditioning prefix; 0 = whole document (default: 256)")
|
|
60
|
+
p.add_argument("--statistic", choices=sorted(_stats.STATISTICS), default="word-jaccard",
|
|
61
|
+
help="self-concentration statistic (default: word-jaccard)")
|
|
62
|
+
p.add_argument("--k", type=int, default=5, help="k for the kgram statistic (default: 5)")
|
|
63
|
+
p.add_argument("--mode", choices=("chat", "completion"), default="chat",
|
|
64
|
+
help="API surface: chat for instruct/closed APIs (default), completion for base models")
|
|
65
|
+
p.add_argument("--concurrency", type=int, default=8, help="parallel requests (default: 8)")
|
|
66
|
+
p.add_argument("--n-per-request", type=int, default=1,
|
|
67
|
+
help="continuations per API call via the provider's n param (default: 1)")
|
|
68
|
+
p.add_argument("--calibrate", default=None, metavar="GLOB",
|
|
69
|
+
help="glob (quote it) or comma-separated path(s) of known NON-member documents; "
|
|
70
|
+
"report each candidate's percentile vs this baseline")
|
|
71
|
+
p.add_argument("--json", action="store_true", help="emit JSON instead of a table")
|
|
72
|
+
p.add_argument("--show-completions", action="store_true",
|
|
73
|
+
help="include raw completions in JSON output")
|
|
74
|
+
p.add_argument("--version", action="version", version=f"leakit {__version__}")
|
|
75
|
+
return p
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _read_documents(args) -> list[tuple[str, str]]:
|
|
79
|
+
"""Return (id, text) pairs from files, --text, or stdin."""
|
|
80
|
+
docs: list[tuple[str, str]] = []
|
|
81
|
+
for path in args.documents:
|
|
82
|
+
text = Path(path).read_text(encoding="utf-8", errors="replace")
|
|
83
|
+
docs.append((path, text))
|
|
84
|
+
for i, t in enumerate(args.text):
|
|
85
|
+
docs.append((f"--text[{i}]", t))
|
|
86
|
+
if not docs and not sys.stdin.isatty():
|
|
87
|
+
data = sys.stdin.read()
|
|
88
|
+
if data.strip():
|
|
89
|
+
docs.append(("<stdin>", data))
|
|
90
|
+
return docs
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _make_scorer(args) -> LeakIt:
|
|
94
|
+
return LeakIt(
|
|
95
|
+
model=args.model,
|
|
96
|
+
base_url=args.base_url,
|
|
97
|
+
api_key_env=args.api_key_env,
|
|
98
|
+
n_samples=args.samples,
|
|
99
|
+
max_tokens=args.max_tokens,
|
|
100
|
+
temperature=args.temperature,
|
|
101
|
+
top_p=args.top_p,
|
|
102
|
+
mode=args.mode,
|
|
103
|
+
concurrency=args.concurrency,
|
|
104
|
+
n_per_request=args.n_per_request,
|
|
105
|
+
statistic=args.statistic,
|
|
106
|
+
k=args.k,
|
|
107
|
+
prefix_chars=args.prefix_chars,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _print_table(results: list[ScoreResult], percentiles: dict[str, float] | None) -> None:
|
|
112
|
+
name_w = max([len(r.document_id) for r in results] + [8])
|
|
113
|
+
header = f"{'document':<{name_w}} {'score':>7} {'samples':>7}"
|
|
114
|
+
if percentiles is not None:
|
|
115
|
+
header += f" {'pctile':>7}"
|
|
116
|
+
print(header)
|
|
117
|
+
print("-" * len(header))
|
|
118
|
+
for r in results:
|
|
119
|
+
line = f"{r.document_id:<{name_w}} {r.score:>7.4f} {r.n_returned:>3}/{r.n_requested:<3}"
|
|
120
|
+
if percentiles is not None:
|
|
121
|
+
pv = percentiles.get(r.document_id, float("nan"))
|
|
122
|
+
line += f" {pv:>6.1f}%"
|
|
123
|
+
print(line)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def main(argv: list[str] | None = None) -> int:
|
|
127
|
+
args = build_parser().parse_args(argv)
|
|
128
|
+
docs = _read_documents(args)
|
|
129
|
+
if not docs:
|
|
130
|
+
print("error: no documents given (pass file paths, --text, or pipe stdin)", file=sys.stderr)
|
|
131
|
+
return 2
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
scorer = _make_scorer(args)
|
|
135
|
+
except (RuntimeError, ValueError) as exc:
|
|
136
|
+
print(f"error: {exc}", file=sys.stderr)
|
|
137
|
+
return 2
|
|
138
|
+
|
|
139
|
+
results = [scorer.score(text, document_id=doc_id) for doc_id, text in docs]
|
|
140
|
+
|
|
141
|
+
percentiles: dict[str, float] | None = None
|
|
142
|
+
if args.calibrate:
|
|
143
|
+
import glob
|
|
144
|
+
|
|
145
|
+
paths: list[str] = []
|
|
146
|
+
for pattern in args.calibrate.split(","):
|
|
147
|
+
paths.extend(sorted(glob.glob(pattern.strip())))
|
|
148
|
+
if not paths:
|
|
149
|
+
print(f"error: --calibrate matched no files: {args.calibrate!r}", file=sys.stderr)
|
|
150
|
+
return 2
|
|
151
|
+
baseline_scores = [
|
|
152
|
+
scorer.score(Path(p).read_text(encoding="utf-8", errors="replace"),
|
|
153
|
+
document_id=p).score
|
|
154
|
+
for p in paths
|
|
155
|
+
]
|
|
156
|
+
percentiles = {r.document_id: percentile_of(r.score, baseline_scores) for r in results}
|
|
157
|
+
|
|
158
|
+
if args.json:
|
|
159
|
+
out = [r.as_dict(include_completions=args.show_completions) for r in results]
|
|
160
|
+
if percentiles is not None:
|
|
161
|
+
for d in out:
|
|
162
|
+
d["percentile_vs_baseline"] = percentiles.get(d["document"])
|
|
163
|
+
print(json.dumps(out, indent=2))
|
|
164
|
+
else:
|
|
165
|
+
_print_table(results, percentiles)
|
|
166
|
+
|
|
167
|
+
if any(r.n_returned == 0 for r in results):
|
|
168
|
+
print("warning: some documents returned no completions (check model/endpoint/key)",
|
|
169
|
+
file=sys.stderr)
|
|
170
|
+
return 1
|
|
171
|
+
return 0
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
if __name__ == "__main__": # pragma: no cover
|
|
175
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""High-level scoring API: prefix extraction, sampling, self-concentration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
from . import _stats
|
|
8
|
+
from .sampler import Sampler, SamplerConfig, resolve_api_key
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ScoreResult:
|
|
13
|
+
document_id: str
|
|
14
|
+
score: float
|
|
15
|
+
n_requested: int
|
|
16
|
+
n_returned: int
|
|
17
|
+
statistic: str
|
|
18
|
+
prefix: str
|
|
19
|
+
completions: list[str]
|
|
20
|
+
|
|
21
|
+
def as_dict(self, include_completions: bool = False) -> dict:
|
|
22
|
+
d = {
|
|
23
|
+
"document": self.document_id,
|
|
24
|
+
"score": self.score,
|
|
25
|
+
"statistic": self.statistic,
|
|
26
|
+
"n_requested": self.n_requested,
|
|
27
|
+
"n_returned": self.n_returned,
|
|
28
|
+
"prefix_preview": self.prefix[:120],
|
|
29
|
+
}
|
|
30
|
+
if include_completions:
|
|
31
|
+
d["completions"] = self.completions
|
|
32
|
+
return d
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def make_prefix(document: str, prefix_chars: int) -> str:
|
|
36
|
+
"""Take the conditioning prefix from the start of the document.
|
|
37
|
+
|
|
38
|
+
prefix_chars == 0 means use the whole document. The prefix is what the model
|
|
39
|
+
conditions on; the rest of the document is never sent (continuation-free).
|
|
40
|
+
"""
|
|
41
|
+
text = document.strip()
|
|
42
|
+
if prefix_chars and prefix_chars > 0:
|
|
43
|
+
return text[:prefix_chars]
|
|
44
|
+
return text
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LeakIt:
|
|
48
|
+
"""Continuation-free membership-inference scorer over an OpenAI-compatible API.
|
|
49
|
+
|
|
50
|
+
Example
|
|
51
|
+
-------
|
|
52
|
+
>>> scorer = LeakIt(model="gpt-4o-mini") # reads LEAKIT_API_KEY / OPENAI_API_KEY
|
|
53
|
+
>>> result = scorer.score("In the beginning the Universe was created.")
|
|
54
|
+
>>> result.score
|
|
55
|
+
0.42
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
model: str,
|
|
61
|
+
*,
|
|
62
|
+
base_url: str | None = None,
|
|
63
|
+
api_key: str | None = None,
|
|
64
|
+
api_key_env: str | None = None,
|
|
65
|
+
n_samples: int = 16,
|
|
66
|
+
max_tokens: int = 64,
|
|
67
|
+
temperature: float = 1.0,
|
|
68
|
+
top_p: float = 1.0,
|
|
69
|
+
mode: str = "chat",
|
|
70
|
+
concurrency: int = 8,
|
|
71
|
+
n_per_request: int = 1,
|
|
72
|
+
statistic: str = "word-jaccard",
|
|
73
|
+
k: int = 5,
|
|
74
|
+
prefix_chars: int = 256,
|
|
75
|
+
client=None,
|
|
76
|
+
):
|
|
77
|
+
if statistic not in _stats.STATISTICS:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"unknown statistic {statistic!r}; choose from {sorted(_stats.STATISTICS)}"
|
|
80
|
+
)
|
|
81
|
+
self.statistic = statistic
|
|
82
|
+
self.k = k
|
|
83
|
+
self.prefix_chars = prefix_chars
|
|
84
|
+
self.n_samples = n_samples
|
|
85
|
+
key = api_key or resolve_api_key(api_key_env)
|
|
86
|
+
self.sampler = Sampler(
|
|
87
|
+
SamplerConfig(
|
|
88
|
+
model=model,
|
|
89
|
+
base_url=base_url,
|
|
90
|
+
n_samples=n_samples,
|
|
91
|
+
max_tokens=max_tokens,
|
|
92
|
+
temperature=temperature,
|
|
93
|
+
top_p=top_p,
|
|
94
|
+
mode=mode,
|
|
95
|
+
concurrency=concurrency,
|
|
96
|
+
n_per_request=n_per_request,
|
|
97
|
+
),
|
|
98
|
+
api_key=key,
|
|
99
|
+
client=client,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def score(self, document: str, document_id: str = "<text>") -> ScoreResult:
|
|
103
|
+
prefix = make_prefix(document, self.prefix_chars)
|
|
104
|
+
completions = self.sampler.sample(prefix)
|
|
105
|
+
value = _stats.compute(completions, self.statistic, self.k)
|
|
106
|
+
return ScoreResult(
|
|
107
|
+
document_id=document_id,
|
|
108
|
+
score=value,
|
|
109
|
+
n_requested=self.n_samples,
|
|
110
|
+
n_returned=len(completions),
|
|
111
|
+
statistic=self.statistic,
|
|
112
|
+
prefix=prefix,
|
|
113
|
+
completions=completions,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def percentile_of(value: float, baseline: list[float]) -> float:
|
|
118
|
+
"""Fraction of baseline scores below ``value`` (0-100). Empty baseline -> nan."""
|
|
119
|
+
if not baseline:
|
|
120
|
+
return float("nan")
|
|
121
|
+
below = sum(1 for b in baseline if b < value)
|
|
122
|
+
return 100.0 * below / len(baseline)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Sampling backend built on the OpenAI Python SDK.
|
|
2
|
+
|
|
3
|
+
Any OpenAI-compatible endpoint works: set ``base_url`` to the provider you are
|
|
4
|
+
probing (OpenAI, Anthropic via its OpenAI-compatible route, OpenRouter, vLLM,
|
|
5
|
+
Together, a local server, ...) and supply the matching API key. The attack only
|
|
6
|
+
needs a sampling endpoint; it never reads logits.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
|
|
15
|
+
# The continuation instruction used in chat mode. Closed chat APIs do not expose
|
|
16
|
+
# a raw text-completion surface, so we ask the model to continue the passage
|
|
17
|
+
# verbatim. Base/text models should use mode="completion" for the unbiased
|
|
18
|
+
# sampling distribution the paper studies.
|
|
19
|
+
_CONTINUE_SYSTEM = (
|
|
20
|
+
"You continue text. Given the beginning of a passage, write the text that "
|
|
21
|
+
"most plausibly comes next. Output only the continuation, with no preamble, "
|
|
22
|
+
"quotation marks, or commentary."
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
DEFAULT_API_KEY_ENVS = ("LEAKIT_API_KEY", "OPENAI_API_KEY")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def resolve_api_key(api_key_env: str | None = None) -> str:
|
|
29
|
+
"""Resolve the API key from the environment.
|
|
30
|
+
|
|
31
|
+
If ``api_key_env`` is given, only that variable is consulted. Otherwise the
|
|
32
|
+
default chain (LEAKIT_API_KEY then OPENAI_API_KEY) is tried.
|
|
33
|
+
"""
|
|
34
|
+
candidates = (api_key_env,) if api_key_env else DEFAULT_API_KEY_ENVS
|
|
35
|
+
for name in candidates:
|
|
36
|
+
if name and os.environ.get(name):
|
|
37
|
+
return os.environ[name]
|
|
38
|
+
tried = ", ".join(c for c in candidates if c)
|
|
39
|
+
raise RuntimeError(
|
|
40
|
+
f"no API key found in environment (looked at: {tried}). "
|
|
41
|
+
f"Export your provider key, e.g. `export {candidates[0]}=sk-...`."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class SamplerConfig:
|
|
47
|
+
model: str
|
|
48
|
+
base_url: str | None = None
|
|
49
|
+
n_samples: int = 16
|
|
50
|
+
max_tokens: int = 64
|
|
51
|
+
temperature: float = 1.0
|
|
52
|
+
top_p: float = 1.0
|
|
53
|
+
mode: str = "chat" # "chat" or "completion"
|
|
54
|
+
concurrency: int = 8
|
|
55
|
+
n_per_request: int = 1 # set >1 to batch via the provider's `n` param
|
|
56
|
+
timeout: float = 120.0
|
|
57
|
+
max_retries: int = 4
|
|
58
|
+
extra_body: dict = field(default_factory=dict)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Sampler:
|
|
62
|
+
"""Draws continuations of a prefix from an OpenAI-compatible endpoint."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, config: SamplerConfig, api_key: str, client=None):
|
|
65
|
+
self.config = config
|
|
66
|
+
# `client` is injectable for testing; defaults to a real OpenAI client.
|
|
67
|
+
if client is None:
|
|
68
|
+
from openai import OpenAI
|
|
69
|
+
|
|
70
|
+
client = OpenAI(
|
|
71
|
+
api_key=api_key,
|
|
72
|
+
base_url=config.base_url,
|
|
73
|
+
timeout=config.timeout,
|
|
74
|
+
max_retries=0, # we handle retries/backoff ourselves
|
|
75
|
+
)
|
|
76
|
+
self.client = client
|
|
77
|
+
|
|
78
|
+
def _request(self, prefix: str, n: int) -> list[str]:
|
|
79
|
+
"""One API call returning up to ``n`` continuations."""
|
|
80
|
+
cfg = self.config
|
|
81
|
+
if cfg.mode == "completion":
|
|
82
|
+
resp = self.client.completions.create(
|
|
83
|
+
model=cfg.model,
|
|
84
|
+
prompt=prefix,
|
|
85
|
+
max_tokens=cfg.max_tokens,
|
|
86
|
+
temperature=cfg.temperature,
|
|
87
|
+
top_p=cfg.top_p,
|
|
88
|
+
n=n,
|
|
89
|
+
extra_body=cfg.extra_body or None,
|
|
90
|
+
)
|
|
91
|
+
return [choice.text or "" for choice in resp.choices]
|
|
92
|
+
elif cfg.mode == "chat":
|
|
93
|
+
resp = self.client.chat.completions.create(
|
|
94
|
+
model=cfg.model,
|
|
95
|
+
messages=[
|
|
96
|
+
{"role": "system", "content": _CONTINUE_SYSTEM},
|
|
97
|
+
{"role": "user", "content": prefix},
|
|
98
|
+
],
|
|
99
|
+
max_tokens=cfg.max_tokens,
|
|
100
|
+
temperature=cfg.temperature,
|
|
101
|
+
top_p=cfg.top_p,
|
|
102
|
+
n=n,
|
|
103
|
+
extra_body=cfg.extra_body or None,
|
|
104
|
+
)
|
|
105
|
+
return [(choice.message.content or "") for choice in resp.choices]
|
|
106
|
+
raise ValueError(f"unknown mode {cfg.mode!r}; use 'chat' or 'completion'")
|
|
107
|
+
|
|
108
|
+
def _request_with_retry(self, prefix: str, n: int) -> list[str]:
|
|
109
|
+
import time
|
|
110
|
+
|
|
111
|
+
last_exc: Exception | None = None
|
|
112
|
+
for attempt in range(self.config.max_retries):
|
|
113
|
+
try:
|
|
114
|
+
return self._request(prefix, n)
|
|
115
|
+
except Exception as exc: # noqa: BLE001 - provider errors are heterogeneous
|
|
116
|
+
last_exc = exc
|
|
117
|
+
if not _is_retryable(exc) or attempt == self.config.max_retries - 1:
|
|
118
|
+
break
|
|
119
|
+
time.sleep(min(2**attempt, 30))
|
|
120
|
+
raise RuntimeError(f"sampling request failed: {last_exc}") from last_exc
|
|
121
|
+
|
|
122
|
+
def sample(self, prefix: str) -> list[str]:
|
|
123
|
+
"""Return up to ``n_samples`` continuations of ``prefix``.
|
|
124
|
+
|
|
125
|
+
Requests are batched by ``n_per_request`` and run concurrently. Failed
|
|
126
|
+
requests are skipped; the returned list may be shorter than n_samples.
|
|
127
|
+
"""
|
|
128
|
+
cfg = self.config
|
|
129
|
+
per = max(1, cfg.n_per_request)
|
|
130
|
+
# Build a list of per-request batch sizes summing to n_samples.
|
|
131
|
+
batches: list[int] = []
|
|
132
|
+
remaining = cfg.n_samples
|
|
133
|
+
while remaining > 0:
|
|
134
|
+
batches.append(min(per, remaining))
|
|
135
|
+
remaining -= batches[-1]
|
|
136
|
+
|
|
137
|
+
completions: list[str] = []
|
|
138
|
+
workers = max(1, min(cfg.concurrency, len(batches)))
|
|
139
|
+
with ThreadPoolExecutor(max_workers=workers) as ex:
|
|
140
|
+
futures = [ex.submit(self._request_with_retry, prefix, b) for b in batches]
|
|
141
|
+
for fut in as_completed(futures):
|
|
142
|
+
try:
|
|
143
|
+
completions.extend(fut.result())
|
|
144
|
+
except Exception: # noqa: BLE001 - one dead batch shouldn't kill the run
|
|
145
|
+
continue
|
|
146
|
+
return completions
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _is_retryable(exc: Exception) -> bool:
|
|
150
|
+
"""Heuristic: retry on rate limits, timeouts, and 5xx; not on 4xx/auth."""
|
|
151
|
+
status = getattr(exc, "status_code", None)
|
|
152
|
+
if status is not None:
|
|
153
|
+
return status == 429 or status >= 500
|
|
154
|
+
name = type(exc).__name__.lower()
|
|
155
|
+
return any(tok in name for tok in ("timeout", "connection", "ratelimit", "apierror"))
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Shared fakes that mimic the openai client response shape."""
|
|
2
|
+
|
|
3
|
+
import threading
|
|
4
|
+
import types
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _chat_response(texts):
|
|
10
|
+
return types.SimpleNamespace(
|
|
11
|
+
choices=[
|
|
12
|
+
types.SimpleNamespace(message=types.SimpleNamespace(content=t)) for t in texts
|
|
13
|
+
]
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _completion_response(texts):
|
|
18
|
+
return types.SimpleNamespace(
|
|
19
|
+
choices=[types.SimpleNamespace(text=t) for t in texts]
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FakeAPIError(Exception):
|
|
24
|
+
"""Mimics a retryable provider error (carries an HTTP status code)."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, status_code=503):
|
|
27
|
+
super().__init__(f"fake {status_code}")
|
|
28
|
+
self.status_code = status_code
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FakeClient:
|
|
32
|
+
"""Returns canned continuations keyed by a substring of the prefix.
|
|
33
|
+
|
|
34
|
+
Successive requests rotate through the pool (so independent single-sample
|
|
35
|
+
calls vary, like a real model). `fail_times` raises a retryable error on the
|
|
36
|
+
first N calls to exercise backoff.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, script=None, default=None, fail_times=0):
|
|
40
|
+
self.script = script or {}
|
|
41
|
+
self.default = default or ["a b c", "a b d", "a b e"]
|
|
42
|
+
self.fail_times = fail_times
|
|
43
|
+
self.calls = 0
|
|
44
|
+
self._offset = 0
|
|
45
|
+
self._lock = threading.Lock()
|
|
46
|
+
self.chat = types.SimpleNamespace(completions=types.SimpleNamespace(create=self._chat))
|
|
47
|
+
self.completions = types.SimpleNamespace(create=self._completion)
|
|
48
|
+
|
|
49
|
+
def _pick(self, prompt):
|
|
50
|
+
for key, comps in self.script.items():
|
|
51
|
+
if key in prompt:
|
|
52
|
+
return comps
|
|
53
|
+
return self.default
|
|
54
|
+
|
|
55
|
+
def _emit(self, texts, n):
|
|
56
|
+
with self._lock:
|
|
57
|
+
start = self._offset
|
|
58
|
+
self._offset += n
|
|
59
|
+
return [texts[(start + i) % len(texts)] for i in range(n)]
|
|
60
|
+
|
|
61
|
+
def _guard(self):
|
|
62
|
+
with self._lock:
|
|
63
|
+
self.calls += 1
|
|
64
|
+
failed = self.calls <= self.fail_times
|
|
65
|
+
if failed:
|
|
66
|
+
raise FakeAPIError(503)
|
|
67
|
+
|
|
68
|
+
def _chat(self, *, model, messages, n=1, **kw):
|
|
69
|
+
self._guard()
|
|
70
|
+
prompt = messages[-1]["content"]
|
|
71
|
+
return _chat_response(self._emit(self._pick(prompt), n))
|
|
72
|
+
|
|
73
|
+
def _completion(self, *, model, prompt, n=1, **kw):
|
|
74
|
+
self._guard()
|
|
75
|
+
return _completion_response(self._emit(self._pick(prompt), n))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.fixture
|
|
79
|
+
def fake_client():
|
|
80
|
+
return FakeClient
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""End-to-end CLI tests with an injected fake client (no network)."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from leakit import cli
|
|
8
|
+
from leakit.core import LeakIt
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def patch_scorer(monkeypatch, fake_client):
|
|
13
|
+
"""Replace cli._make_scorer with one that injects a scripted fake client."""
|
|
14
|
+
def _factory(args):
|
|
15
|
+
fc = fake_client(
|
|
16
|
+
script={
|
|
17
|
+
"MEMBER": ["born 1809 hardin county"], # identical -> full concentration
|
|
18
|
+
"NOVEL": ["sky blue today", "i wonder if", "banana telephone qux"],
|
|
19
|
+
},
|
|
20
|
+
default=["a b c", "a b d"],
|
|
21
|
+
)
|
|
22
|
+
return LeakIt(
|
|
23
|
+
model=args.model, n_samples=args.samples, statistic=args.statistic,
|
|
24
|
+
prefix_chars=args.prefix_chars, mode=args.mode, client=fc, api_key="k",
|
|
25
|
+
)
|
|
26
|
+
monkeypatch.setattr(cli, "_make_scorer", _factory)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_cli_table_output(patch_scorer, tmp_path, capsys):
|
|
30
|
+
doc = tmp_path / "suspect.txt"
|
|
31
|
+
doc.write_text("MEMBER text here")
|
|
32
|
+
rc = cli.main(["--model", "m", str(doc)])
|
|
33
|
+
out = capsys.readouterr().out
|
|
34
|
+
assert rc == 0
|
|
35
|
+
assert "suspect.txt" in out
|
|
36
|
+
assert "score" in out
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_cli_json_output(patch_scorer, tmp_path, capsys):
|
|
40
|
+
doc = tmp_path / "d.txt"
|
|
41
|
+
doc.write_text("MEMBER text")
|
|
42
|
+
rc = cli.main(["--model", "m", "--json", str(doc)])
|
|
43
|
+
out = capsys.readouterr().out
|
|
44
|
+
data = json.loads(out)
|
|
45
|
+
assert rc == 0
|
|
46
|
+
assert data[0]["document"].endswith("d.txt")
|
|
47
|
+
assert data[0]["n_returned"] == 16
|
|
48
|
+
assert 0.0 <= data[0]["score"] <= 1.0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_cli_member_scores_above_novel(patch_scorer, tmp_path, capsys):
|
|
52
|
+
m = tmp_path / "m.txt"; m.write_text("MEMBER passage")
|
|
53
|
+
nvl = tmp_path / "n.txt"; nvl.write_text("NOVEL passage")
|
|
54
|
+
cli.main(["--model", "m", "--json", str(m), str(nvl)])
|
|
55
|
+
data = json.loads(capsys.readouterr().out)
|
|
56
|
+
by = {d["document"].split("/")[-1]: d["score"] for d in data}
|
|
57
|
+
assert by["m.txt"] > by["n.txt"]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_cli_calibrate_percentile(patch_scorer, tmp_path, capsys):
|
|
61
|
+
suspect = tmp_path / "suspect.txt"; suspect.write_text("MEMBER passage")
|
|
62
|
+
clean1 = tmp_path / "c1.txt"; clean1.write_text("NOVEL one")
|
|
63
|
+
clean2 = tmp_path / "c2.txt"; clean2.write_text("NOVEL two")
|
|
64
|
+
rc = cli.main(["--model", "m", "--json",
|
|
65
|
+
"--calibrate", f"{clean1},{clean2}", str(suspect)])
|
|
66
|
+
data = json.loads(capsys.readouterr().out)
|
|
67
|
+
assert rc == 0
|
|
68
|
+
# MEMBER concentrates more than the NOVEL baseline -> high percentile
|
|
69
|
+
assert data[0]["percentile_vs_baseline"] == 100.0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_cli_no_documents_errors(monkeypatch, capsys):
|
|
73
|
+
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
|
74
|
+
rc = cli.main(["--model", "m"])
|
|
75
|
+
assert rc == 2
|
|
76
|
+
assert "no documents" in capsys.readouterr().err
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_cli_stdin(patch_scorer, monkeypatch, capsys):
|
|
80
|
+
import io
|
|
81
|
+
monkeypatch.setattr("sys.stdin", io.StringIO("MEMBER from stdin"))
|
|
82
|
+
monkeypatch.setattr("sys.stdin.isatty", lambda: False, raising=False)
|
|
83
|
+
rc = cli.main(["--model", "m", "--json"])
|
|
84
|
+
data = json.loads(capsys.readouterr().out)
|
|
85
|
+
assert rc == 0
|
|
86
|
+
assert data[0]["document"] == "<stdin>"
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Live HTTP integration: real openai SDK against a local OpenAI-shaped server.
|
|
2
|
+
|
|
3
|
+
Proves the actual network path (request shape, base_url routing, response
|
|
4
|
+
parsing) without any paid API. The server returns concentrated continuations
|
|
5
|
+
for a "MEMBER" prefix and diffuse ones for a "NOVEL" prefix.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import threading
|
|
10
|
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
11
|
+
|
|
12
|
+
import pytest
|
|
13
|
+
|
|
14
|
+
from leakit import LeakIt
|
|
15
|
+
|
|
16
|
+
# Members: identical continuations (model concentrates on memorised text).
|
|
17
|
+
# Novel: four distinct continuations rotated across requests (diffuse distribution).
|
|
18
|
+
MEMBER_CONTINUATIONS = ["in Hardin County Kentucky"]
|
|
19
|
+
NOVEL_CONTINUATIONS = ["a winding road ahead", "thoughts about nothing much",
|
|
20
|
+
"quux frobnicate widget", "the seventeenth of never"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class _Handler(BaseHTTPRequestHandler):
|
|
24
|
+
def log_message(self, *a): # silence
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
def _emit(self, pool, n):
|
|
28
|
+
# Rotate a server-wide cursor so independent requests vary, like a model.
|
|
29
|
+
srv = self.server
|
|
30
|
+
with srv._lock: # type: ignore[attr-defined]
|
|
31
|
+
start = srv._cursor # type: ignore[attr-defined]
|
|
32
|
+
srv._cursor += n # type: ignore[attr-defined]
|
|
33
|
+
return [pool[(start + i) % len(pool)] for i in range(n)]
|
|
34
|
+
|
|
35
|
+
def do_POST(self):
|
|
36
|
+
length = int(self.headers.get("Content-Length", 0))
|
|
37
|
+
body = json.loads(self.rfile.read(length) or b"{}")
|
|
38
|
+
n = int(body.get("n", 1) or 1)
|
|
39
|
+
|
|
40
|
+
if self.path.endswith("/chat/completions"):
|
|
41
|
+
prompt = body["messages"][-1]["content"]
|
|
42
|
+
pool = MEMBER_CONTINUATIONS if "MEMBER" in prompt else NOVEL_CONTINUATIONS
|
|
43
|
+
texts = self._emit(pool, n)
|
|
44
|
+
choices = [{"index": i, "message": {"role": "assistant",
|
|
45
|
+
"content": t}, "finish_reason": "stop"}
|
|
46
|
+
for i, t in enumerate(texts)]
|
|
47
|
+
else: # /completions
|
|
48
|
+
prompt = body.get("prompt", "")
|
|
49
|
+
pool = MEMBER_CONTINUATIONS if "MEMBER" in prompt else NOVEL_CONTINUATIONS
|
|
50
|
+
texts = self._emit(pool, n)
|
|
51
|
+
choices = [{"index": i, "text": t, "finish_reason": "stop"}
|
|
52
|
+
for i, t in enumerate(texts)]
|
|
53
|
+
|
|
54
|
+
payload = json.dumps({
|
|
55
|
+
"id": "cmpl-test", "object": "chat.completion", "model": body.get("model", "m"),
|
|
56
|
+
"choices": choices,
|
|
57
|
+
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
|
58
|
+
}).encode()
|
|
59
|
+
self.send_response(200)
|
|
60
|
+
self.send_header("Content-Type", "application/json")
|
|
61
|
+
self.send_header("Content-Length", str(len(payload)))
|
|
62
|
+
self.end_headers()
|
|
63
|
+
self.wfile.write(payload)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.fixture
|
|
67
|
+
def local_openai_server():
|
|
68
|
+
import threading as _t
|
|
69
|
+
server = HTTPServer(("127.0.0.1", 0), _Handler)
|
|
70
|
+
server._cursor = 0
|
|
71
|
+
server._lock = _t.Lock()
|
|
72
|
+
port = server.server_address[1]
|
|
73
|
+
t = threading.Thread(target=server.serve_forever, daemon=True)
|
|
74
|
+
t.start()
|
|
75
|
+
yield f"http://127.0.0.1:{port}/v1"
|
|
76
|
+
server.shutdown()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_live_http_chat_mode_member_vs_novel(local_openai_server):
|
|
80
|
+
scorer = LeakIt(model="local-test", base_url=local_openai_server, api_key="test-key",
|
|
81
|
+
n_samples=8, concurrency=4, mode="chat")
|
|
82
|
+
member = scorer.score("MEMBER: the sixteenth president was born", document_id="member")
|
|
83
|
+
novel = scorer.score("NOVEL: my grocery list for tuesday", document_id="novel")
|
|
84
|
+
assert member.n_returned == 8
|
|
85
|
+
assert novel.n_returned == 8
|
|
86
|
+
assert member.score == 1.0 # identical continuations -> full concentration
|
|
87
|
+
assert member.score > novel.score
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def test_live_http_completion_mode(local_openai_server):
|
|
91
|
+
scorer = LeakIt(model="local-test", base_url=local_openai_server, api_key="test-key",
|
|
92
|
+
n_samples=4, concurrency=2, mode="completion")
|
|
93
|
+
res = scorer.score("MEMBER prefix", document_id="m")
|
|
94
|
+
assert res.n_returned == 4
|
|
95
|
+
assert res.score == 1.0
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def test_live_http_n_per_request_batch(local_openai_server):
|
|
99
|
+
scorer = LeakIt(model="local-test", base_url=local_openai_server, api_key="test-key",
|
|
100
|
+
n_samples=12, n_per_request=4, concurrency=3, mode="chat")
|
|
101
|
+
res = scorer.score("MEMBER prefix", document_id="m")
|
|
102
|
+
assert res.n_returned == 12
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Sampler behaviour with an injected fake client (no network)."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from leakit.sampler import Sampler, SamplerConfig, resolve_api_key
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _sampler(fake, **overrides):
|
|
9
|
+
cfg = SamplerConfig(model="m", **overrides)
|
|
10
|
+
return Sampler(cfg, api_key="k", client=fake)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_chat_collects_all_samples(fake_client):
|
|
14
|
+
s = _sampler(fake_client(default=["x y z"]), n_samples=16, concurrency=4)
|
|
15
|
+
out = s.sample("some prefix")
|
|
16
|
+
assert len(out) == 16
|
|
17
|
+
assert all(c == "x y z" for c in out)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_completion_mode(fake_client):
|
|
21
|
+
s = _sampler(fake_client(default=["foo bar"]), n_samples=8, mode="completion", concurrency=2)
|
|
22
|
+
out = s.sample("prefix")
|
|
23
|
+
assert len(out) == 8
|
|
24
|
+
assert out[0] == "foo bar"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_n_per_request_batches(fake_client):
|
|
28
|
+
fc = fake_client(default=["a b", "c d"])
|
|
29
|
+
s = _sampler(fc, n_samples=10, n_per_request=5, concurrency=4)
|
|
30
|
+
out = s.sample("prefix")
|
|
31
|
+
assert len(out) == 10
|
|
32
|
+
# 10 samples / 5 per request = 2 calls
|
|
33
|
+
assert fc.calls == 2
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def test_retry_then_success(fake_client):
|
|
37
|
+
fc = fake_client(default=["ok ok"], fail_times=2)
|
|
38
|
+
s = _sampler(fc, n_samples=1, max_retries=5, concurrency=1)
|
|
39
|
+
out = s.sample("prefix")
|
|
40
|
+
assert out == ["ok ok"]
|
|
41
|
+
assert fc.calls == 3 # 2 failures + 1 success
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_dead_batch_is_skipped_not_fatal(fake_client):
|
|
45
|
+
# Every call fails; with retries exhausted the run yields zero completions
|
|
46
|
+
fc = fake_client(default=["never"], fail_times=99)
|
|
47
|
+
s = _sampler(fc, n_samples=4, max_retries=2, concurrency=2)
|
|
48
|
+
out = s.sample("prefix")
|
|
49
|
+
assert out == []
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_resolve_api_key_prefers_named(monkeypatch):
|
|
53
|
+
monkeypatch.delenv("LEAKIT_API_KEY", raising=False)
|
|
54
|
+
monkeypatch.setenv("MY_KEY", "secret")
|
|
55
|
+
assert resolve_api_key("MY_KEY") == "secret"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_resolve_api_key_missing_raises(monkeypatch):
|
|
59
|
+
monkeypatch.delenv("LEAKIT_API_KEY", raising=False)
|
|
60
|
+
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
|
61
|
+
with pytest.raises(RuntimeError):
|
|
62
|
+
resolve_api_key()
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Statistic correctness, cross-checked against the paper's reference behaviour."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from leakit import _stats
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_identical_completions_score_one():
|
|
9
|
+
comps = ["the quick brown fox", "the quick brown fox", "the quick brown fox"]
|
|
10
|
+
assert _stats.self_concentration_word_jaccard(comps) == 1.0
|
|
11
|
+
assert _stats.self_concentration_kgram(comps, k=3) == 1.0
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_disjoint_completions_score_zero():
|
|
15
|
+
comps = ["alpha beta gamma", "delta epsilon zeta", "eta theta iota"]
|
|
16
|
+
assert _stats.self_concentration_word_jaccard(comps) == 0.0
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_word_jaccard_known_value():
|
|
20
|
+
# {a,b} vs {b,c}: |∩|=1, |∪|=3 -> 1/3 ; single pair so mean == 1/3
|
|
21
|
+
assert math.isclose(_stats.self_concentration_word_jaccard(["a b", "b c"]), 1 / 3)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def test_word_jaccard_mean_over_pairs():
|
|
25
|
+
# pairs: (1,2)=1/3, (1,3)=0, (2,3)=0 -> mean = 1/9
|
|
26
|
+
val = _stats.self_concentration_word_jaccard(["a b", "b c", "x y"])
|
|
27
|
+
assert math.isclose(val, (1 / 3) / 3)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_fewer_than_two_is_zero():
|
|
31
|
+
assert _stats.self_concentration_word_jaccard([]) == 0.0
|
|
32
|
+
assert _stats.self_concentration_word_jaccard(["only one"]) == 0.0
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_empty_strings_do_not_crash():
|
|
36
|
+
assert _stats.self_concentration_word_jaccard(["", ""]) == 0.0
|
|
37
|
+
assert _stats.self_concentration_kgram(["", "abc"], k=5) == 0.0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_member_scores_above_nonmember():
|
|
41
|
+
member = ["born in 1809 in Hardin County", "born in 1809 in Hardin County, Kentucky",
|
|
42
|
+
"born in 1809 in Hardin"]
|
|
43
|
+
nonmember = ["the weather today is", "I think that maybe", "purple monkey dishwasher"]
|
|
44
|
+
assert (_stats.self_concentration_word_jaccard(member)
|
|
45
|
+
> _stats.self_concentration_word_jaccard(nonmember))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_dispatch_unknown_raises():
|
|
49
|
+
try:
|
|
50
|
+
_stats.compute(["a", "b"], statistic="nope")
|
|
51
|
+
except ValueError as e:
|
|
52
|
+
assert "unknown statistic" in str(e)
|
|
53
|
+
else:
|
|
54
|
+
raise AssertionError("expected ValueError")
|