cyvest 0.1.0__py3-none-any.whl → 5.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cyvest/shared.py ADDED
@@ -0,0 +1,508 @@
1
+ """
2
+ Shared investigation context for concurrent execution.
3
+
4
+ This module provides a single implementation that supports both:
5
+ - synchronous usage (threads / thread pools)
6
+ - asynchronous usage (asyncio)
7
+
8
+ Key design goals:
9
+ - All state mutation and reads go through a single shared implementation.
10
+ - Async APIs never block the event loop: they run the critical section in a worker thread.
11
+ - Returned objects are deep-copied snapshots (read-only-by-convention) to avoid shared mutable state.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import threading
18
+ from copy import deepcopy
19
+ from decimal import Decimal
20
+ from pathlib import Path
21
+ from typing import TYPE_CHECKING, Any, Literal
22
+
23
+ from logurich import logger
24
+
25
+ from cyvest import keys
26
+ from cyvest.cyvest import Cyvest
27
+ from cyvest.io_serialization import (
28
+ generate_markdown_report,
29
+ save_investigation_json,
30
+ save_investigation_markdown,
31
+ serialize_investigation,
32
+ )
33
+ from cyvest.levels import Level
34
+ from cyvest.model import Check, Enrichment, Observable, ObservableType
35
+
36
+ if TYPE_CHECKING:
37
+ from cyvest.investigation import Investigation
38
+ from cyvest.model_schema import InvestigationSchema
39
+
40
+
41
+ class _SharedLock:
42
+ """
43
+ Dual-mode lock adapter with a single canonical lock.
44
+
45
+ - Sync path: acquires a single `threading.RLock` around the critical section.
46
+ - Async path: runs the entire critical section in a worker thread via `asyncio.to_thread(...)`
47
+ so the event loop is never blocked.
48
+
49
+ Notes:
50
+ - Optionally limits concurrent async callers via a single `asyncio.Semaphore(max_async_workers)`.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ thread_lock: threading.RLock | None = None,
56
+ *,
57
+ max_async_workers: int | None = None,
58
+ ) -> None:
59
+ self._thread_lock = thread_lock or threading.RLock()
60
+ self._max_async_workers = max_async_workers
61
+ self._async_semaphores: dict[int, asyncio.Semaphore] = {}
62
+
63
+ def run(self, fn, /, *args, **kwargs):
64
+ with self._thread_lock:
65
+ return fn(*args, **kwargs)
66
+
67
+ async def arun(self, fn, /, *args, **kwargs):
68
+ max_workers = self._max_async_workers
69
+ if max_workers is None:
70
+ return await asyncio.to_thread(self.run, fn, *args, **kwargs)
71
+
72
+ loop = asyncio.get_running_loop()
73
+ loop_id = id(loop)
74
+ semaphore = self._async_semaphores.get(loop_id)
75
+ if semaphore is None:
76
+ semaphore = asyncio.Semaphore(max_workers)
77
+ self._async_semaphores[loop_id] = semaphore
78
+
79
+ async with semaphore:
80
+ return await asyncio.to_thread(self.run, fn, *args, **kwargs)
81
+
82
+
83
+ class SharedInvestigationContext:
84
+ """
85
+ Shared context for cross-task observable/check/enrichment sharing.
86
+
87
+ Initialize with a Cyvest instance; the canonical state is its investigation.
88
+
89
+ Invariants:
90
+ - The canonical state lives in `_main_investigation`.
91
+ - All merges are atomic: merge + registry refresh happen in a single critical section.
92
+ - Registries only contain deep-copied snapshots; callers never get live references.
93
+ - Async APIs never block the event loop: all critical sections run in a worker thread.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ root_cyvest: Cyvest,
99
+ *,
100
+ lock: threading.RLock | None = None,
101
+ max_async_workers: int | None = None,
102
+ ) -> None:
103
+ if not isinstance(root_cyvest, Cyvest):
104
+ raise TypeError("SharedInvestigationContext expects a Cyvest instance. Use Cyvest.shared_context().")
105
+ self._lock = _SharedLock(lock, max_async_workers=max_async_workers)
106
+ self._main_cyvest = root_cyvest
107
+ self._main_investigation = root_cyvest._investigation
108
+
109
+ self._root_type = (
110
+ ObservableType.ARTIFACT
111
+ if self._main_investigation._root_observable.obs_type == ObservableType.ARTIFACT
112
+ else ObservableType.FILE
113
+ )
114
+ self._score_mode_obs = self._main_investigation._score_engine._score_mode_obs
115
+
116
+ self._observable_registry: dict[str, Observable] = {}
117
+ self._check_registry: dict[str, Check] = {}
118
+ self._enrichment_registry: dict[str, Enrichment] = {}
119
+
120
+ # Initialize registries from the provided canonical investigation so lookups
121
+ # work immediately (even before the first reconcile).
122
+ self._lock.run(self._refresh_registries_unlocked)
123
+
124
+ # ---------------------------------------------------------------------
125
+ # Task creation (local fragment builder)
126
+ # ---------------------------------------------------------------------
127
+
128
+ def create_cyvest(
129
+ self,
130
+ root_data: Any | None = None,
131
+ investigation_id: str | None = None,
132
+ investigation_name: str | None = None,
133
+ ):
134
+ """
135
+ Return a context manager for a task-local Cyvest instance.
136
+
137
+ - `with shared.create_cyvest() as cy:` auto-reconciles on successful exit.
138
+ - `async with shared.create_cyvest() as cy:` also works (reconciles via `areconcile()`).
139
+
140
+ Args:
141
+ root_data: Task data (if None, uses a deep copy of the canonical root observable extra).
142
+ investigation_id: Optional deterministic investigation ID for the fragment.
143
+ investigation_name: Optional human-readable name for the fragment.
144
+
145
+ If `root_data` is None, task data is a deep copy of the canonical root observable extra.
146
+ """
147
+ return self._CyvestContextManager(
148
+ shared_context=self,
149
+ root_data=root_data,
150
+ investigation_id=investigation_id,
151
+ investigation_name=investigation_name,
152
+ )
153
+
154
+ def acreate_cyvest(
155
+ self,
156
+ root_data: Any | None = None,
157
+ investigation_id: str | None = None,
158
+ investigation_name: str | None = None,
159
+ ):
160
+ """Async-friendly alias for `create_cyvest` (supports `async with`)."""
161
+ return self.create_cyvest(
162
+ root_data=root_data,
163
+ investigation_id=investigation_id,
164
+ investigation_name=investigation_name,
165
+ )
166
+
167
+ class _CyvestContextManager:
168
+ def __init__(
169
+ self,
170
+ *,
171
+ shared_context: SharedInvestigationContext,
172
+ root_data: Any | None,
173
+ investigation_id: str | None = None,
174
+ investigation_name: str | None = None,
175
+ ) -> None:
176
+ self._shared_context = shared_context
177
+ self._root_data = root_data
178
+ self._investigation_id = investigation_id
179
+ self._investigation_name = investigation_name
180
+ self._cyvest: Cyvest | None = None
181
+
182
+ def __enter__(self):
183
+ self._cyvest = self._shared_context._create_task_cyvest_sync(
184
+ self._root_data, self._investigation_id, self._investigation_name
185
+ )
186
+ return self._cyvest
187
+
188
+ def __exit__(self, exc_type, _exc_val, _exc_tb) -> Literal[False]:
189
+ if exc_type is None and self._cyvest is not None:
190
+ self._shared_context.reconcile(self._cyvest)
191
+ return False
192
+
193
+ async def __aenter__(self):
194
+ self._cyvest = await self._shared_context._create_task_cyvest_async(
195
+ self._root_data, self._investigation_id, self._investigation_name
196
+ )
197
+ return self._cyvest
198
+
199
+ async def __aexit__(self, exc_type, _exc_val, _exc_tb) -> Literal[False]:
200
+ if exc_type is None and self._cyvest is not None:
201
+ await self._shared_context.areconcile(self._cyvest)
202
+ return False
203
+
204
+ def _create_task_cyvest_sync(
205
+ self,
206
+ root_data: Any | None,
207
+ investigation_id: str | None = None,
208
+ investigation_name: str | None = None,
209
+ ):
210
+ if root_data is None:
211
+ root_data = self._lock.run(self._get_root_data_copy_unlocked)
212
+ else:
213
+ root_data = deepcopy(root_data)
214
+ return Cyvest(
215
+ root_data,
216
+ root_type=self._root_type,
217
+ score_mode_obs=self._score_mode_obs,
218
+ investigation_id=investigation_id,
219
+ investigation_name=investigation_name,
220
+ )
221
+
222
+ async def _create_task_cyvest_async(
223
+ self,
224
+ root_data: Any | None,
225
+ investigation_id: str | None = None,
226
+ investigation_name: str | None = None,
227
+ ):
228
+ if root_data is None:
229
+ root_data = await self._lock.arun(self._get_root_data_copy_unlocked)
230
+ else:
231
+ root_data = deepcopy(root_data)
232
+ return Cyvest(
233
+ root_data,
234
+ root_type=self._root_type,
235
+ score_mode_obs=self._score_mode_obs,
236
+ investigation_id=investigation_id,
237
+ investigation_name=investigation_name,
238
+ )
239
+
240
+ def _get_root_data_copy_unlocked(self) -> Any:
241
+ return deepcopy(self._main_investigation._root_observable.extra)
242
+
243
+ # ---------------------------------------------------------------------
244
+ # Reconciliation (atomic merge into canonical)
245
+ # ---------------------------------------------------------------------
246
+
247
+ def reconcile(self, source: Cyvest | Investigation) -> None:
248
+ task_investigation = self._extract_investigation(source)
249
+ self._lock.run(self._reconcile_unlocked, task_investigation)
250
+
251
+ async def areconcile(self, source: Cyvest | Investigation) -> None:
252
+ task_investigation = self._extract_investigation(source)
253
+ await self._lock.arun(self._reconcile_unlocked, task_investigation)
254
+
255
+ def _extract_investigation(self, source: Cyvest | Investigation) -> Investigation:
256
+ if isinstance(source, Cyvest):
257
+ return source._investigation
258
+ return source
259
+
260
+ def _reconcile_unlocked(self, task_investigation: Investigation) -> None:
261
+ logger.debug("Reconciling task investigation into shared context")
262
+ self._main_investigation.merge_investigation(task_investigation)
263
+ self._refresh_registries_unlocked()
264
+ logger.debug(
265
+ "Reconciliation complete. Registry: %d observables, %d checks, %d enrichments",
266
+ len(self._observable_registry),
267
+ len(self._check_registry),
268
+ len(self._enrichment_registry),
269
+ )
270
+
271
+ def _refresh_registries_unlocked(self) -> None:
272
+ observable_registry: dict[str, Observable] = {}
273
+ for obs in self._main_investigation.get_all_observables().values():
274
+ copy = obs.model_copy(deep=True)
275
+ copy._from_shared_context = True
276
+ observable_registry[obs.key] = copy
277
+ check_registry = {
278
+ check.key: check.model_copy(deep=True) for check in self._main_investigation.get_all_checks().values()
279
+ }
280
+ enrichment_registry = {
281
+ enrichment.key: enrichment.model_copy(deep=True)
282
+ for enrichment in self._main_investigation.get_all_enrichments().values()
283
+ }
284
+ self._observable_registry = observable_registry
285
+ self._check_registry = check_registry
286
+ self._enrichment_registry = enrichment_registry
287
+
288
+ # ---------------------------------------------------------------------
289
+ # Lookups (deep-copied snapshots only)
290
+ # ---------------------------------------------------------------------
291
+
292
+ def observable_get(self, obs_type: ObservableType, value: str) -> Observable | None:
293
+ key = self._observable_key(obs_type, value)
294
+ return self._lock.run(self._get_observable_by_key_unlocked, key)
295
+
296
+ async def observable_aget(self, obs_type: ObservableType, value: str) -> Observable | None:
297
+ key = self._observable_key(obs_type, value)
298
+ return await self._lock.arun(self._get_observable_by_key_unlocked, key)
299
+
300
+ def _get_observable_by_key_unlocked(self, key: str) -> Observable | None:
301
+ obs = self._observable_registry.get(key)
302
+ if obs is None:
303
+ return None
304
+ copy = obs.model_copy(deep=True)
305
+ copy._from_shared_context = True
306
+ return copy
307
+
308
+ def _observable_key(self, obs_type: ObservableType, value: str) -> str:
309
+ try:
310
+ return keys.generate_observable_key(obs_type.value, value)
311
+ except Exception as e:
312
+ raise ValueError(f"Failed to generate observable key for type='{obs_type}', value='{value}': {e}") from e
313
+
314
+ def check_get(self, check_name: str) -> Check | None:
315
+ key = self._check_key(check_name)
316
+ return self._lock.run(self._get_check_by_key_unlocked, key)
317
+
318
+ async def check_aget(self, check_name: str) -> Check | None:
319
+ key = self._check_key(check_name)
320
+ return await self._lock.arun(self._get_check_by_key_unlocked, key)
321
+
322
+ def _get_check_by_key_unlocked(self, key: str) -> Check | None:
323
+ check = self._check_registry.get(key)
324
+ return check.model_copy(deep=True) if check else None
325
+
326
+ def _check_key(self, check_name: str) -> str:
327
+ try:
328
+ return keys.generate_check_key(check_name)
329
+ except Exception as e:
330
+ raise ValueError(f"Failed to generate check key for check_name='{check_name}': {e}") from e
331
+
332
+ def enrichment_get(self, name: str, context: str = "") -> Enrichment | None:
333
+ key = self._enrichment_key(name, context)
334
+ return self._lock.run(self._get_enrichment_by_key_unlocked, key)
335
+
336
+ async def enrichment_aget(self, name: str, context: str = "") -> Enrichment | None:
337
+ key = self._enrichment_key(name, context)
338
+ return await self._lock.arun(self._get_enrichment_by_key_unlocked, key)
339
+
340
+ def _get_enrichment_by_key_unlocked(self, key: str) -> Enrichment | None:
341
+ enrichment = self._enrichment_registry.get(key)
342
+ return enrichment.model_copy(deep=True) if enrichment else None
343
+
344
+ def _enrichment_key(self, name: str, context: str = "") -> str:
345
+ try:
346
+ return keys.generate_enrichment_key(name, context)
347
+ except Exception as e:
348
+ raise ValueError(f"Failed to generate enrichment key for name='{name}', context='{context}': {e}") from e
349
+
350
+ # ---------------------------------------------------------------------
351
+ # Lightweight state reads
352
+ # ---------------------------------------------------------------------
353
+
354
+ def get_global_score(self) -> Decimal:
355
+ return self._lock.run(self._main_investigation.get_global_score)
356
+
357
+ async def aget_global_score(self) -> Decimal:
358
+ return await self._lock.arun(self._main_investigation.get_global_score)
359
+
360
+ def is_whitelisted(self) -> bool:
361
+ return self._lock.run(self._main_investigation.is_whitelisted)
362
+
363
+ async def ais_whitelisted(self) -> bool:
364
+ return await self._lock.arun(self._main_investigation.is_whitelisted)
365
+
366
+ def get_global_level(self) -> Level:
367
+ return self._lock.run(self._main_investigation.get_global_level)
368
+
369
+ async def aget_global_level(self) -> Level:
370
+ return await self._lock.arun(self._main_investigation.get_global_level)
371
+
372
+ def observables_list_by_type(self, obs_type: ObservableType) -> list[Observable]:
373
+ return self._lock.run(self._observables_list_by_type_unlocked, obs_type)
374
+
375
+ async def observables_alist_by_type(self, obs_type: ObservableType) -> list[Observable]:
376
+ return await self._lock.arun(self._observables_list_by_type_unlocked, obs_type)
377
+
378
+ def _observables_list_by_type_unlocked(self, obs_type: ObservableType) -> list[Observable]:
379
+ matches = [obs for obs in self._observable_registry.values() if obs.obs_type == obs_type]
380
+
381
+ results: list[Observable] = []
382
+ for obs in matches:
383
+ copy = obs.model_copy(deep=True)
384
+ copy._from_shared_context = True
385
+ results.append(copy)
386
+ return results
387
+
388
+ # Intentionally minimal: prefer `observable_get()` / `check_get()` and user-side filtering.
389
+
390
+ # ---------------------------------------------------------------------
391
+ # Serialization helpers (sync + async wrappers)
392
+ # ---------------------------------------------------------------------
393
+
394
+ def io_to_markdown(
395
+ self,
396
+ include_tags: bool = False,
397
+ include_enrichments: bool = False,
398
+ include_observables: bool = True,
399
+ exclude_levels: set[Level] | None = None,
400
+ ) -> str:
401
+ return self._lock.run(
402
+ self._io_to_markdown_unlocked,
403
+ include_tags,
404
+ include_enrichments,
405
+ include_observables,
406
+ exclude_levels,
407
+ )
408
+
409
+ async def aio_to_markdown(
410
+ self,
411
+ include_tags: bool = False,
412
+ include_enrichments: bool = False,
413
+ include_observables: bool = True,
414
+ exclude_levels: set[Level] | None = None,
415
+ ) -> str:
416
+ return await self._lock.arun(
417
+ self._io_to_markdown_unlocked,
418
+ include_tags,
419
+ include_enrichments,
420
+ include_observables,
421
+ exclude_levels,
422
+ )
423
+
424
+ def _io_to_markdown_unlocked(
425
+ self,
426
+ include_tags: bool,
427
+ include_enrichments: bool,
428
+ include_observables: bool,
429
+ exclude_levels: set[Level] | None,
430
+ ) -> str:
431
+ return generate_markdown_report(
432
+ self._main_investigation,
433
+ include_tags,
434
+ include_enrichments,
435
+ include_observables,
436
+ exclude_levels,
437
+ )
438
+
439
+ def io_save_markdown(
440
+ self,
441
+ filepath: str | Path,
442
+ include_tags: bool = False,
443
+ include_enrichments: bool = False,
444
+ include_observables: bool = True,
445
+ exclude_levels: set[Level] | None = None,
446
+ ) -> str:
447
+ return self._lock.run(
448
+ self._io_save_markdown_unlocked,
449
+ filepath,
450
+ include_tags,
451
+ include_enrichments,
452
+ include_observables,
453
+ exclude_levels,
454
+ )
455
+
456
+ async def aio_save_markdown(
457
+ self,
458
+ filepath: str | Path,
459
+ include_tags: bool = False,
460
+ include_enrichments: bool = False,
461
+ include_observables: bool = True,
462
+ exclude_levels: set[Level] | None = None,
463
+ ) -> str:
464
+ return await self._lock.arun(
465
+ self._io_save_markdown_unlocked,
466
+ filepath,
467
+ include_tags,
468
+ include_enrichments,
469
+ include_observables,
470
+ exclude_levels,
471
+ )
472
+
473
+ def _io_save_markdown_unlocked(
474
+ self,
475
+ filepath: str | Path,
476
+ include_tags: bool,
477
+ include_enrichments: bool,
478
+ include_observables: bool,
479
+ exclude_levels: set[Level] | None,
480
+ ) -> str:
481
+ save_investigation_markdown(
482
+ self._main_investigation,
483
+ filepath,
484
+ include_tags,
485
+ include_enrichments,
486
+ include_observables,
487
+ exclude_levels,
488
+ )
489
+ return str(Path(filepath).resolve())
490
+
491
+ def io_to_invest(self, *, include_audit_log: bool = True) -> InvestigationSchema:
492
+ return self._lock.run(self._io_to_invest_unlocked, include_audit_log)
493
+
494
+ async def aio_to_invest(self, *, include_audit_log: bool = True) -> InvestigationSchema:
495
+ return await self._lock.arun(self._io_to_invest_unlocked, include_audit_log)
496
+
497
+ def _io_to_invest_unlocked(self, include_audit_log: bool = True) -> InvestigationSchema:
498
+ return serialize_investigation(self._main_investigation, include_audit_log=include_audit_log)
499
+
500
+ def io_save_json(self, filepath: str | Path, *, include_audit_log: bool = True) -> str:
501
+ return self._lock.run(self._io_save_json_unlocked, filepath, include_audit_log)
502
+
503
+ async def aio_save_json(self, filepath: str | Path, *, include_audit_log: bool = True) -> str:
504
+ return await self._lock.arun(self._io_save_json_unlocked, filepath, include_audit_log)
505
+
506
+ def _io_save_json_unlocked(self, filepath: str | Path, include_audit_log: bool = True) -> str:
507
+ save_investigation_json(self._main_investigation, filepath, include_audit_log=include_audit_log)
508
+ return str(Path(filepath).resolve())