@autorest/python 6.13.14 → 6.13.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,6 +37,7 @@ class OptionsRetriever:
37
37
  "multiapi": False,
38
38
  "polymorphic-examples": 5,
39
39
  "generate-sample": False,
40
+ "generate-test": False,
40
41
  "from-typespec": False,
41
42
  "emit-cross-language-definition-file": False,
42
43
  }
@@ -332,6 +333,7 @@ class CodeGenerator(Plugin):
332
333
  "packaging_files_config",
333
334
  "default_optional_constants_to_none",
334
335
  "generate_sample",
336
+ "generate_test",
335
337
  "default_api_version",
336
338
  "from_typespec",
337
339
  "flavor",
@@ -436,6 +438,7 @@ class CodeGeneratorAutorest(CodeGenerator, PluginAutorest):
436
438
  "default-optional-constants-to-none"
437
439
  ),
438
440
  "generate-sample": self._autorestapi.get_boolean_value("generate-sample"),
441
+ "generate-test": self._autorestapi.get_boolean_value("generate-test"),
439
442
  "default-api-version": self._autorestapi.get_value("default-api-version"),
440
443
  }
441
444
  return {k: v for k, v in options.items() if v is not None}
@@ -76,6 +76,7 @@ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-in
76
76
  t for t in self.types_map.values() if isinstance(t, CombinedType) and t.name
77
77
  ]
78
78
  self.cross_language_package_id = self.yaml_data.get("crossLanguagePackageId")
79
+ self.for_test: bool = False
79
80
 
80
81
  @property
81
82
  def has_form_data(self) -> bool:
@@ -175,7 +175,9 @@ class EnumType(BaseType):
175
175
  :rtype: str
176
176
  """
177
177
  if self.code_model.options["models_mode"]:
178
- model_name = f"_models.{self.name}"
178
+ module_name = "_models." if kwargs.get("need_module_name", True) else ""
179
+ file_name = f"{self.code_model.enums_filename}." if self.internal else ""
180
+ model_name = module_name + file_name + self.name
179
181
  # we don't need quoted annotation in operation files, and need it in model folder files.
180
182
  if not kwargs.get("is_operation_file", False):
181
183
  model_name = f'"{model_name}"'
@@ -6,7 +6,10 @@
6
6
  from collections import OrderedDict
7
7
  from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast
8
8
  import sys
9
- from autorest.codegen.models.utils import add_to_pylint_disable, NAME_LENGTH_LIMIT
9
+ from autorest.codegen.models.utils import (
10
+ add_to_pylint_disable,
11
+ NAME_LENGTH_LIMIT,
12
+ )
10
13
  from .base import BaseType
11
14
  from .constant_type import ConstantType
12
15
  from .property import Property
@@ -148,7 +151,11 @@ class ModelType( # pylint: disable=abstract-method
148
151
  if self.discriminated_subtypes:
149
152
  # we will instead print the discriminated subtypes
150
153
  self._created_json_template_representation = False
151
- return self.snake_case_name
154
+ return (
155
+ f'"{self.snake_case_name}"'
156
+ if self.code_model.for_test
157
+ else self.snake_case_name
158
+ )
152
159
 
153
160
  # don't add additional properties, because there's not really a concept of
154
161
  # additional properties in the template
@@ -308,13 +315,14 @@ class JSONModelType(ModelType):
308
315
  class GeneratedModelType(ModelType): # pylint: disable=abstract-method
309
316
  def type_annotation(self, **kwargs: Any) -> str:
310
317
  is_operation_file = kwargs.pop("is_operation_file", False)
311
- retval = f"_models.{self.name}"
312
- if self.internal:
313
- retval = f"{self.code_model.models_filename}.{retval}"
314
- return retval if is_operation_file else f'"{retval}"'
318
+ skip_quote = kwargs.get("skip_quote", False)
319
+ module_name = "_models." if kwargs.get("need_module_name", True) else ""
320
+ file_name = f"{self.code_model.models_filename}." if self.internal else ""
321
+ retval = module_name + file_name + self.name
322
+ return retval if is_operation_file or skip_quote else f'"{retval}"'
315
323
 
316
324
  def docstring_type(self, **kwargs: Any) -> str:
317
- return f"~{self.code_model.namespace}.models.{self.name}"
325
+ return f"~{self.code_model.namespace}.models.{self.type_annotation(need_module_name=False, skip_quote=True)}"
318
326
 
319
327
  def docstring_text(self, **kwargs: Any) -> str:
320
328
  return self.name
@@ -358,8 +366,7 @@ class MsrestModelType(GeneratedModelType):
358
366
 
359
367
  @property
360
368
  def serialization_type(self) -> str:
361
- private_model_path = f"_models.{self.code_model.models_filename}."
362
- return f"{private_model_path if self.internal else ''}{self.name}"
369
+ return self.type_annotation(skip_quote=True) if self.internal else self.name
363
370
 
364
371
  @property
365
372
  def instance_check_template(self) -> str:
@@ -378,7 +385,11 @@ class DPGModelType(GeneratedModelType):
378
385
 
379
386
  @property
380
387
  def serialization_type(self) -> str:
381
- return f"{'_models.' if self.internal else ''}_models.{self.name}"
388
+ return (
389
+ self.type_annotation(skip_quote=True)
390
+ if self.internal
391
+ else self.type_annotation(need_module_name=False, skip_quote=True)
392
+ )
382
393
 
383
394
  @property
384
395
  def instance_check_template(self) -> str:
@@ -217,7 +217,7 @@ class OperationBase( # pylint: disable=too-many-public-methods,too-many-instanc
217
217
  return None
218
218
  excep_schema = default_exceptions[0].type
219
219
  if isinstance(excep_schema, ModelType):
220
- return f"_models.{excep_schema.name}"
220
+ return excep_schema.type_annotation(skip_quote=True)
221
221
  # in this case, it's just an AnyType
222
222
  return "'object'"
223
223
 
@@ -62,7 +62,9 @@ class PrimitiveType(BaseType): # pylint: disable=abstract-method
62
62
  comment = add_to_description(comment, description)
63
63
  if comment:
64
64
  comment = f"# {comment}"
65
- return f"{client_default_value_declaration}{comment}"
65
+ return client_default_value_declaration + (
66
+ "" if self.code_model.for_test else comment
67
+ )
66
68
 
67
69
  @property
68
70
  def default_template_representation_declaration(self) -> str:
@@ -134,7 +134,8 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
134
134
  description: Optional[str] = None,
135
135
  ) -> Any:
136
136
  if self.is_multipart_file_input:
137
- return "[filetype]" if self.type.type == "list" else "filetype"
137
+ file_type_str = '"filetype"' if self.code_model.for_test else "filetype"
138
+ return f"[{file_type_str}]" if self.type.type == "list" else file_type_str
138
139
  if self.client_default_value:
139
140
  client_default_value_declaration = self.get_declaration(
140
141
  self.client_default_value
@@ -27,6 +27,7 @@ from .metadata_serializer import MetadataSerializer
27
27
  from .request_builders_serializer import RequestBuildersSerializer
28
28
  from .patch_serializer import PatchSerializer
29
29
  from .sample_serializer import SampleSerializer
30
+ from .test_serializer import TestSerializer, TestGeneralSerializer
30
31
  from .types_serializer import TypesSerializer
31
32
  from ..._utils import to_snake_case
32
33
  from .._utils import VALID_PACKAGE_MODE
@@ -146,6 +147,14 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
146
147
  ):
147
148
  self._serialize_and_write_sample(env, namespace_path)
148
149
 
150
+ if (
151
+ self.code_model.options["show_operations"]
152
+ and self.code_model.has_operations
153
+ and self.code_model.options["generate_test"]
154
+ and not self.code_model.options["azure_arm"]
155
+ ):
156
+ self._serialize_and_write_test(env, namespace_path)
157
+
149
158
  def serialize(self) -> None:
150
159
  env = Environment(
151
160
  loader=PackageLoader("autorest.codegen", "templates"),
@@ -631,6 +640,40 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
631
640
  log_error = f"error happens in sample {file}: {e}"
632
641
  _LOGGER.error(log_error)
633
642
 
643
+ def _serialize_and_write_test(self, env: Environment, namespace_path: Path):
644
+ self.code_model.for_test = True
645
+ out_path = self._package_root_folder(namespace_path) / Path("generated_tests")
646
+ general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
647
+ self.write_file(
648
+ out_path / "conftest.py", general_serializer.serialize_conftest()
649
+ )
650
+ for is_async in (True, False):
651
+ async_suffix = "_async" if is_async else ""
652
+ general_serializer.is_async = is_async
653
+ self.write_file(
654
+ out_path / f"testpreparer{async_suffix}.py",
655
+ general_serializer.serialize_testpreparer(),
656
+ )
657
+
658
+ for client in self.code_model.clients:
659
+ for og in client.operation_groups:
660
+ test_serializer = TestSerializer(
661
+ self.code_model, env, client=client, operation_group=og
662
+ )
663
+ for is_async in (True, False):
664
+ try:
665
+ test_serializer.is_async = is_async
666
+ self.write_file(
667
+ out_path
668
+ / f"{to_snake_case(test_serializer.test_class_name)}.py",
669
+ test_serializer.serialize_test(),
670
+ )
671
+ except Exception as e: # pylint: disable=broad-except
672
+ # test generation shall not block code generation, so just log error
673
+ log_error = f"error happens in test generation for operation group {og.class_name}: {e}"
674
+ _LOGGER.error(log_error)
675
+ self.code_model.for_test = False
676
+
634
677
 
635
678
  class JinjaSerializerAutorest(JinjaSerializer, ReaderAndWriterAutorest):
636
679
  def __init__(
@@ -4,10 +4,9 @@
4
4
  # Licensed under the MIT License. See License.txt in the project root for
5
5
  # license information.
6
6
  # --------------------------------------------------------------------------
7
- import json
8
7
  from abc import abstractmethod
9
8
  from collections import defaultdict
10
- from typing import Any, Generic, List, Type, TypeVar, Dict, Union, Optional, cast
9
+ from typing import Generic, List, Type, TypeVar, Dict, Union, Optional, cast
11
10
 
12
11
  from ..models import (
13
12
  Operation,
@@ -84,34 +83,13 @@ def _escape_str(input_str: str) -> str:
84
83
  return f'"{replace}"'
85
84
 
86
85
 
87
- def _improve_json_string(template_representation: str) -> Any:
88
- origin = template_representation.split("\n")
89
- final = []
90
- for line in origin:
91
- idx0 = line.find("#")
92
- idx1 = line.rfind('"')
93
- modified_line = ""
94
- if idx0 > -1 and idx1 > -1:
95
- modified_line = line[:idx0] + line[idx1:] + " " + line[idx0:idx1] + "\n"
96
- else:
97
- modified_line = line + "\n"
98
- modified_line = modified_line.replace('"', "").replace("\\", '"')
99
- final.append(modified_line)
100
- return "".join(final)
101
-
102
-
103
- def _json_dumps_template(template_representation: Any) -> Any:
104
- # only for template use, since it wraps everything in strings
105
- return _improve_json_string(json.dumps(template_representation, indent=4))
106
-
107
-
108
86
  def _get_polymorphic_subtype_template(polymorphic_subtype: ModelType) -> List[str]:
109
87
  retval: List[str] = []
110
88
  retval.append("")
111
89
  retval.append(
112
90
  f'# JSON input template for discriminator value "{polymorphic_subtype.discriminator_value}":'
113
91
  )
114
- subtype_template = _json_dumps_template(
92
+ subtype_template = utils.json_dumps_template(
115
93
  polymorphic_subtype.get_json_template_representation(),
116
94
  )
117
95
 
@@ -229,7 +207,7 @@ def _get_json_response_template_to_status_codes(
229
207
  if not json_template:
230
208
  continue
231
209
  status_codes = [str(status_code) for status_code in response.status_codes]
232
- response_json = _json_dumps_template(json_template)
210
+ response_json = utils.json_dumps_template(json_template)
233
211
  retval[response_json].extend(status_codes)
234
212
  return retval
235
213
 
@@ -440,7 +418,7 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
440
418
  template.append(
441
419
  "# JSON input template you can fill out and use as your body input."
442
420
  )
443
- json_template = _json_dumps_template(
421
+ json_template = utils.json_dumps_template(
444
422
  json_type.get_json_template_representation(),
445
423
  )
446
424
  template.extend(
@@ -0,0 +1,283 @@
1
+ # pylint: disable=too-many-lines
2
+ # -------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft Corporation. All rights reserved.
4
+ # Licensed under the MIT License. See License.txt in the project root for
5
+ # license information.
6
+ # --------------------------------------------------------------------------
7
+ from typing import Dict, Any, List, Optional
8
+ from jinja2 import Environment
9
+
10
+ from .import_serializer import FileImportSerializer
11
+ from .base_serializer import BaseSerializer
12
+ from ..models import (
13
+ CodeModel,
14
+ ImportType,
15
+ OperationGroup,
16
+ Client,
17
+ OperationType,
18
+ ModelType,
19
+ BaseType,
20
+ CombinedType,
21
+ )
22
+ from .utils import get_namespace_from_package_name, json_dumps_template
23
+
24
+
25
+ class TestName:
26
+ def __init__(self, client_name: str, *, is_async: bool = False) -> None:
27
+ self.client_name = client_name
28
+ self.is_async = is_async
29
+
30
+ @property
31
+ def async_suffix_capt(self) -> str:
32
+ return "Async" if self.is_async else ""
33
+
34
+ @property
35
+ def create_client_name(self) -> str:
36
+ return "create_async_client" if self.is_async else "create_client"
37
+
38
+ @property
39
+ def prefix(self) -> str:
40
+ return self.client_name.replace("Client", "")
41
+
42
+ @property
43
+ def preparer_name(self) -> str:
44
+ return self.prefix + "Preparer"
45
+
46
+ @property
47
+ def base_test_class_name(self) -> str:
48
+ return f"{self.client_name}TestBase{self.async_suffix_capt}"
49
+
50
+
51
+ class TestCase:
52
+ def __init__(
53
+ self,
54
+ operation_groups: List[OperationGroup],
55
+ params: Dict[str, Any],
56
+ operation: OperationType,
57
+ *,
58
+ is_async: bool = False,
59
+ ) -> None:
60
+ self.operation_groups = operation_groups
61
+ self.params = params
62
+ self.operation = operation
63
+ self.is_async = is_async
64
+
65
+ @property
66
+ def operation_group_prefix(self) -> str:
67
+ if self.operation_groups[-1].is_mixin:
68
+ return ""
69
+ return "." + ".".join([og.property_name for og in self.operation_groups])
70
+
71
+ @property
72
+ def response(self) -> str:
73
+ if self.is_async:
74
+ if self.operation.operation_type == "lropaging":
75
+ return "response = await (await "
76
+ return "response = await "
77
+ return "response = "
78
+
79
+ @property
80
+ def lro_comment(self) -> str:
81
+ return " # poll until service return final result"
82
+
83
+ @property
84
+ def operation_suffix(self) -> str:
85
+ if self.operation.operation_type == "lropaging":
86
+ extra = ")" if self.is_async else ""
87
+ return f"{extra}.result(){self.lro_comment}"
88
+ return ""
89
+
90
+ @property
91
+ def extra_operation(self) -> str:
92
+ if self.is_async:
93
+ if self.operation.operation_type == "lro":
94
+ return f"result = await response.result(){self.lro_comment}"
95
+ if self.operation.operation_type == ("lropaging", "paging"):
96
+ return "result = [r async for r in response]"
97
+ else:
98
+ if self.operation.operation_type == "lro":
99
+ return f"result = response.result(){self.lro_comment}"
100
+ if self.operation.operation_type in ("lropaging", "paging"):
101
+ return "result = [r for r in response]"
102
+ return ""
103
+
104
+
105
+ class Test(TestName):
106
+ def __init__(
107
+ self,
108
+ client_name: str,
109
+ operation_group: OperationGroup,
110
+ testcases: List[TestCase],
111
+ test_class_name: str,
112
+ *,
113
+ is_async: bool = False,
114
+ ) -> None:
115
+ super().__init__(client_name, is_async=is_async)
116
+ self.operation_group = operation_group
117
+ self.testcases = testcases
118
+ self.test_class_name = test_class_name
119
+
120
+
121
+ class TestGeneralSerializer(BaseSerializer):
122
+ def __init__(
123
+ self, code_model: CodeModel, env: Environment, *, is_async: bool = False
124
+ ) -> None:
125
+ super().__init__(code_model, env)
126
+ self.is_async = is_async
127
+
128
+ @property
129
+ def aio_str(self) -> str:
130
+ return ".aio" if self.is_async else ""
131
+
132
+ @property
133
+ def test_names(self) -> List[TestName]:
134
+ return [
135
+ TestName(c.name, is_async=self.is_async) for c in self.code_model.clients
136
+ ]
137
+
138
+ @property
139
+ def import_clients(self) -> FileImportSerializer:
140
+ imports = self.init_file_import()
141
+ namespace = get_namespace_from_package_name(
142
+ self.code_model.options["package_name"]
143
+ )
144
+
145
+ imports.add_submodule_import(
146
+ "devtools_testutils", "AzureRecordedTestCase", ImportType.STDLIB
147
+ )
148
+ if not self.is_async:
149
+ imports.add_import("functools", ImportType.STDLIB)
150
+ imports.add_submodule_import(
151
+ "devtools_testutils", "PowerShellPreparer", ImportType.STDLIB
152
+ )
153
+ for client in self.code_model.clients:
154
+ imports.add_submodule_import(
155
+ namespace + self.aio_str, client.name, ImportType.STDLIB
156
+ )
157
+ return FileImportSerializer(imports, self.is_async)
158
+
159
+ def serialize_conftest(self) -> str:
160
+ return self.env.get_template("conftest.py.jinja2").render(
161
+ test_names=self.test_names,
162
+ code_model=self.code_model,
163
+ )
164
+
165
+ def serialize_testpreparer(self) -> str:
166
+ return self.env.get_template("testpreparer.py.jinja2").render(
167
+ test_names=self.test_names,
168
+ imports=self.import_clients,
169
+ code_model=self.code_model,
170
+ )
171
+
172
+
173
+ class TestSerializer(TestGeneralSerializer):
174
+ def __init__(
175
+ self,
176
+ code_model: CodeModel,
177
+ env: Environment,
178
+ *,
179
+ client: Client,
180
+ operation_group: OperationGroup,
181
+ is_async: bool = False,
182
+ ) -> None:
183
+ super().__init__(code_model, env, is_async=is_async)
184
+ self.client = client
185
+ self.operation_group = operation_group
186
+
187
+ @property
188
+ def import_test(self) -> FileImportSerializer:
189
+ imports = self.init_file_import()
190
+ test_name = TestName(self.client.name, is_async=self.is_async)
191
+ async_suffix = "_async" if self.is_async else ""
192
+ imports.add_submodule_import(
193
+ "testpreparer" + async_suffix,
194
+ test_name.base_test_class_name,
195
+ ImportType.LOCAL,
196
+ )
197
+ imports.add_submodule_import(
198
+ "testpreparer", test_name.preparer_name, ImportType.LOCAL
199
+ )
200
+ imports.add_submodule_import(
201
+ "devtools_testutils" + self.aio_str,
202
+ "recorded_by_proxy" + async_suffix,
203
+ ImportType.LOCAL,
204
+ )
205
+ return FileImportSerializer(imports, self.is_async)
206
+
207
+ @property
208
+ def breadth_search_operation_group(self) -> List[List[OperationGroup]]:
209
+ result = []
210
+ queue = [[self.operation_group]]
211
+ while queue:
212
+ current = queue.pop(0)
213
+ if current[-1].operations:
214
+ result.append(current)
215
+ if current[-1].operation_groups:
216
+ queue.extend([current + [og] for og in current[-1].operation_groups])
217
+ return result
218
+
219
+ def get_sub_type(self, param_type: ModelType) -> ModelType:
220
+ if param_type.discriminated_subtypes:
221
+ for item in param_type.discriminated_subtypes.values():
222
+ return self.get_sub_type(item)
223
+ return param_type
224
+
225
+ def get_model_type(self, param_type: BaseType) -> Optional[ModelType]:
226
+ if isinstance(param_type, ModelType):
227
+ return param_type
228
+ if isinstance(param_type, CombinedType):
229
+ return param_type.target_model_subtype((ModelType,))
230
+ return None
231
+
232
+ def get_operation_params(self, operation: OperationType) -> Dict[str, Any]:
233
+ operation_params = {}
234
+ required_params = [p for p in operation.parameters.method if not p.optional]
235
+ for param in required_params:
236
+ model_type = self.get_model_type(param.type)
237
+ param_type = self.get_sub_type(model_type) if model_type else param.type
238
+ operation_params[param.client_name] = json_dumps_template(
239
+ param_type.get_json_template_representation()
240
+ )
241
+ return operation_params
242
+
243
+ def get_test(self) -> Test:
244
+ testcases = []
245
+ for operation_groups in self.breadth_search_operation_group:
246
+ for operation in operation_groups[-1].operations:
247
+ if operation.internal or operation.is_lro_initial_operation:
248
+ continue
249
+ operation_params = self.get_operation_params(operation)
250
+ testcase = TestCase(
251
+ operation_groups=operation_groups,
252
+ params=operation_params,
253
+ operation=operation,
254
+ is_async=self.is_async,
255
+ )
256
+ testcases.append(testcase)
257
+ if not testcases:
258
+ raise Exception( # pylint: disable=broad-exception-raised
259
+ "no public operation to test"
260
+ )
261
+
262
+ return Test(
263
+ client_name=self.client.name,
264
+ operation_group=self.operation_group,
265
+ testcases=testcases,
266
+ test_class_name=self.test_class_name,
267
+ is_async=self.is_async,
268
+ )
269
+
270
+ @property
271
+ def test_class_name(self) -> str:
272
+ test_name = TestName(self.client.name, is_async=self.is_async)
273
+ class_name = (
274
+ "" if self.operation_group.is_mixin else self.operation_group.class_name
275
+ )
276
+ return f"Test{test_name.prefix}{class_name}{test_name.async_suffix_capt}"
277
+
278
+ def serialize_test(self) -> str:
279
+ return self.env.get_template("test.py.jinja2").render(
280
+ imports=self.import_test,
281
+ code_model=self.code_model,
282
+ test=self.get_test(),
283
+ )
@@ -3,7 +3,8 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import Optional, List
6
+ import json
7
+ from typing import Optional, List, Any
7
8
  from pathlib import Path
8
9
 
9
10
  from ..models import Client, OperationGroup
@@ -44,3 +45,24 @@ def get_all_operation_groups_recursively(clients: List[Client]) -> List[Operatio
44
45
  if operation_groups[-1].operation_groups:
45
46
  queue.extend(operation_groups[-1].operation_groups)
46
47
  return operation_groups
48
+
49
+
50
+ def _improve_json_string(template_representation: str) -> Any:
51
+ origin = template_representation.split("\n")
52
+ final = []
53
+ for line in origin:
54
+ idx0 = line.find("#")
55
+ idx1 = line.rfind('"')
56
+ modified_line = ""
57
+ if idx0 > -1 and idx1 > -1:
58
+ modified_line = line[:idx0] + line[idx1:] + " " + line[idx0:idx1] + "\n"
59
+ else:
60
+ modified_line = line + "\n"
61
+ modified_line = modified_line.replace('"', "").replace("\\", '"')
62
+ final.append(modified_line)
63
+ return "".join(final)
64
+
65
+
66
+ def json_dumps_template(template_representation: Any) -> Any:
67
+ # only for template use, since it wraps everything in strings
68
+ return _improve_json_string(json.dumps(template_representation, indent=4))
@@ -0,0 +1,28 @@
1
+ # coding=utf-8
2
+ {{ code_model.options['license_header'] }}
3
+ import os
4
+ import pytest
5
+ from dotenv import load_dotenv
6
+ from devtools_testutils import test_proxy, add_general_regex_sanitizer, add_body_key_sanitizer, add_header_regex_sanitizer
7
+
8
+ load_dotenv()
9
+
10
+ # aovid record sensitive identity information in recordings
11
+ @pytest.fixture(scope="session", autouse=True)
12
+ def add_sanitizers(test_proxy):
13
+ {% for test_name in test_names %}
14
+ {% set prefix_upper = test_name.prefix|upper %}
15
+ {% set prefix_lower = test_name.prefix|lower %}
16
+ {{ prefix_lower }}_subscription_id = os.environ.get("{{ prefix_upper }}_SUBSCRIPTION_ID", "00000000-0000-0000-0000-000000000000")
17
+ {{ prefix_lower }}_tenant_id = os.environ.get("{{ prefix_upper }}_TENANT_ID", "00000000-0000-0000-0000-000000000000")
18
+ {{ prefix_lower }}_client_id = os.environ.get("{{ prefix_upper }}_CLIENT_ID", "00000000-0000-0000-0000-000000000000")
19
+ {{ prefix_lower }}_client_secret = os.environ.get("{{ prefix_upper }}_CLIENT_SECRET", "00000000-0000-0000-0000-000000000000")
20
+ add_general_regex_sanitizer(regex={{ prefix_lower }}_subscription_id, value="00000000-0000-0000-0000-000000000000")
21
+ add_general_regex_sanitizer(regex={{ prefix_lower }}_tenant_id, value="00000000-0000-0000-0000-000000000000")
22
+ add_general_regex_sanitizer(regex={{ prefix_lower }}_client_id, value="00000000-0000-0000-0000-000000000000")
23
+ add_general_regex_sanitizer(regex={{ prefix_lower }}_client_secret, value="00000000-0000-0000-0000-000000000000")
24
+
25
+ {% endfor %}
26
+ add_header_regex_sanitizer(key="Set-Cookie", value="[set-cookie;]")
27
+ add_header_regex_sanitizer(key="Cookie", value="cookie;")
28
+ add_body_key_sanitizer(json_path="$..access_token", value="access_token")
@@ -6,6 +6,7 @@
6
6
  # --------------------------------------------------------------------------
7
7
  # pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
8
8
 
9
+ import copy
9
10
  import calendar
10
11
  import decimal
11
12
  import functools
@@ -655,6 +656,12 @@ def _deserialize_sequence(
655
656
  return obj
656
657
  return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
657
658
 
659
+ def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.Any]:
660
+ return sorted(
661
+ types,
662
+ key=lambda x: hasattr(x, "__name__") and x.__name__.lower() in ("str", "float", "int", "bool"),
663
+ )
664
+
658
665
  def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
659
666
  annotation: typing.Any,
660
667
  module: typing.Optional[str],
@@ -696,21 +703,25 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
696
703
  # is it optional?
697
704
  try:
698
705
  if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
699
- if_obj_deserializer = _get_deserialize_callable_from_annotation(
700
- next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
701
- )
702
-
703
- return functools.partial(_deserialize_with_optional, if_obj_deserializer)
706
+ if len(annotation.__args__) <= 2: # pyright: ignore
707
+ if_obj_deserializer = _get_deserialize_callable_from_annotation(
708
+ next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
709
+ )
710
+
711
+ return functools.partial(_deserialize_with_optional, if_obj_deserializer)
712
+ # the type is Optional[Union[...]], we need to remove the None type from the Union
713
+ annotation_copy = copy.copy(annotation)
714
+ annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore
715
+ return _get_deserialize_callable_from_annotation(annotation_copy, module, rf)
704
716
  except AttributeError:
705
717
  pass
706
718
 
719
+ # is it union?
707
720
  if getattr(annotation, "__origin__", None) is typing.Union:
708
721
  # initial ordering is we make `string` the last deserialization option, because it is often them most generic
709
722
  deserializers = [
710
723
  _get_deserialize_callable_from_annotation(arg, module, rf)
711
- for arg in sorted(
712
- annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
713
- )
724
+ for arg in _sorted_annotations(annotation.__args__) # pyright: ignore
714
725
  ]
715
726
 
716
727
  return functools.partial(_deserialize_with_union, deserializers)
@@ -0,0 +1,26 @@
1
+ {% set prefix_lower = test.prefix|lower %}
2
+ {% set async = "async " if test.is_async else "" %}
3
+ {% set async_suffix = "_async" if test.is_async else "" %}
4
+ # coding=utf-8
5
+ {{ code_model.options['license_header'] }}
6
+ import pytest
7
+ {{ imports }}
8
+
9
+
10
+ @pytest.mark.skip("you may need to update the auto-generated test case before run it")
11
+ class {{ test.test_class_name }}({{ test.base_test_class_name }}):
12
+ {% for testcase in test.testcases %}
13
+ @{{ test.preparer_name }}()
14
+ @recorded_by_proxy{{ async_suffix }}
15
+ {{ async }}def test_{{ testcase.operation.name }}(self, {{ prefix_lower }}_endpoint):
16
+ client = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint)
17
+ {{testcase.response }}client{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
18
+ {% for key, value in testcase.params.items() %}
19
+ {{ key }}={{ value|indent(12) }},
20
+ {% endfor %}
21
+ ){{ testcase.operation_suffix }}
22
+ {{ testcase.extra_operation }}
23
+ # please add some check logic here by yourself
24
+ # ...
25
+
26
+ {% endfor %}
@@ -0,0 +1,26 @@
1
+ # coding=utf-8
2
+ {{ code_model.options['license_header'] }}
3
+ {{ imports }}
4
+
5
+ {% for test_name in test_names %}
6
+ {% set extra_async = ", is_async=True" if test_name.is_async else ""%}
7
+ {% set prefix_lower = test_name.prefix|lower %}
8
+ class {{ test_name.base_test_class_name }}(AzureRecordedTestCase):
9
+
10
+ def {{ test_name.create_client_name }}(self, endpoint):
11
+ credential = self.get_credential({{ test_name.client_name }}{{ extra_async }})
12
+ return self.create_client_from_credential(
13
+ {{ test_name.client_name }},
14
+ credential=credential,
15
+ endpoint=endpoint,
16
+ )
17
+
18
+ {% if not test_name.is_async %}
19
+ {{ test_name.preparer_name }} = functools.partial(
20
+ PowerShellPreparer,
21
+ "{{ prefix_lower }}",
22
+ {{ prefix_lower }}_endpoint="https://fake_{{ prefix_lower }}_endpoint.com"
23
+ )
24
+ {% endif %}
25
+
26
+ {% endfor %}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.13.14",
3
+ "version": "6.13.15",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "main": "index.js",
6
6
  "repository": {