mal-toolbox 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
maltoolbox/model.py CHANGED
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
7
7
  import json
8
8
  import logging
9
9
  from typing import TYPE_CHECKING
10
+ import math
10
11
 
11
12
  from .file_utils import (
12
13
  load_dict_from_json_file,
@@ -15,14 +16,14 @@ from .file_utils import (
15
16
  )
16
17
 
17
18
  from . import __version__
18
- from .exceptions import DuplicateModelAssociationError, ModelAssociationException
19
+ from .exceptions import ModelException
19
20
 
20
21
  if TYPE_CHECKING:
21
- from typing import Any, Optional, TypeAlias
22
- from .language import LanguageClassesFactory
23
- from python_jsonschema_objects.classbuilder import ProtocolBase
24
-
25
- SchemaGeneratedClass: TypeAlias = ProtocolBase
22
+ from typing import Any, Optional
23
+ from .language import (
24
+ LanguageGraph,
25
+ LanguageGraphAsset,
26
+ )
26
27
 
27
28
  logger = logging.getLogger(__name__)
28
29
 
@@ -31,14 +32,14 @@ class AttackerAttachment:
31
32
  """Used to attach attackers to attack step entry points of assets"""
32
33
  id: Optional[int] = None
33
34
  name: Optional[str] = None
34
- entry_points: list[tuple[SchemaGeneratedClass, list[str]]] = \
35
+ entry_points: list[tuple[ModelAsset, list[str]]] = \
35
36
  field(default_factory=lambda: [])
36
37
 
37
38
 
38
39
  def get_entry_point_tuple(
39
40
  self,
40
- asset: SchemaGeneratedClass
41
- ) -> Optional[tuple[SchemaGeneratedClass, list[str]]]:
41
+ asset: ModelAsset
42
+ ) -> Optional[tuple[ModelAsset, list[str]]]:
42
43
  """Return an entry point tuple of an AttackerAttachment matching the
43
44
  asset provided.
44
45
 
@@ -57,7 +58,7 @@ class AttackerAttachment:
57
58
 
58
59
 
59
60
  def add_entry_point(
60
- self, asset: SchemaGeneratedClass, attackstep_name: str):
61
+ self, asset: ModelAsset, attackstep_name: str):
61
62
  """Add an entry point to an AttackerAttachment
62
63
 
63
64
  self.entry_points contain tuples, first element of each tuple
@@ -92,8 +93,9 @@ class AttackerAttachment:
92
93
  # point
93
94
  self.entry_points.append((asset, [attackstep_name]))
94
95
 
96
+
95
97
  def remove_entry_point(
96
- self, asset: SchemaGeneratedClass, attackstep_name: str):
98
+ self, asset: ModelAsset, attackstep_name: str):
97
99
  """Remove an entry point from an AttackerAttachment if it exists
98
100
 
99
101
  Arguments:
@@ -129,25 +131,25 @@ class AttackerAttachment:
129
131
 
130
132
 
131
133
  class Model():
132
- """An implementation of a MAL language with assets and associations"""
134
+ """An implementation of a MAL language model containing assets"""
133
135
  next_id: int = 0
134
136
 
135
137
  def __repr__(self) -> str:
136
- return f'Model {self.name}'
138
+ return f'Model(name: "{self.name}", language: {self.lang_graph})'
139
+
137
140
 
138
141
  def __init__(
139
142
  self,
140
143
  name: str,
141
- lang_classes_factory: LanguageClassesFactory,
144
+ lang_graph: LanguageGraph,
142
145
  mt_version: str = __version__
143
146
  ):
144
147
 
145
148
  self.name = name
146
- self.assets: list[SchemaGeneratedClass] = []
147
- self.associations: list[SchemaGeneratedClass] = []
148
- self._type_to_association:dict = {} # optimization
149
+ self.assets: dict[int, ModelAsset] = {}
150
+ self._name_to_asset:dict[str, ModelAsset] = {} # optimization
149
151
  self.attackers: list[AttackerAttachment] = []
150
- self.lang_classes_factory: LanguageClassesFactory = lang_classes_factory
152
+ self.lang_graph = lang_graph
151
153
  self.maltoolbox_version: str = mt_version
152
154
 
153
155
  # Below sets used to check for duplicate names or ids,
@@ -155,63 +157,84 @@ class Model():
155
157
  self.asset_ids: set[int] = set()
156
158
  self.asset_names: set[str] = set()
157
159
 
160
+
158
161
  def add_asset(
159
162
  self,
160
- asset: SchemaGeneratedClass,
163
+ asset_type: str,
164
+ name: Optional[str] = None,
161
165
  asset_id: Optional[int] = None,
166
+ defenses: Optional[dict[str, float]] = None,
167
+ extras: Optional[dict] = None,
162
168
  allow_duplicate_names: bool = True
163
- ) -> None:
164
- """Add an asset to the model.
169
+ ) -> ModelAsset:
170
+ """
171
+ Create an asset based on the provided parameters and add it to the
172
+ model.
165
173
 
166
174
  Arguments:
167
- asset - the asset to add to the model
168
- asset_id - the id to assign to this asset, usually
169
- from an instance model file
175
+ asset_type - string containing the asset type name
176
+ name - string containing the asset name. If not
177
+ provided the concatenated asset type and id
178
+ will be used as a name.
179
+ asset_id - id to assign to this asset, usually from an
180
+ instance model file. If not provided the id
181
+ will be set to the next highest id
182
+ available.
183
+ defeses - dictionary of defense values
184
+ extras - dictionary of extras
170
185
  allow_duplicate_name - allow duplicate names to be used. If allowed
171
186
  and a duplicate is encountered the name will
172
187
  be appended with the id.
173
188
 
174
189
  Return:
175
- An asset matching the name if it exists in the model.
190
+ The newly created asset.
176
191
  """
177
192
 
178
193
  # Set asset ID and check for duplicates
179
- asset.id = asset_id or self.next_id
180
- if asset.id in self.asset_ids:
194
+ asset_id = asset_id or self.next_id
195
+ if asset_id in self.asset_ids:
181
196
  raise ValueError(f'Asset index {asset_id} already in use.')
182
- self.asset_ids.add(asset.id)
197
+ self.asset_ids.add(asset_id)
183
198
 
184
- self.next_id = max(asset.id + 1, self.next_id)
199
+ self.next_id = max(asset_id + 1, self.next_id)
185
200
 
186
- asset.associations = []
187
-
188
- if not hasattr(asset, 'name'):
189
- asset.name = asset.type + ':' + str(asset.id)
201
+ if not name:
202
+ name = asset_type + ':' + str(asset_id)
190
203
  else:
191
- if asset.name in self.asset_names:
204
+ if name in self.asset_names:
192
205
  if allow_duplicate_names:
193
- asset.name = asset.name + ':' + str(asset.id)
206
+ name = name + ':' + str(asset_id)
194
207
  else:
195
208
  raise ValueError(
196
- f'Asset name {asset.name} is a duplicate'
209
+ f'Asset name {name} is a duplicate'
197
210
  ' and we do not allow duplicates.'
198
211
  )
199
- self.asset_names.add(asset.name)
212
+ self.asset_names.add(name)
213
+
214
+ lg_asset = self.lang_graph.assets[asset_type]
200
215
 
201
- # Optional field for extra asset data
202
- if not hasattr(asset, 'extras'):
203
- asset.extras = {}
216
+ asset = ModelAsset(
217
+ name = name,
218
+ asset_id = asset_id,
219
+ lg_asset = lg_asset,
220
+ defenses = defenses,
221
+ extras = extras)
204
222
 
205
223
  logger.debug(
206
- 'Add "%s"(%d) to model "%s".', asset.name, asset.id, self.name
224
+ 'Add "%s"(%d) to model "%s".', name, asset_id, self.name
207
225
  )
208
- self.assets.append(asset)
226
+ self.assets[asset_id] = asset
227
+ self._name_to_asset[name] = asset
228
+
229
+ return asset
230
+
209
231
 
210
232
  def remove_attacker(self, attacker: AttackerAttachment) -> None:
211
233
  """Remove attacker"""
212
234
  self.attackers.remove(attacker)
213
235
 
214
- def remove_asset(self, asset: SchemaGeneratedClass) -> None:
236
+
237
+ def remove_asset(self, asset: ModelAsset) -> None:
215
238
  """Remove an asset from the model.
216
239
 
217
240
  Arguments:
@@ -222,15 +245,15 @@ class Model():
222
245
  'Remove "%s"(%d) from model "%s".',
223
246
  asset.name, asset.id, self.name
224
247
  )
225
- if asset not in self.assets:
248
+ if asset.id not in self.assets:
226
249
  raise LookupError(
227
250
  f'Asset "{asset.name}"({asset.id}) is not part'
228
251
  f' of model"{self.name}".'
229
252
  )
230
253
 
231
- # First remove all of the associations
232
- for association in asset.associations:
233
- self.remove_asset_from_association(asset, association)
254
+ # First remove all of the associated assets
255
+ for fieldname, assoc_assets in asset.associated_assets.items():
256
+ asset.remove_associated_assets(fieldname, assoc_assets)
234
257
 
235
258
  # Also remove all of the entry points
236
259
  for attacker in self.attackers:
@@ -238,186 +261,9 @@ class Model():
238
261
  if entry_point_tuple:
239
262
  attacker.entry_points.remove(entry_point_tuple)
240
263
 
241
- self.assets.remove(asset)
242
-
243
- def remove_asset_from_association(
244
- self,
245
- asset: SchemaGeneratedClass,
246
- association: SchemaGeneratedClass
247
- ) -> None:
248
- """Remove an asset from an association and remove the association
249
- if any of the two sides is now empty.
250
-
251
- Arguments:
252
- asset - the asset to remove from the given association
253
- association - the association to remove the asset from
254
- """
255
-
256
- logger.debug(
257
- 'Remove "%s"(%d) from association of type "%s".',
258
- asset.name, asset.id, type(association)
259
- )
260
-
261
- if asset not in self.assets:
262
- raise LookupError(
263
- f'Asset "{asset.name}"({asset.id}) is not part of model '
264
- f'"{self.name}".'
265
- )
266
- if association not in self.associations:
267
- raise LookupError(
268
- f'Association is not part of model "{self.name}".'
269
- )
270
-
271
- left_field_name, right_field_name = \
272
- self.get_association_field_names(association)
273
- left_field = getattr(association, left_field_name)
274
- right_field = getattr(association, right_field_name)
275
- found = False
276
- for field in [left_field, right_field]:
277
- if asset in field:
278
- found = True
279
- if len(field) == 1:
280
- # There are no other assets on this side,
281
- # so we should remove the entire association.
282
- self.remove_association(association)
283
- return
284
- field.remove(asset)
285
-
286
- if not found:
287
- raise LookupError(f'Asset "{asset.name}"({asset.id}) is not '
288
- 'part of the association provided.')
289
-
290
- def _validate_association(self, association: SchemaGeneratedClass) -> None:
291
- """Raise error if association is invalid or already part of the Model.
292
-
293
- Raises:
294
- DuplicateAssociationError - same association already exists
295
- ModelAssociationException - association is not valid
296
- """
297
-
298
- # Optimization: only look for duplicates in associations of same type
299
- association_type = association.type
300
- associations_same_type = self._type_to_association.get(
301
- association_type, []
302
- )
303
-
304
- # Check if identical association already exists
305
- if association in associations_same_type:
306
- raise DuplicateModelAssociationError(
307
- f"Identical association {association_type} already exists"
308
- )
309
-
310
-
311
- # Check for duplicate assets in each field
312
- left_field_name, right_field_name = \
313
- self.get_association_field_names(association)
314
-
315
- for field_name in (left_field_name, right_field_name):
316
- field_assets = getattr(association, field_name)
317
-
318
- unique_field_asset_names = {a.name for a in field_assets}
319
- if len(field_assets) > len(unique_field_asset_names):
320
- raise ModelAssociationException(
321
- "More than one asset share same name in field"
322
- f"{association_type}.{field_name}"
323
- )
324
-
325
- # For each asset in left field, go through each assets in right field
326
- # to find all unique connections. Raise error if a connection between
327
- # two assets already exist in a previously added association.
328
- for left_asset in getattr(association, left_field_name):
329
- for right_asset in getattr(association, right_field_name):
330
-
331
- if self.association_exists_between_assets(
332
- association_type, left_asset, right_asset
333
- ):
334
- # Assets already have the connection in another
335
- # association with same type
336
- raise DuplicateModelAssociationError(
337
- f"Association type {association_type} already exists"
338
- f" between {left_asset.name} and {right_asset.name}"
339
- )
340
-
341
- def add_association(self, association: SchemaGeneratedClass) -> None:
342
- """Add an association to the model.
343
-
344
- An association will have 2 field names, each
345
- potentially containing several assets.
346
-
347
- Arguments:
348
- association - the association to add to the model
349
-
350
- Raises:
351
- DuplicateAssociationError - same association already exists
352
- ModelAssociationException - association is not valid
353
-
354
- """
264
+ del self.assets[asset.id]
265
+ del self._name_to_asset[asset.name]
355
266
 
356
- # Check association is valid and not duplicate
357
- self._validate_association(association)
358
-
359
- # Optional field for extra association data
360
- association.extras = {}
361
-
362
- field_names = self.get_association_field_names(association)
363
-
364
- # Add the association to all of the included assets
365
- for field_name in field_names:
366
- for asset in getattr(association, field_name):
367
- asset_assocs = list(asset.associations)
368
- asset_assocs.append(association)
369
- asset.associations = asset_assocs
370
-
371
- self.associations.append(association)
372
-
373
- # Add association to type->association mapping
374
- association_type = association.type
375
- self._type_to_association.setdefault(
376
- association_type, []
377
- ).append(association)
378
-
379
-
380
- def remove_association(self, association: SchemaGeneratedClass) -> None:
381
- """Remove an association from the model.
382
-
383
- Arguments:
384
- association - the association to remove from the model
385
- """
386
-
387
- if association not in self.associations:
388
- raise LookupError(
389
- f'Association is not part of model "{self.name}".'
390
- )
391
-
392
- left_field_name, right_field_name = \
393
- self.get_association_field_names(association)
394
- left_field = getattr(association, left_field_name)
395
- right_field = getattr(association, right_field_name)
396
-
397
- for asset in left_field:
398
- assocs = list(asset.associations)
399
- assocs.remove(association)
400
- asset.associations = assocs
401
-
402
- for asset in right_field:
403
- # In fringe cases we may have reflexive associations where the
404
- # association was already removed when processing the left field
405
- # assets therefore we have to check if it is still in the list.
406
- if association in asset.associations:
407
- assocs = list(asset.associations)
408
- assocs.remove(association)
409
- asset.associations = assocs
410
-
411
- self.associations.remove(association)
412
-
413
- # Remove association from type->association mapping
414
- association_type = association.type
415
- self._type_to_association[association_type].remove(
416
- association
417
- )
418
- # Remove type from type->association mapping if mapping empty
419
- if len(self._type_to_association[association_type]) == 0:
420
- del self._type_to_association[association_type]
421
267
 
422
268
  def add_attacker(
423
269
  self,
@@ -441,9 +287,10 @@ class Model():
441
287
  attacker.name = 'Attacker:' + str(attacker.id)
442
288
  self.attackers.append(attacker)
443
289
 
290
+
444
291
  def get_asset_by_id(
445
292
  self, asset_id: int
446
- ) -> Optional[SchemaGeneratedClass]:
293
+ ) -> Optional[ModelAsset]:
447
294
  """
448
295
  Find an asset in the model based on its id.
449
296
 
@@ -457,14 +304,12 @@ class Model():
457
304
  'Get asset with id %d from model "%s".',
458
305
  asset_id, self.name
459
306
  )
460
- return next(
461
- (asset for asset in self.assets
462
- if asset.id == asset_id), None
463
- )
307
+ return self.assets.get(asset_id, None)
308
+
464
309
 
465
310
  def get_asset_by_name(
466
311
  self, asset_name: str
467
- ) -> Optional[SchemaGeneratedClass]:
312
+ ) -> Optional[ModelAsset]:
468
313
  """
469
314
  Find an asset in the model based on its name.
470
315
 
@@ -478,10 +323,8 @@ class Model():
478
323
  'Get asset with name "%s" from model "%s".',
479
324
  asset_name, self.name
480
325
  )
481
- return next(
482
- (asset for asset in self.assets
483
- if asset.name == asset_name), None
484
- )
326
+ return self._name_to_asset.get(asset_name, None)
327
+
485
328
 
486
329
  def get_attacker_by_id(
487
330
  self, attacker_id: int
@@ -504,187 +347,6 @@ class Model():
504
347
  if attacker.id == attacker_id), None
505
348
  )
506
349
 
507
- def association_exists_between_assets(
508
- self,
509
- association_type: str,
510
- left_asset: SchemaGeneratedClass,
511
- right_asset: SchemaGeneratedClass
512
- ):
513
- """Return True if the association already exists between the assets"""
514
- logger.debug(
515
- 'Check to see if an association of type "%s" '
516
- 'already exists between "%s" and "%s".',
517
- association_type, left_asset.name, right_asset.name
518
- )
519
- associations = self._type_to_association.get(association_type, [])
520
- for association in associations:
521
- left_field_name, right_field_name = \
522
- self.get_association_field_names(association)
523
- if (left_asset.id in [asset.id for asset in \
524
- getattr(association, left_field_name)] and \
525
- right_asset.id in [asset.id for asset in \
526
- getattr(association, right_field_name)]):
527
- logger.debug(
528
- 'An association of type "%s" '
529
- 'already exists between "%s" and "%s".',
530
- association_type, left_asset.name, right_asset.name
531
- )
532
- return True
533
- logger.debug(
534
- 'No association of type "%s" '
535
- 'exists between "%s" and "%s".',
536
- association_type, left_asset.name, right_asset.name
537
- )
538
- return False
539
-
540
- def get_asset_defenses(
541
- self,
542
- asset: SchemaGeneratedClass,
543
- include_defaults: bool = False
544
- ):
545
- """
546
- Get the two field names of the association as a list.
547
- Arguments:
548
- asset - the asset to fetch the defenses for
549
- include_defaults - if not True the defenses that have default
550
- values will not be included in the list
551
-
552
- Return:
553
- A dictionary containing the defenses of the asset
554
- """
555
-
556
- defenses = {}
557
- for key, value in asset._properties.items():
558
- property_schema = (
559
- self.lang_classes_factory.json_schema['definitions']
560
- ['LanguageAsset'] ['definitions']
561
- ['Asset_' + asset.type]['properties'][key]
562
- )
563
-
564
- if "maximum" not in property_schema:
565
- # Check if property is a defense by looking up defense
566
- # specific key. Skip if it is not a defense.
567
- continue
568
-
569
- logger.debug(
570
- 'Translating %s: %s defense to dictionary.',
571
- key,
572
- value
573
- )
574
-
575
- if not include_defaults and value == value.default():
576
- # Skip the defense values if they are the default ones.
577
- continue
578
-
579
- defenses[key] = float(value)
580
-
581
- return defenses
582
-
583
- def get_association_field_names(
584
- self,
585
- association: SchemaGeneratedClass
586
- ):
587
- """
588
- Get the two field names of the association as a list.
589
- Arguments:
590
- association - the association to fetch the field names for
591
-
592
- Return:
593
- A two item list containing the field names of the association.
594
- """
595
-
596
- return list(association._properties.keys())[1:]
597
-
598
-
599
- def get_associated_assets_by_field_name(
600
- self,
601
- asset: SchemaGeneratedClass,
602
- field_name: str
603
- ) -> list[SchemaGeneratedClass]:
604
- """
605
- Get a list of associated assets for an asset given a field name.
606
-
607
- Arguments:
608
- asset - the asset whose fields we are interested in
609
- field_name - the field name we are looking for
610
-
611
- Return:
612
- A list of assets associated with the asset given that match the
613
- field_name.
614
- """
615
- logger.debug(
616
- 'Get associated assets for asset "%s"(%d) by field name %s.',
617
- asset.name, asset.id, field_name
618
- )
619
- associated_assets = []
620
- for association in asset.associations:
621
- if hasattr(association, field_name):
622
- associated_assets.extend(getattr(association, field_name))
623
-
624
- return associated_assets
625
-
626
- def asset_to_dict(self, asset: SchemaGeneratedClass) -> tuple[str, dict]:
627
- """Get dictionary representation of the asset.
628
-
629
- Arguments:
630
- asset - asset to get dictionary representation of
631
-
632
- Return: tuple with name of asset and the asset as dict
633
- """
634
-
635
- logger.debug(
636
- 'Translating "%s"(%d) to dictionary.',
637
- asset.name,
638
- asset.id
639
- )
640
-
641
-
642
- asset_dict: dict[str, Any] = {
643
- 'name': str(asset.name),
644
- 'type': str(asset.type)
645
- }
646
-
647
- defenses = self.get_asset_defenses(asset)
648
-
649
- if defenses:
650
- asset_dict['defenses'] = defenses
651
-
652
- if asset.extras:
653
- # Add optional metadata to dict
654
- asset_dict['extras'] = asset.extras.as_dict()
655
-
656
- return (asset.id, asset_dict)
657
-
658
-
659
- def association_to_dict(self, association: SchemaGeneratedClass) -> dict:
660
- """Get dictionary representation of the association.
661
-
662
- Arguments:
663
- association - association to get dictionary representation of
664
-
665
- Returns the association serialized to a dict
666
- """
667
-
668
- left_field_name, right_field_name = \
669
- self.get_association_field_names(association)
670
- left_field = getattr(association, left_field_name)
671
- right_field = getattr(association, right_field_name)
672
-
673
- association_dict = {
674
- str(association.type) :
675
- {
676
- str(left_field_name):
677
- {int(asset.id): str(asset.name) for asset in left_field},
678
- str(right_field_name):
679
- {int(asset.id): str(asset.name) for asset in right_field}
680
- }
681
- }
682
-
683
- if association.extras:
684
- # Add optional metadata to dict
685
- association_dict['extras'] = association.extras
686
-
687
- return association_dict
688
350
 
689
351
  def attacker_to_dict(
690
352
  self, attacker: AttackerAttachment
@@ -697,43 +359,37 @@ class Model():
697
359
 
698
360
  logger.debug('Translating %s to dictionary.', attacker.name)
699
361
  attacker_dict: dict[str, Any] = {
700
- 'name': str(attacker.name),
362
+ 'name': attacker.name,
701
363
  'entry_points': {},
702
364
  }
703
365
  for (asset, attack_steps) in attacker.entry_points:
704
- attacker_dict['entry_points'][str(asset.name)] = {
705
- 'asset_id': int(asset.id),
366
+ attacker_dict['entry_points'][asset.name] = {
367
+ 'asset_id': asset.id,
706
368
  'attack_steps' : attack_steps
707
369
  }
708
370
  return (attacker.id, attacker_dict)
709
371
 
372
+
710
373
  def _to_dict(self) -> dict:
711
374
  """Get dictionary representation of the model."""
712
375
  logger.debug('Translating model to dict.')
713
376
  contents: dict[str, Any] = {
714
377
  'metadata': {},
715
378
  'assets': {},
716
- 'associations': [],
717
379
  'attackers' : {}
718
380
  }
719
381
  contents['metadata'] = {
720
382
  'name': self.name,
721
- 'langVersion': self.lang_classes_factory.lang_graph.metadata['version'],
722
- 'langID': self.lang_classes_factory.lang_graph.metadata['id'],
383
+ 'langVersion': self.lang_graph.metadata['version'],
384
+ 'langID': self.lang_graph.metadata['id'],
723
385
  'malVersion': '0.1.0-SNAPSHOT',
724
386
  'MAL-Toolbox Version': __version__,
725
387
  'info': 'Created by the mal-toolbox model python module.'
726
388
  }
727
389
 
728
390
  logger.debug('Translating assets to dictionary.')
729
- for asset in self.assets:
730
- (asset_id, asset_dict) = self.asset_to_dict(asset)
731
- contents['assets'][int(asset_id)] = asset_dict
732
-
733
- logger.debug('Translating associations to dictionary.')
734
- for association in self.associations:
735
- assoc_dict = self.association_to_dict(association)
736
- contents['associations'].append(assoc_dict)
391
+ for asset in self.assets.values():
392
+ contents['assets'].update(asset._to_dict())
737
393
 
738
394
  logger.debug('Translating attackers to dictionary.')
739
395
  for attacker in self.attackers:
@@ -741,22 +397,24 @@ class Model():
741
397
  contents['attackers'][attacker_id] = attacker_dict
742
398
  return contents
743
399
 
400
+
744
401
  def save_to_file(self, filename: str) -> None:
745
402
  """Save to json/yml depending on extension"""
746
403
  logger.debug('Save instance model to file "%s".', filename)
747
404
  return save_dict_to_file(filename, self._to_dict())
748
405
 
406
+
749
407
  @classmethod
750
408
  def _from_dict(
751
409
  cls,
752
410
  serialized_object: dict,
753
- lang_classes_factory: LanguageClassesFactory
411
+ lang_graph: LanguageGraph,
754
412
  ) -> Model:
755
413
  """Create a model from dict representation
756
414
 
757
415
  Arguments:
758
416
  serialized_object - Model in dict format
759
- lang_classes_factory -
417
+ lang_graph -
760
418
  """
761
419
 
762
420
  maltoolbox_version = serialized_object['metadata']['MAL Toolbox Version'] \
@@ -764,72 +422,46 @@ class Model():
764
422
  else __version__
765
423
  model = Model(
766
424
  serialized_object['metadata']['name'],
767
- lang_classes_factory,
425
+ lang_graph,
768
426
  mt_version = maltoolbox_version)
769
427
 
770
428
  # Reconstruct the assets
771
- for asset_id, asset_object in serialized_object['assets'].items():
429
+ for asset_id, asset_dict in serialized_object['assets'].items():
772
430
 
773
431
  if logger.isEnabledFor(logging.DEBUG):
774
432
  # Avoid running json.dumps when not in debug
775
433
  logger.debug(
776
- "Loading asset:\n%s", json.dumps(asset_object, indent=2)
434
+ "Loading asset:\n%s", json.dumps(asset_dict, indent=2)
777
435
  )
778
436
 
779
437
  # Allow defining an asset via type only.
780
- asset_object = (
781
- asset_object
782
- if isinstance(asset_object, dict)
438
+ asset_dict = (
439
+ asset_dict
440
+ if isinstance(asset_dict, dict)
783
441
  else {
784
- 'type': asset_object,
785
- 'name': f"{asset_object}:{asset_id}"
442
+ 'type': asset_dict,
443
+ 'name': f"{asset_dict}:{asset_id}"
786
444
  }
787
445
  )
788
446
 
789
- asset_type_class = model.lang_classes_factory.get_asset_class(
790
- asset_object['type'])
791
-
792
- # TODO: remove this when factory goes away
793
- asset_type_class.__hash__ = lambda self: hash(self.name) # type: ignore[method-assign,misc]
794
-
795
- if asset_type_class is None:
796
- raise LookupError('Failed to find asset "%s" in language'
797
- ' classes factory' % asset_object['type'])
798
- asset = asset_type_class(name = asset_object['name'])
799
-
800
- if 'extras' in asset_object:
801
- asset.extras = asset_object['extras']
802
-
803
- for defense in (defenses:=asset_object.get('defenses', [])):
804
- setattr(asset, defense, float(defenses[defense]))
805
-
806
- model.add_asset(asset, asset_id = int(asset_id))
807
-
808
- # Reconstruct the associations
809
- for assoc_entry in serialized_object.get('associations', []):
810
- [(assoc, assoc_fields)] = assoc_entry.items()
811
- assoc_keys_iter = iter(assoc_fields)
812
- field1 = next(assoc_keys_iter)
813
- field2 = next(assoc_keys_iter)
814
- assoc_type_class = model.lang_classes_factory.\
815
- get_association_class_by_fieldnames(assoc, field1, field2)
816
- if assoc_type_class is None:
817
- raise LookupError('Failed to find association "%s" with '
818
- 'fields "%s" and "%s" in language classes factory' %
819
- (assoc, field1, field2)
447
+ model.add_asset(
448
+ asset_type = asset_dict['type'],
449
+ name = asset_dict['name'],
450
+ defenses = {defense: float(value) for defense, value in \
451
+ asset_dict.get('defenses', {}).items()},
452
+ extras = asset_dict.get('extras', {}),
453
+ asset_id = int(asset_id))
454
+
455
+ # Reconstruct the association links
456
+ for asset_id, asset_dict in serialized_object['assets'].items():
457
+ asset = model.assets[int(asset_id)]
458
+ assoc_assets_dict = asset_dict['associated_assets'].items()
459
+ for fieldname, assoc_assets in assoc_assets_dict:
460
+ asset.add_associated_assets(
461
+ fieldname,
462
+ {model.assets[int(assoc_asset_id)]
463
+ for assoc_asset_id in assoc_assets}
820
464
  )
821
- association = assoc_type_class()
822
-
823
- for field, targets in assoc_fields.items():
824
- setattr(
825
- association,
826
- field,
827
- [model.get_asset_by_id(int(id)) for id in targets]
828
- )
829
-
830
- #TODO Properly handle extras
831
-
832
- model.add_association(association)
833
465
 
834
466
  # Reconstruct the attackers
835
467
  if 'attackers' in serialized_object:
@@ -839,21 +471,30 @@ class Model():
839
471
  attacker.entry_points = []
840
472
  for asset_name, entry_points_dict in \
841
473
  attackers_info[attacker_id]['entry_points'].items():
474
+ target_asset = model.get_asset_by_id(
475
+ entry_points_dict['asset_id'])
476
+ if target_asset is None:
477
+ raise LookupError(
478
+ 'Asset "%s"(%d) is not part of model "%s".' % (
479
+ asset_name,
480
+ entry_points_dict['asset_id'],
481
+ model.name)
482
+ )
842
483
  attacker.entry_points.append(
843
484
  (
844
- model.get_asset_by_id(
845
- entry_points_dict['asset_id']),
485
+ target_asset,
846
486
  entry_points_dict['attack_steps']
847
487
  )
848
488
  )
849
489
  model.add_attacker(attacker, attacker_id = int(attacker_id))
850
490
  return model
851
491
 
492
+
852
493
  @classmethod
853
494
  def load_from_file(
854
495
  cls,
855
496
  filename: str,
856
- lang_classes_factory: LanguageClassesFactory
497
+ lang_graph: LanguageGraph,
857
498
  ) -> Model:
858
499
  """Create from json or yaml file depending on file extension"""
859
500
  logger.debug('Load instance model from file "%s".', filename)
@@ -864,4 +505,181 @@ class Model():
864
505
  serialized_model = load_dict_from_json_file(filename)
865
506
  else:
866
507
  raise ValueError('Unknown file extension, expected json/yml/yaml')
867
- return cls._from_dict(serialized_model, lang_classes_factory)
508
+ try:
509
+ return cls._from_dict(serialized_model, lang_graph)
510
+ except Exception as e:
511
+ raise ModelException(
512
+ "Could not load model. It might be of an older version. "
513
+ "Try to upgrade it with 'maltoolbox upgrade-model'"
514
+ ) from e
515
+
516
+
517
+ class ModelAsset:
518
+ def __init__(
519
+ self,
520
+ name: str,
521
+ asset_id: int,
522
+ lg_asset: LanguageGraphAsset,
523
+ defenses: Optional[dict[str, float]] = None,
524
+ extras: Optional[dict] = None
525
+ ):
526
+
527
+ self.name: str = name
528
+ self._id: int = asset_id
529
+ self.lg_asset: LanguageGraphAsset = lg_asset
530
+ self.type = self.lg_asset.name
531
+ self.defenses: dict[str, float] = defenses or {}
532
+ self.extras: dict = extras or {}
533
+ self._associated_assets: dict[str, set[ModelAsset]] = {}
534
+ self.attack_step_nodes: list = []
535
+
536
+ for step in self.lg_asset.attack_steps.values():
537
+ if step.type == 'defense' and step.name not in self.defenses:
538
+ self.defenses[step.name] = 1.0 if step.ttc and \
539
+ step.ttc['name'] == 'Enabled' else 0.0
540
+
541
+
542
+ def _to_dict(self):
543
+ """Get dictionary representation of the asset."""
544
+
545
+ logger.debug(
546
+ 'Translating "%s"(%d) to dictionary.', self.name, self.id)
547
+
548
+ asset_dict: dict[str, Any] = {
549
+ 'name': self.name,
550
+ 'type': self.type,
551
+ 'defenses': {},
552
+ 'associated_assets': {}
553
+ }
554
+
555
+ # Only add non-default values for defenses to improve legibility of
556
+ # the model format
557
+ for defense, defense_value in self.defenses.items():
558
+ lg_step = self.lg_asset.attack_steps[defense]
559
+ default_defval = 1.0 if lg_step.ttc and \
560
+ lg_step.ttc['name'] == 'Enabled' else 0.0
561
+ if defense_value != default_defval:
562
+ asset_dict['defenses'][defense] = defense_value
563
+
564
+ for fieldname, assets in self.associated_assets.items():
565
+ asset_dict['associated_assets'][fieldname] = {asset.id: asset.name
566
+ for asset in assets}
567
+
568
+ if len(asset_dict['defenses']) == 0:
569
+ # Do not include an empty defenses dictionary
570
+ del asset_dict['defenses']
571
+
572
+ if self.extras != {}:
573
+ # Add optional metadata to dict
574
+ asset_dict['extras'] = self.extras
575
+
576
+ return {self.id: asset_dict}
577
+
578
+
579
+ def __repr__(self):
580
+ return (f'ModelAsset(name: "{self.name}", id: {self.id}, '
581
+ f'type: {self.type})')
582
+
583
+
584
+ def validate_associated_assets(
585
+ self, fieldname: str, assets_to_add: set[ModelAsset]
586
+ ):
587
+ """
588
+ Validate an association we want to add (through `fieldname`)
589
+ is valid with the assets given in param `assets_to_add`:
590
+ - fieldname is valid for the asset type of this ModelAsset
591
+ - type of `assets_to_add` is valid for the association
592
+ - no more assets than 'field.maximum' are added to the field
593
+
594
+ Raises:
595
+ LookupError - fieldname can not be found for this ModelAsset
596
+ ValueError - there will be too many assets in the field
597
+ if we add this association
598
+ TypeError - if the asset type of `assets_to_add` is not valid
599
+ """
600
+
601
+ # Validate that the field name is allowed for this asset type
602
+ if fieldname not in self.lg_asset.associations:
603
+ accepted_fieldnames = list(self.lg_asset.associations.keys())
604
+ raise LookupError(
605
+ f"Fieldname '{fieldname}' is not an accepted association "
606
+ f"fieldname from asset type {self.lg_asset.name}. "
607
+ f"Did you mean one of {accepted_fieldnames}?"
608
+ )
609
+
610
+ lg_assoc = self.lg_asset.associations[fieldname]
611
+ assoc_field = lg_assoc.get_field(fieldname)
612
+
613
+ # Validate that the asset to add association to is of correct type
614
+ for asset_to_add in assets_to_add:
615
+ if not asset_to_add.lg_asset.is_subasset_of(assoc_field.asset):
616
+ raise TypeError(
617
+ f"Asset '{asset_to_add.name}' of type "
618
+ f"'{asset_to_add.type}' can not be added to association "
619
+ f"'{self.name}.{fieldname}'. Expected type of "
620
+ f"'{fieldname}' is {assoc_field.asset.name}."
621
+ )
622
+
623
+ # Validate that there will not be too many assets in field
624
+ assets_in_field_before = self.associated_assets.get(fieldname, set())
625
+ assets_in_field_after = assets_in_field_before | set(assets_to_add)
626
+ max_assets_in_field = assoc_field.maximum or math.inf
627
+
628
+ if len(assets_in_field_after) > max_assets_in_field:
629
+ raise ValueError(
630
+ f"You can have maximum {assoc_field.maximum} "
631
+ f"assets for association field {fieldname}"
632
+ )
633
+
634
+ def add_associated_assets(self, fieldname: str, assets: set[ModelAsset]):
635
+ """
636
+ Add the assets provided as a parameter to the set of associated
637
+ assets dictionary entry corresponding to the given fieldname.
638
+ """
639
+
640
+ lg_assoc = self.lg_asset.associations[fieldname]
641
+ other_fieldname = lg_assoc.get_opposite_fieldname(fieldname)
642
+
643
+ # Validation from both sides
644
+ self.validate_associated_assets(fieldname, assets)
645
+ for asset in assets:
646
+ asset.validate_associated_assets(other_fieldname, {self})
647
+
648
+ # Add the associated assets to this asset's dictionary
649
+ self._associated_assets.setdefault(
650
+ fieldname, set()
651
+ ).update(assets)
652
+
653
+ # Add this asset to the associated assets' corresponding dictionaries
654
+ for asset in assets:
655
+ asset._associated_assets.setdefault(
656
+ other_fieldname, set()
657
+ ).add(self)
658
+
659
+ def remove_associated_assets(self, fieldname: str,
660
+ assets: set[ModelAsset]):
661
+ """ Remove the assets provided as a parameter from the set of
662
+ associated assets dictionary entry corresponding to the fieldname
663
+ parameter.
664
+ """
665
+ self._associated_assets[fieldname] -= set(assets)
666
+ if len(self._associated_assets[fieldname]) == 0:
667
+ del self._associated_assets[fieldname]
668
+
669
+ # Also remove this asset to the associated assets' dictionaries
670
+ lg_assoc = self.lg_asset.associations[fieldname]
671
+ other_fieldname = lg_assoc.get_opposite_fieldname(fieldname)
672
+ for asset in assets:
673
+ asset._associated_assets[other_fieldname].remove(self)
674
+ if len(asset._associated_assets[other_fieldname]) == 0:
675
+ del asset._associated_assets[other_fieldname]
676
+
677
+
678
+ @property
679
+ def associated_assets(self):
680
+ return self._associated_assets
681
+
682
+
683
+ @property
684
+ def id(self):
685
+ return self._id