betterproto2-compiler 0.2.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. betterproto2_compiler/__init__.py +0 -0
  2. betterproto2_compiler/casing.py +140 -0
  3. betterproto2_compiler/compile/__init__.py +0 -0
  4. betterproto2_compiler/compile/importing.py +180 -0
  5. betterproto2_compiler/compile/naming.py +21 -0
  6. betterproto2_compiler/known_types/__init__.py +14 -0
  7. betterproto2_compiler/known_types/any.py +36 -0
  8. betterproto2_compiler/known_types/duration.py +25 -0
  9. betterproto2_compiler/known_types/timestamp.py +45 -0
  10. betterproto2_compiler/lib/__init__.py +0 -0
  11. betterproto2_compiler/lib/google/__init__.py +0 -0
  12. betterproto2_compiler/lib/google/protobuf/__init__.py +3338 -0
  13. betterproto2_compiler/lib/google/protobuf/compiler/__init__.py +235 -0
  14. betterproto2_compiler/lib/message_pool.py +3 -0
  15. betterproto2_compiler/plugin/__init__.py +3 -0
  16. betterproto2_compiler/plugin/__main__.py +3 -0
  17. betterproto2_compiler/plugin/compiler.py +70 -0
  18. betterproto2_compiler/plugin/main.py +47 -0
  19. betterproto2_compiler/plugin/models.py +643 -0
  20. betterproto2_compiler/plugin/module_validation.py +156 -0
  21. betterproto2_compiler/plugin/parser.py +272 -0
  22. betterproto2_compiler/plugin/plugin.bat +2 -0
  23. betterproto2_compiler/plugin/typing_compiler.py +163 -0
  24. betterproto2_compiler/py.typed +0 -0
  25. betterproto2_compiler/settings.py +9 -0
  26. betterproto2_compiler/templates/header.py.j2 +59 -0
  27. betterproto2_compiler/templates/template.py.j2 +258 -0
  28. betterproto2_compiler-0.2.0.dist-info/LICENSE.md +22 -0
  29. betterproto2_compiler-0.2.0.dist-info/METADATA +35 -0
  30. betterproto2_compiler-0.2.0.dist-info/RECORD +32 -0
  31. betterproto2_compiler-0.2.0.dist-info/WHEEL +4 -0
  32. betterproto2_compiler-0.2.0.dist-info/entry_points.txt +3 -0
File without changes
@@ -0,0 +1,140 @@
1
+ import keyword
2
+ import re
3
+
4
+ # Word delimiters and symbols that will not be preserved when re-casing.
5
+ # language=PythonRegExp
6
+ SYMBOLS = "[^a-zA-Z0-9]*"
7
+
8
+ # Optionally capitalized word.
9
+ # language=PythonRegExp
10
+ WORD = "[A-Z]*[a-z]*[0-9]*"
11
+
12
+ # Uppercase word, not followed by lowercase letters.
13
+ # language=PythonRegExp
14
+ WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*"
15
+
16
+
17
+ def safe_snake_case(value: str) -> str:
18
+ """Snake case a value taking into account Python keywords."""
19
+ value = snake_case(value)
20
+ value = sanitize_name(value)
21
+ return value
22
+
23
+
24
+ def snake_case(value: str, strict: bool = True) -> str:
25
+ """
26
+ Join words with an underscore into lowercase and remove symbols.
27
+
28
+ Parameters
29
+ -----------
30
+ value: :class:`str`
31
+ The value to convert.
32
+ strict: :class:`bool`
33
+ Whether or not to force single underscores.
34
+
35
+ Returns
36
+ --------
37
+ :class:`str`
38
+ The value in snake_case.
39
+ """
40
+
41
+ def substitute_word(symbols: str, word: str, is_start: bool) -> str:
42
+ if not word:
43
+ return ""
44
+ if strict:
45
+ delimiter_count = 0 if is_start else 1 # Single underscore if strict.
46
+ elif is_start:
47
+ delimiter_count = len(symbols)
48
+ elif word.isupper() or word.islower():
49
+ delimiter_count = max(1, len(symbols)) # Preserve all delimiters if not strict.
50
+ else:
51
+ delimiter_count = len(symbols) + 1 # Extra underscore for leading capital.
52
+
53
+ return ("_" * delimiter_count) + word.lower()
54
+
55
+ snake = re.sub(
56
+ f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
57
+ lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
58
+ value,
59
+ )
60
+ return snake
61
+
62
+
63
+ def pascal_case(value: str, strict: bool = True) -> str:
64
+ """
65
+ Capitalize each word and remove symbols.
66
+
67
+ Parameters
68
+ -----------
69
+ value: :class:`str`
70
+ The value to convert.
71
+ strict: :class:`bool`
72
+ Whether or not to output only alphanumeric characters.
73
+
74
+ Returns
75
+ --------
76
+ :class:`str`
77
+ The value in PascalCase.
78
+ """
79
+
80
+ def substitute_word(symbols, word):
81
+ if strict:
82
+ return word.capitalize() # Remove all delimiters
83
+
84
+ if word.islower():
85
+ delimiter_length = len(symbols[:-1]) # Lose one delimiter
86
+ else:
87
+ delimiter_length = len(symbols) # Preserve all delimiters
88
+
89
+ return ("_" * delimiter_length) + word.capitalize()
90
+
91
+ return re.sub(
92
+ f"({SYMBOLS})({WORD_UPPER}|{WORD})",
93
+ lambda groups: substitute_word(groups[1], groups[2]),
94
+ value,
95
+ )
96
+
97
+
98
+ def camel_case(value: str, strict: bool = True) -> str:
99
+ """
100
+ Capitalize all words except first and remove symbols.
101
+
102
+ Parameters
103
+ -----------
104
+ value: :class:`str`
105
+ The value to convert.
106
+ strict: :class:`bool`
107
+ Whether or not to output only alphanumeric characters.
108
+
109
+ Returns
110
+ --------
111
+ :class:`str`
112
+ The value in camelCase.
113
+ """
114
+ return lowercase_first(pascal_case(value, strict=strict))
115
+
116
+
117
+ def lowercase_first(value: str) -> str:
118
+ """
119
+ Lower cases the first character of the value.
120
+
121
+ Parameters
122
+ ----------
123
+ value: :class:`str`
124
+ The value to lower case.
125
+
126
+ Returns
127
+ -------
128
+ :class:`str`
129
+ The lower cased string.
130
+ """
131
+ return value[0:1].lower() + value[1:]
132
+
133
+
134
+ def sanitize_name(value: str) -> str:
135
+ # https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
136
+ if keyword.iskeyword(value):
137
+ return f"{value}_"
138
+ if not value.isidentifier():
139
+ return f"_{value}"
140
+ return value
File without changes
@@ -0,0 +1,180 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ )
7
+
8
+ from betterproto2_compiler.lib.google import protobuf as google_protobuf
9
+ from betterproto2_compiler.settings import Settings
10
+
11
+ from ..casing import safe_snake_case
12
+ from .naming import pythonize_class_name
13
+
14
+ if TYPE_CHECKING:
15
+ from ..plugin.models import PluginRequestCompiler
16
+
17
+ WRAPPER_TYPES: dict[str, type] = {
18
+ ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
19
+ ".google.protobuf.FloatValue": google_protobuf.FloatValue,
20
+ ".google.protobuf.Int32Value": google_protobuf.Int32Value,
21
+ ".google.protobuf.Int64Value": google_protobuf.Int64Value,
22
+ ".google.protobuf.UInt32Value": google_protobuf.UInt32Value,
23
+ ".google.protobuf.UInt64Value": google_protobuf.UInt64Value,
24
+ ".google.protobuf.BoolValue": google_protobuf.BoolValue,
25
+ ".google.protobuf.StringValue": google_protobuf.StringValue,
26
+ ".google.protobuf.BytesValue": google_protobuf.BytesValue,
27
+ }
28
+
29
+
30
+ def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) -> tuple[str, str]:
31
+ """
32
+ Split full source type name into package and type name.
33
+ E.g. 'root.package.Message' -> ('root.package', 'Message')
34
+ 'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')
35
+
36
+ The function goes through the symbols that have been defined (names, enums,
37
+ packages) to find the actual package and name of the object that is referenced.
38
+ """
39
+ if field_type_name[0] != ".":
40
+ raise RuntimeError("relative names are not supported")
41
+ field_type_name = field_type_name[1:]
42
+ parts = field_type_name.split(".")
43
+
44
+ answer = None
45
+
46
+ # a.b.c:
47
+ # i=0: "", "a.b.c"
48
+ # i=1: "a", "b.c"
49
+ # i=2: "a.b", "c"
50
+ for i in range(len(parts)):
51
+ package_name, object_name = ".".join(parts[:i]), ".".join(parts[i:])
52
+
53
+ package = request.output_packages.get(package_name)
54
+
55
+ if not package:
56
+ continue
57
+
58
+ if object_name in package.messages or object_name in package.enums:
59
+ if answer:
60
+ # This should have already been handeled by protoc
61
+ raise ValueError(f"ambiguous definition: {field_type_name}")
62
+ answer = package_name, object_name
63
+
64
+ if answer:
65
+ return answer
66
+
67
+ raise ValueError(f"can't find type name: {field_type_name}")
68
+
69
+
70
+ def get_type_reference(
71
+ *,
72
+ package: str,
73
+ imports: set,
74
+ source_type: str,
75
+ request: PluginRequestCompiler,
76
+ unwrap: bool = True,
77
+ settings: Settings,
78
+ ) -> str:
79
+ """
80
+ Return a Python type name for a proto type reference. Adds the import if
81
+ necessary. Unwraps well known type if required.
82
+ """
83
+ if unwrap:
84
+ if source_type in WRAPPER_TYPES:
85
+ wrapped_type = type(WRAPPER_TYPES[source_type]().value)
86
+ return settings.typing_compiler.optional(wrapped_type.__name__)
87
+
88
+ if source_type == ".google.protobuf.Duration":
89
+ return "datetime.timedelta"
90
+
91
+ elif source_type == ".google.protobuf.Timestamp":
92
+ return "datetime.datetime"
93
+
94
+ source_package, source_type = parse_source_type_name(source_type, request)
95
+
96
+ current_package: list[str] = package.split(".") if package else []
97
+ py_package: list[str] = source_package.split(".") if source_package else []
98
+ py_type: str = pythonize_class_name(source_type)
99
+
100
+ if py_package == current_package:
101
+ return reference_sibling(py_type)
102
+
103
+ if py_package[: len(current_package)] == current_package:
104
+ return reference_descendent(current_package, imports, py_package, py_type)
105
+
106
+ if current_package[: len(py_package)] == py_package:
107
+ return reference_ancestor(current_package, imports, py_package, py_type)
108
+
109
+ return reference_cousin(current_package, imports, py_package, py_type)
110
+
111
+
112
+ def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> str:
113
+ """
114
+ Returns a reference to a python type located in the root, i.e. sys.path.
115
+ """
116
+ string_import = ".".join(py_package)
117
+ string_alias = safe_snake_case(string_import)
118
+ imports.add(f"import {string_import} as {string_alias}")
119
+ return f"{string_alias}.{py_type}"
120
+
121
+
122
+ def reference_sibling(py_type: str) -> str:
123
+ """
124
+ Returns a reference to a python type within the same package as the current package.
125
+ """
126
+ return f"{py_type}"
127
+
128
+
129
+ def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
130
+ """
131
+ Returns a reference to a python type in a package that is a descendent of the
132
+ current package, and adds the required import that is aliased to avoid name
133
+ conflicts.
134
+ """
135
+ importing_descendent = py_package[len(current_package) :]
136
+ string_from = ".".join(importing_descendent[:-1])
137
+ string_import = importing_descendent[-1]
138
+ if string_from:
139
+ string_alias = "_".join(importing_descendent)
140
+ imports.add(f"from .{string_from} import {string_import} as {string_alias}")
141
+ return f"{string_alias}.{py_type}"
142
+ else:
143
+ imports.add(f"from . import {string_import}")
144
+ return f"{string_import}.{py_type}"
145
+
146
+
147
+ def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
148
+ """
149
+ Returns a reference to a python type in a package which is an ancestor to the
150
+ current package, and adds the required import that is aliased (if possible) to avoid
151
+ name conflicts.
152
+
153
+ Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
154
+ """
155
+ distance_up = len(current_package) - len(py_package)
156
+ if py_package:
157
+ string_import = py_package[-1]
158
+ string_alias = f"_{'_' * distance_up}{string_import}__"
159
+ string_from = f"..{'.' * distance_up}"
160
+ imports.add(f"from {string_from} import {string_import} as {string_alias}")
161
+ return f"{string_alias}.{py_type}"
162
+ else:
163
+ string_alias = f"{'_' * distance_up}{py_type}__"
164
+ imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
165
+ return string_alias
166
+
167
+
168
+ def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
169
+ """
170
+ Returns a reference to a python type in a package that is not descendent, ancestor
171
+ or sibling, and adds the required import that is aliased to avoid name conflicts.
172
+ """
173
+ shared_ancestry = os.path.commonprefix([current_package, py_package]) # type: ignore
174
+ distance_up = len(current_package) - len(shared_ancestry)
175
+ string_from = f".{'.' * distance_up}" + ".".join(py_package[len(shared_ancestry) : -1])
176
+ string_import = py_package[-1]
177
+ # Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
178
+ string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__"
179
+ imports.add(f"from {string_from} import {string_import} as {string_alias}")
180
+ return f"{string_alias}.{py_type}"
@@ -0,0 +1,21 @@
1
+ from betterproto2_compiler import casing
2
+
3
+
4
+ def pythonize_class_name(name: str) -> str:
5
+ return casing.pascal_case(name)
6
+
7
+
8
+ def pythonize_field_name(name: str) -> str:
9
+ return casing.safe_snake_case(name)
10
+
11
+
12
+ def pythonize_method_name(name: str) -> str:
13
+ return casing.safe_snake_case(name)
14
+
15
+
16
+ def pythonize_enum_member_name(name: str, enum_name: str) -> str:
17
+ enum_name = casing.snake_case(enum_name).upper()
18
+ find = name.find(enum_name)
19
+ if find != -1:
20
+ name = name[find + len(enum_name) :].strip("_")
21
+ return casing.sanitize_name(name)
@@ -0,0 +1,14 @@
1
+ from collections.abc import Callable
2
+
3
+ from .any import Any
4
+ from .duration import Duration
5
+ from .timestamp import Timestamp
6
+
7
+ # For each (package, message name), lists the methods that should be added to the message definition.
8
+ # The source code of the method is read from the `known_types` folder. If imports are needed, they can be directly added
9
+ # to the template file: they will automatically be removed if not necessary.
10
+ KNOWN_METHODS: dict[tuple[str, str], list[Callable]] = {
11
+ ("google.protobuf", "Any"): [Any.pack, Any.unpack, Any.to_dict],
12
+ ("google.protobuf", "Timestamp"): [Timestamp.from_datetime, Timestamp.to_datetime, Timestamp.timestamp_to_json],
13
+ ("google.protobuf", "Duration"): [Duration.from_timedelta, Duration.to_timedelta, Duration.delta_to_json],
14
+ }
@@ -0,0 +1,36 @@
1
+ import betterproto2
2
+
3
+ from betterproto2_compiler.lib.google.protobuf import Any as VanillaAny
4
+
5
+ default_message_pool = betterproto2.MessagePool() # Only for typing purpose
6
+
7
+
8
+ class Any(VanillaAny):
9
+ def pack(self, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> None:
10
+ """
11
+ Pack the given message in the `Any` object.
12
+
13
+ The message type must be registered in the message pool, which is done automatically when the module defining
14
+ the message type is imported.
15
+ """
16
+ message_pool = message_pool or default_message_pool
17
+
18
+ self.type_url = message_pool.type_to_url[type(message)]
19
+ self.value = bytes(message)
20
+
21
+ def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message:
22
+ """
23
+ Return the message packed inside the `Any` object.
24
+
25
+ The target message type must be registered in the message pool, which is done automatically when the module
26
+ defining the message type is imported.
27
+ """
28
+ message_pool = message_pool or default_message_pool
29
+
30
+ message_type = message_pool.url_to_type[self.type_url]
31
+
32
+ return message_type().parse(self.value)
33
+
34
+ def to_dict(self) -> dict: # pyright: ignore [reportIncompatibleMethodOverride]
35
+ # TOOO improve when dict is updated
36
+ return {"@type": self.type_url, "value": self.unpack().to_dict()}
@@ -0,0 +1,25 @@
1
+ import datetime
2
+
3
+ from betterproto2_compiler.lib.google.protobuf import Duration as VanillaDuration
4
+
5
+
6
+ class Duration(VanillaDuration):
7
+ @classmethod
8
+ def from_timedelta(
9
+ cls, delta: datetime.timedelta, *, _1_microsecond: datetime.timedelta = datetime.timedelta(microseconds=1)
10
+ ) -> "Duration":
11
+ total_ms = delta // _1_microsecond
12
+ seconds = int(total_ms / 1e6)
13
+ nanos = int((total_ms % 1e6) * 1e3)
14
+ return cls(seconds, nanos)
15
+
16
+ def to_timedelta(self) -> datetime.timedelta:
17
+ return datetime.timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
18
+
19
+ @staticmethod
20
+ def delta_to_json(delta: datetime.timedelta) -> str:
21
+ parts = str(delta.total_seconds()).split(".")
22
+ if len(parts) > 1:
23
+ while len(parts[1]) not in (3, 6, 9):
24
+ parts[1] = f"{parts[1]}0"
25
+ return f"{'.'.join(parts)}s"
@@ -0,0 +1,45 @@
1
+ import datetime
2
+
3
+ from betterproto2_compiler.lib.google.protobuf import Timestamp as VanillaTimestamp
4
+
5
+
6
+ class Timestamp(VanillaTimestamp):
7
+ @classmethod
8
+ def from_datetime(cls, dt: datetime.datetime) -> "Timestamp":
9
+ # manual epoch offset calulation to avoid rounding errors,
10
+ # to support negative timestamps (before 1970) and skirt
11
+ # around datetime bugs (apparently 0 isn't a year in [0, 9999]??)
12
+ offset = dt - datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
13
+ # below is the same as timedelta.total_seconds() but without dividing by 1e6
14
+ # so we end up with microseconds as integers instead of seconds as float
15
+ offset_us = (offset.days * 24 * 60 * 60 + offset.seconds) * 10**6 + offset.microseconds
16
+ seconds, us = divmod(offset_us, 10**6)
17
+ return cls(seconds, us * 1000)
18
+
19
+ def to_datetime(self) -> datetime.datetime:
20
+ # datetime.fromtimestamp() expects a timestamp in seconds, not microseconds
21
+ # if we pass it as a floating point number, we will run into rounding errors
22
+ # see also #407
23
+ offset = datetime.timedelta(seconds=self.seconds, microseconds=self.nanos // 1000)
24
+ return datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) + offset
25
+
26
+ @staticmethod
27
+ def timestamp_to_json(dt: datetime.datetime) -> str:
28
+ nanos = dt.microsecond * 1e3
29
+ if dt.tzinfo is not None:
30
+ # change timezone aware datetime objects to utc
31
+ dt = dt.astimezone(datetime.timezone.utc)
32
+ copy = dt.replace(microsecond=0, tzinfo=None)
33
+ result = copy.isoformat()
34
+ if (nanos % 1e9) == 0:
35
+ # If there are 0 fractional digits, the fractional
36
+ # point '.' should be omitted when serializing.
37
+ return f"{result}Z"
38
+ if (nanos % 1e6) == 0:
39
+ # Serialize 3 fractional digits.
40
+ return f"{result}.{int(nanos // 1e6) :03d}Z"
41
+ if (nanos % 1e3) == 0:
42
+ # Serialize 6 fractional digits.
43
+ return f"{result}.{int(nanos // 1e3) :06d}Z"
44
+ # Serialize 9 fractional digits.
45
+ return f"{result}.{nanos:09d}"
File without changes
File without changes