vnnlib 0.0.1a0__py3-none-any.whl → 0.0.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vnnlib might be problematic. Click here for more details.

vnnlib/__init__.py CHANGED
@@ -0,0 +1,3 @@
1
+ from .parser import VnnLibParser, parse_file
2
+
3
+ __all__ = ["VnnLibParser", "parse_file"]
vnnlib/__main__.py CHANGED
@@ -1,18 +1,52 @@
1
+ from __future__ import annotations
2
+
1
3
  import argparse
4
+ import pickle
5
+ from pathlib import Path
6
+
7
+ from .compat import CompatTransformer
8
+ from .parser import parse_file
2
9
 
3
10
 
4
- def parse_args():
11
+ def parse_args() -> argparse.Namespace:
5
12
  parser = argparse.ArgumentParser(
6
13
  "vnnlib",
7
14
  description="",
8
15
  )
9
- parser.add_argument("file", nargs="+")
16
+ parser.add_argument("file", type=Path)
17
+ parser.add_argument(
18
+ "--compat", action="store_true", help="Use the VNN-COMP-1 output format"
19
+ )
20
+ parser.add_argument(
21
+ "--strict",
22
+ action=argparse.BooleanOptionalAction,
23
+ default=True,
24
+ help="Whether or not to strictly follow VNN-LIB (default: True)",
25
+ )
26
+ parser.add_argument(
27
+ "-o", "--output", type=str, help="The path to save the compiled output"
28
+ )
10
29
  return parser.parse_args()
11
30
 
12
31
 
13
- def __main__():
32
+ def __main__() -> None:
14
33
  args = parse_args()
15
- print(f"parsing {args.file}")
34
+ file: Path = args.file
35
+ print(f"parsing file: {args.file}")
36
+
37
+ if args.compat:
38
+ if ".vnnlib" in file.suffixes:
39
+ ast_node = parse_file(file, strict=args.strict)
40
+ result = CompatTransformer("X", "Y").transform(ast_node)
41
+ if args.output:
42
+ with open(args.output, "wb+") as f:
43
+ pickle.dump(result, f)
44
+ else:
45
+ raise RuntimeError(f"Unsupported file type: {file.suffix}")
46
+ else:
47
+ raise NotImplementedError(
48
+ "Currently only the VNN-COMP-1 output format is supported"
49
+ )
16
50
 
17
51
 
18
52
  if __name__ == "__main__":
vnnlib/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.0.1a0"
1
+ __version__ = "0.0.1a2"
vnnlib/compat.py ADDED
@@ -0,0 +1,238 @@
1
+ from __future__ import annotations
2
+
3
+ import operator
4
+ import pathlib
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+
9
+ from .parser import AstNodeTransformer, Real, parse_file
10
+
11
+
12
+ class CompatTransformer(AstNodeTransformer):
13
+ def __init__(
14
+ self,
15
+ input_name: str,
16
+ output_name: str,
17
+ input_size: Optional[int] = None,
18
+ output_size: Optional[int] = None,
19
+ ) -> None:
20
+ self.input_name = input_name
21
+ self.output_name = output_name
22
+
23
+ self.input_size = input_size or 0
24
+ self.output_size = output_size or 0
25
+
26
+ self.infer_input_size = input_size is None
27
+ self.infer_output_size = output_size is None
28
+
29
+ self._id_map: Dict[str, int] = {}
30
+ self._assertions: Dict[Tuple[int, ...], Real] = {}
31
+ self._num_assertions = 0
32
+ self._disjunctions: List[Dict[Tuple[int, ...], Real]] = [{}]
33
+
34
+ def transform_Assert(
35
+ self,
36
+ term: Union[List[Dict[Tuple[int, ...], Real]], Dict[Tuple[int, ...], Real]],
37
+ ) -> List[Dict[Tuple[int, ...], Real]]:
38
+ if isinstance(term, list):
39
+ if len(term) == 1:
40
+ row_offset = self._num_assertions
41
+ max_row = 0
42
+ for (row, *index), value in term[0].items():
43
+ self._assertions[(row + row_offset, *index)] = value
44
+ max_row = max(max_row, row)
45
+ if len(term[0]):
46
+ self._num_assertions += max_row + 1
47
+ elif len(self._disjunctions) == 1:
48
+ new_disjunctions = []
49
+ for disjunct in term:
50
+ disjunct = disjunct.copy()
51
+ new_disjunctions.append(disjunct)
52
+ row_offset = max(disjunct, default=(-1,))[0] + 1
53
+ for (row, *index), value in self._disjunctions[0].items():
54
+ disjunct[(row + row_offset, *index)] = value
55
+ self._disjunctions = new_disjunctions
56
+ else:
57
+ assert False
58
+ return term
59
+ if isinstance(term, dict):
60
+ row_offset = self._num_assertions
61
+ max_row = 0
62
+ for (row, *index), value in term.items():
63
+ self._assertions[(row + row_offset, *index)] = value
64
+ max_row = max(max_row, row)
65
+ if len(term):
66
+ self._num_assertions += max_row + 1
67
+ return [term]
68
+ raise RuntimeError("unexpected term for assert")
69
+
70
+ def transform_Constant(self, value) -> Dict[Tuple[int, ...], Real]:
71
+ assert isinstance(value, (Real, int))
72
+ return {(0, -1, -1): value}
73
+
74
+ def transform_DeclareConst(self, symbol: str, sort: str) -> None:
75
+ if self.infer_input_size and symbol.startswith(f"{self.input_name}_"):
76
+ _, index = symbol.split("_")
77
+ self.input_size = max(self.input_size, int(index) + 1)
78
+ elif self.infer_output_size and symbol.startswith(f"{self.output_name}_"):
79
+ _, index = symbol.split("_")
80
+ self.output_size = max(self.output_size, int(index) + 1)
81
+
82
+ def transform_FunctionApplication(
83
+ self,
84
+ symbol: str,
85
+ *terms: Union[List[Dict[Tuple[int, ...], Real]], Dict[Tuple[int, ...], Real]],
86
+ ) -> Union[List[Dict[Tuple[int, ...], Real]], Dict[Tuple[int, ...], Real]]:
87
+ if symbol == "<=":
88
+ lhs, rhs = terms
89
+ assert isinstance(lhs, dict)
90
+ assert isinstance(rhs, dict)
91
+ result = lhs.copy()
92
+ for key, value in rhs.items():
93
+ result[key] = result.get(key, 0) - value
94
+ return result
95
+ elif symbol == ">=":
96
+ lhs, rhs = terms
97
+ assert isinstance(lhs, dict)
98
+ assert isinstance(rhs, dict)
99
+ result = rhs.copy()
100
+ for key, value in lhs.items():
101
+ result[key] = result.get(key, 0) - value
102
+ return result
103
+ elif symbol == "and":
104
+ conjuncts = {}
105
+ for i, term in enumerate(terms):
106
+ assert isinstance(term, dict)
107
+ for (row, *index), value in term.items():
108
+ assert row == 0
109
+ assert len(index) == 2, "please open a bug report"
110
+ conjuncts[(i, *index)] = value
111
+ return [conjuncts]
112
+ elif symbol == "or":
113
+ or_result = []
114
+ for term in terms:
115
+ assert isinstance(term, list), "please open a bug report"
116
+ for subterm in term:
117
+ assert isinstance(subterm, dict), "please open a bug report"
118
+ or_result.extend(term)
119
+ return or_result
120
+ else:
121
+ raise NotImplementedError(
122
+ f"Function {symbol!r} is not supported by the legacy parser"
123
+ )
124
+
125
+ def transform_Identifier(
126
+ self, value: str
127
+ ) -> Union[str, Dict[Tuple[int, ...], Real]]:
128
+ if value.startswith(f"{self.input_name}_"):
129
+ _, *str_index = value.split("_")
130
+ return {(0, 0, *tuple(map(int, str_index))): 1}
131
+ elif value.startswith(f"{self.output_name}_"):
132
+ _, *str_index = value.split("_")
133
+ return {(0, 1, *tuple(map(int, str_index))): 1}
134
+ elif value in {"<=", ">=", "and", "or"}:
135
+ return value
136
+ if value not in self._id_map:
137
+ self._id_map[value] = len(self._id_map) + 2
138
+ return {(0, self._id_map[value]): 1}
139
+
140
+ def transform_Script(self, *commands):
141
+ common_box = [[float("-inf"), float("inf")] for _ in range(self.input_size)]
142
+ common_polytope = []
143
+ input_box_rows = set()
144
+ output_polytope_rows = set()
145
+ rhs = 0
146
+ for (row, var_type, index), value in sorted(
147
+ self._assertions.items(), key=operator.itemgetter(0)
148
+ ):
149
+ assert var_type <= 0 or row not in input_box_rows
150
+ if var_type == -1:
151
+ rhs = value
152
+ continue
153
+ if var_type == 0:
154
+ input_box_rows.add(row)
155
+ if value == 1:
156
+ common_box[index][1] = min(-rhs, common_box[index][1])
157
+ elif value == -1:
158
+ common_box[index][0] = max(rhs, common_box[index][0])
159
+ else:
160
+ raise RuntimeError(f"unexpected lhs coeff for box: {value}")
161
+ rhs = 0
162
+ continue
163
+ if var_type == 1:
164
+ output_polytope_rows.add(row)
165
+ polytope_row = row - min(output_polytope_rows)
166
+ if len(common_polytope) <= polytope_row:
167
+ common_polytope.append(
168
+ [[0 for _ in range(self.output_size)], [rhs]]
169
+ )
170
+ rhs = 0
171
+ common_polytope[polytope_row][0][index] = value
172
+ continue
173
+ raise RuntimeError(f"unexpected variable type {var_type}")
174
+ results = {}
175
+ for disjunct in self._disjunctions:
176
+ box = [interval.copy() for interval in common_box]
177
+ polytope = [[lhs.copy(), rhs.copy()] for lhs, rhs in common_polytope]
178
+ input_box_rows = set()
179
+ output_polytope_rows = set()
180
+ rhs = 0
181
+ for (row, var_type, index), value in sorted(
182
+ disjunct.items(), key=lambda kv: kv[0]
183
+ ):
184
+ assert var_type <= 0 or row not in input_box_rows
185
+ if var_type == -1:
186
+ rhs = value
187
+ continue
188
+ if var_type == 0:
189
+ input_box_rows.add(row)
190
+ if value == 1:
191
+ box[index][1] = min(-rhs, box[index][1])
192
+ elif value == -1:
193
+ box[index][0] = max(rhs, box[index][0])
194
+ else:
195
+ raise RuntimeError(f"unexpected lhs coeff for box: {value}")
196
+ rhs = 0
197
+ continue
198
+ if var_type == 1:
199
+ output_polytope_rows.add(row)
200
+ polytope_row = (
201
+ row - min(output_polytope_rows) + len(common_polytope)
202
+ )
203
+ if len(polytope) <= polytope_row:
204
+ polytope.append([[0 for _ in range(self.output_size)], [rhs]])
205
+ rhs = 0
206
+ polytope[polytope_row][0][index] = value
207
+ continue
208
+ raise RuntimeError(f"unexpected variable type {var_type}")
209
+ box_str = str(box)
210
+ polytope_arr = (
211
+ np.array([lhs for lhs, _ in polytope]),
212
+ np.array([rhs for _, rhs in polytope]),
213
+ )
214
+ if box_str not in results:
215
+ results[box_str] = [box, [polytope_arr]]
216
+ else:
217
+ results[box_str][1].append(polytope_arr)
218
+ return list(results.values())
219
+
220
+
221
+ def read_vnnlib_simple(
222
+ vnnlib_filename: Union[str, pathlib.Path], num_inputs: int, num_outputs: int
223
+ ) -> List[Tuple[List[Real], Tuple[np.ndarray, np.ndarray]]]:
224
+ """process in a vnnlib file. You can get num_inputs and num_outputs using get_num_inputs_outputs().
225
+
226
+ output a list containing 2-tuples:
227
+ 1. input ranges (box), list of pairs for each input variable
228
+ 2. specification, provided as a list of pairs (mat, rhs), as in: mat * y <= rhs, where y is the output.
229
+ Each element in the list is a term in a disjunction for the specification.
230
+ """
231
+ ast_node = parse_file(vnnlib_filename, strict=False)
232
+ result = CompatTransformer("X", "Y", num_inputs, num_outputs).transform(ast_node)
233
+ return result
234
+
235
+
236
+ __all__ = [
237
+ "read_vnnlib_simple",
238
+ ]
vnnlib/errors.py ADDED
@@ -0,0 +1,11 @@
1
+ class ParserError(Exception):
2
+ def __init__(self, msg: str, *args: object, lineno=None, col_offset=None) -> None:
3
+ if lineno is not None:
4
+ prefix = f"line {lineno}"
5
+ if col_offset is not None:
6
+ prefix = f"{prefix}, column {col_offset}"
7
+ msg = f"{prefix}: {msg}"
8
+ super().__init__(msg, *args)
9
+
10
+
11
+ __all__ = ["ParserError"]
vnnlib/parser.py ADDED
@@ -0,0 +1,368 @@
1
+ from __future__ import annotations
2
+
3
+ import bz2
4
+ import gzip
5
+ import lzma
6
+ import re
7
+ import warnings
8
+ from pathlib import Path
9
+ from typing import (
10
+ Callable,
11
+ Dict,
12
+ Iterator,
13
+ List,
14
+ NamedTuple,
15
+ Optional,
16
+ Set,
17
+ TextIO,
18
+ Union,
19
+ )
20
+
21
+ from .errors import ParserError
22
+
23
+ Real = float
24
+
25
+
26
+ class Meta(NamedTuple):
27
+ start_pos: int
28
+ end_pos: int
29
+
30
+
31
+ class Token(NamedTuple):
32
+ token_type: str
33
+ value: str
34
+ meta: Meta
35
+
36
+
37
+ def as_dict(t: NamedTuple):
38
+ return t._asdict()
39
+
40
+
41
+ _DUMMY_TOKEN = Token("_", "", Meta(0, 0))
42
+ EOF = Token("EOF", "", Meta(-1, -1))
43
+
44
+
45
+ def tokenize(text: str, skip: Set[str], strict=True) -> Iterator[Token]:
46
+ if len(text) == 0:
47
+ return
48
+ tokens: Dict[str, str] = {
49
+ "COMMENT": r";[\t -~]*(?:[\r\n]|$)",
50
+ "WS": r"\x09|\x0a|\x0d|\x20",
51
+ "LPAREN": r"\(",
52
+ "RPAREN": r"\)",
53
+ "BINARY": r"#b[01]+",
54
+ "HEXADECIMAL": r"#x[0-9A-Fa-f]+",
55
+ "_DECIMAL_strict": r"(?:{NUMERAL})\.0*(?:{NUMERAL})",
56
+ "_DECIMAL_extended": r"(?:(?:{NUMERAL})\.0*(?:{NUMERAL})(?:[eE][+-]?0*(?:{NUMERAL}))?)",
57
+ "DECIMAL": f"(?:{{_DECIMAL_{'strict' if strict else 'extended'}}})",
58
+ "NUMERAL": r"(?:(?:[1-9][0-9]*)|0)",
59
+ "STRING": r"\x22(?:(?:{WS})|(?:{_PRINTABLE_CHAR}))*\x22",
60
+ "SYMBOL": r"(?:(?:(?:{_LETTER})|(?:{_CHARACTER}))(?:[0-9]|(?:{_LETTER})|(?:{_CHARACTER}))*)|(?:\x7c(?:[\x20-\x5b]|[\x5d-\x7b]|[\x7d\x7e]|[\x80-\xff])*\x7c)",
61
+ "_PRINTABLE_CHAR": r"[\x20-\x7e]|[\x80-\xff]",
62
+ "_LETTER": r"[A-Za-z]",
63
+ "_CHARACTER": r"[~!@$%^&*+=<>.?/_-]",
64
+ }
65
+ for key, value in tokens.items():
66
+ tokens[key] = value.format(**tokens)
67
+ token_pattern = re.compile(
68
+ "|".join(
69
+ f"(?P<{token_type}>{pattern})"
70
+ for token_type, pattern in tokens.items()
71
+ if not token_type.startswith("_")
72
+ )
73
+ )
74
+ pos: int = 0
75
+ for match in token_pattern.finditer(text):
76
+ start_pos, end_pos = match.span()
77
+ if start_pos != pos:
78
+ raise ParserError(f"Unknown string: {text[pos:start_pos]!r}")
79
+ assert match.lastgroup is not None
80
+ if match.lastgroup not in skip:
81
+ yield Token(match.lastgroup, match.group(), Meta(start_pos, end_pos))
82
+ pos = end_pos
83
+ if pos != len(text):
84
+ raise ParserError(f"Unknown string: {text[pos:]!r}")
85
+
86
+
87
+ class AstNode:
88
+ @property
89
+ def _type(self) -> str:
90
+ return self.__class__.__name__
91
+
92
+
93
+ class Script(AstNode):
94
+ def __init__(self, *commands: Command):
95
+ self.commands = commands
96
+
97
+
98
+ class Command(AstNode):
99
+ pass
100
+
101
+
102
+ class Declare(Command):
103
+ pass
104
+
105
+
106
+ class DeclareConst(Declare):
107
+ def __init__(self, symbol: str, sort: str):
108
+ self.symbol = symbol
109
+ self.sort = sort
110
+
111
+
112
+ class Assert(Command):
113
+ def __init__(self, term: Term):
114
+ self.term = term
115
+
116
+
117
+ class Term(AstNode):
118
+ pass
119
+
120
+
121
+ class FunctionApplication(Term):
122
+ def __init__(self, function: Identifier, *terms: Term):
123
+ self.function = function
124
+ self.terms = terms
125
+
126
+
127
+ class Constant(Term):
128
+ def __init__(self, value: float | int | str | Real):
129
+ self.value = value
130
+
131
+
132
+ class Sort(AstNode):
133
+ def __init__(self, value: str):
134
+ self.value = value
135
+
136
+
137
+ class Identifier(Term):
138
+ def __init__(self, value: str, sort: Sort):
139
+ self.value = value
140
+
141
+
142
+ def _hex_to_int(x: str) -> int:
143
+ return int(x[2:], 16)
144
+
145
+
146
+ def _bin_to_int(x: str) -> int:
147
+ return int(x[2:], 2)
148
+
149
+
150
+ def _identity(x: str) -> str:
151
+ return x
152
+
153
+
154
+ LITERAL_CONVERTERS: Dict[str, Callable[[str], float | int | str | Real]] = {
155
+ "DECIMAL": Real,
156
+ "NUMERAL": int,
157
+ "HEXADECIMAL": _hex_to_int,
158
+ "BINARY": _bin_to_int,
159
+ "STRING": _identity,
160
+ }
161
+ CORE_IDS: Dict[str, Identifier] = {
162
+ # arithmetic
163
+ "+": Identifier("+", Sort("(A A) A")),
164
+ "-": Identifier("-", Sort("(A A) A")),
165
+ "*": Identifier("*", Sort("(A A) A")),
166
+ "/": Identifier("/", Sort("(A A) A")),
167
+ # comparisons
168
+ ">": Identifier(">", Sort("(A A) Bool")),
169
+ ">=": Identifier(">=", Sort("(A A) Bool")),
170
+ "<": Identifier("<", Sort("(A A) Bool")),
171
+ "<=": Identifier("<=", Sort("(A A) Bool")),
172
+ "=": Identifier("=", Sort("(A A) Bool")),
173
+ # logic
174
+ "true": Identifier("true", Sort("Bool")),
175
+ "false": Identifier("false", Sort("Bool")),
176
+ "not": Identifier("not", Sort("(Bool Bool)")),
177
+ "=>": Identifier("=>", Sort("(Bool Bool) Bool")),
178
+ "or": Identifier("or", Sort("(Bool Bool) Bool")),
179
+ "and": Identifier("and", Sort("(Bool Bool) Bool")),
180
+ "xor": Identifier("xor", Sort("(Bool Bool) Bool")),
181
+ "ite": Identifier("ite", Sort("(Bool A A) A")),
182
+ }
183
+
184
+
185
+ class VnnLibParser:
186
+ def __init__(self, token_stream: Iterator[Token]):
187
+ self.token_stream = token_stream
188
+ self.curr_token = _DUMMY_TOKEN
189
+ self.sorts = {"Bool": Sort("Bool"), "Int": Sort("Int"), "Real": Sort("Real")}
190
+ self.identifiers: Dict[str, Identifier] = CORE_IDS.copy()
191
+
192
+ def advance_token_stream(self) -> Token:
193
+ self.curr_token = next_token = next(self.token_stream, EOF)
194
+ return next_token
195
+
196
+ def expect_token_type(
197
+ self,
198
+ type: str,
199
+ *,
200
+ expected_value: Optional[str] = None,
201
+ msg: str = "unexpected token: {token_type}({value!r})",
202
+ ) -> bool:
203
+ if self.curr_token.token_type == type:
204
+ return True
205
+ if expected_value:
206
+ raise ParserError(f"Expected {expected_value!r}")
207
+ raise ParserError(msg.format(**as_dict(self.curr_token)))
208
+
209
+ def lookup_identifier(self, identifier: str) -> Identifier:
210
+ if identifier not in self.identifiers:
211
+ if (
212
+ identifier.startswith("e")
213
+ and len(identifier) >= 2
214
+ and not set(identifier[1:]).difference(set("0123456789+-"))
215
+ ):
216
+ raise ParserError(
217
+ (
218
+ f"undeclared identifier: {identifier!r}."
219
+ "\n\tIt looks like this may be exponential notation, which is not SMT-LIB compliant."
220
+ "\n\tTry turning of strict mode."
221
+ )
222
+ )
223
+ raise ParserError(f"undeclared identifier: {identifier!r}")
224
+ return self.identifiers[identifier]
225
+
226
+ def lookup_sort(self, name: str) -> Sort:
227
+ if name not in self.sorts:
228
+ raise ParserError(f"undeclared sort: {name!r}")
229
+ return self.sorts[name]
230
+
231
+ @classmethod
232
+ def parse(cls, text: str, strict=True) -> Script:
233
+ parser = VnnLibParser(tokenize(text, {"WS", "COMMENT"}, strict=strict))
234
+ parser.advance_token_stream()
235
+ commands = []
236
+ while parser.curr_token != EOF:
237
+ commands.append(parser.parse_command())
238
+ return Script(*commands)
239
+
240
+ def parse_command(self) -> Command:
241
+ self.expect_token_type(type="LPAREN", expected_value="(")
242
+ curr_token = self.advance_token_stream()
243
+ command = curr_token.value
244
+ command_parsers = {
245
+ "assert": self.parse_assert,
246
+ "declare-const": self.parse_declare_const,
247
+ }
248
+ if command not in command_parsers:
249
+ raise ParserError(f"Unknown command: {command!r}")
250
+ node = command_parsers[command]()
251
+ self.expect_token_type(type="RPAREN", expected_value=")")
252
+ self.advance_token_stream()
253
+ return node
254
+
255
+ def parse_declare_const(self) -> Declare:
256
+ symbol = self.advance_token_stream()
257
+ self.expect_token_type("SYMBOL")
258
+ sort = self.advance_token_stream()
259
+ self.expect_token_type("SYMBOL")
260
+ self.advance_token_stream()
261
+ self.identifiers[symbol.value] = Identifier(
262
+ symbol.value, self.lookup_sort(sort.value)
263
+ )
264
+ return DeclareConst(symbol.value, sort.value)
265
+
266
+ def parse_assert(self) -> Assert:
267
+ self.advance_token_stream()
268
+ return Assert(self.parse_term())
269
+
270
+ def parse_term(self) -> Term:
271
+ curr_token = self.curr_token
272
+ token_type = curr_token.token_type
273
+ if token_type == "SYMBOL":
274
+ self.advance_token_stream()
275
+ if curr_token.value.startswith("-"):
276
+ warnings.warn("literal negation does not strictly follow SMT-LIB")
277
+ try:
278
+ float_value = Real(curr_token.value)
279
+ return Constant(float_value)
280
+ except:
281
+ pass
282
+ return FunctionApplication(
283
+ self.lookup_identifier("-"),
284
+ self.lookup_identifier(curr_token.value[1:]),
285
+ )
286
+ return self.lookup_identifier(curr_token.value)
287
+ if token_type == "LPAREN":
288
+ children: List[Term] = []
289
+ function_id_token = self.advance_token_stream()
290
+ self.expect_token_type("SYMBOL")
291
+ self.advance_token_stream()
292
+ children.append(self.parse_term())
293
+ while self.curr_token.token_type != "RPAREN":
294
+ children.append(self.parse_term())
295
+ self.advance_token_stream()
296
+ return FunctionApplication(
297
+ self.lookup_identifier(function_id_token.value), *children
298
+ )
299
+ if token_type in LITERAL_CONVERTERS:
300
+ value = LITERAL_CONVERTERS[token_type](curr_token.value)
301
+ self.advance_token_stream()
302
+ return Constant(value)
303
+ raise ParserError(f"Unexpected token: {curr_token}")
304
+
305
+
306
+ def _identity_args(*args):
307
+ return args
308
+
309
+
310
+ class _Discard:
311
+ pass
312
+
313
+
314
+ Discard = _Discard()
315
+
316
+
317
+ class AstNodeTransformer:
318
+ def transform(self, node: AstNode):
319
+ args = getattr(self, f"_visit_{node._type}")(node)
320
+ return getattr(self, f"transform_{node._type}", _identity_args)(*args)
321
+
322
+ def _visit_Assert(self, node: Assert):
323
+ result = self.transform(node.term)
324
+ return (result,)
325
+
326
+ def _visit_Constant(self, node: Constant):
327
+ return (node.value,)
328
+
329
+ def _visit_DeclareConst(self, node: DeclareConst):
330
+ return (node.symbol, node.sort)
331
+
332
+ def _visit_FunctionApplication(self, node: FunctionApplication):
333
+ function = self.transform(node.function)
334
+ terms = [self.transform(term) for term in node.terms]
335
+ return (function, *terms)
336
+
337
+ def _visit_Identifier(self, node: Identifier):
338
+ return (node.value,)
339
+
340
+ def _visit_Script(self, node: Script):
341
+ results = []
342
+ for command in node.commands:
343
+ result = self.transform(command)
344
+ if result is not Discard:
345
+ results.append(result)
346
+ return results
347
+
348
+
349
+ def parse_file(filename: Union[str, Path], strict=True) -> AstNode:
350
+ if isinstance(filename, str):
351
+ filename = Path(filename)
352
+ open_func: Callable[[Union[str, Path]], TextIO]
353
+ if filename.suffix in {".gz", ".gzip"}:
354
+ open_func = lambda fname: gzip.open(fname, "rt")
355
+ elif filename.suffix in {".bz2", ".bzip2"}:
356
+ open_func = lambda fname: bz2.open(fname, "rt")
357
+ elif filename.suffix == ".xz":
358
+ open_func = lambda fname: lzma.open(fname, "rt")
359
+ else:
360
+ open_func = open
361
+
362
+ with open_func(filename) as f:
363
+ text = f.read()
364
+ ast_node = VnnLibParser.parse(text, strict=strict)
365
+ return ast_node
366
+
367
+
368
+ __all__ = ["VnnLibParser", "parse_file"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vnnlib
3
- Version: 0.0.1a0
3
+ Version: 0.0.1a2
4
4
  Author-email: David Shriver <davidshriver@outlook.com>
5
5
  License: MIT
6
6
  Project-URL: Documentation, https://github.com/dlshriver/vnnlib#readme
@@ -47,17 +47,49 @@ Requires-Dist: pytest (~=7.2.1) ; extra == 'test'
47
47
 
48
48
  -----
49
49
 
50
- **Table of Contents**
50
+ A python package for parsing neural network properties in the [VNN-LIB format](https://www.vnnlib.org/).
51
+ It should currently parse a superset of the VNN-LIB spec supported by [example parser](https://github.com/stanleybak/nnenum/blob/master/src/nnenum/vnnlib.py) written by Stan Bak for [VNN-COMP](https://sites.google.com/view/vnn2023), and will produce compiled specs in the same format.
52
+ Additionally, we allow parsing of gzip, bzip2, and lzma compressed specs.
53
+
54
+ > Our parser is currently slower than the previous scripts due to the increased specification support. However, we expect significant optimization opportunities are available, and that overhead will decrease over time.
55
+
56
+ > This package is still alpha software and APIs other than the compatibility API may change before the first release. We hope to have a stable release out before or during the benchmark proposal phase of VNN-COMP 2023.
51
57
 
52
- - [Installation](#installation)
53
- - [License](#license)
54
58
 
55
59
  ## Installation
56
60
 
61
+ For the latest stable version, you can install from PyPI with:
62
+
57
63
  ```console
58
64
  pip install vnnlib
59
65
  ```
60
66
 
67
+ > PyPI currently only has pre-releases of `vnnlib`. To install a pre-release version, add the `--pre` option to the above command.
68
+
69
+ For the latest updates of `vnnlib`, you can pip install directly from the GitHub repo with:
70
+
71
+ ```console
72
+ pip install git+https://github.com/dlshriver/vnnlib.git@main
73
+ ```
74
+
75
+ ## Usage
76
+
77
+ This package can be used as a drop-in replacement for the VNN-COMP utility script by importing
78
+
79
+ ```python
80
+ from vnnlib.compat import read_vnnlib_simple
81
+ ```
82
+
83
+ wherever you previously imported `read_vnnlib_simple`.
84
+
85
+ ### Standalone
86
+
87
+ The parser can also be used to compile vnnlib ahead of time to reduce future property read times. The result of parsing will be pickled and saved to the location specified.
88
+
89
+ ```console
90
+ python -m vnnlib [FILE] --compat -o [OUTPUTFILE]
91
+ ```
92
+
61
93
  ## License
62
94
 
63
95
  `vnnlib` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
@@ -0,0 +1,11 @@
1
+ vnnlib/__init__.py,sha256=omfedQ-RoZiKP9S45H_m-HymipFJN__PxlX-_gzQr7A,87
2
+ vnnlib/__main__.py,sha256=A1FMFZQf5RZaceQ1p0QkneOIkRZK1HZ_93L73QC4wBY,1465
3
+ vnnlib/__version__.py,sha256=3EXaFnXv1z-8AVCGrDM9MpCrpg8I15YhcHCsHzX3iGI,24
4
+ vnnlib/compat.py,sha256=zpEsdqLf1oWBoDAZlZ_Etp2IYtiZfFNb44zrrWuIhCg,9725
5
+ vnnlib/errors.py,sha256=1FNvtN977kBeNE6l0VwqF4nPZuTfZyS6V2NsdMmmtW4,385
6
+ vnnlib/parser.py,sha256=4AHssf44RRiYh83TwTu9jehMD6UK-cq1n3YhUGAiSvA,11227
7
+ vnnlib-0.0.1a2.dist-info/LICENSE.txt,sha256=-XL5dew5SUKD_JytQcjChPYEUOkRzbKGjV2NOQfiWJk,1105
8
+ vnnlib-0.0.1a2.dist-info/METADATA,sha256=B8r76gi17n59twGajuErGpuKLrm02E1R3AV6YZNCvX8,3850
9
+ vnnlib-0.0.1a2.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
10
+ vnnlib-0.0.1a2.dist-info/top_level.txt,sha256=lYFo9TDo2Pzkutd_gdOp3mv-cQ9HwisEMB8GC9mX9uI,7
11
+ vnnlib-0.0.1a2.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- vnnlib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- vnnlib/__main__.py,sha256=_AJDw4czH0hCK-fAZqjnOwpJo3TmCBEWPA0WOLBx-e0,316
3
- vnnlib/__version__.py,sha256=508eDu4nPCyqYyg_NQBRjRZmK_68ggYywDmqDMXRF1I,24
4
- vnnlib-0.0.1a0.dist-info/LICENSE.txt,sha256=-XL5dew5SUKD_JytQcjChPYEUOkRzbKGjV2NOQfiWJk,1105
5
- vnnlib-0.0.1a0.dist-info/METADATA,sha256=dMmYNMksAuyiHeakhviSbfL69t_c3VeKWZuY8cLzOoA,2178
6
- vnnlib-0.0.1a0.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
7
- vnnlib-0.0.1a0.dist-info/top_level.txt,sha256=lYFo9TDo2Pzkutd_gdOp3mv-cQ9HwisEMB8GC9mX9uI,7
8
- vnnlib-0.0.1a0.dist-info/RECORD,,