workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -7,7 +7,6 @@ from datetime import datetime, timezone
|
|
|
7
7
|
import botocore.exceptions
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import awswrangler as wr
|
|
10
|
-
import numpy as np
|
|
11
10
|
|
|
12
11
|
from sagemaker.feature_store.feature_group import FeatureGroup
|
|
13
12
|
from sagemaker.feature_store.feature_store import FeatureStore
|
|
@@ -17,7 +16,7 @@ from workbench.core.artifacts.artifact import Artifact
|
|
|
17
16
|
from workbench.core.artifacts.data_source_factory import DataSourceFactory
|
|
18
17
|
from workbench.core.artifacts.athena_source import AthenaSource
|
|
19
18
|
|
|
20
|
-
from typing import TYPE_CHECKING
|
|
19
|
+
from typing import TYPE_CHECKING, List, Dict, Union
|
|
21
20
|
|
|
22
21
|
from workbench.utils.aws_utils import aws_throttle
|
|
23
22
|
|
|
@@ -194,24 +193,24 @@ class FeatureSetCore(Artifact):
|
|
|
194
193
|
|
|
195
194
|
return View(self, view_name)
|
|
196
195
|
|
|
197
|
-
def set_display_columns(self,
|
|
196
|
+
def set_display_columns(self, display_columns: list[str]):
|
|
198
197
|
"""Set the display columns for this Data Source
|
|
199
198
|
|
|
200
199
|
Args:
|
|
201
|
-
|
|
200
|
+
display_columns (list[str]): The display columns for this Data Source
|
|
202
201
|
"""
|
|
203
202
|
# Check mismatch of display columns to computation columns
|
|
204
203
|
c_view = self.view("computation")
|
|
205
204
|
computation_columns = c_view.columns
|
|
206
|
-
mismatch_columns = [col for col in
|
|
205
|
+
mismatch_columns = [col for col in display_columns if col not in computation_columns]
|
|
207
206
|
if mismatch_columns:
|
|
208
207
|
self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
|
|
209
208
|
|
|
210
|
-
self.log.important(f"Setting Display Columns...{
|
|
209
|
+
self.log.important(f"Setting Display Columns...{display_columns}")
|
|
211
210
|
from workbench.core.views import DisplayView
|
|
212
211
|
|
|
213
212
|
# Create a NEW display view
|
|
214
|
-
DisplayView.create(self, source_table=c_view.table, column_list=
|
|
213
|
+
DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
|
|
215
214
|
|
|
216
215
|
def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
|
|
217
216
|
"""Set the computation columns for this FeatureSet
|
|
@@ -246,7 +245,7 @@ class FeatureSetCore(Artifact):
|
|
|
246
245
|
|
|
247
246
|
# Set the compressed features in our FeatureSet metadata
|
|
248
247
|
self.log.important(f"Setting Compressed Columns...{compressed_columns}")
|
|
249
|
-
self.upsert_workbench_meta({"
|
|
248
|
+
self.upsert_workbench_meta({"compressed_features": compressed_columns})
|
|
250
249
|
|
|
251
250
|
def get_compressed_features(self) -> list[str]:
|
|
252
251
|
"""Get the compressed features for this FeatureSet
|
|
@@ -255,7 +254,7 @@ class FeatureSetCore(Artifact):
|
|
|
255
254
|
list[str]: The compressed columns for this FeatureSet
|
|
256
255
|
"""
|
|
257
256
|
# Get the compressed features from our FeatureSet metadata
|
|
258
|
-
return self.workbench_meta().get("
|
|
257
|
+
return self.workbench_meta().get("compressed_features", [])
|
|
259
258
|
|
|
260
259
|
def num_columns(self) -> int:
|
|
261
260
|
"""Return the number of columns of the Feature Set"""
|
|
@@ -482,18 +481,6 @@ class FeatureSetCore(Artifact):
|
|
|
482
481
|
time.sleep(1)
|
|
483
482
|
cls.log.info(f"FeatureSet {feature_group.name} successfully deleted")
|
|
484
483
|
|
|
485
|
-
def set_training_holdouts(self, holdout_ids: list[str]):
|
|
486
|
-
"""Set the hold out ids for the training view for this FeatureSet
|
|
487
|
-
|
|
488
|
-
Args:
|
|
489
|
-
holdout_ids (list[str]): The list of holdout ids.
|
|
490
|
-
"""
|
|
491
|
-
from workbench.core.views import TrainingView
|
|
492
|
-
|
|
493
|
-
# Create a NEW training view
|
|
494
|
-
self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
|
|
495
|
-
TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
|
|
496
|
-
|
|
497
484
|
def get_training_holdouts(self) -> list[str]:
|
|
498
485
|
"""Get the hold out ids for the training view for this FeatureSet
|
|
499
486
|
|
|
@@ -509,6 +496,177 @@ class FeatureSetCore(Artifact):
|
|
|
509
496
|
].tolist()
|
|
510
497
|
return hold_out_ids
|
|
511
498
|
|
|
499
|
+
# ---- Public methods for training configuration ----
|
|
500
|
+
def set_training_config(
|
|
501
|
+
self,
|
|
502
|
+
holdout_ids: List[Union[str, int]] = None,
|
|
503
|
+
weight_dict: Dict[Union[str, int], float] = None,
|
|
504
|
+
default_weight: float = 1.0,
|
|
505
|
+
exclude_zero_weights: bool = True,
|
|
506
|
+
):
|
|
507
|
+
"""Configure training view with holdout IDs and/or sample weights.
|
|
508
|
+
|
|
509
|
+
This method creates a training view that can include both:
|
|
510
|
+
- A 'training' column (True/False) based on holdout IDs
|
|
511
|
+
- A 'sample_weight' column for weighted training
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
holdout_ids: List of IDs to mark as training=False (validation/holdout set)
|
|
515
|
+
weight_dict: Mapping of ID to sample weight
|
|
516
|
+
- weight > 1.0: oversample/emphasize
|
|
517
|
+
- weight = 1.0: normal (default)
|
|
518
|
+
- 0 < weight < 1.0: downweight/de-emphasize
|
|
519
|
+
- weight = 0.0: exclude from training (filtered out if exclude_zero_weights=True)
|
|
520
|
+
default_weight: Weight for IDs not in weight_dict (default: 1.0)
|
|
521
|
+
exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
|
|
522
|
+
|
|
523
|
+
Example:
|
|
524
|
+
# Temporal split with sample weights
|
|
525
|
+
fs.set_training_config(
|
|
526
|
+
holdout_ids=temporal_hold_out_ids, # IDs after cutoff date
|
|
527
|
+
weight_dict={'compound_42': 0.0, 'compound_99': 2.0}, # exclude/upweight
|
|
528
|
+
)
|
|
529
|
+
"""
|
|
530
|
+
from workbench.core.views.training_view import TrainingView
|
|
531
|
+
|
|
532
|
+
# If neither is provided, create a standard training view
|
|
533
|
+
if not holdout_ids and not weight_dict:
|
|
534
|
+
self.log.important("No holdouts or weights specified, creating standard training view")
|
|
535
|
+
TrainingView.create(self, id_column=self.id_column)
|
|
536
|
+
return
|
|
537
|
+
|
|
538
|
+
# If only holdout_ids, delegate to set_training_holdouts
|
|
539
|
+
if holdout_ids and not weight_dict:
|
|
540
|
+
self.set_training_holdouts(holdout_ids)
|
|
541
|
+
return
|
|
542
|
+
|
|
543
|
+
# If only weight_dict, delegate to set_sample_weights
|
|
544
|
+
if weight_dict and not holdout_ids:
|
|
545
|
+
self.set_sample_weights(weight_dict, default_weight, exclude_zero_weights)
|
|
546
|
+
return
|
|
547
|
+
|
|
548
|
+
# Both holdout_ids and weight_dict provided - build combined view
|
|
549
|
+
self.log.important(f"Setting training config: {len(holdout_ids)} holdouts, {len(weight_dict)} weights")
|
|
550
|
+
|
|
551
|
+
# Get column list (excluding AWS-generated columns)
|
|
552
|
+
from workbench.core.views.view_utils import get_column_list
|
|
553
|
+
|
|
554
|
+
aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
|
|
555
|
+
source_columns = get_column_list(self.data_source, self.table)
|
|
556
|
+
column_list = [col for col in source_columns if col not in aws_cols]
|
|
557
|
+
sql_columns = ", ".join([f'"{column}"' for column in column_list])
|
|
558
|
+
|
|
559
|
+
# Build inner query with both columns
|
|
560
|
+
training_case = self._build_holdout_case(holdout_ids)
|
|
561
|
+
weight_case = self._build_weight_case(weight_dict, default_weight)
|
|
562
|
+
inner_sql = f"SELECT {sql_columns}, {training_case}, {weight_case} FROM {self.table}"
|
|
563
|
+
|
|
564
|
+
# Optionally filter out zero weights
|
|
565
|
+
if exclude_zero_weights:
|
|
566
|
+
zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
|
|
567
|
+
if zero_count:
|
|
568
|
+
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
569
|
+
sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
570
|
+
else:
|
|
571
|
+
sql_query = inner_sql
|
|
572
|
+
|
|
573
|
+
self._create_training_view(sql_query)
|
|
574
|
+
|
|
575
|
+
def set_training_holdouts(self, holdout_ids: list[str]):
|
|
576
|
+
"""Set the hold out ids for the training view for this FeatureSet
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
holdout_ids (list[str]): The list of holdout ids.
|
|
580
|
+
"""
|
|
581
|
+
from workbench.core.views import TrainingView
|
|
582
|
+
|
|
583
|
+
self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
|
|
584
|
+
TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
|
|
585
|
+
|
|
586
|
+
def set_sample_weights(
|
|
587
|
+
self,
|
|
588
|
+
weight_dict: Dict[Union[str, int], float],
|
|
589
|
+
default_weight: float = 1.0,
|
|
590
|
+
exclude_zero_weights: bool = True,
|
|
591
|
+
):
|
|
592
|
+
"""Configure training view with sample weights for each ID.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
weight_dict: Mapping of ID to sample weight
|
|
596
|
+
- weight > 1.0: oversample/emphasize
|
|
597
|
+
- weight = 1.0: normal (default)
|
|
598
|
+
- 0 < weight < 1.0: downweight/de-emphasize
|
|
599
|
+
- weight = 0.0: exclude from training
|
|
600
|
+
default_weight: Weight for IDs not in weight_dict (default: 1.0)
|
|
601
|
+
exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
|
|
602
|
+
|
|
603
|
+
Example:
|
|
604
|
+
weights = {
|
|
605
|
+
'compound_42': 3.0, # oversample 3x
|
|
606
|
+
'compound_99': 0.1, # noisy, downweight
|
|
607
|
+
'compound_123': 0.0, # exclude from training
|
|
608
|
+
}
|
|
609
|
+
fs.set_sample_weights(weights) # zeros automatically excluded
|
|
610
|
+
fs.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
|
|
611
|
+
"""
|
|
612
|
+
from workbench.core.views import TrainingView
|
|
613
|
+
|
|
614
|
+
if not weight_dict:
|
|
615
|
+
self.log.important("Empty weight_dict, creating standard training view")
|
|
616
|
+
TrainingView.create(self, id_column=self.id_column)
|
|
617
|
+
return
|
|
618
|
+
|
|
619
|
+
self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
|
|
620
|
+
|
|
621
|
+
# Build inner query with sample weights
|
|
622
|
+
weight_case = self._build_weight_case(weight_dict, default_weight)
|
|
623
|
+
inner_sql = f"SELECT *, {weight_case} FROM {self.table}"
|
|
624
|
+
|
|
625
|
+
# Optionally filter out zero weights
|
|
626
|
+
if exclude_zero_weights:
|
|
627
|
+
zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
|
|
628
|
+
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
629
|
+
sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
630
|
+
else:
|
|
631
|
+
sql_query = inner_sql
|
|
632
|
+
|
|
633
|
+
TrainingView.create_with_sql(self, sql_query=sql_query, id_column=self.id_column)
|
|
634
|
+
|
|
635
|
+
# ---- Internal helpers for training view SQL generation ----
|
|
636
|
+
@staticmethod
|
|
637
|
+
def _format_id_for_sql(id_val: Union[str, int]) -> str:
|
|
638
|
+
"""Format an ID value for use in SQL."""
|
|
639
|
+
return repr(id_val)
|
|
640
|
+
|
|
641
|
+
def _build_holdout_case(self, holdout_ids: List[Union[str, int]]) -> str:
|
|
642
|
+
"""Build SQL CASE statement for training column based on holdout IDs."""
|
|
643
|
+
if all(isinstance(id_val, str) for id_val in holdout_ids):
|
|
644
|
+
formatted_ids = ", ".join(f"'{id_val}'" for id_val in holdout_ids)
|
|
645
|
+
else:
|
|
646
|
+
formatted_ids = ", ".join(map(str, holdout_ids))
|
|
647
|
+
return f"""CASE
|
|
648
|
+
WHEN {self.id_column} IN ({formatted_ids}) THEN False
|
|
649
|
+
ELSE True
|
|
650
|
+
END AS training"""
|
|
651
|
+
|
|
652
|
+
def _build_weight_case(self, weight_dict: Dict[Union[str, int], float], default_weight: float) -> str:
|
|
653
|
+
"""Build SQL CASE statement for sample_weight column."""
|
|
654
|
+
conditions = [
|
|
655
|
+
f"WHEN {self.id_column} = {self._format_id_for_sql(id_val)} THEN {weight}"
|
|
656
|
+
for id_val, weight in weight_dict.items()
|
|
657
|
+
]
|
|
658
|
+
case_body = "\n ".join(conditions)
|
|
659
|
+
return f"""CASE
|
|
660
|
+
{case_body}
|
|
661
|
+
ELSE {default_weight}
|
|
662
|
+
END AS sample_weight"""
|
|
663
|
+
|
|
664
|
+
def _create_training_view(self, sql_query: str):
|
|
665
|
+
"""Create the training view directly from a SQL query."""
|
|
666
|
+
view_table = f"{self.table}___training"
|
|
667
|
+
create_view_query = f"CREATE OR REPLACE VIEW {view_table} AS\n{sql_query}"
|
|
668
|
+
self.data_source.execute_statement(create_view_query)
|
|
669
|
+
|
|
512
670
|
@classmethod
|
|
513
671
|
def delete_views(cls, table: str, database: str):
|
|
514
672
|
"""Delete any views associated with this FeatureSet
|
|
@@ -558,20 +716,6 @@ class FeatureSetCore(Artifact):
|
|
|
558
716
|
"""
|
|
559
717
|
return self.data_source.smart_sample()
|
|
560
718
|
|
|
561
|
-
def anomalies(self) -> pd.DataFrame:
|
|
562
|
-
"""Get a set of anomalous data from the underlying DataSource
|
|
563
|
-
Returns:
|
|
564
|
-
pd.DataFrame: A dataframe of anomalies from the underlying DataSource
|
|
565
|
-
"""
|
|
566
|
-
|
|
567
|
-
# FIXME: Mock this for now
|
|
568
|
-
anom_df = self.sample().copy()
|
|
569
|
-
anom_df["anomaly_score"] = np.random.rand(anom_df.shape[0])
|
|
570
|
-
anom_df["cluster"] = np.random.randint(0, 10, anom_df.shape[0])
|
|
571
|
-
anom_df["x"] = np.random.rand(anom_df.shape[0])
|
|
572
|
-
anom_df["y"] = np.random.rand(anom_df.shape[0])
|
|
573
|
-
return anom_df
|
|
574
|
-
|
|
575
719
|
def value_counts(self) -> dict:
|
|
576
720
|
"""Get the value counts for the string columns of the underlying DataSource
|
|
577
721
|
|
|
@@ -667,7 +811,7 @@ if __name__ == "__main__":
|
|
|
667
811
|
pd.set_option("display.width", 1000)
|
|
668
812
|
|
|
669
813
|
# Grab a FeatureSet object and pull some information from it
|
|
670
|
-
my_features = LocalFeatureSetCore("
|
|
814
|
+
my_features = LocalFeatureSetCore("abalone_features")
|
|
671
815
|
if not my_features.exists():
|
|
672
816
|
print("FeatureSet not found!")
|
|
673
817
|
sys.exit(1)
|
|
@@ -707,7 +851,7 @@ if __name__ == "__main__":
|
|
|
707
851
|
|
|
708
852
|
# Test getting the holdout ids
|
|
709
853
|
print("Getting the hold out ids...")
|
|
710
|
-
holdout_ids = my_features.get_training_holdouts(
|
|
854
|
+
holdout_ids = my_features.get_training_holdouts()
|
|
711
855
|
print(f"Holdout IDs: {holdout_ids}")
|
|
712
856
|
|
|
713
857
|
# Get a sample of the data
|
|
@@ -727,18 +871,78 @@ if __name__ == "__main__":
|
|
|
727
871
|
# Set the holdout ids for the training view
|
|
728
872
|
print("Setting hold out ids...")
|
|
729
873
|
table = my_features.view("training").table
|
|
730
|
-
df = my_features.query(f'SELECT
|
|
731
|
-
my_holdout_ids = [id for id in df["
|
|
732
|
-
my_features.set_training_holdouts(
|
|
733
|
-
|
|
734
|
-
# Test the hold out set functionality with strings
|
|
735
|
-
print("Setting hold out ids (strings)...")
|
|
736
|
-
my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
|
|
737
|
-
my_features.set_training_holdouts("name", my_holdout_ids)
|
|
874
|
+
df = my_features.query(f'SELECT auto_id, length FROM "{table}"')
|
|
875
|
+
my_holdout_ids = [id for id in df["auto_id"] if id < 20]
|
|
876
|
+
my_features.set_training_holdouts(my_holdout_ids)
|
|
738
877
|
|
|
739
878
|
# Get the training data
|
|
740
879
|
print("Getting the training data...")
|
|
741
880
|
training_data = my_features.get_training_data()
|
|
881
|
+
print(f"Training Data: {training_data.shape}")
|
|
882
|
+
|
|
883
|
+
# Test set_sample_weights
|
|
884
|
+
print("\n--- Testing set_sample_weights ---")
|
|
885
|
+
sample_ids = df["auto_id"].tolist()[:5]
|
|
886
|
+
weight_dict = {sample_ids[0]: 0.0, sample_ids[1]: 0.5, sample_ids[2]: 2.0}
|
|
887
|
+
my_features.set_sample_weights(weight_dict)
|
|
888
|
+
training_view = my_features.view("training")
|
|
889
|
+
training_df = training_view.pull_dataframe()
|
|
890
|
+
print(f"Training view shape after set_sample_weights: {training_df.shape}")
|
|
891
|
+
print(f"Columns: {training_df.columns.tolist()}")
|
|
892
|
+
assert "sample_weight" in training_df.columns, "sample_weight column missing!"
|
|
893
|
+
assert "training" in training_df.columns, "training column missing!"
|
|
894
|
+
# Verify zero-weight row was excluded
|
|
895
|
+
assert sample_ids[0] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
|
|
896
|
+
print("set_sample_weights test passed!")
|
|
897
|
+
|
|
898
|
+
# Test set_training_config with both holdouts and weights
|
|
899
|
+
print("\n--- Testing set_training_config (combined) ---")
|
|
900
|
+
holdout_ids = [id for id in df["auto_id"] if id >= 100 and id < 120]
|
|
901
|
+
weight_dict = {sample_ids[3]: 0.0, sample_ids[4]: 3.0} # exclude one, upweight another
|
|
902
|
+
my_features.set_training_config(holdout_ids=holdout_ids, weight_dict=weight_dict)
|
|
903
|
+
training_view = my_features.view("training")
|
|
904
|
+
training_df = training_view.pull_dataframe()
|
|
905
|
+
print(f"Training view shape after set_training_config: {training_df.shape}")
|
|
906
|
+
print(f"Columns: {training_df.columns.tolist()}")
|
|
907
|
+
assert "sample_weight" in training_df.columns, "sample_weight column missing!"
|
|
908
|
+
assert "training" in training_df.columns, "training column missing!"
|
|
909
|
+
# Verify holdout IDs are marked as training=False
|
|
910
|
+
holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
|
|
911
|
+
assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
|
|
912
|
+
# Verify zero-weight row was excluded
|
|
913
|
+
assert sample_ids[3] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
|
|
914
|
+
# Verify upweighted row has correct weight
|
|
915
|
+
upweight_row = training_df[training_df["auto_id"] == sample_ids[4]]
|
|
916
|
+
assert upweight_row["sample_weight"].iloc[0] == 3.0, "Upweighted ID should have weight=3.0!"
|
|
917
|
+
print("set_training_config (combined) test passed!")
|
|
918
|
+
|
|
919
|
+
# Test set_training_config with only holdouts (should delegate to set_training_holdouts)
|
|
920
|
+
print("\n--- Testing set_training_config (holdouts only) ---")
|
|
921
|
+
my_features.set_training_config(holdout_ids=holdout_ids)
|
|
922
|
+
training_view = my_features.view("training")
|
|
923
|
+
training_df = training_view.pull_dataframe()
|
|
924
|
+
assert "training" in training_df.columns, "training column missing!"
|
|
925
|
+
holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
|
|
926
|
+
assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
|
|
927
|
+
print("set_training_config (holdouts only) test passed!")
|
|
928
|
+
|
|
929
|
+
# Test set_training_config with only weights (should delegate to set_sample_weights)
|
|
930
|
+
print("\n--- Testing set_training_config (weights only) ---")
|
|
931
|
+
my_features.set_training_config(weight_dict={sample_ids[0]: 0.5, sample_ids[1]: 2.0})
|
|
932
|
+
training_view = my_features.view("training")
|
|
933
|
+
training_df = training_view.pull_dataframe()
|
|
934
|
+
assert "sample_weight" in training_df.columns, "sample_weight column missing!"
|
|
935
|
+
print("set_training_config (weights only) test passed!")
|
|
936
|
+
|
|
937
|
+
# Test set_training_config with neither (should create standard training view)
|
|
938
|
+
print("\n--- Testing set_training_config (neither) ---")
|
|
939
|
+
my_features.set_training_config()
|
|
940
|
+
training_view = my_features.view("training")
|
|
941
|
+
training_df = training_view.pull_dataframe()
|
|
942
|
+
assert "training" in training_df.columns, "training column missing!"
|
|
943
|
+
print("set_training_config (neither) test passed!")
|
|
944
|
+
|
|
945
|
+
print("\n=== All training config tests passed! ===")
|
|
742
946
|
|
|
743
947
|
# Now delete the AWS artifacts associated with this Feature Set
|
|
744
948
|
# print("Deleting Workbench Feature Set...")
|