mal-toolbox 0.0.28__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {mal_toolbox-0.0.28.dist-info → mal_toolbox-0.1.12.dist-info}/METADATA +60 -28
  2. mal_toolbox-0.1.12.dist-info/RECORD +32 -0
  3. {mal_toolbox-0.0.28.dist-info → mal_toolbox-0.1.12.dist-info}/WHEEL +1 -1
  4. maltoolbox/__init__.py +31 -31
  5. maltoolbox/__main__.py +80 -4
  6. maltoolbox/attackgraph/__init__.py +8 -0
  7. maltoolbox/attackgraph/analyzers/__init__.py +0 -0
  8. maltoolbox/attackgraph/analyzers/apriori.py +173 -27
  9. maltoolbox/attackgraph/attacker.py +84 -25
  10. maltoolbox/attackgraph/attackgraph.py +503 -215
  11. maltoolbox/attackgraph/node.py +92 -31
  12. maltoolbox/attackgraph/query.py +125 -19
  13. maltoolbox/default.conf +8 -7
  14. maltoolbox/exceptions.py +45 -0
  15. maltoolbox/file_utils.py +66 -0
  16. maltoolbox/ingestors/__init__.py +0 -0
  17. maltoolbox/ingestors/neo4j.py +95 -84
  18. maltoolbox/language/__init__.py +4 -0
  19. maltoolbox/language/classes_factory.py +145 -64
  20. maltoolbox/language/{lexer_parser/__main__.py → compiler/__init__.py} +5 -12
  21. maltoolbox/language/{lexer_parser → compiler}/mal_lexer.py +1 -1
  22. maltoolbox/language/{lexer_parser → compiler}/mal_parser.py +1 -1
  23. maltoolbox/language/{lexer_parser → compiler}/mal_visitor.py +4 -5
  24. maltoolbox/language/languagegraph.py +569 -168
  25. maltoolbox/model.py +858 -0
  26. maltoolbox/translators/__init__.py +0 -0
  27. maltoolbox/translators/securicad.py +76 -52
  28. maltoolbox/translators/updater.py +132 -0
  29. maltoolbox/wrappers.py +62 -0
  30. mal_toolbox-0.0.28.dist-info/RECORD +0 -26
  31. maltoolbox/cl_parser.py +0 -89
  32. maltoolbox/language/specification.py +0 -265
  33. maltoolbox/main.py +0 -84
  34. maltoolbox/model/model.py +0 -282
  35. {mal_toolbox-0.0.28.dist-info → mal_toolbox-0.1.12.dist-info}/AUTHORS +0 -0
  36. {mal_toolbox-0.0.28.dist-info → mal_toolbox-0.1.12.dist-info}/LICENSE +0 -0
  37. {mal_toolbox-0.0.28.dist-info → mal_toolbox-0.1.12.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,44 @@
1
1
  """
2
2
  MAL-Toolbox Attack Graph Module
3
3
  """
4
-
4
+ from __future__ import annotations
5
+ import copy
5
6
  import logging
6
7
  import json
7
8
 
8
- from typing import List, Optional
9
+ from typing import TYPE_CHECKING
10
+
11
+ from .node import AttackGraphNode
12
+ from .attacker import Attacker
13
+ from ..exceptions import AttackGraphStepExpressionError
14
+ from ..model import Model
15
+ from ..exceptions import AttackGraphException
16
+ from ..file_utils import (
17
+ load_dict_from_json_file,
18
+ load_dict_from_yaml_file,
19
+ save_dict_to_file
20
+ )
9
21
 
10
- from maltoolbox.language import specification
11
- from maltoolbox.model import model
12
- from maltoolbox.attackgraph import node
13
- from maltoolbox.attackgraph import attacker
22
+ if TYPE_CHECKING:
23
+ from typing import Any, Optional
24
+ from ..language import LanguageGraph
14
25
 
15
26
  logger = logging.getLogger(__name__)
16
27
 
17
- def _process_step_expression(lang: dict, model: model.Model,
18
- target_assets: List, step_expression: dict):
28
+ # TODO see if (part of) this can be incorporated into the LanguageGraph, so that
29
+ # the LanguageGraph's _lang_spec private property does not need to be accessed
30
+ def _process_step_expression(
31
+ lang_graph: LanguageGraph,
32
+ model: Model,
33
+ target_assets: list[Any],
34
+ step_expression: dict[str, Any]
35
+ ) -> tuple[list, Optional[str]]:
19
36
  """
20
37
  Recursively process an attack step expression.
21
38
 
22
39
  Arguments:
23
- lang - a dictionary representing the MAL language specification
40
+ lang_graph - a language graph representing the MAL language
41
+ specification
24
42
  model - a maltoolbox.model.Model instance from which the attack
25
43
  graph was generated
26
44
  target_assets - the list of assets that this step expression should apply
@@ -32,8 +50,13 @@ def _process_step_expression(lang: dict, model: model.Model,
32
50
  A tuple pair containing a list of all of the target assets and the name of
33
51
  the attack step.
34
52
  """
35
- logger.debug('Processing Step Expression:\n' \
36
- + json.dumps(step_expression, indent = 2))
53
+
54
+ if logger.isEnabledFor(logging.DEBUG):
55
+ # Avoid running json.dumps when not in debug
56
+ logger.debug(
57
+ 'Processing Step Expression:\n%s',
58
+ json.dumps(step_expression, indent = 2)
59
+ )
37
60
 
38
61
  match (step_expression['type']):
39
62
  case 'attackStep':
@@ -45,9 +68,9 @@ def _process_step_expression(lang: dict, model: model.Model,
45
68
  # The set operators are used to combine the left hand and right
46
69
  # hand targets accordingly.
47
70
  lh_targets, lh_attack_steps = _process_step_expression(
48
- lang, model, target_assets, step_expression['lhs'])
71
+ lang_graph, model, target_assets, step_expression['lhs'])
49
72
  rh_targets, rh_attack_steps = _process_step_expression(
50
- lang, model, target_assets, step_expression['rhs'])
73
+ lang_graph, model, target_assets, step_expression['rhs'])
51
74
 
52
75
  new_target_assets = []
53
76
  match (step_expression['type']):
@@ -60,7 +83,7 @@ def _process_step_expression(lang: dict, model: model.Model,
60
83
 
61
84
  case 'intersection':
62
85
  for ag_node in rh_targets:
63
- if next((lnode for lnode in new_target_assets \
86
+ if next((lnode for lnode in lh_targets \
64
87
  if lnode.id == ag_node.id), None):
65
88
  new_target_assets.append(ag_node)
66
89
 
@@ -77,17 +100,19 @@ def _process_step_expression(lang: dict, model: model.Model,
77
100
  # Fetch the step expression associated with the variable from
78
101
  # the language specification and resolve that.
79
102
  for target_asset in target_assets:
80
- if (hasattr(target_asset, 'metaconcept')):
81
- variable_step_expr = specification.\
82
- get_variable_for_class_by_name(lang,
83
- target_asset.metaconcept, step_expression['name'])
103
+ if (hasattr(target_asset, 'type')):
104
+ # TODO how can this info be accessed in the lang_graph
105
+ # directly without going through the private method?
106
+ variable_step_expr = lang_graph._get_variable_for_asset_type_by_name(
107
+ target_asset.type, step_expression['name'])
84
108
  return _process_step_expression(
85
- lang, model, target_assets, variable_step_expr)
109
+ lang_graph, model, target_assets, variable_step_expr)
86
110
 
87
111
  else:
88
- logger.error('Requested variable from non-asset'
89
- f'target node: {target_asset} which cannot be'
90
- 'resolved.')
112
+ logger.error(
113
+ 'Requested variable from non-asset target node:'
114
+ '%s which cannot be resolved.', target_asset
115
+ )
91
116
  return ([], None)
92
117
 
93
118
  case 'field':
@@ -112,7 +137,7 @@ def _process_step_expression(lang: dict, model: model.Model,
112
137
  step_expression['stepExpression']['name']))
113
138
  if new_target_assets:
114
139
  (additional_assets, _) = _process_step_expression(
115
- lang, model, new_target_assets, step_expression)
140
+ lang_graph, model, new_target_assets, step_expression)
116
141
  new_target_assets.extend(additional_assets)
117
142
  return (new_target_assets, None)
118
143
  else:
@@ -122,93 +147,181 @@ def _process_step_expression(lang: dict, model: model.Model,
122
147
  new_target_assets = []
123
148
  for target_asset in target_assets:
124
149
  (assets, _) = _process_step_expression(
125
- lang, model, target_assets, step_expression['stepExpression'])
150
+ lang_graph, model, target_assets,
151
+ step_expression['stepExpression'])
126
152
  new_target_assets.extend(assets)
127
153
 
128
- selected_new_target_assets = (asset for asset in \
129
- new_target_assets if specification.extends_asset(
130
- lang,
131
- asset.metaconcept,
132
- step_expression['subType']))
154
+ selected_new_target_assets = []
155
+ for asset in new_target_assets:
156
+ lang_graph_asset = lang_graph.get_asset_by_name(
157
+ asset.type
158
+ )
159
+ if not lang_graph_asset:
160
+ raise LookupError(
161
+ f'Failed to find asset \"{asset.type}\" in the '
162
+ 'language graph.'
163
+ )
164
+ lang_graph_subtype_asset = lang_graph.get_asset_by_name(
165
+ step_expression['subType']
166
+ )
167
+ if not lang_graph_subtype_asset:
168
+ raise LookupError(
169
+ 'Failed to find asset '
170
+ f'\"{step_expression["subType"]}\" in the '
171
+ 'language graph.'
172
+ )
173
+ if lang_graph_asset.is_subasset_of(lang_graph_subtype_asset):
174
+ selected_new_target_assets.append(asset)
175
+
133
176
  return (selected_new_target_assets, None)
134
177
 
135
178
  case 'collect':
136
179
  # Apply the right hand step expression to left hand step
137
180
  # expression target assets.
138
181
  lh_targets, _ = _process_step_expression(
139
- lang, model, target_assets, step_expression['lhs'])
140
- return _process_step_expression(lang, model, lh_targets,
182
+ lang_graph, model, target_assets, step_expression['lhs'])
183
+ return _process_step_expression(lang_graph, model, lh_targets,
141
184
  step_expression['rhs'])
142
185
 
143
186
 
144
187
  case _:
145
- logger.error('Unknown attack step type: '
146
- f'{step_expression["type"]}')
188
+ logger.error(
189
+ 'Unknown attack step type: %s', step_expression["type"]
190
+ )
147
191
  return ([], None)
148
192
 
193
+ class AttackGraph():
194
+ """Graph representation of attack steps"""
195
+ def __init__(self, lang_graph = None, model: Optional[Model] = None):
196
+ self.nodes: list[AttackGraphNode] = []
197
+ self.attackers: list[Attacker] = []
198
+ # Dictionaries used in optimization to get nodes and attackers by id
199
+ # or full name faster
200
+ self._id_to_node: dict[int, AttackGraphNode] = {}
201
+ self._full_name_to_node: dict[str, AttackGraphNode] = {}
202
+ self._id_to_attacker: dict[int, Attacker] = {}
149
203
 
150
-
151
- class AttackGraph:
152
- def __init__(self, lang_spec = None, model: Optional[model.Model] = None):
153
- self.nodes = []
154
- self.attackers = []
155
204
  self.model = model
156
- self.lang_spec = lang_spec
157
- if self.model is not None and self.lang_spec is not None:
158
- self.generate_graph(self.lang_spec, self.model)
205
+ self.lang_graph = lang_graph
206
+ self.next_node_id = 0
207
+ self.next_attacker_id = 0
208
+ if self.model is not None and self.lang_graph is not None:
209
+ self._generate_graph()
159
210
 
160
211
  def __repr__(self) -> str:
161
212
  return f'AttackGraph({len(self.nodes)} nodes)'
162
213
 
163
- def save_to_file(self, filename: str):
164
- """
165
- Save the attack graph to a json file.
166
-
167
- Arguments:
168
- filename - the name of the output file
169
- """
170
-
171
- logger.info(f'Saving attack graph with {len(self.nodes)} attack step '
172
- f'nodes to {filename} file.')
173
- serialized_graph = []
214
+ def _to_dict(self) -> dict:
215
+ """Convert AttackGraph to dict"""
216
+ serialized_attack_steps = {}
217
+ serialized_attackers = {}
174
218
  for ag_node in self.nodes:
175
- serialized_graph.append(ag_node.to_dict())
176
- with open(filename, 'w', encoding='utf-8') as file:
177
- json.dump(serialized_graph, file, indent=4)
178
-
179
-
180
- def load_from_file(self, filename: str, model: Optional[model.Model] = None):
219
+ serialized_attack_steps[ag_node.full_name] =\
220
+ ag_node.to_dict()
221
+ for attacker in self.attackers:
222
+ serialized_attackers[attacker.name] = attacker.to_dict()
223
+ return {
224
+ 'attack_steps': serialized_attack_steps,
225
+ 'attackers': serialized_attackers,
226
+ }
227
+
228
+ def __deepcopy__(self, memo):
229
+
230
+ # Check if the object is already in the memo dictionary
231
+ if id(self) in memo:
232
+ return memo[id(self)]
233
+
234
+ copied_attackgraph = AttackGraph(self.lang_graph)
235
+ copied_attackgraph.model = self.model
236
+
237
+ copied_attackgraph.nodes = []
238
+
239
+ # Deep copy nodes
240
+ for node in self.nodes:
241
+ copied_node = copy.deepcopy(node, memo)
242
+ copied_attackgraph.nodes.append(copied_node)
243
+
244
+ # Re-link node references
245
+ for node in self.nodes:
246
+ if node.parents:
247
+ memo[id(node)].parents = copy.deepcopy(node.parents, memo)
248
+ if node.children:
249
+ memo[id(node)].children = copy.deepcopy(node.children, memo)
250
+
251
+ # Deep copy attackers and references to them
252
+ copied_attackgraph.attackers = copy.deepcopy(self.attackers, memo)
253
+
254
+ # Re-link attacker references
255
+ for node in self.nodes:
256
+ if node.compromised_by:
257
+ memo[id(node)].compromised_by = copy.deepcopy(
258
+ node.compromised_by, memo)
259
+
260
+ # Copy lookup dicts
261
+ copied_attackgraph._id_to_attacker = \
262
+ copy.deepcopy(self._id_to_attacker, memo)
263
+ copied_attackgraph._id_to_node = \
264
+ copy.deepcopy(self._id_to_node, memo)
265
+ copied_attackgraph._full_name_to_node = \
266
+ copy.deepcopy(self._full_name_to_node, memo)
267
+
268
+ # Copy counters
269
+ copied_attackgraph.next_node_id = self.next_node_id
270
+ copied_attackgraph.next_attacker_id = self.next_attacker_id
271
+
272
+ return copied_attackgraph
273
+
274
+ def save_to_file(self, filename: str) -> None:
275
+ """Save to json/yml depending on extension"""
276
+ logger.debug('Save attack graph to file "%s".', filename)
277
+ return save_dict_to_file(filename, self._to_dict())
278
+
279
+ @classmethod
280
+ def _from_dict(
281
+ cls,
282
+ serialized_object: dict,
283
+ model: Optional[Model]=None
284
+ ) -> AttackGraph:
285
+ """Create AttackGraph from dict
286
+ Args:
287
+ serialized_object - AttackGraph in dict format
288
+ model - Optional Model to add connections to
181
289
  """
182
- Load the attack graph model from a json file.
183
290
 
184
- Arguments:
185
- filename - the name of the input file to parse
186
- model - (optional) the instance model that the attack graph was
187
- generated from. If this given then the attack graph node
188
- and instance model asset link can be re-established. If
189
- this argument is not given the attack graph will still
190
- be created it will just omit the links to the assets.
191
- """
192
-
193
- logger.info(f'Loading attack graph from {filename} file.')
194
- if model:
195
- logger.info(f'Model(\'{model.name}\') was provided will attempt '
196
- 'to establish links to assets.')
197
- else:
198
- logger.info('No model was provided therefore asset links will '
199
- 'not be established.')
291
+ attack_graph = AttackGraph()
292
+ attack_graph.model = model
293
+ serialized_attack_steps = serialized_object['attack_steps']
294
+ serialized_attackers = serialized_object['attackers']
200
295
 
201
- with open(filename, 'r', encoding='utf-8') as file:
202
- serialized_graph = json.load(file)
203
296
  # Create all of the nodes in the imported attack graph.
204
- for node_dict in serialized_graph:
205
- ag_node = node.AttackGraphNode(
206
- id=node_dict['id'],
297
+ for node_full_name, node_dict in serialized_attack_steps.items():
298
+
299
+ # Recreate asset links if model is available.
300
+ node_asset = None
301
+ if model and 'asset' in node_dict:
302
+ node_asset = model.get_asset_by_name(node_dict['asset'])
303
+ if node_asset is None:
304
+ msg = ('Failed to find asset with id %s'
305
+ 'when loading from attack graph dict')
306
+ logger.error(msg, node_dict["asset"])
307
+ raise LookupError(msg % node_dict["asset"])
308
+
309
+ ag_node = AttackGraphNode(
207
310
  type=node_dict['type'],
208
311
  name=node_dict['name'],
209
- ttc=node_dict['ttc']
312
+ ttc=node_dict['ttc'],
313
+ asset=node_asset
210
314
  )
211
315
 
316
+ if node_asset:
317
+ # Add AttackGraphNode to attack_step_nodes of asset
318
+ if hasattr(node_asset, 'attack_step_nodes'):
319
+ node_attack_steps = list(node_asset.attack_step_nodes)
320
+ node_attack_steps.append(ag_node)
321
+ node_asset.attack_step_nodes = node_attack_steps
322
+ else:
323
+ node_asset.attack_step_nodes = [ag_node]
324
+
212
325
  ag_node.defense_status = float(node_dict['defense_status']) if \
213
326
  'defense_status' in node_dict else None
214
327
  ag_node.existence_status = node_dict['existence_status'] \
@@ -220,180 +333,214 @@ class AttackGraph:
220
333
  ag_node.mitre_info = str(node_dict['mitre_info']) if \
221
334
  'mitre_info' in node_dict else None
222
335
  ag_node.tags = node_dict['tags'] if \
223
- 'tags' in node_dict else None
224
- if ag_node.name == 'firstSteps':
225
- # This is an attacker entry point node, recreate the attacker.
226
- attacker_id = ag_node.id.split(':')[1]
227
- ag_attacker = attacker.Attacker(
228
- id = str(attacker_id),
229
- entry_points = [],
230
- reached_attack_steps = [],
231
- node = ag_node
232
- )
233
- self.attackers.append(ag_attacker)
234
- ag_node.attacker = ag_attacker
336
+ 'tags' in node_dict else []
337
+ ag_node.extras = node_dict.get('extras', {})
235
338
 
236
- self.nodes.append(ag_node)
339
+ # Add AttackGraphNode to AttackGraph
340
+ attack_graph.add_node(ag_node, node_id=node_dict['id'])
237
341
 
238
342
  # Re-establish links between nodes.
239
- for node_dict in serialized_graph:
240
- _ag_node: Optional[node.AttackGraphNode] = self.get_node_by_id(node_dict['id'])
241
- if not isinstance(_ag_node, node.AttackGraphNode):
242
- logger.error(f'Failed to find node with id {node_dict["id"]}'
243
- f' when loading from attack graph from file {filename}')
343
+ for node_full_name, node_dict in serialized_attack_steps.items():
344
+ _ag_node = attack_graph.get_node_by_id(node_dict['id'])
345
+ if not isinstance(_ag_node, AttackGraphNode):
346
+ msg = ('Failed to find node with id %s when loading'
347
+ ' attack graph from dict')
348
+ logger.error(msg, node_dict["id"])
349
+ raise LookupError(msg % node_dict["id"])
244
350
  else:
245
351
  for child_id in node_dict['children']:
246
- child = self.get_node_by_id(child_id)
352
+ child = attack_graph.get_node_by_id(int(child_id))
247
353
  if child is None:
248
- logger.error(f'Failed to find child node with id {child_id}'
249
- f' when loading from attack graph from file {filename}')
250
- return None
354
+ msg = ('Failed to find child node with id %s'
355
+ ' when loading from attack graph from dict')
356
+ logger.error(msg, child_id)
357
+ raise LookupError(msg % child_id)
251
358
  _ag_node.children.append(child)
252
359
 
253
- if isinstance(_ag_node.attacker, attacker.Attacker):
254
- # Relink the attacker related connections since the node
255
- # is an attacker entry point node.
256
- ag_attacker = _ag_node.attacker
257
- ag_attacker.entry_points.append(child)
258
- ag_attacker.compromise(child)
259
-
260
360
  for parent_id in node_dict['parents']:
261
- parent = self.get_node_by_id(parent_id)
361
+ parent = attack_graph.get_node_by_id(int(parent_id))
262
362
  if parent is None:
263
- logger.error('Failed to find parent node with id '
264
- f'{parent_id} when loading from attack graph from '
265
- f'file {filename}')
266
- return None
363
+ msg = ('Failed to find parent node with id %s '
364
+ 'when loading from attack graph from dict')
365
+ logger.error(msg, parent_id)
366
+ raise LookupError(msg % parent_id)
267
367
  _ag_node.parents.append(parent)
268
368
 
269
- # Also recreate asset links if model is available.
270
- if model and 'asset' in node_dict:
271
- asset = model.get_asset_by_id(
272
- int(node_dict['asset'].split(':')[1]))
273
- if asset is None:
274
- logger.error('Failed to find asset with id '
275
- f'{node_dict["asset"]} when loading from attack graph '
276
- f'from file {filename}')
277
- return None
278
- _ag_node.asset = asset
279
- if hasattr(asset, 'attack_step_nodes'):
280
- attack_step_nodes = list(asset.attack_step_nodes)
281
- attack_step_nodes.append(_ag_node)
282
- asset.attack_step_nodes = attack_step_nodes
283
- else:
284
- asset.attack_step_nodes = [_ag_node]
285
-
286
-
287
- def get_node_by_id(self, node_id: str) -> Optional[node.AttackGraphNode]:
369
+ for attacker_name, attacker in serialized_attackers.items():
370
+ ag_attacker = Attacker(
371
+ name = attacker['name'],
372
+ entry_points = [],
373
+ reached_attack_steps = []
374
+ )
375
+ attack_graph.add_attacker(
376
+ attacker = ag_attacker,
377
+ attacker_id = int(attacker['id']),
378
+ entry_points = attacker['entry_points'].keys(),
379
+ reached_attack_steps = [
380
+ int(node_id) # Convert to int since they can be strings
381
+ for node_id in attacker['reached_attack_steps'].keys()
382
+ ]
383
+ )
384
+
385
+ return attack_graph
386
+
387
+ @classmethod
388
+ def load_from_file(
389
+ cls,
390
+ filename: str,
391
+ model: Optional[Model]=None
392
+ ) -> AttackGraph:
393
+ """Create from json or yaml file depending on file extension"""
394
+ if model is not None:
395
+ logger.debug('Load attack graph from file "%s" with '
396
+ 'model "%s".', filename, model.name)
397
+ else:
398
+ logger.debug('Load attack graph from file "%s" '
399
+ 'without model.', filename)
400
+ serialized_attack_graph = None
401
+ if filename.endswith(('.yml', '.yaml')):
402
+ serialized_attack_graph = load_dict_from_yaml_file(filename)
403
+ elif filename.endswith('.json'):
404
+ serialized_attack_graph = load_dict_from_json_file(filename)
405
+ else:
406
+ raise ValueError('Unknown file extension, expected json/yml/yaml')
407
+ return cls._from_dict(serialized_attack_graph, model=model)
408
+
409
+ def get_node_by_id(self, node_id: int) -> Optional[AttackGraphNode]:
288
410
  """
289
411
  Return the attack node that matches the id provided.
290
412
 
291
413
  Arguments:
292
- node_id - the id of the attack graph none we are looking for
414
+ node_id - the id of the attack graph node we are looking for
293
415
 
294
416
  Return:
295
417
  The attack step node that matches the given id.
296
418
  """
297
419
 
298
- logger.debug(f'Looking up node with id {node_id}')
299
- return next((ag_node for ag_node in self.nodes \
300
- if ag_node.id == node_id), None)
420
+ logger.debug('Looking up node with id %s', node_id)
421
+ return self._id_to_node.get(node_id)
301
422
 
423
+ def get_node_by_full_name(self, full_name: str) -> Optional[AttackGraphNode]:
424
+ """
425
+ Return the attack node that matches the full name provided.
302
426
 
303
- def attach_attackers(self, model: model.Model):
427
+ Arguments:
428
+ full_name - the full name of the attack graph node we are looking
429
+ for
430
+
431
+ Return:
432
+ The attack step node that matches the given full name.
304
433
  """
305
- Create attackers and their entry point nodes and attach them to the
306
- relevant attack step nodes and to the attackers.
434
+
435
+ logger.debug(f'Looking up node with full name "{full_name}"')
436
+ return self._full_name_to_node.get(full_name)
437
+
438
+ def get_attacker_by_id(self, attacker_id: int) -> Optional[Attacker]:
439
+ """
440
+ Return the attacker that matches the id provided.
307
441
 
308
442
  Arguments:
309
- model - the instance model containing the attackers
443
+ attacker_id - the id of the attacker we are looking for
444
+
445
+ Return:
446
+ The attacker that matches the given id.
310
447
  """
311
448
 
312
- logger.info(f'Attach attackers from \'{model.name}\' model to the '
313
- 'graph.')
314
- for attacker_info in model.attackers:
315
- attacker_node = node.AttackGraphNode(
316
- id = 'Attacker:' + str(attacker_info.id) + ':firstSteps',
317
- type = 'or',
318
- asset = None,
319
- name = 'firstSteps',
320
- ttc = {},
321
- children = [],
322
- parents = [],
323
- compromised_by = []
324
- )
449
+ logger.debug(f'Looking up attacker with id {attacker_id}')
450
+ return self._id_to_attacker.get(attacker_id)
451
+
452
+ def attach_attackers(self) -> None:
453
+ """
454
+ Create attackers and their entry point nodes and attach them to the
455
+ relevant attack step nodes and to the attackers.
456
+ """
457
+
458
+ if not self.model:
459
+ msg = "Can not attach attackers without a model"
460
+ logger.error(msg)
461
+ raise AttackGraphException(msg)
462
+
463
+ logger.info(
464
+ 'Attach attackers from "%s" model to the graph.', self.model.name
465
+ )
466
+
467
+ for attacker_info in self.model.attackers:
468
+
469
+ if not attacker_info.name:
470
+ msg = "Can not attach attacker without name"
471
+ logger.error(msg)
472
+ raise AttackGraphException(msg)
325
473
 
326
- ag_attacker = attacker.Attacker(
327
- id = str(attacker_info.id),
474
+ attacker = Attacker(
475
+ name = attacker_info.name,
328
476
  entry_points = [],
329
- reached_attack_steps = [],
330
- node = attacker_node
477
+ reached_attack_steps = []
331
478
  )
332
- attacker_node.attacker = ag_attacker
333
- self.attackers.append(ag_attacker)
479
+ self.add_attacker(attacker)
334
480
 
335
481
  for (asset, attack_steps) in attacker_info.entry_points:
336
482
  for attack_step in attack_steps:
337
- attack_step_id = asset.metaconcept + ':' \
338
- + str(asset.id) + ':' + attack_step
339
- ag_node = self.get_node_by_id(attack_step_id)
483
+ full_name = asset.name + ':' + attack_step
484
+ ag_node = self.get_node_by_full_name(full_name)
340
485
  if not ag_node:
341
- logger.warning('Failed to find attacker entry point '
342
- + attack_step_id + ' for Attacker:'
343
- + ag_attacker.id + '.')
486
+ logger.warning(
487
+ 'Failed to find attacker entry point '
488
+ '%s for %s.',
489
+ full_name, attacker.name
490
+ )
344
491
  continue
345
- ag_attacker.compromise(ag_node)
492
+ attacker.compromise(ag_node)
346
493
 
347
- ag_attacker.entry_points = ag_attacker.reached_attack_steps
348
- attacker_node.children = ag_attacker.entry_points
349
- self.nodes.append(attacker_node)
494
+ attacker.entry_points = list(attacker.reached_attack_steps)
350
495
 
351
-
352
- def generate_graph(self, lang: Optional[dict] = None, model: Optional[model.Model] = None):
496
+ def _generate_graph(self) -> None:
353
497
  """
354
- Generate attack graph starting from a model instance
355
- and a MAL language specification
356
-
357
- Arguments:
358
- model - a maltoolbox.model.Model instance
359
- lang - a dictionary representing the MAL language specification
498
+ Generate the attack graph based on the original model instance and the
499
+ MAL language specification provided at initialization.
360
500
  """
361
501
 
362
- if model is not None:
363
- self.model = model
364
- if lang is not None:
365
- self.lang_spec = lang
366
- if self.model is None or self.lang_spec is None:
367
- return
502
+ if not self.model:
503
+ msg = "Can not generate AttackGraph without model"
504
+ logger.error(msg)
505
+ raise AttackGraphException(msg)
368
506
 
369
507
  # First, generate all of the nodes of the attack graph.
370
508
  for asset in self.model.assets:
371
- logger.debug(f'Generating attack steps for asset {asset.name} which '\
372
- f'is of class {asset.metaconcept}.')
509
+
510
+ logger.debug(
511
+ 'Generating attack steps for asset %s which is of class %s.',
512
+ asset.name, asset.type
513
+ )
514
+
373
515
  attack_step_nodes = []
374
- attack_steps = specification.get_attacks_for_class(lang,
375
- asset.metaconcept)
516
+
517
+ # TODO probably part of what happens here is already done in lang_graph
518
+ attack_steps = self.lang_graph._get_attacks_for_asset_type(asset.type)
519
+
376
520
  for attack_step_name, attack_step_attribs in attack_steps.items():
377
- logger.debug('Generating attack step node for '\
378
- f'{attack_step_name}.')
521
+ logger.debug(
522
+ 'Generating attack step node for %s.', attack_step_name
523
+ )
379
524
 
380
525
  defense_status = None
381
- existence_status: Optional[bool] = None
382
- node_id = asset.metaconcept + ':' + str(asset.id) + ':' + attack_step_name
526
+ existence_status = None
527
+ node_name = asset.name + ':' + attack_step_name
383
528
 
384
529
  match (attack_step_attribs['type']):
385
530
  case 'defense':
386
531
  # Set the defense status for defenses
387
532
  defense_status = getattr(asset, attack_step_name)
388
- logger.debug('Setting the defense status of '\
389
- f'{node_id} to {defense_status}.')
533
+ logger.debug(
534
+ 'Setting the defense status of %s to %s.',
535
+ node_name, defense_status
536
+ )
390
537
 
391
538
  case 'exist' | 'notExist':
392
539
  # Resolve step expression associated with (non-)existence
393
540
  # attack steps.
394
541
  (target_assets, attack_step) = _process_step_expression(
395
- lang,
396
- model,
542
+ self.lang_graph,
543
+ self.model,
397
544
  [asset],
398
545
  attack_step_attribs['requires']['stepExpressions'][0])
399
546
  # If the step expression resolution yielded the target
@@ -402,8 +549,7 @@ class AttackGraph:
402
549
 
403
550
  mitre_info = attack_step_attribs['meta']['mitre'] if 'mitre' in\
404
551
  attack_step_attribs['meta'] else None
405
- ag_node = node.AttackGraphNode(
406
- id = node_id,
552
+ ag_node = AttackGraphNode(
407
553
  type = attack_step_attribs['type'],
408
554
  asset = asset,
409
555
  name = attack_step_name,
@@ -420,13 +566,16 @@ class AttackGraph:
420
566
  )
421
567
  ag_node.attributes = attack_step_attribs
422
568
  attack_step_nodes.append(ag_node)
423
- self.nodes.append(ag_node)
569
+ self.add_node(ag_node)
424
570
  asset.attack_step_nodes = attack_step_nodes
425
571
 
426
572
  # Then, link all of the nodes according to their associations.
427
573
  for ag_node in self.nodes:
428
- logger.debug('Determining children for attack step '\
429
- f'{ag_node.id}.')
574
+ logger.debug(
575
+ 'Determining children for attack step "%s"(%d)',
576
+ ag_node.full_name,
577
+ ag_node.id
578
+ )
430
579
  step_expressions = \
431
580
  ag_node.attributes['reaches']['stepExpressions'] if \
432
581
  isinstance(ag_node.attributes, dict) and ag_node.attributes['reaches'] else []
@@ -434,21 +583,160 @@ class AttackGraph:
434
583
  for step_expression in step_expressions:
435
584
  # Resolve each of the attack step expressions listed for this
436
585
  # attack step to determine children.
437
- (target_assets, attack_step) = _process_step_expression(lang,
438
- model, [ag_node.asset], step_expression)
586
+ (target_assets, attack_step) = _process_step_expression(
587
+ self.lang_graph,
588
+ self.model,
589
+ [ag_node.asset],
590
+ step_expression)
591
+
439
592
  for target in target_assets:
440
- target_node_id = target.metaconcept + ':' \
441
- + str(target.id) + ':' + attack_step
442
- target_node = self.get_node_by_id(target_node_id)
593
+ target_node_full_name = target.name + ':' + attack_step
594
+ target_node = self.get_node_by_full_name(
595
+ target_node_full_name
596
+ )
443
597
  if not target_node:
444
- logger.error('Failed to find targed node ' \
445
- f'{target_node_id} to link with for attack step ' \
446
- f'{ag_node.id}!')
447
- print('Failed to find targed node ' \
448
- f'{target_node_id} to link with for attack step ' \
449
- f'{ag_node.id}!')
450
- return 1
598
+ msg = ('Failed to find target node '
599
+ '"%s" to link with for attack step "%s"(%d)!')
600
+ logger.error(
601
+ msg,
602
+ target_node_full_name,
603
+ ag_node.full_name,
604
+ ag_node.id
605
+ )
606
+ raise AttackGraphStepExpressionError(
607
+ msg % (
608
+ target_node_full_name,
609
+ ag_node.full_name,
610
+ ag_node.id
611
+ )
612
+ )
451
613
  ag_node.children.append(target_node)
452
614
  target_node.parents.append(ag_node)
453
615
 
454
- return 0
616
+ def regenerate_graph(self) -> None:
617
+ """
618
+ Regenerate the attack graph based on the original model instance and
619
+ the MAL language specification provided at initialization.
620
+ """
621
+
622
+ self.nodes = []
623
+ self.attackers = []
624
+ self._generate_graph()
625
+
626
+ def add_node(
627
+ self,
628
+ node: AttackGraphNode,
629
+ node_id: Optional[int] = None
630
+ ) -> None:
631
+ """Add a node to the graph
632
+ Arguments:
633
+ node - the node to add
634
+ node_id - the id to assign to this node, usually used when loading
635
+ an attack graph from a file
636
+ """
637
+ if logger.isEnabledFor(logging.DEBUG):
638
+ # Avoid running json.dumps when not in debug
639
+ logger.debug(f'Add node \"{node.full_name}\" '
640
+ f'with id:{node_id}:\n' \
641
+ + json.dumps(node.to_dict(), indent = 2))
642
+
643
+ if node.id in self._id_to_node:
644
+ raise ValueError(f'Node index {node_id} already in use.')
645
+
646
+ node.id = node_id if node_id is not None else self.next_node_id
647
+ self.next_node_id = max(node.id + 1, self.next_node_id)
648
+
649
+ self.nodes.append(node)
650
+ self._id_to_node[node.id] = node
651
+ self._full_name_to_node[node.full_name] = node
652
+
653
+ def remove_node(self, node: AttackGraphNode) -> None:
654
+ """Remove node from attack graph
655
+ Arguments:
656
+ node - the node we wish to remove from the attack graph
657
+ """
658
+ if logger.isEnabledFor(logging.DEBUG):
659
+ # Avoid running json.dumps when not in debug
660
+ logger.debug(f'Remove node "%s"(%d).', node.full_name, node.id)
661
+ for child in node.children:
662
+ child.parents.remove(node)
663
+ for parent in node.parents:
664
+ parent.children.remove(node)
665
+ self.nodes.remove(node)
666
+
667
+ if not isinstance(node.id, int):
668
+ raise ValueError(f'Invalid node id.')
669
+ del self._id_to_node[node.id]
670
+ del self._full_name_to_node[node.full_name]
671
+
672
+ def add_attacker(
673
+ self,
674
+ attacker: Attacker,
675
+ attacker_id: Optional[int] = None,
676
+ entry_points: list[int] = [],
677
+ reached_attack_steps: list[int] = []
678
+ ):
679
+ """Add an attacker to the graph
680
+ Arguments:
681
+ attacker - the attacker to add
682
+ attacker_id - the id to assign to this attacker, usually
683
+ used when loading an attack graph from a
684
+ file
685
+ entry_points - list of attack step ids that serve as entry
686
+ points for the attacker
687
+ reached_attack_steps - list of ids of the attack steps that the
688
+ attacker has reached
689
+ """
690
+ if logger.isEnabledFor(logging.DEBUG):
691
+ # Avoid running json.dumps when not in debug
692
+ if attacker_id is not None:
693
+ logger.debug('Add attacker "%s" with id:%d.',
694
+ attacker.name,
695
+ attacker_id)
696
+ else:
697
+ logger.debug('Add attacker "%s" without id.',
698
+ attacker.name)
699
+
700
+
701
+ attacker.id = attacker_id or self.next_attacker_id
702
+ if attacker.id in self._id_to_attacker:
703
+ raise ValueError(f'Attacker index {attacker_id} already in use.')
704
+
705
+ self.next_attacker_id = max(attacker.id + 1, self.next_attacker_id)
706
+ for node_id in reached_attack_steps:
707
+ node = self.get_node_by_id(node_id)
708
+ if node:
709
+ attacker.compromise(node)
710
+ else:
711
+ msg = ("Could not find node with id %d"
712
+ "in reached attack steps.")
713
+ logger.error(msg, node_id)
714
+ raise AttackGraphException(msg % node_id)
715
+ for node_id in entry_points:
716
+ node = self.get_node_by_id(int(node_id))
717
+ if node:
718
+ attacker.entry_points.append(node)
719
+ else:
720
+ msg = ("Could not find node with id %d"
721
+ "in attacker entrypoints.")
722
+ logger.error(msg, node_id)
723
+ raise AttackGraphException(msg % node_id)
724
+ self.attackers.append(attacker)
725
+ self._id_to_attacker[attacker.id] = attacker
726
+
727
+ def remove_attacker(self, attacker: Attacker):
728
+ """Remove attacker from attack graph
729
+ Arguments:
730
+ attacker - the attacker we wish to remove from the attack graph
731
+ """
732
+ if logger.isEnabledFor(logging.DEBUG):
733
+ # Avoid running json.dumps when not in debug
734
+ logger.debug('Remove attacker "%s" with id:%d.',
735
+ attacker.name,
736
+ attacker.id)
737
+ for node in attacker.reached_attack_steps:
738
+ attacker.undo_compromise(node)
739
+ self.attackers.remove(attacker)
740
+ if not isinstance(attacker.id, int):
741
+ raise ValueError(f'Invalid attacker id.')
742
+ del self._id_to_attacker[attacker.id]