mal-toolbox 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mal_toolbox-0.2.0.dist-info → mal_toolbox-0.3.0.dist-info}/METADATA +43 -25
- mal_toolbox-0.3.0.dist-info/RECORD +29 -0
- mal_toolbox-0.3.0.dist-info/entry_points.txt +2 -0
- maltoolbox/__init__.py +38 -57
- maltoolbox/__main__.py +43 -14
- maltoolbox/attackgraph/__init__.py +1 -1
- maltoolbox/attackgraph/analyzers/apriori.py +6 -5
- maltoolbox/attackgraph/attacker.py +26 -13
- maltoolbox/attackgraph/attackgraph.py +175 -148
- maltoolbox/attackgraph/node.py +56 -54
- maltoolbox/attackgraph/query.py +4 -2
- maltoolbox/file_utils.py +0 -8
- maltoolbox/ingestors/neo4j.py +146 -157
- maltoolbox/language/__init__.py +7 -3
- maltoolbox/language/compiler/__init__.py +485 -17
- maltoolbox/language/compiler/mal_lexer.py +172 -152
- maltoolbox/language/compiler/mal_parser.py +1370 -663
- maltoolbox/language/languagegraph.py +103 -99
- maltoolbox/model.py +306 -488
- maltoolbox/translators/securicad.py +164 -163
- maltoolbox/translators/updater.py +231 -108
- mal_toolbox-0.2.0.dist-info/RECORD +0 -32
- maltoolbox/default.conf +0 -17
- maltoolbox/language/classes_factory.py +0 -259
- maltoolbox/language/compiler/mal_visitor.py +0 -416
- maltoolbox/wrappers.py +0 -62
- {mal_toolbox-0.2.0.dist-info → mal_toolbox-0.3.0.dist-info}/AUTHORS +0 -0
- {mal_toolbox-0.2.0.dist-info → mal_toolbox-0.3.0.dist-info}/LICENSE +0 -0
- {mal_toolbox-0.2.0.dist-info → mal_toolbox-0.3.0.dist-info}/WHEEL +0 -0
- {mal_toolbox-0.2.0.dist-info → mal_toolbox-0.3.0.dist-info}/top_level.txt +0 -0
maltoolbox/model.py
CHANGED
|
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
|
|
|
7
7
|
import json
|
|
8
8
|
import logging
|
|
9
9
|
from typing import TYPE_CHECKING
|
|
10
|
+
import math
|
|
10
11
|
|
|
11
12
|
from .file_utils import (
|
|
12
13
|
load_dict_from_json_file,
|
|
@@ -15,14 +16,14 @@ from .file_utils import (
|
|
|
15
16
|
)
|
|
16
17
|
|
|
17
18
|
from . import __version__
|
|
18
|
-
from .exceptions import
|
|
19
|
+
from .exceptions import ModelException
|
|
19
20
|
|
|
20
21
|
if TYPE_CHECKING:
|
|
21
|
-
from typing import Any, Optional
|
|
22
|
-
from .language import
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
from typing import Any, Optional
|
|
23
|
+
from .language import (
|
|
24
|
+
LanguageGraph,
|
|
25
|
+
LanguageGraphAsset,
|
|
26
|
+
)
|
|
26
27
|
|
|
27
28
|
logger = logging.getLogger(__name__)
|
|
28
29
|
|
|
@@ -31,14 +32,14 @@ class AttackerAttachment:
|
|
|
31
32
|
"""Used to attach attackers to attack step entry points of assets"""
|
|
32
33
|
id: Optional[int] = None
|
|
33
34
|
name: Optional[str] = None
|
|
34
|
-
entry_points: list[tuple[
|
|
35
|
+
entry_points: list[tuple[ModelAsset, list[str]]] = \
|
|
35
36
|
field(default_factory=lambda: [])
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
def get_entry_point_tuple(
|
|
39
40
|
self,
|
|
40
|
-
asset:
|
|
41
|
-
) -> Optional[tuple[
|
|
41
|
+
asset: ModelAsset
|
|
42
|
+
) -> Optional[tuple[ModelAsset, list[str]]]:
|
|
42
43
|
"""Return an entry point tuple of an AttackerAttachment matching the
|
|
43
44
|
asset provided.
|
|
44
45
|
|
|
@@ -57,7 +58,7 @@ class AttackerAttachment:
|
|
|
57
58
|
|
|
58
59
|
|
|
59
60
|
def add_entry_point(
|
|
60
|
-
self, asset:
|
|
61
|
+
self, asset: ModelAsset, attackstep_name: str):
|
|
61
62
|
"""Add an entry point to an AttackerAttachment
|
|
62
63
|
|
|
63
64
|
self.entry_points contain tuples, first element of each tuple
|
|
@@ -92,8 +93,9 @@ class AttackerAttachment:
|
|
|
92
93
|
# point
|
|
93
94
|
self.entry_points.append((asset, [attackstep_name]))
|
|
94
95
|
|
|
96
|
+
|
|
95
97
|
def remove_entry_point(
|
|
96
|
-
self, asset:
|
|
98
|
+
self, asset: ModelAsset, attackstep_name: str):
|
|
97
99
|
"""Remove an entry point from an AttackerAttachment if it exists
|
|
98
100
|
|
|
99
101
|
Arguments:
|
|
@@ -129,25 +131,25 @@ class AttackerAttachment:
|
|
|
129
131
|
|
|
130
132
|
|
|
131
133
|
class Model():
|
|
132
|
-
"""An implementation of a MAL language
|
|
134
|
+
"""An implementation of a MAL language model containing assets"""
|
|
133
135
|
next_id: int = 0
|
|
134
136
|
|
|
135
137
|
def __repr__(self) -> str:
|
|
136
|
-
return f'Model {self.name}'
|
|
138
|
+
return f'Model(name: "{self.name}", language: {self.lang_graph})'
|
|
139
|
+
|
|
137
140
|
|
|
138
141
|
def __init__(
|
|
139
142
|
self,
|
|
140
143
|
name: str,
|
|
141
|
-
|
|
144
|
+
lang_graph: LanguageGraph,
|
|
142
145
|
mt_version: str = __version__
|
|
143
146
|
):
|
|
144
147
|
|
|
145
148
|
self.name = name
|
|
146
|
-
self.assets:
|
|
147
|
-
self.
|
|
148
|
-
self._type_to_association:dict = {} # optimization
|
|
149
|
+
self.assets: dict[int, ModelAsset] = {}
|
|
150
|
+
self._name_to_asset:dict[str, ModelAsset] = {} # optimization
|
|
149
151
|
self.attackers: list[AttackerAttachment] = []
|
|
150
|
-
self.
|
|
152
|
+
self.lang_graph = lang_graph
|
|
151
153
|
self.maltoolbox_version: str = mt_version
|
|
152
154
|
|
|
153
155
|
# Below sets used to check for duplicate names or ids,
|
|
@@ -155,63 +157,84 @@ class Model():
|
|
|
155
157
|
self.asset_ids: set[int] = set()
|
|
156
158
|
self.asset_names: set[str] = set()
|
|
157
159
|
|
|
160
|
+
|
|
158
161
|
def add_asset(
|
|
159
162
|
self,
|
|
160
|
-
|
|
163
|
+
asset_type: str,
|
|
164
|
+
name: Optional[str] = None,
|
|
161
165
|
asset_id: Optional[int] = None,
|
|
166
|
+
defenses: Optional[dict[str, float]] = None,
|
|
167
|
+
extras: Optional[dict] = None,
|
|
162
168
|
allow_duplicate_names: bool = True
|
|
163
|
-
) ->
|
|
164
|
-
"""
|
|
169
|
+
) -> ModelAsset:
|
|
170
|
+
"""
|
|
171
|
+
Create an asset based on the provided parameters and add it to the
|
|
172
|
+
model.
|
|
165
173
|
|
|
166
174
|
Arguments:
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
175
|
+
asset_type - string containing the asset type name
|
|
176
|
+
name - string containing the asset name. If not
|
|
177
|
+
provided the concatenated asset type and id
|
|
178
|
+
will be used as a name.
|
|
179
|
+
asset_id - id to assign to this asset, usually from an
|
|
180
|
+
instance model file. If not provided the id
|
|
181
|
+
will be set to the next highest id
|
|
182
|
+
available.
|
|
183
|
+
defeses - dictionary of defense values
|
|
184
|
+
extras - dictionary of extras
|
|
170
185
|
allow_duplicate_name - allow duplicate names to be used. If allowed
|
|
171
186
|
and a duplicate is encountered the name will
|
|
172
187
|
be appended with the id.
|
|
173
188
|
|
|
174
189
|
Return:
|
|
175
|
-
|
|
190
|
+
The newly created asset.
|
|
176
191
|
"""
|
|
177
192
|
|
|
178
193
|
# Set asset ID and check for duplicates
|
|
179
|
-
|
|
180
|
-
if
|
|
194
|
+
asset_id = asset_id or self.next_id
|
|
195
|
+
if asset_id in self.asset_ids:
|
|
181
196
|
raise ValueError(f'Asset index {asset_id} already in use.')
|
|
182
|
-
self.asset_ids.add(
|
|
197
|
+
self.asset_ids.add(asset_id)
|
|
183
198
|
|
|
184
|
-
self.next_id = max(
|
|
199
|
+
self.next_id = max(asset_id + 1, self.next_id)
|
|
185
200
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
if not hasattr(asset, 'name'):
|
|
189
|
-
asset.name = asset.type + ':' + str(asset.id)
|
|
201
|
+
if not name:
|
|
202
|
+
name = asset_type + ':' + str(asset_id)
|
|
190
203
|
else:
|
|
191
|
-
if
|
|
204
|
+
if name in self.asset_names:
|
|
192
205
|
if allow_duplicate_names:
|
|
193
|
-
|
|
206
|
+
name = name + ':' + str(asset_id)
|
|
194
207
|
else:
|
|
195
208
|
raise ValueError(
|
|
196
|
-
f'Asset name {
|
|
209
|
+
f'Asset name {name} is a duplicate'
|
|
197
210
|
' and we do not allow duplicates.'
|
|
198
211
|
)
|
|
199
|
-
self.asset_names.add(
|
|
212
|
+
self.asset_names.add(name)
|
|
213
|
+
|
|
214
|
+
lg_asset = self.lang_graph.assets[asset_type]
|
|
200
215
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
216
|
+
asset = ModelAsset(
|
|
217
|
+
name = name,
|
|
218
|
+
asset_id = asset_id,
|
|
219
|
+
lg_asset = lg_asset,
|
|
220
|
+
defenses = defenses,
|
|
221
|
+
extras = extras)
|
|
204
222
|
|
|
205
223
|
logger.debug(
|
|
206
|
-
'Add "%s"(%d) to model "%s".',
|
|
224
|
+
'Add "%s"(%d) to model "%s".', name, asset_id, self.name
|
|
207
225
|
)
|
|
208
|
-
self.assets
|
|
226
|
+
self.assets[asset_id] = asset
|
|
227
|
+
self._name_to_asset[name] = asset
|
|
228
|
+
|
|
229
|
+
return asset
|
|
230
|
+
|
|
209
231
|
|
|
210
232
|
def remove_attacker(self, attacker: AttackerAttachment) -> None:
|
|
211
233
|
"""Remove attacker"""
|
|
212
234
|
self.attackers.remove(attacker)
|
|
213
235
|
|
|
214
|
-
|
|
236
|
+
|
|
237
|
+
def remove_asset(self, asset: ModelAsset) -> None:
|
|
215
238
|
"""Remove an asset from the model.
|
|
216
239
|
|
|
217
240
|
Arguments:
|
|
@@ -222,15 +245,15 @@ class Model():
|
|
|
222
245
|
'Remove "%s"(%d) from model "%s".',
|
|
223
246
|
asset.name, asset.id, self.name
|
|
224
247
|
)
|
|
225
|
-
if asset not in self.assets:
|
|
248
|
+
if asset.id not in self.assets:
|
|
226
249
|
raise LookupError(
|
|
227
250
|
f'Asset "{asset.name}"({asset.id}) is not part'
|
|
228
251
|
f' of model"{self.name}".'
|
|
229
252
|
)
|
|
230
253
|
|
|
231
|
-
# First remove all of the
|
|
232
|
-
for
|
|
233
|
-
|
|
254
|
+
# First remove all of the associated assets
|
|
255
|
+
for fieldname, assoc_assets in asset.associated_assets.items():
|
|
256
|
+
asset.remove_associated_assets(fieldname, assoc_assets)
|
|
234
257
|
|
|
235
258
|
# Also remove all of the entry points
|
|
236
259
|
for attacker in self.attackers:
|
|
@@ -238,186 +261,9 @@ class Model():
|
|
|
238
261
|
if entry_point_tuple:
|
|
239
262
|
attacker.entry_points.remove(entry_point_tuple)
|
|
240
263
|
|
|
241
|
-
self.assets.
|
|
242
|
-
|
|
243
|
-
def remove_asset_from_association(
|
|
244
|
-
self,
|
|
245
|
-
asset: SchemaGeneratedClass,
|
|
246
|
-
association: SchemaGeneratedClass
|
|
247
|
-
) -> None:
|
|
248
|
-
"""Remove an asset from an association and remove the association
|
|
249
|
-
if any of the two sides is now empty.
|
|
250
|
-
|
|
251
|
-
Arguments:
|
|
252
|
-
asset - the asset to remove from the given association
|
|
253
|
-
association - the association to remove the asset from
|
|
254
|
-
"""
|
|
255
|
-
|
|
256
|
-
logger.debug(
|
|
257
|
-
'Remove "%s"(%d) from association of type "%s".',
|
|
258
|
-
asset.name, asset.id, type(association)
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
if asset not in self.assets:
|
|
262
|
-
raise LookupError(
|
|
263
|
-
f'Asset "{asset.name}"({asset.id}) is not part of model '
|
|
264
|
-
f'"{self.name}".'
|
|
265
|
-
)
|
|
266
|
-
if association not in self.associations:
|
|
267
|
-
raise LookupError(
|
|
268
|
-
f'Association is not part of model "{self.name}".'
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
left_field_name, right_field_name = \
|
|
272
|
-
self.get_association_field_names(association)
|
|
273
|
-
left_field = getattr(association, left_field_name)
|
|
274
|
-
right_field = getattr(association, right_field_name)
|
|
275
|
-
found = False
|
|
276
|
-
for field in [left_field, right_field]:
|
|
277
|
-
if asset in field:
|
|
278
|
-
found = True
|
|
279
|
-
if len(field) == 1:
|
|
280
|
-
# There are no other assets on this side,
|
|
281
|
-
# so we should remove the entire association.
|
|
282
|
-
self.remove_association(association)
|
|
283
|
-
return
|
|
284
|
-
field.remove(asset)
|
|
285
|
-
|
|
286
|
-
if not found:
|
|
287
|
-
raise LookupError(f'Asset "{asset.name}"({asset.id}) is not '
|
|
288
|
-
'part of the association provided.')
|
|
289
|
-
|
|
290
|
-
def _validate_association(self, association: SchemaGeneratedClass) -> None:
|
|
291
|
-
"""Raise error if association is invalid or already part of the Model.
|
|
292
|
-
|
|
293
|
-
Raises:
|
|
294
|
-
DuplicateAssociationError - same association already exists
|
|
295
|
-
ModelAssociationException - association is not valid
|
|
296
|
-
"""
|
|
297
|
-
|
|
298
|
-
# Optimization: only look for duplicates in associations of same type
|
|
299
|
-
association_type = association.type
|
|
300
|
-
associations_same_type = self._type_to_association.get(
|
|
301
|
-
association_type, []
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
# Check if identical association already exists
|
|
305
|
-
if association in associations_same_type:
|
|
306
|
-
raise DuplicateModelAssociationError(
|
|
307
|
-
f"Identical association {association_type} already exists"
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
# Check for duplicate assets in each field
|
|
312
|
-
left_field_name, right_field_name = \
|
|
313
|
-
self.get_association_field_names(association)
|
|
314
|
-
|
|
315
|
-
for field_name in (left_field_name, right_field_name):
|
|
316
|
-
field_assets = getattr(association, field_name)
|
|
317
|
-
|
|
318
|
-
unique_field_asset_names = {a.name for a in field_assets}
|
|
319
|
-
if len(field_assets) > len(unique_field_asset_names):
|
|
320
|
-
raise ModelAssociationException(
|
|
321
|
-
"More than one asset share same name in field"
|
|
322
|
-
f"{association_type}.{field_name}"
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
# For each asset in left field, go through each assets in right field
|
|
326
|
-
# to find all unique connections. Raise error if a connection between
|
|
327
|
-
# two assets already exist in a previously added association.
|
|
328
|
-
for left_asset in getattr(association, left_field_name):
|
|
329
|
-
for right_asset in getattr(association, right_field_name):
|
|
330
|
-
|
|
331
|
-
if self.association_exists_between_assets(
|
|
332
|
-
association_type, left_asset, right_asset
|
|
333
|
-
):
|
|
334
|
-
# Assets already have the connection in another
|
|
335
|
-
# association with same type
|
|
336
|
-
raise DuplicateModelAssociationError(
|
|
337
|
-
f"Association type {association_type} already exists"
|
|
338
|
-
f" between {left_asset.name} and {right_asset.name}"
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
def add_association(self, association: SchemaGeneratedClass) -> None:
|
|
342
|
-
"""Add an association to the model.
|
|
343
|
-
|
|
344
|
-
An association will have 2 field names, each
|
|
345
|
-
potentially containing several assets.
|
|
346
|
-
|
|
347
|
-
Arguments:
|
|
348
|
-
association - the association to add to the model
|
|
349
|
-
|
|
350
|
-
Raises:
|
|
351
|
-
DuplicateAssociationError - same association already exists
|
|
352
|
-
ModelAssociationException - association is not valid
|
|
353
|
-
|
|
354
|
-
"""
|
|
264
|
+
del self.assets[asset.id]
|
|
265
|
+
del self._name_to_asset[asset.name]
|
|
355
266
|
|
|
356
|
-
# Check association is valid and not duplicate
|
|
357
|
-
self._validate_association(association)
|
|
358
|
-
|
|
359
|
-
# Optional field for extra association data
|
|
360
|
-
association.extras = {}
|
|
361
|
-
|
|
362
|
-
field_names = self.get_association_field_names(association)
|
|
363
|
-
|
|
364
|
-
# Add the association to all of the included assets
|
|
365
|
-
for field_name in field_names:
|
|
366
|
-
for asset in getattr(association, field_name):
|
|
367
|
-
asset_assocs = list(asset.associations)
|
|
368
|
-
asset_assocs.append(association)
|
|
369
|
-
asset.associations = asset_assocs
|
|
370
|
-
|
|
371
|
-
self.associations.append(association)
|
|
372
|
-
|
|
373
|
-
# Add association to type->association mapping
|
|
374
|
-
association_type = association.type
|
|
375
|
-
self._type_to_association.setdefault(
|
|
376
|
-
association_type, []
|
|
377
|
-
).append(association)
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
def remove_association(self, association: SchemaGeneratedClass) -> None:
|
|
381
|
-
"""Remove an association from the model.
|
|
382
|
-
|
|
383
|
-
Arguments:
|
|
384
|
-
association - the association to remove from the model
|
|
385
|
-
"""
|
|
386
|
-
|
|
387
|
-
if association not in self.associations:
|
|
388
|
-
raise LookupError(
|
|
389
|
-
f'Association is not part of model "{self.name}".'
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
left_field_name, right_field_name = \
|
|
393
|
-
self.get_association_field_names(association)
|
|
394
|
-
left_field = getattr(association, left_field_name)
|
|
395
|
-
right_field = getattr(association, right_field_name)
|
|
396
|
-
|
|
397
|
-
for asset in left_field:
|
|
398
|
-
assocs = list(asset.associations)
|
|
399
|
-
assocs.remove(association)
|
|
400
|
-
asset.associations = assocs
|
|
401
|
-
|
|
402
|
-
for asset in right_field:
|
|
403
|
-
# In fringe cases we may have reflexive associations where the
|
|
404
|
-
# association was already removed when processing the left field
|
|
405
|
-
# assets therefore we have to check if it is still in the list.
|
|
406
|
-
if association in asset.associations:
|
|
407
|
-
assocs = list(asset.associations)
|
|
408
|
-
assocs.remove(association)
|
|
409
|
-
asset.associations = assocs
|
|
410
|
-
|
|
411
|
-
self.associations.remove(association)
|
|
412
|
-
|
|
413
|
-
# Remove association from type->association mapping
|
|
414
|
-
association_type = association.type
|
|
415
|
-
self._type_to_association[association_type].remove(
|
|
416
|
-
association
|
|
417
|
-
)
|
|
418
|
-
# Remove type from type->association mapping if mapping empty
|
|
419
|
-
if len(self._type_to_association[association_type]) == 0:
|
|
420
|
-
del self._type_to_association[association_type]
|
|
421
267
|
|
|
422
268
|
def add_attacker(
|
|
423
269
|
self,
|
|
@@ -441,9 +287,10 @@ class Model():
|
|
|
441
287
|
attacker.name = 'Attacker:' + str(attacker.id)
|
|
442
288
|
self.attackers.append(attacker)
|
|
443
289
|
|
|
290
|
+
|
|
444
291
|
def get_asset_by_id(
|
|
445
292
|
self, asset_id: int
|
|
446
|
-
) -> Optional[
|
|
293
|
+
) -> Optional[ModelAsset]:
|
|
447
294
|
"""
|
|
448
295
|
Find an asset in the model based on its id.
|
|
449
296
|
|
|
@@ -457,14 +304,12 @@ class Model():
|
|
|
457
304
|
'Get asset with id %d from model "%s".',
|
|
458
305
|
asset_id, self.name
|
|
459
306
|
)
|
|
460
|
-
return
|
|
461
|
-
|
|
462
|
-
if asset.id == asset_id), None
|
|
463
|
-
)
|
|
307
|
+
return self.assets.get(asset_id, None)
|
|
308
|
+
|
|
464
309
|
|
|
465
310
|
def get_asset_by_name(
|
|
466
311
|
self, asset_name: str
|
|
467
|
-
) -> Optional[
|
|
312
|
+
) -> Optional[ModelAsset]:
|
|
468
313
|
"""
|
|
469
314
|
Find an asset in the model based on its name.
|
|
470
315
|
|
|
@@ -478,10 +323,8 @@ class Model():
|
|
|
478
323
|
'Get asset with name "%s" from model "%s".',
|
|
479
324
|
asset_name, self.name
|
|
480
325
|
)
|
|
481
|
-
return
|
|
482
|
-
|
|
483
|
-
if asset.name == asset_name), None
|
|
484
|
-
)
|
|
326
|
+
return self._name_to_asset.get(asset_name, None)
|
|
327
|
+
|
|
485
328
|
|
|
486
329
|
def get_attacker_by_id(
|
|
487
330
|
self, attacker_id: int
|
|
@@ -504,187 +347,6 @@ class Model():
|
|
|
504
347
|
if attacker.id == attacker_id), None
|
|
505
348
|
)
|
|
506
349
|
|
|
507
|
-
def association_exists_between_assets(
|
|
508
|
-
self,
|
|
509
|
-
association_type: str,
|
|
510
|
-
left_asset: SchemaGeneratedClass,
|
|
511
|
-
right_asset: SchemaGeneratedClass
|
|
512
|
-
):
|
|
513
|
-
"""Return True if the association already exists between the assets"""
|
|
514
|
-
logger.debug(
|
|
515
|
-
'Check to see if an association of type "%s" '
|
|
516
|
-
'already exists between "%s" and "%s".',
|
|
517
|
-
association_type, left_asset.name, right_asset.name
|
|
518
|
-
)
|
|
519
|
-
associations = self._type_to_association.get(association_type, [])
|
|
520
|
-
for association in associations:
|
|
521
|
-
left_field_name, right_field_name = \
|
|
522
|
-
self.get_association_field_names(association)
|
|
523
|
-
if (left_asset.id in [asset.id for asset in \
|
|
524
|
-
getattr(association, left_field_name)] and \
|
|
525
|
-
right_asset.id in [asset.id for asset in \
|
|
526
|
-
getattr(association, right_field_name)]):
|
|
527
|
-
logger.debug(
|
|
528
|
-
'An association of type "%s" '
|
|
529
|
-
'already exists between "%s" and "%s".',
|
|
530
|
-
association_type, left_asset.name, right_asset.name
|
|
531
|
-
)
|
|
532
|
-
return True
|
|
533
|
-
logger.debug(
|
|
534
|
-
'No association of type "%s" '
|
|
535
|
-
'exists between "%s" and "%s".',
|
|
536
|
-
association_type, left_asset.name, right_asset.name
|
|
537
|
-
)
|
|
538
|
-
return False
|
|
539
|
-
|
|
540
|
-
def get_asset_defenses(
|
|
541
|
-
self,
|
|
542
|
-
asset: SchemaGeneratedClass,
|
|
543
|
-
include_defaults: bool = False
|
|
544
|
-
):
|
|
545
|
-
"""
|
|
546
|
-
Get the two field names of the association as a list.
|
|
547
|
-
Arguments:
|
|
548
|
-
asset - the asset to fetch the defenses for
|
|
549
|
-
include_defaults - if not True the defenses that have default
|
|
550
|
-
values will not be included in the list
|
|
551
|
-
|
|
552
|
-
Return:
|
|
553
|
-
A dictionary containing the defenses of the asset
|
|
554
|
-
"""
|
|
555
|
-
|
|
556
|
-
defenses = {}
|
|
557
|
-
for key, value in asset._properties.items():
|
|
558
|
-
property_schema = (
|
|
559
|
-
self.lang_classes_factory.json_schema['definitions']
|
|
560
|
-
['LanguageAsset'] ['definitions']
|
|
561
|
-
['Asset_' + asset.type]['properties'][key]
|
|
562
|
-
)
|
|
563
|
-
|
|
564
|
-
if "maximum" not in property_schema:
|
|
565
|
-
# Check if property is a defense by looking up defense
|
|
566
|
-
# specific key. Skip if it is not a defense.
|
|
567
|
-
continue
|
|
568
|
-
|
|
569
|
-
logger.debug(
|
|
570
|
-
'Translating %s: %s defense to dictionary.',
|
|
571
|
-
key,
|
|
572
|
-
value
|
|
573
|
-
)
|
|
574
|
-
|
|
575
|
-
if not include_defaults and value == value.default():
|
|
576
|
-
# Skip the defense values if they are the default ones.
|
|
577
|
-
continue
|
|
578
|
-
|
|
579
|
-
defenses[key] = float(value)
|
|
580
|
-
|
|
581
|
-
return defenses
|
|
582
|
-
|
|
583
|
-
def get_association_field_names(
|
|
584
|
-
self,
|
|
585
|
-
association: SchemaGeneratedClass
|
|
586
|
-
):
|
|
587
|
-
"""
|
|
588
|
-
Get the two field names of the association as a list.
|
|
589
|
-
Arguments:
|
|
590
|
-
association - the association to fetch the field names for
|
|
591
|
-
|
|
592
|
-
Return:
|
|
593
|
-
A two item list containing the field names of the association.
|
|
594
|
-
"""
|
|
595
|
-
|
|
596
|
-
return list(association._properties.keys())[1:]
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def get_associated_assets_by_field_name(
|
|
600
|
-
self,
|
|
601
|
-
asset: SchemaGeneratedClass,
|
|
602
|
-
field_name: str
|
|
603
|
-
) -> list[SchemaGeneratedClass]:
|
|
604
|
-
"""
|
|
605
|
-
Get a list of associated assets for an asset given a field name.
|
|
606
|
-
|
|
607
|
-
Arguments:
|
|
608
|
-
asset - the asset whose fields we are interested in
|
|
609
|
-
field_name - the field name we are looking for
|
|
610
|
-
|
|
611
|
-
Return:
|
|
612
|
-
A list of assets associated with the asset given that match the
|
|
613
|
-
field_name.
|
|
614
|
-
"""
|
|
615
|
-
logger.debug(
|
|
616
|
-
'Get associated assets for asset "%s"(%d) by field name %s.',
|
|
617
|
-
asset.name, asset.id, field_name
|
|
618
|
-
)
|
|
619
|
-
associated_assets = []
|
|
620
|
-
for association in asset.associations:
|
|
621
|
-
if hasattr(association, field_name):
|
|
622
|
-
associated_assets.extend(getattr(association, field_name))
|
|
623
|
-
|
|
624
|
-
return associated_assets
|
|
625
|
-
|
|
626
|
-
def asset_to_dict(self, asset: SchemaGeneratedClass) -> tuple[str, dict]:
|
|
627
|
-
"""Get dictionary representation of the asset.
|
|
628
|
-
|
|
629
|
-
Arguments:
|
|
630
|
-
asset - asset to get dictionary representation of
|
|
631
|
-
|
|
632
|
-
Return: tuple with name of asset and the asset as dict
|
|
633
|
-
"""
|
|
634
|
-
|
|
635
|
-
logger.debug(
|
|
636
|
-
'Translating "%s"(%d) to dictionary.',
|
|
637
|
-
asset.name,
|
|
638
|
-
asset.id
|
|
639
|
-
)
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
asset_dict: dict[str, Any] = {
|
|
643
|
-
'name': str(asset.name),
|
|
644
|
-
'type': str(asset.type)
|
|
645
|
-
}
|
|
646
|
-
|
|
647
|
-
defenses = self.get_asset_defenses(asset)
|
|
648
|
-
|
|
649
|
-
if defenses:
|
|
650
|
-
asset_dict['defenses'] = defenses
|
|
651
|
-
|
|
652
|
-
if asset.extras:
|
|
653
|
-
# Add optional metadata to dict
|
|
654
|
-
asset_dict['extras'] = asset.extras.as_dict()
|
|
655
|
-
|
|
656
|
-
return (asset.id, asset_dict)
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
def association_to_dict(self, association: SchemaGeneratedClass) -> dict:
|
|
660
|
-
"""Get dictionary representation of the association.
|
|
661
|
-
|
|
662
|
-
Arguments:
|
|
663
|
-
association - association to get dictionary representation of
|
|
664
|
-
|
|
665
|
-
Returns the association serialized to a dict
|
|
666
|
-
"""
|
|
667
|
-
|
|
668
|
-
left_field_name, right_field_name = \
|
|
669
|
-
self.get_association_field_names(association)
|
|
670
|
-
left_field = getattr(association, left_field_name)
|
|
671
|
-
right_field = getattr(association, right_field_name)
|
|
672
|
-
|
|
673
|
-
association_dict = {
|
|
674
|
-
str(association.type) :
|
|
675
|
-
{
|
|
676
|
-
str(left_field_name):
|
|
677
|
-
{int(asset.id): str(asset.name) for asset in left_field},
|
|
678
|
-
str(right_field_name):
|
|
679
|
-
{int(asset.id): str(asset.name) for asset in right_field}
|
|
680
|
-
}
|
|
681
|
-
}
|
|
682
|
-
|
|
683
|
-
if association.extras:
|
|
684
|
-
# Add optional metadata to dict
|
|
685
|
-
association_dict['extras'] = association.extras
|
|
686
|
-
|
|
687
|
-
return association_dict
|
|
688
350
|
|
|
689
351
|
def attacker_to_dict(
|
|
690
352
|
self, attacker: AttackerAttachment
|
|
@@ -697,43 +359,37 @@ class Model():
|
|
|
697
359
|
|
|
698
360
|
logger.debug('Translating %s to dictionary.', attacker.name)
|
|
699
361
|
attacker_dict: dict[str, Any] = {
|
|
700
|
-
'name':
|
|
362
|
+
'name': attacker.name,
|
|
701
363
|
'entry_points': {},
|
|
702
364
|
}
|
|
703
365
|
for (asset, attack_steps) in attacker.entry_points:
|
|
704
|
-
attacker_dict['entry_points'][
|
|
705
|
-
'asset_id':
|
|
366
|
+
attacker_dict['entry_points'][asset.name] = {
|
|
367
|
+
'asset_id': asset.id,
|
|
706
368
|
'attack_steps' : attack_steps
|
|
707
369
|
}
|
|
708
370
|
return (attacker.id, attacker_dict)
|
|
709
371
|
|
|
372
|
+
|
|
710
373
|
def _to_dict(self) -> dict:
|
|
711
374
|
"""Get dictionary representation of the model."""
|
|
712
375
|
logger.debug('Translating model to dict.')
|
|
713
376
|
contents: dict[str, Any] = {
|
|
714
377
|
'metadata': {},
|
|
715
378
|
'assets': {},
|
|
716
|
-
'associations': [],
|
|
717
379
|
'attackers' : {}
|
|
718
380
|
}
|
|
719
381
|
contents['metadata'] = {
|
|
720
382
|
'name': self.name,
|
|
721
|
-
'langVersion': self.
|
|
722
|
-
'langID': self.
|
|
383
|
+
'langVersion': self.lang_graph.metadata['version'],
|
|
384
|
+
'langID': self.lang_graph.metadata['id'],
|
|
723
385
|
'malVersion': '0.1.0-SNAPSHOT',
|
|
724
386
|
'MAL-Toolbox Version': __version__,
|
|
725
387
|
'info': 'Created by the mal-toolbox model python module.'
|
|
726
388
|
}
|
|
727
389
|
|
|
728
390
|
logger.debug('Translating assets to dictionary.')
|
|
729
|
-
for asset in self.assets:
|
|
730
|
-
(
|
|
731
|
-
contents['assets'][int(asset_id)] = asset_dict
|
|
732
|
-
|
|
733
|
-
logger.debug('Translating associations to dictionary.')
|
|
734
|
-
for association in self.associations:
|
|
735
|
-
assoc_dict = self.association_to_dict(association)
|
|
736
|
-
contents['associations'].append(assoc_dict)
|
|
391
|
+
for asset in self.assets.values():
|
|
392
|
+
contents['assets'].update(asset._to_dict())
|
|
737
393
|
|
|
738
394
|
logger.debug('Translating attackers to dictionary.')
|
|
739
395
|
for attacker in self.attackers:
|
|
@@ -741,22 +397,24 @@ class Model():
|
|
|
741
397
|
contents['attackers'][attacker_id] = attacker_dict
|
|
742
398
|
return contents
|
|
743
399
|
|
|
400
|
+
|
|
744
401
|
def save_to_file(self, filename: str) -> None:
|
|
745
402
|
"""Save to json/yml depending on extension"""
|
|
746
403
|
logger.debug('Save instance model to file "%s".', filename)
|
|
747
404
|
return save_dict_to_file(filename, self._to_dict())
|
|
748
405
|
|
|
406
|
+
|
|
749
407
|
@classmethod
|
|
750
408
|
def _from_dict(
|
|
751
409
|
cls,
|
|
752
410
|
serialized_object: dict,
|
|
753
|
-
|
|
411
|
+
lang_graph: LanguageGraph,
|
|
754
412
|
) -> Model:
|
|
755
413
|
"""Create a model from dict representation
|
|
756
414
|
|
|
757
415
|
Arguments:
|
|
758
416
|
serialized_object - Model in dict format
|
|
759
|
-
|
|
417
|
+
lang_graph -
|
|
760
418
|
"""
|
|
761
419
|
|
|
762
420
|
maltoolbox_version = serialized_object['metadata']['MAL Toolbox Version'] \
|
|
@@ -764,72 +422,46 @@ class Model():
|
|
|
764
422
|
else __version__
|
|
765
423
|
model = Model(
|
|
766
424
|
serialized_object['metadata']['name'],
|
|
767
|
-
|
|
425
|
+
lang_graph,
|
|
768
426
|
mt_version = maltoolbox_version)
|
|
769
427
|
|
|
770
428
|
# Reconstruct the assets
|
|
771
|
-
for asset_id,
|
|
429
|
+
for asset_id, asset_dict in serialized_object['assets'].items():
|
|
772
430
|
|
|
773
431
|
if logger.isEnabledFor(logging.DEBUG):
|
|
774
432
|
# Avoid running json.dumps when not in debug
|
|
775
433
|
logger.debug(
|
|
776
|
-
"Loading asset:\n%s", json.dumps(
|
|
434
|
+
"Loading asset:\n%s", json.dumps(asset_dict, indent=2)
|
|
777
435
|
)
|
|
778
436
|
|
|
779
437
|
# Allow defining an asset via type only.
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
if isinstance(
|
|
438
|
+
asset_dict = (
|
|
439
|
+
asset_dict
|
|
440
|
+
if isinstance(asset_dict, dict)
|
|
783
441
|
else {
|
|
784
|
-
'type':
|
|
785
|
-
'name': f"{
|
|
442
|
+
'type': asset_dict,
|
|
443
|
+
'name': f"{asset_dict}:{asset_id}"
|
|
786
444
|
}
|
|
787
445
|
)
|
|
788
446
|
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
model.add_asset(asset, asset_id = int(asset_id))
|
|
807
|
-
|
|
808
|
-
# Reconstruct the associations
|
|
809
|
-
for assoc_entry in serialized_object.get('associations', []):
|
|
810
|
-
[(assoc, assoc_fields)] = assoc_entry.items()
|
|
811
|
-
assoc_keys_iter = iter(assoc_fields)
|
|
812
|
-
field1 = next(assoc_keys_iter)
|
|
813
|
-
field2 = next(assoc_keys_iter)
|
|
814
|
-
assoc_type_class = model.lang_classes_factory.\
|
|
815
|
-
get_association_class_by_fieldnames(assoc, field1, field2)
|
|
816
|
-
if assoc_type_class is None:
|
|
817
|
-
raise LookupError('Failed to find association "%s" with '
|
|
818
|
-
'fields "%s" and "%s" in language classes factory' %
|
|
819
|
-
(assoc, field1, field2)
|
|
447
|
+
model.add_asset(
|
|
448
|
+
asset_type = asset_dict['type'],
|
|
449
|
+
name = asset_dict['name'],
|
|
450
|
+
defenses = {defense: float(value) for defense, value in \
|
|
451
|
+
asset_dict.get('defenses', {}).items()},
|
|
452
|
+
extras = asset_dict.get('extras', {}),
|
|
453
|
+
asset_id = int(asset_id))
|
|
454
|
+
|
|
455
|
+
# Reconstruct the association links
|
|
456
|
+
for asset_id, asset_dict in serialized_object['assets'].items():
|
|
457
|
+
asset = model.assets[int(asset_id)]
|
|
458
|
+
assoc_assets_dict = asset_dict['associated_assets'].items()
|
|
459
|
+
for fieldname, assoc_assets in assoc_assets_dict:
|
|
460
|
+
asset.add_associated_assets(
|
|
461
|
+
fieldname,
|
|
462
|
+
{model.assets[int(assoc_asset_id)]
|
|
463
|
+
for assoc_asset_id in assoc_assets}
|
|
820
464
|
)
|
|
821
|
-
association = assoc_type_class()
|
|
822
|
-
|
|
823
|
-
for field, targets in assoc_fields.items():
|
|
824
|
-
setattr(
|
|
825
|
-
association,
|
|
826
|
-
field,
|
|
827
|
-
[model.get_asset_by_id(int(id)) for id in targets]
|
|
828
|
-
)
|
|
829
|
-
|
|
830
|
-
#TODO Properly handle extras
|
|
831
|
-
|
|
832
|
-
model.add_association(association)
|
|
833
465
|
|
|
834
466
|
# Reconstruct the attackers
|
|
835
467
|
if 'attackers' in serialized_object:
|
|
@@ -839,21 +471,30 @@ class Model():
|
|
|
839
471
|
attacker.entry_points = []
|
|
840
472
|
for asset_name, entry_points_dict in \
|
|
841
473
|
attackers_info[attacker_id]['entry_points'].items():
|
|
474
|
+
target_asset = model.get_asset_by_id(
|
|
475
|
+
entry_points_dict['asset_id'])
|
|
476
|
+
if target_asset is None:
|
|
477
|
+
raise LookupError(
|
|
478
|
+
'Asset "%s"(%d) is not part of model "%s".' % (
|
|
479
|
+
asset_name,
|
|
480
|
+
entry_points_dict['asset_id'],
|
|
481
|
+
model.name)
|
|
482
|
+
)
|
|
842
483
|
attacker.entry_points.append(
|
|
843
484
|
(
|
|
844
|
-
|
|
845
|
-
entry_points_dict['asset_id']),
|
|
485
|
+
target_asset,
|
|
846
486
|
entry_points_dict['attack_steps']
|
|
847
487
|
)
|
|
848
488
|
)
|
|
849
489
|
model.add_attacker(attacker, attacker_id = int(attacker_id))
|
|
850
490
|
return model
|
|
851
491
|
|
|
492
|
+
|
|
852
493
|
@classmethod
|
|
853
494
|
def load_from_file(
|
|
854
495
|
cls,
|
|
855
496
|
filename: str,
|
|
856
|
-
|
|
497
|
+
lang_graph: LanguageGraph,
|
|
857
498
|
) -> Model:
|
|
858
499
|
"""Create from json or yaml file depending on file extension"""
|
|
859
500
|
logger.debug('Load instance model from file "%s".', filename)
|
|
@@ -864,4 +505,181 @@ class Model():
|
|
|
864
505
|
serialized_model = load_dict_from_json_file(filename)
|
|
865
506
|
else:
|
|
866
507
|
raise ValueError('Unknown file extension, expected json/yml/yaml')
|
|
867
|
-
|
|
508
|
+
try:
|
|
509
|
+
return cls._from_dict(serialized_model, lang_graph)
|
|
510
|
+
except Exception as e:
|
|
511
|
+
raise ModelException(
|
|
512
|
+
"Could not load model. It might be of an older version. "
|
|
513
|
+
"Try to upgrade it with 'maltoolbox upgrade-model'"
|
|
514
|
+
) from e
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
class ModelAsset:
|
|
518
|
+
def __init__(
|
|
519
|
+
self,
|
|
520
|
+
name: str,
|
|
521
|
+
asset_id: int,
|
|
522
|
+
lg_asset: LanguageGraphAsset,
|
|
523
|
+
defenses: Optional[dict[str, float]] = None,
|
|
524
|
+
extras: Optional[dict] = None
|
|
525
|
+
):
|
|
526
|
+
|
|
527
|
+
self.name: str = name
|
|
528
|
+
self._id: int = asset_id
|
|
529
|
+
self.lg_asset: LanguageGraphAsset = lg_asset
|
|
530
|
+
self.type = self.lg_asset.name
|
|
531
|
+
self.defenses: dict[str, float] = defenses or {}
|
|
532
|
+
self.extras: dict = extras or {}
|
|
533
|
+
self._associated_assets: dict[str, set[ModelAsset]] = {}
|
|
534
|
+
self.attack_step_nodes: list = []
|
|
535
|
+
|
|
536
|
+
for step in self.lg_asset.attack_steps.values():
|
|
537
|
+
if step.type == 'defense' and step.name not in self.defenses:
|
|
538
|
+
self.defenses[step.name] = 1.0 if step.ttc and \
|
|
539
|
+
step.ttc['name'] == 'Enabled' else 0.0
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _to_dict(self):
|
|
543
|
+
"""Get dictionary representation of the asset."""
|
|
544
|
+
|
|
545
|
+
logger.debug(
|
|
546
|
+
'Translating "%s"(%d) to dictionary.', self.name, self.id)
|
|
547
|
+
|
|
548
|
+
asset_dict: dict[str, Any] = {
|
|
549
|
+
'name': self.name,
|
|
550
|
+
'type': self.type,
|
|
551
|
+
'defenses': {},
|
|
552
|
+
'associated_assets': {}
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
# Only add non-default values for defenses to improve legibility of
|
|
556
|
+
# the model format
|
|
557
|
+
for defense, defense_value in self.defenses.items():
|
|
558
|
+
lg_step = self.lg_asset.attack_steps[defense]
|
|
559
|
+
default_defval = 1.0 if lg_step.ttc and \
|
|
560
|
+
lg_step.ttc['name'] == 'Enabled' else 0.0
|
|
561
|
+
if defense_value != default_defval:
|
|
562
|
+
asset_dict['defenses'][defense] = defense_value
|
|
563
|
+
|
|
564
|
+
for fieldname, assets in self.associated_assets.items():
|
|
565
|
+
asset_dict['associated_assets'][fieldname] = {asset.id: asset.name
|
|
566
|
+
for asset in assets}
|
|
567
|
+
|
|
568
|
+
if len(asset_dict['defenses']) == 0:
|
|
569
|
+
# Do not include an empty defenses dictionary
|
|
570
|
+
del asset_dict['defenses']
|
|
571
|
+
|
|
572
|
+
if self.extras != {}:
|
|
573
|
+
# Add optional metadata to dict
|
|
574
|
+
asset_dict['extras'] = self.extras
|
|
575
|
+
|
|
576
|
+
return {self.id: asset_dict}
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def __repr__(self):
|
|
580
|
+
return (f'ModelAsset(name: "{self.name}", id: {self.id}, '
|
|
581
|
+
f'type: {self.type})')
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def validate_associated_assets(
|
|
585
|
+
self, fieldname: str, assets_to_add: set[ModelAsset]
|
|
586
|
+
):
|
|
587
|
+
"""
|
|
588
|
+
Validate an association we want to add (through `fieldname`)
|
|
589
|
+
is valid with the assets given in param `assets_to_add`:
|
|
590
|
+
- fieldname is valid for the asset type of this ModelAsset
|
|
591
|
+
- type of `assets_to_add` is valid for the association
|
|
592
|
+
- no more assets than 'field.maximum' are added to the field
|
|
593
|
+
|
|
594
|
+
Raises:
|
|
595
|
+
LookupError - fieldname can not be found for this ModelAsset
|
|
596
|
+
ValueError - there will be too many assets in the field
|
|
597
|
+
if we add this association
|
|
598
|
+
TypeError - if the asset type of `assets_to_add` is not valid
|
|
599
|
+
"""
|
|
600
|
+
|
|
601
|
+
# Validate that the field name is allowed for this asset type
|
|
602
|
+
if fieldname not in self.lg_asset.associations:
|
|
603
|
+
accepted_fieldnames = list(self.lg_asset.associations.keys())
|
|
604
|
+
raise LookupError(
|
|
605
|
+
f"Fieldname '{fieldname}' is not an accepted association "
|
|
606
|
+
f"fieldname from asset type {self.lg_asset.name}. "
|
|
607
|
+
f"Did you mean one of {accepted_fieldnames}?"
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
lg_assoc = self.lg_asset.associations[fieldname]
|
|
611
|
+
assoc_field = lg_assoc.get_field(fieldname)
|
|
612
|
+
|
|
613
|
+
# Validate that the asset to add association to is of correct type
|
|
614
|
+
for asset_to_add in assets_to_add:
|
|
615
|
+
if not asset_to_add.lg_asset.is_subasset_of(assoc_field.asset):
|
|
616
|
+
raise TypeError(
|
|
617
|
+
f"Asset '{asset_to_add.name}' of type "
|
|
618
|
+
f"'{asset_to_add.type}' can not be added to association "
|
|
619
|
+
f"'{self.name}.{fieldname}'. Expected type of "
|
|
620
|
+
f"'{fieldname}' is {assoc_field.asset.name}."
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Validate that there will not be too many assets in field
|
|
624
|
+
assets_in_field_before = self.associated_assets.get(fieldname, set())
|
|
625
|
+
assets_in_field_after = assets_in_field_before | set(assets_to_add)
|
|
626
|
+
max_assets_in_field = assoc_field.maximum or math.inf
|
|
627
|
+
|
|
628
|
+
if len(assets_in_field_after) > max_assets_in_field:
|
|
629
|
+
raise ValueError(
|
|
630
|
+
f"You can have maximum {assoc_field.maximum} "
|
|
631
|
+
f"assets for association field {fieldname}"
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
def add_associated_assets(self, fieldname: str, assets: set[ModelAsset]):
|
|
635
|
+
"""
|
|
636
|
+
Add the assets provided as a parameter to the set of associated
|
|
637
|
+
assets dictionary entry corresponding to the given fieldname.
|
|
638
|
+
"""
|
|
639
|
+
|
|
640
|
+
lg_assoc = self.lg_asset.associations[fieldname]
|
|
641
|
+
other_fieldname = lg_assoc.get_opposite_fieldname(fieldname)
|
|
642
|
+
|
|
643
|
+
# Validation from both sides
|
|
644
|
+
self.validate_associated_assets(fieldname, assets)
|
|
645
|
+
for asset in assets:
|
|
646
|
+
asset.validate_associated_assets(other_fieldname, {self})
|
|
647
|
+
|
|
648
|
+
# Add the associated assets to this asset's dictionary
|
|
649
|
+
self._associated_assets.setdefault(
|
|
650
|
+
fieldname, set()
|
|
651
|
+
).update(assets)
|
|
652
|
+
|
|
653
|
+
# Add this asset to the associated assets' corresponding dictionaries
|
|
654
|
+
for asset in assets:
|
|
655
|
+
asset._associated_assets.setdefault(
|
|
656
|
+
other_fieldname, set()
|
|
657
|
+
).add(self)
|
|
658
|
+
|
|
659
|
+
def remove_associated_assets(self, fieldname: str,
|
|
660
|
+
assets: set[ModelAsset]):
|
|
661
|
+
""" Remove the assets provided as a parameter from the set of
|
|
662
|
+
associated assets dictionary entry corresponding to the fieldname
|
|
663
|
+
parameter.
|
|
664
|
+
"""
|
|
665
|
+
self._associated_assets[fieldname] -= set(assets)
|
|
666
|
+
if len(self._associated_assets[fieldname]) == 0:
|
|
667
|
+
del self._associated_assets[fieldname]
|
|
668
|
+
|
|
669
|
+
# Also remove this asset to the associated assets' dictionaries
|
|
670
|
+
lg_assoc = self.lg_asset.associations[fieldname]
|
|
671
|
+
other_fieldname = lg_assoc.get_opposite_fieldname(fieldname)
|
|
672
|
+
for asset in assets:
|
|
673
|
+
asset._associated_assets[other_fieldname].remove(self)
|
|
674
|
+
if len(asset._associated_assets[other_fieldname]) == 0:
|
|
675
|
+
del asset._associated_assets[other_fieldname]
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
@property
|
|
679
|
+
def associated_assets(self):
|
|
680
|
+
return self._associated_assets
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
@property
|
|
684
|
+
def id(self):
|
|
685
|
+
return self._id
|