vnnlib 0.0.1a0__py3-none-any.whl → 0.0.1a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vnnlib might be problematic. Click here for more details.
- vnnlib/__init__.py +3 -0
- vnnlib/__main__.py +38 -4
- vnnlib/__version__.py +1 -1
- vnnlib/compat.py +238 -0
- vnnlib/errors.py +11 -0
- vnnlib/parser.py +368 -0
- {vnnlib-0.0.1a0.dist-info → vnnlib-0.0.1a2.dist-info}/METADATA +36 -4
- vnnlib-0.0.1a2.dist-info/RECORD +11 -0
- vnnlib-0.0.1a0.dist-info/RECORD +0 -8
- {vnnlib-0.0.1a0.dist-info → vnnlib-0.0.1a2.dist-info}/LICENSE.txt +0 -0
- {vnnlib-0.0.1a0.dist-info → vnnlib-0.0.1a2.dist-info}/WHEEL +0 -0
- {vnnlib-0.0.1a0.dist-info → vnnlib-0.0.1a2.dist-info}/top_level.txt +0 -0
vnnlib/__init__.py
CHANGED
vnnlib/__main__.py
CHANGED
|
@@ -1,18 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import argparse
|
|
4
|
+
import pickle
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from .compat import CompatTransformer
|
|
8
|
+
from .parser import parse_file
|
|
2
9
|
|
|
3
10
|
|
|
4
|
-
def parse_args():
|
|
11
|
+
def parse_args() -> argparse.Namespace:
|
|
5
12
|
parser = argparse.ArgumentParser(
|
|
6
13
|
"vnnlib",
|
|
7
14
|
description="",
|
|
8
15
|
)
|
|
9
|
-
parser.add_argument("file",
|
|
16
|
+
parser.add_argument("file", type=Path)
|
|
17
|
+
parser.add_argument(
|
|
18
|
+
"--compat", action="store_true", help="Use the VNN-COMP-1 output format"
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--strict",
|
|
22
|
+
action=argparse.BooleanOptionalAction,
|
|
23
|
+
default=True,
|
|
24
|
+
help="Whether or not to strictly follow VNN-LIB (default: True)",
|
|
25
|
+
)
|
|
26
|
+
parser.add_argument(
|
|
27
|
+
"-o", "--output", type=str, help="The path to save the compiled output"
|
|
28
|
+
)
|
|
10
29
|
return parser.parse_args()
|
|
11
30
|
|
|
12
31
|
|
|
13
|
-
def __main__():
|
|
32
|
+
def __main__() -> None:
|
|
14
33
|
args = parse_args()
|
|
15
|
-
|
|
34
|
+
file: Path = args.file
|
|
35
|
+
print(f"parsing file: {args.file}")
|
|
36
|
+
|
|
37
|
+
if args.compat:
|
|
38
|
+
if ".vnnlib" in file.suffixes:
|
|
39
|
+
ast_node = parse_file(file, strict=args.strict)
|
|
40
|
+
result = CompatTransformer("X", "Y").transform(ast_node)
|
|
41
|
+
if args.output:
|
|
42
|
+
with open(args.output, "wb+") as f:
|
|
43
|
+
pickle.dump(result, f)
|
|
44
|
+
else:
|
|
45
|
+
raise RuntimeError(f"Unsupported file type: {file.suffix}")
|
|
46
|
+
else:
|
|
47
|
+
raise NotImplementedError(
|
|
48
|
+
"Currently only the VNN-COMP-1 output format is supported"
|
|
49
|
+
)
|
|
16
50
|
|
|
17
51
|
|
|
18
52
|
if __name__ == "__main__":
|
vnnlib/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.0.
|
|
1
|
+
__version__ = "0.0.1a2"
|
vnnlib/compat.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import operator
|
|
4
|
+
import pathlib
|
|
5
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from .parser import AstNodeTransformer, Real, parse_file
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CompatTransformer(AstNodeTransformer):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
input_name: str,
|
|
16
|
+
output_name: str,
|
|
17
|
+
input_size: Optional[int] = None,
|
|
18
|
+
output_size: Optional[int] = None,
|
|
19
|
+
) -> None:
|
|
20
|
+
self.input_name = input_name
|
|
21
|
+
self.output_name = output_name
|
|
22
|
+
|
|
23
|
+
self.input_size = input_size or 0
|
|
24
|
+
self.output_size = output_size or 0
|
|
25
|
+
|
|
26
|
+
self.infer_input_size = input_size is None
|
|
27
|
+
self.infer_output_size = output_size is None
|
|
28
|
+
|
|
29
|
+
self._id_map: Dict[str, int] = {}
|
|
30
|
+
self._assertions: Dict[Tuple[int, ...], Real] = {}
|
|
31
|
+
self._num_assertions = 0
|
|
32
|
+
self._disjunctions: List[Dict[Tuple[int, ...], Real]] = [{}]
|
|
33
|
+
|
|
34
|
+
def transform_Assert(
|
|
35
|
+
self,
|
|
36
|
+
term: Union[List[Dict[Tuple[int, ...], Real]], Dict[Tuple[int, ...], Real]],
|
|
37
|
+
) -> List[Dict[Tuple[int, ...], Real]]:
|
|
38
|
+
if isinstance(term, list):
|
|
39
|
+
if len(term) == 1:
|
|
40
|
+
row_offset = self._num_assertions
|
|
41
|
+
max_row = 0
|
|
42
|
+
for (row, *index), value in term[0].items():
|
|
43
|
+
self._assertions[(row + row_offset, *index)] = value
|
|
44
|
+
max_row = max(max_row, row)
|
|
45
|
+
if len(term[0]):
|
|
46
|
+
self._num_assertions += max_row + 1
|
|
47
|
+
elif len(self._disjunctions) == 1:
|
|
48
|
+
new_disjunctions = []
|
|
49
|
+
for disjunct in term:
|
|
50
|
+
disjunct = disjunct.copy()
|
|
51
|
+
new_disjunctions.append(disjunct)
|
|
52
|
+
row_offset = max(disjunct, default=(-1,))[0] + 1
|
|
53
|
+
for (row, *index), value in self._disjunctions[0].items():
|
|
54
|
+
disjunct[(row + row_offset, *index)] = value
|
|
55
|
+
self._disjunctions = new_disjunctions
|
|
56
|
+
else:
|
|
57
|
+
assert False
|
|
58
|
+
return term
|
|
59
|
+
if isinstance(term, dict):
|
|
60
|
+
row_offset = self._num_assertions
|
|
61
|
+
max_row = 0
|
|
62
|
+
for (row, *index), value in term.items():
|
|
63
|
+
self._assertions[(row + row_offset, *index)] = value
|
|
64
|
+
max_row = max(max_row, row)
|
|
65
|
+
if len(term):
|
|
66
|
+
self._num_assertions += max_row + 1
|
|
67
|
+
return [term]
|
|
68
|
+
raise RuntimeError("unexpected term for assert")
|
|
69
|
+
|
|
70
|
+
def transform_Constant(self, value) -> Dict[Tuple[int, ...], Real]:
|
|
71
|
+
assert isinstance(value, (Real, int))
|
|
72
|
+
return {(0, -1, -1): value}
|
|
73
|
+
|
|
74
|
+
def transform_DeclareConst(self, symbol: str, sort: str) -> None:
|
|
75
|
+
if self.infer_input_size and symbol.startswith(f"{self.input_name}_"):
|
|
76
|
+
_, index = symbol.split("_")
|
|
77
|
+
self.input_size = max(self.input_size, int(index) + 1)
|
|
78
|
+
elif self.infer_output_size and symbol.startswith(f"{self.output_name}_"):
|
|
79
|
+
_, index = symbol.split("_")
|
|
80
|
+
self.output_size = max(self.output_size, int(index) + 1)
|
|
81
|
+
|
|
82
|
+
def transform_FunctionApplication(
|
|
83
|
+
self,
|
|
84
|
+
symbol: str,
|
|
85
|
+
*terms: Union[List[Dict[Tuple[int, ...], Real]], Dict[Tuple[int, ...], Real]],
|
|
86
|
+
) -> Union[List[Dict[Tuple[int, ...], Real]], Dict[Tuple[int, ...], Real]]:
|
|
87
|
+
if symbol == "<=":
|
|
88
|
+
lhs, rhs = terms
|
|
89
|
+
assert isinstance(lhs, dict)
|
|
90
|
+
assert isinstance(rhs, dict)
|
|
91
|
+
result = lhs.copy()
|
|
92
|
+
for key, value in rhs.items():
|
|
93
|
+
result[key] = result.get(key, 0) - value
|
|
94
|
+
return result
|
|
95
|
+
elif symbol == ">=":
|
|
96
|
+
lhs, rhs = terms
|
|
97
|
+
assert isinstance(lhs, dict)
|
|
98
|
+
assert isinstance(rhs, dict)
|
|
99
|
+
result = rhs.copy()
|
|
100
|
+
for key, value in lhs.items():
|
|
101
|
+
result[key] = result.get(key, 0) - value
|
|
102
|
+
return result
|
|
103
|
+
elif symbol == "and":
|
|
104
|
+
conjuncts = {}
|
|
105
|
+
for i, term in enumerate(terms):
|
|
106
|
+
assert isinstance(term, dict)
|
|
107
|
+
for (row, *index), value in term.items():
|
|
108
|
+
assert row == 0
|
|
109
|
+
assert len(index) == 2, "please open a bug report"
|
|
110
|
+
conjuncts[(i, *index)] = value
|
|
111
|
+
return [conjuncts]
|
|
112
|
+
elif symbol == "or":
|
|
113
|
+
or_result = []
|
|
114
|
+
for term in terms:
|
|
115
|
+
assert isinstance(term, list), "please open a bug report"
|
|
116
|
+
for subterm in term:
|
|
117
|
+
assert isinstance(subterm, dict), "please open a bug report"
|
|
118
|
+
or_result.extend(term)
|
|
119
|
+
return or_result
|
|
120
|
+
else:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
f"Function {symbol!r} is not supported by the legacy parser"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def transform_Identifier(
|
|
126
|
+
self, value: str
|
|
127
|
+
) -> Union[str, Dict[Tuple[int, ...], Real]]:
|
|
128
|
+
if value.startswith(f"{self.input_name}_"):
|
|
129
|
+
_, *str_index = value.split("_")
|
|
130
|
+
return {(0, 0, *tuple(map(int, str_index))): 1}
|
|
131
|
+
elif value.startswith(f"{self.output_name}_"):
|
|
132
|
+
_, *str_index = value.split("_")
|
|
133
|
+
return {(0, 1, *tuple(map(int, str_index))): 1}
|
|
134
|
+
elif value in {"<=", ">=", "and", "or"}:
|
|
135
|
+
return value
|
|
136
|
+
if value not in self._id_map:
|
|
137
|
+
self._id_map[value] = len(self._id_map) + 2
|
|
138
|
+
return {(0, self._id_map[value]): 1}
|
|
139
|
+
|
|
140
|
+
def transform_Script(self, *commands):
|
|
141
|
+
common_box = [[float("-inf"), float("inf")] for _ in range(self.input_size)]
|
|
142
|
+
common_polytope = []
|
|
143
|
+
input_box_rows = set()
|
|
144
|
+
output_polytope_rows = set()
|
|
145
|
+
rhs = 0
|
|
146
|
+
for (row, var_type, index), value in sorted(
|
|
147
|
+
self._assertions.items(), key=operator.itemgetter(0)
|
|
148
|
+
):
|
|
149
|
+
assert var_type <= 0 or row not in input_box_rows
|
|
150
|
+
if var_type == -1:
|
|
151
|
+
rhs = value
|
|
152
|
+
continue
|
|
153
|
+
if var_type == 0:
|
|
154
|
+
input_box_rows.add(row)
|
|
155
|
+
if value == 1:
|
|
156
|
+
common_box[index][1] = min(-rhs, common_box[index][1])
|
|
157
|
+
elif value == -1:
|
|
158
|
+
common_box[index][0] = max(rhs, common_box[index][0])
|
|
159
|
+
else:
|
|
160
|
+
raise RuntimeError(f"unexpected lhs coeff for box: {value}")
|
|
161
|
+
rhs = 0
|
|
162
|
+
continue
|
|
163
|
+
if var_type == 1:
|
|
164
|
+
output_polytope_rows.add(row)
|
|
165
|
+
polytope_row = row - min(output_polytope_rows)
|
|
166
|
+
if len(common_polytope) <= polytope_row:
|
|
167
|
+
common_polytope.append(
|
|
168
|
+
[[0 for _ in range(self.output_size)], [rhs]]
|
|
169
|
+
)
|
|
170
|
+
rhs = 0
|
|
171
|
+
common_polytope[polytope_row][0][index] = value
|
|
172
|
+
continue
|
|
173
|
+
raise RuntimeError(f"unexpected variable type {var_type}")
|
|
174
|
+
results = {}
|
|
175
|
+
for disjunct in self._disjunctions:
|
|
176
|
+
box = [interval.copy() for interval in common_box]
|
|
177
|
+
polytope = [[lhs.copy(), rhs.copy()] for lhs, rhs in common_polytope]
|
|
178
|
+
input_box_rows = set()
|
|
179
|
+
output_polytope_rows = set()
|
|
180
|
+
rhs = 0
|
|
181
|
+
for (row, var_type, index), value in sorted(
|
|
182
|
+
disjunct.items(), key=lambda kv: kv[0]
|
|
183
|
+
):
|
|
184
|
+
assert var_type <= 0 or row not in input_box_rows
|
|
185
|
+
if var_type == -1:
|
|
186
|
+
rhs = value
|
|
187
|
+
continue
|
|
188
|
+
if var_type == 0:
|
|
189
|
+
input_box_rows.add(row)
|
|
190
|
+
if value == 1:
|
|
191
|
+
box[index][1] = min(-rhs, box[index][1])
|
|
192
|
+
elif value == -1:
|
|
193
|
+
box[index][0] = max(rhs, box[index][0])
|
|
194
|
+
else:
|
|
195
|
+
raise RuntimeError(f"unexpected lhs coeff for box: {value}")
|
|
196
|
+
rhs = 0
|
|
197
|
+
continue
|
|
198
|
+
if var_type == 1:
|
|
199
|
+
output_polytope_rows.add(row)
|
|
200
|
+
polytope_row = (
|
|
201
|
+
row - min(output_polytope_rows) + len(common_polytope)
|
|
202
|
+
)
|
|
203
|
+
if len(polytope) <= polytope_row:
|
|
204
|
+
polytope.append([[0 for _ in range(self.output_size)], [rhs]])
|
|
205
|
+
rhs = 0
|
|
206
|
+
polytope[polytope_row][0][index] = value
|
|
207
|
+
continue
|
|
208
|
+
raise RuntimeError(f"unexpected variable type {var_type}")
|
|
209
|
+
box_str = str(box)
|
|
210
|
+
polytope_arr = (
|
|
211
|
+
np.array([lhs for lhs, _ in polytope]),
|
|
212
|
+
np.array([rhs for _, rhs in polytope]),
|
|
213
|
+
)
|
|
214
|
+
if box_str not in results:
|
|
215
|
+
results[box_str] = [box, [polytope_arr]]
|
|
216
|
+
else:
|
|
217
|
+
results[box_str][1].append(polytope_arr)
|
|
218
|
+
return list(results.values())
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def read_vnnlib_simple(
|
|
222
|
+
vnnlib_filename: Union[str, pathlib.Path], num_inputs: int, num_outputs: int
|
|
223
|
+
) -> List[Tuple[List[Real], Tuple[np.ndarray, np.ndarray]]]:
|
|
224
|
+
"""process in a vnnlib file. You can get num_inputs and num_outputs using get_num_inputs_outputs().
|
|
225
|
+
|
|
226
|
+
output a list containing 2-tuples:
|
|
227
|
+
1. input ranges (box), list of pairs for each input variable
|
|
228
|
+
2. specification, provided as a list of pairs (mat, rhs), as in: mat * y <= rhs, where y is the output.
|
|
229
|
+
Each element in the list is a term in a disjunction for the specification.
|
|
230
|
+
"""
|
|
231
|
+
ast_node = parse_file(vnnlib_filename, strict=False)
|
|
232
|
+
result = CompatTransformer("X", "Y", num_inputs, num_outputs).transform(ast_node)
|
|
233
|
+
return result
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
__all__ = [
|
|
237
|
+
"read_vnnlib_simple",
|
|
238
|
+
]
|
vnnlib/errors.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
class ParserError(Exception):
|
|
2
|
+
def __init__(self, msg: str, *args: object, lineno=None, col_offset=None) -> None:
|
|
3
|
+
if lineno is not None:
|
|
4
|
+
prefix = f"line {lineno}"
|
|
5
|
+
if col_offset is not None:
|
|
6
|
+
prefix = f"{prefix}, column {col_offset}"
|
|
7
|
+
msg = f"{prefix}: {msg}"
|
|
8
|
+
super().__init__(msg, *args)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = ["ParserError"]
|
vnnlib/parser.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import bz2
|
|
4
|
+
import gzip
|
|
5
|
+
import lzma
|
|
6
|
+
import re
|
|
7
|
+
import warnings
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import (
|
|
10
|
+
Callable,
|
|
11
|
+
Dict,
|
|
12
|
+
Iterator,
|
|
13
|
+
List,
|
|
14
|
+
NamedTuple,
|
|
15
|
+
Optional,
|
|
16
|
+
Set,
|
|
17
|
+
TextIO,
|
|
18
|
+
Union,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from .errors import ParserError
|
|
22
|
+
|
|
23
|
+
Real = float
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Meta(NamedTuple):
|
|
27
|
+
start_pos: int
|
|
28
|
+
end_pos: int
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Token(NamedTuple):
|
|
32
|
+
token_type: str
|
|
33
|
+
value: str
|
|
34
|
+
meta: Meta
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def as_dict(t: NamedTuple):
|
|
38
|
+
return t._asdict()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
_DUMMY_TOKEN = Token("_", "", Meta(0, 0))
|
|
42
|
+
EOF = Token("EOF", "", Meta(-1, -1))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def tokenize(text: str, skip: Set[str], strict=True) -> Iterator[Token]:
|
|
46
|
+
if len(text) == 0:
|
|
47
|
+
return
|
|
48
|
+
tokens: Dict[str, str] = {
|
|
49
|
+
"COMMENT": r";[\t -~]*(?:[\r\n]|$)",
|
|
50
|
+
"WS": r"\x09|\x0a|\x0d|\x20",
|
|
51
|
+
"LPAREN": r"\(",
|
|
52
|
+
"RPAREN": r"\)",
|
|
53
|
+
"BINARY": r"#b[01]+",
|
|
54
|
+
"HEXADECIMAL": r"#x[0-9A-Fa-f]+",
|
|
55
|
+
"_DECIMAL_strict": r"(?:{NUMERAL})\.0*(?:{NUMERAL})",
|
|
56
|
+
"_DECIMAL_extended": r"(?:(?:{NUMERAL})\.0*(?:{NUMERAL})(?:[eE][+-]?0*(?:{NUMERAL}))?)",
|
|
57
|
+
"DECIMAL": f"(?:{{_DECIMAL_{'strict' if strict else 'extended'}}})",
|
|
58
|
+
"NUMERAL": r"(?:(?:[1-9][0-9]*)|0)",
|
|
59
|
+
"STRING": r"\x22(?:(?:{WS})|(?:{_PRINTABLE_CHAR}))*\x22",
|
|
60
|
+
"SYMBOL": r"(?:(?:(?:{_LETTER})|(?:{_CHARACTER}))(?:[0-9]|(?:{_LETTER})|(?:{_CHARACTER}))*)|(?:\x7c(?:[\x20-\x5b]|[\x5d-\x7b]|[\x7d\x7e]|[\x80-\xff])*\x7c)",
|
|
61
|
+
"_PRINTABLE_CHAR": r"[\x20-\x7e]|[\x80-\xff]",
|
|
62
|
+
"_LETTER": r"[A-Za-z]",
|
|
63
|
+
"_CHARACTER": r"[~!@$%^&*+=<>.?/_-]",
|
|
64
|
+
}
|
|
65
|
+
for key, value in tokens.items():
|
|
66
|
+
tokens[key] = value.format(**tokens)
|
|
67
|
+
token_pattern = re.compile(
|
|
68
|
+
"|".join(
|
|
69
|
+
f"(?P<{token_type}>{pattern})"
|
|
70
|
+
for token_type, pattern in tokens.items()
|
|
71
|
+
if not token_type.startswith("_")
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
pos: int = 0
|
|
75
|
+
for match in token_pattern.finditer(text):
|
|
76
|
+
start_pos, end_pos = match.span()
|
|
77
|
+
if start_pos != pos:
|
|
78
|
+
raise ParserError(f"Unknown string: {text[pos:start_pos]!r}")
|
|
79
|
+
assert match.lastgroup is not None
|
|
80
|
+
if match.lastgroup not in skip:
|
|
81
|
+
yield Token(match.lastgroup, match.group(), Meta(start_pos, end_pos))
|
|
82
|
+
pos = end_pos
|
|
83
|
+
if pos != len(text):
|
|
84
|
+
raise ParserError(f"Unknown string: {text[pos:]!r}")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class AstNode:
|
|
88
|
+
@property
|
|
89
|
+
def _type(self) -> str:
|
|
90
|
+
return self.__class__.__name__
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class Script(AstNode):
|
|
94
|
+
def __init__(self, *commands: Command):
|
|
95
|
+
self.commands = commands
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Command(AstNode):
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Declare(Command):
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class DeclareConst(Declare):
|
|
107
|
+
def __init__(self, symbol: str, sort: str):
|
|
108
|
+
self.symbol = symbol
|
|
109
|
+
self.sort = sort
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class Assert(Command):
|
|
113
|
+
def __init__(self, term: Term):
|
|
114
|
+
self.term = term
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class Term(AstNode):
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class FunctionApplication(Term):
|
|
122
|
+
def __init__(self, function: Identifier, *terms: Term):
|
|
123
|
+
self.function = function
|
|
124
|
+
self.terms = terms
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Constant(Term):
|
|
128
|
+
def __init__(self, value: float | int | str | Real):
|
|
129
|
+
self.value = value
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class Sort(AstNode):
|
|
133
|
+
def __init__(self, value: str):
|
|
134
|
+
self.value = value
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class Identifier(Term):
|
|
138
|
+
def __init__(self, value: str, sort: Sort):
|
|
139
|
+
self.value = value
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _hex_to_int(x: str) -> int:
|
|
143
|
+
return int(x[2:], 16)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _bin_to_int(x: str) -> int:
|
|
147
|
+
return int(x[2:], 2)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _identity(x: str) -> str:
|
|
151
|
+
return x
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
LITERAL_CONVERTERS: Dict[str, Callable[[str], float | int | str | Real]] = {
|
|
155
|
+
"DECIMAL": Real,
|
|
156
|
+
"NUMERAL": int,
|
|
157
|
+
"HEXADECIMAL": _hex_to_int,
|
|
158
|
+
"BINARY": _bin_to_int,
|
|
159
|
+
"STRING": _identity,
|
|
160
|
+
}
|
|
161
|
+
CORE_IDS: Dict[str, Identifier] = {
|
|
162
|
+
# arithmetic
|
|
163
|
+
"+": Identifier("+", Sort("(A A) A")),
|
|
164
|
+
"-": Identifier("-", Sort("(A A) A")),
|
|
165
|
+
"*": Identifier("*", Sort("(A A) A")),
|
|
166
|
+
"/": Identifier("/", Sort("(A A) A")),
|
|
167
|
+
# comparisons
|
|
168
|
+
">": Identifier(">", Sort("(A A) Bool")),
|
|
169
|
+
">=": Identifier(">=", Sort("(A A) Bool")),
|
|
170
|
+
"<": Identifier("<", Sort("(A A) Bool")),
|
|
171
|
+
"<=": Identifier("<=", Sort("(A A) Bool")),
|
|
172
|
+
"=": Identifier("=", Sort("(A A) Bool")),
|
|
173
|
+
# logic
|
|
174
|
+
"true": Identifier("true", Sort("Bool")),
|
|
175
|
+
"false": Identifier("false", Sort("Bool")),
|
|
176
|
+
"not": Identifier("not", Sort("(Bool Bool)")),
|
|
177
|
+
"=>": Identifier("=>", Sort("(Bool Bool) Bool")),
|
|
178
|
+
"or": Identifier("or", Sort("(Bool Bool) Bool")),
|
|
179
|
+
"and": Identifier("and", Sort("(Bool Bool) Bool")),
|
|
180
|
+
"xor": Identifier("xor", Sort("(Bool Bool) Bool")),
|
|
181
|
+
"ite": Identifier("ite", Sort("(Bool A A) A")),
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class VnnLibParser:
|
|
186
|
+
def __init__(self, token_stream: Iterator[Token]):
|
|
187
|
+
self.token_stream = token_stream
|
|
188
|
+
self.curr_token = _DUMMY_TOKEN
|
|
189
|
+
self.sorts = {"Bool": Sort("Bool"), "Int": Sort("Int"), "Real": Sort("Real")}
|
|
190
|
+
self.identifiers: Dict[str, Identifier] = CORE_IDS.copy()
|
|
191
|
+
|
|
192
|
+
def advance_token_stream(self) -> Token:
|
|
193
|
+
self.curr_token = next_token = next(self.token_stream, EOF)
|
|
194
|
+
return next_token
|
|
195
|
+
|
|
196
|
+
def expect_token_type(
|
|
197
|
+
self,
|
|
198
|
+
type: str,
|
|
199
|
+
*,
|
|
200
|
+
expected_value: Optional[str] = None,
|
|
201
|
+
msg: str = "unexpected token: {token_type}({value!r})",
|
|
202
|
+
) -> bool:
|
|
203
|
+
if self.curr_token.token_type == type:
|
|
204
|
+
return True
|
|
205
|
+
if expected_value:
|
|
206
|
+
raise ParserError(f"Expected {expected_value!r}")
|
|
207
|
+
raise ParserError(msg.format(**as_dict(self.curr_token)))
|
|
208
|
+
|
|
209
|
+
def lookup_identifier(self, identifier: str) -> Identifier:
|
|
210
|
+
if identifier not in self.identifiers:
|
|
211
|
+
if (
|
|
212
|
+
identifier.startswith("e")
|
|
213
|
+
and len(identifier) >= 2
|
|
214
|
+
and not set(identifier[1:]).difference(set("0123456789+-"))
|
|
215
|
+
):
|
|
216
|
+
raise ParserError(
|
|
217
|
+
(
|
|
218
|
+
f"undeclared identifier: {identifier!r}."
|
|
219
|
+
"\n\tIt looks like this may be exponential notation, which is not SMT-LIB compliant."
|
|
220
|
+
"\n\tTry turning of strict mode."
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
raise ParserError(f"undeclared identifier: {identifier!r}")
|
|
224
|
+
return self.identifiers[identifier]
|
|
225
|
+
|
|
226
|
+
def lookup_sort(self, name: str) -> Sort:
|
|
227
|
+
if name not in self.sorts:
|
|
228
|
+
raise ParserError(f"undeclared sort: {name!r}")
|
|
229
|
+
return self.sorts[name]
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def parse(cls, text: str, strict=True) -> Script:
|
|
233
|
+
parser = VnnLibParser(tokenize(text, {"WS", "COMMENT"}, strict=strict))
|
|
234
|
+
parser.advance_token_stream()
|
|
235
|
+
commands = []
|
|
236
|
+
while parser.curr_token != EOF:
|
|
237
|
+
commands.append(parser.parse_command())
|
|
238
|
+
return Script(*commands)
|
|
239
|
+
|
|
240
|
+
def parse_command(self) -> Command:
|
|
241
|
+
self.expect_token_type(type="LPAREN", expected_value="(")
|
|
242
|
+
curr_token = self.advance_token_stream()
|
|
243
|
+
command = curr_token.value
|
|
244
|
+
command_parsers = {
|
|
245
|
+
"assert": self.parse_assert,
|
|
246
|
+
"declare-const": self.parse_declare_const,
|
|
247
|
+
}
|
|
248
|
+
if command not in command_parsers:
|
|
249
|
+
raise ParserError(f"Unknown command: {command!r}")
|
|
250
|
+
node = command_parsers[command]()
|
|
251
|
+
self.expect_token_type(type="RPAREN", expected_value=")")
|
|
252
|
+
self.advance_token_stream()
|
|
253
|
+
return node
|
|
254
|
+
|
|
255
|
+
def parse_declare_const(self) -> Declare:
|
|
256
|
+
symbol = self.advance_token_stream()
|
|
257
|
+
self.expect_token_type("SYMBOL")
|
|
258
|
+
sort = self.advance_token_stream()
|
|
259
|
+
self.expect_token_type("SYMBOL")
|
|
260
|
+
self.advance_token_stream()
|
|
261
|
+
self.identifiers[symbol.value] = Identifier(
|
|
262
|
+
symbol.value, self.lookup_sort(sort.value)
|
|
263
|
+
)
|
|
264
|
+
return DeclareConst(symbol.value, sort.value)
|
|
265
|
+
|
|
266
|
+
def parse_assert(self) -> Assert:
|
|
267
|
+
self.advance_token_stream()
|
|
268
|
+
return Assert(self.parse_term())
|
|
269
|
+
|
|
270
|
+
def parse_term(self) -> Term:
|
|
271
|
+
curr_token = self.curr_token
|
|
272
|
+
token_type = curr_token.token_type
|
|
273
|
+
if token_type == "SYMBOL":
|
|
274
|
+
self.advance_token_stream()
|
|
275
|
+
if curr_token.value.startswith("-"):
|
|
276
|
+
warnings.warn("literal negation does not strictly follow SMT-LIB")
|
|
277
|
+
try:
|
|
278
|
+
float_value = Real(curr_token.value)
|
|
279
|
+
return Constant(float_value)
|
|
280
|
+
except:
|
|
281
|
+
pass
|
|
282
|
+
return FunctionApplication(
|
|
283
|
+
self.lookup_identifier("-"),
|
|
284
|
+
self.lookup_identifier(curr_token.value[1:]),
|
|
285
|
+
)
|
|
286
|
+
return self.lookup_identifier(curr_token.value)
|
|
287
|
+
if token_type == "LPAREN":
|
|
288
|
+
children: List[Term] = []
|
|
289
|
+
function_id_token = self.advance_token_stream()
|
|
290
|
+
self.expect_token_type("SYMBOL")
|
|
291
|
+
self.advance_token_stream()
|
|
292
|
+
children.append(self.parse_term())
|
|
293
|
+
while self.curr_token.token_type != "RPAREN":
|
|
294
|
+
children.append(self.parse_term())
|
|
295
|
+
self.advance_token_stream()
|
|
296
|
+
return FunctionApplication(
|
|
297
|
+
self.lookup_identifier(function_id_token.value), *children
|
|
298
|
+
)
|
|
299
|
+
if token_type in LITERAL_CONVERTERS:
|
|
300
|
+
value = LITERAL_CONVERTERS[token_type](curr_token.value)
|
|
301
|
+
self.advance_token_stream()
|
|
302
|
+
return Constant(value)
|
|
303
|
+
raise ParserError(f"Unexpected token: {curr_token}")
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _identity_args(*args):
|
|
307
|
+
return args
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class _Discard:
|
|
311
|
+
pass
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
Discard = _Discard()
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class AstNodeTransformer:
|
|
318
|
+
def transform(self, node: AstNode):
|
|
319
|
+
args = getattr(self, f"_visit_{node._type}")(node)
|
|
320
|
+
return getattr(self, f"transform_{node._type}", _identity_args)(*args)
|
|
321
|
+
|
|
322
|
+
def _visit_Assert(self, node: Assert):
|
|
323
|
+
result = self.transform(node.term)
|
|
324
|
+
return (result,)
|
|
325
|
+
|
|
326
|
+
def _visit_Constant(self, node: Constant):
|
|
327
|
+
return (node.value,)
|
|
328
|
+
|
|
329
|
+
def _visit_DeclareConst(self, node: DeclareConst):
|
|
330
|
+
return (node.symbol, node.sort)
|
|
331
|
+
|
|
332
|
+
def _visit_FunctionApplication(self, node: FunctionApplication):
|
|
333
|
+
function = self.transform(node.function)
|
|
334
|
+
terms = [self.transform(term) for term in node.terms]
|
|
335
|
+
return (function, *terms)
|
|
336
|
+
|
|
337
|
+
def _visit_Identifier(self, node: Identifier):
|
|
338
|
+
return (node.value,)
|
|
339
|
+
|
|
340
|
+
def _visit_Script(self, node: Script):
|
|
341
|
+
results = []
|
|
342
|
+
for command in node.commands:
|
|
343
|
+
result = self.transform(command)
|
|
344
|
+
if result is not Discard:
|
|
345
|
+
results.append(result)
|
|
346
|
+
return results
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def parse_file(filename: Union[str, Path], strict=True) -> AstNode:
|
|
350
|
+
if isinstance(filename, str):
|
|
351
|
+
filename = Path(filename)
|
|
352
|
+
open_func: Callable[[Union[str, Path]], TextIO]
|
|
353
|
+
if filename.suffix in {".gz", ".gzip"}:
|
|
354
|
+
open_func = lambda fname: gzip.open(fname, "rt")
|
|
355
|
+
elif filename.suffix in {".bz2", ".bzip2"}:
|
|
356
|
+
open_func = lambda fname: bz2.open(fname, "rt")
|
|
357
|
+
elif filename.suffix == ".xz":
|
|
358
|
+
open_func = lambda fname: lzma.open(fname, "rt")
|
|
359
|
+
else:
|
|
360
|
+
open_func = open
|
|
361
|
+
|
|
362
|
+
with open_func(filename) as f:
|
|
363
|
+
text = f.read()
|
|
364
|
+
ast_node = VnnLibParser.parse(text, strict=strict)
|
|
365
|
+
return ast_node
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
__all__ = ["VnnLibParser", "parse_file"]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vnnlib
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.1a2
|
|
4
4
|
Author-email: David Shriver <davidshriver@outlook.com>
|
|
5
5
|
License: MIT
|
|
6
6
|
Project-URL: Documentation, https://github.com/dlshriver/vnnlib#readme
|
|
@@ -47,17 +47,49 @@ Requires-Dist: pytest (~=7.2.1) ; extra == 'test'
|
|
|
47
47
|
|
|
48
48
|
-----
|
|
49
49
|
|
|
50
|
-
|
|
50
|
+
A python package for parsing neural network properties in the [VNN-LIB format](https://www.vnnlib.org/).
|
|
51
|
+
It should currently parse a superset of the VNN-LIB spec supported by [example parser](https://github.com/stanleybak/nnenum/blob/master/src/nnenum/vnnlib.py) written by Stan Bak for [VNN-COMP](https://sites.google.com/view/vnn2023), and will produce compiled specs in the same format.
|
|
52
|
+
Additionally, we allow parsing of gzip, bzip2, and lzma compressed specs.
|
|
53
|
+
|
|
54
|
+
> Our parser is currently slower than the previous scripts due to the increased specification support. However, we expect significant optimization opportunities are available, and that overhead will decrease over time.
|
|
55
|
+
|
|
56
|
+
> This package is still alpha software and APIs other than the compatibility API may change before the first release. We hope to have a stable release out before or during the benchmark proposal phase of VNN-COMP 2023.
|
|
51
57
|
|
|
52
|
-
- [Installation](#installation)
|
|
53
|
-
- [License](#license)
|
|
54
58
|
|
|
55
59
|
## Installation
|
|
56
60
|
|
|
61
|
+
For the latest stable version, you can install from PyPI with:
|
|
62
|
+
|
|
57
63
|
```console
|
|
58
64
|
pip install vnnlib
|
|
59
65
|
```
|
|
60
66
|
|
|
67
|
+
> PyPI currently only has pre-releases of `vnnlib`. To install a pre-release version, add the `--pre` option to the above command.
|
|
68
|
+
|
|
69
|
+
For the latest updates of `vnnlib`, you can pip install directly from the GitHub repo with:
|
|
70
|
+
|
|
71
|
+
```console
|
|
72
|
+
pip install git+https://github.com/dlshriver/vnnlib.git@main
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Usage
|
|
76
|
+
|
|
77
|
+
This package can be used as a drop-in replacement for the VNN-COMP utility script by importing
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
from vnnlib.compat import read_vnnlib_simple
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
wherever you previously imported `read_vnnlib_simple`.
|
|
84
|
+
|
|
85
|
+
### Standalone
|
|
86
|
+
|
|
87
|
+
The parser can also be used to compile vnnlib ahead of time to reduce future property read times. The result of parsing will be pickled and saved to the location specified.
|
|
88
|
+
|
|
89
|
+
```console
|
|
90
|
+
python -m vnnlib [FILE] --compat -o [OUTPUTFILE]
|
|
91
|
+
```
|
|
92
|
+
|
|
61
93
|
## License
|
|
62
94
|
|
|
63
95
|
`vnnlib` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
vnnlib/__init__.py,sha256=omfedQ-RoZiKP9S45H_m-HymipFJN__PxlX-_gzQr7A,87
|
|
2
|
+
vnnlib/__main__.py,sha256=A1FMFZQf5RZaceQ1p0QkneOIkRZK1HZ_93L73QC4wBY,1465
|
|
3
|
+
vnnlib/__version__.py,sha256=3EXaFnXv1z-8AVCGrDM9MpCrpg8I15YhcHCsHzX3iGI,24
|
|
4
|
+
vnnlib/compat.py,sha256=zpEsdqLf1oWBoDAZlZ_Etp2IYtiZfFNb44zrrWuIhCg,9725
|
|
5
|
+
vnnlib/errors.py,sha256=1FNvtN977kBeNE6l0VwqF4nPZuTfZyS6V2NsdMmmtW4,385
|
|
6
|
+
vnnlib/parser.py,sha256=4AHssf44RRiYh83TwTu9jehMD6UK-cq1n3YhUGAiSvA,11227
|
|
7
|
+
vnnlib-0.0.1a2.dist-info/LICENSE.txt,sha256=-XL5dew5SUKD_JytQcjChPYEUOkRzbKGjV2NOQfiWJk,1105
|
|
8
|
+
vnnlib-0.0.1a2.dist-info/METADATA,sha256=B8r76gi17n59twGajuErGpuKLrm02E1R3AV6YZNCvX8,3850
|
|
9
|
+
vnnlib-0.0.1a2.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
|
|
10
|
+
vnnlib-0.0.1a2.dist-info/top_level.txt,sha256=lYFo9TDo2Pzkutd_gdOp3mv-cQ9HwisEMB8GC9mX9uI,7
|
|
11
|
+
vnnlib-0.0.1a2.dist-info/RECORD,,
|
vnnlib-0.0.1a0.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
vnnlib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
vnnlib/__main__.py,sha256=_AJDw4czH0hCK-fAZqjnOwpJo3TmCBEWPA0WOLBx-e0,316
|
|
3
|
-
vnnlib/__version__.py,sha256=508eDu4nPCyqYyg_NQBRjRZmK_68ggYywDmqDMXRF1I,24
|
|
4
|
-
vnnlib-0.0.1a0.dist-info/LICENSE.txt,sha256=-XL5dew5SUKD_JytQcjChPYEUOkRzbKGjV2NOQfiWJk,1105
|
|
5
|
-
vnnlib-0.0.1a0.dist-info/METADATA,sha256=dMmYNMksAuyiHeakhviSbfL69t_c3VeKWZuY8cLzOoA,2178
|
|
6
|
-
vnnlib-0.0.1a0.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
|
|
7
|
-
vnnlib-0.0.1a0.dist-info/top_level.txt,sha256=lYFo9TDo2Pzkutd_gdOp3mv-cQ9HwisEMB8GC9mX9uI,7
|
|
8
|
-
vnnlib-0.0.1a0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|