sagemaker-core 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sagemaker-core might be problematic. Click here for more details.
- sagemaker_core/__init__.py +0 -0
- sagemaker_core/_version.py +11 -0
- sagemaker_core/code_injection/__init__.py +0 -0
- sagemaker_core/code_injection/base.py +42 -0
- sagemaker_core/code_injection/codec.py +241 -0
- sagemaker_core/code_injection/constants.py +18 -0
- sagemaker_core/code_injection/shape_dag.py +14527 -0
- sagemaker_core/generated/__init__.py +0 -0
- sagemaker_core/generated/config_schema.py +870 -0
- sagemaker_core/generated/exceptions.py +147 -0
- sagemaker_core/generated/intelligent_defaults_helper.py +198 -0
- sagemaker_core/generated/resources.py +26998 -0
- sagemaker_core/generated/shapes.py +11584 -0
- sagemaker_core/generated/utils.py +314 -0
- sagemaker_core/tools/__init__.py +1 -0
- sagemaker_core/tools/codegen.py +56 -0
- sagemaker_core/tools/constants.py +96 -0
- sagemaker_core/tools/data_extractor.py +49 -0
- sagemaker_core/tools/method.py +32 -0
- sagemaker_core/tools/resources_codegen.py +2122 -0
- sagemaker_core/tools/resources_extractor.py +373 -0
- sagemaker_core/tools/shapes_codegen.py +284 -0
- sagemaker_core/tools/shapes_extractor.py +259 -0
- sagemaker_core/tools/templates.py +747 -0
- sagemaker_core/util/__init__.py +0 -0
- sagemaker_core/util/util.py +81 -0
- sagemaker_core-0.1.3.dist-info/LICENSE +201 -0
- sagemaker_core-0.1.3.dist-info/METADATA +28 -0
- sagemaker_core-0.1.3.dist-info/RECORD +31 -0
- sagemaker_core-0.1.3.dist-info/WHEEL +5 -0
- sagemaker_core-0.1.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
class SageMakerCoreError(Exception):
|
|
2
|
+
"""Base class for all exceptions in SageMaker Core"""
|
|
3
|
+
|
|
4
|
+
fmt = "An unspecified error occurred."
|
|
5
|
+
|
|
6
|
+
def __init__(self, **kwargs):
|
|
7
|
+
"""Initialize a SageMakerCoreError exception.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
**kwargs: Keyword arguments to be formatted into the custom error message template.
|
|
11
|
+
"""
|
|
12
|
+
msg = self.fmt.format(**kwargs)
|
|
13
|
+
Exception.__init__(self, msg)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
### Generic Validation Errors
|
|
17
|
+
class ValidationError(SageMakerCoreError):
|
|
18
|
+
"""Raised when a validation error occurs."""
|
|
19
|
+
|
|
20
|
+
fmt = "An error occurred while validating user input/setup. {message}"
|
|
21
|
+
|
|
22
|
+
def __init__(self, message="", **kwargs):
|
|
23
|
+
"""Initialize a ValidationError exception.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
message (str): A message describing the error.
|
|
27
|
+
"""
|
|
28
|
+
super().__init__(message=message, **kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
### Waiter Errors
|
|
32
|
+
class WaiterError(SageMakerCoreError):
|
|
33
|
+
"""Raised when an error occurs while waiting."""
|
|
34
|
+
|
|
35
|
+
fmt = "An error occurred while waiting for {resource_type}. Final Resource State: {status}."
|
|
36
|
+
|
|
37
|
+
def __init__(self, resource_type="(Unkown)", status="(Unkown)", **kwargs):
|
|
38
|
+
"""Initialize a WaiterError exception.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
resource_type (str): The type of resource being waited on.
|
|
42
|
+
status (str): The final status of the resource.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__(resource_type=resource_type, status=status, **kwargs)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class FailedStatusError(WaiterError):
|
|
48
|
+
"""Raised when a resource enters a failed state."""
|
|
49
|
+
|
|
50
|
+
fmt = "Encountered unexpected failed state while waiting for {resource_type}. Final Resource State: {status}. Failure Reason: {reason}"
|
|
51
|
+
|
|
52
|
+
def __init__(self, resource_type="(Unkown)", status="(Unkown)", reason="(Unkown)"):
|
|
53
|
+
"""Initialize a FailedStatusError exception.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
resource_type (str): The type of resource being waited on.
|
|
57
|
+
status (str): The final status of the resource.
|
|
58
|
+
reason (str): The reason the resource entered a failed state.
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(resource_type=resource_type, status=status, reason=reason)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class DeleteFailedStatusError(WaiterError):
|
|
64
|
+
"""Raised when a resource enters a delete_failed state."""
|
|
65
|
+
|
|
66
|
+
fmt = "Encountered unexpected delete_failed state while deleting {resource_type}. Failure Reason: {reason}"
|
|
67
|
+
|
|
68
|
+
def __init__(self, resource_type="(Unkown)", reason="(Unkown)"):
|
|
69
|
+
"""Initialize a FailedStatusError exception.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
resource_type (str): The type of resource being waited on.
|
|
73
|
+
status (str): The final status of the resource.
|
|
74
|
+
reason (str): The reason the resource entered a failed state.
|
|
75
|
+
"""
|
|
76
|
+
super().__init__(resource_type=resource_type, reason=reason)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TimeoutExceededError(WaiterError):
|
|
80
|
+
"""Raised when a specified timeout is exceeded"""
|
|
81
|
+
|
|
82
|
+
fmt = "Timeout exceeded while waiting for {resource_type}. Final Resource State: {status}. Increase the timeout and try again."
|
|
83
|
+
|
|
84
|
+
def __init__(self, resource_type="(Unkown)", status="(Unkown)", reason="(Unkown)"):
|
|
85
|
+
"""Initialize a TimeoutExceededError exception.
|
|
86
|
+
Args:
|
|
87
|
+
resource_type (str): The type of resource being waited on.
|
|
88
|
+
status (str): The final status of the resource.
|
|
89
|
+
reason (str): The reason the resource entered a failed state.
|
|
90
|
+
"""
|
|
91
|
+
super().__init__(resource_type=resource_type, status=status, reason=reason)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
### Intelligent Defaults Errors
|
|
95
|
+
class IntelligentDefaultsError(SageMakerCoreError):
|
|
96
|
+
"""Raised when an error occurs in the Intelligent Defaults"""
|
|
97
|
+
|
|
98
|
+
fmt = "An error occurred while loading Intelligent Default. {message}"
|
|
99
|
+
|
|
100
|
+
def __init__(self, message="", **kwargs):
|
|
101
|
+
"""Initialize an IntelligentDefaultsError exception.
|
|
102
|
+
Args:
|
|
103
|
+
message (str): A message describing the error.
|
|
104
|
+
"""
|
|
105
|
+
super().__init__(message=message, **kwargs)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LocalConfigNotFoundError(IntelligentDefaultsError):
|
|
109
|
+
"""Raised when a configuration file is not found in local file system"""
|
|
110
|
+
|
|
111
|
+
fmt = "Failed to load configuration file from location: {file_path}. {message}"
|
|
112
|
+
|
|
113
|
+
def __init__(self, file_path="(Unkown)", message=""):
|
|
114
|
+
"""Initialize a LocalConfigNotFoundError exception.
|
|
115
|
+
Args:
|
|
116
|
+
file_path (str): The path to the configuration file.
|
|
117
|
+
message (str): A message describing the error.
|
|
118
|
+
"""
|
|
119
|
+
super().__init__(file_path=file_path, message=message)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class S3ConfigNotFoundError(IntelligentDefaultsError):
|
|
123
|
+
"""Raised when a configuration file is not found in S3"""
|
|
124
|
+
|
|
125
|
+
fmt = "Failed to load configuration file from S3 location: {s3_uri}. {message}"
|
|
126
|
+
|
|
127
|
+
def __init__(self, s3_uri="(Unkown)", message=""):
|
|
128
|
+
"""Initialize a S3ConfigNotFoundError exception.
|
|
129
|
+
Args:
|
|
130
|
+
s3_uri (str): The S3 URI path to the configuration file.
|
|
131
|
+
message (str): A message describing the error.
|
|
132
|
+
"""
|
|
133
|
+
super().__init__(s3_uri=s3_uri, message=message)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class ConfigSchemaValidationError(IntelligentDefaultsError, ValidationError):
|
|
137
|
+
"""Raised when a configuration file does not adhere to the schema"""
|
|
138
|
+
|
|
139
|
+
fmt = "Failed to validate configuration file from location: {file_path}. {message}"
|
|
140
|
+
|
|
141
|
+
def __init__(self, file_path="(Unkown)", message=""):
|
|
142
|
+
"""Initialize a ConfigSchemaValidationError exception.
|
|
143
|
+
Args:
|
|
144
|
+
file_path (str): The path to the configuration file.
|
|
145
|
+
message (str): A message describing the error.
|
|
146
|
+
"""
|
|
147
|
+
super().__init__(file_path=file_path, message=message)
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import jsonschema
|
|
18
|
+
import boto3
|
|
19
|
+
import yaml
|
|
20
|
+
import pathlib
|
|
21
|
+
|
|
22
|
+
from functools import lru_cache
|
|
23
|
+
from typing import List
|
|
24
|
+
from platformdirs import site_config_dir, user_config_dir
|
|
25
|
+
|
|
26
|
+
from botocore.utils import merge_dicts
|
|
27
|
+
from six.moves.urllib.parse import urlparse
|
|
28
|
+
from sagemaker_core.generated.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
|
|
29
|
+
from sagemaker_core.generated.exceptions import (
|
|
30
|
+
LocalConfigNotFoundError,
|
|
31
|
+
S3ConfigNotFoundError,
|
|
32
|
+
IntelligentDefaultsError,
|
|
33
|
+
ConfigSchemaValidationError,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
logging.basicConfig(level=logging.INFO)
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
_APP_NAME = "sagemaker"
|
|
41
|
+
# The default name of the config file.
|
|
42
|
+
_CONFIG_FILE_NAME = "config.yaml"
|
|
43
|
+
# The default config file location of the Administrator provided config file. This path can be
|
|
44
|
+
# overridden with `SAGEMAKER_ADMIN_CONFIG_OVERRIDE` environment variable.
|
|
45
|
+
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
|
|
46
|
+
# The default config file location of the user provided config file. This path can be
|
|
47
|
+
# overridden with `SAGEMAKER_USER_CONFIG_OVERRIDE` environment variable.
|
|
48
|
+
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
|
|
49
|
+
# The default config file location of the local mode.
|
|
50
|
+
_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH = os.path.join(
|
|
51
|
+
os.path.expanduser("~"), ".sagemaker", _CONFIG_FILE_NAME
|
|
52
|
+
)
|
|
53
|
+
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE"
|
|
54
|
+
ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"
|
|
55
|
+
|
|
56
|
+
S3_PREFIX = "s3://"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load_default_configs(additional_config_paths: List[str] = None, s3_resource=None):
|
|
60
|
+
default_config_path = os.getenv(
|
|
61
|
+
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH
|
|
62
|
+
)
|
|
63
|
+
user_config_path = os.getenv(ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH)
|
|
64
|
+
|
|
65
|
+
config_paths = [default_config_path, user_config_path]
|
|
66
|
+
if additional_config_paths:
|
|
67
|
+
config_paths += additional_config_paths
|
|
68
|
+
config_paths = list(filter(lambda item: item is not None, config_paths))
|
|
69
|
+
merged_config = {}
|
|
70
|
+
for file_path in config_paths:
|
|
71
|
+
config_from_file = {}
|
|
72
|
+
if file_path.startswith(S3_PREFIX):
|
|
73
|
+
config_from_file = _load_config_from_s3(file_path, s3_resource)
|
|
74
|
+
else:
|
|
75
|
+
try:
|
|
76
|
+
config_from_file = _load_config_from_file(file_path)
|
|
77
|
+
except ValueError:
|
|
78
|
+
error = LocalConfigNotFoundError(file_path=file_path)
|
|
79
|
+
if file_path not in (
|
|
80
|
+
_DEFAULT_ADMIN_CONFIG_FILE_PATH,
|
|
81
|
+
_DEFAULT_USER_CONFIG_FILE_PATH,
|
|
82
|
+
):
|
|
83
|
+
# Throw exception only when User provided file path is invalid.
|
|
84
|
+
# If there are no files in the Default config file locations, don't throw
|
|
85
|
+
# Exceptions.
|
|
86
|
+
raise error
|
|
87
|
+
|
|
88
|
+
logger.debug(error)
|
|
89
|
+
if config_from_file:
|
|
90
|
+
try:
|
|
91
|
+
validate_sagemaker_config(config_from_file)
|
|
92
|
+
except jsonschema.exceptions.ValidationError as error:
|
|
93
|
+
raise ConfigSchemaValidationError(file_path=file_path, message=str(error))
|
|
94
|
+
merge_dicts(merged_config, config_from_file)
|
|
95
|
+
logger.debug("Fetched defaults config from location: %s", file_path)
|
|
96
|
+
else:
|
|
97
|
+
logger.debug("Not applying SDK defaults from location: %s", file_path)
|
|
98
|
+
|
|
99
|
+
return merged_config
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def validate_sagemaker_config(sagemaker_config: dict = None):
|
|
103
|
+
"""Validates whether a given dictionary adheres to the schema.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
sagemaker_config: A dictionary containing default values for the
|
|
107
|
+
SageMaker Python SDK. (default: None).
|
|
108
|
+
"""
|
|
109
|
+
jsonschema.validate(sagemaker_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:
|
|
113
|
+
"""Placeholder docstring"""
|
|
114
|
+
if not s3_resource_for_config:
|
|
115
|
+
# Constructing a default Boto3 S3 Resource from a default Boto3 session.
|
|
116
|
+
boto_session = boto3.DEFAULT_SESSION or boto3.Session()
|
|
117
|
+
boto_region_name = boto_session.region_name
|
|
118
|
+
if boto_region_name is None:
|
|
119
|
+
raise IntelligentDefaultsError(
|
|
120
|
+
message=(
|
|
121
|
+
"Valid region is not provided in the Boto3 session."
|
|
122
|
+
+ "Setup local AWS configuration with a valid region supported by SageMaker."
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
s3_resource_for_config = boto_session.resource("s3", region_name=boto_region_name)
|
|
126
|
+
|
|
127
|
+
logger.debug("Fetching defaults config from location: %s", s3_uri)
|
|
128
|
+
inferred_s3_uri = _get_inferred_s3_uri(s3_uri, s3_resource_for_config)
|
|
129
|
+
parsed_url = urlparse(inferred_s3_uri)
|
|
130
|
+
bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/")
|
|
131
|
+
s3_object = s3_resource_for_config.Object(bucket, key_prefix)
|
|
132
|
+
s3_file_content = s3_object.get()["Body"].read()
|
|
133
|
+
return yaml.safe_load(s3_file_content.decode("utf-8"))
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _get_inferred_s3_uri(s3_uri, s3_resource_for_config):
|
|
137
|
+
"""Placeholder docstring"""
|
|
138
|
+
parsed_url = urlparse(s3_uri)
|
|
139
|
+
bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/")
|
|
140
|
+
s3_bucket = s3_resource_for_config.Bucket(name=bucket)
|
|
141
|
+
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
|
|
142
|
+
s3_files_with_same_prefix = [
|
|
143
|
+
"{}{}/{}".format(S3_PREFIX, bucket, s3_object.key) for s3_object in s3_objects
|
|
144
|
+
]
|
|
145
|
+
if len(s3_files_with_same_prefix) == 0:
|
|
146
|
+
# Customer provided us with an incorrect s3 path.
|
|
147
|
+
raise S3ConfigNotFoundError(
|
|
148
|
+
s3_uri=s3_uri,
|
|
149
|
+
message="Provide a valid S3 URI in the format s3://<bucket>/<key-prefix>/{_CONFIG_FILE_NAME}.",
|
|
150
|
+
)
|
|
151
|
+
if len(s3_files_with_same_prefix) > 1:
|
|
152
|
+
# Customer has provided us with a S3 URI which points to a directory
|
|
153
|
+
# search for s3://<bucket>/directory-key-prefix/config.yaml
|
|
154
|
+
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, _CONFIG_FILE_NAME)).replace(
|
|
155
|
+
"s3:/", "s3://"
|
|
156
|
+
)
|
|
157
|
+
if inferred_s3_uri not in s3_files_with_same_prefix:
|
|
158
|
+
# We don't know which file we should be operating with.
|
|
159
|
+
raise S3ConfigNotFoundError(
|
|
160
|
+
s3_uri=s3_uri,
|
|
161
|
+
message="Provide an S3 URI pointing to a directory that contains a {_CONFIG_FILE_NAME} file.",
|
|
162
|
+
)
|
|
163
|
+
# Customer has a config.yaml present in the directory that was provided as the S3 URI
|
|
164
|
+
return inferred_s3_uri
|
|
165
|
+
return s3_uri
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _load_config_from_file(file_path: str) -> dict:
|
|
169
|
+
"""Placeholder docstring"""
|
|
170
|
+
inferred_file_path = file_path
|
|
171
|
+
if os.path.isdir(file_path):
|
|
172
|
+
inferred_file_path = os.path.join(file_path, _CONFIG_FILE_NAME)
|
|
173
|
+
if not os.path.exists(inferred_file_path):
|
|
174
|
+
raise ValueError
|
|
175
|
+
logger.debug("Fetching defaults config from location: %s", file_path)
|
|
176
|
+
with open(inferred_file_path, "r") as f:
|
|
177
|
+
content = yaml.safe_load(f)
|
|
178
|
+
return content
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@lru_cache(maxsize=None)
|
|
182
|
+
def load_default_configs_for_resource_name(resource_name: str):
|
|
183
|
+
configs_data = load_default_configs()
|
|
184
|
+
if not configs_data:
|
|
185
|
+
logger.debug("No default configurations found for resource: %s", resource_name)
|
|
186
|
+
return {}
|
|
187
|
+
return configs_data["SageMaker"]["PythonSDK"]["Resources"].get(resource_name)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def get_config_value(attribute, resource_defaults, global_defaults):
|
|
191
|
+
if resource_defaults and attribute in resource_defaults:
|
|
192
|
+
return resource_defaults[attribute]
|
|
193
|
+
if global_defaults and attribute in global_defaults:
|
|
194
|
+
return global_defaults[attribute]
|
|
195
|
+
logger.info(
|
|
196
|
+
f"Configurable value {attribute} not entered in parameters or present in the Config"
|
|
197
|
+
)
|
|
198
|
+
return None
|