mal-toolbox 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/METADATA +25 -2
  2. mal_toolbox-1.1.3.dist-info/RECORD +32 -0
  3. maltoolbox/__init__.py +6 -7
  4. maltoolbox/__main__.py +17 -9
  5. maltoolbox/attackgraph/__init__.py +2 -3
  6. maltoolbox/attackgraph/attackgraph.py +379 -362
  7. maltoolbox/attackgraph/node.py +14 -19
  8. maltoolbox/exceptions.py +7 -10
  9. maltoolbox/file_utils.py +10 -4
  10. maltoolbox/language/__init__.py +1 -1
  11. maltoolbox/language/compiler/__init__.py +4 -4
  12. maltoolbox/language/compiler/mal_lexer.py +154 -154
  13. maltoolbox/language/compiler/mal_parser.py +784 -1136
  14. maltoolbox/language/languagegraph.py +487 -639
  15. maltoolbox/model.py +64 -77
  16. maltoolbox/patternfinder/attackgraph_patterns.py +17 -8
  17. maltoolbox/translators/__init__.py +8 -0
  18. maltoolbox/translators/networkx.py +42 -0
  19. maltoolbox/translators/updater.py +18 -25
  20. maltoolbox/visualization/__init__.py +4 -4
  21. maltoolbox/visualization/draw_io_utils.py +6 -5
  22. maltoolbox/visualization/graphviz_utils.py +4 -2
  23. maltoolbox/visualization/neo4j_utils.py +13 -14
  24. maltoolbox/visualization/utils.py +2 -3
  25. mal_toolbox-1.1.1.dist-info/RECORD +0 -32
  26. maltoolbox/translators/securicad.py +0 -179
  27. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/WHEEL +0 -0
  28. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/entry_points.txt +0 -0
  29. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/licenses/AUTHORS +0 -0
  30. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/licenses/LICENSE +0 -0
  31. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/top_level.txt +0 -0
@@ -1,32 +1,37 @@
1
- """
2
- MAL-Toolbox Attack Graph Module
1
+ """MAL-Toolbox Attack Graph Module
3
2
  """
4
3
  from __future__ import annotations
4
+
5
5
  import copy
6
- import logging
7
6
  import json
7
+ import logging
8
8
  import sys
9
9
  import zipfile
10
-
11
- from itertools import chain
12
10
  from typing import TYPE_CHECKING
13
11
 
14
- from .node import AttackGraphNode
15
12
  from .. import log_configs
16
- from ..exceptions import AttackGraphStepExpressionError, AttackGraphException
17
- from ..exceptions import LanguageGraphException
18
- from ..model import Model
19
- from ..language import (LanguageGraph, ExpressionsChain,
20
- LanguageGraphAttackStep, disaggregate_attack_step_full_name)
13
+ from ..exceptions import (
14
+ AttackGraphException,
15
+ AttackGraphStepExpressionError,
16
+ LanguageGraphException,
17
+ )
21
18
  from ..file_utils import (
22
19
  load_dict_from_json_file,
23
20
  load_dict_from_yaml_file,
24
- save_dict_to_file
21
+ save_dict_to_file,
25
22
  )
26
-
23
+ from ..language import (
24
+ ExpressionsChain,
25
+ LanguageGraph,
26
+ LanguageGraphAttackStep,
27
+ disaggregate_attack_step_full_name,
28
+ )
29
+ from ..model import Model
30
+ from .node import AttackGraphNode
27
31
 
28
32
  if TYPE_CHECKING:
29
- from typing import Any, Optional
33
+ from typing import Any
34
+
30
35
  from ..model import ModelAsset
31
36
 
32
37
  logger = logging.getLogger(__name__)
@@ -39,10 +44,11 @@ def create_attack_graph(
39
44
  """Create and return an attack graph
40
45
 
41
46
  Args:
47
+ ----
42
48
  lang - path to language file (.mar or .mal) or a LanguageGraph object
43
49
  model - path to model file (yaml or json) or a Model object
44
- """
45
50
 
51
+ """
46
52
  # Load language
47
53
  if isinstance(lang, LanguageGraph):
48
54
  lang_graph = lang
@@ -87,46 +93,46 @@ def create_attack_graph(
87
93
  return attack_graph
88
94
 
89
95
 
90
- class AttackGraph():
96
+ class AttackGraph:
91
97
  """Graph representation of attack steps"""
92
- def __init__(self, lang_graph: LanguageGraph, model: Optional[Model] = None):
98
+
99
+ def __init__(self, lang_graph: LanguageGraph, model: Model | None = None):
93
100
  self.nodes: dict[int, AttackGraphNode] = {}
94
101
  self.attack_steps: list[AttackGraphNode] = []
95
102
  self.defense_steps: list[AttackGraphNode] = []
96
-
97
- # Dictionaries used in optimization to get nodes by id or full name
98
- # faster
99
- self._full_name_to_node: dict[str, AttackGraphNode] = {}
100
-
101
103
  self.model = model
102
104
  self.lang_graph = lang_graph
103
105
  self.next_node_id = 0
106
+
107
+ # Dictionary used in optimization to get nodes by full name faster
108
+ self._full_name_to_node: dict[str, AttackGraphNode] = {}
109
+
104
110
  if self.model is not None:
105
- self._generate_graph()
111
+ self._generate_graph(self.model)
106
112
 
107
113
  def __repr__(self) -> str:
108
- return (f'AttackGraph(Number of nodes: {len(self.nodes)}, '
109
- f'model: {self.model}, language: {self.lang_graph}')
114
+ return (
115
+ f'AttackGraph(Number of nodes: {len(self.nodes)}, '
116
+ f'model: {self.model}, language: {self.lang_graph}'
117
+ )
110
118
 
111
119
  def _to_dict(self) -> dict:
112
120
  """Convert AttackGraph to dict"""
113
121
  serialized_attack_steps = {}
114
122
  for ag_node in self.nodes.values():
115
- serialized_attack_steps[ag_node.full_name] =\
116
- ag_node.to_dict()
123
+ serialized_attack_steps[ag_node.full_name] = ag_node.to_dict()
117
124
  return {
118
125
  'attack_steps': serialized_attack_steps
119
126
  }
120
127
 
121
128
  def __deepcopy__(self, memo):
122
-
129
+ """Custom deepcopy implementation for attack graph"""
123
130
  # Check if the object is already in the memo dictionary
124
131
  if id(self) in memo:
125
132
  return memo[id(self)]
126
133
 
127
134
  copied_attackgraph = AttackGraph(self.lang_graph)
128
135
  copied_attackgraph.model = self.model
129
-
130
136
  copied_attackgraph.nodes = {}
131
137
 
132
138
  # Deep copy nodes
@@ -157,17 +163,16 @@ class AttackGraph():
157
163
 
158
164
  @classmethod
159
165
  def _from_dict(
160
- cls,
161
- serialized_object: dict,
162
- lang_graph: LanguageGraph,
163
- model: Optional[Model]=None
164
- ) -> AttackGraph:
166
+ cls,
167
+ serialized_object: dict,
168
+ lang_graph: LanguageGraph,
169
+ model: Model | None = None
170
+ ) -> AttackGraph:
165
171
  """Create AttackGraph from dict
166
172
  Args:
167
173
  serialized_object - AttackGraph in dict format
168
174
  model - Optional Model to add connections to
169
175
  """
170
-
171
176
  attack_graph = AttackGraph(lang_graph)
172
177
  attack_graph.model = model
173
178
  serialized_attack_steps: dict[str, dict] = serialized_object['attack_steps']
@@ -180,22 +185,27 @@ class AttackGraph():
180
185
  if model and 'asset' in node_dict:
181
186
  node_asset = model.get_asset_by_name(node_dict['asset'])
182
187
  if node_asset is None:
183
- msg = ('Failed to find asset with name "%s"'
184
- ' when loading from attack graph dict')
188
+ msg = (
189
+ 'Failed to find asset with name "%s"'
190
+ ' when loading from attack graph dict'
191
+ )
185
192
  logger.error(msg, node_dict["asset"])
186
193
  raise LookupError(msg % node_dict["asset"])
187
194
 
188
- lg_asset_name, lg_attack_step_name = \
195
+ lg_asset_name, lg_attack_step_name = (
189
196
  disaggregate_attack_step_full_name(
190
- node_dict['lang_graph_attack_step'])
191
- lg_attack_step = lang_graph.assets[lg_asset_name].\
192
- attack_steps[lg_attack_step_name]
197
+ node_dict['lang_graph_attack_step']
198
+ )
199
+ )
200
+ lg_attack_step = (
201
+ lang_graph.assets[lg_asset_name].attack_steps[lg_attack_step_name]
202
+ )
193
203
  ag_node = attack_graph.add_node(
194
- lg_attack_step = lg_attack_step,
195
- node_id = node_dict['id'],
196
- model_asset = node_asset,
197
- ttc_dist = node_dict['ttc'],
198
- existence_status = (
204
+ lg_attack_step=lg_attack_step,
205
+ node_id=node_dict['id'],
206
+ model_asset=node_asset,
207
+ ttc_dist=node_dict['ttc'],
208
+ existence_status=(
199
209
  bool(node_dict['existence_status'])
200
210
  if 'existence_status' in node_dict else None
201
211
  ),
@@ -215,7 +225,6 @@ class AttackGraph():
215
225
  else:
216
226
  node_asset.attack_step_nodes = [ag_node]
217
227
 
218
-
219
228
  # Re-establish links between nodes.
220
229
  for node_dict in serialized_attack_steps.values():
221
230
  _ag_node = attack_graph.nodes[node_dict['id']]
@@ -224,24 +233,23 @@ class AttackGraph():
224
233
  ' attack graph from dict')
225
234
  logger.error(msg, node_dict["id"])
226
235
  raise LookupError(msg % node_dict["id"])
227
- else:
228
- for child_id in node_dict['children']:
229
- child = attack_graph.nodes[int(child_id)]
230
- if child is None:
231
- msg = ('Failed to find child node with id %s'
232
- ' when loading from attack graph from dict')
233
- logger.error(msg, child_id)
234
- raise LookupError(msg % child_id)
235
- _ag_node.children.add(child)
236
-
237
- for parent_id in node_dict['parents']:
238
- parent = attack_graph.nodes[int(parent_id)]
239
- if parent is None:
240
- msg = ('Failed to find parent node with id %s '
241
- 'when loading from attack graph from dict')
242
- logger.error(msg, parent_id)
243
- raise LookupError(msg % parent_id)
244
- _ag_node.parents.add(parent)
236
+ for child_id in node_dict['children']:
237
+ child = attack_graph.nodes[int(child_id)]
238
+ if child is None:
239
+ msg = ('Failed to find child node with id %s'
240
+ ' when loading from attack graph from dict')
241
+ logger.error(msg, child_id)
242
+ raise LookupError(msg % child_id)
243
+ _ag_node.children.add(child)
244
+
245
+ for parent_id in node_dict['parents']:
246
+ parent = attack_graph.nodes[int(parent_id)]
247
+ if parent is None:
248
+ msg = ('Failed to find parent node with id %s '
249
+ 'when loading from attack graph from dict')
250
+ logger.error(msg, parent_id)
251
+ raise LookupError(msg % parent_id)
252
+ _ag_node.parents.add(parent)
245
253
 
246
254
  return attack_graph
247
255
 
@@ -250,7 +258,7 @@ class AttackGraph():
250
258
  cls,
251
259
  filename: str,
252
260
  lang_graph: LanguageGraph,
253
- model: Optional[Model] = None
261
+ model: Model | None = None
254
262
  ) -> AttackGraph:
255
263
  """Create from json or yaml file depending on file extension"""
256
264
  if model is not None:
@@ -267,34 +275,175 @@ class AttackGraph():
267
275
  else:
268
276
  raise ValueError('Unknown file extension, expected json/yml/yaml')
269
277
  return cls._from_dict(serialized_attack_graph,
270
- lang_graph, model = model)
278
+ lang_graph, model=model)
271
279
 
272
- def get_node_by_full_name(self, full_name: str) -> Optional[AttackGraphNode]:
273
- """
274
- Return the attack node that matches the full name provided.
280
+ def get_node_by_full_name(self, full_name: str) -> AttackGraphNode | None:
281
+ """Return the attack node that matches the full name provided.
275
282
 
276
283
  Arguments:
284
+ ---------
277
285
  full_name - the full name of the attack graph node we are looking
278
286
  for
279
287
 
280
288
  Return:
289
+ ------
281
290
  The attack step node that matches the given full name.
282
- """
283
291
 
284
- logger.debug(f'Looking up node with full name "%s"', full_name)
292
+ """
293
+ logger.debug('Looking up node with full name "%s"', full_name)
285
294
  return self._full_name_to_node.get(full_name)
286
295
 
296
+ def _follow_field_expr_chain(
297
+ self, target_assets: set[ModelAsset], expr_chain: ExpressionsChain
298
+ ):
299
+ # Change the target assets from the current ones to the
300
+ # associated assets given the specified field name.
301
+ if not expr_chain.fieldname:
302
+ raise LanguageGraphException(
303
+ '"field" step expression chain is missing fieldname.'
304
+ )
305
+ new_target_assets: set[ModelAsset] = set()
306
+ new_target_assets.update(
307
+ *(
308
+ asset.associated_assets.get(expr_chain.fieldname, set())
309
+ for asset in target_assets
310
+ )
311
+ )
312
+ return new_target_assets
313
+
314
+ def _follow_transitive_expr_chain(
315
+ self,
316
+ model: Model,
317
+ target_assets: set[ModelAsset],
318
+ expr_chain: ExpressionsChain
319
+ ):
320
+ if not expr_chain.sub_link:
321
+ raise LanguageGraphException(
322
+ '"transitive" step expression chain is missing sub link.'
323
+ )
324
+
325
+ new_assets = target_assets
326
+ while new_assets := self._follow_expr_chain(
327
+ model, new_assets, expr_chain.sub_link
328
+ ):
329
+ new_assets = new_assets.difference(target_assets)
330
+ if not new_assets:
331
+ break
332
+ target_assets.update(new_assets)
333
+ return target_assets
334
+
335
+ def _follow_subtype_expr_chain(
336
+ self,
337
+ model: Model,
338
+ target_assets: set[ModelAsset],
339
+ expr_chain: ExpressionsChain
340
+ ):
341
+ if not expr_chain.sub_link:
342
+ raise LanguageGraphException(
343
+ '"subType" step expression chain is missing sub link.'
344
+ )
345
+ new_target_assets = set()
346
+ new_target_assets.update(
347
+ self._follow_expr_chain(
348
+ model, target_assets, expr_chain.sub_link
349
+ )
350
+ )
351
+ selected_new_target_assets = set()
352
+ for asset in new_target_assets:
353
+ lang_graph_asset = self.lang_graph.assets[asset.type]
354
+ if not lang_graph_asset:
355
+ raise LookupError(
356
+ f'Failed to find asset "{asset.type}" in the '
357
+ 'language graph.'
358
+ )
359
+ lang_graph_subtype_asset = expr_chain.subtype
360
+ if not lang_graph_subtype_asset:
361
+ raise LookupError(
362
+ 'Failed to find asset "{expr_chain.subtype}" in '
363
+ 'the language graph.'
364
+ )
365
+ if lang_graph_asset.is_subasset_of(lang_graph_subtype_asset):
366
+ selected_new_target_assets.add(asset)
367
+
368
+ return selected_new_target_assets
369
+
370
+ def _follow_union_intersection_difference_expr_chain(
371
+ self,
372
+ model: Model,
373
+ target_assets: set[ModelAsset],
374
+ expr_chain: ExpressionsChain
375
+ ) -> set[Any]:
376
+ # The set operators are used to combine the left hand and
377
+ # right hand targets accordingly.
378
+ if not expr_chain.left_link:
379
+ raise LanguageGraphException(
380
+ '"%s" step expression chain is missing the left link.',
381
+ expr_chain.type
382
+ )
383
+ if not expr_chain.right_link:
384
+ raise LanguageGraphException(
385
+ '"%s" step expression chain is missing the right link.',
386
+ expr_chain.type
387
+ )
388
+ lh_targets = self._follow_expr_chain(
389
+ model, target_assets, expr_chain.left_link
390
+ )
391
+ rh_targets = self._follow_expr_chain(
392
+ model, target_assets, expr_chain.right_link
393
+ )
394
+
395
+ if expr_chain.type == 'union':
396
+ # Once the assets become hashable set operations should be
397
+ # used instead.
398
+ return lh_targets.union(rh_targets)
399
+
400
+ if expr_chain.type == 'intersection':
401
+ return lh_targets.intersection(rh_targets)
402
+
403
+ if expr_chain.type == 'difference':
404
+ return lh_targets.difference(rh_targets)
405
+
406
+ raise ValueError("Expr chain must be of type union, intersectin or difference")
407
+
408
+ def _follow_collect_expr_chain(
409
+ self,
410
+ model: Model,
411
+ target_assets: set[ModelAsset],
412
+ expr_chain: ExpressionsChain
413
+ ) -> set[Any]:
414
+ if not expr_chain.left_link:
415
+ raise LanguageGraphException(
416
+ '"collect" step expression chain missing the left link.'
417
+ )
418
+ if not expr_chain.right_link:
419
+ raise LanguageGraphException(
420
+ '"collect" step expression chain missing the right link.'
421
+ )
422
+ lh_targets = self._follow_expr_chain(
423
+ model,
424
+ target_assets,
425
+ expr_chain.left_link
426
+ )
427
+ rh_targets = set()
428
+ for lh_target in lh_targets:
429
+ rh_targets |= self._follow_expr_chain(
430
+ model,
431
+ {lh_target},
432
+ expr_chain.right_link
433
+ )
434
+ return rh_targets
435
+
287
436
  def _follow_expr_chain(
288
- self,
289
- model: Model,
290
- target_assets: set[ModelAsset],
291
- expr_chain: Optional[ExpressionsChain]
292
- ) -> set[Any]:
293
- """
294
- Recursively follow a language graph expressions chain on an instance
437
+ self,
438
+ model: Model,
439
+ target_assets: set[ModelAsset],
440
+ expr_chain: ExpressionsChain | None
441
+ ) -> set[Any]:
442
+ """Recursively follow a language graph expressions chain on an instance
295
443
  model.
296
444
 
297
445
  Arguments:
446
+ ---------
298
447
  model - a maltoolbox.model.Model on which to follow the
299
448
  expressions chain
300
449
  target_assets - the set of assets that this expressions chain
@@ -303,9 +452,10 @@ class AttackGraph():
303
452
  expr_chain - the expressions chain we are following
304
453
 
305
454
  Return:
455
+ ------
306
456
  A list of all of the target assets.
307
- """
308
457
 
458
+ """
309
459
  if expr_chain is None:
310
460
  # There is no expressions chain link left to follow return the
311
461
  # current target assets
@@ -315,128 +465,26 @@ class AttackGraph():
315
465
  # Avoid running json.dumps when not in debug
316
466
  logger.debug(
317
467
  'Following Expressions Chain:\n%s',
318
- json.dumps(expr_chain.to_dict(), indent = 2)
468
+ json.dumps(expr_chain.to_dict(), indent=2)
319
469
  )
320
470
 
321
471
  match (expr_chain.type):
322
472
  case 'union' | 'intersection' | 'difference':
323
- # The set operators are used to combine the left hand and
324
- # right hand targets accordingly.
325
- if not expr_chain.left_link:
326
- raise LanguageGraphException('"%s" step expression chain'
327
- ' is missing the left link.' % expr_chain.type)
328
- if not expr_chain.right_link:
329
- raise LanguageGraphException('"%s" step expression chain'
330
- ' is missing the right link.' % expr_chain.type)
331
- lh_targets = self._follow_expr_chain(
332
- model,
333
- target_assets,
334
- expr_chain.left_link
335
- )
336
- rh_targets = self._follow_expr_chain(
337
- model,
338
- target_assets,
339
- expr_chain.right_link
473
+ return self._follow_union_intersection_difference_expr_chain(
474
+ model, target_assets, expr_chain
340
475
  )
341
476
 
342
- match (expr_chain.type):
343
- # Once the assets become hashable set operations should be
344
- # used instead.
345
- case 'union':
346
- new_target_assets = lh_targets.union(rh_targets)
347
-
348
- case 'intersection':
349
- new_target_assets = lh_targets.intersection(rh_targets)
350
-
351
- case 'difference':
352
- new_target_assets = lh_targets.difference(rh_targets)
353
-
354
- return new_target_assets
355
-
356
477
  case 'field':
357
- # Change the target assets from the current ones to the
358
- # associated assets given the specified field name.
359
- if not expr_chain.fieldname:
360
- raise LanguageGraphException('"field" step expression '
361
- 'chain is missing fieldname.')
362
- new_target_assets = set()
363
- new_target_assets.update(
364
- *(
365
- asset.associated_assets.get(
366
- expr_chain.fieldname, set()
367
- ) for asset in target_assets
368
- )
369
- )
370
- return new_target_assets
478
+ return self._follow_field_expr_chain(target_assets, expr_chain)
371
479
 
372
480
  case 'transitive':
373
- if not expr_chain.sub_link:
374
- raise LanguageGraphException('"transitive" step '
375
- 'expression chain is missing sub link.')
376
-
377
- new_assets = target_assets
378
-
379
- while new_assets := self._follow_expr_chain(
380
- model, new_assets, expr_chain.sub_link
381
- ):
382
- if not (new_assets := new_assets.difference(target_assets)):
383
- break
384
-
385
- target_assets.update(new_assets)
386
-
387
- return target_assets
481
+ return self._follow_transitive_expr_chain(model, target_assets, expr_chain)
388
482
 
389
483
  case 'subType':
390
- if not expr_chain.sub_link:
391
- raise LanguageGraphException('"subType" step '
392
- 'expression chain is missing sub link.')
393
- new_target_assets = set()
394
- new_target_assets.update(
395
- self._follow_expr_chain(
396
- model, target_assets, expr_chain.sub_link
397
- )
398
- )
399
-
400
- selected_new_target_assets = set()
401
- for asset in new_target_assets:
402
- lang_graph_asset = self.lang_graph.assets[asset.type]
403
- if not lang_graph_asset:
404
- raise LookupError(
405
- f'Failed to find asset \"{asset.type}\" in the '
406
- 'language graph.'
407
- )
408
- lang_graph_subtype_asset = expr_chain.subtype
409
- if not lang_graph_subtype_asset:
410
- raise LookupError(
411
- 'Failed to find asset "%s" in the '
412
- 'language graph.' % expr_chain.subtype
413
- )
414
- if lang_graph_asset.is_subasset_of(
415
- lang_graph_subtype_asset):
416
- selected_new_target_assets.add(asset)
417
-
418
- return selected_new_target_assets
484
+ return self._follow_subtype_expr_chain(model, target_assets, expr_chain)
419
485
 
420
486
  case 'collect':
421
- if not expr_chain.left_link:
422
- raise LanguageGraphException('"collect" step expression chain'
423
- ' is missing the left link.')
424
- if not expr_chain.right_link:
425
- raise LanguageGraphException('"collect" step expression chain'
426
- ' is missing the right link.')
427
- lh_targets = self._follow_expr_chain(
428
- model,
429
- target_assets,
430
- expr_chain.left_link
431
- )
432
- rh_targets = set()
433
- for lh_target in lh_targets:
434
- rh_targets |= self._follow_expr_chain(
435
- model,
436
- {lh_target},
437
- expr_chain.right_link
438
- )
439
- return rh_targets
487
+ return self._follow_collect_expr_chain(model, target_assets, expr_chain)
440
488
 
441
489
  case _:
442
490
  msg = 'Unknown attack expressions chain type: %s'
@@ -447,179 +495,151 @@ class AttackGraph():
447
495
  raise AttackGraphStepExpressionError(
448
496
  msg % expr_chain.type
449
497
  )
450
- return None
451
498
 
452
- def _generate_graph(self) -> None:
453
- """
454
- Generate the attack graph based on the original model instance and the
455
- MAL language specification provided at initialization.
499
+ def _get_existance_status(
500
+ self,
501
+ model: Model,
502
+ asset: ModelAsset,
503
+ attack_step: LanguageGraphAttackStep
504
+ ) -> bool | None:
505
+ """Get existance status of a step"""
506
+ if attack_step.type not in ('exist', 'notExist'):
507
+ # No existence status for other type of steps
508
+ return None
509
+
510
+ existence_status = False
511
+ for requirement in attack_step.requires:
512
+ target_assets = self._follow_expr_chain(
513
+ model, set([asset]), requirement
514
+ )
515
+ # If the step expression resolution yielded
516
+ # the target assets then the required assets
517
+ # exist in the model.
518
+ if target_assets:
519
+ existence_status = True
520
+ break
521
+
522
+ return existence_status
523
+
524
+ def _get_ttc_dist(
525
+ self,
526
+ asset: ModelAsset,
527
+ attack_step: LanguageGraphAttackStep
528
+ ):
529
+ """Get step ttc distribution based on language
530
+ and possibly overriding defense status
456
531
  """
532
+ ttc_dist = copy.deepcopy(attack_step.ttc)
533
+ if attack_step.type == 'defense':
534
+ if attack_step.name in asset.defenses:
535
+ # If defense status was set in model, set ttc accordingly
536
+ defense_value = float(asset.defenses[attack_step.name])
537
+ ttc_dist = {
538
+ 'arguments': [defense_value],
539
+ 'name': 'Bernoulli',
540
+ 'type': 'function'
541
+ }
542
+ logger.debug(
543
+ 'Setting defense \"%s\" to "%s".',
544
+ asset.name + ":" + attack_step.name, defense_value
545
+ )
546
+ return ttc_dist
457
547
 
458
- if not self.model:
459
- msg = "Can not generate AttackGraph without model"
460
- logger.error(msg)
461
- raise AttackGraphException(msg)
462
-
463
- # First, generate all of the nodes of the attack graph.
464
- for asset in self.model.assets.values():
465
-
466
- logger.debug(
467
- 'Generating attack steps for asset %s which is of class %s.',
468
- asset.name, asset.type
469
- )
548
+ def _generate_graph(self, model: Model) -> None:
549
+ """Generate the attack graph from model and MAL language."""
550
+ self.nodes = {}
551
+ self._full_name_to_node = {}
470
552
 
471
- attack_step_nodes = []
553
+ self._create_nodes_from_model(model)
554
+ self._link_nodes_by_language(model)
472
555
 
556
+ def _create_nodes_from_model(self, model: Model) -> None:
557
+ """Create attack graph nodes for all model assets."""
558
+ for asset in model.assets.values():
559
+ asset.attack_step_nodes = []
473
560
  for attack_step in asset.lg_asset.attack_steps.values():
474
- logger.debug(
475
- 'Generating attack step node for %s.', attack_step.name
561
+ node = self.add_node(
562
+ lg_attack_step=attack_step,
563
+ model_asset=asset,
564
+ ttc_dist=self._get_ttc_dist(asset, attack_step),
565
+ existence_status=(
566
+ self._get_existance_status(model, asset, attack_step)
567
+ ),
476
568
  )
569
+ asset.attack_step_nodes.append(node)
477
570
 
478
- existence_status = None
479
- node_name = asset.name + ':' + attack_step.name
480
-
481
- ttc_dist = copy.deepcopy(attack_step.ttc)
482
- match (attack_step.type):
483
- case 'defense':
484
- # Set the TTC probability for defenses
485
- # that were explicitly set in model
486
- if attack_step.name in asset.defenses:
487
- defense_value = float(
488
- asset.defenses[attack_step.name]
489
- )
490
- ttc_dist = {
491
- 'arguments': [defense_value],
492
- 'name': 'Bernoulli',
493
- 'type': 'function'
494
- }
495
- logger.debug(
496
- 'Setting defense \"%s\" to "%s".',
497
- node_name, defense_value
498
- )
499
-
500
- case 'exist' | 'notExist':
501
- # Resolve step expression associated with
502
- # (non-)existence attack steps.
503
- existence_status = False
504
- for requirement in attack_step.requires:
505
- target_assets = self._follow_expr_chain(
506
- self.model,
507
- set([asset]),
508
- requirement
509
- )
510
- # If the step expression resolution yielded
511
- # the target assets then the required assets
512
- # exist in the model.
513
- if target_assets:
514
- existence_status = True
515
- break
516
-
517
- logger.debug(
518
- 'Setting the existence status of \"%s\" to '
519
- '%s.',
520
- node_name, existence_status
521
- )
522
-
523
- case _:
524
- pass
525
-
526
- ag_node = self.add_node(
527
- lg_attack_step = attack_step,
528
- model_asset = asset,
529
- ttc_dist = ttc_dist,
530
- existence_status = existence_status
531
- )
532
- attack_step_nodes.append(ag_node)
533
-
534
- asset.attack_step_nodes = attack_step_nodes
535
-
536
- # Then, link all of the nodes according to their associations.
571
+ def _link_nodes_by_language(self, model: Model) -> None:
572
+ """Establish parent-child links between nodes."""
537
573
  for ag_node in self.nodes.values():
538
- logger.debug(
539
- 'Determining children for attack step "%s"(%d)',
540
- ag_node.full_name,
541
- ag_node.id
542
- )
574
+ self._link_node_children(model, ag_node)
543
575
 
544
- if not ag_node.model_asset:
545
- raise AttackGraphException('Attack graph node is missing '
546
- 'asset link')
576
+ def _link_node_children(self, model: Model, ag_node: AttackGraphNode) -> None:
577
+ """Link one node to its children."""
578
+ if not ag_node.model_asset:
579
+ raise AttackGraphException('Attack graph node is missing asset link')
547
580
 
548
- lang_graph_asset = self.lang_graph.assets[ag_node.model_asset.type]
549
- lang_graph_attack_step: Optional[LanguageGraphAttackStep] = (
550
- lang_graph_asset.attack_steps[ag_node.name]
581
+ lg_asset = self.lang_graph.assets[ag_node.model_asset.type]
582
+ lg_attack_step: LanguageGraphAttackStep | None = (
583
+ lg_asset.attack_steps[ag_node.name]
584
+ )
585
+ while lg_attack_step:
586
+ for child_type, expr_chains in lg_attack_step.children.items():
587
+ for expr_chain in expr_chains:
588
+ self._link_from_expr_chain(model, ag_node, child_type, expr_chain)
589
+ if lg_attack_step.overrides:
590
+ break
591
+ lg_attack_step = lg_attack_step.inherits
592
+
593
+ def _link_from_expr_chain(
594
+ self,
595
+ model: Model,
596
+ ag_node: AttackGraphNode,
597
+ child_type: LanguageGraphAttackStep,
598
+ expr_chain: ExpressionsChain | None,
599
+ ) -> None:
600
+ """Link a node to targets from a specific expression chain."""
601
+ if not ag_node.model_asset:
602
+ raise AttackGraphException(
603
+ "Need model asset connection to generate graph"
551
604
  )
552
605
 
553
- while lang_graph_attack_step:
554
- for target_attack_step, expr_chains in lang_graph_attack_step.children.items():
555
- for expr_chain in expr_chains:
556
- target_assets = self._follow_expr_chain(
557
- self.model,
558
- set([ag_node.model_asset]),
559
- expr_chain
560
- )
561
-
562
- for target_asset in target_assets:
563
- if target_asset is not None:
564
- target_node_full_name = target_asset.name + \
565
- ':' + target_attack_step.name
566
- target_node = self.get_node_by_full_name(
567
- target_node_full_name)
568
- if target_node is None:
569
- msg = ('Failed to find target node '
570
- '"%s" to link with for attack '
571
- 'step "%s"(%d)!')
572
- logger.error(
573
- msg,
574
- target_node_full_name,
575
- ag_node.full_name,
576
- ag_node.id
577
- )
578
- raise AttackGraphStepExpressionError(
579
- msg % (
580
- target_node_full_name,
581
- ag_node.full_name,
582
- ag_node.id
583
- )
584
- )
585
-
586
- assert ag_node.id is not None
587
- assert target_node.id is not None
588
-
589
- logger.debug('Linking attack step "%s"(%d) '
590
- 'to attack step "%s"(%d)' %
591
- (
592
- ag_node.full_name,
593
- ag_node.id,
594
- target_node.full_name,
595
- target_node.id
596
- )
597
- )
598
- ag_node.children.add(target_node)
599
- target_node.parents.add(ag_node)
600
- if lang_graph_attack_step.overrides:
601
- break
602
- lang_graph_attack_step = lang_graph_attack_step.inherits
603
-
606
+ target_assets = self._follow_expr_chain(model, {ag_node.model_asset}, expr_chain)
607
+ for target_asset in target_assets:
608
+ if not target_asset:
609
+ continue
610
+ target_node = self.get_node_by_full_name(
611
+ f"{target_asset.name}:{child_type.name}"
612
+ )
613
+ if not target_node:
614
+ raise AttackGraphStepExpressionError(
615
+ f'Failed to find target node "{target_asset.name}:{child_type.name}" '
616
+ f'for "{ag_node.full_name}"({ag_node.id})'
617
+ )
618
+ logger.debug(
619
+ 'Linking attack step "%s"(%d) to attack step "%s"(%d)',
620
+ ag_node.full_name, ag_node.id,
621
+ target_node.full_name, target_node.id
622
+ )
623
+ ag_node.children.add(target_node)
624
+ target_node.parents.add(ag_node)
604
625
 
605
626
  def regenerate_graph(self) -> None:
606
- """
607
- Regenerate the attack graph based on the original model instance and
627
+ """Regenerate the attack graph based on the original model instance and
608
628
  the MAL language specification provided at initialization.
609
629
  """
610
-
611
630
  self.nodes = {}
612
- self._generate_graph()
631
+ assert self.model, "Model required to generate graph"
632
+ self._generate_graph(self.model)
613
633
 
614
634
  def add_node(
615
- self,
616
- lg_attack_step: LanguageGraphAttackStep,
617
- node_id: Optional[int] = None,
618
- model_asset: Optional[ModelAsset] = None,
619
- ttc_dist: Optional[dict] = None,
620
- existence_status: Optional[bool] = None,
621
- full_name: Optional[str] = None
622
- ) -> AttackGraphNode:
635
+ self,
636
+ lg_attack_step: LanguageGraphAttackStep,
637
+ node_id: int | None = None,
638
+ model_asset: ModelAsset | None = None,
639
+ ttc_dist: dict | None = None,
640
+ existence_status: bool | None = None,
641
+ full_name: str | None = None
642
+ ) -> AttackGraphNode:
623
643
  """Create and add a node to the graph
624
644
  Arguments:
625
645
  lg_attack_step - the language graph attack step that corresponds
@@ -642,33 +662,29 @@ class AttackGraph():
642
662
  for exist and notExist type nodes.
643
663
 
644
664
  Return:
665
+ ------
645
666
  The newly created attack step node.
667
+
646
668
  """
647
669
  node_id = node_id if node_id is not None else self.next_node_id
648
670
  if node_id in self.nodes:
649
671
  raise ValueError(f'Node index {node_id} already in use.')
650
672
  self.next_node_id = max(node_id + 1, self.next_node_id)
651
673
 
652
- if logger.isEnabledFor(logging.DEBUG):
653
- # Avoid running json.dumps when not in debug
654
- logger.debug('Create and add to attackgraph node of type "%s" '
655
- 'with id:%d.\n' % (
656
- lg_attack_step.full_name,
657
- node_id
658
- ))
659
-
674
+ logger.debug(
675
+ 'Create and add to attackgraph node of type "%s" with id:%d.\n',
676
+ lg_attack_step.full_name, node_id
677
+ )
660
678
 
661
679
  node = AttackGraphNode(
662
- node_id = node_id,
663
- lg_attack_step = lg_attack_step,
664
- model_asset = model_asset,
665
- ttc_dist = ttc_dist,
666
- existence_status = existence_status,
667
- full_name = full_name
680
+ node_id=node_id,
681
+ lg_attack_step=lg_attack_step,
682
+ model_asset=model_asset,
683
+ ttc_dist=ttc_dist,
684
+ existence_status=existence_status,
685
+ full_name=full_name
668
686
  )
669
687
 
670
- self.nodes[node_id] = node
671
-
672
688
  # Add to different lists depending on types
673
689
  # Useful but not vital for functionality
674
690
  if node.type in ('or', 'and'):
@@ -676,6 +692,7 @@ class AttackGraph():
676
692
  if node.type == 'defense':
677
693
  self.defense_steps.append(node)
678
694
 
695
+ self.nodes[node_id] = node
679
696
  self._full_name_to_node[node.full_name] = node
680
697
 
681
698
  return node
@@ -685,15 +702,15 @@ class AttackGraph():
685
702
  Arguments:
686
703
  node - the node we wish to remove from the attack graph
687
704
  """
688
- if logger.isEnabledFor(logging.DEBUG):
689
- # Avoid running json.dumps when not in debug
690
- logger.debug(f'Remove node "%s"(%d).', node.full_name, node.id)
705
+ logger.debug(
706
+ 'Remove node "%s"(%d).', node.full_name, node.id
707
+ )
691
708
  for child in node.children:
692
709
  child.parents.remove(node)
693
710
  for parent in node.parents:
694
711
  parent.children.remove(node)
695
712
 
696
713
  if not isinstance(node.id, int):
697
- raise ValueError(f'Invalid node id.')
714
+ raise ValueError('Invalid node id.')
698
715
  del self.nodes[node.id]
699
716
  del self._full_name_to_node[node.full_name]