janus-llm 2.1.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,614 @@
1
+ import functools
2
+ import json
3
+ import math
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from langchain.output_parsers import RetryWithErrorOutputParser
9
+ from langchain.output_parsers.fix import OutputFixingParser
10
+ from langchain_core.exceptions import OutputParserException
11
+ from langchain_core.language_models import BaseLanguageModel
12
+ from langchain_core.output_parsers import BaseOutputParser
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.runnables import RunnableLambda, RunnableParallel
15
+ from openai import BadRequestError, RateLimitError
16
+ from pydantic import ValidationError
17
+
18
+ from janus.embedding.vectorize import ChromaDBVectorizer
19
+ from janus.language.block import CodeBlock, TranslatedCodeBlock
20
+ from janus.language.combine import Combiner
21
+ from janus.language.naive.registry import CUSTOM_SPLITTERS
22
+ from janus.language.splitter import (
23
+ EmptyTreeError,
24
+ FileSizeError,
25
+ Splitter,
26
+ TokenLimitError,
27
+ )
28
+ from janus.llm import load_model
29
+ from janus.llm.model_callbacks import get_model_callback
30
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES
31
+ from janus.parsers.code_parser import GenericParser
32
+ from janus.utils.enums import LANGUAGES
33
+ from janus.utils.logger import create_logger
34
+
35
+ log = create_logger(__name__)
36
+
37
+
38
+ def run_if_changed(*tracked_vars):
39
+ """Wrapper to skip function calls if the given instance attributes haven't
40
+ been updated. Requires the _changed_attrs set to exist, and the __setattr__
41
+ method to be overridden to track parameter updates in _changed_attrs.
42
+ """
43
+
44
+ def wrapper(func):
45
+ @functools.wraps(func)
46
+ def wrapped(self, *args, **kwargs):
47
+ # If there is overlap between the tracked variables and the changed
48
+ # ones, then call the function as normal
49
+ if not tracked_vars or self._changed_attrs.intersection(tracked_vars):
50
+ func(self, *args, **kwargs)
51
+
52
+ return wrapped
53
+
54
+ return wrapper
55
+
56
+
57
+ class Converter:
58
+ """Parent class that converts code into something else.
59
+
60
+ Children will determine what the code gets converted into. Whether that's translated
61
+ into another language, into pseudocode, requirements, documentation, etc., or
62
+ converted into embeddings
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ model: str = "gpt-3.5-turbo-0125",
68
+ model_arguments: dict[str, Any] = {},
69
+ source_language: str = "fortran",
70
+ max_prompts: int = 10,
71
+ max_tokens: int | None = None,
72
+ prompt_template: str = "simple",
73
+ db_path: str | None = None,
74
+ db_config: dict[str, Any] | None = None,
75
+ protected_node_types: tuple[str, ...] = (),
76
+ prune_node_types: tuple[str, ...] = (),
77
+ splitter_type: str = "file",
78
+ ) -> None:
79
+ """Initialize a Converter instance.
80
+
81
+ Arguments:
82
+ source_language: The source programming language.
83
+ parser_type: The type of parser to use for parsing the LLM output. Valid
84
+ values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
85
+ the `Converter` assumes you won't be parsing an output (i.e., adding to an
86
+ embedding DB).
87
+ """
88
+ self._changed_attrs: set = set()
89
+
90
+ self.max_prompts: int = max_prompts
91
+ self._max_tokens: int | None = max_tokens
92
+ self.override_token_limit: bool = max_tokens is not None
93
+
94
+ self._model_name: str
95
+ self._custom_model_arguments: dict[str, Any]
96
+
97
+ self._source_language: str
98
+ self._source_suffix: str
99
+
100
+ self._target_language = "json"
101
+ self._target_suffix = ".json"
102
+
103
+ self._protected_node_types: tuple[str, ...] = ()
104
+ self._prune_node_types: tuple[str, ...] = ()
105
+ self._max_tokens: int | None = max_tokens
106
+ self._prompt_template_name: str
107
+ self._splitter_type: str
108
+ self._db_path: str | None
109
+ self._db_config: dict[str, Any] | None
110
+
111
+ self._splitter: Splitter
112
+ self._llm: BaseLanguageModel
113
+ self._prompt: ChatPromptTemplate
114
+
115
+ self._parser: BaseOutputParser = GenericParser()
116
+ self._combiner: Combiner = Combiner()
117
+
118
+ self.set_splitter(splitter_type=splitter_type)
119
+ self.set_model(model_name=model, **model_arguments)
120
+ self.set_prompt(prompt_template=prompt_template)
121
+ self.set_source_language(source_language)
122
+ self.set_protected_node_types(protected_node_types)
123
+ self.set_prune_node_types(prune_node_types)
124
+ self.set_db_path(db_path=db_path)
125
+ self.set_db_config(db_config=db_config)
126
+
127
+ # Child class must call this. Should we enforce somehow?
128
+ # self._load_parameters()
129
+
130
+ def __setattr__(self, key: Any, value: Any) -> None:
131
+ if hasattr(self, "_changed_attrs"):
132
+ if not hasattr(self, key) or getattr(self, key) != value:
133
+ self._changed_attrs.add(key)
134
+ # Avoid infinite recursion
135
+ elif key != "_changed_attrs":
136
+ self._changed_attrs = set()
137
+ super().__setattr__(key, value)
138
+
139
+ def _load_parameters(self) -> None:
140
+ self._load_model()
141
+ self._load_prompt()
142
+ self._load_splitter()
143
+ self._load_vectorizer()
144
+ self._changed_attrs.clear()
145
+
146
+ def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
147
+ """Validate and set the model name.
148
+
149
+ The affected objects will not be updated until translate() is called.
150
+
151
+ Arguments:
152
+ model_name: The name of the model to use. Valid models are found in
153
+ `janus.llm.models_info.MODEL_CONSTRUCTORS`.
154
+ custom_arguments: Additional arguments to pass to the model constructor.
155
+ """
156
+ self._model_name = model_name
157
+ self._custom_model_arguments = custom_arguments
158
+
159
+ def set_prompt(self, prompt_template: str) -> None:
160
+ """Validate and set the prompt template name.
161
+
162
+ The affected objects will not be updated until translate() is called.
163
+
164
+ Arguments:
165
+ prompt_template: name of prompt template directory
166
+ (see janus/prompts/templates) or path to a directory.
167
+ """
168
+ self._prompt_template_name = prompt_template
169
+
170
+ def set_splitter(self, splitter_type: str) -> None:
171
+ """Validate and set the prompt template name.
172
+
173
+ The affected objects will not be updated until translate() is called.
174
+
175
+ Arguments:
176
+ prompt_template: name of prompt template directory
177
+ (see janus/prompts/templates) or path to a directory.
178
+ """
179
+ self._splitter_type = splitter_type
180
+
181
+ def set_source_language(self, source_language: str) -> None:
182
+ """Validate and set the source language.
183
+
184
+ The affected objects will not be updated until _load_parameters() is called.
185
+
186
+ Arguments:
187
+ source_language: The source programming language.
188
+ """
189
+ source_language = source_language.lower()
190
+ if source_language not in LANGUAGES:
191
+ raise ValueError(
192
+ f"Invalid source language: {source_language}. "
193
+ "Valid source languages are found in `janus.utils.enums.LANGUAGES`."
194
+ )
195
+
196
+ ext = LANGUAGES[source_language]["suffix"]
197
+ self._source_suffix = f".{ext}"
198
+ self._source_language = source_language
199
+
200
+ def set_protected_node_types(self, protected_node_types: tuple[str, ...]) -> None:
201
+ """Set the protected (non-mergeable) node types. This will often be structures
202
+ like functions, classes, or modules which you might want to keep separate
203
+
204
+ The affected objects will not be updated until _load_parameters() is called.
205
+
206
+ Arguments:
207
+ protected_node_types: A set of node types that aren't to be merged
208
+ """
209
+ self._protected_node_types = tuple(set(protected_node_types or []))
210
+
211
+ def set_prune_node_types(self, prune_node_types: tuple[str, ...]) -> None:
212
+ """Set the node types to prune. This will often be structures
213
+ like comments or whitespace which you might want to keep out of the LLM
214
+
215
+ The affected objects will not be updated until _load_parameters() is called.
216
+
217
+ Arguments:
218
+ prune_node_types: A set of node types which should be pruned
219
+ """
220
+ self._prune_node_types = tuple(set(prune_node_types or []))
221
+
222
+ def set_db_path(self, db_path: str | None) -> None:
223
+ self._db_path = db_path
224
+
225
+ def set_db_config(self, db_config: dict[str, Any] | None) -> None:
226
+ self._db_config = db_config
227
+
228
+ @run_if_changed(
229
+ "_source_language",
230
+ "_max_tokens",
231
+ "_llm",
232
+ "_protected_node_types",
233
+ "_prune_node_types",
234
+ "_custom_splitter",
235
+ )
236
+ def _load_splitter(self) -> None:
237
+ """Load the splitter according to this instance's attributes.
238
+
239
+ If the relevant fields have not been changed since the last time this method was
240
+ called, nothing happens.
241
+ """
242
+ kwargs: dict[str, Any] = dict(
243
+ language=self._source_language,
244
+ max_tokens=self._max_tokens,
245
+ model=self._llm,
246
+ protected_node_types=self._protected_node_types,
247
+ prune_node_types=self._prune_node_types,
248
+ )
249
+
250
+ if self._splitter_type == "tag":
251
+ kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
252
+
253
+ self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
254
+
255
+ @run_if_changed("_model_name", "_custom_model_arguments")
256
+ def _load_model(self) -> None:
257
+ """Load the model according to this instance's attributes.
258
+
259
+ If the relevant fields have not been changed since the last time this method was
260
+ called, nothing happens.
261
+ """
262
+
263
+ # Get default arguments, set custom ones
264
+ # model_arguments = deepcopy(MODEL_DEFAULT_ARGUMENTS[self._model_name])
265
+ # model_arguments.update(self._custom_model_arguments)
266
+
267
+ # Load the model
268
+ self._llm, token_limit, self.model_cost = load_model(self._model_name)
269
+ # Set the max_tokens to less than half the model's limit to allow for enough
270
+ # tokens at output
271
+ # Only modify max_tokens if it is not specified by user
272
+ if not self.override_token_limit:
273
+ self._max_tokens = int(token_limit // 2.5)
274
+
275
+ @run_if_changed(
276
+ "_prompt_template_name",
277
+ "_source_language",
278
+ "_model_name",
279
+ )
280
+ def _load_prompt(self) -> None:
281
+ """Load the prompt according to this instance's attributes.
282
+
283
+ If the relevant fields have not been changed since the last time this
284
+ method was called, nothing happens.
285
+ """
286
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
287
+ source_language=self._source_language,
288
+ prompt_template=self._prompt_template_name,
289
+ )
290
+ self._prompt = prompt_engine.prompt
291
+
292
+ @run_if_changed("_db_path", "_db_config")
293
+ def _load_vectorizer(self) -> None:
294
+ if self._db_path is None or self._db_config is None:
295
+ self._vectorizer = None
296
+ return
297
+ vectorizer_factory = ChromaDBVectorizer()
298
+ self._vectorizer = vectorizer_factory.create_vectorizer(
299
+ self._db_path, self._db_config
300
+ )
301
+
302
+ def translate(
303
+ self,
304
+ input_directory: str | Path,
305
+ output_directory: str | Path | None = None,
306
+ overwrite: bool = False,
307
+ collection_name: str | None = None,
308
+ ) -> None:
309
+ """Convert code in the input directory from the source language to the target
310
+ language, and write the resulting files to the output directory.
311
+
312
+ Arguments:
313
+ input_directory: The directory containing the code to translate.
314
+ output_directory: The directory to write the translated code to.
315
+ overwrite: Whether to overwrite existing files (vs skip them)
316
+ collection_name: Collection to add to
317
+ """
318
+ # Convert paths to pathlib Paths if needed
319
+ if isinstance(input_directory, str):
320
+ input_directory = Path(input_directory)
321
+ if isinstance(output_directory, str):
322
+ output_directory = Path(output_directory)
323
+
324
+ # Make sure the output directory exists
325
+ if output_directory is not None and not output_directory.exists():
326
+ output_directory.mkdir(parents=True)
327
+
328
+ input_paths = [p for p in input_directory.rglob(f"**/*{self._source_suffix}")]
329
+
330
+ log.info(f"Input directory: {input_directory.absolute()}")
331
+ log.info(
332
+ f"{self._source_language} '*{self._source_suffix}' files: "
333
+ f"{len(input_paths)}"
334
+ )
335
+ log.info(
336
+ "Other files (skipped): "
337
+ f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
338
+ )
339
+ if output_directory is not None:
340
+ output_paths = [
341
+ output_directory
342
+ / p.relative_to(input_directory).with_suffix(self._target_suffix)
343
+ for p in input_paths
344
+ ]
345
+ in_out_pairs = list(zip(input_paths, output_paths))
346
+ if not overwrite:
347
+ n_files = len(in_out_pairs)
348
+ in_out_pairs = [
349
+ (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
350
+ ]
351
+ log.info(
352
+ f"Skipping {n_files - len(in_out_pairs)} existing "
353
+ f"'*{self._source_suffix}' files"
354
+ )
355
+ else:
356
+ in_out_pairs = [(f, None) for f in input_paths]
357
+ log.info(f"Translating {len(in_out_pairs)} '*{self._source_suffix}' files")
358
+
359
+ # Loop through each input file, convert and save it
360
+ total_cost = 0.0
361
+ for in_path, out_path in in_out_pairs:
362
+ # Translate the file, skip it if there's a rate limit error
363
+ try:
364
+ out_block = self.translate_file(in_path)
365
+ total_cost += out_block.total_cost
366
+ except RateLimitError:
367
+ continue
368
+ except OutputParserException as e:
369
+ log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
370
+ continue
371
+ except BadRequestError as e:
372
+ if str(e).startswith("Detected an error in the prompt"):
373
+ log.warning("Malformed input, skipping")
374
+ continue
375
+ raise e
376
+ except ValidationError as e:
377
+ # Only allow ValidationError to pass if token limit is manually set
378
+ if self.override_token_limit:
379
+ log.warning(
380
+ "Current file and manually set token "
381
+ "limit is too large for this model, skipping"
382
+ )
383
+ continue
384
+ raise e
385
+ except TokenLimitError:
386
+ log.warning("Ran into irreducible node too large for context, skipping")
387
+ continue
388
+ except EmptyTreeError:
389
+ log.warning(
390
+ f'Input file "{in_path.name}" has no nodes of interest, skipping'
391
+ )
392
+ continue
393
+ except FileSizeError:
394
+ log.warning("Current tile is too large for basic splitter, skipping")
395
+ continue
396
+
397
+ # Don't attempt to write files for which translation failed
398
+ if not out_block.translated:
399
+ continue
400
+
401
+ if collection_name is not None:
402
+ self._vectorizer.add_nodes_recursively(
403
+ out_block,
404
+ collection_name,
405
+ in_path.name,
406
+ )
407
+
408
+ # Make sure the tree's code has been consolidated at the top level
409
+ # before writing to file
410
+ self._combiner.combine(out_block)
411
+ if out_path is not None and (overwrite or not out_path.exists()):
412
+ self._save_to_file(out_block, out_path)
413
+
414
+ log.info(f"Total cost: ${total_cost:,.2f}")
415
+
416
+ def translate_file(self, file: Path) -> TranslatedCodeBlock:
417
+ """Translate a single file.
418
+
419
+ Arguments:
420
+ file: Input path to file
421
+
422
+ Returns:
423
+ A `TranslatedCodeBlock` object. This block does not have a path set, and its
424
+ code is not guaranteed to be consolidated. To amend this, run
425
+ `Combiner.combine_children` on the block.
426
+ """
427
+ self._load_parameters()
428
+ filename = file.name
429
+
430
+ input_block = self._split_file(file)
431
+ t0 = time.time()
432
+ output_block = self._iterative_translate(input_block)
433
+ output_block.processing_time = time.time() - t0
434
+ if output_block.translated:
435
+ completeness = output_block.translation_completeness
436
+ log.info(
437
+ f"[{filename}] Translation complete\n"
438
+ f" {completeness:.2%} of input successfully translated\n"
439
+ f" Total cost: ${output_block.total_cost:,.2f}\n"
440
+ f" Total retries: {output_block.total_retries:,d}\n"
441
+ f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
442
+ )
443
+
444
+ else:
445
+ log.error(
446
+ f"[{filename}] Translation failed\n"
447
+ f" Total cost: ${output_block.total_cost:,.2f}\n"
448
+ f" Total retries: {output_block.total_retries:,d}\n"
449
+ )
450
+ return output_block
451
+
452
+ def _iterative_translate(self, root: CodeBlock) -> TranslatedCodeBlock:
453
+ """Translate the passed CodeBlock representing a full file.
454
+
455
+ Arguments:
456
+ root: A root block representing the top-level block of a file
457
+
458
+ Returns:
459
+ A `TranslatedCodeBlock`
460
+ """
461
+ translated_root = TranslatedCodeBlock(root, self._target_language)
462
+ last_prog, prog_delta = 0, 0.1
463
+ stack = [translated_root]
464
+ while stack:
465
+ translated_block = stack.pop()
466
+
467
+ self._add_translation(translated_block)
468
+
469
+ # If translating this block was unsuccessful, don't bother with its
470
+ # children (they wouldn't show up in the final text anyway)
471
+ if not translated_block.translated:
472
+ continue
473
+
474
+ stack.extend(translated_block.children)
475
+
476
+ progress = translated_root.translation_completeness
477
+ if progress - last_prog > prog_delta:
478
+ last_prog = int(progress / prog_delta) * prog_delta
479
+ log.info(f"[{root.name}] progress: {progress:.2%}")
480
+
481
+ return translated_root
482
+
483
+ def _add_translation(self, block: TranslatedCodeBlock) -> None:
484
+ """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
485
+ `block.original`, setting the relevant fields in the translated block. The
486
+ `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
487
+ translates *only* the code for this block, not its children.
488
+
489
+ Arguments:
490
+ block: An empty `TranslatedCodeBlock`
491
+ """
492
+ if block.translated:
493
+ return
494
+
495
+ if block.original.text is None:
496
+ block.translated = True
497
+ return
498
+
499
+ if self._llm is None:
500
+ message = (
501
+ "Model not configured correctly, cannot translate. Try setting "
502
+ "the model"
503
+ )
504
+ log.error(message)
505
+ raise ValueError(message)
506
+
507
+ log.debug(f"[{block.name}] Translating...")
508
+ log.debug(f"[{block.name}] Input text:\n{block.original.text}")
509
+
510
+ # Track the cost of translating this block
511
+ # TODO: If non-OpenAI models with prices are added, this will need
512
+ # to be updated.
513
+ with get_model_callback() as cb:
514
+ t0 = time.time()
515
+ block.text = self._run_chain(block)
516
+ block.processing_time = time.time() - t0
517
+ block.cost = cb.total_cost
518
+ block.retries = max(0, cb.successful_requests - 1)
519
+
520
+ block.tokens = self._llm.get_num_tokens(block.text)
521
+ block.translated = True
522
+
523
+ log.debug(f"[{block.name}] Output code:\n{block.text}")
524
+
525
+ def _split_file(self, file: Path) -> CodeBlock:
526
+ filename = file.name
527
+ log.info(f"[{filename}] Splitting file")
528
+ root = self._splitter.split(file)
529
+ log.info(
530
+ f"[{filename}] File split into {root.n_descendents:,} blocks, "
531
+ f"tree of height {root.height}"
532
+ )
533
+ log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
534
+ return root
535
+
536
+ def _run_chain(self, block: TranslatedCodeBlock) -> str:
537
+ """Run the model with three nested error fixing schemes.
538
+ First, try to fix simple formatting errors by giving the model just
539
+ the output and the parsing error. After a number of attempts, try
540
+ giving the model the output, the parsing error, and the original
541
+ input. Again check/retry this output to solve for formatting errors.
542
+ If we still haven't succeeded after several attempts, the model may
543
+ be getting thrown off by a bad initial output; start from scratch
544
+ and try again.
545
+
546
+ The number of tries for each layer of this scheme is roughly equal
547
+ to the cube root of self.max_retries, so the total calls to the
548
+ LLM will be roughly as expected (up to sqrt(self.max_retries) over)
549
+ """
550
+ self._parser.set_reference(block.original)
551
+
552
+ # Retries with just the output and the error
553
+ n1 = round(self.max_prompts ** (1 / 3))
554
+
555
+ # Retries with the input, output, and error
556
+ n2 = round((self.max_prompts // n1) ** (1 / 2))
557
+
558
+ # Retries with just the input
559
+ n3 = math.ceil(self.max_prompts / (n1 * n2))
560
+
561
+ fix_format = OutputFixingParser.from_llm(
562
+ llm=self._llm,
563
+ parser=self._parser,
564
+ max_retries=n1,
565
+ )
566
+ retry = RetryWithErrorOutputParser.from_llm(
567
+ llm=self._llm,
568
+ parser=fix_format,
569
+ max_retries=n2,
570
+ )
571
+
572
+ completion_chain = self._prompt | self._llm
573
+ chain = RunnableParallel(
574
+ completion=completion_chain, prompt_value=self._prompt
575
+ ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
576
+
577
+ for _ in range(n3):
578
+ try:
579
+ return chain.invoke({"SOURCE_CODE": block.original.text})
580
+ except OutputParserException:
581
+ pass
582
+
583
+ raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
584
+
585
+ def _get_output_obj(
586
+ self, block: TranslatedCodeBlock
587
+ ) -> dict[str, int | float | str | dict[str, str]]:
588
+ output_str = self._parser.parse_combined_output(block.complete_text)
589
+
590
+ output: str | dict[str, str]
591
+ try:
592
+ output = json.loads(output_str)
593
+ except json.JSONDecodeError:
594
+ output = output_str
595
+
596
+ return dict(
597
+ input=block.original.text,
598
+ metadata=dict(
599
+ retries=block.total_retries,
600
+ cost=block.total_cost,
601
+ processing_time=block.processing_time,
602
+ ),
603
+ output=output,
604
+ )
605
+
606
+ def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
607
+ """Save a file to disk.
608
+
609
+ Arguments:
610
+ block: The `TranslatedCodeBlock` to save to a file.
611
+ """
612
+ obj = self._get_output_obj(block)
613
+ out_path.parent.mkdir(parents=True, exist_ok=True)
614
+ out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")