mal-toolbox 1.2.1__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/METADATA +8 -75
  2. mal_toolbox-2.1.0.dist-info/RECORD +51 -0
  3. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/WHEEL +1 -1
  4. maltoolbox/__init__.py +2 -2
  5. maltoolbox/attackgraph/__init__.py +2 -2
  6. maltoolbox/attackgraph/attackgraph.py +121 -549
  7. maltoolbox/attackgraph/factories.py +68 -0
  8. maltoolbox/attackgraph/file_utils.py +0 -0
  9. maltoolbox/attackgraph/generate.py +338 -0
  10. maltoolbox/attackgraph/node.py +1 -0
  11. maltoolbox/attackgraph/node_getters.py +36 -0
  12. maltoolbox/attackgraph/ttcs.py +28 -0
  13. maltoolbox/language/__init__.py +2 -2
  14. maltoolbox/language/compiler/__init__.py +4 -499
  15. maltoolbox/language/compiler/distributions.py +158 -0
  16. maltoolbox/language/compiler/exceptions.py +37 -0
  17. maltoolbox/language/compiler/lang.py +5 -0
  18. maltoolbox/language/compiler/mal_analyzer.py +920 -0
  19. maltoolbox/language/compiler/mal_compiler.py +1071 -0
  20. maltoolbox/language/detector.py +43 -0
  21. maltoolbox/language/expression_chain.py +218 -0
  22. maltoolbox/language/language_graph_asset.py +180 -0
  23. maltoolbox/language/language_graph_assoc.py +147 -0
  24. maltoolbox/language/language_graph_attack_step.py +129 -0
  25. maltoolbox/language/language_graph_builder.py +282 -0
  26. maltoolbox/language/language_graph_loaders.py +7 -0
  27. maltoolbox/language/language_graph_lookup.py +140 -0
  28. maltoolbox/language/language_graph_serialization.py +5 -0
  29. maltoolbox/language/languagegraph.py +244 -1536
  30. maltoolbox/language/step_expression_processor.py +491 -0
  31. mal_toolbox-1.2.1.dist-info/RECORD +0 -33
  32. maltoolbox/language/compiler/mal_lexer.py +0 -232
  33. maltoolbox/language/compiler/mal_parser.py +0 -3159
  34. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/entry_points.txt +0 -0
  35. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/licenses/AUTHORS +0 -0
  36. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/licenses/LICENSE +0 -0
  37. {mal_toolbox-1.2.1.dist-info → mal_toolbox-2.1.0.dist-info}/top_level.txt +0 -0
@@ -3,100 +3,129 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import copy
6
- import json
7
6
  import logging
8
- import sys
9
- import zipfile
10
- from typing import TYPE_CHECKING
11
-
12
- from .. import log_configs
13
- from ..exceptions import (
14
- AttackGraphException,
15
- AttackGraphStepExpressionError,
16
- LanguageGraphException,
17
- )
18
- from ..file_utils import (
19
- load_dict_from_json_file,
20
- load_dict_from_yaml_file,
21
- save_dict_to_file,
22
- )
23
- from ..language import (
24
- ExpressionsChain,
25
- LanguageGraph,
26
- LanguageGraphAttackStep,
27
- disaggregate_attack_step_full_name,
28
- )
29
-
30
- from ..str_utils import levenshtein_distance
7
+ from typing import TYPE_CHECKING, Optional
8
+
9
+ from maltoolbox.attackgraph.generate import generate_graph
10
+ from maltoolbox.attackgraph.node_getters import get_node_by_full_name
11
+ from maltoolbox.language.languagegraph import disaggregate_attack_step_full_name
12
+
13
+ from ..file_utils import load_dict_from_json_file, load_dict_from_yaml_file, save_dict_to_file
14
+ from ..language import LanguageGraph, LanguageGraphAttackStep
31
15
  from ..model import Model
32
16
  from .node import AttackGraphNode
33
17
 
34
18
  if TYPE_CHECKING:
35
- from typing import Any
36
-
37
19
  from ..model import ModelAsset
38
20
 
39
21
  logger = logging.getLogger(__name__)
40
22
 
41
23
 
42
- def create_attack_graph(
43
- lang: str | LanguageGraph,
44
- model: str | Model,
45
- ) -> AttackGraph:
46
- """Create and return an attack graph
47
-
48
- Args:
49
- ----
50
- lang - path to language file (.mar or .mal) or a LanguageGraph object
51
- model - path to model file (yaml or json) or a Model object
52
-
53
- """
54
- # Load language
55
- if isinstance(lang, LanguageGraph):
56
- lang_graph = lang
57
- elif isinstance(lang, str):
58
- # Load from path
59
- try:
60
- lang_graph = LanguageGraph.from_mar_archive(lang)
61
- except zipfile.BadZipFile:
62
- lang_graph = LanguageGraph.from_mal_spec(lang)
63
- else:
64
- raise TypeError("`lang` must be either string or LanguageGraph")
24
+ def attack_graph_from_dict(
25
+ serialized_object: dict, lang_graph: LanguageGraph, model: Optional[Model]
26
+ ):
27
+ attack_graph = AttackGraph(lang_graph)
28
+ attack_graph.model = model
29
+ serialized_attack_steps: dict[str, dict] = serialized_object['attack_steps']
30
+
31
+ # Create all of the nodes in the imported attack graph.
32
+ for node_full_name, node_dict in serialized_attack_steps.items():
33
+
34
+ # Recreate asset links if model is available.
35
+ node_asset = None
36
+ if model and 'asset' in node_dict:
37
+ node_asset = model.get_asset_by_name(node_dict['asset'])
38
+ if node_asset is None:
39
+ msg = (
40
+ 'Failed to find asset with name "%s"'
41
+ ' when loading from attack graph dict'
42
+ )
43
+ logger.error(msg, node_dict["asset"])
44
+ raise LookupError(msg % node_dict["asset"])
65
45
 
66
- if 'langspec_file' in log_configs:
67
- lang_graph.save_language_specification_to_json(
68
- log_configs['langspec_file']
46
+ lg_asset_name, lg_attack_step_name = (
47
+ disaggregate_attack_step_full_name(
48
+ node_dict['lang_graph_attack_step']
49
+ )
69
50
  )
70
-
71
- if 'langgraph_file' in log_configs:
72
- lang_graph.save_to_file(log_configs['langgraph_file'])
73
-
74
- # Load model
75
- if isinstance(model, Model):
76
- instance_model = model
77
- elif isinstance(model, str):
78
- # Load from path
79
- instance_model = Model.load_from_file(model, lang_graph)
80
- else:
81
- raise TypeError("`model` must be either string or Model")
82
-
83
- if log_configs['model_file']:
84
- instance_model.save_to_file(log_configs['model_file'])
85
-
86
- try:
87
- attack_graph = AttackGraph(lang_graph, instance_model)
88
- except AttackGraphStepExpressionError:
89
- logger.error(
90
- 'Attack graph generation failed when attempting '
91
- 'to resolve attack step expression!'
51
+ lg_attack_step = (
52
+ lang_graph.assets[lg_asset_name].attack_steps[lg_attack_step_name]
53
+ )
54
+ ag_node = attack_graph.add_node(
55
+ lg_attack_step=lg_attack_step,
56
+ node_id=node_dict['id'],
57
+ model_asset=node_asset,
58
+ ttc_dist=node_dict['ttc'],
59
+ existence_status=(
60
+ bool(node_dict['existence_status'])
61
+ if 'existence_status' in node_dict else None
62
+ ),
63
+ # Give explicit full name if model is missing, otherwise
64
+ # it will generate automatically in node.full_name
65
+ full_name=node_full_name if not model else None
92
66
  )
93
- sys.exit(1)
67
+ ag_node.tags = list(node_dict.get('tags', []))
68
+ ag_node.extras = node_dict.get('extras', {})
69
+
70
+ if node_asset:
71
+ # Add AttackGraphNode to attack_step_nodes of asset
72
+ if hasattr(node_asset, 'attack_step_nodes'):
73
+ node_attack_steps = list(node_asset.attack_step_nodes)
74
+ node_attack_steps.append(ag_node)
75
+ node_asset.attack_step_nodes = node_attack_steps
76
+ else:
77
+ node_asset.attack_step_nodes = [ag_node]
78
+
79
+ # Re-establish links between nodes.
80
+ for node_dict in serialized_attack_steps.values():
81
+ _ag_node = attack_graph.nodes[node_dict['id']]
82
+ if not isinstance(_ag_node, AttackGraphNode):
83
+ msg = ('Failed to find node with id %s when loading'
84
+ ' attack graph from dict')
85
+ logger.error(msg, node_dict["id"])
86
+ raise LookupError(msg % node_dict["id"])
87
+ for child_id in node_dict['children']:
88
+ child = attack_graph.nodes[int(child_id)]
89
+ if child is None:
90
+ msg = ('Failed to find child node with id %s'
91
+ ' when loading from attack graph from dict')
92
+ logger.error(msg, child_id)
93
+ raise LookupError(msg % child_id)
94
+ _ag_node.children.add(child)
95
+
96
+ for parent_id in node_dict['parents']:
97
+ parent = attack_graph.nodes[int(parent_id)]
98
+ if parent is None:
99
+ msg = ('Failed to find parent node with id %s '
100
+ 'when loading from attack graph from dict')
101
+ logger.error(msg, parent_id)
102
+ raise LookupError(msg % parent_id)
103
+ _ag_node.parents.add(parent)
94
104
 
95
105
  return attack_graph
96
106
 
97
107
 
108
+ def attack_graph_from_file(
109
+ filename: str, lang_graph: LanguageGraph, model: Optional[Model]
110
+ ):
111
+ if model is not None:
112
+ logger.debug('Load attack graph from file "%s" with '
113
+ 'model "%s".', filename, model.name)
114
+ else:
115
+ logger.debug('Load attack graph from file "%s" '
116
+ 'without model.', filename)
117
+ serialized_attack_graph = None
118
+ if filename.endswith(('.yml', '.yaml')):
119
+ serialized_attack_graph = load_dict_from_yaml_file(filename)
120
+ elif filename.endswith('.json'):
121
+ serialized_attack_graph = load_dict_from_json_file(filename)
122
+ else:
123
+ raise ValueError('Unknown file extension, expected json/yml/yaml')
124
+ return attack_graph_from_dict(serialized_attack_graph, lang_graph, model)
125
+
126
+
98
127
  class AttackGraph:
99
- """Graph representation of attack steps"""
128
+ """Graph representation of attack and defense steps"""
100
129
 
101
130
  def __init__(self, lang_graph: LanguageGraph, model: Model | None = None):
102
131
  self.nodes: dict[int, AttackGraphNode] = {}
@@ -105,12 +134,12 @@ class AttackGraph:
105
134
  self.model = model
106
135
  self.lang_graph = lang_graph
107
136
  self.next_node_id = 0
108
-
109
- # Dictionary used in optimization to get nodes by full name faster
110
- self._full_name_to_node: dict[str, AttackGraphNode] = {}
137
+ self.full_name_to_node: dict[str, AttackGraphNode] = {}
111
138
 
112
139
  if self.model is not None:
113
- self._generate_graph(self.model)
140
+ self.nodes, self.attack_steps, self.defense_steps, self.full_name_to_node = (
141
+ generate_graph(self.model)
142
+ )
114
143
 
115
144
  def __repr__(self) -> str:
116
145
  return (
@@ -150,12 +179,11 @@ class AttackGraph:
150
179
  memo[id(node)].children = copy.deepcopy(node.children, memo)
151
180
 
152
181
  # Copy lookup dicts
153
- copied_attackgraph._full_name_to_node = \
154
- copy.deepcopy(self._full_name_to_node, memo)
182
+ copied_attackgraph.full_name_to_node = \
183
+ copy.deepcopy(self.full_name_to_node, memo)
155
184
 
156
185
  # Copy counters
157
186
  copied_attackgraph.next_node_id = self.next_node_id
158
-
159
187
  return copied_attackgraph
160
188
 
161
189
  def save_to_file(self, filename: str) -> None:
@@ -163,98 +191,6 @@ class AttackGraph:
163
191
  logger.debug('Save attack graph to file "%s".', filename)
164
192
  return save_dict_to_file(filename, self._to_dict())
165
193
 
166
- @classmethod
167
- def _from_dict(
168
- cls,
169
- serialized_object: dict,
170
- lang_graph: LanguageGraph,
171
- model: Model | None = None
172
- ) -> AttackGraph:
173
- """Create AttackGraph from dict
174
- Args:
175
- serialized_object - AttackGraph in dict format
176
- model - Optional Model to add connections to
177
- """
178
- attack_graph = AttackGraph(lang_graph)
179
- attack_graph.model = model
180
- serialized_attack_steps: dict[str, dict] = serialized_object['attack_steps']
181
-
182
- # Create all of the nodes in the imported attack graph.
183
- for node_full_name, node_dict in serialized_attack_steps.items():
184
-
185
- # Recreate asset links if model is available.
186
- node_asset = None
187
- if model and 'asset' in node_dict:
188
- node_asset = model.get_asset_by_name(node_dict['asset'])
189
- if node_asset is None:
190
- msg = (
191
- 'Failed to find asset with name "%s"'
192
- ' when loading from attack graph dict'
193
- )
194
- logger.error(msg, node_dict["asset"])
195
- raise LookupError(msg % node_dict["asset"])
196
-
197
- lg_asset_name, lg_attack_step_name = (
198
- disaggregate_attack_step_full_name(
199
- node_dict['lang_graph_attack_step']
200
- )
201
- )
202
- lg_attack_step = (
203
- lang_graph.assets[lg_asset_name].attack_steps[lg_attack_step_name]
204
- )
205
- ag_node = attack_graph.add_node(
206
- lg_attack_step=lg_attack_step,
207
- node_id=node_dict['id'],
208
- model_asset=node_asset,
209
- ttc_dist=node_dict['ttc'],
210
- existence_status=(
211
- bool(node_dict['existence_status'])
212
- if 'existence_status' in node_dict else None
213
- ),
214
- # Give explicit full name if model is missing, otherwise
215
- # it will generate automatically in node.full_name
216
- full_name=node_full_name if not model else None
217
- )
218
- ag_node.tags = list(node_dict.get('tags', []))
219
- ag_node.extras = node_dict.get('extras', {})
220
-
221
- if node_asset:
222
- # Add AttackGraphNode to attack_step_nodes of asset
223
- if hasattr(node_asset, 'attack_step_nodes'):
224
- node_attack_steps = list(node_asset.attack_step_nodes)
225
- node_attack_steps.append(ag_node)
226
- node_asset.attack_step_nodes = node_attack_steps
227
- else:
228
- node_asset.attack_step_nodes = [ag_node]
229
-
230
- # Re-establish links between nodes.
231
- for node_dict in serialized_attack_steps.values():
232
- _ag_node = attack_graph.nodes[node_dict['id']]
233
- if not isinstance(_ag_node, AttackGraphNode):
234
- msg = ('Failed to find node with id %s when loading'
235
- ' attack graph from dict')
236
- logger.error(msg, node_dict["id"])
237
- raise LookupError(msg % node_dict["id"])
238
- for child_id in node_dict['children']:
239
- child = attack_graph.nodes[int(child_id)]
240
- if child is None:
241
- msg = ('Failed to find child node with id %s'
242
- ' when loading from attack graph from dict')
243
- logger.error(msg, child_id)
244
- raise LookupError(msg % child_id)
245
- _ag_node.children.add(child)
246
-
247
- for parent_id in node_dict['parents']:
248
- parent = attack_graph.nodes[int(parent_id)]
249
- if parent is None:
250
- msg = ('Failed to find parent node with id %s '
251
- 'when loading from attack graph from dict')
252
- logger.error(msg, parent_id)
253
- raise LookupError(msg % parent_id)
254
- _ag_node.parents.add(parent)
255
-
256
- return attack_graph
257
-
258
194
  @classmethod
259
195
  def load_from_file(
260
196
  cls,
@@ -263,21 +199,7 @@ class AttackGraph:
263
199
  model: Model | None = None
264
200
  ) -> AttackGraph:
265
201
  """Create from json or yaml file depending on file extension"""
266
- if model is not None:
267
- logger.debug('Load attack graph from file "%s" with '
268
- 'model "%s".', filename, model.name)
269
- else:
270
- logger.debug('Load attack graph from file "%s" '
271
- 'without model.', filename)
272
- serialized_attack_graph = None
273
- if filename.endswith(('.yml', '.yaml')):
274
- serialized_attack_graph = load_dict_from_yaml_file(filename)
275
- elif filename.endswith('.json'):
276
- serialized_attack_graph = load_dict_from_json_file(filename)
277
- else:
278
- raise ValueError('Unknown file extension, expected json/yml/yaml')
279
- return cls._from_dict(serialized_attack_graph,
280
- lang_graph, model=model)
202
+ return attack_graph_from_file(filename, lang_graph, model)
281
203
 
282
204
  def get_node_by_full_name(self, full_name: str) -> AttackGraphNode:
283
205
  """Return the attack node that matches the full name provided.
@@ -292,365 +214,16 @@ class AttackGraph:
292
214
  The attack step node that matches the given full name.
293
215
 
294
216
  """
295
- logger.debug('Looking up node with full name "%s"', full_name)
296
- if full_name not in self._full_name_to_node:
297
- similar_names = self._get_similar_full_names(full_name)
298
- raise LookupError(
299
- f'Could not find node with name "{full_name}". '
300
- f'Did you mean: {", ".join(similar_names)}?'
301
- )
302
- return self._full_name_to_node[full_name]
303
-
304
- def _follow_field_expr_chain(
305
- self, target_assets: set[ModelAsset], expr_chain: ExpressionsChain
306
- ):
307
- # Change the target assets from the current ones to the
308
- # associated assets given the specified field name.
309
- if not expr_chain.fieldname:
310
- raise LanguageGraphException(
311
- '"field" step expression chain is missing fieldname.'
312
- )
313
- new_target_assets: set[ModelAsset] = set()
314
- new_target_assets.update(
315
- *(
316
- asset.associated_assets.get(expr_chain.fieldname, set())
317
- for asset in target_assets
318
- )
319
- )
320
- return new_target_assets
321
-
322
- def _follow_transitive_expr_chain(
323
- self,
324
- model: Model,
325
- target_assets: set[ModelAsset],
326
- expr_chain: ExpressionsChain
327
- ):
328
- if not expr_chain.sub_link:
329
- raise LanguageGraphException(
330
- '"transitive" step expression chain is missing sub link.'
331
- )
332
-
333
- new_assets = target_assets
334
- while new_assets := self._follow_expr_chain(
335
- model, new_assets, expr_chain.sub_link
336
- ):
337
- new_assets = new_assets.difference(target_assets)
338
- if not new_assets:
339
- break
340
- target_assets.update(new_assets)
341
- return target_assets
342
-
343
- def _follow_subtype_expr_chain(
344
- self,
345
- model: Model,
346
- target_assets: set[ModelAsset],
347
- expr_chain: ExpressionsChain
348
- ):
349
- if not expr_chain.sub_link:
350
- raise LanguageGraphException(
351
- '"subType" step expression chain is missing sub link.'
352
- )
353
- new_target_assets = set()
354
- new_target_assets.update(
355
- self._follow_expr_chain(
356
- model, target_assets, expr_chain.sub_link
357
- )
358
- )
359
- selected_new_target_assets = set()
360
- for asset in new_target_assets:
361
- lang_graph_asset = self.lang_graph.assets[asset.type]
362
- if not lang_graph_asset:
363
- raise LookupError(
364
- f'Failed to find asset "{asset.type}" in the '
365
- 'language graph.'
366
- )
367
- lang_graph_subtype_asset = expr_chain.subtype
368
- if not lang_graph_subtype_asset:
369
- raise LookupError(
370
- 'Failed to find asset "{expr_chain.subtype}" in '
371
- 'the language graph.'
372
- )
373
- if lang_graph_asset.is_subasset_of(lang_graph_subtype_asset):
374
- selected_new_target_assets.add(asset)
375
-
376
- return selected_new_target_assets
377
-
378
- def _follow_union_intersection_difference_expr_chain(
379
- self,
380
- model: Model,
381
- target_assets: set[ModelAsset],
382
- expr_chain: ExpressionsChain
383
- ) -> set[Any]:
384
- # The set operators are used to combine the left hand and
385
- # right hand targets accordingly.
386
- if not expr_chain.left_link:
387
- raise LanguageGraphException(
388
- '"%s" step expression chain is missing the left link.',
389
- expr_chain.type
390
- )
391
- if not expr_chain.right_link:
392
- raise LanguageGraphException(
393
- '"%s" step expression chain is missing the right link.',
394
- expr_chain.type
395
- )
396
- lh_targets = self._follow_expr_chain(
397
- model, target_assets, expr_chain.left_link
398
- )
399
- rh_targets = self._follow_expr_chain(
400
- model, target_assets, expr_chain.right_link
401
- )
402
-
403
- if expr_chain.type == 'union':
404
- # Once the assets become hashable set operations should be
405
- # used instead.
406
- return lh_targets.union(rh_targets)
407
-
408
- if expr_chain.type == 'intersection':
409
- return lh_targets.intersection(rh_targets)
410
-
411
- if expr_chain.type == 'difference':
412
- return lh_targets.difference(rh_targets)
413
-
414
- raise ValueError("Expr chain must be of type union, intersectin or difference")
415
-
416
- def _follow_collect_expr_chain(
417
- self,
418
- model: Model,
419
- target_assets: set[ModelAsset],
420
- expr_chain: ExpressionsChain
421
- ) -> set[Any]:
422
- if not expr_chain.left_link:
423
- raise LanguageGraphException(
424
- '"collect" step expression chain missing the left link.'
425
- )
426
- if not expr_chain.right_link:
427
- raise LanguageGraphException(
428
- '"collect" step expression chain missing the right link.'
429
- )
430
- lh_targets = self._follow_expr_chain(
431
- model,
432
- target_assets,
433
- expr_chain.left_link
434
- )
435
- rh_targets = set()
436
- for lh_target in lh_targets:
437
- rh_targets |= self._follow_expr_chain(
438
- model,
439
- {lh_target},
440
- expr_chain.right_link
441
- )
442
- return rh_targets
443
-
444
- def _follow_expr_chain(
445
- self,
446
- model: Model,
447
- target_assets: set[ModelAsset],
448
- expr_chain: ExpressionsChain | None
449
- ) -> set[Any]:
450
- """Recursively follow a language graph expressions chain on an instance
451
- model.
452
-
453
- Arguments:
454
- ---------
455
- model - a maltoolbox.model.Model on which to follow the
456
- expressions chain
457
- target_assets - the set of assets that this expressions chain
458
- should apply to. Initially it will contain the
459
- asset to which the attack step belongs
460
- expr_chain - the expressions chain we are following
461
-
462
- Return:
463
- ------
464
- A list of all of the target assets.
465
-
466
- """
467
- if expr_chain is None:
468
- # There is no expressions chain link left to follow return the
469
- # current target assets
470
- return set(target_assets)
471
-
472
- if logger.isEnabledFor(logging.DEBUG):
473
- # Avoid running json.dumps when not in debug
474
- logger.debug(
475
- 'Following Expressions Chain:\n%s',
476
- json.dumps(expr_chain.to_dict(), indent=2)
477
- )
478
-
479
- match (expr_chain.type):
480
- case 'union' | 'intersection' | 'difference':
481
- return self._follow_union_intersection_difference_expr_chain(
482
- model, target_assets, expr_chain
483
- )
484
-
485
- case 'field':
486
- return self._follow_field_expr_chain(target_assets, expr_chain)
487
-
488
- case 'transitive':
489
- return self._follow_transitive_expr_chain(model, target_assets, expr_chain)
490
-
491
- case 'subType':
492
- return self._follow_subtype_expr_chain(model, target_assets, expr_chain)
493
-
494
- case 'collect':
495
- return self._follow_collect_expr_chain(model, target_assets, expr_chain)
496
-
497
- case _:
498
- msg = 'Unknown attack expressions chain type: %s'
499
- logger.error(
500
- msg,
501
- expr_chain.type
502
- )
503
- raise AttackGraphStepExpressionError(
504
- msg % expr_chain.type
505
- )
506
-
507
- def _get_existance_status(
508
- self,
509
- model: Model,
510
- asset: ModelAsset,
511
- attack_step: LanguageGraphAttackStep
512
- ) -> bool | None:
513
- """Get existance status of a step"""
514
- if attack_step.type not in ('exist', 'notExist'):
515
- # No existence status for other type of steps
516
- return None
517
-
518
- existence_status = False
519
- for requirement in attack_step.requires:
520
- target_assets = self._follow_expr_chain(
521
- model, set([asset]), requirement
522
- )
523
- # If the step expression resolution yielded
524
- # the target assets then the required assets
525
- # exist in the model.
526
- if target_assets:
527
- existence_status = True
528
- break
529
-
530
- return existence_status
531
-
532
- def _get_ttc_dist(
533
- self,
534
- asset: ModelAsset,
535
- attack_step: LanguageGraphAttackStep
536
- ):
537
- """Get step ttc distribution based on language
538
- and possibly overriding defense status
539
- """
540
- ttc_dist = copy.deepcopy(attack_step.ttc)
541
- if attack_step.type == 'defense':
542
- if attack_step.name in asset.defenses:
543
- # If defense status was set in model, set ttc accordingly
544
- defense_value = float(asset.defenses[attack_step.name])
545
- ttc_dist = {
546
- 'arguments': [defense_value],
547
- 'name': 'Bernoulli',
548
- 'type': 'function'
549
- }
550
- logger.debug(
551
- 'Setting defense \"%s\" to "%s".',
552
- asset.name + ":" + attack_step.name, defense_value
553
- )
554
- return ttc_dist
555
-
556
- def _generate_graph(self, model: Model) -> None:
557
- """Generate the attack graph from model and MAL language."""
558
- self.nodes = {}
559
- self._full_name_to_node = {}
560
-
561
- self._create_nodes_from_model(model)
562
- self._link_nodes_by_language(model)
563
-
564
- def _create_nodes_from_model(self, model: Model) -> None:
565
- """Create attack graph nodes for all model assets."""
566
- for asset in model.assets.values():
567
- asset.attack_step_nodes = []
568
- for attack_step in asset.lg_asset.attack_steps.values():
569
- node = self.add_node(
570
- lg_attack_step=attack_step,
571
- model_asset=asset,
572
- ttc_dist=self._get_ttc_dist(asset, attack_step),
573
- existence_status=(
574
- self._get_existance_status(model, asset, attack_step)
575
- ),
576
- )
577
- asset.attack_step_nodes.append(node)
578
-
579
- def _link_nodes_by_language(self, model: Model) -> None:
580
- """Establish parent-child links between nodes."""
581
- for ag_node in self.nodes.values():
582
- self._link_node_children(model, ag_node)
583
-
584
- def _link_node_children(self, model: Model, ag_node: AttackGraphNode) -> None:
585
- """Link one node to its children."""
586
- if not ag_node.model_asset:
587
- raise AttackGraphException('Attack graph node is missing asset link')
588
-
589
- lg_asset = self.lang_graph.assets[ag_node.model_asset.type]
590
- lg_attack_step: LanguageGraphAttackStep | None = (
591
- lg_asset.attack_steps[ag_node.name]
592
- )
593
- while lg_attack_step:
594
- for child_type, expr_chains in lg_attack_step.children.items():
595
- for expr_chain in expr_chains:
596
- self._link_from_expr_chain(model, ag_node, child_type, expr_chain)
597
- if lg_attack_step.overrides:
598
- break
599
- lg_attack_step = lg_attack_step.inherits
600
-
601
- def _link_from_expr_chain(
602
- self,
603
- model: Model,
604
- ag_node: AttackGraphNode,
605
- child_type: LanguageGraphAttackStep,
606
- expr_chain: ExpressionsChain | None,
607
- ) -> None:
608
- """Link a node to targets from a specific expression chain."""
609
- if not ag_node.model_asset:
610
- raise AttackGraphException(
611
- "Need model asset connection to generate graph"
612
- )
613
-
614
- target_assets = self._follow_expr_chain(model, {ag_node.model_asset}, expr_chain)
615
- for target_asset in target_assets:
616
- if not target_asset:
617
- continue
618
- target_node = self.get_node_by_full_name(
619
- f"{target_asset.name}:{child_type.name}"
620
- )
621
- if not target_node:
622
- raise AttackGraphStepExpressionError(
623
- f'Failed to find target node "{target_asset.name}:{child_type.name}" '
624
- f'for "{ag_node.full_name}"({ag_node.id})'
625
- )
626
- logger.debug(
627
- 'Linking attack step "%s"(%d) to attack step "%s"(%d)',
628
- ag_node.full_name, ag_node.id,
629
- target_node.full_name, target_node.id
630
- )
631
- ag_node.children.add(target_node)
632
- target_node.parents.add(ag_node)
633
-
634
- def _get_similar_full_names(self, q: str) -> list[str]:
635
- """Return a list of node full names that are similar to `q`"""
636
- shortest_dist = 100
637
- similar_names = []
638
- for full_name in self._full_name_to_node:
639
- dist = levenshtein_distance(q, full_name)
640
- if dist == shortest_dist:
641
- similar_names.append(full_name)
642
- elif dist < shortest_dist:
643
- similar_names = [full_name]
644
- shortest_dist = dist
645
- return similar_names
217
+ return get_node_by_full_name(self.full_name_to_node, full_name)
646
218
 
647
219
  def regenerate_graph(self) -> None:
648
220
  """Regenerate the attack graph based on the original model instance and
649
221
  the MAL language specification provided at initialization.
650
222
  """
651
- self.nodes = {}
652
223
  assert self.model, "Model required to generate graph"
653
- self._generate_graph(self.model)
224
+ self.nodes, self.attack_steps, self.defense_steps, self.full_name_to_node = (
225
+ generate_graph(self.model)
226
+ )
654
227
 
655
228
  def add_node(
656
229
  self,
@@ -687,10 +260,11 @@ class AttackGraph:
687
260
  The newly created attack step node.
688
261
 
689
262
  """
263
+
690
264
  node_id = node_id if node_id is not None else self.next_node_id
691
265
  if node_id in self.nodes:
692
266
  raise ValueError(f'Node index {node_id} already in use.')
693
- self.next_node_id = max(node_id + 1, self.next_node_id)
267
+ self.next_node_id = node_id + 1
694
268
 
695
269
  logger.debug(
696
270
  'Create and add to attackgraph node of type "%s" with id:%d.\n',
@@ -714,7 +288,7 @@ class AttackGraph:
714
288
  self.defense_steps.append(node)
715
289
 
716
290
  self.nodes[node_id] = node
717
- self._full_name_to_node[node.full_name] = node
291
+ self.full_name_to_node[node.full_name] = node
718
292
 
719
293
  return node
720
294
 
@@ -730,8 +304,6 @@ class AttackGraph:
730
304
  child.parents.remove(node)
731
305
  for parent in node.parents:
732
306
  parent.children.remove(node)
733
-
734
- if not isinstance(node.id, int):
735
- raise ValueError('Invalid node id.')
736
307
  del self.nodes[node.id]
737
- del self._full_name_to_node[node.full_name]
308
+ del self.full_name_to_node[node.full_name]
309
+