workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -28,14 +28,16 @@ from typing import List, Tuple
28
28
 
29
29
  # Template Parameters
30
30
  TEMPLATE_PARAMS = {
31
- "model_type": "classifier",
32
- "target_column": "solubility_class",
33
- "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct', 'fingerprint'],
34
- "compressed_features": ['fingerprint'],
35
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-fingerprints-plus-class/training",
36
- "train_all_data": True
31
+ "model_type": "regressor",
32
+ "target": "class_number_of_rings",
33
+ "features": ['length', 'diameter', 'height', 'whole_weight', 'shucked_weight', 'viscera_weight', 'shell_weight', 'sex'],
34
+ "compressed_features": [],
35
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/abalone-regression/training",
36
+ "train_all_data": False,
37
+ "hyperparameters": {},
37
38
  }
38
39
 
40
+
39
41
  # Function to check if dataframe is empty
40
42
  def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
41
43
  """
@@ -75,7 +77,7 @@ def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFra
75
77
  proba_df = pd.DataFrame(df[proba_column].tolist(), columns=proba_splits)
76
78
 
77
79
  # Drop any proba columns and reset the index in prep for the concat
78
- df = df.drop(columns=[proba_column]+proba_splits, errors="ignore")
80
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
79
81
  df = df.reset_index(drop=True)
80
82
 
81
83
  # Concatenate the new columns with the original DataFrame
@@ -88,13 +90,12 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
88
90
  """
89
91
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
90
92
  Prioritizes exact matches, then case-insensitive matches.
91
-
93
+
92
94
  Raises ValueError if any model features cannot be matched.
93
95
  """
94
96
  df_columns_lower = {col.lower(): col for col in df.columns}
95
97
  rename_dict = {}
96
98
  missing = []
97
-
98
99
  for feature in model_features:
99
100
  if feature in df.columns:
100
101
  continue # Exact match
@@ -102,10 +103,11 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
102
103
  rename_dict[df_columns_lower[feature.lower()]] = feature
103
104
  else:
104
105
  missing.append(feature)
105
-
106
+
106
107
  if missing:
107
108
  raise ValueError(f"Features not found: {missing}")
108
-
109
+
110
+ # Rename the DataFrame columns to match the model features
109
111
  return df.rename(columns=rename_dict)
110
112
 
111
113
 
@@ -140,8 +142,10 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
140
142
  return df, category_mappings
141
143
 
142
144
 
143
- def decompress_features(df: pd.DataFrame, features: List[str], compressed_features: List[str]) -> Tuple[pd.DataFrame, List[str]]:
144
- """Prepare features for the XGBoost model
145
+ def decompress_features(
146
+ df: pd.DataFrame, features: List[str], compressed_features: List[str]
147
+ ) -> Tuple[pd.DataFrame, List[str]]:
148
+ """Prepare features for the model by decompressing bitstring features
145
149
 
146
150
  Args:
147
151
  df (pd.DataFrame): The features DataFrame
@@ -166,7 +170,7 @@ def decompress_features(df: pd.DataFrame, features: List[str], compressed_featur
166
170
  )
167
171
 
168
172
  # Decompress the specified compressed features
169
- decompressed_features = features
173
+ decompressed_features = features.copy()
170
174
  for feature in compressed_features:
171
175
  if (feature not in df.columns) or (feature not in features):
172
176
  print(f"Feature '{feature}' not in the features list, skipping decompression.")
@@ -197,13 +201,14 @@ if __name__ == "__main__":
197
201
  """The main function is for training the XGBoost model"""
198
202
 
199
203
  # Harness Template Parameters
200
- target = TEMPLATE_PARAMS["target_column"]
204
+ target = TEMPLATE_PARAMS["target"]
201
205
  features = TEMPLATE_PARAMS["features"]
202
206
  orig_features = features.copy()
203
207
  compressed_features = TEMPLATE_PARAMS["compressed_features"]
204
208
  model_type = TEMPLATE_PARAMS["model_type"]
205
209
  model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
206
210
  train_all_data = TEMPLATE_PARAMS["train_all_data"]
211
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
207
212
  validation_split = 0.2
208
213
 
209
214
  # Script arguments for input/output directories
@@ -216,11 +221,7 @@ if __name__ == "__main__":
216
221
  args = parser.parse_args()
217
222
 
218
223
  # Read the training data into DataFrames
219
- training_files = [
220
- os.path.join(args.train, file)
221
- for file in os.listdir(args.train)
222
- if file.endswith(".csv")
223
- ]
224
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
224
225
  print(f"Training Files: {training_files}")
225
226
 
226
227
  # Combine files and read them all into a single pandas dataframe
@@ -255,15 +256,16 @@ if __name__ == "__main__":
255
256
  else:
256
257
  # Just do a random training Split
257
258
  print("WARNING: No training column found, splitting data with random state=42")
258
- df_train, df_val = train_test_split(
259
- all_df, test_size=validation_split, random_state=42
260
- )
259
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
261
260
  print(f"FIT/TRAIN: {df_train.shape}")
262
261
  print(f"VALIDATION: {df_val.shape}")
263
262
 
263
+ # Use any hyperparameters to set up both the trainer and model configurations
264
+ print(f"Hyperparameters: {hyperparameters}")
265
+
264
266
  # Now spin up our XGB Model
265
267
  if model_type == "classifier":
266
- xgb_model = xgb.XGBClassifier(enable_categorical=True)
268
+ xgb_model = xgb.XGBClassifier(enable_categorical=True, **hyperparameters)
267
269
 
268
270
  # Encode the target column
269
271
  label_encoder = LabelEncoder()
@@ -271,12 +273,12 @@ if __name__ == "__main__":
271
273
  df_val[target] = label_encoder.transform(df_val[target])
272
274
 
273
275
  else:
274
- xgb_model = xgb.XGBRegressor(enable_categorical=True)
276
+ xgb_model = xgb.XGBRegressor(enable_categorical=True, **hyperparameters)
275
277
  label_encoder = None # We don't need this for regression
276
278
 
277
279
  # Grab our Features, Target and Train the Model
278
280
  y_train = df_train[target]
279
- X_train= df_train[features]
281
+ X_train = df_train[features]
280
282
  xgb_model.fit(X_train, y_train)
281
283
 
282
284
  # Make Predictions on the Validation Set
@@ -315,9 +317,7 @@ if __name__ == "__main__":
315
317
  label_names = label_encoder.classes_
316
318
 
317
319
  # Calculate various model performance metrics
318
- scores = precision_recall_fscore_support(
319
- y_validate, preds, average=None, labels=label_names
320
- )
320
+ scores = precision_recall_fscore_support(y_validate, preds, average=None, labels=label_names)
321
321
 
322
322
  # Put the scores into a dataframe
323
323
  score_df = pd.DataFrame(
@@ -325,13 +325,13 @@ if __name__ == "__main__":
325
325
  target: label_names,
326
326
  "precision": scores[0],
327
327
  "recall": scores[1],
328
- "fscore": scores[2],
328
+ "f1": scores[2],
329
329
  "support": scores[3],
330
330
  }
331
331
  )
332
332
 
333
333
  # We need to get creative with the Classification Metrics
334
- metrics = ["precision", "recall", "fscore", "support"]
334
+ metrics = ["precision", "recall", "f1", "support"]
335
335
  for t in label_names:
336
336
  for m in metrics:
337
337
  value = score_df.loc[score_df[target] == t, m].iloc[0]
@@ -355,7 +355,9 @@ if __name__ == "__main__":
355
355
  print(f"NumRows: {len(df_val)}")
356
356
 
357
357
  # Now save the model to the standard place/name
358
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
358
+ joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
359
+
360
+ # Save the label encoder if we have one
359
361
  if label_encoder:
360
362
  joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
361
363
 
@@ -370,19 +372,8 @@ if __name__ == "__main__":
370
372
 
371
373
  def model_fn(model_dir):
372
374
  """Deserialize and return fitted XGBoost model"""
373
-
374
- model_path = os.path.join(model_dir, "xgb_model.json")
375
-
376
- with open(model_path, "r") as f:
377
- model_json = json.load(f)
378
-
379
- sklearn_data = model_json['learner']['attributes']['scikit_learn']
380
- model_type = json.loads(sklearn_data)['_estimator_type']
381
-
382
- model_class = xgb.XGBClassifier if model_type == "classifier" else xgb.XGBRegressor
383
- model = model_class(enable_categorical=True)
384
- model.load_model(model_path)
385
-
375
+ model_path = os.path.join(model_dir, "xgb_model.joblib")
376
+ model = joblib.load(model_path)
386
377
  return model
387
378
 
388
379
 
@@ -390,7 +381,7 @@ def input_fn(input_data, content_type):
390
381
  """Parse input data and return a DataFrame."""
391
382
  if not input_data:
392
383
  raise ValueError("Empty input data is not supported!")
393
-
384
+
394
385
  # Decode bytes to string if necessary
395
386
  if isinstance(input_data, bytes):
396
387
  input_data = input_data.decode("utf-8")
@@ -6,11 +6,13 @@ import numpy as np
6
6
  # Model Performance Scores
7
7
  from sklearn.metrics import (
8
8
  mean_absolute_error,
9
+ median_absolute_error,
9
10
  r2_score,
10
11
  root_mean_squared_error,
11
12
  precision_recall_fscore_support,
12
13
  confusion_matrix,
13
14
  )
15
+ from scipy.stats import spearmanr
14
16
 
15
17
  # Classification Encoder
16
18
  from sklearn.preprocessing import LabelEncoder
@@ -29,13 +31,15 @@ from typing import List, Tuple
29
31
  # Template Parameters
30
32
  TEMPLATE_PARAMS = {
31
33
  "model_type": "{{model_type}}",
32
- "target_column": "{{target_column}}",
34
+ "target": "{{target_column}}",
33
35
  "features": "{{feature_list}}",
34
36
  "compressed_features": "{{compressed_features}}",
35
37
  "model_metrics_s3_path": "{{model_metrics_s3_path}}",
36
- "train_all_data": "{{train_all_data}}"
38
+ "train_all_data": "{{train_all_data}}",
39
+ "hyperparameters": "{{hyperparameters}}",
37
40
  }
38
41
 
42
+
39
43
  # Function to check if dataframe is empty
40
44
  def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
41
45
  """
@@ -75,7 +79,7 @@ def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFra
75
79
  proba_df = pd.DataFrame(df[proba_column].tolist(), columns=proba_splits)
76
80
 
77
81
  # Drop any proba columns and reset the index in prep for the concat
78
- df = df.drop(columns=[proba_column]+proba_splits, errors="ignore")
82
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
79
83
  df = df.reset_index(drop=True)
80
84
 
81
85
  # Concatenate the new columns with the original DataFrame
@@ -88,13 +92,12 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
88
92
  """
89
93
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
90
94
  Prioritizes exact matches, then case-insensitive matches.
91
-
95
+
92
96
  Raises ValueError if any model features cannot be matched.
93
97
  """
94
98
  df_columns_lower = {col.lower(): col for col in df.columns}
95
99
  rename_dict = {}
96
100
  missing = []
97
-
98
101
  for feature in model_features:
99
102
  if feature in df.columns:
100
103
  continue # Exact match
@@ -102,10 +105,11 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
102
105
  rename_dict[df_columns_lower[feature.lower()]] = feature
103
106
  else:
104
107
  missing.append(feature)
105
-
108
+
106
109
  if missing:
107
110
  raise ValueError(f"Features not found: {missing}")
108
-
111
+
112
+ # Rename the DataFrame columns to match the model features
109
113
  return df.rename(columns=rename_dict)
110
114
 
111
115
 
@@ -140,8 +144,10 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
140
144
  return df, category_mappings
141
145
 
142
146
 
143
- def decompress_features(df: pd.DataFrame, features: List[str], compressed_features: List[str]) -> Tuple[pd.DataFrame, List[str]]:
144
- """Prepare features for the XGBoost model
147
+ def decompress_features(
148
+ df: pd.DataFrame, features: List[str], compressed_features: List[str]
149
+ ) -> Tuple[pd.DataFrame, List[str]]:
150
+ """Prepare features for the model by decompressing bitstring features
145
151
 
146
152
  Args:
147
153
  df (pd.DataFrame): The features DataFrame
@@ -166,7 +172,7 @@ def decompress_features(df: pd.DataFrame, features: List[str], compressed_featur
166
172
  )
167
173
 
168
174
  # Decompress the specified compressed features
169
- decompressed_features = features
175
+ decompressed_features = features.copy()
170
176
  for feature in compressed_features:
171
177
  if (feature not in df.columns) or (feature not in features):
172
178
  print(f"Feature '{feature}' not in the features list, skipping decompression.")
@@ -197,13 +203,14 @@ if __name__ == "__main__":
197
203
  """The main function is for training the XGBoost model"""
198
204
 
199
205
  # Harness Template Parameters
200
- target = TEMPLATE_PARAMS["target_column"]
206
+ target = TEMPLATE_PARAMS["target"]
201
207
  features = TEMPLATE_PARAMS["features"]
202
208
  orig_features = features.copy()
203
209
  compressed_features = TEMPLATE_PARAMS["compressed_features"]
204
210
  model_type = TEMPLATE_PARAMS["model_type"]
205
211
  model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
206
212
  train_all_data = TEMPLATE_PARAMS["train_all_data"]
213
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
207
214
  validation_split = 0.2
208
215
 
209
216
  # Script arguments for input/output directories
@@ -216,11 +223,7 @@ if __name__ == "__main__":
216
223
  args = parser.parse_args()
217
224
 
218
225
  # Read the training data into DataFrames
219
- training_files = [
220
- os.path.join(args.train, file)
221
- for file in os.listdir(args.train)
222
- if file.endswith(".csv")
223
- ]
226
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
224
227
  print(f"Training Files: {training_files}")
225
228
 
226
229
  # Combine files and read them all into a single pandas dataframe
@@ -255,15 +258,16 @@ if __name__ == "__main__":
255
258
  else:
256
259
  # Just do a random training Split
257
260
  print("WARNING: No training column found, splitting data with random state=42")
258
- df_train, df_val = train_test_split(
259
- all_df, test_size=validation_split, random_state=42
260
- )
261
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
261
262
  print(f"FIT/TRAIN: {df_train.shape}")
262
263
  print(f"VALIDATION: {df_val.shape}")
263
264
 
265
+ # Use any hyperparameters to set up both the trainer and model configurations
266
+ print(f"Hyperparameters: {hyperparameters}")
267
+
264
268
  # Now spin up our XGB Model
265
269
  if model_type == "classifier":
266
- xgb_model = xgb.XGBClassifier(enable_categorical=True)
270
+ xgb_model = xgb.XGBClassifier(enable_categorical=True, **hyperparameters)
267
271
 
268
272
  # Encode the target column
269
273
  label_encoder = LabelEncoder()
@@ -271,12 +275,12 @@ if __name__ == "__main__":
271
275
  df_val[target] = label_encoder.transform(df_val[target])
272
276
 
273
277
  else:
274
- xgb_model = xgb.XGBRegressor(enable_categorical=True)
278
+ xgb_model = xgb.XGBRegressor(enable_categorical=True, **hyperparameters)
275
279
  label_encoder = None # We don't need this for regression
276
280
 
277
281
  # Grab our Features, Target and Train the Model
278
282
  y_train = df_train[target]
279
- X_train= df_train[features]
283
+ X_train = df_train[features]
280
284
  xgb_model.fit(X_train, y_train)
281
285
 
282
286
  # Make Predictions on the Validation Set
@@ -315,9 +319,7 @@ if __name__ == "__main__":
315
319
  label_names = label_encoder.classes_
316
320
 
317
321
  # Calculate various model performance metrics
318
- scores = precision_recall_fscore_support(
319
- y_validate, preds, average=None, labels=label_names
320
- )
322
+ scores = precision_recall_fscore_support(y_validate, preds, average=None, labels=label_names)
321
323
 
322
324
  # Put the scores into a dataframe
323
325
  score_df = pd.DataFrame(
@@ -325,13 +327,13 @@ if __name__ == "__main__":
325
327
  target: label_names,
326
328
  "precision": scores[0],
327
329
  "recall": scores[1],
328
- "fscore": scores[2],
330
+ "f1": scores[2],
329
331
  "support": scores[3],
330
332
  }
331
333
  )
332
334
 
333
335
  # We need to get creative with the Classification Metrics
334
- metrics = ["precision", "recall", "fscore", "support"]
336
+ metrics = ["precision", "recall", "f1", "support"]
335
337
  for t in label_names:
336
338
  for m in metrics:
337
339
  value = score_df.loc[score_df[target] == t, m].iloc[0]
@@ -348,14 +350,21 @@ if __name__ == "__main__":
348
350
  # Calculate various model performance metrics (regression)
349
351
  rmse = root_mean_squared_error(y_validate, preds)
350
352
  mae = mean_absolute_error(y_validate, preds)
353
+ medae = median_absolute_error(y_validate, preds)
351
354
  r2 = r2_score(y_validate, preds)
352
- print(f"RMSE: {rmse:.3f}")
353
- print(f"MAE: {mae:.3f}")
354
- print(f"R2: {r2:.3f}")
355
- print(f"NumRows: {len(df_val)}")
355
+ spearman_corr = spearmanr(y_validate, preds).correlation
356
+ support = len(df_val)
357
+ print(f"rmse: {rmse:.3f}")
358
+ print(f"mae: {mae:.3f}")
359
+ print(f"medae: {medae:.3f}")
360
+ print(f"r2: {r2:.3f}")
361
+ print(f"spearmanr: {spearman_corr:.3f}")
362
+ print(f"support: {support}")
356
363
 
357
364
  # Now save the model to the standard place/name
358
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
365
+ joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
366
+
367
+ # Save the label encoder if we have one
359
368
  if label_encoder:
360
369
  joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
361
370
 
@@ -370,19 +379,8 @@ if __name__ == "__main__":
370
379
 
371
380
  def model_fn(model_dir):
372
381
  """Deserialize and return fitted XGBoost model"""
373
-
374
- model_path = os.path.join(model_dir, "xgb_model.json")
375
-
376
- with open(model_path, "r") as f:
377
- model_json = json.load(f)
378
-
379
- sklearn_data = model_json['learner']['attributes']['scikit_learn']
380
- model_type = json.loads(sklearn_data)['_estimator_type']
381
-
382
- model_class = xgb.XGBClassifier if model_type == "classifier" else xgb.XGBRegressor
383
- model = model_class(enable_categorical=True)
384
- model.load_model(model_path)
385
-
382
+ model_path = os.path.join(model_dir, "xgb_model.joblib")
383
+ model = joblib.load(model_path)
386
384
  return model
387
385
 
388
386
 
@@ -390,7 +388,7 @@ def input_fn(input_data, content_type):
390
388
  """Parse input data and return a DataFrame."""
391
389
  if not input_data:
392
390
  raise ValueError("Empty input data is not supported!")
393
-
391
+
394
392
  # Decode bytes to string if necessary
395
393
  if isinstance(input_data, bytes):
396
394
  input_data = input_data.decode("utf-8")
@@ -1,14 +1,25 @@
1
+ # flake8: noqa: E402
2
+ import os
3
+ import sys
4
+ import logging
5
+ import importlib
6
+ import webbrowser
7
+ import readline # noqa: F401
8
+
9
+ # Disable OpenMP parallelism to avoid segfaults with PyTorch in iPython
10
+ # This is a known issue on macOS where libomp crashes during thread synchronization
11
+ # Must be set before importing numpy/pandas/torch or any library that uses OpenMP
12
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
13
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
14
+
15
+ import IPython
1
16
  from IPython import start_ipython
17
+ from distutils.version import LooseVersion
2
18
  from IPython.terminal.prompts import Prompts
3
19
  from IPython.terminal.ipapp import load_default_config
4
20
  from pygments.token import Token
5
- import sys
6
- import logging
7
- import importlib
8
21
  import botocore
9
- import webbrowser
10
22
  import pandas as pd
11
- import readline # noqa
12
23
 
13
24
  try:
14
25
  import matplotlib.pyplot as plt # noqa
@@ -39,7 +50,7 @@ from workbench.cached.cached_meta import CachedMeta
39
50
  try:
40
51
  import rdkit # noqa
41
52
  import mordred # noqa
42
- from workbench.utils import chem_utils
53
+ from workbench.utils.chem_utils import vis
43
54
 
44
55
  HAVE_CHEM_UTILS = True
45
56
  except ImportError:
@@ -70,7 +81,7 @@ if not ConfigManager().config_okay():
70
81
 
71
82
  # Set the log level to important
72
83
  log = logging.getLogger("workbench")
73
- log.setLevel(IMPORTANT_LEVEL_NUM)
84
+ log.setLevel(logging.INFO)
74
85
  log.addFilter(
75
86
  lambda record: not (
76
87
  record.getMessage().startswith("Async: Metadata") or record.getMessage().startswith("Updated Metadata")
@@ -176,12 +187,12 @@ class WorkbenchShell:
176
187
 
177
188
  # Add cheminformatics utils if available
178
189
  if HAVE_CHEM_UTILS:
179
- self.commands["show"] = chem_utils.show
190
+ self.commands["show"] = vis.show
180
191
 
181
192
  def start(self):
182
193
  """Start the Workbench IPython shell"""
183
194
  cprint("magenta", "\nWelcome to Workbench!")
184
- if self.aws_status is False:
195
+ if not self.aws_status:
185
196
  cprint("red", "AWS Account Connection Failed...Review/Fix the Workbench Config:")
186
197
  cprint("red", f"Path: {self.cm.site_config_path}")
187
198
  self.show_config()
@@ -202,7 +213,10 @@ class WorkbenchShell:
202
213
 
203
214
  # Start IPython with the config and commands in the namespace
204
215
  try:
205
- ipython_argv = ["--no-tip", "--theme", "linux"]
216
+ if LooseVersion(IPython.__version__) >= LooseVersion("9.0.0"):
217
+ ipython_argv = ["--no-tip", "--theme", "linux"]
218
+ else:
219
+ ipython_argv = []
206
220
  start_ipython(ipython_argv, user_ns=locs, config=config)
207
221
  finally:
208
222
  spinner = self.spinner_start("Goodbye to AWS:")
@@ -520,7 +534,7 @@ class WorkbenchShell:
520
534
  def get_meta(self):
521
535
  return self.meta
522
536
 
523
- def plot_manager(self, data, plot_type: str = "table", **kwargs):
537
+ def plot_manager(self, data, plot_type: str = "scatter", **kwargs):
524
538
  """Plot Manager for Workbench"""
525
539
  from workbench.web_interface.components.plugins import ag_table, graph_plot, scatter_plot
526
540
 
@@ -555,14 +569,14 @@ class WorkbenchShell:
555
569
  from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
556
570
 
557
571
  # Get kwargs
558
- theme = kwargs.get("theme", "dark")
572
+ theme = kwargs.get("theme", "midnight_blue")
559
573
 
560
574
  plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
561
575
 
562
- # Run the server and open in the browser
563
- plugin_test.run()
576
+ # Open the browser and run the dash server
564
577
  url = f"http://127.0.0.1:{plugin_test.port}"
565
578
  webbrowser.open(url)
579
+ plugin_test.run()
566
580
 
567
581
 
568
582
  # Launch Shell Entry Point