workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  5. workbench/algorithms/dataframe/projection_2d.py +8 -2
  6. workbench/algorithms/dataframe/proximity.py +3 -0
  7. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  8. workbench/algorithms/sql/column_stats.py +0 -1
  9. workbench/algorithms/sql/correlations.py +0 -1
  10. workbench/algorithms/sql/descriptive_stats.py +0 -1
  11. workbench/api/feature_set.py +0 -1
  12. workbench/api/meta.py +0 -1
  13. workbench/cached/cached_meta.py +0 -1
  14. workbench/cached/cached_model.py +37 -7
  15. workbench/core/artifacts/endpoint_core.py +12 -2
  16. workbench/core/artifacts/feature_set_core.py +238 -225
  17. workbench/core/cloud_platform/cloud_meta.py +0 -1
  18. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  19. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  20. workbench/model_script_utils/model_script_utils.py +30 -0
  21. workbench/model_script_utils/uq_harness.py +0 -1
  22. workbench/model_scripts/chemprop/chemprop.template +196 -68
  23. workbench/model_scripts/chemprop/generated_model_script.py +197 -72
  24. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  26. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  27. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
  29. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  30. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  31. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  32. workbench/model_scripts/script_generation.py +0 -1
  33. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  34. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  35. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  36. workbench/scripts/ml_pipeline_sqs.py +71 -2
  37. workbench/themes/dark/custom.css +85 -8
  38. workbench/themes/dark/plotly.json +6 -6
  39. workbench/themes/light/custom.css +172 -64
  40. workbench/themes/light/plotly.json +9 -9
  41. workbench/themes/midnight_blue/custom.css +82 -29
  42. workbench/themes/midnight_blue/plotly.json +1 -1
  43. workbench/utils/aws_utils.py +0 -1
  44. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  45. workbench/utils/chem_utils/projections.py +16 -6
  46. workbench/utils/chem_utils/vis.py +137 -27
  47. workbench/utils/clientside_callbacks.py +41 -0
  48. workbench/utils/markdown_utils.py +57 -0
  49. workbench/utils/model_utils.py +0 -1
  50. workbench/utils/pipeline_utils.py +0 -1
  51. workbench/utils/plot_utils.py +52 -36
  52. workbench/utils/theme_manager.py +95 -30
  53. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  54. workbench/web_interface/components/model_plot.py +2 -0
  55. workbench/web_interface/components/plugin_unit_test.py +0 -1
  56. workbench/web_interface/components/plugins/ag_table.py +2 -4
  57. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  58. workbench/web_interface/components/plugins/model_details.py +10 -6
  59. workbench/web_interface/components/plugins/scatter_plot.py +184 -85
  60. workbench/web_interface/components/settings_menu.py +185 -0
  61. workbench/web_interface/page_views/main_page.py +0 -1
  62. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
  63. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
  64. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
  65. workbench/themes/quartz/base_css.url +0 -1
  66. workbench/themes/quartz/custom.css +0 -117
  67. workbench/themes/quartz/plotly.json +0 -642
  68. workbench/themes/quartz_dark/base_css.url +0 -1
  69. workbench/themes/quartz_dark/custom.css +0 -131
  70. workbench/themes/quartz_dark/plotly.json +0 -642
  71. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
  72. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
  73. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@ from datetime import datetime, timezone
7
7
  import botocore.exceptions
8
8
  import pandas as pd
9
9
  import awswrangler as wr
10
- import numpy as np
11
10
 
12
11
  from sagemaker.feature_store.feature_group import FeatureGroup
13
12
  from sagemaker.feature_store.feature_store import FeatureStore
@@ -16,9 +15,8 @@ from sagemaker.feature_store.feature_store import FeatureStore
16
15
  from workbench.core.artifacts.artifact import Artifact
17
16
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
17
  from workbench.core.artifacts.athena_source import AthenaSource
19
- from workbench.utils.deprecated_utils import deprecated
20
18
 
21
- from typing import TYPE_CHECKING, Optional, List, Dict, Union
19
+ from typing import TYPE_CHECKING, List, Dict, Union
22
20
 
23
21
  from workbench.utils.aws_utils import aws_throttle
24
22
 
@@ -483,18 +481,6 @@ class FeatureSetCore(Artifact):
483
481
  time.sleep(1)
484
482
  cls.log.info(f"FeatureSet {feature_group.name} successfully deleted")
485
483
 
486
- def set_training_holdouts(self, holdout_ids: list[str]):
487
- """Set the hold out ids for the training view for this FeatureSet
488
-
489
- Args:
490
- holdout_ids (list[str]): The list of holdout ids.
491
- """
492
- from workbench.core.views import TrainingView
493
-
494
- # Create a NEW training view
495
- self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
496
- TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
497
-
498
484
  def get_training_holdouts(self) -> list[str]:
499
485
  """Get the hold out ids for the training view for this FeatureSet
500
486
 
@@ -510,183 +496,234 @@ class FeatureSetCore(Artifact):
510
496
  ].tolist()
511
497
  return hold_out_ids
512
498
 
513
- def set_sample_weights(
499
+ # ---- Public methods for training configuration ----
500
+ def set_training_config(
514
501
  self,
515
- weight_dict: Dict[Union[str, int], float],
502
+ holdout_ids: List[Union[str, int]] = None,
503
+ weight_dict: Dict[Union[str, int], float] = None,
516
504
  default_weight: float = 1.0,
517
505
  exclude_zero_weights: bool = True,
518
506
  ):
519
- """Configure training view with sample weights for each ID.
507
+ """Configure training view with holdout IDs and/or sample weights.
508
+
509
+ This method creates a training view that can include both:
510
+ - A 'training' column (True/False) based on holdout IDs
511
+ - A 'sample_weight' column for weighted training
520
512
 
521
513
  Args:
514
+ holdout_ids: List of IDs to mark as training=False (validation/holdout set)
522
515
  weight_dict: Mapping of ID to sample weight
523
516
  - weight > 1.0: oversample/emphasize
524
517
  - weight = 1.0: normal (default)
525
518
  - 0 < weight < 1.0: downweight/de-emphasize
526
- - weight = 0.0: exclude from training
519
+ - weight = 0.0: exclude from training (filtered out if exclude_zero_weights=True)
527
520
  default_weight: Weight for IDs not in weight_dict (default: 1.0)
528
521
  exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
529
522
 
530
523
  Example:
531
- weights = {
532
- 'compound_42': 3.0, # oversample 3x
533
- 'compound_99': 0.1, # noisy, downweight
534
- 'compound_123': 0.0, # exclude from training
535
- }
536
- model.set_sample_weights(weights) # zeros automatically excluded
537
- model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
524
+ # Temporal split with sample weights
525
+ fs.set_training_config(
526
+ holdout_ids=temporal_hold_out_ids, # IDs after cutoff date
527
+ weight_dict={'compound_42': 0.0, 'compound_99': 2.0}, # exclude/upweight
528
+ )
538
529
  """
539
- from workbench.core.views import TrainingView
530
+ from workbench.core.views.training_view import TrainingView
540
531
 
541
- if not weight_dict:
542
- self.log.important("Empty weight_dict, creating standard training view")
532
+ # If neither is provided, create a standard training view
533
+ if not holdout_ids and not weight_dict:
534
+ self.log.important("No holdouts or weights specified, creating standard training view")
543
535
  TrainingView.create(self, id_column=self.id_column)
544
536
  return
545
537
 
546
- self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
538
+ # If only holdout_ids, delegate to set_training_holdouts
539
+ if holdout_ids and not weight_dict:
540
+ self.set_training_holdouts(holdout_ids)
541
+ return
547
542
 
548
- # Helper to format IDs for SQL
549
- def format_id(id_val):
550
- return repr(id_val)
543
+ # If only weight_dict, delegate to set_sample_weights
544
+ if weight_dict and not holdout_ids:
545
+ self.set_sample_weights(weight_dict, default_weight, exclude_zero_weights)
546
+ return
551
547
 
552
- # Build CASE statement for sample_weight
553
- case_conditions = [
554
- f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
555
- ]
556
- case_statement = "\n ".join(case_conditions)
548
+ # Both holdout_ids and weight_dict provided - build combined view
549
+ self.log.important(f"Setting training config: {len(holdout_ids)} holdouts, {len(weight_dict)} weights")
557
550
 
558
- # Build inner query with sample weights
559
- inner_sql = f"""SELECT
560
- *,
561
- CASE
562
- {case_statement}
563
- ELSE {default_weight}
564
- END AS sample_weight
565
- FROM {self.table}"""
551
+ # Get column list (excluding AWS-generated columns)
552
+ from workbench.core.views.view_utils import get_column_list
566
553
 
567
- # Optionally filter out zero weights
568
- if exclude_zero_weights:
569
- zero_count = sum(1 for weight in weight_dict.values() if weight == 0.0)
570
- custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
571
- self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
572
- else:
573
- custom_sql = inner_sql
554
+ aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
555
+ source_columns = get_column_list(self.data_source, self.table)
556
+ column_list = [col for col in source_columns if col not in aws_cols]
574
557
 
575
- TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
558
+ # Build the training column CASE statement
559
+ training_case = self._build_holdout_case(holdout_ids)
576
560
 
577
- @deprecated(version="0.9")
578
- def set_training_filter(self, filter_expression: Optional[str] = None):
579
- """Set a filter expression for the training view for this FeatureSet
561
+ # For large weight_dict, use supplemental table + JOIN
562
+ if len(weight_dict) >= 100:
563
+ self.log.info("Using supplemental table approach for large weight_dict")
564
+ weights_table = self._create_weights_table(weight_dict)
580
565
 
581
- Args:
582
- filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
583
- If None or empty string, will reset to training view with no filter
584
- (default: None)
585
- """
586
- from workbench.core.views import TrainingView
566
+ # Build column selection with table alias
567
+ sql_columns = ", ".join([f't."{col}"' for col in column_list])
587
568
 
588
- # Grab the existing holdout ids
589
- holdout_ids = self.get_training_holdouts()
569
+ # Build JOIN query with training CASE and weight from joined table
570
+ training_case_aliased = training_case.replace(f"WHEN {self.id_column} IN", f"WHEN t.{self.id_column} IN")
571
+ inner_sql = f"""SELECT {sql_columns}, {training_case_aliased},
572
+ COALESCE(w.sample_weight, {default_weight}) AS sample_weight
573
+ FROM {self.table} t
574
+ LEFT JOIN {weights_table} w ON t.{self.id_column} = w.{self.id_column}"""
575
+ else:
576
+ # For small weight_dict, use CASE statement
577
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
578
+ weight_case = self._build_weight_case(weight_dict, default_weight)
579
+ inner_sql = f"SELECT {sql_columns}, {training_case}, {weight_case} FROM {self.table}"
590
580
 
591
- # Create a NEW training view
592
- self.log.important(f"Setting Training Filter: {filter_expression}")
593
- TrainingView.create(
594
- self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
595
- )
581
+ # Optionally filter out zero weights
582
+ if exclude_zero_weights:
583
+ zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
584
+ if zero_count:
585
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
586
+ sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
587
+ else:
588
+ sql_query = inner_sql
596
589
 
597
- @deprecated(version="0.9")
598
- def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
599
- """Exclude a list of IDs from the training view
590
+ self._create_training_view(sql_query)
591
+
592
+ def set_training_holdouts(self, holdout_ids: list[str]):
593
+ """Set the hold out ids for the training view for this FeatureSet
600
594
 
601
595
  Args:
602
- ids (List[Union[str, int]],): List of IDs to exclude from training
603
- column_name (Optional[str]): Column name to filter on.
604
- If None, uses self.id_column (default: None)
596
+ holdout_ids (list[str]): The list of holdout ids.
605
597
  """
606
- # Use the default id_column if not specified
607
- column = column_name or self.id_column
608
-
609
- # Handle empty list case
610
- if not ids:
611
- self.log.warning("No IDs provided to exclude")
612
- return
613
-
614
- # Build the filter expression with proper SQL quoting
615
- quoted_ids = ", ".join([repr(id) for id in ids])
616
- filter_expression = f"{column} NOT IN ({quoted_ids})"
598
+ from workbench.core.views import TrainingView
617
599
 
618
- # Apply the filter
619
- self.set_training_filter(filter_expression)
600
+ self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
601
+ TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
620
602
 
621
- @deprecated(version="0.9")
622
- def set_training_sampling(
603
+ def set_sample_weights(
623
604
  self,
624
- exclude_ids: Optional[List[Union[str, int]]] = None,
625
- replicate_ids: Optional[List[Union[str, int]]] = None,
626
- replication_factor: int = 2,
605
+ weight_dict: Dict[Union[str, int], float],
606
+ default_weight: float = 1.0,
607
+ exclude_zero_weights: bool = True,
627
608
  ):
628
- """Configure training view with ID exclusions and replications (oversampling).
609
+ """Configure training view with sample weights for each ID.
629
610
 
630
611
  Args:
631
- exclude_ids: List of IDs to exclude from training view
632
- replicate_ids: List of IDs to replicate in training view for oversampling
633
- replication_factor: Number of times to replicate each ID (default: 2)
612
+ weight_dict: Mapping of ID to sample weight
613
+ - weight > 1.0: oversample/emphasize
614
+ - weight = 1.0: normal (default)
615
+ - 0 < weight < 1.0: downweight/de-emphasize
616
+ - weight = 0.0: exclude from training
617
+ default_weight: Weight for IDs not in weight_dict (default: 1.0)
618
+ exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
619
+
620
+ Example:
621
+ weights = {
622
+ 'compound_42': 3.0, # oversample 3x
623
+ 'compound_99': 0.1, # noisy, downweight
624
+ 'compound_123': 0.0, # exclude from training
625
+ }
626
+ fs.set_sample_weights(weights) # zeros automatically excluded
627
+ fs.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
634
628
 
635
629
  Note:
636
- If an ID appears in both lists, exclusion takes precedence.
630
+ For large weight_dict (100+ entries), weights are stored as a supplemental
631
+ table and joined to avoid Athena query size limits.
637
632
  """
638
633
  from workbench.core.views import TrainingView
639
634
 
640
- # Normalize to empty lists if None
641
- exclude_ids = exclude_ids or []
642
- replicate_ids = replicate_ids or []
643
-
644
- # Remove any replicate_ids that are also in exclude_ids (exclusion wins)
645
- replicate_ids = [rid for rid in replicate_ids if rid not in exclude_ids]
646
-
647
- # If no sampling needed, just create normal view
648
- if not exclude_ids and not replicate_ids:
649
- self.log.important("No sampling specified, creating standard training view")
635
+ if not weight_dict:
636
+ self.log.important("Empty weight_dict, creating standard training view")
650
637
  TrainingView.create(self, id_column=self.id_column)
651
638
  return
652
639
 
653
- # Build the custom SQL query
654
- self.log.important(
655
- f"Excluding {len(exclude_ids)} IDs, Replicating {len(replicate_ids)} IDs "
656
- f"(factor: {replication_factor}x)"
657
- )
658
-
659
- # Helper to format IDs for SQL
660
- def format_ids(ids):
661
- return ", ".join([repr(id) for id in ids])
640
+ self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
662
641
 
663
- # Start with base query
664
- base_query = f"SELECT * FROM {self.table}"
642
+ # For large weight_dict, use supplemental table + JOIN to avoid query size limits
643
+ if len(weight_dict) >= 100:
644
+ self.log.info("Using supplemental table approach for large weight_dict")
645
+ weights_table = self._create_weights_table(weight_dict)
665
646
 
666
- # Add exclusions if needed
667
- if exclude_ids:
668
- base_query += f"\nWHERE {self.id_column} NOT IN ({format_ids(exclude_ids)})"
647
+ # Build JOIN query with COALESCE for default weight
648
+ inner_sql = f"""SELECT t.*, COALESCE(w.sample_weight, {default_weight}) AS sample_weight
649
+ FROM {self.table} t
650
+ LEFT JOIN {weights_table} w ON t.{self.id_column} = w.{self.id_column}"""
651
+ else:
652
+ # For small weight_dict, use CASE statement (simpler, no extra table)
653
+ weight_case = self._build_weight_case(weight_dict, default_weight)
654
+ inner_sql = f"SELECT *, {weight_case} FROM {self.table}"
669
655
 
670
- # Build full query with replication
671
- if replicate_ids:
672
- # Generate VALUES clause for CROSS JOIN: (1), (2), ..., (N-1)
673
- # We want N-1 additional copies since the original row is already in base_query
674
- values_clause = ", ".join([f"({i})" for i in range(1, replication_factor)])
656
+ # Optionally filter out zero weights
657
+ if exclude_zero_weights:
658
+ zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
659
+ if zero_count:
660
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
661
+ sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
662
+ else:
663
+ sql_query = inner_sql
675
664
 
676
- custom_sql = f"""{base_query}
665
+ TrainingView.create_with_sql(self, sql_query=sql_query, id_column=self.id_column)
677
666
 
678
- UNION ALL
667
+ # ---- Internal helpers for training view SQL generation ----
668
+ @staticmethod
669
+ def _format_id_for_sql(id_val: Union[str, int]) -> str:
670
+ """Format an ID value for use in SQL."""
671
+ return repr(id_val)
679
672
 
680
- SELECT t.*
681
- FROM {self.table} t
682
- CROSS JOIN (VALUES {values_clause}) AS n(num)
683
- WHERE t.{self.id_column} IN ({format_ids(replicate_ids)})"""
673
+ def _build_holdout_case(self, holdout_ids: List[Union[str, int]]) -> str:
674
+ """Build SQL CASE statement for training column based on holdout IDs."""
675
+ if all(isinstance(id_val, str) for id_val in holdout_ids):
676
+ formatted_ids = ", ".join(f"'{id_val}'" for id_val in holdout_ids)
684
677
  else:
685
- # Only exclusions, no UNION needed
686
- custom_sql = base_query
678
+ formatted_ids = ", ".join(map(str, holdout_ids))
679
+ return f"""CASE
680
+ WHEN {self.id_column} IN ({formatted_ids}) THEN False
681
+ ELSE True
682
+ END AS training"""
683
+
684
+ def _build_weight_case(self, weight_dict: Dict[Union[str, int], float], default_weight: float) -> str:
685
+ """Build SQL CASE statement for sample_weight column."""
686
+ conditions = [
687
+ f"WHEN {self.id_column} = {self._format_id_for_sql(id_val)} THEN {weight}"
688
+ for id_val, weight in weight_dict.items()
689
+ ]
690
+ case_body = "\n ".join(conditions)
691
+ return f"""CASE
692
+ {case_body}
693
+ ELSE {default_weight}
694
+ END AS sample_weight"""
695
+
696
+ def _create_training_view(self, sql_query: str):
697
+ """Create the training view directly from a SQL query."""
698
+ view_table = f"{self.table}___training"
699
+ create_view_query = f"CREATE OR REPLACE VIEW {view_table} AS\n{sql_query}"
700
+ self.data_source.execute_statement(create_view_query)
701
+
702
+ def _create_weights_table(self, weight_dict: Dict[Union[str, int], float]) -> str:
703
+ """Store sample weights as a supplemental data table.
704
+
705
+ Args:
706
+ weight_dict: Mapping of ID to sample weight
687
707
 
688
- # Create the training view with our custom SQL
689
- TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
708
+ Returns:
709
+ str: The name of the created supplemental table
710
+ """
711
+ from workbench.core.views.view_utils import dataframe_to_table
712
+
713
+ # Create DataFrame from weight_dict
714
+ df = pd.DataFrame(
715
+ [(id_val, weight) for id_val, weight in weight_dict.items()],
716
+ columns=[self.id_column, "sample_weight"],
717
+ )
718
+
719
+ # Supplemental table name follows convention: _{base_table}___sample_weights
720
+ weights_table = f"_{self.table}___sample_weights"
721
+
722
+ # Store as supplemental data table
723
+ self.log.info(f"Creating supplemental weights table: {weights_table}")
724
+ dataframe_to_table(self.data_source, df, weights_table)
725
+
726
+ return weights_table
690
727
 
691
728
  @classmethod
692
729
  def delete_views(cls, table: str, database: str):
@@ -737,20 +774,6 @@ class FeatureSetCore(Artifact):
737
774
  """
738
775
  return self.data_source.smart_sample()
739
776
 
740
- def anomalies(self) -> pd.DataFrame:
741
- """Get a set of anomalous data from the underlying DataSource
742
- Returns:
743
- pd.DataFrame: A dataframe of anomalies from the underlying DataSource
744
- """
745
-
746
- # FIXME: Mock this for now
747
- anom_df = self.sample().copy()
748
- anom_df["anomaly_score"] = np.random.rand(anom_df.shape[0])
749
- anom_df["cluster"] = np.random.randint(0, 10, anom_df.shape[0])
750
- anom_df["x"] = np.random.rand(anom_df.shape[0])
751
- anom_df["y"] = np.random.rand(anom_df.shape[0])
752
- return anom_df
753
-
754
777
  def value_counts(self) -> dict:
755
778
  """Get the value counts for the string columns of the underlying DataSource
756
779
 
@@ -915,81 +938,71 @@ if __name__ == "__main__":
915
938
  training_data = my_features.get_training_data()
916
939
  print(f"Training Data: {training_data.shape}")
917
940
 
918
- # Test the filter expression functionality
919
- print("Setting a filter expression...")
920
- my_features.set_training_filter("auto_id < 50 AND length > 65.0")
921
- training_data = my_features.get_training_data()
922
- print(f"Training Data: {training_data.shape}")
923
- print(training_data)
924
-
925
- # Remove training filter
926
- print("Removing the filter expression...")
927
- my_features.set_training_filter(None)
928
- training_data = my_features.get_training_data()
929
- print(f"Training Data: {training_data.shape}")
930
- print(training_data)
931
-
932
- # Test excluding ids from training
933
- print("Excluding ids from training...")
934
- my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
935
- training_data = my_features.get_training_data()
936
- print(f"Training Data: {training_data.shape}")
937
- print(training_data)
941
+ # Test set_sample_weights
942
+ print("\n--- Testing set_sample_weights ---")
943
+ sample_ids = df["auto_id"].tolist()[:5]
944
+ weight_dict = {sample_ids[0]: 0.0, sample_ids[1]: 0.5, sample_ids[2]: 2.0}
945
+ my_features.set_sample_weights(weight_dict)
946
+ training_view = my_features.view("training")
947
+ training_df = training_view.pull_dataframe()
948
+ print(f"Training view shape after set_sample_weights: {training_df.shape}")
949
+ print(f"Columns: {training_df.columns.tolist()}")
950
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
951
+ assert "training" in training_df.columns, "training column missing!"
952
+ # Verify zero-weight row was excluded
953
+ assert sample_ids[0] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
954
+ print("set_sample_weights test passed!")
955
+
956
+ # Test set_training_config with both holdouts and weights
957
+ print("\n--- Testing set_training_config (combined) ---")
958
+ holdout_ids = [id for id in df["auto_id"] if id >= 100 and id < 120]
959
+ weight_dict = {sample_ids[3]: 0.0, sample_ids[4]: 3.0} # exclude one, upweight another
960
+ my_features.set_training_config(holdout_ids=holdout_ids, weight_dict=weight_dict)
961
+ training_view = my_features.view("training")
962
+ training_df = training_view.pull_dataframe()
963
+ print(f"Training view shape after set_training_config: {training_df.shape}")
964
+ print(f"Columns: {training_df.columns.tolist()}")
965
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
966
+ assert "training" in training_df.columns, "training column missing!"
967
+ # Verify holdout IDs are marked as training=False
968
+ holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
969
+ assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
970
+ # Verify zero-weight row was excluded
971
+ assert sample_ids[3] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
972
+ # Verify upweighted row has correct weight
973
+ upweight_row = training_df[training_df["auto_id"] == sample_ids[4]]
974
+ assert upweight_row["sample_weight"].iloc[0] == 3.0, "Upweighted ID should have weight=3.0!"
975
+ print("set_training_config (combined) test passed!")
976
+
977
+ # Test set_training_config with only holdouts (should delegate to set_training_holdouts)
978
+ print("\n--- Testing set_training_config (holdouts only) ---")
979
+ my_features.set_training_config(holdout_ids=holdout_ids)
980
+ training_view = my_features.view("training")
981
+ training_df = training_view.pull_dataframe()
982
+ assert "training" in training_df.columns, "training column missing!"
983
+ holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
984
+ assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
985
+ print("set_training_config (holdouts only) test passed!")
986
+
987
+ # Test set_training_config with only weights (should delegate to set_sample_weights)
988
+ print("\n--- Testing set_training_config (weights only) ---")
989
+ my_features.set_training_config(weight_dict={sample_ids[0]: 0.5, sample_ids[1]: 2.0})
990
+ training_view = my_features.view("training")
991
+ training_df = training_view.pull_dataframe()
992
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
993
+ print("set_training_config (weights only) test passed!")
994
+
995
+ # Test set_training_config with neither (should create standard training view)
996
+ print("\n--- Testing set_training_config (neither) ---")
997
+ my_features.set_training_config()
998
+ training_view = my_features.view("training")
999
+ training_df = training_view.pull_dataframe()
1000
+ assert "training" in training_df.columns, "training column missing!"
1001
+ print("set_training_config (neither) test passed!")
1002
+
1003
+ print("\n=== All training config tests passed! ===")
938
1004
 
939
1005
  # Now delete the AWS artifacts associated with this Feature Set
940
1006
  # print("Deleting Workbench Feature Set...")
941
1007
  # my_features.delete()
942
1008
  # print("Done")
943
-
944
- # Test set_training_sampling with exclusions and replications
945
- print("\n--- Testing set_training_sampling ---")
946
- my_features.set_training_filter(None) # Reset any existing filters
947
- original_count = num_rows
948
-
949
- # Get valid IDs from the table
950
- all_data = my_features.query(f'SELECT auto_id, length FROM "{table}"')
951
- valid_ids = sorted(all_data["auto_id"].tolist())
952
- print(f"Valid IDs range from {valid_ids[0]} to {valid_ids[-1]}")
953
-
954
- exclude_list = valid_ids[0:3] # First 3 IDs
955
- replicate_list = valid_ids[10:13] # IDs at positions 10, 11, 12
956
-
957
- print(f"Original row count: {original_count}")
958
- print(f"Excluding IDs: {exclude_list}")
959
- print(f"Replicating IDs: {replicate_list}")
960
-
961
- # Test with default replication factor (2x)
962
- print("\n--- Testing with replication_factor=2 (default) ---")
963
- my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list)
964
- training_data = my_features.get_training_data()
965
- print(f"Training Data after sampling: {training_data.shape}")
966
-
967
- # Verify exclusions
968
- for exc_id in exclude_list:
969
- count = len(training_data[training_data["auto_id"] == exc_id])
970
- print(f"Excluded ID {exc_id} appears {count} times (should be 0)")
971
-
972
- # Verify replications
973
- for rep_id in replicate_list:
974
- count = len(training_data[training_data["auto_id"] == rep_id])
975
- print(f"Replicated ID {rep_id} appears {count} times (should be 2)")
976
-
977
- # Test with replication factor of 5
978
- print("\n--- Testing with replication_factor=5 ---")
979
- replicate_list_5x = [20, 21]
980
- my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list_5x, replication_factor=5)
981
- training_data = my_features.get_training_data()
982
- print(f"Training Data after sampling: {training_data.shape}")
983
-
984
- # Verify 5x replication
985
- for rep_id in replicate_list_5x:
986
- count = len(training_data[training_data["auto_id"] == rep_id])
987
- print(f"Replicated ID {rep_id} appears {count} times (should be 5)")
988
-
989
- # Test with large replication list (simulate 100 IDs)
990
- print("\n--- Testing with large ID list (100 IDs) ---")
991
- large_replicate_list = list(range(30, 130)) # 100 IDs
992
- my_features.set_training_sampling(replicate_ids=large_replicate_list, replication_factor=3)
993
- training_data = my_features.get_training_data()
994
- print(f"Training Data after sampling: {training_data.shape}")
995
- print(f"Expected extra rows: {len(large_replicate_list) * 3}")
@@ -7,7 +7,6 @@ import logging
7
7
  from typing import Union
8
8
  import pandas as pd
9
9
 
10
-
11
10
  # Workbench Imports
12
11
  from workbench.core.cloud_platform.aws.aws_meta import AWSMeta
13
12
 
@@ -227,20 +227,14 @@ class FeaturesToModel(Transform):
227
227
  self.log.critical(msg)
228
228
  raise ValueError(msg)
229
229
 
230
- # Dynamically create the metric definitions
230
+ # Dynamically create the metric definitions (per-class precision/recall/f1/support)
231
+ # Note: Confusion matrix metrics are skipped to stay under SageMaker's 40 metric limit
231
232
  metrics = ["precision", "recall", "f1", "support"]
232
233
  metric_definitions = []
233
234
  for t in self.class_labels:
234
235
  for m in metrics:
235
236
  metric_definitions.append({"Name": f"Metrics:{t}:{m}", "Regex": f"Metrics:{t}:{m} ([0-9.]+)"})
236
237
 
237
- # Add the confusion matrix metrics
238
- for row in self.class_labels:
239
- for col in self.class_labels:
240
- metric_definitions.append(
241
- {"Name": f"ConfusionMatrix:{row}:{col}", "Regex": f"ConfusionMatrix:{row}:{col} ([0-9.]+)"}
242
- )
243
-
244
238
  # If the model type is UNKNOWN, our metric_definitions will be empty
245
239
  else:
246
240
  self.log.important(f"ModelType is {self.model_type}, skipping metric_definitions...")
@@ -148,6 +148,7 @@ class ModelToEndpoint(Transform):
148
148
  deserializer=CSVDeserializer(),
149
149
  data_capture_config=data_capture_config,
150
150
  tags=aws_tags,
151
+ container_startup_health_check_timeout=300,
151
152
  )
152
153
  except ClientError as e:
153
154
  # Check if this is the "endpoint config already exists" error
@@ -164,6 +165,7 @@ class ModelToEndpoint(Transform):
164
165
  deserializer=CSVDeserializer(),
165
166
  data_capture_config=data_capture_config,
166
167
  tags=aws_tags,
168
+ container_startup_health_check_timeout=300,
167
169
  )
168
170
  else:
169
171
  raise
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
249
249
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
250
 
251
251
 
252
+ def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
253
+ """Cap extreme outliers in prediction_std using IQR method.
254
+
255
+ Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
256
+ This prevents unreasonably large std values while preserving the
257
+ relative ordering and keeping meaningful high-uncertainty signals.
258
+
259
+ Args:
260
+ std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
261
+
262
+ Returns:
263
+ Array with outliers capped at the upper fence
264
+ """
265
+ if std_array.ndim == 1:
266
+ std_array = std_array.reshape(-1, 1)
267
+ squeeze = True
268
+ else:
269
+ squeeze = False
270
+
271
+ capped = std_array.copy()
272
+ for col in range(capped.shape[1]):
273
+ col_data = capped[:, col]
274
+ q1, q3 = np.percentile(col_data, [25, 75])
275
+ iqr = q3 - q1
276
+ upper_bound = q3 + 1.5 * iqr
277
+ capped[:, col] = np.minimum(col_data, upper_bound)
278
+
279
+ return capped.squeeze() if squeeze else capped
280
+
281
+
252
282
  def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
283
  """Compute standard regression metrics.
254
284