sqlglotc 28.10.1.dev130__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlglotc-28.10.1.dev130/MANIFEST.in +3 -0
- sqlglotc-28.10.1.dev130/PKG-INFO +9 -0
- sqlglotc-28.10.1.dev130/pyproject.toml +27 -0
- sqlglotc-28.10.1.dev130/setup.cfg +4 -0
- sqlglotc-28.10.1.dev130/setup.py +68 -0
- sqlglotc-28.10.1.dev130/sqlglot/__init__.py +0 -0
- sqlglotc-28.10.1.dev130/sqlglot/expression_core.py +449 -0
- sqlglotc-28.10.1.dev130/sqlglot/helper.py +465 -0
- sqlglotc-28.10.1.dev130/sqlglot/tokenizer_core.py +1081 -0
- sqlglotc-28.10.1.dev130/sqlglot/trie.py +81 -0
- sqlglotc-28.10.1.dev130/sqlglotc.egg-info/PKG-INFO +9 -0
- sqlglotc-28.10.1.dev130/sqlglotc.egg-info/SOURCES.txt +12 -0
- sqlglotc-28.10.1.dev130/sqlglotc.egg-info/dependency_links.txt +1 -0
- sqlglotc-28.10.1.dev130/sqlglotc.egg-info/top_level.txt +2 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sqlglotc
|
|
3
|
+
Version: 28.10.1.dev130
|
|
4
|
+
Summary: mypyc-compiled extensions for sqlglot
|
|
5
|
+
Author-email: Toby Mao <toby.mao@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://sqlglot.com/
|
|
8
|
+
Project-URL: Repository, https://github.com/tobymao/sqlglot
|
|
9
|
+
Requires-Python: >=3.9
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "sqlglotc"
|
|
3
|
+
dynamic = ["version"]
|
|
4
|
+
description = "mypyc-compiled extensions for sqlglot"
|
|
5
|
+
authors = [{ name = "Toby Mao", email = "toby.mao@gmail.com" }]
|
|
6
|
+
license = {text = "MIT"}
|
|
7
|
+
requires-python = ">= 3.9"
|
|
8
|
+
|
|
9
|
+
[project.urls]
|
|
10
|
+
Homepage = "https://sqlglot.com/"
|
|
11
|
+
Repository = "https://github.com/tobymao/sqlglot"
|
|
12
|
+
|
|
13
|
+
[build-system]
|
|
14
|
+
requires = ["setuptools >= 61.0", "setuptools_scm", "mypy>=1.0", "types-python-dateutil"]
|
|
15
|
+
build-backend = "setuptools.build_meta"
|
|
16
|
+
|
|
17
|
+
[tool.setuptools]
|
|
18
|
+
include-package-data = false
|
|
19
|
+
|
|
20
|
+
[tool.setuptools_scm]
|
|
21
|
+
root = ".."
|
|
22
|
+
fallback_version = "0.0.0"
|
|
23
|
+
local_scheme = "no-local-version"
|
|
24
|
+
|
|
25
|
+
[tool.mypy]
|
|
26
|
+
# Allow mypyc to resolve sqlglot.* from the repo root (../sqlglot/) or sdist root (./sqlglot/).
|
|
27
|
+
mypy_path = [".", ".."]
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
|
|
4
|
+
from setuptools import setup
|
|
5
|
+
from setuptools.command.build_ext import build_ext as _build_ext
|
|
6
|
+
from setuptools.command.sdist import sdist as _sdist
|
|
7
|
+
|
|
8
|
+
here = os.path.dirname(os.path.abspath(__file__))
|
|
9
|
+
sqlglot_src = os.path.join(here, "..", "sqlglot")
|
|
10
|
+
|
|
11
|
+
from mypyc.build import mypycify
|
|
12
|
+
|
|
13
|
+
SOURCE_FILES = ["expression_core.py", "helper.py", "trie.py", "tokenizer_core.py"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _source_paths():
|
|
17
|
+
if os.path.isdir(sqlglot_src):
|
|
18
|
+
# Building from the git repo: compile directly from sqlglot source, no copies.
|
|
19
|
+
return [os.path.join(sqlglot_src, f) for f in SOURCE_FILES]
|
|
20
|
+
# Building from an sdist: source files are bundled in ./sqlglot/.
|
|
21
|
+
return [os.path.join(here, "sqlglot", f) for f in SOURCE_FILES]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class build_ext(_build_ext):
|
|
25
|
+
def copy_extensions_to_source(self):
|
|
26
|
+
"""For editable installs, put sqlglot.* .so files in the sqlglot source dir."""
|
|
27
|
+
build_py = self.get_finalized_command("build_py")
|
|
28
|
+
for ext in self.extensions:
|
|
29
|
+
fullname = self.get_ext_fullname(ext.name)
|
|
30
|
+
filename = self.get_ext_filename(fullname)
|
|
31
|
+
src = os.path.join(self.build_lib, filename)
|
|
32
|
+
parts = fullname.split(".")
|
|
33
|
+
if parts[0] == "sqlglot" and len(parts) == 2 and os.path.isdir(sqlglot_src):
|
|
34
|
+
# Place compiled sqlglot.* modules directly in the sqlglot source package.
|
|
35
|
+
dst = os.path.join(sqlglot_src, self.get_ext_filename(parts[1]))
|
|
36
|
+
else:
|
|
37
|
+
# Default: mypyc runtime helper (e.g., HASH__mypyc) goes in current dir.
|
|
38
|
+
package = ".".join(parts[:-1])
|
|
39
|
+
package_dir = build_py.get_package_dir(package)
|
|
40
|
+
dst = (
|
|
41
|
+
os.path.join(package_dir, os.path.basename(filename))
|
|
42
|
+
if package_dir
|
|
43
|
+
else os.path.basename(filename)
|
|
44
|
+
)
|
|
45
|
+
self.copy_file(src, dst, level=self.verbose)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class sdist(_sdist):
|
|
49
|
+
"""Bundle sqlglot source files into the sdist so sqlglotc can compile on install."""
|
|
50
|
+
|
|
51
|
+
def run(self):
|
|
52
|
+
local_sqlglot = os.path.join(here, "sqlglot")
|
|
53
|
+
os.makedirs(local_sqlglot, exist_ok=True)
|
|
54
|
+
open(os.path.join(local_sqlglot, "__init__.py"), "w").close()
|
|
55
|
+
for fname in SOURCE_FILES:
|
|
56
|
+
shutil.copy2(os.path.join(sqlglot_src, fname), os.path.join(local_sqlglot, fname))
|
|
57
|
+
try:
|
|
58
|
+
super().run()
|
|
59
|
+
finally:
|
|
60
|
+
shutil.rmtree(local_sqlglot, ignore_errors=True)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
setup(
|
|
64
|
+
name="sqlglotc",
|
|
65
|
+
packages=[],
|
|
66
|
+
ext_modules=mypycify(_source_paths(), opt_level="3"),
|
|
67
|
+
cmdclass={"build_ext": build_ext, "sdist": sdist},
|
|
68
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import typing as t
|
|
5
|
+
from collections import deque
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from mypy_extensions import mypyc_attr
|
|
10
|
+
except ImportError:
|
|
11
|
+
|
|
12
|
+
def mypyc_attr(*attrs: str, **kwattrs: object) -> t.Callable[[t.Any], t.Any]: # type: ignore[misc]
|
|
13
|
+
return lambda f: f
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from sqlglot.helper import to_bool
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
EC = t.TypeVar("EC", bound="ExpressionCore")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
POSITION_META_KEYS: t.Tuple[str, ...] = ("line", "col", "start", "end")
|
|
23
|
+
SQLGLOT_META: str = "sqlglot.meta"
|
|
24
|
+
UNITTEST: bool = "unittest" in sys.modules or "pytest" in sys.modules
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@mypyc_attr(allow_interpreted_subclasses=True)
|
|
28
|
+
class ExpressionCore:
|
|
29
|
+
__slots__ = (
|
|
30
|
+
"args",
|
|
31
|
+
"parent",
|
|
32
|
+
"arg_key",
|
|
33
|
+
"index",
|
|
34
|
+
"comments",
|
|
35
|
+
"_type",
|
|
36
|
+
"_meta",
|
|
37
|
+
"_hash",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
key: t.ClassVar[str]
|
|
41
|
+
arg_types: t.ClassVar[t.Dict[str, bool]] = {}
|
|
42
|
+
required_args: t.ClassVar[t.Set[str]] = set()
|
|
43
|
+
is_var_len_args: t.ClassVar[bool] = False
|
|
44
|
+
is_func: t.ClassVar[bool] = False
|
|
45
|
+
_hash_raw_args: t.ClassVar[bool] = False
|
|
46
|
+
|
|
47
|
+
def __init__(self, **args: t.Any) -> None:
|
|
48
|
+
self.args: t.Dict[str, t.Any] = args
|
|
49
|
+
self.parent: t.Optional[ExpressionCore] = None
|
|
50
|
+
self.arg_key: t.Optional[str] = None
|
|
51
|
+
self.index: t.Optional[int] = None
|
|
52
|
+
self.comments: t.Optional[t.List[str]] = None
|
|
53
|
+
self._type: t.Optional[t.Any] = None
|
|
54
|
+
self._meta: t.Optional[t.Dict[str, t.Any]] = None
|
|
55
|
+
self._hash: t.Optional[int] = None
|
|
56
|
+
|
|
57
|
+
for arg_key, value in self.args.items():
|
|
58
|
+
self._set_parent(arg_key, value)
|
|
59
|
+
|
|
60
|
+
def _set_parent(self, arg_key: str, value: t.Any, index: t.Optional[int] = None) -> None:
|
|
61
|
+
if isinstance(value, ExpressionCore):
|
|
62
|
+
value.parent = self
|
|
63
|
+
value.arg_key = arg_key
|
|
64
|
+
value.index = index
|
|
65
|
+
elif isinstance(value, list):
|
|
66
|
+
for i, v in enumerate(value):
|
|
67
|
+
if isinstance(v, ExpressionCore):
|
|
68
|
+
v.parent = self
|
|
69
|
+
v.arg_key = arg_key
|
|
70
|
+
v.index = i
|
|
71
|
+
|
|
72
|
+
def iter_expressions(self: EC, reverse: bool = False) -> t.Iterator[EC]:
|
|
73
|
+
for vs in reversed(self.args.values()) if reverse else self.args.values():
|
|
74
|
+
if isinstance(vs, list):
|
|
75
|
+
for v in reversed(vs) if reverse else vs:
|
|
76
|
+
if isinstance(v, ExpressionCore):
|
|
77
|
+
yield t.cast(EC, v)
|
|
78
|
+
elif isinstance(vs, ExpressionCore):
|
|
79
|
+
yield t.cast(EC, vs)
|
|
80
|
+
|
|
81
|
+
def bfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]:
|
|
82
|
+
queue: t.Deque[EC] = deque()
|
|
83
|
+
queue.append(self)
|
|
84
|
+
while queue:
|
|
85
|
+
node = queue.popleft()
|
|
86
|
+
yield node
|
|
87
|
+
if prune and prune(node):
|
|
88
|
+
continue
|
|
89
|
+
for v in node.iter_expressions():
|
|
90
|
+
queue.append(v)
|
|
91
|
+
|
|
92
|
+
def dfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]:
|
|
93
|
+
stack: t.List[EC] = [self]
|
|
94
|
+
while stack:
|
|
95
|
+
node = stack.pop()
|
|
96
|
+
yield node
|
|
97
|
+
if prune and prune(node):
|
|
98
|
+
continue
|
|
99
|
+
for v in node.iter_expressions(reverse=True):
|
|
100
|
+
stack.append(v)
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def meta(self) -> t.Dict[str, t.Any]:
|
|
104
|
+
if self._meta is None:
|
|
105
|
+
self._meta = {}
|
|
106
|
+
return self._meta
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def this(self) -> t.Any:
|
|
110
|
+
return self.args.get("this")
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def expression(self) -> t.Any:
|
|
114
|
+
return self.args.get("expression")
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def expressions(self) -> t.List[t.Any]:
|
|
118
|
+
return self.args.get("expressions") or []
|
|
119
|
+
|
|
120
|
+
def pop_comments(self) -> t.List[str]:
|
|
121
|
+
comments = self.comments or []
|
|
122
|
+
self.comments = None
|
|
123
|
+
return comments
|
|
124
|
+
|
|
125
|
+
def append(self, arg_key: str, value: t.Any) -> None:
|
|
126
|
+
if type(self.args.get(arg_key)) is not list:
|
|
127
|
+
self.args[arg_key] = []
|
|
128
|
+
self._set_parent(arg_key, value)
|
|
129
|
+
values = self.args[arg_key]
|
|
130
|
+
if hasattr(value, "parent"):
|
|
131
|
+
value.index = len(values)
|
|
132
|
+
values.append(value)
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def depth(self) -> int:
|
|
136
|
+
if self.parent:
|
|
137
|
+
return self.parent.depth + 1
|
|
138
|
+
return 0
|
|
139
|
+
|
|
140
|
+
def find_ancestor(self, *expression_types: t.Any) -> t.Optional[t.Any]:
|
|
141
|
+
ancestor = self.parent
|
|
142
|
+
while ancestor and not isinstance(ancestor, expression_types):
|
|
143
|
+
ancestor = ancestor.parent
|
|
144
|
+
return ancestor
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def same_parent(self) -> bool:
|
|
148
|
+
return type(self.parent) is self.__class__
|
|
149
|
+
|
|
150
|
+
def root(self) -> ExpressionCore:
|
|
151
|
+
expression = self
|
|
152
|
+
while expression.parent:
|
|
153
|
+
expression = expression.parent
|
|
154
|
+
return expression
|
|
155
|
+
|
|
156
|
+
def __eq__(self, other: object) -> bool:
|
|
157
|
+
return self is other or (type(self) is type(other) and hash(self) == hash(other))
|
|
158
|
+
|
|
159
|
+
def __hash__(self) -> int:
|
|
160
|
+
if self._hash is None:
|
|
161
|
+
nodes: t.List[ExpressionCore] = []
|
|
162
|
+
queue: t.Deque[ExpressionCore] = deque()
|
|
163
|
+
queue.append(self)
|
|
164
|
+
|
|
165
|
+
while queue:
|
|
166
|
+
node = queue.popleft()
|
|
167
|
+
nodes.append(node)
|
|
168
|
+
|
|
169
|
+
for child in node.iter_expressions():
|
|
170
|
+
if child._hash is None:
|
|
171
|
+
queue.append(child)
|
|
172
|
+
|
|
173
|
+
for node in reversed(nodes):
|
|
174
|
+
hash_ = hash(node.key)
|
|
175
|
+
|
|
176
|
+
if node._hash_raw_args:
|
|
177
|
+
for k, v in sorted(node.args.items()):
|
|
178
|
+
if v:
|
|
179
|
+
hash_ = hash((hash_, k, v))
|
|
180
|
+
else:
|
|
181
|
+
for k, v in sorted(node.args.items()):
|
|
182
|
+
vt = type(v)
|
|
183
|
+
|
|
184
|
+
if vt is list:
|
|
185
|
+
for x in v:
|
|
186
|
+
if x is not None and x is not False:
|
|
187
|
+
hash_ = hash((hash_, k, x.lower() if type(x) is str else x))
|
|
188
|
+
else:
|
|
189
|
+
hash_ = hash((hash_, k))
|
|
190
|
+
elif v is not None and v is not False:
|
|
191
|
+
hash_ = hash((hash_, k, v.lower() if vt is str else v))
|
|
192
|
+
|
|
193
|
+
node._hash = hash_
|
|
194
|
+
assert self._hash
|
|
195
|
+
return self._hash
|
|
196
|
+
|
|
197
|
+
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
|
|
198
|
+
errors: t.List[str] = []
|
|
199
|
+
|
|
200
|
+
if UNITTEST:
|
|
201
|
+
for k in self.args:
|
|
202
|
+
if k not in self.arg_types:
|
|
203
|
+
raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}")
|
|
204
|
+
|
|
205
|
+
for k in self.required_args:
|
|
206
|
+
v = self.args.get(k)
|
|
207
|
+
if v is None or (isinstance(v, list) and not v):
|
|
208
|
+
errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
|
|
209
|
+
|
|
210
|
+
if args and self.is_func and len(args) > len(self.arg_types) and not self.is_var_len_args:
|
|
211
|
+
errors.append(
|
|
212
|
+
f"The number of provided arguments ({len(args)}) is greater than "
|
|
213
|
+
f"the maximum number of supported arguments ({len(self.arg_types)})"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return errors
|
|
217
|
+
|
|
218
|
+
def update_positions(
|
|
219
|
+
self: EC,
|
|
220
|
+
other: t.Optional[t.Any] = None,
|
|
221
|
+
line: t.Optional[int] = None,
|
|
222
|
+
col: t.Optional[int] = None,
|
|
223
|
+
start: t.Optional[int] = None,
|
|
224
|
+
end: t.Optional[int] = None,
|
|
225
|
+
) -> EC:
|
|
226
|
+
if other is None:
|
|
227
|
+
self.meta["line"] = line
|
|
228
|
+
self.meta["col"] = col
|
|
229
|
+
self.meta["start"] = start
|
|
230
|
+
self.meta["end"] = end
|
|
231
|
+
elif isinstance(other, ExpressionCore):
|
|
232
|
+
for k in POSITION_META_KEYS:
|
|
233
|
+
if k in other.meta:
|
|
234
|
+
self.meta[k] = other.meta[k]
|
|
235
|
+
else:
|
|
236
|
+
# Token: has .line, .col, .start, .end attributes
|
|
237
|
+
self.meta["line"] = other.line
|
|
238
|
+
self.meta["col"] = other.col
|
|
239
|
+
self.meta["start"] = other.start
|
|
240
|
+
self.meta["end"] = other.end
|
|
241
|
+
return self
|
|
242
|
+
|
|
243
|
+
def to_py(self) -> t.Any:
|
|
244
|
+
raise ValueError(f"{self} cannot be converted to a Python object.")
|
|
245
|
+
|
|
246
|
+
def text(self, key: str) -> str:
|
|
247
|
+
field = self.args.get(key)
|
|
248
|
+
if isinstance(field, str):
|
|
249
|
+
return field
|
|
250
|
+
return ""
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def name(self) -> str:
|
|
254
|
+
return self.text("this")
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def alias(self) -> str:
|
|
258
|
+
alias = self.args.get("alias")
|
|
259
|
+
if isinstance(alias, ExpressionCore):
|
|
260
|
+
return alias.name
|
|
261
|
+
return self.text("alias")
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def alias_column_names(self) -> t.List[str]:
|
|
265
|
+
table_alias = self.args.get("alias")
|
|
266
|
+
if not table_alias:
|
|
267
|
+
return []
|
|
268
|
+
return [c.name for c in table_alias.args.get("columns") or []]
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def alias_or_name(self) -> str:
|
|
272
|
+
return self.alias or self.name
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def output_name(self) -> str:
|
|
276
|
+
return ""
|
|
277
|
+
|
|
278
|
+
def is_leaf(self) -> bool:
|
|
279
|
+
return not any(
|
|
280
|
+
(isinstance(v, ExpressionCore) or type(v) is list) and v for v in self.args.values()
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def __deepcopy__(self, memo: t.Any) -> ExpressionCore:
|
|
284
|
+
root = self.__class__()
|
|
285
|
+
stack: t.List[t.Tuple[ExpressionCore, ExpressionCore]] = [(self, root)]
|
|
286
|
+
|
|
287
|
+
while stack:
|
|
288
|
+
node, copy = stack.pop()
|
|
289
|
+
|
|
290
|
+
if node.comments is not None:
|
|
291
|
+
copy.comments = deepcopy(node.comments)
|
|
292
|
+
if node._type is not None:
|
|
293
|
+
copy._type = deepcopy(node._type)
|
|
294
|
+
if node._meta is not None:
|
|
295
|
+
copy._meta = deepcopy(node._meta)
|
|
296
|
+
if node._hash is not None:
|
|
297
|
+
copy._hash = node._hash
|
|
298
|
+
|
|
299
|
+
for k, vs in node.args.items():
|
|
300
|
+
if hasattr(vs, "parent"):
|
|
301
|
+
stack.append((vs, vs.__class__()))
|
|
302
|
+
copy.set(k, stack[-1][-1])
|
|
303
|
+
elif type(vs) is list:
|
|
304
|
+
copy.args[k] = []
|
|
305
|
+
|
|
306
|
+
for v in vs:
|
|
307
|
+
if hasattr(v, "parent"):
|
|
308
|
+
stack.append((v, v.__class__()))
|
|
309
|
+
copy.append(k, stack[-1][-1])
|
|
310
|
+
else:
|
|
311
|
+
copy.append(k, v)
|
|
312
|
+
else:
|
|
313
|
+
copy.args[k] = vs
|
|
314
|
+
|
|
315
|
+
return root
|
|
316
|
+
|
|
317
|
+
def copy(self: EC) -> EC:
|
|
318
|
+
return deepcopy(self)
|
|
319
|
+
|
|
320
|
+
def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool = False) -> None:
|
|
321
|
+
if self.comments is None:
|
|
322
|
+
self.comments = []
|
|
323
|
+
|
|
324
|
+
if comments:
|
|
325
|
+
for comment in comments:
|
|
326
|
+
_, *meta = comment.split(SQLGLOT_META)
|
|
327
|
+
if meta:
|
|
328
|
+
for kv in "".join(meta).split(","):
|
|
329
|
+
k, *v = kv.split("=")
|
|
330
|
+
value: t.Any = v[0].strip() if v else True
|
|
331
|
+
self.meta[k.strip()] = to_bool(value)
|
|
332
|
+
|
|
333
|
+
if not prepend:
|
|
334
|
+
self.comments.append(comment)
|
|
335
|
+
|
|
336
|
+
if prepend:
|
|
337
|
+
self.comments = comments + self.comments
|
|
338
|
+
|
|
339
|
+
def set(
|
|
340
|
+
self,
|
|
341
|
+
arg_key: str,
|
|
342
|
+
value: t.Any,
|
|
343
|
+
index: t.Optional[int] = None,
|
|
344
|
+
overwrite: bool = True,
|
|
345
|
+
) -> None:
|
|
346
|
+
node: t.Optional[ExpressionCore] = self
|
|
347
|
+
|
|
348
|
+
while node and node._hash is not None:
|
|
349
|
+
node._hash = None
|
|
350
|
+
node = node.parent
|
|
351
|
+
|
|
352
|
+
if index is not None:
|
|
353
|
+
expressions = self.args.get(arg_key) or []
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
if expressions[index] is None:
|
|
357
|
+
return
|
|
358
|
+
except IndexError:
|
|
359
|
+
return
|
|
360
|
+
|
|
361
|
+
if value is None:
|
|
362
|
+
expressions.pop(index)
|
|
363
|
+
for v in expressions[index:]:
|
|
364
|
+
v.index = v.index - 1
|
|
365
|
+
return
|
|
366
|
+
|
|
367
|
+
if isinstance(value, list):
|
|
368
|
+
expressions.pop(index)
|
|
369
|
+
expressions[index:index] = value
|
|
370
|
+
elif overwrite:
|
|
371
|
+
expressions[index] = value
|
|
372
|
+
else:
|
|
373
|
+
expressions.insert(index, value)
|
|
374
|
+
|
|
375
|
+
value = expressions
|
|
376
|
+
elif value is None:
|
|
377
|
+
self.args.pop(arg_key, None)
|
|
378
|
+
return
|
|
379
|
+
|
|
380
|
+
self.args[arg_key] = value
|
|
381
|
+
self._set_parent(arg_key, value, index)
|
|
382
|
+
|
|
383
|
+
def find(self, *expression_types: t.Any, bfs: bool = True) -> t.Optional[t.Any]:
|
|
384
|
+
return next(self.find_all(*expression_types, bfs=bfs), None)
|
|
385
|
+
|
|
386
|
+
def find_all(self, *expression_types: t.Any, bfs: bool = True) -> t.Iterator[t.Any]:
|
|
387
|
+
for expression in self.walk(bfs=bfs):
|
|
388
|
+
if isinstance(expression, expression_types):
|
|
389
|
+
yield expression
|
|
390
|
+
|
|
391
|
+
def walk(
|
|
392
|
+
self: EC,
|
|
393
|
+
bfs: bool = True,
|
|
394
|
+
prune: t.Optional[t.Callable[[EC], bool]] = None,
|
|
395
|
+
) -> t.Iterator[EC]:
|
|
396
|
+
if bfs:
|
|
397
|
+
yield from self.bfs(prune=prune)
|
|
398
|
+
else:
|
|
399
|
+
yield from self.dfs(prune=prune)
|
|
400
|
+
|
|
401
|
+
def replace(self, expression: t.Any) -> t.Any:
|
|
402
|
+
parent = self.parent
|
|
403
|
+
|
|
404
|
+
if not parent or parent is expression:
|
|
405
|
+
return expression
|
|
406
|
+
|
|
407
|
+
key = self.arg_key
|
|
408
|
+
if not key:
|
|
409
|
+
return expression
|
|
410
|
+
|
|
411
|
+
value = parent.args.get(key)
|
|
412
|
+
|
|
413
|
+
if type(expression) is list and isinstance(value, ExpressionCore):
|
|
414
|
+
if value.parent:
|
|
415
|
+
value.parent.replace(expression)
|
|
416
|
+
else:
|
|
417
|
+
parent.set(key, expression, self.index)
|
|
418
|
+
|
|
419
|
+
if expression is not self:
|
|
420
|
+
self.parent = None
|
|
421
|
+
self.arg_key = None
|
|
422
|
+
self.index = None
|
|
423
|
+
|
|
424
|
+
return expression
|
|
425
|
+
|
|
426
|
+
def pop(self: EC) -> EC:
|
|
427
|
+
self.replace(None)
|
|
428
|
+
return self
|
|
429
|
+
|
|
430
|
+
def assert_is(self, type_: t.Any) -> t.Any:
|
|
431
|
+
if not isinstance(self, type_):
|
|
432
|
+
raise AssertionError(f"{self} is not {type_}.")
|
|
433
|
+
return self
|
|
434
|
+
|
|
435
|
+
def transform(self, fun: t.Callable, *args: t.Any, copy: bool = True, **kwargs: t.Any) -> t.Any:
|
|
436
|
+
root: t.Optional[t.Any] = None
|
|
437
|
+
new_node: t.Optional[t.Any] = None
|
|
438
|
+
|
|
439
|
+
for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node):
|
|
440
|
+
parent, arg_key, index = node.parent, node.arg_key, node.index
|
|
441
|
+
new_node = fun(node, *args, **kwargs)
|
|
442
|
+
|
|
443
|
+
if not root:
|
|
444
|
+
root = new_node
|
|
445
|
+
elif parent and arg_key and new_node is not node:
|
|
446
|
+
parent.set(arg_key, new_node, index)
|
|
447
|
+
|
|
448
|
+
assert root
|
|
449
|
+
return root
|