validmind 2.8.28__py3-none-any.whl → 2.8.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/models/function.py +11 -3
- validmind/tests/data_validation/ACFandPACFPlot.py +3 -1
- validmind/tests/data_validation/ADF.py +3 -1
- validmind/tests/data_validation/AutoAR.py +3 -1
- validmind/tests/data_validation/AutoMA.py +5 -1
- validmind/tests/data_validation/AutoStationarity.py +5 -1
- validmind/tests/data_validation/BivariateScatterPlots.py +3 -1
- validmind/tests/data_validation/BoxPierce.py +4 -1
- validmind/tests/data_validation/ChiSquaredFeaturesTable.py +1 -1
- validmind/tests/data_validation/ClassImbalance.py +1 -1
- validmind/tests/data_validation/DatasetDescription.py +4 -1
- validmind/tests/data_validation/DatasetSplit.py +3 -2
- validmind/tests/data_validation/DescriptiveStatistics.py +3 -1
- validmind/tests/data_validation/DickeyFullerGLS.py +3 -1
- validmind/tests/data_validation/Duplicates.py +3 -1
- validmind/tests/data_validation/EngleGrangerCoint.py +6 -1
- validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +3 -1
- validmind/tests/data_validation/HighCardinality.py +3 -1
- validmind/tests/data_validation/HighPearsonCorrelation.py +4 -1
- validmind/tests/data_validation/IQROutliersBarPlot.py +4 -1
- validmind/tests/data_validation/IQROutliersTable.py +6 -1
- validmind/tests/data_validation/IsolationForestOutliers.py +3 -1
- validmind/tests/data_validation/JarqueBera.py +3 -1
- validmind/tests/data_validation/KPSS.py +3 -1
- validmind/tests/data_validation/LJungBox.py +3 -1
- validmind/tests/data_validation/LaggedCorrelationHeatmap.py +6 -1
- validmind/tests/data_validation/MissingValues.py +5 -1
- validmind/tests/data_validation/MissingValuesBarPlot.py +3 -1
- validmind/tests/data_validation/MutualInformation.py +4 -1
- validmind/tests/data_validation/PearsonCorrelationMatrix.py +3 -1
- validmind/tests/data_validation/PhillipsPerronArch.py +3 -1
- validmind/tests/data_validation/ProtectedClassesCombination.py +5 -1
- validmind/tests/data_validation/ProtectedClassesDescription.py +5 -1
- validmind/tests/data_validation/ProtectedClassesDisparity.py +5 -3
- validmind/tests/data_validation/ProtectedClassesThresholdOptimizer.py +9 -2
- validmind/tests/data_validation/RollingStatsPlot.py +5 -1
- validmind/tests/data_validation/RunsTest.py +1 -1
- validmind/tests/data_validation/ScatterPlot.py +2 -1
- validmind/tests/data_validation/ScoreBandDefaultRates.py +3 -1
- validmind/tests/data_validation/SeasonalDecompose.py +6 -1
- validmind/tests/data_validation/ShapiroWilk.py +4 -1
- validmind/tests/data_validation/Skewness.py +3 -1
- validmind/tests/data_validation/SpreadPlot.py +3 -1
- validmind/tests/data_validation/TabularCategoricalBarPlots.py +4 -1
- validmind/tests/data_validation/TabularDateTimeHistograms.py +3 -1
- validmind/tests/data_validation/TabularDescriptionTables.py +4 -1
- validmind/tests/data_validation/TabularNumericalHistograms.py +3 -1
- validmind/tests/data_validation/TargetRateBarPlots.py +4 -1
- validmind/tests/data_validation/TimeSeriesDescription.py +1 -1
- validmind/tests/data_validation/TimeSeriesDescriptiveStatistics.py +1 -1
- validmind/tests/data_validation/TimeSeriesFrequency.py +5 -1
- validmind/tests/data_validation/TimeSeriesHistogram.py +4 -1
- validmind/tests/data_validation/TimeSeriesLinePlot.py +3 -1
- validmind/tests/data_validation/TimeSeriesMissingValues.py +6 -1
- validmind/tests/data_validation/TimeSeriesOutliers.py +5 -1
- validmind/tests/data_validation/TooManyZeroValues.py +6 -1
- validmind/tests/data_validation/UniqueRows.py +5 -1
- validmind/tests/data_validation/WOEBinPlots.py +4 -1
- validmind/tests/data_validation/WOEBinTable.py +5 -1
- validmind/tests/data_validation/ZivotAndrewsArch.py +3 -1
- validmind/tests/data_validation/nlp/CommonWords.py +2 -1
- validmind/tests/data_validation/nlp/Hashtags.py +2 -1
- validmind/tests/data_validation/nlp/LanguageDetection.py +4 -1
- validmind/tests/data_validation/nlp/Mentions.py +3 -1
- validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +6 -1
- validmind/tests/data_validation/nlp/Punctuations.py +2 -1
- validmind/tests/data_validation/nlp/Sentiment.py +3 -1
- validmind/tests/data_validation/nlp/StopWords.py +2 -1
- validmind/tests/data_validation/nlp/TextDescription.py +3 -1
- validmind/tests/data_validation/nlp/Toxicity.py +3 -1
- validmind/tests/load.py +91 -17
- validmind/tests/model_validation/BertScore.py +6 -3
- validmind/tests/model_validation/BleuScore.py +6 -1
- validmind/tests/model_validation/ClusterSizeDistribution.py +5 -1
- validmind/tests/model_validation/ContextualRecall.py +6 -1
- validmind/tests/model_validation/FeaturesAUC.py +5 -1
- validmind/tests/model_validation/MeteorScore.py +6 -1
- validmind/tests/model_validation/ModelMetadata.py +2 -1
- validmind/tests/model_validation/ModelPredictionResiduals.py +10 -2
- validmind/tests/model_validation/RegardScore.py +7 -1
- validmind/tests/model_validation/RegressionResidualsPlot.py +5 -1
- validmind/tests/model_validation/RougeScore.py +8 -1
- validmind/tests/model_validation/TimeSeriesPredictionWithCI.py +8 -1
- validmind/tests/model_validation/TimeSeriesPredictionsPlot.py +7 -1
- validmind/tests/model_validation/TimeSeriesR2SquareBySegments.py +6 -1
- validmind/tests/model_validation/TokenDisparity.py +6 -1
- validmind/tests/model_validation/ToxicityScore.py +6 -1
- validmind/tests/model_validation/embeddings/ClusterDistribution.py +6 -1
- validmind/tests/model_validation/embeddings/CosineSimilarityComparison.py +6 -1
- validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +6 -1
- validmind/tests/model_validation/embeddings/CosineSimilarityHeatmap.py +7 -3
- validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +6 -1
- validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +4 -3
- validmind/tests/model_validation/embeddings/EuclideanDistanceComparison.py +6 -1
- validmind/tests/model_validation/embeddings/EuclideanDistanceHeatmap.py +7 -3
- validmind/tests/model_validation/embeddings/PCAComponentsPairwisePlots.py +6 -1
- validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +5 -2
- validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +5 -1
- validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +4 -1
- validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +5 -1
- validmind/tests/model_validation/embeddings/TSNEComponentsPairwisePlots.py +9 -6
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +8 -5
- validmind/tests/model_validation/ragas/AspectCritic.py +11 -8
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +5 -2
- validmind/tests/model_validation/ragas/ContextPrecision.py +5 -2
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +5 -2
- validmind/tests/model_validation/ragas/ContextRecall.py +6 -2
- validmind/tests/model_validation/ragas/Faithfulness.py +9 -5
- validmind/tests/model_validation/ragas/NoiseSensitivity.py +10 -7
- validmind/tests/model_validation/ragas/ResponseRelevancy.py +9 -6
- validmind/tests/model_validation/ragas/SemanticSimilarity.py +7 -4
- validmind/tests/model_validation/sklearn/AdjustedMutualInformation.py +5 -1
- validmind/tests/model_validation/sklearn/AdjustedRandIndex.py +5 -1
- validmind/tests/model_validation/sklearn/CalibrationCurve.py +5 -1
- validmind/tests/model_validation/sklearn/ClassifierPerformance.py +5 -1
- validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +5 -1
- validmind/tests/model_validation/sklearn/ClusterPerformanceMetrics.py +5 -1
- validmind/tests/model_validation/sklearn/CompletenessScore.py +5 -1
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +4 -1
- validmind/tests/model_validation/sklearn/FeatureImportance.py +5 -1
- validmind/tests/model_validation/sklearn/FowlkesMallowsScore.py +5 -1
- validmind/tests/model_validation/sklearn/HomogeneityScore.py +5 -1
- validmind/tests/model_validation/sklearn/HyperParametersTuning.py +2 -4
- validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +3 -3
- validmind/tests/model_validation/sklearn/MinimumAccuracy.py +5 -1
- validmind/tests/model_validation/sklearn/MinimumF1Score.py +5 -1
- validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +5 -1
- validmind/tests/model_validation/sklearn/ModelParameters.py +6 -1
- validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +5 -1
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -2
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +4 -4
- validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +2 -2
- validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +5 -1
- validmind/tests/model_validation/sklearn/ROCCurve.py +3 -1
- validmind/tests/model_validation/sklearn/RegressionErrors.py +6 -1
- validmind/tests/model_validation/sklearn/RegressionErrorsComparison.py +6 -1
- validmind/tests/model_validation/sklearn/RegressionPerformance.py +5 -1
- validmind/tests/model_validation/sklearn/RegressionR2Square.py +6 -1
- validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py +6 -1
- validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +2 -2
- validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +3 -1
- validmind/tests/model_validation/sklearn/SilhouettePlot.py +6 -1
- validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +2 -2
- validmind/tests/model_validation/sklearn/VMeasure.py +5 -1
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +6 -5
- validmind/tests/model_validation/statsmodels/AutoARIMA.py +3 -1
- validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +6 -1
- validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +6 -1
- validmind/tests/model_validation/statsmodels/GINITable.py +4 -1
- validmind/tests/model_validation/statsmodels/KolmogorovSmirnov.py +5 -1
- validmind/tests/model_validation/statsmodels/Lilliefors.py +3 -1
- validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +6 -2
- validmind/tests/model_validation/statsmodels/RegressionCoeffs.py +4 -1
- validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +7 -2
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +5 -4
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +4 -1
- validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +3 -2
- validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +5 -1
- validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +3 -1
- validmind/tests/model_validation/statsmodels/ScorecardHistogram.py +6 -1
- validmind/tests/ongoing_monitoring/CalibrationCurveDrift.py +2 -2
- validmind/tests/ongoing_monitoring/ClassDiscriminationDrift.py +2 -2
- validmind/tests/ongoing_monitoring/ClassImbalanceDrift.py +2 -2
- validmind/tests/ongoing_monitoring/ClassificationAccuracyDrift.py +2 -2
- validmind/tests/ongoing_monitoring/ConfusionMatrixDrift.py +2 -2
- validmind/tests/ongoing_monitoring/CumulativePredictionProbabilitiesDrift.py +2 -2
- validmind/tests/ongoing_monitoring/FeatureDrift.py +5 -2
- validmind/tests/ongoing_monitoring/PredictionAcrossEachFeature.py +6 -1
- validmind/tests/ongoing_monitoring/PredictionCorrelation.py +8 -1
- validmind/tests/ongoing_monitoring/PredictionProbabilitiesHistogramDrift.py +2 -2
- validmind/tests/ongoing_monitoring/PredictionQuantilesAcrossFeatures.py +6 -1
- validmind/tests/ongoing_monitoring/ROCCurveDrift.py +4 -2
- validmind/tests/ongoing_monitoring/ScoreBandsDrift.py +2 -2
- validmind/tests/ongoing_monitoring/ScorecardHistogramDrift.py +2 -2
- validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +8 -1
- validmind/tests/prompt_validation/Bias.py +5 -1
- validmind/tests/prompt_validation/Clarity.py +5 -1
- validmind/tests/prompt_validation/Conciseness.py +5 -1
- validmind/tests/prompt_validation/Delimitation.py +5 -1
- validmind/tests/prompt_validation/NegativeInstruction.py +5 -1
- validmind/tests/prompt_validation/Robustness.py +5 -1
- validmind/tests/prompt_validation/Specificity.py +5 -1
- validmind/unit_metrics/classification/Accuracy.py +2 -1
- validmind/unit_metrics/classification/F1.py +2 -1
- validmind/unit_metrics/classification/Precision.py +2 -1
- validmind/unit_metrics/classification/ROC_AUC.py +2 -1
- validmind/unit_metrics/classification/Recall.py +2 -1
- validmind/unit_metrics/regression/AdjustedRSquaredScore.py +2 -1
- validmind/unit_metrics/regression/GiniCoefficient.py +2 -1
- validmind/unit_metrics/regression/HuberLoss.py +2 -1
- validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +2 -1
- validmind/unit_metrics/regression/MeanAbsoluteError.py +2 -1
- validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +2 -1
- validmind/unit_metrics/regression/MeanBiasDeviation.py +2 -1
- validmind/unit_metrics/regression/MeanSquaredError.py +2 -1
- validmind/unit_metrics/regression/QuantileLoss.py +1 -1
- validmind/unit_metrics/regression/RSquaredScore.py +2 -1
- validmind/unit_metrics/regression/RootMeanSquaredError.py +2 -1
- validmind/vm_models/dataset/dataset.py +145 -38
- {validmind-2.8.28.dist-info → validmind-2.8.29.dist-info}/METADATA +1 -1
- {validmind-2.8.28.dist-info → validmind-2.8.29.dist-info}/RECORD +204 -204
- {validmind-2.8.28.dist-info → validmind-2.8.29.dist-info}/LICENSE +0 -0
- {validmind-2.8.28.dist-info → validmind-2.8.29.dist-info}/WHEEL +0 -0
- {validmind-2.8.28.dist-info → validmind-2.8.29.dist-info}/entry_points.txt +0 -0
@@ -8,7 +8,7 @@ Dataset class wrapper
|
|
8
8
|
|
9
9
|
import warnings
|
10
10
|
from copy import deepcopy
|
11
|
-
from typing import Any, Dict,
|
11
|
+
from typing import Any, Dict, Optional
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
import pandas as pd
|
@@ -258,69 +258,91 @@ class VMDataset(VMInput):
|
|
258
258
|
f"Options {kwargs} are not supported for this input"
|
259
259
|
)
|
260
260
|
|
261
|
-
def
|
262
|
-
self,
|
263
|
-
|
264
|
-
|
265
|
-
prediction_values: Optional[List[Any]] = None,
|
266
|
-
probability_column: Optional[str] = None,
|
267
|
-
probability_values: Optional[List[float]] = None,
|
268
|
-
prediction_probabilities: Optional[
|
269
|
-
List[float]
|
270
|
-
] = None, # DEPRECATED: use probability_values
|
271
|
-
**kwargs: Dict[str, Any],
|
272
|
-
) -> None:
|
273
|
-
"""Assign predictions and probabilities to the dataset.
|
274
|
-
|
275
|
-
Args:
|
276
|
-
model (VMModel): The model used to generate the predictions.
|
277
|
-
prediction_column (Optional[str]): The name of the column containing the predictions.
|
278
|
-
prediction_values (Optional[List[Any]]): The values of the predictions.
|
279
|
-
probability_column (Optional[str]): The name of the column containing the probabilities.
|
280
|
-
probability_values (Optional[List[float]]): The values of the probabilities.
|
281
|
-
prediction_probabilities (Optional[List[float]]): DEPRECATED: The values of the probabilities.
|
282
|
-
**kwargs: Additional keyword arguments that will get passed through to the model's `predict` method.
|
283
|
-
"""
|
261
|
+
def _handle_deprecated_parameters(
|
262
|
+
self, prediction_probabilities, probability_values
|
263
|
+
):
|
264
|
+
"""Handle deprecated parameters and return the correct probability values."""
|
284
265
|
if prediction_probabilities is not None:
|
285
266
|
warnings.warn(
|
286
267
|
"The `prediction_probabilities` argument is deprecated. Use `probability_values` instead.",
|
287
268
|
DeprecationWarning,
|
288
269
|
)
|
289
|
-
|
290
|
-
|
291
|
-
self._validate_assign_predictions(
|
292
|
-
model,
|
293
|
-
prediction_column,
|
294
|
-
prediction_values,
|
295
|
-
probability_column,
|
296
|
-
probability_values,
|
297
|
-
)
|
270
|
+
return prediction_probabilities
|
271
|
+
return probability_values
|
298
272
|
|
273
|
+
def _check_existing_predictions(self, model):
|
274
|
+
"""Check for existing predictions and probabilities, warn if overwriting."""
|
299
275
|
if self.prediction_column(model):
|
300
276
|
logger.warning("Model predictions already assigned... Overwriting.")
|
301
277
|
|
302
278
|
if self.probability_column(model):
|
303
279
|
logger.warning("Model probabilities already assigned... Overwriting.")
|
304
280
|
|
305
|
-
|
281
|
+
def _get_precomputed_values(self, prediction_column, probability_column):
|
282
|
+
"""Get precomputed prediction and probability values from existing columns."""
|
283
|
+
prediction_values = None
|
284
|
+
probability_values = None
|
285
|
+
|
306
286
|
if prediction_column:
|
307
287
|
prediction_values = self._df[prediction_column].values
|
308
288
|
|
309
289
|
if probability_column:
|
310
290
|
probability_values = self._df[probability_column].values
|
311
291
|
|
292
|
+
return prediction_values, probability_values
|
293
|
+
|
294
|
+
def _compute_predictions_if_needed(self, model, prediction_values, **kwargs):
|
295
|
+
"""Compute predictions if not provided."""
|
312
296
|
if prediction_values is None:
|
313
297
|
X = self.df if isinstance(model, (FunctionModel, PipelineModel)) else self.x
|
314
|
-
|
315
|
-
|
298
|
+
return compute_predictions(model, X, **kwargs)
|
299
|
+
return None, prediction_values
|
300
|
+
|
301
|
+
def _handle_dictionary_predictions(self, model, prediction_values):
|
302
|
+
"""Handle dictionary predictions by converting to separate columns."""
|
303
|
+
if (
|
304
|
+
prediction_values is not None
|
305
|
+
and len(prediction_values) > 0
|
306
|
+
and isinstance(prediction_values[0], dict)
|
307
|
+
):
|
308
|
+
df_prediction_values = pd.DataFrame.from_dict(
|
309
|
+
prediction_values, orient="columns"
|
316
310
|
)
|
317
311
|
|
318
|
-
|
312
|
+
for column_name in df_prediction_values.columns.tolist():
|
313
|
+
values = df_prediction_values[column_name].values
|
314
|
+
|
315
|
+
if column_name == "prediction":
|
316
|
+
prediction_column = f"{model.input_id}_prediction"
|
317
|
+
self._add_column(prediction_column, values)
|
318
|
+
self.prediction_column(model, prediction_column)
|
319
|
+
else:
|
320
|
+
self._add_column(f"{model.input_id}_{column_name}", values)
|
321
|
+
|
322
|
+
return (
|
323
|
+
True,
|
324
|
+
None,
|
325
|
+
) # Return True to indicate dictionary handled, None for prediction_column
|
326
|
+
return False, None
|
327
|
+
|
328
|
+
def _add_prediction_columns(
|
329
|
+
self,
|
330
|
+
model,
|
331
|
+
prediction_column,
|
332
|
+
prediction_values,
|
333
|
+
probability_column,
|
334
|
+
probability_values,
|
335
|
+
):
|
336
|
+
"""Add prediction and probability columns to the dataset."""
|
337
|
+
if prediction_column is None:
|
338
|
+
prediction_column = f"{model.input_id}_prediction"
|
339
|
+
|
319
340
|
self._add_column(prediction_column, prediction_values)
|
320
341
|
self.prediction_column(model, prediction_column)
|
321
342
|
|
322
343
|
if probability_values is not None:
|
323
|
-
|
344
|
+
if probability_column is None:
|
345
|
+
probability_column = f"{model.input_id}_probabilities"
|
324
346
|
self._add_column(probability_column, probability_values)
|
325
347
|
self.probability_column(model, probability_column)
|
326
348
|
else:
|
@@ -329,6 +351,91 @@ class VMDataset(VMInput):
|
|
329
351
|
"Not adding probability column to the dataset."
|
330
352
|
)
|
331
353
|
|
354
|
+
def assign_predictions(
|
355
|
+
self,
|
356
|
+
model: VMModel,
|
357
|
+
prediction_column: Optional[str] = None,
|
358
|
+
prediction_values: Optional[Any] = None,
|
359
|
+
probability_column: Optional[str] = None,
|
360
|
+
probability_values: Optional[Any] = None,
|
361
|
+
prediction_probabilities: Optional[
|
362
|
+
Any
|
363
|
+
] = None, # DEPRECATED: use probability_values
|
364
|
+
**kwargs: Dict[str, Any],
|
365
|
+
) -> None:
|
366
|
+
"""Assign predictions and probabilities to the dataset.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
model (VMModel): The model used to generate the predictions.
|
370
|
+
prediction_column (Optional[str]): The name of the column containing the predictions.
|
371
|
+
prediction_values (Optional[Any]): The values of the predictions. Can be array-like (list, numpy array, pandas Series, etc.).
|
372
|
+
probability_column (Optional[str]): The name of the column containing the probabilities.
|
373
|
+
probability_values (Optional[Any]): The values of the probabilities. Can be array-like (list, numpy array, pandas Series, etc.).
|
374
|
+
prediction_probabilities (Optional[Any]): DEPRECATED: The values of the probabilities. Use probability_values instead.
|
375
|
+
**kwargs: Additional keyword arguments that will get passed through to the model's `predict` method.
|
376
|
+
"""
|
377
|
+
# Handle deprecated parameters
|
378
|
+
probability_values = self._handle_deprecated_parameters(
|
379
|
+
prediction_probabilities, probability_values
|
380
|
+
)
|
381
|
+
|
382
|
+
# Convert pandas Series to numpy array for prediction_values
|
383
|
+
if (
|
384
|
+
hasattr(prediction_values, "values")
|
385
|
+
and hasattr(prediction_values, "index")
|
386
|
+
and hasattr(prediction_values, "dtype")
|
387
|
+
):
|
388
|
+
prediction_values = prediction_values.values
|
389
|
+
|
390
|
+
# Convert pandas Series to numpy array for probability_values
|
391
|
+
if (
|
392
|
+
hasattr(probability_values, "values")
|
393
|
+
and hasattr(probability_values, "index")
|
394
|
+
and hasattr(probability_values, "dtype")
|
395
|
+
):
|
396
|
+
probability_values = probability_values.values
|
397
|
+
|
398
|
+
# Validate input parameters
|
399
|
+
self._validate_assign_predictions(
|
400
|
+
model,
|
401
|
+
prediction_column,
|
402
|
+
prediction_values,
|
403
|
+
probability_column,
|
404
|
+
probability_values,
|
405
|
+
)
|
406
|
+
|
407
|
+
# Check for existing predictions and warn if overwriting
|
408
|
+
self._check_existing_predictions(model)
|
409
|
+
|
410
|
+
# Get precomputed values if column names are provided
|
411
|
+
if prediction_column or probability_column:
|
412
|
+
prediction_values, prob_values_from_column = self._get_precomputed_values(
|
413
|
+
prediction_column, probability_column
|
414
|
+
)
|
415
|
+
if prob_values_from_column is not None:
|
416
|
+
probability_values = prob_values_from_column
|
417
|
+
|
418
|
+
# Compute predictions if not provided
|
419
|
+
if prediction_values is None:
|
420
|
+
probability_values, prediction_values = self._compute_predictions_if_needed(
|
421
|
+
model, prediction_values, **kwargs
|
422
|
+
)
|
423
|
+
|
424
|
+
# Handle dictionary predictions
|
425
|
+
is_dict_handled, _ = self._handle_dictionary_predictions(
|
426
|
+
model, prediction_values
|
427
|
+
)
|
428
|
+
|
429
|
+
# Add prediction and probability columns (skip if dictionary was handled)
|
430
|
+
if not is_dict_handled:
|
431
|
+
self._add_prediction_columns(
|
432
|
+
model,
|
433
|
+
prediction_column,
|
434
|
+
prediction_values,
|
435
|
+
probability_column,
|
436
|
+
probability_values,
|
437
|
+
)
|
438
|
+
|
332
439
|
def prediction_column(self, model: VMModel, column_name: str = None) -> str:
|
333
440
|
"""Get or set the prediction column for a model."""
|
334
441
|
if column_name and column_name not in self.columns:
|