stix2arango 1.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,81 @@
1
+ import argparse
2
+ from stix2arango.stix2arango import Stix2Arango
3
+
4
+
5
+ def parse_bool(value: str):
6
+ value = value.lower()
7
+ # ["false", "no", "n"]
8
+ return value in ["yes", "y", "true", "1"]
9
+
10
+ def parse_ref(value: str):
11
+ if not (value.endswith('_ref') or value.endswith('_refs')):
12
+ raise argparse.ArgumentTypeError('value must end with _ref or _refs')
13
+ return value
14
+
15
+
16
+ def parse_arguments():
17
+ parser = argparse.ArgumentParser(description="Import STIX JSON into ArangoDB")
18
+ parser.add_argument("--file", required=True, help="Path to STIX JSON file")
19
+ parser.add_argument(
20
+ "--is_large_file",
21
+ action="store_true",
22
+ help="Use large file mode [Use this mode when the bundle is very large, this will enable you stix2arango to chunk before loading into memory]",
23
+ )
24
+ parser.add_argument("--database", required=True, help="ArangoDB database name")
25
+ parser.add_argument(
26
+ "--create_db",
27
+ default=True,
28
+ type=parse_bool,
29
+ help="whether or not to skip the creation of database, requires admin permission",
30
+ )
31
+ parser.add_argument("--collection", required=True, help="ArangoDB collection name")
32
+ parser.add_argument(
33
+ "--stix2arango_note", required=False, help="Note for the import", default=""
34
+ )
35
+ parser.add_argument(
36
+ "--ignore_embedded_relationships",
37
+ required=False,
38
+ help="Ignore Embedded Relationship for the import",
39
+ type=parse_bool,
40
+ default=False,
41
+ )
42
+ parser.add_argument(
43
+ "--ignore_embedded_relationships_sro",
44
+ required=False,
45
+ help="Ignore Embedded Relationship for imported SROs",
46
+ type=parse_bool,
47
+ default=False,
48
+ )
49
+ parser.add_argument(
50
+ "--ignore_embedded_relationships_smo",
51
+ required=False,
52
+ help="Ignore Embedded Relationship for imported SMOs",
53
+ type=parse_bool,
54
+ default=False,
55
+ )
56
+ parser.add_argument(
57
+ "--include_embedded_relationships_attributes",
58
+ required=False,
59
+ help="Only create embedded relationships for keys",
60
+ action="extend",
61
+ nargs="+",
62
+ type=parse_ref
63
+ )
64
+ return parser.parse_args()
65
+
66
+
67
+ def main():
68
+ args = parse_arguments()
69
+ stix_obj = Stix2Arango(
70
+ database=args.database,
71
+ collection=args.collection,
72
+ file=args.file,
73
+ create_db=args.create_db,
74
+ stix2arango_note=args.stix2arango_note,
75
+ ignore_embedded_relationships=args.ignore_embedded_relationships,
76
+ ignore_embedded_relationships_sro=args.ignore_embedded_relationships_sro,
77
+ ignore_embedded_relationships_smo=args.ignore_embedded_relationships_smo,
78
+ is_large_file=args.is_large_file,
79
+ include_embedded_relationships_attributes=args.include_embedded_relationships_attributes,
80
+ )
81
+ stix_obj.run()
stix2arango/config.py ADDED
@@ -0,0 +1,41 @@
1
+ import os
2
+ import logging
3
+ from dotenv import load_dotenv
4
+ from uuid import UUID
5
+
6
+ load_dotenv()
7
+
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format="[%(asctime)s] %(levelname)s - %(message)s", # noqa D100 E501
11
+ datefmt="%Y-%m-%d - %H:%M:%S",
12
+ )
13
+ ARANGODB_HOST = os.getenv("ARANGODB_HOST")
14
+ ARANGODB_PORT = os.getenv("ARANGODB_PORT")
15
+ ARANGODB_USERNAME = os.getenv("ARANGODB_USERNAME")
16
+ ARANGODB_PASSWORD = os.getenv("ARANGODB_PASSWORD")
17
+
18
+ json_schema = {
19
+ "type": "object",
20
+ "properties": {
21
+ "type": {"type": "string", "const": "bundle"},
22
+ "id": {"type": "string"},
23
+ "objects": {"type": "array", "items": {"type": "object"}}
24
+ },
25
+ "required": ["type", "id", "objects"]
26
+ }
27
+ STIX2ARANGO_IDENTITY = "https://github.com/muchdogesec/stix4doge/raw/main/objects/identity/stix2arango.json" # this is stix2arango identity
28
+ DOGESEC_IDENTITY = "https://github.com/muchdogesec/stix4doge/raw/main/objects/identity/dogesec.json" # this is stix2arango identity
29
+
30
+ STIX2ARANGO_MARKING_DEFINITION = "https://raw.githubusercontent.com/muchdogesec/stix4doge/main/objects/marking-definition/stix2arango.json" # this is stix2arango marking-definition
31
+
32
+ IDENTITY_REFS = [
33
+ STIX2ARANGO_IDENTITY,
34
+ DOGESEC_IDENTITY
35
+ ]
36
+ MARKING_DEFINITION_REFS = [
37
+ STIX2ARANGO_MARKING_DEFINITION
38
+ ]
39
+ DEFAULT_OBJECT_URL = MARKING_DEFINITION_REFS + IDENTITY_REFS
40
+
41
+ namespace = UUID("72e906ce-ca1b-5d73-adcd-9ea9eb66a1b4")
@@ -0,0 +1 @@
1
+ from .arangodb_service import ArangoDBService
@@ -0,0 +1,313 @@
1
+ import contextlib
2
+ import os
3
+ import json
4
+ import logging
5
+ import re
6
+ import time
7
+ from typing import Any
8
+ import arango.database
9
+ from arango.collection import StandardCollection
10
+ from arango import ArangoClient
11
+ from arango.exceptions import ArangoServerError
12
+
13
+ from datetime import datetime, timezone
14
+ from tqdm import tqdm
15
+
16
+ from stix2arango.services.version_annotator import annotate_versions
17
+
18
+ from .. import config
19
+ from .. import utils
20
+ from pprint import pprint
21
+
22
+ module_logger = logging.getLogger("data_ingestion_service")
23
+
24
+
25
+ class ArangoDBService:
26
+
27
+ def __init__(
28
+ self,
29
+ db,
30
+ vertex_collections,
31
+ edge_collections,
32
+ relationship=None,
33
+ create_db=False,
34
+ create=False,
35
+ username=None,
36
+ password=None,
37
+ host_url=None,
38
+ **kwargs,
39
+ ):
40
+ self.ARANGO_DB = self.get_db_name(db)
41
+ self.ARANGO_GRAPH = f"{self.ARANGO_DB.split('_database')[0]}_graph"
42
+ self.COLLECTIONS_VERTEX = vertex_collections
43
+ self.COLLECTIONS_EDGE = edge_collections
44
+ self.FORCE_RELATIONSHIP = [relationship] if relationship else None
45
+ self.missing_collection = True
46
+
47
+ module_logger.info("Establishing connection...")
48
+ client = ArangoClient(hosts=host_url)
49
+ self._client = client
50
+
51
+ if create_db:
52
+ module_logger.info(f"create db `{self.ARANGO_DB}` if not exist")
53
+ self.sys_db = client.db("_system", username=username, password=password)
54
+
55
+ module_logger.info("_system database - OK")
56
+
57
+ if not self.sys_db.has_database(self.ARANGO_DB):
58
+ self.create_database(self.ARANGO_DB)
59
+
60
+ self.db = client.db(
61
+ self.ARANGO_DB, username=username, password=password, verify=True
62
+ )
63
+
64
+ if self.db.has_graph(self.ARANGO_GRAPH):
65
+ self.cti2stix_graph = self.db.graph(self.ARANGO_GRAPH)
66
+ elif create_db:
67
+ self.cti2stix_graph = self.db.create_graph(self.ARANGO_GRAPH)
68
+
69
+ self.collections: dict[str, StandardCollection] = {}
70
+ for collection in self.COLLECTIONS_VERTEX:
71
+ if create:
72
+ self.collections[collection] = self.create_collection(collection)
73
+
74
+ self.collections[collection] = self.db.collection(collection)
75
+
76
+ for collection in self.COLLECTIONS_EDGE:
77
+
78
+ if create:
79
+ try:
80
+ self.cti2stix_objects_relationship = (
81
+ self.cti2stix_graph.create_edge_definition(
82
+ edge_collection=collection,
83
+ from_vertex_collections=self.COLLECTIONS_VERTEX,
84
+ to_vertex_collections=self.COLLECTIONS_VERTEX,
85
+ )
86
+ )
87
+ except Exception as e:
88
+ module_logger.debug(
89
+ f"create edge collection {collection} failed with {e}"
90
+ )
91
+
92
+ self.cti2stix_objects_relationship = self.cti2stix_graph.edge_collection(
93
+ collection
94
+ )
95
+ self.collections[collection] = self.cti2stix_objects_relationship
96
+
97
+ module_logger.info("ArangoDB Connected now!")
98
+
99
+ def create_database(self, db_name):
100
+ try:
101
+ self.sys_db.create_database(db_name)
102
+ except arango.exceptions.DatabaseCreateError as e:
103
+ module_logger.debug(f"create database {db_name} failed with {e}")
104
+
105
+ def create_collection(self, collection_name):
106
+ try:
107
+ return self.db.create_collection(collection_name)
108
+ except arango.exceptions.CollectionCreateError as e:
109
+ module_logger.warning(
110
+ f"create collection {collection_name} failed with {e}"
111
+ )
112
+ return self.db.collection(collection_name)
113
+
114
+ def execute_raw_query(self, query: str, bind_vars=None, **kwargs) -> list:
115
+ try:
116
+ cursor = self.db.aql.execute(query, bind_vars=bind_vars, **kwargs)
117
+ result = [doc for doc in cursor]
118
+ return result
119
+ except arango.exceptions.AQLQueryExecuteError:
120
+ module_logger.error(f"AQL exception in the query: {query}")
121
+ raise
122
+
123
+ def insert_several_objects(self, objects: list[dict], collection_name: str) -> None:
124
+ if not collection_name:
125
+ module_logger.info(f"Object has unknown type: {objects}")
126
+ return
127
+
128
+ for _, obj in enumerate(objects):
129
+ now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
130
+ obj["_is_latest"] = False
131
+ obj["_record_created"] = obj.get("_record_created", now)
132
+ obj["_record_modified"] = now
133
+ obj["_key"] = obj.get("_key", f'{obj["id"]}+{now}')
134
+
135
+ if obj["type"] == "relationship":
136
+ obj.update(
137
+ _target_type=obj["target_ref"].split("--")[0],
138
+ _source_type=obj["source_ref"].split("--")[0],
139
+ )
140
+ new_insertions = objects #[obj for obj in objects if f'{obj["id"]};{obj["_record_md5_hash"]}' not in existing_objects]
141
+ existing_objects = {}
142
+
143
+ d = self.db.collection(collection_name).insert_many(new_insertions, overwrite_mode="ignore", sync=True)
144
+ for i, ret in enumerate(d):
145
+ obj = objects[i]
146
+ if isinstance(ret, arango.exceptions.DocumentInsertError):
147
+ if ret.error_code == 1210:
148
+ existing_objects[f'{obj["id"]};{obj["_record_md5_hash"]}'] = collection_name + '/' + re.search(r'conflicting key: (.*)', ret.message).group(1)
149
+ else:
150
+ raise ret
151
+ return [obj["id"] for obj in new_insertions], existing_objects
152
+
153
+ def insert_several_objects_chunked(
154
+ self, objects, collection_name, chunk_size=5000, remove_duplicates=True
155
+ ):
156
+ if remove_duplicates:
157
+ original_length = len(objects)
158
+ objects = utils.remove_duplicates(objects)
159
+ logging.info(
160
+ "removed {count} duplicates from imported objects.".format(
161
+ count=original_length - len(objects)
162
+ )
163
+ )
164
+
165
+ progress_bar = tqdm(
166
+ utils.chunked(objects, chunk_size),
167
+ total=len(objects),
168
+ desc="insert_several_objects_chunked",
169
+ )
170
+ inserted_objects = []
171
+ existing_objects = {}
172
+ for chunk in progress_bar:
173
+ inserted, existing = self.insert_several_objects(chunk, collection_name)
174
+ inserted_objects.extend(inserted)
175
+ existing_objects.update(existing)
176
+ progress_bar.update(len(chunk))
177
+ return inserted_objects, existing_objects
178
+
179
+ def insert_relationships_chunked(
180
+ self,
181
+ relationships: list[dict[str, Any]],
182
+ id_to_key_map: dict[str, str],
183
+ collection_name: str,
184
+ chunk_size=5000,
185
+ ):
186
+ for relationship in relationships:
187
+ source_key = id_to_key_map.get(relationship["source_ref"])
188
+ target_key = id_to_key_map.get(relationship["target_ref"])
189
+
190
+ relationship["_stix2arango_ref_err"] = not (target_key and source_key)
191
+ relationship["_from"] = self.fix_edge_ref(source_key or relationship["_from"])
192
+ relationship["_to"] = self.fix_edge_ref(target_key or relationship["_to"])
193
+ relationship["_record_md5_hash"] = relationship.get(
194
+ "_record_md5_hash", utils.generate_md5(relationship)
195
+ )
196
+ return self.insert_several_objects_chunked(
197
+ relationships, collection_name, chunk_size=chunk_size
198
+ )
199
+
200
+ @staticmethod
201
+ def fix_edge_ref(_id):
202
+ c, _, _key = _id.rpartition('/')
203
+ if not c:
204
+ c = "missing_collection"
205
+ return f"{c}/{_key}"
206
+
207
+ def update_is_latest_several(self, object_ids, collection_name):
208
+ # returns newly deprecated _ids
209
+ query = """
210
+ FOR doc IN @@collection OPTIONS {indexHint: "s2a_search", forceIndexHint: true}
211
+ FILTER doc.id IN @object_ids
212
+ RETURN [doc.id, doc._key, doc.modified, doc._record_modified, doc._is_latest, doc._id]
213
+ """
214
+ out = self.execute_raw_query(
215
+ query,
216
+ bind_vars={
217
+ "@collection": collection_name,
218
+ "object_ids": object_ids,
219
+ },
220
+ )
221
+ out = [dict(zip(('id', '_key', 'modified', '_record_modified', '_is_latest', '_id'), obj_tuple)) for obj_tuple in out]
222
+ annotated, deprecated = annotate_versions(out)
223
+ logging.info(f"Updating annotated versions for {len(annotated)} items, deprecating {len(deprecated)} items")
224
+ for chunk in utils.chunked(annotated, 5000):
225
+ self.db.collection(collection_name).update_many(chunk, sync=True, keep_none=False, silent=True)
226
+ return deprecated
227
+
228
+
229
+
230
+ def update_is_latest_several_chunked(
231
+ self, object_ids, collection_name, edge_collection=None, chunk_size=5000
232
+ ):
233
+ logging.info(f"Updating _is_latest for {len(object_ids)} newly inserted items")
234
+ progress_bar = tqdm(
235
+ utils.chunked(object_ids, chunk_size),
236
+ total=len(object_ids),
237
+ desc="update_is_latest_several_chunked",
238
+ )
239
+ deprecated_key_ids = [] # contains newly deprecated _ids
240
+ for chunk in progress_bar:
241
+ deprecated_key_ids.extend(
242
+ self.update_is_latest_several(chunk, collection_name)
243
+ )
244
+ progress_bar.update(len(chunk))
245
+
246
+ logging.info(
247
+ f"Deprecating _is_latest for {len(deprecated_key_ids)} items"
248
+ )
249
+ self.deprecate_relationships(deprecated_key_ids, edge_collection)
250
+ return deprecated_key_ids
251
+
252
+ def deprecate_relationships(
253
+ self, deprecated_key_ids: list, edge_collection: str, chunk_size=5000
254
+ ):
255
+ keys = self.get_relationships_to_deprecate(deprecated_key_ids, edge_collection)
256
+
257
+ progress_bar = tqdm(
258
+ utils.chunked(keys, chunk_size),
259
+ total=len(keys),
260
+ desc="deprecate_relationships",
261
+ )
262
+ for chunk in progress_bar:
263
+ self.db.collection(edge_collection).update_many(
264
+ tuple(dict(_key=_key, _is_latest=False) for _key in chunk),
265
+ silent=True,
266
+ raise_on_document_error=True,
267
+ )
268
+ progress_bar.update(len(chunk))
269
+ return len(keys)
270
+
271
+ def get_relationships_to_deprecate(
272
+ self, deprecated_key_ids: list, edge_collection: str
273
+ ):
274
+ query = """
275
+ FOR doc IN @@collection OPTIONS {indexHint: "s2a_search_edge", forceIndexHint: true}
276
+ FILTER doc._from IN @deprecated_key_ids AND doc._is_latest == TRUE
277
+ RETURN doc._id
278
+ """
279
+ items_to_deprecate_full: set[str] = {*deprecated_key_ids}
280
+
281
+ while deprecated_key_ids:
282
+ deprecated_key_ids = self.execute_raw_query(
283
+ query,
284
+ bind_vars={
285
+ "@collection": edge_collection,
286
+ "deprecated_key_ids": deprecated_key_ids,
287
+ },
288
+ )
289
+ items_to_deprecate_full.update(deprecated_key_ids)
290
+ return [_id.split("/", 1)[1] for _id in items_to_deprecate_full]
291
+
292
+ @staticmethod
293
+ def get_db_name(name):
294
+ ENDING = "_database"
295
+ if name.endswith(ENDING):
296
+ return name
297
+ return name + ENDING
298
+
299
+ @contextlib.contextmanager
300
+ def transactional(self, write=None, exclusive=None, sync=True):
301
+ original_db = self.db
302
+ transactional_db = self.db.begin_transaction(allow_implicit=True, write=write, exclusive=exclusive, sync=sync, lock_timeout=300)
303
+ try:
304
+ logging.info(f"entering transaction: {transactional_db.transaction_status()}")
305
+ self.db = transactional_db
306
+ yield self
307
+ transactional_db.commit_transaction()
308
+ except:
309
+ transactional_db.abort_transaction()
310
+ raise
311
+ finally:
312
+ logging.info(f"exiting transaction: {transactional_db.transaction_status()}")
313
+ self.db = original_db
@@ -0,0 +1,95 @@
1
+ from collections import defaultdict
2
+ import os
3
+ import time
4
+ from typing import List, Dict
5
+ import copy
6
+
7
+ from datetime import datetime
8
+ import os
9
+ import typing
10
+ from arango.client import ArangoClient
11
+ from arango.database import StandardDatabase
12
+
13
+
14
+ def annotate_versions(objects: List[Dict]):
15
+ grouped = defaultdict(list)
16
+
17
+ # Group by 'id'
18
+ for obj in objects:
19
+ grouped[obj["id"]].append(obj)
20
+
21
+ annotations: list[dict] = []
22
+ deprecated = []
23
+
24
+ for obj_id, items in grouped.items():
25
+ # items = [copy.deepcopy(item) for item in items]
26
+
27
+ # Separate items with non-None modified
28
+ valid_modified = [item for item in items if item.get("modified") is not None]
29
+
30
+ # _is_latest: max(modified) -> max(_record_modified)
31
+ if valid_modified:
32
+ max_modified = max(item.get("modified") for item in valid_modified)
33
+ latest_candidates = [
34
+ item for item in valid_modified if item.get("modified") == max_modified
35
+ ]
36
+ max_record_modified_latest = max(
37
+ item["_record_modified"] for item in latest_candidates
38
+ )
39
+ else:
40
+ max_modified = None
41
+ max_record_modified_latest = max(item["_record_modified"] for item in items)
42
+ # _is_earliest: min(modified) -> max(_record_modified)
43
+ if valid_modified:
44
+ min_modified = min(item.get("modified") for item in valid_modified)
45
+ earliest_candidates = [
46
+ item for item in valid_modified if item.get("modified") == min_modified
47
+ ]
48
+ max_record_modified_earliest = max(
49
+ item["_record_modified"] for item in earliest_candidates
50
+ )
51
+ else:
52
+ min_modified = None
53
+ max_record_modified_earliest = min(
54
+ item["_record_modified"] for item in items
55
+ )
56
+
57
+ # _taxii_visible: for each modified (including None), select highest _record_modified
58
+ taxii_visible_keys = set()
59
+ modified_groups = defaultdict(list)
60
+ for item in items:
61
+ modified_groups[item.get("modified")].append(item)
62
+
63
+ for mod_val, group in modified_groups.items():
64
+ max_rec_mod = max(i["_record_modified"] for i in group)
65
+ for item in group:
66
+ if item["_record_modified"] == max_rec_mod:
67
+ taxii_visible_keys.add(item["_key"])
68
+
69
+ for item in items:
70
+ is_latest = (
71
+ item.get("modified") == max_modified
72
+ and item["_record_modified"] == max_record_modified_latest
73
+ )
74
+
75
+ is_earliest = (
76
+ item.get("modified") == min_modified
77
+ and item["_record_modified"] == max_record_modified_earliest
78
+ )
79
+
80
+ if item.get("_is_latest") and not is_latest:
81
+ deprecated.append(item["_id"])
82
+ taxii_ = dict(
83
+ visible=item["_key"] in taxii_visible_keys,
84
+ first=is_earliest,
85
+ last=is_latest,
86
+ )
87
+ annotations.append(
88
+ dict(
89
+ _key=item["_key"],
90
+ _is_latest=is_latest,
91
+ _taxii=taxii_,
92
+ )
93
+ )
94
+
95
+ return annotations, deprecated
@@ -0,0 +1 @@
1
+ from .stix2arango import Stix2Arango
@@ -0,0 +1,143 @@
1
+ import contextlib
2
+ from datetime import datetime
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import sqlite3
7
+ import tempfile
8
+ import uuid
9
+ import ijson
10
+ import json
11
+ from collections import Counter
12
+
13
+ from stix2arango.utils import get_embedded_refs
14
+
15
+
16
+ class BundleLoader:
17
+ def __init__(self, file_path, chunk_size_min=20_000, db_path=""):
18
+ self.file_path = Path(file_path)
19
+ self.chunk_size_min = chunk_size_min
20
+ self.groups = None
21
+ self.bundle_id = "bundle--" + str(uuid.uuid4())
22
+
23
+ self.db_path = db_path
24
+ if not self.db_path:
25
+ self.temp_path = tempfile.NamedTemporaryFile(
26
+ prefix="s2a_bundle_loader--", suffix=".sqlite"
27
+ )
28
+ self.db_path = self.temp_path.name
29
+ self._init_db()
30
+
31
+ def _init_db(self):
32
+ """Initialize SQLite DB with objects table."""
33
+ self.conn = sqlite3.connect(self.db_path)
34
+ self.conn.execute(
35
+ """
36
+ CREATE TABLE IF NOT EXISTS objects (
37
+ id TEXT PRIMARY KEY,
38
+ type TEXT,
39
+ raw TEXT
40
+ )
41
+ """
42
+ )
43
+ self.conn.execute("PRAGMA synchronous = OFF;")
44
+ self.conn.execute("PRAGMA journal_mode = MEMORY;")
45
+ self.conn.execute("PRAGMA temp_store = MEMORY;")
46
+ self.conn.commit()
47
+
48
+ def save_to_sqlite(self, objects):
49
+ """Save one STIX object to the SQLite database."""
50
+ self.inserted = getattr(self, "inserted", 0)
51
+
52
+ try:
53
+ self.conn.executemany(
54
+ "INSERT OR REPLACE INTO objects (id, type, raw) VALUES (?, ?, ?)",
55
+ [(obj["id"], obj["type"], json.dumps(obj)) for obj in objects],
56
+ )
57
+ except sqlite3.IntegrityError as e:
58
+ print(f"Failed to insert len({objects}) objects: {e}")
59
+ else:
60
+ self.conn.commit()
61
+ self.inserted += len(objects)
62
+ # logging.info(f"inserted {self.inserted}")
63
+
64
+ @staticmethod
65
+ def get_refs(obj):
66
+ refs = []
67
+ for _type, targets in get_embedded_refs(obj):
68
+ refs.extend(targets)
69
+ return refs
70
+
71
+ def build_groups(self):
72
+ """
73
+ Iterates the STIX bundle and uses union-find to group IDs such that for every
74
+ relationship (object of type "relationship"), its own id and its source_ref
75
+ and target_ref end up in the same group.
76
+ """
77
+ all_ids: dict[str, list[str]] = dict() # All object IDs in the file
78
+ logging.info(f"loading into {self.db_path}")
79
+
80
+ with open(self.file_path, "rb") as f:
81
+ objects = ijson.items(f, "objects.item", use_float=True)
82
+ to_insert = []
83
+ for obj in objects:
84
+ obj_id = obj.get("id")
85
+ to_insert.append(obj)
86
+ all_ids.setdefault(obj_id, [])
87
+ if obj["type"] == "relationship" and all(
88
+ x in obj for x in ["target_ref", "source_ref"]
89
+ ):
90
+ sr, tr = [obj["source_ref"], obj["target_ref"]]
91
+ all_ids[obj_id].extend([sr, tr])
92
+ all_ids.setdefault(sr, []).extend([tr, obj_id])
93
+ all_ids.setdefault(tr, []).extend([sr, obj_id])
94
+ for ref in self.get_refs(obj):
95
+ all_ids[obj_id].append(ref)
96
+ if len(to_insert) >= self.chunk_size_min:
97
+ self.save_to_sqlite(to_insert)
98
+ to_insert.clear()
99
+ if to_insert:
100
+ self.save_to_sqlite(to_insert)
101
+
102
+ logging.info(f"loaded {self.inserted} into {self.db_path}")
103
+ handled = set()
104
+
105
+ self.groups = []
106
+ group = set()
107
+
108
+ def from_ids(all_ids):
109
+ for obj_id in all_ids:
110
+ if obj_id in handled:
111
+ continue
112
+ group_objs = {obj_id, *all_ids[obj_id]}
113
+ handled.update(group_objs)
114
+ new_group = group.union(group_objs)
115
+ if len(new_group) >= self.chunk_size_min:
116
+ group.clear()
117
+ self.groups.append(tuple(new_group))
118
+ else:
119
+ group.update(group_objs)
120
+
121
+ from_ids(all_ids)
122
+ if group:
123
+ self.groups.append(tuple(group))
124
+ return self.groups
125
+
126
+ def load_objects_by_ids(self, ids):
127
+ """Retrieve a list of STIX objects by their IDs from the SQLite database."""
128
+ placeholders = ",".join(["?"] * len(ids))
129
+ query = f"SELECT raw FROM objects WHERE id IN ({placeholders})"
130
+ cursor = self.conn.execute(query, list(ids))
131
+ return [json.loads(row[0]) for row in cursor.fetchall()]
132
+
133
+ def get_objects(self, group):
134
+ return list(self.load_objects_by_ids(group))
135
+
136
+ @property
137
+ def chunks(self):
138
+ for group in self.groups or self.build_groups():
139
+ yield self.get_objects(group)
140
+
141
+ def __del__(self):
142
+ with contextlib.suppress(Exception):
143
+ os.remove(self.db_path)