@autorest/python 6.13.14 → 6.13.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/autorest/codegen/__init__.py +3 -0
- package/autorest/codegen/models/code_model.py +1 -0
- package/autorest/codegen/models/enum_type.py +3 -1
- package/autorest/codegen/models/model_type.py +21 -10
- package/autorest/codegen/models/operation.py +1 -1
- package/autorest/codegen/models/primitive_types.py +3 -1
- package/autorest/codegen/models/property.py +2 -1
- package/autorest/codegen/serializers/__init__.py +43 -0
- package/autorest/codegen/serializers/builder_serializer.py +4 -26
- package/autorest/codegen/serializers/test_serializer.py +283 -0
- package/autorest/codegen/serializers/utils.py +23 -1
- package/autorest/codegen/templates/conftest.py.jinja2 +28 -0
- package/autorest/codegen/templates/model_base.py.jinja2 +19 -8
- package/autorest/codegen/templates/test.py.jinja2 +26 -0
- package/autorest/codegen/templates/testpreparer.py.jinja2 +26 -0
- package/package.json +1 -1
|
@@ -37,6 +37,7 @@ class OptionsRetriever:
|
|
|
37
37
|
"multiapi": False,
|
|
38
38
|
"polymorphic-examples": 5,
|
|
39
39
|
"generate-sample": False,
|
|
40
|
+
"generate-test": False,
|
|
40
41
|
"from-typespec": False,
|
|
41
42
|
"emit-cross-language-definition-file": False,
|
|
42
43
|
}
|
|
@@ -332,6 +333,7 @@ class CodeGenerator(Plugin):
|
|
|
332
333
|
"packaging_files_config",
|
|
333
334
|
"default_optional_constants_to_none",
|
|
334
335
|
"generate_sample",
|
|
336
|
+
"generate_test",
|
|
335
337
|
"default_api_version",
|
|
336
338
|
"from_typespec",
|
|
337
339
|
"flavor",
|
|
@@ -436,6 +438,7 @@ class CodeGeneratorAutorest(CodeGenerator, PluginAutorest):
|
|
|
436
438
|
"default-optional-constants-to-none"
|
|
437
439
|
),
|
|
438
440
|
"generate-sample": self._autorestapi.get_boolean_value("generate-sample"),
|
|
441
|
+
"generate-test": self._autorestapi.get_boolean_value("generate-test"),
|
|
439
442
|
"default-api-version": self._autorestapi.get_value("default-api-version"),
|
|
440
443
|
}
|
|
441
444
|
return {k: v for k, v in options.items() if v is not None}
|
|
@@ -76,6 +76,7 @@ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-in
|
|
|
76
76
|
t for t in self.types_map.values() if isinstance(t, CombinedType) and t.name
|
|
77
77
|
]
|
|
78
78
|
self.cross_language_package_id = self.yaml_data.get("crossLanguagePackageId")
|
|
79
|
+
self.for_test: bool = False
|
|
79
80
|
|
|
80
81
|
@property
|
|
81
82
|
def has_form_data(self) -> bool:
|
|
@@ -175,7 +175,9 @@ class EnumType(BaseType):
|
|
|
175
175
|
:rtype: str
|
|
176
176
|
"""
|
|
177
177
|
if self.code_model.options["models_mode"]:
|
|
178
|
-
|
|
178
|
+
module_name = "_models." if kwargs.get("need_module_name", True) else ""
|
|
179
|
+
file_name = f"{self.code_model.enums_filename}." if self.internal else ""
|
|
180
|
+
model_name = module_name + file_name + self.name
|
|
179
181
|
# we don't need quoted annotation in operation files, and need it in model folder files.
|
|
180
182
|
if not kwargs.get("is_operation_file", False):
|
|
181
183
|
model_name = f'"{model_name}"'
|
|
@@ -6,7 +6,10 @@
|
|
|
6
6
|
from collections import OrderedDict
|
|
7
7
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast
|
|
8
8
|
import sys
|
|
9
|
-
from autorest.codegen.models.utils import
|
|
9
|
+
from autorest.codegen.models.utils import (
|
|
10
|
+
add_to_pylint_disable,
|
|
11
|
+
NAME_LENGTH_LIMIT,
|
|
12
|
+
)
|
|
10
13
|
from .base import BaseType
|
|
11
14
|
from .constant_type import ConstantType
|
|
12
15
|
from .property import Property
|
|
@@ -148,7 +151,11 @@ class ModelType( # pylint: disable=abstract-method
|
|
|
148
151
|
if self.discriminated_subtypes:
|
|
149
152
|
# we will instead print the discriminated subtypes
|
|
150
153
|
self._created_json_template_representation = False
|
|
151
|
-
return
|
|
154
|
+
return (
|
|
155
|
+
f'"{self.snake_case_name}"'
|
|
156
|
+
if self.code_model.for_test
|
|
157
|
+
else self.snake_case_name
|
|
158
|
+
)
|
|
152
159
|
|
|
153
160
|
# don't add additional properties, because there's not really a concept of
|
|
154
161
|
# additional properties in the template
|
|
@@ -308,13 +315,14 @@ class JSONModelType(ModelType):
|
|
|
308
315
|
class GeneratedModelType(ModelType): # pylint: disable=abstract-method
|
|
309
316
|
def type_annotation(self, **kwargs: Any) -> str:
|
|
310
317
|
is_operation_file = kwargs.pop("is_operation_file", False)
|
|
311
|
-
|
|
312
|
-
if
|
|
313
|
-
|
|
314
|
-
|
|
318
|
+
skip_quote = kwargs.get("skip_quote", False)
|
|
319
|
+
module_name = "_models." if kwargs.get("need_module_name", True) else ""
|
|
320
|
+
file_name = f"{self.code_model.models_filename}." if self.internal else ""
|
|
321
|
+
retval = module_name + file_name + self.name
|
|
322
|
+
return retval if is_operation_file or skip_quote else f'"{retval}"'
|
|
315
323
|
|
|
316
324
|
def docstring_type(self, **kwargs: Any) -> str:
|
|
317
|
-
return f"~{self.code_model.namespace}.models.{self.
|
|
325
|
+
return f"~{self.code_model.namespace}.models.{self.type_annotation(need_module_name=False, skip_quote=True)}"
|
|
318
326
|
|
|
319
327
|
def docstring_text(self, **kwargs: Any) -> str:
|
|
320
328
|
return self.name
|
|
@@ -358,8 +366,7 @@ class MsrestModelType(GeneratedModelType):
|
|
|
358
366
|
|
|
359
367
|
@property
|
|
360
368
|
def serialization_type(self) -> str:
|
|
361
|
-
|
|
362
|
-
return f"{private_model_path if self.internal else ''}{self.name}"
|
|
369
|
+
return self.type_annotation(skip_quote=True) if self.internal else self.name
|
|
363
370
|
|
|
364
371
|
@property
|
|
365
372
|
def instance_check_template(self) -> str:
|
|
@@ -378,7 +385,11 @@ class DPGModelType(GeneratedModelType):
|
|
|
378
385
|
|
|
379
386
|
@property
|
|
380
387
|
def serialization_type(self) -> str:
|
|
381
|
-
return
|
|
388
|
+
return (
|
|
389
|
+
self.type_annotation(skip_quote=True)
|
|
390
|
+
if self.internal
|
|
391
|
+
else self.type_annotation(need_module_name=False, skip_quote=True)
|
|
392
|
+
)
|
|
382
393
|
|
|
383
394
|
@property
|
|
384
395
|
def instance_check_template(self) -> str:
|
|
@@ -217,7 +217,7 @@ class OperationBase( # pylint: disable=too-many-public-methods,too-many-instanc
|
|
|
217
217
|
return None
|
|
218
218
|
excep_schema = default_exceptions[0].type
|
|
219
219
|
if isinstance(excep_schema, ModelType):
|
|
220
|
-
return
|
|
220
|
+
return excep_schema.type_annotation(skip_quote=True)
|
|
221
221
|
# in this case, it's just an AnyType
|
|
222
222
|
return "'object'"
|
|
223
223
|
|
|
@@ -62,7 +62,9 @@ class PrimitiveType(BaseType): # pylint: disable=abstract-method
|
|
|
62
62
|
comment = add_to_description(comment, description)
|
|
63
63
|
if comment:
|
|
64
64
|
comment = f"# {comment}"
|
|
65
|
-
return
|
|
65
|
+
return client_default_value_declaration + (
|
|
66
|
+
"" if self.code_model.for_test else comment
|
|
67
|
+
)
|
|
66
68
|
|
|
67
69
|
@property
|
|
68
70
|
def default_template_representation_declaration(self) -> str:
|
|
@@ -134,7 +134,8 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
|
|
|
134
134
|
description: Optional[str] = None,
|
|
135
135
|
) -> Any:
|
|
136
136
|
if self.is_multipart_file_input:
|
|
137
|
-
|
|
137
|
+
file_type_str = '"filetype"' if self.code_model.for_test else "filetype"
|
|
138
|
+
return f"[{file_type_str}]" if self.type.type == "list" else file_type_str
|
|
138
139
|
if self.client_default_value:
|
|
139
140
|
client_default_value_declaration = self.get_declaration(
|
|
140
141
|
self.client_default_value
|
|
@@ -27,6 +27,7 @@ from .metadata_serializer import MetadataSerializer
|
|
|
27
27
|
from .request_builders_serializer import RequestBuildersSerializer
|
|
28
28
|
from .patch_serializer import PatchSerializer
|
|
29
29
|
from .sample_serializer import SampleSerializer
|
|
30
|
+
from .test_serializer import TestSerializer, TestGeneralSerializer
|
|
30
31
|
from .types_serializer import TypesSerializer
|
|
31
32
|
from ..._utils import to_snake_case
|
|
32
33
|
from .._utils import VALID_PACKAGE_MODE
|
|
@@ -146,6 +147,14 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
|
|
|
146
147
|
):
|
|
147
148
|
self._serialize_and_write_sample(env, namespace_path)
|
|
148
149
|
|
|
150
|
+
if (
|
|
151
|
+
self.code_model.options["show_operations"]
|
|
152
|
+
and self.code_model.has_operations
|
|
153
|
+
and self.code_model.options["generate_test"]
|
|
154
|
+
and not self.code_model.options["azure_arm"]
|
|
155
|
+
):
|
|
156
|
+
self._serialize_and_write_test(env, namespace_path)
|
|
157
|
+
|
|
149
158
|
def serialize(self) -> None:
|
|
150
159
|
env = Environment(
|
|
151
160
|
loader=PackageLoader("autorest.codegen", "templates"),
|
|
@@ -631,6 +640,40 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
|
|
|
631
640
|
log_error = f"error happens in sample {file}: {e}"
|
|
632
641
|
_LOGGER.error(log_error)
|
|
633
642
|
|
|
643
|
+
def _serialize_and_write_test(self, env: Environment, namespace_path: Path):
|
|
644
|
+
self.code_model.for_test = True
|
|
645
|
+
out_path = self._package_root_folder(namespace_path) / Path("generated_tests")
|
|
646
|
+
general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
|
|
647
|
+
self.write_file(
|
|
648
|
+
out_path / "conftest.py", general_serializer.serialize_conftest()
|
|
649
|
+
)
|
|
650
|
+
for is_async in (True, False):
|
|
651
|
+
async_suffix = "_async" if is_async else ""
|
|
652
|
+
general_serializer.is_async = is_async
|
|
653
|
+
self.write_file(
|
|
654
|
+
out_path / f"testpreparer{async_suffix}.py",
|
|
655
|
+
general_serializer.serialize_testpreparer(),
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
for client in self.code_model.clients:
|
|
659
|
+
for og in client.operation_groups:
|
|
660
|
+
test_serializer = TestSerializer(
|
|
661
|
+
self.code_model, env, client=client, operation_group=og
|
|
662
|
+
)
|
|
663
|
+
for is_async in (True, False):
|
|
664
|
+
try:
|
|
665
|
+
test_serializer.is_async = is_async
|
|
666
|
+
self.write_file(
|
|
667
|
+
out_path
|
|
668
|
+
/ f"{to_snake_case(test_serializer.test_class_name)}.py",
|
|
669
|
+
test_serializer.serialize_test(),
|
|
670
|
+
)
|
|
671
|
+
except Exception as e: # pylint: disable=broad-except
|
|
672
|
+
# test generation shall not block code generation, so just log error
|
|
673
|
+
log_error = f"error happens in test generation for operation group {og.class_name}: {e}"
|
|
674
|
+
_LOGGER.error(log_error)
|
|
675
|
+
self.code_model.for_test = False
|
|
676
|
+
|
|
634
677
|
|
|
635
678
|
class JinjaSerializerAutorest(JinjaSerializer, ReaderAndWriterAutorest):
|
|
636
679
|
def __init__(
|
|
@@ -4,10 +4,9 @@
|
|
|
4
4
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
5
5
|
# license information.
|
|
6
6
|
# --------------------------------------------------------------------------
|
|
7
|
-
import json
|
|
8
7
|
from abc import abstractmethod
|
|
9
8
|
from collections import defaultdict
|
|
10
|
-
from typing import
|
|
9
|
+
from typing import Generic, List, Type, TypeVar, Dict, Union, Optional, cast
|
|
11
10
|
|
|
12
11
|
from ..models import (
|
|
13
12
|
Operation,
|
|
@@ -84,34 +83,13 @@ def _escape_str(input_str: str) -> str:
|
|
|
84
83
|
return f'"{replace}"'
|
|
85
84
|
|
|
86
85
|
|
|
87
|
-
def _improve_json_string(template_representation: str) -> Any:
|
|
88
|
-
origin = template_representation.split("\n")
|
|
89
|
-
final = []
|
|
90
|
-
for line in origin:
|
|
91
|
-
idx0 = line.find("#")
|
|
92
|
-
idx1 = line.rfind('"')
|
|
93
|
-
modified_line = ""
|
|
94
|
-
if idx0 > -1 and idx1 > -1:
|
|
95
|
-
modified_line = line[:idx0] + line[idx1:] + " " + line[idx0:idx1] + "\n"
|
|
96
|
-
else:
|
|
97
|
-
modified_line = line + "\n"
|
|
98
|
-
modified_line = modified_line.replace('"', "").replace("\\", '"')
|
|
99
|
-
final.append(modified_line)
|
|
100
|
-
return "".join(final)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def _json_dumps_template(template_representation: Any) -> Any:
|
|
104
|
-
# only for template use, since it wraps everything in strings
|
|
105
|
-
return _improve_json_string(json.dumps(template_representation, indent=4))
|
|
106
|
-
|
|
107
|
-
|
|
108
86
|
def _get_polymorphic_subtype_template(polymorphic_subtype: ModelType) -> List[str]:
|
|
109
87
|
retval: List[str] = []
|
|
110
88
|
retval.append("")
|
|
111
89
|
retval.append(
|
|
112
90
|
f'# JSON input template for discriminator value "{polymorphic_subtype.discriminator_value}":'
|
|
113
91
|
)
|
|
114
|
-
subtype_template =
|
|
92
|
+
subtype_template = utils.json_dumps_template(
|
|
115
93
|
polymorphic_subtype.get_json_template_representation(),
|
|
116
94
|
)
|
|
117
95
|
|
|
@@ -229,7 +207,7 @@ def _get_json_response_template_to_status_codes(
|
|
|
229
207
|
if not json_template:
|
|
230
208
|
continue
|
|
231
209
|
status_codes = [str(status_code) for status_code in response.status_codes]
|
|
232
|
-
response_json =
|
|
210
|
+
response_json = utils.json_dumps_template(json_template)
|
|
233
211
|
retval[response_json].extend(status_codes)
|
|
234
212
|
return retval
|
|
235
213
|
|
|
@@ -440,7 +418,7 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
|
|
|
440
418
|
template.append(
|
|
441
419
|
"# JSON input template you can fill out and use as your body input."
|
|
442
420
|
)
|
|
443
|
-
json_template =
|
|
421
|
+
json_template = utils.json_dumps_template(
|
|
444
422
|
json_type.get_json_template_representation(),
|
|
445
423
|
)
|
|
446
424
|
template.extend(
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
# pylint: disable=too-many-lines
|
|
2
|
+
# -------------------------------------------------------------------------
|
|
3
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
4
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
5
|
+
# license information.
|
|
6
|
+
# --------------------------------------------------------------------------
|
|
7
|
+
from typing import Dict, Any, List, Optional
|
|
8
|
+
from jinja2 import Environment
|
|
9
|
+
|
|
10
|
+
from .import_serializer import FileImportSerializer
|
|
11
|
+
from .base_serializer import BaseSerializer
|
|
12
|
+
from ..models import (
|
|
13
|
+
CodeModel,
|
|
14
|
+
ImportType,
|
|
15
|
+
OperationGroup,
|
|
16
|
+
Client,
|
|
17
|
+
OperationType,
|
|
18
|
+
ModelType,
|
|
19
|
+
BaseType,
|
|
20
|
+
CombinedType,
|
|
21
|
+
)
|
|
22
|
+
from .utils import get_namespace_from_package_name, json_dumps_template
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestName:
|
|
26
|
+
def __init__(self, client_name: str, *, is_async: bool = False) -> None:
|
|
27
|
+
self.client_name = client_name
|
|
28
|
+
self.is_async = is_async
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def async_suffix_capt(self) -> str:
|
|
32
|
+
return "Async" if self.is_async else ""
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def create_client_name(self) -> str:
|
|
36
|
+
return "create_async_client" if self.is_async else "create_client"
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def prefix(self) -> str:
|
|
40
|
+
return self.client_name.replace("Client", "")
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def preparer_name(self) -> str:
|
|
44
|
+
return self.prefix + "Preparer"
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def base_test_class_name(self) -> str:
|
|
48
|
+
return f"{self.client_name}TestBase{self.async_suffix_capt}"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TestCase:
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
operation_groups: List[OperationGroup],
|
|
55
|
+
params: Dict[str, Any],
|
|
56
|
+
operation: OperationType,
|
|
57
|
+
*,
|
|
58
|
+
is_async: bool = False,
|
|
59
|
+
) -> None:
|
|
60
|
+
self.operation_groups = operation_groups
|
|
61
|
+
self.params = params
|
|
62
|
+
self.operation = operation
|
|
63
|
+
self.is_async = is_async
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def operation_group_prefix(self) -> str:
|
|
67
|
+
if self.operation_groups[-1].is_mixin:
|
|
68
|
+
return ""
|
|
69
|
+
return "." + ".".join([og.property_name for og in self.operation_groups])
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def response(self) -> str:
|
|
73
|
+
if self.is_async:
|
|
74
|
+
if self.operation.operation_type == "lropaging":
|
|
75
|
+
return "response = await (await "
|
|
76
|
+
return "response = await "
|
|
77
|
+
return "response = "
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def lro_comment(self) -> str:
|
|
81
|
+
return " # poll until service return final result"
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def operation_suffix(self) -> str:
|
|
85
|
+
if self.operation.operation_type == "lropaging":
|
|
86
|
+
extra = ")" if self.is_async else ""
|
|
87
|
+
return f"{extra}.result(){self.lro_comment}"
|
|
88
|
+
return ""
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def extra_operation(self) -> str:
|
|
92
|
+
if self.is_async:
|
|
93
|
+
if self.operation.operation_type == "lro":
|
|
94
|
+
return f"result = await response.result(){self.lro_comment}"
|
|
95
|
+
if self.operation.operation_type == ("lropaging", "paging"):
|
|
96
|
+
return "result = [r async for r in response]"
|
|
97
|
+
else:
|
|
98
|
+
if self.operation.operation_type == "lro":
|
|
99
|
+
return f"result = response.result(){self.lro_comment}"
|
|
100
|
+
if self.operation.operation_type in ("lropaging", "paging"):
|
|
101
|
+
return "result = [r for r in response]"
|
|
102
|
+
return ""
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Test(TestName):
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
client_name: str,
|
|
109
|
+
operation_group: OperationGroup,
|
|
110
|
+
testcases: List[TestCase],
|
|
111
|
+
test_class_name: str,
|
|
112
|
+
*,
|
|
113
|
+
is_async: bool = False,
|
|
114
|
+
) -> None:
|
|
115
|
+
super().__init__(client_name, is_async=is_async)
|
|
116
|
+
self.operation_group = operation_group
|
|
117
|
+
self.testcases = testcases
|
|
118
|
+
self.test_class_name = test_class_name
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class TestGeneralSerializer(BaseSerializer):
|
|
122
|
+
def __init__(
|
|
123
|
+
self, code_model: CodeModel, env: Environment, *, is_async: bool = False
|
|
124
|
+
) -> None:
|
|
125
|
+
super().__init__(code_model, env)
|
|
126
|
+
self.is_async = is_async
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def aio_str(self) -> str:
|
|
130
|
+
return ".aio" if self.is_async else ""
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def test_names(self) -> List[TestName]:
|
|
134
|
+
return [
|
|
135
|
+
TestName(c.name, is_async=self.is_async) for c in self.code_model.clients
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def import_clients(self) -> FileImportSerializer:
|
|
140
|
+
imports = self.init_file_import()
|
|
141
|
+
namespace = get_namespace_from_package_name(
|
|
142
|
+
self.code_model.options["package_name"]
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
imports.add_submodule_import(
|
|
146
|
+
"devtools_testutils", "AzureRecordedTestCase", ImportType.STDLIB
|
|
147
|
+
)
|
|
148
|
+
if not self.is_async:
|
|
149
|
+
imports.add_import("functools", ImportType.STDLIB)
|
|
150
|
+
imports.add_submodule_import(
|
|
151
|
+
"devtools_testutils", "PowerShellPreparer", ImportType.STDLIB
|
|
152
|
+
)
|
|
153
|
+
for client in self.code_model.clients:
|
|
154
|
+
imports.add_submodule_import(
|
|
155
|
+
namespace + self.aio_str, client.name, ImportType.STDLIB
|
|
156
|
+
)
|
|
157
|
+
return FileImportSerializer(imports, self.is_async)
|
|
158
|
+
|
|
159
|
+
def serialize_conftest(self) -> str:
|
|
160
|
+
return self.env.get_template("conftest.py.jinja2").render(
|
|
161
|
+
test_names=self.test_names,
|
|
162
|
+
code_model=self.code_model,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def serialize_testpreparer(self) -> str:
|
|
166
|
+
return self.env.get_template("testpreparer.py.jinja2").render(
|
|
167
|
+
test_names=self.test_names,
|
|
168
|
+
imports=self.import_clients,
|
|
169
|
+
code_model=self.code_model,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class TestSerializer(TestGeneralSerializer):
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
code_model: CodeModel,
|
|
177
|
+
env: Environment,
|
|
178
|
+
*,
|
|
179
|
+
client: Client,
|
|
180
|
+
operation_group: OperationGroup,
|
|
181
|
+
is_async: bool = False,
|
|
182
|
+
) -> None:
|
|
183
|
+
super().__init__(code_model, env, is_async=is_async)
|
|
184
|
+
self.client = client
|
|
185
|
+
self.operation_group = operation_group
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def import_test(self) -> FileImportSerializer:
|
|
189
|
+
imports = self.init_file_import()
|
|
190
|
+
test_name = TestName(self.client.name, is_async=self.is_async)
|
|
191
|
+
async_suffix = "_async" if self.is_async else ""
|
|
192
|
+
imports.add_submodule_import(
|
|
193
|
+
"testpreparer" + async_suffix,
|
|
194
|
+
test_name.base_test_class_name,
|
|
195
|
+
ImportType.LOCAL,
|
|
196
|
+
)
|
|
197
|
+
imports.add_submodule_import(
|
|
198
|
+
"testpreparer", test_name.preparer_name, ImportType.LOCAL
|
|
199
|
+
)
|
|
200
|
+
imports.add_submodule_import(
|
|
201
|
+
"devtools_testutils" + self.aio_str,
|
|
202
|
+
"recorded_by_proxy" + async_suffix,
|
|
203
|
+
ImportType.LOCAL,
|
|
204
|
+
)
|
|
205
|
+
return FileImportSerializer(imports, self.is_async)
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def breadth_search_operation_group(self) -> List[List[OperationGroup]]:
|
|
209
|
+
result = []
|
|
210
|
+
queue = [[self.operation_group]]
|
|
211
|
+
while queue:
|
|
212
|
+
current = queue.pop(0)
|
|
213
|
+
if current[-1].operations:
|
|
214
|
+
result.append(current)
|
|
215
|
+
if current[-1].operation_groups:
|
|
216
|
+
queue.extend([current + [og] for og in current[-1].operation_groups])
|
|
217
|
+
return result
|
|
218
|
+
|
|
219
|
+
def get_sub_type(self, param_type: ModelType) -> ModelType:
|
|
220
|
+
if param_type.discriminated_subtypes:
|
|
221
|
+
for item in param_type.discriminated_subtypes.values():
|
|
222
|
+
return self.get_sub_type(item)
|
|
223
|
+
return param_type
|
|
224
|
+
|
|
225
|
+
def get_model_type(self, param_type: BaseType) -> Optional[ModelType]:
|
|
226
|
+
if isinstance(param_type, ModelType):
|
|
227
|
+
return param_type
|
|
228
|
+
if isinstance(param_type, CombinedType):
|
|
229
|
+
return param_type.target_model_subtype((ModelType,))
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
def get_operation_params(self, operation: OperationType) -> Dict[str, Any]:
|
|
233
|
+
operation_params = {}
|
|
234
|
+
required_params = [p for p in operation.parameters.method if not p.optional]
|
|
235
|
+
for param in required_params:
|
|
236
|
+
model_type = self.get_model_type(param.type)
|
|
237
|
+
param_type = self.get_sub_type(model_type) if model_type else param.type
|
|
238
|
+
operation_params[param.client_name] = json_dumps_template(
|
|
239
|
+
param_type.get_json_template_representation()
|
|
240
|
+
)
|
|
241
|
+
return operation_params
|
|
242
|
+
|
|
243
|
+
def get_test(self) -> Test:
|
|
244
|
+
testcases = []
|
|
245
|
+
for operation_groups in self.breadth_search_operation_group:
|
|
246
|
+
for operation in operation_groups[-1].operations:
|
|
247
|
+
if operation.internal or operation.is_lro_initial_operation:
|
|
248
|
+
continue
|
|
249
|
+
operation_params = self.get_operation_params(operation)
|
|
250
|
+
testcase = TestCase(
|
|
251
|
+
operation_groups=operation_groups,
|
|
252
|
+
params=operation_params,
|
|
253
|
+
operation=operation,
|
|
254
|
+
is_async=self.is_async,
|
|
255
|
+
)
|
|
256
|
+
testcases.append(testcase)
|
|
257
|
+
if not testcases:
|
|
258
|
+
raise Exception( # pylint: disable=broad-exception-raised
|
|
259
|
+
"no public operation to test"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
return Test(
|
|
263
|
+
client_name=self.client.name,
|
|
264
|
+
operation_group=self.operation_group,
|
|
265
|
+
testcases=testcases,
|
|
266
|
+
test_class_name=self.test_class_name,
|
|
267
|
+
is_async=self.is_async,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def test_class_name(self) -> str:
|
|
272
|
+
test_name = TestName(self.client.name, is_async=self.is_async)
|
|
273
|
+
class_name = (
|
|
274
|
+
"" if self.operation_group.is_mixin else self.operation_group.class_name
|
|
275
|
+
)
|
|
276
|
+
return f"Test{test_name.prefix}{class_name}{test_name.async_suffix_capt}"
|
|
277
|
+
|
|
278
|
+
def serialize_test(self) -> str:
|
|
279
|
+
return self.env.get_template("test.py.jinja2").render(
|
|
280
|
+
imports=self.import_test,
|
|
281
|
+
code_model=self.code_model,
|
|
282
|
+
test=self.get_test(),
|
|
283
|
+
)
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
4
|
# license information.
|
|
5
5
|
# --------------------------------------------------------------------------
|
|
6
|
-
|
|
6
|
+
import json
|
|
7
|
+
from typing import Optional, List, Any
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
|
|
9
10
|
from ..models import Client, OperationGroup
|
|
@@ -44,3 +45,24 @@ def get_all_operation_groups_recursively(clients: List[Client]) -> List[Operatio
|
|
|
44
45
|
if operation_groups[-1].operation_groups:
|
|
45
46
|
queue.extend(operation_groups[-1].operation_groups)
|
|
46
47
|
return operation_groups
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _improve_json_string(template_representation: str) -> Any:
|
|
51
|
+
origin = template_representation.split("\n")
|
|
52
|
+
final = []
|
|
53
|
+
for line in origin:
|
|
54
|
+
idx0 = line.find("#")
|
|
55
|
+
idx1 = line.rfind('"')
|
|
56
|
+
modified_line = ""
|
|
57
|
+
if idx0 > -1 and idx1 > -1:
|
|
58
|
+
modified_line = line[:idx0] + line[idx1:] + " " + line[idx0:idx1] + "\n"
|
|
59
|
+
else:
|
|
60
|
+
modified_line = line + "\n"
|
|
61
|
+
modified_line = modified_line.replace('"', "").replace("\\", '"')
|
|
62
|
+
final.append(modified_line)
|
|
63
|
+
return "".join(final)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def json_dumps_template(template_representation: Any) -> Any:
|
|
67
|
+
# only for template use, since it wraps everything in strings
|
|
68
|
+
return _improve_json_string(json.dumps(template_representation, indent=4))
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
{{ code_model.options['license_header'] }}
|
|
3
|
+
import os
|
|
4
|
+
import pytest
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
from devtools_testutils import test_proxy, add_general_regex_sanitizer, add_body_key_sanitizer, add_header_regex_sanitizer
|
|
7
|
+
|
|
8
|
+
load_dotenv()
|
|
9
|
+
|
|
10
|
+
# aovid record sensitive identity information in recordings
|
|
11
|
+
@pytest.fixture(scope="session", autouse=True)
|
|
12
|
+
def add_sanitizers(test_proxy):
|
|
13
|
+
{% for test_name in test_names %}
|
|
14
|
+
{% set prefix_upper = test_name.prefix|upper %}
|
|
15
|
+
{% set prefix_lower = test_name.prefix|lower %}
|
|
16
|
+
{{ prefix_lower }}_subscription_id = os.environ.get("{{ prefix_upper }}_SUBSCRIPTION_ID", "00000000-0000-0000-0000-000000000000")
|
|
17
|
+
{{ prefix_lower }}_tenant_id = os.environ.get("{{ prefix_upper }}_TENANT_ID", "00000000-0000-0000-0000-000000000000")
|
|
18
|
+
{{ prefix_lower }}_client_id = os.environ.get("{{ prefix_upper }}_CLIENT_ID", "00000000-0000-0000-0000-000000000000")
|
|
19
|
+
{{ prefix_lower }}_client_secret = os.environ.get("{{ prefix_upper }}_CLIENT_SECRET", "00000000-0000-0000-0000-000000000000")
|
|
20
|
+
add_general_regex_sanitizer(regex={{ prefix_lower }}_subscription_id, value="00000000-0000-0000-0000-000000000000")
|
|
21
|
+
add_general_regex_sanitizer(regex={{ prefix_lower }}_tenant_id, value="00000000-0000-0000-0000-000000000000")
|
|
22
|
+
add_general_regex_sanitizer(regex={{ prefix_lower }}_client_id, value="00000000-0000-0000-0000-000000000000")
|
|
23
|
+
add_general_regex_sanitizer(regex={{ prefix_lower }}_client_secret, value="00000000-0000-0000-0000-000000000000")
|
|
24
|
+
|
|
25
|
+
{% endfor %}
|
|
26
|
+
add_header_regex_sanitizer(key="Set-Cookie", value="[set-cookie;]")
|
|
27
|
+
add_header_regex_sanitizer(key="Cookie", value="cookie;")
|
|
28
|
+
add_body_key_sanitizer(json_path="$..access_token", value="access_token")
|
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
# --------------------------------------------------------------------------
|
|
7
7
|
# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
|
|
8
8
|
|
|
9
|
+
import copy
|
|
9
10
|
import calendar
|
|
10
11
|
import decimal
|
|
11
12
|
import functools
|
|
@@ -655,6 +656,12 @@ def _deserialize_sequence(
|
|
|
655
656
|
return obj
|
|
656
657
|
return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
|
|
657
658
|
|
|
659
|
+
def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.Any]:
|
|
660
|
+
return sorted(
|
|
661
|
+
types,
|
|
662
|
+
key=lambda x: hasattr(x, "__name__") and x.__name__.lower() in ("str", "float", "int", "bool"),
|
|
663
|
+
)
|
|
664
|
+
|
|
658
665
|
def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
|
|
659
666
|
annotation: typing.Any,
|
|
660
667
|
module: typing.Optional[str],
|
|
@@ -696,21 +703,25 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
|
|
|
696
703
|
# is it optional?
|
|
697
704
|
try:
|
|
698
705
|
if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
706
|
+
if len(annotation.__args__) <= 2: # pyright: ignore
|
|
707
|
+
if_obj_deserializer = _get_deserialize_callable_from_annotation(
|
|
708
|
+
next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
return functools.partial(_deserialize_with_optional, if_obj_deserializer)
|
|
712
|
+
# the type is Optional[Union[...]], we need to remove the None type from the Union
|
|
713
|
+
annotation_copy = copy.copy(annotation)
|
|
714
|
+
annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore
|
|
715
|
+
return _get_deserialize_callable_from_annotation(annotation_copy, module, rf)
|
|
704
716
|
except AttributeError:
|
|
705
717
|
pass
|
|
706
718
|
|
|
719
|
+
# is it union?
|
|
707
720
|
if getattr(annotation, "__origin__", None) is typing.Union:
|
|
708
721
|
# initial ordering is we make `string` the last deserialization option, because it is often them most generic
|
|
709
722
|
deserializers = [
|
|
710
723
|
_get_deserialize_callable_from_annotation(arg, module, rf)
|
|
711
|
-
for arg in
|
|
712
|
-
annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
|
|
713
|
-
)
|
|
724
|
+
for arg in _sorted_annotations(annotation.__args__) # pyright: ignore
|
|
714
725
|
]
|
|
715
726
|
|
|
716
727
|
return functools.partial(_deserialize_with_union, deserializers)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
{% set prefix_lower = test.prefix|lower %}
|
|
2
|
+
{% set async = "async " if test.is_async else "" %}
|
|
3
|
+
{% set async_suffix = "_async" if test.is_async else "" %}
|
|
4
|
+
# coding=utf-8
|
|
5
|
+
{{ code_model.options['license_header'] }}
|
|
6
|
+
import pytest
|
|
7
|
+
{{ imports }}
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@pytest.mark.skip("you may need to update the auto-generated test case before run it")
|
|
11
|
+
class {{ test.test_class_name }}({{ test.base_test_class_name }}):
|
|
12
|
+
{% for testcase in test.testcases %}
|
|
13
|
+
@{{ test.preparer_name }}()
|
|
14
|
+
@recorded_by_proxy{{ async_suffix }}
|
|
15
|
+
{{ async }}def test_{{ testcase.operation.name }}(self, {{ prefix_lower }}_endpoint):
|
|
16
|
+
client = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint)
|
|
17
|
+
{{testcase.response }}client{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
|
|
18
|
+
{% for key, value in testcase.params.items() %}
|
|
19
|
+
{{ key }}={{ value|indent(12) }},
|
|
20
|
+
{% endfor %}
|
|
21
|
+
){{ testcase.operation_suffix }}
|
|
22
|
+
{{ testcase.extra_operation }}
|
|
23
|
+
# please add some check logic here by yourself
|
|
24
|
+
# ...
|
|
25
|
+
|
|
26
|
+
{% endfor %}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
{{ code_model.options['license_header'] }}
|
|
3
|
+
{{ imports }}
|
|
4
|
+
|
|
5
|
+
{% for test_name in test_names %}
|
|
6
|
+
{% set extra_async = ", is_async=True" if test_name.is_async else ""%}
|
|
7
|
+
{% set prefix_lower = test_name.prefix|lower %}
|
|
8
|
+
class {{ test_name.base_test_class_name }}(AzureRecordedTestCase):
|
|
9
|
+
|
|
10
|
+
def {{ test_name.create_client_name }}(self, endpoint):
|
|
11
|
+
credential = self.get_credential({{ test_name.client_name }}{{ extra_async }})
|
|
12
|
+
return self.create_client_from_credential(
|
|
13
|
+
{{ test_name.client_name }},
|
|
14
|
+
credential=credential,
|
|
15
|
+
endpoint=endpoint,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
{% if not test_name.is_async %}
|
|
19
|
+
{{ test_name.preparer_name }} = functools.partial(
|
|
20
|
+
PowerShellPreparer,
|
|
21
|
+
"{{ prefix_lower }}",
|
|
22
|
+
{{ prefix_lower }}_endpoint="https://fake_{{ prefix_lower }}_endpoint.com"
|
|
23
|
+
)
|
|
24
|
+
{% endif %}
|
|
25
|
+
|
|
26
|
+
{% endfor %}
|