tritopic 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tritopic/__init__.py +22 -32
- tritopic/config.py +305 -0
- tritopic/core/__init__.py +0 -17
- tritopic/core/clustering.py +229 -243
- tritopic/core/embeddings.py +151 -157
- tritopic/core/graph.py +435 -0
- tritopic/core/keywords.py +213 -249
- tritopic/core/refinement.py +231 -0
- tritopic/core/representatives.py +560 -0
- tritopic/labeling.py +313 -0
- tritopic/model.py +718 -0
- tritopic/multilingual/__init__.py +38 -0
- tritopic/multilingual/detection.py +208 -0
- tritopic/multilingual/stopwords.py +467 -0
- tritopic/multilingual/tokenizers.py +275 -0
- tritopic/visualization.py +371 -0
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/METADATA +92 -48
- tritopic-1.0.0.dist-info/RECORD +20 -0
- tritopic/core/graph_builder.py +0 -493
- tritopic/core/model.py +0 -810
- tritopic/labeling/__init__.py +0 -5
- tritopic/labeling/llm_labeler.py +0 -279
- tritopic/utils/__init__.py +0 -13
- tritopic/utils/metrics.py +0 -254
- tritopic/visualization/__init__.py +0 -5
- tritopic/visualization/plotter.py +0 -523
- tritopic-0.1.0.dist-info/RECORD +0 -18
- tritopic-0.1.0.dist-info/licenses/LICENSE +0 -21
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/WHEEL +0 -0
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/top_level.txt +0 -0
tritopic/labeling/__init__.py
DELETED
tritopic/labeling/llm_labeler.py
DELETED
|
@@ -1,279 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
LLM-based Topic Labeling
|
|
3
|
-
=========================
|
|
4
|
-
|
|
5
|
-
Generate human-readable topic labels using LLMs:
|
|
6
|
-
- Claude (Anthropic)
|
|
7
|
-
- GPT-4 (OpenAI)
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
from __future__ import annotations
|
|
11
|
-
|
|
12
|
-
from typing import Literal
|
|
13
|
-
import json
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class LLMLabeler:
|
|
17
|
-
"""
|
|
18
|
-
Generate topic labels using Large Language Models.
|
|
19
|
-
|
|
20
|
-
Uses LLMs to create meaningful, human-readable labels for topics
|
|
21
|
-
based on their keywords and representative documents.
|
|
22
|
-
|
|
23
|
-
Parameters
|
|
24
|
-
----------
|
|
25
|
-
provider : str
|
|
26
|
-
LLM provider: "anthropic" or "openai"
|
|
27
|
-
api_key : str
|
|
28
|
-
API key for the provider.
|
|
29
|
-
model : str, optional
|
|
30
|
-
Model name. Defaults to best available model.
|
|
31
|
-
max_tokens : int
|
|
32
|
-
Maximum tokens in response. Default: 200
|
|
33
|
-
temperature : float
|
|
34
|
-
Sampling temperature. Default: 0.3
|
|
35
|
-
language : str
|
|
36
|
-
Output language. Default: "english"
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
provider: Literal["anthropic", "openai"] = "anthropic",
|
|
42
|
-
api_key: str | None = None,
|
|
43
|
-
model: str | None = None,
|
|
44
|
-
max_tokens: int = 200,
|
|
45
|
-
temperature: float = 0.3,
|
|
46
|
-
language: str = "english",
|
|
47
|
-
):
|
|
48
|
-
self.provider = provider
|
|
49
|
-
self.api_key = api_key
|
|
50
|
-
self.model = model or self._default_model()
|
|
51
|
-
self.max_tokens = max_tokens
|
|
52
|
-
self.temperature = temperature
|
|
53
|
-
self.language = language
|
|
54
|
-
|
|
55
|
-
self._client = None
|
|
56
|
-
|
|
57
|
-
def _default_model(self) -> str:
|
|
58
|
-
"""Get default model for provider."""
|
|
59
|
-
if self.provider == "anthropic":
|
|
60
|
-
return "claude-3-haiku-20240307"
|
|
61
|
-
else:
|
|
62
|
-
return "gpt-4o-mini"
|
|
63
|
-
|
|
64
|
-
def _init_client(self):
|
|
65
|
-
"""Initialize API client."""
|
|
66
|
-
if self._client is not None:
|
|
67
|
-
return
|
|
68
|
-
|
|
69
|
-
if self.provider == "anthropic":
|
|
70
|
-
try:
|
|
71
|
-
from anthropic import Anthropic
|
|
72
|
-
self._client = Anthropic(api_key=self.api_key)
|
|
73
|
-
except ImportError:
|
|
74
|
-
raise ImportError(
|
|
75
|
-
"anthropic package not installed. "
|
|
76
|
-
"Install with: pip install anthropic"
|
|
77
|
-
)
|
|
78
|
-
else:
|
|
79
|
-
try:
|
|
80
|
-
from openai import OpenAI
|
|
81
|
-
self._client = OpenAI(api_key=self.api_key)
|
|
82
|
-
except ImportError:
|
|
83
|
-
raise ImportError(
|
|
84
|
-
"openai package not installed. "
|
|
85
|
-
"Install with: pip install openai"
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
def generate_label(
|
|
89
|
-
self,
|
|
90
|
-
keywords: list[str],
|
|
91
|
-
representative_docs: list[str],
|
|
92
|
-
domain_hint: str | None = None,
|
|
93
|
-
) -> tuple[str, str]:
|
|
94
|
-
"""
|
|
95
|
-
Generate a label for a topic.
|
|
96
|
-
|
|
97
|
-
Parameters
|
|
98
|
-
----------
|
|
99
|
-
keywords : list[str]
|
|
100
|
-
Topic keywords (top 10 recommended).
|
|
101
|
-
representative_docs : list[str]
|
|
102
|
-
Representative documents for the topic.
|
|
103
|
-
domain_hint : str, optional
|
|
104
|
-
Domain context (e.g., "tourism", "technology").
|
|
105
|
-
|
|
106
|
-
Returns
|
|
107
|
-
-------
|
|
108
|
-
label : str
|
|
109
|
-
Short topic label (2-5 words).
|
|
110
|
-
description : str
|
|
111
|
-
Brief description of the topic.
|
|
112
|
-
"""
|
|
113
|
-
self._init_client()
|
|
114
|
-
|
|
115
|
-
# Build prompt
|
|
116
|
-
prompt = self._build_prompt(keywords, representative_docs, domain_hint)
|
|
117
|
-
|
|
118
|
-
# Call API
|
|
119
|
-
if self.provider == "anthropic":
|
|
120
|
-
response = self._call_anthropic(prompt)
|
|
121
|
-
else:
|
|
122
|
-
response = self._call_openai(prompt)
|
|
123
|
-
|
|
124
|
-
# Parse response
|
|
125
|
-
label, description = self._parse_response(response)
|
|
126
|
-
|
|
127
|
-
return label, description
|
|
128
|
-
|
|
129
|
-
def _build_prompt(
|
|
130
|
-
self,
|
|
131
|
-
keywords: list[str],
|
|
132
|
-
representative_docs: list[str],
|
|
133
|
-
domain_hint: str | None = None,
|
|
134
|
-
) -> str:
|
|
135
|
-
"""Build the labeling prompt."""
|
|
136
|
-
# Truncate long documents
|
|
137
|
-
docs_text = ""
|
|
138
|
-
for i, doc in enumerate(representative_docs[:5], 1):
|
|
139
|
-
truncated = doc[:500] + "..." if len(doc) > 500 else doc
|
|
140
|
-
docs_text += f"\nDocument {i}: {truncated}\n"
|
|
141
|
-
|
|
142
|
-
domain_context = ""
|
|
143
|
-
if domain_hint:
|
|
144
|
-
domain_context = f"\nDomain context: This is about {domain_hint}.\n"
|
|
145
|
-
|
|
146
|
-
prompt = f"""You are an expert at creating concise, meaningful topic labels.
|
|
147
|
-
|
|
148
|
-
Given the following information about a topic, create:
|
|
149
|
-
1. A SHORT LABEL (2-5 words, title case, no special characters)
|
|
150
|
-
2. A BRIEF DESCRIPTION (1-2 sentences explaining what this topic is about)
|
|
151
|
-
|
|
152
|
-
Keywords (most representative words for this topic):
|
|
153
|
-
{', '.join(keywords[:10])}
|
|
154
|
-
|
|
155
|
-
Representative Documents:
|
|
156
|
-
{docs_text}
|
|
157
|
-
{domain_context}
|
|
158
|
-
Requirements:
|
|
159
|
-
- The label should be specific and descriptive, not generic
|
|
160
|
-
- The label should capture the main theme, not just list keywords
|
|
161
|
-
- The description should explain what documents in this topic discuss
|
|
162
|
-
- Output in {self.language}
|
|
163
|
-
|
|
164
|
-
Respond in this exact JSON format:
|
|
165
|
-
{{"label": "Your Topic Label", "description": "Your brief description."}}
|
|
166
|
-
|
|
167
|
-
JSON response:"""
|
|
168
|
-
|
|
169
|
-
return prompt
|
|
170
|
-
|
|
171
|
-
def _call_anthropic(self, prompt: str) -> str:
|
|
172
|
-
"""Call Anthropic API."""
|
|
173
|
-
response = self._client.messages.create(
|
|
174
|
-
model=self.model,
|
|
175
|
-
max_tokens=self.max_tokens,
|
|
176
|
-
temperature=self.temperature,
|
|
177
|
-
messages=[{"role": "user", "content": prompt}],
|
|
178
|
-
)
|
|
179
|
-
return response.content[0].text
|
|
180
|
-
|
|
181
|
-
def _call_openai(self, prompt: str) -> str:
|
|
182
|
-
"""Call OpenAI API."""
|
|
183
|
-
response = self._client.chat.completions.create(
|
|
184
|
-
model=self.model,
|
|
185
|
-
max_tokens=self.max_tokens,
|
|
186
|
-
temperature=self.temperature,
|
|
187
|
-
messages=[{"role": "user", "content": prompt}],
|
|
188
|
-
)
|
|
189
|
-
return response.choices[0].message.content
|
|
190
|
-
|
|
191
|
-
def _parse_response(self, response: str) -> tuple[str, str]:
|
|
192
|
-
"""Parse LLM response to extract label and description."""
|
|
193
|
-
try:
|
|
194
|
-
# Try to parse as JSON
|
|
195
|
-
# Find JSON in response
|
|
196
|
-
start = response.find("{")
|
|
197
|
-
end = response.rfind("}") + 1
|
|
198
|
-
|
|
199
|
-
if start != -1 and end > start:
|
|
200
|
-
json_str = response[start:end]
|
|
201
|
-
data = json.loads(json_str)
|
|
202
|
-
|
|
203
|
-
label = data.get("label", "Unknown Topic")
|
|
204
|
-
description = data.get("description", "")
|
|
205
|
-
|
|
206
|
-
return label, description
|
|
207
|
-
except (json.JSONDecodeError, KeyError):
|
|
208
|
-
pass
|
|
209
|
-
|
|
210
|
-
# Fallback: extract from text
|
|
211
|
-
lines = response.strip().split("\n")
|
|
212
|
-
label = lines[0] if lines else "Unknown Topic"
|
|
213
|
-
description = " ".join(lines[1:]) if len(lines) > 1 else ""
|
|
214
|
-
|
|
215
|
-
# Clean up
|
|
216
|
-
label = label.replace('"', "").replace("Label:", "").strip()
|
|
217
|
-
description = description.replace('"', "").replace("Description:", "").strip()
|
|
218
|
-
|
|
219
|
-
return label, description
|
|
220
|
-
|
|
221
|
-
def generate_labels_batch(
|
|
222
|
-
self,
|
|
223
|
-
topics_data: list[dict],
|
|
224
|
-
domain_hint: str | None = None,
|
|
225
|
-
) -> list[tuple[str, str]]:
|
|
226
|
-
"""
|
|
227
|
-
Generate labels for multiple topics.
|
|
228
|
-
|
|
229
|
-
Parameters
|
|
230
|
-
----------
|
|
231
|
-
topics_data : list[dict]
|
|
232
|
-
List of dicts with "keywords" and "representative_docs".
|
|
233
|
-
domain_hint : str, optional
|
|
234
|
-
Domain context.
|
|
235
|
-
|
|
236
|
-
Returns
|
|
237
|
-
-------
|
|
238
|
-
labels : list[tuple[str, str]]
|
|
239
|
-
List of (label, description) tuples.
|
|
240
|
-
"""
|
|
241
|
-
results = []
|
|
242
|
-
|
|
243
|
-
for topic in topics_data:
|
|
244
|
-
label, desc = self.generate_label(
|
|
245
|
-
keywords=topic["keywords"],
|
|
246
|
-
representative_docs=topic["representative_docs"],
|
|
247
|
-
domain_hint=domain_hint,
|
|
248
|
-
)
|
|
249
|
-
results.append((label, desc))
|
|
250
|
-
|
|
251
|
-
return results
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
class SimpleLabeler:
|
|
255
|
-
"""
|
|
256
|
-
Simple rule-based labeler (no LLM required).
|
|
257
|
-
|
|
258
|
-
Creates labels from top keywords.
|
|
259
|
-
"""
|
|
260
|
-
|
|
261
|
-
def __init__(self, n_words: int = 3):
|
|
262
|
-
self.n_words = n_words
|
|
263
|
-
|
|
264
|
-
def generate_label(
|
|
265
|
-
self,
|
|
266
|
-
keywords: list[str],
|
|
267
|
-
**kwargs,
|
|
268
|
-
) -> tuple[str, str]:
|
|
269
|
-
"""Generate label from top keywords."""
|
|
270
|
-
# Take top n keywords
|
|
271
|
-
top_keywords = keywords[:self.n_words]
|
|
272
|
-
|
|
273
|
-
# Title case
|
|
274
|
-
label = " & ".join(kw.title() for kw in top_keywords)
|
|
275
|
-
|
|
276
|
-
# Description from more keywords
|
|
277
|
-
description = f"Topics related to: {', '.join(keywords[:6])}"
|
|
278
|
-
|
|
279
|
-
return label, description
|
tritopic/utils/__init__.py
DELETED
tritopic/utils/metrics.py
DELETED
|
@@ -1,254 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Evaluation Metrics for Topic Models
|
|
3
|
-
====================================
|
|
4
|
-
|
|
5
|
-
Provides standard metrics for evaluating topic model quality:
|
|
6
|
-
- Coherence (NPMI, CV)
|
|
7
|
-
- Diversity
|
|
8
|
-
- Stability (ARI between runs)
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
from __future__ import annotations
|
|
12
|
-
|
|
13
|
-
import numpy as np
|
|
14
|
-
from collections import Counter
|
|
15
|
-
from itertools import combinations
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def compute_coherence(
|
|
19
|
-
keywords: list[str],
|
|
20
|
-
documents: list[str],
|
|
21
|
-
method: str = "npmi",
|
|
22
|
-
window_size: int = 10,
|
|
23
|
-
) -> float:
|
|
24
|
-
"""
|
|
25
|
-
Compute topic coherence based on keyword co-occurrence.
|
|
26
|
-
|
|
27
|
-
Parameters
|
|
28
|
-
----------
|
|
29
|
-
keywords : list[str]
|
|
30
|
-
Topic keywords.
|
|
31
|
-
documents : list[str]
|
|
32
|
-
Documents used to compute co-occurrence.
|
|
33
|
-
method : str
|
|
34
|
-
Coherence method: "npmi" (default), "uci", "umass"
|
|
35
|
-
window_size : int
|
|
36
|
-
Window size for co-occurrence. Default: 10
|
|
37
|
-
|
|
38
|
-
Returns
|
|
39
|
-
-------
|
|
40
|
-
coherence : float
|
|
41
|
-
Coherence score (higher is better).
|
|
42
|
-
"""
|
|
43
|
-
if len(keywords) < 2:
|
|
44
|
-
return 0.0
|
|
45
|
-
|
|
46
|
-
# Tokenize documents
|
|
47
|
-
def tokenize(text):
|
|
48
|
-
import re
|
|
49
|
-
return set(re.findall(r'\b\w+\b', text.lower()))
|
|
50
|
-
|
|
51
|
-
doc_tokens = [tokenize(doc) for doc in documents]
|
|
52
|
-
n_docs = len(documents)
|
|
53
|
-
|
|
54
|
-
# Count document frequencies
|
|
55
|
-
word_doc_freq = Counter()
|
|
56
|
-
for tokens in doc_tokens:
|
|
57
|
-
for word in tokens:
|
|
58
|
-
word_doc_freq[word] += 1
|
|
59
|
-
|
|
60
|
-
# Count co-occurrences (document-level)
|
|
61
|
-
pair_doc_freq = Counter()
|
|
62
|
-
for tokens in doc_tokens:
|
|
63
|
-
for w1, w2 in combinations(keywords, 2):
|
|
64
|
-
if w1.lower() in tokens and w2.lower() in tokens:
|
|
65
|
-
pair_doc_freq[(w1.lower(), w2.lower())] += 1
|
|
66
|
-
|
|
67
|
-
# Compute coherence
|
|
68
|
-
coherence_scores = []
|
|
69
|
-
|
|
70
|
-
for w1, w2 in combinations(keywords, 2):
|
|
71
|
-
w1_lower, w2_lower = w1.lower(), w2.lower()
|
|
72
|
-
|
|
73
|
-
freq_w1 = word_doc_freq.get(w1_lower, 0)
|
|
74
|
-
freq_w2 = word_doc_freq.get(w2_lower, 0)
|
|
75
|
-
freq_pair = pair_doc_freq.get((w1_lower, w2_lower), 0)
|
|
76
|
-
|
|
77
|
-
if freq_w1 == 0 or freq_w2 == 0:
|
|
78
|
-
continue
|
|
79
|
-
|
|
80
|
-
if method == "npmi":
|
|
81
|
-
# Normalized Pointwise Mutual Information
|
|
82
|
-
p_w1 = freq_w1 / n_docs
|
|
83
|
-
p_w2 = freq_w2 / n_docs
|
|
84
|
-
p_pair = (freq_pair + 1) / n_docs # Add-1 smoothing
|
|
85
|
-
|
|
86
|
-
pmi = np.log(p_pair / (p_w1 * p_w2 + 1e-10))
|
|
87
|
-
npmi = pmi / (-np.log(p_pair + 1e-10) + 1e-10)
|
|
88
|
-
coherence_scores.append(npmi)
|
|
89
|
-
|
|
90
|
-
elif method == "uci":
|
|
91
|
-
# UCI coherence
|
|
92
|
-
p_pair = (freq_pair + 1) / n_docs
|
|
93
|
-
p_w1 = freq_w1 / n_docs
|
|
94
|
-
p_w2 = freq_w2 / n_docs
|
|
95
|
-
|
|
96
|
-
pmi = np.log(p_pair / (p_w1 * p_w2 + 1e-10))
|
|
97
|
-
coherence_scores.append(pmi)
|
|
98
|
-
|
|
99
|
-
elif method == "umass":
|
|
100
|
-
# UMass coherence
|
|
101
|
-
if freq_w2 > 0:
|
|
102
|
-
score = np.log((freq_pair + 1) / freq_w2)
|
|
103
|
-
coherence_scores.append(score)
|
|
104
|
-
|
|
105
|
-
return float(np.mean(coherence_scores)) if coherence_scores else 0.0
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def compute_diversity(
|
|
109
|
-
all_keywords: list[str],
|
|
110
|
-
n_topics: int,
|
|
111
|
-
) -> float:
|
|
112
|
-
"""
|
|
113
|
-
Compute topic diversity (proportion of unique keywords).
|
|
114
|
-
|
|
115
|
-
Diversity measures how different topics are from each other.
|
|
116
|
-
A model where every topic has the same keywords has diversity 0.
|
|
117
|
-
|
|
118
|
-
Parameters
|
|
119
|
-
----------
|
|
120
|
-
all_keywords : list[str]
|
|
121
|
-
All keywords from all topics (flattened).
|
|
122
|
-
n_topics : int
|
|
123
|
-
Number of topics.
|
|
124
|
-
|
|
125
|
-
Returns
|
|
126
|
-
-------
|
|
127
|
-
diversity : float
|
|
128
|
-
Diversity score between 0 and 1 (higher is better).
|
|
129
|
-
"""
|
|
130
|
-
if not all_keywords or n_topics == 0:
|
|
131
|
-
return 0.0
|
|
132
|
-
|
|
133
|
-
unique_keywords = set(kw.lower() for kw in all_keywords)
|
|
134
|
-
|
|
135
|
-
# Diversity = unique keywords / total keywords
|
|
136
|
-
diversity = len(unique_keywords) / len(all_keywords)
|
|
137
|
-
|
|
138
|
-
return float(diversity)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def compute_stability(
|
|
142
|
-
partitions: list[np.ndarray],
|
|
143
|
-
) -> float:
|
|
144
|
-
"""
|
|
145
|
-
Compute clustering stability as average pairwise ARI.
|
|
146
|
-
|
|
147
|
-
Parameters
|
|
148
|
-
----------
|
|
149
|
-
partitions : list[np.ndarray]
|
|
150
|
-
Multiple cluster assignments from different runs.
|
|
151
|
-
|
|
152
|
-
Returns
|
|
153
|
-
-------
|
|
154
|
-
stability : float
|
|
155
|
-
Average Adjusted Rand Index between partitions.
|
|
156
|
-
"""
|
|
157
|
-
from sklearn.metrics import adjusted_rand_score
|
|
158
|
-
|
|
159
|
-
if len(partitions) < 2:
|
|
160
|
-
return 1.0
|
|
161
|
-
|
|
162
|
-
ari_scores = []
|
|
163
|
-
for i in range(len(partitions)):
|
|
164
|
-
for j in range(i + 1, len(partitions)):
|
|
165
|
-
ari = adjusted_rand_score(partitions[i], partitions[j])
|
|
166
|
-
ari_scores.append(ari)
|
|
167
|
-
|
|
168
|
-
return float(np.mean(ari_scores))
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
def compute_silhouette(
|
|
172
|
-
embeddings: np.ndarray,
|
|
173
|
-
labels: np.ndarray,
|
|
174
|
-
) -> float:
|
|
175
|
-
"""
|
|
176
|
-
Compute silhouette score for cluster quality.
|
|
177
|
-
|
|
178
|
-
Parameters
|
|
179
|
-
----------
|
|
180
|
-
embeddings : np.ndarray
|
|
181
|
-
Document embeddings.
|
|
182
|
-
labels : np.ndarray
|
|
183
|
-
Cluster assignments.
|
|
184
|
-
|
|
185
|
-
Returns
|
|
186
|
-
-------
|
|
187
|
-
silhouette : float
|
|
188
|
-
Silhouette score between -1 and 1 (higher is better).
|
|
189
|
-
"""
|
|
190
|
-
from sklearn.metrics import silhouette_score
|
|
191
|
-
|
|
192
|
-
# Filter out outliers
|
|
193
|
-
mask = labels != -1
|
|
194
|
-
if mask.sum() < 2:
|
|
195
|
-
return 0.0
|
|
196
|
-
|
|
197
|
-
unique_labels = np.unique(labels[mask])
|
|
198
|
-
if len(unique_labels) < 2:
|
|
199
|
-
return 0.0
|
|
200
|
-
|
|
201
|
-
return float(silhouette_score(embeddings[mask], labels[mask]))
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
def compute_downstream_score(
|
|
205
|
-
embeddings: np.ndarray,
|
|
206
|
-
labels: np.ndarray,
|
|
207
|
-
y_true: np.ndarray,
|
|
208
|
-
task: str = "classification",
|
|
209
|
-
) -> float:
|
|
210
|
-
"""
|
|
211
|
-
Evaluate topic model by downstream task performance.
|
|
212
|
-
|
|
213
|
-
Parameters
|
|
214
|
-
----------
|
|
215
|
-
embeddings : np.ndarray
|
|
216
|
-
Document embeddings.
|
|
217
|
-
labels : np.ndarray
|
|
218
|
-
Topic assignments.
|
|
219
|
-
y_true : np.ndarray
|
|
220
|
-
True labels for downstream task.
|
|
221
|
-
task : str
|
|
222
|
-
Task type: "classification" or "clustering"
|
|
223
|
-
|
|
224
|
-
Returns
|
|
225
|
-
-------
|
|
226
|
-
score : float
|
|
227
|
-
Task-specific score.
|
|
228
|
-
"""
|
|
229
|
-
from sklearn.linear_model import LogisticRegression
|
|
230
|
-
from sklearn.metrics import f1_score, adjusted_rand_score
|
|
231
|
-
from sklearn.model_selection import cross_val_score
|
|
232
|
-
|
|
233
|
-
# Create topic features (one-hot + embedding)
|
|
234
|
-
n_topics = len(np.unique(labels[labels != -1]))
|
|
235
|
-
|
|
236
|
-
# One-hot encode topics
|
|
237
|
-
topic_features = np.zeros((len(labels), n_topics + 1))
|
|
238
|
-
for i, label in enumerate(labels):
|
|
239
|
-
if label == -1:
|
|
240
|
-
topic_features[i, -1] = 1 # Outlier feature
|
|
241
|
-
else:
|
|
242
|
-
topic_features[i, label] = 1
|
|
243
|
-
|
|
244
|
-
# Combine with embeddings
|
|
245
|
-
features = np.hstack([embeddings, topic_features])
|
|
246
|
-
|
|
247
|
-
if task == "classification":
|
|
248
|
-
# Cross-validated F1
|
|
249
|
-
clf = LogisticRegression(max_iter=1000, random_state=42)
|
|
250
|
-
scores = cross_val_score(clf, features, y_true, cv=5, scoring="f1_macro")
|
|
251
|
-
return float(np.mean(scores))
|
|
252
|
-
else:
|
|
253
|
-
# Clustering ARI
|
|
254
|
-
return float(adjusted_rand_score(labels, y_true))
|