workbench 0.8.198__py3-none-any.whl → 0.8.203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +11 -4
- workbench/api/__init__.py +2 -1
- workbench/api/df_store.py +17 -108
- workbench/api/feature_set.py +48 -11
- workbench/api/model.py +1 -1
- workbench/api/parameter_store.py +3 -52
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +261 -78
- workbench/core/artifacts/feature_set_core.py +69 -1
- workbench/core/artifacts/model_core.py +48 -14
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/transforms/features_to_model/features_to_model.py +50 -33
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
- workbench/core/views/view.py +2 -2
- workbench/model_scripts/chemprop/chemprop.template +933 -0
- workbench/model_scripts/chemprop/generated_model_script.py +933 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
- workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
- workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
- workbench/model_scripts/pytorch_model/generated_model_script.py +365 -173
- workbench/model_scripts/pytorch_model/pytorch.template +362 -170
- workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
- workbench/model_scripts/script_generation.py +10 -7
- workbench/model_scripts/uq_models/generated_model_script.py +43 -27
- workbench/model_scripts/uq_models/mapie.template +40 -24
- workbench/model_scripts/xgb_model/generated_model_script.py +36 -7
- workbench/model_scripts/xgb_model/xgb_model.template +36 -7
- workbench/repl/workbench_shell.py +14 -5
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
- workbench/utils/chemprop_utils.py +761 -0
- workbench/utils/pytorch_utils.py +527 -0
- workbench/utils/xgboost_model_utils.py +10 -5
- workbench/web_interface/components/model_plot.py +7 -1
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/METADATA +3 -3
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/RECORD +49 -43
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/entry_points.txt +2 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
- workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/WHEEL +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/top_level.txt +0 -0
|
@@ -6,11 +6,13 @@ import numpy as np
|
|
|
6
6
|
# Model Performance Scores
|
|
7
7
|
from sklearn.metrics import (
|
|
8
8
|
mean_absolute_error,
|
|
9
|
+
median_absolute_error,
|
|
9
10
|
r2_score,
|
|
10
11
|
root_mean_squared_error,
|
|
11
12
|
precision_recall_fscore_support,
|
|
12
13
|
confusion_matrix,
|
|
13
14
|
)
|
|
15
|
+
from scipy.stats import spearmanr
|
|
14
16
|
|
|
15
17
|
# Classification Encoder
|
|
16
18
|
from sklearn.preprocessing import LabelEncoder
|
|
@@ -26,6 +28,28 @@ import os
|
|
|
26
28
|
import pandas as pd
|
|
27
29
|
from typing import List, Tuple
|
|
28
30
|
|
|
31
|
+
# Default Hyperparameters for XGBoost
|
|
32
|
+
DEFAULT_HYPERPARAMETERS = {
|
|
33
|
+
# Core tree parameters
|
|
34
|
+
"n_estimators": 200, # More trees for better signal capture when we have lots of features
|
|
35
|
+
"max_depth": 6, # Medium depth
|
|
36
|
+
"learning_rate": 0.05, # Lower rate with more estimators for smoother learning
|
|
37
|
+
|
|
38
|
+
# Sampling parameters
|
|
39
|
+
"subsample": 0.7, # Moderate row sampling to reduce overfitting
|
|
40
|
+
"colsample_bytree": 0.6, # More aggressive feature sampling given lots of features
|
|
41
|
+
"colsample_bylevel": 0.8, # Additional feature sampling at each tree level
|
|
42
|
+
|
|
43
|
+
# Regularization
|
|
44
|
+
"min_child_weight": 5, # Higher to prevent overfitting on small groups
|
|
45
|
+
"gamma": 0.2, # Moderate pruning - you have real signal so don't over-prune
|
|
46
|
+
"reg_alpha": 0.5, # L1 for feature selection (useful with many features)
|
|
47
|
+
"reg_lambda": 2.0, # Strong L2 to smooth predictions
|
|
48
|
+
|
|
49
|
+
# Random seed
|
|
50
|
+
"random_state": 42,
|
|
51
|
+
}
|
|
52
|
+
|
|
29
53
|
# Template Parameters
|
|
30
54
|
TEMPLATE_PARAMS = {
|
|
31
55
|
"model_type": "classifier",
|
|
@@ -208,7 +232,7 @@ if __name__ == "__main__":
|
|
|
208
232
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
209
233
|
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
210
234
|
train_all_data = TEMPLATE_PARAMS["train_all_data"]
|
|
211
|
-
hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
|
|
235
|
+
hyperparameters = {**DEFAULT_HYPERPARAMETERS, **TEMPLATE_PARAMS["hyperparameters"]}
|
|
212
236
|
validation_split = 0.2
|
|
213
237
|
|
|
214
238
|
# Script arguments for input/output directories
|
|
@@ -325,13 +349,13 @@ if __name__ == "__main__":
|
|
|
325
349
|
target: label_names,
|
|
326
350
|
"precision": scores[0],
|
|
327
351
|
"recall": scores[1],
|
|
328
|
-
"
|
|
352
|
+
"f1": scores[2],
|
|
329
353
|
"support": scores[3],
|
|
330
354
|
}
|
|
331
355
|
)
|
|
332
356
|
|
|
333
357
|
# We need to get creative with the Classification Metrics
|
|
334
|
-
metrics = ["precision", "recall", "
|
|
358
|
+
metrics = ["precision", "recall", "f1", "support"]
|
|
335
359
|
for t in label_names:
|
|
336
360
|
for m in metrics:
|
|
337
361
|
value = score_df.loc[score_df[target] == t, m].iloc[0]
|
|
@@ -348,11 +372,16 @@ if __name__ == "__main__":
|
|
|
348
372
|
# Calculate various model performance metrics (regression)
|
|
349
373
|
rmse = root_mean_squared_error(y_validate, preds)
|
|
350
374
|
mae = mean_absolute_error(y_validate, preds)
|
|
375
|
+
medae = median_absolute_error(y_validate, preds)
|
|
351
376
|
r2 = r2_score(y_validate, preds)
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
print(f"
|
|
355
|
-
print(f"
|
|
377
|
+
spearman_corr = spearmanr(y_validate, preds).correlation
|
|
378
|
+
support = len(df_val)
|
|
379
|
+
print(f"rmse: {rmse:.3f}")
|
|
380
|
+
print(f"mae: {mae:.3f}")
|
|
381
|
+
print(f"medae: {medae:.3f}")
|
|
382
|
+
print(f"r2: {r2:.3f}")
|
|
383
|
+
print(f"spearmanr: {spearman_corr:.3f}")
|
|
384
|
+
print(f"support: {support}")
|
|
356
385
|
|
|
357
386
|
# Now save the model to the standard place/name
|
|
358
387
|
joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
|
|
@@ -6,11 +6,13 @@ import numpy as np
|
|
|
6
6
|
# Model Performance Scores
|
|
7
7
|
from sklearn.metrics import (
|
|
8
8
|
mean_absolute_error,
|
|
9
|
+
median_absolute_error,
|
|
9
10
|
r2_score,
|
|
10
11
|
root_mean_squared_error,
|
|
11
12
|
precision_recall_fscore_support,
|
|
12
13
|
confusion_matrix,
|
|
13
14
|
)
|
|
15
|
+
from scipy.stats import spearmanr
|
|
14
16
|
|
|
15
17
|
# Classification Encoder
|
|
16
18
|
from sklearn.preprocessing import LabelEncoder
|
|
@@ -26,6 +28,28 @@ import os
|
|
|
26
28
|
import pandas as pd
|
|
27
29
|
from typing import List, Tuple
|
|
28
30
|
|
|
31
|
+
# Default Hyperparameters for XGBoost
|
|
32
|
+
DEFAULT_HYPERPARAMETERS = {
|
|
33
|
+
# Core tree parameters
|
|
34
|
+
"n_estimators": 200, # More trees for better signal capture when we have lots of features
|
|
35
|
+
"max_depth": 6, # Medium depth
|
|
36
|
+
"learning_rate": 0.05, # Lower rate with more estimators for smoother learning
|
|
37
|
+
|
|
38
|
+
# Sampling parameters
|
|
39
|
+
"subsample": 0.7, # Moderate row sampling to reduce overfitting
|
|
40
|
+
"colsample_bytree": 0.6, # More aggressive feature sampling given lots of features
|
|
41
|
+
"colsample_bylevel": 0.8, # Additional feature sampling at each tree level
|
|
42
|
+
|
|
43
|
+
# Regularization
|
|
44
|
+
"min_child_weight": 5, # Higher to prevent overfitting on small groups
|
|
45
|
+
"gamma": 0.2, # Moderate pruning - you have real signal so don't over-prune
|
|
46
|
+
"reg_alpha": 0.5, # L1 for feature selection (useful with many features)
|
|
47
|
+
"reg_lambda": 2.0, # Strong L2 to smooth predictions
|
|
48
|
+
|
|
49
|
+
# Random seed
|
|
50
|
+
"random_state": 42,
|
|
51
|
+
}
|
|
52
|
+
|
|
29
53
|
# Template Parameters
|
|
30
54
|
TEMPLATE_PARAMS = {
|
|
31
55
|
"model_type": "{{model_type}}",
|
|
@@ -208,7 +232,7 @@ if __name__ == "__main__":
|
|
|
208
232
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
209
233
|
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
210
234
|
train_all_data = TEMPLATE_PARAMS["train_all_data"]
|
|
211
|
-
hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
|
|
235
|
+
hyperparameters = {**DEFAULT_HYPERPARAMETERS, **TEMPLATE_PARAMS["hyperparameters"]}
|
|
212
236
|
validation_split = 0.2
|
|
213
237
|
|
|
214
238
|
# Script arguments for input/output directories
|
|
@@ -325,13 +349,13 @@ if __name__ == "__main__":
|
|
|
325
349
|
target: label_names,
|
|
326
350
|
"precision": scores[0],
|
|
327
351
|
"recall": scores[1],
|
|
328
|
-
"
|
|
352
|
+
"f1": scores[2],
|
|
329
353
|
"support": scores[3],
|
|
330
354
|
}
|
|
331
355
|
)
|
|
332
356
|
|
|
333
357
|
# We need to get creative with the Classification Metrics
|
|
334
|
-
metrics = ["precision", "recall", "
|
|
358
|
+
metrics = ["precision", "recall", "f1", "support"]
|
|
335
359
|
for t in label_names:
|
|
336
360
|
for m in metrics:
|
|
337
361
|
value = score_df.loc[score_df[target] == t, m].iloc[0]
|
|
@@ -348,11 +372,16 @@ if __name__ == "__main__":
|
|
|
348
372
|
# Calculate various model performance metrics (regression)
|
|
349
373
|
rmse = root_mean_squared_error(y_validate, preds)
|
|
350
374
|
mae = mean_absolute_error(y_validate, preds)
|
|
375
|
+
medae = median_absolute_error(y_validate, preds)
|
|
351
376
|
r2 = r2_score(y_validate, preds)
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
print(f"
|
|
355
|
-
print(f"
|
|
377
|
+
spearman_corr = spearmanr(y_validate, preds).correlation
|
|
378
|
+
support = len(df_val)
|
|
379
|
+
print(f"rmse: {rmse:.3f}")
|
|
380
|
+
print(f"mae: {mae:.3f}")
|
|
381
|
+
print(f"medae: {medae:.3f}")
|
|
382
|
+
print(f"r2: {r2:.3f}")
|
|
383
|
+
print(f"spearmanr: {spearman_corr:.3f}")
|
|
384
|
+
print(f"support: {support}")
|
|
356
385
|
|
|
357
386
|
# Now save the model to the standard place/name
|
|
358
387
|
joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
|
|
@@ -1,16 +1,25 @@
|
|
|
1
|
+
# flake8: noqa: E402
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import logging
|
|
5
|
+
import importlib
|
|
6
|
+
import webbrowser
|
|
7
|
+
import readline # noqa: F401
|
|
8
|
+
|
|
9
|
+
# Disable OpenMP parallelism to avoid segfaults with PyTorch in iPython
|
|
10
|
+
# This is a known issue on macOS where libomp crashes during thread synchronization
|
|
11
|
+
# Must be set before importing numpy/pandas/torch or any library that uses OpenMP
|
|
12
|
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
|
13
|
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
|
14
|
+
|
|
1
15
|
import IPython
|
|
2
16
|
from IPython import start_ipython
|
|
3
17
|
from distutils.version import LooseVersion
|
|
4
18
|
from IPython.terminal.prompts import Prompts
|
|
5
19
|
from IPython.terminal.ipapp import load_default_config
|
|
6
20
|
from pygments.token import Token
|
|
7
|
-
import sys
|
|
8
|
-
import logging
|
|
9
|
-
import importlib
|
|
10
21
|
import botocore
|
|
11
|
-
import webbrowser
|
|
12
22
|
import pandas as pd
|
|
13
|
-
import readline # noqa
|
|
14
23
|
|
|
15
24
|
try:
|
|
16
25
|
import matplotlib.pyplot as plt # noqa
|
|
@@ -1 +1 @@
|
|
|
1
|
-
|
|
1
|
+
eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTEyLTA1In1IsmpkuybFALADkRj_RfmkQ0LAIsQeXRE7Uoc3DL1UrDr-rSnwu-PDqsKBUkX6jPRFZV3DLxNjBapxPeEIFhfvxvjzz_sc6CwtxNpZ3bPmxSPs2W-j3xZS4-XyEqIilcwSkWh-NU1u27gCuuivn5eiUmIYJGAp0wdVkeE6_Z9dlg==
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for SageMaker model scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python model_script_harness.py <local_script.py> <model_name>
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
python model_script_harness.py pytorch.py aqsol-pytorch-reg
|
|
9
|
+
|
|
10
|
+
This allows you to test LOCAL changes to a model script against deployed model artifacts.
|
|
11
|
+
Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
|
|
12
|
+
|
|
13
|
+
Optional: testing/env.json with additional environment variables
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import json
|
|
19
|
+
import importlib.util
|
|
20
|
+
import tempfile
|
|
21
|
+
import shutil
|
|
22
|
+
import pandas as pd
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
# Workbench Imports
|
|
26
|
+
from workbench.api import Model, FeatureSet
|
|
27
|
+
from workbench.utils.pytorch_utils import download_and_extract_model
|
|
28
|
+
|
|
29
|
+
# Force CPU mode BEFORE any PyTorch imports to avoid MPS/CUDA issues on Mac
|
|
30
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
31
|
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
32
|
+
torch.set_default_device("cpu")
|
|
33
|
+
# Disable MPS entirely
|
|
34
|
+
if hasattr(torch.backends, "mps"):
|
|
35
|
+
torch.backends.mps.is_available = lambda: False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_eval_data(workbench_model: Model) -> pd.DataFrame:
|
|
39
|
+
"""Get evaluation data from the FeatureSet associated with this model."""
|
|
40
|
+
# Get the FeatureSet
|
|
41
|
+
fs_name = workbench_model.get_input()
|
|
42
|
+
fs = FeatureSet(fs_name)
|
|
43
|
+
if not fs.exists():
|
|
44
|
+
raise ValueError(f"No FeatureSet found: {fs_name}")
|
|
45
|
+
|
|
46
|
+
# Get evaluation data (training = FALSE)
|
|
47
|
+
table = workbench_model.training_view().table
|
|
48
|
+
print(f"Querying evaluation data from {table}...")
|
|
49
|
+
eval_df = fs.query(f'SELECT * FROM "{table}" WHERE training = FALSE')
|
|
50
|
+
print(f"Retrieved {len(eval_df)} evaluation rows")
|
|
51
|
+
|
|
52
|
+
return eval_df
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def load_model_script(script_path: str):
|
|
56
|
+
"""Dynamically load the model script module."""
|
|
57
|
+
if not os.path.exists(script_path):
|
|
58
|
+
raise FileNotFoundError(f"Script not found: {script_path}")
|
|
59
|
+
|
|
60
|
+
spec = importlib.util.spec_from_file_location("model_script", script_path)
|
|
61
|
+
module = importlib.util.module_from_spec(spec)
|
|
62
|
+
|
|
63
|
+
# Add to sys.modules so imports within the script work
|
|
64
|
+
sys.modules["model_script"] = module
|
|
65
|
+
|
|
66
|
+
spec.loader.exec_module(module)
|
|
67
|
+
return module
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def main():
|
|
71
|
+
if len(sys.argv) < 3:
|
|
72
|
+
print("Usage: python model_script_harness.py <local_script.py> <model_name>")
|
|
73
|
+
print("\nArguments:")
|
|
74
|
+
print(" local_script.py - Path to your LOCAL model script to test")
|
|
75
|
+
print(" model_name - Workbench model name (e.g., aqsol-pytorch-reg)")
|
|
76
|
+
print("\nOptional: testing/env.json with additional environment variables")
|
|
77
|
+
sys.exit(1)
|
|
78
|
+
|
|
79
|
+
script_path = sys.argv[1]
|
|
80
|
+
model_name = sys.argv[2]
|
|
81
|
+
|
|
82
|
+
# Validate local script exists
|
|
83
|
+
if not os.path.exists(script_path):
|
|
84
|
+
print(f"Error: Local script not found: {script_path}")
|
|
85
|
+
sys.exit(1)
|
|
86
|
+
|
|
87
|
+
# Initialize Workbench model
|
|
88
|
+
print(f"Loading Workbench model: {model_name}")
|
|
89
|
+
workbench_model = Model(model_name)
|
|
90
|
+
print(f"Model Framework: {workbench_model.model_framework}")
|
|
91
|
+
print()
|
|
92
|
+
|
|
93
|
+
# Create a temporary model directory
|
|
94
|
+
model_dir = tempfile.mkdtemp(prefix="model_harness_")
|
|
95
|
+
print(f"Using model directory: {model_dir}")
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
# Load environment variables from env.json if it exists
|
|
99
|
+
if os.path.exists("testing/env.json"):
|
|
100
|
+
print("Loading environment variables from testing/env.json")
|
|
101
|
+
with open("testing/env.json") as f:
|
|
102
|
+
env_vars = json.load(f)
|
|
103
|
+
for key, value in env_vars.items():
|
|
104
|
+
os.environ[key] = value
|
|
105
|
+
print(f" Set {key} = {value}")
|
|
106
|
+
print()
|
|
107
|
+
|
|
108
|
+
# Set up SageMaker environment variables
|
|
109
|
+
os.environ["SM_MODEL_DIR"] = model_dir
|
|
110
|
+
print(f"Set SM_MODEL_DIR = {model_dir}")
|
|
111
|
+
|
|
112
|
+
# Download and extract model artifacts
|
|
113
|
+
s3_uri = workbench_model.model_data_url()
|
|
114
|
+
download_and_extract_model(s3_uri, model_dir)
|
|
115
|
+
print()
|
|
116
|
+
|
|
117
|
+
# Load the LOCAL model script
|
|
118
|
+
print(f"Loading LOCAL model script: {script_path}")
|
|
119
|
+
module = load_model_script(script_path)
|
|
120
|
+
print()
|
|
121
|
+
|
|
122
|
+
# Check for required functions
|
|
123
|
+
if not hasattr(module, "model_fn"):
|
|
124
|
+
raise AttributeError("Model script must have a model_fn function")
|
|
125
|
+
if not hasattr(module, "predict_fn"):
|
|
126
|
+
raise AttributeError("Model script must have a predict_fn function")
|
|
127
|
+
|
|
128
|
+
# Load the model
|
|
129
|
+
print("Calling model_fn...")
|
|
130
|
+
print("-" * 50)
|
|
131
|
+
model = module.model_fn(model_dir)
|
|
132
|
+
print("-" * 50)
|
|
133
|
+
print(f"Model loaded: {type(model)}")
|
|
134
|
+
print()
|
|
135
|
+
|
|
136
|
+
# Get evaluation data from FeatureSet
|
|
137
|
+
print("Pulling evaluation data from FeatureSet...")
|
|
138
|
+
df = get_eval_data(workbench_model)
|
|
139
|
+
print(f"Input shape: {df.shape}")
|
|
140
|
+
print(f"Columns: {df.columns.tolist()}")
|
|
141
|
+
print()
|
|
142
|
+
|
|
143
|
+
print("Calling predict_fn...")
|
|
144
|
+
print("-" * 50)
|
|
145
|
+
result = module.predict_fn(df, model)
|
|
146
|
+
print("-" * 50)
|
|
147
|
+
print()
|
|
148
|
+
|
|
149
|
+
print("Prediction result:")
|
|
150
|
+
print(f"Output shape: {result.shape}")
|
|
151
|
+
print(f"Output columns: {result.columns.tolist()}")
|
|
152
|
+
print()
|
|
153
|
+
print(result.head(10).to_string())
|
|
154
|
+
|
|
155
|
+
finally:
|
|
156
|
+
# Cleanup
|
|
157
|
+
print(f"\nCleaning up model directory: {model_dir}")
|
|
158
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
main()
|