validmind 2.8.12__py3-none-any.whl → 2.8.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. validmind/__init__.py +6 -5
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_descriptions.py +13 -9
  4. validmind/ai/utils.py +2 -2
  5. validmind/api_client.py +75 -32
  6. validmind/client.py +111 -100
  7. validmind/client_config.py +3 -3
  8. validmind/datasets/classification/__init__.py +7 -3
  9. validmind/datasets/credit_risk/lending_club.py +28 -16
  10. validmind/datasets/nlp/cnn_dailymail.py +10 -4
  11. validmind/datasets/regression/__init__.py +22 -5
  12. validmind/errors.py +17 -7
  13. validmind/input_registry.py +1 -1
  14. validmind/logging.py +44 -35
  15. validmind/models/foundation.py +2 -2
  16. validmind/models/function.py +10 -3
  17. validmind/template.py +33 -24
  18. validmind/test_suites/__init__.py +2 -2
  19. validmind/tests/_store.py +13 -4
  20. validmind/tests/comparison.py +65 -33
  21. validmind/tests/data_validation/ClassImbalance.py +3 -1
  22. validmind/tests/data_validation/DatasetDescription.py +2 -23
  23. validmind/tests/data_validation/DescriptiveStatistics.py +1 -1
  24. validmind/tests/data_validation/Skewness.py +7 -6
  25. validmind/tests/decorator.py +14 -11
  26. validmind/tests/load.py +38 -24
  27. validmind/tests/model_validation/ragas/AnswerCorrectness.py +4 -2
  28. validmind/tests/model_validation/ragas/ContextEntityRecall.py +4 -2
  29. validmind/tests/model_validation/ragas/ContextPrecision.py +4 -2
  30. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +4 -2
  31. validmind/tests/model_validation/ragas/ContextRecall.py +4 -2
  32. validmind/tests/model_validation/ragas/Faithfulness.py +4 -2
  33. validmind/tests/model_validation/ragas/ResponseRelevancy.py +4 -2
  34. validmind/tests/model_validation/ragas/SemanticSimilarity.py +4 -2
  35. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +13 -3
  36. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -1
  37. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +28 -25
  38. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +15 -10
  39. validmind/tests/output.py +66 -11
  40. validmind/tests/run.py +28 -14
  41. validmind/tests/test_providers.py +28 -35
  42. validmind/tests/utils.py +17 -4
  43. validmind/unit_metrics/__init__.py +1 -1
  44. validmind/utils.py +295 -31
  45. validmind/vm_models/dataset/dataset.py +83 -43
  46. validmind/vm_models/dataset/utils.py +5 -3
  47. validmind/vm_models/figure.py +6 -6
  48. validmind/vm_models/input.py +6 -5
  49. validmind/vm_models/model.py +5 -5
  50. validmind/vm_models/result/result.py +122 -43
  51. validmind/vm_models/result/utils.py +5 -5
  52. validmind/vm_models/test_suite/__init__.py +5 -0
  53. validmind/vm_models/test_suite/runner.py +5 -5
  54. validmind/vm_models/test_suite/summary.py +20 -2
  55. validmind/vm_models/test_suite/test.py +6 -6
  56. validmind/vm_models/test_suite/test_suite.py +10 -10
  57. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/METADATA +3 -4
  58. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/RECORD +61 -60
  59. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/WHEEL +1 -1
  60. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/LICENSE +0 -0
  61. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/entry_points.txt +0 -0
validmind/utils.py CHANGED
@@ -12,7 +12,7 @@ import sys
12
12
  import warnings
13
13
  from datetime import date, datetime, time
14
14
  from platform import python_version
15
- from typing import Any, Dict, List
15
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar
16
16
 
17
17
  import matplotlib.pylab as pylab
18
18
  import mistune
@@ -20,6 +20,7 @@ import nest_asyncio
20
20
  import numpy as np
21
21
  import pandas as pd
22
22
  import seaborn as sns
23
+ from bs4 import BeautifulSoup
23
24
  from IPython.core import getipython
24
25
  from IPython.display import HTML
25
26
  from IPython.display import display as ipy_display
@@ -59,23 +60,25 @@ pylab.rcParams.update(params)
59
60
 
60
61
  logger = get_logger(__name__)
61
62
 
63
+ T = TypeVar("T")
64
+
62
65
 
63
66
  def parse_version(version: str) -> tuple[int, ...]:
64
67
  """
65
- Parse a semver version string into a tuple of major, minor, patch integers
68
+ Parse a semver version string into a tuple of major, minor, patch integers.
66
69
 
67
70
  Args:
68
- version (str): The semantic version string to parse
71
+ version (str): The semantic version string to parse.
69
72
 
70
73
  Returns:
71
- tuple[int, ...]: A tuple of major, minor, patch integers
74
+ tuple[int, ...]: A tuple of major, minor, patch integers.
72
75
  """
73
76
  return tuple(int(x) for x in version.split(".")[:3])
74
77
 
75
78
 
76
79
  def is_notebook() -> bool:
77
80
  """
78
- Checks if the code is running in a Jupyter notebook or IPython shell
81
+ Checks if the code is running in a Jupyter notebook or IPython shell.
79
82
 
80
83
  https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
81
84
  """
@@ -209,9 +212,7 @@ class HumanReadableEncoder(NumpyEncoder):
209
212
 
210
213
 
211
214
  def get_full_typename(o: Any) -> Any:
212
- """We determine types based on type names so we don't have to import
213
- (and therefore depend on) PyTorch, TensorFlow, etc.
214
- """
215
+ """We determine types based on type names so we don't have to import."""
215
216
  instance_name = o.__class__.__module__ + "." + o.__class__.__name__
216
217
  if instance_name in ["builtins.module", "__builtin__.module"]:
217
218
  return o.__name__
@@ -313,9 +314,9 @@ def format_key_values(key_values: Dict[str, Any]) -> Dict[str, Any]:
313
314
 
314
315
  def summarize_data_quality_results(results):
315
316
  """
316
- TODO: generalize this to work with metrics and test results
317
+ TODO: generalize this to work with metrics and test results.
317
318
 
318
- Summarize the results of the data quality test suite
319
+ Summarize the results of the data quality test suite.
319
320
  """
320
321
  test_results = []
321
322
  for result in results:
@@ -354,25 +355,31 @@ def format_number(number):
354
355
 
355
356
 
356
357
  def format_dataframe(df: pd.DataFrame) -> pd.DataFrame:
357
- """Format a pandas DataFrame for display purposes"""
358
+ """Format a pandas DataFrame for display purposes."""
358
359
  df = df.style.set_properties(**{"text-align": "left"}).hide(axis="index")
359
360
  return df.set_table_styles([dict(selector="th", props=[("text-align", "left")])])
360
361
 
361
362
 
362
- def run_async(func, *args, name=None, **kwargs):
363
- """Helper function to run functions asynchronously
363
+ def run_async(
364
+ func: Callable[..., Awaitable[T]],
365
+ *args: Any,
366
+ name: Optional[str] = None,
367
+ **kwargs: Any,
368
+ ) -> T:
369
+ """Helper function to run functions asynchronously.
364
370
 
365
371
  This takes care of the complexity of running the logging functions asynchronously. It will
366
- detect the type of environment we are running in (ipython notebook or not) and run the
372
+ detect the type of environment we are running in (IPython notebook or not) and run the
367
373
  function accordingly.
368
374
 
369
375
  Args:
370
- func (function): The function to run asynchronously
371
- *args: The arguments to pass to the function
372
- **kwargs: The keyword arguments to pass to the function
376
+ func: The function to run asynchronously.
377
+ *args: The arguments to pass to the function.
378
+ name: Optional name for the task.
379
+ **kwargs: The keyword arguments to pass to the function.
373
380
 
374
381
  Returns:
375
- The result of the function
382
+ The result of the function.
376
383
  """
377
384
  try:
378
385
  if asyncio.get_event_loop().is_running() and is_notebook():
@@ -390,8 +397,19 @@ def run_async(func, *args, name=None, **kwargs):
390
397
  return asyncio.get_event_loop().run_until_complete(func(*args, **kwargs))
391
398
 
392
399
 
393
- def run_async_check(func, *args, **kwargs):
394
- """Helper function to run functions asynchronously if the task doesn't already exist"""
400
+ def run_async_check(
401
+ func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
402
+ ) -> Optional[asyncio.Task[T]]:
403
+ """Helper function to run functions asynchronously if the task doesn't already exist.
404
+
405
+ Args:
406
+ func: The function to run asynchronously.
407
+ *args: The arguments to pass to the function.
408
+ **kwargs: The keyword arguments to pass to the function.
409
+
410
+ Returns:
411
+ Optional[asyncio.Task[T]]: The task if created or found, None otherwise.
412
+ """
395
413
  if __loop:
396
414
  return # we don't need this if we are using our own loop
397
415
 
@@ -408,16 +426,16 @@ def run_async_check(func, *args, **kwargs):
408
426
  pass
409
427
 
410
428
 
411
- def fuzzy_match(string: str, search_string: str, threshold=0.7):
412
- """Check if a string matches another string using fuzzy matching
429
+ def fuzzy_match(string: str, search_string: str, threshold: float = 0.7) -> bool:
430
+ """Check if a string matches another string using fuzzy matching.
413
431
 
414
432
  Args:
415
- string (str): The string to check
416
- search_string (str): The string to search for
417
- threshold (float): The similarity threshold to use (Default: 0.7)
433
+ string (str): The string to check.
434
+ search_string (str): The string to search for.
435
+ threshold (float): The similarity threshold to use (Default: 0.7).
418
436
 
419
437
  Returns:
420
- True if the string matches the search string, False otherwise
438
+ bool: True if the string matches the search string, False otherwise.
421
439
  """
422
440
  score = difflib.SequenceMatcher(None, string, search_string).ratio()
423
441
 
@@ -448,7 +466,7 @@ def test_id_to_name(test_id: str) -> str:
448
466
 
449
467
 
450
468
  def get_model_info(model):
451
- """Attempts to extract all model info from a model object instance"""
469
+ """Attempts to extract all model info from a model object instance."""
452
470
  architecture = model.name
453
471
  framework = model.library
454
472
  framework_version = model.library_version
@@ -472,7 +490,7 @@ def get_model_info(model):
472
490
 
473
491
 
474
492
  def get_dataset_info(dataset):
475
- """Attempts to extract all dataset info from a dataset object instance"""
493
+ """Attempts to extract all dataset info from a dataset object instance."""
476
494
  num_rows, num_cols = dataset.df.shape
477
495
  schema = dataset.df.dtypes.apply(lambda x: x.name).to_dict()
478
496
  description = (
@@ -491,7 +509,7 @@ def preview_test_config(config):
491
509
  """Preview test configuration in a collapsible HTML section.
492
510
 
493
511
  Args:
494
- config (dict): Test configuration dictionary
512
+ config (dict): Test configuration dictionary.
495
513
  """
496
514
 
497
515
  try:
@@ -515,7 +533,7 @@ def preview_test_config(config):
515
533
 
516
534
 
517
535
  def display(widget_or_html, syntax_highlighting=True, mathjax=True):
518
- """Display widgets with extra goodies (syntax highlighting, MathJax, etc.)"""
536
+ """Display widgets with extra goodies (syntax highlighting, MathJax, etc.)."""
519
537
  if isinstance(widget_or_html, str):
520
538
  ipy_display(HTML(widget_or_html))
521
539
  # if html we can auto-detect if we actually need syntax highlighting or MathJax
@@ -532,7 +550,7 @@ def display(widget_or_html, syntax_highlighting=True, mathjax=True):
532
550
 
533
551
 
534
552
  def md_to_html(md: str, mathml=False) -> str:
535
- """Converts Markdown to HTML using mistune with plugins"""
553
+ """Converts Markdown to HTML using mistune with plugins."""
536
554
  # use mistune with math plugin to convert to html
537
555
  html = mistune.create_markdown(
538
556
  plugins=["math", "table", "strikethrough", "footnotes"]
@@ -559,6 +577,63 @@ def md_to_html(md: str, mathml=False) -> str:
559
577
  return html
560
578
 
561
579
 
580
+ def is_html(text: str) -> bool:
581
+ """Check if a string is HTML.
582
+
583
+ Uses more robust heuristics to determine if a string contains HTML content.
584
+
585
+ Args:
586
+ text (str): The string to check
587
+
588
+ Returns:
589
+ bool: True if the string likely contains HTML, False otherwise
590
+ """
591
+ # Strip whitespace first
592
+ text = text.strip()
593
+
594
+ # Basic check: Must at least start with < and end with >
595
+ if not (text.startswith("<") and text.endswith(">")):
596
+ return False
597
+
598
+ # Look for common HTML tags
599
+ common_html_patterns = [
600
+ r"<html.*?>", # HTML tag
601
+ r"<body.*?>", # Body tag
602
+ r"<div.*?>", # Div tag
603
+ r"<p>.*?</p>", # Paragraph with content
604
+ r"<h[1-6]>.*?</h[1-6]>", # Headers
605
+ r"<script.*?>", # Script tags
606
+ r"<style.*?>", # Style tags
607
+ r"<a href=.*?>", # Links
608
+ r"<img.*?>", # Images
609
+ r"<table.*?>", # Tables
610
+ r"<!DOCTYPE html>", # DOCTYPE declaration
611
+ ]
612
+
613
+ for pattern in common_html_patterns:
614
+ if re.search(pattern, text, re.IGNORECASE | re.DOTALL):
615
+ return True
616
+
617
+ # If we have at least 2 matching tags, it's likely HTML
618
+ # This helps detect custom elements or patterns not in our list
619
+ tags = re.findall(r"</?[a-zA-Z][a-zA-Z0-9]*.*?>", text)
620
+ if len(tags) >= 2:
621
+ return True
622
+
623
+ # Try parsing with BeautifulSoup as a last resort
624
+ try:
625
+ soup = BeautifulSoup(text, "html.parser")
626
+ # If we find any tags that weren't in the original text, BeautifulSoup
627
+ # likely tried to fix broken HTML, meaning it's not valid HTML
628
+ return len(soup.find_all()) > 0
629
+
630
+ except Exception as e:
631
+ logger.error(f"Error checking if text is HTML: {e}")
632
+ return False
633
+
634
+ return False
635
+
636
+
562
637
  def inspect_obj(obj):
563
638
  # Filtering only attributes
564
639
  print(len("Attributes:") * "-")
@@ -601,3 +676,192 @@ def serialize(obj):
601
676
  elif isinstance(obj, (pd.DataFrame, pd.Series)):
602
677
  return "" # Simple empty string for non-serializable objects
603
678
  return obj
679
+
680
+
681
+ def is_text_column(series, threshold=0.05) -> bool:
682
+ """
683
+ Determines if a series is likely to contain text data using heuristics.
684
+
685
+ Args:
686
+ series (pd.Series): The pandas Series to analyze
687
+ threshold (float): The minimum threshold to classify a pattern match as significant
688
+
689
+ Returns:
690
+ bool: True if the series likely contains text data, False otherwise
691
+ """
692
+ # Filter to non-null string values and sample if needed
693
+ string_series = series.dropna().astype(str)
694
+ if len(string_series) == 0:
695
+ return False
696
+ if len(string_series) > 1000:
697
+ string_series = string_series.sample(1000, random_state=42)
698
+
699
+ # Calculate basic metrics
700
+ total_values = len(string_series)
701
+ unique_ratio = len(string_series.unique()) / total_values if total_values > 0 else 0
702
+ avg_length = string_series.str.len().mean()
703
+ avg_words = string_series.str.split(r"\s+").str.len().mean()
704
+
705
+ # Check for special text patterns
706
+ patterns = {
707
+ "url": r"https?://\S+|www\.\S+",
708
+ "email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
709
+ "filepath": r'(?:[a-zA-Z]:|[\\/])(?:[\\/][^\\/:*?"<>|]+)+',
710
+ }
711
+
712
+ # Check if any special patterns exceed threshold
713
+ for pattern in patterns.values():
714
+ if string_series.str.contains(pattern, regex=True, na=False).mean() > threshold:
715
+ return True
716
+
717
+ # Calculate proportion of alphabetic characters
718
+ total_chars = string_series.str.len().sum()
719
+ if total_chars > 0:
720
+ alpha_ratio = string_series.str.count(r"[a-zA-Z]").sum() / total_chars
721
+ else:
722
+ alpha_ratio = 0
723
+
724
+ # Check for free-form text indicators
725
+ text_indicators = [
726
+ unique_ratio > 0.8 and avg_length > 20, # High uniqueness and long strings
727
+ unique_ratio > 0.4
728
+ and avg_length > 15
729
+ and string_series.str.contains(r"[.,;:!?]", regex=True, na=False).mean()
730
+ > 0.3, # Moderate uniqueness with punctuation
731
+ string_series.str.contains(
732
+ r"\b\w+\b\s+\b\w+\b\s+\b\w+\b\s+\b\w+\b", regex=True, na=False
733
+ ).mean()
734
+ > 0.3, # Contains long phrases
735
+ avg_words > 5 and alpha_ratio > 0.6, # Many words with mostly letters
736
+ unique_ratio > 0.95 and avg_length > 10, # Very high uniqueness
737
+ ]
738
+
739
+ return any(text_indicators)
740
+
741
+
742
+ def _get_numeric_type_detail(column, dtype, series):
743
+ """Helper function to determine numeric type details."""
744
+ if pd.api.types.is_integer_dtype(dtype):
745
+ return {"type": "Numeric", "subtype": "Integer"}
746
+ elif pd.api.types.is_float_dtype(dtype):
747
+ return {"type": "Numeric", "subtype": "Float"}
748
+ else:
749
+ return {"type": "Numeric", "subtype": "Other"}
750
+
751
+
752
+ def _get_text_type_detail(series):
753
+ """Helper function to determine text/categorical type details."""
754
+ string_series = series.dropna().astype(str)
755
+
756
+ if len(string_series) == 0:
757
+ return {"type": "Categorical"}
758
+
759
+ # Check for common patterns
760
+ url_pattern = r"https?://\S+|www\.\S+"
761
+ email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
762
+ filepath_pattern = r'(?:[a-zA-Z]:|[\\/])(?:[\\/][^\\/:*?"<>|]+)+'
763
+
764
+ url_ratio = string_series.str.contains(url_pattern, regex=True, na=False).mean()
765
+ email_ratio = string_series.str.contains(email_pattern, regex=True, na=False).mean()
766
+ filepath_ratio = string_series.str.contains(
767
+ filepath_pattern, regex=True, na=False
768
+ ).mean()
769
+
770
+ # Check if general text using enhanced function
771
+ if url_ratio > 0.7:
772
+ return {"type": "Text", "subtype": "URL"}
773
+ elif email_ratio > 0.7:
774
+ return {"type": "Text", "subtype": "Email"}
775
+ elif filepath_ratio > 0.7:
776
+ return {"type": "Text", "subtype": "Path"}
777
+ elif is_text_column(series):
778
+ return {"type": "Text", "subtype": "FreeText"}
779
+
780
+ # Must be categorical
781
+ n_unique = series.nunique()
782
+ if n_unique == 2:
783
+ return {"type": "Categorical", "subtype": "Binary"}
784
+ else:
785
+ return {"type": "Categorical", "subtype": "Nominal"}
786
+
787
+
788
+ def get_column_type_detail(df, column) -> dict:
789
+ """
790
+ Get detailed column type information beyond basic type detection.
791
+ Similar to ydata-profiling's type system.
792
+
793
+ Args:
794
+ df (pd.DataFrame): DataFrame containing the column
795
+ column (str): Column name to analyze
796
+
797
+ Returns:
798
+ dict: Detailed type information including primary type and subtype
799
+ """
800
+ series = df[column]
801
+ dtype = series.dtype
802
+
803
+ # Initialize result with id and basic type
804
+ result = {"id": column, "type": "Unknown"}
805
+
806
+ # Determine type details based on dtype
807
+ type_detail = None
808
+
809
+ if pd.api.types.is_numeric_dtype(dtype):
810
+ type_detail = _get_numeric_type_detail(column, dtype, series)
811
+ elif pd.api.types.is_bool_dtype(dtype):
812
+ type_detail = {"type": "Boolean"}
813
+ elif pd.api.types.is_datetime64_any_dtype(dtype):
814
+ type_detail = {"type": "Datetime"}
815
+ elif pd.api.types.is_categorical_dtype(dtype) or pd.api.types.is_object_dtype(
816
+ dtype
817
+ ):
818
+ type_detail = _get_text_type_detail(series)
819
+
820
+ # Update result with type details
821
+ if type_detail:
822
+ result.update(type_detail)
823
+
824
+ return result
825
+
826
+
827
+ def infer_datatypes(df, detailed=False) -> list:
828
+ """
829
+ Infer data types for columns in a DataFrame.
830
+
831
+ Args:
832
+ df (pd.DataFrame): DataFrame to analyze
833
+ detailed (bool): Whether to return detailed type information including subtypes
834
+
835
+ Returns:
836
+ list: Column type mappings
837
+ """
838
+ if detailed:
839
+ return [get_column_type_detail(df, column) for column in df.columns]
840
+
841
+ column_type_mappings = {}
842
+ # Use pandas to infer data types
843
+ for column in df.columns:
844
+ # Check if all values are None
845
+ if df[column].isna().all():
846
+ column_type_mappings[column] = {"id": column, "type": "Null"}
847
+ continue
848
+
849
+ dtype = df[column].dtype
850
+ if pd.api.types.is_numeric_dtype(dtype):
851
+ column_type_mappings[column] = {"id": column, "type": "Numeric"}
852
+ elif pd.api.types.is_bool_dtype(dtype):
853
+ column_type_mappings[column] = {"id": column, "type": "Boolean"}
854
+ elif pd.api.types.is_datetime64_any_dtype(dtype):
855
+ column_type_mappings[column] = {"id": column, "type": "Datetime"}
856
+ elif pd.api.types.is_categorical_dtype(dtype) or pd.api.types.is_object_dtype(
857
+ dtype
858
+ ):
859
+ # Check if this is more likely to be text than categorical
860
+ if is_text_column(df[column]):
861
+ column_type_mappings[column] = {"id": column, "type": "Text"}
862
+ else:
863
+ column_type_mappings[column] = {"id": column, "type": "Categorical"}
864
+ else:
865
+ column_type_mappings[column] = {"id": column, "type": "Unsupported"}
866
+
867
+ return list(column_type_mappings.values())
@@ -8,6 +8,7 @@ Dataset class wrapper
8
8
 
9
9
  import warnings
10
10
  from copy import deepcopy
11
+ from typing import Any, Dict, List, Optional
11
12
 
12
13
  import numpy as np
13
14
  import pandas as pd
@@ -24,9 +25,9 @@ logger = get_logger(__name__)
24
25
 
25
26
 
26
27
  class VMDataset(VMInput):
27
- """Base class for VM datasets
28
+ """Base class for VM datasets.
28
29
 
29
- Child classes should be used to support new dataset types (tensor, polars etc)
30
+ Child classes should be used to support new dataset types (tensor, polars etc.)
30
31
  by converting the user's dataset into a numpy array collecting metadata like
31
32
  column names and then call this (parent) class `__init__` method.
32
33
 
@@ -46,6 +47,7 @@ class VMDataset(VMInput):
46
47
  target_class_labels (Dict): The class labels for the target columns.
47
48
  df (pd.DataFrame): The dataset as a pandas DataFrame.
48
49
  extra_columns (Dict): Extra columns to include in the dataset.
50
+ copy_data (bool): Whether to copy the data. Defaults to True.
49
51
  """
50
52
 
51
53
  def __repr__(self):
@@ -65,6 +67,7 @@ class VMDataset(VMInput):
65
67
  text_column: str = None,
66
68
  extra_columns: dict = None,
67
69
  target_class_labels: dict = None,
70
+ copy_data: bool = True,
68
71
  ):
69
72
  """
70
73
  Initializes a VMDataset instance.
@@ -81,6 +84,7 @@ class VMDataset(VMInput):
81
84
  feature_columns (str, optional): The feature column names of the dataset. Defaults to None.
82
85
  text_column (str, optional): The text column name of the dataset for nlp tasks. Defaults to None.
83
86
  target_class_labels (Dict, optional): The class labels for the target columns. Defaults to None.
87
+ copy_data (bool, optional): Whether to copy the data. Defaults to True.
84
88
  """
85
89
  # initialize input_id
86
90
  self.input_id = input_id
@@ -111,6 +115,7 @@ class VMDataset(VMInput):
111
115
  self.target_class_labels = target_class_labels
112
116
  self.extra_columns = ExtraColumns.from_dict(extra_columns)
113
117
  self._set_feature_columns(feature_columns)
118
+ self._copy_data = copy_data
114
119
 
115
120
  if model:
116
121
  self.assign_predictions(model)
@@ -128,16 +133,19 @@ class VMDataset(VMInput):
128
133
  excluded = [self.target_column, *self.extra_columns.flatten()]
129
134
  self.feature_columns = [col for col in self.columns if col not in excluded]
130
135
 
131
- self.feature_columns_numeric = (
132
- self._df[self.feature_columns]
133
- .select_dtypes(include=[np.number])
134
- .columns.tolist()
135
- )
136
- self.feature_columns_categorical = (
137
- self._df[self.feature_columns]
138
- .select_dtypes(include=[object, pd.Categorical])
139
- .columns.tolist()
140
- )
136
+ # Get dtypes without loading data into memory
137
+ feature_dtypes = self._df[self.feature_columns].dtypes
138
+
139
+ self.feature_columns_numeric = feature_dtypes[
140
+ feature_dtypes.apply(lambda x: pd.api.types.is_numeric_dtype(x))
141
+ ].index.tolist()
142
+
143
+ self.feature_columns_categorical = feature_dtypes[
144
+ feature_dtypes.apply(
145
+ lambda x: pd.api.types.is_categorical_dtype(x)
146
+ or pd.api.types.is_object_dtype(x)
147
+ )
148
+ ].index.tolist()
141
149
 
142
150
  def _add_column(self, column_name, column_values):
143
151
  column_values = np.array(column_values)
@@ -200,7 +208,7 @@ class VMDataset(VMInput):
200
208
  "Cannot use precomputed probabilities without precomputed predictions"
201
209
  )
202
210
 
203
- def with_options(self, **kwargs) -> "VMDataset":
211
+ def with_options(self, **kwargs: Dict[str, Any]) -> "VMDataset":
204
212
  """Support options provided when passing an input to run_test or run_test_suite
205
213
 
206
214
  Example:
@@ -253,23 +261,25 @@ class VMDataset(VMInput):
253
261
  def assign_predictions(
254
262
  self,
255
263
  model: VMModel,
256
- prediction_column: str = None,
257
- prediction_values: list = None,
258
- probability_column: str = None,
259
- probability_values: list = None,
260
- prediction_probabilities: list = None, # DEPRECATED: use probability_values
261
- **kwargs,
262
- ):
264
+ prediction_column: Optional[str] = None,
265
+ prediction_values: Optional[List[Any]] = None,
266
+ probability_column: Optional[str] = None,
267
+ probability_values: Optional[List[float]] = None,
268
+ prediction_probabilities: Optional[
269
+ List[float]
270
+ ] = None, # DEPRECATED: use probability_values
271
+ **kwargs: Dict[str, Any],
272
+ ) -> None:
263
273
  """Assign predictions and probabilities to the dataset.
264
274
 
265
275
  Args:
266
276
  model (VMModel): The model used to generate the predictions.
267
- prediction_column (str, optional): The name of the column containing the predictions. Defaults to None.
268
- prediction_values (list, optional): The values of the predictions. Defaults to None.
269
- probability_column (str, optional): The name of the column containing the probabilities. Defaults to None.
270
- probability_values (list, optional): The values of the probabilities. Defaults to None.
271
- prediction_probabilities (list, optional): DEPRECATED: The values of the probabilities. Defaults to None.
272
- kwargs: Additional keyword arguments that will get passed through to the model's `predict` method.
277
+ prediction_column (Optional[str]): The name of the column containing the predictions.
278
+ prediction_values (Optional[List[Any]]): The values of the predictions.
279
+ probability_column (Optional[str]): The name of the column containing the probabilities.
280
+ probability_values (Optional[List[float]]): The values of the probabilities.
281
+ prediction_probabilities (Optional[List[float]]): DEPRECATED: The values of the probabilities.
282
+ **kwargs: Additional keyword arguments that will get passed through to the model's `predict` method.
273
283
  """
274
284
  if prediction_probabilities is not None:
275
285
  warnings.warn(
@@ -394,8 +404,18 @@ class VMDataset(VMInput):
394
404
  assert self.target_column not in columns
395
405
  columns.append(self.target_column)
396
406
 
397
- # return a copy to prevent accidental modification
398
- return as_df(self._df[columns]).copy()
407
+ # Check if all columns in self._df are requested
408
+ all_columns = set(columns) == set(self._df.columns)
409
+
410
+ # For copy_data=False and all columns: return exact same DataFrame object
411
+ if not self._copy_data and all_columns:
412
+ return self._df
413
+ # For copy_data=False and subset of columns: return view with shared data
414
+ elif not self._copy_data:
415
+ return as_df(self._df[columns])
416
+ # For copy_data=True: return independent copy with duplicated data
417
+ else:
418
+ return as_df(self._df[columns]).copy()
399
419
 
400
420
  @property
401
421
  def x(self) -> np.ndarray:
@@ -519,9 +539,10 @@ class DataFrameDataset(VMDataset):
519
539
  text_column: str = None,
520
540
  target_class_labels: dict = None,
521
541
  date_time_index: bool = False,
542
+ copy_data: bool = True,
522
543
  ):
523
544
  """
524
- Initializes a DataFrameDataset instance.
545
+ Initializes a DataFrameDataset instance, preserving original pandas dtypes.
525
546
 
526
547
  Args:
527
548
  raw_dataset (pd.DataFrame): The raw dataset as a pandas DataFrame.
@@ -533,25 +554,44 @@ class DataFrameDataset(VMDataset):
533
554
  text_column (str, optional): The text column name of the dataset for NLP tasks. Defaults to None.
534
555
  target_class_labels (dict, optional): The class labels for the target columns. Defaults to None.
535
556
  date_time_index (bool, optional): Whether to use date-time index. Defaults to False.
557
+ copy_data (bool, optional): Whether to create a copy of the input data. Defaults to True.
536
558
  """
559
+
560
+ VMInput.__init__(self)
561
+
562
+ self.input_id = input_id
563
+
537
564
  index = None
538
565
  if isinstance(raw_dataset.index, pd.Index):
539
566
  index = raw_dataset.index.values
567
+ self.index = index
540
568
 
541
- super().__init__(
542
- raw_dataset=raw_dataset.values,
543
- input_id=input_id,
544
- model=model,
545
- index_name=raw_dataset.index.name,
546
- index=index,
547
- columns=raw_dataset.columns.to_list(),
548
- target_column=target_column,
549
- extra_columns=extra_columns,
550
- feature_columns=feature_columns,
551
- text_column=text_column,
552
- target_class_labels=target_class_labels,
553
- date_time_index=date_time_index,
554
- )
569
+ # Store the DataFrame directly
570
+ self._df = raw_dataset
571
+
572
+ if date_time_index:
573
+ self._df = convert_index_to_datetime(self._df)
574
+
575
+ self.columns = raw_dataset.columns.tolist()
576
+ self.column_aliases = {}
577
+ self.target_column = target_column
578
+ self.text_column = text_column
579
+ self.target_class_labels = target_class_labels
580
+ self.extra_columns = ExtraColumns.from_dict(extra_columns)
581
+ self._copy_data = copy_data
582
+
583
+ # Add warning when copy_data is False
584
+ if not copy_data:
585
+ logger.warning(
586
+ "Dataset initialized with copy_data=False. Changes to the original DataFrame "
587
+ "may affect this dataset. Use this option only when memory efficiency is critical "
588
+ "and you won't modify the source data."
589
+ )
590
+
591
+ self._set_feature_columns(feature_columns)
592
+
593
+ if model:
594
+ self.assign_predictions(model)
555
595
 
556
596
 
557
597
  class PolarsDataset(VMDataset):