structural-topic-model 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pystm/__init__.py ADDED
@@ -0,0 +1,31 @@
1
+ """pystm: Python implementation of the Structural Topic Model.
2
+
3
+ A port of the R ``stm`` package (Roberts, Stewart & Tingley) with an API
4
+ modeled on scikit-learn's ``LatentDirichletAllocation``.
5
+ """
6
+
7
+ from .diagnostics import (
8
+ TopicCorrelations,
9
+ check_residuals,
10
+ exclusivity,
11
+ semantic_coherence,
12
+ topic_corr,
13
+ )
14
+ from .effects import EstimatedEffects, estimate_effect
15
+ from .model_selection import eval_heldout, make_heldout, search_k
16
+ from .stm import StructuralTopicModel
17
+
18
+ __all__ = [
19
+ "StructuralTopicModel",
20
+ "estimate_effect",
21
+ "EstimatedEffects",
22
+ "search_k",
23
+ "make_heldout",
24
+ "eval_heldout",
25
+ "topic_corr",
26
+ "TopicCorrelations",
27
+ "semantic_coherence",
28
+ "exclusivity",
29
+ "check_residuals",
30
+ ]
31
+ __version__ = "0.2.0"
pystm/_estep.py ADDED
@@ -0,0 +1,159 @@
1
+ """Variational E-step (port of STMestep.R, STMlncpp.R and STMCfuns.cpp).
2
+
3
+ For each document the variational posterior over eta (the K-1 dimensional
4
+ logistic-normal document-topic parameter) is approximated by a Laplace
5
+ approximation: the mode ``lambda`` is found with BFGS and the covariance
6
+ ``nu`` is the inverse Hessian at the mode.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import numpy as np
12
+ from scipy.linalg import cho_solve, cholesky, solve
13
+ from scipy.optimize import minimize
14
+
15
+ # Numerical guards absent from the R/C++ original: BFGS line searches can
16
+ # probe points where exp(eta) over/underflows or a document word has
17
+ # (numerically) zero probability under every topic, poisoning the search
18
+ # with inf/nan. Clipping eta and flooring the logs/denominators keeps the
19
+ # objective finite without affecting values in the normal range.
20
+ _ETA_CLIP = 200.0
21
+ _TINY = 1e-300
22
+
23
+
24
+ def _expeta(eta):
25
+ return np.append(np.exp(np.clip(eta, -_ETA_CLIP, _ETA_CLIP)), 1.0)
26
+
27
+
28
+ def neg_lhood(eta, beta_d, doc_ct, mu_d, siginv):
29
+ """Negative collapsed objective for one document (lhoodcpp)."""
30
+ expeta = _expeta(eta)
31
+ ndoc = doc_ct.sum()
32
+ word_probs = np.maximum(expeta @ beta_d, _TINY)
33
+ part1 = np.log(word_probs) @ doc_ct - ndoc * np.log(expeta.sum())
34
+ diff = eta - mu_d
35
+ part2 = 0.5 * diff @ siginv @ diff
36
+ return part2 - part1
37
+
38
+
39
+ def neg_grad(eta, beta_d, doc_ct, mu_d, siginv):
40
+ """Gradient of :func:`neg_lhood` (gradcpp)."""
41
+ expeta = _expeta(eta)
42
+ EB = beta_d * expeta[:, None]
43
+ denom = np.maximum(EB.sum(axis=0), _TINY)
44
+ part1 = EB @ (doc_ct / denom) - (doc_ct.sum() / expeta.sum()) * expeta
45
+ part2 = siginv @ (eta - mu_d)
46
+ return part2 - part1[:-1]
47
+
48
+
49
+ def hessian_phi_bound(eta, beta_d, doc_ct, mu_d, siginv, sigmaentropy):
50
+ """Compute the Laplace covariance, token assignments and bound (hpbcpp).
51
+
52
+ Returns ``(phis, nu, bound)`` where ``phis`` is K x V_d expected token
53
+ counts per topic, ``nu`` is the (K-1) x (K-1) posterior covariance of
54
+ eta, and ``bound`` is the document's contribution to the global ELBO.
55
+ """
56
+ expeta = _expeta(eta)
57
+ theta = expeta / expeta.sum()
58
+ ndoc = doc_ct.sum()
59
+ sqrtct = np.sqrt(doc_ct)
60
+
61
+ EB = beta_d * expeta[:, None]
62
+ EB *= (sqrtct / np.maximum(EB.sum(axis=0), _TINY))[None, :]
63
+
64
+ hess = EB @ EB.T - ndoc * np.outer(theta, theta)
65
+ # turn EB into phi (expected token counts per topic and word)
66
+ EB *= sqrtct[None, :]
67
+ np.fill_diagonal(hess, np.diag(hess) - (EB.sum(axis=1) - ndoc * theta))
68
+ hess = hess[:-1, :-1] + siginv
69
+
70
+ try:
71
+ L = cholesky(hess, lower=True)
72
+ except np.linalg.LinAlgError:
73
+ # not positive definite: enforce diagonal dominance as in hpbcpp
74
+ dvec = np.diag(hess).copy()
75
+ magnitudes = np.abs(hess).sum(axis=1) - np.abs(dvec)
76
+ dvec = np.maximum(dvec, magnitudes)
77
+ np.fill_diagonal(hess, dvec)
78
+ L = cholesky(hess, lower=True)
79
+
80
+ det_term = -np.log(np.diag(L)).sum()
81
+ nu = cho_solve((L, True), np.eye(hess.shape[0]))
82
+
83
+ diff = eta - mu_d
84
+ bound = (
85
+ np.log(np.maximum(theta @ beta_d, _TINY)) @ doc_ct
86
+ + det_term
87
+ - 0.5 * diff @ siginv @ diff
88
+ - sigmaentropy
89
+ )
90
+ return EB, nu, bound
91
+
92
+
93
+ def optimize_document(eta, beta_d, doc_ct, mu_d, siginv, sigmaentropy,
94
+ max_optim_iter=500):
95
+ """Infer one document's variational parameters (logisticnormalcpp)."""
96
+ res = minimize(
97
+ neg_lhood,
98
+ eta,
99
+ args=(beta_d, doc_ct, mu_d, siginv),
100
+ jac=neg_grad,
101
+ method="BFGS",
102
+ options={"maxiter": max_optim_iter},
103
+ )
104
+ eta_hat = res.x if np.isfinite(res.x).all() else eta
105
+ return (hessian_phi_bound(eta_hat, beta_d, doc_ct, mu_d, siginv,
106
+ sigmaentropy), eta_hat)
107
+
108
+
109
+ def decompose_sigma(sigma):
110
+ """Precompute the inverse and entropy term shared by all documents."""
111
+ try:
112
+ chol_u = cholesky(sigma, lower=False)
113
+ sigmaentropy = np.log(np.diag(chol_u)).sum()
114
+ siginv = cho_solve((chol_u, False), np.eye(sigma.shape[0]))
115
+ except np.linalg.LinAlgError:
116
+ sigmaentropy = 0.5 * np.linalg.slogdet(sigma)[1]
117
+ siginv = solve(sigma, np.eye(sigma.shape[0]), assume_a="sym")
118
+ return siginv, sigmaentropy
119
+
120
+
121
+ def estep(docs, beta_index, update_mu, beta, lambda_old, mu, sigma,
122
+ max_optim_iter=500):
123
+ """Run the E-step over all documents and accumulate sufficient stats.
124
+
125
+ Parameters mirror estep() in STMestep.R. ``docs`` is a list of
126
+ ``(word_indices, counts)`` pairs, ``beta`` a list with one K x V matrix
127
+ per content level (always length 1 here), ``mu`` is (K-1,) when shared
128
+ or (N, K-1) when document specific.
129
+
130
+ Returns ``(sigma_ss, beta_ss, bound, lambda_)``.
131
+ """
132
+ K, V = beta[0].shape
133
+ N = len(docs)
134
+ A = len(beta)
135
+
136
+ sigma_ss = np.zeros((K - 1, K - 1))
137
+ beta_ss = [np.zeros((K, V)) for _ in range(A)]
138
+ bound = np.empty(N)
139
+ lambda_ = np.empty((N, K - 1))
140
+
141
+ siginv, sigmaentropy = decompose_sigma(sigma)
142
+
143
+ for i, (words, counts) in enumerate(docs):
144
+ aspect = beta_index[i]
145
+ init = lambda_old[i]
146
+ mu_d = mu[i] if update_mu else mu
147
+ beta_d = np.ascontiguousarray(beta[aspect][:, words])
148
+
149
+ (phis, nu, bnd), eta_hat = optimize_document(
150
+ init, beta_d, counts, mu_d, siginv, sigmaentropy,
151
+ max_optim_iter=max_optim_iter,
152
+ )
153
+
154
+ sigma_ss += nu
155
+ beta_ss[aspect][:, words] += phis
156
+ bound[i] = bnd
157
+ lambda_[i] = eta_hat
158
+
159
+ return sigma_ss, beta_ss, bound, lambda_
pystm/_mnreg.py ADDED
@@ -0,0 +1,221 @@
1
+ """Content covariate M-step (port of STMmnreg.R).
2
+
3
+ The SAGE-style topic-word update with content covariates is estimated via
4
+ Distributed Multinomial Regression (Taddy 2013): the multinomial is
5
+ factorized into independent Poisson regressions, one per vocabulary word,
6
+ each with an L1 penalty. The R package solves these with glmnet; here we
7
+ implement an equivalent lasso-penalized Poisson solver (IRLS + coordinate
8
+ descent over a regularization path with information-criterion selection).
9
+
10
+ The solver exploits the structure of the problem heavily. The design
11
+ matrix consists of three groups of indicator columns (topic main effects,
12
+ aspect main effects, topic-by-aspect interactions). Columns within a
13
+ group touch disjoint rows, so a coordinate-descent pass over a whole
14
+ group can be performed as one vectorized update; and since every word
15
+ shares the same design, all V regressions are advanced simultaneously.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import numpy as np
21
+ from scipy.special import xlogy
22
+
23
+
24
+ def _poisson_deviance(Y, mu):
25
+ """Per-word Poisson deviance, with the y=0 terms handled."""
26
+ return 2.0 * (xlogy(Y, Y) - xlogy(Y, mu) - (Y - mu)).sum(axis=(0, 1))
27
+
28
+
29
+ def _soft_threshold(rho, lam):
30
+ return np.sign(rho) * np.maximum(np.abs(rho) - lam, 0.0)
31
+
32
+
33
+ class _StructuredPoissonLasso:
34
+ """Distributed Poisson lasso with the STM content design.
35
+
36
+ Data is held as (A, K, V) arrays. ``lam`` arguments are per-word
37
+ penalty levels of shape (V,).
38
+ """
39
+
40
+ def __init__(self, Y3, offsets3, use_aspect, use_inter):
41
+ self.Y3 = Y3
42
+ self.offsets3 = offsets3
43
+ self.A, self.K, self.V = Y3.shape
44
+ self.n = self.A * self.K
45
+ self.use_aspect = use_aspect
46
+ self.use_inter = use_inter
47
+ self.b_topic = np.zeros((self.K, self.V))
48
+ self.b_aspect = np.zeros((self.A, self.V)) if use_aspect else None
49
+ self.b_inter = (
50
+ np.zeros((self.A, self.K, self.V)) if use_inter else None
51
+ )
52
+
53
+ def linpred(self):
54
+ eta = self.offsets3 + self.b_topic[None, :, :]
55
+ if self.use_aspect:
56
+ eta = eta + self.b_aspect[:, None, :]
57
+ if self.use_inter:
58
+ eta = eta + self.b_inter
59
+ return np.clip(eta, -50.0, 30.0)
60
+
61
+ def _block_update(self, R, W, Wsum, b, lam, axis):
62
+ """One exact CD pass over a group of disjoint indicator columns.
63
+
64
+ ``axis`` is the (A, K, V) axis summed over to aggregate a column's
65
+ rows (None for the interaction block where each row is its own
66
+ column). Updates ``R`` in place and returns (new_b, max_delta).
67
+ """
68
+ if axis is None:
69
+ num = R + W * b
70
+ denom = Wsum
71
+ else:
72
+ num = R.sum(axis=axis) + Wsum * b
73
+ denom = Wsum
74
+ rho = num / self.n
75
+ b_new = _soft_threshold(rho, lam)
76
+ with np.errstate(invalid="ignore", divide="ignore"):
77
+ b_new = np.where(denom / self.n > 1e-12,
78
+ b_new / (denom / self.n), 0.0)
79
+ delta = b_new - b
80
+ max_delta = np.abs(delta).max() if delta.size else 0.0
81
+ if max_delta > 0:
82
+ if axis is None:
83
+ R -= W * delta
84
+ else:
85
+ R -= W * np.expand_dims(delta, axis)
86
+ return b_new, max_delta
87
+
88
+ def fit_one_lambda(self, lam, tol, max_irls, max_sweeps):
89
+ """Solve at one penalty level, warm-starting from current state."""
90
+ for _ in range(max_irls):
91
+ eta = self.linpred()
92
+ mu = np.exp(eta)
93
+ W = mu
94
+ R = self.Y3 - mu
95
+ Wsum_topic = W.sum(axis=0)
96
+ Wsum_aspect = W.sum(axis=1) if self.use_aspect else None
97
+ outer_delta = 0.0
98
+ for _ in range(max_sweeps):
99
+ d = 0.0
100
+ self.b_topic, d1 = self._block_update(
101
+ R, W, Wsum_topic, self.b_topic, lam, axis=0)
102
+ d = max(d, d1)
103
+ if self.use_aspect:
104
+ self.b_aspect, d1 = self._block_update(
105
+ R, W, Wsum_aspect, self.b_aspect, lam, axis=1)
106
+ d = max(d, d1)
107
+ if self.use_inter:
108
+ self.b_inter, d1 = self._block_update(
109
+ R, W, W, self.b_inter, lam, axis=None)
110
+ d = max(d, d1)
111
+ outer_delta = max(outer_delta, d)
112
+ if d < tol:
113
+ break
114
+ if outer_delta < tol:
115
+ break
116
+
117
+ def df(self):
118
+ out = (self.b_topic != 0).sum(axis=0)
119
+ if self.use_aspect:
120
+ out = out + (self.b_aspect != 0).sum(axis=0)
121
+ if self.use_inter:
122
+ out = out + (self.b_inter != 0).sum(axis=(0, 1))
123
+ return out
124
+
125
+ def coef_rows(self):
126
+ """Stack coefficients as the R package's kappa params (p, V)."""
127
+ rows = [self.b_topic]
128
+ if self.use_aspect:
129
+ rows.append(self.b_aspect)
130
+ if self.use_inter:
131
+ rows.append(self.b_inter.reshape(self.n, self.V))
132
+ return np.vstack(rows)
133
+
134
+ def state(self):
135
+ return (self.b_topic.copy(),
136
+ None if self.b_aspect is None else self.b_aspect.copy(),
137
+ None if self.b_inter is None else self.b_inter.copy())
138
+
139
+ def set_state(self, state):
140
+ self.b_topic, self.b_aspect, self.b_inter = (
141
+ state[0].copy(),
142
+ None if state[1] is None else state[1].copy(),
143
+ None if state[2] is None else state[2].copy(),
144
+ )
145
+
146
+
147
+ def mnreg(beta_ss, wcounts, *, interactions=True, nlambda=250,
148
+ lambda_min_ratio=0.001, ic_k=2.0, tol=1e-4,
149
+ max_irls=4, max_sweeps=8):
150
+ """Update beta and kappa from the E-step expected counts (mnreg in R).
151
+
152
+ Only the (default) fixed-intercept variant is implemented: the
153
+ intercept of each word's Poisson regression is fixed to the background
154
+ log-probability ``m``.
155
+
156
+ Returns ``(beta, kappa)`` where ``beta`` is a list of A matrices of
157
+ shape (K, V) and ``kappa`` is a dict with the baseline ``m`` and the
158
+ selected deviation coefficients ``params`` of shape (p, V).
159
+ """
160
+ A = len(beta_ss)
161
+ K, V = beta_ss[0].shape
162
+ use_aspect = A > 1
163
+ use_inter = interactions and A > 1
164
+
165
+ Y3 = np.stack(beta_ss) # (A, K, V)
166
+ m = np.log(wcounts) - np.log(wcounts.sum())
167
+ row_totals = np.maximum(Y3.sum(axis=2), 1e-10) # (A, K)
168
+ offsets3 = m[None, None, :] + np.log(row_totals)[:, :, None]
169
+
170
+ solver = _StructuredPoissonLasso(Y3, offsets3, use_aspect, use_inter)
171
+ n = A * K
172
+
173
+ mu0 = np.exp(np.clip(offsets3, -50.0, 30.0))
174
+ nulldev = _poisson_deviance(Y3, mu0)
175
+ # per-word lambda_max: max over columns of |score| at the null model
176
+ R0 = Y3 - mu0
177
+ scores = [np.abs(R0.sum(axis=0))] # topic columns (K, V)
178
+ if use_aspect:
179
+ scores.append(np.abs(R0.sum(axis=1))) # aspect columns (A, V)
180
+ if use_inter:
181
+ scores.append(np.abs(R0).reshape(n, V)) # interaction columns
182
+ lambda_max = np.vstack(scores).max(axis=0) / n
183
+ lambda_max = np.maximum(lambda_max, 1e-10)
184
+
185
+ rel_path = np.exp(np.linspace(0.0, np.log(lambda_min_ratio), nlambda))
186
+ best_ic = nulldev.copy() # path point 0: all coefficients zero
187
+ best_state = solver.state()
188
+ any_improved = False
189
+
190
+ for step in rel_path[1:]:
191
+ lam = lambda_max * step
192
+ solver.fit_one_lambda(lam, tol=tol, max_irls=max_irls,
193
+ max_sweeps=max_sweeps)
194
+ mu = np.exp(solver.linpred())
195
+ dev = 2.0 * (xlogy(Y3, Y3) - xlogy(Y3, mu) - (Y3 - mu)).sum(axis=(0, 1))
196
+ ic = dev + ic_k * solver.df()
197
+ improved = ic < best_ic
198
+ if improved.any():
199
+ any_improved = True
200
+ cur = solver.state()
201
+ best_state[0][:, improved] = cur[0][:, improved]
202
+ if cur[1] is not None:
203
+ best_state[1][:, improved] = cur[1][:, improved]
204
+ if cur[2] is not None:
205
+ best_state[2][:, :, improved] = cur[2][:, :, improved]
206
+ best_ic[improved] = ic[improved]
207
+
208
+ final = _StructuredPoissonLasso(Y3, offsets3, use_aspect, use_inter)
209
+ if any_improved:
210
+ final.set_state(best_state)
211
+
212
+ linpred = final.linpred().reshape(n, V) - (
213
+ np.log(row_totals).reshape(n)[:, None]
214
+ )
215
+ linpred -= linpred.max(axis=1, keepdims=True)
216
+ explinpred = np.exp(linpred)
217
+ beta_full = explinpred / explinpred.sum(axis=1, keepdims=True)
218
+
219
+ beta = [beta_full[a * K:(a + 1) * K] for a in range(A)]
220
+ kappa = {"m": m, "params": final.coef_rows()}
221
+ return beta, kappa
pystm/_mstep.py ADDED
@@ -0,0 +1,96 @@
1
+ """M-step updates (port of STMmu.R, STMsigma.R, STMoptbeta.R)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ from scipy.linalg import cho_solve, cholesky
7
+
8
+
9
+ class PrevalenceRegressionError(RuntimeError):
10
+ pass
11
+
12
+
13
+ def vb_variational_reg(Y, X, Xcorr=None, b0=1.0, d0=1.0, max_iter=1000):
14
+ """Variational linear regression with a half-Cauchy hyperprior.
15
+
16
+ Port of vb.variational.reg(); the first column of ``X`` is assumed to
17
+ be the (unpenalized) intercept.
18
+ """
19
+ if Xcorr is None:
20
+ Xcorr = X.T @ X
21
+ XYcorr = X.T @ Y
22
+
23
+ N, D = X.shape
24
+ an = (1 + N) / 2
25
+ w = np.zeros(D)
26
+ error_prec = 1.0
27
+ cn = D
28
+ dn = 1.0
29
+ Ea = cn / dn
30
+ ba = 1.0
31
+
32
+ for _ in range(max_iter):
33
+ w_old = w
34
+
35
+ prior_diag = np.full(D, Ea)
36
+ prior_diag[0] = 0.0
37
+ invV = error_prec * Xcorr + np.diag(prior_diag)
38
+ L = cholesky(invV, lower=True)
39
+ V = cho_solve((L, True), np.eye(D))
40
+ w = error_prec * (V @ XYcorr)
41
+
42
+ sse = np.sum((X @ w - Y) ** 2)
43
+ bn = 0.5 * (sse + np.trace(Xcorr @ V)) + ba
44
+ error_prec = an / bn
45
+ ba = 1.0 / (error_prec + b0)
46
+
47
+ da = 2.0 / (Ea + d0)
48
+ dn = 2.0 * da + (w[1:] @ w[1:] + np.diag(V)[1:].sum())
49
+ Ea = cn / dn
50
+
51
+ if np.abs(w - w_old).sum() < 1e-4:
52
+ return w
53
+
54
+ raise PrevalenceRegressionError(
55
+ "Prevalence regression failed to converge within the iteration "
56
+ "limit. You can raise it with gamma_max_iter."
57
+ )
58
+
59
+
60
+ def opt_mu(lambda_, covar=None, max_iter=1000):
61
+ """Update the prevalence model (opt.mu, modes CTM and Pooled).
62
+
63
+ Returns ``(mu, gamma)``. Without covariates (CTM mode) ``mu`` is the
64
+ shared (K-1,) mean and ``gamma`` is None. With covariates ``mu`` is
65
+ (N, K-1), the per-document prior means, and ``gamma`` is (P, K-1).
66
+ """
67
+ if covar is None:
68
+ return lambda_.mean(axis=0), None
69
+
70
+ Xcorr = covar.T @ covar
71
+ gamma = np.column_stack([
72
+ vb_variational_reg(lambda_[:, k], covar, Xcorr=Xcorr, max_iter=max_iter)
73
+ for k in range(lambda_.shape[1])
74
+ ])
75
+ mu = covar @ gamma
76
+ return mu, gamma
77
+
78
+
79
+ def opt_sigma(nu, lambda_, mu, sigprior):
80
+ """Update the global covariance matrix (opt.sigma).
81
+
82
+ ``mu`` is (K-1,) for the shared mean or (N, K-1) for covariate models.
83
+ """
84
+ if mu.ndim == 1:
85
+ diff = lambda_ - mu[None, :]
86
+ else:
87
+ diff = lambda_ - mu
88
+ sigma = (diff.T @ diff + nu) / lambda_.shape[0]
89
+ return sigprior * np.diag(np.diag(sigma)) + (1 - sigprior) * sigma
90
+
91
+
92
+ def opt_beta(beta_ss):
93
+ """Update the topic-word distributions (opt.beta, LDA-beta mode)."""
94
+ row_sums = beta_ss[0].sum(axis=1, keepdims=True)
95
+ row_sums[row_sums == 0] = 1.0
96
+ return [beta_ss[0] / row_sums]
pystm/_spectral.py ADDED
@@ -0,0 +1,151 @@
1
+ """Spectral initialization via anchor words (port of spectral.R).
2
+
3
+ Implements the method of Arora et al. (2013): build the word co-occurrence
4
+ gram matrix, greedily select anchor words, then recover the topic-word
5
+ matrix with RecoverL2. The R package solves the simplex-constrained
6
+ regression exactly with quadprog by default; here we use a penalized NNLS
7
+ formulation which matches it closely. The exponentiated gradient
8
+ algorithm (R's ``recoverEG=TRUE`` option) is available as an alternative.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import numpy as np
14
+ from scipy.optimize import nnls
15
+ from scipy.sparse import csr_matrix
16
+
17
+
18
+ def gram(mat: csr_matrix) -> np.ndarray:
19
+ """Word co-occurrence gram matrix from a sparse doc-term matrix."""
20
+ nd = np.asarray(mat.sum(axis=1)).ravel()
21
+ keep = nd >= 2 # undefined for docs with fewer than 2 tokens
22
+ mat = mat[keep]
23
+ nd = nd[keep]
24
+ divisor = nd * (nd - 1)
25
+
26
+ htilde = mat.multiply(1.0 / np.sqrt(divisor)[:, None]).tocsr()
27
+ hhat = np.asarray(mat.multiply(1.0 / divisor[:, None]).sum(axis=0)).ravel()
28
+ Q = (htilde.T @ htilde).toarray()
29
+ Q[np.diag_indices_from(Q)] -= hhat
30
+ return Q
31
+
32
+
33
+ def fast_anchor(Qbar: np.ndarray, K: int) -> np.ndarray:
34
+ """Greedy anchor word selection by stabilized Gram-Schmidt."""
35
+ Qbar = Qbar.copy()
36
+ basis = np.zeros(K, dtype=np.int64)
37
+ row_squared_sums = (Qbar**2).sum(axis=1)
38
+
39
+ for i in range(K):
40
+ basis[i] = int(np.argmax(row_squared_sums))
41
+ max_val = row_squared_sums[basis[i]]
42
+ Qbar[basis[i]] *= 1.0 / np.sqrt(max_val)
43
+
44
+ inner_products = Qbar @ Qbar[basis[i]]
45
+ project = np.outer(inner_products, Qbar[basis[i]])
46
+ project[basis[: i + 1]] = 0.0
47
+ Qbar -= project
48
+ row_squared_sums = (Qbar**2).sum(axis=1)
49
+ row_squared_sums[basis[: i + 1]] = 0.0
50
+ return basis
51
+
52
+
53
+ def expgrad(X, y, XtX=None, alpha=None, tol=1e-7, max_iter=500):
54
+ """Exponentiated gradient for simplex-constrained least squares."""
55
+ if alpha is None:
56
+ alpha = np.full(X.shape[0], 1.0 / X.shape[0])
57
+ if XtX is None:
58
+ XtX = X @ X.T
59
+ ytX = y @ X.T
60
+
61
+ eta = 50.0
62
+ sse_old = np.inf
63
+ for _ in range(max_iter):
64
+ grad = ytX - alpha @ XtX
65
+ sse = grad @ grad
66
+ grad = 2.0 * eta * grad
67
+ alpha = alpha * np.exp(grad - grad.max())
68
+ alpha = alpha / alpha.sum()
69
+ if abs(np.sqrt(sse_old) - np.sqrt(sse)) < tol:
70
+ break
71
+ sse_old = sse
72
+ return alpha
73
+
74
+
75
+ def recover_l2(Qbar, anchors, wprob, solver="nnls"):
76
+ """Recover the K x V topic-word matrix from the anchor rows.
77
+
78
+ Each word's row of ``Qbar`` is expressed as a convex combination of
79
+ the anchor rows. ``solver="nnls"`` enforces the sum-to-one constraint
80
+ through a heavily weighted penalty row in a non-negative least-squares
81
+ problem (the analogue of the exact quadprog solve used by the R
82
+ package); ``solver="expgrad"`` uses exponentiated gradient descent
83
+ (R's ``recoverEG=TRUE``).
84
+ """
85
+ X = Qbar[anchors]
86
+ XtX = X @ X.T
87
+ K = len(anchors)
88
+ anchor_pos = {a: idx for idx, a in enumerate(anchors)}
89
+
90
+ if solver == "nnls":
91
+ penalty = 1000.0
92
+ X_aug = np.vstack([X.T, np.full(K, penalty)])
93
+
94
+ weights = np.empty((Qbar.shape[0], K))
95
+ for i in range(Qbar.shape[0]):
96
+ if i in anchor_pos:
97
+ vec = np.zeros(K)
98
+ vec[anchor_pos[i]] = 1.0
99
+ weights[i] = vec
100
+ elif solver == "nnls":
101
+ solution, _ = nnls(X_aug, np.append(Qbar[i], penalty))
102
+ solution = np.maximum(solution, np.finfo(np.float64).eps)
103
+ weights[i] = solution / solution.sum()
104
+ else:
105
+ solution = expgrad(X, Qbar[i], XtX)
106
+ solution = np.maximum(solution, np.finfo(np.float64).eps)
107
+ weights[i] = solution
108
+
109
+ A = weights * wprob[:, None]
110
+ A = A.T / A.sum(axis=0)[None, :].T
111
+ return A
112
+
113
+
114
+ def spectral_init(X: csr_matrix, K: int, max_vocab: int | None = 10000,
115
+ solver: str = "nnls") -> np.ndarray:
116
+ """Spectral initialization of beta (the Spectral branch of stm.init)."""
117
+ V = X.shape[1]
118
+ if K >= V:
119
+ raise ValueError(
120
+ "Spectral initialization cannot be used when K >= vocabulary size."
121
+ )
122
+
123
+ wprob = np.asarray(X.sum(axis=0)).ravel().astype(np.float64)
124
+ wprob /= wprob.sum()
125
+
126
+ keep = None
127
+ if max_vocab is not None and V > max_vocab:
128
+ keep = np.argsort(wprob)[::-1][:max_vocab]
129
+ X = X[:, keep]
130
+ wprob = wprob[keep]
131
+
132
+ Q = gram(X)
133
+ Qsums = Q.sum(axis=1)
134
+ if np.any(Qsums == 0):
135
+ nonzero = Qsums != 0
136
+ keep = np.where(nonzero)[0] if keep is None else keep[nonzero]
137
+ Q = Q[np.ix_(nonzero, nonzero)]
138
+ Qsums = Qsums[nonzero]
139
+ wprob = wprob[nonzero]
140
+ Qbar = Q / Qsums[:, None]
141
+
142
+ anchors = fast_anchor(Qbar, K)
143
+ beta = recover_l2(Qbar, anchors, wprob, solver=solver)
144
+
145
+ if keep is not None:
146
+ # reintroduce dropped words with a small amount of mass
147
+ beta_new = np.zeros((K, V))
148
+ beta_new[:, keep] = beta
149
+ beta_new += 0.001 / V
150
+ beta = beta_new / beta_new.sum(axis=1, keepdims=True)
151
+ return beta
pystm/_utils.py ADDED
@@ -0,0 +1,41 @@
1
+ """Shared numerical utilities (port of STMfunctions.R)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ from scipy.sparse import csr_matrix, issparse
7
+ from scipy.special import logsumexp
8
+
9
+
10
+ def safelog(x: np.ndarray, min_value: float = -1000.0) -> np.ndarray:
11
+ """log(x) with -inf (and anything below ``min_value``) clamped."""
12
+ with np.errstate(divide="ignore"):
13
+ out = np.log(x)
14
+ return np.maximum(out, min_value)
15
+
16
+
17
+ def row_softmax(mat: np.ndarray) -> np.ndarray:
18
+ """Row-wise softmax of a 2-d array."""
19
+ return np.exp(mat - logsumexp(mat, axis=1, keepdims=True))
20
+
21
+
22
+ def to_doc_list(X) -> list[tuple[np.ndarray, np.ndarray]]:
23
+ """Convert a document-term count matrix to the stm document format.
24
+
25
+ Returns one ``(word_indices, word_counts)`` pair per document, holding
26
+ the column indices of the document's distinct terms and their counts
27
+ (the two rows of the R package's document matrices, zero-indexed).
28
+ """
29
+ X = csr_matrix(X) if not issparse(X) else X.tocsr()
30
+ if (X.data < 0).any():
31
+ raise ValueError("X must contain non-negative counts.")
32
+ if np.any(X.data != np.round(X.data)):
33
+ raise ValueError("X must contain integer counts.")
34
+ docs = []
35
+ for i in range(X.shape[0]):
36
+ start, end = X.indptr[i], X.indptr[i + 1]
37
+ words = X.indices[start:end].astype(np.int64)
38
+ counts = X.data[start:end].astype(np.float64)
39
+ keep = counts > 0
40
+ docs.append((words[keep], counts[keep]))
41
+ return docs