sagemaker-core 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sagemaker-core might be problematic. Click here for more details.

@@ -0,0 +1,259 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Extracts the shapes to DAG structure."""
14
+ import textwrap
15
+ import pprint
16
+ from functools import lru_cache
17
+ from typing import Optional, Any
18
+
19
+ from sagemaker_core.tools.constants import BASIC_JSON_TYPES_TO_PYTHON_TYPES, SHAPE_DAG_FILE_PATH
20
+ from sagemaker_core.util.util import (
21
+ reformat_file_with_black,
22
+ convert_to_snake_case,
23
+ snake_to_pascal,
24
+ )
25
+ from sagemaker_core.tools.data_extractor import load_combined_shapes_data
26
+
27
+
28
+ class ShapesExtractor:
29
+ """Extracts the shapes to DAG structure."""
30
+
31
+ def __init__(self, combined_shapes: Optional[dict] = None):
32
+ """
33
+ Initializes a new instance of the ShapesExtractor class.
34
+
35
+ :param combined_shapes: All the shapes put together from all Sagemaker Service JSONs
36
+ """
37
+ self.combined_shapes = combined_shapes or load_combined_shapes_data()
38
+
39
+ self.shape_dag = self.get_shapes_dag()
40
+ with open(SHAPE_DAG_FILE_PATH, "w") as f:
41
+ f.write("SHAPE_DAG=")
42
+ f.write(textwrap.indent(pprint.pformat(self.shape_dag, width=1), "") + "\n")
43
+ reformat_file_with_black(SHAPE_DAG_FILE_PATH)
44
+
45
+ # @property
46
+ def get_shapes_dag(self):
47
+ """
48
+ Parses the Service Json and generates the Shape DAG.
49
+
50
+ DAG is stored in a Dictionary data structure, and each key denotes a DAG Node.
51
+ Nodes can be of composite types: structure, list, map. Basic types (Ex. str, int, etc)
52
+ are omitted from compactness and can be inferred from composite type nodes.
53
+
54
+ The connections of Nodes are can be followed by using the shape.
55
+
56
+ Possible scenarios of nested associations:
57
+
58
+ 1. StructA → StructB → basic_type_member.
59
+ 2. StructA → list → basic_type_member
60
+ 3. StructA → list → StructB → basic_type_member
61
+ 4. StructA → map → basic_type_member
62
+ 5. StructA → map → StructBMapValue → basic_type_member
63
+ 6. StructA → map → map → basic_type_member
64
+ 7. StructA → map → list → basic_type_member
65
+
66
+ Example:
67
+
68
+ "ContainerDefinition": { # type: structure
69
+ "type":"structure",
70
+ "members":[
71
+ {"name": "ModelName", "shape": "ModelName", "type": "string"},
72
+ {"name": "ContainerDefinition", "shape": "ContainerDefinition", "type": "list"},
73
+ {"name": "CustomerMetadata", "shape": "CustomerMetadataMap", "type": "map"},
74
+ ],
75
+ },
76
+ "ContainerDefinitionList": { # type: list
77
+ "type":"list",
78
+ "member_shape":"ContainerDefinition",
79
+ "member_type":"ContainerDefinition", # potential types: string, structure
80
+ },
81
+ "CustomerMetadataMap": { # type: map
82
+ "type":"map",
83
+ "key_shape":"CustomerMetadataKey",
84
+ "key_type":"string", # allowed types: string
85
+ "value_shape":"CustomerMetadataValue",
86
+ "value_type":"string", # potential types: string, structure, list, map
87
+ },
88
+
89
+ :return: The generated Shape DAG.
90
+ :rtype: dict
91
+ """
92
+ _dag = {}
93
+ _all_shapes = self.combined_shapes
94
+ for shape, shape_attrs in _all_shapes.items():
95
+ shape_data = _all_shapes[shape]
96
+ if "type" not in shape_data:
97
+ continue
98
+ if shape_data["type"] == "structure":
99
+ _dag[shape] = {"type": "structure", "members": []}
100
+ for member, member_attrs in shape_data["members"].items():
101
+ shape_node_member = {"name": member, "shape": member_attrs["shape"]}
102
+ member_shape_dict = _all_shapes[member_attrs["shape"]]
103
+ shape_node_member["type"] = member_shape_dict["type"]
104
+ _dag[shape]["members"].append(shape_node_member)
105
+ elif shape_data["type"] == "list":
106
+ _dag[shape] = {"type": "list"}
107
+ _list_member_shape = shape_data["member"]["shape"]
108
+ _dag[shape]["member_shape"] = _list_member_shape
109
+ _dag[shape]["member_type"] = _all_shapes[_list_member_shape]["type"]
110
+ elif shape_data["type"] == "map":
111
+ _dag[shape] = {"type": "map"}
112
+ _map_key_shape = shape_data["key"]["shape"]
113
+ _dag[shape]["key_shape"] = _map_key_shape
114
+ _map_value_shape = shape_data["value"]["shape"]
115
+ _dag[shape]["value_shape"] = _map_value_shape
116
+ _dag[shape]["key_type"] = _all_shapes[_map_key_shape]["type"]
117
+ _dag[shape]["value_type"] = _all_shapes[_map_value_shape]["type"]
118
+ return _dag
119
+
120
+ def _evaluate_list_type(self, member_shape):
121
+ list_shape_name = member_shape["member"]["shape"]
122
+ list_shape_type = self.combined_shapes[list_shape_name]["type"]
123
+ if list_shape_type in ["list", "map"]:
124
+ raise Exception(
125
+ "Unhandled list shape key type encountered, needs extra logic to handle this"
126
+ )
127
+ if list_shape_type == "structure":
128
+ # handling an edge case of nested structure
129
+ if list_shape_name == "SearchExpression":
130
+ member_type = f"List['{list_shape_name}']"
131
+ else:
132
+ member_type = f"List[{list_shape_name}]"
133
+ else:
134
+ member_type = f"List[{BASIC_JSON_TYPES_TO_PYTHON_TYPES[list_shape_type]}]"
135
+ return member_type
136
+
137
+ def _evaluate_map_type(self, member_shape):
138
+ map_key_shape_name = member_shape["key"]["shape"]
139
+ map_value_shape_name = member_shape["value"]["shape"]
140
+ map_key_shape = self.combined_shapes[map_key_shape_name]
141
+ map_value_shape = self.combined_shapes[map_value_shape_name]
142
+ map_key_shape_type = map_key_shape["type"]
143
+ map_value_shape_type = map_value_shape["type"]
144
+ # Map keys are always expected to be "string" type
145
+ if map_key_shape_type != "string":
146
+ raise Exception(
147
+ "Unhandled map shape key type encountered, needs extra logic to handle this"
148
+ )
149
+ if map_value_shape_type == "structure":
150
+ member_type = (
151
+ f"Dict[{BASIC_JSON_TYPES_TO_PYTHON_TYPES[map_key_shape_type]}, "
152
+ f"{map_value_shape_name}]"
153
+ )
154
+ elif map_value_shape_type == "list":
155
+ member_type = (
156
+ f"Dict[{BASIC_JSON_TYPES_TO_PYTHON_TYPES[map_key_shape_type]}, "
157
+ f"{self._evaluate_list_type(map_value_shape)}]"
158
+ )
159
+ elif map_value_shape_type == "map":
160
+ member_type = (
161
+ f"Dict[{BASIC_JSON_TYPES_TO_PYTHON_TYPES[map_key_shape_type]}, "
162
+ f"{self._evaluate_map_type(map_value_shape)}]"
163
+ )
164
+ else:
165
+ member_type = (
166
+ f"Dict[{BASIC_JSON_TYPES_TO_PYTHON_TYPES[map_key_shape_type]}, "
167
+ f"{BASIC_JSON_TYPES_TO_PYTHON_TYPES[map_value_shape_type]}]"
168
+ )
169
+ return member_type
170
+
171
+ def generate_data_shape_members_and_string_body(
172
+ self, shape, resource_plan: Optional[Any] = None, required_override=()
173
+ ):
174
+ shape_members = self.generate_shape_members(shape, required_override)
175
+ resource_names = None
176
+ if resource_plan is not None:
177
+ resource_names = [row["resource_name"] for _, row in resource_plan.iterrows()]
178
+ init_data_body = ""
179
+ for attr, value in shape_members.items():
180
+ if (
181
+ resource_names
182
+ and attr.endswith("name")
183
+ and attr[: -len("_name")] != shape
184
+ and attr != "name"
185
+ and snake_to_pascal(attr[: -len("_name")]) in resource_names
186
+ ):
187
+ if value.startswith("Optional"):
188
+ init_data_body += f"{attr}: Optional[Union[str, object]] = Unassigned()\n"
189
+ else:
190
+ init_data_body += f"{attr}: Union[str, object]\n"
191
+ elif attr == "lambda":
192
+ init_data_body += f"# {attr}: {value}\n"
193
+ else:
194
+ init_data_body += f"{attr}: {value}\n"
195
+ return shape_members, init_data_body
196
+
197
+ def generate_data_shape_string_body(self, shape, resource_plan, required_override=()):
198
+ return self.generate_data_shape_members_and_string_body(
199
+ shape, resource_plan, required_override
200
+ )[1]
201
+
202
+ def generate_data_shape_members(self, shape, resource_plan, required_override=()):
203
+ return self.generate_data_shape_members_and_string_body(
204
+ shape, resource_plan, required_override
205
+ )[0]
206
+
207
+ @lru_cache
208
+ def generate_shape_members(self, shape, required_override=()):
209
+ shape_dict = self.combined_shapes[shape]
210
+ members = shape_dict["members"]
211
+ required_args = list(required_override) or shape_dict.get("required", [])
212
+ init_data_body = {}
213
+ # bring the required members in front
214
+ ordered_members = {key: members[key] for key in required_args if key in members}
215
+ ordered_members.update(members)
216
+ for member_name, member_attrs in ordered_members.items():
217
+ member_shape_name = member_attrs["shape"]
218
+ if self.combined_shapes[member_shape_name]:
219
+ member_shape = self.combined_shapes[member_shape_name]
220
+ member_shape_type = member_shape["type"]
221
+ if member_shape_type == "structure":
222
+ member_type = member_shape_name
223
+ elif member_shape_type == "list":
224
+ member_type = self._evaluate_list_type(member_shape)
225
+ elif member_shape_type == "map":
226
+ member_type = self._evaluate_map_type(member_shape)
227
+ else:
228
+ # Shape is a simple type like string
229
+ member_type = BASIC_JSON_TYPES_TO_PYTHON_TYPES[member_shape_type]
230
+ else:
231
+ raise Exception("The Shape definition mush exist. The Json Data might be corrupt")
232
+ member_name_snake_case = convert_to_snake_case(member_name)
233
+ if member_name in required_args:
234
+ init_data_body[f"{member_name_snake_case}"] = f"{member_type}"
235
+ else:
236
+ init_data_body[f"{member_name_snake_case}"] = (
237
+ f"Optional[{member_type}] = Unassigned()"
238
+ )
239
+ return init_data_body
240
+
241
+ @lru_cache
242
+ def fetch_shape_members_and_doc_strings(self, shape, required_override=()):
243
+ shape_dict = self.combined_shapes[shape]
244
+ members = shape_dict["members"]
245
+ required_args = list(required_override) or shape_dict.get("required", [])
246
+ # bring the required members in front
247
+ ordered_members = {key: members[key] for key in required_args if key in members}
248
+ ordered_members.update(members)
249
+ shape_members_and_docstrings = {}
250
+ for member_name, member_attrs in ordered_members.items():
251
+ member_shape_documentation = member_attrs.get("documentation")
252
+ shape_members_and_docstrings[member_name] = member_shape_documentation
253
+ return shape_members_and_docstrings
254
+
255
+ def get_required_members(self, shape):
256
+ shape_dict = self.combined_shapes[shape]
257
+ required_args = shape_dict.get("required", [])
258
+
259
+ return [convert_to_snake_case(arg) for arg in required_args]