semaxis 0.16.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- semaxis-0.16.0/LICENSE +21 -0
- semaxis-0.16.0/PKG-INFO +14 -0
- semaxis-0.16.0/pyproject.toml +38 -0
- semaxis-0.16.0/src/semaxis/__init__.py +11 -0
- semaxis-0.16.0/src/semaxis/llm.py +129 -0
- semaxis-0.16.0/src/semaxis/nli.py +46 -0
- semaxis-0.16.0/src/semaxis/prompts/__init__.py +0 -0
- semaxis-0.16.0/src/semaxis/prompts/collection_description.py +42 -0
- semaxis-0.16.0/src/semaxis/prompts/discriminative_features.py +57 -0
- semaxis-0.16.0/src/semaxis/sampling.py +176 -0
- semaxis-0.16.0/src/semaxis/supervised.py +197 -0
- semaxis-0.16.0/src/semaxis/unsupervised.py +139 -0
semaxis-0.16.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 pillyshi
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
semaxis-0.16.0/PKG-INFO
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: semaxis
|
|
3
|
+
Version: 0.16.0
|
|
4
|
+
Summary: Interpretable NLI-based text features for scikit-learn
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
10
|
+
Requires-Dist: numpy (>=1.26)
|
|
11
|
+
Requires-Dist: openai (>=1.30)
|
|
12
|
+
Requires-Dist: scikit-learn (>=1.4)
|
|
13
|
+
Requires-Dist: sentence-transformers (>=3.0)
|
|
14
|
+
Requires-Dist: tiktoken (>=0.7)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "semaxis"
|
|
3
|
+
version = "0.16.0"
|
|
4
|
+
description = "Interpretable NLI-based text features for scikit-learn"
|
|
5
|
+
authors = []
|
|
6
|
+
packages = [{include = "semaxis", from = "src"}]
|
|
7
|
+
|
|
8
|
+
[tool.poetry.dependencies]
|
|
9
|
+
python = ">=3.11"
|
|
10
|
+
openai = ">=1.30"
|
|
11
|
+
tiktoken = ">=0.7"
|
|
12
|
+
sentence-transformers = ">=3.0"
|
|
13
|
+
scikit-learn = ">=1.4"
|
|
14
|
+
numpy = ">=1.26"
|
|
15
|
+
|
|
16
|
+
[tool.poetry.group.dev.dependencies]
|
|
17
|
+
pytest = "*"
|
|
18
|
+
ruff = "*"
|
|
19
|
+
mypy = "*"
|
|
20
|
+
pyalex = "^0.21"
|
|
21
|
+
|
|
22
|
+
[tool.poetry.group.langchain]
|
|
23
|
+
optional = true
|
|
24
|
+
|
|
25
|
+
[tool.poetry.group.langchain.dependencies]
|
|
26
|
+
langchain-core = ">=0.3"
|
|
27
|
+
langchain-ollama = ">=0.3"
|
|
28
|
+
|
|
29
|
+
[build-system]
|
|
30
|
+
requires = ["poetry-core"]
|
|
31
|
+
build-backend = "poetry.core.masonry.api"
|
|
32
|
+
|
|
33
|
+
[tool.mypy]
|
|
34
|
+
python_version = "3.11"
|
|
35
|
+
|
|
36
|
+
[[tool.mypy.overrides]]
|
|
37
|
+
module = ["sklearn.*", "sentence_transformers.*"]
|
|
38
|
+
ignore_missing_imports = true
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .llm import LangChainLLMClient, LLMClient
|
|
2
|
+
from .supervised import FeatureMeta, SupervisedTransformer
|
|
3
|
+
from .unsupervised import UnsupervisedTransformer
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"SupervisedTransformer",
|
|
7
|
+
"UnsupervisedTransformer",
|
|
8
|
+
"FeatureMeta",
|
|
9
|
+
"LLMClient",
|
|
10
|
+
"LangChainLLMClient",
|
|
11
|
+
]
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
import tiktoken
|
|
8
|
+
from openai import OpenAI
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@runtime_checkable
|
|
12
|
+
class BaseLLMClient(Protocol):
|
|
13
|
+
"""Protocol defining the interface required by all LLM client implementations."""
|
|
14
|
+
|
|
15
|
+
def complete(self, messages: list[dict[str, str]]) -> str: ...
|
|
16
|
+
def complete_json(self, messages: list[dict[str, str]]) -> Any: ...
|
|
17
|
+
def count_tokens(self, text: str) -> int: ...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LLMClient:
|
|
21
|
+
"""Thin wrapper around the OpenAI chat completions API with token counting."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, model: str, api_key: str | None = None) -> None:
|
|
24
|
+
self.model = model
|
|
25
|
+
self._client = OpenAI(api_key=api_key)
|
|
26
|
+
try:
|
|
27
|
+
self._encoding = tiktoken.encoding_for_model(model)
|
|
28
|
+
except KeyError:
|
|
29
|
+
self._encoding = tiktoken.get_encoding("cl100k_base")
|
|
30
|
+
|
|
31
|
+
def complete(self, messages: list[dict[str, str]]) -> str:
|
|
32
|
+
"""Send a chat completion request and return the content string."""
|
|
33
|
+
response = self._client.chat.completions.create(
|
|
34
|
+
model=self.model,
|
|
35
|
+
messages=messages, # type: ignore[arg-type]
|
|
36
|
+
)
|
|
37
|
+
return response.choices[0].message.content or ""
|
|
38
|
+
|
|
39
|
+
def complete_json(self, messages: list[dict[str, str]]) -> Any:
|
|
40
|
+
"""Send a request expecting JSON output and return the parsed object."""
|
|
41
|
+
response = self._client.chat.completions.create(
|
|
42
|
+
model=self.model,
|
|
43
|
+
messages=messages, # type: ignore[arg-type]
|
|
44
|
+
response_format={"type": "json_object"}, # type: ignore[call-overload]
|
|
45
|
+
)
|
|
46
|
+
content = response.choices[0].message.content or ""
|
|
47
|
+
return json.loads(content)
|
|
48
|
+
|
|
49
|
+
def count_tokens(self, text: str) -> int:
|
|
50
|
+
"""Return the number of tokens in text using the model's tokenizer."""
|
|
51
|
+
return len(self._encoding.encode(text))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class LangChainLLMClient:
|
|
55
|
+
"""LLM client backed by any LangChain BaseChatModel.
|
|
56
|
+
|
|
57
|
+
Supports Ollama, llama.cpp, and any other LangChain-compatible provider.
|
|
58
|
+
|
|
59
|
+
Example::
|
|
60
|
+
|
|
61
|
+
from langchain_ollama import ChatOllama
|
|
62
|
+
client = LangChainLLMClient(ChatOllama(model="llama3.2", format="json"))
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, model: Any) -> None:
|
|
66
|
+
self._model = model
|
|
67
|
+
|
|
68
|
+
def complete(self, messages: list[dict[str, str]]) -> str:
|
|
69
|
+
"""Invoke the LangChain model and return the response string."""
|
|
70
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
71
|
+
|
|
72
|
+
lc_messages: list[SystemMessage | HumanMessage] = []
|
|
73
|
+
for m in messages:
|
|
74
|
+
role = m.get("role", "user")
|
|
75
|
+
content = m.get("content", "")
|
|
76
|
+
if role == "system":
|
|
77
|
+
lc_messages.append(SystemMessage(content=content))
|
|
78
|
+
else:
|
|
79
|
+
lc_messages.append(HumanMessage(content=content))
|
|
80
|
+
|
|
81
|
+
response = self._model.invoke(lc_messages)
|
|
82
|
+
return response.content
|
|
83
|
+
|
|
84
|
+
def complete_json(self, messages: list[dict[str, str]]) -> Any:
|
|
85
|
+
"""Invoke the model and extract a JSON object from the response."""
|
|
86
|
+
content = self.complete(messages)
|
|
87
|
+
return _extract_json(content)
|
|
88
|
+
|
|
89
|
+
def count_tokens(self, text: str) -> int:
|
|
90
|
+
"""Return an approximate token count for the given text."""
|
|
91
|
+
try:
|
|
92
|
+
return self._model.get_num_tokens(text)
|
|
93
|
+
except (NotImplementedError, AttributeError):
|
|
94
|
+
# Fallback: ~4 characters per token (reasonable for English)
|
|
95
|
+
return len(text) // 4
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _extract_json(text: str) -> Any:
|
|
99
|
+
"""Extract a JSON object from a string, tolerating surrounding prose.
|
|
100
|
+
|
|
101
|
+
Tries in order:
|
|
102
|
+
1. Direct json.loads (model output is clean JSON)
|
|
103
|
+
2. Extract from a markdown code block (```json ... ```)
|
|
104
|
+
3. Extract the first {...} or [...] block via regex
|
|
105
|
+
"""
|
|
106
|
+
text = text.strip()
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
return json.loads(text)
|
|
110
|
+
except json.JSONDecodeError:
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
# Try markdown code block
|
|
114
|
+
code_block = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", text)
|
|
115
|
+
if code_block:
|
|
116
|
+
try:
|
|
117
|
+
return json.loads(code_block.group(1))
|
|
118
|
+
except json.JSONDecodeError:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
# Try first {...} or [...] block
|
|
122
|
+
brace_match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
|
|
123
|
+
if brace_match:
|
|
124
|
+
try:
|
|
125
|
+
return json.loads(brace_match.group(1))
|
|
126
|
+
except json.JSONDecodeError:
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
raise ValueError(f"Could not extract JSON from model response:\n{text}")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NLIModel:
|
|
7
|
+
"""Wrapper around a sentence-transformers CrossEncoder for NLI scoring.
|
|
8
|
+
|
|
9
|
+
The model returns three logits per pair. We apply sigmoid to the entailment
|
|
10
|
+
logit and zero out pairs where the model does not predict entailment.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
model_name: sentence-transformers CrossEncoder model name.
|
|
14
|
+
entailment_idx: Column index of the entailment class in the model output.
|
|
15
|
+
Defaults to 0, which matches ``cross-encoder/nli-deberta-v3-large``
|
|
16
|
+
(label order: [entailment, neutral, contradiction]).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, model_name: str, entailment_idx: int = 0) -> None:
|
|
20
|
+
from sentence_transformers import CrossEncoder
|
|
21
|
+
|
|
22
|
+
self.model_name = model_name
|
|
23
|
+
self.entailment_idx = entailment_idx
|
|
24
|
+
self._model = CrossEncoder(model_name)
|
|
25
|
+
|
|
26
|
+
def score(self, texts: list[str], hypotheses: list[str]) -> np.ndarray:
|
|
27
|
+
"""Score (text, hypothesis) pairs.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
texts: List of n texts.
|
|
31
|
+
hypotheses: List of n hypotheses, parallel to texts.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
np.ndarray of shape (n,) with values in [0, 1]. Pairs where the
|
|
35
|
+
model does not predict entailment are scored as 0.
|
|
36
|
+
"""
|
|
37
|
+
pairs = list(zip(texts, hypotheses))
|
|
38
|
+
logits = self._model.predict(pairs) # shape (n, 3) or (n,) depending on model
|
|
39
|
+
logits = np.array(logits)
|
|
40
|
+
if logits.ndim == 1:
|
|
41
|
+
# Binary model — return sigmoid scores directly
|
|
42
|
+
return 1.0 / (1.0 + np.exp(-logits))
|
|
43
|
+
entail_logits = logits[:, self.entailment_idx]
|
|
44
|
+
scores = 1.0 / (1.0 + np.exp(-entail_logits))
|
|
45
|
+
scores[logits.argmax(axis=1) != self.entailment_idx] = 0.0
|
|
46
|
+
return scores
|
|
File without changes
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
SYSTEM = """\
|
|
4
|
+
You are an expert text analyst. Your task is to generate features that describe \
|
|
5
|
+
properties of individual texts in a collection.
|
|
6
|
+
|
|
7
|
+
Each feature must be defined by:
|
|
8
|
+
- hypothesis: a declarative statement about a single text, suitable for NLI scoring \
|
|
9
|
+
(e.g. "This text expresses satisfaction with the product." or \
|
|
10
|
+
"This text mentions issues with build quality.")
|
|
11
|
+
|
|
12
|
+
Requirements:
|
|
13
|
+
- Each hypothesis must be a statement about a single text, starting with "This text"
|
|
14
|
+
- Features must capture properties that are meaningfully present in some texts \
|
|
15
|
+
but not all — avoid trivially universal or trivially absent properties
|
|
16
|
+
- The hypothesis must be self-contained
|
|
17
|
+
- Aim for diverse, non-redundant features
|
|
18
|
+
|
|
19
|
+
Respond with JSON only:
|
|
20
|
+
{"features": [{"hypothesis": "..."}, ...]}
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_USER_TEMPLATE = """\
|
|
24
|
+
Here are sample texts from the collection:
|
|
25
|
+
---
|
|
26
|
+
{texts_block}
|
|
27
|
+
|
|
28
|
+
Generate exactly {n} features that describe properties of individual texts \
|
|
29
|
+
commonly found in this collection.{language_instruction}
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def build_user_message(texts: list[str], n: int, language: str | None = None) -> str:
|
|
34
|
+
texts_block = "\n---\n".join(texts) if texts else "(none)"
|
|
35
|
+
language_instruction = (
|
|
36
|
+
f"\nGenerate the hypothesis for each feature in {language}." if language else ""
|
|
37
|
+
)
|
|
38
|
+
return _USER_TEMPLATE.format(
|
|
39
|
+
texts_block=texts_block,
|
|
40
|
+
n=n,
|
|
41
|
+
language_instruction=language_instruction,
|
|
42
|
+
)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
SYSTEM = """\
|
|
4
|
+
You are an expert text analyst. Your task is to generate features that distinguish \
|
|
5
|
+
one group of texts from another.
|
|
6
|
+
|
|
7
|
+
Each feature must be defined by:
|
|
8
|
+
- hypothesis: a declarative statement about a single text, suitable for NLI scoring \
|
|
9
|
+
(e.g. "This text expresses satisfaction with the product." or \
|
|
10
|
+
"This text mentions issues with build quality.")
|
|
11
|
+
|
|
12
|
+
Requirements:
|
|
13
|
+
- Each hypothesis must be a statement about a single text, starting with "This text"
|
|
14
|
+
- Hypotheses must capture properties that are more characteristic of the positive group \
|
|
15
|
+
than the negative group — avoid properties shared equally by both groups
|
|
16
|
+
- The hypothesis must be self-contained
|
|
17
|
+
- Aim for diverse, non-redundant features
|
|
18
|
+
|
|
19
|
+
Respond with JSON only:
|
|
20
|
+
{"features": [{"hypothesis": "..."}, ...]}
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_USER_TEMPLATE = """\
|
|
24
|
+
Positive texts (labeled "{pos_label}"):
|
|
25
|
+
---
|
|
26
|
+
{pos_block}
|
|
27
|
+
|
|
28
|
+
Negative texts (labeled "{neg_label}"):
|
|
29
|
+
---
|
|
30
|
+
{neg_block}
|
|
31
|
+
|
|
32
|
+
Generate exactly {n} features whose hypotheses are more likely to be true for \
|
|
33
|
+
"{pos_label}" texts than "{neg_label}" texts.{language_instruction}
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_user_message(
|
|
38
|
+
pos_texts: list[str],
|
|
39
|
+
neg_texts: list[str],
|
|
40
|
+
pos_label: str,
|
|
41
|
+
neg_label: str,
|
|
42
|
+
n: int,
|
|
43
|
+
language: str | None = None,
|
|
44
|
+
) -> str:
|
|
45
|
+
pos_block = "\n---\n".join(pos_texts) if pos_texts else "(none)"
|
|
46
|
+
neg_block = "\n---\n".join(neg_texts) if neg_texts else "(none)"
|
|
47
|
+
language_instruction = (
|
|
48
|
+
f"\nGenerate the hypothesis for each feature in {language}." if language else ""
|
|
49
|
+
)
|
|
50
|
+
return _USER_TEMPLATE.format(
|
|
51
|
+
pos_label=pos_label,
|
|
52
|
+
neg_label=neg_label,
|
|
53
|
+
pos_block=pos_block,
|
|
54
|
+
neg_block=neg_block,
|
|
55
|
+
n=n,
|
|
56
|
+
language_instruction=language_instruction,
|
|
57
|
+
)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def sample_texts_within_budget(
|
|
10
|
+
texts: list[str],
|
|
11
|
+
token_budget: int,
|
|
12
|
+
tokenizer_fn: Callable[[str], int],
|
|
13
|
+
rng: random.Random | None = None,
|
|
14
|
+
) -> list[str]:
|
|
15
|
+
"""Return a random subset of texts that fits within the token budget.
|
|
16
|
+
|
|
17
|
+
Texts are shuffled, then added one by one until adding the next text
|
|
18
|
+
would exceed token_budget. The order of the returned list follows the
|
|
19
|
+
shuffled order.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
texts: Source texts to sample from.
|
|
23
|
+
token_budget: Maximum total tokens allowed.
|
|
24
|
+
tokenizer_fn: Function that returns the token count for a single text.
|
|
25
|
+
rng: Optional Random instance for reproducibility.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
A list of texts whose total token count does not exceed token_budget.
|
|
29
|
+
"""
|
|
30
|
+
if rng is None:
|
|
31
|
+
rng = random.Random()
|
|
32
|
+
|
|
33
|
+
indices = list(range(len(texts)))
|
|
34
|
+
rng.shuffle(indices)
|
|
35
|
+
|
|
36
|
+
selected: list[str] = []
|
|
37
|
+
total_tokens = 0
|
|
38
|
+
|
|
39
|
+
for idx in indices:
|
|
40
|
+
text = texts[idx]
|
|
41
|
+
count = tokenizer_fn(text)
|
|
42
|
+
if total_tokens + count > token_budget:
|
|
43
|
+
break
|
|
44
|
+
selected.append(text)
|
|
45
|
+
total_tokens += count
|
|
46
|
+
|
|
47
|
+
return selected
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def sample_texts_kmeans(
|
|
51
|
+
texts: list[str],
|
|
52
|
+
n: int,
|
|
53
|
+
embeddings: np.ndarray,
|
|
54
|
+
rng: random.Random | None = None,
|
|
55
|
+
) -> list[str]:
|
|
56
|
+
"""Return up to n texts by selecting the closest text to each K-Means centroid.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
texts: Source texts.
|
|
60
|
+
n: Number of texts to select (capped at len(texts)).
|
|
61
|
+
embeddings: Pre-computed embedding matrix of shape (len(texts), dim).
|
|
62
|
+
rng: Optional Random instance for reproducibility.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A list of up to n representative texts, one per cluster.
|
|
66
|
+
"""
|
|
67
|
+
from sklearn.cluster import KMeans
|
|
68
|
+
|
|
69
|
+
n = min(n, len(texts))
|
|
70
|
+
if n == 0:
|
|
71
|
+
return []
|
|
72
|
+
if n == len(texts):
|
|
73
|
+
return list(texts)
|
|
74
|
+
|
|
75
|
+
seed = rng.randint(0, 2**31 - 1) if rng is not None else None
|
|
76
|
+
km = KMeans(n_clusters=n, random_state=seed, n_init="auto")
|
|
77
|
+
km.fit(embeddings)
|
|
78
|
+
|
|
79
|
+
selected_indices: list[int] = []
|
|
80
|
+
for center in km.cluster_centers_:
|
|
81
|
+
dists = np.linalg.norm(embeddings - center, axis=1)
|
|
82
|
+
# Exclude already-selected indices to avoid duplicates when clusters share a medoid
|
|
83
|
+
dists[selected_indices] = np.inf
|
|
84
|
+
selected_indices.append(int(np.argmin(dists)))
|
|
85
|
+
|
|
86
|
+
return [texts[i] for i in selected_indices]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def sample_texts_votek(
|
|
90
|
+
texts: list[str],
|
|
91
|
+
n: int,
|
|
92
|
+
embeddings: np.ndarray,
|
|
93
|
+
k: int = 10,
|
|
94
|
+
rng: random.Random | None = None,
|
|
95
|
+
) -> list[str]:
|
|
96
|
+
"""Return up to n texts using the Vote-K algorithm (Su et al. 2022).
|
|
97
|
+
|
|
98
|
+
Vote-K balances representativeness (high vote count = many neighbours)
|
|
99
|
+
and diversity (selected texts' neighbours are suppressed from future picks).
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
texts: Source texts.
|
|
103
|
+
n: Number of texts to select (capped at len(texts)).
|
|
104
|
+
embeddings: Pre-computed embedding matrix of shape (len(texts), dim).
|
|
105
|
+
k: Number of neighbours to consider for voting and suppression.
|
|
106
|
+
rng: Optional Random instance (unused; kept for API symmetry).
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
A list of up to n texts.
|
|
110
|
+
"""
|
|
111
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
112
|
+
|
|
113
|
+
n = min(n, len(texts))
|
|
114
|
+
if n == 0:
|
|
115
|
+
return []
|
|
116
|
+
if n == len(texts):
|
|
117
|
+
return list(texts)
|
|
118
|
+
|
|
119
|
+
k = min(k, len(texts) - 1)
|
|
120
|
+
sim = cosine_similarity(embeddings) # (N, N)
|
|
121
|
+
|
|
122
|
+
# votes[i] = how many texts consider i among their top-k neighbours
|
|
123
|
+
votes = np.zeros(len(texts), dtype=float)
|
|
124
|
+
for i in range(len(texts)):
|
|
125
|
+
sim_row = sim[i].copy()
|
|
126
|
+
sim_row[i] = -np.inf # exclude self
|
|
127
|
+
top_k = np.argpartition(sim_row, -k)[-k:]
|
|
128
|
+
votes[top_k] += 1.0
|
|
129
|
+
|
|
130
|
+
selected_indices: list[int] = []
|
|
131
|
+
remaining = np.ones(len(texts), dtype=bool)
|
|
132
|
+
|
|
133
|
+
while len(selected_indices) < n and remaining.any():
|
|
134
|
+
# Among remaining texts, pick the one with the most votes
|
|
135
|
+
masked_votes = np.where(remaining, votes, -np.inf)
|
|
136
|
+
best = int(np.argmax(masked_votes))
|
|
137
|
+
selected_indices.append(best)
|
|
138
|
+
remaining[best] = False
|
|
139
|
+
|
|
140
|
+
# Suppress votes of k nearest neighbours
|
|
141
|
+
sim_row = sim[best].copy()
|
|
142
|
+
sim_row[best] = -np.inf
|
|
143
|
+
top_k = np.argpartition(sim_row, -k)[-k:]
|
|
144
|
+
votes[top_k] = 0.0
|
|
145
|
+
|
|
146
|
+
return [texts[i] for i in selected_indices]
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _estimate_n(
|
|
150
|
+
texts: list[str],
|
|
151
|
+
token_budget: int,
|
|
152
|
+
tokenizer_fn: Callable[[str], int],
|
|
153
|
+
) -> int:
|
|
154
|
+
"""Estimate how many texts fit in token_budget based on average token length."""
|
|
155
|
+
if not texts:
|
|
156
|
+
return 0
|
|
157
|
+
probe = texts[:min(20, len(texts))]
|
|
158
|
+
avg = sum(tokenizer_fn(t) for t in probe) / len(probe)
|
|
159
|
+
return max(1, int(token_budget / avg))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _trim_to_budget(
|
|
163
|
+
texts: list[str],
|
|
164
|
+
token_budget: int,
|
|
165
|
+
tokenizer_fn: Callable[[str], int],
|
|
166
|
+
) -> list[str]:
|
|
167
|
+
"""Trim a list of texts to fit within the token budget."""
|
|
168
|
+
result: list[str] = []
|
|
169
|
+
total = 0
|
|
170
|
+
for t in texts:
|
|
171
|
+
count = tokenizer_fn(t)
|
|
172
|
+
if total + count > token_budget:
|
|
173
|
+
break
|
|
174
|
+
result.append(t)
|
|
175
|
+
total += count
|
|
176
|
+
return result
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from itertools import combinations
|
|
5
|
+
from typing import Any, NamedTuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
9
|
+
from sklearn.preprocessing import LabelEncoder
|
|
10
|
+
|
|
11
|
+
from .llm import BaseLLMClient, LLMClient
|
|
12
|
+
from .nli import NLIModel
|
|
13
|
+
from .sampling import (
|
|
14
|
+
_estimate_n,
|
|
15
|
+
_trim_to_budget,
|
|
16
|
+
sample_texts_kmeans,
|
|
17
|
+
sample_texts_votek,
|
|
18
|
+
sample_texts_within_budget,
|
|
19
|
+
)
|
|
20
|
+
from .prompts import discriminative_features as prompts
|
|
21
|
+
|
|
22
|
+
_SAMPLE_METHODS = ("random", "kmeans", "votek")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _sample_group(
|
|
26
|
+
texts: list[str],
|
|
27
|
+
budget: int,
|
|
28
|
+
tokenizer_fn: Any,
|
|
29
|
+
method: str,
|
|
30
|
+
embedding_model: str,
|
|
31
|
+
rng: random.Random,
|
|
32
|
+
) -> list[str]:
|
|
33
|
+
if method == "random":
|
|
34
|
+
return sample_texts_within_budget(texts, budget, tokenizer_fn, rng)
|
|
35
|
+
|
|
36
|
+
from sentence_transformers import SentenceTransformer
|
|
37
|
+
embeddings = SentenceTransformer(embedding_model).encode(
|
|
38
|
+
texts, show_progress_bar=False, convert_to_numpy=True
|
|
39
|
+
)
|
|
40
|
+
n = _estimate_n(texts, budget, tokenizer_fn)
|
|
41
|
+
if method == "kmeans":
|
|
42
|
+
sampled = sample_texts_kmeans(texts, n, embeddings, rng)
|
|
43
|
+
else:
|
|
44
|
+
sampled = sample_texts_votek(texts, n, embeddings, rng=rng)
|
|
45
|
+
return _trim_to_budget(sampled, budget, tokenizer_fn)
|
|
46
|
+
|
|
47
|
+
_PROMPT_OVERHEAD = 500
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class FeatureMeta(NamedTuple):
|
|
51
|
+
positive: Any
|
|
52
|
+
negative: Any
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class SupervisedTransformer(BaseEstimator, TransformerMixin):
|
|
56
|
+
"""Sklearn-compatible transformer that generates discriminative NLI features.
|
|
57
|
+
|
|
58
|
+
Fits by generating hypotheses (via LLM) that distinguish between classes,
|
|
59
|
+
then scores texts against those hypotheses using an NLI model.
|
|
60
|
+
|
|
61
|
+
Supports binary and multi-class labels (numeric or string).
|
|
62
|
+
For multi-class, use ``strategy="ovr"`` (one-vs-rest) or ``strategy="ovo"``
|
|
63
|
+
(one-vs-one). Binary classification ignores ``strategy``.
|
|
64
|
+
|
|
65
|
+
Fitted attributes:
|
|
66
|
+
classes_: Unique class labels in sorted order.
|
|
67
|
+
features_: Hypotheses as plain strings, parallel to ``feature_meta_``.
|
|
68
|
+
feature_meta_: FeatureMeta(positive, negative) for each hypothesis,
|
|
69
|
+
where negative is the original class label or ``"rest"`` for OvR.
|
|
70
|
+
|
|
71
|
+
Example::
|
|
72
|
+
|
|
73
|
+
from sklearn.pipeline import Pipeline
|
|
74
|
+
from sklearn.linear_model import LogisticRegression
|
|
75
|
+
from sklearn.model_selection import cross_val_score
|
|
76
|
+
|
|
77
|
+
pipe = Pipeline([
|
|
78
|
+
("vect", SupervisedTransformer(llm="gpt-4o", nli_model="cross-encoder/nli-deberta-v3-large")),
|
|
79
|
+
("clf", LogisticRegression()),
|
|
80
|
+
])
|
|
81
|
+
cross_val_score(pipe, texts, labels, cv=5)
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
llm: BaseLLMClient | str,
|
|
87
|
+
nli_model: str = "cross-encoder/nli-deberta-v3-large",
|
|
88
|
+
n_features: int = 20,
|
|
89
|
+
strategy: str = "ovr",
|
|
90
|
+
context_limit: int = 100_000,
|
|
91
|
+
language: str | None = None,
|
|
92
|
+
seed: int | None = None,
|
|
93
|
+
sample_method: str = "random",
|
|
94
|
+
embedding_model: str = "paraphrase-albert-small-v2",
|
|
95
|
+
) -> None:
|
|
96
|
+
self.llm = llm
|
|
97
|
+
self.nli_model = nli_model
|
|
98
|
+
self.n_features = n_features
|
|
99
|
+
self.strategy = strategy
|
|
100
|
+
self.context_limit = context_limit
|
|
101
|
+
self.language = language
|
|
102
|
+
self.seed = seed
|
|
103
|
+
self.sample_method = sample_method
|
|
104
|
+
self.embedding_model = embedding_model
|
|
105
|
+
|
|
106
|
+
def fit(self, texts: list[str], y: Any) -> SupervisedTransformer:
|
|
107
|
+
"""Generate discriminative hypotheses from training texts and labels.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
texts: Training texts.
|
|
111
|
+
y: Class labels (numeric or string).
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
self
|
|
115
|
+
"""
|
|
116
|
+
if self.strategy not in ("ovr", "ovo"):
|
|
117
|
+
raise ValueError(f"strategy must be 'ovr' or 'ovo', got {self.strategy!r}")
|
|
118
|
+
if self.sample_method not in _SAMPLE_METHODS:
|
|
119
|
+
raise ValueError(f"sample_method must be one of {_SAMPLE_METHODS}, got {self.sample_method!r}")
|
|
120
|
+
|
|
121
|
+
_llm = LLMClient(self.llm) if isinstance(self.llm, str) else self.llm
|
|
122
|
+
_rng = random.Random(self.seed)
|
|
123
|
+
|
|
124
|
+
le = LabelEncoder()
|
|
125
|
+
y_enc: np.ndarray = le.fit_transform(y)
|
|
126
|
+
self.classes_ = le.classes_
|
|
127
|
+
|
|
128
|
+
n_classes = len(self.classes_)
|
|
129
|
+
budget = self.context_limit - _PROMPT_OVERHEAD
|
|
130
|
+
|
|
131
|
+
if n_classes == 2:
|
|
132
|
+
pairs: list[tuple[int, int | str]] = [(0, 1)]
|
|
133
|
+
elif self.strategy == "ovr":
|
|
134
|
+
pairs = [(i, "rest") for i in range(n_classes)]
|
|
135
|
+
else:
|
|
136
|
+
pairs = list(combinations(range(n_classes), 2))
|
|
137
|
+
|
|
138
|
+
self.features_: list[str] = []
|
|
139
|
+
self.feature_meta_: list[FeatureMeta] = []
|
|
140
|
+
|
|
141
|
+
for pos_idx, neg_idx in pairs:
|
|
142
|
+
pos_label = self.classes_[pos_idx]
|
|
143
|
+
pos_texts = [t for t, yi in zip(texts, y_enc) if yi == pos_idx]
|
|
144
|
+
|
|
145
|
+
if neg_idx == "rest":
|
|
146
|
+
neg_texts = [t for t, yi in zip(texts, y_enc) if yi != pos_idx]
|
|
147
|
+
neg_label: Any = "rest"
|
|
148
|
+
else:
|
|
149
|
+
neg_texts = [t for t, yi in zip(texts, y_enc) if yi == neg_idx]
|
|
150
|
+
neg_label = self.classes_[neg_idx]
|
|
151
|
+
|
|
152
|
+
pos_sampled = _sample_group(
|
|
153
|
+
pos_texts, budget // 2, _llm.count_tokens,
|
|
154
|
+
self.sample_method, self.embedding_model, _rng,
|
|
155
|
+
)
|
|
156
|
+
neg_sampled = _sample_group(
|
|
157
|
+
neg_texts, budget // 2, _llm.count_tokens,
|
|
158
|
+
self.sample_method, self.embedding_model, _rng,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
messages = [
|
|
162
|
+
{"role": "system", "content": prompts.SYSTEM},
|
|
163
|
+
{"role": "user", "content": prompts.build_user_message(
|
|
164
|
+
pos_texts=pos_sampled,
|
|
165
|
+
neg_texts=neg_sampled,
|
|
166
|
+
pos_label=str(pos_label),
|
|
167
|
+
neg_label=str(neg_label),
|
|
168
|
+
n=self.n_features,
|
|
169
|
+
language=self.language,
|
|
170
|
+
)},
|
|
171
|
+
]
|
|
172
|
+
result = _llm.complete_json(messages)
|
|
173
|
+
hypotheses = [item["hypothesis"] for item in result.get("features", [])]
|
|
174
|
+
|
|
175
|
+
self.features_.extend(hypotheses)
|
|
176
|
+
self.feature_meta_.extend(
|
|
177
|
+
FeatureMeta(positive=pos_label, negative=neg_label)
|
|
178
|
+
for _ in hypotheses
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
self._nli = NLIModel(self.nli_model)
|
|
182
|
+
return self
|
|
183
|
+
|
|
184
|
+
def transform(self, texts: list[str]) -> np.ndarray:
|
|
185
|
+
"""Score texts against fitted hypotheses using NLI.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
texts: Texts to score.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
np.ndarray of shape (n_texts, n_features) with entailment scores in [0, 1].
|
|
192
|
+
"""
|
|
193
|
+
columns = [
|
|
194
|
+
self._nli.score(texts, [h] * len(texts))
|
|
195
|
+
for h in self.features_
|
|
196
|
+
]
|
|
197
|
+
return np.column_stack(columns)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
8
|
+
|
|
9
|
+
from .llm import BaseLLMClient, LLMClient
|
|
10
|
+
from .nli import NLIModel
|
|
11
|
+
from .sampling import (
|
|
12
|
+
_estimate_n,
|
|
13
|
+
_trim_to_budget,
|
|
14
|
+
sample_texts_kmeans,
|
|
15
|
+
sample_texts_votek,
|
|
16
|
+
sample_texts_within_budget,
|
|
17
|
+
)
|
|
18
|
+
from .prompts import collection_description as prompts
|
|
19
|
+
|
|
20
|
+
_SAMPLE_METHODS = ("random", "kmeans", "votek")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _sample_group(
|
|
24
|
+
texts: list[str],
|
|
25
|
+
budget: int,
|
|
26
|
+
tokenizer_fn: Any,
|
|
27
|
+
method: str,
|
|
28
|
+
embedding_model: str,
|
|
29
|
+
rng: random.Random,
|
|
30
|
+
) -> list[str]:
|
|
31
|
+
if method == "random":
|
|
32
|
+
return sample_texts_within_budget(texts, budget, tokenizer_fn, rng)
|
|
33
|
+
|
|
34
|
+
from sentence_transformers import SentenceTransformer
|
|
35
|
+
embeddings = SentenceTransformer(embedding_model).encode(
|
|
36
|
+
texts, show_progress_bar=False, convert_to_numpy=True
|
|
37
|
+
)
|
|
38
|
+
n = _estimate_n(texts, budget, tokenizer_fn)
|
|
39
|
+
if method == "kmeans":
|
|
40
|
+
sampled = sample_texts_kmeans(texts, n, embeddings, rng)
|
|
41
|
+
else:
|
|
42
|
+
sampled = sample_texts_votek(texts, n, embeddings, rng=rng)
|
|
43
|
+
return _trim_to_budget(sampled, budget, tokenizer_fn)
|
|
44
|
+
|
|
45
|
+
_PROMPT_OVERHEAD = 500
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class UnsupervisedTransformer(BaseEstimator, TransformerMixin):
|
|
49
|
+
"""Sklearn-compatible transformer that generates NLI features from unlabeled texts.
|
|
50
|
+
|
|
51
|
+
Fits by generating hypotheses (via LLM) that characterize the text collection,
|
|
52
|
+
then scores texts against those hypotheses using an NLI model.
|
|
53
|
+
|
|
54
|
+
Fitted attributes:
|
|
55
|
+
features_: Hypotheses as plain strings.
|
|
56
|
+
|
|
57
|
+
Example::
|
|
58
|
+
|
|
59
|
+
from sklearn.pipeline import Pipeline
|
|
60
|
+
from sklearn.linear_model import LogisticRegression
|
|
61
|
+
from sklearn.model_selection import cross_val_score
|
|
62
|
+
|
|
63
|
+
pipe = Pipeline([
|
|
64
|
+
("vect", UnsupervisedTransformer(llm="gpt-4o", nli_model="cross-encoder/nli-deberta-v3-large")),
|
|
65
|
+
("clf", LogisticRegression()),
|
|
66
|
+
])
|
|
67
|
+
cross_val_score(pipe, texts, labels, cv=5)
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
llm: BaseLLMClient | str,
|
|
73
|
+
nli_model: str = "cross-encoder/nli-deberta-v3-large",
|
|
74
|
+
n_features: int = 20,
|
|
75
|
+
context_limit: int = 100_000,
|
|
76
|
+
language: str | None = None,
|
|
77
|
+
seed: int | None = None,
|
|
78
|
+
sample_method: str = "random",
|
|
79
|
+
embedding_model: str = "paraphrase-albert-small-v2",
|
|
80
|
+
) -> None:
|
|
81
|
+
self.llm = llm
|
|
82
|
+
self.nli_model = nli_model
|
|
83
|
+
self.n_features = n_features
|
|
84
|
+
self.context_limit = context_limit
|
|
85
|
+
self.language = language
|
|
86
|
+
self.seed = seed
|
|
87
|
+
self.sample_method = sample_method
|
|
88
|
+
self.embedding_model = embedding_model
|
|
89
|
+
|
|
90
|
+
def fit(self, texts: list[str], y=None) -> UnsupervisedTransformer:
|
|
91
|
+
"""Generate hypotheses from texts using LLM.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
texts: Texts to characterize.
|
|
95
|
+
y: Ignored. Present for sklearn API compatibility.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
self
|
|
99
|
+
"""
|
|
100
|
+
if self.sample_method not in _SAMPLE_METHODS:
|
|
101
|
+
raise ValueError(f"sample_method must be one of {_SAMPLE_METHODS}, got {self.sample_method!r}")
|
|
102
|
+
|
|
103
|
+
_llm = LLMClient(self.llm) if isinstance(self.llm, str) else self.llm
|
|
104
|
+
_rng = random.Random(self.seed)
|
|
105
|
+
|
|
106
|
+
budget = self.context_limit - _PROMPT_OVERHEAD
|
|
107
|
+
sampled = _sample_group(
|
|
108
|
+
texts, budget, _llm.count_tokens,
|
|
109
|
+
self.sample_method, self.embedding_model, _rng,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
messages = [
|
|
113
|
+
{"role": "system", "content": prompts.SYSTEM},
|
|
114
|
+
{"role": "user", "content": prompts.build_user_message(
|
|
115
|
+
sampled, n=self.n_features, language=self.language
|
|
116
|
+
)},
|
|
117
|
+
]
|
|
118
|
+
result = _llm.complete_json(messages)
|
|
119
|
+
self.features_: list[str] = [
|
|
120
|
+
item["hypothesis"] for item in result.get("features", [])
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
self._nli = NLIModel(self.nli_model)
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
def transform(self, texts: list[str]) -> np.ndarray:
|
|
127
|
+
"""Score texts against fitted hypotheses using NLI.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
texts: Texts to score.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
np.ndarray of shape (n_texts, n_features) with entailment scores in [0, 1].
|
|
134
|
+
"""
|
|
135
|
+
columns = [
|
|
136
|
+
self._nli.score(texts, [h] * len(texts))
|
|
137
|
+
for h in self.features_
|
|
138
|
+
]
|
|
139
|
+
return np.column_stack(columns)
|