workbench 0.8.189__py3-none-any.whl → 0.8.190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

@@ -1,11 +1,13 @@
1
1
  import sys
2
2
  import os
3
+ import json
3
4
  import importlib.util
4
5
 
5
6
 
6
7
  def main():
7
8
  if len(sys.argv) != 2:
8
9
  print("Usage: lambda_launcher <handler_module_name>")
10
+ print("\nOptional: Create event.json with test event")
9
11
  sys.exit(1)
10
12
 
11
13
  handler_file = sys.argv[1]
@@ -19,6 +21,15 @@ def main():
19
21
  print(f"Error: File '{handler_file}' not found")
20
22
  sys.exit(1)
21
23
 
24
+ # Load event configuration
25
+ if os.path.exists("event.json"):
26
+ print("Loading event from event.json")
27
+ with open("event.json") as f:
28
+ event = json.load(f)
29
+ else:
30
+ print("No event.json found, using empty event")
31
+ event = {}
32
+
22
33
  # Load the module dynamically
23
34
  spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
24
35
  lambda_module = importlib.util.module_from_spec(spec)
@@ -27,12 +38,14 @@ def main():
27
38
  # Call the lambda_handler
28
39
  print(f"Invoking lambda_handler from {handler_file}...")
29
40
  print("-" * 50)
41
+ print(f"Event: {json.dumps(event, indent=2)}")
42
+ print("-" * 50)
30
43
 
31
- result = lambda_module.lambda_handler({}, {})
44
+ result = lambda_module.lambda_handler(event, {})
32
45
 
33
46
  print("-" * 50)
34
47
  print("Result:")
35
- print(result)
48
+ print(json.dumps(result, indent=2))
36
49
 
37
50
 
38
51
  if __name__ == "__main__":
@@ -14,7 +14,12 @@ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
14
 
15
15
 
16
16
  def submit_to_sqs(
17
- script_path: str, size: str = "small", realtime: bool = False, dt: bool = False, promote: bool = False
17
+ script_path: str,
18
+ size: str = "small",
19
+ realtime: bool = False,
20
+ dt: bool = False,
21
+ promote: bool = False,
22
+ model_names: str = None,
18
23
  ) -> None:
19
24
  """
20
25
  Upload script to S3 and submit message to SQS queue for processing.
@@ -25,9 +30,10 @@ def submit_to_sqs(
25
30
  realtime: If True, sets serverless=False for real-time processing (default: False)
26
31
  dt: If True, sets DT=True in environment (default: False)
27
32
  promote: If True, sets PROMOTE=True in environment (default: False)
33
+ model_names: Comma-separated model names (required if dt=True)
28
34
 
29
35
  Raises:
30
- ValueError: If size is invalid or script file not found
36
+ ValueError: If size is invalid, script file not found, or dt=True without model_names
31
37
  """
32
38
  print(f"\n{'=' * 60}")
33
39
  print("🚀 SUBMITTING ML PIPELINE JOB")
@@ -35,6 +41,11 @@ def submit_to_sqs(
35
41
 
36
42
  if size not in ["small", "medium", "large"]:
37
43
  raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
44
+
45
+ # Validate dt requirements
46
+ if dt and not model_names:
47
+ raise ValueError("model_names is required when dt=True")
48
+
38
49
  # Validate script exists
39
50
  script_file = Path(script_path)
40
51
  if not script_file.exists():
@@ -45,6 +56,8 @@ def submit_to_sqs(
45
56
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
46
57
  print(f"🔄 DynamicTraining: {dt}")
47
58
  print(f"🆕 Promote: {promote}")
59
+ if model_names:
60
+ print(f"🏷️ Model names: {model_names}")
48
61
  print(f"🪣 Bucket: {workbench_bucket}")
49
62
  sqs = AWSAccountClamp().boto3_session.client("sqs")
50
63
  script_name = script_file.name
@@ -108,6 +121,10 @@ def submit_to_sqs(
108
121
  "PROMOTE": str(promote),
109
122
  }
110
123
 
124
+ # Add MODEL_NAMES if provided
125
+ if model_names:
126
+ message["environment"]["MODEL_NAMES"] = model_names
127
+
111
128
  # Send the message to SQS
112
129
  try:
113
130
  print("\n📨 Sending message to SQS...")
@@ -132,6 +149,8 @@ def submit_to_sqs(
132
149
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
133
150
  print(f"🔄 DynamicTraining: {dt}")
134
151
  print(f"🆕 Promote: {promote}")
152
+ if model_names:
153
+ print(f"🏷️ Model names: {model_names}")
135
154
  print(f"🆔 Message ID: {message_id}")
136
155
  print("\n🔍 MONITORING LOCATIONS:")
137
156
  print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
@@ -163,9 +182,20 @@ def main():
163
182
  action="store_true",
164
183
  help="Set Promote=True (models and endpoints will use promoted naming",
165
184
  )
185
+ parser.add_argument(
186
+ "--model-names",
187
+ help="Comma-separated model names (required if --dt is set)",
188
+ )
166
189
  args = parser.parse_args()
167
190
  try:
168
- submit_to_sqs(args.script_file, args.size, realtime=args.realtime, dt=args.dt, promote=args.promote)
191
+ submit_to_sqs(
192
+ args.script_file,
193
+ args.size,
194
+ realtime=args.realtime,
195
+ dt=args.dt,
196
+ promote=args.promote,
197
+ model_names=args.model_names,
198
+ )
169
199
  except Exception as e:
170
200
  print(f"\n❌ ERROR: {e}")
171
201
  log.error(f"Error: {e}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.189
3
+ Version: 0.8.190
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License-Expression: MIT
@@ -133,14 +133,12 @@ workbench/model_scripts/custom_models/meta_endpoints/example.py,sha256=hzOAuLhIG
133
133
  workbench/model_scripts/custom_models/network_security/Readme.md,sha256=Z2gtiu0hLHvEJ1x-_oFq3qJZcsK81sceBAGAGltpqQ8,222
134
134
  workbench/model_scripts/custom_models/proximity/Readme.md,sha256=RlMFAJZgAT2mCgDk-UwR_R0Y_NbCqeI5-8DUsxsbpWQ,289
135
135
  workbench/model_scripts/custom_models/proximity/feature_space_proximity.template,sha256=eOllmqB20BWtTiV53dgpIqXKtgSbPFDW_zf8PvM3oF0,4813
136
- workbench/model_scripts/custom_models/proximity/generated_model_script.py,sha256=Zk170ztSM_rNSxgbY6ofb5NaqkEdQdhYg0UZprYqRyk,9056
137
136
  workbench/model_scripts/custom_models/proximity/proximity.py,sha256=P8f3GHRhuc4QHj5KkKW0JMrHhIo2QdBiFG-JituTV1U,14633
138
137
  workbench/model_scripts/custom_models/proximity/requirements.txt,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
139
138
  workbench/model_scripts/custom_models/uq_models/Readme.md,sha256=UVpL-lvtTrLqwBeQFinLhd_uNrEw4JUlggIdUSDrd-w,188
140
139
  workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=ca3CaAk6HVuNv1HnPgABTzRY3oDrRxomjgD4V1ZDwoc,6448
141
140
  workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=xlKLHeLQkScONnrlbAGIsrCm2wwsvcfv4Vdrw4nlc_8,13457
142
141
  workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=3nMlCi8nEbc4N-MQTzjfIcljfDQkUmWeLBfmd18m5fg,6632
143
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=Y89qD3gJ8wx9klXXDUQNfoLTImVFcdYLfRz-SA8mppE,21461
144
142
  workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=XTfhODRaHlI1jZGo9pSe-TqNsk2_nuSw0xMO2fKzDv8,14011
145
143
  workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=v1rviYTJGJnQRGgAyveXhOQlS-WFCTlc2vdnWq6HIXk,8241
146
144
  workbench/model_scripts/custom_models/uq_models/proximity.py,sha256=P8f3GHRhuc4QHj5KkKW0JMrHhIo2QdBiFG-JituTV1U,14633
@@ -148,18 +146,13 @@ workbench/model_scripts/custom_models/uq_models/requirements.txt,sha256=fw7T7t_Y
148
146
  workbench/model_scripts/custom_script_example/custom_model_script.py,sha256=T8aydawgRVAdSlDimoWpXxG2YuWWQkbcjBVjAeSG2_0,6408
149
147
  workbench/model_scripts/custom_script_example/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
150
148
  workbench/model_scripts/ensemble_xgb/ensemble_xgb.template,sha256=pWmuo-EVz0owvkRI-h9mUTYt1-ouyD-_yyQu6SQbYZ4,10350
151
- workbench/model_scripts/ensemble_xgb/generated_model_script.py,sha256=dsjUGm22xI1ThGn97HPKtooyEPK-HOQnf5chnZ7-MXk,10675
152
149
  workbench/model_scripts/ensemble_xgb/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
153
- workbench/model_scripts/pytorch_model/generated_model_script.py,sha256=Mr1IMQJE_ML899qjzhjkrP521IjvcAvqU0pk--FB7KY,22356
154
150
  workbench/model_scripts/pytorch_model/pytorch.template,sha256=_gRp6DH294FLxF21UpSTq7s9RFfrLjViKvjXQ4yDfBQ,21999
155
151
  workbench/model_scripts/pytorch_model/requirements.txt,sha256=ICS5nW0wix44EJO2tJszJSaUrSvhSfdedn6FcRInGx4,181
156
- workbench/model_scripts/scikit_learn/generated_model_script.py,sha256=c73ZpJBlU5k13Nx-ZDkLXu7da40CYyhwjwwmuPq6uLg,12870
157
152
  workbench/model_scripts/scikit_learn/requirements.txt,sha256=aVvwiJ3LgBUhM_PyFlb2gHXu_kpGPho3ANBzlOkfcvs,107
158
153
  workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=QQvqx-eX9ZTbYmyupq6R6vIQwosmsmY_MRBPaHyfjdk,12586
159
- workbench/model_scripts/uq_models/generated_model_script.py,sha256=U4_41APyNISnJ3EHnXiaSIEdb3E1M1JT7ECNjsoX4fI,21197
160
154
  workbench/model_scripts/uq_models/mapie.template,sha256=2HIwB_658IsZiLIV1RViIZBIGgXxDsJPZinDUu8SchU,18961
161
155
  workbench/model_scripts/uq_models/requirements.txt,sha256=fw7T7t_YJAXK3T6Ysbesxh_Agx_tv0oYx72cEBTqRDY,98
162
- workbench/model_scripts/xgb_model/generated_model_script.py,sha256=W3koc4swpjOncpMKWqBHnTGDie0CoDpY9U1oj4OUJrI,17990
163
156
  workbench/model_scripts/xgb_model/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
164
157
  workbench/model_scripts/xgb_model/xgb_model.template,sha256=0uXknIEqgUaIFUfu2gfkxa3WHUr8HBBqBepGUTDvrhQ,17917
165
158
  workbench/repl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -168,9 +161,9 @@ workbench/resources/open_source_api.key,sha256=3S0OTblsmC0msUPdE_dbBmI83xJNmYscu
168
161
  workbench/resources/signature_verify_pub.pem,sha256=V3-u-3_z2PH-805ybkKvzDOBwAbvHxcKn0jLBImEtzM,272
169
162
  workbench/scripts/check_double_bond_stereo.py,sha256=p5hnL54Weq77ES0HCELq9JeoM-PyUGkvVSeWYF2dKyo,7776
170
163
  workbench/scripts/glue_launcher.py,sha256=bIKQvfGxpAhzbeNvTnHfRW_5kQhY-169_868ZnCejJk,10692
171
- workbench/scripts/lambda_launcher.py,sha256=U5HevvWdwN0SUrN2kpbkf0doY-5Ih_LzjJTH-45LBJ8,925
164
+ workbench/scripts/lambda_launcher.py,sha256=qnwgxmCeiiMWlbuqh04_Ubp55PTLgI3ADoJNzjY4mnU,1368
172
165
  workbench/scripts/ml_pipeline_batch.py,sha256=1T5JnLlUJR7bwAGBLHmLPOuj1xFRqVIQX8PsuDhHy8o,4907
173
- workbench/scripts/ml_pipeline_sqs.py,sha256=ebe8clE6dMONF43_JiX5Qx1WESPfGlF2-AifvJOde50,6578
166
+ workbench/scripts/ml_pipeline_sqs.py,sha256=A-v_x5TOLSaPaR36WZaoQxmYPvTKTjzIeGPHtCgvkhc,7354
174
167
  workbench/scripts/monitor_cloud_watch.py,sha256=s7MY4bsHts0nup9G0lWESCvgJZ9Mw1Eo-c8aKRgLjMw,9235
175
168
  workbench/scripts/redis_expire.py,sha256=DxI_RKSNlrW2BsJZXcsSbaWGBgPZdPhtzHjV9SUtElE,1120
176
169
  workbench/scripts/redis_report.py,sha256=iaJSuGPyLCs6e0TMcZDoT0YyJ43xJ1u74YD8FLnnUg4,990
@@ -287,9 +280,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
287
280
  workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
288
281
  workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
289
282
  workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
290
- workbench-0.8.189.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
291
- workbench-0.8.189.dist-info/METADATA,sha256=J9H9FvKMQ7q84F5PEZb0kOLGYlfjjcjO4WTMchNpcB8,9261
292
- workbench-0.8.189.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
293
- workbench-0.8.189.dist-info/entry_points.txt,sha256=o7ohD4D2oygnHp7i9-C0LfcHDuPW5Tv0JXGAg97DpGk,413
294
- workbench-0.8.189.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
295
- workbench-0.8.189.dist-info/RECORD,,
283
+ workbench-0.8.190.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
284
+ workbench-0.8.190.dist-info/METADATA,sha256=Qv7v2gWbBQkfpV3w6RCOphkaovGN29Oa6GuecCOlsok,9261
285
+ workbench-0.8.190.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
286
+ workbench-0.8.190.dist-info/entry_points.txt,sha256=o7ohD4D2oygnHp7i9-C0LfcHDuPW5Tv0JXGAg97DpGk,413
287
+ workbench-0.8.190.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
288
+ workbench-0.8.190.dist-info/RECORD,,
@@ -1,136 +0,0 @@
1
- # Model: feature_space_proximity
2
- #
3
- # Description: The feature_space_proximity model computes nearest neighbors for the given feature space
4
- #
5
-
6
- # Template Placeholders
7
- TEMPLATE_PARAMS = {
8
- "id_column": "udm_mol_bat_id",
9
- "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
10
- "target": "udm_asy_res_free_percent",
11
- "track_columns": None,
12
- }
13
-
14
- from io import StringIO
15
- import json
16
- import argparse
17
- import os
18
- import pandas as pd
19
-
20
- # Local Imports
21
- from proximity import Proximity
22
-
23
-
24
- # Function to check if dataframe is empty
25
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
26
- """Check if the DataFrame is empty and raise an error if so."""
27
- if df.empty:
28
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
29
- print(msg)
30
- raise ValueError(msg)
31
-
32
-
33
- # Function to match DataFrame columns to model features (case-insensitive)
34
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
35
- """Match and rename DataFrame columns to match the model's features, case-insensitively."""
36
- # Create a set of exact matches from the DataFrame columns
37
- exact_match_set = set(df.columns)
38
-
39
- # Create a case-insensitive map of DataFrame columns
40
- column_map = {col.lower(): col for col in df.columns}
41
- rename_dict = {}
42
-
43
- # Build a dictionary for renaming columns based on case-insensitive matching
44
- for feature in model_features:
45
- if feature in exact_match_set:
46
- rename_dict[feature] = feature
47
- elif feature.lower() in column_map:
48
- rename_dict[column_map[feature.lower()]] = feature
49
-
50
- # Rename columns in the DataFrame to match model features
51
- return df.rename(columns=rename_dict)
52
-
53
-
54
- # TRAINING SECTION
55
- #
56
- # This section (__main__) is where SageMaker will execute the training job
57
- # and save the model artifacts to the model directory.
58
- #
59
- if __name__ == "__main__":
60
- # Template Parameters
61
- id_column = TEMPLATE_PARAMS["id_column"]
62
- features = TEMPLATE_PARAMS["features"]
63
- target = TEMPLATE_PARAMS["target"] # Can be None for unsupervised models
64
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
65
-
66
- # Script arguments for input/output directories
67
- parser = argparse.ArgumentParser()
68
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
69
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
70
- parser.add_argument(
71
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
72
- )
73
- args = parser.parse_args()
74
-
75
- # Load training data from the specified directory
76
- training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
77
- all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
78
-
79
- # Check if the DataFrame is empty
80
- check_dataframe(all_df, "training_df")
81
-
82
- # Create the Proximity model
83
- model = Proximity(all_df, id_column, features, target, track_columns=track_columns)
84
-
85
- # Now serialize the model
86
- model.serialize(args.model_dir)
87
-
88
-
89
- # Model loading and prediction functions
90
- def model_fn(model_dir):
91
-
92
- # Deserialize the model
93
- model = Proximity.deserialize(model_dir)
94
- return model
95
-
96
-
97
- def input_fn(input_data, content_type):
98
- """Parse input data and return a DataFrame."""
99
- if not input_data:
100
- raise ValueError("Empty input data is not supported!")
101
-
102
- # Decode bytes to string if necessary
103
- if isinstance(input_data, bytes):
104
- input_data = input_data.decode("utf-8")
105
-
106
- if "text/csv" in content_type:
107
- return pd.read_csv(StringIO(input_data))
108
- elif "application/json" in content_type:
109
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
110
- else:
111
- raise ValueError(f"{content_type} not supported!")
112
-
113
-
114
- def output_fn(output_df, accept_type):
115
- """Supports both CSV and JSON output formats."""
116
- use_explicit_na = False
117
- if "text/csv" in accept_type:
118
- if use_explicit_na:
119
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
120
- else:
121
- csv_output = output_df.to_csv(index=False)
122
- return csv_output, "text/csv"
123
- elif "application/json" in accept_type:
124
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
125
- else:
126
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
127
-
128
-
129
- # Prediction function
130
- def predict_fn(df, model):
131
- # Match column names before prediction if needed
132
- df = match_features_case_insensitive(df, model.features + [model.id_column])
133
-
134
- # Compute Nearest neighbors
135
- df = model.neighbors(df)
136
- return df