workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/feature_set.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +238 -225
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +196 -68
- workbench/model_scripts/chemprop/generated_model_script.py +197 -72
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -64
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +82 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/model_utils.py +0 -1
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +52 -36
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +184 -85
- workbench/web_interface/components/settings_menu.py +185 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
|
@@ -7,7 +7,6 @@ from datetime import datetime, timezone
|
|
|
7
7
|
import botocore.exceptions
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import awswrangler as wr
|
|
10
|
-
import numpy as np
|
|
11
10
|
|
|
12
11
|
from sagemaker.feature_store.feature_group import FeatureGroup
|
|
13
12
|
from sagemaker.feature_store.feature_store import FeatureStore
|
|
@@ -16,9 +15,8 @@ from sagemaker.feature_store.feature_store import FeatureStore
|
|
|
16
15
|
from workbench.core.artifacts.artifact import Artifact
|
|
17
16
|
from workbench.core.artifacts.data_source_factory import DataSourceFactory
|
|
18
17
|
from workbench.core.artifacts.athena_source import AthenaSource
|
|
19
|
-
from workbench.utils.deprecated_utils import deprecated
|
|
20
18
|
|
|
21
|
-
from typing import TYPE_CHECKING,
|
|
19
|
+
from typing import TYPE_CHECKING, List, Dict, Union
|
|
22
20
|
|
|
23
21
|
from workbench.utils.aws_utils import aws_throttle
|
|
24
22
|
|
|
@@ -483,18 +481,6 @@ class FeatureSetCore(Artifact):
|
|
|
483
481
|
time.sleep(1)
|
|
484
482
|
cls.log.info(f"FeatureSet {feature_group.name} successfully deleted")
|
|
485
483
|
|
|
486
|
-
def set_training_holdouts(self, holdout_ids: list[str]):
|
|
487
|
-
"""Set the hold out ids for the training view for this FeatureSet
|
|
488
|
-
|
|
489
|
-
Args:
|
|
490
|
-
holdout_ids (list[str]): The list of holdout ids.
|
|
491
|
-
"""
|
|
492
|
-
from workbench.core.views import TrainingView
|
|
493
|
-
|
|
494
|
-
# Create a NEW training view
|
|
495
|
-
self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
|
|
496
|
-
TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
|
|
497
|
-
|
|
498
484
|
def get_training_holdouts(self) -> list[str]:
|
|
499
485
|
"""Get the hold out ids for the training view for this FeatureSet
|
|
500
486
|
|
|
@@ -510,183 +496,234 @@ class FeatureSetCore(Artifact):
|
|
|
510
496
|
].tolist()
|
|
511
497
|
return hold_out_ids
|
|
512
498
|
|
|
513
|
-
|
|
499
|
+
# ---- Public methods for training configuration ----
|
|
500
|
+
def set_training_config(
|
|
514
501
|
self,
|
|
515
|
-
|
|
502
|
+
holdout_ids: List[Union[str, int]] = None,
|
|
503
|
+
weight_dict: Dict[Union[str, int], float] = None,
|
|
516
504
|
default_weight: float = 1.0,
|
|
517
505
|
exclude_zero_weights: bool = True,
|
|
518
506
|
):
|
|
519
|
-
"""Configure training view with
|
|
507
|
+
"""Configure training view with holdout IDs and/or sample weights.
|
|
508
|
+
|
|
509
|
+
This method creates a training view that can include both:
|
|
510
|
+
- A 'training' column (True/False) based on holdout IDs
|
|
511
|
+
- A 'sample_weight' column for weighted training
|
|
520
512
|
|
|
521
513
|
Args:
|
|
514
|
+
holdout_ids: List of IDs to mark as training=False (validation/holdout set)
|
|
522
515
|
weight_dict: Mapping of ID to sample weight
|
|
523
516
|
- weight > 1.0: oversample/emphasize
|
|
524
517
|
- weight = 1.0: normal (default)
|
|
525
518
|
- 0 < weight < 1.0: downweight/de-emphasize
|
|
526
|
-
- weight = 0.0: exclude from training
|
|
519
|
+
- weight = 0.0: exclude from training (filtered out if exclude_zero_weights=True)
|
|
527
520
|
default_weight: Weight for IDs not in weight_dict (default: 1.0)
|
|
528
521
|
exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
|
|
529
522
|
|
|
530
523
|
Example:
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
'
|
|
535
|
-
|
|
536
|
-
model.set_sample_weights(weights) # zeros automatically excluded
|
|
537
|
-
model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
|
|
524
|
+
# Temporal split with sample weights
|
|
525
|
+
fs.set_training_config(
|
|
526
|
+
holdout_ids=temporal_hold_out_ids, # IDs after cutoff date
|
|
527
|
+
weight_dict={'compound_42': 0.0, 'compound_99': 2.0}, # exclude/upweight
|
|
528
|
+
)
|
|
538
529
|
"""
|
|
539
|
-
from workbench.core.views import TrainingView
|
|
530
|
+
from workbench.core.views.training_view import TrainingView
|
|
540
531
|
|
|
541
|
-
|
|
542
|
-
|
|
532
|
+
# If neither is provided, create a standard training view
|
|
533
|
+
if not holdout_ids and not weight_dict:
|
|
534
|
+
self.log.important("No holdouts or weights specified, creating standard training view")
|
|
543
535
|
TrainingView.create(self, id_column=self.id_column)
|
|
544
536
|
return
|
|
545
537
|
|
|
546
|
-
|
|
538
|
+
# If only holdout_ids, delegate to set_training_holdouts
|
|
539
|
+
if holdout_ids and not weight_dict:
|
|
540
|
+
self.set_training_holdouts(holdout_ids)
|
|
541
|
+
return
|
|
547
542
|
|
|
548
|
-
#
|
|
549
|
-
|
|
550
|
-
|
|
543
|
+
# If only weight_dict, delegate to set_sample_weights
|
|
544
|
+
if weight_dict and not holdout_ids:
|
|
545
|
+
self.set_sample_weights(weight_dict, default_weight, exclude_zero_weights)
|
|
546
|
+
return
|
|
551
547
|
|
|
552
|
-
#
|
|
553
|
-
|
|
554
|
-
f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
|
|
555
|
-
]
|
|
556
|
-
case_statement = "\n ".join(case_conditions)
|
|
548
|
+
# Both holdout_ids and weight_dict provided - build combined view
|
|
549
|
+
self.log.important(f"Setting training config: {len(holdout_ids)} holdouts, {len(weight_dict)} weights")
|
|
557
550
|
|
|
558
|
-
#
|
|
559
|
-
|
|
560
|
-
*,
|
|
561
|
-
CASE
|
|
562
|
-
{case_statement}
|
|
563
|
-
ELSE {default_weight}
|
|
564
|
-
END AS sample_weight
|
|
565
|
-
FROM {self.table}"""
|
|
551
|
+
# Get column list (excluding AWS-generated columns)
|
|
552
|
+
from workbench.core.views.view_utils import get_column_list
|
|
566
553
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
571
|
-
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
572
|
-
else:
|
|
573
|
-
custom_sql = inner_sql
|
|
554
|
+
aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
|
|
555
|
+
source_columns = get_column_list(self.data_source, self.table)
|
|
556
|
+
column_list = [col for col in source_columns if col not in aws_cols]
|
|
574
557
|
|
|
575
|
-
|
|
558
|
+
# Build the training column CASE statement
|
|
559
|
+
training_case = self._build_holdout_case(holdout_ids)
|
|
576
560
|
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
561
|
+
# For large weight_dict, use supplemental table + JOIN
|
|
562
|
+
if len(weight_dict) >= 100:
|
|
563
|
+
self.log.info("Using supplemental table approach for large weight_dict")
|
|
564
|
+
weights_table = self._create_weights_table(weight_dict)
|
|
580
565
|
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
If None or empty string, will reset to training view with no filter
|
|
584
|
-
(default: None)
|
|
585
|
-
"""
|
|
586
|
-
from workbench.core.views import TrainingView
|
|
566
|
+
# Build column selection with table alias
|
|
567
|
+
sql_columns = ", ".join([f't."{col}"' for col in column_list])
|
|
587
568
|
|
|
588
|
-
|
|
589
|
-
|
|
569
|
+
# Build JOIN query with training CASE and weight from joined table
|
|
570
|
+
training_case_aliased = training_case.replace(f"WHEN {self.id_column} IN", f"WHEN t.{self.id_column} IN")
|
|
571
|
+
inner_sql = f"""SELECT {sql_columns}, {training_case_aliased},
|
|
572
|
+
COALESCE(w.sample_weight, {default_weight}) AS sample_weight
|
|
573
|
+
FROM {self.table} t
|
|
574
|
+
LEFT JOIN {weights_table} w ON t.{self.id_column} = w.{self.id_column}"""
|
|
575
|
+
else:
|
|
576
|
+
# For small weight_dict, use CASE statement
|
|
577
|
+
sql_columns = ", ".join([f'"{column}"' for column in column_list])
|
|
578
|
+
weight_case = self._build_weight_case(weight_dict, default_weight)
|
|
579
|
+
inner_sql = f"SELECT {sql_columns}, {training_case}, {weight_case} FROM {self.table}"
|
|
590
580
|
|
|
591
|
-
#
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
581
|
+
# Optionally filter out zero weights
|
|
582
|
+
if exclude_zero_weights:
|
|
583
|
+
zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
|
|
584
|
+
if zero_count:
|
|
585
|
+
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
586
|
+
sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
587
|
+
else:
|
|
588
|
+
sql_query = inner_sql
|
|
596
589
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
590
|
+
self._create_training_view(sql_query)
|
|
591
|
+
|
|
592
|
+
def set_training_holdouts(self, holdout_ids: list[str]):
|
|
593
|
+
"""Set the hold out ids for the training view for this FeatureSet
|
|
600
594
|
|
|
601
595
|
Args:
|
|
602
|
-
|
|
603
|
-
column_name (Optional[str]): Column name to filter on.
|
|
604
|
-
If None, uses self.id_column (default: None)
|
|
596
|
+
holdout_ids (list[str]): The list of holdout ids.
|
|
605
597
|
"""
|
|
606
|
-
|
|
607
|
-
column = column_name or self.id_column
|
|
608
|
-
|
|
609
|
-
# Handle empty list case
|
|
610
|
-
if not ids:
|
|
611
|
-
self.log.warning("No IDs provided to exclude")
|
|
612
|
-
return
|
|
613
|
-
|
|
614
|
-
# Build the filter expression with proper SQL quoting
|
|
615
|
-
quoted_ids = ", ".join([repr(id) for id in ids])
|
|
616
|
-
filter_expression = f"{column} NOT IN ({quoted_ids})"
|
|
598
|
+
from workbench.core.views import TrainingView
|
|
617
599
|
|
|
618
|
-
|
|
619
|
-
self.
|
|
600
|
+
self.log.important(f"Setting Training Holdouts: {len(holdout_ids)} ids...")
|
|
601
|
+
TrainingView.create(self, id_column=self.id_column, holdout_ids=holdout_ids)
|
|
620
602
|
|
|
621
|
-
|
|
622
|
-
def set_training_sampling(
|
|
603
|
+
def set_sample_weights(
|
|
623
604
|
self,
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
605
|
+
weight_dict: Dict[Union[str, int], float],
|
|
606
|
+
default_weight: float = 1.0,
|
|
607
|
+
exclude_zero_weights: bool = True,
|
|
627
608
|
):
|
|
628
|
-
"""Configure training view with
|
|
609
|
+
"""Configure training view with sample weights for each ID.
|
|
629
610
|
|
|
630
611
|
Args:
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
612
|
+
weight_dict: Mapping of ID to sample weight
|
|
613
|
+
- weight > 1.0: oversample/emphasize
|
|
614
|
+
- weight = 1.0: normal (default)
|
|
615
|
+
- 0 < weight < 1.0: downweight/de-emphasize
|
|
616
|
+
- weight = 0.0: exclude from training
|
|
617
|
+
default_weight: Weight for IDs not in weight_dict (default: 1.0)
|
|
618
|
+
exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
|
|
619
|
+
|
|
620
|
+
Example:
|
|
621
|
+
weights = {
|
|
622
|
+
'compound_42': 3.0, # oversample 3x
|
|
623
|
+
'compound_99': 0.1, # noisy, downweight
|
|
624
|
+
'compound_123': 0.0, # exclude from training
|
|
625
|
+
}
|
|
626
|
+
fs.set_sample_weights(weights) # zeros automatically excluded
|
|
627
|
+
fs.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
|
|
634
628
|
|
|
635
629
|
Note:
|
|
636
|
-
|
|
630
|
+
For large weight_dict (100+ entries), weights are stored as a supplemental
|
|
631
|
+
table and joined to avoid Athena query size limits.
|
|
637
632
|
"""
|
|
638
633
|
from workbench.core.views import TrainingView
|
|
639
634
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
replicate_ids = replicate_ids or []
|
|
643
|
-
|
|
644
|
-
# Remove any replicate_ids that are also in exclude_ids (exclusion wins)
|
|
645
|
-
replicate_ids = [rid for rid in replicate_ids if rid not in exclude_ids]
|
|
646
|
-
|
|
647
|
-
# If no sampling needed, just create normal view
|
|
648
|
-
if not exclude_ids and not replicate_ids:
|
|
649
|
-
self.log.important("No sampling specified, creating standard training view")
|
|
635
|
+
if not weight_dict:
|
|
636
|
+
self.log.important("Empty weight_dict, creating standard training view")
|
|
650
637
|
TrainingView.create(self, id_column=self.id_column)
|
|
651
638
|
return
|
|
652
639
|
|
|
653
|
-
|
|
654
|
-
self.log.important(
|
|
655
|
-
f"Excluding {len(exclude_ids)} IDs, Replicating {len(replicate_ids)} IDs "
|
|
656
|
-
f"(factor: {replication_factor}x)"
|
|
657
|
-
)
|
|
658
|
-
|
|
659
|
-
# Helper to format IDs for SQL
|
|
660
|
-
def format_ids(ids):
|
|
661
|
-
return ", ".join([repr(id) for id in ids])
|
|
640
|
+
self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
|
|
662
641
|
|
|
663
|
-
#
|
|
664
|
-
|
|
642
|
+
# For large weight_dict, use supplemental table + JOIN to avoid query size limits
|
|
643
|
+
if len(weight_dict) >= 100:
|
|
644
|
+
self.log.info("Using supplemental table approach for large weight_dict")
|
|
645
|
+
weights_table = self._create_weights_table(weight_dict)
|
|
665
646
|
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
647
|
+
# Build JOIN query with COALESCE for default weight
|
|
648
|
+
inner_sql = f"""SELECT t.*, COALESCE(w.sample_weight, {default_weight}) AS sample_weight
|
|
649
|
+
FROM {self.table} t
|
|
650
|
+
LEFT JOIN {weights_table} w ON t.{self.id_column} = w.{self.id_column}"""
|
|
651
|
+
else:
|
|
652
|
+
# For small weight_dict, use CASE statement (simpler, no extra table)
|
|
653
|
+
weight_case = self._build_weight_case(weight_dict, default_weight)
|
|
654
|
+
inner_sql = f"SELECT *, {weight_case} FROM {self.table}"
|
|
669
655
|
|
|
670
|
-
#
|
|
671
|
-
if
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
656
|
+
# Optionally filter out zero weights
|
|
657
|
+
if exclude_zero_weights:
|
|
658
|
+
zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
|
|
659
|
+
if zero_count:
|
|
660
|
+
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
661
|
+
sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
662
|
+
else:
|
|
663
|
+
sql_query = inner_sql
|
|
675
664
|
|
|
676
|
-
|
|
665
|
+
TrainingView.create_with_sql(self, sql_query=sql_query, id_column=self.id_column)
|
|
677
666
|
|
|
678
|
-
|
|
667
|
+
# ---- Internal helpers for training view SQL generation ----
|
|
668
|
+
@staticmethod
|
|
669
|
+
def _format_id_for_sql(id_val: Union[str, int]) -> str:
|
|
670
|
+
"""Format an ID value for use in SQL."""
|
|
671
|
+
return repr(id_val)
|
|
679
672
|
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
673
|
+
def _build_holdout_case(self, holdout_ids: List[Union[str, int]]) -> str:
|
|
674
|
+
"""Build SQL CASE statement for training column based on holdout IDs."""
|
|
675
|
+
if all(isinstance(id_val, str) for id_val in holdout_ids):
|
|
676
|
+
formatted_ids = ", ".join(f"'{id_val}'" for id_val in holdout_ids)
|
|
684
677
|
else:
|
|
685
|
-
|
|
686
|
-
|
|
678
|
+
formatted_ids = ", ".join(map(str, holdout_ids))
|
|
679
|
+
return f"""CASE
|
|
680
|
+
WHEN {self.id_column} IN ({formatted_ids}) THEN False
|
|
681
|
+
ELSE True
|
|
682
|
+
END AS training"""
|
|
683
|
+
|
|
684
|
+
def _build_weight_case(self, weight_dict: Dict[Union[str, int], float], default_weight: float) -> str:
|
|
685
|
+
"""Build SQL CASE statement for sample_weight column."""
|
|
686
|
+
conditions = [
|
|
687
|
+
f"WHEN {self.id_column} = {self._format_id_for_sql(id_val)} THEN {weight}"
|
|
688
|
+
for id_val, weight in weight_dict.items()
|
|
689
|
+
]
|
|
690
|
+
case_body = "\n ".join(conditions)
|
|
691
|
+
return f"""CASE
|
|
692
|
+
{case_body}
|
|
693
|
+
ELSE {default_weight}
|
|
694
|
+
END AS sample_weight"""
|
|
695
|
+
|
|
696
|
+
def _create_training_view(self, sql_query: str):
|
|
697
|
+
"""Create the training view directly from a SQL query."""
|
|
698
|
+
view_table = f"{self.table}___training"
|
|
699
|
+
create_view_query = f"CREATE OR REPLACE VIEW {view_table} AS\n{sql_query}"
|
|
700
|
+
self.data_source.execute_statement(create_view_query)
|
|
701
|
+
|
|
702
|
+
def _create_weights_table(self, weight_dict: Dict[Union[str, int], float]) -> str:
|
|
703
|
+
"""Store sample weights as a supplemental data table.
|
|
704
|
+
|
|
705
|
+
Args:
|
|
706
|
+
weight_dict: Mapping of ID to sample weight
|
|
687
707
|
|
|
688
|
-
|
|
689
|
-
|
|
708
|
+
Returns:
|
|
709
|
+
str: The name of the created supplemental table
|
|
710
|
+
"""
|
|
711
|
+
from workbench.core.views.view_utils import dataframe_to_table
|
|
712
|
+
|
|
713
|
+
# Create DataFrame from weight_dict
|
|
714
|
+
df = pd.DataFrame(
|
|
715
|
+
[(id_val, weight) for id_val, weight in weight_dict.items()],
|
|
716
|
+
columns=[self.id_column, "sample_weight"],
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
# Supplemental table name follows convention: _{base_table}___sample_weights
|
|
720
|
+
weights_table = f"_{self.table}___sample_weights"
|
|
721
|
+
|
|
722
|
+
# Store as supplemental data table
|
|
723
|
+
self.log.info(f"Creating supplemental weights table: {weights_table}")
|
|
724
|
+
dataframe_to_table(self.data_source, df, weights_table)
|
|
725
|
+
|
|
726
|
+
return weights_table
|
|
690
727
|
|
|
691
728
|
@classmethod
|
|
692
729
|
def delete_views(cls, table: str, database: str):
|
|
@@ -737,20 +774,6 @@ class FeatureSetCore(Artifact):
|
|
|
737
774
|
"""
|
|
738
775
|
return self.data_source.smart_sample()
|
|
739
776
|
|
|
740
|
-
def anomalies(self) -> pd.DataFrame:
|
|
741
|
-
"""Get a set of anomalous data from the underlying DataSource
|
|
742
|
-
Returns:
|
|
743
|
-
pd.DataFrame: A dataframe of anomalies from the underlying DataSource
|
|
744
|
-
"""
|
|
745
|
-
|
|
746
|
-
# FIXME: Mock this for now
|
|
747
|
-
anom_df = self.sample().copy()
|
|
748
|
-
anom_df["anomaly_score"] = np.random.rand(anom_df.shape[0])
|
|
749
|
-
anom_df["cluster"] = np.random.randint(0, 10, anom_df.shape[0])
|
|
750
|
-
anom_df["x"] = np.random.rand(anom_df.shape[0])
|
|
751
|
-
anom_df["y"] = np.random.rand(anom_df.shape[0])
|
|
752
|
-
return anom_df
|
|
753
|
-
|
|
754
777
|
def value_counts(self) -> dict:
|
|
755
778
|
"""Get the value counts for the string columns of the underlying DataSource
|
|
756
779
|
|
|
@@ -915,81 +938,71 @@ if __name__ == "__main__":
|
|
|
915
938
|
training_data = my_features.get_training_data()
|
|
916
939
|
print(f"Training Data: {training_data.shape}")
|
|
917
940
|
|
|
918
|
-
# Test
|
|
919
|
-
print("
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
print("
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
941
|
+
# Test set_sample_weights
|
|
942
|
+
print("\n--- Testing set_sample_weights ---")
|
|
943
|
+
sample_ids = df["auto_id"].tolist()[:5]
|
|
944
|
+
weight_dict = {sample_ids[0]: 0.0, sample_ids[1]: 0.5, sample_ids[2]: 2.0}
|
|
945
|
+
my_features.set_sample_weights(weight_dict)
|
|
946
|
+
training_view = my_features.view("training")
|
|
947
|
+
training_df = training_view.pull_dataframe()
|
|
948
|
+
print(f"Training view shape after set_sample_weights: {training_df.shape}")
|
|
949
|
+
print(f"Columns: {training_df.columns.tolist()}")
|
|
950
|
+
assert "sample_weight" in training_df.columns, "sample_weight column missing!"
|
|
951
|
+
assert "training" in training_df.columns, "training column missing!"
|
|
952
|
+
# Verify zero-weight row was excluded
|
|
953
|
+
assert sample_ids[0] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
|
|
954
|
+
print("set_sample_weights test passed!")
|
|
955
|
+
|
|
956
|
+
# Test set_training_config with both holdouts and weights
|
|
957
|
+
print("\n--- Testing set_training_config (combined) ---")
|
|
958
|
+
holdout_ids = [id for id in df["auto_id"] if id >= 100 and id < 120]
|
|
959
|
+
weight_dict = {sample_ids[3]: 0.0, sample_ids[4]: 3.0} # exclude one, upweight another
|
|
960
|
+
my_features.set_training_config(holdout_ids=holdout_ids, weight_dict=weight_dict)
|
|
961
|
+
training_view = my_features.view("training")
|
|
962
|
+
training_df = training_view.pull_dataframe()
|
|
963
|
+
print(f"Training view shape after set_training_config: {training_df.shape}")
|
|
964
|
+
print(f"Columns: {training_df.columns.tolist()}")
|
|
965
|
+
assert "sample_weight" in training_df.columns, "sample_weight column missing!"
|
|
966
|
+
assert "training" in training_df.columns, "training column missing!"
|
|
967
|
+
# Verify holdout IDs are marked as training=False
|
|
968
|
+
holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
|
|
969
|
+
assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
|
|
970
|
+
# Verify zero-weight row was excluded
|
|
971
|
+
assert sample_ids[3] not in training_df["auto_id"].values, "Zero-weight ID should be excluded!"
|
|
972
|
+
# Verify upweighted row has correct weight
|
|
973
|
+
upweight_row = training_df[training_df["auto_id"] == sample_ids[4]]
|
|
974
|
+
assert upweight_row["sample_weight"].iloc[0] == 3.0, "Upweighted ID should have weight=3.0!"
|
|
975
|
+
print("set_training_config (combined) test passed!")
|
|
976
|
+
|
|
977
|
+
# Test set_training_config with only holdouts (should delegate to set_training_holdouts)
|
|
978
|
+
print("\n--- Testing set_training_config (holdouts only) ---")
|
|
979
|
+
my_features.set_training_config(holdout_ids=holdout_ids)
|
|
980
|
+
training_view = my_features.view("training")
|
|
981
|
+
training_df = training_view.pull_dataframe()
|
|
982
|
+
assert "training" in training_df.columns, "training column missing!"
|
|
983
|
+
holdout_rows = training_df[training_df["auto_id"].isin(holdout_ids)]
|
|
984
|
+
assert all(holdout_rows["training"] == False), "Holdout IDs should have training=False!" # noqa: E712
|
|
985
|
+
print("set_training_config (holdouts only) test passed!")
|
|
986
|
+
|
|
987
|
+
# Test set_training_config with only weights (should delegate to set_sample_weights)
|
|
988
|
+
print("\n--- Testing set_training_config (weights only) ---")
|
|
989
|
+
my_features.set_training_config(weight_dict={sample_ids[0]: 0.5, sample_ids[1]: 2.0})
|
|
990
|
+
training_view = my_features.view("training")
|
|
991
|
+
training_df = training_view.pull_dataframe()
|
|
992
|
+
assert "sample_weight" in training_df.columns, "sample_weight column missing!"
|
|
993
|
+
print("set_training_config (weights only) test passed!")
|
|
994
|
+
|
|
995
|
+
# Test set_training_config with neither (should create standard training view)
|
|
996
|
+
print("\n--- Testing set_training_config (neither) ---")
|
|
997
|
+
my_features.set_training_config()
|
|
998
|
+
training_view = my_features.view("training")
|
|
999
|
+
training_df = training_view.pull_dataframe()
|
|
1000
|
+
assert "training" in training_df.columns, "training column missing!"
|
|
1001
|
+
print("set_training_config (neither) test passed!")
|
|
1002
|
+
|
|
1003
|
+
print("\n=== All training config tests passed! ===")
|
|
938
1004
|
|
|
939
1005
|
# Now delete the AWS artifacts associated with this Feature Set
|
|
940
1006
|
# print("Deleting Workbench Feature Set...")
|
|
941
1007
|
# my_features.delete()
|
|
942
1008
|
# print("Done")
|
|
943
|
-
|
|
944
|
-
# Test set_training_sampling with exclusions and replications
|
|
945
|
-
print("\n--- Testing set_training_sampling ---")
|
|
946
|
-
my_features.set_training_filter(None) # Reset any existing filters
|
|
947
|
-
original_count = num_rows
|
|
948
|
-
|
|
949
|
-
# Get valid IDs from the table
|
|
950
|
-
all_data = my_features.query(f'SELECT auto_id, length FROM "{table}"')
|
|
951
|
-
valid_ids = sorted(all_data["auto_id"].tolist())
|
|
952
|
-
print(f"Valid IDs range from {valid_ids[0]} to {valid_ids[-1]}")
|
|
953
|
-
|
|
954
|
-
exclude_list = valid_ids[0:3] # First 3 IDs
|
|
955
|
-
replicate_list = valid_ids[10:13] # IDs at positions 10, 11, 12
|
|
956
|
-
|
|
957
|
-
print(f"Original row count: {original_count}")
|
|
958
|
-
print(f"Excluding IDs: {exclude_list}")
|
|
959
|
-
print(f"Replicating IDs: {replicate_list}")
|
|
960
|
-
|
|
961
|
-
# Test with default replication factor (2x)
|
|
962
|
-
print("\n--- Testing with replication_factor=2 (default) ---")
|
|
963
|
-
my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list)
|
|
964
|
-
training_data = my_features.get_training_data()
|
|
965
|
-
print(f"Training Data after sampling: {training_data.shape}")
|
|
966
|
-
|
|
967
|
-
# Verify exclusions
|
|
968
|
-
for exc_id in exclude_list:
|
|
969
|
-
count = len(training_data[training_data["auto_id"] == exc_id])
|
|
970
|
-
print(f"Excluded ID {exc_id} appears {count} times (should be 0)")
|
|
971
|
-
|
|
972
|
-
# Verify replications
|
|
973
|
-
for rep_id in replicate_list:
|
|
974
|
-
count = len(training_data[training_data["auto_id"] == rep_id])
|
|
975
|
-
print(f"Replicated ID {rep_id} appears {count} times (should be 2)")
|
|
976
|
-
|
|
977
|
-
# Test with replication factor of 5
|
|
978
|
-
print("\n--- Testing with replication_factor=5 ---")
|
|
979
|
-
replicate_list_5x = [20, 21]
|
|
980
|
-
my_features.set_training_sampling(exclude_ids=exclude_list, replicate_ids=replicate_list_5x, replication_factor=5)
|
|
981
|
-
training_data = my_features.get_training_data()
|
|
982
|
-
print(f"Training Data after sampling: {training_data.shape}")
|
|
983
|
-
|
|
984
|
-
# Verify 5x replication
|
|
985
|
-
for rep_id in replicate_list_5x:
|
|
986
|
-
count = len(training_data[training_data["auto_id"] == rep_id])
|
|
987
|
-
print(f"Replicated ID {rep_id} appears {count} times (should be 5)")
|
|
988
|
-
|
|
989
|
-
# Test with large replication list (simulate 100 IDs)
|
|
990
|
-
print("\n--- Testing with large ID list (100 IDs) ---")
|
|
991
|
-
large_replicate_list = list(range(30, 130)) # 100 IDs
|
|
992
|
-
my_features.set_training_sampling(replicate_ids=large_replicate_list, replication_factor=3)
|
|
993
|
-
training_data = my_features.get_training_data()
|
|
994
|
-
print(f"Training Data after sampling: {training_data.shape}")
|
|
995
|
-
print(f"Expected extra rows: {len(large_replicate_list) * 3}")
|
|
@@ -227,20 +227,14 @@ class FeaturesToModel(Transform):
|
|
|
227
227
|
self.log.critical(msg)
|
|
228
228
|
raise ValueError(msg)
|
|
229
229
|
|
|
230
|
-
# Dynamically create the metric definitions
|
|
230
|
+
# Dynamically create the metric definitions (per-class precision/recall/f1/support)
|
|
231
|
+
# Note: Confusion matrix metrics are skipped to stay under SageMaker's 40 metric limit
|
|
231
232
|
metrics = ["precision", "recall", "f1", "support"]
|
|
232
233
|
metric_definitions = []
|
|
233
234
|
for t in self.class_labels:
|
|
234
235
|
for m in metrics:
|
|
235
236
|
metric_definitions.append({"Name": f"Metrics:{t}:{m}", "Regex": f"Metrics:{t}:{m} ([0-9.]+)"})
|
|
236
237
|
|
|
237
|
-
# Add the confusion matrix metrics
|
|
238
|
-
for row in self.class_labels:
|
|
239
|
-
for col in self.class_labels:
|
|
240
|
-
metric_definitions.append(
|
|
241
|
-
{"Name": f"ConfusionMatrix:{row}:{col}", "Regex": f"ConfusionMatrix:{row}:{col} ([0-9.]+)"}
|
|
242
|
-
)
|
|
243
|
-
|
|
244
238
|
# If the model type is UNKNOWN, our metric_definitions will be empty
|
|
245
239
|
else:
|
|
246
240
|
self.log.important(f"ModelType is {self.model_type}, skipping metric_definitions...")
|
|
@@ -148,6 +148,7 @@ class ModelToEndpoint(Transform):
|
|
|
148
148
|
deserializer=CSVDeserializer(),
|
|
149
149
|
data_capture_config=data_capture_config,
|
|
150
150
|
tags=aws_tags,
|
|
151
|
+
container_startup_health_check_timeout=300,
|
|
151
152
|
)
|
|
152
153
|
except ClientError as e:
|
|
153
154
|
# Check if this is the "endpoint config already exists" error
|
|
@@ -164,6 +165,7 @@ class ModelToEndpoint(Transform):
|
|
|
164
165
|
deserializer=CSVDeserializer(),
|
|
165
166
|
data_capture_config=data_capture_config,
|
|
166
167
|
tags=aws_tags,
|
|
168
|
+
container_startup_health_check_timeout=300,
|
|
167
169
|
)
|
|
168
170
|
else:
|
|
169
171
|
raise
|
|
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
|
249
249
|
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
250
250
|
|
|
251
251
|
|
|
252
|
+
def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
|
|
253
|
+
"""Cap extreme outliers in prediction_std using IQR method.
|
|
254
|
+
|
|
255
|
+
Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
|
|
256
|
+
This prevents unreasonably large std values while preserving the
|
|
257
|
+
relative ordering and keeping meaningful high-uncertainty signals.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Array with outliers capped at the upper fence
|
|
264
|
+
"""
|
|
265
|
+
if std_array.ndim == 1:
|
|
266
|
+
std_array = std_array.reshape(-1, 1)
|
|
267
|
+
squeeze = True
|
|
268
|
+
else:
|
|
269
|
+
squeeze = False
|
|
270
|
+
|
|
271
|
+
capped = std_array.copy()
|
|
272
|
+
for col in range(capped.shape[1]):
|
|
273
|
+
col_data = capped[:, col]
|
|
274
|
+
q1, q3 = np.percentile(col_data, [25, 75])
|
|
275
|
+
iqr = q3 - q1
|
|
276
|
+
upper_bound = q3 + 1.5 * iqr
|
|
277
|
+
capped[:, col] = np.minimum(col_data, upper_bound)
|
|
278
|
+
|
|
279
|
+
return capped.squeeze() if squeeze else capped
|
|
280
|
+
|
|
281
|
+
|
|
252
282
|
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
253
283
|
"""Compute standard regression metrics.
|
|
254
284
|
|