@autorest/python 6.2.4 → 6.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/autorest/black/__init__.py +3 -2
  2. package/autorest/codegen/__init__.py +1 -1
  3. package/autorest/codegen/models/base.py +14 -2
  4. package/autorest/codegen/models/client.py +20 -22
  5. package/autorest/codegen/models/combined_type.py +1 -1
  6. package/autorest/codegen/models/imports.py +2 -2
  7. package/autorest/codegen/models/model_type.py +4 -0
  8. package/autorest/codegen/models/operation.py +6 -6
  9. package/autorest/codegen/models/operation_group.py +12 -11
  10. package/autorest/codegen/models/primitive_types.py +50 -0
  11. package/autorest/codegen/models/property.py +4 -0
  12. package/autorest/codegen/models/request_builder.py +9 -7
  13. package/autorest/codegen/serializers/__init__.py +7 -6
  14. package/autorest/codegen/serializers/builder_serializer.py +67 -30
  15. package/autorest/codegen/serializers/client_serializer.py +22 -8
  16. package/autorest/codegen/serializers/metadata_serializer.py +7 -1
  17. package/autorest/codegen/serializers/model_serializer.py +12 -13
  18. package/autorest/codegen/serializers/operation_groups_serializer.py +1 -0
  19. package/autorest/codegen/serializers/parameter_serializer.py +3 -3
  20. package/autorest/codegen/serializers/patch_serializer.py +2 -4
  21. package/autorest/codegen/serializers/sample_serializer.py +23 -14
  22. package/autorest/codegen/serializers/utils.py +6 -0
  23. package/autorest/codegen/templates/client.py.jinja2 +3 -12
  24. package/autorest/codegen/templates/config.py.jinja2 +2 -5
  25. package/autorest/codegen/templates/keywords.jinja2 +2 -2
  26. package/autorest/codegen/templates/metadata.json.jinja2 +2 -2
  27. package/autorest/codegen/templates/model_base.py.jinja2 +171 -130
  28. package/autorest/codegen/templates/model_dpg.py.jinja2 +1 -1
  29. package/autorest/codegen/templates/packaging_templates/setup.py.jinja2 +1 -0
  30. package/autorest/codegen/templates/request_builder.py.jinja2 +1 -1
  31. package/autorest/codegen/templates/serialization.py.jinja2 +286 -325
  32. package/autorest/jsonrpc/__init__.py +3 -1
  33. package/autorest/jsonrpc/localapi.py +3 -1
  34. package/autorest/jsonrpc/stdstream.py +1 -1
  35. package/autorest/m2r/__init__.py +2 -2
  36. package/autorest/multiapi/models/imports.py +34 -22
  37. package/autorest/multiapi/serializers/import_serializer.py +1 -1
  38. package/autorest/multiapi/templates/multiapi_config.py.jinja2 +2 -8
  39. package/autorest/multiapi/templates/multiapi_service_client.py.jinja2 +1 -1
  40. package/autorest/postprocess/__init__.py +5 -4
  41. package/autorest/preprocess/__init__.py +7 -1
  42. package/autorest/preprocess/helpers.py +14 -2
  43. package/autorest/preprocess/python_mappings.py +27 -0
  44. package/package.json +2 -2
  45. package/setup.py +3 -0
@@ -8,13 +8,14 @@ from pathlib import Path
8
8
  import os
9
9
  from typing import Any, Dict
10
10
  import black
11
+ from black.report import NothingChanged
11
12
 
12
13
  from .. import Plugin, PluginAutorest
13
14
  from .._utils import parse_args
14
15
 
15
16
  logging.getLogger("blib2to3").setLevel(logging.ERROR)
16
17
 
17
- _BLACK_MODE = black.Mode()
18
+ _BLACK_MODE = black.Mode() # pyright: ignore [reportPrivateImportUsage]
18
19
  _BLACK_MODE.line_length = 120
19
20
 
20
21
 
@@ -42,7 +43,7 @@ class BlackScriptPlugin(Plugin): # pylint: disable=abstract-method
42
43
  file_content = black.format_file_contents(
43
44
  file_content, fast=True, mode=_BLACK_MODE
44
45
  )
45
- except black.NothingChanged:
46
+ except NothingChanged:
46
47
  pass
47
48
  self.write_file(file, file_content)
48
49
 
@@ -293,7 +293,7 @@ class CodeGeneratorAutorest(CodeGenerator, PluginAutorest):
293
293
  # Parse the received YAML
294
294
  return yaml.safe_load(file_content)
295
295
 
296
- def get_serializer(self, code_model: CodeModel): # type: ignore
296
+ def get_serializer(self, code_model: CodeModel):
297
297
  return JinjaSerializerAutorest(
298
298
  self._autorestapi,
299
299
  code_model,
@@ -32,7 +32,7 @@ class BaseModel:
32
32
  return f"<{self.__class__.__name__}>"
33
33
 
34
34
 
35
- class BaseType(BaseModel, ABC):
35
+ class BaseType(BaseModel, ABC): # pylint: disable=too-many-public-methods
36
36
  """This is the base class for all types.
37
37
 
38
38
  :param yaml_data: the yaml data for this schema
@@ -60,6 +60,14 @@ class BaseType(BaseModel, ABC):
60
60
  def imports_for_multiapi(self, **kwargs: Any) -> FileImport:
61
61
  return self.imports(**kwargs)
62
62
 
63
+ @staticmethod
64
+ def imports_for_sample() -> FileImport:
65
+ return FileImport()
66
+
67
+ @staticmethod
68
+ def serialize_sample_value(value: Any) -> str:
69
+ return repr(value)
70
+
63
71
  @property
64
72
  def xml_metadata(self) -> Dict[str, Any]:
65
73
  """XML metadata for the type, if the type has it."""
@@ -87,7 +95,6 @@ class BaseType(BaseModel, ABC):
87
95
  return ", ".join(attrs_list)
88
96
 
89
97
  @property
90
- @abstractmethod
91
98
  def serialization_type(self) -> str:
92
99
  """The tag recognized by 'msrest' as a serialization/deserialization.
93
100
 
@@ -100,6 +107,11 @@ class BaseType(BaseModel, ABC):
100
107
  If dict: '{str}'
101
108
  """
102
109
  ...
110
+ raise NotImplementedError()
111
+
112
+ @property
113
+ def msrest_deserialization_key(self) -> str:
114
+ return self.serialization_type
103
115
 
104
116
  @property
105
117
  def client_default_value(self) -> Any:
@@ -74,7 +74,9 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
74
74
  ]
75
75
  self.format_lro_operations()
76
76
 
77
- def _build_request_builders(self):
77
+ def _build_request_builders(
78
+ self,
79
+ ) -> List[Union[RequestBuilder, OverloadedRequestBuilder]]:
78
80
  request_builders: List[Union[RequestBuilder, OverloadedRequestBuilder]] = []
79
81
  for og_group in self.yaml_data["operationGroups"]:
80
82
  for operation_yaml in og_group["operations"]:
@@ -84,7 +86,7 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
84
86
  client=self,
85
87
  )
86
88
  if request_builder.overloads:
87
- request_builders.extend(request_builder.overloads) # type: ignore
89
+ request_builders.extend(request_builder.overloads)
88
90
  request_builders.append(request_builder)
89
91
  if operation_yaml.get("nextOperation"):
90
92
  # i am a paging operation and i have a next operation.
@@ -174,7 +176,9 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
174
176
  for gp in self.parameters:
175
177
  if gp.method_location == ParameterMethodLocation.KWARG:
176
178
  continue
177
- file_import.merge(gp.imports(async_mode))
179
+ file_import.merge(
180
+ gp.imports(async_mode, relative_path=".." if async_mode else ".")
181
+ )
178
182
  file_import.add_submodule_import(
179
183
  "._configuration",
180
184
  f"{self.name}Configuration",
@@ -285,23 +289,9 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
285
289
  and self.code_model.options["models_mode"] == "msrest"
286
290
  ):
287
291
  path_to_models = ".." if async_mode else "."
288
- if len(self.code_model.model_types) != len(
289
- self.code_model.public_model_types
290
- ):
291
- # this means we have hidden models. In that case, we import directly from the models
292
- # file, not the module, bc we don't expose the hidden models in the models module
293
-
294
- # Also in this case, we're in version tolerant, so python3 only is true
295
- file_import.add_submodule_import(
296
- f"{path_to_models}models",
297
- self.code_model.models_filename,
298
- ImportType.LOCAL,
299
- alias="models",
300
- )
301
- else:
302
- file_import.add_submodule_import(
303
- path_to_models, "models", ImportType.LOCAL
304
- )
292
+ file_import.add_submodule_import(
293
+ path_to_models, "models", ImportType.LOCAL, alias="_models"
294
+ )
305
295
  elif self.code_model.options["models_mode"] == "msrest":
306
296
  # in this case, we have client_models = {} in the service client, which needs a type annotation
307
297
  # this import will always be commented, so will always add it to the typing section
@@ -411,7 +401,11 @@ class Config(_ClientConfigBase[ConfigGlobalParameterList]):
411
401
  and gp not in self.parameters.kwargs_to_pop
412
402
  ):
413
403
  continue
414
- file_import.merge(gp.imports(async_mode=async_mode))
404
+ file_import.merge(
405
+ gp.imports(
406
+ async_mode=async_mode, relative_path=".." if async_mode else "."
407
+ )
408
+ )
415
409
  return file_import
416
410
 
417
411
  def imports_for_multiapi(self, async_mode: bool) -> FileImport:
@@ -423,7 +417,11 @@ class Config(_ClientConfigBase[ConfigGlobalParameterList]):
423
417
  and gp.client_name == "api_version"
424
418
  ):
425
419
  continue
426
- file_import.merge(gp.imports_for_multiapi(async_mode=async_mode))
420
+ file_import.merge(
421
+ gp.imports_for_multiapi(
422
+ async_mode=async_mode, relative_path=".." if async_mode else "."
423
+ )
424
+ )
427
425
  return file_import
428
426
 
429
427
  @classmethod
@@ -41,7 +41,7 @@ class CombinedType(BaseType):
41
41
  If list: '[str]'
42
42
  If dict: '{str}'
43
43
  """
44
- ...
44
+ raise ValueError("Shouldn't get serialization type of a combinedtype")
45
45
 
46
46
  @property
47
47
  def client_default_value(self) -> Any:
@@ -69,7 +69,7 @@ class ImportModel:
69
69
  except AttributeError:
70
70
  return False
71
71
 
72
- def __hash__(self):
72
+ def __hash__(self) -> int:
73
73
  retval: int = 0
74
74
  for attr in dir(self):
75
75
  if attr[0] != "_":
@@ -88,7 +88,7 @@ class TypeDefinition:
88
88
 
89
89
 
90
90
  class FileImport:
91
- def __init__(self, imports: List[ImportModel] = None) -> None:
91
+ def __init__(self, imports: Optional[List[ImportModel]] = None) -> None:
92
92
  self.imports = imports or []
93
93
  # has sync and async type definitions
94
94
  self.type_definitions: Dict[str, TypeDefinition] = {}
@@ -81,6 +81,10 @@ class ModelType(
81
81
  return f"{'' if self.is_public else '_models.'}_models.{self.name}"
82
82
  return "object"
83
83
 
84
+ @property
85
+ def msrest_deserialization_key(self) -> str:
86
+ return self.name
87
+
84
88
  @property
85
89
  def is_polymorphic(self) -> bool:
86
90
  return any(p.is_polymorphic for p in self.properties)
@@ -160,10 +160,10 @@ class OperationBase( # pylint: disable=too-many-public-methods
160
160
  isinstance(r.type, ModelType) for r in self.responses
161
161
  ):
162
162
  r = next(r for r in self.responses if isinstance(r.type, ModelType))
163
- type_name = getattr(r, "item_type", getattr(r, "type")).docstring_text(
164
- **kwargs
165
- )
166
- retval += f". The {type_name} is compatible with MutableMapping"
163
+ item_type = getattr(r, "item_type", getattr(r, "type"))
164
+ if item_type:
165
+ type_name = item_type.docstring_text(**kwargs)
166
+ retval += f". The {type_name} is compatible with MutableMapping"
167
167
  return retval
168
168
 
169
169
  def response_docstring_type(self, **kwargs) -> str:
@@ -344,9 +344,9 @@ class OperationBase( # pylint: disable=too-many-public-methods
344
344
  )
345
345
 
346
346
  if self.has_kwargs_to_pop_with_default(
347
- self.parameters.kwargs_to_pop, ParameterLocation.HEADER
347
+ self.parameters.kwargs_to_pop, ParameterLocation.HEADER # type: ignore
348
348
  ) or self.has_kwargs_to_pop_with_default(
349
- self.parameters.kwargs_to_pop, ParameterLocation.QUERY
349
+ self.parameters.kwargs_to_pop, ParameterLocation.QUERY # type: ignore
350
350
  ):
351
351
  file_import.add_submodule_import(
352
352
  "azure.core.utils", "case_insensitive_dict", ImportType.AZURECORE
@@ -48,11 +48,16 @@ class OperationGroup(BaseModel):
48
48
 
49
49
  def imports_for_multiapi(self, async_mode: bool) -> FileImport:
50
50
  file_import = FileImport()
51
+ relative_path = ".." if async_mode else "."
51
52
  for operation in self.operations:
52
53
  file_import.merge(
53
- operation.imports_for_multiapi(
54
- async_mode, relative_path=".." if async_mode else "."
55
- )
54
+ operation.imports_for_multiapi(async_mode, relative_path=relative_path)
55
+ )
56
+ if (
57
+ self.code_model.model_types or self.code_model.enums
58
+ ) and self.code_model.options["models_mode"] == "msrest":
59
+ file_import.add_submodule_import(
60
+ relative_path, "models", ImportType.LOCAL, alias="_models"
56
61
  )
57
62
  return file_import
58
63
 
@@ -63,12 +68,6 @@ class OperationGroup(BaseModel):
63
68
  retval = add_to_pylint_disable(retval, "abstract-class-instantiated")
64
69
  return retval
65
70
 
66
- @property
67
- def mypy_ignore(self) -> str:
68
- if self.has_abstract_operations:
69
- return " # type: ignore"
70
- return ""
71
-
72
71
  @property
73
72
  def need_validation(self) -> bool:
74
73
  """Whether any of its operations need validation"""
@@ -84,8 +83,10 @@ class OperationGroup(BaseModel):
84
83
  )
85
84
  # for multiapi
86
85
  if (
87
- self.code_model.model_types or self.code_model.enums
88
- ) and self.code_model.options["models_mode"] == "msrest":
86
+ (self.code_model.public_model_types)
87
+ and self.code_model.options["models_mode"] == "msrest"
88
+ and not self.is_mixin
89
+ ):
89
90
  file_import.add_submodule_import(
90
91
  relative_path, "models", ImportType.LOCAL, alias="_models"
91
92
  )
@@ -384,6 +384,16 @@ class DatetimeType(PrimitiveType):
384
384
  def instance_check_template(self) -> str:
385
385
  return "isinstance({}, datetime.datetime)"
386
386
 
387
+ @staticmethod
388
+ def imports_for_sample() -> FileImport:
389
+ file_import = super(DatetimeType, DatetimeType).imports_for_sample()
390
+ file_import.add_import("isodate", ImportType.STDLIB)
391
+ return file_import
392
+
393
+ @staticmethod
394
+ def serialize_sample_value(value: Any) -> str:
395
+ return f"isodate.parse_datetime({repr(value)})"
396
+
387
397
 
388
398
  class TimeType(PrimitiveType):
389
399
  @property
@@ -418,6 +428,16 @@ class TimeType(PrimitiveType):
418
428
  def instance_check_template(self) -> str:
419
429
  return "isinstance({}, datetime.time)"
420
430
 
431
+ @staticmethod
432
+ def imports_for_sample() -> FileImport:
433
+ file_import = super(TimeType, TimeType).imports_for_sample()
434
+ file_import.add_import("isodate", ImportType.STDLIB)
435
+ return file_import
436
+
437
+ @staticmethod
438
+ def serialize_sample_value(value: Any) -> str:
439
+ return f"isodate.parse_time({repr(value)})"
440
+
421
441
 
422
442
  class UnixTimeType(PrimitiveType):
423
443
  @property
@@ -452,6 +472,16 @@ class UnixTimeType(PrimitiveType):
452
472
  def instance_check_template(self) -> str:
453
473
  return "isinstance({}, datetime.time)"
454
474
 
475
+ @staticmethod
476
+ def imports_for_sample() -> FileImport:
477
+ file_import = super(UnixTimeType, UnixTimeType).imports_for_sample()
478
+ file_import.add_import("datetime", ImportType.STDLIB)
479
+ return file_import
480
+
481
+ @staticmethod
482
+ def serialize_sample_value(value: Any) -> str:
483
+ return f"datetime.datetime.fromtimestamp({repr(value)}, datetime.timezone.utc)"
484
+
455
485
 
456
486
  class DateType(PrimitiveType):
457
487
  @property
@@ -486,6 +516,16 @@ class DateType(PrimitiveType):
486
516
  def instance_check_template(self) -> str:
487
517
  return "isinstance({}, datetime.date)"
488
518
 
519
+ @staticmethod
520
+ def imports_for_sample() -> FileImport:
521
+ file_import = super(DateType, DateType).imports_for_sample()
522
+ file_import.add_import("isodate", ImportType.STDLIB)
523
+ return file_import
524
+
525
+ @staticmethod
526
+ def serialize_sample_value(value: Any) -> str:
527
+ return f"isodate.parse_date({repr(value)})"
528
+
489
529
 
490
530
  class DurationType(PrimitiveType):
491
531
  @property
@@ -520,6 +560,16 @@ class DurationType(PrimitiveType):
520
560
  def instance_check_template(self) -> str:
521
561
  return "isinstance({}, datetime.timedelta)"
522
562
 
563
+ @staticmethod
564
+ def imports_for_sample() -> FileImport:
565
+ file_import = super(DurationType, DurationType).imports_for_sample()
566
+ file_import.add_import("isodate", ImportType.STDLIB)
567
+ return file_import
568
+
569
+ @staticmethod
570
+ def serialize_sample_value(value: Any) -> str:
571
+ return f"isodate.parse_duration({repr(value)})"
572
+
523
573
 
524
574
  class ByteArraySchema(PrimitiveType):
525
575
  def __init__(self, yaml_data: Dict[str, Any], code_model: "CodeModel") -> None:
@@ -83,6 +83,10 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
83
83
  def serialization_type(self) -> str:
84
84
  return self.type.serialization_type
85
85
 
86
+ @property
87
+ def msrest_deserialization_key(self) -> str:
88
+ return self.type.msrest_deserialization_key
89
+
86
90
  def type_annotation(self, *, is_operation_file: bool = False) -> str:
87
91
  if self.optional and self.client_default_value is None:
88
92
  return f"Optional[{self.type.type_annotation(is_operation_file=is_operation_file)}]"
@@ -71,10 +71,18 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
71
71
 
72
72
  def imports(self) -> FileImport:
73
73
  file_import = FileImport()
74
+ relative_path = ".."
75
+ if (
76
+ not self.code_model.options["builders_visibility"] == "embedded"
77
+ and self.group_name
78
+ ):
79
+ relative_path = "..." if self.group_name else ".."
74
80
  if self.abstract:
75
81
  return file_import
76
82
  for parameter in self.parameters.method:
77
- file_import.merge(parameter.imports(async_mode=False))
83
+ file_import.merge(
84
+ parameter.imports(async_mode=False, relative_path=relative_path)
85
+ )
78
86
 
79
87
  file_import.add_submodule_import(
80
88
  "azure.core.rest",
@@ -83,12 +91,6 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
83
91
  )
84
92
 
85
93
  if self.parameters.path:
86
- relative_path = ".."
87
- if (
88
- not self.code_model.options["builders_visibility"] == "embedded"
89
- and self.group_name
90
- ):
91
- relative_path = "..." if self.group_name else ".."
92
94
  file_import.add_submodule_import(
93
95
  f"{relative_path}_vendor", "_format_url_section", ImportType.LOCAL
94
96
  )
@@ -4,7 +4,7 @@
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
6
  import logging
7
- from typing import List, Optional, Any, Union, cast
7
+ from typing import List, Optional, Any, Union
8
8
  from pathlib import Path
9
9
  from jinja2 import PackageLoader, Environment, FileSystemLoader, StrictUndefined
10
10
 
@@ -28,6 +28,7 @@ from .request_builders_serializer import RequestBuildersSerializer
28
28
  from .patch_serializer import PatchSerializer
29
29
  from .sample_serializer import SampleSerializer
30
30
  from ..._utils import to_snake_case
31
+ from .utils import extract_sample_name
31
32
 
32
33
  _LOGGER = logging.getLogger(__name__)
33
34
 
@@ -190,7 +191,7 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
190
191
  if self.read_file(namespace_path / Path("models.py")):
191
192
  self.write_file(
192
193
  namespace_path / Path("models.py"),
193
- cast(str, self.read_file(namespace_path / Path("models.py"))),
194
+ self.read_file(namespace_path / Path("models.py")),
194
195
  )
195
196
 
196
197
  def _serialize_and_write_package_files(self, namespace_path: Path) -> None:
@@ -543,8 +544,9 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
543
544
  samples = operation.yaml_data["samples"]
544
545
  if not samples or operation.name.startswith("_"):
545
546
  continue
546
- for key, value in samples.items():
547
- file_name = to_snake_case(key) + ".py"
547
+ for value in samples.values():
548
+ file = value.get("x-ms-original-file", "sample.json")
549
+ file_name = to_snake_case(extract_sample_name(file)) + ".py"
548
550
  try:
549
551
  self.write_file(
550
552
  out_path / file_name,
@@ -555,12 +557,11 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
555
557
  operation=operation,
556
558
  sample=value,
557
559
  file_name=file_name,
558
- sample_origin_name=key,
559
560
  ).serialize(),
560
561
  )
561
562
  except Exception as e: # pylint: disable=broad-except
562
563
  # sample generation shall not block code generation, so just log error
563
- log_error = f"error happens in sample {key}: {e}"
564
+ log_error = f"error happens in sample {file}: {e}"
564
565
  _LOGGER.error(log_error)
565
566
 
566
567