cyvest 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cyvest might be problematic. Click here for more details.

cyvest/io_rich.py ADDED
@@ -0,0 +1,579 @@
1
+ """
2
+ Rich console output for Cyvest investigations.
3
+
4
+ Provides formatted display of investigation results using the Rich library.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections.abc import Callable, Iterable
10
+ from datetime import datetime, timezone
11
+ from decimal import Decimal, InvalidOperation
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ from rich.align import Align
15
+ from rich.markup import escape
16
+ from rich.rule import Rule
17
+ from rich.table import Table
18
+ from rich.tree import Tree
19
+
20
+ from cyvest.levels import Level, get_color_level, get_color_score, normalize_level
21
+ from cyvest.model import Observable, Relationship, RelationshipDirection, _format_score_decimal
22
+
23
+ if TYPE_CHECKING:
24
+ from cyvest.cyvest import Cyvest
25
+
26
+
27
+ def _normalize_exclude_levels(levels: Level | Iterable[Level]) -> set[Level]:
28
+ base_excluded: set[Level] = {Level.NONE}
29
+ if levels is None:
30
+ return base_excluded
31
+ if isinstance(levels, Level):
32
+ return base_excluded | {levels}
33
+ if isinstance(levels, str):
34
+ return base_excluded | {normalize_level(levels)}
35
+
36
+ collected = list(levels)
37
+ if not collected:
38
+ return set()
39
+
40
+ normalized: set[Level] = set()
41
+ for level in collected:
42
+ normalized.add(normalize_level(level) if isinstance(level, str) else level)
43
+ return base_excluded | normalized
44
+
45
+
46
+ def _sort_key_by_score(item: Any) -> tuple[Decimal, str]:
47
+ score = getattr(item, "score", 0)
48
+ try:
49
+ decimal_score = Decimal(score)
50
+ except (TypeError, ValueError, InvalidOperation):
51
+ decimal_score = Decimal(0)
52
+
53
+ item_id = getattr(item, "check_id", "")
54
+ return (-decimal_score, item_id)
55
+
56
+
57
+ def _get_direction_symbol(rel: Relationship, reversed_edge: bool) -> str:
58
+ """Return an arrow indicating direction relative to traversal."""
59
+ direction = rel.direction
60
+ if isinstance(direction, str):
61
+ try:
62
+ direction = RelationshipDirection(direction)
63
+ except ValueError:
64
+ direction = RelationshipDirection.OUTBOUND
65
+
66
+ symbol_map = {
67
+ RelationshipDirection.OUTBOUND: "→",
68
+ RelationshipDirection.INBOUND: "←",
69
+ RelationshipDirection.BIDIRECTIONAL: "↔",
70
+ }
71
+ symbol = symbol_map.get(direction, "→")
72
+ if reversed_edge and direction != RelationshipDirection.BIDIRECTIONAL:
73
+ symbol = "←" if direction == RelationshipDirection.OUTBOUND else "→"
74
+ return symbol
75
+
76
+
77
+ def _build_observable_tree(
78
+ parent_tree: Tree,
79
+ obs: Any,
80
+ *,
81
+ all_observables: dict[str, Any],
82
+ reverse_relationships: dict[str, list[tuple[Any, Relationship]]],
83
+ visited: set[str],
84
+ rel_info: str = "",
85
+ ) -> None:
86
+ if obs.key in visited:
87
+ return
88
+ visited.add(obs.key)
89
+
90
+ color_level = get_color_level(obs.level)
91
+ color_score = get_color_score(obs.score)
92
+
93
+ linked_checks = ""
94
+ if obs.check_links:
95
+ checks_str = "[cyan], [/cyan]".join(escape(check_id) for check_id in obs.check_links)
96
+ linked_checks = f"[cyan][[/cyan]{checks_str}[cyan]][/cyan] "
97
+
98
+ whitelisted_str = " [green]WHITELISTED[/green]" if obs.whitelisted else ""
99
+
100
+ obs_info = (
101
+ f"{rel_info}{linked_checks}[bold]{obs.key}[/bold] "
102
+ f"[{color_score}]{obs.score_display}[/{color_score}] "
103
+ f"[{color_level}]{obs.level.name}[/{color_level}]"
104
+ f"{whitelisted_str}"
105
+ )
106
+
107
+ child_tree = parent_tree.add(obs_info)
108
+
109
+ # Add outbound children
110
+ for rel in obs.relationships:
111
+ child_obs = all_observables.get(rel.target_key)
112
+ if child_obs:
113
+ direction_symbol = _get_direction_symbol(rel, reversed_edge=False)
114
+ rel_label = f"[dim]{rel.relationship_type_name}[/dim] {direction_symbol} "
115
+ _build_observable_tree(
116
+ child_tree,
117
+ child_obs,
118
+ all_observables=all_observables,
119
+ reverse_relationships=reverse_relationships,
120
+ visited=visited,
121
+ rel_info=rel_label,
122
+ )
123
+
124
+ # Add inbound children (observables pointing to this one)
125
+ for source_obs, rel in reverse_relationships.get(obs.key, []):
126
+ if source_obs.key == obs.key:
127
+ continue
128
+ direction_symbol = _get_direction_symbol(rel, reversed_edge=True)
129
+ rel_label = f"[dim]{rel.relationship_type_name}[/dim] {direction_symbol} "
130
+ _build_observable_tree(
131
+ child_tree,
132
+ source_obs,
133
+ all_observables=all_observables,
134
+ reverse_relationships=reverse_relationships,
135
+ visited=visited,
136
+ rel_info=rel_label,
137
+ )
138
+
139
+
140
+ def _render_audit_log_table(
141
+ *,
142
+ rich_print: Callable[[Any], None],
143
+ title: str,
144
+ events: Iterable[Any],
145
+ started_at: datetime | None,
146
+ ) -> None:
147
+ def _render_score_change(details: dict[str, Any]) -> str:
148
+ old_score = details.get("old_score")
149
+ new_score = details.get("new_score")
150
+ old_level = details.get("old_level")
151
+ new_level = details.get("new_level")
152
+
153
+ parts: list[str] = []
154
+ if old_score is not None and new_score is not None:
155
+ old_score = old_score if isinstance(old_score, Decimal) else Decimal(str(old_score))
156
+ new_score = new_score if isinstance(new_score, Decimal) else Decimal(str(new_score))
157
+ old_score_color = get_color_score(old_score)
158
+ new_score_color = get_color_score(new_score)
159
+ score_str = (
160
+ f"[{old_score_color}]{_format_score_decimal(old_score)}[/"
161
+ f"{old_score_color}] → "
162
+ f"[{new_score_color}]{_format_score_decimal(new_score)}[/"
163
+ f"{new_score_color}]"
164
+ )
165
+ parts.append(f"Score: {score_str}")
166
+
167
+ if old_level is not None and new_level is not None:
168
+ old_level_enum = normalize_level(old_level)
169
+ new_level_enum = normalize_level(new_level)
170
+ old_level_color = get_color_level(old_level_enum)
171
+ new_level_color = get_color_level(new_level_enum)
172
+ level_str = (
173
+ f"[{old_level_color}]{old_level_enum.name}[/"
174
+ f"{old_level_color}] → "
175
+ f"[{new_level_color}]{new_level_enum.name}[/"
176
+ f"{new_level_color}]"
177
+ )
178
+ parts.append(f"Level: {level_str}")
179
+
180
+ return " | ".join(parts) if parts else "[dim]-[/dim]"
181
+
182
+ def _render_level_change(details: dict[str, Any]) -> str:
183
+ old_level = details.get("old_level")
184
+ new_level = details.get("new_level")
185
+ score = details.get("score")
186
+ if old_level is None or new_level is None:
187
+ return "[dim]-[/dim]"
188
+ old_level_enum = normalize_level(old_level)
189
+ new_level_enum = normalize_level(new_level)
190
+ old_level_color = get_color_level(old_level_enum)
191
+ new_level_color = get_color_level(new_level_enum)
192
+ level_str = (
193
+ f"[{old_level_color}]{old_level_enum.name}[/"
194
+ f"{old_level_color}] → "
195
+ f"[{new_level_color}]{new_level_enum.name}[/"
196
+ f"{new_level_color}]"
197
+ )
198
+ if score is None:
199
+ return f"Level: {level_str}"
200
+ score = score if isinstance(score, Decimal) else Decimal(str(score))
201
+ score_color = get_color_score(score)
202
+ score_str = f"[{score_color}]{_format_score_decimal(score)}[/{score_color}]"
203
+ return f"Level: {level_str} | Score: {score_str}"
204
+
205
+ def _render_merge_event(details: dict[str, Any]) -> str:
206
+ from_name = details.get("from_investigation_name")
207
+ into_name = details.get("into_investigation_name")
208
+ from_id = details.get("from_investigation_id")
209
+ into_id = details.get("into_investigation_id")
210
+ from_label = escape(str(from_name)) if from_name else escape(str(from_id))
211
+ into_label = escape(str(into_name)) if into_name else escape(str(into_id))
212
+ if not from_label or from_label == "None":
213
+ from_label = "[dim]-[/dim]"
214
+ if not into_label or into_label == "None":
215
+ into_label = "[dim]-[/dim]"
216
+
217
+ object_changes = details.get("object_changes") or []
218
+ counts: dict[str, int] = {}
219
+ for change in object_changes:
220
+ action = change.get("action")
221
+ if not action:
222
+ continue
223
+ counts[action] = counts.get(action, 0) + 1
224
+
225
+ if counts:
226
+ parts = [f"{key}={value}" for key, value in sorted(counts.items())]
227
+ summary = ", ".join(parts)
228
+ return f"Merge: {from_label} → {into_label} | Changes: {summary}"
229
+
230
+ return f"Merge: {from_label} → {into_label}"
231
+
232
+ def _render_threat_intel_attached(details: dict[str, Any]) -> str:
233
+ source = details.get("source")
234
+ score = details.get("score")
235
+ level = details.get("level")
236
+ parts: list[str] = []
237
+ if source:
238
+ parts.append(f"Source: [cyan]{escape(str(source))}[/cyan]")
239
+ if level is not None:
240
+ level_enum = normalize_level(level)
241
+ level_color = get_color_level(level_enum)
242
+ parts.append(f"Level: [{level_color}]{level_enum.name}[/{level_color}]")
243
+ if score is not None:
244
+ score_value = score if isinstance(score, Decimal) else Decimal(str(score))
245
+ score_color = get_color_score(score_value)
246
+ score_str = f"[{score_color}]{_format_score_decimal(score_value)}[/{score_color}]"
247
+ parts.append(f"Score: {score_str}")
248
+ return " | ".join(parts) if parts else "[dim]-[/dim]"
249
+
250
+ detail_renderers: dict[str, Callable[[dict[str, Any]], str]] = {
251
+ "SCORE_CHANGED": _render_score_change,
252
+ "SCORE_RECALCULATED": _render_score_change,
253
+ "LEVEL_UPDATED": _render_level_change,
254
+ "INVESTIGATION_MERGED": _render_merge_event,
255
+ "THREAT_INTEL_ATTACHED": _render_threat_intel_attached,
256
+ }
257
+
258
+ def _coerce_utc(value: datetime) -> datetime:
259
+ if value.tzinfo is None:
260
+ return value.replace(tzinfo=timezone.utc)
261
+ return value.astimezone(timezone.utc)
262
+
263
+ def _format_elapsed(total_seconds: float) -> str:
264
+ total_ms = int(round(total_seconds * 1000))
265
+ if total_ms < 0:
266
+ total_ms = 0
267
+ hours, rem_ms = divmod(total_ms, 3_600_000)
268
+ minutes, rem_ms = divmod(rem_ms, 60_000)
269
+ seconds, ms = divmod(rem_ms, 1000)
270
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{ms:03d}"
271
+
272
+ table = Table(title=title, show_lines=False)
273
+ table.add_column("#", justify="right")
274
+ table.add_column("Elapsed", style="dim")
275
+ table.add_column("Event")
276
+ table.add_column("Object")
277
+ table.add_column("Context")
278
+
279
+ events_sorted = sorted(events, key=lambda evt: evt.timestamp)
280
+ effective_start = _coerce_utc(started_at) if started_at is not None else None
281
+ if effective_start is None and events_sorted:
282
+ effective_start = _coerce_utc(events_sorted[0].timestamp)
283
+
284
+ grouped_events: dict[str, list[Any]] = {}
285
+ group_order: list[str] = []
286
+ for event in events_sorted:
287
+ group_key = event.object_key or ""
288
+ if group_key not in grouped_events:
289
+ grouped_events[group_key] = []
290
+ group_order.append(group_key)
291
+ grouped_events[group_key].append(event)
292
+
293
+ row_idx = 1
294
+ for group_key in group_order:
295
+ if row_idx > 1:
296
+ table.add_section()
297
+ for event in grouped_events[group_key]:
298
+ event_timestamp = _coerce_utc(event.timestamp)
299
+ elapsed = ""
300
+ if effective_start is not None:
301
+ elapsed = _format_elapsed((event_timestamp - effective_start).total_seconds())
302
+
303
+ event_type = escape(event.event_type)
304
+ object_label = "[dim]-[/dim]"
305
+ if event.object_key:
306
+ object_label = escape(event.object_key)
307
+ reason = escape(event.reason) if event.reason else "[dim]-[/dim]"
308
+ details = "[dim]-[/dim]"
309
+ renderer = detail_renderers.get(event.event_type)
310
+ if renderer:
311
+ details = renderer(getattr(event, "details", {}) or {})
312
+
313
+ if reason == "[dim]-[/dim]":
314
+ context = details
315
+ elif details == "[dim]-[/dim]":
316
+ context = reason
317
+ else:
318
+ context = details
319
+
320
+ table.add_row(
321
+ str(row_idx),
322
+ elapsed,
323
+ event_type,
324
+ object_label,
325
+ context,
326
+ )
327
+ row_idx += 1
328
+
329
+ table.caption = "No audit events recorded." if not events_sorted else ""
330
+ rich_print(table)
331
+
332
+
333
+ def display_summary(
334
+ cv: Cyvest,
335
+ rich_print: Callable[[Any], None],
336
+ show_graph: bool = True,
337
+ exclude_levels: Level | Iterable[Level] = Level.NONE,
338
+ show_audit_log: bool = False,
339
+ ) -> None:
340
+ """
341
+ Display a comprehensive summary of the investigation using Rich.
342
+
343
+ Args:
344
+ cv: Cyvest investigation to display
345
+ rich_print: A rich renderable handler that is called with renderables for output
346
+ show_graph: Whether to display the observable graph
347
+ exclude_levels: Level(s) to omit from the report (default: Level.NONE)
348
+ show_audit_log: Whether to display the investigation audit log (default: False)
349
+ """
350
+
351
+ resolved_excluded_levels = _normalize_exclude_levels(exclude_levels)
352
+
353
+ all_checks = cv.check_get_all().values()
354
+ filtered_checks = [c for c in all_checks if c.level not in resolved_excluded_levels]
355
+ applied_checks = sum(1 for c in filtered_checks if c.level != Level.NONE)
356
+
357
+ excluded_caption = ""
358
+ if resolved_excluded_levels:
359
+ excluded_names = ", ".join(level.name for level in sorted(resolved_excluded_levels, key=lambda lvl: lvl.value))
360
+ excluded_caption = f" (excluding: {excluded_names})"
361
+
362
+ caption_parts = [
363
+ f"Total Checks: {len(cv.check_get_all())}",
364
+ f"Displayed: {len(filtered_checks)}{excluded_caption}",
365
+ f"Applied: {applied_checks}",
366
+ ]
367
+
368
+ table = Table(
369
+ title="Investigation Report",
370
+ caption=" | ".join(caption_parts),
371
+ )
372
+ table.add_column("Name")
373
+ table.add_column("Score", justify="right")
374
+ table.add_column("Level", justify="center")
375
+
376
+ # Checks section
377
+ rule = Rule("[bold magenta]CHECKS[/bold magenta]")
378
+ table.add_row(rule, "-", "-")
379
+
380
+ # Organize checks by scope
381
+ checks_by_scope: dict[str, list[Any]] = {}
382
+ for check in cv.check_get_all().values():
383
+ if check.level in resolved_excluded_levels:
384
+ continue
385
+ if check.scope not in checks_by_scope:
386
+ checks_by_scope[check.scope] = []
387
+ checks_by_scope[check.scope].append(check)
388
+
389
+ for scope_name, checks in checks_by_scope.items():
390
+ scope_rule = Align(f"[bold magenta]{scope_name}[/bold magenta]", align="left")
391
+ table.add_row(scope_rule, "-", "-")
392
+ checks = sorted(checks, key=_sort_key_by_score)
393
+ for check in checks:
394
+ color_level = get_color_level(check.level)
395
+ color_score = get_color_score(check.score)
396
+ name = f" {check.check_id}"
397
+ score = f"[{color_score}]{check.score_display}[/{color_score}]"
398
+ level = f"[{color_level}]{check.level.name}[/{color_level}]"
399
+ table.add_row(name, score, level)
400
+
401
+ # Containers section (if any)
402
+ if cv.container_get_all():
403
+ table.add_section()
404
+ rule = Rule("[bold magenta]CONTAINERS[/bold magenta]")
405
+ table.add_row(rule, "-", "-")
406
+
407
+ for container in cv.container_get_all().values():
408
+ agg_score = container.get_aggregated_score()
409
+ agg_level = container.get_aggregated_level()
410
+ color_level = get_color_level(agg_level)
411
+ color_score = get_color_score(agg_score)
412
+
413
+ name = f" {container.path}"
414
+ score = f"[{color_score}]{agg_score:.2f}[/{color_score}]"
415
+ level = f"[{color_level}]{agg_level.name}[/{color_level}]"
416
+ table.add_row(name, score, level)
417
+
418
+ # Checks by level section
419
+ table.add_section()
420
+ rule = Rule("[bold magenta]BY LEVEL[/bold magenta]")
421
+ table.add_row(rule, "-", "-")
422
+
423
+ for level_enum in [Level.MALICIOUS, Level.SUSPICIOUS, Level.NOTABLE, Level.SAFE, Level.INFO, Level.TRUSTED]:
424
+ if level_enum in resolved_excluded_levels:
425
+ continue
426
+ checks = [
427
+ c for c in cv.check_get_all().values() if c.level == level_enum and c.level not in resolved_excluded_levels
428
+ ]
429
+ checks = sorted(checks, key=_sort_key_by_score)
430
+ if checks:
431
+ color_level = get_color_level(level_enum)
432
+ level_rule = Align(
433
+ f"[bold {color_level}]{level_enum.name}: {len(checks)} check(s)[/bold {color_level}]",
434
+ align="center",
435
+ )
436
+ table.add_row(level_rule, "-", "-")
437
+
438
+ for check in checks:
439
+ color_score = get_color_score(check.score)
440
+ name = f" {check.check_id}"
441
+ score = f"[{color_score}]{check.score_display}[/{color_score}]"
442
+ level = f"[{color_level}]{check.level.name}[/{color_level}]"
443
+ table.add_row(name, score, level)
444
+
445
+ # Enrichments section (if any)
446
+ if cv.enrichment_get_all():
447
+ table.add_section()
448
+ rule = Rule(f"[bold magenta]ENRICHMENTS[/bold magenta]: {len(cv.enrichment_get_all())} enrichments")
449
+ table.add_row(rule, "-", "-")
450
+
451
+ for enr in cv.enrichment_get_all().values():
452
+ table.add_row(f" {enr.name}", "-", "-")
453
+
454
+ # Statistics section
455
+ table.add_section()
456
+ rule = Rule("[bold magenta]STATISTICS[/bold magenta]")
457
+ table.add_row(rule, "-", "-")
458
+
459
+ stats = cv.get_statistics()
460
+ stat_items = [
461
+ ("Total Observables", stats.total_observables),
462
+ ("Internal Observables", stats.internal_observables),
463
+ ("External Observables", stats.external_observables),
464
+ ("Whitelisted Observables", stats.whitelisted_observables),
465
+ ("Total Threat Intel", stats.total_threat_intel),
466
+ ]
467
+
468
+ for stat_name, stat_value in stat_items:
469
+ table.add_row(f" {stat_name}", str(stat_value), "-")
470
+
471
+ # Global score footer
472
+ global_score = cv.get_global_score()
473
+ global_level = cv.get_global_level()
474
+ color_level = get_color_level(global_level)
475
+ color_score = get_color_score(global_score)
476
+
477
+ table.add_section()
478
+ table.add_row(
479
+ Align("[bold]GLOBAL SCORE[/bold]", align="center"),
480
+ f"[{color_score}]{global_score:.2f}[/{color_score}]",
481
+ f"[{color_level}]{global_level.name}[/{color_level}]",
482
+ )
483
+
484
+ # Print table
485
+ rich_print(table)
486
+
487
+ # Observable graph (if requested)
488
+ if show_graph and cv.observable_get_all():
489
+ tree = Tree("Observables", hide_root=True)
490
+
491
+ # Precompute reverse relationships to traverse observables that only
492
+ # appear as targets (e.g., child → parent links).
493
+ all_observables = cv.observable_get_all()
494
+ reverse_relationships: dict[str, list[tuple[Observable, Relationship]]] = {}
495
+ for source_obs in all_observables.values():
496
+ for rel in source_obs.relationships:
497
+ reverse_relationships.setdefault(rel.target_key, []).append((source_obs, rel))
498
+
499
+ # Start from root
500
+ root = cv.observable_get_root()
501
+ if root:
502
+ _build_observable_tree(
503
+ tree,
504
+ root,
505
+ all_observables=all_observables,
506
+ reverse_relationships=reverse_relationships,
507
+ visited=set(),
508
+ )
509
+
510
+ rich_print(tree)
511
+
512
+ if show_audit_log:
513
+ investigation = getattr(cv, "_investigation", None)
514
+ events = investigation.get_audit_log() if investigation else []
515
+ if events:
516
+ started_at = investigation.started_at if investigation else None
517
+ _render_audit_log_table(
518
+ rich_print=rich_print,
519
+ title="Audit Log",
520
+ events=events,
521
+ started_at=started_at,
522
+ )
523
+
524
+
525
+ def display_statistics(cv: Cyvest, rich_print: Callable[[Any], None]) -> None:
526
+ """
527
+ Display detailed statistics about the investigation.
528
+
529
+ Args:
530
+ cv: Cyvest investigation
531
+ rich_print: A rich renderable handler that is called with renderables for output
532
+ """
533
+ stats = cv.get_statistics()
534
+
535
+ # Observable statistics table
536
+ obs_table = Table(title="Observable Statistics")
537
+ obs_table.add_column("Type", style="cyan")
538
+ obs_table.add_column("Total", justify="right")
539
+ obs_table.add_column("INFO", justify="right", style="cyan")
540
+ obs_table.add_column("NOTABLE", justify="right", style="yellow")
541
+ obs_table.add_column("SUSPICIOUS", justify="right", style="orange3")
542
+ obs_table.add_column("MALICIOUS", justify="right", style="red")
543
+
544
+ obs_by_type_level = stats.observables_by_type_and_level
545
+ for obs_type, count in stats.observables_by_type.items():
546
+ levels = obs_by_type_level.get(obs_type, {})
547
+ obs_table.add_row(
548
+ obs_type.upper(),
549
+ str(count),
550
+ str(levels.get("INFO", 0)),
551
+ str(levels.get("NOTABLE", 0)),
552
+ str(levels.get("SUSPICIOUS", 0)),
553
+ str(levels.get("MALICIOUS", 0)),
554
+ )
555
+
556
+ rich_print(obs_table)
557
+
558
+ # Check statistics table
559
+ rich_print("")
560
+ check_table = Table(title="Check Statistics")
561
+ check_table.add_column("Scope", style="cyan")
562
+ check_table.add_column("Count", justify="right")
563
+
564
+ for scope, count in stats.checks_by_scope.items():
565
+ check_table.add_row(scope, str(count))
566
+
567
+ rich_print(check_table)
568
+
569
+ # Threat intel statistics
570
+ if stats.total_threat_intel > 0:
571
+ rich_print("")
572
+ ti_table = Table(title="Threat Intelligence Statistics")
573
+ ti_table.add_column("Source", style="cyan")
574
+ ti_table.add_column("Count", justify="right")
575
+
576
+ for source, count in stats.threat_intel_by_source.items():
577
+ ti_table.add_row(source, str(count))
578
+
579
+ rich_print(ti_table)
cyvest/io_schema.py ADDED
@@ -0,0 +1,35 @@
1
+ """
2
+ JSON Schema definition for serialized Cyvest investigations.
3
+
4
+ The schema mirrors the structure emitted by `serialize_investigation` in
5
+ `cyvest.io_serialization` so consumers can validate exports or generate
6
+ typed bindings.
7
+
8
+ This module uses Pydantic's `model_json_schema(mode='serialization')` to generate
9
+ schemas that match the actual serialized output (respecting field_serializer decorators).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Any
15
+
16
+ from cyvest.model_schema import InvestigationSchema
17
+
18
+
19
+ def get_investigation_schema() -> dict[str, Any]:
20
+ """
21
+ Get the JSON Schema for serialized investigations.
22
+
23
+ Generates a JSON Schema (Draft 2020-12) that describes the output of
24
+ `serialize_investigation()`. The schema uses Pydantic's `model_json_schema`
25
+ with `mode='serialization'`, which respects field_serializer decorators and
26
+ matches the actual `model_dump()` output structure.
27
+
28
+ The returned schema automatically includes all referenced entity types
29
+ (Observable, Check, ThreatIntel, Enrichment, Container, InvestigationWhitelist)
30
+ in the `$defs` section.
31
+
32
+ Returns:
33
+ dict[str, Any]: Schema dictionary compliant with JSON Schema Draft 2020-12.
34
+ """
35
+ return InvestigationSchema.model_json_schema(mode="serialization", by_alias=True)