validmind 2.0.1__py3-none-any.whl → 2.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. validmind/__init__.py +4 -1
  2. validmind/__version__.py +1 -1
  3. validmind/ai.py +197 -0
  4. validmind/api_client.py +16 -4
  5. validmind/client.py +23 -3
  6. validmind/datasets/classification/customer_churn.py +2 -2
  7. validmind/datasets/nlp/__init__.py +5 -0
  8. validmind/datasets/nlp/cnn_dailymail.py +98 -0
  9. validmind/datasets/nlp/datasets/cnn_dailymail_100_with_predictions.csv +255 -0
  10. validmind/datasets/nlp/datasets/cnn_dailymail_500_with_predictions.csv +1277 -0
  11. validmind/datasets/nlp/datasets/sentiments_with_predictions.csv +4847 -0
  12. validmind/errors.py +11 -1
  13. validmind/models/huggingface.py +2 -2
  14. validmind/models/pytorch.py +3 -3
  15. validmind/models/sklearn.py +4 -4
  16. validmind/tests/__init__.py +47 -9
  17. validmind/tests/data_validation/DatasetDescription.py +0 -1
  18. validmind/tests/data_validation/nlp/StopWords.py +1 -6
  19. validmind/tests/data_validation/nlp/TextDescription.py +20 -9
  20. validmind/tests/decorator.py +189 -0
  21. validmind/tests/model_validation/MeteorScore.py +92 -0
  22. validmind/tests/model_validation/RegardHistogram.py +5 -6
  23. validmind/tests/model_validation/RegardScore.py +3 -5
  24. validmind/tests/model_validation/RougeMetrics.py +6 -4
  25. validmind/tests/model_validation/SelfCheckNLIScore.py +112 -0
  26. validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +17 -22
  27. validmind/tests/model_validation/sklearn/ClassifierPerformance.py +3 -1
  28. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +30 -4
  29. validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +9 -3
  30. validmind/tests/model_validation/statsmodels/RegressionModelsPerformance.py +1 -1
  31. validmind/tests/prompt_validation/ai_powered_test.py +2 -0
  32. validmind/unit_metrics/__init__.py +0 -2
  33. validmind/unit_metrics/composite.py +275 -0
  34. validmind/unit_metrics/regression/GiniCoefficient.py +39 -0
  35. validmind/unit_metrics/regression/HuberLoss.py +27 -0
  36. validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +36 -0
  37. validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +22 -0
  38. validmind/unit_metrics/regression/MeanBiasDeviation.py +22 -0
  39. validmind/unit_metrics/regression/QuantileLoss.py +25 -0
  40. validmind/unit_metrics/regression/sklearn/AdjustedRSquaredScore.py +27 -0
  41. validmind/unit_metrics/regression/sklearn/MeanAbsoluteError.py +22 -0
  42. validmind/unit_metrics/regression/sklearn/MeanSquaredError.py +22 -0
  43. validmind/unit_metrics/regression/sklearn/RSquaredScore.py +22 -0
  44. validmind/unit_metrics/regression/sklearn/RootMeanSquaredError.py +23 -0
  45. validmind/unit_metrics/sklearn/classification/Accuracy.py +2 -0
  46. validmind/unit_metrics/sklearn/classification/F1.py +2 -0
  47. validmind/unit_metrics/sklearn/classification/Precision.py +2 -0
  48. validmind/unit_metrics/sklearn/classification/ROC_AUC.py +2 -0
  49. validmind/unit_metrics/sklearn/classification/Recall.py +2 -0
  50. validmind/utils.py +17 -1
  51. validmind/vm_models/dataset.py +376 -21
  52. validmind/vm_models/figure.py +52 -17
  53. validmind/vm_models/test/metric.py +33 -30
  54. validmind/vm_models/test/output_template.py +0 -27
  55. validmind/vm_models/test/result_wrapper.py +57 -24
  56. validmind/vm_models/test/test.py +2 -1
  57. validmind/vm_models/test/threshold_test.py +24 -13
  58. validmind/vm_models/test_context.py +7 -0
  59. validmind/vm_models/test_suite/runner.py +1 -1
  60. validmind/vm_models/test_suite/test.py +1 -1
  61. {validmind-2.0.1.dist-info → validmind-2.0.7.dist-info}/METADATA +9 -13
  62. {validmind-2.0.1.dist-info → validmind-2.0.7.dist-info}/RECORD +65 -44
  63. validmind-2.0.7.dist-info/entry_points.txt +3 -0
  64. {validmind-2.0.1.dist-info → validmind-2.0.7.dist-info}/LICENSE +0 -0
  65. {validmind-2.0.1.dist-info → validmind-2.0.7.dist-info}/WHEEL +0 -0
@@ -11,6 +11,7 @@ from dataclasses import dataclass, field
11
11
 
12
12
  import numpy as np
13
13
  import pandas as pd
14
+ import polars as pl
14
15
 
15
16
  from validmind.logging import get_logger
16
17
  from validmind.vm_models.model import VMModel
@@ -35,9 +36,216 @@ class VMDataset(ABC):
35
36
  pass
36
37
 
37
38
  @abstractmethod
38
- def serialize(self):
39
+ def assign_predictions(
40
+ self,
41
+ model,
42
+ prediction_values: list = None,
43
+ prediction_column=None,
44
+ ):
39
45
  """
40
- Serializes the dataset to a dictionary.
46
+ Assigns predictions to the dataset for a given model or prediction values.
47
+ The dataset is updated with a new column containing the predictions.
48
+ """
49
+ pass
50
+
51
+ @abstractmethod
52
+ def get_extra_column(self, column_name):
53
+ """
54
+ Returns the values of the specified extra column.
55
+
56
+ Args:
57
+ column_name (str): The name of the extra column.
58
+
59
+ Returns:
60
+ np.ndarray: The values of the extra column.
61
+ """
62
+ pass
63
+
64
+ @abstractmethod
65
+ def add_extra_column(self, column_name, column_values=None):
66
+ """
67
+ Adds an extra column to the dataset without modifying the dataset `features` and `target` columns.
68
+
69
+ Args:
70
+ column_name (str): The name of the extra column.
71
+ column_values (np.ndarray, optional): The values of the extra column.
72
+ """
73
+ pass
74
+
75
+ @property
76
+ @abstractmethod
77
+ def input_id(self) -> str:
78
+ """
79
+ Returns input id of dataset.
80
+
81
+ Returns:
82
+ str: input_id.
83
+ """
84
+ return self.input_id
85
+
86
+ @property
87
+ @abstractmethod
88
+ def columns(self) -> list:
89
+ """
90
+ Returns the the list of columns in the dataset.
91
+
92
+ Returns:
93
+ List[str]: The columns list.
94
+ """
95
+ pass
96
+
97
+ @property
98
+ @abstractmethod
99
+ def target_column(self) -> str:
100
+ """
101
+ Returns the target column name of the dataset.
102
+
103
+ Returns:
104
+ str: The target column name.
105
+ """
106
+ pass
107
+
108
+ @property
109
+ @abstractmethod
110
+ def feature_columns(self) -> list:
111
+ """
112
+ Returns the feature columns of the dataset. If _feature_columns is None,
113
+ it returns all columns except the target column.
114
+
115
+ Returns:
116
+ list: The list of feature column names.
117
+ """
118
+ pass
119
+
120
+ @property
121
+ @abstractmethod
122
+ def text_column(self) -> str:
123
+ """
124
+ Returns the text column of the dataset.
125
+
126
+ Returns:
127
+ str: The text column name.
128
+ """
129
+ pass
130
+
131
+ @property
132
+ @abstractmethod
133
+ def x(self) -> np.ndarray:
134
+ """
135
+ Returns the input features (X) of the dataset.
136
+
137
+ Returns:
138
+ np.ndarray: The input features.
139
+ """
140
+ pass
141
+
142
+ @property
143
+ @abstractmethod
144
+ def y(self) -> np.ndarray:
145
+ """
146
+ Returns the target variables (y) of the dataset.
147
+
148
+ Returns:
149
+ np.ndarray: The target variables.
150
+ """
151
+ pass
152
+
153
+ @abstractmethod
154
+ def y_pred(self, model_id) -> np.ndarray:
155
+ """
156
+ Returns the prediction values (y_pred) of the dataset for a given model_id.
157
+
158
+ Returns:
159
+ np.ndarray: The prediction values.
160
+ """
161
+ pass
162
+
163
+ @property
164
+ @abstractmethod
165
+ def df(self):
166
+ """
167
+ Returns the dataset as a pandas DataFrame.
168
+
169
+ Returns:
170
+ pd.DataFrame: The dataset as a DataFrame.
171
+ """
172
+ pass
173
+
174
+ @property
175
+ @abstractmethod
176
+ def copy(self):
177
+ """
178
+ Returns a copy of the raw_dataset dataframe.
179
+ """
180
+ pass
181
+
182
+ @abstractmethod
183
+ def x_df(self):
184
+ """
185
+ Returns the non target and prediction columns.
186
+
187
+ Returns:
188
+ pd.DataFrame: The non target and prediction columns .
189
+ """
190
+ pass
191
+
192
+ @abstractmethod
193
+ def y_df(self):
194
+ """
195
+ Returns the target columns (y) of the dataset.
196
+
197
+ Returns:
198
+ pd.DataFrame: The target columns.
199
+ """
200
+ pass
201
+
202
+ @abstractmethod
203
+ def y_pred_df(self, model_id):
204
+ """
205
+ Returns the target columns (y) of the dataset.
206
+
207
+ Returns:
208
+ pd.DataFrame: The target columns.
209
+ """
210
+ pass
211
+
212
+ @abstractmethod
213
+ def prediction_column(self, model_id) -> str:
214
+ """
215
+ Returns the prediction column name of the dataset.
216
+
217
+ Returns:
218
+ str: The prediction column name.
219
+ """
220
+ pass
221
+
222
+ @abstractmethod
223
+ def get_features_columns(self):
224
+ """
225
+ Returns the column names of the feature variables.
226
+
227
+ Returns:
228
+ List[str]: The column names of the feature variables.
229
+ """
230
+ pass
231
+
232
+ @abstractmethod
233
+ def get_numeric_features_columns(self):
234
+ """
235
+ Returns the column names of the numeric feature variables.
236
+
237
+ Returns:
238
+ List[str]: The column names of the numeric feature variables.
239
+ """
240
+ pass
241
+
242
+ @abstractmethod
243
+ def get_categorical_features_columns(self):
244
+ """
245
+ Returns the column names of the categorical feature variables.
246
+
247
+ Returns:
248
+ List[str]: The column names of the categorical feature variables.
41
249
  """
42
250
  pass
43
251
 
@@ -134,7 +342,7 @@ class NumpyDataset(VMDataset):
134
342
  # initialize target column
135
343
  self._target_column = target_column
136
344
  # initialize extra columns
137
- self.__set_extra_columns(extra_columns, model)
345
+ self.__set_extra_columns(extra_columns)
138
346
  # initialize feature columns
139
347
  self.__set_feature_columns(feature_columns)
140
348
  # initialize text column, target class labels and options
@@ -144,7 +352,7 @@ class NumpyDataset(VMDataset):
144
352
  if model:
145
353
  self.assign_predictions(model)
146
354
 
147
- def __set_extra_columns(self, extra_columns, model):
355
+ def __set_extra_columns(self, extra_columns):
148
356
  if extra_columns is None:
149
357
  extra_columns = {
150
358
  "prediction_columns": {},
@@ -191,14 +399,29 @@ class NumpyDataset(VMDataset):
191
399
  return model.input_id in self._extra_columns.get("prediction_columns", {})
192
400
 
193
401
  def __assign_prediction_values(self, model, pred_column, prediction_values):
402
+ # Link the prediction column with the model
194
403
  self._extra_columns.setdefault("prediction_columns", {})[
195
404
  model.input_id
196
405
  ] = pred_column
197
- self._raw_dataset = np.hstack(
198
- (self._raw_dataset, np.array(prediction_values).reshape(-1, 1))
406
+
407
+ # Check if the predictions are multi-dimensional (e.g., embeddings)
408
+ is_multi_dimensional = (
409
+ isinstance(prediction_values, np.ndarray) and prediction_values.ndim > 1
199
410
  )
200
- self._columns.append(pred_column)
201
- self._df[pred_column] = prediction_values
411
+
412
+ if is_multi_dimensional:
413
+ # For multi-dimensional outputs, convert to a list of lists to store in DataFrame
414
+ self._df[pred_column] = list(map(list, prediction_values))
415
+ else:
416
+ # If not multi-dimensional or a standard numpy array, reshape for compatibility
417
+ self._raw_dataset = np.hstack(
418
+ (self._raw_dataset, np.array(prediction_values).reshape(-1, 1))
419
+ )
420
+ self._df[pred_column] = prediction_values
421
+
422
+ # Update the dataset columns list
423
+ if pred_column not in self._columns:
424
+ self._columns.append(pred_column)
202
425
 
203
426
  def assign_predictions( # noqa: C901 - we need to simplify this method
204
427
  self,
@@ -262,6 +485,59 @@ class NumpyDataset(VMDataset):
262
485
  f"Prediction column {self._extra_columns['prediction_columns'][model.input_id]} already linked to the {model.input_id}"
263
486
  )
264
487
 
488
+ def get_extra_column(self, column_name):
489
+ """
490
+ Returns the values of the specified extra column.
491
+
492
+ Args:
493
+ column_name (str): The name of the extra column.
494
+
495
+ Returns:
496
+ np.ndarray: The values of the extra column.
497
+ """
498
+ if column_name not in self.extra_columns:
499
+ raise ValueError(f"Column {column_name} is not an extra column")
500
+
501
+ return self._df[column_name]
502
+
503
+ def add_extra_column(self, column_name, column_values=None):
504
+ """
505
+ Adds an extra column to the dataset without modifying the dataset `features` and `target` columns.
506
+
507
+ Args:
508
+ column_name (str): The name of the extra column.
509
+ column_values (np.ndarray, optional): The values of the extra column.
510
+ """
511
+ if column_name in self.extra_columns:
512
+ logger.info(f"Column {column_name} already registered as an extra column")
513
+ return
514
+
515
+ # The column name already exists in the dataset so we just assign the extra column
516
+ if column_name in self.columns:
517
+ self._extra_columns[column_name] = column_name
518
+ logger.info(
519
+ f"Column {column_name} exists in the dataset, registering as an extra column"
520
+ )
521
+ return
522
+
523
+ if column_values is None:
524
+ raise ValueError(
525
+ "Column values must be provided when the column doesn't exist in the dataset"
526
+ )
527
+
528
+ if len(column_values) != self.df.shape[0]:
529
+ raise ValueError(
530
+ "Length of column values doesn't match number of rows of the dataset"
531
+ )
532
+
533
+ self._raw_dataset = np.hstack(
534
+ (self._raw_dataset, np.array(column_values).reshape(-1, 1))
535
+ )
536
+ self._columns.append(column_name)
537
+ self._df[column_name] = column_values
538
+ self._extra_columns[column_name] = column_name
539
+ logger.info(f"Column {column_name} added as an extra column")
540
+
265
541
  @property
266
542
  def raw_dataset(self) -> np.ndarray:
267
543
  """
@@ -399,19 +675,42 @@ class NumpyDataset(VMDataset):
399
675
 
400
676
  def y_pred(self, model_id) -> np.ndarray:
401
677
  """
402
- Returns the prediction variable (y_pred) of the dataset.
678
+ Returns the prediction variables for a given model_id, accommodating
679
+ both scalar predictions and multi-dimensional outputs such as embeddings.
680
+
681
+ Args:
682
+ model_id (str): The ID of the model whose predictions are sought.
403
683
 
404
684
  Returns:
405
- np.ndarray: The prediction variables.
406
- """
407
- return self.raw_dataset[
408
- :,
409
- [
410
- self.columns.index(name)
411
- for name in self.columns
412
- if name == self.prediction_column(model_id=model_id)
413
- ],
414
- ].flatten()
685
+ np.ndarray: The prediction variables, either as a flattened array for
686
+ scalar predictions or as an array of arrays for multi-dimensional outputs.
687
+ """
688
+ pred_column = self.prediction_column(model_id)
689
+
690
+ # First, attempt to retrieve the prediction data from the DataFrame
691
+ if hasattr(self, "_df") and pred_column in self._df.columns:
692
+ predictions = self._df[pred_column].to_numpy()
693
+
694
+ # Check if the predictions are stored as objects (e.g., lists for embeddings)
695
+ if self._df[pred_column].dtype == object:
696
+ # Attempt to convert lists to a numpy array
697
+ try:
698
+ predictions = np.stack(predictions)
699
+ except ValueError as e:
700
+ # Handling cases where predictions cannot be directly stacked
701
+ raise ValueError(f"Error stacking prediction arrays: {e}")
702
+ else:
703
+ # Fallback to using the raw numpy dataset if DataFrame is not available or suitable
704
+ try:
705
+ predictions = self.raw_dataset[
706
+ :, self.columns.index(pred_column)
707
+ ].flatten()
708
+ except IndexError as e:
709
+ raise ValueError(
710
+ f"Prediction column '{pred_column}' not found in raw dataset: {e}"
711
+ )
712
+
713
+ return predictions
415
714
 
416
715
  @property
417
716
  def type(self) -> str:
@@ -608,10 +907,15 @@ class DataFrameDataset(NumpyDataset):
608
907
 
609
908
  Args:
610
909
  raw_dataset (pd.DataFrame): The raw dataset as a pandas DataFrame.
910
+ input_id (str, optional): Identifier for the dataset. Defaults to None.
911
+ model (VMModel, optional): Model associated with the dataset. Defaults to None.
611
912
  target_column (str, optional): The target column of the dataset. Defaults to None.
913
+ extra_columns (dict, optional): Extra columns to include in the dataset. Defaults to None.
612
914
  feature_columns (list, optional): The feature columns of the dataset. Defaults to None.
613
- text_column (str, optional): The text column name of the dataset for nlp tasks. Defaults to None.
614
- target_class_labels (Dict, optional): The class labels for the target columns. Defaults to None.
915
+ text_column (str, optional): The text column name of the dataset for NLP tasks. Defaults to None.
916
+ target_class_labels (dict, optional): The class labels for the target columns. Defaults to None.
917
+ options (dict, optional): Additional options for the dataset. Defaults to None.
918
+ date_time_index (bool, optional): Whether to use date-time index. Defaults to False.
615
919
  """
616
920
  index = None
617
921
  if isinstance(raw_dataset.index, pd.Index):
@@ -634,6 +938,57 @@ class DataFrameDataset(NumpyDataset):
634
938
  )
635
939
 
636
940
 
941
+ @dataclass
942
+ class PolarsDataset(NumpyDataset):
943
+ """
944
+ VM dataset implementation for Polars DataFrame.
945
+ """
946
+
947
+ def __init__(
948
+ self,
949
+ raw_dataset: pl.DataFrame,
950
+ input_id: str = None,
951
+ model: VMModel = None,
952
+ target_column: str = None,
953
+ extra_columns: dict = None,
954
+ feature_columns: list = None,
955
+ text_column: str = None,
956
+ target_class_labels: dict = None,
957
+ options: dict = None,
958
+ date_time_index: bool = False,
959
+ ):
960
+ """
961
+ Initializes a PolarsDataset instance.
962
+
963
+ Args:
964
+ raw_dataset (pl.DataFrame): The raw dataset as a Polars DataFrame.
965
+ input_id (str, optional): Identifier for the dataset. Defaults to None.
966
+ model (VMModel, optional): Model associated with the dataset. Defaults to None.
967
+ target_column (str, optional): The target column of the dataset. Defaults to None.
968
+ extra_columns (dict, optional): Extra columns to include in the dataset. Defaults to None.
969
+ feature_columns (list, optional): The feature columns of the dataset. Defaults to None.
970
+ text_column (str, optional): The text column name of the dataset for NLP tasks. Defaults to None.
971
+ target_class_labels (dict, optional): The class labels for the target columns. Defaults to None.
972
+ options (dict, optional): Additional options for the dataset. Defaults to None.
973
+ date_time_index (bool, optional): Whether to use date-time index. Defaults to False.
974
+ """
975
+ super().__init__(
976
+ raw_dataset=raw_dataset.to_numpy(),
977
+ input_id=input_id,
978
+ model=model,
979
+ index_name=None,
980
+ index=None,
981
+ columns=raw_dataset.columns,
982
+ target_column=target_column,
983
+ extra_columns=extra_columns,
984
+ feature_columns=feature_columns,
985
+ text_column=text_column,
986
+ target_class_labels=target_class_labels,
987
+ options=options,
988
+ date_time_index=date_time_index,
989
+ )
990
+
991
+
637
992
  @dataclass
638
993
  class TorchDataset(NumpyDataset):
639
994
  """
@@ -21,6 +21,18 @@ from ..errors import InvalidFigureForObjectError, UnsupportedFigureError
21
21
  from ..utils import get_full_typename
22
22
 
23
23
 
24
+ def is_matplotlib_figure(figure) -> bool:
25
+ return isinstance(figure, matplotlib.figure.Figure)
26
+
27
+
28
+ def is_plotly_figure(figure) -> bool:
29
+ return isinstance(figure, (go.Figure, go.FigureWidget))
30
+
31
+
32
+ def is_png_image(figure) -> bool:
33
+ return isinstance(figure, bytes)
34
+
35
+
24
36
  @dataclass
25
37
  class Figure:
26
38
  """
@@ -52,22 +64,10 @@ class Figure:
52
64
  if (
53
65
  not client_config.running_on_colab
54
66
  and self.figure
55
- and self.is_plotly_figure()
67
+ and is_plotly_figure(self.figure)
56
68
  ):
57
69
  self.figure = go.FigureWidget(self.figure)
58
70
 
59
- def is_matplotlib_figure(self) -> bool:
60
- """
61
- Returns True if the figure is a matplotlib figure
62
- """
63
- return isinstance(self.figure, matplotlib.figure.Figure)
64
-
65
- def is_plotly_figure(self) -> bool:
66
- """
67
- Returns True if the figure is a plotly figure
68
- """
69
- return isinstance(self.figure, (go.Figure, go.FigureWidget))
70
-
71
71
  def _get_for_object_type(self):
72
72
  """
73
73
  Returns the type of the object this figure is for
@@ -91,7 +91,7 @@ class Figure:
91
91
  we would render images as-is, but Plotly FigureWidgets don't work well
92
92
  on Google Colab when they are combined with ipywidgets.
93
93
  """
94
- if self.is_matplotlib_figure():
94
+ if is_matplotlib_figure(self.figure):
95
95
  tmpfile = BytesIO()
96
96
  self.figure.savefig(tmpfile, format="png")
97
97
  encoded = base64.b64encode(tmpfile.getvalue()).decode("utf-8")
@@ -101,7 +101,7 @@ class Figure:
101
101
  """
102
102
  )
103
103
 
104
- elif self.is_plotly_figure():
104
+ elif is_plotly_figure(self.figure):
105
105
  # FigureWidget can be displayed as-is but not on Google Colab. In this case
106
106
  # we just return the image representation of the figure.
107
107
  if client_config.running_on_colab:
@@ -114,6 +114,15 @@ class Figure:
114
114
  )
115
115
  else:
116
116
  return self.figure
117
+
118
+ elif is_png_image(self.figure):
119
+ encoded = base64.b64encode(self.figure).decode("utf-8")
120
+ return widgets.HTML(
121
+ value=f"""
122
+ <img style="width:100%; height: auto;" src="data:image/png;base64,{encoded}"/>
123
+ """
124
+ )
125
+
117
126
  else:
118
127
  raise UnsupportedFigureError(
119
128
  f"Figure type {type(self.figure)} not supported for plotting"
@@ -129,15 +138,38 @@ class Figure:
129
138
  "metadata": json.dumps(self.metadata, allow_nan=False),
130
139
  }
131
140
 
141
+ def _get_b64_url(self):
142
+ """
143
+ Returns a base64 encoded URL for the figure
144
+ """
145
+ if is_matplotlib_figure(self.figure):
146
+ buffer = BytesIO()
147
+ self.figure.savefig(buffer, format="png")
148
+ buffer.seek(0)
149
+
150
+ b64_data = base64.b64encode(buffer.read()).decode("utf-8")
151
+
152
+ return f"data:image/png;base64,{b64_data}"
153
+
154
+ elif is_plotly_figure(self.figure):
155
+ bytes = self.figure.to_image(format="png")
156
+ b64_data = base64.b64encode(bytes).decode("utf-8")
157
+
158
+ return f"data:image/png;base64,{b64_data}"
159
+
160
+ raise UnsupportedFigureError(
161
+ f"Unrecognized figure type: {get_full_typename(self.figure)}"
162
+ )
163
+
132
164
  def serialize_files(self):
133
165
  """Creates a `requests`-compatible files object to be sent to the API"""
134
- if self.is_matplotlib_figure():
166
+ if is_matplotlib_figure(self.figure):
135
167
  buffer = BytesIO()
136
168
  self.figure.savefig(buffer, bbox_inches="tight")
137
169
  buffer.seek(0)
138
170
  return {"image": (f"{self.key}.png", buffer, "image/png")}
139
171
 
140
- elif self.is_plotly_figure():
172
+ elif is_plotly_figure(self.figure):
141
173
  # When using plotly, we need to use we will produce two files:
142
174
  # - a JSON file that will be used to display the figure in the UI
143
175
  # - a PNG file that will be used to display the figure in documents
@@ -154,6 +186,9 @@ class Figure:
154
186
  ),
155
187
  }
156
188
 
189
+ elif is_png_image(self.figure):
190
+ return {"image": (f"{self.key}.png", self.figure, "image/png")}
191
+
157
192
  raise UnsupportedFigureError(
158
193
  f"Unrecognized figure type: {get_full_typename(self.figure)}"
159
194
  )
@@ -6,12 +6,14 @@
6
6
  Class for storing ValidMind metric objects and associated
7
7
  data for display and reporting purposes
8
8
  """
9
+ import os
9
10
  from abc import abstractmethod
10
11
  from dataclasses import dataclass
11
12
  from typing import ClassVar, List, Optional, Union
12
13
 
13
14
  import pandas as pd
14
15
 
16
+ from ...ai import generate_description
15
17
  from ...errors import MissingCacheResultsArgumentsError
16
18
  from ...utils import clean_docstring
17
19
  from ..figure import Figure
@@ -74,41 +76,42 @@ class Metric(Test):
74
76
  "Metric must provide a metric value or figures to cache_results"
75
77
  )
76
78
 
77
- # At a minimum, send the metric description
78
- result_metadata = [
79
- {
80
- "content_id": f"metric_description:{self.test_id}",
81
- "text": clean_docstring(self.description()),
82
- }
83
- ]
84
-
85
- result_summary = self.summary(metric_value)
86
-
87
- result_wrapper = MetricResultWrapper(
88
- result_id=self.test_id,
89
- result_metadata=result_metadata,
90
- inputs=self.get_accessed_inputs(),
91
- output_template=self.output_template,
92
- )
93
-
94
- # We can send an empty result to push an empty metric with a summary and plots
95
- metric_result_value = metric_value if metric_value is not None else {}
96
-
97
- result_wrapper.metric = MetricResult(
98
- # key=self.key,
99
- # Now using the fully qualified test ID as `key`.
100
- # Ideally the backend is updated to use `test_id` instead of `key`.
79
+ metric = MetricResult(
101
80
  key=self.test_id,
102
81
  ref_id=self._ref_id,
103
- value=metric_result_value,
82
+ value=metric_value if metric_value is not None else {},
104
83
  value_formatter=self.value_formatter,
105
- summary=result_summary,
84
+ summary=self.summary(metric_value),
106
85
  )
107
86
 
108
- # Allow metrics to attach figures to the test suite result
109
- if figures:
110
- result_wrapper.figures = figures
87
+ if (
88
+ os.environ.get("VALIDMIND_LLM_DESCRIPTIONS_ENABLED", "false").lower()
89
+ == "true"
90
+ ):
91
+ revision_name = "Generated by ValidMind AI"
92
+ description = generate_description(
93
+ test_name=self.test_id,
94
+ test_description=self.description().splitlines()[0],
95
+ test_results=metric.serialize()["value"],
96
+ test_summary=metric.serialize()["summary"],
97
+ figures=figures,
98
+ )
99
+ else:
100
+ revision_name = "Default Description"
101
+ description = clean_docstring(self.description())
102
+
103
+ description_metadata = {
104
+ "content_id": f"metric_description:{self.test_id}::{revision_name}",
105
+ "text": description,
106
+ }
111
107
 
112
- self.result = result_wrapper
108
+ self.result = MetricResultWrapper(
109
+ result_id=self.test_id,
110
+ result_metadata=[description_metadata],
111
+ metric=metric,
112
+ figures=figures,
113
+ inputs=self.get_accessed_inputs(),
114
+ output_template=self.output_template,
115
+ )
113
116
 
114
117
  return self.result