txt2detection 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of txt2detection might be problematic. Click here for more details.
- txt2detection/__main__.py +219 -68
- txt2detection/ai_extractor/base.py +41 -13
- txt2detection/ai_extractor/models.py +34 -0
- txt2detection/ai_extractor/openai.py +1 -3
- txt2detection/ai_extractor/openrouter.py +4 -4
- txt2detection/ai_extractor/prompts.py +130 -3
- txt2detection/attack_flow.py +233 -0
- txt2detection/bundler.py +174 -87
- txt2detection/credential_checker.py +11 -9
- txt2detection/models.py +86 -49
- txt2detection/observables.py +0 -1
- txt2detection/utils.py +24 -12
- {txt2detection-1.0.7.dist-info → txt2detection-1.0.9.dist-info}/METADATA +7 -8
- txt2detection-1.0.9.dist-info/RECORD +24 -0
- txt2detection-1.0.7.dist-info/RECORD +0 -22
- {txt2detection-1.0.7.dist-info → txt2detection-1.0.9.dist-info}/WHEEL +0 -0
- {txt2detection-1.0.7.dist-info → txt2detection-1.0.9.dist-info}/entry_points.txt +0 -0
- {txt2detection-1.0.7.dist-info → txt2detection-1.0.9.dist-info}/licenses/LICENSE +0 -0
txt2detection/__main__.py
CHANGED
|
@@ -13,11 +13,17 @@ import uuid
|
|
|
13
13
|
from stix2 import Identity
|
|
14
14
|
import yaml
|
|
15
15
|
|
|
16
|
-
from txt2detection import credential_checker
|
|
16
|
+
from txt2detection import attack_flow, credential_checker
|
|
17
17
|
from txt2detection.ai_extractor.base import BaseAIExtractor
|
|
18
|
-
from txt2detection.models import
|
|
18
|
+
from txt2detection.models import (
|
|
19
|
+
TAG_PATTERN,
|
|
20
|
+
DetectionContainer,
|
|
21
|
+
Level,
|
|
22
|
+
SigmaRuleDetection,
|
|
23
|
+
)
|
|
19
24
|
from txt2detection.utils import validate_token_count
|
|
20
25
|
|
|
26
|
+
|
|
21
27
|
def configureLogging():
|
|
22
28
|
# Configure logging
|
|
23
29
|
stream_handler = logging.StreamHandler() # Log to stdout and stderr
|
|
@@ -26,30 +32,38 @@ def configureLogging():
|
|
|
26
32
|
level=logging.DEBUG, # Set the desired logging level
|
|
27
33
|
format=f"%(asctime)s [%(levelname)s] %(message)s",
|
|
28
34
|
handlers=[stream_handler],
|
|
29
|
-
datefmt=
|
|
35
|
+
datefmt="%d-%b-%y %H:%M:%S",
|
|
30
36
|
)
|
|
31
37
|
|
|
32
38
|
return logging.root
|
|
39
|
+
|
|
40
|
+
|
|
33
41
|
configureLogging()
|
|
34
42
|
|
|
43
|
+
|
|
35
44
|
def setLogFile(logger, file: Path):
|
|
36
45
|
file.parent.mkdir(parents=True, exist_ok=True)
|
|
37
46
|
logger.info(f"Saving log to `{file.absolute()}`")
|
|
38
47
|
handler = logging.FileHandler(file, "w")
|
|
39
|
-
handler.formatter = logging.Formatter(
|
|
48
|
+
handler.formatter = logging.Formatter(
|
|
49
|
+
fmt="%(levelname)s %(asctime)s - %(message)s", datefmt="%d-%b-%y %H:%M:%S"
|
|
50
|
+
)
|
|
40
51
|
handler.setLevel(logging.DEBUG)
|
|
41
52
|
logger.addHandler(handler)
|
|
42
53
|
logger.info("=====================txt2detection======================")
|
|
43
54
|
|
|
55
|
+
|
|
44
56
|
from .bundler import Bundler
|
|
45
57
|
import shutil
|
|
46
58
|
|
|
47
59
|
|
|
48
60
|
from .utils import STATUSES, as_date, make_identity, valid_licenses, parse_model
|
|
49
61
|
|
|
62
|
+
|
|
50
63
|
def parse_identity(str):
|
|
51
64
|
return Identity(**json.loads(str))
|
|
52
65
|
|
|
66
|
+
|
|
53
67
|
@dataclass
|
|
54
68
|
class Args:
|
|
55
69
|
input_file: str
|
|
@@ -64,57 +78,146 @@ class Args:
|
|
|
64
78
|
external_refs: dict[str, str]
|
|
65
79
|
reference_urls: list[str]
|
|
66
80
|
|
|
81
|
+
|
|
67
82
|
def parse_created(value):
|
|
68
83
|
"""Convert the created timestamp to a datetime object."""
|
|
69
84
|
try:
|
|
70
|
-
return datetime.strptime(value,
|
|
85
|
+
return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S").replace(tzinfo=UTC)
|
|
71
86
|
except ValueError:
|
|
72
|
-
raise argparse.ArgumentTypeError(
|
|
87
|
+
raise argparse.ArgumentTypeError(
|
|
88
|
+
"Invalid date format. Use YYYY-MM-DDTHH:MM:SS."
|
|
89
|
+
)
|
|
73
90
|
|
|
74
91
|
|
|
75
92
|
def parse_ref(value):
|
|
76
|
-
m = re.compile(r
|
|
93
|
+
m = re.compile(r"(.+?)=(.+)").match(value)
|
|
77
94
|
if not m:
|
|
78
95
|
raise argparse.ArgumentTypeError("must be in format key=value")
|
|
79
96
|
return dict(source_name=m.group(1), external_id=m.group(2))
|
|
80
97
|
|
|
98
|
+
|
|
81
99
|
def parse_label(label: str):
|
|
82
100
|
if not TAG_PATTERN.match(label):
|
|
83
|
-
raise argparse.ArgumentTypeError(
|
|
84
|
-
|
|
85
|
-
|
|
101
|
+
raise argparse.ArgumentTypeError(
|
|
102
|
+
"Invalid label format. Must follow sigma tag format {namespace}.{label}"
|
|
103
|
+
)
|
|
104
|
+
namespace, _, _ = label.partition(".")
|
|
105
|
+
if namespace in ["tlp"]:
|
|
86
106
|
raise argparse.ArgumentTypeError(f"Unsupported tag namespace `{namespace}`")
|
|
87
107
|
return label
|
|
88
108
|
|
|
109
|
+
|
|
89
110
|
def parse_args():
|
|
90
|
-
parser = argparse.ArgumentParser(
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
111
|
+
parser = argparse.ArgumentParser(
|
|
112
|
+
description="Convert text file to detection format."
|
|
113
|
+
)
|
|
114
|
+
mode = parser.add_subparsers(
|
|
115
|
+
title="process-mode", dest="mode", description="mode to use"
|
|
116
|
+
)
|
|
117
|
+
file = mode.add_parser("file", help="process a file input using ai")
|
|
118
|
+
text = mode.add_parser("text", help="process a text argument using ai")
|
|
119
|
+
sigma = mode.add_parser("sigma", help="process a sigma file without ai")
|
|
120
|
+
check_credentials = mode.add_parser(
|
|
121
|
+
"check-credentials",
|
|
122
|
+
help="show status of external services with respect to credentials",
|
|
123
|
+
)
|
|
96
124
|
|
|
97
125
|
for mode_parser in [file, text, sigma]:
|
|
98
|
-
mode_parser.add_argument(
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
mode_parser.add_argument(
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
mode_parser.add_argument(
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
126
|
+
mode_parser.add_argument(
|
|
127
|
+
"--report_id", type=uuid.UUID, help="report_id to use for generated report"
|
|
128
|
+
)
|
|
129
|
+
mode_parser.add_argument(
|
|
130
|
+
"--name",
|
|
131
|
+
required=True,
|
|
132
|
+
help="Name of file, max 72 chars. Will be used in the STIX Report Object created.",
|
|
133
|
+
)
|
|
134
|
+
mode_parser.add_argument(
|
|
135
|
+
"--tlp_level",
|
|
136
|
+
choices=["clear", "green", "amber", "amber_strict", "red"],
|
|
137
|
+
help="Options are clear, green, amber, amber_strict, red. Default is clear if not passed.",
|
|
138
|
+
)
|
|
139
|
+
mode_parser.add_argument(
|
|
140
|
+
"--labels",
|
|
141
|
+
type=parse_label,
|
|
142
|
+
action="extend",
|
|
143
|
+
nargs="+",
|
|
144
|
+
help="Comma-separated list of labels. Case-insensitive (will be converted to lower-case). Allowed a-z, 0-9.",
|
|
145
|
+
)
|
|
146
|
+
mode_parser.add_argument(
|
|
147
|
+
"--created",
|
|
148
|
+
type=parse_created,
|
|
149
|
+
help="Explicitly set created time in format YYYY-MM-DDTHH:MM:SS.sssZ. Default is current time.",
|
|
150
|
+
)
|
|
151
|
+
mode_parser.add_argument(
|
|
152
|
+
"--use_identity",
|
|
153
|
+
type=parse_identity,
|
|
154
|
+
help="Pass a full STIX 2.1 identity object (properly escaped). Validated by the STIX2 library. Default is SIEM Rules identity.",
|
|
155
|
+
)
|
|
156
|
+
mode_parser.add_argument(
|
|
157
|
+
"--ai_provider",
|
|
158
|
+
required=False,
|
|
159
|
+
type=parse_model,
|
|
160
|
+
help="(required): defines the `provider:model` to be used. Select one option.",
|
|
161
|
+
metavar="provider[:model]",
|
|
162
|
+
)
|
|
163
|
+
mode_parser.add_argument(
|
|
164
|
+
"--external_refs",
|
|
165
|
+
type=parse_ref,
|
|
166
|
+
help="pass additional `external_references` entry (or entries) to the report object created. e.g --external_ref author=dogesec link=https://dkjjadhdaj.net",
|
|
167
|
+
default=[],
|
|
168
|
+
metavar="{source_name}={external_id}",
|
|
169
|
+
action="extend",
|
|
170
|
+
nargs="+",
|
|
171
|
+
)
|
|
172
|
+
mode_parser.add_argument(
|
|
173
|
+
"--reference_urls",
|
|
174
|
+
help="pass additional `external_references` url entry (or entries) to the report object created.",
|
|
175
|
+
default=[],
|
|
176
|
+
metavar="{url}",
|
|
177
|
+
action="extend",
|
|
178
|
+
nargs="+",
|
|
179
|
+
)
|
|
180
|
+
mode_parser.add_argument(
|
|
181
|
+
"--license",
|
|
182
|
+
help="Valid SPDX license for the rule",
|
|
183
|
+
default=None,
|
|
184
|
+
metavar="[LICENSE]",
|
|
185
|
+
choices=valid_licenses(),
|
|
186
|
+
)
|
|
187
|
+
mode_parser.add_argument(
|
|
188
|
+
"--ai_create_attack_navigator_layer",
|
|
189
|
+
help="Create navigator layer",
|
|
190
|
+
action="store_true",
|
|
191
|
+
default=False,
|
|
192
|
+
)
|
|
193
|
+
mode_parser.add_argument(
|
|
194
|
+
"--ai_create_attack_flow",
|
|
195
|
+
help="Create attack flow",
|
|
196
|
+
action="store_true",
|
|
197
|
+
default=False,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
file.add_argument(
|
|
201
|
+
"--input_file",
|
|
202
|
+
help="The file to be converted. Must be .txt",
|
|
203
|
+
type=lambda x: Path(x).read_text(),
|
|
204
|
+
)
|
|
205
|
+
text.add_argument("--input_text", help="The text to be converted")
|
|
206
|
+
sigma.add_argument(
|
|
207
|
+
"--sigma_file",
|
|
208
|
+
help="The sigma file to be converted. Must be .yml",
|
|
209
|
+
type=lambda x: Path(x).read_text(),
|
|
210
|
+
)
|
|
211
|
+
sigma.add_argument(
|
|
212
|
+
"--status",
|
|
213
|
+
help="If passed, will overwrite any existing `status` recorded in the rule",
|
|
214
|
+
choices=STATUSES,
|
|
215
|
+
)
|
|
216
|
+
sigma.add_argument(
|
|
217
|
+
"--level",
|
|
218
|
+
help="If passed, will overwrite any existing `level` recorded in the rule",
|
|
219
|
+
choices=Level._member_names_,
|
|
220
|
+
)
|
|
118
221
|
|
|
119
222
|
args: Args = parser.parse_args()
|
|
120
223
|
if args.mode == "check-credentials":
|
|
@@ -122,75 +225,123 @@ def parse_args():
|
|
|
122
225
|
credential_checker.format_statuses(statuses)
|
|
123
226
|
sys.exit(0)
|
|
124
227
|
|
|
125
|
-
if args.mode !=
|
|
228
|
+
if args.mode != "sigma":
|
|
126
229
|
assert args.ai_provider, "--ai_provider is required in file or txt mode"
|
|
127
|
-
|
|
230
|
+
|
|
231
|
+
if args.ai_create_attack_navigator_layer or args.ai_create_attack_flow:
|
|
232
|
+
assert (
|
|
233
|
+
args.ai_provider
|
|
234
|
+
), "--ai_provider is required when --ai_create_attack_navigator_layer/--ai_create_attack_flow is passed"
|
|
235
|
+
|
|
236
|
+
if args.mode == "file":
|
|
128
237
|
args.input_text = args.input_file
|
|
129
238
|
|
|
130
|
-
args.input_text = getattr(args,
|
|
239
|
+
args.input_text = getattr(args, "input_text", "")
|
|
131
240
|
if not args.report_id:
|
|
132
|
-
args.report_id = Bundler.generate_report_id(
|
|
241
|
+
args.report_id = Bundler.generate_report_id(
|
|
242
|
+
args.use_identity.id if args.use_identity else None, args.created, args.name
|
|
243
|
+
)
|
|
133
244
|
|
|
134
245
|
return args
|
|
135
246
|
|
|
136
247
|
|
|
248
|
+
def run_txt2detection(
|
|
249
|
+
name,
|
|
250
|
+
identity,
|
|
251
|
+
tlp_level,
|
|
252
|
+
input_text: str,
|
|
253
|
+
labels: list[str],
|
|
254
|
+
report_id: str | uuid.UUID,
|
|
255
|
+
ai_provider: BaseAIExtractor,
|
|
256
|
+
ai_create_attack_flow=False,
|
|
257
|
+
ai_create_attack_navigator_layer=False,
|
|
258
|
+
**kwargs,
|
|
259
|
+
) -> Bundler:
|
|
260
|
+
if (
|
|
261
|
+
kwargs.get("sigma_file") != "sigma_file"
|
|
262
|
+
or ai_create_attack_flow
|
|
263
|
+
or ai_create_attack_navigator_layer
|
|
264
|
+
):
|
|
265
|
+
validate_token_count(
|
|
266
|
+
int(os.getenv("INPUT_TOKEN_LIMIT", 0)), input_text, ai_provider
|
|
267
|
+
)
|
|
137
268
|
|
|
138
|
-
|
|
139
|
-
def run_txt2detection(name, identity, tlp_level, input_text: str, labels: list[str], report_id: str|uuid.UUID, ai_provider: BaseAIExtractor, **kwargs) -> Bundler:
|
|
140
|
-
if sigma := kwargs.get('sigma_file'):
|
|
269
|
+
if sigma := kwargs.get("sigma_file"):
|
|
141
270
|
detection = get_sigma_detections(sigma)
|
|
142
271
|
if not identity and detection.author:
|
|
143
272
|
identity = make_identity(detection.author)
|
|
144
|
-
kwargs.update(
|
|
145
|
-
|
|
146
|
-
|
|
273
|
+
kwargs.update(
|
|
274
|
+
reference_urls=kwargs.setdefault("reference_urls", [])
|
|
275
|
+
+ detection.references
|
|
276
|
+
)
|
|
277
|
+
if not kwargs.get("created"):
|
|
278
|
+
# only consider rule.date and rule.modified if user does not pass --created
|
|
147
279
|
kwargs.update(
|
|
148
280
|
created=detection.date,
|
|
149
281
|
modified=detection.modified,
|
|
150
282
|
)
|
|
151
|
-
detection.level = kwargs.get(
|
|
152
|
-
detection.status = kwargs.get(
|
|
153
|
-
detection.date = as_date(kwargs.get(
|
|
154
|
-
detection.modified = as_date(kwargs.get(
|
|
155
|
-
detection.references = kwargs[
|
|
156
|
-
detection.detection_id = str(report_id).removeprefix(
|
|
157
|
-
bundler = Bundler(
|
|
283
|
+
detection.level = kwargs.get("level", detection.level)
|
|
284
|
+
detection.status = kwargs.get("status", detection.status)
|
|
285
|
+
detection.date = as_date(kwargs.get("created"))
|
|
286
|
+
detection.modified = as_date(kwargs.get("modified"))
|
|
287
|
+
detection.references = kwargs["reference_urls"]
|
|
288
|
+
detection.detection_id = str(report_id).removeprefix("report--")
|
|
289
|
+
bundler = Bundler(
|
|
290
|
+
name or detection.title,
|
|
291
|
+
identity,
|
|
292
|
+
tlp_level or detection.tlp_level or "clear",
|
|
293
|
+
detection.description,
|
|
294
|
+
(labels or []) + detection.tags,
|
|
295
|
+
report_id=report_id,
|
|
296
|
+
**kwargs,
|
|
297
|
+
)
|
|
158
298
|
detections = DetectionContainer(success=True, detections=[])
|
|
159
299
|
detections.detections.append(detection)
|
|
160
300
|
else:
|
|
161
|
-
|
|
162
|
-
|
|
301
|
+
bundler = Bundler(
|
|
302
|
+
name, identity, tlp_level, input_text, labels, report_id=report_id, **kwargs
|
|
303
|
+
)
|
|
163
304
|
detections = ai_provider.get_detections(input_text)
|
|
164
305
|
bundler.bundle_detections(detections)
|
|
306
|
+
|
|
307
|
+
if ai_create_attack_flow or ai_create_attack_navigator_layer:
|
|
308
|
+
bundler.data.attack_flow, bundler.data.navigator_layer = (
|
|
309
|
+
attack_flow.extract_attack_flow_and_navigator(
|
|
310
|
+
bundler,
|
|
311
|
+
bundler.report.description,
|
|
312
|
+
ai_create_attack_flow,
|
|
313
|
+
ai_create_attack_navigator_layer,
|
|
314
|
+
ai_provider,
|
|
315
|
+
)
|
|
316
|
+
)
|
|
165
317
|
return bundler
|
|
166
318
|
|
|
319
|
+
|
|
167
320
|
def get_sigma_detections(sigma: str) -> SigmaRuleDetection:
|
|
168
321
|
obj = yaml.safe_load(io.StringIO(sigma))
|
|
169
322
|
return SigmaRuleDetection.model_validate(obj)
|
|
170
|
-
|
|
171
323
|
|
|
172
|
-
|
|
324
|
+
|
|
173
325
|
def main(args: Args):
|
|
174
326
|
|
|
175
327
|
setLogFile(logging.root, Path(f"logs/log-{args.report_id}.log"))
|
|
176
328
|
logging.info(f"starting argument: {json.dumps(sys.argv[1:])}")
|
|
177
329
|
kwargs = args.__dict__
|
|
178
|
-
kwargs[
|
|
330
|
+
kwargs["identity"] = args.use_identity
|
|
179
331
|
bundler = run_txt2detection(**kwargs)
|
|
180
332
|
|
|
181
|
-
output_dir = Path("./output")/str(bundler.bundle.id)
|
|
333
|
+
output_dir = Path("./output") / str(bundler.bundle.id)
|
|
182
334
|
shutil.rmtree(output_dir, ignore_errors=True)
|
|
183
|
-
rules_dir = output_dir/"rules"
|
|
335
|
+
rules_dir = output_dir / "rules"
|
|
184
336
|
rules_dir.mkdir(exist_ok=True, parents=True)
|
|
185
337
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
data_path = output_dir/f"data.json"
|
|
338
|
+
output_path = output_dir / "bundle.json"
|
|
339
|
+
data_path = output_dir / f"data.json"
|
|
189
340
|
output_path.write_text(bundler.to_json())
|
|
190
|
-
data_path.write_text(bundler.
|
|
191
|
-
for obj in bundler.bundle[
|
|
192
|
-
if obj[
|
|
341
|
+
data_path.write_text(bundler.data.model_dump_json(indent=4))
|
|
342
|
+
for obj in bundler.bundle["objects"]:
|
|
343
|
+
if obj["type"] != "indicator" or obj["pattern_type"] != "sigma":
|
|
193
344
|
continue
|
|
194
|
-
name = obj[
|
|
195
|
-
(rules_dir/name).write_text(obj[
|
|
345
|
+
name = obj["id"].replace("indicator", "rule") + ".yml"
|
|
346
|
+
(rules_dir / name).write_text(obj["pattern"])
|
|
196
347
|
logging.info(f"Writing bundle output to `{output_path}`")
|
|
@@ -7,15 +7,19 @@ from llama_index.core.llms.llm import LLM
|
|
|
7
7
|
|
|
8
8
|
from txt2detection.ai_extractor import prompts
|
|
9
9
|
|
|
10
|
+
from txt2detection.ai_extractor.models import AttackFlowList
|
|
10
11
|
from txt2detection.ai_extractor.utils import ParserWithLogging
|
|
11
12
|
from txt2detection.models import DetectionContainer, DetectionContainer
|
|
12
13
|
from llama_index.core.utils import get_tokenizer
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
_ai_extractor_registry: dict[str,
|
|
16
|
-
|
|
16
|
+
_ai_extractor_registry: dict[str, "Type[BaseAIExtractor]"] = {}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseAIExtractor:
|
|
17
20
|
llm: LLM
|
|
18
|
-
system_prompt =
|
|
21
|
+
system_prompt = textwrap.dedent(
|
|
22
|
+
"""
|
|
19
23
|
<persona>
|
|
20
24
|
|
|
21
25
|
You are a cyber-security detection engineering tool responsible for analysing intelligence reports provided in text files and writing SIGMA detection rules to detect the content being described in the reports.
|
|
@@ -25,12 +29,11 @@ class BaseAIExtractor():
|
|
|
25
29
|
IMPORTANT: You must always deliver your work as a computer-parsable output in JSON format. All output from you will be parsed with pydantic for further processing.
|
|
26
30
|
|
|
27
31
|
</persona>
|
|
28
|
-
"""
|
|
29
|
-
|
|
32
|
+
"""
|
|
33
|
+
)
|
|
30
34
|
|
|
31
35
|
def get_detections(self, input_text) -> DetectionContainer:
|
|
32
|
-
logging.info(
|
|
33
|
-
|
|
36
|
+
logging.info("getting detections")
|
|
34
37
|
|
|
35
38
|
return LLMTextCompletionProgram.from_defaults(
|
|
36
39
|
output_parser=ParserWithLogging(DetectionContainer),
|
|
@@ -38,14 +41,17 @@ class BaseAIExtractor():
|
|
|
38
41
|
verbose=True,
|
|
39
42
|
llm=self.llm,
|
|
40
43
|
)(document=input_text)
|
|
41
|
-
|
|
44
|
+
|
|
42
45
|
def __init__(self, *args, **kwargs) -> None:
|
|
43
46
|
pass
|
|
44
47
|
|
|
45
48
|
def count_tokens(self, input_text):
|
|
46
|
-
logging.info(
|
|
49
|
+
logging.info(
|
|
50
|
+
"unsupported model `%s`, estimating using llama-index's default tokenizer",
|
|
51
|
+
self.extractor_name,
|
|
52
|
+
)
|
|
47
53
|
return len(get_tokenizer()(input_text))
|
|
48
|
-
|
|
54
|
+
|
|
49
55
|
def __init_subclass__(cls, /, provider, register=True, **kwargs):
|
|
50
56
|
super().__init_subclass__(**kwargs)
|
|
51
57
|
if register:
|
|
@@ -55,13 +61,35 @@ class BaseAIExtractor():
|
|
|
55
61
|
@property
|
|
56
62
|
def extractor_name(self):
|
|
57
63
|
return f"{self.provider}:{self.llm.model}"
|
|
58
|
-
|
|
64
|
+
|
|
65
|
+
def _get_attack_flow_program(self):
|
|
66
|
+
return LLMTextCompletionProgram.from_defaults(
|
|
67
|
+
output_parser=ParserWithLogging(AttackFlowList),
|
|
68
|
+
prompt=prompts.ATTACK_FLOW_PROMPT_TEMPL,
|
|
69
|
+
verbose=True,
|
|
70
|
+
llm=self.llm,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def extract_attack_flow(self, input_text, techniques) -> AttackFlowList:
|
|
74
|
+
extracted_techniques = []
|
|
75
|
+
for t in techniques.values():
|
|
76
|
+
extracted_techniques.append(
|
|
77
|
+
dict(
|
|
78
|
+
id=t["id"],
|
|
79
|
+
name=t["name"],
|
|
80
|
+
possible_tactics=list(t["possible_tactics"].keys()),
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
return self._get_attack_flow_program()(
|
|
84
|
+
document=input_text, extracted_techniques=extracted_techniques
|
|
85
|
+
)
|
|
86
|
+
|
|
59
87
|
def check_credential(self):
|
|
60
88
|
try:
|
|
61
89
|
return "authorized" if self._check_credential() else "unauthorized"
|
|
62
90
|
except:
|
|
63
91
|
return "unknown"
|
|
64
|
-
|
|
92
|
+
|
|
65
93
|
def _check_credential(self):
|
|
66
94
|
self.llm.complete("say 'hi'")
|
|
67
|
-
return True
|
|
95
|
+
return True
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import dotenv
|
|
6
|
+
import textwrap
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, RootModel
|
|
9
|
+
from llama_index.core.output_parsers import PydanticOutputParser
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AttackFlowItem(BaseModel):
|
|
13
|
+
position: int = Field(description="order of object starting at 0")
|
|
14
|
+
attack_technique_id: str
|
|
15
|
+
name: str
|
|
16
|
+
description: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AttackFlowList(BaseModel):
|
|
20
|
+
tactic_selection: list[tuple[str, str]] = Field(
|
|
21
|
+
description="attack technique id to attack tactic id mapping using possible_tactics"
|
|
22
|
+
)
|
|
23
|
+
# additional_tactic_mapping: list[tuple[str, str]] = Field(description="the rest of tactic_mapping")
|
|
24
|
+
items: list[AttackFlowItem]
|
|
25
|
+
success: bool = Field(
|
|
26
|
+
description="determines if there's any valid flow in <extractions>"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def model_post_init(self, context):
|
|
30
|
+
return super().model_post_init(context)
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def tactic_mapping(self):
|
|
34
|
+
return dict(self.tactic_selection)
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
|
|
2
1
|
import logging
|
|
3
2
|
import os
|
|
4
3
|
from .base import BaseAIExtractor
|
|
@@ -7,7 +6,7 @@ from llama_index.llms.openai import OpenAI
|
|
|
7
6
|
|
|
8
7
|
class OpenAIExtractor(BaseAIExtractor, provider="openai"):
|
|
9
8
|
def __init__(self, **kwargs) -> None:
|
|
10
|
-
kwargs.setdefault(
|
|
9
|
+
kwargs.setdefault("temperature", float(os.environ.get("TEMPERATURE", 0.0)))
|
|
11
10
|
self.llm = OpenAI(system_prompt=self.system_prompt, **kwargs, max_tokens=4096)
|
|
12
11
|
super().__init__()
|
|
13
12
|
|
|
@@ -17,4 +16,3 @@ class OpenAIExtractor(BaseAIExtractor, provider="openai"):
|
|
|
17
16
|
except Exception as e:
|
|
18
17
|
logging.warning(e)
|
|
19
18
|
return super().count_tokens(text)
|
|
20
|
-
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
|
|
2
1
|
import logging
|
|
3
2
|
import os
|
|
4
3
|
from .base import BaseAIExtractor
|
|
@@ -7,8 +6,10 @@ from llama_index.llms.openrouter import OpenRouter
|
|
|
7
6
|
|
|
8
7
|
class OpenRouterExtractor(BaseAIExtractor, provider="openrouter"):
|
|
9
8
|
def __init__(self, **kwargs) -> None:
|
|
10
|
-
kwargs.setdefault(
|
|
11
|
-
self.llm = OpenRouter(
|
|
9
|
+
kwargs.setdefault("temperature", float(os.environ.get("TEMPERATURE", 0.0)))
|
|
10
|
+
self.llm = OpenRouter(
|
|
11
|
+
system_prompt=self.system_prompt, max_tokens=4096, **kwargs
|
|
12
|
+
)
|
|
12
13
|
super().__init__()
|
|
13
14
|
|
|
14
15
|
def count_tokens(self, text):
|
|
@@ -17,4 +18,3 @@ class OpenRouterExtractor(BaseAIExtractor, provider="openrouter"):
|
|
|
17
18
|
except Exception as e:
|
|
18
19
|
logging.warning(e)
|
|
19
20
|
return super().count_tokens(text)
|
|
20
|
-
|