@autorest/python 5.15.0 → 5.18.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. package/ChangeLog.md +98 -4
  2. package/README.md +30 -4
  3. package/autorest/__init__.py +2 -3
  4. package/autorest/black/__init__.py +12 -5
  5. package/autorest/codegen/__init__.py +122 -211
  6. package/autorest/codegen/models/__init__.py +122 -78
  7. package/autorest/codegen/models/base_builder.py +70 -72
  8. package/autorest/codegen/models/base_model.py +7 -5
  9. package/autorest/codegen/models/{base_schema.py → base_type.py} +68 -45
  10. package/autorest/codegen/models/client.py +193 -40
  11. package/autorest/codegen/models/code_model.py +145 -245
  12. package/autorest/codegen/models/combined_type.py +107 -0
  13. package/autorest/codegen/models/constant_type.py +122 -0
  14. package/autorest/codegen/models/credential_types.py +224 -0
  15. package/autorest/codegen/models/dictionary_type.py +131 -0
  16. package/autorest/codegen/models/enum_type.py +195 -0
  17. package/autorest/codegen/models/imports.py +93 -41
  18. package/autorest/codegen/models/list_type.py +149 -0
  19. package/autorest/codegen/models/lro_operation.py +90 -133
  20. package/autorest/codegen/models/lro_paging_operation.py +28 -12
  21. package/autorest/codegen/models/model_type.py +262 -0
  22. package/autorest/codegen/models/operation.py +412 -259
  23. package/autorest/codegen/models/operation_group.py +80 -91
  24. package/autorest/codegen/models/paging_operation.py +101 -117
  25. package/autorest/codegen/models/parameter.py +302 -341
  26. package/autorest/codegen/models/parameter_list.py +373 -357
  27. package/autorest/codegen/models/primitive_types.py +544 -0
  28. package/autorest/codegen/models/property.py +136 -134
  29. package/autorest/codegen/models/request_builder.py +138 -86
  30. package/autorest/codegen/models/request_builder_parameter.py +122 -86
  31. package/autorest/codegen/models/response.py +325 -0
  32. package/autorest/codegen/models/utils.py +13 -17
  33. package/autorest/codegen/serializers/__init__.py +212 -112
  34. package/autorest/codegen/serializers/builder_serializer.py +931 -1040
  35. package/autorest/codegen/serializers/client_serializer.py +140 -84
  36. package/autorest/codegen/serializers/general_serializer.py +26 -50
  37. package/autorest/codegen/serializers/import_serializer.py +96 -31
  38. package/autorest/codegen/serializers/metadata_serializer.py +39 -79
  39. package/autorest/codegen/serializers/model_base_serializer.py +62 -34
  40. package/autorest/codegen/serializers/model_generic_serializer.py +9 -10
  41. package/autorest/codegen/serializers/model_init_serializer.py +4 -2
  42. package/autorest/codegen/serializers/model_python3_serializer.py +29 -22
  43. package/autorest/codegen/serializers/operation_groups_serializer.py +21 -19
  44. package/autorest/codegen/serializers/operations_init_serializer.py +23 -11
  45. package/autorest/codegen/serializers/parameter_serializer.py +174 -0
  46. package/autorest/codegen/serializers/patch_serializer.py +4 -1
  47. package/autorest/codegen/serializers/request_builders_serializer.py +57 -0
  48. package/autorest/codegen/serializers/utils.py +0 -126
  49. package/autorest/codegen/templates/MANIFEST.in.jinja2 +1 -0
  50. package/autorest/codegen/templates/{service_client.py.jinja2 → client.py.jinja2} +7 -7
  51. package/autorest/codegen/templates/config.py.jinja2 +13 -13
  52. package/autorest/codegen/templates/enum.py.jinja2 +4 -4
  53. package/autorest/codegen/templates/enum_container.py.jinja2 +1 -1
  54. package/autorest/codegen/templates/init.py.jinja2 +3 -3
  55. package/autorest/codegen/templates/lro_operation.py.jinja2 +6 -5
  56. package/autorest/codegen/templates/lro_paging_operation.py.jinja2 +6 -5
  57. package/autorest/codegen/templates/metadata.json.jinja2 +36 -35
  58. package/autorest/codegen/templates/model.py.jinja2 +23 -24
  59. package/autorest/codegen/templates/model_container.py.jinja2 +2 -1
  60. package/autorest/codegen/templates/model_init.py.jinja2 +3 -5
  61. package/autorest/codegen/templates/operation.py.jinja2 +10 -14
  62. package/autorest/codegen/templates/operation_group.py.jinja2 +9 -15
  63. package/autorest/codegen/templates/operation_groups_container.py.jinja2 +1 -1
  64. package/autorest/codegen/templates/operation_tools.jinja2 +8 -2
  65. package/autorest/codegen/templates/paging_operation.py.jinja2 +7 -8
  66. package/autorest/codegen/templates/request_builder.py.jinja2 +19 -10
  67. package/autorest/codegen/templates/setup.py.jinja2 +9 -3
  68. package/autorest/codegen/templates/vendor.py.jinja2 +1 -1
  69. package/autorest/jsonrpc/__init__.py +7 -12
  70. package/autorest/jsonrpc/localapi.py +4 -3
  71. package/autorest/jsonrpc/server.py +28 -9
  72. package/autorest/jsonrpc/stdstream.py +13 -6
  73. package/autorest/m2r/__init__.py +5 -8
  74. package/autorest/m4reformatter/__init__.py +1126 -0
  75. package/autorest/multiapi/__init__.py +24 -14
  76. package/autorest/multiapi/models/client.py +21 -11
  77. package/autorest/multiapi/models/code_model.py +23 -10
  78. package/autorest/multiapi/models/config.py +4 -1
  79. package/autorest/multiapi/models/constant_global_parameter.py +1 -0
  80. package/autorest/multiapi/models/global_parameter.py +2 -1
  81. package/autorest/multiapi/models/global_parameters.py +14 -8
  82. package/autorest/multiapi/models/imports.py +24 -17
  83. package/autorest/multiapi/models/mixin_operation.py +5 -5
  84. package/autorest/multiapi/models/operation_group.py +2 -1
  85. package/autorest/multiapi/models/operation_mixin_group.py +21 -10
  86. package/autorest/multiapi/serializers/__init__.py +20 -25
  87. package/autorest/multiapi/serializers/import_serializer.py +47 -17
  88. package/autorest/multiapi/serializers/multiapi_serializer.py +17 -17
  89. package/autorest/multiapi/templates/multiapi_config.py.jinja2 +3 -3
  90. package/autorest/multiapi/templates/multiapi_init.py.jinja2 +2 -2
  91. package/autorest/multiapi/templates/multiapi_operations_mixin.py.jinja2 +4 -4
  92. package/autorest/multiapi/templates/multiapi_service_client.py.jinja2 +9 -9
  93. package/autorest/multiapi/utils.py +3 -3
  94. package/autorest/postprocess/__init__.py +202 -0
  95. package/autorest/postprocess/get_all.py +19 -0
  96. package/autorest/postprocess/venvtools.py +73 -0
  97. package/autorest/preprocess/__init__.py +210 -0
  98. package/autorest/preprocess/helpers.py +54 -0
  99. package/autorest/{namer → preprocess}/python_mappings.py +25 -32
  100. package/package.json +3 -3
  101. package/run-python3.js +2 -3
  102. package/venvtools.py +1 -1
  103. package/autorest/codegen/models/constant_schema.py +0 -101
  104. package/autorest/codegen/models/credential_model.py +0 -47
  105. package/autorest/codegen/models/credential_schema.py +0 -91
  106. package/autorest/codegen/models/credential_schema_policy.py +0 -77
  107. package/autorest/codegen/models/dictionary_schema.py +0 -103
  108. package/autorest/codegen/models/enum_schema.py +0 -215
  109. package/autorest/codegen/models/list_schema.py +0 -123
  110. package/autorest/codegen/models/object_schema.py +0 -253
  111. package/autorest/codegen/models/primitive_schemas.py +0 -466
  112. package/autorest/codegen/models/request_builder_parameter_list.py +0 -280
  113. package/autorest/codegen/models/rest.py +0 -42
  114. package/autorest/codegen/models/schema_request.py +0 -45
  115. package/autorest/codegen/models/schema_response.py +0 -136
  116. package/autorest/codegen/serializers/rest_serializer.py +0 -57
  117. package/autorest/namer/__init__.py +0 -25
  118. package/autorest/namer/name_converter.py +0 -412
@@ -0,0 +1,1126 @@
1
+ # pylint: disable=too-many-lines
2
+ # -------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft Corporation. All rights reserved.
4
+ # Licensed under the MIT License. See License.txt in the project root for
5
+ # license information.
6
+ # --------------------------------------------------------------------------
7
+ """The modelerfour reformatter autorest plugin.
8
+ """
9
+ import re
10
+ import copy
11
+ import logging
12
+ from typing import Callable, Dict, Any, Iterable, List, Optional, Set
13
+
14
+ from .. import YamlUpdatePlugin
15
+
16
+ JSON_REGEXP = re.compile(r"^(application|text)/(.+\+)?json$")
17
+ ORIGINAL_ID_TO_UPDATED_TYPE: Dict[int, Dict[str, Any]] = {}
18
+ OAUTH_TYPE = "OAuth2"
19
+ KEY_TYPE = "Key"
20
+
21
+ _LOGGER = logging.getLogger(__name__)
22
+
23
+ # used if we want to get a string / binary type etc
24
+ KNOWN_TYPES: Dict[str, Dict[str, Any]] = {
25
+ "string": {"type": "string"},
26
+ "binary": {"type": "binary"},
27
+ "anydict": {"type": "dict", "elementType": {"type": "any"}},
28
+ }
29
+
30
+
31
+ def is_body(yaml_data: Dict[str, Any]) -> bool:
32
+ """Return true if passed in parameter is a body param"""
33
+ return yaml_data["protocol"]["http"]["in"] == "body"
34
+
35
+
36
+ def get_body_parameter(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
37
+ """Return a request's body parameter"""
38
+ return next(p for p in yaml_data["parameters"] if is_body(p))
39
+
40
+
41
+ def get_azure_key_credential(key: str) -> Dict[str, Any]:
42
+ retval = {
43
+ "type": KEY_TYPE,
44
+ "policy": {"type": "AzureKeyCredentialPolicy", "key": key},
45
+ }
46
+ update_type(retval)
47
+ return retval
48
+
49
+
50
+ def get_type(yaml_data: Dict[str, Any]):
51
+ try:
52
+ return ORIGINAL_ID_TO_UPDATED_TYPE[id(yaml_data)]
53
+ except KeyError:
54
+ return KNOWN_TYPES[yaml_data["type"]]
55
+
56
+
57
+ def _get_api_versions(api_versions: List[Dict[str, str]]) -> List[str]:
58
+ return list({api_version["version"]: None for api_version in api_versions}.keys())
59
+
60
+
61
+ def _update_type_base(updated_type: str, yaml_data: Dict[str, Any]) -> Dict[str, Any]:
62
+ return {
63
+ "type": updated_type,
64
+ "clientDefaultValue": yaml_data.get("defaultValue"),
65
+ "xmlMetadata": yaml_data.get("serialization", {}).get("xml", {}),
66
+ "apiVersions": _get_api_versions(yaml_data.get("apiVersions", [])),
67
+ }
68
+
69
+
70
+ def update_list(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
71
+ base = _update_type_base("list", yaml_data)
72
+ base["elementType"] = update_type(yaml_data["elementType"])
73
+ base["maxItems"] = yaml_data.get("maxItems")
74
+ base["minItems"] = yaml_data.get("minItems")
75
+ base["uniqueItems"] = yaml_data.get("uniqueItems", False)
76
+ return base
77
+
78
+
79
+ def update_dict(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
80
+ base = _update_type_base("dict", yaml_data)
81
+ base["elementType"] = update_type(yaml_data["elementType"])
82
+ return base
83
+
84
+
85
+ def update_constant(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
86
+ base = _update_type_base("constant", yaml_data)
87
+ base["valueType"] = update_type(yaml_data["valueType"])
88
+ base["value"] = yaml_data["value"]["value"]
89
+ return base
90
+
91
+
92
+ def update_enum_value(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
93
+ return {
94
+ "name": yaml_data["language"]["default"]["name"],
95
+ "value": yaml_data["value"],
96
+ "description": yaml_data["language"]["default"]["description"],
97
+ }
98
+
99
+
100
+ def update_enum(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
101
+ base = _update_type_base("enum", yaml_data)
102
+ base.update(
103
+ {
104
+ "name": yaml_data["language"]["default"]["name"],
105
+ "valueType": update_type(yaml_data["choiceType"]),
106
+ "values": [update_enum_value(v) for v in yaml_data["choices"]],
107
+ "description": yaml_data["language"]["default"]["description"],
108
+ }
109
+ )
110
+ return base
111
+
112
+
113
+ def update_property(
114
+ yaml_data: Dict[str, Any], has_additional_properties: bool
115
+ ) -> Dict[str, Any]:
116
+ client_name = yaml_data["language"]["default"]["name"]
117
+ if has_additional_properties and client_name == "additional_properties":
118
+ client_name = "additional_properties1"
119
+ return {
120
+ "clientName": client_name,
121
+ "restApiName": yaml_data["serializedName"],
122
+ "flattenedNames": yaml_data.get("flattenedNames", []),
123
+ "type": update_type(yaml_data["schema"]),
124
+ "optional": not yaml_data.get("required"),
125
+ "description": yaml_data["language"]["default"]["description"],
126
+ "isDiscriminator": yaml_data.get("isDiscriminator"),
127
+ "readonly": yaml_data.get("readOnly", False),
128
+ "groupedParameterNames": [
129
+ op["language"]["default"]["name"].lstrip("_") # TODO: patching m4
130
+ for op in yaml_data.get("originalParameter", [])
131
+ ],
132
+ }
133
+
134
+
135
+ def update_discriminated_subtypes(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
136
+ return {
137
+ obj["discriminatorValue"]: update_type(obj)
138
+ for obj in yaml_data.get("discriminator", {}).get("immediate", {}).values()
139
+ }
140
+
141
+
142
+ def create_model(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
143
+ base = _update_type_base("model", yaml_data)
144
+ base["name"] = yaml_data["language"]["default"]["name"]
145
+ base["description"] = yaml_data["language"]["default"]["description"]
146
+ base["isXml"] = "xml" in yaml_data.get("serializationFormats", [])
147
+ return base
148
+
149
+
150
+ def fill_model(
151
+ yaml_data: Dict[str, Any], current_model: Dict[str, Any]
152
+ ) -> Dict[str, Any]:
153
+ properties = []
154
+ yaml_parents = yaml_data.get("parents", {}).get("immediate", [])
155
+ dict_parents = [p for p in yaml_parents if p["type"] == "dictionary"]
156
+ if dict_parents:
157
+ # add additional properties property
158
+ properties.append(
159
+ {
160
+ "clientName": "additional_properties",
161
+ "restApiName": "",
162
+ "type": update_type(dict_parents[0]),
163
+ "optional": True,
164
+ "description": "Unmatched properties from the message are deserialized to this collection.",
165
+ "isDiscriminator": False,
166
+ "readonly": False,
167
+ }
168
+ )
169
+ properties.extend(
170
+ [
171
+ update_property(p, has_additional_properties=bool(dict_parents))
172
+ for p in yaml_data.get("properties", [])
173
+ ]
174
+ )
175
+ current_model.update(
176
+ {
177
+ "properties": properties,
178
+ "parents": [
179
+ update_type(yaml_data=p) for p in yaml_parents if p["type"] == "object"
180
+ ],
181
+ "discriminatedSubtypes": update_discriminated_subtypes(yaml_data),
182
+ "discriminatorValue": yaml_data.get("discriminatorValue"),
183
+ }
184
+ )
185
+ return current_model
186
+
187
+
188
+ def update_number_type(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
189
+ updated_type = "integer" if yaml_data["type"] == "integer" else "float"
190
+ base = _update_type_base(updated_type, yaml_data)
191
+ base.update(
192
+ {
193
+ "precision": yaml_data.get("precision"),
194
+ "multipleOf": yaml_data.get("multipleOf"),
195
+ "maximum": yaml_data.get("maximum"),
196
+ "minimum": yaml_data.get("minimum"),
197
+ "exclusiveMaximum": yaml_data.get("exclusiveMaximum"),
198
+ "exclusiveMinimum": yaml_data.get("exclusiveMinimum"),
199
+ }
200
+ )
201
+ return base
202
+
203
+
204
+ def update_primitive( # pylint: disable=too-many-return-statements
205
+ type_group: str, yaml_data: Dict[str, Any]
206
+ ) -> Dict[str, Any]:
207
+ if type_group in ("integer", "number"):
208
+ return update_number_type(yaml_data)
209
+ if type_group in ("string", "uuid", "uri"):
210
+ if any(
211
+ r in yaml_data
212
+ for r in (
213
+ "maxLength",
214
+ "minLength",
215
+ "pattern",
216
+ "defaultValue",
217
+ "serialization",
218
+ )
219
+ ):
220
+ base = _update_type_base("string", yaml_data)
221
+ base.update(
222
+ {
223
+ "maxLength": yaml_data.get("maxLength"),
224
+ "minLength": yaml_data.get("minLength"),
225
+ "pattern": yaml_data.get("pattern"),
226
+ }
227
+ )
228
+ return base
229
+ return KNOWN_TYPES["string"]
230
+ if type_group == "binary":
231
+ return KNOWN_TYPES["binary"]
232
+ if type_group == "date-time":
233
+ base = _update_type_base("datetime", yaml_data)
234
+ base["format"] = yaml_data["format"]
235
+ return base
236
+ if type_group == "byte-array":
237
+ base = _update_type_base("byte-array", yaml_data)
238
+ base["format"] = yaml_data["format"]
239
+ return base
240
+ return _update_type_base(type_group, yaml_data)
241
+
242
+
243
+ def update_types(yaml_data: List[Dict[str, Any]]) -> Dict[str, Any]:
244
+ types: List[Dict[str, Any]] = []
245
+ for type in yaml_data:
246
+ if KNOWN_TYPES.get(type["type"]):
247
+ types.append(KNOWN_TYPES[type["type"]])
248
+ else:
249
+ types.append(
250
+ next(
251
+ v for v in ORIGINAL_ID_TO_UPDATED_TYPE.values() if id(v) == id(type)
252
+ )
253
+ )
254
+ retval = {"type": "combined", "types": types}
255
+ ORIGINAL_ID_TO_UPDATED_TYPE[id(retval)] = retval
256
+ return retval
257
+
258
+
259
+ def update_type(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
260
+ if id(yaml_data) in ORIGINAL_ID_TO_UPDATED_TYPE:
261
+ return ORIGINAL_ID_TO_UPDATED_TYPE[id(yaml_data)]
262
+ type_group = yaml_data["type"]
263
+ if type_group == "array":
264
+ updated_type = update_list(yaml_data)
265
+ elif type_group == "dictionary":
266
+ updated_type = update_dict(yaml_data)
267
+ elif type_group == "constant":
268
+ updated_type = update_constant(yaml_data)
269
+ elif type_group in ("choice", "sealed-choice"):
270
+ updated_type = update_enum(yaml_data)
271
+ elif type_group in (OAUTH_TYPE, KEY_TYPE):
272
+ updated_type = yaml_data
273
+ elif type_group in ("object", "group"):
274
+ # avoiding infinite loop
275
+ initial_model = create_model(yaml_data)
276
+ ORIGINAL_ID_TO_UPDATED_TYPE[id(yaml_data)] = initial_model
277
+ updated_type = fill_model(yaml_data, initial_model)
278
+ else:
279
+ updated_type = update_primitive(type_group, yaml_data)
280
+ ORIGINAL_ID_TO_UPDATED_TYPE[id(yaml_data)] = updated_type
281
+ return updated_type
282
+
283
+
284
+ def update_parameter_base(
285
+ yaml_data: Dict[str, Any], *, override_client_name: Optional[str] = None
286
+ ) -> Dict[str, Any]:
287
+ location = yaml_data["protocol"].get("http", {}).get("in")
288
+ if not location:
289
+ location = "other"
290
+ if location == "uri":
291
+ location = "endpointPath"
292
+ grouped_by = (
293
+ yaml_data["groupedBy"]["language"]["default"]["name"]
294
+ if yaml_data.get("groupedBy")
295
+ else None
296
+ )
297
+ client_name: str = override_client_name or yaml_data["language"]["default"]["name"]
298
+ if grouped_by and client_name[0] != "_":
299
+ # this is an m4 bug, doesn't hide constant grouped params, patching m4 for now
300
+ client_name = "_" + client_name
301
+ return {
302
+ "optional": not yaml_data.get("required", False),
303
+ "description": yaml_data["language"]["default"]["description"],
304
+ "clientName": client_name,
305
+ "restApiName": yaml_data["language"]["default"].get("serializedName"),
306
+ "clientDefaultValue": yaml_data.get("clientDefaultValue"),
307
+ "location": location,
308
+ "groupedBy": grouped_by,
309
+ "checkClientInput": yaml_data.get("checkClientInput", False),
310
+ }
311
+
312
+
313
+ def update_parameter_delimiter(style: Optional[str]) -> Optional[str]:
314
+ if not style:
315
+ return None
316
+ if style in ("form", "simple"):
317
+ return "comma"
318
+ if style in ("spaceDelimited", "pipeDelimited", "tabDelimited"):
319
+ return style.replace("Delimited", "")
320
+ return None
321
+
322
+
323
+ def get_all_body_types(yaml_data: Dict[str, Any]) -> List[Dict[str, Any]]:
324
+ seen_body_types = {}
325
+ for schema_request in yaml_data.values():
326
+ body_param = get_body_parameter(schema_request)
327
+ seen_body_types[id(body_param["schema"])] = update_type(body_param["schema"])
328
+ return list(seen_body_types.values())
329
+
330
+
331
+ def get_body_type_for_description(body_parameter: Dict[str, Any]) -> str:
332
+ if body_parameter["type"]["type"] == "binary":
333
+ return "binary"
334
+ if body_parameter["type"]["type"] == "string":
335
+ return "string"
336
+ return "JSON"
337
+
338
+
339
+ def add_lro_information(operation: Dict[str, Any], yaml_data: Dict[str, Any]) -> None:
340
+ operation["discriminator"] = "lro"
341
+ extensions = yaml_data["extensions"]
342
+ operation["lroOptions"] = extensions.get("x-ms-long-running-operation-options")
343
+ for response in operation["responses"]:
344
+ response["pollerSync"] = extensions.get("x-python-custom-poller-sync")
345
+ response["pollerAsync"] = extensions.get("x-python-custom-poller-async")
346
+ response["pollingMethodSync"] = extensions.get(
347
+ "x-python-custom-default-polling-method-sync"
348
+ )
349
+ response["pollingMethodAsync"] = extensions.get(
350
+ "x-python-custom-default-polling-method-async"
351
+ )
352
+
353
+
354
+ def filter_out_paging_next_operation(
355
+ yaml_data: List[Dict[str, Any]]
356
+ ) -> List[Dict[str, Any]]:
357
+ next_operations: Set[str] = set()
358
+ for operation in yaml_data:
359
+ next_operation = operation.get("nextOperation")
360
+ if not next_operation:
361
+ continue
362
+ next_operations.add(next_operation["name"])
363
+ return [o for o in yaml_data if o["name"] not in next_operations]
364
+
365
+
366
+ def update_response_header(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
367
+ return {
368
+ "restApiName": yaml_data["header"],
369
+ "type": update_type(yaml_data["schema"]),
370
+ }
371
+
372
+
373
+ def update_response(
374
+ yaml_data: Dict[str, Any],
375
+ ) -> Dict[str, Any]:
376
+ if yaml_data.get("binary"):
377
+ type = KNOWN_TYPES["binary"]
378
+ elif yaml_data.get("schema"):
379
+ type = get_type(yaml_data["schema"])
380
+ else:
381
+ type = None
382
+ return {
383
+ "headers": [
384
+ update_response_header(h)
385
+ for h in yaml_data["protocol"]["http"].get("headers", [])
386
+ ],
387
+ "statusCodes": [
388
+ int(code) if code != "default" else "default"
389
+ for code in yaml_data["protocol"]["http"]["statusCodes"]
390
+ ],
391
+ "type": type,
392
+ "nullable": yaml_data.get("nullable", False),
393
+ }
394
+
395
+
396
+ def _get_default_content_type( # pylint: disable=too-many-return-statements
397
+ content_types: Iterable[str],
398
+ ) -> Optional[str]:
399
+ json_values = [ct for ct in content_types if JSON_REGEXP.match(ct)]
400
+ if json_values:
401
+ if "application/json" in json_values:
402
+ return "application/json"
403
+ return json_values[0]
404
+
405
+ xml_values = [ct for ct in content_types if "xml" in ct]
406
+ if xml_values:
407
+ if "application/xml" in xml_values:
408
+ return "application/xml"
409
+ return xml_values[0]
410
+
411
+ if "application/octet-stream" in content_types:
412
+ return "application/octet-stream"
413
+ if "application/x-www-form-urlencoded" in content_types:
414
+ return "application/x-www-form-urlencoded"
415
+ return None
416
+
417
+
418
+ def update_client_url(yaml_data: Dict[str, Any]) -> str:
419
+ if any(
420
+ p
421
+ for p in yaml_data["globalParameters"]
422
+ if p["language"]["default"]["name"] == "$host"
423
+ ):
424
+ # this means we DO NOT have a parameterized host
425
+ # in order to share code better, going to make it a "parameterized host" of
426
+ # just the endpoint parameter
427
+ return "{endpoint}"
428
+ # we have a parameterized host. Return first url from first request, quite gross
429
+ return yaml_data["operationGroups"][0]["operations"][0]["requests"][0]["protocol"][
430
+ "http"
431
+ ]["uri"]
432
+
433
+
434
+ class M4Reformatter(YamlUpdatePlugin): # pylint: disable=too-many-public-methods
435
+ """Add Python naming information."""
436
+
437
+ @property
438
+ def azure_arm(self) -> bool:
439
+ return bool(self._autorestapi.get_boolean_value("azure-arm"))
440
+
441
+ @property
442
+ def version_tolerant(self) -> bool:
443
+ return bool(self._autorestapi.get_boolean_value("version-tolerant"))
444
+
445
+ @property
446
+ def low_level_client(self) -> bool:
447
+ return bool(self._autorestapi.get_boolean_value("low-level-client"))
448
+
449
+ @property
450
+ def legacy(self) -> bool:
451
+ return not (self.version_tolerant or self.low_level_client)
452
+
453
+ @property
454
+ def default_optional_constants_to_none(self) -> bool:
455
+ return bool(
456
+ self._autorestapi.get_boolean_value("default-optional-constants-to-none")
457
+ or self.version_tolerant
458
+ )
459
+
460
+ def update_overloads(
461
+ self,
462
+ group_name: str,
463
+ yaml_data: Dict[str, Any],
464
+ body_parameter: Optional[Dict[str, Any]],
465
+ *,
466
+ content_types: Optional[List[str]] = None,
467
+ ) -> List[Dict[str, Any]]:
468
+ overloads: List[Dict[str, Any]] = []
469
+ if not body_parameter:
470
+ return overloads
471
+ body_types = body_parameter["type"].get("types", [])
472
+ if not body_types:
473
+ return overloads
474
+ for body_type in body_types:
475
+ overload = self.update_overload(
476
+ group_name, yaml_data, body_type, content_types=content_types
477
+ )
478
+ for parameter in overload["parameters"]:
479
+ if parameter["restApiName"] == "Content-Type":
480
+ parameter["clientDefaultValue"] = overload["bodyParameter"][
481
+ "defaultContentType"
482
+ ]
483
+ overloads.append(overload)
484
+ return overloads
485
+
486
+ def _update_operation_helper(
487
+ self,
488
+ group_name: str,
489
+ yaml_data: Dict[str, Any],
490
+ body_parameter: Optional[Dict[str, Any]],
491
+ *,
492
+ is_overload: bool = False,
493
+ ) -> Dict[str, Any]:
494
+ in_overriden = (
495
+ body_parameter["type"]["type"] == "combined" if body_parameter else False
496
+ )
497
+ return {
498
+ "name": yaml_data["language"]["default"]["name"],
499
+ "description": yaml_data["language"]["default"]["description"],
500
+ "summary": yaml_data["language"]["default"].get("summary"),
501
+ "url": yaml_data["requests"][0]["protocol"]["http"]["path"],
502
+ "method": yaml_data["requests"][0]["protocol"]["http"]["method"].upper(),
503
+ "parameters": self.update_parameters(
504
+ yaml_data,
505
+ body_parameter,
506
+ in_overload=is_overload,
507
+ in_overriden=in_overriden,
508
+ ),
509
+ "bodyParameter": body_parameter,
510
+ "responses": [update_response(r) for r in yaml_data.get("responses", [])],
511
+ "exceptions": [
512
+ update_response(e)
513
+ for e in yaml_data.get("exceptions", [])
514
+ if not (
515
+ e.get("schema")
516
+ and e["schema"]["language"]["default"]["name"] == "CloudError"
517
+ )
518
+ ],
519
+ "groupName": group_name,
520
+ "discriminator": "operation",
521
+ "isOverload": is_overload,
522
+ "apiVersions": _get_api_versions(yaml_data.get("apiVersions", [])),
523
+ }
524
+
525
+ def get_operation_creator(
526
+ self, yaml_data: Dict[str, Any]
527
+ ) -> Callable[[str, Dict[str, Any]], Dict[str, Any]]:
528
+ lro_operation = yaml_data.get("extensions", {}).get(
529
+ "x-ms-long-running-operation"
530
+ )
531
+ paging_operation = yaml_data.get("extensions", {}).get("x-ms-pageable")
532
+ if lro_operation and paging_operation:
533
+ return self.update_lro_paging_operation
534
+ if lro_operation:
535
+ return self.update_lro_operation
536
+ if paging_operation:
537
+ return self.update_paging_operation
538
+ return self.update_operation
539
+
540
+ def update_operation(
541
+ self, group_name: str, yaml_data: Dict[str, Any]
542
+ ) -> Dict[str, Any]:
543
+ body_parameter = (
544
+ self.update_body_parameter(yaml_data["requestMediaTypes"])
545
+ if yaml_data.get("requestMediaTypes")
546
+ else None
547
+ )
548
+ content_types = None
549
+ if ( # pylint: disable=too-many-boolean-expressions
550
+ body_parameter
551
+ and body_parameter["type"]["type"] != "combined"
552
+ and yaml_data.get("requestMediaTypes")
553
+ and any(
554
+ ct for ct in yaml_data["requestMediaTypes"] if JSON_REGEXP.match(ct)
555
+ )
556
+ and body_parameter["type"]["type"] in ("model", "dict", "list")
557
+ and not body_parameter["type"]["xmlMetadata"]
558
+ and not body_parameter.get("flattened")
559
+ and not body_parameter.get("groupedBy")
560
+ ):
561
+ combined_type = update_types(
562
+ [body_parameter["type"], KNOWN_TYPES["binary"]]
563
+ )
564
+ body_parameter["type"] = combined_type
565
+ content_types = body_parameter["contentTypes"]
566
+ operation = self._update_operation_helper(group_name, yaml_data, body_parameter)
567
+ operation["overloads"] = self.update_overloads(
568
+ group_name, yaml_data, body_parameter, content_types=content_types
569
+ )
570
+ return operation
571
+
572
+ def add_paging_information(
573
+ self, group_name: str, operation: Dict[str, Any], yaml_data: Dict[str, Any]
574
+ ) -> None:
575
+ operation["discriminator"] = "paging"
576
+ operation["itemName"] = yaml_data["extensions"]["x-ms-pageable"].get(
577
+ "itemName", "value"
578
+ )
579
+ operation["continuationTokenName"] = yaml_data["extensions"][
580
+ "x-ms-pageable"
581
+ ].get("nextLinkName")
582
+ if yaml_data["language"]["default"]["paging"].get("nextLinkOperation"):
583
+ operation["nextOperation"] = self.update_operation(
584
+ group_name=group_name,
585
+ yaml_data=yaml_data["language"]["default"]["paging"][
586
+ "nextLinkOperation"
587
+ ],
588
+ )
589
+ extensions = yaml_data["extensions"]
590
+ for response in operation["responses"]:
591
+ response["pagerSync"] = extensions.get("x-python-custom-pager-sync")
592
+ response["pagerAsync"] = extensions.get("x-python-custom-pager-async")
593
+
594
+ def update_paging_operation(
595
+ self, group_name: str, yaml_data: Dict[str, Any]
596
+ ) -> Dict[str, Any]:
597
+ base_operation = self.update_operation(group_name, yaml_data)
598
+ self.add_paging_information(group_name, base_operation, yaml_data)
599
+ return base_operation
600
+
601
+ def update_lro_paging_operation(
602
+ self, group_name: str, yaml_data: Dict[str, Any]
603
+ ) -> Dict[str, Any]:
604
+ operation = self.update_lro_operation(group_name, yaml_data)
605
+ self.add_paging_information(group_name, operation, yaml_data)
606
+ operation["discriminator"] = "lropaging"
607
+ return operation
608
+
609
+ def update_lro_operation(
610
+ self, group_name: str, yaml_data: Dict[str, Any]
611
+ ) -> Dict[str, Any]:
612
+ base_operation = self.update_operation(group_name, yaml_data)
613
+ add_lro_information(base_operation, yaml_data)
614
+ for overload in base_operation["overloads"]:
615
+ add_lro_information(overload, yaml_data)
616
+ return base_operation
617
+
618
+ def update_overload(
619
+ self,
620
+ group_name: str,
621
+ yaml_data: Dict[str, Any],
622
+ body_type: Dict[str, Any],
623
+ *,
624
+ content_types: Optional[List[str]] = None,
625
+ ) -> Dict[str, Any]:
626
+ body_parameter = self.update_body_parameter_overload(
627
+ yaml_data["requestMediaTypes"], body_type, content_types=content_types
628
+ )
629
+ return self._update_operation_helper(
630
+ group_name, yaml_data, body_parameter, is_overload=True
631
+ )
632
+
633
+ def update_operation_group(self, yaml_data: Dict[str, Any]) -> Dict[str, Any]:
634
+ property_name = yaml_data["language"]["default"]["name"]
635
+ return {
636
+ "propertyName": property_name,
637
+ "className": property_name,
638
+ "operations": filter_out_paging_next_operation(
639
+ [
640
+ self.get_operation_creator(o)(property_name, o)
641
+ for o in yaml_data["operations"]
642
+ ]
643
+ ),
644
+ }
645
+
646
+ def _update_body_parameter_helper(
647
+ self,
648
+ yaml_data: Dict[str, Any],
649
+ body_param: Dict[str, Any],
650
+ body_type: Dict[str, Any],
651
+ *,
652
+ content_types: Optional[List[str]] = None,
653
+ ) -> Dict[str, Any]:
654
+ flattened = body_param.get("flattened")
655
+ is_partial_body = body_param.get("isPartialBody")
656
+ param_base = update_parameter_base(body_param)
657
+ body_param = copy.deepcopy(param_base)
658
+ body_param["type"] = body_type
659
+ body_param["contentTypes"] = content_types or [
660
+ ct
661
+ for ct, request in yaml_data.items()
662
+ if id(body_type)
663
+ == id(
664
+ ORIGINAL_ID_TO_UPDATED_TYPE[id(get_body_parameter(request)["schema"])]
665
+ )
666
+ ]
667
+ # get default content type
668
+ body_param["defaultContentType"] = _get_default_content_type(
669
+ body_param["contentTypes"]
670
+ )
671
+ # python supports IO input with all kinds of content_types
672
+ if body_type["type"] == "binary":
673
+ body_param["contentTypes"] = content_types or list(yaml_data.keys())
674
+ if body_param["type"]["type"] == "constant":
675
+ if not body_param["optional"] or (
676
+ body_param["optional"] and not self.default_optional_constants_to_none
677
+ ):
678
+ body_param["clientDefaultValue"] = body_type["value"]
679
+ body_param["flattened"] = flattened
680
+ body_param["isPartialBody"] = is_partial_body
681
+ return body_param
682
+
683
+ def update_multipart_body_parameter(
684
+ self, yaml_data: Dict[str, Any], client_name: str, description: str
685
+ ) -> Dict[str, Any]:
686
+ first_value = list(yaml_data.values())[0]
687
+ entries = [
688
+ self._update_body_parameter_helper(yaml_data, p, update_type(p["schema"]))
689
+ for p in first_value["parameters"]
690
+ if is_body(p)
691
+ ]
692
+ return {
693
+ "optional": not first_value.get("required", False),
694
+ "description": description,
695
+ "clientName": client_name,
696
+ "restApiName": client_name,
697
+ "clientDefaultValue": None,
698
+ "location": "Method",
699
+ "type": KNOWN_TYPES["anydict"],
700
+ "contentTypes": list(yaml_data.keys()),
701
+ "defaultContentType": _get_default_content_type(yaml_data.keys()),
702
+ "entries": entries,
703
+ }
704
+
705
+ def update_body_parameter(self, yaml_data: Dict[str, Any]) -> Dict[str, Any]:
706
+ protocol_http = list(yaml_data.values())[0].get("protocol", {}).get("http", {})
707
+ if protocol_http.get("multipart"):
708
+ return self.update_multipart_body_parameter(
709
+ yaml_data, "files", "Multipart input for files."
710
+ )
711
+ if protocol_http.get("knownMediaType") == "form":
712
+ return self.update_multipart_body_parameter(
713
+ yaml_data, "data", "Multipart input for form encoded data."
714
+ )
715
+ body_types = get_all_body_types(yaml_data)
716
+ if len(body_types) > 1 and not yaml_data.get("flattened"):
717
+ body_type = update_types(body_types)
718
+ else:
719
+ body_type = body_types[0]
720
+ body_param = next(
721
+ p for sr in yaml_data.values() for p in sr["parameters"] if is_body(p)
722
+ )
723
+ return self._update_body_parameter_helper(yaml_data, body_param, body_type)
724
+
725
+ def update_body_parameter_overload(
726
+ self,
727
+ yaml_data: Dict[str, Any],
728
+ body_type: Dict[str, Any],
729
+ *,
730
+ content_types: Optional[List[str]] = None,
731
+ ) -> Dict[str, Any]:
732
+ """For overloads we already know what body_type we want to go with"""
733
+ body_param = next(
734
+ p for sr in yaml_data.values() for p in sr["parameters"] if is_body(p)
735
+ )
736
+ return self._update_body_parameter_helper(
737
+ yaml_data, body_param, body_type, content_types=content_types
738
+ )
739
+
740
+ def update_flattened_parameter(
741
+ self, yaml_data: Dict[str, Any], body_parameter: Optional[Dict[str, Any]]
742
+ ) -> Dict[str, Any]:
743
+ if not body_parameter:
744
+ raise ValueError("Has to have a body parameter if it's flattened")
745
+ # this means i'm a property that is part of a flattened model
746
+ target_property_name = yaml_data["targetProperty"]["language"]["default"][
747
+ "name"
748
+ ]
749
+ param = self.update_parameter(yaml_data)
750
+ body_parameter.setdefault("propertyToParameterName", {})[
751
+ target_property_name
752
+ ] = param["clientName"]
753
+ param["inFlattenedBody"] = True
754
+ return param
755
+
756
+ def _update_content_type_parameter(
757
+ self,
758
+ yaml_data: Dict[str, Any],
759
+ body_parameter: Optional[Dict[str, Any]],
760
+ request_media_types: List[str],
761
+ *,
762
+ in_overload: bool = False,
763
+ in_overriden: bool = False,
764
+ ) -> Dict[str, Any]:
765
+ # override content type type to string
766
+ if not body_parameter:
767
+ return yaml_data
768
+ param = copy.deepcopy(yaml_data)
769
+ param["schema"] = KNOWN_TYPES["string"] # override to string type
770
+ if (
771
+ body_parameter["type"]["type"] == "binary"
772
+ and not body_parameter["defaultContentType"]
773
+ and not self.legacy
774
+ ):
775
+ param["required"] = True
776
+ else:
777
+ param["required"] = False
778
+ description = param["language"]["default"]["description"]
779
+ if description and description[-1] != ".":
780
+ description += "."
781
+ if not (in_overriden or in_overload):
782
+ param["inDocstring"] = False
783
+ elif in_overload:
784
+ description += (
785
+ " Content type parameter for "
786
+ f"{get_body_type_for_description(body_parameter)} body."
787
+ )
788
+ if not in_overload or (
789
+ body_parameter["type"]["type"] == "binary" and len(request_media_types) > 1
790
+ ):
791
+ content_types = "'" + "', '".join(request_media_types) + "'"
792
+ description += f" Known values are: {content_types}."
793
+ if not in_overload and not in_overriden:
794
+ param["clientDefaultValue"] = body_parameter["defaultContentType"]
795
+ param["language"]["default"]["description"] = description
796
+ return param
797
+
798
+ def _update_parameters_helper(
799
+ self,
800
+ parameters: List[Dict[str, Any]],
801
+ body_parameter: Optional[Dict[str, Any]],
802
+ seen_client_names: Set[str],
803
+ groupers: Dict[str, Dict[str, Any]],
804
+ request_media_types: List[str],
805
+ *,
806
+ in_overload: bool = False,
807
+ in_overriden: bool = False,
808
+ ) -> List[Dict[str, Any]]:
809
+ retval: List[Dict[str, Any]] = []
810
+ has_flattened_body = body_parameter and body_parameter.get("flattened")
811
+ for param in parameters:
812
+ client_name = param["language"]["default"]["name"]
813
+ if param["language"]["default"]["name"] == "$host" or (
814
+ client_name in seen_client_names
815
+ ):
816
+ continue
817
+ seen_client_names.add(client_name)
818
+ if param.get("origin") == "modelerfour:synthesized/api-version":
819
+ param["inDocstring"] = False
820
+ if self.legacy:
821
+ param["implementation"] = "Method"
822
+ param["checkClientInput"] = True
823
+ if has_flattened_body and param.get("targetProperty"):
824
+ retval.append(self.update_flattened_parameter(param, body_parameter))
825
+ continue
826
+ if param["schema"]["type"] == "group":
827
+ # this means i'm a parameter group parameter
828
+ param = self.update_parameter(param)
829
+ param["grouper"] = True
830
+ groupers[param["clientName"]] = param
831
+ retval.append(param)
832
+ continue
833
+ if is_body(param):
834
+ continue
835
+ if param["language"]["default"].get("serializedName") == "Content-Type":
836
+ param = self._update_content_type_parameter(
837
+ param,
838
+ body_parameter,
839
+ request_media_types,
840
+ in_overload=in_overload,
841
+ in_overriden=in_overriden,
842
+ )
843
+ updated_param = self.update_parameter(
844
+ param, in_overload=in_overload, in_overriden=in_overriden
845
+ )
846
+ retval.append(updated_param)
847
+ return retval
848
+
849
+ def update_parameters(
850
+ self,
851
+ yaml_data: Dict[str, Any],
852
+ body_parameter: Optional[Dict[str, Any]],
853
+ *,
854
+ in_overload: bool = False,
855
+ in_overriden: bool = False,
856
+ ) -> List[Dict[str, Any]]:
857
+ retval: List[Dict[str, Any]] = []
858
+ seen_client_names: Set[str] = set()
859
+ groupers: Dict[str, Dict[str, Any]] = {}
860
+ # first update top level parameters
861
+ request_media_types = yaml_data.get("requestMediaTypes", [])
862
+ retval.extend(
863
+ self._update_parameters_helper(
864
+ yaml_data["parameters"],
865
+ body_parameter,
866
+ seen_client_names,
867
+ groupers,
868
+ request_media_types,
869
+ in_overload=in_overload,
870
+ in_overriden=in_overriden,
871
+ )
872
+ )
873
+ # now we handle content type and accept headers.
874
+ # We only care about the content types on the body parameter itself,
875
+ # so ignoring the different content types for now
876
+ if yaml_data.get("requestMediaTypes"):
877
+ sub_requests = yaml_data["requestMediaTypes"].values()
878
+ else:
879
+ sub_requests = yaml_data.get("requests", [])
880
+ for request in sub_requests: # pylint: disable=too-many-nested-blocks
881
+ retval.extend(
882
+ self._update_parameters_helper(
883
+ request.get("parameters", []),
884
+ body_parameter,
885
+ seen_client_names,
886
+ groupers,
887
+ request_media_types,
888
+ in_overload=in_overload,
889
+ in_overriden=in_overriden,
890
+ )
891
+ )
892
+ all_params = (retval + [body_parameter]) if body_parameter else retval
893
+ for grouper_name, grouper in groupers.items():
894
+ grouper["propertyToParameterName"] = {
895
+ next(
896
+ prop
897
+ for prop in grouper["type"]["properties"]
898
+ if p["clientName"].lstrip("_")
899
+ in prop["groupedParameterNames"] # TODO: patching m4
900
+ )["clientName"]: p["clientName"]
901
+ for p in all_params
902
+ if p.get("groupedBy") == grouper_name
903
+ }
904
+ return retval
905
+
906
+ def update_parameter(
907
+ self,
908
+ yaml_data: Dict[str, Any],
909
+ *,
910
+ override_client_name: Optional[str] = None,
911
+ in_overload: bool = False,
912
+ in_overriden: bool = False,
913
+ ) -> Dict[str, Any]:
914
+ param_base = update_parameter_base(
915
+ yaml_data, override_client_name=override_client_name
916
+ )
917
+ type = get_type(yaml_data["schema"])
918
+ if type["type"] == "constant":
919
+ if not param_base["optional"] or (
920
+ param_base["optional"] and not self.default_optional_constants_to_none
921
+ ):
922
+ param_base["clientDefaultValue"] = type["value"]
923
+ protocol_http = yaml_data["protocol"].get("http", {})
924
+ param_base.update(
925
+ {
926
+ "type": type,
927
+ "implementation": yaml_data["implementation"],
928
+ "explode": protocol_http.get("explode", False),
929
+ "inOverload": in_overload,
930
+ "skipUrlEncoding": yaml_data.get("extensions", {}).get(
931
+ "x-ms-skip-url-encoding", False
932
+ ),
933
+ "inDocstring": yaml_data.get("inDocstring", True),
934
+ "inOverriden": in_overriden,
935
+ "delimiter": update_parameter_delimiter(protocol_http.get("style")),
936
+ }
937
+ )
938
+ return param_base
939
+
940
+ def update_global_parameters(
941
+ self, yaml_data: List[Dict[str, Any]]
942
+ ) -> List[Dict[str, Any]]:
943
+ global_params: List[Dict[str, Any]] = []
944
+ for global_parameter in yaml_data:
945
+ client_name: Optional[str] = None
946
+ name = global_parameter["language"]["default"]["name"]
947
+ if name == "$host":
948
+ # I am the non-parameterized endpoint. Modify name based off of flag
949
+
950
+ client_name = "base_url" if self.legacy else "endpoint"
951
+ global_parameter["language"]["default"]["description"] = "Service URL."
952
+ global_params.append(
953
+ self.update_parameter(
954
+ global_parameter, override_client_name=client_name
955
+ )
956
+ )
957
+ return global_params
958
+
959
+ def get_token_credential(self, credential_scopes: List[str]) -> Dict[str, Any]:
960
+ retval = {
961
+ "type": OAUTH_TYPE,
962
+ "policy": {
963
+ "type": "ARMChallengeAuthenticationPolicy"
964
+ if self.azure_arm
965
+ else "BearerTokenCredentialPolicy",
966
+ "credentialScopes": credential_scopes,
967
+ },
968
+ }
969
+ update_type(retval)
970
+ return retval
971
+
972
+ def update_credential_from_security(
973
+ self, yaml_data: Dict[str, Any]
974
+ ) -> Dict[str, Any]:
975
+ retval: Dict[str, Any] = {}
976
+ for scheme in yaml_data.get("schemes", []):
977
+ if scheme["type"] == OAUTH_TYPE:
978
+ # TokenCredential
979
+ retval = self.get_token_credential(scheme["scopes"])
980
+ elif scheme["type"] == KEY_TYPE:
981
+ retval = get_azure_key_credential(scheme["name"])
982
+ return retval
983
+
984
+ def get_credential_scopes_from_flags(self, auth_policy: str) -> List[str]:
985
+ if self.azure_arm:
986
+ return ["https://management.azure.com/.default"]
987
+ credential_scopes_temp = self._autorestapi.get_value("credential-scopes")
988
+ credential_scopes = (
989
+ credential_scopes_temp.split(",") if credential_scopes_temp else None
990
+ )
991
+ if (
992
+ self._autorestapi.get_boolean_value("credential-scopes", False)
993
+ and not credential_scopes
994
+ ):
995
+ raise ValueError(
996
+ "--credential-scopes takes a list of scopes in comma separated format. "
997
+ "For example: --credential-scopes=https://cognitiveservices.azure.com/.default"
998
+ )
999
+ if not credential_scopes:
1000
+ _LOGGER.warning(
1001
+ "You have default credential policy %s "
1002
+ "but not the --credential-scopes flag set while generating non-management plane code. "
1003
+ "This is not recommend because it forces the customer to pass credential scopes "
1004
+ "through kwargs if they want to authenticate.",
1005
+ auth_policy,
1006
+ )
1007
+ credential_scopes = []
1008
+ return credential_scopes
1009
+
1010
+ def update_credential_from_flags(self) -> Dict[str, Any]:
1011
+ default_auth_policy = (
1012
+ "ARMChallengeAuthenticationPolicy"
1013
+ if self.azure_arm
1014
+ else "BearerTokenCredentialPolicy"
1015
+ )
1016
+ auth_policy = (
1017
+ self._autorestapi.get_value("credential-default-policy-type")
1018
+ or default_auth_policy
1019
+ )
1020
+ credential_scopes = self.get_credential_scopes_from_flags(auth_policy)
1021
+ key = self._autorestapi.get_value("credential-key-header-name")
1022
+ if auth_policy.lower() in (
1023
+ "armchallengeauthenticationpolicy",
1024
+ "bearertokencredentialpolicy",
1025
+ ):
1026
+ if key:
1027
+ raise ValueError(
1028
+ "You have passed in a credential key header name with default credential policy type "
1029
+ f"{auth_policy}. This is not allowed, since credential key header "
1030
+ "name is tied with AzureKeyCredentialPolicy. Instead, with this policy it is recommend you "
1031
+ "pass in --credential-scopes."
1032
+ )
1033
+ return self.get_token_credential(credential_scopes)
1034
+ # Otherwise you have AzureKeyCredentialPolicy
1035
+ if self._autorestapi.get_value("credential-scopes"):
1036
+ raise ValueError(
1037
+ "You have passed in credential scopes with default credential policy type "
1038
+ "AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
1039
+ f"{default_auth_policy}. Instead, with this policy "
1040
+ "you must pass in --credential-key-header-name."
1041
+ )
1042
+ if not key:
1043
+ key = "api-key"
1044
+ _LOGGER.info(
1045
+ "Defaulting the AzureKeyCredentialPolicy header's name to 'api-key'"
1046
+ )
1047
+ return get_azure_key_credential(key)
1048
+
1049
+ def update_credential(
1050
+ self, yaml_data: Dict[str, Any], parameters: List[Dict[str, Any]]
1051
+ ) -> None:
1052
+ # then override with credential flags
1053
+ credential_flag = (
1054
+ self._autorestapi.get_boolean_value("add-credentials", False)
1055
+ or self._autorestapi.get_boolean_value("add-credential", False)
1056
+ or self.azure_arm
1057
+ )
1058
+ if credential_flag:
1059
+ credential_type = self.update_credential_from_flags()
1060
+ else:
1061
+ credential_type = self.update_credential_from_security(yaml_data)
1062
+ if not credential_type:
1063
+ return
1064
+ credential = {
1065
+ "type": credential_type,
1066
+ "optional": False,
1067
+ "description": "Credential needed for the client to connect to Azure.",
1068
+ "clientName": "credential",
1069
+ "location": "other",
1070
+ "restApiName": "credential",
1071
+ "implementation": "Client",
1072
+ "skipUrlEncoding": True,
1073
+ "inOverload": False,
1074
+ }
1075
+ if self._autorestapi.get_boolean_value(
1076
+ "version-tolerant"
1077
+ ) or self._autorestapi.get_boolean_value("low-level-client"):
1078
+ parameters.append(credential)
1079
+ else:
1080
+ parameters.insert(0, credential)
1081
+
1082
+ def update_client(self, yaml_data: Dict[str, Any]) -> Dict[str, Any]:
1083
+ parameters = self.update_global_parameters(
1084
+ yaml_data.get("globalParameters", [])
1085
+ )
1086
+ self.update_credential(yaml_data.get("security", {}), parameters)
1087
+ return {
1088
+ "name": yaml_data["language"]["default"]["name"],
1089
+ "description": yaml_data["info"].get("description"),
1090
+ "parameters": parameters,
1091
+ "url": update_client_url(yaml_data)
1092
+ if yaml_data.get("globalParameters")
1093
+ else "",
1094
+ "namespace": self._autorestapi.get_value("namespace")
1095
+ or yaml_data["language"]["default"]["name"],
1096
+ }
1097
+
1098
+ def update_yaml(self, yaml_data: Dict[str, Any]) -> None:
1099
+ """Convert in place the YAML str."""
1100
+ # First we update the types, so we can access for when we're creating parameters etc.
1101
+ for type_group, types in yaml_data["schemas"].items():
1102
+ for t in types:
1103
+ if (
1104
+ type_group == "objects"
1105
+ and t["language"]["default"]["name"] == "CloudError"
1106
+ ):
1107
+ # we don't generate cloud error
1108
+ continue
1109
+ update_type(t)
1110
+ yaml_data["client"] = self.update_client(yaml_data)
1111
+ yaml_data["operationGroups"] = [
1112
+ self.update_operation_group(og) for og in yaml_data["operationGroups"]
1113
+ ]
1114
+ yaml_data["types"] = list(ORIGINAL_ID_TO_UPDATED_TYPE.values()) + list(
1115
+ KNOWN_TYPES.values()
1116
+ )
1117
+ if yaml_data.get("globalParameters"):
1118
+ del yaml_data["globalParameters"]
1119
+ del yaml_data["info"]
1120
+ del yaml_data["language"]
1121
+ del yaml_data["protocol"]
1122
+ if yaml_data.get("schemas"):
1123
+ del yaml_data["schemas"]
1124
+ if yaml_data.get("security"):
1125
+ del yaml_data["security"]
1126
+ ORIGINAL_ID_TO_UPDATED_TYPE.clear()