validmind 2.1.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/ai.py +3 -3
  3. validmind/api_client.py +2 -3
  4. validmind/client.py +68 -25
  5. validmind/datasets/llm/rag/__init__.py +11 -0
  6. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_1.csv +30 -0
  7. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_2.csv +30 -0
  8. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_3.csv +53 -0
  9. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_4.csv +53 -0
  10. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_5.csv +53 -0
  11. validmind/datasets/llm/rag/rfp.py +41 -0
  12. validmind/html_templates/__init__.py +0 -0
  13. validmind/html_templates/content_blocks.py +89 -14
  14. validmind/models/__init__.py +7 -4
  15. validmind/models/foundation.py +8 -34
  16. validmind/models/function.py +51 -0
  17. validmind/models/huggingface.py +16 -46
  18. validmind/models/metadata.py +42 -0
  19. validmind/models/pipeline.py +66 -0
  20. validmind/models/pytorch.py +8 -42
  21. validmind/models/r_model.py +33 -82
  22. validmind/models/sklearn.py +39 -38
  23. validmind/template.py +8 -26
  24. validmind/tests/__init__.py +43 -20
  25. validmind/tests/data_validation/ANOVAOneWayTable.py +1 -1
  26. validmind/tests/data_validation/ChiSquaredFeaturesTable.py +1 -1
  27. validmind/tests/data_validation/DescriptiveStatistics.py +2 -4
  28. validmind/tests/data_validation/Duplicates.py +1 -1
  29. validmind/tests/data_validation/IsolationForestOutliers.py +2 -2
  30. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +1 -1
  31. validmind/tests/data_validation/TargetRateBarPlots.py +1 -1
  32. validmind/tests/data_validation/nlp/LanguageDetection.py +59 -0
  33. validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +48 -0
  34. validmind/tests/data_validation/nlp/Punctuations.py +11 -12
  35. validmind/tests/data_validation/nlp/Sentiment.py +57 -0
  36. validmind/tests/data_validation/nlp/Toxicity.py +45 -0
  37. validmind/tests/decorator.py +2 -2
  38. validmind/tests/model_validation/BertScore.py +100 -98
  39. validmind/tests/model_validation/BleuScore.py +93 -64
  40. validmind/tests/model_validation/ContextualRecall.py +74 -91
  41. validmind/tests/model_validation/MeteorScore.py +86 -74
  42. validmind/tests/model_validation/RegardScore.py +103 -121
  43. validmind/tests/model_validation/RougeScore.py +118 -0
  44. validmind/tests/model_validation/TokenDisparity.py +84 -121
  45. validmind/tests/model_validation/ToxicityScore.py +109 -123
  46. validmind/tests/model_validation/embeddings/CosineSimilarityComparison.py +96 -0
  47. validmind/tests/model_validation/embeddings/CosineSimilarityHeatmap.py +71 -0
  48. validmind/tests/model_validation/embeddings/EuclideanDistanceComparison.py +92 -0
  49. validmind/tests/model_validation/embeddings/EuclideanDistanceHeatmap.py +69 -0
  50. validmind/tests/model_validation/embeddings/PCAComponentsPairwisePlots.py +78 -0
  51. validmind/tests/model_validation/embeddings/StabilityAnalysis.py +35 -23
  52. validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +3 -0
  53. validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +7 -1
  54. validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +3 -0
  55. validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +3 -0
  56. validmind/tests/model_validation/embeddings/TSNEComponentsPairwisePlots.py +99 -0
  57. validmind/tests/model_validation/ragas/AnswerCorrectness.py +131 -0
  58. validmind/tests/model_validation/ragas/AnswerRelevance.py +134 -0
  59. validmind/tests/model_validation/ragas/AnswerSimilarity.py +119 -0
  60. validmind/tests/model_validation/ragas/AspectCritique.py +167 -0
  61. validmind/tests/model_validation/ragas/ContextEntityRecall.py +133 -0
  62. validmind/tests/model_validation/ragas/ContextPrecision.py +123 -0
  63. validmind/tests/model_validation/ragas/ContextRecall.py +123 -0
  64. validmind/tests/model_validation/ragas/ContextRelevancy.py +114 -0
  65. validmind/tests/model_validation/ragas/Faithfulness.py +119 -0
  66. validmind/tests/model_validation/ragas/utils.py +66 -0
  67. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -7
  68. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +8 -9
  69. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +5 -10
  70. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +3 -2
  71. validmind/tests/model_validation/sklearn/ROCCurve.py +2 -1
  72. validmind/tests/model_validation/sklearn/RegressionR2Square.py +1 -1
  73. validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +2 -3
  74. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +7 -11
  75. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +3 -4
  76. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +1 -1
  77. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +1 -1
  78. validmind/tests/model_validation/statsmodels/RegressionModelInsampleComparison.py +1 -1
  79. validmind/tests/model_validation/statsmodels/RegressionModelOutsampleComparison.py +1 -1
  80. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +1 -1
  81. validmind/tests/model_validation/statsmodels/RegressionModelsCoeffs.py +1 -1
  82. validmind/tests/model_validation/statsmodels/RegressionModelsPerformance.py +1 -1
  83. validmind/tests/model_validation/statsmodels/ScorecardHistogram.py +5 -6
  84. validmind/unit_metrics/__init__.py +26 -49
  85. validmind/unit_metrics/composite.py +5 -1
  86. validmind/unit_metrics/regression/sklearn/AdjustedRSquaredScore.py +1 -1
  87. validmind/utils.py +56 -6
  88. validmind/vm_models/__init__.py +1 -1
  89. validmind/vm_models/dataset/__init__.py +7 -0
  90. validmind/vm_models/dataset/dataset.py +558 -0
  91. validmind/vm_models/dataset/utils.py +146 -0
  92. validmind/vm_models/model.py +97 -72
  93. validmind/vm_models/test/result_wrapper.py +61 -24
  94. validmind/vm_models/test_context.py +1 -1
  95. validmind/vm_models/test_suite/summary.py +3 -4
  96. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/METADATA +5 -3
  97. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/RECORD +100 -75
  98. validmind/models/catboost.py +0 -33
  99. validmind/models/statsmodels.py +0 -50
  100. validmind/models/xgboost.py +0 -30
  101. validmind/tests/model_validation/BertScoreAggregate.py +0 -90
  102. validmind/tests/model_validation/RegardHistogram.py +0 -148
  103. validmind/tests/model_validation/RougeMetrics.py +0 -147
  104. validmind/tests/model_validation/RougeMetricsAggregate.py +0 -133
  105. validmind/tests/model_validation/SelfCheckNLIScore.py +0 -112
  106. validmind/tests/model_validation/ToxicityHistogram.py +0 -136
  107. validmind/vm_models/dataset.py +0 -1303
  108. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/LICENSE +0 -0
  109. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/WHEEL +0 -0
  110. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/entry_points.txt +0 -0
@@ -109,12 +109,11 @@ class RobustnessDiagnosis(ThresholdTest):
109
109
 
110
110
  features_list = self.params["features_columns"]
111
111
  if features_list is None:
112
- features_list = self.inputs.datasets[0].get_numeric_features_columns()
112
+ features_list = self.inputs.datasets[0].feature_columns
113
113
 
114
114
  # Check if all elements from features_list are present in the numerical feature columns
115
115
  all_present = all(
116
- elem in self.inputs.datasets[0].get_numeric_features_columns()
117
- for elem in features_list
116
+ elem in self.inputs.datasets[0].feature_columns for elem in features_list
118
117
  )
119
118
  if not all_present:
120
119
  raise ValueError(
@@ -11,6 +11,7 @@ import shap
11
11
 
12
12
  from validmind.errors import UnsupportedModelForSHAPError
13
13
  from validmind.logging import get_logger
14
+ from validmind.models import CatBoostModel, SKlearnModel, StatsModelsModel
14
15
  from validmind.vm_models import Figure, Metric
15
16
 
16
17
  logger = get_logger(__name__)
@@ -131,20 +132,14 @@ class SHAPGlobalImportance(Metric):
131
132
  )
132
133
 
133
134
  def run(self):
134
- model_library = self.inputs.model.model_library()
135
- if model_library in [
136
- "statsmodels",
137
- "pytorch",
138
- "catboost",
139
- "transformers",
140
- "FoundationModel",
141
- "R",
142
- ]:
143
- logger.info(f"Skiping SHAP for {model_library} models")
135
+ if not isinstance(self.inputs.model, SKlearnModel) or isinstance(
136
+ self.inputs.model, (CatBoostModel, StatsModelsModel)
137
+ ):
138
+ logger.info(f"Skiping SHAP for {self.inputs.model.library} models")
144
139
  return
145
140
 
146
141
  trained_model = self.inputs.model.model
147
- model_class = self.inputs.model.model_class()
142
+ model_class = self.inputs.model.class_
148
143
 
149
144
  # the shap library generates a bunch of annoying warnings that we don't care about
150
145
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -176,6 +171,7 @@ class SHAPGlobalImportance(Metric):
176
171
  ),
177
172
  )
178
173
  else:
174
+ model_class = "<ExternalModel>" if model_class is None else model_class
179
175
  raise UnsupportedModelForSHAPError(
180
176
  f"Model {model_class} not supported for SHAP importance."
181
177
  )
@@ -113,7 +113,7 @@ class WeakspotsDiagnosis(ThresholdTest):
113
113
  raise ValueError(f"Threshold for metric {metric} is missing")
114
114
 
115
115
  if self.params["features_columns"] is None:
116
- features_list = self.inputs.datasets[0].get_features_columns()
116
+ features_list = self.inputs.datasets[0].feature_columns
117
117
  else:
118
118
  features_list = self.params["features_columns"]
119
119
 
@@ -124,8 +124,7 @@ class WeakspotsDiagnosis(ThresholdTest):
124
124
 
125
125
  # Check if all elements from features_list are present in the feature columns
126
126
  all_present = all(
127
- elem in self.inputs.datasets[0].get_features_columns()
128
- for elem in features_list
127
+ elem in self.inputs.datasets[0].feature_columns for elem in features_list
129
128
  )
130
129
  if not all_present:
131
130
  raise ValueError(
@@ -150,7 +149,7 @@ class WeakspotsDiagnosis(ThresholdTest):
150
149
  results_headers.extend(self.default_metrics.keys())
151
150
  for feature in features_list:
152
151
  bins = 10
153
- if feature in self.inputs.datasets[0].get_categorical_features_columns():
152
+ if feature in self.inputs.datasets[0].feature_columns_categorical:
154
153
  bins = len(train_df[feature].unique())
155
154
  train_df["bin"] = pd.cut(train_df[feature], bins=bins)
156
155
 
@@ -89,7 +89,7 @@ class RegressionModelForecastPlot(Metric):
89
89
  figures = []
90
90
 
91
91
  for i, fitted_model in enumerate(model_list):
92
- feature_columns = datasets[0].get_features_columns()
92
+ feature_columns = datasets[0].feature_columns
93
93
 
94
94
  train_ds = datasets[0]
95
95
  test_ds = datasets[1]
@@ -98,7 +98,7 @@ class RegressionModelForecastPlotLevels(Metric):
98
98
  figures = []
99
99
 
100
100
  for i, fitted_model in enumerate(model_list):
101
- feature_columns = datasets[0].get_features_columns()
101
+ feature_columns = datasets[0].feature_columns
102
102
 
103
103
  train_ds = datasets[0]
104
104
  test_ds = datasets[1]
@@ -106,7 +106,7 @@ class RegressionModelInsampleComparison(Metric):
106
106
  evaluation_results = []
107
107
 
108
108
  for i, model in enumerate(models):
109
- X_columns = dataset.get_features_columns()
109
+ X_columns = dataset.feature_columns
110
110
  y_true = dataset.y
111
111
  y_pred = dataset.y_pred(model)
112
112
 
@@ -96,7 +96,7 @@ class RegressionModelOutsampleComparison(Metric):
96
96
 
97
97
  for fitted_model in model_list:
98
98
  # Extract the column names of the independent variables from the model
99
- independent_vars = dataset.get_features_columns()
99
+ independent_vars = dataset.feature_columns
100
100
 
101
101
  # Separate the target variable and features in the test dataset
102
102
  y_test = dataset.y
@@ -57,7 +57,7 @@ class RegressionModelSummary(Metric):
57
57
  }
58
58
 
59
59
  def run(self):
60
- X_columns = self.inputs.dataset.get_features_columns()
60
+ X_columns = self.inputs.dataset.feature_columns
61
61
 
62
62
  y_true = self.inputs.dataset.y
63
63
  y_pred = self.inputs.dataset.y_pred(self.inputs.model)
@@ -73,7 +73,7 @@ class RegressionModelsCoeffs(Metric):
73
73
  raise ValueError("List of models must be provided in the models parameter")
74
74
 
75
75
  for model in self.inputs.models:
76
- if model.model_class() != "statsmodels" and model.model_class() != "R":
76
+ if model.class_ != "statsmodels" and model.class_ != "R":
77
77
  raise SkipTestError(
78
78
  "Only statsmodels and R models are supported for this metric"
79
79
  )
@@ -80,7 +80,7 @@ class RegressionModelsPerformance(Metric):
80
80
  evaluation_results = []
81
81
 
82
82
  for model, dataset in zip(models, datasets):
83
- X_columns = dataset.get_features_columns()
83
+ X_columns = dataset.feature_columns
84
84
  y_true = dataset.y
85
85
  y_pred = dataset.y_pred(model)
86
86
 
@@ -112,16 +112,15 @@ class ScorecardHistogram(Metric):
112
112
  dataframes = []
113
113
  metric_value = {"score_histogram": {}}
114
114
  for dataset in self.inputs.datasets:
115
- df = dataset.df.copy()
116
- # Check if the score_column exists in the DataFrame
117
- if score_column not in df.columns:
115
+ if score_column not in dataset.df.columns:
118
116
  raise ValueError(
119
117
  f"The required column '{score_column}' is not present in the dataset with input_id {dataset.input_id}"
120
118
  )
121
119
 
122
- df[score_column] = dataset.get_extra_column(score_column)
123
- dataframes.append(df)
124
- metric_value["score_histogram"][dataset.input_id] = list(df[score_column])
120
+ dataframes.append(dataset.df.copy())
121
+ metric_value["score_histogram"][dataset.input_id] = list(
122
+ dataset.df[score_column]
123
+ )
125
124
 
126
125
  figures = self.plot_score_histogram(
127
126
  dataframes, dataset_titles, score_column, target_column, title
@@ -6,8 +6,6 @@ import hashlib
6
6
  import json
7
7
  from importlib import import_module
8
8
 
9
- import numpy as np
10
-
11
9
  from ..tests.decorator import _build_result, _inspect_signature
12
10
  from ..utils import get_model_info, test_id_to_name
13
11
 
@@ -58,7 +56,7 @@ def _serialize_model(model):
58
56
  return hash_object.hexdigest()
59
57
 
60
58
 
61
- def _serialize_dataset(dataset, model_id):
59
+ def _serialize_dataset(dataset, model):
62
60
  """
63
61
  Serialize the description of the dataset input to a unique hash.
64
62
 
@@ -68,11 +66,11 @@ def _serialize_dataset(dataset, model_id):
68
66
 
69
67
  Args:
70
68
  dataset: The dataset object, which should have properties like _df (pandas DataFrame),
71
- target_column (string), feature_columns (list of strings), and _extra_columns (dict).
72
- model_id (str): The ID of the model associated with the prediction column.
69
+ target_column (string), feature_columns (list of strings), and extra_columns (dict).
70
+ model (VMModel): The model whose predictions will be included in the serialized dataset
73
71
 
74
72
  Returns:
75
- str: A SHA-256 hash representing the dataset.
73
+ str: MD5 hash of the dataset
76
74
 
77
75
  Note:
78
76
  Including the model ID and prediction column name in the hash calculation ensures uniqueness,
@@ -80,57 +78,33 @@ def _serialize_dataset(dataset, model_id):
80
78
  This approach guarantees that the hash will distinguish between model-generated predictions
81
79
  and pre-computed prediction columns, addressing potential hash collisions.
82
80
  """
83
-
84
- # Access the prediction column for the given model ID from the dataset's extra columns
85
- prediction_column_name = dataset._extra_columns["prediction_columns"][model_id]
86
-
87
- # Include model ID and prediction column name directly in the hash calculation
88
- model_and_prediction_info = f"{model_id}_{prediction_column_name}".encode()
89
-
90
- # Start with target and feature columns, and include the prediction column
91
- columns = (
92
- [dataset._target_column] + dataset._feature_columns + [prediction_column_name]
81
+ return _fast_hash(
82
+ dataset.df[
83
+ [
84
+ *dataset.feature_columns,
85
+ dataset.target_column,
86
+ dataset.prediction_column(model),
87
+ ]
88
+ ]
93
89
  )
94
90
 
95
- # Use _fast_hash function and include model_and_prediction_info in the hash calculation
96
- hash_digest = _fast_hash(
97
- dataset._df[columns], model_and_prediction_info=model_and_prediction_info
98
- )
99
-
100
- return hash_digest
101
-
102
91
 
103
- def _fast_hash(df, sample_size=1000, model_and_prediction_info=None):
92
+ def _fast_hash(df, sample_size=1000):
104
93
  """
105
- Generates a hash for a DataFrame by sampling and combining its size, content,
106
- and optionally model and prediction information.
94
+ Generates a fast hash by sampling, converting to string and md5 hashing.
107
95
 
108
96
  Args:
109
97
  df (pd.DataFrame): The DataFrame to hash.
110
98
  sample_size (int): The maximum number of rows to include in the sample.
111
- model_and_prediction_info (bytes, optional): Additional information to include in the hash.
112
99
 
113
100
  Returns:
114
- str: A SHA-256 hash of the DataFrame's sample and additional information.
101
+ str: MD5 hash of the DataFrame.
115
102
  """
116
- # Convert the number of rows to bytes and include it in the hash calculation
117
- rows_bytes = str(len(df)).encode()
103
+ df_sample = df.sample(n=min(sample_size, len(df)), random_state=42)
118
104
 
119
- # Sample rows if DataFrame is larger than sample_size, ensuring reproducibility
120
- if len(df) > sample_size:
121
- df_sample = df.sample(n=sample_size, random_state=42)
122
- else:
123
- df_sample = df
124
-
125
- # Convert the sampled DataFrame to a byte array. np.asarray ensures compatibility with various DataFrame contents.
126
- byte_array = np.asarray(df_sample).data.tobytes()
127
-
128
- # Initialize the hash object and update it with the row count, data bytes, and additional info
129
- hash_obj = hashlib.sha256(
130
- rows_bytes + byte_array + (model_and_prediction_info or b"")
131
- )
132
-
133
- return hash_obj.hexdigest()
105
+ return hashlib.md5(
106
+ df_sample.to_string(header=True, index=True).encode()
107
+ ).hexdigest()
134
108
 
135
109
 
136
110
  def get_metric_cache_key(metric_id, params, inputs):
@@ -150,9 +124,8 @@ def get_metric_cache_key(metric_id, params, inputs):
150
124
 
151
125
  dataset = inputs["dataset"]
152
126
  model = inputs["model"]
153
- model_id = model.input_id
154
127
 
155
- cache_elements.append(_serialize_dataset(dataset, model_id))
128
+ cache_elements.append(_serialize_dataset(dataset, model))
156
129
 
157
130
  cache_elements.append(_serialize_model(model))
158
131
 
@@ -197,7 +170,11 @@ def run_metric(metric_id, inputs=None, params=None, show=True, value_only=False)
197
170
  **{k: v for k, v in inputs.items() if k in _inputs.keys()},
198
171
  **{k: v for k, v in params.items() if k in _params.keys()},
199
172
  )
200
- unit_metric_results_cache[cache_key] = (result, list(_inputs.keys()))
173
+ unit_metric_results_cache[cache_key] = (
174
+ result,
175
+ # store the input ids that were used to calculate the result
176
+ [v.input_id for v in inputs.values()],
177
+ )
201
178
 
202
179
  value = unit_metric_results_cache[cache_key][0]
203
180
 
@@ -235,7 +212,7 @@ def run_metric(metric_id, inputs=None, params=None, show=True, value_only=False)
235
212
  )
236
213
 
237
214
  # in case the user tries to log the result object
238
- def log(self):
215
+ def log():
239
216
  raise Exception(
240
217
  "Cannot log unit metrics directly..."
241
218
  "You can run this unit metric as part of a composite metric and log that"
@@ -37,6 +37,7 @@ class CompositeMetric(Metric):
37
37
  metric_ids=self.unit_metrics,
38
38
  description=self.description(),
39
39
  inputs=self._get_input_dict(),
40
+ accessed_inputs=self.get_accessed_inputs(),
40
41
  params=self.params,
41
42
  output_template=self.output_template,
42
43
  show=False,
@@ -103,6 +104,7 @@ def run_metrics(
103
104
  description: str = None,
104
105
  output_template: str = None,
105
106
  inputs: dict = None,
107
+ accessed_inputs: List[str] = None,
106
108
  params: dict = None,
107
109
  test_id: str = None,
108
110
  show: bool = True,
@@ -128,6 +130,8 @@ def run_metrics(
128
130
  output_template (_type_, optional): Output template to customize the result
129
131
  table.
130
132
  inputs (_type_, optional): Inputs to pass to the unit metrics. Defaults to None
133
+ accessed_inputs (_type_, optional): Inputs that were accessed when running the
134
+ unit metrics - used for input tracking. Defaults to None.
131
135
  params (_type_, optional): Parameters to pass to the unit metrics. Defaults to
132
136
  None.
133
137
  test_id (str, optional): Test ID of the composite metric. Required if name is
@@ -212,7 +216,7 @@ def run_metrics(
212
216
  "json": {"output_template": output_template},
213
217
  },
214
218
  ],
215
- inputs=list(inputs.keys()),
219
+ inputs=accessed_inputs,
216
220
  output_template=output_template,
217
221
  metric=MetricResult(
218
222
  key=test_id,
@@ -16,6 +16,6 @@ def AdjustedRSquaredScore(model, dataset):
16
16
  )
17
17
 
18
18
  row_count = len(dataset.y)
19
- feature_count = len(dataset.get_features_columns())
19
+ feature_count = len(dataset.feature_columns)
20
20
 
21
21
  return 1 - (1 - r2_score) * (row_count - 1) / (row_count - feature_count)
validmind/utils.py CHANGED
@@ -12,16 +12,21 @@ from platform import python_version
12
12
  from typing import Any
13
13
 
14
14
  import matplotlib.pylab as pylab
15
+ import mistune
15
16
  import nest_asyncio
16
17
  import numpy as np
17
18
  import pandas as pd
18
19
  import seaborn as sns
19
20
  from IPython.core import getipython
20
- from IPython.display import HTML, display
21
+ from IPython.display import HTML
22
+ from IPython.display import display as ipy_display
23
+ from latex2mathml.converter import convert
21
24
  from matplotlib.axes._axes import _log as matplotlib_axes_logger
22
25
  from numpy import ndarray
23
26
  from tabulate import tabulate
24
27
 
28
+ from .html_templates.content_blocks import math_jax_snippet, python_syntax_highlighting
29
+
25
30
  DEFAULT_BIG_NUMBER_DECIMALS = 2
26
31
  DEFAULT_SMALL_NUMBER_DECIMALS = 4
27
32
 
@@ -97,6 +102,8 @@ class NumpyEncoder(json.JSONEncoder):
97
102
  return bool(obj)
98
103
  if isinstance(obj, pd.Timestamp):
99
104
  return str(obj)
105
+ if isinstance(obj, set):
106
+ return list(obj)
100
107
  return super().default(obj)
101
108
 
102
109
  def encode(self, obj):
@@ -345,10 +352,10 @@ def test_id_to_name(test_id: str) -> str:
345
352
 
346
353
  def get_model_info(model):
347
354
  """Attempts to extract all model info from a model object instance"""
348
- architecture = model.model_name()
349
- framework = model.model_library()
350
- framework_version = model.model_library_version()
351
- language = model.model_language()
355
+ architecture = model.name
356
+ framework = model.library
357
+ framework_version = model.library_version
358
+ language = model.language
352
359
 
353
360
  if language is None:
354
361
  language = f"Python {python_version()}"
@@ -402,4 +409,47 @@ def preview_test_config(config):
402
409
  <div id="collapsibleContent" style="display:none;"><pre>{formatted_json}</pre></div>
403
410
  """
404
411
 
405
- display(HTML(collapsible_html))
412
+ ipy_display(HTML(collapsible_html))
413
+
414
+
415
+ def display(widget_or_html, syntax_highlighting=True, mathjax=True):
416
+ """Display widgets with extra goodies (syntax highlighting, MathJax, etc.)"""
417
+ if isinstance(widget_or_html, str):
418
+ ipy_display(HTML(widget_or_html))
419
+ # if html we can auto-detect if we actually need syntax highlighting or MathJax
420
+ syntax_highlighting = 'class="language-' in widget_or_html
421
+ mathjax = "$$" in widget_or_html
422
+ else:
423
+ ipy_display(widget_or_html)
424
+
425
+ if syntax_highlighting:
426
+ ipy_display(HTML(python_syntax_highlighting))
427
+
428
+ if mathjax:
429
+ ipy_display(HTML(math_jax_snippet))
430
+
431
+
432
+ def md_to_html(md: str, mathml=False) -> str:
433
+ """Converts Markdown to HTML using mistune with plugins"""
434
+ # use mistune with math plugin to convert to html
435
+ html = mistune.create_markdown(plugins=["math"])(md)
436
+
437
+ if not mathml:
438
+ # return the html as is (with latex that will be rendered by MathJax)
439
+ return html
440
+
441
+ # convert the latex to MathML which CKeditor can render
442
+ math_block_pattern = re.compile(r'<div class="math">\$\$([\s\S]*?)\$\$</div>')
443
+ html = math_block_pattern.sub(
444
+ lambda match: "<p>{}</p>".format(convert(match.group(1), display="block")), html
445
+ )
446
+
447
+ inline_math_pattern = re.compile(r'<span class="math">\\\((.*?)\\\)</span>')
448
+ html = inline_math_pattern.sub(
449
+ lambda match: "<span>{}</span>".format(
450
+ convert(match.group(1), display="inline")
451
+ ),
452
+ html,
453
+ )
454
+
455
+ return html
@@ -6,7 +6,7 @@
6
6
  Models entrypoint
7
7
  """
8
8
 
9
- from .dataset import VMDataset
9
+ from .dataset.dataset import VMDataset
10
10
  from .figure import Figure
11
11
  from .model import R_MODEL_TYPES, ModelAttributes, VMModel
12
12
  from .test.metric import Metric
@@ -0,0 +1,7 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ from .dataset import DataFrameDataset, PolarsDataset, TorchDataset, VMDataset
6
+
7
+ __all__ = ["VMDataset", "DataFrameDataset", "PolarsDataset", "TorchDataset"]