likelihood 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
likelihood/graph/nn.py CHANGED
@@ -61,6 +61,8 @@ def cal_adjacency_matrix(
61
61
  ----------
62
62
  similarity: `int`
63
63
  The minimum number of features that must be the same in both arrays to be considered similar.
64
+ threshold : `float`
65
+ The threshold value used in the `compare_similarity` function. Default is 0.05.
64
66
 
65
67
  Returns
66
68
  -------
@@ -79,6 +81,7 @@ def cal_adjacency_matrix(
79
81
  assert len(df_) > 0
80
82
 
81
83
  similarity = kwargs.get("similarity", len(df_.columns) - 1)
84
+ threshold = kwargs.get("threshold", 0.05)
82
85
  assert similarity <= df_.shape[1]
83
86
 
84
87
  adj_dict = {index: row.tolist() for index, row in df_.iterrows()}
@@ -87,7 +90,7 @@ def cal_adjacency_matrix(
87
90
 
88
91
  for i in range(len(df_)):
89
92
  for j in range(len(df_)):
90
- if compare_similarity(adj_dict[i], adj_dict[j]) >= similarity:
93
+ if compare_similarity(adj_dict[i], adj_dict[j], threshold=threshold) >= similarity:
91
94
  adjacency_matrix[i][j] = 1
92
95
 
93
96
  if sparse:
@@ -114,7 +117,10 @@ class Data:
114
117
  **kwargs,
115
118
  ):
116
119
  sparse = kwargs.get("sparse", True)
117
- _, adjacency = cal_adjacency_matrix(df, exclude_subset=exclude_subset, sparse=sparse)
120
+ threshold = kwargs.get("threshold", 0.05)
121
+ _, adjacency = cal_adjacency_matrix(
122
+ df, exclude_subset=exclude_subset, sparse=sparse, threshold=threshold
123
+ )
118
124
  if target is not None:
119
125
  X = df.drop(columns=[target] + exclude_subset)
120
126
  else:
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  import random
4
+ import warnings
4
5
  from functools import partial
5
6
  from shutil import rmtree
6
7
 
@@ -14,8 +15,8 @@ from pandas.plotting import radviz
14
15
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
15
16
  logging.getLogger("tensorflow").setLevel(logging.ERROR)
16
17
 
17
- import warnings
18
- from functools import wraps
18
+
19
+ from typing import List
19
20
 
20
21
  import keras_tuner
21
22
  import tensorflow as tf
@@ -24,21 +25,11 @@ from sklearn.manifold import TSNE
24
25
  from tensorflow.keras.layers import InputLayer
25
26
  from tensorflow.keras.regularizers import l2
26
27
 
27
- from likelihood.tools import LoRALayer, OneHotEncoder
28
+ from likelihood.tools import LoRALayer, OneHotEncoder, suppress_warnings
28
29
 
29
30
  tf.get_logger().setLevel("ERROR")
30
31
 
31
32
 
32
- def suppress_warnings(func):
33
- @wraps(func)
34
- def wrapper(*args, **kwargs):
35
- with warnings.catch_warnings():
36
- warnings.simplefilter("ignore")
37
- return func(*args, **kwargs)
38
-
39
- return wrapper
40
-
41
-
42
33
  class EarlyStopping:
43
34
  def __init__(self, patience=10, min_delta=0.001):
44
35
  self.patience = patience
@@ -246,7 +237,7 @@ class AutoClassifier(tf.keras.Model):
246
237
  Additional keyword arguments to pass to the model.
247
238
 
248
239
  classifier_activation : `str`
249
- The activation function to use for the classifier layer. Default is "softmax". If the activation function is not a classification function, the model can be used in regression problems.
240
+ The activation function to use for the classifier layer. Default is `softmax`. If the activation function is not a classification function, the model can be used in regression problems.
250
241
  num_layers : `int`
251
242
  The number of hidden layers in the classifier. Default is 1.
252
243
  dropout : `float`
@@ -373,7 +364,6 @@ class AutoClassifier(tf.keras.Model):
373
364
  else:
374
365
  self.build_encoder_decoder(input_shape)
375
366
 
376
- # Classifier with L2 regularization
377
367
  self.classifier = tf.keras.Sequential()
378
368
  if self.num_layers > 1 and not self.lora_mode:
379
369
  for _ in range(self.num_layers - 1):
@@ -527,7 +517,6 @@ class AutoClassifier(tf.keras.Model):
527
517
  if not isinstance(source_model, AutoClassifier):
528
518
  raise ValueError("Source model must be an instance of AutoClassifier.")
529
519
 
530
- # Check compatibility in input shape and units
531
520
  if self.input_shape_parm != source_model.input_shape_parm:
532
521
  raise ValueError(
533
522
  f"Incompatible input shape. Expected {self.input_shape_parm}, got {source_model.input_shape_parm}."
@@ -537,9 +526,8 @@ class AutoClassifier(tf.keras.Model):
537
526
  f"Incompatible number of units. Expected {self.units}, got {source_model.units}."
538
527
  )
539
528
  self.encoder, self.decoder = tf.keras.Sequential(), tf.keras.Sequential()
540
- # Copy the encoder layers
541
529
  for i, layer in enumerate(source_model.encoder.layers):
542
- if isinstance(layer, tf.keras.layers.Dense): # Make sure it's a Dense layer
530
+ if isinstance(layer, tf.keras.layers.Dense):
543
531
  dummy_input = tf.convert_to_tensor(tf.random.normal([1, layer.input_shape[1]]))
544
532
  dense_layer = tf.keras.layers.Dense(
545
533
  units=layer.units,
@@ -548,14 +536,12 @@ class AutoClassifier(tf.keras.Model):
548
536
  )
549
537
  dense_layer.build(dummy_input.shape)
550
538
  self.encoder.add(dense_layer)
551
- # Set the weights correctly
552
539
  self.encoder.layers[i].set_weights(layer.get_weights())
553
540
  elif not isinstance(layer, InputLayer):
554
541
  raise ValueError(f"Layer type {type(layer)} not supported for copying.")
555
542
 
556
- # Copy the decoder layers
557
543
  for i, layer in enumerate(source_model.decoder.layers):
558
- if isinstance(layer, tf.keras.layers.Dense): # Ensure it's a Dense layer
544
+ if isinstance(layer, tf.keras.layers.Dense):
559
545
  dummy_input = tf.convert_to_tensor(tf.random.normal([1, layer.input_shape[1]]))
560
546
  dense_layer = tf.keras.layers.Dense(
561
547
  units=layer.units,
@@ -564,7 +550,6 @@ class AutoClassifier(tf.keras.Model):
564
550
  )
565
551
  dense_layer.build(dummy_input.shape)
566
552
  self.decoder.add(dense_layer)
567
- # Set the weights correctly
568
553
  self.decoder.layers[i].set_weights(layer.get_weights())
569
554
  elif not isinstance(layer, InputLayer):
570
555
  raise ValueError(f"Layer type {type(layer)} not supported for copying.")
@@ -907,62 +892,220 @@ def setup_model(
907
892
 
908
893
 
909
894
  class GetInsights:
895
+ """
896
+ A class to analyze the output of a neural network model, including visualizations
897
+ of the weights, t-SNE representation, and feature statistics.
898
+
899
+ Parameters
900
+ ----------
901
+ model : `AutoClassifier`
902
+ The trained model to analyze.
903
+ inputs : `np.ndarray`
904
+ The input data for analysis.
905
+ """
906
+
910
907
  def __init__(self, model: AutoClassifier, inputs: np.ndarray) -> None:
908
+ """
909
+ Initializes the GetInsights class.
910
+
911
+ Parameters
912
+ ----------
913
+ model : `AutoClassifier`
914
+ The trained model to analyze.
915
+ inputs : `np.ndarray`
916
+ The input data for analysis.
917
+ """
911
918
  self.inputs = inputs
912
919
  self.model = model
913
- if isinstance(self.model.encoder.layers[0], InputLayer):
914
- self.encoder_layer = self.model.encoder.layers[1]
915
- else:
916
- self.encoder_layer = self.model.encoder.layers[0]
920
+
921
+ self.encoder_layer = (
922
+ self.model.encoder.layers[1]
923
+ if isinstance(self.model.encoder.layers[0], InputLayer)
924
+ else self.model.encoder.layers[0]
925
+ )
917
926
  self.decoder_layer = self.model.decoder.layers[0]
927
+
918
928
  self.encoder_weights = self.encoder_layer.get_weights()[0]
919
929
  self.decoder_weights = self.decoder_layer.get_weights()[0]
920
- colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
921
930
 
931
+ self.sorted_names = self._generate_sorted_color_names()
932
+
933
+ def _generate_sorted_color_names(self) -> list:
934
+ """
935
+ Generate sorted color names based on their HSV values.
936
+
937
+ Parameters
938
+ ----------
939
+ `None`
940
+
941
+ Returns
942
+ -------
943
+ `list` : Sorted color names.
944
+ """
945
+ colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
922
946
  by_hsv = sorted(
923
947
  (tuple(mcolors.rgb_to_hsv(mcolors.to_rgba(color)[:3])), name)
924
948
  for name, color in colors.items()
925
949
  )
926
- self.sorted_names = [name for hsv, name in by_hsv if hsv[1] > 0.4 and hsv[2] >= 0.4]
927
- random.shuffle(self.sorted_names)
950
+ sorted_names = [name for hsv, name in by_hsv if hsv[1] > 0.4 and hsv[2] >= 0.4]
951
+ random.shuffle(sorted_names)
952
+ return sorted_names
928
953
 
929
954
  def predictor_analyzer(
930
955
  self,
931
- frac=None,
956
+ frac: float = None,
932
957
  cmap: str = "viridis",
933
958
  aspect: str = "auto",
934
959
  highlight: bool = True,
935
960
  **kwargs,
936
961
  ) -> None:
962
+ """
963
+ Analyze the model's predictions and visualize data.
964
+
965
+ Parameters
966
+ ----------
967
+ frac : `float`, optional
968
+ Fraction of data to use for analysis (default is `None`).
969
+ cmap : `str`, optional
970
+ The colormap for visualization (default is `"viridis"`).
971
+ aspect : `str`, optional
972
+ Aspect ratio for the visualization (default is `"auto"`).
973
+ highlight : `bool`, optional
974
+ Whether to highlight the maximum weights (default is `True`).
975
+ **kwargs : `dict`, optional
976
+ Additional keyword arguments for customization.
977
+
978
+ Returns
979
+ -------
980
+ `DataFrame` : The statistical summary of the input data.
981
+ """
937
982
  self._viz_weights(cmap=cmap, aspect=aspect, highlight=highlight, **kwargs)
938
983
  inputs = self.inputs.copy()
984
+ inputs = self._prepare_inputs(inputs, frac)
939
985
  y_labels = kwargs.get("y_labels", None)
986
+ encoded, reconstructed = self._encode_decode(inputs)
987
+ self._visualize_data(inputs, reconstructed, cmap, aspect)
988
+ self._prepare_data_for_analysis(inputs, reconstructed, encoded, y_labels)
989
+
990
+ try:
991
+ self._get_tsne_repr(inputs, frac)
992
+ self._viz_tsne_repr(c=self.classification)
993
+
994
+ self._viz_radviz(self.data, "class", "Radviz Visualization of Latent Space")
995
+ self._viz_radviz(self.data_input, "class", "Radviz Visualization of Input Data")
996
+ except ValueError:
997
+ warnings.warn(
998
+ "Some functions or processes will not be executed for regression problems.",
999
+ UserWarning,
1000
+ )
1001
+
1002
+ return self._statistics(self.data_input)
1003
+
1004
+ def _prepare_inputs(self, inputs: np.ndarray, frac: float) -> np.ndarray:
1005
+ """
1006
+ Prepare the input data, possibly selecting a fraction of it.
1007
+
1008
+ Parameters
1009
+ ----------
1010
+ inputs : `np.ndarray`
1011
+ The input data.
1012
+ frac : `float`
1013
+ Fraction of data to use.
1014
+
1015
+ Returns
1016
+ -------
1017
+ `np.ndarray` : The prepared input data.
1018
+ """
940
1019
  if frac:
941
1020
  n = int(frac * self.inputs.shape[0])
942
1021
  indexes = np.random.choice(np.arange(inputs.shape[0]), n, replace=False)
943
1022
  inputs = inputs[indexes]
944
1023
  inputs[np.isnan(inputs)] = 0.0
945
- # check if self.model.encoder(inputs) has two outputs
1024
+ return inputs
1025
+
1026
+ def _encode_decode(self, inputs: np.ndarray) -> tuple:
1027
+ """
1028
+ Perform encoding and decoding on the input data.
1029
+
1030
+ Parameters
1031
+ ----------
1032
+ inputs : `np.ndarray`
1033
+ The input data.
1034
+
1035
+ Returns
1036
+ -------
1037
+ `tuple` : The encoded and reconstructed data.
1038
+ """
946
1039
  try:
947
1040
  mean, log_var = self.model.encoder(inputs)
948
1041
  encoded = sampling(mean, log_var)
949
1042
  except:
950
1043
  encoded = self.model.encoder(inputs)
951
1044
  reconstructed = self.model.decoder(encoded)
952
- combined = tf.concat([reconstructed, encoded], axis=1)
953
- self.classification = self.model.classifier(combined).numpy().argmax(axis=1)
1045
+ return encoded, reconstructed
1046
+
1047
+ def _visualize_data(
1048
+ self, inputs: np.ndarray, reconstructed: np.ndarray, cmap: str, aspect: str
1049
+ ) -> None:
1050
+ """
1051
+ Visualize the original data and the reconstructed data.
1052
+
1053
+ Parameters
1054
+ ----------
1055
+ inputs : `np.ndarray`
1056
+ The input data.
1057
+ reconstructed : `np.ndarray`
1058
+ The reconstructed data.
1059
+ cmap : `str`
1060
+ The colormap for visualization.
1061
+ aspect : `str`
1062
+ Aspect ratio for the visualization.
1063
+
1064
+ Returns
1065
+ -------
1066
+ `None`
1067
+ """
954
1068
  ax = plt.subplot(1, 2, 1)
955
- plt.imshow(self.inputs, cmap=cmap, aspect=aspect)
1069
+ plt.imshow(inputs, cmap=cmap, aspect=aspect)
956
1070
  plt.colorbar()
957
1071
  plt.title("Original Data")
1072
+
958
1073
  plt.subplot(1, 2, 2, sharex=ax, sharey=ax)
959
1074
  plt.imshow(reconstructed, cmap=cmap, aspect=aspect)
960
1075
  plt.colorbar()
961
1076
  plt.title("Decoder Layer Reconstruction")
962
1077
  plt.show()
963
1078
 
964
- self._get_tsne_repr(inputs=inputs, frac=frac)
965
- self._viz_tsne_repr(c=self.classification)
1079
+ def _prepare_data_for_analysis(
1080
+ self,
1081
+ inputs: np.ndarray,
1082
+ reconstructed: np.ndarray,
1083
+ encoded: np.ndarray,
1084
+ y_labels: List[str],
1085
+ ) -> None:
1086
+ """
1087
+ Prepare data for statistical analysis.
1088
+
1089
+ Parameters
1090
+ ----------
1091
+ inputs : `np.ndarray`
1092
+ The input data.
1093
+ reconstructed : `np.ndarray`
1094
+ The reconstructed data.
1095
+ encoded : `np.ndarray`
1096
+ The encoded data.
1097
+ y_labels : `List[str]`
1098
+ The labels of features.
1099
+
1100
+ Returns
1101
+ -------
1102
+ `None`
1103
+ """
1104
+ self.classification = (
1105
+ self.model.classifier(tf.concat([reconstructed, encoded], axis=1))
1106
+ .numpy()
1107
+ .argmax(axis=1)
1108
+ )
966
1109
 
967
1110
  self.data = pd.DataFrame(encoded, columns=[f"Feature {i}" for i in range(encoded.shape[1])])
968
1111
  self.data_input = pd.DataFrame(
@@ -971,84 +1114,25 @@ class GetInsights:
971
1114
  [f"Feature {i}" for i in range(inputs.shape[1])] if y_labels is None else y_labels
972
1115
  ),
973
1116
  )
1117
+
974
1118
  self.data["class"] = self.classification
975
1119
  self.data_input["class"] = self.classification
976
1120
 
977
- self.data_normalized = self.data.copy(deep=True)
978
- self.data_normalized.iloc[:, :-1] = (
979
- 2.0
980
- * (self.data_normalized.iloc[:, :-1] - self.data_normalized.iloc[:, :-1].min())
981
- / (self.data_normalized.iloc[:, :-1].max() - self.data_normalized.iloc[:, :-1].min())
982
- - 1
983
- )
984
- radviz(self.data_normalized, "class", color=self.colors)
985
- plt.title("Radviz Visualization of Latent Space")
986
- plt.show()
987
- self.data_input_normalized = self.data_input.copy(deep=True)
988
- self.data_input_normalized.iloc[:, :-1] = (
989
- 2.0
990
- * (
991
- self.data_input_normalized.iloc[:, :-1]
992
- - self.data_input_normalized.iloc[:, :-1].min()
993
- )
994
- / (
995
- self.data_input_normalized.iloc[:, :-1].max()
996
- - self.data_input_normalized.iloc[:, :-1].min()
997
- )
998
- - 1
999
- )
1000
- radviz(self.data_input_normalized, "class", color=self.colors)
1001
- plt.title("Radviz Visualization of Input Data")
1002
- plt.show()
1003
- return self._statistics(self.data_input)
1004
-
1005
- def _statistics(self, data_input: DataFrame, **kwargs) -> DataFrame:
1006
- data = data_input.copy(deep=True)
1007
-
1008
- if not pd.api.types.is_string_dtype(data["class"]):
1009
- data["class"] = data["class"].astype(str)
1010
-
1011
- data.ffill(inplace=True)
1012
- grouped_data = data.groupby("class")
1013
-
1014
- numerical_stats = grouped_data.agg(["mean", "min", "max", "std", "median"])
1015
- numerical_stats.columns = ["_".join(col).strip() for col in numerical_stats.columns.values]
1016
-
1017
- def get_mode(x):
1018
- mode_series = x.mode()
1019
- return mode_series.iloc[0] if not mode_series.empty else None
1020
-
1021
- mode_stats = grouped_data.apply(get_mode, include_groups=False)
1022
- mode_stats.columns = [f"{col}_mode" for col in mode_stats.columns]
1023
- combined_stats = pd.concat([numerical_stats, mode_stats], axis=1)
1024
-
1025
- return combined_stats.T
1026
-
1027
- def _viz_weights(
1028
- self, cmap: str = "viridis", aspect: str = "auto", highlight: bool = True, **kwargs
1029
- ) -> None:
1030
- title = kwargs.get("title", "Encoder Layer Weights (Dense Layer)")
1031
- y_labels = kwargs.get("y_labels", None)
1032
- cmap_highlight = kwargs.get("cmap_highlight", "Pastel1")
1033
- highlight_mask = np.zeros_like(self.encoder_weights, dtype=bool)
1121
+ def _get_tsne_repr(self, inputs: np.ndarray = None, frac: float = None) -> None:
1122
+ """
1123
+ Perform t-SNE dimensionality reduction on the input data.
1034
1124
 
1035
- plt.imshow(self.encoder_weights, cmap=cmap, aspect=aspect)
1036
- plt.colorbar()
1037
- plt.title(title)
1038
- if y_labels is not None:
1039
- plt.yticks(ticks=np.arange(self.encoder_weights.shape[0]), labels=y_labels)
1040
- if highlight:
1041
- for i, j in enumerate(self.encoder_weights.argmax(axis=1)):
1042
- highlight_mask[i, j] = True
1043
- plt.imshow(
1044
- np.ma.masked_where(~highlight_mask, self.encoder_weights),
1045
- cmap=cmap_highlight,
1046
- alpha=0.5,
1047
- aspect=aspect,
1048
- )
1049
- plt.show()
1125
+ Parameters
1126
+ ----------
1127
+ inputs : `np.ndarray`
1128
+ The input data.
1129
+ frac : `float`
1130
+ Fraction of data to use.
1050
1131
 
1051
- def _get_tsne_repr(self, inputs=None, frac=None) -> None:
1132
+ Returns
1133
+ -------
1134
+ `None`
1135
+ """
1052
1136
  if inputs is None:
1053
1137
  inputs = self.inputs.copy()
1054
1138
  if frac:
@@ -1062,26 +1146,145 @@ class GetInsights:
1062
1146
  self.reduced_data_tsne = tsne.fit_transform(self.latent_representations)
1063
1147
 
1064
1148
  def _viz_tsne_repr(self, **kwargs) -> None:
1149
+ """
1150
+ Visualize the t-SNE representation of the latent space.
1151
+
1152
+ Parameters
1153
+ ----------
1154
+ **kwargs : `dict`
1155
+ Additional keyword arguments for customization.
1156
+
1157
+ Returns
1158
+ -------
1159
+ `None`
1160
+ """
1065
1161
  c = kwargs.get("c", None)
1066
1162
  self.colors = (
1067
1163
  kwargs.get("colors", self.sorted_names[: len(np.unique(c))]) if c is not None else None
1068
1164
  )
1165
+
1069
1166
  plt.scatter(
1070
1167
  self.reduced_data_tsne[:, 0],
1071
1168
  self.reduced_data_tsne[:, 1],
1072
1169
  cmap=matplotlib.colors.ListedColormap(self.colors) if c is not None else None,
1073
1170
  c=c,
1074
1171
  )
1172
+
1075
1173
  if c is not None:
1076
1174
  cb = plt.colorbar()
1077
1175
  loc = np.arange(0, max(c), max(c) / float(len(self.colors)))
1078
1176
  cb.set_ticks(loc)
1079
1177
  cb.set_ticklabels(np.unique(c))
1178
+
1080
1179
  plt.title("t-SNE Visualization of Latent Space")
1081
1180
  plt.xlabel("t-SNE 1")
1082
1181
  plt.ylabel("t-SNE 2")
1083
1182
  plt.show()
1084
1183
 
1184
+ def _viz_radviz(self, data: pd.DataFrame, color_column: str, title: str) -> None:
1185
+ """
1186
+ Visualize the data using RadViz.
1187
+
1188
+ Parameters
1189
+ ----------
1190
+ data : `pd.DataFrame`
1191
+ The data to visualize.
1192
+ color_column : `str`
1193
+ The column to use for coloring.
1194
+ title : `str`
1195
+ The title of the plot.
1196
+
1197
+ Returns
1198
+ -------
1199
+ `None`
1200
+ """
1201
+ data_normalized = data.copy(deep=True)
1202
+ data_normalized.iloc[:, :-1] = (
1203
+ 2.0
1204
+ * (data_normalized.iloc[:, :-1] - data_normalized.iloc[:, :-1].min())
1205
+ / (data_normalized.iloc[:, :-1].max() - data_normalized.iloc[:, :-1].min())
1206
+ - 1
1207
+ )
1208
+ radviz(data_normalized, color_column, color=self.colors)
1209
+ plt.title(title)
1210
+ plt.show()
1211
+
1212
+ def _viz_weights(
1213
+ self, cmap: str = "viridis", aspect: str = "auto", highlight: bool = True, **kwargs
1214
+ ) -> None:
1215
+ """
1216
+ Visualize the encoder layer weights of the model.
1217
+
1218
+ Parameters
1219
+ ----------
1220
+ cmap : `str`, optional
1221
+ The colormap for visualization (default is `"viridis"`).
1222
+ aspect : `str`, optional
1223
+ Aspect ratio for the visualization (default is `"auto"`).
1224
+ highlight : `bool`, optional
1225
+ Whether to highlight the maximum weights (default is `True`).
1226
+ **kwargs : `dict`, optional
1227
+ Additional keyword arguments for customization.
1228
+
1229
+ Returns
1230
+ -------
1231
+ `None`
1232
+ """
1233
+ title = kwargs.get("title", "Encoder Layer Weights (Dense Layer)")
1234
+ y_labels = kwargs.get("y_labels", None)
1235
+ cmap_highlight = kwargs.get("cmap_highlight", "Pastel1")
1236
+ highlight_mask = np.zeros_like(self.encoder_weights, dtype=bool)
1237
+
1238
+ plt.imshow(self.encoder_weights, cmap=cmap, aspect=aspect)
1239
+ plt.colorbar()
1240
+ plt.title(title)
1241
+ if y_labels is not None:
1242
+ plt.yticks(ticks=np.arange(self.encoder_weights.shape[0]), labels=y_labels)
1243
+ if highlight:
1244
+ for i, j in enumerate(self.encoder_weights.argmax(axis=1)):
1245
+ highlight_mask[i, j] = True
1246
+ plt.imshow(
1247
+ np.ma.masked_where(~highlight_mask, self.encoder_weights),
1248
+ cmap=cmap_highlight,
1249
+ alpha=0.5,
1250
+ aspect=aspect,
1251
+ )
1252
+ plt.show()
1253
+
1254
+ def _statistics(self, data_input: DataFrame) -> DataFrame:
1255
+ """
1256
+ Compute statistical summaries of the input data.
1257
+
1258
+ Parameters
1259
+ ----------
1260
+ data_input : `DataFrame`
1261
+ The data to compute statistics for.
1262
+
1263
+ Returns
1264
+ -------
1265
+ `DataFrame` : The statistical summary of the input data.
1266
+ """
1267
+ data = data_input.copy(deep=True)
1268
+
1269
+ if not pd.api.types.is_string_dtype(data["class"]):
1270
+ data["class"] = data["class"].astype(str)
1271
+
1272
+ data.ffill(inplace=True)
1273
+ grouped_data = data.groupby("class")
1274
+
1275
+ numerical_stats = grouped_data.agg(["mean", "min", "max", "std", "median"])
1276
+ numerical_stats.columns = ["_".join(col).strip() for col in numerical_stats.columns.values]
1277
+
1278
+ def get_mode(x):
1279
+ mode_series = x.mode()
1280
+ return mode_series.iloc[0] if not mode_series.empty else None
1281
+
1282
+ mode_stats = grouped_data.apply(get_mode, include_groups=False)
1283
+ mode_stats.columns = [f"{col}_mode" for col in mode_stats.columns]
1284
+ combined_stats = pd.concat([numerical_stats, mode_stats], axis=1)
1285
+
1286
+ return combined_stats.T
1287
+
1085
1288
 
1086
1289
  ########################################################################################
1087
1290