tabullm 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tabullm-1.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Alireza S. Mahani and Mansour T.A. Sharabiani
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
tabullm-1.0.0/PKG-INFO ADDED
@@ -0,0 +1,149 @@
1
+ Metadata-Version: 2.4
2
+ Name: tabullm
3
+ Version: 1.0.0
4
+ Summary: Seamless Feature Extraction and Interpretation of Text Columns in Tabular Data Using Large Language Models
5
+ Author-email: "Alireza S. Mahani, Mansour T.A. Sharabiani" <alireza.s.mahani@gmail.com>
6
+ License-Expression: MIT
7
+ Keywords: text embedding,large language models,feature engineering,cross-validation
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: numpy>=1.24
11
+ Requires-Dist: pandas>=2.0
12
+ Requires-Dist: scikit-learn>=1.3
13
+ Requires-Dist: scipy>=1.10
14
+ Requires-Dist: pydantic>=2.0
15
+ Requires-Dist: langchain>=0.1
16
+ Requires-Dist: langchain-core>=0.1
17
+ Requires-Dist: langchain-community>=0.1
18
+ Requires-Dist: pillow>=12.1.1
19
+ Provides-Extra: dev
20
+ Requires-Dist: pytest>=7.0; extra == "dev"
21
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
22
+ Provides-Extra: kaggle
23
+ Requires-Dist: kaggle>=1.6; extra == "kaggle"
24
+ Dynamic: license-file
25
+
26
+ # TabuLLM
27
+
28
+ Python package for feature extraction and interpretation of text columns in tabular data using large language models.
29
+
30
+ ## Overview
31
+
32
+ TabuLLM integrates LLM-based text embeddings into scikit-learn pipelines for tabular datasets containing text columns. Built on LangChain and scikit-learn, it provides sklearn-compatible transformers for embedding, dimensionality reduction, and cluster interpretation.
33
+
34
+ ## Installation
35
+
36
+ ```bash
37
+ pip install tabullm
38
+ ```
39
+
40
+ Optional dependency groups provide extended functionality:
41
+
42
+ ```bash
43
+ pip install tabullm[huggingface] # HuggingFace embedding models
44
+ pip install tabullm[aws] # AWS Bedrock / Cohere
45
+ pip install tabullm[kaggle] # Kaggle dataset download (load_fraud)
46
+ pip install tabullm[viz] # Visualization utilities
47
+ ```
48
+
49
+ ## Core Components
50
+
51
+ **TextColumnTransformer** - Wraps LangChain embedding models (OpenAI, Anthropic, HuggingFace, etc.) with a sklearn interface. Handles multiple text columns with configurable concatenation and optional L2 normalization (`normalize=True`). Use `estimate_tokens()` to preview API cost before embedding.
52
+
53
+ **GMMFeatureExtractor** - Extends sklearn's `GaussianMixture` with a `transform()` method that returns per-cluster log-joint features $\log p(\mathbf{x}, c_k)$ — the quantity the GMM maximises for hard assignment — enabling use in sklearn pipelines. An optional `include_log_density` parameter appends the marginal log-density as an explicit outlier score. A companion `assignment_confidence_stats()` method returns per-observation cluster quality diagnostics (`max_posterior`, `entropy`, `log_joint_margin`, `log_density`).
54
+
55
+ **SphericalKMeans** - K-means clustering with cosine distance for L2-normalized embeddings. For normalized embeddings, mathematically equivalent to sklearn's `KMeans`. Available as an alternative hard-clustering option when GMM-based features are not needed.
56
+
57
+ **ClusterExplainer** - Generates natural language cluster descriptions using LLMs with automatic recursive summarization that scales to arbitrarily large datasets. Supports:
58
+ - Cost preview (`preview=True`) before LLM calls
59
+ - Optional outcome-based statistical testing (`y`) to characterize which clusters associate with a target variable
60
+ - Per-observation covariates (`observation_stats`) — e.g., from `assignment_confidence_stats()` — appended to the association table
61
+ - A synthesis step (`synthesize=True`) that produces a coherent interpretive narrative across all cluster results
62
+ - An outcome label (`y_label`) used only in the synthesis prompt; cluster descriptions are generated without knowledge of `y` (*blind labeling* principle)
63
+
64
+ **load_fraud()** - Data utility that downloads and caches the fraud detection dataset from Kaggle (requires `tabullm[kaggle]`), returning features, labels, and metadata.
65
+
66
+ ## Quick Example
67
+
68
+ ```python
69
+ from tabullm import TextColumnTransformer, GMMFeatureExtractor, ClusterExplainer
70
+ from langchain_huggingface import HuggingFaceEmbeddings
71
+ from langchain_openai import ChatOpenAI
72
+ from sklearn.pipeline import Pipeline
73
+ from sklearn.ensemble import RandomForestClassifier
74
+
75
+ # Embed text columns
76
+ embedding_model = HuggingFaceEmbeddings(
77
+ model_name='sentence-transformers/all-MiniLM-L6-v2'
78
+ )
79
+ text_transformer = TextColumnTransformer(
80
+ model=embedding_model,
81
+ colnames={'title': 'Title', 'description': 'Description'}
82
+ )
83
+
84
+ # Build pipeline: Embed → Reduce → Classify
85
+ pipeline = Pipeline([
86
+ ('embed', text_transformer),
87
+ ('reduce', GMMFeatureExtractor(n_components=10)),
88
+ ('classify', RandomForestClassifier(n_estimators=100))
89
+ ])
90
+
91
+ # Fit and predict
92
+ pipeline.fit(df[['title', 'description']], y)
93
+ predictions = pipeline.predict(df_new[['title', 'description']])
94
+
95
+ # Interpret clusters
96
+ explainer = ClusterExplainer(
97
+ llm=ChatOpenAI(model='gpt-4o-mini'),
98
+ text_transformer=text_transformer,
99
+ observations='job postings',
100
+ text_fields='titles and descriptions'
101
+ )
102
+
103
+ gmm = pipeline.named_steps['reduce']
104
+ cluster_labels = gmm.labels_
105
+
106
+ # Cluster descriptions only
107
+ result_df = explainer.explain(df, cluster_labels)
108
+
109
+ # With outcome association + synthesis narrative
110
+ result_df, global_stats, synthesis = explainer.explain(
111
+ df, cluster_labels,
112
+ y=y,
113
+ y_label='fraudulent posting (1=fraud, 0=legitimate)',
114
+ synthesize=True
115
+ )
116
+
117
+ # Include GMM cluster quality diagnostics in the association table
118
+ obs_stats = gmm.assignment_confidence_stats(
119
+ pipeline.named_steps['embed'].transform(df)
120
+ )
121
+ result_df, global_stats, stat_assoc_df, synthesis = explainer.explain(
122
+ df, cluster_labels,
123
+ y=y,
124
+ y_label='fraudulent posting (1=fraud, 0=legitimate)',
125
+ observation_stats=obs_stats,
126
+ synthesize=True
127
+ )
128
+ ```
129
+
130
+ ## Key Features
131
+
132
+ - sklearn-compatible API (Pipeline, ColumnTransformer, GridSearchCV)
133
+ - Access to 50+ embedding models via LangChain
134
+ - Multi-column text handling with flexible concatenation
135
+ - Optional L2 normalization of embedding vectors
136
+ - Token and cost estimation before embedding API calls
137
+ - GMM-based dimensionality reduction with per-cluster log-joint features
138
+ - Optional marginal log-density feature for explicit outlier scoring
139
+ - Per-observation cluster quality diagnostics (max posterior, entropy, log-joint margin, log density)
140
+ - Automatic recursive summarization for arbitrarily large datasets
141
+ - Cost estimation for LLM explanation calls
142
+ - Outcome-based cluster characterization (binary and continuous outcomes)
143
+ - User-supplied per-observation covariates in the association table
144
+ - Synthesis narrative connecting cluster descriptions to outcome patterns
145
+ - Blind labeling: cluster descriptions generated without knowledge of outcome vector
146
+
147
+ ## Citation
148
+
149
+ Sharabiani, M.T.A., Mahani, A.S., Bottle, A. et al. (2025). GenAI exceeds clinical experts in predicting acute kidney injury following paediatric cardiopulmonary bypass. Scientific Reports, 15, 20847. https://doi.org/10.1038/s41598-025-04651-8
@@ -0,0 +1,124 @@
1
+ # TabuLLM
2
+
3
+ Python package for feature extraction and interpretation of text columns in tabular data using large language models.
4
+
5
+ ## Overview
6
+
7
+ TabuLLM integrates LLM-based text embeddings into scikit-learn pipelines for tabular datasets containing text columns. Built on LangChain and scikit-learn, it provides sklearn-compatible transformers for embedding, dimensionality reduction, and cluster interpretation.
8
+
9
+ ## Installation
10
+
11
+ ```bash
12
+ pip install tabullm
13
+ ```
14
+
15
+ Optional dependency groups provide extended functionality:
16
+
17
+ ```bash
18
+ pip install tabullm[huggingface] # HuggingFace embedding models
19
+ pip install tabullm[aws] # AWS Bedrock / Cohere
20
+ pip install tabullm[kaggle] # Kaggle dataset download (load_fraud)
21
+ pip install tabullm[viz] # Visualization utilities
22
+ ```
23
+
24
+ ## Core Components
25
+
26
+ **TextColumnTransformer** - Wraps LangChain embedding models (OpenAI, Anthropic, HuggingFace, etc.) with a sklearn interface. Handles multiple text columns with configurable concatenation and optional L2 normalization (`normalize=True`). Use `estimate_tokens()` to preview API cost before embedding.
27
+
28
+ **GMMFeatureExtractor** - Extends sklearn's `GaussianMixture` with a `transform()` method that returns per-cluster log-joint features $\log p(\mathbf{x}, c_k)$ — the quantity the GMM maximises for hard assignment — enabling use in sklearn pipelines. An optional `include_log_density` parameter appends the marginal log-density as an explicit outlier score. A companion `assignment_confidence_stats()` method returns per-observation cluster quality diagnostics (`max_posterior`, `entropy`, `log_joint_margin`, `log_density`).
29
+
30
+ **SphericalKMeans** - K-means clustering with cosine distance for L2-normalized embeddings. For normalized embeddings, mathematically equivalent to sklearn's `KMeans`. Available as an alternative hard-clustering option when GMM-based features are not needed.
31
+
32
+ **ClusterExplainer** - Generates natural language cluster descriptions using LLMs with automatic recursive summarization that scales to arbitrarily large datasets. Supports:
33
+ - Cost preview (`preview=True`) before LLM calls
34
+ - Optional outcome-based statistical testing (`y`) to characterize which clusters associate with a target variable
35
+ - Per-observation covariates (`observation_stats`) — e.g., from `assignment_confidence_stats()` — appended to the association table
36
+ - A synthesis step (`synthesize=True`) that produces a coherent interpretive narrative across all cluster results
37
+ - An outcome label (`y_label`) used only in the synthesis prompt; cluster descriptions are generated without knowledge of `y` (*blind labeling* principle)
38
+
39
+ **load_fraud()** - Data utility that downloads and caches the fraud detection dataset from Kaggle (requires `tabullm[kaggle]`), returning features, labels, and metadata.
40
+
41
+ ## Quick Example
42
+
43
+ ```python
44
+ from tabullm import TextColumnTransformer, GMMFeatureExtractor, ClusterExplainer
45
+ from langchain_huggingface import HuggingFaceEmbeddings
46
+ from langchain_openai import ChatOpenAI
47
+ from sklearn.pipeline import Pipeline
48
+ from sklearn.ensemble import RandomForestClassifier
49
+
50
+ # Embed text columns
51
+ embedding_model = HuggingFaceEmbeddings(
52
+ model_name='sentence-transformers/all-MiniLM-L6-v2'
53
+ )
54
+ text_transformer = TextColumnTransformer(
55
+ model=embedding_model,
56
+ colnames={'title': 'Title', 'description': 'Description'}
57
+ )
58
+
59
+ # Build pipeline: Embed → Reduce → Classify
60
+ pipeline = Pipeline([
61
+ ('embed', text_transformer),
62
+ ('reduce', GMMFeatureExtractor(n_components=10)),
63
+ ('classify', RandomForestClassifier(n_estimators=100))
64
+ ])
65
+
66
+ # Fit and predict
67
+ pipeline.fit(df[['title', 'description']], y)
68
+ predictions = pipeline.predict(df_new[['title', 'description']])
69
+
70
+ # Interpret clusters
71
+ explainer = ClusterExplainer(
72
+ llm=ChatOpenAI(model='gpt-4o-mini'),
73
+ text_transformer=text_transformer,
74
+ observations='job postings',
75
+ text_fields='titles and descriptions'
76
+ )
77
+
78
+ gmm = pipeline.named_steps['reduce']
79
+ cluster_labels = gmm.labels_
80
+
81
+ # Cluster descriptions only
82
+ result_df = explainer.explain(df, cluster_labels)
83
+
84
+ # With outcome association + synthesis narrative
85
+ result_df, global_stats, synthesis = explainer.explain(
86
+ df, cluster_labels,
87
+ y=y,
88
+ y_label='fraudulent posting (1=fraud, 0=legitimate)',
89
+ synthesize=True
90
+ )
91
+
92
+ # Include GMM cluster quality diagnostics in the association table
93
+ obs_stats = gmm.assignment_confidence_stats(
94
+ pipeline.named_steps['embed'].transform(df)
95
+ )
96
+ result_df, global_stats, stat_assoc_df, synthesis = explainer.explain(
97
+ df, cluster_labels,
98
+ y=y,
99
+ y_label='fraudulent posting (1=fraud, 0=legitimate)',
100
+ observation_stats=obs_stats,
101
+ synthesize=True
102
+ )
103
+ ```
104
+
105
+ ## Key Features
106
+
107
+ - sklearn-compatible API (Pipeline, ColumnTransformer, GridSearchCV)
108
+ - Access to 50+ embedding models via LangChain
109
+ - Multi-column text handling with flexible concatenation
110
+ - Optional L2 normalization of embedding vectors
111
+ - Token and cost estimation before embedding API calls
112
+ - GMM-based dimensionality reduction with per-cluster log-joint features
113
+ - Optional marginal log-density feature for explicit outlier scoring
114
+ - Per-observation cluster quality diagnostics (max posterior, entropy, log-joint margin, log density)
115
+ - Automatic recursive summarization for arbitrarily large datasets
116
+ - Cost estimation for LLM explanation calls
117
+ - Outcome-based cluster characterization (binary and continuous outcomes)
118
+ - User-supplied per-observation covariates in the association table
119
+ - Synthesis narrative connecting cluster descriptions to outcome patterns
120
+ - Blind labeling: cluster descriptions generated without knowledge of outcome vector
121
+
122
+ ## Citation
123
+
124
+ Sharabiani, M.T.A., Mahani, A.S., Bottle, A. et al. (2025). GenAI exceeds clinical experts in predicting acute kidney injury following paediatric cardiopulmonary bypass. Scientific Reports, 15, 20847. https://doi.org/10.1038/s41598-025-04651-8
@@ -0,0 +1,37 @@
1
+ [build-system]
2
+ requires = ["setuptools>=42", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.setuptools]
6
+ packages = { find = { include = ["tabullm*"] } }
7
+
8
+ [project]
9
+ name = "tabullm"
10
+ version = "1.0.0"
11
+ description = "Seamless Feature Extraction and Interpretation of Text Columns in Tabular Data Using Large Language Models"
12
+ readme = "README.md"
13
+ license = "MIT"
14
+ authors = [
15
+ { name = "Alireza S. Mahani, Mansour T.A. Sharabiani", email = "alireza.s.mahani@gmail.com" }
16
+ ]
17
+ keywords = ["text embedding", "large language models", "feature engineering", "cross-validation"]
18
+ dependencies = [
19
+ "numpy>=1.24",
20
+ "pandas>=2.0",
21
+ "scikit-learn>=1.3",
22
+ "scipy>=1.10",
23
+ "pydantic>=2.0",
24
+ "langchain>=0.1",
25
+ "langchain-core>=0.1",
26
+ "langchain-community>=0.1",
27
+ "pillow>=12.1.1", # Security: CVE fix
28
+ ]
29
+
30
+ [project.optional-dependencies]
31
+ dev = [
32
+ "pytest>=7.0",
33
+ "pytest-cov>=4.0"
34
+ ]
35
+ kaggle = [
36
+ "kaggle>=1.6"
37
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,25 @@
1
+ """
2
+ TabuLLM: Feature Extraction and Interpretation of Text Columns in Tabular Data Using LLMs
3
+
4
+ A Python package for seamless integration of text embeddings into tabular ML pipelines,
5
+ with tools for clustering and LLM-based interpretation.
6
+ """
7
+
8
+ __version__ = "1.0.0"
9
+
10
+ # Core components
11
+ from .embed import TextColumnTransformer
12
+ from .cluster import SphericalKMeans, GMMFeatureExtractor
13
+ from .explain import ClusterExplainer
14
+
15
+ # Data utilities
16
+ from .data import load_fraud
17
+
18
+ __all__ = [
19
+ "TextColumnTransformer",
20
+ "SphericalKMeans",
21
+ "GMMFeatureExtractor",
22
+ "ClusterExplainer",
23
+ "load_fraud",
24
+ "__version__",
25
+ ]
@@ -0,0 +1,32 @@
1
+ """
2
+ Clustering and feature extraction for text embeddings
3
+
4
+ This module provides clustering algorithms optimized for text embeddings,
5
+ including spherical k-means (cosine distance) and GMM-based feature extraction.
6
+
7
+ Classes
8
+ -------
9
+ SphericalKMeans : Clustering with cosine distance for normalized embeddings
10
+ GMMFeatureExtractor : Transform embeddings to Mahalanobis distances for ML pipelines
11
+
12
+ Examples
13
+ --------
14
+ >>> from tabullm.cluster import SphericalKMeans, GMMFeatureExtractor
15
+ >>> from sklearn.pipeline import Pipeline
16
+ >>>
17
+ >>> # Spherical k-means clustering
18
+ >>> kmeans = SphericalKMeans(n_clusters=5)
19
+ >>> labels = kmeans.fit_predict(embeddings)
20
+ >>>
21
+ >>> # GMM feature extraction in pipeline
22
+ >>> pipeline = Pipeline([
23
+ ... ('gmm', GMMFeatureExtractor(n_components=10)),
24
+ ... ('clf', RandomForestClassifier())
25
+ ... ])
26
+ """
27
+
28
+ # Cluster submodule exports
29
+ from .spherical_kmeans import SphericalKMeans
30
+ from .gmm_features import GMMFeatureExtractor
31
+
32
+ __all__ = ['SphericalKMeans', 'GMMFeatureExtractor']
@@ -0,0 +1,237 @@
1
+ """
2
+ Gaussian Mixture Model feature extraction for embeddings.
3
+
4
+ Provides GMMFeatureExtractor, which adds transform() to sklearn's GaussianMixture
5
+ for feature extraction. Returns per-cluster log-joint features log p(x, c_k),
6
+ the quantity the GMM maximises for hard assignment.
7
+
8
+ Key contribution: sklearn's GaussianMixture lacks transform() - only has predict()
9
+ and predict_proba(). This fills that gap for ML pipelines.
10
+ """
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ from sklearn.base import BaseEstimator, TransformerMixin
15
+ from sklearn.mixture import GaussianMixture
16
+ from sklearn.utils.validation import check_array, check_is_fitted
17
+
18
+
19
+ class GMMFeatureExtractor(BaseEstimator, TransformerMixin):
20
+ """
21
+ Extract features using Gaussian Mixture Model log-joint probabilities.
22
+
23
+ ``transform()`` returns an (n_samples, K) array of per-cluster log-joints
24
+
25
+ ℓ_k(x) = log p(x, c_k) = log π_k + log p(x | c_k)
26
+
27
+ where log p(x | c_k) is the Gaussian log-likelihood under component k.
28
+ This is the quantity the GMM maximises for hard assignment
29
+ (k* = argmax_k ℓ_k), so features are in exact correspondence with the
30
+ model's own criterion. Posterior probabilities are a softmax of these
31
+ features.
32
+
33
+ Adds transform() method that sklearn's GaussianMixture lacks, enabling
34
+ use in feature extraction pipelines.
35
+
36
+ Parameters
37
+ ----------
38
+ n_components : int, default=10
39
+ Number of mixture components.
40
+
41
+ covariance_type : {'full', 'tied', 'diag', 'spherical'}, default='full'
42
+ Type of covariance parameters.
43
+
44
+ n_init : int, default=3
45
+ Number of initializations to perform.
46
+
47
+ random_state : int, default=None
48
+ Random seed for reproducibility.
49
+
50
+ include_log_density : bool, default=False
51
+ If True, append log p(x) = log Σ_k exp(ℓ_k(x)) as a (K+1)-th column.
52
+ This is the log marginal likelihood — how well the overall mixture
53
+ explains x. It is a deterministic function of the K log-joints
54
+ (log-sum-exp), so it adds no information for expressive nonlinear
55
+ models but provides an explicit outlier score for linear ones.
56
+
57
+ **gmm_kwargs
58
+ Additional parameters for GaussianMixture.
59
+
60
+ Attributes
61
+ ----------
62
+ gmm_ : GaussianMixture
63
+ Fitted Gaussian Mixture Model.
64
+
65
+ means_ : array, shape (n_components, n_features)
66
+ Component means.
67
+
68
+ covariances_ : array
69
+ Component covariances.
70
+
71
+ labels_ : array, shape (n_samples,)
72
+ Cluster labels for each sample from the training set.
73
+
74
+ Examples
75
+ --------
76
+ >>> from tabullm import TextColumnTransformer, GMMFeatureExtractor
77
+ >>> from sklearn.pipeline import Pipeline
78
+ >>>
79
+ >>> pipeline = Pipeline([
80
+ ... ('embed', TextColumnTransformer(model=model, text_columns=['text'])),
81
+ ... ('gmm', GMMFeatureExtractor(n_components=10)),
82
+ ... ('clf', RandomForestClassifier())
83
+ ... ])
84
+ """
85
+
86
+ def __init__(self, n_components=10, covariance_type='full',
87
+ n_init=3, random_state=None, include_log_density=False,
88
+ **gmm_kwargs):
89
+ self.n_components = n_components
90
+ self.covariance_type = covariance_type
91
+ self.n_init = n_init
92
+ self.random_state = random_state
93
+ self.include_log_density = include_log_density
94
+ self.gmm_kwargs = gmm_kwargs
95
+
96
+ def fit(self, X, y=None):
97
+ """Fit GMM to data."""
98
+ X = check_array(X)
99
+
100
+ self.gmm_ = GaussianMixture(
101
+ n_components=self.n_components,
102
+ covariance_type=self.covariance_type,
103
+ n_init=self.n_init,
104
+ random_state=self.random_state,
105
+ **self.gmm_kwargs
106
+ )
107
+ self.gmm_.fit(X)
108
+
109
+ self.means_ = self.gmm_.means_
110
+ self.covariances_ = self.gmm_.covariances_
111
+ self.n_components_ = self.n_components
112
+ self.labels_ = self.gmm_.predict(X)
113
+
114
+ return self
115
+
116
+ def transform(self, X):
117
+ """
118
+ Transform X to per-cluster log-joint features.
119
+
120
+ Parameters
121
+ ----------
122
+ X : array-like of shape (n_samples, n_features)
123
+
124
+ Returns
125
+ -------
126
+ ndarray of shape (n_samples, K) or (n_samples, K+1)
127
+ Per-cluster log p(x, c_k) for k = 0 … K-1, plus log p(x) as
128
+ the final column when ``include_log_density=True``.
129
+ """
130
+ check_is_fitted(self, 'gmm_')
131
+ X = check_array(X)
132
+ log_joints = self.gmm_._estimate_weighted_log_prob(X) # (n_samples, K)
133
+ if self.include_log_density:
134
+ log_density = self.gmm_.score_samples(X).reshape(-1, 1)
135
+ return np.hstack([log_joints, log_density])
136
+ return log_joints
137
+
138
+ def get_feature_names_out(self, input_features=None):
139
+ """
140
+ Feature names: ``lj_0 … lj_{K-1}``, plus ``log_density`` when
141
+ ``include_log_density=True``.
142
+ """
143
+ check_is_fitted(self, 'gmm_')
144
+ names = [f'lj_{k}' for k in range(self.n_components_)]
145
+ if self.include_log_density:
146
+ names.append('log_density')
147
+ return np.array(names)
148
+
149
+ def predict(self, X):
150
+ """Predict cluster labels."""
151
+ check_is_fitted(self, 'gmm_')
152
+ return self.gmm_.predict(check_array(X))
153
+
154
+ def assignment_confidence_stats(self, X):
155
+ """
156
+ Compute per-observation assignment confidence statistics.
157
+
158
+ Returns four metrics that characterise how confidently and
159
+ unambiguously the fitted GMM assigns each observation to a cluster,
160
+ plus the marginal log-density as an outlier score.
161
+ All four are scalars per observation and can be aggregated to
162
+ cluster-level summaries via ``groupby(cluster_labels).mean()``.
163
+
164
+ Parameters
165
+ ----------
166
+ X : array-like of shape (n_samples, n_features)
167
+ Data to score. Must have the same number of features as the
168
+ training data.
169
+
170
+ Returns
171
+ -------
172
+ pd.DataFrame of shape (n_samples, 4)
173
+ Columns:
174
+
175
+ ``max_posterior``
176
+ Maximum posterior probability :math:`\\max_k p(k \\mid x)`.
177
+ Range :math:`(1/K, 1]`. Higher means the model is more
178
+ certain about the cluster assignment.
179
+
180
+ ``entropy``
181
+ Shannon entropy :math:`-\\sum_k p(k \\mid x) \\log p(k \\mid x)`
182
+ of the posterior distribution.
183
+ Range :math:`[0, \\log K]`. Lower means the probability mass
184
+ is concentrated on fewer clusters.
185
+
186
+ ``log_joint_margin``
187
+ Difference between the top-1 and top-2 per-cluster log joints
188
+ :math:`\\ell_{k^*}(x) - \\max_{j \\neq k^*} \\ell_j(x)`.
189
+ Range :math:`[0, \\infty)`. Larger means the assigned cluster
190
+ is more decisively preferred over its nearest rival.
191
+
192
+ ``log_density``
193
+ Log marginal likelihood :math:`\\log p(x) = \\log \\sum_k e^{\\ell_k(x)}`.
194
+ Range :math:`(-\\infty, 0]`. Captures how well the overall GMM
195
+ explains this observation; large negative values flag outliers.
196
+ Orthogonal to the three assignment metrics above: two observations
197
+ can share the same ``max_posterior`` yet have very different
198
+ ``log_density``.
199
+
200
+ Examples
201
+ --------
202
+ >>> stats = gmm.assignment_confidence_stats(X_embeddings)
203
+ >>> stats.describe()
204
+
205
+ >>> # Per-cluster rollup
206
+ >>> stats.groupby(cluster_labels).mean()
207
+
208
+ >>> # Feed into ClusterExplainer
209
+ >>> explainer.explain(X, cluster_labels, y=y, observation_stats=stats)
210
+ """
211
+ check_is_fitted(self, 'gmm_')
212
+ X = check_array(X)
213
+
214
+ # (n_samples, K) posterior probabilities
215
+ posteriors = self.gmm_.predict_proba(X)
216
+
217
+ # 1. Assignment confidence: max posterior per observation
218
+ assignment_confidence = posteriors.max(axis=1)
219
+
220
+ # 2. Entropy: -sum_k p(k|x) log p(k|x)
221
+ log_posteriors = np.log(np.clip(posteriors, 1e-10, 1.0))
222
+ entropy = -(posteriors * log_posteriors).sum(axis=1)
223
+
224
+ # 3. Log-joint margin: top-1 minus top-2 log joint
225
+ log_joints = self.gmm_._estimate_weighted_log_prob(X) # (n_samples, K)
226
+ sorted_lj = np.sort(log_joints, axis=1)[:, ::-1]
227
+ log_joint_margin = sorted_lj[:, 0] - sorted_lj[:, 1]
228
+
229
+ # 4. Log density: log p(x) = log-sum-exp of log joints
230
+ log_density = self.gmm_.score_samples(X)
231
+
232
+ return pd.DataFrame({
233
+ 'max_posterior': assignment_confidence,
234
+ 'entropy': entropy,
235
+ 'log_joint_margin': log_joint_margin,
236
+ 'log_density': log_density,
237
+ })