sagemaker-core 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sagemaker-core might be problematic. Click here for more details.

@@ -0,0 +1,314 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+
14
+ import datetime
15
+ import logging
16
+ import os
17
+ import re
18
+
19
+ from boto3.session import Session
20
+ from botocore.config import Config
21
+ from typing import TypeVar, Generic, Type
22
+ from sagemaker_core.code_injection.codec import transform
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ T = TypeVar("T")
28
+
29
+ SPECIAL_SNAKE_TO_PASCAL_MAPPINGS = {
30
+ "volume_size_in_g_b": "VolumeSizeInGB",
31
+ "volume_size_in_gb": "VolumeSizeInGB",
32
+ }
33
+
34
+
35
+ def configure_logging(log_level=None):
36
+ """Configure the logging configuration based on log level.
37
+
38
+ Usage:
39
+ Set Environment Variable LOG_LEVEL to DEBUG to see debug logs
40
+ configure_logging()
41
+ configure_logging("DEBUG")
42
+
43
+ Args:
44
+ log_level (str): The log level to set.
45
+ Accepted values are: "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL".
46
+ Defaults to the value of the LOG_LEVEL environment variable.
47
+ If argument/environment variable is not set, defaults to "INFO".
48
+
49
+ Raises:
50
+ AttributeError: If the log level is invalid.
51
+ """
52
+
53
+ if not log_level:
54
+ log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
55
+ _logger = logging.getLogger()
56
+ _logger.setLevel(getattr(logging, log_level))
57
+ # reset any currently associated handlers with log level
58
+ for handler in _logger.handlers:
59
+ _logger.removeHandler(handler)
60
+ console_handler = logging.StreamHandler()
61
+ console_handler.setFormatter(
62
+ logging.Formatter("%(asctime)s : %(levelname)s : %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
63
+ )
64
+ _logger.addHandler(console_handler)
65
+
66
+
67
+ def is_snake_case(s: str):
68
+ if not s:
69
+ return False
70
+ if s[0].isupper():
71
+ return False
72
+ if not s.islower() and not s.isalnum():
73
+ return False
74
+ if s.startswith("_") or s.endswith("_"):
75
+ return False
76
+ if "__" in s:
77
+ return False
78
+ return True
79
+
80
+
81
+ def snake_to_pascal(snake_str):
82
+ """
83
+ Convert a snake_case string to PascalCase.
84
+
85
+ Args:
86
+ snake_str (str): The snake_case string to be converted.
87
+
88
+ Returns:
89
+ str: The PascalCase string.
90
+
91
+ """
92
+ if pascal_str := SPECIAL_SNAKE_TO_PASCAL_MAPPINGS.get(snake_str):
93
+ return pascal_str
94
+ components = snake_str.split("_")
95
+ return "".join(x.title() for x in components[0:])
96
+
97
+
98
+ def pascal_to_snake(pascal_str):
99
+ """
100
+ Converts a PascalCase string to snake_case.
101
+
102
+ Args:
103
+ pascal_str (str): The PascalCase string to be converted.
104
+
105
+ Returns:
106
+ str: The converted snake_case string.
107
+ """
108
+ snake_case = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", pascal_str)
109
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case).lower()
110
+
111
+
112
+ def is_not_primitive(obj):
113
+ return not isinstance(obj, (int, float, str, bool, complex))
114
+
115
+
116
+ def is_not_str_dict(obj):
117
+ return not isinstance(obj, dict) or not all(isinstance(k, str) for k in obj.keys())
118
+
119
+
120
+ def is_primitive_list(obj):
121
+ return all(not is_not_primitive(s) for s in obj)
122
+
123
+
124
+ def is_primitive_class(cls):
125
+ return cls in (str, int, bool, float, datetime.datetime)
126
+
127
+
128
+ class Unassigned:
129
+ """A custom type used to signify an undefined optional argument."""
130
+
131
+ _instance = None
132
+
133
+ def __new__(cls):
134
+ if cls._instance is None:
135
+ cls._instance = super().__new__(cls)
136
+ return cls._instance
137
+
138
+
139
+ class SingletonMeta(type):
140
+ """
141
+ Singleton metaclass. Ensures that a single instance of a class using this metaclass is created.
142
+ """
143
+
144
+ _instances = {}
145
+
146
+ def __call__(cls, *args, **kwargs):
147
+ """
148
+ Overrides the call method to return an existing instance of the class if it exists,
149
+ or create a new one if it doesn't.
150
+ """
151
+ if cls not in cls._instances:
152
+ instance = super().__call__(*args, **kwargs)
153
+ cls._instances[cls] = instance
154
+ return cls._instances[cls]
155
+
156
+
157
+ class SageMakerClient(metaclass=SingletonMeta):
158
+ """
159
+ A singleton class for creating a SageMaker client.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ session: Session = None,
165
+ region_name: str = None,
166
+ service_name="sagemaker",
167
+ config: Config = None,
168
+ ):
169
+ """
170
+ Initializes the SageMakerClient with a boto3 session, region name, and service name.
171
+ Creates a boto3 client using the provided session, region, and service.
172
+ """
173
+ if session is None:
174
+ logger.warning("No boto3 session provided. Creating a new session.")
175
+ session = Session()
176
+
177
+ if region_name is None:
178
+ logger.warning("No region provided. Using default region.")
179
+ region_name = session.region_name
180
+
181
+ if config is None:
182
+ logger.warning("No config provided. Using default config.")
183
+ config = Config(retries={"max_attempts": 10, "mode": "standard"})
184
+
185
+ self.session = session
186
+ self.region_name = region_name
187
+ self.service_name = service_name
188
+ self.client = session.client(service_name, region_name, config=config)
189
+
190
+
191
+ class SageMakerRuntimeClient(metaclass=SingletonMeta):
192
+ """
193
+ A singleton class for creating a SageMaker client.
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ session: Session = None,
199
+ region_name: str = None,
200
+ service_name="sagemaker-runtime",
201
+ config: Config = None,
202
+ ):
203
+ """
204
+ Initializes the SageMakerClient with a boto3 session, region name, and service name.
205
+ Creates a boto3 client using the provided session, region, and service.
206
+ """
207
+ if session is None:
208
+ logger.warning("No boto3 session provided. Creating a new session.")
209
+ session = Session()
210
+
211
+ if region_name is None:
212
+ logger.warning("No region provided. Using default region.")
213
+ region_name = session.region_name
214
+
215
+ if config is None:
216
+ logger.warning("No config provided. Using default config.")
217
+ config = Config(retries={"max_attempts": 10, "mode": "standard"})
218
+
219
+ self.session = session
220
+ self.region_name = region_name
221
+ self.service_name = service_name
222
+ self.client = session.client(service_name, region_name, config=config)
223
+
224
+
225
+ class ResourceIterator(Generic[T]):
226
+ """ResourceIterator class to iterate over a list of resources."""
227
+
228
+ def __init__(
229
+ self,
230
+ client: SageMakerClient,
231
+ summaries_key: str,
232
+ summary_name: str,
233
+ resource_cls: Type[T],
234
+ list_method: str,
235
+ list_method_kwargs: dict = {},
236
+ custom_key_mapping: dict = None,
237
+ ):
238
+ """Initialize a ResourceIterator object
239
+
240
+ Args:
241
+ client (SageMakerClient): The sagemaker client object used to make list method calls.
242
+ summaries_key (str): The summaries key string used to access the list of summaries in the response.
243
+ summary_name (str): The summary name used to transform list response data.
244
+ resource_cls (Type[T]): The resource class to be instantiated for each resource object.
245
+ list_method (str): The list method string used to make list calls to the client.
246
+ list_method_kwargs (dict, optional): The kwargs used to make list method calls. Defaults to {}.
247
+ custom_key_mapping (dict, optional): The custom key mapping used to map keys from summary object to those expected from resource object during initialization. Defaults to None.
248
+ """
249
+ self.summaries_key = summaries_key
250
+ self.summary_name = summary_name
251
+ self.client = client
252
+ self.list_method = list_method
253
+ self.list_method_kwargs = list_method_kwargs
254
+ self.custom_key_mapping = custom_key_mapping
255
+
256
+ self.resource_cls = resource_cls
257
+ self.index = 0
258
+ self.summary_list = []
259
+ self.next_token = None
260
+
261
+ def __iter__(self):
262
+ return self
263
+
264
+ def __next__(self) -> T:
265
+
266
+ # If there are summaries in the summary_list, return the next summary
267
+ if len(self.summary_list) > 0 and self.index < len(self.summary_list):
268
+ # Get the next summary from the resource summary_list
269
+ summary = self.summary_list[self.index]
270
+ self.index += 1
271
+
272
+ # Initialize the resource object
273
+ if is_primitive_class(self.resource_cls):
274
+ # If the resource class is a primitive class, there will be only one element in the summary
275
+ resource_object = list(summary.values())[0]
276
+ else:
277
+ # Transform the resource summary into format to initialize object
278
+ init_data = transform(summary, self.summary_name)
279
+
280
+ if self.custom_key_mapping:
281
+ init_data = {self.custom_key_mapping.get(k, k): v for k, v in init_data.items()}
282
+ resource_object = self.resource_cls(**init_data)
283
+
284
+ # If the resource object has refresh method, refresh and return it
285
+ if hasattr(resource_object, "refresh"):
286
+ resource_object.refresh()
287
+ return resource_object
288
+
289
+ # If index reached the end of summary_list, and there is no next token, raise StopIteration
290
+ elif (
291
+ len(self.summary_list) > 0
292
+ and self.index >= len(self.summary_list)
293
+ and self.next_token is None
294
+ ):
295
+ raise StopIteration
296
+
297
+ # Otherwise, get the next page of summaries by calling the list method with the next token if available
298
+ else:
299
+ if self.next_token is not None:
300
+ response = getattr(self.client, self.list_method)(
301
+ NextToken=self.next_token, **self.list_method_kwargs
302
+ )
303
+ else:
304
+ response = getattr(self.client, self.list_method)(**self.list_method_kwargs)
305
+
306
+ self.summary_list = response.get(self.summaries_key, [])
307
+ self.next_token = response.get("NextToken", None)
308
+ self.index = 0
309
+
310
+ # If list_method returned an empty list, raise StopIteration
311
+ if len(self.summary_list) == 0:
312
+ raise StopIteration
313
+
314
+ return self.__next__()
@@ -0,0 +1 @@
1
+ from ..code_injection.codec import pascal_to_snake
@@ -0,0 +1,56 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Generates the code for the service model."""
14
+ from sagemaker_core.tools.shapes_codegen import ShapesCodeGen
15
+ from sagemaker_core.tools.resources_codegen import ResourcesCodeGen
16
+ from typing import Optional
17
+
18
+ from sagemaker_core.tools.data_extractor import ServiceJsonData, load_service_jsons
19
+ from sagemaker_core.util.util import reformat_file_with_black
20
+
21
+
22
+ def generate_code(
23
+ shapes_code_gen: Optional[ShapesCodeGen] = None,
24
+ resources_code_gen: Optional[ShapesCodeGen] = None,
25
+ ) -> None:
26
+ """
27
+ Generates the code for the given code generators. If any code generator is not
28
+ provided when calling this function, the function will initiate the generator.
29
+
30
+ Note ordering is important, generate the utils and lower level classes first
31
+ then generate the higher level classes.
32
+
33
+ Args:
34
+ shapes_code_gen (ShapesCodeGen): The code generator for shape classes.
35
+ resources_code_gen (ResourcesCodeGen): The code generator for resource classes.
36
+
37
+ Returns:
38
+ None
39
+ """
40
+ service_json_data: ServiceJsonData = load_service_jsons()
41
+
42
+ shapes_code_gen = shapes_code_gen or ShapesCodeGen()
43
+ resources_code_gen = resources_code_gen or ResourcesCodeGen(
44
+ service_json=service_json_data.sagemaker
45
+ )
46
+
47
+ shapes_code_gen.generate_shapes()
48
+ resources_code_gen.generate_resources()
49
+ reformat_file_with_black(".")
50
+
51
+
52
+ """
53
+ Initializes all the code generator classes and triggers generator.
54
+ """
55
+ if __name__ == "__main__":
56
+ generate_code()
@@ -0,0 +1,96 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Constants used in the code_generator modules."""
14
+ import os
15
+
16
+ CLASS_METHODS = set(["create", "add", "start", "register", "import", "list", "get"])
17
+ OBJECT_METHODS = set(
18
+ ["refresh", "delete", "update", "stop", "deregister", "wait", "wait_for_status"]
19
+ )
20
+
21
+ TERMINAL_STATES = set(["Completed", "Stopped", "Deleted", "Failed", "Succeeded", "Cancelled"])
22
+
23
+ CONFIGURABLE_ATTRIBUTE_SUBSTRINGS = [
24
+ "kms",
25
+ "s3",
26
+ "subnet",
27
+ "tags",
28
+ "role",
29
+ "security_group",
30
+ ]
31
+
32
+ BASIC_JSON_TYPES_TO_PYTHON_TYPES = {
33
+ "string": "str",
34
+ "integer": "int",
35
+ "boolean": "bool",
36
+ "long": "int",
37
+ "float": "float",
38
+ "map": "dict",
39
+ "double": "float",
40
+ "list": "list",
41
+ "timestamp": "datetime.datetime",
42
+ "blob": "Any",
43
+ }
44
+
45
+ BASIC_RETURN_TYPES = {"str", "int", "bool", "float", "datetime.datetime"}
46
+
47
+ SHAPE_DAG_FILE_PATH = os.getcwd() + "/src/sagemaker_core/code_injection/shape_dag.py"
48
+ PYTHON_TYPES_TO_BASIC_JSON_TYPES = {
49
+ "str": "string",
50
+ "int": "integer",
51
+ "bool": "boolean",
52
+ "float": "double",
53
+ "datetime.datetime": "timestamp",
54
+ }
55
+
56
+ LICENCES_STRING = """
57
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
58
+ #
59
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
60
+ # may not use this file except in compliance with the License. A copy of
61
+ # the License is located at
62
+ #
63
+ # http://aws.amazon.com/apache2.0/
64
+ #
65
+ # or in the "license" file accompanying this file. This file is
66
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
67
+ # ANY KIND, either express or implied. See the License for the specific
68
+ # language governing permissions and limitations under the License.
69
+ """
70
+
71
+ BASIC_IMPORTS_STRING = """
72
+ import logging
73
+ """
74
+
75
+ LOGGER_STRING = """
76
+ logging.basicConfig(level=logging.INFO)
77
+ logger = logging.getLogger(__name__)
78
+
79
+ """
80
+
81
+ # TODO: The file name should be injected, we should update it to be more generic
82
+ ADDITIONAL_OPERATION_FILE_PATH = (
83
+ os.getcwd() + "/src/sagemaker_core/tools/additional_operations.json"
84
+ )
85
+ SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker/2017-07-24/service-2.json"
86
+ RUNTIME_SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker-runtime/2017-05-13/service-2.json"
87
+
88
+ GENERATED_CLASSES_LOCATION = os.getcwd() + "/src/sagemaker_core/generated"
89
+ UTILS_CODEGEN_FILE_NAME = "utils.py"
90
+ INTELLIGENT_DEFAULTS_HELPER_CODEGEN_FILE_NAME = "intelligent_defaults_helper.py"
91
+
92
+ RESOURCES_CODEGEN_FILE_NAME = "resources.py"
93
+
94
+ SHAPES_CODEGEN_FILE_NAME = "shapes.py"
95
+
96
+ CONFIG_SCHEMA_FILE_NAME = "config_schema.py"
@@ -0,0 +1,49 @@
1
+ import json
2
+ from functools import lru_cache
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from sagemaker_core.tools.constants import (
7
+ ADDITIONAL_OPERATION_FILE_PATH,
8
+ SERVICE_JSON_FILE_PATH,
9
+ RUNTIME_SERVICE_JSON_FILE_PATH,
10
+ )
11
+
12
+
13
+ class ServiceJsonData(BaseModel):
14
+ sagemaker: dict
15
+ sagemaker_runtime: dict
16
+
17
+
18
+ @lru_cache(maxsize=1)
19
+ def load_service_jsons() -> ServiceJsonData:
20
+ with open(SERVICE_JSON_FILE_PATH, "r") as file:
21
+ service_json = json.load(file)
22
+ with open(RUNTIME_SERVICE_JSON_FILE_PATH, "r") as file:
23
+ runtime_service_json = json.load(file)
24
+ return ServiceJsonData(sagemaker=service_json, sagemaker_runtime=runtime_service_json)
25
+
26
+
27
+ @lru_cache(maxsize=1)
28
+ def load_combined_shapes_data() -> dict:
29
+ service_json_data = load_service_jsons()
30
+ return {
31
+ **service_json_data.sagemaker["shapes"],
32
+ **service_json_data.sagemaker_runtime["shapes"],
33
+ }
34
+
35
+
36
+ @lru_cache(maxsize=1)
37
+ def load_combined_operations_data() -> dict:
38
+ service_json_data = load_service_jsons()
39
+ return {
40
+ **service_json_data.sagemaker["operations"],
41
+ **service_json_data.sagemaker_runtime["operations"],
42
+ }
43
+
44
+
45
+ @lru_cache(maxsize=1)
46
+ def load_additional_operations_data() -> dict:
47
+ with open(ADDITIONAL_OPERATION_FILE_PATH, "r") as file:
48
+ additional_operation_json = json.load(file)
49
+ return additional_operation_json
@@ -0,0 +1,32 @@
1
+ from enum import Enum
2
+
3
+ from sagemaker_core.util.util import remove_html_tags
4
+
5
+
6
+ class MethodType(Enum):
7
+ CLASS = "class"
8
+ OBJECT = "object"
9
+ STATIC = "static"
10
+
11
+
12
+ class Method:
13
+ """
14
+ A class to store the information of methods to be generated
15
+ """
16
+
17
+ operation_name: str
18
+ resource_name: str
19
+ method_name: str
20
+ return_type: str
21
+ method_type: MethodType
22
+ service_name: str
23
+ docstring_title: str
24
+
25
+ def __init__(self, **kwargs):
26
+ self.__dict__.update(kwargs)
27
+
28
+ def get_docstring_title(self, operation):
29
+ title = remove_html_tags(operation["documentation"])
30
+ self.docstring_title = title.split(".")[0] + "."
31
+
32
+ # TODO: add some templates for common methods