workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -1,211 +0,0 @@
1
- """ExecutionEnvironment provides logic/functionality to figure out the current execution environment"""
2
-
3
- import os
4
- import sys
5
- import logging
6
- import requests
7
- from typing import Union
8
- import boto3
9
- from datetime import datetime, timezone
10
-
11
- # Workbench imports
12
- from workbench.utils.glue_utils import get_resolved_options
13
- from workbench.utils.deprecated_utils import deprecated
14
-
15
- # Set up the logger
16
- log = logging.getLogger("workbench")
17
-
18
-
19
- def running_on_glue():
20
- """
21
- Check if the current execution environment is an AWS Glue job.
22
-
23
- Returns:
24
- bool: True if running in AWS Glue environment, False otherwise.
25
- """
26
- # Check if GLUE_VERSION or GLUE_PYTHON_VERSION is in the environment
27
- if "GLUE_VERSION" in os.environ or "GLUE_PYTHON_VERSION" in os.environ:
28
- log.info("Running in AWS Glue Environment...")
29
- return True
30
- else:
31
- return False
32
-
33
-
34
- def running_on_lambda():
35
- """
36
- Check if the current execution environment is an AWS Lambda function.
37
-
38
- Returns:
39
- bool: True if running in AWS Lambda environment, False otherwise.
40
- """
41
- if "AWS_LAMBDA_FUNCTION_NAME" in os.environ:
42
- log.info("Running in AWS Lambda Environment...")
43
- return True
44
- else:
45
- return False
46
-
47
-
48
- def running_on_docker() -> bool:
49
- """Check if the current environment is running on a Docker container.
50
-
51
- Returns:
52
- bool: True if running in a Docker container, False otherwise.
53
- """
54
- try:
55
- # Docker creates a .dockerenv file at the root of the directory tree inside the container.
56
- # If this file exists, it is very likely that we are running inside a Docker container.
57
- with open("/.dockerenv") as f:
58
- return True
59
- except FileNotFoundError:
60
- pass
61
-
62
- try:
63
- # Another method is to check the contents of /proc/self/cgroup which should be different
64
- # inside a Docker container.
65
- with open("/proc/self/cgroup") as f:
66
- if any("docker" in line for line in f):
67
- return True
68
- except FileNotFoundError:
69
- pass
70
-
71
- # Check if we are running on ECS
72
- if running_on_ecs():
73
- return True
74
-
75
- # Probably not running in a Docker container
76
- return False
77
-
78
-
79
- def running_on_ecs() -> bool:
80
- """
81
- Check if the current environment is running on AWS ECS.
82
-
83
- Returns:
84
- bool: True if running on AWS ECS, False otherwise.
85
- """
86
- indicators = [
87
- "ECS_SERVICE_NAME",
88
- "ECS_CONTAINER_METADATA_URI",
89
- "ECS_CONTAINER_METADATA_URI_V4",
90
- "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
91
- "AWS_EXECUTION_ENV",
92
- ]
93
- return any(indicator in os.environ for indicator in indicators)
94
-
95
-
96
- def running_as_service() -> bool:
97
- """
98
- Check if the current environment is running as a service (e.g. Docker, ECS, Glue, Lambda).
99
-
100
- Returns:
101
- bool: True if running as a service, False otherwise.
102
- """
103
- return running_on_docker() or running_on_glue() or running_on_lambda()
104
-
105
-
106
- def _glue_job_from_script_name(args):
107
- """Get the Glue Job Name from the script name"""
108
- try:
109
- script_name = args["scriptLocation"]
110
- return os.path.splitext(os.path.basename(script_name))[0]
111
- except Exception:
112
- return "unknown"
113
-
114
-
115
- def glue_job_name():
116
- """Get the Glue Job Name from the environment or script name"""
117
- # Define the required argument
118
- args = get_resolved_options(sys.argv)
119
-
120
- # Get the job name
121
- job_name = args.get("JOB_NAME") or _glue_job_from_script_name(args)
122
- return job_name
123
-
124
-
125
- @deprecated(version=0.9)
126
- def glue_job_run_id(job_name: str, session: boto3.Session) -> Union[str, None]:
127
- """Retrieve the Glue Job Run ID closest to the current time for the given job name.
128
- Note: This mostly doesn't work, it will grab A glue job id but often not the correct one.
129
- For now, I would just skip using this
130
- """
131
- try:
132
- # Set current time in UTC
133
- current_time = datetime.now(timezone.utc)
134
-
135
- job_runs = session.client("glue").get_job_runs(JobName=job_name)
136
- if job_runs["JobRuns"]:
137
- # Find the job run with the StartedOn time closest to the current time
138
- closest_job_run = min(job_runs["JobRuns"], key=lambda run: abs(run["StartedOn"] - current_time))
139
- job_id = closest_job_run["Id"]
140
- return job_id[:9] # Shorten the Job Run ID to 9 characters
141
-
142
- log.error(f"No runs found for Glue Job '{job_name}', returning None for Job Run ID.")
143
- return None
144
-
145
- except session.client("glue").exceptions.EntityNotFoundException:
146
- log.error(f"Glue Job '{job_name}' not found, returning None for Job Run ID.")
147
- return None
148
- except Exception as e:
149
- log.error(f"An error occurred while retrieving job run ID: {e}")
150
- return None
151
-
152
-
153
- def ecs_job_name():
154
- """Get the ECS Job Name from the metadata endpoint or environment variables."""
155
- # Attempt to get the job name from ECS metadata
156
- ecs_metadata_uri = os.environ.get("ECS_CONTAINER_METADATA_URI_V4")
157
-
158
- if ecs_metadata_uri:
159
- try:
160
- response = requests.get(f"{ecs_metadata_uri}/task")
161
- if response.status_code == 200:
162
- metadata = response.json()
163
- job_name = metadata.get("Family") # 'Family' represents the ECS task definition family name
164
- if job_name:
165
- return job_name
166
- except requests.RequestException as e:
167
- # Log the error or handle it as needed
168
- log.error(f"Failed to fetch ECS metadata: {e}")
169
-
170
- # Fallback to environment variables if metadata is not available
171
- job_name = os.environ.get("ECS_SERVICE_NAME", "unknown")
172
- return job_name
173
-
174
-
175
- if __name__ == "__main__":
176
- """Test the Execution Environment utilities"""
177
-
178
- # Test running_on_glue
179
- assert running_on_glue() is False
180
- os.environ["GLUE_VERSION"] = "1.0"
181
- assert running_on_glue() is True
182
- del os.environ["GLUE_VERSION"]
183
-
184
- # Test running_on_lambda
185
- assert running_on_lambda() is False
186
- os.environ["AWS_LAMBDA_FUNCTION_NAME"] = "my_lambda_function"
187
- assert running_on_lambda() is True
188
- del os.environ["AWS_LAMBDA_FUNCTION_NAME"]
189
-
190
- # Test running_on_docker
191
- assert running_on_docker() is False
192
- os.environ["ECS_CONTAINER_METADATA_URI"] = "http://localhost:8080"
193
- assert running_on_docker() is True
194
- del os.environ["ECS_CONTAINER_METADATA_URI"]
195
-
196
- # Test running_on_ecs
197
- assert running_on_ecs() is False
198
- os.environ["ECS_CONTAINER_METADATA_URI"] = "http://localhost:8080"
199
- assert running_on_ecs() is True
200
- del os.environ["ECS_CONTAINER_METADATA_URI"]
201
-
202
- # Test getting the Glue Job Name
203
- print(glue_job_name())
204
-
205
- # Test getting the Glue Job Run ID
206
- from workbench.core.cloud_platform.aws.aws_session import AWSSession
207
-
208
- session = AWSSession().boto3_session
209
- print(glue_job_run_id("Test_Workbench_Shell", session))
210
-
211
- print("All tests passed!")
@@ -1,167 +0,0 @@
1
- """Fast Inference on SageMaker Endpoints"""
2
-
3
- import pandas as pd
4
- from io import StringIO
5
- import logging
6
- from concurrent.futures import ThreadPoolExecutor
7
-
8
- # Sagemaker Imports
9
- import sagemaker
10
- from sagemaker.serializers import CSVSerializer
11
- from sagemaker.deserializers import CSVDeserializer
12
- from sagemaker import Predictor
13
-
14
- log = logging.getLogger("workbench")
15
-
16
- _CACHED_SM_SESSION = None
17
-
18
-
19
- def get_or_create_sm_session():
20
- global _CACHED_SM_SESSION
21
- if _CACHED_SM_SESSION is None:
22
- _CACHED_SM_SESSION = sagemaker.Session()
23
- return _CACHED_SM_SESSION
24
-
25
-
26
- def fast_inference(endpoint_name: str, eval_df: pd.DataFrame, sm_session=None, threads: int = 4) -> pd.DataFrame:
27
- """Run inference on the Endpoint using the provided DataFrame
28
-
29
- Args:
30
- endpoint_name (str): The name of the Endpoint
31
- eval_df (pd.DataFrame): The DataFrame to run predictions on
32
- sm_session (sagemaker.session.Session, optional): SageMaker Session. If None, a cached session is created.
33
- threads (int): The number of threads to use (default: 4)
34
-
35
- Returns:
36
- pd.DataFrame: The DataFrame with predictions
37
- """
38
- # Use cached session if none is provided
39
- if sm_session is None:
40
- sm_session = get_or_create_sm_session()
41
-
42
- predictor = Predictor(
43
- endpoint_name,
44
- sagemaker_session=sm_session,
45
- serializer=CSVSerializer(),
46
- deserializer=CSVDeserializer(),
47
- )
48
-
49
- total_rows = len(eval_df)
50
-
51
- def process_chunk(chunk_df: pd.DataFrame, start_index: int) -> pd.DataFrame:
52
- log.info(f"Processing {start_index}:{min(start_index + chunk_size, total_rows)} out of {total_rows} rows...")
53
- csv_buffer = StringIO()
54
- chunk_df.to_csv(csv_buffer, index=False)
55
- response = predictor.predict(csv_buffer.getvalue())
56
- # CSVDeserializer returns a nested list: first row is headers
57
- return pd.DataFrame.from_records(response[1:], columns=response[0])
58
-
59
- # Sagemaker has a connection pool limit of 10
60
- if threads > 10:
61
- log.warning("Sagemaker has a connection pool limit of 10. Reducing threads to 10.")
62
- threads = 10
63
-
64
- # Compute the chunk size (divide number of threads)
65
- chunk_size = max(1, total_rows // threads)
66
-
67
- # We also need to ensure that the chunk size is not too big
68
- if chunk_size > 100:
69
- chunk_size = 100
70
-
71
- # Split DataFrame into chunks and process them concurrently
72
- chunks = [(eval_df[i : i + chunk_size], i) for i in range(0, total_rows, chunk_size)]
73
- with ThreadPoolExecutor(max_workers=threads) as executor:
74
- df_list = list(executor.map(lambda p: process_chunk(*p), chunks))
75
-
76
- combined_df = pd.concat(df_list, ignore_index=True)
77
-
78
- # Convert the types of the dataframe
79
- combined_df = df_type_conversions(combined_df)
80
- return combined_df
81
-
82
-
83
- def df_type_conversions(df: pd.DataFrame) -> pd.DataFrame:
84
- """Convert the types of the dataframe that we get from an endpoint
85
-
86
- Args:
87
- df (pd.DataFrame): DataFrame to convert
88
-
89
- Returns:
90
- pd.DataFrame: Converted DataFrame
91
- """
92
- # Some endpoints will put in "N/A" values (for CSV serialization)
93
- # We need to convert these to NaN and the run the conversions below
94
- # Report on the number of N/A values in each column in the DataFrame
95
- # For any count above 0 list the column name and the number of N/A values
96
- na_counts = df.isin(["N/A"]).sum()
97
- for column, count in na_counts.items():
98
- if count > 0:
99
- log.warning(f"{column} has {count} N/A values, converting to NaN")
100
- pd.set_option("future.no_silent_downcasting", True)
101
- df = df.replace("N/A", float("nan"))
102
-
103
- # Convert data to numeric
104
- # Note: Since we're using CSV serializers numeric columns often get changed to generic 'object' types
105
-
106
- # Hard Conversion
107
- # Note: We explicitly catch exceptions for columns that cannot be converted to numeric
108
- for column in df.columns:
109
- try:
110
- df[column] = pd.to_numeric(df[column])
111
- except ValueError:
112
- # If a ValueError is raised, the column cannot be converted to numeric, so we keep it as is
113
- pass
114
- except TypeError:
115
- # This typically means a duplicated column name, so confirm duplicate (more than 1) and log it
116
- column_count = (df.columns == column).sum()
117
- log.critical(f"{column} occurs {column_count} times in the DataFrame.")
118
- pass
119
-
120
- # Soft Conversion
121
- # Convert columns to the best possible dtype that supports the pd.NA missing value.
122
- df = df.convert_dtypes()
123
-
124
- # Convert pd.NA placeholders to pd.NA
125
- # Note: CSV serialization converts pd.NA to blank strings, so we have to put in placeholders
126
- df.replace("__NA__", pd.NA, inplace=True)
127
-
128
- # Check for True/False values in the string columns
129
- for column in df.select_dtypes(include=["string"]).columns:
130
- if df[column].str.lower().isin(["true", "false"]).all():
131
- df[column] = df[column].str.lower().map({"true": True, "false": False})
132
-
133
- # Return the Dataframe
134
- return df
135
-
136
-
137
- if __name__ == "__main__":
138
- """Exercise the Endpoint Utilities"""
139
- import time
140
- from workbench.api.endpoint import Endpoint
141
- from workbench.utils.endpoint_utils import fs_training_data, fs_evaluation_data
142
-
143
- # Create an Endpoint
144
- my_endpoint_name = "abalone-regression"
145
- my_endpoint = Endpoint(my_endpoint_name)
146
- if not my_endpoint.exists():
147
- print(f"Endpoint {my_endpoint_name} does not exist.")
148
- exit(1)
149
-
150
- # Get the training data
151
- my_train_df = fs_training_data(my_endpoint)
152
- print(my_train_df)
153
-
154
- # Run Fast Inference and time it
155
- my_sm_session = my_endpoint.sm_session
156
- my_eval_df = fs_evaluation_data(my_endpoint)
157
- start_time = time.time()
158
- my_results_df = fast_inference(my_endpoint_name, my_eval_df, my_sm_session)
159
- end_time = time.time()
160
- print(f"Fast Inference took {end_time - start_time} seconds")
161
- print(my_results_df)
162
- print(my_results_df.info())
163
-
164
- # Test with no session
165
- my_results_df = fast_inference(my_endpoint_name, my_eval_df)
166
- print(my_results_df)
167
- print(my_results_df.info())
@@ -1,39 +0,0 @@
1
- """Resource utilities for Workbench"""
2
-
3
- import sys
4
- import importlib.resources as resources
5
- import pathlib
6
- import pkg_resources
7
-
8
-
9
- def get_resource_path(package: str, resource: str) -> pathlib.Path:
10
- """Get the path to a resource file, compatible with Python 3.9 and higher.
11
-
12
- Args:
13
- package (str): The package where the resource is located.
14
- resource (str): The name of the resource file.
15
-
16
- Returns:
17
- pathlib.Path: The path to the resource file.
18
- """
19
- if sys.version_info >= (3, 10):
20
- # Python 3.10 and higher: use importlib.resources.path
21
- with resources.path(package, resource) as path:
22
- return path
23
- else:
24
- # Python 3.9 and lower: manually construct the path based on package location
25
- # Get the location of the installed package
26
- package_location = pathlib.Path(pkg_resources.get_distribution(package.split(".")[0]).location)
27
- resource_path = package_location / package.replace(".", "/") / resource
28
-
29
- if resource_path.exists():
30
- return resource_path
31
- else:
32
- raise FileNotFoundError(f"Resource '{resource}' not found in package '{package}'.")
33
-
34
-
35
- if __name__ == "__main__":
36
- # Test the resource utilities
37
- with get_resource_path("workbench.resources", "open_source_api.key") as open_source_key_path:
38
- with open(open_source_key_path, "r") as key_file:
39
- print(key_file.read().strip())
@@ -1,5 +0,0 @@
1
- [console_scripts]
2
- cloud_watch = workbench.scripts.monitor_cloud_watch:main
3
- glue_launcher = workbench.scripts.glue_launcher:main
4
- workbench = workbench.repl.workbench_shell:launch_shell
5
- workbench_config = workbench.scripts.show_config:main