@autorest/python 5.14.0 → 5.17.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. package/ChangeLog.md +91 -2
  2. package/README.md +30 -4
  3. package/autorest/__init__.py +2 -3
  4. package/autorest/black/__init__.py +12 -5
  5. package/autorest/codegen/__init__.py +130 -179
  6. package/autorest/codegen/models/__init__.py +122 -78
  7. package/autorest/codegen/models/base_builder.py +70 -72
  8. package/autorest/codegen/models/base_model.py +7 -5
  9. package/autorest/codegen/models/{base_schema.py → base_type.py} +62 -49
  10. package/autorest/codegen/models/client.py +195 -36
  11. package/autorest/codegen/models/code_model.py +165 -299
  12. package/autorest/codegen/models/combined_type.py +107 -0
  13. package/autorest/codegen/models/constant_type.py +122 -0
  14. package/autorest/codegen/models/credential_types.py +224 -0
  15. package/autorest/codegen/models/dictionary_type.py +116 -0
  16. package/autorest/codegen/models/enum_type.py +195 -0
  17. package/autorest/codegen/models/imports.py +95 -41
  18. package/autorest/codegen/models/list_type.py +134 -0
  19. package/autorest/codegen/models/lro_operation.py +90 -133
  20. package/autorest/codegen/models/lro_paging_operation.py +28 -12
  21. package/autorest/codegen/models/model_type.py +239 -0
  22. package/autorest/codegen/models/operation.py +415 -241
  23. package/autorest/codegen/models/operation_group.py +82 -88
  24. package/autorest/codegen/models/paging_operation.py +101 -117
  25. package/autorest/codegen/models/parameter.py +307 -322
  26. package/autorest/codegen/models/parameter_list.py +366 -357
  27. package/autorest/codegen/models/primitive_types.py +544 -0
  28. package/autorest/codegen/models/property.py +122 -134
  29. package/autorest/codegen/models/request_builder.py +138 -86
  30. package/autorest/codegen/models/request_builder_parameter.py +122 -79
  31. package/autorest/codegen/models/response.py +325 -0
  32. package/autorest/codegen/models/utils.py +17 -1
  33. package/autorest/codegen/serializers/__init__.py +242 -118
  34. package/autorest/codegen/serializers/builder_serializer.py +863 -1027
  35. package/autorest/codegen/serializers/client_serializer.py +148 -82
  36. package/autorest/codegen/serializers/general_serializer.py +44 -47
  37. package/autorest/codegen/serializers/import_serializer.py +96 -31
  38. package/autorest/codegen/serializers/metadata_serializer.py +39 -79
  39. package/autorest/codegen/serializers/model_base_serializer.py +65 -29
  40. package/autorest/codegen/serializers/model_generic_serializer.py +9 -10
  41. package/autorest/codegen/serializers/model_init_serializer.py +4 -2
  42. package/autorest/codegen/serializers/model_python3_serializer.py +29 -22
  43. package/autorest/codegen/serializers/operation_groups_serializer.py +21 -18
  44. package/autorest/codegen/serializers/operations_init_serializer.py +23 -11
  45. package/autorest/codegen/serializers/parameter_serializer.py +174 -0
  46. package/autorest/codegen/serializers/patch_serializer.py +14 -2
  47. package/autorest/codegen/serializers/request_builders_serializer.py +57 -0
  48. package/autorest/codegen/serializers/utils.py +0 -103
  49. package/autorest/codegen/templates/MANIFEST.in.jinja2 +1 -0
  50. package/autorest/codegen/templates/{service_client.py.jinja2 → client.py.jinja2} +7 -7
  51. package/autorest/codegen/templates/config.py.jinja2 +13 -13
  52. package/autorest/codegen/templates/enum.py.jinja2 +4 -4
  53. package/autorest/codegen/templates/enum_container.py.jinja2 +1 -2
  54. package/autorest/codegen/templates/init.py.jinja2 +9 -6
  55. package/autorest/codegen/templates/keywords.jinja2 +14 -1
  56. package/autorest/codegen/templates/lro_operation.py.jinja2 +6 -5
  57. package/autorest/codegen/templates/lro_paging_operation.py.jinja2 +6 -5
  58. package/autorest/codegen/templates/metadata.json.jinja2 +36 -35
  59. package/autorest/codegen/templates/model.py.jinja2 +23 -29
  60. package/autorest/codegen/templates/model_container.py.jinja2 +2 -1
  61. package/autorest/codegen/templates/model_init.py.jinja2 +9 -8
  62. package/autorest/codegen/templates/operation.py.jinja2 +10 -15
  63. package/autorest/codegen/templates/operation_group.py.jinja2 +14 -13
  64. package/autorest/codegen/templates/operation_groups_container.py.jinja2 +1 -2
  65. package/autorest/codegen/templates/operation_tools.jinja2 +8 -2
  66. package/autorest/codegen/templates/operations_folder_init.py.jinja2 +4 -0
  67. package/autorest/codegen/templates/paging_operation.py.jinja2 +7 -8
  68. package/autorest/codegen/templates/patch.py.jinja2 +18 -29
  69. package/autorest/codegen/templates/request_builder.py.jinja2 +20 -13
  70. package/autorest/codegen/templates/setup.py.jinja2 +9 -3
  71. package/autorest/codegen/templates/vendor.py.jinja2 +12 -2
  72. package/autorest/jsonrpc/__init__.py +7 -12
  73. package/autorest/jsonrpc/localapi.py +4 -3
  74. package/autorest/jsonrpc/server.py +28 -9
  75. package/autorest/jsonrpc/stdstream.py +13 -6
  76. package/autorest/m2r/__init__.py +5 -8
  77. package/autorest/m4reformatter/__init__.py +1108 -0
  78. package/autorest/multiapi/__init__.py +24 -14
  79. package/autorest/multiapi/models/client.py +21 -11
  80. package/autorest/multiapi/models/code_model.py +23 -10
  81. package/autorest/multiapi/models/config.py +4 -1
  82. package/autorest/multiapi/models/constant_global_parameter.py +1 -0
  83. package/autorest/multiapi/models/global_parameter.py +2 -1
  84. package/autorest/multiapi/models/global_parameters.py +14 -8
  85. package/autorest/multiapi/models/imports.py +35 -18
  86. package/autorest/multiapi/models/mixin_operation.py +5 -5
  87. package/autorest/multiapi/models/operation_group.py +2 -1
  88. package/autorest/multiapi/models/operation_mixin_group.py +21 -10
  89. package/autorest/multiapi/serializers/__init__.py +20 -25
  90. package/autorest/multiapi/serializers/import_serializer.py +47 -15
  91. package/autorest/multiapi/serializers/multiapi_serializer.py +17 -17
  92. package/autorest/multiapi/templates/multiapi_config.py.jinja2 +3 -3
  93. package/autorest/multiapi/templates/multiapi_init.py.jinja2 +2 -2
  94. package/autorest/multiapi/templates/multiapi_operations_mixin.py.jinja2 +4 -4
  95. package/autorest/multiapi/templates/multiapi_service_client.py.jinja2 +9 -9
  96. package/autorest/multiapi/utils.py +3 -3
  97. package/autorest/postprocess/__init__.py +202 -0
  98. package/autorest/postprocess/get_all.py +19 -0
  99. package/autorest/postprocess/venvtools.py +73 -0
  100. package/autorest/preprocess/__init__.py +209 -0
  101. package/autorest/preprocess/helpers.py +54 -0
  102. package/autorest/{namer → preprocess}/python_mappings.py +25 -32
  103. package/package.json +3 -3
  104. package/run-python3.js +2 -3
  105. package/venvtools.py +1 -1
  106. package/autorest/codegen/models/constant_schema.py +0 -97
  107. package/autorest/codegen/models/credential_schema.py +0 -90
  108. package/autorest/codegen/models/credential_schema_policy.py +0 -77
  109. package/autorest/codegen/models/dictionary_schema.py +0 -103
  110. package/autorest/codegen/models/enum_schema.py +0 -246
  111. package/autorest/codegen/models/list_schema.py +0 -113
  112. package/autorest/codegen/models/object_schema.py +0 -249
  113. package/autorest/codegen/models/primitive_schemas.py +0 -476
  114. package/autorest/codegen/models/request_builder_parameter_list.py +0 -280
  115. package/autorest/codegen/models/rest.py +0 -42
  116. package/autorest/codegen/models/schema_request.py +0 -45
  117. package/autorest/codegen/models/schema_response.py +0 -123
  118. package/autorest/codegen/serializers/rest_serializer.py +0 -57
  119. package/autorest/namer/__init__.py +0 -25
  120. package/autorest/namer/name_converter.py +0 -412
@@ -0,0 +1,1108 @@
1
+ # pylint: disable=too-many-lines
2
+ # -------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft Corporation. All rights reserved.
4
+ # Licensed under the MIT License. See License.txt in the project root for
5
+ # license information.
6
+ # --------------------------------------------------------------------------
7
+ """The modelerfour reformatter autorest plugin.
8
+ """
9
+ import re
10
+ import copy
11
+ import logging
12
+ from typing import Callable, Dict, Any, Iterable, List, Optional, Set
13
+
14
+ from .. import YamlUpdatePlugin
15
+
16
+ JSON_REGEXP = re.compile(r"^(application|text)/(.+\+)?json$")
17
+ ORIGINAL_ID_TO_UPDATED_TYPE: Dict[int, Dict[str, Any]] = {}
18
+ OAUTH_TYPE = "OAuth2"
19
+ KEY_TYPE = "Key"
20
+
21
+ _LOGGER = logging.getLogger(__name__)
22
+
23
+ # used if we want to get a string / binary type etc
24
+ KNOWN_TYPES: Dict[str, Dict[str, Any]] = {
25
+ "string": {"type": "string"},
26
+ "binary": {"type": "binary"},
27
+ "anydict": {"type": "dict", "elementType": {"type": "any"}},
28
+ }
29
+
30
+
31
+ def is_body(yaml_data: Dict[str, Any]) -> bool:
32
+ """Return true if passed in parameter is a body param"""
33
+ return yaml_data["protocol"]["http"]["in"] == "body"
34
+
35
+
36
+ def get_body_parameter(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
37
+ """Return a request's body parameter"""
38
+ return next(p for p in yaml_data["parameters"] if is_body(p))
39
+
40
+
41
+ def get_azure_key_credential(key: str) -> Dict[str, Any]:
42
+ retval = {
43
+ "type": KEY_TYPE,
44
+ "policy": {"type": "AzureKeyCredentialPolicy", "key": key},
45
+ }
46
+ update_type(retval)
47
+ return retval
48
+
49
+
50
+ def get_type(yaml_data: Dict[str, Any]):
51
+ try:
52
+ return ORIGINAL_ID_TO_UPDATED_TYPE[id(yaml_data)]
53
+ except KeyError:
54
+ return KNOWN_TYPES[yaml_data["type"]]
55
+
56
+
57
+ def _get_api_versions(api_versions: List[Dict[str, str]]) -> List[str]:
58
+ return list({api_version["version"]: None for api_version in api_versions}.keys())
59
+
60
+
61
+ def _update_type_base(updated_type: str, yaml_data: Dict[str, Any]) -> Dict[str, Any]:
62
+ return {
63
+ "type": updated_type,
64
+ "clientDefaultValue": yaml_data.get("defaultValue"),
65
+ "xmlMetadata": yaml_data.get("serialization", {}).get("xml", {}),
66
+ "apiVersions": _get_api_versions(yaml_data.get("apiVersions", [])),
67
+ }
68
+
69
+
70
+ def update_list(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
71
+ base = _update_type_base("list", yaml_data)
72
+ base["elementType"] = update_type(yaml_data["elementType"])
73
+ base["maxItems"] = yaml_data.get("maxItems")
74
+ base["minItems"] = yaml_data.get("minItems")
75
+ base["uniqueItems"] = yaml_data.get("uniqueItems", False)
76
+ return base
77
+
78
+
79
+ def update_dict(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
80
+ base = _update_type_base("dict", yaml_data)
81
+ base["elementType"] = update_type(yaml_data["elementType"])
82
+ return base
83
+
84
+
85
+ def update_constant(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
86
+ base = _update_type_base("constant", yaml_data)
87
+ base["valueType"] = update_type(yaml_data["valueType"])
88
+ base["value"] = yaml_data["value"]["value"]
89
+ return base
90
+
91
+
92
+ def update_enum_value(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
93
+ return {
94
+ "name": yaml_data["language"]["default"]["name"],
95
+ "value": yaml_data["value"],
96
+ "description": yaml_data["language"]["default"]["description"],
97
+ }
98
+
99
+
100
+ def update_enum(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
101
+ base = _update_type_base("enum", yaml_data)
102
+ base.update(
103
+ {
104
+ "name": yaml_data["language"]["default"]["name"],
105
+ "valueType": update_type(yaml_data["choiceType"]),
106
+ "values": [update_enum_value(v) for v in yaml_data["choices"]],
107
+ "description": yaml_data["language"]["default"]["description"],
108
+ }
109
+ )
110
+ return base
111
+
112
+
113
+ def update_property(
114
+ yaml_data: Dict[str, Any], has_additional_properties: bool
115
+ ) -> Dict[str, Any]:
116
+ client_name = yaml_data["language"]["default"]["name"]
117
+ if has_additional_properties and client_name == "additional_properties":
118
+ client_name = "additional_properties1"
119
+ return {
120
+ "clientName": client_name,
121
+ "restApiName": yaml_data["serializedName"],
122
+ "flattenedNames": yaml_data.get("flattenedNames", []),
123
+ "type": update_type(yaml_data["schema"]),
124
+ "optional": not yaml_data.get("required"),
125
+ "description": yaml_data["language"]["default"]["description"],
126
+ "isDiscriminator": yaml_data.get("isDiscriminator"),
127
+ "readonly": yaml_data.get("readOnly", False),
128
+ "groupedParameterNames": [
129
+ op["language"]["default"]["name"].lstrip("_") # TODO: patching m4
130
+ for op in yaml_data.get("originalParameter", [])
131
+ ],
132
+ }
133
+
134
+
135
+ def update_discriminated_subtypes(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
136
+ return {
137
+ obj["discriminatorValue"]: obj["language"]["default"]["name"]
138
+ for obj in yaml_data.get("discriminator", {}).get("immediate", {}).values()
139
+ }
140
+
141
+
142
+ def create_model(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
143
+ base = _update_type_base("model", yaml_data)
144
+ base["name"] = yaml_data["language"]["default"]["name"]
145
+ base["description"] = yaml_data["language"]["default"]["description"]
146
+ base["isXml"] = "xml" in yaml_data.get("serializationFormats", [])
147
+ return base
148
+
149
+
150
+ def fill_model(
151
+ yaml_data: Dict[str, Any], current_model: Dict[str, Any]
152
+ ) -> Dict[str, Any]:
153
+ properties = []
154
+ yaml_parents = yaml_data.get("parents", {}).get("immediate", [])
155
+ dict_parents = [p for p in yaml_parents if p["type"] == "dictionary"]
156
+ if dict_parents:
157
+ # add additional properties property
158
+ properties.append(
159
+ {
160
+ "clientName": "additional_properties",
161
+ "restApiName": "",
162
+ "type": update_type(dict_parents[0]),
163
+ "optional": True,
164
+ "description": "Unmatched properties from the message are deserialized to this collection.",
165
+ "isDiscriminator": False,
166
+ "readonly": False,
167
+ }
168
+ )
169
+ properties.extend(
170
+ [
171
+ update_property(p, has_additional_properties=bool(dict_parents))
172
+ for p in yaml_data.get("properties", [])
173
+ ]
174
+ )
175
+ current_model.update(
176
+ {
177
+ "properties": properties,
178
+ "parents": [
179
+ update_type(yaml_data=p) for p in yaml_parents if p["type"] == "object"
180
+ ],
181
+ "discriminatedSubtypes": update_discriminated_subtypes(yaml_data),
182
+ "discriminatorValue": yaml_data.get("discriminatorValue"),
183
+ }
184
+ )
185
+ return current_model
186
+
187
+
188
+ def update_number_type(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
189
+ updated_type = "integer" if yaml_data["type"] == "integer" else "float"
190
+ base = _update_type_base(updated_type, yaml_data)
191
+ base.update(
192
+ {
193
+ "precision": yaml_data.get("precision"),
194
+ "multipleOf": yaml_data.get("multipleOf"),
195
+ "maximum": yaml_data.get("maximum"),
196
+ "minimum": yaml_data.get("minimum"),
197
+ "exclusiveMaximum": yaml_data.get("exclusiveMaximum"),
198
+ "exclusiveMinimum": yaml_data.get("exclusiveMinimum"),
199
+ }
200
+ )
201
+ return base
202
+
203
+
204
+ def update_primitive( # pylint: disable=too-many-return-statements
205
+ type_group: str, yaml_data: Dict[str, Any]
206
+ ) -> Dict[str, Any]:
207
+ if type_group in ("integer", "number"):
208
+ return update_number_type(yaml_data)
209
+ if type_group in ("string", "uuid", "uri"):
210
+ if any(
211
+ r in yaml_data
212
+ for r in (
213
+ "maxLength",
214
+ "minLength",
215
+ "pattern",
216
+ "defaultValue",
217
+ "serialization",
218
+ )
219
+ ):
220
+ base = _update_type_base("string", yaml_data)
221
+ base.update(
222
+ {
223
+ "maxLength": yaml_data.get("maxLength"),
224
+ "minLength": yaml_data.get("minLength"),
225
+ "pattern": yaml_data.get("pattern"),
226
+ }
227
+ )
228
+ return base
229
+ return KNOWN_TYPES["string"]
230
+ if type_group == "binary":
231
+ return KNOWN_TYPES["binary"]
232
+ if type_group == "date-time":
233
+ base = _update_type_base("datetime", yaml_data)
234
+ base["format"] = yaml_data["format"]
235
+ return base
236
+ if type_group == "byte-array":
237
+ base = _update_type_base("byte-array", yaml_data)
238
+ base["format"] = yaml_data["format"]
239
+ return base
240
+ return _update_type_base(type_group, yaml_data)
241
+
242
+
243
+ def update_types(yaml_data: List[Dict[str, Any]]) -> Dict[str, Any]:
244
+ types: List[Dict[str, Any]] = []
245
+ for type in yaml_data:
246
+ if KNOWN_TYPES.get(type["type"]):
247
+ types.append(KNOWN_TYPES[type["type"]])
248
+ else:
249
+ types.append(
250
+ next(
251
+ v for v in ORIGINAL_ID_TO_UPDATED_TYPE.values() if id(v) == id(type)
252
+ )
253
+ )
254
+ retval = {"type": "combined", "types": types}
255
+ ORIGINAL_ID_TO_UPDATED_TYPE[id(retval)] = retval
256
+ return retval
257
+
258
+
259
+ def update_type(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
260
+ if id(yaml_data) in ORIGINAL_ID_TO_UPDATED_TYPE:
261
+ return ORIGINAL_ID_TO_UPDATED_TYPE[id(yaml_data)]
262
+ type_group = yaml_data["type"]
263
+ if type_group == "array":
264
+ updated_type = update_list(yaml_data)
265
+ elif type_group == "dictionary":
266
+ updated_type = update_dict(yaml_data)
267
+ elif type_group == "constant":
268
+ updated_type = update_constant(yaml_data)
269
+ elif type_group in ("choice", "sealed-choice"):
270
+ updated_type = update_enum(yaml_data)
271
+ elif type_group in (OAUTH_TYPE, KEY_TYPE):
272
+ updated_type = yaml_data
273
+ elif type_group in ("object", "group"):
274
+ # avoiding infinite loop
275
+ initial_model = create_model(yaml_data)
276
+ ORIGINAL_ID_TO_UPDATED_TYPE[id(yaml_data)] = initial_model
277
+ updated_type = fill_model(yaml_data, initial_model)
278
+ else:
279
+ updated_type = update_primitive(type_group, yaml_data)
280
+ ORIGINAL_ID_TO_UPDATED_TYPE[id(yaml_data)] = updated_type
281
+ return updated_type
282
+
283
+
284
+ def update_parameter_base(
285
+ yaml_data: Dict[str, Any], *, override_client_name: Optional[str] = None
286
+ ) -> Dict[str, Any]:
287
+ location = yaml_data["protocol"].get("http", {}).get("in")
288
+ if not location:
289
+ location = "other"
290
+ if location == "uri":
291
+ location = "endpointPath"
292
+ grouped_by = (
293
+ yaml_data["groupedBy"]["language"]["default"]["name"]
294
+ if yaml_data.get("groupedBy")
295
+ else None
296
+ )
297
+ client_name: str = override_client_name or yaml_data["language"]["default"]["name"]
298
+ if grouped_by and client_name[0] != "_":
299
+ # this is an m4 bug, doesn't hide constant grouped params, patching m4 for now
300
+ client_name = "_" + client_name
301
+ return {
302
+ "optional": not yaml_data.get("required", False),
303
+ "description": yaml_data["language"]["default"]["description"],
304
+ "clientName": client_name,
305
+ "restApiName": yaml_data["language"]["default"].get("serializedName"),
306
+ "clientDefaultValue": yaml_data.get("clientDefaultValue"),
307
+ "location": location,
308
+ "groupedBy": grouped_by,
309
+ "checkClientInput": yaml_data.get("checkClientInput", False),
310
+ }
311
+
312
+
313
+ def update_parameter_delimiter(style: Optional[str]) -> Optional[str]:
314
+ if not style:
315
+ return None
316
+ if style in ("form", "simple"):
317
+ return "comma"
318
+ if style in ("spaceDelimited", "pipeDelimited", "tabDelimited"):
319
+ return style.replace("Delimited", "")
320
+ return None
321
+
322
+
323
+ def get_all_body_types(yaml_data: Dict[str, Any]) -> List[Dict[str, Any]]:
324
+ seen_body_types = {}
325
+ for schema_request in yaml_data.values():
326
+ body_param = get_body_parameter(schema_request)
327
+ seen_body_types[id(body_param["schema"])] = update_type(body_param["schema"])
328
+ return list(seen_body_types.values())
329
+
330
+
331
+ def get_body_type_for_description(body_parameter: Dict[str, Any]) -> str:
332
+ if body_parameter["type"]["type"] == "binary":
333
+ return "binary"
334
+ if body_parameter["type"]["type"] == "string":
335
+ return "string"
336
+ return "JSON"
337
+
338
+
339
+ def add_lro_information(operation: Dict[str, Any], yaml_data: Dict[str, Any]) -> None:
340
+ operation["discriminator"] = "lro"
341
+ extensions = yaml_data["extensions"]
342
+ operation["lroOptions"] = extensions.get("x-ms-long-running-operation-options")
343
+ for response in operation["responses"]:
344
+ response["pollerSync"] = extensions.get("x-python-custom-poller-sync")
345
+ response["pollerAsync"] = extensions.get("x-python-custom-poller-async")
346
+ response["pollingMethodSync"] = extensions.get(
347
+ "x-python-custom-default-polling-method-sync"
348
+ )
349
+ response["pollingMethodAsync"] = extensions.get(
350
+ "x-python-custom-default-polling-method-async"
351
+ )
352
+
353
+
354
+ def filter_out_paging_next_operation(
355
+ yaml_data: List[Dict[str, Any]]
356
+ ) -> List[Dict[str, Any]]:
357
+ next_operations: Set[str] = set()
358
+ for operation in yaml_data:
359
+ next_operation = operation.get("nextOperation")
360
+ if not next_operation:
361
+ continue
362
+ next_operations.add(next_operation["name"])
363
+ return [o for o in yaml_data if o["name"] not in next_operations]
364
+
365
+
366
+ def update_response_header(yaml_data: Dict[str, Any]) -> Dict[str, Any]:
367
+ return {
368
+ "restApiName": yaml_data["header"],
369
+ "type": update_type(yaml_data["schema"]),
370
+ }
371
+
372
+
373
+ def update_response(
374
+ yaml_data: Dict[str, Any],
375
+ ) -> Dict[str, Any]:
376
+ if yaml_data.get("binary"):
377
+ type = KNOWN_TYPES["binary"]
378
+ elif yaml_data.get("schema"):
379
+ type = get_type(yaml_data["schema"])
380
+ else:
381
+ type = None
382
+ return {
383
+ "headers": [
384
+ update_response_header(h)
385
+ for h in yaml_data["protocol"]["http"].get("headers", [])
386
+ ],
387
+ "statusCodes": [
388
+ int(code) if code != "default" else "default"
389
+ for code in yaml_data["protocol"]["http"]["statusCodes"]
390
+ ],
391
+ "type": type,
392
+ "nullable": yaml_data.get("nullable", False),
393
+ }
394
+
395
+
396
+ def _get_default_content_type( # pylint: disable=too-many-return-statements
397
+ content_types: Iterable[str],
398
+ ) -> Optional[str]:
399
+ json_values = [ct for ct in content_types if JSON_REGEXP.match(ct)]
400
+ if json_values:
401
+ if "application/json" in json_values:
402
+ return "application/json"
403
+ return json_values[0]
404
+
405
+ xml_values = [ct for ct in content_types if "xml" in ct]
406
+ if xml_values:
407
+ if "application/xml" in xml_values:
408
+ return "application/xml"
409
+ return xml_values[0]
410
+
411
+ if "application/octet-stream" in content_types:
412
+ return "application/octet-stream"
413
+ if "application/x-www-form-urlencoded" in content_types:
414
+ return "application/x-www-form-urlencoded"
415
+ return None
416
+
417
+
418
+ def update_client_url(yaml_data: Dict[str, Any]) -> str:
419
+ if any(
420
+ p
421
+ for p in yaml_data["globalParameters"]
422
+ if p["language"]["default"]["name"] == "$host"
423
+ ):
424
+ # this means we DO NOT have a parameterized host
425
+ # in order to share code better, going to make it a "parameterized host" of
426
+ # just the endpoint parameter
427
+ return "{endpoint}"
428
+ # we have a parameterized host. Return first url from first request, quite gross
429
+ return yaml_data["operationGroups"][0]["operations"][0]["requests"][0]["protocol"][
430
+ "http"
431
+ ]["uri"]
432
+
433
+
434
+ def update_content_type_parameter(
435
+ yaml_data: Dict[str, Any],
436
+ body_parameter: Optional[Dict[str, Any]],
437
+ request_media_types: List[str],
438
+ *,
439
+ in_overload: bool = False,
440
+ in_overriden: bool = False,
441
+ ) -> Dict[str, Any]:
442
+ # override content type type to string
443
+ if not body_parameter:
444
+ return yaml_data
445
+ param = copy.deepcopy(yaml_data)
446
+ param["schema"] = KNOWN_TYPES["string"] # override to string type
447
+ param["required"] = False
448
+ description = param["language"]["default"]["description"]
449
+ if description and description[-1] != ".":
450
+ description += "."
451
+ if not (in_overriden or in_overload):
452
+ param["inDocstring"] = False
453
+ elif in_overload:
454
+ description += (
455
+ " Content type parameter for "
456
+ f"{get_body_type_for_description(body_parameter)} body."
457
+ )
458
+ elif not in_overload:
459
+ content_types = "'" + "', '".join(request_media_types) + "'"
460
+ description += f" Known values are: {content_types}."
461
+ if not in_overload and not in_overriden:
462
+ param["clientDefaultValue"] = body_parameter["defaultContentType"]
463
+ param["language"]["default"]["description"] = description
464
+ return param
465
+
466
+
467
+ class M4Reformatter(YamlUpdatePlugin): # pylint: disable=too-many-public-methods
468
+ """Add Python naming information."""
469
+
470
+ @property
471
+ def azure_arm(self) -> bool:
472
+ return bool(self._autorestapi.get_boolean_value("azure-arm"))
473
+
474
+ @property
475
+ def default_optional_constants_to_none(self) -> bool:
476
+ return bool(
477
+ self._autorestapi.get_boolean_value("default-optional-constants-to-none")
478
+ or self._autorestapi.get_boolean_value("version-tolerant")
479
+ )
480
+
481
+ def update_overloads(
482
+ self,
483
+ group_name: str,
484
+ yaml_data: Dict[str, Any],
485
+ body_parameter: Optional[Dict[str, Any]],
486
+ *,
487
+ content_types: Optional[List[str]] = None,
488
+ ) -> List[Dict[str, Any]]:
489
+ overloads: List[Dict[str, Any]] = []
490
+ if not body_parameter:
491
+ return overloads
492
+ body_types = body_parameter["type"].get("types", [])
493
+ if not body_types:
494
+ return overloads
495
+ for body_type in body_types:
496
+ overload = self.update_overload(
497
+ group_name, yaml_data, body_type, content_types=content_types
498
+ )
499
+ for parameter in overload["parameters"]:
500
+ if parameter["restApiName"] == "Content-Type":
501
+ parameter["clientDefaultValue"] = overload["bodyParameter"][
502
+ "defaultContentType"
503
+ ]
504
+ overloads.append(overload)
505
+ return overloads
506
+
507
+ def _update_operation_helper(
508
+ self,
509
+ group_name: str,
510
+ yaml_data: Dict[str, Any],
511
+ body_parameter: Optional[Dict[str, Any]],
512
+ *,
513
+ is_overload: bool = False,
514
+ ) -> Dict[str, Any]:
515
+ in_overriden = (
516
+ body_parameter["type"]["type"] == "combined" if body_parameter else False
517
+ )
518
+ return {
519
+ "name": yaml_data["language"]["default"]["name"],
520
+ "description": yaml_data["language"]["default"]["description"],
521
+ "summary": yaml_data["language"]["default"].get("summary"),
522
+ "url": yaml_data["requests"][0]["protocol"]["http"]["path"],
523
+ "method": yaml_data["requests"][0]["protocol"]["http"]["method"].upper(),
524
+ "parameters": self.update_parameters(
525
+ yaml_data,
526
+ body_parameter,
527
+ in_overload=is_overload,
528
+ in_overriden=in_overriden,
529
+ ),
530
+ "bodyParameter": body_parameter,
531
+ "responses": [update_response(r) for r in yaml_data.get("responses", [])],
532
+ "exceptions": [
533
+ update_response(e)
534
+ for e in yaml_data.get("exceptions", [])
535
+ if not (
536
+ e.get("schema")
537
+ and e["schema"]["language"]["default"]["name"] == "CloudError"
538
+ )
539
+ ],
540
+ "groupName": group_name,
541
+ "discriminator": "operation",
542
+ "isOverload": is_overload,
543
+ "apiVersions": _get_api_versions(yaml_data.get("apiVersions", [])),
544
+ }
545
+
546
+ def get_operation_creator(
547
+ self, yaml_data: Dict[str, Any]
548
+ ) -> Callable[[str, Dict[str, Any]], Dict[str, Any]]:
549
+ lro_operation = yaml_data.get("extensions", {}).get(
550
+ "x-ms-long-running-operation"
551
+ )
552
+ paging_operation = yaml_data.get("extensions", {}).get("x-ms-pageable")
553
+ if lro_operation and paging_operation:
554
+ return self.update_lro_paging_operation
555
+ if lro_operation:
556
+ return self.update_lro_operation
557
+ if paging_operation:
558
+ return self.update_paging_operation
559
+ return self.update_operation
560
+
561
+ def update_operation(
562
+ self, group_name: str, yaml_data: Dict[str, Any]
563
+ ) -> Dict[str, Any]:
564
+ body_parameter = (
565
+ self.update_body_parameter(yaml_data["requestMediaTypes"])
566
+ if yaml_data.get("requestMediaTypes")
567
+ else None
568
+ )
569
+ content_types = None
570
+ if ( # pylint: disable=too-many-boolean-expressions
571
+ body_parameter
572
+ and body_parameter["type"]["type"] != "combined"
573
+ and yaml_data.get("requestMediaTypes")
574
+ and any(
575
+ ct for ct in yaml_data["requestMediaTypes"] if JSON_REGEXP.match(ct)
576
+ )
577
+ and body_parameter["type"]["type"] in ("model", "dict", "list")
578
+ and not body_parameter["type"]["xmlMetadata"]
579
+ and not body_parameter.get("flattened")
580
+ and not body_parameter.get("groupedBy")
581
+ ):
582
+ combined_type = update_types(
583
+ [body_parameter["type"], KNOWN_TYPES["binary"]]
584
+ )
585
+ body_parameter["type"] = combined_type
586
+ content_types = body_parameter["contentTypes"]
587
+ operation = self._update_operation_helper(group_name, yaml_data, body_parameter)
588
+ operation["overloads"] = self.update_overloads(
589
+ group_name, yaml_data, body_parameter, content_types=content_types
590
+ )
591
+ return operation
592
+
593
+ def add_paging_information(
594
+ self, group_name: str, operation: Dict[str, Any], yaml_data: Dict[str, Any]
595
+ ) -> None:
596
+ operation["discriminator"] = "paging"
597
+ operation["itemName"] = yaml_data["extensions"]["x-ms-pageable"].get(
598
+ "itemName", "value"
599
+ )
600
+ operation["continuationTokenName"] = yaml_data["extensions"][
601
+ "x-ms-pageable"
602
+ ].get("nextLinkName")
603
+ if yaml_data["language"]["default"]["paging"].get("nextLinkOperation"):
604
+ operation["nextOperation"] = self.update_operation(
605
+ group_name=group_name,
606
+ yaml_data=yaml_data["language"]["default"]["paging"][
607
+ "nextLinkOperation"
608
+ ],
609
+ )
610
+ extensions = yaml_data["extensions"]
611
+ for response in operation["responses"]:
612
+ response["pagerSync"] = extensions.get("x-python-custom-pager-sync")
613
+ response["pagerAsync"] = extensions.get("x-python-custom-pager-async")
614
+
615
+ def update_paging_operation(
616
+ self, group_name: str, yaml_data: Dict[str, Any]
617
+ ) -> Dict[str, Any]:
618
+ base_operation = self.update_operation(group_name, yaml_data)
619
+ self.add_paging_information(group_name, base_operation, yaml_data)
620
+ return base_operation
621
+
622
+ def update_lro_paging_operation(
623
+ self, group_name: str, yaml_data: Dict[str, Any]
624
+ ) -> Dict[str, Any]:
625
+ operation = self.update_lro_operation(group_name, yaml_data)
626
+ self.add_paging_information(group_name, operation, yaml_data)
627
+ operation["discriminator"] = "lropaging"
628
+ return operation
629
+
630
+ def update_lro_operation(
631
+ self, group_name: str, yaml_data: Dict[str, Any]
632
+ ) -> Dict[str, Any]:
633
+ base_operation = self.update_operation(group_name, yaml_data)
634
+ add_lro_information(base_operation, yaml_data)
635
+ for overload in base_operation["overloads"]:
636
+ add_lro_information(overload, yaml_data)
637
+ return base_operation
638
+
639
+ def update_overload(
640
+ self,
641
+ group_name: str,
642
+ yaml_data: Dict[str, Any],
643
+ body_type: Dict[str, Any],
644
+ *,
645
+ content_types: Optional[List[str]] = None,
646
+ ) -> Dict[str, Any]:
647
+ body_parameter = self.update_body_parameter_overload(
648
+ yaml_data["requestMediaTypes"], body_type, content_types=content_types
649
+ )
650
+ return self._update_operation_helper(
651
+ group_name, yaml_data, body_parameter, is_overload=True
652
+ )
653
+
654
+ def update_operation_group(self, yaml_data: Dict[str, Any]) -> Dict[str, Any]:
655
+ property_name = yaml_data["language"]["default"]["name"]
656
+ return {
657
+ "propertyName": property_name,
658
+ "className": property_name,
659
+ "operations": filter_out_paging_next_operation(
660
+ [
661
+ self.get_operation_creator(o)(property_name, o)
662
+ for o in yaml_data["operations"]
663
+ ]
664
+ ),
665
+ }
666
+
667
+ def _update_body_parameter_helper(
668
+ self,
669
+ yaml_data: Dict[str, Any],
670
+ body_param: Dict[str, Any],
671
+ body_type: Dict[str, Any],
672
+ *,
673
+ content_types: Optional[List[str]] = None,
674
+ ) -> Dict[str, Any]:
675
+ flattened = body_param.get("flattened")
676
+ is_partial_body = body_param.get("isPartialBody")
677
+ param_base = update_parameter_base(body_param)
678
+ body_param = copy.deepcopy(param_base)
679
+ body_param["type"] = body_type
680
+ body_param["contentTypes"] = content_types or [
681
+ ct
682
+ for ct, request in yaml_data.items()
683
+ if id(body_type)
684
+ == id(
685
+ ORIGINAL_ID_TO_UPDATED_TYPE[id(get_body_parameter(request)["schema"])]
686
+ )
687
+ ]
688
+ # get default content type
689
+ body_param["defaultContentType"] = _get_default_content_type(
690
+ body_param["contentTypes"]
691
+ )
692
+ if body_param["type"]["type"] == "constant":
693
+ if not body_param["optional"] or (
694
+ body_param["optional"] and not self.default_optional_constants_to_none
695
+ ):
696
+ body_param["clientDefaultValue"] = body_type["value"]
697
+ body_param["flattened"] = flattened
698
+ body_param["isPartialBody"] = is_partial_body
699
+ return body_param
700
+
701
+ def update_multipart_body_parameter(
702
+ self, yaml_data: Dict[str, Any], client_name: str, description: str
703
+ ) -> Dict[str, Any]:
704
+ first_value = list(yaml_data.values())[0]
705
+ entries = [
706
+ self._update_body_parameter_helper(yaml_data, p, update_type(p["schema"]))
707
+ for p in first_value["parameters"]
708
+ if is_body(p)
709
+ ]
710
+ return {
711
+ "optional": not first_value.get("required", False),
712
+ "description": description,
713
+ "clientName": client_name,
714
+ "restApiName": client_name,
715
+ "clientDefaultValue": None,
716
+ "location": "Method",
717
+ "type": KNOWN_TYPES["anydict"],
718
+ "contentTypes": list(yaml_data.keys()),
719
+ "defaultContentType": _get_default_content_type(yaml_data.keys()),
720
+ "entries": entries,
721
+ }
722
+
723
+ def update_body_parameter(self, yaml_data: Dict[str, Any]) -> Dict[str, Any]:
724
+ protocol_http = list(yaml_data.values())[0].get("protocol", {}).get("http", {})
725
+ if protocol_http.get("multipart"):
726
+ return self.update_multipart_body_parameter(
727
+ yaml_data, "files", "Multipart input for files."
728
+ )
729
+ if protocol_http.get("knownMediaType") == "form":
730
+ return self.update_multipart_body_parameter(
731
+ yaml_data, "data", "Multipart input for form encoded data."
732
+ )
733
+ body_types = get_all_body_types(yaml_data)
734
+ if len(body_types) > 1 and not yaml_data.get("flattened"):
735
+ body_type = update_types(body_types)
736
+ else:
737
+ body_type = body_types[0]
738
+ body_param = next(
739
+ p for sr in yaml_data.values() for p in sr["parameters"] if is_body(p)
740
+ )
741
+ return self._update_body_parameter_helper(yaml_data, body_param, body_type)
742
+
743
+ def update_body_parameter_overload(
744
+ self,
745
+ yaml_data: Dict[str, Any],
746
+ body_type: Dict[str, Any],
747
+ *,
748
+ content_types: Optional[List[str]] = None,
749
+ ) -> Dict[str, Any]:
750
+ """For overloads we already know what body_type we want to go with"""
751
+ body_param = next(
752
+ p for sr in yaml_data.values() for p in sr["parameters"] if is_body(p)
753
+ )
754
+ return self._update_body_parameter_helper(
755
+ yaml_data, body_param, body_type, content_types=content_types
756
+ )
757
+
758
+ def update_flattened_parameter(
759
+ self, yaml_data: Dict[str, Any], body_parameter: Optional[Dict[str, Any]]
760
+ ) -> Dict[str, Any]:
761
+ if not body_parameter:
762
+ raise ValueError("Has to have a body parameter if it's flattened")
763
+ # this means i'm a property that is part of a flattened model
764
+ target_property_name = yaml_data["targetProperty"]["language"]["default"][
765
+ "name"
766
+ ]
767
+ param = self.update_parameter(yaml_data)
768
+ body_parameter.setdefault("propertyToParameterName", {})[
769
+ target_property_name
770
+ ] = param["clientName"]
771
+ param["inFlattenedBody"] = True
772
+ return param
773
+
774
+ def _update_parameters_helper(
775
+ self,
776
+ parameters: List[Dict[str, Any]],
777
+ body_parameter: Optional[Dict[str, Any]],
778
+ seen_rest_api_names: Set[str],
779
+ groupers: Dict[str, Dict[str, Any]],
780
+ request_media_types: List[str],
781
+ *,
782
+ in_overload: bool = False,
783
+ in_overriden: bool = False,
784
+ ) -> List[Dict[str, Any]]:
785
+ retval: List[Dict[str, Any]] = []
786
+ has_flattened_body = body_parameter and body_parameter.get("flattened")
787
+ for param in parameters:
788
+ serialized_name = param["language"]["default"].get("serializedName")
789
+ if param["language"]["default"]["name"] == "$host" or (
790
+ serialized_name and serialized_name in seen_rest_api_names
791
+ ):
792
+ continue
793
+ if param.get("origin") == "modelerfour:synthesized/api-version":
794
+ param["inDocstring"] = False
795
+ param["implementation"] = "Method"
796
+ param["checkClientInput"] = True
797
+ if has_flattened_body and param.get("targetProperty"):
798
+ retval.append(self.update_flattened_parameter(param, body_parameter))
799
+ continue
800
+ if param["schema"]["type"] == "group":
801
+ # this means i'm a parameter group parameter
802
+ param = self.update_parameter(param)
803
+ param["grouper"] = True
804
+ groupers[param["clientName"]] = param
805
+ retval.append(param)
806
+ continue
807
+ if is_body(param):
808
+ continue
809
+ if serialized_name == "Content-Type":
810
+ param = update_content_type_parameter(
811
+ param,
812
+ body_parameter,
813
+ request_media_types,
814
+ in_overload=in_overload,
815
+ in_overriden=in_overriden,
816
+ )
817
+ updated_param = self.update_parameter(
818
+ param, in_overload=in_overload, in_overriden=in_overriden
819
+ )
820
+ retval.append(updated_param)
821
+ seen_rest_api_names.add(updated_param["restApiName"])
822
+ return retval
823
+
824
+ def update_parameters(
825
+ self,
826
+ yaml_data: Dict[str, Any],
827
+ body_parameter: Optional[Dict[str, Any]],
828
+ *,
829
+ in_overload: bool = False,
830
+ in_overriden: bool = False,
831
+ ) -> List[Dict[str, Any]]:
832
+ retval: List[Dict[str, Any]] = []
833
+ seen_rest_api_names: Set[str] = set()
834
+ groupers: Dict[str, Dict[str, Any]] = {}
835
+ # first update top level parameters
836
+ request_media_types = yaml_data.get("requestMediaTypes", [])
837
+ retval.extend(
838
+ self._update_parameters_helper(
839
+ yaml_data["parameters"],
840
+ body_parameter,
841
+ seen_rest_api_names,
842
+ groupers,
843
+ request_media_types,
844
+ in_overload=in_overload,
845
+ in_overriden=in_overriden,
846
+ )
847
+ )
848
+ # now we handle content type and accept headers.
849
+ # We only care about the content types on the body parameter itself,
850
+ # so ignoring the different content types for now
851
+ if yaml_data.get("requestMediaTypes"):
852
+ sub_requests = yaml_data["requestMediaTypes"].values()
853
+ else:
854
+ sub_requests = yaml_data.get("requests", [])
855
+ for request in sub_requests: # pylint: disable=too-many-nested-blocks
856
+ retval.extend(
857
+ self._update_parameters_helper(
858
+ request.get("parameters", []),
859
+ body_parameter,
860
+ seen_rest_api_names,
861
+ groupers,
862
+ request_media_types,
863
+ in_overload=in_overload,
864
+ in_overriden=in_overriden,
865
+ )
866
+ )
867
+ all_params = (retval + [body_parameter]) if body_parameter else retval
868
+ for grouper_name, grouper in groupers.items():
869
+ grouper["propertyToParameterName"] = {
870
+ next(
871
+ prop
872
+ for prop in grouper["type"]["properties"]
873
+ if p["clientName"].lstrip("_")
874
+ in prop["groupedParameterNames"] # TODO: patching m4
875
+ )["clientName"]: p["clientName"]
876
+ for p in all_params
877
+ if p.get("groupedBy") == grouper_name
878
+ }
879
+ return retval
880
+
881
+ def update_parameter(
882
+ self,
883
+ yaml_data: Dict[str, Any],
884
+ *,
885
+ override_client_name: Optional[str] = None,
886
+ in_overload: bool = False,
887
+ in_overriden: bool = False,
888
+ ) -> Dict[str, Any]:
889
+ param_base = update_parameter_base(
890
+ yaml_data, override_client_name=override_client_name
891
+ )
892
+ type = get_type(yaml_data["schema"])
893
+ if type["type"] == "constant":
894
+ if not param_base["optional"] or (
895
+ param_base["optional"] and not self.default_optional_constants_to_none
896
+ ):
897
+ param_base["clientDefaultValue"] = type["value"]
898
+ protocol_http = yaml_data["protocol"].get("http", {})
899
+ param_base.update(
900
+ {
901
+ "type": type,
902
+ "implementation": yaml_data["implementation"],
903
+ "explode": protocol_http.get("explode", False),
904
+ "inOverload": in_overload,
905
+ "skipUrlEncoding": yaml_data.get("extensions", {}).get(
906
+ "x-ms-skip-url-encoding", False
907
+ ),
908
+ "inDocstring": yaml_data.get("inDocstring", True),
909
+ "inOverriden": in_overriden,
910
+ "delimiter": update_parameter_delimiter(protocol_http.get("style")),
911
+ }
912
+ )
913
+ return param_base
914
+
915
+ def update_global_parameters(
916
+ self, yaml_data: List[Dict[str, Any]]
917
+ ) -> List[Dict[str, Any]]:
918
+ global_params: List[Dict[str, Any]] = []
919
+ for global_parameter in yaml_data:
920
+ client_name: Optional[str] = None
921
+ name = global_parameter["language"]["default"]["name"]
922
+ if name == "$host":
923
+ # I am the non-parameterized endpoint. Modify name based off of flag
924
+ version_tolerant = self._autorestapi.get_boolean_value(
925
+ "version-tolerant", False
926
+ )
927
+ low_level_client = self._autorestapi.get_boolean_value(
928
+ "low-level-client", False
929
+ )
930
+ client_name = (
931
+ "endpoint" if (version_tolerant or low_level_client) else "base_url"
932
+ )
933
+ global_parameter["language"]["default"]["description"] = "Service URL."
934
+ global_params.append(
935
+ self.update_parameter(
936
+ global_parameter, override_client_name=client_name
937
+ )
938
+ )
939
+ return global_params
940
+
941
+ def get_token_credential(self, credential_scopes: List[str]) -> Dict[str, Any]:
942
+ retval = {
943
+ "type": OAUTH_TYPE,
944
+ "policy": {
945
+ "type": "ARMChallengeAuthenticationPolicy"
946
+ if self.azure_arm
947
+ else "BearerTokenCredentialPolicy",
948
+ "credentialScopes": credential_scopes,
949
+ },
950
+ }
951
+ update_type(retval)
952
+ return retval
953
+
954
+ def update_credential_from_security(
955
+ self, yaml_data: Dict[str, Any]
956
+ ) -> Dict[str, Any]:
957
+ retval: Dict[str, Any] = {}
958
+ for scheme in yaml_data.get("schemes", []):
959
+ if scheme["type"] == OAUTH_TYPE:
960
+ # TokenCredential
961
+ retval = self.get_token_credential(scheme["scopes"])
962
+ elif scheme["type"] == KEY_TYPE:
963
+ retval = get_azure_key_credential(scheme["name"])
964
+ return retval
965
+
966
+ def get_credential_scopes_from_flags(self, auth_policy: str) -> List[str]:
967
+ if self.azure_arm:
968
+ return ["https://management.azure.com/.default"]
969
+ credential_scopes_temp = self._autorestapi.get_value("credential-scopes")
970
+ credential_scopes = (
971
+ credential_scopes_temp.split(",") if credential_scopes_temp else None
972
+ )
973
+ if (
974
+ self._autorestapi.get_boolean_value("credential-scopes", False)
975
+ and not credential_scopes
976
+ ):
977
+ raise ValueError(
978
+ "--credential-scopes takes a list of scopes in comma separated format. "
979
+ "For example: --credential-scopes=https://cognitiveservices.azure.com/.default"
980
+ )
981
+ if not credential_scopes:
982
+ _LOGGER.warning(
983
+ "You have default credential policy %s "
984
+ "but not the --credential-scopes flag set while generating non-management plane code. "
985
+ "This is not recommend because it forces the customer to pass credential scopes "
986
+ "through kwargs if they want to authenticate.",
987
+ auth_policy,
988
+ )
989
+ credential_scopes = []
990
+ return credential_scopes
991
+
992
+ def update_credential_from_flags(self) -> Dict[str, Any]:
993
+ default_auth_policy = (
994
+ "ARMChallengeAuthenticationPolicy"
995
+ if self.azure_arm
996
+ else "BearerTokenCredentialPolicy"
997
+ )
998
+ auth_policy = (
999
+ self._autorestapi.get_value("credential-default-policy-type")
1000
+ or default_auth_policy
1001
+ )
1002
+ credential_scopes = self.get_credential_scopes_from_flags(auth_policy)
1003
+ key = self._autorestapi.get_value("credential-key-header-name")
1004
+ if auth_policy.lower() in (
1005
+ "armchallengeauthenticationpolicy",
1006
+ "bearertokencredentialpolicy",
1007
+ ):
1008
+ if key:
1009
+ raise ValueError(
1010
+ "You have passed in a credential key header name with default credential policy type "
1011
+ f"{auth_policy}. This is not allowed, since credential key header "
1012
+ "name is tied with AzureKeyCredentialPolicy. Instead, with this policy it is recommend you "
1013
+ "pass in --credential-scopes."
1014
+ )
1015
+ return self.get_token_credential(credential_scopes)
1016
+ # Otherwise you have AzureKeyCredentialPolicy
1017
+ if self._autorestapi.get_value("credential-scopes"):
1018
+ raise ValueError(
1019
+ "You have passed in credential scopes with default credential policy type "
1020
+ "AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
1021
+ f"{default_auth_policy}. Instead, with this policy "
1022
+ "you must pass in --credential-key-header-name."
1023
+ )
1024
+ if not key:
1025
+ key = "api-key"
1026
+ _LOGGER.info(
1027
+ "Defaulting the AzureKeyCredentialPolicy header's name to 'api-key'"
1028
+ )
1029
+ return get_azure_key_credential(key)
1030
+
1031
+ def update_credential(
1032
+ self, yaml_data: Dict[str, Any], parameters: List[Dict[str, Any]]
1033
+ ) -> None:
1034
+ # then override with credential flags
1035
+ credential_flag = (
1036
+ self._autorestapi.get_boolean_value("add-credentials", False)
1037
+ or self._autorestapi.get_boolean_value("add-credential", False)
1038
+ or self.azure_arm
1039
+ )
1040
+ if credential_flag:
1041
+ credential_type = self.update_credential_from_flags()
1042
+ else:
1043
+ credential_type = self.update_credential_from_security(yaml_data)
1044
+ if not credential_type:
1045
+ return
1046
+ credential = {
1047
+ "type": credential_type,
1048
+ "optional": False,
1049
+ "description": "Credential needed for the client to connect to Azure.",
1050
+ "clientName": "credential",
1051
+ "location": "other",
1052
+ "restApiName": "credential",
1053
+ "implementation": "Client",
1054
+ "skipUrlEncoding": True,
1055
+ "inOverload": False,
1056
+ }
1057
+ if self._autorestapi.get_boolean_value(
1058
+ "version-tolerant"
1059
+ ) or self._autorestapi.get_boolean_value("low-level-client"):
1060
+ parameters.append(credential)
1061
+ else:
1062
+ parameters.insert(0, credential)
1063
+
1064
+ def update_client(self, yaml_data: Dict[str, Any]) -> Dict[str, Any]:
1065
+ parameters = self.update_global_parameters(
1066
+ yaml_data.get("globalParameters", [])
1067
+ )
1068
+ self.update_credential(yaml_data.get("security", {}), parameters)
1069
+ return {
1070
+ "name": yaml_data["language"]["default"]["name"],
1071
+ "description": yaml_data["info"].get("description"),
1072
+ "parameters": parameters,
1073
+ "url": update_client_url(yaml_data)
1074
+ if yaml_data.get("globalParameters")
1075
+ else "",
1076
+ "namespace": self._autorestapi.get_value("namespace")
1077
+ or yaml_data["language"]["default"]["name"],
1078
+ }
1079
+
1080
+ def update_yaml(self, yaml_data: Dict[str, Any]) -> None:
1081
+ """Convert in place the YAML str."""
1082
+ # First we update the types, so we can access for when we're creating parameters etc.
1083
+ for type_group, types in yaml_data["schemas"].items():
1084
+ for t in types:
1085
+ if (
1086
+ type_group == "objects"
1087
+ and t["language"]["default"]["name"] == "CloudError"
1088
+ ):
1089
+ # we don't generate cloud error
1090
+ continue
1091
+ update_type(t)
1092
+ yaml_data["client"] = self.update_client(yaml_data)
1093
+ yaml_data["operationGroups"] = [
1094
+ self.update_operation_group(og) for og in yaml_data["operationGroups"]
1095
+ ]
1096
+ yaml_data["types"] = list(ORIGINAL_ID_TO_UPDATED_TYPE.values()) + list(
1097
+ KNOWN_TYPES.values()
1098
+ )
1099
+ if yaml_data.get("globalParameters"):
1100
+ del yaml_data["globalParameters"]
1101
+ del yaml_data["info"]
1102
+ del yaml_data["language"]
1103
+ del yaml_data["protocol"]
1104
+ if yaml_data.get("schemas"):
1105
+ del yaml_data["schemas"]
1106
+ if yaml_data.get("security"):
1107
+ del yaml_data["security"]
1108
+ ORIGINAL_ID_TO_UPDATED_TYPE.clear()