@autorest/python 6.48.2 → 6.48.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,11 +34,7 @@ from .sample_serializer import SampleSerializer
34
34
  from .test_serializer import TestSerializer, TestGeneralSerializer
35
35
  from .types_serializer import TypesSerializer
36
36
  from ...utils import to_snake_case, VALID_PACKAGE_MODE
37
- from .utils import (
38
- extract_sample_name,
39
- get_namespace_from_package_name,
40
- get_namespace_config,
41
- )
37
+ from .utils import extract_sample_name, get_namespace_from_package_name, get_namespace_config, hash_file_import
42
38
 
43
39
  _LOGGER = logging.getLogger(__name__)
44
40
 
@@ -536,39 +532,72 @@ class JinjaSerializer(ReaderAndWriter):
536
532
  def _generated_tests_samples_folder(self, folder_name: str) -> Path:
537
533
  return self._root_of_sdk / folder_name
538
534
 
535
+ def _process_operation_samples(
536
+ self,
537
+ samples: dict,
538
+ env: Environment,
539
+ op_group,
540
+ operation,
541
+ import_sample_cache: dict[tuple[str, str], str],
542
+ out_path: Path,
543
+ sample_additional_folder: Path,
544
+ ) -> None:
545
+ """Process samples for a single operation."""
546
+ for sample_value in samples.values():
547
+ file = sample_value.get("x-ms-original-file", "sample.json")
548
+ file_name = to_snake_case(extract_sample_name(file)) + ".py"
549
+ try:
550
+ sample_ser = SampleSerializer(
551
+ code_model=self.code_model,
552
+ env=env,
553
+ operation_group=op_group,
554
+ operation=operation,
555
+ sample=sample_value,
556
+ file_name=file_name,
557
+ )
558
+ file_import = sample_ser.get_file_import()
559
+ imports_hash_string = hash_file_import(file_import)
560
+ cache_key = (op_group.client.client_namespace, imports_hash_string)
561
+ if cache_key not in import_sample_cache:
562
+ import_sample_cache[cache_key] = sample_ser.get_imports_from_file_import(file_import)
563
+ sample_ser.imports = import_sample_cache[cache_key]
564
+
565
+ content = sample_ser.serialize()
566
+ output_path = out_path / sample_additional_folder / _sample_output_path(file) / file_name
567
+ self.write_file(output_path, content)
568
+ except Exception as e: # pylint: disable=broad-except
569
+ _LOGGER.error("error happens in sample %s: %s", file, e)
570
+
539
571
  def _serialize_and_write_sample(self, env: Environment):
540
572
  out_path = self._generated_tests_samples_folder("generated_samples")
573
+ sample_additional_folder = self.sample_additional_folder
574
+
575
+ # Cache import_test per (client_namespace, imports_hash_string) since it's expensive to compute
576
+ import_sample_cache: dict[tuple[str, str], str] = {}
577
+
541
578
  for client in self.code_model.clients:
542
579
  for op_group in client.operation_groups:
543
580
  for operation in op_group.operations:
544
581
  samples = operation.yaml_data.get("samples")
545
582
  if not samples or operation.name.startswith("_"):
546
583
  continue
547
- for value in samples.values():
548
- file = value.get("x-ms-original-file", "sample.json")
549
- file_name = to_snake_case(extract_sample_name(file)) + ".py"
550
- try:
551
- self.write_file(
552
- out_path / self.sample_additional_folder / _sample_output_path(file) / file_name,
553
- SampleSerializer(
554
- code_model=self.code_model,
555
- env=env,
556
- operation_group=op_group,
557
- operation=operation,
558
- sample=value,
559
- file_name=file_name,
560
- ).serialize(),
561
- )
562
- except Exception as e: # pylint: disable=broad-except
563
- # sample generation shall not block code generation, so just log error
564
- log_error = f"error happens in sample {file}: {e}"
565
- _LOGGER.error(log_error)
584
+ self._process_operation_samples(
585
+ samples,
586
+ env,
587
+ op_group,
588
+ operation,
589
+ import_sample_cache,
590
+ out_path,
591
+ sample_additional_folder,
592
+ )
566
593
 
567
594
  def _serialize_and_write_test(self, env: Environment):
568
595
  self.code_model.for_test = True
569
596
  out_path = self._generated_tests_samples_folder("generated_tests")
597
+
570
598
  general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
571
599
  self.write_file(out_path / "conftest.py", general_serializer.serialize_conftest())
600
+
572
601
  if not self.code_model.options["azure-arm"]:
573
602
  for async_mode in (True, False):
574
603
  async_suffix = "_async" if async_mode else ""
@@ -578,18 +607,24 @@ class JinjaSerializer(ReaderAndWriter):
578
607
  general_serializer.serialize_testpreparer(),
579
608
  )
580
609
 
610
+ # Generate test files - reuse serializer per operation group, toggle async_mode
611
+ # Cache import_test per (client.name, async_mode) since it's expensive to compute
612
+ import_test_cache: dict[tuple[str, bool], str] = {}
581
613
  for client in self.code_model.clients:
582
614
  for og in client.operation_groups:
615
+ # Create serializer once per operation group
583
616
  test_serializer = TestSerializer(self.code_model, env, client=client, operation_group=og)
584
- for async_mode in (True, False):
585
- try:
617
+ try:
618
+ for async_mode in (True, False):
586
619
  test_serializer.async_mode = async_mode
587
- self.write_file(
588
- out_path / f"{to_snake_case(test_serializer.test_class_name)}.py",
589
- test_serializer.serialize_test(),
590
- )
591
- except Exception as e: # pylint: disable=broad-except
592
- # test generation shall not block code generation, so just log error
593
- log_error = f"error happens in test generation for operation group {og.class_name}: {e}"
594
- _LOGGER.error(log_error)
620
+ cache_key = (client.name, async_mode)
621
+ if cache_key not in import_test_cache:
622
+ import_test_cache[cache_key] = test_serializer.get_import_test()
623
+ test_serializer.import_test = import_test_cache[cache_key]
624
+ content = test_serializer.serialize_test()
625
+ output_path = out_path / f"{to_snake_case(test_serializer.test_class_name)}.py"
626
+ self.write_file(output_path, content)
627
+ except Exception as e: # pylint: disable=broad-except
628
+ _LOGGER.error("error happens in test generation for operation group %s: %s", og.class_name, e)
629
+
595
630
  self.code_model.for_test = False
@@ -3,7 +3,6 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from copy import deepcopy
7
6
  from ..models.imports import (
8
7
  ImportType,
9
8
  FileImport,
@@ -76,14 +75,26 @@ class FileImportSerializer:
76
75
  def _get_imports_list(self, baseline_typing_section: TypingSection, add_conditional_typing: bool):
77
76
  # If this is a python 3 file, our regular imports include the CONDITIONAL category
78
77
  # If this is not a python 3 file, our typing imports include the CONDITIONAL category
79
- file_import_copy = deepcopy(self.file_import)
80
- if add_conditional_typing and any(self.file_import.get_imports_from_section(TypingSection.CONDITIONAL)):
81
- # we switch the TypingSection key for the CONDITIONAL typing imports so we can merge
82
- # the imports together
83
- for i in file_import_copy.imports:
84
- if i.typing_section == TypingSection.CONDITIONAL:
85
- i.typing_section = baseline_typing_section
86
- return file_import_copy.get_imports_from_section(baseline_typing_section)
78
+
79
+ # Get imports that already match the baseline_typing_section
80
+ result = list(self.file_import.get_imports_from_section(baseline_typing_section))
81
+
82
+ if add_conditional_typing:
83
+ # Get conditional imports and create new ImportModel instances with modified typing_section
84
+ conditional_imports = self.file_import.get_imports_from_section(TypingSection.CONDITIONAL)
85
+ for i in conditional_imports:
86
+ # Create a new ImportModel with the baseline_typing_section instead of CONDITIONAL
87
+ result.append(
88
+ ImportModel(
89
+ typing_section=baseline_typing_section,
90
+ import_type=i.import_type,
91
+ module_name=i.module_name,
92
+ submodule_name=i.submodule_name,
93
+ alias=i.alias,
94
+ version_modules=i.version_modules,
95
+ )
96
+ )
97
+ return result
87
98
 
88
99
  def _add_type_checking_import(self):
89
100
  if any(self.file_import.get_imports_from_section(TypingSection.TYPING)):
@@ -20,6 +20,7 @@ from ..models import (
20
20
  BodyParameter,
21
21
  FileImport,
22
22
  )
23
+ from .utils import create_fake_value
23
24
 
24
25
  _LOGGER = logging.getLogger(__name__)
25
26
 
@@ -40,8 +41,17 @@ class SampleSerializer(BaseSerializer):
40
41
  self.sample = sample
41
42
  self.file_name = file_name
42
43
  self.sample_params = sample.get("parameters", {})
44
+ self._imports: str = ""
43
45
 
44
- def _imports(self) -> FileImportSerializer:
46
+ @property
47
+ def imports(self) -> str:
48
+ return self._imports
49
+
50
+ @imports.setter
51
+ def imports(self, value: str) -> None:
52
+ self._imports = value
53
+
54
+ def get_file_import(self) -> FileImport:
45
55
  imports = FileImport(self.code_model)
46
56
  client = self.operation_group.client
47
57
  namespace = client.client_namespace
@@ -59,7 +69,12 @@ class SampleSerializer(BaseSerializer):
59
69
  for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
60
70
  if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params:
61
71
  imports.merge(param.type.imports_for_sample())
62
- return FileImportSerializer(imports, True)
72
+
73
+ return imports
74
+
75
+ @staticmethod
76
+ def get_imports_from_file_import(file_import: FileImport) -> str:
77
+ return str(FileImportSerializer(file_import, True))
63
78
 
64
79
  def _client_params(self) -> dict[str, Any]:
65
80
  # client params
@@ -97,19 +112,14 @@ class SampleSerializer(BaseSerializer):
97
112
 
98
113
  # prepare operation parameters
99
114
  def _operation_params(self) -> dict[str, Any]:
100
- params = [
101
- p
102
- for p in (self.operation.parameters.positional + self.operation.parameters.keyword_only)
103
- if not p.client_default_value
104
- ]
105
- failure_info = "fail to find required param named {}"
106
115
  operation_params = {}
107
- for param in params:
108
- if not param.optional:
116
+ for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
117
+ if not param.optional and not param.client_default_value:
109
118
  param_value = self.sample_params.get(param.wire_name)
110
119
  if not param_value:
111
- raise Exception(failure_info.format(param.client_name)) # pylint: disable=broad-exception-raised
112
- operation_params[param.client_name] = self.handle_param(param, param_value)
120
+ operation_params[param.client_name] = create_fake_value(param.type)
121
+ else:
122
+ operation_params[param.client_name] = self.handle_param(param, param_value)
113
123
  return operation_params
114
124
 
115
125
  def _operation_group_name(self) -> str:
@@ -154,7 +164,7 @@ class SampleSerializer(BaseSerializer):
154
164
  operation_params=self._operation_params(),
155
165
  operation_group_name=self._operation_group_name(),
156
166
  operation_name=self._operation_name(),
157
- imports=self._imports(),
167
+ imports=self.imports,
158
168
  client_params=self._client_params(),
159
169
  origin_file=self._origin_file(),
160
170
  return_var=return_var,
@@ -3,7 +3,7 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import Any, Optional
6
+ from typing import Any
7
7
  from jinja2 import Environment
8
8
 
9
9
  from .import_serializer import FileImportSerializer
@@ -14,12 +14,9 @@ from ..models import (
14
14
  OperationGroup,
15
15
  Client,
16
16
  OperationType,
17
- ModelType,
18
- BaseType,
19
- CombinedType,
20
17
  FileImport,
21
18
  )
22
- from .utils import json_dumps_template
19
+ from .utils import create_fake_value
23
20
 
24
21
 
25
22
  def is_lro(operation_type: str) -> bool:
@@ -39,30 +36,15 @@ class TestName:
39
36
  self.code_model = code_model
40
37
  self.client_name = client_name
41
38
  self.async_mode = async_mode
42
-
43
- @property
44
- def async_suffix_capt(self) -> str:
45
- return "Async" if self.async_mode else ""
46
-
47
- @property
48
- def create_client_name(self) -> str:
49
- return "create_async_client" if self.async_mode else "create_client"
50
-
51
- @property
52
- def prefix(self) -> str:
53
- return self.client_name.replace("Client", "")
54
-
55
- @property
56
- def preparer_name(self) -> str:
57
- if self.code_model.options["azure-arm"]:
58
- return "RandomNameResourceGroupPreparer"
59
- return self.prefix + "Preparer"
60
-
61
- @property
62
- def base_test_class_name(self) -> str:
63
- if self.code_model.options["azure-arm"]:
64
- return "AzureMgmtRecordedTestCase"
65
- return f"{self.client_name}TestBase{self.async_suffix_capt}"
39
+ # Pre-compute values for render speed optimization
40
+ self.async_suffix_capt = "Async" if async_mode else ""
41
+ self.create_client_name = "create_async_client" if async_mode else "create_client"
42
+ self.prefix = client_name.replace("Client", "")
43
+ is_azure_arm = code_model.options["azure-arm"]
44
+ self.preparer_name = "RandomNameResourceGroupPreparer" if is_azure_arm else self.prefix + "Preparer"
45
+ self.base_test_class_name = (
46
+ "AzureMgmtRecordedTestCase" if is_azure_arm else f"{client_name}TestBase{self.async_suffix_capt}"
47
+ )
66
48
 
67
49
 
68
50
  class TestCase:
@@ -73,50 +55,52 @@ class TestCase:
73
55
  operation: OperationType,
74
56
  *,
75
57
  async_mode: bool = False,
58
+ is_azure_arm: bool = False,
76
59
  ) -> None:
77
60
  self.operation_groups = operation_groups
78
- self.params = params
79
61
  self.operation = operation
80
62
  self.async_mode = async_mode
81
-
82
- @property
83
- def name(self) -> str:
84
- if self.operation_groups[-1].is_mixin:
85
- return self.operation.name
86
- return "_".join([og.property_name for og in self.operation_groups] + [self.operation.name])
87
-
88
- @property
89
- def operation_group_prefix(self) -> str:
90
- if self.operation_groups[-1].is_mixin:
91
- return ""
92
- return "." + ".".join([og.property_name for og in self.operation_groups])
93
-
94
- @property
95
- def response(self) -> str:
96
- if self.async_mode:
97
- if is_lro(self.operation.operation_type):
98
- return "response = await (await "
99
- if is_common_operation(self.operation.operation_type):
100
- return "response = await "
101
- return "response = "
102
-
103
- @property
104
- def lro_comment(self) -> str:
105
- return " # call '.result()' to poll until service return final result"
106
-
107
- @property
108
- def operation_suffix(self) -> str:
109
- if is_lro(self.operation.operation_type):
110
- extra = ")" if self.async_mode else ""
111
- return f"{extra}.result(){self.lro_comment}"
112
- return ""
113
-
114
- @property
115
- def extra_operation(self) -> str:
116
- if is_paging(self.operation.operation_type):
117
- async_str = "async " if self.async_mode else ""
118
- return f"result = [r {async_str}for r in response]"
119
- return ""
63
+ self.is_azure_arm = is_azure_arm
64
+ # Pre-compute params
65
+ if is_azure_arm:
66
+ self.params = {k: ("resource_group.name" if k == "resource_group_name" else v) for k, v in params.items()}
67
+ else:
68
+ self.params = params
69
+ # Pre-compute name
70
+ if operation_groups[-1].is_mixin:
71
+ self.name = operation.name
72
+ else:
73
+ self.name = "_".join([og.property_name for og in operation_groups] + [operation.name])
74
+ # Pre-compute operation_group_prefix
75
+ if operation_groups[-1].is_mixin:
76
+ self.operation_group_prefix = ""
77
+ else:
78
+ self.operation_group_prefix = "." + ".".join([og.property_name for og in operation_groups])
79
+ # Pre-compute response
80
+ operation_type = operation.operation_type
81
+ if async_mode:
82
+ if is_lro(operation_type):
83
+ self.response = "response = await (await "
84
+ elif is_common_operation(operation_type):
85
+ self.response = "response = await "
86
+ else:
87
+ self.response = "response = "
88
+ else:
89
+ self.response = "response = "
90
+ # Pre-compute lro_comment
91
+ self.lro_comment = " # call '.result()' to poll until service return final result"
92
+ # Pre-compute operation_suffix
93
+ if is_lro(operation_type):
94
+ extra = ")" if async_mode else ""
95
+ self.operation_suffix = f"{extra}.result(){self.lro_comment}"
96
+ else:
97
+ self.operation_suffix = ""
98
+ # Pre-compute extra_operation
99
+ if is_paging(operation_type):
100
+ async_str = "async " if async_mode else ""
101
+ self.extra_operation = f"result = [r {async_str}for r in response]"
102
+ else:
103
+ self.extra_operation = ""
120
104
 
121
105
 
122
106
  class Test(TestName):
@@ -189,9 +173,17 @@ class TestSerializer(TestGeneralSerializer):
189
173
  super().__init__(code_model, env, async_mode=async_mode)
190
174
  self.client = client
191
175
  self.operation_group = operation_group
176
+ self._import_test: str = ""
192
177
 
193
178
  @property
194
- def import_test(self) -> FileImportSerializer:
179
+ def import_test(self) -> str:
180
+ return self._import_test
181
+
182
+ @import_test.setter
183
+ def import_test(self, value: str) -> None:
184
+ self._import_test = value
185
+
186
+ def get_import_test(self) -> str:
195
187
  imports = self.init_file_import()
196
188
  test_name = TestName(self.code_model, self.client.name, async_mode=self.async_mode)
197
189
  async_suffix = "_async" if self.async_mode else ""
@@ -212,7 +204,7 @@ class TestSerializer(TestGeneralSerializer):
212
204
  )
213
205
  if self.code_model.options["azure-arm"]:
214
206
  self.add_import_client(imports)
215
- return FileImportSerializer(imports, self.async_mode)
207
+ return str(FileImportSerializer(imports, self.async_mode))
216
208
 
217
209
  @property
218
210
  def breadth_search_operation_group(self) -> list[list[OperationGroup]]:
@@ -226,26 +218,11 @@ class TestSerializer(TestGeneralSerializer):
226
218
  queue.extend([current + [og] for og in current[-1].operation_groups])
227
219
  return result
228
220
 
229
- def get_sub_type(self, param_type: ModelType) -> ModelType:
230
- if param_type.discriminated_subtypes:
231
- for item in param_type.discriminated_subtypes.values():
232
- return self.get_sub_type(item)
233
- return param_type
234
-
235
- def get_model_type(self, param_type: BaseType) -> Optional[ModelType]:
236
- if isinstance(param_type, ModelType):
237
- return param_type
238
- if isinstance(param_type, CombinedType):
239
- return param_type.target_model_subtype((ModelType,))
240
- return None
241
-
242
221
  def get_operation_params(self, operation: OperationType) -> dict[str, Any]:
243
222
  operation_params = {}
244
223
  required_params = [p for p in operation.parameters.method if not p.optional]
245
224
  for param in required_params:
246
- model_type = self.get_model_type(param.type)
247
- param_type = self.get_sub_type(model_type) if model_type else param.type
248
- operation_params[param.client_name] = json_dumps_template(param_type.get_json_template_representation())
225
+ operation_params[param.client_name] = create_fake_value(param.type)
249
226
  return operation_params
250
227
 
251
228
  def get_test(self) -> Test:
@@ -260,6 +237,7 @@ class TestSerializer(TestGeneralSerializer):
260
237
  params=operation_params,
261
238
  operation=operation,
262
239
  async_mode=self.async_mode,
240
+ is_azure_arm=self.code_model.options["azure-arm"],
263
241
  )
264
242
  testcases.append(testcase)
265
243
  if not testcases:
@@ -283,6 +261,7 @@ class TestSerializer(TestGeneralSerializer):
283
261
  def serialize_test(self) -> str:
284
262
  return self.env.get_template("test.py.jinja2").render(
285
263
  imports=self.import_test,
286
- code_model=self.code_model,
264
+ is_azure_arm=self.code_model.options["azure-arm"],
265
+ license_header=self.code_model.license_header,
287
266
  test=self.get_test(),
288
267
  )
@@ -7,6 +7,15 @@ import json
7
7
  from typing import Optional, Any
8
8
  from pathlib import Path
9
9
 
10
+ from ..models import ModelType, BaseType, CombinedType, FileImport
11
+
12
+
13
+ def get_sub_type(param_type: ModelType) -> ModelType:
14
+ if param_type.discriminated_subtypes:
15
+ for item in param_type.discriminated_subtypes.values():
16
+ return get_sub_type(item)
17
+ return param_type
18
+
10
19
 
11
20
  def method_signature_and_response_type_annotation_template(
12
21
  *,
@@ -52,3 +61,31 @@ def _improve_json_string(template_representation: str) -> Any:
52
61
  def json_dumps_template(template_representation: Any) -> Any:
53
62
  # only for template use, since it wraps everything in strings
54
63
  return _improve_json_string(json.dumps(template_representation, indent=4))
64
+
65
+
66
+ def create_fake_value(param_type: BaseType) -> Any:
67
+ """Create a fake value for a parameter type by getting its JSON template representation.
68
+
69
+ This function generates a fake value suitable for samples and tests.
70
+
71
+ :param param_type: The parameter type to create a fake value for.
72
+ :return: A string representation of the fake value.
73
+ """
74
+
75
+ model_type: Optional[ModelType] = None
76
+ if isinstance(param_type, ModelType):
77
+ model_type = param_type
78
+ elif isinstance(param_type, CombinedType):
79
+ model_type = param_type.target_model_subtype((ModelType,))
80
+ resolved_type = get_sub_type(model_type) if model_type else param_type
81
+ return json_dumps_template(resolved_type.get_json_template_representation())
82
+
83
+
84
+ def hash_file_import(file_import: FileImport) -> str:
85
+ """Generate a hash for a FileImport object based on its imports.
86
+
87
+ :param file_import: The FileImport object to generate a hash for.
88
+ :return: A string representing the hash of the FileImport object.
89
+ """
90
+
91
+ return "".join(sorted({str(hash(i)) for i in file_import.imports}))
@@ -1062,7 +1062,7 @@ def _failsafe_deserialize(
1062
1062
  ) -> typing.Any:
1063
1063
  try:
1064
1064
  return _deserialize(deserializer, response.json(), module, rf, format)
1065
- except DeserializationError:
1065
+ except Exception: # pylint: disable=broad-except
1066
1066
  _LOGGER.warning(
1067
1067
  "Ran into a deserialization error. Ignoring since this is failsafe deserialization",
1068
1068
  exc_info=True
@@ -1072,11 +1072,11 @@ def _failsafe_deserialize(
1072
1072
 
1073
1073
  def _failsafe_deserialize_xml(
1074
1074
  deserializer: typing.Any,
1075
- response: HttpResponse,
1075
+ response: HttpResponse,
1076
1076
  ) -> typing.Any:
1077
1077
  try:
1078
1078
  return _deserialize_xml(deserializer, response.text())
1079
- except DeserializationError:
1079
+ except Exception: # pylint: disable=broad-except
1080
1080
  _LOGGER.warning(
1081
1081
  "Ran into a deserialization error. Ignoring since this is failsafe deserialization",
1082
1082
  exc_info=True
@@ -1,48 +1,51 @@
1
1
  {% set prefix_lower = test.prefix|lower %}
2
- {% set client_var = "self.client" if code_model.options["azure-arm"] else "client" %}
2
+ {% set client_var = "self.client" if is_azure_arm else "client" %}
3
3
  {% set async = "async " if test.async_mode else "" %}
4
4
  {% set async_suffix = "_async" if test.async_mode else "" %}
5
5
  # coding=utf-8
6
- {% if code_model.license_header %}
7
- {{ code_model.license_header }}
6
+ {% if license_header %}
7
+ {{ license_header }}
8
8
  {% endif %}
9
9
  import pytest
10
10
  {{ imports }}
11
11
 
12
- {% if code_model.options["azure-arm"] %}
12
+ {% if is_azure_arm %}
13
13
  AZURE_LOCATION = "eastus"
14
14
  {% endif %}
15
15
 
16
16
  @pytest.mark.skip("you may need to update the auto-generated test case before run it")
17
17
  class {{ test.test_class_name }}({{ test.base_test_class_name }}):
18
- {% if code_model.options["azure-arm"] %}
18
+ {% if is_azure_arm %}
19
19
  def setup_method(self, method):
20
20
  {% if test.async_mode %}
21
21
  self.client = self.create_mgmt_client({{ test.client_name }}, is_async=True)
22
22
  {% else %}
23
23
  self.client = self.create_mgmt_client({{ test.client_name }})
24
24
  {% endif %}
25
- {% endif %}
25
+
26
26
  {% for testcase in test.testcases %}
27
- {% if code_model.options["azure-arm"] %}
28
27
  @{{ test.preparer_name }}(location=AZURE_LOCATION)
29
- {% else %}
30
- @{{ test.preparer_name }}()
31
- {% endif %}
32
28
  @recorded_by_proxy{{ async_suffix }}
33
- {% if code_model.options["azure-arm"] %}
34
29
  {{ async }}def test_{{ testcase.name }}(self, resource_group):
35
- {% else %}
30
+ {{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
31
+ {% for key, value in testcase.params.items() %}
32
+ {{ key }}={{ value }},
33
+ {% endfor %}
34
+ ){{ testcase.operation_suffix }}
35
+ {{ testcase.extra_operation }}
36
+ # please add some check logic here by yourself
37
+ # ...
38
+
39
+ {% endfor %}
40
+ {% else %}
41
+ {% for testcase in test.testcases %}
42
+ @{{ test.preparer_name }}()
43
+ @recorded_by_proxy{{ async_suffix }}
36
44
  {{ async }}def test_{{ testcase.name }}(self, {{ prefix_lower }}_endpoint):
37
45
  {{ client_var }} = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint)
38
- {% endif %}
39
46
  {{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
40
47
  {% for key, value in testcase.params.items() %}
41
- {% if code_model.options["azure-arm"] and key == "resource_group_name" %}
42
- {{ key }}=resource_group.name,
43
- {% else %}
44
- {{ key }}={{ value|indent(12) }},
45
- {% endif %}
48
+ {{ key }}={{ value }},
46
49
  {% endfor %}
47
50
  ){{ testcase.operation_suffix }}
48
51
  {{ testcase.extra_operation }}
@@ -50,3 +53,4 @@ class {{ test.test_class_name }}({{ test.base_test_class_name }}):
50
53
  # ...
51
54
 
52
55
  {% endfor %}
56
+ {% endif %}
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "version": "1",
3
- "pip_version": "25.3",
3
+ "pip_version": "26.0",
4
4
  "install": [
5
5
  {
6
6
  "download_info": {
@@ -34,11 +34,7 @@ from .sample_serializer import SampleSerializer
34
34
  from .test_serializer import TestSerializer, TestGeneralSerializer
35
35
  from .types_serializer import TypesSerializer
36
36
  from ...utils import to_snake_case, VALID_PACKAGE_MODE
37
- from .utils import (
38
- extract_sample_name,
39
- get_namespace_from_package_name,
40
- get_namespace_config,
41
- )
37
+ from .utils import extract_sample_name, get_namespace_from_package_name, get_namespace_config, hash_file_import
42
38
 
43
39
  _LOGGER = logging.getLogger(__name__)
44
40
 
@@ -536,39 +532,72 @@ class JinjaSerializer(ReaderAndWriter):
536
532
  def _generated_tests_samples_folder(self, folder_name: str) -> Path:
537
533
  return self._root_of_sdk / folder_name
538
534
 
535
+ def _process_operation_samples(
536
+ self,
537
+ samples: dict,
538
+ env: Environment,
539
+ op_group,
540
+ operation,
541
+ import_sample_cache: dict[tuple[str, str], str],
542
+ out_path: Path,
543
+ sample_additional_folder: Path,
544
+ ) -> None:
545
+ """Process samples for a single operation."""
546
+ for sample_value in samples.values():
547
+ file = sample_value.get("x-ms-original-file", "sample.json")
548
+ file_name = to_snake_case(extract_sample_name(file)) + ".py"
549
+ try:
550
+ sample_ser = SampleSerializer(
551
+ code_model=self.code_model,
552
+ env=env,
553
+ operation_group=op_group,
554
+ operation=operation,
555
+ sample=sample_value,
556
+ file_name=file_name,
557
+ )
558
+ file_import = sample_ser.get_file_import()
559
+ imports_hash_string = hash_file_import(file_import)
560
+ cache_key = (op_group.client.client_namespace, imports_hash_string)
561
+ if cache_key not in import_sample_cache:
562
+ import_sample_cache[cache_key] = sample_ser.get_imports_from_file_import(file_import)
563
+ sample_ser.imports = import_sample_cache[cache_key]
564
+
565
+ content = sample_ser.serialize()
566
+ output_path = out_path / sample_additional_folder / _sample_output_path(file) / file_name
567
+ self.write_file(output_path, content)
568
+ except Exception as e: # pylint: disable=broad-except
569
+ _LOGGER.error("error happens in sample %s: %s", file, e)
570
+
539
571
  def _serialize_and_write_sample(self, env: Environment):
540
572
  out_path = self._generated_tests_samples_folder("generated_samples")
573
+ sample_additional_folder = self.sample_additional_folder
574
+
575
+ # Cache import_test per (client_namespace, imports_hash_string) since it's expensive to compute
576
+ import_sample_cache: dict[tuple[str, str], str] = {}
577
+
541
578
  for client in self.code_model.clients:
542
579
  for op_group in client.operation_groups:
543
580
  for operation in op_group.operations:
544
581
  samples = operation.yaml_data.get("samples")
545
582
  if not samples or operation.name.startswith("_"):
546
583
  continue
547
- for value in samples.values():
548
- file = value.get("x-ms-original-file", "sample.json")
549
- file_name = to_snake_case(extract_sample_name(file)) + ".py"
550
- try:
551
- self.write_file(
552
- out_path / self.sample_additional_folder / _sample_output_path(file) / file_name,
553
- SampleSerializer(
554
- code_model=self.code_model,
555
- env=env,
556
- operation_group=op_group,
557
- operation=operation,
558
- sample=value,
559
- file_name=file_name,
560
- ).serialize(),
561
- )
562
- except Exception as e: # pylint: disable=broad-except
563
- # sample generation shall not block code generation, so just log error
564
- log_error = f"error happens in sample {file}: {e}"
565
- _LOGGER.error(log_error)
584
+ self._process_operation_samples(
585
+ samples,
586
+ env,
587
+ op_group,
588
+ operation,
589
+ import_sample_cache,
590
+ out_path,
591
+ sample_additional_folder,
592
+ )
566
593
 
567
594
  def _serialize_and_write_test(self, env: Environment):
568
595
  self.code_model.for_test = True
569
596
  out_path = self._generated_tests_samples_folder("generated_tests")
597
+
570
598
  general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
571
599
  self.write_file(out_path / "conftest.py", general_serializer.serialize_conftest())
600
+
572
601
  if not self.code_model.options["azure-arm"]:
573
602
  for async_mode in (True, False):
574
603
  async_suffix = "_async" if async_mode else ""
@@ -578,18 +607,24 @@ class JinjaSerializer(ReaderAndWriter):
578
607
  general_serializer.serialize_testpreparer(),
579
608
  )
580
609
 
610
+ # Generate test files - reuse serializer per operation group, toggle async_mode
611
+ # Cache import_test per (client.name, async_mode) since it's expensive to compute
612
+ import_test_cache: dict[tuple[str, bool], str] = {}
581
613
  for client in self.code_model.clients:
582
614
  for og in client.operation_groups:
615
+ # Create serializer once per operation group
583
616
  test_serializer = TestSerializer(self.code_model, env, client=client, operation_group=og)
584
- for async_mode in (True, False):
585
- try:
617
+ try:
618
+ for async_mode in (True, False):
586
619
  test_serializer.async_mode = async_mode
587
- self.write_file(
588
- out_path / f"{to_snake_case(test_serializer.test_class_name)}.py",
589
- test_serializer.serialize_test(),
590
- )
591
- except Exception as e: # pylint: disable=broad-except
592
- # test generation shall not block code generation, so just log error
593
- log_error = f"error happens in test generation for operation group {og.class_name}: {e}"
594
- _LOGGER.error(log_error)
620
+ cache_key = (client.name, async_mode)
621
+ if cache_key not in import_test_cache:
622
+ import_test_cache[cache_key] = test_serializer.get_import_test()
623
+ test_serializer.import_test = import_test_cache[cache_key]
624
+ content = test_serializer.serialize_test()
625
+ output_path = out_path / f"{to_snake_case(test_serializer.test_class_name)}.py"
626
+ self.write_file(output_path, content)
627
+ except Exception as e: # pylint: disable=broad-except
628
+ _LOGGER.error("error happens in test generation for operation group %s: %s", og.class_name, e)
629
+
595
630
  self.code_model.for_test = False
@@ -3,7 +3,6 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from copy import deepcopy
7
6
  from ..models.imports import (
8
7
  ImportType,
9
8
  FileImport,
@@ -76,14 +75,26 @@ class FileImportSerializer:
76
75
  def _get_imports_list(self, baseline_typing_section: TypingSection, add_conditional_typing: bool):
77
76
  # If this is a python 3 file, our regular imports include the CONDITIONAL category
78
77
  # If this is not a python 3 file, our typing imports include the CONDITIONAL category
79
- file_import_copy = deepcopy(self.file_import)
80
- if add_conditional_typing and any(self.file_import.get_imports_from_section(TypingSection.CONDITIONAL)):
81
- # we switch the TypingSection key for the CONDITIONAL typing imports so we can merge
82
- # the imports together
83
- for i in file_import_copy.imports:
84
- if i.typing_section == TypingSection.CONDITIONAL:
85
- i.typing_section = baseline_typing_section
86
- return file_import_copy.get_imports_from_section(baseline_typing_section)
78
+
79
+ # Get imports that already match the baseline_typing_section
80
+ result = list(self.file_import.get_imports_from_section(baseline_typing_section))
81
+
82
+ if add_conditional_typing:
83
+ # Get conditional imports and create new ImportModel instances with modified typing_section
84
+ conditional_imports = self.file_import.get_imports_from_section(TypingSection.CONDITIONAL)
85
+ for i in conditional_imports:
86
+ # Create a new ImportModel with the baseline_typing_section instead of CONDITIONAL
87
+ result.append(
88
+ ImportModel(
89
+ typing_section=baseline_typing_section,
90
+ import_type=i.import_type,
91
+ module_name=i.module_name,
92
+ submodule_name=i.submodule_name,
93
+ alias=i.alias,
94
+ version_modules=i.version_modules,
95
+ )
96
+ )
97
+ return result
87
98
 
88
99
  def _add_type_checking_import(self):
89
100
  if any(self.file_import.get_imports_from_section(TypingSection.TYPING)):
@@ -20,6 +20,7 @@ from ..models import (
20
20
  BodyParameter,
21
21
  FileImport,
22
22
  )
23
+ from .utils import create_fake_value
23
24
 
24
25
  _LOGGER = logging.getLogger(__name__)
25
26
 
@@ -40,8 +41,17 @@ class SampleSerializer(BaseSerializer):
40
41
  self.sample = sample
41
42
  self.file_name = file_name
42
43
  self.sample_params = sample.get("parameters", {})
44
+ self._imports: str = ""
43
45
 
44
- def _imports(self) -> FileImportSerializer:
46
+ @property
47
+ def imports(self) -> str:
48
+ return self._imports
49
+
50
+ @imports.setter
51
+ def imports(self, value: str) -> None:
52
+ self._imports = value
53
+
54
+ def get_file_import(self) -> FileImport:
45
55
  imports = FileImport(self.code_model)
46
56
  client = self.operation_group.client
47
57
  namespace = client.client_namespace
@@ -59,7 +69,12 @@ class SampleSerializer(BaseSerializer):
59
69
  for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
60
70
  if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params:
61
71
  imports.merge(param.type.imports_for_sample())
62
- return FileImportSerializer(imports, True)
72
+
73
+ return imports
74
+
75
+ @staticmethod
76
+ def get_imports_from_file_import(file_import: FileImport) -> str:
77
+ return str(FileImportSerializer(file_import, True))
63
78
 
64
79
  def _client_params(self) -> dict[str, Any]:
65
80
  # client params
@@ -97,19 +112,14 @@ class SampleSerializer(BaseSerializer):
97
112
 
98
113
  # prepare operation parameters
99
114
  def _operation_params(self) -> dict[str, Any]:
100
- params = [
101
- p
102
- for p in (self.operation.parameters.positional + self.operation.parameters.keyword_only)
103
- if not p.client_default_value
104
- ]
105
- failure_info = "fail to find required param named {}"
106
115
  operation_params = {}
107
- for param in params:
108
- if not param.optional:
116
+ for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
117
+ if not param.optional and not param.client_default_value:
109
118
  param_value = self.sample_params.get(param.wire_name)
110
119
  if not param_value:
111
- raise Exception(failure_info.format(param.client_name)) # pylint: disable=broad-exception-raised
112
- operation_params[param.client_name] = self.handle_param(param, param_value)
120
+ operation_params[param.client_name] = create_fake_value(param.type)
121
+ else:
122
+ operation_params[param.client_name] = self.handle_param(param, param_value)
113
123
  return operation_params
114
124
 
115
125
  def _operation_group_name(self) -> str:
@@ -154,7 +164,7 @@ class SampleSerializer(BaseSerializer):
154
164
  operation_params=self._operation_params(),
155
165
  operation_group_name=self._operation_group_name(),
156
166
  operation_name=self._operation_name(),
157
- imports=self._imports(),
167
+ imports=self.imports,
158
168
  client_params=self._client_params(),
159
169
  origin_file=self._origin_file(),
160
170
  return_var=return_var,
@@ -3,7 +3,7 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import Any, Optional
6
+ from typing import Any
7
7
  from jinja2 import Environment
8
8
 
9
9
  from .import_serializer import FileImportSerializer
@@ -14,12 +14,9 @@ from ..models import (
14
14
  OperationGroup,
15
15
  Client,
16
16
  OperationType,
17
- ModelType,
18
- BaseType,
19
- CombinedType,
20
17
  FileImport,
21
18
  )
22
- from .utils import json_dumps_template
19
+ from .utils import create_fake_value
23
20
 
24
21
 
25
22
  def is_lro(operation_type: str) -> bool:
@@ -39,30 +36,15 @@ class TestName:
39
36
  self.code_model = code_model
40
37
  self.client_name = client_name
41
38
  self.async_mode = async_mode
42
-
43
- @property
44
- def async_suffix_capt(self) -> str:
45
- return "Async" if self.async_mode else ""
46
-
47
- @property
48
- def create_client_name(self) -> str:
49
- return "create_async_client" if self.async_mode else "create_client"
50
-
51
- @property
52
- def prefix(self) -> str:
53
- return self.client_name.replace("Client", "")
54
-
55
- @property
56
- def preparer_name(self) -> str:
57
- if self.code_model.options["azure-arm"]:
58
- return "RandomNameResourceGroupPreparer"
59
- return self.prefix + "Preparer"
60
-
61
- @property
62
- def base_test_class_name(self) -> str:
63
- if self.code_model.options["azure-arm"]:
64
- return "AzureMgmtRecordedTestCase"
65
- return f"{self.client_name}TestBase{self.async_suffix_capt}"
39
+ # Pre-compute values for render speed optimization
40
+ self.async_suffix_capt = "Async" if async_mode else ""
41
+ self.create_client_name = "create_async_client" if async_mode else "create_client"
42
+ self.prefix = client_name.replace("Client", "")
43
+ is_azure_arm = code_model.options["azure-arm"]
44
+ self.preparer_name = "RandomNameResourceGroupPreparer" if is_azure_arm else self.prefix + "Preparer"
45
+ self.base_test_class_name = (
46
+ "AzureMgmtRecordedTestCase" if is_azure_arm else f"{client_name}TestBase{self.async_suffix_capt}"
47
+ )
66
48
 
67
49
 
68
50
  class TestCase:
@@ -73,50 +55,52 @@ class TestCase:
73
55
  operation: OperationType,
74
56
  *,
75
57
  async_mode: bool = False,
58
+ is_azure_arm: bool = False,
76
59
  ) -> None:
77
60
  self.operation_groups = operation_groups
78
- self.params = params
79
61
  self.operation = operation
80
62
  self.async_mode = async_mode
81
-
82
- @property
83
- def name(self) -> str:
84
- if self.operation_groups[-1].is_mixin:
85
- return self.operation.name
86
- return "_".join([og.property_name for og in self.operation_groups] + [self.operation.name])
87
-
88
- @property
89
- def operation_group_prefix(self) -> str:
90
- if self.operation_groups[-1].is_mixin:
91
- return ""
92
- return "." + ".".join([og.property_name for og in self.operation_groups])
93
-
94
- @property
95
- def response(self) -> str:
96
- if self.async_mode:
97
- if is_lro(self.operation.operation_type):
98
- return "response = await (await "
99
- if is_common_operation(self.operation.operation_type):
100
- return "response = await "
101
- return "response = "
102
-
103
- @property
104
- def lro_comment(self) -> str:
105
- return " # call '.result()' to poll until service return final result"
106
-
107
- @property
108
- def operation_suffix(self) -> str:
109
- if is_lro(self.operation.operation_type):
110
- extra = ")" if self.async_mode else ""
111
- return f"{extra}.result(){self.lro_comment}"
112
- return ""
113
-
114
- @property
115
- def extra_operation(self) -> str:
116
- if is_paging(self.operation.operation_type):
117
- async_str = "async " if self.async_mode else ""
118
- return f"result = [r {async_str}for r in response]"
119
- return ""
63
+ self.is_azure_arm = is_azure_arm
64
+ # Pre-compute params
65
+ if is_azure_arm:
66
+ self.params = {k: ("resource_group.name" if k == "resource_group_name" else v) for k, v in params.items()}
67
+ else:
68
+ self.params = params
69
+ # Pre-compute name
70
+ if operation_groups[-1].is_mixin:
71
+ self.name = operation.name
72
+ else:
73
+ self.name = "_".join([og.property_name for og in operation_groups] + [operation.name])
74
+ # Pre-compute operation_group_prefix
75
+ if operation_groups[-1].is_mixin:
76
+ self.operation_group_prefix = ""
77
+ else:
78
+ self.operation_group_prefix = "." + ".".join([og.property_name for og in operation_groups])
79
+ # Pre-compute response
80
+ operation_type = operation.operation_type
81
+ if async_mode:
82
+ if is_lro(operation_type):
83
+ self.response = "response = await (await "
84
+ elif is_common_operation(operation_type):
85
+ self.response = "response = await "
86
+ else:
87
+ self.response = "response = "
88
+ else:
89
+ self.response = "response = "
90
+ # Pre-compute lro_comment
91
+ self.lro_comment = " # call '.result()' to poll until service return final result"
92
+ # Pre-compute operation_suffix
93
+ if is_lro(operation_type):
94
+ extra = ")" if async_mode else ""
95
+ self.operation_suffix = f"{extra}.result(){self.lro_comment}"
96
+ else:
97
+ self.operation_suffix = ""
98
+ # Pre-compute extra_operation
99
+ if is_paging(operation_type):
100
+ async_str = "async " if async_mode else ""
101
+ self.extra_operation = f"result = [r {async_str}for r in response]"
102
+ else:
103
+ self.extra_operation = ""
120
104
 
121
105
 
122
106
  class Test(TestName):
@@ -189,9 +173,17 @@ class TestSerializer(TestGeneralSerializer):
189
173
  super().__init__(code_model, env, async_mode=async_mode)
190
174
  self.client = client
191
175
  self.operation_group = operation_group
176
+ self._import_test: str = ""
192
177
 
193
178
  @property
194
- def import_test(self) -> FileImportSerializer:
179
+ def import_test(self) -> str:
180
+ return self._import_test
181
+
182
+ @import_test.setter
183
+ def import_test(self, value: str) -> None:
184
+ self._import_test = value
185
+
186
+ def get_import_test(self) -> str:
195
187
  imports = self.init_file_import()
196
188
  test_name = TestName(self.code_model, self.client.name, async_mode=self.async_mode)
197
189
  async_suffix = "_async" if self.async_mode else ""
@@ -212,7 +204,7 @@ class TestSerializer(TestGeneralSerializer):
212
204
  )
213
205
  if self.code_model.options["azure-arm"]:
214
206
  self.add_import_client(imports)
215
- return FileImportSerializer(imports, self.async_mode)
207
+ return str(FileImportSerializer(imports, self.async_mode))
216
208
 
217
209
  @property
218
210
  def breadth_search_operation_group(self) -> list[list[OperationGroup]]:
@@ -226,26 +218,11 @@ class TestSerializer(TestGeneralSerializer):
226
218
  queue.extend([current + [og] for og in current[-1].operation_groups])
227
219
  return result
228
220
 
229
- def get_sub_type(self, param_type: ModelType) -> ModelType:
230
- if param_type.discriminated_subtypes:
231
- for item in param_type.discriminated_subtypes.values():
232
- return self.get_sub_type(item)
233
- return param_type
234
-
235
- def get_model_type(self, param_type: BaseType) -> Optional[ModelType]:
236
- if isinstance(param_type, ModelType):
237
- return param_type
238
- if isinstance(param_type, CombinedType):
239
- return param_type.target_model_subtype((ModelType,))
240
- return None
241
-
242
221
  def get_operation_params(self, operation: OperationType) -> dict[str, Any]:
243
222
  operation_params = {}
244
223
  required_params = [p for p in operation.parameters.method if not p.optional]
245
224
  for param in required_params:
246
- model_type = self.get_model_type(param.type)
247
- param_type = self.get_sub_type(model_type) if model_type else param.type
248
- operation_params[param.client_name] = json_dumps_template(param_type.get_json_template_representation())
225
+ operation_params[param.client_name] = create_fake_value(param.type)
249
226
  return operation_params
250
227
 
251
228
  def get_test(self) -> Test:
@@ -260,6 +237,7 @@ class TestSerializer(TestGeneralSerializer):
260
237
  params=operation_params,
261
238
  operation=operation,
262
239
  async_mode=self.async_mode,
240
+ is_azure_arm=self.code_model.options["azure-arm"],
263
241
  )
264
242
  testcases.append(testcase)
265
243
  if not testcases:
@@ -283,6 +261,7 @@ class TestSerializer(TestGeneralSerializer):
283
261
  def serialize_test(self) -> str:
284
262
  return self.env.get_template("test.py.jinja2").render(
285
263
  imports=self.import_test,
286
- code_model=self.code_model,
264
+ is_azure_arm=self.code_model.options["azure-arm"],
265
+ license_header=self.code_model.license_header,
287
266
  test=self.get_test(),
288
267
  )
@@ -7,6 +7,15 @@ import json
7
7
  from typing import Optional, Any
8
8
  from pathlib import Path
9
9
 
10
+ from ..models import ModelType, BaseType, CombinedType, FileImport
11
+
12
+
13
+ def get_sub_type(param_type: ModelType) -> ModelType:
14
+ if param_type.discriminated_subtypes:
15
+ for item in param_type.discriminated_subtypes.values():
16
+ return get_sub_type(item)
17
+ return param_type
18
+
10
19
 
11
20
  def method_signature_and_response_type_annotation_template(
12
21
  *,
@@ -52,3 +61,31 @@ def _improve_json_string(template_representation: str) -> Any:
52
61
  def json_dumps_template(template_representation: Any) -> Any:
53
62
  # only for template use, since it wraps everything in strings
54
63
  return _improve_json_string(json.dumps(template_representation, indent=4))
64
+
65
+
66
+ def create_fake_value(param_type: BaseType) -> Any:
67
+ """Create a fake value for a parameter type by getting its JSON template representation.
68
+
69
+ This function generates a fake value suitable for samples and tests.
70
+
71
+ :param param_type: The parameter type to create a fake value for.
72
+ :return: A string representation of the fake value.
73
+ """
74
+
75
+ model_type: Optional[ModelType] = None
76
+ if isinstance(param_type, ModelType):
77
+ model_type = param_type
78
+ elif isinstance(param_type, CombinedType):
79
+ model_type = param_type.target_model_subtype((ModelType,))
80
+ resolved_type = get_sub_type(model_type) if model_type else param_type
81
+ return json_dumps_template(resolved_type.get_json_template_representation())
82
+
83
+
84
+ def hash_file_import(file_import: FileImport) -> str:
85
+ """Generate a hash for a FileImport object based on its imports.
86
+
87
+ :param file_import: The FileImport object to generate a hash for.
88
+ :return: A string representing the hash of the FileImport object.
89
+ """
90
+
91
+ return "".join(sorted({str(hash(i)) for i in file_import.imports}))
@@ -1062,7 +1062,7 @@ def _failsafe_deserialize(
1062
1062
  ) -> typing.Any:
1063
1063
  try:
1064
1064
  return _deserialize(deserializer, response.json(), module, rf, format)
1065
- except DeserializationError:
1065
+ except Exception: # pylint: disable=broad-except
1066
1066
  _LOGGER.warning(
1067
1067
  "Ran into a deserialization error. Ignoring since this is failsafe deserialization",
1068
1068
  exc_info=True
@@ -1072,11 +1072,11 @@ def _failsafe_deserialize(
1072
1072
 
1073
1073
  def _failsafe_deserialize_xml(
1074
1074
  deserializer: typing.Any,
1075
- response: HttpResponse,
1075
+ response: HttpResponse,
1076
1076
  ) -> typing.Any:
1077
1077
  try:
1078
1078
  return _deserialize_xml(deserializer, response.text())
1079
- except DeserializationError:
1079
+ except Exception: # pylint: disable=broad-except
1080
1080
  _LOGGER.warning(
1081
1081
  "Ran into a deserialization error. Ignoring since this is failsafe deserialization",
1082
1082
  exc_info=True
@@ -1,48 +1,51 @@
1
1
  {% set prefix_lower = test.prefix|lower %}
2
- {% set client_var = "self.client" if code_model.options["azure-arm"] else "client" %}
2
+ {% set client_var = "self.client" if is_azure_arm else "client" %}
3
3
  {% set async = "async " if test.async_mode else "" %}
4
4
  {% set async_suffix = "_async" if test.async_mode else "" %}
5
5
  # coding=utf-8
6
- {% if code_model.license_header %}
7
- {{ code_model.license_header }}
6
+ {% if license_header %}
7
+ {{ license_header }}
8
8
  {% endif %}
9
9
  import pytest
10
10
  {{ imports }}
11
11
 
12
- {% if code_model.options["azure-arm"] %}
12
+ {% if is_azure_arm %}
13
13
  AZURE_LOCATION = "eastus"
14
14
  {% endif %}
15
15
 
16
16
  @pytest.mark.skip("you may need to update the auto-generated test case before run it")
17
17
  class {{ test.test_class_name }}({{ test.base_test_class_name }}):
18
- {% if code_model.options["azure-arm"] %}
18
+ {% if is_azure_arm %}
19
19
  def setup_method(self, method):
20
20
  {% if test.async_mode %}
21
21
  self.client = self.create_mgmt_client({{ test.client_name }}, is_async=True)
22
22
  {% else %}
23
23
  self.client = self.create_mgmt_client({{ test.client_name }})
24
24
  {% endif %}
25
- {% endif %}
25
+
26
26
  {% for testcase in test.testcases %}
27
- {% if code_model.options["azure-arm"] %}
28
27
  @{{ test.preparer_name }}(location=AZURE_LOCATION)
29
- {% else %}
30
- @{{ test.preparer_name }}()
31
- {% endif %}
32
28
  @recorded_by_proxy{{ async_suffix }}
33
- {% if code_model.options["azure-arm"] %}
34
29
  {{ async }}def test_{{ testcase.name }}(self, resource_group):
35
- {% else %}
30
+ {{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
31
+ {% for key, value in testcase.params.items() %}
32
+ {{ key }}={{ value }},
33
+ {% endfor %}
34
+ ){{ testcase.operation_suffix }}
35
+ {{ testcase.extra_operation }}
36
+ # please add some check logic here by yourself
37
+ # ...
38
+
39
+ {% endfor %}
40
+ {% else %}
41
+ {% for testcase in test.testcases %}
42
+ @{{ test.preparer_name }}()
43
+ @recorded_by_proxy{{ async_suffix }}
36
44
  {{ async }}def test_{{ testcase.name }}(self, {{ prefix_lower }}_endpoint):
37
45
  {{ client_var }} = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint)
38
- {% endif %}
39
46
  {{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
40
47
  {% for key, value in testcase.params.items() %}
41
- {% if code_model.options["azure-arm"] and key == "resource_group_name" %}
42
- {{ key }}=resource_group.name,
43
- {% else %}
44
- {{ key }}={{ value|indent(12) }},
45
- {% endif %}
48
+ {{ key }}={{ value }},
46
49
  {% endfor %}
47
50
  ){{ testcase.operation_suffix }}
48
51
  {{ testcase.extra_operation }}
@@ -50,3 +53,4 @@ class {{ test.test_class_name }}({{ test.base_test_class_name }}):
50
53
  # ...
51
54
 
52
55
  {% endfor %}
56
+ {% endif %}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.48.2",
3
+ "version": "6.48.3",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "main": "index.js",
6
6
  "repository": {
@@ -19,7 +19,7 @@
19
19
  },
20
20
  "homepage": "https://github.com/Azure/autorest.python/blob/main/README.md",
21
21
  "dependencies": {
22
- "@typespec/http-client-python": "~0.26.2",
22
+ "@typespec/http-client-python": "~0.26.3",
23
23
  "@autorest/system-requirements": "~1.0.2",
24
24
  "fs-extra": "~11.2.0",
25
25
  "tsx": "~4.19.1"