cyvest 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cyvest might be problematic. Click here for more details.

@@ -0,0 +1,1644 @@
1
+ """
2
+ Investigation core - central state management for cybersecurity investigations.
3
+
4
+ Handles all object storage, merging, scoring, and statistics in a unified way.
5
+ Provides automatic merge-on-create for all object types.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from copy import deepcopy
11
+ from datetime import datetime, timezone
12
+ from decimal import Decimal
13
+ from typing import TYPE_CHECKING, Any, Literal
14
+
15
+ from logurich import logger
16
+
17
+ from cyvest.level_score_rules import recalculate_level_for_score
18
+ from cyvest.levels import Level, normalize_level
19
+ from cyvest.model import (
20
+ AuditEvent,
21
+ Check,
22
+ Container,
23
+ Enrichment,
24
+ InvestigationWhitelist,
25
+ Observable,
26
+ ObservableLink,
27
+ ObservableType,
28
+ Relationship,
29
+ Taxonomy,
30
+ ThreatIntel,
31
+ )
32
+ from cyvest.model_enums import PropagationMode, RelationshipDirection, RelationshipType
33
+ from cyvest.score import ScoreEngine, ScoreMode
34
+ from cyvest.stats import InvestigationStats
35
+ from cyvest.ulid import generate_ulid
36
+
37
+ if TYPE_CHECKING:
38
+ from cyvest.model_schema import StatisticsSchema
39
+
40
+
41
+ class Investigation:
42
+ """
43
+ Core investigation state and operations.
44
+
45
+ Manages all investigation objects (observables, checks, threat intel, etc.),
46
+ handles automatic merging on creation, score propagation, and statistics tracking.
47
+ """
48
+
49
+ _MODEL_METADATA_RULES: dict[str, dict[str, set[str]]] = {
50
+ "observable": {
51
+ "fields": {"comment", "extra", "internal", "whitelisted"},
52
+ "dict_fields": {"extra"},
53
+ },
54
+ "check": {
55
+ "fields": {"comment", "extra", "description"},
56
+ "dict_fields": {"extra"},
57
+ },
58
+ "threat_intel": {
59
+ "fields": {"comment", "extra", "level", "taxonomies"},
60
+ "dict_fields": {"extra"},
61
+ },
62
+ "enrichment": {
63
+ "fields": {"context", "data"},
64
+ "dict_fields": {"data"},
65
+ },
66
+ "container": {
67
+ "fields": {"description"},
68
+ "dict_fields": set(),
69
+ },
70
+ }
71
+
72
+ def __init__(
73
+ self,
74
+ root_data: Any = None,
75
+ root_type: ObservableType | Literal["file", "artifact"] = ObservableType.FILE,
76
+ score_mode_obs: ScoreMode | Literal["max", "sum"] = ScoreMode.MAX,
77
+ *,
78
+ investigation_id: str | None = None,
79
+ investigation_name: str | None = None,
80
+ ) -> None:
81
+ """
82
+ Initialize a new investigation.
83
+
84
+ Args:
85
+ root_data: Data stored on the root observable (optional)
86
+ root_type: Root observable type (ObservableType.FILE or ObservableType.ARTIFACT)
87
+ score_mode_obs: Observable score calculation mode (MAX or SUM)
88
+ investigation_name: Optional human-readable investigation name
89
+ """
90
+ self.investigation_id = investigation_id or generate_ulid()
91
+ self.investigation_name = investigation_name
92
+ self._audit_log: list[AuditEvent] = []
93
+ self._audit_enabled = True
94
+
95
+ # Record investigation start as the first event
96
+ self._record_event(
97
+ event_type="INVESTIGATION_STARTED",
98
+ object_type="investigation",
99
+ object_key=self.investigation_id,
100
+ )
101
+
102
+ # Object collections
103
+ self._observables: dict[str, Observable] = {}
104
+ self._checks: dict[str, Check] = {}
105
+ self._threat_intels: dict[str, ThreatIntel] = {}
106
+ self._enrichments: dict[str, Enrichment] = {}
107
+ self._containers: dict[str, Container] = {}
108
+
109
+ # Internal components
110
+ normalized_score_mode_obs = ScoreMode.normalize(score_mode_obs)
111
+ self._score_engine = ScoreEngine(score_mode_obs=normalized_score_mode_obs, sink=self)
112
+ self._stats = InvestigationStats()
113
+ self._whitelists: dict[str, InvestigationWhitelist] = {}
114
+
115
+ # Create root observable
116
+ obj_type = ObservableType.normalize_root_type(root_type)
117
+
118
+ self._root_observable = Observable(
119
+ obs_type=obj_type,
120
+ value="root",
121
+ internal=False,
122
+ whitelisted=False,
123
+ comment="Root observable for investigation",
124
+ extra=root_data,
125
+ score=Decimal("0"),
126
+ level=Level.INFO,
127
+ )
128
+ self._observables[self._root_observable.key] = self._root_observable
129
+ self._score_engine.register_observable(self._root_observable)
130
+ self._stats.register_observable(self._root_observable)
131
+ self._record_event(
132
+ event_type="OBSERVABLE_CREATED",
133
+ object_type="observable",
134
+ object_key=self._root_observable.key,
135
+ )
136
+
137
+ def _record_event(
138
+ self,
139
+ *,
140
+ event_type: str,
141
+ object_type: str | None = None,
142
+ object_key: str | None = None,
143
+ reason: str | None = None,
144
+ actor: str | None = None,
145
+ tool: str | None = None,
146
+ details: dict[str, Any] | None = None,
147
+ timestamp: datetime | None = None,
148
+ ) -> AuditEvent | None:
149
+ if not self._audit_enabled:
150
+ return None
151
+
152
+ event = AuditEvent(
153
+ event_id=generate_ulid(),
154
+ timestamp=timestamp or datetime.now(timezone.utc),
155
+ event_type=event_type,
156
+ actor=actor,
157
+ reason=reason,
158
+ tool=tool,
159
+ object_type=object_type,
160
+ object_key=object_key,
161
+ details=deepcopy(details) if details else {},
162
+ )
163
+ self._audit_log.append(event)
164
+ return event
165
+
166
+ @property
167
+ def started_at(self) -> datetime:
168
+ """Return the investigation start time from the first event in the audit log."""
169
+ for event in self._audit_log:
170
+ if event.event_type == "INVESTIGATION_STARTED":
171
+ return event.timestamp
172
+ # Fallback if no INVESTIGATION_STARTED event (shouldn't happen)
173
+ return datetime.now(timezone.utc)
174
+
175
+ def _add_threat_intel_to_observable(self, observable: Observable, ti: ThreatIntel) -> None:
176
+ if any(existing.key == ti.key for existing in observable.threat_intels):
177
+ return
178
+ observable.threat_intels.append(ti)
179
+
180
+ def _add_relationship_internal(
181
+ self,
182
+ source_obs: Observable,
183
+ target_key: str,
184
+ relationship_type: RelationshipType | str,
185
+ direction: RelationshipDirection | str | None = None,
186
+ ) -> None:
187
+ rel = Relationship(target_key=target_key, relationship_type=relationship_type, direction=direction)
188
+ rel_tuple = (rel.target_key, rel.relationship_type, rel.direction)
189
+ existing_rels = {(r.target_key, r.relationship_type, r.direction) for r in source_obs.relationships}
190
+ if rel_tuple not in existing_rels:
191
+ source_obs.relationships.append(rel)
192
+
193
+ def _add_check_observable_link(self, check: Check, link: ObservableLink) -> bool:
194
+ existing: dict[tuple[str, PropagationMode], int] = {}
195
+ for idx, existing_link in enumerate(check.observable_links):
196
+ existing[(existing_link.observable_key, existing_link.propagation_mode)] = idx
197
+ link_tuple = (link.observable_key, link.propagation_mode)
198
+ if link_tuple in existing:
199
+ return False
200
+
201
+ check.observable_links.append(link)
202
+ return True
203
+
204
+ def _container_add_check(self, container: Container, check: Check) -> None:
205
+ if any(existing.key == check.key for existing in container.checks):
206
+ return
207
+ container.checks.append(check)
208
+
209
+ def _container_add_sub_container(self, parent: Container, child: Container) -> None:
210
+ parent.sub_containers[child.key] = child
211
+
212
+ def _infer_object_type(self, obj: Any) -> str | None:
213
+ if isinstance(obj, Observable):
214
+ return "observable"
215
+ if isinstance(obj, Check):
216
+ return "check"
217
+ if isinstance(obj, ThreatIntel):
218
+ return "threat_intel"
219
+ if isinstance(obj, Enrichment):
220
+ return "enrichment"
221
+ if isinstance(obj, Container):
222
+ return "container"
223
+ return None
224
+
225
+ def apply_score_change(
226
+ self,
227
+ obj: Any,
228
+ new_score: Decimal,
229
+ *,
230
+ reason: str = "",
231
+ event_type: str = "SCORE_CHANGED",
232
+ contributing_investigation_ids: set[str] | None = None,
233
+ ) -> bool:
234
+ """Apply a score change and emit an audit event."""
235
+ if not isinstance(new_score, Decimal):
236
+ new_score = Decimal(str(new_score))
237
+
238
+ old_score = obj.score
239
+ old_level = obj.level
240
+ new_level = recalculate_level_for_score(old_level, new_score)
241
+
242
+ if new_score == old_score and new_level == old_level:
243
+ return False
244
+
245
+ obj.score = new_score
246
+ obj.level = new_level
247
+
248
+ if event_type == "SCORE_RECALCULATED":
249
+ # Skip audit log entry for recalculated scores.
250
+ return True
251
+
252
+ details = {
253
+ "old_score": float(old_score),
254
+ "new_score": float(new_score),
255
+ "old_level": old_level.value,
256
+ "new_level": new_level.value,
257
+ }
258
+ if contributing_investigation_ids:
259
+ details["contributing_investigation_ids"] = sorted(contributing_investigation_ids)
260
+
261
+ self._record_event(
262
+ event_type=event_type,
263
+ object_type=self._infer_object_type(obj),
264
+ object_key=getattr(obj, "key", None),
265
+ reason=reason,
266
+ details=details,
267
+ )
268
+ return True
269
+
270
+ def apply_level_change(
271
+ self,
272
+ obj: Any,
273
+ level: Level | str,
274
+ *,
275
+ reason: str = "",
276
+ event_type: str = "LEVEL_UPDATED",
277
+ ) -> bool:
278
+ """Apply a level change and emit an audit event."""
279
+ new_level = normalize_level(level)
280
+ old_level = obj.level
281
+ if new_level == old_level:
282
+ return False
283
+
284
+ obj.level = new_level
285
+ self._record_event(
286
+ event_type=event_type,
287
+ object_type=self._infer_object_type(obj),
288
+ object_key=getattr(obj, "key", None),
289
+ reason=reason,
290
+ details={
291
+ "old_level": old_level.value,
292
+ "new_level": new_level.value,
293
+ "score": float(obj.score),
294
+ },
295
+ )
296
+ return True
297
+
298
+ def _sync_check_links_for_observable(self, observable_key: str) -> None:
299
+ obs = self._observables.get(observable_key)
300
+ if not obs:
301
+ return
302
+ check_keys = self._score_engine.get_check_links_for_observable(observable_key)
303
+ obs._check_links = check_keys
304
+
305
+ def _refresh_check_links(self) -> None:
306
+ for observable_key in self._observables:
307
+ self._sync_check_links_for_observable(observable_key)
308
+
309
+ def get_audit_log(self) -> list[AuditEvent]:
310
+ """Return a deep copy of the audit log."""
311
+ return [event.model_copy(deep=True) for event in self._audit_log]
312
+
313
+ def get_audit_events(
314
+ self,
315
+ *,
316
+ object_type: str | None = None,
317
+ object_key: str | None = None,
318
+ event_type: str | None = None,
319
+ ) -> list[AuditEvent]:
320
+ """Filter audit events by optional object type/key and event type."""
321
+ events = self._audit_log
322
+ if object_type is not None:
323
+ events = [event for event in events if event.object_type == object_type]
324
+ if object_key is not None:
325
+ events = [event for event in events if event.object_key == object_key]
326
+ if event_type is not None:
327
+ events = [event for event in events if event.event_type == event_type]
328
+ return [event.model_copy(deep=True) for event in events]
329
+
330
+ def set_investigation_name(self, name: str | None, *, reason: str | None = None) -> None:
331
+ """Set or clear the human-readable investigation name."""
332
+ name = str(name).strip() if name is not None else None
333
+ if name == self.investigation_name:
334
+ return
335
+ old_name = self.investigation_name
336
+ self.investigation_name = name
337
+ self._record_event(
338
+ event_type="INVESTIGATION_NAME_UPDATED",
339
+ object_type="investigation",
340
+ object_key=self.investigation_id,
341
+ reason=reason,
342
+ details={"old_name": old_name, "new_name": name},
343
+ )
344
+
345
+ # Private merge methods
346
+
347
+ def _merge_observable(self, existing: Observable, incoming: Observable) -> tuple[Observable, list]:
348
+ """
349
+ Merge an incoming observable into an existing observable.
350
+
351
+ Strategy:
352
+ - Update score (take maximum)
353
+ - Update level (take maximum)
354
+ - Update extra (merge dicts)
355
+ - Concatenate comments
356
+ - Merge threat intels
357
+ - Merge relationships (defer if target missing)
358
+ - Preserve provenance metadata
359
+
360
+ Args:
361
+ existing: The existing observable
362
+ incoming: The incoming observable to merge
363
+
364
+ Returns:
365
+ Tuple of (merged observable, deferred relationships)
366
+ """
367
+ # Normal merge logic for scores and levels (SAFE level protection in Observable.update_score)
368
+ # Take the higher score
369
+ if incoming.score > existing.score:
370
+ self.apply_score_change(
371
+ existing,
372
+ incoming.score,
373
+ reason=f"Merged from {incoming.key}",
374
+ )
375
+
376
+ # Take the higher level
377
+ if incoming.level > existing.level:
378
+ self.apply_level_change(existing, incoming.level, reason=f"Merged from {incoming.key}")
379
+
380
+ # Update extra (merge dictionaries)
381
+ if existing.extra:
382
+ existing.extra.update(incoming.extra)
383
+ elif incoming.extra:
384
+ existing.extra = dict(incoming.extra)
385
+
386
+ # Concatenate comments
387
+ if incoming.comment:
388
+ if existing.comment:
389
+ existing.comment += "\n\n" + incoming.comment
390
+ else:
391
+ existing.comment = incoming.comment
392
+
393
+ # Merge whitelisted status (if either is whitelisted, result is whitelisted)
394
+ existing.whitelisted = existing.whitelisted or incoming.whitelisted
395
+
396
+ # Merge internal status (if either is external, result is external)
397
+ existing.internal = existing.internal and incoming.internal
398
+
399
+ # Merge threat intels (avoid duplicates by key)
400
+ existing_ti_keys = {ti.key for ti in existing.threat_intels}
401
+ for ti in incoming.threat_intels:
402
+ if ti.key not in existing_ti_keys:
403
+ self._add_threat_intel_to_observable(existing, ti)
404
+ existing_ti_keys.add(ti.key)
405
+
406
+ # Merge relationships (defer if target not yet available)
407
+ deferred_relationships = []
408
+ for rel in incoming.relationships:
409
+ if rel.target_key in self._observables:
410
+ # Target exists - add relationship immediately
411
+ self._add_relationship_internal(existing, rel.target_key, rel.relationship_type, rel.direction)
412
+ else:
413
+ # Target doesn't exist yet - defer for Pass 2 of merge_investigation()
414
+ deferred_relationships.append((existing.key, rel))
415
+
416
+ return existing, deferred_relationships
417
+
418
+ def _merge_check(self, existing: Check, incoming: Check) -> Check:
419
+ """
420
+ Merge an incoming check into an existing check.
421
+
422
+ Strategy:
423
+ - Update score (take maximum)
424
+ - Update level (take maximum)
425
+ - Update extra (merge dicts)
426
+ - Concatenate comments
427
+ - Merge observable links (tuple-based deduplication, provenance-preserving)
428
+
429
+ Args:
430
+ existing: The existing check
431
+ incoming: The incoming check to merge
432
+
433
+ Returns:
434
+ The merged check (existing is modified in place)
435
+ """
436
+ if not incoming.origin_investigation_id:
437
+ incoming.origin_investigation_id = self.investigation_id
438
+ if not existing.origin_investigation_id:
439
+ existing.origin_investigation_id = incoming.origin_investigation_id
440
+
441
+ # Take the higher score
442
+ if incoming.score > existing.score:
443
+ self.apply_score_change(
444
+ existing,
445
+ incoming.score,
446
+ reason=f"Merged from {incoming.key}",
447
+ )
448
+
449
+ # Take the higher level
450
+ if incoming.level > existing.level:
451
+ self.apply_level_change(existing, incoming.level, reason=f"Merged from {incoming.key}")
452
+
453
+ # Update extra (merge dictionaries)
454
+ existing.extra.update(incoming.extra)
455
+
456
+ # Concatenate comments
457
+ if incoming.comment:
458
+ if existing.comment:
459
+ existing.comment += "\n\n" + incoming.comment
460
+ else:
461
+ existing.comment = incoming.comment
462
+
463
+ existing_by_tuple: dict[tuple[str, PropagationMode], int] = {}
464
+ for idx, existing_link in enumerate(existing.observable_links):
465
+ existing_by_tuple[(existing_link.observable_key, existing_link.propagation_mode)] = idx
466
+ for incoming_link in incoming.observable_links:
467
+ link_tuple = (incoming_link.observable_key, incoming_link.propagation_mode)
468
+ existing_idx = existing_by_tuple.get(link_tuple)
469
+ if existing_idx is None:
470
+ existing.observable_links.append(incoming_link)
471
+ existing_by_tuple[link_tuple] = len(existing.observable_links) - 1
472
+ continue
473
+
474
+ return existing
475
+
476
+ def _merge_threat_intel(self, existing: ThreatIntel, incoming: ThreatIntel) -> ThreatIntel:
477
+ """
478
+ Merge an incoming threat intel into an existing threat intel.
479
+
480
+ Strategy:
481
+ - Update score (take maximum)
482
+ - Update level (take maximum)
483
+ - Update extra (merge dicts)
484
+ - Concatenate comments
485
+ - Merge taxonomies
486
+
487
+ Args:
488
+ existing: The existing threat intel
489
+ incoming: The incoming threat intel to merge
490
+
491
+ Returns:
492
+ The merged threat intel (existing is modified in place)
493
+ """
494
+ # Take the higher score
495
+ if incoming.score > existing.score:
496
+ self.apply_score_change(
497
+ existing,
498
+ incoming.score,
499
+ reason=f"Merged from {incoming.key}",
500
+ )
501
+
502
+ # Take the higher level
503
+ if incoming.level > existing.level:
504
+ self.apply_level_change(existing, incoming.level, reason=f"Merged from {incoming.key}")
505
+
506
+ # Update extra (merge dictionaries)
507
+ existing.extra.update(incoming.extra)
508
+
509
+ # Concatenate comments
510
+ if incoming.comment:
511
+ if existing.comment:
512
+ existing.comment += "\n\n" + incoming.comment
513
+ else:
514
+ existing.comment = incoming.comment
515
+
516
+ # Merge taxonomies (ensure unique names)
517
+ existing_by_name: dict[str, int] = {taxonomy.name: idx for idx, taxonomy in enumerate(existing.taxonomies)}
518
+ for taxonomy in incoming.taxonomies:
519
+ existing_idx = existing_by_name.get(taxonomy.name)
520
+ if existing_idx is None:
521
+ existing.taxonomies.append(taxonomy)
522
+ existing_by_name[taxonomy.name] = len(existing.taxonomies) - 1
523
+ else:
524
+ existing.taxonomies[existing_idx] = taxonomy
525
+
526
+ return existing
527
+
528
+ def _merge_enrichment(self, existing: Enrichment, incoming: Enrichment) -> Enrichment:
529
+ """
530
+ Merge an incoming enrichment into an existing enrichment.
531
+
532
+ Strategy:
533
+ - Deep merge data structure (merge dictionaries recursively)
534
+
535
+ Args:
536
+ existing: The existing enrichment
537
+ incoming: The incoming enrichment to merge
538
+
539
+ Returns:
540
+ The merged enrichment (existing is modified in place)
541
+ """
542
+
543
+ def deep_merge(base: dict, update: dict) -> dict:
544
+ """Recursively merge dictionaries."""
545
+ for key, value in update.items():
546
+ if key in base and isinstance(base[key], dict) and isinstance(value, dict):
547
+ deep_merge(base[key], value)
548
+ else:
549
+ base[key] = value
550
+ return base
551
+
552
+ # Deep merge data structures
553
+ if isinstance(existing.data, dict) and isinstance(incoming.data, dict):
554
+ deep_merge(existing.data, incoming.data)
555
+ else:
556
+ existing.data = deepcopy(incoming.data)
557
+
558
+ # Update context if incoming has one
559
+ if incoming.context:
560
+ existing.context = incoming.context
561
+
562
+ return existing
563
+
564
+ def _merge_container(self, existing: Container, incoming: Container) -> Container:
565
+ """
566
+ Merge an incoming container into an existing container.
567
+
568
+ Strategy:
569
+ - Merge checks (dict-based lookup for efficiency)
570
+ - Merge sub-containers recursively
571
+
572
+ Args:
573
+ existing: The existing container
574
+ incoming: The incoming container to merge
575
+
576
+ Returns:
577
+ The merged container (existing is modified in place)
578
+ """
579
+ # Update description if incoming has one
580
+ if incoming.description:
581
+ existing.description = incoming.description
582
+
583
+ # Merge checks using dict-based lookup (more efficient)
584
+ existing_checks_dict = {check.key: check for check in existing.checks}
585
+
586
+ for incoming_check in incoming.checks:
587
+ if incoming_check.key in existing_checks_dict:
588
+ # Merge existing check
589
+ self._merge_check(existing_checks_dict[incoming_check.key], incoming_check)
590
+ else:
591
+ # Add new check
592
+ self._container_add_check(existing, incoming_check)
593
+
594
+ # Merge sub-containers recursively
595
+ for sub_key, incoming_sub in incoming.sub_containers.items():
596
+ if sub_key in existing.sub_containers:
597
+ # Merge existing sub-container
598
+ self._merge_container(existing.sub_containers[sub_key], incoming_sub)
599
+ else:
600
+ # Add new sub-container
601
+ self._container_add_sub_container(existing, incoming_sub)
602
+
603
+ return existing
604
+
605
+ # Public add methods with merge-on-create
606
+
607
+ def add_observable(self, obs: Observable) -> tuple[Observable, list]:
608
+ """
609
+ Add or merge an observable.
610
+
611
+ Args:
612
+ obs: Observable to add or merge
613
+
614
+ Returns:
615
+ Tuple of (resulting observable, deferred relationships)
616
+ """
617
+ if obs.key in self._observables:
618
+ r = self._merge_observable(self._observables[obs.key], obs)
619
+ self._score_engine.recalculate_all()
620
+ return r
621
+
622
+ # Register new observable
623
+ self._observables[obs.key] = obs
624
+ self._score_engine.register_observable(obs)
625
+ self._stats.register_observable(obs)
626
+ self._sync_check_links_for_observable(obs.key)
627
+ self._record_event(
628
+ event_type="OBSERVABLE_CREATED",
629
+ object_type="observable",
630
+ object_key=obs.key,
631
+ )
632
+ return obs, []
633
+
634
+ def add_check(self, check: Check) -> Check:
635
+ """
636
+ Add or merge a check.
637
+
638
+ Args:
639
+ check: Check to add or merge
640
+
641
+ Returns:
642
+ The resulting check (either new or merged)
643
+ """
644
+ if check.key in self._checks:
645
+ r = self._merge_check(self._checks[check.key], check)
646
+ self._score_engine.rebuild_link_index()
647
+ self._score_engine.recalculate_all()
648
+ for link in r.observable_links:
649
+ self._sync_check_links_for_observable(link.observable_key)
650
+ return r
651
+
652
+ if not getattr(check, "origin_investigation_id", None):
653
+ check.origin_investigation_id = self.investigation_id
654
+
655
+ # Register new check
656
+ self._checks[check.key] = check
657
+ self._score_engine.register_check(check)
658
+ self._stats.register_check(check)
659
+ for link in check.observable_links:
660
+ self._sync_check_links_for_observable(link.observable_key)
661
+ self._record_event(
662
+ event_type="CHECK_CREATED",
663
+ object_type="check",
664
+ object_key=check.key,
665
+ )
666
+ return check
667
+
668
+ def add_threat_intel(self, ti: ThreatIntel, observable: Observable) -> ThreatIntel:
669
+ """
670
+ Add or merge threat intel and link to observable.
671
+
672
+ Args:
673
+ ti: Threat intel to add or merge
674
+ observable: Observable to link to
675
+
676
+ Returns:
677
+ The resulting threat intel (either new or merged)
678
+ """
679
+ if ti.key in self._threat_intels:
680
+ merged_ti = self._merge_threat_intel(self._threat_intels[ti.key], ti)
681
+ # Propagate score to observable
682
+ self._score_engine.propagate_threat_intel_to_observable(merged_ti, observable)
683
+ self._record_event(
684
+ event_type="THREAT_INTEL_ATTACHED",
685
+ object_type="observable",
686
+ object_key=observable.key,
687
+ details={
688
+ "threat_intel_key": merged_ti.key,
689
+ "source": merged_ti.source,
690
+ "score": merged_ti.score,
691
+ "level": merged_ti.level,
692
+ },
693
+ )
694
+ return merged_ti
695
+
696
+ # Register new threat intel
697
+ self._threat_intels[ti.key] = ti
698
+ self._stats.register_threat_intel(ti)
699
+
700
+ # Add to observable
701
+ self._add_threat_intel_to_observable(observable, ti)
702
+
703
+ # Propagate score
704
+ self._score_engine.propagate_threat_intel_to_observable(ti, observable)
705
+
706
+ self._record_event(
707
+ event_type="THREAT_INTEL_ATTACHED",
708
+ object_type="observable",
709
+ object_key=observable.key,
710
+ details={
711
+ "threat_intel_key": ti.key,
712
+ "source": ti.source,
713
+ "score": ti.score,
714
+ "level": ti.level,
715
+ },
716
+ )
717
+ return ti
718
+
719
+ def add_threat_intel_taxonomy(self, threat_intel_key: str, taxonomy: Taxonomy) -> ThreatIntel:
720
+ """
721
+ Add or replace a taxonomy entry on a threat intel by name.
722
+
723
+ Args:
724
+ threat_intel_key: Threat intel key
725
+ taxonomy: Taxonomy entry to add or replace
726
+
727
+ Returns:
728
+ The updated threat intel
729
+ """
730
+ ti = self._threat_intels.get(threat_intel_key)
731
+ if ti is None:
732
+ raise KeyError(f"threat_intel '{threat_intel_key}' not found in investigation.")
733
+
734
+ updated_taxonomies = list(ti.taxonomies)
735
+ replaced = False
736
+ for idx, existing in enumerate(updated_taxonomies):
737
+ if existing.name == taxonomy.name:
738
+ updated_taxonomies[idx] = taxonomy
739
+ replaced = True
740
+ break
741
+
742
+ if not replaced:
743
+ updated_taxonomies.append(taxonomy)
744
+
745
+ return self.update_model_metadata("threat_intel", threat_intel_key, {"taxonomies": updated_taxonomies})
746
+
747
+ def remove_threat_intel_taxonomy(self, threat_intel_key: str, name: str) -> ThreatIntel:
748
+ """
749
+ Remove a taxonomy entry from a threat intel by name.
750
+
751
+ Args:
752
+ threat_intel_key: Threat intel key
753
+ name: Taxonomy name to remove
754
+
755
+ Returns:
756
+ The updated threat intel
757
+ """
758
+ ti = self._threat_intels.get(threat_intel_key)
759
+ if ti is None:
760
+ raise KeyError(f"threat_intel '{threat_intel_key}' not found in investigation.")
761
+
762
+ updated_taxonomies = [taxonomy for taxonomy in ti.taxonomies if taxonomy.name != name]
763
+ if len(updated_taxonomies) == len(ti.taxonomies):
764
+ return ti
765
+
766
+ return self.update_model_metadata("threat_intel", threat_intel_key, {"taxonomies": updated_taxonomies})
767
+
768
+ def add_enrichment(self, enrichment: Enrichment) -> Enrichment:
769
+ """
770
+ Add or merge enrichment.
771
+
772
+ Args:
773
+ enrichment: Enrichment to add or merge
774
+
775
+ Returns:
776
+ The resulting enrichment (either new or merged)
777
+ """
778
+ if enrichment.key in self._enrichments:
779
+ return self._merge_enrichment(self._enrichments[enrichment.key], enrichment)
780
+
781
+ # Register new enrichment
782
+ self._enrichments[enrichment.key] = enrichment
783
+ self._record_event(
784
+ event_type="ENRICHMENT_CREATED",
785
+ object_type="enrichment",
786
+ object_key=enrichment.key,
787
+ )
788
+ return enrichment
789
+
790
+ def add_container(self, container: Container) -> Container:
791
+ """
792
+ Add or merge container.
793
+
794
+ Args:
795
+ container: Container to add or merge
796
+
797
+ Returns:
798
+ The resulting container (either new or merged)
799
+ """
800
+ if container.key in self._containers:
801
+ r = self._merge_container(self._containers[container.key], container)
802
+ self._score_engine.recalculate_all()
803
+ return r
804
+
805
+ # Register new container
806
+ self._containers[container.key] = container
807
+ self._stats.register_container(container)
808
+ self._record_event(
809
+ event_type="CONTAINER_CREATED",
810
+ object_type="container",
811
+ object_key=container.key,
812
+ )
813
+ return container
814
+
815
+ # Relationship and linking methods
816
+
817
+ def add_relationship(
818
+ self,
819
+ source: Observable | str,
820
+ target: Observable | str,
821
+ relationship_type: RelationshipType | str,
822
+ direction: RelationshipDirection | str | None = None,
823
+ ) -> Observable:
824
+ """
825
+ Add a relationship between observables.
826
+
827
+ Args:
828
+ source: Source observable or its key
829
+ target: Target observable or its key
830
+ relationship_type: Type of relationship
831
+ direction: Direction of the relationship (None = use semantic default)
832
+
833
+ Returns:
834
+ The source observable
835
+
836
+ Raises:
837
+ KeyError: If the source or target observable does not exist
838
+ """
839
+
840
+ # Extract keys from Observable objects if needed
841
+ source_key = source.key if isinstance(source, Observable) else source
842
+ target_key = target.key if isinstance(target, Observable) else target
843
+
844
+ # Check if target is a copy from shared context (anti-pattern)
845
+ if isinstance(target, Observable) and getattr(target, "_from_shared_context", False):
846
+ obs_type_name = target.obs_type.name
847
+ raise ValueError(
848
+ f"Cannot use observable from shared_context.observable_get() directly in relationships.\n"
849
+ f"Observable '{target_key}' is a read-only copy not registered in this investigation.\n\n"
850
+ f"Incorrect pattern:\n"
851
+ f" source.relate_to(shared_context.observable_get(...), RelationshipType.{relationship_type})\n\n"
852
+ f"Correct pattern (and use reconcile or merge):\n"
853
+ f" # Use cy.observable() to create/get observable in local investigation\n"
854
+ f" source.relate_to(\n"
855
+ f" cy.observable(ObservableType.{obs_type_name}, '{target.value}'),\n"
856
+ f" RelationshipType.{relationship_type}\n"
857
+ f" )"
858
+ )
859
+
860
+ # Validate both source and target exist
861
+ source_obs = self._observables.get(source_key)
862
+ target_obs = self._observables.get(target_key)
863
+
864
+ if not source_obs:
865
+ raise KeyError(f"observable '{source_key}' not found in investigation.")
866
+
867
+ if not target_obs:
868
+ raise KeyError(f"observable '{target_key}' not found in investigation.")
869
+
870
+ # Add relationship using internal method
871
+ self._add_relationship_internal(source_obs, target_key, relationship_type, direction)
872
+
873
+ self._record_event(
874
+ event_type="RELATIONSHIP_CREATED",
875
+ object_type="observable",
876
+ object_key=source_obs.key,
877
+ details={
878
+ "target_key": target_key,
879
+ "relationship_type": relationship_type,
880
+ "direction": direction,
881
+ },
882
+ )
883
+
884
+ # Recalculate scores after adding relationship
885
+ self._score_engine.recalculate_all()
886
+
887
+ return source_obs
888
+
889
+ def link_check_observable(
890
+ self,
891
+ check_key: str,
892
+ observable_key: str,
893
+ propagation_mode: PropagationMode | str = PropagationMode.LOCAL_ONLY,
894
+ ) -> Check:
895
+ """
896
+ Link an observable to a check.
897
+
898
+ Args:
899
+ check_key: Key of the check
900
+ observable_key: Key of the observable
901
+ propagation_mode: Propagation behavior for this link
902
+
903
+ Returns:
904
+ The check
905
+
906
+ Raises:
907
+ KeyError: If the check or observable does not exist
908
+ """
909
+ check = self._checks.get(check_key)
910
+ observable = self._observables.get(observable_key)
911
+
912
+ if check is None:
913
+ raise KeyError(f"check '{check_key}' not found in investigation.")
914
+ if observable is None:
915
+ raise KeyError(f"observable '{observable_key}' not found in investigation.")
916
+
917
+ if check and observable:
918
+ propagation_mode = PropagationMode(propagation_mode)
919
+ link = ObservableLink(
920
+ observable_key=observable_key,
921
+ propagation_mode=propagation_mode,
922
+ )
923
+ created = self._add_check_observable_link(check, link)
924
+ if created:
925
+ self._score_engine.register_check_observable_link(check_key=check.key, observable_key=observable_key)
926
+ self._sync_check_links_for_observable(observable_key)
927
+ self._record_event(
928
+ event_type="CHECK_LINKED_TO_OBSERVABLE",
929
+ object_type="check",
930
+ object_key=check.key,
931
+ details={
932
+ "observable_key": observable_key,
933
+ "propagation_mode": propagation_mode.value,
934
+ },
935
+ )
936
+ is_effective = (
937
+ propagation_mode == PropagationMode.GLOBAL or self.investigation_id == check.origin_investigation_id
938
+ )
939
+ if is_effective and check.level == Level.NONE:
940
+ self.apply_level_change(check, Level.INFO, reason="Effective link added")
941
+
942
+ self._score_engine._propagate_observable_to_checks(observable_key)
943
+
944
+ return check
945
+
946
+ def add_check_to_container(self, container_key: str, check_key: str) -> Container:
947
+ """
948
+ Add a check to a container.
949
+
950
+ Args:
951
+ container_key: Key of the container
952
+ check_key: Key of the check
953
+
954
+ Returns:
955
+ The container
956
+
957
+ Raises:
958
+ KeyError: If the container or check does not exist
959
+ """
960
+ container = self._containers.get(container_key)
961
+ check = self._checks.get(check_key)
962
+
963
+ if container is None:
964
+ raise KeyError(f"container '{container_key}' not found in investigation.")
965
+ if check is None:
966
+ raise KeyError(f"check '{check_key}' not found in investigation.")
967
+
968
+ if container and check:
969
+ self._container_add_check(container, check)
970
+ self._record_event(
971
+ event_type="CONTAINER_CHECK_ADDED",
972
+ object_type="container",
973
+ object_key=container.key,
974
+ details={"check_key": check.key},
975
+ )
976
+
977
+ return container
978
+
979
+ def add_sub_container(self, parent_key: str, child_key: str) -> Container:
980
+ """
981
+ Add a sub-container to a container.
982
+
983
+ Args:
984
+ parent_key: Key of the parent container
985
+ child_key: Key of the child container
986
+
987
+ Returns:
988
+ The parent container
989
+
990
+ Raises:
991
+ KeyError: If the parent or child container does not exist
992
+ """
993
+ parent = self._containers.get(parent_key)
994
+ child = self._containers.get(child_key)
995
+
996
+ if parent is None:
997
+ raise KeyError(f"container '{parent_key}' not found in investigation.")
998
+ if child is None:
999
+ raise KeyError(f"container '{child_key}' not found in investigation.")
1000
+
1001
+ if parent and child:
1002
+ self._container_add_sub_container(parent, child)
1003
+ self._record_event(
1004
+ event_type="CONTAINER_SUBCONTAINER_ADDED",
1005
+ object_type="container",
1006
+ object_key=parent.key,
1007
+ details={"child_container_key": child.key},
1008
+ )
1009
+
1010
+ return parent
1011
+
1012
+ # Query methods
1013
+
1014
+ def get_observable(self, key: str) -> Observable | None:
1015
+ """Get observable by full key string."""
1016
+ return self._observables.get(key)
1017
+
1018
+ def get_check(self, key: str) -> Check | None:
1019
+ """Get check by full key string."""
1020
+ return self._checks.get(key)
1021
+
1022
+ def get_container(self, key: str) -> Container | None:
1023
+ """Get a container by key."""
1024
+ return self._containers.get(key)
1025
+
1026
+ def get_enrichment(self, key: str) -> Enrichment | None:
1027
+ """Get an enrichment by key."""
1028
+ return self._enrichments.get(key)
1029
+
1030
+ def get_threat_intel(self, key: str) -> ThreatIntel | None:
1031
+ """Get a threat intel by key."""
1032
+ return self._threat_intels.get(key)
1033
+
1034
+ @staticmethod
1035
+ def _normalize_taxonomies(value: Any) -> list[Taxonomy]:
1036
+ if value is None:
1037
+ return []
1038
+ if not isinstance(value, list):
1039
+ raise TypeError("taxonomies must be a list of taxonomy objects.")
1040
+ taxonomies = [Taxonomy.model_validate(item) for item in value]
1041
+ seen: set[str] = set()
1042
+ duplicates: set[str] = set()
1043
+ for taxonomy in taxonomies:
1044
+ if taxonomy.name in seen:
1045
+ duplicates.add(taxonomy.name)
1046
+ seen.add(taxonomy.name)
1047
+ if duplicates:
1048
+ dupes = ", ".join(sorted(duplicates))
1049
+ raise ValueError(f"Duplicate taxonomy name(s): {dupes}")
1050
+ return taxonomies
1051
+
1052
+ def get_root(self) -> Observable:
1053
+ """Get the root observable."""
1054
+ return self._root_observable
1055
+
1056
+ def update_model_metadata(
1057
+ self,
1058
+ model_type: Literal["observable", "check", "threat_intel", "enrichment", "container"],
1059
+ key: str,
1060
+ updates: dict[str, Any],
1061
+ *,
1062
+ dict_merge: dict[str, bool] | None = None,
1063
+ ):
1064
+ """
1065
+ Update mutable metadata fields for a stored model instance.
1066
+
1067
+ Args:
1068
+ model_type: Model family to update.
1069
+ key: Key of the target object.
1070
+ updates: Mapping of field names to new values. ``None`` values are ignored.
1071
+ dict_merge: Optional overrides for dict fields (True=merge, False=replace).
1072
+
1073
+ Returns:
1074
+ The updated model instance.
1075
+
1076
+ Raises:
1077
+ KeyError: If the key cannot be found.
1078
+ ValueError: If an unsupported field is requested.
1079
+ TypeError: If a dict field receives a non-dict value.
1080
+ """
1081
+ store_lookup: dict[str, dict[str, Any]] = {
1082
+ "observable": self._observables,
1083
+ "check": self._checks,
1084
+ "threat_intel": self._threat_intels,
1085
+ "enrichment": self._enrichments,
1086
+ "container": self._containers,
1087
+ }
1088
+ store = store_lookup[model_type]
1089
+ target = store.get(key)
1090
+ if target is None:
1091
+ raise KeyError(f"{model_type} '{key}' not found in investigation.")
1092
+
1093
+ if not updates:
1094
+ return target
1095
+
1096
+ rules = self._MODEL_METADATA_RULES[model_type]
1097
+ allowed_fields = rules["fields"]
1098
+ dict_fields = rules["dict_fields"]
1099
+
1100
+ changes: dict[str, dict[str, Any]] = {}
1101
+
1102
+ for field, value in updates.items():
1103
+ if field not in allowed_fields:
1104
+ raise ValueError(f"Field '{field}' is not mutable on {model_type}.")
1105
+ if value is None:
1106
+ continue
1107
+ old_value = deepcopy(getattr(target, field, None))
1108
+ if field == "level":
1109
+ value = normalize_level(value)
1110
+ if model_type == "threat_intel" and field == "taxonomies":
1111
+ value = self._normalize_taxonomies(value)
1112
+ if field in dict_fields:
1113
+ if not isinstance(value, dict):
1114
+ raise TypeError(f"Field '{field}' on {model_type} expects a dict value.")
1115
+ merge = dict_merge.get(field, True) if dict_merge else True
1116
+ if merge:
1117
+ current_value = getattr(target, field, None)
1118
+ if current_value is None:
1119
+ setattr(target, field, deepcopy(value))
1120
+ else:
1121
+ current_value.update(value)
1122
+ else:
1123
+ setattr(target, field, deepcopy(value))
1124
+ else:
1125
+ setattr(target, field, value)
1126
+ new_value = deepcopy(getattr(target, field, None))
1127
+ if old_value != new_value:
1128
+ changes[field] = {"old": old_value, "new": new_value}
1129
+
1130
+ if changes:
1131
+ self._record_event(
1132
+ event_type="METADATA_UPDATED",
1133
+ object_type=model_type,
1134
+ object_key=key,
1135
+ details={"changes": changes},
1136
+ )
1137
+ return target
1138
+
1139
+ def get_all_observables(self) -> dict[str, Observable]:
1140
+ """Get all observables."""
1141
+ return self._observables.copy()
1142
+
1143
+ def get_all_checks(self) -> dict[str, Check]:
1144
+ """Get all checks."""
1145
+ return self._checks.copy()
1146
+
1147
+ def get_all_threat_intels(self) -> dict[str, ThreatIntel]:
1148
+ """Get all threat intels."""
1149
+ return self._threat_intels.copy()
1150
+
1151
+ def get_all_enrichments(self) -> dict[str, Enrichment]:
1152
+ """Get all enrichments."""
1153
+ return self._enrichments.copy()
1154
+
1155
+ def get_all_containers(self) -> dict[str, Container]:
1156
+ """Get all containers."""
1157
+ return self._containers.copy()
1158
+
1159
+ # Scoring and statistics
1160
+
1161
+ def get_global_score(self) -> Decimal:
1162
+ """Get the global investigation score."""
1163
+ return self._score_engine.get_global_score()
1164
+
1165
+ def get_global_level(self) -> Level:
1166
+ """Get the global investigation level."""
1167
+ return self._score_engine.get_global_level()
1168
+
1169
+ def is_whitelisted(self) -> bool:
1170
+ """Return whether the investigation has any whitelist entries."""
1171
+ return bool(self._whitelists)
1172
+
1173
+ def add_whitelist(self, identifier: str, name: str, justification: str | None = None) -> InvestigationWhitelist:
1174
+ """
1175
+ Add or update a whitelist entry.
1176
+
1177
+ Args:
1178
+ identifier: Unique identifier for this whitelist entry.
1179
+ name: Human-readable name for the whitelist entry.
1180
+ justification: Optional markdown justification.
1181
+
1182
+ Returns:
1183
+ The stored whitelist entry.
1184
+ """
1185
+ identifier = str(identifier).strip()
1186
+ name = str(name).strip()
1187
+ if not identifier:
1188
+ raise ValueError("Whitelist identifier must be provided.")
1189
+ if not name:
1190
+ raise ValueError("Whitelist name must be provided.")
1191
+ if justification is not None:
1192
+ justification = str(justification)
1193
+
1194
+ entry = InvestigationWhitelist(identifier=identifier, name=name, justification=justification)
1195
+ self._whitelists[identifier] = entry
1196
+ self._record_event(
1197
+ event_type="WHITELIST_APPLIED",
1198
+ object_type="investigation",
1199
+ object_key=self.investigation_id,
1200
+ details={
1201
+ "identifier": identifier,
1202
+ "name": name,
1203
+ "justification": justification,
1204
+ },
1205
+ )
1206
+ return entry
1207
+
1208
+ def remove_whitelist(self, identifier: str) -> bool:
1209
+ """
1210
+ Remove a whitelist entry by identifier.
1211
+
1212
+ Returns:
1213
+ True if removed, False if it did not exist.
1214
+ """
1215
+ removed = self._whitelists.pop(identifier, None)
1216
+ if removed:
1217
+ self._record_event(
1218
+ event_type="WHITELIST_REMOVED",
1219
+ object_type="investigation",
1220
+ object_key=self.investigation_id,
1221
+ details={"identifier": identifier},
1222
+ )
1223
+ return removed is not None
1224
+
1225
+ def clear_whitelists(self) -> None:
1226
+ """Remove all whitelist entries."""
1227
+ if not self._whitelists:
1228
+ return
1229
+ removed = list(self._whitelists.keys())
1230
+ self._whitelists.clear()
1231
+ self._record_event(
1232
+ event_type="WHITELIST_CLEARED",
1233
+ object_type="investigation",
1234
+ object_key=self.investigation_id,
1235
+ details={"identifiers": removed},
1236
+ )
1237
+
1238
+ def get_whitelists(self) -> list[InvestigationWhitelist]:
1239
+ """Return a copy of all whitelist entries."""
1240
+ return [w.model_copy(deep=True) for w in self._whitelists.values()]
1241
+
1242
+ def get_statistics(self) -> StatisticsSchema:
1243
+ """Get comprehensive investigation statistics."""
1244
+ return self._stats.get_summary()
1245
+
1246
+ def finalize_relationships(self) -> None:
1247
+ """
1248
+ Finalize observable relationships by linking orphans to root.
1249
+
1250
+ Detects orphan sub-graphs (connected components not linked to root) and links
1251
+ the most appropriate starting node of each sub-graph to root.
1252
+ """
1253
+ root_key = self._root_observable.key
1254
+
1255
+ # Build adjacency lists for graph traversal
1256
+ graph = {key: set() for key in self._observables.keys()}
1257
+ incoming = {key: set() for key in self._observables.keys()}
1258
+
1259
+ for obs_key, obs in self._observables.items():
1260
+ for rel in obs.relationships:
1261
+ if rel.target_key in self._observables:
1262
+ graph[obs_key].add(rel.target_key)
1263
+ incoming[rel.target_key].add(obs_key)
1264
+
1265
+ # Find all connected components using BFS
1266
+ visited = set()
1267
+ components = []
1268
+
1269
+ def bfs(start_key: str) -> set[str]:
1270
+ """Breadth-first search to find connected component."""
1271
+ component = set()
1272
+ queue = [start_key]
1273
+ component.add(start_key)
1274
+
1275
+ while queue:
1276
+ current = queue.pop(0)
1277
+ # Check both outgoing and incoming edges for connectivity
1278
+ neighbors = graph[current] | incoming[current]
1279
+ for neighbor in neighbors:
1280
+ if neighbor not in component:
1281
+ component.add(neighbor)
1282
+ queue.append(neighbor)
1283
+
1284
+ return component
1285
+
1286
+ # Find all connected components
1287
+ for obs_key in self._observables.keys():
1288
+ if obs_key not in visited:
1289
+ component = bfs(obs_key)
1290
+ visited.update(component)
1291
+ components.append(component)
1292
+
1293
+ # Process each component that doesn't include root
1294
+ for component in components:
1295
+ if root_key in component:
1296
+ continue # This component is already connected to root
1297
+
1298
+ # Find the best starting node in this orphan sub-graph
1299
+ # Prioritize nodes with:
1300
+ # 1. No incoming edges (true source nodes)
1301
+ # 2. Most outgoing edges (central nodes)
1302
+ best_node = None
1303
+ best_score = (-1, -1) # (negative incoming count, outgoing count)
1304
+
1305
+ for node_key in component:
1306
+ incoming_count = len(incoming[node_key] & component)
1307
+ outgoing_count = len(graph[node_key] & component)
1308
+ score = (-incoming_count, outgoing_count)
1309
+
1310
+ if score > best_score:
1311
+ best_score = score
1312
+ best_node = node_key
1313
+
1314
+ # Link the best starting node to root
1315
+ if best_node:
1316
+ self._add_relationship_internal(self._root_observable, best_node, RelationshipType.RELATED_TO)
1317
+ self._record_event(
1318
+ event_type="RELATIONSHIP_CREATED",
1319
+ object_type="observable",
1320
+ object_key=self._root_observable.key,
1321
+ reason="Finalize relationships",
1322
+ details={
1323
+ "target_key": best_node,
1324
+ "relationship_type": RelationshipType.RELATED_TO.value,
1325
+ "direction": RelationshipType.RELATED_TO.get_default_direction().value,
1326
+ },
1327
+ )
1328
+ self._score_engine.recalculate_all()
1329
+
1330
+ # Investigation merging
1331
+
1332
+ def merge_investigation(self, other: Investigation) -> None:
1333
+ """
1334
+ Merge another investigation into this one.
1335
+
1336
+ Uses a two-pass approach to handle relationship dependencies:
1337
+ - Pass 1: Merge all observables, collecting deferred relationships
1338
+ - Pass 2: Add deferred relationships now that all observables exist
1339
+
1340
+ Args:
1341
+ other: The investigation to merge
1342
+ """
1343
+
1344
+ def _diff_fields(before: dict[str, Any], after: dict[str, Any]) -> list[str]:
1345
+ return [field for field, value in before.items() if value != after.get(field)]
1346
+
1347
+ def _snapshot_observable(obs: Observable) -> dict[str, Any]:
1348
+ relationships = [
1349
+ (
1350
+ rel.target_key,
1351
+ rel.relationship_type_name,
1352
+ rel.direction.value,
1353
+ )
1354
+ for rel in obs.relationships
1355
+ ]
1356
+ return {
1357
+ "score": obs.score,
1358
+ "level": obs.level,
1359
+ "comment": obs.comment,
1360
+ "extra": deepcopy(obs.extra),
1361
+ "internal": obs.internal,
1362
+ "whitelisted": obs.whitelisted,
1363
+ "threat_intels": sorted(ti.key for ti in obs.threat_intels),
1364
+ "relationships": sorted(relationships),
1365
+ }
1366
+
1367
+ def _snapshot_check(check: Check) -> dict[str, Any]:
1368
+ links = [
1369
+ (
1370
+ link.observable_key,
1371
+ link.propagation_mode.value,
1372
+ )
1373
+ for link in check.observable_links
1374
+ ]
1375
+ return {
1376
+ "score": check.score,
1377
+ "level": check.level,
1378
+ "comment": check.comment,
1379
+ "description": check.description,
1380
+ "extra": deepcopy(check.extra),
1381
+ "origin_investigation_id": check.origin_investigation_id,
1382
+ "observable_links": sorted(links),
1383
+ }
1384
+
1385
+ def _snapshot_threat_intel(ti: ThreatIntel) -> dict[str, Any]:
1386
+ return {
1387
+ "score": ti.score,
1388
+ "level": ti.level,
1389
+ "comment": ti.comment,
1390
+ "extra": deepcopy(ti.extra),
1391
+ "taxonomies": deepcopy(ti.taxonomies),
1392
+ }
1393
+
1394
+ def _snapshot_enrichment(enrichment: Enrichment) -> dict[str, Any]:
1395
+ return {
1396
+ "context": enrichment.context,
1397
+ "data": deepcopy(enrichment.data),
1398
+ }
1399
+
1400
+ def _snapshot_container(container: Container) -> dict[str, Any]:
1401
+ return {
1402
+ "description": container.description,
1403
+ "checks": sorted(check.key for check in container.checks),
1404
+ "sub_containers": sorted(container.sub_containers.keys()),
1405
+ }
1406
+
1407
+ merge_summary: list[dict[str, Any]] = []
1408
+
1409
+ (
1410
+ incoming_observables,
1411
+ incoming_threat_intels,
1412
+ incoming_checks,
1413
+ incoming_enrichments,
1414
+ incoming_containers,
1415
+ ) = self._clone_for_merge(other)
1416
+
1417
+ # PASS 1: Merge observables and collect deferred relationships
1418
+ all_deferred_relationships = []
1419
+ for obs in incoming_observables.values():
1420
+ existing = self._observables.get(obs.key)
1421
+ before = _snapshot_observable(existing) if existing else None
1422
+ _, deferred = self.add_observable(obs)
1423
+ all_deferred_relationships.extend(deferred)
1424
+ if existing:
1425
+ after = _snapshot_observable(existing)
1426
+ changed_fields = _diff_fields(before, after) if before else []
1427
+ action = "merged" if changed_fields else "skipped"
1428
+ merge_summary.append(
1429
+ {
1430
+ "object_type": "observable",
1431
+ "object_key": obs.key,
1432
+ "action": action,
1433
+ "changed_fields": changed_fields,
1434
+ }
1435
+ )
1436
+ else:
1437
+ merge_summary.append(
1438
+ {
1439
+ "object_type": "observable",
1440
+ "object_key": obs.key,
1441
+ "action": "created",
1442
+ "changed_fields": [],
1443
+ }
1444
+ )
1445
+
1446
+ # PASS 2: Process deferred relationships now that all observables exist
1447
+ for source_key, rel in all_deferred_relationships:
1448
+ source_obs = self._observables.get(source_key)
1449
+ if source_obs and rel.target_key in self._observables:
1450
+ # Both source and target exist - add relationship
1451
+ self._add_relationship_internal(source_obs, rel.target_key, rel.relationship_type, rel.direction)
1452
+ else:
1453
+ # Genuine error - target still doesn't exist after Pass 2
1454
+ logger.critical(
1455
+ "Relationship target '{}' not found after merge completion for observable '{}'. "
1456
+ "This indicates corrupted data or a bug in the merge logic.",
1457
+ rel.target_key,
1458
+ source_key,
1459
+ )
1460
+
1461
+ # Merge threat intels (need to link to observables)
1462
+ for ti in incoming_threat_intels.values():
1463
+ existing_ti = self._threat_intels.get(ti.key)
1464
+ before = _snapshot_threat_intel(existing_ti) if existing_ti else None
1465
+ # Find the observable this TI belongs to
1466
+ observable = self._observables.get(ti.observable_key)
1467
+ if observable:
1468
+ self.add_threat_intel(ti, observable)
1469
+ if existing_ti:
1470
+ after = _snapshot_threat_intel(existing_ti)
1471
+ changed_fields = _diff_fields(before, after) if before else []
1472
+ action = "merged" if changed_fields else "skipped"
1473
+ else:
1474
+ changed_fields = []
1475
+ action = "created"
1476
+ merge_summary.append(
1477
+ {
1478
+ "object_type": "threat_intel",
1479
+ "object_key": ti.key,
1480
+ "action": action,
1481
+ "changed_fields": changed_fields,
1482
+ }
1483
+ )
1484
+
1485
+ # Merge checks
1486
+ for check in incoming_checks.values():
1487
+ existing_check = self._checks.get(check.key)
1488
+ before = _snapshot_check(existing_check) if existing_check else None
1489
+ self.add_check(check)
1490
+ if existing_check:
1491
+ after = _snapshot_check(existing_check)
1492
+ changed_fields = _diff_fields(before, after) if before else []
1493
+ action = "merged" if changed_fields else "skipped"
1494
+ else:
1495
+ changed_fields = []
1496
+ action = "created"
1497
+ merge_summary.append(
1498
+ {
1499
+ "object_type": "check",
1500
+ "object_key": check.key,
1501
+ "action": action,
1502
+ "changed_fields": changed_fields,
1503
+ }
1504
+ )
1505
+
1506
+ # Merge enrichments
1507
+ for enrichment in incoming_enrichments.values():
1508
+ existing_enrichment = self._enrichments.get(enrichment.key)
1509
+ before = _snapshot_enrichment(existing_enrichment) if existing_enrichment else None
1510
+ self.add_enrichment(enrichment)
1511
+ if existing_enrichment:
1512
+ after = _snapshot_enrichment(existing_enrichment)
1513
+ changed_fields = _diff_fields(before, after) if before else []
1514
+ action = "merged" if changed_fields else "skipped"
1515
+ else:
1516
+ changed_fields = []
1517
+ action = "created"
1518
+ merge_summary.append(
1519
+ {
1520
+ "object_type": "enrichment",
1521
+ "object_key": enrichment.key,
1522
+ "action": action,
1523
+ "changed_fields": changed_fields,
1524
+ }
1525
+ )
1526
+
1527
+ # Merge containers
1528
+ for container in incoming_containers.values():
1529
+ existing_container = self._containers.get(container.key)
1530
+ before = _snapshot_container(existing_container) if existing_container else None
1531
+ self.add_container(container)
1532
+ if existing_container:
1533
+ after = _snapshot_container(existing_container)
1534
+ changed_fields = _diff_fields(before, after) if before else []
1535
+ action = "merged" if changed_fields else "skipped"
1536
+ else:
1537
+ changed_fields = []
1538
+ action = "created"
1539
+ merge_summary.append(
1540
+ {
1541
+ "object_type": "container",
1542
+ "object_key": container.key,
1543
+ "action": action,
1544
+ "changed_fields": changed_fields,
1545
+ }
1546
+ )
1547
+
1548
+ # Merge whitelists (other investigation overrides on identifier conflicts)
1549
+ for entry in other.get_whitelists():
1550
+ self.add_whitelist(entry.identifier, entry.name, entry.justification)
1551
+
1552
+ # Rebuild link index after merges
1553
+ self._score_engine.rebuild_link_index()
1554
+ self._refresh_check_links()
1555
+
1556
+ # Final score recalculation
1557
+ self._score_engine.recalculate_all()
1558
+
1559
+ self._record_event(
1560
+ event_type="INVESTIGATION_MERGED",
1561
+ object_type="investigation",
1562
+ object_key=self.investigation_id,
1563
+ details={
1564
+ "from_investigation_id": other.investigation_id,
1565
+ "into_investigation_id": self.investigation_id,
1566
+ "from_investigation_name": other.investigation_name,
1567
+ "into_investigation_name": self.investigation_name,
1568
+ "object_changes": merge_summary,
1569
+ },
1570
+ )
1571
+
1572
+ def _clone_for_merge(
1573
+ self, other: Investigation
1574
+ ) -> tuple[
1575
+ dict[str, Observable],
1576
+ dict[str, ThreatIntel],
1577
+ dict[str, Check],
1578
+ dict[str, Enrichment],
1579
+ dict[str, Container],
1580
+ ]:
1581
+ """Clone incoming models while preserving shared object references."""
1582
+ incoming_threat_intels = {key: ti.model_copy(deep=True) for key, ti in other._threat_intels.items()}
1583
+ incoming_checks = {key: check.model_copy(deep=True) for key, check in other._checks.items()}
1584
+ incoming_enrichments = {key: enrichment.model_copy(deep=True) for key, enrichment in other._enrichments.items()}
1585
+
1586
+ orphan_threat_intels: dict[str, ThreatIntel] = {}
1587
+
1588
+ def _copy_threat_intel(ti: ThreatIntel) -> ThreatIntel:
1589
+ if ti.key in incoming_threat_intels:
1590
+ return incoming_threat_intels[ti.key]
1591
+ existing = orphan_threat_intels.get(ti.key)
1592
+ if existing:
1593
+ return existing
1594
+ copied = ti.model_copy(deep=True)
1595
+ orphan_threat_intels[ti.key] = copied
1596
+ return copied
1597
+
1598
+ incoming_observables: dict[str, Observable] = {}
1599
+ for obs in other._observables.values():
1600
+ copied_obs = obs.model_copy(deep=True)
1601
+ if obs.threat_intels:
1602
+ copied_obs.threat_intels = [_copy_threat_intel(ti) for ti in obs.threat_intels]
1603
+ incoming_observables[obs.key] = copied_obs
1604
+
1605
+ orphan_checks: dict[str, Check] = {}
1606
+
1607
+ def _copy_check(check: Check) -> Check:
1608
+ if check.key in incoming_checks:
1609
+ return incoming_checks[check.key]
1610
+ existing = orphan_checks.get(check.key)
1611
+ if existing:
1612
+ return existing
1613
+ copied = check.model_copy(deep=True)
1614
+ orphan_checks[check.key] = copied
1615
+ return copied
1616
+
1617
+ incoming_containers: dict[str, Container] = {}
1618
+
1619
+ def _copy_container(container: Container) -> Container:
1620
+ existing = incoming_containers.get(container.key)
1621
+ if existing:
1622
+ return existing
1623
+ copied = Container(
1624
+ path=container.path,
1625
+ description=container.description,
1626
+ checks=[_copy_check(check) for check in container.checks],
1627
+ sub_containers={},
1628
+ key=container.key,
1629
+ )
1630
+ incoming_containers[container.key] = copied
1631
+ for sub_key, sub in container.sub_containers.items():
1632
+ copied.sub_containers[sub_key] = _copy_container(sub)
1633
+ return copied
1634
+
1635
+ for container in other._containers.values():
1636
+ _copy_container(container)
1637
+
1638
+ return (
1639
+ incoming_observables,
1640
+ incoming_threat_intels,
1641
+ incoming_checks,
1642
+ incoming_enrichments,
1643
+ incoming_containers,
1644
+ )