workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -1,83 +0,0 @@
1
- # Model: tautomerization_processor
2
- #
3
- # Description: The tautomerization_processor model uses RDKit to perform tautomer enumeration
4
- # and canonicalization of chemical compounds. Tautomerization is the chemical process where
5
- # compounds can interconvert between structurally distinct forms, often affecting their
6
- # chemical properties and reactivity. This model provides a robust approach to identifying
7
- # and processing tautomers, crucial for improving molecular modeling and cheminformatics tasks
8
- # like virtual screening, QSAR modeling, and property prediction.
9
- #
10
- import argparse
11
- import os
12
- import joblib
13
- from io import StringIO
14
- import pandas as pd
15
- import json
16
-
17
- # Local imports
18
- from local_utils import tautomerize_smiles
19
-
20
-
21
- # TRAINING SECTION
22
- #
23
- # This section (__main__) is where SageMaker will execute the job and save the model artifacts.
24
- #
25
- if __name__ == "__main__":
26
- # Script arguments for input/output directories
27
- parser = argparse.ArgumentParser()
28
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
29
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
30
- parser.add_argument(
31
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
32
- )
33
- args = parser.parse_args()
34
-
35
- # This model doesn't get trained; it's a feature processing 'model'
36
-
37
- # Sagemaker expects a model artifact, so we'll save a placeholder
38
- placeholder_model = {}
39
- joblib.dump(placeholder_model, os.path.join(args.model_dir, "model.joblib"))
40
-
41
-
42
- # Model loading and prediction functions
43
- def model_fn(model_dir):
44
- return joblib.load(os.path.join(model_dir, "model.joblib"))
45
-
46
-
47
- def input_fn(input_data, content_type):
48
- """Parse input data and return a DataFrame."""
49
- if not input_data:
50
- raise ValueError("Empty input data is not supported!")
51
-
52
- # Decode bytes to string if necessary
53
- if isinstance(input_data, bytes):
54
- input_data = input_data.decode("utf-8")
55
-
56
- if "text/csv" in content_type:
57
- return pd.read_csv(StringIO(input_data))
58
- elif "application/json" in content_type:
59
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
60
- else:
61
- raise ValueError(f"{content_type} not supported!")
62
-
63
-
64
- def output_fn(output_df, accept_type):
65
- """Supports both CSV and JSON output formats."""
66
- use_explicit_na = False
67
- if "text/csv" in accept_type:
68
- if use_explicit_na:
69
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
70
- else:
71
- csv_output = output_df.to_csv(index=False)
72
- return csv_output, "text/csv"
73
- elif "application/json" in accept_type:
74
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
75
- else:
76
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
77
-
78
-
79
- # Prediction function
80
- def predict_fn(df, model):
81
- # Perform Tautomerization
82
- df = tautomerize_smiles(df)
83
- return df
@@ -1,53 +0,0 @@
1
- # Model: Meta Endpoint Example
2
- # This script is a template for creating a custom meta endpoint in AWS Workbench.
3
- from io import StringIO
4
- import pandas as pd
5
- import json
6
-
7
- # Workbench Bridges imports
8
- try:
9
- from workbench_bridges.endpoints.fast_inference import fast_inference
10
- except ImportError:
11
- print("workbench_bridges not found, this is fine for training...")
12
-
13
-
14
- # Not Used: We need to define this function for SageMaker
15
- def model_fn(model_dir):
16
- return None
17
-
18
-
19
- def input_fn(input_data, content_type):
20
- """Parse input data and return a DataFrame."""
21
- if not input_data:
22
- raise ValueError("Empty input data is not supported!")
23
-
24
- # Decode bytes to string if necessary
25
- if isinstance(input_data, bytes):
26
- input_data = input_data.decode("utf-8")
27
-
28
- # Support CSV and JSON input formats
29
- if "text/csv" in content_type:
30
- return pd.read_csv(StringIO(input_data))
31
- elif "application/json" in content_type:
32
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
33
- else:
34
- raise ValueError(f"{content_type} not supported!")
35
-
36
-
37
- def output_fn(output_df, accept_type):
38
- """Supports both CSV and JSON output formats."""
39
- if "text/csv" in accept_type:
40
- csv_output = output_df.to_csv(index=False)
41
- return csv_output, "text/csv"
42
- elif "application/json" in accept_type:
43
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
44
- else:
45
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
46
-
47
-
48
- # Prediction function
49
- def predict_fn(df, model):
50
-
51
- # Call inference on an endpoint
52
- df = fast_inference("abalone-regression", df)
53
- return df
@@ -1,138 +0,0 @@
1
- # Model: feature_space_proximity
2
- #
3
- # Description: The feature_space_proximity model computes nearest neighbors for the given feature space
4
- #
5
-
6
- # Template Placeholders
7
- TEMPLATE_PARAMS = {
8
- "id_column": "udm_mol_bat_id",
9
- "features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v'],
10
- "target": "udm_asy_res_value",
11
- "track_columns": None
12
- }
13
-
14
- from io import StringIO
15
- import json
16
- import argparse
17
- import os
18
- import pandas as pd
19
-
20
- # Local Imports
21
- from proximity import Proximity
22
-
23
-
24
- # Function to check if dataframe is empty
25
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
26
- """Check if the DataFrame is empty and raise an error if so."""
27
- if df.empty:
28
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
29
- print(msg)
30
- raise ValueError(msg)
31
-
32
-
33
- # Function to match DataFrame columns to model features (case-insensitive)
34
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
35
- """Match and rename DataFrame columns to match the model's features, case-insensitively."""
36
- # Create a set of exact matches from the DataFrame columns
37
- exact_match_set = set(df.columns)
38
-
39
- # Create a case-insensitive map of DataFrame columns
40
- column_map = {col.lower(): col for col in df.columns}
41
- rename_dict = {}
42
-
43
- # Build a dictionary for renaming columns based on case-insensitive matching
44
- for feature in model_features:
45
- if feature in exact_match_set:
46
- rename_dict[feature] = feature
47
- elif feature.lower() in column_map:
48
- rename_dict[column_map[feature.lower()]] = feature
49
-
50
- # Rename columns in the DataFrame to match model features
51
- return df.rename(columns=rename_dict)
52
-
53
-
54
- # TRAINING SECTION
55
- #
56
- # This section (__main__) is where SageMaker will execute the training job
57
- # and save the model artifacts to the model directory.
58
- #
59
- if __name__ == "__main__":
60
- # Template Parameters
61
- id_column = TEMPLATE_PARAMS["id_column"]
62
- features = TEMPLATE_PARAMS["features"]
63
- target = TEMPLATE_PARAMS["target"] # Can be None for unsupervised models
64
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
65
-
66
- # Script arguments for input/output directories
67
- parser = argparse.ArgumentParser()
68
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
69
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
70
- parser.add_argument(
71
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
72
- )
73
- args = parser.parse_args()
74
-
75
- # Load training data from the specified directory
76
- training_files = [
77
- os.path.join(args.train, file)
78
- for file in os.listdir(args.train) if file.endswith(".csv")
79
- ]
80
- all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
81
-
82
- # Check if the DataFrame is empty
83
- check_dataframe(all_df, "training_df")
84
-
85
- # Create the Proximity model
86
- model = Proximity(all_df, id_column, features, target, track_columns=track_columns)
87
-
88
- # Now serialize the model
89
- model.serialize(args.model_dir)
90
-
91
- # Model loading and prediction functions
92
- def model_fn(model_dir):
93
-
94
- # Deserialize the model
95
- model = Proximity.deserialize(model_dir)
96
- return model
97
-
98
-
99
- def input_fn(input_data, content_type):
100
- """Parse input data and return a DataFrame."""
101
- if not input_data:
102
- raise ValueError("Empty input data is not supported!")
103
-
104
- # Decode bytes to string if necessary
105
- if isinstance(input_data, bytes):
106
- input_data = input_data.decode("utf-8")
107
-
108
- if "text/csv" in content_type:
109
- return pd.read_csv(StringIO(input_data))
110
- elif "application/json" in content_type:
111
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
112
- else:
113
- raise ValueError(f"{content_type} not supported!")
114
-
115
-
116
- def output_fn(output_df, accept_type):
117
- """Supports both CSV and JSON output formats."""
118
- use_explicit_na = False
119
- if "text/csv" in accept_type:
120
- if use_explicit_na:
121
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
122
- else:
123
- csv_output = output_df.to_csv(index=False)
124
- return csv_output, "text/csv"
125
- elif "application/json" in accept_type:
126
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
127
- else:
128
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
129
-
130
-
131
- # Prediction function
132
- def predict_fn(df, model):
133
- # Match column names before prediction if needed
134
- df = match_features_case_insensitive(df, model.features + [model.id_column])
135
-
136
- # Compute Nearest neighbors
137
- df = model.neighbors(df)
138
- return df
@@ -1,384 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- from sklearn.preprocessing import StandardScaler
4
- from sklearn.neighbors import NearestNeighbors
5
- from typing import List, Dict
6
- import logging
7
- import pickle
8
- import os
9
- import json
10
- from pathlib import Path
11
- from enum import Enum
12
-
13
- # Set up logging
14
- log = logging.getLogger("workbench")
15
-
16
-
17
- # ^Enumerated^ Proximity Types (distance or similarity)
18
- class ProximityType(Enum):
19
- DISTANCE = "distance"
20
- SIMILARITY = "similarity"
21
-
22
-
23
- class Proximity:
24
- def __init__(
25
- self,
26
- df: pd.DataFrame,
27
- id_column: str,
28
- features: List[str],
29
- target: str = None,
30
- track_columns: List[str] = None,
31
- n_neighbors: int = 10,
32
- ):
33
- """
34
- Initialize the Proximity class.
35
-
36
- Args:
37
- df (pd.DataFrame): DataFrame containing data for neighbor computations.
38
- id_column (str): Name of the column used as the identifier.
39
- features (List[str]): List of feature column names to be used for neighbor computations.
40
- target (str, optional): Name of the target column. Defaults to None.
41
- track_columns (List[str], optional): Additional columns to track in results. Defaults to None.
42
- n_neighbors (int): Number of neighbors to compute. Defaults to 10.
43
- """
44
- self.df = df.dropna(subset=features).copy()
45
- self.id_column = id_column
46
- self.n_neighbors = min(n_neighbors, len(self.df) - 1)
47
- self.target = target
48
- self.features = features
49
- self.scaler = None
50
- self.X = None
51
- self.nn = None
52
- self.proximity_type = None
53
- self.track_columns = track_columns or []
54
-
55
- # Right now we only support numeric features, so remove any columns that are not numeric
56
- non_numeric_features = self.df[self.features].select_dtypes(exclude=["number"]).columns.tolist()
57
- if non_numeric_features:
58
- log.warning(f"Non-numeric features {non_numeric_features} aren't currently supported...")
59
- self.features = [f for f in self.features if f not in non_numeric_features]
60
-
61
- # Build the proximity model
62
- self.build_proximity_model()
63
-
64
- def build_proximity_model(self) -> None:
65
- """Standardize features and fit Nearest Neighbors model.
66
- Note: This method can be overridden in subclasses for custom behavior."""
67
- self.proximity_type = ProximityType.DISTANCE
68
- self.scaler = StandardScaler()
69
- self.X = self.scaler.fit_transform(self.df[self.features])
70
- self.nn = NearestNeighbors(n_neighbors=self.n_neighbors + 1).fit(self.X)
71
-
72
- def all_neighbors(self) -> pd.DataFrame:
73
- """
74
- Compute nearest neighbors for all rows in the dataset.
75
-
76
- Returns:
77
- pd.DataFrame: A DataFrame of neighbors and their distances.
78
- """
79
- distances, indices = self.nn.kneighbors(self.X)
80
- results = []
81
-
82
- for i, (dists, nbrs) in enumerate(zip(distances, indices)):
83
- query_id = self.df.iloc[i][self.id_column]
84
-
85
- # Process neighbors
86
- for neighbor_idx, dist in zip(nbrs, dists):
87
- # Skip self (neighbor index == current row index)
88
- if neighbor_idx == i:
89
- continue
90
- results.append(self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist))
91
-
92
- return pd.DataFrame(results)
93
-
94
- def neighbors(
95
- self,
96
- query_df: pd.DataFrame,
97
- radius: float = None,
98
- include_self: bool = True,
99
- ) -> pd.DataFrame:
100
- """
101
- Return neighbors for rows in a query DataFrame.
102
-
103
- Args:
104
- query_df: DataFrame containing query points
105
- radius: If provided, find all neighbors within this radius
106
- include_self: Whether to include self in results (if present)
107
-
108
- Returns:
109
- DataFrame containing neighbors and distances
110
-
111
- Note: The query DataFrame must include the feature columns. The id_column is optional.
112
- """
113
- # Check if all required features are present
114
- missing = set(self.features) - set(query_df.columns)
115
- if missing:
116
- raise ValueError(f"Query DataFrame is missing required feature columns: {missing}")
117
-
118
- # Check if id_column is present
119
- id_column_present = self.id_column in query_df.columns
120
-
121
- # None of the features can be NaNs, so report rows with NaNs and then drop them
122
- rows_with_nan = query_df[self.features].isna().any(axis=1)
123
-
124
- # Print the ID column for rows with NaNs
125
- if rows_with_nan.any():
126
- log.warning(f"Found {rows_with_nan.sum()} rows with NaNs in feature columns:")
127
- log.warning(query_df.loc[rows_with_nan, self.id_column])
128
-
129
- # Drop rows with NaNs in feature columns and reassign to query_df
130
- query_df = query_df.dropna(subset=self.features)
131
-
132
- # Transform the query features using the model's scaler
133
- X_query = self.scaler.transform(query_df[self.features])
134
-
135
- # Get neighbors using either radius or k-nearest neighbors
136
- if radius is not None:
137
- distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
138
- else:
139
- distances, indices = self.nn.kneighbors(X_query)
140
-
141
- # Build results
142
- all_results = []
143
- for i, (dists, nbrs) in enumerate(zip(distances, indices)):
144
- # Use the ID from the query DataFrame if available, otherwise use the row index
145
- query_id = query_df.iloc[i][self.id_column] if id_column_present else f"query_{i}"
146
-
147
- for neighbor_idx, dist in zip(nbrs, dists):
148
- # Skip if the neighbor is the query itself and include_self is False
149
- neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
150
- if not include_self and neighbor_id == query_id:
151
- continue
152
-
153
- all_results.append(
154
- self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist)
155
- )
156
-
157
- return pd.DataFrame(all_results)
158
-
159
- def _build_neighbor_result(self, query_id, neighbor_idx: int, distance: float) -> Dict:
160
- """
161
- Internal: Build a result dictionary for a single neighbor.
162
-
163
- Args:
164
- query_id: ID of the query point
165
- neighbor_idx: Index of the neighbor in the original DataFrame
166
- distance: Distance between query and neighbor
167
-
168
- Returns:
169
- Dictionary containing neighbor information
170
- """
171
- neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
172
-
173
- # Basic neighbor info
174
- neighbor_info = {
175
- self.id_column: query_id,
176
- "neighbor_id": neighbor_id,
177
- "distance": distance,
178
- }
179
-
180
- # Determine which additional columns to include
181
- relevant_cols = [self.target, "prediction"] if self.target else []
182
- relevant_cols += [c for c in self.df.columns if "_proba" in c or "residual" in c]
183
- relevant_cols += ["outlier"]
184
-
185
- # Add user-specified columns
186
- relevant_cols += self.track_columns
187
-
188
- # Add values for each relevant column that exists in the dataframe
189
- for col in filter(lambda c: c in self.df.columns, relevant_cols):
190
- neighbor_info[col] = self.df.iloc[neighbor_idx][col]
191
-
192
- return neighbor_info
193
-
194
- def serialize(self, directory: str) -> None:
195
- """
196
- Serialize the Proximity model to a directory.
197
-
198
- Args:
199
- directory: Directory path to save the model components
200
- """
201
- # Create directory if it doesn't exist
202
- os.makedirs(directory, exist_ok=True)
203
-
204
- # Save metadata
205
- metadata = {
206
- "id_column": self.id_column,
207
- "features": self.features,
208
- "target": self.target,
209
- "track_columns": self.track_columns,
210
- "n_neighbors": self.n_neighbors,
211
- }
212
-
213
- with open(os.path.join(directory, "metadata.json"), "w") as f:
214
- json.dump(metadata, f)
215
-
216
- # Save the DataFrame
217
- self.df.to_pickle(os.path.join(directory, "df.pkl"))
218
-
219
- # Save the scaler and nearest neighbors model
220
- with open(os.path.join(directory, "scaler.pkl"), "wb") as f:
221
- pickle.dump(self.scaler, f)
222
-
223
- with open(os.path.join(directory, "nn_model.pkl"), "wb") as f:
224
- pickle.dump(self.nn, f)
225
-
226
- log.info(f"Proximity model serialized to {directory}")
227
-
228
- @classmethod
229
- def deserialize(cls, directory: str) -> "Proximity":
230
- """
231
- Deserialize a Proximity model from a directory.
232
-
233
- Args:
234
- directory: Directory path containing the serialized model components
235
-
236
- Returns:
237
- Proximity: A new Proximity instance
238
- """
239
- directory_path = Path(directory)
240
- if not directory_path.exists() or not directory_path.is_dir():
241
- raise ValueError(f"Directory {directory} does not exist or is not a directory")
242
-
243
- # Load metadata
244
- with open(os.path.join(directory, "metadata.json"), "r") as f:
245
- metadata = json.load(f)
246
-
247
- # Load DataFrame
248
- df_path = os.path.join(directory, "df.pkl")
249
- if not os.path.exists(df_path):
250
- raise FileNotFoundError(f"DataFrame file not found at {df_path}")
251
- df = pd.read_pickle(df_path)
252
-
253
- # Create instance but skip _prepare_data
254
- instance = cls.__new__(cls)
255
- instance.df = df
256
- instance.id_column = metadata["id_column"]
257
- instance.features = metadata["features"]
258
- instance.target = metadata["target"]
259
- instance.track_columns = metadata["track_columns"]
260
- instance.n_neighbors = metadata["n_neighbors"]
261
-
262
- # Load scaler and nn model
263
- with open(os.path.join(directory, "scaler.pkl"), "rb") as f:
264
- instance.scaler = pickle.load(f)
265
-
266
- with open(os.path.join(directory, "nn_model.pkl"), "rb") as f:
267
- instance.nn = pickle.load(f)
268
-
269
- # Load X from scaler transform
270
- instance.X = instance.scaler.transform(instance.df[instance.features])
271
-
272
- log.info(f"Proximity model deserialized from {directory}")
273
- return instance
274
-
275
-
276
- # Testing the Proximity class
277
- if __name__ == "__main__":
278
-
279
- pd.set_option("display.max_columns", None)
280
- pd.set_option("display.width", 1000)
281
-
282
- # Create a sample DataFrame
283
- data = {
284
- "ID": [1, 2, 3, 4, 5],
285
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
286
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
287
- "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
288
- }
289
- df = pd.DataFrame(data)
290
-
291
- # Test the Proximity class
292
- features = ["Feature1", "Feature2", "Feature3"]
293
- prox = Proximity(df, id_column="ID", features=features, n_neighbors=3)
294
- print(prox.all_neighbors())
295
-
296
- # Test the neighbors method
297
- print(prox.neighbors(query_df=df.iloc[[0]]))
298
-
299
- # Test the neighbors method with radius
300
- print(prox.neighbors(query_df=df.iloc[0:2], radius=2.0))
301
-
302
- # Test with data that isn't in the 'train' dataframe
303
- query_data = {
304
- "ID": [6],
305
- "Feature1": [0.31],
306
- "Feature2": [0.31],
307
- "Feature3": [2.31],
308
- }
309
- query_df = pd.DataFrame(query_data)
310
- print(prox.neighbors(query_df=query_df))
311
-
312
- # Test with Features list
313
- prox = Proximity(df, id_column="ID", features=["Feature1"], n_neighbors=2)
314
- print(prox.all_neighbors())
315
-
316
- # Create a sample DataFrame
317
- data = {
318
- "foo_id": ["a", "b", "c", "d", "e"], # Testing string IDs
319
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
320
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
321
- "target": [1, 0, 1, 0, 5],
322
- }
323
- df = pd.DataFrame(data)
324
-
325
- # Test with String Ids
326
- prox = Proximity(
327
- df,
328
- id_column="foo_id",
329
- features=["Feature1", "Feature2"],
330
- target="target",
331
- track_columns=["Feature1", "Feature2"],
332
- n_neighbors=3,
333
- )
334
- print(prox.all_neighbors())
335
-
336
- # Test the neighbors method
337
- print(prox.neighbors(query_df=df.iloc[0:2]))
338
-
339
- # Time neighbors with all IDs versus calling all_neighbors
340
- import time
341
-
342
- start_time = time.time()
343
- prox_df = prox.neighbors(query_df=df, include_self=False)
344
- end_time = time.time()
345
- print(f"Time taken for neighbors: {end_time - start_time:.4f} seconds")
346
- start_time = time.time()
347
- prox_df_all = prox.all_neighbors()
348
- end_time = time.time()
349
- print(f"Time taken for all_neighbors: {end_time - start_time:.4f} seconds")
350
-
351
- # Now compare the two dataframes
352
- print("Neighbors DataFrame:")
353
- print(prox_df)
354
- print("\nAll Neighbors DataFrame:")
355
- print(prox_df_all)
356
- # Check for any discrepancies
357
- if prox_df.equals(prox_df_all):
358
- print("The two DataFrames are equal :)")
359
- else:
360
- print("ERROR: The two DataFrames are not equal!")
361
-
362
- # Test querying without the id_column
363
- df_no_id = df.drop(columns=["foo_id"])
364
- print(prox.neighbors(query_df=df_no_id, include_self=False))
365
-
366
- # Test duplicate IDs
367
- data = {
368
- "foo_id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
369
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
370
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
371
- "target": [1, 0, 1, 0, 5],
372
- }
373
- df = pd.DataFrame(data)
374
- prox = Proximity(df, id_column="foo_id", features=["Feature1", "Feature2"], target="target", n_neighbors=3)
375
- print(df.equals(prox.df))
376
-
377
- # Test with a categorical feature
378
- from workbench.api import FeatureSet, Model
379
-
380
- fs = FeatureSet("abalone_features")
381
- model = Model("abalone-regression")
382
- df = fs.pull_dataframe()
383
- prox = Proximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
384
- print(prox.neighbors(query_df=df[0:2]))