mal-toolbox 1.2.1__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/METADATA +8 -75
  2. mal_toolbox-2.1.0.dist-info/RECORD +51 -0
  3. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/WHEEL +1 -1
  4. maltoolbox/__init__.py +2 -2
  5. maltoolbox/attackgraph/__init__.py +2 -2
  6. maltoolbox/attackgraph/attackgraph.py +121 -549
  7. maltoolbox/attackgraph/factories.py +68 -0
  8. maltoolbox/attackgraph/file_utils.py +0 -0
  9. maltoolbox/attackgraph/generate.py +338 -0
  10. maltoolbox/attackgraph/node.py +1 -0
  11. maltoolbox/attackgraph/node_getters.py +36 -0
  12. maltoolbox/attackgraph/ttcs.py +28 -0
  13. maltoolbox/language/__init__.py +2 -2
  14. maltoolbox/language/compiler/__init__.py +4 -499
  15. maltoolbox/language/compiler/distributions.py +158 -0
  16. maltoolbox/language/compiler/exceptions.py +37 -0
  17. maltoolbox/language/compiler/lang.py +5 -0
  18. maltoolbox/language/compiler/mal_analyzer.py +920 -0
  19. maltoolbox/language/compiler/mal_compiler.py +1071 -0
  20. maltoolbox/language/detector.py +43 -0
  21. maltoolbox/language/expression_chain.py +218 -0
  22. maltoolbox/language/language_graph_asset.py +180 -0
  23. maltoolbox/language/language_graph_assoc.py +147 -0
  24. maltoolbox/language/language_graph_attack_step.py +129 -0
  25. maltoolbox/language/language_graph_builder.py +282 -0
  26. maltoolbox/language/language_graph_loaders.py +7 -0
  27. maltoolbox/language/language_graph_lookup.py +140 -0
  28. maltoolbox/language/language_graph_serialization.py +5 -0
  29. maltoolbox/language/languagegraph.py +244 -1536
  30. maltoolbox/language/step_expression_processor.py +491 -0
  31. mal_toolbox-1.2.1.dist-info/RECORD +0 -33
  32. maltoolbox/language/compiler/mal_lexer.py +0 -232
  33. maltoolbox/language/compiler/mal_parser.py +0 -3159
  34. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/entry_points.txt +0 -0
  35. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/licenses/AUTHORS +0 -0
  36. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/licenses/LICENSE +0 -0
  37. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1071 @@
1
+ from tree_sitter import TreeCursor, Node
2
+ from typing import Any, Callable, Tuple
3
+ from collections.abc import MutableMapping, MutableSequence
4
+ from pathlib import Path
5
+ from .lang import PARSER as parser
6
+ from .mal_analyzer import malAnalyzer
7
+ from .exceptions import (
8
+ MalCompilerError,
9
+ )
10
+
11
+ ASTNode = Tuple[str, object]
12
+
13
+ """
14
+ This function is crucial to use instead of cursor.goto_next_sibling()
15
+
16
+ Although the `comment` node is included as an extra in the
17
+ TreeSitter grammar, it still shows up in the AST.
18
+ For this reason, this function exists to go to the next node,
19
+ while ignoring `comment` node.
20
+
21
+ It returns a boolean which states if there are any nodes left
22
+ and if the current node is not a comment.
23
+ """
24
+
25
+
26
+ def go_to_sibling(cursor: TreeCursor) -> bool:
27
+ found_sibling = cursor.goto_next_sibling()
28
+ while assert_node(cursor.node).type == 'comment' and found_sibling:
29
+ found_sibling = cursor.goto_next_sibling()
30
+ return found_sibling and assert_node(cursor.node).type != 'comment'
31
+
32
+
33
+ class ParseTreeVisitor:
34
+ def __init__(self) -> None:
35
+ self.current_file: Path | None = None
36
+ self.visited_files: set[Path] = set()
37
+ self.path_stack: list[Path] = []
38
+ self.analyzer = malAnalyzer()
39
+
40
+ def compile(self, malfile: Path | str):
41
+ current_file = Path(malfile)
42
+
43
+ if not current_file.is_absolute() and self.path_stack:
44
+ # Only for the first file self.path_stack will be empty.
45
+ current_file = self.path_stack[-1] / current_file
46
+
47
+ if current_file in self.visited_files:
48
+ # Avoid infinite loops due to recursive includes
49
+ return {}
50
+
51
+ self.visited_files.add(current_file)
52
+ self.path_stack.append(current_file.parent)
53
+
54
+ result = None
55
+ with open(current_file, 'rb') as f:
56
+ source = f.read()
57
+ tree = parser.parse(source)
58
+ result = self.visit(tree.walk())
59
+
60
+ self.path_stack.pop()
61
+
62
+ return result
63
+
64
+ def visit(self, cursor, params=None):
65
+ function_name = f'visit_{cursor.node.type}' # obtain the appropriate function
66
+ visitor = getattr(self, function_name, self.skip)
67
+ hasChild = cursor.goto_first_child() # enter node's children
68
+ result = visitor(cursor) if not params else visitor(cursor, params)
69
+
70
+ if hasChild:
71
+ cursor.goto_parent() # leave node's children
72
+
73
+ analyzer_method_name: str = f'check_{cursor.node.type}'
74
+ analyzer_method: Callable[..., Any] | None = getattr(
75
+ self.analyzer, analyzer_method_name, None
76
+ )
77
+
78
+ if analyzer_method:
79
+ arguments = analyzer_method.__code__.co_argcount
80
+ if arguments in [2, 3]:
81
+ {
82
+ 3: lambda: analyzer_method(cursor.node, result),
83
+ 2: lambda: analyzer_method(cursor.node),
84
+ }[arguments]()
85
+ else:
86
+ raise ValueError(f'Unexpected number of arguments: {arguments}')
87
+
88
+ return result
89
+
90
+ def visit_source_file(self, cursor: TreeCursor) -> dict[str, Any]:
91
+ langspec: dict[str, Any] = {
92
+ 'formatVersion': '1.0.0',
93
+ 'defines': {},
94
+ 'categories': [],
95
+ 'assets': [],
96
+ 'associations': [],
97
+ }
98
+
99
+ # Go to first declaration
100
+ while True:
101
+ if assert_node(cursor.node).type == 'comment':
102
+ go_to_sibling(cursor)
103
+
104
+ # Obtain node type of declaration
105
+ cursor.goto_first_child()
106
+
107
+ # Visit declaration
108
+ result = self.visit(cursor)
109
+ assert result, 'No result from visit'
110
+ key, value = result
111
+ if key == 'categories':
112
+ category, assets = value
113
+ langspec['categories'].extend(category)
114
+ langspec['assets'].extend(assets)
115
+ elif key == 'defines':
116
+ langspec[key].update(value)
117
+ elif key == 'associations':
118
+ langspec[key].extend(value)
119
+ elif key == 'include':
120
+ included_file = self.compile(value)
121
+ for k, v in langspec.items():
122
+ if isinstance(v, MutableMapping):
123
+ langspec[k].update(included_file.get(k, {}))
124
+ if isinstance(v, MutableSequence) and k in included_file:
125
+ langspec[k].extend(included_file[k])
126
+
127
+ # Go back to declaration
128
+ cursor.goto_parent()
129
+
130
+ # Attempt to move to next declaration. If not possible, done processing
131
+ if not go_to_sibling(cursor):
132
+ break
133
+
134
+ for key in ('categories', 'assets', 'associations'):
135
+ unique = []
136
+ for item in langspec[key]:
137
+ if item not in unique:
138
+ unique.append(item)
139
+ langspec[key] = unique
140
+
141
+ return langspec
142
+
143
+ def skip(self, *args, **kwargs):
144
+ pass
145
+
146
+ def _visit(self, cursor: TreeCursor) -> ASTNode | list[ASTNode] | None:
147
+ # Function name of child class handling the node type
148
+ function_name = f'visit_{assert_node(cursor.node).type}'
149
+
150
+ # Enter into the node
151
+ has_children = cursor.goto_first_child()
152
+
153
+ # Default to skip, in case a specific visitor can't be found
154
+ # Generally the case for anonymous nodes (keywords etc) or
155
+ # named nodes (rules) that do not have a vistor implemented yet.
156
+ visitor = getattr(self, function_name, self.skip)
157
+
158
+ # Use visitor implementation
159
+ visitor_value = visitor(cursor)
160
+
161
+ # Exit the node
162
+ if has_children:
163
+ cursor.goto_parent()
164
+
165
+ return visitor_value
166
+
167
+ def _skip(self, cursor: TreeCursor) -> ASTNode | list[ASTNode] | None:
168
+ values = []
169
+ if visitor_value := self.visit(cursor):
170
+ values.append(visitor_value)
171
+ while go_to_sibling(cursor):
172
+ if visitor_value := self.visit(cursor):
173
+ values.append(visitor_value)
174
+ match len(values):
175
+ case 0:
176
+ return None
177
+ case 1:
178
+ return values[0]
179
+ case _:
180
+ return values
181
+
182
+ def visit_comment(self, cursor: TreeCursor, params=None):
183
+ return (None, None)
184
+
185
+
186
+ # Concrete visitor to process function definitions
187
+ class MalCompiler(ParseTreeVisitor):
188
+ # Named visit_{rule name in grammar.js}
189
+ def visit_define_declaration(self, cursor: TreeCursor) -> ASTNode:
190
+ ###############################
191
+ # '#' (identity) ':' (string) #
192
+ ###############################
193
+
194
+ # skip '#' node
195
+ go_to_sibling(cursor)
196
+ # grab (identity) node
197
+ key = node_text(cursor, 'id')
198
+ # next node
199
+ go_to_sibling(cursor)
200
+ # skip ':' node
201
+ go_to_sibling(cursor)
202
+ # grab (string) node
203
+ value = node_text(cursor, 'string')
204
+
205
+ return ('defines', {key.decode(): value.decode().strip('"')})
206
+
207
+ def visit_category_declaration(self, cursor: TreeCursor) -> ASTNode:
208
+ ############################################
209
+ # 'category' (id) (meta)* '{' (asset)* '}' #
210
+ ############################################
211
+
212
+ category = {}
213
+
214
+ # skip 'category'
215
+ go_to_sibling(cursor)
216
+ # grab (identity)
217
+ category['name'] = node_text(cursor, 'id').decode()
218
+ # next node
219
+ go_to_sibling(cursor)
220
+
221
+ # grab (meta)
222
+ #
223
+ # Since it is optional, we have to make sure we are dealing with a
224
+ # grammar rule
225
+ meta = {}
226
+ while assert_node(cursor.node).is_named:
227
+ info = self.visit(cursor)
228
+ meta[info[0]] = info[1]
229
+ go_to_sibling(cursor)
230
+ category['meta'] = meta
231
+
232
+ # skip '{' node
233
+ go_to_sibling(cursor)
234
+
235
+ # grab (asset)
236
+ assets = []
237
+ while assert_node(cursor.node).is_named:
238
+ asset = self.visit(cursor, category['name'])
239
+ assets.append(asset)
240
+ go_to_sibling(cursor)
241
+
242
+ # next node and skip '}' node
243
+ go_to_sibling(cursor), cursor.goto_next_sibling()
244
+
245
+ return ('categories', ([category], assets))
246
+
247
+ def visit_include_declaration(self, cursor: TreeCursor):
248
+ ####################
249
+ # 'include' (file) #
250
+ ####################
251
+
252
+ # skip 'include'
253
+ go_to_sibling(cursor)
254
+
255
+ # grab (file) which is a (string)
256
+ # '"' are ASCII so they are garantueed to only take one byte in UTF-8 (assumed for .decode)
257
+ # therefor, we can greedily only take bytes 1:-1
258
+ # strip surrounding quotes (") by slicing
259
+ return ('include', node_text(cursor, '')[1:-1].decode())
260
+
261
+ def visit_meta(self, cursor: TreeCursor) -> ASTNode:
262
+ ############################
263
+ # (id) 'info' ':' (string) #
264
+ ############################
265
+
266
+ # grab (id) node
267
+ id = node_text(cursor, 'id').decode()
268
+
269
+ # next node
270
+ go_to_sibling(cursor)
271
+ # skip 'info' node
272
+ go_to_sibling(cursor)
273
+ # skip ':' node
274
+ go_to_sibling(cursor)
275
+
276
+ # grab (string) node
277
+ # '"' are ASCII so they are garantueed to only take one byte in UTF-8 (assumed for .decode)
278
+ # therefor, we can greedily only take bytes 1:-1
279
+ # strip surrounding quotes (") by slicing
280
+ info_string = node_text(cursor, '')[1:-1].decode()
281
+
282
+ return (id, info_string)
283
+
284
+ def visit_asset_declaration(
285
+ self, cursor: TreeCursor, category: str
286
+ ) -> dict[str, Any]:
287
+ ##############################################################################
288
+ # (abstract)? 'asset' (id) (extends id)? (meta)* '{' (asset_definition)* '}' #
289
+ ##############################################################################
290
+
291
+ # grab (abstract)?
292
+ isAbstract = node_text(cursor, 'abstract') == b'abstract'
293
+ if isAbstract:
294
+ go_to_sibling(cursor) # We must go to 'asset'
295
+
296
+ # skip 'asset'
297
+ go_to_sibling(cursor)
298
+
299
+ # grab (id)
300
+ name = node_text(cursor, '').decode()
301
+ go_to_sibling(cursor)
302
+
303
+ # grab (extends id)?
304
+ superAsset = None
305
+ if node_text(cursor, '') == b'extends':
306
+ go_to_sibling(cursor) # move to the id
307
+ superAsset = node_text(cursor, '').decode() # get the text
308
+ go_to_sibling(cursor) # move to the meta
309
+
310
+ # grab (meta)*
311
+ meta = {}
312
+ while assert_node(cursor.node).is_named:
313
+ info = self.visit(cursor)
314
+ meta[info[0]] = info[1]
315
+ go_to_sibling(cursor)
316
+
317
+ # skip '{'
318
+ go_to_sibling(cursor)
319
+
320
+ # visit asset_definition
321
+ variables, attackSteps = [], []
322
+ if assert_node(cursor.node).is_named:
323
+ variables, attackSteps = self.visit(cursor)
324
+
325
+ return {
326
+ 'name': name,
327
+ 'meta': meta,
328
+ 'category': category,
329
+ 'isAbstract': isAbstract,
330
+ 'superAsset': superAsset,
331
+ 'variables': variables,
332
+ 'attackSteps': attackSteps,
333
+ }
334
+
335
+ def visit_asset_definition(self, cursor: TreeCursor) -> tuple[list, list]:
336
+ #######################
337
+ # (variable)* (step)* #
338
+ #######################
339
+
340
+ variables, steps = [], []
341
+ while True:
342
+ definition, result = self.visit(cursor)
343
+
344
+ if definition == 'variable':
345
+ variables.append(result)
346
+ elif definition == 'step':
347
+ steps.append(result)
348
+
349
+ if not go_to_sibling(cursor):
350
+ break
351
+
352
+ return (variables, steps)
353
+
354
+ def visit_asset_variable(self, cursor: TreeCursor) -> ASTNode:
355
+ ##########################
356
+ # 'let' (id) '=' (value) #
357
+ ##########################
358
+
359
+ ret = {}
360
+
361
+ # skip 'let'
362
+ go_to_sibling(cursor)
363
+
364
+ # grab id
365
+ ret['name'] = node_text(cursor, 'id').decode()
366
+ go_to_sibling(cursor)
367
+
368
+ # skip '='
369
+ go_to_sibling(cursor)
370
+
371
+ ret['stepExpression'] = self.visit(cursor)
372
+
373
+ # TODO visit step expression
374
+
375
+ return ('variable', ret)
376
+
377
+ def visit_attack_step(self, cursor: TreeCursor) -> tuple[str, dict]:
378
+ ##############################################################################################################
379
+ # (step_type) (id) ( '@' (id) )* ( '{' (cias) '}' )? (ttc)? (meta)* (detector)? (preconditions)? (reaches)? #
380
+ ##############################################################################################################
381
+
382
+ # grab (step_type)
383
+ # use raw text bytes to avoid decoding as much as possible
384
+ step_type_btext = node_text(cursor, 'step type')
385
+ step_type_bindings = {
386
+ b'&': 'and',
387
+ b'|': 'or',
388
+ b'#': 'defense',
389
+ b'E': 'exist',
390
+ b'!E': 'notExist',
391
+ }
392
+
393
+ # decode value only if its really necessary (no binding found)
394
+ if (step_type := step_type_bindings.get(step_type_btext)) is None:
395
+ step_type = step_type_btext.decode()
396
+ go_to_sibling(cursor)
397
+
398
+ # grab optional (causal_mode) before (id)
399
+ causal_mode = None
400
+ current_text = node_text(cursor, 'causal_mode')
401
+ if current_text in (b'action', b'effect'):
402
+ causal_mode = current_text.decode()
403
+ go_to_sibling(cursor) # skip causal_mode
404
+
405
+ # grab (id)
406
+ name = node_text(cursor, 'name').decode()
407
+ go_to_sibling(cursor)
408
+
409
+ # process all ( '@' (id) ) we might have
410
+ tags = []
411
+
412
+ # TODO change grammar to make (@ id)* instead of (@ id)?
413
+ while node_text(cursor, 'tag at') == b'@':
414
+ go_to_sibling(cursor) # skip '@'
415
+ tags.append(node_text(cursor, '').decode()) # grab (id)
416
+ if not go_to_sibling(cursor): # move to next symbol, break if last
417
+ break
418
+
419
+ # process all ( '{' (cias) '}' ) we might have
420
+ risk = None
421
+ if assert_node(cursor.node).text == b'{':
422
+ go_to_sibling(cursor) # skip '{'
423
+ risk = self.visit(cursor) # grab (cias)
424
+ go_to_sibling(cursor) # go to '}'
425
+ go_to_sibling(cursor) # and skip it
426
+
427
+ ttc = None
428
+ if assert_node(cursor.node).type == 'ttc':
429
+ # visit ttc
430
+ ttc = self.visit(cursor)
431
+ go_to_sibling(cursor)
432
+
433
+ meta = {}
434
+ while assert_node(cursor.node).type == 'meta':
435
+ info = self.visit(cursor)
436
+ meta[info[0]] = info[1]
437
+ if not go_to_sibling(cursor): # in case there is nothing after the meta
438
+ break
439
+
440
+ detectors: dict[str, Any] = {}
441
+ while assert_node(cursor.node).type == 'detector':
442
+ detector = self.visit(cursor)
443
+ detector_name = str(detector['name'])
444
+ detectors[detector_name] = detector
445
+ if not go_to_sibling(cursor): # in case there is nothing after the meta
446
+ break
447
+
448
+ requires = None
449
+ if assert_node(cursor.node).type == 'preconditions':
450
+ requires = self.visit(cursor)
451
+ go_to_sibling(cursor)
452
+
453
+ reaches = None
454
+ if assert_node(cursor.node).type == 'reaching':
455
+ reaches = self.visit(cursor)
456
+ go_to_sibling(cursor)
457
+
458
+ ret = {
459
+ 'name': name,
460
+ 'meta': meta,
461
+ 'detectors': detectors,
462
+ 'type': step_type,
463
+ 'causal_mode': causal_mode,
464
+ 'tags': tags,
465
+ 'risk': risk,
466
+ 'ttc': ttc,
467
+ 'requires': requires,
468
+ 'reaches': reaches,
469
+ }
470
+
471
+ return ('step', ret)
472
+
473
+ def visit_detector(self, cursor: TreeCursor):
474
+ ####################################################################
475
+ # ('!' | '//!') (detector_name)? (detector_context) (type)? (ttc)? #
476
+ ####################################################################
477
+
478
+ # skip bang
479
+ go_to_sibling(cursor)
480
+
481
+ # grab detector_name
482
+ detector_name = None
483
+ if cursor.field_name == 'name':
484
+ detector_name = node_text(cursor, 'detector name').decode()
485
+ go_to_sibling(cursor)
486
+
487
+ # grab detector_context
488
+ detector_context = self.visit(cursor)
489
+ go_to_sibling(cursor)
490
+
491
+ # grab id
492
+ detector_type = None
493
+ if cursor.field_name == 'type':
494
+ detector_name = node_text(cursor, 'type').decode()
495
+ go_to_sibling(cursor)
496
+
497
+ # grab ttc
498
+ detector_ttc = None
499
+ if cursor.field_name == 'ttc':
500
+ # TODO: this is broken
501
+ raise NotImplementedError('TTC not implemented for detectors')
502
+ # detector_ttc = self.visit(ttc)
503
+ # go_to_sibling(cursor)
504
+
505
+ return {
506
+ 'name': detector_name,
507
+ 'context': detector_context,
508
+ 'type': detector_type,
509
+ 'tprate': detector_ttc,
510
+ }
511
+
512
+ def visit_detector_context(self, cursor: TreeCursor):
513
+ ####################################################################
514
+ # '(' (detector_context_asset) (',' (detector_context_asset))* ')' #
515
+ ####################################################################
516
+
517
+ # skip '('
518
+ go_to_sibling(cursor)
519
+
520
+ # grab detector_context_asset
521
+ context = {}
522
+ label, asset = self.visit(cursor)
523
+ context[label] = asset
524
+ go_to_sibling(cursor)
525
+
526
+ while node_text(cursor, 'char') != b')':
527
+ # skip ','
528
+ go_to_sibling(cursor)
529
+ # grab another detector_context_asset
530
+ label, asset = self.visit(cursor)
531
+ context[label] = asset
532
+ go_to_sibling(cursor)
533
+
534
+ return context
535
+
536
+ def visit_detector_context_asset(self, cursor: TreeCursor):
537
+ ###############
538
+ # (type) (id) #
539
+ ###############
540
+ asset = node_text(cursor, 'asset').decode('utf-8')
541
+ label = node_text(cursor, 'label').decode('utf-8')
542
+
543
+ return (label, asset)
544
+
545
+ def visit_cias(self, cursor: TreeCursor):
546
+ ######################
547
+ # (cia) (',' (cia))* #
548
+ ######################
549
+ risk = {
550
+ 'isConfidentiality': False,
551
+ 'isIntegrity': False,
552
+ 'isAvailability': False,
553
+ }
554
+
555
+ while True:
556
+ val = self.visit(cursor)
557
+ risk.update(val)
558
+
559
+ ret = go_to_sibling(cursor)
560
+ if not ret: # no more ',' -> done
561
+ break
562
+
563
+ # Otherwise, process the next CIA
564
+ go_to_sibling(cursor)
565
+
566
+ return risk
567
+
568
+ def visit_cia(self, cursor: TreeCursor):
569
+ ###############
570
+ # 'C'|'I'|'A' #
571
+ ###############
572
+
573
+ cia_btext = node_text(cursor, 'cia')
574
+ cia_bindings = {
575
+ b'C': 'isConfidentiality',
576
+ b'I': 'isIntegrity',
577
+ b'A': 'isAvailability',
578
+ }
579
+ key = cia_bindings.get(cia_btext)
580
+
581
+ return {key: True}
582
+
583
+ def visit_ttc(self, cursor: TreeCursor):
584
+ ##################################
585
+ # '[' (intermediary_ttc_exp) ']' #
586
+ ##################################
587
+
588
+ # skip '['
589
+ go_to_sibling(cursor)
590
+
591
+ return self._visit_intermediary_ttc_expr(cursor)
592
+
593
+ def _visit_intermediary_ttc_expr(self, cursor: TreeCursor):
594
+ ###################################################################################################
595
+ # '(' (intermediary_ttc_expr) ')' | (integer) | (float) | (id) | (ttc_distribution) | (ttc_binop) #
596
+ ###################################################################################################
597
+
598
+ # check if we have '(', in this case it's a parenthesized expression
599
+ if node_text(cursor, 'char') == b'(':
600
+ go_to_sibling(cursor) # skip '('
601
+ result = self._visit_intermediary_ttc_expr(cursor) # visit the expression
602
+ go_to_sibling(cursor) # skip ')'
603
+ return result
604
+
605
+ # if we have an id, just return it
606
+ elif assert_node(cursor.node).type == 'identifier':
607
+ text = node_text(cursor, 'id').decode()
608
+ return {'type': 'function', 'name': text, 'arguments': []}
609
+
610
+ # if we have a number (integer/float) we need to construct
611
+ # the dictionary correctly
612
+ elif (
613
+ assert_node(cursor.node).type == 'float'
614
+ or assert_node(cursor.node).type == 'integer'
615
+ ):
616
+ ret: dict[str, Any] = {'type': 'number'}
617
+ ret['value'] = self.visit(cursor)
618
+ return ret
619
+
620
+ # otherwise visit the node
621
+ return self.visit(cursor)
622
+
623
+ def visit_float(self, cursor: TreeCursor):
624
+ ret = float(node_text(cursor, 'float'))
625
+
626
+ return ret
627
+
628
+ def visit_integer(self, cursor: TreeCursor):
629
+ ret = float(node_text(cursor, 'integer'))
630
+
631
+ return ret
632
+
633
+ def visit_ttc_binop(self, cursor: TreeCursor):
634
+ #########################################################################
635
+ # (intermediary_ttc_expr) ('+'|'-'|'*'|'/'|'^') (intermediary_ttc_expr) #
636
+ #########################################################################
637
+
638
+ # grab first (intermediary_ttc_expr)
639
+ lhs = self._visit_intermediary_ttc_expr(cursor)
640
+ go_to_sibling(cursor)
641
+
642
+ # grab operation type
643
+ operation = assert_node(cursor.node).text
644
+ operation_type_bindings = {
645
+ b'+': 'addition',
646
+ b'-': 'subtraction',
647
+ b'*': 'multiplication',
648
+ b'/': 'division',
649
+ b'^': 'exponentiation',
650
+ }
651
+ assert operation, 'Operation not found'
652
+ operation_type = operation_type_bindings.get(operation)
653
+ go_to_sibling(cursor)
654
+
655
+ # grab second (intermediary_ttc_expr)
656
+ rhs = self._visit_intermediary_ttc_expr(cursor)
657
+
658
+ return {'type': operation_type, 'lhs': lhs, 'rhs': rhs}
659
+
660
+ def visit_ttc_distribution(self, cursor: TreeCursor):
661
+ ############################################
662
+ # (id) '(' (number)* ( ',' (number) )* ')' #
663
+ ############################################
664
+
665
+ # grab (id)
666
+ name = node_text(cursor, 'name').decode()
667
+ go_to_sibling(cursor)
668
+
669
+ # skip '('
670
+ go_to_sibling(cursor)
671
+
672
+ # parse function arguments
673
+ args = []
674
+ while assert_node(cursor.node).type in ('float', 'integer'):
675
+ # obtain the number
676
+ arg = self.visit(cursor)
677
+ args.append(arg)
678
+ # move to next symbol, if it's not a comma then done
679
+ go_to_sibling(cursor)
680
+ if assert_node(cursor.node).text != b',':
681
+ break
682
+ # otherwise, ignore the comma
683
+ go_to_sibling(cursor)
684
+
685
+ return {'type': 'function', 'name': name, 'arguments': args}
686
+
687
+ def visit_preconditions(self, cursor: TreeCursor):
688
+ ##########################################
689
+ # '<-' (asset_expr) (',' (asset_expr) )* #
690
+ ##########################################
691
+
692
+ # Skip '<-'
693
+ go_to_sibling(cursor)
694
+
695
+ ret: dict[str, Any] = {}
696
+ ret['overrides'] = True
697
+ ret['stepExpressions'] = [self.visit(cursor)]
698
+
699
+ while go_to_sibling(cursor): # check if we have a ','
700
+ go_to_sibling(cursor) # ignore the ','
701
+ ret['stepExpressions'].append(self.visit(cursor))
702
+
703
+ return ret
704
+
705
+ def visit_reaching(self, cursor: TreeCursor):
706
+ ################################################
707
+ # ( '+>' | '->' ) (reaches) ( ',' (reaches) )* #
708
+ ################################################
709
+
710
+ ret: dict[str, Any] = {}
711
+
712
+ # Get type of reaches
713
+ ret['overrides'] = assert_node(cursor.node).text == b'->'
714
+ go_to_sibling(cursor)
715
+
716
+ # Visit the steps
717
+ ret['stepExpressions'] = [self.visit(cursor)]
718
+
719
+ while go_to_sibling(cursor): # check if we have a ','
720
+ go_to_sibling(cursor) # ignore the ','
721
+ ret['stepExpressions'].append(self.visit(cursor))
722
+
723
+ return ret
724
+
725
+ def visit_asset_expr(self, cursor: TreeCursor):
726
+ return self._visit_inline_asset_expr(cursor)
727
+
728
+ def _visit_inline_asset_expr(self, cursor: TreeCursor):
729
+ #############################################################################################################################################
730
+ # '(' (_inline_asset_expr) ')' | (id) | (asset_variable_substitution) | (asset_expr_binop) | (asset_expr_unop) | (asset_expr_type) #
731
+ #############################################################################################################################################
732
+
733
+ # The objective of this function is to mimick the _inline_asset_expr
734
+ # In other words, this function will figure out the type of the node it just received,
735
+ # pretending that it was an _inline_asset_expr
736
+
737
+ ret = {}
738
+
739
+ assert cursor.node, 'Missing node'
740
+ if assert_node(cursor.node).type == 'identifier':
741
+ ret['type'] = self._resolve_part_ID_type(cursor)
742
+ ret['name'] = node_text(cursor, 'name').decode()
743
+ elif node_text(cursor, 'char').decode() == '(':
744
+ go_to_sibling(cursor) # ignore the '('
745
+ ret = self._visit_inline_asset_expr(cursor)
746
+ go_to_sibling(cursor) # ignore the ')'
747
+ else:
748
+ ret = self.visit(cursor)
749
+
750
+ return ret
751
+
752
+ def visit_asset_variable_substitution(self, cursor: TreeCursor):
753
+ ################
754
+ # (id) '(' ')' #
755
+ ################
756
+
757
+ return {'type': 'variable', 'name': node_text(cursor, 'name').decode()}
758
+
759
+ def visit_asset_expr_type(self, cursor: TreeCursor):
760
+ #####################################
761
+ # (_inline_asset_expr) '[' (id) ']' #
762
+ #####################################
763
+
764
+ # On the ANTLR version, we would visit the subtypes from left to right,
765
+ # so we would have to store them recursively. However, in the TreeSitter
766
+ # version, we are starting from right to left, so we can just visit
767
+ # the `lhs` and return the current subtype
768
+
769
+ # Visit the inline expr
770
+ stepExpression = self._visit_inline_asset_expr(cursor)
771
+ go_to_sibling(cursor)
772
+
773
+ # Skip '['
774
+ go_to_sibling(cursor)
775
+
776
+ # Get the subType
777
+ subType = node_text(cursor, 'subType').decode()
778
+
779
+ return {'type': 'subType', 'subType': subType, 'stepExpression': stepExpression}
780
+
781
+ def visit_asset_expr_binop(self, cursor: TreeCursor):
782
+ ########################################################################
783
+ # (_inline_asset_expr) ( '\/' | '/\' | '-' | '.') (_inline_asset_expr) #
784
+ ########################################################################
785
+
786
+ # Get the lhs
787
+ lhs = self._visit_inline_asset_expr(cursor)
788
+ go_to_sibling(cursor)
789
+
790
+ # Get the type of operation
791
+ assert cursor.node, 'Missing node for operation type'
792
+ op_btext = node_text(cursor, '')
793
+ assert op_btext, 'Missing text for operation node'
794
+ optype_bindings = {
795
+ b'.': 'collect',
796
+ b'\\/': 'union',
797
+ b'/\\': 'intersection',
798
+ b'-': 'difference',
799
+ }
800
+ optype = optype_bindings.get(op_btext)
801
+ go_to_sibling(cursor)
802
+
803
+ # Get the rhs
804
+ rhs = self._visit_inline_asset_expr(cursor)
805
+ return {'type': optype, 'lhs': lhs, 'rhs': rhs}
806
+
807
+ def visit_asset_expr_unop(self, cursor: TreeCursor):
808
+ #############################
809
+ # (_inline_asset_expr) '*' #
810
+ #############################
811
+
812
+ # Get the associated expression
813
+ expr = self._visit_inline_asset_expr(cursor)
814
+ go_to_sibling(cursor)
815
+
816
+ return {'type': 'transitive', 'stepExpression': expr}
817
+
818
+ def _resolve_part_ID_type(self, cursor: TreeCursor):
819
+ # Figure out if we have a `field` or an `attackStep`
820
+ original_node = cursor.node
821
+ if not original_node:
822
+ raise ValueError('Missing node for id')
823
+
824
+ parent_node = original_node.parent
825
+
826
+ while parent_node and parent_node.type != 'reaching':
827
+ # The idea is to go up the tree. If we find a "reaching" node,
828
+ # we still need to determine if it's a field or a an attackStep
829
+ parent_node = parent_node.parent
830
+
831
+ if not parent_node:
832
+ # If we never find a "reaching" node, eventually we will go to
833
+ # the top of the tree, and we won't be able to go further up.
834
+ # In this case, we originally were in a `let` or `precondition`,
835
+ # which only accepts fields
836
+ return 'field'
837
+
838
+ # We want to know if there is any `.` after the context.
839
+ # If there is, we have a field (as an attackStep does not
840
+ # have attributes)
841
+ #
842
+ # To do this, we will find the start position of the the original
843
+ # node in the text. Each rule matches to one line in the end,
844
+ # so this node will be in the same row as its parent node and in
845
+ # a column inside the range of columns of its parent. So, we
846
+ # just have to split the whole text of the parent starting at the
847
+ # original node's position and iterate from there until the end of
848
+ # the text.
849
+
850
+ # The following logic was implemented to deal with how TreeSitter
851
+ # deals with indents and new lines
852
+
853
+ # We start by obtaining the column where the target node starts,
854
+ original_node_column = original_node.start_point.column
855
+
856
+ # We get the parent's text and split it into the original
857
+ # lines (as written in the code)
858
+ assert parent_node.text, 'Missing parent node text'
859
+ tokenStream = parent_node.text.decode()
860
+ tokenStream = tokenStream.split('\n')
861
+ tokenStream_split = None
862
+
863
+ # If the parent and the target are defined in the same line,
864
+ # then we must remove the start point from the original column,
865
+ # since TreeSitter deletes the indent
866
+ if original_node.start_point.row == parent_node.start_point.row:
867
+ tokenStream_split = tokenStream[0]
868
+ original_node_column = (
869
+ original_node.start_point.column - parent_node.start_point.column
870
+ )
871
+ # However, if they are in different rows, the indent must be included,
872
+ # so we use the same column
873
+ else:
874
+ tokenStream_split = tokenStream[
875
+ original_node.start_point.row - parent_node.start_point.row
876
+ ]
877
+
878
+ # Afterwards, we just do the normal checks, knowing what column to start in
879
+ assert original_node.text, 'Missing node text'
880
+ start_col = original_node_column + len(original_node.text.decode())
881
+ tokenStream_split = tokenStream_split[start_col:]
882
+ for char in tokenStream_split:
883
+ if char == '.':
884
+ return 'field' # Only a field can have attributes
885
+ if char == ',':
886
+ return 'attackStep' # A `,` means we are starting a new reaches
887
+
888
+ return 'attackStep'
889
+
890
+ def visit_associations_declaration(self, cursor: TreeCursor):
891
+ #########################################
892
+ # 'associations' '{' (association)* '}' #
893
+ #########################################
894
+
895
+ # skip 'associations'
896
+ go_to_sibling(cursor)
897
+
898
+ # skip '{'
899
+ go_to_sibling(cursor)
900
+
901
+ # visit all associations
902
+ associations = []
903
+ while cursor.node and node_text(cursor, '') != b'}':
904
+ associations.append(self.visit(cursor))
905
+ go_to_sibling(cursor)
906
+
907
+ return ('associations', associations)
908
+
909
+ def visit_association(self, cursor: TreeCursor):
910
+ ##############################################################################################
911
+ # (id) '[' (id) ']' (multiplicity) '<--' (id) '-->' (multiplicity) '[' (id) ']' (id) (meta)* #
912
+ ##############################################################################################
913
+
914
+ # Get 1st id - left asset
915
+ left_asset = node_text(cursor, 'left asset').decode()
916
+ go_to_sibling(cursor)
917
+
918
+ # skip '['
919
+ go_to_sibling(cursor)
920
+
921
+ # Get 2nd id - left field
922
+ left_field = node_text(cursor, 'left field').decode()
923
+ go_to_sibling(cursor)
924
+
925
+ # skip ']'
926
+ go_to_sibling(cursor)
927
+
928
+ # Get left multiplicity
929
+ left_multiplicity = self.visit(cursor)
930
+ go_to_sibling(cursor)
931
+
932
+ # skip '<--'
933
+ go_to_sibling(cursor)
934
+
935
+ # Get 3rd id - name of the association
936
+ name = node_text(cursor, 'name').decode()
937
+ go_to_sibling(cursor)
938
+
939
+ # skip '-->'
940
+ go_to_sibling(cursor)
941
+
942
+ # Get right multiplicity
943
+ right_multiplicity = self.visit(cursor)
944
+ go_to_sibling(cursor)
945
+
946
+ # skip '['
947
+ go_to_sibling(cursor)
948
+
949
+ # Get 4th id - right field
950
+ right_field = node_text(cursor, 'right field').decode()
951
+ go_to_sibling(cursor)
952
+
953
+ # skip ']'
954
+ go_to_sibling(cursor)
955
+
956
+ # Get 5th id - right asset
957
+ right_asset = node_text(cursor, 'right asset').decode()
958
+
959
+ # Get all metas
960
+ meta = {}
961
+ while go_to_sibling(cursor):
962
+ res = self.visit(cursor)
963
+ assert res
964
+ meta[res[0]] = res[1]
965
+
966
+ association = {
967
+ 'name': name,
968
+ 'meta': meta,
969
+ 'leftAsset': left_asset,
970
+ 'leftField': left_field,
971
+ 'leftMultiplicity': left_multiplicity,
972
+ 'rightAsset': right_asset,
973
+ 'rightField': right_field,
974
+ 'rightMultiplicity': right_multiplicity,
975
+ }
976
+
977
+ self._process_multitudes(association)
978
+
979
+ return association
980
+
981
+ def visit_multiplicity(self, cursor: TreeCursor):
982
+ ###############################################
983
+ # (_multiplicity_atom) | (multiplicity_range) #
984
+ ###############################################
985
+
986
+ if cursor.node is None:
987
+ raise MalCompilerError('multiplicity atom missing node')
988
+
989
+ if assert_node(cursor.node).type == 'multiplicity_range':
990
+ return self.visit(cursor)
991
+
992
+ # Otherwise we need to visit an intermediary function for
993
+ # atomic multiplicity expressions
994
+ min = self._visit_multiplicity_atom(cursor)
995
+ return {
996
+ 'min': min,
997
+ 'max': None,
998
+ }
999
+
1000
+ def visit_multiplicity_range(self, cursor: TreeCursor):
1001
+ ##################################################
1002
+ # (_multiplicity_atom) '..' (_multiplicity_atom) #
1003
+ ##################################################
1004
+
1005
+ min = self._visit_multiplicity_atom(cursor)
1006
+ go_to_sibling(cursor)
1007
+
1008
+ # skip '..'
1009
+ go_to_sibling(cursor)
1010
+
1011
+ max = self._visit_multiplicity_atom(cursor)
1012
+
1013
+ return {
1014
+ 'min': min,
1015
+ 'max': max,
1016
+ }
1017
+
1018
+ def _visit_multiplicity_atom(self, cursor: TreeCursor):
1019
+ ######################
1020
+ # (integer) | (star) #
1021
+ ######################
1022
+ if not cursor.node:
1023
+ raise MalCompilerError('multiplicity atom missing node')
1024
+ if not node_text(cursor, ''):
1025
+ raise ValueError('multiplicity atom has empty text')
1026
+ return node_text(cursor, '').decode()
1027
+
1028
+ def _process_multitudes(self, association):
1029
+ mult_keys = [
1030
+ # start the multatoms from right to left to make sure the rules
1031
+ # below get applied cleanly
1032
+ 'rightMultiplicity.max',
1033
+ 'rightMultiplicity.min',
1034
+ 'leftMultiplicity.max',
1035
+ 'leftMultiplicity.min',
1036
+ ]
1037
+
1038
+ for mult_key in mult_keys:
1039
+ key, subkey = mult_key.split('.')
1040
+
1041
+ # upper limit equals lower limit if not given
1042
+ if subkey == 'max' and association[key][subkey] is None:
1043
+ association[key][subkey] = association[key]['min']
1044
+
1045
+ if association[key][subkey] == '*':
1046
+ # 'any' as lower limit means start from 0
1047
+ if subkey == 'min':
1048
+ association[key][subkey] = 0
1049
+
1050
+ # 'any' as upper limit means not limit
1051
+ else:
1052
+ association[key][subkey] = None
1053
+
1054
+ # cast numerical strings to integers
1055
+ if (multatom := association[key][subkey]) and multatom.isdigit():
1056
+ association[key][subkey] = int(association[key][subkey])
1057
+
1058
+
1059
+ def assert_node(node: Node | None) -> Node:
1060
+ if node is None:
1061
+ raise ValueError('Node can not be None')
1062
+ return node
1063
+
1064
+
1065
+ def node_text(cursor: TreeCursor, context: str):
1066
+ node = cursor.node
1067
+ if node is None:
1068
+ raise MalCompilerError(f'expected node for {context}, found None')
1069
+ if not node.text:
1070
+ raise MalCompilerError(f'expected text for {context}, found empty')
1071
+ return node.text