txt2detection 1.0.8__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of txt2detection might be problematic. Click here for more details.

txt2detection/__main__.py CHANGED
@@ -13,11 +13,17 @@ import uuid
13
13
  from stix2 import Identity
14
14
  import yaml
15
15
 
16
- from txt2detection import credential_checker
16
+ from txt2detection import attack_flow, credential_checker
17
17
  from txt2detection.ai_extractor.base import BaseAIExtractor
18
- from txt2detection.models import TAG_PATTERN, DetectionContainer, Level, SigmaRuleDetection
18
+ from txt2detection.models import (
19
+ TAG_PATTERN,
20
+ DetectionContainer,
21
+ Level,
22
+ SigmaRuleDetection,
23
+ )
19
24
  from txt2detection.utils import validate_token_count
20
25
 
26
+
21
27
  def configureLogging():
22
28
  # Configure logging
23
29
  stream_handler = logging.StreamHandler() # Log to stdout and stderr
@@ -26,30 +32,38 @@ def configureLogging():
26
32
  level=logging.DEBUG, # Set the desired logging level
27
33
  format=f"%(asctime)s [%(levelname)s] %(message)s",
28
34
  handlers=[stream_handler],
29
- datefmt='%d-%b-%y %H:%M:%S'
35
+ datefmt="%d-%b-%y %H:%M:%S",
30
36
  )
31
37
 
32
38
  return logging.root
39
+
40
+
33
41
  configureLogging()
34
42
 
43
+
35
44
  def setLogFile(logger, file: Path):
36
45
  file.parent.mkdir(parents=True, exist_ok=True)
37
46
  logger.info(f"Saving log to `{file.absolute()}`")
38
47
  handler = logging.FileHandler(file, "w")
39
- handler.formatter = logging.Formatter(fmt='%(levelname)s %(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
48
+ handler.formatter = logging.Formatter(
49
+ fmt="%(levelname)s %(asctime)s - %(message)s", datefmt="%d-%b-%y %H:%M:%S"
50
+ )
40
51
  handler.setLevel(logging.DEBUG)
41
52
  logger.addHandler(handler)
42
53
  logger.info("=====================txt2detection======================")
43
54
 
55
+
44
56
  from .bundler import Bundler
45
57
  import shutil
46
58
 
47
59
 
48
60
  from .utils import STATUSES, as_date, make_identity, valid_licenses, parse_model
49
61
 
62
+
50
63
  def parse_identity(str):
51
64
  return Identity(**json.loads(str))
52
65
 
66
+
53
67
  @dataclass
54
68
  class Args:
55
69
  input_file: str
@@ -64,57 +78,146 @@ class Args:
64
78
  external_refs: dict[str, str]
65
79
  reference_urls: list[str]
66
80
 
81
+
67
82
  def parse_created(value):
68
83
  """Convert the created timestamp to a datetime object."""
69
84
  try:
70
- return datetime.strptime(value, '%Y-%m-%dT%H:%M:%S').replace(tzinfo=UTC)
85
+ return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S").replace(tzinfo=UTC)
71
86
  except ValueError:
72
- raise argparse.ArgumentTypeError("Invalid date format. Use YYYY-MM-DDTHH:MM:SS.")
87
+ raise argparse.ArgumentTypeError(
88
+ "Invalid date format. Use YYYY-MM-DDTHH:MM:SS."
89
+ )
73
90
 
74
91
 
75
92
  def parse_ref(value):
76
- m = re.compile(r'(.+?)=(.+)').match(value)
93
+ m = re.compile(r"(.+?)=(.+)").match(value)
77
94
  if not m:
78
95
  raise argparse.ArgumentTypeError("must be in format key=value")
79
96
  return dict(source_name=m.group(1), external_id=m.group(2))
80
97
 
98
+
81
99
  def parse_label(label: str):
82
100
  if not TAG_PATTERN.match(label):
83
- raise argparse.ArgumentTypeError("Invalid label format. Must follow sigma tag format {namespace}.{label}")
84
- namespace, _, _ = label.partition('.')
85
- if namespace in ['tlp']:
101
+ raise argparse.ArgumentTypeError(
102
+ "Invalid label format. Must follow sigma tag format {namespace}.{label}"
103
+ )
104
+ namespace, _, _ = label.partition(".")
105
+ if namespace in ["tlp"]:
86
106
  raise argparse.ArgumentTypeError(f"Unsupported tag namespace `{namespace}`")
87
107
  return label
88
108
 
109
+
89
110
  def parse_args():
90
- parser = argparse.ArgumentParser(description='Convert text file to detection format.')
91
- mode = parser.add_subparsers(title='process-mode', dest='mode', description="mode to use")
92
- file = mode.add_parser('file', help="process a file input using ai")
93
- text = mode.add_parser('text', help="process a text argument using ai")
94
- sigma = mode.add_parser('sigma', help="process a sigma file without ai")
95
- check_credentials = mode.add_parser('check-credentials', help="show status of external services with respect to credentials")
111
+ parser = argparse.ArgumentParser(
112
+ description="Convert text file to detection format."
113
+ )
114
+ mode = parser.add_subparsers(
115
+ title="process-mode", dest="mode", description="mode to use"
116
+ )
117
+ file = mode.add_parser("file", help="process a file input using ai")
118
+ text = mode.add_parser("text", help="process a text argument using ai")
119
+ sigma = mode.add_parser("sigma", help="process a sigma file without ai")
120
+ check_credentials = mode.add_parser(
121
+ "check-credentials",
122
+ help="show status of external services with respect to credentials",
123
+ )
96
124
 
97
125
  for mode_parser in [file, text, sigma]:
98
- mode_parser.add_argument('--report_id', type=uuid.UUID, help='report_id to use for generated report')
99
- mode_parser.add_argument('--name', required=True, help='Name of file, max 72 chars. Will be used in the STIX Report Object created.')
100
- mode_parser.add_argument('--tlp_level', choices=['clear', 'green', 'amber', 'amber_strict', 'red'],
101
- help='Options are clear, green, amber, amber_strict, red. Default is clear if not passed.')
102
- mode_parser.add_argument('--labels', type=parse_label, action="extend", nargs='+',
103
- help='Comma-separated list of labels. Case-insensitive (will be converted to lower-case). Allowed a-z, 0-9.')
104
- mode_parser.add_argument('--created', type=parse_created,
105
- help='Explicitly set created time in format YYYY-MM-DDTHH:MM:SS.sssZ. Default is current time.')
106
- mode_parser.add_argument('--use_identity', type=parse_identity,
107
- help='Pass a full STIX 2.1 identity object (properly escaped). Validated by the STIX2 library. Default is SIEM Rules identity.')
108
- mode_parser.add_argument("--ai_provider", required=False, type=parse_model, help="(required): defines the `provider:model` to be used. Select one option.", metavar="provider[:model]")
109
- mode_parser.add_argument("--external_refs", type=parse_ref, help="pass additional `external_references` entry (or entries) to the report object created. e.g --external_ref author=dogesec link=https://dkjjadhdaj.net", default=[], metavar="{source_name}={external_id}", action="extend", nargs='+')
110
- mode_parser.add_argument("--reference_urls", help="pass additional `external_references` url entry (or entries) to the report object created.", default=[], metavar="{url}", action="extend", nargs='+')
111
- mode_parser.add_argument("--license", help="Valid SPDX license for the rule", default=None, metavar="[LICENSE]", choices=valid_licenses())
112
-
113
- file.add_argument('--input_file', help='The file to be converted. Must be .txt', type=lambda x: Path(x).read_text())
114
- text.add_argument('--input_text', help='The text to be converted')
115
- sigma.add_argument('--sigma_file', help='The sigma file to be converted. Must be .yml', type=lambda x: Path(x).read_text())
116
- sigma.add_argument('--status', help="If passed, will overwrite any existing `status` recorded in the rule", choices=STATUSES)
117
- sigma.add_argument('--level', help="If passed, will overwrite any existing `level` recorded in the rule", choices=Level._member_names_)
126
+ mode_parser.add_argument(
127
+ "--report_id", type=uuid.UUID, help="report_id to use for generated report"
128
+ )
129
+ mode_parser.add_argument(
130
+ "--name",
131
+ required=True,
132
+ help="Name of file, max 72 chars. Will be used in the STIX Report Object created.",
133
+ )
134
+ mode_parser.add_argument(
135
+ "--tlp_level",
136
+ choices=["clear", "green", "amber", "amber_strict", "red"],
137
+ help="Options are clear, green, amber, amber_strict, red. Default is clear if not passed.",
138
+ )
139
+ mode_parser.add_argument(
140
+ "--labels",
141
+ type=parse_label,
142
+ action="extend",
143
+ nargs="+",
144
+ help="Comma-separated list of labels. Case-insensitive (will be converted to lower-case). Allowed a-z, 0-9.",
145
+ )
146
+ mode_parser.add_argument(
147
+ "--created",
148
+ type=parse_created,
149
+ help="Explicitly set created time in format YYYY-MM-DDTHH:MM:SS.sssZ. Default is current time.",
150
+ )
151
+ mode_parser.add_argument(
152
+ "--use_identity",
153
+ type=parse_identity,
154
+ help="Pass a full STIX 2.1 identity object (properly escaped). Validated by the STIX2 library. Default is SIEM Rules identity.",
155
+ )
156
+ mode_parser.add_argument(
157
+ "--ai_provider",
158
+ required=False,
159
+ type=parse_model,
160
+ help="(required): defines the `provider:model` to be used. Select one option.",
161
+ metavar="provider[:model]",
162
+ )
163
+ mode_parser.add_argument(
164
+ "--external_refs",
165
+ type=parse_ref,
166
+ help="pass additional `external_references` entry (or entries) to the report object created. e.g --external_ref author=dogesec link=https://dkjjadhdaj.net",
167
+ default=[],
168
+ metavar="{source_name}={external_id}",
169
+ action="extend",
170
+ nargs="+",
171
+ )
172
+ mode_parser.add_argument(
173
+ "--reference_urls",
174
+ help="pass additional `external_references` url entry (or entries) to the report object created.",
175
+ default=[],
176
+ metavar="{url}",
177
+ action="extend",
178
+ nargs="+",
179
+ )
180
+ mode_parser.add_argument(
181
+ "--license",
182
+ help="Valid SPDX license for the rule",
183
+ default=None,
184
+ metavar="[LICENSE]",
185
+ choices=valid_licenses(),
186
+ )
187
+ mode_parser.add_argument(
188
+ "--ai_create_attack_navigator_layer",
189
+ help="Create navigator layer",
190
+ action="store_true",
191
+ default=False,
192
+ )
193
+ mode_parser.add_argument(
194
+ "--ai_create_attack_flow",
195
+ help="Create attack flow",
196
+ action="store_true",
197
+ default=False,
198
+ )
199
+
200
+ file.add_argument(
201
+ "--input_file",
202
+ help="The file to be converted. Must be .txt",
203
+ type=lambda x: Path(x).read_text(),
204
+ )
205
+ text.add_argument("--input_text", help="The text to be converted")
206
+ sigma.add_argument(
207
+ "--sigma_file",
208
+ help="The sigma file to be converted. Must be .yml",
209
+ type=lambda x: Path(x).read_text(),
210
+ )
211
+ sigma.add_argument(
212
+ "--status",
213
+ help="If passed, will overwrite any existing `status` recorded in the rule",
214
+ choices=STATUSES,
215
+ )
216
+ sigma.add_argument(
217
+ "--level",
218
+ help="If passed, will overwrite any existing `level` recorded in the rule",
219
+ choices=Level._member_names_,
220
+ )
118
221
 
119
222
  args: Args = parser.parse_args()
120
223
  if args.mode == "check-credentials":
@@ -122,75 +225,123 @@ def parse_args():
122
225
  credential_checker.format_statuses(statuses)
123
226
  sys.exit(0)
124
227
 
125
- if args.mode != 'sigma':
228
+ if args.mode != "sigma":
126
229
  assert args.ai_provider, "--ai_provider is required in file or txt mode"
127
- if args.mode == 'file':
230
+
231
+ if args.ai_create_attack_navigator_layer or args.ai_create_attack_flow:
232
+ assert (
233
+ args.ai_provider
234
+ ), "--ai_provider is required when --ai_create_attack_navigator_layer/--ai_create_attack_flow is passed"
235
+
236
+ if args.mode == "file":
128
237
  args.input_text = args.input_file
129
238
 
130
- args.input_text = getattr(args, 'input_text', "")
239
+ args.input_text = getattr(args, "input_text", "")
131
240
  if not args.report_id:
132
- args.report_id = Bundler.generate_report_id(args.use_identity.id if args.use_identity else None, args.created, args.name)
241
+ args.report_id = Bundler.generate_report_id(
242
+ args.use_identity.id if args.use_identity else None, args.created, args.name
243
+ )
133
244
 
134
245
  return args
135
246
 
136
247
 
248
+ def run_txt2detection(
249
+ name,
250
+ identity,
251
+ tlp_level,
252
+ input_text: str,
253
+ labels: list[str],
254
+ report_id: str | uuid.UUID,
255
+ ai_provider: BaseAIExtractor,
256
+ ai_create_attack_flow=False,
257
+ ai_create_attack_navigator_layer=False,
258
+ **kwargs,
259
+ ) -> Bundler:
260
+ if (
261
+ kwargs.get("sigma_file") != "sigma_file"
262
+ or ai_create_attack_flow
263
+ or ai_create_attack_navigator_layer
264
+ ):
265
+ validate_token_count(
266
+ int(os.getenv("INPUT_TOKEN_LIMIT", 0)), input_text, ai_provider
267
+ )
137
268
 
138
-
139
- def run_txt2detection(name, identity, tlp_level, input_text: str, labels: list[str], report_id: str|uuid.UUID, ai_provider: BaseAIExtractor, **kwargs) -> Bundler:
140
- if sigma := kwargs.get('sigma_file'):
269
+ if sigma := kwargs.get("sigma_file"):
141
270
  detection = get_sigma_detections(sigma)
142
271
  if not identity and detection.author:
143
272
  identity = make_identity(detection.author)
144
- kwargs.update(reference_urls=kwargs.setdefault('reference_urls', [])+detection.references)
145
- if not kwargs.get('created'):
146
- #only consider rule.date and rule.modified if user does not pass --created
273
+ kwargs.update(
274
+ reference_urls=kwargs.setdefault("reference_urls", [])
275
+ + detection.references
276
+ )
277
+ if not kwargs.get("created"):
278
+ # only consider rule.date and rule.modified if user does not pass --created
147
279
  kwargs.update(
148
280
  created=detection.date,
149
281
  modified=detection.modified,
150
282
  )
151
- detection.level = kwargs.get('level', detection.level)
152
- detection.status = kwargs.get('status', detection.status)
153
- detection.date = as_date(kwargs.get('created'))
154
- detection.modified = as_date(kwargs.get('modified'))
155
- detection.references = kwargs['reference_urls']
156
- detection.detection_id = str(report_id).removeprefix('report--')
157
- bundler = Bundler(name or detection.title, identity, tlp_level or detection.tlp_level or 'clear', detection.description, (labels or [])+detection.tags, report_id=report_id, **kwargs)
283
+ detection.level = kwargs.get("level", detection.level)
284
+ detection.status = kwargs.get("status", detection.status)
285
+ detection.date = as_date(kwargs.get("created"))
286
+ detection.modified = as_date(kwargs.get("modified"))
287
+ detection.references = kwargs["reference_urls"]
288
+ detection.detection_id = str(report_id).removeprefix("report--")
289
+ bundler = Bundler(
290
+ name or detection.title,
291
+ identity,
292
+ tlp_level or detection.tlp_level or "clear",
293
+ detection.description,
294
+ (labels or []) + detection.tags,
295
+ report_id=report_id,
296
+ **kwargs,
297
+ )
158
298
  detections = DetectionContainer(success=True, detections=[])
159
299
  detections.detections.append(detection)
160
300
  else:
161
- validate_token_count(int(os.getenv('INPUT_TOKEN_LIMIT', 0)), input_text, ai_provider)
162
- bundler = Bundler(name, identity, tlp_level, input_text, labels, report_id=report_id, **kwargs)
301
+ bundler = Bundler(
302
+ name, identity, tlp_level, input_text, labels, report_id=report_id, **kwargs
303
+ )
163
304
  detections = ai_provider.get_detections(input_text)
164
305
  bundler.bundle_detections(detections)
306
+
307
+ if ai_create_attack_flow or ai_create_attack_navigator_layer:
308
+ bundler.data.attack_flow, bundler.data.navigator_layer = (
309
+ attack_flow.extract_attack_flow_and_navigator(
310
+ bundler,
311
+ bundler.report.description,
312
+ ai_create_attack_flow,
313
+ ai_create_attack_navigator_layer,
314
+ ai_provider,
315
+ )
316
+ )
165
317
  return bundler
166
318
 
319
+
167
320
  def get_sigma_detections(sigma: str) -> SigmaRuleDetection:
168
321
  obj = yaml.safe_load(io.StringIO(sigma))
169
322
  return SigmaRuleDetection.model_validate(obj)
170
-
171
323
 
172
-
324
+
173
325
  def main(args: Args):
174
326
 
175
327
  setLogFile(logging.root, Path(f"logs/log-{args.report_id}.log"))
176
328
  logging.info(f"starting argument: {json.dumps(sys.argv[1:])}")
177
329
  kwargs = args.__dict__
178
- kwargs['identity'] = args.use_identity
330
+ kwargs["identity"] = args.use_identity
179
331
  bundler = run_txt2detection(**kwargs)
180
332
 
181
- output_dir = Path("./output")/str(bundler.bundle.id)
333
+ output_dir = Path("./output") / str(bundler.bundle.id)
182
334
  shutil.rmtree(output_dir, ignore_errors=True)
183
- rules_dir = output_dir/"rules"
335
+ rules_dir = output_dir / "rules"
184
336
  rules_dir.mkdir(exist_ok=True, parents=True)
185
337
 
186
-
187
- output_path = output_dir/'bundle.json'
188
- data_path = output_dir/f"data.json"
338
+ output_path = output_dir / "bundle.json"
339
+ data_path = output_dir / f"data.json"
189
340
  output_path.write_text(bundler.to_json())
190
- data_path.write_text(bundler.detections.model_dump_json(indent=4))
191
- for obj in bundler.bundle['objects']:
192
- if obj['type'] != 'indicator' or obj['pattern_type'] != 'sigma':
341
+ data_path.write_text(bundler.data.model_dump_json(indent=4))
342
+ for obj in bundler.bundle["objects"]:
343
+ if obj["type"] != "indicator" or obj["pattern_type"] != "sigma":
193
344
  continue
194
- name = obj['id'].replace('indicator', 'rule') + '.yml'
195
- (rules_dir/name).write_text(obj['pattern'])
345
+ name = obj["id"].replace("indicator", "rule") + ".yml"
346
+ (rules_dir / name).write_text(obj["pattern"])
196
347
  logging.info(f"Writing bundle output to `{output_path}`")
@@ -7,15 +7,19 @@ from llama_index.core.llms.llm import LLM
7
7
 
8
8
  from txt2detection.ai_extractor import prompts
9
9
 
10
+ from txt2detection.ai_extractor.models import AttackFlowList
10
11
  from txt2detection.ai_extractor.utils import ParserWithLogging
11
12
  from txt2detection.models import DetectionContainer, DetectionContainer
12
13
  from llama_index.core.utils import get_tokenizer
13
14
 
14
15
 
15
- _ai_extractor_registry: dict[str, 'Type[BaseAIExtractor]'] = {}
16
- class BaseAIExtractor():
16
+ _ai_extractor_registry: dict[str, "Type[BaseAIExtractor]"] = {}
17
+
18
+
19
+ class BaseAIExtractor:
17
20
  llm: LLM
18
- system_prompt = (textwrap.dedent("""
21
+ system_prompt = textwrap.dedent(
22
+ """
19
23
  <persona>
20
24
 
21
25
  You are a cyber-security detection engineering tool responsible for analysing intelligence reports provided in text files and writing SIGMA detection rules to detect the content being described in the reports.
@@ -25,12 +29,11 @@ class BaseAIExtractor():
25
29
  IMPORTANT: You must always deliver your work as a computer-parsable output in JSON format. All output from you will be parsed with pydantic for further processing.
26
30
 
27
31
  </persona>
28
- """))
29
-
32
+ """
33
+ )
30
34
 
31
35
  def get_detections(self, input_text) -> DetectionContainer:
32
- logging.info('getting detections')
33
-
36
+ logging.info("getting detections")
34
37
 
35
38
  return LLMTextCompletionProgram.from_defaults(
36
39
  output_parser=ParserWithLogging(DetectionContainer),
@@ -38,14 +41,17 @@ class BaseAIExtractor():
38
41
  verbose=True,
39
42
  llm=self.llm,
40
43
  )(document=input_text)
41
-
44
+
42
45
  def __init__(self, *args, **kwargs) -> None:
43
46
  pass
44
47
 
45
48
  def count_tokens(self, input_text):
46
- logging.info("unsupported model `%s`, estimating using llama-index's default tokenizer", self.extractor_name)
49
+ logging.info(
50
+ "unsupported model `%s`, estimating using llama-index's default tokenizer",
51
+ self.extractor_name,
52
+ )
47
53
  return len(get_tokenizer()(input_text))
48
-
54
+
49
55
  def __init_subclass__(cls, /, provider, register=True, **kwargs):
50
56
  super().__init_subclass__(**kwargs)
51
57
  if register:
@@ -55,13 +61,35 @@ class BaseAIExtractor():
55
61
  @property
56
62
  def extractor_name(self):
57
63
  return f"{self.provider}:{self.llm.model}"
58
-
64
+
65
+ def _get_attack_flow_program(self):
66
+ return LLMTextCompletionProgram.from_defaults(
67
+ output_parser=ParserWithLogging(AttackFlowList),
68
+ prompt=prompts.ATTACK_FLOW_PROMPT_TEMPL,
69
+ verbose=True,
70
+ llm=self.llm,
71
+ )
72
+
73
+ def extract_attack_flow(self, input_text, techniques) -> AttackFlowList:
74
+ extracted_techniques = []
75
+ for t in techniques.values():
76
+ extracted_techniques.append(
77
+ dict(
78
+ id=t["id"],
79
+ name=t["name"],
80
+ possible_tactics=list(t["possible_tactics"].keys()),
81
+ )
82
+ )
83
+ return self._get_attack_flow_program()(
84
+ document=input_text, extracted_techniques=extracted_techniques
85
+ )
86
+
59
87
  def check_credential(self):
60
88
  try:
61
89
  return "authorized" if self._check_credential() else "unauthorized"
62
90
  except:
63
91
  return "unknown"
64
-
92
+
65
93
  def _check_credential(self):
66
94
  self.llm.complete("say 'hi'")
67
- return True
95
+ return True
@@ -0,0 +1,34 @@
1
+ import io
2
+ import json
3
+ import logging
4
+
5
+ import dotenv
6
+ import textwrap
7
+
8
+ from pydantic import BaseModel, Field, RootModel
9
+ from llama_index.core.output_parsers import PydanticOutputParser
10
+
11
+
12
+ class AttackFlowItem(BaseModel):
13
+ position: int = Field(description="order of object starting at 0")
14
+ attack_technique_id: str
15
+ name: str
16
+ description: str
17
+
18
+
19
+ class AttackFlowList(BaseModel):
20
+ tactic_selection: list[tuple[str, str]] = Field(
21
+ description="attack technique id to attack tactic id mapping using possible_tactics"
22
+ )
23
+ # additional_tactic_mapping: list[tuple[str, str]] = Field(description="the rest of tactic_mapping")
24
+ items: list[AttackFlowItem]
25
+ success: bool = Field(
26
+ description="determines if there's any valid flow in <extractions>"
27
+ )
28
+
29
+ def model_post_init(self, context):
30
+ return super().model_post_init(context)
31
+
32
+ @property
33
+ def tactic_mapping(self):
34
+ return dict(self.tactic_selection)
@@ -1,4 +1,3 @@
1
-
2
1
  import logging
3
2
  import os
4
3
  from .base import BaseAIExtractor
@@ -7,7 +6,7 @@ from llama_index.llms.openai import OpenAI
7
6
 
8
7
  class OpenAIExtractor(BaseAIExtractor, provider="openai"):
9
8
  def __init__(self, **kwargs) -> None:
10
- kwargs.setdefault('temperature', float(os.environ.get('TEMPERATURE', 0.0)))
9
+ kwargs.setdefault("temperature", float(os.environ.get("TEMPERATURE", 0.0)))
11
10
  self.llm = OpenAI(system_prompt=self.system_prompt, **kwargs, max_tokens=4096)
12
11
  super().__init__()
13
12
 
@@ -17,4 +16,3 @@ class OpenAIExtractor(BaseAIExtractor, provider="openai"):
17
16
  except Exception as e:
18
17
  logging.warning(e)
19
18
  return super().count_tokens(text)
20
-
@@ -1,4 +1,3 @@
1
-
2
1
  import logging
3
2
  import os
4
3
  from .base import BaseAIExtractor
@@ -7,8 +6,10 @@ from llama_index.llms.openrouter import OpenRouter
7
6
 
8
7
  class OpenRouterExtractor(BaseAIExtractor, provider="openrouter"):
9
8
  def __init__(self, **kwargs) -> None:
10
- kwargs.setdefault('temperature', float(os.environ.get('TEMPERATURE', 0.0)))
11
- self.llm = OpenRouter(system_prompt=self.system_prompt, max_tokens=4096, **kwargs)
9
+ kwargs.setdefault("temperature", float(os.environ.get("TEMPERATURE", 0.0)))
10
+ self.llm = OpenRouter(
11
+ system_prompt=self.system_prompt, max_tokens=4096, **kwargs
12
+ )
12
13
  super().__init__()
13
14
 
14
15
  def count_tokens(self, text):
@@ -17,4 +18,3 @@ class OpenRouterExtractor(BaseAIExtractor, provider="openrouter"):
17
18
  except Exception as e:
18
19
  logging.warning(e)
19
20
  return super().count_tokens(text)
20
-