sagemaker-core 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sagemaker-core might be problematic. Click here for more details.
- sagemaker_core/__init__.py +0 -0
- sagemaker_core/_version.py +11 -0
- sagemaker_core/code_injection/__init__.py +0 -0
- sagemaker_core/code_injection/base.py +42 -0
- sagemaker_core/code_injection/codec.py +241 -0
- sagemaker_core/code_injection/constants.py +18 -0
- sagemaker_core/code_injection/shape_dag.py +14527 -0
- sagemaker_core/generated/__init__.py +0 -0
- sagemaker_core/generated/config_schema.py +870 -0
- sagemaker_core/generated/exceptions.py +147 -0
- sagemaker_core/generated/intelligent_defaults_helper.py +198 -0
- sagemaker_core/generated/resources.py +26998 -0
- sagemaker_core/generated/shapes.py +11584 -0
- sagemaker_core/generated/utils.py +314 -0
- sagemaker_core/tools/__init__.py +1 -0
- sagemaker_core/tools/codegen.py +56 -0
- sagemaker_core/tools/constants.py +96 -0
- sagemaker_core/tools/data_extractor.py +49 -0
- sagemaker_core/tools/method.py +32 -0
- sagemaker_core/tools/resources_codegen.py +2122 -0
- sagemaker_core/tools/resources_extractor.py +373 -0
- sagemaker_core/tools/shapes_codegen.py +284 -0
- sagemaker_core/tools/shapes_extractor.py +259 -0
- sagemaker_core/tools/templates.py +747 -0
- sagemaker_core/util/__init__.py +0 -0
- sagemaker_core/util/util.py +81 -0
- sagemaker_core-0.1.3.dist-info/LICENSE +201 -0
- sagemaker_core-0.1.3.dist-info/METADATA +28 -0
- sagemaker_core-0.1.3.dist-info/RECORD +31 -0
- sagemaker_core-0.1.3.dist-info/WHEEL +5 -0
- sagemaker_core-0.1.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,747 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Templates for generating resources."""
|
|
14
|
+
|
|
15
|
+
RESOURCE_CLASS_TEMPLATE = """
|
|
16
|
+
class {class_name}:
|
|
17
|
+
{data_class_members}
|
|
18
|
+
{init_method}
|
|
19
|
+
{class_methods}
|
|
20
|
+
{object_methods}
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
RESOURCE_METHOD_EXCEPTION_DOCSTRING = """
|
|
24
|
+
Raises:
|
|
25
|
+
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
|
|
26
|
+
The error message and error code can be parsed from the exception as follows:
|
|
27
|
+
```
|
|
28
|
+
try:
|
|
29
|
+
# AWS service call here
|
|
30
|
+
except botocore.exceptions.ClientError as e:
|
|
31
|
+
error_message = e.response['Error']['Message']
|
|
32
|
+
error_code = e.response['Error']['Code']
|
|
33
|
+
```"""
|
|
34
|
+
|
|
35
|
+
CREATE_METHOD_TEMPLATE = """
|
|
36
|
+
@classmethod
|
|
37
|
+
@populate_inputs_decorator
|
|
38
|
+
def create(
|
|
39
|
+
cls,
|
|
40
|
+
{create_args}
|
|
41
|
+
session: Optional[Session] = None,
|
|
42
|
+
region: Optional[str] = None,
|
|
43
|
+
) -> Optional["{resource_name}"]:
|
|
44
|
+
{docstring}
|
|
45
|
+
logger.debug("Creating {resource_lower} resource.")
|
|
46
|
+
client = Base.get_sagemaker_client(session=session, region_name=region, service_name='{service_name}')
|
|
47
|
+
|
|
48
|
+
operation_input_args = {{
|
|
49
|
+
{operation_input_args}
|
|
50
|
+
}}
|
|
51
|
+
|
|
52
|
+
operation_input_args = Base.populate_chained_attributes(resource_name='{resource_name}', operation_input_args=operation_input_args)
|
|
53
|
+
|
|
54
|
+
logger.debug(f"Input request: {{operation_input_args}}")
|
|
55
|
+
# serialize the input request
|
|
56
|
+
operation_input_args = cls._serialize_args(operation_input_args)
|
|
57
|
+
logger.debug(f"Serialized input request: {{operation_input_args}}")
|
|
58
|
+
|
|
59
|
+
# create the resource
|
|
60
|
+
response = client.{operation}(**operation_input_args)
|
|
61
|
+
logger.debug(f"Response: {{response}}")
|
|
62
|
+
|
|
63
|
+
return cls.get({get_args}, session=session, region=region)
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
CREATE_METHOD_TEMPLATE_WITHOUT_DEFAULTS = """
|
|
67
|
+
@classmethod
|
|
68
|
+
def create(
|
|
69
|
+
cls,
|
|
70
|
+
{create_args}
|
|
71
|
+
session: Optional[Session] = None,
|
|
72
|
+
region: Optional[str] = None,
|
|
73
|
+
) -> Optional["{resource_name}"]:
|
|
74
|
+
{docstring}
|
|
75
|
+
logger.debug("Creating {resource_lower} resource.")
|
|
76
|
+
client =Base.get_sagemaker_client(session=session, region_name=region, service_name='{service_name}')
|
|
77
|
+
|
|
78
|
+
operation_input_args = {{
|
|
79
|
+
{operation_input_args}
|
|
80
|
+
}}
|
|
81
|
+
|
|
82
|
+
operation_input_args = Base.populate_chained_attributes(resource_name='{resource_name}', operation_input_args=operation_input_args)
|
|
83
|
+
|
|
84
|
+
logger.debug(f"Input request: {{operation_input_args}}")
|
|
85
|
+
# serialize the input request
|
|
86
|
+
operation_input_args = cls._serialize_args(operation_input_args)
|
|
87
|
+
logger.debug(f"Serialized input request: {{operation_input_args}}")
|
|
88
|
+
|
|
89
|
+
# create the resource
|
|
90
|
+
response = client.{operation}(**operation_input_args)
|
|
91
|
+
logger.debug(f"Response: {{response}}")
|
|
92
|
+
|
|
93
|
+
return cls.get({get_args}, session=session, region=region)
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
IMPORT_METHOD_TEMPLATE = """
|
|
97
|
+
@classmethod
|
|
98
|
+
def load(
|
|
99
|
+
cls,
|
|
100
|
+
{import_args}
|
|
101
|
+
session: Optional[Session] = None,
|
|
102
|
+
region: Optional[str] = None,
|
|
103
|
+
) -> Optional["{resource_name}"]:
|
|
104
|
+
{docstring}
|
|
105
|
+
logger.debug(f"Importing {resource_lower} resource.")
|
|
106
|
+
client = SageMakerClient(session=session, region_name=region, service_name='{service_name}').client
|
|
107
|
+
|
|
108
|
+
operation_input_args = {{
|
|
109
|
+
{operation_input_args}
|
|
110
|
+
}}
|
|
111
|
+
|
|
112
|
+
logger.debug(f"Input request: {{operation_input_args}}")
|
|
113
|
+
# serialize the input request
|
|
114
|
+
operation_input_args = cls._serialize_args(operation_input_args)
|
|
115
|
+
logger.debug(f"Serialized input request: {{operation_input_args}}")
|
|
116
|
+
|
|
117
|
+
# import the resource
|
|
118
|
+
response = client.{operation}(**operation_input_args)
|
|
119
|
+
logger.debug(f"Response: {{response}}")
|
|
120
|
+
|
|
121
|
+
return cls.get({get_args}, session=session, region=region)
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
GET_NAME_METHOD_TEMPLATE = """
|
|
125
|
+
def get_name(self) -> str:
|
|
126
|
+
attributes = vars(self)
|
|
127
|
+
resource_name = '{resource_lower}_name'
|
|
128
|
+
resource_name_split = resource_name.split('_')
|
|
129
|
+
attribute_name_candidates = []
|
|
130
|
+
|
|
131
|
+
l = len(resource_name_split)
|
|
132
|
+
for i in range(0, l):
|
|
133
|
+
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
|
|
134
|
+
|
|
135
|
+
for attribute, value in attributes.items():
|
|
136
|
+
if attribute == 'name' or attribute in attribute_name_candidates:
|
|
137
|
+
return value
|
|
138
|
+
logger.error("Name attribute not found for object {resource_lower}")
|
|
139
|
+
return None
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
UPDATE_METHOD_TEMPLATE = """
|
|
144
|
+
@populate_inputs_decorator
|
|
145
|
+
def update(
|
|
146
|
+
self,
|
|
147
|
+
{update_args}
|
|
148
|
+
) -> Optional["{resource_name}"]:
|
|
149
|
+
{docstring}
|
|
150
|
+
logger.debug("Updating {resource_lower} resource.")
|
|
151
|
+
client = Base.get_sagemaker_client()
|
|
152
|
+
|
|
153
|
+
operation_input_args = {{
|
|
154
|
+
{operation_input_args}
|
|
155
|
+
}}
|
|
156
|
+
logger.debug(f"Input request: {{operation_input_args}}")
|
|
157
|
+
# serialize the input request
|
|
158
|
+
operation_input_args = {resource_name}._serialize_args(operation_input_args)
|
|
159
|
+
logger.debug(f"Serialized input request: {{operation_input_args}}")
|
|
160
|
+
|
|
161
|
+
# create the resource
|
|
162
|
+
response = client.{operation}(**operation_input_args)
|
|
163
|
+
logger.debug(f"Response: {{response}}")
|
|
164
|
+
self.refresh()
|
|
165
|
+
|
|
166
|
+
return self
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
UPDATE_METHOD_TEMPLATE_WITHOUT_DECORATOR = """
|
|
171
|
+
def update(
|
|
172
|
+
self,
|
|
173
|
+
{update_args}
|
|
174
|
+
) -> Optional["{resource_name}"]:
|
|
175
|
+
{docstring}
|
|
176
|
+
logger.debug("Updating {resource_lower} resource.")
|
|
177
|
+
client = Base.get_sagemaker_client()
|
|
178
|
+
|
|
179
|
+
operation_input_args = {{
|
|
180
|
+
{operation_input_args}
|
|
181
|
+
}}
|
|
182
|
+
logger.debug(f"Input request: {{operation_input_args}}")
|
|
183
|
+
# serialize the input request
|
|
184
|
+
operation_input_args = {resource_name}._serialize_args(operation_input_args)
|
|
185
|
+
logger.debug(f"Serialized input request: {{operation_input_args}}")
|
|
186
|
+
|
|
187
|
+
# create the resource
|
|
188
|
+
response = client.{operation}(**operation_input_args)
|
|
189
|
+
logger.debug(f"Response: {{response}}")
|
|
190
|
+
self.refresh()
|
|
191
|
+
|
|
192
|
+
return self
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
INVOKE_METHOD_TEMPLATE = """
|
|
196
|
+
def invoke(self,
|
|
197
|
+
{invoke_args}
|
|
198
|
+
) -> Optional[object]:
|
|
199
|
+
{docstring}
|
|
200
|
+
logger.debug(f"Invoking {resource_lower} resource.")
|
|
201
|
+
client = SageMakerRuntimeClient(service_name="{service_name}").client
|
|
202
|
+
operation_input_args = {{
|
|
203
|
+
{operation_input_args}
|
|
204
|
+
}}
|
|
205
|
+
logger.debug(f"Input request: {{operation_input_args}}")
|
|
206
|
+
# serialize the input request
|
|
207
|
+
operation_input_args = {resource_name}._serialize_args(operation_input_args)
|
|
208
|
+
logger.debug(f"Serialized input request: {{operation_input_args}}")
|
|
209
|
+
|
|
210
|
+
# create the resource
|
|
211
|
+
response = client.{operation}(**operation_input_args)
|
|
212
|
+
logger.debug(f"Response: {{response}}")
|
|
213
|
+
|
|
214
|
+
return response
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
INVOKE_ASYNC_METHOD_TEMPLATE = """
|
|
218
|
+
def invoke_async(self,
|
|
219
|
+
{create_args}
|
|
220
|
+
) -> Optional[object]:
|
|
221
|
+
{docstring}
|
|
222
|
+
logger.debug(f"Invoking {resource_lower} resource Async.")
|
|
223
|
+
client = SageMakerRuntimeClient(service_name="{service_name}").client
|
|
224
|
+
|
|
225
|
+
operation_input_args = {{
|
|
226
|
+
{operation_input_args}
|
|
227
|
+
}}
|
|
228
|
+
logger.debug(f"Input request: {{operation_input_args}}")
|
|
229
|
+
# serialize the input request
|
|
230
|
+
operation_input_args = {resource_name}._serialize_args(operation_input_args)
|
|
231
|
+
logger.debug(f"Serialized input request: {{operation_input_args}}")
|
|
232
|
+
|
|
233
|
+
# create the resource
|
|
234
|
+
response = client.{operation}(**operation_input_args)
|
|
235
|
+
logger.debug(f"Response: {{response}}")
|
|
236
|
+
|
|
237
|
+
return response
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
INVOKE_WITH_RESPONSE_STREAM_METHOD_TEMPLATE = """
|
|
241
|
+
def invoke_with_response_stream(self,
|
|
242
|
+
{create_args}
|
|
243
|
+
) -> Optional[object]:
|
|
244
|
+
{docstring}
|
|
245
|
+
logger.debug(f"Invoking {resource_lower} resource with Response Stream.")
|
|
246
|
+
client = SageMakerRuntimeClient(service_name="{service_name}").client
|
|
247
|
+
|
|
248
|
+
operation_input_args = {{
|
|
249
|
+
{operation_input_args}
|
|
250
|
+
}}
|
|
251
|
+
logger.debug(f"Input request: {{operation_input_args}}")
|
|
252
|
+
# serialize the input request
|
|
253
|
+
operation_input_args = {resource_name}._serialize_args(operation_input_args)
|
|
254
|
+
logger.debug(f"Serialized input request: {{operation_input_args}}")
|
|
255
|
+
|
|
256
|
+
# create the resource
|
|
257
|
+
response = client.{operation}(**operation_input_args)
|
|
258
|
+
logger.debug(f"Response: {{response}}")
|
|
259
|
+
|
|
260
|
+
return response
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
POPULATE_DEFAULTS_DECORATOR_TEMPLATE = """
|
|
265
|
+
def populate_inputs_decorator(create_func):
|
|
266
|
+
@functools.wraps(create_func)
|
|
267
|
+
def wrapper(*args, **kwargs):
|
|
268
|
+
config_schema_for_resource = \\
|
|
269
|
+
{config_schema_for_resource}
|
|
270
|
+
return create_func(*args, **Base.get_updated_kwargs_with_configured_attributes(config_schema_for_resource, "{resource_name}", **kwargs))
|
|
271
|
+
return wrapper
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
GET_METHOD_TEMPLATE = """
|
|
275
|
+
@classmethod
|
|
276
|
+
def get(
|
|
277
|
+
cls,
|
|
278
|
+
{describe_args}
|
|
279
|
+
session: Optional[Session] = None,
|
|
280
|
+
region: Optional[str] = None,
|
|
281
|
+
) -> Optional["{resource_name}"]:
|
|
282
|
+
{docstring}
|
|
283
|
+
operation_input_args = {{
|
|
284
|
+
{operation_input_args}
|
|
285
|
+
}}
|
|
286
|
+
client = Base.get_sagemaker_client(session=session, region_name=region, service_name='{service_name}')
|
|
287
|
+
response = client.{operation}(**operation_input_args)
|
|
288
|
+
|
|
289
|
+
pprint(response)
|
|
290
|
+
|
|
291
|
+
# deserialize the response
|
|
292
|
+
transformed_response = transform(response, '{describe_operation_output_shape}')
|
|
293
|
+
{resource_lower} = cls(**transformed_response)
|
|
294
|
+
return {resource_lower}
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
REFRESH_METHOD_TEMPLATE = """
|
|
298
|
+
def refresh(
|
|
299
|
+
self,
|
|
300
|
+
{refresh_args}
|
|
301
|
+
) -> Optional["{resource_name}"]:
|
|
302
|
+
{docstring}
|
|
303
|
+
operation_input_args = {{
|
|
304
|
+
{operation_input_args}
|
|
305
|
+
}}
|
|
306
|
+
client = Base.get_sagemaker_client()
|
|
307
|
+
response = client.{operation}(**operation_input_args)
|
|
308
|
+
|
|
309
|
+
# deserialize response and update self
|
|
310
|
+
transform(response, '{describe_operation_output_shape}', self)
|
|
311
|
+
return self
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
FAILED_STATUS_ERROR_TEMPLATE = """
|
|
315
|
+
if "failed" in current_status.lower():
|
|
316
|
+
raise FailedStatusError(resource_type="{resource_name}", status=current_status, reason={reason})
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
WAIT_METHOD_TEMPLATE = '''
|
|
320
|
+
def wait(
|
|
321
|
+
self,
|
|
322
|
+
poll: int = 5,
|
|
323
|
+
timeout: Optional[int] = None
|
|
324
|
+
) -> None:
|
|
325
|
+
"""
|
|
326
|
+
Wait for a {resource_name} resource.
|
|
327
|
+
|
|
328
|
+
Parameters:
|
|
329
|
+
poll: The number of seconds to wait between each poll.
|
|
330
|
+
timeout: The maximum number of seconds to wait before timing out.
|
|
331
|
+
|
|
332
|
+
Raises:
|
|
333
|
+
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
|
|
334
|
+
FailedStatusError: If the resource reaches a failed state.
|
|
335
|
+
WaiterError: Raised when an error occurs while waiting.
|
|
336
|
+
|
|
337
|
+
"""
|
|
338
|
+
terminal_states = {terminal_resource_states}
|
|
339
|
+
start_time = time.time()
|
|
340
|
+
|
|
341
|
+
while True:
|
|
342
|
+
self.refresh()
|
|
343
|
+
current_status = self{status_key_path}
|
|
344
|
+
|
|
345
|
+
if current_status in terminal_states:
|
|
346
|
+
print(f"\\nFinal Resource Status: {{current_status}}")
|
|
347
|
+
{failed_error_block}
|
|
348
|
+
return
|
|
349
|
+
|
|
350
|
+
if timeout is not None and time.time() - start_time >= timeout:
|
|
351
|
+
raise TimeoutExceededError(resouce_type="{resource_name}", status=current_status)
|
|
352
|
+
print("-", end="")
|
|
353
|
+
time.sleep(poll)
|
|
354
|
+
'''
|
|
355
|
+
|
|
356
|
+
WAIT_FOR_STATUS_METHOD_TEMPLATE = '''
|
|
357
|
+
def wait_for_status(
|
|
358
|
+
self,
|
|
359
|
+
status: Literal{resource_states},
|
|
360
|
+
poll: int = 5,
|
|
361
|
+
timeout: Optional[int] = None
|
|
362
|
+
) -> None:
|
|
363
|
+
"""
|
|
364
|
+
Wait for a {resource_name} resource to reach certain status.
|
|
365
|
+
|
|
366
|
+
Parameters:
|
|
367
|
+
status: The status to wait for.
|
|
368
|
+
poll: The number of seconds to wait between each poll.
|
|
369
|
+
timeout: The maximum number of seconds to wait before timing out.
|
|
370
|
+
|
|
371
|
+
Raises:
|
|
372
|
+
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
|
|
373
|
+
FailedStatusError: If the resource reaches a failed state.
|
|
374
|
+
WaiterError: Raised when an error occurs while waiting.
|
|
375
|
+
"""
|
|
376
|
+
start_time = time.time()
|
|
377
|
+
|
|
378
|
+
while True:
|
|
379
|
+
self.refresh()
|
|
380
|
+
current_status = self{status_key_path}
|
|
381
|
+
|
|
382
|
+
if status == current_status:
|
|
383
|
+
print(f"\\nFinal Resource Status: {{current_status}}")
|
|
384
|
+
return
|
|
385
|
+
{failed_error_block}
|
|
386
|
+
if timeout is not None and time.time() - start_time >= timeout:
|
|
387
|
+
raise TimeoutExceededError(resouce_type="{resource_name}", status=current_status)
|
|
388
|
+
print("-", end="")
|
|
389
|
+
time.sleep(poll)
|
|
390
|
+
'''
|
|
391
|
+
|
|
392
|
+
WAIT_FOR_DELETE_METHOD_TEMPLATE = '''
|
|
393
|
+
def wait_for_delete(
|
|
394
|
+
self,
|
|
395
|
+
poll: int = 5,
|
|
396
|
+
timeout: Optional[int] = None,
|
|
397
|
+
) -> None:
|
|
398
|
+
"""
|
|
399
|
+
Wait for a {resource_name} resource to be deleted.
|
|
400
|
+
|
|
401
|
+
Parameters:
|
|
402
|
+
poll: The number of seconds to wait between each poll.
|
|
403
|
+
timeout: The maximum number of seconds to wait before timing out.
|
|
404
|
+
|
|
405
|
+
Raises:
|
|
406
|
+
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
|
|
407
|
+
The error message and error code can be parsed from the exception as follows:
|
|
408
|
+
```
|
|
409
|
+
try:
|
|
410
|
+
# AWS service call here
|
|
411
|
+
except botocore.exceptions.ClientError as e:
|
|
412
|
+
error_message = e.response['Error']['Message']
|
|
413
|
+
error_code = e.response['Error']['Code']
|
|
414
|
+
```
|
|
415
|
+
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
|
|
416
|
+
DeleteFailedStatusError: If the resource reaches a failed state.
|
|
417
|
+
WaiterError: Raised when an error occurs while waiting.
|
|
418
|
+
"""
|
|
419
|
+
start_time = time.time()
|
|
420
|
+
|
|
421
|
+
while True:
|
|
422
|
+
try:
|
|
423
|
+
self.refresh()
|
|
424
|
+
current_status = self{status_key_path}
|
|
425
|
+
{delete_failed_error_block}
|
|
426
|
+
{deleted_status_check}
|
|
427
|
+
|
|
428
|
+
if timeout is not None and time.time() - start_time >= timeout:
|
|
429
|
+
raise TimeoutExceededError(resouce_type="{resource_name}", status=current_status)
|
|
430
|
+
except botocore.exceptions.ClientError as e:
|
|
431
|
+
error_code = e.response["Error"]["Code"]
|
|
432
|
+
|
|
433
|
+
if "ResourceNotFound" in error_code or "ValidationException" in error_code:
|
|
434
|
+
print("Resource was not found. It may have been deleted.")
|
|
435
|
+
return
|
|
436
|
+
raise e
|
|
437
|
+
|
|
438
|
+
print("-", end="")
|
|
439
|
+
time.sleep(poll)
|
|
440
|
+
'''
|
|
441
|
+
|
|
442
|
+
DELETE_FAILED_STATUS_CHECK = """
|
|
443
|
+
if "delete_failed" in current_status.lower() or "deletefailed" in current_status.lower():
|
|
444
|
+
raise DeleteFailedStatusError(resource_type="{resource_name}", reason={reason})
|
|
445
|
+
"""
|
|
446
|
+
|
|
447
|
+
DELETED_STATUS_CHECK = """
|
|
448
|
+
if current_status.lower() == "deleted":
|
|
449
|
+
print("Resource was deleted.")
|
|
450
|
+
return
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
DELETE_METHOD_TEMPLATE = """
|
|
454
|
+
def delete(
|
|
455
|
+
self,
|
|
456
|
+
{delete_args}
|
|
457
|
+
) -> None:
|
|
458
|
+
{docstring}
|
|
459
|
+
client = Base.get_sagemaker_client()
|
|
460
|
+
|
|
461
|
+
operation_input_args = {{
|
|
462
|
+
{operation_input_args}
|
|
463
|
+
}}
|
|
464
|
+
client.{operation}(**operation_input_args)
|
|
465
|
+
|
|
466
|
+
logger.info(f"Deleting {{self.__class__.__name__}} - {{self.get_name()}}")
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
STOP_METHOD_TEMPLATE = """
|
|
470
|
+
def stop(self) -> None:
|
|
471
|
+
{docstring}
|
|
472
|
+
client = SageMakerClient().client
|
|
473
|
+
|
|
474
|
+
operation_input_args = {{
|
|
475
|
+
{operation_input_args}
|
|
476
|
+
}}
|
|
477
|
+
client.{operation}(**operation_input_args)
|
|
478
|
+
"""
|
|
479
|
+
|
|
480
|
+
GET_ALL_METHOD_WITH_ARGS_TEMPLATE = """
|
|
481
|
+
@classmethod
|
|
482
|
+
def get_all(
|
|
483
|
+
cls,
|
|
484
|
+
{get_all_args}
|
|
485
|
+
session: Optional[Session] = None,
|
|
486
|
+
region: Optional[str] = None,
|
|
487
|
+
) -> ResourceIterator["{resource}"]:
|
|
488
|
+
{docstring}
|
|
489
|
+
client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}")
|
|
490
|
+
|
|
491
|
+
operation_input_args = {{
|
|
492
|
+
{operation_input_args}
|
|
493
|
+
}}
|
|
494
|
+
{custom_key_mapping}
|
|
495
|
+
operation_input_args = {{k: v for k, v in operation_input_args.items() if v is not None and not isinstance(v, Unassigned)}}
|
|
496
|
+
|
|
497
|
+
return ResourceIterator(
|
|
498
|
+
{resource_iterator_args}
|
|
499
|
+
)
|
|
500
|
+
"""
|
|
501
|
+
|
|
502
|
+
GET_ALL_METHOD_NO_ARGS_TEMPLATE = '''
|
|
503
|
+
@classmethod
|
|
504
|
+
def get_all(
|
|
505
|
+
cls,
|
|
506
|
+
session: Optional[Session] = None,
|
|
507
|
+
region: Optional[str] = None,
|
|
508
|
+
) -> ResourceIterator["{resource}"]:
|
|
509
|
+
"""
|
|
510
|
+
Get all {resource} resources.
|
|
511
|
+
|
|
512
|
+
Parameters:
|
|
513
|
+
session: Boto3 session.
|
|
514
|
+
region: Region name.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
Iterator for listed {resource} resources.
|
|
518
|
+
|
|
519
|
+
"""
|
|
520
|
+
client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}")
|
|
521
|
+
{custom_key_mapping}
|
|
522
|
+
return ResourceIterator(
|
|
523
|
+
{resource_iterator_args}
|
|
524
|
+
)
|
|
525
|
+
'''
|
|
526
|
+
|
|
527
|
+
GENERIC_METHOD_TEMPLATE = """
|
|
528
|
+
{decorator}
|
|
529
|
+
def {method_name}(
|
|
530
|
+
{method_args}
|
|
531
|
+
) -> {return_type}:
|
|
532
|
+
{docstring}
|
|
533
|
+
{serialize_operation_input}
|
|
534
|
+
{initialize_client}
|
|
535
|
+
{call_operation_api}
|
|
536
|
+
{deserialize_response}
|
|
537
|
+
"""
|
|
538
|
+
|
|
539
|
+
SERIALIZE_INPUT_TEMPLATE = """
|
|
540
|
+
operation_input_args = {{
|
|
541
|
+
{operation_input_args}
|
|
542
|
+
}}
|
|
543
|
+
logger.debug(f"Input request: {{operation_input_args}}")"""
|
|
544
|
+
|
|
545
|
+
INITIALIZE_CLIENT_TEMPLATE = """
|
|
546
|
+
client = Base.get_sagemaker_client(session=session, region_name=region, service_name='{service_name}')"""
|
|
547
|
+
|
|
548
|
+
CALL_OPERATION_API_TEMPLATE = """
|
|
549
|
+
logger.debug(f"Calling {operation} API")
|
|
550
|
+
response = client.{operation}(**operation_input_args)
|
|
551
|
+
logger.debug(f"Response: {{response}}")"""
|
|
552
|
+
|
|
553
|
+
CALL_OPERATION_API_NO_ARG_TEMPLATE = """
|
|
554
|
+
logger.debug(f"Calling {operation} API")
|
|
555
|
+
response = client.{operation}()
|
|
556
|
+
logger.debug(f"Response: {{response}}")"""
|
|
557
|
+
|
|
558
|
+
DESERIALIZE_RESPONSE_TEMPLATE = """
|
|
559
|
+
transformed_response = transform(response, '{operation_output_shape}')
|
|
560
|
+
return {return_type_conversion}(**transformed_response)"""
|
|
561
|
+
|
|
562
|
+
DESERIALIZE_RESPONSE_TO_BASIC_TYPE_TEMPLATE = """
|
|
563
|
+
return list(response.values())[0]"""
|
|
564
|
+
|
|
565
|
+
SERIALIZE_LIST_INPUT_TEMPLATE = """
|
|
566
|
+
operation_input_args = {{
|
|
567
|
+
{operation_input_args}
|
|
568
|
+
}}
|
|
569
|
+
operation_input_args = {{k: v for k, v in operation_input_args.items() if v is not None and not isinstance(v, Unassigned)}}
|
|
570
|
+
logger.debug(f"Input request: {{operation_input_args}}")"""
|
|
571
|
+
|
|
572
|
+
RETURN_ITERATOR_TEMPLATE = """
|
|
573
|
+
return ResourceIterator(
|
|
574
|
+
{resource_iterator_args}
|
|
575
|
+
)"""
|
|
576
|
+
|
|
577
|
+
DESERIALIZE_INPUT_AND_RESPONSE_TO_CLS_TEMPLATE = """
|
|
578
|
+
transformed_response = transform(response, '{operation_output_shape}')
|
|
579
|
+
return cls(**operation_input_args, **transformed_response)"""
|
|
580
|
+
|
|
581
|
+
RESOURCE_BASE_CLASS_TEMPLATE = """
|
|
582
|
+
class Base(BaseModel):
|
|
583
|
+
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True)
|
|
584
|
+
|
|
585
|
+
@classmethod
|
|
586
|
+
def get_sagemaker_client(cls, session = None, region_name = None, service_name = 'sagemaker'):
|
|
587
|
+
return SageMakerClient(session=session, region_name=region_name, service_name=service_name).client
|
|
588
|
+
|
|
589
|
+
@classmethod
|
|
590
|
+
def _serialize_args(cls, value: dict) -> dict:
|
|
591
|
+
serialized_dict = {}
|
|
592
|
+
for k, v in value.items():
|
|
593
|
+
if serialize_result := cls._serialize(v):
|
|
594
|
+
serialized_dict.update({k: serialize_result})
|
|
595
|
+
return serialized_dict
|
|
596
|
+
|
|
597
|
+
@classmethod
|
|
598
|
+
def _serialize(cls, value: any) -> any:
|
|
599
|
+
if isinstance(value, Unassigned):
|
|
600
|
+
return None
|
|
601
|
+
elif isinstance(value, List):
|
|
602
|
+
return cls._serialize_list(value)
|
|
603
|
+
elif is_not_primitive(value) and not isinstance(value, dict):
|
|
604
|
+
return cls._serialize_object(value)
|
|
605
|
+
elif hasattr(value, "serialize"):
|
|
606
|
+
return value.serialize()
|
|
607
|
+
else:
|
|
608
|
+
return value
|
|
609
|
+
|
|
610
|
+
@classmethod
|
|
611
|
+
def _serialize_list(cls, value: List):
|
|
612
|
+
serialized_list = []
|
|
613
|
+
for v in value:
|
|
614
|
+
if serialize_result := cls._serialize(v):
|
|
615
|
+
serialized_list.append(serialize_result)
|
|
616
|
+
return serialized_list
|
|
617
|
+
|
|
618
|
+
@classmethod
|
|
619
|
+
def _serialize_object(cls, value: any):
|
|
620
|
+
serialized_dict = {}
|
|
621
|
+
for k, v in vars(value).items():
|
|
622
|
+
if serialize_result := cls._serialize(v):
|
|
623
|
+
key = snake_to_pascal(k) if is_snake_case(k) else k
|
|
624
|
+
serialized_dict.update({key[0].upper() + key[1:]: serialize_result})
|
|
625
|
+
return serialized_dict
|
|
626
|
+
|
|
627
|
+
@staticmethod
|
|
628
|
+
def get_updated_kwargs_with_configured_attributes(
|
|
629
|
+
config_schema_for_resource: dict, resource_name: str, **kwargs
|
|
630
|
+
):
|
|
631
|
+
try:
|
|
632
|
+
for configurable_attribute in config_schema_for_resource:
|
|
633
|
+
if kwargs.get(configurable_attribute) is None:
|
|
634
|
+
resource_defaults = load_default_configs_for_resource_name(
|
|
635
|
+
resource_name=resource_name
|
|
636
|
+
)
|
|
637
|
+
global_defaults = load_default_configs_for_resource_name(
|
|
638
|
+
resource_name="GlobalDefaults"
|
|
639
|
+
)
|
|
640
|
+
if config_value := get_config_value(
|
|
641
|
+
configurable_attribute, resource_defaults, global_defaults
|
|
642
|
+
):
|
|
643
|
+
resource_name = snake_to_pascal(configurable_attribute)
|
|
644
|
+
class_object = globals()[resource_name]
|
|
645
|
+
kwargs[configurable_attribute] = class_object(**config_value)
|
|
646
|
+
except BaseException as e:
|
|
647
|
+
logger.info("Could not load Default Configs. Continuing.", exc_info=True)
|
|
648
|
+
# Continue with existing kwargs if no default configs found
|
|
649
|
+
return kwargs
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
@staticmethod
|
|
653
|
+
def populate_chained_attributes(resource_name: str, operation_input_args: Union[dict, object]):
|
|
654
|
+
resource_name_in_snake_case = pascal_to_snake(resource_name)
|
|
655
|
+
updated_args = vars(operation_input_args) if type(operation_input_args) == object else operation_input_args
|
|
656
|
+
unassigned_args = []
|
|
657
|
+
keys = operation_input_args.keys()
|
|
658
|
+
for arg in keys:
|
|
659
|
+
value = operation_input_args.get(arg)
|
|
660
|
+
arg_snake = pascal_to_snake(arg)
|
|
661
|
+
|
|
662
|
+
if value == Unassigned() :
|
|
663
|
+
unassigned_args.append(arg)
|
|
664
|
+
elif value == None or not value:
|
|
665
|
+
continue
|
|
666
|
+
elif (
|
|
667
|
+
arg_snake.endswith("name")
|
|
668
|
+
and arg_snake[: -len("_name")] != resource_name_in_snake_case
|
|
669
|
+
and arg_snake != "name"
|
|
670
|
+
):
|
|
671
|
+
if value and value != Unassigned() and type(value) != str:
|
|
672
|
+
updated_args[arg] = value.get_name()
|
|
673
|
+
elif isinstance(value, list) and is_primitive_list(value):
|
|
674
|
+
continue
|
|
675
|
+
elif isinstance(value, list) and value != []:
|
|
676
|
+
updated_args[arg] = [
|
|
677
|
+
Base._get_chained_attribute(list_item)
|
|
678
|
+
for list_item in value
|
|
679
|
+
]
|
|
680
|
+
elif is_not_primitive(value) and is_not_str_dict(value) and type(value) == object:
|
|
681
|
+
updated_args[arg] = Base._get_chained_attribute(item_value=value)
|
|
682
|
+
|
|
683
|
+
for unassigned_arg in unassigned_args:
|
|
684
|
+
del updated_args[unassigned_arg]
|
|
685
|
+
return updated_args
|
|
686
|
+
|
|
687
|
+
@staticmethod
|
|
688
|
+
def _get_chained_attribute(item_value: any):
|
|
689
|
+
resource_name = type(item_value).__name__
|
|
690
|
+
class_object = globals()[resource_name]
|
|
691
|
+
return class_object(**Base.populate_chained_attributes(
|
|
692
|
+
resource_name=resource_name,
|
|
693
|
+
operation_input_args=vars(item_value)
|
|
694
|
+
))
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
"""
|
|
698
|
+
|
|
699
|
+
SHAPE_BASE_CLASS_TEMPLATE = """
|
|
700
|
+
class {class_name}:
|
|
701
|
+
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True)
|
|
702
|
+
|
|
703
|
+
def serialize(self):
|
|
704
|
+
result = {{}}
|
|
705
|
+
for attr, value in self.__dict__.items():
|
|
706
|
+
if isinstance(value, Unassigned):
|
|
707
|
+
continue
|
|
708
|
+
|
|
709
|
+
components = attr.split('_')
|
|
710
|
+
pascal_attr = ''.join(x.title() for x in components[0:])
|
|
711
|
+
if isinstance(value, List):
|
|
712
|
+
result[pascal_attr] = self._serialize_list(value)
|
|
713
|
+
elif isinstance(value, Dict):
|
|
714
|
+
result[pascal_attr] = self._serialize_dict(value)
|
|
715
|
+
elif hasattr(value, 'serialize'):
|
|
716
|
+
result[pascal_attr] = value.serialize()
|
|
717
|
+
else:
|
|
718
|
+
result[pascal_attr] = value
|
|
719
|
+
return result
|
|
720
|
+
|
|
721
|
+
def _serialize_list(self, value: List):
|
|
722
|
+
return [v.serialize() if hasattr(v, 'serialize') else v for v in value]
|
|
723
|
+
|
|
724
|
+
def _serialize_dict(self, value: Dict):
|
|
725
|
+
return {{k: v.serialize() if hasattr(v, 'serialize') else v for k, v in value.items()}}
|
|
726
|
+
"""
|
|
727
|
+
|
|
728
|
+
SHAPE_CLASS_TEMPLATE = '''
|
|
729
|
+
class {class_name}:
|
|
730
|
+
"""
|
|
731
|
+
{docstring}
|
|
732
|
+
"""
|
|
733
|
+
{data_class_members}
|
|
734
|
+
|
|
735
|
+
'''
|
|
736
|
+
|
|
737
|
+
RESOURCE_METHOD_EXCEPTION_DOCSTRING = """
|
|
738
|
+
Raises:
|
|
739
|
+
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
|
|
740
|
+
The error message and error code can be parsed from the exception as follows:
|
|
741
|
+
```
|
|
742
|
+
try:
|
|
743
|
+
# AWS service call here
|
|
744
|
+
except botocore.exceptions.ClientError as e:
|
|
745
|
+
error_message = e.response['Error']['Message']
|
|
746
|
+
error_code = e.response['Error']['Code']
|
|
747
|
+
```"""
|