structural-topic-model 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pystm/__init__.py +31 -0
- pystm/_estep.py +159 -0
- pystm/_mnreg.py +221 -0
- pystm/_mstep.py +96 -0
- pystm/_spectral.py +151 -0
- pystm/_utils.py +41 -0
- pystm/diagnostics.py +168 -0
- pystm/effects.py +203 -0
- pystm/model_selection.py +166 -0
- pystm/stm.py +443 -0
- structural_topic_model-0.2.0.dist-info/METADATA +234 -0
- structural_topic_model-0.2.0.dist-info/RECORD +14 -0
- structural_topic_model-0.2.0.dist-info/WHEEL +4 -0
- structural_topic_model-0.2.0.dist-info/licenses/LICENSE +21 -0
pystm/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""pystm: Python implementation of the Structural Topic Model.
|
|
2
|
+
|
|
3
|
+
A port of the R ``stm`` package (Roberts, Stewart & Tingley) with an API
|
|
4
|
+
modeled on scikit-learn's ``LatentDirichletAllocation``.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .diagnostics import (
|
|
8
|
+
TopicCorrelations,
|
|
9
|
+
check_residuals,
|
|
10
|
+
exclusivity,
|
|
11
|
+
semantic_coherence,
|
|
12
|
+
topic_corr,
|
|
13
|
+
)
|
|
14
|
+
from .effects import EstimatedEffects, estimate_effect
|
|
15
|
+
from .model_selection import eval_heldout, make_heldout, search_k
|
|
16
|
+
from .stm import StructuralTopicModel
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"StructuralTopicModel",
|
|
20
|
+
"estimate_effect",
|
|
21
|
+
"EstimatedEffects",
|
|
22
|
+
"search_k",
|
|
23
|
+
"make_heldout",
|
|
24
|
+
"eval_heldout",
|
|
25
|
+
"topic_corr",
|
|
26
|
+
"TopicCorrelations",
|
|
27
|
+
"semantic_coherence",
|
|
28
|
+
"exclusivity",
|
|
29
|
+
"check_residuals",
|
|
30
|
+
]
|
|
31
|
+
__version__ = "0.2.0"
|
pystm/_estep.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""Variational E-step (port of STMestep.R, STMlncpp.R and STMCfuns.cpp).
|
|
2
|
+
|
|
3
|
+
For each document the variational posterior over eta (the K-1 dimensional
|
|
4
|
+
logistic-normal document-topic parameter) is approximated by a Laplace
|
|
5
|
+
approximation: the mode ``lambda`` is found with BFGS and the covariance
|
|
6
|
+
``nu`` is the inverse Hessian at the mode.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from scipy.linalg import cho_solve, cholesky, solve
|
|
13
|
+
from scipy.optimize import minimize
|
|
14
|
+
|
|
15
|
+
# Numerical guards absent from the R/C++ original: BFGS line searches can
|
|
16
|
+
# probe points where exp(eta) over/underflows or a document word has
|
|
17
|
+
# (numerically) zero probability under every topic, poisoning the search
|
|
18
|
+
# with inf/nan. Clipping eta and flooring the logs/denominators keeps the
|
|
19
|
+
# objective finite without affecting values in the normal range.
|
|
20
|
+
_ETA_CLIP = 200.0
|
|
21
|
+
_TINY = 1e-300
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _expeta(eta):
|
|
25
|
+
return np.append(np.exp(np.clip(eta, -_ETA_CLIP, _ETA_CLIP)), 1.0)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def neg_lhood(eta, beta_d, doc_ct, mu_d, siginv):
|
|
29
|
+
"""Negative collapsed objective for one document (lhoodcpp)."""
|
|
30
|
+
expeta = _expeta(eta)
|
|
31
|
+
ndoc = doc_ct.sum()
|
|
32
|
+
word_probs = np.maximum(expeta @ beta_d, _TINY)
|
|
33
|
+
part1 = np.log(word_probs) @ doc_ct - ndoc * np.log(expeta.sum())
|
|
34
|
+
diff = eta - mu_d
|
|
35
|
+
part2 = 0.5 * diff @ siginv @ diff
|
|
36
|
+
return part2 - part1
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def neg_grad(eta, beta_d, doc_ct, mu_d, siginv):
|
|
40
|
+
"""Gradient of :func:`neg_lhood` (gradcpp)."""
|
|
41
|
+
expeta = _expeta(eta)
|
|
42
|
+
EB = beta_d * expeta[:, None]
|
|
43
|
+
denom = np.maximum(EB.sum(axis=0), _TINY)
|
|
44
|
+
part1 = EB @ (doc_ct / denom) - (doc_ct.sum() / expeta.sum()) * expeta
|
|
45
|
+
part2 = siginv @ (eta - mu_d)
|
|
46
|
+
return part2 - part1[:-1]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def hessian_phi_bound(eta, beta_d, doc_ct, mu_d, siginv, sigmaentropy):
|
|
50
|
+
"""Compute the Laplace covariance, token assignments and bound (hpbcpp).
|
|
51
|
+
|
|
52
|
+
Returns ``(phis, nu, bound)`` where ``phis`` is K x V_d expected token
|
|
53
|
+
counts per topic, ``nu`` is the (K-1) x (K-1) posterior covariance of
|
|
54
|
+
eta, and ``bound`` is the document's contribution to the global ELBO.
|
|
55
|
+
"""
|
|
56
|
+
expeta = _expeta(eta)
|
|
57
|
+
theta = expeta / expeta.sum()
|
|
58
|
+
ndoc = doc_ct.sum()
|
|
59
|
+
sqrtct = np.sqrt(doc_ct)
|
|
60
|
+
|
|
61
|
+
EB = beta_d * expeta[:, None]
|
|
62
|
+
EB *= (sqrtct / np.maximum(EB.sum(axis=0), _TINY))[None, :]
|
|
63
|
+
|
|
64
|
+
hess = EB @ EB.T - ndoc * np.outer(theta, theta)
|
|
65
|
+
# turn EB into phi (expected token counts per topic and word)
|
|
66
|
+
EB *= sqrtct[None, :]
|
|
67
|
+
np.fill_diagonal(hess, np.diag(hess) - (EB.sum(axis=1) - ndoc * theta))
|
|
68
|
+
hess = hess[:-1, :-1] + siginv
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
L = cholesky(hess, lower=True)
|
|
72
|
+
except np.linalg.LinAlgError:
|
|
73
|
+
# not positive definite: enforce diagonal dominance as in hpbcpp
|
|
74
|
+
dvec = np.diag(hess).copy()
|
|
75
|
+
magnitudes = np.abs(hess).sum(axis=1) - np.abs(dvec)
|
|
76
|
+
dvec = np.maximum(dvec, magnitudes)
|
|
77
|
+
np.fill_diagonal(hess, dvec)
|
|
78
|
+
L = cholesky(hess, lower=True)
|
|
79
|
+
|
|
80
|
+
det_term = -np.log(np.diag(L)).sum()
|
|
81
|
+
nu = cho_solve((L, True), np.eye(hess.shape[0]))
|
|
82
|
+
|
|
83
|
+
diff = eta - mu_d
|
|
84
|
+
bound = (
|
|
85
|
+
np.log(np.maximum(theta @ beta_d, _TINY)) @ doc_ct
|
|
86
|
+
+ det_term
|
|
87
|
+
- 0.5 * diff @ siginv @ diff
|
|
88
|
+
- sigmaentropy
|
|
89
|
+
)
|
|
90
|
+
return EB, nu, bound
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def optimize_document(eta, beta_d, doc_ct, mu_d, siginv, sigmaentropy,
|
|
94
|
+
max_optim_iter=500):
|
|
95
|
+
"""Infer one document's variational parameters (logisticnormalcpp)."""
|
|
96
|
+
res = minimize(
|
|
97
|
+
neg_lhood,
|
|
98
|
+
eta,
|
|
99
|
+
args=(beta_d, doc_ct, mu_d, siginv),
|
|
100
|
+
jac=neg_grad,
|
|
101
|
+
method="BFGS",
|
|
102
|
+
options={"maxiter": max_optim_iter},
|
|
103
|
+
)
|
|
104
|
+
eta_hat = res.x if np.isfinite(res.x).all() else eta
|
|
105
|
+
return (hessian_phi_bound(eta_hat, beta_d, doc_ct, mu_d, siginv,
|
|
106
|
+
sigmaentropy), eta_hat)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def decompose_sigma(sigma):
|
|
110
|
+
"""Precompute the inverse and entropy term shared by all documents."""
|
|
111
|
+
try:
|
|
112
|
+
chol_u = cholesky(sigma, lower=False)
|
|
113
|
+
sigmaentropy = np.log(np.diag(chol_u)).sum()
|
|
114
|
+
siginv = cho_solve((chol_u, False), np.eye(sigma.shape[0]))
|
|
115
|
+
except np.linalg.LinAlgError:
|
|
116
|
+
sigmaentropy = 0.5 * np.linalg.slogdet(sigma)[1]
|
|
117
|
+
siginv = solve(sigma, np.eye(sigma.shape[0]), assume_a="sym")
|
|
118
|
+
return siginv, sigmaentropy
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def estep(docs, beta_index, update_mu, beta, lambda_old, mu, sigma,
|
|
122
|
+
max_optim_iter=500):
|
|
123
|
+
"""Run the E-step over all documents and accumulate sufficient stats.
|
|
124
|
+
|
|
125
|
+
Parameters mirror estep() in STMestep.R. ``docs`` is a list of
|
|
126
|
+
``(word_indices, counts)`` pairs, ``beta`` a list with one K x V matrix
|
|
127
|
+
per content level (always length 1 here), ``mu`` is (K-1,) when shared
|
|
128
|
+
or (N, K-1) when document specific.
|
|
129
|
+
|
|
130
|
+
Returns ``(sigma_ss, beta_ss, bound, lambda_)``.
|
|
131
|
+
"""
|
|
132
|
+
K, V = beta[0].shape
|
|
133
|
+
N = len(docs)
|
|
134
|
+
A = len(beta)
|
|
135
|
+
|
|
136
|
+
sigma_ss = np.zeros((K - 1, K - 1))
|
|
137
|
+
beta_ss = [np.zeros((K, V)) for _ in range(A)]
|
|
138
|
+
bound = np.empty(N)
|
|
139
|
+
lambda_ = np.empty((N, K - 1))
|
|
140
|
+
|
|
141
|
+
siginv, sigmaentropy = decompose_sigma(sigma)
|
|
142
|
+
|
|
143
|
+
for i, (words, counts) in enumerate(docs):
|
|
144
|
+
aspect = beta_index[i]
|
|
145
|
+
init = lambda_old[i]
|
|
146
|
+
mu_d = mu[i] if update_mu else mu
|
|
147
|
+
beta_d = np.ascontiguousarray(beta[aspect][:, words])
|
|
148
|
+
|
|
149
|
+
(phis, nu, bnd), eta_hat = optimize_document(
|
|
150
|
+
init, beta_d, counts, mu_d, siginv, sigmaentropy,
|
|
151
|
+
max_optim_iter=max_optim_iter,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
sigma_ss += nu
|
|
155
|
+
beta_ss[aspect][:, words] += phis
|
|
156
|
+
bound[i] = bnd
|
|
157
|
+
lambda_[i] = eta_hat
|
|
158
|
+
|
|
159
|
+
return sigma_ss, beta_ss, bound, lambda_
|
pystm/_mnreg.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""Content covariate M-step (port of STMmnreg.R).
|
|
2
|
+
|
|
3
|
+
The SAGE-style topic-word update with content covariates is estimated via
|
|
4
|
+
Distributed Multinomial Regression (Taddy 2013): the multinomial is
|
|
5
|
+
factorized into independent Poisson regressions, one per vocabulary word,
|
|
6
|
+
each with an L1 penalty. The R package solves these with glmnet; here we
|
|
7
|
+
implement an equivalent lasso-penalized Poisson solver (IRLS + coordinate
|
|
8
|
+
descent over a regularization path with information-criterion selection).
|
|
9
|
+
|
|
10
|
+
The solver exploits the structure of the problem heavily. The design
|
|
11
|
+
matrix consists of three groups of indicator columns (topic main effects,
|
|
12
|
+
aspect main effects, topic-by-aspect interactions). Columns within a
|
|
13
|
+
group touch disjoint rows, so a coordinate-descent pass over a whole
|
|
14
|
+
group can be performed as one vectorized update; and since every word
|
|
15
|
+
shares the same design, all V regressions are advanced simultaneously.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
from scipy.special import xlogy
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _poisson_deviance(Y, mu):
|
|
25
|
+
"""Per-word Poisson deviance, with the y=0 terms handled."""
|
|
26
|
+
return 2.0 * (xlogy(Y, Y) - xlogy(Y, mu) - (Y - mu)).sum(axis=(0, 1))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _soft_threshold(rho, lam):
|
|
30
|
+
return np.sign(rho) * np.maximum(np.abs(rho) - lam, 0.0)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class _StructuredPoissonLasso:
|
|
34
|
+
"""Distributed Poisson lasso with the STM content design.
|
|
35
|
+
|
|
36
|
+
Data is held as (A, K, V) arrays. ``lam`` arguments are per-word
|
|
37
|
+
penalty levels of shape (V,).
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, Y3, offsets3, use_aspect, use_inter):
|
|
41
|
+
self.Y3 = Y3
|
|
42
|
+
self.offsets3 = offsets3
|
|
43
|
+
self.A, self.K, self.V = Y3.shape
|
|
44
|
+
self.n = self.A * self.K
|
|
45
|
+
self.use_aspect = use_aspect
|
|
46
|
+
self.use_inter = use_inter
|
|
47
|
+
self.b_topic = np.zeros((self.K, self.V))
|
|
48
|
+
self.b_aspect = np.zeros((self.A, self.V)) if use_aspect else None
|
|
49
|
+
self.b_inter = (
|
|
50
|
+
np.zeros((self.A, self.K, self.V)) if use_inter else None
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def linpred(self):
|
|
54
|
+
eta = self.offsets3 + self.b_topic[None, :, :]
|
|
55
|
+
if self.use_aspect:
|
|
56
|
+
eta = eta + self.b_aspect[:, None, :]
|
|
57
|
+
if self.use_inter:
|
|
58
|
+
eta = eta + self.b_inter
|
|
59
|
+
return np.clip(eta, -50.0, 30.0)
|
|
60
|
+
|
|
61
|
+
def _block_update(self, R, W, Wsum, b, lam, axis):
|
|
62
|
+
"""One exact CD pass over a group of disjoint indicator columns.
|
|
63
|
+
|
|
64
|
+
``axis`` is the (A, K, V) axis summed over to aggregate a column's
|
|
65
|
+
rows (None for the interaction block where each row is its own
|
|
66
|
+
column). Updates ``R`` in place and returns (new_b, max_delta).
|
|
67
|
+
"""
|
|
68
|
+
if axis is None:
|
|
69
|
+
num = R + W * b
|
|
70
|
+
denom = Wsum
|
|
71
|
+
else:
|
|
72
|
+
num = R.sum(axis=axis) + Wsum * b
|
|
73
|
+
denom = Wsum
|
|
74
|
+
rho = num / self.n
|
|
75
|
+
b_new = _soft_threshold(rho, lam)
|
|
76
|
+
with np.errstate(invalid="ignore", divide="ignore"):
|
|
77
|
+
b_new = np.where(denom / self.n > 1e-12,
|
|
78
|
+
b_new / (denom / self.n), 0.0)
|
|
79
|
+
delta = b_new - b
|
|
80
|
+
max_delta = np.abs(delta).max() if delta.size else 0.0
|
|
81
|
+
if max_delta > 0:
|
|
82
|
+
if axis is None:
|
|
83
|
+
R -= W * delta
|
|
84
|
+
else:
|
|
85
|
+
R -= W * np.expand_dims(delta, axis)
|
|
86
|
+
return b_new, max_delta
|
|
87
|
+
|
|
88
|
+
def fit_one_lambda(self, lam, tol, max_irls, max_sweeps):
|
|
89
|
+
"""Solve at one penalty level, warm-starting from current state."""
|
|
90
|
+
for _ in range(max_irls):
|
|
91
|
+
eta = self.linpred()
|
|
92
|
+
mu = np.exp(eta)
|
|
93
|
+
W = mu
|
|
94
|
+
R = self.Y3 - mu
|
|
95
|
+
Wsum_topic = W.sum(axis=0)
|
|
96
|
+
Wsum_aspect = W.sum(axis=1) if self.use_aspect else None
|
|
97
|
+
outer_delta = 0.0
|
|
98
|
+
for _ in range(max_sweeps):
|
|
99
|
+
d = 0.0
|
|
100
|
+
self.b_topic, d1 = self._block_update(
|
|
101
|
+
R, W, Wsum_topic, self.b_topic, lam, axis=0)
|
|
102
|
+
d = max(d, d1)
|
|
103
|
+
if self.use_aspect:
|
|
104
|
+
self.b_aspect, d1 = self._block_update(
|
|
105
|
+
R, W, Wsum_aspect, self.b_aspect, lam, axis=1)
|
|
106
|
+
d = max(d, d1)
|
|
107
|
+
if self.use_inter:
|
|
108
|
+
self.b_inter, d1 = self._block_update(
|
|
109
|
+
R, W, W, self.b_inter, lam, axis=None)
|
|
110
|
+
d = max(d, d1)
|
|
111
|
+
outer_delta = max(outer_delta, d)
|
|
112
|
+
if d < tol:
|
|
113
|
+
break
|
|
114
|
+
if outer_delta < tol:
|
|
115
|
+
break
|
|
116
|
+
|
|
117
|
+
def df(self):
|
|
118
|
+
out = (self.b_topic != 0).sum(axis=0)
|
|
119
|
+
if self.use_aspect:
|
|
120
|
+
out = out + (self.b_aspect != 0).sum(axis=0)
|
|
121
|
+
if self.use_inter:
|
|
122
|
+
out = out + (self.b_inter != 0).sum(axis=(0, 1))
|
|
123
|
+
return out
|
|
124
|
+
|
|
125
|
+
def coef_rows(self):
|
|
126
|
+
"""Stack coefficients as the R package's kappa params (p, V)."""
|
|
127
|
+
rows = [self.b_topic]
|
|
128
|
+
if self.use_aspect:
|
|
129
|
+
rows.append(self.b_aspect)
|
|
130
|
+
if self.use_inter:
|
|
131
|
+
rows.append(self.b_inter.reshape(self.n, self.V))
|
|
132
|
+
return np.vstack(rows)
|
|
133
|
+
|
|
134
|
+
def state(self):
|
|
135
|
+
return (self.b_topic.copy(),
|
|
136
|
+
None if self.b_aspect is None else self.b_aspect.copy(),
|
|
137
|
+
None if self.b_inter is None else self.b_inter.copy())
|
|
138
|
+
|
|
139
|
+
def set_state(self, state):
|
|
140
|
+
self.b_topic, self.b_aspect, self.b_inter = (
|
|
141
|
+
state[0].copy(),
|
|
142
|
+
None if state[1] is None else state[1].copy(),
|
|
143
|
+
None if state[2] is None else state[2].copy(),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def mnreg(beta_ss, wcounts, *, interactions=True, nlambda=250,
|
|
148
|
+
lambda_min_ratio=0.001, ic_k=2.0, tol=1e-4,
|
|
149
|
+
max_irls=4, max_sweeps=8):
|
|
150
|
+
"""Update beta and kappa from the E-step expected counts (mnreg in R).
|
|
151
|
+
|
|
152
|
+
Only the (default) fixed-intercept variant is implemented: the
|
|
153
|
+
intercept of each word's Poisson regression is fixed to the background
|
|
154
|
+
log-probability ``m``.
|
|
155
|
+
|
|
156
|
+
Returns ``(beta, kappa)`` where ``beta`` is a list of A matrices of
|
|
157
|
+
shape (K, V) and ``kappa`` is a dict with the baseline ``m`` and the
|
|
158
|
+
selected deviation coefficients ``params`` of shape (p, V).
|
|
159
|
+
"""
|
|
160
|
+
A = len(beta_ss)
|
|
161
|
+
K, V = beta_ss[0].shape
|
|
162
|
+
use_aspect = A > 1
|
|
163
|
+
use_inter = interactions and A > 1
|
|
164
|
+
|
|
165
|
+
Y3 = np.stack(beta_ss) # (A, K, V)
|
|
166
|
+
m = np.log(wcounts) - np.log(wcounts.sum())
|
|
167
|
+
row_totals = np.maximum(Y3.sum(axis=2), 1e-10) # (A, K)
|
|
168
|
+
offsets3 = m[None, None, :] + np.log(row_totals)[:, :, None]
|
|
169
|
+
|
|
170
|
+
solver = _StructuredPoissonLasso(Y3, offsets3, use_aspect, use_inter)
|
|
171
|
+
n = A * K
|
|
172
|
+
|
|
173
|
+
mu0 = np.exp(np.clip(offsets3, -50.0, 30.0))
|
|
174
|
+
nulldev = _poisson_deviance(Y3, mu0)
|
|
175
|
+
# per-word lambda_max: max over columns of |score| at the null model
|
|
176
|
+
R0 = Y3 - mu0
|
|
177
|
+
scores = [np.abs(R0.sum(axis=0))] # topic columns (K, V)
|
|
178
|
+
if use_aspect:
|
|
179
|
+
scores.append(np.abs(R0.sum(axis=1))) # aspect columns (A, V)
|
|
180
|
+
if use_inter:
|
|
181
|
+
scores.append(np.abs(R0).reshape(n, V)) # interaction columns
|
|
182
|
+
lambda_max = np.vstack(scores).max(axis=0) / n
|
|
183
|
+
lambda_max = np.maximum(lambda_max, 1e-10)
|
|
184
|
+
|
|
185
|
+
rel_path = np.exp(np.linspace(0.0, np.log(lambda_min_ratio), nlambda))
|
|
186
|
+
best_ic = nulldev.copy() # path point 0: all coefficients zero
|
|
187
|
+
best_state = solver.state()
|
|
188
|
+
any_improved = False
|
|
189
|
+
|
|
190
|
+
for step in rel_path[1:]:
|
|
191
|
+
lam = lambda_max * step
|
|
192
|
+
solver.fit_one_lambda(lam, tol=tol, max_irls=max_irls,
|
|
193
|
+
max_sweeps=max_sweeps)
|
|
194
|
+
mu = np.exp(solver.linpred())
|
|
195
|
+
dev = 2.0 * (xlogy(Y3, Y3) - xlogy(Y3, mu) - (Y3 - mu)).sum(axis=(0, 1))
|
|
196
|
+
ic = dev + ic_k * solver.df()
|
|
197
|
+
improved = ic < best_ic
|
|
198
|
+
if improved.any():
|
|
199
|
+
any_improved = True
|
|
200
|
+
cur = solver.state()
|
|
201
|
+
best_state[0][:, improved] = cur[0][:, improved]
|
|
202
|
+
if cur[1] is not None:
|
|
203
|
+
best_state[1][:, improved] = cur[1][:, improved]
|
|
204
|
+
if cur[2] is not None:
|
|
205
|
+
best_state[2][:, :, improved] = cur[2][:, :, improved]
|
|
206
|
+
best_ic[improved] = ic[improved]
|
|
207
|
+
|
|
208
|
+
final = _StructuredPoissonLasso(Y3, offsets3, use_aspect, use_inter)
|
|
209
|
+
if any_improved:
|
|
210
|
+
final.set_state(best_state)
|
|
211
|
+
|
|
212
|
+
linpred = final.linpred().reshape(n, V) - (
|
|
213
|
+
np.log(row_totals).reshape(n)[:, None]
|
|
214
|
+
)
|
|
215
|
+
linpred -= linpred.max(axis=1, keepdims=True)
|
|
216
|
+
explinpred = np.exp(linpred)
|
|
217
|
+
beta_full = explinpred / explinpred.sum(axis=1, keepdims=True)
|
|
218
|
+
|
|
219
|
+
beta = [beta_full[a * K:(a + 1) * K] for a in range(A)]
|
|
220
|
+
kappa = {"m": m, "params": final.coef_rows()}
|
|
221
|
+
return beta, kappa
|
pystm/_mstep.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""M-step updates (port of STMmu.R, STMsigma.R, STMoptbeta.R)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.linalg import cho_solve, cholesky
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PrevalenceRegressionError(RuntimeError):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def vb_variational_reg(Y, X, Xcorr=None, b0=1.0, d0=1.0, max_iter=1000):
|
|
14
|
+
"""Variational linear regression with a half-Cauchy hyperprior.
|
|
15
|
+
|
|
16
|
+
Port of vb.variational.reg(); the first column of ``X`` is assumed to
|
|
17
|
+
be the (unpenalized) intercept.
|
|
18
|
+
"""
|
|
19
|
+
if Xcorr is None:
|
|
20
|
+
Xcorr = X.T @ X
|
|
21
|
+
XYcorr = X.T @ Y
|
|
22
|
+
|
|
23
|
+
N, D = X.shape
|
|
24
|
+
an = (1 + N) / 2
|
|
25
|
+
w = np.zeros(D)
|
|
26
|
+
error_prec = 1.0
|
|
27
|
+
cn = D
|
|
28
|
+
dn = 1.0
|
|
29
|
+
Ea = cn / dn
|
|
30
|
+
ba = 1.0
|
|
31
|
+
|
|
32
|
+
for _ in range(max_iter):
|
|
33
|
+
w_old = w
|
|
34
|
+
|
|
35
|
+
prior_diag = np.full(D, Ea)
|
|
36
|
+
prior_diag[0] = 0.0
|
|
37
|
+
invV = error_prec * Xcorr + np.diag(prior_diag)
|
|
38
|
+
L = cholesky(invV, lower=True)
|
|
39
|
+
V = cho_solve((L, True), np.eye(D))
|
|
40
|
+
w = error_prec * (V @ XYcorr)
|
|
41
|
+
|
|
42
|
+
sse = np.sum((X @ w - Y) ** 2)
|
|
43
|
+
bn = 0.5 * (sse + np.trace(Xcorr @ V)) + ba
|
|
44
|
+
error_prec = an / bn
|
|
45
|
+
ba = 1.0 / (error_prec + b0)
|
|
46
|
+
|
|
47
|
+
da = 2.0 / (Ea + d0)
|
|
48
|
+
dn = 2.0 * da + (w[1:] @ w[1:] + np.diag(V)[1:].sum())
|
|
49
|
+
Ea = cn / dn
|
|
50
|
+
|
|
51
|
+
if np.abs(w - w_old).sum() < 1e-4:
|
|
52
|
+
return w
|
|
53
|
+
|
|
54
|
+
raise PrevalenceRegressionError(
|
|
55
|
+
"Prevalence regression failed to converge within the iteration "
|
|
56
|
+
"limit. You can raise it with gamma_max_iter."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def opt_mu(lambda_, covar=None, max_iter=1000):
|
|
61
|
+
"""Update the prevalence model (opt.mu, modes CTM and Pooled).
|
|
62
|
+
|
|
63
|
+
Returns ``(mu, gamma)``. Without covariates (CTM mode) ``mu`` is the
|
|
64
|
+
shared (K-1,) mean and ``gamma`` is None. With covariates ``mu`` is
|
|
65
|
+
(N, K-1), the per-document prior means, and ``gamma`` is (P, K-1).
|
|
66
|
+
"""
|
|
67
|
+
if covar is None:
|
|
68
|
+
return lambda_.mean(axis=0), None
|
|
69
|
+
|
|
70
|
+
Xcorr = covar.T @ covar
|
|
71
|
+
gamma = np.column_stack([
|
|
72
|
+
vb_variational_reg(lambda_[:, k], covar, Xcorr=Xcorr, max_iter=max_iter)
|
|
73
|
+
for k in range(lambda_.shape[1])
|
|
74
|
+
])
|
|
75
|
+
mu = covar @ gamma
|
|
76
|
+
return mu, gamma
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def opt_sigma(nu, lambda_, mu, sigprior):
|
|
80
|
+
"""Update the global covariance matrix (opt.sigma).
|
|
81
|
+
|
|
82
|
+
``mu`` is (K-1,) for the shared mean or (N, K-1) for covariate models.
|
|
83
|
+
"""
|
|
84
|
+
if mu.ndim == 1:
|
|
85
|
+
diff = lambda_ - mu[None, :]
|
|
86
|
+
else:
|
|
87
|
+
diff = lambda_ - mu
|
|
88
|
+
sigma = (diff.T @ diff + nu) / lambda_.shape[0]
|
|
89
|
+
return sigprior * np.diag(np.diag(sigma)) + (1 - sigprior) * sigma
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def opt_beta(beta_ss):
|
|
93
|
+
"""Update the topic-word distributions (opt.beta, LDA-beta mode)."""
|
|
94
|
+
row_sums = beta_ss[0].sum(axis=1, keepdims=True)
|
|
95
|
+
row_sums[row_sums == 0] = 1.0
|
|
96
|
+
return [beta_ss[0] / row_sums]
|
pystm/_spectral.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Spectral initialization via anchor words (port of spectral.R).
|
|
2
|
+
|
|
3
|
+
Implements the method of Arora et al. (2013): build the word co-occurrence
|
|
4
|
+
gram matrix, greedily select anchor words, then recover the topic-word
|
|
5
|
+
matrix with RecoverL2. The R package solves the simplex-constrained
|
|
6
|
+
regression exactly with quadprog by default; here we use a penalized NNLS
|
|
7
|
+
formulation which matches it closely. The exponentiated gradient
|
|
8
|
+
algorithm (R's ``recoverEG=TRUE`` option) is available as an alternative.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from scipy.optimize import nnls
|
|
15
|
+
from scipy.sparse import csr_matrix
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def gram(mat: csr_matrix) -> np.ndarray:
|
|
19
|
+
"""Word co-occurrence gram matrix from a sparse doc-term matrix."""
|
|
20
|
+
nd = np.asarray(mat.sum(axis=1)).ravel()
|
|
21
|
+
keep = nd >= 2 # undefined for docs with fewer than 2 tokens
|
|
22
|
+
mat = mat[keep]
|
|
23
|
+
nd = nd[keep]
|
|
24
|
+
divisor = nd * (nd - 1)
|
|
25
|
+
|
|
26
|
+
htilde = mat.multiply(1.0 / np.sqrt(divisor)[:, None]).tocsr()
|
|
27
|
+
hhat = np.asarray(mat.multiply(1.0 / divisor[:, None]).sum(axis=0)).ravel()
|
|
28
|
+
Q = (htilde.T @ htilde).toarray()
|
|
29
|
+
Q[np.diag_indices_from(Q)] -= hhat
|
|
30
|
+
return Q
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def fast_anchor(Qbar: np.ndarray, K: int) -> np.ndarray:
|
|
34
|
+
"""Greedy anchor word selection by stabilized Gram-Schmidt."""
|
|
35
|
+
Qbar = Qbar.copy()
|
|
36
|
+
basis = np.zeros(K, dtype=np.int64)
|
|
37
|
+
row_squared_sums = (Qbar**2).sum(axis=1)
|
|
38
|
+
|
|
39
|
+
for i in range(K):
|
|
40
|
+
basis[i] = int(np.argmax(row_squared_sums))
|
|
41
|
+
max_val = row_squared_sums[basis[i]]
|
|
42
|
+
Qbar[basis[i]] *= 1.0 / np.sqrt(max_val)
|
|
43
|
+
|
|
44
|
+
inner_products = Qbar @ Qbar[basis[i]]
|
|
45
|
+
project = np.outer(inner_products, Qbar[basis[i]])
|
|
46
|
+
project[basis[: i + 1]] = 0.0
|
|
47
|
+
Qbar -= project
|
|
48
|
+
row_squared_sums = (Qbar**2).sum(axis=1)
|
|
49
|
+
row_squared_sums[basis[: i + 1]] = 0.0
|
|
50
|
+
return basis
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def expgrad(X, y, XtX=None, alpha=None, tol=1e-7, max_iter=500):
|
|
54
|
+
"""Exponentiated gradient for simplex-constrained least squares."""
|
|
55
|
+
if alpha is None:
|
|
56
|
+
alpha = np.full(X.shape[0], 1.0 / X.shape[0])
|
|
57
|
+
if XtX is None:
|
|
58
|
+
XtX = X @ X.T
|
|
59
|
+
ytX = y @ X.T
|
|
60
|
+
|
|
61
|
+
eta = 50.0
|
|
62
|
+
sse_old = np.inf
|
|
63
|
+
for _ in range(max_iter):
|
|
64
|
+
grad = ytX - alpha @ XtX
|
|
65
|
+
sse = grad @ grad
|
|
66
|
+
grad = 2.0 * eta * grad
|
|
67
|
+
alpha = alpha * np.exp(grad - grad.max())
|
|
68
|
+
alpha = alpha / alpha.sum()
|
|
69
|
+
if abs(np.sqrt(sse_old) - np.sqrt(sse)) < tol:
|
|
70
|
+
break
|
|
71
|
+
sse_old = sse
|
|
72
|
+
return alpha
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def recover_l2(Qbar, anchors, wprob, solver="nnls"):
|
|
76
|
+
"""Recover the K x V topic-word matrix from the anchor rows.
|
|
77
|
+
|
|
78
|
+
Each word's row of ``Qbar`` is expressed as a convex combination of
|
|
79
|
+
the anchor rows. ``solver="nnls"`` enforces the sum-to-one constraint
|
|
80
|
+
through a heavily weighted penalty row in a non-negative least-squares
|
|
81
|
+
problem (the analogue of the exact quadprog solve used by the R
|
|
82
|
+
package); ``solver="expgrad"`` uses exponentiated gradient descent
|
|
83
|
+
(R's ``recoverEG=TRUE``).
|
|
84
|
+
"""
|
|
85
|
+
X = Qbar[anchors]
|
|
86
|
+
XtX = X @ X.T
|
|
87
|
+
K = len(anchors)
|
|
88
|
+
anchor_pos = {a: idx for idx, a in enumerate(anchors)}
|
|
89
|
+
|
|
90
|
+
if solver == "nnls":
|
|
91
|
+
penalty = 1000.0
|
|
92
|
+
X_aug = np.vstack([X.T, np.full(K, penalty)])
|
|
93
|
+
|
|
94
|
+
weights = np.empty((Qbar.shape[0], K))
|
|
95
|
+
for i in range(Qbar.shape[0]):
|
|
96
|
+
if i in anchor_pos:
|
|
97
|
+
vec = np.zeros(K)
|
|
98
|
+
vec[anchor_pos[i]] = 1.0
|
|
99
|
+
weights[i] = vec
|
|
100
|
+
elif solver == "nnls":
|
|
101
|
+
solution, _ = nnls(X_aug, np.append(Qbar[i], penalty))
|
|
102
|
+
solution = np.maximum(solution, np.finfo(np.float64).eps)
|
|
103
|
+
weights[i] = solution / solution.sum()
|
|
104
|
+
else:
|
|
105
|
+
solution = expgrad(X, Qbar[i], XtX)
|
|
106
|
+
solution = np.maximum(solution, np.finfo(np.float64).eps)
|
|
107
|
+
weights[i] = solution
|
|
108
|
+
|
|
109
|
+
A = weights * wprob[:, None]
|
|
110
|
+
A = A.T / A.sum(axis=0)[None, :].T
|
|
111
|
+
return A
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def spectral_init(X: csr_matrix, K: int, max_vocab: int | None = 10000,
|
|
115
|
+
solver: str = "nnls") -> np.ndarray:
|
|
116
|
+
"""Spectral initialization of beta (the Spectral branch of stm.init)."""
|
|
117
|
+
V = X.shape[1]
|
|
118
|
+
if K >= V:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"Spectral initialization cannot be used when K >= vocabulary size."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
wprob = np.asarray(X.sum(axis=0)).ravel().astype(np.float64)
|
|
124
|
+
wprob /= wprob.sum()
|
|
125
|
+
|
|
126
|
+
keep = None
|
|
127
|
+
if max_vocab is not None and V > max_vocab:
|
|
128
|
+
keep = np.argsort(wprob)[::-1][:max_vocab]
|
|
129
|
+
X = X[:, keep]
|
|
130
|
+
wprob = wprob[keep]
|
|
131
|
+
|
|
132
|
+
Q = gram(X)
|
|
133
|
+
Qsums = Q.sum(axis=1)
|
|
134
|
+
if np.any(Qsums == 0):
|
|
135
|
+
nonzero = Qsums != 0
|
|
136
|
+
keep = np.where(nonzero)[0] if keep is None else keep[nonzero]
|
|
137
|
+
Q = Q[np.ix_(nonzero, nonzero)]
|
|
138
|
+
Qsums = Qsums[nonzero]
|
|
139
|
+
wprob = wprob[nonzero]
|
|
140
|
+
Qbar = Q / Qsums[:, None]
|
|
141
|
+
|
|
142
|
+
anchors = fast_anchor(Qbar, K)
|
|
143
|
+
beta = recover_l2(Qbar, anchors, wprob, solver=solver)
|
|
144
|
+
|
|
145
|
+
if keep is not None:
|
|
146
|
+
# reintroduce dropped words with a small amount of mass
|
|
147
|
+
beta_new = np.zeros((K, V))
|
|
148
|
+
beta_new[:, keep] = beta
|
|
149
|
+
beta_new += 0.001 / V
|
|
150
|
+
beta = beta_new / beta_new.sum(axis=1, keepdims=True)
|
|
151
|
+
return beta
|
pystm/_utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Shared numerical utilities (port of STMfunctions.R)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.sparse import csr_matrix, issparse
|
|
7
|
+
from scipy.special import logsumexp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def safelog(x: np.ndarray, min_value: float = -1000.0) -> np.ndarray:
|
|
11
|
+
"""log(x) with -inf (and anything below ``min_value``) clamped."""
|
|
12
|
+
with np.errstate(divide="ignore"):
|
|
13
|
+
out = np.log(x)
|
|
14
|
+
return np.maximum(out, min_value)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def row_softmax(mat: np.ndarray) -> np.ndarray:
|
|
18
|
+
"""Row-wise softmax of a 2-d array."""
|
|
19
|
+
return np.exp(mat - logsumexp(mat, axis=1, keepdims=True))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def to_doc_list(X) -> list[tuple[np.ndarray, np.ndarray]]:
|
|
23
|
+
"""Convert a document-term count matrix to the stm document format.
|
|
24
|
+
|
|
25
|
+
Returns one ``(word_indices, word_counts)`` pair per document, holding
|
|
26
|
+
the column indices of the document's distinct terms and their counts
|
|
27
|
+
(the two rows of the R package's document matrices, zero-indexed).
|
|
28
|
+
"""
|
|
29
|
+
X = csr_matrix(X) if not issparse(X) else X.tocsr()
|
|
30
|
+
if (X.data < 0).any():
|
|
31
|
+
raise ValueError("X must contain non-negative counts.")
|
|
32
|
+
if np.any(X.data != np.round(X.data)):
|
|
33
|
+
raise ValueError("X must contain integer counts.")
|
|
34
|
+
docs = []
|
|
35
|
+
for i in range(X.shape[0]):
|
|
36
|
+
start, end = X.indptr[i], X.indptr[i + 1]
|
|
37
|
+
words = X.indices[start:end].astype(np.int64)
|
|
38
|
+
counts = X.data[start:end].astype(np.float64)
|
|
39
|
+
keep = counts > 0
|
|
40
|
+
docs.append((words[keep], counts[keep]))
|
|
41
|
+
return docs
|