cyvest 0.1.0__py3-none-any.whl → 5.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1682 @@
1
+ """
2
+ Investigation core - central state management for cybersecurity investigations.
3
+
4
+ Handles all object storage, merging, scoring, and statistics in a unified way.
5
+ Provides automatic merge-on-create for all object types.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from copy import deepcopy
11
+ from datetime import datetime, timezone
12
+ from decimal import Decimal
13
+ from typing import TYPE_CHECKING, Any, Literal
14
+
15
+ from logurich import logger
16
+
17
+ from cyvest import keys
18
+ from cyvest.level_score_rules import recalculate_level_for_score
19
+ from cyvest.levels import Level, normalize_level
20
+ from cyvest.model import (
21
+ AuditEvent,
22
+ Check,
23
+ Enrichment,
24
+ InvestigationWhitelist,
25
+ Observable,
26
+ ObservableLink,
27
+ ObservableType,
28
+ Relationship,
29
+ Tag,
30
+ Taxonomy,
31
+ ThreatIntel,
32
+ )
33
+ from cyvest.model_enums import PropagationMode, RelationshipDirection, RelationshipType
34
+ from cyvest.score import ScoreEngine, ScoreMode
35
+ from cyvest.stats import InvestigationStats
36
+ from cyvest.ulid import generate_ulid
37
+
38
+ if TYPE_CHECKING:
39
+ from cyvest.model_schema import StatisticsSchema
40
+
41
+
42
+ class Investigation:
43
+ """
44
+ Core investigation state and operations.
45
+
46
+ Manages all investigation objects (observables, checks, threat intel, etc.),
47
+ handles automatic merging on creation, score propagation, and statistics tracking.
48
+ """
49
+
50
+ _MODEL_METADATA_RULES: dict[str, dict[str, set[str]]] = {
51
+ "observable": {
52
+ "fields": {"comment", "extra", "internal", "whitelisted"},
53
+ "dict_fields": {"extra"},
54
+ },
55
+ "check": {
56
+ "fields": {"comment", "extra", "description"},
57
+ "dict_fields": {"extra"},
58
+ },
59
+ "threat_intel": {
60
+ "fields": {"comment", "extra", "level", "taxonomies"},
61
+ "dict_fields": {"extra"},
62
+ },
63
+ "enrichment": {
64
+ "fields": {"context", "data"},
65
+ "dict_fields": {"data"},
66
+ },
67
+ "tag": {
68
+ "fields": {"description"},
69
+ "dict_fields": set(),
70
+ },
71
+ }
72
+
73
+ def __init__(
74
+ self,
75
+ root_data: Any = None,
76
+ root_type: ObservableType | Literal["file", "artifact"] = ObservableType.FILE,
77
+ score_mode_obs: ScoreMode | Literal["max", "sum"] = ScoreMode.MAX,
78
+ *,
79
+ investigation_id: str | None = None,
80
+ investigation_name: str | None = None,
81
+ ) -> None:
82
+ """
83
+ Initialize a new investigation.
84
+
85
+ Args:
86
+ root_data: Data stored on the root observable (optional)
87
+ root_type: Root observable type (ObservableType.FILE or ObservableType.ARTIFACT)
88
+ score_mode_obs: Observable score calculation mode (MAX or SUM)
89
+ investigation_name: Optional human-readable investigation name
90
+ """
91
+ self.investigation_id = investigation_id or generate_ulid()
92
+ self.investigation_name = investigation_name
93
+ self._audit_log: list[AuditEvent] = []
94
+ self._audit_enabled = True
95
+
96
+ # Record investigation start as the first event
97
+ self._record_event(
98
+ event_type="INVESTIGATION_STARTED",
99
+ object_type="investigation",
100
+ object_key=self.investigation_id,
101
+ )
102
+
103
+ # Object collections
104
+ self._observables: dict[str, Observable] = {}
105
+ self._checks: dict[str, Check] = {}
106
+ self._threat_intels: dict[str, ThreatIntel] = {}
107
+ self._enrichments: dict[str, Enrichment] = {}
108
+ self._tags: dict[str, Tag] = {}
109
+
110
+ # Internal components
111
+ normalized_score_mode_obs = ScoreMode.normalize(score_mode_obs)
112
+ self._score_engine = ScoreEngine(score_mode_obs=normalized_score_mode_obs, sink=self)
113
+ self._stats = InvestigationStats()
114
+ self._whitelists: dict[str, InvestigationWhitelist] = {}
115
+
116
+ # Create root observable
117
+ obj_type = ObservableType.normalize_root_type(root_type)
118
+
119
+ self._root_observable = Observable(
120
+ obs_type=obj_type,
121
+ value="root",
122
+ internal=False,
123
+ whitelisted=False,
124
+ comment="Root observable for investigation",
125
+ extra=root_data,
126
+ score=Decimal("0"),
127
+ level=Level.INFO,
128
+ )
129
+ self._observables[self._root_observable.key] = self._root_observable
130
+ self._score_engine.register_observable(self._root_observable)
131
+ self._stats.register_observable(self._root_observable)
132
+ self._record_event(
133
+ event_type="OBSERVABLE_CREATED",
134
+ object_type="observable",
135
+ object_key=self._root_observable.key,
136
+ )
137
+
138
+ def _record_event(
139
+ self,
140
+ *,
141
+ event_type: str,
142
+ object_type: str | None = None,
143
+ object_key: str | None = None,
144
+ reason: str | None = None,
145
+ actor: str | None = None,
146
+ tool: str | None = None,
147
+ details: dict[str, Any] | None = None,
148
+ timestamp: datetime | None = None,
149
+ ) -> AuditEvent | None:
150
+ if not self._audit_enabled:
151
+ return None
152
+
153
+ event = AuditEvent(
154
+ event_id=generate_ulid(),
155
+ timestamp=timestamp or datetime.now(timezone.utc),
156
+ event_type=event_type,
157
+ actor=actor,
158
+ reason=reason,
159
+ tool=tool,
160
+ object_type=object_type,
161
+ object_key=object_key,
162
+ details=deepcopy(details) if details else {},
163
+ )
164
+ self._audit_log.append(event)
165
+ return event
166
+
167
+ @property
168
+ def started_at(self) -> datetime:
169
+ """Return the investigation start time from the first event in the audit log."""
170
+ for event in self._audit_log:
171
+ if event.event_type == "INVESTIGATION_STARTED":
172
+ return event.timestamp
173
+ # Fallback if no INVESTIGATION_STARTED event (shouldn't happen)
174
+ return datetime.now(timezone.utc)
175
+
176
+ def _link_threat_intel_to_observable(self, observable: Observable, ti: ThreatIntel) -> None:
177
+ if any(existing.key == ti.key for existing in observable.threat_intels):
178
+ return
179
+ observable.threat_intels.append(ti)
180
+
181
+ def _create_relationship(
182
+ self,
183
+ source_obs: Observable,
184
+ target_key: str,
185
+ relationship_type: RelationshipType | str,
186
+ direction: RelationshipDirection | str | None = None,
187
+ ) -> None:
188
+ rel = Relationship(target_key=target_key, relationship_type=relationship_type, direction=direction)
189
+ rel_tuple = (rel.target_key, rel.relationship_type, rel.direction)
190
+ existing_rels = {(r.target_key, r.relationship_type, r.direction) for r in source_obs.relationships}
191
+ if rel_tuple not in existing_rels:
192
+ source_obs.relationships.append(rel)
193
+
194
+ def _link_check_to_observable(self, check: Check, link: ObservableLink) -> bool:
195
+ existing: dict[tuple[str, PropagationMode], int] = {}
196
+ for idx, existing_link in enumerate(check.observable_links):
197
+ existing[(existing_link.observable_key, existing_link.propagation_mode)] = idx
198
+ link_tuple = (link.observable_key, link.propagation_mode)
199
+ if link_tuple in existing:
200
+ return False
201
+
202
+ check.observable_links.append(link)
203
+ return True
204
+
205
+ def _link_check_to_tag(self, tag: Tag, check: Check) -> None:
206
+ if any(existing.key == check.key for existing in tag.checks):
207
+ return
208
+ tag.checks.append(check)
209
+
210
+ def _get_object_type(self, obj: Any) -> str | None:
211
+ if isinstance(obj, Observable):
212
+ return "observable"
213
+ if isinstance(obj, Check):
214
+ return "check"
215
+ if isinstance(obj, ThreatIntel):
216
+ return "threat_intel"
217
+ if isinstance(obj, Enrichment):
218
+ return "enrichment"
219
+ if isinstance(obj, Tag):
220
+ return "tag"
221
+ return None
222
+
223
+ @staticmethod
224
+ def _normalize_taxonomies(value: Any) -> list[Taxonomy]:
225
+ if value is None:
226
+ return []
227
+ if not isinstance(value, list):
228
+ raise TypeError("taxonomies must be a list of taxonomy objects.")
229
+ taxonomies = [Taxonomy.model_validate(item) for item in value]
230
+ seen: set[str] = set()
231
+ duplicates: set[str] = set()
232
+ for taxonomy in taxonomies:
233
+ if taxonomy.name in seen:
234
+ duplicates.add(taxonomy.name)
235
+ seen.add(taxonomy.name)
236
+ if duplicates:
237
+ dupes = ", ".join(sorted(duplicates))
238
+ raise ValueError(f"Duplicate taxonomy name(s): {dupes}")
239
+ return taxonomies
240
+
241
+ def apply_score_change(
242
+ self,
243
+ obj: Any,
244
+ new_score: Decimal,
245
+ *,
246
+ reason: str = "",
247
+ event_type: str = "SCORE_CHANGED",
248
+ contributing_investigation_ids: set[str] | None = None,
249
+ ) -> bool:
250
+ """Apply a score change and emit an audit event."""
251
+ if not isinstance(new_score, Decimal):
252
+ new_score = Decimal(str(new_score))
253
+
254
+ old_score = obj.score
255
+ old_level = obj.level
256
+ new_level = recalculate_level_for_score(old_level, new_score)
257
+
258
+ if new_score == old_score and new_level == old_level:
259
+ return False
260
+
261
+ obj.score = new_score
262
+ obj.level = new_level
263
+
264
+ if event_type == "SCORE_RECALCULATED":
265
+ # Skip audit log entry for recalculated scores.
266
+ return True
267
+
268
+ details = {
269
+ "old_score": float(old_score),
270
+ "new_score": float(new_score),
271
+ "old_level": old_level.value,
272
+ "new_level": new_level.value,
273
+ }
274
+ if contributing_investigation_ids:
275
+ details["contributing_investigation_ids"] = sorted(contributing_investigation_ids)
276
+
277
+ self._record_event(
278
+ event_type=event_type,
279
+ object_type=self._get_object_type(obj),
280
+ object_key=getattr(obj, "key", None),
281
+ reason=reason,
282
+ details=details,
283
+ )
284
+ return True
285
+
286
+ def apply_level_change(
287
+ self,
288
+ obj: Any,
289
+ level: Level | str,
290
+ *,
291
+ reason: str = "",
292
+ event_type: str = "LEVEL_UPDATED",
293
+ ) -> bool:
294
+ """Apply a level change and emit an audit event."""
295
+ new_level = normalize_level(level)
296
+ old_level = obj.level
297
+ if new_level == old_level:
298
+ return False
299
+
300
+ obj.level = new_level
301
+ self._record_event(
302
+ event_type=event_type,
303
+ object_type=self._get_object_type(obj),
304
+ object_key=getattr(obj, "key", None),
305
+ reason=reason,
306
+ details={
307
+ "old_level": old_level.value,
308
+ "new_level": new_level.value,
309
+ "score": float(obj.score),
310
+ },
311
+ )
312
+ return True
313
+
314
+ def _update_observable_check_links(self, observable_key: str) -> None:
315
+ obs = self._observables.get(observable_key)
316
+ if not obs:
317
+ return
318
+ check_keys = self._score_engine.get_check_links_for_observable(observable_key)
319
+ obs._check_links = check_keys
320
+
321
+ def _rebuild_all_check_links(self) -> None:
322
+ for observable_key in self._observables:
323
+ self._update_observable_check_links(observable_key)
324
+
325
+ def get_audit_log(self) -> list[AuditEvent]:
326
+ """Return a deep copy of the audit log."""
327
+ return [event.model_copy(deep=True) for event in self._audit_log]
328
+
329
+ def get_audit_events(
330
+ self,
331
+ *,
332
+ object_type: str | None = None,
333
+ object_key: str | None = None,
334
+ event_type: str | None = None,
335
+ ) -> list[AuditEvent]:
336
+ """Filter audit events by optional object type/key and event type."""
337
+ events = self._audit_log
338
+ if object_type is not None:
339
+ events = [event for event in events if event.object_type == object_type]
340
+ if object_key is not None:
341
+ events = [event for event in events if event.object_key == object_key]
342
+ if event_type is not None:
343
+ events = [event for event in events if event.event_type == event_type]
344
+ return [event.model_copy(deep=True) for event in events]
345
+
346
+ def set_investigation_name(self, name: str | None, *, reason: str | None = None) -> None:
347
+ """Set or clear the human-readable investigation name."""
348
+ name = str(name).strip() if name is not None else None
349
+ if name == self.investigation_name:
350
+ return
351
+ old_name = self.investigation_name
352
+ self.investigation_name = name
353
+ self._record_event(
354
+ event_type="INVESTIGATION_NAME_UPDATED",
355
+ object_type="investigation",
356
+ object_key=self.investigation_id,
357
+ reason=reason,
358
+ details={"old_name": old_name, "new_name": name},
359
+ )
360
+
361
+ def _merge_observable(self, existing: Observable, incoming: Observable) -> tuple[Observable, list]:
362
+ """
363
+ Merge an incoming observable into an existing observable.
364
+
365
+ Strategy:
366
+ - Update score (take maximum)
367
+ - Update level (take maximum)
368
+ - Update extra (merge dicts)
369
+ - Overwrite comment (if incoming is non-empty)
370
+ - Merge threat intels
371
+ - Merge relationships (defer if target missing)
372
+ - Preserve provenance metadata
373
+
374
+ Args:
375
+ existing: The existing observable
376
+ incoming: The incoming observable to merge
377
+
378
+ Returns:
379
+ Tuple of (merged observable, deferred relationships)
380
+ """
381
+ # Normal merge logic for scores and levels (SAFE level protection in Observable.update_score)
382
+ # Take the higher score
383
+ if incoming.score > existing.score:
384
+ self.apply_score_change(
385
+ existing,
386
+ incoming.score,
387
+ reason=f"Merged from {incoming.key}",
388
+ )
389
+
390
+ # Take the higher level
391
+ if incoming.level > existing.level:
392
+ self.apply_level_change(existing, incoming.level, reason=f"Merged from {incoming.key}")
393
+
394
+ # Update extra (merge dictionaries)
395
+ if existing.extra:
396
+ existing.extra.update(incoming.extra)
397
+ elif incoming.extra:
398
+ existing.extra = dict(incoming.extra)
399
+
400
+ # Overwrite comment if incoming is non-empty
401
+ if incoming.comment:
402
+ existing.comment = incoming.comment
403
+
404
+ # Merge whitelisted status (if either is whitelisted, result is whitelisted)
405
+ existing.whitelisted = existing.whitelisted or incoming.whitelisted
406
+
407
+ # Merge internal status (if either is external, result is external)
408
+ existing.internal = existing.internal and incoming.internal
409
+
410
+ # Merge threat intels (avoid duplicates by key)
411
+ existing_ti_keys = {ti.key for ti in existing.threat_intels}
412
+ for ti in incoming.threat_intels:
413
+ if ti.key not in existing_ti_keys:
414
+ self._link_threat_intel_to_observable(existing, ti)
415
+ existing_ti_keys.add(ti.key)
416
+
417
+ # Merge relationships (defer if target not yet available)
418
+ deferred_relationships = []
419
+ for rel in incoming.relationships:
420
+ if rel.target_key in self._observables:
421
+ # Target exists - add relationship immediately
422
+ self._create_relationship(existing, rel.target_key, rel.relationship_type, rel.direction)
423
+ else:
424
+ # Target doesn't exist yet - defer for Pass 2 of merge_investigation()
425
+ deferred_relationships.append((existing.key, rel))
426
+
427
+ return existing, deferred_relationships
428
+
429
+ def _merge_check(self, existing: Check, incoming: Check) -> Check:
430
+ """
431
+ Merge an incoming check into an existing check.
432
+
433
+ Strategy:
434
+ - Update score (take maximum)
435
+ - Update level (take maximum)
436
+ - Update extra (merge dicts)
437
+ - Overwrite description (if incoming is non-empty)
438
+ - Overwrite comment (if incoming is non-empty)
439
+ - Merge observable links (tuple-based deduplication, provenance-preserving)
440
+
441
+ Args:
442
+ existing: The existing check
443
+ incoming: The incoming check to merge
444
+
445
+ Returns:
446
+ The merged check (existing is modified in place)
447
+ """
448
+ if not incoming.origin_investigation_id:
449
+ incoming.origin_investigation_id = self.investigation_id
450
+ if not existing.origin_investigation_id:
451
+ existing.origin_investigation_id = incoming.origin_investigation_id
452
+
453
+ # Take the higher score
454
+ if incoming.score > existing.score:
455
+ self.apply_score_change(
456
+ existing,
457
+ incoming.score,
458
+ reason=f"Merged from {incoming.key}",
459
+ )
460
+
461
+ # Take the higher level
462
+ if incoming.level > existing.level:
463
+ self.apply_level_change(existing, incoming.level, reason=f"Merged from {incoming.key}")
464
+
465
+ # Update extra (merge dictionaries)
466
+ existing.extra.update(incoming.extra)
467
+
468
+ # Overwrite description if incoming is non-empty
469
+ if incoming.description:
470
+ existing.description = incoming.description
471
+
472
+ # Overwrite comment if incoming is non-empty
473
+ if incoming.comment:
474
+ existing.comment = incoming.comment
475
+
476
+ existing_by_tuple: dict[tuple[str, PropagationMode], int] = {}
477
+ for idx, existing_link in enumerate(existing.observable_links):
478
+ existing_by_tuple[(existing_link.observable_key, existing_link.propagation_mode)] = idx
479
+ for incoming_link in incoming.observable_links:
480
+ link_tuple = (incoming_link.observable_key, incoming_link.propagation_mode)
481
+ existing_idx = existing_by_tuple.get(link_tuple)
482
+ if existing_idx is None:
483
+ existing.observable_links.append(incoming_link)
484
+ existing_by_tuple[link_tuple] = len(existing.observable_links) - 1
485
+ continue
486
+
487
+ return existing
488
+
489
+ def _merge_threat_intel(self, existing: ThreatIntel, incoming: ThreatIntel) -> ThreatIntel:
490
+ """
491
+ Merge an incoming threat intel into an existing threat intel.
492
+
493
+ Strategy:
494
+ - Update score (take maximum)
495
+ - Update level (take maximum)
496
+ - Update extra (merge dicts)
497
+ - Concatenate comments
498
+ - Merge taxonomies
499
+
500
+ Args:
501
+ existing: The existing threat intel
502
+ incoming: The incoming threat intel to merge
503
+
504
+ Returns:
505
+ The merged threat intel (existing is modified in place)
506
+ """
507
+ # Take the higher score
508
+ if incoming.score > existing.score:
509
+ self.apply_score_change(
510
+ existing,
511
+ incoming.score,
512
+ reason=f"Merged from {incoming.key}",
513
+ )
514
+
515
+ # Take the higher level
516
+ if incoming.level > existing.level:
517
+ self.apply_level_change(existing, incoming.level, reason=f"Merged from {incoming.key}")
518
+
519
+ # Update extra (merge dictionaries)
520
+ existing.extra.update(incoming.extra)
521
+
522
+ # Concatenate comments
523
+ if incoming.comment:
524
+ if existing.comment:
525
+ existing.comment += "\n\n" + incoming.comment
526
+ else:
527
+ existing.comment = incoming.comment
528
+
529
+ # Merge taxonomies (ensure unique names)
530
+ existing_by_name: dict[str, int] = {taxonomy.name: idx for idx, taxonomy in enumerate(existing.taxonomies)}
531
+ for taxonomy in incoming.taxonomies:
532
+ existing_idx = existing_by_name.get(taxonomy.name)
533
+ if existing_idx is None:
534
+ existing.taxonomies.append(taxonomy)
535
+ existing_by_name[taxonomy.name] = len(existing.taxonomies) - 1
536
+ else:
537
+ existing.taxonomies[existing_idx] = taxonomy
538
+
539
+ return existing
540
+
541
+ def _merge_enrichment(self, existing: Enrichment, incoming: Enrichment) -> Enrichment:
542
+ """
543
+ Merge an incoming enrichment into an existing enrichment.
544
+
545
+ Strategy:
546
+ - Deep merge data structure (merge dictionaries recursively)
547
+
548
+ Args:
549
+ existing: The existing enrichment
550
+ incoming: The incoming enrichment to merge
551
+
552
+ Returns:
553
+ The merged enrichment (existing is modified in place)
554
+ """
555
+
556
+ def deep_merge(base: dict, update: dict) -> dict:
557
+ """Recursively merge dictionaries."""
558
+ for key, value in update.items():
559
+ if key in base and isinstance(base[key], dict) and isinstance(value, dict):
560
+ deep_merge(base[key], value)
561
+ else:
562
+ base[key] = value
563
+ return base
564
+
565
+ # Deep merge data structures
566
+ if isinstance(existing.data, dict) and isinstance(incoming.data, dict):
567
+ deep_merge(existing.data, incoming.data)
568
+ else:
569
+ existing.data = deepcopy(incoming.data)
570
+
571
+ # Update context if incoming has one
572
+ if incoming.context:
573
+ existing.context = incoming.context
574
+
575
+ return existing
576
+
577
+ def _merge_tag(self, existing: Tag, incoming: Tag) -> Tag:
578
+ """
579
+ Merge an incoming tag into an existing tag.
580
+
581
+ Strategy:
582
+ - Merge checks (dict-based lookup for efficiency)
583
+
584
+ Args:
585
+ existing: The existing tag
586
+ incoming: The incoming tag to merge
587
+
588
+ Returns:
589
+ The merged tag (existing is modified in place)
590
+ """
591
+ # Update description if incoming has one
592
+ if incoming.description:
593
+ existing.description = incoming.description
594
+
595
+ # Merge checks using dict-based lookup (more efficient)
596
+ existing_checks_dict = {check.key: check for check in existing.checks}
597
+
598
+ for incoming_check in incoming.checks:
599
+ if incoming_check.key in existing_checks_dict:
600
+ # Merge existing check
601
+ self._merge_check(existing_checks_dict[incoming_check.key], incoming_check)
602
+ else:
603
+ # Add new check
604
+ self._link_check_to_tag(existing, incoming_check)
605
+
606
+ return existing
607
+
608
+ def _clone_for_merge(
609
+ self, other: Investigation
610
+ ) -> tuple[
611
+ dict[str, Observable],
612
+ dict[str, ThreatIntel],
613
+ dict[str, Check],
614
+ dict[str, Enrichment],
615
+ dict[str, Tag],
616
+ ]:
617
+ """Clone incoming models while preserving shared object references."""
618
+ incoming_threat_intels = {key: ti.model_copy(deep=True) for key, ti in other._threat_intels.items()}
619
+ incoming_checks = {key: check.model_copy(deep=True) for key, check in other._checks.items()}
620
+ incoming_enrichments = {key: enrichment.model_copy(deep=True) for key, enrichment in other._enrichments.items()}
621
+
622
+ orphan_threat_intels: dict[str, ThreatIntel] = {}
623
+
624
+ def _copy_threat_intel(ti: ThreatIntel) -> ThreatIntel:
625
+ if ti.key in incoming_threat_intels:
626
+ return incoming_threat_intels[ti.key]
627
+ existing = orphan_threat_intels.get(ti.key)
628
+ if existing:
629
+ return existing
630
+ copied = ti.model_copy(deep=True)
631
+ orphan_threat_intels[ti.key] = copied
632
+ return copied
633
+
634
+ incoming_observables: dict[str, Observable] = {}
635
+ for obs in other._observables.values():
636
+ copied_obs = obs.model_copy(deep=True)
637
+ if obs.threat_intels:
638
+ copied_obs.threat_intels = [_copy_threat_intel(ti) for ti in obs.threat_intels]
639
+ incoming_observables[obs.key] = copied_obs
640
+
641
+ orphan_checks: dict[str, Check] = {}
642
+
643
+ def _copy_check(check: Check) -> Check:
644
+ if check.key in incoming_checks:
645
+ return incoming_checks[check.key]
646
+ existing = orphan_checks.get(check.key)
647
+ if existing:
648
+ return existing
649
+ copied = check.model_copy(deep=True)
650
+ orphan_checks[check.key] = copied
651
+ return copied
652
+
653
+ incoming_tags: dict[str, Tag] = {}
654
+
655
+ def _copy_tag(tag: Tag) -> Tag:
656
+ existing = incoming_tags.get(tag.key)
657
+ if existing:
658
+ return existing
659
+ copied = Tag(
660
+ name=tag.name,
661
+ description=tag.description,
662
+ checks=[_copy_check(check) for check in tag.checks],
663
+ key=tag.key,
664
+ )
665
+ incoming_tags[tag.key] = copied
666
+ return copied
667
+
668
+ for tag in other._tags.values():
669
+ _copy_tag(tag)
670
+
671
+ return (
672
+ incoming_observables,
673
+ incoming_threat_intels,
674
+ incoming_checks,
675
+ incoming_enrichments,
676
+ incoming_tags,
677
+ )
678
+
679
+ def add_observable(self, obs: Observable) -> tuple[Observable, list]:
680
+ """
681
+ Add or merge an observable.
682
+
683
+ Args:
684
+ obs: Observable to add or merge
685
+
686
+ Returns:
687
+ Tuple of (resulting observable, deferred relationships)
688
+ """
689
+ if obs.key in self._observables:
690
+ r = self._merge_observable(self._observables[obs.key], obs)
691
+ self._score_engine.recalculate_all()
692
+ return r
693
+
694
+ # Register new observable
695
+ self._observables[obs.key] = obs
696
+ self._score_engine.register_observable(obs)
697
+ self._stats.register_observable(obs)
698
+ self._update_observable_check_links(obs.key)
699
+ self._record_event(
700
+ event_type="OBSERVABLE_CREATED",
701
+ object_type="observable",
702
+ object_key=obs.key,
703
+ )
704
+ return obs, []
705
+
706
+ def add_check(self, check: Check) -> Check:
707
+ """
708
+ Add or merge a check.
709
+
710
+ Args:
711
+ check: Check to add or merge
712
+
713
+ Returns:
714
+ The resulting check (either new or merged)
715
+ """
716
+ if check.key in self._checks:
717
+ r = self._merge_check(self._checks[check.key], check)
718
+ self._score_engine.rebuild_link_index()
719
+ self._score_engine.recalculate_all()
720
+ for link in r.observable_links:
721
+ self._update_observable_check_links(link.observable_key)
722
+ return r
723
+
724
+ if not getattr(check, "origin_investigation_id", None):
725
+ check.origin_investigation_id = self.investigation_id
726
+
727
+ # Register new check
728
+ self._checks[check.key] = check
729
+ self._score_engine.register_check(check)
730
+ self._stats.register_check(check)
731
+ for link in check.observable_links:
732
+ self._update_observable_check_links(link.observable_key)
733
+ self._record_event(
734
+ event_type="CHECK_CREATED",
735
+ object_type="check",
736
+ object_key=check.key,
737
+ )
738
+ return check
739
+
740
+ def add_threat_intel(self, ti: ThreatIntel, observable: Observable) -> ThreatIntel:
741
+ """
742
+ Add or merge threat intel and link to observable.
743
+
744
+ Args:
745
+ ti: Threat intel to add or merge
746
+ observable: Observable to link to
747
+
748
+ Returns:
749
+ The resulting threat intel (either new or merged)
750
+ """
751
+ if ti.key in self._threat_intels:
752
+ merged_ti = self._merge_threat_intel(self._threat_intels[ti.key], ti)
753
+ # Propagate score to observable
754
+ self._score_engine.propagate_threat_intel_to_observable(merged_ti, observable)
755
+ self._record_event(
756
+ event_type="THREAT_INTEL_ATTACHED",
757
+ object_type="observable",
758
+ object_key=observable.key,
759
+ details={
760
+ "threat_intel_key": merged_ti.key,
761
+ "source": merged_ti.source,
762
+ "score": merged_ti.score,
763
+ "level": merged_ti.level,
764
+ },
765
+ )
766
+ return merged_ti
767
+
768
+ # Register new threat intel
769
+ self._threat_intels[ti.key] = ti
770
+ self._stats.register_threat_intel(ti)
771
+
772
+ # Add to observable
773
+ self._link_threat_intel_to_observable(observable, ti)
774
+
775
+ # Propagate score
776
+ self._score_engine.propagate_threat_intel_to_observable(ti, observable)
777
+
778
+ self._record_event(
779
+ event_type="THREAT_INTEL_ATTACHED",
780
+ object_type="observable",
781
+ object_key=observable.key,
782
+ details={
783
+ "threat_intel_key": ti.key,
784
+ "source": ti.source,
785
+ "score": ti.score,
786
+ "level": ti.level,
787
+ },
788
+ )
789
+ return ti
790
+
791
+ def add_threat_intel_taxonomy(self, threat_intel_key: str, taxonomy: Taxonomy) -> ThreatIntel:
792
+ """
793
+ Add or replace a taxonomy entry on a threat intel by name.
794
+
795
+ Args:
796
+ threat_intel_key: Threat intel key
797
+ taxonomy: Taxonomy entry to add or replace
798
+
799
+ Returns:
800
+ The updated threat intel
801
+ """
802
+ ti = self._threat_intels.get(threat_intel_key)
803
+ if ti is None:
804
+ raise KeyError(f"threat_intel '{threat_intel_key}' not found in investigation.")
805
+
806
+ updated_taxonomies = list(ti.taxonomies)
807
+ replaced = False
808
+ for idx, existing in enumerate(updated_taxonomies):
809
+ if existing.name == taxonomy.name:
810
+ updated_taxonomies[idx] = taxonomy
811
+ replaced = True
812
+ break
813
+
814
+ if not replaced:
815
+ updated_taxonomies.append(taxonomy)
816
+
817
+ return self.update_model_metadata("threat_intel", threat_intel_key, {"taxonomies": updated_taxonomies})
818
+
819
+ def remove_threat_intel_taxonomy(self, threat_intel_key: str, name: str) -> ThreatIntel:
820
+ """
821
+ Remove a taxonomy entry from a threat intel by name.
822
+
823
+ Args:
824
+ threat_intel_key: Threat intel key
825
+ name: Taxonomy name to remove
826
+
827
+ Returns:
828
+ The updated threat intel
829
+ """
830
+ ti = self._threat_intels.get(threat_intel_key)
831
+ if ti is None:
832
+ raise KeyError(f"threat_intel '{threat_intel_key}' not found in investigation.")
833
+
834
+ updated_taxonomies = [taxonomy for taxonomy in ti.taxonomies if taxonomy.name != name]
835
+ if len(updated_taxonomies) == len(ti.taxonomies):
836
+ return ti
837
+
838
+ return self.update_model_metadata("threat_intel", threat_intel_key, {"taxonomies": updated_taxonomies})
839
+
840
+ def add_enrichment(self, enrichment: Enrichment) -> Enrichment:
841
+ """
842
+ Add or merge enrichment.
843
+
844
+ Args:
845
+ enrichment: Enrichment to add or merge
846
+
847
+ Returns:
848
+ The resulting enrichment (either new or merged)
849
+ """
850
+ if enrichment.key in self._enrichments:
851
+ return self._merge_enrichment(self._enrichments[enrichment.key], enrichment)
852
+
853
+ # Register new enrichment
854
+ self._enrichments[enrichment.key] = enrichment
855
+ self._record_event(
856
+ event_type="ENRICHMENT_CREATED",
857
+ object_type="enrichment",
858
+ object_key=enrichment.key,
859
+ )
860
+ return enrichment
861
+
862
+ def add_tag(self, tag: Tag) -> Tag:
863
+ """
864
+ Add or merge a tag, automatically creating ancestor tags.
865
+
866
+ When adding a tag with a hierarchical name (using ":" delimiter),
867
+ ancestor tags are automatically created if they don't exist.
868
+ For example, adding "header:auth:dkim" will auto-create
869
+ "header" and "header:auth" tags.
870
+
871
+ Args:
872
+ tag: Tag to add or merge
873
+
874
+ Returns:
875
+ The resulting tag (either new or merged)
876
+ """
877
+ # Auto-create ancestor tags
878
+ ancestor_names = keys.get_tag_ancestors(tag.name)
879
+ for ancestor_name in ancestor_names:
880
+ ancestor_key = keys.generate_tag_key(ancestor_name)
881
+ if ancestor_key not in self._tags:
882
+ ancestor_tag = Tag(name=ancestor_name)
883
+ self._tags[ancestor_key] = ancestor_tag
884
+ self._stats.register_tag(ancestor_tag)
885
+ self._record_event(
886
+ event_type="TAG_CREATED",
887
+ object_type="tag",
888
+ object_key=ancestor_key,
889
+ details={"auto_created": True, "descendant": tag.name},
890
+ )
891
+
892
+ # Add or merge the tag itself
893
+ if tag.key in self._tags:
894
+ r = self._merge_tag(self._tags[tag.key], tag)
895
+ self._score_engine.recalculate_all()
896
+ return r
897
+
898
+ # Register new tag
899
+ self._tags[tag.key] = tag
900
+ self._stats.register_tag(tag)
901
+ self._record_event(
902
+ event_type="TAG_CREATED",
903
+ object_type="tag",
904
+ object_key=tag.key,
905
+ )
906
+ return tag
907
+
908
+ def add_relationship(
909
+ self,
910
+ source: Observable | str,
911
+ target: Observable | str,
912
+ relationship_type: RelationshipType | str,
913
+ direction: RelationshipDirection | str | None = None,
914
+ ) -> Observable:
915
+ """
916
+ Add a relationship between observables.
917
+
918
+ Args:
919
+ source: Source observable or its key
920
+ target: Target observable or its key
921
+ relationship_type: Type of relationship
922
+ direction: Direction of the relationship (None = use semantic default)
923
+
924
+ Returns:
925
+ The source observable
926
+
927
+ Raises:
928
+ KeyError: If the source or target observable does not exist
929
+ """
930
+
931
+ # Extract keys from Observable objects if needed
932
+ source_key = source.key if isinstance(source, Observable) else source
933
+ target_key = target.key if isinstance(target, Observable) else target
934
+
935
+ # Check if target is a copy from shared context (anti-pattern)
936
+ if isinstance(target, Observable) and getattr(target, "_from_shared_context", False):
937
+ obs_type_name = target.obs_type.name
938
+ raise ValueError(
939
+ f"Cannot use observable from shared_context.observable_get() directly in relationships.\n"
940
+ f"Observable '{target_key}' is a read-only copy not registered in this investigation.\n\n"
941
+ f"Incorrect pattern:\n"
942
+ f" source.relate_to(shared_context.observable_get(...), RelationshipType.{relationship_type})\n\n"
943
+ f"Correct pattern (and use reconcile or merge):\n"
944
+ f" # Use cy.observable() to create/get observable in local investigation\n"
945
+ f" source.relate_to(\n"
946
+ f" cy.observable(ObservableType.{obs_type_name}, '{target.value}'),\n"
947
+ f" RelationshipType.{relationship_type}\n"
948
+ f" )"
949
+ )
950
+
951
+ # Validate both source and target exist
952
+ source_obs = self._observables.get(source_key)
953
+ target_obs = self._observables.get(target_key)
954
+
955
+ if not source_obs:
956
+ raise KeyError(f"observable '{source_key}' not found in investigation.")
957
+
958
+ if not target_obs:
959
+ raise KeyError(f"observable '{target_key}' not found in investigation.")
960
+
961
+ # Add relationship using internal method
962
+ self._create_relationship(source_obs, target_key, relationship_type, direction)
963
+
964
+ self._record_event(
965
+ event_type="RELATIONSHIP_CREATED",
966
+ object_type="observable",
967
+ object_key=source_obs.key,
968
+ details={
969
+ "target_key": target_key,
970
+ "relationship_type": relationship_type,
971
+ "direction": direction,
972
+ },
973
+ )
974
+
975
+ # Recalculate scores after adding relationship
976
+ self._score_engine.recalculate_all()
977
+
978
+ return source_obs
979
+
980
+ def link_check_observable(
981
+ self,
982
+ check_key: str,
983
+ observable_key: str,
984
+ propagation_mode: PropagationMode | str = PropagationMode.LOCAL_ONLY,
985
+ ) -> Check:
986
+ """
987
+ Link an observable to a check.
988
+
989
+ Args:
990
+ check_key: Key of the check
991
+ observable_key: Key of the observable
992
+ propagation_mode: Propagation behavior for this link
993
+
994
+ Returns:
995
+ The check
996
+
997
+ Raises:
998
+ KeyError: If the check or observable does not exist
999
+ """
1000
+ check = self._checks.get(check_key)
1001
+ observable = self._observables.get(observable_key)
1002
+
1003
+ if check is None:
1004
+ raise KeyError(f"check '{check_key}' not found in investigation.")
1005
+ if observable is None:
1006
+ raise KeyError(f"observable '{observable_key}' not found in investigation.")
1007
+
1008
+ if check and observable:
1009
+ propagation_mode = PropagationMode(propagation_mode)
1010
+ link = ObservableLink(
1011
+ observable_key=observable_key,
1012
+ propagation_mode=propagation_mode,
1013
+ )
1014
+ created = self._link_check_to_observable(check, link)
1015
+ if created:
1016
+ self._score_engine.register_check_observable_link(check_key=check.key, observable_key=observable_key)
1017
+ self._update_observable_check_links(observable_key)
1018
+ self._record_event(
1019
+ event_type="CHECK_LINKED_TO_OBSERVABLE",
1020
+ object_type="check",
1021
+ object_key=check.key,
1022
+ details={
1023
+ "observable_key": observable_key,
1024
+ "propagation_mode": propagation_mode.value,
1025
+ },
1026
+ )
1027
+ is_effective = (
1028
+ propagation_mode == PropagationMode.GLOBAL or self.investigation_id == check.origin_investigation_id
1029
+ )
1030
+ if is_effective and check.level == Level.NONE:
1031
+ self.apply_level_change(check, Level.INFO, reason="Effective link added")
1032
+
1033
+ self._score_engine._propagate_observable_to_checks(observable_key)
1034
+
1035
+ return check
1036
+
1037
+ def add_check_to_tag(self, tag_key: str, check_key: str) -> Tag:
1038
+ """
1039
+ Add a check to a tag.
1040
+
1041
+ Args:
1042
+ tag_key: Key of the tag
1043
+ check_key: Key of the check
1044
+
1045
+ Returns:
1046
+ The tag
1047
+
1048
+ Raises:
1049
+ KeyError: If the tag or check does not exist
1050
+ """
1051
+ tag = self._tags.get(tag_key)
1052
+ check = self._checks.get(check_key)
1053
+
1054
+ if tag is None:
1055
+ raise KeyError(f"tag '{tag_key}' not found in investigation.")
1056
+ if check is None:
1057
+ raise KeyError(f"check '{check_key}' not found in investigation.")
1058
+
1059
+ if tag and check:
1060
+ self._link_check_to_tag(tag, check)
1061
+ self._record_event(
1062
+ event_type="TAG_CHECK_ADDED",
1063
+ object_type="tag",
1064
+ object_key=tag.key,
1065
+ details={"check_key": check.key},
1066
+ )
1067
+
1068
+ return tag
1069
+
1070
+ def get_root(self) -> Observable:
1071
+ """Get the root observable."""
1072
+ return self._root_observable
1073
+
1074
+ def get_observable(self, key: str) -> Observable | None:
1075
+ """Get observable by full key string."""
1076
+ return self._observables.get(key)
1077
+
1078
+ def get_check(self, key: str) -> Check | None:
1079
+ """Get check by full key string."""
1080
+ return self._checks.get(key)
1081
+
1082
+ def get_tag(self, key: str) -> Tag | None:
1083
+ """Get a tag by key."""
1084
+ return self._tags.get(key)
1085
+
1086
+ def get_tag_children(self, tag_name: str) -> list[Tag]:
1087
+ """
1088
+ Get direct child tags of a tag.
1089
+
1090
+ Args:
1091
+ tag_name: Name of the parent tag
1092
+
1093
+ Returns:
1094
+ List of direct child Tag objects
1095
+ """
1096
+ return [t for t in self._tags.values() if keys.is_tag_child_of(t.name, tag_name)]
1097
+
1098
+ def get_tag_descendants(self, tag_name: str) -> list[Tag]:
1099
+ """
1100
+ Get all descendant tags of a tag.
1101
+
1102
+ Args:
1103
+ tag_name: Name of the ancestor tag
1104
+
1105
+ Returns:
1106
+ List of all descendant Tag objects
1107
+ """
1108
+ return [t for t in self._tags.values() if keys.is_tag_descendant_of(t.name, tag_name)]
1109
+
1110
+ def get_tag_ancestors(self, tag_name: str) -> list[Tag]:
1111
+ """
1112
+ Get all ancestor tags of a tag.
1113
+
1114
+ Args:
1115
+ tag_name: Name of the descendant tag
1116
+
1117
+ Returns:
1118
+ List of ancestor Tag objects (in order from root to immediate parent)
1119
+ """
1120
+ ancestor_names = keys.get_tag_ancestors(tag_name)
1121
+ result = []
1122
+ for name in ancestor_names:
1123
+ tag_key = keys.generate_tag_key(name)
1124
+ if tag_key in self._tags:
1125
+ result.append(self._tags[tag_key])
1126
+ return result
1127
+
1128
+ def get_tag_aggregated_score(self, tag_name: str) -> Decimal:
1129
+ """
1130
+ Get aggregated score for a tag including all descendants.
1131
+
1132
+ Args:
1133
+ tag_name: Name of the tag
1134
+
1135
+ Returns:
1136
+ Total score from direct checks and all descendant tag checks
1137
+ """
1138
+ tag_key = keys.generate_tag_key(tag_name)
1139
+ tag = self._tags.get(tag_key)
1140
+ if not tag:
1141
+ return Decimal("0")
1142
+
1143
+ total = tag.get_direct_score()
1144
+
1145
+ # Add scores from direct children only (they will recursively add their children)
1146
+ for child in self.get_tag_children(tag_name):
1147
+ total += self.get_tag_aggregated_score(child.name)
1148
+
1149
+ return total
1150
+
1151
+ def get_tag_aggregated_level(self, tag_name: str) -> Level:
1152
+ """
1153
+ Get aggregated level for a tag including all descendants.
1154
+
1155
+ Args:
1156
+ tag_name: Name of the tag
1157
+
1158
+ Returns:
1159
+ Level based on aggregated score
1160
+ """
1161
+ from cyvest.levels import get_level_from_score
1162
+
1163
+ return get_level_from_score(self.get_tag_aggregated_score(tag_name))
1164
+
1165
+ def get_enrichment(self, key: str) -> Enrichment | None:
1166
+ """Get an enrichment by key."""
1167
+ return self._enrichments.get(key)
1168
+
1169
+ def get_threat_intel(self, key: str) -> ThreatIntel | None:
1170
+ """Get a threat intel by key."""
1171
+ return self._threat_intels.get(key)
1172
+
1173
+ def update_model_metadata(
1174
+ self,
1175
+ model_type: Literal["observable", "check", "threat_intel", "enrichment", "tag"],
1176
+ key: str,
1177
+ updates: dict[str, Any],
1178
+ *,
1179
+ dict_merge: dict[str, bool] | None = None,
1180
+ ):
1181
+ """
1182
+ Update mutable metadata fields for a stored model instance.
1183
+
1184
+ Args:
1185
+ model_type: Model family to update.
1186
+ key: Key of the target object.
1187
+ updates: Mapping of field names to new values. ``None`` values are ignored.
1188
+ dict_merge: Optional overrides for dict fields (True=merge, False=replace).
1189
+
1190
+ Returns:
1191
+ The updated model instance.
1192
+
1193
+ Raises:
1194
+ KeyError: If the key cannot be found.
1195
+ ValueError: If an unsupported field is requested.
1196
+ TypeError: If a dict field receives a non-dict value.
1197
+ """
1198
+ store_lookup: dict[str, dict[str, Any]] = {
1199
+ "observable": self._observables,
1200
+ "check": self._checks,
1201
+ "threat_intel": self._threat_intels,
1202
+ "enrichment": self._enrichments,
1203
+ "tag": self._tags,
1204
+ }
1205
+ store = store_lookup[model_type]
1206
+ target = store.get(key)
1207
+ if target is None:
1208
+ raise KeyError(f"{model_type} '{key}' not found in investigation.")
1209
+
1210
+ if not updates:
1211
+ return target
1212
+
1213
+ rules = self._MODEL_METADATA_RULES[model_type]
1214
+ allowed_fields = rules["fields"]
1215
+ dict_fields = rules["dict_fields"]
1216
+
1217
+ changes: dict[str, dict[str, Any]] = {}
1218
+
1219
+ for field, value in updates.items():
1220
+ if field not in allowed_fields:
1221
+ raise ValueError(f"Field '{field}' is not mutable on {model_type}.")
1222
+ if value is None:
1223
+ continue
1224
+ old_value = deepcopy(getattr(target, field, None))
1225
+ if field == "level":
1226
+ value = normalize_level(value)
1227
+ if model_type == "threat_intel" and field == "taxonomies":
1228
+ value = self._normalize_taxonomies(value)
1229
+ if field in dict_fields:
1230
+ if not isinstance(value, dict):
1231
+ raise TypeError(f"Field '{field}' on {model_type} expects a dict value.")
1232
+ merge = dict_merge.get(field, True) if dict_merge else True
1233
+ if merge:
1234
+ current_value = getattr(target, field, None)
1235
+ if current_value is None:
1236
+ setattr(target, field, deepcopy(value))
1237
+ else:
1238
+ current_value.update(value)
1239
+ else:
1240
+ setattr(target, field, deepcopy(value))
1241
+ else:
1242
+ setattr(target, field, value)
1243
+ new_value = deepcopy(getattr(target, field, None))
1244
+ if old_value != new_value:
1245
+ changes[field] = {"old": old_value, "new": new_value}
1246
+
1247
+ if changes:
1248
+ self._record_event(
1249
+ event_type="METADATA_UPDATED",
1250
+ object_type=model_type,
1251
+ object_key=key,
1252
+ details={"changes": changes},
1253
+ )
1254
+ return target
1255
+
1256
+ def get_all_observables(self) -> dict[str, Observable]:
1257
+ """Get all observables."""
1258
+ return self._observables.copy()
1259
+
1260
+ def get_all_checks(self) -> dict[str, Check]:
1261
+ """Get all checks."""
1262
+ return self._checks.copy()
1263
+
1264
+ def get_all_threat_intels(self) -> dict[str, ThreatIntel]:
1265
+ """Get all threat intels."""
1266
+ return self._threat_intels.copy()
1267
+
1268
+ def get_all_enrichments(self) -> dict[str, Enrichment]:
1269
+ """Get all enrichments."""
1270
+ return self._enrichments.copy()
1271
+
1272
+ def get_all_tags(self) -> dict[str, Tag]:
1273
+ """Get all tags."""
1274
+ return self._tags.copy()
1275
+
1276
+ def get_global_score(self) -> Decimal:
1277
+ """Get the global investigation score."""
1278
+ return self._score_engine.get_global_score()
1279
+
1280
+ def get_global_level(self) -> Level:
1281
+ """Get the global investigation level."""
1282
+ return self._score_engine.get_global_level()
1283
+
1284
+ def is_whitelisted(self) -> bool:
1285
+ """Return whether the investigation has any whitelist entries."""
1286
+ return bool(self._whitelists)
1287
+
1288
+ def add_whitelist(self, identifier: str, name: str, justification: str | None = None) -> InvestigationWhitelist:
1289
+ """
1290
+ Add or update a whitelist entry.
1291
+
1292
+ Args:
1293
+ identifier: Unique identifier for this whitelist entry.
1294
+ name: Human-readable name for the whitelist entry.
1295
+ justification: Optional markdown justification.
1296
+
1297
+ Returns:
1298
+ The stored whitelist entry.
1299
+ """
1300
+ identifier = str(identifier).strip()
1301
+ name = str(name).strip()
1302
+ if not identifier:
1303
+ raise ValueError("Whitelist identifier must be provided.")
1304
+ if not name:
1305
+ raise ValueError("Whitelist name must be provided.")
1306
+ if justification is not None:
1307
+ justification = str(justification)
1308
+
1309
+ entry = InvestigationWhitelist(identifier=identifier, name=name, justification=justification)
1310
+ self._whitelists[identifier] = entry
1311
+ self._record_event(
1312
+ event_type="WHITELIST_APPLIED",
1313
+ object_type="investigation",
1314
+ object_key=self.investigation_id,
1315
+ details={
1316
+ "identifier": identifier,
1317
+ "name": name,
1318
+ "justification": justification,
1319
+ },
1320
+ )
1321
+ return entry
1322
+
1323
+ def remove_whitelist(self, identifier: str) -> bool:
1324
+ """
1325
+ Remove a whitelist entry by identifier.
1326
+
1327
+ Returns:
1328
+ True if removed, False if it did not exist.
1329
+ """
1330
+ removed = self._whitelists.pop(identifier, None)
1331
+ if removed:
1332
+ self._record_event(
1333
+ event_type="WHITELIST_REMOVED",
1334
+ object_type="investigation",
1335
+ object_key=self.investigation_id,
1336
+ details={"identifier": identifier},
1337
+ )
1338
+ return removed is not None
1339
+
1340
+ def clear_whitelists(self) -> None:
1341
+ """Remove all whitelist entries."""
1342
+ if not self._whitelists:
1343
+ return
1344
+ removed = list(self._whitelists.keys())
1345
+ self._whitelists.clear()
1346
+ self._record_event(
1347
+ event_type="WHITELIST_CLEARED",
1348
+ object_type="investigation",
1349
+ object_key=self.investigation_id,
1350
+ details={"identifiers": removed},
1351
+ )
1352
+
1353
+ def get_whitelists(self) -> list[InvestigationWhitelist]:
1354
+ """Return a copy of all whitelist entries."""
1355
+ return [w.model_copy(deep=True) for w in self._whitelists.values()]
1356
+
1357
+ def get_statistics(self) -> StatisticsSchema:
1358
+ """Get comprehensive investigation statistics."""
1359
+ return self._stats.get_summary()
1360
+
1361
+ def finalize_relationships(self) -> None:
1362
+ """
1363
+ Finalize observable relationships by linking orphans to root.
1364
+
1365
+ Detects orphan sub-graphs (connected components not linked to root) and links
1366
+ the most appropriate starting node of each sub-graph to root.
1367
+ """
1368
+ root_key = self._root_observable.key
1369
+
1370
+ # Build adjacency lists for graph traversal
1371
+ graph = {key: set() for key in self._observables.keys()}
1372
+ incoming = {key: set() for key in self._observables.keys()}
1373
+
1374
+ for obs_key, obs in self._observables.items():
1375
+ for rel in obs.relationships:
1376
+ if rel.target_key in self._observables:
1377
+ graph[obs_key].add(rel.target_key)
1378
+ incoming[rel.target_key].add(obs_key)
1379
+
1380
+ # Find all connected components using BFS
1381
+ visited = set()
1382
+ components = []
1383
+
1384
+ def bfs(start_key: str) -> set[str]:
1385
+ """Breadth-first search to find connected component."""
1386
+ component = set()
1387
+ queue = [start_key]
1388
+ component.add(start_key)
1389
+
1390
+ while queue:
1391
+ current = queue.pop(0)
1392
+ # Check both outgoing and incoming edges for connectivity
1393
+ neighbors = graph[current] | incoming[current]
1394
+ for neighbor in neighbors:
1395
+ if neighbor not in component:
1396
+ component.add(neighbor)
1397
+ queue.append(neighbor)
1398
+
1399
+ return component
1400
+
1401
+ # Find all connected components
1402
+ for obs_key in self._observables.keys():
1403
+ if obs_key not in visited:
1404
+ component = bfs(obs_key)
1405
+ visited.update(component)
1406
+ components.append(component)
1407
+
1408
+ # Process each component that doesn't include root
1409
+ for component in components:
1410
+ if root_key in component:
1411
+ continue # This component is already connected to root
1412
+
1413
+ # Find the best starting node in this orphan sub-graph
1414
+ # Prioritize nodes with:
1415
+ # 1. No incoming edges (true source nodes)
1416
+ # 2. Most outgoing edges (central nodes)
1417
+ best_node = None
1418
+ best_score = (-1, -1) # (negative incoming count, outgoing count)
1419
+
1420
+ for node_key in component:
1421
+ incoming_count = len(incoming[node_key] & component)
1422
+ outgoing_count = len(graph[node_key] & component)
1423
+ score = (-incoming_count, outgoing_count)
1424
+
1425
+ if score > best_score:
1426
+ best_score = score
1427
+ best_node = node_key
1428
+
1429
+ # Link the best starting node to root
1430
+ if best_node:
1431
+ self._create_relationship(self._root_observable, best_node, RelationshipType.RELATED_TO)
1432
+ self._record_event(
1433
+ event_type="RELATIONSHIP_CREATED",
1434
+ object_type="observable",
1435
+ object_key=self._root_observable.key,
1436
+ reason="Finalize relationships",
1437
+ details={
1438
+ "target_key": best_node,
1439
+ "relationship_type": RelationshipType.RELATED_TO.value,
1440
+ "direction": RelationshipType.RELATED_TO.get_default_direction().value,
1441
+ },
1442
+ )
1443
+ self._score_engine.recalculate_all()
1444
+
1445
+ def merge_investigation(self, other: Investigation) -> None:
1446
+ """
1447
+ Merge another investigation into this one.
1448
+
1449
+ Uses a two-pass approach to handle relationship dependencies:
1450
+ - Pass 1: Merge all observables, collecting deferred relationships
1451
+ - Pass 2: Add deferred relationships now that all observables exist
1452
+
1453
+ Args:
1454
+ other: The investigation to merge
1455
+ """
1456
+
1457
+ def _diff_fields(before: dict[str, Any], after: dict[str, Any]) -> list[str]:
1458
+ return [field for field, value in before.items() if value != after.get(field)]
1459
+
1460
+ def _snapshot_observable(obs: Observable) -> dict[str, Any]:
1461
+ relationships = [
1462
+ (
1463
+ rel.target_key,
1464
+ rel.relationship_type_name,
1465
+ rel.direction.value,
1466
+ )
1467
+ for rel in obs.relationships
1468
+ ]
1469
+ return {
1470
+ "score": obs.score,
1471
+ "level": obs.level,
1472
+ "comment": obs.comment,
1473
+ "extra": deepcopy(obs.extra),
1474
+ "internal": obs.internal,
1475
+ "whitelisted": obs.whitelisted,
1476
+ "threat_intels": sorted(ti.key for ti in obs.threat_intels),
1477
+ "relationships": sorted(relationships),
1478
+ }
1479
+
1480
+ def _snapshot_check(check: Check) -> dict[str, Any]:
1481
+ links = [
1482
+ (
1483
+ link.observable_key,
1484
+ link.propagation_mode.value,
1485
+ )
1486
+ for link in check.observable_links
1487
+ ]
1488
+ return {
1489
+ "score": check.score,
1490
+ "level": check.level,
1491
+ "comment": check.comment,
1492
+ "description": check.description,
1493
+ "extra": deepcopy(check.extra),
1494
+ "origin_investigation_id": check.origin_investigation_id,
1495
+ "observable_links": sorted(links),
1496
+ }
1497
+
1498
+ def _snapshot_threat_intel(ti: ThreatIntel) -> dict[str, Any]:
1499
+ return {
1500
+ "score": ti.score,
1501
+ "level": ti.level,
1502
+ "comment": ti.comment,
1503
+ "extra": deepcopy(ti.extra),
1504
+ "taxonomies": deepcopy(ti.taxonomies),
1505
+ }
1506
+
1507
+ def _snapshot_enrichment(enrichment: Enrichment) -> dict[str, Any]:
1508
+ return {
1509
+ "context": enrichment.context,
1510
+ "data": deepcopy(enrichment.data),
1511
+ }
1512
+
1513
+ def _snapshot_tag(tag: Tag) -> dict[str, Any]:
1514
+ return {
1515
+ "description": tag.description,
1516
+ "checks": sorted(check.key for check in tag.checks),
1517
+ }
1518
+
1519
+ merge_summary: list[dict[str, Any]] = []
1520
+
1521
+ (
1522
+ incoming_observables,
1523
+ incoming_threat_intels,
1524
+ incoming_checks,
1525
+ incoming_enrichments,
1526
+ incoming_tags,
1527
+ ) = self._clone_for_merge(other)
1528
+
1529
+ # PASS 1: Merge observables and collect deferred relationships
1530
+ all_deferred_relationships = []
1531
+ for obs in incoming_observables.values():
1532
+ existing = self._observables.get(obs.key)
1533
+ before = _snapshot_observable(existing) if existing else None
1534
+ _, deferred = self.add_observable(obs)
1535
+ all_deferred_relationships.extend(deferred)
1536
+ if existing:
1537
+ after = _snapshot_observable(existing)
1538
+ changed_fields = _diff_fields(before, after) if before else []
1539
+ action = "merged" if changed_fields else "skipped"
1540
+ merge_summary.append(
1541
+ {
1542
+ "object_type": "observable",
1543
+ "object_key": obs.key,
1544
+ "action": action,
1545
+ "changed_fields": changed_fields,
1546
+ }
1547
+ )
1548
+ else:
1549
+ merge_summary.append(
1550
+ {
1551
+ "object_type": "observable",
1552
+ "object_key": obs.key,
1553
+ "action": "created",
1554
+ "changed_fields": [],
1555
+ }
1556
+ )
1557
+
1558
+ # PASS 2: Process deferred relationships now that all observables exist
1559
+ for source_key, rel in all_deferred_relationships:
1560
+ source_obs = self._observables.get(source_key)
1561
+ if source_obs and rel.target_key in self._observables:
1562
+ # Both source and target exist - add relationship
1563
+ self._create_relationship(source_obs, rel.target_key, rel.relationship_type, rel.direction)
1564
+ else:
1565
+ # Genuine error - target still doesn't exist after Pass 2
1566
+ logger.critical(
1567
+ "Relationship target '{}' not found after merge completion for observable '{}'. "
1568
+ "This indicates corrupted data or a bug in the merge logic.",
1569
+ rel.target_key,
1570
+ source_key,
1571
+ )
1572
+
1573
+ # Merge threat intels (need to link to observables)
1574
+ for ti in incoming_threat_intels.values():
1575
+ existing_ti = self._threat_intels.get(ti.key)
1576
+ before = _snapshot_threat_intel(existing_ti) if existing_ti else None
1577
+ # Find the observable this TI belongs to
1578
+ observable = self._observables.get(ti.observable_key)
1579
+ if observable:
1580
+ self.add_threat_intel(ti, observable)
1581
+ if existing_ti:
1582
+ after = _snapshot_threat_intel(existing_ti)
1583
+ changed_fields = _diff_fields(before, after) if before else []
1584
+ action = "merged" if changed_fields else "skipped"
1585
+ else:
1586
+ changed_fields = []
1587
+ action = "created"
1588
+ merge_summary.append(
1589
+ {
1590
+ "object_type": "threat_intel",
1591
+ "object_key": ti.key,
1592
+ "action": action,
1593
+ "changed_fields": changed_fields,
1594
+ }
1595
+ )
1596
+
1597
+ # Merge checks
1598
+ for check in incoming_checks.values():
1599
+ existing_check = self._checks.get(check.key)
1600
+ before = _snapshot_check(existing_check) if existing_check else None
1601
+ self.add_check(check)
1602
+ if existing_check:
1603
+ after = _snapshot_check(existing_check)
1604
+ changed_fields = _diff_fields(before, after) if before else []
1605
+ action = "merged" if changed_fields else "skipped"
1606
+ else:
1607
+ changed_fields = []
1608
+ action = "created"
1609
+ merge_summary.append(
1610
+ {
1611
+ "object_type": "check",
1612
+ "object_key": check.key,
1613
+ "action": action,
1614
+ "changed_fields": changed_fields,
1615
+ }
1616
+ )
1617
+
1618
+ # Merge enrichments
1619
+ for enrichment in incoming_enrichments.values():
1620
+ existing_enrichment = self._enrichments.get(enrichment.key)
1621
+ before = _snapshot_enrichment(existing_enrichment) if existing_enrichment else None
1622
+ self.add_enrichment(enrichment)
1623
+ if existing_enrichment:
1624
+ after = _snapshot_enrichment(existing_enrichment)
1625
+ changed_fields = _diff_fields(before, after) if before else []
1626
+ action = "merged" if changed_fields else "skipped"
1627
+ else:
1628
+ changed_fields = []
1629
+ action = "created"
1630
+ merge_summary.append(
1631
+ {
1632
+ "object_type": "enrichment",
1633
+ "object_key": enrichment.key,
1634
+ "action": action,
1635
+ "changed_fields": changed_fields,
1636
+ }
1637
+ )
1638
+
1639
+ # Merge tags
1640
+ for tag in incoming_tags.values():
1641
+ existing_tag = self._tags.get(tag.key)
1642
+ before = _snapshot_tag(existing_tag) if existing_tag else None
1643
+ self.add_tag(tag)
1644
+ if existing_tag:
1645
+ after = _snapshot_tag(existing_tag)
1646
+ changed_fields = _diff_fields(before, after) if before else []
1647
+ action = "merged" if changed_fields else "skipped"
1648
+ else:
1649
+ changed_fields = []
1650
+ action = "created"
1651
+ merge_summary.append(
1652
+ {
1653
+ "object_type": "tag",
1654
+ "object_key": tag.key,
1655
+ "action": action,
1656
+ "changed_fields": changed_fields,
1657
+ }
1658
+ )
1659
+
1660
+ # Merge whitelists (other investigation overrides on identifier conflicts)
1661
+ for entry in other.get_whitelists():
1662
+ self.add_whitelist(entry.identifier, entry.name, entry.justification)
1663
+
1664
+ # Rebuild link index after merges
1665
+ self._score_engine.rebuild_link_index()
1666
+ self._rebuild_all_check_links()
1667
+
1668
+ # Final score recalculation
1669
+ self._score_engine.recalculate_all()
1670
+
1671
+ self._record_event(
1672
+ event_type="INVESTIGATION_MERGED",
1673
+ object_type="investigation",
1674
+ object_key=self.investigation_id,
1675
+ details={
1676
+ "from_investigation_id": other.investigation_id,
1677
+ "into_investigation_id": self.investigation_id,
1678
+ "from_investigation_name": other.investigation_name,
1679
+ "into_investigation_name": self.investigation_name,
1680
+ "object_changes": merge_summary,
1681
+ },
1682
+ )