mal-toolbox 0.1.12__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,221 +5,112 @@ from __future__ import annotations
5
5
  import copy
6
6
  import logging
7
7
  import json
8
+ import sys
9
+ import zipfile
8
10
 
11
+ from itertools import chain
9
12
  from typing import TYPE_CHECKING
10
13
 
14
+ from .analyzers.apriori import calculate_viability_and_necessity
11
15
  from .node import AttackGraphNode
12
16
  from .attacker import Attacker
13
- from ..exceptions import AttackGraphStepExpressionError
17
+ from .. import log_configs
18
+ from ..exceptions import AttackGraphStepExpressionError, AttackGraphException
19
+ from ..exceptions import LanguageGraphException
14
20
  from ..model import Model
15
- from ..exceptions import AttackGraphException
21
+ from ..language import (LanguageGraph, ExpressionsChain,
22
+ LanguageGraphAttackStep, disaggregate_attack_step_full_name)
16
23
  from ..file_utils import (
17
24
  load_dict_from_json_file,
18
25
  load_dict_from_yaml_file,
19
26
  save_dict_to_file
20
27
  )
21
28
 
29
+
22
30
  if TYPE_CHECKING:
23
31
  from typing import Any, Optional
24
- from ..language import LanguageGraph
32
+ from ..model import ModelAsset
25
33
 
26
34
  logger = logging.getLogger(__name__)
27
35
 
28
- # TODO see if (part of) this can be incorporated into the LanguageGraph, so that
29
- # the LanguageGraph's _lang_spec private property does not need to be accessed
30
- def _process_step_expression(
31
- lang_graph: LanguageGraph,
32
- model: Model,
33
- target_assets: list[Any],
34
- step_expression: dict[str, Any]
35
- ) -> tuple[list, Optional[str]]:
36
- """
37
- Recursively process an attack step expression.
38
-
39
- Arguments:
40
- lang_graph - a language graph representing the MAL language
41
- specification
42
- model - a maltoolbox.model.Model instance from which the attack
43
- graph was generated
44
- target_assets - the list of assets that this step expression should apply
45
- to. Initially it will contain the asset to which the
46
- attack step belongs
47
- step_expression - a dictionary containing the step expression
48
-
49
- Return:
50
- A tuple pair containing a list of all of the target assets and the name of
51
- the attack step.
36
+
37
+ def create_attack_graph(
38
+ lang_file: str,
39
+ model_file: str,
40
+ attach_attackers=True,
41
+ calc_viability_and_necessity=True
42
+ ) -> AttackGraph:
43
+ """Create and return an attack graph
44
+
45
+ Args:
46
+ lang_file - path to language file (.mar or .mal)
47
+ model_file - path to model file (yaml or json)
48
+ attach_attackers - whether to run attach_attackers or not
49
+ calc_viability_and_necessity - whether run apriori calculations or not
52
50
  """
51
+ try:
52
+ lang_graph = LanguageGraph.from_mar_archive(lang_file)
53
+ except zipfile.BadZipFile:
54
+ lang_graph = LanguageGraph.from_mal_spec(lang_file)
53
55
 
54
- if logger.isEnabledFor(logging.DEBUG):
55
- # Avoid running json.dumps when not in debug
56
- logger.debug(
57
- 'Processing Step Expression:\n%s',
58
- json.dumps(step_expression, indent = 2)
59
- )
56
+ if log_configs['langspec_file']:
57
+ lang_graph.save_to_file(log_configs['langspec_file'])
60
58
 
61
- match (step_expression['type']):
62
- case 'attackStep':
63
- # The attack step expression just adds the name of the attack
64
- # step. All other step expressions only modify the target assets.
65
- return (target_assets, step_expression['name'])
66
-
67
- case 'union' | 'intersection' | 'difference':
68
- # The set operators are used to combine the left hand and right
69
- # hand targets accordingly.
70
- lh_targets, lh_attack_steps = _process_step_expression(
71
- lang_graph, model, target_assets, step_expression['lhs'])
72
- rh_targets, rh_attack_steps = _process_step_expression(
73
- lang_graph, model, target_assets, step_expression['rhs'])
74
-
75
- new_target_assets = []
76
- match (step_expression['type']):
77
- case 'union':
78
- new_target_assets = lh_targets
79
- for ag_node in rh_targets:
80
- if next((lnode for lnode in new_target_assets \
81
- if lnode.id != ag_node.id), None):
82
- new_target_assets.append(ag_node)
83
-
84
- case 'intersection':
85
- for ag_node in rh_targets:
86
- if next((lnode for lnode in lh_targets \
87
- if lnode.id == ag_node.id), None):
88
- new_target_assets.append(ag_node)
89
-
90
- case 'difference':
91
- new_target_assets = lh_targets
92
- for ag_node in lh_targets:
93
- if next((rnode for rnode in rh_targets \
94
- if rnode.id != ag_node.id), None):
95
- new_target_assets.remove(ag_node)
96
-
97
- return (new_target_assets, None)
98
-
99
- case 'variable':
100
- # Fetch the step expression associated with the variable from
101
- # the language specification and resolve that.
102
- for target_asset in target_assets:
103
- if (hasattr(target_asset, 'type')):
104
- # TODO how can this info be accessed in the lang_graph
105
- # directly without going through the private method?
106
- variable_step_expr = lang_graph._get_variable_for_asset_type_by_name(
107
- target_asset.type, step_expression['name'])
108
- return _process_step_expression(
109
- lang_graph, model, target_assets, variable_step_expr)
59
+ instance_model = Model.load_from_file(model_file, lang_graph)
110
60
 
111
- else:
112
- logger.error(
113
- 'Requested variable from non-asset target node:'
114
- '%s which cannot be resolved.', target_asset
115
- )
116
- return ([], None)
117
-
118
- case 'field':
119
- # Change the target assets from the current ones to the associated
120
- # assets given the specified field name.
121
- new_target_assets = []
122
- for target_asset in target_assets:
123
- new_target_assets.extend(model.\
124
- get_associated_assets_by_field_name(target_asset,
125
- step_expression['name']))
126
- return (new_target_assets, None)
127
-
128
- case 'transitive':
129
- # The transitive expression is very similar to the field
130
- # expression, but it proceeds recursively until no target is
131
- # found and it and it sets the new targets to the entire list
132
- # of assets identified during the entire transitive recursion.
133
- new_target_assets = []
134
- for target_asset in target_assets:
135
- new_target_assets.extend(model.\
136
- get_associated_assets_by_field_name(target_asset,
137
- step_expression['stepExpression']['name']))
138
- if new_target_assets:
139
- (additional_assets, _) = _process_step_expression(
140
- lang_graph, model, new_target_assets, step_expression)
141
- new_target_assets.extend(additional_assets)
142
- return (new_target_assets, None)
143
- else:
144
- return ([], None)
145
-
146
- case 'subType':
147
- new_target_assets = []
148
- for target_asset in target_assets:
149
- (assets, _) = _process_step_expression(
150
- lang_graph, model, target_assets,
151
- step_expression['stepExpression'])
152
- new_target_assets.extend(assets)
153
-
154
- selected_new_target_assets = []
155
- for asset in new_target_assets:
156
- lang_graph_asset = lang_graph.get_asset_by_name(
157
- asset.type
158
- )
159
- if not lang_graph_asset:
160
- raise LookupError(
161
- f'Failed to find asset \"{asset.type}\" in the '
162
- 'language graph.'
163
- )
164
- lang_graph_subtype_asset = lang_graph.get_asset_by_name(
165
- step_expression['subType']
166
- )
167
- if not lang_graph_subtype_asset:
168
- raise LookupError(
169
- 'Failed to find asset '
170
- f'\"{step_expression["subType"]}\" in the '
171
- 'language graph.'
172
- )
173
- if lang_graph_asset.is_subasset_of(lang_graph_subtype_asset):
174
- selected_new_target_assets.append(asset)
61
+ if log_configs['model_file']:
62
+ instance_model.save_to_file(log_configs['model_file'])
63
+
64
+ try:
65
+ attack_graph = AttackGraph(lang_graph, instance_model)
66
+ except AttackGraphStepExpressionError:
67
+ logger.error(
68
+ 'Attack graph generation failed when attempting '
69
+ 'to resolve attack step expression!'
70
+ )
71
+ sys.exit(1)
175
72
 
176
- return (selected_new_target_assets, None)
73
+ if attach_attackers:
74
+ attack_graph.attach_attackers()
177
75
 
178
- case 'collect':
179
- # Apply the right hand step expression to left hand step
180
- # expression target assets.
181
- lh_targets, _ = _process_step_expression(
182
- lang_graph, model, target_assets, step_expression['lhs'])
183
- return _process_step_expression(lang_graph, model, lh_targets,
184
- step_expression['rhs'])
76
+ if calc_viability_and_necessity:
77
+ calculate_viability_and_necessity(attack_graph)
185
78
 
79
+ return attack_graph
186
80
 
187
- case _:
188
- logger.error(
189
- 'Unknown attack step type: %s', step_expression["type"]
190
- )
191
- return ([], None)
192
81
 
193
82
  class AttackGraph():
194
83
  """Graph representation of attack steps"""
195
- def __init__(self, lang_graph = None, model: Optional[Model] = None):
196
- self.nodes: list[AttackGraphNode] = []
197
- self.attackers: list[Attacker] = []
84
+ def __init__(self, lang_graph, model: Optional[Model] = None):
85
+ self.nodes: dict[int, AttackGraphNode] = {}
86
+ self.attackers: dict[int, Attacker] = {}
198
87
  # Dictionaries used in optimization to get nodes and attackers by id
199
88
  # or full name faster
200
- self._id_to_node: dict[int, AttackGraphNode] = {}
201
89
  self._full_name_to_node: dict[str, AttackGraphNode] = {}
202
- self._id_to_attacker: dict[int, Attacker] = {}
203
90
 
204
91
  self.model = model
205
92
  self.lang_graph = lang_graph
206
93
  self.next_node_id = 0
207
94
  self.next_attacker_id = 0
208
- if self.model is not None and self.lang_graph is not None:
95
+ if self.model is not None:
209
96
  self._generate_graph()
210
97
 
211
98
  def __repr__(self) -> str:
212
- return f'AttackGraph({len(self.nodes)} nodes)'
99
+ return (f'AttackGraph(Number of nodes: {len(self.nodes)}, '
100
+ f'model: {self.model}, language: {self.lang_graph}')
213
101
 
214
102
  def _to_dict(self) -> dict:
215
103
  """Convert AttackGraph to dict"""
216
104
  serialized_attack_steps = {}
217
105
  serialized_attackers = {}
218
- for ag_node in self.nodes:
106
+ for ag_node in self.nodes.values():
219
107
  serialized_attack_steps[ag_node.full_name] =\
220
108
  ag_node.to_dict()
221
- for attacker in self.attackers:
109
+ for attacker in self.attackers.values():
222
110
  serialized_attackers[attacker.name] = attacker.to_dict()
111
+ logger.debug('Serialized %d attack steps and %d attackers.' %
112
+ (len(self.nodes), len(self.attackers))
113
+ )
223
114
  return {
224
115
  'attack_steps': serialized_attack_steps,
225
116
  'attackers': serialized_attackers,
@@ -234,34 +125,32 @@ class AttackGraph():
234
125
  copied_attackgraph = AttackGraph(self.lang_graph)
235
126
  copied_attackgraph.model = self.model
236
127
 
237
- copied_attackgraph.nodes = []
128
+ copied_attackgraph.nodes = {}
238
129
 
239
130
  # Deep copy nodes
240
- for node in self.nodes:
131
+ for node_id, node in self.nodes.items():
241
132
  copied_node = copy.deepcopy(node, memo)
242
- copied_attackgraph.nodes.append(copied_node)
133
+ copied_attackgraph.nodes[node_id] = copied_node
243
134
 
244
135
  # Re-link node references
245
- for node in self.nodes:
136
+ for node in self.nodes.values():
246
137
  if node.parents:
247
138
  memo[id(node)].parents = copy.deepcopy(node.parents, memo)
248
139
  if node.children:
249
140
  memo[id(node)].children = copy.deepcopy(node.children, memo)
250
141
 
251
- # Deep copy attackers and references to them
252
- copied_attackgraph.attackers = copy.deepcopy(self.attackers, memo)
142
+ # Deep copy attackers
143
+ for attacker_id, attacker in self.attackers.items():
144
+ copied_attacker = copy.deepcopy(attacker, memo)
145
+ copied_attackgraph.attackers[attacker_id] = copied_attacker
253
146
 
254
147
  # Re-link attacker references
255
- for node in self.nodes:
148
+ for node in self.nodes.values():
256
149
  if node.compromised_by:
257
150
  memo[id(node)].compromised_by = copy.deepcopy(
258
151
  node.compromised_by, memo)
259
152
 
260
153
  # Copy lookup dicts
261
- copied_attackgraph._id_to_attacker = \
262
- copy.deepcopy(self._id_to_attacker, memo)
263
- copied_attackgraph._id_to_node = \
264
- copy.deepcopy(self._id_to_node, memo)
265
154
  copied_attackgraph._full_name_to_node = \
266
155
  copy.deepcopy(self._full_name_to_node, memo)
267
156
 
@@ -280,6 +169,7 @@ class AttackGraph():
280
169
  def _from_dict(
281
170
  cls,
282
171
  serialized_object: dict,
172
+ lang_graph: LanguageGraph,
283
173
  model: Optional[Model]=None
284
174
  ) -> AttackGraph:
285
175
  """Create AttackGraph from dict
@@ -288,30 +178,38 @@ class AttackGraph():
288
178
  model - Optional Model to add connections to
289
179
  """
290
180
 
291
- attack_graph = AttackGraph()
181
+ attack_graph = AttackGraph(lang_graph)
292
182
  attack_graph.model = model
293
183
  serialized_attack_steps = serialized_object['attack_steps']
294
184
  serialized_attackers = serialized_object['attackers']
295
185
 
296
186
  # Create all of the nodes in the imported attack graph.
297
- for node_full_name, node_dict in serialized_attack_steps.items():
187
+ for node_dict in serialized_attack_steps.values():
298
188
 
299
189
  # Recreate asset links if model is available.
300
190
  node_asset = None
301
191
  if model and 'asset' in node_dict:
302
192
  node_asset = model.get_asset_by_name(node_dict['asset'])
303
193
  if node_asset is None:
304
- msg = ('Failed to find asset with id %s'
305
- 'when loading from attack graph dict')
194
+ msg = ('Failed to find asset with name "%s"'
195
+ ' when loading from attack graph dict')
306
196
  logger.error(msg, node_dict["asset"])
307
197
  raise LookupError(msg % node_dict["asset"])
308
198
 
309
- ag_node = AttackGraphNode(
310
- type=node_dict['type'],
311
- name=node_dict['name'],
312
- ttc=node_dict['ttc'],
313
- asset=node_asset
199
+ lg_asset_name, lg_attack_step_name = \
200
+ disaggregate_attack_step_full_name(
201
+ node_dict['lang_graph_attack_step'])
202
+ lg_attack_step = lang_graph.assets[lg_asset_name].\
203
+ attack_steps[lg_attack_step_name]
204
+ ag_node = attack_graph.add_node(
205
+ lg_attack_step = lg_attack_step,
206
+ node_id = node_dict['id'],
207
+ model_asset = node_asset,
208
+ defense_status = node_dict.get('defense_status', None),
209
+ existence_status = node_dict.get('existence_status', None)
314
210
  )
211
+ ag_node.tags = set(node_dict.get('tags', []))
212
+ ag_node.extras = node_dict.get('extras', {})
315
213
 
316
214
  if node_asset:
317
215
  # Add AttackGraphNode to attack_step_nodes of asset
@@ -322,26 +220,10 @@ class AttackGraph():
322
220
  else:
323
221
  node_asset.attack_step_nodes = [ag_node]
324
222
 
325
- ag_node.defense_status = float(node_dict['defense_status']) if \
326
- 'defense_status' in node_dict else None
327
- ag_node.existence_status = node_dict['existence_status'] \
328
- == 'True' if 'existence_status' in node_dict else None
329
- ag_node.is_viable = node_dict['is_viable'] == 'True' if \
330
- 'is_viable' in node_dict else True
331
- ag_node.is_necessary = node_dict['is_necessary'] == 'True' if \
332
- 'is_necessary' in node_dict else True
333
- ag_node.mitre_info = str(node_dict['mitre_info']) if \
334
- 'mitre_info' in node_dict else None
335
- ag_node.tags = node_dict['tags'] if \
336
- 'tags' in node_dict else []
337
- ag_node.extras = node_dict.get('extras', {})
338
-
339
- # Add AttackGraphNode to AttackGraph
340
- attack_graph.add_node(ag_node, node_id=node_dict['id'])
341
223
 
342
224
  # Re-establish links between nodes.
343
- for node_full_name, node_dict in serialized_attack_steps.items():
344
- _ag_node = attack_graph.get_node_by_id(node_dict['id'])
225
+ for node_dict in serialized_attack_steps.values():
226
+ _ag_node = attack_graph.nodes[node_dict['id']]
345
227
  if not isinstance(_ag_node, AttackGraphNode):
346
228
  msg = ('Failed to find node with id %s when loading'
347
229
  ' attack graph from dict')
@@ -349,33 +231,36 @@ class AttackGraph():
349
231
  raise LookupError(msg % node_dict["id"])
350
232
  else:
351
233
  for child_id in node_dict['children']:
352
- child = attack_graph.get_node_by_id(int(child_id))
234
+ child = attack_graph.nodes[int(child_id)]
353
235
  if child is None:
354
236
  msg = ('Failed to find child node with id %s'
355
237
  ' when loading from attack graph from dict')
356
238
  logger.error(msg, child_id)
357
239
  raise LookupError(msg % child_id)
358
- _ag_node.children.append(child)
240
+ _ag_node.children.add(child)
359
241
 
360
242
  for parent_id in node_dict['parents']:
361
- parent = attack_graph.get_node_by_id(int(parent_id))
243
+ parent = attack_graph.nodes[int(parent_id)]
362
244
  if parent is None:
363
245
  msg = ('Failed to find parent node with id %s '
364
246
  'when loading from attack graph from dict')
365
247
  logger.error(msg, parent_id)
366
248
  raise LookupError(msg % parent_id)
367
- _ag_node.parents.append(parent)
249
+ _ag_node.parents.add(parent)
368
250
 
369
- for attacker_name, attacker in serialized_attackers.items():
251
+ for attacker in serialized_attackers.values():
370
252
  ag_attacker = Attacker(
371
253
  name = attacker['name'],
372
- entry_points = [],
373
- reached_attack_steps = []
254
+ entry_points = set(),
255
+ reached_attack_steps = set()
374
256
  )
375
257
  attack_graph.add_attacker(
376
258
  attacker = ag_attacker,
377
259
  attacker_id = int(attacker['id']),
378
- entry_points = attacker['entry_points'].keys(),
260
+ entry_points = [
261
+ int(node_id) # Convert to int since they can be strings
262
+ for node_id in attacker['entry_points'].keys()
263
+ ],
379
264
  reached_attack_steps = [
380
265
  int(node_id) # Convert to int since they can be strings
381
266
  for node_id in attacker['reached_attack_steps'].keys()
@@ -388,7 +273,8 @@ class AttackGraph():
388
273
  def load_from_file(
389
274
  cls,
390
275
  filename: str,
391
- model: Optional[Model]=None
276
+ lang_graph: LanguageGraph,
277
+ model: Optional[Model] = None
392
278
  ) -> AttackGraph:
393
279
  """Create from json or yaml file depending on file extension"""
394
280
  if model is not None:
@@ -404,21 +290,8 @@ class AttackGraph():
404
290
  serialized_attack_graph = load_dict_from_json_file(filename)
405
291
  else:
406
292
  raise ValueError('Unknown file extension, expected json/yml/yaml')
407
- return cls._from_dict(serialized_attack_graph, model=model)
408
-
409
- def get_node_by_id(self, node_id: int) -> Optional[AttackGraphNode]:
410
- """
411
- Return the attack node that matches the id provided.
412
-
413
- Arguments:
414
- node_id - the id of the attack graph node we are looking for
415
-
416
- Return:
417
- The attack step node that matches the given id.
418
- """
419
-
420
- logger.debug('Looking up node with id %s', node_id)
421
- return self._id_to_node.get(node_id)
293
+ return cls._from_dict(serialized_attack_graph,
294
+ lang_graph, model = model)
422
295
 
423
296
  def get_node_by_full_name(self, full_name: str) -> Optional[AttackGraphNode]:
424
297
  """
@@ -432,23 +305,9 @@ class AttackGraph():
432
305
  The attack step node that matches the given full name.
433
306
  """
434
307
 
435
- logger.debug(f'Looking up node with full name "{full_name}"')
308
+ logger.debug(f'Looking up node with full name "%s"', full_name)
436
309
  return self._full_name_to_node.get(full_name)
437
310
 
438
- def get_attacker_by_id(self, attacker_id: int) -> Optional[Attacker]:
439
- """
440
- Return the attacker that matches the id provided.
441
-
442
- Arguments:
443
- attacker_id - the id of the attacker we are looking for
444
-
445
- Return:
446
- The attacker that matches the given id.
447
- """
448
-
449
- logger.debug(f'Looking up attacker with id {attacker_id}')
450
- return self._id_to_attacker.get(attacker_id)
451
-
452
311
  def attach_attackers(self) -> None:
453
312
  """
454
313
  Create attackers and their entry point nodes and attach them to the
@@ -473,8 +332,8 @@ class AttackGraph():
473
332
 
474
333
  attacker = Attacker(
475
334
  name = attacker_info.name,
476
- entry_points = [],
477
- reached_attack_steps = []
335
+ entry_points = set(),
336
+ reached_attack_steps = set()
478
337
  )
479
338
  self.add_attacker(attacker)
480
339
 
@@ -491,7 +350,170 @@ class AttackGraph():
491
350
  continue
492
351
  attacker.compromise(ag_node)
493
352
 
494
- attacker.entry_points = list(attacker.reached_attack_steps)
353
+ attacker.entry_points = set(attacker.reached_attack_steps)
354
+
355
+ def _follow_expr_chain(
356
+ self,
357
+ model: Model,
358
+ target_assets: set[ModelAsset],
359
+ expr_chain: Optional[ExpressionsChain]
360
+ ) -> set[Any]:
361
+ """
362
+ Recursively follow a language graph expressions chain on an instance
363
+ model.
364
+
365
+ Arguments:
366
+ model - a maltoolbox.model.Model on which to follow the
367
+ expressions chain
368
+ target_assets - the set of assets that this expressions chain
369
+ should apply to. Initially it will contain the
370
+ asset to which the attack step belongs
371
+ expr_chain - the expressions chain we are following
372
+
373
+ Return:
374
+ A list of all of the target assets.
375
+ """
376
+
377
+ if expr_chain is None:
378
+ # There is no expressions chain link left to follow return the
379
+ # current target assets
380
+ return set(target_assets)
381
+
382
+ if logger.isEnabledFor(logging.DEBUG):
383
+ # Avoid running json.dumps when not in debug
384
+ logger.debug(
385
+ 'Following Expressions Chain:\n%s',
386
+ json.dumps(expr_chain.to_dict(), indent = 2)
387
+ )
388
+
389
+ match (expr_chain.type):
390
+ case 'union' | 'intersection' | 'difference':
391
+ # The set operators are used to combine the left hand and
392
+ # right hand targets accordingly.
393
+ if not expr_chain.left_link:
394
+ raise LanguageGraphException('"%s" step expression chain'
395
+ ' is missing the left link.' % expr_chain.type)
396
+ if not expr_chain.right_link:
397
+ raise LanguageGraphException('"%s" step expression chain'
398
+ ' is missing the right link.' % expr_chain.type)
399
+ lh_targets = self._follow_expr_chain(
400
+ model,
401
+ target_assets,
402
+ expr_chain.left_link
403
+ )
404
+ rh_targets = self._follow_expr_chain(
405
+ model,
406
+ target_assets,
407
+ expr_chain.right_link
408
+ )
409
+
410
+ match (expr_chain.type):
411
+ # Once the assets become hashable set operations should be
412
+ # used instead.
413
+ case 'union':
414
+ new_target_assets = lh_targets.union(rh_targets)
415
+
416
+ case 'intersection':
417
+ new_target_assets = lh_targets.intersection(rh_targets)
418
+
419
+ case 'difference':
420
+ new_target_assets = lh_targets.difference(rh_targets)
421
+
422
+ return new_target_assets
423
+
424
+ case 'field':
425
+ # Change the target assets from the current ones to the
426
+ # associated assets given the specified field name.
427
+ if not expr_chain.fieldname:
428
+ raise LanguageGraphException('"field" step expression '
429
+ 'chain is missing fieldname.')
430
+ new_target_assets = set()
431
+ new_target_assets.update(
432
+ *(
433
+ asset.associated_assets.get(
434
+ expr_chain.fieldname, set()
435
+ ) for asset in target_assets
436
+ )
437
+ )
438
+ return new_target_assets
439
+
440
+ case 'transitive':
441
+ if not expr_chain.sub_link:
442
+ raise LanguageGraphException('"transitive" step '
443
+ 'expression chain is missing sub link.')
444
+
445
+ new_assets = target_assets
446
+
447
+ while new_assets := self._follow_expr_chain(
448
+ model, new_assets, expr_chain.sub_link
449
+ ):
450
+ if not (new_assets := new_assets.difference(target_assets)):
451
+ break
452
+
453
+ target_assets.update(new_assets)
454
+
455
+ return target_assets
456
+
457
+ case 'subType':
458
+ if not expr_chain.sub_link:
459
+ raise LanguageGraphException('"subType" step '
460
+ 'expression chain is missing sub link.')
461
+ new_target_assets = set()
462
+ new_target_assets.update(
463
+ self._follow_expr_chain(
464
+ model, target_assets, expr_chain.sub_link
465
+ )
466
+ )
467
+
468
+ selected_new_target_assets = set()
469
+ for asset in new_target_assets:
470
+ lang_graph_asset = self.lang_graph.assets[asset.type]
471
+ if not lang_graph_asset:
472
+ raise LookupError(
473
+ f'Failed to find asset \"{asset.type}\" in the '
474
+ 'language graph.'
475
+ )
476
+ lang_graph_subtype_asset = expr_chain.subtype
477
+ if not lang_graph_subtype_asset:
478
+ raise LookupError(
479
+ 'Failed to find asset "%s" in the '
480
+ 'language graph.' % expr_chain.subtype
481
+ )
482
+ if lang_graph_asset.is_subasset_of(
483
+ lang_graph_subtype_asset):
484
+ selected_new_target_assets.add(asset)
485
+
486
+ return selected_new_target_assets
487
+
488
+ case 'collect':
489
+ if not expr_chain.left_link:
490
+ raise LanguageGraphException('"collect" step expression chain'
491
+ ' is missing the left link.')
492
+ if not expr_chain.right_link:
493
+ raise LanguageGraphException('"collect" step expression chain'
494
+ ' is missing the right link.')
495
+ lh_targets = self._follow_expr_chain(
496
+ model,
497
+ target_assets,
498
+ expr_chain.left_link
499
+ )
500
+ rh_targets = self._follow_expr_chain(
501
+ model,
502
+ lh_targets,
503
+ expr_chain.right_link
504
+ )
505
+ return rh_targets
506
+
507
+ case _:
508
+ msg = 'Unknown attack expressions chain type: %s'
509
+ logger.error(
510
+ msg,
511
+ expr_chain.type
512
+ )
513
+ raise AttackGraphStepExpressionError(
514
+ msg % expr_chain.type
515
+ )
516
+ return None
495
517
 
496
518
  def _generate_graph(self) -> None:
497
519
  """
@@ -505,7 +527,7 @@ class AttackGraph():
505
527
  raise AttackGraphException(msg)
506
528
 
507
529
  # First, generate all of the nodes of the attack graph.
508
- for asset in self.model.assets:
530
+ for asset in self.model.assets.values():
509
531
 
510
532
  logger.debug(
511
533
  'Generating attack steps for asset %s which is of class %s.',
@@ -514,104 +536,128 @@ class AttackGraph():
514
536
 
515
537
  attack_step_nodes = []
516
538
 
517
- # TODO probably part of what happens here is already done in lang_graph
518
- attack_steps = self.lang_graph._get_attacks_for_asset_type(asset.type)
519
-
520
- for attack_step_name, attack_step_attribs in attack_steps.items():
539
+ for attack_step in asset.lg_asset.attack_steps.values():
521
540
  logger.debug(
522
- 'Generating attack step node for %s.', attack_step_name
541
+ 'Generating attack step node for %s.', attack_step.name
523
542
  )
524
543
 
525
544
  defense_status = None
526
545
  existence_status = None
527
- node_name = asset.name + ':' + attack_step_name
546
+ node_name = asset.name + ':' + attack_step.name
528
547
 
529
- match (attack_step_attribs['type']):
548
+ match (attack_step.type):
530
549
  case 'defense':
531
550
  # Set the defense status for defenses
532
- defense_status = getattr(asset, attack_step_name)
551
+ defense_status = asset.defenses[attack_step.name]
533
552
  logger.debug(
534
- 'Setting the defense status of %s to %s.',
553
+ 'Setting the defense status of \"%s\" to "%s".',
535
554
  node_name, defense_status
536
555
  )
537
556
 
538
557
  case 'exist' | 'notExist':
539
- # Resolve step expression associated with (non-)existence
540
- # attack steps.
541
- (target_assets, attack_step) = _process_step_expression(
542
- self.lang_graph,
543
- self.model,
544
- [asset],
545
- attack_step_attribs['requires']['stepExpressions'][0])
546
- # If the step expression resolution yielded the target
547
- # assets then the required assets exist in the model.
548
- existence_status = target_assets != []
549
-
550
- mitre_info = attack_step_attribs['meta']['mitre'] if 'mitre' in\
551
- attack_step_attribs['meta'] else None
552
- ag_node = AttackGraphNode(
553
- type = attack_step_attribs['type'],
554
- asset = asset,
555
- name = attack_step_name,
556
- ttc = attack_step_attribs['ttc'],
557
- children = [],
558
- parents = [],
558
+ # Resolve step expression associated with
559
+ # (non-)existence attack steps.
560
+ existence_status = False
561
+ for requirement in attack_step.requires:
562
+ target_assets = self._follow_expr_chain(
563
+ self.model,
564
+ set([asset]),
565
+ requirement
566
+ )
567
+ # If the step expression resolution yielded
568
+ # the target assets then the required assets
569
+ # exist in the model.
570
+ if target_assets:
571
+ existence_status = True
572
+ break
573
+
574
+ logger.debug(
575
+ 'Setting the existence status of \"%s\" to '
576
+ '%s.',
577
+ node_name, existence_status
578
+ )
579
+
580
+ case _:
581
+ pass
582
+
583
+ ag_node = self.add_node(
584
+ lg_attack_step = attack_step,
585
+ model_asset = asset,
559
586
  defense_status = defense_status,
560
- existence_status = existence_status,
561
- is_viable = True,
562
- is_necessary = True,
563
- mitre_info = mitre_info,
564
- tags = attack_step_attribs['tags'],
565
- compromised_by = []
587
+ existence_status = existence_status
566
588
  )
567
- ag_node.attributes = attack_step_attribs
568
589
  attack_step_nodes.append(ag_node)
569
- self.add_node(ag_node)
590
+
570
591
  asset.attack_step_nodes = attack_step_nodes
571
592
 
572
593
  # Then, link all of the nodes according to their associations.
573
- for ag_node in self.nodes:
594
+ for ag_node in self.nodes.values():
574
595
  logger.debug(
575
596
  'Determining children for attack step "%s"(%d)',
576
597
  ag_node.full_name,
577
598
  ag_node.id
578
599
  )
579
- step_expressions = \
580
- ag_node.attributes['reaches']['stepExpressions'] if \
581
- isinstance(ag_node.attributes, dict) and ag_node.attributes['reaches'] else []
582
-
583
- for step_expression in step_expressions:
584
- # Resolve each of the attack step expressions listed for this
585
- # attack step to determine children.
586
- (target_assets, attack_step) = _process_step_expression(
587
- self.lang_graph,
588
- self.model,
589
- [ag_node.asset],
590
- step_expression)
591
-
592
- for target in target_assets:
593
- target_node_full_name = target.name + ':' + attack_step
594
- target_node = self.get_node_by_full_name(
595
- target_node_full_name
596
- )
597
- if not target_node:
598
- msg = ('Failed to find target node '
599
- '"%s" to link with for attack step "%s"(%d)!')
600
- logger.error(
601
- msg,
602
- target_node_full_name,
603
- ag_node.full_name,
604
- ag_node.id
605
- )
606
- raise AttackGraphStepExpressionError(
607
- msg % (
608
- target_node_full_name,
609
- ag_node.full_name,
610
- ag_node.id
611
- )
600
+
601
+ if not ag_node.model_asset:
602
+ raise AttackGraphException('Attack graph node is missing '
603
+ 'asset link')
604
+ lang_graph_asset = self.lang_graph.assets[
605
+ ag_node.model_asset.type]
606
+
607
+ lang_graph_attack_step = lang_graph_asset.attack_steps[
608
+ ag_node.name]
609
+
610
+ while lang_graph_attack_step:
611
+ for child in lang_graph_attack_step.children.values():
612
+ for target_attack_step, expr_chain in child:
613
+ target_assets = self._follow_expr_chain(
614
+ self.model,
615
+ set([ag_node.model_asset]),
616
+ expr_chain
612
617
  )
613
- ag_node.children.append(target_node)
614
- target_node.parents.append(ag_node)
618
+
619
+ for target_asset in target_assets:
620
+ if target_asset is not None:
621
+ target_node_full_name = target_asset.name + \
622
+ ':' + target_attack_step.name
623
+ target_node = self.get_node_by_full_name(
624
+ target_node_full_name)
625
+ if target_node is None:
626
+ msg = ('Failed to find target node '
627
+ '"%s" to link with for attack '
628
+ 'step "%s"(%d)!')
629
+ logger.error(
630
+ msg,
631
+ target_node_full_name,
632
+ ag_node.full_name,
633
+ ag_node.id
634
+ )
635
+ raise AttackGraphStepExpressionError(
636
+ msg % (
637
+ target_node_full_name,
638
+ ag_node.full_name,
639
+ ag_node.id
640
+ )
641
+ )
642
+
643
+ assert ag_node.id is not None
644
+ assert target_node.id is not None
645
+
646
+ logger.debug('Linking attack step "%s"(%d) '
647
+ 'to attack step "%s"(%d)' %
648
+ (
649
+ ag_node.full_name,
650
+ ag_node.id,
651
+ target_node.full_name,
652
+ target_node.id
653
+ )
654
+ )
655
+ ag_node.children.add(target_node)
656
+ target_node.parents.add(ag_node)
657
+ if lang_graph_attack_step.overrides:
658
+ break
659
+ lang_graph_attack_step = lang_graph_attack_step.inherits
660
+
615
661
 
616
662
  def regenerate_graph(self) -> None:
617
663
  """
@@ -619,37 +665,68 @@ class AttackGraph():
619
665
  the MAL language specification provided at initialization.
620
666
  """
621
667
 
622
- self.nodes = []
623
- self.attackers = []
668
+ self.nodes = {}
669
+ self.attackers = {}
624
670
  self._generate_graph()
625
671
 
626
672
  def add_node(
627
673
  self,
628
- node: AttackGraphNode,
629
- node_id: Optional[int] = None
630
- ) -> None:
631
- """Add a node to the graph
674
+ lg_attack_step: LanguageGraphAttackStep,
675
+ node_id: Optional[int] = None,
676
+ model_asset: Optional[ModelAsset] = None,
677
+ defense_status: Optional[float] = None,
678
+ existence_status: Optional[bool] = None
679
+ ) -> AttackGraphNode:
680
+ """Create and add a node to the graph
632
681
  Arguments:
633
- node - the node to add
634
- node_id - the id to assign to this node, usually used when loading
635
- an attack graph from a file
636
- """
637
- if logger.isEnabledFor(logging.DEBUG):
638
- # Avoid running json.dumps when not in debug
639
- logger.debug(f'Add node \"{node.full_name}\" '
640
- f'with id:{node_id}:\n' \
641
- + json.dumps(node.to_dict(), indent = 2))
682
+ lg_attack_step - the language graph attack step that corresponds
683
+ to the attack graph node to create
684
+ node_id - id to assign to the newly created node, usually
685
+ provided only when loading an existing attack
686
+ graph from a file. If not provided the id will
687
+ be set to the next highest id available.
688
+ model_asset - the model asset that corresponds to the attack
689
+ step node. While optional it is highly
690
+ recommended that this be provided. It should
691
+ only be ommitted if the model which was used to
692
+ generate the attack graph is not available when
693
+ loading an attack graph from a file.
694
+ defese_status - the defense status of the node. Only, relevant
695
+ for defense type nodes. A value between 0.0 and
696
+ 1.0 is expected.
697
+ existence_status - the existence status of the node. Only, relevant
698
+ for exist and notExist type nodes.
642
699
 
643
- if node.id in self._id_to_node:
700
+ Return:
701
+ The newly created attack step node.
702
+ """
703
+ node_id = node_id if node_id is not None else self.next_node_id
704
+ if node_id in self.nodes:
644
705
  raise ValueError(f'Node index {node_id} already in use.')
706
+ self.next_node_id = max(node_id + 1, self.next_node_id)
645
707
 
646
- node.id = node_id if node_id is not None else self.next_node_id
647
- self.next_node_id = max(node.id + 1, self.next_node_id)
708
+ if logger.isEnabledFor(logging.DEBUG):
709
+ # Avoid running json.dumps when not in debug
710
+ logger.debug('Create and add to attackgraph node of type "%s" '
711
+ 'with id:%d.\n' % (
712
+ lg_attack_step.full_name,
713
+ node_id
714
+ ))
715
+
716
+
717
+ node = AttackGraphNode(
718
+ node_id = node_id,
719
+ lg_attack_step = lg_attack_step,
720
+ model_asset = model_asset,
721
+ defense_status = defense_status,
722
+ existence_status = existence_status
723
+ )
648
724
 
649
- self.nodes.append(node)
650
- self._id_to_node[node.id] = node
725
+ self.nodes[node_id] = node
651
726
  self._full_name_to_node[node.full_name] = node
652
727
 
728
+ return node
729
+
653
730
  def remove_node(self, node: AttackGraphNode) -> None:
654
731
  """Remove node from attack graph
655
732
  Arguments:
@@ -662,11 +739,10 @@ class AttackGraph():
662
739
  child.parents.remove(node)
663
740
  for parent in node.parents:
664
741
  parent.children.remove(node)
665
- self.nodes.remove(node)
666
742
 
667
743
  if not isinstance(node.id, int):
668
744
  raise ValueError(f'Invalid node id.')
669
- del self._id_to_node[node.id]
745
+ del self.nodes[node.id]
670
746
  del self._full_name_to_node[node.full_name]
671
747
 
672
748
  def add_attacker(
@@ -687,24 +763,26 @@ class AttackGraph():
687
763
  reached_attack_steps - list of ids of the attack steps that the
688
764
  attacker has reached
689
765
  """
766
+
690
767
  if logger.isEnabledFor(logging.DEBUG):
691
768
  # Avoid running json.dumps when not in debug
692
769
  if attacker_id is not None:
693
770
  logger.debug('Add attacker "%s" with id:%d.',
694
771
  attacker.name,
695
- attacker_id)
772
+ attacker_id
773
+ )
696
774
  else:
697
775
  logger.debug('Add attacker "%s" without id.',
698
- attacker.name)
699
-
776
+ attacker.name
777
+ )
700
778
 
701
779
  attacker.id = attacker_id or self.next_attacker_id
702
- if attacker.id in self._id_to_attacker:
780
+ if attacker.id in self.attackers:
703
781
  raise ValueError(f'Attacker index {attacker_id} already in use.')
704
782
 
705
783
  self.next_attacker_id = max(attacker.id + 1, self.next_attacker_id)
706
784
  for node_id in reached_attack_steps:
707
- node = self.get_node_by_id(node_id)
785
+ node = self.nodes[node_id]
708
786
  if node:
709
787
  attacker.compromise(node)
710
788
  else:
@@ -713,16 +791,15 @@ class AttackGraph():
713
791
  logger.error(msg, node_id)
714
792
  raise AttackGraphException(msg % node_id)
715
793
  for node_id in entry_points:
716
- node = self.get_node_by_id(int(node_id))
794
+ node = self.nodes[node_id]
717
795
  if node:
718
- attacker.entry_points.append(node)
796
+ attacker.entry_points.add(node)
719
797
  else:
720
798
  msg = ("Could not find node with id %d"
721
799
  "in attacker entrypoints.")
722
800
  logger.error(msg, node_id)
723
801
  raise AttackGraphException(msg % node_id)
724
- self.attackers.append(attacker)
725
- self._id_to_attacker[attacker.id] = attacker
802
+ self.attackers[attacker.id] = attacker
726
803
 
727
804
  def remove_attacker(self, attacker: Attacker):
728
805
  """Remove attacker from attack graph
@@ -736,7 +813,6 @@ class AttackGraph():
736
813
  attacker.id)
737
814
  for node in attacker.reached_attack_steps:
738
815
  attacker.undo_compromise(node)
739
- self.attackers.remove(attacker)
740
816
  if not isinstance(attacker.id, int):
741
817
  raise ValueError(f'Invalid attacker id.')
742
- del self._id_to_attacker[attacker.id]
818
+ del self.attackers[attacker.id]