contentctl 4.3.1__py3-none-any.whl → 4.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. contentctl/actions/detection_testing/infrastructures/DetectionTestingInfrastructure.py +35 -21
  2. contentctl/actions/detection_testing/views/DetectionTestingView.py +64 -38
  3. contentctl/actions/detection_testing/views/DetectionTestingViewCLI.py +1 -0
  4. contentctl/actions/detection_testing/views/DetectionTestingViewFile.py +3 -5
  5. contentctl/actions/test.py +55 -32
  6. contentctl/contentctl.py +3 -6
  7. contentctl/enrichments/attack_enrichment.py +24 -11
  8. contentctl/objects/abstract_security_content_objects/detection_abstract.py +180 -88
  9. contentctl/objects/abstract_security_content_objects/security_content_object_abstract.py +1 -0
  10. contentctl/objects/base_test.py +1 -0
  11. contentctl/objects/base_test_result.py +1 -0
  12. contentctl/objects/config.py +24 -9
  13. contentctl/objects/detection_tags.py +3 -0
  14. contentctl/objects/integration_test.py +3 -5
  15. contentctl/objects/integration_test_result.py +1 -5
  16. contentctl/objects/investigation.py +1 -0
  17. contentctl/objects/manual_test.py +32 -0
  18. contentctl/objects/manual_test_result.py +8 -0
  19. contentctl/objects/mitre_attack_enrichment.py +67 -3
  20. contentctl/objects/ssa_detection.py +1 -0
  21. contentctl/objects/story_tags.py +2 -0
  22. contentctl/objects/{unit_test_attack_data.py → test_attack_data.py} +4 -5
  23. contentctl/objects/test_group.py +3 -3
  24. contentctl/objects/unit_test.py +4 -11
  25. contentctl/output/templates/savedsearches_detections.j2 +1 -1
  26. {contentctl-4.3.1.dist-info → contentctl-4.3.3.dist-info}/METADATA +14 -14
  27. {contentctl-4.3.1.dist-info → contentctl-4.3.3.dist-info}/RECORD +30 -28
  28. {contentctl-4.3.1.dist-info → contentctl-4.3.3.dist-info}/LICENSE.md +0 -0
  29. {contentctl-4.3.1.dist-info → contentctl-4.3.3.dist-info}/WHEEL +0 -0
  30. {contentctl-4.3.1.dist-info → contentctl-4.3.3.dist-info}/entry_points.txt +0 -0
@@ -10,7 +10,6 @@ import pathlib
10
10
  from tempfile import TemporaryDirectory, mktemp
11
11
  from ssl import SSLEOFError, SSLZeroReturnError
12
12
  from sys import stdout
13
- #from dataclasses import dataclass
14
13
  from shutil import copyfile
15
14
  from typing import Union, Optional
16
15
 
@@ -29,7 +28,7 @@ from contentctl.objects.detection import Detection
29
28
  from contentctl.objects.base_test import BaseTest
30
29
  from contentctl.objects.unit_test import UnitTest
31
30
  from contentctl.objects.integration_test import IntegrationTest
32
- from contentctl.objects.unit_test_attack_data import UnitTestAttackData
31
+ from contentctl.objects.test_attack_data import TestAttackData
33
32
  from contentctl.objects.unit_test_result import UnitTestResult
34
33
  from contentctl.objects.integration_test_result import IntegrationTestResult
35
34
  from contentctl.objects.test_group import TestGroup
@@ -61,13 +60,19 @@ class CleanupTestGroupResults(BaseModel):
61
60
 
62
61
  class ContainerStoppedException(Exception):
63
62
  pass
63
+ class CannotRunBaselineException(Exception):
64
+ # Support for testing detections with baselines
65
+ # does not currently exist in contentctl.
66
+ # As such, whenever we encounter a detection
67
+ # with baselines we should generate a descriptive
68
+ # exception
69
+ pass
64
70
 
65
71
 
66
72
  @dataclasses.dataclass(frozen=False)
67
73
  class DetectionTestingManagerOutputDto():
68
74
  inputQueue: list[Detection] = Field(default_factory=list)
69
75
  outputQueue: list[Detection] = Field(default_factory=list)
70
- skippedQueue: list[Detection] = Field(default_factory=list)
71
76
  currentTestingQueue: dict[str, Union[Detection, None]] = Field(default_factory=dict)
72
77
  start_time: Union[datetime.datetime, None] = None
73
78
  replay_index: str = "CONTENTCTL_TESTING_INDEX"
@@ -647,11 +652,7 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
647
652
  # Set the mode and timeframe, if required
648
653
  kwargs = {"exec_mode": "blocking"}
649
654
 
650
- # Iterate over baselines (if any)
651
- for baseline in test.baselines:
652
- # TODO: this is executing the test, not the baseline...
653
- # TODO: should this be in a try/except if the later call is?
654
- self.retry_search_until_timeout(detection, test, kwargs, test_start_time)
655
+
655
656
 
656
657
  # Set earliest_time and latest_time appropriately if FORCE_ALL_TIME is False
657
658
  if not FORCE_ALL_TIME:
@@ -662,7 +663,23 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
662
663
 
663
664
  # Run the detection's search query
664
665
  try:
666
+ # Iterate over baselines (if any)
667
+ for baseline in detection.baselines:
668
+ raise CannotRunBaselineException("Detection requires Execution of a Baseline, "
669
+ "however Baseline execution is not "
670
+ "currently supported in contentctl. Mark "
671
+ "this as manual_test.")
665
672
  self.retry_search_until_timeout(detection, test, kwargs, test_start_time)
673
+ except CannotRunBaselineException as e:
674
+ # Init the test result and record a failure if there was an issue during the search
675
+ test.result = UnitTestResult()
676
+ test.result.set_job_content(
677
+ None,
678
+ self.infrastructure,
679
+ TestResultStatus.ERROR,
680
+ exception=e,
681
+ duration=time.time() - test_start_time
682
+ )
666
683
  except ContainerStoppedException as e:
667
684
  raise e
668
685
  except Exception as e:
@@ -1015,18 +1032,15 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
1015
1032
  """
1016
1033
  # Get the start time and compute the timeout
1017
1034
  search_start_time = time.time()
1018
- search_stop_time = time.time() + self.sync_obj.timeout_seconds
1019
-
1020
- # We will default to ensuring at least one result exists
1021
- if test.pass_condition is None:
1022
- search = detection.search
1023
- else:
1024
- # Else, use the explicit pass condition
1025
- search = f"{detection.search} {test.pass_condition}"
1035
+ search_stop_time = time.time() + self.sync_obj.timeout_seconds
1036
+
1037
+ # Make a copy of the search string since we may
1038
+ # need to make some small changes to it below
1039
+ search = detection.search
1026
1040
 
1027
1041
  # Ensure searches that do not begin with '|' must begin with 'search '
1028
- if not search.strip().startswith("|"): # type: ignore
1029
- if not search.strip().startswith("search "): # type: ignore
1042
+ if not search.strip().startswith("|"):
1043
+ if not search.strip().startswith("search "):
1030
1044
  search = f"search {search}"
1031
1045
 
1032
1046
  # exponential backoff for wait time
@@ -1179,7 +1193,7 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
1179
1193
 
1180
1194
  return
1181
1195
 
1182
- def delete_attack_data(self, attack_data_files: list[UnitTestAttackData]):
1196
+ def delete_attack_data(self, attack_data_files: list[TestAttackData]):
1183
1197
  for attack_data_file in attack_data_files:
1184
1198
  index = attack_data_file.custom_index or self.sync_obj.replay_index
1185
1199
  host = attack_data_file.host or self.sync_obj.replay_host
@@ -1212,7 +1226,7 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
1212
1226
 
1213
1227
  def replay_attack_data_file(
1214
1228
  self,
1215
- attack_data_file: UnitTestAttackData,
1229
+ attack_data_file: TestAttackData,
1216
1230
  tmp_dir: str,
1217
1231
  test_group: TestGroup,
1218
1232
  test_group_start_time: float,
@@ -1280,7 +1294,7 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
1280
1294
  def hec_raw_replay(
1281
1295
  self,
1282
1296
  tempfile: str,
1283
- attack_data_file: UnitTestAttackData,
1297
+ attack_data_file: TestAttackData,
1284
1298
  verify_ssl: bool = False,
1285
1299
  ):
1286
1300
  if verify_ssl is False:
@@ -1,5 +1,6 @@
1
1
  import abc
2
2
  import datetime
3
+ from typing import Any
3
4
 
4
5
  from pydantic import BaseModel
5
6
 
@@ -10,6 +11,7 @@ from contentctl.actions.detection_testing.infrastructures.DetectionTestingInfras
10
11
  )
11
12
  from contentctl.helper.utils import Utils
12
13
  from contentctl.objects.enums import DetectionStatus
14
+ from contentctl.objects.base_test_result import TestResultStatus
13
15
 
14
16
 
15
17
  class DetectionTestingView(BaseModel, abc.ABC):
@@ -74,18 +76,23 @@ class DetectionTestingView(BaseModel, abc.ABC):
74
76
  self,
75
77
  test_result_fields: list[str] = ["success", "message", "exception", "status", "duration", "wait_duration"],
76
78
  test_job_fields: list[str] = ["resultCount", "runDuration"],
77
- ) -> dict:
79
+ ) -> dict[str, dict[str, Any] | list[dict[str, Any]] | str]:
78
80
  """
79
81
  Iterates over detections, consolidating results into a single dict and aggregating metrics
80
82
  :param test_result_fields: fields to pull from the test result
81
83
  :param test_job_fields: fields to pull from the job content of the test result
82
84
  :returns: summary dict
83
85
  """
84
- # Init the list of tested detections, and some metrics aggregate counters
85
- tested_detections = []
86
+ # Init the list of tested and skipped detections, and some metrics aggregate counters
87
+ tested_detections: list[dict[str, Any]] = []
88
+ skipped_detections: list[dict[str, Any]] = []
86
89
  total_pass = 0
87
90
  total_fail = 0
88
91
  total_skipped = 0
92
+ total_production = 0
93
+ total_experimental = 0
94
+ total_deprecated = 0
95
+ total_manual = 0
89
96
 
90
97
  # Iterate the detections tested (anything in the output queue was tested)
91
98
  for detection in self.sync_obj.outputQueue:
@@ -95,46 +102,59 @@ class DetectionTestingView(BaseModel, abc.ABC):
95
102
  )
96
103
 
97
104
  # Aggregate detection pass/fail metrics
98
- if summary["success"] is False:
105
+ if detection.test_status == TestResultStatus.FAIL:
99
106
  total_fail += 1
107
+ elif detection.test_status == TestResultStatus.PASS:
108
+ total_pass += 1
109
+ elif detection.test_status == TestResultStatus.SKIP:
110
+ total_skipped += 1
111
+
112
+ # Aggregate production status metrics
113
+ if detection.status == DetectionStatus.production.value: # type: ignore
114
+ total_production += 1
115
+ elif detection.status == DetectionStatus.experimental.value: # type: ignore
116
+ total_experimental += 1
117
+ elif detection.status == DetectionStatus.deprecated.value: # type: ignore
118
+ total_deprecated += 1
119
+
120
+ # Check if the detection is manual_test
121
+ if detection.tags.manual_test is not None:
122
+ total_manual += 1
123
+
124
+ # Append to our list (skipped or tested)
125
+ if detection.test_status == TestResultStatus.SKIP:
126
+ skipped_detections.append(summary)
100
127
  else:
101
- #Test is marked as a success, but we need to determine if there were skipped unit tests
102
- #SKIPPED tests still show a success in this field, but we want to count them differently
103
- pass_increment = 1
104
- for test in summary.get("tests"):
105
- if test.get("test_type") == "unit" and test.get("status") == "skip":
106
- total_skipped += 1
107
- #Test should not count as a pass, so do not increment the count
108
- pass_increment = 0
109
- break
110
- total_pass += pass_increment
111
-
112
-
113
- # Append to our list
114
- tested_detections.append(summary)
115
-
116
- # Sort s.t. all failures appear first (then by name)
117
- #Second short condition is a hack to get detections with unit skipped tests to appear above pass tests
118
- tested_detections.sort(key=lambda x: (x["success"], 0 if x.get("tests",[{}])[0].get("status","status_missing")=="skip" else 1, x["name"]))
128
+ tested_detections.append(summary)
119
129
 
130
+ # Sort tested detections s.t. all failures appear first, then by name
131
+ tested_detections.sort(
132
+ key=lambda x: (
133
+ x["success"],
134
+ x["name"]
135
+ )
136
+ )
137
+
138
+ # Sort skipped detections s.t. detections w/ tests appear before those w/o, then by name
139
+ skipped_detections.sort(
140
+ key=lambda x: (
141
+ 0 if len(x["tests"]) > 0 else 1,
142
+ x["name"]
143
+ )
144
+ )
145
+
146
+ # TODO (#267): Align test reporting more closely w/ status enums (as it relates to
147
+ # "untested")
120
148
  # Aggregate summaries for the untested detections (anything still in the input queue was untested)
121
149
  total_untested = len(self.sync_obj.inputQueue)
122
- untested_detections = []
150
+ untested_detections: list[dict[str, Any]] = []
123
151
  for detection in self.sync_obj.inputQueue:
124
152
  untested_detections.append(detection.get_summary())
125
153
 
126
154
  # Sort by detection name
127
155
  untested_detections.sort(key=lambda x: x["name"])
128
156
 
129
- # Get lists of detections (name only) that were skipped due to their status (experimental or deprecated)
130
- experimental_detections = sorted([
131
- detection.name for detection in self.sync_obj.skippedQueue if detection.status == DetectionStatus.experimental.value
132
- ])
133
- deprecated_detections = sorted([
134
- detection.name for detection in self.sync_obj.skippedQueue if detection.status == DetectionStatus.deprecated.value
135
- ])
136
-
137
- # If any detection failed, the overall success is False
157
+ # If any detection failed, or if there are untested detections, the overall success is False
138
158
  if (total_fail + len(untested_detections)) == 0:
139
159
  overall_success = True
140
160
  else:
@@ -143,33 +163,39 @@ class DetectionTestingView(BaseModel, abc.ABC):
143
163
  # Compute total detections
144
164
  total_detections = total_fail + total_pass + total_untested + total_skipped
145
165
 
166
+ # Compute total detections actually tested (at least one test not skipped)
167
+ total_tested_detections = total_fail + total_pass
146
168
 
147
169
  # Compute the percentage of completion for testing, as well as the success rate
148
170
  percent_complete = Utils.getPercent(
149
171
  len(tested_detections), len(untested_detections), 1
150
172
  )
151
173
  success_rate = Utils.getPercent(
152
- total_pass, total_detections-total_skipped, 1
174
+ total_pass, total_tested_detections, 1
153
175
  )
154
176
 
155
- # TODO (#230): expand testing metrics reported
177
+ # TODO (#230): expand testing metrics reported (and make nested)
156
178
  # Construct and return the larger results dict
157
179
  result_dict = {
158
180
  "summary": {
181
+ "mode": self.config.getModeName(),
182
+ "enable_integration_testing": self.config.enable_integration_testing,
159
183
  "success": overall_success,
160
184
  "total_detections": total_detections,
185
+ "total_tested_detections": total_tested_detections,
161
186
  "total_pass": total_pass,
162
187
  "total_fail": total_fail,
163
188
  "total_skipped": total_skipped,
164
189
  "total_untested": total_untested,
165
- "total_experimental_or_deprecated": len(deprecated_detections+experimental_detections),
190
+ "total_production": total_production,
191
+ "total_experimental": total_experimental,
192
+ "total_deprecated": total_deprecated,
193
+ "total_manual": total_manual,
166
194
  "success_rate": success_rate,
167
195
  },
168
196
  "tested_detections": tested_detections,
197
+ "skipped_detections": skipped_detections,
169
198
  "untested_detections": untested_detections,
170
199
  "percent_complete": percent_complete,
171
- "deprecated_detections": deprecated_detections,
172
- "experimental_detections": experimental_detections
173
-
174
200
  }
175
201
  return result_dict
@@ -45,6 +45,7 @@ class DetectionTestingViewCLI(DetectionTestingView, arbitrary_types_allowed=True
45
45
 
46
46
  self.showStatus()
47
47
 
48
+ # TODO (#267): Align test reporting more closely w/ status enums (as it relates to "untested")
48
49
  def showStatus(self, interval: int = 1):
49
50
 
50
51
  while True:
@@ -13,7 +13,6 @@ class DetectionTestingViewFile(DetectionTestingView):
13
13
  output_filename: str = OUTPUT_FILENAME
14
14
 
15
15
  def getOutputFilePath(self) -> pathlib.Path:
16
-
17
16
  folder_path = pathlib.Path('.') / self.output_folder
18
17
  output_file = folder_path / self.output_filename
19
18
 
@@ -27,13 +26,12 @@ class DetectionTestingViewFile(DetectionTestingView):
27
26
  output_file = self.getOutputFilePath()
28
27
 
29
28
  folder_path.mkdir(parents=True, exist_ok=True)
30
-
31
-
29
+
32
30
  result_dict = self.getSummaryObject()
33
-
31
+
34
32
  # use the yaml writer class
35
33
  with open(output_file, "w") as res:
36
- res.write(yaml.safe_dump(result_dict,sort_keys=False))
34
+ res.write(yaml.safe_dump(result_dict, sort_keys=False))
37
35
 
38
36
  def showStatus(self, interval: int = 60):
39
37
  pass
@@ -44,35 +44,25 @@ class TestInputDto:
44
44
 
45
45
 
46
46
  class Test:
47
+ def filter_tests(self, input_dto: TestInputDto) -> None:
48
+ """
49
+ If integration testing has NOT been enabled, then skip
50
+ all of the integration tests. Otherwise, do nothing
51
+
52
+ Args:
53
+ input_dto (TestInputDto): A configuration of the test and all of the
54
+ tests to be run.
55
+ """
47
56
 
48
- def filter_detections(self, input_dto: TestInputDto)->TestInputDto:
49
-
50
57
  if not input_dto.config.enable_integration_testing:
51
- #Skip all integraiton tests if integration testing is not enabled:
58
+ # Skip all integraiton tests if integration testing is not enabled:
52
59
  for detection in input_dto.detections:
53
60
  for test in detection.tests:
54
61
  if isinstance(test, IntegrationTest):
55
62
  test.skip("TEST SKIPPED: Skipping all integration tests")
56
-
57
- list_after_filtering:List[Detection] = []
58
- #extra filtering which may be removed/modified in the future
59
- for detection in input_dto.detections:
60
- if (detection.status != DetectionStatus.production.value):
61
- #print(f"{detection.name} - Not testing because [STATUS: {detection.status}]")
62
- pass
63
- elif detection.type == AnalyticsType.Correlation:
64
- #print(f"{detection.name} - Not testing because [ TYPE: {detection.type}]")
65
- pass
66
- else:
67
- list_after_filtering.append(detection)
68
-
69
- return TestInputDto(list_after_filtering, input_dto.config)
70
-
71
-
72
- def execute(self, input_dto: TestInputDto) -> bool:
73
63
 
74
64
 
75
-
65
+ def execute(self, input_dto: TestInputDto) -> bool:
76
66
  output_dto = DetectionTestingManagerOutputDto()
77
67
 
78
68
  web = DetectionTestingViewWeb(config=input_dto.config, sync_obj=output_dto)
@@ -87,26 +77,33 @@ class Test:
87
77
  manager = DetectionTestingManager(
88
78
  input_dto=manager_input_dto, output_dto=output_dto
89
79
  )
90
-
80
+
81
+ mode = input_dto.config.getModeName()
91
82
  if len(input_dto.detections) == 0:
92
- print(f"With Detection Testing Mode '{input_dto.config.getModeName()}', there were [0] detections found to test.\nAs such, we will quit immediately.")
93
- # Directly call stop so that the summary.yml will be generated. Of course it will not have any test results, but we still want it to contain
94
- # a summary showing that now detections were tested.
83
+ print(
84
+ f"With Detection Testing Mode '{mode}', there were [0] detections found to test."
85
+ "\nAs such, we will quit immediately."
86
+ )
87
+ # Directly call stop so that the summary.yml will be generated. Of course it will not
88
+ # have any test results, but we still want it to contain a summary showing that now
89
+ # detections were tested.
95
90
  file.stop()
96
91
  else:
97
- print(f"MODE: [{input_dto.config.getModeName()}] - Test [{len(input_dto.detections)}] detections")
98
- if input_dto.config.mode in [DetectionTestingMode.changes, DetectionTestingMode.selected]:
99
- files_string = '\n- '.join([str(pathlib.Path(detection.file_path).relative_to(input_dto.config.path)) for detection in input_dto.detections])
92
+ print(f"MODE: [{mode}] - Test [{len(input_dto.detections)}] detections")
93
+ if mode in [DetectionTestingMode.changes.value, DetectionTestingMode.selected.value]:
94
+ files_string = '\n- '.join(
95
+ [str(pathlib.Path(detection.file_path).relative_to(input_dto.config.path)) for detection in input_dto.detections]
96
+ )
100
97
  print(f"Detections:\n- {files_string}")
101
98
 
102
99
  manager.setup()
103
100
  manager.execute()
104
-
101
+
105
102
  try:
106
103
  summary_results = file.getSummaryObject()
107
104
  summary = summary_results.get("summary", {})
108
105
 
109
- print("Test Summary")
106
+ print(f"Test Summary (mode: {summary.get('mode','Error')})")
110
107
  print(f"\tSuccess : {summary.get('success',False)}")
111
108
  print(
112
109
  f"\tSuccess Rate : {summary.get('success_rate','ERROR')}"
@@ -115,15 +112,41 @@ class Test:
115
112
  f"\tTotal Detections : {summary.get('total_detections','ERROR')}"
116
113
  )
117
114
  print(
118
- f"\tPassed Detections : {summary.get('total_pass','ERROR')}"
115
+ f"\tTotal Tested Detections : {summary.get('total_tested_detections','ERROR')}"
119
116
  )
120
117
  print(
121
- f"\tFailed Detections : {summary.get('total_fail','ERROR')}"
118
+ f"\t Passed Detections : {summary.get('total_pass','ERROR')}"
119
+ )
120
+ print(
121
+ f"\t Failed Detections : {summary.get('total_fail','ERROR')}"
122
+ )
123
+ print(
124
+ f"\tSkipped Detections : {summary.get('total_skipped','ERROR')}"
125
+ )
126
+ print(
127
+ "\tProduction Status :"
128
+ )
129
+ print(
130
+ f"\t Production Detections : {summary.get('total_production','ERROR')}"
131
+ )
132
+ print(
133
+ f"\t Experimental Detections : {summary.get('total_experimental','ERROR')}"
134
+ )
135
+ print(
136
+ f"\t Deprecated Detections : {summary.get('total_deprecated','ERROR')}"
137
+ )
138
+ print(
139
+ f"\tManually Tested Detections : {summary.get('total_manual','ERROR')}"
122
140
  )
123
141
  print(
124
142
  f"\tUntested Detections : {summary.get('total_untested','ERROR')}"
125
143
  )
126
144
  print(f"\tTest Results File : {file.getOutputFilePath()}")
145
+ print(
146
+ "\nNOTE: skipped detections include non-production, manually tested, and certain\n"
147
+ "detection types (e.g. Correlation), but there may be overlap between these\n"
148
+ "categories."
149
+ )
127
150
  return summary_results.get("summary", {}).get("success", False)
128
151
 
129
152
  except Exception as e:
contentctl/contentctl.py CHANGED
@@ -113,17 +113,14 @@ def test_common_func(config:test_common):
113
113
  test_input_dto = TestInputDto(detections_to_test, config)
114
114
 
115
115
  t = Test()
116
-
117
- # Remove detections that we do not want to test because they are
118
- # not production, the correct type, or manual_test only
119
- filted_test_input_dto = t.filter_detections(test_input_dto)
116
+ t.filter_tests(test_input_dto)
120
117
 
121
118
  if config.plan_only:
122
119
  #Emit the test plan and quit. Do not actually run the test
123
- config.dumpCICDPlanAndQuit(gitServer.getHash(),filted_test_input_dto.detections)
120
+ config.dumpCICDPlanAndQuit(gitServer.getHash(),test_input_dto.detections)
124
121
  return
125
122
 
126
- success = t.execute(filted_test_input_dto)
123
+ success = t.execute(test_input_dto)
127
124
 
128
125
  if success:
129
126
  #Everything passed!
@@ -7,7 +7,7 @@ from attackcti import attack_client
7
7
  import logging
8
8
  from pydantic import BaseModel, Field
9
9
  from dataclasses import field
10
- from typing import Annotated
10
+ from typing import Annotated,Any
11
11
  from contentctl.objects.mitre_attack_enrichment import MitreAttackEnrichment
12
12
  from contentctl.objects.config import validate
13
13
  logging.getLogger('taxii2client').setLevel(logging.CRITICAL)
@@ -33,21 +33,33 @@ class AttackEnrichment(BaseModel):
33
33
  else:
34
34
  raise Exception(f"Error, Unable to find Mitre Enrichment for MitreID {mitre_id}")
35
35
 
36
-
37
- def addMitreID(self, technique:dict, tactics:list[str], groups:list[str])->None:
38
-
36
+ def addMitreIDViaGroupNames(self, technique:dict, tactics:list[str], groupNames:list[str])->None:
39
37
  technique_id = technique['technique_id']
40
38
  technique_obj = technique['technique']
41
39
  tactics.sort()
42
- groups.sort()
43
-
40
+
44
41
  if technique_id in self.data:
45
42
  raise Exception(f"Error, trying to redefine MITRE ID '{technique_id}'")
43
+ self.data[technique_id] = MitreAttackEnrichment(mitre_attack_id=technique_id,
44
+ mitre_attack_technique=technique_obj,
45
+ mitre_attack_tactics=tactics,
46
+ mitre_attack_groups=groupNames,
47
+ mitre_attack_group_objects=[])
48
+
49
+ def addMitreIDViaGroupObjects(self, technique:dict, tactics:list[str], groupObjects:list[dict[str,Any]])->None:
50
+ technique_id = technique['technique_id']
51
+ technique_obj = technique['technique']
52
+ tactics.sort()
46
53
 
54
+ groupNames:list[str] = sorted([group['group'] for group in groupObjects])
55
+
56
+ if technique_id in self.data:
57
+ raise Exception(f"Error, trying to redefine MITRE ID '{technique_id}'")
47
58
  self.data[technique_id] = MitreAttackEnrichment(mitre_attack_id=technique_id,
48
59
  mitre_attack_technique=technique_obj,
49
60
  mitre_attack_tactics=tactics,
50
- mitre_attack_groups=groups)
61
+ mitre_attack_groups=groupNames,
62
+ mitre_attack_group_objects=groupObjects)
51
63
 
52
64
 
53
65
  def get_attack_lookup(self, input_path: str, store_csv: bool = False, force_cached_or_offline: bool = False, skip_enrichment:bool = False) -> dict:
@@ -86,19 +98,20 @@ class AttackEnrichment(BaseModel):
86
98
  progress_percent = ((index+1)/len(all_enterprise_techniques)) * 100
87
99
  if (sys.stdout.isatty() and sys.stdin.isatty() and sys.stderr.isatty()):
88
100
  print(f"\r\t{'MITRE Technique Progress'.rjust(23)}: [{progress_percent:3.0f}%]...", end="", flush=True)
89
- apt_groups = []
101
+ apt_groups:list[dict[str,Any]] = []
90
102
  for relationship in enterprise_relationships:
91
103
  if (relationship['target_object'] == technique['id']) and relationship['source_object'].startswith('intrusion-set'):
92
104
  for group in enterprise_groups:
93
105
  if relationship['source_object'] == group['id']:
94
- apt_groups.append(group['group'])
106
+ apt_groups.append(group)
107
+ #apt_groups.append(group['group'])
95
108
 
96
109
  tactics = []
97
110
  if ('tactic' in technique):
98
111
  for tactic in technique['tactic']:
99
112
  tactics.append(tactic.replace('-',' ').title())
100
113
 
101
- self.addMitreID(technique, tactics, apt_groups)
114
+ self.addMitreIDViaGroupObjects(technique, tactics, apt_groups)
102
115
  attack_lookup[technique['technique_id']] = {'technique': technique['technique'], 'tactics': tactics, 'groups': apt_groups}
103
116
 
104
117
  if store_csv:
@@ -131,7 +144,7 @@ class AttackEnrichment(BaseModel):
131
144
  technique_input = {'technique_id': key , 'technique': attack_lookup[key]['technique'] }
132
145
  tactics_input = attack_lookup[key]['tactics']
133
146
  groups_input = attack_lookup[key]['groups']
134
- self.addMitreID(technique=technique_input, tactics=tactics_input, groups=groups_input)
147
+ self.addMitreIDViaGroupNames(technique=technique_input, tactics=tactics_input, groups=groups_input)
135
148
 
136
149
 
137
150