mal-toolbox 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/METADATA +25 -2
  2. mal_toolbox-1.1.3.dist-info/RECORD +32 -0
  3. maltoolbox/__init__.py +6 -7
  4. maltoolbox/__main__.py +17 -9
  5. maltoolbox/attackgraph/__init__.py +2 -3
  6. maltoolbox/attackgraph/attackgraph.py +379 -362
  7. maltoolbox/attackgraph/node.py +14 -19
  8. maltoolbox/exceptions.py +7 -10
  9. maltoolbox/file_utils.py +10 -4
  10. maltoolbox/language/__init__.py +1 -1
  11. maltoolbox/language/compiler/__init__.py +4 -4
  12. maltoolbox/language/compiler/mal_lexer.py +154 -154
  13. maltoolbox/language/compiler/mal_parser.py +784 -1136
  14. maltoolbox/language/languagegraph.py +487 -639
  15. maltoolbox/model.py +64 -77
  16. maltoolbox/patternfinder/attackgraph_patterns.py +17 -8
  17. maltoolbox/translators/__init__.py +8 -0
  18. maltoolbox/translators/networkx.py +42 -0
  19. maltoolbox/translators/updater.py +18 -25
  20. maltoolbox/visualization/__init__.py +4 -4
  21. maltoolbox/visualization/draw_io_utils.py +6 -5
  22. maltoolbox/visualization/graphviz_utils.py +4 -2
  23. maltoolbox/visualization/neo4j_utils.py +13 -14
  24. maltoolbox/visualization/utils.py +2 -3
  25. mal_toolbox-1.1.1.dist-info/RECORD +0 -32
  26. maltoolbox/translators/securicad.py +0 -179
  27. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/WHEEL +0 -0
  28. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/entry_points.txt +0 -0
  29. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/licenses/AUTHORS +0 -0
  30. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/licenses/LICENSE +0 -0
  31. {mal_toolbox-1.1.1.dist-info → mal_toolbox-1.1.3.dist-info}/top_level.txt +0 -0
maltoolbox/model.py CHANGED
@@ -1,41 +1,37 @@
1
- """
2
- MAL-Toolbox Model Module
1
+ """MAL-Toolbox Model Module
3
2
  """
4
3
 
5
4
  from __future__ import annotations
5
+
6
6
  import json
7
7
  import logging
8
- from typing import TYPE_CHECKING
9
8
  import math
9
+ from typing import TYPE_CHECKING
10
10
 
11
+ from . import __version__
12
+ from .exceptions import ModelException
11
13
  from .file_utils import (
12
14
  load_dict_from_json_file,
13
15
  load_dict_from_yaml_file,
14
- save_dict_to_file
16
+ save_dict_to_file,
15
17
  )
16
-
17
- from . import __version__
18
- from .exceptions import ModelException
19
-
18
+ from .language import LanguageGraph
20
19
  if TYPE_CHECKING:
21
- from typing import Any, Optional
22
- from .language import (
23
- LanguageGraph,
24
- LanguageGraphAsset,
25
- LanguageGraphAssociation
26
- )
20
+ from typing import Any
21
+
22
+ from .language import LanguageGraphAsset, LanguageGraphAssociation
27
23
 
28
24
  logger = logging.getLogger(__name__)
29
25
 
30
26
 
31
- class Model():
27
+ class Model:
32
28
  """An implementation of a MAL language model containing assets"""
29
+
33
30
  next_id: int = 0
34
31
 
35
32
  def __repr__(self) -> str:
36
33
  return f'Model(name: "{self.name}", language: {self.lang_graph})'
37
34
 
38
-
39
35
  def __init__(
40
36
  self,
41
37
  name: str,
@@ -45,25 +41,24 @@ class Model():
45
41
 
46
42
  self.name = name
47
43
  self.assets: dict[int, ModelAsset] = {}
48
- self._name_to_asset:dict[str, ModelAsset] = {} # optimization
44
+ self._name_to_asset: dict[str, ModelAsset] = {} # optimization
49
45
  self.lang_graph = lang_graph
50
46
  self.maltoolbox_version: str = mt_version
51
47
 
52
-
53
48
  def add_asset(
54
49
  self,
55
50
  asset_type: str,
56
- name: Optional[str] = None,
57
- asset_id: Optional[int] = None,
58
- defenses: Optional[dict[str, float]] = None,
59
- extras: Optional[dict] = None,
51
+ name: str | None = None,
52
+ asset_id: int | None = None,
53
+ defenses: dict[str, float] | None = None,
54
+ extras: dict | None = None,
60
55
  allow_duplicate_names: bool = True
61
56
  ) -> ModelAsset:
62
- """
63
- Create an asset based on the provided parameters and add it to the
57
+ """Create an asset based on the provided parameters and add it to the
64
58
  model.
65
59
 
66
60
  Arguments:
61
+ ---------
67
62
  asset_type - string containing the asset type name
68
63
  name - string containing the asset name. If not
69
64
  provided the concatenated asset type and id
@@ -79,9 +74,10 @@ class Model():
79
74
  be appended with the id.
80
75
 
81
76
  Return:
77
+ ------
82
78
  The newly created asset.
83
- """
84
79
 
80
+ """
85
81
  # Set asset ID and check for duplicates
86
82
  asset_id = asset_id or self.next_id
87
83
  if asset_id in self.assets:
@@ -91,15 +87,14 @@ class Model():
91
87
 
92
88
  if not name:
93
89
  name = asset_type + ':' + str(asset_id)
94
- else:
95
- if name in self._name_to_asset:
96
- if allow_duplicate_names:
97
- name = name + ':' + str(asset_id)
98
- else:
99
- raise ValueError(
100
- f'Asset name {name} is a duplicate'
101
- ' and we do not allow duplicates.'
102
- )
90
+ elif name in self._name_to_asset:
91
+ if allow_duplicate_names:
92
+ name = name + ':' + str(asset_id)
93
+ else:
94
+ raise ValueError(
95
+ f'Asset name {name} is a duplicate'
96
+ ' and we do not allow duplicates.'
97
+ )
103
98
 
104
99
  if asset_type not in self.lang_graph.assets:
105
100
  raise ValueError(
@@ -110,11 +105,11 @@ class Model():
110
105
  lg_asset = self.lang_graph.assets[asset_type]
111
106
 
112
107
  asset = ModelAsset(
113
- name = name,
114
- asset_id = asset_id,
115
- lg_asset = lg_asset,
116
- defenses = defenses,
117
- extras = extras)
108
+ name=name,
109
+ asset_id=asset_id,
110
+ lg_asset=lg_asset,
111
+ defenses=defenses,
112
+ extras=extras)
118
113
 
119
114
  logger.debug(
120
115
  'Add "%s"(%d) to model "%s".', name, asset_id, self.name
@@ -124,14 +119,14 @@ class Model():
124
119
 
125
120
  return asset
126
121
 
127
-
128
122
  def remove_asset(self, asset: ModelAsset) -> None:
129
123
  """Remove an asset from the model.
130
124
 
131
125
  Arguments:
126
+ ---------
132
127
  asset - the asset to remove
133
- """
134
128
 
129
+ """
135
130
  logger.debug(
136
131
  'Remove "%s"(%d) from model "%s".',
137
132
  asset.name, asset.id, self.name
@@ -152,18 +147,19 @@ class Model():
152
147
  del self.assets[asset.id]
153
148
  del self._name_to_asset[asset.name]
154
149
 
155
-
156
150
  def get_asset_by_id(
157
151
  self, asset_id: int
158
- ) -> Optional[ModelAsset]:
159
- """
160
- Find an asset in the model based on its id.
152
+ ) -> ModelAsset | None:
153
+ """Find an asset in the model based on its id.
161
154
 
162
155
  Arguments:
156
+ ---------
163
157
  asset_id - the id of the asset we are looking for
164
158
 
165
159
  Return:
160
+ ------
166
161
  An asset matching the id if it exists in the model.
162
+
167
163
  """
168
164
  logger.debug(
169
165
  'Get asset with id %d from model "%s".',
@@ -171,18 +167,19 @@ class Model():
171
167
  )
172
168
  return self.assets.get(asset_id, None)
173
169
 
174
-
175
170
  def get_asset_by_name(
176
171
  self, asset_name: str
177
- ) -> Optional[ModelAsset]:
178
- """
179
- Find an asset in the model based on its name.
172
+ ) -> ModelAsset | None:
173
+ """Find an asset in the model based on its name.
180
174
 
181
175
  Arguments:
176
+ ---------
182
177
  asset_name - the name of the asset we are looking for
183
178
 
184
179
  Return:
180
+ ------
185
181
  An asset matching the name if it exists in the model.
182
+
186
183
  """
187
184
  logger.debug(
188
185
  'Get asset with name "%s" from model "%s".',
@@ -216,13 +213,11 @@ class Model():
216
213
 
217
214
  return contents
218
215
 
219
-
220
216
  def save_to_file(self, filename: str) -> None:
221
217
  """Save to json/yml depending on extension"""
222
218
  logger.debug('Save instance model to file "%s".', filename)
223
219
  return save_dict_to_file(filename, self._to_dict())
224
220
 
225
-
226
221
  @classmethod
227
222
  def _from_dict(
228
223
  cls,
@@ -232,17 +227,18 @@ class Model():
232
227
  """Create a model from dict representation
233
228
 
234
229
  Arguments:
230
+ ---------
235
231
  serialized_object - Model in dict format
236
232
  lang_graph -
237
- """
238
233
 
234
+ """
239
235
  maltoolbox_version = serialized_object['metadata']['MAL Toolbox Version'] \
240
236
  if 'MAL Toolbox Version' in serialized_object['metadata'] \
241
237
  else __version__
242
238
  model = Model(
243
239
  serialized_object['metadata']['name'],
244
240
  lang_graph,
245
- mt_version = maltoolbox_version)
241
+ mt_version=maltoolbox_version)
246
242
 
247
243
  # Reconstruct the assets
248
244
  for asset_id, asset_dict in serialized_object['assets'].items():
@@ -264,12 +260,12 @@ class Model():
264
260
  )
265
261
 
266
262
  model.add_asset(
267
- asset_type = asset_dict['type'],
268
- name = asset_dict['name'],
269
- defenses = {defense: float(value) for defense, value in \
263
+ asset_type=asset_dict['type'],
264
+ name=asset_dict['name'],
265
+ defenses={defense: float(value) for defense, value in
270
266
  asset_dict.get('defenses', {}).items()},
271
- extras = asset_dict.get('extras', {}),
272
- asset_id = int(asset_id))
267
+ extras=asset_dict.get('extras', {}),
268
+ asset_id=int(asset_id))
273
269
 
274
270
  # Reconstruct the association links
275
271
  for asset_id, asset_dict in serialized_object['assets'].items():
@@ -291,7 +287,6 @@ class Model():
291
287
 
292
288
  return model
293
289
 
294
-
295
290
  @classmethod
296
291
  def load_from_file(
297
292
  cls,
@@ -322,8 +317,8 @@ class ModelAsset:
322
317
  name: str,
323
318
  asset_id: int,
324
319
  lg_asset: LanguageGraphAsset,
325
- defenses: Optional[dict[str, float]] = None,
326
- extras: Optional[dict] = None
320
+ defenses: dict[str, float] | None = None,
321
+ extras: dict | None = None
327
322
  ):
328
323
 
329
324
  self.name: str = name
@@ -335,9 +330,8 @@ class ModelAsset:
335
330
  self._associated_assets: dict[str, set[ModelAsset]] = {}
336
331
  self.attack_step_nodes: list = []
337
332
 
338
- def _to_dict(self):
333
+ def _to_dict(self) -> dict[int, dict[str, Any]]:
339
334
  """Get dictionary representation of the asset."""
340
-
341
335
  logger.debug(
342
336
  'Translating "%s"(%d) to dictionary.', self.name, self.id)
343
337
 
@@ -365,7 +359,6 @@ class ModelAsset:
365
359
 
366
360
  return {self.id: asset_dict}
367
361
 
368
-
369
362
  def __repr__(self):
370
363
  return (f'ModelAsset(name: "{self.name}", id: {self.id}, '
371
364
  f'type: {self.type})')
@@ -388,10 +381,8 @@ class ModelAsset:
388
381
  return assocs_in_common
389
382
 
390
383
  def has_association_with(self, b: ModelAsset, assoc_name: str) -> bool:
384
+ """Returns True if association `assoc_name` exists between self and `b`
391
385
  """
392
- Returns True if association `assoc_name` exists between self and `b`
393
- """
394
-
395
386
  for fieldname, associated_assets in self.associated_assets.items():
396
387
  assoc = self.lg_asset.associations[fieldname]
397
388
  if assoc.name == assoc_name and b in associated_assets:
@@ -402,20 +393,20 @@ class ModelAsset:
402
393
  def validate_associated_assets(
403
394
  self, fieldname: str, assets_to_add: set[ModelAsset]
404
395
  ):
405
- """
406
- Validate an association we want to add (through `fieldname`)
396
+ """Validate an association we want to add (through `fieldname`)
407
397
  is valid with the assets given in param `assets_to_add`:
408
398
  - fieldname is valid for the asset type of this ModelAsset
409
399
  - type of `assets_to_add` is valid for the association
410
400
  - no more assets than 'field.maximum' are added to the field
411
401
 
412
- Raises:
402
+ Raises
403
+ ------
413
404
  LookupError - fieldname can not be found for this ModelAsset
414
405
  ValueError - there will be too many assets in the field
415
406
  if we add this association
416
407
  TypeError - if the asset type of `assets_to_add` is not valid
417
- """
418
408
 
409
+ """
419
410
  # Validate that the field name is allowed for this asset type
420
411
  if fieldname not in self.lg_asset.associations:
421
412
  accepted_fieldnames = list(self.lg_asset.associations.keys())
@@ -450,11 +441,9 @@ class ModelAsset:
450
441
  )
451
442
 
452
443
  def add_associated_assets(self, fieldname: str, assets: set[ModelAsset]):
453
- """
454
- Add the assets provided as a parameter to the set of associated
444
+ """Add the assets provided as a parameter to the set of associated
455
445
  assets dictionary entry corresponding to the given fieldname.
456
446
  """
457
-
458
447
  if fieldname not in self.lg_asset.associations:
459
448
  if assets:
460
449
  to_asset_type = next(iter(assets)).lg_asset
@@ -490,7 +479,7 @@ class ModelAsset:
490
479
 
491
480
  def remove_associated_assets(
492
481
  self, fieldname: str, assets: set[ModelAsset]):
493
- """ Remove the assets provided as a parameter from the set of
482
+ """Remove the assets provided as a parameter from the set of
494
483
  associated assets dictionary entry corresponding to the fieldname
495
484
  parameter.
496
485
  """
@@ -507,12 +496,10 @@ class ModelAsset:
507
496
  if len(self._associated_assets[fieldname]) == 0:
508
497
  del self._associated_assets[fieldname]
509
498
 
510
-
511
499
  @property
512
500
  def associated_assets(self):
513
501
  return self._associated_assets
514
502
 
515
-
516
503
  @property
517
504
  def id(self):
518
505
  return self._id
@@ -1,13 +1,18 @@
1
1
  """Utilities for finding patterns in the AttackGraph"""
2
2
 
3
3
  from __future__ import annotations
4
+
5
+ from collections.abc import Callable
4
6
  from dataclasses import dataclass
5
- from typing import Callable
7
+
6
8
  from maltoolbox.attackgraph import AttackGraph, AttackGraphNode
7
9
 
10
+
8
11
  class SearchPattern:
9
12
  """A pattern consists of conditions, the conditions are used
10
- to find all matching sequences of nodes in an AttackGraph."""
13
+ to find all matching sequences of nodes in an AttackGraph.
14
+ """
15
+
11
16
  conditions: list[SearchCondition]
12
17
 
13
18
  def __init__(self, conditions):
@@ -19,11 +24,12 @@ class SearchPattern:
19
24
  that match all the conditions in the pattern
20
25
 
21
26
  Args:
27
+ ----
22
28
  graph - The AttackGraph to search in
23
29
 
24
30
  Return: list[list[AttackGraphNode]] matching paths of Nodes
25
- """
26
31
 
32
+ """
27
33
  # Find the starting nodes which match the first condition
28
34
  condition = self.conditions[0]
29
35
  matching_paths = []
@@ -34,12 +40,14 @@ class SearchPattern:
34
40
  )
35
41
  return matching_paths
36
42
 
37
- @dataclass
43
+
44
+ @dataclass(frozen=True, eq=True)
38
45
  class SearchCondition:
39
46
  """A condition that has to be true for a node to match"""
40
47
 
41
48
  # Predefined search conditions
42
- ANY = lambda _: True
49
+ @staticmethod
50
+ def ANY(_): return True
43
51
 
44
52
  # `matches` should be a lambda that takes node as input and returns bool
45
53
  # If lamdba returns True for a node, the node matches
@@ -64,7 +72,7 @@ def find_matches_recursively(
64
72
  node: AttackGraphNode,
65
73
  condition_list: list[SearchCondition],
66
74
  current_path: list[AttackGraphNode] | None = None,
67
- matching_paths: set[tuple[AttackGraphNode,...]] | None = None,
75
+ matching_paths: set[tuple[AttackGraphNode, ...]] | None = None,
68
76
  condition_match_count: int = 0
69
77
  ):
70
78
  """Find all paths of nodes that match the list of conditions.
@@ -73,6 +81,7 @@ def find_matches_recursively(
73
81
  The function runs recursively down all paths of children nodes.
74
82
 
75
83
  Args:
84
+ ----
76
85
  node - node to check if current `condition` matches for
77
86
  condition_list - first condition in list will attempt match `node`
78
87
  current_path - list of matched nodes so far (recursively built)
@@ -80,8 +89,8 @@ def find_matches_recursively(
80
89
  condition_match_count - number of matches on current condition so far
81
90
 
82
91
  Return: set of tuples (paths) of AttackGraphNodes that match the condition
83
- """
84
92
 
93
+ """
85
94
  # Init path lists if None, or copy/init into new lists for each iteration
86
95
  current_path = [] if current_path is None else list(current_path)
87
96
  matching_paths = set() if matching_paths is None else matching_paths
@@ -129,6 +138,6 @@ def find_matches_recursively(
129
138
 
130
139
  if not next_conds:
131
140
  # Congrats - matched a full unique search pattern!
132
- matching_paths.add(tuple(current_path)) # tuple is hashable
141
+ matching_paths.add(tuple(current_path)) # tuple is hashable
133
142
 
134
143
  return matching_paths
@@ -0,0 +1,8 @@
1
+ from .networkx import attack_graph_to_nx, model_to_nx
2
+ from .updater import load_model_from_older_version
3
+
4
+ __all__ = [
5
+ 'attack_graph_to_nx',
6
+ 'load_model_from_older_version',
7
+ 'model_to_nx'
8
+ ]
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+
5
+ import networkx as nx
6
+
7
+ from maltoolbox.attackgraph import AttackGraph, AttackGraphNode
8
+ from maltoolbox.model import Model
9
+
10
+
11
+ def attack_graph_to_nx(nodes: AttackGraph | Iterable[AttackGraphNode]) -> nx.DiGraph:
12
+ """Convert an attack graph to a networkx DiGraph"""
13
+ if isinstance(nodes, AttackGraph):
14
+ nodes = list(nodes.nodes.values())
15
+ G: nx.DiGraph = nx.DiGraph()
16
+
17
+ for node in nodes:
18
+ G.add_node(node.id, **node.to_dict())
19
+ G.nodes[node.id]["full_name"] = node.full_name
20
+
21
+ edges = [(node.id, child.id) for node in nodes for child in node.children]
22
+ edges += [(parent.id, node.id) for node in nodes for parent in node.parents]
23
+ G.add_edges_from(edges)
24
+
25
+ return G
26
+
27
+
28
+ def model_to_nx(model: Model) -> nx.Graph:
29
+ """Convert a MAL Model to a networkx GRaph"""
30
+ G: nx.Graph = nx.Graph()
31
+
32
+ for id, asset in model.assets.items():
33
+ asset_dict = asset._to_dict()[id]
34
+ asset_dict["id"] = id
35
+ G.add_node(id, **asset_dict)
36
+
37
+ for id, asset in model.assets.items():
38
+ for fieldname, associated_assets in asset.associated_assets.items():
39
+ for associated_asset in associated_assets:
40
+ G.add_edge(id, associated_asset.id, name=asset.lg_asset.associations[fieldname].name)
41
+
42
+ return G
@@ -1,25 +1,20 @@
1
- import json
2
1
  import logging
3
2
 
4
- import yaml
5
-
6
- import logging
7
- from ..model import Model
8
- from ..language import LanguageGraph
9
3
  from ..file_utils import load_dict_from_json_file, load_dict_from_yaml_file
4
+ from ..language import LanguageGraph
5
+ from ..model import Model
10
6
 
11
7
  logger = logging.getLogger(__name__)
12
8
 
9
+
13
10
  def load_model_from_older_version(
14
11
  filename: str, lang_graph: LanguageGraph,
15
12
  ) -> Model:
16
-
17
- """ Load an older Model file
13
+ """Load an older Model file
18
14
 
19
15
  Load an older model from given `filename` (yml/json)
20
16
  convert the model to the new format and return a Model object.
21
17
  """
22
-
23
18
  model_dict = load_model_dict_from_file(filename)
24
19
 
25
20
  # Get the version of the model, default to 0.0
@@ -54,7 +49,6 @@ def load_model_dict_from_file(
54
49
  filename: str,
55
50
  ) -> dict:
56
51
  """Load a json or yaml file to dict"""
57
-
58
52
  model_dict = {}
59
53
  if filename.endswith('.yml') or filename.endswith('.yaml'):
60
54
  model_dict = load_dict_from_yaml_file(filename)
@@ -68,16 +62,17 @@ def load_model_dict_from_file(
68
62
 
69
63
 
70
64
  def convert_model_dict_from_version_0_0(model_dict: dict) -> dict:
71
- """
72
- Convert model dict version 0.0 to 0.1
65
+ """Convert model dict version 0.0 to 0.1
73
66
 
74
67
  Arguments:
68
+ ---------
75
69
  model_dict - the dictionary containing the serialized model
76
70
 
77
71
  Returns:
72
+ -------
78
73
  A dictionary containing the version 0.1 equivalent serialized model
79
- """
80
74
 
75
+ """
81
76
  new_model_dict = {}
82
77
 
83
78
  # Meta data and attackers did not change
@@ -121,24 +116,21 @@ def convert_model_dict_from_version_0_0(model_dict: dict) -> dict:
121
116
  # Add new assoc dict to new model dict
122
117
  new_model_dict['associations'] = new_assoc_list
123
118
 
124
- # Reconstruct the attackers
125
- if 'attackers' in model_dict:
126
- attackers_info = model_dict['attackers']
127
-
128
119
  return new_model_dict
129
120
 
130
121
 
131
122
  def convert_model_dict_from_version_0_1(model_dict: dict) -> dict:
132
- """
133
- Convert model dict version 0.1 to 0.2
123
+ """Convert model dict version 0.1 to 0.2
134
124
 
135
125
  Arguments:
126
+ ---------
136
127
  model_dict - the dictionary containing the serialized model
137
128
 
138
129
  Returns:
130
+ -------
139
131
  A dictionary containing the version 0.2 equivalent serialized model
140
- """
141
132
 
133
+ """
142
134
  new_model_dict = {}
143
135
 
144
136
  # Meta data and assets format did not change from version 0.1
@@ -176,14 +168,14 @@ def convert_model_dict_from_version_0_1(model_dict: dict) -> dict:
176
168
  new_attackers_dict: dict[int, dict] = {}
177
169
  attackers_dict: dict = model_dict.get('attackers', {})
178
170
  for attacker_id, attacker_dict in attackers_dict.items():
179
- attacker_id = int(attacker_id) # JSON compatibility
171
+ attacker_id = int(attacker_id) # JSON compatibility
180
172
  new_attackers_dict[attacker_id] = {}
181
173
  new_attackers_dict[attacker_id]['name'] = attacker_dict['name']
182
174
  new_entry_points_dict = {}
183
175
 
184
176
  entry_points_dict = attacker_dict['entry_points']
185
177
  for asset_id, attack_steps in entry_points_dict.items():
186
- asset_id = int(asset_id) # JSON compatibility
178
+ asset_id = int(asset_id) # JSON compatibility
187
179
  asset_name = new_assets_dict[asset_id]['name']
188
180
  new_entry_points_dict[asset_name] = {
189
181
  'asset_id': asset_id,
@@ -198,16 +190,17 @@ def convert_model_dict_from_version_0_1(model_dict: dict) -> dict:
198
190
 
199
191
 
200
192
  def convert_model_dict_from_version_0_2(model_dict: dict) -> dict:
201
- """
202
- Convert model dict version 0.2 to 0.3
193
+ """Convert model dict version 0.2 to 0.3
203
194
 
204
195
  Arguments:
196
+ ---------
205
197
  model_dict - the dictionary containing the serialized model
206
198
 
207
199
  Returns:
200
+ -------
208
201
  A dictionary containing the version 0.3 equivalent serialized model
209
- """
210
202
 
203
+ """
211
204
  new_model_dict = {}
212
205
 
213
206
  # Meta data and assets format did not change from version 0.1
@@ -1,11 +1,11 @@
1
+ from .draw_io_utils import create_drawio_file_with_images
1
2
  from .graphviz_utils import render_attack_graph, render_model
2
3
  from .neo4j_utils import ingest_attack_graph_neo4j, ingest_model_neo4j
3
- from .draw_io_utils import create_drawio_file_with_images
4
4
 
5
5
  __all__ = [
6
- 'render_attack_graph',
7
- 'render_model',
6
+ 'create_drawio_file_with_images',
8
7
  'ingest_attack_graph_neo4j',
9
8
  'ingest_model_neo4j',
10
- 'create_drawio_file_with_images'
9
+ 'render_attack_graph',
10
+ 'render_model',
11
11
  ]
@@ -1,7 +1,7 @@
1
1
  """DrawIO exporter made by Sandor"""
2
+ import math
2
3
  import xml.etree.ElementTree as ET
3
4
  from xml.dom import minidom
4
- import math
5
5
 
6
6
  from maltoolbox.model import Model
7
7
 
@@ -29,6 +29,7 @@ type2iconURL = {
29
29
  "HardwareVulnerability": "https://uxwing.com/wp-content/themes/uxwing/download/crime-security-military-law/shield-sedo-line-icon.png",
30
30
  }
31
31
 
32
+
32
33
  def create_drawio_file_with_images(
33
34
  model: Model,
34
35
  show_edge_labels=True,
@@ -36,17 +37,17 @@ def create_drawio_file_with_images(
36
37
  coordinate_scale=0.75,
37
38
  output_filename=None
38
39
  ):
39
- """
40
- Create a draw.io file with all model assets as boxes using their actual positions and images
40
+ """Create a draw.io file with all model assets as boxes using their actual positions and images
41
41
 
42
42
  Args:
43
+ ----
43
44
  model: The model containing assets and associations
44
45
  output_filename: Name of the output draw.io file
45
46
  show_edge_labels: If True, show association type as text on edges. If False, edges will have no labels.
46
47
  line_thickness: Thickness of the edges in pixels (default: 2)
47
48
  coordinate_scale: Scale factor for model coordinates (default: 1.0, use 0.5 for half size, 2.0 for double size)
48
- """
49
49
 
50
+ """
50
51
  if not all(a.extras.get('position') for a in model.assets.values()):
51
52
  # Give assets positions if not already set
52
53
  position_assets(model)
@@ -314,4 +315,4 @@ def create_drawio_file_with_images(
314
315
 
315
316
  print("\nAsset type distribution:")
316
317
  for asset_type, count in sorted(type_counts.items()):
317
- print(f" {asset_type}: {count}")
318
+ print(f" {asset_type}: {count}")