workbench 0.8.217__py3-none-any.whl → 0.8.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  3. workbench/algorithms/dataframe/projection_2d.py +8 -2
  4. workbench/algorithms/dataframe/proximity.py +3 -0
  5. workbench/algorithms/sql/outliers.py +3 -3
  6. workbench/api/feature_set.py +0 -1
  7. workbench/core/artifacts/endpoint_core.py +2 -2
  8. workbench/core/artifacts/feature_set_core.py +185 -230
  9. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  10. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  11. workbench/model_script_utils/model_script_utils.py +15 -11
  12. workbench/model_scripts/chemprop/chemprop.template +195 -70
  13. workbench/model_scripts/chemprop/generated_model_script.py +198 -73
  14. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
  16. workbench/model_scripts/pytorch_model/generated_model_script.py +2 -2
  17. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  18. workbench/model_scripts/xgb_model/generated_model_script.py +7 -7
  19. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  20. workbench/scripts/meta_model_sim.py +35 -0
  21. workbench/scripts/ml_pipeline_sqs.py +71 -2
  22. workbench/themes/light/custom.css +7 -1
  23. workbench/themes/midnight_blue/custom.css +34 -0
  24. workbench/utils/chem_utils/fingerprints.py +80 -43
  25. workbench/utils/chem_utils/projections.py +16 -6
  26. workbench/utils/meta_model_simulator.py +41 -13
  27. workbench/utils/model_utils.py +0 -1
  28. workbench/utils/plot_utils.py +146 -28
  29. workbench/utils/shap_utils.py +1 -55
  30. workbench/utils/theme_manager.py +95 -30
  31. workbench/web_interface/components/plugins/scatter_plot.py +152 -66
  32. workbench/web_interface/components/settings_menu.py +184 -0
  33. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
  34. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/RECORD +38 -37
  35. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +1 -0
  36. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  37. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  38. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
  39. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
  40. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
@@ -410,7 +410,7 @@ class EndpointCore(Artifact):
410
410
  primary_target = targets
411
411
 
412
412
  # Sanity Check that the target column is present
413
- if primary_target and (primary_target not in prediction_df.columns):
413
+ if primary_target not in prediction_df.columns:
414
414
  self.log.important(f"Target Column {primary_target} not found in prediction_df!")
415
415
  self.log.important("In order to compute metrics, the target column must be present!")
416
416
  metrics = pd.DataFrame()
@@ -432,7 +432,7 @@ class EndpointCore(Artifact):
432
432
  print(metrics.head())
433
433
 
434
434
  # Capture the inference results and metrics
435
- if capture_name is not None:
435
+ if primary_target and capture_name:
436
436
 
437
437
  # If we don't have an id_column, we'll pull it from the model's FeatureSet
438
438
  if id_column is None:
@@ -7,7 +7,6 @@ from datetime import datetime, timezone
7
7
  import botocore.exceptions
8
8
  import pandas as pd
9
9
  import awswrangler as wr
10
- import numpy as np
11
10
 
12
11
  from sagemaker.feature_store.feature_group import FeatureGroup
13
12
  from sagemaker.feature_store.feature_store import FeatureStore
@@ -16,9 +15,8 @@ from sagemaker.feature_store.feature_store import FeatureStore
16
15
  from workbench.core.artifacts.artifact import Artifact
17
16
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
17
  from workbench.core.artifacts.athena_source import AthenaSource
19
- from workbench.utils.deprecated_utils import deprecated
20
18
 
21
- from typing import TYPE_CHECKING, Optional, List, Dict, Union
19
+ from typing import TYPE_CHECKING, List, Dict, Union
22
20
 
23
21
  from workbench.utils.aws_utils import aws_throttle
24
22
 
@@ -247,7 +245,7 @@ class FeatureSetCore(Artifact):
247
245
 
248
246
  # Set the compressed features in our FeatureSet metadata
249
247
  self.log.important(f"Setting Compressed Columns...{compressed_columns}")
250
- self.upsert_workbench_meta({"comp_features": compressed_columns})
248
+ self.upsert_workbench_meta({"compressed_features": compressed_columns})
251
249
 
252
250
  def get_compressed_features(self) -> list[str]:
253
251
  """Get the compressed features for this FeatureSet
@@ -256,7 +254,7 @@ class FeatureSetCore(Artifact):
256
254
  list[str]: The compressed columns for this FeatureSet
257
255
  """
258
256
  # Get the compressed features from our FeatureSet metadata
259
- return self.workbench_meta().get("comp_features", [])
257
+ return self.workbench_meta().get("compressed_features", [])
260
258
 
261
259
  def num_columns(self) -> int:
262
260
  """Return the number of columns of the Feature Set"""
@@ -483,18 +481,6 @@ class FeatureSetCore(Artifact):
483
481
  time.sleep(1)
484
482
  cls.log.info(f"FeatureSet {feature_group.name} successfully deleted")
485
483
 
486
- def set_training_holdouts(self, holdout_ids: list[str]):
487
- """Set the hold out ids for the training view for this FeatureSet
488
-
489
- Args:
490
- holdout_ids (list[str]): The list of holdout ids.
491
- """
492
- from workbench.core.views import TrainingView
493
-
494
- # Create a NEW training view
495
- self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
496
- TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
497
-
498
484
  def get_training_holdouts(self) -> list[str]:
499
485
  """Get the hold out ids for the training view for this FeatureSet
500
486
 
@@ -510,183 +496,176 @@ class FeatureSetCore(Artifact):
510
496
  ].tolist()
511
497
  return hold_out_ids
512
498
 
513
- def set_sample_weights(
499
+ # ---- Public methods for training configuration ----
500
+ def set_training_config(
514
501
  self,
515
- weight_dict: Dict[Union[str, int], float],
502
+ holdout_ids: List[Union[str, int]] = None,
503
+ weight_dict: Dict[Union[str, int], float] = None,
516
504
  default_weight: float = 1.0,
517
505
  exclude_zero_weights: bool = True,
518
506
  ):
519
- """Configure training view with sample weights for each ID.
507
+ """Configure training view with holdout IDs and/or sample weights.
508
+
509
+ This method creates a training view that can include both:
510
+ - A 'training' column (True/False) based on holdout IDs
511
+ - A 'sample_weight' column for weighted training
520
512
 
521
513
  Args:
514
+ holdout_ids: List of IDs to mark as training=False (validation/holdout set)
522
515
  weight_dict: Mapping of ID to sample weight
523
516
  - weight > 1.0: oversample/emphasize
524
517
  - weight = 1.0: normal (default)
525
518
  - 0 < weight < 1.0: downweight/de-emphasize
526
- - weight = 0.0: exclude from training
519
+ - weight = 0.0: exclude from training (filtered out if exclude_zero_weights=True)
527
520
  default_weight: Weight for IDs not in weight_dict (default: 1.0)
528
521
  exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
529
522
 
530
523
  Example:
531
- weights = {
532
- 'compound_42': 3.0, # oversample 3x
533
- 'compound_99': 0.1, # noisy, downweight
534
- 'compound_123': 0.0, # exclude from training
535
- }
536
- model.set_sample_weights(weights) # zeros automatically excluded
537
- model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
524
+ # Temporal split with sample weights
525
+ fs.set_training_config(
526
+ holdout_ids=temporal_hold_out_ids, # IDs after cutoff date
527
+ weight_dict={'compound_42': 0.0, 'compound_99': 2.0}, # exclude/upweight
528
+ )
538
529
  """
539
- from workbench.core.views import TrainingView
530
+ from workbench.core.views.training_view import TrainingView
540
531
 
541
- if not weight_dict:
542
- self.log.important("Empty weight_dict, creating standard training view")
532
+ # If neither is provided, create a standard training view
533
+ if not holdout_ids and not weight_dict:
534
+ self.log.important("No holdouts or weights specified, creating standard training view")
543
535
  TrainingView.create(self, id_column=self.id_column)
544
536
  return
545
537
 
546
- self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
538
+ # If only holdout_ids, delegate to set_training_holdouts
539
+ if holdout_ids and not weight_dict:
540
+ self.set_training_holdouts(holdout_ids)
541
+ return
547
542
 
548
- # Helper to format IDs for SQL
549
- def format_id(id_val):
550
- return repr(id_val)
543
+ # If only weight_dict, delegate to set_sample_weights
544
+ if weight_dict and not holdout_ids:
545
+ self.set_sample_weights(weight_dict, default_weight, exclude_zero_weights)
546
+ return
551
547
 
552
- # Build CASE statement for sample_weight
553
- case_conditions = [
554
- f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
555
- ]
556
- case_statement = "\n ".join(case_conditions)
548
+ # Both holdout_ids and weight_dict provided - build combined view
549
+ self.log.important(f"Setting training config: {len(holdout_ids)} holdouts, {len(weight_dict)} weights")
557
550
 
558
- # Build inner query with sample weights
559
- inner_sql = f"""SELECT
560
- *,
561
- CASE
562
- {case_statement}
563
- ELSE {default_weight}
564
- END AS sample_weight
565
- FROM {self.table}"""
551
+ # Get column list (excluding AWS-generated columns)
552
+ from workbench.core.views.view_utils import get_column_list
553
+
554
+ aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
555
+ source_columns = get_column_list(self.data_source, self.table)
556
+ column_list = [col for col in source_columns if col not in aws_cols]
557
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
558
+
559
+ # Build inner query with both columns
560
+ training_case = self._build_holdout_case(holdout_ids)
561
+ weight_case = self._build_weight_case(weight_dict, default_weight)
562
+ inner_sql = f"SELECT {sql_columns}, {training_case}, {weight_case} FROM {self.table}"
566
563
 
567
564
  # Optionally filter out zero weights
568
565
  if exclude_zero_weights:
569
- zero_count = sum(1 for weight in weight_dict.values() if weight == 0.0)
570
- custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
571
- self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
566
+ zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
567
+ if zero_count:
568
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
569
+ sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
572
570
  else:
573
- custom_sql = inner_sql
571
+ sql_query = inner_sql
574
572
 
575
- TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
573
+ self._create_training_view(sql_query)
576
574
 
577
- @deprecated(version="0.9")
578
- def set_training_filter(self, filter_expression: Optional[str] = None):
579
- """Set a filter expression for the training view for this FeatureSet
575
+ def set_training_holdouts(self, holdout_ids: list[str]):
576
+ """Set the hold out ids for the training view for this FeatureSet
580
577
 
581
578
  Args:
582
- filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
583
- If None or empty string, will reset to training view with no filter
584
- (default: None)
579
+ holdout_ids (list[str]): The list of holdout ids.
585
580
  """
586
581
  from workbench.core.views import TrainingView
587
582
 
588
- # Grab the existing holdout ids
589
- holdout_ids = self.get_training_holdouts()
590
-
591
- # Create a NEW training view
592
- self.log.important(f"Setting Training Filter: {filter_expression}")
593
- TrainingView.create(
594
- self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
595
- )
596
-
597
- @deprecated(version="0.9")
598
- def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
599
- """Exclude a list of IDs from the training view
600
-
601
- Args:
602
- ids (List[Union[str, int]],): List of IDs to exclude from training
603
- column_name (Optional[str]): Column name to filter on.
604
- If None, uses self.id_column (default: None)
605
- """
606
- # Use the default id_column if not specified
607
- column = column_name or self.id_column
608
-
609
- # Handle empty list case
610
- if not ids:
611
- self.log.warning("No IDs provided to exclude")
612
- return
613
-
614
- # Build the filter expression with proper SQL quoting
615
- quoted_ids = ", ".join([repr(id) for id in ids])
616
- filter_expression = f"{column} NOT IN ({quoted_ids})"
617
-
618
- # Apply the filter
619
- self.set_training_filter(filter_expression)
583
+ self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
584
+ TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
620
585
 
621
- @deprecated(version="0.9")
622
- def set_training_sampling(
586
+ def set_sample_weights(
623
587
  self,
624
- exclude_ids: Optional[List[Union[str, int]]] = None,
625
- replicate_ids: Optional[List[Union[str, int]]] = None,
626
- replication_factor: int = 2,
588
+ weight_dict: Dict[Union[str, int], float],
589
+ default_weight: float = 1.0,
590
+ exclude_zero_weights: bool = True,
627
591
  ):
628
- """Configure training view with ID exclusions and replications (oversampling).
592
+ """Configure training view with sample weights for each ID.
629
593
 
630
594
  Args:
631
- exclude_ids: List of IDs to exclude from training view
632
- replicate_ids: List of IDs to replicate in training view for oversampling
633
- replication_factor: Number of times to replicate each ID (default: 2)
595
+ weight_dict: Mapping of ID to sample weight
596
+ - weight > 1.0: oversample/emphasize
597
+ - weight = 1.0: normal (default)
598
+ - 0 < weight < 1.0: downweight/de-emphasize
599
+ - weight = 0.0: exclude from training
600
+ default_weight: Weight for IDs not in weight_dict (default: 1.0)
601
+ exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
634
602
 
635
- Note:
636
- If an ID appears in both lists, exclusion takes precedence.
603
+ Example:
604
+ weights = {
605
+ 'compound_42': 3.0, # oversample 3x
606
+ 'compound_99': 0.1, # noisy, downweight
607
+ 'compound_123': 0.0, # exclude from training
608
+ }
609
+ fs.set_sample_weights(weights) # zeros automatically excluded
610
+ fs.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
637
611
  """
638
612
  from workbench.core.views import TrainingView
639
613
 
640
- # Normalize to empty lists if None
641
- exclude_ids = exclude_ids or []
642
- replicate_ids = replicate_ids or []
643
-
644
- # Remove any replicate_ids that are also in exclude_ids (exclusion wins)
645
- replicate_ids = [rid for rid in replicate_ids if rid not in exclude_ids]
646
-
647
- # If no sampling needed, just create normal view
648
- if not exclude_ids and not replicate_ids:
649
- self.log.important("No sampling specified, creating standard training view")
614
+ if not weight_dict:
615
+ self.log.important("Empty weight_dict, creating standard training view")
650
616
  TrainingView.create(self, id_column=self.id_column)
651
617
  return
652
618
 
653
- # Build the custom SQL query
654
- self.log.important(
655
- f"Excluding {len(exclude_ids)} IDs, Replicating {len(replicate_ids)} IDs "
656
- f"(factor: {replication_factor}x)"
657
- )
658
-
659
- # Helper to format IDs for SQL
660
- def format_ids(ids):
661
- return ", ".join([repr(id) for id in ids])
662
-
663
- # Start with base query
664
- base_query = f"SELECT * FROM {self.table}"
619
+ self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
665
620
 
666
- # Add exclusions if needed
667
- if exclude_ids:
668
- base_query += f"\nWHERE {self.id_column} NOT IN ({format_ids(exclude_ids)})"
621
+ # Build inner query with sample weights
622
+ weight_case = self._build_weight_case(weight_dict, default_weight)
623
+ inner_sql = f"SELECT *, {weight_case} FROM {self.table}"
669
624
 
670
- # Build full query with replication
671
- if replicate_ids:
672
- # Generate VALUES clause for CROSS JOIN: (1), (2), ..., (N-1)
673
- # We want N-1 additional copies since the original row is already in base_query
674
- values_clause = ", ".join([f"({i})" for i in range(1, replication_factor)])
625
+ # Optionally filter out zero weights
626
+ if exclude_zero_weights:
627
+ zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
628
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
629
+ sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
630
+ else:
631
+ sql_query = inner_sql
675
632
 
676
- custom_sql = f"""{base_query}
633
+ TrainingView.create_with_sql(self, sql_query=sql_query, id_column=self.id_column)
677
634
 
678
- UNION ALL
635
+ # ---- Internal helpers for training view SQL generation ----
636
+ @staticmethod
637
+ def _format_id_for_sql(id_val: Union[str, int]) -> str:
638
+ """Format an ID value for use in SQL."""
639
+ return repr(id_val)
679
640
 
680
- SELECT t.*
681
- FROM {self.table} t
682
- CROSS JOIN (VALUES {values_clause}) AS n(num)
683
- WHERE t.{self.id_column} IN ({format_ids(replicate_ids)})"""
641
+ def _build_holdout_case(self, holdout_ids: List[Union[str, int]]) -> str:
642
+ """Build SQL CASE statement for training column based on holdout IDs."""
643
+ if all(isinstance(id_val, str) for id_val in holdout_ids):
644
+ formatted_ids = ", ".join(f"'{id_val}'" for id_val in holdout_ids)
684
645
  else:
685
- # Only exclusions, no UNION needed
686
- custom_sql = base_query
687
-
688
- # Create the training view with our custom SQL
689
- TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
646
+ formatted_ids = ", ".join(map(str, holdout_ids))
647
+ return f"""CASE
648
+ WHEN {self.id_column} IN ({formatted_ids}) THEN False
649
+ ELSE True
650
+ END AS training"""
651
+
652
+ def _build_weight_case(self, weight_dict: Dict[Union[str, int], float], default_weight: float) -> str:
653
+ """Build SQL CASE statement for sample_weight column."""
654
+ conditions = [
655
+ f"WHEN {self.id_column} = {self._format_id_for_sql(id_val)} THEN {weight}"
656
+ for id_val, weight in weight_dict.items()
657
+ ]
658
+ case_body = "\n ".join(conditions)
659
+ return f"""CASE
660
+ {case_body}
661
+ ELSE {default_weight}
662
+ END AS sample_weight"""
663
+
664
+ def _create_training_view(self, sql_query: str):
665
+ """Create the training view directly from a SQL query."""
666
+ view_table = f"{self.table}___training"
667
+ create_view_query = f"CREATE OR REPLACE VIEW {view_table} AS\n{sql_query}"
668
+ self.data_source.execute_statement(create_view_query)
690
669
 
691
670
  @classmethod
692
671
  def delete_views(cls, table: str, database: str):
@@ -737,20 +716,6 @@ class FeatureSetCore(Artifact):
737
716
  """
738
717
  return self.data_source.smart_sample()
739
718
 
740
- def anomalies(self) -> pd.DataFrame:
741
- """Get a set of anomalous data from the underlying DataSource
742
- Returns:
743
- pd.DataFrame: A dataframe of anomalies from the underlying DataSource
744
- """
745
-
746
- # FIXME: Mock this for now
747
- anom_df = self.sample().copy()
748
- anom_df["anomaly_score"] = np.random.rand(anom_df.shape[0])
749
- anom_df["cluster"] = np.random.randint(0, 10, anom_df.shape[0])
750
- anom_df["x"] = np.random.rand(anom_df.shape[0])
751
- anom_df["y"] = np.random.rand(anom_df.shape[0])
752
- return anom_df
753
-
754
719
  def value_counts(self) -> dict:
755
720
  """Get the value counts for the string columns of the underlying DataSource
756
721
 
@@ -915,81 +880,71 @@ if __name__ == "__main__":
915
880
  training_data = my_features.get_training_data()
916
881
  print(f"Training Data: {training_data.shape}")
917
882
 
918
- # Test the filter expression functionality
919
- print("Setting a filter expression...")
920
- my_features.set_training_filter("auto_id < 50 AND length > 65.0")
921
- training_data = my_features.get_training_data()
922
- print(f"Training Data: {training_data.shape}")
923
- print(training_data)
924
-
925
- # Remove training filter
926
- print("Removing the filter expression...")
927
- my_features.set_training_filter(None)
928
- training_data = my_features.get_training_data()
929
- print(f"Training Data: {training_data.shape}")
930
- print(training_data)
931
-
932
- # Test excluding ids from training
933
- print("Excluding ids from training...")
934
- my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
935
- training_data = my_features.get_training_data()
936
- print(f"Training Data: {training_data.shape}")
937
- print(training_data)
883
+ # Test set_sample_weights
884
+ print("\n--- Testing set_sample_weights ---")
885
+ sample_ids = df["auto_id"].tolist()[:5]
886
+ weight_dict = {sample_ids[0]: 0.0, sample_ids[1]: 0.5, sample_ids[2]: 2.0}
887
+ my_features.set_sample_weights(weight_dict)
888
+ training_view = my_features.view("training")
889
+ training_df = training_view.pull_dataframe()
890
+ print(f"Training view shape after set_sample_weights: {training_df.shape}")
891
+ print(f"Columns: {training_df.columns.tolist()}")
892
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
893
+ assert "training" in training_df.columns, "training column missing!"
894
+ # Verify zero-weight row was excluded
895
+ assert sample_ids[0] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
896
+ print("set_sample_weights test passed!")
897
+
898
+ # Test set_training_config with both holdouts and weights
899
+ print("\n--- Testing set_training_config (combined) ---")
900
+ holdout_ids = [id for id in df["auto_id"] if id >= 100 and id < 120]
901
+ weight_dict = {sample_ids[3]: 0.0, sample_ids[4]: 3.0} # exclude one, upweight another
902
+ my_features.set_training_config(holdout_ids=holdout_ids, weight_dict=weight_dict)
903
+ training_view = my_features.view("training")
904
+ training_df = training_view.pull_dataframe()
905
+ print(f"Training view shape after set_training_config: {training_df.shape}")
906
+ print(f"Columns: {training_df.columns.tolist()}")
907
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
908
+ assert "training" in training_df.columns, "training column missing!"
909
+ # Verify holdout IDs are marked as training=False
910
+ holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
911
+ assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
912
+ # Verify zero-weight row was excluded
913
+ assert sample_ids[3] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
914
+ # Verify upweighted row has correct weight
915
+ upweight_row = training_df[training_df["auto_id"] == sample_ids[4]]
916
+ assert upweight_row["sample_weight"].iloc[0] == 3.0, "Upweighted ID should have weight=3.0!"
917
+ print("set_training_config (combined) test passed!")
918
+
919
+ # Test set_training_config with only holdouts (should delegate to set_training_holdouts)
920
+ print("\n--- Testing set_training_config (holdouts only) ---")
921
+ my_features.set_training_config(holdout_ids=holdout_ids)
922
+ training_view = my_features.view("training")
923
+ training_df = training_view.pull_dataframe()
924
+ assert "training" in training_df.columns, "training column missing!"
925
+ holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
926
+ assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
927
+ print("set_training_config (holdouts only) test passed!")
928
+
929
+ # Test set_training_config with only weights (should delegate to set_sample_weights)
930
+ print("\n--- Testing set_training_config (weights only) ---")
931
+ my_features.set_training_config(weight_dict={sample_ids[0]: 0.5, sample_ids[1]: 2.0})
932
+ training_view = my_features.view("training")
933
+ training_df = training_view.pull_dataframe()
934
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
935
+ print("set_training_config (weights only) test passed!")
936
+
937
+ # Test set_training_config with neither (should create standard training view)
938
+ print("\n--- Testing set_training_config (neither) ---")
939
+ my_features.set_training_config()
940
+ training_view = my_features.view("training")
941
+ training_df = training_view.pull_dataframe()
942
+ assert "training" in training_df.columns, "training column missing!"
943
+ print("set_training_config (neither) test passed!")
944
+
945
+ print("\n=== All training config tests passed! ===")
938
946
 
939
947
  # Now delete the AWS artifacts associated with this Feature Set
940
948
  # print("Deleting Workbench Feature Set...")
941
949
  # my_features.delete()
942
950
  # print("Done")
943
-
944
- # Test set_training_sampling with exclusions and replications
945
- print("\n--- Testing set_training_sampling ---")
946
- my_features.set_training_filter(None) # Reset any existing filters
947
- original_count = num_rows
948
-
949
- # Get valid IDs from the table
950
- all_data = my_features.query(f'SELECT auto_id, length FROM "{table}"')
951
- valid_ids = sorted(all_data["auto_id"].tolist())
952
- print(f"Valid IDs range from {valid_ids[0]} to {valid_ids[-1]}")
953
-
954
- exclude_list = valid_ids[0:3] # First 3 IDs
955
- replicate_list = valid_ids[10:13] # IDs at positions 10, 11, 12
956
-
957
- print(f"Original row count: {original_count}")
958
- print(f"Excluding IDs: {exclude_list}")
959
- print(f"Replicating IDs: {replicate_list}")
960
-
961
- # Test with default replication factor (2x)
962
- print("\n--- Testing with replication_factor=2 (default) ---")
963
- my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list)
964
- training_data = my_features.get_training_data()
965
- print(f"Training Data after sampling: {training_data.shape}")
966
-
967
- # Verify exclusions
968
- for exc_id in exclude_list:
969
- count = len(training_data[training_data["auto_id"] == exc_id])
970
- print(f"Excluded ID {exc_id} appears {count} times (should be 0)")
971
-
972
- # Verify replications
973
- for rep_id in replicate_list:
974
- count = len(training_data[training_data["auto_id"] == rep_id])
975
- print(f"Replicated ID {rep_id} appears {count} times (should be 2)")
976
-
977
- # Test with replication factor of 5
978
- print("\n--- Testing with replication_factor=5 ---")
979
- replicate_list_5x = [20, 21]
980
- my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list_5x, replication_factor=5)
981
- training_data = my_features.get_training_data()
982
- print(f"Training Data after sampling: {training_data.shape}")
983
-
984
- # Verify 5x replication
985
- for rep_id in replicate_list_5x:
986
- count = len(training_data[training_data["auto_id"] == rep_id])
987
- print(f"Replicated ID {rep_id} appears {count} times (should be 5)")
988
-
989
- # Test with large replication list (simulate 100 IDs)
990
- print("\n--- Testing with large ID list (100 IDs) ---")
991
- large_replicate_list = list(range(30, 130)) # 100 IDs
992
- my_features.set_training_sampling(replicate_ids=large_replicate_list, replication_factor=3)
993
- training_data = my_features.get_training_data()
994
- print(f"Training Data after sampling: {training_data.shape}")
995
- print(f"Expected extra rows: {len(large_replicate_list) * 3}")
@@ -227,20 +227,14 @@ class FeaturesToModel(Transform):
227
227
  self.log.critical(msg)
228
228
  raise ValueError(msg)
229
229
 
230
- # Dynamically create the metric definitions
230
+ # Dynamically create the metric definitions (per-class precision/recall/f1/support)
231
+ # Note: Confusion matrix metrics are skipped to stay under SageMaker's 40 metric limit
231
232
  metrics = ["precision", "recall", "f1", "support"]
232
233
  metric_definitions = []
233
234
  for t in self.class_labels:
234
235
  for m in metrics:
235
236
  metric_definitions.append({"Name": f"Metrics:{t}:{m}", "Regex": f"Metrics:{t}:{m} ([0-9.]+)"})
236
237
 
237
- # Add the confusion matrix metrics
238
- for row in self.class_labels:
239
- for col in self.class_labels:
240
- metric_definitions.append(
241
- {"Name": f"ConfusionMatrix:{row}:{col}", "Regex": f"ConfusionMatrix:{row}:{col} ([0-9.]+)"}
242
- )
243
-
244
238
  # If the model type is UNKNOWN, our metric_definitions will be empty
245
239
  else:
246
240
  self.log.important(f"ModelType is {self.model_type}, skipping metric_definitions...")
@@ -148,6 +148,7 @@ class ModelToEndpoint(Transform):
148
148
  deserializer=CSVDeserializer(),
149
149
  data_capture_config=data_capture_config,
150
150
  tags=aws_tags,
151
+ container_startup_health_check_timeout=300,
151
152
  )
152
153
  except ClientError as e:
153
154
  # Check if this is the "endpoint config already exists" error
@@ -164,6 +165,7 @@ class ModelToEndpoint(Transform):
164
165
  deserializer=CSVDeserializer(),
165
166
  data_capture_config=data_capture_config,
166
167
  tags=aws_tags,
168
+ container_startup_health_check_timeout=300,
167
169
  )
168
170
  else:
169
171
  raise
@@ -148,12 +148,16 @@ def convert_categorical_types(
148
148
  def decompress_features(
149
149
  df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
150
  ) -> tuple[pd.DataFrame, list[str]]:
151
- """Decompress bitstring features into individual bit columns.
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
152
156
 
153
157
  Args:
154
158
  df: The features DataFrame
155
159
  features: Full list of feature names
156
- compressed_features: List of feature names to decompress (bitstrings)
160
+ compressed_features: List of feature names to decompress
157
161
 
158
162
  Returns:
159
163
  Tuple of (DataFrame with decompressed features, updated feature list)
@@ -178,18 +182,18 @@ def decompress_features(
178
182
  # Remove the feature from the list to avoid duplication
179
183
  decompressed_features.remove(feature)
180
184
 
181
- # Handle all compressed features as bitstrings
182
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
183
- prefix = feature[:3]
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
184
189
 
185
- # Create all new columns at once - avoids fragmentation
186
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
187
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
188
194
 
189
- # Add to features list
195
+ # Update features list and dataframe
190
196
  decompressed_features.extend(new_col_names)
191
-
192
- # Drop original column and concatenate new ones
193
197
  df = df.drop(columns=[feature])
194
198
  df = pd.concat([df, new_df], axis=1)
195
199