validmind 2.8.12__py3-none-any.whl → 2.8.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +6 -5
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +13 -9
- validmind/ai/utils.py +2 -2
- validmind/api_client.py +75 -32
- validmind/client.py +111 -100
- validmind/client_config.py +3 -3
- validmind/datasets/classification/__init__.py +7 -3
- validmind/datasets/credit_risk/lending_club.py +28 -16
- validmind/datasets/nlp/cnn_dailymail.py +10 -4
- validmind/datasets/regression/__init__.py +22 -5
- validmind/errors.py +17 -7
- validmind/input_registry.py +1 -1
- validmind/logging.py +44 -35
- validmind/models/foundation.py +2 -2
- validmind/models/function.py +10 -3
- validmind/template.py +33 -24
- validmind/test_suites/__init__.py +2 -2
- validmind/tests/_store.py +13 -4
- validmind/tests/comparison.py +65 -33
- validmind/tests/data_validation/ClassImbalance.py +3 -1
- validmind/tests/data_validation/DatasetDescription.py +2 -23
- validmind/tests/data_validation/DescriptiveStatistics.py +1 -1
- validmind/tests/data_validation/Skewness.py +7 -6
- validmind/tests/decorator.py +14 -11
- validmind/tests/load.py +38 -24
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +4 -2
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +4 -2
- validmind/tests/model_validation/ragas/ContextPrecision.py +4 -2
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +4 -2
- validmind/tests/model_validation/ragas/ContextRecall.py +4 -2
- validmind/tests/model_validation/ragas/Faithfulness.py +4 -2
- validmind/tests/model_validation/ragas/ResponseRelevancy.py +4 -2
- validmind/tests/model_validation/ragas/SemanticSimilarity.py +4 -2
- validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +13 -3
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -1
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +28 -25
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +15 -10
- validmind/tests/output.py +66 -11
- validmind/tests/run.py +28 -14
- validmind/tests/test_providers.py +28 -35
- validmind/tests/utils.py +17 -4
- validmind/unit_metrics/__init__.py +1 -1
- validmind/utils.py +295 -31
- validmind/vm_models/dataset/dataset.py +83 -43
- validmind/vm_models/dataset/utils.py +5 -3
- validmind/vm_models/figure.py +6 -6
- validmind/vm_models/input.py +6 -5
- validmind/vm_models/model.py +5 -5
- validmind/vm_models/result/result.py +122 -43
- validmind/vm_models/result/utils.py +5 -5
- validmind/vm_models/test_suite/__init__.py +5 -0
- validmind/vm_models/test_suite/runner.py +5 -5
- validmind/vm_models/test_suite/summary.py +20 -2
- validmind/vm_models/test_suite/test.py +6 -6
- validmind/vm_models/test_suite/test_suite.py +10 -10
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/METADATA +3 -4
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/RECORD +61 -60
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/WHEEL +1 -1
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/LICENSE +0 -0
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/entry_points.txt +0 -0
validmind/utils.py
CHANGED
@@ -12,7 +12,7 @@ import sys
|
|
12
12
|
import warnings
|
13
13
|
from datetime import date, datetime, time
|
14
14
|
from platform import python_version
|
15
|
-
from typing import Any, Dict, List
|
15
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar
|
16
16
|
|
17
17
|
import matplotlib.pylab as pylab
|
18
18
|
import mistune
|
@@ -20,6 +20,7 @@ import nest_asyncio
|
|
20
20
|
import numpy as np
|
21
21
|
import pandas as pd
|
22
22
|
import seaborn as sns
|
23
|
+
from bs4 import BeautifulSoup
|
23
24
|
from IPython.core import getipython
|
24
25
|
from IPython.display import HTML
|
25
26
|
from IPython.display import display as ipy_display
|
@@ -59,23 +60,25 @@ pylab.rcParams.update(params)
|
|
59
60
|
|
60
61
|
logger = get_logger(__name__)
|
61
62
|
|
63
|
+
T = TypeVar("T")
|
64
|
+
|
62
65
|
|
63
66
|
def parse_version(version: str) -> tuple[int, ...]:
|
64
67
|
"""
|
65
|
-
Parse a semver version string into a tuple of major, minor, patch integers
|
68
|
+
Parse a semver version string into a tuple of major, minor, patch integers.
|
66
69
|
|
67
70
|
Args:
|
68
|
-
version (str): The semantic version string to parse
|
71
|
+
version (str): The semantic version string to parse.
|
69
72
|
|
70
73
|
Returns:
|
71
|
-
tuple[int, ...]: A tuple of major, minor, patch integers
|
74
|
+
tuple[int, ...]: A tuple of major, minor, patch integers.
|
72
75
|
"""
|
73
76
|
return tuple(int(x) for x in version.split(".")[:3])
|
74
77
|
|
75
78
|
|
76
79
|
def is_notebook() -> bool:
|
77
80
|
"""
|
78
|
-
Checks if the code is running in a Jupyter notebook or IPython shell
|
81
|
+
Checks if the code is running in a Jupyter notebook or IPython shell.
|
79
82
|
|
80
83
|
https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
|
81
84
|
"""
|
@@ -209,9 +212,7 @@ class HumanReadableEncoder(NumpyEncoder):
|
|
209
212
|
|
210
213
|
|
211
214
|
def get_full_typename(o: Any) -> Any:
|
212
|
-
"""We determine types based on type names so we don't have to import
|
213
|
-
(and therefore depend on) PyTorch, TensorFlow, etc.
|
214
|
-
"""
|
215
|
+
"""We determine types based on type names so we don't have to import."""
|
215
216
|
instance_name = o.__class__.__module__ + "." + o.__class__.__name__
|
216
217
|
if instance_name in ["builtins.module", "__builtin__.module"]:
|
217
218
|
return o.__name__
|
@@ -313,9 +314,9 @@ def format_key_values(key_values: Dict[str, Any]) -> Dict[str, Any]:
|
|
313
314
|
|
314
315
|
def summarize_data_quality_results(results):
|
315
316
|
"""
|
316
|
-
TODO: generalize this to work with metrics and test results
|
317
|
+
TODO: generalize this to work with metrics and test results.
|
317
318
|
|
318
|
-
Summarize the results of the data quality test suite
|
319
|
+
Summarize the results of the data quality test suite.
|
319
320
|
"""
|
320
321
|
test_results = []
|
321
322
|
for result in results:
|
@@ -354,25 +355,31 @@ def format_number(number):
|
|
354
355
|
|
355
356
|
|
356
357
|
def format_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
357
|
-
"""Format a pandas DataFrame for display purposes"""
|
358
|
+
"""Format a pandas DataFrame for display purposes."""
|
358
359
|
df = df.style.set_properties(**{"text-align": "left"}).hide(axis="index")
|
359
360
|
return df.set_table_styles([dict(selector="th", props=[("text-align", "left")])])
|
360
361
|
|
361
362
|
|
362
|
-
def run_async(
|
363
|
-
|
363
|
+
def run_async(
|
364
|
+
func: Callable[..., Awaitable[T]],
|
365
|
+
*args: Any,
|
366
|
+
name: Optional[str] = None,
|
367
|
+
**kwargs: Any,
|
368
|
+
) -> T:
|
369
|
+
"""Helper function to run functions asynchronously.
|
364
370
|
|
365
371
|
This takes care of the complexity of running the logging functions asynchronously. It will
|
366
|
-
detect the type of environment we are running in (
|
372
|
+
detect the type of environment we are running in (IPython notebook or not) and run the
|
367
373
|
function accordingly.
|
368
374
|
|
369
375
|
Args:
|
370
|
-
func
|
371
|
-
*args: The arguments to pass to the function
|
372
|
-
|
376
|
+
func: The function to run asynchronously.
|
377
|
+
*args: The arguments to pass to the function.
|
378
|
+
name: Optional name for the task.
|
379
|
+
**kwargs: The keyword arguments to pass to the function.
|
373
380
|
|
374
381
|
Returns:
|
375
|
-
The result of the function
|
382
|
+
The result of the function.
|
376
383
|
"""
|
377
384
|
try:
|
378
385
|
if asyncio.get_event_loop().is_running() and is_notebook():
|
@@ -390,8 +397,19 @@ def run_async(func, *args, name=None, **kwargs):
|
|
390
397
|
return asyncio.get_event_loop().run_until_complete(func(*args, **kwargs))
|
391
398
|
|
392
399
|
|
393
|
-
def run_async_check(
|
394
|
-
|
400
|
+
def run_async_check(
|
401
|
+
func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
|
402
|
+
) -> Optional[asyncio.Task[T]]:
|
403
|
+
"""Helper function to run functions asynchronously if the task doesn't already exist.
|
404
|
+
|
405
|
+
Args:
|
406
|
+
func: The function to run asynchronously.
|
407
|
+
*args: The arguments to pass to the function.
|
408
|
+
**kwargs: The keyword arguments to pass to the function.
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
Optional[asyncio.Task[T]]: The task if created or found, None otherwise.
|
412
|
+
"""
|
395
413
|
if __loop:
|
396
414
|
return # we don't need this if we are using our own loop
|
397
415
|
|
@@ -408,16 +426,16 @@ def run_async_check(func, *args, **kwargs):
|
|
408
426
|
pass
|
409
427
|
|
410
428
|
|
411
|
-
def fuzzy_match(string: str, search_string: str, threshold=0.7):
|
412
|
-
"""Check if a string matches another string using fuzzy matching
|
429
|
+
def fuzzy_match(string: str, search_string: str, threshold: float = 0.7) -> bool:
|
430
|
+
"""Check if a string matches another string using fuzzy matching.
|
413
431
|
|
414
432
|
Args:
|
415
|
-
string (str): The string to check
|
416
|
-
search_string (str): The string to search for
|
417
|
-
threshold (float): The similarity threshold to use (Default: 0.7)
|
433
|
+
string (str): The string to check.
|
434
|
+
search_string (str): The string to search for.
|
435
|
+
threshold (float): The similarity threshold to use (Default: 0.7).
|
418
436
|
|
419
437
|
Returns:
|
420
|
-
True if the string matches the search string, False otherwise
|
438
|
+
bool: True if the string matches the search string, False otherwise.
|
421
439
|
"""
|
422
440
|
score = difflib.SequenceMatcher(None, string, search_string).ratio()
|
423
441
|
|
@@ -448,7 +466,7 @@ def test_id_to_name(test_id: str) -> str:
|
|
448
466
|
|
449
467
|
|
450
468
|
def get_model_info(model):
|
451
|
-
"""Attempts to extract all model info from a model object instance"""
|
469
|
+
"""Attempts to extract all model info from a model object instance."""
|
452
470
|
architecture = model.name
|
453
471
|
framework = model.library
|
454
472
|
framework_version = model.library_version
|
@@ -472,7 +490,7 @@ def get_model_info(model):
|
|
472
490
|
|
473
491
|
|
474
492
|
def get_dataset_info(dataset):
|
475
|
-
"""Attempts to extract all dataset info from a dataset object instance"""
|
493
|
+
"""Attempts to extract all dataset info from a dataset object instance."""
|
476
494
|
num_rows, num_cols = dataset.df.shape
|
477
495
|
schema = dataset.df.dtypes.apply(lambda x: x.name).to_dict()
|
478
496
|
description = (
|
@@ -491,7 +509,7 @@ def preview_test_config(config):
|
|
491
509
|
"""Preview test configuration in a collapsible HTML section.
|
492
510
|
|
493
511
|
Args:
|
494
|
-
config (dict): Test configuration dictionary
|
512
|
+
config (dict): Test configuration dictionary.
|
495
513
|
"""
|
496
514
|
|
497
515
|
try:
|
@@ -515,7 +533,7 @@ def preview_test_config(config):
|
|
515
533
|
|
516
534
|
|
517
535
|
def display(widget_or_html, syntax_highlighting=True, mathjax=True):
|
518
|
-
"""Display widgets with extra goodies (syntax highlighting, MathJax, etc.)"""
|
536
|
+
"""Display widgets with extra goodies (syntax highlighting, MathJax, etc.)."""
|
519
537
|
if isinstance(widget_or_html, str):
|
520
538
|
ipy_display(HTML(widget_or_html))
|
521
539
|
# if html we can auto-detect if we actually need syntax highlighting or MathJax
|
@@ -532,7 +550,7 @@ def display(widget_or_html, syntax_highlighting=True, mathjax=True):
|
|
532
550
|
|
533
551
|
|
534
552
|
def md_to_html(md: str, mathml=False) -> str:
|
535
|
-
"""Converts Markdown to HTML using mistune with plugins"""
|
553
|
+
"""Converts Markdown to HTML using mistune with plugins."""
|
536
554
|
# use mistune with math plugin to convert to html
|
537
555
|
html = mistune.create_markdown(
|
538
556
|
plugins=["math", "table", "strikethrough", "footnotes"]
|
@@ -559,6 +577,63 @@ def md_to_html(md: str, mathml=False) -> str:
|
|
559
577
|
return html
|
560
578
|
|
561
579
|
|
580
|
+
def is_html(text: str) -> bool:
|
581
|
+
"""Check if a string is HTML.
|
582
|
+
|
583
|
+
Uses more robust heuristics to determine if a string contains HTML content.
|
584
|
+
|
585
|
+
Args:
|
586
|
+
text (str): The string to check
|
587
|
+
|
588
|
+
Returns:
|
589
|
+
bool: True if the string likely contains HTML, False otherwise
|
590
|
+
"""
|
591
|
+
# Strip whitespace first
|
592
|
+
text = text.strip()
|
593
|
+
|
594
|
+
# Basic check: Must at least start with < and end with >
|
595
|
+
if not (text.startswith("<") and text.endswith(">")):
|
596
|
+
return False
|
597
|
+
|
598
|
+
# Look for common HTML tags
|
599
|
+
common_html_patterns = [
|
600
|
+
r"<html.*?>", # HTML tag
|
601
|
+
r"<body.*?>", # Body tag
|
602
|
+
r"<div.*?>", # Div tag
|
603
|
+
r"<p>.*?</p>", # Paragraph with content
|
604
|
+
r"<h[1-6]>.*?</h[1-6]>", # Headers
|
605
|
+
r"<script.*?>", # Script tags
|
606
|
+
r"<style.*?>", # Style tags
|
607
|
+
r"<a href=.*?>", # Links
|
608
|
+
r"<img.*?>", # Images
|
609
|
+
r"<table.*?>", # Tables
|
610
|
+
r"<!DOCTYPE html>", # DOCTYPE declaration
|
611
|
+
]
|
612
|
+
|
613
|
+
for pattern in common_html_patterns:
|
614
|
+
if re.search(pattern, text, re.IGNORECASE | re.DOTALL):
|
615
|
+
return True
|
616
|
+
|
617
|
+
# If we have at least 2 matching tags, it's likely HTML
|
618
|
+
# This helps detect custom elements or patterns not in our list
|
619
|
+
tags = re.findall(r"</?[a-zA-Z][a-zA-Z0-9]*.*?>", text)
|
620
|
+
if len(tags) >= 2:
|
621
|
+
return True
|
622
|
+
|
623
|
+
# Try parsing with BeautifulSoup as a last resort
|
624
|
+
try:
|
625
|
+
soup = BeautifulSoup(text, "html.parser")
|
626
|
+
# If we find any tags that weren't in the original text, BeautifulSoup
|
627
|
+
# likely tried to fix broken HTML, meaning it's not valid HTML
|
628
|
+
return len(soup.find_all()) > 0
|
629
|
+
|
630
|
+
except Exception as e:
|
631
|
+
logger.error(f"Error checking if text is HTML: {e}")
|
632
|
+
return False
|
633
|
+
|
634
|
+
return False
|
635
|
+
|
636
|
+
|
562
637
|
def inspect_obj(obj):
|
563
638
|
# Filtering only attributes
|
564
639
|
print(len("Attributes:") * "-")
|
@@ -601,3 +676,192 @@ def serialize(obj):
|
|
601
676
|
elif isinstance(obj, (pd.DataFrame, pd.Series)):
|
602
677
|
return "" # Simple empty string for non-serializable objects
|
603
678
|
return obj
|
679
|
+
|
680
|
+
|
681
|
+
def is_text_column(series, threshold=0.05) -> bool:
|
682
|
+
"""
|
683
|
+
Determines if a series is likely to contain text data using heuristics.
|
684
|
+
|
685
|
+
Args:
|
686
|
+
series (pd.Series): The pandas Series to analyze
|
687
|
+
threshold (float): The minimum threshold to classify a pattern match as significant
|
688
|
+
|
689
|
+
Returns:
|
690
|
+
bool: True if the series likely contains text data, False otherwise
|
691
|
+
"""
|
692
|
+
# Filter to non-null string values and sample if needed
|
693
|
+
string_series = series.dropna().astype(str)
|
694
|
+
if len(string_series) == 0:
|
695
|
+
return False
|
696
|
+
if len(string_series) > 1000:
|
697
|
+
string_series = string_series.sample(1000, random_state=42)
|
698
|
+
|
699
|
+
# Calculate basic metrics
|
700
|
+
total_values = len(string_series)
|
701
|
+
unique_ratio = len(string_series.unique()) / total_values if total_values > 0 else 0
|
702
|
+
avg_length = string_series.str.len().mean()
|
703
|
+
avg_words = string_series.str.split(r"\s+").str.len().mean()
|
704
|
+
|
705
|
+
# Check for special text patterns
|
706
|
+
patterns = {
|
707
|
+
"url": r"https?://\S+|www\.\S+",
|
708
|
+
"email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
709
|
+
"filepath": r'(?:[a-zA-Z]:|[\\/])(?:[\\/][^\\/:*?"<>|]+)+',
|
710
|
+
}
|
711
|
+
|
712
|
+
# Check if any special patterns exceed threshold
|
713
|
+
for pattern in patterns.values():
|
714
|
+
if string_series.str.contains(pattern, regex=True, na=False).mean() > threshold:
|
715
|
+
return True
|
716
|
+
|
717
|
+
# Calculate proportion of alphabetic characters
|
718
|
+
total_chars = string_series.str.len().sum()
|
719
|
+
if total_chars > 0:
|
720
|
+
alpha_ratio = string_series.str.count(r"[a-zA-Z]").sum() / total_chars
|
721
|
+
else:
|
722
|
+
alpha_ratio = 0
|
723
|
+
|
724
|
+
# Check for free-form text indicators
|
725
|
+
text_indicators = [
|
726
|
+
unique_ratio > 0.8 and avg_length > 20, # High uniqueness and long strings
|
727
|
+
unique_ratio > 0.4
|
728
|
+
and avg_length > 15
|
729
|
+
and string_series.str.contains(r"[.,;:!?]", regex=True, na=False).mean()
|
730
|
+
> 0.3, # Moderate uniqueness with punctuation
|
731
|
+
string_series.str.contains(
|
732
|
+
r"\b\w+\b\s+\b\w+\b\s+\b\w+\b\s+\b\w+\b", regex=True, na=False
|
733
|
+
).mean()
|
734
|
+
> 0.3, # Contains long phrases
|
735
|
+
avg_words > 5 and alpha_ratio > 0.6, # Many words with mostly letters
|
736
|
+
unique_ratio > 0.95 and avg_length > 10, # Very high uniqueness
|
737
|
+
]
|
738
|
+
|
739
|
+
return any(text_indicators)
|
740
|
+
|
741
|
+
|
742
|
+
def _get_numeric_type_detail(column, dtype, series):
|
743
|
+
"""Helper function to determine numeric type details."""
|
744
|
+
if pd.api.types.is_integer_dtype(dtype):
|
745
|
+
return {"type": "Numeric", "subtype": "Integer"}
|
746
|
+
elif pd.api.types.is_float_dtype(dtype):
|
747
|
+
return {"type": "Numeric", "subtype": "Float"}
|
748
|
+
else:
|
749
|
+
return {"type": "Numeric", "subtype": "Other"}
|
750
|
+
|
751
|
+
|
752
|
+
def _get_text_type_detail(series):
|
753
|
+
"""Helper function to determine text/categorical type details."""
|
754
|
+
string_series = series.dropna().astype(str)
|
755
|
+
|
756
|
+
if len(string_series) == 0:
|
757
|
+
return {"type": "Categorical"}
|
758
|
+
|
759
|
+
# Check for common patterns
|
760
|
+
url_pattern = r"https?://\S+|www\.\S+"
|
761
|
+
email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
762
|
+
filepath_pattern = r'(?:[a-zA-Z]:|[\\/])(?:[\\/][^\\/:*?"<>|]+)+'
|
763
|
+
|
764
|
+
url_ratio = string_series.str.contains(url_pattern, regex=True, na=False).mean()
|
765
|
+
email_ratio = string_series.str.contains(email_pattern, regex=True, na=False).mean()
|
766
|
+
filepath_ratio = string_series.str.contains(
|
767
|
+
filepath_pattern, regex=True, na=False
|
768
|
+
).mean()
|
769
|
+
|
770
|
+
# Check if general text using enhanced function
|
771
|
+
if url_ratio > 0.7:
|
772
|
+
return {"type": "Text", "subtype": "URL"}
|
773
|
+
elif email_ratio > 0.7:
|
774
|
+
return {"type": "Text", "subtype": "Email"}
|
775
|
+
elif filepath_ratio > 0.7:
|
776
|
+
return {"type": "Text", "subtype": "Path"}
|
777
|
+
elif is_text_column(series):
|
778
|
+
return {"type": "Text", "subtype": "FreeText"}
|
779
|
+
|
780
|
+
# Must be categorical
|
781
|
+
n_unique = series.nunique()
|
782
|
+
if n_unique == 2:
|
783
|
+
return {"type": "Categorical", "subtype": "Binary"}
|
784
|
+
else:
|
785
|
+
return {"type": "Categorical", "subtype": "Nominal"}
|
786
|
+
|
787
|
+
|
788
|
+
def get_column_type_detail(df, column) -> dict:
|
789
|
+
"""
|
790
|
+
Get detailed column type information beyond basic type detection.
|
791
|
+
Similar to ydata-profiling's type system.
|
792
|
+
|
793
|
+
Args:
|
794
|
+
df (pd.DataFrame): DataFrame containing the column
|
795
|
+
column (str): Column name to analyze
|
796
|
+
|
797
|
+
Returns:
|
798
|
+
dict: Detailed type information including primary type and subtype
|
799
|
+
"""
|
800
|
+
series = df[column]
|
801
|
+
dtype = series.dtype
|
802
|
+
|
803
|
+
# Initialize result with id and basic type
|
804
|
+
result = {"id": column, "type": "Unknown"}
|
805
|
+
|
806
|
+
# Determine type details based on dtype
|
807
|
+
type_detail = None
|
808
|
+
|
809
|
+
if pd.api.types.is_numeric_dtype(dtype):
|
810
|
+
type_detail = _get_numeric_type_detail(column, dtype, series)
|
811
|
+
elif pd.api.types.is_bool_dtype(dtype):
|
812
|
+
type_detail = {"type": "Boolean"}
|
813
|
+
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
814
|
+
type_detail = {"type": "Datetime"}
|
815
|
+
elif pd.api.types.is_categorical_dtype(dtype) or pd.api.types.is_object_dtype(
|
816
|
+
dtype
|
817
|
+
):
|
818
|
+
type_detail = _get_text_type_detail(series)
|
819
|
+
|
820
|
+
# Update result with type details
|
821
|
+
if type_detail:
|
822
|
+
result.update(type_detail)
|
823
|
+
|
824
|
+
return result
|
825
|
+
|
826
|
+
|
827
|
+
def infer_datatypes(df, detailed=False) -> list:
|
828
|
+
"""
|
829
|
+
Infer data types for columns in a DataFrame.
|
830
|
+
|
831
|
+
Args:
|
832
|
+
df (pd.DataFrame): DataFrame to analyze
|
833
|
+
detailed (bool): Whether to return detailed type information including subtypes
|
834
|
+
|
835
|
+
Returns:
|
836
|
+
list: Column type mappings
|
837
|
+
"""
|
838
|
+
if detailed:
|
839
|
+
return [get_column_type_detail(df, column) for column in df.columns]
|
840
|
+
|
841
|
+
column_type_mappings = {}
|
842
|
+
# Use pandas to infer data types
|
843
|
+
for column in df.columns:
|
844
|
+
# Check if all values are None
|
845
|
+
if df[column].isna().all():
|
846
|
+
column_type_mappings[column] = {"id": column, "type": "Null"}
|
847
|
+
continue
|
848
|
+
|
849
|
+
dtype = df[column].dtype
|
850
|
+
if pd.api.types.is_numeric_dtype(dtype):
|
851
|
+
column_type_mappings[column] = {"id": column, "type": "Numeric"}
|
852
|
+
elif pd.api.types.is_bool_dtype(dtype):
|
853
|
+
column_type_mappings[column] = {"id": column, "type": "Boolean"}
|
854
|
+
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
855
|
+
column_type_mappings[column] = {"id": column, "type": "Datetime"}
|
856
|
+
elif pd.api.types.is_categorical_dtype(dtype) or pd.api.types.is_object_dtype(
|
857
|
+
dtype
|
858
|
+
):
|
859
|
+
# Check if this is more likely to be text than categorical
|
860
|
+
if is_text_column(df[column]):
|
861
|
+
column_type_mappings[column] = {"id": column, "type": "Text"}
|
862
|
+
else:
|
863
|
+
column_type_mappings[column] = {"id": column, "type": "Categorical"}
|
864
|
+
else:
|
865
|
+
column_type_mappings[column] = {"id": column, "type": "Unsupported"}
|
866
|
+
|
867
|
+
return list(column_type_mappings.values())
|
@@ -8,6 +8,7 @@ Dataset class wrapper
|
|
8
8
|
|
9
9
|
import warnings
|
10
10
|
from copy import deepcopy
|
11
|
+
from typing import Any, Dict, List, Optional
|
11
12
|
|
12
13
|
import numpy as np
|
13
14
|
import pandas as pd
|
@@ -24,9 +25,9 @@ logger = get_logger(__name__)
|
|
24
25
|
|
25
26
|
|
26
27
|
class VMDataset(VMInput):
|
27
|
-
"""Base class for VM datasets
|
28
|
+
"""Base class for VM datasets.
|
28
29
|
|
29
|
-
Child classes should be used to support new dataset types (tensor, polars etc)
|
30
|
+
Child classes should be used to support new dataset types (tensor, polars etc.)
|
30
31
|
by converting the user's dataset into a numpy array collecting metadata like
|
31
32
|
column names and then call this (parent) class `__init__` method.
|
32
33
|
|
@@ -46,6 +47,7 @@ class VMDataset(VMInput):
|
|
46
47
|
target_class_labels (Dict): The class labels for the target columns.
|
47
48
|
df (pd.DataFrame): The dataset as a pandas DataFrame.
|
48
49
|
extra_columns (Dict): Extra columns to include in the dataset.
|
50
|
+
copy_data (bool): Whether to copy the data. Defaults to True.
|
49
51
|
"""
|
50
52
|
|
51
53
|
def __repr__(self):
|
@@ -65,6 +67,7 @@ class VMDataset(VMInput):
|
|
65
67
|
text_column: str = None,
|
66
68
|
extra_columns: dict = None,
|
67
69
|
target_class_labels: dict = None,
|
70
|
+
copy_data: bool = True,
|
68
71
|
):
|
69
72
|
"""
|
70
73
|
Initializes a VMDataset instance.
|
@@ -81,6 +84,7 @@ class VMDataset(VMInput):
|
|
81
84
|
feature_columns (str, optional): The feature column names of the dataset. Defaults to None.
|
82
85
|
text_column (str, optional): The text column name of the dataset for nlp tasks. Defaults to None.
|
83
86
|
target_class_labels (Dict, optional): The class labels for the target columns. Defaults to None.
|
87
|
+
copy_data (bool, optional): Whether to copy the data. Defaults to True.
|
84
88
|
"""
|
85
89
|
# initialize input_id
|
86
90
|
self.input_id = input_id
|
@@ -111,6 +115,7 @@ class VMDataset(VMInput):
|
|
111
115
|
self.target_class_labels = target_class_labels
|
112
116
|
self.extra_columns = ExtraColumns.from_dict(extra_columns)
|
113
117
|
self._set_feature_columns(feature_columns)
|
118
|
+
self._copy_data = copy_data
|
114
119
|
|
115
120
|
if model:
|
116
121
|
self.assign_predictions(model)
|
@@ -128,16 +133,19 @@ class VMDataset(VMInput):
|
|
128
133
|
excluded = [self.target_column, *self.extra_columns.flatten()]
|
129
134
|
self.feature_columns = [col for col in self.columns if col not in excluded]
|
130
135
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
.
|
140
|
-
|
136
|
+
# Get dtypes without loading data into memory
|
137
|
+
feature_dtypes = self._df[self.feature_columns].dtypes
|
138
|
+
|
139
|
+
self.feature_columns_numeric = feature_dtypes[
|
140
|
+
feature_dtypes.apply(lambda x: pd.api.types.is_numeric_dtype(x))
|
141
|
+
].index.tolist()
|
142
|
+
|
143
|
+
self.feature_columns_categorical = feature_dtypes[
|
144
|
+
feature_dtypes.apply(
|
145
|
+
lambda x: pd.api.types.is_categorical_dtype(x)
|
146
|
+
or pd.api.types.is_object_dtype(x)
|
147
|
+
)
|
148
|
+
].index.tolist()
|
141
149
|
|
142
150
|
def _add_column(self, column_name, column_values):
|
143
151
|
column_values = np.array(column_values)
|
@@ -200,7 +208,7 @@ class VMDataset(VMInput):
|
|
200
208
|
"Cannot use precomputed probabilities without precomputed predictions"
|
201
209
|
)
|
202
210
|
|
203
|
-
def with_options(self, **kwargs) -> "VMDataset":
|
211
|
+
def with_options(self, **kwargs: Dict[str, Any]) -> "VMDataset":
|
204
212
|
"""Support options provided when passing an input to run_test or run_test_suite
|
205
213
|
|
206
214
|
Example:
|
@@ -253,23 +261,25 @@ class VMDataset(VMInput):
|
|
253
261
|
def assign_predictions(
|
254
262
|
self,
|
255
263
|
model: VMModel,
|
256
|
-
prediction_column: str = None,
|
257
|
-
prediction_values:
|
258
|
-
probability_column: str = None,
|
259
|
-
probability_values:
|
260
|
-
prediction_probabilities:
|
261
|
-
|
262
|
-
|
264
|
+
prediction_column: Optional[str] = None,
|
265
|
+
prediction_values: Optional[List[Any]] = None,
|
266
|
+
probability_column: Optional[str] = None,
|
267
|
+
probability_values: Optional[List[float]] = None,
|
268
|
+
prediction_probabilities: Optional[
|
269
|
+
List[float]
|
270
|
+
] = None, # DEPRECATED: use probability_values
|
271
|
+
**kwargs: Dict[str, Any],
|
272
|
+
) -> None:
|
263
273
|
"""Assign predictions and probabilities to the dataset.
|
264
274
|
|
265
275
|
Args:
|
266
276
|
model (VMModel): The model used to generate the predictions.
|
267
|
-
prediction_column (str
|
268
|
-
prediction_values (
|
269
|
-
probability_column (str
|
270
|
-
probability_values (
|
271
|
-
prediction_probabilities (
|
272
|
-
kwargs: Additional keyword arguments that will get passed through to the model's `predict` method.
|
277
|
+
prediction_column (Optional[str]): The name of the column containing the predictions.
|
278
|
+
prediction_values (Optional[List[Any]]): The values of the predictions.
|
279
|
+
probability_column (Optional[str]): The name of the column containing the probabilities.
|
280
|
+
probability_values (Optional[List[float]]): The values of the probabilities.
|
281
|
+
prediction_probabilities (Optional[List[float]]): DEPRECATED: The values of the probabilities.
|
282
|
+
**kwargs: Additional keyword arguments that will get passed through to the model's `predict` method.
|
273
283
|
"""
|
274
284
|
if prediction_probabilities is not None:
|
275
285
|
warnings.warn(
|
@@ -394,8 +404,18 @@ class VMDataset(VMInput):
|
|
394
404
|
assert self.target_column not in columns
|
395
405
|
columns.append(self.target_column)
|
396
406
|
|
397
|
-
#
|
398
|
-
|
407
|
+
# Check if all columns in self._df are requested
|
408
|
+
all_columns = set(columns) == set(self._df.columns)
|
409
|
+
|
410
|
+
# For copy_data=False and all columns: return exact same DataFrame object
|
411
|
+
if not self._copy_data and all_columns:
|
412
|
+
return self._df
|
413
|
+
# For copy_data=False and subset of columns: return view with shared data
|
414
|
+
elif not self._copy_data:
|
415
|
+
return as_df(self._df[columns])
|
416
|
+
# For copy_data=True: return independent copy with duplicated data
|
417
|
+
else:
|
418
|
+
return as_df(self._df[columns]).copy()
|
399
419
|
|
400
420
|
@property
|
401
421
|
def x(self) -> np.ndarray:
|
@@ -519,9 +539,10 @@ class DataFrameDataset(VMDataset):
|
|
519
539
|
text_column: str = None,
|
520
540
|
target_class_labels: dict = None,
|
521
541
|
date_time_index: bool = False,
|
542
|
+
copy_data: bool = True,
|
522
543
|
):
|
523
544
|
"""
|
524
|
-
Initializes a DataFrameDataset instance.
|
545
|
+
Initializes a DataFrameDataset instance, preserving original pandas dtypes.
|
525
546
|
|
526
547
|
Args:
|
527
548
|
raw_dataset (pd.DataFrame): The raw dataset as a pandas DataFrame.
|
@@ -533,25 +554,44 @@ class DataFrameDataset(VMDataset):
|
|
533
554
|
text_column (str, optional): The text column name of the dataset for NLP tasks. Defaults to None.
|
534
555
|
target_class_labels (dict, optional): The class labels for the target columns. Defaults to None.
|
535
556
|
date_time_index (bool, optional): Whether to use date-time index. Defaults to False.
|
557
|
+
copy_data (bool, optional): Whether to create a copy of the input data. Defaults to True.
|
536
558
|
"""
|
559
|
+
|
560
|
+
VMInput.__init__(self)
|
561
|
+
|
562
|
+
self.input_id = input_id
|
563
|
+
|
537
564
|
index = None
|
538
565
|
if isinstance(raw_dataset.index, pd.Index):
|
539
566
|
index = raw_dataset.index.values
|
567
|
+
self.index = index
|
540
568
|
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
569
|
+
# Store the DataFrame directly
|
570
|
+
self._df = raw_dataset
|
571
|
+
|
572
|
+
if date_time_index:
|
573
|
+
self._df = convert_index_to_datetime(self._df)
|
574
|
+
|
575
|
+
self.columns = raw_dataset.columns.tolist()
|
576
|
+
self.column_aliases = {}
|
577
|
+
self.target_column = target_column
|
578
|
+
self.text_column = text_column
|
579
|
+
self.target_class_labels = target_class_labels
|
580
|
+
self.extra_columns = ExtraColumns.from_dict(extra_columns)
|
581
|
+
self._copy_data = copy_data
|
582
|
+
|
583
|
+
# Add warning when copy_data is False
|
584
|
+
if not copy_data:
|
585
|
+
logger.warning(
|
586
|
+
"Dataset initialized with copy_data=False. Changes to the original DataFrame "
|
587
|
+
"may affect this dataset. Use this option only when memory efficiency is critical "
|
588
|
+
"and you won't modify the source data."
|
589
|
+
)
|
590
|
+
|
591
|
+
self._set_feature_columns(feature_columns)
|
592
|
+
|
593
|
+
if model:
|
594
|
+
self.assign_predictions(model)
|
555
595
|
|
556
596
|
|
557
597
|
class PolarsDataset(VMDataset):
|