mal-toolbox 0.1.12__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mal_toolbox-0.1.12.dist-info → mal_toolbox-0.3.0.dist-info}/METADATA +43 -25
- mal_toolbox-0.3.0.dist-info/RECORD +29 -0
- mal_toolbox-0.3.0.dist-info/entry_points.txt +2 -0
- maltoolbox/__init__.py +38 -57
- maltoolbox/__main__.py +43 -14
- maltoolbox/attackgraph/__init__.py +1 -1
- maltoolbox/attackgraph/analyzers/apriori.py +10 -6
- maltoolbox/attackgraph/attacker.py +26 -13
- maltoolbox/attackgraph/attackgraph.py +431 -355
- maltoolbox/attackgraph/node.py +72 -54
- maltoolbox/attackgraph/query.py +4 -2
- maltoolbox/file_utils.py +4 -8
- maltoolbox/ingestors/neo4j.py +146 -157
- maltoolbox/language/__init__.py +10 -2
- maltoolbox/language/compiler/__init__.py +485 -17
- maltoolbox/language/compiler/mal_lexer.py +172 -152
- maltoolbox/language/compiler/mal_parser.py +1370 -663
- maltoolbox/language/languagegraph.py +1096 -545
- maltoolbox/model.py +312 -485
- maltoolbox/translators/securicad.py +164 -163
- maltoolbox/translators/updater.py +231 -108
- mal_toolbox-0.1.12.dist-info/RECORD +0 -32
- maltoolbox/default.conf +0 -17
- maltoolbox/language/classes_factory.py +0 -243
- maltoolbox/language/compiler/mal_visitor.py +0 -416
- maltoolbox/wrappers.py +0 -62
- {mal_toolbox-0.1.12.dist-info → mal_toolbox-0.3.0.dist-info}/AUTHORS +0 -0
- {mal_toolbox-0.1.12.dist-info → mal_toolbox-0.3.0.dist-info}/LICENSE +0 -0
- {mal_toolbox-0.1.12.dist-info → mal_toolbox-0.3.0.dist-info}/WHEEL +0 -0
- {mal_toolbox-0.1.12.dist-info → mal_toolbox-0.3.0.dist-info}/top_level.txt +0 -0
maltoolbox/model.py
CHANGED
|
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
|
|
|
7
7
|
import json
|
|
8
8
|
import logging
|
|
9
9
|
from typing import TYPE_CHECKING
|
|
10
|
+
import math
|
|
10
11
|
|
|
11
12
|
from .file_utils import (
|
|
12
13
|
load_dict_from_json_file,
|
|
@@ -15,14 +16,14 @@ from .file_utils import (
|
|
|
15
16
|
)
|
|
16
17
|
|
|
17
18
|
from . import __version__
|
|
18
|
-
from .exceptions import
|
|
19
|
+
from .exceptions import ModelException
|
|
19
20
|
|
|
20
21
|
if TYPE_CHECKING:
|
|
21
|
-
from typing import Any, Optional
|
|
22
|
-
from .language import
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
from typing import Any, Optional
|
|
23
|
+
from .language import (
|
|
24
|
+
LanguageGraph,
|
|
25
|
+
LanguageGraphAsset,
|
|
26
|
+
)
|
|
26
27
|
|
|
27
28
|
logger = logging.getLogger(__name__)
|
|
28
29
|
|
|
@@ -31,14 +32,14 @@ class AttackerAttachment:
|
|
|
31
32
|
"""Used to attach attackers to attack step entry points of assets"""
|
|
32
33
|
id: Optional[int] = None
|
|
33
34
|
name: Optional[str] = None
|
|
34
|
-
entry_points: list[tuple[
|
|
35
|
+
entry_points: list[tuple[ModelAsset, list[str]]] = \
|
|
35
36
|
field(default_factory=lambda: [])
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
def get_entry_point_tuple(
|
|
39
40
|
self,
|
|
40
|
-
asset:
|
|
41
|
-
) -> Optional[tuple[
|
|
41
|
+
asset: ModelAsset
|
|
42
|
+
) -> Optional[tuple[ModelAsset, list[str]]]:
|
|
42
43
|
"""Return an entry point tuple of an AttackerAttachment matching the
|
|
43
44
|
asset provided.
|
|
44
45
|
|
|
@@ -57,7 +58,7 @@ class AttackerAttachment:
|
|
|
57
58
|
|
|
58
59
|
|
|
59
60
|
def add_entry_point(
|
|
60
|
-
self, asset:
|
|
61
|
+
self, asset: ModelAsset, attackstep_name: str):
|
|
61
62
|
"""Add an entry point to an AttackerAttachment
|
|
62
63
|
|
|
63
64
|
self.entry_points contain tuples, first element of each tuple
|
|
@@ -92,8 +93,9 @@ class AttackerAttachment:
|
|
|
92
93
|
# point
|
|
93
94
|
self.entry_points.append((asset, [attackstep_name]))
|
|
94
95
|
|
|
96
|
+
|
|
95
97
|
def remove_entry_point(
|
|
96
|
-
self, asset:
|
|
98
|
+
self, asset: ModelAsset, attackstep_name: str):
|
|
97
99
|
"""Remove an entry point from an AttackerAttachment if it exists
|
|
98
100
|
|
|
99
101
|
Arguments:
|
|
@@ -129,25 +131,25 @@ class AttackerAttachment:
|
|
|
129
131
|
|
|
130
132
|
|
|
131
133
|
class Model():
|
|
132
|
-
"""An implementation of a MAL language
|
|
134
|
+
"""An implementation of a MAL language model containing assets"""
|
|
133
135
|
next_id: int = 0
|
|
134
136
|
|
|
135
137
|
def __repr__(self) -> str:
|
|
136
|
-
return f'Model {self.name}'
|
|
138
|
+
return f'Model(name: "{self.name}", language: {self.lang_graph})'
|
|
139
|
+
|
|
137
140
|
|
|
138
141
|
def __init__(
|
|
139
142
|
self,
|
|
140
143
|
name: str,
|
|
141
|
-
|
|
144
|
+
lang_graph: LanguageGraph,
|
|
142
145
|
mt_version: str = __version__
|
|
143
146
|
):
|
|
144
147
|
|
|
145
148
|
self.name = name
|
|
146
|
-
self.assets:
|
|
147
|
-
self.
|
|
148
|
-
self._type_to_association:dict = {} # optimization
|
|
149
|
+
self.assets: dict[int, ModelAsset] = {}
|
|
150
|
+
self._name_to_asset:dict[str, ModelAsset] = {} # optimization
|
|
149
151
|
self.attackers: list[AttackerAttachment] = []
|
|
150
|
-
self.
|
|
152
|
+
self.lang_graph = lang_graph
|
|
151
153
|
self.maltoolbox_version: str = mt_version
|
|
152
154
|
|
|
153
155
|
# Below sets used to check for duplicate names or ids,
|
|
@@ -155,63 +157,84 @@ class Model():
|
|
|
155
157
|
self.asset_ids: set[int] = set()
|
|
156
158
|
self.asset_names: set[str] = set()
|
|
157
159
|
|
|
160
|
+
|
|
158
161
|
def add_asset(
|
|
159
162
|
self,
|
|
160
|
-
|
|
163
|
+
asset_type: str,
|
|
164
|
+
name: Optional[str] = None,
|
|
161
165
|
asset_id: Optional[int] = None,
|
|
166
|
+
defenses: Optional[dict[str, float]] = None,
|
|
167
|
+
extras: Optional[dict] = None,
|
|
162
168
|
allow_duplicate_names: bool = True
|
|
163
|
-
) ->
|
|
164
|
-
"""
|
|
169
|
+
) -> ModelAsset:
|
|
170
|
+
"""
|
|
171
|
+
Create an asset based on the provided parameters and add it to the
|
|
172
|
+
model.
|
|
165
173
|
|
|
166
174
|
Arguments:
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
175
|
+
asset_type - string containing the asset type name
|
|
176
|
+
name - string containing the asset name. If not
|
|
177
|
+
provided the concatenated asset type and id
|
|
178
|
+
will be used as a name.
|
|
179
|
+
asset_id - id to assign to this asset, usually from an
|
|
180
|
+
instance model file. If not provided the id
|
|
181
|
+
will be set to the next highest id
|
|
182
|
+
available.
|
|
183
|
+
defeses - dictionary of defense values
|
|
184
|
+
extras - dictionary of extras
|
|
170
185
|
allow_duplicate_name - allow duplicate names to be used. If allowed
|
|
171
186
|
and a duplicate is encountered the name will
|
|
172
187
|
be appended with the id.
|
|
173
188
|
|
|
174
189
|
Return:
|
|
175
|
-
|
|
190
|
+
The newly created asset.
|
|
176
191
|
"""
|
|
177
192
|
|
|
178
193
|
# Set asset ID and check for duplicates
|
|
179
|
-
|
|
180
|
-
if
|
|
194
|
+
asset_id = asset_id or self.next_id
|
|
195
|
+
if asset_id in self.asset_ids:
|
|
181
196
|
raise ValueError(f'Asset index {asset_id} already in use.')
|
|
182
|
-
self.asset_ids.add(
|
|
197
|
+
self.asset_ids.add(asset_id)
|
|
183
198
|
|
|
184
|
-
self.next_id = max(
|
|
199
|
+
self.next_id = max(asset_id + 1, self.next_id)
|
|
185
200
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
if not hasattr(asset, 'name'):
|
|
189
|
-
asset.name = asset.type + ':' + str(asset.id)
|
|
201
|
+
if not name:
|
|
202
|
+
name = asset_type + ':' + str(asset_id)
|
|
190
203
|
else:
|
|
191
|
-
if
|
|
204
|
+
if name in self.asset_names:
|
|
192
205
|
if allow_duplicate_names:
|
|
193
|
-
|
|
206
|
+
name = name + ':' + str(asset_id)
|
|
194
207
|
else:
|
|
195
208
|
raise ValueError(
|
|
196
|
-
f'Asset name {
|
|
209
|
+
f'Asset name {name} is a duplicate'
|
|
197
210
|
' and we do not allow duplicates.'
|
|
198
211
|
)
|
|
199
|
-
self.asset_names.add(
|
|
212
|
+
self.asset_names.add(name)
|
|
213
|
+
|
|
214
|
+
lg_asset = self.lang_graph.assets[asset_type]
|
|
200
215
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
216
|
+
asset = ModelAsset(
|
|
217
|
+
name = name,
|
|
218
|
+
asset_id = asset_id,
|
|
219
|
+
lg_asset = lg_asset,
|
|
220
|
+
defenses = defenses,
|
|
221
|
+
extras = extras)
|
|
204
222
|
|
|
205
223
|
logger.debug(
|
|
206
|
-
'Add "%s"(%d) to model "%s".',
|
|
224
|
+
'Add "%s"(%d) to model "%s".', name, asset_id, self.name
|
|
207
225
|
)
|
|
208
|
-
self.assets
|
|
226
|
+
self.assets[asset_id] = asset
|
|
227
|
+
self._name_to_asset[name] = asset
|
|
228
|
+
|
|
229
|
+
return asset
|
|
230
|
+
|
|
209
231
|
|
|
210
232
|
def remove_attacker(self, attacker: AttackerAttachment) -> None:
|
|
211
233
|
"""Remove attacker"""
|
|
212
234
|
self.attackers.remove(attacker)
|
|
213
235
|
|
|
214
|
-
|
|
236
|
+
|
|
237
|
+
def remove_asset(self, asset: ModelAsset) -> None:
|
|
215
238
|
"""Remove an asset from the model.
|
|
216
239
|
|
|
217
240
|
Arguments:
|
|
@@ -222,15 +245,15 @@ class Model():
|
|
|
222
245
|
'Remove "%s"(%d) from model "%s".',
|
|
223
246
|
asset.name, asset.id, self.name
|
|
224
247
|
)
|
|
225
|
-
if asset not in self.assets:
|
|
248
|
+
if asset.id not in self.assets:
|
|
226
249
|
raise LookupError(
|
|
227
250
|
f'Asset "{asset.name}"({asset.id}) is not part'
|
|
228
251
|
f' of model"{self.name}".'
|
|
229
252
|
)
|
|
230
253
|
|
|
231
|
-
# First remove all of the
|
|
232
|
-
for
|
|
233
|
-
|
|
254
|
+
# First remove all of the associated assets
|
|
255
|
+
for fieldname, assoc_assets in asset.associated_assets.items():
|
|
256
|
+
asset.remove_associated_assets(fieldname, assoc_assets)
|
|
234
257
|
|
|
235
258
|
# Also remove all of the entry points
|
|
236
259
|
for attacker in self.attackers:
|
|
@@ -238,186 +261,9 @@ class Model():
|
|
|
238
261
|
if entry_point_tuple:
|
|
239
262
|
attacker.entry_points.remove(entry_point_tuple)
|
|
240
263
|
|
|
241
|
-
self.assets.
|
|
242
|
-
|
|
243
|
-
def remove_asset_from_association(
|
|
244
|
-
self,
|
|
245
|
-
asset: SchemaGeneratedClass,
|
|
246
|
-
association: SchemaGeneratedClass
|
|
247
|
-
) -> None:
|
|
248
|
-
"""Remove an asset from an association and remove the association
|
|
249
|
-
if any of the two sides is now empty.
|
|
250
|
-
|
|
251
|
-
Arguments:
|
|
252
|
-
asset - the asset to remove from the given association
|
|
253
|
-
association - the association to remove the asset from
|
|
254
|
-
"""
|
|
255
|
-
|
|
256
|
-
logger.debug(
|
|
257
|
-
'Remove "%s"(%d) from association of type "%s".',
|
|
258
|
-
asset.name, asset.id, type(association)
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
if asset not in self.assets:
|
|
262
|
-
raise LookupError(
|
|
263
|
-
f'Asset "{asset.name}"({asset.id}) is not part of model '
|
|
264
|
-
f'"{self.name}".'
|
|
265
|
-
)
|
|
266
|
-
if association not in self.associations:
|
|
267
|
-
raise LookupError(
|
|
268
|
-
f'Association is not part of model "{self.name}".'
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
left_field_name, right_field_name = \
|
|
272
|
-
self.get_association_field_names(association)
|
|
273
|
-
left_field = getattr(association, left_field_name)
|
|
274
|
-
right_field = getattr(association, right_field_name)
|
|
275
|
-
found = False
|
|
276
|
-
for field in [left_field, right_field]:
|
|
277
|
-
if asset in field:
|
|
278
|
-
found = True
|
|
279
|
-
if len(field) == 1:
|
|
280
|
-
# There are no other assets on this side,
|
|
281
|
-
# so we should remove the entire association.
|
|
282
|
-
self.remove_association(association)
|
|
283
|
-
return
|
|
284
|
-
field.remove(asset)
|
|
285
|
-
|
|
286
|
-
if not found:
|
|
287
|
-
raise LookupError(f'Asset "{asset.name}"({asset.id}) is not '
|
|
288
|
-
'part of the association provided.')
|
|
289
|
-
|
|
290
|
-
def _validate_association(self, association: SchemaGeneratedClass) -> None:
|
|
291
|
-
"""Raise error if association is invalid or already part of the Model.
|
|
292
|
-
|
|
293
|
-
Raises:
|
|
294
|
-
DuplicateAssociationError - same association already exists
|
|
295
|
-
ModelAssociationException - association is not valid
|
|
296
|
-
"""
|
|
297
|
-
|
|
298
|
-
# Optimization: only look for duplicates in associations of same type
|
|
299
|
-
association_type = association.__class__.__name__
|
|
300
|
-
associations_same_type = self._type_to_association.get(
|
|
301
|
-
association_type, []
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
# Check if identical association already exists
|
|
305
|
-
if association in associations_same_type:
|
|
306
|
-
raise DuplicateModelAssociationError(
|
|
307
|
-
f"Identical association {association_type} already exists"
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
# Check for duplicate assets in each field
|
|
312
|
-
left_field_name, right_field_name = \
|
|
313
|
-
self.get_association_field_names(association)
|
|
314
|
-
|
|
315
|
-
for field_name in (left_field_name, right_field_name):
|
|
316
|
-
field_assets = getattr(association, field_name)
|
|
317
|
-
|
|
318
|
-
unique_field_asset_names = {a.name for a in field_assets}
|
|
319
|
-
if len(field_assets) > len(unique_field_asset_names):
|
|
320
|
-
raise ModelAssociationException(
|
|
321
|
-
"More than one asset share same name in field"
|
|
322
|
-
f"{association_type}.{field_name}"
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
# For each asset in left field, go through each assets in right field
|
|
326
|
-
# to find all unique connections. Raise error if a connection between
|
|
327
|
-
# two assets already exist in a previously added association.
|
|
328
|
-
for left_asset in getattr(association, left_field_name):
|
|
329
|
-
for right_asset in getattr(association, right_field_name):
|
|
330
|
-
|
|
331
|
-
if self.association_exists_between_assets(
|
|
332
|
-
association_type, left_asset, right_asset
|
|
333
|
-
):
|
|
334
|
-
# Assets already have the connection in another
|
|
335
|
-
# association with same type
|
|
336
|
-
raise DuplicateModelAssociationError(
|
|
337
|
-
f"Association type {association_type} already exists"
|
|
338
|
-
f" between {left_asset.name} and {right_asset.name}"
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
def add_association(self, association: SchemaGeneratedClass) -> None:
|
|
342
|
-
"""Add an association to the model.
|
|
343
|
-
|
|
344
|
-
An association will have 2 field names, each
|
|
345
|
-
potentially containing several assets.
|
|
346
|
-
|
|
347
|
-
Arguments:
|
|
348
|
-
association - the association to add to the model
|
|
349
|
-
|
|
350
|
-
Raises:
|
|
351
|
-
DuplicateAssociationError - same association already exists
|
|
352
|
-
ModelAssociationException - association is not valid
|
|
353
|
-
|
|
354
|
-
"""
|
|
355
|
-
|
|
356
|
-
# Check association is valid and not duplicate
|
|
357
|
-
self._validate_association(association)
|
|
264
|
+
del self.assets[asset.id]
|
|
265
|
+
del self._name_to_asset[asset.name]
|
|
358
266
|
|
|
359
|
-
# Optional field for extra association data
|
|
360
|
-
association.extras = {}
|
|
361
|
-
|
|
362
|
-
field_names = self.get_association_field_names(association)
|
|
363
|
-
|
|
364
|
-
# Add the association to all of the included assets
|
|
365
|
-
for field_name in field_names:
|
|
366
|
-
for asset in getattr(association, field_name):
|
|
367
|
-
asset_assocs = list(asset.associations)
|
|
368
|
-
asset_assocs.append(association)
|
|
369
|
-
asset.associations = asset_assocs
|
|
370
|
-
|
|
371
|
-
self.associations.append(association)
|
|
372
|
-
|
|
373
|
-
# Add association to type->association mapping
|
|
374
|
-
association_type = association.__class__.__name__
|
|
375
|
-
self._type_to_association.setdefault(
|
|
376
|
-
association_type, []
|
|
377
|
-
).append(association)
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
def remove_association(self, association: SchemaGeneratedClass) -> None:
|
|
381
|
-
"""Remove an association from the model.
|
|
382
|
-
|
|
383
|
-
Arguments:
|
|
384
|
-
association - the association to remove from the model
|
|
385
|
-
"""
|
|
386
|
-
|
|
387
|
-
if association not in self.associations:
|
|
388
|
-
raise LookupError(
|
|
389
|
-
f'Association is not part of model "{self.name}".'
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
left_field_name, right_field_name = \
|
|
393
|
-
self.get_association_field_names(association)
|
|
394
|
-
left_field = getattr(association, left_field_name)
|
|
395
|
-
right_field = getattr(association, right_field_name)
|
|
396
|
-
|
|
397
|
-
for asset in left_field:
|
|
398
|
-
assocs = list(asset.associations)
|
|
399
|
-
assocs.remove(association)
|
|
400
|
-
asset.associations = assocs
|
|
401
|
-
|
|
402
|
-
for asset in right_field:
|
|
403
|
-
# In fringe cases we may have reflexive associations where the
|
|
404
|
-
# association was already removed when processing the left field
|
|
405
|
-
# assets therefore we have to check if it is still in the list.
|
|
406
|
-
if association in asset.associations:
|
|
407
|
-
assocs = list(asset.associations)
|
|
408
|
-
assocs.remove(association)
|
|
409
|
-
asset.associations = assocs
|
|
410
|
-
|
|
411
|
-
self.associations.remove(association)
|
|
412
|
-
|
|
413
|
-
# Remove association from type->association mapping
|
|
414
|
-
association_type = association.__class__.__name__
|
|
415
|
-
self._type_to_association[association_type].remove(
|
|
416
|
-
association
|
|
417
|
-
)
|
|
418
|
-
# Remove type from type->association mapping if mapping empty
|
|
419
|
-
if len(self._type_to_association[association_type]) == 0:
|
|
420
|
-
del self._type_to_association[association_type]
|
|
421
267
|
|
|
422
268
|
def add_attacker(
|
|
423
269
|
self,
|
|
@@ -441,9 +287,10 @@ class Model():
|
|
|
441
287
|
attacker.name = 'Attacker:' + str(attacker.id)
|
|
442
288
|
self.attackers.append(attacker)
|
|
443
289
|
|
|
290
|
+
|
|
444
291
|
def get_asset_by_id(
|
|
445
292
|
self, asset_id: int
|
|
446
|
-
) -> Optional[
|
|
293
|
+
) -> Optional[ModelAsset]:
|
|
447
294
|
"""
|
|
448
295
|
Find an asset in the model based on its id.
|
|
449
296
|
|
|
@@ -457,14 +304,12 @@ class Model():
|
|
|
457
304
|
'Get asset with id %d from model "%s".',
|
|
458
305
|
asset_id, self.name
|
|
459
306
|
)
|
|
460
|
-
return
|
|
461
|
-
|
|
462
|
-
if asset.id == asset_id), None
|
|
463
|
-
)
|
|
307
|
+
return self.assets.get(asset_id, None)
|
|
308
|
+
|
|
464
309
|
|
|
465
310
|
def get_asset_by_name(
|
|
466
311
|
self, asset_name: str
|
|
467
|
-
) -> Optional[
|
|
312
|
+
) -> Optional[ModelAsset]:
|
|
468
313
|
"""
|
|
469
314
|
Find an asset in the model based on its name.
|
|
470
315
|
|
|
@@ -478,10 +323,8 @@ class Model():
|
|
|
478
323
|
'Get asset with name "%s" from model "%s".',
|
|
479
324
|
asset_name, self.name
|
|
480
325
|
)
|
|
481
|
-
return
|
|
482
|
-
|
|
483
|
-
if asset.name == asset_name), None
|
|
484
|
-
)
|
|
326
|
+
return self._name_to_asset.get(asset_name, None)
|
|
327
|
+
|
|
485
328
|
|
|
486
329
|
def get_attacker_by_id(
|
|
487
330
|
self, attacker_id: int
|
|
@@ -504,199 +347,6 @@ class Model():
|
|
|
504
347
|
if attacker.id == attacker_id), None
|
|
505
348
|
)
|
|
506
349
|
|
|
507
|
-
def association_exists_between_assets(
|
|
508
|
-
self,
|
|
509
|
-
association_type: str,
|
|
510
|
-
left_asset: SchemaGeneratedClass,
|
|
511
|
-
right_asset: SchemaGeneratedClass
|
|
512
|
-
):
|
|
513
|
-
"""Return True if the association already exists between the assets"""
|
|
514
|
-
logger.debug(
|
|
515
|
-
'Check to see if an association of type "%s" '
|
|
516
|
-
'already exists between "%s" and "%s".',
|
|
517
|
-
association_type, left_asset.name, right_asset.name
|
|
518
|
-
)
|
|
519
|
-
associations = self._type_to_association.get(association_type, [])
|
|
520
|
-
for association in associations:
|
|
521
|
-
left_field_name, right_field_name = \
|
|
522
|
-
self.get_association_field_names(association)
|
|
523
|
-
if (left_asset.id in [asset.id for asset in \
|
|
524
|
-
getattr(association, left_field_name)] and \
|
|
525
|
-
right_asset.id in [asset.id for asset in \
|
|
526
|
-
getattr(association, right_field_name)]):
|
|
527
|
-
logger.debug(
|
|
528
|
-
'An association of type "%s" '
|
|
529
|
-
'already exists between "%s" and "%s".',
|
|
530
|
-
association_type, left_asset.name, right_asset.name
|
|
531
|
-
)
|
|
532
|
-
return True
|
|
533
|
-
logger.debug(
|
|
534
|
-
'No association of type "%s" '
|
|
535
|
-
'exists between "%s" and "%s".',
|
|
536
|
-
association_type, left_asset.name, right_asset.name
|
|
537
|
-
)
|
|
538
|
-
return False
|
|
539
|
-
|
|
540
|
-
def get_asset_defenses(
|
|
541
|
-
self,
|
|
542
|
-
asset: SchemaGeneratedClass,
|
|
543
|
-
include_defaults: bool = False
|
|
544
|
-
):
|
|
545
|
-
"""
|
|
546
|
-
Get the two field names of the association as a list.
|
|
547
|
-
Arguments:
|
|
548
|
-
asset - the asset to fetch the defenses for
|
|
549
|
-
include_defaults - if not True the defenses that have default
|
|
550
|
-
values will not be included in the list
|
|
551
|
-
|
|
552
|
-
Return:
|
|
553
|
-
A dictionary containing the defenses of the asset
|
|
554
|
-
"""
|
|
555
|
-
|
|
556
|
-
defenses = {}
|
|
557
|
-
for key, value in asset._properties.items():
|
|
558
|
-
property_schema = (
|
|
559
|
-
self.lang_classes_factory.json_schema['definitions']['LanguageAsset']
|
|
560
|
-
['definitions'][asset.type]['properties'][key]
|
|
561
|
-
)
|
|
562
|
-
|
|
563
|
-
if "maximum" not in property_schema:
|
|
564
|
-
# Check if property is a defense by looking up defense
|
|
565
|
-
# specific key. Skip if it is not a defense.
|
|
566
|
-
continue
|
|
567
|
-
|
|
568
|
-
logger.debug(
|
|
569
|
-
'Translating %s: %s defense to dictionary.',
|
|
570
|
-
key,
|
|
571
|
-
value
|
|
572
|
-
)
|
|
573
|
-
|
|
574
|
-
if not include_defaults and value == value.default():
|
|
575
|
-
# Skip the defense values if they are the default ones.
|
|
576
|
-
continue
|
|
577
|
-
|
|
578
|
-
defenses[key] = float(value)
|
|
579
|
-
|
|
580
|
-
return defenses
|
|
581
|
-
|
|
582
|
-
def get_association_field_names(
|
|
583
|
-
self,
|
|
584
|
-
association: SchemaGeneratedClass
|
|
585
|
-
):
|
|
586
|
-
"""
|
|
587
|
-
Get the two field names of the association as a list.
|
|
588
|
-
Arguments:
|
|
589
|
-
association - the association to fetch the field names for
|
|
590
|
-
|
|
591
|
-
Return:
|
|
592
|
-
A two item list containing the field names of the association.
|
|
593
|
-
"""
|
|
594
|
-
|
|
595
|
-
return association._properties.keys()
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
def get_associated_assets_by_field_name(
|
|
599
|
-
self,
|
|
600
|
-
asset: SchemaGeneratedClass,
|
|
601
|
-
field_name: str
|
|
602
|
-
) -> list[SchemaGeneratedClass]:
|
|
603
|
-
"""
|
|
604
|
-
Get a list of associated assets for an asset given a field name.
|
|
605
|
-
|
|
606
|
-
Arguments:
|
|
607
|
-
asset - the asset whose fields we are interested in
|
|
608
|
-
field_name - the field name we are looking for
|
|
609
|
-
|
|
610
|
-
Return:
|
|
611
|
-
A list of assets associated with the asset given that match the
|
|
612
|
-
field_name.
|
|
613
|
-
"""
|
|
614
|
-
|
|
615
|
-
logger.debug(
|
|
616
|
-
'Get associated assets for asset "%s"(%d) by field name %s.',
|
|
617
|
-
asset.name, asset.id, field_name
|
|
618
|
-
)
|
|
619
|
-
associated_assets = []
|
|
620
|
-
for association in asset.associations:
|
|
621
|
-
# Determine which two of the fields matches the asset given.
|
|
622
|
-
# The other field will provide the associated assets.
|
|
623
|
-
left_field_name, right_field_name = \
|
|
624
|
-
self.get_association_field_names(association)
|
|
625
|
-
|
|
626
|
-
if asset in getattr(association, left_field_name):
|
|
627
|
-
opposite_field_name = right_field_name
|
|
628
|
-
else:
|
|
629
|
-
opposite_field_name = left_field_name
|
|
630
|
-
|
|
631
|
-
if opposite_field_name == field_name:
|
|
632
|
-
associated_assets.extend(
|
|
633
|
-
getattr(association, opposite_field_name)
|
|
634
|
-
)
|
|
635
|
-
|
|
636
|
-
return associated_assets
|
|
637
|
-
|
|
638
|
-
def asset_to_dict(self, asset: SchemaGeneratedClass) -> tuple[str, dict]:
|
|
639
|
-
"""Get dictionary representation of the asset.
|
|
640
|
-
|
|
641
|
-
Arguments:
|
|
642
|
-
asset - asset to get dictionary representation of
|
|
643
|
-
|
|
644
|
-
Return: tuple with name of asset and the asset as dict
|
|
645
|
-
"""
|
|
646
|
-
|
|
647
|
-
logger.debug(
|
|
648
|
-
'Translating "%s"(%d) to dictionary.',
|
|
649
|
-
asset.name,
|
|
650
|
-
asset.id
|
|
651
|
-
)
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
asset_dict: dict[str, Any] = {
|
|
655
|
-
'name': str(asset.name),
|
|
656
|
-
'type': str(asset.type)
|
|
657
|
-
}
|
|
658
|
-
|
|
659
|
-
defenses = self.get_asset_defenses(asset)
|
|
660
|
-
|
|
661
|
-
if defenses:
|
|
662
|
-
asset_dict['defenses'] = defenses
|
|
663
|
-
|
|
664
|
-
if asset.extras:
|
|
665
|
-
# Add optional metadata to dict
|
|
666
|
-
asset_dict['extras'] = asset.extras.as_dict()
|
|
667
|
-
|
|
668
|
-
return (asset.id, asset_dict)
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
def association_to_dict(self, association: SchemaGeneratedClass) -> dict:
|
|
672
|
-
"""Get dictionary representation of the association.
|
|
673
|
-
|
|
674
|
-
Arguments:
|
|
675
|
-
association - association to get dictionary representation of
|
|
676
|
-
|
|
677
|
-
Returns the association serialized to a dict
|
|
678
|
-
"""
|
|
679
|
-
|
|
680
|
-
left_field_name, right_field_name = \
|
|
681
|
-
self.get_association_field_names(association)
|
|
682
|
-
left_field = getattr(association, left_field_name)
|
|
683
|
-
right_field = getattr(association, right_field_name)
|
|
684
|
-
|
|
685
|
-
association_dict = {
|
|
686
|
-
association.__class__.__name__ :
|
|
687
|
-
{
|
|
688
|
-
str(left_field_name):
|
|
689
|
-
[int(asset.id) for asset in left_field],
|
|
690
|
-
str(right_field_name):
|
|
691
|
-
[int(asset.id) for asset in right_field]
|
|
692
|
-
}
|
|
693
|
-
}
|
|
694
|
-
|
|
695
|
-
if association.extras:
|
|
696
|
-
# Add optional metadata to dict
|
|
697
|
-
association_dict['extras'] = association.extras
|
|
698
|
-
|
|
699
|
-
return association_dict
|
|
700
350
|
|
|
701
351
|
def attacker_to_dict(
|
|
702
352
|
self, attacker: AttackerAttachment
|
|
@@ -709,42 +359,37 @@ class Model():
|
|
|
709
359
|
|
|
710
360
|
logger.debug('Translating %s to dictionary.', attacker.name)
|
|
711
361
|
attacker_dict: dict[str, Any] = {
|
|
712
|
-
'name':
|
|
362
|
+
'name': attacker.name,
|
|
713
363
|
'entry_points': {},
|
|
714
364
|
}
|
|
715
365
|
for (asset, attack_steps) in attacker.entry_points:
|
|
716
|
-
attacker_dict['entry_points'][
|
|
366
|
+
attacker_dict['entry_points'][asset.name] = {
|
|
367
|
+
'asset_id': asset.id,
|
|
717
368
|
'attack_steps' : attack_steps
|
|
718
369
|
}
|
|
719
370
|
return (attacker.id, attacker_dict)
|
|
720
371
|
|
|
372
|
+
|
|
721
373
|
def _to_dict(self) -> dict:
|
|
722
374
|
"""Get dictionary representation of the model."""
|
|
723
375
|
logger.debug('Translating model to dict.')
|
|
724
376
|
contents: dict[str, Any] = {
|
|
725
377
|
'metadata': {},
|
|
726
378
|
'assets': {},
|
|
727
|
-
'associations': [],
|
|
728
379
|
'attackers' : {}
|
|
729
380
|
}
|
|
730
381
|
contents['metadata'] = {
|
|
731
382
|
'name': self.name,
|
|
732
|
-
'langVersion': self.
|
|
733
|
-
'langID': self.
|
|
383
|
+
'langVersion': self.lang_graph.metadata['version'],
|
|
384
|
+
'langID': self.lang_graph.metadata['id'],
|
|
734
385
|
'malVersion': '0.1.0-SNAPSHOT',
|
|
735
386
|
'MAL-Toolbox Version': __version__,
|
|
736
387
|
'info': 'Created by the mal-toolbox model python module.'
|
|
737
388
|
}
|
|
738
389
|
|
|
739
390
|
logger.debug('Translating assets to dictionary.')
|
|
740
|
-
for asset in self.assets:
|
|
741
|
-
(
|
|
742
|
-
contents['assets'][int(asset_id)] = asset_dict
|
|
743
|
-
|
|
744
|
-
logger.debug('Translating associations to dictionary.')
|
|
745
|
-
for association in self.associations:
|
|
746
|
-
assoc_dict = self.association_to_dict(association)
|
|
747
|
-
contents['associations'].append(assoc_dict)
|
|
391
|
+
for asset in self.assets.values():
|
|
392
|
+
contents['assets'].update(asset._to_dict())
|
|
748
393
|
|
|
749
394
|
logger.debug('Translating attackers to dictionary.')
|
|
750
395
|
for attacker in self.attackers:
|
|
@@ -752,22 +397,24 @@ class Model():
|
|
|
752
397
|
contents['attackers'][attacker_id] = attacker_dict
|
|
753
398
|
return contents
|
|
754
399
|
|
|
400
|
+
|
|
755
401
|
def save_to_file(self, filename: str) -> None:
|
|
756
402
|
"""Save to json/yml depending on extension"""
|
|
757
403
|
logger.debug('Save instance model to file "%s".', filename)
|
|
758
404
|
return save_dict_to_file(filename, self._to_dict())
|
|
759
405
|
|
|
406
|
+
|
|
760
407
|
@classmethod
|
|
761
408
|
def _from_dict(
|
|
762
409
|
cls,
|
|
763
410
|
serialized_object: dict,
|
|
764
|
-
|
|
411
|
+
lang_graph: LanguageGraph,
|
|
765
412
|
) -> Model:
|
|
766
413
|
"""Create a model from dict representation
|
|
767
414
|
|
|
768
415
|
Arguments:
|
|
769
416
|
serialized_object - Model in dict format
|
|
770
|
-
|
|
417
|
+
lang_graph -
|
|
771
418
|
"""
|
|
772
419
|
|
|
773
420
|
maltoolbox_version = serialized_object['metadata']['MAL Toolbox Version'] \
|
|
@@ -775,76 +422,79 @@ class Model():
|
|
|
775
422
|
else __version__
|
|
776
423
|
model = Model(
|
|
777
424
|
serialized_object['metadata']['name'],
|
|
778
|
-
|
|
425
|
+
lang_graph,
|
|
779
426
|
mt_version = maltoolbox_version)
|
|
780
427
|
|
|
781
428
|
# Reconstruct the assets
|
|
782
|
-
for asset_id,
|
|
429
|
+
for asset_id, asset_dict in serialized_object['assets'].items():
|
|
783
430
|
|
|
784
431
|
if logger.isEnabledFor(logging.DEBUG):
|
|
785
432
|
# Avoid running json.dumps when not in debug
|
|
786
433
|
logger.debug(
|
|
787
|
-
"Loading asset:\n%s", json.dumps(
|
|
434
|
+
"Loading asset:\n%s", json.dumps(asset_dict, indent=2)
|
|
788
435
|
)
|
|
789
436
|
|
|
790
437
|
# Allow defining an asset via type only.
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
if isinstance(
|
|
438
|
+
asset_dict = (
|
|
439
|
+
asset_dict
|
|
440
|
+
if isinstance(asset_dict, dict)
|
|
794
441
|
else {
|
|
795
|
-
'type':
|
|
796
|
-
'name': f"{
|
|
442
|
+
'type': asset_dict,
|
|
443
|
+
'name': f"{asset_dict}:{asset_id}"
|
|
797
444
|
}
|
|
798
445
|
)
|
|
799
446
|
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
for field, targets in assoc_fields.items():
|
|
818
|
-
targets = targets if isinstance(targets, list) else [targets]
|
|
819
|
-
setattr(
|
|
820
|
-
association,
|
|
821
|
-
field,
|
|
822
|
-
[model.get_asset_by_id(int(id)) for id in targets]
|
|
447
|
+
model.add_asset(
|
|
448
|
+
asset_type = asset_dict['type'],
|
|
449
|
+
name = asset_dict['name'],
|
|
450
|
+
defenses = {defense: float(value) for defense, value in \
|
|
451
|
+
asset_dict.get('defenses', {}).items()},
|
|
452
|
+
extras = asset_dict.get('extras', {}),
|
|
453
|
+
asset_id = int(asset_id))
|
|
454
|
+
|
|
455
|
+
# Reconstruct the association links
|
|
456
|
+
for asset_id, asset_dict in serialized_object['assets'].items():
|
|
457
|
+
asset = model.assets[int(asset_id)]
|
|
458
|
+
assoc_assets_dict = asset_dict['associated_assets'].items()
|
|
459
|
+
for fieldname, assoc_assets in assoc_assets_dict:
|
|
460
|
+
asset.add_associated_assets(
|
|
461
|
+
fieldname,
|
|
462
|
+
{model.assets[int(assoc_asset_id)]
|
|
463
|
+
for assoc_asset_id in assoc_assets}
|
|
823
464
|
)
|
|
824
465
|
|
|
825
|
-
#TODO Properly handle extras
|
|
826
|
-
|
|
827
|
-
model.add_association(association)
|
|
828
|
-
|
|
829
466
|
# Reconstruct the attackers
|
|
830
467
|
if 'attackers' in serialized_object:
|
|
831
468
|
attackers_info = serialized_object['attackers']
|
|
832
469
|
for attacker_id in attackers_info:
|
|
833
470
|
attacker = AttackerAttachment(name = attackers_info[attacker_id]['name'])
|
|
834
471
|
attacker.entry_points = []
|
|
835
|
-
for
|
|
472
|
+
for asset_name, entry_points_dict in \
|
|
473
|
+
attackers_info[attacker_id]['entry_points'].items():
|
|
474
|
+
target_asset = model.get_asset_by_id(
|
|
475
|
+
entry_points_dict['asset_id'])
|
|
476
|
+
if target_asset is None:
|
|
477
|
+
raise LookupError(
|
|
478
|
+
'Asset "%s"(%d) is not part of model "%s".' % (
|
|
479
|
+
asset_name,
|
|
480
|
+
entry_points_dict['asset_id'],
|
|
481
|
+
model.name)
|
|
482
|
+
)
|
|
836
483
|
attacker.entry_points.append(
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
484
|
+
(
|
|
485
|
+
target_asset,
|
|
486
|
+
entry_points_dict['attack_steps']
|
|
487
|
+
)
|
|
488
|
+
)
|
|
840
489
|
model.add_attacker(attacker, attacker_id = int(attacker_id))
|
|
841
490
|
return model
|
|
842
491
|
|
|
492
|
+
|
|
843
493
|
@classmethod
|
|
844
494
|
def load_from_file(
|
|
845
495
|
cls,
|
|
846
496
|
filename: str,
|
|
847
|
-
|
|
497
|
+
lang_graph: LanguageGraph,
|
|
848
498
|
) -> Model:
|
|
849
499
|
"""Create from json or yaml file depending on file extension"""
|
|
850
500
|
logger.debug('Load instance model from file "%s".', filename)
|
|
@@ -855,4 +505,181 @@ class Model():
|
|
|
855
505
|
serialized_model = load_dict_from_json_file(filename)
|
|
856
506
|
else:
|
|
857
507
|
raise ValueError('Unknown file extension, expected json/yml/yaml')
|
|
858
|
-
|
|
508
|
+
try:
|
|
509
|
+
return cls._from_dict(serialized_model, lang_graph)
|
|
510
|
+
except Exception as e:
|
|
511
|
+
raise ModelException(
|
|
512
|
+
"Could not load model. It might be of an older version. "
|
|
513
|
+
"Try to upgrade it with 'maltoolbox upgrade-model'"
|
|
514
|
+
) from e
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
class ModelAsset:
|
|
518
|
+
def __init__(
|
|
519
|
+
self,
|
|
520
|
+
name: str,
|
|
521
|
+
asset_id: int,
|
|
522
|
+
lg_asset: LanguageGraphAsset,
|
|
523
|
+
defenses: Optional[dict[str, float]] = None,
|
|
524
|
+
extras: Optional[dict] = None
|
|
525
|
+
):
|
|
526
|
+
|
|
527
|
+
self.name: str = name
|
|
528
|
+
self._id: int = asset_id
|
|
529
|
+
self.lg_asset: LanguageGraphAsset = lg_asset
|
|
530
|
+
self.type = self.lg_asset.name
|
|
531
|
+
self.defenses: dict[str, float] = defenses or {}
|
|
532
|
+
self.extras: dict = extras or {}
|
|
533
|
+
self._associated_assets: dict[str, set[ModelAsset]] = {}
|
|
534
|
+
self.attack_step_nodes: list = []
|
|
535
|
+
|
|
536
|
+
for step in self.lg_asset.attack_steps.values():
|
|
537
|
+
if step.type == 'defense' and step.name not in self.defenses:
|
|
538
|
+
self.defenses[step.name] = 1.0 if step.ttc and \
|
|
539
|
+
step.ttc['name'] == 'Enabled' else 0.0
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _to_dict(self):
|
|
543
|
+
"""Get dictionary representation of the asset."""
|
|
544
|
+
|
|
545
|
+
logger.debug(
|
|
546
|
+
'Translating "%s"(%d) to dictionary.', self.name, self.id)
|
|
547
|
+
|
|
548
|
+
asset_dict: dict[str, Any] = {
|
|
549
|
+
'name': self.name,
|
|
550
|
+
'type': self.type,
|
|
551
|
+
'defenses': {},
|
|
552
|
+
'associated_assets': {}
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
# Only add non-default values for defenses to improve legibility of
|
|
556
|
+
# the model format
|
|
557
|
+
for defense, defense_value in self.defenses.items():
|
|
558
|
+
lg_step = self.lg_asset.attack_steps[defense]
|
|
559
|
+
default_defval = 1.0 if lg_step.ttc and \
|
|
560
|
+
lg_step.ttc['name'] == 'Enabled' else 0.0
|
|
561
|
+
if defense_value != default_defval:
|
|
562
|
+
asset_dict['defenses'][defense] = defense_value
|
|
563
|
+
|
|
564
|
+
for fieldname, assets in self.associated_assets.items():
|
|
565
|
+
asset_dict['associated_assets'][fieldname] = {asset.id: asset.name
|
|
566
|
+
for asset in assets}
|
|
567
|
+
|
|
568
|
+
if len(asset_dict['defenses']) == 0:
|
|
569
|
+
# Do not include an empty defenses dictionary
|
|
570
|
+
del asset_dict['defenses']
|
|
571
|
+
|
|
572
|
+
if self.extras != {}:
|
|
573
|
+
# Add optional metadata to dict
|
|
574
|
+
asset_dict['extras'] = self.extras
|
|
575
|
+
|
|
576
|
+
return {self.id: asset_dict}
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def __repr__(self):
|
|
580
|
+
return (f'ModelAsset(name: "{self.name}", id: {self.id}, '
|
|
581
|
+
f'type: {self.type})')
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def validate_associated_assets(
|
|
585
|
+
self, fieldname: str, assets_to_add: set[ModelAsset]
|
|
586
|
+
):
|
|
587
|
+
"""
|
|
588
|
+
Validate an association we want to add (through `fieldname`)
|
|
589
|
+
is valid with the assets given in param `assets_to_add`:
|
|
590
|
+
- fieldname is valid for the asset type of this ModelAsset
|
|
591
|
+
- type of `assets_to_add` is valid for the association
|
|
592
|
+
- no more assets than 'field.maximum' are added to the field
|
|
593
|
+
|
|
594
|
+
Raises:
|
|
595
|
+
LookupError - fieldname can not be found for this ModelAsset
|
|
596
|
+
ValueError - there will be too many assets in the field
|
|
597
|
+
if we add this association
|
|
598
|
+
TypeError - if the asset type of `assets_to_add` is not valid
|
|
599
|
+
"""
|
|
600
|
+
|
|
601
|
+
# Validate that the field name is allowed for this asset type
|
|
602
|
+
if fieldname not in self.lg_asset.associations:
|
|
603
|
+
accepted_fieldnames = list(self.lg_asset.associations.keys())
|
|
604
|
+
raise LookupError(
|
|
605
|
+
f"Fieldname '{fieldname}' is not an accepted association "
|
|
606
|
+
f"fieldname from asset type {self.lg_asset.name}. "
|
|
607
|
+
f"Did you mean one of {accepted_fieldnames}?"
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
lg_assoc = self.lg_asset.associations[fieldname]
|
|
611
|
+
assoc_field = lg_assoc.get_field(fieldname)
|
|
612
|
+
|
|
613
|
+
# Validate that the asset to add association to is of correct type
|
|
614
|
+
for asset_to_add in assets_to_add:
|
|
615
|
+
if not asset_to_add.lg_asset.is_subasset_of(assoc_field.asset):
|
|
616
|
+
raise TypeError(
|
|
617
|
+
f"Asset '{asset_to_add.name}' of type "
|
|
618
|
+
f"'{asset_to_add.type}' can not be added to association "
|
|
619
|
+
f"'{self.name}.{fieldname}'. Expected type of "
|
|
620
|
+
f"'{fieldname}' is {assoc_field.asset.name}."
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Validate that there will not be too many assets in field
|
|
624
|
+
assets_in_field_before = self.associated_assets.get(fieldname, set())
|
|
625
|
+
assets_in_field_after = assets_in_field_before | set(assets_to_add)
|
|
626
|
+
max_assets_in_field = assoc_field.maximum or math.inf
|
|
627
|
+
|
|
628
|
+
if len(assets_in_field_after) > max_assets_in_field:
|
|
629
|
+
raise ValueError(
|
|
630
|
+
f"You can have maximum {assoc_field.maximum} "
|
|
631
|
+
f"assets for association field {fieldname}"
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
def add_associated_assets(self, fieldname: str, assets: set[ModelAsset]):
|
|
635
|
+
"""
|
|
636
|
+
Add the assets provided as a parameter to the set of associated
|
|
637
|
+
assets dictionary entry corresponding to the given fieldname.
|
|
638
|
+
"""
|
|
639
|
+
|
|
640
|
+
lg_assoc = self.lg_asset.associations[fieldname]
|
|
641
|
+
other_fieldname = lg_assoc.get_opposite_fieldname(fieldname)
|
|
642
|
+
|
|
643
|
+
# Validation from both sides
|
|
644
|
+
self.validate_associated_assets(fieldname, assets)
|
|
645
|
+
for asset in assets:
|
|
646
|
+
asset.validate_associated_assets(other_fieldname, {self})
|
|
647
|
+
|
|
648
|
+
# Add the associated assets to this asset's dictionary
|
|
649
|
+
self._associated_assets.setdefault(
|
|
650
|
+
fieldname, set()
|
|
651
|
+
).update(assets)
|
|
652
|
+
|
|
653
|
+
# Add this asset to the associated assets' corresponding dictionaries
|
|
654
|
+
for asset in assets:
|
|
655
|
+
asset._associated_assets.setdefault(
|
|
656
|
+
other_fieldname, set()
|
|
657
|
+
).add(self)
|
|
658
|
+
|
|
659
|
+
def remove_associated_assets(self, fieldname: str,
|
|
660
|
+
assets: set[ModelAsset]):
|
|
661
|
+
""" Remove the assets provided as a parameter from the set of
|
|
662
|
+
associated assets dictionary entry corresponding to the fieldname
|
|
663
|
+
parameter.
|
|
664
|
+
"""
|
|
665
|
+
self._associated_assets[fieldname] -= set(assets)
|
|
666
|
+
if len(self._associated_assets[fieldname]) == 0:
|
|
667
|
+
del self._associated_assets[fieldname]
|
|
668
|
+
|
|
669
|
+
# Also remove this asset to the associated assets' dictionaries
|
|
670
|
+
lg_assoc = self.lg_asset.associations[fieldname]
|
|
671
|
+
other_fieldname = lg_assoc.get_opposite_fieldname(fieldname)
|
|
672
|
+
for asset in assets:
|
|
673
|
+
asset._associated_assets[other_fieldname].remove(self)
|
|
674
|
+
if len(asset._associated_assets[other_fieldname]) == 0:
|
|
675
|
+
del asset._associated_assets[other_fieldname]
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
@property
|
|
679
|
+
def associated_assets(self):
|
|
680
|
+
return self._associated_assets
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
@property
|
|
684
|
+
def id(self):
|
|
685
|
+
return self._id
|