sagemaker-core 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sagemaker-core might be problematic. Click here for more details.
- sagemaker_core/__init__.py +0 -0
- sagemaker_core/_version.py +11 -0
- sagemaker_core/code_injection/__init__.py +0 -0
- sagemaker_core/code_injection/base.py +42 -0
- sagemaker_core/code_injection/codec.py +241 -0
- sagemaker_core/code_injection/constants.py +18 -0
- sagemaker_core/code_injection/shape_dag.py +14527 -0
- sagemaker_core/generated/__init__.py +0 -0
- sagemaker_core/generated/config_schema.py +870 -0
- sagemaker_core/generated/exceptions.py +147 -0
- sagemaker_core/generated/intelligent_defaults_helper.py +198 -0
- sagemaker_core/generated/resources.py +26998 -0
- sagemaker_core/generated/shapes.py +11584 -0
- sagemaker_core/generated/utils.py +314 -0
- sagemaker_core/tools/__init__.py +1 -0
- sagemaker_core/tools/codegen.py +56 -0
- sagemaker_core/tools/constants.py +96 -0
- sagemaker_core/tools/data_extractor.py +49 -0
- sagemaker_core/tools/method.py +32 -0
- sagemaker_core/tools/resources_codegen.py +2122 -0
- sagemaker_core/tools/resources_extractor.py +373 -0
- sagemaker_core/tools/shapes_codegen.py +284 -0
- sagemaker_core/tools/shapes_extractor.py +259 -0
- sagemaker_core/tools/templates.py +747 -0
- sagemaker_core/util/__init__.py +0 -0
- sagemaker_core/util/util.py +81 -0
- sagemaker_core-0.1.3.dist-info/LICENSE +201 -0
- sagemaker_core-0.1.3.dist-info/METADATA +28 -0
- sagemaker_core-0.1.3.dist-info/RECORD +31 -0
- sagemaker_core-0.1.3.dist-info/WHEEL +5 -0
- sagemaker_core-0.1.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""A class for extracting resource information from a service JSON."""
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
from sagemaker_core.tools.constants import CLASS_METHODS, OBJECT_METHODS
|
|
20
|
+
from sagemaker_core.tools.data_extractor import (
|
|
21
|
+
load_additional_operations_data,
|
|
22
|
+
load_combined_operations_data,
|
|
23
|
+
load_combined_shapes_data,
|
|
24
|
+
)
|
|
25
|
+
from sagemaker_core.tools.method import Method
|
|
26
|
+
|
|
27
|
+
logging.basicConfig(level=logging.INFO)
|
|
28
|
+
log = logging.getLogger(__name__)
|
|
29
|
+
"""
|
|
30
|
+
This class is used to extract the resources and its actions from the service-2.json file.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ResourcesExtractor:
|
|
35
|
+
"""
|
|
36
|
+
A class for extracting resource information from a service JSON.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
service_json (dict): The Botocore service.json containing the shape definitions.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
service_json (dict): The service JSON containing operations and shapes.
|
|
43
|
+
operations (dict): The operations defined in the service JSON.
|
|
44
|
+
shapes (dict): The shapes defined in the service JSON.
|
|
45
|
+
resource_actions (dict): A dictionary mapping resources to their associated actions.
|
|
46
|
+
actions_under_resource (set): A set of actions that are performed on resources.
|
|
47
|
+
create_resources (set): A set of resources that can be created.
|
|
48
|
+
add_resources (set): A set of resources that can be added.
|
|
49
|
+
start_resources (set): A set of resources that can be started.
|
|
50
|
+
register_resources (set): A set of resources that can be registered.
|
|
51
|
+
import_resources (set): A set of resources that can be imported.
|
|
52
|
+
resources (set): A set of all resources.
|
|
53
|
+
df (DataFrame): A DataFrame containing resource information.
|
|
54
|
+
|
|
55
|
+
Methods:
|
|
56
|
+
_filter_actions_for_resources(resources): Filters actions based on the given resources.
|
|
57
|
+
_extract_resources_plan(): Extracts the resource plan from the service JSON.
|
|
58
|
+
_get_status_chain_and_states(shape_name, status_chain): Recursively extracts the status chain and states for a given shape.
|
|
59
|
+
_extract_resource_plan_as_dataframe(): Builds a DataFrame containing resource information.
|
|
60
|
+
get_resource_plan(): Returns the resource plan DataFrame.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
RESOURCE_TO_ADDITIONAL_METHODS = {
|
|
64
|
+
"Cluster": ["DescribeClusterNode", "ListClusterNodes"],
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
combined_shapes: Optional[dict] = None,
|
|
70
|
+
combined_operations: Optional[dict] = None,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Initializes a ResourceExtractor object.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
service_json (dict): The service JSON containing operations and shapes.
|
|
77
|
+
"""
|
|
78
|
+
self.operations = combined_operations or load_combined_operations_data()
|
|
79
|
+
self.shapes = combined_shapes or load_combined_shapes_data()
|
|
80
|
+
self.additional_operations = load_additional_operations_data()
|
|
81
|
+
# contains information about additional methods only now.
|
|
82
|
+
# TODO: replace resource_actions with resource_methods to include all methods
|
|
83
|
+
self.resource_methods = {}
|
|
84
|
+
self.resource_actions = {}
|
|
85
|
+
self.actions_under_resource = set()
|
|
86
|
+
|
|
87
|
+
self._extract_resources_plan()
|
|
88
|
+
|
|
89
|
+
def _filter_additional_operations(self):
|
|
90
|
+
"""
|
|
91
|
+
Extracts information from additional operations defined in additional_operations.json
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
None
|
|
95
|
+
"""
|
|
96
|
+
for resource_name, resource_operations in self.additional_operations.items():
|
|
97
|
+
self.resources.add(resource_name)
|
|
98
|
+
if resource_name not in self.resource_methods:
|
|
99
|
+
self.resource_methods[resource_name] = dict()
|
|
100
|
+
for operation_name, operation in resource_operations.items():
|
|
101
|
+
self.actions_under_resource.update(operation_name)
|
|
102
|
+
method = Method(**operation)
|
|
103
|
+
method.get_docstring_title(self.operations[operation_name])
|
|
104
|
+
self.resource_methods[resource_name][operation["method_name"]] = method
|
|
105
|
+
self.actions.remove(operation_name)
|
|
106
|
+
|
|
107
|
+
def _filter_actions_for_resources(self, resources):
|
|
108
|
+
"""
|
|
109
|
+
Filters actions based on the given resources.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
resources (set): A set of resources.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
None
|
|
116
|
+
"""
|
|
117
|
+
for resource in sorted(resources, key=len, reverse=True):
|
|
118
|
+
filtered_actions = set(
|
|
119
|
+
[
|
|
120
|
+
a
|
|
121
|
+
for a in self.actions
|
|
122
|
+
if a.endswith(resource)
|
|
123
|
+
or (a.startswith("List") and a.endswith(resource + "s"))
|
|
124
|
+
or a.startswith("Invoke" + resource)
|
|
125
|
+
]
|
|
126
|
+
)
|
|
127
|
+
self.actions_under_resource.update(filtered_actions)
|
|
128
|
+
self.resource_actions[resource] = filtered_actions
|
|
129
|
+
|
|
130
|
+
self.actions = self.actions - filtered_actions
|
|
131
|
+
|
|
132
|
+
def _extract_resources_plan(self):
|
|
133
|
+
"""
|
|
134
|
+
Extracts the resource plan from the service JSON.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
None
|
|
138
|
+
"""
|
|
139
|
+
self.actions = set(self.operations.keys())
|
|
140
|
+
log.info(f"Total actions - {len(self.actions)}")
|
|
141
|
+
|
|
142
|
+
# Filter out additional operations and resources first
|
|
143
|
+
self.resources = set()
|
|
144
|
+
self._filter_additional_operations()
|
|
145
|
+
|
|
146
|
+
self.create_resources = set(
|
|
147
|
+
[key[len("Create") :] for key in self.actions if key.startswith("Create")]
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
self.add_resources = set(
|
|
151
|
+
[key[len("Add") :] for key in self.actions if key.startswith("Add")]
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
self.start_resources = set(
|
|
155
|
+
[key[len("Start") :] for key in self.actions if key.startswith("Start")]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.register_resources = set(
|
|
159
|
+
[key[len("Register") :] for key in self.actions if key.startswith("Register")]
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
self.import_resources = set(
|
|
163
|
+
[key[len("Import") :] for key in self.actions if key.startswith("Import")]
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
self.resources.update(
|
|
167
|
+
self.create_resources
|
|
168
|
+
| self.add_resources
|
|
169
|
+
| self.start_resources
|
|
170
|
+
| self.register_resources
|
|
171
|
+
| self.import_resources
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
self._filter_actions_for_resources(self.resources)
|
|
175
|
+
|
|
176
|
+
log.info(f"Total resource - {len(self.resources)}")
|
|
177
|
+
|
|
178
|
+
log.info(f"Total actions_under_resource - {len(self.actions_under_resource)}")
|
|
179
|
+
|
|
180
|
+
log.info(f"Unsupported actions: - {len(self.actions)}")
|
|
181
|
+
|
|
182
|
+
self._extract_resource_plan_as_dataframe()
|
|
183
|
+
|
|
184
|
+
def get_status_chain_and_states(self, resource_name):
|
|
185
|
+
"""
|
|
186
|
+
Extract the status chain and states for a given resource.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
resource_name (str): The name of the resource
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
status_chain (list): The status chain for the resource.
|
|
193
|
+
resource_states (list): The states associated with the resource.
|
|
194
|
+
"""
|
|
195
|
+
resource_operation = self.operations["Describe" + resource_name]
|
|
196
|
+
resource_operation_output_shape_name = resource_operation["output"]["shape"]
|
|
197
|
+
output_members_data = self.shapes[resource_operation_output_shape_name]["members"]
|
|
198
|
+
if len(output_members_data) == 1:
|
|
199
|
+
single_member_name = next(iter(output_members_data))
|
|
200
|
+
single_member_shape_name = output_members_data[single_member_name]["shape"]
|
|
201
|
+
status_chain = []
|
|
202
|
+
status_chain.append(
|
|
203
|
+
{"name": single_member_name, "shape_name": single_member_shape_name}
|
|
204
|
+
)
|
|
205
|
+
resource_status_chain, resource_states = self._get_status_chain_and_states(
|
|
206
|
+
single_member_shape_name, status_chain
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
resource_status_chain, resource_states = self._get_status_chain_and_states(
|
|
210
|
+
resource_operation_output_shape_name
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return resource_status_chain, resource_states
|
|
214
|
+
|
|
215
|
+
def _get_status_chain_and_states(self, shape_name, status_chain: list = None):
|
|
216
|
+
"""
|
|
217
|
+
Recursively extracts the status chain and states for a given shape.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
shape_name (str): The name of the shape.
|
|
221
|
+
status_chain (list): The current status chain.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
status_chain (list): The status chain for the shape.
|
|
225
|
+
resource_states (list): The states associated with the shape.
|
|
226
|
+
"""
|
|
227
|
+
if status_chain is None:
|
|
228
|
+
status_chain = []
|
|
229
|
+
|
|
230
|
+
member_data = self.shapes[shape_name]["members"]
|
|
231
|
+
status_name = next((member for member in member_data if "status" in member.lower()), None)
|
|
232
|
+
if status_name is None:
|
|
233
|
+
return [], []
|
|
234
|
+
|
|
235
|
+
status_shape_name = member_data[status_name]["shape"]
|
|
236
|
+
|
|
237
|
+
status_chain.append({"name": status_name, "shape_name": status_shape_name})
|
|
238
|
+
|
|
239
|
+
if "enum" in self.shapes[status_shape_name]:
|
|
240
|
+
resource_states = self.shapes[status_shape_name]["enum"]
|
|
241
|
+
return status_chain, resource_states
|
|
242
|
+
else:
|
|
243
|
+
status_chain, resource_states = self._get_status_chain_and_states(
|
|
244
|
+
status_shape_name, status_chain
|
|
245
|
+
)
|
|
246
|
+
return status_chain, resource_states
|
|
247
|
+
|
|
248
|
+
def _extract_resource_plan_as_dataframe(self):
|
|
249
|
+
"""
|
|
250
|
+
Builds a DataFrame containing resource information.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
None
|
|
254
|
+
"""
|
|
255
|
+
self.df = pd.DataFrame(
|
|
256
|
+
columns=[
|
|
257
|
+
"resource_name",
|
|
258
|
+
"type",
|
|
259
|
+
"class_methods",
|
|
260
|
+
"object_methods",
|
|
261
|
+
"chain_resource_name",
|
|
262
|
+
"additional_methods",
|
|
263
|
+
"raw_actions",
|
|
264
|
+
"resource_status_chain",
|
|
265
|
+
"resource_states",
|
|
266
|
+
]
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
for resource, actions in sorted(self.resource_actions.items()):
|
|
270
|
+
class_methods = set()
|
|
271
|
+
object_methods = set()
|
|
272
|
+
additional_methods = set()
|
|
273
|
+
chain_resource_names = set()
|
|
274
|
+
resource_status_chain = set()
|
|
275
|
+
resource_states = set()
|
|
276
|
+
|
|
277
|
+
for action in actions:
|
|
278
|
+
action_low = action.lower()
|
|
279
|
+
resource_low = resource.lower()
|
|
280
|
+
|
|
281
|
+
if action_low.split(resource_low)[0] == "describe":
|
|
282
|
+
class_methods.add("get")
|
|
283
|
+
object_methods.add("refresh")
|
|
284
|
+
|
|
285
|
+
output_shape_name = self.operations[action]["output"]["shape"]
|
|
286
|
+
output_members_data = self.shapes[output_shape_name]["members"]
|
|
287
|
+
|
|
288
|
+
resource_status_chain, resource_states = self.get_status_chain_and_states(
|
|
289
|
+
resource
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
if resource_low.endswith("job") or resource_low.endswith("jobv2"):
|
|
293
|
+
object_methods.add("wait")
|
|
294
|
+
elif resource_states and resource_low != "action":
|
|
295
|
+
object_methods.add("wait_for_status")
|
|
296
|
+
|
|
297
|
+
if "Deleting" in resource_states or "DELETING" in resource_states:
|
|
298
|
+
object_methods.add("wait_for_delete")
|
|
299
|
+
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
if action_low.split(resource_low)[0] == "create":
|
|
303
|
+
shape_name = self.operations[action]["input"]["shape"]
|
|
304
|
+
input = self.shapes[shape_name]
|
|
305
|
+
for member in input["members"]:
|
|
306
|
+
if member.endswith("Name") or member.endswith("Names"):
|
|
307
|
+
chain_resource_name = member[: -len("Name")]
|
|
308
|
+
|
|
309
|
+
if (
|
|
310
|
+
chain_resource_name != resource
|
|
311
|
+
and chain_resource_name in self.resources
|
|
312
|
+
):
|
|
313
|
+
chain_resource_names.add(chain_resource_name)
|
|
314
|
+
action_split = action_low.split(resource_low)
|
|
315
|
+
if action_split[0] == "invoke":
|
|
316
|
+
if not action_split[1]:
|
|
317
|
+
invoke_method = "invoke"
|
|
318
|
+
elif action_split[1] == "async":
|
|
319
|
+
invoke_method = "invoke_async"
|
|
320
|
+
else:
|
|
321
|
+
invoke_method = "invoke_with_response_stream"
|
|
322
|
+
object_methods.add(invoke_method)
|
|
323
|
+
elif action_split[0] in CLASS_METHODS:
|
|
324
|
+
if action_low.split(resource_low)[0] == "list":
|
|
325
|
+
class_methods.add("get_all")
|
|
326
|
+
else:
|
|
327
|
+
class_methods.add(action_low.split(resource_low)[0])
|
|
328
|
+
elif action_split[0] in OBJECT_METHODS:
|
|
329
|
+
object_methods.add(action_split[0])
|
|
330
|
+
else:
|
|
331
|
+
additional_methods.add(action)
|
|
332
|
+
|
|
333
|
+
if resource in self.RESOURCE_TO_ADDITIONAL_METHODS:
|
|
334
|
+
additional_methods.update(self.RESOURCE_TO_ADDITIONAL_METHODS[resource])
|
|
335
|
+
|
|
336
|
+
new_row = pd.DataFrame(
|
|
337
|
+
{
|
|
338
|
+
"resource_name": [resource],
|
|
339
|
+
"type": ["resource"],
|
|
340
|
+
"class_methods": [list(sorted(class_methods))],
|
|
341
|
+
"object_methods": [list(sorted(object_methods))],
|
|
342
|
+
"chain_resource_name": [list(sorted(chain_resource_names))],
|
|
343
|
+
"additional_methods": [list(sorted(additional_methods))],
|
|
344
|
+
"raw_actions": [list(sorted(actions))],
|
|
345
|
+
"resource_status_chain": [list(resource_status_chain)],
|
|
346
|
+
"resource_states": [list(resource_states)],
|
|
347
|
+
}
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
self.df = pd.concat([self.df, new_row], ignore_index=True)
|
|
351
|
+
|
|
352
|
+
self.df.to_csv("resource_plan.csv", index=False)
|
|
353
|
+
|
|
354
|
+
def get_resource_plan(self):
|
|
355
|
+
"""
|
|
356
|
+
Returns the resource plan DataFrame.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
df (DataFrame): The resource plan DataFrame.
|
|
360
|
+
"""
|
|
361
|
+
return self.df
|
|
362
|
+
|
|
363
|
+
def get_resource_methods(self):
|
|
364
|
+
"""
|
|
365
|
+
Returns the resource methods dict.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
resource_methods (dict): The resource methods dict.
|
|
369
|
+
"""
|
|
370
|
+
return self.resource_methods
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
resource_extractor = ResourcesExtractor()
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""A class for generating class structure from Service Model JSON.
|
|
14
|
+
|
|
15
|
+
To run the script be sure to set the PYTHONPATH
|
|
16
|
+
export PYTHONPATH=<sagemaker-code-gen repo directory>:$PYTHONPATH
|
|
17
|
+
"""
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
from sagemaker_core.code_injection.codec import pascal_to_snake
|
|
21
|
+
from sagemaker_core.tools.constants import (
|
|
22
|
+
LICENCES_STRING,
|
|
23
|
+
GENERATED_CLASSES_LOCATION,
|
|
24
|
+
SHAPES_CODEGEN_FILE_NAME,
|
|
25
|
+
)
|
|
26
|
+
from sagemaker_core.tools.shapes_extractor import ShapesExtractor
|
|
27
|
+
from sagemaker_core.util.util import add_indent, convert_to_snake_case, remove_html_tags
|
|
28
|
+
from sagemaker_core.tools.templates import SHAPE_CLASS_TEMPLATE, SHAPE_BASE_CLASS_TEMPLATE
|
|
29
|
+
from sagemaker_core.tools.data_extractor import (
|
|
30
|
+
load_combined_shapes_data,
|
|
31
|
+
load_combined_operations_data,
|
|
32
|
+
)
|
|
33
|
+
from .resources_extractor import ResourcesExtractor
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ShapesCodeGen:
|
|
37
|
+
"""
|
|
38
|
+
Generates shape classes based on an input Botocore service.json.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
service_json (dict): The Botocore service.json containing the shape definitions.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
service_json (dict): The Botocore service.json containing the shape definitions.
|
|
45
|
+
shapes_extractor (ShapesExtractor): An instance of the ShapesExtractor class.
|
|
46
|
+
shape_dag (dict): Shape DAG generated from service.json
|
|
47
|
+
|
|
48
|
+
Methods:
|
|
49
|
+
build_graph(): Builds a directed acyclic graph (DAG) representing the dependencies between shapes.
|
|
50
|
+
topological_sort(): Performs a topological sort on the DAG to determine the order in which shapes should be generated.
|
|
51
|
+
generate_data_class_for_shape(shape): Generates a data class for a given shape.
|
|
52
|
+
_generate_doc_string_for_shape(shape): Generates the docstring for a given shape.
|
|
53
|
+
generate_imports(): Generates the import statements for the generated shape classes.
|
|
54
|
+
generate_base_class(): Generates the base class for the shape classes.
|
|
55
|
+
_filter_input_output_shapes(shape): Filters out shapes that are used as input or output for operations.
|
|
56
|
+
generate_shapes(output_folder): Generates the shape classes and writes them to the specified output folder.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self):
|
|
60
|
+
self.combined_shapes = load_combined_shapes_data()
|
|
61
|
+
self.combined_operations = load_combined_operations_data()
|
|
62
|
+
self.shapes_extractor = ShapesExtractor()
|
|
63
|
+
self.shape_dag = self.shapes_extractor.get_shapes_dag()
|
|
64
|
+
self.resources_extractor = ResourcesExtractor()
|
|
65
|
+
self.resources_plan = self.resources_extractor.get_resource_plan()
|
|
66
|
+
self.resource_methods = self.resources_extractor.get_resource_methods()
|
|
67
|
+
|
|
68
|
+
def build_graph(self):
|
|
69
|
+
"""
|
|
70
|
+
Builds a directed acyclic graph (DAG) representing the dependencies between shapes.
|
|
71
|
+
|
|
72
|
+
Steps:
|
|
73
|
+
1. Loop over the Service Json shapes.
|
|
74
|
+
1.1. If dependency(members) found, add association of node -> dependency.
|
|
75
|
+
1.1.1. Sometimes members are not shape themselves, but have associated links to actual shapes.
|
|
76
|
+
In that case add link to node -> dependency (actual)
|
|
77
|
+
CreateExperimentRequest -> [ExperimentEntityName, ExperimentDescription, TagList]
|
|
78
|
+
1.2. else leaf node found (no dependent members), add association of node -> None.
|
|
79
|
+
|
|
80
|
+
:return: A dict which defines the structure of the DAG in the format:
|
|
81
|
+
{key : [dependencies]}
|
|
82
|
+
Example input:
|
|
83
|
+
{'CreateExperimentRequest': ['ExperimentEntityName', 'ExperimentEntityName',
|
|
84
|
+
'ExperimentDescription', 'TagList'],
|
|
85
|
+
'CreateExperimentResponse': ['ExperimentArn'],
|
|
86
|
+
'DeleteExperimentRequest': ['ExperimentEntityName'],
|
|
87
|
+
'DeleteExperimentResponse': ['ExperimentArn']}
|
|
88
|
+
"""
|
|
89
|
+
graph = {}
|
|
90
|
+
|
|
91
|
+
for node, attributes in self.combined_shapes.items():
|
|
92
|
+
if "members" in attributes:
|
|
93
|
+
for member, member_attributes in attributes["members"].items():
|
|
94
|
+
# add shapes and not shape attribute
|
|
95
|
+
# i.e. ExperimentEntityName taken over ExperimentName
|
|
96
|
+
if member_attributes["shape"] in self.combined_shapes.keys():
|
|
97
|
+
node_deps = graph.get(node, [])
|
|
98
|
+
# evaluate the member shape and then append to node deps
|
|
99
|
+
member_shape = self.combined_shapes[member_attributes["shape"]]
|
|
100
|
+
if member_shape["type"] == "list":
|
|
101
|
+
node_deps.append(member_shape["member"]["shape"])
|
|
102
|
+
elif member_shape["type"] == "map":
|
|
103
|
+
node_deps.append(member_shape["key"]["shape"])
|
|
104
|
+
node_deps.append(member_shape["value"]["shape"])
|
|
105
|
+
else:
|
|
106
|
+
node_deps.append(member_attributes["shape"])
|
|
107
|
+
graph[node] = node_deps
|
|
108
|
+
else:
|
|
109
|
+
graph[node] = None
|
|
110
|
+
return graph
|
|
111
|
+
|
|
112
|
+
def topological_sort(self):
|
|
113
|
+
"""
|
|
114
|
+
Performs a topological sort on the DAG to determine the order in which shapes should be generated.
|
|
115
|
+
|
|
116
|
+
:return: A list of shape names in the order of topological sort.
|
|
117
|
+
"""
|
|
118
|
+
graph = self.build_graph()
|
|
119
|
+
visited = set()
|
|
120
|
+
stack = []
|
|
121
|
+
|
|
122
|
+
def dfs(node):
|
|
123
|
+
visited.add(node)
|
|
124
|
+
# unless leaf node is reached do dfs
|
|
125
|
+
if graph.get(node) is not None:
|
|
126
|
+
for neighbor in graph.get(node, []):
|
|
127
|
+
if neighbor not in visited:
|
|
128
|
+
dfs(neighbor)
|
|
129
|
+
stack.append(node)
|
|
130
|
+
|
|
131
|
+
for node in graph:
|
|
132
|
+
if node not in visited:
|
|
133
|
+
dfs(node)
|
|
134
|
+
|
|
135
|
+
return stack
|
|
136
|
+
|
|
137
|
+
def generate_data_class_for_shape(self, shape):
|
|
138
|
+
"""
|
|
139
|
+
Generates a data class for a given shape.
|
|
140
|
+
|
|
141
|
+
:param shape: The name of the shape.
|
|
142
|
+
:return: The generated data class as a string.
|
|
143
|
+
"""
|
|
144
|
+
class_name = shape
|
|
145
|
+
init_data = self.shapes_extractor.generate_data_shape_string_body(
|
|
146
|
+
shape, self.resources_plan
|
|
147
|
+
)
|
|
148
|
+
try:
|
|
149
|
+
data_class_members = add_indent(init_data, 4)
|
|
150
|
+
except Exception:
|
|
151
|
+
print("DEBUG HELP\n", init_data)
|
|
152
|
+
raise
|
|
153
|
+
return SHAPE_CLASS_TEMPLATE.format(
|
|
154
|
+
class_name=class_name + "(Base)",
|
|
155
|
+
data_class_members=data_class_members,
|
|
156
|
+
docstring=self._generate_doc_string_for_shape(shape),
|
|
157
|
+
class_name_snake=pascal_to_snake(class_name),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def _generate_doc_string_for_shape(self, shape):
|
|
161
|
+
"""
|
|
162
|
+
Generates the docstring for a given shape.
|
|
163
|
+
|
|
164
|
+
:param shape: The name of the shape.
|
|
165
|
+
:return: The generated docstring as a string.
|
|
166
|
+
"""
|
|
167
|
+
shape_dict = self.combined_shapes[shape]
|
|
168
|
+
|
|
169
|
+
docstring = f" {shape}"
|
|
170
|
+
if "documentation" in shape_dict:
|
|
171
|
+
docstring += f"\n \t {shape_dict['documentation']}"
|
|
172
|
+
|
|
173
|
+
docstring += "\n\n \t Attributes"
|
|
174
|
+
docstring += "\n\t----------------------"
|
|
175
|
+
|
|
176
|
+
if "members" in shape_dict:
|
|
177
|
+
for member, member_attributes in shape_dict["members"].items():
|
|
178
|
+
docstring += f"\n \t{convert_to_snake_case(member)}"
|
|
179
|
+
if "documentation" in member_attributes:
|
|
180
|
+
docstring += f": \t {member_attributes['documentation']}"
|
|
181
|
+
|
|
182
|
+
return remove_html_tags(docstring)
|
|
183
|
+
|
|
184
|
+
def generate_license(self):
|
|
185
|
+
"""
|
|
186
|
+
Generates the license string.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
str: The license string.
|
|
190
|
+
"""
|
|
191
|
+
return LICENCES_STRING
|
|
192
|
+
|
|
193
|
+
def generate_imports(self):
|
|
194
|
+
"""
|
|
195
|
+
Generates the import statements for the generated shape classes.
|
|
196
|
+
|
|
197
|
+
:return: The generated import statements as a string.
|
|
198
|
+
"""
|
|
199
|
+
imports = "import datetime\n"
|
|
200
|
+
imports += "\n"
|
|
201
|
+
imports += "from pydantic import BaseModel, ConfigDict\n"
|
|
202
|
+
imports += "from typing import List, Dict, Optional, Any, Union\n"
|
|
203
|
+
imports += "from sagemaker_core.generated.utils import Unassigned"
|
|
204
|
+
imports += "\n"
|
|
205
|
+
return imports
|
|
206
|
+
|
|
207
|
+
def generate_base_class(self):
|
|
208
|
+
"""
|
|
209
|
+
Generates the base class for the shape classes.
|
|
210
|
+
|
|
211
|
+
:return: The generated base class as a string.
|
|
212
|
+
"""
|
|
213
|
+
# more customizations would be added later
|
|
214
|
+
return SHAPE_BASE_CLASS_TEMPLATE.format(
|
|
215
|
+
class_name="Base(BaseModel)",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def _filter_input_output_shapes(self, shape):
|
|
219
|
+
"""
|
|
220
|
+
Filters out shapes that are used as input or output for operations.
|
|
221
|
+
|
|
222
|
+
:param shape: The name of the shape.
|
|
223
|
+
:return: True if the shape should be generated, False otherwise.
|
|
224
|
+
"""
|
|
225
|
+
operation_input_output_shapes = set()
|
|
226
|
+
for operation, attrs in self.combined_operations.items():
|
|
227
|
+
if attrs.get("input"):
|
|
228
|
+
operation_input_output_shapes.add(attrs["input"]["shape"])
|
|
229
|
+
if attrs.get("output"):
|
|
230
|
+
operation_input_output_shapes.add(attrs["output"]["shape"])
|
|
231
|
+
|
|
232
|
+
required_output_shapes = set()
|
|
233
|
+
for resource_name in self.resource_methods:
|
|
234
|
+
for method in self.resource_methods[resource_name].values():
|
|
235
|
+
required_output_shapes.add(method.return_type)
|
|
236
|
+
|
|
237
|
+
if shape in operation_input_output_shapes and shape not in required_output_shapes:
|
|
238
|
+
return False
|
|
239
|
+
return True
|
|
240
|
+
|
|
241
|
+
def generate_shapes(
|
|
242
|
+
self,
|
|
243
|
+
output_folder=GENERATED_CLASSES_LOCATION,
|
|
244
|
+
file_name=SHAPES_CODEGEN_FILE_NAME,
|
|
245
|
+
) -> None:
|
|
246
|
+
"""
|
|
247
|
+
Generates the shape classes and writes them to the specified output folder.
|
|
248
|
+
|
|
249
|
+
:param output_folder: The path to the output folder.
|
|
250
|
+
"""
|
|
251
|
+
# Check if the output folder exists, if not, create it
|
|
252
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
253
|
+
|
|
254
|
+
# Create the full path for the output file
|
|
255
|
+
output_file = os.path.join(output_folder, file_name)
|
|
256
|
+
|
|
257
|
+
# Open the output file
|
|
258
|
+
with open(output_file, "w") as file:
|
|
259
|
+
# Generate and write the license to the file
|
|
260
|
+
license = self.generate_license()
|
|
261
|
+
file.write(license)
|
|
262
|
+
|
|
263
|
+
# Generate and write the imports to the file
|
|
264
|
+
imports = self.generate_imports()
|
|
265
|
+
file.write(imports)
|
|
266
|
+
|
|
267
|
+
# Generate and write Base Class
|
|
268
|
+
base_class = self.generate_base_class()
|
|
269
|
+
file.write(base_class)
|
|
270
|
+
file.write("\n\n")
|
|
271
|
+
|
|
272
|
+
# Iterate through shapes in topological order and generate classes
|
|
273
|
+
topological_order = self.topological_sort()
|
|
274
|
+
for shape in topological_order:
|
|
275
|
+
|
|
276
|
+
# Extract the necessary data for the shape
|
|
277
|
+
if self._filter_input_output_shapes(shape):
|
|
278
|
+
shape_dict = self.combined_shapes[shape]
|
|
279
|
+
shape_type = shape_dict["type"]
|
|
280
|
+
if shape_type == "structure":
|
|
281
|
+
|
|
282
|
+
# Generate and write data class for shape
|
|
283
|
+
shape_class = self.generate_data_class_for_shape(shape)
|
|
284
|
+
file.write(shape_class)
|