sagemaker-core 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sagemaker-core might be problematic. Click here for more details.
- sagemaker_core/__init__.py +0 -0
- sagemaker_core/_version.py +11 -0
- sagemaker_core/code_injection/__init__.py +0 -0
- sagemaker_core/code_injection/base.py +42 -0
- sagemaker_core/code_injection/codec.py +241 -0
- sagemaker_core/code_injection/constants.py +18 -0
- sagemaker_core/code_injection/shape_dag.py +14527 -0
- sagemaker_core/generated/__init__.py +0 -0
- sagemaker_core/generated/config_schema.py +870 -0
- sagemaker_core/generated/exceptions.py +147 -0
- sagemaker_core/generated/intelligent_defaults_helper.py +198 -0
- sagemaker_core/generated/resources.py +26998 -0
- sagemaker_core/generated/shapes.py +11584 -0
- sagemaker_core/generated/utils.py +314 -0
- sagemaker_core/tools/__init__.py +1 -0
- sagemaker_core/tools/codegen.py +56 -0
- sagemaker_core/tools/constants.py +96 -0
- sagemaker_core/tools/data_extractor.py +49 -0
- sagemaker_core/tools/method.py +32 -0
- sagemaker_core/tools/resources_codegen.py +2122 -0
- sagemaker_core/tools/resources_extractor.py +373 -0
- sagemaker_core/tools/shapes_codegen.py +284 -0
- sagemaker_core/tools/shapes_extractor.py +259 -0
- sagemaker_core/tools/templates.py +747 -0
- sagemaker_core/util/__init__.py +0 -0
- sagemaker_core/util/util.py +81 -0
- sagemaker_core-0.1.3.dist-info/LICENSE +201 -0
- sagemaker_core-0.1.3.dist-info/METADATA +28 -0
- sagemaker_core-0.1.3.dist-info/RECORD +31 -0
- sagemaker_core-0.1.3.dist-info/WHEEL +5 -0
- sagemaker_core-0.1.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
|
|
14
|
+
import datetime
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
|
|
19
|
+
from boto3.session import Session
|
|
20
|
+
from botocore.config import Config
|
|
21
|
+
from typing import TypeVar, Generic, Type
|
|
22
|
+
from sagemaker_core.code_injection.codec import transform
|
|
23
|
+
|
|
24
|
+
logging.basicConfig(level=logging.INFO)
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
T = TypeVar("T")
|
|
28
|
+
|
|
29
|
+
SPECIAL_SNAKE_TO_PASCAL_MAPPINGS = {
|
|
30
|
+
"volume_size_in_g_b": "VolumeSizeInGB",
|
|
31
|
+
"volume_size_in_gb": "VolumeSizeInGB",
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def configure_logging(log_level=None):
|
|
36
|
+
"""Configure the logging configuration based on log level.
|
|
37
|
+
|
|
38
|
+
Usage:
|
|
39
|
+
Set Environment Variable LOG_LEVEL to DEBUG to see debug logs
|
|
40
|
+
configure_logging()
|
|
41
|
+
configure_logging("DEBUG")
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
log_level (str): The log level to set.
|
|
45
|
+
Accepted values are: "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL".
|
|
46
|
+
Defaults to the value of the LOG_LEVEL environment variable.
|
|
47
|
+
If argument/environment variable is not set, defaults to "INFO".
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
AttributeError: If the log level is invalid.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
if not log_level:
|
|
54
|
+
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
|
55
|
+
_logger = logging.getLogger()
|
|
56
|
+
_logger.setLevel(getattr(logging, log_level))
|
|
57
|
+
# reset any currently associated handlers with log level
|
|
58
|
+
for handler in _logger.handlers:
|
|
59
|
+
_logger.removeHandler(handler)
|
|
60
|
+
console_handler = logging.StreamHandler()
|
|
61
|
+
console_handler.setFormatter(
|
|
62
|
+
logging.Formatter("%(asctime)s : %(levelname)s : %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
|
63
|
+
)
|
|
64
|
+
_logger.addHandler(console_handler)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def is_snake_case(s: str):
|
|
68
|
+
if not s:
|
|
69
|
+
return False
|
|
70
|
+
if s[0].isupper():
|
|
71
|
+
return False
|
|
72
|
+
if not s.islower() and not s.isalnum():
|
|
73
|
+
return False
|
|
74
|
+
if s.startswith("_") or s.endswith("_"):
|
|
75
|
+
return False
|
|
76
|
+
if "__" in s:
|
|
77
|
+
return False
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def snake_to_pascal(snake_str):
|
|
82
|
+
"""
|
|
83
|
+
Convert a snake_case string to PascalCase.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
snake_str (str): The snake_case string to be converted.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
str: The PascalCase string.
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
if pascal_str := SPECIAL_SNAKE_TO_PASCAL_MAPPINGS.get(snake_str):
|
|
93
|
+
return pascal_str
|
|
94
|
+
components = snake_str.split("_")
|
|
95
|
+
return "".join(x.title() for x in components[0:])
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def pascal_to_snake(pascal_str):
|
|
99
|
+
"""
|
|
100
|
+
Converts a PascalCase string to snake_case.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
pascal_str (str): The PascalCase string to be converted.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
str: The converted snake_case string.
|
|
107
|
+
"""
|
|
108
|
+
snake_case = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", pascal_str)
|
|
109
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case).lower()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def is_not_primitive(obj):
|
|
113
|
+
return not isinstance(obj, (int, float, str, bool, complex))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def is_not_str_dict(obj):
|
|
117
|
+
return not isinstance(obj, dict) or not all(isinstance(k, str) for k in obj.keys())
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def is_primitive_list(obj):
|
|
121
|
+
return all(not is_not_primitive(s) for s in obj)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def is_primitive_class(cls):
|
|
125
|
+
return cls in (str, int, bool, float, datetime.datetime)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class Unassigned:
|
|
129
|
+
"""A custom type used to signify an undefined optional argument."""
|
|
130
|
+
|
|
131
|
+
_instance = None
|
|
132
|
+
|
|
133
|
+
def __new__(cls):
|
|
134
|
+
if cls._instance is None:
|
|
135
|
+
cls._instance = super().__new__(cls)
|
|
136
|
+
return cls._instance
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class SingletonMeta(type):
|
|
140
|
+
"""
|
|
141
|
+
Singleton metaclass. Ensures that a single instance of a class using this metaclass is created.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
_instances = {}
|
|
145
|
+
|
|
146
|
+
def __call__(cls, *args, **kwargs):
|
|
147
|
+
"""
|
|
148
|
+
Overrides the call method to return an existing instance of the class if it exists,
|
|
149
|
+
or create a new one if it doesn't.
|
|
150
|
+
"""
|
|
151
|
+
if cls not in cls._instances:
|
|
152
|
+
instance = super().__call__(*args, **kwargs)
|
|
153
|
+
cls._instances[cls] = instance
|
|
154
|
+
return cls._instances[cls]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class SageMakerClient(metaclass=SingletonMeta):
|
|
158
|
+
"""
|
|
159
|
+
A singleton class for creating a SageMaker client.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
session: Session = None,
|
|
165
|
+
region_name: str = None,
|
|
166
|
+
service_name="sagemaker",
|
|
167
|
+
config: Config = None,
|
|
168
|
+
):
|
|
169
|
+
"""
|
|
170
|
+
Initializes the SageMakerClient with a boto3 session, region name, and service name.
|
|
171
|
+
Creates a boto3 client using the provided session, region, and service.
|
|
172
|
+
"""
|
|
173
|
+
if session is None:
|
|
174
|
+
logger.warning("No boto3 session provided. Creating a new session.")
|
|
175
|
+
session = Session()
|
|
176
|
+
|
|
177
|
+
if region_name is None:
|
|
178
|
+
logger.warning("No region provided. Using default region.")
|
|
179
|
+
region_name = session.region_name
|
|
180
|
+
|
|
181
|
+
if config is None:
|
|
182
|
+
logger.warning("No config provided. Using default config.")
|
|
183
|
+
config = Config(retries={"max_attempts": 10, "mode": "standard"})
|
|
184
|
+
|
|
185
|
+
self.session = session
|
|
186
|
+
self.region_name = region_name
|
|
187
|
+
self.service_name = service_name
|
|
188
|
+
self.client = session.client(service_name, region_name, config=config)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class SageMakerRuntimeClient(metaclass=SingletonMeta):
|
|
192
|
+
"""
|
|
193
|
+
A singleton class for creating a SageMaker client.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
session: Session = None,
|
|
199
|
+
region_name: str = None,
|
|
200
|
+
service_name="sagemaker-runtime",
|
|
201
|
+
config: Config = None,
|
|
202
|
+
):
|
|
203
|
+
"""
|
|
204
|
+
Initializes the SageMakerClient with a boto3 session, region name, and service name.
|
|
205
|
+
Creates a boto3 client using the provided session, region, and service.
|
|
206
|
+
"""
|
|
207
|
+
if session is None:
|
|
208
|
+
logger.warning("No boto3 session provided. Creating a new session.")
|
|
209
|
+
session = Session()
|
|
210
|
+
|
|
211
|
+
if region_name is None:
|
|
212
|
+
logger.warning("No region provided. Using default region.")
|
|
213
|
+
region_name = session.region_name
|
|
214
|
+
|
|
215
|
+
if config is None:
|
|
216
|
+
logger.warning("No config provided. Using default config.")
|
|
217
|
+
config = Config(retries={"max_attempts": 10, "mode": "standard"})
|
|
218
|
+
|
|
219
|
+
self.session = session
|
|
220
|
+
self.region_name = region_name
|
|
221
|
+
self.service_name = service_name
|
|
222
|
+
self.client = session.client(service_name, region_name, config=config)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class ResourceIterator(Generic[T]):
|
|
226
|
+
"""ResourceIterator class to iterate over a list of resources."""
|
|
227
|
+
|
|
228
|
+
def __init__(
|
|
229
|
+
self,
|
|
230
|
+
client: SageMakerClient,
|
|
231
|
+
summaries_key: str,
|
|
232
|
+
summary_name: str,
|
|
233
|
+
resource_cls: Type[T],
|
|
234
|
+
list_method: str,
|
|
235
|
+
list_method_kwargs: dict = {},
|
|
236
|
+
custom_key_mapping: dict = None,
|
|
237
|
+
):
|
|
238
|
+
"""Initialize a ResourceIterator object
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
client (SageMakerClient): The sagemaker client object used to make list method calls.
|
|
242
|
+
summaries_key (str): The summaries key string used to access the list of summaries in the response.
|
|
243
|
+
summary_name (str): The summary name used to transform list response data.
|
|
244
|
+
resource_cls (Type[T]): The resource class to be instantiated for each resource object.
|
|
245
|
+
list_method (str): The list method string used to make list calls to the client.
|
|
246
|
+
list_method_kwargs (dict, optional): The kwargs used to make list method calls. Defaults to {}.
|
|
247
|
+
custom_key_mapping (dict, optional): The custom key mapping used to map keys from summary object to those expected from resource object during initialization. Defaults to None.
|
|
248
|
+
"""
|
|
249
|
+
self.summaries_key = summaries_key
|
|
250
|
+
self.summary_name = summary_name
|
|
251
|
+
self.client = client
|
|
252
|
+
self.list_method = list_method
|
|
253
|
+
self.list_method_kwargs = list_method_kwargs
|
|
254
|
+
self.custom_key_mapping = custom_key_mapping
|
|
255
|
+
|
|
256
|
+
self.resource_cls = resource_cls
|
|
257
|
+
self.index = 0
|
|
258
|
+
self.summary_list = []
|
|
259
|
+
self.next_token = None
|
|
260
|
+
|
|
261
|
+
def __iter__(self):
|
|
262
|
+
return self
|
|
263
|
+
|
|
264
|
+
def __next__(self) -> T:
|
|
265
|
+
|
|
266
|
+
# If there are summaries in the summary_list, return the next summary
|
|
267
|
+
if len(self.summary_list) > 0 and self.index < len(self.summary_list):
|
|
268
|
+
# Get the next summary from the resource summary_list
|
|
269
|
+
summary = self.summary_list[self.index]
|
|
270
|
+
self.index += 1
|
|
271
|
+
|
|
272
|
+
# Initialize the resource object
|
|
273
|
+
if is_primitive_class(self.resource_cls):
|
|
274
|
+
# If the resource class is a primitive class, there will be only one element in the summary
|
|
275
|
+
resource_object = list(summary.values())[0]
|
|
276
|
+
else:
|
|
277
|
+
# Transform the resource summary into format to initialize object
|
|
278
|
+
init_data = transform(summary, self.summary_name)
|
|
279
|
+
|
|
280
|
+
if self.custom_key_mapping:
|
|
281
|
+
init_data = {self.custom_key_mapping.get(k, k): v for k, v in init_data.items()}
|
|
282
|
+
resource_object = self.resource_cls(**init_data)
|
|
283
|
+
|
|
284
|
+
# If the resource object has refresh method, refresh and return it
|
|
285
|
+
if hasattr(resource_object, "refresh"):
|
|
286
|
+
resource_object.refresh()
|
|
287
|
+
return resource_object
|
|
288
|
+
|
|
289
|
+
# If index reached the end of summary_list, and there is no next token, raise StopIteration
|
|
290
|
+
elif (
|
|
291
|
+
len(self.summary_list) > 0
|
|
292
|
+
and self.index >= len(self.summary_list)
|
|
293
|
+
and self.next_token is None
|
|
294
|
+
):
|
|
295
|
+
raise StopIteration
|
|
296
|
+
|
|
297
|
+
# Otherwise, get the next page of summaries by calling the list method with the next token if available
|
|
298
|
+
else:
|
|
299
|
+
if self.next_token is not None:
|
|
300
|
+
response = getattr(self.client, self.list_method)(
|
|
301
|
+
NextToken=self.next_token, **self.list_method_kwargs
|
|
302
|
+
)
|
|
303
|
+
else:
|
|
304
|
+
response = getattr(self.client, self.list_method)(**self.list_method_kwargs)
|
|
305
|
+
|
|
306
|
+
self.summary_list = response.get(self.summaries_key, [])
|
|
307
|
+
self.next_token = response.get("NextToken", None)
|
|
308
|
+
self.index = 0
|
|
309
|
+
|
|
310
|
+
# If list_method returned an empty list, raise StopIteration
|
|
311
|
+
if len(self.summary_list) == 0:
|
|
312
|
+
raise StopIteration
|
|
313
|
+
|
|
314
|
+
return self.__next__()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ..code_injection.codec import pascal_to_snake
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Generates the code for the service model."""
|
|
14
|
+
from sagemaker_core.tools.shapes_codegen import ShapesCodeGen
|
|
15
|
+
from sagemaker_core.tools.resources_codegen import ResourcesCodeGen
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
from sagemaker_core.tools.data_extractor import ServiceJsonData, load_service_jsons
|
|
19
|
+
from sagemaker_core.util.util import reformat_file_with_black
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def generate_code(
|
|
23
|
+
shapes_code_gen: Optional[ShapesCodeGen] = None,
|
|
24
|
+
resources_code_gen: Optional[ShapesCodeGen] = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Generates the code for the given code generators. If any code generator is not
|
|
28
|
+
provided when calling this function, the function will initiate the generator.
|
|
29
|
+
|
|
30
|
+
Note ordering is important, generate the utils and lower level classes first
|
|
31
|
+
then generate the higher level classes.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
shapes_code_gen (ShapesCodeGen): The code generator for shape classes.
|
|
35
|
+
resources_code_gen (ResourcesCodeGen): The code generator for resource classes.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
None
|
|
39
|
+
"""
|
|
40
|
+
service_json_data: ServiceJsonData = load_service_jsons()
|
|
41
|
+
|
|
42
|
+
shapes_code_gen = shapes_code_gen or ShapesCodeGen()
|
|
43
|
+
resources_code_gen = resources_code_gen or ResourcesCodeGen(
|
|
44
|
+
service_json=service_json_data.sagemaker
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
shapes_code_gen.generate_shapes()
|
|
48
|
+
resources_code_gen.generate_resources()
|
|
49
|
+
reformat_file_with_black(".")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
Initializes all the code generator classes and triggers generator.
|
|
54
|
+
"""
|
|
55
|
+
if __name__ == "__main__":
|
|
56
|
+
generate_code()
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Constants used in the code_generator modules."""
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
CLASS_METHODS = set(["create", "add", "start", "register", "import", "list", "get"])
|
|
17
|
+
OBJECT_METHODS = set(
|
|
18
|
+
["refresh", "delete", "update", "stop", "deregister", "wait", "wait_for_status"]
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
TERMINAL_STATES = set(["Completed", "Stopped", "Deleted", "Failed", "Succeeded", "Cancelled"])
|
|
22
|
+
|
|
23
|
+
CONFIGURABLE_ATTRIBUTE_SUBSTRINGS = [
|
|
24
|
+
"kms",
|
|
25
|
+
"s3",
|
|
26
|
+
"subnet",
|
|
27
|
+
"tags",
|
|
28
|
+
"role",
|
|
29
|
+
"security_group",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
BASIC_JSON_TYPES_TO_PYTHON_TYPES = {
|
|
33
|
+
"string": "str",
|
|
34
|
+
"integer": "int",
|
|
35
|
+
"boolean": "bool",
|
|
36
|
+
"long": "int",
|
|
37
|
+
"float": "float",
|
|
38
|
+
"map": "dict",
|
|
39
|
+
"double": "float",
|
|
40
|
+
"list": "list",
|
|
41
|
+
"timestamp": "datetime.datetime",
|
|
42
|
+
"blob": "Any",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
BASIC_RETURN_TYPES = {"str", "int", "bool", "float", "datetime.datetime"}
|
|
46
|
+
|
|
47
|
+
SHAPE_DAG_FILE_PATH = os.getcwd() + "/src/sagemaker_core/code_injection/shape_dag.py"
|
|
48
|
+
PYTHON_TYPES_TO_BASIC_JSON_TYPES = {
|
|
49
|
+
"str": "string",
|
|
50
|
+
"int": "integer",
|
|
51
|
+
"bool": "boolean",
|
|
52
|
+
"float": "double",
|
|
53
|
+
"datetime.datetime": "timestamp",
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
LICENCES_STRING = """
|
|
57
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
58
|
+
#
|
|
59
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
60
|
+
# may not use this file except in compliance with the License. A copy of
|
|
61
|
+
# the License is located at
|
|
62
|
+
#
|
|
63
|
+
# http://aws.amazon.com/apache2.0/
|
|
64
|
+
#
|
|
65
|
+
# or in the "license" file accompanying this file. This file is
|
|
66
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
67
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
68
|
+
# language governing permissions and limitations under the License.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
BASIC_IMPORTS_STRING = """
|
|
72
|
+
import logging
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
LOGGER_STRING = """
|
|
76
|
+
logging.basicConfig(level=logging.INFO)
|
|
77
|
+
logger = logging.getLogger(__name__)
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# TODO: The file name should be injected, we should update it to be more generic
|
|
82
|
+
ADDITIONAL_OPERATION_FILE_PATH = (
|
|
83
|
+
os.getcwd() + "/src/sagemaker_core/tools/additional_operations.json"
|
|
84
|
+
)
|
|
85
|
+
SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker/2017-07-24/service-2.json"
|
|
86
|
+
RUNTIME_SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker-runtime/2017-05-13/service-2.json"
|
|
87
|
+
|
|
88
|
+
GENERATED_CLASSES_LOCATION = os.getcwd() + "/src/sagemaker_core/generated"
|
|
89
|
+
UTILS_CODEGEN_FILE_NAME = "utils.py"
|
|
90
|
+
INTELLIGENT_DEFAULTS_HELPER_CODEGEN_FILE_NAME = "intelligent_defaults_helper.py"
|
|
91
|
+
|
|
92
|
+
RESOURCES_CODEGEN_FILE_NAME = "resources.py"
|
|
93
|
+
|
|
94
|
+
SHAPES_CODEGEN_FILE_NAME = "shapes.py"
|
|
95
|
+
|
|
96
|
+
CONFIG_SCHEMA_FILE_NAME = "config_schema.py"
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from sagemaker_core.tools.constants import (
|
|
7
|
+
ADDITIONAL_OPERATION_FILE_PATH,
|
|
8
|
+
SERVICE_JSON_FILE_PATH,
|
|
9
|
+
RUNTIME_SERVICE_JSON_FILE_PATH,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ServiceJsonData(BaseModel):
|
|
14
|
+
sagemaker: dict
|
|
15
|
+
sagemaker_runtime: dict
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache(maxsize=1)
|
|
19
|
+
def load_service_jsons() -> ServiceJsonData:
|
|
20
|
+
with open(SERVICE_JSON_FILE_PATH, "r") as file:
|
|
21
|
+
service_json = json.load(file)
|
|
22
|
+
with open(RUNTIME_SERVICE_JSON_FILE_PATH, "r") as file:
|
|
23
|
+
runtime_service_json = json.load(file)
|
|
24
|
+
return ServiceJsonData(sagemaker=service_json, sagemaker_runtime=runtime_service_json)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@lru_cache(maxsize=1)
|
|
28
|
+
def load_combined_shapes_data() -> dict:
|
|
29
|
+
service_json_data = load_service_jsons()
|
|
30
|
+
return {
|
|
31
|
+
**service_json_data.sagemaker["shapes"],
|
|
32
|
+
**service_json_data.sagemaker_runtime["shapes"],
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@lru_cache(maxsize=1)
|
|
37
|
+
def load_combined_operations_data() -> dict:
|
|
38
|
+
service_json_data = load_service_jsons()
|
|
39
|
+
return {
|
|
40
|
+
**service_json_data.sagemaker["operations"],
|
|
41
|
+
**service_json_data.sagemaker_runtime["operations"],
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@lru_cache(maxsize=1)
|
|
46
|
+
def load_additional_operations_data() -> dict:
|
|
47
|
+
with open(ADDITIONAL_OPERATION_FILE_PATH, "r") as file:
|
|
48
|
+
additional_operation_json = json.load(file)
|
|
49
|
+
return additional_operation_json
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
from sagemaker_core.util.util import remove_html_tags
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MethodType(Enum):
|
|
7
|
+
CLASS = "class"
|
|
8
|
+
OBJECT = "object"
|
|
9
|
+
STATIC = "static"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Method:
|
|
13
|
+
"""
|
|
14
|
+
A class to store the information of methods to be generated
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
operation_name: str
|
|
18
|
+
resource_name: str
|
|
19
|
+
method_name: str
|
|
20
|
+
return_type: str
|
|
21
|
+
method_type: MethodType
|
|
22
|
+
service_name: str
|
|
23
|
+
docstring_title: str
|
|
24
|
+
|
|
25
|
+
def __init__(self, **kwargs):
|
|
26
|
+
self.__dict__.update(kwargs)
|
|
27
|
+
|
|
28
|
+
def get_docstring_title(self, operation):
|
|
29
|
+
title = remove_html_tags(operation["documentation"])
|
|
30
|
+
self.docstring_title = title.split(".")[0] + "."
|
|
31
|
+
|
|
32
|
+
# TODO: add some templates for common methods
|