janus-llm 2.1.0__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,614 @@
1
+ import functools
2
+ import json
3
+ import math
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from langchain.output_parsers import RetryWithErrorOutputParser
9
+ from langchain.output_parsers.fix import OutputFixingParser
10
+ from langchain_core.exceptions import OutputParserException
11
+ from langchain_core.language_models import BaseLanguageModel
12
+ from langchain_core.output_parsers import BaseOutputParser
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.runnables import RunnableLambda, RunnableParallel
15
+ from openai import BadRequestError, RateLimitError
16
+ from pydantic import ValidationError
17
+
18
+ from janus.embedding.vectorize import ChromaDBVectorizer
19
+ from janus.language.block import CodeBlock, TranslatedCodeBlock
20
+ from janus.language.combine import Combiner
21
+ from janus.language.naive.registry import CUSTOM_SPLITTERS
22
+ from janus.language.splitter import (
23
+ EmptyTreeError,
24
+ FileSizeError,
25
+ Splitter,
26
+ TokenLimitError,
27
+ )
28
+ from janus.llm import load_model
29
+ from janus.llm.model_callbacks import get_model_callback
30
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES
31
+ from janus.parsers.code_parser import GenericParser
32
+ from janus.utils.enums import LANGUAGES
33
+ from janus.utils.logger import create_logger
34
+
35
+ log = create_logger(__name__)
36
+
37
+
38
+ def run_if_changed(*tracked_vars):
39
+ """Wrapper to skip function calls if the given instance attributes haven't
40
+ been updated. Requires the _changed_attrs set to exist, and the __setattr__
41
+ method to be overridden to track parameter updates in _changed_attrs.
42
+ """
43
+
44
+ def wrapper(func):
45
+ @functools.wraps(func)
46
+ def wrapped(self, *args, **kwargs):
47
+ # If there is overlap between the tracked variables and the changed
48
+ # ones, then call the function as normal
49
+ if not tracked_vars or self._changed_attrs.intersection(tracked_vars):
50
+ func(self, *args, **kwargs)
51
+
52
+ return wrapped
53
+
54
+ return wrapper
55
+
56
+
57
+ class Converter:
58
+ """Parent class that converts code into something else.
59
+
60
+ Children will determine what the code gets converted into. Whether that's translated
61
+ into another language, into pseudocode, requirements, documentation, etc., or
62
+ converted into embeddings
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ model: str = "gpt-3.5-turbo-0125",
68
+ model_arguments: dict[str, Any] = {},
69
+ source_language: str = "fortran",
70
+ max_prompts: int = 10,
71
+ max_tokens: int | None = None,
72
+ prompt_template: str = "simple",
73
+ db_path: str | None = None,
74
+ db_config: dict[str, Any] | None = None,
75
+ protected_node_types: tuple[str, ...] = (),
76
+ prune_node_types: tuple[str, ...] = (),
77
+ splitter_type: str = "file",
78
+ ) -> None:
79
+ """Initialize a Converter instance.
80
+
81
+ Arguments:
82
+ source_language: The source programming language.
83
+ parser_type: The type of parser to use for parsing the LLM output. Valid
84
+ values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
85
+ the `Converter` assumes you won't be parsing an output (i.e., adding to an
86
+ embedding DB).
87
+ """
88
+ self._changed_attrs: set = set()
89
+
90
+ self.max_prompts: int = max_prompts
91
+ self._max_tokens: int | None = max_tokens
92
+ self.override_token_limit: bool = max_tokens is not None
93
+
94
+ self._model_name: str
95
+ self._custom_model_arguments: dict[str, Any]
96
+
97
+ self._source_language: str
98
+ self._source_suffix: str
99
+
100
+ self._target_language = "json"
101
+ self._target_suffix = ".json"
102
+
103
+ self._protected_node_types: tuple[str, ...] = ()
104
+ self._prune_node_types: tuple[str, ...] = ()
105
+ self._max_tokens: int | None = max_tokens
106
+ self._prompt_template_name: str
107
+ self._splitter_type: str
108
+ self._db_path: str | None
109
+ self._db_config: dict[str, Any] | None
110
+
111
+ self._splitter: Splitter
112
+ self._llm: BaseLanguageModel
113
+ self._prompt: ChatPromptTemplate
114
+
115
+ self._parser: BaseOutputParser = GenericParser()
116
+ self._combiner: Combiner = Combiner()
117
+
118
+ self.set_splitter(splitter_type=splitter_type)
119
+ self.set_model(model_name=model, **model_arguments)
120
+ self.set_prompt(prompt_template=prompt_template)
121
+ self.set_source_language(source_language)
122
+ self.set_protected_node_types(protected_node_types)
123
+ self.set_prune_node_types(prune_node_types)
124
+ self.set_db_path(db_path=db_path)
125
+ self.set_db_config(db_config=db_config)
126
+
127
+ # Child class must call this. Should we enforce somehow?
128
+ # self._load_parameters()
129
+
130
+ def __setattr__(self, key: Any, value: Any) -> None:
131
+ if hasattr(self, "_changed_attrs"):
132
+ if not hasattr(self, key) or getattr(self, key) != value:
133
+ self._changed_attrs.add(key)
134
+ # Avoid infinite recursion
135
+ elif key != "_changed_attrs":
136
+ self._changed_attrs = set()
137
+ super().__setattr__(key, value)
138
+
139
+ def _load_parameters(self) -> None:
140
+ self._load_model()
141
+ self._load_prompt()
142
+ self._load_splitter()
143
+ self._load_vectorizer()
144
+ self._changed_attrs.clear()
145
+
146
+ def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
147
+ """Validate and set the model name.
148
+
149
+ The affected objects will not be updated until translate() is called.
150
+
151
+ Arguments:
152
+ model_name: The name of the model to use. Valid models are found in
153
+ `janus.llm.models_info.MODEL_CONSTRUCTORS`.
154
+ custom_arguments: Additional arguments to pass to the model constructor.
155
+ """
156
+ self._model_name = model_name
157
+ self._custom_model_arguments = custom_arguments
158
+
159
+ def set_prompt(self, prompt_template: str) -> None:
160
+ """Validate and set the prompt template name.
161
+
162
+ The affected objects will not be updated until translate() is called.
163
+
164
+ Arguments:
165
+ prompt_template: name of prompt template directory
166
+ (see janus/prompts/templates) or path to a directory.
167
+ """
168
+ self._prompt_template_name = prompt_template
169
+
170
+ def set_splitter(self, splitter_type: str) -> None:
171
+ """Validate and set the prompt template name.
172
+
173
+ The affected objects will not be updated until translate() is called.
174
+
175
+ Arguments:
176
+ prompt_template: name of prompt template directory
177
+ (see janus/prompts/templates) or path to a directory.
178
+ """
179
+ self._splitter_type = splitter_type
180
+
181
+ def set_source_language(self, source_language: str) -> None:
182
+ """Validate and set the source language.
183
+
184
+ The affected objects will not be updated until _load_parameters() is called.
185
+
186
+ Arguments:
187
+ source_language: The source programming language.
188
+ """
189
+ source_language = source_language.lower()
190
+ if source_language not in LANGUAGES:
191
+ raise ValueError(
192
+ f"Invalid source language: {source_language}. "
193
+ "Valid source languages are found in `janus.utils.enums.LANGUAGES`."
194
+ )
195
+
196
+ ext = LANGUAGES[source_language]["suffix"]
197
+ self._source_suffix = f".{ext}"
198
+ self._source_language = source_language
199
+
200
+ def set_protected_node_types(self, protected_node_types: tuple[str, ...]) -> None:
201
+ """Set the protected (non-mergeable) node types. This will often be structures
202
+ like functions, classes, or modules which you might want to keep separate
203
+
204
+ The affected objects will not be updated until _load_parameters() is called.
205
+
206
+ Arguments:
207
+ protected_node_types: A set of node types that aren't to be merged
208
+ """
209
+ self._protected_node_types = tuple(set(protected_node_types or []))
210
+
211
+ def set_prune_node_types(self, prune_node_types: tuple[str, ...]) -> None:
212
+ """Set the node types to prune. This will often be structures
213
+ like comments or whitespace which you might want to keep out of the LLM
214
+
215
+ The affected objects will not be updated until _load_parameters() is called.
216
+
217
+ Arguments:
218
+ prune_node_types: A set of node types which should be pruned
219
+ """
220
+ self._prune_node_types = tuple(set(prune_node_types or []))
221
+
222
+ def set_db_path(self, db_path: str | None) -> None:
223
+ self._db_path = db_path
224
+
225
+ def set_db_config(self, db_config: dict[str, Any] | None) -> None:
226
+ self._db_config = db_config
227
+
228
+ @run_if_changed(
229
+ "_source_language",
230
+ "_max_tokens",
231
+ "_llm",
232
+ "_protected_node_types",
233
+ "_prune_node_types",
234
+ "_custom_splitter",
235
+ )
236
+ def _load_splitter(self) -> None:
237
+ """Load the splitter according to this instance's attributes.
238
+
239
+ If the relevant fields have not been changed since the last time this method was
240
+ called, nothing happens.
241
+ """
242
+ kwargs: dict[str, Any] = dict(
243
+ language=self._source_language,
244
+ max_tokens=self._max_tokens,
245
+ model=self._llm,
246
+ protected_node_types=self._protected_node_types,
247
+ prune_node_types=self._prune_node_types,
248
+ )
249
+
250
+ if self._splitter_type == "tag":
251
+ kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
252
+
253
+ self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
254
+
255
+ @run_if_changed("_model_name", "_custom_model_arguments")
256
+ def _load_model(self) -> None:
257
+ """Load the model according to this instance's attributes.
258
+
259
+ If the relevant fields have not been changed since the last time this method was
260
+ called, nothing happens.
261
+ """
262
+
263
+ # Get default arguments, set custom ones
264
+ # model_arguments = deepcopy(MODEL_DEFAULT_ARGUMENTS[self._model_name])
265
+ # model_arguments.update(self._custom_model_arguments)
266
+
267
+ # Load the model
268
+ self._llm, token_limit, self.model_cost = load_model(self._model_name)
269
+ # Set the max_tokens to less than half the model's limit to allow for enough
270
+ # tokens at output
271
+ # Only modify max_tokens if it is not specified by user
272
+ if not self.override_token_limit:
273
+ self._max_tokens = int(token_limit // 2.5)
274
+
275
+ @run_if_changed(
276
+ "_prompt_template_name",
277
+ "_source_language",
278
+ "_model_name",
279
+ )
280
+ def _load_prompt(self) -> None:
281
+ """Load the prompt according to this instance's attributes.
282
+
283
+ If the relevant fields have not been changed since the last time this
284
+ method was called, nothing happens.
285
+ """
286
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
287
+ source_language=self._source_language,
288
+ prompt_template=self._prompt_template_name,
289
+ )
290
+ self._prompt = prompt_engine.prompt
291
+
292
+ @run_if_changed("_db_path", "_db_config")
293
+ def _load_vectorizer(self) -> None:
294
+ if self._db_path is None or self._db_config is None:
295
+ self._vectorizer = None
296
+ return
297
+ vectorizer_factory = ChromaDBVectorizer()
298
+ self._vectorizer = vectorizer_factory.create_vectorizer(
299
+ self._db_path, self._db_config
300
+ )
301
+
302
+ def translate(
303
+ self,
304
+ input_directory: str | Path,
305
+ output_directory: str | Path | None = None,
306
+ overwrite: bool = False,
307
+ collection_name: str | None = None,
308
+ ) -> None:
309
+ """Convert code in the input directory from the source language to the target
310
+ language, and write the resulting files to the output directory.
311
+
312
+ Arguments:
313
+ input_directory: The directory containing the code to translate.
314
+ output_directory: The directory to write the translated code to.
315
+ overwrite: Whether to overwrite existing files (vs skip them)
316
+ collection_name: Collection to add to
317
+ """
318
+ # Convert paths to pathlib Paths if needed
319
+ if isinstance(input_directory, str):
320
+ input_directory = Path(input_directory)
321
+ if isinstance(output_directory, str):
322
+ output_directory = Path(output_directory)
323
+
324
+ # Make sure the output directory exists
325
+ if output_directory is not None and not output_directory.exists():
326
+ output_directory.mkdir(parents=True)
327
+
328
+ input_paths = [p for p in input_directory.rglob(f"**/*{self._source_suffix}")]
329
+
330
+ log.info(f"Input directory: {input_directory.absolute()}")
331
+ log.info(
332
+ f"{self._source_language} '*{self._source_suffix}' files: "
333
+ f"{len(input_paths)}"
334
+ )
335
+ log.info(
336
+ "Other files (skipped): "
337
+ f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
338
+ )
339
+ if output_directory is not None:
340
+ output_paths = [
341
+ output_directory
342
+ / p.relative_to(input_directory).with_suffix(self._target_suffix)
343
+ for p in input_paths
344
+ ]
345
+ in_out_pairs = list(zip(input_paths, output_paths))
346
+ if not overwrite:
347
+ n_files = len(in_out_pairs)
348
+ in_out_pairs = [
349
+ (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
350
+ ]
351
+ log.info(
352
+ f"Skipping {n_files - len(in_out_pairs)} existing "
353
+ f"'*{self._source_suffix}' files"
354
+ )
355
+ else:
356
+ in_out_pairs = [(f, None) for f in input_paths]
357
+ log.info(f"Translating {len(in_out_pairs)} '*{self._source_suffix}' files")
358
+
359
+ # Loop through each input file, convert and save it
360
+ total_cost = 0.0
361
+ for in_path, out_path in in_out_pairs:
362
+ # Translate the file, skip it if there's a rate limit error
363
+ try:
364
+ out_block = self.translate_file(in_path)
365
+ total_cost += out_block.total_cost
366
+ except RateLimitError:
367
+ continue
368
+ except OutputParserException as e:
369
+ log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
370
+ continue
371
+ except BadRequestError as e:
372
+ if str(e).startswith("Detected an error in the prompt"):
373
+ log.warning("Malformed input, skipping")
374
+ continue
375
+ raise e
376
+ except ValidationError as e:
377
+ # Only allow ValidationError to pass if token limit is manually set
378
+ if self.override_token_limit:
379
+ log.warning(
380
+ "Current file and manually set token "
381
+ "limit is too large for this model, skipping"
382
+ )
383
+ continue
384
+ raise e
385
+ except TokenLimitError:
386
+ log.warning("Ran into irreducible node too large for context, skipping")
387
+ continue
388
+ except EmptyTreeError:
389
+ log.warning(
390
+ f'Input file "{in_path.name}" has no nodes of interest, skipping'
391
+ )
392
+ continue
393
+ except FileSizeError:
394
+ log.warning("Current tile is too large for basic splitter, skipping")
395
+ continue
396
+
397
+ # Don't attempt to write files for which translation failed
398
+ if not out_block.translated:
399
+ continue
400
+
401
+ if collection_name is not None:
402
+ self._vectorizer.add_nodes_recursively(
403
+ out_block,
404
+ collection_name,
405
+ in_path.name,
406
+ )
407
+
408
+ # Make sure the tree's code has been consolidated at the top level
409
+ # before writing to file
410
+ self._combiner.combine(out_block)
411
+ if out_path is not None and (overwrite or not out_path.exists()):
412
+ self._save_to_file(out_block, out_path)
413
+
414
+ log.info(f"Total cost: ${total_cost:,.2f}")
415
+
416
+ def translate_file(self, file: Path) -> TranslatedCodeBlock:
417
+ """Translate a single file.
418
+
419
+ Arguments:
420
+ file: Input path to file
421
+
422
+ Returns:
423
+ A `TranslatedCodeBlock` object. This block does not have a path set, and its
424
+ code is not guaranteed to be consolidated. To amend this, run
425
+ `Combiner.combine_children` on the block.
426
+ """
427
+ self._load_parameters()
428
+ filename = file.name
429
+
430
+ input_block = self._split_file(file)
431
+ t0 = time.time()
432
+ output_block = self._iterative_translate(input_block)
433
+ output_block.processing_time = time.time() - t0
434
+ if output_block.translated:
435
+ completeness = output_block.translation_completeness
436
+ log.info(
437
+ f"[{filename}] Translation complete\n"
438
+ f" {completeness:.2%} of input successfully translated\n"
439
+ f" Total cost: ${output_block.total_cost:,.2f}\n"
440
+ f" Total retries: {output_block.total_retries:,d}\n"
441
+ f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
442
+ )
443
+
444
+ else:
445
+ log.error(
446
+ f"[{filename}] Translation failed\n"
447
+ f" Total cost: ${output_block.total_cost:,.2f}\n"
448
+ f" Total retries: {output_block.total_retries:,d}\n"
449
+ )
450
+ return output_block
451
+
452
+ def _iterative_translate(self, root: CodeBlock) -> TranslatedCodeBlock:
453
+ """Translate the passed CodeBlock representing a full file.
454
+
455
+ Arguments:
456
+ root: A root block representing the top-level block of a file
457
+
458
+ Returns:
459
+ A `TranslatedCodeBlock`
460
+ """
461
+ translated_root = TranslatedCodeBlock(root, self._target_language)
462
+ last_prog, prog_delta = 0, 0.1
463
+ stack = [translated_root]
464
+ while stack:
465
+ translated_block = stack.pop()
466
+
467
+ self._add_translation(translated_block)
468
+
469
+ # If translating this block was unsuccessful, don't bother with its
470
+ # children (they wouldn't show up in the final text anyway)
471
+ if not translated_block.translated:
472
+ continue
473
+
474
+ stack.extend(translated_block.children)
475
+
476
+ progress = translated_root.translation_completeness
477
+ if progress - last_prog > prog_delta:
478
+ last_prog = int(progress / prog_delta) * prog_delta
479
+ log.info(f"[{root.name}] progress: {progress:.2%}")
480
+
481
+ return translated_root
482
+
483
+ def _add_translation(self, block: TranslatedCodeBlock) -> None:
484
+ """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
485
+ `block.original`, setting the relevant fields in the translated block. The
486
+ `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
487
+ translates *only* the code for this block, not its children.
488
+
489
+ Arguments:
490
+ block: An empty `TranslatedCodeBlock`
491
+ """
492
+ if block.translated:
493
+ return
494
+
495
+ if block.original.text is None:
496
+ block.translated = True
497
+ return
498
+
499
+ if self._llm is None:
500
+ message = (
501
+ "Model not configured correctly, cannot translate. Try setting "
502
+ "the model"
503
+ )
504
+ log.error(message)
505
+ raise ValueError(message)
506
+
507
+ log.debug(f"[{block.name}] Translating...")
508
+ log.debug(f"[{block.name}] Input text:\n{block.original.text}")
509
+
510
+ # Track the cost of translating this block
511
+ # TODO: If non-OpenAI models with prices are added, this will need
512
+ # to be updated.
513
+ with get_model_callback() as cb:
514
+ t0 = time.time()
515
+ block.text = self._run_chain(block)
516
+ block.processing_time = time.time() - t0
517
+ block.cost = cb.total_cost
518
+ block.retries = max(0, cb.successful_requests - 1)
519
+
520
+ block.tokens = self._llm.get_num_tokens(block.text)
521
+ block.translated = True
522
+
523
+ log.debug(f"[{block.name}] Output code:\n{block.text}")
524
+
525
+ def _split_file(self, file: Path) -> CodeBlock:
526
+ filename = file.name
527
+ log.info(f"[{filename}] Splitting file")
528
+ root = self._splitter.split(file)
529
+ log.info(
530
+ f"[{filename}] File split into {root.n_descendents:,} blocks, "
531
+ f"tree of height {root.height}"
532
+ )
533
+ log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
534
+ return root
535
+
536
+ def _run_chain(self, block: TranslatedCodeBlock) -> str:
537
+ """Run the model with three nested error fixing schemes.
538
+ First, try to fix simple formatting errors by giving the model just
539
+ the output and the parsing error. After a number of attempts, try
540
+ giving the model the output, the parsing error, and the original
541
+ input. Again check/retry this output to solve for formatting errors.
542
+ If we still haven't succeeded after several attempts, the model may
543
+ be getting thrown off by a bad initial output; start from scratch
544
+ and try again.
545
+
546
+ The number of tries for each layer of this scheme is roughly equal
547
+ to the cube root of self.max_retries, so the total calls to the
548
+ LLM will be roughly as expected (up to sqrt(self.max_retries) over)
549
+ """
550
+ self._parser.set_reference(block.original)
551
+
552
+ # Retries with just the output and the error
553
+ n1 = round(self.max_prompts ** (1 / 3))
554
+
555
+ # Retries with the input, output, and error
556
+ n2 = round((self.max_prompts // n1) ** (1 / 2))
557
+
558
+ # Retries with just the input
559
+ n3 = math.ceil(self.max_prompts / (n1 * n2))
560
+
561
+ fix_format = OutputFixingParser.from_llm(
562
+ llm=self._llm,
563
+ parser=self._parser,
564
+ max_retries=n1,
565
+ )
566
+ retry = RetryWithErrorOutputParser.from_llm(
567
+ llm=self._llm,
568
+ parser=fix_format,
569
+ max_retries=n2,
570
+ )
571
+
572
+ completion_chain = self._prompt | self._llm
573
+ chain = RunnableParallel(
574
+ completion=completion_chain, prompt_value=self._prompt
575
+ ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
576
+
577
+ for _ in range(n3):
578
+ try:
579
+ return chain.invoke({"SOURCE_CODE": block.original.text})
580
+ except OutputParserException:
581
+ pass
582
+
583
+ raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
584
+
585
+ def _get_output_obj(
586
+ self, block: TranslatedCodeBlock
587
+ ) -> dict[str, int | float | str | dict[str, str]]:
588
+ output_str = self._parser.parse_combined_output(block.complete_text)
589
+
590
+ output: str | dict[str, str]
591
+ try:
592
+ output = json.loads(output_str)
593
+ except json.JSONDecodeError:
594
+ output = output_str
595
+
596
+ return dict(
597
+ input=block.original.text,
598
+ metadata=dict(
599
+ retries=block.total_retries,
600
+ cost=block.total_cost,
601
+ processing_time=block.processing_time,
602
+ ),
603
+ output=output,
604
+ )
605
+
606
+ def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
607
+ """Save a file to disk.
608
+
609
+ Arguments:
610
+ block: The `TranslatedCodeBlock` to save to a file.
611
+ """
612
+ obj = self._get_output_obj(block)
613
+ out_path.parent.mkdir(parents=True, exist_ok=True)
614
+ out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")