validmind 2.8.12__py3-none-any.whl → 2.8.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +6 -5
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +13 -9
- validmind/ai/utils.py +2 -2
- validmind/api_client.py +75 -32
- validmind/client.py +108 -100
- validmind/client_config.py +3 -3
- validmind/datasets/classification/__init__.py +7 -3
- validmind/datasets/credit_risk/lending_club.py +28 -16
- validmind/datasets/nlp/cnn_dailymail.py +10 -4
- validmind/datasets/regression/__init__.py +22 -5
- validmind/errors.py +17 -7
- validmind/input_registry.py +1 -1
- validmind/logging.py +44 -35
- validmind/models/foundation.py +2 -2
- validmind/models/function.py +10 -3
- validmind/template.py +30 -22
- validmind/test_suites/__init__.py +2 -2
- validmind/tests/_store.py +13 -4
- validmind/tests/comparison.py +65 -33
- validmind/tests/data_validation/ClassImbalance.py +3 -1
- validmind/tests/data_validation/DatasetDescription.py +2 -23
- validmind/tests/data_validation/DescriptiveStatistics.py +1 -1
- validmind/tests/data_validation/Skewness.py +7 -6
- validmind/tests/decorator.py +14 -11
- validmind/tests/load.py +38 -24
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +4 -2
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +4 -2
- validmind/tests/model_validation/ragas/ContextPrecision.py +4 -2
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +4 -2
- validmind/tests/model_validation/ragas/ContextRecall.py +4 -2
- validmind/tests/model_validation/ragas/Faithfulness.py +4 -2
- validmind/tests/model_validation/ragas/ResponseRelevancy.py +4 -2
- validmind/tests/model_validation/ragas/SemanticSimilarity.py +4 -2
- validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +13 -3
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -1
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +28 -25
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +15 -10
- validmind/tests/output.py +66 -11
- validmind/tests/run.py +28 -14
- validmind/tests/test_providers.py +28 -35
- validmind/tests/utils.py +17 -4
- validmind/unit_metrics/__init__.py +1 -1
- validmind/utils.py +295 -31
- validmind/vm_models/dataset/dataset.py +19 -16
- validmind/vm_models/dataset/utils.py +5 -3
- validmind/vm_models/figure.py +6 -6
- validmind/vm_models/input.py +6 -5
- validmind/vm_models/model.py +5 -5
- validmind/vm_models/result/result.py +122 -43
- validmind/vm_models/result/utils.py +5 -5
- validmind/vm_models/test_suite/__init__.py +5 -0
- validmind/vm_models/test_suite/runner.py +5 -5
- validmind/vm_models/test_suite/summary.py +20 -2
- validmind/vm_models/test_suite/test.py +6 -6
- validmind/vm_models/test_suite/test_suite.py +10 -10
- {validmind-2.8.12.dist-info → validmind-2.8.20.dist-info}/METADATA +3 -4
- {validmind-2.8.12.dist-info → validmind-2.8.20.dist-info}/RECORD +61 -60
- {validmind-2.8.12.dist-info → validmind-2.8.20.dist-info}/WHEEL +1 -1
- {validmind-2.8.12.dist-info → validmind-2.8.20.dist-info}/LICENSE +0 -0
- {validmind-2.8.12.dist-info → validmind-2.8.20.dist-info}/entry_points.txt +0 -0
validmind/utils.py
CHANGED
@@ -12,7 +12,7 @@ import sys
|
|
12
12
|
import warnings
|
13
13
|
from datetime import date, datetime, time
|
14
14
|
from platform import python_version
|
15
|
-
from typing import Any, Dict, List
|
15
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar
|
16
16
|
|
17
17
|
import matplotlib.pylab as pylab
|
18
18
|
import mistune
|
@@ -20,6 +20,7 @@ import nest_asyncio
|
|
20
20
|
import numpy as np
|
21
21
|
import pandas as pd
|
22
22
|
import seaborn as sns
|
23
|
+
from bs4 import BeautifulSoup
|
23
24
|
from IPython.core import getipython
|
24
25
|
from IPython.display import HTML
|
25
26
|
from IPython.display import display as ipy_display
|
@@ -59,23 +60,25 @@ pylab.rcParams.update(params)
|
|
59
60
|
|
60
61
|
logger = get_logger(__name__)
|
61
62
|
|
63
|
+
T = TypeVar("T")
|
64
|
+
|
62
65
|
|
63
66
|
def parse_version(version: str) -> tuple[int, ...]:
|
64
67
|
"""
|
65
|
-
Parse a semver version string into a tuple of major, minor, patch integers
|
68
|
+
Parse a semver version string into a tuple of major, minor, patch integers.
|
66
69
|
|
67
70
|
Args:
|
68
|
-
version (str): The semantic version string to parse
|
71
|
+
version (str): The semantic version string to parse.
|
69
72
|
|
70
73
|
Returns:
|
71
|
-
tuple[int, ...]: A tuple of major, minor, patch integers
|
74
|
+
tuple[int, ...]: A tuple of major, minor, patch integers.
|
72
75
|
"""
|
73
76
|
return tuple(int(x) for x in version.split(".")[:3])
|
74
77
|
|
75
78
|
|
76
79
|
def is_notebook() -> bool:
|
77
80
|
"""
|
78
|
-
Checks if the code is running in a Jupyter notebook or IPython shell
|
81
|
+
Checks if the code is running in a Jupyter notebook or IPython shell.
|
79
82
|
|
80
83
|
https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
|
81
84
|
"""
|
@@ -209,9 +212,7 @@ class HumanReadableEncoder(NumpyEncoder):
|
|
209
212
|
|
210
213
|
|
211
214
|
def get_full_typename(o: Any) -> Any:
|
212
|
-
"""We determine types based on type names so we don't have to import
|
213
|
-
(and therefore depend on) PyTorch, TensorFlow, etc.
|
214
|
-
"""
|
215
|
+
"""We determine types based on type names so we don't have to import."""
|
215
216
|
instance_name = o.__class__.__module__ + "." + o.__class__.__name__
|
216
217
|
if instance_name in ["builtins.module", "__builtin__.module"]:
|
217
218
|
return o.__name__
|
@@ -313,9 +314,9 @@ def format_key_values(key_values: Dict[str, Any]) -> Dict[str, Any]:
|
|
313
314
|
|
314
315
|
def summarize_data_quality_results(results):
|
315
316
|
"""
|
316
|
-
TODO: generalize this to work with metrics and test results
|
317
|
+
TODO: generalize this to work with metrics and test results.
|
317
318
|
|
318
|
-
Summarize the results of the data quality test suite
|
319
|
+
Summarize the results of the data quality test suite.
|
319
320
|
"""
|
320
321
|
test_results = []
|
321
322
|
for result in results:
|
@@ -354,25 +355,31 @@ def format_number(number):
|
|
354
355
|
|
355
356
|
|
356
357
|
def format_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
357
|
-
"""Format a pandas DataFrame for display purposes"""
|
358
|
+
"""Format a pandas DataFrame for display purposes."""
|
358
359
|
df = df.style.set_properties(**{"text-align": "left"}).hide(axis="index")
|
359
360
|
return df.set_table_styles([dict(selector="th", props=[("text-align", "left")])])
|
360
361
|
|
361
362
|
|
362
|
-
def run_async(
|
363
|
-
|
363
|
+
def run_async(
|
364
|
+
func: Callable[..., Awaitable[T]],
|
365
|
+
*args: Any,
|
366
|
+
name: Optional[str] = None,
|
367
|
+
**kwargs: Any,
|
368
|
+
) -> T:
|
369
|
+
"""Helper function to run functions asynchronously.
|
364
370
|
|
365
371
|
This takes care of the complexity of running the logging functions asynchronously. It will
|
366
|
-
detect the type of environment we are running in (
|
372
|
+
detect the type of environment we are running in (IPython notebook or not) and run the
|
367
373
|
function accordingly.
|
368
374
|
|
369
375
|
Args:
|
370
|
-
func
|
371
|
-
*args: The arguments to pass to the function
|
372
|
-
|
376
|
+
func: The function to run asynchronously.
|
377
|
+
*args: The arguments to pass to the function.
|
378
|
+
name: Optional name for the task.
|
379
|
+
**kwargs: The keyword arguments to pass to the function.
|
373
380
|
|
374
381
|
Returns:
|
375
|
-
The result of the function
|
382
|
+
The result of the function.
|
376
383
|
"""
|
377
384
|
try:
|
378
385
|
if asyncio.get_event_loop().is_running() and is_notebook():
|
@@ -390,8 +397,19 @@ def run_async(func, *args, name=None, **kwargs):
|
|
390
397
|
return asyncio.get_event_loop().run_until_complete(func(*args, **kwargs))
|
391
398
|
|
392
399
|
|
393
|
-
def run_async_check(
|
394
|
-
|
400
|
+
def run_async_check(
|
401
|
+
func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
|
402
|
+
) -> Optional[asyncio.Task[T]]:
|
403
|
+
"""Helper function to run functions asynchronously if the task doesn't already exist.
|
404
|
+
|
405
|
+
Args:
|
406
|
+
func: The function to run asynchronously.
|
407
|
+
*args: The arguments to pass to the function.
|
408
|
+
**kwargs: The keyword arguments to pass to the function.
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
Optional[asyncio.Task[T]]: The task if created or found, None otherwise.
|
412
|
+
"""
|
395
413
|
if __loop:
|
396
414
|
return # we don't need this if we are using our own loop
|
397
415
|
|
@@ -408,16 +426,16 @@ def run_async_check(func, *args, **kwargs):
|
|
408
426
|
pass
|
409
427
|
|
410
428
|
|
411
|
-
def fuzzy_match(string: str, search_string: str, threshold=0.7):
|
412
|
-
"""Check if a string matches another string using fuzzy matching
|
429
|
+
def fuzzy_match(string: str, search_string: str, threshold: float = 0.7) -> bool:
|
430
|
+
"""Check if a string matches another string using fuzzy matching.
|
413
431
|
|
414
432
|
Args:
|
415
|
-
string (str): The string to check
|
416
|
-
search_string (str): The string to search for
|
417
|
-
threshold (float): The similarity threshold to use (Default: 0.7)
|
433
|
+
string (str): The string to check.
|
434
|
+
search_string (str): The string to search for.
|
435
|
+
threshold (float): The similarity threshold to use (Default: 0.7).
|
418
436
|
|
419
437
|
Returns:
|
420
|
-
True if the string matches the search string, False otherwise
|
438
|
+
bool: True if the string matches the search string, False otherwise.
|
421
439
|
"""
|
422
440
|
score = difflib.SequenceMatcher(None, string, search_string).ratio()
|
423
441
|
|
@@ -448,7 +466,7 @@ def test_id_to_name(test_id: str) -> str:
|
|
448
466
|
|
449
467
|
|
450
468
|
def get_model_info(model):
|
451
|
-
"""Attempts to extract all model info from a model object instance"""
|
469
|
+
"""Attempts to extract all model info from a model object instance."""
|
452
470
|
architecture = model.name
|
453
471
|
framework = model.library
|
454
472
|
framework_version = model.library_version
|
@@ -472,7 +490,7 @@ def get_model_info(model):
|
|
472
490
|
|
473
491
|
|
474
492
|
def get_dataset_info(dataset):
|
475
|
-
"""Attempts to extract all dataset info from a dataset object instance"""
|
493
|
+
"""Attempts to extract all dataset info from a dataset object instance."""
|
476
494
|
num_rows, num_cols = dataset.df.shape
|
477
495
|
schema = dataset.df.dtypes.apply(lambda x: x.name).to_dict()
|
478
496
|
description = (
|
@@ -491,7 +509,7 @@ def preview_test_config(config):
|
|
491
509
|
"""Preview test configuration in a collapsible HTML section.
|
492
510
|
|
493
511
|
Args:
|
494
|
-
config (dict): Test configuration dictionary
|
512
|
+
config (dict): Test configuration dictionary.
|
495
513
|
"""
|
496
514
|
|
497
515
|
try:
|
@@ -515,7 +533,7 @@ def preview_test_config(config):
|
|
515
533
|
|
516
534
|
|
517
535
|
def display(widget_or_html, syntax_highlighting=True, mathjax=True):
|
518
|
-
"""Display widgets with extra goodies (syntax highlighting, MathJax, etc.)"""
|
536
|
+
"""Display widgets with extra goodies (syntax highlighting, MathJax, etc.)."""
|
519
537
|
if isinstance(widget_or_html, str):
|
520
538
|
ipy_display(HTML(widget_or_html))
|
521
539
|
# if html we can auto-detect if we actually need syntax highlighting or MathJax
|
@@ -532,7 +550,7 @@ def display(widget_or_html, syntax_highlighting=True, mathjax=True):
|
|
532
550
|
|
533
551
|
|
534
552
|
def md_to_html(md: str, mathml=False) -> str:
|
535
|
-
"""Converts Markdown to HTML using mistune with plugins"""
|
553
|
+
"""Converts Markdown to HTML using mistune with plugins."""
|
536
554
|
# use mistune with math plugin to convert to html
|
537
555
|
html = mistune.create_markdown(
|
538
556
|
plugins=["math", "table", "strikethrough", "footnotes"]
|
@@ -559,6 +577,63 @@ def md_to_html(md: str, mathml=False) -> str:
|
|
559
577
|
return html
|
560
578
|
|
561
579
|
|
580
|
+
def is_html(text: str) -> bool:
|
581
|
+
"""Check if a string is HTML.
|
582
|
+
|
583
|
+
Uses more robust heuristics to determine if a string contains HTML content.
|
584
|
+
|
585
|
+
Args:
|
586
|
+
text (str): The string to check
|
587
|
+
|
588
|
+
Returns:
|
589
|
+
bool: True if the string likely contains HTML, False otherwise
|
590
|
+
"""
|
591
|
+
# Strip whitespace first
|
592
|
+
text = text.strip()
|
593
|
+
|
594
|
+
# Basic check: Must at least start with < and end with >
|
595
|
+
if not (text.startswith("<") and text.endswith(">")):
|
596
|
+
return False
|
597
|
+
|
598
|
+
# Look for common HTML tags
|
599
|
+
common_html_patterns = [
|
600
|
+
r"<html.*?>", # HTML tag
|
601
|
+
r"<body.*?>", # Body tag
|
602
|
+
r"<div.*?>", # Div tag
|
603
|
+
r"<p>.*?</p>", # Paragraph with content
|
604
|
+
r"<h[1-6]>.*?</h[1-6]>", # Headers
|
605
|
+
r"<script.*?>", # Script tags
|
606
|
+
r"<style.*?>", # Style tags
|
607
|
+
r"<a href=.*?>", # Links
|
608
|
+
r"<img.*?>", # Images
|
609
|
+
r"<table.*?>", # Tables
|
610
|
+
r"<!DOCTYPE html>", # DOCTYPE declaration
|
611
|
+
]
|
612
|
+
|
613
|
+
for pattern in common_html_patterns:
|
614
|
+
if re.search(pattern, text, re.IGNORECASE | re.DOTALL):
|
615
|
+
return True
|
616
|
+
|
617
|
+
# If we have at least 2 matching tags, it's likely HTML
|
618
|
+
# This helps detect custom elements or patterns not in our list
|
619
|
+
tags = re.findall(r"</?[a-zA-Z][a-zA-Z0-9]*.*?>", text)
|
620
|
+
if len(tags) >= 2:
|
621
|
+
return True
|
622
|
+
|
623
|
+
# Try parsing with BeautifulSoup as a last resort
|
624
|
+
try:
|
625
|
+
soup = BeautifulSoup(text, "html.parser")
|
626
|
+
# If we find any tags that weren't in the original text, BeautifulSoup
|
627
|
+
# likely tried to fix broken HTML, meaning it's not valid HTML
|
628
|
+
return len(soup.find_all()) > 0
|
629
|
+
|
630
|
+
except Exception as e:
|
631
|
+
logger.error(f"Error checking if text is HTML: {e}")
|
632
|
+
return False
|
633
|
+
|
634
|
+
return False
|
635
|
+
|
636
|
+
|
562
637
|
def inspect_obj(obj):
|
563
638
|
# Filtering only attributes
|
564
639
|
print(len("Attributes:") * "-")
|
@@ -601,3 +676,192 @@ def serialize(obj):
|
|
601
676
|
elif isinstance(obj, (pd.DataFrame, pd.Series)):
|
602
677
|
return "" # Simple empty string for non-serializable objects
|
603
678
|
return obj
|
679
|
+
|
680
|
+
|
681
|
+
def is_text_column(series, threshold=0.05) -> bool:
|
682
|
+
"""
|
683
|
+
Determines if a series is likely to contain text data using heuristics.
|
684
|
+
|
685
|
+
Args:
|
686
|
+
series (pd.Series): The pandas Series to analyze
|
687
|
+
threshold (float): The minimum threshold to classify a pattern match as significant
|
688
|
+
|
689
|
+
Returns:
|
690
|
+
bool: True if the series likely contains text data, False otherwise
|
691
|
+
"""
|
692
|
+
# Filter to non-null string values and sample if needed
|
693
|
+
string_series = series.dropna().astype(str)
|
694
|
+
if len(string_series) == 0:
|
695
|
+
return False
|
696
|
+
if len(string_series) > 1000:
|
697
|
+
string_series = string_series.sample(1000, random_state=42)
|
698
|
+
|
699
|
+
# Calculate basic metrics
|
700
|
+
total_values = len(string_series)
|
701
|
+
unique_ratio = len(string_series.unique()) / total_values if total_values > 0 else 0
|
702
|
+
avg_length = string_series.str.len().mean()
|
703
|
+
avg_words = string_series.str.split(r"\s+").str.len().mean()
|
704
|
+
|
705
|
+
# Check for special text patterns
|
706
|
+
patterns = {
|
707
|
+
"url": r"https?://\S+|www\.\S+",
|
708
|
+
"email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
709
|
+
"filepath": r'(?:[a-zA-Z]:|[\\/])(?:[\\/][^\\/:*?"<>|]+)+',
|
710
|
+
}
|
711
|
+
|
712
|
+
# Check if any special patterns exceed threshold
|
713
|
+
for pattern in patterns.values():
|
714
|
+
if string_series.str.contains(pattern, regex=True, na=False).mean() > threshold:
|
715
|
+
return True
|
716
|
+
|
717
|
+
# Calculate proportion of alphabetic characters
|
718
|
+
total_chars = string_series.str.len().sum()
|
719
|
+
if total_chars > 0:
|
720
|
+
alpha_ratio = string_series.str.count(r"[a-zA-Z]").sum() / total_chars
|
721
|
+
else:
|
722
|
+
alpha_ratio = 0
|
723
|
+
|
724
|
+
# Check for free-form text indicators
|
725
|
+
text_indicators = [
|
726
|
+
unique_ratio > 0.8 and avg_length > 20, # High uniqueness and long strings
|
727
|
+
unique_ratio > 0.4
|
728
|
+
and avg_length > 15
|
729
|
+
and string_series.str.contains(r"[.,;:!?]", regex=True, na=False).mean()
|
730
|
+
> 0.3, # Moderate uniqueness with punctuation
|
731
|
+
string_series.str.contains(
|
732
|
+
r"\b\w+\b\s+\b\w+\b\s+\b\w+\b\s+\b\w+\b", regex=True, na=False
|
733
|
+
).mean()
|
734
|
+
> 0.3, # Contains long phrases
|
735
|
+
avg_words > 5 and alpha_ratio > 0.6, # Many words with mostly letters
|
736
|
+
unique_ratio > 0.95 and avg_length > 10, # Very high uniqueness
|
737
|
+
]
|
738
|
+
|
739
|
+
return any(text_indicators)
|
740
|
+
|
741
|
+
|
742
|
+
def _get_numeric_type_detail(column, dtype, series):
|
743
|
+
"""Helper function to determine numeric type details."""
|
744
|
+
if pd.api.types.is_integer_dtype(dtype):
|
745
|
+
return {"type": "Numeric", "subtype": "Integer"}
|
746
|
+
elif pd.api.types.is_float_dtype(dtype):
|
747
|
+
return {"type": "Numeric", "subtype": "Float"}
|
748
|
+
else:
|
749
|
+
return {"type": "Numeric", "subtype": "Other"}
|
750
|
+
|
751
|
+
|
752
|
+
def _get_text_type_detail(series):
|
753
|
+
"""Helper function to determine text/categorical type details."""
|
754
|
+
string_series = series.dropna().astype(str)
|
755
|
+
|
756
|
+
if len(string_series) == 0:
|
757
|
+
return {"type": "Categorical"}
|
758
|
+
|
759
|
+
# Check for common patterns
|
760
|
+
url_pattern = r"https?://\S+|www\.\S+"
|
761
|
+
email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
762
|
+
filepath_pattern = r'(?:[a-zA-Z]:|[\\/])(?:[\\/][^\\/:*?"<>|]+)+'
|
763
|
+
|
764
|
+
url_ratio = string_series.str.contains(url_pattern, regex=True, na=False).mean()
|
765
|
+
email_ratio = string_series.str.contains(email_pattern, regex=True, na=False).mean()
|
766
|
+
filepath_ratio = string_series.str.contains(
|
767
|
+
filepath_pattern, regex=True, na=False
|
768
|
+
).mean()
|
769
|
+
|
770
|
+
# Check if general text using enhanced function
|
771
|
+
if url_ratio > 0.7:
|
772
|
+
return {"type": "Text", "subtype": "URL"}
|
773
|
+
elif email_ratio > 0.7:
|
774
|
+
return {"type": "Text", "subtype": "Email"}
|
775
|
+
elif filepath_ratio > 0.7:
|
776
|
+
return {"type": "Text", "subtype": "Path"}
|
777
|
+
elif is_text_column(series):
|
778
|
+
return {"type": "Text", "subtype": "FreeText"}
|
779
|
+
|
780
|
+
# Must be categorical
|
781
|
+
n_unique = series.nunique()
|
782
|
+
if n_unique == 2:
|
783
|
+
return {"type": "Categorical", "subtype": "Binary"}
|
784
|
+
else:
|
785
|
+
return {"type": "Categorical", "subtype": "Nominal"}
|
786
|
+
|
787
|
+
|
788
|
+
def get_column_type_detail(df, column) -> dict:
|
789
|
+
"""
|
790
|
+
Get detailed column type information beyond basic type detection.
|
791
|
+
Similar to ydata-profiling's type system.
|
792
|
+
|
793
|
+
Args:
|
794
|
+
df (pd.DataFrame): DataFrame containing the column
|
795
|
+
column (str): Column name to analyze
|
796
|
+
|
797
|
+
Returns:
|
798
|
+
dict: Detailed type information including primary type and subtype
|
799
|
+
"""
|
800
|
+
series = df[column]
|
801
|
+
dtype = series.dtype
|
802
|
+
|
803
|
+
# Initialize result with id and basic type
|
804
|
+
result = {"id": column, "type": "Unknown"}
|
805
|
+
|
806
|
+
# Determine type details based on dtype
|
807
|
+
type_detail = None
|
808
|
+
|
809
|
+
if pd.api.types.is_numeric_dtype(dtype):
|
810
|
+
type_detail = _get_numeric_type_detail(column, dtype, series)
|
811
|
+
elif pd.api.types.is_bool_dtype(dtype):
|
812
|
+
type_detail = {"type": "Boolean"}
|
813
|
+
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
814
|
+
type_detail = {"type": "Datetime"}
|
815
|
+
elif pd.api.types.is_categorical_dtype(dtype) or pd.api.types.is_object_dtype(
|
816
|
+
dtype
|
817
|
+
):
|
818
|
+
type_detail = _get_text_type_detail(series)
|
819
|
+
|
820
|
+
# Update result with type details
|
821
|
+
if type_detail:
|
822
|
+
result.update(type_detail)
|
823
|
+
|
824
|
+
return result
|
825
|
+
|
826
|
+
|
827
|
+
def infer_datatypes(df, detailed=False) -> list:
|
828
|
+
"""
|
829
|
+
Infer data types for columns in a DataFrame.
|
830
|
+
|
831
|
+
Args:
|
832
|
+
df (pd.DataFrame): DataFrame to analyze
|
833
|
+
detailed (bool): Whether to return detailed type information including subtypes
|
834
|
+
|
835
|
+
Returns:
|
836
|
+
list: Column type mappings
|
837
|
+
"""
|
838
|
+
if detailed:
|
839
|
+
return [get_column_type_detail(df, column) for column in df.columns]
|
840
|
+
|
841
|
+
column_type_mappings = {}
|
842
|
+
# Use pandas to infer data types
|
843
|
+
for column in df.columns:
|
844
|
+
# Check if all values are None
|
845
|
+
if df[column].isna().all():
|
846
|
+
column_type_mappings[column] = {"id": column, "type": "Null"}
|
847
|
+
continue
|
848
|
+
|
849
|
+
dtype = df[column].dtype
|
850
|
+
if pd.api.types.is_numeric_dtype(dtype):
|
851
|
+
column_type_mappings[column] = {"id": column, "type": "Numeric"}
|
852
|
+
elif pd.api.types.is_bool_dtype(dtype):
|
853
|
+
column_type_mappings[column] = {"id": column, "type": "Boolean"}
|
854
|
+
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
855
|
+
column_type_mappings[column] = {"id": column, "type": "Datetime"}
|
856
|
+
elif pd.api.types.is_categorical_dtype(dtype) or pd.api.types.is_object_dtype(
|
857
|
+
dtype
|
858
|
+
):
|
859
|
+
# Check if this is more likely to be text than categorical
|
860
|
+
if is_text_column(df[column]):
|
861
|
+
column_type_mappings[column] = {"id": column, "type": "Text"}
|
862
|
+
else:
|
863
|
+
column_type_mappings[column] = {"id": column, "type": "Categorical"}
|
864
|
+
else:
|
865
|
+
column_type_mappings[column] = {"id": column, "type": "Unsupported"}
|
866
|
+
|
867
|
+
return list(column_type_mappings.values())
|
@@ -8,6 +8,7 @@ Dataset class wrapper
|
|
8
8
|
|
9
9
|
import warnings
|
10
10
|
from copy import deepcopy
|
11
|
+
from typing import Any, Dict, List, Optional
|
11
12
|
|
12
13
|
import numpy as np
|
13
14
|
import pandas as pd
|
@@ -24,9 +25,9 @@ logger = get_logger(__name__)
|
|
24
25
|
|
25
26
|
|
26
27
|
class VMDataset(VMInput):
|
27
|
-
"""Base class for VM datasets
|
28
|
+
"""Base class for VM datasets.
|
28
29
|
|
29
|
-
Child classes should be used to support new dataset types (tensor, polars etc)
|
30
|
+
Child classes should be used to support new dataset types (tensor, polars etc.)
|
30
31
|
by converting the user's dataset into a numpy array collecting metadata like
|
31
32
|
column names and then call this (parent) class `__init__` method.
|
32
33
|
|
@@ -200,7 +201,7 @@ class VMDataset(VMInput):
|
|
200
201
|
"Cannot use precomputed probabilities without precomputed predictions"
|
201
202
|
)
|
202
203
|
|
203
|
-
def with_options(self, **kwargs) -> "VMDataset":
|
204
|
+
def with_options(self, **kwargs: Dict[str, Any]) -> "VMDataset":
|
204
205
|
"""Support options provided when passing an input to run_test or run_test_suite
|
205
206
|
|
206
207
|
Example:
|
@@ -253,23 +254,25 @@ class VMDataset(VMInput):
|
|
253
254
|
def assign_predictions(
|
254
255
|
self,
|
255
256
|
model: VMModel,
|
256
|
-
prediction_column: str = None,
|
257
|
-
prediction_values:
|
258
|
-
probability_column: str = None,
|
259
|
-
probability_values:
|
260
|
-
prediction_probabilities:
|
261
|
-
|
262
|
-
|
257
|
+
prediction_column: Optional[str] = None,
|
258
|
+
prediction_values: Optional[List[Any]] = None,
|
259
|
+
probability_column: Optional[str] = None,
|
260
|
+
probability_values: Optional[List[float]] = None,
|
261
|
+
prediction_probabilities: Optional[
|
262
|
+
List[float]
|
263
|
+
] = None, # DEPRECATED: use probability_values
|
264
|
+
**kwargs: Dict[str, Any],
|
265
|
+
) -> None:
|
263
266
|
"""Assign predictions and probabilities to the dataset.
|
264
267
|
|
265
268
|
Args:
|
266
269
|
model (VMModel): The model used to generate the predictions.
|
267
|
-
prediction_column (str
|
268
|
-
prediction_values (
|
269
|
-
probability_column (str
|
270
|
-
probability_values (
|
271
|
-
prediction_probabilities (
|
272
|
-
kwargs: Additional keyword arguments that will get passed through to the model's `predict` method.
|
270
|
+
prediction_column (Optional[str]): The name of the column containing the predictions.
|
271
|
+
prediction_values (Optional[List[Any]]): The values of the predictions.
|
272
|
+
probability_column (Optional[str]): The name of the column containing the probabilities.
|
273
|
+
probability_values (Optional[List[float]]): The values of the probabilities.
|
274
|
+
prediction_probabilities (Optional[List[float]]): DEPRECATED: The values of the probabilities.
|
275
|
+
**kwargs: Additional keyword arguments that will get passed through to the model's `predict` method.
|
273
276
|
"""
|
274
277
|
if prediction_probabilities is not None:
|
275
278
|
warnings.warn(
|
@@ -45,11 +45,11 @@ class ExtraColumns:
|
|
45
45
|
)
|
46
46
|
|
47
47
|
def __contains__(self, key):
|
48
|
-
"""Allow checking if a key is `in` the extra columns"""
|
48
|
+
"""Allow checking if a key is `in` the extra columns."""
|
49
49
|
return key in self.flatten()
|
50
50
|
|
51
51
|
def flatten(self) -> List[str]:
|
52
|
-
"""Get a list of all column names"""
|
52
|
+
"""Get a list of all column names."""
|
53
53
|
return [
|
54
54
|
self.group_by_column,
|
55
55
|
*self.extras,
|
@@ -78,13 +78,14 @@ class ExtraColumns:
|
|
78
78
|
|
79
79
|
|
80
80
|
def as_df(series_or_frame: Union[pd.Series, pd.DataFrame]) -> pd.DataFrame:
|
81
|
+
"""Convert a pandas Series or DataFrame to a DataFrame."""
|
81
82
|
if isinstance(series_or_frame, pd.Series):
|
82
83
|
return series_or_frame.to_frame()
|
83
84
|
return series_or_frame
|
84
85
|
|
85
86
|
|
86
87
|
def _is_probabilties(output):
|
87
|
-
"""Check if the output
|
88
|
+
"""Check if the output is a probability array."""
|
88
89
|
if not isinstance(output, np.ndarray) or output.ndim > 1:
|
89
90
|
return False
|
90
91
|
|
@@ -98,6 +99,7 @@ def _is_probabilties(output):
|
|
98
99
|
|
99
100
|
|
100
101
|
def compute_predictions(model, X, **kwargs) -> tuple:
|
102
|
+
"""Compute predictions and probabilities for a model."""
|
101
103
|
probability_values = None
|
102
104
|
|
103
105
|
try:
|
validmind/vm_models/figure.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
5
|
"""
|
6
|
-
Figure objects track the figure schema supported by the ValidMind API
|
6
|
+
Figure objects track the figure schema supported by the ValidMind API.
|
7
7
|
"""
|
8
8
|
|
9
9
|
import base64
|
@@ -38,7 +38,7 @@ def create_figure(
|
|
38
38
|
key: str,
|
39
39
|
ref_id: str,
|
40
40
|
) -> "Figure":
|
41
|
-
"""Create a VM Figure object from a raw figure object"""
|
41
|
+
"""Create a VM Figure object from a raw figure object."""
|
42
42
|
if is_matplotlib_figure(figure) or is_plotly_figure(figure) or is_png_image(figure):
|
43
43
|
return Figure(key=key, figure=figure, ref_id=ref_id)
|
44
44
|
|
@@ -48,7 +48,7 @@ def create_figure(
|
|
48
48
|
@dataclass
|
49
49
|
class Figure:
|
50
50
|
"""
|
51
|
-
Figure objects track the schema supported by the ValidMind API
|
51
|
+
Figure objects track the schema supported by the ValidMind API.
|
52
52
|
"""
|
53
53
|
|
54
54
|
key: str
|
@@ -115,7 +115,7 @@ class Figure:
|
|
115
115
|
|
116
116
|
def serialize(self):
|
117
117
|
"""
|
118
|
-
Serializes the Figure to a dictionary so it can be sent to the API
|
118
|
+
Serializes the Figure to a dictionary so it can be sent to the API.
|
119
119
|
"""
|
120
120
|
return {
|
121
121
|
"type": self._type,
|
@@ -125,7 +125,7 @@ class Figure:
|
|
125
125
|
|
126
126
|
def _get_b64_url(self):
|
127
127
|
"""
|
128
|
-
Returns a base64 encoded URL for the figure
|
128
|
+
Returns a base64 encoded URL for the figure.
|
129
129
|
"""
|
130
130
|
if is_matplotlib_figure(self.figure):
|
131
131
|
buffer = BytesIO()
|
@@ -152,7 +152,7 @@ class Figure:
|
|
152
152
|
)
|
153
153
|
|
154
154
|
def serialize_files(self):
|
155
|
-
"""Creates a `requests`-compatible files object to be sent to the API"""
|
155
|
+
"""Creates a `requests`-compatible files object to be sent to the API."""
|
156
156
|
if is_matplotlib_figure(self.figure):
|
157
157
|
buffer = BytesIO()
|
158
158
|
self.figure.savefig(buffer, bbox_inches="tight")
|
validmind/vm_models/input.py
CHANGED
@@ -5,27 +5,28 @@
|
|
5
5
|
"""Base class for ValidMind Input types"""
|
6
6
|
|
7
7
|
from abc import ABC
|
8
|
+
from typing import Any, Dict
|
8
9
|
|
9
10
|
|
10
11
|
class VMInput(ABC):
|
11
12
|
"""
|
12
|
-
Base class for ValidMind Input types
|
13
|
+
Base class for ValidMind Input types.
|
13
14
|
"""
|
14
15
|
|
15
|
-
def with_options(self, **kwargs) -> "VMInput":
|
16
|
+
def with_options(self, **kwargs: Dict[str, Any]) -> "VMInput":
|
16
17
|
"""
|
17
18
|
Allows for setting options on the input object that are passed by the user
|
18
|
-
when using the input to run a test or set of tests
|
19
|
+
when using the input to run a test or set of tests.
|
19
20
|
|
20
21
|
To allow options, just override this method in the subclass (see VMDataset)
|
21
22
|
and ensure that it returns a new instance of the input with the specified options
|
22
23
|
set.
|
23
24
|
|
24
25
|
Args:
|
25
|
-
**kwargs: Arbitrary keyword arguments that will be passed to the input object
|
26
|
+
**kwargs: Arbitrary keyword arguments that will be passed to the input object.
|
26
27
|
|
27
28
|
Returns:
|
28
|
-
VMInput: A new instance of the input with the specified options set
|
29
|
+
VMInput: A new instance of the input with the specified options set.
|
29
30
|
"""
|
30
31
|
if kwargs:
|
31
32
|
raise NotImplementedError("This type of input does not support options")
|
validmind/vm_models/model.py
CHANGED
@@ -40,7 +40,7 @@ R_MODEL_METHODS = [
|
|
40
40
|
|
41
41
|
|
42
42
|
class ModelTask(Enum):
|
43
|
-
"""Model task enums"""
|
43
|
+
"""Model task enums."""
|
44
44
|
|
45
45
|
# TODO: add more tasks
|
46
46
|
CLASSIFICATION = "classification"
|
@@ -67,7 +67,7 @@ class ModelPipeline:
|
|
67
67
|
@dataclass
|
68
68
|
class ModelAttributes:
|
69
69
|
"""
|
70
|
-
Model attributes definition
|
70
|
+
Model attributes definition.
|
71
71
|
"""
|
72
72
|
|
73
73
|
architecture: str = None
|
@@ -79,7 +79,7 @@ class ModelAttributes:
|
|
79
79
|
@classmethod
|
80
80
|
def from_dict(cls, data):
|
81
81
|
"""
|
82
|
-
Creates a ModelAttributes instance from a dictionary
|
82
|
+
Creates a ModelAttributes instance from a dictionary.
|
83
83
|
"""
|
84
84
|
return cls(
|
85
85
|
architecture=data.get("architecture"),
|
@@ -235,8 +235,8 @@ def is_model_metadata(model):
|
|
235
235
|
Checks if the model is a dictionary containing metadata about a model.
|
236
236
|
We want to check if the metadata dictionary contains at least the following keys:
|
237
237
|
|
238
|
-
-
|
239
|
-
-
|
238
|
+
- Architecture
|
239
|
+
- Language
|
240
240
|
"""
|
241
241
|
if not isinstance(model, dict):
|
242
242
|
return False
|