cyvest 0.1.0__py3-none-any.whl → 5.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cyvest/model.py ADDED
@@ -0,0 +1,595 @@
1
+ """
2
+ Core data models for Cyvest investigation framework.
3
+
4
+ Defines the base classes for Check, Observable, ThreatIntel, Enrichment, Tag,
5
+ and InvestigationWhitelist using Pydantic BaseModel.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from datetime import datetime
11
+ from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
12
+ from typing import Annotated, Any
13
+
14
+ from pydantic import (
15
+ BaseModel,
16
+ ConfigDict,
17
+ Field,
18
+ PrivateAttr,
19
+ StrictStr,
20
+ computed_field,
21
+ field_serializer,
22
+ field_validator,
23
+ model_validator,
24
+ )
25
+ from typing_extensions import Self
26
+
27
+ from cyvest import keys
28
+ from cyvest.level_score_rules import apply_creation_score_level_defaults
29
+ from cyvest.levels import Level, get_level_from_score, normalize_level
30
+ from cyvest.model_enums import (
31
+ ObservableType,
32
+ PropagationMode,
33
+ RelationshipDirection,
34
+ RelationshipType,
35
+ )
36
+
37
+ _DEFAULT_SCORE_PLACES = 2
38
+
39
+
40
+ class AliasDumpModel(BaseModel):
41
+ """Base model that defaults to by_alias=True for JSON-compatible serialization."""
42
+
43
+ def model_dump(self, *, by_alias: bool = True, **kwargs: Any) -> dict[str, Any]:
44
+ """Serialize to dict, defaulting to by_alias=True for JSON compatibility."""
45
+ return super().model_dump(by_alias=by_alias, **kwargs)
46
+
47
+ def model_dump_json(self, *, by_alias: bool = True, **kwargs: Any) -> str:
48
+ """Serialize to JSON string, defaulting to by_alias=True."""
49
+ return super().model_dump_json(by_alias=by_alias, **kwargs)
50
+
51
+
52
+ def _format_score_decimal(value: Decimal | None, *, places: int = _DEFAULT_SCORE_PLACES) -> str:
53
+ if value is None:
54
+ return "-"
55
+ if places < 0:
56
+ raise ValueError("places must be >= 0")
57
+ quantizer = Decimal("1").scaleb(-places)
58
+ try:
59
+ quantized = value.quantize(quantizer, rounding=ROUND_HALF_UP)
60
+ if quantized == 0:
61
+ quantized = Decimal("0").quantize(quantizer)
62
+ return format(quantized, "f")
63
+ except InvalidOperation:
64
+ return str(value)
65
+
66
+
67
+ class AuditEvent(BaseModel):
68
+ """Centralized audit event for investigation-level changes."""
69
+
70
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
71
+
72
+ event_id: str
73
+ timestamp: datetime
74
+ event_type: str
75
+ actor: str | None = None
76
+ reason: str | None = None
77
+ tool: str | None = None
78
+ object_type: str | None = None
79
+ object_key: str | None = None
80
+ details: dict[str, Any] = Field(default_factory=dict)
81
+
82
+
83
+ class InvestigationWhitelist(BaseModel):
84
+ """Represents a whitelist entry on an investigation."""
85
+
86
+ model_config = ConfigDict(str_strip_whitespace=True, frozen=True)
87
+
88
+ identifier: Annotated[str, Field(min_length=1)]
89
+ name: Annotated[str, Field(min_length=1)]
90
+ justification: str | None = None
91
+
92
+
93
+ class Relationship(BaseModel):
94
+ """Represents a relationship between observables."""
95
+
96
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
97
+
98
+ target_key: str = Field(...)
99
+ relationship_type: RelationshipType | str = Field(...)
100
+ direction: RelationshipDirection = Field(...)
101
+
102
+ @model_validator(mode="before")
103
+ @classmethod
104
+ def ensure_defaults(cls, values: Any) -> Any:
105
+ if not isinstance(values, dict):
106
+ return values
107
+ if values.get("direction") is None:
108
+ rel_type = values.get("relationship_type")
109
+
110
+ # Use semantic default when relationship type is known, otherwise fall back to outbound.
111
+ default_direction = RelationshipDirection.OUTBOUND
112
+ if isinstance(rel_type, RelationshipType):
113
+ default_direction = rel_type.get_default_direction()
114
+ else:
115
+ try:
116
+ rel_enum = RelationshipType(rel_type)
117
+ default_direction = rel_enum.get_default_direction()
118
+ values["relationship_type"] = rel_enum
119
+ except Exception:
120
+ # Unknown type: keep fallback outbound
121
+ pass
122
+
123
+ values["direction"] = default_direction
124
+ return values
125
+
126
+ @field_validator("relationship_type", mode="before")
127
+ @classmethod
128
+ def coerce_relationship_type(cls, v: Any) -> RelationshipType | str:
129
+ """Normalize relationship type to enum if possible."""
130
+ if isinstance(v, RelationshipType):
131
+ return v
132
+ if isinstance(v, str):
133
+ try:
134
+ return RelationshipType(v)
135
+ except ValueError:
136
+ # Keep as string if not a recognized relationship type
137
+ return v
138
+ return v
139
+
140
+ @field_serializer("relationship_type")
141
+ def serialize_relationship_type(self, v: RelationshipType | str) -> str:
142
+ return v.value if isinstance(v, RelationshipType) else v
143
+
144
+ @field_validator("direction", mode="before")
145
+ @classmethod
146
+ def coerce_direction(cls, v: Any) -> RelationshipDirection:
147
+ if v is None:
148
+ return RelationshipDirection.OUTBOUND
149
+ if isinstance(v, RelationshipDirection):
150
+ return v
151
+ if isinstance(v, str):
152
+ return RelationshipDirection(v)
153
+ raise TypeError("Invalid direction type")
154
+
155
+ @property
156
+ def relationship_type_name(self) -> str:
157
+ return (
158
+ self.relationship_type.value
159
+ if isinstance(self.relationship_type, RelationshipType)
160
+ else self.relationship_type
161
+ )
162
+
163
+
164
+ class Taxonomy(BaseModel):
165
+ """Represents a structured taxonomy entry for threat intelligence."""
166
+
167
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
168
+
169
+ level: Level = Field(...)
170
+ name: StrictStr = Field(...)
171
+ value: StrictStr = Field(...)
172
+
173
+ @field_validator("level", mode="before")
174
+ @classmethod
175
+ def coerce_level(cls, v: Any) -> Level:
176
+ return normalize_level(v)
177
+
178
+
179
+ class ThreatIntel(BaseModel):
180
+ """
181
+ Represents threat intelligence from an external source.
182
+
183
+ Threat intelligence provides verdicts about observables from sources
184
+ like VirusTotal, URLScan.io, etc.
185
+ """
186
+
187
+ model_config = ConfigDict(arbitrary_types_allowed=True)
188
+
189
+ source: str = Field(...)
190
+ observable_key: str = Field(...)
191
+ comment: str = Field(...)
192
+ extra: dict[str, Any] = Field(...)
193
+ score: Decimal = Field(...)
194
+ level: Level = Field(...)
195
+ taxonomies: list[Taxonomy] = Field(...)
196
+ key: str = Field(...)
197
+
198
+ @field_validator("extra", mode="before")
199
+ @classmethod
200
+ def coerce_extra(cls, v: Any) -> dict[str, Any]:
201
+ if v is None:
202
+ return {}
203
+ return v
204
+
205
+ @field_validator("score", mode="before")
206
+ @classmethod
207
+ def coerce_score(cls, v: Any) -> Decimal:
208
+ if isinstance(v, Decimal):
209
+ return v
210
+ return Decimal(str(v))
211
+
212
+ @field_validator("level", mode="before")
213
+ @classmethod
214
+ def coerce_level(cls, v: Any) -> Level:
215
+ return normalize_level(v)
216
+
217
+ @field_validator("taxonomies")
218
+ @classmethod
219
+ def ensure_unique_taxonomy_names(cls, v: list[Taxonomy]) -> list[Taxonomy]:
220
+ seen: set[str] = set()
221
+ duplicates: set[str] = set()
222
+ for taxonomy in v:
223
+ if taxonomy.name in seen:
224
+ duplicates.add(taxonomy.name)
225
+ seen.add(taxonomy.name)
226
+ if duplicates:
227
+ dupes = ", ".join(sorted(duplicates))
228
+ raise ValueError(f"Duplicate taxonomy name(s): {dupes}")
229
+ return v
230
+
231
+ @model_validator(mode="before")
232
+ @classmethod
233
+ def ensure_defaults(cls, values: Any) -> Any:
234
+ values = apply_creation_score_level_defaults(
235
+ values,
236
+ default_level_no_score=Level.INFO,
237
+ require_score=True,
238
+ )
239
+ if not isinstance(values, dict):
240
+ return values
241
+
242
+ if values.get("observable_key") is None:
243
+ values["observable_key"] = ""
244
+ if "extra" not in values:
245
+ values["extra"] = {}
246
+ if "comment" not in values:
247
+ values["comment"] = ""
248
+ if values.get("taxonomies") is None:
249
+ values["taxonomies"] = []
250
+ if "key" not in values:
251
+ values["key"] = ""
252
+ return values
253
+
254
+ @model_validator(mode="after")
255
+ def generate_key(self) -> Self:
256
+ """Generate key."""
257
+ if not self.key and self.observable_key:
258
+ self.key = keys.generate_threat_intel_key(self.source, self.observable_key)
259
+
260
+ return self
261
+
262
+ @field_serializer("score")
263
+ def serialize_score(self, v: Decimal) -> float:
264
+ return float(v)
265
+
266
+ @computed_field(return_type=str)
267
+ @property
268
+ def score_display(self) -> str:
269
+ return _format_score_decimal(self.score)
270
+
271
+
272
+ class Observable(AliasDumpModel):
273
+ """
274
+ Represents a cyber observable (IP, URL, domain, hash, etc.).
275
+
276
+ Observables can be linked to threat intelligence, checks, and other observables
277
+ through relationships.
278
+ """
279
+
280
+ model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
281
+
282
+ obs_type: ObservableType | str = Field(..., alias="type")
283
+ value: str = Field(...)
284
+ internal: bool = Field(...)
285
+ whitelisted: bool = Field(...)
286
+ comment: str = Field(...)
287
+ extra: dict[str, Any] = Field(...)
288
+ score: Decimal = Field(...)
289
+ level: Level = Field(...)
290
+ threat_intels: list[ThreatIntel] = Field(...)
291
+ relationships: list[Relationship] = Field(...)
292
+ key: str = Field(...)
293
+ _check_links: list[str] = PrivateAttr(default_factory=list)
294
+ _from_shared_context: bool = PrivateAttr(default=False)
295
+
296
+ @field_validator("obs_type", mode="before")
297
+ @classmethod
298
+ def coerce_obs_type(cls, v: Any) -> ObservableType | str:
299
+ if isinstance(v, ObservableType):
300
+ return v
301
+ if isinstance(v, str):
302
+ try:
303
+ # Try case-insensitive match first
304
+ return ObservableType(v.lower())
305
+ except ValueError:
306
+ # Keep as string if not a recognized observable type
307
+ return v
308
+ return v
309
+
310
+ @field_validator("extra", mode="before")
311
+ @classmethod
312
+ def coerce_extra(cls, v: Any) -> dict[str, Any]:
313
+ if v is None:
314
+ return {}
315
+ return v
316
+
317
+ @field_validator("score", mode="before")
318
+ @classmethod
319
+ def coerce_score(cls, v: Any) -> Decimal:
320
+ if isinstance(v, Decimal):
321
+ return v
322
+ return Decimal(str(v))
323
+
324
+ @field_validator("level", mode="before")
325
+ @classmethod
326
+ def coerce_level(cls, v: Any) -> Level:
327
+ return normalize_level(v)
328
+
329
+ @model_validator(mode="before")
330
+ @classmethod
331
+ def ensure_defaults(cls, values: Any) -> Any:
332
+ values = apply_creation_score_level_defaults(values, default_level_no_score=Level.INFO)
333
+ if not isinstance(values, dict):
334
+ return values
335
+
336
+ if "extra" not in values:
337
+ values["extra"] = {}
338
+ if "comment" not in values:
339
+ values["comment"] = ""
340
+ if "internal" not in values:
341
+ values["internal"] = True
342
+ if "whitelisted" not in values:
343
+ values["whitelisted"] = False
344
+ if "threat_intels" not in values:
345
+ values["threat_intels"] = []
346
+ if "relationships" not in values:
347
+ values["relationships"] = []
348
+ if "key" not in values:
349
+ values["key"] = ""
350
+ return values
351
+
352
+ @model_validator(mode="after")
353
+ def generate_key(self) -> Self:
354
+ """Generate key."""
355
+ if not self.key:
356
+ # Use string value of obs_type for key generation
357
+ obs_type_str = self.obs_type.value if isinstance(self.obs_type, ObservableType) else self.obs_type
358
+ self.key = keys.generate_observable_key(obs_type_str, self.value)
359
+
360
+ return self
361
+
362
+ @field_serializer("obs_type")
363
+ def serialize_obs_type(self, v: ObservableType | str) -> str:
364
+ return v.value if isinstance(v, ObservableType) else v
365
+
366
+ @field_serializer("score")
367
+ def serialize_score(self, v: Decimal) -> float:
368
+ return float(v)
369
+
370
+ @field_serializer("threat_intels")
371
+ def serialize_threat_intels(self, value: list[ThreatIntel]) -> list[str]:
372
+ """Serialize threat intels as keys only."""
373
+ return [ti.key for ti in value]
374
+
375
+ @computed_field
376
+ @property
377
+ def check_links(self) -> list[str]:
378
+ """Checks that currently link to this observable (navigation-only)."""
379
+ return list(self._check_links)
380
+
381
+ @computed_field(return_type=str)
382
+ @property
383
+ def score_display(self) -> str:
384
+ return _format_score_decimal(self.score)
385
+
386
+
387
+ class ObservableLink(BaseModel):
388
+ """Edge metadata for a Check↔Observable association."""
389
+
390
+ model_config = ConfigDict(extra="forbid", frozen=True)
391
+
392
+ observable_key: str = Field(...)
393
+ propagation_mode: PropagationMode = PropagationMode.LOCAL_ONLY
394
+
395
+
396
+ class Check(BaseModel):
397
+ """
398
+ Represents a verification step in the investigation.
399
+
400
+ A check validates a specific aspect of the data under investigation
401
+ and contributes to the overall investigation score.
402
+ """
403
+
404
+ model_config = ConfigDict(arbitrary_types_allowed=True)
405
+
406
+ check_name: str = Field(...)
407
+ description: str = Field(...)
408
+ comment: str = Field(...)
409
+ extra: dict[str, Any] = Field(...)
410
+ score: Decimal = Field(...)
411
+ level: Level = Field(...)
412
+ origin_investigation_id: str = Field(...)
413
+ observable_links: list[ObservableLink] = Field(...)
414
+ key: str = Field(...)
415
+
416
+ @field_validator("extra", mode="before")
417
+ @classmethod
418
+ def coerce_extra(cls, v: Any) -> dict[str, Any]:
419
+ if v is None:
420
+ return {}
421
+ return v
422
+
423
+ @field_validator("score", mode="before")
424
+ @classmethod
425
+ def coerce_score(cls, v: Any) -> Decimal:
426
+ if isinstance(v, Decimal):
427
+ return v
428
+ return Decimal(str(v))
429
+
430
+ @field_validator("level", mode="before")
431
+ @classmethod
432
+ def coerce_level(cls, v: Any) -> Level:
433
+ return normalize_level(v)
434
+
435
+ @model_validator(mode="before")
436
+ @classmethod
437
+ def ensure_defaults(cls, values: Any) -> Any:
438
+ values = apply_creation_score_level_defaults(values, default_level_no_score=Level.NONE)
439
+ if not isinstance(values, dict):
440
+ return values
441
+
442
+ if "extra" not in values:
443
+ values["extra"] = {}
444
+ if "comment" not in values:
445
+ values["comment"] = ""
446
+ if "observable_links" not in values:
447
+ values["observable_links"] = []
448
+ if "key" not in values:
449
+ values["key"] = ""
450
+ return values
451
+
452
+ @model_validator(mode="after")
453
+ def generate_key(self) -> Self:
454
+ """Generate key."""
455
+ if not self.key:
456
+ self.key = keys.generate_check_key(self.check_name)
457
+ return self
458
+
459
+ @field_serializer("score")
460
+ def serialize_score(self, v: Decimal) -> float:
461
+ return float(v)
462
+
463
+ @computed_field(return_type=str)
464
+ @property
465
+ def score_display(self) -> str:
466
+ return _format_score_decimal(self.score)
467
+
468
+
469
+ class Enrichment(BaseModel):
470
+ """
471
+ Represents structured data enrichment for the investigation.
472
+
473
+ Enrichments store arbitrary structured data that provides additional
474
+ context but doesn't directly contribute to scoring.
475
+ """
476
+
477
+ model_config = ConfigDict()
478
+
479
+ name: str = Field(...)
480
+ data: Any = Field(...)
481
+ context: str = Field(...)
482
+ key: str = Field(...)
483
+
484
+ @model_validator(mode="after")
485
+ def generate_key(self) -> Self:
486
+ """Generate key."""
487
+ if not self.key:
488
+ self.key = keys.generate_enrichment_key(self.name, self.context)
489
+ return self
490
+
491
+ @model_validator(mode="before")
492
+ @classmethod
493
+ def ensure_defaults(cls, values: Any) -> Any:
494
+ if not isinstance(values, dict):
495
+ return values
496
+ if "data" not in values:
497
+ values["data"] = {}
498
+ if "context" not in values:
499
+ values["context"] = ""
500
+ if "key" not in values:
501
+ values["key"] = ""
502
+ return values
503
+
504
+
505
+ class Tag(BaseModel):
506
+ """
507
+ Groups checks for categorical organization.
508
+
509
+ Tags allow structuring the investigation into logical sections
510
+ with aggregated scores and levels. Hierarchy is automatic based on
511
+ the ":" delimiter in tag names (e.g., "header:auth:dkim").
512
+ """
513
+
514
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
515
+
516
+ name: str
517
+ description: str = ""
518
+ checks: list[Check] = Field(...)
519
+ key: str = Field(...)
520
+
521
+ @model_validator(mode="after")
522
+ def generate_key(self) -> Self:
523
+ """Generate key."""
524
+ if not self.key:
525
+ self.key = keys.generate_tag_key(self.name)
526
+ return self
527
+
528
+ @model_validator(mode="before")
529
+ @classmethod
530
+ def ensure_defaults(cls, values: Any) -> Any:
531
+ if not isinstance(values, dict):
532
+ return values
533
+ if "checks" not in values:
534
+ values["checks"] = []
535
+ if "key" not in values:
536
+ values["key"] = ""
537
+ return values
538
+
539
+ @field_serializer("checks")
540
+ def serialize_checks(self, value: list[Check]) -> list[str]:
541
+ """Serialize checks as keys only."""
542
+ return [check.key for check in value]
543
+
544
+ @computed_field(return_type=Decimal)
545
+ @property
546
+ def direct_score(self) -> Decimal:
547
+ """
548
+ Calculate the score from direct checks only (no hierarchy).
549
+
550
+ For hierarchical aggregation (including descendant tags), use
551
+ Investigation.get_tag_aggregated_score() or TagProxy.get_aggregated_score().
552
+
553
+ Returns:
554
+ Total score from direct checks
555
+ """
556
+ return self.get_direct_score()
557
+
558
+ @field_serializer("direct_score")
559
+ def serialize_direct_score(self, v: Decimal) -> float:
560
+ return float(v)
561
+
562
+ def get_direct_score(self) -> Decimal:
563
+ """
564
+ Calculate the score from direct checks only.
565
+
566
+ Returns:
567
+ Total score from direct checks
568
+ """
569
+ total = Decimal("0")
570
+ for check in self.checks:
571
+ total += check.score
572
+ return total
573
+
574
+ @computed_field(return_type=Level)
575
+ @property
576
+ def direct_level(self) -> Level:
577
+ """
578
+ Calculate the level from direct checks only (no hierarchy).
579
+
580
+ For hierarchical aggregation (including descendant tags), use
581
+ Investigation.get_tag_aggregated_level() or TagProxy.get_aggregated_level().
582
+
583
+ Returns:
584
+ Level based on direct score
585
+ """
586
+ return self.get_direct_level()
587
+
588
+ def get_direct_level(self) -> Level:
589
+ """
590
+ Calculate the level from direct score only.
591
+
592
+ Returns:
593
+ Level based on direct score
594
+ """
595
+ return get_level_from_score(self.get_direct_score())
cyvest/model_enums.py ADDED
@@ -0,0 +1,69 @@
1
+ """
2
+ Shared enum types for Cyvest models.
3
+
4
+ This module intentionally contains only enums (no Pydantic models) so it can be
5
+ imported by both ``cyvest.model`` and ``cyvest.score`` without creating circular
6
+ import dependencies.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from enum import Enum
12
+
13
+
14
+ class ObservableType(str, Enum):
15
+ """Cyber observable types."""
16
+
17
+ IPV4 = "ipv4"
18
+ IPV6 = "ipv6"
19
+ DOMAIN = "domain"
20
+ URL = "url"
21
+ HASH = "hash"
22
+ EMAIL = "email"
23
+ FILE = "file"
24
+ ARTIFACT = "artifact"
25
+
26
+ @classmethod
27
+ def normalize_root_type(cls, root_type: ObservableType | str | None) -> ObservableType:
28
+ if root_type is None:
29
+ return cls.FILE
30
+ if isinstance(root_type, cls):
31
+ normalized = root_type
32
+ elif isinstance(root_type, str):
33
+ try:
34
+ normalized = cls(root_type.lower())
35
+ except ValueError as exc:
36
+ raise ValueError("root_type must be ObservableType.FILE or ObservableType.ARTIFACT") from exc
37
+ else:
38
+ raise TypeError("root_type must be ObservableType.FILE or ObservableType.ARTIFACT")
39
+
40
+ if normalized not in (cls.FILE, cls.ARTIFACT):
41
+ raise ValueError("root_type must be ObservableType.FILE or ObservableType.ARTIFACT")
42
+ return normalized
43
+
44
+
45
+ class RelationshipDirection(str, Enum):
46
+ """Direction of a relationship between observables."""
47
+
48
+ OUTBOUND = "outbound" # Source → Target
49
+ INBOUND = "inbound" # Source ← Target
50
+ BIDIRECTIONAL = "bidirectional" # Source ↔ Target
51
+
52
+
53
+ class RelationshipType(str, Enum):
54
+ """Relationship types supported by Cyvest."""
55
+
56
+ RELATED_TO = "related-to"
57
+
58
+ def get_default_direction(self) -> RelationshipDirection:
59
+ """
60
+ Get the default direction for this relationship type.
61
+ """
62
+ return RelationshipDirection.BIDIRECTIONAL
63
+
64
+
65
+ class PropagationMode(str, Enum):
66
+ """Controls how a Check↔Observable link propagates across merged investigations."""
67
+
68
+ LOCAL_ONLY = "LOCAL_ONLY"
69
+ GLOBAL = "GLOBAL"