teradataml 20.0.0.6__py3-none-any.whl → 20.0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of teradataml might be problematic. Click here for more details.
- teradataml/README.md +210 -0
- teradataml/__init__.py +1 -1
- teradataml/_version.py +1 -1
- teradataml/analytics/analytic_function_executor.py +162 -76
- teradataml/analytics/byom/__init__.py +1 -1
- teradataml/analytics/json_parser/__init__.py +2 -0
- teradataml/analytics/json_parser/analytic_functions_argument.py +95 -2
- teradataml/analytics/json_parser/metadata.py +22 -4
- teradataml/analytics/sqle/DecisionTreePredict.py +3 -2
- teradataml/analytics/sqle/NaiveBayesPredict.py +3 -2
- teradataml/analytics/sqle/__init__.py +3 -0
- teradataml/analytics/utils.py +4 -1
- teradataml/automl/__init__.py +2369 -464
- teradataml/automl/autodataprep/__init__.py +15 -0
- teradataml/automl/custom_json_utils.py +184 -112
- teradataml/automl/data_preparation.py +113 -58
- teradataml/automl/data_transformation.py +154 -53
- teradataml/automl/feature_engineering.py +113 -53
- teradataml/automl/feature_exploration.py +548 -25
- teradataml/automl/model_evaluation.py +260 -32
- teradataml/automl/model_training.py +399 -206
- teradataml/clients/auth_client.py +2 -2
- teradataml/common/aed_utils.py +11 -2
- teradataml/common/bulk_exposed_utils.py +4 -2
- teradataml/common/constants.py +62 -2
- teradataml/common/garbagecollector.py +50 -21
- teradataml/common/messagecodes.py +47 -2
- teradataml/common/messages.py +19 -1
- teradataml/common/sqlbundle.py +23 -6
- teradataml/common/utils.py +116 -10
- teradataml/context/aed_context.py +16 -10
- teradataml/data/Employee.csv +5 -0
- teradataml/data/Employee_Address.csv +4 -0
- teradataml/data/Employee_roles.csv +5 -0
- teradataml/data/JulesBelvezeDummyData.csv +100 -0
- teradataml/data/byom_example.json +5 -0
- teradataml/data/creditcard_data.csv +284618 -0
- teradataml/data/docs/byom/docs/ONNXSeq2Seq.py +255 -0
- teradataml/data/docs/sqle/docs_17_10/NGramSplitter.py +1 -1
- teradataml/data/docs/sqle/docs_17_20/NGramSplitter.py +1 -1
- teradataml/data/docs/sqle/docs_17_20/TextParser.py +1 -1
- teradataml/data/jsons/byom/ONNXSeq2Seq.json +287 -0
- teradataml/data/jsons/sqle/20.00/AI_AnalyzeSentiment.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_AskLLM.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_DetectLanguage.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_ExtractKeyPhrases.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_MaskPII.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_RecognizeEntities.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_RecognizePIIEntities.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_TextClassifier.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_TextEmbeddings.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_TextSummarize.json +3 -7
- teradataml/data/jsons/sqle/20.00/AI_TextTranslate.json +3 -7
- teradataml/data/jsons/sqle/20.00/TD_API_AzureML.json +151 -0
- teradataml/data/jsons/sqle/20.00/TD_API_Sagemaker.json +182 -0
- teradataml/data/jsons/sqle/20.00/TD_API_VertexAI.json +183 -0
- teradataml/data/load_example_data.py +29 -11
- teradataml/data/payment_fraud_dataset.csv +10001 -0
- teradataml/data/teradataml_example.json +67 -0
- teradataml/dataframe/copy_to.py +714 -54
- teradataml/dataframe/dataframe.py +1153 -33
- teradataml/dataframe/dataframe_utils.py +8 -3
- teradataml/dataframe/functions.py +168 -1
- teradataml/dataframe/setop.py +4 -1
- teradataml/dataframe/sql.py +141 -9
- teradataml/dbutils/dbutils.py +470 -35
- teradataml/dbutils/filemgr.py +1 -1
- teradataml/hyperparameter_tuner/optimizer.py +456 -142
- teradataml/lib/aed_0_1.dll +0 -0
- teradataml/lib/libaed_0_1.dylib +0 -0
- teradataml/lib/libaed_0_1.so +0 -0
- teradataml/lib/libaed_0_1_aarch64.so +0 -0
- teradataml/scriptmgmt/UserEnv.py +234 -34
- teradataml/scriptmgmt/lls_utils.py +43 -17
- teradataml/sdk/_json_parser.py +1 -1
- teradataml/sdk/api_client.py +9 -6
- teradataml/sdk/modelops/_client.py +3 -0
- teradataml/series/series.py +12 -7
- teradataml/store/feature_store/constants.py +601 -234
- teradataml/store/feature_store/feature_store.py +2886 -616
- teradataml/store/feature_store/mind_map.py +639 -0
- teradataml/store/feature_store/models.py +5831 -214
- teradataml/store/feature_store/utils.py +390 -0
- teradataml/table_operators/table_operator_util.py +1 -1
- teradataml/table_operators/templates/dataframe_register.template +6 -2
- teradataml/table_operators/templates/dataframe_udf.template +6 -2
- teradataml/utils/docstring.py +527 -0
- teradataml/utils/dtypes.py +93 -0
- teradataml/utils/internal_buffer.py +2 -2
- teradataml/utils/utils.py +41 -2
- teradataml/utils/validators.py +694 -17
- {teradataml-20.0.0.6.dist-info → teradataml-20.0.0.7.dist-info}/METADATA +213 -2
- {teradataml-20.0.0.6.dist-info → teradataml-20.0.0.7.dist-info}/RECORD +96 -81
- {teradataml-20.0.0.6.dist-info → teradataml-20.0.0.7.dist-info}/WHEEL +0 -0
- {teradataml-20.0.0.6.dist-info → teradataml-20.0.0.7.dist-info}/top_level.txt +0 -0
- {teradataml-20.0.0.6.dist-info → teradataml-20.0.0.7.dist-info}/zip-safe +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# ##################################################################
|
|
2
2
|
#
|
|
3
|
-
# Copyright
|
|
3
|
+
# Copyright 2025 Teradata. All rights reserved.
|
|
4
4
|
# TERADATA CONFIDENTIAL AND TRADE SECRET
|
|
5
5
|
#
|
|
6
6
|
# Primary Owner: Sweta Shaw
|
|
@@ -20,6 +20,8 @@ import ast
|
|
|
20
20
|
# Teradata libraries
|
|
21
21
|
from teradataml.dataframe.dataframe import DataFrame
|
|
22
22
|
from teradataml.automl.model_training import _ModelTraining
|
|
23
|
+
from teradataml.automl.feature_exploration import _FeatureExplore
|
|
24
|
+
from teradataml import Shap
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
class _ModelEvaluator:
|
|
@@ -27,7 +29,8 @@ class _ModelEvaluator:
|
|
|
27
29
|
def __init__(self,
|
|
28
30
|
df=None,
|
|
29
31
|
target_column=None,
|
|
30
|
-
task_type=None
|
|
32
|
+
task_type=None,
|
|
33
|
+
cluster=False):
|
|
31
34
|
"""
|
|
32
35
|
DESCRIPTION:
|
|
33
36
|
Function initializes the data, target column, features and models
|
|
@@ -52,17 +55,26 @@ class _ModelEvaluator:
|
|
|
52
55
|
Permitted Values: "Regression", "Classification"
|
|
53
56
|
Types: str
|
|
54
57
|
|
|
58
|
+
cluster:
|
|
59
|
+
Required Argument.
|
|
60
|
+
Specifies whether to apply clustering techniques.
|
|
61
|
+
Default Value: False
|
|
62
|
+
Types: bool
|
|
63
|
+
|
|
55
64
|
"""
|
|
56
65
|
self.model_info = df
|
|
57
66
|
self.target_column = target_column
|
|
58
67
|
self.task_type = task_type
|
|
59
|
-
|
|
68
|
+
self.cluster = cluster
|
|
69
|
+
self.shap_results = None
|
|
70
|
+
|
|
60
71
|
def model_evaluation(self,
|
|
61
72
|
rank,
|
|
62
73
|
table_name_mapping,
|
|
63
74
|
data_node_id,
|
|
64
|
-
target_column_ind
|
|
65
|
-
get_metrics
|
|
75
|
+
target_column_ind=True,
|
|
76
|
+
get_metrics=False,
|
|
77
|
+
is_predict=False):
|
|
66
78
|
"""
|
|
67
79
|
DESCRIPTION:
|
|
68
80
|
Function performs the model evaluation on the specified rank in leaderborad.
|
|
@@ -94,7 +106,12 @@ class _ModelEvaluator:
|
|
|
94
106
|
Specifies whether to return metrics or not.
|
|
95
107
|
Default Value: False
|
|
96
108
|
Types: bool
|
|
97
|
-
|
|
109
|
+
|
|
110
|
+
is_predict:
|
|
111
|
+
Required Argument.
|
|
112
|
+
Specifies whether predict is called or evaluate is called.
|
|
113
|
+
Default Value: False
|
|
114
|
+
Types: bool
|
|
98
115
|
RETURNS:
|
|
99
116
|
tuple containing, performance metrics and predicitions of specified rank ML model.
|
|
100
117
|
|
|
@@ -105,8 +122,25 @@ class _ModelEvaluator:
|
|
|
105
122
|
self.data_node_id = data_node_id
|
|
106
123
|
self.get_metrics = get_metrics
|
|
107
124
|
|
|
108
|
-
#
|
|
109
|
-
|
|
125
|
+
# Perform evaluation
|
|
126
|
+
if self.cluster:
|
|
127
|
+
evaluation_results, test_data = self._evaluator(rank)
|
|
128
|
+
else:
|
|
129
|
+
evaluation_results = self._evaluator(rank)
|
|
130
|
+
|
|
131
|
+
# Apply SHAP if applicable
|
|
132
|
+
if is_predict:
|
|
133
|
+
if not self.cluster:
|
|
134
|
+
model_id = self.model_info.loc[rank]['MODEL_ID'].split('_')[0]
|
|
135
|
+
permitted_models = ["XGBOOST", "DECISIONFOREST"]
|
|
136
|
+
if model_id.upper() in permitted_models:
|
|
137
|
+
print("\nApplying SHAP for Model Interpretation...")
|
|
138
|
+
self._apply_shap(rank, isload=False)
|
|
139
|
+
else:
|
|
140
|
+
print(f"\nSHAP is not applied for {model_id}. Only permitted models: {permitted_models}")
|
|
141
|
+
else:
|
|
142
|
+
self._visualize_cluster(test_data)
|
|
143
|
+
return evaluation_results
|
|
110
144
|
|
|
111
145
|
def _evaluator(self,
|
|
112
146
|
rank):
|
|
@@ -130,31 +164,225 @@ class _ModelEvaluator:
|
|
|
130
164
|
|
|
131
165
|
ml_name = self.model_info.loc[rank]['MODEL_ID'].split('_')[0]
|
|
132
166
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
167
|
+
if not self.cluster:
|
|
168
|
+
# Defining eval_params
|
|
169
|
+
eval_params = _ModelTraining._eval_params_generation(ml_name,
|
|
170
|
+
self.target_column,
|
|
171
|
+
self.task_type)
|
|
172
|
+
|
|
173
|
+
# Extracting test data for evaluation based on data node id
|
|
174
|
+
test = DataFrame(self.table_name_mapping[self.data_node_id]['{}_test'.format(model['FEATURE_SELECTION'])])
|
|
175
|
+
|
|
176
|
+
print("\nFollowing model is being picked for evaluation:")
|
|
177
|
+
print("Model ID :", model['MODEL_ID'],
|
|
178
|
+
"\nFeature Selection Method :",model['FEATURE_SELECTION'])
|
|
179
|
+
|
|
180
|
+
if self.task_type.lower() == 'classification':
|
|
181
|
+
params = ast.literal_eval(model['PARAMETERS'])
|
|
182
|
+
eval_params['output_responses'] = params['output_responses']
|
|
183
|
+
|
|
184
|
+
# Mapping data according to model type
|
|
185
|
+
data_map = 'test_data' if ml_name == 'KNN' else 'newdata'
|
|
186
|
+
# Performing evaluation if get_metrics is True else returning predictions
|
|
187
|
+
if self.get_metrics:
|
|
188
|
+
metrics = model['model-obj'].evaluate(**{data_map: test}, **eval_params)
|
|
189
|
+
return metrics
|
|
190
|
+
else:
|
|
191
|
+
# Removing accumulate parameter if target column is not present
|
|
192
|
+
if not self.target_column_ind:
|
|
193
|
+
eval_params.pop("accumulate")
|
|
194
|
+
pred = model['model-obj'].predict(**{data_map: test}, **eval_params)
|
|
195
|
+
return pred
|
|
196
|
+
else:
|
|
197
|
+
print("\nFollowing model is being picked for evaluation of clustering:")
|
|
198
|
+
print("Model ID :", model['MODEL_ID'],
|
|
199
|
+
"\nFeature Selection Method :",model['FEATURE_SELECTION'])
|
|
200
|
+
feature_type = model["FEATURE_SELECTION"]
|
|
201
|
+
test_table_key = f"{feature_type}_test"
|
|
202
|
+
|
|
203
|
+
if test_table_key not in self.table_name_mapping[self.data_node_id]:
|
|
204
|
+
raise KeyError(f"Table key '{test_table_key}' not found in table_name_mapping. Available keys: {self.table_name_mapping[self.data_node_id].keys()}")
|
|
205
|
+
|
|
206
|
+
test_data = DataFrame(self.table_name_mapping[self.data_node_id][test_table_key])
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
if self.get_metrics:
|
|
210
|
+
from teradataml import td_sklearn as skl
|
|
211
|
+
|
|
212
|
+
X = test_data
|
|
213
|
+
result = model["model-obj"].predict(X)
|
|
214
|
+
silhouette = skl.silhouette_score(X=result.select(X.columns), labels=result.select(["gridsearchcv_predict_1"]))
|
|
215
|
+
calinski = skl.calinski_harabasz_score(X=result.select(X.columns), labels=result.select(["gridsearchcv_predict_1"]))
|
|
216
|
+
davies = skl.davies_bouldin_score(X=result.select(X.columns), labels=result.select(["gridsearchcv_predict_1"]))
|
|
217
|
+
|
|
218
|
+
return {
|
|
219
|
+
"SILHOUETTE": silhouette,
|
|
220
|
+
"CALINSKI": calinski,
|
|
221
|
+
"DAVIES": davies
|
|
222
|
+
}, test_data
|
|
223
|
+
else:
|
|
224
|
+
return model["model-obj"].predict(test_data),test_data
|
|
225
|
+
|
|
226
|
+
def _apply_shap(self, rank, isload):
|
|
227
|
+
"""
|
|
228
|
+
DESCRIPTION:
|
|
229
|
+
Applies SHAP analysis to explain model predictions after evaluation.
|
|
230
|
+
|
|
231
|
+
PARAMETERS:
|
|
232
|
+
rank:
|
|
233
|
+
Required Argument.
|
|
234
|
+
Specifies the position(rank) of ML model for evaluation.
|
|
235
|
+
Types: int
|
|
236
|
+
|
|
237
|
+
isload:
|
|
238
|
+
Required Argument.
|
|
239
|
+
Specifies whether load is calling the function or not.
|
|
240
|
+
Types: bool
|
|
241
|
+
"""
|
|
137
242
|
|
|
138
|
-
|
|
139
|
-
|
|
243
|
+
test_data = DataFrame(self.table_name_mapping[self.data_node_id]['{}_test'.format(self.model_info.loc[rank]['FEATURE_SELECTION'])])
|
|
244
|
+
id_column = "id"
|
|
245
|
+
input_columns = [col for col in test_data.columns if col != self.target_column and col != id_column]
|
|
140
246
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
247
|
+
if isload:
|
|
248
|
+
result_table_name = self.model_info.loc[rank, 'RESULT_TABLE']
|
|
249
|
+
model_object = DataFrame(result_table_name)
|
|
250
|
+
else:
|
|
251
|
+
model_obj = self.model_info.loc[rank]['model-obj']
|
|
252
|
+
model_object = model_obj.result
|
|
144
253
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
254
|
+
# Extract model training function from MODEL_ID and format it correctly
|
|
255
|
+
raw_model_id = self.model_info.loc[rank]['MODEL_ID'].split('_')[0] # Extract base model name
|
|
256
|
+
formatted_training_function = "TD_" + raw_model_id # Add TD_ prefix
|
|
257
|
+
#Currently issue with default value of model_type, it is not case insensitive
|
|
258
|
+
#Hence, converting task_type to lower case
|
|
259
|
+
shap_output = Shap(
|
|
260
|
+
data=test_data,
|
|
261
|
+
object=model_object,
|
|
262
|
+
id_column='id',
|
|
263
|
+
training_function=formatted_training_function,
|
|
264
|
+
model_type=self.task_type.lower(),
|
|
265
|
+
input_columns=input_columns,
|
|
266
|
+
detailed=True
|
|
267
|
+
)
|
|
148
268
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
269
|
+
self.shap_results = shap_output.output_data
|
|
270
|
+
print("\nSHAP Analysis Completed. Feature Importance Available.")
|
|
271
|
+
|
|
272
|
+
# Extract SHAP values for visualization
|
|
273
|
+
df = self.shap_results
|
|
274
|
+
data = next(df.itertuples())._asdict()
|
|
275
|
+
|
|
276
|
+
import matplotlib.pyplot as plt
|
|
277
|
+
|
|
278
|
+
# Extract keys and values
|
|
279
|
+
keys = list(data.keys())
|
|
280
|
+
values = list(data.values())
|
|
281
|
+
|
|
282
|
+
# Plot SHAP values as a bar graph
|
|
283
|
+
plt.figure(figsize=(10, 6))
|
|
284
|
+
bars = plt.bar(keys, values, color='skyblue', edgecolor='black')
|
|
285
|
+
for bar in bars:
|
|
286
|
+
yval = bar.get_height()
|
|
287
|
+
plt.text(bar.get_x() + bar.get_width()/2, yval, f'{yval:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
|
|
288
|
+
|
|
289
|
+
plt.xticks(rotation=45, ha='right')
|
|
290
|
+
plt.title('Feature Importance (SHAP Values)', fontsize=14)
|
|
291
|
+
plt.xlabel('Features', fontsize=12)
|
|
292
|
+
plt.ylabel('SHAP Value', fontsize=12)
|
|
293
|
+
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
|
294
|
+
plt.tight_layout()
|
|
295
|
+
plt.show()
|
|
296
|
+
|
|
297
|
+
def _visualize_cluster(self, test_data):
|
|
298
|
+
print("\nVisualizing Clusters for interpretability...")
|
|
299
|
+
|
|
300
|
+
df = test_data.to_pandas()
|
|
301
|
+
print(df.head())
|
|
302
|
+
from sklearn.cluster import KMeans
|
|
303
|
+
import numpy as np
|
|
304
|
+
import matplotlib.pyplot as plt
|
|
305
|
+
|
|
306
|
+
# Automatically pick top 2 high variance numeric features
|
|
307
|
+
numerical_features = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
|
|
308
|
+
if 'id' in numerical_features:
|
|
309
|
+
numerical_features.remove('id')
|
|
310
|
+
|
|
311
|
+
if len(numerical_features) < 2:
|
|
312
|
+
print("Not enough numeric features available for scatter plot.")
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
# Compute correlation matrix
|
|
316
|
+
corr_matrix = df[numerical_features].corr()
|
|
317
|
+
|
|
318
|
+
# Extract upper triangle without diagonal
|
|
319
|
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool), k=1)
|
|
320
|
+
corr_vals = corr_matrix.where(mask).stack().reset_index()
|
|
321
|
+
corr_vals.columns = ['Feature1', 'Feature2', 'Correlation']
|
|
322
|
+
corr_vals['Abs_Correlation'] = corr_vals['Correlation'].abs()
|
|
323
|
+
|
|
324
|
+
# Sort and select top pair
|
|
325
|
+
corr_vals = corr_vals.sort_values(by='Abs_Correlation', ascending=False)
|
|
326
|
+
filtered = corr_vals[corr_vals['Abs_Correlation'] > 0.1].head(1)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
variances = df[numerical_features].var().sort_values(ascending=False)
|
|
330
|
+
top_features = variances.index[:2].tolist()
|
|
331
|
+
print("Selection Criteria: Top 2 High Variance Features")
|
|
332
|
+
print(f"Selected Features: {top_features[0]}, {top_features[1]}")
|
|
333
|
+
X = df[top_features].values
|
|
334
|
+
|
|
335
|
+
kmeans = KMeans(n_clusters=4, init='k-means++', n_init=10, max_iter=300,
|
|
336
|
+
tol=0.0001, random_state=111, algorithm='elkan')
|
|
337
|
+
kmeans.fit(X)
|
|
338
|
+
|
|
339
|
+
import matplotlib.pyplot as plt
|
|
340
|
+
import matplotlib.patches as mpatches
|
|
341
|
+
import numpy as np
|
|
342
|
+
from matplotlib.colors import ListedColormap
|
|
343
|
+
|
|
344
|
+
# Define a fixed color map
|
|
345
|
+
cmap = ListedColormap(plt.cm.Pastel2.colors)
|
|
346
|
+
n_clusters = len(np.unique(kmeans.labels_))
|
|
347
|
+
colors = cmap.colors[:n_clusters]
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
# Plot decision regions
|
|
351
|
+
h = 0.02
|
|
352
|
+
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
|
353
|
+
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
|
354
|
+
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
|
|
355
|
+
np.arange(y_min, y_max, h))
|
|
356
|
+
Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])
|
|
357
|
+
Z = Z.reshape(xx.shape)
|
|
358
|
+
|
|
359
|
+
plt.figure(figsize=(14, 7))
|
|
360
|
+
plt.imshow(Z, interpolation='nearest',
|
|
361
|
+
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
|
|
362
|
+
cmap=ListedColormap(colors), aspect='auto', origin='lower', zorder=1)
|
|
363
|
+
|
|
364
|
+
# Plot actual clustered data points (zorder > 1)
|
|
365
|
+
cluster_colors = [colors[label] for label in kmeans.labels_]
|
|
366
|
+
plt.scatter(X[:, 0], X[:, 1], c=cluster_colors, s=100, edgecolor='k', alpha=0.85, zorder=2)
|
|
367
|
+
|
|
368
|
+
# Plot red centroids
|
|
369
|
+
centroids = kmeans.cluster_centers_
|
|
370
|
+
plt.scatter(centroids[:, 0], centroids[:, 1],
|
|
371
|
+
s=300, c='red', alpha=0.7, zorder=3)
|
|
372
|
+
|
|
373
|
+
# Annotate centroids
|
|
374
|
+
for i, (x, y) in enumerate(centroids):
|
|
375
|
+
"""plt.text(x, y + 0.05, f'Cluster {i}', fontsize=11, weight='bold',
|
|
376
|
+
ha='center', va='bottom', zorder=4)"""
|
|
377
|
+
plt.text(x, y - 0.05, f'({x:.2f}, {y:.2f})', fontsize=9,
|
|
378
|
+
ha='center', va='top', bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'), zorder=4)
|
|
379
|
+
|
|
380
|
+
# Legend (manually matched)
|
|
381
|
+
legend_handles = [mpatches.Patch(color=colors[i], label=f'Cluster {i}') for i in range(n_clusters)]
|
|
382
|
+
plt.legend(handles=legend_handles, title="Cluster ID", loc='upper right')
|
|
383
|
+
|
|
384
|
+
# Axis labels and title
|
|
385
|
+
plt.xlabel(top_features[0])
|
|
386
|
+
plt.ylabel(top_features[1])
|
|
387
|
+
plt.title("Cluster Visualization on Test Data")
|
|
388
|
+
plt.show()
|