mal-toolbox 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {mal_toolbox-1.1.0.dist-info → mal_toolbox-1.1.2.dist-info}/METADATA +26 -2
  2. mal_toolbox-1.1.2.dist-info/RECORD +32 -0
  3. maltoolbox/__init__.py +6 -7
  4. maltoolbox/__main__.py +17 -9
  5. maltoolbox/attackgraph/__init__.py +2 -3
  6. maltoolbox/attackgraph/attackgraph.py +379 -362
  7. maltoolbox/attackgraph/node.py +14 -19
  8. maltoolbox/exceptions.py +7 -10
  9. maltoolbox/file_utils.py +10 -4
  10. maltoolbox/language/__init__.py +1 -1
  11. maltoolbox/language/compiler/__init__.py +4 -4
  12. maltoolbox/language/compiler/mal_lexer.py +154 -154
  13. maltoolbox/language/compiler/mal_parser.py +784 -1136
  14. maltoolbox/language/languagegraph.py +491 -636
  15. maltoolbox/model.py +85 -77
  16. maltoolbox/patternfinder/attackgraph_patterns.py +17 -8
  17. maltoolbox/translators/__init__.py +8 -0
  18. maltoolbox/translators/networkx.py +42 -0
  19. maltoolbox/translators/updater.py +18 -25
  20. maltoolbox/visualization/__init__.py +4 -4
  21. maltoolbox/visualization/draw_io_utils.py +6 -5
  22. maltoolbox/visualization/graphviz_utils.py +4 -2
  23. maltoolbox/visualization/neo4j_utils.py +13 -14
  24. maltoolbox/visualization/utils.py +2 -3
  25. mal_toolbox-1.1.0.dist-info/RECORD +0 -32
  26. maltoolbox/translators/securicad.py +0 -179
  27. {mal_toolbox-1.1.0.dist-info → mal_toolbox-1.1.2.dist-info}/WHEEL +0 -0
  28. {mal_toolbox-1.1.0.dist-info → mal_toolbox-1.1.2.dist-info}/entry_points.txt +0 -0
  29. {mal_toolbox-1.1.0.dist-info → mal_toolbox-1.1.2.dist-info}/licenses/AUTHORS +0 -0
  30. {mal_toolbox-1.1.0.dist-info → mal_toolbox-1.1.2.dist-info}/licenses/LICENSE +0 -0
  31. {mal_toolbox-1.1.0.dist-info → mal_toolbox-1.1.2.dist-info}/top_level.txt +0 -0
maltoolbox/model.py CHANGED
@@ -1,41 +1,37 @@
1
- """
2
- MAL-Toolbox Model Module
1
+ """MAL-Toolbox Model Module
3
2
  """
4
3
 
5
4
  from __future__ import annotations
5
+
6
6
  import json
7
7
  import logging
8
- from typing import TYPE_CHECKING
9
8
  import math
9
+ from typing import TYPE_CHECKING
10
10
 
11
+ from . import __version__
12
+ from .exceptions import ModelException
11
13
  from .file_utils import (
12
14
  load_dict_from_json_file,
13
15
  load_dict_from_yaml_file,
14
- save_dict_to_file
16
+ save_dict_to_file,
15
17
  )
16
-
17
- from . import __version__
18
- from .exceptions import ModelException
19
-
18
+ from .language import LanguageGraph
20
19
  if TYPE_CHECKING:
21
- from typing import Any, Optional
22
- from .language import (
23
- LanguageGraph,
24
- LanguageGraphAsset,
25
- LanguageGraphAssociation
26
- )
20
+ from typing import Any
21
+
22
+ from .language import LanguageGraphAsset, LanguageGraphAssociation
27
23
 
28
24
  logger = logging.getLogger(__name__)
29
25
 
30
26
 
31
- class Model():
27
+ class Model:
32
28
  """An implementation of a MAL language model containing assets"""
29
+
33
30
  next_id: int = 0
34
31
 
35
32
  def __repr__(self) -> str:
36
33
  return f'Model(name: "{self.name}", language: {self.lang_graph})'
37
34
 
38
-
39
35
  def __init__(
40
36
  self,
41
37
  name: str,
@@ -45,25 +41,24 @@ class Model():
45
41
 
46
42
  self.name = name
47
43
  self.assets: dict[int, ModelAsset] = {}
48
- self._name_to_asset:dict[str, ModelAsset] = {} # optimization
44
+ self._name_to_asset: dict[str, ModelAsset] = {} # optimization
49
45
  self.lang_graph = lang_graph
50
46
  self.maltoolbox_version: str = mt_version
51
47
 
52
-
53
48
  def add_asset(
54
49
  self,
55
50
  asset_type: str,
56
- name: Optional[str] = None,
57
- asset_id: Optional[int] = None,
58
- defenses: Optional[dict[str, float]] = None,
59
- extras: Optional[dict] = None,
51
+ name: str | None = None,
52
+ asset_id: int | None = None,
53
+ defenses: dict[str, float] | None = None,
54
+ extras: dict | None = None,
60
55
  allow_duplicate_names: bool = True
61
56
  ) -> ModelAsset:
62
- """
63
- Create an asset based on the provided parameters and add it to the
57
+ """Create an asset based on the provided parameters and add it to the
64
58
  model.
65
59
 
66
60
  Arguments:
61
+ ---------
67
62
  asset_type - string containing the asset type name
68
63
  name - string containing the asset name. If not
69
64
  provided the concatenated asset type and id
@@ -79,9 +74,10 @@ class Model():
79
74
  be appended with the id.
80
75
 
81
76
  Return:
77
+ ------
82
78
  The newly created asset.
83
- """
84
79
 
80
+ """
85
81
  # Set asset ID and check for duplicates
86
82
  asset_id = asset_id or self.next_id
87
83
  if asset_id in self.assets:
@@ -91,15 +87,14 @@ class Model():
91
87
 
92
88
  if not name:
93
89
  name = asset_type + ':' + str(asset_id)
94
- else:
95
- if name in self._name_to_asset:
96
- if allow_duplicate_names:
97
- name = name + ':' + str(asset_id)
98
- else:
99
- raise ValueError(
100
- f'Asset name {name} is a duplicate'
101
- ' and we do not allow duplicates.'
102
- )
90
+ elif name in self._name_to_asset:
91
+ if allow_duplicate_names:
92
+ name = name + ':' + str(asset_id)
93
+ else:
94
+ raise ValueError(
95
+ f'Asset name {name} is a duplicate'
96
+ ' and we do not allow duplicates.'
97
+ )
103
98
 
104
99
  if asset_type not in self.lang_graph.assets:
105
100
  raise ValueError(
@@ -110,11 +105,11 @@ class Model():
110
105
  lg_asset = self.lang_graph.assets[asset_type]
111
106
 
112
107
  asset = ModelAsset(
113
- name = name,
114
- asset_id = asset_id,
115
- lg_asset = lg_asset,
116
- defenses = defenses,
117
- extras = extras)
108
+ name=name,
109
+ asset_id=asset_id,
110
+ lg_asset=lg_asset,
111
+ defenses=defenses,
112
+ extras=extras)
118
113
 
119
114
  logger.debug(
120
115
  'Add "%s"(%d) to model "%s".', name, asset_id, self.name
@@ -124,14 +119,14 @@ class Model():
124
119
 
125
120
  return asset
126
121
 
127
-
128
122
  def remove_asset(self, asset: ModelAsset) -> None:
129
123
  """Remove an asset from the model.
130
124
 
131
125
  Arguments:
126
+ ---------
132
127
  asset - the asset to remove
133
- """
134
128
 
129
+ """
135
130
  logger.debug(
136
131
  'Remove "%s"(%d) from model "%s".',
137
132
  asset.name, asset.id, self.name
@@ -152,18 +147,19 @@ class Model():
152
147
  del self.assets[asset.id]
153
148
  del self._name_to_asset[asset.name]
154
149
 
155
-
156
150
  def get_asset_by_id(
157
151
  self, asset_id: int
158
- ) -> Optional[ModelAsset]:
159
- """
160
- Find an asset in the model based on its id.
152
+ ) -> ModelAsset | None:
153
+ """Find an asset in the model based on its id.
161
154
 
162
155
  Arguments:
156
+ ---------
163
157
  asset_id - the id of the asset we are looking for
164
158
 
165
159
  Return:
160
+ ------
166
161
  An asset matching the id if it exists in the model.
162
+
167
163
  """
168
164
  logger.debug(
169
165
  'Get asset with id %d from model "%s".',
@@ -171,18 +167,19 @@ class Model():
171
167
  )
172
168
  return self.assets.get(asset_id, None)
173
169
 
174
-
175
170
  def get_asset_by_name(
176
171
  self, asset_name: str
177
- ) -> Optional[ModelAsset]:
178
- """
179
- Find an asset in the model based on its name.
172
+ ) -> ModelAsset | None:
173
+ """Find an asset in the model based on its name.
180
174
 
181
175
  Arguments:
176
+ ---------
182
177
  asset_name - the name of the asset we are looking for
183
178
 
184
179
  Return:
180
+ ------
185
181
  An asset matching the name if it exists in the model.
182
+
186
183
  """
187
184
  logger.debug(
188
185
  'Get asset with name "%s" from model "%s".',
@@ -216,13 +213,11 @@ class Model():
216
213
 
217
214
  return contents
218
215
 
219
-
220
216
  def save_to_file(self, filename: str) -> None:
221
217
  """Save to json/yml depending on extension"""
222
218
  logger.debug('Save instance model to file "%s".', filename)
223
219
  return save_dict_to_file(filename, self._to_dict())
224
220
 
225
-
226
221
  @classmethod
227
222
  def _from_dict(
228
223
  cls,
@@ -232,17 +227,18 @@ class Model():
232
227
  """Create a model from dict representation
233
228
 
234
229
  Arguments:
230
+ ---------
235
231
  serialized_object - Model in dict format
236
232
  lang_graph -
237
- """
238
233
 
234
+ """
239
235
  maltoolbox_version = serialized_object['metadata']['MAL Toolbox Version'] \
240
236
  if 'MAL Toolbox Version' in serialized_object['metadata'] \
241
237
  else __version__
242
238
  model = Model(
243
239
  serialized_object['metadata']['name'],
244
240
  lang_graph,
245
- mt_version = maltoolbox_version)
241
+ mt_version=maltoolbox_version)
246
242
 
247
243
  # Reconstruct the assets
248
244
  for asset_id, asset_dict in serialized_object['assets'].items():
@@ -264,12 +260,12 @@ class Model():
264
260
  )
265
261
 
266
262
  model.add_asset(
267
- asset_type = asset_dict['type'],
268
- name = asset_dict['name'],
269
- defenses = {defense: float(value) for defense, value in \
263
+ asset_type=asset_dict['type'],
264
+ name=asset_dict['name'],
265
+ defenses={defense: float(value) for defense, value in
270
266
  asset_dict.get('defenses', {}).items()},
271
- extras = asset_dict.get('extras', {}),
272
- asset_id = int(asset_id))
267
+ extras=asset_dict.get('extras', {}),
268
+ asset_id=int(asset_id))
273
269
 
274
270
  # Reconstruct the association links
275
271
  for asset_id, asset_dict in serialized_object['assets'].items():
@@ -291,7 +287,6 @@ class Model():
291
287
 
292
288
  return model
293
289
 
294
-
295
290
  @classmethod
296
291
  def load_from_file(
297
292
  cls,
@@ -315,6 +310,27 @@ class Model():
315
310
  "Try to upgrade it with 'maltoolbox upgrade-model'"
316
311
  ) from e
317
312
 
313
+ def __getstate__(self):
314
+ lang_state = self.lang_graph.__getstate__()
315
+ state = self._to_dict()
316
+ return {
317
+ 'model_state': state,
318
+ 'lang_graph': lang_state
319
+ }
320
+
321
+ def __setstate__(self, state):
322
+ # Restore the language graph first
323
+ lang_graph = LanguageGraph.__new__(LanguageGraph)
324
+ lang_graph.__setstate__(state['lang_graph'])
325
+ self.lang_graph = lang_graph
326
+
327
+ # Restore the model state by creating a temporary model and copying attributes
328
+ temp_model = self._from_dict(state['model_state'], self.lang_graph)
329
+ self.name = temp_model.name
330
+ self.assets = temp_model.assets
331
+ self._name_to_asset = temp_model._name_to_asset
332
+ self.maltoolbox_version = temp_model.maltoolbox_version
333
+ self.next_id = temp_model.next_id
318
334
 
319
335
  class ModelAsset:
320
336
  def __init__(
@@ -322,8 +338,8 @@ class ModelAsset:
322
338
  name: str,
323
339
  asset_id: int,
324
340
  lg_asset: LanguageGraphAsset,
325
- defenses: Optional[dict[str, float]] = None,
326
- extras: Optional[dict] = None
341
+ defenses: dict[str, float] | None = None,
342
+ extras: dict | None = None
327
343
  ):
328
344
 
329
345
  self.name: str = name
@@ -335,9 +351,8 @@ class ModelAsset:
335
351
  self._associated_assets: dict[str, set[ModelAsset]] = {}
336
352
  self.attack_step_nodes: list = []
337
353
 
338
- def _to_dict(self):
354
+ def _to_dict(self) -> dict[int, dict[str, Any]]:
339
355
  """Get dictionary representation of the asset."""
340
-
341
356
  logger.debug(
342
357
  'Translating "%s"(%d) to dictionary.', self.name, self.id)
343
358
 
@@ -365,7 +380,6 @@ class ModelAsset:
365
380
 
366
381
  return {self.id: asset_dict}
367
382
 
368
-
369
383
  def __repr__(self):
370
384
  return (f'ModelAsset(name: "{self.name}", id: {self.id}, '
371
385
  f'type: {self.type})')
@@ -388,10 +402,8 @@ class ModelAsset:
388
402
  return assocs_in_common
389
403
 
390
404
  def has_association_with(self, b: ModelAsset, assoc_name: str) -> bool:
405
+ """Returns True if association `assoc_name` exists between self and `b`
391
406
  """
392
- Returns True if association `assoc_name` exists between self and `b`
393
- """
394
-
395
407
  for fieldname, associated_assets in self.associated_assets.items():
396
408
  assoc = self.lg_asset.associations[fieldname]
397
409
  if assoc.name == assoc_name and b in associated_assets:
@@ -402,20 +414,20 @@ class ModelAsset:
402
414
  def validate_associated_assets(
403
415
  self, fieldname: str, assets_to_add: set[ModelAsset]
404
416
  ):
405
- """
406
- Validate an association we want to add (through `fieldname`)
417
+ """Validate an association we want to add (through `fieldname`)
407
418
  is valid with the assets given in param `assets_to_add`:
408
419
  - fieldname is valid for the asset type of this ModelAsset
409
420
  - type of `assets_to_add` is valid for the association
410
421
  - no more assets than 'field.maximum' are added to the field
411
422
 
412
- Raises:
423
+ Raises
424
+ ------
413
425
  LookupError - fieldname can not be found for this ModelAsset
414
426
  ValueError - there will be too many assets in the field
415
427
  if we add this association
416
428
  TypeError - if the asset type of `assets_to_add` is not valid
417
- """
418
429
 
430
+ """
419
431
  # Validate that the field name is allowed for this asset type
420
432
  if fieldname not in self.lg_asset.associations:
421
433
  accepted_fieldnames = list(self.lg_asset.associations.keys())
@@ -450,11 +462,9 @@ class ModelAsset:
450
462
  )
451
463
 
452
464
  def add_associated_assets(self, fieldname: str, assets: set[ModelAsset]):
453
- """
454
- Add the assets provided as a parameter to the set of associated
465
+ """Add the assets provided as a parameter to the set of associated
455
466
  assets dictionary entry corresponding to the given fieldname.
456
467
  """
457
-
458
468
  if fieldname not in self.lg_asset.associations:
459
469
  if assets:
460
470
  to_asset_type = next(iter(assets)).lg_asset
@@ -490,7 +500,7 @@ class ModelAsset:
490
500
 
491
501
  def remove_associated_assets(
492
502
  self, fieldname: str, assets: set[ModelAsset]):
493
- """ Remove the assets provided as a parameter from the set of
503
+ """Remove the assets provided as a parameter from the set of
494
504
  associated assets dictionary entry corresponding to the fieldname
495
505
  parameter.
496
506
  """
@@ -507,12 +517,10 @@ class ModelAsset:
507
517
  if len(self._associated_assets[fieldname]) == 0:
508
518
  del self._associated_assets[fieldname]
509
519
 
510
-
511
520
  @property
512
521
  def associated_assets(self):
513
522
  return self._associated_assets
514
523
 
515
-
516
524
  @property
517
525
  def id(self):
518
526
  return self._id
@@ -1,13 +1,18 @@
1
1
  """Utilities for finding patterns in the AttackGraph"""
2
2
 
3
3
  from __future__ import annotations
4
+
5
+ from collections.abc import Callable
4
6
  from dataclasses import dataclass
5
- from typing import Callable
7
+
6
8
  from maltoolbox.attackgraph import AttackGraph, AttackGraphNode
7
9
 
10
+
8
11
  class SearchPattern:
9
12
  """A pattern consists of conditions, the conditions are used
10
- to find all matching sequences of nodes in an AttackGraph."""
13
+ to find all matching sequences of nodes in an AttackGraph.
14
+ """
15
+
11
16
  conditions: list[SearchCondition]
12
17
 
13
18
  def __init__(self, conditions):
@@ -19,11 +24,12 @@ class SearchPattern:
19
24
  that match all the conditions in the pattern
20
25
 
21
26
  Args:
27
+ ----
22
28
  graph - The AttackGraph to search in
23
29
 
24
30
  Return: list[list[AttackGraphNode]] matching paths of Nodes
25
- """
26
31
 
32
+ """
27
33
  # Find the starting nodes which match the first condition
28
34
  condition = self.conditions[0]
29
35
  matching_paths = []
@@ -34,12 +40,14 @@ class SearchPattern:
34
40
  )
35
41
  return matching_paths
36
42
 
37
- @dataclass
43
+
44
+ @dataclass(frozen=True, eq=True)
38
45
  class SearchCondition:
39
46
  """A condition that has to be true for a node to match"""
40
47
 
41
48
  # Predefined search conditions
42
- ANY = lambda _: True
49
+ @staticmethod
50
+ def ANY(_): return True
43
51
 
44
52
  # `matches` should be a lambda that takes node as input and returns bool
45
53
  # If lamdba returns True for a node, the node matches
@@ -64,7 +72,7 @@ def find_matches_recursively(
64
72
  node: AttackGraphNode,
65
73
  condition_list: list[SearchCondition],
66
74
  current_path: list[AttackGraphNode] | None = None,
67
- matching_paths: set[tuple[AttackGraphNode,...]] | None = None,
75
+ matching_paths: set[tuple[AttackGraphNode, ...]] | None = None,
68
76
  condition_match_count: int = 0
69
77
  ):
70
78
  """Find all paths of nodes that match the list of conditions.
@@ -73,6 +81,7 @@ def find_matches_recursively(
73
81
  The function runs recursively down all paths of children nodes.
74
82
 
75
83
  Args:
84
+ ----
76
85
  node - node to check if current `condition` matches for
77
86
  condition_list - first condition in list will attempt match `node`
78
87
  current_path - list of matched nodes so far (recursively built)
@@ -80,8 +89,8 @@ def find_matches_recursively(
80
89
  condition_match_count - number of matches on current condition so far
81
90
 
82
91
  Return: set of tuples (paths) of AttackGraphNodes that match the condition
83
- """
84
92
 
93
+ """
85
94
  # Init path lists if None, or copy/init into new lists for each iteration
86
95
  current_path = [] if current_path is None else list(current_path)
87
96
  matching_paths = set() if matching_paths is None else matching_paths
@@ -129,6 +138,6 @@ def find_matches_recursively(
129
138
 
130
139
  if not next_conds:
131
140
  # Congrats - matched a full unique search pattern!
132
- matching_paths.add(tuple(current_path)) # tuple is hashable
141
+ matching_paths.add(tuple(current_path)) # tuple is hashable
133
142
 
134
143
  return matching_paths
@@ -0,0 +1,8 @@
1
+ from .networkx import attack_graph_to_nx, model_to_nx
2
+ from .updater import load_model_from_older_version
3
+
4
+ __all__ = [
5
+ 'attack_graph_to_nx',
6
+ 'load_model_from_older_version',
7
+ 'model_to_nx'
8
+ ]
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+
5
+ import networkx as nx
6
+
7
+ from maltoolbox.attackgraph import AttackGraph, AttackGraphNode
8
+ from maltoolbox.model import Model
9
+
10
+
11
+ def attack_graph_to_nx(nodes: AttackGraph | Iterable[AttackGraphNode]) -> nx.DiGraph:
12
+ """Convert an attack graph to a networkx DiGraph"""
13
+ if isinstance(nodes, AttackGraph):
14
+ nodes = list(nodes.nodes.values())
15
+ G: nx.DiGraph = nx.DiGraph()
16
+
17
+ for node in nodes:
18
+ G.add_node(node.id, **node.to_dict())
19
+ G.nodes[node.id]["full_name"] = node.full_name
20
+
21
+ edges = [(node.id, child.id) for node in nodes for child in node.children]
22
+ edges += [(parent.id, node.id) for node in nodes for parent in node.parents]
23
+ G.add_edges_from(edges)
24
+
25
+ return G
26
+
27
+
28
+ def model_to_nx(model: Model) -> nx.Graph:
29
+ """Convert a MAL Model to a networkx GRaph"""
30
+ G: nx.Graph = nx.Graph()
31
+
32
+ for id, asset in model.assets.items():
33
+ asset_dict = asset._to_dict()[id]
34
+ asset_dict["id"] = id
35
+ G.add_node(id, **asset_dict)
36
+
37
+ for id, asset in model.assets.items():
38
+ for fieldname, associated_assets in asset.associated_assets.items():
39
+ for associated_asset in associated_assets:
40
+ G.add_edge(id, associated_asset.id, name=asset.lg_asset.associations[fieldname].name)
41
+
42
+ return G
@@ -1,25 +1,20 @@
1
- import json
2
1
  import logging
3
2
 
4
- import yaml
5
-
6
- import logging
7
- from ..model import Model
8
- from ..language import LanguageGraph
9
3
  from ..file_utils import load_dict_from_json_file, load_dict_from_yaml_file
4
+ from ..language import LanguageGraph
5
+ from ..model import Model
10
6
 
11
7
  logger = logging.getLogger(__name__)
12
8
 
9
+
13
10
  def load_model_from_older_version(
14
11
  filename: str, lang_graph: LanguageGraph,
15
12
  ) -> Model:
16
-
17
- """ Load an older Model file
13
+ """Load an older Model file
18
14
 
19
15
  Load an older model from given `filename` (yml/json)
20
16
  convert the model to the new format and return a Model object.
21
17
  """
22
-
23
18
  model_dict = load_model_dict_from_file(filename)
24
19
 
25
20
  # Get the version of the model, default to 0.0
@@ -54,7 +49,6 @@ def load_model_dict_from_file(
54
49
  filename: str,
55
50
  ) -> dict:
56
51
  """Load a json or yaml file to dict"""
57
-
58
52
  model_dict = {}
59
53
  if filename.endswith('.yml') or filename.endswith('.yaml'):
60
54
  model_dict = load_dict_from_yaml_file(filename)
@@ -68,16 +62,17 @@ def load_model_dict_from_file(
68
62
 
69
63
 
70
64
  def convert_model_dict_from_version_0_0(model_dict: dict) -> dict:
71
- """
72
- Convert model dict version 0.0 to 0.1
65
+ """Convert model dict version 0.0 to 0.1
73
66
 
74
67
  Arguments:
68
+ ---------
75
69
  model_dict - the dictionary containing the serialized model
76
70
 
77
71
  Returns:
72
+ -------
78
73
  A dictionary containing the version 0.1 equivalent serialized model
79
- """
80
74
 
75
+ """
81
76
  new_model_dict = {}
82
77
 
83
78
  # Meta data and attackers did not change
@@ -121,24 +116,21 @@ def convert_model_dict_from_version_0_0(model_dict: dict) -> dict:
121
116
  # Add new assoc dict to new model dict
122
117
  new_model_dict['associations'] = new_assoc_list
123
118
 
124
- # Reconstruct the attackers
125
- if 'attackers' in model_dict:
126
- attackers_info = model_dict['attackers']
127
-
128
119
  return new_model_dict
129
120
 
130
121
 
131
122
  def convert_model_dict_from_version_0_1(model_dict: dict) -> dict:
132
- """
133
- Convert model dict version 0.1 to 0.2
123
+ """Convert model dict version 0.1 to 0.2
134
124
 
135
125
  Arguments:
126
+ ---------
136
127
  model_dict - the dictionary containing the serialized model
137
128
 
138
129
  Returns:
130
+ -------
139
131
  A dictionary containing the version 0.2 equivalent serialized model
140
- """
141
132
 
133
+ """
142
134
  new_model_dict = {}
143
135
 
144
136
  # Meta data and assets format did not change from version 0.1
@@ -176,14 +168,14 @@ def convert_model_dict_from_version_0_1(model_dict: dict) -> dict:
176
168
  new_attackers_dict: dict[int, dict] = {}
177
169
  attackers_dict: dict = model_dict.get('attackers', {})
178
170
  for attacker_id, attacker_dict in attackers_dict.items():
179
- attacker_id = int(attacker_id) # JSON compatibility
171
+ attacker_id = int(attacker_id) # JSON compatibility
180
172
  new_attackers_dict[attacker_id] = {}
181
173
  new_attackers_dict[attacker_id]['name'] = attacker_dict['name']
182
174
  new_entry_points_dict = {}
183
175
 
184
176
  entry_points_dict = attacker_dict['entry_points']
185
177
  for asset_id, attack_steps in entry_points_dict.items():
186
- asset_id = int(asset_id) # JSON compatibility
178
+ asset_id = int(asset_id) # JSON compatibility
187
179
  asset_name = new_assets_dict[asset_id]['name']
188
180
  new_entry_points_dict[asset_name] = {
189
181
  'asset_id': asset_id,
@@ -198,16 +190,17 @@ def convert_model_dict_from_version_0_1(model_dict: dict) -> dict:
198
190
 
199
191
 
200
192
  def convert_model_dict_from_version_0_2(model_dict: dict) -> dict:
201
- """
202
- Convert model dict version 0.2 to 0.3
193
+ """Convert model dict version 0.2 to 0.3
203
194
 
204
195
  Arguments:
196
+ ---------
205
197
  model_dict - the dictionary containing the serialized model
206
198
 
207
199
  Returns:
200
+ -------
208
201
  A dictionary containing the version 0.3 equivalent serialized model
209
- """
210
202
 
203
+ """
211
204
  new_model_dict = {}
212
205
 
213
206
  # Meta data and assets format did not change from version 0.1
@@ -1,11 +1,11 @@
1
+ from .draw_io_utils import create_drawio_file_with_images
1
2
  from .graphviz_utils import render_attack_graph, render_model
2
3
  from .neo4j_utils import ingest_attack_graph_neo4j, ingest_model_neo4j
3
- from .draw_io_utils import create_drawio_file_with_images
4
4
 
5
5
  __all__ = [
6
- 'render_attack_graph',
7
- 'render_model',
6
+ 'create_drawio_file_with_images',
8
7
  'ingest_attack_graph_neo4j',
9
8
  'ingest_model_neo4j',
10
- 'create_drawio_file_with_images'
9
+ 'render_attack_graph',
10
+ 'render_model',
11
11
  ]