segmentae 1.5.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- segmentae/__init__.py +83 -0
- segmentae/anomaly_detection.py +20 -0
- segmentae/autoencoders/__init__.py +16 -0
- segmentae/autoencoders/batch_norm.py +208 -0
- segmentae/autoencoders/dense.py +211 -0
- segmentae/autoencoders/ensemble.py +219 -0
- segmentae/clusters/__init__.py +18 -0
- segmentae/clusters/clustering.py +171 -0
- segmentae/clusters/models.py +438 -0
- segmentae/clusters/registry.py +75 -0
- segmentae/core/__init__.py +65 -0
- segmentae/core/base.py +108 -0
- segmentae/core/constants.py +91 -0
- segmentae/core/exceptions.py +60 -0
- segmentae/core/types.py +55 -0
- segmentae/data_sources/__init__.py +3 -0
- segmentae/data_sources/examples.py +198 -0
- segmentae/metrics/__init__.py +6 -0
- segmentae/metrics/performance_metrics.py +119 -0
- segmentae/optimization/__init__.py +6 -0
- segmentae/optimization/optimizer.py +375 -0
- segmentae/pipeline/__init__.py +21 -0
- segmentae/pipeline/reconstruction.py +214 -0
- segmentae/pipeline/segmentae.py +562 -0
- segmentae/processing/__init__.py +21 -0
- segmentae/processing/preprocessing.py +263 -0
- segmentae/processing/simplifier.py +74 -0
- segmentae/utils/__init__.py +17 -0
- segmentae/utils/validation.py +94 -0
- segmentae-1.5.20.dist-info/METADATA +393 -0
- segmentae-1.5.20.dist-info/RECORD +34 -0
- segmentae-1.5.20.dist-info/WHEEL +5 -0
- segmentae-1.5.20.dist-info/licenses/LICENSE +21 -0
- segmentae-1.5.20.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from pydantic import BaseModel, field_validator
|
|
6
|
+
from sklearn.cluster import AgglomerativeClustering, KMeans, MiniBatchKMeans
|
|
7
|
+
from sklearn.mixture import GaussianMixture
|
|
8
|
+
|
|
9
|
+
from segmentae.clusters.registry import ClusterRegistry
|
|
10
|
+
from segmentae.core.base import AbstractClusterModel
|
|
11
|
+
from segmentae.core.constants import ClusterModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class KMeansCluster(AbstractClusterModel):
|
|
15
|
+
"""
|
|
16
|
+
K-Means clustering implementation.
|
|
17
|
+
|
|
18
|
+
K-Means partitions data into n_clusters by minimizing within-cluster
|
|
19
|
+
variance. It's efficient and works well for spherical clusters.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
class Config(BaseModel):
|
|
23
|
+
"""Pydantic configuration for K-Means parameters."""
|
|
24
|
+
n_clusters: int = 3
|
|
25
|
+
random_state: int = 0
|
|
26
|
+
max_iter: int = 300
|
|
27
|
+
|
|
28
|
+
@field_validator('n_clusters')
|
|
29
|
+
def validate_n_clusters(cls, v):
|
|
30
|
+
if v < 1:
|
|
31
|
+
raise ValueError("n_clusters must be >= 1")
|
|
32
|
+
return v
|
|
33
|
+
|
|
34
|
+
@field_validator('max_iter')
|
|
35
|
+
def validate_max_iter(cls, v):
|
|
36
|
+
if v < 1:
|
|
37
|
+
raise ValueError("max_iter must be >= 1")
|
|
38
|
+
return v
|
|
39
|
+
|
|
40
|
+
class Config:
|
|
41
|
+
use_enum_values = True
|
|
42
|
+
|
|
43
|
+
def __init__(self,
|
|
44
|
+
n_clusters: int = 3,
|
|
45
|
+
random_state: int = 0,
|
|
46
|
+
max_iter: int = 300):
|
|
47
|
+
"""
|
|
48
|
+
Initialize K-Means clustering model.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
n_clusters: Number of clusters to form
|
|
52
|
+
random_state: Random seed for reproducibility
|
|
53
|
+
max_iter: Maximum iterations for convergence
|
|
54
|
+
"""
|
|
55
|
+
self.config = self.Config(
|
|
56
|
+
n_clusters=n_clusters,
|
|
57
|
+
random_state=random_state,
|
|
58
|
+
max_iter=max_iter
|
|
59
|
+
)
|
|
60
|
+
self._model: Optional[KMeans] = None
|
|
61
|
+
self._is_fitted: bool = False
|
|
62
|
+
|
|
63
|
+
def fit(self, X: pd.DataFrame) -> None:
|
|
64
|
+
"""Fit K-Means model to data."""
|
|
65
|
+
self._validate_input(X, "Training data")
|
|
66
|
+
|
|
67
|
+
self._model = KMeans(
|
|
68
|
+
n_clusters=self.config.n_clusters,
|
|
69
|
+
random_state=self.config.random_state,
|
|
70
|
+
max_iter=self.config.max_iter
|
|
71
|
+
)
|
|
72
|
+
self._model.fit(X)
|
|
73
|
+
self._is_fitted = True
|
|
74
|
+
|
|
75
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
76
|
+
"""Predict cluster labels for data."""
|
|
77
|
+
self._validate_fitted()
|
|
78
|
+
self._validate_input(X, "Prediction data")
|
|
79
|
+
return self._model.predict(X)
|
|
80
|
+
|
|
81
|
+
def fit_predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
82
|
+
"""Fit model and predict labels in one step."""
|
|
83
|
+
self.fit(X)
|
|
84
|
+
return self.predict(X)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def n_clusters(self) -> int:
|
|
88
|
+
"""Return number of clusters."""
|
|
89
|
+
return self.config.n_clusters
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def is_fitted(self) -> bool:
|
|
93
|
+
"""Check if model is fitted."""
|
|
94
|
+
return self._is_fitted
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class MiniBatchKMeansCluster(AbstractClusterModel):
|
|
98
|
+
"""
|
|
99
|
+
MiniBatch K-Means clustering implementation.
|
|
100
|
+
|
|
101
|
+
A variant of K-Means that uses mini-batches to reduce computation time
|
|
102
|
+
while approximating the standard K-Means algorithm. Ideal for large datasets.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
class Config(BaseModel):
|
|
106
|
+
"""Pydantic configuration for MiniBatch K-Means parameters."""
|
|
107
|
+
n_clusters: int = 3
|
|
108
|
+
random_state: int = 0
|
|
109
|
+
max_iter: int = 150
|
|
110
|
+
|
|
111
|
+
@field_validator('n_clusters')
|
|
112
|
+
def validate_n_clusters(cls, v):
|
|
113
|
+
if v < 1:
|
|
114
|
+
raise ValueError("n_clusters must be >= 1")
|
|
115
|
+
return v
|
|
116
|
+
|
|
117
|
+
@field_validator('max_iter')
|
|
118
|
+
def validate_max_iter(cls, v):
|
|
119
|
+
if v < 1:
|
|
120
|
+
raise ValueError("max_iter must be >= 1")
|
|
121
|
+
return v
|
|
122
|
+
|
|
123
|
+
class Config:
|
|
124
|
+
use_enum_values = True
|
|
125
|
+
|
|
126
|
+
def __init__(self,
|
|
127
|
+
n_clusters: int = 3,
|
|
128
|
+
random_state: int = 0,
|
|
129
|
+
max_iter: int = 150):
|
|
130
|
+
"""
|
|
131
|
+
Initialize MiniBatch K-Means clustering model.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
n_clusters: Number of clusters to form
|
|
135
|
+
random_state: Random seed for reproducibility
|
|
136
|
+
max_iter: Maximum iterations for convergence
|
|
137
|
+
"""
|
|
138
|
+
self.config = self.Config(
|
|
139
|
+
n_clusters=n_clusters,
|
|
140
|
+
random_state=random_state,
|
|
141
|
+
max_iter=max_iter
|
|
142
|
+
)
|
|
143
|
+
self._model: Optional[MiniBatchKMeans] = None
|
|
144
|
+
self._is_fitted: bool = False
|
|
145
|
+
|
|
146
|
+
def fit(self, X: pd.DataFrame) -> None:
|
|
147
|
+
"""Fit MiniBatch K-Means model to data."""
|
|
148
|
+
self._validate_input(X, "Training data")
|
|
149
|
+
|
|
150
|
+
self._model = MiniBatchKMeans(
|
|
151
|
+
n_clusters=self.config.n_clusters,
|
|
152
|
+
random_state=self.config.random_state,
|
|
153
|
+
max_iter=self.config.max_iter
|
|
154
|
+
)
|
|
155
|
+
self._model.fit(X)
|
|
156
|
+
self._is_fitted = True
|
|
157
|
+
|
|
158
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
159
|
+
"""Predict cluster labels for data."""
|
|
160
|
+
self._validate_fitted()
|
|
161
|
+
self._validate_input(X, "Prediction data")
|
|
162
|
+
return self._model.predict(X)
|
|
163
|
+
|
|
164
|
+
def fit_predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
165
|
+
"""Fit model and predict labels in one step."""
|
|
166
|
+
self.fit(X)
|
|
167
|
+
return self.predict(X)
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def n_clusters(self) -> int:
|
|
171
|
+
"""Return number of clusters."""
|
|
172
|
+
return self.config.n_clusters
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def is_fitted(self) -> bool:
|
|
176
|
+
"""Check if model is fitted."""
|
|
177
|
+
return self._is_fitted
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class GaussianMixtureCluster(AbstractClusterModel):
|
|
181
|
+
"""
|
|
182
|
+
Gaussian Mixture Model clustering implementation.
|
|
183
|
+
|
|
184
|
+
GMM assumes data is generated from a mixture of Gaussian distributions.
|
|
185
|
+
Uses Expectation-Maximization algorithm for parameter estimation.
|
|
186
|
+
Provides probabilistic cluster assignments.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
class Config(BaseModel):
|
|
190
|
+
"""Pydantic configuration for GMM parameters."""
|
|
191
|
+
n_components: int = 3
|
|
192
|
+
covariance_type: str = "full"
|
|
193
|
+
max_iter: int = 150
|
|
194
|
+
init_params: str = "k-means++"
|
|
195
|
+
|
|
196
|
+
@field_validator('n_components')
|
|
197
|
+
def validate_n_components(cls, v):
|
|
198
|
+
if v < 1:
|
|
199
|
+
raise ValueError("n_components must be >= 1")
|
|
200
|
+
return v
|
|
201
|
+
|
|
202
|
+
@field_validator('covariance_type')
|
|
203
|
+
def validate_covariance_type(cls, v):
|
|
204
|
+
valid_types = ['full', 'tied', 'diag', 'spherical']
|
|
205
|
+
if v not in valid_types:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"covariance_type must be one of {valid_types}, got '{v}'"
|
|
208
|
+
)
|
|
209
|
+
return v
|
|
210
|
+
|
|
211
|
+
@field_validator('max_iter')
|
|
212
|
+
def validate_max_iter(cls, v):
|
|
213
|
+
if v < 1:
|
|
214
|
+
raise ValueError("max_iter must be >= 1")
|
|
215
|
+
return v
|
|
216
|
+
|
|
217
|
+
class Config:
|
|
218
|
+
use_enum_values = True
|
|
219
|
+
|
|
220
|
+
def __init__(self,
|
|
221
|
+
n_components: int = 3,
|
|
222
|
+
covariance_type: str = "full",
|
|
223
|
+
max_iter: int = 150,
|
|
224
|
+
init_params: str = "k-means++"):
|
|
225
|
+
"""
|
|
226
|
+
Initialize Gaussian Mixture Model.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
n_components: Number of mixture components (clusters)
|
|
230
|
+
covariance_type: Type of covariance ('full', 'tied', 'diag', 'spherical')
|
|
231
|
+
max_iter: Maximum EM iterations
|
|
232
|
+
init_params: Initialization method for parameters
|
|
233
|
+
"""
|
|
234
|
+
self.config = self.Config(
|
|
235
|
+
n_components=n_components,
|
|
236
|
+
covariance_type=covariance_type,
|
|
237
|
+
max_iter=max_iter,
|
|
238
|
+
init_params=init_params
|
|
239
|
+
)
|
|
240
|
+
self._model: Optional[GaussianMixture] = None
|
|
241
|
+
self._is_fitted: bool = False
|
|
242
|
+
|
|
243
|
+
def fit(self, X: pd.DataFrame) -> None:
|
|
244
|
+
"""Fit Gaussian Mixture Model to data."""
|
|
245
|
+
self._validate_input(X, "Training data")
|
|
246
|
+
|
|
247
|
+
self._model = GaussianMixture(
|
|
248
|
+
n_components=self.config.n_components,
|
|
249
|
+
covariance_type=self.config.covariance_type,
|
|
250
|
+
max_iter=self.config.max_iter
|
|
251
|
+
)
|
|
252
|
+
self._model.fit(X)
|
|
253
|
+
self._is_fitted = True
|
|
254
|
+
|
|
255
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
256
|
+
"""Predict cluster labels for data."""
|
|
257
|
+
self._validate_fitted()
|
|
258
|
+
self._validate_input(X, "Prediction data")
|
|
259
|
+
return self._model.predict(X)
|
|
260
|
+
|
|
261
|
+
def fit_predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
262
|
+
"""Fit model and predict labels in one step."""
|
|
263
|
+
self.fit(X)
|
|
264
|
+
return self.predict(X)
|
|
265
|
+
|
|
266
|
+
@property
|
|
267
|
+
def n_clusters(self) -> int:
|
|
268
|
+
"""Return number of clusters."""
|
|
269
|
+
return self.config.n_components
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def is_fitted(self) -> bool:
|
|
273
|
+
"""Check if model is fitted."""
|
|
274
|
+
return self._is_fitted
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class AgglomerativeCluster(AbstractClusterModel):
|
|
278
|
+
"""
|
|
279
|
+
Agglomerative (Hierarchical) clustering implementation.
|
|
280
|
+
|
|
281
|
+
Builds nested clusters by successively merging or splitting them.
|
|
282
|
+
Works bottom-up, starting with each sample as its own cluster.
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
class Config(BaseModel):
|
|
286
|
+
"""Pydantic configuration for Agglomerative clustering parameters."""
|
|
287
|
+
n_clusters: int = 3
|
|
288
|
+
linkage: str = "ward"
|
|
289
|
+
distance_threshold: Optional[float] = None
|
|
290
|
+
|
|
291
|
+
@field_validator('n_clusters')
|
|
292
|
+
def validate_n_clusters(cls, v):
|
|
293
|
+
if v < 1:
|
|
294
|
+
raise ValueError("n_clusters must be >= 1")
|
|
295
|
+
return v
|
|
296
|
+
|
|
297
|
+
@field_validator('linkage')
|
|
298
|
+
def validate_linkage(cls, v):
|
|
299
|
+
valid_linkage = ['ward', 'complete', 'average', 'single']
|
|
300
|
+
if v not in valid_linkage:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"linkage must be one of {valid_linkage}, got '{v}'"
|
|
303
|
+
)
|
|
304
|
+
return v
|
|
305
|
+
|
|
306
|
+
class Config:
|
|
307
|
+
use_enum_values = True
|
|
308
|
+
|
|
309
|
+
def __init__(self,
|
|
310
|
+
n_clusters: int = 3,
|
|
311
|
+
linkage: str = "ward",
|
|
312
|
+
distance_threshold: Optional[float] = None):
|
|
313
|
+
"""
|
|
314
|
+
Initialize Agglomerative clustering model.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
n_clusters: Number of clusters to find
|
|
318
|
+
linkage: Linkage criterion to use
|
|
319
|
+
distance_threshold: Distance threshold for merging clusters
|
|
320
|
+
"""
|
|
321
|
+
self.config = self.Config(
|
|
322
|
+
n_clusters=n_clusters,
|
|
323
|
+
linkage=linkage,
|
|
324
|
+
distance_threshold=distance_threshold
|
|
325
|
+
)
|
|
326
|
+
self._model: Optional[AgglomerativeClustering] = None
|
|
327
|
+
self._labels: Optional[np.ndarray] = None
|
|
328
|
+
self._is_fitted: bool = False
|
|
329
|
+
|
|
330
|
+
def fit(self, X: pd.DataFrame) -> None:
|
|
331
|
+
"""Fit Agglomerative clustering model to data."""
|
|
332
|
+
self._validate_input(X, "Training data")
|
|
333
|
+
|
|
334
|
+
self._model = AgglomerativeClustering(
|
|
335
|
+
n_clusters=self.config.n_clusters,
|
|
336
|
+
linkage=self.config.linkage,
|
|
337
|
+
distance_threshold=self.config.distance_threshold
|
|
338
|
+
)
|
|
339
|
+
self._labels = self._model.fit_predict(X)
|
|
340
|
+
self._is_fitted = True
|
|
341
|
+
|
|
342
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
343
|
+
"""
|
|
344
|
+
Predict cluster labels for data.
|
|
345
|
+
|
|
346
|
+
Note: Agglomerative clustering doesn't support prediction on new data
|
|
347
|
+
after fitting. This method re-fits the model on the new data.
|
|
348
|
+
"""
|
|
349
|
+
self._validate_input(X, "Prediction data")
|
|
350
|
+
|
|
351
|
+
model = AgglomerativeClustering(
|
|
352
|
+
n_clusters=self.config.n_clusters,
|
|
353
|
+
linkage=self.config.linkage,
|
|
354
|
+
distance_threshold=self.config.distance_threshold
|
|
355
|
+
)
|
|
356
|
+
return model.fit_predict(X)
|
|
357
|
+
|
|
358
|
+
def fit_predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
359
|
+
"""Fit model and predict labels in one step."""
|
|
360
|
+
self.fit(X)
|
|
361
|
+
return self._labels
|
|
362
|
+
|
|
363
|
+
@property
|
|
364
|
+
def n_clusters(self) -> int:
|
|
365
|
+
"""Return number of clusters."""
|
|
366
|
+
return self.config.n_clusters
|
|
367
|
+
|
|
368
|
+
@property
|
|
369
|
+
def is_fitted(self) -> bool:
|
|
370
|
+
"""Check if model is fitted."""
|
|
371
|
+
return self._is_fitted
|
|
372
|
+
|
|
373
|
+
class ClusteringConfig(BaseModel):
|
|
374
|
+
"""
|
|
375
|
+
Configuration for clustering pipeline.
|
|
376
|
+
|
|
377
|
+
Attributes:
|
|
378
|
+
cluster_models: List of clustering algorithms to use
|
|
379
|
+
n_clusters: Number of clusters to form
|
|
380
|
+
random_state: Random seed for reproducibility
|
|
381
|
+
covariance_type: Covariance type for GMM
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
cluster_models: List[Union[ClusterModel, str]]
|
|
385
|
+
n_clusters: int = 3
|
|
386
|
+
random_state: int = 0
|
|
387
|
+
covariance_type: str = "full"
|
|
388
|
+
|
|
389
|
+
@field_validator('cluster_models')
|
|
390
|
+
def convert_to_enum_list(cls, v):
|
|
391
|
+
"""Convert cluster models to enum list."""
|
|
392
|
+
if isinstance(v, str):
|
|
393
|
+
v = [v]
|
|
394
|
+
|
|
395
|
+
result = []
|
|
396
|
+
for model in v:
|
|
397
|
+
if isinstance(model, ClusterModel):
|
|
398
|
+
result.append(model)
|
|
399
|
+
elif isinstance(model, str):
|
|
400
|
+
try:
|
|
401
|
+
result.append(ClusterModel(model))
|
|
402
|
+
except ValueError:
|
|
403
|
+
valid_models = [m.value for m in ClusterModel]
|
|
404
|
+
raise ValueError(
|
|
405
|
+
f"Invalid cluster model: '{model}'. "
|
|
406
|
+
f"Valid options: {valid_models}"
|
|
407
|
+
)
|
|
408
|
+
else:
|
|
409
|
+
raise ValueError(f"Invalid cluster model type: {type(model)}")
|
|
410
|
+
|
|
411
|
+
return result
|
|
412
|
+
|
|
413
|
+
@field_validator('n_clusters')
|
|
414
|
+
def validate_n_clusters(cls, v):
|
|
415
|
+
"""Validate n_clusters is positive."""
|
|
416
|
+
if v < 1:
|
|
417
|
+
raise ValueError("n_clusters must be >= 1")
|
|
418
|
+
return v
|
|
419
|
+
|
|
420
|
+
@field_validator('covariance_type')
|
|
421
|
+
def validate_covariance_type(cls, v):
|
|
422
|
+
"""Validate covariance type for GMM."""
|
|
423
|
+
valid_types = ['full', 'tied', 'diag', 'spherical']
|
|
424
|
+
if v not in valid_types:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"covariance_type must be one of {valid_types}, got '{v}'"
|
|
427
|
+
)
|
|
428
|
+
return v
|
|
429
|
+
|
|
430
|
+
class Config:
|
|
431
|
+
use_enum_values = False
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
# Auto-register all clustering models
|
|
435
|
+
ClusterRegistry.register(ClusterModel.KMEANS, KMeansCluster)
|
|
436
|
+
ClusterRegistry.register(ClusterModel.MINIBATCH_KMEANS, MiniBatchKMeansCluster)
|
|
437
|
+
ClusterRegistry.register(ClusterModel.GMM, GaussianMixtureCluster)
|
|
438
|
+
ClusterRegistry.register(ClusterModel.AGGLOMERATIVE, AgglomerativeCluster)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from typing import Dict, List, Type
|
|
2
|
+
|
|
3
|
+
from segmentae.core.base import AbstractClusterModel
|
|
4
|
+
from segmentae.core.constants import ClusterModel
|
|
5
|
+
from segmentae.core.exceptions import ConfigurationError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ClusterRegistry:
|
|
9
|
+
"""
|
|
10
|
+
Registry for clustering model classes.
|
|
11
|
+
|
|
12
|
+
This class implements the Registry pattern for managing clustering
|
|
13
|
+
algorithm implementations. Models are registered at module load time
|
|
14
|
+
and can be instantiated dynamically by type.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
_models: Dict[ClusterModel, Type[AbstractClusterModel]] = {}
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def register(cls,
|
|
21
|
+
model_type: ClusterModel,
|
|
22
|
+
model_class: Type[AbstractClusterModel]) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Register a clustering model class.
|
|
25
|
+
"""
|
|
26
|
+
if model_type in cls._models:
|
|
27
|
+
raise ConfigurationError(
|
|
28
|
+
f"Cluster model '{model_type.value}' is already registered"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Validate that model_class implements AbstractClusterModel
|
|
32
|
+
if not issubclass(model_class, AbstractClusterModel):
|
|
33
|
+
raise ConfigurationError(
|
|
34
|
+
f"Model class must inherit from AbstractClusterModel"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
cls._models[model_type] = model_class
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def create(cls,
|
|
41
|
+
model_type: ClusterModel,
|
|
42
|
+
**kwargs) -> AbstractClusterModel:
|
|
43
|
+
"""
|
|
44
|
+
Create a clustering model instance.
|
|
45
|
+
"""
|
|
46
|
+
if model_type not in cls._models:
|
|
47
|
+
available = [m.value for m in cls.list_available()]
|
|
48
|
+
raise ConfigurationError(
|
|
49
|
+
f"Unknown cluster model: '{model_type.value}'",
|
|
50
|
+
valid_options=available
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
model_class = cls._models[model_type]
|
|
54
|
+
return model_class(**kwargs)
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def list_available(cls) -> List[ClusterModel]:
|
|
58
|
+
"""
|
|
59
|
+
List all registered clustering models.
|
|
60
|
+
"""
|
|
61
|
+
return list(cls._models.keys())
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def is_registered(cls, model_type: ClusterModel) -> bool:
|
|
65
|
+
"""
|
|
66
|
+
Check if a model type is registered.
|
|
67
|
+
"""
|
|
68
|
+
return model_type in cls._models
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def clear(cls) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Clear all registered models.
|
|
74
|
+
"""
|
|
75
|
+
cls._models.clear()
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from segmentae.core.base import AbstractClusterModel, AbstractPreprocessor
|
|
2
|
+
from segmentae.core.constants import (
|
|
3
|
+
METRIC_COLUMN_MAP,
|
|
4
|
+
ClusterModel,
|
|
5
|
+
EncoderType,
|
|
6
|
+
ImputerType,
|
|
7
|
+
PhaseType,
|
|
8
|
+
ScalerType,
|
|
9
|
+
ThresholdMetric,
|
|
10
|
+
get_metric_column_name,
|
|
11
|
+
parse_threshold_metric,
|
|
12
|
+
)
|
|
13
|
+
from segmentae.core.exceptions import (
|
|
14
|
+
AutoencoderError,
|
|
15
|
+
ClusteringError,
|
|
16
|
+
ConfigurationError,
|
|
17
|
+
ModelNotFittedError,
|
|
18
|
+
ReconstructionError,
|
|
19
|
+
SegmentAEError,
|
|
20
|
+
ValidationError,
|
|
21
|
+
)
|
|
22
|
+
from segmentae.core.types import (
|
|
23
|
+
AutoencoderProtocol,
|
|
24
|
+
ClusterModelProtocol,
|
|
25
|
+
DataFrame,
|
|
26
|
+
DictStrAny,
|
|
27
|
+
NDArray,
|
|
28
|
+
PreprocessorProtocol,
|
|
29
|
+
Series,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
# Constants
|
|
34
|
+
'PhaseType',
|
|
35
|
+
'ClusterModel',
|
|
36
|
+
'ThresholdMetric',
|
|
37
|
+
'EncoderType',
|
|
38
|
+
'ScalerType',
|
|
39
|
+
'ImputerType',
|
|
40
|
+
'METRIC_COLUMN_MAP',
|
|
41
|
+
'get_metric_column_name',
|
|
42
|
+
'parse_threshold_metric',
|
|
43
|
+
|
|
44
|
+
# Exceptions
|
|
45
|
+
'SegmentAEError',
|
|
46
|
+
'ClusteringError',
|
|
47
|
+
'ReconstructionError',
|
|
48
|
+
'ValidationError',
|
|
49
|
+
'ModelNotFittedError',
|
|
50
|
+
'ConfigurationError',
|
|
51
|
+
'AutoencoderError',
|
|
52
|
+
|
|
53
|
+
# Base Classes
|
|
54
|
+
'AbstractClusterModel',
|
|
55
|
+
'AbstractPreprocessor',
|
|
56
|
+
|
|
57
|
+
# Types
|
|
58
|
+
'DataFrame',
|
|
59
|
+
'Series',
|
|
60
|
+
'NDArray',
|
|
61
|
+
'DictStrAny',
|
|
62
|
+
'AutoencoderProtocol',
|
|
63
|
+
'ClusterModelProtocol',
|
|
64
|
+
'PreprocessorProtocol'
|
|
65
|
+
]
|
segmentae/core/base.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from segmentae.core.exceptions import ModelNotFittedError, ValidationError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AbstractClusterModel(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for all clustering implementations.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def fit(self, X: pd.DataFrame) -> None:
|
|
16
|
+
"""
|
|
17
|
+
Fit clustering model to data.
|
|
18
|
+
"""
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
23
|
+
"""
|
|
24
|
+
Predict cluster assignments for data.
|
|
25
|
+
"""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def fit_predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
30
|
+
"""
|
|
31
|
+
Fit model and predict cluster assignments in one step.
|
|
32
|
+
"""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def n_clusters(self) -> int:
|
|
38
|
+
"""
|
|
39
|
+
Return the number of clusters.
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def is_fitted(self) -> bool:
|
|
46
|
+
"""
|
|
47
|
+
Check if model has been fitted.
|
|
48
|
+
"""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
def _validate_input(self, X: pd.DataFrame, context: str = "Input") -> None:
|
|
52
|
+
"""
|
|
53
|
+
Validate input DataFrame.
|
|
54
|
+
"""
|
|
55
|
+
if not isinstance(X, pd.DataFrame):
|
|
56
|
+
raise ValidationError(
|
|
57
|
+
f"{context} must be a pandas DataFrame, got {type(X).__name__}",
|
|
58
|
+
suggestion="Convert your data to a pandas DataFrame using pd.DataFrame()"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if X.empty:
|
|
62
|
+
raise ValidationError(
|
|
63
|
+
f"{context} DataFrame is empty",
|
|
64
|
+
suggestion="Ensure your dataset contains data before fitting"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if X.isnull().any().any():
|
|
68
|
+
raise ValidationError(
|
|
69
|
+
f"{context} contains missing values",
|
|
70
|
+
suggestion="Handle missing values using preprocessing before clustering"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _validate_fitted(self) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Check if model is fitted, raise error if not.
|
|
76
|
+
"""
|
|
77
|
+
if not self.is_fitted:
|
|
78
|
+
raise ModelNotFittedError(
|
|
79
|
+
component=self.__class__.__name__,
|
|
80
|
+
message=f"{self.__class__.__name__} must be fitted before prediction. "
|
|
81
|
+
f"Call fit(X) method first."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class AbstractPreprocessor(ABC):
|
|
86
|
+
"""
|
|
87
|
+
Abstract base class for preprocessing components.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def fit(self, X: pd.DataFrame) -> 'AbstractPreprocessor':
|
|
92
|
+
"""
|
|
93
|
+
Fit preprocessor to data.
|
|
94
|
+
"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
|
|
99
|
+
"""
|
|
100
|
+
Transform data using fitted preprocessor.
|
|
101
|
+
"""
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
def fit_transform(self, X: pd.DataFrame) -> pd.DataFrame:
|
|
105
|
+
"""
|
|
106
|
+
Fit and transform data in one step.
|
|
107
|
+
"""
|
|
108
|
+
return self.fit(X).transform(X)
|