mal-toolbox 1.2.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1070 @@
1
+ from tree_sitter import TreeCursor, Node
2
+ from typing import Any, Callable, Tuple
3
+ from collections.abc import MutableMapping, MutableSequence
4
+ from pathlib import Path
5
+ from .lang import PARSER as parser
6
+ from .mal_analyzer import malAnalyzer
7
+ from .exceptions import (
8
+ MalCompilerError,
9
+ )
10
+
11
+ ASTNode = Tuple[str, object]
12
+
13
+ """
14
+ This function is crucial to use instead of cursor.goto_next_sibling()
15
+
16
+ Although the `comment` node is included as an extra in the
17
+ TreeSitter grammar, it still shows up in the AST.
18
+ For this reason, this function exists to go to the next node,
19
+ while ignoring `comment` node.
20
+
21
+ It returns a boolean which states if there are any nodes left
22
+ and if the current node is not a comment.
23
+ """
24
+
25
+
26
+ def go_to_sibling(cursor: TreeCursor) -> bool:
27
+ found_sibling = cursor.goto_next_sibling()
28
+ while assert_node(cursor.node).type == 'comment' and found_sibling:
29
+ found_sibling = cursor.goto_next_sibling()
30
+ return found_sibling and assert_node(cursor.node).type != 'comment'
31
+
32
+
33
+ class ParseTreeVisitor:
34
+ def __init__(self) -> None:
35
+ self.current_file: Path | None = None
36
+ self.visited_files: set[Path] = set()
37
+ self.path_stack: list[Path] = []
38
+ self.analyzer = malAnalyzer()
39
+
40
+ def compile(self, malfile: Path | str):
41
+ current_file = Path(malfile)
42
+
43
+ if not current_file.is_absolute() and self.path_stack:
44
+ # Only for the first file self.path_stack will be empty.
45
+ current_file = self.path_stack[-1] / current_file
46
+
47
+ if current_file in self.visited_files:
48
+ # Avoid infinite loops due to recursive includes
49
+ return {}
50
+
51
+ self.visited_files.add(current_file)
52
+ self.path_stack.append(current_file.parent)
53
+
54
+ result = None
55
+ with open(current_file, 'rb') as f:
56
+ source = f.read()
57
+ tree = parser.parse(source)
58
+ result = self.visit(tree.walk())
59
+
60
+ self.path_stack.pop()
61
+
62
+ return result
63
+
64
+ def visit(self, cursor, params=None):
65
+ function_name = f'visit_{cursor.node.type}' # obtain the appropriate function
66
+ visitor = getattr(self, function_name, self.skip)
67
+ hasChild = cursor.goto_first_child() # enter node's children
68
+ result = visitor(cursor) if not params else visitor(cursor, params)
69
+
70
+ if hasChild:
71
+ cursor.goto_parent() # leave node's children
72
+
73
+ analyzer_method_name: str = f'check_{cursor.node.type}'
74
+ analyzer_method: Callable[..., Any] | None = getattr(
75
+ self.analyzer, analyzer_method_name, None
76
+ )
77
+
78
+ if analyzer_method:
79
+ arguments = analyzer_method.__code__.co_argcount
80
+ if arguments in [2, 3]:
81
+ {
82
+ 3: lambda: analyzer_method(cursor.node, result),
83
+ 2: lambda: analyzer_method(cursor.node),
84
+ }[arguments]()
85
+ else:
86
+ raise ValueError(f'Unexpected number of arguments: {arguments}')
87
+
88
+ return result
89
+
90
+ def visit_source_file(self, cursor: TreeCursor) -> dict[str, Any]:
91
+ langspec: dict[str, Any] = {
92
+ 'formatVersion': '1.0.0',
93
+ 'defines': {},
94
+ 'categories': [],
95
+ 'assets': [],
96
+ 'associations': [],
97
+ }
98
+
99
+ # Go to first declaration
100
+ while True:
101
+ if assert_node(cursor.node).type == 'comment':
102
+ go_to_sibling(cursor)
103
+
104
+ # Obtain node type of declaration
105
+ cursor.goto_first_child()
106
+
107
+ # Visit declaration
108
+ result = self.visit(cursor)
109
+ assert result, 'No result from visit'
110
+ key, value = result
111
+ if key == 'categories':
112
+ category, assets = value
113
+ langspec['categories'].extend(category)
114
+ langspec['assets'].extend(assets)
115
+ elif key == 'defines':
116
+ langspec[key].update(value)
117
+ elif key == 'associations':
118
+ langspec[key].extend(value)
119
+ elif key == 'include':
120
+ included_file = self.compile(value)
121
+ for k, v in langspec.items():
122
+ if isinstance(v, MutableMapping):
123
+ langspec[k].update(included_file.get(k, {}))
124
+ if isinstance(v, MutableSequence) and k in included_file:
125
+ langspec[k].extend(included_file[k])
126
+
127
+ # Go back to declaration
128
+ cursor.goto_parent()
129
+
130
+ # Attempt to move to next declaration. If not possible, done processing
131
+ if not go_to_sibling(cursor):
132
+ break
133
+
134
+ for key in ('categories', 'assets', 'associations'):
135
+ unique = []
136
+ for item in langspec[key]:
137
+ if item not in unique:
138
+ unique.append(item)
139
+ langspec[key] = unique
140
+
141
+ return langspec
142
+
143
+ def skip(self, *args, **kwargs):
144
+ pass
145
+
146
+ def _visit(self, cursor: TreeCursor) -> ASTNode | list[ASTNode] | None:
147
+ # Function name of child class handling the node type
148
+ function_name = f'visit_{assert_node(cursor.node).type}'
149
+
150
+ # Enter into the node
151
+ has_children = cursor.goto_first_child()
152
+
153
+ # Default to skip, in case a specific visitor can't be found
154
+ # Generally the case for anonymous nodes (keywords etc) or
155
+ # named nodes (rules) that do not have a vistor implemented yet.
156
+ visitor = getattr(self, function_name, self.skip)
157
+
158
+ # Use visitor implementation
159
+ visitor_value = visitor(cursor)
160
+
161
+ # Exit the node
162
+ if has_children:
163
+ cursor.goto_parent()
164
+
165
+ return visitor_value
166
+
167
+ def _skip(self, cursor: TreeCursor) -> ASTNode | list[ASTNode] | None:
168
+ values = []
169
+ if visitor_value := self.visit(cursor):
170
+ values.append(visitor_value)
171
+ while go_to_sibling(cursor):
172
+ if visitor_value := self.visit(cursor):
173
+ values.append(visitor_value)
174
+ match len(values):
175
+ case 0:
176
+ return None
177
+ case 1:
178
+ return values[0]
179
+ case _:
180
+ return values
181
+
182
+ def visit_comment(self, cursor: TreeCursor, params=None):
183
+ return (None, None)
184
+
185
+
186
+ # Concrete visitor to process function definitions
187
+ class MalCompiler(ParseTreeVisitor):
188
+ # Named visit_{rule name in grammar.js}
189
+ def visit_define_declaration(self, cursor: TreeCursor) -> ASTNode:
190
+ ###############################
191
+ # '#' (identity) ':' (string) #
192
+ ###############################
193
+
194
+ # skip '#' node
195
+ go_to_sibling(cursor)
196
+ # grab (identity) node
197
+ key = node_text(cursor, 'id')
198
+ # next node
199
+ go_to_sibling(cursor)
200
+ # skip ':' node
201
+ go_to_sibling(cursor)
202
+ # grab (string) node
203
+ value = node_text(cursor, 'string')
204
+
205
+ return ('defines', {key.decode(): value.decode().strip('"')})
206
+
207
+ def visit_category_declaration(self, cursor: TreeCursor) -> ASTNode:
208
+ ############################################
209
+ # 'category' (id) (meta)* '{' (asset)* '}' #
210
+ ############################################
211
+
212
+ category = {}
213
+
214
+ # skip 'category'
215
+ go_to_sibling(cursor)
216
+ # grab (identity)
217
+ category['name'] = node_text(cursor, 'id').decode()
218
+ # next node
219
+ go_to_sibling(cursor)
220
+
221
+ # grab (meta)
222
+ #
223
+ # Since it is optional, we have to make sure we are dealing with a
224
+ # grammar rule
225
+ meta = {}
226
+ while assert_node(cursor.node).is_named:
227
+ info = self.visit(cursor)
228
+ meta[info[0]] = info[1]
229
+ go_to_sibling(cursor)
230
+ category['meta'] = meta
231
+
232
+ # skip '{' node
233
+ go_to_sibling(cursor)
234
+
235
+ # grab (asset)
236
+ assets = []
237
+ while assert_node(cursor.node).is_named:
238
+ asset = self.visit(cursor, category['name'])
239
+ assets.append(asset)
240
+ go_to_sibling(cursor)
241
+
242
+ # next node and skip '}' node
243
+ go_to_sibling(cursor), cursor.goto_next_sibling()
244
+
245
+ return ('categories', ([category], assets))
246
+
247
+ def visit_include_declaration(self, cursor: TreeCursor):
248
+ ####################
249
+ # 'include' (file) #
250
+ ####################
251
+
252
+ # skip 'include'
253
+ go_to_sibling(cursor)
254
+
255
+ # grab (file) which is a (string)
256
+ # '"' are ASCII so they are garantueed to only take one byte in UTF-8 (assumed for .decode)
257
+ # therefor, we can greedily only take bytes 1:-1
258
+ # strip surrounding quotes (") by slicing
259
+ return ('include', node_text(cursor, '')[1:-1].decode())
260
+
261
+ def visit_meta(self, cursor: TreeCursor) -> ASTNode:
262
+ ############################
263
+ # (id) 'info' ':' (string) #
264
+ ############################
265
+
266
+ # grab (id) node
267
+ id = node_text(cursor, 'id').decode()
268
+
269
+ # next node
270
+ go_to_sibling(cursor)
271
+ # skip 'info' node
272
+ go_to_sibling(cursor)
273
+ # skip ':' node
274
+ go_to_sibling(cursor)
275
+
276
+ # grab (string) node
277
+ # '"' are ASCII so they are garantueed to only take one byte in UTF-8 (assumed for .decode)
278
+ # therefor, we can greedily only take bytes 1:-1
279
+ # strip surrounding quotes (") by slicing
280
+ info_string = node_text(cursor, '')[1:-1].decode()
281
+
282
+ return (id, info_string)
283
+
284
+ def visit_asset_declaration(
285
+ self, cursor: TreeCursor, category: str
286
+ ) -> dict[str, Any]:
287
+ ##############################################################################
288
+ # (abstract)? 'asset' (id) (extends id)? (meta)* '{' (asset_definition)* '}' #
289
+ ##############################################################################
290
+
291
+ # grab (abstract)?
292
+ isAbstract = node_text(cursor, 'abstract') == b'abstract'
293
+ if isAbstract:
294
+ go_to_sibling(cursor) # We must go to 'asset'
295
+
296
+ # skip 'asset'
297
+ go_to_sibling(cursor)
298
+
299
+ # grab (id)
300
+ name = node_text(cursor, '').decode()
301
+ go_to_sibling(cursor)
302
+
303
+ # grab (extends id)?
304
+ superAsset = None
305
+ if node_text(cursor, '') == b'extends':
306
+ go_to_sibling(cursor) # move to the id
307
+ superAsset = node_text(cursor, '').decode() # get the text
308
+ go_to_sibling(cursor) # move to the meta
309
+
310
+ # grab (meta)*
311
+ meta = {}
312
+ while assert_node(cursor.node).is_named:
313
+ info = self.visit(cursor)
314
+ meta[info[0]] = info[1]
315
+ go_to_sibling(cursor)
316
+
317
+ # skip '{'
318
+ go_to_sibling(cursor)
319
+
320
+ # visit asset_definition
321
+ variables, attackSteps = [], []
322
+ if assert_node(cursor.node).is_named:
323
+ variables, attackSteps = self.visit(cursor)
324
+
325
+ return {
326
+ 'name': name,
327
+ 'meta': meta,
328
+ 'category': category,
329
+ 'isAbstract': isAbstract,
330
+ 'superAsset': superAsset,
331
+ 'variables': variables,
332
+ 'attackSteps': attackSteps,
333
+ }
334
+
335
+ def visit_asset_definition(self, cursor: TreeCursor) -> tuple[list, list]:
336
+ #######################
337
+ # (variable)* (step)* #
338
+ #######################
339
+
340
+ variables, steps = [], []
341
+ while True:
342
+ definition, result = self.visit(cursor)
343
+
344
+ if definition == 'variable':
345
+ variables.append(result)
346
+ elif definition == 'step':
347
+ steps.append(result)
348
+
349
+ if not go_to_sibling(cursor):
350
+ break
351
+
352
+ return (variables, steps)
353
+
354
+ def visit_asset_variable(self, cursor: TreeCursor) -> ASTNode:
355
+ ##########################
356
+ # 'let' (id) '=' (value) #
357
+ ##########################
358
+
359
+ ret = {}
360
+
361
+ # skip 'let'
362
+ go_to_sibling(cursor)
363
+
364
+ # grab id
365
+ ret['name'] = node_text(cursor, 'id').decode()
366
+ go_to_sibling(cursor)
367
+
368
+ # skip '='
369
+ go_to_sibling(cursor)
370
+
371
+ ret['stepExpression'] = self.visit(cursor)
372
+
373
+ # TODO visit step expression
374
+
375
+ return ('variable', ret)
376
+
377
+ def visit_attack_step(self, cursor: TreeCursor) -> tuple[str, dict]:
378
+ ##############################################################################################################
379
+ # (step_type) (id) ( '@' (id) )* ( '{' (cias) '}' )? (ttc)? (meta)* (detector)? (preconditions)? (reaches)? #
380
+ ##############################################################################################################
381
+
382
+ # grab (step_type)
383
+ # use raw text bytes to avoid decoding as much as possible
384
+ step_type_btext = node_text(cursor, 'step type')
385
+ step_type_bindings = {
386
+ b'&': 'and',
387
+ b'|': 'or',
388
+ b'#': 'defense',
389
+ b'E': 'exist',
390
+ b'!E': 'notExist',
391
+ }
392
+
393
+ # decode value only if its really necessary (no binding found)
394
+ if (step_type := step_type_bindings.get(step_type_btext)) is None:
395
+ step_type = step_type_btext.decode()
396
+ go_to_sibling(cursor)
397
+
398
+ # grab optional (causal_mode) before (id)
399
+ causal_mode = None
400
+ current_text = node_text(cursor, 'causal_mode')
401
+ if current_text in (b'action', b'effect'):
402
+ causal_mode = current_text.decode()
403
+ go_to_sibling(cursor) # skip causal_mode
404
+
405
+ # grab (id)
406
+ name = node_text(cursor, 'name').decode()
407
+ go_to_sibling(cursor)
408
+
409
+ # process all ( '@' (id) ) we might have
410
+ tags = []
411
+
412
+ # TODO change grammar to make (@ id)* instead of (@ id)?
413
+ while node_text(cursor, 'tag at') == b'@':
414
+ go_to_sibling(cursor) # skip '@'
415
+ tags.append(node_text(cursor, '').decode()) # grab (id)
416
+ if not go_to_sibling(cursor): # move to next symbol, break if last
417
+ break
418
+
419
+ # process all ( '{' (cias) '}' ) we might have
420
+ risk = None
421
+ if assert_node(cursor.node).text == b'{':
422
+ go_to_sibling(cursor) # skip '{'
423
+ risk = self.visit(cursor) # grab (cias)
424
+ go_to_sibling(cursor) # go to '}'
425
+ go_to_sibling(cursor) # and skip it
426
+
427
+ ttc = None
428
+ if assert_node(cursor.node).type == 'ttc':
429
+ # visit ttc
430
+ ttc = self.visit(cursor)
431
+ go_to_sibling(cursor)
432
+
433
+ meta = {}
434
+ while assert_node(cursor.node).type == 'meta':
435
+ info = self.visit(cursor)
436
+ meta[info[0]] = info[1]
437
+ if not go_to_sibling(cursor): # in case there is nothing after the meta
438
+ break
439
+
440
+ detectors: dict[str, Any] = {}
441
+ while assert_node(cursor.node).type == 'detector':
442
+ detector = self.visit(cursor)
443
+ detector[detector['name']] = detector
444
+ if not go_to_sibling(cursor): # in case there is nothing after the meta
445
+ break
446
+
447
+ requires = None
448
+ if assert_node(cursor.node).type == 'preconditions':
449
+ requires = self.visit(cursor)
450
+ go_to_sibling(cursor)
451
+
452
+ reaches = None
453
+ if assert_node(cursor.node).type == 'reaching':
454
+ reaches = self.visit(cursor)
455
+ go_to_sibling(cursor)
456
+
457
+ ret = {
458
+ 'name': name,
459
+ 'meta': meta,
460
+ 'detectors': detectors,
461
+ 'type': step_type,
462
+ 'causal_mode': causal_mode,
463
+ 'tags': tags,
464
+ 'risk': risk,
465
+ 'ttc': ttc,
466
+ 'requires': requires,
467
+ 'reaches': reaches,
468
+ }
469
+
470
+ return ('step', ret)
471
+
472
+ def visit_detector(self, cursor: TreeCursor):
473
+ ####################################################################
474
+ # ('!' | '//!') (detector_name)? (detector_context) (type)? (ttc)? #
475
+ ####################################################################
476
+
477
+ # skip bang
478
+ go_to_sibling(cursor)
479
+
480
+ # grab detector_name
481
+ detector_name = None
482
+ if cursor.field_name == 'name':
483
+ detector_name = node_text(cursor, 'detector name').decode()
484
+ go_to_sibling(cursor)
485
+
486
+ # grab detector_context
487
+ detector_context = self.visit(cursor)
488
+ go_to_sibling(cursor)
489
+
490
+ # grab id
491
+ detector_type = None
492
+ if cursor.field_name == 'type':
493
+ detector_name = node_text(cursor, 'type').decode()
494
+ go_to_sibling(cursor)
495
+
496
+ # grab ttc
497
+ detector_ttc = None
498
+ if cursor.field_name == 'ttc':
499
+ # TODO: this is broken
500
+ raise NotImplementedError('TTC not implemented for detectors')
501
+ # detector_ttc = self.visit(ttc)
502
+ # go_to_sibling(cursor)
503
+
504
+ return {
505
+ 'name': detector_name,
506
+ 'context': detector_context,
507
+ 'type': detector_type,
508
+ 'tprate': detector_ttc,
509
+ }
510
+
511
+ def visit_detector_context(self, cursor: TreeCursor):
512
+ ####################################################################
513
+ # '(' (detector_context_asset) (',' (detector_context_asset))* ')' #
514
+ ####################################################################
515
+
516
+ # skip '('
517
+ go_to_sibling(cursor)
518
+
519
+ # grab detector_context_asset
520
+ context = {}
521
+ label, asset = self.visit(cursor)
522
+ context[label] = asset
523
+ go_to_sibling(cursor)
524
+
525
+ while node_text(cursor, 'char') != b')':
526
+ # skip ','
527
+ go_to_sibling(cursor)
528
+ # grab another detector_context_asset
529
+ label, asset = self.visit(cursor)
530
+ context[label] = asset
531
+ go_to_sibling(cursor)
532
+
533
+ return context
534
+
535
+ def visit_detector_context_asset(self, cursor: TreeCursor):
536
+ ###############
537
+ # (type) (id) #
538
+ ###############
539
+ asset = node_text(cursor, 'asset')
540
+ label = node_text(cursor, 'label')
541
+
542
+ return (label, asset)
543
+
544
+ def visit_cias(self, cursor: TreeCursor):
545
+ ######################
546
+ # (cia) (',' (cia))* #
547
+ ######################
548
+ risk = {
549
+ 'isConfidentiality': False,
550
+ 'isIntegrity': False,
551
+ 'isAvailability': False,
552
+ }
553
+
554
+ while True:
555
+ val = self.visit(cursor)
556
+ risk.update(val)
557
+
558
+ ret = go_to_sibling(cursor)
559
+ if not ret: # no more ',' -> done
560
+ break
561
+
562
+ # Otherwise, process the next CIA
563
+ go_to_sibling(cursor)
564
+
565
+ return risk
566
+
567
+ def visit_cia(self, cursor: TreeCursor):
568
+ ###############
569
+ # 'C'|'I'|'A' #
570
+ ###############
571
+
572
+ cia_btext = node_text(cursor, 'cia')
573
+ cia_bindings = {
574
+ b'C': 'isConfidentiality',
575
+ b'I': 'isIntegrity',
576
+ b'A': 'isAvailability',
577
+ }
578
+ key = cia_bindings.get(cia_btext)
579
+
580
+ return {key: True}
581
+
582
+ def visit_ttc(self, cursor: TreeCursor):
583
+ ##################################
584
+ # '[' (intermediary_ttc_exp) ']' #
585
+ ##################################
586
+
587
+ # skip '['
588
+ go_to_sibling(cursor)
589
+
590
+ return self._visit_intermediary_ttc_expr(cursor)
591
+
592
+ def _visit_intermediary_ttc_expr(self, cursor: TreeCursor):
593
+ ###################################################################################################
594
+ # '(' (intermediary_ttc_expr) ')' | (integer) | (float) | (id) | (ttc_distribution) | (ttc_binop) #
595
+ ###################################################################################################
596
+
597
+ # check if we have '(', in this case it's a parenthesized expression
598
+ if node_text(cursor, 'char') == b'(':
599
+ go_to_sibling(cursor) # skip '('
600
+ result = self._visit_intermediary_ttc_expr(cursor) # visit the expression
601
+ go_to_sibling(cursor) # skip ')'
602
+ return result
603
+
604
+ # if we have an id, just return it
605
+ elif assert_node(cursor.node).type == 'identifier':
606
+ text = node_text(cursor, 'id').decode()
607
+ return {'type': 'function', 'name': text, 'arguments': []}
608
+
609
+ # if we have a number (integer/float) we need to construct
610
+ # the dictionary correctly
611
+ elif (
612
+ assert_node(cursor.node).type == 'float'
613
+ or assert_node(cursor.node).type == 'integer'
614
+ ):
615
+ ret: dict[str, Any] = {'type': 'number'}
616
+ ret['value'] = self.visit(cursor)
617
+ return ret
618
+
619
+ # otherwise visit the node
620
+ return self.visit(cursor)
621
+
622
+ def visit_float(self, cursor: TreeCursor):
623
+ ret = float(node_text(cursor, 'float'))
624
+
625
+ return ret
626
+
627
+ def visit_integer(self, cursor: TreeCursor):
628
+ ret = float(node_text(cursor, 'integer'))
629
+
630
+ return ret
631
+
632
+ def visit_ttc_binop(self, cursor: TreeCursor):
633
+ #########################################################################
634
+ # (intermediary_ttc_expr) ('+'|'-'|'*'|'/'|'^') (intermediary_ttc_expr) #
635
+ #########################################################################
636
+
637
+ # grab first (intermediary_ttc_expr)
638
+ lhs = self._visit_intermediary_ttc_expr(cursor)
639
+ go_to_sibling(cursor)
640
+
641
+ # grab operation type
642
+ operation = assert_node(cursor.node).text
643
+ operation_type_bindings = {
644
+ b'+': 'addition',
645
+ b'-': 'subtraction',
646
+ b'*': 'multiplication',
647
+ b'/': 'division',
648
+ b'^': 'exponentiation',
649
+ }
650
+ assert operation, 'Operation not found'
651
+ operation_type = operation_type_bindings.get(operation)
652
+ go_to_sibling(cursor)
653
+
654
+ # grab second (intermediary_ttc_expr)
655
+ rhs = self._visit_intermediary_ttc_expr(cursor)
656
+
657
+ return {'type': operation_type, 'lhs': lhs, 'rhs': rhs}
658
+
659
+ def visit_ttc_distribution(self, cursor: TreeCursor):
660
+ ############################################
661
+ # (id) '(' (number)* ( ',' (number) )* ')' #
662
+ ############################################
663
+
664
+ # grab (id)
665
+ name = node_text(cursor, 'name').decode()
666
+ go_to_sibling(cursor)
667
+
668
+ # skip '('
669
+ go_to_sibling(cursor)
670
+
671
+ # parse function arguments
672
+ args = []
673
+ while assert_node(cursor.node).type in ('float', 'integer'):
674
+ # obtain the number
675
+ arg = self.visit(cursor)
676
+ args.append(arg)
677
+ # move to next symbol, if it's not a comma then done
678
+ go_to_sibling(cursor)
679
+ if assert_node(cursor.node).text != b',':
680
+ break
681
+ # otherwise, ignore the comma
682
+ go_to_sibling(cursor)
683
+
684
+ return {'type': 'function', 'name': name, 'arguments': args}
685
+
686
+ def visit_preconditions(self, cursor: TreeCursor):
687
+ ##########################################
688
+ # '<-' (asset_expr) (',' (asset_expr) )* #
689
+ ##########################################
690
+
691
+ # Skip '<-'
692
+ go_to_sibling(cursor)
693
+
694
+ ret: dict[str, Any] = {}
695
+ ret['overrides'] = True
696
+ ret['stepExpressions'] = [self.visit(cursor)]
697
+
698
+ while go_to_sibling(cursor): # check if we have a ','
699
+ go_to_sibling(cursor) # ignore the ','
700
+ ret['stepExpressions'].append(self.visit(cursor))
701
+
702
+ return ret
703
+
704
+ def visit_reaching(self, cursor: TreeCursor):
705
+ ################################################
706
+ # ( '+>' | '->' ) (reaches) ( ',' (reaches) )* #
707
+ ################################################
708
+
709
+ ret: dict[str, Any] = {}
710
+
711
+ # Get type of reaches
712
+ ret['overrides'] = assert_node(cursor.node).text == b'->'
713
+ go_to_sibling(cursor)
714
+
715
+ # Visit the steps
716
+ ret['stepExpressions'] = [self.visit(cursor)]
717
+
718
+ while go_to_sibling(cursor): # check if we have a ','
719
+ go_to_sibling(cursor) # ignore the ','
720
+ ret['stepExpressions'].append(self.visit(cursor))
721
+
722
+ return ret
723
+
724
+ def visit_asset_expr(self, cursor: TreeCursor):
725
+ return self._visit_inline_asset_expr(cursor)
726
+
727
+ def _visit_inline_asset_expr(self, cursor: TreeCursor):
728
+ #############################################################################################################################################
729
+ # '(' (_inline_asset_expr) ')' | (id) | (asset_variable_substitution) | (asset_expr_binop) | (asset_expr_unop) | (asset_expr_type) #
730
+ #############################################################################################################################################
731
+
732
+ # The objective of this function is to mimick the _inline_asset_expr
733
+ # In other words, this function will figure out the type of the node it just received,
734
+ # pretending that it was an _inline_asset_expr
735
+
736
+ ret = {}
737
+
738
+ assert cursor.node, 'Missing node'
739
+ if assert_node(cursor.node).type == 'identifier':
740
+ ret['type'] = self._resolve_part_ID_type(cursor)
741
+ ret['name'] = node_text(cursor, 'name').decode()
742
+ elif node_text(cursor, 'char').decode() == '(':
743
+ go_to_sibling(cursor) # ignore the '('
744
+ ret = self._visit_inline_asset_expr(cursor)
745
+ go_to_sibling(cursor) # ignore the ')'
746
+ else:
747
+ ret = self.visit(cursor)
748
+
749
+ return ret
750
+
751
+ def visit_asset_variable_substitution(self, cursor: TreeCursor):
752
+ ################
753
+ # (id) '(' ')' #
754
+ ################
755
+
756
+ return {'type': 'variable', 'name': node_text(cursor, 'name').decode()}
757
+
758
+ def visit_asset_expr_type(self, cursor: TreeCursor):
759
+ #####################################
760
+ # (_inline_asset_expr) '[' (id) ']' #
761
+ #####################################
762
+
763
+ # On the ANTLR version, we would visit the subtypes from left to right,
764
+ # so we would have to store them recursively. However, in the TreeSitter
765
+ # version, we are starting from right to left, so we can just visit
766
+ # the `lhs` and return the current subtype
767
+
768
+ # Visit the inline expr
769
+ stepExpression = self._visit_inline_asset_expr(cursor)
770
+ go_to_sibling(cursor)
771
+
772
+ # Skip '['
773
+ go_to_sibling(cursor)
774
+
775
+ # Get the subType
776
+ subType = node_text(cursor, 'subType').decode()
777
+
778
+ return {'type': 'subType', 'subType': subType, 'stepExpression': stepExpression}
779
+
780
+ def visit_asset_expr_binop(self, cursor: TreeCursor):
781
+ ########################################################################
782
+ # (_inline_asset_expr) ( '\/' | '/\' | '-' | '.') (_inline_asset_expr) #
783
+ ########################################################################
784
+
785
+ # Get the lhs
786
+ lhs = self._visit_inline_asset_expr(cursor)
787
+ go_to_sibling(cursor)
788
+
789
+ # Get the type of operation
790
+ assert cursor.node, 'Missing node for operation type'
791
+ op_btext = node_text(cursor, '')
792
+ assert op_btext, 'Missing text for operation node'
793
+ optype_bindings = {
794
+ b'.': 'collect',
795
+ b'\\/': 'union',
796
+ b'/\\': 'intersection',
797
+ b'-': 'difference',
798
+ }
799
+ optype = optype_bindings.get(op_btext)
800
+ go_to_sibling(cursor)
801
+
802
+ # Get the rhs
803
+ rhs = self._visit_inline_asset_expr(cursor)
804
+ return {'type': optype, 'lhs': lhs, 'rhs': rhs}
805
+
806
+ def visit_asset_expr_unop(self, cursor: TreeCursor):
807
+ #############################
808
+ # (_inline_asset_expr) '*' #
809
+ #############################
810
+
811
+ # Get the associated expression
812
+ expr = self._visit_inline_asset_expr(cursor)
813
+ go_to_sibling(cursor)
814
+
815
+ return {'type': 'transitive', 'stepExpression': expr}
816
+
817
+ def _resolve_part_ID_type(self, cursor: TreeCursor):
818
+ # Figure out if we have a `field` or an `attackStep`
819
+ original_node = cursor.node
820
+ if not original_node:
821
+ raise ValueError('Missing node for id')
822
+
823
+ parent_node = original_node.parent
824
+
825
+ while parent_node and parent_node.type != 'reaching':
826
+ # The idea is to go up the tree. If we find a "reaching" node,
827
+ # we still need to determine if it's a field or a an attackStep
828
+ parent_node = parent_node.parent
829
+
830
+ if not parent_node:
831
+ # If we never find a "reaching" node, eventually we will go to
832
+ # the top of the tree, and we won't be able to go further up.
833
+ # In this case, we originally were in a `let` or `precondition`,
834
+ # which only accepts fields
835
+ return 'field'
836
+
837
+ # We want to know if there is any `.` after the context.
838
+ # If there is, we have a field (as an attackStep does not
839
+ # have attributes)
840
+ #
841
+ # To do this, we will find the start position of the the original
842
+ # node in the text. Each rule matches to one line in the end,
843
+ # so this node will be in the same row as its parent node and in
844
+ # a column inside the range of columns of its parent. So, we
845
+ # just have to split the whole text of the parent starting at the
846
+ # original node's position and iterate from there until the end of
847
+ # the text.
848
+
849
+ # The following logic was implemented to deal with how TreeSitter
850
+ # deals with indents and new lines
851
+
852
+ # We start by obtaining the column where the target node starts,
853
+ original_node_column = original_node.start_point.column
854
+
855
+ # We get the parent's text and split it into the original
856
+ # lines (as written in the code)
857
+ assert parent_node.text, 'Missing parent node text'
858
+ tokenStream = parent_node.text.decode()
859
+ tokenStream = tokenStream.split('\n')
860
+ tokenStream_split = None
861
+
862
+ # If the parent and the target are defined in the same line,
863
+ # then we must remove the start point from the original column,
864
+ # since TreeSitter deletes the indent
865
+ if original_node.start_point.row == parent_node.start_point.row:
866
+ tokenStream_split = tokenStream[0]
867
+ original_node_column = (
868
+ original_node.start_point.column - parent_node.start_point.column
869
+ )
870
+ # However, if they are in different rows, the indent must be included,
871
+ # so we use the same column
872
+ else:
873
+ tokenStream_split = tokenStream[
874
+ original_node.start_point.row - parent_node.start_point.row
875
+ ]
876
+
877
+ # Afterwards, we just do the normal checks, knowing what column to start in
878
+ assert original_node.text, 'Missing node text'
879
+ start_col = original_node_column + len(original_node.text.decode())
880
+ tokenStream_split = tokenStream_split[start_col:]
881
+ for char in tokenStream_split:
882
+ if char == '.':
883
+ return 'field' # Only a field can have attributes
884
+ if char == ',':
885
+ return 'attackStep' # A `,` means we are starting a new reaches
886
+
887
+ return 'attackStep'
888
+
889
+ def visit_associations_declaration(self, cursor: TreeCursor):
890
+ #########################################
891
+ # 'associations' '{' (association)* '}' #
892
+ #########################################
893
+
894
+ # skip 'associations'
895
+ go_to_sibling(cursor)
896
+
897
+ # skip '{'
898
+ go_to_sibling(cursor)
899
+
900
+ # visit all associations
901
+ associations = []
902
+ while cursor.node and node_text(cursor, '') != b'}':
903
+ associations.append(self.visit(cursor))
904
+ go_to_sibling(cursor)
905
+
906
+ return ('associations', associations)
907
+
908
+ def visit_association(self, cursor: TreeCursor):
909
+ ##############################################################################################
910
+ # (id) '[' (id) ']' (multiplicity) '<--' (id) '-->' (multiplicity) '[' (id) ']' (id) (meta)* #
911
+ ##############################################################################################
912
+
913
+ # Get 1st id - left asset
914
+ left_asset = node_text(cursor, 'left asset').decode()
915
+ go_to_sibling(cursor)
916
+
917
+ # skip '['
918
+ go_to_sibling(cursor)
919
+
920
+ # Get 2nd id - left field
921
+ left_field = node_text(cursor, 'left field').decode()
922
+ go_to_sibling(cursor)
923
+
924
+ # skip ']'
925
+ go_to_sibling(cursor)
926
+
927
+ # Get left multiplicity
928
+ left_multiplicity = self.visit(cursor)
929
+ go_to_sibling(cursor)
930
+
931
+ # skip '<--'
932
+ go_to_sibling(cursor)
933
+
934
+ # Get 3rd id - name of the association
935
+ name = node_text(cursor, 'name').decode()
936
+ go_to_sibling(cursor)
937
+
938
+ # skip '-->'
939
+ go_to_sibling(cursor)
940
+
941
+ # Get right multiplicity
942
+ right_multiplicity = self.visit(cursor)
943
+ go_to_sibling(cursor)
944
+
945
+ # skip '['
946
+ go_to_sibling(cursor)
947
+
948
+ # Get 4th id - right field
949
+ right_field = node_text(cursor, 'right field').decode()
950
+ go_to_sibling(cursor)
951
+
952
+ # skip ']'
953
+ go_to_sibling(cursor)
954
+
955
+ # Get 5th id - right asset
956
+ right_asset = node_text(cursor, 'right asset').decode()
957
+
958
+ # Get all metas
959
+ meta = {}
960
+ while go_to_sibling(cursor):
961
+ res = self.visit(cursor)
962
+ assert res
963
+ meta[res[0]] = res[1]
964
+
965
+ association = {
966
+ 'name': name,
967
+ 'meta': meta,
968
+ 'leftAsset': left_asset,
969
+ 'leftField': left_field,
970
+ 'leftMultiplicity': left_multiplicity,
971
+ 'rightAsset': right_asset,
972
+ 'rightField': right_field,
973
+ 'rightMultiplicity': right_multiplicity,
974
+ }
975
+
976
+ self._process_multitudes(association)
977
+
978
+ return association
979
+
980
+ def visit_multiplicity(self, cursor: TreeCursor):
981
+ ###############################################
982
+ # (_multiplicity_atom) | (multiplicity_range) #
983
+ ###############################################
984
+
985
+ if cursor.node is None:
986
+ raise MalCompilerError('multiplicity atom missing node')
987
+
988
+ if assert_node(cursor.node).type == 'multiplicity_range':
989
+ return self.visit(cursor)
990
+
991
+ # Otherwise we need to visit an intermediary function for
992
+ # atomic multiplicity expressions
993
+ min = self._visit_multiplicity_atom(cursor)
994
+ return {
995
+ 'min': min,
996
+ 'max': None,
997
+ }
998
+
999
+ def visit_multiplicity_range(self, cursor: TreeCursor):
1000
+ ##################################################
1001
+ # (_multiplicity_atom) '..' (_multiplicity_atom) #
1002
+ ##################################################
1003
+
1004
+ min = self._visit_multiplicity_atom(cursor)
1005
+ go_to_sibling(cursor)
1006
+
1007
+ # skip '..'
1008
+ go_to_sibling(cursor)
1009
+
1010
+ max = self._visit_multiplicity_atom(cursor)
1011
+
1012
+ return {
1013
+ 'min': min,
1014
+ 'max': max,
1015
+ }
1016
+
1017
+ def _visit_multiplicity_atom(self, cursor: TreeCursor):
1018
+ ######################
1019
+ # (integer) | (star) #
1020
+ ######################
1021
+ if not cursor.node:
1022
+ raise MalCompilerError('multiplicity atom missing node')
1023
+ if not node_text(cursor, ''):
1024
+ raise ValueError('multiplicity atom has empty text')
1025
+ return node_text(cursor, '').decode()
1026
+
1027
+ def _process_multitudes(self, association):
1028
+ mult_keys = [
1029
+ # start the multatoms from right to left to make sure the rules
1030
+ # below get applied cleanly
1031
+ 'rightMultiplicity.max',
1032
+ 'rightMultiplicity.min',
1033
+ 'leftMultiplicity.max',
1034
+ 'leftMultiplicity.min',
1035
+ ]
1036
+
1037
+ for mult_key in mult_keys:
1038
+ key, subkey = mult_key.split('.')
1039
+
1040
+ # upper limit equals lower limit if not given
1041
+ if subkey == 'max' and association[key][subkey] is None:
1042
+ association[key][subkey] = association[key]['min']
1043
+
1044
+ if association[key][subkey] == '*':
1045
+ # 'any' as lower limit means start from 0
1046
+ if subkey == 'min':
1047
+ association[key][subkey] = 0
1048
+
1049
+ # 'any' as upper limit means not limit
1050
+ else:
1051
+ association[key][subkey] = None
1052
+
1053
+ # cast numerical strings to integers
1054
+ if (multatom := association[key][subkey]) and multatom.isdigit():
1055
+ association[key][subkey] = int(association[key][subkey])
1056
+
1057
+
1058
+ def assert_node(node: Node | None) -> Node:
1059
+ if node is None:
1060
+ raise ValueError('Node can not be None')
1061
+ return node
1062
+
1063
+
1064
+ def node_text(cursor: TreeCursor, context: str):
1065
+ node = cursor.node
1066
+ if node is None:
1067
+ raise MalCompilerError(f'expected node for {context}, found None')
1068
+ if not node.text:
1069
+ raise MalCompilerError(f'expected text for {context}, found empty')
1070
+ return node.text