sagemaker-core 0.1.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sagemaker-core might be problematic. Click here for more details.
- sagemaker_core/__init__.py +4 -0
- sagemaker_core/helper/session_helper.py +769 -0
- sagemaker_core/{code_injection → main/code_injection}/codec.py +2 -2
- sagemaker_core/{code_injection → main/code_injection}/constants.py +10 -0
- sagemaker_core/{code_injection → main/code_injection}/shape_dag.py +48 -0
- sagemaker_core/{generated → main}/config_schema.py +47 -0
- sagemaker_core/{generated → main}/intelligent_defaults_helper.py +8 -8
- sagemaker_core/{generated → main}/resources.py +2716 -1284
- sagemaker_core/main/shapes.py +11650 -0
- sagemaker_core/main/user_agent.py +77 -0
- sagemaker_core/{generated → main}/utils.py +246 -10
- sagemaker_core/resources/__init__.py +1 -0
- sagemaker_core/shapes/__init__.py +1 -0
- sagemaker_core/tools/__init__.py +1 -1
- sagemaker_core/tools/codegen.py +1 -2
- sagemaker_core/tools/constants.py +3 -8
- sagemaker_core/tools/method.py +1 -1
- sagemaker_core/tools/resources_codegen.py +30 -28
- sagemaker_core/tools/resources_extractor.py +4 -8
- sagemaker_core/tools/shapes_codegen.py +16 -10
- sagemaker_core/tools/shapes_extractor.py +1 -1
- sagemaker_core/tools/templates.py +109 -122
- sagemaker_core-1.0.1.dist-info/METADATA +81 -0
- sagemaker_core-1.0.1.dist-info/RECORD +34 -0
- {sagemaker_core-0.1.3.dist-info → sagemaker_core-1.0.1.dist-info}/WHEEL +1 -1
- sagemaker_core/generated/shapes.py +0 -11584
- sagemaker_core/util/util.py +0 -81
- sagemaker_core-0.1.3.dist-info/METADATA +0 -28
- sagemaker_core-0.1.3.dist-info/RECORD +0 -31
- /sagemaker_core/{code_injection → helper}/__init__.py +0 -0
- /sagemaker_core/{generated → main}/__init__.py +0 -0
- /sagemaker_core/{util → main/code_injection}/__init__.py +0 -0
- /sagemaker_core/{code_injection → main/code_injection}/base.py +0 -0
- /sagemaker_core/{generated → main}/exceptions.py +0 -0
- {sagemaker_core-0.1.3.dist-info → sagemaker_core-1.0.1.dist-info}/LICENSE +0 -0
- {sagemaker_core-0.1.3.dist-info → sagemaker_core-1.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
from __future__ import absolute_import
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
import importlib_metadata
|
|
19
|
+
|
|
20
|
+
SagemakerCore_PREFIX = "AWS-SageMakerCore"
|
|
21
|
+
STUDIO_PREFIX = "AWS-SageMaker-Studio"
|
|
22
|
+
NOTEBOOK_PREFIX = "AWS-SageMaker-Notebook-Instance"
|
|
23
|
+
|
|
24
|
+
NOTEBOOK_METADATA_FILE = "/etc/opt/ml/sagemaker-notebook-instance-version.txt"
|
|
25
|
+
STUDIO_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
|
|
26
|
+
|
|
27
|
+
SagemakerCore_VERSION = "v0.1.6"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def process_notebook_metadata_file() -> str:
|
|
31
|
+
"""Check if the platform is SageMaker Notebook, if yes, return the InstanceType
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
str: The InstanceType of the SageMaker Notebook if it exists, otherwise None
|
|
35
|
+
"""
|
|
36
|
+
if os.path.exists(NOTEBOOK_METADATA_FILE):
|
|
37
|
+
with open(NOTEBOOK_METADATA_FILE, "r") as sagemaker_nbi_file:
|
|
38
|
+
return sagemaker_nbi_file.read().strip()
|
|
39
|
+
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def process_studio_metadata_file() -> str:
|
|
44
|
+
"""Check if the platform is SageMaker Studio, if yes, return the AppType
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
str: The AppType of the SageMaker Studio if it exists, otherwise None
|
|
48
|
+
"""
|
|
49
|
+
if os.path.exists(STUDIO_METADATA_FILE):
|
|
50
|
+
with open(STUDIO_METADATA_FILE, "r") as sagemaker_studio_file:
|
|
51
|
+
metadata = json.load(sagemaker_studio_file)
|
|
52
|
+
return metadata.get("AppType")
|
|
53
|
+
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_user_agent_extra_suffix() -> str:
|
|
58
|
+
"""Get the user agent extra suffix string specific to SageMakerCore
|
|
59
|
+
|
|
60
|
+
Adhers to new boto recommended User-Agent 2.0 header format
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
str: The user agent extra suffix string to be appended
|
|
64
|
+
"""
|
|
65
|
+
suffix = "lib/{}#{}".format(SagemakerCore_PREFIX, SagemakerCore_VERSION)
|
|
66
|
+
|
|
67
|
+
# Get the notebook instance type and prepend it to the user agent string if exists
|
|
68
|
+
notebook_instance_type = process_notebook_metadata_file()
|
|
69
|
+
if notebook_instance_type:
|
|
70
|
+
suffix = "{} md/{}#{}".format(suffix, NOTEBOOK_PREFIX, notebook_instance_type)
|
|
71
|
+
|
|
72
|
+
# Get the studio app type and prepend it to the user agent string if exists
|
|
73
|
+
studio_app_type = process_studio_metadata_file()
|
|
74
|
+
if studio_app_type:
|
|
75
|
+
suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type)
|
|
76
|
+
|
|
77
|
+
return suffix
|
|
@@ -15,14 +15,174 @@ import datetime
|
|
|
15
15
|
import logging
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
|
+
import subprocess
|
|
18
19
|
|
|
19
20
|
from boto3.session import Session
|
|
20
21
|
from botocore.config import Config
|
|
21
|
-
from
|
|
22
|
-
from
|
|
22
|
+
from rich import reconfigure
|
|
23
|
+
from rich.console import Console
|
|
24
|
+
from rich.logging import RichHandler
|
|
25
|
+
from rich.style import Style
|
|
26
|
+
from rich.theme import Theme
|
|
27
|
+
from rich.traceback import install
|
|
28
|
+
from typing import Any, Dict, List, TypeVar, Generic, Type
|
|
29
|
+
from sagemaker_core.main.code_injection.codec import transform
|
|
30
|
+
from sagemaker_core.main.code_injection.constants import Color
|
|
31
|
+
from sagemaker_core.main.user_agent import get_user_agent_extra_suffix
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def add_indent(text, num_spaces=4):
|
|
35
|
+
"""
|
|
36
|
+
Add customizable indent spaces to a given text.
|
|
37
|
+
Parameters:
|
|
38
|
+
text (str): The text to which the indent spaces will be added.
|
|
39
|
+
num_spaces (int): Number of spaces to be added for each level of indentation. Default is 4.
|
|
40
|
+
Returns:
|
|
41
|
+
str: The text with added indent spaces.
|
|
42
|
+
"""
|
|
43
|
+
indent = " " * num_spaces
|
|
44
|
+
lines = text.split("\n")
|
|
45
|
+
indented_text = "\n".join(indent + line for line in lines)
|
|
46
|
+
return indented_text.rstrip(" ")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def clean_documentaion(documentation):
|
|
50
|
+
documentation = re.sub(r"<\/?p>", "", documentation)
|
|
51
|
+
documentation = re.sub(r"<\/?code>", "'", documentation)
|
|
52
|
+
return documentation
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def convert_to_snake_case(entity_name):
|
|
56
|
+
"""
|
|
57
|
+
Convert a string to snake_case.
|
|
58
|
+
Args:
|
|
59
|
+
entity_name (str): The string to convert.
|
|
60
|
+
Returns:
|
|
61
|
+
str: The converted string in snake_case.
|
|
62
|
+
"""
|
|
63
|
+
snake_case = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", entity_name)
|
|
64
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case).lower()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def snake_to_pascal(snake_str):
|
|
68
|
+
"""
|
|
69
|
+
Convert a snake_case string to PascalCase.
|
|
70
|
+
Args:
|
|
71
|
+
snake_str (str): The snake_case string to be converted.
|
|
72
|
+
Returns:
|
|
73
|
+
str: The PascalCase string.
|
|
74
|
+
"""
|
|
75
|
+
components = snake_str.split("_")
|
|
76
|
+
return "".join(x.title() for x in components[0:])
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def reformat_file_with_black(filename):
|
|
80
|
+
try:
|
|
81
|
+
# Run black with specific options using subprocess
|
|
82
|
+
subprocess.run(["black", "-l", "100", filename], check=True)
|
|
83
|
+
print(f"File '{filename}' reformatted successfully.")
|
|
84
|
+
except subprocess.CalledProcessError as e:
|
|
85
|
+
print(f"An error occurred while reformatting '{filename}': {e}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def remove_html_tags(text):
|
|
89
|
+
clean = re.compile("<.*?>")
|
|
90
|
+
return re.sub(clean, "", text)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def escape_special_rst_characters(text):
|
|
94
|
+
# List of special characters that need to be escaped in reStructuredText
|
|
95
|
+
special_characters = ["*", "|"]
|
|
96
|
+
|
|
97
|
+
for char in special_characters:
|
|
98
|
+
# Use a regex to find the special character if preceded by a space
|
|
99
|
+
pattern = rf"(?<=\s){re.escape(char)}"
|
|
100
|
+
text = re.sub(pattern, rf"\\{char}", text)
|
|
101
|
+
|
|
102
|
+
return text
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_textual_rich_theme() -> Theme:
|
|
106
|
+
"""
|
|
107
|
+
Get a textual rich theme with customized styling.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Theme: A textual rich theme
|
|
111
|
+
"""
|
|
112
|
+
return Theme(
|
|
113
|
+
{
|
|
114
|
+
"logging.level.info": Style(color=Color.BLUE.value, bold=True),
|
|
115
|
+
"logging.level.debug": Style(color=Color.GREEN.value, bold=True),
|
|
116
|
+
"logging.level.warning": Style(color=Color.YELLOW.value, bold=True),
|
|
117
|
+
"logging.level.error": Style(color=Color.RED.value, bold=True),
|
|
118
|
+
"logging.keyword": Style(color=Color.YELLOW.value, bold=True),
|
|
119
|
+
"repr.attrib_name": Style(color=Color.YELLOW.value, italic=False),
|
|
120
|
+
"repr.attrib_value": Style(color=Color.PURPLE.value, italic=False),
|
|
121
|
+
"repr.bool_true": Style(color=Color.GREEN.value, italic=True),
|
|
122
|
+
"repr.bool_false": Style(color=Color.RED.value, italic=True),
|
|
123
|
+
"repr.call": Style(color=Color.PURPLE.value, bold=True),
|
|
124
|
+
"repr.none": Style(color=Color.PURPLE.value, italic=True),
|
|
125
|
+
"repr.str": Style(color=Color.GREEN.value),
|
|
126
|
+
"repr.path": Style(color=Color.PURPLE.value),
|
|
127
|
+
"repr.filename": Style(color=Color.PURPLE.value),
|
|
128
|
+
"repr.url": Style(color=Color.BLUE.value, underline=True),
|
|
129
|
+
"repr.tag_name": Style(color=Color.PURPLE.value, bold=True),
|
|
130
|
+
"repr.ipv4": Style.null(),
|
|
131
|
+
"repr.ipv6": Style.null(),
|
|
132
|
+
"repr.eui48": Style.null(),
|
|
133
|
+
"repr.eui64": Style.null(),
|
|
134
|
+
"json.bool_true": Style(color=Color.GREEN.value, italic=True),
|
|
135
|
+
"json.bool_false": Style(color=Color.RED.value, italic=True),
|
|
136
|
+
"json.null": Style(color=Color.PURPLE.value, italic=True),
|
|
137
|
+
"json.str": Style(color=Color.GREEN.value),
|
|
138
|
+
"json.key": Style(color=Color.BLUE.value, bold=True),
|
|
139
|
+
"traceback.error": Style(color=Color.BRIGHT_RED.value, italic=True),
|
|
140
|
+
"traceback.border": Style(color=Color.BRIGHT_RED.value),
|
|
141
|
+
"traceback.title": Style(color=Color.BRIGHT_RED.value, bold=True),
|
|
142
|
+
}
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
textual_rich_console_and_traceback_enabled = False
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def enable_textual_rich_console_and_traceback():
|
|
150
|
+
"""
|
|
151
|
+
Reconfigure the global textual rich console with the customized theme
|
|
152
|
+
and enable textual rich error traceback
|
|
153
|
+
"""
|
|
154
|
+
global textual_rich_console_and_traceback_enabled
|
|
155
|
+
if not textual_rich_console_and_traceback_enabled:
|
|
156
|
+
theme = get_textual_rich_theme()
|
|
157
|
+
reconfigure(theme=theme)
|
|
158
|
+
console = Console(theme=theme)
|
|
159
|
+
install(console=console)
|
|
160
|
+
textual_rich_console_and_traceback_enabled = True
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def get_textual_rich_logger(name: str, log_level: str = "INFO") -> logging.Logger:
|
|
164
|
+
"""
|
|
165
|
+
Get a logger with textual rich handler.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
name (str): The name of the logger
|
|
169
|
+
log_level (str): The log level to set.
|
|
170
|
+
Accepted values are: "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL".
|
|
171
|
+
Defaults to the value of "INFO".
|
|
172
|
+
|
|
173
|
+
Return:
|
|
174
|
+
logging.Logger: A textial rich logger.
|
|
175
|
+
|
|
176
|
+
"""
|
|
177
|
+
enable_textual_rich_console_and_traceback()
|
|
178
|
+
handler = RichHandler(markup=True)
|
|
179
|
+
logging.basicConfig(level=getattr(logging, log_level), handlers=[handler])
|
|
180
|
+
logger = logging.getLogger(name)
|
|
23
181
|
|
|
24
|
-
|
|
25
|
-
|
|
182
|
+
return logger
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
logger = get_textual_rich_logger(__name__)
|
|
26
186
|
|
|
27
187
|
T = TypeVar("T")
|
|
28
188
|
|
|
@@ -57,10 +217,7 @@ def configure_logging(log_level=None):
|
|
|
57
217
|
# reset any currently associated handlers with log level
|
|
58
218
|
for handler in _logger.handlers:
|
|
59
219
|
_logger.removeHandler(handler)
|
|
60
|
-
console_handler =
|
|
61
|
-
console_handler.setFormatter(
|
|
62
|
-
logging.Formatter("%(asctime)s : %(levelname)s : %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
|
63
|
-
)
|
|
220
|
+
console_handler = RichHandler(markup=True)
|
|
64
221
|
_logger.addHandler(console_handler)
|
|
65
222
|
|
|
66
223
|
|
|
@@ -182,10 +339,11 @@ class SageMakerClient(metaclass=SingletonMeta):
|
|
|
182
339
|
logger.warning("No config provided. Using default config.")
|
|
183
340
|
config = Config(retries={"max_attempts": 10, "mode": "standard"})
|
|
184
341
|
|
|
342
|
+
self.config = Config(user_agent_extra=get_user_agent_extra_suffix())
|
|
185
343
|
self.session = session
|
|
186
344
|
self.region_name = region_name
|
|
187
345
|
self.service_name = service_name
|
|
188
|
-
self.client = session.client(service_name, region_name, config=config)
|
|
346
|
+
self.client = session.client(service_name, region_name, config=self.config)
|
|
189
347
|
|
|
190
348
|
|
|
191
349
|
class SageMakerRuntimeClient(metaclass=SingletonMeta):
|
|
@@ -216,10 +374,11 @@ class SageMakerRuntimeClient(metaclass=SingletonMeta):
|
|
|
216
374
|
logger.warning("No config provided. Using default config.")
|
|
217
375
|
config = Config(retries={"max_attempts": 10, "mode": "standard"})
|
|
218
376
|
|
|
377
|
+
self.config = Config(user_agent_extra=get_user_agent_extra_suffix())
|
|
219
378
|
self.session = session
|
|
220
379
|
self.region_name = region_name
|
|
221
380
|
self.service_name = service_name
|
|
222
|
-
self.client = session.client(service_name, region_name, config=config)
|
|
381
|
+
self.client = session.client(service_name, region_name, config=self.config)
|
|
223
382
|
|
|
224
383
|
|
|
225
384
|
class ResourceIterator(Generic[T]):
|
|
@@ -312,3 +471,80 @@ class ResourceIterator(Generic[T]):
|
|
|
312
471
|
raise StopIteration
|
|
313
472
|
|
|
314
473
|
return self.__next__()
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def serialize(value: Any) -> Any:
|
|
477
|
+
"""
|
|
478
|
+
Serialize an object recursively by converting all objects to JSON-serializable types
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
value (Any): The object to be serialized
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
Any: The serialized object
|
|
485
|
+
"""
|
|
486
|
+
if isinstance(value, Unassigned):
|
|
487
|
+
return None
|
|
488
|
+
elif isinstance(value, Dict):
|
|
489
|
+
# if the value is a dict, use _serialize_dict() to serialize it recursively
|
|
490
|
+
return _serialize_dict(value)
|
|
491
|
+
elif isinstance(value, List):
|
|
492
|
+
# if the value is a dict, use _serialize_list() to serialize it recursively
|
|
493
|
+
return _serialize_list(value)
|
|
494
|
+
elif is_not_primitive(value):
|
|
495
|
+
# if the value is a dict, use _serialize_shape() to serialize it recursively
|
|
496
|
+
return _serialize_shape(value)
|
|
497
|
+
else:
|
|
498
|
+
return value
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _serialize_dict(value: Dict) -> dict:
|
|
502
|
+
"""
|
|
503
|
+
Serialize all values in a dict recursively
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
value (dict): The dict to be serialized
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
dict: The serialized dict
|
|
510
|
+
"""
|
|
511
|
+
serialized_dict = {}
|
|
512
|
+
for k, v in value.items():
|
|
513
|
+
if serialize_result := serialize(v):
|
|
514
|
+
serialized_dict.update({k: serialize_result})
|
|
515
|
+
return serialized_dict
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def _serialize_list(value: List) -> list:
|
|
519
|
+
"""
|
|
520
|
+
Serialize all objects in a list
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
value (list): The dict to be serialized
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
list: The serialized list
|
|
527
|
+
"""
|
|
528
|
+
serialized_list = []
|
|
529
|
+
for v in value:
|
|
530
|
+
if serialize_result := serialize(v):
|
|
531
|
+
serialized_list.append(serialize_result)
|
|
532
|
+
return serialized_list
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def _serialize_shape(value: Any) -> dict:
|
|
536
|
+
"""
|
|
537
|
+
Serialize a shape object defined in resource.py or shape.py to a dict
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
value (Any): The shape to be serialized
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
dict: The dict of serialized shape
|
|
544
|
+
"""
|
|
545
|
+
serialized_dict = {}
|
|
546
|
+
for k, v in vars(value).items():
|
|
547
|
+
if serialize_result := serialize(v):
|
|
548
|
+
key = snake_to_pascal(k) if is_snake_case(k) else k
|
|
549
|
+
serialized_dict.update({key[0].upper() + key[1:]: serialize_result})
|
|
550
|
+
return serialized_dict
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ..main.resources import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ..main.shapes import *
|
sagemaker_core/tools/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from ..code_injection.codec import pascal_to_snake
|
|
1
|
+
from ..main.code_injection.codec import pascal_to_snake
|
sagemaker_core/tools/codegen.py
CHANGED
|
@@ -11,12 +11,12 @@
|
|
|
11
11
|
# ANY KIND, either express or implied. See the License for the specific
|
|
12
12
|
# language governing permissions and limitations under the License.
|
|
13
13
|
"""Generates the code for the service model."""
|
|
14
|
+
from sagemaker_core.main.utils import reformat_file_with_black
|
|
14
15
|
from sagemaker_core.tools.shapes_codegen import ShapesCodeGen
|
|
15
16
|
from sagemaker_core.tools.resources_codegen import ResourcesCodeGen
|
|
16
17
|
from typing import Optional
|
|
17
18
|
|
|
18
19
|
from sagemaker_core.tools.data_extractor import ServiceJsonData, load_service_jsons
|
|
19
|
-
from sagemaker_core.util.util import reformat_file_with_black
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def generate_code(
|
|
@@ -45,7 +45,6 @@ def generate_code(
|
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
shapes_code_gen.generate_shapes()
|
|
48
|
-
resources_code_gen.generate_resources()
|
|
49
48
|
reformat_file_with_black(".")
|
|
50
49
|
|
|
51
50
|
|
|
@@ -44,7 +44,7 @@ BASIC_JSON_TYPES_TO_PYTHON_TYPES = {
|
|
|
44
44
|
|
|
45
45
|
BASIC_RETURN_TYPES = {"str", "int", "bool", "float", "datetime.datetime"}
|
|
46
46
|
|
|
47
|
-
SHAPE_DAG_FILE_PATH = os.getcwd() + "/src/sagemaker_core/code_injection/shape_dag.py"
|
|
47
|
+
SHAPE_DAG_FILE_PATH = os.getcwd() + "/src/sagemaker_core/main/code_injection/shape_dag.py"
|
|
48
48
|
PYTHON_TYPES_TO_BASIC_JSON_TYPES = {
|
|
49
49
|
"str": "string",
|
|
50
50
|
"int": "integer",
|
|
@@ -68,13 +68,8 @@ LICENCES_STRING = """
|
|
|
68
68
|
# language governing permissions and limitations under the License.
|
|
69
69
|
"""
|
|
70
70
|
|
|
71
|
-
BASIC_IMPORTS_STRING = """
|
|
72
|
-
import logging
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
71
|
LOGGER_STRING = """
|
|
76
|
-
|
|
77
|
-
logger = logging.getLogger(__name__)
|
|
72
|
+
logger = get_textual_rich_logger(__name__)
|
|
78
73
|
|
|
79
74
|
"""
|
|
80
75
|
|
|
@@ -85,7 +80,7 @@ ADDITIONAL_OPERATION_FILE_PATH = (
|
|
|
85
80
|
SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker/2017-07-24/service-2.json"
|
|
86
81
|
RUNTIME_SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker-runtime/2017-05-13/service-2.json"
|
|
87
82
|
|
|
88
|
-
GENERATED_CLASSES_LOCATION = os.getcwd() + "/src/sagemaker_core/
|
|
83
|
+
GENERATED_CLASSES_LOCATION = os.getcwd() + "/src/sagemaker_core/main"
|
|
89
84
|
UTILS_CODEGEN_FILE_NAME = "utils.py"
|
|
90
85
|
INTELLIGENT_DEFAULTS_HELPER_CODEGEN_FILE_NAME = "intelligent_defaults_helper.py"
|
|
91
86
|
|
sagemaker_core/tools/method.py
CHANGED
|
@@ -11,33 +11,32 @@
|
|
|
11
11
|
# ANY KIND, either express or implied. See the License for the specific
|
|
12
12
|
# language governing permissions and limitations under the License.
|
|
13
13
|
"""Generates the resource classes for the service model."""
|
|
14
|
-
from collections import OrderedDict
|
|
15
|
-
import logging
|
|
16
14
|
from functools import lru_cache
|
|
17
15
|
|
|
18
16
|
import os
|
|
19
17
|
import json
|
|
20
|
-
from sagemaker_core.code_injection.codec import pascal_to_snake
|
|
21
|
-
from sagemaker_core.
|
|
22
|
-
from sagemaker_core.
|
|
18
|
+
from sagemaker_core.main.code_injection.codec import pascal_to_snake
|
|
19
|
+
from sagemaker_core.main.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
|
|
20
|
+
from sagemaker_core.main.exceptions import IntelligentDefaultsError
|
|
21
|
+
from sagemaker_core.main.utils import get_textual_rich_logger
|
|
23
22
|
from sagemaker_core.tools.constants import (
|
|
24
23
|
BASIC_RETURN_TYPES,
|
|
25
24
|
GENERATED_CLASSES_LOCATION,
|
|
26
25
|
RESOURCES_CODEGEN_FILE_NAME,
|
|
27
26
|
LICENCES_STRING,
|
|
28
27
|
TERMINAL_STATES,
|
|
29
|
-
BASIC_IMPORTS_STRING,
|
|
30
28
|
LOGGER_STRING,
|
|
31
29
|
CONFIG_SCHEMA_FILE_NAME,
|
|
32
30
|
PYTHON_TYPES_TO_BASIC_JSON_TYPES,
|
|
33
31
|
CONFIGURABLE_ATTRIBUTE_SUBSTRINGS,
|
|
34
32
|
)
|
|
35
33
|
from sagemaker_core.tools.method import Method, MethodType
|
|
36
|
-
from sagemaker_core.
|
|
34
|
+
from sagemaker_core.main.utils import (
|
|
37
35
|
add_indent,
|
|
38
36
|
convert_to_snake_case,
|
|
39
37
|
snake_to_pascal,
|
|
40
38
|
remove_html_tags,
|
|
39
|
+
escape_special_rst_characters,
|
|
41
40
|
)
|
|
42
41
|
from sagemaker_core.tools.resources_extractor import ResourcesExtractor
|
|
43
42
|
from sagemaker_core.tools.shapes_extractor import ShapesExtractor
|
|
@@ -82,8 +81,7 @@ from sagemaker_core.tools.data_extractor import (
|
|
|
82
81
|
load_combined_operations_data,
|
|
83
82
|
)
|
|
84
83
|
|
|
85
|
-
|
|
86
|
-
log = logging.getLogger(__name__)
|
|
84
|
+
log = get_textual_rich_logger(__name__)
|
|
87
85
|
|
|
88
86
|
TYPE = "type"
|
|
89
87
|
OBJECT = "object"
|
|
@@ -176,21 +174,26 @@ class ResourcesCodeGen:
|
|
|
176
174
|
"""
|
|
177
175
|
# List of import statements
|
|
178
176
|
imports = [
|
|
179
|
-
BASIC_IMPORTS_STRING,
|
|
180
177
|
"import botocore",
|
|
181
178
|
"import datetime",
|
|
182
179
|
"import time",
|
|
183
180
|
"import functools",
|
|
184
|
-
"from pprint import pprint",
|
|
185
181
|
"from pydantic import validate_call",
|
|
186
|
-
"from typing import Dict, List, Literal, Optional, Union\n"
|
|
182
|
+
"from typing import Dict, List, Literal, Optional, Union, Any\n"
|
|
187
183
|
"from boto3.session import Session",
|
|
188
|
-
"from
|
|
189
|
-
"from
|
|
190
|
-
"
|
|
191
|
-
"from
|
|
192
|
-
"from
|
|
193
|
-
"from
|
|
184
|
+
"from rich.console import Group",
|
|
185
|
+
"from rich.live import Live",
|
|
186
|
+
"from rich.panel import Panel",
|
|
187
|
+
"from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn",
|
|
188
|
+
"from rich.status import Status",
|
|
189
|
+
"from rich.style import Style",
|
|
190
|
+
"from sagemaker_core.main.code_injection.codec import transform",
|
|
191
|
+
"from sagemaker_core.main.code_injection.constants import Color",
|
|
192
|
+
"from sagemaker_core.main.utils import SageMakerClient, SageMakerRuntimeClient, ResourceIterator, Unassigned, get_textual_rich_logger, "
|
|
193
|
+
"snake_to_pascal, pascal_to_snake, is_not_primitive, is_not_str_dict, is_snake_case, is_primitive_list, serialize",
|
|
194
|
+
"from sagemaker_core.main.intelligent_defaults_helper import load_default_configs_for_resource_name, get_config_value",
|
|
195
|
+
"from sagemaker_core.main.shapes import *",
|
|
196
|
+
"from sagemaker_core.main.exceptions import *",
|
|
194
197
|
]
|
|
195
198
|
|
|
196
199
|
formated_imports = "\n".join(imports)
|
|
@@ -344,6 +347,7 @@ class ResourcesCodeGen:
|
|
|
344
347
|
class_attributes, class_attributes_string, attributes_and_documentation = (
|
|
345
348
|
class_attribute_info
|
|
346
349
|
)
|
|
350
|
+
|
|
347
351
|
# Start defining the class
|
|
348
352
|
resource_class = f"class {resource_name}(Base):\n"
|
|
349
353
|
|
|
@@ -640,7 +644,8 @@ class ResourcesCodeGen:
|
|
|
640
644
|
else:
|
|
641
645
|
documentation_string += f"{attribute_snake}: {documentation}\n"
|
|
642
646
|
documentation_string = add_indent(documentation_string)
|
|
643
|
-
|
|
647
|
+
documentation_string = remove_html_tags(documentation_string)
|
|
648
|
+
return escape_special_rst_characters(documentation_string)
|
|
644
649
|
|
|
645
650
|
def _generate_create_method_args(
|
|
646
651
|
self, operation_input_shape_name: str, resource_name: str
|
|
@@ -930,7 +935,7 @@ class ResourcesCodeGen:
|
|
|
930
935
|
decorator=decorator,
|
|
931
936
|
method_name="create",
|
|
932
937
|
method_args=add_indent("cls,\n", 4) + create_args,
|
|
933
|
-
return_type='Optional["resource_name"]',
|
|
938
|
+
return_type=f'Optional["{resource_name}"]',
|
|
934
939
|
serialize_operation_input=serialize_operation_input,
|
|
935
940
|
initialize_client=initialize_client,
|
|
936
941
|
call_operation_api=call_operation_api,
|
|
@@ -1758,7 +1763,7 @@ class ResourcesCodeGen:
|
|
|
1758
1763
|
formatted_failed_block = FAILED_STATUS_ERROR_TEMPLATE.format(
|
|
1759
1764
|
resource_name=resource_name, reason=failure_reason
|
|
1760
1765
|
)
|
|
1761
|
-
formatted_failed_block = add_indent(formatted_failed_block,
|
|
1766
|
+
formatted_failed_block = add_indent(formatted_failed_block, 16)
|
|
1762
1767
|
|
|
1763
1768
|
formatted_method = WAIT_METHOD_TEMPLATE.format(
|
|
1764
1769
|
terminal_resource_states=terminal_resource_states,
|
|
@@ -1792,7 +1797,7 @@ class ResourcesCodeGen:
|
|
|
1792
1797
|
formatted_failed_block = FAILED_STATUS_ERROR_TEMPLATE.format(
|
|
1793
1798
|
resource_name=resource_name, reason=failure_reason
|
|
1794
1799
|
)
|
|
1795
|
-
formatted_failed_block = add_indent(formatted_failed_block,
|
|
1800
|
+
formatted_failed_block = add_indent(formatted_failed_block, 12)
|
|
1796
1801
|
|
|
1797
1802
|
formatted_method = WAIT_FOR_STATUS_METHOD_TEMPLATE.format(
|
|
1798
1803
|
resource_states=resource_states,
|
|
@@ -1826,10 +1831,10 @@ class ResourcesCodeGen:
|
|
|
1826
1831
|
formatted_failed_block = DELETE_FAILED_STATUS_CHECK.format(
|
|
1827
1832
|
resource_name=resource_name, reason=failure_reason
|
|
1828
1833
|
)
|
|
1829
|
-
formatted_failed_block = add_indent(formatted_failed_block,
|
|
1834
|
+
formatted_failed_block = add_indent(formatted_failed_block, 16)
|
|
1830
1835
|
|
|
1831
1836
|
if any(state.lower() == "deleted" for state in resource_states):
|
|
1832
|
-
deleted_status_check = add_indent(DELETED_STATUS_CHECK,
|
|
1837
|
+
deleted_status_check = add_indent(DELETED_STATUS_CHECK, 16)
|
|
1833
1838
|
else:
|
|
1834
1839
|
deleted_status_check = ""
|
|
1835
1840
|
|
|
@@ -1962,9 +1967,6 @@ class ResourcesCodeGen:
|
|
|
1962
1967
|
Input for generating the Schema is the service JSON that is already loaded in the class
|
|
1963
1968
|
|
|
1964
1969
|
"""
|
|
1965
|
-
self.resources_extractor = ResourcesExtractor()
|
|
1966
|
-
self.resources_plan = self.resources_extractor.get_resource_plan()
|
|
1967
|
-
|
|
1968
1970
|
resource_properties = {}
|
|
1969
1971
|
|
|
1970
1972
|
for _, row in self.resources_plan.iterrows():
|
|
@@ -2073,7 +2075,7 @@ class ResourcesCodeGen:
|
|
|
2073
2075
|
TYPE: self._get_json_schema_type_from_python_type(value) or value
|
|
2074
2076
|
}
|
|
2075
2077
|
elif value.startswith("List") or value.startswith("Dict"):
|
|
2076
|
-
log.
|
|
2078
|
+
log.debug("Script does not currently support list of objects as configurable")
|
|
2077
2079
|
continue
|
|
2078
2080
|
else:
|
|
2079
2081
|
class_attributes = self.shapes_extractor.generate_shape_members(value)
|
|
@@ -11,11 +11,11 @@
|
|
|
11
11
|
# ANY KIND, either express or implied. See the License for the specific
|
|
12
12
|
# language governing permissions and limitations under the License.
|
|
13
13
|
"""A class for extracting resource information from a service JSON."""
|
|
14
|
-
import logging
|
|
15
14
|
from typing import Optional
|
|
16
15
|
|
|
17
16
|
import pandas as pd
|
|
18
17
|
|
|
18
|
+
from sagemaker_core.main.utils import get_textual_rich_logger
|
|
19
19
|
from sagemaker_core.tools.constants import CLASS_METHODS, OBJECT_METHODS
|
|
20
20
|
from sagemaker_core.tools.data_extractor import (
|
|
21
21
|
load_additional_operations_data,
|
|
@@ -24,8 +24,7 @@ from sagemaker_core.tools.data_extractor import (
|
|
|
24
24
|
)
|
|
25
25
|
from sagemaker_core.tools.method import Method
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
log = logging.getLogger(__name__)
|
|
27
|
+
log = get_textual_rich_logger(__name__)
|
|
29
28
|
"""
|
|
30
29
|
This class is used to extract the resources and its actions from the service-2.json file.
|
|
31
30
|
"""
|
|
@@ -175,9 +174,9 @@ class ResourcesExtractor:
|
|
|
175
174
|
|
|
176
175
|
log.info(f"Total resource - {len(self.resources)}")
|
|
177
176
|
|
|
178
|
-
log.info(f"
|
|
177
|
+
log.info(f"Supported actions - {len(self.actions_under_resource)}")
|
|
179
178
|
|
|
180
|
-
log.info(f"Unsupported actions
|
|
179
|
+
log.info(f"Unsupported actions - {len(self.actions)}")
|
|
181
180
|
|
|
182
181
|
self._extract_resource_plan_as_dataframe()
|
|
183
182
|
|
|
@@ -368,6 +367,3 @@ class ResourcesExtractor:
|
|
|
368
367
|
resource_methods (dict): The resource methods dict.
|
|
369
368
|
"""
|
|
370
369
|
return self.resource_methods
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
resource_extractor = ResourcesExtractor()
|