workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (140) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +5 -5
  27. workbench/core/artifacts/df_store_core.py +114 -0
  28. workbench/core/artifacts/endpoint_core.py +319 -204
  29. workbench/core/artifacts/feature_set_core.py +249 -45
  30. workbench/core/artifacts/model_core.py +135 -82
  31. workbench/core/artifacts/parameter_store_core.py +98 -0
  32. workbench/core/cloud_platform/cloud_meta.py +0 -1
  33. workbench/core/pipelines/pipeline_executor.py +1 -1
  34. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  35. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  36. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  37. workbench/core/views/training_view.py +113 -42
  38. workbench/core/views/view.py +53 -3
  39. workbench/core/views/view_utils.py +4 -4
  40. workbench/model_script_utils/model_script_utils.py +339 -0
  41. workbench/model_script_utils/pytorch_utils.py +405 -0
  42. workbench/model_script_utils/uq_harness.py +277 -0
  43. workbench/model_scripts/chemprop/chemprop.template +774 -0
  44. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  45. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  46. workbench/model_scripts/chemprop/requirements.txt +3 -0
  47. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  48. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  49. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  50. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  51. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  52. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  53. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  54. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  55. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  56. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  57. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  58. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  59. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  60. workbench/model_scripts/meta_model/meta_model.template +209 -0
  61. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  62. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  63. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  64. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  65. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  66. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  67. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  68. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  69. workbench/model_scripts/script_generation.py +15 -12
  70. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  71. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  72. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  73. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  74. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  75. workbench/repl/workbench_shell.py +18 -14
  76. workbench/resources/open_source_api.key +1 -1
  77. workbench/scripts/endpoint_test.py +162 -0
  78. workbench/scripts/lambda_test.py +73 -0
  79. workbench/scripts/meta_model_sim.py +35 -0
  80. workbench/scripts/ml_pipeline_sqs.py +122 -6
  81. workbench/scripts/training_test.py +85 -0
  82. workbench/themes/dark/custom.css +59 -0
  83. workbench/themes/dark/plotly.json +5 -5
  84. workbench/themes/light/custom.css +153 -40
  85. workbench/themes/light/plotly.json +9 -9
  86. workbench/themes/midnight_blue/custom.css +59 -0
  87. workbench/utils/aws_utils.py +0 -1
  88. workbench/utils/chem_utils/fingerprints.py +87 -46
  89. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  90. workbench/utils/chem_utils/projections.py +16 -6
  91. workbench/utils/chem_utils/vis.py +25 -27
  92. workbench/utils/chemprop_utils.py +141 -0
  93. workbench/utils/config_manager.py +2 -6
  94. workbench/utils/endpoint_utils.py +5 -7
  95. workbench/utils/license_manager.py +2 -6
  96. workbench/utils/markdown_utils.py +57 -0
  97. workbench/utils/meta_model_simulator.py +499 -0
  98. workbench/utils/metrics_utils.py +256 -0
  99. workbench/utils/model_utils.py +260 -76
  100. workbench/utils/pipeline_utils.py +0 -1
  101. workbench/utils/plot_utils.py +159 -34
  102. workbench/utils/pytorch_utils.py +87 -0
  103. workbench/utils/shap_utils.py +11 -57
  104. workbench/utils/theme_manager.py +95 -30
  105. workbench/utils/xgboost_local_crossfold.py +267 -0
  106. workbench/utils/xgboost_model_utils.py +127 -220
  107. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  108. workbench/web_interface/components/model_plot.py +16 -2
  109. workbench/web_interface/components/plugin_unit_test.py +5 -3
  110. workbench/web_interface/components/plugins/ag_table.py +2 -4
  111. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  112. workbench/web_interface/components/plugins/model_details.py +48 -80
  113. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  114. workbench/web_interface/components/settings_menu.py +184 -0
  115. workbench/web_interface/page_views/main_page.py +0 -1
  116. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  117. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
  118. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  119. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  120. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  121. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  122. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  123. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  124. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  125. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
  126. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
  127. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  128. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  129. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  130. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  131. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  132. workbench/themes/quartz/base_css.url +0 -1
  133. workbench/themes/quartz/custom.css +0 -117
  134. workbench/themes/quartz/plotly.json +0 -642
  135. workbench/themes/quartz_dark/base_css.url +0 -1
  136. workbench/themes/quartz_dark/custom.css +0 -131
  137. workbench/themes/quartz_dark/plotly.json +0 -642
  138. workbench/utils/resource_utils.py +0 -39
  139. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  140. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@ from datetime import datetime, timezone
7
7
  import botocore.exceptions
8
8
  import pandas as pd
9
9
  import awswrangler as wr
10
- import numpy as np
11
10
 
12
11
  from sagemaker.feature_store.feature_group import FeatureGroup
13
12
  from sagemaker.feature_store.feature_store import FeatureStore
@@ -17,7 +16,7 @@ from workbench.core.artifacts.artifact import Artifact
17
16
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
17
  from workbench.core.artifacts.athena_source import AthenaSource
19
18
 
20
- from typing import TYPE_CHECKING
19
+ from typing import TYPE_CHECKING, List, Dict, Union
21
20
 
22
21
  from workbench.utils.aws_utils import aws_throttle
23
22
 
@@ -194,24 +193,24 @@ class FeatureSetCore(Artifact):
194
193
 
195
194
  return View(self, view_name)
196
195
 
197
- def set_display_columns(self, diplay_columns: list[str]):
196
+ def set_display_columns(self, display_columns: list[str]):
198
197
  """Set the display columns for this Data Source
199
198
 
200
199
  Args:
201
- diplay_columns (list[str]): The display columns for this Data Source
200
+ display_columns (list[str]): The display columns for this Data Source
202
201
  """
203
202
  # Check mismatch of display columns to computation columns
204
203
  c_view = self.view("computation")
205
204
  computation_columns = c_view.columns
206
- mismatch_columns = [col for col in diplay_columns if col not in computation_columns]
205
+ mismatch_columns = [col for col in display_columns if col not in computation_columns]
207
206
  if mismatch_columns:
208
207
  self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
209
208
 
210
- self.log.important(f"Setting Display Columns...{diplay_columns}")
209
+ self.log.important(f"Setting Display Columns...{display_columns}")
211
210
  from workbench.core.views import DisplayView
212
211
 
213
212
  # Create a NEW display view
214
- DisplayView.create(self, source_table=c_view.table, column_list=diplay_columns)
213
+ DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
215
214
 
216
215
  def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
217
216
  """Set the computation columns for this FeatureSet
@@ -246,7 +245,7 @@ class FeatureSetCore(Artifact):
246
245
 
247
246
  # Set the compressed features in our FeatureSet metadata
248
247
  self.log.important(f"Setting Compressed Columns...{compressed_columns}")
249
- self.upsert_workbench_meta({"comp_features": compressed_columns})
248
+ self.upsert_workbench_meta({"compressed_features": compressed_columns})
250
249
 
251
250
  def get_compressed_features(self) -> list[str]:
252
251
  """Get the compressed features for this FeatureSet
@@ -255,7 +254,7 @@ class FeatureSetCore(Artifact):
255
254
  list[str]: The compressed columns for this FeatureSet
256
255
  """
257
256
  # Get the compressed features from our FeatureSet metadata
258
- return self.workbench_meta().get("comp_features", [])
257
+ return self.workbench_meta().get("compressed_features", [])
259
258
 
260
259
  def num_columns(self) -> int:
261
260
  """Return the number of columns of the Feature Set"""
@@ -482,18 +481,6 @@ class FeatureSetCore(Artifact):
482
481
  time.sleep(1)
483
482
  cls.log.info(f"FeatureSet {feature_group.name} successfully deleted")
484
483
 
485
- def set_training_holdouts(self, holdout_ids: list[str]):
486
- """Set the hold out ids for the training view for this FeatureSet
487
-
488
- Args:
489
- holdout_ids (list[str]): The list of holdout ids.
490
- """
491
- from workbench.core.views import TrainingView
492
-
493
- # Create a NEW training view
494
- self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
495
- TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
496
-
497
484
  def get_training_holdouts(self) -> list[str]:
498
485
  """Get the hold out ids for the training view for this FeatureSet
499
486
 
@@ -509,6 +496,177 @@ class FeatureSetCore(Artifact):
509
496
  ].tolist()
510
497
  return hold_out_ids
511
498
 
499
+ # ---- Public methods for training configuration ----
500
+ def set_training_config(
501
+ self,
502
+ holdout_ids: List[Union[str, int]] = None,
503
+ weight_dict: Dict[Union[str, int], float] = None,
504
+ default_weight: float = 1.0,
505
+ exclude_zero_weights: bool = True,
506
+ ):
507
+ """Configure training view with holdout IDs and/or sample weights.
508
+
509
+ This method creates a training view that can include both:
510
+ - A 'training' column (True/False) based on holdout IDs
511
+ - A 'sample_weight' column for weighted training
512
+
513
+ Args:
514
+ holdout_ids: List of IDs to mark as training=False (validation/holdout set)
515
+ weight_dict: Mapping of ID to sample weight
516
+ - weight > 1.0: oversample/emphasize
517
+ - weight = 1.0: normal (default)
518
+ - 0 < weight < 1.0: downweight/de-emphasize
519
+ - weight = 0.0: exclude from training (filtered out if exclude_zero_weights=True)
520
+ default_weight: Weight for IDs not in weight_dict (default: 1.0)
521
+ exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
522
+
523
+ Example:
524
+ # Temporal split with sample weights
525
+ fs.set_training_config(
526
+ holdout_ids=temporal_hold_out_ids, # IDs after cutoff date
527
+ weight_dict={'compound_42': 0.0, 'compound_99': 2.0}, # exclude/upweight
528
+ )
529
+ """
530
+ from workbench.core.views.training_view import TrainingView
531
+
532
+ # If neither is provided, create a standard training view
533
+ if not holdout_ids and not weight_dict:
534
+ self.log.important("No holdouts or weights specified, creating standard training view")
535
+ TrainingView.create(self, id_column=self.id_column)
536
+ return
537
+
538
+ # If only holdout_ids, delegate to set_training_holdouts
539
+ if holdout_ids and not weight_dict:
540
+ self.set_training_holdouts(holdout_ids)
541
+ return
542
+
543
+ # If only weight_dict, delegate to set_sample_weights
544
+ if weight_dict and not holdout_ids:
545
+ self.set_sample_weights(weight_dict, default_weight, exclude_zero_weights)
546
+ return
547
+
548
+ # Both holdout_ids and weight_dict provided - build combined view
549
+ self.log.important(f"Setting training config: {len(holdout_ids)} holdouts, {len(weight_dict)} weights")
550
+
551
+ # Get column list (excluding AWS-generated columns)
552
+ from workbench.core.views.view_utils import get_column_list
553
+
554
+ aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
555
+ source_columns = get_column_list(self.data_source, self.table)
556
+ column_list = [col for col in source_columns if col not in aws_cols]
557
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
558
+
559
+ # Build inner query with both columns
560
+ training_case = self._build_holdout_case(holdout_ids)
561
+ weight_case = self._build_weight_case(weight_dict, default_weight)
562
+ inner_sql = f"SELECT {sql_columns}, {training_case}, {weight_case} FROM {self.table}"
563
+
564
+ # Optionally filter out zero weights
565
+ if exclude_zero_weights:
566
+ zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
567
+ if zero_count:
568
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
569
+ sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
570
+ else:
571
+ sql_query = inner_sql
572
+
573
+ self._create_training_view(sql_query)
574
+
575
+ def set_training_holdouts(self, holdout_ids: list[str]):
576
+ """Set the hold out ids for the training view for this FeatureSet
577
+
578
+ Args:
579
+ holdout_ids (list[str]): The list of holdout ids.
580
+ """
581
+ from workbench.core.views import TrainingView
582
+
583
+ self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
584
+ TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
585
+
586
+ def set_sample_weights(
587
+ self,
588
+ weight_dict: Dict[Union[str, int], float],
589
+ default_weight: float = 1.0,
590
+ exclude_zero_weights: bool = True,
591
+ ):
592
+ """Configure training view with sample weights for each ID.
593
+
594
+ Args:
595
+ weight_dict: Mapping of ID to sample weight
596
+ - weight > 1.0: oversample/emphasize
597
+ - weight = 1.0: normal (default)
598
+ - 0 < weight < 1.0: downweight/de-emphasize
599
+ - weight = 0.0: exclude from training
600
+ default_weight: Weight for IDs not in weight_dict (default: 1.0)
601
+ exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
602
+
603
+ Example:
604
+ weights = {
605
+ 'compound_42': 3.0, # oversample 3x
606
+ 'compound_99': 0.1, # noisy, downweight
607
+ 'compound_123': 0.0, # exclude from training
608
+ }
609
+ fs.set_sample_weights(weights) # zeros automatically excluded
610
+ fs.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
611
+ """
612
+ from workbench.core.views import TrainingView
613
+
614
+ if not weight_dict:
615
+ self.log.important("Empty weight_dict, creating standard training view")
616
+ TrainingView.create(self, id_column=self.id_column)
617
+ return
618
+
619
+ self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
620
+
621
+ # Build inner query with sample weights
622
+ weight_case = self._build_weight_case(weight_dict, default_weight)
623
+ inner_sql = f"SELECT *, {weight_case} FROM {self.table}"
624
+
625
+ # Optionally filter out zero weights
626
+ if exclude_zero_weights:
627
+ zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
628
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
629
+ sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
630
+ else:
631
+ sql_query = inner_sql
632
+
633
+ TrainingView.create_with_sql(self, sql_query=sql_query, id_column=self.id_column)
634
+
635
+ # ---- Internal helpers for training view SQL generation ----
636
+ @staticmethod
637
+ def _format_id_for_sql(id_val: Union[str, int]) -> str:
638
+ """Format an ID value for use in SQL."""
639
+ return repr(id_val)
640
+
641
+ def _build_holdout_case(self, holdout_ids: List[Union[str, int]]) -> str:
642
+ """Build SQL CASE statement for training column based on holdout IDs."""
643
+ if all(isinstance(id_val, str) for id_val in holdout_ids):
644
+ formatted_ids = ", ".join(f"'{id_val}'" for id_val in holdout_ids)
645
+ else:
646
+ formatted_ids = ", ".join(map(str, holdout_ids))
647
+ return f"""CASE
648
+ WHEN {self.id_column} IN ({formatted_ids}) THEN False
649
+ ELSE True
650
+ END AS training"""
651
+
652
+ def _build_weight_case(self, weight_dict: Dict[Union[str, int], float], default_weight: float) -> str:
653
+ """Build SQL CASE statement for sample_weight column."""
654
+ conditions = [
655
+ f"WHEN {self.id_column} = {self._format_id_for_sql(id_val)} THEN {weight}"
656
+ for id_val, weight in weight_dict.items()
657
+ ]
658
+ case_body = "\n ".join(conditions)
659
+ return f"""CASE
660
+ {case_body}
661
+ ELSE {default_weight}
662
+ END AS sample_weight"""
663
+
664
+ def _create_training_view(self, sql_query: str):
665
+ """Create the training view directly from a SQL query."""
666
+ view_table = f"{self.table}___training"
667
+ create_view_query = f"CREATE OR REPLACE VIEW {view_table} AS\n{sql_query}"
668
+ self.data_source.execute_statement(create_view_query)
669
+
512
670
  @classmethod
513
671
  def delete_views(cls, table: str, database: str):
514
672
  """Delete any views associated with this FeatureSet
@@ -558,20 +716,6 @@ class FeatureSetCore(Artifact):
558
716
  """
559
717
  return self.data_source.smart_sample()
560
718
 
561
- def anomalies(self) -> pd.DataFrame:
562
- """Get a set of anomalous data from the underlying DataSource
563
- Returns:
564
- pd.DataFrame: A dataframe of anomalies from the underlying DataSource
565
- """
566
-
567
- # FIXME: Mock this for now
568
- anom_df = self.sample().copy()
569
- anom_df["anomaly_score"] = np.random.rand(anom_df.shape[0])
570
- anom_df["cluster"] = np.random.randint(0, 10, anom_df.shape[0])
571
- anom_df["x"] = np.random.rand(anom_df.shape[0])
572
- anom_df["y"] = np.random.rand(anom_df.shape[0])
573
- return anom_df
574
-
575
719
  def value_counts(self) -> dict:
576
720
  """Get the value counts for the string columns of the underlying DataSource
577
721
 
@@ -667,7 +811,7 @@ if __name__ == "__main__":
667
811
  pd.set_option("display.width", 1000)
668
812
 
669
813
  # Grab a FeatureSet object and pull some information from it
670
- my_features = LocalFeatureSetCore("test_features")
814
+ my_features = LocalFeatureSetCore("abalone_features")
671
815
  if not my_features.exists():
672
816
  print("FeatureSet not found!")
673
817
  sys.exit(1)
@@ -707,7 +851,7 @@ if __name__ == "__main__":
707
851
 
708
852
  # Test getting the holdout ids
709
853
  print("Getting the hold out ids...")
710
- holdout_ids = my_features.get_training_holdouts("id")
854
+ holdout_ids = my_features.get_training_holdouts()
711
855
  print(f"Holdout IDs: {holdout_ids}")
712
856
 
713
857
  # Get a sample of the data
@@ -727,18 +871,78 @@ if __name__ == "__main__":
727
871
  # Set the holdout ids for the training view
728
872
  print("Setting hold out ids...")
729
873
  table = my_features.view("training").table
730
- df = my_features.query(f'SELECT id, name FROM "{table}"')
731
- my_holdout_ids = [id for id in df["id"] if id < 20]
732
- my_features.set_training_holdouts("id", my_holdout_ids)
733
-
734
- # Test the hold out set functionality with strings
735
- print("Setting hold out ids (strings)...")
736
- my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
737
- my_features.set_training_holdouts("name", my_holdout_ids)
874
+ df = my_features.query(f'SELECT auto_id, length FROM "{table}"')
875
+ my_holdout_ids = [id for id in df["auto_id"] if id < 20]
876
+ my_features.set_training_holdouts(my_holdout_ids)
738
877
 
739
878
  # Get the training data
740
879
  print("Getting the training data...")
741
880
  training_data = my_features.get_training_data()
881
+ print(f"Training Data: {training_data.shape}")
882
+
883
+ # Test set_sample_weights
884
+ print("\n--- Testing set_sample_weights ---")
885
+ sample_ids = df["auto_id"].tolist()[:5]
886
+ weight_dict = {sample_ids[0]: 0.0, sample_ids[1]: 0.5, sample_ids[2]: 2.0}
887
+ my_features.set_sample_weights(weight_dict)
888
+ training_view = my_features.view("training")
889
+ training_df = training_view.pull_dataframe()
890
+ print(f"Training view shape after set_sample_weights: {training_df.shape}")
891
+ print(f"Columns: {training_df.columns.tolist()}")
892
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
893
+ assert "training" in training_df.columns, "training column missing!"
894
+ # Verify zero-weight row was excluded
895
+ assert sample_ids[0] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
896
+ print("set_sample_weights test passed!")
897
+
898
+ # Test set_training_config with both holdouts and weights
899
+ print("\n--- Testing set_training_config (combined) ---")
900
+ holdout_ids = [id for id in df["auto_id"] if id >= 100 and id < 120]
901
+ weight_dict = {sample_ids[3]: 0.0, sample_ids[4]: 3.0} # exclude one, upweight another
902
+ my_features.set_training_config(holdout_ids=holdout_ids, weight_dict=weight_dict)
903
+ training_view = my_features.view("training")
904
+ training_df = training_view.pull_dataframe()
905
+ print(f"Training view shape after set_training_config: {training_df.shape}")
906
+ print(f"Columns: {training_df.columns.tolist()}")
907
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
908
+ assert "training" in training_df.columns, "training column missing!"
909
+ # Verify holdout IDs are marked as training=False
910
+ holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
911
+ assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
912
+ # Verify zero-weight row was excluded
913
+ assert sample_ids[3] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
914
+ # Verify upweighted row has correct weight
915
+ upweight_row = training_df[training_df["auto_id"] == sample_ids[4]]
916
+ assert upweight_row["sample_weight"].iloc[0] == 3.0, "Upweighted ID should have weight=3.0!"
917
+ print("set_training_config (combined) test passed!")
918
+
919
+ # Test set_training_config with only holdouts (should delegate to set_training_holdouts)
920
+ print("\n--- Testing set_training_config (holdouts only) ---")
921
+ my_features.set_training_config(holdout_ids=holdout_ids)
922
+ training_view = my_features.view("training")
923
+ training_df = training_view.pull_dataframe()
924
+ assert "training" in training_df.columns, "training column missing!"
925
+ holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
926
+ assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
927
+ print("set_training_config (holdouts only) test passed!")
928
+
929
+ # Test set_training_config with only weights (should delegate to set_sample_weights)
930
+ print("\n--- Testing set_training_config (weights only) ---")
931
+ my_features.set_training_config(weight_dict={sample_ids[0]: 0.5, sample_ids[1]: 2.0})
932
+ training_view = my_features.view("training")
933
+ training_df = training_view.pull_dataframe()
934
+ assert "sample_weight" in training_df.columns, "sample_weight column missing!"
935
+ print("set_training_config (weights only) test passed!")
936
+
937
+ # Test set_training_config with neither (should create standard training view)
938
+ print("\n--- Testing set_training_config (neither) ---")
939
+ my_features.set_training_config()
940
+ training_view = my_features.view("training")
941
+ training_df = training_view.pull_dataframe()
942
+ assert "training" in training_df.columns, "training column missing!"
943
+ print("set_training_config (neither) test passed!")
944
+
945
+ print("\n=== All training config tests passed! ===")
742
946
 
743
947
  # Now delete the AWS artifacts associated with this Feature Set
744
948
  # print("Deleting Workbench Feature Set...")