@autorest/python 6.48.2 → 6.49.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/generator/build/lib/pygen/codegen/serializers/__init__.py +69 -34
- package/generator/build/lib/pygen/codegen/serializers/import_serializer.py +20 -9
- package/generator/build/lib/pygen/codegen/serializers/sample_serializer.py +23 -13
- package/generator/build/lib/pygen/codegen/serializers/test_serializer.py +67 -88
- package/generator/build/lib/pygen/codegen/serializers/utils.py +37 -0
- package/generator/build/lib/pygen/codegen/templates/model_base.py.jinja2 +3 -3
- package/generator/build/lib/pygen/codegen/templates/test.py.jinja2 +22 -18
- package/generator/component-detection-pip-report.json +1 -1
- package/generator/dist/pygen-0.1.0-py3-none-any.whl +0 -0
- package/generator/pygen/codegen/serializers/__init__.py +69 -34
- package/generator/pygen/codegen/serializers/import_serializer.py +20 -9
- package/generator/pygen/codegen/serializers/sample_serializer.py +23 -13
- package/generator/pygen/codegen/serializers/test_serializer.py +67 -88
- package/generator/pygen/codegen/serializers/utils.py +37 -0
- package/generator/pygen/codegen/templates/model_base.py.jinja2 +3 -3
- package/generator/pygen/codegen/templates/test.py.jinja2 +22 -18
- package/package.json +2 -2
- package/scripts/__pycache__/venvtools.cpython-310.pyc +0 -0
|
@@ -34,11 +34,7 @@ from .sample_serializer import SampleSerializer
|
|
|
34
34
|
from .test_serializer import TestSerializer, TestGeneralSerializer
|
|
35
35
|
from .types_serializer import TypesSerializer
|
|
36
36
|
from ...utils import to_snake_case, VALID_PACKAGE_MODE
|
|
37
|
-
from .utils import
|
|
38
|
-
extract_sample_name,
|
|
39
|
-
get_namespace_from_package_name,
|
|
40
|
-
get_namespace_config,
|
|
41
|
-
)
|
|
37
|
+
from .utils import extract_sample_name, get_namespace_from_package_name, get_namespace_config, hash_file_import
|
|
42
38
|
|
|
43
39
|
_LOGGER = logging.getLogger(__name__)
|
|
44
40
|
|
|
@@ -536,39 +532,72 @@ class JinjaSerializer(ReaderAndWriter):
|
|
|
536
532
|
def _generated_tests_samples_folder(self, folder_name: str) -> Path:
|
|
537
533
|
return self._root_of_sdk / folder_name
|
|
538
534
|
|
|
535
|
+
def _process_operation_samples(
|
|
536
|
+
self,
|
|
537
|
+
samples: dict,
|
|
538
|
+
env: Environment,
|
|
539
|
+
op_group,
|
|
540
|
+
operation,
|
|
541
|
+
import_sample_cache: dict[tuple[str, str], str],
|
|
542
|
+
out_path: Path,
|
|
543
|
+
sample_additional_folder: Path,
|
|
544
|
+
) -> None:
|
|
545
|
+
"""Process samples for a single operation."""
|
|
546
|
+
for sample_value in samples.values():
|
|
547
|
+
file = sample_value.get("x-ms-original-file", "sample.json")
|
|
548
|
+
file_name = to_snake_case(extract_sample_name(file)) + ".py"
|
|
549
|
+
try:
|
|
550
|
+
sample_ser = SampleSerializer(
|
|
551
|
+
code_model=self.code_model,
|
|
552
|
+
env=env,
|
|
553
|
+
operation_group=op_group,
|
|
554
|
+
operation=operation,
|
|
555
|
+
sample=sample_value,
|
|
556
|
+
file_name=file_name,
|
|
557
|
+
)
|
|
558
|
+
file_import = sample_ser.get_file_import()
|
|
559
|
+
imports_hash_string = hash_file_import(file_import)
|
|
560
|
+
cache_key = (op_group.client.client_namespace, imports_hash_string)
|
|
561
|
+
if cache_key not in import_sample_cache:
|
|
562
|
+
import_sample_cache[cache_key] = sample_ser.get_imports_from_file_import(file_import)
|
|
563
|
+
sample_ser.imports = import_sample_cache[cache_key]
|
|
564
|
+
|
|
565
|
+
content = sample_ser.serialize()
|
|
566
|
+
output_path = out_path / sample_additional_folder / _sample_output_path(file) / file_name
|
|
567
|
+
self.write_file(output_path, content)
|
|
568
|
+
except Exception as e: # pylint: disable=broad-except
|
|
569
|
+
_LOGGER.error("error happens in sample %s: %s", file, e)
|
|
570
|
+
|
|
539
571
|
def _serialize_and_write_sample(self, env: Environment):
|
|
540
572
|
out_path = self._generated_tests_samples_folder("generated_samples")
|
|
573
|
+
sample_additional_folder = self.sample_additional_folder
|
|
574
|
+
|
|
575
|
+
# Cache import_test per (client_namespace, imports_hash_string) since it's expensive to compute
|
|
576
|
+
import_sample_cache: dict[tuple[str, str], str] = {}
|
|
577
|
+
|
|
541
578
|
for client in self.code_model.clients:
|
|
542
579
|
for op_group in client.operation_groups:
|
|
543
580
|
for operation in op_group.operations:
|
|
544
581
|
samples = operation.yaml_data.get("samples")
|
|
545
582
|
if not samples or operation.name.startswith("_"):
|
|
546
583
|
continue
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
operation_group=op_group,
|
|
557
|
-
operation=operation,
|
|
558
|
-
sample=value,
|
|
559
|
-
file_name=file_name,
|
|
560
|
-
).serialize(),
|
|
561
|
-
)
|
|
562
|
-
except Exception as e: # pylint: disable=broad-except
|
|
563
|
-
# sample generation shall not block code generation, so just log error
|
|
564
|
-
log_error = f"error happens in sample {file}: {e}"
|
|
565
|
-
_LOGGER.error(log_error)
|
|
584
|
+
self._process_operation_samples(
|
|
585
|
+
samples,
|
|
586
|
+
env,
|
|
587
|
+
op_group,
|
|
588
|
+
operation,
|
|
589
|
+
import_sample_cache,
|
|
590
|
+
out_path,
|
|
591
|
+
sample_additional_folder,
|
|
592
|
+
)
|
|
566
593
|
|
|
567
594
|
def _serialize_and_write_test(self, env: Environment):
|
|
568
595
|
self.code_model.for_test = True
|
|
569
596
|
out_path = self._generated_tests_samples_folder("generated_tests")
|
|
597
|
+
|
|
570
598
|
general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
|
|
571
599
|
self.write_file(out_path / "conftest.py", general_serializer.serialize_conftest())
|
|
600
|
+
|
|
572
601
|
if not self.code_model.options["azure-arm"]:
|
|
573
602
|
for async_mode in (True, False):
|
|
574
603
|
async_suffix = "_async" if async_mode else ""
|
|
@@ -578,18 +607,24 @@ class JinjaSerializer(ReaderAndWriter):
|
|
|
578
607
|
general_serializer.serialize_testpreparer(),
|
|
579
608
|
)
|
|
580
609
|
|
|
610
|
+
# Generate test files - reuse serializer per operation group, toggle async_mode
|
|
611
|
+
# Cache import_test per (client.name, async_mode) since it's expensive to compute
|
|
612
|
+
import_test_cache: dict[tuple[str, bool], str] = {}
|
|
581
613
|
for client in self.code_model.clients:
|
|
582
614
|
for og in client.operation_groups:
|
|
615
|
+
# Create serializer once per operation group
|
|
583
616
|
test_serializer = TestSerializer(self.code_model, env, client=client, operation_group=og)
|
|
584
|
-
|
|
585
|
-
|
|
617
|
+
try:
|
|
618
|
+
for async_mode in (True, False):
|
|
586
619
|
test_serializer.async_mode = async_mode
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
test_serializer.
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
620
|
+
cache_key = (client.name, async_mode)
|
|
621
|
+
if cache_key not in import_test_cache:
|
|
622
|
+
import_test_cache[cache_key] = test_serializer.get_import_test()
|
|
623
|
+
test_serializer.import_test = import_test_cache[cache_key]
|
|
624
|
+
content = test_serializer.serialize_test()
|
|
625
|
+
output_path = out_path / f"{to_snake_case(test_serializer.test_class_name)}.py"
|
|
626
|
+
self.write_file(output_path, content)
|
|
627
|
+
except Exception as e: # pylint: disable=broad-except
|
|
628
|
+
_LOGGER.error("error happens in test generation for operation group %s: %s", og.class_name, e)
|
|
629
|
+
|
|
595
630
|
self.code_model.for_test = False
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
4
|
# license information.
|
|
5
5
|
# --------------------------------------------------------------------------
|
|
6
|
-
from copy import deepcopy
|
|
7
6
|
from ..models.imports import (
|
|
8
7
|
ImportType,
|
|
9
8
|
FileImport,
|
|
@@ -76,14 +75,26 @@ class FileImportSerializer:
|
|
|
76
75
|
def _get_imports_list(self, baseline_typing_section: TypingSection, add_conditional_typing: bool):
|
|
77
76
|
# If this is a python 3 file, our regular imports include the CONDITIONAL category
|
|
78
77
|
# If this is not a python 3 file, our typing imports include the CONDITIONAL category
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
78
|
+
|
|
79
|
+
# Get imports that already match the baseline_typing_section
|
|
80
|
+
result = list(self.file_import.get_imports_from_section(baseline_typing_section))
|
|
81
|
+
|
|
82
|
+
if add_conditional_typing:
|
|
83
|
+
# Get conditional imports and create new ImportModel instances with modified typing_section
|
|
84
|
+
conditional_imports = self.file_import.get_imports_from_section(TypingSection.CONDITIONAL)
|
|
85
|
+
for i in conditional_imports:
|
|
86
|
+
# Create a new ImportModel with the baseline_typing_section instead of CONDITIONAL
|
|
87
|
+
result.append(
|
|
88
|
+
ImportModel(
|
|
89
|
+
typing_section=baseline_typing_section,
|
|
90
|
+
import_type=i.import_type,
|
|
91
|
+
module_name=i.module_name,
|
|
92
|
+
submodule_name=i.submodule_name,
|
|
93
|
+
alias=i.alias,
|
|
94
|
+
version_modules=i.version_modules,
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
return result
|
|
87
98
|
|
|
88
99
|
def _add_type_checking_import(self):
|
|
89
100
|
if any(self.file_import.get_imports_from_section(TypingSection.TYPING)):
|
|
@@ -20,6 +20,7 @@ from ..models import (
|
|
|
20
20
|
BodyParameter,
|
|
21
21
|
FileImport,
|
|
22
22
|
)
|
|
23
|
+
from .utils import create_fake_value
|
|
23
24
|
|
|
24
25
|
_LOGGER = logging.getLogger(__name__)
|
|
25
26
|
|
|
@@ -40,8 +41,17 @@ class SampleSerializer(BaseSerializer):
|
|
|
40
41
|
self.sample = sample
|
|
41
42
|
self.file_name = file_name
|
|
42
43
|
self.sample_params = sample.get("parameters", {})
|
|
44
|
+
self._imports: str = ""
|
|
43
45
|
|
|
44
|
-
|
|
46
|
+
@property
|
|
47
|
+
def imports(self) -> str:
|
|
48
|
+
return self._imports
|
|
49
|
+
|
|
50
|
+
@imports.setter
|
|
51
|
+
def imports(self, value: str) -> None:
|
|
52
|
+
self._imports = value
|
|
53
|
+
|
|
54
|
+
def get_file_import(self) -> FileImport:
|
|
45
55
|
imports = FileImport(self.code_model)
|
|
46
56
|
client = self.operation_group.client
|
|
47
57
|
namespace = client.client_namespace
|
|
@@ -59,7 +69,12 @@ class SampleSerializer(BaseSerializer):
|
|
|
59
69
|
for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
|
|
60
70
|
if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params:
|
|
61
71
|
imports.merge(param.type.imports_for_sample())
|
|
62
|
-
|
|
72
|
+
|
|
73
|
+
return imports
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def get_imports_from_file_import(file_import: FileImport) -> str:
|
|
77
|
+
return str(FileImportSerializer(file_import, True))
|
|
63
78
|
|
|
64
79
|
def _client_params(self) -> dict[str, Any]:
|
|
65
80
|
# client params
|
|
@@ -97,19 +112,14 @@ class SampleSerializer(BaseSerializer):
|
|
|
97
112
|
|
|
98
113
|
# prepare operation parameters
|
|
99
114
|
def _operation_params(self) -> dict[str, Any]:
|
|
100
|
-
params = [
|
|
101
|
-
p
|
|
102
|
-
for p in (self.operation.parameters.positional + self.operation.parameters.keyword_only)
|
|
103
|
-
if not p.client_default_value
|
|
104
|
-
]
|
|
105
|
-
failure_info = "fail to find required param named {}"
|
|
106
115
|
operation_params = {}
|
|
107
|
-
for param in
|
|
108
|
-
if not param.optional:
|
|
116
|
+
for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
|
|
117
|
+
if not param.optional and not param.client_default_value:
|
|
109
118
|
param_value = self.sample_params.get(param.wire_name)
|
|
110
119
|
if not param_value:
|
|
111
|
-
|
|
112
|
-
|
|
120
|
+
operation_params[param.client_name] = create_fake_value(param.type)
|
|
121
|
+
else:
|
|
122
|
+
operation_params[param.client_name] = self.handle_param(param, param_value)
|
|
113
123
|
return operation_params
|
|
114
124
|
|
|
115
125
|
def _operation_group_name(self) -> str:
|
|
@@ -154,7 +164,7 @@ class SampleSerializer(BaseSerializer):
|
|
|
154
164
|
operation_params=self._operation_params(),
|
|
155
165
|
operation_group_name=self._operation_group_name(),
|
|
156
166
|
operation_name=self._operation_name(),
|
|
157
|
-
imports=self.
|
|
167
|
+
imports=self.imports,
|
|
158
168
|
client_params=self._client_params(),
|
|
159
169
|
origin_file=self._origin_file(),
|
|
160
170
|
return_var=return_var,
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
4
|
# license information.
|
|
5
5
|
# --------------------------------------------------------------------------
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
from jinja2 import Environment
|
|
8
8
|
|
|
9
9
|
from .import_serializer import FileImportSerializer
|
|
@@ -14,12 +14,9 @@ from ..models import (
|
|
|
14
14
|
OperationGroup,
|
|
15
15
|
Client,
|
|
16
16
|
OperationType,
|
|
17
|
-
ModelType,
|
|
18
|
-
BaseType,
|
|
19
|
-
CombinedType,
|
|
20
17
|
FileImport,
|
|
21
18
|
)
|
|
22
|
-
from .utils import
|
|
19
|
+
from .utils import create_fake_value
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
def is_lro(operation_type: str) -> bool:
|
|
@@ -39,30 +36,15 @@ class TestName:
|
|
|
39
36
|
self.code_model = code_model
|
|
40
37
|
self.client_name = client_name
|
|
41
38
|
self.async_mode = async_mode
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@property
|
|
52
|
-
def prefix(self) -> str:
|
|
53
|
-
return self.client_name.replace("Client", "")
|
|
54
|
-
|
|
55
|
-
@property
|
|
56
|
-
def preparer_name(self) -> str:
|
|
57
|
-
if self.code_model.options["azure-arm"]:
|
|
58
|
-
return "RandomNameResourceGroupPreparer"
|
|
59
|
-
return self.prefix + "Preparer"
|
|
60
|
-
|
|
61
|
-
@property
|
|
62
|
-
def base_test_class_name(self) -> str:
|
|
63
|
-
if self.code_model.options["azure-arm"]:
|
|
64
|
-
return "AzureMgmtRecordedTestCase"
|
|
65
|
-
return f"{self.client_name}TestBase{self.async_suffix_capt}"
|
|
39
|
+
# Pre-compute values for render speed optimization
|
|
40
|
+
self.async_suffix_capt = "Async" if async_mode else ""
|
|
41
|
+
self.create_client_name = "create_async_client" if async_mode else "create_client"
|
|
42
|
+
self.prefix = client_name.replace("Client", "")
|
|
43
|
+
is_azure_arm = code_model.options["azure-arm"]
|
|
44
|
+
self.preparer_name = "RandomNameResourceGroupPreparer" if is_azure_arm else self.prefix + "Preparer"
|
|
45
|
+
self.base_test_class_name = (
|
|
46
|
+
"AzureMgmtRecordedTestCase" if is_azure_arm else f"{client_name}TestBase{self.async_suffix_capt}"
|
|
47
|
+
)
|
|
66
48
|
|
|
67
49
|
|
|
68
50
|
class TestCase:
|
|
@@ -73,50 +55,52 @@ class TestCase:
|
|
|
73
55
|
operation: OperationType,
|
|
74
56
|
*,
|
|
75
57
|
async_mode: bool = False,
|
|
58
|
+
is_azure_arm: bool = False,
|
|
76
59
|
) -> None:
|
|
77
60
|
self.operation_groups = operation_groups
|
|
78
|
-
self.params = params
|
|
79
61
|
self.operation = operation
|
|
80
62
|
self.async_mode = async_mode
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
63
|
+
self.is_azure_arm = is_azure_arm
|
|
64
|
+
# Pre-compute params
|
|
65
|
+
if is_azure_arm:
|
|
66
|
+
self.params = {k: ("resource_group.name" if k == "resource_group_name" else v) for k, v in params.items()}
|
|
67
|
+
else:
|
|
68
|
+
self.params = params
|
|
69
|
+
# Pre-compute name
|
|
70
|
+
if operation_groups[-1].is_mixin:
|
|
71
|
+
self.name = operation.name
|
|
72
|
+
else:
|
|
73
|
+
self.name = "_".join([og.property_name for og in operation_groups] + [operation.name])
|
|
74
|
+
# Pre-compute operation_group_prefix
|
|
75
|
+
if operation_groups[-1].is_mixin:
|
|
76
|
+
self.operation_group_prefix = ""
|
|
77
|
+
else:
|
|
78
|
+
self.operation_group_prefix = "." + ".".join([og.property_name for og in operation_groups])
|
|
79
|
+
# Pre-compute response
|
|
80
|
+
operation_type = operation.operation_type
|
|
81
|
+
if async_mode:
|
|
82
|
+
if is_lro(operation_type):
|
|
83
|
+
self.response = "response = await (await "
|
|
84
|
+
elif is_common_operation(operation_type):
|
|
85
|
+
self.response = "response = await "
|
|
86
|
+
else:
|
|
87
|
+
self.response = "response = "
|
|
88
|
+
else:
|
|
89
|
+
self.response = "response = "
|
|
90
|
+
# Pre-compute lro_comment
|
|
91
|
+
self.lro_comment = " # call '.result()' to poll until service return final result"
|
|
92
|
+
# Pre-compute operation_suffix
|
|
93
|
+
if is_lro(operation_type):
|
|
94
|
+
extra = ")" if async_mode else ""
|
|
95
|
+
self.operation_suffix = f"{extra}.result(){self.lro_comment}"
|
|
96
|
+
else:
|
|
97
|
+
self.operation_suffix = ""
|
|
98
|
+
# Pre-compute extra_operation
|
|
99
|
+
if is_paging(operation_type):
|
|
100
|
+
async_str = "async " if async_mode else ""
|
|
101
|
+
self.extra_operation = f"result = [r {async_str}for r in response]"
|
|
102
|
+
else:
|
|
103
|
+
self.extra_operation = ""
|
|
120
104
|
|
|
121
105
|
|
|
122
106
|
class Test(TestName):
|
|
@@ -189,9 +173,17 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
189
173
|
super().__init__(code_model, env, async_mode=async_mode)
|
|
190
174
|
self.client = client
|
|
191
175
|
self.operation_group = operation_group
|
|
176
|
+
self._import_test: str = ""
|
|
192
177
|
|
|
193
178
|
@property
|
|
194
|
-
def import_test(self) ->
|
|
179
|
+
def import_test(self) -> str:
|
|
180
|
+
return self._import_test
|
|
181
|
+
|
|
182
|
+
@import_test.setter
|
|
183
|
+
def import_test(self, value: str) -> None:
|
|
184
|
+
self._import_test = value
|
|
185
|
+
|
|
186
|
+
def get_import_test(self) -> str:
|
|
195
187
|
imports = self.init_file_import()
|
|
196
188
|
test_name = TestName(self.code_model, self.client.name, async_mode=self.async_mode)
|
|
197
189
|
async_suffix = "_async" if self.async_mode else ""
|
|
@@ -212,7 +204,7 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
212
204
|
)
|
|
213
205
|
if self.code_model.options["azure-arm"]:
|
|
214
206
|
self.add_import_client(imports)
|
|
215
|
-
return FileImportSerializer(imports, self.async_mode)
|
|
207
|
+
return str(FileImportSerializer(imports, self.async_mode))
|
|
216
208
|
|
|
217
209
|
@property
|
|
218
210
|
def breadth_search_operation_group(self) -> list[list[OperationGroup]]:
|
|
@@ -226,26 +218,11 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
226
218
|
queue.extend([current + [og] for og in current[-1].operation_groups])
|
|
227
219
|
return result
|
|
228
220
|
|
|
229
|
-
def get_sub_type(self, param_type: ModelType) -> ModelType:
|
|
230
|
-
if param_type.discriminated_subtypes:
|
|
231
|
-
for item in param_type.discriminated_subtypes.values():
|
|
232
|
-
return self.get_sub_type(item)
|
|
233
|
-
return param_type
|
|
234
|
-
|
|
235
|
-
def get_model_type(self, param_type: BaseType) -> Optional[ModelType]:
|
|
236
|
-
if isinstance(param_type, ModelType):
|
|
237
|
-
return param_type
|
|
238
|
-
if isinstance(param_type, CombinedType):
|
|
239
|
-
return param_type.target_model_subtype((ModelType,))
|
|
240
|
-
return None
|
|
241
|
-
|
|
242
221
|
def get_operation_params(self, operation: OperationType) -> dict[str, Any]:
|
|
243
222
|
operation_params = {}
|
|
244
223
|
required_params = [p for p in operation.parameters.method if not p.optional]
|
|
245
224
|
for param in required_params:
|
|
246
|
-
|
|
247
|
-
param_type = self.get_sub_type(model_type) if model_type else param.type
|
|
248
|
-
operation_params[param.client_name] = json_dumps_template(param_type.get_json_template_representation())
|
|
225
|
+
operation_params[param.client_name] = create_fake_value(param.type)
|
|
249
226
|
return operation_params
|
|
250
227
|
|
|
251
228
|
def get_test(self) -> Test:
|
|
@@ -260,6 +237,7 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
260
237
|
params=operation_params,
|
|
261
238
|
operation=operation,
|
|
262
239
|
async_mode=self.async_mode,
|
|
240
|
+
is_azure_arm=self.code_model.options["azure-arm"],
|
|
263
241
|
)
|
|
264
242
|
testcases.append(testcase)
|
|
265
243
|
if not testcases:
|
|
@@ -283,6 +261,7 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
283
261
|
def serialize_test(self) -> str:
|
|
284
262
|
return self.env.get_template("test.py.jinja2").render(
|
|
285
263
|
imports=self.import_test,
|
|
286
|
-
|
|
264
|
+
is_azure_arm=self.code_model.options["azure-arm"],
|
|
265
|
+
license_header=self.code_model.license_header,
|
|
287
266
|
test=self.get_test(),
|
|
288
267
|
)
|
|
@@ -7,6 +7,15 @@ import json
|
|
|
7
7
|
from typing import Optional, Any
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
|
|
10
|
+
from ..models import ModelType, BaseType, CombinedType, FileImport
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_sub_type(param_type: ModelType) -> ModelType:
|
|
14
|
+
if param_type.discriminated_subtypes:
|
|
15
|
+
for item in param_type.discriminated_subtypes.values():
|
|
16
|
+
return get_sub_type(item)
|
|
17
|
+
return param_type
|
|
18
|
+
|
|
10
19
|
|
|
11
20
|
def method_signature_and_response_type_annotation_template(
|
|
12
21
|
*,
|
|
@@ -52,3 +61,31 @@ def _improve_json_string(template_representation: str) -> Any:
|
|
|
52
61
|
def json_dumps_template(template_representation: Any) -> Any:
|
|
53
62
|
# only for template use, since it wraps everything in strings
|
|
54
63
|
return _improve_json_string(json.dumps(template_representation, indent=4))
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def create_fake_value(param_type: BaseType) -> Any:
|
|
67
|
+
"""Create a fake value for a parameter type by getting its JSON template representation.
|
|
68
|
+
|
|
69
|
+
This function generates a fake value suitable for samples and tests.
|
|
70
|
+
|
|
71
|
+
:param param_type: The parameter type to create a fake value for.
|
|
72
|
+
:return: A string representation of the fake value.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
model_type: Optional[ModelType] = None
|
|
76
|
+
if isinstance(param_type, ModelType):
|
|
77
|
+
model_type = param_type
|
|
78
|
+
elif isinstance(param_type, CombinedType):
|
|
79
|
+
model_type = param_type.target_model_subtype((ModelType,))
|
|
80
|
+
resolved_type = get_sub_type(model_type) if model_type else param_type
|
|
81
|
+
return json_dumps_template(resolved_type.get_json_template_representation())
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def hash_file_import(file_import: FileImport) -> str:
|
|
85
|
+
"""Generate a hash for a FileImport object based on its imports.
|
|
86
|
+
|
|
87
|
+
:param file_import: The FileImport object to generate a hash for.
|
|
88
|
+
:return: A string representing the hash of the FileImport object.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
return "".join(sorted({str(hash(i)) for i in file_import.imports}))
|
|
@@ -1062,7 +1062,7 @@ def _failsafe_deserialize(
|
|
|
1062
1062
|
) -> typing.Any:
|
|
1063
1063
|
try:
|
|
1064
1064
|
return _deserialize(deserializer, response.json(), module, rf, format)
|
|
1065
|
-
except
|
|
1065
|
+
except Exception: # pylint: disable=broad-except
|
|
1066
1066
|
_LOGGER.warning(
|
|
1067
1067
|
"Ran into a deserialization error. Ignoring since this is failsafe deserialization",
|
|
1068
1068
|
exc_info=True
|
|
@@ -1072,11 +1072,11 @@ def _failsafe_deserialize(
|
|
|
1072
1072
|
|
|
1073
1073
|
def _failsafe_deserialize_xml(
|
|
1074
1074
|
deserializer: typing.Any,
|
|
1075
|
-
response: HttpResponse,
|
|
1075
|
+
response: HttpResponse,
|
|
1076
1076
|
) -> typing.Any:
|
|
1077
1077
|
try:
|
|
1078
1078
|
return _deserialize_xml(deserializer, response.text())
|
|
1079
|
-
except
|
|
1079
|
+
except Exception: # pylint: disable=broad-except
|
|
1080
1080
|
_LOGGER.warning(
|
|
1081
1081
|
"Ran into a deserialization error. Ignoring since this is failsafe deserialization",
|
|
1082
1082
|
exc_info=True
|
|
@@ -1,48 +1,51 @@
|
|
|
1
1
|
{% set prefix_lower = test.prefix|lower %}
|
|
2
|
-
{% set client_var = "self.client" if
|
|
2
|
+
{% set client_var = "self.client" if is_azure_arm else "client" %}
|
|
3
3
|
{% set async = "async " if test.async_mode else "" %}
|
|
4
4
|
{% set async_suffix = "_async" if test.async_mode else "" %}
|
|
5
5
|
# coding=utf-8
|
|
6
|
-
{% if
|
|
7
|
-
{{
|
|
6
|
+
{% if license_header %}
|
|
7
|
+
{{ license_header }}
|
|
8
8
|
{% endif %}
|
|
9
9
|
import pytest
|
|
10
10
|
{{ imports }}
|
|
11
11
|
|
|
12
|
-
{% if
|
|
12
|
+
{% if is_azure_arm %}
|
|
13
13
|
AZURE_LOCATION = "eastus"
|
|
14
14
|
{% endif %}
|
|
15
15
|
|
|
16
16
|
@pytest.mark.skip("you may need to update the auto-generated test case before run it")
|
|
17
17
|
class {{ test.test_class_name }}({{ test.base_test_class_name }}):
|
|
18
|
-
{% if
|
|
18
|
+
{% if is_azure_arm %}
|
|
19
19
|
def setup_method(self, method):
|
|
20
20
|
{% if test.async_mode %}
|
|
21
21
|
self.client = self.create_mgmt_client({{ test.client_name }}, is_async=True)
|
|
22
22
|
{% else %}
|
|
23
23
|
self.client = self.create_mgmt_client({{ test.client_name }})
|
|
24
24
|
{% endif %}
|
|
25
|
-
|
|
25
|
+
|
|
26
26
|
{% for testcase in test.testcases %}
|
|
27
|
-
{% if code_model.options["azure-arm"] %}
|
|
28
27
|
@{{ test.preparer_name }}(location=AZURE_LOCATION)
|
|
29
|
-
{% else %}
|
|
30
|
-
@{{ test.preparer_name }}()
|
|
31
|
-
{% endif %}
|
|
32
28
|
@recorded_by_proxy{{ async_suffix }}
|
|
33
|
-
{% if code_model.options["azure-arm"] %}
|
|
34
29
|
{{ async }}def test_{{ testcase.name }}(self, resource_group):
|
|
35
|
-
|
|
30
|
+
{{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
|
|
31
|
+
{% for key, value in testcase.params.items() %}
|
|
32
|
+
{{ key }}={{ value }},
|
|
33
|
+
{% endfor %}
|
|
34
|
+
){{ testcase.operation_suffix }}
|
|
35
|
+
{{ testcase.extra_operation }}
|
|
36
|
+
# please add some check logic here by yourself
|
|
37
|
+
# ...
|
|
38
|
+
|
|
39
|
+
{% endfor %}
|
|
40
|
+
{% else %}
|
|
41
|
+
{% for testcase in test.testcases %}
|
|
42
|
+
@{{ test.preparer_name }}()
|
|
43
|
+
@recorded_by_proxy{{ async_suffix }}
|
|
36
44
|
{{ async }}def test_{{ testcase.name }}(self, {{ prefix_lower }}_endpoint):
|
|
37
45
|
{{ client_var }} = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint)
|
|
38
|
-
{% endif %}
|
|
39
46
|
{{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
|
|
40
47
|
{% for key, value in testcase.params.items() %}
|
|
41
|
-
{
|
|
42
|
-
{{ key }}=resource_group.name,
|
|
43
|
-
{% else %}
|
|
44
|
-
{{ key }}={{ value|indent(12) }},
|
|
45
|
-
{% endif %}
|
|
48
|
+
{{ key }}={{ value }},
|
|
46
49
|
{% endfor %}
|
|
47
50
|
){{ testcase.operation_suffix }}
|
|
48
51
|
{{ testcase.extra_operation }}
|
|
@@ -50,3 +53,4 @@ class {{ test.test_class_name }}({{ test.base_test_class_name }}):
|
|
|
50
53
|
# ...
|
|
51
54
|
|
|
52
55
|
{% endfor %}
|
|
56
|
+
{% endif %}
|
|
Binary file
|
|
@@ -34,11 +34,7 @@ from .sample_serializer import SampleSerializer
|
|
|
34
34
|
from .test_serializer import TestSerializer, TestGeneralSerializer
|
|
35
35
|
from .types_serializer import TypesSerializer
|
|
36
36
|
from ...utils import to_snake_case, VALID_PACKAGE_MODE
|
|
37
|
-
from .utils import
|
|
38
|
-
extract_sample_name,
|
|
39
|
-
get_namespace_from_package_name,
|
|
40
|
-
get_namespace_config,
|
|
41
|
-
)
|
|
37
|
+
from .utils import extract_sample_name, get_namespace_from_package_name, get_namespace_config, hash_file_import
|
|
42
38
|
|
|
43
39
|
_LOGGER = logging.getLogger(__name__)
|
|
44
40
|
|
|
@@ -536,39 +532,72 @@ class JinjaSerializer(ReaderAndWriter):
|
|
|
536
532
|
def _generated_tests_samples_folder(self, folder_name: str) -> Path:
|
|
537
533
|
return self._root_of_sdk / folder_name
|
|
538
534
|
|
|
535
|
+
def _process_operation_samples(
|
|
536
|
+
self,
|
|
537
|
+
samples: dict,
|
|
538
|
+
env: Environment,
|
|
539
|
+
op_group,
|
|
540
|
+
operation,
|
|
541
|
+
import_sample_cache: dict[tuple[str, str], str],
|
|
542
|
+
out_path: Path,
|
|
543
|
+
sample_additional_folder: Path,
|
|
544
|
+
) -> None:
|
|
545
|
+
"""Process samples for a single operation."""
|
|
546
|
+
for sample_value in samples.values():
|
|
547
|
+
file = sample_value.get("x-ms-original-file", "sample.json")
|
|
548
|
+
file_name = to_snake_case(extract_sample_name(file)) + ".py"
|
|
549
|
+
try:
|
|
550
|
+
sample_ser = SampleSerializer(
|
|
551
|
+
code_model=self.code_model,
|
|
552
|
+
env=env,
|
|
553
|
+
operation_group=op_group,
|
|
554
|
+
operation=operation,
|
|
555
|
+
sample=sample_value,
|
|
556
|
+
file_name=file_name,
|
|
557
|
+
)
|
|
558
|
+
file_import = sample_ser.get_file_import()
|
|
559
|
+
imports_hash_string = hash_file_import(file_import)
|
|
560
|
+
cache_key = (op_group.client.client_namespace, imports_hash_string)
|
|
561
|
+
if cache_key not in import_sample_cache:
|
|
562
|
+
import_sample_cache[cache_key] = sample_ser.get_imports_from_file_import(file_import)
|
|
563
|
+
sample_ser.imports = import_sample_cache[cache_key]
|
|
564
|
+
|
|
565
|
+
content = sample_ser.serialize()
|
|
566
|
+
output_path = out_path / sample_additional_folder / _sample_output_path(file) / file_name
|
|
567
|
+
self.write_file(output_path, content)
|
|
568
|
+
except Exception as e: # pylint: disable=broad-except
|
|
569
|
+
_LOGGER.error("error happens in sample %s: %s", file, e)
|
|
570
|
+
|
|
539
571
|
def _serialize_and_write_sample(self, env: Environment):
|
|
540
572
|
out_path = self._generated_tests_samples_folder("generated_samples")
|
|
573
|
+
sample_additional_folder = self.sample_additional_folder
|
|
574
|
+
|
|
575
|
+
# Cache import_test per (client_namespace, imports_hash_string) since it's expensive to compute
|
|
576
|
+
import_sample_cache: dict[tuple[str, str], str] = {}
|
|
577
|
+
|
|
541
578
|
for client in self.code_model.clients:
|
|
542
579
|
for op_group in client.operation_groups:
|
|
543
580
|
for operation in op_group.operations:
|
|
544
581
|
samples = operation.yaml_data.get("samples")
|
|
545
582
|
if not samples or operation.name.startswith("_"):
|
|
546
583
|
continue
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
operation_group=op_group,
|
|
557
|
-
operation=operation,
|
|
558
|
-
sample=value,
|
|
559
|
-
file_name=file_name,
|
|
560
|
-
).serialize(),
|
|
561
|
-
)
|
|
562
|
-
except Exception as e: # pylint: disable=broad-except
|
|
563
|
-
# sample generation shall not block code generation, so just log error
|
|
564
|
-
log_error = f"error happens in sample {file}: {e}"
|
|
565
|
-
_LOGGER.error(log_error)
|
|
584
|
+
self._process_operation_samples(
|
|
585
|
+
samples,
|
|
586
|
+
env,
|
|
587
|
+
op_group,
|
|
588
|
+
operation,
|
|
589
|
+
import_sample_cache,
|
|
590
|
+
out_path,
|
|
591
|
+
sample_additional_folder,
|
|
592
|
+
)
|
|
566
593
|
|
|
567
594
|
def _serialize_and_write_test(self, env: Environment):
|
|
568
595
|
self.code_model.for_test = True
|
|
569
596
|
out_path = self._generated_tests_samples_folder("generated_tests")
|
|
597
|
+
|
|
570
598
|
general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
|
|
571
599
|
self.write_file(out_path / "conftest.py", general_serializer.serialize_conftest())
|
|
600
|
+
|
|
572
601
|
if not self.code_model.options["azure-arm"]:
|
|
573
602
|
for async_mode in (True, False):
|
|
574
603
|
async_suffix = "_async" if async_mode else ""
|
|
@@ -578,18 +607,24 @@ class JinjaSerializer(ReaderAndWriter):
|
|
|
578
607
|
general_serializer.serialize_testpreparer(),
|
|
579
608
|
)
|
|
580
609
|
|
|
610
|
+
# Generate test files - reuse serializer per operation group, toggle async_mode
|
|
611
|
+
# Cache import_test per (client.name, async_mode) since it's expensive to compute
|
|
612
|
+
import_test_cache: dict[tuple[str, bool], str] = {}
|
|
581
613
|
for client in self.code_model.clients:
|
|
582
614
|
for og in client.operation_groups:
|
|
615
|
+
# Create serializer once per operation group
|
|
583
616
|
test_serializer = TestSerializer(self.code_model, env, client=client, operation_group=og)
|
|
584
|
-
|
|
585
|
-
|
|
617
|
+
try:
|
|
618
|
+
for async_mode in (True, False):
|
|
586
619
|
test_serializer.async_mode = async_mode
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
test_serializer.
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
620
|
+
cache_key = (client.name, async_mode)
|
|
621
|
+
if cache_key not in import_test_cache:
|
|
622
|
+
import_test_cache[cache_key] = test_serializer.get_import_test()
|
|
623
|
+
test_serializer.import_test = import_test_cache[cache_key]
|
|
624
|
+
content = test_serializer.serialize_test()
|
|
625
|
+
output_path = out_path / f"{to_snake_case(test_serializer.test_class_name)}.py"
|
|
626
|
+
self.write_file(output_path, content)
|
|
627
|
+
except Exception as e: # pylint: disable=broad-except
|
|
628
|
+
_LOGGER.error("error happens in test generation for operation group %s: %s", og.class_name, e)
|
|
629
|
+
|
|
595
630
|
self.code_model.for_test = False
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
4
|
# license information.
|
|
5
5
|
# --------------------------------------------------------------------------
|
|
6
|
-
from copy import deepcopy
|
|
7
6
|
from ..models.imports import (
|
|
8
7
|
ImportType,
|
|
9
8
|
FileImport,
|
|
@@ -76,14 +75,26 @@ class FileImportSerializer:
|
|
|
76
75
|
def _get_imports_list(self, baseline_typing_section: TypingSection, add_conditional_typing: bool):
|
|
77
76
|
# If this is a python 3 file, our regular imports include the CONDITIONAL category
|
|
78
77
|
# If this is not a python 3 file, our typing imports include the CONDITIONAL category
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
78
|
+
|
|
79
|
+
# Get imports that already match the baseline_typing_section
|
|
80
|
+
result = list(self.file_import.get_imports_from_section(baseline_typing_section))
|
|
81
|
+
|
|
82
|
+
if add_conditional_typing:
|
|
83
|
+
# Get conditional imports and create new ImportModel instances with modified typing_section
|
|
84
|
+
conditional_imports = self.file_import.get_imports_from_section(TypingSection.CONDITIONAL)
|
|
85
|
+
for i in conditional_imports:
|
|
86
|
+
# Create a new ImportModel with the baseline_typing_section instead of CONDITIONAL
|
|
87
|
+
result.append(
|
|
88
|
+
ImportModel(
|
|
89
|
+
typing_section=baseline_typing_section,
|
|
90
|
+
import_type=i.import_type,
|
|
91
|
+
module_name=i.module_name,
|
|
92
|
+
submodule_name=i.submodule_name,
|
|
93
|
+
alias=i.alias,
|
|
94
|
+
version_modules=i.version_modules,
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
return result
|
|
87
98
|
|
|
88
99
|
def _add_type_checking_import(self):
|
|
89
100
|
if any(self.file_import.get_imports_from_section(TypingSection.TYPING)):
|
|
@@ -20,6 +20,7 @@ from ..models import (
|
|
|
20
20
|
BodyParameter,
|
|
21
21
|
FileImport,
|
|
22
22
|
)
|
|
23
|
+
from .utils import create_fake_value
|
|
23
24
|
|
|
24
25
|
_LOGGER = logging.getLogger(__name__)
|
|
25
26
|
|
|
@@ -40,8 +41,17 @@ class SampleSerializer(BaseSerializer):
|
|
|
40
41
|
self.sample = sample
|
|
41
42
|
self.file_name = file_name
|
|
42
43
|
self.sample_params = sample.get("parameters", {})
|
|
44
|
+
self._imports: str = ""
|
|
43
45
|
|
|
44
|
-
|
|
46
|
+
@property
|
|
47
|
+
def imports(self) -> str:
|
|
48
|
+
return self._imports
|
|
49
|
+
|
|
50
|
+
@imports.setter
|
|
51
|
+
def imports(self, value: str) -> None:
|
|
52
|
+
self._imports = value
|
|
53
|
+
|
|
54
|
+
def get_file_import(self) -> FileImport:
|
|
45
55
|
imports = FileImport(self.code_model)
|
|
46
56
|
client = self.operation_group.client
|
|
47
57
|
namespace = client.client_namespace
|
|
@@ -59,7 +69,12 @@ class SampleSerializer(BaseSerializer):
|
|
|
59
69
|
for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
|
|
60
70
|
if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params:
|
|
61
71
|
imports.merge(param.type.imports_for_sample())
|
|
62
|
-
|
|
72
|
+
|
|
73
|
+
return imports
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def get_imports_from_file_import(file_import: FileImport) -> str:
|
|
77
|
+
return str(FileImportSerializer(file_import, True))
|
|
63
78
|
|
|
64
79
|
def _client_params(self) -> dict[str, Any]:
|
|
65
80
|
# client params
|
|
@@ -97,19 +112,14 @@ class SampleSerializer(BaseSerializer):
|
|
|
97
112
|
|
|
98
113
|
# prepare operation parameters
|
|
99
114
|
def _operation_params(self) -> dict[str, Any]:
|
|
100
|
-
params = [
|
|
101
|
-
p
|
|
102
|
-
for p in (self.operation.parameters.positional + self.operation.parameters.keyword_only)
|
|
103
|
-
if not p.client_default_value
|
|
104
|
-
]
|
|
105
|
-
failure_info = "fail to find required param named {}"
|
|
106
115
|
operation_params = {}
|
|
107
|
-
for param in
|
|
108
|
-
if not param.optional:
|
|
116
|
+
for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
|
|
117
|
+
if not param.optional and not param.client_default_value:
|
|
109
118
|
param_value = self.sample_params.get(param.wire_name)
|
|
110
119
|
if not param_value:
|
|
111
|
-
|
|
112
|
-
|
|
120
|
+
operation_params[param.client_name] = create_fake_value(param.type)
|
|
121
|
+
else:
|
|
122
|
+
operation_params[param.client_name] = self.handle_param(param, param_value)
|
|
113
123
|
return operation_params
|
|
114
124
|
|
|
115
125
|
def _operation_group_name(self) -> str:
|
|
@@ -154,7 +164,7 @@ class SampleSerializer(BaseSerializer):
|
|
|
154
164
|
operation_params=self._operation_params(),
|
|
155
165
|
operation_group_name=self._operation_group_name(),
|
|
156
166
|
operation_name=self._operation_name(),
|
|
157
|
-
imports=self.
|
|
167
|
+
imports=self.imports,
|
|
158
168
|
client_params=self._client_params(),
|
|
159
169
|
origin_file=self._origin_file(),
|
|
160
170
|
return_var=return_var,
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
4
|
# license information.
|
|
5
5
|
# --------------------------------------------------------------------------
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
from jinja2 import Environment
|
|
8
8
|
|
|
9
9
|
from .import_serializer import FileImportSerializer
|
|
@@ -14,12 +14,9 @@ from ..models import (
|
|
|
14
14
|
OperationGroup,
|
|
15
15
|
Client,
|
|
16
16
|
OperationType,
|
|
17
|
-
ModelType,
|
|
18
|
-
BaseType,
|
|
19
|
-
CombinedType,
|
|
20
17
|
FileImport,
|
|
21
18
|
)
|
|
22
|
-
from .utils import
|
|
19
|
+
from .utils import create_fake_value
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
def is_lro(operation_type: str) -> bool:
|
|
@@ -39,30 +36,15 @@ class TestName:
|
|
|
39
36
|
self.code_model = code_model
|
|
40
37
|
self.client_name = client_name
|
|
41
38
|
self.async_mode = async_mode
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@property
|
|
52
|
-
def prefix(self) -> str:
|
|
53
|
-
return self.client_name.replace("Client", "")
|
|
54
|
-
|
|
55
|
-
@property
|
|
56
|
-
def preparer_name(self) -> str:
|
|
57
|
-
if self.code_model.options["azure-arm"]:
|
|
58
|
-
return "RandomNameResourceGroupPreparer"
|
|
59
|
-
return self.prefix + "Preparer"
|
|
60
|
-
|
|
61
|
-
@property
|
|
62
|
-
def base_test_class_name(self) -> str:
|
|
63
|
-
if self.code_model.options["azure-arm"]:
|
|
64
|
-
return "AzureMgmtRecordedTestCase"
|
|
65
|
-
return f"{self.client_name}TestBase{self.async_suffix_capt}"
|
|
39
|
+
# Pre-compute values for render speed optimization
|
|
40
|
+
self.async_suffix_capt = "Async" if async_mode else ""
|
|
41
|
+
self.create_client_name = "create_async_client" if async_mode else "create_client"
|
|
42
|
+
self.prefix = client_name.replace("Client", "")
|
|
43
|
+
is_azure_arm = code_model.options["azure-arm"]
|
|
44
|
+
self.preparer_name = "RandomNameResourceGroupPreparer" if is_azure_arm else self.prefix + "Preparer"
|
|
45
|
+
self.base_test_class_name = (
|
|
46
|
+
"AzureMgmtRecordedTestCase" if is_azure_arm else f"{client_name}TestBase{self.async_suffix_capt}"
|
|
47
|
+
)
|
|
66
48
|
|
|
67
49
|
|
|
68
50
|
class TestCase:
|
|
@@ -73,50 +55,52 @@ class TestCase:
|
|
|
73
55
|
operation: OperationType,
|
|
74
56
|
*,
|
|
75
57
|
async_mode: bool = False,
|
|
58
|
+
is_azure_arm: bool = False,
|
|
76
59
|
) -> None:
|
|
77
60
|
self.operation_groups = operation_groups
|
|
78
|
-
self.params = params
|
|
79
61
|
self.operation = operation
|
|
80
62
|
self.async_mode = async_mode
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
63
|
+
self.is_azure_arm = is_azure_arm
|
|
64
|
+
# Pre-compute params
|
|
65
|
+
if is_azure_arm:
|
|
66
|
+
self.params = {k: ("resource_group.name" if k == "resource_group_name" else v) for k, v in params.items()}
|
|
67
|
+
else:
|
|
68
|
+
self.params = params
|
|
69
|
+
# Pre-compute name
|
|
70
|
+
if operation_groups[-1].is_mixin:
|
|
71
|
+
self.name = operation.name
|
|
72
|
+
else:
|
|
73
|
+
self.name = "_".join([og.property_name for og in operation_groups] + [operation.name])
|
|
74
|
+
# Pre-compute operation_group_prefix
|
|
75
|
+
if operation_groups[-1].is_mixin:
|
|
76
|
+
self.operation_group_prefix = ""
|
|
77
|
+
else:
|
|
78
|
+
self.operation_group_prefix = "." + ".".join([og.property_name for og in operation_groups])
|
|
79
|
+
# Pre-compute response
|
|
80
|
+
operation_type = operation.operation_type
|
|
81
|
+
if async_mode:
|
|
82
|
+
if is_lro(operation_type):
|
|
83
|
+
self.response = "response = await (await "
|
|
84
|
+
elif is_common_operation(operation_type):
|
|
85
|
+
self.response = "response = await "
|
|
86
|
+
else:
|
|
87
|
+
self.response = "response = "
|
|
88
|
+
else:
|
|
89
|
+
self.response = "response = "
|
|
90
|
+
# Pre-compute lro_comment
|
|
91
|
+
self.lro_comment = " # call '.result()' to poll until service return final result"
|
|
92
|
+
# Pre-compute operation_suffix
|
|
93
|
+
if is_lro(operation_type):
|
|
94
|
+
extra = ")" if async_mode else ""
|
|
95
|
+
self.operation_suffix = f"{extra}.result(){self.lro_comment}"
|
|
96
|
+
else:
|
|
97
|
+
self.operation_suffix = ""
|
|
98
|
+
# Pre-compute extra_operation
|
|
99
|
+
if is_paging(operation_type):
|
|
100
|
+
async_str = "async " if async_mode else ""
|
|
101
|
+
self.extra_operation = f"result = [r {async_str}for r in response]"
|
|
102
|
+
else:
|
|
103
|
+
self.extra_operation = ""
|
|
120
104
|
|
|
121
105
|
|
|
122
106
|
class Test(TestName):
|
|
@@ -189,9 +173,17 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
189
173
|
super().__init__(code_model, env, async_mode=async_mode)
|
|
190
174
|
self.client = client
|
|
191
175
|
self.operation_group = operation_group
|
|
176
|
+
self._import_test: str = ""
|
|
192
177
|
|
|
193
178
|
@property
|
|
194
|
-
def import_test(self) ->
|
|
179
|
+
def import_test(self) -> str:
|
|
180
|
+
return self._import_test
|
|
181
|
+
|
|
182
|
+
@import_test.setter
|
|
183
|
+
def import_test(self, value: str) -> None:
|
|
184
|
+
self._import_test = value
|
|
185
|
+
|
|
186
|
+
def get_import_test(self) -> str:
|
|
195
187
|
imports = self.init_file_import()
|
|
196
188
|
test_name = TestName(self.code_model, self.client.name, async_mode=self.async_mode)
|
|
197
189
|
async_suffix = "_async" if self.async_mode else ""
|
|
@@ -212,7 +204,7 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
212
204
|
)
|
|
213
205
|
if self.code_model.options["azure-arm"]:
|
|
214
206
|
self.add_import_client(imports)
|
|
215
|
-
return FileImportSerializer(imports, self.async_mode)
|
|
207
|
+
return str(FileImportSerializer(imports, self.async_mode))
|
|
216
208
|
|
|
217
209
|
@property
|
|
218
210
|
def breadth_search_operation_group(self) -> list[list[OperationGroup]]:
|
|
@@ -226,26 +218,11 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
226
218
|
queue.extend([current + [og] for og in current[-1].operation_groups])
|
|
227
219
|
return result
|
|
228
220
|
|
|
229
|
-
def get_sub_type(self, param_type: ModelType) -> ModelType:
|
|
230
|
-
if param_type.discriminated_subtypes:
|
|
231
|
-
for item in param_type.discriminated_subtypes.values():
|
|
232
|
-
return self.get_sub_type(item)
|
|
233
|
-
return param_type
|
|
234
|
-
|
|
235
|
-
def get_model_type(self, param_type: BaseType) -> Optional[ModelType]:
|
|
236
|
-
if isinstance(param_type, ModelType):
|
|
237
|
-
return param_type
|
|
238
|
-
if isinstance(param_type, CombinedType):
|
|
239
|
-
return param_type.target_model_subtype((ModelType,))
|
|
240
|
-
return None
|
|
241
|
-
|
|
242
221
|
def get_operation_params(self, operation: OperationType) -> dict[str, Any]:
|
|
243
222
|
operation_params = {}
|
|
244
223
|
required_params = [p for p in operation.parameters.method if not p.optional]
|
|
245
224
|
for param in required_params:
|
|
246
|
-
|
|
247
|
-
param_type = self.get_sub_type(model_type) if model_type else param.type
|
|
248
|
-
operation_params[param.client_name] = json_dumps_template(param_type.get_json_template_representation())
|
|
225
|
+
operation_params[param.client_name] = create_fake_value(param.type)
|
|
249
226
|
return operation_params
|
|
250
227
|
|
|
251
228
|
def get_test(self) -> Test:
|
|
@@ -260,6 +237,7 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
260
237
|
params=operation_params,
|
|
261
238
|
operation=operation,
|
|
262
239
|
async_mode=self.async_mode,
|
|
240
|
+
is_azure_arm=self.code_model.options["azure-arm"],
|
|
263
241
|
)
|
|
264
242
|
testcases.append(testcase)
|
|
265
243
|
if not testcases:
|
|
@@ -283,6 +261,7 @@ class TestSerializer(TestGeneralSerializer):
|
|
|
283
261
|
def serialize_test(self) -> str:
|
|
284
262
|
return self.env.get_template("test.py.jinja2").render(
|
|
285
263
|
imports=self.import_test,
|
|
286
|
-
|
|
264
|
+
is_azure_arm=self.code_model.options["azure-arm"],
|
|
265
|
+
license_header=self.code_model.license_header,
|
|
287
266
|
test=self.get_test(),
|
|
288
267
|
)
|
|
@@ -7,6 +7,15 @@ import json
|
|
|
7
7
|
from typing import Optional, Any
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
|
|
10
|
+
from ..models import ModelType, BaseType, CombinedType, FileImport
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_sub_type(param_type: ModelType) -> ModelType:
|
|
14
|
+
if param_type.discriminated_subtypes:
|
|
15
|
+
for item in param_type.discriminated_subtypes.values():
|
|
16
|
+
return get_sub_type(item)
|
|
17
|
+
return param_type
|
|
18
|
+
|
|
10
19
|
|
|
11
20
|
def method_signature_and_response_type_annotation_template(
|
|
12
21
|
*,
|
|
@@ -52,3 +61,31 @@ def _improve_json_string(template_representation: str) -> Any:
|
|
|
52
61
|
def json_dumps_template(template_representation: Any) -> Any:
|
|
53
62
|
# only for template use, since it wraps everything in strings
|
|
54
63
|
return _improve_json_string(json.dumps(template_representation, indent=4))
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def create_fake_value(param_type: BaseType) -> Any:
|
|
67
|
+
"""Create a fake value for a parameter type by getting its JSON template representation.
|
|
68
|
+
|
|
69
|
+
This function generates a fake value suitable for samples and tests.
|
|
70
|
+
|
|
71
|
+
:param param_type: The parameter type to create a fake value for.
|
|
72
|
+
:return: A string representation of the fake value.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
model_type: Optional[ModelType] = None
|
|
76
|
+
if isinstance(param_type, ModelType):
|
|
77
|
+
model_type = param_type
|
|
78
|
+
elif isinstance(param_type, CombinedType):
|
|
79
|
+
model_type = param_type.target_model_subtype((ModelType,))
|
|
80
|
+
resolved_type = get_sub_type(model_type) if model_type else param_type
|
|
81
|
+
return json_dumps_template(resolved_type.get_json_template_representation())
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def hash_file_import(file_import: FileImport) -> str:
|
|
85
|
+
"""Generate a hash for a FileImport object based on its imports.
|
|
86
|
+
|
|
87
|
+
:param file_import: The FileImport object to generate a hash for.
|
|
88
|
+
:return: A string representing the hash of the FileImport object.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
return "".join(sorted({str(hash(i)) for i in file_import.imports}))
|
|
@@ -1062,7 +1062,7 @@ def _failsafe_deserialize(
|
|
|
1062
1062
|
) -> typing.Any:
|
|
1063
1063
|
try:
|
|
1064
1064
|
return _deserialize(deserializer, response.json(), module, rf, format)
|
|
1065
|
-
except
|
|
1065
|
+
except Exception: # pylint: disable=broad-except
|
|
1066
1066
|
_LOGGER.warning(
|
|
1067
1067
|
"Ran into a deserialization error. Ignoring since this is failsafe deserialization",
|
|
1068
1068
|
exc_info=True
|
|
@@ -1072,11 +1072,11 @@ def _failsafe_deserialize(
|
|
|
1072
1072
|
|
|
1073
1073
|
def _failsafe_deserialize_xml(
|
|
1074
1074
|
deserializer: typing.Any,
|
|
1075
|
-
response: HttpResponse,
|
|
1075
|
+
response: HttpResponse,
|
|
1076
1076
|
) -> typing.Any:
|
|
1077
1077
|
try:
|
|
1078
1078
|
return _deserialize_xml(deserializer, response.text())
|
|
1079
|
-
except
|
|
1079
|
+
except Exception: # pylint: disable=broad-except
|
|
1080
1080
|
_LOGGER.warning(
|
|
1081
1081
|
"Ran into a deserialization error. Ignoring since this is failsafe deserialization",
|
|
1082
1082
|
exc_info=True
|
|
@@ -1,48 +1,51 @@
|
|
|
1
1
|
{% set prefix_lower = test.prefix|lower %}
|
|
2
|
-
{% set client_var = "self.client" if
|
|
2
|
+
{% set client_var = "self.client" if is_azure_arm else "client" %}
|
|
3
3
|
{% set async = "async " if test.async_mode else "" %}
|
|
4
4
|
{% set async_suffix = "_async" if test.async_mode else "" %}
|
|
5
5
|
# coding=utf-8
|
|
6
|
-
{% if
|
|
7
|
-
{{
|
|
6
|
+
{% if license_header %}
|
|
7
|
+
{{ license_header }}
|
|
8
8
|
{% endif %}
|
|
9
9
|
import pytest
|
|
10
10
|
{{ imports }}
|
|
11
11
|
|
|
12
|
-
{% if
|
|
12
|
+
{% if is_azure_arm %}
|
|
13
13
|
AZURE_LOCATION = "eastus"
|
|
14
14
|
{% endif %}
|
|
15
15
|
|
|
16
16
|
@pytest.mark.skip("you may need to update the auto-generated test case before run it")
|
|
17
17
|
class {{ test.test_class_name }}({{ test.base_test_class_name }}):
|
|
18
|
-
{% if
|
|
18
|
+
{% if is_azure_arm %}
|
|
19
19
|
def setup_method(self, method):
|
|
20
20
|
{% if test.async_mode %}
|
|
21
21
|
self.client = self.create_mgmt_client({{ test.client_name }}, is_async=True)
|
|
22
22
|
{% else %}
|
|
23
23
|
self.client = self.create_mgmt_client({{ test.client_name }})
|
|
24
24
|
{% endif %}
|
|
25
|
-
|
|
25
|
+
|
|
26
26
|
{% for testcase in test.testcases %}
|
|
27
|
-
{% if code_model.options["azure-arm"] %}
|
|
28
27
|
@{{ test.preparer_name }}(location=AZURE_LOCATION)
|
|
29
|
-
{% else %}
|
|
30
|
-
@{{ test.preparer_name }}()
|
|
31
|
-
{% endif %}
|
|
32
28
|
@recorded_by_proxy{{ async_suffix }}
|
|
33
|
-
{% if code_model.options["azure-arm"] %}
|
|
34
29
|
{{ async }}def test_{{ testcase.name }}(self, resource_group):
|
|
35
|
-
|
|
30
|
+
{{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
|
|
31
|
+
{% for key, value in testcase.params.items() %}
|
|
32
|
+
{{ key }}={{ value }},
|
|
33
|
+
{% endfor %}
|
|
34
|
+
){{ testcase.operation_suffix }}
|
|
35
|
+
{{ testcase.extra_operation }}
|
|
36
|
+
# please add some check logic here by yourself
|
|
37
|
+
# ...
|
|
38
|
+
|
|
39
|
+
{% endfor %}
|
|
40
|
+
{% else %}
|
|
41
|
+
{% for testcase in test.testcases %}
|
|
42
|
+
@{{ test.preparer_name }}()
|
|
43
|
+
@recorded_by_proxy{{ async_suffix }}
|
|
36
44
|
{{ async }}def test_{{ testcase.name }}(self, {{ prefix_lower }}_endpoint):
|
|
37
45
|
{{ client_var }} = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint)
|
|
38
|
-
{% endif %}
|
|
39
46
|
{{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
|
|
40
47
|
{% for key, value in testcase.params.items() %}
|
|
41
|
-
{
|
|
42
|
-
{{ key }}=resource_group.name,
|
|
43
|
-
{% else %}
|
|
44
|
-
{{ key }}={{ value|indent(12) }},
|
|
45
|
-
{% endif %}
|
|
48
|
+
{{ key }}={{ value }},
|
|
46
49
|
{% endfor %}
|
|
47
50
|
){{ testcase.operation_suffix }}
|
|
48
51
|
{{ testcase.extra_operation }}
|
|
@@ -50,3 +53,4 @@ class {{ test.test_class_name }}({{ test.base_test_class_name }}):
|
|
|
50
53
|
# ...
|
|
51
54
|
|
|
52
55
|
{% endfor %}
|
|
56
|
+
{% endif %}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@autorest/python",
|
|
3
|
-
"version": "6.
|
|
3
|
+
"version": "6.49.0",
|
|
4
4
|
"description": "The Python extension for generators in AutoRest.",
|
|
5
5
|
"main": "index.js",
|
|
6
6
|
"repository": {
|
|
@@ -19,7 +19,7 @@
|
|
|
19
19
|
},
|
|
20
20
|
"homepage": "https://github.com/Azure/autorest.python/blob/main/README.md",
|
|
21
21
|
"dependencies": {
|
|
22
|
-
"@typespec/http-client-python": "~0.
|
|
22
|
+
"@typespec/http-client-python": "~0.27.0",
|
|
23
23
|
"@autorest/system-requirements": "~1.0.2",
|
|
24
24
|
"fs-extra": "~11.2.0",
|
|
25
25
|
"tsx": "~4.19.1"
|
|
Binary file
|